From 5c89ef26ee587a25eb0a3feac7b97f9b741082b3 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Fri, 4 Oct 2024 13:58:50 +0100 Subject: [PATCH] Fix (core/trunc): Fix output scaling after truncation --- src/brevitas/core/quant/int.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/brevitas/core/quant/int.py b/src/brevitas/core/quant/int.py index cdb75df74..46127eb08 100644 --- a/src/brevitas/core/quant/int.py +++ b/src/brevitas/core/quant/int.py @@ -211,12 +211,13 @@ def forward(self, x: Tensor, scale: Tensor, zero_point: Tensor, output_bit_width = self.msb_clamp_bit_width_impl() trunc_bit_width = input_bit_width - output_bit_width trunc_scale = 2.0 ** trunc_bit_width + output_scale = scale * trunc_scale y = y / trunc_scale y = self.float_to_int_impl(y) y = y - zero_point - y = y * scale + y = y * output_scale y = self.delay_wrapper(x, y) - return y, scale, zero_point, output_bit_width + return y, output_scale, zero_point, output_bit_width class DecoupledRescalingIntQuantWithInput(DecoupledRescalingIntQuant):