Skip to content

Commit

Permalink
Merge branch 'main' into 667-reorganize
Browse files Browse the repository at this point in the history
  • Loading branch information
loostrum committed Mar 20, 2024
2 parents f4aa7a1 + 452838f commit 46ce7dc
Show file tree
Hide file tree
Showing 18 changed files with 927 additions and 226 deletions.
2 changes: 1 addition & 1 deletion .bumpversion.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 1.3.0
current_version = 1.4.0

[comment]
comment = The contents of this file cannot be merged with that of setup.cfg until https://github.com/c4urself/bump2version/issues/185 is resolved
Expand Down
2 changes: 1 addition & 1 deletion CITATION.cff
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ authors:
name-particle: "van der"

doi: 10.5281/zenodo.5801485
version: "1.3.0"
version: "1.4.0"
repository-code: "https://github.com/dianna-ai/dianna"
keywords:
- XAI
Expand Down
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -290,15 +290,15 @@ And here are links to notebooks showing how we created our models on the benchma

| Models | Generation |
| :-------------------------------------------------------- | :---------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| Coffee model | [Coffee model generation](https://github.com/dianna-ai/dianna-exploration/blob/main/example_data/model_generation/coffee/generate_model.ipynb) |
| [Coffee model](https://zenodo.org/records/10579458) | [Coffee model generation](https://github.com/dianna-ai/dianna-exploration/blob/main/example_data/model_generation/coffee/generate_model.ipynb) |
| [Season prediction model](https://zenodo.org/record/7543883) | [Season prediction model generation](https://github.com/dianna-ai/dianna-exploration/blob/main/example_data/model_generation/season_prediction/generate_model.ipynb) |

### Tabular

| Models | Generation |
| :-------------------------------------------------------- | :---------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| Penguin model (classification) | [Penguin model generation](https://github.com/dianna-ai/dianna-exploration/blob/main/example_data/model_generation/penguin_species/generate_model.ipynb) |
| Sunshine hours prediction model (regression) | [Sunshine hours prediction model generation](https://github.com/dianna-ai/dianna-exploration/blob/main/example_data/model_generation/sunshine_prediction/generate_model.ipynb) |
| [Penguin model (classification)](https://zenodo.org/records/10580743) | [Penguin model generation](https://github.com/dianna-ai/dianna-exploration/blob/main/example_data/model_generation/penguin_species/generate_model.ipynb) |
| [Sunshine hours prediction model (regression)](https://zenodo.org/records/10580833) | [Sunshine hours prediction model generation](https://github.com/dianna-ai/dianna-exploration/blob/main/example_data/model_generation/sunshine_prediction/generate_model.ipynb) |


**_We envision the birth of the ONNX Scientific models zoo soon..._**
Expand Down
52 changes: 25 additions & 27 deletions dianna/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,17 @@
"""
import importlib
import logging
import warnings
from . import utils

logging.getLogger(__name__).addHandler(logging.NullHandler())

__author__ = 'DIANNA Team'
__email__ = '[email protected]'
__version__ = '1.3.0'
__version__ = '1.4.0'


def explain_timeseries(model_or_function, input_timeseries, method, labels, **kwargs):
def explain_timeseries(model_or_function, input_timeseries, method, labels,
**kwargs):
"""Explain timeseries data given a model and a chosen method.
Args:
Expand All @@ -49,15 +49,13 @@ def explain_timeseries(model_or_function, input_timeseries, method, labels, **kw
"""
explainer = _get_explainer(method, kwargs, modality='Timeseries')
explain_timeseries_kwargs = utils.get_kwargs_applicable_to_function(
explainer.explain, kwargs
)
explainer.explain, kwargs)
for key in explain_timeseries_kwargs.keys():
kwargs.pop(key)
if kwargs:
warnings.warn(message = f'Please note the following kwargs are not being used: {kwargs}')
return explainer.explain(
model_or_function, input_timeseries, labels, **explain_timeseries_kwargs
)
raise TypeError(f'Error due to following unused kwargs: {kwargs}')
return explainer.explain(model_or_function, input_timeseries, labels,
**explain_timeseries_kwargs)


def explain_image(model_or_function, input_image, method, labels, **kwargs):
Expand All @@ -80,18 +78,17 @@ def explain_image(model_or_function, input_image, method, labels, **kwargs):
from onnx_tf.backend import prepare # noqa: F401
explainer = _get_explainer(method, kwargs, modality='Image')
explain_image_kwargs = utils.get_kwargs_applicable_to_function(
explainer.explain, kwargs
)
explainer.explain, kwargs)
for key in explain_image_kwargs.keys():
kwargs.pop(key)
if kwargs:
warnings.warn(message = f'Please note the following kwargs are not being used: {kwargs}')
return explainer.explain(
model_or_function, input_image, labels, **explain_image_kwargs
)
raise TypeError(f'Error due to following unused kwargs: {kwargs}')
return explainer.explain(model_or_function, input_image, labels,
**explain_image_kwargs)


def explain_text(model_or_function, input_text, tokenizer, method, labels, **kwargs):
def explain_text(model_or_function, input_text, tokenizer, method, labels,
**kwargs):
"""Explain text (input_text) given a model and a chosen method.
Args:
Expand All @@ -109,12 +106,11 @@ def explain_text(model_or_function, input_text, tokenizer, method, labels, **kwa
"""
explainer = _get_explainer(method, kwargs, modality='Text')
explain_text_kwargs = utils.get_kwargs_applicable_to_function(
explainer.explain, kwargs
)
explainer.explain, kwargs)
for key in explain_text_kwargs.keys():
kwargs.pop(key)
if kwargs:
warnings.warn(message = f'Please note the following kwargs are not being used: {kwargs}')
raise TypeError(f'Error due to following unused kwargs: {kwargs}')
return explainer.explain(
model_or_function=model_or_function,
input_text=input_text,
Expand All @@ -124,7 +120,11 @@ def explain_text(model_or_function, input_text, tokenizer, method, labels, **kwa
)


def explain_tabular(model_or_function, input_tabular, method, labels=(1, ), **kwargs):
def explain_tabular(model_or_function,
input_tabular,
method,
labels=(1, ),
**kwargs):
"""Explain tabular (input_text) given a model and a chosen method.
Args:
Expand All @@ -140,24 +140,23 @@ def explain_tabular(model_or_function, input_tabular, method, labels=(1, ), **kw
"""
explainer = _get_explainer(method, kwargs, modality='Tabular')
explain_tabular_kwargs = utils.get_kwargs_applicable_to_function(
explainer.explain, kwargs
)
explainer.explain, kwargs)
for key in explain_tabular_kwargs.keys():
kwargs.pop(key)
if kwargs:
warnings.warn(message = f'Please note the following kwargs are not being used: {kwargs}')
raise TypeError(f'Error due to following unused kwargs: {kwargs}')
return explainer.explain(
model_or_function=model_or_function,
input_tabular=input_tabular,
labels=labels,
**explain_tabular_kwargs,
)


def _get_explainer(method, kwargs, modality):
try:
method_submodule = importlib.import_module(
f'dianna.methods.{method.lower()}_{modality.lower()}'
)
f'dianna.methods.{method.lower()}_{modality.lower()}')
except ImportError as err:
raise ValueError(
f'Method {method.lower()}_{modality.lower()} does not exist'
Expand All @@ -169,8 +168,7 @@ def _get_explainer(method, kwargs, modality):
f'Data modality {modality} is not available for method {method.upper()}'
) from err
method_kwargs = utils.get_kwargs_applicable_to_function(
method_class.__init__, kwargs
)
method_class.__init__, kwargs)
# Remove used kwargs from list of kwargs passed to the function.
for key in method_kwargs.keys():
kwargs.pop(key)
Expand Down
Binary file removed dianna/data/bee.jpg
Binary file not shown.
10 changes: 5 additions & 5 deletions dianna/methods/rise_image.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np
from dianna import utils
from dianna.utils.maskers import generate_masks_for_images
from dianna.utils.maskers import generate_interpolated_float_masks_for_image
from dianna.utils.predict import make_predictions
from dianna.utils.rise_utils import normalize

Expand Down Expand Up @@ -60,8 +60,8 @@ def explain(self, model_or_function, input_data, labels, batch_size=100):
# data shape without batch axis and channel axis
img_shape = input_data.shape[1:3]
# Expose masks for to make user inspection possible
self.masks = generate_masks_for_images(img_shape, self.n_masks,
active_p_keep, self.feature_res)
self.masks = generate_interpolated_float_masks_for_image(
img_shape, active_p_keep, self.n_masks, self.feature_res)

# Make sure multiplication is being done for correct axes
masked = input_data * self.masks
Expand Down Expand Up @@ -117,8 +117,8 @@ def _determine_p_keep(self, input_data, runner, n_masks=100):

def _calculate_max_class_std(self, p_keep, runner, input_data, n_masks):
img_shape = input_data.shape[1:3]
masks = generate_masks_for_images(img_shape, n_masks, p_keep,
self.feature_res)
masks = generate_interpolated_float_masks_for_image(
img_shape, p_keep, n_masks, self.feature_res)
masked = input_data * masks
predictions = make_predictions(masked, runner, batch_size=50)
std_per_class = predictions.std(axis=0)
Expand Down
53 changes: 36 additions & 17 deletions dianna/methods/rise_timeseries.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional
import numpy as np
from dianna import utils
from dianna.utils.maskers import generate_masks
from dianna.utils.maskers import mask_data
Expand All @@ -8,25 +10,37 @@
class RISETimeseries:
"""RISE implementation for timeseries adapted from the image version of RISE."""

def __init__(self,
n_masks=1000,
feature_res=8,
p_keep=0.5,
preprocess_function=None):
def __init__(
self,
n_masks: int = 1000,
feature_res: int = 8,
p_keep: float = 0.5,
preprocess_function: Optional[callable] = None,
keep_masks: bool = False,
keep_masked_data: bool = False,
keep_predictions: bool = False,
) -> np.ndarray:
"""RISE initializer.
Args:
n_masks (int): Number of masks to generate.
feature_res (int): Resolution of features in masks.
p_keep (float): Fraction of input data to keep in each mask (Default: auto-tune this value).
preprocess_function (callable, optional): Function to preprocess input data with
n_masks: Number of masks to generate.
feature_res: Resolution of features in masks.
p_keep: Fraction of input data to keep in each mask (Default: auto-tune this value).
preprocess_function: Function to preprocess input data with
keep_masks: keep masks in memory for the user to inspect
keep_masked_data: keep masked data in memory for the user to inspect
keep_predictions: keep model predictions in memory for the user to inspect
"""
self.n_masks = n_masks
self.feature_res = feature_res
self.p_keep = p_keep
self.preprocess_function = preprocess_function
self.masks = None
self.masked = None
self.predictions = None
self.keep_masks = keep_masks
self.keep_masked_data = keep_masked_data
self.keep_predictions = keep_predictions

def explain(self,
model_or_function,
Expand All @@ -47,20 +61,25 @@ def explain(self,
labels (Iterable(int)): Labels to be explained
mask_type: Masking strategy for masked values. Choose from 'mean' or a callable(input_timeseries)
Returns:
Explanation heatmap for each class (np.ndarray).
"""
runner = utils.get_function(
model_or_function, preprocess_function=self.preprocess_function)
self.masks = generate_masks(input_timeseries,
number_of_masks=self.n_masks,
p_keep=self.p_keep)
masked = mask_data(input_timeseries, self.masks, mask_type=mask_type)

self.predictions = make_predictions(masked, runner, batch_size)
n_labels = self.predictions.shape[1]
masks = generate_masks(input_timeseries,
number_of_masks=self.n_masks,
feature_res=self.feature_res,
p_keep=self.p_keep)
self.masks = masks if self.keep_masks else None
masked = mask_data(input_timeseries, masks, mask_type=mask_type)
self.masked = masked if self.keep_masked_data else None
predictions = make_predictions(masked, runner, batch_size)
self.predictions = predictions if self.keep_predictions else None
n_labels = predictions.shape[1]

saliency = self.predictions.T.dot(self.masks.reshape(
self.n_masks, -1)).reshape(n_labels, *input_timeseries.shape)
saliency = predictions.T.dot(masks.reshape(self.n_masks, -1)).reshape(
n_labels, *input_timeseries.shape)
selected_saliency = saliency[labels]
return normalize(selected_saliency, self.n_masks, self.p_keep)
Loading

0 comments on commit 46ce7dc

Please sign in to comment.