Skip to content

Commit

Permalink
Add explainer for local classifier per level #minor (#116)
Browse files Browse the repository at this point in the history
  • Loading branch information
iwan-tee authored Apr 12, 2024
1 parent 2fe8480 commit 6f37990
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 15 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,8 @@ predictions = pipeline.predict(X_test)
```

## Explaining Hierarchical Classifiers
Hierarchical classifiers can provide additional insights when combined with explainability methods. HiClass allows explaining hierarchical models using SHAP values. Different hierarchical models yield different insights. More information on explaining [Local classifier per parent node](https://colab.research.google.com/drive/1rVlYuRU_uO1jw5sD6qo2HoCpCz6E6z5J?usp=sharing), [Local classifier per node](https://colab.research.google.com/drive/1wqSl1t_Qn2f62WNZQ48mdB0mNeu1XSF1?usp=sharing), and [Local classifier per level]() is available on [Read the Docs](https://hiclass.readthedocs.io/en/latest/algorithms/explainer.html).

Hierarchical classifiers can provide additional insights when combined with explainability methods. HiClass allows explaining hierarchical models using SHAP values. Different hierarchical models yield different insights. More information on explaining [Local classifier per parent node](https://colab.research.google.com/drive/1rVlYuRU_uO1jw5sD6qo2HoCpCz6E6z5J?usp=sharing), [Local classifier per node](https://colab.research.google.com/drive/1wqSl1t_Qn2f62WNZQ48mdB0mNeu1XSF1?usp=sharing), and [Local classifier per level](https://colab.research.google.com/drive/1VnGlJu-1wSG4wxHXL0Ijf2a7Pu3kklT-?usp=sharing) is available on [Read the Docs](https://hiclass.readthedocs.io/en/latest/algorithms/explainer.html).

## Step-by-step walk-through

Expand Down
59 changes: 59 additions & 0 deletions docs/examples/plot_lcpl_explainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# -*- coding: utf-8 -*-
"""
=========================================
Explaining Local Classifier Per Level
=========================================
A minimalist example showing how to use HiClass Explainer to obtain SHAP values of LCPL model.
A detailed summary of the Explainer class has been given at Algorithms Overview Section for :ref:`Hierarchical Explainability`.
SHAP values are calculated based on a synthetic platypus diseases dataset that can be downloaded `here <https://gist.githubusercontent.com/ashishpatel16/9306f8ed3ed101e7ddcb519776bcbd80/raw/3f225c3f80dd8cbb1b6252f6c372a054ec968705/platypus_diseases.csv>`_.
"""
from sklearn.ensemble import RandomForestClassifier
from hiclass import LocalClassifierPerLevel, Explainer
import shap
from hiclass.datasets import load_platypus

# Load train and test splits
X_train, X_test, Y_train, Y_test = load_platypus()

# Use random forest classifiers for every level
rfc = RandomForestClassifier()
classifier = LocalClassifierPerLevel(local_classifier=rfc, replace_classifiers=False)

# Train local classifiers per level
classifier.fit(X_train, Y_train)

# Define Explainer
explainer = Explainer(classifier, data=X_train, mode="tree")
explanations = explainer.explain(X_test.values)
print(explanations)

# Let's filter the Shapley values corresponding to the Covid (level 1)
# and 'Respiratory' (level 0)

covid_idx = classifier.predict(X_test)[:, 1] == "Covid"

shap_filter_covid = {"level": 1, "class": "Covid", "sample": covid_idx}
shap_filter_resp = {"level": 0, "class": "Respiratory", "sample": covid_idx}
shap_val_covid = explanations.sel(**shap_filter_covid)
shap_val_resp = explanations.sel(**shap_filter_resp)


# This code snippet demonstrates how to visually compare the mean absolute SHAP values for 'Covid' vs. 'Respiratory' diseases.

# Feature names for the X-axis
feature_names = X_train.columns.values

# SHAP values for 'Covid'
shap_values_covid = shap_val_covid.shap_values.values

# SHAP values for 'Respiratory'
shap_values_resp = shap_val_resp.shap_values.values

shap.summary_plot(
[shap_values_covid, shap_values_resp],
features=X_test.iloc[covid_idx],
feature_names=X_train.columns.values,
plot_type="bar",
class_names=["Covid", "Respiratory"],
)
55 changes: 45 additions & 10 deletions hiclass/Explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,31 @@ def _get_traversed_nodes_lcpn(self, samples):

return traversals

def _get_traversed_nodes_lcpl(self, samples):
"""
Return a list of all traversed nodes as per the provided LocalClassifierPerLevel model.
Parameters
----------
samples : array-like
Sample data for which to generate traversed nodes.
Returns
-------
traversals : list
A list of all traversed nodes as per LocalClassifierPerLevel (LCPL) strategy.
"""
traversals = []
predictions = self.hierarchical_model.predict(samples)
for pred in predictions:
traversal_order = []
filtered_pred = [p for p in pred if p.strip()]
for i in range(1, len(filtered_pred) + 1):
node = self.hierarchical_model.separator_.join(filtered_pred[:i])
traversal_order.append(node)
traversals.append(traversal_order)
return traversals

def _calculate_shap_values(self, X):
"""
Return an xarray.Dataset object for a single sample provided. This dataset is aligned on the `level` attribute.
Expand All @@ -244,23 +269,27 @@ def _calculate_shap_values(self, X):
A single explanation for the prediction of given sample.
"""
traversed_nodes = []
if isinstance(self.hierarchical_model, LocalClassifierPerParentNode):
if isinstance(self.hierarchical_model, LocalClassifierPerLevel):
traversed_nodes = self._get_traversed_nodes_lcpl(X)[0]
elif isinstance(self.hierarchical_model, LocalClassifierPerParentNode):
traversed_nodes = self._get_traversed_nodes_lcppn(X)[0]
elif isinstance(self.hierarchical_model, LocalClassifierPerNode):
traversed_nodes = self._get_traversed_nodes_lcpn(X)[0]
datasets = []
level = 0
for node in traversed_nodes:
# Skip if node is empty or classifier is not found, can happen in case of imbalanced hierarchies
if (
node == ""
or "classifier" not in self.hierarchical_model.hierarchy_.nodes[node]
if node == "" or (
("classifier" not in self.hierarchical_model.hierarchy_.nodes[node])
and (not isinstance(self.hierarchical_model, LocalClassifierPerLevel))
):
continue

local_classifier = self.hierarchical_model.hierarchy_.nodes[node][
"classifier"
]
if isinstance(self.hierarchical_model, LocalClassifierPerLevel):
local_classifier = self.hierarchical_model.local_classifiers_[level]
else:
local_classifier = self.hierarchical_model.hierarchy_.nodes[node][
"classifier"
]

# Create a SHAP explainer for the local classifier
local_explainer = deepcopy(self.explainer)(local_classifier, self.data)
Expand All @@ -283,7 +312,7 @@ def _calculate_shap_values(self, X):
for label in local_classifier.classes_
]
predicted_class = current_node
else:
elif isinstance(self.hierarchical_model, LocalClassifierPerParentNode):
simplified_labels = [
label.split(self.hierarchical_model.separator_)[-1]
for label in local_classifier.classes_
Expand All @@ -293,6 +322,12 @@ def _calculate_shap_values(self, X):
.flatten()[0]
.split(self.hierarchical_model.separator_)[-1]
)
else:
simplified_labels = [
label.split(self.hierarchical_model.separator_)[-1]
for label in local_classifier.classes_
]
predicted_class = current_node

classes = xr.DataArray(
simplified_labels,
Expand Down Expand Up @@ -326,7 +361,7 @@ def _calculate_shap_values(self, X):
"level": level,
}
)
level = level + 1
level += 1
datasets.append(local_dataset)
sample_explanation = xr.concat(datasets, dim="level")
return sample_explanation
55 changes: 51 additions & 4 deletions tests/test_Explainer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import numpy as np
import pytest
from sklearn.ensemble import RandomForestClassifier
from hiclass import LocalClassifierPerNode, LocalClassifierPerParentNode, Explainer
from hiclass import (
LocalClassifierPerLevel,
LocalClassifierPerParentNode,
LocalClassifierPerNode,
Explainer,
)

try:
import shap
Expand Down Expand Up @@ -98,6 +103,26 @@ def test_explainer_tree_lcpn(data, request):
assert str(explanations["node"][i].data[j]) == y_pred[j]


@pytest.mark.skipif(not shap_installed, reason="shap not installed")
@pytest.mark.parametrize("data", ["explainer_data", "explainer_data_no_root"])
def test_explainer_tree_lcpl(data, request):
rfc = RandomForestClassifier()
lcpl = LocalClassifierPerLevel(local_classifier=rfc, replace_classifiers=False)

x_train, x_test, y_train = request.getfixturevalue(data)

lcpl.fit(x_train, y_train)

explainer = Explainer(lcpl, data=x_train, mode="tree")
explanations = explainer.explain(x_test)
assert explanations is not None
y_preds = lcpl.predict(x_test)
for i in range(len(x_test)):
y_pred = y_preds[i]
for j in range(len(y_pred)):
assert str(explanations["node"][i].data[j]) == y_pred[j]


@pytest.mark.skipif(not shap_installed, reason="shap not installed")
@pytest.mark.parametrize("data", ["explainer_data", "explainer_data_no_root"])
def test_traversal_path_lcppn(data, request):
Expand Down Expand Up @@ -142,11 +167,30 @@ def test_traversal_path_lcpn(data, request):
assert label == preds[i][j]


@pytest.mark.skipif(not shap_installed, reason="shap not installed")
@pytest.mark.parametrize("data", ["explainer_data", "explainer_data_no_root"])
def test_traversal_path_lcpl(data, request):
x_train, x_test, y_train = request.getfixturevalue(data)
rfc = RandomForestClassifier()
lcpl = LocalClassifierPerLevel(local_classifier=rfc, replace_classifiers=False)

lcpl.fit(x_train, y_train)
explainer = Explainer(lcpl, data=x_train, mode="tree")
traversals = explainer._get_traversed_nodes_lcpl(x_test)
preds = lcpl.predict(x_test)
assert len(preds) == len(traversals)
for i in range(len(x_test)):
for j in range(len(traversals[i])):
label = traversals[i][j].split(lcpl.separator_)[-1]
assert label == preds[i][j]


@pytest.mark.skipif(not shap_installed, reason="shap not installed")
@pytest.mark.skipif(not xarray_installed, reason="xarray not installed")
@pytest.mark.parametrize("data", ["explainer_data", "explainer_data_no_root"])
@pytest.mark.parametrize(
"classifier", [LocalClassifierPerParentNode, LocalClassifierPerNode]
"classifier",
[LocalClassifierPerLevel, LocalClassifierPerParentNode, LocalClassifierPerNode],
)
def test_explain_with_xr(data, request, classifier):
x_train, x_test, y_train = request.getfixturevalue(data)
Expand All @@ -162,7 +206,8 @@ def test_explain_with_xr(data, request, classifier):


@pytest.mark.parametrize(
"classifier", [LocalClassifierPerParentNode, LocalClassifierPerNode]
"classifier",
[LocalClassifierPerParentNode, LocalClassifierPerLevel, LocalClassifierPerNode],
)
def test_imports(classifier):
x_train = [[76, 12, 49], [88, 63, 31], [5, 42, 24], [17, 90, 55]]
Expand All @@ -176,8 +221,10 @@ def test_imports(classifier):
assert isinstance(explainer.data, np.ndarray)


@pytest.mark.skipif(not shap_installed, reason="shap not installed")
@pytest.mark.parametrize(
"classifier", [LocalClassifierPerParentNode, LocalClassifierPerNode]
"classifier",
[LocalClassifierPerLevel, LocalClassifierPerParentNode, LocalClassifierPerNode],
)
@pytest.mark.parametrize("data", ["explainer_data"])
@pytest.mark.parametrize("mode", ["linear", "gradient", "deep", "tree", ""])
Expand Down

0 comments on commit 6f37990

Please sign in to comment.