Skip to content

Commit

Permalink
ENH: V1 conformalized quantile regressor implementation (#579)
Browse files Browse the repository at this point in the history
ENH: V1 CQR implmentation

---------

Co-authored-by: Valentin Laurent <[email protected]>
  • Loading branch information
jawadhussein462 and Valentin-Laurent committed Jan 7, 2025
1 parent 193e193 commit 35255e0
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 45 deletions.
6 changes: 4 additions & 2 deletions mapie_v1/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,11 @@ def check_if_X_y_different_from_fit(

def make_intervals_single_if_single_alpha(
intervals: NDArray,
alphas: List[float]
alphas: Union[float, List[float]]
) -> NDArray:
if len(alphas) == 1:
if isinstance(alphas, float):
return intervals[:, :, 0]
if isinstance(alphas, list) and len(alphas) == 1:
return intervals[:, :, 0]
return intervals

Expand Down
24 changes: 18 additions & 6 deletions mapie_v1/integration_tests/tests/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from sklearn.ensemble import RandomForestRegressor
from sklearn.linear_model import QuantileRegressor
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.model_selection import train_test_split

from mapie.subsample import Subsample
from mapie._typing import ArrayLike
Expand Down Expand Up @@ -306,22 +307,29 @@ def test_intervals_and_predictions_exact_equality_jackknife(params_jackknife):
)
gbr_models.append(estimator_)

sample_weight_train = train_test_split(
X,
y,
sample_weight,
test_size=0.4,
random_state=RANDOM_STATE
)[-2]

params_test_cases_quantile = [
{
"v0": {
"alpha": 0.2,
"cv": "split",
"method": "quantile",
"calib_size": 0.3,
"calib_size": 0.4,
"sample_weight": sample_weight,
"random_state": RANDOM_STATE,
},
"v1": {
"confidence_level": 0.8,
"prefit": False,
"test_size": 0.3,
"fit_params": {"sample_weight": sample_weight},
"test_size": 0.4,
"fit_params": {"sample_weight": sample_weight_train},
"random_state": RANDOM_STATE,
},
},
Expand All @@ -330,15 +338,15 @@ def test_intervals_and_predictions_exact_equality_jackknife(params_jackknife):
"estimator": gbr_models,
"cv": "prefit",
"method": "quantile",
"calib_size": 0.3,
"calib_size": 0.2,
"sample_weight": sample_weight,
"optimize_beta": True,
"random_state": RANDOM_STATE,
},
"v1": {
"estimator": gbr_models,
"prefit": True,
"test_size": 0.3,
"test_size": 0.2,
"fit_params": {"sample_weight": sample_weight},
"minimize_interval_width": True,
"random_state": RANDOM_STATE,
Expand Down Expand Up @@ -418,12 +426,16 @@ def compare_model_predictions_and_intervals(
v1_params: Dict = {},
prefit: bool = False,
test_size: Optional[float] = None,
sample_weight: Optional[ArrayLike] = None,
random_state: int = 42,
) -> None:

if test_size is not None:
X_train, X_conf, y_train, y_conf = train_test_split_shuffle(
X, y, test_size=test_size, random_state=random_state
X,
y,
test_size=test_size,
random_state=random_state,
)
else:
X_train, X_conf, y_train, y_conf = X, X, y, y
Expand Down
145 changes: 108 additions & 37 deletions mapie_v1/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
from typing_extensions import Self

import numpy as np
from sklearn.linear_model import LinearRegression, QuantileRegressor
from sklearn.linear_model import LinearRegression
from sklearn.base import RegressorMixin, clone
from sklearn.model_selection import BaseCrossValidator
from sklearn.pipeline import Pipeline

from mapie.subsample import Subsample
from mapie._typing import ArrayLike, NDArray
from mapie.conformity_scores import BaseRegressionScore
from mapie.regression import MapieRegressor
from mapie.regression import MapieRegressor, MapieQuantileRegressor
from mapie.utils import check_estimator_fit_predict
from mapie_v1.conformity_scores._utils import (
check_and_select_regression_conformity_score,
Expand Down Expand Up @@ -833,43 +834,54 @@ def predict(

class ConformalizedQuantileRegressor:
"""
A conformal quantile regression model that generates prediction intervals
using quantile regression as the base estimator.
A model that combines quantile regression with conformal prediction to
generate reliable prediction intervals with specified coverage levels.
This approach provides prediction intervals by leveraging
quantile predictions and applying conformal adjustments to ensure coverage.
The `ConformalizedQuantileRegressor` leverages quantile regression as its
base estimator to predict conditional quantiles of the target variable,
and applies conformal adjustments to ensure prediction intervals achieve
the desired confidence levels. This approach is particularly useful in
uncertainty quantification for regression tasks.
Parameters
----------
estimator : RegressorMixin, default=QuantileRegressor()
The base quantile regression estimator used to generate point and
interval predictions.
confidence_level : Union[float, List[float]], default=0.9
estimator : Union[`RegressorMixin`, `Pipeline`, \
`List[Union[RegressorMixin, Pipeline]]`]
The base quantile regression model(s) for estimating target quantiles.
- When `prefit=False` (default):
A single quantile regression estimator (e.g., `QuantileRegressor`)
or a pipeline that combines preprocessing and regression.
Supported Regression estimators:
- ``sklearn.linear_model.QuantileRegressor``
- ``sklearn.ensemble.GradientBoostingRegressor``
- ``sklearn.ensemble.HistGradientBoostingRegressor``
- ``lightgbm.LGBMRegressor``
- When `prefit=True`:
A list of three fitted quantile regression estimators corresponding
to lower, upper, and median quantiles. These estimators should be
pre-trained with consistent quantile settings:
* ``lower quantile = 1 - confidence_level / 2``
* ``upper quantile = confidence_level / 2``
* ``median quantile = 0.5``
confidence_level : float default=0.9
The confidence level(s) for the prediction intervals, indicating the
desired coverage probability of the prediction intervals. If a float
is provided, it represents a single confidence level. If a list,
multiple prediction intervals for each specified confidence level
are returned.
desired coverage probability of the prediction intervals.
conformity_score : Union[str, BaseRegressionScore], default="absolute"
The conformity score method used to calculate the conformity error.
Valid options: TODO : reference here the valid options, once the list
has been be created during the implementation.
See: TODO : reference conformity score classes or documentation
A custom score function inheriting from BaseRegressionScore may also
be provided.
random_state : Optional[Union[int, np.random.RandomState]], default=None
A seed or random state instance to ensure reproducibility in any random
operations within the regressor.
prefit : bool, default=False
If `True`, assumes the base estimators are already fitted.
When set to `True`, the `fit` method cannot be called and the
provided estimators should be pre-trained.
Methods
-------
fit(X_train, y_train, fit_params=None) -> Self
Fits the base estimator to the training data and initializes internal
parameters required for conformal prediction.
Trains the base quantile regression estimator on the provided data.
Not applicable if `prefit=True`.
conformalize(X_conf, y_conf, predict_params=None) -> Self
Calibrates the model on provided data, adjusting the prediction
Expand Down Expand Up @@ -904,12 +916,29 @@ class ConformalizedQuantileRegressor:

def __init__(
self,
estimator: RegressorMixin = QuantileRegressor(),
confidence_level: Union[float, List[float]] = 0.9,
conformity_score: Union[str, BaseRegressionScore] = "absolute",
random_state: Optional[Union[int, np.random.RandomState]] = None,
estimator: Optional[
Union[
RegressorMixin,
Pipeline,
List[Union[RegressorMixin, Pipeline]]
]
] = None,
confidence_level: float = 0.9,
prefit: bool = False,
) -> None:
pass

self._alpha = 1 - confidence_level
self.prefit = prefit

cv: str = "prefit" if prefit else "split"
self._mapie_quantile_regressor = MapieQuantileRegressor(
estimator=estimator,
method="quantile",
cv=cv,
alpha=self._alpha,
)

self._sample_weight: Optional[NDArray] = None

def fit(
self,
Expand Down Expand Up @@ -937,6 +966,27 @@ def fit(
Self
The fitted ConformalizedQuantileRegressor instance.
"""

if self.prefit:
raise ValueError(
"The estimators are already fitted, the .fit() method should"
" not be called with prefit=True."
)

if fit_params:
fit_params_ = copy.deepcopy(fit_params)
self._sample_weight = fit_params_.pop("sample_weight", None)
else:
fit_params_ = {}

self._mapie_quantile_regressor._initialize_fit_conformalize()
self._mapie_quantile_regressor._fit_estimators(
X=X_train,
y=y_train,
sample_weight=self._sample_weight,
**fit_params_,
)

return self

def conformalize(
Expand All @@ -948,7 +998,7 @@ def conformalize(
"""
Calibrates the model on the provided data, adjusting the prediction
intervals based on quantile predictions and specified confidence
levels. This step analyzes the conformity scores and refines the
level. This step analyzes the conformity scores and refines the
intervals to ensure desired coverage.
Parameters
Expand All @@ -969,6 +1019,14 @@ def conformalize(
The ConformalizedQuantileRegressor instance with calibrated
prediction intervals.
"""
self.predict_params = predict_params if predict_params else {}

self._mapie_quantile_regressor.conformalize(
X_conf,
y_conf,
**self.predict_params
)

return self

def predict_set(
Expand Down Expand Up @@ -1007,7 +1065,18 @@ def predict_set(
Prediction intervals with shape `(n_samples, 2)`, with lower
and upper bounds for each sample.
"""
return np.ndarray(0)
_, intervals = self._mapie_quantile_regressor.predict(
X,
optimize_beta=minimize_interval_width,
allow_infinite_bounds=allow_infinite_bounds,
symmetry=symmetric_intervals,
**self.predict_params
)

return make_intervals_single_if_single_alpha(
intervals,
self._alpha
)

def predict(
self,
Expand All @@ -1026,7 +1095,9 @@ def predict(
NDArray
Array of point predictions with shape `(n_samples,)`.
"""
return np.ndarray(0)
estimator = self._mapie_quantile_regressor
predictions, _ = estimator.predict(X, **self.predict_params)
return predictions


class GibbsConformalRegressor:
Expand Down

0 comments on commit 35255e0

Please sign in to comment.