Skip to content

Commit

Permalink
ENH: implement several changes to the API:
Browse files Browse the repository at this point in the history
 - make prefit=True default for SplitConformalRegressor
 - change predict_set to predict_interval, and make it return point predictions
 - make mean aggregation default for predictions in cross conformal methods
  • Loading branch information
Valentin-Laurent committed Jan 22, 2025
1 parent d047838 commit 77e74f0
Show file tree
Hide file tree
Showing 7 changed files with 139 additions and 167 deletions.
5 changes: 2 additions & 3 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -119,15 +119,14 @@ As **MAPIE** is compatible with the standard scikit-learn API, you can see that
X_train, X_conformalize, y_train, y_conformalize = train_test_split(X_train_conformalize, y_train_conformalize, test_size=0.5)

regressor = LinearRegression()
regressor.fit(X_train, y_train)
mapie_regressor = SplitConformalRegressor(
regressor,
confidence_level=[0.95, 0.68],
)
mapie_regressor.fit(X_train, y_train)
mapie_regressor.conformalize(X_conformalize, y_conformalize)

y_pred = mapie_regressor.predict(X_test)
y_pred_intervals = mapie_regressor.predict_set(X_test)
y_pred, y_pred_intervals = mapie_regressor.predict_interval(X_test)

.. code:: python
Expand Down
4 changes: 2 additions & 2 deletions doc/quick_start.rst
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,12 @@ Here, we generate one-dimensional noisy data that we fit with a linear model.
mapie_regressor = SplitConformalRegressor(
regressor,
confidence_level=[0.95, 0.68],
prefit=False,
)
mapie_regressor.fit(X_train, y_train)
mapie_regressor.conformalize(X_conformalize, y_conformalize)

y_pred = mapie_regressor.predict(X_test)
y_pred_intervals = mapie_regressor.predict_set(X_test)
y_pred, y_pred_intervals = mapie_regressor.predict_interval(X_test)

# MAPIE's ``predict`` method returns point predictions as a ``np.ndarray`` of shape ``(n_samples)``.
# The ``predict_set`` method returns prediction intervals as a ``np.ndarray`` of shape ``(n_samples, 2, 2)``
Expand Down
28 changes: 14 additions & 14 deletions doc/v1_migration_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ Step 1: Data splitting
~~~~~~~~~~~~~~~~~~~~~~
In v0.9, data splitting is handled by MAPIE.

In v1, the data splitting is left to the user, with the exception of cross-conformal methods (``CrossConformalRegressor``). The user can split the data into training, conformalization, and test sets using scikit-learn's ``train_test_split`` or other methods.
In v1, the data splitting is left to the user, with the exception of cross-conformal methods (``CrossConformalRegressor`` and ``JackknifeAfterBootstrapRegressor``). The user can split the data into training, conformalization, and test sets using scikit-learn's ``train_test_split`` or other methods.

Step 2 & 3: Model training and conformalization (ie: calibration)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand All @@ -54,13 +54,12 @@ In v1.0: MAPIE separates between training and calibration. We decided to name th
- This new method performs conformalization after fitting, using separate conformalization data ``(X_conformalize, y_conformalize)``.
- ``predict_params`` can be passed here, allowing independent control over conformalization and prediction stages.

Step 4: Making predictions (``predict`` and ``predict_set`` methods)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Step 4: Making predictions (``predict`` and ``predict_interval`` methods)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
In MAPIE v0.9, both point predictions and prediction intervals were produced through the ``predict`` method.

MAPIE v1 introduces two distinct methods for prediction:
- ``.predict_set()`` is dedicated to generating prediction intervals (i.e., lower and upper bounds), clearly separating interval predictions from point predictions.
- ``.predict()`` now focuses solely on producing point predictions.
MAPIE v1 introduces a new method for prediction, ``.predict_interval()``, that behaves like v0.9 ``.predict(alpha=...)`` method. Namely, it predicts points and intervals.
The ``.predict()`` method now focuses solely on producing point predictions.



Expand Down Expand Up @@ -107,7 +106,7 @@ The ``groups`` parameter is used to specify group labels for cross-validation, e
Controls whether the model has been pre-fitted before applying conformal prediction.

- **v0.9**: Indicated through ``cv="prefit"`` in ``MapieRegressor``.
- **v1**: ``prefit`` is now a separate boolean parameter, allowing explicit control over whether the model has been pre-fitted before applying conformal methods.
- **v1**: ``prefit`` is now a separate boolean parameter, allowing explicit control over whether the model has been pre-fitted before applying conformal methods. It is set by default to ``True`` for ``SplitConformalRegressor``, as we believe this will become MAPIE nominal usage.

``fit_params`` (includes ``sample_weight``)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand All @@ -125,10 +124,12 @@ Defines additional parameters exclusively for prediction.

``agg_function``, ``aggregation_method``, ``aggregate_predictions``, and ``ensemble``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The aggregation method and technique for combining predictions in ensemble methods.
The aggregation method and technique for combining predictions in cross conformal methods.

- **v0.9**: Previously, the ``agg_function`` parameter had two usage: to aggregate predictions when setting ``ensemble=True`` in the ``predict`` method, and to specify the aggregation technique in ``JackknifeAfterBootstrapRegressor``.
- **v1**: The ``agg_function`` parameter has been split into two distinct parameters: ``aggregate_predictions`` and ``aggregation_method``. ``aggregate_predictions`` is specific to ``CrossConformalRegressor``, and it specifies how predictions from multiple conformal regressors are aggregated when making point predictions. ``aggregation_method`` is specific to ``JackknifeAfterBootstrapRegressor``, and it specifies the aggregation technique for combining predictions across different bootstrap samples during conformalization.
- **v1**:
- The ``agg_function`` parameter has been split into two distinct parameters: ``aggregate_predictions`` and ``aggregation_method``. ``aggregate_predictions`` is specific to ``CrossConformalRegressor``, and it specifies how predictions from multiple conformal regressors are aggregated when making point predictions. ``aggregation_method`` is specific to ``JackknifeAfterBootstrapRegressor``, and it specifies the aggregation technique for combining predictions across different bootstrap samples during conformalization.
- Note that for both cross conformal methods, predictions points are now computed by default using mean aggregation. This is to avoid prediction points outside of prediction intervals in the default setting.

``random_state``
~~~~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -189,7 +190,7 @@ Below is a MAPIE v0.9 code for split conformal prediction in case of pre-fitted

v0.fit(X_conf, y_conf)

prediction_intervals_v0 = v0.predict(X_test, alpha=0.1)[1][:, :, 0]
prediction_points_v0, prediction_intervals_v0 = v0.predict(X_test, alpha=0.1)
prediction_points_v0 = v0.predict(X_test)

Equivalent MAPIE v1 code
Expand All @@ -215,13 +216,12 @@ Below is the equivalent MAPIE v1 code for split conformal prediction:
estimator=prefit_model,
confidence_level=0.9,
conformity_score="residual_normalized",
prefit=True
)

# Here we're not using v1.fit(), because the provided model is already fitted
v1.conformalize(X_conf, y_conf)

prediction_intervals_v1 = v1.predict_set(X_test)
prediction_points_v1, prediction_intervals_v1 = v1.predict_interval(X_test)
prediction_points_v1 = v1.predict(X_test)

Example 2: Cross-Conformal Prediction
Expand Down Expand Up @@ -263,7 +263,7 @@ Below is a MAPIE v0.9 code for cross-conformal prediction:

v0.fit(X, y, sample_weight=sample_weight, groups=groups)

prediction_intervals_v0 = v0.predict(X_test, alpha=0.1)[1][:, :, 0]
prediction_points_v0, prediction_intervals_v0 = v0.predict(X_test, alpha=0.1)
prediction_points_v0 = v0.predict(X_test, ensemble=True)

Equivalent MAPIE v1 code
Expand Down Expand Up @@ -299,5 +299,5 @@ Below is the equivalent MAPIE v1 code for cross-conformal prediction:
v1.fit(X, y, fit_params={"sample_weight": sample_weight})
v1.conformalize(X, y, groups=groups)

prediction_intervals_v1 = v1.predict_set(X_test)
prediction_points_v1, prediction_intervals_v1 = v1.predict_interval(X_test)
prediction_points_v1 = v1.predict(X_test, aggregate_predictions="median")
18 changes: 6 additions & 12 deletions mapie_v1/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,19 +52,13 @@ def check_if_X_y_different_from_fit(
)


def make_intervals_single_if_single_alpha(
intervals: NDArray,
alphas: Union[float, List[float]]
) -> NDArray:
if isinstance(alphas, float):
return intervals[:, :, 0]
if isinstance(alphas, list) and len(alphas) == 1:
return intervals[:, :, 0]
return intervals


def cast_point_predictions_to_ndarray(
point_predictions: Union[NDArray, Tuple[NDArray, NDArray]]
) -> NDArray:
# This will be useless when we split .predict and .predict_set in back-end
return cast(NDArray, point_predictions)


def cast_predictions_to_ndarray_tuple(
predictions: Union[NDArray, Tuple[NDArray, NDArray]]
) -> Tuple[NDArray, NDArray]:
return cast(Tuple[NDArray, NDArray], predictions)
2 changes: 1 addition & 1 deletion mapie_v1/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(
estimator: ClassifierMixin = LogisticRegression(),
confidence_level: Union[float, List[float]] = 0.9,
conformity_score: Union[str, BaseClassificationScore] = "lac",
prefit: bool = False,
prefit: bool = True,
n_jobs: Optional[int] = None,
verbose: int = 0,
random_state: Optional[Union[int, np.random.RandomState]] = None,
Expand Down
26 changes: 20 additions & 6 deletions mapie_v1/integration_tests/tests/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@
"v1": {
"estimator": positive_predictor,
"confidence_level": 0.9,
"prefit": False,
"conformity_score": GammaConformityScore(),
"test_size": 0.3,
"minimize_interval_width": True
Expand Down Expand Up @@ -183,6 +184,8 @@ def test_intervals_and_predictions_exact_equality_split(params_split):
"alpha": [0.5, 0.5],
"conformity_score": GammaConformityScore(),
"cv": LeaveOneOut(),
"agg_function": "mean",
"ensemble": True,
"method": "plus",
"optimize_beta": True,
"random_state": RANDOM_STATE,
Expand Down Expand Up @@ -210,6 +213,7 @@ def test_intervals_and_predictions_exact_equality_split(params_split):
"cv": GroupKFold(),
"groups": groups,
"method": "minmax",
"aggregate_predictions": None,
"allow_infinite_bounds": True,
"random_state": RANDOM_STATE,
}
Expand Down Expand Up @@ -262,6 +266,7 @@ def test_intervals_and_predictions_exact_equality_cross(params_cross):
"alpha": [0.5, 0.5],
"conformity_score": GammaConformityScore(),
"agg_function": "mean",
"ensemble": True,
"cv": Subsample(n_resamplings=20,
replace=True,
random_state=RANDOM_STATE),
Expand Down Expand Up @@ -448,9 +453,15 @@ def compare_model_predictions_and_intervals(
v1_params: Dict = {},
prefit: bool = False,
test_size: Optional[float] = None,
sample_weight: Optional[ArrayLike] = None,
random_state: int = RANDOM_STATE,
) -> None:
if v0_params.get("alpha"):
if isinstance(v0_params["alpha"], float):
n_alpha = 1
else:
n_alpha = len(v0_params["alpha"])
else:
n_alpha = 1

if test_size is not None:
X_train, X_conf, y_train, y_conf = train_test_split_shuffle(
Expand Down Expand Up @@ -496,14 +507,17 @@ def compare_model_predictions_and_intervals(
v0_predict_params.pop('alpha')

v1_predict_params = filter_params(v1.predict, v1_params)
v1_predict_set_params = filter_params(v1.predict_set, v1_params)
v1_predict_interval_params = filter_params(v1.predict_interval, v1_params)

v0_preds, v0_pred_intervals = v0.predict(X_conf, **v0_predict_params)
v1_pred_intervals = v1.predict_set(X_conf, **v1_predict_set_params)
if v1_pred_intervals.ndim == 2:
v1_pred_intervals = np.expand_dims(v1_pred_intervals, axis=2)
v1_preds, v1_pred_intervals = v1.predict_interval(
X_conf, **v1_predict_interval_params
)

v1_preds: ArrayLike = v1.predict(X_conf, **v1_predict_params)
v1_preds_using_predict: ArrayLike = v1.predict(X_conf, **v1_predict_params)

np.testing.assert_array_equal(v0_preds, v1_preds)
np.testing.assert_array_equal(v0_pred_intervals, v1_pred_intervals)
np.testing.assert_array_equal(v1_preds_using_predict, v1_preds)
if not v0_params.get("optimize_beta"):
assert v1_pred_intervals.shape == (len(X_conf), 2, n_alpha)
Loading

0 comments on commit 77e74f0

Please sign in to comment.