Skip to content

Commit

Permalink
Fix (core/trunc): Fix output scaling after truncation
Browse files Browse the repository at this point in the history
  • Loading branch information
nickfraser committed Oct 4, 2024
1 parent af37fcc commit 5c89ef2
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions src/brevitas/core/quant/int.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,12 +211,13 @@ def forward(self, x: Tensor, scale: Tensor, zero_point: Tensor,
output_bit_width = self.msb_clamp_bit_width_impl()
trunc_bit_width = input_bit_width - output_bit_width
trunc_scale = 2.0 ** trunc_bit_width
output_scale = scale * trunc_scale
y = y / trunc_scale
y = self.float_to_int_impl(y)
y = y - zero_point
y = y * scale
y = y * output_scale
y = self.delay_wrapper(x, y)
return y, scale, zero_point, output_bit_width
return y, output_scale, zero_point, output_bit_width


class DecoupledRescalingIntQuantWithInput(DecoupledRescalingIntQuant):
Expand Down

0 comments on commit 5c89ef2

Please sign in to comment.