diff --git a/src/brevitas/quant/base.py b/src/brevitas/quant/base.py index 7cd050e21..e27cd1094 100644 --- a/src/brevitas/quant/base.py +++ b/src/brevitas/quant/base.py @@ -363,6 +363,7 @@ class AccumulatorAwarePerChannelPreNorm(PerChannelPreNorm): pre_scaling_impl = AccumulatorAwareParameterPreScaling accumulator_bit_width = (this << 1).accumulator_bit_width accumulator_bit_width_impl = (this << 1).accumulator_bit_width_impl + restrict_pre_scaling_impl = (this << 1).restrict_pre_scaling_impl class AccumulatorAwareZeroCenterPerChannelPreNorm(AccumulatorAwarePerChannelPreNorm): @@ -373,6 +374,7 @@ class AccumulatorAwareZeroCenterPerChannelPreNorm(AccumulatorAwarePerChannelPreN pre_zero_point_stats_input_view_shape_impl = this.scaling_stats_input_view_shape_impl stats_reduce_dim = SCALING_STATS_REDUCE_DIM scaling_shape = (this << 1).scaling_shape + restrict_pre_scaling_impl = (this << 1).restrict_pre_scaling_impl class SolvePostScaleGranularity(ExtendedInjector):