Skip to content

Commit

Permalink
UPD: change name of alpha method
Browse files Browse the repository at this point in the history
  • Loading branch information
Thibault Cordier committed Dec 21, 2023
1 parent e8c2ba4 commit 176bbaf
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
6 changes: 3 additions & 3 deletions mapie/regression/time_series_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def partial_fit(
] = new_conformity_scores_
return self

def init_alpha(
def _get_alpha(
self,
alpha: Optional[Union[float, Iterable[float]]] = None,
reset: bool = False
Expand Down Expand Up @@ -251,7 +251,7 @@ def adapt_conformal_inference(
X, y = cast(NDArray, X), cast(NDArray, y)
X, y = convert_to_numpy(X, y)

self.init_alpha()
self._get_alpha()
alpha = cast(Optional[NDArray], check_alpha(alpha))
if alpha is None:
alpha = np.array(list(self.current_alpha.keys()))
Expand Down Expand Up @@ -387,7 +387,7 @@ def predict(
super().predict(X, ensemble, alpha, optimize_beta)

if self.method == "aci":
alpha = self.init_alpha(alpha)
alpha = self._get_alpha(alpha)

return super().predict(X, ensemble, alpha, optimize_beta)

Expand Down
8 changes: 4 additions & 4 deletions mapie/tests/test_time_series_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,17 +477,17 @@ def test_aci_method() -> None:


def test_aci_init_and_reset_alpha_dict() -> None:
"""Test that `init_alpha` resets all the values in the dictionary."""
"""Test that `_get_alpha` resets all the values in the dictionary."""
mapie_ts_reg = MapieTimeSeriesRegressor(method="aci")
mapie_ts_reg.init_alpha()
mapie_ts_reg._get_alpha()
np.testing.assert_equal(isinstance(mapie_ts_reg.current_alpha, dict), True)

mapie_ts_reg.current_alpha[0.05] = 0.45
mapie_ts_reg.init_alpha(reset=True)
mapie_ts_reg._get_alpha(reset=True)
np.testing.assert_equal(bool(mapie_ts_reg.current_alpha), False)


def test_aci_init_alpha_with_unknown_alpha() -> None:
def test_aci__get_alpha_with_unknown_alpha() -> None:
"""
Test that the `adapt_conformal_inference` method initializes
a new value if alpha is seen for the first time.
Expand Down

0 comments on commit 176bbaf

Please sign in to comment.