diff --git a/torchpruner/graph.py b/torchpruner/graph.py index cbbd463..b393fe8 100644 --- a/torchpruner/graph.py +++ b/torchpruner/graph.py @@ -458,9 +458,11 @@ def build_graph(self, inputs, fill_value=True, training=False): graph, params_dict, torch_out = torch.onnx.utils._model_to_graph( model, inputs, - _retain_param_name=True, - do_constant_folding=False, - training=training, + **normalize_onnx_parameters( + _retain_param_name=True, + do_constant_folding=False, + training=training, + ) ) torch.onnx.symbolic_helper._set_opset_version(9) # create the inputs and the terminals diff --git a/torchpruner/model_tools.py b/torchpruner/model_tools.py index 3e59850..589ba7d 100644 --- a/torchpruner/model_tools.py +++ b/torchpruner/model_tools.py @@ -176,3 +176,10 @@ def replace_object_by_class( ): name_list = get_names_by_class(model, object_class, include_super_class) return replace_object_by_names(model, name_list, replace_function) + + +def normalize_onnx_parameters(**kwargs): + torch_version = torch.__version__.split(".") + if torch_version[0] > "2" or len(torch_version) > 1 and torch_version[1] >= "10": + kwargs.pop("_retain_param_name", None) + return kwargs diff --git a/torchslim/quantizing/qat_tools.py b/torchslim/quantizing/qat_tools.py index e51c4cb..f1b56d4 100644 --- a/torchslim/quantizing/qat_tools.py +++ b/torchslim/quantizing/qat_tools.py @@ -305,10 +305,12 @@ def export_onnx(model, inputs, onnx_file): model, inputs, onnx_file, - opset_version=10, - verbose=False, - enable_onnx_checker=False, - _retain_param_name=False, + **model_tools.normalize_onnx_parameters( + opset_version=10, + verbose=False, + enable_onnx_checker=False, + _retain_param_name=True, + ) ) onnx_model = onnx.load(onnx_file) onnx_post_process(onnx_model) diff --git a/torchslim/quantizing/quantizer_test.py b/torchslim/quantizing/quantizer_test.py index 6c8fb14..548eb2b 100644 --- a/torchslim/quantizing/quantizer_test.py +++ b/torchslim/quantizing/quantizer_test.py @@ -42,9 +42,15 @@ def pre_hook(self, input): # # torch.onnx.symbolic_opset10 # with torch.onnx.select_model_mode_for_export(resnet_prepare, None): # graph = torch.onnx.utils._trace(resnet_prepare,(torch.zeros(1,3,224,224),), OperatorExportTypes.ONNX) +try: + from torchpruner.model_tools import normalize_onnx_parameters +except ImportError: + def normalize_onnx_parameters (**_kwargs): + return {} graph, params_dict, torch_out = torch.onnx.utils._model_to_graph( - resnet_prepare, (torch.zeros(1, 3, 224, 224),), _retain_param_name=True + resnet_prepare, (torch.zeros(1, 3, 224, 224),), + **normalize_onnx_parameters(_retain_param_name=True) ) print(graph)