Skip to content

Commit

Permalink
Merge branch 'main' into fix_plots
Browse files Browse the repository at this point in the history
  • Loading branch information
SarahAlidoost committed Mar 27, 2024
2 parents 75e2b9d + 6fe3359 commit 2165854
Show file tree
Hide file tree
Showing 72 changed files with 1,467 additions and 14,519 deletions.
2 changes: 1 addition & 1 deletion .bumpversion.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 1.3.0
current_version = 1.4.0

[comment]
comment = The contents of this file cannot be merged with that of setup.cfg until https://github.com/c4urself/bump2version/issues/185 is resolved
Expand Down
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ coverage.xml
.tox
*word_vectors.txt.pt

# tutorial model that is downloaded automatically
apertif_frb_dynamic_spectrum_model.onnx

docs/_build

# ide
Expand All @@ -36,4 +39,4 @@ venv3
.python-version

cache/
dashboard/cache/
dashboard/cache/
2 changes: 1 addition & 1 deletion CITATION.cff
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ authors:
name-particle: "van der"

doi: 10.5281/zenodo.5801485
version: "1.3.0"
version: "1.4.0"
repository-code: "https://github.com/dianna-ai/dianna"
keywords:
- XAI
Expand Down
43 changes: 21 additions & 22 deletions dianna/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@

__author__ = 'DIANNA Team'
__email__ = '[email protected]'
__version__ = '1.3.0'
__version__ = '1.4.0'


def explain_timeseries(model_or_function, input_timeseries, method, labels, **kwargs):
def explain_timeseries(model_or_function, input_timeseries, method, labels,
**kwargs):
"""Explain timeseries data given a model and a chosen method.
Args:
Expand All @@ -48,15 +49,13 @@ def explain_timeseries(model_or_function, input_timeseries, method, labels, **kw
"""
explainer = _get_explainer(method, kwargs, modality='Timeseries')
explain_timeseries_kwargs = utils.get_kwargs_applicable_to_function(
explainer.explain, kwargs
)
explainer.explain, kwargs)
for key in explain_timeseries_kwargs.keys():
kwargs.pop(key)
if kwargs:
raise TypeError(f'Error due to following unused kwargs: {kwargs}')
return explainer.explain(
model_or_function, input_timeseries, labels, **explain_timeseries_kwargs
)
return explainer.explain(model_or_function, input_timeseries, labels,
**explain_timeseries_kwargs)


def explain_image(model_or_function, input_image, method, labels, **kwargs):
Expand All @@ -79,18 +78,17 @@ def explain_image(model_or_function, input_image, method, labels, **kwargs):
from onnx_tf.backend import prepare # noqa: F401
explainer = _get_explainer(method, kwargs, modality='Image')
explain_image_kwargs = utils.get_kwargs_applicable_to_function(
explainer.explain, kwargs
)
explainer.explain, kwargs)
for key in explain_image_kwargs.keys():
kwargs.pop(key)
if kwargs:
raise TypeError(f'Error due to following unused kwargs: {kwargs}')
return explainer.explain(
model_or_function, input_image, labels, **explain_image_kwargs
)
return explainer.explain(model_or_function, input_image, labels,
**explain_image_kwargs)


def explain_text(model_or_function, input_text, tokenizer, method, labels, **kwargs):
def explain_text(model_or_function, input_text, tokenizer, method, labels,
**kwargs):
"""Explain text (input_text) given a model and a chosen method.
Args:
Expand All @@ -108,8 +106,7 @@ def explain_text(model_or_function, input_text, tokenizer, method, labels, **kwa
"""
explainer = _get_explainer(method, kwargs, modality='Text')
explain_text_kwargs = utils.get_kwargs_applicable_to_function(
explainer.explain, kwargs
)
explainer.explain, kwargs)
for key in explain_text_kwargs.keys():
kwargs.pop(key)
if kwargs:
Expand All @@ -123,7 +120,11 @@ def explain_text(model_or_function, input_text, tokenizer, method, labels, **kwa
)


def explain_tabular(model_or_function, input_tabular, method, labels=(1, ), **kwargs):
def explain_tabular(model_or_function,
input_tabular,
method,
labels=(1, ),
**kwargs):
"""Explain tabular (input_text) given a model and a chosen method.
Args:
Expand All @@ -139,8 +140,7 @@ def explain_tabular(model_or_function, input_tabular, method, labels=(1, ), **kw
"""
explainer = _get_explainer(method, kwargs, modality='Tabular')
explain_tabular_kwargs = utils.get_kwargs_applicable_to_function(
explainer.explain, kwargs
)
explainer.explain, kwargs)
for key in explain_tabular_kwargs.keys():
kwargs.pop(key)
if kwargs:
Expand All @@ -152,11 +152,11 @@ def explain_tabular(model_or_function, input_tabular, method, labels=(1, ), **kw
**explain_tabular_kwargs,
)


def _get_explainer(method, kwargs, modality):
try:
method_submodule = importlib.import_module(
f'dianna.methods.{method.lower()}_{modality.lower()}'
)
f'dianna.methods.{method.lower()}_{modality.lower()}')
except ImportError as err:
raise ValueError(
f'Method {method.lower()}_{modality.lower()} does not exist'
Expand All @@ -168,8 +168,7 @@ def _get_explainer(method, kwargs, modality):
f'Data modality {modality} is not available for method {method.upper()}'
) from err
method_kwargs = utils.get_kwargs_applicable_to_function(
method_class.__init__, kwargs
)
method_class.__init__, kwargs)
# Remove used kwargs from list of kwargs passed to the function.
for key in method_kwargs.keys():
kwargs.pop(key)
Expand Down
4 changes: 2 additions & 2 deletions dianna/dashboard/_movie_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import numpy as np
from _shared import data_directory
from _shared import label_directory
from scipy.special import expit as sigmoid
from torchtext.vocab import Vectors
from dianna import utils
Expand All @@ -13,7 +13,7 @@ class MovieReviewsModelRunner:
def __init__(self, model, word_vectors=None, max_filter_size=5):
"""Initializes the class."""
if word_vectors is None:
word_vectors = data_directory / 'movie_reviews_word_vectors.txt'
word_vectors = label_directory / 'movie_reviews_word_vectors.txt'

self.run_model = utils.get_function(model)
self.vocab = Vectors(word_vectors, cache=os.path.dirname(word_vectors))
Expand Down
3 changes: 2 additions & 1 deletion dianna/dashboard/_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from importlib.resources import files

data_directory = files('dianna.data')

model_directory = files('dianna.models')
label_directory = files('dianna.labels')

@st.cache_data
def get_base64_of_bin_file(png_file):
Expand Down
6 changes: 4 additions & 2 deletions dianna/dashboard/pages/1_Images.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from _shared import _methods_checkboxes
from _shared import add_sidebar_logo
from _shared import data_directory
from _shared import label_directory
from _shared import model_directory
from dianna.visualization import plot_image

add_sidebar_logo()
Expand Down Expand Up @@ -37,8 +39,8 @@

if load_example:
image_file = (data_directory / 'digit0.png')
image_model_file = (data_directory / 'mnist_model_tf.onnx')
image_label_file = (data_directory / 'labels_mnist.txt')
image_model_file = (model_directory / 'mnist_model_tf.onnx')
image_label_file = (label_directory / 'labels_mnist.txt')

if not (image_file and image_model_file and image_label_file):
st.info('Add your input data in the left panel to continue')
Expand Down
8 changes: 4 additions & 4 deletions dianna/dashboard/pages/2_Text.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from _shared import _get_top_indices_and_labels
from _shared import _methods_checkboxes
from _shared import add_sidebar_logo
from _shared import data_directory
from dianna.visualization.text import highlight_text
from _shared import label_directory
from _shared import model_directory

add_sidebar_logo()

Expand All @@ -35,8 +35,8 @@

if load_example:
text_input = 'The movie started out great but the ending was dissappointing'
text_model_file = data_directory / 'movie_review_model.onnx'
text_label_file = data_directory / 'labels_text.txt'
text_model_file = model_directory / 'movie_review_model.onnx'
text_label_file = label_directory / 'labels_text.txt'

if not (text_input and text_model_file and text_label_file):
st.info('Add your input data in the left panel to continue')
Expand Down
6 changes: 4 additions & 2 deletions dianna/dashboard/pages/3_Time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from _shared import _methods_checkboxes
from _shared import add_sidebar_logo
from _shared import data_directory
from _shared import label_directory
from _shared import model_directory
from _ts_utils import _convert_to_segments
from _ts_utils import open_timeseries
from dianna.visualization import plot_timeseries
Expand All @@ -34,9 +36,9 @@

if load_example:
ts_file = (data_directory / 'weather_data.npy')
ts_model_file = (data_directory /
ts_model_file = (model_directory /
'season_prediction_model_temp_max_binary.onnx')
ts_label_file = (data_directory / 'weather_data_labels.txt')
ts_label_file = (label_directory / 'weather_data_labels.txt')

if not (ts_file and ts_model_file and ts_label_file):
st.info('Add your input data in the left panel to continue')
Expand Down
Binary file added dianna/data/FRB211024.npy
Binary file not shown.
File renamed without changes
Binary file removed dianna/data/bee_2.png
Binary file not shown.
File renamed without changes.
2 changes: 2 additions & 0 deletions dianna/labels/apertif_frb_classes.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Noise
FRB
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
10 changes: 5 additions & 5 deletions dianna/methods/rise_image.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np
from dianna import utils
from dianna.utils.maskers import generate_masks_for_images
from dianna.utils.maskers import generate_interpolated_float_masks_for_image
from dianna.utils.predict import make_predictions
from dianna.utils.rise_utils import normalize

Expand Down Expand Up @@ -60,8 +60,8 @@ def explain(self, model_or_function, input_data, labels, batch_size=100):
# data shape without batch axis and channel axis
img_shape = input_data.shape[1:3]
# Expose masks for to make user inspection possible
self.masks = generate_masks_for_images(img_shape, self.n_masks,
active_p_keep, self.feature_res)
self.masks = generate_interpolated_float_masks_for_image(
img_shape, active_p_keep, self.n_masks, self.feature_res)

# Make sure multiplication is being done for correct axes
masked = input_data * self.masks
Expand Down Expand Up @@ -117,8 +117,8 @@ def _determine_p_keep(self, input_data, runner, n_masks=100):

def _calculate_max_class_std(self, p_keep, runner, input_data, n_masks):
img_shape = input_data.shape[1:3]
masks = generate_masks_for_images(img_shape, n_masks, p_keep,
self.feature_res)
masks = generate_interpolated_float_masks_for_image(
img_shape, p_keep, n_masks, self.feature_res)
masked = input_data * masks
predictions = make_predictions(masked, runner, batch_size=50)
std_per_class = predictions.std(axis=0)
Expand Down
53 changes: 36 additions & 17 deletions dianna/methods/rise_timeseries.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional
import numpy as np
from dianna import utils
from dianna.utils.maskers import generate_masks
from dianna.utils.maskers import mask_data
Expand All @@ -8,25 +10,37 @@
class RISETimeseries:
"""RISE implementation for timeseries adapted from the image version of RISE."""

def __init__(self,
n_masks=1000,
feature_res=8,
p_keep=0.5,
preprocess_function=None):
def __init__(
self,
n_masks: int = 1000,
feature_res: int = 8,
p_keep: float = 0.5,
preprocess_function: Optional[callable] = None,
keep_masks: bool = False,
keep_masked_data: bool = False,
keep_predictions: bool = False,
) -> np.ndarray:
"""RISE initializer.
Args:
n_masks (int): Number of masks to generate.
feature_res (int): Resolution of features in masks.
p_keep (float): Fraction of input data to keep in each mask (Default: auto-tune this value).
preprocess_function (callable, optional): Function to preprocess input data with
n_masks: Number of masks to generate.
feature_res: Resolution of features in masks.
p_keep: Fraction of input data to keep in each mask (Default: auto-tune this value).
preprocess_function: Function to preprocess input data with
keep_masks: keep masks in memory for the user to inspect
keep_masked_data: keep masked data in memory for the user to inspect
keep_predictions: keep model predictions in memory for the user to inspect
"""
self.n_masks = n_masks
self.feature_res = feature_res
self.p_keep = p_keep
self.preprocess_function = preprocess_function
self.masks = None
self.masked = None
self.predictions = None
self.keep_masks = keep_masks
self.keep_masked_data = keep_masked_data
self.keep_predictions = keep_predictions

def explain(self,
model_or_function,
Expand All @@ -47,20 +61,25 @@ def explain(self,
labels (Iterable(int)): Labels to be explained
mask_type: Masking strategy for masked values. Choose from 'mean' or a callable(input_timeseries)
Returns:
Explanation heatmap for each class (np.ndarray).
"""
runner = utils.get_function(
model_or_function, preprocess_function=self.preprocess_function)
self.masks = generate_masks(input_timeseries,
number_of_masks=self.n_masks,
p_keep=self.p_keep)
masked = mask_data(input_timeseries, self.masks, mask_type=mask_type)

self.predictions = make_predictions(masked, runner, batch_size)
n_labels = self.predictions.shape[1]
masks = generate_masks(input_timeseries,
number_of_masks=self.n_masks,
feature_res=self.feature_res,
p_keep=self.p_keep)
self.masks = masks if self.keep_masks else None
masked = mask_data(input_timeseries, masks, mask_type=mask_type)
self.masked = masked if self.keep_masked_data else None
predictions = make_predictions(masked, runner, batch_size)
self.predictions = predictions if self.keep_predictions else None
n_labels = predictions.shape[1]

saliency = self.predictions.T.dot(self.masks.reshape(
self.n_masks, -1)).reshape(n_labels, *input_timeseries.shape)
saliency = predictions.T.dot(masks.reshape(self.n_masks, -1)).reshape(
n_labels, *input_timeseries.shape)
selected_saliency = saliency[labels]
return normalize(selected_saliency, self.n_masks, self.p_keep)
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Loading

0 comments on commit 2165854

Please sign in to comment.