Skip to content

Commit

Permalink
BNN-PYNQ training scripts and models
Browse files Browse the repository at this point in the history
  • Loading branch information
volcacius committed Apr 9, 2020
1 parent 2425a9e commit 5762115
Show file tree
Hide file tree
Showing 27 changed files with 1,487 additions and 0 deletions.
46 changes: 46 additions & 0 deletions brevitas_examples/bnn_pynq/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# BNN-PYNQ Brevitas experiments

This repo contains training scripts and pretrained models to recreate the LFC and CNV models
used in the [BNN-PYNQ](https://github.com/Xilinx/BNN-PYNQ) repo using [Brevitas](https://github.com/Xilinx/brevitas).
These pretrained models and training scripts are courtesy of
[Alessandro Pappalardo](https://github.com/volcacius) and [Ussama Zahid](https://github.com/ussamazahid96).

## Experiments

| Name | Input quantization | Weight quantization | Activation quantization | Brevitas Top1 | Theano Top1 |
|----------|------------------------------|---------------------|-------------------------|---------------|---------------|
| TFC_1W1A | 1 bit | 1 bit | 1 bit | 93.17% | |
| TFC_1W2A | 2 bit | 1 bit | 2 bit | 94.79% | |
| TFC_2W2A | 2 bit | 2 bit | 2 bit | 96.60% | |
| SFC_1W1A | 1 bit | 1 bit | 1 bit | 97.81% | |
| SFC_1W2A | 2 bit | 1 bit | 2 bit | 98.31% | |
| SFC_2W2A | 2 bit | 2 bit | 2 bit | 98.66% | |
| LFC_1W1A | 1 bit | 1 bit | 1 bit | 98.88% | 98.35% |
| LFC_1W2A | 2 bit | 1 bit | 2 bit | 98.99% | 98.55% |
| CNV_1W1A | 8 bit | 1 bit | 1 bit | 84.22% | 79.54% |
| CNV_1W2A | 8 bit | 1 bit | 2 bit | 87.80% | 83.63% |
| CNV_2W2A | 8 bit | 2 bit | 2 bit | 89.03% | 84.80% |

## Train

A few notes on training:
- An experiments folder at */path/to/experiments* must exist before launching the training.
- Training is set to 1000 epochs for 1W1A networks, 500 otherwise.
- Force-enabling the Pytorch JIT with the env flag PYTORCH_JIT=1 significantly speeds up training.

To start training a model from scratch, e.g. LFC_1W1A, run:
```bash
PYTORCH_JIT=1 brevitas_bnn_pynq_train --network LFC_1W1A --experiments /path/to/experiments
```

## Evaluate

To evaluate a pretrained model, e.g. LFC_1W1A, run:
```bash
PYTORCH_JIT=1 brevitas_bnn_pynq_train --evaluate --network LFC_1W1A --pretrained
```

To evaluate your own checkpoint, of e.g. LFC_1W1A, run:
```bash
PYTORCH_JIT=1 brevitas_bnn_pynq_train --evaluate --network LFC_1W1A --resume /path/to/checkpoint.tar
```
1 change: 1 addition & 0 deletions brevitas_examples/bnn_pynq/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .models import *
135 changes: 135 additions & 0 deletions brevitas_examples/bnn_pynq/bnn_pynq_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# MIT License
#
# Copyright (c) 2019 Xilinx
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import argparse
import os

import torch
from .trainer import Trainer


# Util method to add mutually exclusive boolean
def add_bool_arg(parser, name, default):
group = parser.add_mutually_exclusive_group(required=False)
group.add_argument("--" + name, dest=name, action="store_true")
group.add_argument("--no_" + name, dest=name, action="store_false")
parser.set_defaults(**{name: default})


# Util method to pass None as a string and be recognized as None value
def none_or_str(value):
if value == "None":
return None
return value


def none_or_int(value):
if value == "None":
return None
return int(value)


# I/O
parser = argparse.ArgumentParser(description="PyTorch MNIST/CIFAR10 Training")
parser.add_argument("--datadir", default="./data/", help="Dataset location")
parser.add_argument("--experiments", default="./experiments", help="Path to experiments folder")
parser.add_argument("--dry_run", action="store_true", help="Disable output files generation")
parser.add_argument("--log_freq", type=int, default=10)

# Execution modes
parser.add_argument("--evaluate", dest="evaluate", action="store_true", help="evaluate model on validation set")
parser.add_argument("--resume", dest="resume", type=none_or_str,
help="Resume from checkpoint. Overrides --pretrained flag.")
add_bool_arg(parser, "detect_nan", default=False)

# Compute resources
parser.add_argument("--num_workers", default=4, type=int, help="Number of workers")
parser.add_argument("--gpus", type=none_or_str, default="0", help="Comma separated GPUs")

# Optimizer hyperparams
parser.add_argument("--batch_size", default=100, type=int, help="batch size")
parser.add_argument("--lr", default=0.02, type=float, help="Learning rate")
parser.add_argument("--optim", type=none_or_str, default="ADAM",help="Optimizer to use")
parser.add_argument("--loss", type=none_or_str, default="SqrHinge",help="Loss function to use")
parser.add_argument("--scheduler", default="FIXED", type=none_or_str, help="LR Scheduler")
parser.add_argument("--milestones", type=none_or_str, default='100,150,200,250', help="Scheduler milestones")
parser.add_argument("--momentum", default=0.9, type=float, help="Momentum")
parser.add_argument("--weight_decay", default=0, type=float, help="Weight decay")
parser.add_argument("--epochs", default=1000, type=int, help="Number of epochs")
parser.add_argument("--random_seed", default=1, type=int, help="Random seed")

# Neural network Architecture
parser.add_argument("--network", default="LFC_1W1A", type=str, help="neural network")
parser.add_argument("--pretrained", action='store_true', help="Load pretrained model")

# Pytorch precision
torch.set_printoptions(precision=10)


class objdict(dict):
def __getattr__(self, name):
if name in self:
return self[name]
else:
raise AttributeError("No such attribute: " + name)

def __setattr__(self, name, value):
self[name] = value

def __delattr__(self, name):
if name in self:
del self[name]
else:
raise AttributeError("No such attribute: " + name)


def main():
args = parser.parse_args()

# Set relative paths relative to current workdir
path_args = ["datadir", "experiments", "resume"]
for path_arg in path_args:
path = getattr(args, path_arg)
if path is not None and not os.path.isabs(path):
abs_path = os.path.abspath(os.path.join(os.getcwd(), path))
setattr(args, path_arg, abs_path)

# Access config as an object
args = objdict(args.__dict__)

# Avoid creating new folders etc.
if args.evaluate:
args.dry_run = True

# Init trainer
trainer = Trainer(args)

# Execute
if args.evaluate:
with torch.no_grad():
trainer.eval_model()
else:
trainer.train_model()


if __name__ == "__main__":
main()
Empty file.
12 changes: 12 additions & 0 deletions brevitas_examples/bnn_pynq/cfg/cnv_1w1a.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[MODEL]
ARCH: CNV
PRETRAINED_URL: https://github.com/Xilinx/brevitas/releases/download/bnn_pynq-r0/cnv_1w1a-758c8fef.pth
DATASET: CIFAR10
IN_CHANNELS: 3
NUM_CLASSES: 10

[QUANT]
WEIGHT_BIT_WIDTH: 1
ACT_BIT_WIDTH: 1
IN_BIT_WIDTH: 8

12 changes: 12 additions & 0 deletions brevitas_examples/bnn_pynq/cfg/cnv_1w2a.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[MODEL]
ARCH: CNV
PRETRAINED_URL: https://github.com/Xilinx/brevitas/releases/download/bnn_pynq-r0/cnv_1w2a-23b6e2e4.pth
DATASET: CIFAR10
IN_CHANNELS: 3
NUM_CLASSES: 10

[QUANT]
WEIGHT_BIT_WIDTH: 1
ACT_BIT_WIDTH: 2
IN_BIT_WIDTH: 8

12 changes: 12 additions & 0 deletions brevitas_examples/bnn_pynq/cfg/cnv_2w2a.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[MODEL]
ARCH: CNV
PRETRAINED_URL: https://github.com/Xilinx/brevitas/releases/download/bnn_pynq-r0/cnv_2w2a-0702987f.pth
DATASET: CIFAR10
IN_CHANNELS: 3
NUM_CLASSES: 10

[QUANT]
WEIGHT_BIT_WIDTH: 2
ACT_BIT_WIDTH: 2
IN_BIT_WIDTH: 8

12 changes: 12 additions & 0 deletions brevitas_examples/bnn_pynq/cfg/lfc_1w1a.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[MODEL]
ARCH: LFC
PRETRAINED_URL: https://github.com/Xilinx/brevitas/releases/download/bnn_pynq-r0/lfc_1w1a-db6e13bd.pth
DATASET: MNIST
IN_CHANNELS: 1
NUM_CLASSES: 10

[QUANT]
WEIGHT_BIT_WIDTH: 1
ACT_BIT_WIDTH: 1
IN_BIT_WIDTH: 1

12 changes: 12 additions & 0 deletions brevitas_examples/bnn_pynq/cfg/lfc_1w2a.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[MODEL]
ARCH: LFC
PRETRAINED_URL: https://github.com/Xilinx/brevitas/releases/download/bnn_pynq-r0/lfc_1w2a-0a771c67.pth
DATASET: MNIST
IN_CHANNELS: 1
NUM_CLASSES: 10

[QUANT]
WEIGHT_BIT_WIDTH: 1
ACT_BIT_WIDTH: 2
IN_BIT_WIDTH: 2

12 changes: 12 additions & 0 deletions brevitas_examples/bnn_pynq/cfg/sfc_1w1a.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[MODEL]
ARCH: SFC
PRETRAINED_URL: https://github.com/Xilinx/brevitas/releases/download/bnn_pynq-r0/sfc_1w1a-fd8a6c3d.pth
DATASET: MNIST
IN_CHANNELS: 1
NUM_CLASSES: 10

[QUANT]
WEIGHT_BIT_WIDTH: 1
ACT_BIT_WIDTH: 1
IN_BIT_WIDTH: 1

12 changes: 12 additions & 0 deletions brevitas_examples/bnn_pynq/cfg/sfc_1w2a.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[MODEL]
ARCH: SFC
PRETRAINED_URL: https://github.com/Xilinx/brevitas/releases/download/bnn_pynq-r0/sfc_1w2a-fdc0c779.pth
DATASET: MNIST
IN_CHANNELS: 1
NUM_CLASSES: 10

[QUANT]
WEIGHT_BIT_WIDTH: 1
ACT_BIT_WIDTH: 2
IN_BIT_WIDTH: 2

12 changes: 12 additions & 0 deletions brevitas_examples/bnn_pynq/cfg/sfc_2w2a.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[MODEL]
ARCH: SFC
PRETRAINED_URL: https://github.com/Xilinx/brevitas/releases/download/bnn_pynq-r0/sfc_2w2a-35a7c41d.pth
DATASET: MNIST
IN_CHANNELS: 1
NUM_CLASSES: 10

[QUANT]
WEIGHT_BIT_WIDTH: 2
ACT_BIT_WIDTH: 2
IN_BIT_WIDTH: 2

12 changes: 12 additions & 0 deletions brevitas_examples/bnn_pynq/cfg/tfc_1w1a.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[MODEL]
ARCH: TFC
PRETRAINED_URL: https://github.com/Xilinx/brevitas/releases/download/bnn_pynq-r0/tfc_1w1a-ff8140dc.pth
DATASET: MNIST
IN_CHANNELS: 1
NUM_CLASSES: 10

[QUANT]
WEIGHT_BIT_WIDTH: 1
ACT_BIT_WIDTH: 1
IN_BIT_WIDTH: 1

12 changes: 12 additions & 0 deletions brevitas_examples/bnn_pynq/cfg/tfc_1w2a.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[MODEL]
ARCH: TFC
PRETRAINED_URL: https://github.com/Xilinx/brevitas/releases/download/bnn_pynq-r0/tfc_1w2a-95ad635b.pth
DATASET: MNIST
IN_CHANNELS: 1
NUM_CLASSES: 10

[QUANT]
WEIGHT_BIT_WIDTH: 1
ACT_BIT_WIDTH: 2
IN_BIT_WIDTH: 2

12 changes: 12 additions & 0 deletions brevitas_examples/bnn_pynq/cfg/tfc_2w2a.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[MODEL]
ARCH: TFC
PRETRAINED_URL: https://github.com/Xilinx/brevitas/releases/download/bnn_pynq-r0/tfc_2w2a-7e0a62f1.pth
DATASET: MNIST
IN_CHANNELS: 1
NUM_CLASSES: 10

[QUANT]
WEIGHT_BIT_WIDTH: 2
ACT_BIT_WIDTH: 2
IN_BIT_WIDTH: 2

Loading

0 comments on commit 5762115

Please sign in to comment.