A light-weight implementation of Dataset Reinforcement, pretrained checkpoints, and reinforced datasets.
Reinforce Data, Multiply Impact: Improved Model Accuracy and Robustness with Dataset Reinforcement. , Faghri, F., Pouransari, H., Mehta, S., Farajtabar, M., Farhadi, A., Rastegari, M., & Tuzel, O., Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV), 2023.
Update 2023/09/22: Table 7-Average column corrected in ArXiv V3. Correct numbers: 30.4, 37.1, 37.9, 43.7, 39.6, 51.1.
Reinforced ImageNet, ImageNet+, improves accuracy at similar iterations/wall-clock
ImageNet validation accuracy of ResNet-50 is shown as a function of training duration with (1) ImageNet dataset, (2) knowledge distillation (KD), and (3) ImageNet+ dataset (ours). Each point is a full training with epochs varying from 50-1000. An epoch has the same number of iterations for ImageNet/ImageNet+.
Illustration of Dataset Reinforcement
Data augmentation and knowledge distillation are common approaches to improving accuracy. Dataset reinforcement combines the benefits of both by bringing the advantages of large models trained on large datasets to other datasets and models. Training of new models with a reinforced dataset is as fast as training on the original dataset for the same total iterations. Creating a reinforced dataset is a one-time process (e.g., ImageNet to ImageNet+) the cost of which is amortized over repeated uses.
Install the requirements using:
pip install -r requirements.txt
We support loading models from Timm library and CVNets library.
To install CVNets library follow their installation instructions.
The following is a list of reinforcements for ImageNet/CIFAR-100/Food-101/Flowers-102. We recommend ImageNet+-RA/RE based on the analysis in the paper.
Reinforce Data | Task ID | Size (GBs) | Comments |
---|---|---|---|
ImageNet+-RRC | rdata | 33.4 | [NS=400] |
ImageNet+-+M* | rdata | 46.3 | [NS=400] |
ImageNet+-+RA/RE | rdata | 37.5 | [NS=400] |
ImageNet+-+M*+R* | rdata | 53.3 | [NS=400] |
ImageNet+-RRC-Small | rdata | 4.7 | [NS=100, K=5] |
ImageNet+-+M*-Small | rdata | 7.8 | [NS=100, K=5] |
ImageNet+-+RA/RE-Small | rdata | 5.6 | [NS=100, K=5] |
ImageNet+-+M*+R*-Small | rdata | 9.4 | [NS=100, K=5] |
ImageNet+-RRC-Mini | rdata | 4.4 | [NS=50] |
ImageNet+-+M*-Mini | rdata | 6.1 | [NS=50] |
ImageNet+-+RA/RE-Mini | rdata | 4.9 | [NS=50] |
ImageNet+-+M*+R*-Mini | rdata | 7.0 | [NS=50] |
CIFAR-100 | rdata | 2.5 | [NS=800] |
Food-101 | rdata | 4.2 | [NS=800] |
Flowers-102 | rdata | 0.5 | [NS=8000] |
We provide pretrained checkpoints for various models in CVNets. The accuracies can be verified using the CVNets library.
- 150 Epochs Checkpoints
- 300 Epochs Checkpoints
- 1000 Epochs Checkpoints
- imagenet-cvnets.tar: All CVNets checkpoints trained on ImageNet (14.3GBs).
- imagenet-plus-cvnets.tar: All CVNets checkpoints trained on ImageNet+ (14.3GBs).
Selected results trained for 1000 epochs:
Name | Mode | Params | ImageNet | ImageNet+ | ImageNet (EMA) | ImageNet+ (EMA) | Links |
---|---|---|---|---|---|---|---|
MobileNetV3 | large | 5.5M | 74.8 | 77.9 (+3.1) | 75.8 | 77.9 (+2.1) | [best.pt] [ema_best.pt] [config.yaml] [metrics.jb] |
ResNet | 50 | 25.6M | 80.0 | 82.0 (+2.0) | 80.1 | 82.0 (+1.9) | [best.pt] [ema_best.pt] [config.yaml] [metrics.jb] |
ViT | base | 86.7M | 76.8 | 85.1 (+8.3) | 80.8 | 85.1 (+4.3) | [best.pt] [ema_best.pt] [config.yaml] [metrics.jb] |
ViT-384 | base | 86.7M | 79.4 | 85.4 (+6.0) | 83.1 | 85.5 (+2.4) | [best.pt] [ema_best.pt] [config.yaml] [metrics.jb] |
Swin | tiny | 28.3M | 81.3 | 84.0 (+2.7) | 80.5 | 83.5 (+3.0) | [best.pt] [ema_best.pt] [config.yaml] [metrics.jb] |
Swin | small | 49.7M | 81.3 | 85.0 (+3.7) | 81.9 | 84.5 (+2.6) | [best.pt] [ema_best.pt] [config.yaml] [metrics.jb] |
Swin | base | 87.8M | 81.5 | 85.4 (+3.9) | 81.8 | 85.2 (+3.4) | [best.pt] [ema_best.pt] [config.yaml] [metrics.jb] |
Swin-384 | base | 87.8M | 83.6 | 85.8 (+2.2) | 83.8 | 85.5 (+1.7) | [best.pt] [ema_best.pt] [config.yaml] [metrics.jb] |
We provide pretrained checkpoints for ResNet50d from Timm library trained for 150 epochs using various reinforced datasets:
- imagenet-timm.tar: All Timm checkpoints trained on ImageNet and ImageNet+ (2.3GBs).
Model | Reinforce Data | Accuracy | Links |
---|---|---|---|
ResNet50d [ERM] | N/A | 78.9 | [best.pt] [config.yaml] [metrics.jb] |
ResNet50d | ImageNet+-RRC | 80.0 | [best.pt] [config.yaml] [metrics.jb] |
ResNet50d | ImageNet+-+M* | 80.5 | [best.pt] [config.yaml] [metrics.jb] |
ResNet50d | ImageNet+-+RA/RE | 80.4 | [best.pt] [config.yaml] [metrics.jb] |
ResNet50d | ImageNet+-+M*+R* | 80.2 | [best.pt] [config.yaml] [metrics.jb] |
ResNet50d | ImageNet+-RRC-Small | 80.0 | [best.pt] [config.yaml] [metrics.jb] |
ResNet50d | ImageNet+-+M*-Small | 80.6 | [best.pt] [config.yaml] [metrics.jb] |
ResNet50d | ImageNet+-+RA/RE-Small | 80.2 | [best.pt] [config.yaml] [metrics.jb] |
ResNet50d | ImageNet+-+M*+R*-Small | 80.1 | [best.pt] [config.yaml] [metrics.jb] |
ResNet50d | ImageNet+-RRC-Mini | 80.1 | [best.pt] [config.yaml] [metrics.jb] |
ResNet50d | ImageNet+-+M*-Mini | 80.5 | [best.pt] [config.yaml] [metrics.jb] |
ResNet50d | ImageNet+-+RA/RE-Mini | 80.4 | [best.pt] [config.yaml] [metrics.jb] |
ResNet50d | ImageNet+-+M*+R*-Mini | 80.2 | [best.pt] [config.yaml] [metrics.jb] |
We provide YAML configurations for training ResNet-50 in
CFG_FILE=configs/${DATASET}/${TRAINER}.yaml
, with the following options:
DATASET
:imagenet
,cifar100
,flowers102
, andfood101
.TRAINER
: standard training (erm
), knowledge distillation (kd
), and with reinforced data (plus
).
Follow the steps:
- Choose the dataset and trainer from the choices above.
- Download ImageNet data and set
data_path
in$CFG_FILE
. - Download reinforcement metadata and set
reinforce.data_path
in$CFG_FILE
.
python train.py --config configs/imagenet/erm.yaml # ImageNet training without Reinforcements (ERM)
python train.py --config configs/imagenet/kd.yaml # Knowledge Distillation
python train.py --config configs/imagenet/plus.yaml # ImageNet+ training with reinforcements
Hyperparameters such as batch size for ImageNet training are optimized for running on a single node with 8xA100 40GB GPUs. For CIFAR-100/Flowers-102/Food-101, the configurations are optimized for training on a single GPU.
Follow the steps:
- Download ImageNet data and set
data_path
in$CFG_FILE
. - If needed, change the teacher in
$CFG_FILE
to a smaller architecture.
python reinforce.py --config configs/imagenet/reinforce/randaug.yaml
If you found this code useful, please cite the following paper:
@InProceedings{faghri2023reinforce,
author = {Faghri, Fartash and Pouransari, Hadi and Mehta, Sachin and Farajtabar, Mehrdad and Farhadi, Ali and Rastegari, Mohammad and Tuzel, Oncel},
title = {Reinforce Data, Multiply Impact: Improved Model Accuracy and Robustness with Dataset Reinforcement},
booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
month = {October},
year = {2023},
}
This sample code is released under the LICENSE terms.