Skip to content

Commit

Permalink
[batch_inference] Remove last_in_batch_set parameter (#783)
Browse files Browse the repository at this point in the history
  • Loading branch information
jond01 authored Jan 18, 2024
1 parent 603d8d8 commit ca1f49b
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 54 deletions.
40 changes: 3 additions & 37 deletions batch_inference_v2/batch_inference_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from inspect import signature
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Union

import mlrun

Expand Down Expand Up @@ -78,28 +77,6 @@ def _prepare_result_set(
)


def _parse_record_results_kwarg(
last_in_batch_set: Optional[bool],
) -> dict[str, bool]:
"""
Check if `last_in_batch_set` is provided and expected as a parameter.
Return it as a dictionary.
"""
kwarg = "last_in_batch_set"
if last_in_batch_set is None:
return {}
if (
signature(mlrun.model_monitoring.api.record_results).parameters.get(kwarg)
is None
):
raise mlrun.errors.MLRunInvalidArgumentError(
f"Unexpected parameter `{kwarg}` for function: "
"`mlrun.model_monitoring.api.record_results`. "
"Please make sure that you are using `mlrun>=1.6.0` version."
)
return {kwarg: last_in_batch_set}


def infer(
context: mlrun.MLClientCtx,
dataset: Union[mlrun.DataItem, list, dict, pd.DataFrame, pd.Series, np.ndarray],
Expand All @@ -125,7 +102,6 @@ def infer(
model_endpoint_sample_set: Union[
mlrun.DataItem, list, dict, pd.DataFrame, pd.Series, np.ndarray
] = None,
last_in_batch_set: Optional[bool] = None,
**predict_kwargs: Dict[str, Any],
):
"""
Expand Down Expand Up @@ -185,17 +161,8 @@ def infer(
:param model_endpoint_sample_set: A sample dataset to give to compare the inputs in the drift analysis.
Can be provided as an input (DataItem) or as a parameter (e.g. string, list, DataFrame).
The default chosen sample set will always be the one who is set in the model artifact itself.
:param last_in_batch_set: Relevant only when `perform_drift_analysis` is `True`.
This flag can (and should only) be used when the model endpoint does not have
model-monitoring set.
If set to `True` (the default), this flag marks the current monitoring window
(on this monitoring endpoint) as completed - the data inferred so far is assumed
to be the complete data for this monitoring window.
You may want to set this flag to `False` if you want to record multiple results in
close time proximity ("batch set"). In this case, set this flag to `False` on all
but the last batch in the set.
raises MLRunInvalidArgumentError: if both `model_path` and `endpoint_id` are not provided, or if `last_in_batch_set` is
provided for an unsupported `mlrun` version.
raises MLRunInvalidArgumentError: if both `model_path` and `endpoint_id` are not provided
"""

# Loading the model:
Expand Down Expand Up @@ -272,5 +239,4 @@ def infer(
artifacts_tag=artifacts_tag,
trigger_monitoring_job=trigger_monitoring_job,
default_batch_image=batch_image_job,
**_parse_record_results_kwarg(last_in_batch_set=last_in_batch_set),
)
Loading

0 comments on commit ca1f49b

Please sign in to comment.