Skip to content

Commit

Permalink
Update documentation according to the new regression API (missing /ex…
Browse files Browse the repository at this point in the history
…amples and /notebooks folders) (#594)
  • Loading branch information
Valentin-Laurent committed Jan 10, 2025
1 parent b50d8b9 commit 9c6e00c
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 64 deletions.
29 changes: 16 additions & 13 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ Here's a quick instantiation of MAPIE models for regression and classification p
.. code:: python
# Uncertainty quantification for regression problem
from mapie.regression import MapieRegressor
mapie_regressor = MapieRegressor(estimator=regressor, method='plus', cv=5)
from mapie_v1.regression import SplitConformalRegressor
mapie_regressor = SplitConformalRegressor(estimator=regressor)
.. code:: python
Expand Down Expand Up @@ -105,26 +105,29 @@ As **MAPIE** is compatible with the standard scikit-learn API, you can see that
- How easy it is **to wrap your favorite scikit-learn-compatible model** around your model.
- How easy it is **to follow the standard sequential** ``fit`` and ``predict`` process like any scikit-learn estimator.

.. code:: python
.. testcode::

# Uncertainty quantification for regression problem
import numpy as np
from sklearn.linear_model import LinearRegression
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
from mapie_v1.regression import SplitConformalRegressor

from mapie.regression import MapieRegressor
X, y = make_regression(n_samples=500, n_features=1)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5)
X, y = make_regression(n_samples=500, n_features=1, noise=20, random_state=59)
X_train_conf, X_test, y_train_conf, y_test = train_test_split(X, y, test_size=0.5)
X_train, X_conf, y_train, y_conf = train_test_split(X_train_conf, y_train_conf, test_size=0.5)

regressor = LinearRegression()
mapie_regressor = MapieRegressor(estimator=regressor, method='plus', cv=5)
mapie_regressor = mapie_regressor.fit(X_train, y_train)
y_pred, y_pis = mapie_regressor.predict(X_test, alpha=[0.05, 0.32])
mapie_regressor = SplitConformalRegressor(
regressor,
confidence_level=[0.95, 0.68],
)
mapie_regressor.fit(X_train, y_train)
mapie_regressor.conformalize(X_conf, y_conf)

y_pred = mapie_regressor.predict(X_test)
y_pred_intervals = mapie_regressor.predict_set(X_test)

.. code:: python
Expand Down
68 changes: 34 additions & 34 deletions doc/quick_start.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,13 @@ To install directly from the github repository :
pip install git+https://github.com/scikit-learn-contrib/MAPIE
2. Run MapieRegressor
2. Regression
=====================

Let us start with a basic regression problem.
Let us start with a basic regression problem.
Here, we generate one-dimensional noisy data that we fit with a linear model.

.. code:: python
.. testcode::

import numpy as np
from sklearn.linear_model import LinearRegression
Expand All @@ -45,59 +45,60 @@ Here, we generate one-dimensional noisy data that we fit with a linear model.

regressor = LinearRegression()
X, y = make_regression(n_samples=500, n_features=1, noise=20, random_state=59)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5)
Since MAPIE is compliant with the standard scikit-learn API, we follow the standard
sequential ``fit`` and ``predict`` process like any scikit-learn regressor.
We set two values for alpha to estimate prediction intervals at approximately one
and two standard deviations from the mean.
X_train_conf, X_test, y_train_conf, y_test = train_test_split(X, y, test_size=0.5)
X_train, X_conf, y_train, y_conf = train_test_split(X_train_conf, y_train_conf,
test_size=0.5)

.. code:: python
# We follow a sequential ``fit``, ``conformalize``, and ``predict`` process.
# We set the confidence level to estimate prediction intervals at approximately one and two
# standard deviation from the mean.

from mapie.regression import MapieRegressor
from mapie_v1.regression import SplitConformalRegressor

mapie_regressor = MapieRegressor(regressor)
mapie_regressor = SplitConformalRegressor(
regressor,
confidence_level=[0.95, 0.68],
)
mapie_regressor.fit(X_train, y_train)
mapie_regressor.conformalize(X_conf, y_conf)

alpha = [0.05, 0.32]
y_pred, y_pred_intervals = mapie_regressor.predict(X_test, alpha=alpha)
y_pred = mapie_regressor.predict(X_test)
y_pred_intervals = mapie_regressor.predict_set(X_test)

MAPIE returns a tuple, the first element is a ``np.ndarray`` of shape ``(n_samples)`` giving the
predictions, and the second element a ``np.ndarray`` of shape ``(n_samples, 2, len(alpha))`` giving
the lower and upper bounds of the prediction intervals for the target quantile for each desired alpha value.
# MAPIE's ``predict`` method returns point predictions as a ``np.ndarray`` of shape ``(n_samples)``.
# The ``predict_set`` method returns prediction intervals as a ``np.ndarray`` of shape ``(n_samples, 2, 2)``
# giving the lower and upper bounds of the intervals for each confidence level.

You can compute the coverage of your prediction intervals.
# You can compute the coverage of your prediction intervals.

.. code:: python
from mapie.metrics import regression_coverage_score_v2

coverage_scores = regression_coverage_score_v2(y_test, y_pis)
The estimated prediction intervals can then be plotted as follows.
coverage_scores = regression_coverage_score_v2(y_test, y_pred_intervals)

.. code:: python
# The estimated prediction intervals can then be plotted as follows.

from matplotlib import pyplot as plt

confidence_level = [0.95, 0.68]

plt.xlabel("x")
plt.ylabel("y")
plt.scatter(X, y, alpha=0.3)
plt.plot(X_test, y_pred, color="C1")
order = np.argsort(X_test[:, 0])
plt.plot(X_test[order], y_pis[order][:, 0, 1], color="C1", ls="--")
plt.plot(X_test[order], y_pis[order][:, 1, 1], color="C1", ls="--")
plt.plot(X_test[order], y_pred_intervals[order, 0], color="C1", ls="--")
plt.plot(X_test[order], y_pred_intervals[order, 1], color="C1", ls="--")
plt.fill_between(
X_test[order].ravel(),
y_pis[order][:, 0, 0].ravel(),
y_pis[order][:, 1, 0].ravel(),
y_pred_intervals[order][:, 0, 0].ravel(),
y_pred_intervals[order][:, 1, 0].ravel(),
alpha=0.2
)
plt.title(
f"Target and effective coverages for "
f"alpha={alpha[0]:.2f}: ({1-alpha[0]:.3f}, {coverage_scores[0]:.3f})\n"
f"confidence_level={confidence_level[0]:.2f}: {coverage_scores[0]:.3f}\n"
f"Target and effective coverages for "
f"alpha={alpha[1]:.2f}: ({1-alpha[1]:.3f}, {coverage_scores[1]:.3f})"
f"confidence_level={confidence_level[1]:.2f}: {coverage_scores[1]:.3f}"
)
plt.show()

Expand All @@ -106,10 +107,9 @@ The estimated prediction intervals can then be plotted as follows.
:align: center

The title of the plot compares the target coverages with the effective coverages.
The target coverage, or the confidence interval, is the fraction of true labels lying in the
The target coverage, or the confidence level, is the fraction of true labels lying in the
prediction intervals that we aim to obtain for a given dataset.
It is given by the alpha parameter defined in ``MapieRegressor``, here equal to ``0.05`` and ``0.32``,
thus giving target coverages of ``0.95`` and ``0.68``.
It is given by the ``confidence_level`` parameter defined in ``SplitConformalRegressor``, here equal to ``0.95`` and ``0.68``.
The effective coverage is the actual fraction of true labels lying in the prediction intervals.

3. Run MapieClassifier
Expand Down Expand Up @@ -173,4 +173,4 @@ Similarly, it's possible to do the same for a basic classification problem.
.. image:: images/quickstart_2.png
:width: 400
:align: center
:align: center
3 changes: 2 additions & 1 deletion doc/theoretical_description_conformity_scores.rst
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ it is not proportional to the uncertainty.
Key takeaways
-------------

- The absolute residual score is the basic conformity score and gives constant intervals. It is the one used by default by :class:`mapie.regression.MapieRegressor`.
- The absolute residual score is the basic conformity score and gives constant intervals. It is the one used by default by regression methods
such as :class:`mapie_v1.regression.SplitConformalRegressor`.
- The gamma conformity score adds a notion of adaptivity by giving intervals of different sizes
and is proportional to the uncertainty.
- The residual normalized score is a conformity score that requires an additional model
Expand Down
2 changes: 1 addition & 1 deletion doc/theoretical_description_regression.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
Theoretical Description
#######################

The :class:`mapie.regression.MapieRegressor` class uses various
The methods in `mapie_v1.regression` use various
resampling methods based on the jackknife strategy
recently introduced by Foygel-Barber et al. (2020) [1].
They allow the user to estimate robust prediction intervals with any kind of
Expand Down
28 changes: 13 additions & 15 deletions doc/v1_migration_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,29 +32,27 @@ In MAPIE v0.9, ``MapieRegressor`` managed all conformal regression methods under
2. Method changes
-----------------

In MAPIE v1, the conformal prediction workflow is more streamlined and modular, with distinct methods for training, calibration, and prediction. The calibration process in v1 consists of four steps.
In MAPIE v1, the conformal prediction workflow is more streamlined and modular, with distinct methods for training, conformalization (named calibration in the scientific literature), and prediction. The conformalization process in v1 consists of four steps.

Step 1: Data splitting
~~~~~~~~~~~~~~~~~~~~~~
In v0.9, Data splitting is done within two-phase process. First, data ``(X, y)`` was divided into training ``(X_train, y_train)`` and test ``(X_test, y_test)`` sets using ``train_test_split`` from ``sklearn``. In the second phase, the split between training and calibration was either done manually or handled internally by ``MapieRegressor``.
In v0.9, data splitting is handled by MAPIE.

In v1, a ``conf_split`` function has been introduced to split the data ``(X, y)`` into training ``(X_train, y_train)``, calibration ``(X_calib, y_calib)``, and test sets ``(X_test, y_test)``.
In v1, the data splitting is left to the user, with the exception of cross-conformal methods (``CrossConformalRegressor``). The user can split the data into training, conformalization, and test sets using scikit-learn's ``train_test_split`` or other methods.

This new approach in v1 gives users more control over data splitting, making it easier to manage training, calibration, and testing phases explicitly. The ``CrossConformalRegressor`` is an exception, where train/calibration splitting happens internally because cross-validation requires more granular control over data splits.
Step 2 & 3: Model training and conformalization
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
In v0.9, the ``fit`` method handled both model training and conformalization.

Step 2 & 3: Model training and calibration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
In v0.9, the ``fit`` method handled both model training and calibration.

In v1.0: MAPIE separates between the training and calibration:
In v1.0: MAPIE separates between the training and conformalization:

- ``.fit()`` method:
- In v1, ``fit`` only trains the model on training data, without handling calibration.
- In v1, ``fit`` only trains the model on training data, without handling conformalization.
- Additional fitting parameters, like ``sample_weight``, should be included in ``fit_params``, keeping this method focused on training alone.

- ``.conformalize()`` method:
- This new method performs calibration after fitting, using separate calibration data ``(X_calib, y_calib)``.
- ``predict_params`` can be passed here, allowing independent control over calibration and prediction stages.
- This new method performs conformalization after fitting, using separate conformity data ``(X_conf, y_conf)``.
- ``predict_params`` can be passed here, allowing independent control over conformalization and prediction stages.

Step 4: Making predictions (``predict`` and ``predict_set`` methods)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -99,7 +97,7 @@ The ``cv`` parameter manages the cross-validation configuration, accepting eithe

``groups``
~~~~~~~~~~~
The ``groups`` parameter is used to specify group labels for cross-validation, ensuring that the same group is not present in both training and calibration sets.
The ``groups`` parameter is used to specify group labels for cross-validation, ensuring that the same group is not present in both training and conformity sets.

- **v0.9**: Passed as a parameter to the ``fit`` method.
- **v1**: The ``groups`` present is now only present in ``CrossConformalRegressor``. It is passed in the ``.conformalize()`` method instead of the ``.fit()`` method. In other classes (like ``SplitConformalRegressor``), groups can be directly handled by the user during data splitting.
Expand Down Expand Up @@ -130,7 +128,7 @@ Defines additional parameters exclusively for prediction.
The aggregation method and technique for combining predictions in ensemble methods.

- **v0.9**: Previously, the ``agg_function`` parameter had two usage: to aggregate predictions when setting ``ensemble=True`` in the ``predict`` method, and to specify the aggregation technique in ``JackknifeAfterBootstrapRegressor``.
- **v1**: The ``agg_function`` parameter has been split into two distinct parameters: ``aggregate_predictions`` and ``aggregation_method``. ``aggregate_predictions`` is specific to ``CrossConformalRegressor``, and it specifies how predictions from multiple conformal regressors are aggregated when making point predictions. ``aggregation_method`` is specific to ``JackknifeAfterBootstrapRegressor``, and it specifies the aggregation technique for combining predictions across different bootstrap samples during calibration.
- **v1**: The ``agg_function`` parameter has been split into two distinct parameters: ``aggregate_predictions`` and ``aggregation_method``. ``aggregate_predictions`` is specific to ``CrossConformalRegressor``, and it specifies how predictions from multiple conformal regressors are aggregated when making point predictions. ``aggregation_method`` is specific to ``JackknifeAfterBootstrapRegressor``, and it specifies the aggregation technique for combining predictions across different bootstrap samples during conformalization.

``Other parameters``
~~~~~~~~~~~~~~~~~~~~
Expand All @@ -155,7 +153,7 @@ Example 1: Split Conformal Prediction

Description
############
Split conformal prediction is a widely used method for generating prediction intervals, it splits the data into training, calibration, and test sets. The model is trained on the training set, calibrated on the calibration set, and then used to make predictions on the test set. In `MAPIE v1`, the `SplitConformalRegressor` replaces the older `MapieRegressor` with a more modular design and simplified API.
Split conformal prediction is a widely used method for generating prediction intervals, it splits the data into training, conformity, and test sets. The model is trained on the training set, calibrated on the conformity set, and then used to make predictions on the test set. In `MAPIE v1`, the `SplitConformalRegressor` replaces the older `MapieRegressor` with a more modular design and simplified API.

MAPIE v0.9 Code
###############
Expand Down

0 comments on commit 9c6e00c

Please sign in to comment.