diff --git a/src/brevitas/core/quant/int.py b/src/brevitas/core/quant/int.py index cdb75df74..46127eb08 100644 --- a/src/brevitas/core/quant/int.py +++ b/src/brevitas/core/quant/int.py @@ -211,12 +211,13 @@ def forward(self, x: Tensor, scale: Tensor, zero_point: Tensor, output_bit_width = self.msb_clamp_bit_width_impl() trunc_bit_width = input_bit_width - output_bit_width trunc_scale = 2.0 ** trunc_bit_width + output_scale = scale * trunc_scale y = y / trunc_scale y = self.float_to_int_impl(y) y = y - zero_point - y = y * scale + y = y * output_scale y = self.delay_wrapper(x, y) - return y, scale, zero_point, output_bit_width + return y, output_scale, zero_point, output_bit_width class DecoupledRescalingIntQuantWithInput(DecoupledRescalingIntQuant):