Skip to content

Commit

Permalink
Fix IR import for integer to signless integer conversions
Browse files Browse the repository at this point in the history
  • Loading branch information
vinayakdsci committed Oct 24, 2024
1 parent 5e8caf4 commit 95994fd
Showing 1 changed file with 40 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from pathlib import Path
import sys
import tempfile
import warnings
import copy
import random
import iree.runtime as rt
import string
Expand Down Expand Up @@ -56,6 +56,7 @@
InsertionPoint,
Value,
SymbolTable,
IntegerType,
)


Expand Down Expand Up @@ -115,8 +116,11 @@ def create_tensor_global(
with InsertionPoint.at_block_begin(
self._m.regions[0].blocks[0]
), Location.unknown():
# After lowering to linalg-on-tensors, the data type need to be signless.
# So, we construct the globals to have signless types, and use
# torch_c.from_builtin_tensor to convert to the correct frontend type.
vtensor_type = RankedTensorType.get(
tuple(t.dims), self._cc.tensor_element_type(t.data_type)
tuple(t.dims), ELEM_TYPE_TO_SIGNLESS_IR_TYPE[t.data_type]()
)
ir_attrs = {
"sym_name": StringAttr.get(name),
Expand Down Expand Up @@ -240,10 +244,6 @@ def main(args: argparse.Namespace):
if args.externalize_params:
imp = IREENodeImporter.define_function(model_info.main_graph, m, args.max_numel)
else:
if args.max_numel:
warnings.warn(
"'--max-numel' has no effect until externalization is enabled with '--externalize-params'"
)
imp = onnx_importer.NodeImporter.define_function(model_info.main_graph, m)
imp.import_all()

Expand Down Expand Up @@ -317,6 +317,40 @@ def load_onnx_model(args: argparse.Namespace) -> onnx.ModelProto:
return inferred_model


ELEM_TYPE_TO_SIGNLESS_IR_TYPE = copy.deepcopy(onnx_importer.ELEM_TYPE_TO_IR_TYPE_CB)

ELEM_TYPE_TO_SIGNLESS_IR_TYPE[
onnx.TensorProto.DataType.INT64
] = lambda: IntegerType.get_signless(64)
ELEM_TYPE_TO_SIGNLESS_IR_TYPE[
onnx.TensorProto.DataType.INT32
] = lambda: IntegerType.get_signless(32)
ELEM_TYPE_TO_SIGNLESS_IR_TYPE[
onnx.TensorProto.DataType.INT16
] = lambda: IntegerType.get_signless(16)
ELEM_TYPE_TO_SIGNLESS_IR_TYPE[
onnx.TensorProto.DataType.INT8
] = lambda: IntegerType.get_signless(8)
ELEM_TYPE_TO_SIGNLESS_IR_TYPE[
onnx.TensorProto.DataType.INT4
] = lambda: IntegerType.get_signless(4)
ELEM_TYPE_TO_SIGNLESS_IR_TYPE[
onnx.TensorProto.DataType.UINT8
] = lambda: IntegerType.get_signless(8)
ELEM_TYPE_TO_SIGNLESS_IR_TYPE[
onnx.TensorProto.DataType.UINT4
] = lambda: IntegerType.get_signless(4)
ELEM_TYPE_TO_SIGNLESS_IR_TYPE[
onnx.TensorProto.DataType.UINT16
] = lambda: IntegerType.get_signless(16)
ELEM_TYPE_TO_SIGNLESS_IR_TYPE[
onnx.TensorProto.DataType.UINT64
] = lambda: IntegerType.get_signless(64)
ELEM_TYPE_TO_SIGNLESS_IR_TYPE[
onnx.TensorProto.DataType.UINT32
] = lambda: IntegerType.get_signless(32)


def parse_arguments(argv=None) -> argparse.Namespace:
parser = argparse.ArgumentParser(description="IREE ONNX import tool")
parser.add_argument("input_file", help="ONNX protobuf input", type=Path)
Expand Down

0 comments on commit 95994fd

Please sign in to comment.