Skip to content
/ ml-dr Public

A light-weight implementation of ICCV2023 paper "Reinforce Data, Multiply Impact: Improved Model Accuracy and Robustness with Dataset Reinforcement."

License

Notifications You must be signed in to change notification settings

apple/ml-dr

Repository files navigation

Dataset Reinforcement

A light-weight implementation of Dataset Reinforcement, pretrained checkpoints, and reinforced datasets.

Reinforce Data, Multiply Impact: Improved Model Accuracy and Robustness with Dataset Reinforcement. , Faghri, F., Pouransari, H., Mehta, S., Farajtabar, M., Farhadi, A., Rastegari, M., & Tuzel, O., Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV), 2023.

Update 2023/09/22: Table 7-Average column corrected in ArXiv V3. Correct numbers: 30.4, 37.1, 37.9, 43.7, 39.6, 51.1.

Reinforced ImageNet, ImageNet+, improves accuracy at similar iterations/wall-clock

Reinforced ImageNet, ImageNet, improves accuracy at similar iterations/wall-clock.

ImageNet validation accuracy of ResNet-50 is shown as a function of training duration with (1) ImageNet dataset, (2) knowledge distillation (KD), and (3) ImageNet+ dataset (ours). Each point is a full training with epochs varying from 50-1000. An epoch has the same number of iterations for ImageNet/ImageNet+.

Illustration of Dataset Reinforcement

Illustration of Dataset Reinforcement.

Data augmentation and knowledge distillation are common approaches to improving accuracy. Dataset reinforcement combines the benefits of both by bringing the advantages of large models trained on large datasets to other datasets and models. Training of new models with a reinforced dataset is as fast as training on the original dataset for the same total iterations. Creating a reinforced dataset is a one-time process (e.g., ImageNet to ImageNet+) the cost of which is amortized over repeated uses.

Requirements

Install the requirements using:

   pip install -r requirements.txt

We support loading models from Timm library and CVNets library.

To install CVNets library follow their installation instructions.

Reinforced Data

The following is a list of reinforcements for ImageNet/CIFAR-100/Food-101/Flowers-102. We recommend ImageNet+-RA/RE based on the analysis in the paper.

Reinforce Data Task ID Size (GBs) Comments
ImageNet+-RRC rdata 33.4 [NS=400]
ImageNet+-+M* rdata 46.3 [NS=400]
ImageNet+-+RA/RE rdata 37.5 [NS=400]
ImageNet+-+M*+R* rdata 53.3 [NS=400]
ImageNet+-RRC-Small rdata 4.7 [NS=100, K=5]
ImageNet+-+M*-Small rdata 7.8 [NS=100, K=5]
ImageNet+-+RA/RE-Small rdata 5.6 [NS=100, K=5]
ImageNet+-+M*+R*-Small rdata 9.4 [NS=100, K=5]
ImageNet+-RRC-Mini rdata 4.4 [NS=50]
ImageNet+-+M*-Mini rdata 6.1 [NS=50]
ImageNet+-+RA/RE-Mini rdata 4.9 [NS=50]
ImageNet+-+M*+R*-Mini rdata 7.0 [NS=50]
CIFAR-100 rdata 2.5 [NS=800]
Food-101 rdata 4.2 [NS=800]
Flowers-102 rdata 0.5 [NS=8000]

Pretrained Checkpoints

CVNets Checkpoints

We provide pretrained checkpoints for various models in CVNets. The accuracies can be verified using the CVNets library.

Selected results trained for 1000 epochs:

Name Mode Params ImageNet ImageNet+ ImageNet (EMA) ImageNet+ (EMA) Links
MobileNetV3 large 5.5M 74.8 77.9 (+3.1) 75.8 77.9 (+2.1) [best.pt] [ema_best.pt] [config.yaml] [metrics.jb]
ResNet 50 25.6M 80.0 82.0 (+2.0) 80.1 82.0 (+1.9) [best.pt] [ema_best.pt] [config.yaml] [metrics.jb]
ViT base 86.7M 76.8 85.1 (+8.3) 80.8 85.1 (+4.3) [best.pt] [ema_best.pt] [config.yaml] [metrics.jb]
ViT-384 base 86.7M 79.4 85.4 (+6.0) 83.1 85.5 (+2.4) [best.pt] [ema_best.pt] [config.yaml] [metrics.jb]
Swin tiny 28.3M 81.3 84.0 (+2.7) 80.5 83.5 (+3.0) [best.pt] [ema_best.pt] [config.yaml] [metrics.jb]
Swin small 49.7M 81.3 85.0 (+3.7) 81.9 84.5 (+2.6) [best.pt] [ema_best.pt] [config.yaml] [metrics.jb]
Swin base 87.8M 81.5 85.4 (+3.9) 81.8 85.2 (+3.4) [best.pt] [ema_best.pt] [config.yaml] [metrics.jb]
Swin-384 base 87.8M 83.6 85.8 (+2.2) 83.8 85.5 (+1.7) [best.pt] [ema_best.pt] [config.yaml] [metrics.jb]

Timm Checkpoints

We provide pretrained checkpoints for ResNet50d from Timm library trained for 150 epochs using various reinforced datasets:

  • imagenet-timm.tar: All Timm checkpoints trained on ImageNet and ImageNet+ (2.3GBs).
Model Reinforce Data Accuracy Links
ResNet50d [ERM] N/A 78.9 [best.pt] [config.yaml] [metrics.jb]
ResNet50d ImageNet+-RRC 80.0 [best.pt] [config.yaml] [metrics.jb]
ResNet50d ImageNet+-+M* 80.5 [best.pt] [config.yaml] [metrics.jb]
ResNet50d ImageNet+-+RA/RE 80.4 [best.pt] [config.yaml] [metrics.jb]
ResNet50d ImageNet+-+M*+R* 80.2 [best.pt] [config.yaml] [metrics.jb]
ResNet50d ImageNet+-RRC-Small 80.0 [best.pt] [config.yaml] [metrics.jb]
ResNet50d ImageNet+-+M*-Small 80.6 [best.pt] [config.yaml] [metrics.jb]
ResNet50d ImageNet+-+RA/RE-Small 80.2 [best.pt] [config.yaml] [metrics.jb]
ResNet50d ImageNet+-+M*+R*-Small 80.1 [best.pt] [config.yaml] [metrics.jb]
ResNet50d ImageNet+-RRC-Mini 80.1 [best.pt] [config.yaml] [metrics.jb]
ResNet50d ImageNet+-+M*-Mini 80.5 [best.pt] [config.yaml] [metrics.jb]
ResNet50d ImageNet+-+RA/RE-Mini 80.4 [best.pt] [config.yaml] [metrics.jb]
ResNet50d ImageNet+-+M*+R*-Mini 80.2 [best.pt] [config.yaml] [metrics.jb]

Training

We provide YAML configurations for training ResNet-50 in CFG_FILE=configs/${DATASET}/${TRAINER}.yaml, with the following options:

  • DATASET: imagenet, cifar100, flowers102, and food101.
  • TRAINER: standard training (erm), knowledge distillation (kd), and with reinforced data (plus).

Follow the steps:

  • Choose the dataset and trainer from the choices above.
  • Download ImageNet data and set data_path in $CFG_FILE.
  • Download reinforcement metadata and set reinforce.data_path in $CFG_FILE.
python train.py --config configs/imagenet/erm.yaml  # ImageNet training without Reinforcements (ERM)
python train.py --config configs/imagenet/kd.yaml  # Knowledge Distillation
python train.py --config configs/imagenet/plus.yaml  # ImageNet+ training with reinforcements

Hyperparameters such as batch size for ImageNet training are optimized for running on a single node with 8xA100 40GB GPUs. For CIFAR-100/Flowers-102/Food-101, the configurations are optimized for training on a single GPU.

Reinforce ImageNet

Follow the steps:

  • Download ImageNet data and set data_path in $CFG_FILE.
  • If needed, change the teacher in $CFG_FILE to a smaller architecture.
python reinforce.py --config configs/imagenet/reinforce/randaug.yaml

Reference

If you found this code useful, please cite the following paper:

@InProceedings{faghri2023reinforce,
    author    = {Faghri, Fartash and Pouransari, Hadi and Mehta, Sachin and Farajtabar, Mehrdad and Farhadi, Ali and Rastegari, Mohammad and Tuzel, Oncel},
    title     = {Reinforce Data, Multiply Impact: Improved Model Accuracy and Robustness with Dataset Reinforcement},
    booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
    month     = {October},
    year      = {2023},
}

License

This sample code is released under the LICENSE terms.

About

A light-weight implementation of ICCV2023 paper "Reinforce Data, Multiply Impact: Improved Model Accuracy and Robustness with Dataset Reinforcement."

Resources

License

Code of conduct

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages