From 091c218ed7079b60607c12301549a6d8e05412cf Mon Sep 17 00:00:00 2001 From: Tim Adams <tim.adams@scai.fraunhofer.de> Date: Mon, 16 Dec 2024 10:46:07 +0100 Subject: [PATCH] Fix: Handle constant columns --- syndat/scores.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/syndat/scores.py b/syndat/scores.py index 2171640..1e6288e 100644 --- a/syndat/scores.py +++ b/syndat/scores.py @@ -203,6 +203,12 @@ def correlation(real: pd.DataFrame, synthetic: pd.DataFrame, score=True) -> floa # Compute numerical correlation only real_numerical = real_encoded.select_dtypes(include=[np.number]) synthetic_numerical = synthetic_encoded.select_dtypes(include=[np.number]) + # Remove constant columns (zero variance) + constant_columns = real_numerical.columns[real_numerical.nunique() <= 1] + if len(constant_columns) > 0: + logger.warning(f'Removing constant columns {constant_columns} for correlation computation.') + real_numerical = real_numerical.drop(columns=constant_columns, errors="ignore") + synthetic_numerical = synthetic_numerical.drop(columns=constant_columns, errors="ignore") # Compute correlation matrices corr_real = real_numerical.corr() corr_synthetic = synthetic_numerical.corr() @@ -214,8 +220,11 @@ def correlation(real: pd.DataFrame, synthetic: pd.DataFrame, score=True) -> floa if not corr_real.drop(columns=one_hot_encoded_columns).empty: corr_real = corr_real.drop(columns=one_hot_encoded_columns) corr_synthetic = corr_synthetic.drop(columns=one_hot_encoded_columns) + # assure both matrices have the same columns + corr_synthetic = corr_synthetic[corr_real] # now compute correlation matrices - norm_diff = np.linalg.norm(corr_real - corr_synthetic) + corr_diff = corr_real - corr_synthetic + norm_diff = np.linalg.norm(corr_diff) norm_real = np.linalg.norm(corr_real) norm_quotient = norm_diff / norm_real if score: