From 7f126336eea4d3d807de45ac46cd8945ca673f68 Mon Sep 17 00:00:00 2001 From: Alessandro Pappalardo Date: Tue, 8 Sep 2020 16:54:42 +0100 Subject: [PATCH] Examples (BNN-PYNQ): idiomatic and simplified implementation w.r.t. new D.I. APIs Signed-off-by: Alessandro Pappalardo --- brevitas_examples/bnn_pynq/cfg/lfc_1w1a.ini | 3 +- brevitas_examples/bnn_pynq/cfg/lfc_1w2a.ini | 3 +- brevitas_examples/bnn_pynq/cfg/sfc_1w1a.ini | 3 +- brevitas_examples/bnn_pynq/cfg/sfc_1w2a.ini | 3 +- brevitas_examples/bnn_pynq/cfg/sfc_2w2a.ini | 3 +- brevitas_examples/bnn_pynq/cfg/tfc_1w1a.ini | 3 +- brevitas_examples/bnn_pynq/cfg/tfc_1w2a.ini | 3 +- brevitas_examples/bnn_pynq/cfg/tfc_2w2a.ini | 3 +- brevitas_examples/bnn_pynq/models/CNV.py | 77 ++++++++------- .../bnn_pynq/models/{SFC.py => FC.py} | 76 ++++++++------- brevitas_examples/bnn_pynq/models/LFC.py | 97 ------------------- brevitas_examples/bnn_pynq/models/TFC.py | 97 ------------------- brevitas_examples/bnn_pynq/models/__init__.py | 8 +- brevitas_examples/bnn_pynq/models/common.py | 40 +++++++- brevitas_examples/bnn_pynq/trainer.py | 2 +- 15 files changed, 137 insertions(+), 284 deletions(-) rename brevitas_examples/bnn_pynq/models/{SFC.py => FC.py} (54%) delete mode 100644 brevitas_examples/bnn_pynq/models/LFC.py delete mode 100644 brevitas_examples/bnn_pynq/models/TFC.py diff --git a/brevitas_examples/bnn_pynq/cfg/lfc_1w1a.ini b/brevitas_examples/bnn_pynq/cfg/lfc_1w1a.ini index 5a09e84d0..22b1b8ea6 100644 --- a/brevitas_examples/bnn_pynq/cfg/lfc_1w1a.ini +++ b/brevitas_examples/bnn_pynq/cfg/lfc_1w1a.ini @@ -1,9 +1,10 @@ [MODEL] -ARCH: LFC +ARCH: FC PRETRAINED_URL: https://github.com/Xilinx/brevitas/releases/download/bnn_pynq-r0/lfc_1w1a-db6e13bd.pth EVAL_LOG: https://github.com/Xilinx/brevitas/releases/download/bnn_pynq-r0/lfc_1w1a_eval-db6e13bd.txt DATASET: MNIST IN_CHANNELS: 1 +OUT_FEATURES = [1024, 1024, 1024] NUM_CLASSES: 10 [QUANT] diff --git a/brevitas_examples/bnn_pynq/cfg/lfc_1w2a.ini b/brevitas_examples/bnn_pynq/cfg/lfc_1w2a.ini index e15c573a5..f76fb6e72 100644 --- a/brevitas_examples/bnn_pynq/cfg/lfc_1w2a.ini +++ b/brevitas_examples/bnn_pynq/cfg/lfc_1w2a.ini @@ -1,9 +1,10 @@ [MODEL] -ARCH: LFC +ARCH: FC PRETRAINED_URL: https://github.com/Xilinx/brevitas/releases/download/bnn_pynq-r0/lfc_1w2a-0a771c67.pth EVAL_LOG: https://github.com/Xilinx/brevitas/releases/download/bnn_pynq-r0/lfc_1w2a_eval-0a771c67.txt DATASET: MNIST IN_CHANNELS: 1 +OUT_FEATURES = [1024, 1024, 1024] NUM_CLASSES: 10 [QUANT] diff --git a/brevitas_examples/bnn_pynq/cfg/sfc_1w1a.ini b/brevitas_examples/bnn_pynq/cfg/sfc_1w1a.ini index c1f1211dc..1ed2f5f90 100644 --- a/brevitas_examples/bnn_pynq/cfg/sfc_1w1a.ini +++ b/brevitas_examples/bnn_pynq/cfg/sfc_1w1a.ini @@ -1,9 +1,10 @@ [MODEL] -ARCH: SFC +ARCH: FC PRETRAINED_URL: https://github.com/Xilinx/brevitas/releases/download/bnn_pynq-r0/sfc_1w1a-fd8a6c3d.pth EVAL_LOG: https://github.com/Xilinx/brevitas/releases/download/bnn_pynq-r0/sfc_1w1a_eval-fd8a6c3d.txt DATASET: MNIST IN_CHANNELS: 1 +OUT_FEATURES = [256, 256, 256] NUM_CLASSES: 10 [QUANT] diff --git a/brevitas_examples/bnn_pynq/cfg/sfc_1w2a.ini b/brevitas_examples/bnn_pynq/cfg/sfc_1w2a.ini index 13cf7a8d5..0dc26df51 100644 --- a/brevitas_examples/bnn_pynq/cfg/sfc_1w2a.ini +++ b/brevitas_examples/bnn_pynq/cfg/sfc_1w2a.ini @@ -1,9 +1,10 @@ [MODEL] -ARCH: SFC +ARCH: FC PRETRAINED_URL: https://github.com/Xilinx/brevitas/releases/download/bnn_pynq-r0/sfc_1w2a-fdc0c779.pth EVAL_LOG: https://github.com/Xilinx/brevitas/releases/download/bnn_pynq-r0/sfc_1w2a_eval-fdc0c779.txt DATASET: MNIST IN_CHANNELS: 1 +OUT_FEATURES = [256, 256, 256] NUM_CLASSES: 10 [QUANT] diff --git a/brevitas_examples/bnn_pynq/cfg/sfc_2w2a.ini b/brevitas_examples/bnn_pynq/cfg/sfc_2w2a.ini index ccb2db65e..1f4d268b2 100644 --- a/brevitas_examples/bnn_pynq/cfg/sfc_2w2a.ini +++ b/brevitas_examples/bnn_pynq/cfg/sfc_2w2a.ini @@ -1,9 +1,10 @@ [MODEL] -ARCH: SFC +ARCH: FC PRETRAINED_URL: https://github.com/Xilinx/brevitas/releases/download/bnn_pynq-r0/sfc_2w2a-35a7c41d.pth EVAL_LOG: https://github.com/Xilinx/brevitas/releases/download/bnn_pynq-r0/sfc_2w2a_eval-35a7c41d.txt DATASET: MNIST IN_CHANNELS: 1 +OUT_FEATURES = [256, 256, 256] NUM_CLASSES: 10 [QUANT] diff --git a/brevitas_examples/bnn_pynq/cfg/tfc_1w1a.ini b/brevitas_examples/bnn_pynq/cfg/tfc_1w1a.ini index fe9a9f1e3..2dcae82b4 100644 --- a/brevitas_examples/bnn_pynq/cfg/tfc_1w1a.ini +++ b/brevitas_examples/bnn_pynq/cfg/tfc_1w1a.ini @@ -1,9 +1,10 @@ [MODEL] -ARCH: TFC +ARCH: FC PRETRAINED_URL: https://github.com/Xilinx/brevitas/releases/download/bnn_pynq-r0/tfc_1w1a-ff8140dc.pth EVAL_LOG: https://github.com/Xilinx/brevitas/releases/download/bnn_pynq-r0/tfc_1w1a_eval-ff8140dc.txt DATASET: MNIST IN_CHANNELS: 1 +OUT_FEATURES = [64, 64, 64] NUM_CLASSES: 10 [QUANT] diff --git a/brevitas_examples/bnn_pynq/cfg/tfc_1w2a.ini b/brevitas_examples/bnn_pynq/cfg/tfc_1w2a.ini index 2a95141c0..147abe67d 100644 --- a/brevitas_examples/bnn_pynq/cfg/tfc_1w2a.ini +++ b/brevitas_examples/bnn_pynq/cfg/tfc_1w2a.ini @@ -1,9 +1,10 @@ [MODEL] -ARCH: TFC +ARCH: FC PRETRAINED_URL: https://github.com/Xilinx/brevitas/releases/download/bnn_pynq-r0/tfc_1w2a-95ad635b.pth EVAL_LOG: https://github.com/Xilinx/brevitas/releases/download/bnn_pynq-r0/tfc_1w2a_eval-95ad635b.txt DATASET: MNIST IN_CHANNELS: 1 +OUT_FEATURES = [64, 64, 64] NUM_CLASSES: 10 [QUANT] diff --git a/brevitas_examples/bnn_pynq/cfg/tfc_2w2a.ini b/brevitas_examples/bnn_pynq/cfg/tfc_2w2a.ini index bfffb0b65..c72ae67e5 100644 --- a/brevitas_examples/bnn_pynq/cfg/tfc_2w2a.ini +++ b/brevitas_examples/bnn_pynq/cfg/tfc_2w2a.ini @@ -1,9 +1,10 @@ [MODEL] -ARCH: TFC +ARCH: FC PRETRAINED_URL: https://github.com/Xilinx/brevitas/releases/download/bnn_pynq-r0/tfc_2w2a-7e0a62f1.pth EVAL_LOG: https://github.com/Xilinx/brevitas/releases/download/bnn_pynq-r0/tfc_2w2a_eval-7e0a62f1.txt DATASET: MNIST IN_CHANNELS: 1 +OUT_FEATURES = [64, 64, 64] NUM_CLASSES: 10 [QUANT] diff --git a/brevitas_examples/bnn_pynq/models/CNV.py b/brevitas_examples/bnn_pynq/models/CNV.py index 0c14264c2..859f7d97f 100644 --- a/brevitas_examples/bnn_pynq/models/CNV.py +++ b/brevitas_examples/bnn_pynq/models/CNV.py @@ -22,72 +22,71 @@ import torch from torch.nn import Module, ModuleList, BatchNorm2d, MaxPool2d, BatchNorm1d -from .tensor_norm import TensorNorm -from .common import get_quant_conv2d, get_quant_linear, get_act_quant, get_quant_type -from brevitas.nn import QuantConv2d, QuantHardTanh, QuantLinear +from brevitas.nn import QuantConv2d, QuantIdentity, QuantLinear from brevitas.core.restrict_val import RestrictValueType -from brevitas.core.scaling import ScalingImplType +from .tensor_norm import TensorNorm +from .common import CommonWeightQuant, CommonActQuant -# QuantConv2d configuration CNV_OUT_CH_POOL = [(64, False), (64, True), (128, False), (128, True), (256, False), (256, False)] - -# Intermediate QuantLinear configuration -INTERMEDIATE_FC_PER_OUT_CH_SCALING = False INTERMEDIATE_FC_FEATURES = [(256, 512), (512, 512)] - -# Last QuantLinear configuration LAST_FC_IN_FEATURES = 512 LAST_FC_PER_OUT_CH_SCALING = False - -# MaxPool2d configuration POOL_SIZE = 2 +KERNEL_SIZE = 3 class CNV(Module): - def __init__(self, num_classes=10, weight_bit_width=None, act_bit_width=None, in_bit_width=None, in_ch=3): + def __init__(self, num_classes, weight_bit_width, act_bit_width, in_bit_width, in_ch): super(CNV, self).__init__() - weight_quant_type = get_quant_type(weight_bit_width) - act_quant_type = get_quant_type(act_bit_width) - in_quant_type = get_quant_type(in_bit_width) - max_in_val = 1-2**(-7) # for Q1.7 input format self.conv_features = ModuleList() self.linear_features = ModuleList() - self.conv_features.append(QuantHardTanh(bit_width=in_bit_width, - quant_type=in_quant_type, - max_val=max_in_val, - restrict_scaling_type=RestrictValueType.POWER_OF_TWO, - scaling_impl_type=ScalingImplType.CONST)) + self.conv_features.append(QuantIdentity( # for Q1.7 input format + act_quant=CommonActQuant, + bit_width=in_bit_width, + min_val=- 1.0, + max_val=1.0 - 2.0 ** (-7), + narrow_range=False, + restrict_scaling_type=RestrictValueType.POWER_OF_TWO)) for out_ch, is_pool_enabled in CNV_OUT_CH_POOL: - self.conv_features.append(get_quant_conv2d(in_ch=in_ch, - out_ch=out_ch, - bit_width=weight_bit_width, - quant_type=weight_quant_type)) + self.conv_features.append(QuantConv2d( + kernel_size=KERNEL_SIZE, + in_channels=in_ch, + out_channels=out_ch, + bias=False, + weight_quant=CommonWeightQuant, + weight_bit_width=weight_bit_width)) in_ch = out_ch self.conv_features.append(BatchNorm2d(in_ch, eps=1e-4)) - self.conv_features.append(get_act_quant(act_bit_width, act_quant_type)) + self.conv_features.append(QuantIdentity( + act_quant=CommonActQuant, + bit_width=act_bit_width)) if is_pool_enabled: self.conv_features.append(MaxPool2d(kernel_size=2)) for in_features, out_features in INTERMEDIATE_FC_FEATURES: - self.linear_features.append(get_quant_linear(in_features=in_features, - out_features=out_features, - per_out_ch_scaling=INTERMEDIATE_FC_PER_OUT_CH_SCALING, - bit_width=weight_bit_width, - quant_type=weight_quant_type)) + self.linear_features.append(QuantLinear( + in_features=in_features, + out_features=out_features, + bias=False, + weight_quant=CommonWeightQuant, + weight_bit_width=weight_bit_width)) self.linear_features.append(BatchNorm1d(out_features, eps=1e-4)) - self.linear_features.append(get_act_quant(act_bit_width, act_quant_type)) - - self.linear_features.append(get_quant_linear(in_features=LAST_FC_IN_FEATURES, - out_features=num_classes, - per_out_ch_scaling=LAST_FC_PER_OUT_CH_SCALING, - bit_width=weight_bit_width, - quant_type=weight_quant_type)) + self.linear_features.append(QuantIdentity( + act_quant=CommonActQuant, + bit_width=act_bit_width)) + + self.linear_features.append(QuantLinear( + in_features=LAST_FC_IN_FEATURES, + out_features=num_classes, + bias=False, + weight_quant=CommonWeightQuant, + weight_bit_width=weight_bit_width)) self.linear_features.append(TensorNorm()) for m in self.modules(): diff --git a/brevitas_examples/bnn_pynq/models/SFC.py b/brevitas_examples/bnn_pynq/models/FC.py similarity index 54% rename from brevitas_examples/bnn_pynq/models/SFC.py rename to brevitas_examples/bnn_pynq/models/FC.py index 83cf53918..848f9e26f 100644 --- a/brevitas_examples/bnn_pynq/models/SFC.py +++ b/brevitas_examples/bnn_pynq/models/FC.py @@ -20,50 +20,53 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +import ast from functools import reduce from operator import mul from torch.nn import Module, ModuleList, BatchNorm1d, Dropout import torch -from .common import get_quant_linear, get_act_quant, get_quant_type, QuantLinear +from brevitas.nn import QuantIdentity, QuantLinear +from .common import CommonWeightQuant, CommonActQuant -FC_OUT_FEATURES = [256, 256, 256] -INTERMEDIATE_FC_PER_OUT_CH_SCALING = True -LAST_FC_PER_OUT_CH_SCALING = False -IN_DROPOUT = 0.2 -HIDDEN_DROPOUT = 0.2 +DROPOUT = 0.2 -class SFC(Module): +class FC(Module): - def __init__(self, num_classes=10, weight_bit_width=None, act_bit_width=None, - in_bit_width=None, in_ch=1, in_features=(28, 28)): - super(SFC, self).__init__() - - weight_quant_type = get_quant_type(weight_bit_width) - act_quant_type = get_quant_type(act_bit_width) - in_quant_type = get_quant_type(in_bit_width) + def __init__( + self, + num_classes, + weight_bit_width, + act_bit_width, + in_bit_width, + in_channels, + out_features, + in_features=(28, 28)): + super(FC, self).__init__() self.features = ModuleList() - self.features.append(get_act_quant(in_bit_width, in_quant_type)) - self.features.append(Dropout(p=IN_DROPOUT)) + self.features.append(QuantIdentity(act_quant=CommonActQuant, bit_width=in_bit_width)) + self.features.append(Dropout(p=DROPOUT)) in_features = reduce(mul, in_features) - for out_features in FC_OUT_FEATURES: - self.features.append(get_quant_linear(in_features=in_features, - out_features=out_features, - per_out_ch_scaling=INTERMEDIATE_FC_PER_OUT_CH_SCALING, - bit_width=weight_bit_width, - quant_type=weight_quant_type)) + for out_features in out_features: + self.features.append(QuantLinear( + in_features=in_features, + out_features=out_features, + bias=False, + weight_bit_width=weight_bit_width, + weight_quant=CommonWeightQuant)) in_features = out_features self.features.append(BatchNorm1d(num_features=in_features)) - self.features.append(get_act_quant(act_bit_width, act_quant_type)) - self.features.append(Dropout(p=HIDDEN_DROPOUT)) - self.features.append(get_quant_linear(in_features=in_features, - out_features=num_classes, - per_out_ch_scaling=LAST_FC_PER_OUT_CH_SCALING, - bit_width=weight_bit_width, - quant_type=weight_quant_type)) + self.features.append(QuantIdentity(act_quant=CommonActQuant, bit_width=act_bit_width)) + self.features.append(Dropout(p=DROPOUT)) + self.features.append(QuantLinear( + in_features=in_features, + out_features=num_classes, + bias=False, + weight_bit_width=weight_bit_width, + weight_quant=CommonWeightQuant)) self.features.append(BatchNorm1d(num_features=num_classes)) for m in self.modules(): @@ -83,15 +86,18 @@ def forward(self, x): return x -def sfc(cfg): +def fc(cfg): weight_bit_width = cfg.getint('QUANT', 'WEIGHT_BIT_WIDTH') act_bit_width = cfg.getint('QUANT', 'ACT_BIT_WIDTH') in_bit_width = cfg.getint('QUANT', 'IN_BIT_WIDTH') num_classes = cfg.getint('MODEL', 'NUM_CLASSES') in_channels = cfg.getint('MODEL', 'IN_CHANNELS') - net = SFC(weight_bit_width=weight_bit_width, - act_bit_width=act_bit_width, - in_bit_width=in_bit_width, - num_classes=num_classes, - in_ch=in_channels) + out_features = ast.literal_eval(cfg.get('MODEL', 'OUT_FEATURES')) + net = FC( + weight_bit_width=weight_bit_width, + act_bit_width=act_bit_width, + in_bit_width=in_bit_width, + in_channels=in_channels, + out_features=out_features, + num_classes=num_classes) return net \ No newline at end of file diff --git a/brevitas_examples/bnn_pynq/models/LFC.py b/brevitas_examples/bnn_pynq/models/LFC.py deleted file mode 100644 index d55c80eee..000000000 --- a/brevitas_examples/bnn_pynq/models/LFC.py +++ /dev/null @@ -1,97 +0,0 @@ -# MIT License -# -# Copyright (c) 2019 Xilinx -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -from functools import reduce -from operator import mul - -from torch.nn import Module, ModuleList, BatchNorm1d, Dropout -import torch - -from .common import get_quant_linear, get_act_quant, get_quant_type, QuantLinear - -FC_OUT_FEATURES = [1024, 1024, 1024] -INTERMEDIATE_FC_PER_OUT_CH_SCALING = True -LAST_FC_PER_OUT_CH_SCALING = False -IN_DROPOUT = 0.2 -HIDDEN_DROPOUT = 0.2 - - -class LFC(Module): - - def __init__(self, num_classes=10, weight_bit_width=None, act_bit_width=None, - in_bit_width=None, in_ch=1, in_features=(28, 28)): - super(LFC, self).__init__() - - weight_quant_type = get_quant_type(weight_bit_width) - act_quant_type = get_quant_type(act_bit_width) - in_quant_type = get_quant_type(in_bit_width) - - self.features = ModuleList() - self.features.append(get_act_quant(in_bit_width, in_quant_type)) - self.features.append(Dropout(p=IN_DROPOUT)) - in_features = reduce(mul, in_features) - for out_features in FC_OUT_FEATURES: - self.features.append(get_quant_linear(in_features=in_features, - out_features=out_features, - per_out_ch_scaling=INTERMEDIATE_FC_PER_OUT_CH_SCALING, - bit_width=weight_bit_width, - quant_type=weight_quant_type)) - in_features = out_features - self.features.append(BatchNorm1d(num_features=in_features)) - self.features.append(get_act_quant(act_bit_width, act_quant_type)) - self.features.append(Dropout(p=HIDDEN_DROPOUT)) - self.features.append(get_quant_linear(in_features=in_features, - out_features=num_classes, - per_out_ch_scaling=LAST_FC_PER_OUT_CH_SCALING, - bit_width=weight_bit_width, - quant_type=weight_quant_type)) - self.features.append(BatchNorm1d(num_features=num_classes)) - - for m in self.modules(): - if isinstance(m, QuantLinear): - torch.nn.init.uniform_(m.weight.data, -1, 1) - - def clip_weights(self, min_val, max_val): - for mod in self.features: - if isinstance(mod, QuantLinear): - mod.weight.data.clamp_(min_val, max_val) - - def forward(self, x): - x = x.view(x.shape[0], -1) - x = 2.0 * x - torch.tensor([1.0], device=x.device) - for mod in self.features: - x = mod(x) - return x - - -def lfc(cfg): - weight_bit_width = cfg.getint('QUANT', 'WEIGHT_BIT_WIDTH') - act_bit_width = cfg.getint('QUANT', 'ACT_BIT_WIDTH') - in_bit_width = cfg.getint('QUANT', 'IN_BIT_WIDTH') - num_classes = cfg.getint('MODEL', 'NUM_CLASSES') - in_channels = cfg.getint('MODEL', 'IN_CHANNELS') - net = LFC(weight_bit_width=weight_bit_width, - act_bit_width=act_bit_width, - in_bit_width=in_bit_width, - num_classes=num_classes, - in_ch=in_channels) - return net \ No newline at end of file diff --git a/brevitas_examples/bnn_pynq/models/TFC.py b/brevitas_examples/bnn_pynq/models/TFC.py deleted file mode 100644 index be5d33f75..000000000 --- a/brevitas_examples/bnn_pynq/models/TFC.py +++ /dev/null @@ -1,97 +0,0 @@ -# MIT License -# -# Copyright (c) 2019 Xilinx -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -from functools import reduce -from operator import mul - -from torch.nn import Module, ModuleList, BatchNorm1d, Dropout -import torch - -from .common import get_quant_linear, get_act_quant, get_quant_type, QuantLinear - -FC_OUT_FEATURES = [64, 64, 64] -INTERMEDIATE_FC_PER_OUT_CH_SCALING = True -LAST_FC_PER_OUT_CH_SCALING = False -IN_DROPOUT = 0.2 -HIDDEN_DROPOUT = 0.2 - - -class TFC(Module): - - def __init__(self, num_classes=10, weight_bit_width=None, act_bit_width=None, - in_bit_width=None, in_ch=1, in_features=(28, 28)): - super(TFC, self).__init__() - - weight_quant_type = get_quant_type(weight_bit_width) - act_quant_type = get_quant_type(act_bit_width) - in_quant_type = get_quant_type(in_bit_width) - - self.features = ModuleList() - self.features.append(get_act_quant(in_bit_width, in_quant_type)) - self.features.append(Dropout(p=IN_DROPOUT)) - in_features = reduce(mul, in_features) - for out_features in FC_OUT_FEATURES: - self.features.append(get_quant_linear(in_features=in_features, - out_features=out_features, - per_out_ch_scaling=INTERMEDIATE_FC_PER_OUT_CH_SCALING, - bit_width=weight_bit_width, - quant_type=weight_quant_type)) - in_features = out_features - self.features.append(BatchNorm1d(num_features=in_features)) - self.features.append(get_act_quant(act_bit_width, act_quant_type)) - self.features.append(Dropout(p=HIDDEN_DROPOUT)) - self.features.append(get_quant_linear(in_features=in_features, - out_features=num_classes, - per_out_ch_scaling=LAST_FC_PER_OUT_CH_SCALING, - bit_width=weight_bit_width, - quant_type=weight_quant_type)) - self.features.append(BatchNorm1d(num_features=num_classes)) - - for m in self.modules(): - if isinstance(m, QuantLinear): - torch.nn.init.uniform_(m.weight.data, -1, 1) - - def clip_weights(self, min_val, max_val): - for mod in self.features: - if isinstance(mod, QuantLinear): - mod.weight.data.clamp_(min_val, max_val) - - def forward(self, x): - x = x.view(x.shape[0], -1) - x = 2.0 * x - torch.tensor([1.0], device=x.device) - for mod in self.features: - x = mod(x) - return x - - -def tfc(cfg): - weight_bit_width = cfg.getint('QUANT', 'WEIGHT_BIT_WIDTH') - act_bit_width = cfg.getint('QUANT', 'ACT_BIT_WIDTH') - in_bit_width = cfg.getint('QUANT', 'IN_BIT_WIDTH') - num_classes = cfg.getint('MODEL', 'NUM_CLASSES') - in_channels = cfg.getint('MODEL', 'IN_CHANNELS') - net = TFC(weight_bit_width=weight_bit_width, - act_bit_width=act_bit_width, - in_bit_width=in_bit_width, - num_classes=num_classes, - in_ch=in_channels) - return net diff --git a/brevitas_examples/bnn_pynq/models/__init__.py b/brevitas_examples/bnn_pynq/models/__init__.py index b849edd8e..5f231a924 100644 --- a/brevitas_examples/bnn_pynq/models/__init__.py +++ b/brevitas_examples/bnn_pynq/models/__init__.py @@ -33,15 +33,11 @@ 'model_with_cfg'] from .CNV import cnv -from .LFC import lfc -from .TFC import tfc -from .SFC import sfc +from .FC import fc model_impl = { 'CNV': cnv, - 'LFC': lfc, - 'TFC': tfc, - 'SFC': sfc + 'FC': fc, } def get_model_cfg(name): diff --git a/brevitas_examples/bnn_pynq/models/common.py b/brevitas_examples/bnn_pynq/models/common.py index 3813786db..b652e7813 100644 --- a/brevitas_examples/bnn_pynq/models/common.py +++ b/brevitas_examples/bnn_pynq/models/common.py @@ -20,12 +20,45 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +from dependencies import Injector, value + from brevitas.core.bit_width import BitWidthImplType from brevitas.core.quant import QuantType from brevitas.core.restrict_val import RestrictValueType from brevitas.core.scaling import ScalingImplType + + +class CommonQuant(Injector): + bit_width_impl_type = BitWidthImplType.CONST + scaling_impl_type = ScalingImplType.CONST + restrict_scaling_type = RestrictValueType.FP + scaling_per_output_channel = False + narrow_range = True + signed = True + + @value + def quant_type(bit_width): + if bit_width is None: + return QuantType.FP + elif bit_width == 1: + return QuantType.BINARY + else: + return QuantType.INT + + +class CommonWeightQuant(CommonQuant): + scaling_const = 1.0 + + +class CommonActQuant(CommonQuant): + min_val = -1.0 + max_val = 1.0 + + + from brevitas.nn import QuantConv2d, QuantHardTanh, QuantLinear + # Quant common BIT_WIDTH_IMPL_TYPE = BitWidthImplType.CONST SCALING_VALUE_TYPE = RestrictValueType.LOG_FP @@ -56,7 +89,7 @@ def get_quant_type(bit_width): return QuantType.INT -def get_act_quant(act_bit_width, act_quant_type): +def get_act_quant(act_bit_width, act_quant_type): return QuantHardTanh(quant_type=act_quant_type, bit_width=act_bit_width, bit_width_impl_type=BIT_WIDTH_IMPL_TYPE, @@ -94,3 +127,8 @@ def get_quant_conv2d(in_ch, out_ch, bit_width, quant_type): weight_restrict_scaling_type=SCALING_VALUE_TYPE, weight_bit_width_impl_type=BIT_WIDTH_IMPL_TYPE, bias=BIAS_ENABLED) + + + + + diff --git a/brevitas_examples/bnn_pynq/trainer.py b/brevitas_examples/bnn_pynq/trainer.py index 5941dcd02..23c96f54a 100644 --- a/brevitas_examples/bnn_pynq/trainer.py +++ b/brevitas_examples/bnn_pynq/trainer.py @@ -49,7 +49,7 @@ def accuracy(output, target, topk=(1,)): res = [] for k in topk: - correct_k = correct[:k].view(-1).float().sum(0) + correct_k = correct[:k].flatten().float().sum(0) res.append(correct_k.mul_(100.0 / batch_size)) return res