-
Notifications
You must be signed in to change notification settings - Fork 202
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Alessandro Pappalardo <[email protected]>
- Loading branch information
Showing
128 changed files
with
3,595 additions
and
10,418 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,3 @@ | ||
Components: pytorch | ||
|
||
|
||
/* | ||
Copyright (c) 2018- Xilinx, Inc (Alessandro Pappalardo) | ||
Copyright (c) 2016- Facebook, Inc (Adam Paszke) | ||
Copyright (c) 2014- Facebook, Inc (Soumith Chintala) | ||
|
@@ -25,9 +21,9 @@ modification, are permitted provided that the following conditions are met: | |
notice, this list of conditions and the following disclaimer in the | ||
documentation and/or other materials provided with the distribution. | ||
|
||
3. Neither the names of Xilinx, Facebook, Deepmind Technologies, NYU, | ||
NEC Laboratories America and IDIAP Research Institute nor the names | ||
of its contributors may be used to endorse or promote products derived | ||
3. Neither the names of Xilinx, Facebook, Deepmind Technologies, NYU, | ||
NEC Laboratories America and IDIAP Research Institute nor the names | ||
of its contributors may be used to endorse or promote products derived | ||
from this software without specific prior written permission. | ||
|
||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | ||
|
@@ -40,46 +36,4 @@ SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS | |
INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN | ||
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) | ||
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE | ||
POSSIBILITY OF SUCH DAMAGE. | ||
*/ | ||
|
||
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||
## Mostly based on the Pytorch-Encoding source code, due MIT copyright below | ||
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||
## Created by: Hang Zhang | ||
## ECE Department, Rutgers University | ||
## Email: [email protected] | ||
## Copyright (c) 2017 | ||
## | ||
## This source code is licensed under the MIT-style license found below | ||
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||
# MIT License | ||
|
||
# Copyright (c) 2017 Hang Zhang | ||
|
||
# Permission is hereby granted, free of charge, to any person obtaining a copy | ||
# of this software and associated documentation files (the "Software"), to deal | ||
# in the Software without restriction, including without limitation the rights | ||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
# copies of the Software, and to permit persons to whom the Software is | ||
# furnished to do so, subject to the following conditions: | ||
|
||
# The above copyright notice and this permission notice shall be included in | ||
# all copies or substantial portions of the Software. | ||
|
||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
# SOFTWARE. | ||
|
||
|
||
Copyright 2012-2014 Deepmind Technologies | ||
Copyright 2018 Xilinx Inc | ||
Copyright 2017 Hang Zhang | ||
Copyright 2011-2013 NYU | ||
Copyright 2014, 2016 Facebook Inc | ||
Copyright 2001-2014 Idiap Research Institute | ||
Copyright 2006-2012 NEC Laboratories America | ||
POSSIBILITY OF SUCH DAMAGE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,35 +1,40 @@ | ||
# Pytorch Quantization | ||
# Brevitas | ||
|
||
## Introduction | ||
|
||
This repository implements a set of quantization strategies to be applied to supported type of layers. | ||
Brevitas is a Pytorch library for training-aware quantization. | ||
|
||
The code originally started from the Pytorch and ATen implementation of a fused GRU/LSTM, extracted as a CFFI extension and expanded from there. | ||
*Brevitas is currently in alpha stage and under active development. APIs might and probably will change. Documentation, examples, and pretrained models will be progressively released.* | ||
|
||
## Requirements | ||
Building currently requires an appropriate CUDA environment, but execution is supported on CPU as well. | ||
* [Pytorch](https://pytorch.org) >= 1.1.0 | ||
|
||
* Nvidia CUDA Toolkit (tested with CUDA 9.0) | ||
* [Pytorch](https://pytorch.org) (tested with version 0.3.1) | ||
## Introduction | ||
|
||
## Installation | ||
Brevitas implements a set of building blocks to model a reduced precision hardware data-path at training time. | ||
While partially biased towards modelling dataflow-style, very low-precision implementations, the building blocks can be parametrized and assembled together to target all sorts of reduced precision hardware. | ||
|
||
1. Run `python build.py` | ||
2. Add current path to the python path: `EXPORT PYTHONPATH=/path/to/pytorch-quantization:PYTHONPATH` | ||
The implementations tries to adhere to the following design principles: | ||
- Idiomatic Pytorch, when possible. | ||
- Modularity first, at the cost of some verbosity. | ||
- Easily extendible. | ||
|
||
## Usage | ||
## Target audience | ||
Brevitas is mainly targeted at researchers and practicioners in the fields of training for reduced precision inference. | ||
|
||
The following quantization modes are implemented for weights: | ||
The implementation is quite rich in options and allows for very fine grained control over the trained model. However, compared to other software solutions in this space, the burden of correctly modelling the target data-path is currently placed on the user. | ||
|
||
* FP: full-precision, no quantization performed. | ||
* SIGNED_FIXED_UNIT: fixed point quantization between [-1,1). | ||
## Features | ||
|
||
The following quantization modes are implemented for activations: | ||
Soon. | ||
|
||
## Installation | ||
|
||
Soon. | ||
|
||
## Usage | ||
Soon. | ||
|
||
* FP: full-precision, no quantization performed. | ||
* SIGNED_FIXED_UNIT: fixed point quantization between [-1,1). | ||
## Author | ||
|
||
The following quantized layers are implemented: | ||
Alessandro Pappalardo @ Xilinx Research Labs. | ||
|
||
* QuantizedLinear | ||
* QuantizedLSTM | ||
## |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
ZERO_HW_SENTINEL_NAME = 'zero_hw_sentinel' | ||
ZERO_HW_SENTINEL_VALUE = 0.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,244 @@ | ||
# Copyright (c) 2018- Xilinx, Inc (Alessandro Pappalardo) | ||
# Copyright (c) 2016- Facebook, Inc (Adam Paszke) | ||
# Copyright (c) 2014- Facebook, Inc (Soumith Chintala) | ||
# Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) | ||
# Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) | ||
# Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) | ||
# Copyright (c) 2011-2013 NYU (Clement Farabet) | ||
# Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) | ||
# Copyright (c) 2006 Idiap Research Institute (Samy Bengio) | ||
# Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) | ||
|
||
# All rights reserved. | ||
|
||
# Redistribution and use in source and binary forms, with or without | ||
# modification, are permitted provided that the following conditions are met: | ||
|
||
# 1. Redistributions of source code must retain the above copyright | ||
# notice, this list of conditions and the following disclaimer. | ||
|
||
# 2. Redistributions in binary form must reproduce the above copyright | ||
# notice, this list of conditions and the following disclaimer in the | ||
# documentation and/or other materials provided with the distribution. | ||
|
||
# 3. Neither the names of Xilinx, Facebook, Deepmind Technologies, NYU, | ||
# NEC Laboratories America and IDIAP Research Institute nor the names | ||
# of its contributors may be used to endorse or promote products derived | ||
# from this software without specific prior written permission. | ||
|
||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | ||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | ||
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE | ||
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR | ||
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF | ||
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS | ||
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN | ||
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) | ||
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE | ||
# POSSIBILITY OF SUCH DAMAGE. | ||
|
||
from typing import Optional | ||
from enum import auto | ||
|
||
import torch | ||
from torch import Tensor | ||
from torch.nn import Parameter | ||
|
||
from brevitas.utils.python_utils import AutoName | ||
from brevitas.function.ops import tensor_clamp_ste, tensor_clamp | ||
from .restrict_val import RestrictValueOpImplType, RestrictValueType, RestrictValue, FloatToIntImplType | ||
|
||
|
||
MIN_INT_BIT_WIDTH = 2 | ||
NON_ZERO_EPSILON = 1e-6 | ||
REMOVE_ZERO_BIT_WIDTH = 0.1 | ||
|
||
|
||
class BitWidthImplType(AutoName): | ||
CONST = auto() | ||
PARAMETER = auto() | ||
|
||
|
||
class ZeroLsbTruncBitWidth(torch.jit.ScriptModule): | ||
|
||
def forward(self, input_bit_width: Tensor, zero_hw_sentinel: Tensor): | ||
return zero_hw_sentinel | ||
|
||
|
||
class BitWidthConst(torch.jit.ScriptModule): | ||
__constants__ = ['bit_width'] | ||
|
||
def __init__(self, bit_width_init: int, restrict_bit_width_type: RestrictValueType) -> None: | ||
super(BitWidthConst, self).__init__() | ||
|
||
if restrict_bit_width_type != RestrictValueType.INT: | ||
raise Exception("When bit width is predefined, it has to be an INT value.") | ||
|
||
self.bit_width = int(bit_width_init) | ||
|
||
@torch.jit.script_method | ||
def forward(self, zero_hw_sentinel: Tensor) -> Tensor: | ||
return self.bit_width + zero_hw_sentinel | ||
|
||
|
||
class BitWidthParameter(torch.jit.ScriptModule): | ||
__constants__ = ['bit_width_base', 'max_bit_width', 'override_pretrained'] | ||
|
||
def __init__(self, | ||
bit_width_init: int, | ||
min_overall_bit_width: Optional[int], | ||
max_overall_bit_width: Optional[int], | ||
restrict_bit_width_type: RestrictValueType, | ||
override_pretrained: bool) -> None: | ||
super(BitWidthParameter, self).__init__() | ||
|
||
if min_overall_bit_width is None: | ||
min_overall_bit_width = MIN_INT_BIT_WIDTH | ||
if not (restrict_bit_width_type == RestrictValueType.FP | ||
or restrict_bit_width_type == RestrictValueType.INT | ||
or restrict_bit_width_type == RestrictValueType.POWER_OF_TWO): | ||
raise Exception("Restriction on bit width {} not supported".format(restrict_bit_width_type)) | ||
if bit_width_init < MIN_INT_BIT_WIDTH or min_overall_bit_width < MIN_INT_BIT_WIDTH: | ||
raise Exception("Int bit width has to be at least {}, instead is {}." | ||
.format(MIN_INT_BIT_WIDTH, bit_width_init)) | ||
|
||
self.override_pretrained = override_pretrained | ||
bit_width_init_op = RestrictValue.restrict_value_op(restrict_bit_width_type, | ||
restrict_value_op_impl_type=RestrictValueOpImplType.MATH) | ||
self.restrict_bit_width = RestrictValue(restrict_bit_width_type, | ||
float_to_int_impl_type=FloatToIntImplType.ROUND) | ||
self.bit_width_base = bit_width_init_op(min_overall_bit_width) | ||
self.max_bit_width = bit_width_init_op(min_overall_bit_width) if max_overall_bit_width is not None else None | ||
bit_width_offset_init = bit_width_init_op(bit_width_init) - self.bit_width_base | ||
self.bit_width_offset = Parameter(torch.tensor(float(bit_width_offset_init))) | ||
|
||
@torch.jit.script_method | ||
def forward(self, zero_hw_sentinel: Tensor) -> Tensor: | ||
if self.max_bit_width is not None: | ||
raise Exception("Not implemented yet.") | ||
bit_width = torch.abs(self.bit_width_offset) + self.bit_width_base | ||
bit_width = self.restrict_bit_width(bit_width) | ||
return bit_width | ||
|
||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, | ||
missing_keys, unexpected_keys, error_msgs): | ||
super(BitWidthParameter, self)._load_from_state_dict(state_dict, prefix, local_metadata, strict, | ||
missing_keys, unexpected_keys, error_msgs) | ||
bit_width_offset_key = prefix + 'bit_width_offset' | ||
if bit_width_offset_key in missing_keys: | ||
missing_keys.remove(bit_width_offset_key) | ||
if self.override_pretrained and bit_width_offset_key in state_dict: | ||
del state_dict[bit_width_offset_key] | ||
|
||
|
||
class RemoveBitwidthParameter(torch.jit.ScriptModule): | ||
__constants__ = ['min_overall_bit_width', 'non_zero_epsilon', 'override_pretrained', 'remove_at_least_init_val'] | ||
|
||
def __init__(self, bit_width_to_remove, remove_at_least_init_val, restrict_bit_width_impl, override_pretrained): | ||
super(RemoveBitwidthParameter, self).__init__() | ||
|
||
if bit_width_to_remove < 0: | ||
raise Exception("Bit width to clamp has to be at least 0, instead is {}." | ||
.format(bit_width_to_remove)) | ||
elif bit_width_to_remove == 0: | ||
bit_width_coeff_init = 1 / REMOVE_ZERO_BIT_WIDTH | ||
else: | ||
bit_width_coeff_init = 1 / bit_width_to_remove | ||
self.bit_width_coeff = Parameter(torch.tensor(bit_width_coeff_init)) | ||
self.restrict_bit_width_impl = restrict_bit_width_impl | ||
self.non_zero_epsilon = NON_ZERO_EPSILON | ||
self.override_pretrained = override_pretrained | ||
self.remove_at_least_init_val = remove_at_least_init_val | ||
|
||
@torch.jit.script_method | ||
def forward(self, zero_hw_sentinel) -> Tensor: | ||
bit_width_to_remove = 1.0 / (self.non_zero_epsilon + torch.abs(self.bit_width_coeff)) | ||
bit_width_to_remove = self.restrict_bit_width_impl(bit_width_to_remove) | ||
return bit_width_to_remove | ||
|
||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, | ||
missing_keys, unexpected_keys, error_msgs): | ||
super(RemoveBitwidthParameter, self)._load_from_state_dict(state_dict, prefix, local_metadata, strict, | ||
missing_keys, unexpected_keys, error_msgs) | ||
bit_width_coeff_key = prefix + 'bit_width_coeff' | ||
if bit_width_coeff_key in missing_keys: | ||
missing_keys.remove(bit_width_coeff_key) | ||
if self.override_pretrained and bit_width_coeff_key in state_dict: | ||
del state_dict[bit_width_coeff_key] | ||
|
||
|
||
class MsbClampParameterBitWidth(torch.jit.ScriptModule): | ||
__constants__ = ['min_overall_bit_width', 'max_overall_bit_width'] | ||
|
||
def __init__(self, | ||
ms_bit_width_to_clamp: int, | ||
clamp_at_least_init_val: bool, | ||
min_overall_bit_width: int, | ||
max_overall_bit_width: int, | ||
bit_width_impl_type: BitWidthImplType, | ||
override_pretrained: bool) -> None: | ||
super(MsbClampParameterBitWidth, self).__init__() | ||
|
||
self.min_overall_bit_width = min_overall_bit_width | ||
self.max_overall_bit_width = max_overall_bit_width | ||
|
||
if bit_width_impl_type == BitWidthImplType.CONST: | ||
self.bit_width_to_remove_impl = BitWidthConst(ms_bit_width_to_clamp, RestrictValueType.INT) | ||
elif bit_width_impl_type == BitWidthImplType.PARAMETER: | ||
restrict_bit_width_impl = RestrictValue(RestrictValueType.INT, | ||
float_to_int_impl_type=FloatToIntImplType.ROUND) | ||
self.bit_width_to_remove_impl = RemoveBitwidthParameter(bit_width_to_remove=ms_bit_width_to_clamp, | ||
remove_at_least_init_val=clamp_at_least_init_val, | ||
restrict_bit_width_impl=restrict_bit_width_impl, | ||
override_pretrained=override_pretrained) | ||
else: | ||
raise Exception("Bit width implementation type {} not recognized for clamping accumulator." | ||
.format(bit_width_impl_type)) | ||
|
||
@torch.jit.script_method | ||
def forward(self, input_bit_width: Tensor, zero_hw_sentinel: Tensor) -> Tensor: | ||
bit_width_to_remove = self.bit_width_to_remove_impl(zero_hw_sentinel) | ||
output_bit_width = torch.abs(input_bit_width - bit_width_to_remove) | ||
output_bit_width = tensor_clamp_ste(output_bit_width, | ||
self.min_overall_bit_width + zero_hw_sentinel, | ||
self.max_overall_bit_width + zero_hw_sentinel) | ||
return output_bit_width | ||
|
||
|
||
class LsbTruncParameterBitWidth(torch.jit.ScriptModule): | ||
__constants__ = ['is_const', 'min_overall_bit_width', 'max_overall_bit_width'] | ||
|
||
def __init__(self, | ||
ls_bit_width_to_trunc: int, | ||
trunc_at_least_init_val: bool, | ||
min_overall_bit_width: int, | ||
max_overall_bit_width: int, | ||
bit_width_impl_type: BitWidthImplType, | ||
override_pretrained: bool): | ||
super(LsbTruncParameterBitWidth, self).__init__() | ||
|
||
self.min_overall_bit_width = min_overall_bit_width | ||
self.max_overall_bit_width = max_overall_bit_width | ||
|
||
if bit_width_impl_type == BitWidthImplType.CONST: | ||
self.bit_width_to_remove_impl = BitWidthConst(ls_bit_width_to_trunc, RestrictValueType.INT) | ||
elif bit_width_impl_type == BitWidthImplType.PARAMETER: | ||
restrict_bit_width_impl = RestrictValue(RestrictValueType.INT, | ||
float_to_int_impl_type=FloatToIntImplType.ROUND) | ||
self.bit_width_to_remove_impl = RemoveBitwidthParameter(bit_width_to_remove=ls_bit_width_to_trunc, | ||
remove_at_least_init_val=trunc_at_least_init_val, | ||
restrict_bit_width_impl=restrict_bit_width_impl, | ||
override_pretrained=override_pretrained) | ||
else: | ||
raise Exception("Bit width implementation type {} not recognized for truncating accumulator." | ||
.format(bit_width_impl_type)) | ||
|
||
@torch.jit.script_method | ||
def forward(self, input_bit_width: Tensor, zero_hw_sentinel: Tensor) -> Tensor: | ||
bit_width_to_remove = self.bit_width_to_remove_impl(zero_hw_sentinel) | ||
max_bit_width_to_remove = input_bit_width - self.min_overall_bit_width | ||
bit_width_to_remove = torch.where(bit_width_to_remove > max_bit_width_to_remove, | ||
max_bit_width_to_remove, | ||
bit_width_to_remove) | ||
return bit_width_to_remove |
Oops, something went wrong.