diff --git a/.gitignore b/.gitignore index 2766717..b75032c 100644 --- a/.gitignore +++ b/.gitignore @@ -161,3 +161,4 @@ cython_debug/ # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ .vscode/settings.json +.vscode/launch.json diff --git a/docs/API/utils/type_of_target.md b/docs/API/utils/type_of_target.md new file mode 100644 index 0000000..7798080 --- /dev/null +++ b/docs/API/utils/type_of_target.md @@ -0,0 +1 @@ +::: sklearo.utils.type_of_target diff --git a/mkdocs.yml b/mkdocs.yml index f4083c1..35980ac 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -41,9 +41,6 @@ plugins: show_root_full_path: true show_symbol_type_heading: true show_symbol_type_toc: true - # signature_crossrefs: true - # show_signature_annotations: true - # summary: true markdown_extensions: - pymdownx.highlight: diff --git a/sklearo/encoding/base.py b/sklearo/encoding/base.py index 297d152..50d6434 100644 --- a/sklearo/encoding/base.py +++ b/sklearo/encoding/base.py @@ -4,79 +4,28 @@ from narwhals.typing import IntoFrameT from sklearo.base import BaseTransformer -from sklearo.validation import check_if_fitted class BaseOneToOneEncoder(BaseTransformer): - @nw.narwhalify - @check_if_fitted - def transform(self, X: IntoFrameT) -> IntoFrameT: - """Transform the data. - Args: - X (DataFrame): The input data. - """ - X = self._handle_missing_values(X) - unseen_per_col = {} - for column, mapping in self.encoding_map_.items(): - uniques = X[column].unique() - unseen_cats = uniques.filter( - ( - ~uniques.is_in(next(iter(mapping.values())).keys()) - & ~uniques.is_null() - ) - ).to_list() - if unseen_cats: - unseen_per_col[column] = unseen_cats - - if unseen_per_col: - if self.unseen == "raise": + def _handle_missing_values(self, X: IntoFrameT) -> IntoFrameT: + if self.missing_values == "ignore": + return X + if self.missing_values == "raise": + if max(X[self.columns_].null_count().row(0)) > 0: raise ValueError( - f"Unseen categories {unseen_per_col} found during transform. " - "Please handle unseen categories for example by using a RareLabelEncoder. " - "Alternatively, set unseen to 'ignore'." + f"Some columns have missing values. " + "Please handle missing values before encoding or set " + "missing_values to either 'ignore' or 'encode'." ) - else: - warnings.warn( - f"Unseen categories {unseen_per_col} found during transform. " - "Please handle unseen categories for example by using a RareLabelEncoder. " - f"These categories will be encoded as {self.fill_value_unseen}." - ) - - X_out = X.with_columns( - nw.col(column) - .replace_strict( - { - **mapping, - **{cat: self.fill_value_unseen for cat in unseen_cats}, - } + return X + if self.missing_values == "encode": + # fillna does not work with categorical columns, so we use this + # workaround + return X.with_columns( + nw.when(nw.col(column).is_null()) + .then(nw.lit("MISSING")) + .otherwise(nw.col(column)) + .alias(column) + for column in self.columns_ ) - .alias( - f"{column}" - if self.is_binary_target_ - else f"{column}_WOE_class_{class_}" - ) - for column, classes_mapping in self.encoding_map_.items() - for class_, mapping in classes_mapping.items() - ) - - # In case of binary target, the original columns are replaced with the encoded columns. - # If it is not a binary target, the original columns need to be dropped before returning. - if not self.is_binary_target_: - X_out = X_out.drop(*self.columns_) - - return X_out - - @check_if_fitted - def get_feature_names_out(self) -> list[str]: - """Get the feature names after encoding.""" - if self.is_binary_target_: - return self.feature_names_in_ - else: - return [ - feat for feat in self.feature_names_in_ if feat not in self.columns_ - ] + [ - f"{column}_WOE_class_{class_}" - for column, classes_mapping in self.encoding_map_.items() - for class_ in classes_mapping - ] diff --git a/sklearo/encoding/target.py b/sklearo/encoding/target.py new file mode 100644 index 0000000..bc1862c --- /dev/null +++ b/sklearo/encoding/target.py @@ -0,0 +1,236 @@ +import math +import warnings +from collections import defaultdict +from typing import Any, Literal, Sequence + +import narwhals as nw +from narwhals.typing import IntoFrameT, IntoSeriesT +from pydantic import validate_call + +from sklearo.encoding.base import BaseOneToOneEncoder +from sklearo.utils import infer_type_of_target, select_columns +from sklearo.validation import check_if_fitted, check_X_y + + +class TargetEncoder(BaseOneToOneEncoder): + + @validate_call(config=dict(arbitrary_types_allowed=True)) + def __init__( + self, + columns: Sequence[nw.typing.DTypes | str] | str = ( + nw.Categorical, + nw.String, + ), + underrepresented_categories: Literal["raise", "fill"] = "raise", + fill_values_underrepresented: Sequence[int | float | None] = ( + -999.0, + 999.0, + ), + unseen: Literal["raise", "ignore"] = "raise", + fill_value_unseen: int | float | None | Literal["mean"] = "mean", + missing_values: Literal["encode", "ignore", "raise"] = "encode", + type_of_target: Literal["auto", "binary", "multiclass", "continuous"] = "auto", + ) -> None: + self.columns = columns + self.underrepresented_categories = underrepresented_categories + self.missing_values = missing_values + self.fill_values_underrepresented = fill_values_underrepresented or (None, None) + self.unseen = unseen + self.fill_value_unseen = fill_value_unseen + self.type_of_target = type_of_target + + def _calculate_mean_target( + self, x_y: IntoFrameT, target_cols: Sequence[str], column: str + ) -> dict: + mean_target_all_categories = ( + x_y.group_by(column) + .agg(nw.col(target_col).mean() for target_col in target_cols) + .rows(named=True) + ) + + if len(target_cols) == 1: + mean_target = {} + [target_column_name] = target_cols + for mean_target_per_category in mean_target_all_categories: + mean_target[mean_target_per_category[column]] = ( + mean_target_per_category[target_column_name] + ) + else: + mean_target = defaultdict(dict) + for target_column in target_cols: + class_ = target_column.split("_")[-1] + for mean_target_per_category in mean_target_all_categories: + mean_target[class_][mean_target_per_category[column]] = ( + mean_target_per_category[target_column] + ) + mean_target = dict(mean_target) + + return mean_target + + @nw.narwhalify + @check_X_y + def fit(self, X: IntoFrameT, y: IntoSeriesT) -> "TargetEncoder": + """Fit the encoder. + + Args: + X (DataFrame): The input data. + y (Series): The target variable. + """ + + self.columns_ = list(select_columns(X, self.columns)) + if not self.columns_: + return self + + X = self._handle_missing_values(X) + + if self.type_of_target == "auto": + self.type_of_target_ = infer_type_of_target(y) + else: + self.type_of_target_ = self.type_of_target + + if self.type_of_target_ == "binary": + unique_classes = sorted(y.unique().to_list()) + try: + greatest_class_as_int = int(unique_classes[1]) + except ValueError: + self.is_zero_one_target_ = False + else: + if greatest_class_as_int == 1: + self.is_zero_one_target_ = True + else: + self.is_zero_one_target_ = False + + if not self.is_zero_one_target_: + y = y.replace_strict({unique_classes[0]: 0, unique_classes[1]: 1}) + + else: + self.is_zero_one_target_ = False + + X = X[self.columns_] + + if "target" in X.columns: + target_col_name = "__target__" + + else: + target_col_name = "target" + + X_y = X.with_columns(**{target_col_name: y}) + + if self.type_of_target_ == "multiclass": + unique_classes = y.unique().sort().to_list() + self.unique_classes_ = unique_classes + + X_y = X_y.with_columns( + nw.when(nw.col(target_col_name) == class_) + .then(1) + .otherwise(0) + .alias(f"{target_col_name}_is_class_{class_}") + for class_ in unique_classes + ) + target_cols = [ + f"{target_col_name}_is_class_{class_}" for class_ in unique_classes + ] + + if self.unseen == "fill" and self.fill_value_unseen == "mean": + mean_targets = [X_y[target_cols].mean().rows(named=True)] + mean_target_per_class = {} + for target_col, class_ in zip(target_cols, unique_classes): + mean_target_per_class[class_] = mean_targets[target_col] + self.mean_target_ = mean_target_per_class + + else: + target_cols = [target_col_name] + if self.unseen == "fill" and self.fill_value_unseen == "mean": + self.mean_target_ = X_y[target_col_name].mean() + + self.encoding_map_ = {} + for column in self.columns_: + self.encoding_map_[column] = self._calculate_mean_target( + X_y[target_cols + [column]], target_cols=target_cols, column=column + ) + + self.feature_names_in_ = list(X.columns) + return self + + def _transform_binary_continuous( + self, X: nw.DataFrame, unseen_per_col: dict + ) -> IntoFrameT: + fill_value_unseen = ( + self.fill_value_unseen + if self.fill_value_unseen != "mean" + else self.mean_target_ + ) + return X.with_columns( + nw.col(column).replace_strict( + { + **mapping, + **{ + cat: fill_value_unseen for cat in unseen_per_col.get(column, []) + }, + } + ) + for column, mapping in self.encoding_map_.items() + ) + + def _transform_multiclass( + self, X: nw.DataFrame, unseen_per_col: dict + ) -> IntoFrameT: + fill_value_unseen = ( + {class_: self.fill_value_unseen for class_ in self.unique_classes_} + if self.fill_value_unseen != "mean" + else self.mean_target_ + ) + return X.with_columns( + nw.col(column).replace_strict( + { + **mapping, + **{ + cat: fill_value_unseen[class_] + for cat in unseen_per_col.get(column, []) + }, + } + ) + for column, class_mapping in self.encoding_map_.items() + for class_, mapping in class_mapping.items() + ) + + @nw.narwhalify + @check_if_fitted + def transform(self, X: IntoFrameT) -> IntoFrameT: + """Transform the data. + + Args: + X (DataFrame): The input data. + """ + X = self._handle_missing_values(X) + unseen_per_col = {} + for column, mapping in self.encoding_map_.items(): + uniques = X[column].unique() + unseen_cats = uniques.filter( + ( + ~uniques.is_in(next(iter(mapping.values())).keys()) + & ~uniques.is_null() + ) + ).to_list() + if unseen_cats: + unseen_per_col[column] = unseen_cats + + if unseen_per_col: + if self.unseen == "raise": + raise ValueError( + f"Unseen categories {unseen_per_col} found during transform. " + "Please handle unseen categories for example by using a RareLabelEncoder. " + "Alternatively, set unseen to 'ignore'." + ) + else: + warnings.warn( + f"Unseen categories {unseen_per_col} found during transform. " + "Please handle unseen categories for example by using a RareLabelEncoder. " + f"These categories will be encoded as {self.fill_value_unseen}." + ) + + if self.type_of_target_ in ("binary", "continuous"): + return self._transform_binary_continuous(X, unseen_per_col) + + else: # multiclass + return self._transform_multiclass(X, unseen_per_col) diff --git a/sklearo/encoding/woe.py b/sklearo/encoding/woe.py index 05bbb41..d16f6a4 100644 --- a/sklearo/encoding/woe.py +++ b/sklearo/encoding/woe.py @@ -9,7 +9,7 @@ from sklearo.encoding.base import BaseOneToOneEncoder from sklearo.utils import select_columns -from sklearo.validation import check_X_y +from sklearo.validation import check_if_fitted, check_type_of_target, check_X_y class WOEEncoder(BaseOneToOneEncoder): @@ -96,7 +96,7 @@ class WOEEncoder(BaseOneToOneEncoder): columns_ (list[str]): List of columns to be encoded, learned during fit. encoding_map_ (dict[str, dict[str, float]]): Nested dictionary mapping columns to their WOE values for each class, learned during fit. - is_binary_target_ (bool): Whether the target variable is binary (exactly 0 or 1) or not, + is_zero_one_target_ (bool): Whether the target variable is exactly 0 or 1 or not, learned during fit. feature_names_in_ (list[str]): List of feature names seen during fit. @@ -149,28 +149,6 @@ def __init__( self.unseen = unseen self.fill_value_unseen = fill_value_unseen - def _handle_missing_values(self, X: IntoFrameT) -> IntoFrameT: - if self.missing_values == "ignore": - return X - if self.missing_values == "raise": - if max(X[self.columns_].null_count().row(0)) > 0: - raise ValueError( - f"Some columns have missing values. " - "Please handle missing values before encoding or set " - "missing_values to either 'ignore' or 'encode'." - ) - return X - if self.missing_values == "encode": - # fillna does not work with categorical columns, so we use this - # workaround - return X.with_columns( - nw.when(nw.col(column).is_null()) - .then(nw.lit("MISSING")) - .otherwise(nw.col(column)) - .alias(column) - for column in self.columns_ - ) - def _calculate_woe( self, x: IntoSeriesT, y: IntoSeriesT, unique_classes: list[Any] ) -> dict[str, dict[str, float | int | None]]: @@ -310,6 +288,7 @@ def _calculate_woe( @nw.narwhalify @check_X_y + @check_type_of_target("binary", "multiclass") def fit(self, X: IntoFrameT, y: IntoSeriesT) -> "WOEEncoder": """Fit the encoder. @@ -318,35 +297,111 @@ def fit(self, X: IntoFrameT, y: IntoSeriesT) -> "WOEEncoder": y (Series): The target variable. """ - self.columns_ = list(select_columns(X, self.columns)) X = self._handle_missing_values(X) + + self.feature_names_in_ = list(X.columns) + self.columns_ = list(select_columns(X, self.columns)) self.encoding_map_ = {} + self.is_zero_one_target_ = False + unique_classes = sorted(y.unique().to_list()) + self.unqiue_classes_ = unique_classes if not self.columns_: return self - unique_classes = sorted(y.unique().to_list()) - self.unqiue_classes_ = unique_classes - if len(unique_classes) == 2: unique_classes = [unique_classes[1]] try: greatest_class_as_int = int(unique_classes[0]) except ValueError: - self.is_binary_target_ = False + self.is_zero_one_target_ = False else: if greatest_class_as_int == 1: - self.is_binary_target_ = True + self.is_zero_one_target_ = True else: - self.is_binary_target_ = False + self.is_zero_one_target_ = False else: - self.is_binary_target_ = False + self.is_zero_one_target_ = False for column in self.columns_: self.encoding_map_[column] = self._calculate_woe( X[column], y, unique_classes ) - self.feature_names_in_ = list(X.columns) return self + + @nw.narwhalify + @check_if_fitted + def transform(self, X: IntoFrameT) -> IntoFrameT: + """Transform the data. + + Args: + X (DataFrame): The input data. + """ + X = self._handle_missing_values(X) + unseen_per_col = {} + + for column, mapping in self.encoding_map_.items(): + uniques = X[column].unique() + unseen_cats = uniques.filter( + ( + ~uniques.is_in(next(iter(mapping.values())).keys()) + & ~uniques.is_null() + ) + ).to_list() + if unseen_cats: + unseen_per_col[column] = unseen_cats + + if unseen_per_col: + if self.unseen == "raise": + raise ValueError( + f"Unseen categories {unseen_per_col} found during transform. " + "Please handle unseen categories for example by using a RareLabelEncoder. " + "Alternatively, set unseen to 'ignore'." + ) + else: + warnings.warn( + f"Unseen categories {unseen_per_col} found during transform. " + "Please handle unseen categories for example by using a RareLabelEncoder. " + f"These categories will be encoded as {self.fill_value_unseen}." + ) + + X_out = X.with_columns( + nw.col(column) + .replace_strict( + { + **mapping, + **{ + cat: self.fill_value_unseen + for cat in unseen_per_col.get(column, []) + }, + } + ) + .alias( + column if self.is_zero_one_target_ else f"{column}_WOE_class_{class_}" + ) + for column, classes_mapping in self.encoding_map_.items() + for class_, mapping in classes_mapping.items() + ) + + # In case of binary target, the original columns are replaced with the encoded columns. + # If it is not a binary target, the original columns need to be dropped before returning. + if not self.is_zero_one_target_: + X_out = X_out.drop(*self.columns_) + + return X_out + + @check_if_fitted + def get_feature_names_out(self) -> list[str]: + """Get the feature names after encoding.""" + if self.is_zero_one_target_: + return self.feature_names_in_ + else: + return [ + feat for feat in self.feature_names_in_ if feat not in self.columns_ + ] + [ + f"{column}_WOE_class_{class_}" + for column, classes_mapping in self.encoding_map_.items() + for class_ in classes_mapping + ] diff --git a/sklearo/utils.py b/sklearo/utils.py index 070cd8d..2f4ab37 100644 --- a/sklearo/utils.py +++ b/sklearo/utils.py @@ -3,22 +3,35 @@ from typing import Sequence import narwhals as nw -from narwhals.typing import IntoFrameT +from narwhals.typing import IntoSeriesT +INTEGER_DTYPES = [ + nw.Int8, + nw.Int16, + nw.Int32, + nw.Int64, + nw.UInt8, + nw.UInt16, + nw.UInt32, + nw.UInt64, +] -def select_columns_by_regex_pattern(df: IntoFrameT, pattern: str): +FLOAT_DTYPES = [nw.Float32, nw.Float64] + + +def select_columns_by_regex_pattern(df: nw.DataFrame, pattern: str): for column in df.columns: if re.search(pattern, column): yield column -def select_columns_by_types(df: IntoFrameT, dtypes: list[nw.dtypes.DType]): +def select_columns_by_types(df: nw.DataFrame, dtypes: list[nw.dtypes.DType]): for column, dtype in zip(df.schema.names(), df.schema.dtypes()): if dtype in dtypes: yield column -def select_columns(df: IntoFrameT, columns: Sequence[nw.typing.DTypes | str] | str): +def select_columns(df: nw.DataFrame, columns: Sequence[nw.typing.DTypes | str] | str): if isinstance(columns, str): yield from select_columns_by_regex_pattern(df, columns) @@ -29,3 +42,80 @@ def select_columns(df: IntoFrameT, columns: Sequence[nw.typing.DTypes | str] | s yield from columns else: raise ValueError("Invalid columns type") + + +@nw.narwhalify +def infer_type_of_target(y: IntoSeriesT) -> str: + """Infer the type of target variable based on the input series. + + This function determines the type of target variable based on the unique values and data type + of the input series. + + Args: + y (nw.Series): The target variable series. + + Returns: + str: The inferred type of target variable, which can be one of the following: + + - `"binary"`: Returned when the target variable contains exactly two unique values and + is of an integer, boolean, string or categorical data type or it's floating point with + no decimal digits (e.g. `[0.0, 1.0]`). + - `"multiclass"`: Returned when the target variable has more than two unique values and + is of an integer, boolean, string or categorical data type or it's floating point with + no decimal digits (e.g. `[0.0, 1.0, 2.0]`). In case of floating point data type, the + unique values should be consecutive integers. + - `"continuous"`: Returned when the target variable is of a floating-point data type and + contains at least one non-integer value or the unique values are not consecutive + integers. + - `"unknown"`: Returned when the input series is none of the above types. + + Examples: + >>> type_of_target(pd.Series([1, 2, 3]) + "multiclass" + >>> type_of_target(pd.Series([1, 2, 1]) + "binary" + >>> type_of_target(pd.Series([1, 2, 4]) + "multiclass" + >>> type_of_target(pd.Series(["a", "b", "c"]) + "multiclass" + >>> type_of_target(pd.Series(["a", "b", "a"]) + "binary" + >>> type_of_target(pd.Series([1.0, 2.0, 3.5]) + "continuous" + >>> type_of_target(pd.Series([1.0, 3.5, 3.5]) + "continuous" + >>> type_of_target(pd.Series([1.0, 2.0, 4.0]) + "continuous" + >>> type_of_target(pd.Series([1.0, 4.0, 4.0]) + "binary" + >>> type_of_target(pd.Series([1.0, 2.0, 3.0]) + "multiclass" + >>> type_of_target(pd.Series([1.0, 2.0, 1.0]) + "binary" + + """ + if y.dtype == nw.Boolean: + return "binary" + + if y.dtype in INTEGER_DTYPES or y.dtype in (nw.String, nw.Categorical): + if len(y.unique().to_list()) == 2: + return "binary" + else: + return "multiclass" + + if y.dtype in FLOAT_DTYPES: + if (y % 1 != 0).any(): + return "continuous" + + else: + unique = y.unique() + if len(unique.to_list()) == 2: + return "binary" + sorted_unique = unique.sort() + labels_diff = sorted_unique - sorted_unique.shift(1) + if labels_diff.max() == labels_diff.min() == 1.0: + return "multiclass" + else: + return "continuous" + + return "unknown" diff --git a/sklearo/validation.py b/sklearo/validation.py index 4eb0d2c..9db3eb8 100644 --- a/sklearo/validation.py +++ b/sklearo/validation.py @@ -1,5 +1,7 @@ from functools import wraps +from sklearo.utils import infer_type_of_target + def check_X_y(func): @wraps(func) @@ -32,3 +34,22 @@ def wrapper(self, *args, **kwargs): return func(self, *args, **kwargs) return wrapper + + +def check_type_of_target(*allowed_types_of_target): + def decorator(func): + @wraps(func) + def wrapper(self, X, y, *args, **kwargs): + inferred_type_of_target = infer_type_of_target(y) + if inferred_type_of_target not in allowed_types_of_target: + raise ValueError( + f"{self.__class__.__name__} supports the following types of target: " + f"{allowed_types_of_target}, but the inferred type of target was " + f"type_of_target={inferred_type_of_target}. To know more on how target is " + "inferred please refer to the documentation of sklearo.utils.type_of_target." + ) + return func(self, X, y, *args, **kwargs) + + return wrapper + + return decorator diff --git a/tests/encoding/test_woe.py b/tests/encoding/test_woe.py index 111cd4f..e978f09 100644 --- a/tests/encoding/test_woe.py +++ b/tests/encoding/test_woe.py @@ -34,7 +34,7 @@ def test_woe_encoder_fit_binary(self, binary_class_data, DataFrame): assert encoder.columns_ == ["category"] assert "category" in encoder.encoding_map_ - assert encoder.is_binary_target_ is True + assert encoder.is_zero_one_target_ is True def test_woe_encoder_fit_multiclass_non_int_target( self, binary_class_data, DataFrame @@ -45,7 +45,7 @@ def test_woe_encoder_fit_multiclass_non_int_target( assert encoder.columns_ == ["target"] assert "target" in encoder.encoding_map_ - assert encoder.is_binary_target_ is False + assert encoder.is_zero_one_target_ is False transformed_data = encoder.transform(binary_class_data[["target"]]) np.testing.assert_allclose( @@ -71,7 +71,7 @@ def test_woe_encoder_fit_binary_non_int_target(self, multi_class_data, DataFrame assert encoder.columns_ == ["target"] assert "target" in encoder.encoding_map_ - assert encoder.is_binary_target_ is False + assert encoder.is_zero_one_target_ is False transformed_data = encoder.transform(multi_class_data[["target"]]) @@ -109,7 +109,7 @@ def test_woe_encoder_fit_binary_non_int_target_classes_1_and_2( assert encoder.columns_ == ["category"] assert "category" in encoder.encoding_map_ - assert encoder.is_binary_target_ is False + assert encoder.is_zero_one_target_ is False transformed_data = encoder.transform(binary_class_data[["category"]]) @@ -138,7 +138,7 @@ def test_woe_encoder_fit_with_target_in_X_binary( assert encoder.columns_ == ["category", "target"] assert "category" in encoder.encoding_map_ - assert encoder.is_binary_target_ is True + assert encoder.is_zero_one_target_ is True def test_woe_encoder_fit_with_target_in_X_multi_class( self, multi_class_data, DataFrame @@ -153,7 +153,7 @@ def test_woe_encoder_fit_with_target_in_X_multi_class( assert encoder.columns_ == ["category", "target"] assert "category" in encoder.encoding_map_ - assert encoder.is_binary_target_ is False + assert encoder.is_zero_one_target_ is False def test_woe_encoder_fit_with_target_in_X_multi_class_raise_underrepresented( self, multi_class_data, DataFrame @@ -181,7 +181,7 @@ def test_woe_encoder_fit_multi_class(self, multi_class_data, DataFrame): assert encoder.columns_ == ["category"] assert "category" in encoder.encoding_map_ - assert encoder.is_binary_target_ is False + assert encoder.is_zero_one_target_ is False def test_woe_encoder_transform_binary(self, binary_class_data, DataFrame): binary_class_data = DataFrame(binary_class_data) diff --git a/tests/test_utils.py b/tests/test_utils.py index 20c1864..2441633 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -4,6 +4,7 @@ import pytest from sklearo.utils import ( + infer_type_of_target, select_columns, select_columns_by_regex_pattern, select_columns_by_types, @@ -76,3 +77,32 @@ def test_select_columns_invalid_type(self, sample_data, DataFrame): df = DataFrame(sample_data) with pytest.raises(ValueError, match="Invalid columns type"): list(select_columns(df, [1, 2])) + + +@pytest.mark.parametrize("Series", [pd.Series, pl.Series], ids=["pandas", "polars"]) +class TestTypeOfTarget: + + @pytest.mark.parametrize( + "data, expected", + [ + ([1, 2, 3], "multiclass"), + ([1, 2, 1], "binary"), + ([1, 2, 4], "multiclass"), + (["a", "b", "c"], "multiclass"), + (["a", "b", "a"], "binary"), + ([1.0, 2.0, 3.5], "continuous"), + ([1.0, 2.0, 4.0], "continuous"), + ([1.0, 3.5, 3.5], "continuous"), + ([1.0, 2.0, 3.0], "multiclass"), + ([1.0, 2.0, 1.0], "binary"), + ([1.0, 4.0, 4.0], "binary"), + ], + ) + def test_type_of_target(self, Series, data, expected): + series = Series(data) + assert infer_type_of_target(series) == expected + + def test_type_of_target_unknown(self, Series): + data = [None, None, None] + series = Series(data) + assert infer_type_of_target(series) == "unknown"