Skip to content

Commit

Permalink
fix: export into onnx: remove _retain_param_name
Browse files Browse the repository at this point in the history
This is reported in
#7 (comment)
  • Loading branch information
gdh1995 committed Dec 7, 2021
1 parent 8c6455c commit cb210e2
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 8 deletions.
8 changes: 5 additions & 3 deletions torchpruner/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,9 +458,11 @@ def build_graph(self, inputs, fill_value=True, training=False):
graph, params_dict, torch_out = torch.onnx.utils._model_to_graph(
model,
inputs,
_retain_param_name=True,
do_constant_folding=False,
training=training,
**normalize_onnx_parameters(
_retain_param_name=True,
do_constant_folding=False,
training=training,
)
)
torch.onnx.symbolic_helper._set_opset_version(9)
# create the inputs and the terminals
Expand Down
7 changes: 7 additions & 0 deletions torchpruner/model_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,3 +176,10 @@ def replace_object_by_class(
):
name_list = get_names_by_class(model, object_class, include_super_class)
return replace_object_by_names(model, name_list, replace_function)


def normalize_onnx_parameters(**kwargs):
torch_version = torch.__version__.split(".")
if torch_version[0] > "2" or len(torch_version) > 1 and torch_version[1] >= "10":
kwargs.pop("_retain_param_name", None)
return kwargs
10 changes: 6 additions & 4 deletions torchslim/quantizing/qat_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,10 +305,12 @@ def export_onnx(model, inputs, onnx_file):
model,
inputs,
onnx_file,
opset_version=10,
verbose=False,
enable_onnx_checker=False,
_retain_param_name=False,
**model_tools.normalize_onnx_parameters(
opset_version=10,
verbose=False,
enable_onnx_checker=False,
_retain_param_name=True,
)
)
onnx_model = onnx.load(onnx_file)
onnx_post_process(onnx_model)
Expand Down
8 changes: 7 additions & 1 deletion torchslim/quantizing/quantizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,15 @@ def pre_hook(self, input):
# # torch.onnx.symbolic_opset10
# with torch.onnx.select_model_mode_for_export(resnet_prepare, None):
# graph = torch.onnx.utils._trace(resnet_prepare,(torch.zeros(1,3,224,224),), OperatorExportTypes.ONNX)
try:
from torchpruner.model_tools import normalize_onnx_parameters
except ImportError:
def normalize_onnx_parameters (**_kwargs):
return {}

graph, params_dict, torch_out = torch.onnx.utils._model_to_graph(
resnet_prepare, (torch.zeros(1, 3, 224, 224),), _retain_param_name=True
resnet_prepare, (torch.zeros(1, 3, 224, 224),),
**normalize_onnx_parameters(_retain_param_name=True)
)

print(graph)
Expand Down

0 comments on commit cb210e2

Please sign in to comment.