Skip to content

Commit

Permalink
Merge pull request #386 from scikit-learn-contrib/385-regressor-any-s…
Browse files Browse the repository at this point in the history
…plit-strategy
  • Loading branch information
thibaultcordier authored Dec 20, 2023
2 parents 960d982 + f46f117 commit 3a93f31
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 22 deletions.
1 change: 1 addition & 0 deletions HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ History

##### (##########)
------------------
* Allow to use more split methods for MapieRegressor (ShuffleSplit, PredefinedSplit).
* Integrate ConformityScore into MapieTimeSeriesRegressor.
* Add (extend) the optimal estimation strategy for the bounds of the prediction intervals for regression via ConformityScore.
* Add new checks for metrics calculations.
Expand Down
17 changes: 10 additions & 7 deletions mapie/estimator/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
import numpy as np
from joblib import Parallel, delayed
from sklearn.base import RegressorMixin, clone
from sklearn.model_selection import BaseCrossValidator, ShuffleSplit
from sklearn.model_selection import BaseCrossValidator
from sklearn.utils import _safe_indexing
from sklearn.utils.validation import (_num_samples, check_is_fitted)

from mapie._typing import ArrayLike, NDArray
from mapie.aggregation_functions import aggregate_all, phi2D
from mapie.utils import (check_nan_in_aposteriori_prediction,
from mapie.utils import (check_nan_in_aposteriori_prediction, check_no_agg_cv,
fit_estimator)
from mapie.estimator.interface import EnsembleEstimator

Expand Down Expand Up @@ -152,6 +152,7 @@ class EnsembleRegressor(EnsembleEstimator):
"single_estimator_",
"estimators_",
"k_",
"use_split_method_",
]

def __init__(
Expand Down Expand Up @@ -278,10 +279,10 @@ def _aggregate_with_mask(
ArrayLike of shape (n_samples_test,)
Array of aggregated predictions for each testing sample.
"""
if self.method in self.no_agg_methods_ or self.cv in self.no_agg_cv_:
if self.method in self.no_agg_methods_ or self.use_split_method_:
raise ValueError(
"There should not be aggregation of predictions "
f"if cv is in '{self.no_agg_cv_}' "
f"if cv is in '{self.no_agg_cv_}', if cv >=2 "
f"or if method is in '{self.no_agg_methods_}'."
)
elif self.agg_function == "median":
Expand Down Expand Up @@ -406,6 +407,7 @@ def fit(
estimators_: List[RegressorMixin] = []
full_indexes = np.arange(_num_samples(X))
cv = self.cv
self.use_split_method_ = check_no_agg_cv(X, self.cv, self.no_agg_cv_)
estimator = self.estimator
n_samples = _num_samples(y)

Expand Down Expand Up @@ -434,8 +436,9 @@ def fit(
)
for train_index, _ in cv.split(X)
)
if isinstance(cv, ShuffleSplit):
single_estimator_ = estimators_[0]
# In split-CP, we keep only the model fitted on train dataset
if self.use_split_method_:
single_estimator_ = estimators_[0]

self.single_estimator_ = single_estimator_
self.estimators_ = estimators_
Expand Down Expand Up @@ -487,7 +490,7 @@ def predict(
if not return_multi_pred and not ensemble:
return y_pred

if self.method in self.no_agg_methods_ or self.cv in self.no_agg_cv_:
if self.method in self.no_agg_methods_ or self.use_split_method_:
y_pred_multi_low = y_pred[:, np.newaxis]
y_pred_multi_up = y_pred[:, np.newaxis]
else:
Expand Down
4 changes: 3 additions & 1 deletion mapie/regression/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ def _check_agg_function(
"You need to specify an aggregation function when "
f"cv's type is in {self.cv_need_agg_function_}."
)
elif (agg_function is not None) or (self.cv in self.no_agg_cv_):
elif agg_function is not None:
return agg_function
else:
return "mean"
Expand Down Expand Up @@ -508,6 +508,8 @@ def fit(
)
# Fit the prediction function
self.estimator_ = self.estimator_.fit(X, y, sample_weight)

# Predict on calibration data
y_pred = self.estimator_.predict_calib(X)

# Compute the conformity scores (manage jk-ab case)
Expand Down
9 changes: 6 additions & 3 deletions mapie/tests/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from sklearn.dummy import DummyRegressor
from sklearn.impute import SimpleImputer
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import (KFold, LeaveOneOut, ShuffleSplit,
train_test_split)
from sklearn.model_selection import (KFold, LeaveOneOut, PredefinedSplit,
ShuffleSplit, train_test_split)
from sklearn.pipeline import Pipeline, make_pipeline
from sklearn.preprocessing import OneHotEncoder
from sklearn.utils.validation import check_is_fitted
Expand Down Expand Up @@ -211,7 +211,9 @@ def test_valid_agg_function(agg_function: str) -> None:

@pytest.mark.parametrize(
"cv", [None, -1, 2, KFold(), LeaveOneOut(),
ShuffleSplit(n_splits=1), "prefit", "split"]
ShuffleSplit(n_splits=1),
PredefinedSplit(test_fold=[-1]*3+[0]*3),
"prefit", "split"]
)
def test_valid_cv(cv: Any) -> None:
"""Test that valid cv raise no errors."""
Expand Down Expand Up @@ -526,6 +528,7 @@ def test_aggregate_with_mask_with_invalid_agg_function() -> None:
0.20,
False
)
ens_reg.use_split_method_ = False
with pytest.raises(
ValueError,
match=r".*The value of self.agg_function is not correct*",
Expand Down
39 changes: 34 additions & 5 deletions mapie/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from __future__ import annotations

from typing import Any, Optional
from typing import Any, Optional, Tuple

import numpy as np
import pytest
from numpy.random import RandomState
from sklearn.datasets import make_regression
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import BaseCrossValidator
from sklearn.model_selection import (BaseCrossValidator, KFold, LeaveOneOut,
ShuffleSplit)
from sklearn.utils.validation import check_is_fitted

from mapie._typing import ArrayLike, NDArray
Expand All @@ -16,9 +17,10 @@
check_array_nan, check_array_inf, check_arrays_length,
check_binary_zero_one, check_cv,
check_lower_upper_bounds, check_n_features_in,
check_n_jobs, check_null_weight, check_number_bins,
check_split_strategy, check_verbose,
compute_quantiles, fit_estimator, get_binning_groups)
check_n_jobs, check_no_agg_cv, check_null_weight,
check_number_bins, check_split_strategy,
check_verbose, compute_quantiles, fit_estimator,
get_binning_groups)

X_toy = np.array([0, 1, 2, 3, 4, 5]).reshape(-1, 1)
y_toy = np.array([5, 7, 9, 11, 13, 15])
Expand Down Expand Up @@ -474,3 +476,30 @@ def test_check_cv_same_split_no_random_state(cv: BaseCrossValidator) -> None:

for i in range(cv.get_n_splits()):
np.testing.assert_allclose(train_indices_1[i], train_indices_2[i])


@pytest.mark.parametrize(
"cv_result", [
(1, True), (2, False),
("split", True), (KFold(5), False),
(ShuffleSplit(1), True),
(ShuffleSplit(2), False),
(LeaveOneOut(), False),
]
)
def test_check_no_agg_cv(cv_result: Tuple) -> None:
"""Test that if `check_no_agg_cv` function returns the expected result."""
array = ["prefit", "split"]
cv, result = cv_result
np.testing.assert_almost_equal(check_no_agg_cv(X_toy, cv, array), result)


@pytest.mark.parametrize("cv", [object()])
def test_check_no_agg_cv_value_error(cv: Any) -> None:
"""Test that if `check_no_agg_cv` function raises value error."""
array = ["prefit", "split"]
with pytest.raises(
ValueError,
match=r"Allowed values must have the `get_n_splits` method"
):
check_no_agg_cv(X_toy, cv, array)
52 changes: 46 additions & 6 deletions mapie/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,21 +129,22 @@ def fit_estimator(


def check_cv(
cv: Optional[Union[int, str, BaseCrossValidator]] = None,
cv: Optional[Union[int, str, BaseCrossValidator, BaseShuffleSplit]] = None,
test_size: Optional[Union[int, float]] = None,
random_state: Optional[Union[int, np.random.RandomState]] = None,
) -> Union[str, BaseCrossValidator]:
) -> Union[str, BaseCrossValidator, BaseShuffleSplit]:
"""
Check if cross-validator is
``None``, ``int``, ``"prefit"``, ``"split"``or ``BaseCrossValidator``.
``None``, ``int``, ``"prefit"``, ``"split"``, ``BaseCrossValidator`` or
``BaseShuffleSplit``.
Return a ``LeaveOneOut`` instance if integer equal to -1.
Return a ``KFold`` instance if integer superior or equal to 2.
Return a ``KFold`` instance if ``None``.
Else raise error.
Parameters
----------
cv: Optional[Union[int, str, BaseCrossValidator]], optional
cv: Optional[Union[int, str, BaseCrossValidator, BaseShuffleSplit]]
Cross-validator to check, by default ``None``.
test_size: Optional[Union[int, float]]
Expand All @@ -163,8 +164,8 @@ def check_cv(
Returns
-------
Optional[Union[float, str]]
'prefit' or None.
Union[str, BaseCrossValidator, BaseShuffleSplit]
The cast `cv` parameter.
Raises
------
Expand Down Expand Up @@ -208,6 +209,45 @@ def check_cv(
)


def check_no_agg_cv(
X: ArrayLike,
cv: Union[int, str, BaseCrossValidator, BaseShuffleSplit],
no_agg_cv_array: list,
) -> bool:
"""
Check if cross-validator is ``"prefit"``, ``"split"`` or any split
equivalent `BaseCrossValidator` or `BaseShuffleSplit`.
Parameters
----------
X: ArrayLike of shape (n_samples, n_features)
Training data.
cv: Union[int, str, BaseCrossValidator, BaseShuffleSplit]
Cross-validator to check.
no_agg_cv_array: list
List of all non-aggregated cv methods.
Returns
-------
bool
True if `cv` is a split equivalent / non-aggregated cv method.
"""
if isinstance(cv, str):
return cv in no_agg_cv_array
elif isinstance(cv, int):
return cv == 1
elif hasattr(cv, "get_n_splits"):
return cv.get_n_splits(X) == 1
else:
raise ValueError(
"Invalid cv argument. "
"Allowed values must have the `get_n_splits` method "
"with zero or one parameter (X)."
)


def check_alpha(
alpha: Optional[Union[float, Iterable[float]]] = None
) -> Optional[ArrayLike]:
Expand Down

0 comments on commit 3a93f31

Please sign in to comment.