From 091c218ed7079b60607c12301549a6d8e05412cf Mon Sep 17 00:00:00 2001
From: Tim Adams <tim.adams@scai.fraunhofer.de>
Date: Mon, 16 Dec 2024 10:46:07 +0100
Subject: [PATCH] Fix: Handle constant columns

---
 syndat/scores.py | 11 ++++++++++-
 1 file changed, 10 insertions(+), 1 deletion(-)

diff --git a/syndat/scores.py b/syndat/scores.py
index 2171640..1e6288e 100644
--- a/syndat/scores.py
+++ b/syndat/scores.py
@@ -203,6 +203,12 @@ def correlation(real: pd.DataFrame, synthetic: pd.DataFrame, score=True) -> floa
     # Compute numerical correlation only
     real_numerical = real_encoded.select_dtypes(include=[np.number])
     synthetic_numerical = synthetic_encoded.select_dtypes(include=[np.number])
+    # Remove constant columns (zero variance)
+    constant_columns = real_numerical.columns[real_numerical.nunique() <= 1]
+    if len(constant_columns) > 0:
+        logger.warning(f'Removing constant columns {constant_columns} for correlation computation.')
+        real_numerical = real_numerical.drop(columns=constant_columns, errors="ignore")
+        synthetic_numerical = synthetic_numerical.drop(columns=constant_columns, errors="ignore")
     # Compute correlation matrices
     corr_real = real_numerical.corr()
     corr_synthetic = synthetic_numerical.corr()
@@ -214,8 +220,11 @@ def correlation(real: pd.DataFrame, synthetic: pd.DataFrame, score=True) -> floa
     if not corr_real.drop(columns=one_hot_encoded_columns).empty:
         corr_real = corr_real.drop(columns=one_hot_encoded_columns)
         corr_synthetic = corr_synthetic.drop(columns=one_hot_encoded_columns)
+    # assure both matrices have the same columns
+    corr_synthetic = corr_synthetic[corr_real]
     # now compute correlation matrices
-    norm_diff = np.linalg.norm(corr_real - corr_synthetic)
+    corr_diff = corr_real - corr_synthetic
+    norm_diff = np.linalg.norm(corr_diff)
     norm_real = np.linalg.norm(corr_real)
     norm_quotient = norm_diff / norm_real
     if score: