From d8665e46788839111310f7b7aed31b359eb3a8c7 Mon Sep 17 00:00:00 2001 From: Valentin Laurent Date: Mon, 6 Jan 2025 15:43:19 +0100 Subject: [PATCH] REFACTO: in split setting, remove checking NaNs and irrelevant aggregation to avoid triggering unwanted warnings (#586) * REFACTO: in split setting, remove checking NaNs and irrelevant aggregation to avoid triggering unwanted warnings --- mapie/estimator/regressor.py | 7 +++++-- mapie/tests/test_regression.py | 2 +- mapie/tests/test_time_series_regression.py | 3 ++- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/mapie/estimator/regressor.py b/mapie/estimator/regressor.py index bad8988ca..ddf778e02 100644 --- a/mapie/estimator/regressor.py +++ b/mapie/estimator/regressor.py @@ -402,9 +402,12 @@ def predict_calib( predictions[i], dtype=float ) self.k_[ind, i] = 1 - check_nan_in_aposteriori_prediction(pred_matrix) - y_pred = aggregate_all(self.agg_function, pred_matrix) + if self.use_split_method_: + y_pred = pred_matrix.flatten() + else: + check_nan_in_aposteriori_prediction(pred_matrix) + y_pred = aggregate_all(self.agg_function, pred_matrix) return y_pred diff --git a/mapie/tests/test_regression.py b/mapie/tests/test_regression.py index e062a3704..f06fff2e3 100644 --- a/mapie/tests/test_regression.py +++ b/mapie/tests/test_regression.py @@ -701,7 +701,7 @@ def test_not_enough_resamplings() -> None: """ with pytest.warns(UserWarning, match=r"WARNING: at least one point of*"): mapie_reg = MapieRegressor( - cv=Subsample(n_resamplings=1), agg_function="mean" + cv=Subsample(n_resamplings=2, random_state=0), agg_function="mean" ) mapie_reg.fit(X, y) diff --git a/mapie/tests/test_time_series_regression.py b/mapie/tests/test_time_series_regression.py index 785cb9088..77e4607b4 100644 --- a/mapie/tests/test_time_series_regression.py +++ b/mapie/tests/test_time_series_regression.py @@ -318,7 +318,8 @@ def test_not_enough_resamplings() -> None: match=r"WARNING: at least one point of*" ): mapie_ts_reg = MapieTimeSeriesRegressor( - cv=BlockBootstrap(n_resamplings=1, n_blocks=1), agg_function="mean" + cv=BlockBootstrap(n_resamplings=2, n_blocks=1, random_state=0), + agg_function="mean" ) mapie_ts_reg.fit(X, y)