Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Low effective loading in get_class_subset function #159

Open
TDteach opened this issue Aug 15, 2022 · 1 comment
Open

Low effective loading in get_class_subset function #159

TDteach opened this issue Aug 15, 2022 · 1 comment

Comments

@TDteach
Copy link

TDteach commented Aug 15, 2022

Currently, the get_class_subset function in trojanzoo.datasets.Dataset directly use the function get_class_subset in trojanzoo.utils.data.py.
However, the function get_class_subset in trojanzoo.utils.data.py runs in low efficiency, especially for ImageNet data. It loads the whole dataset including images and labels through this line.
And, only labels will be used subsequently.

I suggest to use the following code to replace this function to avoid the useless loading.

        class_list = [class_list] if isinstance(class_list, int) else class_list
        indices = np.arange(len(dataset))
        if isinstance(dataset, Subset):
            idx = np.array(dataset.indices)
            indices = idx[indices]
            dataset = dataset.dataset

        if self.target_transform is not None:
            targets = [dataset.target_transform(t) for t in dataset.targets]
        else:
            targets = dataset.targets
        targets = np.asarray(targets)
        idx_bool = np.isin(targets, class_list)
        idx = np.arange(len(dataset))[idx_bool]
        idx = np.intersect1d(idx, indices)
        return Subset(dataset, idx)
@ain-soph
Copy link
Owner

ain-soph commented Aug 15, 2022

Your provided implementation is based on the assumption that dataset is actually torchvision.datasets.VisionDataset rather than torch.utils.data.Dataset (especially for Subset), while this assumption is not always true.

Not every dataset has target_transform method or targets attribute.

Especially since pytorch team is gradually deprecating the dataset convention and use the new datapipe style, I don't think it's a good idea to make trojanzoo function rely on such concrete internal method and attribute.


But what you claim is correct, current implementation is too slow for ImageNet.
It'll be perfect if we can find a solution that works for future ImageNet dataset as well. The old ImageNet dataset (ImageFolder style) will be deprecated next year after pytorch 2.0 .

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants