Skip to content

Commit

Permalink
Fix (tests): fix tests after deprecation of FINN ONNX
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Nov 27, 2023
1 parent de069f3 commit 63370f4
Show file tree
Hide file tree
Showing 9 changed files with 67 additions and 181 deletions.
1 change: 1 addition & 0 deletions requirements/requirements-ort-integration.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
onnx
onnxoptimizer
onnxruntime>=1.15.0
qonnx
37 changes: 19 additions & 18 deletions tests/brevitas_finn/brevitas/test_brevitas_avg_pool_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,16 @@

import numpy as np
import pytest
from qonnx.core.datatype import DataType
from qonnx.core.modelwrapper import ModelWrapper
import qonnx.core.onnx_exec as oxe
from qonnx.transformation.infer_datatypes import InferDataTypes
from qonnx.transformation.infer_shapes import InferShapes
from qonnx.util.basic import gen_finn_dt_tensor
import torch

from brevitas.export import FINNManager
from brevitas.export import export_qonnx
from brevitas.nn import TruncAvgPool2d
from brevitas.quant_tensor import QuantTensor
from brevitas.nn.quant_activation import QuantIdentity
from brevitas.nn.quant_activation import QuantReLU

export_onnx_path = "test_brevitas_avg_pool_export.onnx"

Expand All @@ -29,36 +28,38 @@
@pytest.mark.parametrize("idim", [7, 8])
def test_brevitas_avg_pool_export(
kernel_size, stride, signed, bit_width, input_bit_width, channels, idim, request):
if signed:
quant_node = QuantIdentity(
bit_width=input_bit_width,
return_quant_tensor=True,
)
else:
quant_node = QuantReLU(
bit_width=input_bit_width,
return_quant_tensor=True,
)

quant_avgpool = TruncAvgPool2d(
kernel_size=kernel_size, stride=stride, bit_width=bit_width, float_to_int_impl_type='floor')
quant_avgpool.eval()
model_brevitas = torch.nn.Sequential(quant_node, quant_avgpool)
model_brevitas.eval()

# determine input
prefix = 'INT' if signed else 'UINT'
dt_name = prefix + str(input_bit_width)
dtype = DataType[dt_name]
input_shape = (1, channels, idim, idim)
input_array = gen_finn_dt_tensor(dtype, input_shape)
scale_array = np.random.uniform(low=0, high=1, size=(1, channels, 1, 1)).astype(np.float32)
input_tensor = torch.from_numpy(input_array * scale_array).float()
scale_tensor = torch.from_numpy(scale_array).float()
zp = torch.tensor(0.)
input_quant_tensor = QuantTensor(
input_tensor, scale_tensor, zp, input_bit_width, signed, training=False)
inp = torch.randn(input_shape)

# export
test_id = request.node.callspec.id
export_path = test_id + '_' + export_onnx_path
FINNManager.export(quant_avgpool, export_path=export_path, input_t=input_quant_tensor)
export_qonnx(model_brevitas, export_path=export_path, input_t=inp)
model = ModelWrapper(export_path)
model = model.transform(InferShapes())
model = model.transform(InferDataTypes())

# reference brevitas output
ref_output_array = quant_avgpool(input_quant_tensor).tensor.detach().numpy()
ref_output_array = model_brevitas(inp).tensor.detach().numpy()
# finn output
idict = {model.graph.input[0].name: input_array}
idict = {model.graph.input[0].name: inp.detach().numpy()}
odict = oxe.execute_onnx(model, idict, True)
finn_output = odict[model.graph.output[0].name]
# compare outputs
Expand Down
4 changes: 2 additions & 2 deletions tests/brevitas_finn/brevitas/test_debug_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch

from brevitas.export import enable_debug
from brevitas.export import export_finn_onnx
from brevitas.export import export_qonnx
from brevitas_examples import bnn_pynq

REF_MODEL = 'CNV_2W2A'
Expand All @@ -15,6 +15,6 @@ def test_debug_finn_onnx_export():
model.eval()
debug_hook = enable_debug(model)
input_tensor = torch.randn(1, 3, 32, 32)
export_finn_onnx(model, input_t=input_tensor, export_path='finn_debug.onnx')
export_qonnx(model, input_t=input_tensor, export_path='finn_debug.onnx')
model(input_tensor)
assert debug_hook.values
117 changes: 0 additions & 117 deletions tests/brevitas_finn/brevitas/test_wbiol.py

This file was deleted.

31 changes: 11 additions & 20 deletions tests/brevitas_finn/brevitas_examples/test_bnn_pynq_finn_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,13 @@
import torch

from brevitas import torch_version
from brevitas.export import export_finn_onnx
from brevitas.export import export_qonnx
from brevitas.quant_tensor import QuantTensor
from brevitas_examples.bnn_pynq.models import model_with_cfg

FC_INPUT_SIZE = (1, 1, 28, 28)
CNV_INPUT_SIZE = (1, 3, 32, 32)

MIN_INP_VAL = 0
MAX_INP_VAL = 255

MAX_WBITS = 2
MAX_ABITS = 2

Expand All @@ -48,12 +45,9 @@ def test_brevitas_fc_onnx_export_and_exec(size, wbits, abits, pretrained):
fc, _ = model_with_cfg(nname.lower(), pretrained=pretrained)
fc.eval()
# load a random int test vector
input_a = np.random.randint(MIN_INP_VAL, MAX_INP_VAL, size=FC_INPUT_SIZE).astype(np.float32)
scale = 1. / 255
input_t = torch.from_numpy(input_a * scale)
input_qt = QuantTensor(
input_t, scale=torch.tensor(scale), bit_width=torch.tensor(8.0), signed=False)
export_finn_onnx(fc, export_path=finn_onnx, input_t=input_qt, input_names=['input'])
input = torch.randn(FC_INPUT_SIZE)

export_qonnx(fc, export_path=finn_onnx, input_t=input, input_names=['input'])
model = ModelWrapper(finn_onnx)
model = model.transform(GiveUniqueNodeNames())
model = model.transform(DoubleToSingleFloat())
Expand All @@ -62,11 +56,11 @@ def test_brevitas_fc_onnx_export_and_exec(size, wbits, abits, pretrained):
model = model.transform(RemoveStaticGraphInputs())

# run using FINN-based execution
input_dict = {'input': input_a}
input_dict = {'input': input.detach().numpy()}
output_dict = oxe.execute_onnx(model, input_dict)
produced = output_dict[list(output_dict.keys())[0]]
# do forward pass in PyTorch/Brevitas
expected = fc.forward(input_t).detach().numpy()
expected = fc.forward(input).detach().numpy()
assert np.isclose(produced, expected, atol=ATOL).all()


Expand All @@ -84,12 +78,9 @@ def test_brevitas_cnv_onnx_export_and_exec(wbits, abits, pretrained):
cnv, _ = model_with_cfg(nname.lower(), pretrained=pretrained)
cnv.eval()
# load a random int test vector
input_a = np.random.randint(MIN_INP_VAL, MAX_INP_VAL, size=CNV_INPUT_SIZE).astype(np.float32)
scale = 1. / 255
input_t = torch.from_numpy(input_a * scale)
input_qt = QuantTensor(
input_t, scale=torch.tensor(scale), bit_width=torch.tensor(8.0), signed=False)
export_finn_onnx(cnv, export_path=finn_onnx, input_t=input_qt, input_names=['input'])
input = torch.randn(CNV_INPUT_SIZE)

export_qonnx(cnv, export_path=finn_onnx, input_t=input, input_names=['input'])
model = ModelWrapper(finn_onnx)
model = model.transform(GiveUniqueNodeNames())
model = model.transform(DoubleToSingleFloat())
Expand All @@ -98,9 +89,9 @@ def test_brevitas_cnv_onnx_export_and_exec(wbits, abits, pretrained):
model = model.transform(RemoveStaticGraphInputs())

# run using FINN-based execution
input_dict = {"input": input_a}
input_dict = {"input": input.detach().numpy()}
output_dict = oxe.execute_onnx(model, input_dict)
produced = output_dict[list(output_dict.keys())[0]]
# do forward pass in PyTorch/Brevitas
expected = cnv(input_t).detach().numpy()
expected = cnv(input).detach().numpy()
assert np.isclose(produced, expected, atol=ATOL).all()
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import torch

from brevitas import torch_version
from brevitas.export import export_finn_onnx
from brevitas.export import export_qonnx
from brevitas_examples.imagenet_classification import quant_mobilenet_v1_4b

ort_mac_fail = pytest.mark.skipif(
Expand All @@ -41,7 +41,7 @@ def test_mobilenet_v1_4b(pretrained):
torch_tensor = torch.from_numpy(numpy_tensor).float()
# do forward pass in PyTorch/Brevitas
expected = mobilenet(torch_tensor).detach().numpy()
export_finn_onnx(mobilenet, input_shape=INPUT_SIZE, export_path=finn_onnx)
export_qonnx(mobilenet, input_shape=INPUT_SIZE, export_path=finn_onnx)
model = ModelWrapper(finn_onnx)
model = model.transform(GiveUniqueNodeNames())
model = model.transform(DoubleToSingleFloat())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from qonnx.transformation.infer_shapes import InferShapes
import torch

from brevitas.export import export_finn_onnx
from brevitas.export import export_qonnx
from brevitas_examples.speech_to_text import quant_quartznet_perchannelscaling_4b

QUARTZNET_POSTPROCESSED_INPUT_SIZE = (1, 64, 256) # B, features, sequence
Expand All @@ -26,8 +26,7 @@ def test_quartznet_asr_4b(pretrained):
finn_onnx = "quant_quartznet_perchannelscaling_4b.onnx"
quartznet = quant_quartznet_perchannelscaling_4b(pretrained, export_mode=True)
quartznet.eval()
export_finn_onnx(
quartznet, input_shape=QUARTZNET_POSTPROCESSED_INPUT_SIZE, export_path=finn_onnx)
export_qonnx(quartznet, input_shape=QUARTZNET_POSTPROCESSED_INPUT_SIZE, export_path=finn_onnx)
model = ModelWrapper(finn_onnx)
model = model.transform(GiveUniqueNodeNames())
model = model.transform(DoubleToSingleFloat())
Expand Down
46 changes: 28 additions & 18 deletions tests/brevitas_ort/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import numpy as np
import onnxruntime as ort
import pytest
from qonnx.core.modelwrapper import ModelWrapper
import qonnx.core.onnx_exec as oxe
from qonnx.transformation.infer_shapes import InferShapes
import torch

from brevitas.export import export_onnx_qcdq
Expand Down Expand Up @@ -113,32 +116,39 @@ def is_brevitas_ort_close(
input_t = torch.from_numpy(np_input)
brevitas_output = model(input_t)

if export_type == 'qop':
export_onnx_qop(model, input_t, export_path=export_name)
brevitas_output = brevitas_output.int(float_datatype=False)
elif export_type == 'qcdq':
export_onnx_qcdq(model, input_t, export_path=export_name)
elif export_type == 'qcdq_opset14':
export_onnx_qcdq(model, input_t, opset_version=14, export_path=export_name)
elif export_type == 'qonnx_opset14':
export_qonnx(model, input_t, opset_version=14, export_path=export_name)
else:
raise RuntimeError(f"Export type {export_type} not recognized.")

if tolerance is not None and export_type == 'qcdq':
tolerance = tolerance * brevitas_output.scale # Float Output, tolerance is +/- output scale

ort_output = compute_ort(export_name, np_input)
if export_type == 'qonnx':
exported_model = export_qonnx(model, input_t, export_path=export_name)
exported_model = ModelWrapper(exported_model)
exported_model = exported_model.transform(InferShapes())
idict = {exported_model.graph.input[0].name: np_input}
odict = oxe.execute_onnx(exported_model, idict, True)
ort_output = odict[exported_model.graph.output[0].name]
else:
if export_type == 'qop':
export_onnx_qop(model, input_t, export_path=export_name)
brevitas_output = brevitas_output.int(float_datatype=False)
elif export_type == 'qcdq':
export_onnx_qcdq(model, input_t, export_path=export_name)
elif export_type == 'qcdq_opset14':
export_onnx_qcdq(model, input_t, opset_version=14, export_path=export_name)
elif export_type == 'qonnx_opset14':
export_qonnx(model, input_t, opset_version=14, export_path=export_name)
else:
raise RuntimeError(f"Export type {export_type} not recognized.")

ort_output = compute_ort(export_name, np_input)

if first_output_only:
if isinstance(ort_output, tuple):
if isinstance(ort_output, (tuple, list)):
ort_output = ort_output[0]
if isinstance(brevitas_output, tuple):
brevitas_output = brevitas_output[0]

# make sure we are not comparing 0s
if ort_output == 0 and (brevitas_output == 0).all():
pytest.skip("Skip testing against all 0s.")
# make sure we are not comparing 0s
if (ort_output == 0).all() and (brevitas_output == 0).all():
pytest.skip("Skip testing against all 0s.")

return recursive_allclose(ort_output, brevitas_output, tolerance)

Expand Down
Loading

0 comments on commit 63370f4

Please sign in to comment.