Skip to content

Commit

Permalink
Pass mark_monitoring_window_completed to record_results (batch in…
Browse files Browse the repository at this point in the history
…fer)

Update the `batch_inference_v2` function to accept the new parameter
ans pass it forward to mlrun if it's expected.
  • Loading branch information
jond01 committed Dec 11, 2023
1 parent 388009f commit a1bf738
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 8 deletions.
38 changes: 33 additions & 5 deletions batch_inference_v2/batch_inference_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


from typing import Any, Dict, List, Union
from inspect import signature
from typing import Any, Dict, List, Optional, Union

import mlrun

Expand Down Expand Up @@ -79,6 +78,28 @@ def _prepare_result_set(
)


def _parse_record_results_kwarg(
mark_monitoring_window_completed: Optional[bool],
) -> dict[str, bool]:
"""
Check if `mark_monitoring_window_completed` is provided and expected as a parameter.
Return it as a dictionary.
"""
kwarg = "mark_monitoring_window_completed"
if mark_monitoring_window_completed is None:
return {}
if (
signature(mlrun.model_monitoring.api.record_results).parameters.get(kwarg)
is None
):
raise mlrun.errors.MLRunInvalidArgumentError(
f"Unexpected parameter `{kwarg}` for function: "
"`mlrun.model_monitoring.api.record_results`. "
"Please make sure that you are using `mlrun>=1.6.0` version."
)
return {kwarg: mark_monitoring_window_completed}


def infer(
context: mlrun.MLClientCtx,
dataset: Union[mlrun.DataItem, list, dict, pd.DataFrame, pd.Series, np.ndarray],
Expand All @@ -104,6 +125,7 @@ def infer(
model_endpoint_sample_set: Union[
mlrun.DataItem, list, dict, pd.DataFrame, pd.Series, np.ndarray
] = None,
mark_monitoring_window_completed: Optional[bool] = None,
**predict_kwargs: Dict[str, Any],
):
"""
Expand Down Expand Up @@ -163,8 +185,11 @@ def infer(
:param model_endpoint_sample_set: A sample dataset to give to compare the inputs in the drift analysis.
Can be provided as an input (DataItem) or as a parameter (e.g. string, list, DataFrame).
The default chosen sample set will always be the one who is set in the model artifact itself.
raises MLRunInvalidArgumentError: if both `model_path` and `endpoint_id` are not provided
:param mark_monitoring_window_completed: Relevant only when `trigger_monitoring_job` and `perform_drift_analysis` are both `True`.
Whether to mark the monitoring window as completed and allow monitoring without extra inferences.
Defaults to None, which means use the `mlrun` default if the parameter exists.
raises MLRunInvalidArgumentError: if both `model_path` and `endpoint_id` are not provided, or if `mark_monitoring_window_completed` is
provided for an unsupported `mlrun` version.
"""

# Loading the model:
Expand Down Expand Up @@ -241,4 +266,7 @@ def infer(
artifacts_tag=artifacts_tag,
trigger_monitoring_job=trigger_monitoring_job,
default_batch_image=batch_image_job,
**_parse_record_results_kwarg(
mark_monitoring_window_completed=mark_monitoring_window_completed
),
)
Loading

0 comments on commit a1bf738

Please sign in to comment.