Skip to content

Commit

Permalink
Merge pull request #832 from dianna-ai/790-add-scientific-use-case-fr…
Browse files Browse the repository at this point in the history
…b-to-dashboard

790 add scientific use case frb to dashboard
  • Loading branch information
laurasootes authored Aug 13, 2024
2 parents 76f9c36 + 0bb0693 commit 93c624e
Show file tree
Hide file tree
Showing 6 changed files with 286 additions and 111 deletions.
8 changes: 8 additions & 0 deletions dianna/dashboard/_models_ts.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ def predict(*, model, ts_data):

@st.cache_data
def _run_rise_timeseries(_model, ts_data, **kwargs):
# convert streamlit kwarg requirement back to dianna kwarg requirement
if "_preprocess_function" in kwargs:
kwargs["preprocess_function"] = kwargs["_preprocess_function"]
del kwargs["_preprocess_function"]

def run_model(ts_data):
return predict(model=_model, ts_data=ts_data)
Expand All @@ -37,6 +41,10 @@ def run_model(ts_data):

@st.cache_data
def _run_lime_timeseries(_model, ts_data, **kwargs):
# convert streamlit kwarg requirement back to dianna kwarg requirement
if "_preprocess_function" in kwargs:
kwargs["preprocess_function"] = kwargs["_preprocess_function"]
del kwargs["_preprocess_function"]

def run_model(ts_data):
return predict(model=_model, ts_data=ts_data)
Expand Down
14 changes: 14 additions & 0 deletions dianna/dashboard/_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,17 @@ def _get_top_indices_and_labels(*, predictions, labels):
st.metric('Predicted class', top_labels[0])

return top_indices, top_labels

def reset_method():
# Clear selection
for k in st.session_state.keys():
if '_cb_' in k:
st.session_state[k] = False
if 'params' in k:
st.session_state.pop(k)

def reset_example():
# Clear selection
for k in st.session_state.keys():
if '_load_' in k:
st.session_state.pop(k)
93 changes: 58 additions & 35 deletions dianna/dashboard/pages/Images.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
from _shared import _get_top_indices_and_labels
from _shared import _methods_checkboxes
from _shared import add_sidebar_logo
from _shared import data_directory
from _shared import label_directory
from _shared import model_directory
from _shared import reset_example
from _shared import reset_method
from dianna.utils.downloader import download
from dianna.visualization import plot_image

add_sidebar_logo()
Expand All @@ -19,40 +19,63 @@

st.sidebar.header('Input data')

load_example_digits = st.sidebar.checkbox('Load hand-written digits example',
key='Image_digits_example_check')

image_file = st.sidebar.file_uploader('Select image',
type=('png', 'jpg', 'jpeg'),
disabled=load_example_digits)

if image_file:
st.sidebar.image(image_file)

image_model_file = st.sidebar.file_uploader('Select model',
type='onnx',
disabled=load_example_digits)

image_label_file = st.sidebar.file_uploader('Select labels',
type='txt',
disabled=load_example_digits)

if load_example_digits:
image_file = (data_directory / 'digit0.jpg')
image_model_file = (model_directory / 'mnist_model_tf.onnx')
image_label_file = (label_directory / 'labels_mnist.txt')

st.markdown(
"""
This example demonstrates the use of DIANNA on a pretrained binary
[MNIST](https://yann.lecun.com/exdb/mnist/) model using hand-written
digit images. The model predicts for an image of a hand-written 0 or 1,
which of the two it most likely is. This example visualizes the
relevance attributions for each pixel/super-pixel by displaying them on
top of the input image.
"""
input_type = st.sidebar.radio(
label='Select which input to use',
options = ('Use an example', 'Use your own data'),
index = None,
on_change = reset_example,
key = 'Image_input_type'
)

# Use the examples
if input_type == 'Use an example':
load_example = st.sidebar.radio(
label='Load example',
options=('Hand-written digit recognition',),
index = None,
on_change = reset_method,
key='Image_load_example'
)

if load_example == 'Hand-written digit recognition':
image_file = download('digit0.jpg', 'data')
image_model_file = download('mnist_model_tf.onnx', 'model')
image_label_file = download('labels_mnist.txt', 'label')

st.markdown(
"""
This example demonstrates the use of DIANNA on a pretrained binary
[MNIST](https://yann.lecun.com/exdb/mnist/) model using a hand-written digit images.
The model predict for an image of a hand-written 0 or 1, which of the two it most
likely is.
This example visualizes the relevance attributions for each pixel/super-pixel by
displaying them on top of the input image.
"""
)
else:
st.info('Select an example in the left panel to coninue')
st.stop()

# Option to upload your own data
if input_type == 'Use your own data':
load_example = None

image_file = st.sidebar.file_uploader('Select image',
type=('png', 'jpg', 'jpeg'))

if image_file:
st.sidebar.image(image_file)

image_model_file = st.sidebar.file_uploader('Select model',
type='onnx')

image_label_file = st.sidebar.file_uploader('Select labels',
type='txt')

if input_type is None:
st.info('Select which input type to use in the left panel to continue')
st.stop()

if not (image_file and image_model_file and image_label_file):
st.info('Add your input data in the left panel to continue')
st.stop()
Expand Down
70 changes: 46 additions & 24 deletions dianna/dashboard/pages/Text.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
from _shared import _get_top_indices_and_labels
from _shared import _methods_checkboxes
from _shared import add_sidebar_logo
from _shared import label_directory
from _shared import model_directory
from _shared import reset_example
from _shared import reset_method
from dianna.utils.downloader import download
from dianna.visualization.text import highlight_text

add_sidebar_logo()
Expand All @@ -18,35 +19,56 @@

st.sidebar.header('Input data')

load_example_moviesentiment = st.sidebar.checkbox('Load movie sentiment example',
key='Text_example_check_moviesentiment')

text_input = st.sidebar.text_input('Input string', disabled=load_example_moviesentiment)

if text_input:
st.sidebar.write(text_input)

text_model_file = st.sidebar.file_uploader('Select model',
type='onnx',
disabled=load_example_moviesentiment)

text_label_file = st.sidebar.file_uploader('Select labels',
type='txt',
disabled=load_example_moviesentiment)

if load_example_moviesentiment:
text_input = 'The movie started out great but the ending was dissappointing'
text_model_file = model_directory / 'movie_review_model.onnx'
text_label_file = label_directory / 'labels_text.txt'

st.markdown(
input_type = st.sidebar.radio(
label='Select which input to use',
options = ('Use an example', 'Use your own data'),
index = None,
on_change = reset_example,
key = 'Text_input_type'
)

# Use the examples
if input_type == 'Use an example':
load_example = st.sidebar.radio(
label='Use example',
options=('Movie sentiment',),
index = None,
on_change = reset_method,
key='Text_example_check_moviesentiment')

if load_example == 'Movie sentiment':
text_input = 'The movie started out great but the ending was dissappointing'
text_model_file = download('movie_review_model.onnx', 'model')
text_label_file = download('labels_text.txt', 'label')

st.markdown(
"""
This example demonstrates the use of DIANNA on the [Stanford Sentiment
Treebank dataset](https://nlp.stanford.edu/sentiment/index.html) which
contains one-sentence movie reviews. A pre-trained neural network
classifier is used, which identifies whether a movie review is positive
or negative.
""")
else:
st.info('Select an example in the left panel to coninue')
st.stop()

# Option to upload your own data
if input_type == 'Use your own data':
text_input = st.sidebar.text_input('Input string')

if text_input:
st.sidebar.write(text_input)

text_model_file = st.sidebar.file_uploader('Select model',
type='onnx')

text_label_file = st.sidebar.file_uploader('Select labels',
type='txt')

if input_type is None:
st.info('Select which input type to use in the left panel to continue')
st.stop()

if not (text_input and text_model_file and text_label_file):
st.info('Add your input data in the left panel to continue')
Expand Down
Loading

0 comments on commit 93c624e

Please sign in to comment.