diff --git a/dev-requirements.txt b/dev-requirements.txt index 479aca5e22..85fc4b9d49 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -19,3 +19,4 @@ scikit-learn~=1.0 # needed for frameworks tests lightgbm~=3.0 xgboost~=1.1 +sqlalchemy_utils~=0.39.0 diff --git a/mlrun/api/crud/client_spec.py b/mlrun/api/crud/client_spec.py index e57d4dcef2..608283e341 100644 --- a/mlrun/api/crud/client_spec.py +++ b/mlrun/api/crud/client_spec.py @@ -44,6 +44,7 @@ def get_client_spec(self): generate_artifact_target_path_from_artifact_hash=config.artifacts.generate_target_path_from_artifact_hash, redis_url=config.redis.url, redis_type=config.redis.type, + sql_url=config.sql.url, # These don't have a default value, but we don't send them if they are not set to allow the client to know # when to use server value and when to use client value (server only if set). Since their default value is # empty and not set is also empty we can use the same _get_config_value_if_not_default diff --git a/mlrun/api/schemas/client_spec.py b/mlrun/api/schemas/client_spec.py index e0d4d943e6..3488bf0e19 100644 --- a/mlrun/api/schemas/client_spec.py +++ b/mlrun/api/schemas/client_spec.py @@ -56,6 +56,7 @@ class ClientSpec(pydantic.BaseModel): function: typing.Optional[Function] redis_url: typing.Optional[str] redis_type: typing.Optional[str] + sql_url: typing.Optional[str] # ce_mode is deprecated, we will use the full ce config instead and ce_mode will be removed in 1.6.0 ce_mode: typing.Optional[str] diff --git a/mlrun/config.py b/mlrun/config.py index 248dfe53ce..ebedcf4dc3 100644 --- a/mlrun/config.py +++ b/mlrun/config.py @@ -117,6 +117,9 @@ "url": "", "type": "standalone", # deprecated. }, + "sql": { + "url": "", + }, "v3io_framesd": "http://framesd:8080", "datastore": {"async_source_mode": "disabled"}, # default node selector to be applied to all functions - json string base64 encoded format diff --git a/mlrun/datastore/sources.py b/mlrun/datastore/sources.py index cb9bf754a9..3c696e60f2 100644 --- a/mlrun/datastore/sources.py +++ b/mlrun/datastore/sources.py @@ -19,6 +19,7 @@ from datetime import datetime from typing import Dict, List, Optional, Union +import pandas as pd import v3io import v3io.dataplane from nuclio import KafkaTrigger @@ -858,6 +859,115 @@ def add_nuclio_trigger(self, function): return func +class SQLSource(BaseSourceDriver): + kind = "sqldb" + support_storey = True + support_spark = False + + def __init__( + self, + name: str = "", + chunksize: int = None, + key_field: str = None, + time_field: str = None, + schedule: str = None, + start_time: Optional[Union[datetime, str]] = None, + end_time: Optional[Union[datetime, str]] = None, + db_url: str = None, + table_name: str = None, + spark_options: dict = None, + time_fields: List[str] = None, + ): + """ + Reads SqlDB as input source for a flow. + example:: + db_path = "mysql+pymysql://:@:/" + source = SqlDBSource( + collection_name='source_name', db_path=self.db, key_field='key' + ) + :param name: source name + :param chunksize: number of rows per chunk (default large single chunk) + :param key_field: the column to be used as the key for the collection. + :param time_field: the column to be parsed as timestamp for events. Defaults to None + :param start_time: filters out data before this time + :param end_time: filters out data after this time + :param schedule: string to configure scheduling of the ingestion job. + For example '*/30 * * * *' will + cause the job to run every 30 minutes + :param db_url: url string connection to sql database. + If not set, the MLRUN_SQL__URL environment variable will be used. + :param table_name: the name of the collection to access, + from the current database + :param spark_options: additional spark read options + :param time_fields : all the field to be parsed as timestamp. + """ + + db_url = db_url or mlrun.mlconf.sql.url + if db_url is None: + raise mlrun.errors.MLRunInvalidArgumentError( + "cannot specify without db_path arg or secret MLRUN_SQL__URL" + ) + attrs = { + "chunksize": chunksize, + "spark_options": spark_options, + "table_name": table_name, + "db_path": db_url, + "time_fields": time_fields, + } + attrs = {key: value for key, value in attrs.items() if value is not None} + super().__init__( + name, + attributes=attrs, + key_field=key_field, + time_field=time_field, + schedule=schedule, + start_time=start_time, + end_time=end_time, + ) + + def to_dataframe(self): + import sqlalchemy as db + + query = self.attributes.get("query", None) + db_path = self.attributes.get("db_path") + table_name = self.attributes.get("table_name") + if not query: + query = f"SELECT * FROM {table_name}" + if table_name and db_path: + engine = db.create_engine(db_path) + with engine.connect() as con: + return pd.read_sql( + query, + con=con, + chunksize=self.attributes.get("chunksize"), + parse_dates=self.attributes.get("time_fields"), + ) + else: + raise mlrun.errors.MLRunInvalidArgumentError( + "table_name and db_name args must be specified" + ) + + def to_step(self, key_field=None, time_field=None, context=None): + import storey + + attributes = self.attributes or {} + if context: + attributes["context"] = context + + return storey.SQLSource( + key_field=self.key_field or key_field, + time_field=self.time_field or time_field, + end_filter=self.end_time, + start_filter=self.start_time, + filter_column=self.time_field or time_field, + **attributes, + ) + pass + + def is_iterator(self): + return True if self.attributes.get("chunksize") else False + + # map of sources (exclude DF source which is not serializable) source_kind_to_driver = { "": BaseSourceDriver, @@ -869,4 +979,5 @@ def add_nuclio_trigger(self, function): CustomSource.kind: CustomSource, BigQuerySource.kind: BigQuerySource, SnowflakeSource.kind: SnowflakeSource, + SQLSource.kind: SQLSource, } diff --git a/mlrun/datastore/targets.py b/mlrun/datastore/targets.py index e29f23203c..1020acf07e 100644 --- a/mlrun/datastore/targets.py +++ b/mlrun/datastore/targets.py @@ -11,16 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import ast +import datetime import os import random import time -import typing import warnings from collections import Counter from copy import copy -from typing import Union +from typing import Any, Dict, List, Optional, Union import pandas as pd +import sqlalchemy import mlrun import mlrun.utils.helpers @@ -45,6 +47,7 @@ class TargetTypes: kafka = "kafka" dataframe = "dataframe" custom = "custom" + sql = "sql" @staticmethod def all(): @@ -58,6 +61,7 @@ def all(): TargetTypes.kafka, TargetTypes.dataframe, TargetTypes.custom, + TargetTypes.sql, ] @@ -376,17 +380,18 @@ def __init__( self, name: str = "", path=None, - attributes: typing.Dict[str, str] = None, + attributes: Dict[str, str] = None, after_step=None, columns=None, partitioned: bool = False, - key_bucketing_number: typing.Optional[int] = None, - partition_cols: typing.Optional[typing.List[str]] = None, - time_partitioning_granularity: typing.Optional[str] = None, + key_bucketing_number: Optional[int] = None, + partition_cols: Optional[List[str]] = None, + time_partitioning_granularity: Optional[str] = None, after_state=None, - max_events: typing.Optional[int] = None, - flush_after_seconds: typing.Optional[int] = None, - storage_options: typing.Dict[str, str] = None, + max_events: Optional[int] = None, + flush_after_seconds: Optional[int] = None, + storage_options: Dict[str, str] = None, + schema: Dict[str, Any] = None, ): super().__init__( self.kind, @@ -401,6 +406,7 @@ def __init__( max_events, flush_after_seconds, after_state, + schema=schema, ) if after_state: warnings.warn( @@ -422,6 +428,7 @@ def __init__( self.max_events = max_events self.flush_after_seconds = flush_after_seconds self.storage_options = storage_options + self.schema = schema or {} self._target = None self._resource = None @@ -470,7 +477,7 @@ def write_dataframe( timestamp_key=None, chunk_id=0, **kwargs, - ) -> typing.Optional[int]: + ) -> Optional[int]: if hasattr(df, "rdd"): options = self.get_spark_options(key_column, timestamp_key) options.update(kwargs) @@ -559,6 +566,7 @@ def from_spec(cls, spec: DataTargetBase, resource=None): driver.name = spec.name driver.path = spec.path driver.attributes = spec.attributes + driver.schema = spec.schema if hasattr(spec, "columns"): driver.columns = spec.columns @@ -719,17 +727,17 @@ def __init__( self, name: str = "", path=None, - attributes: typing.Dict[str, str] = None, + attributes: Dict[str, str] = None, after_step=None, columns=None, partitioned: bool = None, - key_bucketing_number: typing.Optional[int] = None, - partition_cols: typing.Optional[typing.List[str]] = None, - time_partitioning_granularity: typing.Optional[str] = None, + key_bucketing_number: Optional[int] = None, + partition_cols: Optional[List[str]] = None, + time_partitioning_granularity: Optional[str] = None, after_state=None, - max_events: typing.Optional[int] = 10000, - flush_after_seconds: typing.Optional[int] = 900, - storage_options: typing.Dict[str, str] = None, + max_events: Optional[int] = 10000, + flush_after_seconds: Optional[int] = 900, + storage_options: Dict[str, str] = None, ): if after_state: warnings.warn( @@ -1487,6 +1495,280 @@ def as_df( return self._df +class SQLTarget(BaseStoreTarget): + kind = TargetTypes.sql + is_online = True + support_spark = False + support_storey = True + + def __init__( + self, + name: str = "", + path=None, + attributes: Dict[str, str] = None, + after_step=None, + partitioned: bool = False, + key_bucketing_number: Optional[int] = None, + partition_cols: Optional[List[str]] = None, + time_partitioning_granularity: Optional[str] = None, + after_state=None, + max_events: Optional[int] = None, + flush_after_seconds: Optional[int] = None, + storage_options: Dict[str, str] = None, + db_url: str = None, + table_name: str = None, + schema: Dict[str, Any] = None, + primary_key_column: str = "", + if_exists: str = "append", + create_table: bool = False, + # create_according_to_data: bool = False, + time_fields: List[str] = None, + varchar_len: int = 50, + ): + """ + Write to SqlDB as output target for a flow. + example:: + db_path = "sqlite:///stockmarket.db" + schema = {'time': datetime.datetime, 'ticker': str, + 'bid': float, 'ask': float, 'ind': int} + target = SqlDBTarget(table_name=f'{name}-tatget', db_path=db_path, create_table=True, + schema=schema, primary_key_column=key) + :param name: + :param path: + :param attributes: + :param after_step: + :param partitioned: + :param key_bucketing_number: + :param partition_cols: + :param time_partitioning_granularity: + :param after_state: + :param max_events: + :param flush_after_seconds: + :param storage_options: + :param db_url: url string connection to sql database. + If not set, the MLRUN_SQL__URL environment variable will + be used. + :param table_name: the name of the table to access, + from the current database + :param schema: the schema of the table (must pass when + create_table=True) + :param primary_key_column: the primary key of the table (must pass always) + :param if_exists: {'fail', 'replace', 'append'}, default 'append' + - fail: If table exists, do nothing. + - replace: If table exists, drop it, recreate it, and insert data. + - append: If table exists, insert data. Create if does not exist. + :param create_table: pass True if you want to create new table named by + table_name with schema on current database. + :param create_according_to_data: (not valid) + :param time_fields : all the field to be parsed as timestamp. + :param varchar_len : the defalut len of the all the varchar column (using if needed to create the table). + """ + create_according_to_data = False # TODO: open for user + db_url = db_url or mlrun.mlconf.sql.url + if db_url is None or table_name is None: + attr = {} + else: + # check for table existence and acts according to the user input + self._primary_key_column = primary_key_column + + attr = { + "table_name": table_name, + "db_path": db_url, + "create_according_to_data": create_according_to_data, + "if_exists": if_exists, + "time_fields": time_fields, + "varchar_len": varchar_len, + } + path = ( + f"mlrunSql://@{db_url}//@{table_name}" + f"//@{str(create_according_to_data)}//@{if_exists}//@{primary_key_column}//@{create_table}" + ) + + if attributes: + attributes.update(attr) + else: + attributes = attr + + super().__init__( + name, + path, + attributes, + after_step, + list(schema.keys()) if schema else None, + partitioned, + key_bucketing_number, + partition_cols, + time_partitioning_granularity, + max_events=max_events, + flush_after_seconds=flush_after_seconds, + storage_options=storage_options, + after_state=after_state, + schema=schema, + ) + + def add_writer_state( + self, graph, after, features, key_columns=None, timestamp_key=None + ): + warnings.warn( + "This method is deprecated. Use add_writer_step instead", + # TODO: In 0.7.0 do changes in examples & demos In 0.9.0 remove + PendingDeprecationWarning, + ) + """add storey writer state to graph""" + self.add_writer_step(graph, after, features, key_columns, timestamp_key) + + def get_table_object(self): + from storey import SQLDriver, Table + + (db_path, table_name, _, _, primary_key, _) = self._parse_url() + try: + primary_key = ast.literal_eval(primary_key) + except Exception: + pass + return Table( + f"{db_path}/{table_name}", + SQLDriver(db_path=db_path, primary_key=primary_key), + flush_interval_secs=mlrun.mlconf.feature_store.flush_interval, + ) + + def add_writer_step( + self, + graph, + after, + features, + key_columns=None, + timestamp_key=None, + featureset_status=None, + ): + key_columns = list(key_columns.keys()) + column_list = self._get_column_list( + features=features, timestamp_key=timestamp_key, key_columns=key_columns + ) + table = self._resource.uri + self._create_sql_table() + graph.add_step( + name=self.name or "SqlTarget", + after=after, + graph_shape="cylinder", + class_name="storey.NoSqlTarget", + columns=column_list, + header=True, + table=table, + index_cols=key_columns, + **self.attributes, + ) + + def as_df( + self, + columns=None, + df_module=None, + entities=None, + start_time=None, + end_time=None, + time_column=None, + **kwargs, + ): + db_path, table_name, _, _, _, _ = self._parse_url() + engine = sqlalchemy.create_engine(db_path) + with engine.connect() as conn: + df = pd.read_sql( + f"SELECT * FROM {self.attributes.get('table_name')}", + con=conn, + parse_dates=self.attributes.get("time_fields"), + ) + if self._primary_key_column: + df.set_index(self._primary_key_column, inplace=True) + if columns: + df = df[columns] + return df + + def write_dataframe( + self, df, key_column=None, timestamp_key=None, chunk_id=0, **kwargs + ): + self._create_sql_table() + + if hasattr(df, "rdd"): + raise ValueError("Spark is not supported") + else: + ( + db_path, + table_name, + create_according_to_data, + if_exists, + primary_key, + _, + ) = self._parse_url() + create_according_to_data = bool(create_according_to_data) + engine = sqlalchemy.create_engine( + db_path, + ) + connection = engine.connect() + if create_according_to_data: + # todo : create according to first row. + pass + df.to_sql(table_name, connection, if_exists=if_exists) + + def _parse_url(self): + path = self.path[len("mlrunSql:///") :] + return path.split("//@") + + def purge(self): + pass + + def _create_sql_table(self): + ( + db_path, + table_name, + create_according_to_data, + if_exists, + primary_key, + create_table, + ) = self._parse_url() + try: + primary_key = ast.literal_eval(primary_key) + primary_key_for_check = primary_key + except Exception: + primary_key_for_check = [primary_key] + engine = sqlalchemy.create_engine(db_path) + with engine.connect() as conn: + metadata = sqlalchemy.MetaData() + table_exists = engine.dialect.has_table(conn, table_name) + if not table_exists and not create_table: + raise ValueError(f"Table named {table_name} is not exist") + + elif not table_exists and create_table: + TYPE_TO_SQL_TYPE = { + int: sqlalchemy.Integer, + str: sqlalchemy.String(self.attributes.get("varchar_len")), + datetime.datetime: sqlalchemy.dialects.mysql.DATETIME(fsp=6), + pd.Timestamp: sqlalchemy.dialects.mysql.DATETIME(fsp=6), + bool: sqlalchemy.Boolean, + float: sqlalchemy.Float, + datetime.timedelta: sqlalchemy.Interval, + pd.Timedelta: sqlalchemy.Interval, + } + # creat new table with the given name + columns = [] + for col, col_type in self.schema.items(): + col_type = TYPE_TO_SQL_TYPE.get(col_type) + if col_type is None: + raise TypeError(f"{col_type} unsupported type") + columns.append( + sqlalchemy.Column( + col, col_type, primary_key=(col in primary_key_for_check) + ) + ) + + sqlalchemy.Table(table_name, metadata, *columns) + metadata.create_all(engine) + if_exists = "append" + self.path = ( + f"mlrunSql://@{db_path}//@{table_name}" + f"//@{str(create_according_to_data)}//@{if_exists}//@{primary_key}//@{create_table}" + ) + conn.close() + + kind_to_driver = { TargetTypes.parquet: ParquetTarget, TargetTypes.csv: CSVTarget, @@ -1497,6 +1779,7 @@ def as_df( TargetTypes.kafka: KafkaTarget, TargetTypes.tsdb: TSDBTarget, TargetTypes.custom: CustomTarget, + TargetTypes.sql: SQLTarget, } diff --git a/mlrun/db/httpdb.py b/mlrun/db/httpdb.py index c6dff635ad..427eece2c2 100644 --- a/mlrun/db/httpdb.py +++ b/mlrun/db/httpdb.py @@ -321,6 +321,7 @@ def connect(self, secrets=None): # allow client to set the default partial WA for lack of support of per-target auxiliary options config.redis.type = config.redis.type or server_cfg.get("redis_type") + config.sql.url = config.sql.url or server_cfg.get("sql_url") # These have a default value, therefore local config will always have a value, prioritize the # API value first config.ui.projects_prefix = ( diff --git a/mlrun/model.py b/mlrun/model.py index 53a763360d..a0b71cdbbe 100644 --- a/mlrun/model.py +++ b/mlrun/model.py @@ -21,7 +21,7 @@ from copy import deepcopy from datetime import datetime from os import environ -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import mlrun @@ -1238,6 +1238,7 @@ class DataTargetBase(ModelObj): "flush_after_seconds", "storage_options", "run_id", + "schema", ] # TODO - remove once "after_state" is fully deprecated @@ -1269,6 +1270,7 @@ def __init__( flush_after_seconds: Optional[int] = None, after_state=None, storage_options: Dict[str, str] = None, + schema: Dict[str, Any] = None, ): if after_state: warnings.warn( @@ -1292,6 +1294,7 @@ def __init__( self.flush_after_seconds = flush_after_seconds self.storage_options = storage_options self.run_id = None + self.schema = schema class FeatureSetProducer(ModelObj): diff --git a/tests/system/env-template.yml b/tests/system/env-template.yml index 70d728c637..2f8784c1b1 100644 --- a/tests/system/env-template.yml +++ b/tests/system/env-template.yml @@ -44,3 +44,6 @@ MLRUN_SYSTEM_TESTS_KAFKA_BROKERS: # slack webhook for sending system test reports MLRUN_SYSTEM_TESTS_SLACK_WEBHOOK_URL: + +# sql db path string for test_sql_db - e.g. mysql+pymysql://:@:/ +MLRUN_SQL__URL: diff --git a/tests/system/feature_store/test_sql_db.py b/tests/system/feature_store/test_sql_db.py new file mode 100644 index 0000000000..ba1c80068c --- /dev/null +++ b/tests/system/feature_store/test_sql_db.py @@ -0,0 +1,456 @@ +# Copyright 2022 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +from typing import List + +import pandas as pd +import pytest +import sqlalchemy as db + +import mlrun.feature_store as fs +from mlrun.datastore.sources import SQLSource +from mlrun.datastore.targets import SQLTarget +from mlrun.feature_store.steps import OneHotEncoder +from tests.system.base import TestMLRunSystem + + +@pytest.mark.enterprise +class TestFeatureStoreSqlDB(TestMLRunSystem): + project_name = "fs-system-test-sqldb" + + @classmethod + def _init_env_from_file(cls): + env = cls._get_env_from_file() + cls.db = env["MLRUN_SQL__URL"] + if cls.db == "" or cls.db is None: + pytest.skip("Environment variable MLRUN_SQL_DB_PATH_STRING is not defined") + cls.source_collection = "source_collection" + cls.target_collection = "target_collection" + + def custom_setup(self): + self._init_env_from_file() + self.prepare_data() + + def get_data(self, data_name: str): + if data_name == "stocks": + return self.stocks + elif data_name == "quotes": + return self.quotes + elif data_name == "trades": + return self.trades + else: + return None + + @staticmethod + def get_schema(data_name: str): + if data_name == "stocks": + return {"ticker": str, "name": str, "exchange": str} + elif data_name == "quotes": + return { + "time": datetime.datetime, + "ticker": str, + "bid": float, + "ask": float, + "ind": int, + } + elif data_name == "trades": + return { + "time": datetime.datetime, + "ticker": str, + "price": float, + "quantity": int, + "ind": int, + } + else: + return None + + @pytest.fixture(autouse=True) + def run_around_tests(self): + # create db if wasn't exist + from sqlalchemy_utils import create_database, database_exists + + engine = db.create_engine(self.db) + if not database_exists(engine.url): + create_database(engine.url) + + yield + + # drop all the collection on self.db + engine = db.create_engine(self.db) + with engine.connect(): + metadata = db.MetaData() + metadata.reflect(bind=engine) + # and drop them, if they exist + metadata.drop_all(bind=engine, checkfirst=True) + engine.dispose() + + @pytest.mark.parametrize( + "source_name, key, time_fields", + [("stocks", "ticker", None), ("trades", "ind", ["time"])], + ) + @pytest.mark.parametrize("fset_engine", ["pandas", "storey"]) + def test_sql_source_basic( + self, source_name: str, key: str, time_fields: List[str], fset_engine: str + ): + from sqlalchemy_utils import create_database, database_exists + + engine = db.create_engine(self.db) + if not database_exists(engine.url): + create_database(engine.url) + with engine.connect() as conn: + origin_df = self.get_data(source_name) + origin_df.to_sql( + source_name, + conn, + if_exists="replace", + index=False, + ) + conn.close() + source = SQLSource( + table_name=source_name, + key_field=key, + time_fields=time_fields, + ) + + feature_set = fs.FeatureSet( + f"fs-{source_name}", entities=[fs.Entity(key)], engine=fset_engine + ) + feature_set.set_targets([]) + df = fs.ingest(feature_set, source=source) + origin_df.set_index(keys=[key], inplace=True) + assert df.equals(origin_df) + + @pytest.mark.parametrize( + "source_name, key, encoder_col", + [ + ("stocks", "ticker", "exchange"), + ("quotes", "ind", "ticker"), + ], + ) + @pytest.mark.parametrize("fset_engine", ["pandas", "storey"]) + def test_sql_source_with_step( + self, source_name: str, key: str, encoder_col: str, fset_engine: str + ): + engine = db.create_engine(self.db) + with engine.connect() as conn: + origin_df = self.get_data(source_name) + origin_df.to_sql( + source_name, + conn, + if_exists="replace", + index=False, + dtype={"time": db.dialects.mysql.DATETIME(fsp=6)} + if source_name == "quotes" + else None, + ) + conn.close() + + # test source + source = SQLSource( + table_name=source_name, + key_field=key, + time_fields=["time"] if source_name == "quotes" else None, + ) + feature_set = fs.FeatureSet( + f"fs-{source_name}", entities=[fs.Entity(key)], engine=fset_engine + ) + one_hot_encoder_mapping = { + encoder_col: list(origin_df[encoder_col].unique()), + } + feature_set.graph.to(OneHotEncoder(mapping=one_hot_encoder_mapping)) + df = fs.ingest(feature_set, source=source) + + # reference source + feature_set_ref = fs.FeatureSet( + f"fs-{source_name}-ref", entities=[fs.Entity(key)], engine=fset_engine + ) + feature_set_ref.graph.to(OneHotEncoder(mapping=one_hot_encoder_mapping)) + df_ref = fs.ingest(feature_set_ref, origin_df) + + assert df.equals(df_ref) + + @pytest.mark.parametrize( + "source_name, key, aggr_col", + [("quotes", "ind", "ask"), ("trades", "ind", "price")], + ) + def test_sql_source_with_aggregation( + self, + source_name: str, + key: str, + aggr_col: str, + ): + engine = db.create_engine(self.db) + with engine.connect() as conn: + origin_df = self.get_data(source_name) + origin_df.to_sql( + source_name, + conn, + if_exists="replace", + index=False, + dtype={"time": db.dialects.mysql.DATETIME(fsp=6)} + if source_name == "quotes" + else None, + ) + conn.close() + + # test source + source = SQLSource(table_name=source_name, key_field=key, time_fields=["time"]) + feature_set = fs.FeatureSet(f"fs-{source_name}", entities=[fs.Entity(key)]) + feature_set.add_aggregation( + aggr_col, ["sum", "max"], "1h", "10m", name=f"{aggr_col}1" + ) + df = fs.ingest(feature_set, source=source) + + # reference source + feature_set_ref = fs.FeatureSet( + f"fs-{source_name}-ref", + entities=[fs.Entity(key)], + ) + feature_set_ref.add_aggregation( + aggr_col, ["sum", "max"], "1h", "10m", name=f"{aggr_col}1" + ) + df_ref = fs.ingest(feature_set_ref, origin_df) + + assert df.equals(df_ref) + + @pytest.mark.parametrize( + "target_name, key", [("stocks", "ticker"), ("quotes", "ind")] + ) + @pytest.mark.parametrize("fset_engine", ["pandas", "storey"]) + def test_sql_target_basic(self, target_name: str, key: str, fset_engine: str): + origin_df = self.get_data(target_name) + schema = self.get_schema(target_name) + + target = SQLTarget( + table_name=target_name, + create_table=True, + schema=schema, + primary_key_column=key, + time_fields=["time"], + ) + feature_set = fs.FeatureSet( + f"fs-{target_name}-tr", entities=[fs.Entity(key)], engine=fset_engine + ) + fs.ingest(feature_set, source=origin_df, targets=[target]) + df = target.as_df() + + origin_df.set_index(key, inplace=True) + columns = [*schema.keys()] + columns.remove(key) + df.sort_index(inplace=True), origin_df.sort_index(inplace=True) + + assert df[columns].equals(origin_df[columns]) + + @pytest.mark.parametrize( + "target_name, key", [("stocks", "ticker"), ("trades", "ind")] + ) + @pytest.mark.parametrize("fset_engine", ["pandas", "storey"]) + def test_sql_target_without_create( + self, target_name: str, key: str, fset_engine: str + ): + origin_df = self.get_data(target_name) + schema = self.get_schema(target_name) + engine = db.create_engine(self.db) + with engine.connect() as conn: + metadata = db.MetaData() + self._create(schema, target_name, metadata, engine, key) + conn.close() + + target = SQLTarget( + table_name=target_name, + create_table=False, + primary_key_column=key, + time_fields=["time"] if target_name == "trades" else None, + ) + feature_set = fs.FeatureSet( + f"fs-{target_name}-tr", entities=[fs.Entity(key)], engine=fset_engine + ) + fs.ingest(feature_set, source=origin_df, targets=[target]) + df = target.as_df() + + origin_df.set_index(key, inplace=True) + columns = [*schema.keys()] + columns.remove(key) + df.sort_index(inplace=True), origin_df.sort_index(inplace=True) + + assert df[columns].equals(origin_df[columns]) + + @pytest.mark.parametrize("target_name, key", [("quotes", "ind")]) + @pytest.mark.parametrize("fset_engine", ["pandas", "storey"]) + def test_sql_get_online_feature_basic( + self, target_name: str, key: str, fset_engine + ): + origin_df = self.get_data(target_name) + schema = self.get_schema(target_name) + + target = SQLTarget( + table_name=target_name, + create_table=True, + schema=schema, + primary_key_column=key, + time_fields=["time"], + ) + feature_set = fs.FeatureSet( + f"fs-{target_name}-tr", entities=[fs.Entity(key)], engine=fset_engine + ) + feature_set_ref = fs.FeatureSet( + f"fs-{target_name}-ref", entities=[fs.Entity(key)], engine=fset_engine + ) + fs.ingest(feature_set, source=origin_df, targets=[target]) + fs.ingest(feature_set_ref, source=origin_df) + columns = [*schema.keys()] + columns.remove(key) + + # reference + features_ref = [ + f"fs-{target_name}-ref.{columns[-1]}", + f"fs-{target_name}-ref.{columns[-2]}", + ] + vector = fs.FeatureVector( + f"{target_name}-vec", features_ref, description="my test vector" + ) + service_ref = fs.get_online_feature_service(vector) + ref_output = service_ref.get([{key: 1}], as_list=True) + + # test + features = [ + f"fs-{target_name}-tr.{columns[-1]}", + f"fs-{target_name}-tr.{columns[-2]}", + ] + vector = fs.FeatureVector( + f"{target_name}-vec", features, description="my test vector" + ) + with fs.get_online_feature_service(vector) as svc: + output = svc.get([{key: 1}], as_list=True) + assert ref_output == output + + @pytest.mark.parametrize("name, key", [("stocks", "ticker"), ("trades", "ind")]) + @pytest.mark.parametrize("fset_engine", ["pandas", "storey"]) + def test_sql_source_and_target_basic(self, name: str, key: str, fset_engine: str): + origin_df = self.get_data(name) + schema = self.get_schema(name) + table_name = f"{name}_target" + + engine = db.create_engine(self.db) + with engine.connect() as conn: + origin_df.to_sql(table_name, conn, if_exists="replace", index=False) + conn.close() + + source = SQLSource( + table_name=table_name, + key_field=key, + time_fields=["time"] if name == "trades" else None, + ) + + target = SQLTarget( + table_name=table_name, + create_table=True, + schema=schema, + primary_key_column=key, + time_fields=["time"] if name == "trades" else None, + ) + + targets = [target] + feature_set = fs.FeatureSet( + "sample_training_posts", + entities=[fs.Entity(key)], + description="feature set", + engine=fset_engine, + ) + + ingest_df = fs.ingest( + feature_set, + source=source, + targets=targets, + ) + + origin_df.set_index(keys=[key], inplace=True) + assert ingest_df.equals(origin_df) + + def prepare_data(self): + + self.quotes = pd.DataFrame( + { + "time": [ + pd.Timestamp("2016-05-25 13:30:00.023"), + pd.Timestamp("2016-05-25 13:30:00.023"), + pd.Timestamp("2016-05-25 13:30:00.030"), + pd.Timestamp("2016-05-25 13:30:00.041"), + pd.Timestamp("2016-05-25 13:30:00.048"), + pd.Timestamp("2016-05-25 13:30:00.049"), + pd.Timestamp("2016-05-25 13:30:00.072"), + pd.Timestamp("2016-05-25 13:30:00.075"), + ], + "ticker": [ + "GOOG", + "MSFT", + "MSFT", + "MSFT", + "GOOG", + "AAPL", + "GOOG", + "MSFT", + ], + "bid": [720.50, 51.95, 51.97, 51.99, 720.50, 97.99, 720.50, 52.01], + "ask": [720.93, 51.96, 51.98, 52.00, 720.93, 98.01, 720.88, 52.03], + "ind": [1, 2, 3, 4, 5, 6, 7, 8], + } + ) + + self.trades = pd.DataFrame( + { + "time": [ + pd.Timestamp("2016-05-25 13:30:23"), + pd.Timestamp("2016-05-25 13:30:38"), + pd.Timestamp("2016-05-25 13:30:48"), + pd.Timestamp("2016-05-25 13:30:48"), + pd.Timestamp("2016-05-25 13:30:48"), + ], + "ticker": ["MSFT", "MSFT", "GOOG", "GOOG", "AAPL"], + "price": [51.95, 51.95, 720.77, 720.92, 98.0], + "quantity": [75, 155, 100, 100, 100], + "ind": [1, 2, 3, 4, 5], + } + ) + + self.stocks = pd.DataFrame( + { + "ticker": ["MSFT", "GOOG", "AAPL"], + "name": ["Microsoft Corporation", "Alphabet Inc", "Apple Inc"], + "exchange": ["NASDAQ", "NASDAQ", "NASDAQ"], + } + ) + + def _create(self, schema, collection_name, metadata, engine, key): + columns = [] + for col, col_type in schema.items(): + if col_type == int: + col_type = db.Integer + elif col_type == str: + col_type = db.String(50) + elif col_type == datetime.timedelta or col_type == pd.Timedelta: + col_type = db.Interval + elif col_type == datetime.datetime or col_type == pd.Timestamp: + col_type = db.dialects.mysql.DATETIME(fsp=6) + elif col_type == bool: + col_type = db.Boolean + elif col_type == float: + col_type = db.Float + else: + raise TypeError(f"{col_type} unsupported type") + columns.append(db.Column(col, col_type, primary_key=(col == key))) + + db.Table(collection_name, metadata, *columns) + metadata.create_all(engine)