From 63370f4a770e71cce27b0057d3c970be022f8866 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Sun, 26 Nov 2023 22:26:51 +0000 Subject: [PATCH] Fix (tests): fix tests after deprecation of FINN ONNX --- requirements/requirements-ort-integration.txt | 1 + .../brevitas/test_brevitas_avg_pool_export.py | 37 +++--- .../brevitas/test_debug_export.py | 4 +- tests/brevitas_finn/brevitas/test_wbiol.py | 117 ------------------ .../test_bnn_pynq_finn_export.py | 31 ++--- .../test_mobilenet_finn_export.py | 4 +- .../test_quartznet_finn_export.py | 5 +- tests/brevitas_ort/common.py | 46 ++++--- tests/brevitas_ort/test_quant_module.py | 3 +- 9 files changed, 67 insertions(+), 181 deletions(-) delete mode 100644 tests/brevitas_finn/brevitas/test_wbiol.py diff --git a/requirements/requirements-ort-integration.txt b/requirements/requirements-ort-integration.txt index a00d61123..afc10d07b 100644 --- a/requirements/requirements-ort-integration.txt +++ b/requirements/requirements-ort-integration.txt @@ -1,3 +1,4 @@ onnx onnxoptimizer onnxruntime>=1.15.0 +qonnx diff --git a/tests/brevitas_finn/brevitas/test_brevitas_avg_pool_export.py b/tests/brevitas_finn/brevitas/test_brevitas_avg_pool_export.py index 090fc5345..aa0d28faf 100644 --- a/tests/brevitas_finn/brevitas/test_brevitas_avg_pool_export.py +++ b/tests/brevitas_finn/brevitas/test_brevitas_avg_pool_export.py @@ -5,17 +5,16 @@ import numpy as np import pytest -from qonnx.core.datatype import DataType from qonnx.core.modelwrapper import ModelWrapper import qonnx.core.onnx_exec as oxe from qonnx.transformation.infer_datatypes import InferDataTypes from qonnx.transformation.infer_shapes import InferShapes -from qonnx.util.basic import gen_finn_dt_tensor import torch -from brevitas.export import FINNManager +from brevitas.export import export_qonnx from brevitas.nn import TruncAvgPool2d -from brevitas.quant_tensor import QuantTensor +from brevitas.nn.quant_activation import QuantIdentity +from brevitas.nn.quant_activation import QuantReLU export_onnx_path = "test_brevitas_avg_pool_export.onnx" @@ -29,36 +28,38 @@ @pytest.mark.parametrize("idim", [7, 8]) def test_brevitas_avg_pool_export( kernel_size, stride, signed, bit_width, input_bit_width, channels, idim, request): + if signed: + quant_node = QuantIdentity( + bit_width=input_bit_width, + return_quant_tensor=True, + ) + else: + quant_node = QuantReLU( + bit_width=input_bit_width, + return_quant_tensor=True, + ) quant_avgpool = TruncAvgPool2d( kernel_size=kernel_size, stride=stride, bit_width=bit_width, float_to_int_impl_type='floor') - quant_avgpool.eval() + model_brevitas = torch.nn.Sequential(quant_node, quant_avgpool) + model_brevitas.eval() # determine input - prefix = 'INT' if signed else 'UINT' - dt_name = prefix + str(input_bit_width) - dtype = DataType[dt_name] input_shape = (1, channels, idim, idim) - input_array = gen_finn_dt_tensor(dtype, input_shape) - scale_array = np.random.uniform(low=0, high=1, size=(1, channels, 1, 1)).astype(np.float32) - input_tensor = torch.from_numpy(input_array * scale_array).float() - scale_tensor = torch.from_numpy(scale_array).float() - zp = torch.tensor(0.) - input_quant_tensor = QuantTensor( - input_tensor, scale_tensor, zp, input_bit_width, signed, training=False) + inp = torch.randn(input_shape) # export test_id = request.node.callspec.id export_path = test_id + '_' + export_onnx_path - FINNManager.export(quant_avgpool, export_path=export_path, input_t=input_quant_tensor) + export_qonnx(model_brevitas, export_path=export_path, input_t=inp) model = ModelWrapper(export_path) model = model.transform(InferShapes()) model = model.transform(InferDataTypes()) # reference brevitas output - ref_output_array = quant_avgpool(input_quant_tensor).tensor.detach().numpy() + ref_output_array = model_brevitas(inp).tensor.detach().numpy() # finn output - idict = {model.graph.input[0].name: input_array} + idict = {model.graph.input[0].name: inp.detach().numpy()} odict = oxe.execute_onnx(model, idict, True) finn_output = odict[model.graph.output[0].name] # compare outputs diff --git a/tests/brevitas_finn/brevitas/test_debug_export.py b/tests/brevitas_finn/brevitas/test_debug_export.py index 1b232b68f..b470e2b9d 100644 --- a/tests/brevitas_finn/brevitas/test_debug_export.py +++ b/tests/brevitas_finn/brevitas/test_debug_export.py @@ -4,7 +4,7 @@ import torch from brevitas.export import enable_debug -from brevitas.export import export_finn_onnx +from brevitas.export import export_qonnx from brevitas_examples import bnn_pynq REF_MODEL = 'CNV_2W2A' @@ -15,6 +15,6 @@ def test_debug_finn_onnx_export(): model.eval() debug_hook = enable_debug(model) input_tensor = torch.randn(1, 3, 32, 32) - export_finn_onnx(model, input_t=input_tensor, export_path='finn_debug.onnx') + export_qonnx(model, input_t=input_tensor, export_path='finn_debug.onnx') model(input_tensor) assert debug_hook.values diff --git a/tests/brevitas_finn/brevitas/test_wbiol.py b/tests/brevitas_finn/brevitas/test_wbiol.py deleted file mode 100644 index c9c5360b6..000000000 --- a/tests/brevitas_finn/brevitas/test_wbiol.py +++ /dev/null @@ -1,117 +0,0 @@ -# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause - -import os - -import numpy as np -import pytest -from qonnx.core.modelwrapper import ModelWrapper -import qonnx.core.onnx_exec as oxe -from qonnx.transformation.infer_shapes import InferShapes -import torch - -from brevitas.nn import QuantConv2d -from brevitas.nn import QuantIdentity -from brevitas.nn import QuantLinear -import brevitas.onnx as bo -from brevitas.quant import Int16Bias - - -@pytest.mark.parametrize("bias", [True, False]) -@pytest.mark.parametrize("bias_quant", [Int16Bias, None]) -@pytest.mark.parametrize("out_features", [3]) -@pytest.mark.parametrize("in_features", [4]) -@pytest.mark.parametrize("w_bits", [2, 4]) -@pytest.mark.parametrize("channel_scaling", [True, False]) -@pytest.mark.parametrize("i_bits", [2, 4]) -def test_quant_linear(bias, bias_quant, out_features, in_features, w_bits, channel_scaling, i_bits): - # required to generated quantized inputs, not part of the exported model to test - quant_inp = QuantIdentity(bit_width=i_bits, return_quant_tensor=True) - inp_tensor = quant_inp(torch.randn(1, in_features)) - linear = QuantLinear( - out_features=out_features, - in_features=in_features, - bias=bias, - bias_quant=bias_quant, - weight_bit_width=w_bits, - weight_scaling_per_output_channel=channel_scaling) - linear.eval() - model = bo.export_finn_onnx(linear, input_t=inp_tensor, export_path='linear.onnx') - model = ModelWrapper(model) - model = model.transform(InferShapes()) - # the quantized input tensor passed to FINN should be in integer form - int_inp_array = inp_tensor.int(float_datatype=True).detach().numpy() - idict = {model.graph.input[0].name: int_inp_array} - odict = oxe.execute_onnx(model, idict, True) - produced = odict[model.graph.output[0].name] - expected = linear(inp_tensor).detach().numpy() - assert np.isclose(produced, expected, atol=1e-3).all() - - -@pytest.mark.parametrize("dw", [True, False]) -@pytest.mark.parametrize("bias", [True, False]) -@pytest.mark.parametrize("bias_quant", [Int16Bias, None]) -@pytest.mark.parametrize("in_features", [8]) -@pytest.mark.parametrize("in_channels", [4]) -@pytest.mark.parametrize("out_channels", [5]) -@pytest.mark.parametrize("w_bits", [2, 4]) -@pytest.mark.parametrize("channel_scaling", [True, False]) -@pytest.mark.parametrize("kernel_size", [3, 4]) -@pytest.mark.parametrize("padding", [0, 1]) -@pytest.mark.parametrize("stride", [1, 2]) -@pytest.mark.parametrize("i_bits", [2, 4]) -def test_quant_conv2d( - dw, - bias, - bias_quant, - in_features, - in_channels, - out_channels, - w_bits, - channel_scaling, - kernel_size, - padding, - stride, - i_bits): - # required to generated quantized inputs, not part of the exported model to test - quant_inp = QuantIdentity(bit_width=i_bits, return_quant_tensor=True) - inp_tensor = quant_inp(torch.randn(1, in_channels, in_features, in_features)) - try: - conv = QuantConv2d( - in_channels=in_channels, - # out_channels=in_channels if dw else out_channels, - out_channels= - out_channels, # this allows for multi-depthwise, but it needs exception check - groups=in_channels if dw else 1, - kernel_size=kernel_size, - padding=padding, - stride=stride, - bias=bias, - bias_quant=bias_quant, - weight_bit_width=w_bits, - weight_scaling_per_output_channel=channel_scaling) - except Exception as e: - # exception should be rised when (multi-)dw is expected and out_channels - # is not multiplication of in_channels - dw_groups = out_channels // in_channels - dw_out_channels = dw_groups * in_channels - if dw and dw_out_channels != out_channels: - # exception caused by inproper parameters is ok, - # but further computation gives an error. - # So return without assertion - return - else: - # any other exeptions are unknown... - assert False - - conv.eval() - model = bo.export_finn_onnx(conv, input_t=inp_tensor) - model = ModelWrapper(model) - model = model.transform(InferShapes()) - # the quantized input tensor passed to FINN should be in integer form - int_inp_array = inp_tensor.int(float_datatype=True).detach().numpy() - idict = {model.graph.input[0].name: int_inp_array} - odict = oxe.execute_onnx(model, idict, True) - produced = odict[model.graph.output[0].name] - expected = conv(inp_tensor).detach().numpy() - assert np.isclose(produced, expected, atol=1e-3).all() diff --git a/tests/brevitas_finn/brevitas_examples/test_bnn_pynq_finn_export.py b/tests/brevitas_finn/brevitas_examples/test_bnn_pynq_finn_export.py index ce6149b64..93dc94316 100644 --- a/tests/brevitas_finn/brevitas_examples/test_bnn_pynq_finn_export.py +++ b/tests/brevitas_finn/brevitas_examples/test_bnn_pynq_finn_export.py @@ -15,16 +15,13 @@ import torch from brevitas import torch_version -from brevitas.export import export_finn_onnx +from brevitas.export import export_qonnx from brevitas.quant_tensor import QuantTensor from brevitas_examples.bnn_pynq.models import model_with_cfg FC_INPUT_SIZE = (1, 1, 28, 28) CNV_INPUT_SIZE = (1, 3, 32, 32) -MIN_INP_VAL = 0 -MAX_INP_VAL = 255 - MAX_WBITS = 2 MAX_ABITS = 2 @@ -48,12 +45,9 @@ def test_brevitas_fc_onnx_export_and_exec(size, wbits, abits, pretrained): fc, _ = model_with_cfg(nname.lower(), pretrained=pretrained) fc.eval() # load a random int test vector - input_a = np.random.randint(MIN_INP_VAL, MAX_INP_VAL, size=FC_INPUT_SIZE).astype(np.float32) - scale = 1. / 255 - input_t = torch.from_numpy(input_a * scale) - input_qt = QuantTensor( - input_t, scale=torch.tensor(scale), bit_width=torch.tensor(8.0), signed=False) - export_finn_onnx(fc, export_path=finn_onnx, input_t=input_qt, input_names=['input']) + input = torch.randn(FC_INPUT_SIZE) + + export_qonnx(fc, export_path=finn_onnx, input_t=input, input_names=['input']) model = ModelWrapper(finn_onnx) model = model.transform(GiveUniqueNodeNames()) model = model.transform(DoubleToSingleFloat()) @@ -62,11 +56,11 @@ def test_brevitas_fc_onnx_export_and_exec(size, wbits, abits, pretrained): model = model.transform(RemoveStaticGraphInputs()) # run using FINN-based execution - input_dict = {'input': input_a} + input_dict = {'input': input.detach().numpy()} output_dict = oxe.execute_onnx(model, input_dict) produced = output_dict[list(output_dict.keys())[0]] # do forward pass in PyTorch/Brevitas - expected = fc.forward(input_t).detach().numpy() + expected = fc.forward(input).detach().numpy() assert np.isclose(produced, expected, atol=ATOL).all() @@ -84,12 +78,9 @@ def test_brevitas_cnv_onnx_export_and_exec(wbits, abits, pretrained): cnv, _ = model_with_cfg(nname.lower(), pretrained=pretrained) cnv.eval() # load a random int test vector - input_a = np.random.randint(MIN_INP_VAL, MAX_INP_VAL, size=CNV_INPUT_SIZE).astype(np.float32) - scale = 1. / 255 - input_t = torch.from_numpy(input_a * scale) - input_qt = QuantTensor( - input_t, scale=torch.tensor(scale), bit_width=torch.tensor(8.0), signed=False) - export_finn_onnx(cnv, export_path=finn_onnx, input_t=input_qt, input_names=['input']) + input = torch.randn(CNV_INPUT_SIZE) + + export_qonnx(cnv, export_path=finn_onnx, input_t=input, input_names=['input']) model = ModelWrapper(finn_onnx) model = model.transform(GiveUniqueNodeNames()) model = model.transform(DoubleToSingleFloat()) @@ -98,9 +89,9 @@ def test_brevitas_cnv_onnx_export_and_exec(wbits, abits, pretrained): model = model.transform(RemoveStaticGraphInputs()) # run using FINN-based execution - input_dict = {"input": input_a} + input_dict = {"input": input.detach().numpy()} output_dict = oxe.execute_onnx(model, input_dict) produced = output_dict[list(output_dict.keys())[0]] # do forward pass in PyTorch/Brevitas - expected = cnv(input_t).detach().numpy() + expected = cnv(input).detach().numpy() assert np.isclose(produced, expected, atol=ATOL).all() diff --git a/tests/brevitas_finn/brevitas_examples/test_mobilenet_finn_export.py b/tests/brevitas_finn/brevitas_examples/test_mobilenet_finn_export.py index c0bb4db38..01505b9e5 100644 --- a/tests/brevitas_finn/brevitas_examples/test_mobilenet_finn_export.py +++ b/tests/brevitas_finn/brevitas_examples/test_mobilenet_finn_export.py @@ -16,7 +16,7 @@ import torch from brevitas import torch_version -from brevitas.export import export_finn_onnx +from brevitas.export import export_qonnx from brevitas_examples.imagenet_classification import quant_mobilenet_v1_4b ort_mac_fail = pytest.mark.skipif( @@ -41,7 +41,7 @@ def test_mobilenet_v1_4b(pretrained): torch_tensor = torch.from_numpy(numpy_tensor).float() # do forward pass in PyTorch/Brevitas expected = mobilenet(torch_tensor).detach().numpy() - export_finn_onnx(mobilenet, input_shape=INPUT_SIZE, export_path=finn_onnx) + export_qonnx(mobilenet, input_shape=INPUT_SIZE, export_path=finn_onnx) model = ModelWrapper(finn_onnx) model = model.transform(GiveUniqueNodeNames()) model = model.transform(DoubleToSingleFloat()) diff --git a/tests/brevitas_finn/brevitas_examples/test_quartznet_finn_export.py b/tests/brevitas_finn/brevitas_examples/test_quartznet_finn_export.py index c785e66fb..dded0e276 100644 --- a/tests/brevitas_finn/brevitas_examples/test_quartznet_finn_export.py +++ b/tests/brevitas_finn/brevitas_examples/test_quartznet_finn_export.py @@ -12,7 +12,7 @@ from qonnx.transformation.infer_shapes import InferShapes import torch -from brevitas.export import export_finn_onnx +from brevitas.export import export_qonnx from brevitas_examples.speech_to_text import quant_quartznet_perchannelscaling_4b QUARTZNET_POSTPROCESSED_INPUT_SIZE = (1, 64, 256) # B, features, sequence @@ -26,8 +26,7 @@ def test_quartznet_asr_4b(pretrained): finn_onnx = "quant_quartznet_perchannelscaling_4b.onnx" quartznet = quant_quartznet_perchannelscaling_4b(pretrained, export_mode=True) quartznet.eval() - export_finn_onnx( - quartznet, input_shape=QUARTZNET_POSTPROCESSED_INPUT_SIZE, export_path=finn_onnx) + export_qonnx(quartznet, input_shape=QUARTZNET_POSTPROCESSED_INPUT_SIZE, export_path=finn_onnx) model = ModelWrapper(finn_onnx) model = model.transform(GiveUniqueNodeNames()) model = model.transform(DoubleToSingleFloat()) diff --git a/tests/brevitas_ort/common.py b/tests/brevitas_ort/common.py index f5c73a2a4..c05fd59b9 100644 --- a/tests/brevitas_ort/common.py +++ b/tests/brevitas_ort/common.py @@ -4,6 +4,9 @@ import numpy as np import onnxruntime as ort import pytest +from qonnx.core.modelwrapper import ModelWrapper +import qonnx.core.onnx_exec as oxe +from qonnx.transformation.infer_shapes import InferShapes import torch from brevitas.export import export_onnx_qcdq @@ -113,32 +116,39 @@ def is_brevitas_ort_close( input_t = torch.from_numpy(np_input) brevitas_output = model(input_t) - if export_type == 'qop': - export_onnx_qop(model, input_t, export_path=export_name) - brevitas_output = brevitas_output.int(float_datatype=False) - elif export_type == 'qcdq': - export_onnx_qcdq(model, input_t, export_path=export_name) - elif export_type == 'qcdq_opset14': - export_onnx_qcdq(model, input_t, opset_version=14, export_path=export_name) - elif export_type == 'qonnx_opset14': - export_qonnx(model, input_t, opset_version=14, export_path=export_name) - else: - raise RuntimeError(f"Export type {export_type} not recognized.") - if tolerance is not None and export_type == 'qcdq': tolerance = tolerance * brevitas_output.scale # Float Output, tolerance is +/- output scale - ort_output = compute_ort(export_name, np_input) + if export_type == 'qonnx': + exported_model = export_qonnx(model, input_t, export_path=export_name) + exported_model = ModelWrapper(exported_model) + exported_model = exported_model.transform(InferShapes()) + idict = {exported_model.graph.input[0].name: np_input} + odict = oxe.execute_onnx(exported_model, idict, True) + ort_output = odict[exported_model.graph.output[0].name] + else: + if export_type == 'qop': + export_onnx_qop(model, input_t, export_path=export_name) + brevitas_output = brevitas_output.int(float_datatype=False) + elif export_type == 'qcdq': + export_onnx_qcdq(model, input_t, export_path=export_name) + elif export_type == 'qcdq_opset14': + export_onnx_qcdq(model, input_t, opset_version=14, export_path=export_name) + elif export_type == 'qonnx_opset14': + export_qonnx(model, input_t, opset_version=14, export_path=export_name) + else: + raise RuntimeError(f"Export type {export_type} not recognized.") + + ort_output = compute_ort(export_name, np_input) if first_output_only: - if isinstance(ort_output, tuple): + if isinstance(ort_output, (tuple, list)): ort_output = ort_output[0] if isinstance(brevitas_output, tuple): brevitas_output = brevitas_output[0] - - # make sure we are not comparing 0s - if ort_output == 0 and (brevitas_output == 0).all(): - pytest.skip("Skip testing against all 0s.") + # make sure we are not comparing 0s + if (ort_output == 0).all() and (brevitas_output == 0).all(): + pytest.skip("Skip testing against all 0s.") return recursive_allclose(ort_output, brevitas_output, tolerance) diff --git a/tests/brevitas_ort/test_quant_module.py b/tests/brevitas_ort/test_quant_module.py index a5e5bbf7d..a4f7e7f5c 100644 --- a/tests/brevitas_ort/test_quant_module.py +++ b/tests/brevitas_ort/test_quant_module.py @@ -3,6 +3,7 @@ from functools import reduce from operator import mul +import os import pytest from pytest_cases import get_case_id @@ -18,7 +19,7 @@ @parametrize_with_cases('model', cases=QuantWBIOLCases) -@pytest.mark.parametrize('export_type', ['qcdq', 'qop']) +@pytest.mark.parametrize('export_type', ['qcdq', 'qonnx', 'qop']) @requires_pt_ge('1.8.1') def test_ort_wbiol(model, export_type, current_cases): cases_generator_func = current_cases['model'][1]