From eb6e1081dbfad9f3039bec1792d9a565b5eca73d Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 23 Sep 2024 18:31:36 +0100 Subject: [PATCH] fix some errors --- src/brevitas/quant/base.py | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/src/brevitas/quant/base.py b/src/brevitas/quant/base.py index 308fad3c8..f7a6aa171 100644 --- a/src/brevitas/quant/base.py +++ b/src/brevitas/quant/base.py @@ -357,12 +357,19 @@ class PerChannelPreNorm(ExtendedInjector): class AccumulatorAwarePerChannelPreNorm(PerChannelPreNorm): - @value - def accumulator_bit_width_impl(accumulator_bit_width): - return BitWidthStatefulConst(accumulator_bit_width) - pre_scaling_impl = AccumulatorAwareParameterPreScaling - accumulator_bit_width = 32 # default maximum accumulator width is 32 bits + accumulator_bit_width = (this << 1).accumulator_bit_width + accumulator_bit_width_impl = (this << 1).accumulator_bit_width_impl + + +class AccumulatorAwareZeroCenterPerChannelPreNorm(AccumulatorAwarePerChannelPreNorm): + + pre_scaling_impl = AccumulatorAwareZeroCenterParameterPreScaling + pre_zero_point_impl = PreZeroCenterZeroPoint + pre_zero_point_shape = this.scaling_shape # TODO: decouple zero_point from scaling + pre_zero_point_stats_input_view_shape_impl = this.scaling_stats_input_view_shape_impl + stats_reduce_dim = (this << 1).stats_reduce_dim + scaling_shape = (this << 1).scaling_shape class SolvePostScaleGranularity(ExtendedInjector): @@ -457,10 +464,12 @@ class AccumulatorAwareWeightQuant(WeightNormPerChannelFloatDecoupled): per_channel_pre_norm = AccumulatorAwarePerChannelPreNorm normalize_stats_impl = PerChannelL1Norm.normalize_stats_impl # required to align with derivations in paper float_to_int_impl = RoundToZeroSte # required to ensure no upwards rounding violates constraints + accumulator_bit_width = 32 # default maximum accumulator width is 32 bits @value - def accumulator_bit_width(): - return this.per_channel_pre_norm.accumulator_bit_width + def accumulator_bit_width_impl(accumulator_bit_width): + return BitWidthStatefulConst(accumulator_bit_width) + class AccumulatorAwareZeroCenterWeightQuant(AccumulatorAwareWeightQuant): """Experimental zero-centered accumulator-aware weight quantized based on: @@ -470,10 +479,7 @@ class AccumulatorAwareZeroCenterWeightQuant(AccumulatorAwareWeightQuant): (1) added zero-centering constraint on the weights (i.e., `PreZeroCenterZeroPoint`) (2) a more relaxed l1-norm bound that is derived in the referenced paper """ - pre_scaling_impl = AccumulatorAwareZeroCenterParameterPreScaling - pre_zero_point_impl = PreZeroCenterZeroPoint - pre_zero_point_shape = this.scaling_shape # TODO: decouple zero_point from scaling - pre_zero_point_stats_input_view_shape_impl = this.scaling_stats_input_view_shape_impl + per_channel_pre_norm = AccumulatorAwareZeroCenterPerChannelPreNorm class MSESubInjectorBase(ExtendedInjector):