Skip to content

Commit

Permalink
[Datastore] Add KafkaTarget and KafkaOutputStream (mlrun#2015)
Browse files Browse the repository at this point in the history
  • Loading branch information
Gal Topper authored Jun 12, 2022
1 parent 036e935 commit edbcee0
Show file tree
Hide file tree
Showing 15 changed files with 358 additions and 25 deletions.
5 changes: 4 additions & 1 deletion automation/package_test/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __init__(self):
"from mlrun.datastore.sources import BigQuerySource"
)
google_cloud_storage_import = "import mlrun.datastore.google_cloud_storage"
targets_import = "import mlrun.datastore.targets"

self._extras_tests_data = {
"": {"import_test_command": f"{basic_import}"},
Expand All @@ -45,9 +46,11 @@ def __init__(self):
"[google-cloud-storage]": {
"import_test_command": f"{basic_import}; {google_cloud_storage_import}"
},
# TODO: this won't actually fail if the requirement is missing
"[kafka]": {"import_test_command": f"{basic_import}; {targets_import}"},
"[complete]": {
"import_test_command": f"{basic_import}; {s3_import}; {azure_blob_storage_import}; "
+ f"{azure_key_vault_import}; {google_cloud_storage_import}",
+ f"{azure_key_vault_import}; {google_cloud_storage_import}; {targets_import}",
"perform_vulnerability_check": True,
},
}
Expand Down
1 change: 1 addition & 0 deletions extras-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@ bokeh~=2.4, >=2.4.2
gcsfs~=2021.8.1
plotly~=5.4
google-cloud-bigquery~=3.0
kafka-python~=2.0
13 changes: 11 additions & 2 deletions mlrun/datastore/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"CSVTarget",
"NoSqlTarget",
"StreamTarget",
"KafkaTarget",
"CSVSource",
"ParquetSource",
"BigQuerySource",
Expand All @@ -29,7 +30,7 @@
"KafkaSource",
]

from ..platforms.iguazio import OutputStream, parse_v3io_path
from ..platforms.iguazio import KafkaOutputStream, OutputStream, parse_v3io_path
from ..utils import logger
from .base import DataItem
from .datastore import StoreManager, in_memory_store, uri_to_ipython
Expand All @@ -49,6 +50,7 @@
parse_store_uri,
)
from .targets import CSVTarget, NoSqlTarget, ParquetTarget, StreamTarget
from .utils import parse_kafka_url

store_manager = StoreManager()

Expand All @@ -75,7 +77,14 @@ def get_stream_pusher(stream_path: str, **kwargs):
:param stream_path: path/url of stream
"""

if "://" not in stream_path:
if stream_path.startswith("kafka://") or "kafka_bootstrap_servers" in kwargs:
topic, bootstrap_servers = parse_kafka_url(
stream_path, kwargs.get("kafka_bootstrap_servers")
)
return KafkaOutputStream(
topic, bootstrap_servers, kwargs.get("kafka_producer_options")
)
elif "://" not in stream_path:
return OutputStream(stream_path, **kwargs)
elif stream_path.startswith("v3io"):
endpoint, stream_path = parse_v3io_path(stream_path)
Expand Down
2 changes: 1 addition & 1 deletion mlrun/datastore/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,7 +811,7 @@ def __init__(
def add_nuclio_trigger(self, function):
partitions = self.attributes.get("partitions")
trigger = KafkaTrigger(
brokers=self.attributes["brokers"],
brokers=self.attributes["brokers"].split(","),
topics=self.attributes["topics"],
partitions=partitions,
consumer_group=self.attributes["group"],
Expand Down
58 changes: 57 additions & 1 deletion mlrun/datastore/targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from .. import errors
from ..data_types import ValueType
from ..platforms.iguazio import parse_v3io_path, split_path
from .utils import store_path_to_spark
from .utils import parse_kafka_url, store_path_to_spark


class TargetTypes:
Expand All @@ -40,6 +40,7 @@ class TargetTypes:
nosql = "nosql"
tsdb = "tsdb"
stream = "stream"
kafka = "kafka"
dataframe = "dataframe"
custom = "custom"

Expand All @@ -51,6 +52,7 @@ def all():
TargetTypes.nosql,
TargetTypes.tsdb,
TargetTypes.stream,
TargetTypes.kafka,
TargetTypes.dataframe,
TargetTypes.custom,
]
Expand Down Expand Up @@ -1152,6 +1154,59 @@ def as_df(self, columns=None, df_module=None, **kwargs):
raise NotImplementedError()


class KafkaTarget(BaseStoreTarget):
kind = TargetTypes.kafka
is_table = False
is_online = False
support_spark = False
support_storey = True
support_append = True

def __init__(
self,
*args,
bootstrap_servers=None,
producer_options=None,
**kwargs,
):
attrs = {
"bootstrap_servers": bootstrap_servers,
"producer_options": producer_options,
}
super().__init__(*args, attributes=attrs, **kwargs)

def add_writer_step(
self,
graph,
after,
features,
key_columns=None,
timestamp_key=None,
featureset_status=None,
):
key_columns = list(key_columns.keys())
column_list = self._get_column_list(
features=features, timestamp_key=timestamp_key, key_columns=key_columns
)

bootstrap_servers = self.attributes.get("bootstrap_servers")
topic, bootstrap_servers = parse_kafka_url(self.path, bootstrap_servers)

graph.add_step(
name=self.name or "KafkaTarget",
after=after,
graph_shape="cylinder",
class_name="storey.KafkaTarget",
columns=column_list,
topic=topic,
bootstrap_servers=bootstrap_servers,
producer_options=self.attributes.get("producer_options"),
)

def as_df(self, columns=None, df_module=None, **kwargs):
raise NotImplementedError()


class TSDBTarget(BaseStoreTarget):
kind = TargetTypes.tsdb
is_table = False
Expand Down Expand Up @@ -1355,6 +1410,7 @@ def as_df(
TargetTypes.nosql: NoSqlTarget,
TargetTypes.dataframe: DFTarget,
TargetTypes.stream: StreamTarget,
TargetTypes.kafka: KafkaTarget,
TargetTypes.tsdb: TSDBTarget,
TargetTypes.custom: CustomTarget,
}
Expand Down
12 changes: 12 additions & 0 deletions mlrun/datastore/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from urllib.parse import urlparse


def store_path_to_spark(path):
if path.startswith("v3io:///"):
path = "v3io:" + path[len("v3io:/") :]
Expand All @@ -13,3 +16,12 @@ def store_path_to_spark(path):
else:
path = "s3a:" + path[len("s3:") :]
return path


def parse_kafka_url(url, bootstrap_servers=None):
bootstrap_servers = bootstrap_servers or []
url = urlparse(url)
if url.netloc:
bootstrap_servers = [url.netloc] + bootstrap_servers
topic = url.path
return topic, bootstrap_servers
57 changes: 55 additions & 2 deletions mlrun/platforms/iguazio.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@

_cached_control_session = None


VolumeMount = namedtuple("Mount", ["path", "sub_path"])


Expand Down Expand Up @@ -424,6 +423,61 @@ def dump_record(rec):
)


class KafkaOutputStream:
def __init__(
self,
topic,
brokers,
producer_options=None,
mock=False,
):
self._kafka_producer = None
self._topic = topic
self._brokers = brokers
self._producer_options = producer_options or {}

self._mock = mock
self._mock_queue = []

self._initialized = False

def _lazy_init(self):
if self._initialized:
return

import kafka

self._kafka_producer = kafka.KafkaProducer(
bootstrap_servers=self._brokers,
**self._producer_options,
)

self._initialized = True

def push(self, data):
self._lazy_init()

def dump_record(rec):
if isinstance(rec, bytes):
return rec

if not isinstance(rec, str):
rec = dict_to_json(rec)

return rec.encode("UTF-8")

if not isinstance(data, list):
data = [data]

if self._mock:
# for mock testing
self._mock_queue.extend(data)
else:
for record in data:
serialized_record = dump_record(record)
self._kafka_producer.send(self._topic, serialized_record)


class V3ioStreamClient:
def __init__(self, url: str, shard_id: int = 0, seek_to: str = None, **kwargs):
endpoint, stream_path = parse_v3io_path(url)
Expand Down Expand Up @@ -528,7 +582,6 @@ def is_iguazio_system_2_10_or_above(dashboard_url):
def add_or_refresh_credentials(
api_url: str, username: str = "", password: str = "", token: str = ""
) -> (str, str, str):

if is_iguazio_session(password):
return username, password, token

Expand Down
24 changes: 21 additions & 3 deletions mlrun/runtimes/serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
from typing import List, Union

import nuclio
from nuclio import KafkaTrigger

import mlrun
import mlrun.api.schemas

from ..datastore import parse_kafka_url
from ..model import ObjectList
from ..secrets import SecretsStore
from ..serving.server import GraphServer, create_graph_server
Expand Down Expand Up @@ -433,9 +435,25 @@ def _add_ref_triggers(self):

child_function = self._spec.function_refs[function_name]
trigger_args = stream.trigger_args or {}
child_function.function_object.add_v3io_stream_trigger(
stream.path, group=group, shards=stream.shards, **trigger_args
)

if (
stream.path.startswith("kafka://")
or "kafka_bootstrap_servers" in stream.options
):
brokers = stream.options.get("kafka_bootstrap_servers")
if brokers:
brokers = brokers.split(",")
topic, brokers = parse_kafka_url(stream.path, brokers)
trigger = KafkaTrigger(
brokers=brokers,
topics=[topic],
**trigger_args,
)
child_function.function_object.add_trigger("kafka", trigger)
else:
child_function.function_object.add_v3io_stream_trigger(
stream.path, group=group, shards=stream.shards, **trigger_args
)

def _deploy_function_refs(self, builder_env: dict = None):
"""set metadata and deploy child functions"""
Expand Down
41 changes: 28 additions & 13 deletions mlrun/serving/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from ..config import config
from ..datastore import get_stream_pusher
from ..datastore.utils import parse_kafka_url
from ..errors import MLRunInvalidArgumentError
from ..model import ModelObj, ObjectDict
from ..platforms.iguazio import parse_v3io_path
Expand Down Expand Up @@ -617,15 +618,11 @@ def __init__(
name: str = None,
path: str = None,
after: list = None,
shards: int = None,
retention_in_hours: int = None,
trigger_args: dict = None,
**options,
):
super().__init__(name, after)
self.path = path
self.shards = shards
self.retention_in_hours = retention_in_hours
self.options = options
self.trigger_args = trigger_args
self._stream = None
Expand All @@ -636,8 +633,7 @@ def init_object(self, context, namespace, mode="sync", reset=False, **extra_kwar
if self.path:
self._stream = get_stream_pusher(
self.path,
shards=self.shards,
retention_in_hours=self.retention_in_hours,
**self.options,
)
self._set_error_handler()

Expand Down Expand Up @@ -1418,14 +1414,33 @@ def _init_async_objects(context, steps):
if step.path and not skip_stream:
stream_path = step.path
endpoint = None
if "://" in stream_path:
endpoint, stream_path = parse_v3io_path(step.path)
stream_path = stream_path.strip("/")
step._async_object = storey.StreamTarget(
storey.V3ioDriver(endpoint),
stream_path,
context=context,
kafka_bootstrap_servers = step.options.get(
"kafka_bootstrap_servers"
)
if stream_path.startswith("kafka://") or kafka_bootstrap_servers:
topic, bootstrap_servers = parse_kafka_url(
stream_path, kafka_bootstrap_servers
)

kafka_producer_options = step.options.get(
"kafka_producer_options"
)

step._async_object = storey.KafkaTarget(
topic=topic,
bootstrap_servers=bootstrap_servers,
producer_options=kafka_producer_options,
context=context,
)
else:
if stream_path.startswith("v3io://"):
endpoint, stream_path = parse_v3io_path(step.path)
stream_path = stream_path.strip("/")
step._async_object = storey.StreamTarget(
storey.V3ioDriver(endpoint),
stream_path,
context=context,
)
else:
step._async_object = storey.Map(lambda x: x)

Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ fsspec~=2021.8.1
v3iofs~=0.1.7
# 3.4 and above failed builidng in some images - see https://github.com/pyca/cryptography/issues/5771
cryptography~=3.0, <3.4
storey~=1.1.1
storey~=1.1.3
deepdiff~=5.0
pymysql~=1.0
inflection~=0.5.0
Expand Down
Loading

0 comments on commit edbcee0

Please sign in to comment.