diff --git a/src/tabpfn/classifier.py b/src/tabpfn/classifier.py index c8b6ed14..61f96978 100644 --- a/src/tabpfn/classifier.py +++ b/src/tabpfn/classifier.py @@ -33,10 +33,10 @@ determine_precision, initialize_tabpfn_model, ) +from tabpfn.config import ModelInterfaceConfig from tabpfn.constants import ( PROBABILITY_EPSILON_ROUND_ZERO, SKLEARN_16_DECIMAL_PRECISION, - ModelInterfaceConfig, XType, YType, ) diff --git a/src/tabpfn/config.py b/src/tabpfn/config.py new file mode 100644 index 00000000..18499677 --- /dev/null +++ b/src/tabpfn/config.py @@ -0,0 +1,199 @@ +"""Configuration for the model interfaces.""" + +# Copyright (c) Prior Labs GmbH 2025. + +from __future__ import annotations + +from copy import deepcopy +from dataclasses import dataclass +from typing import Literal + +from tabpfn.preprocessing import PreprocessorConfig + + +@dataclass +class ModelInterfaceConfig: + """Constants used as default HPs in the model interfaces. + + These constants are not exposed to the models' init on purpose + to reduce the complexity for users. Furthermore, most of these + should not be optimized over by the (standard) user. + + Several of the preprocessing options are supported by our code for efficiency + reasons (to avoid loading TabPFN multiple times). However, these can also be + applied outside of the model interface. + """ + + MAX_UNIQUE_FOR_CATEGORICAL_FEATURES: int = 30 + """The maximum number of unique values for a feature to be considered + categorical. Otherwise, it is considered numerical.""" + MIN_UNIQUE_FOR_NUMERICAL_FEATURES: int = 4 + """The minimum number of unique values for a feature to be considered numerical. + Otherwise, it is considered categorical.""" + MIN_NUMBER_SAMPLES_FOR_CATEGORICAL_INFERENCE: int = 100 + """The minimum number of samples in the data to run our infer which features might + be categorical.""" + + OUTLIER_REMOVAL_STD: float | None | Literal["auto"] = "auto" + """The number of standard deviations from the mean to consider a sample an outlier. + - If None, no outliers are removed. + - If float, the number of standard deviations from the mean to consider a sample + an outlier. + - If "auto", the OUTLIER_REMOVAL_STD is automatically determined. + -> 12.0 for classification and None for regression. + """ + + FEATURE_SHIFT_METHOD: Literal["shuffle", "rotate"] | None = "shuffle" + """The method used to shift features during preprocessing for ensembling to emulate + the effect of invariance to feature position. Without ensembling, TabPFN is not + invariant to feature position due to using a transformer. Moreover, shifting + features can have a positive effect on the model's performance. The options are: + - If "shuffle", the features are shuffled. + - If "rotate", the features are rotated (think of a ring). + - If None, no feature shifting is done. + """ + CLASS_SHIFT_METHOD: Literal["rotate", "shuffle"] | None = "shuffle" + """The method used to shift classes during preprocessing for ensembling to emulate + the effect of invariance to class order. Without ensembling, TabPFN is not + invariant to class order due to using a transformer. Shifting classes can + have a positive effect on the model's performance. The options are: + - If "shuffle", the classes are shuffled. + - If "rotate", the classes are rotated (think of a ring). + - If None, no class shifting is done. + """ + + FINGERPRINT_FEATURE: bool = True + """Whether to add a fingerprint feature to the data. The added feature is a hash of + the row, counting up for duplicates. This helps TabPFN to distinguish between + duplicated data points in the input data. Otherwise, duplicates would be less + obvious during attention. This is expected to improve prediction performance and + help with stability if the data has many sample duplicates.""" + POLYNOMIAL_FEATURES: Literal["no", "all"] | int = "no" + """The number of 2 factor polynomial features to generate and add to the original + data before passing the data to TabPFN. The polynomial features are generated by + multiplying the original features together, e.g., this might add a feature `x1*x2` + to the features, if `x1` and `x2` are features. In total, this can add up O(n^2) + many features. Adding polynomial features can improve predictive performance by + exploiting simple feature engineering. + - If "no", no polynomial features are added. + - If "all", all possible polynomial features are added. + - If an int, determines the maximal number of polynomial features to add to the + original data. + """ + SUBSAMPLE_SAMPLES: ( + int | float | None # (0,1) percentage, (1+) n samples + ) = None + """Subsample the input data sample/row-wise before performing any preprocessing + and the TabPFN forward pass. + - If None, no subsampling is done. + - If an int, the number of samples to subsample (or oversample if + `SUBSAMPLE_SAMPLES` is larger than the number of samples). + - If a float, the percentage of samples to subsample. + """ + + PREPROCESS_TRANSFORMS: list[PreprocessorConfig | dict] | None = None + """The preprocessing applied to the data before passing it to TabPFN. See + `PreprocessorConfig` for options and more details. If a list of `PreprocessorConfig` + is provided, the preprocessors are (repeatedly) applied across different estimators. + + By default, for classification, two preprocessors are applied: + 1. Uses the original input data, all features transformed with a quantile + scaler, and the first n-many components of SVD transformer (whereby + n is a fract of on the number of features or samples). Categorical features + are ordinal encoded but all categories with less than 10 features are + ignored. + 2. Uses the original input data, with categorical features as ordinal encoded. + + By default, for regression, two preprocessor are applied: + 1. The same as for classification, with a minimal different quantile scaler. + 2. The original input data power transformed and categories onehot encoded. + """ + REGRESSION_Y_PREPROCESS_TRANSFORMS: tuple[ + Literal["safepower", "power", "quantile_norm", None], + ..., + ] = (None, "safepower") + """The preprocessing applied to the target variable before passing it to TabPFN for + regression. This can be understood as scaling the target variable to better predict + it. The preprocessors should be passed as a tuple/list and are then (repeatedly) + used by the estimators in the ensembles. + + By default, we use no preprocessing and a power transformation (if we have + more than one estimator). + + The options are: + - If None, no preprocessing is done. + - If "power", a power transformation is applied. + - If "safepower", a power transformation is applied with a safety factor to + avoid numerical issues. + - If "quantile_norm", a quantile normalization is applied. + """ + + USE_SKLEARN_16_DECIMAL_PRECISION: bool = False + """Whether to round the probabilities to float 16 to match the precision of + scikit-learn. This can help with reproducibility and compatibility with + scikit-learn but is not recommended for general use. This is not exposed to the + user or as a hyperparameter. + To improve reproducibility,set `._sklearn_16_decimal_precision = True` before + calling `.predict()` or `.predict_proba()`.""" + + # TODO: move this somewhere else to support that this might change. + MAX_NUMBER_OF_CLASSES: int = 10 + """The number of classes seen during pretraining for classification. If the + number of classes is larger than this number, TabPFN requires an additional step + to predict for more than classes.""" + MAX_NUMBER_OF_FEATURES: int = 500 + """The number of features that the pretraining was intended for. If the number of + features is larger than this number, you may see degraded performance. Note, this + is not the number of features seen by the model during pretraining but also accounts + for expected generalization (i.e., length extrapolation).""" + MAX_NUMBER_OF_SAMPLES: int = 10_000 + """The number of samples that the pretraining was intended for. If the number of + samples is larger than this number, you may see degraded performance. Note, this + is not the number of samples seen by the model during pretraining but also accounts + for expected generalization (i.e., length extrapolation).""" + + FIX_NAN_BORDERS_AFTER_TARGET_TRANSFORM: bool = True + """Whether to repair any borders of the bar distribution in regression that are NaN + after the transformation. This can happen due to multiple reasons and should in + general always be done.""" + + _REGRESSION_DEFAULT_OUTLIER_REMOVAL_STD: None = None + _CLASSIFICATION_DEFAULT_OUTLIER_REMOVAL_STD: float = 12.0 + + @staticmethod + def from_user_input( + *, + inference_config: dict | ModelInterfaceConfig | None, + ) -> ModelInterfaceConfig: + """Converts the user input to a `ModelInterfaceConfig` object. + + The input inference_config can be a dictionary, a `ModelInterfaceConfig` object, + or None. If a dictionary is passed, the keys must match the attributes of + `ModelInterfaceConfig`. If a `ModelInterfaceConfig` object is passed, it is + returned as is. If None is passed, a new `ModelInterfaceConfig` object is + created with default values. + """ + if inference_config is None: + interface_config_ = ModelInterfaceConfig() + elif isinstance(inference_config, ModelInterfaceConfig): + interface_config_ = deepcopy(inference_config) + elif isinstance(inference_config, dict): + interface_config_ = ModelInterfaceConfig() + for key, value in inference_config.items(): + if not hasattr(interface_config_, key): + raise ValueError( + f"Unknown kwarg passed to model construction: {key}", + ) + setattr(interface_config_, key, value) + else: + raise ValueError(f"Unknown {inference_config=} passed to model.") + + if interface_config_.PREPROCESS_TRANSFORMS is not None: + interface_config_.PREPROCESS_TRANSFORMS = [ + PreprocessorConfig.from_dict(config) + if isinstance(config, dict) + else config + for config in interface_config_.PREPROCESS_TRANSFORMS + ] + + return interface_config_ diff --git a/src/tabpfn/constants.py b/src/tabpfn/constants.py index e570952f..3ab94bb1 100644 --- a/src/tabpfn/constants.py +++ b/src/tabpfn/constants.py @@ -7,20 +7,13 @@ # enumeration of things from __future__ import annotations -from copy import deepcopy -from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Literal +from typing import Any, Literal from typing_extensions import TypeAlias import joblib import numpy as np from packaging import version -if TYPE_CHECKING: - from tabpfn.preprocessing import ( - PreprocessorConfig, - ) - TaskType: TypeAlias = Literal["multiclass", "regression"] TaskTypeValues: tuple[TaskType, ...] = ("multiclass", "regression") @@ -31,186 +24,6 @@ TODO_TYPE1: TypeAlias = str -@dataclass -class ModelInterfaceConfig: - """Constants used as default HPs in the model interfaces. - - These constants are not exposed to the models' init on purpose - to reduce the complexity for users. Furthermore, most of these - should not be optimized over by the (standard) user. - - Several of the preprocessing options are supported by our code for efficiency - reasons (to avoid loading TabPFN multiple times). However, these can also be - applied outside of the model interface. - """ - - MAX_UNIQUE_FOR_CATEGORICAL_FEATURES: int = 30 - """The maximum number of unique values for a feature to be considered - categorical. Otherwise, it is considered numerical.""" - MIN_UNIQUE_FOR_NUMERICAL_FEATURES: int = 4 - """The minimum number of unique values for a feature to be considered numerical. - Otherwise, it is considered categorical.""" - MIN_NUMBER_SAMPLES_FOR_CATEGORICAL_INFERENCE: int = 100 - """The minimum number of samples in the data to run our infer which features might - be categorical.""" - - OUTLIER_REMOVAL_STD: float | None | Literal["auto"] = "auto" - """The number of standard deviations from the mean to consider a sample an outlier. - - If None, no outliers are removed. - - If float, the number of standard deviations from the mean to consider a sample - an outlier. - - If "auto", the OUTLIER_REMOVAL_STD is automatically determined. - -> 12.0 for classification and None for regression. - """ - - FEATURE_SHIFT_METHOD: Literal["shuffle", "rotate"] | None = "shuffle" - """The method used to shift features during preprocessing for ensembling to emulate - the effect of invariance to feature position. Without ensembling, TabPFN is not - invariant to feature position due to using a transformer. Moreover, shifting - features can have a positive effect on the model's performance. The options are: - - If "shuffle", the features are shuffled. - - If "rotate", the features are rotated (think of a ring). - - If None, no feature shifting is done. - """ - CLASS_SHIFT_METHOD: Literal["rotate", "shuffle"] | None = "shuffle" - """The method used to shift classes during preprocessing for ensembling to emulate - the effect of invariance to class order. Without ensembling, TabPFN is not - invariant to class order due to using a transformer. Shifting classes can - have a positive effect on the model's performance. The options are: - - If "shuffle", the classes are shuffled. - - If "rotate", the classes are rotated (think of a ring). - - If None, no class shifting is done. - """ - - FINGERPRINT_FEATURE: bool = True - """Whether to add a fingerprint feature to the data. The added feature is a hash of - the row, counting up for duplicates. This helps TabPFN to distinguish between - duplicated data points in the input data. Otherwise, duplicates would be less - obvious during attention. This is expected to improve prediction performance and - help with stability if the data has many sample duplicates.""" - POLYNOMIAL_FEATURES: Literal["no", "all"] | int = "no" - """The number of 2 factor polynomial features to generate and add to the original - data before passing the data to TabPFN. The polynomial features are generated by - multiplying the original features together, e.g., this might add a feature `x1*x2` - to the features, if `x1` and `x2` are features. In total, this can add up O(n^2) - many features. Adding polynomial features can improve predictive performance by - exploiting simple feature engineering. - - If "no", no polynomial features are added. - - If "all", all possible polynomial features are added. - - If an int, determines the maximal number of polynomial features to add to the - original data. - """ - SUBSAMPLE_SAMPLES: ( - int | float | None # (0,1) percentage, (1+) n samples - ) = None - """Subsample the input data sample/row-wise before performing any preprocessing - and the TabPFN forward pass. - - If None, no subsampling is done. - - If an int, the number of samples to subsample (or oversample if - `SUBSAMPLE_SAMPLES` is larger than the number of samples). - - If a float, the percentage of samples to subsample. - """ - - PREPROCESS_TRANSFORMS: list[PreprocessorConfig] | None = None - """The preprocessing applied to the data before passing it to TabPFN. See - `PreprocessorConfig` for options and more details. If a list of `PreprocessorConfig` - is provided, the preprocessors are (repeatedly) applied across different estimators. - - By default, for classification, two preprocessors are applied: - 1. Uses the original input data, all features transformed with a quantile - scaler, and the first n-many components of SVD transformer (whereby - n is a fract of on the number of features or samples). Categorical features - are ordinal encoded but all categories with less than 10 features are - ignored. - 2. Uses the original input data, with categorical features as ordinal encoded. - - By default, for regression, two preprocessor are applied: - 1. The same as for classification, with a minimal different quantile scaler. - 2. The original input data power transformed and categories onehot encoded. - """ - REGRESSION_Y_PREPROCESS_TRANSFORMS: tuple[ - Literal["safepower", "power", "quantile_norm", None], - ..., - ] = (None, "safepower") - """The preprocessing applied to the target variable before passing it to TabPFN for - regression. This can be understood as scaling the target variable to better predict - it. The preprocessors should be passed as a tuple/list and are then (repeatedly) - used by the estimators in the ensembles. - - By default, we use no preprocessing and a power transformation (if we have - more than one estimator). - - The options are: - - If None, no preprocessing is done. - - If "power", a power transformation is applied. - - If "safepower", a power transformation is applied with a safety factor to - avoid numerical issues. - - If "quantile_norm", a quantile normalization is applied. - """ - - USE_SKLEARN_16_DECIMAL_PRECISION: bool = False - """Whether to round the probabilities to float 16 to match the precision of - scikit-learn. This can help with reproducibility and compatibility with - scikit-learn but is not recommended for general use. This is not exposed to the - user or as a hyperparameter. - To improve reproducibility,set `._sklearn_16_decimal_precision = True` before - calling `.predict()` or `.predict_proba()`.""" - - # TODO: move this somewhere else to support that this might change. - MAX_NUMBER_OF_CLASSES: int = 10 - """The number of classes seen during pretraining for classification. If the - number of classes is larger than this number, TabPFN requires an additional step - to predict for more than classes.""" - MAX_NUMBER_OF_FEATURES: int = 500 - """The number of features that the pretraining was intended for. If the number of - features is larger than this number, you may see degraded performance. Note, this - is not the number of features seen by the model during pretraining but also accounts - for expected generalization (i.e., length extrapolation).""" - MAX_NUMBER_OF_SAMPLES: int = 10_000 - """The number of samples that the pretraining was intended for. If the number of - samples is larger than this number, you may see degraded performance. Note, this - is not the number of samples seen by the model during pretraining but also accounts - for expected generalization (i.e., length extrapolation).""" - - FIX_NAN_BORDERS_AFTER_TARGET_TRANSFORM: bool = True - """Whether to repair any borders of the bar distribution in regression that are NaN - after the transformation. This can happen due to multiple reasons and should in - general always be done.""" - - _REGRESSION_DEFAULT_OUTLIER_REMOVAL_STD: None = None - _CLASSIFICATION_DEFAULT_OUTLIER_REMOVAL_STD: float = 12.0 - - @staticmethod - def from_user_input( - *, - inference_config: dict | ModelInterfaceConfig | None, - ) -> ModelInterfaceConfig: - """Converts the user input to a `ModelInterfaceConfig` object. - - The input inference_config can be a dictionary, a `ModelInterfaceConfig` object, - or None. If a dictionary is passed, the keys must match the attributes of - `ModelInterfaceConfig`. If a `ModelInterfaceConfig` object is passed, it is - returned as is. If None is passed, a new `ModelInterfaceConfig` object is - created with default values. - """ - if inference_config is None: - interface_config_ = ModelInterfaceConfig() - elif isinstance(inference_config, ModelInterfaceConfig): - interface_config_ = deepcopy(inference_config) - elif isinstance(inference_config, dict): - interface_config_ = ModelInterfaceConfig() - for key, value in inference_config.items(): - if not hasattr(interface_config_, key): - raise ValueError( - f"Unknown kwarg passed to model construction: {key}", - ) - setattr(interface_config_, key, value) - else: - raise ValueError(f"Unknown {inference_config=} passed to model.") - - return interface_config_ - - SKLEARN_16_DECIMAL_PRECISION = 16 PROBABILITY_EPSILON_ROUND_ZERO = 1e-3 REGRESSION_NAN_BORDER_LIMIT_UPPER = 1e3 diff --git a/src/tabpfn/preprocessing.py b/src/tabpfn/preprocessing.py index 85d7d384..85c232b6 100644 --- a/src/tabpfn/preprocessing.py +++ b/src/tabpfn/preprocessing.py @@ -138,6 +138,38 @@ def __str__(self) -> str: ) ) + def to_dict(self) -> dict: + """Convert the config to a dictionary. + + Returns: + Dictionary representation of the config. + """ + return { + "name": self.name, + "categorical_name": self.categorical_name, + "append_original": self.append_original, + "subsample_features": self.subsample_features, + "global_transformer_name": self.global_transformer_name, + } + + @classmethod + def from_dict(cls, config_dict: dict) -> PreprocessorConfig: + """Create a config from a dictionary. + + Args: + config_dict: Dictionary containing the config parameters. + + Returns: + PreprocessorConfig instance. + """ + return cls( + name=config_dict["name"], + categorical_name=config_dict["categorical_name"], + append_original=config_dict["append_original"], + subsample_features=config_dict["subsample_features"], + global_transformer_name=config_dict["global_transformer_name"], + ) + def default_classifier_preprocessor_configs() -> list[PreprocessorConfig]: """Default preprocessor configurations for classification.""" diff --git a/src/tabpfn/regressor.py b/src/tabpfn/regressor.py index 4fd77d6f..68774905 100644 --- a/src/tabpfn/regressor.py +++ b/src/tabpfn/regressor.py @@ -37,11 +37,7 @@ determine_precision, initialize_tabpfn_model, ) -from tabpfn.constants import ( - ModelInterfaceConfig, - XType, - YType, -) +from tabpfn.config import ModelInterfaceConfig from tabpfn.model.bar_distribution import FullSupportBarDistribution from tabpfn.model.preprocessing import ( ReshapeFeatureDistributionsStep, @@ -70,6 +66,10 @@ from sklearn.pipeline import Pipeline from torch.types import _dtype + from tabpfn.constants import ( + XType, + YType, + ) from tabpfn.inference import ( InferenceEngine, ) diff --git a/tests/test_classifier_interface.py b/tests/test_classifier_interface.py index 679d968b..dfd5790a 100644 --- a/tests/test_classifier_interface.py +++ b/tests/test_classifier_interface.py @@ -13,6 +13,7 @@ from sklearn.utils.estimator_checks import parametrize_with_checks from tabpfn import TabPFNClassifier +from tabpfn.preprocessing import PreprocessorConfig devices = ["cpu"] if torch.cuda.is_available(): @@ -169,3 +170,52 @@ def test_classifier_in_pipeline(X_y: tuple[np.ndarray, np.ndarray]) -> None: expected_mean, rtol=0.1, ), "Class probabilities are not properly balanced in pipeline" + + +def test_dict_vs_object_preprocessor_config(X_y: tuple[np.ndarray, np.ndarray]) -> None: + """Test that dict configs behave identically to PreprocessorConfig objects.""" + X, y = X_y + + # Define same config as both dict and object + dict_config = { + "name": "quantile_uni_coarse", + "append_original": False, # changed from default + "categorical_name": "ordinal_very_common_categories_shuffled", + "global_transformer_name": "svd", + "subsample_features": -1, + } + + object_config = PreprocessorConfig( + name="quantile_uni_coarse", + append_original=False, # changed from default + categorical_name="ordinal_very_common_categories_shuffled", + global_transformer_name="svd", + subsample_features=-1, + ) + + # Create two models with same random state + model_dict = TabPFNClassifier( + inference_config={"PREPROCESS_TRANSFORMS": [dict_config]}, + n_estimators=2, + random_state=42, + ) + + model_obj = TabPFNClassifier( + inference_config={"PREPROCESS_TRANSFORMS": [object_config]}, + n_estimators=2, + random_state=42, + ) + + # Fit both models + model_dict.fit(X, y) + model_obj.fit(X, y) + + # Compare predictions + pred_dict = model_dict.predict(X) + pred_obj = model_obj.predict(X) + np.testing.assert_array_equal(pred_dict, pred_obj) + + # Compare probabilities + prob_dict = model_dict.predict_proba(X) + prob_obj = model_obj.predict_proba(X) + np.testing.assert_array_almost_equal(prob_dict, prob_obj) diff --git a/tests/test_regressor_interface.py b/tests/test_regressor_interface.py index 0ab3ca74..ccd48226 100644 --- a/tests/test_regressor_interface.py +++ b/tests/test_regressor_interface.py @@ -13,6 +13,7 @@ from sklearn.utils.estimator_checks import parametrize_with_checks from tabpfn import TabPFNRegressor +from tabpfn.preprocessing import PreprocessorConfig devices = ["cpu"] if torch.cuda.is_available(): @@ -155,3 +156,64 @@ def test_regressor_in_pipeline(X_y: tuple[np.ndarray, np.ndarray]) -> None: assert quantiles[0].shape == ( X.shape[0], ), "Quantile predictions shape is incorrect" + + +def test_dict_vs_object_preprocessor_config(X_y: tuple[np.ndarray, np.ndarray]) -> None: + """Test that dict configs behave identically to PreprocessorConfig objects.""" + X, y = X_y + + # Define same config as both dict and object + dict_config = { + "name": "quantile_uni", + "append_original": False, # changed from default + "categorical_name": "ordinal_very_common_categories_shuffled", + "global_transformer_name": "svd", + "subsample_features": -1, + } + + object_config = PreprocessorConfig( + name="quantile_uni", + append_original=False, # changed from default + categorical_name="ordinal_very_common_categories_shuffled", + global_transformer_name="svd", + subsample_features=-1, + ) + + # Create two models with same random state + model_dict = TabPFNRegressor( + inference_config={"PREPROCESS_TRANSFORMS": [dict_config]}, + n_estimators=2, + random_state=42, + ) + + model_obj = TabPFNRegressor( + inference_config={"PREPROCESS_TRANSFORMS": [object_config]}, + n_estimators=2, + random_state=42, + ) + + # Fit both models + model_dict.fit(X, y) + model_obj.fit(X, y) + + # Compare predictions for different output types + for output_type in ["mean", "median", "mode"]: + pred_dict = model_dict.predict(X, output_type=output_type) + pred_obj = model_obj.predict(X, output_type=output_type) + np.testing.assert_array_almost_equal( + pred_dict, + pred_obj, + err_msg=f"Predictions differ for output_type={output_type}", + ) + + # Compare quantile predictions + quantiles = [0.1, 0.5, 0.9] + quant_dict = model_dict.predict(X, output_type="quantiles", quantiles=quantiles) + quant_obj = model_obj.predict(X, output_type="quantiles", quantiles=quantiles) + + for q_dict, q_obj in zip(quant_dict, quant_obj): + np.testing.assert_array_almost_equal( + q_dict, + q_obj, + err_msg="Quantile predictions differ", + )