Skip to content

Commit

Permalink
FIX: merge with master
Browse files Browse the repository at this point in the history
  • Loading branch information
LacombeLouis committed Feb 26, 2024
1 parent cdf6cdd commit 05003e6
Show file tree
Hide file tree
Showing 13 changed files with 305 additions and 94 deletions.
1 change: 1 addition & 0 deletions AUTHORS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,5 @@ Contributors
* Rafael Saraiva <[email protected]>
* Mehdi Elion <[email protected]>
* Sami Kaddani <[email protected]>
* Pierre de Fréminville <pidefrem>
To be continued ...
10 changes: 10 additions & 0 deletions HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,17 @@ History

##### (##########)
------------------
* Allow the use of `y` and `groups` arguments in cross validator methods `get_n_splits`
and `split` to enable more cv-split variants for :class:`~regression.regression.MapieRegressor`
and :class:`~classification.MapieClassifier`
(e.g. :class:`sklearn.model_selection.GroupKFold`, stratified continuous split).
This change adds the `groups` argument to the following methods:
:meth:`~estimator.interface.EnsembleEstimator.fit()`,
:meth:`~estimator.estimator.EnsembleRegressor.predict_calib()`, :meth:`~estimator.estimator.EnsembleRegressor.fit()`,
:meth:`~regression.regression.MapieRegressor.fit()`,
:meth:`~classification.MapieClassifier.fit()`.
* Add possibility of passing fit parameters used by estimators.
* Fix memory issue CQR when testing for upper and lower bounds.

0.8.0 (2024-01-03)
------------------
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
.PHONY: tests doc build

lint:
lint:
flake8 . --exclude=doc

type-check:
Expand Down
16 changes: 14 additions & 2 deletions mapie/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -1053,6 +1053,7 @@ def fit(
y: ArrayLike,
sample_weight: Optional[ArrayLike] = None,
size_raps: Optional[float] = .2,
groups: Optional[ArrayLike] = None,
**fit_params,
) -> MapieClassifier:
"""
Expand Down Expand Up @@ -1081,10 +1082,15 @@ def fit(
By default ``.2``.
groups: Optional[ArrayLike] of shape (n_samples,)
Group labels for the samples used while splitting the dataset into
train/test set.
By default ``None``.
**fit_params : dict
Additional fit parameters.
Returns
-------
MapieClassifier
Expand All @@ -1099,6 +1105,7 @@ def fit(
y = _check_y(y)

sample_weight = cast(Optional[NDArray], sample_weight)
groups = cast(Optional[NDArray], groups)
sample_weight, X, y = check_null_weight(sample_weight, X, y)

y = cast(NDArray, y)
Expand Down Expand Up @@ -1147,6 +1154,9 @@ def fit(
if sample_weight is not None:
sample_weight = sample_weight[train_raps_index]
sample_weight = cast(NDArray, sample_weight)
if groups is not None:
groups = groups[train_raps_index]
groups = cast(NDArray, groups)

# Work
if cv == "prefit":
Expand Down Expand Up @@ -1174,7 +1184,9 @@ def fit(
sample_weight,
**fit_params,
)
for k, (train_index, val_index) in enumerate(cv.split(X))
for k, (train_index, val_index) in enumerate(
cv.split(X, y_enc, groups)
)
)
(
self.estimators_,
Expand Down
38 changes: 32 additions & 6 deletions mapie/estimator/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,12 @@ def _pred_multi(self, X: ArrayLike) -> NDArray:
y_pred_multi = self._aggregate_with_mask(y_pred_multi, self.k_)
return y_pred_multi

def predict_calib(self, X: ArrayLike) -> NDArray:
def predict_calib(
self,
X: ArrayLike,
y: Optional[ArrayLike] = None,
groups: Optional[ArrayLike] = None
) -> NDArray:
"""
Perform predictions on X : the calibration set.
Expand All @@ -339,6 +344,17 @@ def predict_calib(self, X: ArrayLike) -> NDArray:
X: ArrayLike of shape (n_samples_test, n_features)
Input data
y: Optional[ArrayLike] of shape (n_samples_test,)
Input labels.
By default ``None``.
groups: Optional[ArrayLike] of shape (n_samples_test,)
Group labels for the samples used while splitting the dataset into
train/test set.
By default ``None``.
Returns
-------
NDArray of shape (n_samples_test, 1)
Expand All @@ -357,15 +373,17 @@ def predict_calib(self, X: ArrayLike) -> NDArray:
delayed(self._predict_oof_estimator)(
estimator, X, calib_index,
)
for (_, calib_index), estimator in zip(cv.split(X),
self.estimators_)
for (_, calib_index), estimator in zip(
cv.split(X, y, groups),
self.estimators_
)
)
predictions, indices = map(
list, zip(*outputs)
)
n_samples = _num_samples(X)
pred_matrix = np.full(
shape=(n_samples, cv.get_n_splits(X)),
shape=(n_samples, cv.get_n_splits(X, y, groups)),
fill_value=np.nan,
dtype=float,
)
Expand All @@ -385,6 +403,7 @@ def fit(
X: ArrayLike,
y: ArrayLike,
sample_weight: Optional[ArrayLike] = None,
groups: Optional[ArrayLike] = None,
**fit_params,
) -> EnsembleRegressor:
"""
Expand All @@ -404,6 +423,13 @@ def fit(
sample_weight: Optional[ArrayLike] of shape (n_samples,)
Sample weights. If None, then samples are equally weighted.
By default ``None``.
groups: Optional[ArrayLike] of shape (n_samples,)
Group labels for the samples used while splitting the dataset into
train/test set.
By default ``None``.
**fit_params : dict
Expand Down Expand Up @@ -440,7 +466,7 @@ def fit(
)
cv = cast(BaseCrossValidator, cv)
self.k_ = np.full(
shape=(n_samples, cv.get_n_splits(X, y)),
shape=(n_samples, cv.get_n_splits(X, y, groups)),
fill_value=np.nan,
dtype=float,
)
Expand All @@ -456,7 +482,7 @@ def fit(
sample_weight,
**fit_params
)
for train_index, _ in cv.split(X)
for train_index, _ in cv.split(X, y, groups)
)
# In split-CP, we keep only the model fitted on train dataset
if self.use_split_method_:
Expand Down
6 changes: 6 additions & 0 deletions mapie/estimator/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def fit(
X: ArrayLike,
y: ArrayLike,
sample_weight: Optional[ArrayLike] = None,
groups: Optional[ArrayLike] = None,
**fit_params
) -> EnsembleEstimator:
"""
Expand All @@ -42,6 +43,11 @@ def fit(
Sample weights. If None, then samples are equally weighted.
By default ``None``.
groups: Optional[ArrayLike] of shape (n_samples,)
Group labels for the samples used while splitting the dataset into
train/test set.
By default ``None``.
**fit_params : dict
Additional fit parameters.
Expand Down
7 changes: 6 additions & 1 deletion mapie/regression/quantile_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,7 @@ def fit(
X: ArrayLike,
y: ArrayLike,
sample_weight: Optional[ArrayLike] = None,
groups: Optional[ArrayLike] = None,
X_calib: Optional[ArrayLike] = None,
y_calib: Optional[ArrayLike] = None,
calib_size: Optional[float] = 0.3,
Expand Down Expand Up @@ -499,6 +500,9 @@ def fit(
By default ``None``.
groups: Optional[ArrayLike] of shape (n_samples,)
Always ignored, exists for compatibility.
X_calib: Optional[ArrayLike] of shape (n_calib_samples, n_features)
Calibration data.
Expand Down Expand Up @@ -696,6 +700,7 @@ def predict(
)
for i, est in enumerate(self.estimators_):
y_preds[i] = est.predict(X)
check_lower_upper_bounds(y_preds[0], y_preds[1], y_preds[2])
if symmetry:
quantile = np.full(
2,
Expand All @@ -716,5 +721,5 @@ def predict(
)
y_pred_low = y_preds[0][:, np.newaxis] - quantile[0]
y_pred_up = y_preds[1][:, np.newaxis] + quantile[1]
check_lower_upper_bounds(y_preds, y_pred_low, y_pred_up)
check_lower_upper_bounds(y_pred_low, y_pred_up, y_preds[2])
return y_preds[2], np.stack([y_pred_low, y_pred_up], axis=1)
27 changes: 23 additions & 4 deletions mapie/regression/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,7 @@ def _check_fit_parameters(
X: ArrayLike,
y: ArrayLike,
sample_weight: Optional[ArrayLike] = None,
groups: Optional[ArrayLike] = None
):
"""
Perform several checks on class parameters.
Expand All @@ -407,6 +408,11 @@ def _check_fit_parameters(
sample_weight: Optional[NDArray] of shape (n_samples,)
Non-null sample weights.
groups: Optional[ArrayLike] of shape (n_samples,)
Group labels for the samples used while splitting the dataset into
train/test set.
By default ``None``.
Raises
------
ValueError
Expand Down Expand Up @@ -449,14 +455,21 @@ def _check_fit_parameters(
X = cast(NDArray, X)
y = cast(NDArray, y)
sample_weight = cast(Optional[NDArray], sample_weight)
groups = cast(Optional[NDArray], groups)

return estimator, cs_estimator, agg_function, cv, X, y, sample_weight
return (
estimator, cs_estimator,
agg_function, cv,
X, y,
sample_weight, groups
)

def fit(
self,
X: ArrayLike,
y: ArrayLike,
sample_weight: Optional[ArrayLike] = None,
groups: Optional[ArrayLike] = None,
**fit_params,
) -> MapieRegressor:
"""
Expand Down Expand Up @@ -485,6 +498,11 @@ def fit(
By default ``None``.
groups: Optional[ArrayLike] of shape (n_samples,)
Group labels for the samples used while splitting the dataset into
train/test set.
By default ``None``.
**fit_params : dict
Additional fit parameters.
Expand All @@ -500,7 +518,8 @@ def fit(
cv,
X,
y,
sample_weight) = self._check_fit_parameters(X, y, sample_weight)
sample_weight,
groups) = self._check_fit_parameters(X, y, sample_weight, groups)

self.estimator_ = EnsembleRegressor(
estimator,
Expand All @@ -514,11 +533,11 @@ def fit(
)
# Fit the prediction function
self.estimator_ = self.estimator_.fit(
X, y, sample_weight, **fit_params
X, y, sample_weight=sample_weight, groups=groups, **fit_params
)

# Predict on calibration data
y_pred = self.estimator_.predict_calib(X)
y_pred = self.estimator_.predict_calib(X, y=y, groups=groups)

# Compute the conformity scores (manage jk-ab case)
self.conformity_scores_ = \
Expand Down
4 changes: 2 additions & 2 deletions mapie/subsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(
self.random_state = random_state

def split(
self, X: NDArray
self, X: NDArray, *args: Any, **kargs: Any
) -> Generator[Tuple[NDArray, NDArray], None, None]:
"""
Generate indices to split data into training and test sets.
Expand Down Expand Up @@ -154,7 +154,7 @@ def __init__(
self.random_state = random_state

def split(
self, X: NDArray
self, X: NDArray, *args: Any, **kargs: Any
) -> Generator[Tuple[NDArray, NDArray], None, None]:
"""
Generate indices to split data into training and test sets.
Expand Down
Loading

0 comments on commit 05003e6

Please sign in to comment.