Skip to content

Commit

Permalink
add tests for the shap tabular api
Browse files Browse the repository at this point in the history
  • Loading branch information
Yang committed Dec 12, 2023
1 parent 765dd67 commit ebefbe8
Showing 1 changed file with 35 additions and 0 deletions.
35 changes: 35 additions & 0 deletions tests/methods/test_shap_tabular.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""Test LIME tabular method."""
from unittest import TestCase
import numpy as np
import dianna
from dianna.methods.kernelshap_tabular import KERNELSHAPTabular
from tests.utils import run_model


class LIMEOnTabular(TestCase):
"""Suite of LIME tests for the tabular case."""

def test_shap_tabular_classification_correct_output_shape(self):
"""Test the output of explainer."""
training_data = np.random.random((10, 2))
input_data = np.random.random(2)
feature_names = ["feature_1", "feature_2"]
explainer = KERNELSHAPTabular(training_data,
mode ='classification',
feature_names=feature_names,)
exp = explainer.explain(
run_model,
input_data,
)
assert len(exp[0]) == len(feature_names)

def test_shap_tabular_regression_correct_output_shape(self):
"""Test the output of explainer."""
training_data = np.random.random((10, 2))
input_data = np.random.random(2)
feature_names = ["feature_1", "feature_2"]
exp = dianna.explain_tabular(run_model, input_tabular=input_data, method='kernelshap',
mode ='regression', training_data = training_data,
training_data_kmeans = 2, feature_names=feature_names)

assert len(exp) == len(feature_names)

0 comments on commit ebefbe8

Please sign in to comment.