Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accept dict as preprocessor config #148

Merged
merged 4 commits into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/tabpfn/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@
determine_precision,
initialize_tabpfn_model,
)
from tabpfn.config import ModelInterfaceConfig
from tabpfn.constants import (
PROBABILITY_EPSILON_ROUND_ZERO,
SKLEARN_16_DECIMAL_PRECISION,
ModelInterfaceConfig,
XType,
YType,
)
Expand Down
199 changes: 199 additions & 0 deletions src/tabpfn/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
"""Configuration for the model interfaces."""

# Copyright (c) Prior Labs GmbH 2025.

from __future__ import annotations

from copy import deepcopy
from dataclasses import dataclass
from typing import Literal

from tabpfn.preprocessing import PreprocessorConfig


@dataclass
class ModelInterfaceConfig:
"""Constants used as default HPs in the model interfaces.

These constants are not exposed to the models' init on purpose
to reduce the complexity for users. Furthermore, most of these
should not be optimized over by the (standard) user.

Several of the preprocessing options are supported by our code for efficiency
reasons (to avoid loading TabPFN multiple times). However, these can also be
applied outside of the model interface.
"""

MAX_UNIQUE_FOR_CATEGORICAL_FEATURES: int = 30
"""The maximum number of unique values for a feature to be considered
categorical. Otherwise, it is considered numerical."""
MIN_UNIQUE_FOR_NUMERICAL_FEATURES: int = 4
"""The minimum number of unique values for a feature to be considered numerical.
Otherwise, it is considered categorical."""
MIN_NUMBER_SAMPLES_FOR_CATEGORICAL_INFERENCE: int = 100
"""The minimum number of samples in the data to run our infer which features might
be categorical."""

OUTLIER_REMOVAL_STD: float | None | Literal["auto"] = "auto"
"""The number of standard deviations from the mean to consider a sample an outlier.
- If None, no outliers are removed.
- If float, the number of standard deviations from the mean to consider a sample
an outlier.
- If "auto", the OUTLIER_REMOVAL_STD is automatically determined.
-> 12.0 for classification and None for regression.
"""

FEATURE_SHIFT_METHOD: Literal["shuffle", "rotate"] | None = "shuffle"
"""The method used to shift features during preprocessing for ensembling to emulate
the effect of invariance to feature position. Without ensembling, TabPFN is not
invariant to feature position due to using a transformer. Moreover, shifting
features can have a positive effect on the model's performance. The options are:
- If "shuffle", the features are shuffled.
- If "rotate", the features are rotated (think of a ring).
- If None, no feature shifting is done.
"""
CLASS_SHIFT_METHOD: Literal["rotate", "shuffle"] | None = "shuffle"
"""The method used to shift classes during preprocessing for ensembling to emulate
the effect of invariance to class order. Without ensembling, TabPFN is not
invariant to class order due to using a transformer. Shifting classes can
have a positive effect on the model's performance. The options are:
- If "shuffle", the classes are shuffled.
- If "rotate", the classes are rotated (think of a ring).
- If None, no class shifting is done.
"""

FINGERPRINT_FEATURE: bool = True
"""Whether to add a fingerprint feature to the data. The added feature is a hash of
the row, counting up for duplicates. This helps TabPFN to distinguish between
duplicated data points in the input data. Otherwise, duplicates would be less
obvious during attention. This is expected to improve prediction performance and
help with stability if the data has many sample duplicates."""
POLYNOMIAL_FEATURES: Literal["no", "all"] | int = "no"
"""The number of 2 factor polynomial features to generate and add to the original
data before passing the data to TabPFN. The polynomial features are generated by
multiplying the original features together, e.g., this might add a feature `x1*x2`
to the features, if `x1` and `x2` are features. In total, this can add up O(n^2)
many features. Adding polynomial features can improve predictive performance by
exploiting simple feature engineering.
- If "no", no polynomial features are added.
- If "all", all possible polynomial features are added.
- If an int, determines the maximal number of polynomial features to add to the
original data.
"""
SUBSAMPLE_SAMPLES: (
int | float | None # (0,1) percentage, (1+) n samples
) = None
"""Subsample the input data sample/row-wise before performing any preprocessing
and the TabPFN forward pass.
- If None, no subsampling is done.
- If an int, the number of samples to subsample (or oversample if
`SUBSAMPLE_SAMPLES` is larger than the number of samples).
- If a float, the percentage of samples to subsample.
"""

PREPROCESS_TRANSFORMS: list[PreprocessorConfig | dict] | None = None
"""The preprocessing applied to the data before passing it to TabPFN. See
`PreprocessorConfig` for options and more details. If a list of `PreprocessorConfig`
is provided, the preprocessors are (repeatedly) applied across different estimators.

By default, for classification, two preprocessors are applied:
1. Uses the original input data, all features transformed with a quantile
scaler, and the first n-many components of SVD transformer (whereby
n is a fract of on the number of features or samples). Categorical features
are ordinal encoded but all categories with less than 10 features are
ignored.
2. Uses the original input data, with categorical features as ordinal encoded.

By default, for regression, two preprocessor are applied:
1. The same as for classification, with a minimal different quantile scaler.
2. The original input data power transformed and categories onehot encoded.
"""
REGRESSION_Y_PREPROCESS_TRANSFORMS: tuple[
Literal["safepower", "power", "quantile_norm", None],
...,
] = (None, "safepower")
"""The preprocessing applied to the target variable before passing it to TabPFN for
regression. This can be understood as scaling the target variable to better predict
it. The preprocessors should be passed as a tuple/list and are then (repeatedly)
used by the estimators in the ensembles.

By default, we use no preprocessing and a power transformation (if we have
more than one estimator).

The options are:
- If None, no preprocessing is done.
- If "power", a power transformation is applied.
- If "safepower", a power transformation is applied with a safety factor to
avoid numerical issues.
- If "quantile_norm", a quantile normalization is applied.
"""

USE_SKLEARN_16_DECIMAL_PRECISION: bool = False
"""Whether to round the probabilities to float 16 to match the precision of
scikit-learn. This can help with reproducibility and compatibility with
scikit-learn but is not recommended for general use. This is not exposed to the
user or as a hyperparameter.
To improve reproducibility,set `._sklearn_16_decimal_precision = True` before
calling `.predict()` or `.predict_proba()`."""

# TODO: move this somewhere else to support that this might change.
MAX_NUMBER_OF_CLASSES: int = 10
"""The number of classes seen during pretraining for classification. If the
number of classes is larger than this number, TabPFN requires an additional step
to predict for more than classes."""
MAX_NUMBER_OF_FEATURES: int = 500
"""The number of features that the pretraining was intended for. If the number of
features is larger than this number, you may see degraded performance. Note, this
is not the number of features seen by the model during pretraining but also accounts
for expected generalization (i.e., length extrapolation)."""
MAX_NUMBER_OF_SAMPLES: int = 10_000
"""The number of samples that the pretraining was intended for. If the number of
samples is larger than this number, you may see degraded performance. Note, this
is not the number of samples seen by the model during pretraining but also accounts
for expected generalization (i.e., length extrapolation)."""

FIX_NAN_BORDERS_AFTER_TARGET_TRANSFORM: bool = True
"""Whether to repair any borders of the bar distribution in regression that are NaN
after the transformation. This can happen due to multiple reasons and should in
general always be done."""

_REGRESSION_DEFAULT_OUTLIER_REMOVAL_STD: None = None
_CLASSIFICATION_DEFAULT_OUTLIER_REMOVAL_STD: float = 12.0

@staticmethod
def from_user_input(
*,
inference_config: dict | ModelInterfaceConfig | None,
) -> ModelInterfaceConfig:
"""Converts the user input to a `ModelInterfaceConfig` object.

The input inference_config can be a dictionary, a `ModelInterfaceConfig` object,
or None. If a dictionary is passed, the keys must match the attributes of
`ModelInterfaceConfig`. If a `ModelInterfaceConfig` object is passed, it is
returned as is. If None is passed, a new `ModelInterfaceConfig` object is
created with default values.
"""
if inference_config is None:
interface_config_ = ModelInterfaceConfig()
elif isinstance(inference_config, ModelInterfaceConfig):
interface_config_ = deepcopy(inference_config)
elif isinstance(inference_config, dict):
interface_config_ = ModelInterfaceConfig()
for key, value in inference_config.items():
if not hasattr(interface_config_, key):
raise ValueError(
f"Unknown kwarg passed to model construction: {key}",
)
setattr(interface_config_, key, value)
else:
raise ValueError(f"Unknown {inference_config=} passed to model.")

if interface_config_.PREPROCESS_TRANSFORMS is not None:
interface_config_.PREPROCESS_TRANSFORMS = [
PreprocessorConfig.from_dict(config)
if isinstance(config, dict)
else config
for config in interface_config_.PREPROCESS_TRANSFORMS
]

return interface_config_
Loading
Loading