From 87e5f93d47f076db7692d60df392f56182a0b2ba Mon Sep 17 00:00:00 2001 From: moranbental <107995850+moranbental@users.noreply.github.com> Date: Sun, 5 Jan 2025 15:06:09 +0200 Subject: [PATCH 01/15] [Pagination] Validate the Maximum Allowed Range for page and page_size (#7058) --- mlrun/config.py | 2 + server/py/framework/db/sqldb/db.py | 39 +++++++++++++++++++ .../tests/unit/crud/test_pagination_cache.py | 22 +++++++++++ .../api/tests/unit/utils/test_pagination.py | 30 ++++++++++++++ server/py/services/api/utils/pagination.py | 13 +++++++ 5 files changed, 106 insertions(+) diff --git a/mlrun/config.py b/mlrun/config.py index 848939ade34..322260ce488 100644 --- a/mlrun/config.py +++ b/mlrun/config.py @@ -537,6 +537,8 @@ }, "pagination": { "default_page_size": 200, + "page_limit": 1000000, + "page_size_limit": 1000000, "pagination_cache": { "interval": 60, "ttl": 3600, diff --git a/server/py/framework/db/sqldb/db.py b/server/py/framework/db/sqldb/db.py index 30412a81a2e..20f46db4d7e 100644 --- a/server/py/framework/db/sqldb/db.py +++ b/server/py/framework/db/sqldb/db.py @@ -128,6 +128,10 @@ NULL = None # Avoid flake8 issuing warnings when comparing in filter unversioned_tagged_object_uid_prefix = "unversioned-" +# Max values for 32-bit and 64-bit signed integers +MAX_INT_32 = 2_147_483_647 # For Integer (4-byte) +MAX_INT_64 = 9_223_372_036_854_775_807 # For BigInteger (8-byte) + conflict_messages = [ "(sqlite3.IntegrityError) UNIQUE constraint failed", "(pymysql.err.IntegrityError) (1062", @@ -6956,6 +6960,13 @@ def store_paginated_query_cache_record( page_size: int, kwargs: dict, ): + self._validate_integer_max_value( + PaginationCache.__table__.c.current_page, current_page + ) + self._validate_integer_max_value( + PaginationCache.__table__.c.page_size, page_size + ) + # generate key hash from user, function, current_page and kwargs key = hashlib.sha256( f"{user}/{function}/{page_size}/{kwargs}".encode() @@ -7350,3 +7361,31 @@ def _paginate_query( query = query.limit(limit) return query + + @staticmethod + def _validate_integer_max_value(column: Column, value: int): + """ + Validate that the value of a column does not exceed the max allowed integer value for that column's type. + + :param column: The SQLAlchemy column to check (e.g., PaginationCache.__table__.c.current_page). + :param value: The value to validate. + :raises: MLRunInvalidArgumentError if value exceeds the max allowed integer value for the column's type. + """ + if isinstance(column.type, sqlalchemy.Integer): + # Validate against 32-bit max + if value > MAX_INT_32: + raise mlrun.errors.MLRunInvalidArgumentError( + f"The '{column.name}' field value must be less than or equal to {MAX_INT_32}." + ) + + elif isinstance(column.type, sqlalchemy.BigInteger): + # Validate against 64-bit max + if value > MAX_INT_64: + raise mlrun.errors.MLRunInvalidArgumentError( + f"The '{column.name}' field value must be less than or equal to {MAX_INT_64}." + ) + + else: + raise mlrun.errors.MLRunInvalidArgumentError( + f"Unsupported column type '{column.type}' for validation." + ) diff --git a/server/py/services/api/tests/unit/crud/test_pagination_cache.py b/server/py/services/api/tests/unit/crud/test_pagination_cache.py index 95a9639ca1f..5e19fc1737a 100644 --- a/server/py/services/api/tests/unit/crud/test_pagination_cache.py +++ b/server/py/services/api/tests/unit/crud/test_pagination_cache.py @@ -15,12 +15,15 @@ import time +import pytest import sqlalchemy.orm +import mlrun.errors from mlrun import mlconf from mlrun.utils import logger import services.api.crud +from framework.db.sqldb.db import MAX_INT_32 def test_pagination_cache_monitor_ttl(db: sqlalchemy.orm.Session): @@ -155,3 +158,22 @@ def test_pagination_cleanup(db: sqlalchemy.orm.Session): assert ( len(services.api.crud.PaginationCache().list_pagination_cache_records(db)) == 0 ) + + +@pytest.mark.parametrize( + "page, page_size", + [ + (MAX_INT_32 + 1, 100), # page exceeds max allowed value + (200, MAX_INT_32 + 1), # page_size exceeds max allowed value + ], +) +def test_store_paginated_query_cache_record_out_of_range( + db: sqlalchemy.orm.Session, page: int, page_size: int +): + method = services.api.crud.Projects().list_projects + kwargs = {} + + with pytest.raises(mlrun.errors.MLRunInvalidArgumentError): + services.api.crud.PaginationCache().store_pagination_cache_record( + db, "user_name", method, page, page_size, kwargs + ) diff --git a/server/py/services/api/tests/unit/utils/test_pagination.py b/server/py/services/api/tests/unit/utils/test_pagination.py index ff809713800..f7f00506749 100644 --- a/server/py/services/api/tests/unit/utils/test_pagination.py +++ b/server/py/services/api/tests/unit/utils/test_pagination.py @@ -19,6 +19,7 @@ import sqlalchemy.orm import mlrun.common.schemas +from mlrun import mlconf from mlrun.utils import logger import framework.db.sqldb.models @@ -665,6 +666,35 @@ async def filter_(items): ) +@pytest.mark.asyncio +@pytest.mark.parametrize( + "page, page_size", + [ + (mlconf.httpdb.pagination.page_limit + 1, 1), # page exceeds max allowed value + ( + 1, + mlconf.httpdb.pagination.page_size_limit + 1, + ), # page_size exceeds max allowed value + ], +) +async def test_paginate_request_invalid_page_or_page_size( + mock_paginated_method, + cleanup_pagination_cache_on_teardown, + db: sqlalchemy.orm.Session, + page, + page_size, +): + auth_info = mlrun.common.schemas.AuthInfo(user_id="user1") + method_kwargs = {"total_amount": 5, "since": datetime.datetime.now()} + + paginator = services.api.utils.pagination.Paginator() + + with pytest.raises(mlrun.errors.MLRunInvalidArgumentError): + await paginator.paginate_request( + db, paginated_method, auth_info, None, page, page_size, **method_kwargs + ) + + def _assert_paginated_response( response, pagination_info, diff --git a/server/py/services/api/utils/pagination.py b/server/py/services/api/utils/pagination.py index f37ad2fd707..562533e3a12 100644 --- a/server/py/services/api/utils/pagination.py +++ b/server/py/services/api/utils/pagination.py @@ -167,6 +167,19 @@ async def paginate_request( method, session, **method_kwargs ), None + if page is not None and page > mlconf.httpdb.pagination.page_limit: + raise mlrun.errors.MLRunInvalidArgumentError( + f"'page' must be less than or equal to {mlconf.httpdb.pagination.page_limit}" + ) + + if ( + page_size is not None + and page_size > mlconf.httpdb.pagination.page_size_limit + ): + raise mlrun.errors.MLRunInvalidArgumentError( + f"'page_size' must be less than or equal to {mlconf.httpdb.pagination.page_size_limit}" + ) + page_size = page_size or mlconf.httpdb.pagination.default_page_size ( From 7f5f7af922e68338905bd5789cb010ddeb38c03b Mon Sep 17 00:00:00 2001 From: roei3000b <40743125+roei3000b@users.noreply.github.com> Date: Sun, 5 Jan 2025 18:55:14 +0200 Subject: [PATCH 02/15] [Notifications] Add refresh_smtp_configuration to RunDBInterface (#7064) Add `refresh_smtp_configuration` to RunDBInterface. --- mlrun/db/base.py | 3 +++ mlrun/db/nopdb.py | 3 +++ server/py/framework/rundb/sqldb.py | 3 +++ 3 files changed, 9 insertions(+) diff --git a/mlrun/db/base.py b/mlrun/db/base.py index d3058e4ce8f..959369ef4cd 100644 --- a/mlrun/db/base.py +++ b/mlrun/db/base.py @@ -68,6 +68,9 @@ def push_run_notifications( ): pass + def refresh_smtp_configuration(self): + pass + def push_pipeline_notifications( self, pipeline_id, diff --git a/mlrun/db/nopdb.py b/mlrun/db/nopdb.py index ab30367f551..e01756fb515 100644 --- a/mlrun/db/nopdb.py +++ b/mlrun/db/nopdb.py @@ -84,6 +84,9 @@ def push_run_notifications( ): pass + def refresh_smtp_configuration(self): + pass + def push_pipeline_notifications( self, pipeline_id, diff --git a/server/py/framework/rundb/sqldb.py b/server/py/framework/rundb/sqldb.py index 6364d257d92..1adc91f2c7c 100644 --- a/server/py/framework/rundb/sqldb.py +++ b/server/py/framework/rundb/sqldb.py @@ -115,6 +115,9 @@ def push_run_notifications( ): raise NotImplementedError() + def refresh_smtp_configuration(self): + raise NotImplementedError() + def push_pipeline_notifications( self, pipeline_id, From 2d0f7acfdf984980e8cf407beadfd8a2a14ac6d7 Mon Sep 17 00:00:00 2001 From: roei3000b <40743125+roei3000b@users.noreply.github.com> Date: Sun, 5 Jan 2025 21:07:18 +0200 Subject: [PATCH 03/15] [Notifications] Add mail notifications docs (#7063) --- docs/concepts/notifications.md | 54 +++++++++++++++++++++++++++++++++- 1 file changed, 53 insertions(+), 1 deletion(-) diff --git a/docs/concepts/notifications.md b/docs/concepts/notifications.md index df80cba76b3..be836c6b8c8 100644 --- a/docs/concepts/notifications.md +++ b/docs/concepts/notifications.md @@ -8,6 +8,7 @@ MLRun supports configuring notifications on jobs and scheduled jobs. This sectio - [Local vs. remote](#local-vs-remote) - [Notification parameters and secrets](#notification-parameters-and-secrets) - [Notification kinds](#notification-kinds) +- [Mail notifications](#mail-notifications) - [Configuring notifications for runs](#configuring-notifications-for-runs) - [Configuring notifications for pipelines](#configuring-notifications-for-pipelines) - [Setting notifications on live runs](#setting-notifications-on-live-runs) @@ -24,7 +25,7 @@ Notifications can be sent either locally from the SDK, or remotely from the MLRu Usually, a local run sends locally, and a remote run sends remotely. However, there are several special cases where the notification is sent locally either way. These cases are: -- Local or KFP Engine Pipelines: To conserve backwards compatibility, the SDK sends the notifications as it did before adding the run +- Local: To conserve backwards compatibility, the SDK sends the notifications as it did before adding the run notifications mechanism. This means you need to watch the pipeline in order for its notifications to be sent. (Remote pipelines act differently. See [Configuring Notifications For Pipelines](#configuring-notifications-for-pipelines) for more details. - Dask: Dask runs are always local (against a remote Dask cluster), so the notifications are sent locally as well. @@ -45,6 +46,57 @@ It's essential to utilize `secret_params` exclusively for handling sensitive inf See {py:class}`~mlrun.common.schemas.notification.NotificationKind`. +## Mail notifications +To send mail notifications, you need an existing SMTP server. +```python +mail_notification = mlrun.model.Notification( + kind="mail", + when=["completed", "error", "running"], + name="mail-notification", + message="", + condition="", + severity="verbose", + params={ + "start_tls": True, + "use_tls": False, + "validate_certs": False, + "email_addresses": ["user.name@domain.com"], + }, +) +``` +We use the [aiosmtplib](https://aiosmtplib.readthedocs.io/en/stable/) library for sending mail notifications. +The `params` argument is a dictionary, that supports the following fields: + - server_host (string): The SMTP server host. + - server_port (int): The SMTP server port. + - sender_address (string): The sender email address. + - username (string): The username for the SMTP server. + - password (string): The password for the SMTP server. + - email_addresses (list of strings): The list of email addresses to send the mail to. + - start_tls (boolean): Whether to start the TLS connection. + - use_tls (boolean): Whether to use TLS. + - validate_certs (boolean): Whether to validate the certificates. + +You can read more about `start_tls` and `use_tls` on the [aiosmtplib docs](https://aiosmtplib.readthedocs.io/en/stable/encryption.html). +Missing params are enriched with default values which can be configured in the `mlrun-smtp-config` kubernetes (see below). + +### MLRun on Iguazio +If MLRun is deployed on the Iguazio platform, an SMTP server already exists. +To use it, run the following (with privileged user - `IT Admin`): +```python +import mlrun + +mlrun.get_run_db().refresh_smtp_configuration() +``` +The `refresh_smtp_configuration` method will get the smtp configuration from the Iguazio platform and set it +as the default smtp configuration (create a `mlrun-smtp-config` with the smtp configuration). +If you edit the configuration on the Iguazio platform, you should run the `refresh_smtp_configuration` method again. + +### MLRun CE +In the community edition, you can use your own SMTP server. +To configure it, manually create the `mlrun-smtp-config` kubernetes secret with the default +params for the SMTP server (`server_host`, `server_port`, `username`, `password`, etc..). +After creating or editing the secret, refresh the mlrun SMTP configuration by running the `refresh_smtp_configuration` method. + ## Configuring notifications for runs In any `run` method you can configure the notifications via their model. For example: From eaefd2c16473e2051ea63ce557420d1383467aba Mon Sep 17 00:00:00 2001 From: Katerina Molchanova <35141662+rokatyy@users.noreply.github.com> Date: Sun, 5 Jan 2025 19:46:56 +0000 Subject: [PATCH 04/15] [Alerts] Fix test_job_failure_alert_sliding_window (#7051) --- tests/system/alerts/test_alerts.py | 394 ++++++++++++++++------------- 1 file changed, 213 insertions(+), 181 deletions(-) diff --git a/tests/system/alerts/test_alerts.py b/tests/system/alerts/test_alerts.py index 06246b41684..58a011f6ea9 100644 --- a/tests/system/alerts/test_alerts.py +++ b/tests/system/alerts/test_alerts.py @@ -25,6 +25,7 @@ import mlrun.common.schemas.model_monitoring.constants as mm_constants import mlrun.model_monitoring.api import tests.system.common.helpers.notifications as notification_helpers +from mlrun import mlconf from mlrun.common.schemas.model_monitoring.model_endpoints import ( ModelEndpoint, ModelEndpointList, @@ -46,8 +47,9 @@ def test_job_failure_alert(self): """ validate that an alert is sent in case a job fails """ + function_name = "test-func-job-failure-alert" self.project.set_function( - name="test-func", + name=function_name, func=str(self.assets_path / "function.py"), handler="handler", image="mlrun/mlrun" if self.image is None else self.image, @@ -62,7 +64,7 @@ def test_job_failure_alert(self): # create an alert with webhook notification alert_name = "failure-webhook" alert_summary = "Job failed" - run_id = "test-func-handler" + run_id = f"{function_name}-handler" notifications = self._generate_failure_notifications(nuclio_function_url) self._create_custom_alert_config( name=alert_name, @@ -74,19 +76,18 @@ def test_job_failure_alert(self): ) with pytest.raises(Exception): - self.project.run_function("test-func", watch=False) - self.project.run_function("test-func") + self.project.run_function(function_name, watch=False) + self.project.run_function(function_name) # in order to trigger the periodic monitor runs function, to detect the failed run and send an event on it - time.sleep(35) - - # get project summary to validate the alert activations counters - project_summary = mlrun.get_run_db().get_project_summary( - project=self.project_name + mlrun.utils.retry_until_successful( + 3, + 10 * 6, + self._logger, + True, + self._validate_project_alerts_summary, + expected_job_alerts_count=2, ) - assert project_summary.job_alerts_count == 2 - assert project_summary.endpoint_alerts_count == 0 - assert project_summary.other_alerts_count == 0 # Validate that the notifications was sent on the failed job expected_notifications = ["notification failure"] @@ -146,160 +147,9 @@ def test_job_failure_alert(self): assert len(entities) == 0 assert token is None - @staticmethod - def _generate_typical_event( - endpoint_id: str, - result_name: str, - endpoint_name: str, - ) -> dict[str, typing.Any]: - return { - mm_constants.WriterEvent.ENDPOINT_ID: endpoint_id, - mm_constants.WriterEvent.ENDPOINT_NAME: endpoint_name, - mm_constants.WriterEvent.APPLICATION_NAME: mm_constants.HistogramDataDriftApplicationConstants.NAME, - mm_constants.WriterEvent.START_INFER_TIME: "2023-09-11T12:00:00", - mm_constants.WriterEvent.END_INFER_TIME: "2023-09-11T12:01:00", - mm_constants.WriterEvent.EVENT_KIND: "result", - mm_constants.WriterEvent.DATA: json.dumps( - { - mm_constants.ResultData.RESULT_NAME: result_name, - mm_constants.ResultData.RESULT_KIND: mm_constants.ResultKindApp.model_performance.value, - mm_constants.ResultData.RESULT_VALUE: 0.1, - mm_constants.ResultData.RESULT_STATUS: mm_constants.ResultStatusApp.detected.value, - mm_constants.ResultData.RESULT_EXTRA_DATA: json.dumps( - {"threshold": 0.3} - ), - } - ), - } - - @staticmethod - def _generate_anomaly_events( - endpoint_id: str, - result_name: str, - endpoint_name: str, - ) -> list[dict[str, typing.Any]]: - data_drift_example = { - mm_constants.WriterEvent.ENDPOINT_ID: endpoint_id, - mm_constants.WriterEvent.ENDPOINT_NAME: endpoint_name, - mm_constants.WriterEvent.APPLICATION_NAME: mm_constants.HistogramDataDriftApplicationConstants.NAME, - mm_constants.WriterEvent.START_INFER_TIME: "2023-09-11T12:00:00", - mm_constants.WriterEvent.END_INFER_TIME: "2023-09-11T12:01:00", - mm_constants.WriterEvent.EVENT_KIND: "result", - mm_constants.WriterEvent.DATA: json.dumps( - { - mm_constants.ResultData.RESULT_NAME: result_name, - mm_constants.ResultData.RESULT_KIND: mm_constants.ResultKindApp.data_drift.value, - mm_constants.ResultData.RESULT_VALUE: 0.5, - mm_constants.ResultData.RESULT_STATUS: mm_constants.ResultStatusApp.detected.value, - mm_constants.ResultData.RESULT_EXTRA_DATA: json.dumps( - {"threshold": 0.3} - ), - } - ), - } - - concept_drift_example = { - mm_constants.WriterEvent.ENDPOINT_ID: endpoint_id, - mm_constants.WriterEvent.ENDPOINT_NAME: endpoint_name, - mm_constants.WriterEvent.APPLICATION_NAME: mm_constants.HistogramDataDriftApplicationConstants.NAME, - mm_constants.WriterEvent.START_INFER_TIME: "2023-09-11T12:00:00", - mm_constants.WriterEvent.END_INFER_TIME: "2023-09-11T12:01:00", - mm_constants.WriterEvent.EVENT_KIND: "result", - mm_constants.WriterEvent.DATA: json.dumps( - { - mm_constants.ResultData.RESULT_NAME: result_name, - mm_constants.ResultData.RESULT_KIND: mm_constants.ResultKindApp.concept_drift.value, - mm_constants.ResultData.RESULT_VALUE: 0.9, - mm_constants.ResultData.RESULT_STATUS: mm_constants.ResultStatusApp.potential_detection.value, - mm_constants.ResultData.RESULT_EXTRA_DATA: json.dumps( - {"threshold": 0.7} - ), - } - ), - } - - anomaly_example = { - mm_constants.WriterEvent.ENDPOINT_ID: endpoint_id, - mm_constants.WriterEvent.ENDPOINT_NAME: endpoint_name, - mm_constants.WriterEvent.APPLICATION_NAME: mm_constants.HistogramDataDriftApplicationConstants.NAME, - mm_constants.WriterEvent.START_INFER_TIME: "2023-09-11T12:00:00", - mm_constants.WriterEvent.END_INFER_TIME: "2023-09-11T12:01:00", - mm_constants.WriterEvent.EVENT_KIND: "result", - mm_constants.WriterEvent.DATA: json.dumps( - { - mm_constants.ResultData.RESULT_NAME: result_name, - mm_constants.ResultData.RESULT_KIND: mm_constants.ResultKindApp.mm_app_anomaly.value, - mm_constants.ResultData.RESULT_VALUE: 0.9, - mm_constants.ResultData.RESULT_STATUS: mm_constants.ResultStatusApp.detected.value, - mm_constants.ResultData.RESULT_EXTRA_DATA: json.dumps( - {"threshold": 0.4} - ), - } - ), - } - - system_performance_example = { - mm_constants.WriterEvent.ENDPOINT_ID: endpoint_id, - mm_constants.WriterEvent.ENDPOINT_NAME: endpoint_name, - mm_constants.WriterEvent.APPLICATION_NAME: mm_constants.HistogramDataDriftApplicationConstants.NAME, - mm_constants.WriterEvent.START_INFER_TIME: "2023-09-11T12:00:00", - mm_constants.WriterEvent.END_INFER_TIME: "2023-09-11T12:01:00", - mm_constants.WriterEvent.EVENT_KIND: "result", - mm_constants.WriterEvent.DATA: json.dumps( - { - mm_constants.ResultData.RESULT_NAME: result_name, - mm_constants.ResultData.RESULT_KIND: mm_constants.ResultKindApp.system_performance.value, - mm_constants.ResultData.RESULT_VALUE: 0.9, - mm_constants.ResultData.RESULT_STATUS: mm_constants.ResultStatusApp.detected.value, - mm_constants.ResultData.RESULT_EXTRA_DATA: json.dumps( - {"threshold": 0.4} - ), - } - ), - } - - return [ - data_drift_example, - concept_drift_example, - anomaly_example, - system_performance_example, - ] - - def _generate_alerts( - self, nuclio_function_url: str, model_endpoint: ModelEndpoint - ) -> list[str]: - """Generate alerts for the different result kind and return data from the expected notifications.""" - expected_notifications = [] - alerts_kind_to_test = [ - alert_objects.EventKind.DATA_DRIFT_DETECTED, - alert_objects.EventKind.CONCEPT_DRIFT_SUSPECTED, - alert_objects.EventKind.MM_APP_ANOMALY_DETECTED, - alert_objects.EventKind.SYSTEM_PERFORMANCE_DETECTED, - ] - # Create alert configurations for each alert_kind individually. - # This ensures different notifications are raised, which is why we don't send all alert_kinds as events at once. - - for alert_kind in alerts_kind_to_test: - alert_name = mlrun.utils.helpers.normalize_name( - f"drift-webhook-{alert_kind}" - ) - alert_summary = "Model is drifting" - self._create_alert_config( - name=alert_name, - summary=alert_summary, - model_endpoint=model_endpoint, - events=[alert_kind], - notifications=self._generate_drift_notifications( - nuclio_function_url, alert_kind.value - ), - ) - expected_notifications.extend( - [ - f"first drift of {alert_kind.value}", - f"second drift of {alert_kind.value}", - ] - ) - return expected_notifications + mlrun.get_run_db().delete_function( + name=function_name, project=self.project.name + ) @pytest.mark.model_monitoring def test_drift_detection_alert(self): @@ -365,14 +215,14 @@ def test_drift_detection_alert(self): ) # wait for the periodic project summaries calculation to start - time.sleep(20) - # validate the alert activations counters - project_summary = mlrun.get_run_db().get_project_summary( - project=self.project_name + mlrun.utils.retry_until_successful( + 3, + 10 * 6, + self._logger, + True, + self._validate_project_alerts_summary, + expected_endpoint_alerts_count=4, ) - assert project_summary.job_alerts_count == 0 - assert project_summary.endpoint_alerts_count == 4 - assert project_summary.other_alerts_count == 0 def test_job_failure_alert_sliding_window(self): """ @@ -386,9 +236,9 @@ def test_job_failure_alert_sliding_window(self): another job failure to confirm that the alert does not trigger prematurely. Finally, a third failure within the adjusted window is used to confirm that the alert triggers as expected. """ - + function_name = "test-func-failure-alert-sliding-window" self.project.set_function( - name="test-func", + name=function_name, func=str(self.assets_path / "function.py"), handler="handler", image="mlrun/mlrun" if self.image is None else self.image, @@ -404,7 +254,7 @@ def test_job_failure_alert_sliding_window(self): alert_name = "failure-webhook" alert_summary = "Job failed" alert_criteria = alert_objects.AlertCriteria(period="2m", count=2) - run_id = "test-func-handler" + run_id = f"{function_name}-handler" notifications = self._generate_failure_notifications(nuclio_function_url) self._create_custom_alert_config( @@ -419,14 +269,17 @@ def test_job_failure_alert_sliding_window(self): # this is the first failure with pytest.raises(Exception): - self.project.run_function("test-func") + self.project.run_function(function_name) # Wait for more than two minutes to simulate a delay that is slightly longer than the alert period time.sleep(125) # this is the second failure with pytest.raises(Exception): - self.project.run_function("test-func") + self.project.run_function(function_name) + + # wait since there is a might be a delay + time.sleep(mlconf.alerts.events_generation_interval) # validate that no notifications were sent yet, as the two failures did not occur within the same period expected_notifications = [] @@ -437,12 +290,23 @@ def test_job_failure_alert_sliding_window(self): # this failure should fall within the adjusted sliding window when combined with the second failure # should trigger the alert with pytest.raises(Exception): - self.project.run_function("test-func") + self.project.run_function(function_name) # validate that the alert was triggered and the notification was sent expected_notifications = ["notification failure"] - self._validate_notifications_on_nuclio( - nuclio_function_url, expected_notifications + + # wait since there is a might be a delay + mlrun.utils.retry_until_successful( + 3, + 10 * 3, + self._logger, + True, + self._validate_notifications_on_nuclio, + nuclio_function_url, + expected_notifications, + ) + mlrun.get_run_db().delete_function( + name=function_name, project=self.project.name ) @staticmethod @@ -535,6 +399,19 @@ def _create_alert_config( ) mlrun.get_run_db().store_alert_config(name, alert_data[0]) + def _validate_project_alerts_summary( + self, + expected_job_alerts_count=0, + expected_endpoint_alerts_count=0, + expected_other_alerts_count=0, + ): + project_summary = mlrun.get_run_db().get_project_summary( + project=self.project_name + ) + assert project_summary.job_alerts_count == expected_job_alerts_count + assert project_summary.endpoint_alerts_count == expected_endpoint_alerts_count + assert project_summary.other_alerts_count == expected_other_alerts_count + @staticmethod def _validate_notifications_on_nuclio(nuclio_function_url, expected_notifications): sent_notifications = list( @@ -548,3 +425,158 @@ def _validate_notifications_on_nuclio(nuclio_function_url, expected_notification ) == {} ) + + @staticmethod + def _generate_typical_event( + endpoint_id: str, + result_name: str, + endpoint_name: str, + ) -> dict[str, typing.Any]: + return { + mm_constants.WriterEvent.ENDPOINT_ID: endpoint_id, + mm_constants.WriterEvent.ENDPOINT_NAME: endpoint_name, + mm_constants.WriterEvent.APPLICATION_NAME: mm_constants.HistogramDataDriftApplicationConstants.NAME, + mm_constants.WriterEvent.START_INFER_TIME: "2023-09-11T12:00:00", + mm_constants.WriterEvent.END_INFER_TIME: "2023-09-11T12:01:00", + mm_constants.WriterEvent.EVENT_KIND: "result", + mm_constants.WriterEvent.DATA: json.dumps( + { + mm_constants.ResultData.RESULT_NAME: result_name, + mm_constants.ResultData.RESULT_KIND: mm_constants.ResultKindApp.model_performance.value, + mm_constants.ResultData.RESULT_VALUE: 0.1, + mm_constants.ResultData.RESULT_STATUS: mm_constants.ResultStatusApp.detected.value, + mm_constants.ResultData.RESULT_EXTRA_DATA: json.dumps( + {"threshold": 0.3} + ), + } + ), + } + + @staticmethod + def _generate_anomaly_events( + endpoint_id: str, + result_name: str, + endpoint_name: str, + ) -> list[dict[str, typing.Any]]: + data_drift_example = { + mm_constants.WriterEvent.ENDPOINT_ID: endpoint_id, + mm_constants.WriterEvent.ENDPOINT_NAME: endpoint_name, + mm_constants.WriterEvent.APPLICATION_NAME: mm_constants.HistogramDataDriftApplicationConstants.NAME, + mm_constants.WriterEvent.START_INFER_TIME: "2023-09-11T12:00:00", + mm_constants.WriterEvent.END_INFER_TIME: "2023-09-11T12:01:00", + mm_constants.WriterEvent.EVENT_KIND: "result", + mm_constants.WriterEvent.DATA: json.dumps( + { + mm_constants.ResultData.RESULT_NAME: result_name, + mm_constants.ResultData.RESULT_KIND: mm_constants.ResultKindApp.data_drift.value, + mm_constants.ResultData.RESULT_VALUE: 0.5, + mm_constants.ResultData.RESULT_STATUS: mm_constants.ResultStatusApp.detected.value, + mm_constants.ResultData.RESULT_EXTRA_DATA: json.dumps( + {"threshold": 0.3} + ), + } + ), + } + + concept_drift_example = { + mm_constants.WriterEvent.ENDPOINT_ID: endpoint_id, + mm_constants.WriterEvent.ENDPOINT_NAME: endpoint_name, + mm_constants.WriterEvent.APPLICATION_NAME: mm_constants.HistogramDataDriftApplicationConstants.NAME, + mm_constants.WriterEvent.START_INFER_TIME: "2023-09-11T12:00:00", + mm_constants.WriterEvent.END_INFER_TIME: "2023-09-11T12:01:00", + mm_constants.WriterEvent.EVENT_KIND: "result", + mm_constants.WriterEvent.DATA: json.dumps( + { + mm_constants.ResultData.RESULT_NAME: result_name, + mm_constants.ResultData.RESULT_KIND: mm_constants.ResultKindApp.concept_drift.value, + mm_constants.ResultData.RESULT_VALUE: 0.9, + mm_constants.ResultData.RESULT_STATUS: mm_constants.ResultStatusApp.potential_detection.value, + mm_constants.ResultData.RESULT_EXTRA_DATA: json.dumps( + {"threshold": 0.7} + ), + } + ), + } + + anomaly_example = { + mm_constants.WriterEvent.ENDPOINT_ID: endpoint_id, + mm_constants.WriterEvent.ENDPOINT_NAME: endpoint_name, + mm_constants.WriterEvent.APPLICATION_NAME: mm_constants.HistogramDataDriftApplicationConstants.NAME, + mm_constants.WriterEvent.START_INFER_TIME: "2023-09-11T12:00:00", + mm_constants.WriterEvent.END_INFER_TIME: "2023-09-11T12:01:00", + mm_constants.WriterEvent.EVENT_KIND: "result", + mm_constants.WriterEvent.DATA: json.dumps( + { + mm_constants.ResultData.RESULT_NAME: result_name, + mm_constants.ResultData.RESULT_KIND: mm_constants.ResultKindApp.mm_app_anomaly.value, + mm_constants.ResultData.RESULT_VALUE: 0.9, + mm_constants.ResultData.RESULT_STATUS: mm_constants.ResultStatusApp.detected.value, + mm_constants.ResultData.RESULT_EXTRA_DATA: json.dumps( + {"threshold": 0.4} + ), + } + ), + } + + system_performance_example = { + mm_constants.WriterEvent.ENDPOINT_ID: endpoint_id, + mm_constants.WriterEvent.ENDPOINT_NAME: endpoint_name, + mm_constants.WriterEvent.APPLICATION_NAME: mm_constants.HistogramDataDriftApplicationConstants.NAME, + mm_constants.WriterEvent.START_INFER_TIME: "2023-09-11T12:00:00", + mm_constants.WriterEvent.END_INFER_TIME: "2023-09-11T12:01:00", + mm_constants.WriterEvent.EVENT_KIND: "result", + mm_constants.WriterEvent.DATA: json.dumps( + { + mm_constants.ResultData.RESULT_NAME: result_name, + mm_constants.ResultData.RESULT_KIND: mm_constants.ResultKindApp.system_performance.value, + mm_constants.ResultData.RESULT_VALUE: 0.9, + mm_constants.ResultData.RESULT_STATUS: mm_constants.ResultStatusApp.detected.value, + mm_constants.ResultData.RESULT_EXTRA_DATA: json.dumps( + {"threshold": 0.4} + ), + } + ), + } + + return [ + data_drift_example, + concept_drift_example, + anomaly_example, + system_performance_example, + ] + + def _generate_alerts( + self, nuclio_function_url: str, model_endpoint: ModelEndpoint + ) -> list[str]: + """Generate alerts for the different result kind and return data from the expected notifications.""" + expected_notifications = [] + alerts_kind_to_test = [ + alert_objects.EventKind.DATA_DRIFT_DETECTED, + alert_objects.EventKind.CONCEPT_DRIFT_SUSPECTED, + alert_objects.EventKind.MM_APP_ANOMALY_DETECTED, + alert_objects.EventKind.SYSTEM_PERFORMANCE_DETECTED, + ] + # Create alert configurations for each alert_kind individually. + # This ensures different notifications are raised, which is why we don't send all alert_kinds as events at once. + + for alert_kind in alerts_kind_to_test: + alert_name = mlrun.utils.helpers.normalize_name( + f"drift-webhook-{alert_kind}" + ) + alert_summary = "Model is drifting" + self._create_alert_config( + name=alert_name, + summary=alert_summary, + model_endpoint=model_endpoint, + events=[alert_kind], + notifications=self._generate_drift_notifications( + nuclio_function_url, alert_kind.value + ), + ) + expected_notifications.extend( + [ + f"first drift of {alert_kind.value}", + f"second drift of {alert_kind.value}", + ] + ) + return expected_notifications From 6395cea7f3ddf4f72794b6727be6926b3f24107d Mon Sep 17 00:00:00 2001 From: moranbental <107995850+moranbental@users.noreply.github.com> Date: Sun, 5 Jan 2025 22:39:32 +0200 Subject: [PATCH 05/15] [Config] Ensure Correct Ordering of Environment Variables (#7065) --- mlrun/__init__.py | 40 +++++++++++++++++++++++++++++++++++++--- mlrun/config.py | 7 +++++++ tests/assets/envfile | 2 ++ tests/system/base.py | 5 +---- tests/test_config.py | 7 ++++++- 5 files changed, 53 insertions(+), 8 deletions(-) diff --git a/mlrun/__init__.py b/mlrun/__init__.py index 1497597a15b..63c360d357c 100644 --- a/mlrun/__init__.py +++ b/mlrun/__init__.py @@ -213,7 +213,41 @@ def set_env_from_file(env_file: str, return_dict: bool = False) -> Optional[dict env_vars = dotenv.dotenv_values(env_file) if None in env_vars.values(): raise MLRunInvalidArgumentError("env file lines must be in the form key=value") - for key, value in env_vars.items(): - environ[key] = value # Load to local environ + + ordered_env_vars = order_env_vars(env_vars) + for key, value in ordered_env_vars.items(): + environ[key] = value + mlconf.reload() # reload mlrun configuration - return env_vars if return_dict else None + return ordered_env_vars if return_dict else None + + +def order_env_vars(env_vars: dict[str, str]) -> dict[str, str]: + """ + Order and process environment variables by first handling specific ordered keys, + then processing the remaining keys in the given dictionary. + + The function ensures that environment variables defined in the `ordered_keys` list + are added to the result dictionary first. Any other environment variables from + `env_vars` are then added in the order they appear in the input dictionary. + + :param env_vars: A dictionary where each key is the name of an environment variable (str), + and each value is the corresponding environment variable value (str). + :return: A dictionary with the processed environment variables, ordered with the specific + keys first, followed by the rest in their original order. + """ + ordered_keys = mlconf.get_ordered_keys() + + ordered_env_vars: dict[str, str] = {} + + # First, add the ordered keys to the dictionary + for key in ordered_keys: + if key in env_vars: + ordered_env_vars[key] = env_vars[key] + + # Then, add the remaining keys (those not in ordered_keys) + for key, value in env_vars.items(): + if key not in ordered_keys: + ordered_env_vars[key] = value + + return ordered_env_vars diff --git a/mlrun/config.py b/mlrun/config.py index 322260ce488..fcdd3661114 100644 --- a/mlrun/config.py +++ b/mlrun/config.py @@ -1365,6 +1365,13 @@ def is_explicit_ack_enabled(self) -> bool: >= semver.VersionInfo.parse("1.12.10") ) + @staticmethod + def get_ordered_keys(): + # Define the keys to process first + return [ + "MLRUN_HTTPDB__HTTP__VERIFY" # Ensure this key is processed first for proper connection setup + ] + # Global configuration config = Config.from_dict(default_config) diff --git a/tests/assets/envfile b/tests/assets/envfile index 2cec3adea77..d49105ec93d 100644 --- a/tests/assets/envfile +++ b/tests/assets/envfile @@ -1,5 +1,7 @@ MLRUN_KFP_TTL=12345 ENV_ARG1=123 +MLRUN_HTTPDB__HTTP__VERIFY=false + # comment ENV_ARG2=abc diff --git a/tests/system/base.py b/tests/system/base.py index 78ba5e481c5..57195251ac4 100644 --- a/tests/system/base.py +++ b/tests/system/base.py @@ -224,10 +224,7 @@ def _setup_env(cls, env: dict): cls._logger.debug("Setting up test environment") cls._test_env.update(env) - # Define the keys to process first - ordered_keys = [ - "MLRUN_HTTPDB__HTTP__VERIFY" # Ensure this key is processed first for proper connection setup - ] + ordered_keys = mlconf.get_ordered_keys() # Process ordered keys for key in ordered_keys & env.keys(): diff --git a/tests/test_config.py b/tests/test_config.py index de84e32e6fa..a01e97e7225 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -587,7 +587,12 @@ def test_set_environment_cred(): def test_env_from_file(): env_path = str(assets_path / "envfile") env_dict = mlrun.set_env_from_file(env_path, return_dict=True) - assert env_dict == {"ENV_ARG1": "123", "ENV_ARG2": "abc", "MLRUN_KFP_TTL": "12345"} + assert env_dict == { + "ENV_ARG1": "123", + "ENV_ARG2": "abc", + "MLRUN_HTTPDB__HTTP__VERIFY": "false", + "MLRUN_KFP_TTL": "12345", + } assert mlrun.mlconf.kfp_ttl == 12345 for key, value in env_dict.items(): assert os.environ[key] == value From 0034f5a33fc030f929fbd399e2bdf639d29c5e1e Mon Sep 17 00:00:00 2001 From: daniels290813 <78727943+daniels290813@users.noreply.github.com> Date: Mon, 6 Jan 2025 15:45:28 +0200 Subject: [PATCH 06/15] [Jupyter] Correct docs image in dockerfile (#7054) --- dockerfiles/jupyter/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dockerfiles/jupyter/Dockerfile b/dockerfiles/jupyter/Dockerfile index 3667ea2352b..fd741b4653e 100644 --- a/dockerfiles/jupyter/Dockerfile +++ b/dockerfiles/jupyter/Dockerfile @@ -49,7 +49,7 @@ RUN python -m pip install --upgrade pip~=${MLRUN_PIP_VERSION} WORKDIR $HOME COPY --chown=$NB_UID:$NB_GID ./docs/tutorials $HOME/tutorials -COPY --chown=$NB_UID:$NB_GID ./docs/_static/images/MLRun-logo.png $HOME/_static/images +COPY --chown=$NB_UID:$NB_GID ./docs/_static/images/MLRun-logo.png $HOME/_static/images/MLRun-logo.png COPY --chown=$NB_UID:$NB_GID ./dockerfiles/jupyter/README.ipynb $HOME COPY --chown=$NB_UID:$NB_GID ./dockerfiles/jupyter/mlrun.env $HOME COPY --chown=$NB_UID:$NB_GID ./dockerfiles/jupyter/mlce-start.sh /usr/local/bin/mlce-start.sh From 121c355835f42b9336ed298225240f5b30c78ffb Mon Sep 17 00:00:00 2001 From: Liran BG Date: Mon, 6 Jan 2025 16:13:02 +0200 Subject: [PATCH 07/15] [Microservices] Move pagination to framework (#7066) --- .../py/framework/routers/alert_activations.py | 4 +- server/py/framework/service/__init__.py | 40 +++++++-- .../api => framework}/utils/pagination.py | 25 +++--- .../utils}/pagination_cache.py | 0 server/py/services/alerts/main.py | 10 ++- .../api/api/endpoints/artifacts_v2.py | 4 +- .../services/api/api/endpoints/functions.py | 4 +- server/py/services/api/api/endpoints/runs.py | 4 +- server/py/services/api/crud/__init__.py | 1 - server/py/services/api/initial_data.py | 4 +- server/py/services/api/main.py | 13 ++- .../tests/unit/crud/test_pagination_cache.py | 73 +++++++++++---- .../unit/utils/clients/test_discovery.py | 9 +- .../api/tests/unit/utils/test_pagination.py | 89 ++++++++++++------- 14 files changed, 189 insertions(+), 91 deletions(-) rename server/py/{services/api => framework}/utils/pagination.py (95%) rename server/py/{services/api/crud => framework/utils}/pagination_cache.py (100%) diff --git a/server/py/framework/routers/alert_activations.py b/server/py/framework/routers/alert_activations.py index 3b61adda4f6..038cb50fab8 100644 --- a/server/py/framework/routers/alert_activations.py +++ b/server/py/framework/routers/alert_activations.py @@ -22,9 +22,7 @@ import mlrun.common.schemas -import framework.utils.auth.verifier -import framework.utils.clients.chief -import framework.utils.singletons.project_member +import framework.service from framework.api import deps router = APIRouter() diff --git a/server/py/framework/service/__init__.py b/server/py/framework/service/__init__.py index ad97961ad4a..8fd35a2652a 100644 --- a/server/py/framework/service/__init__.py +++ b/server/py/framework/service/__init__.py @@ -35,6 +35,7 @@ import framework.middlewares import framework.utils.clients.chief import framework.utils.clients.messaging +import framework.utils.pagination import framework.utils.periodic from framework.utils.singletons.db import initialize_db @@ -45,10 +46,11 @@ def __init__(self): self.service_prefix = f"/{self.service_name}" self.base_versioned_service_prefix = f"{self.service_prefix}/v1" self.v2_service_prefix = f"{self.service_prefix}/v2" - self.app: fastapi.FastAPI = None + self.app: typing.Optional[fastapi.FastAPI] = None self._logger = mlrun.utils.logger.get_child(self.service_name) self._mounted_services: list[Service] = [] self._messaging_client = framework.utils.clients.messaging.Client() + self._paginated_methods: list[tuple[typing.Callable, str]] = [] def initialize(self, mounts: typing.Optional[list] = None): self._logger.info("Initializing service", service_name=self.service_name) @@ -57,6 +59,7 @@ def initialize(self, mounts: typing.Optional[list] = None): self._mount_services(mounts) self._add_middlewares() self._add_exception_handlers() + self._ensure_paginated_methods() async def move_service_to_online(self): self._logger.info("Moving service to online", service_name=self.service_name) @@ -95,6 +98,21 @@ async def handle_request( **kwargs, ) + def is_forwarded_request(self, request: fastapi.Request) -> bool: + """ + Determines whether the incoming request should be forwarded to another service. + + :param request: The incoming FastAPI request. + :return: `True` if the request should be forwarded to another service, otherwise `False`. + """ + + # let non-api requests pass through + if request.url.path.startswith( + self.service_prefix + ) and not request.url.path.startswith("/api/"): + return False + return self._messaging_client.is_forwarded_request(request) + @abstractmethod async def _move_service_to_online(self): pass @@ -367,14 +385,20 @@ async def _align_worker_state_with_chief_state( self._synchronize_with_chief_clusterization_spec.__name__ ) - def is_forwarded_request(self, request: fastapi.Request) -> bool: - """ - Determines whether the incoming request should be forwarded to another service. + def _ensure_paginated_methods(self): + for cls, method in self._resolve_paginated_methods(): + framework.utils.pagination.PaginatedMethods.add_method( + getattr(cls(), method) + ) - :param request: The incoming FastAPI request. - :return: `True` if the request should be forwarded to another service, otherwise `False`. - """ - return self._messaging_client.is_forwarded_request(request) + def _resolve_paginated_methods( + self, + ) -> typing.Generator[tuple[typing.Callable, str], None, None]: + for cls, method in self._paginated_methods: + yield cls, method + for mounted_service in self._mounted_services: + for cls, method in mounted_service._paginated_methods: + yield cls, method class Daemon(ABC): diff --git a/server/py/services/api/utils/pagination.py b/server/py/framework/utils/pagination.py similarity index 95% rename from server/py/services/api/utils/pagination.py rename to server/py/framework/utils/pagination.py index 562533e3a12..b1963f07b23 100644 --- a/server/py/services/api/utils/pagination.py +++ b/server/py/framework/utils/pagination.py @@ -26,8 +26,7 @@ from mlrun.utils import logger import framework.utils.asyncio -import services.alerts.crud -import services.api.crud +import framework.utils.pagination_cache def _generate_pydantic_schema_from_method_signature( @@ -59,21 +58,17 @@ class Config: ) -class PaginatedMethods: - _methods: list[typing.Callable] = [ - # TODO: add methods when they implement pagination - services.api.crud.Runs().list_runs, - services.api.crud.Functions().list_functions, - services.api.crud.Artifacts().list_artifacts, - services.alerts.crud.AlertActivation().list_alert_activations, - ] - _method_map = { - method.__name__: { +class PaginatedMethods(metaclass=mlrun.utils.singleton.Singleton): + _methods: list[typing.Callable] = [] + _method_map = {} + + @classmethod + def add_method(cls, method: typing.Callable): + cls._methods.append(method) + cls._method_map[method.__name__] = { "method": method, "schema": _generate_pydantic_schema_from_method_signature(method), } - for method in _methods - } @classmethod def method_is_supported(cls, method: typing.Union[str, typing.Callable]) -> bool: @@ -92,7 +87,7 @@ def get_method_schema(cls, method_name: str) -> pydantic.v1.BaseModel: class Paginator(metaclass=mlrun.utils.singleton.Singleton): def __init__(self): self._logger = logger.get_child("paginator") - self._pagination_cache = services.api.crud.PaginationCache() + self._pagination_cache = framework.utils.pagination_cache.PaginationCache() async def paginate_permission_filtered_request( self, diff --git a/server/py/services/api/crud/pagination_cache.py b/server/py/framework/utils/pagination_cache.py similarity index 100% rename from server/py/services/api/crud/pagination_cache.py rename to server/py/framework/utils/pagination_cache.py diff --git a/server/py/services/alerts/main.py b/server/py/services/alerts/main.py index 1aaf9b05add..590f0476ec4 100644 --- a/server/py/services/alerts/main.py +++ b/server/py/services/alerts/main.py @@ -34,6 +34,7 @@ import framework.service import framework.utils.auth.verifier import framework.utils.clients.chief +import framework.utils.pagination import framework.utils.periodic import framework.utils.singletons.db import framework.utils.singletons.project_member @@ -41,7 +42,6 @@ import services.alerts.crud import services.alerts.initial_data import services.api.crud -import services.api.utils.pagination from framework.db.session import close_session, create_session from framework.routers import ( alert_activations, @@ -58,6 +58,12 @@ class Service(framework.service.Service): + def __init__(self): + super().__init__() + self._paginated_methods = [ + (services.alerts.crud.AlertActivation, "list_alert_activations"), + ] + async def store_alert( self, request: fastapi.Request, @@ -409,7 +415,7 @@ async def list_alert_activations( project=project, ) ) - paginator = services.api.utils.pagination.Paginator() + paginator = framework.utils.pagination.Paginator() async def _filter_alert_activations_by_permissions(_alert_activations): return await framework.utils.auth.verifier.AuthVerifier().filter_project_resources_by_permissions( diff --git a/server/py/services/api/api/endpoints/artifacts_v2.py b/server/py/services/api/api/endpoints/artifacts_v2.py index 8230a91b336..6501af193e4 100644 --- a/server/py/services/api/api/endpoints/artifacts_v2.py +++ b/server/py/services/api/api/endpoints/artifacts_v2.py @@ -25,9 +25,9 @@ from mlrun.utils import logger import framework.utils.auth.verifier +import framework.utils.pagination import framework.utils.singletons.project_member import services.api.crud -import services.api.utils.pagination from framework.api import deps from framework.api.utils import artifact_project_and_resource_name_extractor @@ -197,7 +197,7 @@ async def list_artifacts( "'page/page_size' and 'limit' are conflicting, only one can be specified." ) - paginator = services.api.utils.pagination.Paginator() + paginator = framework.utils.pagination.Paginator() async def _filter_artifacts(_artifacts): return await framework.utils.auth.verifier.AuthVerifier().filter_project_resources_by_permissions( diff --git a/server/py/services/api/api/endpoints/functions.py b/server/py/services/api/api/endpoints/functions.py index 7346b41b523..e63c3f70332 100644 --- a/server/py/services/api/api/endpoints/functions.py +++ b/server/py/services/api/api/endpoints/functions.py @@ -48,13 +48,13 @@ import framework.utils.background_tasks import framework.utils.clients.chief import framework.utils.helpers +import framework.utils.pagination import framework.utils.singletons.k8s import framework.utils.singletons.project_member import services.api.crud.model_monitoring.deployment import services.api.crud.runtimes.nuclio.function import services.api.launcher import services.api.utils.functions -import services.api.utils.pagination from framework.api import deps from services.api.api.endpoints.nuclio import ( _get_api_gateways_urls_for_function, @@ -232,7 +232,7 @@ async def list_functions( ) ) - paginator = services.api.utils.pagination.Paginator() + paginator = framework.utils.pagination.Paginator() async def _filter_functions_by_permissions(_functions): return await framework.utils.auth.verifier.AuthVerifier().filter_project_resources_by_permissions( diff --git a/server/py/services/api/api/endpoints/runs.py b/server/py/services/api/api/endpoints/runs.py index ad2d56bf692..37d9cdb8c66 100644 --- a/server/py/services/api/api/endpoints/runs.py +++ b/server/py/services/api/api/endpoints/runs.py @@ -29,10 +29,10 @@ import framework.utils.auth.verifier import framework.utils.background_tasks import framework.utils.notifications +import framework.utils.pagination import framework.utils.singletons.db as db_singleton import framework.utils.singletons.project_member import services.api.crud -import services.api.utils.pagination from framework.api import deps from framework.api.utils import log_and_raise @@ -242,7 +242,7 @@ async def list_runs( ) ) - paginator = services.api.utils.pagination.Paginator() + paginator = framework.utils.pagination.Paginator() async def _filter_runs(_runs): return await framework.utils.auth.verifier.AuthVerifier().filter_project_resources_by_permissions( diff --git a/server/py/services/api/crud/__init__.py b/server/py/services/api/crud/__init__.py index 7047adcbb73..334d881e594 100644 --- a/server/py/services/api/crud/__init__.py +++ b/server/py/services/api/crud/__init__.py @@ -23,7 +23,6 @@ from .logs import Logs from .model_monitoring import ModelEndpoints from .notifications import Notifications -from .pagination_cache import PaginationCache from .pipelines import Pipelines from .projects import Projects from .runs import Runs diff --git a/server/py/services/api/initial_data.py b/server/py/services/api/initial_data.py index 97fe1c7edda..bfda3a98904 100644 --- a/server/py/services/api/initial_data.py +++ b/server/py/services/api/initial_data.py @@ -43,7 +43,7 @@ import framework.db.sqldb.db import framework.db.sqldb.models import framework.utils.db.mysql -import services.api.crud.pagination_cache +import framework.utils.pagination_cache import services.api.utils.db.alembic import services.api.utils.db.backup import services.api.utils.scheduler @@ -133,7 +133,7 @@ def init_data( if not from_scratch: # Cleanup pagination cache on api startup session = create_session() - services.api.crud.pagination_cache.PaginationCache().cleanup_pagination_cache( + framework.utils.pagination_cache.PaginationCache().cleanup_pagination_cache( session ) session.commit() diff --git a/server/py/services/api/main.py b/server/py/services/api/main.py index 347a176e78c..aa831b97cae 100644 --- a/server/py/services/api/main.py +++ b/server/py/services/api/main.py @@ -42,6 +42,7 @@ import framework.utils.clients.log_collector import framework.utils.clients.messaging import framework.utils.notifications.notification_pusher +import framework.utils.pagination_cache import framework.utils.time_window_tracker import services.api.crud import services.api.initial_data @@ -73,6 +74,14 @@ class Service(framework.service.Service): + def __init__(self): + super().__init__() + self._paginated_methods = [ + (services.api.crud.Runs, "list_runs"), + (services.api.crud.Functions, "list_functions"), + (services.api.crud.Artifacts, "list_artifacts"), + ] + async def _move_service_to_online(self): # scheduler is needed on both workers and chief # on workers - it allows to us to list/get scheduler(s) @@ -545,10 +554,10 @@ def _start_periodic_pagination_cache_monitoring(self): ) run_function_periodically( interval, - services.api.crud.pagination_cache.PaginationCache().monitor_pagination_cache.__name__, + framework.utils.pagination_cache.PaginationCache().monitor_pagination_cache.__name__, False, framework.db.session.run_function_with_new_db_session, - services.api.crud.pagination_cache.PaginationCache().monitor_pagination_cache, + framework.utils.pagination_cache.PaginationCache().monitor_pagination_cache, ) def _start_periodic_project_summaries_calculation(self): diff --git a/server/py/services/api/tests/unit/crud/test_pagination_cache.py b/server/py/services/api/tests/unit/crud/test_pagination_cache.py index 5e19fc1737a..8e479f70adb 100644 --- a/server/py/services/api/tests/unit/crud/test_pagination_cache.py +++ b/server/py/services/api/tests/unit/crud/test_pagination_cache.py @@ -22,6 +22,7 @@ from mlrun import mlconf from mlrun.utils import logger +import framework.utils.pagination_cache import services.api.crud from framework.db.sqldb.db import MAX_INT_32 @@ -41,12 +42,17 @@ def test_pagination_cache_monitor_ttl(db: sqlalchemy.orm.Session): logger.debug("Creating paginated cache records") for i in range(3): - services.api.crud.PaginationCache().store_pagination_cache_record( + framework.utils.pagination_cache.PaginationCache().store_pagination_cache_record( db, f"user{i}", method, page, page_size, kwargs ) assert ( - len(services.api.crud.PaginationCache().list_pagination_cache_records(db)) == 3 + len( + framework.utils.pagination_cache.PaginationCache().list_pagination_cache_records( + db + ) + ) + == 3 ) logger.debug( @@ -55,19 +61,26 @@ def test_pagination_cache_monitor_ttl(db: sqlalchemy.orm.Session): time.sleep(ttl + 2) logger.debug("Creating new paginated cache record that won't be expired") - new_key = services.api.crud.PaginationCache().store_pagination_cache_record( + new_key = framework.utils.pagination_cache.PaginationCache().store_pagination_cache_record( db, "user3", method, page, page_size, kwargs ) logger.debug("Monitoring pagination cache") - services.api.crud.PaginationCache().monitor_pagination_cache(db) + framework.utils.pagination_cache.PaginationCache().monitor_pagination_cache(db) logger.debug("Checking that old records were removed and new record still exists") assert ( - len(services.api.crud.PaginationCache().list_pagination_cache_records(db)) == 1 + len( + framework.utils.pagination_cache.PaginationCache().list_pagination_cache_records( + db + ) + ) + == 1 ) assert ( - services.api.crud.PaginationCache().get_pagination_cache_record(db, new_key) + framework.utils.pagination_cache.PaginationCache().get_pagination_cache_record( + db, new_key + ) is not None ) @@ -86,7 +99,7 @@ def test_pagination_cache_monitor_max_table_size(db: sqlalchemy.orm.Session): kwargs = {} logger.debug("Creating old paginated cache record") - old_key = services.api.crud.PaginationCache().store_pagination_cache_record( + old_key = framework.utils.pagination_cache.PaginationCache().store_pagination_cache_record( db, "user0", method, page, page_size, kwargs ) @@ -97,36 +110,48 @@ def test_pagination_cache_monitor_max_table_size(db: sqlalchemy.orm.Session): "Creating paginated cache records up to max size (including the old record)" ) for i in range(1, max_size): - services.api.crud.PaginationCache().store_pagination_cache_record( + framework.utils.pagination_cache.PaginationCache().store_pagination_cache_record( db, f"user{i}", method, page, page_size, kwargs ) assert ( - len(services.api.crud.PaginationCache().list_pagination_cache_records(db)) + len( + framework.utils.pagination_cache.PaginationCache().list_pagination_cache_records( + db + ) + ) == max_size ) logger.debug("Creating new paginated cache record to replace the old one") - new_key = services.api.crud.PaginationCache().store_pagination_cache_record( + new_key = framework.utils.pagination_cache.PaginationCache().store_pagination_cache_record( db, "user3", method, page, page_size, kwargs ) logger.debug("Monitoring pagination cache") - services.api.crud.PaginationCache().monitor_pagination_cache(db) + framework.utils.pagination_cache.PaginationCache().monitor_pagination_cache(db) logger.debug( "Checking that old record was removed and all other records still exist" ) assert ( - len(services.api.crud.PaginationCache().list_pagination_cache_records(db)) + len( + framework.utils.pagination_cache.PaginationCache().list_pagination_cache_records( + db + ) + ) == max_size ) assert ( - services.api.crud.PaginationCache().get_pagination_cache_record(db, new_key) + framework.utils.pagination_cache.PaginationCache().get_pagination_cache_record( + db, new_key + ) is not None ) assert ( - services.api.crud.PaginationCache().get_pagination_cache_record(db, old_key) + framework.utils.pagination_cache.PaginationCache().get_pagination_cache_record( + db, old_key + ) is None ) @@ -142,21 +167,31 @@ def test_pagination_cleanup(db: sqlalchemy.orm.Session): logger.debug("Creating paginated cache records") for i in range(3): - services.api.crud.PaginationCache().store_pagination_cache_record( + framework.utils.pagination_cache.PaginationCache().store_pagination_cache_record( db, f"user{i}", method, page, page_size, kwargs ) assert ( - len(services.api.crud.PaginationCache().list_pagination_cache_records(db)) == 3 + len( + framework.utils.pagination_cache.PaginationCache().list_pagination_cache_records( + db + ) + ) + == 3 ) logger.debug("Cleaning up pagination cache") - services.api.crud.PaginationCache().cleanup_pagination_cache(db) + framework.utils.pagination_cache.PaginationCache().cleanup_pagination_cache(db) db.commit() logger.debug("Checking that all records were removed") assert ( - len(services.api.crud.PaginationCache().list_pagination_cache_records(db)) == 0 + len( + framework.utils.pagination_cache.PaginationCache().list_pagination_cache_records( + db + ) + ) + == 0 ) @@ -174,6 +209,6 @@ def test_store_paginated_query_cache_record_out_of_range( kwargs = {} with pytest.raises(mlrun.errors.MLRunInvalidArgumentError): - services.api.crud.PaginationCache().store_pagination_cache_record( + framework.utils.pagination_cache.PaginationCache().store_pagination_cache_record( db, "user_name", method, page, page_size, kwargs ) diff --git a/server/py/services/api/tests/unit/utils/clients/test_discovery.py b/server/py/services/api/tests/unit/utils/clients/test_discovery.py index 4cc5e3b8217..e22291a25b5 100644 --- a/server/py/services/api/tests/unit/utils/clients/test_discovery.py +++ b/server/py/services/api/tests/unit/utils/clients/test_discovery.py @@ -15,6 +15,7 @@ import re import fastapi.testclient +import pytest from mlrun import mlconf @@ -66,13 +67,17 @@ def test_star_notation_translation(): assert route_regex in service_instance.method_routes["get"] -def test_find_service(): - method, path = "get", "projects/test/alerts" +@pytest.mark.parametrize( + "method, path", [("get", "projects/test/alerts"), ("get", "projects/*/alerts")] +) +def test_find_service(method, path): + # requests goes to api mlconf.services.hydra.services = "*" discovery = framework.utils.clients.discovery.Client() service_instance = discovery.resolve_service_by_request(method, path) assert service_instance is None + # request goes to api > alerts mlconf.services.hydra.services = "" discovery.initialize() service_instance = discovery.resolve_service_by_request(method, path) diff --git a/server/py/services/api/tests/unit/utils/test_pagination.py b/server/py/services/api/tests/unit/utils/test_pagination.py index f7f00506749..5145cd0bf54 100644 --- a/server/py/services/api/tests/unit/utils/test_pagination.py +++ b/server/py/services/api/tests/unit/utils/test_pagination.py @@ -23,8 +23,8 @@ from mlrun.utils import logger import framework.db.sqldb.models -import services.api.crud -import services.api.utils.pagination +import framework.utils.pagination +import framework.utils.pagination_cache def paginated_method( @@ -52,12 +52,12 @@ def dict(self): return self._dict monkeypatch.setattr( - services.api.utils.pagination.PaginatedMethods, + framework.utils.pagination.PaginatedMethods, "_method_map", { paginated_method.__name__: { "method": paginated_method, - "schema": services.api.utils.pagination._generate_pydantic_schema_from_method_signature( + "schema": framework.utils.pagination._generate_pydantic_schema_from_method_signature( paginated_method ), } @@ -69,7 +69,7 @@ def dict(self): @pytest.fixture() def cleanup_pagination_cache_on_teardown(db: sqlalchemy.orm.Session): yield - services.api.crud.PaginationCache().cleanup_pagination_cache(db) + framework.utils.pagination_cache.PaginationCache().cleanup_pagination_cache(db) def test_paginated_method(): @@ -80,7 +80,7 @@ def test_paginated_method(): total_amount = 10 page_size = 3 since = datetime.datetime.now() - paginator = services.api.utils.pagination.Paginator() + paginator = framework.utils.pagination.Paginator() offset, limit = paginator._calculate_offset_and_limit(1, page_size) items = paginated_method(None, total_amount, since, offset, limit - 1) @@ -135,7 +135,7 @@ async def test_paginate_request( page_size = 3 method_kwargs = {"total_amount": 5, "since": datetime.datetime.now()} - paginator = services.api.utils.pagination.Paginator() + paginator = framework.utils.pagination.Paginator() logger.info("Requesting first page") response, pagination_info = await paginator.paginate_request( @@ -151,8 +151,10 @@ async def test_paginate_request( ) logger.info("Checking db cache record") - cache_record = services.api.crud.PaginationCache().get_pagination_cache_record( - db, pagination_info.page_token + cache_record = ( + framework.utils.pagination_cache.PaginationCache().get_pagination_cache_record( + db, pagination_info.page_token + ) ) _assert_cache_record( cache_record, auth_info.user_id, paginated_method, 1, page_size @@ -173,8 +175,10 @@ async def test_paginate_request( ) logger.info("Checking db cache record") - cache_record = services.api.crud.PaginationCache().get_pagination_cache_record( - db, pagination_info.page_token + cache_record = ( + framework.utils.pagination_cache.PaginationCache().get_pagination_cache_record( + db, pagination_info.page_token + ) ) _assert_cache_record( cache_record, auth_info.user_id, paginated_method, 2, page_size @@ -198,7 +202,7 @@ async def test_paginate_other_users_token( page_size = 3 method_kwargs = {"total_amount": 5, "since": datetime.datetime.now()} - paginator = services.api.utils.pagination.Paginator() + paginator = framework.utils.pagination.Paginator() logger.info("Requesting first page with user1") response, pagination_info = await paginator.paginate_request( @@ -214,8 +218,10 @@ async def test_paginate_other_users_token( ) logger.info("Checking db cache record") - cache_record = services.api.crud.PaginationCache().get_pagination_cache_record( - db, pagination_info.page_token + cache_record = ( + framework.utils.pagination_cache.PaginationCache().get_pagination_cache_record( + db, pagination_info.page_token + ) ) _assert_cache_record( cache_record, auth_info_1.user_id, paginated_method, 1, page_size @@ -251,7 +257,7 @@ async def test_paginate_no_auth( page_size = 3 method_kwargs = {"total_amount": 5, "since": datetime.datetime.now()} - paginator = services.api.utils.pagination.Paginator() + paginator = framework.utils.pagination.Paginator() logger.info("Requesting first page") response, pagination_info = await paginator.paginate_request( @@ -267,8 +273,10 @@ async def test_paginate_no_auth( ) logger.info("Checking db cache record") - cache_record = services.api.crud.PaginationCache().get_pagination_cache_record( - db, pagination_info.page_token + cache_record = ( + framework.utils.pagination_cache.PaginationCache().get_pagination_cache_record( + db, pagination_info.page_token + ) ) _assert_cache_record(cache_record, None, paginated_method, 1, page_size) @@ -290,8 +298,10 @@ async def test_paginate_no_auth( ) logger.info("Checking old db cache record") - cache_record = services.api.crud.PaginationCache().get_pagination_cache_record( - db, old_token + cache_record = ( + framework.utils.pagination_cache.PaginationCache().get_pagination_cache_record( + db, old_token + ) ) # The request with AuthInfo creates a new cache record, therefore the old one # should still be on page 1 and without a user. @@ -314,8 +324,10 @@ async def test_paginate_no_auth( ) logger.info("Checking old db cache record again") - cache_record = services.api.crud.PaginationCache().get_pagination_cache_record( - db, old_token + cache_record = ( + framework.utils.pagination_cache.PaginationCache().get_pagination_cache_record( + db, old_token + ) ) _assert_cache_record(cache_record, None, paginated_method, 2, page_size) @@ -333,7 +345,7 @@ async def test_no_pagination( auth_info = mlrun.common.schemas.AuthInfo(user_id="user1") method_kwargs = {"total_amount": 5, "since": datetime.datetime.now()} - paginator = services.api.utils.pagination.Paginator() + paginator = framework.utils.pagination.Paginator() logger.info("Requesting all items") response, pagination_info = await paginator.paginate_request( @@ -350,7 +362,12 @@ async def test_no_pagination( logger.info("Checking that no cache record was created") assert ( - len(services.api.crud.PaginationCache().list_pagination_cache_records(db)) == 0 + len( + framework.utils.pagination_cache.PaginationCache().list_pagination_cache_records( + db + ) + ) + == 0 ) @@ -367,7 +384,7 @@ async def test_pagination_not_supported( auth_info = mlrun.common.schemas.AuthInfo(user_id="user1") method_kwargs = {"total_amount": 5, "since": datetime.datetime.now()} - paginator = services.api.utils.pagination.Paginator() + paginator = framework.utils.pagination.Paginator() logger.info("Requesting a method that is not supported for pagination") with pytest.raises(NotImplementedError): @@ -397,7 +414,7 @@ async def test_pagination_cache_cleanup( page_size = 3 token = None - paginator = services.api.utils.pagination.Paginator() + paginator = framework.utils.pagination.Paginator() logger.info("Creating paginated cache records") for i in range(3): @@ -415,7 +432,12 @@ async def test_pagination_cache_cleanup( token = pagination_info.page_token assert ( - len(services.api.crud.PaginationCache().list_pagination_cache_records(db)) == 3 + len( + framework.utils.pagination_cache.PaginationCache().list_pagination_cache_records( + db + ) + ) + == 3 ) logger.info("Cleaning up pagination cache") @@ -424,7 +446,12 @@ async def test_pagination_cache_cleanup( logger.info("Checking that all records were removed") assert ( - len(services.api.crud.PaginationCache().list_pagination_cache_records(db)) == 0 + len( + framework.utils.pagination_cache.PaginationCache().list_pagination_cache_records( + db + ) + ) + == 0 ) logger.info("Try to get page with token") @@ -518,7 +545,7 @@ async def filter_(items): last_page = total // page_size method_kwargs = {"total_amount": total, "since": datetime.datetime.now()} - paginator = services.api.utils.pagination.Paginator() + paginator = framework.utils.pagination.Paginator() response, pagination_info = await paginator.paginate_permission_filtered_request( db, @@ -557,7 +584,7 @@ async def test_paginate_permission_filtered_no_pagination( auth_info = mlrun.common.schemas.AuthInfo(user_id="user1") method_kwargs = {"total_amount": 5, "since": datetime.datetime.now()} - paginator = services.api.utils.pagination.Paginator() + paginator = framework.utils.pagination.Paginator() async def filter_(items): return items @@ -612,7 +639,7 @@ async def filter_(items): page_size = 4 method_kwargs = {"total_amount": 20, "since": datetime.datetime.now()} - paginator = services.api.utils.pagination.Paginator() + paginator = framework.utils.pagination.Paginator() response, pagination_info = await paginator.paginate_permission_filtered_request( db, @@ -687,7 +714,7 @@ async def test_paginate_request_invalid_page_or_page_size( auth_info = mlrun.common.schemas.AuthInfo(user_id="user1") method_kwargs = {"total_amount": 5, "since": datetime.datetime.now()} - paginator = services.api.utils.pagination.Paginator() + paginator = framework.utils.pagination.Paginator() with pytest.raises(mlrun.errors.MLRunInvalidArgumentError): await paginator.paginate_request( From 6116039fed541ed1f68e2cbc3b27fd6c3d9f052e Mon Sep 17 00:00:00 2001 From: davesh0812 <85231462+davesh0812@users.noreply.github.com> Date: Mon, 6 Jan 2025 20:50:06 +0200 Subject: [PATCH 08/15] [Serving] Raise error if there are no routes (#7068) --- mlrun/serving/states.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mlrun/serving/states.py b/mlrun/serving/states.py index 399d6b52bbf..9fe173ad4fb 100644 --- a/mlrun/serving/states.py +++ b/mlrun/serving/states.py @@ -806,6 +806,10 @@ def clear_children(self, routes: list): del self._routes[key] def init_object(self, context, namespace, mode="sync", reset=False, **extra_kwargs): + if not self.routes: + raise mlrun.errors.MLRunRuntimeError( + "You have to add models to the router step before initializing it" + ) if not self._is_local_function(context): return From c9fb305bc73d8ce095cf7288d88211a8318bfe14 Mon Sep 17 00:00:00 2001 From: davesh0812 <85231462+davesh0812@users.noreply.github.com> Date: Mon, 6 Jan 2025 20:52:50 +0200 Subject: [PATCH 09/15] [Serving] Limit the number of models to 5K per Router (#7062) --- mlrun/errors.py | 4 ++++ mlrun/serving/states.py | 20 +++++++++--------- tests/serving/test_serving.py | 38 ++++++++++++++++++++++++++++------- 3 files changed, 46 insertions(+), 16 deletions(-) diff --git a/mlrun/errors.py b/mlrun/errors.py index 821cfe217b4..c53ed97e716 100644 --- a/mlrun/errors.py +++ b/mlrun/errors.py @@ -174,6 +174,10 @@ class MLRunInvalidArgumentError(MLRunHTTPStatusError, ValueError): error_status_code = HTTPStatus.BAD_REQUEST.value +class MLRunModelLimitExceededError(MLRunHTTPStatusError, ValueError): + error_status_code = HTTPStatus.BAD_REQUEST.value + + class MLRunInvalidArgumentTypeError(MLRunHTTPStatusError, TypeError): error_status_code = HTTPStatus.BAD_REQUEST.value diff --git a/mlrun/serving/states.py b/mlrun/serving/states.py index 9fe173ad4fb..d07c4b04061 100644 --- a/mlrun/serving/states.py +++ b/mlrun/serving/states.py @@ -31,6 +31,7 @@ import mlrun import mlrun.common.schemas as schemas +from mlrun.utils import logger from ..config import config from ..datastore import get_stream_pusher @@ -49,6 +50,8 @@ previous_step = "$prev" queue_class_names = [">>", "$queue"] +MAX_MODELS_PER_ROUTER = 5000 + class GraphError(Exception): """error in graph topology or configuration""" @@ -87,9 +90,6 @@ class StepKinds: ] -MAX_ALLOWED_STEPS = 4500 - - def new_remote_endpoint( url: str, creation_strategy: schemas.ModelEndpointCreationStrategy, @@ -755,7 +755,7 @@ def add_route( creation_strategy: schemas.ModelEndpointCreationStrategy = schemas.ModelEndpointCreationStrategy.INPLACE, **class_args, ): - """add child route step or class to the router + """add child route step or class to the router, if key exists it will be updated :param key: unique name (and route path) for the child step :param route: child step object (Task, ..) @@ -775,7 +775,13 @@ def add_route( 2. Create a new model endpoint with the same name and set it to `latest`. """ - + if len(self.routes.keys()) >= MAX_MODELS_PER_ROUTER and key not in self.routes: + raise mlrun.errors.MLRunModelLimitExceededError( + f"Router cannot support more than {MAX_MODELS_PER_ROUTER} model endpoints. " + f"To add a new route, edit an existing one by passing the same key." + ) + if key in self.routes: + logger.info(f"Model {key} already exists, updating it.") if not route and not class_name and not handler: raise MLRunInvalidArgumentError("route or class_name must be specified") if not route: @@ -790,10 +796,6 @@ def add_route( ) route.function = function or route.function - if len(self._routes) >= MAX_ALLOWED_STEPS: - raise mlrun.errors.MLRunInvalidArgumentError( - f"Cannot create the serving graph: the maximum number of steps is {MAX_ALLOWED_STEPS}" - ) route = self._routes.update(key, route) route.set_parent(self) return route diff --git a/tests/serving/test_serving.py b/tests/serving/test_serving.py index 3a1f60a2b7c..cc0e3116b2a 100644 --- a/tests/serving/test_serving.py +++ b/tests/serving/test_serving.py @@ -17,6 +17,7 @@ import pathlib import random import time +from unittest.mock import patch import pandas as pd import pytest @@ -811,10 +812,33 @@ def test_mock_invoke(): mlrun.mlconf.mock_nuclio_deployment = mock_nuclio_config -def test_add_route_exceeds_max_steps(): - """Test adding a route when the maximum number of steps is exceeded.""" - host = create_graph_server(graph=RouterStep()) - max_steps = mlrun.serving.states.MAX_ALLOWED_STEPS - with pytest.raises(mlrun.errors.MLRunInvalidArgumentError): - for key in range(max_steps + 1): - host.graph.add_route(f"test_key_{key}", class_name=ModelTestingClass) +def test_updating_model(): + fn = mlrun.new_function("tests", kind="serving") + fn.add_model("my", ".", class_name=ModelTestingClass(multiplier=100)) + server = fn.to_mock_server() + resp = server.test("/v2/models/my/infer", testdata) + assert resp["outputs"] == 5 * 100, f"wrong data response {resp}" + + with patch("mlrun.utils.logger.info") as mock_warning: + # update the model + fn.add_model("my", ".", class_name=ModelTestingClass(multiplier=200)) + mock_warning.assert_called_with("Model my already exists, updating it.") + server = fn.to_mock_server() + resp = server.test("/v2/models/my/infer", testdata) + assert resp["outputs"] == 5 * 200, f"wrong data response {resp}" + + +def test_add_route_exceeds_max_models(): + """Test adding a route when the maximum number of models is exceeded.""" + server = create_graph_server(graph=RouterStep()) + max_models = mlrun.serving.states.MAX_MODELS_PER_ROUTER + with pytest.raises(mlrun.errors.MLRunModelLimitExceededError): + for key in range(max_models + 1): + server.graph.add_route(f"test_key_{key}", class_name=ModelTestingClass) + + # edit existing model + server.graph.add_route(f"test_key_{key-1}", class_name=ModelTestingClass) + + assert ( + len(server.graph.routes) == max_models + ), f"expected to have {max_models} models" From 3d3b0081ede1340e811b20c79817b53b8ce0e042 Mon Sep 17 00:00:00 2001 From: Jonathan Daniel <36337649+jond01@users.noreply.github.com> Date: Mon, 6 Jan 2025 20:58:27 +0200 Subject: [PATCH 10/15] [Datastore] Add TDEngine datastore profile (#7061) --- mlrun/datastore/datastore_profile.py | 71 ++++++++++++++++++----- tests/datastore/test_datastore_profile.py | 56 ++++++++++++++++++ 2 files changed, 112 insertions(+), 15 deletions(-) diff --git a/mlrun/datastore/datastore_profile.py b/mlrun/datastore/datastore_profile.py index 87f252699bf..8fed2506c71 100644 --- a/mlrun/datastore/datastore_profile.py +++ b/mlrun/datastore/datastore_profile.py @@ -17,7 +17,7 @@ import json import typing import warnings -from urllib.parse import ParseResult, urlparse, urlunparse +from urllib.parse import ParseResult, urlparse import pydantic.v1 from mergedeep import merge @@ -312,7 +312,7 @@ def url_with_credentials(self): query=parsed_url.query, fragment=parsed_url.fragment, ) - return urlunparse(new_parsed_url) + return new_parsed_url.geturl() def secrets(self) -> dict: res = {} @@ -473,6 +473,59 @@ def url(self, subpath): return f"webhdfs://{self.host}:{self.http_port}{subpath}" +class TDEngineDatastoreProfile(DatastoreProfile): + """ + A profile that holds the required parameters for a TDEngine database, with the websocket scheme. + https://docs.tdengine.com/developer-guide/connecting-to-tdengine/#websocket-connection + """ + + type: str = pydantic.v1.Field("taosws") + _private_attributes = ["password"] + user: str + # The password cannot be empty in real world scenarios. It's here just because of the profiles completion design. + password: typing.Optional[str] + host: str + port: int + + def dsn(self) -> str: + """Get the Data Source Name of the configured TDEngine profile.""" + return f"{self.type}://{self.user}:{self.password}@{self.host}:{self.port}" + + @classmethod + def from_dsn(cls, dsn: str, profile_name: str) -> "TDEngineDatastoreProfile": + """ + Construct a TDEngine profile from DSN (connection string) and a name for the profile. + + :param dsn: The DSN (Data Source Name) of the TDEngine database, e.g.: ``"taosws://root:taosdata@localhost:6041"``. + :param profile_name: The new profile's name. + :return: The TDEngine profile. + """ + parsed_url = urlparse(dsn) + return cls( + name=profile_name, + user=parsed_url.username, + password=parsed_url.password, + host=parsed_url.hostname, + port=parsed_url.port, + ) + + +_DATASTORE_TYPE_TO_PROFILE_CLASS: dict[str, type[DatastoreProfile]] = { + "v3io": DatastoreProfileV3io, + "s3": DatastoreProfileS3, + "redis": DatastoreProfileRedis, + "basic": DatastoreProfileBasic, + "kafka_target": DatastoreProfileKafkaTarget, + "kafka_source": DatastoreProfileKafkaSource, + "dbfs": DatastoreProfileDBFS, + "gcs": DatastoreProfileGCS, + "az": DatastoreProfileAzureBlob, + "hdfs": DatastoreProfileHdfs, + "taosws": TDEngineDatastoreProfile, + "config": ConfigProfile, +} + + class DatastoreProfile2Json(pydantic.v1.BaseModel): @staticmethod def _to_json(attributes): @@ -523,19 +576,7 @@ def safe_literal_eval(value): decoded_dict = {k: safe_literal_eval(v) for k, v in decoded_dict.items()} datastore_type = decoded_dict.get("type") - ds_profile_factory = { - "v3io": DatastoreProfileV3io, - "s3": DatastoreProfileS3, - "redis": DatastoreProfileRedis, - "basic": DatastoreProfileBasic, - "kafka_target": DatastoreProfileKafkaTarget, - "kafka_source": DatastoreProfileKafkaSource, - "dbfs": DatastoreProfileDBFS, - "gcs": DatastoreProfileGCS, - "az": DatastoreProfileAzureBlob, - "hdfs": DatastoreProfileHdfs, - "config": ConfigProfile, - } + ds_profile_factory = _DATASTORE_TYPE_TO_PROFILE_CLASS if datastore_type in ds_profile_factory: return ds_profile_factory[datastore_type].parse_obj(decoded_dict) else: diff --git a/tests/datastore/test_datastore_profile.py b/tests/datastore/test_datastore_profile.py index b1457338e83..13aec5ce0db 100644 --- a/tests/datastore/test_datastore_profile.py +++ b/tests/datastore/test_datastore_profile.py @@ -13,15 +13,20 @@ # limitations under the License. from collections.abc import Iterator +from unittest.mock import patch import pytest +import mlrun import mlrun.common.schemas import mlrun.errors from mlrun.datastore.datastore_profile import ( + _DATASTORE_TYPE_TO_PROFILE_CLASS, + DatastoreProfile, DatastoreProfile2Json, DatastoreProfileKafkaTarget, DatastoreProfileV3io, + TDEngineDatastoreProfile, datastore_profile_read, register_temporary_client_datastore_profile, remove_temporary_client_datastore_profile, @@ -98,3 +103,54 @@ def test_from_public_json() -> None: ) profile = DatastoreProfile2Json.create_from_json(public_profile_schema.object) assert isinstance(profile, DatastoreProfileV3io), "Not the right profile" + + +class TestTDEngineProfile: + @staticmethod + def test_from_dsn() -> None: + dsn = "taosws://root:taosdata@localhost:6041" + profile_name = "test-taosws" + profile = TDEngineDatastoreProfile.from_dsn(dsn=dsn, profile_name=profile_name) + assert profile.type == "taosws" + assert profile.user == "root" + assert profile.password == "taosdata" + assert profile.host == "localhost" + assert profile.port == 6041 + assert ( + profile.dsn() == dsn + ), "Converting the profile back to DSN did not work as expected" + + @staticmethod + def test_datastore_profile_read_from_env(monkeypatch: pytest.MonkeyPatch) -> None: + profile_name = "test-profile" + project_name = "test-project" + + public_profile = mlrun.common.schemas.DatastoreProfile( + name=profile_name, + type="taosws", + object='{"type":"dGFvc3dz","name":"dGRlbmdpbmUx","user":"cm9vdA==","host":"MC4wLjAuMA==","port":"NjA0MQ=="}', + private=None, + project=project_name, + ) + + with patch( + "mlrun.db.nopdb.NopDB.get_datastore_profile", return_value=public_profile + ): + monkeypatch.setenv( + f"datastore-profiles.{project_name}.{profile_name}", + '{"password": "MTIzNA=="}', + ) + profile_read = datastore_profile_read(f"ds://{profile_name}", project_name) + + assert profile_read.type == "taosws", "Wrong profile type" + assert profile_read.password == "1234", "Wrong password" + + +def test_datastore_type_map() -> None: + assert set(_DATASTORE_TYPE_TO_PROFILE_CLASS.values()) == set( + DatastoreProfile.__subclasses__() + ), "Missing profiles in the map" + for type_, profile_class in _DATASTORE_TYPE_TO_PROFILE_CLASS.items(): + assert type_ == profile_class.schema().get("properties", {}).get("type").get( + "default" + ), "Type key and profile class type do not match" From b077c761e6c148f7ee4d1945b5aecbf824912e07 Mon Sep 17 00:00:00 2001 From: Jonathan Daniel <36337649+jond01@users.noreply.github.com> Date: Mon, 6 Jan 2025 21:05:23 +0200 Subject: [PATCH 11/15] [Model Monitoring] Support Kafka source datastore profile (#7042) --- .../schemas/model_monitoring/constants.py | 13 +- mlrun/config.py | 2 - mlrun/db/httpdb.py | 4 - mlrun/model_monitoring/helpers.py | 24 ++- mlrun/projects/project.py | 60 ++++-- mlrun/serving/server.py | 14 +- .../api/api/endpoints/model_monitoring.py | 10 +- server/py/services/api/crud/client_spec.py | 3 - .../api/crud/model_monitoring/deployment.py | 198 ++++++++++-------- .../crud/model_monitoring/model_endpoints.py | 20 +- .../crud/model_monitoring/test_deployment.py | 1 + tests/model_monitoring/test_target_path.py | 34 ++- tests/system/alerts/test_alerts.py | 7 +- tests/system/model_monitoring/test_app.py | 22 +- .../model_monitoring/test_model_monitoring.py | 15 +- 15 files changed, 253 insertions(+), 174 deletions(-) diff --git a/mlrun/common/schemas/model_monitoring/constants.py b/mlrun/common/schemas/model_monitoring/constants.py index b52668ac049..807c75a3219 100644 --- a/mlrun/common/schemas/model_monitoring/constants.py +++ b/mlrun/common/schemas/model_monitoring/constants.py @@ -228,19 +228,18 @@ class ModelEndpointTarget(MonitoringStrEnum): SQL = "sql" -class StreamKind(MonitoringStrEnum): - V3IO_STREAM = "v3io_stream" - KAFKA = "kafka" - - class TSDBTarget(MonitoringStrEnum): V3IO_TSDB = "v3io-tsdb" TDEngine = "tdengine" +class DefaultProfileName(StrEnum): + STREAM = "mm-infra-stream" + TSDB = "mm-infra-tsdb" + + class ProjectSecretKeys: ACCESS_KEY = "MODEL_MONITORING_ACCESS_KEY" - STREAM_PATH = "STREAM_PATH" TSDB_CONNECTION = "TSDB_CONNECTION" TSDB_PROFILE_NAME = "TSDB_PROFILE_NAME" STREAM_PROFILE_NAME = "STREAM_PROFILE_NAME" @@ -248,7 +247,7 @@ class ProjectSecretKeys: @classmethod def mandatory_secrets(cls): return [ - cls.STREAM_PATH, + cls.STREAM_PROFILE_NAME, cls.TSDB_CONNECTION, ] diff --git a/mlrun/config.py b/mlrun/config.py index fcdd3661114..77038dc3f36 100644 --- a/mlrun/config.py +++ b/mlrun/config.py @@ -610,8 +610,6 @@ "parquet_batching_timeout_secs": timedelta(minutes=1).total_seconds(), # See mlrun.model_monitoring.db.tsdb.ObjectTSDBFactory for available options "tsdb_connection": "", - # See mlrun.common.schemas.model_monitoring.constants.StreamKind for available options - "stream_connection": "", "tdengine": { "timeout": 10, "retries": 1, diff --git a/mlrun/db/httpdb.py b/mlrun/db/httpdb.py index 0f1e4827315..c5b025e073e 100644 --- a/mlrun/db/httpdb.py +++ b/mlrun/db/httpdb.py @@ -563,10 +563,6 @@ def connect(self, secrets=None): server_cfg.get("model_monitoring_tsdb_connection") or config.model_endpoint_monitoring.tsdb_connection ) - config.model_endpoint_monitoring.stream_connection = ( - server_cfg.get("stream_connection") - or config.model_endpoint_monitoring.stream_connection - ) config.packagers = server_cfg.get("packagers") or config.packagers server_data_prefixes = server_cfg.get("feature_store_data_prefixes") or {} for prefix in ["default", "nosql", "redisnosql"]: diff --git a/mlrun/model_monitoring/helpers.py b/mlrun/model_monitoring/helpers.py index 6563ac043c1..9f2a8af5560 100644 --- a/mlrun/model_monitoring/helpers.py +++ b/mlrun/model_monitoring/helpers.py @@ -117,6 +117,7 @@ def get_stream_path( function_name: str = mm_constants.MonitoringFunctionNames.STREAM, stream_uri: Optional[str] = None, secret_provider: Optional[Callable[[str], str]] = None, + profile: Optional[mlrun.datastore.datastore_profile.DatastoreProfile] = None, ) -> str: """ Get stream path from the project secret. If wasn't set, take it from the system configurations @@ -126,20 +127,25 @@ def get_stream_path( :param stream_uri: Stream URI. If provided, it will be used instead of the one from the project's secret. :param secret_provider: Optional secret provider to get the connection string secret. If not set, the env vars are used. + :param profile: Optional datastore profile of the stream (V3IO/KafkaSource profile). :return: Monitoring stream path to the relevant application. """ - try: - profile = _get_stream_profile(project=project, secret_provider=secret_provider) - except mlrun.errors.MLRunNotFoundError: - profile = None + profile = profile or _get_stream_profile( + project=project, secret_provider=secret_provider + ) if isinstance(profile, mlrun.datastore.datastore_profile.DatastoreProfileV3io): stream_uri = "v3io" - - stream_uri = stream_uri or mlrun.get_secret_or_env( - key=mm_constants.ProjectSecretKeys.STREAM_PATH, secret_provider=secret_provider - ) + elif isinstance( + profile, mlrun.datastore.datastore_profile.DatastoreProfileKafkaSource + ): + stream_uri = f"kafka://{profile.brokers[0]}" + else: + raise mlrun.errors.MLRunValueError( + f"Received an unexpected stream profile type: {type(profile)}\n" + "Expects `DatastoreProfileV3io` or `DatastoreProfileKafkaSource`." + ) if not stream_uri or stream_uri == "v3io": stream_uri = mlrun.mlconf.get_model_monitoring_file_target_path( @@ -273,7 +279,7 @@ def _get_profile( ) if not profile_name: raise mlrun.errors.MLRunNotFoundError( - f"Not found `{profile_name_key}` profile name" + f"Not found `{profile_name_key}` profile name for project '{project}'" ) return mlrun.datastore.datastore_profile.datastore_profile_read( url=f"ds://{profile_name}", project_name=project, secrets=secret_provider diff --git a/mlrun/projects/project.py b/mlrun/projects/project.py index b95e7f61c5c..d1e3e59b88b 100644 --- a/mlrun/projects/project.py +++ b/mlrun/projects/project.py @@ -29,6 +29,7 @@ from copy import deepcopy from os import environ, makedirs, path from typing import Callable, Optional, Union, cast +from urllib.parse import urlparse import dotenv import git @@ -3608,9 +3609,11 @@ def export(self, filepath=None, include_files: Optional[str] = None): def set_model_monitoring_credentials( self, access_key: Optional[str] = None, - stream_path: Optional[str] = None, + stream_path: Optional[str] = None, # Deprecated tsdb_connection: Optional[str] = None, replace_creds: bool = False, + *, + stream_profile_name: Optional[str] = None, ): """ Set the credentials that will be used by the project's model monitoring @@ -3622,13 +3625,13 @@ def set_model_monitoring_credentials( * None - will be set from the system configuration. * v3io - for v3io endpoint store, pass `v3io` and the system will generate the exact path. - :param stream_path: Path to the model monitoring stream. By default, None. Options: + :param stream_path: (Deprecated) This argument is deprecated. Use ``stream_profile_name`` instead. + Path to the model monitoring stream. By default, None. Options: - * None - will be set from the system configuration. - * v3io - for v3io stream, pass `v3io` and the system will generate the exact - path. - * Kafka - for Kafka stream, provide the full connection string without custom - topic, for example kafka://:. + * ``"v3io"`` - for v3io stream, pass ``"v3io"`` and the system will generate + the exact path. + * Kafka - for Kafka stream, provide the full connection string without acustom + topic, for example ``"kafka://:"``. :param tsdb_connection: Connection string to the time series database. By default, None. Options: @@ -3642,29 +3645,58 @@ def set_model_monitoring_credentials( your project this action can cause data loose and will require redeploying all model monitoring functions & model monitoring infra & tracked model server. + :param stream_profile_name: The datastore profile name of the stream to be used in model monitoring. + The supported profiles are: + + * :py:class:`~mlrun.datastore.datastore_profile.DatastoreProfileV3io` + * :py:class:`~mlrun.datastore.datastore_profile.DatastoreProfileKafkaSource` + + You need to register one of them, and pass the profile's name. """ db = mlrun.db.get_run_db(secrets=self._secrets) if tsdb_connection == "v3io": tsdb_profile = mlrun.datastore.datastore_profile.DatastoreProfileV3io( - name="mm-infra-tsdb" + name=mm_constants.DefaultProfileName.TSDB ) self.register_datastore_profile(tsdb_profile) tsdb_profile_name = tsdb_profile.name else: tsdb_profile_name = None - if stream_path == "v3io": - stream_profile = mlrun.datastore.datastore_profile.DatastoreProfileV3io( - name="mm-infra-stream" + + if stream_path: + warnings.warn( + "The `stream_path` argument is deprecated and will be removed in MLRun version 1.10.0. " + "Use `stream_profile_name` instead.", + FutureWarning, ) + if stream_profile_name: + raise mlrun.errors.MLRunValueError( + "If you set `stream_profile_name`, you must not pass `stream_path`." + ) + if stream_path == "v3io": + stream_profile = mlrun.datastore.datastore_profile.DatastoreProfileV3io( + name=mm_constants.DefaultProfileName.STREAM + ) + else: + parsed_stream = urlparse(stream_path) + if parsed_stream.scheme != "kafka": + raise mlrun.errors.MLRunValueError( + f"Unsupported `stream_path`: '{stream_path}'." + ) + stream_profile = ( + mlrun.datastore.datastore_profile.DatastoreProfileKafkaSource( + name=mm_constants.DefaultProfileName.STREAM, + brokers=[parsed_stream.netloc], + topics=[], + ) + ) self.register_datastore_profile(stream_profile) stream_profile_name = stream_profile.name - else: - stream_profile_name = None + db.set_model_monitoring_credentials( project=self.name, credentials={ "access_key": access_key, - "stream_path": stream_path, "tsdb_connection": tsdb_connection, "tsdb_profile_name": tsdb_profile_name, "stream_profile_name": stream_profile_name, diff --git a/mlrun/serving/server.py b/mlrun/serving/server.py index 22ecdf68127..c37b4f01492 100644 --- a/mlrun/serving/server.py +++ b/mlrun/serving/server.py @@ -44,6 +44,8 @@ from .states import RootFlowStep, RouterStep, get_function, graph_root_setter from .utils import event_id_key, event_path_key +DUMMY_STREAM = "dummy://" + class _StreamContext: """Handles the stream context for the events stream process. Includes the configuration for the output stream @@ -72,14 +74,20 @@ def __init__(self, enabled: bool, parameters: dict, function_uri: str): function_uri, config.default_project ) - self.stream_uri = mlrun.model_monitoring.get_stream_path(project=project) + stream_args = parameters.get("stream_args", {}) + + if log_stream == DUMMY_STREAM: + # Dummy stream used for testing, see tests/serving/test_serving.py + self.stream_uri = DUMMY_STREAM + elif not stream_args.get("mock"): # if not a mock: `context.is_mock = True` + self.stream_uri = mlrun.model_monitoring.get_stream_path( + project=project + ) if log_stream: # Update the stream path to the log stream value self.stream_uri = log_stream.format(project=project) - stream_args = parameters.get("stream_args", {}) - self.output_stream = get_stream_pusher(self.stream_uri, **stream_args) diff --git a/server/py/services/api/api/endpoints/model_monitoring.py b/server/py/services/api/api/endpoints/model_monitoring.py index ed5230928f7..7638caf73ee 100644 --- a/server/py/services/api/api/endpoints/model_monitoring.py +++ b/server/py/services/api/api/endpoints/model_monitoring.py @@ -360,7 +360,6 @@ async def delete_model_monitoring_function( def set_model_monitoring_credentials( commons: Annotated[_CommonParams, Depends(_common_parameters)], access_key: Optional[str] = None, - stream_path: Optional[str] = None, tsdb_connection: Optional[str] = None, tsdb_profile_name: Optional[str] = None, stream_profile_name: Optional[str] = None, @@ -372,13 +371,6 @@ def set_model_monitoring_credentials( model monitoring or serving function. :param commons: The common parameters of the request. :param access_key: Model Monitoring access key for managing user permissions. - :param stream_path: Path to the model monitoring stream. By default, None. - Options: - 1. None, will be set from the system configuration. - 2. v3io - for v3io stream, - pass `v3io` and the system will generate the exact path. - 3. Kafka - for Kafka stream, please provide full connection string without - custom topic, for example kafka://:. :param tsdb_connection: Connection string to the time series database. By default, None. Options: 1. None, will be set from the system configuration. @@ -388,6 +380,7 @@ def set_model_monitoring_credentials( for example taosws://:@:. :param tsdb_profile_name: TSDB datastore profile name. If specified, takes precedence over tsdb_connection. :param stream_profile_name: Stream datastore profile name. If specified, takes precedence over stream_path. + The profile can be V3IO or KafkaSource. :param replace_creds: If True, it will force the credentials update. By default, False. """ MonitoringDeployment( @@ -397,7 +390,6 @@ def set_model_monitoring_credentials( model_monitoring_access_key=commons.model_monitoring_access_key, ).set_credentials( access_key=access_key, - stream_path=stream_path, tsdb_connection=tsdb_connection, tsdb_profile_name=tsdb_profile_name, stream_profile_name=stream_profile_name, diff --git a/server/py/services/api/crud/client_spec.py b/server/py/services/api/crud/client_spec.py index ea29f34fe0d..a1b74f36b4d 100644 --- a/server/py/services/api/crud/client_spec.py +++ b/server/py/services/api/crud/client_spec.py @@ -124,9 +124,6 @@ def get_client_spec( model_monitoring_tsdb_connection=self._get_config_value_if_not_default( "model_endpoint_monitoring.tsdb_connection" ), - model_monitoring_stream_connection=self._get_config_value_if_not_default( - "model_endpoint_monitoring.stream_connection" - ), packagers=self._get_config_value_if_not_default("packagers"), alerts_mode=self._get_config_value_if_not_default("alerts.mode"), system_id=self._get_config_value_if_not_default("system_id"), diff --git a/server/py/services/api/crud/model_monitoring/deployment.py b/server/py/services/api/crud/model_monitoring/deployment.py index 5be33eb0a72..dea9abb0fc6 100644 --- a/server/py/services/api/crud/model_monitoring/deployment.py +++ b/server/py/services/api/crud/model_monitoring/deployment.py @@ -31,6 +31,7 @@ import mlrun.common.model_monitoring.helpers import mlrun.common.schemas import mlrun.common.schemas.model_monitoring.constants as mm_constants +import mlrun.datastore.datastore_profile import mlrun.model_monitoring import mlrun.model_monitoring.api import mlrun.model_monitoring.applications @@ -38,6 +39,7 @@ import mlrun.model_monitoring.stream_processing import mlrun.model_monitoring.writer import mlrun.serving.states +import mlrun.utils.v3io_clients from mlrun import feature_store as fstore from mlrun.config import config from mlrun.model_monitoring.writer import ModelMonitoringWriter @@ -1018,12 +1020,7 @@ def _delete_model_monitoring_stream_resources( def _get_monitoring_mandatory_project_secrets(self) -> dict[str, str]: credentials_dict = { - key: services.api.crud.Secrets().get_project_secret( - project=self.project, - provider=mlrun.common.schemas.SecretProviderName.kubernetes, - secret_key=key, - allow_secrets_from_k8s=True, - ) + key: mlrun.get_secret_or_env(key, secret_provider=self._secret_provider) for key in mlrun.common.schemas.model_monitoring.ProjectSecretKeys.mandatory_secrets() } @@ -1048,10 +1045,94 @@ def check_if_credentials_are_set( "or pass fetch_credentials_from_sys_config=True when using enable_model_monitoring API/SDK." ) + def _validate_stream_profile(self, stream_profile_name: str) -> None: + try: + stream_profile = mlrun.datastore.datastore_profile.datastore_profile_read( + url=f"ds://{stream_profile_name}", + project_name=self.project, + secrets=self._secret_provider, + ) + except mlrun.errors.MLRunNotFoundError: + raise mlrun.errors.MLRunNotFoundError( + f"The given model monitoring stream profile name '{stream_profile_name}' " + "was not found. Please make sure to register it properly in the project with " + "`project.register_datastore_profile(stream_profile)`." + ) + if isinstance( + stream_profile, + mlrun.datastore.datastore_profile.DatastoreProfileKafkaSource, + ): + self._validate_kafka_stream(stream_profile) + elif isinstance( + stream_profile, mlrun.datastore.datastore_profile.DatastoreProfileV3io + ): + self._validate_v3io_stream(stream_profile) + else: + raise mlrun.errors.MLRunInvalidMMStoreTypeError( + f"The model monitoring stream profile is of an unexpected type: '{type(stream_profile)}'\n" + "Expects `DatastoreProfileV3io` or `DatastoreProfileKafkaSource`." + ) + + def _validate_kafka_stream( + self, + kafka_profile: mlrun.datastore.datastore_profile.DatastoreProfileKafkaSource, + ) -> None: + if kafka_profile.topics: + raise mlrun.errors.MLRunInvalidMMStoreTypeError( + "Custom Kafka topics are not supported" + ) + self._verify_kafka_access(kafka_profile) + + @staticmethod + def _verify_kafka_access( + kafka_profile: mlrun.datastore.datastore_profile.DatastoreProfileKafkaSource, + ) -> None: + import kafka + import kafka.errors + + kafka_brokers = kafka_profile.brokers + try: + # The following constructor attempts to establish a connection + consumer = kafka.KafkaConsumer(brokers=kafka_brokers) + except kafka.errors.NoBrokersAvailable as err: + logger.warn( + "No Kafka brokers available for the given kafka source profile in model monitoring", + kafka_brokers=kafka_brokers, + err=mlrun.errors.err_to_str(err), + ) + raise + else: + consumer.close() + + def _validate_v3io_stream( + self, + v3io_profile: mlrun.datastore.datastore_profile.DatastoreProfileV3io, + ) -> None: + if mlrun.mlconf.is_ce_mode(): + raise mlrun.errors.MLRunInvalidMMStoreTypeError( + "MLRun CE supports only Kafka streams, received a V3IO profile for the stream" + ) + self._verify_v3io_access(v3io_profile) + + def _verify_v3io_access( + self, v3io_profile: mlrun.datastore.datastore_profile.DatastoreProfileV3io + ) -> None: + stream_path = mlrun.model_monitoring.get_stream_path( + project=self.project, profile=v3io_profile + ) + container, path = split_path(stream_path) + + v3io_client = mlrun.utils.v3io_clients.get_v3io_client( + endpoint=mlrun.mlconf.v3io_api, access_key=v3io_profile.v3io_access_key + ) + # We don't expect the stream to exist. The purpose is to make sure we have access. + v3io_client.stream.describe( + container, path, raise_for_status=[HTTPStatus.OK, HTTPStatus.NOT_FOUND] + ) + def set_credentials( self, access_key: typing.Optional[str] = None, - stream_path: typing.Optional[str] = None, tsdb_connection: typing.Optional[str] = None, tsdb_profile_name: typing.Optional[str] = None, stream_profile_name: typing.Optional[str] = None, @@ -1062,13 +1143,6 @@ def set_credentials( Set the model monitoring credentials for the project. The credentials are stored in the project secrets. :param access_key: Model Monitoring access key for managing user permissions. - :param stream_path: Path to the model monitoring stream. By default, None. - Options: - 1. None, will be set from the system configuration. - 2. v3io - for v3io stream, - pass `v3io` and the system will generate the exact path. - 3. Kafka - for Kafka stream, please provide full connection string without - custom topic, for example kafka://:. :param tsdb_connection: Connection string to the time series database. By default, None. Options: 1. None, will be set from the system configuration. @@ -1078,7 +1152,7 @@ def set_credentials( for example taosws://:@:. :param tsdb_profile_name: The TSDB profile name to be used in the project's model monitoring framework. :param stream_profile_name: The stream profile name to be used in the project's model monitoring - framework. + framework. Either V3IO or KafkaSource profile. :param replace_creds: If True, the credentials will be set even if they are already set. :param _default_secrets_v3io: Optional parameter for the upgrade process in which the v3io default secret key is set. @@ -1086,10 +1160,11 @@ def set_credentials( provided different creds. :raise MLRunInvalidMMStoreTypeError: If the user provided invalid credentials. """ + if not replace_creds: try: self.check_if_credentials_are_set() - if self._is_the_same_cred(stream_path, tsdb_connection): + if self._is_the_same_cred(stream_profile_name, tsdb_connection): logger.debug( "The same credentials are already set for the project - aborting with no error", project=self.project, @@ -1112,48 +1187,14 @@ def set_credentials( mlrun.common.schemas.model_monitoring.ProjectSecretKeys.ACCESS_KEY ) - # stream_path - if not stream_path: - stream_path = ( - old_secrets_dict.get( - mlrun.common.schemas.model_monitoring.ProjectSecretKeys.STREAM_PATH - ) - or mlrun.mlconf.model_endpoint_monitoring.stream_connection - or _default_secrets_v3io - ) - + stream_profile_name = stream_profile_name or old_secrets_dict.get( + mlrun.common.schemas.model_monitoring.ProjectSecretKeys.STREAM_PROFILE_NAME + ) if stream_profile_name: - # TODO: Add checks. + self._validate_stream_profile(stream_profile_name) secrets_dict[ mlrun.common.schemas.model_monitoring.ProjectSecretKeys.STREAM_PROFILE_NAME ] = stream_profile_name - if stream_path: - if ( - stream_path == mm_constants.V3IO_MODEL_MONITORING_DB - and mlrun.mlconf.is_ce_mode() - ): - raise mlrun.errors.MLRunInvalidMMStoreTypeError( - "In CE mode, only kafka stream are supported for stream path" - ) - elif stream_path.startswith("kafka://") and "?topic" in stream_path: - raise mlrun.errors.MLRunInvalidMMStoreTypeError( - "Custom kafka topic is not allowed" - ) - elif not stream_path.startswith("kafka://") and ( - stream_path != mm_constants.V3IO_MODEL_MONITORING_DB - ): - raise mlrun.errors.MLRunInvalidMMStoreTypeError( - "Currently only Kafka connection is supported for non-v3io stream," - "please provide a full URL (e.g. kafka://:)" - ) - secrets_dict[ - mlrun.common.schemas.model_monitoring.ProjectSecretKeys.STREAM_PATH - ] = stream_path - elif stream_profile_name is None: - raise mlrun.errors.MLRunInvalidMMStoreTypeError( - "You must provide a valid stream path connection while using set_model_monitoring_credentials " - "API/SDK or in the system config" - ) if not tsdb_connection: tsdb_connection = ( @@ -1198,9 +1239,7 @@ def set_credentials( for key in ( mlrun.common.schemas.model_monitoring.ProjectSecretKeys.mandatory_secrets() ): - try: - secrets_dict[key] - except KeyError: + if key not in secrets_dict: raise mlrun.errors.MLRunInvalidMMStoreTypeError( f"You must provide a valid {key} connection while using set_model_monitoring_credentials." ) @@ -1212,12 +1251,6 @@ def set_credentials( ) ) - if not mlrun.mlconf.is_ce_mode(): - stream_path = secrets_dict.get( - mlrun.common.schemas.model_monitoring.ProjectSecretKeys.STREAM_PATH - ) - self._verify_v3io_access(stream_path) - services.api.crud.Secrets().store_project_secrets( project=self.project, secrets=mlrun.common.schemas.SecretsData( @@ -1226,41 +1259,22 @@ def set_credentials( ), ) - def _verify_v3io_access(self, stream_path: str): - import v3io.dataplane - - stream_path = mlrun.model_monitoring.get_stream_path( - project=self.project, - stream_uri=stream_path, - secret_provider=self._secret_provider, - ) - v3io_client = v3io.dataplane.Client(endpoint=mlrun.mlconf.v3io_api) - container, path = split_path(stream_path) - # We don't expect the stream to exist. The purpose is to make sure we have access. - v3io_client.stream.describe( - container, - path, - access_key=self.model_monitoring_access_key, - raise_for_status=[200, 404], - ) - - def _is_the_same_cred(self, stream_path: str, tsdb_connection: str) -> bool: + def _is_the_same_cred( + self, + stream_profile_name: typing.Optional[str], + tsdb_connection: typing.Optional[str], + ) -> bool: credentials_dict = { - key: services.api.crud.Secrets().get_project_secret( - project=self.project, - provider=mlrun.common.schemas.SecretProviderName.kubernetes, - secret_key=key, - allow_secrets_from_k8s=True, - ) + key: mlrun.get_secret_or_env(key, self._secret_provider) for key in mlrun.common.schemas.model_monitoring.ProjectSecretKeys.mandatory_secrets() } - old_stream = credentials_dict[ - mlrun.common.schemas.model_monitoring.ProjectSecretKeys.STREAM_PATH + old_stream_profile_name = credentials_dict[ + mlrun.common.schemas.model_monitoring.ProjectSecretKeys.STREAM_PROFILE_NAME ] - if stream_path and old_stream != stream_path: + if stream_profile_name and old_stream_profile_name != stream_profile_name: logger.debug( - "User provided different stream path", + "User provided different stream profile name", ) return False old_tsdb = credentials_dict[ diff --git a/server/py/services/api/crud/model_monitoring/model_endpoints.py b/server/py/services/api/crud/model_monitoring/model_endpoints.py index f4a0cf73b47..da36d0ce73e 100644 --- a/server/py/services/api/crud/model_monitoring/model_endpoints.py +++ b/server/py/services/api/crud/model_monitoring/model_endpoints.py @@ -874,18 +874,22 @@ def delete_model_endpoints_resources( V3IO resources. """ logger.debug( - "Deleting model monitoring endpoints resources", - project_name=project_name, + "Deleting model monitoring endpoints resources", project_name=project_name ) + try: + stream_path = mlrun.model_monitoring.get_stream_path( + project=project_name, + secret_provider=services.api.crud.secrets.get_project_secret_provider( + project=project_name + ), + ) + except mlrun.errors.MLRunNotFoundError: + # There is no MM infra in place for the project - no resources to delete + return + # We would ideally base on config.v3io_api but can't for backwards compatibility reasons, # we're using the igz version heuristic # TODO : adjust for ce scenario - stream_path = mlrun.model_monitoring.get_stream_path( - project=project_name, - secret_provider=services.api.crud.secrets.get_project_secret_provider( - project=project_name - ), - ) if stream_path.startswith("v3io") and ( not mlrun.mlconf.igz_version or not mlrun.mlconf.v3io_api ): diff --git a/server/py/services/api/tests/unit/crud/model_monitoring/test_deployment.py b/server/py/services/api/tests/unit/crud/model_monitoring/test_deployment.py index 655408303f7..fbdae7e96d9 100644 --- a/server/py/services/api/tests/unit/crud/model_monitoring/test_deployment.py +++ b/server/py/services/api/tests/unit/crud/model_monitoring/test_deployment.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import os import typing from collections.abc import Iterator diff --git a/tests/model_monitoring/test_target_path.py b/tests/model_monitoring/test_target_path.py index f920a50840d..57cbb3f783a 100644 --- a/tests/model_monitoring/test_target_path.py +++ b/tests/model_monitoring/test_target_path.py @@ -13,6 +13,7 @@ # limitations under the License. import os +from collections.abc import Iterator from unittest import mock import pytest @@ -20,6 +21,12 @@ import mlrun.common.schemas.model_monitoring.constants as mm_constants import mlrun.config import mlrun.model_monitoring +from mlrun.datastore.datastore_profile import ( + DatastoreProfileKafkaSource, + DatastoreProfileV3io, + register_temporary_client_datastore_profile, + remove_temporary_client_datastore_profile, +) TEST_PROJECT = "test-model-endpoints" @@ -71,14 +78,29 @@ def test_get_file_target_path(): ) -def test_get_stream_path(monkeypatch: pytest.MonkeyPatch): - # default stream path - stream_path = mlrun.model_monitoring.get_stream_path(project=TEST_PROJECT) +def test_get_v3io_stream_path() -> None: + stream_path = mlrun.model_monitoring.get_stream_path( + project=TEST_PROJECT, profile=DatastoreProfileV3io(name="tmp") + ) assert stream_path == f"v3io:///projects/{TEST_PROJECT}/model-endpoints/stream" - # kafka stream path from env - monkeypatch.setenv("STREAM_PATH", "kafka://some_kafka_broker:8080") - stream_path = mlrun.model_monitoring.get_stream_path(project=TEST_PROJECT) + +@pytest.fixture +def kafka_profile_name() -> Iterator[str]: + profile_name = "kafka-prof" + profile = DatastoreProfileKafkaSource( + name=profile_name, brokers=["some_kafka_broker:8080"], topics=[] + ) + register_temporary_client_datastore_profile(profile) + yield profile_name + remove_temporary_client_datastore_profile(profile_name) + + +def test_get_kafka_profile_stream_path(kafka_profile_name: str) -> None: + # kafka stream path from datastore profile + stream_path = mlrun.model_monitoring.get_stream_path( + project=TEST_PROJECT, secret_provider=lambda _: kafka_profile_name + ) assert ( stream_path == f"kafka://some_kafka_broker:8080?topic=monitoring_stream_{TEST_PROJECT}" diff --git a/tests/system/alerts/test_alerts.py b/tests/system/alerts/test_alerts.py index 58a011f6ea9..ca54093bcba 100644 --- a/tests/system/alerts/test_alerts.py +++ b/tests/system/alerts/test_alerts.py @@ -11,8 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# + import json +import os import time import typing @@ -31,6 +32,7 @@ ModelEndpointList, ) from mlrun.datastore import get_stream_pusher +from mlrun.datastore.datastore_profile import DatastoreProfileV3io from mlrun.model_monitoring.helpers import get_stream_path from tests.system.base import TestMLRunSystem @@ -158,7 +160,7 @@ def test_drift_detection_alert(self): """ # enable model monitoring - deploy writer function self.project.set_model_monitoring_credentials( - stream_path=mlrun.mlconf.model_endpoint_monitoring.stream_connection, + stream_path=os.getenv("MLRUN_MODEL_ENDPOINT_MONITORING__STREAM_CONNECTION"), tsdb_connection=mlrun.mlconf.model_endpoint_monitoring.tsdb_connection, ) self.project.enable_model_monitoring(image=self.image or "mlrun/mlrun") @@ -181,6 +183,7 @@ def test_drift_detection_alert(self): stream_uri = get_stream_path( project=self.project.metadata.name, function_name=mm_constants.MonitoringFunctionNames.WRITER, + profile=DatastoreProfileV3io(name="tmp"), ) output_stream = get_stream_pusher( stream_uri, diff --git a/tests/system/model_monitoring/test_app.py b/tests/system/model_monitoring/test_app.py index 70b82c7f21c..5dba9685980 100644 --- a/tests/system/model_monitoring/test_app.py +++ b/tests/system/model_monitoring/test_app.py @@ -14,6 +14,7 @@ import concurrent.futures import json +import os import pickle import time import typing @@ -40,6 +41,7 @@ import mlrun.model_monitoring import mlrun.model_monitoring.api import mlrun.model_monitoring.applications.histogram_data_drift +from mlrun.datastore.datastore_profile import DatastoreProfileV3io from mlrun.datastore.targets import ParquetTarget from mlrun.model_monitoring.applications import ( SUPPORTED_EVIDENTLY_VERSION, @@ -110,7 +112,7 @@ def custom_setup(cls, project_name: str) -> None: project_name, "./", allow_cross_project=True ) project.set_model_monitoring_credentials( - stream_path=mlrun.mlconf.model_endpoint_monitoring.stream_connection, + stream_path=os.getenv("MLRUN_MODEL_ENDPOINT_MONITORING__STREAM_CONNECTION"), tsdb_connection=mlrun.mlconf.model_endpoint_monitoring.tsdb_connection, ) @@ -876,7 +878,7 @@ def test_model_monitoring_crud(self) -> None: image=self.image or "mlrun/mlrun" ) self.project.set_model_monitoring_credentials( - stream_path=mlrun.mlconf.model_endpoint_monitoring.stream_connection, + stream_path=os.getenv("MLRUN_MODEL_ENDPOINT_MONITORING__STREAM_CONNECTION"), tsdb_connection=mlrun.mlconf.model_endpoint_monitoring.tsdb_connection, ) self.project.enable_model_monitoring( @@ -952,7 +954,9 @@ def test_model_monitoring_crud(self) -> None: # controller and writer(with has stream) should be deleted for name in mm_constants.MonitoringFunctionNames.list(): stream_path = mlrun.model_monitoring.helpers.get_stream_path( - project=self.project.name, function_name=name + project=self.project.name, + function_name=name, + profile=DatastoreProfileV3io(name="tmp"), ) _, container, stream_path = ( mlrun.common.model_monitoring.helpers.parse_model_endpoint_store_prefix( @@ -980,6 +984,7 @@ def test_model_monitoring_crud(self) -> None: stream_path = mlrun.model_monitoring.helpers.get_stream_path( project=self.project.name, function_name=mm_constants.HistogramDataDriftApplicationConstants.NAME, + profile=DatastoreProfileV3io(name="tmp"), ) _, container, stream_path = ( mlrun.common.model_monitoring.helpers.parse_model_endpoint_store_prefix( @@ -995,6 +1000,7 @@ def test_model_monitoring_crud(self) -> None: stream_path = mlrun.model_monitoring.helpers.get_stream_path( project=self.project.name, function_name=mm_constants.HistogramDataDriftApplicationConstants.NAME, + profile=DatastoreProfileV3io(name="tmp"), ) _, container, stream_path = ( mlrun.common.model_monitoring.helpers.parse_model_endpoint_store_prefix( @@ -1258,7 +1264,7 @@ def _test_endpoint( def test_different_kind_of_serving(self) -> None: self.function_name = "serving-router" self.project.set_model_monitoring_credentials( - stream_path=mlrun.mlconf.model_endpoint_monitoring.stream_connection, + stream_path=os.getenv("MLRUN_MODEL_ENDPOINT_MONITORING__STREAM_CONNECTION"), tsdb_connection=mlrun.mlconf.model_endpoint_monitoring.tsdb_connection, ) self.project.enable_model_monitoring( @@ -1303,7 +1309,7 @@ def test_different_kind_of_serving(self) -> None: def test_tracking(self) -> None: self.function_name = "serving-1" self.project.set_model_monitoring_credentials( - stream_path=mlrun.mlconf.model_endpoint_monitoring.stream_connection, + stream_path=os.getenv("MLRUN_MODEL_ENDPOINT_MONITORING__STREAM_CONNECTION"), tsdb_connection=mlrun.mlconf.model_endpoint_monitoring.tsdb_connection, ) self.project.enable_model_monitoring( @@ -1384,7 +1390,7 @@ def test_tracking(self) -> None: def test_enable_model_monitoring_after_failure(self) -> None: self.function_name = "test-function" self.project.set_model_monitoring_credentials( - stream_path=mlrun.mlconf.model_endpoint_monitoring.stream_connection, + stream_path=os.getenv("MLRUN_MODEL_ENDPOINT_MONITORING__STREAM_CONNECTION"), tsdb_connection=mlrun.mlconf.model_endpoint_monitoring.tsdb_connection, ) @@ -1490,7 +1496,7 @@ class TestAppJobModelEndpointData(TestMLRunSystem): def _set_credentials(self) -> None: self.project.set_model_monitoring_credentials( - stream_path=mlrun.mlconf.model_endpoint_monitoring.stream_connection, + stream_path=os.getenv("MLRUN_MODEL_ENDPOINT_MONITORING__STREAM_CONNECTION"), tsdb_connection=mlrun.mlconf.model_endpoint_monitoring.tsdb_connection, ) @@ -1627,7 +1633,7 @@ class TestBatchServingWithSampling(TestMLRunSystem): def _set_credentials(self) -> None: self.project.set_model_monitoring_credentials( - stream_path=mlrun.mlconf.model_endpoint_monitoring.stream_connection, + stream_path=os.getenv("MLRUN_MODEL_ENDPOINT_MONITORING__STREAM_CONNECTION"), tsdb_connection=mlrun.mlconf.model_endpoint_monitoring.tsdb_connection, ) diff --git a/tests/system/model_monitoring/test_model_monitoring.py b/tests/system/model_monitoring/test_model_monitoring.py index c45ddcf1685..6d43bb78b69 100644 --- a/tests/system/model_monitoring/test_model_monitoring.py +++ b/tests/system/model_monitoring/test_model_monitoring.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import json import os import pickle @@ -66,7 +67,7 @@ def setup_method(self, method): ]: return self.project.set_model_monitoring_credentials( - stream_path=mlrun.mlconf.model_endpoint_monitoring.stream_connection, + stream_path=os.getenv("MLRUN_MODEL_ENDPOINT_MONITORING__STREAM_CONNECTION"), tsdb_connection=mlrun.mlconf.model_endpoint_monitoring.tsdb_connection, ) @@ -110,7 +111,7 @@ def test_get_model_endpoint_metrics(self): tsdb_connection_string=mlrun.mlconf.model_endpoint_monitoring.tsdb_connection, ) self.project.set_model_monitoring_credentials( - stream_path=mlrun.mlconf.model_endpoint_monitoring.stream_connection, + stream_path=os.getenv("MLRUN_MODEL_ENDPOINT_MONITORING__STREAM_CONNECTION"), tsdb_connection=mlrun.mlconf.model_endpoint_monitoring.tsdb_connection, ) db = mlrun.get_run_db() @@ -534,7 +535,7 @@ def test_basic_model_monitoring(self) -> None: project = self.project project.set_model_monitoring_credentials( - stream_path=mlrun.mlconf.model_endpoint_monitoring.stream_connection, + stream_path=os.getenv("MLRUN_MODEL_ENDPOINT_MONITORING__STREAM_CONNECTION"), tsdb_connection=mlrun.mlconf.model_endpoint_monitoring.tsdb_connection, replace_creds=True, # remove once ML-7501 is resolved ) @@ -1141,7 +1142,7 @@ def test_batch_drift(self): # Deploy model monitoring infra project.set_model_monitoring_credentials( - stream_path=mlrun.mlconf.model_endpoint_monitoring.stream_connection, + stream_path=os.getenv("MLRUN_MODEL_ENDPOINT_MONITORING__STREAM_CONNECTION"), tsdb_connection=mlrun.mlconf.model_endpoint_monitoring.tsdb_connection, ) project.enable_model_monitoring( @@ -1371,7 +1372,7 @@ def custom_setup(self) -> None: mlrun.runtimes.utils.global_context.set(None) # Set the model monitoring credentials self.project.set_model_monitoring_credentials( - stream_path=mlrun.mlconf.model_endpoint_monitoring.stream_connection, + stream_path=os.getenv("MLRUN_MODEL_ENDPOINT_MONITORING__STREAM_CONNECTION"), tsdb_connection=mlrun.mlconf.model_endpoint_monitoring.tsdb_connection, ) @@ -1510,7 +1511,7 @@ def _test_v3io_tsdb_record(cls) -> None: def test_record(self) -> None: self.project.set_model_monitoring_credentials( - stream_path=mlrun.mlconf.model_endpoint_monitoring.stream_connection, + stream_path=os.getenv("MLRUN_MODEL_ENDPOINT_MONITORING__STREAM_CONNECTION"), tsdb_connection=mlrun.mlconf.model_endpoint_monitoring.tsdb_connection, ) self.project.enable_model_monitoring( @@ -1548,7 +1549,7 @@ def test_model_endpoint_with_many_features(self) -> None: project = self.project project.set_model_monitoring_credentials( - stream_path=mlrun.mlconf.model_endpoint_monitoring.stream_connection, + stream_path=os.getenv("MLRUN_MODEL_ENDPOINT_MONITORING__STREAM_CONNECTION"), tsdb_connection=mlrun.mlconf.model_endpoint_monitoring.tsdb_connection, ) From 35d22985f375d3b2b50b3bbb2c2c5b20ea907fb1 Mon Sep 17 00:00:00 2001 From: davesh0812 <85231462+davesh0812@users.noreply.github.com> Date: Tue, 7 Jan 2025 12:00:05 +0200 Subject: [PATCH 12/15] [Tests] Fix TestInferenceWithSpecialChars::test_inference_feature_set & TestRecordResults::test_inference_feature_set (#7060) --- tests/system/model_monitoring/test_app.py | 4 +--- .../system/model_monitoring/test_model_monitoring.py | 12 +++++++----- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/system/model_monitoring/test_app.py b/tests/system/model_monitoring/test_app.py index 5dba9685980..e8a46074141 100644 --- a/tests/system/model_monitoring/test_app.py +++ b/tests/system/model_monitoring/test_app.py @@ -751,8 +751,6 @@ def custom_setup_class(cls) -> None: cls.training_set = cls.x_train.join(cls.y_train) cls.test_set = cls.x_test.join(cls.y_test) cls.infer_results_df = cls.test_set - # endpoint - cls.endpoint_id = "58d42fdd76ad999c377fad1adcafd2790b5a89b9" cls.function_name = f"{cls.name_prefix}-function" # training cls._train() @@ -847,7 +845,7 @@ def test_inference_feature_set(self) -> None: self._test_v3io_records( mep.metadata.uid, inputs=set(self.columns), outputs=set(self.y_name) ) - self._test_predictions_table(self.endpoint_id, should_be_empty=True) + self._test_predictions_table(mep.metadata.uid, should_be_empty=True) @TestMLRunSystem.skip_test_if_env_not_configured diff --git a/tests/system/model_monitoring/test_model_monitoring.py b/tests/system/model_monitoring/test_model_monitoring.py index 6d43bb78b69..699bb96199b 100644 --- a/tests/system/model_monitoring/test_model_monitoring.py +++ b/tests/system/model_monitoring/test_model_monitoring.py @@ -1364,8 +1364,8 @@ def custom_setup_class(cls) -> None: cls.infer_results_df[mlrun.common.schemas.EventFieldType.TIMESTAMP] = ( mlrun.utils.datetime_now() ) - cls.endpoint_id = "5d6ce0e704442c0ac59a933cb4d238baba83bb5d" cls.function_name = f"{cls.name_prefix}-function" + cls.model_endpoint_name = f"{cls.name_prefix}-test" cls._train() def custom_setup(self) -> None: @@ -1393,10 +1393,13 @@ def _train(cls) -> None: def _get_monitoring_feature_set(self) -> mlrun.feature_store.FeatureSet: model_endpoint = mlrun.get_run_db().get_model_endpoint( - project=self.project_name, endpoint_id=self.endpoint_id, name="testsssssss" + project=self.project_name, + name=self.model_endpoint_name, + function_name=self.function_name, + function_tag="latest", ) return mlrun.feature_store.get_feature_set( - model_endpoint.status.monitoring_feature_set_uri + model_endpoint.spec.monitoring_feature_set_uri ) def _test_feature_names(self) -> None: @@ -1430,9 +1433,8 @@ def test_inference_feature_set(self) -> None: model_path=self.project.get_artifact_uri( key=self.model_name, category="model", tag="latest" ), - model_endpoint_name=f"{self.name_prefix}-test", + model_endpoint_name=self.model_endpoint_name, function_name=self.function_name, - endpoint_id=self.endpoint_id, context=mlrun.get_or_create_ctx(name=f"{self.name_prefix}-context"), # pyright: ignore[reportGeneralTypeIssues] infer_results_df=self.infer_results_df, # TODO: activate ad-hoc mode when ML-5792 is done From 643d691eaa47968b09de96642d4cde6cecf13b31 Mon Sep 17 00:00:00 2001 From: TomerShor <90552140+TomerShor@users.noreply.github.com> Date: Tue, 7 Jan 2025 12:03:36 +0200 Subject: [PATCH 13/15] [Datastore] Fix Kafka source SASL configuration (#7067) --- mlrun/datastore/datastore_profile.py | 3 ++- mlrun/datastore/sources.py | 3 ++- tests/datastore/test_base.py | 9 ++++++--- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/mlrun/datastore/datastore_profile.py b/mlrun/datastore/datastore_profile.py index 8fed2506c71..c71ea28dca9 100644 --- a/mlrun/datastore/datastore_profile.py +++ b/mlrun/datastore/datastore_profile.py @@ -211,9 +211,10 @@ def attributes(self): attributes["partitions"] = self.partitions sasl = attributes.pop("sasl", {}) if self.sasl_user and self.sasl_pass: - sasl["enabled"] = True + sasl["enable"] = True sasl["user"] = self.sasl_user sasl["password"] = self.sasl_pass + sasl["mechanism"] = "PLAIN" if sasl: attributes["sasl"] = sasl return attributes diff --git a/mlrun/datastore/sources.py b/mlrun/datastore/sources.py index c13fe2b60aa..784538c2de0 100644 --- a/mlrun/datastore/sources.py +++ b/mlrun/datastore/sources.py @@ -1089,9 +1089,10 @@ def __init__( attributes["partitions"] = partitions sasl = attributes.pop("sasl", {}) if sasl_user and sasl_pass: - sasl["enabled"] = True + sasl["enable"] = True sasl["user"] = sasl_user sasl["password"] = sasl_pass + sasl["mechanism"] = "PLAIN" if sasl: attributes["sasl"] = sasl super().__init__(attributes=attributes, **kwargs) diff --git a/tests/datastore/test_base.py b/tests/datastore/test_base.py index e6d1f631d4d..6309edce4c4 100644 --- a/tests/datastore/test_base.py +++ b/tests/datastore/test_base.py @@ -81,9 +81,10 @@ def test_kafka_source_with_attributes(): assert attributes["topics"] == ["mytopic"] assert attributes["consumerGroup"] == "mygroup" assert attributes["sasl"] == { - "enabled": True, + "enable": True, "user": "myuser", "password": "mypassword", + "mechanism": "PLAIN", "handshake": True, } @@ -116,9 +117,10 @@ def test_kafka_source_with_attributes_as_ds_profile(): assert attributes["topics"] == ["mytopic"] assert attributes["consumerGroup"] == "mygroup" assert attributes["sasl"] == { - "enabled": True, + "enable": True, "user": "myuser", "password": "mypassword", + "mechanism": "PLAIN", "handshake": True, } @@ -173,9 +175,10 @@ def test_kafka_source_without_attributes(): assert attributes["topics"] == ["mytopic"] assert attributes["consumerGroup"] == "mygroup" assert attributes["sasl"] == { - "enabled": True, + "enable": True, "user": "myuser", "password": "mypassword", + "mechanism": "PLAIN", } From be5271326750acfe2e03ab358bdf096920a3177c Mon Sep 17 00:00:00 2001 From: Katerina Molchanova <35141662+rokatyy@users.noreply.github.com> Date: Tue, 7 Jan 2025 11:18:35 +0000 Subject: [PATCH 14/15] [Tests] Fix sleep interval (#7069) title says it all --- mlrun/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlrun/config.py b/mlrun/config.py index 77038dc3f36..c072cb3fcc2 100644 --- a/mlrun/config.py +++ b/mlrun/config.py @@ -799,7 +799,7 @@ # maximum allowed value for count in criteria field inside AlertConfig "max_criteria_count": 100, # interval for periodic events generation job - "events_generation_interval": "30", + "events_generation_interval": 30, # seconds }, "auth_with_client_id": { "enabled": False, From 98c1f95a8ccc2e0d9de01e1ea98c5cec4be51e0e Mon Sep 17 00:00:00 2001 From: Roy Schossberger <85231212+royischoss@users.noreply.github.com> Date: Tue, 7 Jan 2025 13:35:26 +0200 Subject: [PATCH 15/15] [Model Monitoring] Controller stream, chief worker implementation (#7045) --- .../schemas/model_monitoring/constants.py | 19 + mlrun/config.py | 29 +- mlrun/datastore/sources.py | 5 + mlrun/model_monitoring/controller.py | 369 +++++++++++++----- .../db/tsdb/tdengine/tdengine_connector.py | 2 +- .../db/tsdb/v3io/v3io_connector.py | 4 +- mlrun/model_monitoring/stream_processing.py | 74 +++- .../api/crud/model_monitoring/deployment.py | 36 +- .../test_stream_processing.py | 9 +- tests/system/model_monitoring/test_app.py | 11 +- 10 files changed, 429 insertions(+), 129 deletions(-) diff --git a/mlrun/common/schemas/model_monitoring/constants.py b/mlrun/common/schemas/model_monitoring/constants.py index 807c75a3219..7fd8fd86bf2 100644 --- a/mlrun/common/schemas/model_monitoring/constants.py +++ b/mlrun/common/schemas/model_monitoring/constants.py @@ -183,6 +183,25 @@ class WriterEventKind(MonitoringStrEnum): STATS = "stats" +class ControllerEvent(MonitoringStrEnum): + KIND = "kind" + ENDPOINT_ID = "endpoint_id" + ENDPOINT_NAME = "endpoint_name" + PROJECT = "project" + TIMESTAMP = "timestamp" + FIRST_REQUEST = "first_request" + FEATURE_SET_URI = "feature_set_uri" + ENDPOINT_TYPE = "endpoint_type" + ENDPOINT_POLICY = "endpoint_policy" + # Note: currently under endpoint policy we will have a dictionary including the keys: "application_names" + # and "base_period" + + +class ControllerEventKind(MonitoringStrEnum): + NOP_EVENT = "nop_event" + REGULAR_EVENT = "regular_event" + + class MetricData(MonitoringStrEnum): METRIC_NAME = "metric_name" METRIC_VALUE = "metric_value" diff --git a/mlrun/config.py b/mlrun/config.py index c072cb3fcc2..e969361f862 100644 --- a/mlrun/config.py +++ b/mlrun/config.py @@ -596,6 +596,22 @@ "max_replicas": 1, }, }, + "controller_stream_args": { + "v3io": { + "shard_count": 10, + "retention_period_hours": 24, + "num_workers": 10, + "min_replicas": 1, + "max_replicas": 1, + }, + "kafka": { + "partition_count": 10, + "replication_factor": 1, + "num_workers": 10, + "min_replicas": 1, + "max_replicas": 1, + }, + }, # Store prefixes are used to handle model monitoring storing policies based on project and kind, such as events, # stream, and endpoints. "store_prefixes": { @@ -1282,6 +1298,8 @@ def get_model_monitoring_file_target_path( function_name and function_name != mlrun.common.schemas.model_monitoring.constants.MonitoringFunctionNames.STREAM + and function_name + != mlrun.common.schemas.model_monitoring.constants.MonitoringFunctionNames.APPLICATION_CONTROLLER ): return mlrun.mlconf.model_endpoint_monitoring.store_prefixes.user_space.format( project=project, @@ -1289,12 +1307,21 @@ def get_model_monitoring_file_target_path( if function_name is None else f"{kind}-{function_name.lower()}", ) - elif kind == "stream": + elif ( + kind == "stream" + and function_name + != mlrun.common.schemas.model_monitoring.constants.MonitoringFunctionNames.APPLICATION_CONTROLLER + ): return mlrun.mlconf.model_endpoint_monitoring.store_prefixes.user_space.format( project=project, kind=kind, ) else: + if ( + function_name + == mlrun.common.schemas.model_monitoring.constants.MonitoringFunctionNames.APPLICATION_CONTROLLER + ): + kind = function_name return mlrun.mlconf.model_endpoint_monitoring.store_prefixes.default.format( project=project, kind=kind, diff --git a/mlrun/datastore/sources.py b/mlrun/datastore/sources.py index 784538c2de0..5d36469d7cc 100644 --- a/mlrun/datastore/sources.py +++ b/mlrun/datastore/sources.py @@ -1128,8 +1128,13 @@ def add_nuclio_trigger(self, function): extra_attributes["workerAllocationMode"] = extra_attributes.get( "worker_allocation_mode", "static" ) + else: + extra_attributes["workerAllocationMode"] = extra_attributes.get( + "worker_allocation_mode", "pool" + ) trigger_kwargs = {} + if "max_workers" in extra_attributes: trigger_kwargs = {"max_workers": extra_attributes.pop("max_workers")} diff --git a/mlrun/model_monitoring/controller.py b/mlrun/model_monitoring/controller.py index 5188d180855..bfe3a154fc6 100644 --- a/mlrun/model_monitoring/controller.py +++ b/mlrun/model_monitoring/controller.py @@ -12,14 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import concurrent.futures import datetime import json import os from collections.abc import Iterator from contextlib import AbstractContextManager from types import TracebackType -from typing import NamedTuple, Optional, cast +from typing import Any, NamedTuple, Optional, cast import nuclio_sdk @@ -28,6 +27,10 @@ import mlrun.feature_store as fstore import mlrun.model_monitoring from mlrun.common.schemas import EndpointType +from mlrun.common.schemas.model_monitoring.constants import ( + ControllerEvent, + ControllerEventKind, +) from mlrun.datastore import get_stream_pusher from mlrun.errors import err_to_str from mlrun.model_monitoring.db._schedules import ModelMonitoringSchedulesFile @@ -140,6 +143,7 @@ def __init__(self, project: str, endpoint_id: str, window_length: int) -> None: Initialize a batch window generator object that generates batch window objects for the monitoring functions. """ + self.batch_window: _BatchWindow = None self._project = project self._endpoint_id = endpoint_id self._timedelta = window_length @@ -199,14 +203,14 @@ def get_intervals( `first_request` and `last_request` are the timestamps of the first request and last request to the endpoint, respectively. They are guaranteed to be nonempty at this point. """ - batch_window = _BatchWindow( + self.batch_window = _BatchWindow( schedules_file=self._schedules_file, application=application, timedelta_seconds=self._timedelta, last_updated=self._get_last_updated_time(last_request, not_batch_endpoint), first_request=int(first_request.timestamp()), ) - yield from batch_window.get_intervals() + yield from self.batch_window.get_intervals() def _get_window_length() -> int: @@ -237,6 +241,7 @@ def __init__(self) -> None: self._window_length = _get_window_length() self.model_monitoring_access_key = self._get_model_monitoring_access_key() + self.v3io_access_key = mlrun.get_secret_or_env("V3IO_ACCESS_KEY") self.storage_options = None if mlrun.mlconf.artifact_path.startswith("s3://"): self.storage_options = mlrun.mlconf.get_s3_storage_options() @@ -262,112 +267,65 @@ def _should_monitor_endpoint(endpoint: mlrun.common.schemas.ModelEndpoint) -> bo != mm_constants.EndpointType.ROUTER.value ) - def run(self) -> None: + def run(self, event: nuclio_sdk.Event) -> None: """ - Main method for run all the relevant monitoring applications on each endpoint. + Main method for controller chief, runs all the relevant monitoring applications for a single endpoint. + Handles nop events logic. This method handles the following: - 1. List model endpoints - 2. List applications - 3. Check model monitoring windows - 4. Send data to applications - 5. Delete old parquets + 1. Read applications from the event (endpoint_policy) + 2. Check model monitoring windows + 3. Send data to applications + 4. Pushes nop event to main stream if needed """ - logger.info("Start running monitoring controller") + logger.info("Start running monitoring controller worker") try: - applications_names = [] - endpoints_list = mlrun.db.get_run_db().list_model_endpoints( - project=self.project, tsdb_metrics=True - ) - endpoints = endpoints_list.endpoints - if not endpoints: - logger.info("No model endpoints found", project=self.project) - return - monitoring_functions = self.project_obj.list_model_monitoring_functions() - if monitoring_functions: - applications_names = list( - {app.metadata.name for app in monitoring_functions} - ) - # if monitoring_functions: - TODO : ML-7700 - # Gets only application in ready state - # applications_names = list( - # { - # app.metadata.name - # for app in monitoring_functions - # if ( - # app.status.state == "ready" - # # workaround for the default app, as its `status.state` is `None` - # or app.metadata.name - # == mm_constants.HistogramDataDriftApplicationConstants.NAME - # ) - # } - # ) - if not applications_names: - logger.info("No monitoring functions found", project=self.project) - return - logger.info( - "Starting to iterate over the applications", - applications=applications_names, - ) - + body = json.loads(event.body.decode("utf-8")) except Exception as e: logger.error( - "Failed to list endpoints and monitoring applications", + "Failed to decode event", exc=err_to_str(e), ) return - # Initialize a thread pool that will be used to monitor each endpoint on a dedicated thread - with concurrent.futures.ThreadPoolExecutor( - max_workers=min(len(endpoints), 10) - ) as pool: - for endpoint in endpoints: - if self._should_monitor_endpoint(endpoint): - pool.submit( - MonitoringApplicationController.model_endpoint_process, - project=self.project, - endpoint=endpoint, - applications_names=applications_names, - window_length=self._window_length, - model_monitoring_access_key=self.model_monitoring_access_key, - storage_options=self.storage_options, - ) - else: - logger.debug( - "Skipping endpoint, not ready or not suitable for monitoring", - endpoint_id=endpoint.metadata.uid, - endpoint_name=endpoint.metadata.name, - ) - logger.info("Finished running monitoring controller") + # Run single endpoint process + self.model_endpoint_process(event=body) - @classmethod def model_endpoint_process( - cls, - project: str, - endpoint: mlrun.common.schemas.ModelEndpoint, - applications_names: list[str], - window_length: int, - model_monitoring_access_key: str, - storage_options: Optional[dict] = None, + self, + event: Optional[dict] = None, ) -> None: """ Process a model endpoint and trigger the monitoring applications. This function running on different process - for each endpoint. In addition, this function will generate a parquet file that includes the relevant data - for a specific time range. - - :param endpoint: (dict) Model endpoint record. - :param applications_names: (list[str]) List of application names to push results to. - :param batch_window_generator: (_BatchWindowGenerator) An object that generates _BatchWindow objects. - :param project: (str) Project name. - :param model_monitoring_access_key: (str) Access key to apply the model monitoring process. - :param storage_options: (dict) Storage options for reading the infer parquet files. + for each endpoint. + + :param event: (dict) Event that triggered the monitoring process. """ - endpoint_id = endpoint.metadata.uid - not_batch_endpoint = not ( - endpoint.metadata.endpoint_type == EndpointType.BATCH_EP - ) - m_fs = fstore.get_feature_set(endpoint.spec.monitoring_feature_set_uri) + logger.info("Model endpoint process started", event=event) + try: + project_name = event[ControllerEvent.PROJECT] + endpoint_id = event[ControllerEvent.ENDPOINT_ID] + endpoint_name = event[ControllerEvent.ENDPOINT_NAME] + applications_names = event[ControllerEvent.ENDPOINT_POLICY][ + "monitoring_applications" + ] + + not_batch_endpoint = ( + event[ControllerEvent.ENDPOINT_POLICY] != EndpointType.BATCH_EP + ) + m_fs = fstore.get_feature_set(event[ControllerEvent.FEATURE_SET_URI]) + logger.info( + "Starting analyzing for:", timestamp=event[ControllerEvent.TIMESTAMP] + ) + last_stream_timestamp = datetime.datetime.fromisoformat( + event[ControllerEvent.TIMESTAMP] + ) + first_request = datetime.datetime.fromisoformat( + event[ControllerEvent.FIRST_REQUEST] + ) with _BatchWindowGenerator( - project=project, endpoint_id=endpoint_id, window_length=window_length + project=project_name, + endpoint_id=endpoint_id, + window_length=self._window_length, ) as batch_window_generator: for application in applications_names: for ( @@ -375,15 +333,15 @@ def model_endpoint_process( end_infer_time, ) in batch_window_generator.get_intervals( application=application, - first_request=endpoint.status.first_request, - last_request=endpoint.status.last_request, not_batch_endpoint=not_batch_endpoint, + first_request=first_request, + last_request=last_stream_timestamp, ): df = m_fs.to_dataframe( start_time=start_infer_time, end_time=end_infer_time, time_column=mm_constants.EventFieldType.TIMESTAMP, - storage_options=storage_options, + storage_options=self.storage_options, ) if len(df) == 0: logger.info( @@ -399,21 +357,53 @@ def model_endpoint_process( end=end_infer_time, endpoint_id=endpoint_id, ) - cls._push_to_applications( + self._push_to_applications( start_infer_time=start_infer_time, end_infer_time=end_infer_time, endpoint_id=endpoint_id, - endpoint_name=endpoint.metadata.name, - project=project, + endpoint_name=endpoint_name, + project=project_name, applications_names=[application], - model_monitoring_access_key=model_monitoring_access_key, + model_monitoring_access_key=self.model_monitoring_access_key, ) - logger.info("Finished processing endpoint", endpoint_id=endpoint_id) + base_period = event[ControllerEvent.ENDPOINT_POLICY]["base_period"] + current_time = mlrun.utils.datetime_now() + if ( + current_time.timestamp() + - batch_window_generator.batch_window._get_last_analyzed() + >= datetime.timedelta(minutes=base_period).total_seconds() + and event[ControllerEvent.KIND] != ControllerEventKind.NOP_EVENT + ): + event = { + ControllerEvent.KIND: mm_constants.ControllerEventKind.NOP_EVENT, + ControllerEvent.PROJECT: project_name, + ControllerEvent.ENDPOINT_ID: endpoint_id, + ControllerEvent.ENDPOINT_NAME: endpoint_name, + ControllerEvent.TIMESTAMP: current_time.isoformat( + timespec="microseconds" + ), + ControllerEvent.ENDPOINT_POLICY: event[ + ControllerEvent.ENDPOINT_POLICY + ], + ControllerEvent.ENDPOINT_TYPE: event[ + ControllerEvent.ENDPOINT_TYPE + ], + ControllerEvent.FEATURE_SET_URI: event[ + ControllerEvent.FEATURE_SET_URI + ], + ControllerEvent.FIRST_REQUEST: event[ + ControllerEvent.FIRST_REQUEST + ], + } + self._push_to_main_stream( + event=event, + endpoint_id=endpoint_id, + ) except Exception: logger.exception( "Encountered an exception", - endpoint_id=endpoint.metadata.uid, + endpoint_id=event[ControllerEvent.ENDPOINT_ID], ) @staticmethod @@ -465,6 +455,168 @@ def _push_to_applications( [data] ) + def push_regular_event_to_controller_stream(self, event: nuclio_sdk.Event) -> None: + """ + pushes a regular event to the controller stream. + :param event: the nuclio trigger event + """ + logger.info("Starting monitoring controller chief") + applications_names = [] + db = mlrun.get_run_db() + endpoints = db.list_model_endpoints( + project=self.project, tsdb_metrics=True + ).endpoints + if not endpoints: + logger.info("No model endpoints found", project=self.project) + return + monitoring_functions = self.project_obj.list_model_monitoring_functions() + if monitoring_functions: + # if monitoring_functions: - TODO : ML-7700 + # Gets only application in ready state + # applications_names = list( + # { + # app.metadata.name + # for app in monitoring_functions + # if ( + # app.status.state == "ready" + # # workaround for the default app, as its `status.state` is `None` + # or app.metadata.name + # == mm_constants.HistogramDataDriftApplicationConstants.NAME + # ) + # } + # ) + applications_names = list( + {app.metadata.name for app in monitoring_functions} + ) + if not applications_names: + logger.info("No monitoring functions found", project=self.project) + return + policy = { + "monitoring_applications": applications_names, + "base_period": int( + batch_dict2timedelta( + json.loads( + cast( + str, + os.getenv(mm_constants.EventFieldType.BATCH_INTERVALS_DICT), + ) + ) + ).total_seconds() + // 60 + ), + } + for endpoint in endpoints: + if self._should_monitor_endpoint(endpoint): + logger.info( + "Regular event is being pushed to controller stream for model endpoint", + endpoint_id=endpoint.metadata.uid, + endpoint_name=endpoint.metadata.name, + timestamp=endpoint.status.last_request.isoformat( + sep=" ", timespec="microseconds" + ), + first_request=endpoint.status.first_request.isoformat( + sep=" ", timespec="microseconds" + ), + endpoint_type=endpoint.metadata.endpoint_type, + feature_set_uri=endpoint.spec.monitoring_feature_set_uri, + endpoint_policy=json.dumps(policy), + ) + self.push_to_controller_stream( + kind=mm_constants.ControllerEventKind.REGULAR_EVENT, + project=self.project, + endpoint_id=endpoint.metadata.uid, + endpoint_name=endpoint.metadata.name, + stream_access_key=self.v3io_access_key, + timestamp=endpoint.status.last_request.isoformat( + sep=" ", timespec="microseconds" + ), + first_request=endpoint.status.first_request.isoformat( + sep=" ", timespec="microseconds" + ), + endpoint_type=endpoint.metadata.endpoint_type, + feature_set_uri=endpoint.spec.monitoring_feature_set_uri, + endpoint_policy=policy, + ) + else: + logger.info( + "Should not monitor model endpoint, didn't push regular event", + endpoint_id=endpoint.metadata.uid, + endpoint_name=endpoint.metadata.name, + timestamp=endpoint.status.last_request, + first_request=endpoint.status.first_request, + endpoint_type=endpoint.metadata.endpoint_type, + feature_set_uri=endpoint.spec.monitoring_feature_set_uri, + ) + + @staticmethod + def push_to_controller_stream( + kind: str, + project: str, + endpoint_id: str, + endpoint_name: str, + stream_access_key: str, + timestamp: str, + first_request: str, + endpoint_type: str, + feature_set_uri: str, + endpoint_policy: dict[str, Any], + ) -> None: + """ + Pushes event data to controller stream. + :param timestamp: the event timestamp str isoformat utc timezone + :param first_request: the first request str isoformat utc timezone + :param endpoint_policy: dictionary hold the monitoring policy + :param kind: str event kind + :param project: project name + :param endpoint_id: endpoint id string + :param endpoint_name: the endpoint name string + :param endpoint_type: Enum of the endpoint type + :param feature_set_uri: the feature set uri string + :param stream_access_key: access key to apply the model monitoring process. + """ + stream_uri = get_stream_path( + project=project, + function_name=mm_constants.MonitoringFunctionNames.APPLICATION_CONTROLLER, + ) + event = { + ControllerEvent.KIND.value: kind, + ControllerEvent.PROJECT.value: project, + ControllerEvent.ENDPOINT_ID.value: endpoint_id, + ControllerEvent.ENDPOINT_NAME.value: endpoint_name, + ControllerEvent.TIMESTAMP.value: timestamp, + ControllerEvent.FIRST_REQUEST.value: first_request, + ControllerEvent.ENDPOINT_TYPE.value: endpoint_type, + ControllerEvent.FEATURE_SET_URI.value: feature_set_uri, + ControllerEvent.ENDPOINT_POLICY.value: endpoint_policy, + } + logger.info( + "Pushing data to controller stream", + event=event, + endpoint_id=endpoint_id, + stream_uri=stream_uri, + ) + get_stream_pusher(stream_uri, access_key=stream_access_key).push( + [event], partition_key=endpoint_id + ) + + def _push_to_main_stream(self, event: dict, endpoint_id: str) -> None: + """ + Pushes the given event to model monitoring stream + :param event: event dictionary to push to stream + :param endpoint_id: endpoint id string + """ + stream_uri = get_stream_path(project=event.get(ControllerEvent.PROJECT)) + + logger.info( + "Pushing data to main stream, NOP event is been generated", + event=json.dumps(event), + endpoint_id=endpoint_id, + stream_uri=stream_uri, + ) + get_stream_pusher(stream_uri, access_key=self.model_monitoring_access_key).push( + [event], partition_key=endpoint_id + ) + def handler(context: nuclio_sdk.Context, event: nuclio_sdk.Event) -> None: """ @@ -473,4 +625,15 @@ def handler(context: nuclio_sdk.Context, event: nuclio_sdk.Event) -> None: :param context: the Nuclio context :param event: trigger event """ - MonitoringApplicationController().run() + logger.info( + "Controller got event", + trigger=event.trigger, + trigger_kind=event.trigger.kind, + ) + + if event.trigger.kind == "http": + # Runs controller chief: + MonitoringApplicationController().push_regular_event_to_controller_stream(event) + else: + # Runs controller worker: + MonitoringApplicationController().run(event=event) diff --git a/mlrun/model_monitoring/db/tsdb/tdengine/tdengine_connector.py b/mlrun/model_monitoring/db/tsdb/tdengine/tdengine_connector.py index 7480a5c5934..6901b5b1af0 100644 --- a/mlrun/model_monitoring/db/tsdb/tdengine/tdengine_connector.py +++ b/mlrun/model_monitoring/db/tsdb/tdengine/tdengine_connector.py @@ -188,7 +188,7 @@ def apply_process_before_tsdb(): graph.add_step( "mlrun.model_monitoring.db.tsdb.tdengine.stream_graph_steps.ProcessBeforeTDEngine", name="ProcessBeforeTDEngine", - after="MapFeatureNames", + after="FilterNOP", ) def apply_tdengine_target(name, after): diff --git a/mlrun/model_monitoring/db/tsdb/v3io/v3io_connector.py b/mlrun/model_monitoring/db/tsdb/v3io/v3io_connector.py index eb03cad7184..70ee7e5af15 100644 --- a/mlrun/model_monitoring/db/tsdb/v3io/v3io_connector.py +++ b/mlrun/model_monitoring/db/tsdb/v3io/v3io_connector.py @@ -204,7 +204,7 @@ def apply_storey_aggregations(): } ], name=EventFieldType.LATENCY, - after="MapFeatureNames", + after="FilterNOP", step_name="Aggregates", table=".", key_field=EventFieldType.ENDPOINT_ID, @@ -225,7 +225,7 @@ def apply_storey_aggregations(): graph.add_step( "storey.TSDBTarget", name="tsdb_predictions", - after="MapFeatureNames", + after="FilterNOP", path=f"{self.container}/{self.tables[mm_schemas.FileTargetKind.PREDICTIONS]}", rate="1/s", time_col=mm_schemas.EventFieldType.TIMESTAMP, diff --git a/mlrun/model_monitoring/stream_processing.py b/mlrun/model_monitoring/stream_processing.py index 073b91f2fc7..dc6fc63af9f 100644 --- a/mlrun/model_monitoring/stream_processing.py +++ b/mlrun/model_monitoring/stream_processing.py @@ -29,11 +29,14 @@ import mlrun.serving.states import mlrun.utils from mlrun.common.schemas.model_monitoring.constants import ( + ControllerEvent, + ControllerEventKind, EndpointType, EventFieldType, FileTargetKind, ProjectSecretKeys, ) +from mlrun.datastore import parse_kafka_url from mlrun.model_monitoring.db import TSDBConnector from mlrun.utils import logger @@ -88,7 +91,9 @@ def _initialize_v3io_configurations( self.v3io_framesd = v3io_framesd or mlrun.mlconf.v3io_framesd self.v3io_api = v3io_api or mlrun.mlconf.v3io_api - self.v3io_access_key = v3io_access_key or os.environ.get("V3IO_ACCESS_KEY") + self.v3io_access_key = v3io_access_key or mlrun.get_secret_or_env( + "V3IO_ACCESS_KEY" + ) self.model_monitoring_access_key = ( model_monitoring_access_key or os.environ.get(ProjectSecretKeys.ACCESS_KEY) @@ -118,6 +123,7 @@ def apply_monitoring_serving_graph( self, fn: mlrun.runtimes.ServingRuntime, tsdb_connector: TSDBConnector, + controller_stream_uri: str, ) -> None: """ Apply monitoring serving graph to a given serving function. The following serving graph includes about 4 main @@ -146,6 +152,8 @@ def apply_monitoring_serving_graph( :param fn: A serving function. :param tsdb_connector: Time series database connector. + :param controller_stream_uri: The controller stream URI. Runs on server api pod so needed to be provided as + input """ graph = typing.cast( @@ -209,6 +217,20 @@ def apply_map_feature_names(): ) apply_map_feature_names() + # split the graph between event with error vs valid event + graph.add_step( + "storey.Filter", + "FilterNOP", + after="MapFeatureNames", + _fn="(event.get('kind', " ") != 'nop_event')", + ) + graph.add_step( + "storey.Filter", + "ForwardNOP", + after="MapFeatureNames", + _fn="(event.get('kind', " ") == 'nop_event')", + ) + tsdb_connector.apply_monitoring_stream_steps( graph=graph, aggregate_windows=self.aggregate_windows, @@ -221,7 +243,7 @@ def apply_process_before_parquet(): graph.add_step( "ProcessBeforeParquet", name="ProcessBeforeParquet", - after="MapFeatureNames", + after="FilterNOP", _fn="(event)", ) @@ -248,6 +270,44 @@ def apply_parquet_target(): apply_parquet_target() + # controller branch + def apply_push_controller_stream(stream_uri: str): + if stream_uri.startswith("v3io://"): + graph.add_step( + ">>", + "controller_stream_v3io", + path=stream_uri, + sharding_func=ControllerEvent.ENDPOINT_ID, + access_key=self.v3io_access_key, + after="ForwardNOP", + ) + elif stream_uri.startswith("kafka://"): + topic, brokers = parse_kafka_url(stream_uri) + logger.info( + "Controller stream uri for kafka", + stream_uri=stream_uri, + topic=topic, + brokers=brokers, + ) + if isinstance(brokers, list): + path = f"kafka://{brokers[0]}/{topic}" + elif isinstance(brokers, str): + path = f"kafka://{brokers}/{topic}" + else: + raise mlrun.errors.MLRunInvalidArgumentError( + "Brokers must be a list or str check controller stream uri" + ) + graph.add_step( + ">>", + "controller_stream_kafka", + path=path, + kafka_brokers=brokers, + _sharding_func="kafka_sharding_func", # TODO: remove this when storey handle str key + after="ForwardNOP", + ) + + apply_push_controller_stream(controller_stream_uri) + class ProcessBeforeParquet(mlrun.feature_store.steps.MapClass): def __init__(self, **kwargs): @@ -321,6 +381,9 @@ def __init__( def do(self, full_event): event = full_event.body + if event.get(ControllerEvent.KIND, "") == ControllerEventKind.NOP_EVENT: + logger.info("Skipped nop event inside of ProcessEndpointEvent", event=event) + return storey.Event(body=[event]) # Getting model version and function uri from event # and use them for retrieving the endpoint_id function_uri = full_event.body.get(EventFieldType.FUNCTION_URI) @@ -589,6 +652,9 @@ def _infer_label_columns_from_data(self, event): return None def do(self, event: dict): + if event.get(ControllerEvent.KIND, "") == ControllerEventKind.NOP_EVENT: + logger.info("Skipped nop event inside of MapFeatureNames", event=event) + return event endpoint_id = event[EventFieldType.ENDPOINT_ID] feature_values = event[EventFieldType.FEATURES] @@ -827,3 +893,7 @@ def update_monitoring_feature_set( ) monitoring_feature_set.save() + + +def kafka_sharding_func(event): + return event.body[ControllerEvent.ENDPOINT_ID].encode("UTF-8") diff --git a/server/py/services/api/crud/model_monitoring/deployment.py b/server/py/services/api/crud/model_monitoring/deployment.py index dea9abb0fc6..cacb400e635 100644 --- a/server/py/services/api/crud/model_monitoring/deployment.py +++ b/server/py/services/api/crud/model_monitoring/deployment.py @@ -328,7 +328,8 @@ def apply_and_create_stream_trigger( function=function, function_name=function_name ) - function.spec.disable_default_http_trigger = True + if function_name != mm_constants.MonitoringFunctionNames.APPLICATION_CONTROLLER: + function.spec.disable_default_http_trigger = True return function @@ -346,7 +347,10 @@ def _apply_and_create_kafka_source( stream_source = mlrun.datastore.sources.KafkaSource( brokers=brokers, topics=[topic], - attributes={"max_workers": stream_args.kafka.num_workers}, + attributes={ + "max_workers": stream_args.kafka.num_workers, + "worker_allocation_mode": "static", + }, ) try: stream_source.create_topics( @@ -375,11 +379,16 @@ def _apply_and_create_v3io_source( function_name: str, stream_args: mlrun.config.Config, ): - access_key = self.model_monitoring_access_key - kwargs = {"access_key": self.model_monitoring_access_key} + access_key = ( + self.model_monitoring_access_key + if function_name + != mm_constants.MonitoringFunctionNames.APPLICATION_CONTROLLER + else mlrun.mlconf.get_v3io_access_key() + ) + kwargs = {"access_key": access_key} if mlrun.mlconf.is_explicit_ack_enabled(): kwargs["explicit_ack_mode"] = "explicitOnly" - kwargs["worker_allocation_mode"] = "static" + kwargs["worker_allocation_mode"] = "static" kwargs["max_workers"] = stream_args.v3io.num_workers services.api.api.endpoints.nuclio.create_model_monitoring_stream( project=self.project, @@ -444,10 +453,15 @@ def _initial_model_monitoring_stream_processing_function( project=self.project, secret_provider=self._secret_provider ) + controller_stream_uri = mlrun.model_monitoring.get_stream_path( + project=self.project, + function_name=mm_constants.MonitoringFunctionNames.APPLICATION_CONTROLLER, + secret_provider=self._secret_provider, + ) + # Create monitoring serving graph stream_processor.apply_monitoring_serving_graph( - function, - tsdb_connector, + function, tsdb_connector, controller_stream_uri ) # Set the project to the serving function @@ -489,11 +503,17 @@ def _get_model_monitoring_controller_function(self, image: str): # Set the project to the job function function.metadata.project = self.project + # Add stream triggers + function = self.apply_and_create_stream_trigger( + function=function, + function_name=mm_constants.MonitoringFunctionNames.APPLICATION_CONTROLLER, + stream_args=config.model_endpoint_monitoring.controller_stream_args, + ) + function = self._apply_access_key_and_mount_function( function=function, function_name=mm_constants.MonitoringFunctionNames.APPLICATION_CONTROLLER, ) - function.spec.max_replicas = 1 # Enrich runtime with the required configurations framework.api.utils.apply_enrichment_and_validation_on_function( function, self.auth_info diff --git a/tests/model_monitoring/test_stream_processing.py b/tests/model_monitoring/test_stream_processing.py index 2db1c72f7f8..9ad57b2c6e5 100644 --- a/tests/model_monitoring/test_stream_processing.py +++ b/tests/model_monitoring/test_stream_processing.py @@ -20,7 +20,8 @@ @pytest.mark.parametrize("tsdb_connector", ["v3io", "taosws"]) -def test_plot_monitoring_serving_graph(tsdb_connector): +@pytest.mark.parametrize("stream_path", ["v3io", "kafka://192.168.226.176:9092/topic"]) +def test_plot_monitoring_serving_graph(tsdb_connector, stream_path): project_name = "test-stream-processing" project = mlrun.get_or_create_project(project_name) @@ -40,11 +41,13 @@ def test_plot_monitoring_serving_graph(tsdb_connector): project=project_name, tsdb_connection_string=tsdb_connector ) - processor.apply_monitoring_serving_graph(fn, tsdb_connector) + processor.apply_monitoring_serving_graph(fn, tsdb_connector, stream_path) graph = fn.spec.graph.plot(rankdir="TB") print() - print(f"Graphviz graph definition with tsdb_connector={tsdb_connector}") + print( + f"Graphviz graph definition with tsdb_connector={tsdb_connector} and stream_path={stream_path}" + ) print("Feed this to graphviz, or to https://dreampuf.github.io/GraphvizOnline") print() print(graph) diff --git a/tests/system/model_monitoring/test_app.py b/tests/system/model_monitoring/test_app.py index e8a46074141..c248db5b9d2 100644 --- a/tests/system/model_monitoring/test_app.py +++ b/tests/system/model_monitoring/test_app.py @@ -685,23 +685,16 @@ def test_app_flow(self, with_training_set: bool) -> None: self._add_error_alert() time.sleep(5) - self._infer( + last_request = self._infer( serving_fn, num_events=self.num_events, with_training_set=with_training_set ) self._infer_with_error(serving_fn, with_training_set=with_training_set) # mark the first window as "done" with another request time.sleep( - self.app_interval_seconds + 2 * self.app_interval_seconds + mlrun.mlconf.model_endpoint_monitoring.parquet_batching_timeout_secs - + 2 ) - for i in range(10): - last_request = self._infer( - serving_fn, num_events=1, with_training_set=with_training_set - ) - # wait for the completed window to be processed - time.sleep(1.2 * self.app_interval_seconds) mep = mlrun.db.get_run_db().get_model_endpoint( name=f"{self.model_name}_{with_training_set}",