diff --git a/src/brevitas/quant/base.py b/src/brevitas/quant/base.py index 790ceed32..e27cd1094 100644 --- a/src/brevitas/quant/base.py +++ b/src/brevitas/quant/base.py @@ -363,6 +363,7 @@ class AccumulatorAwarePerChannelPreNorm(PerChannelPreNorm): pre_scaling_impl = AccumulatorAwareParameterPreScaling accumulator_bit_width = (this << 1).accumulator_bit_width accumulator_bit_width_impl = (this << 1).accumulator_bit_width_impl + restrict_pre_scaling_impl = (this << 1).restrict_pre_scaling_impl class AccumulatorAwareZeroCenterPerChannelPreNorm(AccumulatorAwarePerChannelPreNorm):