diff --git a/src/tabpfn/model/transformer.py b/src/tabpfn/model/transformer.py index d3d88807..181feb3a 100644 --- a/src/tabpfn/model/transformer.py +++ b/src/tabpfn/model/transformer.py @@ -4,7 +4,7 @@ import random import warnings -from collections.abc import Callable, Iterable +from collections.abc import Callable, Generator, Iterable from contextlib import contextmanager from functools import partial from typing import Any, Literal @@ -27,7 +27,7 @@ @contextmanager -def isolate_torch_rng(seed, device): +def isolate_torch_rng(seed: int, device: torch.device) -> Generator[None, None, None]: torch_rng_state = torch.get_rng_state() if torch.cuda.is_available(): torch_cuda_rng_state = torch.cuda.get_rng_state(device=device) diff --git a/tests/test_classifier_interface.py b/tests/test_classifier_interface.py index dfd5790a..6b9706c1 100644 --- a/tests/test_classifier_interface.py +++ b/tests/test_classifier_interface.py @@ -1,5 +1,6 @@ from __future__ import annotations +import io from itertools import product from typing import Callable, Literal @@ -11,6 +12,7 @@ from sklearn.pipeline import Pipeline from sklearn.preprocessing import StandardScaler from sklearn.utils.estimator_checks import parametrize_with_checks +from torch import nn from tabpfn import TabPFNClassifier from tabpfn.preprocessing import PreprocessorConfig @@ -219,3 +221,59 @@ def test_dict_vs_object_preprocessor_config(X_y: tuple[np.ndarray, np.ndarray]) prob_dict = model_dict.predict_proba(X) prob_obj = model_obj.predict_proba(X) np.testing.assert_array_almost_equal(prob_dict, prob_obj) + + +class ModelWrapper(nn.Module): + def __init__(self, original_model): # noqa: D107 + super().__init__() + self.model = original_model + + def forward( + self, + X, + y, + single_eval_pos, + only_return_standard_out, + categorical_inds, + ): + return self.model( + None, + X, + y, + single_eval_pos=single_eval_pos, + only_return_standard_out=only_return_standard_out, + categorical_inds=categorical_inds, + ) + + +@pytest.mark.filterwarnings("ignore::torch.jit.TracerWarning") +def test_onnx_exportable_cpu(X_y: tuple[np.ndarray, np.ndarray]) -> None: + X, y = X_y + with torch.no_grad(): + classifier = TabPFNClassifier(n_estimators=1, device="cpu") + # load the model so we can access it via classifier.model_ + classifier.fit(X, y) + # this is necessary if cuda is available + classifier.predict(X) + # replicate the above call with random tensors of same shape + X = torch.randn((X.shape[0] * 2, 1, X.shape[1] + 1)) + y = (torch.randn(y.shape) > 0).to(torch.float32) + dynamic_axes = { + "X": {0: "num_datapoints", 1: "batch_size", 2: "num_features"}, + "y": {0: "num_labels"}, + } + torch.onnx.export( + ModelWrapper(classifier.model_).eval(), + (X, y, y.shape[0], True, []), + io.BytesIO(), + input_names=[ + "X", + "y", + "single_eval_pos", + "only_return_standard_out", + "categorical_inds", + ], + output_names=["output"], + opset_version=17, # using 17 since we use torch>=2.1 + dynamic_axes=dynamic_axes, + ) diff --git a/tests/test_regressor_interface.py b/tests/test_regressor_interface.py index 8aa71424..03b65649 100644 --- a/tests/test_regressor_interface.py +++ b/tests/test_regressor_interface.py @@ -1,5 +1,6 @@ from __future__ import annotations +import io from itertools import product from typing import Callable, Literal @@ -11,6 +12,7 @@ from sklearn.pipeline import Pipeline from sklearn.preprocessing import StandardScaler from sklearn.utils.estimator_checks import parametrize_with_checks +from torch import nn from tabpfn import TabPFNRegressor from tabpfn.preprocessing import PreprocessorConfig @@ -216,3 +218,59 @@ def test_dict_vs_object_preprocessor_config(X_y: tuple[np.ndarray, np.ndarray]) q_obj, err_msg="Quantile predictions differ", ) + + +class ModelWrapper(nn.Module): + def __init__(self, original_model): # noqa: D107 + super().__init__() + self.model = original_model + + def forward( + self, + X, + y, + single_eval_pos, + only_return_standard_out, + categorical_inds, + ): + return self.model( + None, + X, + y, + single_eval_pos=single_eval_pos, + only_return_standard_out=only_return_standard_out, + categorical_inds=categorical_inds, + ) + + +@pytest.mark.filterwarnings("ignore::torch.jit.TracerWarning") +def test_onnx_exportable_cpu(X_y: tuple[np.ndarray, np.ndarray]) -> None: + X, y = X_y + with torch.no_grad(): + regressor = TabPFNRegressor(n_estimators=1, device="cpu") + # load the model so we can access it via classifier.model_ + regressor.fit(X, y) + # this is necessary if cuda is available + regressor.predict(X) + # replicate the above call with random tensors of same shape + X = torch.randn((X.shape[0] * 2, 1, X.shape[1] + 1)) + y = (torch.randn(y.shape) > 0).to(torch.float32) + dynamic_axes = { + "X": {0: "num_datapoints", 1: "batch_size", 2: "num_features"}, + "y": {0: "num_labels"}, + } + torch.onnx.export( + ModelWrapper(regressor.model_).eval(), + (X, y, y.shape[0], True, []), + io.BytesIO(), + input_names=[ + "X", + "y", + "single_eval_pos", + "only_return_standard_out", + "categorical_inds", + ], + output_names=["output"], + opset_version=17, # using 17 since we use torch>=2.1 + dynamic_axes=dynamic_axes, + )