diff --git a/syndat/visualization.py b/syndat/visualization.py
index 8d7c42c..fe91c50 100644
--- a/syndat/visualization.py
+++ b/syndat/visualization.py
@@ -69,8 +69,18 @@ def plot_correlations(real: pandas.DataFrame, synthetic: pandas.DataFrame, store
         fig = ax.get_figure()
         fig.savefig(store_destination + "/" + names[idx] + '.png', bbox_inches="tight")
 
-def plot_shap_discrimination(real: pandas.DataFrame, synthetic: pandas.DataFrame) -> None:
-    # Assuming 'real' and 'synthetic_no_dp' are your datasets and are pandas DataFrames
+def plot_shap_discrimination(real: pd.DataFrame, synthetic: pd.DataFrame, save_path: str = None) -> None:
+    """
+    Generates a SHAP summary plot to illustrate the discrimination between real and synthetic datasets
+    using a Random Forest classifier.
+
+    :param real: The real data
+    :param synthetic: The synthetic data
+    :param save_path: Path to the file where the resulting plot should be saved. If None, the plot will not be saved.
+
+    :return: None
+    """
+    # Assuming 'real' and 'synthetic' are your datasets and are pandas DataFrames
     # Add a label column to each dataset
     real['label'] = 1
     synthetic['label'] = 0
@@ -99,7 +109,16 @@ def plot_shap_discrimination(real: pandas.DataFrame, synthetic: pandas.DataFrame
     shap_values = explainer.shap_values(X_test)
 
     # Plot SHAP summary
-    shap.summary_plot(shap_values[1], X_test)
+    plt.figure()
+    shap.summary_plot(shap_values[1], X_test, show=False)
+
+    # Save the plot if save_path is specified
+    if save_path:
+        plt.savefig(save_path, bbox_inches='tight')
+        print(f"Plot saved to {save_path}")
+
+    # Show the plot
+    plt.show()
 
 
 def plot_categorical_feature(feature: str, real_data: pandas.DataFrame, synthetic_data: pandas.DataFrame) -> None:
diff --git a/tests/test_visualization.py b/tests/test_visualization.py
new file mode 100644
index 0000000..b69c754
--- /dev/null
+++ b/tests/test_visualization.py
@@ -0,0 +1,30 @@
+import unittest
+import pandas as pd
+import numpy as np
+import os
+
+from syndat import plot_shap_discrimination
+
+
+class TestPlotShapDiscrimination(unittest.TestCase):
+
+    def setUp(self):
+        # Create sample data for testing
+        self.real = pd.DataFrame(np.random.normal(size=(100, 5)), columns=[f"feature_{i}" for i in range(5)])
+        self.synthetic = pd.DataFrame(np.random.normal(size=(100, 5)), columns=[f"feature_{i}" for i in range(5)])
+
+        # Define the path where the plot will be temporarily saved
+        self.save_path = "test_shap_plot.png"
+
+    def test_plot_shap_discrimination(self):
+        # Call the function with test data and save_path
+        plot_shap_discrimination(self.real, self.synthetic, save_path=self.save_path)
+
+        # Check if the plot file was created
+        self.assertTrue(os.path.exists(self.save_path), "SHAP plot file was not created.")
+
+    def tearDown(self):
+        # Remove the file if it exists after the test
+        if os.path.exists(self.save_path):
+            os.remove(self.save_path)
+