Skip to content

Commit

Permalink
Fix (nn/TruncAvgPool): Remove any quant tensor manual manipulation.
Browse files Browse the repository at this point in the history
  • Loading branch information
nickfraser committed Jan 22, 2025
1 parent 13ca170 commit f183191
Showing 1 changed file with 2 additions and 6 deletions.
8 changes: 2 additions & 6 deletions src/brevitas/nn/quant_avg_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def _avg_scaling(self):
else:
return self.kernel_size * self.kernel_size

# TODO: Replace with functional call
def forward(self, input: Union[Tensor, QuantTensor]):
x = self.unpack_input(input)

Expand All @@ -71,8 +72,6 @@ def forward(self, input: Union[Tensor, QuantTensor]):
if not isinstance(x, QuantTensor):
x = self.cache_class.quant_tensor.set(value=x)
y = AvgPool2d.forward(self, x)
rescaled_value = y.value * self._avg_scaling
y = y.set(value=rescaled_value)
y = self.trunc_quant(y)
else:
y = AvgPool2d.forward(self, _unpack_quant_tensor(x))
Expand Down Expand Up @@ -123,6 +122,7 @@ def compute_kernel_size_stride(input_shape, output_shape):
stride_list.append(stride)
return kernel_size_list, stride_list

# TODO: Replace with functional call
def forward(self, input: Union[Tensor, QuantTensor]):
x = self.unpack_input(input)

Expand All @@ -139,10 +139,6 @@ def forward(self, input: Union[Tensor, QuantTensor]):
if not isinstance(x, QuantTensor):
x = self.cache_class.quant_tensor.set(value=x)
y = AdaptiveAvgPool2d.forward(self, x)
k_size, stride = self.compute_kernel_size_stride(x.value.shape[2:], y.value.shape[2:])
reduce_size = reduce(mul, k_size, 1)
rescaled_value = y.value * reduce_size # remove avg scaling
y = y.set(value=rescaled_value)
y = self.trunc_quant(y)
else:
y = AdaptiveAvgPool2d.forward(self, _unpack_quant_tensor(x))
Expand Down

0 comments on commit f183191

Please sign in to comment.