Skip to content

Commit

Permalink
Examples (BNN-PYNQ): idiomatic and simplified implementation w.r.t. n…
Browse files Browse the repository at this point in the history
…ew D.I. APIs

Signed-off-by: Alessandro Pappalardo <[email protected]>
  • Loading branch information
volcacius committed Sep 8, 2020
1 parent 5fbb725 commit 7f12633
Show file tree
Hide file tree
Showing 15 changed files with 137 additions and 284 deletions.
3 changes: 2 additions & 1 deletion brevitas_examples/bnn_pynq/cfg/lfc_1w1a.ini
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
[MODEL]
ARCH: LFC
ARCH: FC
PRETRAINED_URL: https://github.com/Xilinx/brevitas/releases/download/bnn_pynq-r0/lfc_1w1a-db6e13bd.pth
EVAL_LOG: https://github.com/Xilinx/brevitas/releases/download/bnn_pynq-r0/lfc_1w1a_eval-db6e13bd.txt
DATASET: MNIST
IN_CHANNELS: 1
OUT_FEATURES = [1024, 1024, 1024]
NUM_CLASSES: 10

[QUANT]
Expand Down
3 changes: 2 additions & 1 deletion brevitas_examples/bnn_pynq/cfg/lfc_1w2a.ini
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
[MODEL]
ARCH: LFC
ARCH: FC
PRETRAINED_URL: https://github.com/Xilinx/brevitas/releases/download/bnn_pynq-r0/lfc_1w2a-0a771c67.pth
EVAL_LOG: https://github.com/Xilinx/brevitas/releases/download/bnn_pynq-r0/lfc_1w2a_eval-0a771c67.txt
DATASET: MNIST
IN_CHANNELS: 1
OUT_FEATURES = [1024, 1024, 1024]
NUM_CLASSES: 10

[QUANT]
Expand Down
3 changes: 2 additions & 1 deletion brevitas_examples/bnn_pynq/cfg/sfc_1w1a.ini
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
[MODEL]
ARCH: SFC
ARCH: FC
PRETRAINED_URL: https://github.com/Xilinx/brevitas/releases/download/bnn_pynq-r0/sfc_1w1a-fd8a6c3d.pth
EVAL_LOG: https://github.com/Xilinx/brevitas/releases/download/bnn_pynq-r0/sfc_1w1a_eval-fd8a6c3d.txt
DATASET: MNIST
IN_CHANNELS: 1
OUT_FEATURES = [256, 256, 256]
NUM_CLASSES: 10

[QUANT]
Expand Down
3 changes: 2 additions & 1 deletion brevitas_examples/bnn_pynq/cfg/sfc_1w2a.ini
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
[MODEL]
ARCH: SFC
ARCH: FC
PRETRAINED_URL: https://github.com/Xilinx/brevitas/releases/download/bnn_pynq-r0/sfc_1w2a-fdc0c779.pth
EVAL_LOG: https://github.com/Xilinx/brevitas/releases/download/bnn_pynq-r0/sfc_1w2a_eval-fdc0c779.txt
DATASET: MNIST
IN_CHANNELS: 1
OUT_FEATURES = [256, 256, 256]
NUM_CLASSES: 10

[QUANT]
Expand Down
3 changes: 2 additions & 1 deletion brevitas_examples/bnn_pynq/cfg/sfc_2w2a.ini
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
[MODEL]
ARCH: SFC
ARCH: FC
PRETRAINED_URL: https://github.com/Xilinx/brevitas/releases/download/bnn_pynq-r0/sfc_2w2a-35a7c41d.pth
EVAL_LOG: https://github.com/Xilinx/brevitas/releases/download/bnn_pynq-r0/sfc_2w2a_eval-35a7c41d.txt
DATASET: MNIST
IN_CHANNELS: 1
OUT_FEATURES = [256, 256, 256]
NUM_CLASSES: 10

[QUANT]
Expand Down
3 changes: 2 additions & 1 deletion brevitas_examples/bnn_pynq/cfg/tfc_1w1a.ini
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
[MODEL]
ARCH: TFC
ARCH: FC
PRETRAINED_URL: https://github.com/Xilinx/brevitas/releases/download/bnn_pynq-r0/tfc_1w1a-ff8140dc.pth
EVAL_LOG: https://github.com/Xilinx/brevitas/releases/download/bnn_pynq-r0/tfc_1w1a_eval-ff8140dc.txt
DATASET: MNIST
IN_CHANNELS: 1
OUT_FEATURES = [64, 64, 64]
NUM_CLASSES: 10

[QUANT]
Expand Down
3 changes: 2 additions & 1 deletion brevitas_examples/bnn_pynq/cfg/tfc_1w2a.ini
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
[MODEL]
ARCH: TFC
ARCH: FC
PRETRAINED_URL: https://github.com/Xilinx/brevitas/releases/download/bnn_pynq-r0/tfc_1w2a-95ad635b.pth
EVAL_LOG: https://github.com/Xilinx/brevitas/releases/download/bnn_pynq-r0/tfc_1w2a_eval-95ad635b.txt
DATASET: MNIST
IN_CHANNELS: 1
OUT_FEATURES = [64, 64, 64]
NUM_CLASSES: 10

[QUANT]
Expand Down
3 changes: 2 additions & 1 deletion brevitas_examples/bnn_pynq/cfg/tfc_2w2a.ini
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
[MODEL]
ARCH: TFC
ARCH: FC
PRETRAINED_URL: https://github.com/Xilinx/brevitas/releases/download/bnn_pynq-r0/tfc_2w2a-7e0a62f1.pth
EVAL_LOG: https://github.com/Xilinx/brevitas/releases/download/bnn_pynq-r0/tfc_2w2a_eval-7e0a62f1.txt
DATASET: MNIST
IN_CHANNELS: 1
OUT_FEATURES = [64, 64, 64]
NUM_CLASSES: 10

[QUANT]
Expand Down
77 changes: 38 additions & 39 deletions brevitas_examples/bnn_pynq/models/CNV.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,72 +22,71 @@

import torch
from torch.nn import Module, ModuleList, BatchNorm2d, MaxPool2d, BatchNorm1d
from .tensor_norm import TensorNorm
from .common import get_quant_conv2d, get_quant_linear, get_act_quant, get_quant_type
from brevitas.nn import QuantConv2d, QuantHardTanh, QuantLinear

from brevitas.nn import QuantConv2d, QuantIdentity, QuantLinear
from brevitas.core.restrict_val import RestrictValueType
from brevitas.core.scaling import ScalingImplType
from .tensor_norm import TensorNorm
from .common import CommonWeightQuant, CommonActQuant


# QuantConv2d configuration
CNV_OUT_CH_POOL = [(64, False), (64, True), (128, False), (128, True), (256, False), (256, False)]

# Intermediate QuantLinear configuration
INTERMEDIATE_FC_PER_OUT_CH_SCALING = False
INTERMEDIATE_FC_FEATURES = [(256, 512), (512, 512)]

# Last QuantLinear configuration
LAST_FC_IN_FEATURES = 512
LAST_FC_PER_OUT_CH_SCALING = False

# MaxPool2d configuration
POOL_SIZE = 2
KERNEL_SIZE = 3


class CNV(Module):

def __init__(self, num_classes=10, weight_bit_width=None, act_bit_width=None, in_bit_width=None, in_ch=3):
def __init__(self, num_classes, weight_bit_width, act_bit_width, in_bit_width, in_ch):
super(CNV, self).__init__()

weight_quant_type = get_quant_type(weight_bit_width)
act_quant_type = get_quant_type(act_bit_width)
in_quant_type = get_quant_type(in_bit_width)
max_in_val = 1-2**(-7) # for Q1.7 input format
self.conv_features = ModuleList()
self.linear_features = ModuleList()

self.conv_features.append(QuantHardTanh(bit_width=in_bit_width,
quant_type=in_quant_type,
max_val=max_in_val,
restrict_scaling_type=RestrictValueType.POWER_OF_TWO,
scaling_impl_type=ScalingImplType.CONST))
self.conv_features.append(QuantIdentity( # for Q1.7 input format
act_quant=CommonActQuant,
bit_width=in_bit_width,
min_val=- 1.0,
max_val=1.0 - 2.0 ** (-7),
narrow_range=False,
restrict_scaling_type=RestrictValueType.POWER_OF_TWO))

for out_ch, is_pool_enabled in CNV_OUT_CH_POOL:
self.conv_features.append(get_quant_conv2d(in_ch=in_ch,
out_ch=out_ch,
bit_width=weight_bit_width,
quant_type=weight_quant_type))
self.conv_features.append(QuantConv2d(
kernel_size=KERNEL_SIZE,
in_channels=in_ch,
out_channels=out_ch,
bias=False,
weight_quant=CommonWeightQuant,
weight_bit_width=weight_bit_width))
in_ch = out_ch
self.conv_features.append(BatchNorm2d(in_ch, eps=1e-4))
self.conv_features.append(get_act_quant(act_bit_width, act_quant_type))
self.conv_features.append(QuantIdentity(
act_quant=CommonActQuant,
bit_width=act_bit_width))
if is_pool_enabled:
self.conv_features.append(MaxPool2d(kernel_size=2))

for in_features, out_features in INTERMEDIATE_FC_FEATURES:
self.linear_features.append(get_quant_linear(in_features=in_features,
out_features=out_features,
per_out_ch_scaling=INTERMEDIATE_FC_PER_OUT_CH_SCALING,
bit_width=weight_bit_width,
quant_type=weight_quant_type))
self.linear_features.append(QuantLinear(
in_features=in_features,
out_features=out_features,
bias=False,
weight_quant=CommonWeightQuant,
weight_bit_width=weight_bit_width))
self.linear_features.append(BatchNorm1d(out_features, eps=1e-4))
self.linear_features.append(get_act_quant(act_bit_width, act_quant_type))

self.linear_features.append(get_quant_linear(in_features=LAST_FC_IN_FEATURES,
out_features=num_classes,
per_out_ch_scaling=LAST_FC_PER_OUT_CH_SCALING,
bit_width=weight_bit_width,
quant_type=weight_quant_type))
self.linear_features.append(QuantIdentity(
act_quant=CommonActQuant,
bit_width=act_bit_width))

self.linear_features.append(QuantLinear(
in_features=LAST_FC_IN_FEATURES,
out_features=num_classes,
bias=False,
weight_quant=CommonWeightQuant,
weight_bit_width=weight_bit_width))
self.linear_features.append(TensorNorm())

for m in self.modules():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,50 +20,53 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import ast
from functools import reduce
from operator import mul

from torch.nn import Module, ModuleList, BatchNorm1d, Dropout
import torch

from .common import get_quant_linear, get_act_quant, get_quant_type, QuantLinear
from brevitas.nn import QuantIdentity, QuantLinear
from .common import CommonWeightQuant, CommonActQuant

FC_OUT_FEATURES = [256, 256, 256]
INTERMEDIATE_FC_PER_OUT_CH_SCALING = True
LAST_FC_PER_OUT_CH_SCALING = False
IN_DROPOUT = 0.2
HIDDEN_DROPOUT = 0.2
DROPOUT = 0.2


class SFC(Module):
class FC(Module):

def __init__(self, num_classes=10, weight_bit_width=None, act_bit_width=None,
in_bit_width=None, in_ch=1, in_features=(28, 28)):
super(SFC, self).__init__()

weight_quant_type = get_quant_type(weight_bit_width)
act_quant_type = get_quant_type(act_bit_width)
in_quant_type = get_quant_type(in_bit_width)
def __init__(
self,
num_classes,
weight_bit_width,
act_bit_width,
in_bit_width,
in_channels,
out_features,
in_features=(28, 28)):
super(FC, self).__init__()

self.features = ModuleList()
self.features.append(get_act_quant(in_bit_width, in_quant_type))
self.features.append(Dropout(p=IN_DROPOUT))
self.features.append(QuantIdentity(act_quant=CommonActQuant, bit_width=in_bit_width))
self.features.append(Dropout(p=DROPOUT))
in_features = reduce(mul, in_features)
for out_features in FC_OUT_FEATURES:
self.features.append(get_quant_linear(in_features=in_features,
out_features=out_features,
per_out_ch_scaling=INTERMEDIATE_FC_PER_OUT_CH_SCALING,
bit_width=weight_bit_width,
quant_type=weight_quant_type))
for out_features in out_features:
self.features.append(QuantLinear(
in_features=in_features,
out_features=out_features,
bias=False,
weight_bit_width=weight_bit_width,
weight_quant=CommonWeightQuant))
in_features = out_features
self.features.append(BatchNorm1d(num_features=in_features))
self.features.append(get_act_quant(act_bit_width, act_quant_type))
self.features.append(Dropout(p=HIDDEN_DROPOUT))
self.features.append(get_quant_linear(in_features=in_features,
out_features=num_classes,
per_out_ch_scaling=LAST_FC_PER_OUT_CH_SCALING,
bit_width=weight_bit_width,
quant_type=weight_quant_type))
self.features.append(QuantIdentity(act_quant=CommonActQuant, bit_width=act_bit_width))
self.features.append(Dropout(p=DROPOUT))
self.features.append(QuantLinear(
in_features=in_features,
out_features=num_classes,
bias=False,
weight_bit_width=weight_bit_width,
weight_quant=CommonWeightQuant))
self.features.append(BatchNorm1d(num_features=num_classes))

for m in self.modules():
Expand All @@ -83,15 +86,18 @@ def forward(self, x):
return x


def sfc(cfg):
def fc(cfg):
weight_bit_width = cfg.getint('QUANT', 'WEIGHT_BIT_WIDTH')
act_bit_width = cfg.getint('QUANT', 'ACT_BIT_WIDTH')
in_bit_width = cfg.getint('QUANT', 'IN_BIT_WIDTH')
num_classes = cfg.getint('MODEL', 'NUM_CLASSES')
in_channels = cfg.getint('MODEL', 'IN_CHANNELS')
net = SFC(weight_bit_width=weight_bit_width,
act_bit_width=act_bit_width,
in_bit_width=in_bit_width,
num_classes=num_classes,
in_ch=in_channels)
out_features = ast.literal_eval(cfg.get('MODEL', 'OUT_FEATURES'))
net = FC(
weight_bit_width=weight_bit_width,
act_bit_width=act_bit_width,
in_bit_width=in_bit_width,
in_channels=in_channels,
out_features=out_features,
num_classes=num_classes)
return net
Loading

0 comments on commit 7f12633

Please sign in to comment.