You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Currently, the get_class_subset function in trojanzoo.datasets.Dataset directly use the function get_class_subset in trojanzoo.utils.data.py.
However, the function get_class_subset in trojanzoo.utils.data.py runs in low efficiency, especially for ImageNet data. It loads the whole dataset including images and labels through this line.
And, only labels will be used subsequently.
I suggest to use the following code to replace this function to avoid the useless loading.
class_list = [class_list] if isinstance(class_list, int) else class_list
indices = np.arange(len(dataset))
if isinstance(dataset, Subset):
idx = np.array(dataset.indices)
indices = idx[indices]
dataset = dataset.dataset
if self.target_transform is not None:
targets = [dataset.target_transform(t) for t in dataset.targets]
else:
targets = dataset.targets
targets = np.asarray(targets)
idx_bool = np.isin(targets, class_list)
idx = np.arange(len(dataset))[idx_bool]
idx = np.intersect1d(idx, indices)
return Subset(dataset, idx)
The text was updated successfully, but these errors were encountered:
Your provided implementation is based on the assumption that dataset is actually torchvision.datasets.VisionDataset rather than torch.utils.data.Dataset (especially for Subset), while this assumption is not always true.
Not every dataset has target_transform method or targets attribute.
Especially since pytorch team is gradually deprecating the dataset convention and use the new datapipe style, I don't think it's a good idea to make trojanzoo function rely on such concrete internal method and attribute.
But what you claim is correct, current implementation is too slow for ImageNet.
It'll be perfect if we can find a solution that works for future ImageNet dataset as well. The old ImageNet dataset (ImageFolder style) will be deprecated next year after pytorch 2.0 .
Currently, the get_class_subset function in trojanzoo.datasets.Dataset directly use the function get_class_subset in trojanzoo.utils.data.py.
However, the function get_class_subset in trojanzoo.utils.data.py runs in low efficiency, especially for ImageNet data. It loads the whole dataset including images and labels through this line.
And, only labels will be used subsequently.
I suggest to use the following code to replace this function to avoid the useless loading.
The text was updated successfully, but these errors were encountered: