From c36406c2387d5e1784066efad8d5c2d0ebb6f638 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 18 Jan 2024 13:44:07 +0000 Subject: [PATCH] review --- .../export/onnx/standard/qcdq/handler.py | 4 +- .../common/generative/quantize.py | 8 +++- .../common/generative/quantizers.py | 41 +++++++++++++++---- 3 files changed, 42 insertions(+), 11 deletions(-) diff --git a/src/brevitas/export/onnx/standard/qcdq/handler.py b/src/brevitas/export/onnx/standard/qcdq/handler.py index a1981eac1..b81b143b6 100644 --- a/src/brevitas/export/onnx/standard/qcdq/handler.py +++ b/src/brevitas/export/onnx/standard/qcdq/handler.py @@ -95,9 +95,9 @@ def int32_dtype(cls): def validate(self, module): super().validate(module) - # ONNX QuantizeLinear supports only 8b output with round to nearest even. - # Below 8b quantization is supported through clipping. + # ONNX DynamicQuantizeLinear supports only 8b output with round to nearest even. assert module.rounding_mode.upper() == 'ROUND', 'Only round to nearest even supported' + # Below 8b quantization is not supported. self.validate_8b_bit_width(module.bit_width(), le_then=False) def quantize_fn(self, x, dtype): diff --git a/src/brevitas_examples/common/generative/quantize.py b/src/brevitas_examples/common/generative/quantize.py index 979d21c48..929a5412f 100644 --- a/src/brevitas_examples/common/generative/quantize.py +++ b/src/brevitas_examples/common/generative/quantize.py @@ -40,6 +40,8 @@ from brevitas_examples.common.generative.quantizers import Int8ActPerRowFloat from brevitas_examples.common.generative.quantizers import Int8ActPerRowFloatMSE from brevitas_examples.common.generative.quantizers import IntWeightSymmetricGroupQuant +from brevitas_examples.common.generative.quantizers import ShiftedUint8ActDynamicPerGroupFloat +from brevitas_examples.common.generative.quantizers import ShiftedUint8ActDynamicPerRowFloat from brevitas_examples.common.generative.quantizers import ShiftedUint8ActDynamicPerTensorFloat from brevitas_examples.common.generative.quantizers import ShiftedUint8ActPerRowFloat from brevitas_examples.common.generative.quantizers import ShiftedUint8ActPerRowFloatMSE @@ -112,9 +114,11 @@ 'sym': Int8ActDynamicPerTensorFloat, 'asym': ShiftedUint8ActDynamicPerTensorFloat}, 'per_row': { - 'sym': Int8ActDynamicPerRowFloat}, + 'sym': Int8ActDynamicPerRowFloat, + 'asym': ShiftedUint8ActDynamicPerRowFloat}, 'per_group': { - 'sym': Int8ActDynamicPerGroupFloat},}}}}, + 'sym': Int8ActDynamicPerGroupFloat, + 'asym': ShiftedUint8ActDynamicPerGroupFloat},}}}}, 'float': { 'static': { 'float_scale': { diff --git a/src/brevitas_examples/common/generative/quantizers.py b/src/brevitas_examples/common/generative/quantizers.py index 8dc5db2ce..e15f3b723 100644 --- a/src/brevitas_examples/common/generative/quantizers.py +++ b/src/brevitas_examples/common/generative/quantizers.py @@ -67,6 +67,10 @@ def reshaped_scaling_shape(module): block_size = None +class ActDynamicProxyMixin(ExtendedInjector): + proxy_class = DynamicActQuantProxyFromInjector + + class IntWeightSymmetricGroupQuant(WeightSymmetricGroupQuantMixin, Int8WeightPerChannelFloat): """ Block / group / vector signed symmetric int weight quantizer with float scales. @@ -130,11 +134,32 @@ class Int8ActDynamicPerTensorFloat(Int8ActPerTensorFloat): scaling_stats_op = 'max' -class ShiftedUint8ActDynamicPerTensorFloat(ShiftedUint8ActPerTensorFloat): +class Int8ActDynamicPerRowFloat(Int8ActPerRowFloat, ActDynamicProxyMixin): """ - Symmetric quantizer with per tensor dynamic scale. + Symmetric quantizer with per row dynamic scale. + """ + scaling_impl = RuntimeDynamicStatsScaling + scaling_stats_input_view_shape_impl = OverBatchOverOutputChannelView + scaling_stats_op = 'max' + + +class Int8ActDynamicPerGroupFloat(Int8ActPerRowFloat, ActDynamicProxyMixin): + """ + Symmetric quantizer with per group scale. + """ + scaling_impl = RuntimeDynamicGroupStatsScaling + keepdim = True + scaling_stats_op = 'max' + + @value + def stats_reduce_dim(group_dim): + return group_dim + 1 + + +class ShiftedUint8ActDynamicPerTensorFloat(ShiftedUint8ActPerTensorFloat, ActDynamicProxyMixin): + """ + Asymmetric quantizer with per tensor dynamic scale. """ - proxy_class = DynamicActQuantProxyFromInjector scaling_impl = RuntimeDynamicStatsScaling scaling_stats_input_view_shape_impl = OverTensorView scaling_stats_op = 'max' @@ -143,22 +168,24 @@ class ShiftedUint8ActDynamicPerTensorFloat(ShiftedUint8ActPerTensorFloat): stats_reduce_dim = 0 -class Int8ActDynamicPerRowFloat(Int8ActPerRowFloat): +class ShiftedUint8ActDynamicPerRowFloat(ShiftedUint8ActPerRowFloat, ActDynamicProxyMixin): """ - Symmetric quantizer with per row dynamic scale. + Asymmetric quantizer with per row dynamic scale. """ scaling_impl = RuntimeDynamicStatsScaling scaling_stats_input_view_shape_impl = OverBatchOverOutputChannelView scaling_stats_op = 'max' + zero_point_stats_impl = NegativeMinOrZero -class Int8ActDynamicPerGroupFloat(Int8ActPerRowFloat): +class ShiftedUint8ActDynamicPerGroupFloat(ShiftedUint8ActPerRowFloat, ActDynamicProxyMixin): """ - Symmetric quantizer with per group scale. + Asymmetric quantizer with per group dynamic scale. """ scaling_impl = RuntimeDynamicGroupStatsScaling keepdim = True scaling_stats_op = 'max' + zero_point_stats_impl = NegativeMinOrZero @value def stats_reduce_dim(group_dim):