Skip to content

Commit

Permalink
feat: add visualization & postprocessing functions
Browse files Browse the repository at this point in the history
  • Loading branch information
tiadams committed Sep 27, 2024
1 parent 5b33fab commit 349ff07
Show file tree
Hide file tree
Showing 3 changed files with 248 additions and 0 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
'scikit-learn~=1.5',
'matplotlib~=3.8',
'seaborn~=0.13',
'shap~=0.42.0',
'setuptools>=70.0.0'
],
classifiers=[
Expand Down
108 changes: 108 additions & 0 deletions syndat/postprocessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import pandas as pd
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score
import matplotlib.pyplot as plt
import seaborn as sns

def normalize_scale(real_df: pd.DataFrame, synthetic_df: pd.DataFrame) -> pd.DataFrame:
"""
Scales the columns in the synthetic DataFrame to match the scale (min and max values) of the corresponding columns in the real DataFrame.
Parameters:
real_df (pd.DataFrame): The real dataset used as the scaling reference.
synthetic_df (pd.DataFrame): The synthetic dataset to be scaled.
Returns:
pd.DataFrame: The scaled synthetic dataset with columns adjusted to the real dataset's scale.
"""
# Create a copy of the synthetic dataframe to avoid modifying the original one
scaled_synthetic_df = synthetic_df.copy()

# Iterate over each column in the synthetic dataframe
for column in synthetic_df.columns:
# Check if the column is of floating-point type in both real and synthetic data
if np.issubdtype(synthetic_df[column].dtype, np.floating):
# Find the min and max values in the real data for this column
real_min = real_df[column].min()
real_max = real_df[column].max()

# Find the min and max values in the synthetic data for this column
synthetic_min = synthetic_df[column].min()
synthetic_max = synthetic_df[column].max()

# Scale the synthetic data to match the min/max of the real data
scaled_synthetic_df[column] = ((synthetic_df[column] - synthetic_min) / (synthetic_max - synthetic_min)) * (real_max - real_min) + real_min

return scaled_synthetic_df

def assert_minmax(real: pd.DataFrame, synthetic: pd.DataFrame, method: str = 'clip') -> pd.DataFrame:
"""
Postprocess the synthetic data by either deleting records that fall outside the min-max range of the real data,
or adjusting them to fit within the range. Also normalizes -0.0 to 0.0 to avoid plotting issues.
Parameters:
real (pd.DataFrame): The real dataset.
synthetic (pd.DataFrame): The synthetic dataset.
method (str): The method to apply. 'delete' to remove records, 'clip' to adjust them.
Returns:
pd.DataFrame: The postprocessed synthetic dataset.
"""
# Normalize -0.0 to 0.0 in synthetic data
synthetic = synthetic.apply(lambda col: col.map(lambda x: 0.0 if x == -0.0 else x))

# Iterate over each column in the synthetic DataFrame
for column in synthetic.columns:
if column in real.columns:
# Get the min and max of the column in the real data
min_val = real[column].min()
max_val = real[column].max()

if method == 'delete':
# Filter the synthetic DataFrame to keep only rows within the min-max range
synthetic = synthetic[(synthetic[column] >= min_val) & (synthetic[column] <= max_val)]
elif method == 'clip':
# Clip the values to be within the min-max range
synthetic[column] = synthetic[column].clip(lower=min_val, upper=max_val)

return synthetic

def normalize_float_precision(real: pd.DataFrame, synthetic: pd.DataFrame) -> pd.DataFrame:
"""
Postprocess the synthetic data to match the precision or step size found in the real data for float columns.
This function identifies columns in the real dataset that have float data types and determines the precision
or step size (e.g., 1.0, 0.5, 0.1) used in those columns. It then rounds the corresponding columns in the
synthetic dataset to match this detected precision or step size.
Parameters:
real (pd.DataFrame): The real dataset containing float columns.
synthetic (pd.DataFrame): The synthetic dataset that needs to be adjusted to match the precision of the real data.
Returns:
pd.DataFrame: The synthetic dataset with float columns rounded to match the precision or step size of the real data.
"""
# Select float columns from the real dataset
float_columns = real.select_dtypes(include='float').columns

for col in float_columns:
if col in synthetic.columns:
# Get the unique values from the real data column, excluding NaN
unique_values = real[col].dropna().unique()

# Calculate the differences between the unique sorted values
unique_diffs = np.diff(np.sort(unique_values))

# If the unique values are all the same, continue to the next column
if len(unique_diffs) == 0:
continue

# Find the smallest non-zero difference (the step size)
step_size = np.min(unique_diffs[unique_diffs > 0])

# Round the synthetic column to the nearest multiple of the step size
synthetic[col] = np.round(synthetic[col] / step_size) * step_size

return synthetic
139 changes: 139 additions & 0 deletions syndat/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
import pandas as pd
import seaborn as sns
from pandas.plotting import table
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import roc_auc_score
import shap # Added


def plot_distributions(real: pandas.DataFrame, synthetic: pandas.DataFrame, store_destination: str) -> None:
Expand Down Expand Up @@ -65,6 +69,141 @@ def plot_correlations(real: pandas.DataFrame, synthetic: pandas.DataFrame, store
fig = ax.get_figure()
fig.savefig(store_destination + "/" + names[idx] + '.png', bbox_inches="tight")

def plot_shap_discrimination(real: pandas.DataFrame, synthetic: pandas.DataFrame) -> None:
"""
Trains a Random Forest Classifier to discriminate between real and synthetic data and plots SHAP summary values.
:param real: The real data
:param synthetic: The synthetic data
"""
# Assuming 'real' and 'synthetic' are your datasets and are pandas DataFrames
# Add a label column to each dataset
real['label'] = 1
synthetic['label'] = 0

# Combine datasets
combined_data = pd.concat([real, synthetic])

# Separate features and labels
X = combined_data.drop('label', axis=1)
y = combined_data['label']

# Split data into training and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)

# Initialize and train the Random Forest Classifier
rfc = RandomForestClassifier(n_estimators=100, random_state=42)
rfc.fit(X_train, y_train)

# Predict probabilities and calculate AUC score
y_pred_proba = rfc.predict_proba(X_test)[:, 1]
auc_score = roc_auc_score(y_test, y_pred_proba)
print(f'AUC Score: {auc_score}')

# Compute SHAP values
explainer = shap.TreeExplainer(rfc)
shap_values = explainer.shap_values(X_test)

# Plot SHAP summary
shap.summary_plot(shap_values[:, :, 1], X_test)


def plot_categorical_feature(feature: str, real_data: pandas.DataFrame, synthetic_data: pandas.DataFrame) -> None:
"""
Plots count plots for a categorical feature from both real and synthetic datasets.
:param feature: The feature to be plotted
:param real_data: The real data
:param synthetic_data: The synthetic data
"""
plt.figure(figsize=(14, 6))

# Plot for the real dataset
plt.subplot(1, 2, 1)
sns.countplot(x=feature, data=real_data, color='blue')
plt.xlabel(feature)
plt.ylabel('Frequency')
plt.title(f'Real Data - {feature}')
plt.xticks(rotation=90)

# Plot for the synthetic dataset
plt.subplot(1, 2, 2)
sns.countplot(x=feature, data=synthetic_data, color='orange')
plt.xlabel(feature)
plt.ylabel('Frequency')
plt.title(f'Synthetic Data - {feature}')
plt.xticks(rotation=90)

plt.tight_layout()
plt.show()


def plot_numerical_feature(feature: str, real_data: pandas.DataFrame, synthetic_data: pandas.DataFrame) -> None:
"""
Plots violin plots for a numerical feature from both real and synthetic datasets and displays their summary statistics.
:param feature: The feature to be plotted
:param real_data: The real data
:param synthetic_data: The synthetic data
"""
# Calculate summary statistics
def get_summary_stats(data, feature):
return {
'Mean': data[feature].mean(),
'Median': data[feature].median(),
'Std Dev': data[feature].std(),
'Min': data[feature].min(),
'Max': data[feature].max()
}

real_stats = get_summary_stats(real_data, feature)
synthetic_stats = get_summary_stats(synthetic_data, feature)

# Create summary statistics DataFrame
stats_df = pd.DataFrame({
'Statistic': ['Mean', 'Median', 'Std Dev', 'Min', 'Max'],
'Real Data': [real_stats['Mean'], real_stats['Median'], real_stats['Std Dev'], real_stats['Min'], real_stats['Max']],
'Synthetic Data': [synthetic_stats['Mean'], synthetic_stats['Median'], synthetic_stats['Std Dev'], synthetic_stats['Min'], synthetic_stats['Max']]
})

plt.figure(figsize=(14, 8))

# Compute the combined range for x-axis limits
min_value = min(real_data[feature].min(), synthetic_data[feature].min())
max_value = max(real_data[feature].max(), synthetic_data[feature].max())

# Plot for the real dataset
plt.subplot(2, 2, 1)
sns.violinplot(x=real_data[feature], color='blue')
plt.xlabel(feature)
plt.ylabel('Density')
plt.title(f'Real Data - {feature}')
plt.xlim(min_value, max_value)

# Plot for the synthetic dataset
plt.subplot(2, 2, 2)
sns.violinplot(x=synthetic_data[feature], color='orange')
plt.xlabel(feature)
plt.ylabel('Density')
plt.title(f'Synthetic Data - {feature}')
plt.xlim(min_value, max_value)

# Display summary statistics table
plt.subplot(2, 1, 2)
plt.axis('off')
table = plt.table(cellText=stats_df.values,
colLabels=stats_df.columns,
rowLabels=stats_df['Statistic'],
cellLoc='center',
loc='center',
bbox=[0.0, -0.5, 1.0, 0.4])
table.auto_set_font_size(False)
table.set_fontsize(10)
table.scale(1.2, 1.2) # Adjust the size of the table

plt.tight_layout()
plt.show()




0 comments on commit 349ff07

Please sign in to comment.