diff --git a/README.md b/README.md index c835ba91..3491e72b 100644 --- a/README.md +++ b/README.md @@ -202,7 +202,8 @@ predictions = pipeline.predict(X_test) ``` ## Explaining Hierarchical Classifiers -Hierarchical classifiers can provide additional insights when combined with explainability methods. HiClass allows explaining hierarchical models using SHAP values. Different hierarchical models yield different insights. More information on explaining [Local classifier per parent node](https://colab.research.google.com/drive/1rVlYuRU_uO1jw5sD6qo2HoCpCz6E6z5J?usp=sharing), [Local classifier per node](https://colab.research.google.com/drive/1wqSl1t_Qn2f62WNZQ48mdB0mNeu1XSF1?usp=sharing), and [Local classifier per level]() is available on [Read the Docs](https://hiclass.readthedocs.io/en/latest/algorithms/explainer.html). + +Hierarchical classifiers can provide additional insights when combined with explainability methods. HiClass allows explaining hierarchical models using SHAP values. Different hierarchical models yield different insights. More information on explaining [Local classifier per parent node](https://colab.research.google.com/drive/1rVlYuRU_uO1jw5sD6qo2HoCpCz6E6z5J?usp=sharing), [Local classifier per node](https://colab.research.google.com/drive/1wqSl1t_Qn2f62WNZQ48mdB0mNeu1XSF1?usp=sharing), and [Local classifier per level](https://colab.research.google.com/drive/1VnGlJu-1wSG4wxHXL0Ijf2a7Pu3kklT-?usp=sharing) is available on [Read the Docs](https://hiclass.readthedocs.io/en/latest/algorithms/explainer.html). ## Step-by-step walk-through diff --git a/docs/examples/plot_lcpl_explainer.py b/docs/examples/plot_lcpl_explainer.py new file mode 100644 index 00000000..d085c791 --- /dev/null +++ b/docs/examples/plot_lcpl_explainer.py @@ -0,0 +1,59 @@ +# -*- coding: utf-8 -*- +""" +========================================= +Explaining Local Classifier Per Level +========================================= + +A minimalist example showing how to use HiClass Explainer to obtain SHAP values of LCPL model. +A detailed summary of the Explainer class has been given at Algorithms Overview Section for :ref:`Hierarchical Explainability`. +SHAP values are calculated based on a synthetic platypus diseases dataset that can be downloaded `here `_. +""" +from sklearn.ensemble import RandomForestClassifier +from hiclass import LocalClassifierPerLevel, Explainer +import shap +from hiclass.datasets import load_platypus + +# Load train and test splits +X_train, X_test, Y_train, Y_test = load_platypus() + +# Use random forest classifiers for every level +rfc = RandomForestClassifier() +classifier = LocalClassifierPerLevel(local_classifier=rfc, replace_classifiers=False) + +# Train local classifiers per level +classifier.fit(X_train, Y_train) + +# Define Explainer +explainer = Explainer(classifier, data=X_train, mode="tree") +explanations = explainer.explain(X_test.values) +print(explanations) + +# Let's filter the Shapley values corresponding to the Covid (level 1) +# and 'Respiratory' (level 0) + +covid_idx = classifier.predict(X_test)[:, 1] == "Covid" + +shap_filter_covid = {"level": 1, "class": "Covid", "sample": covid_idx} +shap_filter_resp = {"level": 0, "class": "Respiratory", "sample": covid_idx} +shap_val_covid = explanations.sel(**shap_filter_covid) +shap_val_resp = explanations.sel(**shap_filter_resp) + + +# This code snippet demonstrates how to visually compare the mean absolute SHAP values for 'Covid' vs. 'Respiratory' diseases. + +# Feature names for the X-axis +feature_names = X_train.columns.values + +# SHAP values for 'Covid' +shap_values_covid = shap_val_covid.shap_values.values + +# SHAP values for 'Respiratory' +shap_values_resp = shap_val_resp.shap_values.values + +shap.summary_plot( + [shap_values_covid, shap_values_resp], + features=X_test.iloc[covid_idx], + feature_names=X_train.columns.values, + plot_type="bar", + class_names=["Covid", "Respiratory"], +) diff --git a/hiclass/Explainer.py b/hiclass/Explainer.py index 585f4b85..8708527b 100644 --- a/hiclass/Explainer.py +++ b/hiclass/Explainer.py @@ -229,6 +229,31 @@ def _get_traversed_nodes_lcpn(self, samples): return traversals + def _get_traversed_nodes_lcpl(self, samples): + """ + Return a list of all traversed nodes as per the provided LocalClassifierPerLevel model. + + Parameters + ---------- + samples : array-like + Sample data for which to generate traversed nodes. + + Returns + ------- + traversals : list + A list of all traversed nodes as per LocalClassifierPerLevel (LCPL) strategy. + """ + traversals = [] + predictions = self.hierarchical_model.predict(samples) + for pred in predictions: + traversal_order = [] + filtered_pred = [p for p in pred if p.strip()] + for i in range(1, len(filtered_pred) + 1): + node = self.hierarchical_model.separator_.join(filtered_pred[:i]) + traversal_order.append(node) + traversals.append(traversal_order) + return traversals + def _calculate_shap_values(self, X): """ Return an xarray.Dataset object for a single sample provided. This dataset is aligned on the `level` attribute. @@ -244,23 +269,27 @@ def _calculate_shap_values(self, X): A single explanation for the prediction of given sample. """ traversed_nodes = [] - if isinstance(self.hierarchical_model, LocalClassifierPerParentNode): + if isinstance(self.hierarchical_model, LocalClassifierPerLevel): + traversed_nodes = self._get_traversed_nodes_lcpl(X)[0] + elif isinstance(self.hierarchical_model, LocalClassifierPerParentNode): traversed_nodes = self._get_traversed_nodes_lcppn(X)[0] elif isinstance(self.hierarchical_model, LocalClassifierPerNode): traversed_nodes = self._get_traversed_nodes_lcpn(X)[0] datasets = [] level = 0 for node in traversed_nodes: - # Skip if node is empty or classifier is not found, can happen in case of imbalanced hierarchies - if ( - node == "" - or "classifier" not in self.hierarchical_model.hierarchy_.nodes[node] + if node == "" or ( + ("classifier" not in self.hierarchical_model.hierarchy_.nodes[node]) + and (not isinstance(self.hierarchical_model, LocalClassifierPerLevel)) ): continue - local_classifier = self.hierarchical_model.hierarchy_.nodes[node][ - "classifier" - ] + if isinstance(self.hierarchical_model, LocalClassifierPerLevel): + local_classifier = self.hierarchical_model.local_classifiers_[level] + else: + local_classifier = self.hierarchical_model.hierarchy_.nodes[node][ + "classifier" + ] # Create a SHAP explainer for the local classifier local_explainer = deepcopy(self.explainer)(local_classifier, self.data) @@ -283,7 +312,7 @@ def _calculate_shap_values(self, X): for label in local_classifier.classes_ ] predicted_class = current_node - else: + elif isinstance(self.hierarchical_model, LocalClassifierPerParentNode): simplified_labels = [ label.split(self.hierarchical_model.separator_)[-1] for label in local_classifier.classes_ @@ -293,6 +322,12 @@ def _calculate_shap_values(self, X): .flatten()[0] .split(self.hierarchical_model.separator_)[-1] ) + else: + simplified_labels = [ + label.split(self.hierarchical_model.separator_)[-1] + for label in local_classifier.classes_ + ] + predicted_class = current_node classes = xr.DataArray( simplified_labels, @@ -326,7 +361,7 @@ def _calculate_shap_values(self, X): "level": level, } ) - level = level + 1 + level += 1 datasets.append(local_dataset) sample_explanation = xr.concat(datasets, dim="level") return sample_explanation diff --git a/tests/test_Explainer.py b/tests/test_Explainer.py index c1caa5e7..303216f6 100644 --- a/tests/test_Explainer.py +++ b/tests/test_Explainer.py @@ -1,7 +1,12 @@ import numpy as np import pytest from sklearn.ensemble import RandomForestClassifier -from hiclass import LocalClassifierPerNode, LocalClassifierPerParentNode, Explainer +from hiclass import ( + LocalClassifierPerLevel, + LocalClassifierPerParentNode, + LocalClassifierPerNode, + Explainer, +) try: import shap @@ -98,6 +103,26 @@ def test_explainer_tree_lcpn(data, request): assert str(explanations["node"][i].data[j]) == y_pred[j] +@pytest.mark.skipif(not shap_installed, reason="shap not installed") +@pytest.mark.parametrize("data", ["explainer_data", "explainer_data_no_root"]) +def test_explainer_tree_lcpl(data, request): + rfc = RandomForestClassifier() + lcpl = LocalClassifierPerLevel(local_classifier=rfc, replace_classifiers=False) + + x_train, x_test, y_train = request.getfixturevalue(data) + + lcpl.fit(x_train, y_train) + + explainer = Explainer(lcpl, data=x_train, mode="tree") + explanations = explainer.explain(x_test) + assert explanations is not None + y_preds = lcpl.predict(x_test) + for i in range(len(x_test)): + y_pred = y_preds[i] + for j in range(len(y_pred)): + assert str(explanations["node"][i].data[j]) == y_pred[j] + + @pytest.mark.skipif(not shap_installed, reason="shap not installed") @pytest.mark.parametrize("data", ["explainer_data", "explainer_data_no_root"]) def test_traversal_path_lcppn(data, request): @@ -142,11 +167,30 @@ def test_traversal_path_lcpn(data, request): assert label == preds[i][j] +@pytest.mark.skipif(not shap_installed, reason="shap not installed") +@pytest.mark.parametrize("data", ["explainer_data", "explainer_data_no_root"]) +def test_traversal_path_lcpl(data, request): + x_train, x_test, y_train = request.getfixturevalue(data) + rfc = RandomForestClassifier() + lcpl = LocalClassifierPerLevel(local_classifier=rfc, replace_classifiers=False) + + lcpl.fit(x_train, y_train) + explainer = Explainer(lcpl, data=x_train, mode="tree") + traversals = explainer._get_traversed_nodes_lcpl(x_test) + preds = lcpl.predict(x_test) + assert len(preds) == len(traversals) + for i in range(len(x_test)): + for j in range(len(traversals[i])): + label = traversals[i][j].split(lcpl.separator_)[-1] + assert label == preds[i][j] + + @pytest.mark.skipif(not shap_installed, reason="shap not installed") @pytest.mark.skipif(not xarray_installed, reason="xarray not installed") @pytest.mark.parametrize("data", ["explainer_data", "explainer_data_no_root"]) @pytest.mark.parametrize( - "classifier", [LocalClassifierPerParentNode, LocalClassifierPerNode] + "classifier", + [LocalClassifierPerLevel, LocalClassifierPerParentNode, LocalClassifierPerNode], ) def test_explain_with_xr(data, request, classifier): x_train, x_test, y_train = request.getfixturevalue(data) @@ -162,7 +206,8 @@ def test_explain_with_xr(data, request, classifier): @pytest.mark.parametrize( - "classifier", [LocalClassifierPerParentNode, LocalClassifierPerNode] + "classifier", + [LocalClassifierPerParentNode, LocalClassifierPerLevel, LocalClassifierPerNode], ) def test_imports(classifier): x_train = [[76, 12, 49], [88, 63, 31], [5, 42, 24], [17, 90, 55]] @@ -176,8 +221,10 @@ def test_imports(classifier): assert isinstance(explainer.data, np.ndarray) +@pytest.mark.skipif(not shap_installed, reason="shap not installed") @pytest.mark.parametrize( - "classifier", [LocalClassifierPerParentNode, LocalClassifierPerNode] + "classifier", + [LocalClassifierPerLevel, LocalClassifierPerParentNode, LocalClassifierPerNode], ) @pytest.mark.parametrize("data", ["explainer_data"]) @pytest.mark.parametrize("mode", ["linear", "gradient", "deep", "tree", ""])