diff --git a/HISTORY.rst b/HISTORY.rst index 761c853d6..8a9f21b95 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -5,6 +5,7 @@ History ##### (##########) ------------------ * Integrate ConformityScore into MapieTimeSeriesRegressor. +* Add (extend) the optimal estimation strategy for the bounds of the prediction intervals for regression via ConformityScore. * Add new checks for metrics calculations. * Fix reference for residual normalised score in documentation. diff --git a/examples/regression/4-tutorials/plot_ts-tutorial.py b/examples/regression/4-tutorials/plot_ts-tutorial.py index d34e95ecb..10709f36d 100644 --- a/examples/regression/4-tutorials/plot_ts-tutorial.py +++ b/examples/regression/4-tutorials/plot_ts-tutorial.py @@ -216,8 +216,6 @@ class that block bootstraps the training set. y_pred_pfit = np.zeros(y_pred_npfit.shape) y_pis_pfit = np.zeros(y_pis_npfit.shape) conformity_scores_pfit = [] -lower_quantiles_pfit = [] -higher_quantiles_pfit = [] y_pred_pfit[:gap], y_pis_pfit[:gap, :, :] = mapie_enbpi.predict( X_test.iloc[:gap, :], alpha=alpha, ensemble=True, optimize_beta=True ) @@ -236,8 +234,6 @@ class that block bootstraps the training set. optimize_beta=True ) conformity_scores_pfit.append(mapie_enbpi.conformity_scores_) - lower_quantiles_pfit.append(mapie_enbpi.lower_quantiles_) - higher_quantiles_pfit.append(mapie_enbpi.higher_quantiles_) coverage_pfit = regression_coverage_score( y_test, y_pis_pfit[:, 0, 0], y_pis_pfit[:, 1, 0] ) diff --git a/mapie/conformity_scores/conformity_scores.py b/mapie/conformity_scores/conformity_scores.py index ef4a79ade..4e9d92f9b 100644 --- a/mapie/conformity_scores/conformity_scores.py +++ b/mapie/conformity_scores/conformity_scores.py @@ -254,14 +254,72 @@ def get_quantile( ]) return quantile + @staticmethod + def _beta_optimize( + alpha_np: NDArray, + upper_bounds: NDArray, + lower_bounds: NDArray, + ) -> NDArray: + """ + Minimize the width of the PIs, for a given difference of quantiles. + + Parameters + ---------- + alpha_np: NDArray + The quantiles to compute. + + upper_bounds: NDArray + The array of upper values. + + lower_bounds: NDArray + The array of lower values. + + Returns + ------- + NDArray + Array of betas minimizing the differences + ``(1-alpa+beta)-quantile - beta-quantile``. + """ + beta_np = np.full( + shape=(len(lower_bounds), len(alpha_np)), + fill_value=np.nan, + dtype=float, + ) + + for ind_alpha, _alpha in enumerate(alpha_np): + betas = np.linspace( + _alpha / (len(lower_bounds) + 1), + _alpha, + num=len(lower_bounds), + endpoint=True, + ) + one_alpha_beta = np_nanquantile( + upper_bounds.astype(float), + 1 - _alpha + betas, + axis=1, + method="higher", + ) + beta = np_nanquantile( + lower_bounds.astype(float), + betas, + axis=1, + method="lower", + ) + beta_np[:, ind_alpha] = betas[ + np.argmin(one_alpha_beta - beta, axis=0) + ] + + return beta_np + def get_bounds( self, X: ArrayLike, estimator: EnsembleEstimator, conformity_scores: NDArray, alpha_np: NDArray, - ensemble: bool, - method: str + ensemble: bool = False, + method: str = 'base', + optimize_beta: bool = False, ) -> Tuple[NDArray, NDArray, NDArray]: """ Compute bounds of the prediction intervals from the observed values, @@ -285,6 +343,8 @@ def get_bounds( ensemble: bool Boolean determining whether the predictions are ensembled or not. + By default ``False``. + method: str Method to choose for prediction interval estimates. The ``"plus"`` method implies that the quantile is calculated @@ -292,6 +352,13 @@ def get_bounds( (among the ``"naive"``, ``"base"`` or ``"minmax"`` methods, for example) do the opposite. + By default ``base``. + + optimize_beta: bool + Whether to optimize the PIs' width or not. + + By default ``False``. + Returns ------- Tuple[NDArray, NDArray, NDArray] @@ -300,13 +367,33 @@ def get_bounds( (n_samples, n_alpha). - The upper bounds of the prediction intervals of shape (n_samples, n_alpha). + + Raises + ------ + ValueError + If beta optimisation with symmetrical conformity score function. """ + if self.sym and optimize_beta: + raise ValueError( + "Beta optimisation cannot be used with " + + "symmetrical conformity score function." + ) + y_pred, y_pred_low, y_pred_up = estimator.predict(X, ensemble) signed = -1 if self.sym else 1 + if optimize_beta: + beta_np = self._beta_optimize( + alpha_np, + conformity_scores.reshape(1, -1), + conformity_scores.reshape(1, -1), + ) + else: + beta_np = alpha_np / 2 + if method == "plus": - alpha_low = alpha_np if self.sym else alpha_np / 2 - alpha_up = 1 - alpha_np if self.sym else 1 - alpha_np / 2 + alpha_low = alpha_np if self.sym else beta_np + alpha_up = 1 - alpha_np if self.sym else 1 - alpha_np + beta_np conformity_scores_low = self.get_estimation_distribution( X, y_pred_low, signed * conformity_scores @@ -322,8 +409,8 @@ def get_bounds( ) else: quantile_search = "higher" if self.sym else "lower" - alpha_low = 1 - alpha_np if self.sym else alpha_np / 2 - alpha_up = 1 - alpha_np if self.sym else 1 - alpha_np / 2 + alpha_low = 1 - alpha_np if self.sym else beta_np + alpha_up = 1 - alpha_np if self.sym else 1 - alpha_np + beta_np quantile_low = self.get_quantile( conformity_scores, alpha_low, axis=0, method=quantile_search diff --git a/mapie/estimator/estimator.py b/mapie/estimator/estimator.py index 33bda7b3d..b79521ea7 100644 --- a/mapie/estimator/estimator.py +++ b/mapie/estimator/estimator.py @@ -496,9 +496,12 @@ def predict( if self.method == "minmax": y_pred_multi_low = np.min(y_pred_multi, axis=1, keepdims=True) y_pred_multi_up = np.max(y_pred_multi, axis=1, keepdims=True) - else: + elif self.method == "plus": y_pred_multi_low = y_pred_multi y_pred_multi_up = y_pred_multi + else: + y_pred_multi_low = y_pred[:, np.newaxis] + y_pred_multi_up = y_pred[:, np.newaxis] if ensemble: y_pred = aggregate_all(self.agg_function, y_pred_multi) diff --git a/mapie/regression/regression.py b/mapie/regression/regression.py index de87b4b11..e380b9f35 100644 --- a/mapie/regression/regression.py +++ b/mapie/regression/regression.py @@ -209,6 +209,7 @@ class MapieRegressor(BaseEstimator, RegressorMixin): no_agg_methods_ = ["naive", "base"] valid_agg_functions_ = [None, "median", "mean"] ensemble_agg_functions_ = ["median", "mean"] + default_sym_ = True fit_attributes = [ "estimator_", "conformity_scores_", @@ -424,7 +425,7 @@ def _check_fit_parameters( estimator = self._check_estimator(self.estimator) agg_function = self._check_agg_function(self.agg_function) cs_estimator = check_conformity_score( - self.conformity_score + self.conformity_score, self.default_sym_ ) if isinstance(cs_estimator, ResidualNormalisedScore) and \ self.cv not in ["split", "prefit"]: @@ -522,6 +523,7 @@ def predict( X: ArrayLike, ensemble: bool = False, alpha: Optional[Union[float, Iterable[float]]] = None, + optimize_beta: bool = False, ) -> Union[NDArray, Tuple[NDArray, NDArray]]: """ Predict target on new samples with confidence intervals. @@ -561,6 +563,11 @@ def predict( By default ``None``. + optimize_beta: bool + Whether to optimize the PIs' width or not. + + By default ``False``. + Returns ------- Union[NDArray, Tuple[NDArray, NDArray]] @@ -582,6 +589,12 @@ def predict( return np.array(y_pred) else: + if optimize_beta and self.method != 'enbpi': + raise UserWarning( + "Beta optimisation should only be used for " + "method='enbpi'." + ) + n = len(self.conformity_scores_) alpha_np = cast(NDArray, alpha) check_alpha_and_n_samples(alpha_np, n) @@ -592,7 +605,8 @@ def predict( self.estimator_, self.conformity_scores_, alpha_np, - ensemble, - self.method + ensemble=ensemble, + method=self.method, + optimize_beta=optimize_beta ) return np.array(y_pred), np.stack([y_pred_low, y_pred_up], axis=1) diff --git a/mapie/regression/time_series_regression.py b/mapie/regression/time_series_regression.py index 0ad6c8e0b..c68a68b40 100644 --- a/mapie/regression/time_series_regression.py +++ b/mapie/regression/time_series_regression.py @@ -1,18 +1,15 @@ from __future__ import annotations -from typing import Iterable, Optional, Tuple, Union, cast +from typing import Optional, Union, cast import numpy as np from sklearn.base import RegressorMixin from sklearn.model_selection import BaseCrossValidator from sklearn.utils.validation import check_is_fitted -from mapie._compatibility import np_nanquantile from mapie._typing import ArrayLike, NDArray -from mapie.aggregation_functions import aggregate_all from mapie.conformity_scores import ConformityScore from .regression import MapieRegressor -from mapie.utils import check_alpha, check_alpha_and_n_samples class MapieTimeSeriesRegressor(MapieRegressor): @@ -36,6 +33,7 @@ class MapieTimeSeriesRegressor(MapieRegressor): cv_need_agg_function_ = MapieRegressor.cv_need_agg_function_ \ + ["BlockBootstrap"] valid_methods_ = ["enbpi"] + default_sym_ = False def __init__( self, @@ -92,79 +90,12 @@ def _relative_conformity_scores( ------- The conformity scores corresponding to the input data set. """ - y_pred, _ = super().predict(X, alpha=0.5, ensemble=ensemble) + y_pred = super().predict(X, ensemble=ensemble) scores = np.array( self.conformity_score_function_.get_conformity_scores(X, y, y_pred) ) return scores - def _beta_optimize( - self, - alpha: Union[float, NDArray], - upper_bounds: NDArray, - lower_bounds: NDArray, - ) -> NDArray: - """ - Minimize the width of the PIs, for a given difference of quantiles. - - Parameters - ---------- - alpha: Union[float, NDArray] - The quantiles to compute. - - upper_bounds: NDArray - The array of upper values. - - lower_bounds: NDArray - The array of lower values. - - Returns - ------- - NDArray - Array of betas minimizing the differences - ``(1-alpa+beta)-quantile - beta-quantile``. - - Raises - ------ - ValueError - If lower and upper bounds arrays don't have the same shape. - """ - if lower_bounds.shape != upper_bounds.shape: - raise ValueError( - "Lower and upper bounds arrays should have the same shape." - ) - alpha = cast(NDArray, alpha) - betas_0 = np.full( - shape=(len(lower_bounds), len(alpha)), - fill_value=np.nan, - dtype=float, - ) - - for ind_alpha, _alpha in enumerate(alpha): - betas = np.linspace( - _alpha / (len(lower_bounds) + 1), - _alpha, - num=len(lower_bounds), - endpoint=True, - ) - one_alpha_beta = np_nanquantile( - upper_bounds.astype(float), - 1 - _alpha + betas, - axis=1, - method="higher", - ) - beta = np_nanquantile( - lower_bounds.astype(float), - betas, - axis=1, - method="lower", - ) - betas_0[:, ind_alpha] = betas[ - np.argmin(one_alpha_beta - beta, axis=0) - ] - - return betas_0 - def fit( self, X: ArrayLike, @@ -281,109 +212,6 @@ def partial_fit( ] = new_conformity_scores_ return self - def predict( - self, - X: ArrayLike, - ensemble: bool = False, - alpha: Optional[Union[float, Iterable[float]]] = None, - optimize_beta: bool = True, - ) -> Union[NDArray, Tuple[NDArray, NDArray]]: - """ - Correspond to 'Conformal prediction for dynamic time-series'. - - Parameters - ---------- - X: ArrayLike of shape (n_samples, n_features) - Test data. - - ensemble: bool - Boolean determining whether the predictions are ensembled or not. - If ``False``, predictions are those of the model trained on the - whole training set. - If ``True``, predictions from perturbed models are aggregated by - the aggregation function specified in the ``agg_function`` - attribute. - - If ``cv`` is ``"prefit"`` or ``"split"``, ``ensemble`` is ignored. - - By default ``False``. - - alpha: Optional[Union[float, Iterable[float]]] - Can be a float, a list of floats, or a ``ArrayLike`` of floats. - Between ``0`` and ``1``, represents the uncertainty of the - confidence interval. - Lower ``alpha`` produce larger (more conservative) prediction - intervals. - ``alpha`` is the complement of the target coverage level. - - By default ``None``. - - optimize_beta: bool - Whether to optimize the PIs' width or not. - - Returns - ------- - Union[NDArray, Tuple[NDArray, NDArray]] - - NDArray of shape (n_samples,) if ``alpha`` is ``None``. - - Tuple[NDArray, NDArray] of shapes (n_samples,) and - (n_samples, 2, n_alpha) if ``alpha`` is not ``None``. - - [:, 0, :]: Lower bound of the prediction interval. - - [:, 1, :]: Upper bound of the prediction interval. - """ - # Checks - check_is_fitted(self, self.fit_attributes) - self._check_ensemble(ensemble) - alpha = cast(Optional[NDArray], check_alpha(alpha)) - y_pred = self.estimator_.single_estimator_.predict(X) - n = len(self.conformity_scores_) - - if alpha is None: - return np.array(y_pred) - - alpha_np = cast(NDArray, alpha) - check_alpha_and_n_samples(alpha_np, n) - - if optimize_beta: - betas_0 = self._beta_optimize( - alpha_np, - self.conformity_scores_.reshape(1, -1), - self.conformity_scores_.reshape(1, -1), - ) - else: - betas_0 = np.repeat(alpha[:, np.newaxis] / 2, n, axis=0) - - lower_quantiles = np_nanquantile( - self.conformity_scores_.astype(float), - betas_0[0, :], - axis=0, - method="lower", - ).T - higher_quantiles = np_nanquantile( - self.conformity_scores_.astype(float), - 1 - alpha_np + betas_0[0, :], - axis=0, - method="higher", - ).T - self.lower_quantiles_ = lower_quantiles - self.higher_quantiles_ = higher_quantiles - - if self.method in self.no_agg_methods_ \ - or self.cv in self.no_agg_cv_: - y_pred_low = y_pred[:, np.newaxis] + lower_quantiles - y_pred_up = y_pred[:, np.newaxis] + higher_quantiles - else: - y_pred_multi = self.estimator_._pred_multi(X) - pred = aggregate_all(self.agg_function, y_pred_multi) - lower_bounds, upper_bounds = pred, pred - - y_pred_low = lower_bounds.reshape(-1, 1) + lower_quantiles - y_pred_up = upper_bounds.reshape(-1, 1) + higher_quantiles - - if ensemble: - y_pred = aggregate_all(self.agg_function, y_pred_multi) - - return y_pred, np.stack([y_pred_low, y_pred_up], axis=1) - def _more_tags(self): return { "_xfail_checks": diff --git a/mapie/tests/test_regression.py b/mapie/tests/test_regression.py index 14557982f..695effdd8 100644 --- a/mapie/tests/test_regression.py +++ b/mapie/tests/test_regression.py @@ -632,3 +632,14 @@ def test_return_multi_pred(ensemble: bool) -> None: X_toy, ensemble=ensemble, return_multi_pred=True ) assert len(output) == 3 + + +def test_beta_optimize_user_warning() -> None: + """ + Test that a UserWarning is displayed when optimize_beta is used. + """ + mapie_reg = MapieRegressor().fit(X, y) + with pytest.raises( + UserWarning, match=r"Beta optimisation should only be used for*", + ): + mapie_reg.predict(X, alpha=0.05, optimize_beta=True) diff --git a/mapie/tests/test_time_series_regression.py b/mapie/tests/test_time_series_regression.py index 5586d4a11..04ac58876 100644 --- a/mapie/tests/test_time_series_regression.py +++ b/mapie/tests/test_time_series_regression.py @@ -12,7 +12,7 @@ from mapie._typing import NDArray from mapie.aggregation_functions import aggregate_all -from mapie.conformity_scores import ConformityScore, AbsoluteConformityScore +from mapie.conformity_scores import AbsoluteConformityScore from mapie.metrics import regression_coverage_score from mapie.regression import MapieTimeSeriesRegressor from mapie.subsample import BlockBootstrap @@ -34,7 +34,6 @@ "method": str, "agg_function": str, "cv": Optional[Union[int, KFold, BlockBootstrap]], - "conformity_score": ConformityScore }, ) STRATEGIES = { @@ -44,7 +43,6 @@ cv=BlockBootstrap( n_resamplings=30, n_blocks=5, random_state=random_state ), - conformity_score=AbsoluteConformityScore(sym=False), ), "jackknife_enbpi_median_ab_wopt": Params( method="enbpi", @@ -54,7 +52,6 @@ n_blocks=5, random_state=random_state, ), - conformity_score=AbsoluteConformityScore(sym=False), ), "jackknife_enbpi_mean_ab": Params( method="enbpi", @@ -62,7 +59,6 @@ cv=BlockBootstrap( n_resamplings=30, n_blocks=5, random_state=random_state ), - conformity_score=AbsoluteConformityScore(sym=False), ), "jackknife_enbpi_median_ab": Params( method="enbpi", @@ -72,7 +68,6 @@ n_blocks=5, random_state=random_state, ), - conformity_score=AbsoluteConformityScore(sym=False), ), } @@ -275,8 +270,7 @@ def test_results_prefit() -> None: ) estimator = LinearRegression().fit(X_train, y_train) mapie_ts_reg = MapieTimeSeriesRegressor( - estimator=estimator, cv="prefit", - conformity_score=AbsoluteConformityScore(sym=False) + estimator=estimator, cv="prefit" ) mapie_ts_reg.fit(X_val, y_val) _, y_pis = mapie_ts_reg.predict(X_test, alpha=0.05) @@ -348,9 +342,7 @@ def test_MapieTimeSeriesRegressor_if_alpha_is_None() -> None: def test_MapieTimeSeriesRegressor_partial_fit_ensemble() -> None: """Test ``partial_fit``.""" - mapie_ts_reg = MapieTimeSeriesRegressor( - cv=-1, conformity_score=AbsoluteConformityScore(sym=False) - ) + mapie_ts_reg = MapieTimeSeriesRegressor(cv=-1) mapie_ts_reg = mapie_ts_reg.fit(X_toy, y_toy, ensemble=True) assert round(mapie_ts_reg.conformity_scores_[-1], 2) == round( np.abs(CONFORMITY_SCORES[0]), 2 @@ -371,13 +363,40 @@ def test_MapieTimeSeriesRegressor_partial_fit_too_big() -> None: mapie_ts_reg = mapie_ts_reg.partial_fit(X=X, y=y) -def test_MapieTimeSeriesRegressor_beta_optimize_eeror() -> None: +def test_MapieTimeSeriesRegressor_beta_optimize_error() -> None: """Test ``beta_optimize`` raised error.""" - mapie_ts_reg = MapieTimeSeriesRegressor(cv=-1) - with pytest.raises(ValueError, match=r".*Lower and upper bounds arrays*"): - mapie_ts_reg._beta_optimize( - alpha=0.1, upper_bounds=X, lower_bounds=X_toy + mapie_ts_reg = MapieTimeSeriesRegressor( + cv=-1, conformity_score=AbsoluteConformityScore(sym=True) + ).fit(X_toy, y_toy) + with pytest.raises( + ValueError, match=r"Beta optimisation cannot be used*" + ): + mapie_ts_reg.predict(X_toy, alpha=0.4, optimize_beta=True) + + +def test_interval_prediction_with_beta_optimize() -> None: + """Test use of ``beta_optimize`` in prediction.""" + X_train_val, X_test, y_train_val, y_test = train_test_split( + X, y, test_size=1 / 10, random_state=random_state + ) + X_train, X_val, y_train, y_val = train_test_split( + X_train_val, y_train_val, test_size=1 / 9, random_state=random_state + ) + estimator = LinearRegression().fit(X_train, y_train) + mapie_ts_reg = MapieTimeSeriesRegressor( + estimator=estimator, + cv=BlockBootstrap( + n_resamplings=30, n_blocks=5, random_state=random_state ) + ) + mapie_ts_reg.fit(X_val, y_val) + _, y_pis = mapie_ts_reg.predict(X_test, alpha=0.05, optimize_beta=True) + width_mean = (y_pis[:, 1, 0] - y_pis[:, 0, 0]).mean() + coverage = regression_coverage_score( + y_test, y_pis[:, 0, 0], y_pis[:, 1, 0] + ) + np.testing.assert_allclose(width_mean, 4.22, rtol=1e-2) + np.testing.assert_allclose(coverage, 0.9, rtol=1e-2) def test_deprecated_path_warning() -> None: diff --git a/mapie/utils.py b/mapie/utils.py index 51f9ec78d..a877cfa5d 100644 --- a/mapie/utils.py +++ b/mapie/utils.py @@ -543,13 +543,16 @@ def check_lower_upper_bounds( def check_conformity_score( conformity_score: Optional[ConformityScore], + sym: bool = True, ) -> ConformityScore: """ Check parameter ``conformity_score``. + Raises ------ ValueError If parameter is not valid. + Examples -------- >>> from mapie.utils import check_conformity_score @@ -562,7 +565,7 @@ def check_conformity_score( Must be None or a ConformityScore instance. """ if conformity_score is None: - return AbsoluteConformityScore() + return AbsoluteConformityScore(sym=sym) elif isinstance(conformity_score, ConformityScore): return conformity_score else: