-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
29f5418
commit 22dd5b3
Showing
9 changed files
with
388 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
env-create: | ||
pip install --upgrade pip | ||
if ! test -d venv; \ | ||
then \ | ||
echo creating virtual environment; \ | ||
pip install --upgrade virtualenv; \ | ||
python -m venv venv; \ | ||
fi | ||
|
||
env-install: | ||
pip install --upgrade pip | ||
if test -s requirements.txt; \ | ||
then \ | ||
echo Installing requirements from requirements.txt; \ | ||
pip install -r requirements.txt ; \ | ||
pip install -e . --no-deps ; \ | ||
else \ | ||
echo Installing requirements from pyproject.toml; \ | ||
pip install -e '.[dev]'; \ | ||
pip freeze --exclude-editable > requirements.txt; \ | ||
fi | ||
|
||
env-update: | ||
pip install -e '.[dev]' | ||
pip freeze --exclude-editable > requirements.txt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
# sklearo | ||
A versatile Python package featuring scikit-learn like transformers for feature preprocessing, compatible with all kind of dataframes thanks to narwals. | ||
A versatile Python package featuring scikit-learn like transformers for feature preprocessing, compatible with all kind of dataframes thanks to narwhals. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# Development Guide | ||
|
||
## Installing dev dependencies | ||
|
||
```bash | ||
# Create a new virtual environment | ||
python -m venv venv | ||
# Activate the virtual environment | ||
source venv/bin/activate | ||
# Install the dependencies | ||
make env-install | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
[project] | ||
name = "sklearo" | ||
description = "A versatile Python package featuring scikit-learn like transformers for feature preprocessing, compatible with all kind of dataframes thanks to narwhals." | ||
version = "0.1.0" | ||
keywords = ["feature preprocessing", "scikit-learn", "machine learning"] | ||
authors = [ | ||
{ name = "Claudio Salvatore Arcidiacono", email = "[email protected]" }, | ||
] | ||
readme = "README.md" | ||
requires-python = ">=3.9" | ||
classifiers = [ | ||
"Programming Language :: Python :: 3", | ||
"License :: OSI Approved :: MIT License", | ||
"Operating System :: OS Independent", | ||
] | ||
|
||
dependencies = ["narwhals", "pydantic"] | ||
|
||
[project.optional-dependencies] | ||
dev = ["black", "ruff", "pre-commit", "pytest", "polars"] | ||
doc = ["mkdocs", "mkdocs-material", "mkdocstrings[python]", "mkdocs-jupyter"] | ||
build = ["build", "twine"] | ||
|
||
[build-system] | ||
build-backend = "flit_core.buildapi" | ||
requires = ["flit_core >=3.2,<4"] | ||
|
||
[project.urls] | ||
"Homepage" = "https://github.com/ClaudioSalvatoreArcidiacono/sklearo" | ||
"Documentation" = "https://claudiosalvatorearcidiacono.github.io/sklearo/" | ||
"Bug Tracker" = "https://github.com/ClaudioSalvatoreArcidiacono/sklearo/issues" | ||
|
||
[tool.black] | ||
line-length = 88 | ||
include = '\.pyi?$' | ||
exclude = ''' | ||
/( | ||
\.git | ||
| \.hg | ||
| \.mypy_cache | ||
| \.tox | ||
| \.venv | ||
| _build | ||
| buck-out | ||
| build | ||
| dist | ||
)/ | ||
''' | ||
|
||
[tool.flake8] | ||
max-line-length = 88 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
black==24.10.0 | ||
cfgv==3.4.0 | ||
click==8.1.7 | ||
distlib==0.3.9 | ||
filelock==3.16.1 | ||
identify==2.6.3 | ||
iniconfig==2.0.0 | ||
mypy-extensions==1.0.0 | ||
narwhals==1.15.2 | ||
nodeenv==1.9.1 | ||
packaging==24.2 | ||
pathspec==0.12.1 | ||
platformdirs==4.3.6 | ||
pluggy==1.5.0 | ||
polars==1.16.0 | ||
pre_commit==4.0.1 | ||
pytest==8.3.4 | ||
PyYAML==6.0.2 | ||
ruff==0.8.2 | ||
virtualenv==20.28.0 |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .woe import WOEEncoder | ||
|
||
__all__ = ["WOEEncoder"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,250 @@ | ||
import narwhals as nw | ||
from narwhals.typing import IntoFrameT, IntoSeriesT | ||
import warnings | ||
import math | ||
from typing import Sequence, Literal, Optional | ||
|
||
from pydantic import validate_arguments | ||
from sklearo.utils import select_columns | ||
|
||
|
||
class WOEEncoder: | ||
"""Weight of Evidence (WOE) Encoder. | ||
This class provides functionality to encode categorical features using the Weight of Evidence | ||
(WOE) technique. WOE is commonly used in credit scoring and other binary classification problems | ||
to transform categorical variables into continuous variables. | ||
WOE is defined as the natural logarithm of the ratio of the distribution of goods (i.e. the | ||
negative class, 0) to the distribution of bads (i.e. the positive class, 1) for a | ||
given category. | ||
``` | ||
WOE = ln((% of goods) / (% of bads)) | ||
``` | ||
The WOE value is positive if the category is more likely to be good (negative class) and | ||
negative if it is more likely to be bad (positive class). This means that the WOE should be | ||
inversely correlated to the target variable. | ||
The WOE encoding is useful for | ||
logistic regression and other linear models, as it transforms the categorical variables into | ||
continuous variables that can be used as input features. | ||
Args: | ||
columns (str, list[str], list[nw.typing.DTypes]): list of columns to encode. | ||
If a single string is passed instead, it is treated as a regular expression pattern to | ||
match column names. If a list of `narwhals.typing.DTypes` is passed, it will select | ||
all columns matching the specified dtype. Defaults to [narwhals.Categorical, | ||
narwhals.String]. | ||
underrepresented_categories (str): Strategy to handle underrepresented categories. | ||
If 'raise', an error is raised when a category is missing one of the target classes. If | ||
'fill', the missing categories are encoded using the fill_values_underrepresented | ||
values. | ||
fill_values_underrepresented (list[int, float]): Fill values to use for underrepresented | ||
categories. The first value is used when there are no goods and the second value when | ||
there are no bads. Only used when underrepresented_categories is set to 'fill'. | ||
Optional, Defaults to (-999.0, 999.0). | ||
unseen (str): Strategy to handle unseen categories. If 'raise', an error is raised when | ||
unseen categories are found. If 'ignore', the unseen categories are encoded with the | ||
fill_value_unseen. | ||
fill_value_unseen (int, float): Fill value to use for unseen categories. Only used when | ||
unseen is set to 'ignore'. Optional, Defaults to 0.0. | ||
missing_values (str): Strategy to handle missing values. If 'encode', missing values are | ||
initially encoded as 'MISSING' and the WOE is computed as if it were a regular category. | ||
If 'ignore', missing values are left as is. If 'raise', an error is raised when missing | ||
values are found. | ||
suffix (str): Suffix to append to the column names of the encoded columns. If an empty | ||
string is passed, the original column names are replaced. Optional, Defaults to "". | ||
Attributes: | ||
columns_ (list): List of columns to be encoded, learned during fit. | ||
encoding_map_ (dict): Dictionary mapping columns to their WOE values, learned during fit. | ||
Examples: | ||
```python | ||
import pandas as pd | ||
from sklearo.encoding import WOEEncoder | ||
data = { | ||
"category": ["A", "B", "A", "C", "B", "C", "A", "B", "C"], | ||
"target": [0, 0, 1, 0, 1, 0, 1, 0, 1], | ||
} | ||
df = pd.DataFrame(data) | ||
encoder = WOEEncoder() | ||
encoder.fit(df[["category"]], df["target"]) | ||
encoded = encoder.transform(df[["category"]]) | ||
print(encoded) | ||
category | ||
0 -0.693147 | ||
1 0.693147 | ||
2 -0.693147 | ||
3 0.693147 | ||
4 0.693147 | ||
5 0.693147 | ||
6 -0.693147 | ||
7 0.693147 | ||
8 0.693147 | ||
``` | ||
""" | ||
|
||
@validate_arguments(config=dict(arbitrary_types_allowed=True)) | ||
def __init__( | ||
self, | ||
columns: Sequence[nw.typing.DTypes | str] | str = ( | ||
nw.Categorical, | ||
nw.String, | ||
), | ||
underrepresented_categories: Literal["raise", "fill"] = "raise", | ||
fill_values_underrepresented: Sequence[int | float | None] | None = ( | ||
-999.0, | ||
999.0, | ||
), | ||
unseen: Literal["raise", "ignore"] = "raise", | ||
fill_value_unseen: int | float | None = 0.0, | ||
missing_values: Literal["encode", "ignore", "raise"] = "encode", | ||
suffix: str = "", | ||
) -> None: | ||
self.columns = columns | ||
self.underrepresented_categories = underrepresented_categories | ||
self.missing_values = missing_values | ||
self.fill_values_underrepresented = fill_values_underrepresented or (None, None) | ||
self.unseen = unseen | ||
self.fill_value_unseen = fill_value_unseen | ||
self.suffix = suffix | ||
|
||
def _handle_missing_values(self, x: IntoSeriesT) -> IntoSeriesT: | ||
if self.missing_values == "ignore": | ||
return x | ||
if self.missing_values == "raise": | ||
if x.null_count() > 0: | ||
raise ValueError( | ||
f"Column {x.name} has missing values. " | ||
"Please handle missing values before encoding or set " | ||
"missing_values to either 'ignore' or 'encode'." | ||
) | ||
if self.missing_values == "encode": | ||
return x.fill_null("MISSING") | ||
|
||
def _calculate_woe( | ||
self, x: IntoSeriesT, y: IntoSeriesT, total_goods: int, total_bads: int | ||
) -> dict[str, dict[str, float | int | None]]: | ||
"""Calculate the Weight of Evidence for a column.""" | ||
|
||
categories_n_goods_n_bads_dist_ratio = ( | ||
x.to_frame() | ||
.with_columns(y) | ||
.group_by(x.name) | ||
.agg( | ||
n_total=nw.col(y.name).count(), | ||
n_bads=nw.col(y.name).sum(), | ||
) | ||
.with_columns(n_goods=nw.col("n_total") - nw.col("n_bads")) | ||
.with_columns( | ||
perc_goods=nw.col("n_goods") / total_goods, | ||
perc_bads=nw.col("n_bads") / total_bads, | ||
) | ||
.with_columns( | ||
dist_ratio=nw.col("perc_bads") / nw.col("perc_goods") | ||
) | ||
.select(x.name, "n_goods", "n_bads", "dist_ratio") | ||
.rows() | ||
) | ||
categories, n_goods, n_bads, dist_ratios = zip(*categories_n_goods_n_bads_dist_ratio) | ||
|
||
total_goods = sum(n_goods) | ||
total_bads = sum(n_bads) | ||
|
||
if any(n_good == 0 for n_good in n_goods) or any( | ||
n_bad == 0 for n_bad in n_bads | ||
): | ||
problematic_categories = [ | ||
cat | ||
for cat, n_good, n_bad in zip(categories, n_goods, n_bads) | ||
if n_good == 0 or n_bad == 0 | ||
] | ||
msg = ( | ||
f"The categories {problematic_categories} for the column {x.name} " | ||
"are missing one of the target classes. For WOE to be defined, all categories " | ||
"should have at least one observation of each target class. Please consider " | ||
"removing infrequent categories using a RareLabelEncoder" | ||
) | ||
if self.underrepresented_categories == "raise": | ||
raise ValueError( | ||
msg + " or by setting underrepresented_categories to 'fill'." | ||
) | ||
|
||
else: # fill | ||
warnings.warn( | ||
msg + ". The infrequent categories will be encoded as " | ||
f"{self.fill_values_underrepresented[0]} " | ||
f"when there are no goods and with {self.fill_values_underrepresented[1]} when " | ||
"there are no bads." | ||
) | ||
|
||
woes = [] | ||
for dist_ratio, n_good, n_bad in zip(dist_ratios, n_goods, n_bads): | ||
if n_good == 0: | ||
# means there are only bads | ||
woes.append(self.fill_values_underrepresented[0]) | ||
elif n_bad == 0: | ||
# means there are only goods | ||
woes.append(self.fill_values_underrepresented[1]) | ||
else: | ||
woes.append(math.log(dist_ratio)) | ||
|
||
return dict(zip(categories, woes)) | ||
|
||
@nw.narwhalify | ||
def fit(self, X: IntoFrameT, y: IntoSeriesT) -> "WOEEncoder": | ||
"""Fit the encoder.""" | ||
|
||
self.columns_ = select_columns(X, self.columns) | ||
self.encoding_map_ = {} | ||
|
||
total_bads = y.sum() | ||
total_goods = y.count() - total_bads | ||
for column in self.columns_: | ||
self.encoding_map_[column] = self._calculate_woe( | ||
self._handle_missing_values(X[column]), y, total_goods, total_bads | ||
) | ||
return self | ||
|
||
@nw.narwhalify | ||
def transform(self, X: IntoFrameT) -> IntoFrameT: | ||
"""Transform the data.""" | ||
|
||
unseen_per_col = {} | ||
for column, mapping in self.encoding_map_.items(): | ||
uniques = X[column].unique() | ||
unseen_cats = uniques.filter(~uniques.is_in(mapping.keys())).to_list() | ||
if unseen_cats: | ||
unseen_per_col[column] = unseen_cats | ||
|
||
if unseen_per_col: | ||
if self.unseen == "raise": | ||
raise ValueError( | ||
f"Unseen categories {unseen_per_col} found during transform. " | ||
"Please handle unseen categories for example by using a RareLabelEncoder. " | ||
"Alternatively, set unseen to 'ignore'." | ||
) | ||
else: | ||
warnings.warn( | ||
f"Unseen categories {unseen_per_col} found during transform. " | ||
"Please handle unseen categories for example by using a RareLabelEncoder. " | ||
f"These categories will be encoded as {self.fill_value_unseen}." | ||
) | ||
|
||
return X.with_columns( | ||
nw.col(column) | ||
.pipe(self._handle_missing_values) | ||
.replace_strict( | ||
{ | ||
**mapping, | ||
**{cat: self.fill_value_unseen for cat in unseen_cats}, | ||
} | ||
) | ||
.alias(f"{column}{self.suffix}") | ||
for column, mapping in self.encoding_map_.items() | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import re | ||
import narwhals as nw | ||
from narwhals.typing import IntoFrameT | ||
|
||
def select_columns_by_regex_pattern(df: IntoFrameT, pattern: str): | ||
for column in df.columns: | ||
if re.search(pattern, column): | ||
yield column | ||
|
||
|
||
def select_columns_by_types(df: IntoFrameT, dtypes: list[nw.dtypes.DType]): | ||
for column, dtype in zip(df.schema.names(), df.schema.dtypes()): | ||
if dtype in dtypes: | ||
yield column | ||
|
||
def select_columns(df: IntoFrameT, columns): | ||
if isinstance(columns, str): | ||
yield from select_columns_by_regex_pattern(df, columns) | ||
|
||
if (isinstance(columns, list) or isinstance(columns, tuple)) and columns: | ||
if issubclass(columns[0], nw.dtypes.DType): | ||
yield from select_columns_by_types(df, columns) | ||
elif isinstance(columns[0], str): | ||
yield from columns | ||
else: | ||
raise ValueError("Invalid columns type") |