From 767252d9d05dba3b6658bc00dad922063742d67c Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Fri, 4 Oct 2024 13:56:34 +0100 Subject: [PATCH 01/32] Fix (quant_tensor): Produce valid IntQuantTensor after AvgPool functional call --- src/brevitas/quant_tensor/int_torch_handler.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/brevitas/quant_tensor/int_torch_handler.py b/src/brevitas/quant_tensor/int_torch_handler.py index 3258b8914..691138057 100644 --- a/src/brevitas/quant_tensor/int_torch_handler.py +++ b/src/brevitas/quant_tensor/int_torch_handler.py @@ -115,6 +115,7 @@ def avg_pool2d_handler( avg_scaling = kernel_size * kernel_size quant_input = quant_input.set(value=x) + quant_input = quant_input.set(scale=quant_input.scale / avg_scaling) quant_input = quant_input.set(bit_width=max_acc_bit_width(quant_input.bit_width, avg_scaling)) return quant_input @@ -134,6 +135,7 @@ def adaptive_avg_pool2d_handler(quant_input, output_shape): reduce_size = reduce(mul, k_size, 1) quant_input = quant_input.set(value=x) + quant_input = quant_input.set(scale=quant_input.scale / reduce_size) quant_input = quant_input.set(bit_width=max_acc_bit_width(quant_input.bit_width, reduce_size)) return quant_input From 13ca170bb0183a04523b25e409568817de639c73 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Fri, 4 Oct 2024 13:58:50 +0100 Subject: [PATCH 02/32] Fix (core/trunc): Fix output scaling after truncation --- src/brevitas/core/quant/int.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/brevitas/core/quant/int.py b/src/brevitas/core/quant/int.py index 328ad63b3..cb294cdda 100644 --- a/src/brevitas/core/quant/int.py +++ b/src/brevitas/core/quant/int.py @@ -217,12 +217,13 @@ def forward(self, x: Tensor, scale: Tensor, zero_point: Tensor, output_bit_width = self.msb_clamp_bit_width_impl() trunc_bit_width = input_bit_width - output_bit_width trunc_scale = 2.0 ** trunc_bit_width + output_scale = scale * trunc_scale y = y / trunc_scale y = self.float_to_int_impl(y) y = y - zero_point - y = y * scale + y = y * output_scale y = self.delay_wrapper(x, y) - return y, scale, zero_point, output_bit_width + return y, output_scale, zero_point, output_bit_width class DecoupledRescalingIntQuantWithInput(DecoupledRescalingIntQuant): From f183191478bd29691a2baf1ec784a3c37395af37 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Fri, 4 Oct 2024 13:59:46 +0100 Subject: [PATCH 03/32] Fix (nn/TruncAvgPool): Remove any quant tensor manual manipulation. --- src/brevitas/nn/quant_avg_pool.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/brevitas/nn/quant_avg_pool.py b/src/brevitas/nn/quant_avg_pool.py index 12f324901..1bd284eda 100644 --- a/src/brevitas/nn/quant_avg_pool.py +++ b/src/brevitas/nn/quant_avg_pool.py @@ -58,6 +58,7 @@ def _avg_scaling(self): else: return self.kernel_size * self.kernel_size + # TODO: Replace with functional call def forward(self, input: Union[Tensor, QuantTensor]): x = self.unpack_input(input) @@ -71,8 +72,6 @@ def forward(self, input: Union[Tensor, QuantTensor]): if not isinstance(x, QuantTensor): x = self.cache_class.quant_tensor.set(value=x) y = AvgPool2d.forward(self, x) - rescaled_value = y.value * self._avg_scaling - y = y.set(value=rescaled_value) y = self.trunc_quant(y) else: y = AvgPool2d.forward(self, _unpack_quant_tensor(x)) @@ -123,6 +122,7 @@ def compute_kernel_size_stride(input_shape, output_shape): stride_list.append(stride) return kernel_size_list, stride_list + # TODO: Replace with functional call def forward(self, input: Union[Tensor, QuantTensor]): x = self.unpack_input(input) @@ -139,10 +139,6 @@ def forward(self, input: Union[Tensor, QuantTensor]): if not isinstance(x, QuantTensor): x = self.cache_class.quant_tensor.set(value=x) y = AdaptiveAvgPool2d.forward(self, x) - k_size, stride = self.compute_kernel_size_stride(x.value.shape[2:], y.value.shape[2:]) - reduce_size = reduce(mul, k_size, 1) - rescaled_value = y.value * reduce_size # remove avg scaling - y = y.set(value=rescaled_value) y = self.trunc_quant(y) else: y = AdaptiveAvgPool2d.forward(self, _unpack_quant_tensor(x)) From 859897c073bfda63c13b03afa0a23c1d1586619f Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Wed, 22 Jan 2025 16:25:11 +0000 Subject: [PATCH 04/32] fix/trunc_avg_pool: Clamp output. --- src/brevitas/core/quant/int.py | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/src/brevitas/core/quant/int.py b/src/brevitas/core/quant/int.py index cb294cdda..8110138cb 100644 --- a/src/brevitas/core/quant/int.py +++ b/src/brevitas/core/quant/int.py @@ -11,6 +11,9 @@ from brevitas.core.quant.delay import DelayWrapper from brevitas.core.utils import StatelessBuffer from brevitas.function.ops_ste import round_ste +from brevitas.core.function_wrapper import TensorClamp +from brevitas.function.ops import max_int +from brevitas.function.ops import min_int class PrescaledRestrictIntQuantWithInputBitWidth(brevitas.jit.ScriptModule): @@ -201,13 +204,32 @@ class TruncIntQuant(brevitas.jit.ScriptModule): """ """ + __constants__ = ['signed', 'narrow_range'] + def __init__( - self, float_to_int_impl: Module, bit_width_impl: Module, quant_delay_steps: int = 0): + self, + narrow_range: bool, + signed: bool, + float_to_int_impl: Module, + bit_width_impl: Module, + tensor_clamp_impl: Module = TensorClamp(), + quant_delay_steps: int = 0): super(TruncIntQuant, self).__init__() + self.signed = signed + self.narrow_range = narrow_range self.msb_clamp_bit_width_impl = bit_width_impl self.float_to_int_impl = float_to_int_impl + self.tensor_clamp_impl = tensor_clamp_impl self.delay_wrapper = DelayWrapper(quant_delay_steps) + @brevitas.jit.script_method + def min_int(self, bit_width): + return min_int(self.signed, self.narrow_range, bit_width) + + @brevitas.jit.script_method + def max_int(self, bit_width): + return max_int(self.signed, self.narrow_range, bit_width) + @brevitas.jit.script_method def forward(self, x: Tensor, scale: Tensor, zero_point: Tensor, input_bit_width: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: @@ -219,7 +241,10 @@ def forward(self, x: Tensor, scale: Tensor, zero_point: Tensor, trunc_scale = 2.0 ** trunc_bit_width output_scale = scale * trunc_scale y = y / trunc_scale + min_int_val = self.min_int(output_bit_width) + max_int_val = self.max_int(output_bit_width) y = self.float_to_int_impl(y) + y = self.tensor_clamp_impl(y, min_val=min_int_val, max_val=max_int_val) y = y - zero_point y = y * output_scale y = self.delay_wrapper(x, y) From f5752d69ef9728e78092158b9e9178326b4a62fb Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Wed, 22 Jan 2025 16:29:07 +0000 Subject: [PATCH 05/32] style: fix --- src/brevitas/core/quant/int.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/brevitas/core/quant/int.py b/src/brevitas/core/quant/int.py index 8110138cb..dd0a900be 100644 --- a/src/brevitas/core/quant/int.py +++ b/src/brevitas/core/quant/int.py @@ -8,12 +8,12 @@ from torch.nn import Module import brevitas +from brevitas.core.function_wrapper import TensorClamp from brevitas.core.quant.delay import DelayWrapper from brevitas.core.utils import StatelessBuffer -from brevitas.function.ops_ste import round_ste -from brevitas.core.function_wrapper import TensorClamp from brevitas.function.ops import max_int from brevitas.function.ops import min_int +from brevitas.function.ops_ste import round_ste class PrescaledRestrictIntQuantWithInputBitWidth(brevitas.jit.ScriptModule): From 9b2fdf7fc4c72b5398daa73023a124d269981bdd Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Wed, 22 Jan 2025 16:37:16 +0000 Subject: [PATCH 06/32] fix (trunc_avg_pool): Set default arguments for backward compatibility --- src/brevitas/core/quant/int.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/brevitas/core/quant/int.py b/src/brevitas/core/quant/int.py index dd0a900be..114812b16 100644 --- a/src/brevitas/core/quant/int.py +++ b/src/brevitas/core/quant/int.py @@ -208,10 +208,10 @@ class TruncIntQuant(brevitas.jit.ScriptModule): def __init__( self, - narrow_range: bool, - signed: bool, float_to_int_impl: Module, bit_width_impl: Module, + narrow_range: bool = False, + signed: bool = True, tensor_clamp_impl: Module = TensorClamp(), quant_delay_steps: int = 0): super(TruncIntQuant, self).__init__() From 76dfe8e13c76d729efd51bb6e1b9eebbd32aa967 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Wed, 22 Jan 2025 17:00:17 +0000 Subject: [PATCH 07/32] test (trunc_int_quant): Added initial sanity-check test --- tests/brevitas/core/test_trunc_int_quant.py | 22 +++++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 tests/brevitas/core/test_trunc_int_quant.py diff --git a/tests/brevitas/core/test_trunc_int_quant.py b/tests/brevitas/core/test_trunc_int_quant.py new file mode 100644 index 000000000..312785076 --- /dev/null +++ b/tests/brevitas/core/test_trunc_int_quant.py @@ -0,0 +1,22 @@ +# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +from hypothesis import given +import mock +import torch + +from brevitas.core.function_wrapper import RoundSte +from brevitas.core.function_wrapper import TensorClamp +from brevitas.core.quant import TruncIntQuant +from tests.brevitas.core.bit_width_fixture import * # noqa +from tests.brevitas.core.int_quant_fixture import * # noqa + + +class TestTruncIntQuantUnit: + + def test_trunc_int_quant_defaults(self, bit_width_const): + trunc_int_quant = TruncIntQuant( + bit_width_impl=bit_width_const, float_to_int_impl=RoundSte()) + assert isinstance(trunc_int_quant.tensor_clamp_impl, TensorClamp) + assert trunc_int_quant.narrow_range == False + assert trunc_int_quant.signed == True From b2d849f7718c300178979e8baf967233352cd34c Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Wed, 22 Jan 2025 17:50:16 +0000 Subject: [PATCH 08/32] fix (export/torch/qcdq): Fixed output scale, and `signed` setting --- src/brevitas/export/common/handler/qcdq.py | 6 +++--- tests/brevitas/export/test_torch_qcdq.py | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/brevitas/export/common/handler/qcdq.py b/src/brevitas/export/common/handler/qcdq.py index 39347baad..091fd48c4 100644 --- a/src/brevitas/export/common/handler/qcdq.py +++ b/src/brevitas/export/common/handler/qcdq.py @@ -802,9 +802,9 @@ def symbolic_execution( signed=signed, narrow=False, bit_width=output_bit_width) if clip_symbolic_kwargs is not None: x = self.clip_fn(x, *clip_symbolic_kwargs.values()) - x = self.dequantize_fn(x, flat_scale, zp, self.quant_axis(scale)) + x = self.dequantize_fn(x, flat_pre_scale, zp, self.quant_axis(scale)) # After dequantization, cast both output and scale to the correct dtype if scale_dtype == torch.float16 or scale_dtype == torch.bfloat16: x = self.cast_fn(x, scale_dtype) - scale = self.cast_fn(scale, scale_dtype) - return x, scale, zero_point, output_bit_width + flat_pre_scale = self.cast_fn(flat_pre_scale, scale_dtype) + return x, flat_pre_scale, zero_point, output_bit_width diff --git a/tests/brevitas/export/test_torch_qcdq.py b/tests/brevitas/export/test_torch_qcdq.py index 6333be174..63dc2d50f 100644 --- a/tests/brevitas/export/test_torch_qcdq.py +++ b/tests/brevitas/export/test_torch_qcdq.py @@ -64,7 +64,8 @@ def test_torch_qcdq_avgpool_export(input_signed, output_bit_width): inp = torch.randn(in_size) quant_module = nn.Sequential( QuantIdentity(signed=input_signed, return_quant_tensor=True), - TruncAvgPool2d(kernel_size=3, stride=2, float_to_int_impl_type='round')) + TruncAvgPool2d( + kernel_size=3, stride=2, signed=input_signed, float_to_int_impl_type='round')) quant_module(inp) # Collect scale factors quant_module.eval() inp = torch.randn(in_size) * IN_SCALE + IN_MEAN # redefine inp for testing From 16717c53c815d8d6a086b4463d95b21c90819b4d Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Wed, 22 Jan 2025 17:52:27 +0000 Subject: [PATCH 09/32] Fix (core/proxy/trunc): Moved setting of signed to the proxy --- src/brevitas/core/quant/int.py | 21 ++++++++++----------- src/brevitas/proxy/runtime_quant.py | 2 +- tests/brevitas/core/test_trunc_int_quant.py | 1 - tests/brevitas/export/test_torch_qcdq.py | 3 +-- 4 files changed, 12 insertions(+), 15 deletions(-) diff --git a/src/brevitas/core/quant/int.py b/src/brevitas/core/quant/int.py index 114812b16..1cd4ed026 100644 --- a/src/brevitas/core/quant/int.py +++ b/src/brevitas/core/quant/int.py @@ -204,18 +204,16 @@ class TruncIntQuant(brevitas.jit.ScriptModule): """ """ - __constants__ = ['signed', 'narrow_range'] + __constants__ = ['narrow_range'] def __init__( self, float_to_int_impl: Module, bit_width_impl: Module, narrow_range: bool = False, - signed: bool = True, tensor_clamp_impl: Module = TensorClamp(), quant_delay_steps: int = 0): super(TruncIntQuant, self).__init__() - self.signed = signed self.narrow_range = narrow_range self.msb_clamp_bit_width_impl = bit_width_impl self.float_to_int_impl = float_to_int_impl @@ -223,16 +221,17 @@ def __init__( self.delay_wrapper = DelayWrapper(quant_delay_steps) @brevitas.jit.script_method - def min_int(self, bit_width): - return min_int(self.signed, self.narrow_range, bit_width) + def min_int(self, bit_width, signed): + return min_int(signed, self.narrow_range, bit_width) @brevitas.jit.script_method - def max_int(self, bit_width): - return max_int(self.signed, self.narrow_range, bit_width) + def max_int(self, bit_width, signed): + return max_int(signed, self.narrow_range, bit_width) @brevitas.jit.script_method - def forward(self, x: Tensor, scale: Tensor, zero_point: Tensor, - input_bit_width: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + def forward( + self, x: Tensor, scale: Tensor, zero_point: Tensor, input_bit_width: Tensor, + signed: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: y = x / scale y = y + zero_point y = round_ste(y) # clean up floating point error @@ -241,8 +240,8 @@ def forward(self, x: Tensor, scale: Tensor, zero_point: Tensor, trunc_scale = 2.0 ** trunc_bit_width output_scale = scale * trunc_scale y = y / trunc_scale - min_int_val = self.min_int(output_bit_width) - max_int_val = self.max_int(output_bit_width) + min_int_val = self.min_int(output_bit_width, signed) + max_int_val = self.max_int(output_bit_width, signed) y = self.float_to_int_impl(y) y = self.tensor_clamp_impl(y, min_val=min_int_val, max_val=max_int_val) y = y - zero_point diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index 3e7248602..b67578f68 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -279,7 +279,7 @@ def forward(self, x: IntQuantTensor) -> Union[Tensor, IntQuantTensor]: out_tuple = self.export_handler( x.value, x.scale, x.zero_point, x.bit_width, x.signed) else: - out_tuple = self.tensor_quant(x.value, x.scale, x.zero_point, x.bit_width) + out_tuple = self.tensor_quant(x.value, x.scale, x.zero_point, x.bit_width, x.signed) out_value, out_scale, out_zp, out_bit_width = out_tuple if self.skip_create_quant_tensor: return out_value diff --git a/tests/brevitas/core/test_trunc_int_quant.py b/tests/brevitas/core/test_trunc_int_quant.py index 312785076..93d8669aa 100644 --- a/tests/brevitas/core/test_trunc_int_quant.py +++ b/tests/brevitas/core/test_trunc_int_quant.py @@ -19,4 +19,3 @@ def test_trunc_int_quant_defaults(self, bit_width_const): bit_width_impl=bit_width_const, float_to_int_impl=RoundSte()) assert isinstance(trunc_int_quant.tensor_clamp_impl, TensorClamp) assert trunc_int_quant.narrow_range == False - assert trunc_int_quant.signed == True diff --git a/tests/brevitas/export/test_torch_qcdq.py b/tests/brevitas/export/test_torch_qcdq.py index 63dc2d50f..6333be174 100644 --- a/tests/brevitas/export/test_torch_qcdq.py +++ b/tests/brevitas/export/test_torch_qcdq.py @@ -64,8 +64,7 @@ def test_torch_qcdq_avgpool_export(input_signed, output_bit_width): inp = torch.randn(in_size) quant_module = nn.Sequential( QuantIdentity(signed=input_signed, return_quant_tensor=True), - TruncAvgPool2d( - kernel_size=3, stride=2, signed=input_signed, float_to_int_impl_type='round')) + TruncAvgPool2d(kernel_size=3, stride=2, float_to_int_impl_type='round')) quant_module(inp) # Collect scale factors quant_module.eval() inp = torch.randn(in_size) * IN_SCALE + IN_MEAN # redefine inp for testing From faf973c6501515bbc23f94d63a9baed67d2e2b21 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Thu, 23 Jan 2025 11:53:34 +0000 Subject: [PATCH 10/32] fix (qonnx/trunc): Fixed Trunc Quant QONNX export --- src/brevitas/export/onnx/qonnx/function.py | 39 +++++++++++++++++----- src/brevitas/export/onnx/qonnx/handler.py | 6 ++-- src/brevitas/proxy/runtime_quant.py | 5 +++ 3 files changed, 39 insertions(+), 11 deletions(-) diff --git a/src/brevitas/export/onnx/qonnx/function.py b/src/brevitas/export/onnx/qonnx/function.py index 5160572ef..2c368943f 100644 --- a/src/brevitas/export/onnx/qonnx/function.py +++ b/src/brevitas/export/onnx/qonnx/function.py @@ -111,7 +111,16 @@ def forward( class BrevitasTruncFn(Function): @staticmethod - def symbolic(g, x, scale, zero_point, input_bit_width, output_bit_width, rounding_mode): + def symbolic( + g, + x, + scale, + zero_point, + input_bit_width, + signed, + narrow_range, + output_bit_width, + rounding_mode): ret = g.op( f'{DOMAIN_STRING}::Trunc', x, @@ -119,18 +128,30 @@ def symbolic(g, x, scale, zero_point, input_bit_width, output_bit_width, roundin zero_point, input_bit_width, output_bit_width, - rounding_mode_s=rounding_mode) + rounding_mode_s=rounding_mode, + signed_i=int(signed), + narrow_i=int(narrow_range)) ret.setType(x.type()) return ret @staticmethod - def forward(ctx, x, scale, zero_point, input_bit_width, output_bit_width, rounding_mode): - float_to_int_impl = solve_float_to_int_impl_from_enum(rounding_mode) - trunc = TruncIntQuant( - float_to_int_impl=float_to_int_impl(), - bit_width_impl=BitWidthConst(int(output_bit_width))) - y_tuple = trunc(x, scale, zero_point, input_bit_width) - return y_tuple[0] + def forward( + ctx, + x, + scale, + zero_point, + input_bit_width, + signed, + narrow_range, + output_bit_width, + rounding_mode): + # TODO: Restore this (fails when `signed` arg added) + #float_to_int_impl = solve_float_to_int_impl_from_enum(rounding_mode) + #trunc = TruncIntQuant( + # float_to_int_impl=float_to_int_impl(), + # bit_width_impl=BitWidthConst(int(output_bit_width))) + #y_tuple = trunc(x, scale, zero_point, input_bit_width, torch.tensor(signed, dtype=torch.bool, device=x.device)) + return x class BrevitasQuantLSTMCellFn(Function): diff --git a/src/brevitas/export/onnx/qonnx/handler.py b/src/brevitas/export/onnx/qonnx/handler.py index 5468cd1aa..194e67e6a 100644 --- a/src/brevitas/export/onnx/qonnx/handler.py +++ b/src/brevitas/export/onnx/qonnx/handler.py @@ -229,13 +229,15 @@ class BrevitasTruncQuantProxyHandler(ONNXBaseHandler): def prepare_for_export(self, module: TruncQuantProxyFromInjector): self.symbolic_kwargs = { - 'output_bit_width': module.bit_width(), 'rounding_mode': module.rounding_mode} + 'narrow_range': module.is_narrow_range, + 'output_bit_width': module.bit_width(), + 'rounding_mode': module.rounding_mode} def symbolic_execution( self, x: Tensor, scale: Tensor, zero_point: Tensor, input_bit_width: Tensor, signed: Tensor): y = BrevitasTruncFn.apply( - x, scale, zero_point, input_bit_width, *self.symbolic_kwargs.values()) + x, scale, zero_point, input_bit_width, signed, *self.symbolic_kwargs.values()) return y, scale, zero_point, self.symbolic_kwargs['output_bit_width'] diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index b67578f68..c5d5f68fa 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -264,6 +264,11 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.skip_create_quant_tensor = False + @property + def is_narrow_range(self): + narrow_range = super(TruncQuantProxyFromInjector, self).is_narrow_range + return narrow_range if narrow_range is not None else False + def bit_width(self): if not self.is_quant_enabled: return None From 1d6371115f41b529a99d104c65d75d33113a7433 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Fri, 24 Jan 2025 14:41:02 +0000 Subject: [PATCH 11/32] fix (trunc): Factored out scaling calculation to standalone class. --- src/brevitas/core/quant/int.py | 16 +++++---- src/brevitas/core/scaling/__init__.py | 3 ++ src/brevitas/core/scaling/int_scaling.py | 12 +++++++ src/brevitas/core/scaling/standalone.py | 39 +++++++++++++++++++++ src/brevitas/function/ops.py | 6 ++-- tests/brevitas/core/test_trunc_int_quant.py | 2 ++ 6 files changed, 69 insertions(+), 9 deletions(-) diff --git a/src/brevitas/core/quant/int.py b/src/brevitas/core/quant/int.py index 1cd4ed026..3d17efc99 100644 --- a/src/brevitas/core/quant/int.py +++ b/src/brevitas/core/quant/int.py @@ -1,7 +1,7 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -from typing import Optional, Tuple +from typing import Optional, Tuple, Union import torch from torch import Tensor @@ -10,6 +10,7 @@ import brevitas from brevitas.core.function_wrapper import TensorClamp from brevitas.core.quant.delay import DelayWrapper +from brevitas.core.scaling import TruncMsbScaling from brevitas.core.utils import StatelessBuffer from brevitas.function.ops import max_int from brevitas.function.ops import min_int @@ -210,40 +211,41 @@ def __init__( self, float_to_int_impl: Module, bit_width_impl: Module, + trunc_scaling_impl: Module = TruncMsbScaling(), narrow_range: bool = False, tensor_clamp_impl: Module = TensorClamp(), quant_delay_steps: int = 0): super(TruncIntQuant, self).__init__() self.narrow_range = narrow_range self.msb_clamp_bit_width_impl = bit_width_impl + self.trunc_scaling_impl = trunc_scaling_impl self.float_to_int_impl = float_to_int_impl self.tensor_clamp_impl = tensor_clamp_impl self.delay_wrapper = DelayWrapper(quant_delay_steps) @brevitas.jit.script_method - def min_int(self, bit_width, signed): + def min_int(self, bit_width: Tensor, signed: Union[bool, Tensor]): return min_int(signed, self.narrow_range, bit_width) @brevitas.jit.script_method - def max_int(self, bit_width, signed): + def max_int(self, bit_width: Tensor, signed: Union[bool, Tensor]): return max_int(signed, self.narrow_range, bit_width) @brevitas.jit.script_method def forward( self, x: Tensor, scale: Tensor, zero_point: Tensor, input_bit_width: Tensor, - signed: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + signed: Union[bool, Tensor]) -> Tuple[Tensor, Tensor, Tensor, Tensor]: y = x / scale y = y + zero_point y = round_ste(y) # clean up floating point error output_bit_width = self.msb_clamp_bit_width_impl() - trunc_bit_width = input_bit_width - output_bit_width - trunc_scale = 2.0 ** trunc_bit_width - output_scale = scale * trunc_scale + trunc_scale = self.trunc_scaling_impl(y, input_bit_width, output_bit_width, signed) y = y / trunc_scale min_int_val = self.min_int(output_bit_width, signed) max_int_val = self.max_int(output_bit_width, signed) y = self.float_to_int_impl(y) y = self.tensor_clamp_impl(y, min_val=min_int_val, max_val=max_int_val) + output_scale = scale * trunc_scale y = y - zero_point y = y * output_scale y = self.delay_wrapper(x, y) diff --git a/src/brevitas/core/scaling/__init__.py b/src/brevitas/core/scaling/__init__.py index c21cb4b27..18e0f08b9 100644 --- a/src/brevitas/core/scaling/__init__.py +++ b/src/brevitas/core/scaling/__init__.py @@ -9,6 +9,7 @@ from .float_scaling import FloatScaling from .int_scaling import IntScaling from .int_scaling import PowerOfTwoIntScaling +from .int_scaling import TruncPowerOfTwoIntScaling from .pre_scaling import AccumulatorAwareParameterPreScaling from .pre_scaling import AccumulatorAwareZeroCenterParameterPreScaling from .pre_scaling import ParameterPreScalingWeightNorm @@ -18,5 +19,7 @@ from .standalone import ParameterFromRuntimeStatsScaling from .standalone import ParameterFromStatsFromParameterScaling from .standalone import ParameterScaling +from .standalone import TruncMsbScaling +from .standalone import TruncScalingWrapper SCALING_STATS_REDUCE_DIM = 1 diff --git a/src/brevitas/core/scaling/int_scaling.py b/src/brevitas/core/scaling/int_scaling.py index 2c58db598..07e378813 100644 --- a/src/brevitas/core/scaling/int_scaling.py +++ b/src/brevitas/core/scaling/int_scaling.py @@ -1,6 +1,8 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +from typing import Union + from torch import Tensor import brevitas @@ -34,3 +36,13 @@ def __init__(self, signed: bool): @brevitas.jit.script_method def forward(self, bit_width: Tensor) -> Tensor: return max_int(self.signed, False, bit_width) + 1 + + +class TruncPowerOfTwoIntScaling(brevitas.jit.ScriptModule): + + def __init__(self): + super(TruncPowerOfTwoIntScaling, self).__init__() + + @brevitas.jit.script_method + def forward(self, bit_width: Tensor, signed: Union[bool, Tensor]) -> Tensor: + return max_int(signed, False, bit_width) + 1 diff --git a/src/brevitas/core/scaling/standalone.py b/src/brevitas/core/scaling/standalone.py index 703fed5a4..96fbdc3d5 100644 --- a/src/brevitas/core/scaling/standalone.py +++ b/src/brevitas/core/scaling/standalone.py @@ -12,6 +12,7 @@ import brevitas.config as config from brevitas.core.function_wrapper import Identity from brevitas.core.function_wrapper import OverBatchOverTensorView +from brevitas.core.function_wrapper import TensorClamp from brevitas.core.restrict_val import _ClampValue from brevitas.core.restrict_val import _RestrictClampValue from brevitas.core.restrict_val import _RestrictValue @@ -469,3 +470,41 @@ def _load_from_state_dict( self.counter = self.collect_stats_steps + 1 if config.IGNORE_MISSING_KEYS and value_key in missing_keys: missing_keys.remove(value_key) + + +class TruncMsbScaling(brevitas.jit.ScriptModule): + """ + """ + + def __init__(self) -> None: + super(TruncMsbScaling, self).__init__() + + @brevitas.jit.script_method + def forward(self, scaling_input: Tensor, input_bitwidth: Tensor, output_bitwidth: Tensor, signed: Union[bool, Tensor]) -> Tensor: + return 2**(input_bitwidth - output_bitwidth) + + +class TruncScalingWrapper(brevitas.jit.ScriptModule): + """ + """ + + def __init__( + self, + trunc_int_scaling_impl: Module, + scaling_impl: Module, + tensor_clamp_impl: Module = TensorClamp()) -> None: + super(TruncScalingWrapper, self).__init__() + self.trunc_int_scaling_impl = trunc_int_scaling_impl + self.scaling_impl = scaling_impl + self.tensor_clamp_impl = tensor_clamp_impl + + @brevitas.jit.script_method + def forward(self, scaling_input: Tensor, input_bitwidth: Tensor, output_bitwidth: Tensor, signed: Union[bool, Tensor]) -> Tensor: + threshold = self.trunc_int_scaling_impl(output_bit_width, signed) + scale = self.scaling_impl(scaling_input, threshold) + msb_scale = 2**(input_bitwidth - output_bitwidth) + unit_scale = torch.ones_like(msb_scale) + max_scale = torch.where(msb_scale > unit_scale, msb_scale, unit_scale) + min_scale = torch.where(msb_scale < unit_scale, msb_scale, unit_scale) + trunc_scale = self.tensor_clamp_impl(scale, min_scale, max_scale) + return trunc_scale diff --git a/src/brevitas/function/ops.py b/src/brevitas/function/ops.py index 74da08e19..f836eed63 100644 --- a/src/brevitas/function/ops.py +++ b/src/brevitas/function/ops.py @@ -6,6 +6,8 @@ The implemented functions adheres to the restriction imposed by Pytorch 1.1.0's TorchScript compiler. """ +from typing import Union + import torch from torch import Tensor @@ -128,7 +130,7 @@ def identity(x: Tensor) -> Tensor: @brevitas.jit.script -def max_int(signed: bool, narrow_range: bool, bit_width: Tensor) -> Tensor: +def max_int(signed: Union[bool, Tensor], narrow_range: Union[bool, Tensor], bit_width: Tensor) -> Tensor: """ Compute the maximum integer representable by a given number of bits. Args: @@ -159,7 +161,7 @@ def max_int(signed: bool, narrow_range: bool, bit_width: Tensor) -> Tensor: @brevitas.jit.script -def min_int(signed: bool, narrow_range: bool, bit_width: Tensor) -> Tensor: +def min_int(signed: Union[bool, Tensor], narrow_range: Union[bool, Tensor], bit_width: Tensor) -> Tensor: """ Compute the minimum integer representable by a given number of bits. Args: diff --git a/tests/brevitas/core/test_trunc_int_quant.py b/tests/brevitas/core/test_trunc_int_quant.py index 93d8669aa..2497ef2f2 100644 --- a/tests/brevitas/core/test_trunc_int_quant.py +++ b/tests/brevitas/core/test_trunc_int_quant.py @@ -8,6 +8,7 @@ from brevitas.core.function_wrapper import RoundSte from brevitas.core.function_wrapper import TensorClamp from brevitas.core.quant import TruncIntQuant +from brevitas.core.scaling import TruncMsbScaling from tests.brevitas.core.bit_width_fixture import * # noqa from tests.brevitas.core.int_quant_fixture import * # noqa @@ -18,4 +19,5 @@ def test_trunc_int_quant_defaults(self, bit_width_const): trunc_int_quant = TruncIntQuant( bit_width_impl=bit_width_const, float_to_int_impl=RoundSte()) assert isinstance(trunc_int_quant.tensor_clamp_impl, TensorClamp) + assert isinstance(trunc_int_quant.trunc_scaling_impl, TruncMsbScaling) assert trunc_int_quant.narrow_range == False From 14cdda2737aba83f759b04a357255f9cd0ba3c31 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Fri, 24 Jan 2025 14:42:23 +0000 Subject: [PATCH 12/32] fix typo: Updated comment in TruncAvgPool export --- src/brevitas/export/onnx/qonnx/function.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/brevitas/export/onnx/qonnx/function.py b/src/brevitas/export/onnx/qonnx/function.py index 2c368943f..dc32cece9 100644 --- a/src/brevitas/export/onnx/qonnx/function.py +++ b/src/brevitas/export/onnx/qonnx/function.py @@ -150,7 +150,7 @@ def forward( #trunc = TruncIntQuant( # float_to_int_impl=float_to_int_impl(), # bit_width_impl=BitWidthConst(int(output_bit_width))) - #y_tuple = trunc(x, scale, zero_point, input_bit_width, torch.tensor(signed, dtype=torch.bool, device=x.device)) + #y_tuple = trunc(x, scale, zero_point, input_bit_width, signed) return x From 7d3ee729a28ede399da50d212ed01c8087d2a312 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Fri, 24 Jan 2025 17:39:59 +0000 Subject: [PATCH 13/32] feat (trunc/scaling): Factored out the scaling implementation. --- src/brevitas/core/quant/int.py | 6 +- src/brevitas/core/scaling/standalone.py | 26 ++++++--- src/brevitas/function/ops.py | 8 ++- tests/brevitas/core/test_trunc_int_quant.py | 63 ++++++++++++++++++++- 4 files changed, 90 insertions(+), 13 deletions(-) diff --git a/src/brevitas/core/quant/int.py b/src/brevitas/core/quant/int.py index 3d17efc99..0583a43ae 100644 --- a/src/brevitas/core/quant/int.py +++ b/src/brevitas/core/quant/int.py @@ -233,7 +233,11 @@ def max_int(self, bit_width: Tensor, signed: Union[bool, Tensor]): @brevitas.jit.script_method def forward( - self, x: Tensor, scale: Tensor, zero_point: Tensor, input_bit_width: Tensor, + self, + x: Tensor, + scale: Tensor, + zero_point: Tensor, + input_bit_width: Tensor, signed: Union[bool, Tensor]) -> Tuple[Tensor, Tensor, Tensor, Tensor]: y = x / scale y = y + zero_point diff --git a/src/brevitas/core/scaling/standalone.py b/src/brevitas/core/scaling/standalone.py index 96fbdc3d5..e9c73d6bd 100644 --- a/src/brevitas/core/scaling/standalone.py +++ b/src/brevitas/core/scaling/standalone.py @@ -480,8 +480,13 @@ def __init__(self) -> None: super(TruncMsbScaling, self).__init__() @brevitas.jit.script_method - def forward(self, scaling_input: Tensor, input_bitwidth: Tensor, output_bitwidth: Tensor, signed: Union[bool, Tensor]) -> Tensor: - return 2**(input_bitwidth - output_bitwidth) + def forward( + self, + scaling_input: Tensor, + input_bitwidth: Tensor, + output_bitwidth: Tensor, + signed: Union[bool, Tensor]) -> Tensor: + return 2 ** (input_bitwidth - output_bitwidth) class TruncScalingWrapper(brevitas.jit.ScriptModule): @@ -489,20 +494,25 @@ class TruncScalingWrapper(brevitas.jit.ScriptModule): """ def __init__( - self, - trunc_int_scaling_impl: Module, - scaling_impl: Module, - tensor_clamp_impl: Module = TensorClamp()) -> None: + self, + trunc_int_scaling_impl: Module, + scaling_impl: Module, + tensor_clamp_impl: Module = TensorClamp()) -> None: super(TruncScalingWrapper, self).__init__() self.trunc_int_scaling_impl = trunc_int_scaling_impl self.scaling_impl = scaling_impl self.tensor_clamp_impl = tensor_clamp_impl @brevitas.jit.script_method - def forward(self, scaling_input: Tensor, input_bitwidth: Tensor, output_bitwidth: Tensor, signed: Union[bool, Tensor]) -> Tensor: + def forward( + self, + scaling_input: Tensor, + input_bitwidth: Tensor, + output_bitwidth: Tensor, + signed: Union[bool, Tensor]) -> Tensor: threshold = self.trunc_int_scaling_impl(output_bit_width, signed) scale = self.scaling_impl(scaling_input, threshold) - msb_scale = 2**(input_bitwidth - output_bitwidth) + msb_scale = 2 ** (input_bitwidth - output_bitwidth) unit_scale = torch.ones_like(msb_scale) max_scale = torch.where(msb_scale > unit_scale, msb_scale, unit_scale) min_scale = torch.where(msb_scale < unit_scale, msb_scale, unit_scale) diff --git a/src/brevitas/function/ops.py b/src/brevitas/function/ops.py index f836eed63..a87462ce6 100644 --- a/src/brevitas/function/ops.py +++ b/src/brevitas/function/ops.py @@ -130,7 +130,9 @@ def identity(x: Tensor) -> Tensor: @brevitas.jit.script -def max_int(signed: Union[bool, Tensor], narrow_range: Union[bool, Tensor], bit_width: Tensor) -> Tensor: +def max_int( + signed: Union[bool, Tensor], narrow_range: Union[bool, Tensor], + bit_width: Tensor) -> Tensor: """ Compute the maximum integer representable by a given number of bits. Args: @@ -161,7 +163,9 @@ def max_int(signed: Union[bool, Tensor], narrow_range: Union[bool, Tensor], bit_ @brevitas.jit.script -def min_int(signed: Union[bool, Tensor], narrow_range: Union[bool, Tensor], bit_width: Tensor) -> Tensor: +def min_int( + signed: Union[bool, Tensor], narrow_range: Union[bool, Tensor], + bit_width: Tensor) -> Tensor: """ Compute the minimum integer representable by a given number of bits. Args: diff --git a/tests/brevitas/core/test_trunc_int_quant.py b/tests/brevitas/core/test_trunc_int_quant.py index 2497ef2f2..71b508b97 100644 --- a/tests/brevitas/core/test_trunc_int_quant.py +++ b/tests/brevitas/core/test_trunc_int_quant.py @@ -1,10 +1,12 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -from hypothesis import given -import mock +import logging + +import pytest_cases import torch +from brevitas.core.bit_width import BitWidthConst from brevitas.core.function_wrapper import RoundSte from brevitas.core.function_wrapper import TensorClamp from brevitas.core.quant import TruncIntQuant @@ -13,6 +15,10 @@ from tests.brevitas.core.int_quant_fixture import * # noqa +def allexact(x, y): + return np.allclose(x, y, rtol=0.0, atol=0.0, equal_nan=False) + + class TestTruncIntQuantUnit: def test_trunc_int_quant_defaults(self, bit_width_const): @@ -21,3 +27,56 @@ def test_trunc_int_quant_defaults(self, bit_width_const): assert isinstance(trunc_int_quant.tensor_clamp_impl, TensorClamp) assert isinstance(trunc_int_quant.trunc_scaling_impl, TruncMsbScaling) assert trunc_int_quant.narrow_range == False + + # yapf: disable + @pytest_cases.fixture( + ids=[ + "defaults_overflow", + ], + params=[ + { + "init_args": { + "bit_width_impl": BitWidthConst(4), + "float_to_int_impl": RoundSte(), + }, + "train_args": { + "x": torch.tensor([255.]), + "scale": torch.tensor([1.]), + "zero_point": torch.tensor([0.]), + "input_bit_width": torch.tensor([8.]), + "signed": False, + }, + "eval_args": { + "x": torch.tensor([255.]), + "scale": torch.tensor([1.]), + "zero_point": torch.tensor([0.]), + "input_bit_width": torch.tensor([8.]), + "signed": False, + }, + "result": { # Result needs to match the order of the output tuple + "y": torch.tensor([240.]), + "scale": torch.tensor([16.]), + "zero_point": torch.tensor([0.]), + "bit_width": torch.tensor([4.]), + } + }, + ],) + # yapf: enable + def trunc_int_quant_io_fixture(self, request): + yield request.param + + def test_trunc_int_quant_io(self, caplog, trunc_int_quant_io_fixture): + caplog.set_level(logging.INFO) + test_cfg = trunc_int_quant_io_fixture + init_args = test_cfg["init_args"] + train_args = test_cfg["train_args"] + eval_args = test_cfg["eval_args"] + expected_result = test_cfg["result"] + trunc_int_quant = TruncIntQuant(**init_args) + trunc_int_quant.train() + y = trunc_int_quant(**train_args) + trunc_int_quant.eval() + with torch.no_grad(): + y = trunc_int_quant(**eval_args) + for i, k in enumerate(expected_result.keys()): + assert torch.allclose(expected_result[k], y[i], rtol=0.0, atol=0.0, equal_nan=False), "Expected result[{k}]: {expected_result[k]}, result: {y[i]}" From 0c849e7dd3214c456d8d7b80106498dc882d213f Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Fri, 24 Jan 2025 17:49:22 +0000 Subject: [PATCH 14/32] test (trunc): Added signed overflow test --- tests/brevitas/core/test_trunc_int_quant.py | 30 +++++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/tests/brevitas/core/test_trunc_int_quant.py b/tests/brevitas/core/test_trunc_int_quant.py index 71b508b97..5ac08ea99 100644 --- a/tests/brevitas/core/test_trunc_int_quant.py +++ b/tests/brevitas/core/test_trunc_int_quant.py @@ -31,7 +31,8 @@ def test_trunc_int_quant_defaults(self, bit_width_const): # yapf: disable @pytest_cases.fixture( ids=[ - "defaults_overflow", + "defaults_uint_overflow", + "defaults_int_overflow", ], params=[ { @@ -58,7 +59,32 @@ def test_trunc_int_quant_defaults(self, bit_width_const): "scale": torch.tensor([16.]), "zero_point": torch.tensor([0.]), "bit_width": torch.tensor([4.]), - } + }, + }, { + "init_args": { + "bit_width_impl": BitWidthConst(4), + "float_to_int_impl": RoundSte(), + }, + "train_args": { + "x": torch.tensor([127.]), + "scale": torch.tensor([1.]), + "zero_point": torch.tensor([0.]), + "input_bit_width": torch.tensor([8.]), + "signed": True, + }, + "eval_args": { + "x": torch.tensor([127.]), + "scale": torch.tensor([1.]), + "zero_point": torch.tensor([0.]), + "input_bit_width": torch.tensor([8.]), + "signed": True, + }, + "result": { # Result needs to match the order of the output tuple + "y": torch.tensor([112.]), + "scale": torch.tensor([16.]), + "zero_point": torch.tensor([0.]), + "bit_width": torch.tensor([4.]), + }, }, ],) # yapf: enable From 7f76cd9a6a0199b29aa9b9c35a7829c7e138139c Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Fri, 24 Jan 2025 18:00:52 +0000 Subject: [PATCH 15/32] test (trunc): Added more unti tests. --- tests/brevitas/core/test_trunc_int_quant.py | 132 +++++++++++++++++++- 1 file changed, 131 insertions(+), 1 deletion(-) diff --git a/tests/brevitas/core/test_trunc_int_quant.py b/tests/brevitas/core/test_trunc_int_quant.py index 5ac08ea99..0ef625266 100644 --- a/tests/brevitas/core/test_trunc_int_quant.py +++ b/tests/brevitas/core/test_trunc_int_quant.py @@ -32,7 +32,12 @@ def test_trunc_int_quant_defaults(self, bit_width_const): @pytest_cases.fixture( ids=[ "defaults_uint_overflow", - "defaults_int_overflow", + "defaults_int+_overflow", + "defaults_int-_overflow", + "defaults_uint_underflow", + "defaults_int_underflow", + "defaults_uint_ulp", + "defaults_int_ulp", ], params=[ { @@ -85,6 +90,131 @@ def test_trunc_int_quant_defaults(self, bit_width_const): "zero_point": torch.tensor([0.]), "bit_width": torch.tensor([4.]), }, + }, { + "init_args": { + "bit_width_impl": BitWidthConst(4), + "float_to_int_impl": RoundSte(), + }, + "train_args": { + "x": torch.tensor([-128.]), + "scale": torch.tensor([1.]), + "zero_point": torch.tensor([0.]), + "input_bit_width": torch.tensor([8.]), + "signed": True, + }, + "eval_args": { + "x": torch.tensor([-128.]), + "scale": torch.tensor([1.]), + "zero_point": torch.tensor([0.]), + "input_bit_width": torch.tensor([8.]), + "signed": True, + }, + "result": { # Result needs to match the order of the output tuple + "y": torch.tensor([-128.]), + "scale": torch.tensor([16.]), + "zero_point": torch.tensor([0.]), + "bit_width": torch.tensor([4.]), + }, + }, { + "init_args": { + "bit_width_impl": BitWidthConst(4), + "float_to_int_impl": RoundSte(), + }, + "train_args": { + "x": torch.tensor([8.]), + "scale": torch.tensor([1.]), + "zero_point": torch.tensor([0.]), + "input_bit_width": torch.tensor([8.]), + "signed": True, + }, + "eval_args": { + "x": torch.tensor([8.]), + "scale": torch.tensor([1.]), + "zero_point": torch.tensor([0.]), + "input_bit_width": torch.tensor([8.]), + "signed": True, + }, + "result": { # Result needs to match the order of the output tuple + "y": torch.tensor([0.]), + "scale": torch.tensor([16.]), + "zero_point": torch.tensor([0.]), + "bit_width": torch.tensor([4.]), + }, + }, { + "init_args": { + "bit_width_impl": BitWidthConst(4), + "float_to_int_impl": RoundSte(), + }, + "train_args": { + "x": torch.tensor([-8.]), + "scale": torch.tensor([1.]), + "zero_point": torch.tensor([0.]), + "input_bit_width": torch.tensor([8.]), + "signed": True, + }, + "eval_args": { + "x": torch.tensor([-8.]), + "scale": torch.tensor([1.]), + "zero_point": torch.tensor([0.]), + "input_bit_width": torch.tensor([8.]), + "signed": True, + }, + "result": { # Result needs to match the order of the output tuple + "y": torch.tensor([0.]), + "scale": torch.tensor([16.]), + "zero_point": torch.tensor([0.]), + "bit_width": torch.tensor([4.]), + }, + }, { + "init_args": { + "bit_width_impl": BitWidthConst(4), + "float_to_int_impl": RoundSte(), + }, + "train_args": { + "x": torch.tensor([9.]), + "scale": torch.tensor([1.]), + "zero_point": torch.tensor([0.]), + "input_bit_width": torch.tensor([8.]), + "signed": True, + }, + "eval_args": { + "x": torch.tensor([9.]), + "scale": torch.tensor([1.]), + "zero_point": torch.tensor([0.]), + "input_bit_width": torch.tensor([8.]), + "signed": True, + }, + "result": { # Result needs to match the order of the output tuple + "y": torch.tensor([16.]), + "scale": torch.tensor([16.]), + "zero_point": torch.tensor([0.]), + "bit_width": torch.tensor([4.]), + }, + }, { + "init_args": { + "bit_width_impl": BitWidthConst(4), + "float_to_int_impl": RoundSte(), + }, + "train_args": { + "x": torch.tensor([-9.]), + "scale": torch.tensor([1.]), + "zero_point": torch.tensor([0.]), + "input_bit_width": torch.tensor([8.]), + "signed": True, + }, + "eval_args": { + "x": torch.tensor([-9.]), + "scale": torch.tensor([1.]), + "zero_point": torch.tensor([0.]), + "input_bit_width": torch.tensor([8.]), + "signed": True, + }, + "result": { # Result needs to match the order of the output tuple + "y": torch.tensor([-16.]), + "scale": torch.tensor([16.]), + "zero_point": torch.tensor([0.]), + "bit_width": torch.tensor([4.]), + }, }, ],) # yapf: enable From d2934a00a32b884fda380adfb20046fc321ecb7b Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Fri, 24 Jan 2025 19:04:50 +0000 Subject: [PATCH 16/32] fix (test/trunc): Bugfixes and tests. --- src/brevitas/core/scaling/standalone.py | 12 +- tests/brevitas/core/test_trunc_int_quant.py | 285 +++++++++++++++++++- 2 files changed, 278 insertions(+), 19 deletions(-) diff --git a/src/brevitas/core/scaling/standalone.py b/src/brevitas/core/scaling/standalone.py index e9c73d6bd..8be2b2cd4 100644 --- a/src/brevitas/core/scaling/standalone.py +++ b/src/brevitas/core/scaling/standalone.py @@ -483,10 +483,10 @@ def __init__(self) -> None: def forward( self, scaling_input: Tensor, - input_bitwidth: Tensor, - output_bitwidth: Tensor, + input_bit_width: Tensor, + output_bit_width: Tensor, signed: Union[bool, Tensor]) -> Tensor: - return 2 ** (input_bitwidth - output_bitwidth) + return 2 ** (input_bit_width - output_bit_width) class TruncScalingWrapper(brevitas.jit.ScriptModule): @@ -507,12 +507,12 @@ def __init__( def forward( self, scaling_input: Tensor, - input_bitwidth: Tensor, - output_bitwidth: Tensor, + input_bit_width: Tensor, + output_bit_width: Tensor, signed: Union[bool, Tensor]) -> Tensor: threshold = self.trunc_int_scaling_impl(output_bit_width, signed) scale = self.scaling_impl(scaling_input, threshold) - msb_scale = 2 ** (input_bitwidth - output_bitwidth) + msb_scale = 2 ** (input_bit_width - output_bit_width) unit_scale = torch.ones_like(msb_scale) max_scale = torch.where(msb_scale > unit_scale, msb_scale, unit_scale) min_scale = torch.where(msb_scale < unit_scale, msb_scale, unit_scale) diff --git a/tests/brevitas/core/test_trunc_int_quant.py b/tests/brevitas/core/test_trunc_int_quant.py index 0ef625266..ea386c0ae 100644 --- a/tests/brevitas/core/test_trunc_int_quant.py +++ b/tests/brevitas/core/test_trunc_int_quant.py @@ -7,10 +7,16 @@ import torch from brevitas.core.bit_width import BitWidthConst +from brevitas.core.function_wrapper import Identity from brevitas.core.function_wrapper import RoundSte from brevitas.core.function_wrapper import TensorClamp from brevitas.core.quant import TruncIntQuant from brevitas.core.scaling import TruncMsbScaling +from brevitas.core.scaling import TruncPowerOfTwoIntScaling +from brevitas.core.scaling import TruncScalingWrapper +from brevitas.core.scaling import RuntimeStatsScaling +from brevitas.core.stats import AbsMax +from brevitas.core.restrict_val import PowerOfTwoRestrictValue from tests.brevitas.core.bit_width_fixture import * # noqa from tests.brevitas.core.int_quant_fixture import * # noqa @@ -33,14 +39,21 @@ def test_trunc_int_quant_defaults(self, bit_width_const): ids=[ "defaults_uint_overflow", "defaults_int+_overflow", - "defaults_int-_overflow", + "defaults_int-_max", "defaults_uint_underflow", "defaults_int_underflow", "defaults_uint_ulp", "defaults_int_ulp", + "abxmax_uint_overflow", + "abxmax_int+_overflow", + "abxmax_int-_overflow", + "abxmax_uint_underflow", + "abxmax_int_underflow", + "abxmax_uint_ulp", + "abxmax_int_ulp", ], params=[ - { + { # defaults_uint_overflow "init_args": { "bit_width_impl": BitWidthConst(4), "float_to_int_impl": RoundSte(), @@ -65,7 +78,7 @@ def test_trunc_int_quant_defaults(self, bit_width_const): "zero_point": torch.tensor([0.]), "bit_width": torch.tensor([4.]), }, - }, { + }, { # defaults_int+_overflow "init_args": { "bit_width_impl": BitWidthConst(4), "float_to_int_impl": RoundSte(), @@ -90,7 +103,7 @@ def test_trunc_int_quant_defaults(self, bit_width_const): "zero_point": torch.tensor([0.]), "bit_width": torch.tensor([4.]), }, - }, { + }, { # defaults_int-_max "init_args": { "bit_width_impl": BitWidthConst(4), "float_to_int_impl": RoundSte(), @@ -115,7 +128,7 @@ def test_trunc_int_quant_defaults(self, bit_width_const): "zero_point": torch.tensor([0.]), "bit_width": torch.tensor([4.]), }, - }, { + }, { # defaults_uint_underflow "init_args": { "bit_width_impl": BitWidthConst(4), "float_to_int_impl": RoundSte(), @@ -125,14 +138,14 @@ def test_trunc_int_quant_defaults(self, bit_width_const): "scale": torch.tensor([1.]), "zero_point": torch.tensor([0.]), "input_bit_width": torch.tensor([8.]), - "signed": True, + "signed": False, }, "eval_args": { "x": torch.tensor([8.]), "scale": torch.tensor([1.]), "zero_point": torch.tensor([0.]), "input_bit_width": torch.tensor([8.]), - "signed": True, + "signed": False, }, "result": { # Result needs to match the order of the output tuple "y": torch.tensor([0.]), @@ -140,7 +153,7 @@ def test_trunc_int_quant_defaults(self, bit_width_const): "zero_point": torch.tensor([0.]), "bit_width": torch.tensor([4.]), }, - }, { + }, { # defaults_int_underflow "init_args": { "bit_width_impl": BitWidthConst(4), "float_to_int_impl": RoundSte(), @@ -165,7 +178,7 @@ def test_trunc_int_quant_defaults(self, bit_width_const): "zero_point": torch.tensor([0.]), "bit_width": torch.tensor([4.]), }, - }, { + }, { # defaults_uint_ulp "init_args": { "bit_width_impl": BitWidthConst(4), "float_to_int_impl": RoundSte(), @@ -175,14 +188,14 @@ def test_trunc_int_quant_defaults(self, bit_width_const): "scale": torch.tensor([1.]), "zero_point": torch.tensor([0.]), "input_bit_width": torch.tensor([8.]), - "signed": True, + "signed": False, }, "eval_args": { "x": torch.tensor([9.]), "scale": torch.tensor([1.]), "zero_point": torch.tensor([0.]), "input_bit_width": torch.tensor([8.]), - "signed": True, + "signed": False, }, "result": { # Result needs to match the order of the output tuple "y": torch.tensor([16.]), @@ -190,7 +203,7 @@ def test_trunc_int_quant_defaults(self, bit_width_const): "zero_point": torch.tensor([0.]), "bit_width": torch.tensor([4.]), }, - }, { + }, { # defaults_int_ulp "init_args": { "bit_width_impl": BitWidthConst(4), "float_to_int_impl": RoundSte(), @@ -215,7 +228,253 @@ def test_trunc_int_quant_defaults(self, bit_width_const): "zero_point": torch.tensor([0.]), "bit_width": torch.tensor([4.]), }, + }, { # abxmax_uint_overflow + "init_args": { + "bit_width_impl": BitWidthConst(4), + "float_to_int_impl": RoundSte(), + "trunc_scaling_impl": TruncScalingWrapper( + trunc_int_scaling_impl=TruncPowerOfTwoIntScaling(), + scaling_impl=RuntimeStatsScaling( + scaling_stats_impl=AbsMax(), + scaling_stats_input_view_shape_impl=Identity(), + scaling_shape=(1,), + scaling_stats_momentum=1.0, + restrict_scaling_impl=PowerOfTwoRestrictValue(), + ) + ), + }, + "train_args": { + "x": torch.tensor([128.]), + "scale": torch.tensor([1.]), + "zero_point": torch.tensor([0.]), + "input_bit_width": torch.tensor([8.]), + "signed": False, + }, + "eval_args": { + "x": torch.tensor([255.]), + "scale": torch.tensor([1.]), + "zero_point": torch.tensor([0.]), + "input_bit_width": torch.tensor([8.]), + "signed": False, + }, + "result": { # Result needs to match the order of the output tuple + "y": torch.tensor([120.]), + "scale": torch.tensor([8.]), + "zero_point": torch.tensor([0.]), + "bit_width": torch.tensor([4.]), + }, + }, { # abxmax_int+_overflow + "init_args": { + "bit_width_impl": BitWidthConst(4), + "float_to_int_impl": RoundSte(), + "trunc_scaling_impl": TruncScalingWrapper( + trunc_int_scaling_impl=TruncPowerOfTwoIntScaling(), + scaling_impl=RuntimeStatsScaling( + scaling_stats_impl=AbsMax(), + scaling_stats_input_view_shape_impl=Identity(), + scaling_shape=(1,), + scaling_stats_momentum=1.0, + restrict_scaling_impl=PowerOfTwoRestrictValue(), + ) + ), + }, + "train_args": { + "x": torch.tensor([32.]), + "scale": torch.tensor([1.]), + "zero_point": torch.tensor([0.]), + "input_bit_width": torch.tensor([8.]), + "signed": True, + }, + "eval_args": { + "x": torch.tensor([64.]), + "scale": torch.tensor([1.]), + "zero_point": torch.tensor([0.]), + "input_bit_width": torch.tensor([8.]), + "signed": True, + }, + "result": { # Result needs to match the order of the output tuple + "y": torch.tensor([28.]), + "scale": torch.tensor([4.]), + "zero_point": torch.tensor([0.]), + "bit_width": torch.tensor([4.]), + }, + }, { # abxmax_int-_overflow + "init_args": { + "bit_width_impl": BitWidthConst(4), + "float_to_int_impl": RoundSte(), + "trunc_scaling_impl": TruncScalingWrapper( + trunc_int_scaling_impl=TruncPowerOfTwoIntScaling(), + scaling_impl=RuntimeStatsScaling( + scaling_stats_impl=AbsMax(), + scaling_stats_input_view_shape_impl=Identity(), + scaling_shape=(1,), + scaling_stats_momentum=1.0, + restrict_scaling_impl=PowerOfTwoRestrictValue(), + ) + ), + }, + "train_args": { + "x": torch.tensor([-16.]), + "scale": torch.tensor([1.]), + "zero_point": torch.tensor([0.]), + "input_bit_width": torch.tensor([8.]), + "signed": True, + }, + "eval_args": { + "x": torch.tensor([-32.]), + "scale": torch.tensor([1.]), + "zero_point": torch.tensor([0.]), + "input_bit_width": torch.tensor([8.]), + "signed": True, + }, + "result": { # Result needs to match the order of the output tuple + "y": torch.tensor([-16.]), + "scale": torch.tensor([2.]), + "zero_point": torch.tensor([0.]), + "bit_width": torch.tensor([4.]), + }, + }, { # abxmax_uint_underflow + "init_args": { + "bit_width_impl": BitWidthConst(4), + "float_to_int_impl": RoundSte(), + "trunc_scaling_impl": TruncScalingWrapper( + trunc_int_scaling_impl=TruncPowerOfTwoIntScaling(), + scaling_impl=RuntimeStatsScaling( + scaling_stats_impl=AbsMax(), + scaling_stats_input_view_shape_impl=Identity(), + scaling_shape=(1,), + scaling_stats_momentum=1.0, + restrict_scaling_impl=PowerOfTwoRestrictValue(), + ) + ), + }, + "train_args": { + "x": torch.tensor([15.]), + "scale": torch.tensor([1.]), + "zero_point": torch.tensor([0.]), + "input_bit_width": torch.tensor([8.]), + "signed": False, + }, + "eval_args": { + "x": torch.tensor([.5]), + "scale": torch.tensor([1.]), + "zero_point": torch.tensor([0.]), + "input_bit_width": torch.tensor([8.]), + "signed": False, + }, + "result": { # Result needs to match the order of the output tuple + "y": torch.tensor([0.]), + "scale": torch.tensor([1.]), + "zero_point": torch.tensor([0.]), + "bit_width": torch.tensor([4.]), + }, + }, { # abxmax_int_underflow + "init_args": { + "bit_width_impl": BitWidthConst(4), + "float_to_int_impl": RoundSte(), + "trunc_scaling_impl": TruncScalingWrapper( + trunc_int_scaling_impl=TruncPowerOfTwoIntScaling(), + scaling_impl=RuntimeStatsScaling( + scaling_stats_impl=AbsMax(), + scaling_stats_input_view_shape_impl=Identity(), + scaling_shape=(1,), + scaling_stats_momentum=1.0, + restrict_scaling_impl=PowerOfTwoRestrictValue(), + ) + ), + }, + "train_args": { + "x": torch.tensor([-8.]), + "scale": torch.tensor([1.]), + "zero_point": torch.tensor([0.]), + "input_bit_width": torch.tensor([8.]), + "signed": True, + }, + "eval_args": { + "x": torch.tensor([-.5]), + "scale": torch.tensor([1.]), + "zero_point": torch.tensor([0.]), + "input_bit_width": torch.tensor([8.]), + "signed": True, + }, + "result": { # Result needs to match the order of the output tuple + "y": torch.tensor([0.]), + "scale": torch.tensor([1.]), + "zero_point": torch.tensor([0.]), + "bit_width": torch.tensor([4.]), + }, + }, { # abxmax_uint_ulp + "init_args": { + "bit_width_impl": BitWidthConst(4), + "float_to_int_impl": RoundSte(), + "trunc_scaling_impl": TruncScalingWrapper( + trunc_int_scaling_impl=TruncPowerOfTwoIntScaling(), + scaling_impl=RuntimeStatsScaling( + scaling_stats_impl=AbsMax(), + scaling_stats_input_view_shape_impl=Identity(), + scaling_shape=(1,), + scaling_stats_momentum=1.0, + restrict_scaling_impl=PowerOfTwoRestrictValue(), + ) + ), + }, + "train_args": { + "x": torch.tensor([31.]), + "scale": torch.tensor([1.]), + "zero_point": torch.tensor([0.]), + "input_bit_width": torch.tensor([8.]), + "signed": False, + }, + "eval_args": { + "x": torch.tensor([2.]), + "scale": torch.tensor([1.]), + "zero_point": torch.tensor([0.]), + "input_bit_width": torch.tensor([8.]), + "signed": False, + }, + "result": { # Result needs to match the order of the output tuple + "y": torch.tensor([2.]), + "scale": torch.tensor([2.]), + "zero_point": torch.tensor([0.]), + "bit_width": torch.tensor([4.]), + }, + }, { # abxmax_int_ulp + "init_args": { + "bit_width_impl": BitWidthConst(4), + "float_to_int_impl": RoundSte(), + "trunc_scaling_impl": TruncScalingWrapper( + trunc_int_scaling_impl=TruncPowerOfTwoIntScaling(), + scaling_impl=RuntimeStatsScaling( + scaling_stats_impl=AbsMax(), + scaling_stats_input_view_shape_impl=Identity(), + scaling_shape=(1,), + scaling_stats_momentum=1.0, + restrict_scaling_impl=PowerOfTwoRestrictValue(), + ) + ), + }, + "train_args": { + "x": torch.tensor([-64.]), + "scale": torch.tensor([1.]), + "zero_point": torch.tensor([0.]), + "input_bit_width": torch.tensor([8.]), + "signed": True, + }, + "eval_args": { + "x": torch.tensor([-5.]), + "scale": torch.tensor([1.]), + "zero_point": torch.tensor([0.]), + "input_bit_width": torch.tensor([8.]), + "signed": True, + }, + "result": { # Result needs to match the order of the output tuple + "y": torch.tensor([-8.]), + "scale": torch.tensor([8.]), + "zero_point": torch.tensor([0.]), + "bit_width": torch.tensor([4.]), + }, }, + ],) # yapf: enable def trunc_int_quant_io_fixture(self, request): @@ -235,4 +494,4 @@ def test_trunc_int_quant_io(self, caplog, trunc_int_quant_io_fixture): with torch.no_grad(): y = trunc_int_quant(**eval_args) for i, k in enumerate(expected_result.keys()): - assert torch.allclose(expected_result[k], y[i], rtol=0.0, atol=0.0, equal_nan=False), "Expected result[{k}]: {expected_result[k]}, result: {y[i]}" + assert torch.allclose(expected_result[k], y[i], rtol=0.0, atol=0.0, equal_nan=False), f"Expected result[{k}]: {expected_result[k]}, result: {y[i]}" From 7d6c934f72a3a4304b4910f423a343f425925da3 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Mon, 27 Jan 2025 08:55:19 +0000 Subject: [PATCH 17/32] Fix: precommit --- tests/brevitas/core/test_trunc_int_quant.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/brevitas/core/test_trunc_int_quant.py b/tests/brevitas/core/test_trunc_int_quant.py index ea386c0ae..a064b0676 100644 --- a/tests/brevitas/core/test_trunc_int_quant.py +++ b/tests/brevitas/core/test_trunc_int_quant.py @@ -11,12 +11,12 @@ from brevitas.core.function_wrapper import RoundSte from brevitas.core.function_wrapper import TensorClamp from brevitas.core.quant import TruncIntQuant +from brevitas.core.restrict_val import PowerOfTwoRestrictValue +from brevitas.core.scaling import RuntimeStatsScaling from brevitas.core.scaling import TruncMsbScaling from brevitas.core.scaling import TruncPowerOfTwoIntScaling from brevitas.core.scaling import TruncScalingWrapper -from brevitas.core.scaling import RuntimeStatsScaling from brevitas.core.stats import AbsMax -from brevitas.core.restrict_val import PowerOfTwoRestrictValue from tests.brevitas.core.bit_width_fixture import * # noqa from tests.brevitas.core.int_quant_fixture import * # noqa From b121f5dffa7b75c610991a4919c596235d1e205a Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Mon, 27 Jan 2025 11:16:05 +0000 Subject: [PATCH 18/32] Fix (solver/trunc): Added a ShiftRoundSaturate quantizer and update trunc solver --- src/brevitas/inject/enum.py | 8 ++++++ src/brevitas/quant/scaled_int.py | 25 ++++++++++++++-- src/brevitas/quant/solver/trunc.py | 46 ++++++++++++++++++++++++++++-- 3 files changed, 74 insertions(+), 5 deletions(-) diff --git a/src/brevitas/inject/enum.py b/src/brevitas/inject/enum.py index fbac29176..c0990d717 100644 --- a/src/brevitas/inject/enum.py +++ b/src/brevitas/inject/enum.py @@ -93,3 +93,11 @@ class StatsOp(AutoName): # Typically adopted for asymmetric quantization MIN_MAX = auto() PERCENTILE_INTERVAL = auto() + + +class TruncScalingImplType(AutoName): + """ + + """ + MSB = auto() + WRAPPER = auto() diff --git a/src/brevitas/quant/scaled_int.py b/src/brevitas/quant/scaled_int.py index b5f9174d7..cbe681f5e 100644 --- a/src/brevitas/quant/scaled_int.py +++ b/src/brevitas/quant/scaled_int.py @@ -28,6 +28,7 @@ 'Int8WeightPerChannelFloatMSE', 'TruncTo8bit', 'RoundTo8bit', + 'ShiftRoundSaturateTo8bit', 'Int4WeightPerTensorFloatDecoupled', 'Int8WeightPerChannelFloatDecoupled', 'Uint8ActPerTensorFloatBatchQuant1d', @@ -261,7 +262,7 @@ class Uint8ActPerTensorFloatMSE(MSESymmetricScale, Uint8ActPerTensorFloat): class TruncTo8bit(TruncQuantSolver): """ - 8-bit signed int truncator that preserves the input scale factor and zero-point. + 8-bit int truncator that preserves most-significant bits and zero-point. Examples: >>> from brevitas.nn import TruncAvgPool2d @@ -271,11 +272,12 @@ class TruncTo8bit(TruncQuantSolver): quant_type = 'int' bit_width_impl_type = 'const' float_to_int_impl_type = 'floor' + trunc_scaling_impl_type = 'msb' class RoundTo8bit(TruncQuantSolver): """ - 8-bit signed int truncator with rounding that preserves the input scale factor and zero-point. + 8-bit int truncator with rounding that preserves most-significant bits and zero-point. Examples: >>> from brevitas.nn import TruncAvgPool2d @@ -285,6 +287,25 @@ class RoundTo8bit(TruncQuantSolver): quant_type = 'int' bit_width_impl_type = 'const' float_to_int_impl_type = 'round' + trunc_scaling_impl_type = 'msb' + + +class ShiftRoundSaturateTo8bit(TruncQuantSolver, + ParamFromRuntimePercentileScaling, + PerTensorPoTScaling8bit): + """ + 8-bit shift-round-saturate quantizer which uses statistics to calculate the amount of truncation + the lest-significant bits and most-significant bits. Zero-point is preserved. + + Examples: + >>> from brevitas.nn import TruncAvgPool2d + >>> pool = TruncAvgPool2d(kernel_size=(3, 3), trunc_quant=ShiftRoundSaturateTo8bit) + """ + bit_width = 8 + quant_type = 'int' + bit_width_impl_type = 'const' + float_to_int_impl_type = 'round' + trunc_scaling_impl_type = 'wrapper' class Int4WeightPerTensorFloatDecoupled(WeightPerTensorFloatDecoupledL2Param): diff --git a/src/brevitas/quant/solver/trunc.py b/src/brevitas/quant/solver/trunc.py index 9203e2f11..658cb682d 100644 --- a/src/brevitas/quant/solver/trunc.py +++ b/src/brevitas/quant/solver/trunc.py @@ -2,12 +2,17 @@ # SPDX-License-Identifier: BSD-3-Clause from brevitas.core.quant import TruncIntQuant +from brevitas.core.scaling import TruncMsbScaling +from brevitas.core.scaling import TruncPowerOfTwoIntScaling +from brevitas.core.scaling import TruncScalingWrapper from brevitas.inject import ExtendedInjector from brevitas.inject import value from brevitas.inject.enum import QuantType +from brevitas.inject.enum import RestrictValueType +from brevitas.inject.enum import TruncScalingImplType from brevitas.proxy import TruncQuantProxyFromInjector -from brevitas.quant.solver.common import SolveBitWidthImplFromEnum -from brevitas.quant.solver.common import SolveTensorQuantFloatToIntImplFromEnum +from brevitas.quant.solver.act import * +from brevitas.quant.solver.common import * class SolveTruncTensorQuantFromEnum(ExtendedInjector): @@ -26,9 +31,44 @@ def tensor_quant(quant_type): raise RuntimeError(f'{quant_type} not recognized.') +class SolveTruncScalingImplFromEnum(ExtendedInjector): + + @value + def trunc_scaling_impl(trunc_scaling_impl_type="msb"): + if trunc_scaling_impl_type == TruncScalingImplType.MSB: + return TruncMsbScaling + elif trunc_scaling_impl_type == TruncScalingImplType.WRAPPER: + return TruncScalingWrapper + else: + raise RuntimeError(f'trunc_scaling_impl_type={trunc_scaling_impl_type} not recognized.') + + +class SolveTruncIntScalingImplFromEnum(ExtendedInjector): + + @value + def trunc_int_scaling_impl(restrict_scaling_type): + if restrict_scaling_type == RestrictValueType.POWER_OF_TWO: + return TruncPowerOfTwoIntScaling + else: + raise RuntimeError(f'restrict_scaling_type={restrict_scaling_type} not recognized.') + + class TruncQuantSolver(SolveBitWidthImplFromEnum, SolveTensorQuantFloatToIntImplFromEnum, - SolveTruncTensorQuantFromEnum): + SolveActScalingImplFromEnum, + SolveIntScalingImplFromEnum, + SolveScalingStatsOpFromEnum, + SolveRestrictScalingImplFromEnum, + SolveActScalingInitFromEnum, + SolveStatsReduceDimFromEnum, + SolveActScalingShape, + SolveScalingStatsInputViewShapeImplFromEnum, + SolveActScalingPerOutputChannelShape, + SolveUpdateStateDictImplFromEnum, + SolveInputViewImpl, + SolveTruncTensorQuantFromEnum, + SolveTruncScalingImplFromEnum, + SolveTruncIntScalingImplFromEnum): """ Translate enum directives to truncation-specific quantization core modules. It should be placed last in the list of classes a quantizer inherits from, From 0a00e8faa63597b5913cfe98a3f0e334b5db9611 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Mon, 27 Jan 2025 12:14:00 +0000 Subject: [PATCH 19/32] Fix (export/trunc): Updated export to generate Quant node. --- src/brevitas/export/onnx/qonnx/function.py | 7 ++++--- src/brevitas/export/onnx/qonnx/handler.py | 3 ++- src/brevitas/proxy/runtime_quant.py | 21 +++++++++++++++++++-- 3 files changed, 25 insertions(+), 6 deletions(-) diff --git a/src/brevitas/export/onnx/qonnx/function.py b/src/brevitas/export/onnx/qonnx/function.py index dc32cece9..0fbeba8f6 100644 --- a/src/brevitas/export/onnx/qonnx/function.py +++ b/src/brevitas/export/onnx/qonnx/function.py @@ -119,14 +119,14 @@ def symbolic( input_bit_width, signed, narrow_range, + output_scale, output_bit_width, rounding_mode): ret = g.op( - f'{DOMAIN_STRING}::Trunc', + f'{DOMAIN_STRING}::Quant', x, - scale, + output_scale, zero_point, - input_bit_width, output_bit_width, rounding_mode_s=rounding_mode, signed_i=int(signed), @@ -143,6 +143,7 @@ def forward( input_bit_width, signed, narrow_range, + output_scale, output_bit_width, rounding_mode): # TODO: Restore this (fails when `signed` arg added) diff --git a/src/brevitas/export/onnx/qonnx/handler.py b/src/brevitas/export/onnx/qonnx/handler.py index 194e67e6a..961b15761 100644 --- a/src/brevitas/export/onnx/qonnx/handler.py +++ b/src/brevitas/export/onnx/qonnx/handler.py @@ -230,6 +230,7 @@ class BrevitasTruncQuantProxyHandler(ONNXBaseHandler): def prepare_for_export(self, module: TruncQuantProxyFromInjector): self.symbolic_kwargs = { 'narrow_range': module.is_narrow_range, + 'output_scale': module.scale(), 'output_bit_width': module.bit_width(), 'rounding_mode': module.rounding_mode} @@ -238,7 +239,7 @@ def symbolic_execution( signed: Tensor): y = BrevitasTruncFn.apply( x, scale, zero_point, input_bit_width, signed, *self.symbolic_kwargs.values()) - return y, scale, zero_point, self.symbolic_kwargs['output_bit_width'] + return y, self.symbolic_kwargs['output_scale'], zero_point, self.symbolic_kwargs['output_bit_width'] class BrevitasQuantLSTMLayerHandler(QuantLSTMLayerHandler): diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index c5d5f68fa..c8bc92977 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -60,7 +60,7 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> QuantTensor: @runtime_checkable class AccQuantProxyProtocol(QuantProxyProtocol, Protocol): - def forward(self, x: QuantTensor) -> QuantTensor: + def forward(self, x: Union[Tensor, IntQuantTensor]) -> Union[Tensor, IntQuantTensor]: ... @@ -262,13 +262,26 @@ class TruncQuantProxyFromInjector(QuantProxyFromInjector, AccQuantProxyProtocol) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + self._cached_act = None + self.cache_inference_quant_act = True + self.cache_quant_io_metadata_only = True + self.cache_class = _CachedIO self.skip_create_quant_tensor = False + def retrieve_attribute(self, attribute): + if self._cached_act is not None: + return getattr(self._cached_act, attribute) + elif self._cached_act is None: + return None + @property def is_narrow_range(self): narrow_range = super(TruncQuantProxyFromInjector, self).is_narrow_range return narrow_range if narrow_range is not None else False + def scale(self): + return self.retrieve_attribute('scale') + def bit_width(self): if not self.is_quant_enabled: return None @@ -288,8 +301,12 @@ def forward(self, x: IntQuantTensor) -> Union[Tensor, IntQuantTensor]: out_value, out_scale, out_zp, out_bit_width = out_tuple if self.skip_create_quant_tensor: return out_value - return IntQuantTensor( + out = IntQuantTensor( out_value, out_scale, out_zp, out_bit_width, x.signed, self.training) + if not self.training and self.cache_inference_quant_act: + cached_out = self.cache_class(out.detach(), self.cache_quant_io_metadata_only) + self._cached_act = cached_out + return out else: return x From 80f57e3df7e8c0def9619df7918b6e69ad01c4c6 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Mon, 27 Jan 2025 17:53:38 +0000 Subject: [PATCH 20/32] Fix (test/qonnx/trunc): Allow off-by-1 errors in test --- src/brevitas/nn/quant_avg_pool.py | 5 ++++- src/brevitas/quant_tensor/int_torch_handler.py | 4 +++- .../brevitas/test_brevitas_avg_pool_export.py | 8 ++++++-- 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/src/brevitas/nn/quant_avg_pool.py b/src/brevitas/nn/quant_avg_pool.py index 1bd284eda..590ddcebb 100644 --- a/src/brevitas/nn/quant_avg_pool.py +++ b/src/brevitas/nn/quant_avg_pool.py @@ -33,10 +33,13 @@ def __init__( self, kernel_size: Union[int, Tuple[int, int]], stride: Union[int, Tuple[int, int]] = None, + ceil_mode: bool = False, + count_include_pad: bool = True, + divisor_override: Optional[int] = None, trunc_quant: Optional[AccQuantType] = RoundTo8bit, return_quant_tensor: bool = True, **kwargs): - AvgPool2d.__init__(self, kernel_size=kernel_size, stride=stride) + AvgPool2d.__init__(self, kernel_size=kernel_size, stride=stride, ceil_mode=ceil_mode, count_include_pad=count_include_pad, divisor_override=divisor_override) QuantLayerMixin.__init__(self, return_quant_tensor) TruncMixin.__init__(self, trunc_quant=trunc_quant, **kwargs) self.cache_inference_quant_act = False diff --git a/src/brevitas/quant_tensor/int_torch_handler.py b/src/brevitas/quant_tensor/int_torch_handler.py index 691138057..a2e6572da 100644 --- a/src/brevitas/quant_tensor/int_torch_handler.py +++ b/src/brevitas/quant_tensor/int_torch_handler.py @@ -109,7 +109,9 @@ def avg_pool2d_handler( max_acc_bit_width = FN_ACC_BITWIDTH_MAPPING[F.avg_pool2d] # remove avg scaling - if isinstance(kernel_size, tuple): + if divisor_override is not None: + avg_scaling = divisor_override + elif isinstance(kernel_size, tuple): avg_scaling = kernel_size[0] * kernel_size[1] else: avg_scaling = kernel_size * kernel_size diff --git a/tests/brevitas_finn/brevitas/test_brevitas_avg_pool_export.py b/tests/brevitas_finn/brevitas/test_brevitas_avg_pool_export.py index aa0d28faf..db082c124 100644 --- a/tests/brevitas_finn/brevitas/test_brevitas_avg_pool_export.py +++ b/tests/brevitas_finn/brevitas/test_brevitas_avg_pool_export.py @@ -42,11 +42,14 @@ def test_brevitas_avg_pool_export( quant_avgpool = TruncAvgPool2d( kernel_size=kernel_size, stride=stride, bit_width=bit_width, float_to_int_impl_type='floor') model_brevitas = torch.nn.Sequential(quant_node, quant_avgpool) - model_brevitas.eval() # determine input input_shape = (1, channels, idim, idim) inp = torch.randn(input_shape) + model_brevitas.train() + model_brevitas(inp) + model_brevitas.eval() + model_brevitas(inp) # export test_id = request.node.callspec.id @@ -63,6 +66,7 @@ def test_brevitas_avg_pool_export( odict = oxe.execute_onnx(model, idict, True) finn_output = odict[model.graph.output[0].name] # compare outputs - assert np.isclose(ref_output_array, finn_output).all() + scale = quant_avgpool.trunc_quant.scale().detach().numpy() # Allow "off-by-1" errors + assert np.isclose(ref_output_array, finn_output, atol=scale).all() # cleanup os.remove(export_path) From 1af6840764ac187c761adf44738819eec88b68e5 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Tue, 28 Jan 2025 11:24:14 +0000 Subject: [PATCH 21/32] tests (brv_finn/avgpool): Add "lossless" tests --- src/brevitas/nn/quant_avg_pool.py | 8 +++++++- .../brevitas/test_brevitas_avg_pool_export.py | 20 ++++++++++++++++--- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/src/brevitas/nn/quant_avg_pool.py b/src/brevitas/nn/quant_avg_pool.py index 590ddcebb..795bd0cf5 100644 --- a/src/brevitas/nn/quant_avg_pool.py +++ b/src/brevitas/nn/quant_avg_pool.py @@ -39,7 +39,13 @@ def __init__( trunc_quant: Optional[AccQuantType] = RoundTo8bit, return_quant_tensor: bool = True, **kwargs): - AvgPool2d.__init__(self, kernel_size=kernel_size, stride=stride, ceil_mode=ceil_mode, count_include_pad=count_include_pad, divisor_override=divisor_override) + AvgPool2d.__init__( + self, + kernel_size=kernel_size, + stride=stride, + ceil_mode=ceil_mode, + count_include_pad=count_include_pad, + divisor_override=divisor_override) QuantLayerMixin.__init__(self, return_quant_tensor) TruncMixin.__init__(self, trunc_quant=trunc_quant, **kwargs) self.cache_inference_quant_act = False diff --git a/tests/brevitas_finn/brevitas/test_brevitas_avg_pool_export.py b/tests/brevitas_finn/brevitas/test_brevitas_avg_pool_export.py index db082c124..a9c851d7e 100644 --- a/tests/brevitas_finn/brevitas/test_brevitas_avg_pool_export.py +++ b/tests/brevitas_finn/brevitas/test_brevitas_avg_pool_export.py @@ -26,16 +26,27 @@ @pytest.mark.parametrize("input_bit_width", [4, 8, 16]) @pytest.mark.parametrize("channels", [2, 4]) @pytest.mark.parametrize("idim", [7, 8]) +@pytest.mark.parametrize("restrict_scaling_type", ["log_fp", "power_of_two"]) def test_brevitas_avg_pool_export( - kernel_size, stride, signed, bit_width, input_bit_width, channels, idim, request): + kernel_size, + stride, + signed, + bit_width, + input_bit_width, + channels, + idim, + restrict_scaling_type, + request): if signed: quant_node = QuantIdentity( bit_width=input_bit_width, + restrict_scaling_type=restrict_scaling_type, return_quant_tensor=True, ) else: quant_node = QuantReLU( bit_width=input_bit_width, + restrict_scaling_type=restrict_scaling_type, return_quant_tensor=True, ) @@ -66,7 +77,10 @@ def test_brevitas_avg_pool_export( odict = oxe.execute_onnx(model, idict, True) finn_output = odict[model.graph.output[0].name] # compare outputs - scale = quant_avgpool.trunc_quant.scale().detach().numpy() # Allow "off-by-1" errors - assert np.isclose(ref_output_array, finn_output, atol=scale).all() + if restrict_scaling_type == "power_of_two" and kernel_size == 2: + atol = 1e-8 + else: + atol = quant_avgpool.trunc_quant.scale().detach().numpy() # Allow "off-by-1" errors + assert np.isclose(ref_output_array, finn_output, atol=atol).all() # cleanup os.remove(export_path) From 20a06bd99e64d8893478ba075d3f91dcf870c018 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Tue, 28 Jan 2025 11:48:29 +0000 Subject: [PATCH 22/32] Fix (brevitas/scaling): TruncPowerOfTwoIntScaling -> PowerOfTwoIntScaling --- src/brevitas/core/scaling/__init__.py | 1 - src/brevitas/core/scaling/int_scaling.py | 26 +++++++-------------- src/brevitas/quant/scaled_int.py | 1 + src/brevitas/quant/solver/trunc.py | 4 ++-- tests/brevitas/core/test_trunc_int_quant.py | 16 ++++++------- 5 files changed, 20 insertions(+), 28 deletions(-) diff --git a/src/brevitas/core/scaling/__init__.py b/src/brevitas/core/scaling/__init__.py index 18e0f08b9..d8a786c8b 100644 --- a/src/brevitas/core/scaling/__init__.py +++ b/src/brevitas/core/scaling/__init__.py @@ -9,7 +9,6 @@ from .float_scaling import FloatScaling from .int_scaling import IntScaling from .int_scaling import PowerOfTwoIntScaling -from .int_scaling import TruncPowerOfTwoIntScaling from .pre_scaling import AccumulatorAwareParameterPreScaling from .pre_scaling import AccumulatorAwareZeroCenterParameterPreScaling from .pre_scaling import ParameterPreScalingWeightNorm diff --git a/src/brevitas/core/scaling/int_scaling.py b/src/brevitas/core/scaling/int_scaling.py index 07e378813..f78519da7 100644 --- a/src/brevitas/core/scaling/int_scaling.py +++ b/src/brevitas/core/scaling/int_scaling.py @@ -1,7 +1,7 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -from typing import Union +from typing import Optional, Union from torch import Tensor @@ -19,11 +19,12 @@ def __init__(self, signed: bool, narrow_range: bool): self.narrow_range = narrow_range @brevitas.jit.script_method - def forward(self, bit_width: Tensor) -> Tensor: - if self.signed: - return -min_int(self.signed, self.narrow_range, bit_width) + def forward(self, bit_width: Tensor, signed: Optional[Union[bool, Tensor]] = None) -> Tensor: + is_signed = signed if signed is not None else self.signed + if is_signed: + return -min_int(is_signed, self.narrow_range, bit_width) else: - return max_int(self.signed, self.narrow_range, bit_width) + return max_int(is_signed, self.narrow_range, bit_width) class PowerOfTwoIntScaling(brevitas.jit.ScriptModule): @@ -34,15 +35,6 @@ def __init__(self, signed: bool): self.signed = signed @brevitas.jit.script_method - def forward(self, bit_width: Tensor) -> Tensor: - return max_int(self.signed, False, bit_width) + 1 - - -class TruncPowerOfTwoIntScaling(brevitas.jit.ScriptModule): - - def __init__(self): - super(TruncPowerOfTwoIntScaling, self).__init__() - - @brevitas.jit.script_method - def forward(self, bit_width: Tensor, signed: Union[bool, Tensor]) -> Tensor: - return max_int(signed, False, bit_width) + 1 + def forward(self, bit_width: Tensor, signed: Optional[Union[bool, Tensor]] = None) -> Tensor: + is_signed = signed if signed is not None else self.signed + return max_int(is_signed, False, bit_width) + 1 diff --git a/src/brevitas/quant/scaled_int.py b/src/brevitas/quant/scaled_int.py index cbe681f5e..0148834a6 100644 --- a/src/brevitas/quant/scaled_int.py +++ b/src/brevitas/quant/scaled_int.py @@ -306,6 +306,7 @@ class ShiftRoundSaturateTo8bit(TruncQuantSolver, bit_width_impl_type = 'const' float_to_int_impl_type = 'round' trunc_scaling_impl_type = 'wrapper' + signed = True # Ignored class Int4WeightPerTensorFloatDecoupled(WeightPerTensorFloatDecoupledL2Param): diff --git a/src/brevitas/quant/solver/trunc.py b/src/brevitas/quant/solver/trunc.py index 658cb682d..c167826e8 100644 --- a/src/brevitas/quant/solver/trunc.py +++ b/src/brevitas/quant/solver/trunc.py @@ -2,8 +2,8 @@ # SPDX-License-Identifier: BSD-3-Clause from brevitas.core.quant import TruncIntQuant +from brevitas.core.scaling import PowerOfTwoIntScaling from brevitas.core.scaling import TruncMsbScaling -from brevitas.core.scaling import TruncPowerOfTwoIntScaling from brevitas.core.scaling import TruncScalingWrapper from brevitas.inject import ExtendedInjector from brevitas.inject import value @@ -48,7 +48,7 @@ class SolveTruncIntScalingImplFromEnum(ExtendedInjector): @value def trunc_int_scaling_impl(restrict_scaling_type): if restrict_scaling_type == RestrictValueType.POWER_OF_TWO: - return TruncPowerOfTwoIntScaling + return PowerOfTwoIntScaling else: raise RuntimeError(f'restrict_scaling_type={restrict_scaling_type} not recognized.') diff --git a/tests/brevitas/core/test_trunc_int_quant.py b/tests/brevitas/core/test_trunc_int_quant.py index a064b0676..3989c7f6e 100644 --- a/tests/brevitas/core/test_trunc_int_quant.py +++ b/tests/brevitas/core/test_trunc_int_quant.py @@ -12,9 +12,9 @@ from brevitas.core.function_wrapper import TensorClamp from brevitas.core.quant import TruncIntQuant from brevitas.core.restrict_val import PowerOfTwoRestrictValue +from brevitas.core.scaling import PowerOfTwoIntScaling from brevitas.core.scaling import RuntimeStatsScaling from brevitas.core.scaling import TruncMsbScaling -from brevitas.core.scaling import TruncPowerOfTwoIntScaling from brevitas.core.scaling import TruncScalingWrapper from brevitas.core.stats import AbsMax from tests.brevitas.core.bit_width_fixture import * # noqa @@ -233,7 +233,7 @@ def test_trunc_int_quant_defaults(self, bit_width_const): "bit_width_impl": BitWidthConst(4), "float_to_int_impl": RoundSte(), "trunc_scaling_impl": TruncScalingWrapper( - trunc_int_scaling_impl=TruncPowerOfTwoIntScaling(), + trunc_int_scaling_impl=PowerOfTwoIntScaling(signed=True), scaling_impl=RuntimeStatsScaling( scaling_stats_impl=AbsMax(), scaling_stats_input_view_shape_impl=Identity(), @@ -268,7 +268,7 @@ def test_trunc_int_quant_defaults(self, bit_width_const): "bit_width_impl": BitWidthConst(4), "float_to_int_impl": RoundSte(), "trunc_scaling_impl": TruncScalingWrapper( - trunc_int_scaling_impl=TruncPowerOfTwoIntScaling(), + trunc_int_scaling_impl=PowerOfTwoIntScaling(signed=True), scaling_impl=RuntimeStatsScaling( scaling_stats_impl=AbsMax(), scaling_stats_input_view_shape_impl=Identity(), @@ -303,7 +303,7 @@ def test_trunc_int_quant_defaults(self, bit_width_const): "bit_width_impl": BitWidthConst(4), "float_to_int_impl": RoundSte(), "trunc_scaling_impl": TruncScalingWrapper( - trunc_int_scaling_impl=TruncPowerOfTwoIntScaling(), + trunc_int_scaling_impl=PowerOfTwoIntScaling(signed=True), scaling_impl=RuntimeStatsScaling( scaling_stats_impl=AbsMax(), scaling_stats_input_view_shape_impl=Identity(), @@ -338,7 +338,7 @@ def test_trunc_int_quant_defaults(self, bit_width_const): "bit_width_impl": BitWidthConst(4), "float_to_int_impl": RoundSte(), "trunc_scaling_impl": TruncScalingWrapper( - trunc_int_scaling_impl=TruncPowerOfTwoIntScaling(), + trunc_int_scaling_impl=PowerOfTwoIntScaling(signed=True), scaling_impl=RuntimeStatsScaling( scaling_stats_impl=AbsMax(), scaling_stats_input_view_shape_impl=Identity(), @@ -373,7 +373,7 @@ def test_trunc_int_quant_defaults(self, bit_width_const): "bit_width_impl": BitWidthConst(4), "float_to_int_impl": RoundSte(), "trunc_scaling_impl": TruncScalingWrapper( - trunc_int_scaling_impl=TruncPowerOfTwoIntScaling(), + trunc_int_scaling_impl=PowerOfTwoIntScaling(signed=True), scaling_impl=RuntimeStatsScaling( scaling_stats_impl=AbsMax(), scaling_stats_input_view_shape_impl=Identity(), @@ -408,7 +408,7 @@ def test_trunc_int_quant_defaults(self, bit_width_const): "bit_width_impl": BitWidthConst(4), "float_to_int_impl": RoundSte(), "trunc_scaling_impl": TruncScalingWrapper( - trunc_int_scaling_impl=TruncPowerOfTwoIntScaling(), + trunc_int_scaling_impl=PowerOfTwoIntScaling(signed=True), scaling_impl=RuntimeStatsScaling( scaling_stats_impl=AbsMax(), scaling_stats_input_view_shape_impl=Identity(), @@ -443,7 +443,7 @@ def test_trunc_int_quant_defaults(self, bit_width_const): "bit_width_impl": BitWidthConst(4), "float_to_int_impl": RoundSte(), "trunc_scaling_impl": TruncScalingWrapper( - trunc_int_scaling_impl=TruncPowerOfTwoIntScaling(), + trunc_int_scaling_impl=PowerOfTwoIntScaling(signed=True), scaling_impl=RuntimeStatsScaling( scaling_stats_impl=AbsMax(), scaling_stats_input_view_shape_impl=Identity(), From 54e86a02bcdccaa4488fca3690b6bc9af821b0c2 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Tue, 28 Jan 2025 12:04:00 +0000 Subject: [PATCH 23/32] Fix (scaling): Made signed an optional argument at init time. --- src/brevitas/core/scaling/int_scaling.py | 6 ++++-- src/brevitas/quant/scaled_int.py | 1 - tests/brevitas/core/test_trunc_int_quant.py | 15 +++++++-------- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/brevitas/core/scaling/int_scaling.py b/src/brevitas/core/scaling/int_scaling.py index f78519da7..100515571 100644 --- a/src/brevitas/core/scaling/int_scaling.py +++ b/src/brevitas/core/scaling/int_scaling.py @@ -13,7 +13,7 @@ class IntScaling(brevitas.jit.ScriptModule): __constants__ = ['signed', 'narrow_range'] - def __init__(self, signed: bool, narrow_range: bool): + def __init__(self, narrow_range: bool, signed: Optional[bool] = None): super(IntScaling, self).__init__() self.signed = signed self.narrow_range = narrow_range @@ -21,6 +21,7 @@ def __init__(self, signed: bool, narrow_range: bool): @brevitas.jit.script_method def forward(self, bit_width: Tensor, signed: Optional[Union[bool, Tensor]] = None) -> Tensor: is_signed = signed if signed is not None else self.signed + assert is_signed is not None, f"signed is not defined, signed={is_signed}" if is_signed: return -min_int(is_signed, self.narrow_range, bit_width) else: @@ -30,11 +31,12 @@ def forward(self, bit_width: Tensor, signed: Optional[Union[bool, Tensor]] = Non class PowerOfTwoIntScaling(brevitas.jit.ScriptModule): __constants__ = ['signed'] - def __init__(self, signed: bool): + def __init__(self, signed: Optional[bool] = None): super(PowerOfTwoIntScaling, self).__init__() self.signed = signed @brevitas.jit.script_method def forward(self, bit_width: Tensor, signed: Optional[Union[bool, Tensor]] = None) -> Tensor: is_signed = signed if signed is not None else self.signed + assert is_signed is not None, f"signed is not defined, signed={is_signed}" return max_int(is_signed, False, bit_width) + 1 diff --git a/src/brevitas/quant/scaled_int.py b/src/brevitas/quant/scaled_int.py index 0148834a6..cbe681f5e 100644 --- a/src/brevitas/quant/scaled_int.py +++ b/src/brevitas/quant/scaled_int.py @@ -306,7 +306,6 @@ class ShiftRoundSaturateTo8bit(TruncQuantSolver, bit_width_impl_type = 'const' float_to_int_impl_type = 'round' trunc_scaling_impl_type = 'wrapper' - signed = True # Ignored class Int4WeightPerTensorFloatDecoupled(WeightPerTensorFloatDecoupledL2Param): diff --git a/tests/brevitas/core/test_trunc_int_quant.py b/tests/brevitas/core/test_trunc_int_quant.py index 3989c7f6e..579563a7a 100644 --- a/tests/brevitas/core/test_trunc_int_quant.py +++ b/tests/brevitas/core/test_trunc_int_quant.py @@ -233,7 +233,7 @@ def test_trunc_int_quant_defaults(self, bit_width_const): "bit_width_impl": BitWidthConst(4), "float_to_int_impl": RoundSte(), "trunc_scaling_impl": TruncScalingWrapper( - trunc_int_scaling_impl=PowerOfTwoIntScaling(signed=True), + trunc_int_scaling_impl=PowerOfTwoIntScaling(), scaling_impl=RuntimeStatsScaling( scaling_stats_impl=AbsMax(), scaling_stats_input_view_shape_impl=Identity(), @@ -268,7 +268,7 @@ def test_trunc_int_quant_defaults(self, bit_width_const): "bit_width_impl": BitWidthConst(4), "float_to_int_impl": RoundSte(), "trunc_scaling_impl": TruncScalingWrapper( - trunc_int_scaling_impl=PowerOfTwoIntScaling(signed=True), + trunc_int_scaling_impl=PowerOfTwoIntScaling(), scaling_impl=RuntimeStatsScaling( scaling_stats_impl=AbsMax(), scaling_stats_input_view_shape_impl=Identity(), @@ -303,7 +303,7 @@ def test_trunc_int_quant_defaults(self, bit_width_const): "bit_width_impl": BitWidthConst(4), "float_to_int_impl": RoundSte(), "trunc_scaling_impl": TruncScalingWrapper( - trunc_int_scaling_impl=PowerOfTwoIntScaling(signed=True), + trunc_int_scaling_impl=PowerOfTwoIntScaling(), scaling_impl=RuntimeStatsScaling( scaling_stats_impl=AbsMax(), scaling_stats_input_view_shape_impl=Identity(), @@ -338,7 +338,7 @@ def test_trunc_int_quant_defaults(self, bit_width_const): "bit_width_impl": BitWidthConst(4), "float_to_int_impl": RoundSte(), "trunc_scaling_impl": TruncScalingWrapper( - trunc_int_scaling_impl=PowerOfTwoIntScaling(signed=True), + trunc_int_scaling_impl=PowerOfTwoIntScaling(), scaling_impl=RuntimeStatsScaling( scaling_stats_impl=AbsMax(), scaling_stats_input_view_shape_impl=Identity(), @@ -373,7 +373,7 @@ def test_trunc_int_quant_defaults(self, bit_width_const): "bit_width_impl": BitWidthConst(4), "float_to_int_impl": RoundSte(), "trunc_scaling_impl": TruncScalingWrapper( - trunc_int_scaling_impl=PowerOfTwoIntScaling(signed=True), + trunc_int_scaling_impl=PowerOfTwoIntScaling(), scaling_impl=RuntimeStatsScaling( scaling_stats_impl=AbsMax(), scaling_stats_input_view_shape_impl=Identity(), @@ -408,7 +408,7 @@ def test_trunc_int_quant_defaults(self, bit_width_const): "bit_width_impl": BitWidthConst(4), "float_to_int_impl": RoundSte(), "trunc_scaling_impl": TruncScalingWrapper( - trunc_int_scaling_impl=PowerOfTwoIntScaling(signed=True), + trunc_int_scaling_impl=PowerOfTwoIntScaling(), scaling_impl=RuntimeStatsScaling( scaling_stats_impl=AbsMax(), scaling_stats_input_view_shape_impl=Identity(), @@ -443,7 +443,7 @@ def test_trunc_int_quant_defaults(self, bit_width_const): "bit_width_impl": BitWidthConst(4), "float_to_int_impl": RoundSte(), "trunc_scaling_impl": TruncScalingWrapper( - trunc_int_scaling_impl=PowerOfTwoIntScaling(signed=True), + trunc_int_scaling_impl=PowerOfTwoIntScaling(), scaling_impl=RuntimeStatsScaling( scaling_stats_impl=AbsMax(), scaling_stats_input_view_shape_impl=Identity(), @@ -474,7 +474,6 @@ def test_trunc_int_quant_defaults(self, bit_width_const): "bit_width": torch.tensor([4.]), }, }, - ],) # yapf: enable def trunc_int_quant_io_fixture(self, request): From 756482a8fd704a23c891e376abf32fc9fb6ba0c6 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Tue, 28 Jan 2025 12:11:54 +0000 Subject: [PATCH 24/32] test (trunc_quant): Switched to pytest_cases.parametrize --- tests/brevitas/core/test_trunc_int_quant.py | 42 +++++++++++---------- 1 file changed, 22 insertions(+), 20 deletions(-) diff --git a/tests/brevitas/core/test_trunc_int_quant.py b/tests/brevitas/core/test_trunc_int_quant.py index 579563a7a..186563686 100644 --- a/tests/brevitas/core/test_trunc_int_quant.py +++ b/tests/brevitas/core/test_trunc_int_quant.py @@ -35,24 +35,10 @@ def test_trunc_int_quant_defaults(self, bit_width_const): assert trunc_int_quant.narrow_range == False # yapf: disable - @pytest_cases.fixture( - ids=[ - "defaults_uint_overflow", - "defaults_int+_overflow", - "defaults_int-_max", - "defaults_uint_underflow", - "defaults_int_underflow", - "defaults_uint_ulp", - "defaults_int_ulp", - "abxmax_uint_overflow", - "abxmax_int+_overflow", - "abxmax_int-_overflow", - "abxmax_uint_underflow", - "abxmax_int_underflow", - "abxmax_uint_ulp", - "abxmax_int_ulp", - ], - params=[ + @pytest_cases.fixture + @pytest_cases.parametrize( + "test_cfg", + [ { # defaults_uint_overflow "init_args": { "bit_width_impl": BitWidthConst(4), @@ -474,10 +460,26 @@ def test_trunc_int_quant_defaults(self, bit_width_const): "bit_width": torch.tensor([4.]), }, }, + ], + ids=[ + "defaults_uint_overflow", + "defaults_int+_overflow", + "defaults_int-_max", + "defaults_uint_underflow", + "defaults_int_underflow", + "defaults_uint_ulp", + "defaults_int_ulp", + "abxmax_uint_overflow", + "abxmax_int+_overflow", + "abxmax_int-_overflow", + "abxmax_uint_underflow", + "abxmax_int_underflow", + "abxmax_uint_ulp", + "abxmax_int_ulp", ],) # yapf: enable - def trunc_int_quant_io_fixture(self, request): - yield request.param + def trunc_int_quant_io_fixture(self, test_cfg): + yield test_cfg def test_trunc_int_quant_io(self, caplog, trunc_int_quant_io_fixture): caplog.set_level(logging.INFO) From 0b0e18cb828ab55e659d7ab65d3184219b00be00 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Tue, 28 Jan 2025 13:20:23 +0000 Subject: [PATCH 25/32] Fix (trunc): Fixed output zero-point calculation --- src/brevitas/core/quant/int.py | 5 ++-- tests/brevitas/core/test_trunc_int_quant.py | 26 +++++++++++++++++++++ 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/src/brevitas/core/quant/int.py b/src/brevitas/core/quant/int.py index 0583a43ae..9bae3a42a 100644 --- a/src/brevitas/core/quant/int.py +++ b/src/brevitas/core/quant/int.py @@ -250,10 +250,11 @@ def forward( y = self.float_to_int_impl(y) y = self.tensor_clamp_impl(y, min_val=min_int_val, max_val=max_int_val) output_scale = scale * trunc_scale - y = y - zero_point + output_zero_point = zero_point / trunc_scale + y = y - output_zero_point y = y * output_scale y = self.delay_wrapper(x, y) - return y, output_scale, zero_point, output_bit_width + return y, output_scale, output_zero_point, output_bit_width class DecoupledRescalingIntQuantWithInput(DecoupledRescalingIntQuant): diff --git a/tests/brevitas/core/test_trunc_int_quant.py b/tests/brevitas/core/test_trunc_int_quant.py index 186563686..4eafcc819 100644 --- a/tests/brevitas/core/test_trunc_int_quant.py +++ b/tests/brevitas/core/test_trunc_int_quant.py @@ -89,6 +89,31 @@ def test_trunc_int_quant_defaults(self, bit_width_const): "zero_point": torch.tensor([0.]), "bit_width": torch.tensor([4.]), }, + }, { # defaults_int+_overflow_zp + "init_args": { + "bit_width_impl": BitWidthConst(4), + "float_to_int_impl": RoundSte(), + }, + "train_args": { + "x": torch.tensor([1727.]), + "scale": torch.tensor([1.]), + "zero_point": torch.tensor([-1600.]), + "input_bit_width": torch.tensor([8.]), + "signed": True, + }, + "eval_args": { + "x": torch.tensor([1727.]), + "scale": torch.tensor([1.]), + "zero_point": torch.tensor([-1600.]), + "input_bit_width": torch.tensor([8.]), + "signed": True, + }, + "result": { # Result needs to match the order of the output tuple + "y": torch.tensor([1712.]), + "scale": torch.tensor([16.]), + "zero_point": torch.tensor([-100.]), + "bit_width": torch.tensor([4.]), + }, }, { # defaults_int-_max "init_args": { "bit_width_impl": BitWidthConst(4), @@ -464,6 +489,7 @@ def test_trunc_int_quant_defaults(self, bit_width_const): ids=[ "defaults_uint_overflow", "defaults_int+_overflow", + "defaults_int+_overflow_zp", "defaults_int-_max", "defaults_uint_underflow", "defaults_int_underflow", From 1f737159dd27dc0f45227360e3a0d9bfdf09187e Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Tue, 28 Jan 2025 13:21:21 +0000 Subject: [PATCH 26/32] Fix (export/qonnx/trunc): Added check that zero-point is zero. --- src/brevitas/export/onnx/qonnx/handler.py | 4 ++++ src/brevitas/proxy/runtime_quant.py | 3 +++ 2 files changed, 7 insertions(+) diff --git a/src/brevitas/export/onnx/qonnx/handler.py b/src/brevitas/export/onnx/qonnx/handler.py index 961b15761..1c1489ce2 100644 --- a/src/brevitas/export/onnx/qonnx/handler.py +++ b/src/brevitas/export/onnx/qonnx/handler.py @@ -227,7 +227,11 @@ def symbolic_execution(self, x: Tensor, input_scale=None, input_bit_width=None): class BrevitasTruncQuantProxyHandler(ONNXBaseHandler): handled_layer = TruncQuantProxyFromInjector + def validate(self, module): + assert module.zero_point() == 0, "Zero-point export not supported for TruncQuant." + def prepare_for_export(self, module: TruncQuantProxyFromInjector): + self.validate(module) self.symbolic_kwargs = { 'narrow_range': module.is_narrow_range, 'output_scale': module.scale(), diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index c8bc92977..2e3749236 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -282,6 +282,9 @@ def is_narrow_range(self): def scale(self): return self.retrieve_attribute('scale') + def zero_point(self): + return self.retrieve_attribute('zero_point') + def bit_width(self): if not self.is_quant_enabled: return None From e9de9d42629d8ff073994907fc379e8dfcc820b1 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Tue, 28 Jan 2025 13:23:09 +0000 Subject: [PATCH 27/32] Fix (export/qcdq/trunc): Pick up output scale from proxy --- src/brevitas/export/common/handler/qcdq.py | 24 ++++++++++++++-------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/src/brevitas/export/common/handler/qcdq.py b/src/brevitas/export/common/handler/qcdq.py index 091fd48c4..1951f1be4 100644 --- a/src/brevitas/export/common/handler/qcdq.py +++ b/src/brevitas/export/common/handler/qcdq.py @@ -773,18 +773,26 @@ class QCDQCastTruncQuantProxyHandlerMixin(QuantAxisMixin, ABC): handled_layer = TruncQuantProxyFromInjector + def validate(self, module): + assert module.zero_point() == 0, "Zero-point export not supported for TruncQuant." + super(QCDQCastTruncQuantProxyHandlerMixin, self).validate(module) + def prepare_for_export(self, module: TruncQuantProxyFromInjector): if module.is_quant_enabled: self.validate(module) - self.symbolic_kwargs = {'output_bit_width': module.bit_width()} + self.symbolic_kwargs = { + 'narrow_range': module.is_narrow_range, + 'output_scale': module.scale(), + 'output_bit_width': module.bit_width()} def symbolic_execution( self, x: Tensor, scale: Tensor, zero_point: Tensor, input_bit_width: Tensor, signed: Tensor): assert self.symbolic_kwargs is not None, 'Symbolic execution requires quant to be enabled' output_bit_width = self.symbolic_kwargs['output_bit_width'] + narrow_range = self.symbolic_kwargs['narrow_range'] dtype = self.int8_dtype() if signed else self.uint8_dtype() - trunc_scale = 2.0 ** (input_bit_width - output_bit_width) + scale = self.symbolic_kwargs['output_scale'] # Input scale is ignored now # If original dtype of scale is (b)float16, store the original scale dtype # and cast the scale and the input to float32 scale_dtype = scale.dtype @@ -792,19 +800,17 @@ def symbolic_execution( scale = self.cast_fn(scale, torch.float32) if x.dtype == torch.bfloat16 or x.dtype == torch.float16: x = self.cast_fn(x, torch.float32) - pre_scale = scale * trunc_scale - flat_pre_scale = to_0dim_if_scalar(pre_scale.flatten()) flat_scale = to_0dim_if_scalar(scale.flatten()) zp = to_0dim_if_scalar(zero_point.flatten()).expand_as(flat_scale) zp = self.zero_point_with_dtype(signed, output_bit_width, zp) - x = self.quantize_fn(x, flat_pre_scale, zp, dtype, self.quant_axis(pre_scale)) + x = self.quantize_fn(x, flat_scale, zp, dtype, self.quant_axis(scale)) clip_symbolic_kwargs = self.int_clip_symbolic_kwargs( - signed=signed, narrow=False, bit_width=output_bit_width) + signed=signed, narrow=self.symbolic_kwargs['narrow_range'], bit_width=output_bit_width) if clip_symbolic_kwargs is not None: x = self.clip_fn(x, *clip_symbolic_kwargs.values()) - x = self.dequantize_fn(x, flat_pre_scale, zp, self.quant_axis(scale)) + x = self.dequantize_fn(x, flat_scale, zp, self.quant_axis(scale)) # After dequantization, cast both output and scale to the correct dtype if scale_dtype == torch.float16 or scale_dtype == torch.bfloat16: x = self.cast_fn(x, scale_dtype) - flat_pre_scale = self.cast_fn(flat_pre_scale, scale_dtype) - return x, flat_pre_scale, zero_point, output_bit_width + flat_scale = self.cast_fn(flat_scale, scale_dtype) + return x, flat_scale, zero_point, output_bit_width From 5a41cfb0ac57d1e7db01820ddc81206c32654bf3 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Tue, 28 Jan 2025 13:47:15 +0000 Subject: [PATCH 28/32] Fix (export/trunc): Retrieve bit_width from cache --- src/brevitas/proxy/runtime_quant.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index 2e3749236..12dcf9528 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -286,13 +286,7 @@ def zero_point(self): return self.retrieve_attribute('zero_point') def bit_width(self): - if not self.is_quant_enabled: - return None - zhs = self._zero_hw_sentinel() - # Signed might or might not be defined. We just care about retrieving the bitwidth - empty_imp = IntQuantTensor(zhs, zhs, zhs, zhs, signed=True, training=self.training) - bit_width = self.__call__(empty_imp).bit_width - return bit_width + return self.retrieve_attribute('bit_width') def forward(self, x: IntQuantTensor) -> Union[Tensor, IntQuantTensor]: if self.is_quant_enabled: From b3cbf1680d3d6acbf27b0f1ebd94d9d56ad8436c Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Tue, 28 Jan 2025 13:50:23 +0000 Subject: [PATCH 29/32] precommit --- src/brevitas/export/common/handler/qcdq.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/brevitas/export/common/handler/qcdq.py b/src/brevitas/export/common/handler/qcdq.py index 1951f1be4..2bb37af89 100644 --- a/src/brevitas/export/common/handler/qcdq.py +++ b/src/brevitas/export/common/handler/qcdq.py @@ -792,7 +792,7 @@ def symbolic_execution( output_bit_width = self.symbolic_kwargs['output_bit_width'] narrow_range = self.symbolic_kwargs['narrow_range'] dtype = self.int8_dtype() if signed else self.uint8_dtype() - scale = self.symbolic_kwargs['output_scale'] # Input scale is ignored now + scale = self.symbolic_kwargs['output_scale'] # Input scale is ignored now # If original dtype of scale is (b)float16, store the original scale dtype # and cast the scale and the input to float32 scale_dtype = scale.dtype From ef384e3f67b4f9ea7d1a36d8e792efbe363a8c29 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Tue, 28 Jan 2025 16:21:38 +0000 Subject: [PATCH 30/32] docs (imagenet/qat): Updated accuracy with new TruncAvgPool implementation --- src/brevitas_examples/imagenet_classification/qat/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/brevitas_examples/imagenet_classification/qat/README.md b/src/brevitas_examples/imagenet_classification/qat/README.md index 92e318c05..5d95a3c55 100644 --- a/src/brevitas_examples/imagenet_classification/qat/README.md +++ b/src/brevitas_examples/imagenet_classification/qat/README.md @@ -7,7 +7,7 @@ Below in the table is a list of example pretrained models made available for ref | Name | Cfg | Scaling Type | First layer weights | Weights | Activations | Avg pool | Top1 | Pretrained model | Retrained from | |--------------|-----------------------|----------------------------|---------------------|---------|-------------|----------|-------|-------------------------------------------------------------------------------------------------|---------------------------------------------------------------| -| MobileNet V1 | quant_mobilenet_v1_4b | Floating-point per channel | 8 bit | 4 bit | 4 bit | 4 bit | 70.95 | [Download](https://github.com/Xilinx/brevitas/releases/download/quant_mobilenet_v1_4b-r1/quant_mobilenet_v1_4b-0100a667.pth) | [link](https://github.com/osmr/imgclsmob/tree/master/pytorch) | +| MobileNet V1 | quant_mobilenet_v1_4b | Floating-point per channel | 8 bit | 4 bit | 4 bit | 4 bit | 70.86 | [Download](https://github.com/Xilinx/brevitas/releases/download/quant_mobilenet_v1_4b-r1/quant_mobilenet_v1_4b-0100a667.pth) | [link](https://github.com/osmr/imgclsmob/tree/master/pytorch) | | ProxylessNAS Mobile14 w/ Hadamard classifier | quant_proxylessnas_mobile14_hadamard_4b | Floating-point per channel | 8 bit | 4 bit | 4 bit | 4 bit | 72.87 | [Download](https://github.com/Xilinx/brevitas/releases/download/quant_proxylessnas_mobile14_hadamard_4b-r0/quant_proxylessnas_mobile14_hadamard_4b-4acbfa9f.pth) | [link](https://github.com/osmr/imgclsmob/tree/master/pytorch) | | ProxylessNAS Mobile14 | quant_proxylessnas_mobile14_4b | Floating-point per channel | 8 bit | 4 bit | 4 bit | 4 bit | 74.39 | [Download](https://github.com/Xilinx/brevitas/releases/download/quant_proxylessnas_mobile14_4b-r0/quant_proxylessnas_mobile14_4b-e10882e1.pth) | [link](https://github.com/osmr/imgclsmob/tree/master/pytorch) | | ProxylessNAS Mobile14 | quant_proxylessnas_mobile14_4b5b | Floating-point per channel | 8 bit | 4 bit, 5 bit | 4 bit, 5 bit | 4 bit | 74.94 | [Download](https://github.com/Xilinx/brevitas/releases/download/quant_proxylessnas_mobile14_4b5b-r0/quant_proxylessnas_mobile14_4b5b-2bdf7f8d.pth) | [link](https://github.com/osmr/imgclsmob/tree/master/pytorch) | From 43b02e5355de23b51df9f493e100524d8cf82ef6 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Tue, 28 Jan 2025 17:11:41 +0000 Subject: [PATCH 31/32] test (finn/mobilenet): Allow tolerance of up-to 7 in output. --- .../brevitas_examples/test_mobilenet_finn_export.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/brevitas_finn/brevitas_examples/test_mobilenet_finn_export.py b/tests/brevitas_finn/brevitas_examples/test_mobilenet_finn_export.py index 01505b9e5..6e4648e30 100644 --- a/tests/brevitas_finn/brevitas_examples/test_mobilenet_finn_export.py +++ b/tests/brevitas_finn/brevitas_examples/test_mobilenet_finn_export.py @@ -24,7 +24,8 @@ reason='Issue with ORT and MobileNet export on MacOS on PyTorch >= 1.5.0') INPUT_SIZE = (1, 3, 224, 224) -ATOL = 1e-3 +ATOL = 7 # How many bitflips to tolerate in the 32-bit output +RTOL = 1e-2 SEED = 0 @@ -42,6 +43,7 @@ def test_mobilenet_v1_4b(pretrained): # do forward pass in PyTorch/Brevitas expected = mobilenet(torch_tensor).detach().numpy() export_qonnx(mobilenet, input_shape=INPUT_SIZE, export_path=finn_onnx) + output_scale = mobilenet.output.bias_quant.scale() # Scale at the output model = ModelWrapper(finn_onnx) model = model.transform(GiveUniqueNodeNames()) model = model.transform(DoubleToSingleFloat()) @@ -53,4 +55,4 @@ def test_mobilenet_v1_4b(pretrained): input_dict = {inp_name: numpy_tensor} output_dict = oxe.execute_onnx(model, input_dict) produced = output_dict[list(output_dict.keys())[0]] - assert np.isclose(produced, expected, atol=ATOL).all() + assert np.isclose(produced, expected, rtol=RTOL, atol=ATOL * output_scale).all() From 6df800c94fc6218879b236a715031bba6dd15572 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Tue, 11 Feb 2025 16:30:57 +0000 Subject: [PATCH 32/32] Fix (test/export/trunc): Revert export to produce a Trunc node. --- src/brevitas/export/onnx/qonnx/function.py | 8 +++++--- src/brevitas/export/onnx/qonnx/handler.py | 2 +- .../brevitas/test_brevitas_avg_pool_export.py | 9 +++++---- .../brevitas_examples/test_mobilenet_finn_export.py | 11 +++++++---- 4 files changed, 18 insertions(+), 12 deletions(-) diff --git a/src/brevitas/export/onnx/qonnx/function.py b/src/brevitas/export/onnx/qonnx/function.py index 0fbeba8f6..7df4990df 100644 --- a/src/brevitas/export/onnx/qonnx/function.py +++ b/src/brevitas/export/onnx/qonnx/function.py @@ -123,14 +123,16 @@ def symbolic( output_bit_width, rounding_mode): ret = g.op( - f'{DOMAIN_STRING}::Quant', + f'{DOMAIN_STRING}::Trunc', x, - output_scale, + scale, zero_point, + input_bit_width, output_bit_width, rounding_mode_s=rounding_mode, signed_i=int(signed), - narrow_i=int(narrow_range)) + narrow_i=int(narrow_range), + output_scale_f=output_scale) ret.setType(x.type()) return ret diff --git a/src/brevitas/export/onnx/qonnx/handler.py b/src/brevitas/export/onnx/qonnx/handler.py index 1c1489ce2..b1b7d0625 100644 --- a/src/brevitas/export/onnx/qonnx/handler.py +++ b/src/brevitas/export/onnx/qonnx/handler.py @@ -234,7 +234,7 @@ def prepare_for_export(self, module: TruncQuantProxyFromInjector): self.validate(module) self.symbolic_kwargs = { 'narrow_range': module.is_narrow_range, - 'output_scale': module.scale(), + 'output_scale': float(module.scale().detach().cpu().numpy()), 'output_bit_width': module.bit_width(), 'rounding_mode': module.rounding_mode} diff --git a/tests/brevitas_finn/brevitas/test_brevitas_avg_pool_export.py b/tests/brevitas_finn/brevitas/test_brevitas_avg_pool_export.py index a9c851d7e..01e329dbb 100644 --- a/tests/brevitas_finn/brevitas/test_brevitas_avg_pool_export.py +++ b/tests/brevitas_finn/brevitas/test_brevitas_avg_pool_export.py @@ -77,10 +77,11 @@ def test_brevitas_avg_pool_export( odict = oxe.execute_onnx(model, idict, True) finn_output = odict[model.graph.output[0].name] # compare outputs - if restrict_scaling_type == "power_of_two" and kernel_size == 2: - atol = 1e-8 - else: - atol = quant_avgpool.trunc_quant.scale().detach().numpy() # Allow "off-by-1" errors + #if restrict_scaling_type == "power_of_two" and kernel_size == 2: + # atol = 1e-8 + #else: + # atol = quant_avgpool.trunc_quant.scale().detach().numpy() # Allow "off-by-1" errors + atol = 1e-8 assert np.isclose(ref_output_array, finn_output, atol=atol).all() # cleanup os.remove(export_path) diff --git a/tests/brevitas_finn/brevitas_examples/test_mobilenet_finn_export.py b/tests/brevitas_finn/brevitas_examples/test_mobilenet_finn_export.py index 6e4648e30..f225ddb12 100644 --- a/tests/brevitas_finn/brevitas_examples/test_mobilenet_finn_export.py +++ b/tests/brevitas_finn/brevitas_examples/test_mobilenet_finn_export.py @@ -24,8 +24,9 @@ reason='Issue with ORT and MobileNet export on MacOS on PyTorch >= 1.5.0') INPUT_SIZE = (1, 3, 224, 224) -ATOL = 7 # How many bitflips to tolerate in the 32-bit output -RTOL = 1e-2 +#ATOL = 1 # Alternative: how many bitflips to tolerate in the 32-bit output +ATOL = 1e-3 +RTOL = 1e-5 SEED = 0 @@ -43,7 +44,7 @@ def test_mobilenet_v1_4b(pretrained): # do forward pass in PyTorch/Brevitas expected = mobilenet(torch_tensor).detach().numpy() export_qonnx(mobilenet, input_shape=INPUT_SIZE, export_path=finn_onnx) - output_scale = mobilenet.output.bias_quant.scale() # Scale at the output + #output_scale = mobilenet.output.bias_quant.scale() # Scale at the output model = ModelWrapper(finn_onnx) model = model.transform(GiveUniqueNodeNames()) model = model.transform(DoubleToSingleFloat()) @@ -55,4 +56,6 @@ def test_mobilenet_v1_4b(pretrained): input_dict = {inp_name: numpy_tensor} output_dict = oxe.execute_onnx(model, input_dict) produced = output_dict[list(output_dict.keys())[0]] - assert np.isclose(produced, expected, rtol=RTOL, atol=ATOL * output_scale).all() + #atol = ATOL * output_scale # Absolute tolerance in bitflips + atol = ATOL + assert np.isclose(produced, expected, rtol=RTOL, atol=atol).all()