Skip to content

Commit

Permalink
Tests (export/finn): update quant avg pool tests
Browse files Browse the repository at this point in the history
  • Loading branch information
volcacius committed May 5, 2021
1 parent 4184c92 commit 3e24be8
Showing 1 changed file with 22 additions and 51 deletions.
73 changes: 22 additions & 51 deletions tests/brevitas_finn/brevitas/test_brevitas_avg_pool_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,68 +28,39 @@
@pytest.mark.parametrize("idim", [7, 8])
def test_brevitas_avg_pool_export(
kernel_size, stride, signed, bit_width, input_bit_width, channels, idim):
ishape = (1, channels, idim, idim)
ibw_tensor = torch.Tensor([input_bit_width])

b_avgpool = QuantAvgPool2d(
quant_avgpool = QuantAvgPool2d(
kernel_size=kernel_size,
stride=stride,
bit_width=bit_width,
quant_type=QuantType.INT)
# call forward pass manually once to cache scale factor and bitwidth
input_tensor = torch.from_numpy(np.zeros(ishape)).float()
scale = np.ones((1, channels, 1, 1))
output_scale = torch.from_numpy(scale).float()
zp = torch.tensor(0.)
input_quant_tensor = QuantTensor(input_tensor, output_scale, zp, ibw_tensor, signed)
FINNManager.export(b_avgpool, export_path=export_onnx_path, input_t=input_quant_tensor)
model = ModelWrapper(export_onnx_path)
bit_width=bit_width)
quant_avgpool.eval()

# determine input FINN datatype
if signed is True:
prefix = "INT"
else:
prefix = "UINT"
# determine input
prefix = 'INT' if signed else 'UINT'
dt_name = prefix + str(input_bit_width)
dtype = DataType[dt_name]
model = model.transform(InferShapes())
model = model.transform(InferDataTypes())

# execution with input tensor using integers and scale = 1
# calculate golden output
inp = gen_finn_dt_tensor(dtype, ishape)
input_tensor = torch.from_numpy(inp).float()
input_shape = (1, channels, idim, idim)
input_array = gen_finn_dt_tensor(dtype, input_shape)
scale_array = np.random.uniform(low=0, high=1, size=(1, channels, 1, 1)).astype(np.float32)
input_tensor = torch.from_numpy(input_array * scale_array).float()
scale_tensor = torch.from_numpy(scale_array).float()
zp = torch.tensor(0.)
input_quant_tensor = QuantTensor(input_tensor, output_scale, zp, ibw_tensor, signed)
b_avgpool.eval()
expected = b_avgpool.forward(input_quant_tensor).tensor.detach().numpy()
input_quant_tensor = QuantTensor(input_tensor, scale_tensor, zp, input_bit_width, signed)

# finn execution
idict = {model.graph.input[0].name: inp}
odict = oxe.execute_onnx(model, idict, True)
produced = odict[model.graph.output[0].name]
assert (expected == produced).all()

# execution with input tensor using float and scale != 1
scale = np.random.uniform(low=0, high=1, size=(1, channels, 1, 1)).astype(np.float32)
inp_tensor = inp * scale
input_tensor = torch.from_numpy(inp_tensor).float()
input_scale = torch.from_numpy(scale).float()
zp = torch.tensor(0.)
input_quant_tensor = QuantTensor(input_tensor, input_scale, zp, ibw_tensor, signed)
# export again to set the scale values correctly
FINNManager.export(b_avgpool, export_path=export_onnx_path, input_t=input_quant_tensor)
# export
FINNManager.export(quant_avgpool, export_path=export_onnx_path, input_t=input_quant_tensor)
model = ModelWrapper(export_onnx_path)
model = model.transform(InferShapes())
model = model.transform(InferDataTypes())
b_avgpool.eval()
expected = b_avgpool.forward(input_quant_tensor).tensor.detach().numpy()
# finn execution
idict = {model.graph.input[0].name: inp_tensor}
odict = oxe.execute_onnx(model, idict, True)
produced = odict[model.graph.output[0].name]

assert np.isclose(expected, produced).all()

# reference brevitas output
ref_output_array = quant_avgpool(input_quant_tensor).tensor.detach().numpy()
# finn output
idict = {model.graph.input[0].name: input_array}
odict = oxe.execute_onnx(model, idict, True)
finn_output = odict[model.graph.output[0].name]
# compare outputs
assert np.isclose(ref_output_array, finn_output).all()
# cleanup
os.remove(export_onnx_path)

0 comments on commit 3e24be8

Please sign in to comment.