From 3d63bac45f0129e1f6930ff8bbddc59e21c477f9 Mon Sep 17 00:00:00 2001 From: Alexander Pfefferle Date: Fri, 31 Jan 2025 14:16:14 +0100 Subject: [PATCH 1/9] xfail on check_methods_sample_order_invariance too --- tests/test_regressor_interface.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_regressor_interface.py b/tests/test_regressor_interface.py index ccd48226..8aa71424 100644 --- a/tests/test_regressor_interface.py +++ b/tests/test_regressor_interface.py @@ -111,7 +111,6 @@ def test_sklearn_compatible_estimator( "check_methods_sample_order_invariance", ): estimator.inference_precision = torch.float64 - if check.func.__name__ == "check_methods_sample_order_invariance": # type: ignore pytest.xfail("We're not at 1e-7 difference yet") check(estimator) From dd698095f9d9828a98985ac2ef2f20860f122ce1 Mon Sep 17 00:00:00 2001 From: Alexander Pfefferle Date: Fri, 31 Jan 2025 20:04:10 +0100 Subject: [PATCH 2/9] use custom nanmean and nansum --- src/tabpfn/model/encoders.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/src/tabpfn/model/encoders.py b/src/tabpfn/model/encoders.py index 2ae5e698..b39fce54 100644 --- a/src/tabpfn/model/encoders.py +++ b/src/tabpfn/model/encoders.py @@ -9,11 +9,17 @@ from torch import nn -# TODO(eddiebergman): These were used before but I have no idea why. -# We use the implementations given by torch for now. -# TODO(Arjun): Enabling these again because their behaviour is a little -# different from torch's implementation (see Issue #2). We should check if this makes -# a difference in the results. +# usage of custom implementations is required to support ONNX export +def torch_nansum(x: torch.Tensor, axis=None, keepdim=False, dtype=None): + nan_mask = torch.isnan(x) + masked_input = torch.where( + nan_mask, + torch.tensor(0.0, device=x.device, dtype=x.dtype), + x, + ) + return torch.sum(masked_input, axis=axis, keepdim=keepdim, dtype=dtype) + + def torch_nanmean( x: torch.Tensor, axis: int = 0, @@ -46,7 +52,7 @@ def torch_nanstd(x: torch.Tensor, axis: int = 0): dim=axis, ) return torch.sqrt( - torch.nansum(torch.square(mean_broadcast - x), axis=axis) / (num - 1), # type: ignore + torch_nansum(torch.square(mean_broadcast - x), axis=axis) / (num - 1), # type: ignore ) @@ -458,7 +464,7 @@ def _fit(self, x: torch.Tensor, single_eval_pos: int, **kwargs: Any) -> None: single_eval_pos: The position to use for single evaluation. **kwargs: Additional keyword arguments (unused). """ - self.feature_means_ = torch.nanmean(x[:single_eval_pos], dim=0) + self.feature_means_ = torch_nanmean(x[:single_eval_pos], dim=0) def _transform( self, From e19741e4f81e1848af8932c74cf4514d03c0f67f Mon Sep 17 00:00:00 2001 From: Alexander Pfefferle Date: Fri, 31 Jan 2025 20:29:58 +0100 Subject: [PATCH 3/9] replaced torch.Generator with isolated rng --- src/tabpfn/model/transformer.py | 119 ++++++++++++++------------------ 1 file changed, 51 insertions(+), 68 deletions(-) diff --git a/src/tabpfn/model/transformer.py b/src/tabpfn/model/transformer.py index b4a5f1ab..d3d88807 100644 --- a/src/tabpfn/model/transformer.py +++ b/src/tabpfn/model/transformer.py @@ -5,6 +5,7 @@ import random import warnings from collections.abc import Callable, Iterable +from contextlib import contextmanager from functools import partial from typing import Any, Literal @@ -25,6 +26,20 @@ DEFAULT_EMSIZE = 128 +@contextmanager +def isolate_torch_rng(seed, device): + torch_rng_state = torch.get_rng_state() + if torch.cuda.is_available(): + torch_cuda_rng_state = torch.cuda.get_rng_state(device=device) + torch.manual_seed(seed) + try: + yield + finally: + torch.set_rng_state(torch_rng_state) + if torch.cuda.is_available(): + torch.cuda.set_rng_state(torch_cuda_rng_state, device=device) + + class LayerStack(nn.Module): """Similar to nn.Sequential, but with support for passing keyword arguments to layers and stacks the same layer multiple times. @@ -294,17 +309,6 @@ def __init__( # noqa: C901, D417, PLR0913 self.cached_feature_positional_embeddings: torch.Tensor | None = None self.seed = seed if seed is not None else random.randint(0, 1_000_000) # noqa: S311 - # Device on which the generator was last initialized. - # If loading from a checkpoint, this might be false, - # but it will be set to the correct device on the first forward pass. - self.generator_device = "cpu" - self._init_rnd() - - def _init_rnd(self) -> None: - self.generator = SerializableGenerator(device=self.generator_device) - if self.seed: # This can be none if set outside of the model. - self.generator.manual_seed(self.seed) - def reset_save_peak_mem_factor(self, factor: int | None = None) -> None: """Sets the save_peak_mem_factor for all layers. @@ -377,7 +381,6 @@ def forward(self, *args: Any, **kwargs: Any) -> dict[str, torch.Tensor]: # noqa Returns: The output of the model, which can be a tensor or a dictionary of tensors. """ - self._init_rnd() half_layers = kwargs.pop("half_layers", False) assert half_layers is False @@ -694,57 +697,47 @@ def add_embeddings( # noqa: C901, PLR0912 x += self.cached_embeddings[None, None] return x, y - if ( - self.generator_device != self.generator.device - or self.generator_device != x.device - ): - self.generator_device = x.device - self._init_rnd() - - if self.feature_positional_embedding == "normal_rand_vec": - embs = torch.randn( - (x.shape[2], x.shape[3]), - device=x.device, - dtype=x.dtype, - generator=self.generator, - ) - x += embs[None, None] - elif self.feature_positional_embedding == "uni_rand_vec": - embs = ( - torch.rand( + with isolate_torch_rng(self.seed, device=x.device): + if self.feature_positional_embedding == "normal_rand_vec": + embs = torch.randn( (x.shape[2], x.shape[3]), device=x.device, dtype=x.dtype, - generator=self.generator, ) - * 2 - - 1 - ) - x += embs[None, None] - elif self.feature_positional_embedding == "learned": - w = self.feature_positional_embedding_embeddings.weight - embs = w[ - torch.randint( - 0, - w.shape[0], - (x.shape[2],), - generator=self.generator, + x += embs[None, None] + elif self.feature_positional_embedding == "uni_rand_vec": + embs = ( + torch.rand( + (x.shape[2], x.shape[3]), + device=x.device, + dtype=x.dtype, + ) + * 2 + - 1 ) - ] - x += embs[None, None] - elif self.feature_positional_embedding == "subspace": - embs = torch.randn( - (x.shape[2], x.shape[3] // 4), - device=x.device, - dtype=x.dtype, - generator=self.generator, - ) - embs = self.feature_positional_embedding_embeddings(embs) - x += embs[None, None] - elif self.feature_positional_embedding is None: - embs = None - else: - raise ValueError(f"Unknown {self.feature_positional_embedding=}") + x += embs[None, None] + elif self.feature_positional_embedding == "learned": + w = self.feature_positional_embedding_embeddings.weight + embs = w[ + torch.randint( + 0, + w.shape[0], + (x.shape[2],), + ) + ] + x += embs[None, None] + elif self.feature_positional_embedding == "subspace": + embs = torch.randn( + (x.shape[2], x.shape[3] // 4), + device=x.device, + dtype=x.dtype, + ) + embs = self.feature_positional_embedding_embeddings(embs) + x += embs[None, None] + elif self.feature_positional_embedding is None: + embs = None + else: + raise ValueError(f"Unknown {self.feature_positional_embedding=}") self.cached_embeddings = None if cache_embeddings and embs is not None: @@ -869,13 +862,3 @@ def _add_pos_emb( # TODO(old) Double check the ordering is right for n, pe_ in zip(graph.nodes(), pe): graph.nodes[n]["positional_encoding"] = pe_ - - -class SerializableGenerator(torch.Generator): - """A serializable version of the torch.Generator, that cna be saved and pickled.""" - - def __getstate__(self) -> Any: - return self.__dict__ - - def __setstate__(self, d: Any) -> None: - self.__dict__ = d From bdd2d29712b60f65754ad12eb23c62cca5c05774 Mon Sep 17 00:00:00 2001 From: Alexander Pfefferle Date: Tue, 4 Feb 2025 10:23:13 +0100 Subject: [PATCH 4/9] bugfix --- src/tabpfn/model/encoders.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tabpfn/model/encoders.py b/src/tabpfn/model/encoders.py index b39fce54..35c9265d 100644 --- a/src/tabpfn/model/encoders.py +++ b/src/tabpfn/model/encoders.py @@ -464,7 +464,7 @@ def _fit(self, x: torch.Tensor, single_eval_pos: int, **kwargs: Any) -> None: single_eval_pos: The position to use for single evaluation. **kwargs: Additional keyword arguments (unused). """ - self.feature_means_ = torch_nanmean(x[:single_eval_pos], dim=0) + self.feature_means_ = torch_nanmean(x[:single_eval_pos], axis=0) def _transform( self, From 50bba69f850e76bacffe8cf334c0207cbc5ccc43 Mon Sep 17 00:00:00 2001 From: Alexander Pfefferle Date: Thu, 6 Feb 2025 00:34:06 +0100 Subject: [PATCH 5/9] added test_onnx_exportable_cpu --- src/tabpfn/model/transformer.py | 4 +-- tests/test_classifier_interface.py | 58 ++++++++++++++++++++++++++++++ tests/test_regressor_interface.py | 58 ++++++++++++++++++++++++++++++ 3 files changed, 118 insertions(+), 2 deletions(-) diff --git a/src/tabpfn/model/transformer.py b/src/tabpfn/model/transformer.py index d3d88807..181feb3a 100644 --- a/src/tabpfn/model/transformer.py +++ b/src/tabpfn/model/transformer.py @@ -4,7 +4,7 @@ import random import warnings -from collections.abc import Callable, Iterable +from collections.abc import Callable, Generator, Iterable from contextlib import contextmanager from functools import partial from typing import Any, Literal @@ -27,7 +27,7 @@ @contextmanager -def isolate_torch_rng(seed, device): +def isolate_torch_rng(seed: int, device: torch.device) -> Generator[None, None, None]: torch_rng_state = torch.get_rng_state() if torch.cuda.is_available(): torch_cuda_rng_state = torch.cuda.get_rng_state(device=device) diff --git a/tests/test_classifier_interface.py b/tests/test_classifier_interface.py index dfd5790a..6b9706c1 100644 --- a/tests/test_classifier_interface.py +++ b/tests/test_classifier_interface.py @@ -1,5 +1,6 @@ from __future__ import annotations +import io from itertools import product from typing import Callable, Literal @@ -11,6 +12,7 @@ from sklearn.pipeline import Pipeline from sklearn.preprocessing import StandardScaler from sklearn.utils.estimator_checks import parametrize_with_checks +from torch import nn from tabpfn import TabPFNClassifier from tabpfn.preprocessing import PreprocessorConfig @@ -219,3 +221,59 @@ def test_dict_vs_object_preprocessor_config(X_y: tuple[np.ndarray, np.ndarray]) prob_dict = model_dict.predict_proba(X) prob_obj = model_obj.predict_proba(X) np.testing.assert_array_almost_equal(prob_dict, prob_obj) + + +class ModelWrapper(nn.Module): + def __init__(self, original_model): # noqa: D107 + super().__init__() + self.model = original_model + + def forward( + self, + X, + y, + single_eval_pos, + only_return_standard_out, + categorical_inds, + ): + return self.model( + None, + X, + y, + single_eval_pos=single_eval_pos, + only_return_standard_out=only_return_standard_out, + categorical_inds=categorical_inds, + ) + + +@pytest.mark.filterwarnings("ignore::torch.jit.TracerWarning") +def test_onnx_exportable_cpu(X_y: tuple[np.ndarray, np.ndarray]) -> None: + X, y = X_y + with torch.no_grad(): + classifier = TabPFNClassifier(n_estimators=1, device="cpu") + # load the model so we can access it via classifier.model_ + classifier.fit(X, y) + # this is necessary if cuda is available + classifier.predict(X) + # replicate the above call with random tensors of same shape + X = torch.randn((X.shape[0] * 2, 1, X.shape[1] + 1)) + y = (torch.randn(y.shape) > 0).to(torch.float32) + dynamic_axes = { + "X": {0: "num_datapoints", 1: "batch_size", 2: "num_features"}, + "y": {0: "num_labels"}, + } + torch.onnx.export( + ModelWrapper(classifier.model_).eval(), + (X, y, y.shape[0], True, []), + io.BytesIO(), + input_names=[ + "X", + "y", + "single_eval_pos", + "only_return_standard_out", + "categorical_inds", + ], + output_names=["output"], + opset_version=17, # using 17 since we use torch>=2.1 + dynamic_axes=dynamic_axes, + ) diff --git a/tests/test_regressor_interface.py b/tests/test_regressor_interface.py index 8aa71424..03b65649 100644 --- a/tests/test_regressor_interface.py +++ b/tests/test_regressor_interface.py @@ -1,5 +1,6 @@ from __future__ import annotations +import io from itertools import product from typing import Callable, Literal @@ -11,6 +12,7 @@ from sklearn.pipeline import Pipeline from sklearn.preprocessing import StandardScaler from sklearn.utils.estimator_checks import parametrize_with_checks +from torch import nn from tabpfn import TabPFNRegressor from tabpfn.preprocessing import PreprocessorConfig @@ -216,3 +218,59 @@ def test_dict_vs_object_preprocessor_config(X_y: tuple[np.ndarray, np.ndarray]) q_obj, err_msg="Quantile predictions differ", ) + + +class ModelWrapper(nn.Module): + def __init__(self, original_model): # noqa: D107 + super().__init__() + self.model = original_model + + def forward( + self, + X, + y, + single_eval_pos, + only_return_standard_out, + categorical_inds, + ): + return self.model( + None, + X, + y, + single_eval_pos=single_eval_pos, + only_return_standard_out=only_return_standard_out, + categorical_inds=categorical_inds, + ) + + +@pytest.mark.filterwarnings("ignore::torch.jit.TracerWarning") +def test_onnx_exportable_cpu(X_y: tuple[np.ndarray, np.ndarray]) -> None: + X, y = X_y + with torch.no_grad(): + regressor = TabPFNRegressor(n_estimators=1, device="cpu") + # load the model so we can access it via classifier.model_ + regressor.fit(X, y) + # this is necessary if cuda is available + regressor.predict(X) + # replicate the above call with random tensors of same shape + X = torch.randn((X.shape[0] * 2, 1, X.shape[1] + 1)) + y = (torch.randn(y.shape) > 0).to(torch.float32) + dynamic_axes = { + "X": {0: "num_datapoints", 1: "batch_size", 2: "num_features"}, + "y": {0: "num_labels"}, + } + torch.onnx.export( + ModelWrapper(regressor.model_).eval(), + (X, y, y.shape[0], True, []), + io.BytesIO(), + input_names=[ + "X", + "y", + "single_eval_pos", + "only_return_standard_out", + "categorical_inds", + ], + output_names=["output"], + opset_version=17, # using 17 since we use torch>=2.1 + dynamic_axes=dynamic_axes, + ) From 591a73b86aca56958e426d918c159598dcb8a414 Mon Sep 17 00:00:00 2001 From: Alexander Pfefferle Date: Thu, 6 Feb 2025 11:04:37 +0100 Subject: [PATCH 6/9] added onnx to dev dependencies --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index af086f77..471cfff6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,6 +61,7 @@ dev = [ "mypy", # Test "pytest", + "onnx", # required for onnx export tests # Docs "mkdocs", "mkdocs-material", From dd10623acf074cda80d10483769219adc46e223f Mon Sep 17 00:00:00 2001 From: LeoGrin Date: Thu, 6 Feb 2025 11:47:59 +0100 Subject: [PATCH 7/9] add onnx dependency to the ci --- .github/workflows/pull_request.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/pull_request.yml b/.github/workflows/pull_request.yml index 1241e617..35b1f380 100644 --- a/.github/workflows/pull_request.yml +++ b/.github/workflows/pull_request.yml @@ -112,8 +112,10 @@ jobs: - name: Install dependencies run: | uv pip install --system --no-deps . - uv pip install --system pytest + # onnx is required for onnx export tests + # we don't install all dev dependencies here for speed uv pip install --system -r requirements.txt + uv pip install --system pytest onnx - name: Initialize submodules run: git submodule update --init --recursive From 41e360267ed42bb765c3340feb36c77567784380 Mon Sep 17 00:00:00 2001 From: LeoGrin Date: Thu, 6 Feb 2025 12:47:12 +0100 Subject: [PATCH 8/9] make tests deterministic --- tests/test_classifier_interface.py | 13 ++++++++++--- tests/test_regressor_interface.py | 13 +++++++++---- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/tests/test_classifier_interface.py b/tests/test_classifier_interface.py index 6b9706c1..67bbc16e 100644 --- a/tests/test_classifier_interface.py +++ b/tests/test_classifier_interface.py @@ -250,14 +250,21 @@ def forward( def test_onnx_exportable_cpu(X_y: tuple[np.ndarray, np.ndarray]) -> None: X, y = X_y with torch.no_grad(): - classifier = TabPFNClassifier(n_estimators=1, device="cpu") + classifier = TabPFNClassifier(n_estimators=1, device="cpu", random_state=42) # load the model so we can access it via classifier.model_ classifier.fit(X, y) # this is necessary if cuda is available classifier.predict(X) # replicate the above call with random tensors of same shape - X = torch.randn((X.shape[0] * 2, 1, X.shape[1] + 1)) - y = (torch.randn(y.shape) > 0).to(torch.float32) + X = torch.randn( + (X.shape[0] * 2, 1, X.shape[1] + 1), + generator=torch.Generator().manual_seed(42), + ) + y = ( + torch.rand(y.shape, generator=torch.Generator().manual_seed(42)) + .round() + .to(torch.float32) + ) dynamic_axes = { "X": {0: "num_datapoints", 1: "batch_size", 2: "num_features"}, "y": {0: "num_labels"}, diff --git a/tests/test_regressor_interface.py b/tests/test_regressor_interface.py index 03b65649..ca34e576 100644 --- a/tests/test_regressor_interface.py +++ b/tests/test_regressor_interface.py @@ -247,14 +247,19 @@ def forward( def test_onnx_exportable_cpu(X_y: tuple[np.ndarray, np.ndarray]) -> None: X, y = X_y with torch.no_grad(): - regressor = TabPFNRegressor(n_estimators=1, device="cpu") + regressor = TabPFNRegressor(n_estimators=1, device="cpu", random_state=42) # load the model so we can access it via classifier.model_ regressor.fit(X, y) # this is necessary if cuda is available regressor.predict(X) - # replicate the above call with random tensors of same shape - X = torch.randn((X.shape[0] * 2, 1, X.shape[1] + 1)) - y = (torch.randn(y.shape) > 0).to(torch.float32) + # Use fixed random values instead of random generation + X = torch.randn( + (X.shape[0] * 2, 1, X.shape[1] + 1), + generator=torch.Generator().manual_seed(42), + ) + y = (torch.randn(y.shape, generator=torch.Generator().manual_seed(42)) > 0).to( + torch.float32, + ) dynamic_axes = { "X": {0: "num_datapoints", 1: "batch_size", 2: "num_features"}, "y": {0: "num_labels"}, From 133e5ab044cbba8d32ab4d7d9ad444c1b26f7c9f Mon Sep 17 00:00:00 2001 From: LeoGrin Date: Fri, 7 Feb 2025 10:28:54 +0100 Subject: [PATCH 9/9] make unstable test pass for older scipy --- tests/test_regressor_interface.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_regressor_interface.py b/tests/test_regressor_interface.py index ca34e576..ae130452 100644 --- a/tests/test_regressor_interface.py +++ b/tests/test_regressor_interface.py @@ -243,16 +243,17 @@ def forward( ) +# WARNING: unstable for scipy<1.11.0 @pytest.mark.filterwarnings("ignore::torch.jit.TracerWarning") def test_onnx_exportable_cpu(X_y: tuple[np.ndarray, np.ndarray]) -> None: X, y = X_y with torch.no_grad(): - regressor = TabPFNRegressor(n_estimators=1, device="cpu", random_state=42) + regressor = TabPFNRegressor(n_estimators=1, device="cpu", random_state=43) # load the model so we can access it via classifier.model_ regressor.fit(X, y) # this is necessary if cuda is available regressor.predict(X) - # Use fixed random values instead of random generation + # replicate the above call with random tensors of same shape X = torch.randn( (X.shape[0] * 2, 1, X.shape[1] + 1), generator=torch.Generator().manual_seed(42),