Skip to content

Commit

Permalink
Fix (quant_tensor): Produce valid IntQuantTensor after AvgPool functi…
Browse files Browse the repository at this point in the history
…onal call
  • Loading branch information
nickfraser committed Oct 4, 2024
1 parent f84a0e1 commit af37fcc
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/brevitas/quant_tensor/int_torch_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def avg_pool2d_handler(
avg_scaling = kernel_size * kernel_size

quant_input = quant_input.set(value=x)
quant_input = quant_input.set(scale=quant_input.scale / avg_scaling)
quant_input = quant_input.set(bit_width=max_acc_bit_width(quant_input.bit_width, avg_scaling))
return quant_input

Expand All @@ -133,6 +134,7 @@ def adaptive_avg_pool2d_handler(quant_input, output_shape):
reduce_size = reduce(mul, k_size, 1)

quant_input = quant_input.set(value=x)
quant_input = quant_input.set(scale=quant_input.scale / reduce_size)
quant_input = quant_input.set(bit_width=max_acc_bit_width(quant_input.bit_width, reduce_size))
return quant_input

Expand Down

0 comments on commit af37fcc

Please sign in to comment.