diff --git a/feature_engine/selection/smart_correlation_selection.py b/feature_engine/selection/smart_correlation_selection.py index 8fd9afa60..baae9ef4e 100644 --- a/feature_engine/selection/smart_correlation_selection.py +++ b/feature_engine/selection/smart_correlation_selection.py @@ -29,6 +29,7 @@ _check_contains_inf, _check_contains_na, check_X, + check_y, ) from feature_engine.selection.base_selector import BaseSelector @@ -64,6 +65,7 @@ class SmartCorrelatedSelection(BaseSelector): - Feature with the highest cardinality (greatest number of unique values). - Feature with the highest variance. - Feature with the highest importance according to an estimator. + - Feature with the highest correlation with the target variable. SmartCorrelatedSelection() returns a dataframe containing from each group of correlated features, the selected variable, plus all the features that were @@ -99,8 +101,8 @@ class SmartCorrelatedSelection(BaseSelector): {missing_values} selection_method: str, default= "missing_values" - Takes the values "missing_values", "cardinality", "variance" and - "model_performance". + Takes the values "missing_values", "cardinality", "variance", + "model_performance", and "corr_with_target". **"missing_values"**: keeps the feature from the correlated group with the least missing observations. @@ -115,6 +117,11 @@ class SmartCorrelatedSelection(BaseSelector): features in a correlated group and retains the feature with the highest importance. + **"corr_with_target"**: keeps the feature from the correlated group that has the + highest correlation with the target variable. The same correlation method + defined in the `method` parameter is used to calculate the correlation between + the features and the target. + {estimator} {scoring} @@ -229,10 +236,12 @@ def __init__( "cardinality", "variance", "model_performance", + "corr_with_target", ]: raise ValueError( "selection_method takes only values 'missing_values', 'cardinality', " - f"'variance' or 'model_performance'. Got {selection_method} instead." + "'variance', 'model_performance' or 'corr_with_target'. " + f"Got {selection_method} instead." ) if selection_method == "model_performance" and estimator is None: @@ -271,7 +280,8 @@ def fit(self, X: pd.DataFrame, y: pd.Series = None): The training dataset. y: pandas series. Default = None - y is needed if selection_method == 'model_performance'. + y is needed if selection_method == 'model_performance' or + 'corr_with_target'. """ # check input dataframe @@ -289,9 +299,11 @@ def fit(self, X: pd.DataFrame, y: pd.Series = None): _check_contains_na(X, self.variables_) _check_contains_inf(X, self.variables_) - if self.selection_method == "model_performance" and y is None: + if ( + self.selection_method in ["model_performance", "corr_with_target"] + ) and y is None: raise ValueError( - "When `selection_method = 'model_performance'` y is needed to " + f"When `selection_method = '{self.selection_method}'` y is needed to " "fit the transformer." ) @@ -314,6 +326,15 @@ def fit(self, X: pd.DataFrame, y: pd.Series = None): .sort_values(ascending=False) .index.to_list() ) + elif self.selection_method == "corr_with_target": + y = check_y(y) + features = ( + X[self.variables_] + .corrwith(y, method=self.method) + .abs() + .sort_values(ascending=False) + .index.to_list() + ) else: features = sorted(self.variables_) diff --git a/tests/test_selection/test_smart_correlation_selection.py b/tests/test_selection/test_smart_correlation_selection.py index c7a3254e7..d40f7970b 100644 --- a/tests/test_selection/test_smart_correlation_selection.py +++ b/tests/test_selection/test_smart_correlation_selection.py @@ -117,7 +117,7 @@ def test_raises_error_when_threshold_not_permitted(_threshold): def test_raises_error_when_selection_method_not_permitted(_method): msg = ( "selection_method takes only values 'missing_values', 'cardinality', " - f"'variance' or 'model_performance'. Got {_method} instead." + f"'variance', 'model_performance' or 'corr_with_target'. Got {_method} instead." ) with pytest.raises(ValueError) as record: SmartCorrelatedSelection(selection_method=_method) @@ -252,6 +252,21 @@ def test_error_if_select_model_performance_and_y_is_none(df_single): assert record.value.args[0] == msg +def test_error_if_select_corr_with_target_and_y_is_none(df_single): + X, _ = df_single + + transformer = SmartCorrelatedSelection( + selection_method="corr_with_target", + ) + msg = ( + "When `selection_method = 'corr_with_target'` y is needed to fit " + "the transformer." + ) + with pytest.raises(ValueError) as record: + transformer.fit(X) + assert record.value.args[0] == msg + + def test_selection_method_variance(df_var_car): X = df_var_car @@ -434,3 +449,27 @@ def test_smart_correlation_selection_with_groups(df_test_with_groups): X_tr = transformer.fit_transform(X, y) pd.testing.assert_frame_equal(X_tr_expected, X_tr) + + +def test_corr_with_target_single_corr_group(df_single): + X, y = df_single + + transformer = SmartCorrelatedSelection( + variables=None, + method="pearson", + threshold=0.8, + missing_values="raise", + selection_method="corr_with_target", + ) + + Xt = transformer.fit_transform(X, y) + + # expected result + df = X[["var_0", "var_2", "var_3", "var_5"]].copy() + + # test fit attrs + assert transformer.correlated_feature_sets_ == [{"var_1", "var_2", "var_4"}] + assert transformer.features_to_drop_ == ['var_4', 'var_1'] + assert transformer.correlated_feature_dict_ == {"var_2": {"var_1", "var_4"}} + # test transform output + pd.testing.assert_frame_equal(Xt, df)