Skip to content

Commit

Permalink
UPD: allow more split cv methods for regression
Browse files Browse the repository at this point in the history
  • Loading branch information
thibaultcordier committed Dec 13, 2023
1 parent 6a880c5 commit 051f8a4
Show file tree
Hide file tree
Showing 7 changed files with 110 additions and 51 deletions.
17 changes: 11 additions & 6 deletions mapie/estimator/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
import numpy as np
from joblib import Parallel, delayed
from sklearn.base import RegressorMixin, clone
from sklearn.model_selection import BaseCrossValidator, ShuffleSplit
from sklearn.model_selection import BaseCrossValidator
from sklearn.utils import _safe_indexing
from sklearn.utils.validation import (_num_samples, check_is_fitted)

from mapie._typing import ArrayLike, NDArray
from mapie.aggregation_functions import aggregate_all, phi2D
from mapie.utils import (check_nan_in_aposteriori_prediction,
from mapie.utils import (check_nan_in_aposteriori_prediction, check_no_agg_cv,
fit_estimator)
from mapie.estimator.interface import EnsembleEstimator

Expand Down Expand Up @@ -278,7 +278,9 @@ def _aggregate_with_mask(
ArrayLike of shape (n_samples_test,)
Array of aggregated predictions for each testing sample.
"""
if self.method in self.no_agg_methods_ or self.cv in self.no_agg_cv_:
if self.method in self.no_agg_methods_ or (
check_no_agg_cv(self.cv, self.no_agg_cv_)
):
raise ValueError(
"There should not be aggregation of predictions "
f"if cv is in '{self.no_agg_cv_}' "
Expand Down Expand Up @@ -434,8 +436,9 @@ def fit(
)
for train_index, _ in cv.split(X)
)
if isinstance(cv, ShuffleSplit):
single_estimator_ = estimators_[0]
# In split-CP, we keep only the model fitted on train dataset
if check_no_agg_cv(cv, self.no_agg_cv_):
single_estimator_ = estimators_[0]

self.single_estimator_ = single_estimator_
self.estimators_ = estimators_
Expand Down Expand Up @@ -487,7 +490,9 @@ def predict(
if not return_multi_pred and not ensemble:
return y_pred

if self.method in self.no_agg_methods_ or self.cv in self.no_agg_cv_:
if self.method in self.no_agg_methods_ or (
check_no_agg_cv(self.cv, self.no_agg_cv_)
):
y_pred_multi_low = y_pred[:, np.newaxis]
y_pred_multi_up = y_pred[:, np.newaxis]
else:
Expand Down
9 changes: 7 additions & 2 deletions mapie/regression/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from mapie.utils import (check_alpha, check_alpha_and_n_samples,
check_conformity_score, check_cv,
check_estimator_fit_predict, check_n_features_in,
check_n_jobs, check_null_weight, check_verbose)
check_n_jobs, check_no_agg_cv, check_null_weight,
check_verbose)


class MapieRegressor(BaseEstimator, RegressorMixin):
Expand Down Expand Up @@ -315,7 +316,9 @@ def _check_agg_function(
"You need to specify an aggregation function when "
f"cv's type is in {self.cv_need_agg_function_}."
)
elif (agg_function is not None) or (self.cv in self.no_agg_cv_):
elif (agg_function is not None) or (
check_no_agg_cv(self.cv, self.no_agg_cv_)
):
return agg_function
else:
return "mean"
Expand Down Expand Up @@ -507,6 +510,8 @@ def fit(
)
# Fit the prediction function
self.estimator_ = self.estimator_.fit(X, y, sample_weight)

# Predict on calibration data
y_pred = self.estimator_.predict_calib(X)

# Compute the conformity scores (manage jk-ab case)
Expand Down
8 changes: 5 additions & 3 deletions mapie/regression/time_series_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from mapie._typing import ArrayLike, NDArray
from mapie.aggregation_functions import aggregate_all
from .regression import MapieRegressor
from mapie.utils import check_alpha, check_alpha_and_n_samples
from mapie.utils import (check_alpha, check_alpha_and_n_samples,
check_no_agg_cv)


class MapieTimeSeriesRegressor(MapieRegressor):
Expand Down Expand Up @@ -316,8 +317,9 @@ def predict(
self.lower_quantiles_ = lower_quantiles
self.higher_quantiles_ = higher_quantiles

if self.method in self.no_agg_methods_ \
or self.cv in self.no_agg_cv_:
if self.method in self.no_agg_methods_ or (
check_no_agg_cv(self.cv, self.no_agg_cv_)
):
y_pred_low = y_pred[:, np.newaxis] + lower_quantiles
y_pred_up = y_pred[:, np.newaxis] + higher_quantiles
else:
Expand Down
32 changes: 32 additions & 0 deletions mapie/tests/test_quantile_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,38 @@ def test_results_for_same_alpha(
np.testing.assert_allclose(y_pis[:, 1, 0], y_pis_clone[:, 1, 0])


def test_ensemble_in_predict() -> None:
"""Checking for ensemble defined in predict of CQR"""
mapie_reg = MapieQuantileRegressor()
mapie_reg.fit(X, y)
with pytest.warns(
UserWarning, match=r"WARNING: Alpha should not be spec.*"
):
mapie_reg.predict(X, alpha=0.2)


def test_alpha_in_predict() -> None:
"""Checking for alpha defined in predict of CQR"""
mapie_reg = MapieQuantileRegressor()
mapie_reg.fit(X, y)
with pytest.warns(UserWarning, match=r"WARNING: ensemble is not util*"):
mapie_reg.predict(X, ensemble=True)


@pytest.mark.parametrize("estimator", [-1, 3, 0.2])
def test_quantile_prefit_non_iterable(estimator: Any) -> None:
"""
Test that there is a list of estimators provided when cv='prefit'
is called for MapieQuantileRegressor.
"""
with pytest.raises(
ValueError,
match=r".*Estimator for prefit must be an iterable object.*",
):
mapie_reg = MapieQuantileRegressor(estimator=estimator, cv="prefit")
mapie_reg.fit([1, 2, 3], [4, 5, 6])


@pytest.mark.parametrize("alphas", ["hello", MapieQuantileRegressor, [2], 1])
def test_wrong_alphas_types(alphas: float) -> None:
"""Checking for wrong type of alphas"""
Expand Down
8 changes: 5 additions & 3 deletions mapie/tests/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from sklearn.dummy import DummyRegressor
from sklearn.impute import SimpleImputer
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import (KFold, LeaveOneOut, ShuffleSplit,
train_test_split)
from sklearn.model_selection import (KFold, LeaveOneOut, PredefinedSplit,
ShuffleSplit, train_test_split)
from sklearn.pipeline import Pipeline, make_pipeline
from sklearn.preprocessing import OneHotEncoder
from sklearn.utils.validation import check_is_fitted
Expand Down Expand Up @@ -211,7 +211,9 @@ def test_valid_agg_function(agg_function: str) -> None:

@pytest.mark.parametrize(
"cv", [None, -1, 2, KFold(), LeaveOneOut(),
ShuffleSplit(n_splits=1), "prefit", "split"]
ShuffleSplit(n_splits=1),
PredefinedSplit(test_fold=[-1]*3+[0]*3),
"prefit", "split"]
)
def test_valid_cv(cv: Any) -> None:
"""Test that valid cv raise no errors."""
Expand Down
56 changes: 19 additions & 37 deletions mapie/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,18 @@
from numpy.random import RandomState
from sklearn.datasets import make_regression
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import BaseCrossValidator
from sklearn.model_selection import BaseCrossValidator, KFold, ShuffleSplit
from sklearn.utils.validation import check_is_fitted

from mapie._typing import ArrayLike, NDArray
from mapie.regression import MapieQuantileRegressor
from mapie.utils import (check_alpha, check_alpha_and_n_samples,
check_array_nan, check_array_inf, check_arrays_length,
check_binary_zero_one, check_cv,
check_lower_upper_bounds, check_n_features_in,
check_n_jobs, check_null_weight, check_number_bins,
check_split_strategy, check_verbose,
compute_quantiles, fit_estimator, get_binning_groups)
check_n_jobs, check_no_agg_cv, check_null_weight,
check_number_bins, check_split_strategy,
check_verbose, compute_quantiles, fit_estimator,
get_binning_groups)

X_toy = np.array([0, 1, 2, 3, 4, 5]).reshape(-1, 1)
y_toy = np.array([5, 7, 9, 11, 13, 15])
Expand Down Expand Up @@ -254,24 +254,6 @@ def test_final1D_low_high_pred() -> None:
check_lower_upper_bounds(y_preds, y_pred_low, y_pred_up)


def test_ensemble_in_predict() -> None:
"""Checking for ensemble defined in predict of CQR"""
mapie_reg = MapieQuantileRegressor()
mapie_reg.fit(X, y)
with pytest.warns(
UserWarning, match=r"WARNING: Alpha should not be spec.*"
):
mapie_reg.predict(X, alpha=0.2)


def test_alpha_in_predict() -> None:
"""Checking for alpha defined in predict of CQR"""
mapie_reg = MapieQuantileRegressor()
mapie_reg.fit(X, y)
with pytest.warns(UserWarning, match=r"WARNING: ensemble is not util*"):
mapie_reg.predict(X, ensemble=True)


def test_compute_quantiles_value_error():
"""Test that if the size of the last axis of vector
is different from the number of aphas an error is raised.
Expand Down Expand Up @@ -325,20 +307,6 @@ def test_compute_quantiles_2D_and_3D(alphas: NDArray):
assert (quantiles1 == quantiles2).all()


@pytest.mark.parametrize("estimator", [-1, 3, 0.2])
def test_quantile_prefit_non_iterable(estimator: Any) -> None:
"""
Test that there is a list of estimators provided when cv='prefit'
is called for MapieQuantileRegressor.
"""
with pytest.raises(
ValueError,
match=r".*Estimator for prefit must be an iterable object.*",
):
mapie_reg = MapieQuantileRegressor(estimator=estimator, cv="prefit")
mapie_reg.fit([1, 2, 3], [4, 5, 6])


# def test_calib_set_no_Xy_but_sample_weight() -> None:
# """Test warning message if sample weight provided but no X y in calib."""
# X = np.array([4, 5, 6])
Expand Down Expand Up @@ -474,3 +442,17 @@ def test_check_cv_same_split_no_random_state(cv: BaseCrossValidator) -> None:

for i in range(cv.get_n_splits()):
np.testing.assert_allclose(train_indices_1[i], train_indices_2[i])


def test_check_no_agg_cv() -> None:
array = ["prefit", "split"]
np.testing.assert_almost_equal(check_no_agg_cv(1, array), True)
np.testing.assert_almost_equal(check_no_agg_cv(2, array), False)
cv = "split"
np.testing.assert_almost_equal(check_no_agg_cv(cv, array), True)
cv = KFold(5)
np.testing.assert_almost_equal(check_no_agg_cv(cv, array), False)
cv = ShuffleSplit(1)
np.testing.assert_almost_equal(check_no_agg_cv(cv, array), True)
cv = ShuffleSplit(2)
np.testing.assert_almost_equal(check_no_agg_cv(cv, array), False)
31 changes: 31 additions & 0 deletions mapie/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,37 @@ def check_cv(
)


def check_no_agg_cv(
cv: Union[int, str, BaseCrossValidator, BaseShuffleSplit],
no_agg_cv_array: list,
) -> bool:
"""
Check if cross-validator is ``"prefit"``, ``"split"`` or any split
equivalent `BaseCrossValidator` or `BaseShuffleSplit`.
Parameters
----------
cv: Union[int, str, BaseCrossValidator, BaseShuffleSplit]
Cross-validator to check.
no_agg_cv_array: list
List of all non-aggregated cv methods.
Returns
-------
bool
True if `cv` is a split equivalent / non-aggregated cv method.
"""
if isinstance(cv, str):
return cv in no_agg_cv_array
elif isinstance(cv, int):
return cv == 1
try:
return cv.get_n_splits() == 1
except Exception:
return False


def check_alpha(
alpha: Optional[Union[float, Iterable[float]]] = None
) -> Optional[ArrayLike]:
Expand Down

0 comments on commit 051f8a4

Please sign in to comment.