diff --git a/src/brevitas/quant_tensor/int_torch_handler.py b/src/brevitas/quant_tensor/int_torch_handler.py index 8882bd097..a9c6572fa 100644 --- a/src/brevitas/quant_tensor/int_torch_handler.py +++ b/src/brevitas/quant_tensor/int_torch_handler.py @@ -114,6 +114,7 @@ def avg_pool2d_handler( avg_scaling = kernel_size * kernel_size quant_input = quant_input.set(value=x) + quant_input = quant_input.set(scale=quant_input.scale / avg_scaling) quant_input = quant_input.set(bit_width=max_acc_bit_width(quant_input.bit_width, avg_scaling)) return quant_input @@ -133,6 +134,7 @@ def adaptive_avg_pool2d_handler(quant_input, output_shape): reduce_size = reduce(mul, k_size, 1) quant_input = quant_input.set(value=x) + quant_input = quant_input.set(scale=quant_input.scale / reduce_size) quant_input = quant_input.set(bit_width=max_acc_bit_width(quant_input.bit_width, reduce_size)) return quant_input