Skip to content

Commit

Permalink
[Feature Store] Fix Imputer's None types check (mlrun#2941)
Browse files Browse the repository at this point in the history
  • Loading branch information
yonishelach authored Jan 22, 2023
1 parent bbc7d14 commit 4f407d3
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 3 deletions.
6 changes: 3 additions & 3 deletions mlrun/feature_store/steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,12 +272,12 @@ def __init__(
:param kwargs: optional kwargs (for storey)
"""
super().__init__(**kwargs)
self.mapping = mapping
self.mapping = mapping or {}
self.method = method
self.default_value = default_value

def _impute(self, feature: str, value):
if value is None:
def _impute(self, feature: str, value: Any):
if pd.isna(value):
return self.mapping.get(feature, self.default_value)
return value

Expand Down
36 changes: 36 additions & 0 deletions tests/feature-store/test_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import time
import unittest.mock

import numpy as np
import pandas as pd
import pytest

Expand Down Expand Up @@ -616,6 +617,41 @@ def test_pandas_step_drop_feature(rundb_mock, entities, set_index_before):
)


@pytest.mark.parametrize("engine", ["storey", "pandas"])
def test_imputer_default_value(rundb_mock, engine):
data_with_nones = pd.DataFrame(
{
"id": [1, 2, 3, 4],
"height": [None, 160, pd.NA, np.nan],
"age": [20, pd.NaT, 19, 18],
}
)
# Building graph with Imputer:
feature_set = fstore.FeatureSet(
"fs-default-value",
entities=["id"],
description="feature set with nones",
engine=engine,
)
feature_set.graph.to(Imputer(default_value=1))

# Mocking
output_path = tempfile.TemporaryDirectory()
feature_set._run_db = rundb_mock
feature_set.reload = unittest.mock.Mock()
feature_set.save = unittest.mock.Mock()
feature_set.purge_targets = unittest.mock.Mock()

imputed_df = fstore.ingest(
featureset=feature_set,
source=data_with_nones,
targets=[ParquetTarget(path=f"{output_path.name}/temp.parquet")],
)

# Checking that the ingested dataframe is none-free:
assert not imputed_df.isnull().values.any()


def get_data(with_none=False):
names = ["A", "B", "C", "D", "E"]
ages = [33, 4, 76, 90, 24]
Expand Down

0 comments on commit 4f407d3

Please sign in to comment.