diff --git a/mapie_v1/_utils.py b/mapie_v1/_utils.py index 49a3381d..06caa7b4 100644 --- a/mapie_v1/_utils.py +++ b/mapie_v1/_utils.py @@ -4,16 +4,24 @@ from numpy import array from mapie._typing import ArrayLike, NDArray from sklearn.model_selection import BaseCrossValidator +from decimal import Decimal def transform_confidence_level_to_alpha_list( confidence_level: Union[float, List[float]] ) -> List[float]: - if isinstance(confidence_level, float): - confidence_levels = [confidence_level] - else: + if isinstance(confidence_level, list): confidence_levels = confidence_level - return [1 - level for level in confidence_levels] + else: + confidence_levels = [confidence_level] + + # Using decimals to avoid weird-looking float approximations + # when computing alpha = 1 - confidence_level + # Such approximations arise even with simple confidence levels like 0.9 + confidence_levels_d = [Decimal(str(conf_level)) for conf_level in confidence_levels] + alphas_d = [Decimal("1") - conf_level_d for conf_level_d in confidence_levels_d] + + return [float(alpha_d) for alpha_d in alphas_d] def check_if_param_in_allowed_values(