From 5762115e43543d42c0191e279ede9d1c8d933300 Mon Sep 17 00:00:00 2001 From: Alessandro Pappalardo Date: Wed, 8 Apr 2020 21:53:29 +0100 Subject: [PATCH] BNN-PYNQ training scripts and models --- brevitas_examples/bnn_pynq/README.md | 46 +++ brevitas_examples/bnn_pynq/__init__.py | 1 + brevitas_examples/bnn_pynq/bnn_pynq_train.py | 135 +++++++ brevitas_examples/bnn_pynq/cfg/__init__.py | 0 brevitas_examples/bnn_pynq/cfg/cnv_1w1a.ini | 12 + brevitas_examples/bnn_pynq/cfg/cnv_1w2a.ini | 12 + brevitas_examples/bnn_pynq/cfg/cnv_2w2a.ini | 12 + brevitas_examples/bnn_pynq/cfg/lfc_1w1a.ini | 12 + brevitas_examples/bnn_pynq/cfg/lfc_1w2a.ini | 12 + brevitas_examples/bnn_pynq/cfg/sfc_1w1a.ini | 12 + brevitas_examples/bnn_pynq/cfg/sfc_1w2a.ini | 12 + brevitas_examples/bnn_pynq/cfg/sfc_2w2a.ini | 12 + brevitas_examples/bnn_pynq/cfg/tfc_1w1a.ini | 12 + brevitas_examples/bnn_pynq/cfg/tfc_1w2a.ini | 12 + brevitas_examples/bnn_pynq/cfg/tfc_2w2a.ini | 12 + brevitas_examples/bnn_pynq/logger.py | 115 ++++++ brevitas_examples/bnn_pynq/models/CNV.py | 128 +++++++ brevitas_examples/bnn_pynq/models/LFC.py | 97 +++++ brevitas_examples/bnn_pynq/models/SFC.py | 97 +++++ brevitas_examples/bnn_pynq/models/TFC.py | 97 +++++ brevitas_examples/bnn_pynq/models/__init__.py | 116 ++++++ brevitas_examples/bnn_pynq/models/common.py | 96 +++++ brevitas_examples/bnn_pynq/models/losses.py | 31 ++ .../bnn_pynq/models/tensor_norm.py | 56 +++ brevitas_examples/bnn_pynq/trainer.py | 331 ++++++++++++++++++ setup.py | 1 + test/brevitas_examples/test_import.py | 8 + 27 files changed, 1487 insertions(+) create mode 100644 brevitas_examples/bnn_pynq/README.md create mode 100644 brevitas_examples/bnn_pynq/__init__.py create mode 100644 brevitas_examples/bnn_pynq/bnn_pynq_train.py create mode 100644 brevitas_examples/bnn_pynq/cfg/__init__.py create mode 100644 brevitas_examples/bnn_pynq/cfg/cnv_1w1a.ini create mode 100644 brevitas_examples/bnn_pynq/cfg/cnv_1w2a.ini create mode 100644 brevitas_examples/bnn_pynq/cfg/cnv_2w2a.ini create mode 100644 brevitas_examples/bnn_pynq/cfg/lfc_1w1a.ini create mode 100644 brevitas_examples/bnn_pynq/cfg/lfc_1w2a.ini create mode 100644 brevitas_examples/bnn_pynq/cfg/sfc_1w1a.ini create mode 100644 brevitas_examples/bnn_pynq/cfg/sfc_1w2a.ini create mode 100644 brevitas_examples/bnn_pynq/cfg/sfc_2w2a.ini create mode 100644 brevitas_examples/bnn_pynq/cfg/tfc_1w1a.ini create mode 100644 brevitas_examples/bnn_pynq/cfg/tfc_1w2a.ini create mode 100644 brevitas_examples/bnn_pynq/cfg/tfc_2w2a.ini create mode 100644 brevitas_examples/bnn_pynq/logger.py create mode 100644 brevitas_examples/bnn_pynq/models/CNV.py create mode 100644 brevitas_examples/bnn_pynq/models/LFC.py create mode 100644 brevitas_examples/bnn_pynq/models/SFC.py create mode 100644 brevitas_examples/bnn_pynq/models/TFC.py create mode 100644 brevitas_examples/bnn_pynq/models/__init__.py create mode 100644 brevitas_examples/bnn_pynq/models/common.py create mode 100644 brevitas_examples/bnn_pynq/models/losses.py create mode 100644 brevitas_examples/bnn_pynq/models/tensor_norm.py create mode 100644 brevitas_examples/bnn_pynq/trainer.py diff --git a/brevitas_examples/bnn_pynq/README.md b/brevitas_examples/bnn_pynq/README.md new file mode 100644 index 000000000..ce1cc3685 --- /dev/null +++ b/brevitas_examples/bnn_pynq/README.md @@ -0,0 +1,46 @@ +# BNN-PYNQ Brevitas experiments + +This repo contains training scripts and pretrained models to recreate the LFC and CNV models +used in the [BNN-PYNQ](https://github.com/Xilinx/BNN-PYNQ) repo using [Brevitas](https://github.com/Xilinx/brevitas). +These pretrained models and training scripts are courtesy of +[Alessandro Pappalardo](https://github.com/volcacius) and [Ussama Zahid](https://github.com/ussamazahid96). + +## Experiments + +| Name | Input quantization | Weight quantization | Activation quantization | Brevitas Top1 | Theano Top1 | +|----------|------------------------------|---------------------|-------------------------|---------------|---------------| +| TFC_1W1A | 1 bit | 1 bit | 1 bit | 93.17% | | +| TFC_1W2A | 2 bit | 1 bit | 2 bit | 94.79% | | +| TFC_2W2A | 2 bit | 2 bit | 2 bit | 96.60% | | +| SFC_1W1A | 1 bit | 1 bit | 1 bit | 97.81% | | +| SFC_1W2A | 2 bit | 1 bit | 2 bit | 98.31% | | +| SFC_2W2A | 2 bit | 2 bit | 2 bit | 98.66% | | +| LFC_1W1A | 1 bit | 1 bit | 1 bit | 98.88% | 98.35% | +| LFC_1W2A | 2 bit | 1 bit | 2 bit | 98.99% | 98.55% | +| CNV_1W1A | 8 bit | 1 bit | 1 bit | 84.22% | 79.54% | +| CNV_1W2A | 8 bit | 1 bit | 2 bit | 87.80% | 83.63% | +| CNV_2W2A | 8 bit | 2 bit | 2 bit | 89.03% | 84.80% | + +## Train + +A few notes on training: +- An experiments folder at */path/to/experiments* must exist before launching the training. +- Training is set to 1000 epochs for 1W1A networks, 500 otherwise. +- Force-enabling the Pytorch JIT with the env flag PYTORCH_JIT=1 significantly speeds up training. + +To start training a model from scratch, e.g. LFC_1W1A, run: + ```bash +PYTORCH_JIT=1 brevitas_bnn_pynq_train --network LFC_1W1A --experiments /path/to/experiments + ``` + +## Evaluate + +To evaluate a pretrained model, e.g. LFC_1W1A, run: + ```bash +PYTORCH_JIT=1 brevitas_bnn_pynq_train --evaluate --network LFC_1W1A --pretrained + ``` + +To evaluate your own checkpoint, of e.g. LFC_1W1A, run: + ```bash +PYTORCH_JIT=1 brevitas_bnn_pynq_train --evaluate --network LFC_1W1A --resume /path/to/checkpoint.tar + ``` \ No newline at end of file diff --git a/brevitas_examples/bnn_pynq/__init__.py b/brevitas_examples/bnn_pynq/__init__.py new file mode 100644 index 000000000..cf4f59d6c --- /dev/null +++ b/brevitas_examples/bnn_pynq/__init__.py @@ -0,0 +1 @@ +from .models import * \ No newline at end of file diff --git a/brevitas_examples/bnn_pynq/bnn_pynq_train.py b/brevitas_examples/bnn_pynq/bnn_pynq_train.py new file mode 100644 index 000000000..c2224a160 --- /dev/null +++ b/brevitas_examples/bnn_pynq/bnn_pynq_train.py @@ -0,0 +1,135 @@ +# MIT License +# +# Copyright (c) 2019 Xilinx +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import argparse +import os + +import torch +from .trainer import Trainer + + +# Util method to add mutually exclusive boolean +def add_bool_arg(parser, name, default): + group = parser.add_mutually_exclusive_group(required=False) + group.add_argument("--" + name, dest=name, action="store_true") + group.add_argument("--no_" + name, dest=name, action="store_false") + parser.set_defaults(**{name: default}) + + +# Util method to pass None as a string and be recognized as None value +def none_or_str(value): + if value == "None": + return None + return value + + +def none_or_int(value): + if value == "None": + return None + return int(value) + + +# I/O +parser = argparse.ArgumentParser(description="PyTorch MNIST/CIFAR10 Training") +parser.add_argument("--datadir", default="./data/", help="Dataset location") +parser.add_argument("--experiments", default="./experiments", help="Path to experiments folder") +parser.add_argument("--dry_run", action="store_true", help="Disable output files generation") +parser.add_argument("--log_freq", type=int, default=10) + +# Execution modes +parser.add_argument("--evaluate", dest="evaluate", action="store_true", help="evaluate model on validation set") +parser.add_argument("--resume", dest="resume", type=none_or_str, + help="Resume from checkpoint. Overrides --pretrained flag.") +add_bool_arg(parser, "detect_nan", default=False) + +# Compute resources +parser.add_argument("--num_workers", default=4, type=int, help="Number of workers") +parser.add_argument("--gpus", type=none_or_str, default="0", help="Comma separated GPUs") + +# Optimizer hyperparams +parser.add_argument("--batch_size", default=100, type=int, help="batch size") +parser.add_argument("--lr", default=0.02, type=float, help="Learning rate") +parser.add_argument("--optim", type=none_or_str, default="ADAM",help="Optimizer to use") +parser.add_argument("--loss", type=none_or_str, default="SqrHinge",help="Loss function to use") +parser.add_argument("--scheduler", default="FIXED", type=none_or_str, help="LR Scheduler") +parser.add_argument("--milestones", type=none_or_str, default='100,150,200,250', help="Scheduler milestones") +parser.add_argument("--momentum", default=0.9, type=float, help="Momentum") +parser.add_argument("--weight_decay", default=0, type=float, help="Weight decay") +parser.add_argument("--epochs", default=1000, type=int, help="Number of epochs") +parser.add_argument("--random_seed", default=1, type=int, help="Random seed") + +# Neural network Architecture +parser.add_argument("--network", default="LFC_1W1A", type=str, help="neural network") +parser.add_argument("--pretrained", action='store_true', help="Load pretrained model") + +# Pytorch precision +torch.set_printoptions(precision=10) + + +class objdict(dict): + def __getattr__(self, name): + if name in self: + return self[name] + else: + raise AttributeError("No such attribute: " + name) + + def __setattr__(self, name, value): + self[name] = value + + def __delattr__(self, name): + if name in self: + del self[name] + else: + raise AttributeError("No such attribute: " + name) + + +def main(): + args = parser.parse_args() + + # Set relative paths relative to current workdir + path_args = ["datadir", "experiments", "resume"] + for path_arg in path_args: + path = getattr(args, path_arg) + if path is not None and not os.path.isabs(path): + abs_path = os.path.abspath(os.path.join(os.getcwd(), path)) + setattr(args, path_arg, abs_path) + + # Access config as an object + args = objdict(args.__dict__) + + # Avoid creating new folders etc. + if args.evaluate: + args.dry_run = True + + # Init trainer + trainer = Trainer(args) + + # Execute + if args.evaluate: + with torch.no_grad(): + trainer.eval_model() + else: + trainer.train_model() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/brevitas_examples/bnn_pynq/cfg/__init__.py b/brevitas_examples/bnn_pynq/cfg/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/brevitas_examples/bnn_pynq/cfg/cnv_1w1a.ini b/brevitas_examples/bnn_pynq/cfg/cnv_1w1a.ini new file mode 100644 index 000000000..ef9fdc1f3 --- /dev/null +++ b/brevitas_examples/bnn_pynq/cfg/cnv_1w1a.ini @@ -0,0 +1,12 @@ +[MODEL] +ARCH: CNV +PRETRAINED_URL: https://github.com/Xilinx/brevitas/releases/download/bnn_pynq-r0/cnv_1w1a-758c8fef.pth +DATASET: CIFAR10 +IN_CHANNELS: 3 +NUM_CLASSES: 10 + +[QUANT] +WEIGHT_BIT_WIDTH: 1 +ACT_BIT_WIDTH: 1 +IN_BIT_WIDTH: 8 + diff --git a/brevitas_examples/bnn_pynq/cfg/cnv_1w2a.ini b/brevitas_examples/bnn_pynq/cfg/cnv_1w2a.ini new file mode 100644 index 000000000..4219915d1 --- /dev/null +++ b/brevitas_examples/bnn_pynq/cfg/cnv_1w2a.ini @@ -0,0 +1,12 @@ +[MODEL] +ARCH: CNV +PRETRAINED_URL: https://github.com/Xilinx/brevitas/releases/download/bnn_pynq-r0/cnv_1w2a-23b6e2e4.pth +DATASET: CIFAR10 +IN_CHANNELS: 3 +NUM_CLASSES: 10 + +[QUANT] +WEIGHT_BIT_WIDTH: 1 +ACT_BIT_WIDTH: 2 +IN_BIT_WIDTH: 8 + diff --git a/brevitas_examples/bnn_pynq/cfg/cnv_2w2a.ini b/brevitas_examples/bnn_pynq/cfg/cnv_2w2a.ini new file mode 100644 index 000000000..c434a9897 --- /dev/null +++ b/brevitas_examples/bnn_pynq/cfg/cnv_2w2a.ini @@ -0,0 +1,12 @@ +[MODEL] +ARCH: CNV +PRETRAINED_URL: https://github.com/Xilinx/brevitas/releases/download/bnn_pynq-r0/cnv_2w2a-0702987f.pth +DATASET: CIFAR10 +IN_CHANNELS: 3 +NUM_CLASSES: 10 + +[QUANT] +WEIGHT_BIT_WIDTH: 2 +ACT_BIT_WIDTH: 2 +IN_BIT_WIDTH: 8 + diff --git a/brevitas_examples/bnn_pynq/cfg/lfc_1w1a.ini b/brevitas_examples/bnn_pynq/cfg/lfc_1w1a.ini new file mode 100644 index 000000000..e8198aa46 --- /dev/null +++ b/brevitas_examples/bnn_pynq/cfg/lfc_1w1a.ini @@ -0,0 +1,12 @@ +[MODEL] +ARCH: LFC +PRETRAINED_URL: https://github.com/Xilinx/brevitas/releases/download/bnn_pynq-r0/lfc_1w1a-db6e13bd.pth +DATASET: MNIST +IN_CHANNELS: 1 +NUM_CLASSES: 10 + +[QUANT] +WEIGHT_BIT_WIDTH: 1 +ACT_BIT_WIDTH: 1 +IN_BIT_WIDTH: 1 + diff --git a/brevitas_examples/bnn_pynq/cfg/lfc_1w2a.ini b/brevitas_examples/bnn_pynq/cfg/lfc_1w2a.ini new file mode 100644 index 000000000..31cabb670 --- /dev/null +++ b/brevitas_examples/bnn_pynq/cfg/lfc_1w2a.ini @@ -0,0 +1,12 @@ +[MODEL] +ARCH: LFC +PRETRAINED_URL: https://github.com/Xilinx/brevitas/releases/download/bnn_pynq-r0/lfc_1w2a-0a771c67.pth +DATASET: MNIST +IN_CHANNELS: 1 +NUM_CLASSES: 10 + +[QUANT] +WEIGHT_BIT_WIDTH: 1 +ACT_BIT_WIDTH: 2 +IN_BIT_WIDTH: 2 + diff --git a/brevitas_examples/bnn_pynq/cfg/sfc_1w1a.ini b/brevitas_examples/bnn_pynq/cfg/sfc_1w1a.ini new file mode 100644 index 000000000..300c7c1da --- /dev/null +++ b/brevitas_examples/bnn_pynq/cfg/sfc_1w1a.ini @@ -0,0 +1,12 @@ +[MODEL] +ARCH: SFC +PRETRAINED_URL: https://github.com/Xilinx/brevitas/releases/download/bnn_pynq-r0/sfc_1w1a-fd8a6c3d.pth +DATASET: MNIST +IN_CHANNELS: 1 +NUM_CLASSES: 10 + +[QUANT] +WEIGHT_BIT_WIDTH: 1 +ACT_BIT_WIDTH: 1 +IN_BIT_WIDTH: 1 + diff --git a/brevitas_examples/bnn_pynq/cfg/sfc_1w2a.ini b/brevitas_examples/bnn_pynq/cfg/sfc_1w2a.ini new file mode 100644 index 000000000..e0ef97be1 --- /dev/null +++ b/brevitas_examples/bnn_pynq/cfg/sfc_1w2a.ini @@ -0,0 +1,12 @@ +[MODEL] +ARCH: SFC +PRETRAINED_URL: https://github.com/Xilinx/brevitas/releases/download/bnn_pynq-r0/sfc_1w2a-fdc0c779.pth +DATASET: MNIST +IN_CHANNELS: 1 +NUM_CLASSES: 10 + +[QUANT] +WEIGHT_BIT_WIDTH: 1 +ACT_BIT_WIDTH: 2 +IN_BIT_WIDTH: 2 + diff --git a/brevitas_examples/bnn_pynq/cfg/sfc_2w2a.ini b/brevitas_examples/bnn_pynq/cfg/sfc_2w2a.ini new file mode 100644 index 000000000..9c31d82db --- /dev/null +++ b/brevitas_examples/bnn_pynq/cfg/sfc_2w2a.ini @@ -0,0 +1,12 @@ +[MODEL] +ARCH: SFC +PRETRAINED_URL: https://github.com/Xilinx/brevitas/releases/download/bnn_pynq-r0/sfc_2w2a-35a7c41d.pth +DATASET: MNIST +IN_CHANNELS: 1 +NUM_CLASSES: 10 + +[QUANT] +WEIGHT_BIT_WIDTH: 2 +ACT_BIT_WIDTH: 2 +IN_BIT_WIDTH: 2 + diff --git a/brevitas_examples/bnn_pynq/cfg/tfc_1w1a.ini b/brevitas_examples/bnn_pynq/cfg/tfc_1w1a.ini new file mode 100644 index 000000000..8eb074aa4 --- /dev/null +++ b/brevitas_examples/bnn_pynq/cfg/tfc_1w1a.ini @@ -0,0 +1,12 @@ +[MODEL] +ARCH: TFC +PRETRAINED_URL: https://github.com/Xilinx/brevitas/releases/download/bnn_pynq-r0/tfc_1w1a-ff8140dc.pth +DATASET: MNIST +IN_CHANNELS: 1 +NUM_CLASSES: 10 + +[QUANT] +WEIGHT_BIT_WIDTH: 1 +ACT_BIT_WIDTH: 1 +IN_BIT_WIDTH: 1 + diff --git a/brevitas_examples/bnn_pynq/cfg/tfc_1w2a.ini b/brevitas_examples/bnn_pynq/cfg/tfc_1w2a.ini new file mode 100644 index 000000000..1c755a572 --- /dev/null +++ b/brevitas_examples/bnn_pynq/cfg/tfc_1w2a.ini @@ -0,0 +1,12 @@ +[MODEL] +ARCH: TFC +PRETRAINED_URL: https://github.com/Xilinx/brevitas/releases/download/bnn_pynq-r0/tfc_1w2a-95ad635b.pth +DATASET: MNIST +IN_CHANNELS: 1 +NUM_CLASSES: 10 + +[QUANT] +WEIGHT_BIT_WIDTH: 1 +ACT_BIT_WIDTH: 2 +IN_BIT_WIDTH: 2 + diff --git a/brevitas_examples/bnn_pynq/cfg/tfc_2w2a.ini b/brevitas_examples/bnn_pynq/cfg/tfc_2w2a.ini new file mode 100644 index 000000000..50fd788ff --- /dev/null +++ b/brevitas_examples/bnn_pynq/cfg/tfc_2w2a.ini @@ -0,0 +1,12 @@ +[MODEL] +ARCH: TFC +PRETRAINED_URL: https://github.com/Xilinx/brevitas/releases/download/bnn_pynq-r0/tfc_2w2a-7e0a62f1.pth +DATASET: MNIST +IN_CHANNELS: 1 +NUM_CLASSES: 10 + +[QUANT] +WEIGHT_BIT_WIDTH: 2 +ACT_BIT_WIDTH: 2 +IN_BIT_WIDTH: 2 + diff --git a/brevitas_examples/bnn_pynq/logger.py b/brevitas_examples/bnn_pynq/logger.py new file mode 100644 index 000000000..4ee5f51d0 --- /dev/null +++ b/brevitas_examples/bnn_pynq/logger.py @@ -0,0 +1,115 @@ +# MIT License +# +# Copyright (c) 2019 Xilinx +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import logging +import sys +import os + + +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +class TrainingEpochMeters(object): + def __init__(self): + self.batch_time = AverageMeter() + self.data_time = AverageMeter() + self.losses = AverageMeter() + self.top1 = AverageMeter() + self.top5 = AverageMeter() + + +class EvalEpochMeters(object): + def __init__(self): + self.model_time = AverageMeter() + self.loss_time = AverageMeter() + self.losses = AverageMeter() + self.top1 = AverageMeter() + self.top5 = AverageMeter() + + +class Logger(object): + + def __init__(self, output_dir_path, dry_run): + self.output_dir_path = output_dir_path + self.log = logging.getLogger('log') + self.log.setLevel(logging.INFO) + + # Stout logging + out_hdlr = logging.StreamHandler(sys.stdout) + out_hdlr.setFormatter(logging.Formatter('%(asctime)s %(message)s')) + out_hdlr.setLevel(logging.INFO) + self.log.addHandler(out_hdlr) + + # Txt logging + if not dry_run: + file_hdlr = logging.FileHandler(os.path.join(self.output_dir_path, 'log.txt')) + file_hdlr.setFormatter(logging.Formatter('%(asctime)s %(message)s')) + file_hdlr.setLevel(logging.INFO) + self.log.addHandler(file_hdlr) + self.log.propagate = False + + def info(self, arg): + self.log.info(arg) + + def eval_batch_cli_log(self, epoch_meters, batch, tot_batches): + self.info('Test: [{0}/{1}]\t' + 'Model Time {model_time.val:.3f} ({model_time.avg:.3f})\t' + 'Loss Time {loss_time.val:.3f} ({loss_time.avg:.3f})\t' + 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' + 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' + 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})\t' + .format(batch, tot_batches, + model_time=epoch_meters.model_time, + loss_time=epoch_meters.loss_time, + loss=epoch_meters.losses, + top1=epoch_meters.top1, + top5=epoch_meters.top5)) + + def training_batch_cli_log(self, epoch_meters, epoch, batch, tot_batches): + self.info('Epoch: [{0}][{1}/{2}]\t' + 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' + 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' + 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' + 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' + 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})\t' + .format(epoch, batch, tot_batches, + batch_time=epoch_meters.batch_time, + data_time=epoch_meters.data_time, + loss=epoch_meters.losses, + top1=epoch_meters.top1, + top5=epoch_meters.top5)) diff --git a/brevitas_examples/bnn_pynq/models/CNV.py b/brevitas_examples/bnn_pynq/models/CNV.py new file mode 100644 index 000000000..0c14264c2 --- /dev/null +++ b/brevitas_examples/bnn_pynq/models/CNV.py @@ -0,0 +1,128 @@ +# MIT License +# +# Copyright (c) 2019 Xilinx +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import torch +from torch.nn import Module, ModuleList, BatchNorm2d, MaxPool2d, BatchNorm1d +from .tensor_norm import TensorNorm +from .common import get_quant_conv2d, get_quant_linear, get_act_quant, get_quant_type +from brevitas.nn import QuantConv2d, QuantHardTanh, QuantLinear + +from brevitas.core.restrict_val import RestrictValueType +from brevitas.core.scaling import ScalingImplType + + +# QuantConv2d configuration +CNV_OUT_CH_POOL = [(64, False), (64, True), (128, False), (128, True), (256, False), (256, False)] + +# Intermediate QuantLinear configuration +INTERMEDIATE_FC_PER_OUT_CH_SCALING = False +INTERMEDIATE_FC_FEATURES = [(256, 512), (512, 512)] + +# Last QuantLinear configuration +LAST_FC_IN_FEATURES = 512 +LAST_FC_PER_OUT_CH_SCALING = False + +# MaxPool2d configuration +POOL_SIZE = 2 + + +class CNV(Module): + + def __init__(self, num_classes=10, weight_bit_width=None, act_bit_width=None, in_bit_width=None, in_ch=3): + super(CNV, self).__init__() + + weight_quant_type = get_quant_type(weight_bit_width) + act_quant_type = get_quant_type(act_bit_width) + in_quant_type = get_quant_type(in_bit_width) + max_in_val = 1-2**(-7) # for Q1.7 input format + self.conv_features = ModuleList() + self.linear_features = ModuleList() + + self.conv_features.append(QuantHardTanh(bit_width=in_bit_width, + quant_type=in_quant_type, + max_val=max_in_val, + restrict_scaling_type=RestrictValueType.POWER_OF_TWO, + scaling_impl_type=ScalingImplType.CONST)) + + for out_ch, is_pool_enabled in CNV_OUT_CH_POOL: + self.conv_features.append(get_quant_conv2d(in_ch=in_ch, + out_ch=out_ch, + bit_width=weight_bit_width, + quant_type=weight_quant_type)) + in_ch = out_ch + self.conv_features.append(BatchNorm2d(in_ch, eps=1e-4)) + self.conv_features.append(get_act_quant(act_bit_width, act_quant_type)) + if is_pool_enabled: + self.conv_features.append(MaxPool2d(kernel_size=2)) + + for in_features, out_features in INTERMEDIATE_FC_FEATURES: + self.linear_features.append(get_quant_linear(in_features=in_features, + out_features=out_features, + per_out_ch_scaling=INTERMEDIATE_FC_PER_OUT_CH_SCALING, + bit_width=weight_bit_width, + quant_type=weight_quant_type)) + self.linear_features.append(BatchNorm1d(out_features, eps=1e-4)) + self.linear_features.append(get_act_quant(act_bit_width, act_quant_type)) + + self.linear_features.append(get_quant_linear(in_features=LAST_FC_IN_FEATURES, + out_features=num_classes, + per_out_ch_scaling=LAST_FC_PER_OUT_CH_SCALING, + bit_width=weight_bit_width, + quant_type=weight_quant_type)) + self.linear_features.append(TensorNorm()) + + for m in self.modules(): + if isinstance(m, QuantConv2d) or isinstance(m, QuantLinear): + torch.nn.init.uniform_(m.weight.data, -1, 1) + + + def clip_weights(self, min_val, max_val): + for mod in self.conv_features: + if isinstance(mod, QuantConv2d): + mod.weight.data.clamp_(min_val, max_val) + for mod in self.linear_features: + if isinstance(mod, QuantLinear): + mod.weight.data.clamp_(min_val, max_val) + + def forward(self, x): + x = 2.0 * x - torch.tensor([1.0], device=x.device) + for mod in self.conv_features: + x = mod(x) + x = x.view(x.shape[0], -1) + for mod in self.linear_features: + x = mod(x) + return x + + +def cnv(cfg): + weight_bit_width = cfg.getint('QUANT', 'WEIGHT_BIT_WIDTH') + act_bit_width = cfg.getint('QUANT', 'ACT_BIT_WIDTH') + in_bit_width = cfg.getint('QUANT', 'IN_BIT_WIDTH') + num_classes = cfg.getint('MODEL', 'NUM_CLASSES') + in_channels = cfg.getint('MODEL', 'IN_CHANNELS') + net = CNV(weight_bit_width=weight_bit_width, + act_bit_width=act_bit_width, + in_bit_width=in_bit_width, + num_classes=num_classes, + in_ch=in_channels) + return net + diff --git a/brevitas_examples/bnn_pynq/models/LFC.py b/brevitas_examples/bnn_pynq/models/LFC.py new file mode 100644 index 000000000..d55c80eee --- /dev/null +++ b/brevitas_examples/bnn_pynq/models/LFC.py @@ -0,0 +1,97 @@ +# MIT License +# +# Copyright (c) 2019 Xilinx +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from functools import reduce +from operator import mul + +from torch.nn import Module, ModuleList, BatchNorm1d, Dropout +import torch + +from .common import get_quant_linear, get_act_quant, get_quant_type, QuantLinear + +FC_OUT_FEATURES = [1024, 1024, 1024] +INTERMEDIATE_FC_PER_OUT_CH_SCALING = True +LAST_FC_PER_OUT_CH_SCALING = False +IN_DROPOUT = 0.2 +HIDDEN_DROPOUT = 0.2 + + +class LFC(Module): + + def __init__(self, num_classes=10, weight_bit_width=None, act_bit_width=None, + in_bit_width=None, in_ch=1, in_features=(28, 28)): + super(LFC, self).__init__() + + weight_quant_type = get_quant_type(weight_bit_width) + act_quant_type = get_quant_type(act_bit_width) + in_quant_type = get_quant_type(in_bit_width) + + self.features = ModuleList() + self.features.append(get_act_quant(in_bit_width, in_quant_type)) + self.features.append(Dropout(p=IN_DROPOUT)) + in_features = reduce(mul, in_features) + for out_features in FC_OUT_FEATURES: + self.features.append(get_quant_linear(in_features=in_features, + out_features=out_features, + per_out_ch_scaling=INTERMEDIATE_FC_PER_OUT_CH_SCALING, + bit_width=weight_bit_width, + quant_type=weight_quant_type)) + in_features = out_features + self.features.append(BatchNorm1d(num_features=in_features)) + self.features.append(get_act_quant(act_bit_width, act_quant_type)) + self.features.append(Dropout(p=HIDDEN_DROPOUT)) + self.features.append(get_quant_linear(in_features=in_features, + out_features=num_classes, + per_out_ch_scaling=LAST_FC_PER_OUT_CH_SCALING, + bit_width=weight_bit_width, + quant_type=weight_quant_type)) + self.features.append(BatchNorm1d(num_features=num_classes)) + + for m in self.modules(): + if isinstance(m, QuantLinear): + torch.nn.init.uniform_(m.weight.data, -1, 1) + + def clip_weights(self, min_val, max_val): + for mod in self.features: + if isinstance(mod, QuantLinear): + mod.weight.data.clamp_(min_val, max_val) + + def forward(self, x): + x = x.view(x.shape[0], -1) + x = 2.0 * x - torch.tensor([1.0], device=x.device) + for mod in self.features: + x = mod(x) + return x + + +def lfc(cfg): + weight_bit_width = cfg.getint('QUANT', 'WEIGHT_BIT_WIDTH') + act_bit_width = cfg.getint('QUANT', 'ACT_BIT_WIDTH') + in_bit_width = cfg.getint('QUANT', 'IN_BIT_WIDTH') + num_classes = cfg.getint('MODEL', 'NUM_CLASSES') + in_channels = cfg.getint('MODEL', 'IN_CHANNELS') + net = LFC(weight_bit_width=weight_bit_width, + act_bit_width=act_bit_width, + in_bit_width=in_bit_width, + num_classes=num_classes, + in_ch=in_channels) + return net \ No newline at end of file diff --git a/brevitas_examples/bnn_pynq/models/SFC.py b/brevitas_examples/bnn_pynq/models/SFC.py new file mode 100644 index 000000000..83cf53918 --- /dev/null +++ b/brevitas_examples/bnn_pynq/models/SFC.py @@ -0,0 +1,97 @@ +# MIT License +# +# Copyright (c) 2019 Xilinx +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from functools import reduce +from operator import mul + +from torch.nn import Module, ModuleList, BatchNorm1d, Dropout +import torch + +from .common import get_quant_linear, get_act_quant, get_quant_type, QuantLinear + +FC_OUT_FEATURES = [256, 256, 256] +INTERMEDIATE_FC_PER_OUT_CH_SCALING = True +LAST_FC_PER_OUT_CH_SCALING = False +IN_DROPOUT = 0.2 +HIDDEN_DROPOUT = 0.2 + + +class SFC(Module): + + def __init__(self, num_classes=10, weight_bit_width=None, act_bit_width=None, + in_bit_width=None, in_ch=1, in_features=(28, 28)): + super(SFC, self).__init__() + + weight_quant_type = get_quant_type(weight_bit_width) + act_quant_type = get_quant_type(act_bit_width) + in_quant_type = get_quant_type(in_bit_width) + + self.features = ModuleList() + self.features.append(get_act_quant(in_bit_width, in_quant_type)) + self.features.append(Dropout(p=IN_DROPOUT)) + in_features = reduce(mul, in_features) + for out_features in FC_OUT_FEATURES: + self.features.append(get_quant_linear(in_features=in_features, + out_features=out_features, + per_out_ch_scaling=INTERMEDIATE_FC_PER_OUT_CH_SCALING, + bit_width=weight_bit_width, + quant_type=weight_quant_type)) + in_features = out_features + self.features.append(BatchNorm1d(num_features=in_features)) + self.features.append(get_act_quant(act_bit_width, act_quant_type)) + self.features.append(Dropout(p=HIDDEN_DROPOUT)) + self.features.append(get_quant_linear(in_features=in_features, + out_features=num_classes, + per_out_ch_scaling=LAST_FC_PER_OUT_CH_SCALING, + bit_width=weight_bit_width, + quant_type=weight_quant_type)) + self.features.append(BatchNorm1d(num_features=num_classes)) + + for m in self.modules(): + if isinstance(m, QuantLinear): + torch.nn.init.uniform_(m.weight.data, -1, 1) + + def clip_weights(self, min_val, max_val): + for mod in self.features: + if isinstance(mod, QuantLinear): + mod.weight.data.clamp_(min_val, max_val) + + def forward(self, x): + x = x.view(x.shape[0], -1) + x = 2.0 * x - torch.tensor([1.0], device=x.device) + for mod in self.features: + x = mod(x) + return x + + +def sfc(cfg): + weight_bit_width = cfg.getint('QUANT', 'WEIGHT_BIT_WIDTH') + act_bit_width = cfg.getint('QUANT', 'ACT_BIT_WIDTH') + in_bit_width = cfg.getint('QUANT', 'IN_BIT_WIDTH') + num_classes = cfg.getint('MODEL', 'NUM_CLASSES') + in_channels = cfg.getint('MODEL', 'IN_CHANNELS') + net = SFC(weight_bit_width=weight_bit_width, + act_bit_width=act_bit_width, + in_bit_width=in_bit_width, + num_classes=num_classes, + in_ch=in_channels) + return net \ No newline at end of file diff --git a/brevitas_examples/bnn_pynq/models/TFC.py b/brevitas_examples/bnn_pynq/models/TFC.py new file mode 100644 index 000000000..be5d33f75 --- /dev/null +++ b/brevitas_examples/bnn_pynq/models/TFC.py @@ -0,0 +1,97 @@ +# MIT License +# +# Copyright (c) 2019 Xilinx +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from functools import reduce +from operator import mul + +from torch.nn import Module, ModuleList, BatchNorm1d, Dropout +import torch + +from .common import get_quant_linear, get_act_quant, get_quant_type, QuantLinear + +FC_OUT_FEATURES = [64, 64, 64] +INTERMEDIATE_FC_PER_OUT_CH_SCALING = True +LAST_FC_PER_OUT_CH_SCALING = False +IN_DROPOUT = 0.2 +HIDDEN_DROPOUT = 0.2 + + +class TFC(Module): + + def __init__(self, num_classes=10, weight_bit_width=None, act_bit_width=None, + in_bit_width=None, in_ch=1, in_features=(28, 28)): + super(TFC, self).__init__() + + weight_quant_type = get_quant_type(weight_bit_width) + act_quant_type = get_quant_type(act_bit_width) + in_quant_type = get_quant_type(in_bit_width) + + self.features = ModuleList() + self.features.append(get_act_quant(in_bit_width, in_quant_type)) + self.features.append(Dropout(p=IN_DROPOUT)) + in_features = reduce(mul, in_features) + for out_features in FC_OUT_FEATURES: + self.features.append(get_quant_linear(in_features=in_features, + out_features=out_features, + per_out_ch_scaling=INTERMEDIATE_FC_PER_OUT_CH_SCALING, + bit_width=weight_bit_width, + quant_type=weight_quant_type)) + in_features = out_features + self.features.append(BatchNorm1d(num_features=in_features)) + self.features.append(get_act_quant(act_bit_width, act_quant_type)) + self.features.append(Dropout(p=HIDDEN_DROPOUT)) + self.features.append(get_quant_linear(in_features=in_features, + out_features=num_classes, + per_out_ch_scaling=LAST_FC_PER_OUT_CH_SCALING, + bit_width=weight_bit_width, + quant_type=weight_quant_type)) + self.features.append(BatchNorm1d(num_features=num_classes)) + + for m in self.modules(): + if isinstance(m, QuantLinear): + torch.nn.init.uniform_(m.weight.data, -1, 1) + + def clip_weights(self, min_val, max_val): + for mod in self.features: + if isinstance(mod, QuantLinear): + mod.weight.data.clamp_(min_val, max_val) + + def forward(self, x): + x = x.view(x.shape[0], -1) + x = 2.0 * x - torch.tensor([1.0], device=x.device) + for mod in self.features: + x = mod(x) + return x + + +def tfc(cfg): + weight_bit_width = cfg.getint('QUANT', 'WEIGHT_BIT_WIDTH') + act_bit_width = cfg.getint('QUANT', 'ACT_BIT_WIDTH') + in_bit_width = cfg.getint('QUANT', 'IN_BIT_WIDTH') + num_classes = cfg.getint('MODEL', 'NUM_CLASSES') + in_channels = cfg.getint('MODEL', 'IN_CHANNELS') + net = TFC(weight_bit_width=weight_bit_width, + act_bit_width=act_bit_width, + in_bit_width=in_bit_width, + num_classes=num_classes, + in_ch=in_channels) + return net diff --git a/brevitas_examples/bnn_pynq/models/__init__.py b/brevitas_examples/bnn_pynq/models/__init__.py new file mode 100644 index 000000000..b67fba803 --- /dev/null +++ b/brevitas_examples/bnn_pynq/models/__init__.py @@ -0,0 +1,116 @@ +# MIT License +# +# Copyright (c) 2019 Xilinx +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import os +from configparser import ConfigParser + +import torch +from torch import hub + +__all__ = ['cnv_1w1a', 'cnv_1w2a', 'cnv_2w2a', + 'sfc_1w1a', 'sfc_1w2a', 'sfc_2w2a', + 'tfc_1w1a', 'tfc_1w2a', 'tfc_2w2a', + 'lfc_1w1a', 'lfc_1w2a'] + +from .CNV import cnv +from .LFC import lfc +from .TFC import tfc +from .SFC import sfc + +model_impl = { + 'CNV': cnv, + 'LFC': lfc, + 'TFC': tfc, + 'SFC': sfc +} + + +def model_with_cfg(name, pretrained): + cfg = ConfigParser() + current_dir = os.path.dirname(os.path.abspath(__file__)) + config_path = os.path.join(current_dir, '..', 'cfg', name.lower() + '.ini') + assert os.path.exists(config_path) + cfg.read(config_path) + arch = cfg.get('MODEL', 'ARCH') + model = model_impl[arch](cfg) + if pretrained: + checkpoint = cfg.get('MODEL', 'PRETRAINED_URL') + state_dict = hub.load_state_dict_from_url(checkpoint, progress=True, map_location='cpu') + model.load_state_dict(state_dict, strict=True) + return model, cfg + + +def cnv_1w1a(pretrained=True): + model, _ = model_with_cfg('cnv_1w1a', pretrained) + return model + + +def cnv_1w2a(pretrained=True): + model, _ = model_with_cfg('cnv_1w2a', pretrained) + return model + + +def cnv_2w2a(pretrained=True): + model, _ = model_with_cfg('cnv_2w2a', pretrained) + return model + + +def sfc_1w1a(pretrained=True): + model, _ = model_with_cfg('sfc_1w1a', pretrained) + return model + + +def sfc_1w2a(pretrained=True): + model, _ = model_with_cfg('sfc_1w2a', pretrained) + return model + + +def sfc_2w2a(pretrained=True): + model, _ = model_with_cfg('sfc_2w2a', pretrained) + return model + + +def tfc_1w1a(pretrained=True): + model, _ = model_with_cfg('tfc_1w1a', pretrained) + return model + + +def tfc_1w2a(pretrained=True): + model, _ = model_with_cfg('tfc_1w2a', pretrained) + return model + + +def tfc_2w2a(pretrained=True): + model, _ = model_with_cfg('tfc_2w2a', pretrained) + return model + + +def lfc_1w1a(pretrained=True): + model, _ = model_with_cfg('lfc_1w1a', pretrained) + return model + + +def lfc_1w2a(pretrained=True): + model, _ = model_with_cfg('lfc_1w2a', pretrained) + return model + + diff --git a/brevitas_examples/bnn_pynq/models/common.py b/brevitas_examples/bnn_pynq/models/common.py new file mode 100644 index 000000000..3813786db --- /dev/null +++ b/brevitas_examples/bnn_pynq/models/common.py @@ -0,0 +1,96 @@ +# MIT License +# +# Copyright (c) 2019 Xilinx +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from brevitas.core.bit_width import BitWidthImplType +from brevitas.core.quant import QuantType +from brevitas.core.restrict_val import RestrictValueType +from brevitas.core.scaling import ScalingImplType +from brevitas.nn import QuantConv2d, QuantHardTanh, QuantLinear + +# Quant common +BIT_WIDTH_IMPL_TYPE = BitWidthImplType.CONST +SCALING_VALUE_TYPE = RestrictValueType.LOG_FP +NARROW_RANGE_ENABLED = True + +# Weight quant common +BIAS_ENABLED = False +WEIGHT_SCALING_IMPL_TYPE = ScalingImplType.CONST +WEIGHT_SCALING_CONST = 1.0 + +# QuantHardTanh configuration +HARD_TANH_MIN = -1.0 +HARD_TANH_MAX = 1.0 +ACT_PER_OUT_CH_SCALING = False +ACT_SCALING_IMPL_TYPE = ScalingImplType.CONST + +# QuantConv2d configuration +KERNEL_SIZE = 3 +CONV_PER_OUT_CH_SCALING = False + + +def get_quant_type(bit_width): + if bit_width is None: + return QuantType.FP + elif bit_width == 1: + return QuantType.BINARY + else: + return QuantType.INT + + +def get_act_quant(act_bit_width, act_quant_type): + return QuantHardTanh(quant_type=act_quant_type, + bit_width=act_bit_width, + bit_width_impl_type=BIT_WIDTH_IMPL_TYPE, + min_val=HARD_TANH_MIN, + max_val=HARD_TANH_MAX, + scaling_impl_type=ACT_SCALING_IMPL_TYPE, + restrict_scaling_type=SCALING_VALUE_TYPE, + scaling_per_channel=ACT_PER_OUT_CH_SCALING, + narrow_range=NARROW_RANGE_ENABLED) + + +def get_quant_linear(in_features, out_features, per_out_ch_scaling, bit_width, quant_type): + return QuantLinear(bias=BIAS_ENABLED, + in_features=in_features, + out_features=out_features, + weight_quant_type=quant_type, + weight_bit_width=bit_width, + weight_scaling_const=WEIGHT_SCALING_CONST, + weight_bit_width_impl_type=BIT_WIDTH_IMPL_TYPE, + weight_scaling_per_output_channel=per_out_ch_scaling, + weight_scaling_impl_type=WEIGHT_SCALING_IMPL_TYPE, + weight_narrow_range=NARROW_RANGE_ENABLED) + + +def get_quant_conv2d(in_ch, out_ch, bit_width, quant_type): + return QuantConv2d(in_channels=in_ch, + kernel_size=KERNEL_SIZE, + out_channels=out_ch, + weight_quant_type=quant_type, + weight_bit_width=bit_width, + weight_narrow_range=NARROW_RANGE_ENABLED, + weight_scaling_impl_type=WEIGHT_SCALING_IMPL_TYPE, + weight_scaling_const=WEIGHT_SCALING_CONST, + weight_scaling_per_output_channel=CONV_PER_OUT_CH_SCALING, + weight_restrict_scaling_type=SCALING_VALUE_TYPE, + weight_bit_width_impl_type=BIT_WIDTH_IMPL_TYPE, + bias=BIAS_ENABLED) diff --git a/brevitas_examples/bnn_pynq/models/losses.py b/brevitas_examples/bnn_pynq/models/losses.py new file mode 100644 index 000000000..d3887399d --- /dev/null +++ b/brevitas_examples/bnn_pynq/models/losses.py @@ -0,0 +1,31 @@ + +import torch +import torch.nn as nn +from torch.autograd import Function + +class squared_hinge_loss(Function): + @staticmethod + def forward(ctx, predictions, targets): + ctx.save_for_backward(predictions, targets) + output = 1.-predictions.mul(targets) + output[output.le(0.)] = 0. + loss = torch.mean(output.mul(output)) + return loss + + @staticmethod + def backward(ctx, grad_output): + predictions, targets = ctx.saved_tensors + output=1.-predictions.mul(targets) + output[output.le(0.)]=0. + grad_output.resize_as_(predictions).copy_(targets).mul_(-2.).mul_(output) + grad_output.mul_(output.ne(0).float()) + grad_output.div_(predictions.numel()) + return grad_output, None + +class SqrHingeLoss(nn.Module): + # Squared Hinge Loss + def __init__(self): + super(SqrHingeLoss, self).__init__() + + def forward(self, input, target): + return squared_hinge_loss.apply(input, target) \ No newline at end of file diff --git a/brevitas_examples/bnn_pynq/models/tensor_norm.py b/brevitas_examples/bnn_pynq/models/tensor_norm.py new file mode 100644 index 000000000..613f1c04e --- /dev/null +++ b/brevitas_examples/bnn_pynq/models/tensor_norm.py @@ -0,0 +1,56 @@ +# MIT License +# +# Copyright (c) 2019 Xilinx +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import torch +import torch.nn as nn +import torch.nn.init as init + + +class TensorNorm(nn.Module): + def __init__(self, eps=1e-4, momentum=0.1): + super().__init__() + + self.eps = eps + self.momentum = momentum + self.weight = nn.Parameter(torch.rand(1)) + self.bias = nn.Parameter(torch.rand(1)) + self.register_buffer('running_mean', torch.zeros(1)) + self.register_buffer('running_var', torch.ones(1)) + self.reset_running_stats() + + def reset_running_stats(self): + self.running_mean.zero_() + self.running_var.fill_(1) + init.ones_(self.weight) + init.zeros_(self.bias) + + def forward(self, x): + if self.training: + mean = x.mean() + unbias_var = x.var(unbiased=True) + biased_var = x.var(unbiased=False) + self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.detach() + self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.detach() + inv_std = 1 / (biased_var + self.eps).pow(0.5) + return (x - mean) * inv_std * self.weight + self.bias + else: + return ((x - self.running_mean) / (self.running_var + self.eps).pow(0.5)) * self.weight + self.bias \ No newline at end of file diff --git a/brevitas_examples/bnn_pynq/trainer.py b/brevitas_examples/bnn_pynq/trainer.py new file mode 100644 index 000000000..c1e09d8aa --- /dev/null +++ b/brevitas_examples/bnn_pynq/trainer.py @@ -0,0 +1,331 @@ +# MIT License +# +# Copyright (c) 2019 Xilinx +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import random +import os +import time +from datetime import datetime + +import torch +import torch.optim as optim +from torch import nn +from torch.optim.lr_scheduler import MultiStepLR +from torch.utils.data import DataLoader +from torchvision import transforms +from torchvision.datasets import MNIST, CIFAR10 + +from .logger import Logger, TrainingEpochMeters, EvalEpochMeters +from .models import model_with_cfg +from .models.losses import SqrHingeLoss + + +def accuracy(output, target, topk=(1,)): + """Computes the precision@k for the specified values of k""" + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +class Trainer(object): + def __init__(self, args): + + model, cfg = model_with_cfg(args.network, args.pretrained) + + # Init arguments + self.args = args + prec_name = "_{}W{}A".format(cfg.getint('QUANT', 'WEIGHT_BIT_WIDTH'), cfg.getint('QUANT', 'ACT_BIT_WIDTH')) + experiment_name = '{}{}_{}'.format(args.network, prec_name, datetime.now().strftime('%Y%m%d_%H%M%S')) + self.output_dir_path = os.path.join(args.experiments, experiment_name) + + if self.args.resume: + self.output_dir_path, _ = os.path.split(args.resume) + self.output_dir_path, _ = os.path.split(self.output_dir_path) + + if not args.dry_run: + self.checkpoints_dir_path = os.path.join(self.output_dir_path, 'checkpoints') + if not args.resume: + os.mkdir(self.output_dir_path) + os.mkdir(self.checkpoints_dir_path) + self.logger = Logger(self.output_dir_path, args.dry_run) + + # Randomness + random.seed(args.random_seed) + torch.manual_seed(args.random_seed) + torch.cuda.manual_seed_all(args.random_seed) + + # Datasets + transform_to_tensor = transforms.Compose([transforms.ToTensor()]) + + dataset = cfg.get('MODEL', 'DATASET') + self.num_classes = cfg.getint('MODEL', 'NUM_CLASSES') + if dataset == 'CIFAR10': + train_transforms_list = [transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor()] + transform_train = transforms.Compose(train_transforms_list) + builder = CIFAR10 + + elif dataset == 'MNIST': + transform_train = transform_to_tensor + builder = MNIST + else: + raise Exception("Dataset not supported: {}".format(args.dataset)) + + train_set = builder(root=args.datadir, + train=True, + download=True, + transform=transform_train) + test_set = builder(root=args.datadir, + train=False, + download=True, + transform=transform_to_tensor) + self.train_loader = DataLoader(train_set, + batch_size=args.batch_size, + shuffle=True, + num_workers=args.num_workers) + self.test_loader = DataLoader(test_set, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.num_workers) + + # Init starting values + self.starting_epoch = 1 + self.best_val_acc = 0 + + # Setup device + if args.gpus is not None: + args.gpus = [int(i) for i in args.gpus.split(',')] + self.device = 'cuda:' + str(args.gpus[0]) + torch.backends.cudnn.benchmark = True + else: + self.device = 'cpu' + self.device = torch.device(self.device) + + # Resume checkpoint, if any + if args.resume: + print('Loading model checkpoint at: {}'.format(args.resume)) + package = torch.load(args.resume, map_location='cpu') + model_state_dict = package['state_dict'] + self.model.load_state_dict(model_state_dict, strict=args.strict) + + if args.gpus is not None and len(args.gpus) == 1: + model = model.to(device=self.device) + if args.gpus is not None and len(args.gpus) > 1: + model = nn.DataParallel(model, args.gpus) + self.model = model + + # Loss function + if args.loss == 'SqrHinge': + self.criterion = SqrHingeLoss() + else: + self.criterion = nn.CrossEntropyLoss() + self.criterion = self.criterion.to(device=self.device) + + # Init optimizer + if args.optim == 'ADAM': + self.optimizer = optim.Adam(self.model.parameters(), + lr=args.lr, + weight_decay=args.weight_decay) + elif args.optim == 'SGD': + self.optimizer = optim.SGD(self.model.parameters(), + lr=self.args.lr, + momentum=self.args.momentum, + weight_decay=self.args.weight_decay) + + # Resume optimizer, if any + if args.resume and not args.evaluate: + self.logger.log.info("Loading optimizer checkpoint") + if 'optim_dict' in package.keys(): + self.optimizer.load_state_dict(package['optim_dict']) + if 'epoch' in package.keys(): + self.starting_epoch = package['epoch'] + if 'best_val_acc' in package.keys(): + self.best_val_acc = package['best_val_acc'] + + # LR scheduler + if args.scheduler == 'STEP': + milestones = [int(i) for i in args.milestones.split(',')] + self.scheduler = MultiStepLR(optimizer=self.optimizer, + milestones=milestones, + gamma=0.1) + elif args.scheduler == 'FIXED': + self.scheduler = None + else: + raise Exception("Unrecognized scheduler {}".format(self.args.scheduler)) + + # Resume scheduler, if any + if args.resume and not args.evaluate and self.scheduler is not None: + self.scheduler.last_epoch = package['epoch'] - 1 + + def checkpoint_best(self, epoch, name): + best_path = os.path.join(self.checkpoints_dir_path, name) + self.logger.info("Saving checkpoint model to {}".format(best_path)) + torch.save({ + 'state_dict': self.model.state_dict(), + 'optim_dict': self.optimizer.state_dict(), + 'epoch': epoch + 1, + 'best_val_acc': self.best_val_acc, + }, best_path) + + def train_model(self): + + # training starts + if self.args.detect_nan: + torch.autograd.set_detect_anomaly(True) + + for epoch in range(self.starting_epoch, self.args.epochs): + + # Set to training mode + self.model.train() + self.criterion.train() + + # Init metrics + epoch_meters = TrainingEpochMeters() + start_data_loading = time.time() + + + for i, data in enumerate(self.train_loader): + (input, target) = data + input = input.to(self.device, non_blocking=True) + target = target.to(self.device, non_blocking=True) + + # for hingeloss only + if isinstance(self.criterion, SqrHingeLoss): + target=target.unsqueeze(1) + target_onehot = torch.Tensor(target.size(0), self.num_classes).to(self.device, non_blocking=True) + target_onehot.fill_(-1) + target_onehot.scatter_(1, target, 1) + target=target.squeeze() + target_var = target_onehot + else: + target_var = target + + # measure data loading time + epoch_meters.data_time.update(time.time() - start_data_loading) + + # Training batch starts + start_batch = time.time() + output = self.model(input) + loss = self.criterion(output, target_var) + + # compute gradient and do SGD step + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + + self.model.clip_weights(-1,1) + + # measure elapsed time + epoch_meters.batch_time.update(time.time() - start_batch) + + if i % int(self.args.log_freq) == 0 or i == len(self.train_loader) - 1: + prec1, prec5 = accuracy(output.detach(), target, topk=(1, 5)) + epoch_meters.losses.update(loss.item(), input.size(0)) + epoch_meters.top1.update(prec1.item(), input.size(0)) + epoch_meters.top5.update(prec5.item(), input.size(0)) + self.logger.training_batch_cli_log(epoch_meters, epoch, i, len(self.train_loader)) + + # training batch ends + start_data_loading = time.time() + + # Set the learning rate + if self.scheduler is not None: + self.scheduler.step(epoch) + else: + # Set the learning rate + if epoch%40==0: + self.optimizer.param_groups[0]['lr'] *= 0.5 + + # Perform eval + with torch.no_grad(): + top1avg = self.eval_model(epoch) + + # checkpoint + if top1avg >= self.best_val_acc and not self.args.dry_run: + self.best_val_acc = top1avg + self.checkpoint_best(epoch, "best.tar") + elif not self.args.dry_run: + self.checkpoint_best(epoch, "checkpoint.tar") + + # training ends + if not self.args.dry_run: + return os.path.join(self.checkpoints_dir_path, "best.tar") + + def eval_model(self, epoch=None): + eval_meters = EvalEpochMeters() + + # switch to evaluate mode + self.model.eval() + self.criterion.eval() + + for i, data in enumerate(self.test_loader): + + end = time.time() + (input, target) = data + + input = input.to(self.device, non_blocking=True) + target = target.to(self.device, non_blocking=True) + + # for hingeloss only + if isinstance(self.criterion, SqrHingeLoss): + target=target.unsqueeze(1) + target_onehot = torch.Tensor(target.size(0), self.num_classes).to(self.device, non_blocking=True) + target_onehot.fill_(-1) + target_onehot.scatter_(1, target, 1) + target=target.squeeze() + target_var = target_onehot + else: + target_var = target + + # compute output + output = self.model(input) + + # measure model elapsed time + eval_meters.model_time.update(time.time() - end) + end = time.time() + + #compute loss + loss = self.criterion(output, target_var) + eval_meters.loss_time.update(time.time() - end) + + pred = output.data.argmax(1, keepdim=True) + correct = pred.eq(target.data.view_as(pred)).sum() + prec1 = 100. * correct.float() / input.size(0) + + _, prec5 = accuracy(output, target, topk=(1, 5)) + eval_meters.losses.update(loss.item(), input.size(0)) + eval_meters.top1.update(prec1.item(), input.size(0)) + eval_meters.top5.update(prec5.item(), input.size(0)) + + #Eval batch ends + self.logger.eval_batch_cli_log(eval_meters, i, len(self.test_loader)) + + return eval_meters.top1.avg diff --git a/setup.py b/setup.py index 567c0877e..75054661c 100644 --- a/setup.py +++ b/setup.py @@ -184,6 +184,7 @@ def run(self): }, entry_points={ 'console_scripts': [ + 'brevitas_bnn_pynq_train = brevitas_examples.bnn_pynq.bnn_pynq_train:main', 'brevitas_imagenet_val = brevitas_examples.imagenet_classification.imagenet_val:main', 'brevitas_quartznet_val = brevitas_examples.speech_to_text.quartznet_val:main', 'brevitas_melgan_val = brevitas_examples.text_to_speech.melgan_val:main', diff --git a/test/brevitas_examples/test_import.py b/test/brevitas_examples/test_import.py index 2b61a7b45..9f7908067 100644 --- a/test/brevitas_examples/test_import.py +++ b/test/brevitas_examples/test_import.py @@ -39,6 +39,14 @@ # POSSIBILITY OF SUCH DAMAGE. +def test_import_bnn_pynq(): + from brevitas_examples.bnn_pynq import ( + cnv_1w1a, cnv_1w2a, cnv_2w2a, + sfc_1w1a, sfc_1w2a, sfc_2w2a, + tfc_1w1a, tfc_1w2a, tfc_2w2a, + lfc_1w1a, lfc_1w2a) + + def test_import_image_classification(): from brevitas_examples.imagenet_classification import ( quant_mobilenet_v1_4b,