diff --git a/src/brevitas/core/quant/int.py b/src/brevitas/core/quant/int.py index 328ad63b3..9bae3a42a 100644 --- a/src/brevitas/core/quant/int.py +++ b/src/brevitas/core/quant/int.py @@ -1,15 +1,19 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -from typing import Optional, Tuple +from typing import Optional, Tuple, Union import torch from torch import Tensor from torch.nn import Module import brevitas +from brevitas.core.function_wrapper import TensorClamp from brevitas.core.quant.delay import DelayWrapper +from brevitas.core.scaling import TruncMsbScaling from brevitas.core.utils import StatelessBuffer +from brevitas.function.ops import max_int +from brevitas.function.ops import min_int from brevitas.function.ops_ste import round_ste @@ -201,28 +205,56 @@ class TruncIntQuant(brevitas.jit.ScriptModule): """ """ + __constants__ = ['narrow_range'] + def __init__( - self, float_to_int_impl: Module, bit_width_impl: Module, quant_delay_steps: int = 0): + self, + float_to_int_impl: Module, + bit_width_impl: Module, + trunc_scaling_impl: Module = TruncMsbScaling(), + narrow_range: bool = False, + tensor_clamp_impl: Module = TensorClamp(), + quant_delay_steps: int = 0): super(TruncIntQuant, self).__init__() + self.narrow_range = narrow_range self.msb_clamp_bit_width_impl = bit_width_impl + self.trunc_scaling_impl = trunc_scaling_impl self.float_to_int_impl = float_to_int_impl + self.tensor_clamp_impl = tensor_clamp_impl self.delay_wrapper = DelayWrapper(quant_delay_steps) @brevitas.jit.script_method - def forward(self, x: Tensor, scale: Tensor, zero_point: Tensor, - input_bit_width: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + def min_int(self, bit_width: Tensor, signed: Union[bool, Tensor]): + return min_int(signed, self.narrow_range, bit_width) + + @brevitas.jit.script_method + def max_int(self, bit_width: Tensor, signed: Union[bool, Tensor]): + return max_int(signed, self.narrow_range, bit_width) + + @brevitas.jit.script_method + def forward( + self, + x: Tensor, + scale: Tensor, + zero_point: Tensor, + input_bit_width: Tensor, + signed: Union[bool, Tensor]) -> Tuple[Tensor, Tensor, Tensor, Tensor]: y = x / scale y = y + zero_point y = round_ste(y) # clean up floating point error output_bit_width = self.msb_clamp_bit_width_impl() - trunc_bit_width = input_bit_width - output_bit_width - trunc_scale = 2.0 ** trunc_bit_width + trunc_scale = self.trunc_scaling_impl(y, input_bit_width, output_bit_width, signed) y = y / trunc_scale + min_int_val = self.min_int(output_bit_width, signed) + max_int_val = self.max_int(output_bit_width, signed) y = self.float_to_int_impl(y) - y = y - zero_point - y = y * scale + y = self.tensor_clamp_impl(y, min_val=min_int_val, max_val=max_int_val) + output_scale = scale * trunc_scale + output_zero_point = zero_point / trunc_scale + y = y - output_zero_point + y = y * output_scale y = self.delay_wrapper(x, y) - return y, scale, zero_point, output_bit_width + return y, output_scale, output_zero_point, output_bit_width class DecoupledRescalingIntQuantWithInput(DecoupledRescalingIntQuant): diff --git a/src/brevitas/core/scaling/__init__.py b/src/brevitas/core/scaling/__init__.py index c21cb4b27..d8a786c8b 100644 --- a/src/brevitas/core/scaling/__init__.py +++ b/src/brevitas/core/scaling/__init__.py @@ -18,5 +18,7 @@ from .standalone import ParameterFromRuntimeStatsScaling from .standalone import ParameterFromStatsFromParameterScaling from .standalone import ParameterScaling +from .standalone import TruncMsbScaling +from .standalone import TruncScalingWrapper SCALING_STATS_REDUCE_DIM = 1 diff --git a/src/brevitas/core/scaling/int_scaling.py b/src/brevitas/core/scaling/int_scaling.py index 2c58db598..100515571 100644 --- a/src/brevitas/core/scaling/int_scaling.py +++ b/src/brevitas/core/scaling/int_scaling.py @@ -1,6 +1,8 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +from typing import Optional, Union + from torch import Tensor import brevitas @@ -11,26 +13,30 @@ class IntScaling(brevitas.jit.ScriptModule): __constants__ = ['signed', 'narrow_range'] - def __init__(self, signed: bool, narrow_range: bool): + def __init__(self, narrow_range: bool, signed: Optional[bool] = None): super(IntScaling, self).__init__() self.signed = signed self.narrow_range = narrow_range @brevitas.jit.script_method - def forward(self, bit_width: Tensor) -> Tensor: - if self.signed: - return -min_int(self.signed, self.narrow_range, bit_width) + def forward(self, bit_width: Tensor, signed: Optional[Union[bool, Tensor]] = None) -> Tensor: + is_signed = signed if signed is not None else self.signed + assert is_signed is not None, f"signed is not defined, signed={is_signed}" + if is_signed: + return -min_int(is_signed, self.narrow_range, bit_width) else: - return max_int(self.signed, self.narrow_range, bit_width) + return max_int(is_signed, self.narrow_range, bit_width) class PowerOfTwoIntScaling(brevitas.jit.ScriptModule): __constants__ = ['signed'] - def __init__(self, signed: bool): + def __init__(self, signed: Optional[bool] = None): super(PowerOfTwoIntScaling, self).__init__() self.signed = signed @brevitas.jit.script_method - def forward(self, bit_width: Tensor) -> Tensor: - return max_int(self.signed, False, bit_width) + 1 + def forward(self, bit_width: Tensor, signed: Optional[Union[bool, Tensor]] = None) -> Tensor: + is_signed = signed if signed is not None else self.signed + assert is_signed is not None, f"signed is not defined, signed={is_signed}" + return max_int(is_signed, False, bit_width) + 1 diff --git a/src/brevitas/core/scaling/standalone.py b/src/brevitas/core/scaling/standalone.py index 703fed5a4..8be2b2cd4 100644 --- a/src/brevitas/core/scaling/standalone.py +++ b/src/brevitas/core/scaling/standalone.py @@ -12,6 +12,7 @@ import brevitas.config as config from brevitas.core.function_wrapper import Identity from brevitas.core.function_wrapper import OverBatchOverTensorView +from brevitas.core.function_wrapper import TensorClamp from brevitas.core.restrict_val import _ClampValue from brevitas.core.restrict_val import _RestrictClampValue from brevitas.core.restrict_val import _RestrictValue @@ -469,3 +470,51 @@ def _load_from_state_dict( self.counter = self.collect_stats_steps + 1 if config.IGNORE_MISSING_KEYS and value_key in missing_keys: missing_keys.remove(value_key) + + +class TruncMsbScaling(brevitas.jit.ScriptModule): + """ + """ + + def __init__(self) -> None: + super(TruncMsbScaling, self).__init__() + + @brevitas.jit.script_method + def forward( + self, + scaling_input: Tensor, + input_bit_width: Tensor, + output_bit_width: Tensor, + signed: Union[bool, Tensor]) -> Tensor: + return 2 ** (input_bit_width - output_bit_width) + + +class TruncScalingWrapper(brevitas.jit.ScriptModule): + """ + """ + + def __init__( + self, + trunc_int_scaling_impl: Module, + scaling_impl: Module, + tensor_clamp_impl: Module = TensorClamp()) -> None: + super(TruncScalingWrapper, self).__init__() + self.trunc_int_scaling_impl = trunc_int_scaling_impl + self.scaling_impl = scaling_impl + self.tensor_clamp_impl = tensor_clamp_impl + + @brevitas.jit.script_method + def forward( + self, + scaling_input: Tensor, + input_bit_width: Tensor, + output_bit_width: Tensor, + signed: Union[bool, Tensor]) -> Tensor: + threshold = self.trunc_int_scaling_impl(output_bit_width, signed) + scale = self.scaling_impl(scaling_input, threshold) + msb_scale = 2 ** (input_bit_width - output_bit_width) + unit_scale = torch.ones_like(msb_scale) + max_scale = torch.where(msb_scale > unit_scale, msb_scale, unit_scale) + min_scale = torch.where(msb_scale < unit_scale, msb_scale, unit_scale) + trunc_scale = self.tensor_clamp_impl(scale, min_scale, max_scale) + return trunc_scale diff --git a/src/brevitas/export/common/handler/qcdq.py b/src/brevitas/export/common/handler/qcdq.py index 39347baad..2bb37af89 100644 --- a/src/brevitas/export/common/handler/qcdq.py +++ b/src/brevitas/export/common/handler/qcdq.py @@ -773,18 +773,26 @@ class QCDQCastTruncQuantProxyHandlerMixin(QuantAxisMixin, ABC): handled_layer = TruncQuantProxyFromInjector + def validate(self, module): + assert module.zero_point() == 0, "Zero-point export not supported for TruncQuant." + super(QCDQCastTruncQuantProxyHandlerMixin, self).validate(module) + def prepare_for_export(self, module: TruncQuantProxyFromInjector): if module.is_quant_enabled: self.validate(module) - self.symbolic_kwargs = {'output_bit_width': module.bit_width()} + self.symbolic_kwargs = { + 'narrow_range': module.is_narrow_range, + 'output_scale': module.scale(), + 'output_bit_width': module.bit_width()} def symbolic_execution( self, x: Tensor, scale: Tensor, zero_point: Tensor, input_bit_width: Tensor, signed: Tensor): assert self.symbolic_kwargs is not None, 'Symbolic execution requires quant to be enabled' output_bit_width = self.symbolic_kwargs['output_bit_width'] + narrow_range = self.symbolic_kwargs['narrow_range'] dtype = self.int8_dtype() if signed else self.uint8_dtype() - trunc_scale = 2.0 ** (input_bit_width - output_bit_width) + scale = self.symbolic_kwargs['output_scale'] # Input scale is ignored now # If original dtype of scale is (b)float16, store the original scale dtype # and cast the scale and the input to float32 scale_dtype = scale.dtype @@ -792,19 +800,17 @@ def symbolic_execution( scale = self.cast_fn(scale, torch.float32) if x.dtype == torch.bfloat16 or x.dtype == torch.float16: x = self.cast_fn(x, torch.float32) - pre_scale = scale * trunc_scale - flat_pre_scale = to_0dim_if_scalar(pre_scale.flatten()) flat_scale = to_0dim_if_scalar(scale.flatten()) zp = to_0dim_if_scalar(zero_point.flatten()).expand_as(flat_scale) zp = self.zero_point_with_dtype(signed, output_bit_width, zp) - x = self.quantize_fn(x, flat_pre_scale, zp, dtype, self.quant_axis(pre_scale)) + x = self.quantize_fn(x, flat_scale, zp, dtype, self.quant_axis(scale)) clip_symbolic_kwargs = self.int_clip_symbolic_kwargs( - signed=signed, narrow=False, bit_width=output_bit_width) + signed=signed, narrow=self.symbolic_kwargs['narrow_range'], bit_width=output_bit_width) if clip_symbolic_kwargs is not None: x = self.clip_fn(x, *clip_symbolic_kwargs.values()) x = self.dequantize_fn(x, flat_scale, zp, self.quant_axis(scale)) # After dequantization, cast both output and scale to the correct dtype if scale_dtype == torch.float16 or scale_dtype == torch.bfloat16: x = self.cast_fn(x, scale_dtype) - scale = self.cast_fn(scale, scale_dtype) - return x, scale, zero_point, output_bit_width + flat_scale = self.cast_fn(flat_scale, scale_dtype) + return x, flat_scale, zero_point, output_bit_width diff --git a/src/brevitas/export/onnx/qonnx/function.py b/src/brevitas/export/onnx/qonnx/function.py index 5160572ef..0fbeba8f6 100644 --- a/src/brevitas/export/onnx/qonnx/function.py +++ b/src/brevitas/export/onnx/qonnx/function.py @@ -111,26 +111,48 @@ def forward( class BrevitasTruncFn(Function): @staticmethod - def symbolic(g, x, scale, zero_point, input_bit_width, output_bit_width, rounding_mode): - ret = g.op( - f'{DOMAIN_STRING}::Trunc', + def symbolic( + g, x, scale, zero_point, input_bit_width, + signed, + narrow_range, + output_scale, + output_bit_width, + rounding_mode): + ret = g.op( + f'{DOMAIN_STRING}::Quant', + x, + output_scale, + zero_point, output_bit_width, - rounding_mode_s=rounding_mode) + rounding_mode_s=rounding_mode, + signed_i=int(signed), + narrow_i=int(narrow_range)) ret.setType(x.type()) return ret @staticmethod - def forward(ctx, x, scale, zero_point, input_bit_width, output_bit_width, rounding_mode): - float_to_int_impl = solve_float_to_int_impl_from_enum(rounding_mode) - trunc = TruncIntQuant( - float_to_int_impl=float_to_int_impl(), - bit_width_impl=BitWidthConst(int(output_bit_width))) - y_tuple = trunc(x, scale, zero_point, input_bit_width) - return y_tuple[0] + def forward( + ctx, + x, + scale, + zero_point, + input_bit_width, + signed, + narrow_range, + output_scale, + output_bit_width, + rounding_mode): + # TODO: Restore this (fails when `signed` arg added) + #float_to_int_impl = solve_float_to_int_impl_from_enum(rounding_mode) + #trunc = TruncIntQuant( + # float_to_int_impl=float_to_int_impl(), + # bit_width_impl=BitWidthConst(int(output_bit_width))) + #y_tuple = trunc(x, scale, zero_point, input_bit_width, signed) + return x class BrevitasQuantLSTMCellFn(Function): diff --git a/src/brevitas/export/onnx/qonnx/handler.py b/src/brevitas/export/onnx/qonnx/handler.py index 5468cd1aa..1c1489ce2 100644 --- a/src/brevitas/export/onnx/qonnx/handler.py +++ b/src/brevitas/export/onnx/qonnx/handler.py @@ -227,16 +227,23 @@ def symbolic_execution(self, x: Tensor, input_scale=None, input_bit_width=None): class BrevitasTruncQuantProxyHandler(ONNXBaseHandler): handled_layer = TruncQuantProxyFromInjector + def validate(self, module): + assert module.zero_point() == 0, "Zero-point export not supported for TruncQuant." + def prepare_for_export(self, module: TruncQuantProxyFromInjector): + self.validate(module) self.symbolic_kwargs = { - 'output_bit_width': module.bit_width(), 'rounding_mode': module.rounding_mode} + 'narrow_range': module.is_narrow_range, + 'output_scale': module.scale(), + 'output_bit_width': module.bit_width(), + 'rounding_mode': module.rounding_mode} def symbolic_execution( self, x: Tensor, scale: Tensor, zero_point: Tensor, input_bit_width: Tensor, signed: Tensor): y = BrevitasTruncFn.apply( - x, scale, zero_point, input_bit_width, *self.symbolic_kwargs.values()) - return y, scale, zero_point, self.symbolic_kwargs['output_bit_width'] + x, scale, zero_point, input_bit_width, signed, *self.symbolic_kwargs.values()) + return y, self.symbolic_kwargs['output_scale'], zero_point, self.symbolic_kwargs['output_bit_width'] class BrevitasQuantLSTMLayerHandler(QuantLSTMLayerHandler): diff --git a/src/brevitas/function/ops.py b/src/brevitas/function/ops.py index 74da08e19..a87462ce6 100644 --- a/src/brevitas/function/ops.py +++ b/src/brevitas/function/ops.py @@ -6,6 +6,8 @@ The implemented functions adheres to the restriction imposed by Pytorch 1.1.0's TorchScript compiler. """ +from typing import Union + import torch from torch import Tensor @@ -128,7 +130,9 @@ def identity(x: Tensor) -> Tensor: @brevitas.jit.script -def max_int(signed: bool, narrow_range: bool, bit_width: Tensor) -> Tensor: +def max_int( + signed: Union[bool, Tensor], narrow_range: Union[bool, Tensor], + bit_width: Tensor) -> Tensor: """ Compute the maximum integer representable by a given number of bits. Args: @@ -159,7 +163,9 @@ def max_int(signed: bool, narrow_range: bool, bit_width: Tensor) -> Tensor: @brevitas.jit.script -def min_int(signed: bool, narrow_range: bool, bit_width: Tensor) -> Tensor: +def min_int( + signed: Union[bool, Tensor], narrow_range: Union[bool, Tensor], + bit_width: Tensor) -> Tensor: """ Compute the minimum integer representable by a given number of bits. Args: diff --git a/src/brevitas/inject/enum.py b/src/brevitas/inject/enum.py index fbac29176..c0990d717 100644 --- a/src/brevitas/inject/enum.py +++ b/src/brevitas/inject/enum.py @@ -93,3 +93,11 @@ class StatsOp(AutoName): # Typically adopted for asymmetric quantization MIN_MAX = auto() PERCENTILE_INTERVAL = auto() + + +class TruncScalingImplType(AutoName): + """ + + """ + MSB = auto() + WRAPPER = auto() diff --git a/src/brevitas/nn/quant_avg_pool.py b/src/brevitas/nn/quant_avg_pool.py index 12f324901..795bd0cf5 100644 --- a/src/brevitas/nn/quant_avg_pool.py +++ b/src/brevitas/nn/quant_avg_pool.py @@ -33,10 +33,19 @@ def __init__( self, kernel_size: Union[int, Tuple[int, int]], stride: Union[int, Tuple[int, int]] = None, + ceil_mode: bool = False, + count_include_pad: bool = True, + divisor_override: Optional[int] = None, trunc_quant: Optional[AccQuantType] = RoundTo8bit, return_quant_tensor: bool = True, **kwargs): - AvgPool2d.__init__(self, kernel_size=kernel_size, stride=stride) + AvgPool2d.__init__( + self, + kernel_size=kernel_size, + stride=stride, + ceil_mode=ceil_mode, + count_include_pad=count_include_pad, + divisor_override=divisor_override) QuantLayerMixin.__init__(self, return_quant_tensor) TruncMixin.__init__(self, trunc_quant=trunc_quant, **kwargs) self.cache_inference_quant_act = False @@ -58,6 +67,7 @@ def _avg_scaling(self): else: return self.kernel_size * self.kernel_size + # TODO: Replace with functional call def forward(self, input: Union[Tensor, QuantTensor]): x = self.unpack_input(input) @@ -71,8 +81,6 @@ def forward(self, input: Union[Tensor, QuantTensor]): if not isinstance(x, QuantTensor): x = self.cache_class.quant_tensor.set(value=x) y = AvgPool2d.forward(self, x) - rescaled_value = y.value * self._avg_scaling - y = y.set(value=rescaled_value) y = self.trunc_quant(y) else: y = AvgPool2d.forward(self, _unpack_quant_tensor(x)) @@ -123,6 +131,7 @@ def compute_kernel_size_stride(input_shape, output_shape): stride_list.append(stride) return kernel_size_list, stride_list + # TODO: Replace with functional call def forward(self, input: Union[Tensor, QuantTensor]): x = self.unpack_input(input) @@ -139,10 +148,6 @@ def forward(self, input: Union[Tensor, QuantTensor]): if not isinstance(x, QuantTensor): x = self.cache_class.quant_tensor.set(value=x) y = AdaptiveAvgPool2d.forward(self, x) - k_size, stride = self.compute_kernel_size_stride(x.value.shape[2:], y.value.shape[2:]) - reduce_size = reduce(mul, k_size, 1) - rescaled_value = y.value * reduce_size # remove avg scaling - y = y.set(value=rescaled_value) y = self.trunc_quant(y) else: y = AdaptiveAvgPool2d.forward(self, _unpack_quant_tensor(x)) diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index 3e7248602..12dcf9528 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -60,7 +60,7 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> QuantTensor: @runtime_checkable class AccQuantProxyProtocol(QuantProxyProtocol, Protocol): - def forward(self, x: QuantTensor) -> QuantTensor: + def forward(self, x: Union[Tensor, IntQuantTensor]) -> Union[Tensor, IntQuantTensor]: ... @@ -262,16 +262,31 @@ class TruncQuantProxyFromInjector(QuantProxyFromInjector, AccQuantProxyProtocol) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + self._cached_act = None + self.cache_inference_quant_act = True + self.cache_quant_io_metadata_only = True + self.cache_class = _CachedIO self.skip_create_quant_tensor = False - def bit_width(self): - if not self.is_quant_enabled: + def retrieve_attribute(self, attribute): + if self._cached_act is not None: + return getattr(self._cached_act, attribute) + elif self._cached_act is None: return None - zhs = self._zero_hw_sentinel() - # Signed might or might not be defined. We just care about retrieving the bitwidth - empty_imp = IntQuantTensor(zhs, zhs, zhs, zhs, signed=True, training=self.training) - bit_width = self.__call__(empty_imp).bit_width - return bit_width + + @property + def is_narrow_range(self): + narrow_range = super(TruncQuantProxyFromInjector, self).is_narrow_range + return narrow_range if narrow_range is not None else False + + def scale(self): + return self.retrieve_attribute('scale') + + def zero_point(self): + return self.retrieve_attribute('zero_point') + + def bit_width(self): + return self.retrieve_attribute('bit_width') def forward(self, x: IntQuantTensor) -> Union[Tensor, IntQuantTensor]: if self.is_quant_enabled: @@ -279,12 +294,16 @@ def forward(self, x: IntQuantTensor) -> Union[Tensor, IntQuantTensor]: out_tuple = self.export_handler( x.value, x.scale, x.zero_point, x.bit_width, x.signed) else: - out_tuple = self.tensor_quant(x.value, x.scale, x.zero_point, x.bit_width) + out_tuple = self.tensor_quant(x.value, x.scale, x.zero_point, x.bit_width, x.signed) out_value, out_scale, out_zp, out_bit_width = out_tuple if self.skip_create_quant_tensor: return out_value - return IntQuantTensor( + out = IntQuantTensor( out_value, out_scale, out_zp, out_bit_width, x.signed, self.training) + if not self.training and self.cache_inference_quant_act: + cached_out = self.cache_class(out.detach(), self.cache_quant_io_metadata_only) + self._cached_act = cached_out + return out else: return x diff --git a/src/brevitas/quant/scaled_int.py b/src/brevitas/quant/scaled_int.py index b5f9174d7..cbe681f5e 100644 --- a/src/brevitas/quant/scaled_int.py +++ b/src/brevitas/quant/scaled_int.py @@ -28,6 +28,7 @@ 'Int8WeightPerChannelFloatMSE', 'TruncTo8bit', 'RoundTo8bit', + 'ShiftRoundSaturateTo8bit', 'Int4WeightPerTensorFloatDecoupled', 'Int8WeightPerChannelFloatDecoupled', 'Uint8ActPerTensorFloatBatchQuant1d', @@ -261,7 +262,7 @@ class Uint8ActPerTensorFloatMSE(MSESymmetricScale, Uint8ActPerTensorFloat): class TruncTo8bit(TruncQuantSolver): """ - 8-bit signed int truncator that preserves the input scale factor and zero-point. + 8-bit int truncator that preserves most-significant bits and zero-point. Examples: >>> from brevitas.nn import TruncAvgPool2d @@ -271,11 +272,12 @@ class TruncTo8bit(TruncQuantSolver): quant_type = 'int' bit_width_impl_type = 'const' float_to_int_impl_type = 'floor' + trunc_scaling_impl_type = 'msb' class RoundTo8bit(TruncQuantSolver): """ - 8-bit signed int truncator with rounding that preserves the input scale factor and zero-point. + 8-bit int truncator with rounding that preserves most-significant bits and zero-point. Examples: >>> from brevitas.nn import TruncAvgPool2d @@ -285,6 +287,25 @@ class RoundTo8bit(TruncQuantSolver): quant_type = 'int' bit_width_impl_type = 'const' float_to_int_impl_type = 'round' + trunc_scaling_impl_type = 'msb' + + +class ShiftRoundSaturateTo8bit(TruncQuantSolver, + ParamFromRuntimePercentileScaling, + PerTensorPoTScaling8bit): + """ + 8-bit shift-round-saturate quantizer which uses statistics to calculate the amount of truncation + the lest-significant bits and most-significant bits. Zero-point is preserved. + + Examples: + >>> from brevitas.nn import TruncAvgPool2d + >>> pool = TruncAvgPool2d(kernel_size=(3, 3), trunc_quant=ShiftRoundSaturateTo8bit) + """ + bit_width = 8 + quant_type = 'int' + bit_width_impl_type = 'const' + float_to_int_impl_type = 'round' + trunc_scaling_impl_type = 'wrapper' class Int4WeightPerTensorFloatDecoupled(WeightPerTensorFloatDecoupledL2Param): diff --git a/src/brevitas/quant/solver/trunc.py b/src/brevitas/quant/solver/trunc.py index 9203e2f11..c167826e8 100644 --- a/src/brevitas/quant/solver/trunc.py +++ b/src/brevitas/quant/solver/trunc.py @@ -2,12 +2,17 @@ # SPDX-License-Identifier: BSD-3-Clause from brevitas.core.quant import TruncIntQuant +from brevitas.core.scaling import PowerOfTwoIntScaling +from brevitas.core.scaling import TruncMsbScaling +from brevitas.core.scaling import TruncScalingWrapper from brevitas.inject import ExtendedInjector from brevitas.inject import value from brevitas.inject.enum import QuantType +from brevitas.inject.enum import RestrictValueType +from brevitas.inject.enum import TruncScalingImplType from brevitas.proxy import TruncQuantProxyFromInjector -from brevitas.quant.solver.common import SolveBitWidthImplFromEnum -from brevitas.quant.solver.common import SolveTensorQuantFloatToIntImplFromEnum +from brevitas.quant.solver.act import * +from brevitas.quant.solver.common import * class SolveTruncTensorQuantFromEnum(ExtendedInjector): @@ -26,9 +31,44 @@ def tensor_quant(quant_type): raise RuntimeError(f'{quant_type} not recognized.') +class SolveTruncScalingImplFromEnum(ExtendedInjector): + + @value + def trunc_scaling_impl(trunc_scaling_impl_type="msb"): + if trunc_scaling_impl_type == TruncScalingImplType.MSB: + return TruncMsbScaling + elif trunc_scaling_impl_type == TruncScalingImplType.WRAPPER: + return TruncScalingWrapper + else: + raise RuntimeError(f'trunc_scaling_impl_type={trunc_scaling_impl_type} not recognized.') + + +class SolveTruncIntScalingImplFromEnum(ExtendedInjector): + + @value + def trunc_int_scaling_impl(restrict_scaling_type): + if restrict_scaling_type == RestrictValueType.POWER_OF_TWO: + return PowerOfTwoIntScaling + else: + raise RuntimeError(f'restrict_scaling_type={restrict_scaling_type} not recognized.') + + class TruncQuantSolver(SolveBitWidthImplFromEnum, SolveTensorQuantFloatToIntImplFromEnum, - SolveTruncTensorQuantFromEnum): + SolveActScalingImplFromEnum, + SolveIntScalingImplFromEnum, + SolveScalingStatsOpFromEnum, + SolveRestrictScalingImplFromEnum, + SolveActScalingInitFromEnum, + SolveStatsReduceDimFromEnum, + SolveActScalingShape, + SolveScalingStatsInputViewShapeImplFromEnum, + SolveActScalingPerOutputChannelShape, + SolveUpdateStateDictImplFromEnum, + SolveInputViewImpl, + SolveTruncTensorQuantFromEnum, + SolveTruncScalingImplFromEnum, + SolveTruncIntScalingImplFromEnum): """ Translate enum directives to truncation-specific quantization core modules. It should be placed last in the list of classes a quantizer inherits from, diff --git a/src/brevitas/quant_tensor/int_torch_handler.py b/src/brevitas/quant_tensor/int_torch_handler.py index 3258b8914..a2e6572da 100644 --- a/src/brevitas/quant_tensor/int_torch_handler.py +++ b/src/brevitas/quant_tensor/int_torch_handler.py @@ -109,12 +109,15 @@ def avg_pool2d_handler( max_acc_bit_width = FN_ACC_BITWIDTH_MAPPING[F.avg_pool2d] # remove avg scaling - if isinstance(kernel_size, tuple): + if divisor_override is not None: + avg_scaling = divisor_override + elif isinstance(kernel_size, tuple): avg_scaling = kernel_size[0] * kernel_size[1] else: avg_scaling = kernel_size * kernel_size quant_input = quant_input.set(value=x) + quant_input = quant_input.set(scale=quant_input.scale / avg_scaling) quant_input = quant_input.set(bit_width=max_acc_bit_width(quant_input.bit_width, avg_scaling)) return quant_input @@ -134,6 +137,7 @@ def adaptive_avg_pool2d_handler(quant_input, output_shape): reduce_size = reduce(mul, k_size, 1) quant_input = quant_input.set(value=x) + quant_input = quant_input.set(scale=quant_input.scale / reduce_size) quant_input = quant_input.set(bit_width=max_acc_bit_width(quant_input.bit_width, reduce_size)) return quant_input diff --git a/src/brevitas_examples/imagenet_classification/qat/README.md b/src/brevitas_examples/imagenet_classification/qat/README.md index 92e318c05..5d95a3c55 100644 --- a/src/brevitas_examples/imagenet_classification/qat/README.md +++ b/src/brevitas_examples/imagenet_classification/qat/README.md @@ -7,7 +7,7 @@ Below in the table is a list of example pretrained models made available for ref | Name | Cfg | Scaling Type | First layer weights | Weights | Activations | Avg pool | Top1 | Pretrained model | Retrained from | |--------------|-----------------------|----------------------------|---------------------|---------|-------------|----------|-------|-------------------------------------------------------------------------------------------------|---------------------------------------------------------------| -| MobileNet V1 | quant_mobilenet_v1_4b | Floating-point per channel | 8 bit | 4 bit | 4 bit | 4 bit | 70.95 | [Download](https://github.com/Xilinx/brevitas/releases/download/quant_mobilenet_v1_4b-r1/quant_mobilenet_v1_4b-0100a667.pth) | [link](https://github.com/osmr/imgclsmob/tree/master/pytorch) | +| MobileNet V1 | quant_mobilenet_v1_4b | Floating-point per channel | 8 bit | 4 bit | 4 bit | 4 bit | 70.86 | [Download](https://github.com/Xilinx/brevitas/releases/download/quant_mobilenet_v1_4b-r1/quant_mobilenet_v1_4b-0100a667.pth) | [link](https://github.com/osmr/imgclsmob/tree/master/pytorch) | | ProxylessNAS Mobile14 w/ Hadamard classifier | quant_proxylessnas_mobile14_hadamard_4b | Floating-point per channel | 8 bit | 4 bit | 4 bit | 4 bit | 72.87 | [Download](https://github.com/Xilinx/brevitas/releases/download/quant_proxylessnas_mobile14_hadamard_4b-r0/quant_proxylessnas_mobile14_hadamard_4b-4acbfa9f.pth) | [link](https://github.com/osmr/imgclsmob/tree/master/pytorch) | | ProxylessNAS Mobile14 | quant_proxylessnas_mobile14_4b | Floating-point per channel | 8 bit | 4 bit | 4 bit | 4 bit | 74.39 | [Download](https://github.com/Xilinx/brevitas/releases/download/quant_proxylessnas_mobile14_4b-r0/quant_proxylessnas_mobile14_4b-e10882e1.pth) | [link](https://github.com/osmr/imgclsmob/tree/master/pytorch) | | ProxylessNAS Mobile14 | quant_proxylessnas_mobile14_4b5b | Floating-point per channel | 8 bit | 4 bit, 5 bit | 4 bit, 5 bit | 4 bit | 74.94 | [Download](https://github.com/Xilinx/brevitas/releases/download/quant_proxylessnas_mobile14_4b5b-r0/quant_proxylessnas_mobile14_4b5b-2bdf7f8d.pth) | [link](https://github.com/osmr/imgclsmob/tree/master/pytorch) | diff --git a/tests/brevitas/core/test_trunc_int_quant.py b/tests/brevitas/core/test_trunc_int_quant.py new file mode 100644 index 000000000..4eafcc819 --- /dev/null +++ b/tests/brevitas/core/test_trunc_int_quant.py @@ -0,0 +1,524 @@ +# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +import logging + +import pytest_cases +import torch + +from brevitas.core.bit_width import BitWidthConst +from brevitas.core.function_wrapper import Identity +from brevitas.core.function_wrapper import RoundSte +from brevitas.core.function_wrapper import TensorClamp +from brevitas.core.quant import TruncIntQuant +from brevitas.core.restrict_val import PowerOfTwoRestrictValue +from brevitas.core.scaling import PowerOfTwoIntScaling +from brevitas.core.scaling import RuntimeStatsScaling +from brevitas.core.scaling import TruncMsbScaling +from brevitas.core.scaling import TruncScalingWrapper +from brevitas.core.stats import AbsMax +from tests.brevitas.core.bit_width_fixture import * # noqa +from tests.brevitas.core.int_quant_fixture import * # noqa + + +def allexact(x, y): + return np.allclose(x, y, rtol=0.0, atol=0.0, equal_nan=False) + + +class TestTruncIntQuantUnit: + + def test_trunc_int_quant_defaults(self, bit_width_const): + trunc_int_quant = TruncIntQuant( + bit_width_impl=bit_width_const, float_to_int_impl=RoundSte()) + assert isinstance(trunc_int_quant.tensor_clamp_impl, TensorClamp) + assert isinstance(trunc_int_quant.trunc_scaling_impl, TruncMsbScaling) + assert trunc_int_quant.narrow_range == False + + # yapf: disable + @pytest_cases.fixture + @pytest_cases.parametrize( + "test_cfg", + [ + { # defaults_uint_overflow + "init_args": { + "bit_width_impl": BitWidthConst(4), + "float_to_int_impl": RoundSte(), + }, + "train_args": { + "x": torch.tensor([255.]), + "scale": torch.tensor([1.]), + "zero_point": torch.tensor([0.]), + "input_bit_width": torch.tensor([8.]), + "signed": False, + }, + "eval_args": { + "x": torch.tensor([255.]), + "scale": torch.tensor([1.]), + "zero_point": torch.tensor([0.]), + "input_bit_width": torch.tensor([8.]), + "signed": False, + }, + "result": { # Result needs to match the order of the output tuple + "y": torch.tensor([240.]), + "scale": torch.tensor([16.]), + "zero_point": torch.tensor([0.]), + "bit_width": torch.tensor([4.]), + }, + }, { # defaults_int+_overflow + "init_args": { + "bit_width_impl": BitWidthConst(4), + "float_to_int_impl": RoundSte(), + }, + "train_args": { + "x": torch.tensor([127.]), + "scale": torch.tensor([1.]), + "zero_point": torch.tensor([0.]), + "input_bit_width": torch.tensor([8.]), + "signed": True, + }, + "eval_args": { + "x": torch.tensor([127.]), + "scale": torch.tensor([1.]), + "zero_point": torch.tensor([0.]), + "input_bit_width": torch.tensor([8.]), + "signed": True, + }, + "result": { # Result needs to match the order of the output tuple + "y": torch.tensor([112.]), + "scale": torch.tensor([16.]), + "zero_point": torch.tensor([0.]), + "bit_width": torch.tensor([4.]), + }, + }, { # defaults_int+_overflow_zp + "init_args": { + "bit_width_impl": BitWidthConst(4), + "float_to_int_impl": RoundSte(), + }, + "train_args": { + "x": torch.tensor([1727.]), + "scale": torch.tensor([1.]), + "zero_point": torch.tensor([-1600.]), + "input_bit_width": torch.tensor([8.]), + "signed": True, + }, + "eval_args": { + "x": torch.tensor([1727.]), + "scale": torch.tensor([1.]), + "zero_point": torch.tensor([-1600.]), + "input_bit_width": torch.tensor([8.]), + "signed": True, + }, + "result": { # Result needs to match the order of the output tuple + "y": torch.tensor([1712.]), + "scale": torch.tensor([16.]), + "zero_point": torch.tensor([-100.]), + "bit_width": torch.tensor([4.]), + }, + }, { # defaults_int-_max + "init_args": { + "bit_width_impl": BitWidthConst(4), + "float_to_int_impl": RoundSte(), + }, + "train_args": { + "x": torch.tensor([-128.]), + "scale": torch.tensor([1.]), + "zero_point": torch.tensor([0.]), + "input_bit_width": torch.tensor([8.]), + "signed": True, + }, + "eval_args": { + "x": torch.tensor([-128.]), + "scale": torch.tensor([1.]), + "zero_point": torch.tensor([0.]), + "input_bit_width": torch.tensor([8.]), + "signed": True, + }, + "result": { # Result needs to match the order of the output tuple + "y": torch.tensor([-128.]), + "scale": torch.tensor([16.]), + "zero_point": torch.tensor([0.]), + "bit_width": torch.tensor([4.]), + }, + }, { # defaults_uint_underflow + "init_args": { + "bit_width_impl": BitWidthConst(4), + "float_to_int_impl": RoundSte(), + }, + "train_args": { + "x": torch.tensor([8.]), + "scale": torch.tensor([1.]), + "zero_point": torch.tensor([0.]), + "input_bit_width": torch.tensor([8.]), + "signed": False, + }, + "eval_args": { + "x": torch.tensor([8.]), + "scale": torch.tensor([1.]), + "zero_point": torch.tensor([0.]), + "input_bit_width": torch.tensor([8.]), + "signed": False, + }, + "result": { # Result needs to match the order of the output tuple + "y": torch.tensor([0.]), + "scale": torch.tensor([16.]), + "zero_point": torch.tensor([0.]), + "bit_width": torch.tensor([4.]), + }, + }, { # defaults_int_underflow + "init_args": { + "bit_width_impl": BitWidthConst(4), + "float_to_int_impl": RoundSte(), + }, + "train_args": { + "x": torch.tensor([-8.]), + "scale": torch.tensor([1.]), + "zero_point": torch.tensor([0.]), + "input_bit_width": torch.tensor([8.]), + "signed": True, + }, + "eval_args": { + "x": torch.tensor([-8.]), + "scale": torch.tensor([1.]), + "zero_point": torch.tensor([0.]), + "input_bit_width": torch.tensor([8.]), + "signed": True, + }, + "result": { # Result needs to match the order of the output tuple + "y": torch.tensor([0.]), + "scale": torch.tensor([16.]), + "zero_point": torch.tensor([0.]), + "bit_width": torch.tensor([4.]), + }, + }, { # defaults_uint_ulp + "init_args": { + "bit_width_impl": BitWidthConst(4), + "float_to_int_impl": RoundSte(), + }, + "train_args": { + "x": torch.tensor([9.]), + "scale": torch.tensor([1.]), + "zero_point": torch.tensor([0.]), + "input_bit_width": torch.tensor([8.]), + "signed": False, + }, + "eval_args": { + "x": torch.tensor([9.]), + "scale": torch.tensor([1.]), + "zero_point": torch.tensor([0.]), + "input_bit_width": torch.tensor([8.]), + "signed": False, + }, + "result": { # Result needs to match the order of the output tuple + "y": torch.tensor([16.]), + "scale": torch.tensor([16.]), + "zero_point": torch.tensor([0.]), + "bit_width": torch.tensor([4.]), + }, + }, { # defaults_int_ulp + "init_args": { + "bit_width_impl": BitWidthConst(4), + "float_to_int_impl": RoundSte(), + }, + "train_args": { + "x": torch.tensor([-9.]), + "scale": torch.tensor([1.]), + "zero_point": torch.tensor([0.]), + "input_bit_width": torch.tensor([8.]), + "signed": True, + }, + "eval_args": { + "x": torch.tensor([-9.]), + "scale": torch.tensor([1.]), + "zero_point": torch.tensor([0.]), + "input_bit_width": torch.tensor([8.]), + "signed": True, + }, + "result": { # Result needs to match the order of the output tuple + "y": torch.tensor([-16.]), + "scale": torch.tensor([16.]), + "zero_point": torch.tensor([0.]), + "bit_width": torch.tensor([4.]), + }, + }, { # abxmax_uint_overflow + "init_args": { + "bit_width_impl": BitWidthConst(4), + "float_to_int_impl": RoundSte(), + "trunc_scaling_impl": TruncScalingWrapper( + trunc_int_scaling_impl=PowerOfTwoIntScaling(), + scaling_impl=RuntimeStatsScaling( + scaling_stats_impl=AbsMax(), + scaling_stats_input_view_shape_impl=Identity(), + scaling_shape=(1,), + scaling_stats_momentum=1.0, + restrict_scaling_impl=PowerOfTwoRestrictValue(), + ) + ), + }, + "train_args": { + "x": torch.tensor([128.]), + "scale": torch.tensor([1.]), + "zero_point": torch.tensor([0.]), + "input_bit_width": torch.tensor([8.]), + "signed": False, + }, + "eval_args": { + "x": torch.tensor([255.]), + "scale": torch.tensor([1.]), + "zero_point": torch.tensor([0.]), + "input_bit_width": torch.tensor([8.]), + "signed": False, + }, + "result": { # Result needs to match the order of the output tuple + "y": torch.tensor([120.]), + "scale": torch.tensor([8.]), + "zero_point": torch.tensor([0.]), + "bit_width": torch.tensor([4.]), + }, + }, { # abxmax_int+_overflow + "init_args": { + "bit_width_impl": BitWidthConst(4), + "float_to_int_impl": RoundSte(), + "trunc_scaling_impl": TruncScalingWrapper( + trunc_int_scaling_impl=PowerOfTwoIntScaling(), + scaling_impl=RuntimeStatsScaling( + scaling_stats_impl=AbsMax(), + scaling_stats_input_view_shape_impl=Identity(), + scaling_shape=(1,), + scaling_stats_momentum=1.0, + restrict_scaling_impl=PowerOfTwoRestrictValue(), + ) + ), + }, + "train_args": { + "x": torch.tensor([32.]), + "scale": torch.tensor([1.]), + "zero_point": torch.tensor([0.]), + "input_bit_width": torch.tensor([8.]), + "signed": True, + }, + "eval_args": { + "x": torch.tensor([64.]), + "scale": torch.tensor([1.]), + "zero_point": torch.tensor([0.]), + "input_bit_width": torch.tensor([8.]), + "signed": True, + }, + "result": { # Result needs to match the order of the output tuple + "y": torch.tensor([28.]), + "scale": torch.tensor([4.]), + "zero_point": torch.tensor([0.]), + "bit_width": torch.tensor([4.]), + }, + }, { # abxmax_int-_overflow + "init_args": { + "bit_width_impl": BitWidthConst(4), + "float_to_int_impl": RoundSte(), + "trunc_scaling_impl": TruncScalingWrapper( + trunc_int_scaling_impl=PowerOfTwoIntScaling(), + scaling_impl=RuntimeStatsScaling( + scaling_stats_impl=AbsMax(), + scaling_stats_input_view_shape_impl=Identity(), + scaling_shape=(1,), + scaling_stats_momentum=1.0, + restrict_scaling_impl=PowerOfTwoRestrictValue(), + ) + ), + }, + "train_args": { + "x": torch.tensor([-16.]), + "scale": torch.tensor([1.]), + "zero_point": torch.tensor([0.]), + "input_bit_width": torch.tensor([8.]), + "signed": True, + }, + "eval_args": { + "x": torch.tensor([-32.]), + "scale": torch.tensor([1.]), + "zero_point": torch.tensor([0.]), + "input_bit_width": torch.tensor([8.]), + "signed": True, + }, + "result": { # Result needs to match the order of the output tuple + "y": torch.tensor([-16.]), + "scale": torch.tensor([2.]), + "zero_point": torch.tensor([0.]), + "bit_width": torch.tensor([4.]), + }, + }, { # abxmax_uint_underflow + "init_args": { + "bit_width_impl": BitWidthConst(4), + "float_to_int_impl": RoundSte(), + "trunc_scaling_impl": TruncScalingWrapper( + trunc_int_scaling_impl=PowerOfTwoIntScaling(), + scaling_impl=RuntimeStatsScaling( + scaling_stats_impl=AbsMax(), + scaling_stats_input_view_shape_impl=Identity(), + scaling_shape=(1,), + scaling_stats_momentum=1.0, + restrict_scaling_impl=PowerOfTwoRestrictValue(), + ) + ), + }, + "train_args": { + "x": torch.tensor([15.]), + "scale": torch.tensor([1.]), + "zero_point": torch.tensor([0.]), + "input_bit_width": torch.tensor([8.]), + "signed": False, + }, + "eval_args": { + "x": torch.tensor([.5]), + "scale": torch.tensor([1.]), + "zero_point": torch.tensor([0.]), + "input_bit_width": torch.tensor([8.]), + "signed": False, + }, + "result": { # Result needs to match the order of the output tuple + "y": torch.tensor([0.]), + "scale": torch.tensor([1.]), + "zero_point": torch.tensor([0.]), + "bit_width": torch.tensor([4.]), + }, + }, { # abxmax_int_underflow + "init_args": { + "bit_width_impl": BitWidthConst(4), + "float_to_int_impl": RoundSte(), + "trunc_scaling_impl": TruncScalingWrapper( + trunc_int_scaling_impl=PowerOfTwoIntScaling(), + scaling_impl=RuntimeStatsScaling( + scaling_stats_impl=AbsMax(), + scaling_stats_input_view_shape_impl=Identity(), + scaling_shape=(1,), + scaling_stats_momentum=1.0, + restrict_scaling_impl=PowerOfTwoRestrictValue(), + ) + ), + }, + "train_args": { + "x": torch.tensor([-8.]), + "scale": torch.tensor([1.]), + "zero_point": torch.tensor([0.]), + "input_bit_width": torch.tensor([8.]), + "signed": True, + }, + "eval_args": { + "x": torch.tensor([-.5]), + "scale": torch.tensor([1.]), + "zero_point": torch.tensor([0.]), + "input_bit_width": torch.tensor([8.]), + "signed": True, + }, + "result": { # Result needs to match the order of the output tuple + "y": torch.tensor([0.]), + "scale": torch.tensor([1.]), + "zero_point": torch.tensor([0.]), + "bit_width": torch.tensor([4.]), + }, + }, { # abxmax_uint_ulp + "init_args": { + "bit_width_impl": BitWidthConst(4), + "float_to_int_impl": RoundSte(), + "trunc_scaling_impl": TruncScalingWrapper( + trunc_int_scaling_impl=PowerOfTwoIntScaling(), + scaling_impl=RuntimeStatsScaling( + scaling_stats_impl=AbsMax(), + scaling_stats_input_view_shape_impl=Identity(), + scaling_shape=(1,), + scaling_stats_momentum=1.0, + restrict_scaling_impl=PowerOfTwoRestrictValue(), + ) + ), + }, + "train_args": { + "x": torch.tensor([31.]), + "scale": torch.tensor([1.]), + "zero_point": torch.tensor([0.]), + "input_bit_width": torch.tensor([8.]), + "signed": False, + }, + "eval_args": { + "x": torch.tensor([2.]), + "scale": torch.tensor([1.]), + "zero_point": torch.tensor([0.]), + "input_bit_width": torch.tensor([8.]), + "signed": False, + }, + "result": { # Result needs to match the order of the output tuple + "y": torch.tensor([2.]), + "scale": torch.tensor([2.]), + "zero_point": torch.tensor([0.]), + "bit_width": torch.tensor([4.]), + }, + }, { # abxmax_int_ulp + "init_args": { + "bit_width_impl": BitWidthConst(4), + "float_to_int_impl": RoundSte(), + "trunc_scaling_impl": TruncScalingWrapper( + trunc_int_scaling_impl=PowerOfTwoIntScaling(), + scaling_impl=RuntimeStatsScaling( + scaling_stats_impl=AbsMax(), + scaling_stats_input_view_shape_impl=Identity(), + scaling_shape=(1,), + scaling_stats_momentum=1.0, + restrict_scaling_impl=PowerOfTwoRestrictValue(), + ) + ), + }, + "train_args": { + "x": torch.tensor([-64.]), + "scale": torch.tensor([1.]), + "zero_point": torch.tensor([0.]), + "input_bit_width": torch.tensor([8.]), + "signed": True, + }, + "eval_args": { + "x": torch.tensor([-5.]), + "scale": torch.tensor([1.]), + "zero_point": torch.tensor([0.]), + "input_bit_width": torch.tensor([8.]), + "signed": True, + }, + "result": { # Result needs to match the order of the output tuple + "y": torch.tensor([-8.]), + "scale": torch.tensor([8.]), + "zero_point": torch.tensor([0.]), + "bit_width": torch.tensor([4.]), + }, + }, + ], + ids=[ + "defaults_uint_overflow", + "defaults_int+_overflow", + "defaults_int+_overflow_zp", + "defaults_int-_max", + "defaults_uint_underflow", + "defaults_int_underflow", + "defaults_uint_ulp", + "defaults_int_ulp", + "abxmax_uint_overflow", + "abxmax_int+_overflow", + "abxmax_int-_overflow", + "abxmax_uint_underflow", + "abxmax_int_underflow", + "abxmax_uint_ulp", + "abxmax_int_ulp", + ],) + # yapf: enable + def trunc_int_quant_io_fixture(self, test_cfg): + yield test_cfg + + def test_trunc_int_quant_io(self, caplog, trunc_int_quant_io_fixture): + caplog.set_level(logging.INFO) + test_cfg = trunc_int_quant_io_fixture + init_args = test_cfg["init_args"] + train_args = test_cfg["train_args"] + eval_args = test_cfg["eval_args"] + expected_result = test_cfg["result"] + trunc_int_quant = TruncIntQuant(**init_args) + trunc_int_quant.train() + y = trunc_int_quant(**train_args) + trunc_int_quant.eval() + with torch.no_grad(): + y = trunc_int_quant(**eval_args) + for i, k in enumerate(expected_result.keys()): + assert torch.allclose(expected_result[k], y[i], rtol=0.0, atol=0.0, equal_nan=False), f"Expected result[{k}]: {expected_result[k]}, result: {y[i]}" diff --git a/tests/brevitas_finn/brevitas/test_brevitas_avg_pool_export.py b/tests/brevitas_finn/brevitas/test_brevitas_avg_pool_export.py index aa0d28faf..a9c851d7e 100644 --- a/tests/brevitas_finn/brevitas/test_brevitas_avg_pool_export.py +++ b/tests/brevitas_finn/brevitas/test_brevitas_avg_pool_export.py @@ -26,27 +26,41 @@ @pytest.mark.parametrize("input_bit_width", [4, 8, 16]) @pytest.mark.parametrize("channels", [2, 4]) @pytest.mark.parametrize("idim", [7, 8]) +@pytest.mark.parametrize("restrict_scaling_type", ["log_fp", "power_of_two"]) def test_brevitas_avg_pool_export( - kernel_size, stride, signed, bit_width, input_bit_width, channels, idim, request): + kernel_size, + stride, + signed, + bit_width, + input_bit_width, + channels, + idim, + restrict_scaling_type, + request): if signed: quant_node = QuantIdentity( bit_width=input_bit_width, + restrict_scaling_type=restrict_scaling_type, return_quant_tensor=True, ) else: quant_node = QuantReLU( bit_width=input_bit_width, + restrict_scaling_type=restrict_scaling_type, return_quant_tensor=True, ) quant_avgpool = TruncAvgPool2d( kernel_size=kernel_size, stride=stride, bit_width=bit_width, float_to_int_impl_type='floor') model_brevitas = torch.nn.Sequential(quant_node, quant_avgpool) - model_brevitas.eval() # determine input input_shape = (1, channels, idim, idim) inp = torch.randn(input_shape) + model_brevitas.train() + model_brevitas(inp) + model_brevitas.eval() + model_brevitas(inp) # export test_id = request.node.callspec.id @@ -63,6 +77,10 @@ def test_brevitas_avg_pool_export( odict = oxe.execute_onnx(model, idict, True) finn_output = odict[model.graph.output[0].name] # compare outputs - assert np.isclose(ref_output_array, finn_output).all() + if restrict_scaling_type == "power_of_two" and kernel_size == 2: + atol = 1e-8 + else: + atol = quant_avgpool.trunc_quant.scale().detach().numpy() # Allow "off-by-1" errors + assert np.isclose(ref_output_array, finn_output, atol=atol).all() # cleanup os.remove(export_path) diff --git a/tests/brevitas_finn/brevitas_examples/test_mobilenet_finn_export.py b/tests/brevitas_finn/brevitas_examples/test_mobilenet_finn_export.py index 01505b9e5..6e4648e30 100644 --- a/tests/brevitas_finn/brevitas_examples/test_mobilenet_finn_export.py +++ b/tests/brevitas_finn/brevitas_examples/test_mobilenet_finn_export.py @@ -24,7 +24,8 @@ reason='Issue with ORT and MobileNet export on MacOS on PyTorch >= 1.5.0') INPUT_SIZE = (1, 3, 224, 224) -ATOL = 1e-3 +ATOL = 7 # How many bitflips to tolerate in the 32-bit output +RTOL = 1e-2 SEED = 0 @@ -42,6 +43,7 @@ def test_mobilenet_v1_4b(pretrained): # do forward pass in PyTorch/Brevitas expected = mobilenet(torch_tensor).detach().numpy() export_qonnx(mobilenet, input_shape=INPUT_SIZE, export_path=finn_onnx) + output_scale = mobilenet.output.bias_quant.scale() # Scale at the output model = ModelWrapper(finn_onnx) model = model.transform(GiveUniqueNodeNames()) model = model.transform(DoubleToSingleFloat()) @@ -53,4 +55,4 @@ def test_mobilenet_v1_4b(pretrained): input_dict = {inp_name: numpy_tensor} output_dict = oxe.execute_onnx(model, input_dict) produced = output_dict[list(output_dict.keys())[0]] - assert np.isclose(produced, expected, atol=ATOL).all() + assert np.isclose(produced, expected, rtol=RTOL, atol=ATOL * output_scale).all()