Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: V1 cqr implmentation #576

87 changes: 50 additions & 37 deletions mapie/regression/quantile_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,7 @@ def fit(
if self.cv == "prefit":
X_calib, y_calib = X, y
else:
X_calib, y_calib = self._fit_estimators(
result = self._prepare_train_calib(
X=X,
y=y,
sample_weight=sample_weight,
Expand All @@ -533,7 +533,13 @@ def fit(
random_state=random_state,
shuffle=shuffle,
stratify=stratify,
**fit_params,
)
X_train, y_train, X_calib, y_calib, sample_weight = result
self._fit_estimators(
X=X_train,
y=y_train,
sample_weight=sample_weight,
**fit_params
)

self.conformalize(X_calib, y_calib)
Expand All @@ -551,7 +557,7 @@ def _initialize_and_check_prefit_estimators(self) -> None:
self.estimators_ = list(estimator)
self.single_estimator_ = self.estimators_[2]

def _fit_estimators(
def _prepare_train_calib(
self,
X: ArrayLike,
y: ArrayLike,
Expand All @@ -563,74 +569,81 @@ def _fit_estimators(
random_state: Optional[Union[int, np.random.RandomState]] = None,
shuffle: Optional[bool] = True,
stratify: Optional[ArrayLike] = None,
**fit_params,
) -> Tuple[ArrayLike, ArrayLike]:
) -> Tuple[
ArrayLike, ArrayLike, ArrayLike, ArrayLike, Optional[ArrayLike]
]:
"""
This method:
- Creates train and calib sets
- Checks adn casts params, including the train set
- Fit the 3 estimators
- Returns the calib set
Handles the preparation of training and calibration datasets,
including validation and splitting.
Returns: X_train, y_train, X_calib, y_calib, sample_weight_train
"""

self._check_parameters()
checked_estimator = self._check_estimator(self.estimator)
random_state = check_random_state(random_state)
X, y = indexable(X, y)

if X_calib is None or y_calib is None:
(
X_train, y_train, X_calib, y_calib, sample_weight_train
) = self._train_calib_split(
return self._train_calib_split(
X,
y,
sample_weight,
calib_size,
random_state,
shuffle,
stratify,
stratify
)
else:
X_train, y_train, sample_weight_train = X, y, sample_weight
return X, y, X_calib, y_calib, sample_weight

X_train, y_train = cast(ArrayLike, X_train), cast(ArrayLike, y_train)
sample_weight_train = cast(ArrayLike, sample_weight_train)
X_train, y_train = indexable(X_train, y_train)
y_train = _check_y(y_train)
# Second function: Handles estimator fitting
def _fit_estimators(
self,
X: ArrayLike,
y: ArrayLike,
sample_weight: Optional[ArrayLike] = None,
**fit_params
) -> None:
"""
Fits the estimators with provided training data
and stores them in self.estimators_.
"""
checked_estimator = self._check_estimator(self.estimator)

sample_weight_train, X_train, y_train = check_null_weight(
sample_weight_train,
X_train,
y_train
X, y = indexable(X, y)
y = _check_y(y)

sample_weight, X, y = check_null_weight(
sample_weight, X, y
)
y_train = cast(NDArray, y_train)

if isinstance(checked_estimator, Pipeline):
estimator = checked_estimator[-1]
else:
estimator = checked_estimator

name_estimator = estimator.__class__.__name__
alpha_name = self.quantile_estimator_params[
name_estimator
]["alpha_name"]
alpha_name = self.quantile_estimator_params[name_estimator][
"alpha_name"
]

for i, alpha_ in enumerate(self.alpha_np):
cloned_estimator_ = clone(checked_estimator)
params = {alpha_name: alpha_}
if isinstance(checked_estimator, Pipeline):
cloned_estimator_[-1].set_params(**params)
else:
cloned_estimator_.set_params(**params)
self.estimators_.append(fit_estimator(
cloned_estimator_,
X_train,
y_train,
sample_weight_train,
**fit_params,

self.estimators_.append(
fit_estimator(
cloned_estimator_,
X,
y,
sample_weight,
**fit_params,
)
)
self.single_estimator_ = self.estimators_[2]

return X_calib, y_calib
self.single_estimator_ = self.estimators_[2]

def conformalize(
self,
Expand Down
6 changes: 4 additions & 2 deletions mapie_v1/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,11 @@ def check_if_X_y_different_from_fit(

def make_intervals_single_if_single_alpha(
intervals: NDArray,
alphas: List[float]
alphas: Union[float, List[float]]
) -> NDArray:
if len(alphas) == 1:
if isinstance(alphas, float):
return intervals[:, :, 0]
if isinstance(alphas, list) and len(alphas) == 1:
return intervals[:, :, 0]
return intervals

Expand Down
Loading
Loading