Skip to content

Commit

Permalink
added test_onnx_exportable_cpu
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderPfefferle committed Feb 5, 2025
1 parent bdd2d29 commit 50bba69
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/tabpfn/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import random
import warnings
from collections.abc import Callable, Iterable
from collections.abc import Callable, Generator, Iterable
from contextlib import contextmanager
from functools import partial
from typing import Any, Literal
Expand All @@ -27,7 +27,7 @@


@contextmanager
def isolate_torch_rng(seed, device):
def isolate_torch_rng(seed: int, device: torch.device) -> Generator[None, None, None]:
torch_rng_state = torch.get_rng_state()
if torch.cuda.is_available():
torch_cuda_rng_state = torch.cuda.get_rng_state(device=device)
Expand Down
58 changes: 58 additions & 0 deletions tests/test_classifier_interface.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import io
from itertools import product
from typing import Callable, Literal

Expand All @@ -11,6 +12,7 @@
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.utils.estimator_checks import parametrize_with_checks
from torch import nn

from tabpfn import TabPFNClassifier
from tabpfn.preprocessing import PreprocessorConfig
Expand Down Expand Up @@ -219,3 +221,59 @@ def test_dict_vs_object_preprocessor_config(X_y: tuple[np.ndarray, np.ndarray])
prob_dict = model_dict.predict_proba(X)
prob_obj = model_obj.predict_proba(X)
np.testing.assert_array_almost_equal(prob_dict, prob_obj)


class ModelWrapper(nn.Module):
def __init__(self, original_model): # noqa: D107
super().__init__()
self.model = original_model

def forward(
self,
X,
y,
single_eval_pos,
only_return_standard_out,
categorical_inds,
):
return self.model(
None,
X,
y,
single_eval_pos=single_eval_pos,
only_return_standard_out=only_return_standard_out,
categorical_inds=categorical_inds,
)


@pytest.mark.filterwarnings("ignore::torch.jit.TracerWarning")
def test_onnx_exportable_cpu(X_y: tuple[np.ndarray, np.ndarray]) -> None:
X, y = X_y
with torch.no_grad():
classifier = TabPFNClassifier(n_estimators=1, device="cpu")
# load the model so we can access it via classifier.model_
classifier.fit(X, y)
# this is necessary if cuda is available
classifier.predict(X)
# replicate the above call with random tensors of same shape
X = torch.randn((X.shape[0] * 2, 1, X.shape[1] + 1))
y = (torch.randn(y.shape) > 0).to(torch.float32)
dynamic_axes = {
"X": {0: "num_datapoints", 1: "batch_size", 2: "num_features"},
"y": {0: "num_labels"},
}
torch.onnx.export(
ModelWrapper(classifier.model_).eval(),
(X, y, y.shape[0], True, []),
io.BytesIO(),
input_names=[
"X",
"y",
"single_eval_pos",
"only_return_standard_out",
"categorical_inds",
],
output_names=["output"],
opset_version=17, # using 17 since we use torch>=2.1
dynamic_axes=dynamic_axes,
)
58 changes: 58 additions & 0 deletions tests/test_regressor_interface.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import io
from itertools import product
from typing import Callable, Literal

Expand All @@ -11,6 +12,7 @@
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.utils.estimator_checks import parametrize_with_checks
from torch import nn

from tabpfn import TabPFNRegressor
from tabpfn.preprocessing import PreprocessorConfig
Expand Down Expand Up @@ -216,3 +218,59 @@ def test_dict_vs_object_preprocessor_config(X_y: tuple[np.ndarray, np.ndarray])
q_obj,
err_msg="Quantile predictions differ",
)


class ModelWrapper(nn.Module):
def __init__(self, original_model): # noqa: D107
super().__init__()
self.model = original_model

def forward(
self,
X,
y,
single_eval_pos,
only_return_standard_out,
categorical_inds,
):
return self.model(
None,
X,
y,
single_eval_pos=single_eval_pos,
only_return_standard_out=only_return_standard_out,
categorical_inds=categorical_inds,
)


@pytest.mark.filterwarnings("ignore::torch.jit.TracerWarning")
def test_onnx_exportable_cpu(X_y: tuple[np.ndarray, np.ndarray]) -> None:
X, y = X_y
with torch.no_grad():
regressor = TabPFNRegressor(n_estimators=1, device="cpu")
# load the model so we can access it via classifier.model_
regressor.fit(X, y)
# this is necessary if cuda is available
regressor.predict(X)
# replicate the above call with random tensors of same shape
X = torch.randn((X.shape[0] * 2, 1, X.shape[1] + 1))
y = (torch.randn(y.shape) > 0).to(torch.float32)
dynamic_axes = {
"X": {0: "num_datapoints", 1: "batch_size", 2: "num_features"},
"y": {0: "num_labels"},
}
torch.onnx.export(
ModelWrapper(regressor.model_).eval(),
(X, y, y.shape[0], True, []),
io.BytesIO(),
input_names=[
"X",
"y",
"single_eval_pos",
"only_return_standard_out",
"categorical_inds",
],
output_names=["output"],
opset_version=17, # using 17 since we use torch>=2.1
dynamic_axes=dynamic_axes,
)

0 comments on commit 50bba69

Please sign in to comment.