Skip to content

Commit

Permalink
Use iree-import-onnx --opset-version N in ImportOnnxAction. (iree-o…
Browse files Browse the repository at this point in the history
…rg#19210)

As discussed on
iree-org#18630 (comment), now
that `iree-import-onnx` supports upgrading the model version itself
(added in iree-org#18880), we don't need to
also maintain this logic in the `iree.build` package.
  • Loading branch information
ScottTodd authored and Groverkss committed Nov 29, 2024
1 parent cc0343c commit dc9ddaa
Showing 1 changed file with 14 additions and 44 deletions.
58 changes: 14 additions & 44 deletions compiler/bindings/python/iree/build/onnx_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,24 +24,10 @@ def onnx_import(
input_file = context.file(source)
output_file = context.allocate_file(name)

# Chain through an upgrade if requested.
if upgrade:
processed_file = context.allocate_file(f"{name}__upgrade.onnx")
UpgradeOnnxAction(
input_file=input_file,
output_file=processed_file,
executor=context.executor,
desc=f"Upgrading ONNX {input_file} -> {processed_file}",
deps=[
input_file,
],
)
input_file = processed_file

# Import.
ImportOnnxAction(
input_file=input_file,
output_file=output_file,
upgrade=upgrade,
desc=f"Importing ONNX {name} -> {output_file}",
executor=context.executor,
deps=[
Expand All @@ -52,43 +38,27 @@ def onnx_import(
return output_file


class UpgradeOnnxAction(BuildAction):
def __init__(self, input_file: BuildFile, output_file: BuildFile, **kwargs):
super().__init__(**kwargs)
self.input_file = input_file
self.output_file = output_file
self.deps.add(self.input_file)
output_file.deps.add(self)
CompileSourceMeta.get(output_file).input_type = "onnx"

def _invoke(self):
import onnx

input_path = self.input_file.get_fs_path()
output_path = self.output_file.get_fs_path()

original_model = onnx.load_model(str(input_path))
converted_model = onnx.version_converter.convert_version(original_model, 17)
onnx.save(converted_model, str(output_path))


class ImportOnnxAction(BuildAction):
def __init__(self, input_file: BuildFile, output_file: BuildFile, **kwargs):
def __init__(
self, input_file: BuildFile, output_file: BuildFile, upgrade: bool, **kwargs
):
super().__init__(**kwargs)
self.input_file = input_file
self.output_file = output_file
self.upgrade = upgrade
self.deps.add(input_file)
output_file.deps.add(self)
CompileSourceMeta.get(output_file).input_type = "onnx"

def _invoke(self):
import iree.compiler.tools.import_onnx.__main__ as m

args = m.parse_arguments(
[
str(self.input_file.get_fs_path()),
"-o",
str(self.output_file.get_fs_path()),
]
)
m.main(args)
args = [
str(self.input_file.get_fs_path()),
"-o",
str(self.output_file.get_fs_path()),
]
if self.upgrade:
args.extend(["--opset-version", "17"])
parsed_args = m.parse_arguments(args)
m.main(parsed_args)

0 comments on commit dc9ddaa

Please sign in to comment.