From 9d4aa0125ab4ad23a913ee472a8264064aa485db Mon Sep 17 00:00:00 2001 From: "Rickert, Jonas" Date: Tue, 28 Jan 2025 13:57:11 +0000 Subject: [PATCH 1/3] Add support for TensorProto::UINT4/INT4 Signed-off-by: Rickert, Jonas --- src/Dialect/ONNX/ONNXOps/OpHelper.cpp | 10 ++++++++-- utils/gen_onnx_mlir.py | 10 ++++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/src/Dialect/ONNX/ONNXOps/OpHelper.cpp b/src/Dialect/ONNX/ONNXOps/OpHelper.cpp index da86f80e41..946527160e 100644 --- a/src/Dialect/ONNX/ONNXOps/OpHelper.cpp +++ b/src/Dialect/ONNX/ONNXOps/OpHelper.cpp @@ -667,11 +667,13 @@ Type convertONNXTypeToMLIRType( return builder.getI1Type(); case onnx::TensorProto_DataType::TensorProto_DataType_STRING: return ONNXStringType::get(builder.getContext()); + case onnx::TensorProto_DataType::TensorProto_DataType_INT4: + return builder.getIntegerType(/*width=*/4); + case onnx::TensorProto_DataType::TensorProto_DataType_UINT4: + return builder.getIntegerType(/*width=*/4, false); case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX64: case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX128: - case onnx::TensorProto_DataType::TensorProto_DataType_INT4: - case onnx::TensorProto_DataType::TensorProto_DataType_UINT4: case onnx::TensorProto_DataType::TensorProto_DataType_UNDEFINED: llvm_unreachable("Unsupported data type encountered."); return nullptr; @@ -721,6 +723,10 @@ int64_t mlirTypeToOnnxType(Type elemType) { ? onnx::TensorProto::UNDEFINED : onnx::TensorProto::BOOL; break; + case 4: + onnxType = type.isUnsigned() ? onnx::TensorProto::UINT4 + : onnx::TensorProto::INT4; + break; case 8: onnxType = type.isUnsigned() ? onnx::TensorProto::UINT8 : onnx::TensorProto::INT8; diff --git a/utils/gen_onnx_mlir.py b/utils/gen_onnx_mlir.py index 45e879b0e1..4be0ad4b2b 100755 --- a/utils/gen_onnx_mlir.py +++ b/utils/gen_onnx_mlir.py @@ -614,6 +614,10 @@ # FLOAT8E5M2 = 19; // follows IEEE 754, supports nan, inf, mostly used for gradients # FLOAT8E5M2FNUZ = 20; // follows IEEE 754, supports nan, inf, mostly used for gradients, no negative zero # +# // 4-bit integer data types +# UINT4 = 21; // Unsigned integer in range [0, 15] +# INT4 = 22; // Signed integer in range [-8, 7], using two's-complement representation +# # // Future extensions go here. # } onnx_types = ( @@ -638,6 +642,8 @@ "float8e4m3fnuz", "float8e5m2", "float8e5m2fnuz", + "uint4", + "int4", ) tblgen_types = ( "BF16", @@ -661,6 +667,8 @@ "F8E4M3FNUZ", "F8E5M2", "F8E5M2FNUZ", + "AnyUI4", + "AnyI4", ) # Maximum count for actual type. Number more than MAX_NUM_TYPES will be used to encode @@ -1051,10 +1059,12 @@ def parse_type_str(allowedType): "seq": "SeqOf", "map": "TupleOf", "bool": "I1", + "uint4": "UI<4>", "uint8": "UI8", "uint16": "UI16", "uint32": "UI32", "uint64": "UI64", + "int4": "I<4>", "int8": "I8", "int16": "I16", "int32": "I32", From b4b59930df4858ffaa96dcb6506b342c019bad15 Mon Sep 17 00:00:00 2001 From: "Rickert, Jonas" Date: Tue, 28 Jan 2025 14:02:20 +0000 Subject: [PATCH 2/3] Upgrade onnx.Cast to opset 21 Signed-off-by: Rickert, Jonas --- docs/Dialects/onnx.md | 4 ++-- src/Builder/OpBuildTable.inc | 2 +- src/Dialect/ONNX/ONNXOps.td.inc | 6 +++--- .../parse/cast_to_int_4_and_back.onnxtext | 19 +++++++++++++++++++ utils/gen_onnx_mlir.py | 2 +- 5 files changed, 26 insertions(+), 7 deletions(-) create mode 100644 test/mlir/onnx/parse/cast_to_int_4_and_back.onnxtext diff --git a/docs/Dialects/onnx.md b/docs/Dialects/onnx.md index 3996ad35d6..5470bd6bc0 100644 --- a/docs/Dialects/onnx.md +++ b/docs/Dialects/onnx.md @@ -1114,13 +1114,13 @@ Effects: `MemoryEffects::Effect{}` | Operand | Description | | :-----: | ----------- | -| `input` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of 8-bit signless integer values or tensor of 16-bit signless integer values or tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or tensor of 8-bit unsigned integer values or tensor of 16-bit unsigned integer values or tensor of 32-bit unsigned integer values or tensor of 64-bit unsigned integer values or tensor of 1-bit signless integer values or tensor of string type values or tensor of bfloat16 type values or tensor of f8E4M3FN type values or tensor of f8E4M3FNUZ type values or tensor of f8E5M2 type values or tensor of f8E5M2FNUZ type values +| `input` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of 8-bit signless integer values or tensor of 16-bit signless integer values or tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or tensor of 8-bit unsigned integer values or tensor of 16-bit unsigned integer values or tensor of 32-bit unsigned integer values or tensor of 64-bit unsigned integer values or tensor of 1-bit signless integer values or tensor of string type values or tensor of bfloat16 type values or tensor of f8E4M3FN type values or tensor of f8E4M3FNUZ type values or tensor of f8E5M2 type values or tensor of f8E5M2FNUZ type values or tensor of 4-bit unsigned integer values or tensor of 4-bit signless integer values #### Results: | Result | Description | | :----: | ----------- | -| `output` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of 8-bit signless integer values or tensor of 16-bit signless integer values or tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or tensor of 8-bit unsigned integer values or tensor of 16-bit unsigned integer values or tensor of 32-bit unsigned integer values or tensor of 64-bit unsigned integer values or tensor of 1-bit signless integer values or tensor of string type values or tensor of bfloat16 type values or tensor of f8E4M3FN type values or tensor of f8E4M3FNUZ type values or tensor of f8E5M2 type values or tensor of f8E5M2FNUZ type values +| `output` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of 8-bit signless integer values or tensor of 16-bit signless integer values or tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or tensor of 8-bit unsigned integer values or tensor of 16-bit unsigned integer values or tensor of 32-bit unsigned integer values or tensor of 64-bit unsigned integer values or tensor of 1-bit signless integer values or tensor of string type values or tensor of bfloat16 type values or tensor of f8E4M3FN type values or tensor of f8E4M3FNUZ type values or tensor of f8E5M2 type values or tensor of f8E5M2FNUZ type values or tensor of 4-bit unsigned integer values or tensor of 4-bit signless integer values ### `onnx.CategoryMapper` (ONNXCategoryMapperOp) diff --git a/src/Builder/OpBuildTable.inc b/src/Builder/OpBuildTable.inc index 0a63f65ff7..96517d8d83 100644 --- a/src/Builder/OpBuildTable.inc +++ b/src/Builder/OpBuildTable.inc @@ -28,7 +28,7 @@ op_dialect_version_map_["BitwiseNot"] = {18}; op_dialect_version_map_["BitwiseOr"] = {18}; op_dialect_version_map_["BitwiseXor"] = {18}; op_dialect_version_map_["BlackmanWindow"] = {17}; -op_dialect_version_map_["Cast"] = {19}; +op_dialect_version_map_["Cast"] = {21}; op_dialect_version_map_["CastLike"] = {19}; op_dialect_version_map_["CastMap"] = {1}; op_dialect_version_map_["CategoryMapper"] = {1}; diff --git a/src/Dialect/ONNX/ONNXOps.td.inc b/src/Dialect/ONNX/ONNXOps.td.inc index 0516cd5f3e..8065351fe0 100644 --- a/src/Dialect/ONNX/ONNXOps.td.inc +++ b/src/Dialect/ONNX/ONNXOps.td.inc @@ -843,7 +843,7 @@ def ONNXBlackmanWindowOp:ONNX_Op<"BlackmanWindow", } def ONNXCastOp:ONNX_Op<"Cast", - [Pure, OpVersionTrait<19>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<21>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let hasCanonicalizer = 1; let summary = "ONNX Cast operation"; let description = [{ @@ -912,10 +912,10 @@ def ONNXCastOp:ONNX_Op<"Cast", | [x] < -FLT_MAX | NaN | NaN | -Inf | NaN | | else | RNE | RNE | RNE | RNE | }]; - let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I1]>, TensorOf<[StringType]>, TensorOf<[BF16]>, TensorOf<[F8E4M3FN]>, TensorOf<[F8E4M3FNUZ]>, TensorOf<[F8E5M2]>, TensorOf<[F8E5M2FNUZ]>]>:$input, + let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I1]>, TensorOf<[StringType]>, TensorOf<[BF16]>, TensorOf<[F8E4M3FN]>, TensorOf<[F8E4M3FNUZ]>, TensorOf<[F8E5M2]>, TensorOf<[F8E5M2FNUZ]>, TensorOf<[UI<4>]>, TensorOf<[I<4>]>]>:$input, DefaultValuedAttr:$saturate, TypeAttr:$to); - let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I1]>, TensorOf<[StringType]>, TensorOf<[BF16]>, TensorOf<[F8E4M3FN]>, TensorOf<[F8E4M3FNUZ]>, TensorOf<[F8E5M2]>, TensorOf<[F8E5M2FNUZ]>]>:$output); + let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I1]>, TensorOf<[StringType]>, TensorOf<[BF16]>, TensorOf<[F8E4M3FN]>, TensorOf<[F8E4M3FNUZ]>, TensorOf<[F8E5M2]>, TensorOf<[F8E5M2FNUZ]>, TensorOf<[UI<4>]>, TensorOf<[I<4>]>]>:$output); let extraClassDeclaration = [{ static int getNumberOfOperands() { return 1; diff --git a/test/mlir/onnx/parse/cast_to_int_4_and_back.onnxtext b/test/mlir/onnx/parse/cast_to_int_4_and_back.onnxtext new file mode 100644 index 0000000000..c5005ca136 --- /dev/null +++ b/test/mlir/onnx/parse/cast_to_int_4_and_back.onnxtext @@ -0,0 +1,19 @@ +// RUN: onnx-mlir --EmitONNXBasic --printIR %s | FileCheck %s +< + ir_version: 10, + opset_import: ["" : 22] +> +test_int4_casting (int4[1] input, uint4[1] input2) => (int4[1] int4_cast_output, uint4[1] uint4_cast_output) { + int8_cast_output = Cast (input) + int4_cast_output = Cast (int8_cast_output) + uint8_cast_output = Cast (input2) + uint4_cast_output = Cast (uint8_cast_output) +} +// CHECK-LABEL: func.func @main_graph +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1xi4> {onnx.name = "input"}, [[PARAM_1_:%.+]]: tensor<1xui4> {onnx.name = "input2"}) -> (tensor<1xi4> {onnx.name = "int4_cast_output"}, tensor<1xui4> {onnx.name = "uint4_cast_output"}) { +// CHECK-DAG: [[VAR_0_:%.+]] = "onnx.Cast"([[PARAM_0_]]) {saturate = 1 : si64, to = i8} : (tensor<1xi4>) -> tensor<1xi8> +// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Cast"([[VAR_0_]]) {saturate = 1 : si64, to = i4} : (tensor<1xi8>) -> tensor<1xi4> +// CHECK-DAG: [[VAR_2_:%.+]] = "onnx.Cast"([[PARAM_1_]]) {saturate = 1 : si64, to = ui8} : (tensor<1xui4>) -> tensor<1xui8> +// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.Cast"([[VAR_2_]]) {saturate = 1 : si64, to = ui4} : (tensor<1xui8>) -> tensor<1xui4> +// CHECK: onnx.Return [[VAR_1_]], [[VAR_3_]] : tensor<1xi4>, tensor<1xui4> +// CHECK: } diff --git a/utils/gen_onnx_mlir.py b/utils/gen_onnx_mlir.py index 4be0ad4b2b..46db6d4737 100755 --- a/utils/gen_onnx_mlir.py +++ b/utils/gen_onnx_mlir.py @@ -109,7 +109,7 @@ "BitwiseOr": [18], "BitwiseXor": [18], "BlackmanWindow": [17], - "Cast": [19], + "Cast": [21], "CastLike": [19], "CastMap": [1], "CategoryMapper": [1], From 177dd4d241475f006f6b26032b40fde8872358fc Mon Sep 17 00:00:00 2001 From: "Rickert, Jonas" Date: Tue, 28 Jan 2025 14:03:01 +0000 Subject: [PATCH 3/3] Add test for onnx.Cast v21 to tosa lowering Signed-off-by: Rickert, Jonas --- .../conversion/onnx_to_tosa/Math/Elementwise.mlir | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir b/test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir index 3247d16c70..f39b4deaaf 100644 --- a/test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir +++ b/test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir @@ -19,6 +19,19 @@ func.func @test_cast_f32_i8(%arg0: tensor<13x21x1xf32>) -> tensor<13x21x1xi8> { // ----- +func.func @test_cast_int4_and_uint4_to_from_int8_uint8(%arg0: tensor<1xi4>, %arg1: tensor<1xui4>) -> (tensor<1xi4>, tensor<1xui4>) { + %0 = "onnx.Cast"(%arg0) {saturate = 1 : si64, to = i8} : (tensor<1xi4>) -> tensor<1xi8> + %1 = "onnx.Cast"(%0) {saturate = 1 : si64, to = i4} : (tensor<1xi8>) -> tensor<1xi4> + %2 = "onnx.Cast"(%arg1) {saturate = 1 : si64, to = ui8} : (tensor<1xui4>) -> tensor<1xui8> + %3 = "onnx.Cast"(%2) {saturate = 1 : si64, to = ui4} : (tensor<1xui8>) -> tensor<1xui4> + onnx.Return %1, %3 : tensor<1xi4>, tensor<1xui4> + // CHECK-LABEL: func.func @test_cast_int4_and_uint4_to_from_int8_uint8( + // TOSA does not support int4 casting + // CHECK-NOT: tosa.cast +} + +// ----- + func.func @test_cast_f16_i8(%arg0: tensor<13x21x1xf16>) -> tensor<13x21x1xi8> { %0 = "onnx.Cast"(%arg0) {to = i8} : (tensor<13x21x1xf16>) -> tensor<13x21x1xi8> "func.return"(%0) : (tensor<13x21x1xi8>) -> ()