diff --git a/mlrun/datastore/redis.py b/mlrun/datastore/redis.py index 2e8861035a..6546799bb0 100644 --- a/mlrun/datastore/redis.py +++ b/mlrun/datastore/redis.py @@ -147,7 +147,10 @@ def rm(self, key, recursive=False, maxdepth=None): if recursive: key += "*" if key.endswith("/") else "/*" - for key in self.redis.scan_iter(key): - self.redis.delete(key) + for k in self.redis.scan_iter(key): + self.redis.delete(k) + key = f"_spark:{key}" + for k in self.redis.scan_iter(key): + self.redis.delete(k) else: self.redis.delete(key) diff --git a/mlrun/datastore/targets.py b/mlrun/datastore/targets.py index 1020acf07e..6fa2693ebf 100644 --- a/mlrun/datastore/targets.py +++ b/mlrun/datastore/targets.py @@ -20,6 +20,7 @@ from collections import Counter from copy import copy from typing import Any, Dict, List, Optional, Union +from urllib.parse import urlparse import pandas as pd import sqlalchemy @@ -684,7 +685,7 @@ def get_spark_options(self, key_column=None, timestamp_key=None, overwrite=True) # options used in spark.read.load(**options) raise NotImplementedError() - def prepare_spark_df(self, df): + def prepare_spark_df(self, df, key_columns): return df def get_dask_options(self): @@ -998,7 +999,7 @@ def get_spark_options(self, key_column=None, timestamp_key=None, overwrite=True) "header": "true", } - def prepare_spark_df(self, df): + def prepare_spark_df(self, df, key_columns): import pyspark.sql.functions as funcs for col_name, col_type in df.dtypes: @@ -1120,7 +1121,7 @@ def get_dask_options(self): def as_df(self, columns=None, df_module=None, **kwargs): raise NotImplementedError() - def prepare_spark_df(self, df): + def prepare_spark_df(self, df, key_columns): import pyspark.sql.functions as funcs for col_name, col_type in df.dtypes: @@ -1174,7 +1175,7 @@ def get_table_object(self): class RedisNoSqlTarget(NoSqlBaseTarget): kind = TargetTypes.redisnosql - support_spark = False + support_spark = True writer_step_name = "RedisNoSqlTarget" def get_table_object(self): @@ -1190,6 +1191,29 @@ def get_table_object(self): flush_interval_secs=mlrun.mlconf.feature_store.flush_interval, ) + def get_spark_options(self, key_column=None, timestamp_key=None, overwrite=True): + parsed_url = urlparse(self.get_target_path()) + if parsed_url.hostname is None: + self.path = mlrun.mlconf.redis.url + parsed_url = urlparse(self.get_target_path()) + + return { + "key.column": "_spark_object_name", + "table": "{" + store_path_to_spark(self.get_target_path()), + "format": "org.apache.spark.sql.redis", + "host": parsed_url.hostname, + "port": parsed_url.port if parsed_url.port else "6379", + "user": parsed_url.username, + "auth": parsed_url.password, + } + + def prepare_spark_df(self, df, key_columns): + from pyspark.sql.functions import udf + from pyspark.sql.types import StringType + + udf1 = udf(lambda x: x + "}:static", StringType()) + return df.withColumn("_spark_object_name", udf1(key_columns[0])) + class StreamTarget(BaseStoreTarget): kind = TargetTypes.stream diff --git a/mlrun/datastore/utils.py b/mlrun/datastore/utils.py index d417b492ed..9fd7be42fc 100644 --- a/mlrun/datastore/utils.py +++ b/mlrun/datastore/utils.py @@ -16,7 +16,11 @@ def store_path_to_spark(path): - if path.startswith("v3io:///"): + if path.startswith("redis://") or path.startswith("rediss://"): + url = urlparse(path) + if url.path: + path = url.path + elif path.startswith("v3io:///"): path = "v3io:" + path[len("v3io:/") :] elif path.startswith("s3://"): if path.startswith("s3:///"): diff --git a/mlrun/feature_store/api.py b/mlrun/feature_store/api.py index 52caed9cb8..a97d9f4bd6 100644 --- a/mlrun/feature_store/api.py +++ b/mlrun/feature_store/api.py @@ -886,7 +886,7 @@ def _ingest_with_spark( df_to_write = df_to_write.withColumn( partition, op(timestamp_col) ) - df_to_write = target.prepare_spark_df(df_to_write) + df_to_write = target.prepare_spark_df(df_to_write, key_columns) if overwrite: df_to_write.write.mode("overwrite").save(**spark_options) else: diff --git a/tests/system/feature_store/test_spark_engine.py b/tests/system/feature_store/test_spark_engine.py index b7a9917794..31ed1544c2 100644 --- a/tests/system/feature_store/test_spark_engine.py +++ b/tests/system/feature_store/test_spark_engine.py @@ -28,7 +28,12 @@ import mlrun.feature_store as fstore from mlrun import code_to_function, store_manager from mlrun.datastore.sources import CSVSource, ParquetSource -from mlrun.datastore.targets import CSVTarget, NoSqlTarget, ParquetTarget +from mlrun.datastore.targets import ( + CSVTarget, + NoSqlTarget, + ParquetTarget, + RedisNoSqlTarget, +) from mlrun.feature_store import FeatureSet from mlrun.feature_store.steps import ( DateExtractor, @@ -293,6 +298,50 @@ def test_ingest_to_csv(self): read_back_df_storey.sort_index(axis=1) ) + def test_ingest_to_redis(self): + key = "patient_id" + name = "measurements_spark" + + measurements = fstore.FeatureSet( + name, + entities=[fstore.Entity(key)], + timestamp_key="timestamp", + engine="spark", + ) + source = ParquetSource("myparquet", path=self.get_remote_pq_source_path()) + targets = [RedisNoSqlTarget()] + measurements.set_targets(targets, with_defaults=False) + fstore.ingest( + measurements, + source, + spark_context=self.spark_service, + run_config=fstore.RunConfig(False), + overwrite=True, + ) + # read the dataframe from the redis back + vector = fstore.FeatureVector("myvector", features=[f"{name}.*"]) + with fstore.get_online_feature_service(vector) as svc: + resp = svc.get([{"patient_id": "305-90-1613"}]) + assert resp == [ + { + "bad": 95, + "department": "01e9fe31-76de-45f0-9aed-0f94cc97bca0", + "room": 2, + "hr": 220.0, + "hr_is_error": False, + "rr": 25, + "rr_is_error": False, + "spo2": 99, + "spo2_is_error": False, + "movements": 4.614601941071927, + "movements_is_error": False, + "turn_count": 0.3582583538239813, + "turn_count_is_error": False, + "is_in_bed": 1, + "is_in_bed_is_error": False, + } + ] + # tests that data is filtered by time in scheduled jobs @pytest.mark.parametrize("partitioned", [True, False]) def test_schedule_on_filtered_by_time(self, partitioned):