Skip to content

Commit

Permalink
Merge pull request #168 from PriorLabs/speedup_feature_select
Browse files Browse the repository at this point in the history
Speedup feature selection function
  • Loading branch information
LeoGrin authored Feb 5, 2025
2 parents 7e91976 + dbfc9ea commit 4b8b18a
Showing 1 changed file with 28 additions and 21 deletions.
49 changes: 28 additions & 21 deletions src/tabpfn/model/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,33 +100,40 @@ def normalize_data(


def select_features(x: torch.Tensor, sel: torch.Tensor) -> torch.Tensor:
"""Select features from the input tensor based on the selection mask.
"""Select features from the input tensor based on the selection mask,
and arrange them contiguously in the last dimension.
If batch size is bigger than 1, we pad the features with zeros to make the number of features fixed.
Args:
x: The input tensor.
sel: The boolean selection mask indicating which features to keep.
x: The input tensor of shape (sequence_length, batch_size, total_features)
sel: The boolean selection mask indicating which features to keep of shape (batch_size, total_features)
Returns:
The tensor with selected features.
The shape is (sequence_length, batch_size, number_of_selected_features) if batch_size is 1.
The shape is (sequence_length, batch_size, total_features) if batch_size is greater than 1.
"""
new_x = x.clone()
for B in range(x.shape[1]):
if x.shape[1] > 1:
new_x[:, B, :] = torch.cat(
[
x[:, B, sel[B]],
torch.zeros(
x.shape[0],
x.shape[-1] - sel[B].sum(),
device=x.device,
dtype=x.dtype,
),
],
-1,
)
else:
# If B == 1, we don't need to append zeros, as the number of features can change
new_x = x[:, :, sel[B]]
B, total_features = sel.shape
sequence_length = x.shape[0]

# If B == 1, we don't need to append zeros, as the number of features don't need to be fixed.
if B == 1:
return x[:, :, sel[0]]

new_x = torch.zeros(
(sequence_length, B, total_features),
device=x.device,
dtype=x.dtype,
)

# For each batch, compute the number of selected features.
sel_counts = sel.sum(dim=-1) # shape: (B,)

for b in range(B):
s = int(sel_counts[b])
if s > 0:
new_x[:, b, :s] = x[:, b, sel[b]]

return new_x


Expand Down

0 comments on commit 4b8b18a

Please sign in to comment.