Skip to content

Commit

Permalink
Merge pull request #19 from SCAI-BIO/fix/corr-nan
Browse files Browse the repository at this point in the history
Fix: add test
  • Loading branch information
tiadams authored Dec 16, 2024
2 parents daa3a47 + 294de30 commit ca27319
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 1 deletion.
9 changes: 8 additions & 1 deletion syndat/scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,12 @@ def correlation(real: pd.DataFrame, synthetic: pd.DataFrame, score=True) -> floa
# Compute numerical correlation only
real_numerical = real_encoded.select_dtypes(include=[np.number])
synthetic_numerical = synthetic_encoded.select_dtypes(include=[np.number])
# Remove constant columns (zero variance)
constant_columns = real_numerical.columns[real_numerical.nunique() <= 1]
if len(constant_columns) > 0:
logger.warning(f'Removing constant columns {constant_columns} for correlation computation.')
real_numerical = real_numerical.drop(columns=constant_columns, errors="ignore")
synthetic_numerical = synthetic_numerical.drop(columns=constant_columns, errors="ignore")
# Compute correlation matrices
corr_real = real_numerical.corr()
corr_synthetic = synthetic_numerical.corr()
Expand All @@ -215,7 +221,8 @@ def correlation(real: pd.DataFrame, synthetic: pd.DataFrame, score=True) -> floa
corr_real = corr_real.drop(columns=one_hot_encoded_columns)
corr_synthetic = corr_synthetic.drop(columns=one_hot_encoded_columns)
# now compute correlation matrices
norm_diff = np.linalg.norm(corr_real - corr_synthetic)
corr_diff = corr_real - corr_synthetic
norm_diff = np.linalg.norm(corr_diff)
norm_real = np.linalg.norm(corr_real)
norm_quotient = norm_diff / norm_real
if score:
Expand Down
13 changes: 13 additions & 0 deletions tests/test_correlation.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,16 @@ def test_correlation_with_only_categorical(self):
# Adjust the expected result based on encoding and correlation of categorical data
self.assertLess(result, 100, "Correlation score should be less than 100 for datasets with different "
"categorical data")

def test_correlation_constant_column(self):
real_data = pd.DataFrame({
'A': [1, 1, 1, 1, 1],
'B': [2, 3, 4, 5, 6]
})
synthetic_data = pd.DataFrame({
'A': [1, 1, 1, 1, 1],
'B': [2, 3, 4, 5, 6]
})
result = correlation(real_data, synthetic_data)
# Depending on the implementation, the correlation might return NaN or handle it explicitly
self.assertFalse(pd.isna(result), "Correlation score should not result in NaN even with a constant column")

0 comments on commit ca27319

Please sign in to comment.