diff --git a/src/tabpfn/__init__.py b/src/tabpfn/__init__.py index b021a9e5..32029c76 100644 --- a/src/tabpfn/__init__.py +++ b/src/tabpfn/__init__.py @@ -1,6 +1,7 @@ +from importlib.metadata import version + from tabpfn.classifier import TabPFNClassifier from tabpfn.regressor import TabPFNRegressor -from importlib.metadata import version try: __version__ = version(__name__) diff --git a/src/tabpfn/model/encoders.py b/src/tabpfn/model/encoders.py index 2ae5e698..2f247539 100644 --- a/src/tabpfn/model/encoders.py +++ b/src/tabpfn/model/encoders.py @@ -100,33 +100,40 @@ def normalize_data( def select_features(x: torch.Tensor, sel: torch.Tensor) -> torch.Tensor: - """Select features from the input tensor based on the selection mask. + """Select features from the input tensor based on the selection mask, + and arrange them contiguously in the last dimension. + If batch size is bigger than 1, we pad the features with zeros to make the number of features fixed. Args: - x: The input tensor. - sel: The boolean selection mask indicating which features to keep. + x: The input tensor of shape (sequence_length, batch_size, total_features) + sel: The boolean selection mask indicating which features to keep of shape (batch_size, total_features) Returns: The tensor with selected features. + The shape is (sequence_length, batch_size, number_of_selected_features) if batch_size is 1. + The shape is (sequence_length, batch_size, total_features) if batch_size is greater than 1. """ - new_x = x.clone() - for B in range(x.shape[1]): - if x.shape[1] > 1: - new_x[:, B, :] = torch.cat( - [ - x[:, B, sel[B]], - torch.zeros( - x.shape[0], - x.shape[-1] - sel[B].sum(), - device=x.device, - dtype=x.dtype, - ), - ], - -1, - ) - else: - # If B == 1, we don't need to append zeros, as the number of features can change - new_x = x[:, :, sel[B]] + B, total_features = sel.shape + sequence_length = x.shape[0] + + # If B == 1, we don't need to append zeros, as the number of features don't need to be fixed. + if B == 1: + return x[:, :, sel[0]] + + new_x = torch.zeros( + (sequence_length, B, total_features), + device=x.device, + dtype=x.dtype, + ) + + # For each batch, compute the number of selected features. + sel_counts = sel.sum(dim=-1) # shape: (B,) + + for b in range(B): + s = int(sel_counts[b]) + if s > 0: + new_x[:, b, :s] = x[:, b, sel[b]] + return new_x