Skip to content

Commit

Permalink
[Frameworks] Enable scikit-learn v1.2.0 to work with `mlrun.framewo…
Browse files Browse the repository at this point in the history
…rks` (mlrun#2810)
  • Loading branch information
guy1992l authored Dec 25, 2022
1 parent ba9b54c commit 621b7e1
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 16 deletions.
2 changes: 1 addition & 1 deletion dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ nuclio-sdk~=0.3.0
isort~=5.7
avro~=1.11
# needed for mlutils tests
scikit-learn~=1.0, <1.2
scikit-learn~=1.0
# needed for frameworks tests
lightgbm~=3.0
xgboost~=1.1
2 changes: 1 addition & 1 deletion dockerfiles/jupyter/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
matplotlib~=3.5
scipy~=1.0
scikit-learn~=1.0, <1.2
scikit-learn~=1.0
seaborn~=0.11.0
scikit-plot~=0.3.7
xgboost~=1.1
Expand Down
2 changes: 1 addition & 1 deletion dockerfiles/mlrun/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
matplotlib~=3.5
scipy~=1.0
scikit-learn~=1.0, <1.2
scikit-learn~=1.0
seaborn~=0.11.0
scikit-plot~=0.3.7
2 changes: 1 addition & 1 deletion dockerfiles/test-system/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
pytest~=5.4
matplotlib~=3.5
graphviz~=0.20.0
scikit-learn~=1.0, <1.2
scikit-learn~=1.0
22 changes: 11 additions & 11 deletions mlrun/frameworks/sklearn/mlrun_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class SKLearnMLRunInterface(MLRunInterface, ABC):
# A producer instance for logging this model's training / evaluation artifacts:
"_producer": None, # type: MLProducer
# An estimator instance for logging this model's training / evaluation metrics results:
"_estimator": None, # type: Estimator
"_mlrun_estimator": None, # type: Estimator
# The test set (For validation post training or evaluation post prediction):
"_x_test": None, # type: SKLearnTypes.DatasetType
"_y_test": None, # type: SKLearnTypes.DatasetType
Expand Down Expand Up @@ -86,7 +86,7 @@ def add_interface(
# Setup a default producer and estimator:
if obj._producer is None:
obj._producer = MLProducer()
obj._estimator = Estimator()
obj._mlrun_estimator = Estimator()

@classmethod
def mlrun_fit(cls):
Expand Down Expand Up @@ -186,17 +186,17 @@ def configure_logging(
else model_handler.context
)
self._producer.set_context(context=context)
self._estimator.set_context(context=context)
self._mlrun_estimator.set_context(context=context)
self._model_handler.set_context(context=context)

# Set the logging attributes:
self._producer.set_plans(plans=plans)
self._estimator.set_metrics(metrics=metrics)
self._mlrun_estimator.set_metrics(metrics=metrics)

# Validate that if the prediction probabilities are required, this model has the 'predict_proba' method:
if (
self._producer.is_probabilities_required()
or self._estimator.is_probabilities_required()
or self._mlrun_estimator.is_probabilities_required()
) and not hasattr(self, "predict_proba"):
raise mlrun.errors.MLRunInvalidArgumentError(
f"Some of the metrics and or artifacts required to be calculated and produced require prediction "
Expand Down Expand Up @@ -256,7 +256,7 @@ def _post_fit(
self._model_handler.set_sample_set(sample_set=sample_set)
# Log the model:
self._model_handler.log(
metrics=self._estimator.results,
metrics=self._mlrun_estimator.results,
artifacts=self._producer.artifacts,
)
self._model_handler.context.commit(completed=False)
Expand All @@ -270,7 +270,7 @@ def _pre_predict(self, x: SKLearnTypes.DatasetType, y: SKLearnTypes.DatasetType)
"""
# This function is only called for evaluation, then set the mode to the producer and estimator:
self._producer.set_mode(mode=LoggingMode.EVALUATION)
self._estimator.set_mode(mode=LoggingMode.EVALUATION)
self._mlrun_estimator.set_mode(mode=LoggingMode.EVALUATION)

# Produce and log all the artifacts pre prediction:
self._producer.produce_stage(
Expand Down Expand Up @@ -304,14 +304,14 @@ def _post_predict(
)

# Calculate and log the metrics results:
self._estimator.estimate(
self._mlrun_estimator.estimate(
y_true=y, y_pred=y_pred, is_probabilities=is_predict_proba
)

# If some metrics and / or plans require probabilities, run 'predict_proba':
if not is_predict_proba and (
self._producer.is_probabilities_required()
or self._estimator.is_probabilities_required()
or self._mlrun_estimator.is_probabilities_required()
):
y_pred_proba = self.predict_proba(x)
self._producer.produce_stage(
Expand All @@ -322,7 +322,7 @@ def _post_predict(
y=y,
y_pred=y_pred_proba,
)
self._estimator.estimate(
self._mlrun_estimator.estimate(
y_true=y, y_pred=y_pred_proba, is_probabilities=True
)

Expand All @@ -333,7 +333,7 @@ def _post_predict(
# Update the model with the testing artifacts and results:
if self._model_handler is not None:
self._model_handler.update(
metrics=self._estimator.results,
metrics=self._mlrun_estimator.results,
artifacts=self._producer.artifacts,
)
self._model_handler.context.commit(completed=False)
1 change: 0 additions & 1 deletion tests/test_requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,6 @@ def test_requirement_specifiers_convention():
"protobuf": {">=3.20.2, <4"},
"pandas": {"~=1.2, <1.5.0"},
"importlib_metadata": {">=3.6"},
"scikit-learn": {"~=1.0, <1.2"},
}

for (
Expand Down

0 comments on commit 621b7e1

Please sign in to comment.