Skip to content

Commit

Permalink
[Spark] Add redis support (mlrun#2930)
Browse files Browse the repository at this point in the history
* [Spark] Add redis support

* Test and bug fixes

* Address assafb comments
  • Loading branch information
alxtkr77 authored Jan 22, 2023
1 parent 4f407d3 commit 4fbaa05
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 9 deletions.
7 changes: 5 additions & 2 deletions mlrun/datastore/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,10 @@ def rm(self, key, recursive=False, maxdepth=None):

if recursive:
key += "*" if key.endswith("/") else "/*"
for key in self.redis.scan_iter(key):
self.redis.delete(key)
for k in self.redis.scan_iter(key):
self.redis.delete(k)
key = f"_spark:{key}"
for k in self.redis.scan_iter(key):
self.redis.delete(k)
else:
self.redis.delete(key)
32 changes: 28 additions & 4 deletions mlrun/datastore/targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from collections import Counter
from copy import copy
from typing import Any, Dict, List, Optional, Union
from urllib.parse import urlparse

import pandas as pd
import sqlalchemy
Expand Down Expand Up @@ -684,7 +685,7 @@ def get_spark_options(self, key_column=None, timestamp_key=None, overwrite=True)
# options used in spark.read.load(**options)
raise NotImplementedError()

def prepare_spark_df(self, df):
def prepare_spark_df(self, df, key_columns):
return df

def get_dask_options(self):
Expand Down Expand Up @@ -998,7 +999,7 @@ def get_spark_options(self, key_column=None, timestamp_key=None, overwrite=True)
"header": "true",
}

def prepare_spark_df(self, df):
def prepare_spark_df(self, df, key_columns):
import pyspark.sql.functions as funcs

for col_name, col_type in df.dtypes:
Expand Down Expand Up @@ -1120,7 +1121,7 @@ def get_dask_options(self):
def as_df(self, columns=None, df_module=None, **kwargs):
raise NotImplementedError()

def prepare_spark_df(self, df):
def prepare_spark_df(self, df, key_columns):
import pyspark.sql.functions as funcs

for col_name, col_type in df.dtypes:
Expand Down Expand Up @@ -1174,7 +1175,7 @@ def get_table_object(self):

class RedisNoSqlTarget(NoSqlBaseTarget):
kind = TargetTypes.redisnosql
support_spark = False
support_spark = True
writer_step_name = "RedisNoSqlTarget"

def get_table_object(self):
Expand All @@ -1190,6 +1191,29 @@ def get_table_object(self):
flush_interval_secs=mlrun.mlconf.feature_store.flush_interval,
)

def get_spark_options(self, key_column=None, timestamp_key=None, overwrite=True):
parsed_url = urlparse(self.get_target_path())
if parsed_url.hostname is None:
self.path = mlrun.mlconf.redis.url
parsed_url = urlparse(self.get_target_path())

return {
"key.column": "_spark_object_name",
"table": "{" + store_path_to_spark(self.get_target_path()),
"format": "org.apache.spark.sql.redis",
"host": parsed_url.hostname,
"port": parsed_url.port if parsed_url.port else "6379",
"user": parsed_url.username,
"auth": parsed_url.password,
}

def prepare_spark_df(self, df, key_columns):
from pyspark.sql.functions import udf
from pyspark.sql.types import StringType

udf1 = udf(lambda x: x + "}:static", StringType())
return df.withColumn("_spark_object_name", udf1(key_columns[0]))


class StreamTarget(BaseStoreTarget):
kind = TargetTypes.stream
Expand Down
6 changes: 5 additions & 1 deletion mlrun/datastore/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@


def store_path_to_spark(path):
if path.startswith("v3io:///"):
if path.startswith("redis://") or path.startswith("rediss://"):
url = urlparse(path)
if url.path:
path = url.path
elif path.startswith("v3io:///"):
path = "v3io:" + path[len("v3io:/") :]
elif path.startswith("s3://"):
if path.startswith("s3:///"):
Expand Down
2 changes: 1 addition & 1 deletion mlrun/feature_store/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,7 +886,7 @@ def _ingest_with_spark(
df_to_write = df_to_write.withColumn(
partition, op(timestamp_col)
)
df_to_write = target.prepare_spark_df(df_to_write)
df_to_write = target.prepare_spark_df(df_to_write, key_columns)
if overwrite:
df_to_write.write.mode("overwrite").save(**spark_options)
else:
Expand Down
51 changes: 50 additions & 1 deletion tests/system/feature_store/test_spark_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,12 @@
import mlrun.feature_store as fstore
from mlrun import code_to_function, store_manager
from mlrun.datastore.sources import CSVSource, ParquetSource
from mlrun.datastore.targets import CSVTarget, NoSqlTarget, ParquetTarget
from mlrun.datastore.targets import (
CSVTarget,
NoSqlTarget,
ParquetTarget,
RedisNoSqlTarget,
)
from mlrun.feature_store import FeatureSet
from mlrun.feature_store.steps import (
DateExtractor,
Expand Down Expand Up @@ -293,6 +298,50 @@ def test_ingest_to_csv(self):
read_back_df_storey.sort_index(axis=1)
)

def test_ingest_to_redis(self):
key = "patient_id"
name = "measurements_spark"

measurements = fstore.FeatureSet(
name,
entities=[fstore.Entity(key)],
timestamp_key="timestamp",
engine="spark",
)
source = ParquetSource("myparquet", path=self.get_remote_pq_source_path())
targets = [RedisNoSqlTarget()]
measurements.set_targets(targets, with_defaults=False)
fstore.ingest(
measurements,
source,
spark_context=self.spark_service,
run_config=fstore.RunConfig(False),
overwrite=True,
)
# read the dataframe from the redis back
vector = fstore.FeatureVector("myvector", features=[f"{name}.*"])
with fstore.get_online_feature_service(vector) as svc:
resp = svc.get([{"patient_id": "305-90-1613"}])
assert resp == [
{
"bad": 95,
"department": "01e9fe31-76de-45f0-9aed-0f94cc97bca0",
"room": 2,
"hr": 220.0,
"hr_is_error": False,
"rr": 25,
"rr_is_error": False,
"spo2": 99,
"spo2_is_error": False,
"movements": 4.614601941071927,
"movements_is_error": False,
"turn_count": 0.3582583538239813,
"turn_count_is_error": False,
"is_in_bed": 1,
"is_in_bed_is_error": False,
}
]

# tests that data is filtered by time in scheduled jobs
@pytest.mark.parametrize("partitioned", [True, False])
def test_schedule_on_filtered_by_time(self, partitioned):
Expand Down

0 comments on commit 4fbaa05

Please sign in to comment.