Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable ONNX export #165

Merged
merged 9 commits into from
Feb 7, 2025
20 changes: 13 additions & 7 deletions src/tabpfn/model/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,17 @@
from torch import nn


# TODO(eddiebergman): These were used before but I have no idea why.
# We use the implementations given by torch for now.
# TODO(Arjun): Enabling these again because their behaviour is a little
# different from torch's implementation (see Issue #2). We should check if this makes
# a difference in the results.
# usage of custom implementations is required to support ONNX export
def torch_nansum(x: torch.Tensor, axis=None, keepdim=False, dtype=None):
nan_mask = torch.isnan(x)
masked_input = torch.where(
nan_mask,
torch.tensor(0.0, device=x.device, dtype=x.dtype),
x,
)
return torch.sum(masked_input, axis=axis, keepdim=keepdim, dtype=dtype)


def torch_nanmean(
x: torch.Tensor,
axis: int = 0,
Expand Down Expand Up @@ -46,7 +52,7 @@ def torch_nanstd(x: torch.Tensor, axis: int = 0):
dim=axis,
)
return torch.sqrt(
torch.nansum(torch.square(mean_broadcast - x), axis=axis) / (num - 1), # type: ignore
torch_nansum(torch.square(mean_broadcast - x), axis=axis) / (num - 1), # type: ignore
)


Expand Down Expand Up @@ -458,7 +464,7 @@ def _fit(self, x: torch.Tensor, single_eval_pos: int, **kwargs: Any) -> None:
single_eval_pos: The position to use for single evaluation.
**kwargs: Additional keyword arguments (unused).
"""
self.feature_means_ = torch.nanmean(x[:single_eval_pos], dim=0)
self.feature_means_ = torch_nanmean(x[:single_eval_pos], axis=0)

def _transform(
self,
Expand Down
119 changes: 51 additions & 68 deletions src/tabpfn/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import random
import warnings
from collections.abc import Callable, Iterable
from contextlib import contextmanager
from functools import partial
from typing import Any, Literal

Expand All @@ -25,6 +26,20 @@
DEFAULT_EMSIZE = 128


@contextmanager
def isolate_torch_rng(seed, device):
torch_rng_state = torch.get_rng_state()
if torch.cuda.is_available():
torch_cuda_rng_state = torch.cuda.get_rng_state(device=device)
torch.manual_seed(seed)
try:
yield
finally:
torch.set_rng_state(torch_rng_state)
if torch.cuda.is_available():
torch.cuda.set_rng_state(torch_cuda_rng_state, device=device)


class LayerStack(nn.Module):
"""Similar to nn.Sequential, but with support for passing keyword arguments
to layers and stacks the same layer multiple times.
Expand Down Expand Up @@ -294,17 +309,6 @@ def __init__( # noqa: C901, D417, PLR0913
self.cached_feature_positional_embeddings: torch.Tensor | None = None
self.seed = seed if seed is not None else random.randint(0, 1_000_000) # noqa: S311

# Device on which the generator was last initialized.
# If loading from a checkpoint, this might be false,
# but it will be set to the correct device on the first forward pass.
self.generator_device = "cpu"
self._init_rnd()

def _init_rnd(self) -> None:
self.generator = SerializableGenerator(device=self.generator_device)
if self.seed: # This can be none if set outside of the model.
self.generator.manual_seed(self.seed)

def reset_save_peak_mem_factor(self, factor: int | None = None) -> None:
"""Sets the save_peak_mem_factor for all layers.

Expand Down Expand Up @@ -377,7 +381,6 @@ def forward(self, *args: Any, **kwargs: Any) -> dict[str, torch.Tensor]: # noqa
Returns:
The output of the model, which can be a tensor or a dictionary of tensors.
"""
self._init_rnd()
half_layers = kwargs.pop("half_layers", False)
assert half_layers is False

Expand Down Expand Up @@ -694,57 +697,47 @@ def add_embeddings( # noqa: C901, PLR0912
x += self.cached_embeddings[None, None]
return x, y

if (
self.generator_device != self.generator.device
or self.generator_device != x.device
):
self.generator_device = x.device
self._init_rnd()

if self.feature_positional_embedding == "normal_rand_vec":
embs = torch.randn(
(x.shape[2], x.shape[3]),
device=x.device,
dtype=x.dtype,
generator=self.generator,
)
x += embs[None, None]
elif self.feature_positional_embedding == "uni_rand_vec":
embs = (
torch.rand(
with isolate_torch_rng(self.seed, device=x.device):
if self.feature_positional_embedding == "normal_rand_vec":
embs = torch.randn(
(x.shape[2], x.shape[3]),
device=x.device,
dtype=x.dtype,
generator=self.generator,
)
* 2
- 1
)
x += embs[None, None]
elif self.feature_positional_embedding == "learned":
w = self.feature_positional_embedding_embeddings.weight
embs = w[
torch.randint(
0,
w.shape[0],
(x.shape[2],),
generator=self.generator,
x += embs[None, None]
elif self.feature_positional_embedding == "uni_rand_vec":
embs = (
torch.rand(
(x.shape[2], x.shape[3]),
device=x.device,
dtype=x.dtype,
)
* 2
- 1
)
]
x += embs[None, None]
elif self.feature_positional_embedding == "subspace":
embs = torch.randn(
(x.shape[2], x.shape[3] // 4),
device=x.device,
dtype=x.dtype,
generator=self.generator,
)
embs = self.feature_positional_embedding_embeddings(embs)
x += embs[None, None]
elif self.feature_positional_embedding is None:
embs = None
else:
raise ValueError(f"Unknown {self.feature_positional_embedding=}")
x += embs[None, None]
elif self.feature_positional_embedding == "learned":
w = self.feature_positional_embedding_embeddings.weight
embs = w[
torch.randint(
0,
w.shape[0],
(x.shape[2],),
)
]
x += embs[None, None]
elif self.feature_positional_embedding == "subspace":
embs = torch.randn(
(x.shape[2], x.shape[3] // 4),
device=x.device,
dtype=x.dtype,
)
embs = self.feature_positional_embedding_embeddings(embs)
x += embs[None, None]
elif self.feature_positional_embedding is None:
embs = None
else:
raise ValueError(f"Unknown {self.feature_positional_embedding=}")

self.cached_embeddings = None
if cache_embeddings and embs is not None:
Expand Down Expand Up @@ -869,13 +862,3 @@ def _add_pos_emb(
# TODO(old) Double check the ordering is right
for n, pe_ in zip(graph.nodes(), pe):
graph.nodes[n]["positional_encoding"] = pe_


class SerializableGenerator(torch.Generator):
"""A serializable version of the torch.Generator, that cna be saved and pickled."""

def __getstate__(self) -> Any:
return self.__dict__

def __setstate__(self, d: Any) -> None:
self.__dict__ = d
1 change: 0 additions & 1 deletion tests/test_regressor_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ def test_sklearn_compatible_estimator(
"check_methods_sample_order_invariance",
):
estimator.inference_precision = torch.float64
if check.func.__name__ == "check_methods_sample_order_invariance": # type: ignore
pytest.xfail("We're not at 1e-7 difference yet")
check(estimator)

Expand Down
Loading