Skip to content

Commit

Permalink
bring back deprecated params and add warn (#834) (#835)
Browse files Browse the repository at this point in the history
  • Loading branch information
Eyal-Danieli authored Oct 10, 2024
1 parent 668d545 commit c9d97fb
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 40 deletions.
35 changes: 34 additions & 1 deletion batch_inference_v2/batch_inference_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,17 @@ def infer(
model_endpoint_sample_set: Union[
mlrun.DataItem, list, dict, pd.DataFrame, pd.Series, np.ndarray
] = None,

# the following parameters are deprecated and will be removed once the versioning mechanism is implemented
# TODO: Remove the following parameters once FHUB-13 is resolved
trigger_monitoring_job: Optional[bool] = None,
batch_image_job: Optional[str] = None,
model_endpoint_drift_threshold: Optional[float] = None,
model_endpoint_possible_drift_threshold: Optional[float] = None,

# prediction kwargs to pass to the model predict function
**predict_kwargs: Dict[str, Any],

):
"""
Perform a prediction on the provided dataset using the specified model.
Expand Down Expand Up @@ -173,10 +183,33 @@ def infer(
:param model_endpoint_sample_set: A sample dataset to give to compare the inputs in the drift analysis.
Can be provided as an input (DataItem) or as a parameter (e.g. string, list, DataFrame).
The default chosen sample set will always be the one who is set in the model artifact itself.
:param trigger_monitoring_job: Whether to trigger the batch drift analysis after the infer job.
:param batch_image_job: The image that will be used to register the monitoring batch job if not exist.
By default, the image is mlrun/mlrun.
:param model_endpoint_drift_threshold: The threshold of which to mark drifts. Defaulted to 0.7.
:param model_endpoint_possible_drift_threshold: The threshold of which to mark possible drifts. Defaulted to 0.5.
raises MLRunInvalidArgumentError: if both `model_path` and `endpoint_id` are not provided
"""


if trigger_monitoring_job:
context.logger.warning("The `trigger_monitoring_job` parameter is deprecated and will be removed once the versioning mechanism is implemented. "
"if you are using mlrun<1.7.0, please import the previous version of this function, for example "
"'hub://batch_inference_v2:2.5.0'.")
if batch_image_job:
context.logger.warning("The `batch_image_job` parameter is deprecated and will be removed once the versioning mechanism is implemented. "
"if you are using mlrun<1.7.0, please import the previous version of this function, for example "
"'hub://batch_inference_v2:2.5.0'.")
if model_endpoint_drift_threshold:
context.logger.warning("The `model_endpoint_drift_threshold` parameter is deprecated and will be removed once the versioning mechanism is implemented. "
"if you are using mlrun<1.7.0, please import the previous version of this function, for example "
"'hub://batch_inference_v2:2.5.0'.")
if model_endpoint_possible_drift_threshold:
context.logger.warning("The `model_endpoint_possible_drift_threshold` parameter is deprecated and will be removed once the versioning mechanism is implemented. "
"if you are using mlrun<1.7.0, please import the previous version of this function, for example "
"'hub://batch_inference_v2:2.5.0'.")

# Loading the model:
context.logger.info(f"Loading model...")
if isinstance(model_path, mlrun.DataItem):
Expand Down Expand Up @@ -250,4 +283,4 @@ def infer(
model_endpoint_name=model_endpoint_name,
infer_results_df=result_set.copy(),
sample_set_statistics=sample_set_statistics,
)
)
Loading

0 comments on commit c9d97fb

Please sign in to comment.