Skip to content

Commit

Permalink
speedup feature selection
Browse files Browse the repository at this point in the history
  • Loading branch information
LeoGrin committed Feb 4, 2025
1 parent cd4dc78 commit 2a22cb2
Showing 1 changed file with 17 additions and 19 deletions.
36 changes: 17 additions & 19 deletions src/tabpfn/model/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,27 +109,25 @@ def select_features(x: torch.Tensor, sel: torch.Tensor) -> torch.Tensor:
Returns:
The tensor with selected features.
"""
new_x = x.clone()
for B in range(x.shape[1]):
if x.shape[1] > 1:
new_x[:, B, :] = torch.cat(
[
x[:, B, sel[B]],
torch.zeros(
x.shape[0],
x.shape[-1] - sel[B].sum(),
device=x.device,
dtype=x.dtype,
),
],
-1,
)
else:
# If B == 1, we don't need to append zeros, as the number of features can change
new_x = x[:, :, sel[B]]
B, total_features = sel.shape
batch_size = x.shape[0]

# If B == 1, we don't need to append zeros, as the number of features don't need to be fixed.
if B == 1:
return x[:, :, sel[0]]

new_x = torch.zeros((batch_size, B, total_features), device=x.device, dtype=x.dtype)

# For each block, compute the number of selected features.
sel_counts = sel.sum(dim=-1) # shape: (B,)

for b in range(B):
s = int(sel_counts[b])
if s > 0:
new_x[:, b, :s] = x[:, b, sel[b]]

return new_x


def remove_outliers(
X: torch.Tensor,
n_sigma: float = 4,
Expand Down

0 comments on commit 2a22cb2

Please sign in to comment.