Skip to content

Commit

Permalink
[FeatureStore] Support SQL DBs as source and online target (mlrun#2869)
Browse files Browse the repository at this point in the history
* init sql

* lint

* time_fields

* .

* .

* lint

* test with mysql

* lint

* .

* .

* review

* env var

* .

* mlconf

* lint

* support pandas

* .
  • Loading branch information
davesh0812 authored Jan 19, 2023
1 parent e828af1 commit 4311f75
Show file tree
Hide file tree
Showing 10 changed files with 881 additions and 18 deletions.
1 change: 1 addition & 0 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ scikit-learn~=1.0
# needed for frameworks tests
lightgbm~=3.0
xgboost~=1.1
sqlalchemy_utils~=0.39.0
1 change: 1 addition & 0 deletions mlrun/api/crud/client_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def get_client_spec(self):
generate_artifact_target_path_from_artifact_hash=config.artifacts.generate_target_path_from_artifact_hash,
redis_url=config.redis.url,
redis_type=config.redis.type,
sql_url=config.sql.url,
# These don't have a default value, but we don't send them if they are not set to allow the client to know
# when to use server value and when to use client value (server only if set). Since their default value is
# empty and not set is also empty we can use the same _get_config_value_if_not_default
Expand Down
1 change: 1 addition & 0 deletions mlrun/api/schemas/client_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class ClientSpec(pydantic.BaseModel):
function: typing.Optional[Function]
redis_url: typing.Optional[str]
redis_type: typing.Optional[str]
sql_url: typing.Optional[str]

# ce_mode is deprecated, we will use the full ce config instead and ce_mode will be removed in 1.6.0
ce_mode: typing.Optional[str]
Expand Down
3 changes: 3 additions & 0 deletions mlrun/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,9 @@
"url": "",
"type": "standalone", # deprecated.
},
"sql": {
"url": "",
},
"v3io_framesd": "http://framesd:8080",
"datastore": {"async_source_mode": "disabled"},
# default node selector to be applied to all functions - json string base64 encoded format
Expand Down
111 changes: 111 additions & 0 deletions mlrun/datastore/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from datetime import datetime
from typing import Dict, List, Optional, Union

import pandas as pd
import v3io
import v3io.dataplane
from nuclio import KafkaTrigger
Expand Down Expand Up @@ -858,6 +859,115 @@ def add_nuclio_trigger(self, function):
return func


class SQLSource(BaseSourceDriver):
kind = "sqldb"
support_storey = True
support_spark = False

def __init__(
self,
name: str = "",
chunksize: int = None,
key_field: str = None,
time_field: str = None,
schedule: str = None,
start_time: Optional[Union[datetime, str]] = None,
end_time: Optional[Union[datetime, str]] = None,
db_url: str = None,
table_name: str = None,
spark_options: dict = None,
time_fields: List[str] = None,
):
"""
Reads SqlDB as input source for a flow.
example::
db_path = "mysql+pymysql://<username>:<password>@<host>:<port>/<db_name>"
source = SqlDBSource(
collection_name='source_name', db_path=self.db, key_field='key'
)
:param name: source name
:param chunksize: number of rows per chunk (default large single chunk)
:param key_field: the column to be used as the key for the collection.
:param time_field: the column to be parsed as timestamp for events. Defaults to None
:param start_time: filters out data before this time
:param end_time: filters out data after this time
:param schedule: string to configure scheduling of the ingestion job.
For example '*/30 * * * *' will
cause the job to run every 30 minutes
:param db_url: url string connection to sql database.
If not set, the MLRUN_SQL__URL environment variable will be used.
:param table_name: the name of the collection to access,
from the current database
:param spark_options: additional spark read options
:param time_fields : all the field to be parsed as timestamp.
"""

db_url = db_url or mlrun.mlconf.sql.url
if db_url is None:
raise mlrun.errors.MLRunInvalidArgumentError(
"cannot specify without db_path arg or secret MLRUN_SQL__URL"
)
attrs = {
"chunksize": chunksize,
"spark_options": spark_options,
"table_name": table_name,
"db_path": db_url,
"time_fields": time_fields,
}
attrs = {key: value for key, value in attrs.items() if value is not None}
super().__init__(
name,
attributes=attrs,
key_field=key_field,
time_field=time_field,
schedule=schedule,
start_time=start_time,
end_time=end_time,
)

def to_dataframe(self):
import sqlalchemy as db

query = self.attributes.get("query", None)
db_path = self.attributes.get("db_path")
table_name = self.attributes.get("table_name")
if not query:
query = f"SELECT * FROM {table_name}"
if table_name and db_path:
engine = db.create_engine(db_path)
with engine.connect() as con:
return pd.read_sql(
query,
con=con,
chunksize=self.attributes.get("chunksize"),
parse_dates=self.attributes.get("time_fields"),
)
else:
raise mlrun.errors.MLRunInvalidArgumentError(
"table_name and db_name args must be specified"
)

def to_step(self, key_field=None, time_field=None, context=None):
import storey

attributes = self.attributes or {}
if context:
attributes["context"] = context

return storey.SQLSource(
key_field=self.key_field or key_field,
time_field=self.time_field or time_field,
end_filter=self.end_time,
start_filter=self.start_time,
filter_column=self.time_field or time_field,
**attributes,
)
pass

def is_iterator(self):
return True if self.attributes.get("chunksize") else False


# map of sources (exclude DF source which is not serializable)
source_kind_to_driver = {
"": BaseSourceDriver,
Expand All @@ -869,4 +979,5 @@ def add_nuclio_trigger(self, function):
CustomSource.kind: CustomSource,
BigQuerySource.kind: BigQuerySource,
SnowflakeSource.kind: SnowflakeSource,
SQLSource.kind: SQLSource,
}
Loading

0 comments on commit 4311f75

Please sign in to comment.