Skip to content

Commit

Permalink
[Feature Store] Update return_df behaviour in ingest (mlrun#2284)
Browse files Browse the repository at this point in the history
  • Loading branch information
yonishelach authored Aug 23, 2022
1 parent a077155 commit 9f8cb94
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 8 deletions.
20 changes: 12 additions & 8 deletions mlrun/feature_store/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import copy
from datetime import datetime
from typing import List, Union
from typing import List, Optional, Union
from urllib.parse import urlparse

import pandas as pd
Expand Down Expand Up @@ -304,7 +304,7 @@ def ingest(
mlrun_context=None,
spark_context=None,
overwrite=None,
) -> pd.DataFrame:
) -> Optional[pd.DataFrame]:
"""Read local DataFrame, file, URL, or source into the feature store
Ingest reads from the source, run the graph transformations, infers metadata and stats
and writes the results to the default of specified targets
Expand Down Expand Up @@ -348,7 +348,7 @@ def ingest(
:param overwrite: delete the targets' data prior to ingestion
(default: True for non scheduled ingest - deletes the targets that are about to be ingested.
False for scheduled ingest - does not delete the target)
:return: if return_df is True, a dataframe will be returned based on the graph
"""
if isinstance(source, pd.DataFrame):
source = _rename_source_dataframe_columns(source)
Expand Down Expand Up @@ -509,6 +509,7 @@ def ingest(
mlrun_context=mlrun_context,
namespace=namespace,
overwrite=overwrite,
return_df=return_df,
)

if isinstance(source, str):
Expand All @@ -527,15 +528,16 @@ def ingest(
infer_stats = InferOptions.get_common_options(
infer_options, InferOptions.all_stats()
)
return_df = return_df or infer_stats != InferOptions.Null
# Check if dataframe is already calculated (for feature set graph):
calculate_df = return_df or infer_stats != InferOptions.Null
featureset.save()

df = init_featureset_graph(
source,
featureset,
namespace,
targets=targets_to_ingest,
return_df=return_df,
return_df=calculate_df,
)
if not InferOptions.get_common_options(
infer_stats, InferOptions.Index
Expand All @@ -556,8 +558,8 @@ def ingest(
target.last_written = source.start_time

_post_ingestion(mlrun_context, featureset, spark_context)

return df
if return_df:
return df


def preview(
Expand Down Expand Up @@ -754,6 +756,7 @@ def _ingest_with_spark(
mlrun_context=None,
namespace=None,
overwrite=None,
return_df=None,
):
created_spark_context = False
try:
Expand Down Expand Up @@ -867,7 +870,8 @@ def _ingest_with_spark(
spark.stop()
# We shouldn't return a dataframe that depends on a stopped context
df = None
return df
if return_df:
return df


def _post_ingestion(context, featureset, spark=None):
Expand Down
29 changes: 29 additions & 0 deletions tests/feature-store/test_ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,32 @@ def test_set_targets_with_string():
assert nosql_target.name == "nosql"
assert nosql_target.kind == "nosql"
assert not nosql_target.partitioned


def test_return_df(rundb_mock):
fset = fs.FeatureSet(
"myset",
entities=[fs.Entity("ticker")],
)

df = pd.DataFrame(
{
"ticker": ["GOOG", "MSFT"],
"bid (accepted)": [720.50, 51.95],
"ask": [720.93, 51.96],
"with space": [True, False],
}
)
fset._run_db = rundb_mock

fset.reload = unittest.mock.Mock()
fset.save = unittest.mock.Mock()
fset.purge_targets = unittest.mock.Mock()

result_df = fs.ingest(fset, df, targets=[DFTarget()], return_df=False)

assert result_df is None

result_df = fs.ingest(fset, df, targets=[DFTarget()])

assert isinstance(result_df, pd.DataFrame)

0 comments on commit 9f8cb94

Please sign in to comment.