Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/Feat (trunc avg pool): Update truncation and average pool behaviour #1042

Open
wants to merge 31 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
767252d
Fix (quant_tensor): Produce valid IntQuantTensor after AvgPool functi…
nickfraser Oct 4, 2024
13ca170
Fix (core/trunc): Fix output scaling after truncation
nickfraser Oct 4, 2024
f183191
Fix (nn/TruncAvgPool): Remove any quant tensor manual manipulation.
nickfraser Oct 4, 2024
859897c
fix/trunc_avg_pool: Clamp output.
nickfraser Jan 22, 2025
f5752d6
style: fix
nickfraser Jan 22, 2025
9b2fdf7
fix (trunc_avg_pool): Set default arguments for backward compatibility
nickfraser Jan 22, 2025
76dfe8e
test (trunc_int_quant): Added initial sanity-check test
nickfraser Jan 22, 2025
b2d849f
fix (export/torch/qcdq): Fixed output scale, and `signed` setting
nickfraser Jan 22, 2025
16717c5
Fix (core/proxy/trunc): Moved setting of signed to the proxy
nickfraser Jan 22, 2025
faf973c
fix (qonnx/trunc): Fixed Trunc Quant QONNX export
nickfraser Jan 23, 2025
1d63711
fix (trunc): Factored out scaling calculation to standalone class.
nickfraser Jan 24, 2025
14cdda2
fix typo: Updated comment in TruncAvgPool export
nickfraser Jan 24, 2025
7d3ee72
feat (trunc/scaling): Factored out the scaling implementation.
nickfraser Jan 24, 2025
0c849e7
test (trunc): Added signed overflow test
nickfraser Jan 24, 2025
7f76cd9
test (trunc): Added more unti tests.
nickfraser Jan 24, 2025
d2934a0
fix (test/trunc): Bugfixes and tests.
nickfraser Jan 24, 2025
7d6c934
Fix: precommit
nickfraser Jan 27, 2025
b121f5d
Fix (solver/trunc): Added a ShiftRoundSaturate quantizer and update t…
nickfraser Jan 27, 2025
0a00e8f
Fix (export/trunc): Updated export to generate Quant node.
nickfraser Jan 27, 2025
80f57e3
Fix (test/qonnx/trunc): Allow off-by-1 errors in test
nickfraser Jan 27, 2025
1af6840
tests (brv_finn/avgpool): Add "lossless" tests
nickfraser Jan 28, 2025
20a06bd
Fix (brevitas/scaling): TruncPowerOfTwoIntScaling -> PowerOfTwoIntSca…
nickfraser Jan 28, 2025
54e86a0
Fix (scaling): Made signed an optional argument at init time.
nickfraser Jan 28, 2025
756482a
test (trunc_quant): Switched to pytest_cases.parametrize
nickfraser Jan 28, 2025
0b0e18c
Fix (trunc): Fixed output zero-point calculation
nickfraser Jan 28, 2025
1f73715
Fix (export/qonnx/trunc): Added check that zero-point is zero.
nickfraser Jan 28, 2025
e9de9d4
Fix (export/qcdq/trunc): Pick up output scale from proxy
nickfraser Jan 28, 2025
5a41cfb
Fix (export/trunc): Retrieve bit_width from cache
nickfraser Jan 28, 2025
b3cbf16
precommit
nickfraser Jan 28, 2025
ef384e3
docs (imagenet/qat): Updated accuracy with new TruncAvgPool implement…
nickfraser Jan 28, 2025
43b02e5
test (finn/mobilenet): Allow tolerance of up-to 7 in output.
nickfraser Jan 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 39 additions & 8 deletions src/brevitas/core/quant/int.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from typing import Optional, Tuple
from typing import Optional, Tuple, Union

import torch
from torch import Tensor
from torch.nn import Module

import brevitas
from brevitas.core.function_wrapper import TensorClamp
from brevitas.core.quant.delay import DelayWrapper
from brevitas.core.scaling import TruncMsbScaling
from brevitas.core.utils import StatelessBuffer
from brevitas.function.ops import max_int
from brevitas.function.ops import min_int
from brevitas.function.ops_ste import round_ste


Expand Down Expand Up @@ -201,28 +205,55 @@ class TruncIntQuant(brevitas.jit.ScriptModule):
"""
"""

__constants__ = ['narrow_range']

def __init__(
self, float_to_int_impl: Module, bit_width_impl: Module, quant_delay_steps: int = 0):
self,
float_to_int_impl: Module,
bit_width_impl: Module,
trunc_scaling_impl: Module = TruncMsbScaling(),
narrow_range: bool = False,
tensor_clamp_impl: Module = TensorClamp(),
quant_delay_steps: int = 0):
super(TruncIntQuant, self).__init__()
self.narrow_range = narrow_range
self.msb_clamp_bit_width_impl = bit_width_impl
self.trunc_scaling_impl = trunc_scaling_impl
self.float_to_int_impl = float_to_int_impl
self.tensor_clamp_impl = tensor_clamp_impl
self.delay_wrapper = DelayWrapper(quant_delay_steps)

@brevitas.jit.script_method
def forward(self, x: Tensor, scale: Tensor, zero_point: Tensor,
input_bit_width: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
def min_int(self, bit_width: Tensor, signed: Union[bool, Tensor]):
return min_int(signed, self.narrow_range, bit_width)

@brevitas.jit.script_method
def max_int(self, bit_width: Tensor, signed: Union[bool, Tensor]):
return max_int(signed, self.narrow_range, bit_width)

@brevitas.jit.script_method
def forward(
self,
x: Tensor,
scale: Tensor,
zero_point: Tensor,
input_bit_width: Tensor,
signed: Union[bool, Tensor]) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
y = x / scale
y = y + zero_point
y = round_ste(y) # clean up floating point error
output_bit_width = self.msb_clamp_bit_width_impl()
trunc_bit_width = input_bit_width - output_bit_width
trunc_scale = 2.0 ** trunc_bit_width
trunc_scale = self.trunc_scaling_impl(y, input_bit_width, output_bit_width, signed)
y = y / trunc_scale
min_int_val = self.min_int(output_bit_width, signed)
max_int_val = self.max_int(output_bit_width, signed)
y = self.float_to_int_impl(y)
y = self.tensor_clamp_impl(y, min_val=min_int_val, max_val=max_int_val)
output_scale = scale * trunc_scale
y = y - zero_point
y = y * scale
y = y * output_scale
y = self.delay_wrapper(x, y)
return y, scale, zero_point, output_bit_width
return y, output_scale, zero_point, output_bit_width


class DecoupledRescalingIntQuantWithInput(DecoupledRescalingIntQuant):
Expand Down
3 changes: 3 additions & 0 deletions src/brevitas/core/scaling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .float_scaling import FloatScaling
from .int_scaling import IntScaling
from .int_scaling import PowerOfTwoIntScaling
from .int_scaling import TruncPowerOfTwoIntScaling
from .pre_scaling import AccumulatorAwareParameterPreScaling
from .pre_scaling import AccumulatorAwareZeroCenterParameterPreScaling
from .pre_scaling import ParameterPreScalingWeightNorm
Expand All @@ -18,5 +19,7 @@
from .standalone import ParameterFromRuntimeStatsScaling
from .standalone import ParameterFromStatsFromParameterScaling
from .standalone import ParameterScaling
from .standalone import TruncMsbScaling
from .standalone import TruncScalingWrapper

SCALING_STATS_REDUCE_DIM = 1
12 changes: 12 additions & 0 deletions src/brevitas/core/scaling/int_scaling.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from typing import Union

from torch import Tensor

import brevitas
Expand Down Expand Up @@ -34,3 +36,13 @@ def __init__(self, signed: bool):
@brevitas.jit.script_method
def forward(self, bit_width: Tensor) -> Tensor:
nickfraser marked this conversation as resolved.
Show resolved Hide resolved
return max_int(self.signed, False, bit_width) + 1


class TruncPowerOfTwoIntScaling(brevitas.jit.ScriptModule):
nickfraser marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self):
super(TruncPowerOfTwoIntScaling, self).__init__()

@brevitas.jit.script_method
def forward(self, bit_width: Tensor, signed: Union[bool, Tensor]) -> Tensor:
return max_int(signed, False, bit_width) + 1
49 changes: 49 additions & 0 deletions src/brevitas/core/scaling/standalone.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import brevitas.config as config
from brevitas.core.function_wrapper import Identity
from brevitas.core.function_wrapper import OverBatchOverTensorView
from brevitas.core.function_wrapper import TensorClamp
from brevitas.core.restrict_val import _ClampValue
from brevitas.core.restrict_val import _RestrictClampValue
from brevitas.core.restrict_val import _RestrictValue
Expand Down Expand Up @@ -469,3 +470,51 @@ def _load_from_state_dict(
self.counter = self.collect_stats_steps + 1
if config.IGNORE_MISSING_KEYS and value_key in missing_keys:
missing_keys.remove(value_key)


class TruncMsbScaling(brevitas.jit.ScriptModule):
"""
"""

def __init__(self) -> None:
super(TruncMsbScaling, self).__init__()

@brevitas.jit.script_method
def forward(
self,
scaling_input: Tensor,
input_bit_width: Tensor,
output_bit_width: Tensor,
signed: Union[bool, Tensor]) -> Tensor:
return 2 ** (input_bit_width - output_bit_width)


class TruncScalingWrapper(brevitas.jit.ScriptModule):
"""
"""

def __init__(
self,
trunc_int_scaling_impl: Module,
scaling_impl: Module,
tensor_clamp_impl: Module = TensorClamp()) -> None:
super(TruncScalingWrapper, self).__init__()
self.trunc_int_scaling_impl = trunc_int_scaling_impl
self.scaling_impl = scaling_impl
self.tensor_clamp_impl = tensor_clamp_impl

@brevitas.jit.script_method
def forward(
self,
scaling_input: Tensor,
input_bit_width: Tensor,
output_bit_width: Tensor,
signed: Union[bool, Tensor]) -> Tensor:
threshold = self.trunc_int_scaling_impl(output_bit_width, signed)
scale = self.scaling_impl(scaling_input, threshold)
msb_scale = 2 ** (input_bit_width - output_bit_width)
unit_scale = torch.ones_like(msb_scale)
max_scale = torch.where(msb_scale > unit_scale, msb_scale, unit_scale)
min_scale = torch.where(msb_scale < unit_scale, msb_scale, unit_scale)
trunc_scale = self.tensor_clamp_impl(scale, min_scale, max_scale)
return trunc_scale
6 changes: 3 additions & 3 deletions src/brevitas/export/common/handler/qcdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,9 +802,9 @@ def symbolic_execution(
signed=signed, narrow=False, bit_width=output_bit_width)
if clip_symbolic_kwargs is not None:
x = self.clip_fn(x, *clip_symbolic_kwargs.values())
x = self.dequantize_fn(x, flat_scale, zp, self.quant_axis(scale))
x = self.dequantize_fn(x, flat_pre_scale, zp, self.quant_axis(scale))
nickfraser marked this conversation as resolved.
Show resolved Hide resolved
# After dequantization, cast both output and scale to the correct dtype
if scale_dtype == torch.float16 or scale_dtype == torch.bfloat16:
x = self.cast_fn(x, scale_dtype)
scale = self.cast_fn(scale, scale_dtype)
return x, scale, zero_point, output_bit_width
flat_pre_scale = self.cast_fn(flat_pre_scale, scale_dtype)
return x, flat_pre_scale, zero_point, output_bit_width
44 changes: 33 additions & 11 deletions src/brevitas/export/onnx/qonnx/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,26 +111,48 @@ def forward(
class BrevitasTruncFn(Function):

@staticmethod
def symbolic(g, x, scale, zero_point, input_bit_width, output_bit_width, rounding_mode):
ret = g.op(
f'{DOMAIN_STRING}::Trunc',
def symbolic(
g,
x,
scale,
zero_point,
input_bit_width,
signed,
narrow_range,
output_scale,
output_bit_width,
rounding_mode):
ret = g.op(
f'{DOMAIN_STRING}::Quant',
x,
output_scale,
zero_point,
output_bit_width,
rounding_mode_s=rounding_mode)
rounding_mode_s=rounding_mode,
signed_i=int(signed),
narrow_i=int(narrow_range))
ret.setType(x.type())
return ret

@staticmethod
def forward(ctx, x, scale, zero_point, input_bit_width, output_bit_width, rounding_mode):
float_to_int_impl = solve_float_to_int_impl_from_enum(rounding_mode)
trunc = TruncIntQuant(
float_to_int_impl=float_to_int_impl(),
bit_width_impl=BitWidthConst(int(output_bit_width)))
y_tuple = trunc(x, scale, zero_point, input_bit_width)
return y_tuple[0]
def forward(
ctx,
x,
scale,
zero_point,
input_bit_width,
signed,
narrow_range,
output_scale,
output_bit_width,
rounding_mode):
# TODO: Restore this (fails when `signed` arg added)
#float_to_int_impl = solve_float_to_int_impl_from_enum(rounding_mode)
#trunc = TruncIntQuant(
# float_to_int_impl=float_to_int_impl(),
# bit_width_impl=BitWidthConst(int(output_bit_width)))
#y_tuple = trunc(x, scale, zero_point, input_bit_width, signed)
return x


class BrevitasQuantLSTMCellFn(Function):
Expand Down
9 changes: 6 additions & 3 deletions src/brevitas/export/onnx/qonnx/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,14 +229,17 @@ class BrevitasTruncQuantProxyHandler(ONNXBaseHandler):

def prepare_for_export(self, module: TruncQuantProxyFromInjector):
self.symbolic_kwargs = {
'output_bit_width': module.bit_width(), 'rounding_mode': module.rounding_mode}
'narrow_range': module.is_narrow_range,
'output_scale': module.scale(),
'output_bit_width': module.bit_width(),
'rounding_mode': module.rounding_mode}

def symbolic_execution(
self, x: Tensor, scale: Tensor, zero_point: Tensor, input_bit_width: Tensor,
signed: Tensor):
y = BrevitasTruncFn.apply(
x, scale, zero_point, input_bit_width, *self.symbolic_kwargs.values())
return y, scale, zero_point, self.symbolic_kwargs['output_bit_width']
x, scale, zero_point, input_bit_width, signed, *self.symbolic_kwargs.values())
return y, self.symbolic_kwargs['output_scale'], zero_point, self.symbolic_kwargs['output_bit_width']


class BrevitasQuantLSTMLayerHandler(QuantLSTMLayerHandler):
Expand Down
10 changes: 8 additions & 2 deletions src/brevitas/function/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
The implemented functions adheres to the restriction imposed by Pytorch 1.1.0's TorchScript compiler.
"""

from typing import Union

import torch
from torch import Tensor

Expand Down Expand Up @@ -128,7 +130,9 @@ def identity(x: Tensor) -> Tensor:


@brevitas.jit.script
def max_int(signed: bool, narrow_range: bool, bit_width: Tensor) -> Tensor:
def max_int(
signed: Union[bool, Tensor], narrow_range: Union[bool, Tensor],
bit_width: Tensor) -> Tensor:
""" Compute the maximum integer representable by a given number of bits.

Args:
Expand Down Expand Up @@ -159,7 +163,9 @@ def max_int(signed: bool, narrow_range: bool, bit_width: Tensor) -> Tensor:


@brevitas.jit.script
def min_int(signed: bool, narrow_range: bool, bit_width: Tensor) -> Tensor:
def min_int(
signed: Union[bool, Tensor], narrow_range: Union[bool, Tensor],
bit_width: Tensor) -> Tensor:
""" Compute the minimum integer representable by a given number of bits.

Args:
Expand Down
8 changes: 8 additions & 0 deletions src/brevitas/inject/enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,11 @@ class StatsOp(AutoName):
# Typically adopted for asymmetric quantization
MIN_MAX = auto()
PERCENTILE_INTERVAL = auto()


class TruncScalingImplType(AutoName):
"""

"""
MSB = auto()
WRAPPER = auto()
8 changes: 2 additions & 6 deletions src/brevitas/nn/quant_avg_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def _avg_scaling(self):
else:
return self.kernel_size * self.kernel_size

# TODO: Replace with functional call
def forward(self, input: Union[Tensor, QuantTensor]):
x = self.unpack_input(input)

Expand All @@ -71,8 +72,6 @@ def forward(self, input: Union[Tensor, QuantTensor]):
if not isinstance(x, QuantTensor):
x = self.cache_class.quant_tensor.set(value=x)
y = AvgPool2d.forward(self, x)
rescaled_value = y.value * self._avg_scaling
y = y.set(value=rescaled_value)
y = self.trunc_quant(y)
else:
y = AvgPool2d.forward(self, _unpack_quant_tensor(x))
Expand Down Expand Up @@ -123,6 +122,7 @@ def compute_kernel_size_stride(input_shape, output_shape):
stride_list.append(stride)
return kernel_size_list, stride_list

# TODO: Replace with functional call
def forward(self, input: Union[Tensor, QuantTensor]):
x = self.unpack_input(input)

Expand All @@ -139,10 +139,6 @@ def forward(self, input: Union[Tensor, QuantTensor]):
if not isinstance(x, QuantTensor):
x = self.cache_class.quant_tensor.set(value=x)
y = AdaptiveAvgPool2d.forward(self, x)
k_size, stride = self.compute_kernel_size_stride(x.value.shape[2:], y.value.shape[2:])
reduce_size = reduce(mul, k_size, 1)
rescaled_value = y.value * reduce_size # remove avg scaling
y = y.set(value=rescaled_value)
y = self.trunc_quant(y)
else:
y = AdaptiveAvgPool2d.forward(self, _unpack_quant_tensor(x))
Expand Down
Loading
Loading