Skip to content

Commit

Permalink
Fix bug with smoothing
Browse files Browse the repository at this point in the history
  • Loading branch information
ClaudioSalvatoreArcidiacono committed Jan 2, 2025
1 parent 39ed08c commit 1075b1e
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 80 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ classifiers = [
"Operating System :: OS Independent",
]

dependencies = ["narwhals", "pydantic", "scikit-learn"]
dependencies = ["narwhals>=1.20.1", "pydantic", "scikit-learn"]

[project.optional-dependencies]
dev = [
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jupyter_client==8.6.3
jupyter_core==5.7.2
matplotlib-inline==0.1.7
mypy-extensions==1.0.0
narwhals==1.15.2
narwhals==1.20.1
nest-asyncio==1.6.0
nodeenv==1.9.1
numpy==2.1.3
Expand Down
38 changes: 19 additions & 19 deletions sklearo/encoding/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,9 @@ def _calculate_target_statistic(
) -> dict:

if column in (
"category_count",
"sum_target",
"std_target",
"count_per_category",
"sum_target_per_category",
"std_target_per_category",
"smoothing",
"shrinkage",
"smoothed_target",
Expand All @@ -125,18 +125,20 @@ def _calculate_target_statistic(
else:
original_column_name = column

mean_target = x_y[target_col].mean()

x_y_grouped = x_y.group_by(column, drop_null_keys=True).agg(
category_count=nw.col(target_col).count(),
sum_target=nw.col(target_col).sum(),
count_per_category=nw.col(target_col).count(),
sum_target_per_category=nw.col(target_col).sum(),
**(
{"std_target": nw.col(target_col).std()}
{"var_target_per_category": nw.col(target_col).var()}
if self.smooth == "auto"
else {}
),
)
underrepresented_categories = x_y_grouped.filter(nw.col("category_count") == 1)[
column
].to_list()
underrepresented_categories = x_y_grouped.filter(
nw.col("count_per_category") == 1
)[column].to_list()
if underrepresented_categories:
if self.underrepresented_categories == "raise":
raise ValueError(
Expand All @@ -147,7 +149,7 @@ def _calculate_target_statistic(
)
else:
if self.fill_values_underrepresented == "mean":
fill_values_underrepresented = x_y[target_col].mean()
fill_values_underrepresented = mean_target
else:
fill_values_underrepresented = self.fill_values_underrepresented

Expand All @@ -162,25 +164,23 @@ def _calculate_target_statistic(
encoding_dict = {}

if self.smooth == "auto":
target_std = x_y[target_col].std()
var_target = x_y[target_col].var()
x_y_grouped = x_y_grouped.with_columns(
smoothing=nw.col("std_target") / target_std
smoothing=nw.col("var_target_per_category") / var_target
)
else:
x_y_grouped = x_y_grouped.with_columns(smoothing=nw.lit(self.smooth))

categories_encoding_as_list = (
x_y_grouped.with_columns(
shrinkage=nw.col("category_count")
/ (nw.col("category_count") + nw.col("smoothing"))
shrinkage=nw.col("count_per_category")
/ (nw.col("count_per_category") + nw.col("smoothing"))
)
.with_columns(
smoothed_target=nw.col("shrinkage")
* nw.col("sum_target")
/ nw.col("category_count")
+ (1 - nw.col("shrinkage"))
* nw.col("sum_target")
/ nw.col("category_count")
* nw.col("sum_target_per_category")
/ nw.col("count_per_category")
+ (1 - nw.col("shrinkage")) * mean_target
)
.select(column, "smoothed_target")
.rows()
Expand Down
185 changes: 126 additions & 59 deletions tests/encoding/test_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,12 @@ def test_target_encoder_unseen_value_fill_unseen_multiclass(

np.testing.assert_allclose(
transformed["category_mean_target_class_1"].to_list(),
[0.4, 0.2, 0.3],
[0.379545, 0.214634, 0.3],
rtol=1e-5,
)
np.testing.assert_allclose(
transformed["category_mean_target_class_2"].to_list(),
[0.2, 0.4, 0.3],
[0.214634, 0.379545, 0.3],
rtol=1e-5,
)
np.testing.assert_allclose(
Expand All @@ -104,7 +104,17 @@ def test_target_encoder_fit_multiclass_non_int_target(
transformed_data = encoder.transform(binary_class_data[["target"]])
np.testing.assert_allclose(
transformed_data["target_mean_target_class_A"].to_list(),
[0.2, 0.5, 0.5, 0.2, 0.2, 0.5, 0.2, 0.2, 0.5],
[
0.218391,
0.458333,
0.458333,
0.218391,
0.218391,
0.458333,
0.218391,
0.218391,
0.458333,
],
rtol=1e-5,
)

Expand All @@ -128,14 +138,14 @@ def test_target_encoder_fit_binary_non_int_target(
np.testing.assert_allclose(
transformed_data["target"].to_list(),
[
0.333333,
0.333333,
0.666667,
0.380952,
0.380952,
0.619048,
0.5,
0.5,
0.333333,
0.666667,
0.666667,
0.380952,
0.619048,
0.619048,
0.5,
0.5,
],
Expand Down Expand Up @@ -166,15 +176,15 @@ def test_target_encoder_fit_binary_non_int_target_classes_1_and_2(
np.testing.assert_allclose(
transformed_data["category"].to_list(),
[
0.666667,
0.666667,
0.666667,
0.333333,
0.333333,
0.333333,
0.333333,
0.333333,
0.333333,
0.603175,
0.603175,
0.603175,
0.365079,
0.365079,
0.365079,
0.365079,
0.365079,
0.365079,
],
rtol=1e-5,
)
Expand Down Expand Up @@ -224,15 +234,15 @@ def test_target_encoder_transform_binary(self, binary_class_data, DataFrame):
transformed = encoder.transform(binary_class_data[["category"]])

expected_values = [
0.333333,
0.333333,
0.333333,
0.666667,
0.666667,
0.666667,
0.666667,
0.666667,
0.666667,
0.396825,
0.396825,
0.396825,
0.634921,
0.634921,
0.634921,
0.634921,
0.634921,
0.634921,
]
np.testing.assert_allclose(
transformed["category"].to_list(), expected_values, rtol=1e-5
Expand All @@ -246,16 +256,16 @@ def test_target_encoder_transform_regression(self, regression_data, DataFrame):
transformed = encoder.transform(regression_data[["category"]])

expected_values = [
250.0,
250.0,
250.0,
250.0,
2250.0,
2250.0,
2250.0,
2250.0,
2250.0,
2250.0,
250.549702,
250.549702,
250.549702,
250.549702,
2082.60129,
2082.60129,
2082.60129,
2082.60129,
2082.60129,
2082.60129,
]
np.testing.assert_allclose(
transformed["category"].to_list(), expected_values, rtol=1e-5
Expand All @@ -277,14 +287,36 @@ def test_target_encoder_transform_multi_class(self, multi_class_data, DataFrame)
np.testing.assert_allclose(
transformed["category_mean_target_class_1"],
# For class 1 A counts : 2/5, B counts : 1/5
[0.4, 0.4, 0.4, 0.4, 0.4, 0.2, 0.2, 0.2, 0.2, 0.2],
[
0.379545,
0.379545,
0.379545,
0.379545,
0.379545,
0.214634,
0.214634,
0.214634,
0.214634,
0.214634,
],
rtol=1e-5,
)

np.testing.assert_allclose(
transformed["category_mean_target_class_2"],
# For class 2 A counts : 1/5, B counts : 2/5
[0.2, 0.2, 0.2, 0.2, 0.2, 0.4, 0.4, 0.4, 0.4, 0.4],
[
0.214634,
0.214634,
0.214634,
0.214634,
0.214634,
0.379545,
0.379545,
0.379545,
0.379545,
0.379545,
],
rtol=1e-5,
)

Expand Down Expand Up @@ -348,12 +380,12 @@ def test_target_encoder_handle_missing_values_binary(
0.555556,
0.0,
0.0,
0.666667,
0.666667,
0.666667,
0.666667,
0.666667,
0.666667,
0.634921,
0.634921,
0.634921,
0.634921,
0.634921,
0.634921,
],
rtol=1e-5,
)
Expand All @@ -373,17 +405,39 @@ def test_target_encoder_handle_missing_values_multi_class(
transformed = encoder.transform(multi_class_data[["category"]])
np.testing.assert_allclose(
transformed["category_mean_target_class_1"].to_list(),
[0.3, 0.25, 0.25, 0.25, 0.25, 0.2, 0.2, 0.2, 0.2, 0.2],
[
0.3,
0.260563,
0.260563,
0.260563,
0.260563,
0.214634,
0.214634,
0.214634,
0.214634,
0.214634,
],
rtol=1e-5,
)
np.testing.assert_allclose(
transformed["category_mean_target_class_2"].to_list(),
[0.3, 0.25, 0.25, 0.25, 0.25, 0.4, 0.4, 0.4, 0.4, 0.4],
[
0.3,
0.260563,
0.260563,
0.260563,
0.260563,
0.379545,
0.379545,
0.379545,
0.379545,
0.379545,
],
rtol=1e-5,
)
np.testing.assert_allclose(
transformed["category_mean_target_class_3"].to_list(),
[0.4, 0.5, 0.5, 0.5, 0.5, 0.4, 0.4, 0.4, 0.4, 0.4],
[0.4, 0.47619, 0.47619, 0.47619, 0.47619, 0.4, 0.4, 0.4, 0.4, 0.4],
rtol=1e-5,
)

Expand All @@ -404,15 +458,15 @@ def test_target_encoder_unnderrepresented_categories_binary_fill_binary_set_valu
np.testing.assert_allclose(
transformed["category"].to_list(),
[
999,
0.0,
0.0,
0.666667,
0.666667,
0.666667,
0.666667,
0.666667,
0.666667,
9.990000e02,
0.000000e00,
0.000000e00,
6.349206e-01,
6.349206e-01,
6.349206e-01,
6.349206e-01,
6.349206e-01,
6.349206e-01,
],
rtol=1e-5,
)
Expand All @@ -427,7 +481,9 @@ def test_target_encoder_unseen_category_binary(self, binary_class_data, DataFram
transformed = encoder.transform(new_data)

np.testing.assert_allclose(
transformed["category"].to_list(), [0.3333333, 0.6666667, -999], rtol=1e-5
transformed["category"].to_list(),
[3.968254e-01, 6.349206e-01, -9.990000e02],
rtol=1e-5,
)

def test_target_encoder_unseen_category_binary_raise(
Expand Down Expand Up @@ -595,6 +651,17 @@ def test_target_encoder_explicitly_set_target_type(

np.testing.assert_allclose(
transformed_data["category"].to_list(),
[2.0, 2.0, 2.0, 2.0, 2.0, 2.2, 2.2, 2.2, 2.2, 2.2],
[
2.02069,
2.02069,
2.02069,
2.02069,
2.02069,
2.184559,
2.184559,
2.184559,
2.184559,
2.184559,
],
rtol=1e-5,
)

0 comments on commit 1075b1e

Please sign in to comment.