Skip to content

Commit

Permalink
tests (brv_finn/avgpool): Add "lossless" tests
Browse files Browse the repository at this point in the history
  • Loading branch information
nickfraser committed Jan 28, 2025
1 parent 80f57e3 commit 1af6840
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 4 deletions.
8 changes: 7 additions & 1 deletion src/brevitas/nn/quant_avg_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,13 @@ def __init__(
trunc_quant: Optional[AccQuantType] = RoundTo8bit,
return_quant_tensor: bool = True,
**kwargs):
AvgPool2d.__init__(self, kernel_size=kernel_size, stride=stride, ceil_mode=ceil_mode, count_include_pad=count_include_pad, divisor_override=divisor_override)
AvgPool2d.__init__(
self,
kernel_size=kernel_size,
stride=stride,
ceil_mode=ceil_mode,
count_include_pad=count_include_pad,
divisor_override=divisor_override)
QuantLayerMixin.__init__(self, return_quant_tensor)
TruncMixin.__init__(self, trunc_quant=trunc_quant, **kwargs)
self.cache_inference_quant_act = False
Expand Down
20 changes: 17 additions & 3 deletions tests/brevitas_finn/brevitas/test_brevitas_avg_pool_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,27 @@
@pytest.mark.parametrize("input_bit_width", [4, 8, 16])
@pytest.mark.parametrize("channels", [2, 4])
@pytest.mark.parametrize("idim", [7, 8])
@pytest.mark.parametrize("restrict_scaling_type", ["log_fp", "power_of_two"])
def test_brevitas_avg_pool_export(
kernel_size, stride, signed, bit_width, input_bit_width, channels, idim, request):
kernel_size,
stride,
signed,
bit_width,
input_bit_width,
channels,
idim,
restrict_scaling_type,
request):
if signed:
quant_node = QuantIdentity(
bit_width=input_bit_width,
restrict_scaling_type=restrict_scaling_type,
return_quant_tensor=True,
)
else:
quant_node = QuantReLU(
bit_width=input_bit_width,
restrict_scaling_type=restrict_scaling_type,
return_quant_tensor=True,
)

Expand Down Expand Up @@ -66,7 +77,10 @@ def test_brevitas_avg_pool_export(
odict = oxe.execute_onnx(model, idict, True)
finn_output = odict[model.graph.output[0].name]
# compare outputs
scale = quant_avgpool.trunc_quant.scale().detach().numpy() # Allow "off-by-1" errors
assert np.isclose(ref_output_array, finn_output, atol=scale).all()
if restrict_scaling_type == "power_of_two" and kernel_size == 2:
atol = 1e-8
else:
atol = quant_avgpool.trunc_quant.scale().detach().numpy() # Allow "off-by-1" errors
assert np.isclose(ref_output_array, finn_output, atol=atol).all()
# cleanup
os.remove(export_path)

0 comments on commit 1af6840

Please sign in to comment.