Skip to content

Commit

Permalink
Merge pull request #840 from dianna-ai/fix_822
Browse files Browse the repository at this point in the history
Fix docstring of visualization tabular
  • Loading branch information
cwmeijer authored Aug 14, 2024
2 parents 15f3587 + 75093e6 commit 1274233
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 1 deletion.
8 changes: 7 additions & 1 deletion dianna/visualization/tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def plot_tabular(
"""Plot feature importance with segments highlighted.
Args:
x (np.ndarray): Array of feature importance scores
x (np.ndarray): 1D array of feature importance scores of one instance
y (List[str]): List of feature names
x_label (str): Label for the x-axis
y_label (str): Label or list of labels for the y-axis
Expand All @@ -32,6 +32,12 @@ def plot_tabular(
Returns:
plt.Figure
"""
# check type and shape of x should be 1D array
if not isinstance(x, np.ndarray):
raise TypeError("x should be a numpy array")
if x.ndim != 1:
raise ValueError("x should be a 1D array")

if not num_features:
num_features = len(x)
abs_values = [abs(i) for i in x]
Expand Down
8 changes: 8 additions & 0 deletions tests/test_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,14 @@ def test_plot_tabular(tmpdir):

assert output_path.exists()

def test_plot_tabular_with_ndarray():
"""Test plot tabular data with ndarray."""
x = np.random.rand(5, 3)
y = [f"Feature {i}" for i in range(x.shape[1])]
# check ValueError
with pytest.raises(ValueError):
plot_tabular(x=x, y=y, show_plot=False)


def test_plot_timeseries_univariate(tmpdir, random):
"""Test plot univariate time series."""
Expand Down

0 comments on commit 1274233

Please sign in to comment.