Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speedup feature selection function #168

Merged
merged 5 commits into from
Feb 5, 2025
Merged

Speedup feature selection function #168

merged 5 commits into from
Feb 5, 2025

Conversation

LeoGrin
Copy link
Collaborator

@LeoGrin LeoGrin commented Feb 4, 2025

As noted by @LennartPurucker the select_features function is quite slow. This PR makes it between 1.5-2x faster (still quite slow though).

I ran this test:

import torch
import time
import numpy as np
import torch.nn.functional as F

def select_features(x: torch.Tensor, sel: torch.Tensor) -> torch.Tensor:
    """Select features from the input tensor based on the selection mask.

    Args:
        x: The input tensor.
        sel: The boolean selection mask indicating which features to keep.

    Returns:
        The tensor with selected features.
    """
    new_x = x.clone()
    for B in range(x.shape[1]):
        if x.shape[1] > 1:
            new_x[:, B, :] = torch.cat(
                [
                    x[:, B, sel[B]],
                    torch.zeros(
                        x.shape[0],
                        x.shape[-1] - sel[B].sum(),
                        device=x.device,
                        dtype=x.dtype,
                    ),
                ],
                -1,
            )
        else:
            # If B == 1, we don't need to append zeros, as the number of features can change
            new_x = x[:, :, sel[B]]
    return new_x

def select_features_new(x: torch.Tensor, sel: torch.Tensor) -> torch.Tensor:
    """Select features from the input tensor based on the selection mask.

    Args:
        x: The input tensor.
        sel: The boolean selection mask indicating which features to keep.

    Returns:
        The tensor with selected features.
    """
    B, total_features = sel.shape
    batch_size = x.shape[0]

    # If B == 1, we don't need to append zeros, as the number of features don't need to be fixed.
    if B == 1:
        return x[:, :, sel[0]]
    
    new_x = torch.zeros((batch_size, B, total_features), device=x.device, dtype=x.dtype)
    
    sel_counts = sel.sum(dim=-1)  # shape: (B,)
    
    for b in range(B):
        s = int(sel_counts[b])
        if s > 0:
            new_x[:, b, :s] = x[:, b, sel[b]]
    
    return new_x

def test_equivalence_and_speed():
    # Test cases with different shapes
    test_cases = [
        (torch.randn(32, 4, 100), torch.randint(0, 2, (4, 100)).bool()),  # Multiple batches
        (torch.randn(32, 1, 100), torch.randint(0, 2, (1, 100)).bool()),  # Single batch
        (torch.randn(1, 4, 50), torch.randint(0, 2, (4, 50)).bool()),     # Small
        # lot of features
        (torch.randn(1000, 4, 500), torch.randint(0, 2, (4, 500)).bool()),
        (torch.randn(2000, 8, 500), torch.randint(0, 2, (8, 500)).bool()),
        (torch.randn(1000, 8, 1000), torch.randint(0, 2, (8, 1000)).bool()),
        (torch.randn(1000, 1, 1000), torch.randint(0, 2, (1, 1000)).bool()),
    ]
    
    for i, (x, sel) in enumerate(test_cases):
        print(f"\nTest case {i+1}:")
        print(f"Input shape: {x.shape}, Selection mask shape: {sel.shape}")
        
        # Test equivalence
        try:
            old_result = select_features(x, sel)
            new_result = select_features_new(x, sel)
            
            if old_result.shape != new_result.shape:
                print(f"❌ Shape mismatch: Old {old_result.shape} vs New {new_result.shape}")
                continue
                
            # Compare results
            max_diff = torch.max(torch.abs(old_result - new_result))
            if max_diff > 1e-6:
                print(f"❌ Results differ! Max difference: {max_diff}")
            else:
                print("✓ Results match")
                
            # Speed test
            iterations = 1000
            
            # Time old version
            start = time.time()
            for _ in range(iterations):
                _ = select_features(x, sel)
            old_time = time.time() - start
            
            # Time new version
            start = time.time()
            for _ in range(iterations):
                _ = select_features_new(x, sel)
            new_time = time.time() - start
            
            print(f"Old version: {old_time:.4f}s")
            print(f"New version: {new_time:.4f}s")
            print(f"Speedup: {old_time/new_time:.2f}x")
            
        except Exception as e:
            print(f"❌ Error occurred: {str(e)}")

if __name__ == "__main__":
    test_equivalence_and_speed()

giving these results:
on GPU T4:

Using device: cuda
Warming up GPU...

Test case 1:
Input shape: torch.Size([32, 4, 100]), Selection mask shape: torch.Size([4, 100])
✓ Results match
Old version: 0.7961s
New version: 0.4571s
Speedup: 1.74x

Test case 2:
Input shape: torch.Size([32, 1, 100]), Selection mask shape: torch.Size([1, 100])
✓ Results match
Old version: 0.0770s
New version: 0.0605s
Speedup: 1.27x

Test case 3:
Input shape: torch.Size([1, 4, 50]), Selection mask shape: torch.Size([4, 50])
✓ Results match
Old version: 0.7872s
New version: 0.4613s
Speedup: 1.71x

Test case 4:
Input shape: torch.Size([1000, 4, 500]), Selection mask shape: torch.Size([4, 500])
✓ Results match
Old version: 0.8525s
New version: 0.4458s
Speedup: 1.91x

Test case 5:
Input shape: torch.Size([2000, 8, 500]), Selection mask shape: torch.Size([8, 500])
✓ Results match
Old version: 2.1386s
New version: 1.1011s
Speedup: 1.94x

Test case 6:
Input shape: torch.Size([1000, 8, 1000]), Selection mask shape: torch.Size([8, 1000])
✓ Results match
Old version: 2.0992s
New version: 1.0877s
Speedup: 1.93x

Test case 7:
Input shape: torch.Size([1000, 1, 1000]), Selection mask shape: torch.Size([1, 1000])
✓ Results match
Old version: 0.1086s
New version: 0.0685s
Speedup: 1.58x

on mac M3 cpu:

Test case 1:
Input shape: torch.Size([32, 4, 100]), Selection mask shape: torch.Size([4, 100])
✓ Results match
Old version: 0.0636s
New version: 0.0407s
Speedup: 1.56x

Test case 2:
Input shape: torch.Size([32, 1, 100]), Selection mask shape: torch.Size([1, 100])
✓ Results match
Old version: 0.0070s
New version: 0.0059s
Speedup: 1.19x

Test case 3:
Input shape: torch.Size([1, 4, 50]), Selection mask shape: torch.Size([4, 50])
✓ Results match
Old version: 0.0509s
New version: 0.0307s
Speedup: 1.66x

Test case 4:
Input shape: torch.Size([1000, 4, 500]), Selection mask shape: torch.Size([4, 500])
✓ Results match
Old version: 2.1650s
New version: 1.3150s
Speedup: 1.65x

Test case 5:
Input shape: torch.Size([2000, 8, 500]), Selection mask shape: torch.Size([8, 500])
✓ Results match
Old version: 11.9438s
New version: 8.1608s
Speedup: 1.46x

Test case 6:
Input shape: torch.Size([1000, 8, 1000]), Selection mask shape: torch.Size([8, 1000])
✓ Results match
Old version: 12.5993s
New version: 8.3466s
Speedup: 1.51x

Test case 7:
Input shape: torch.Size([1000, 1, 1000]), Selection mask shape: torch.Size([1, 1000])
✓ Results match
Old version: 0.7504s
New version: 0.3714s
Speedup: 2.02x

The only behavior difference with the original function is that it doesn't clone x in case B=1, but I think that's fine.

Copy link
Collaborator

@LennartPurucker LennartPurucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, I have one minor request: Since I think you now understand the function, could you add a docstring explaining what it does? I am still a bit confused about its purpose.

@LeoGrin LeoGrin merged commit 4b8b18a into main Feb 5, 2025
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants