From 5a5cc6b34117e9956a4c7438afa8d83ae0bb9ee6 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Wed, 11 Dec 2024 10:36:51 +0530 Subject: [PATCH] [MLIR][TORCH] Add aten.special.expm1 op lowering (#3878) This commit adds the support for torch.aten.special.expm1 op by decomposing it into torch.aten.expm1 op. --------- Signed-off-by: Vivek Khandelwal --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 23 +++++++++ .../Transforms/AbstractInterpLibrary.cpp | 9 ++++ .../Torch/Transforms/DecomposeComplexOps.cpp | 14 ++++++ .../Transforms/LowerToBackendContract.cpp | 1 + projects/pt1/e2e_testing/xfail_sets.py | 12 +++-- .../build_tools/abstract_interp_lib_gen.py | 8 +++ .../build_tools/torch_ods_gen.py | 1 + .../test_suite/elementwise.py | 50 ++++++++++++++++++- 8 files changed, 112 insertions(+), 6 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index f951de9af795..556b0aa76e93 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -4610,6 +4610,29 @@ def Torch_AtenTrunc_Op : Torch_Op<"aten.trunc_", [ }]; } +def Torch_AtenSpecialExpm1Op : Torch_Op<"aten.special_expm1", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::special_expm1 : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenSpecialExpm1Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenSpecialExpm1Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + def Torch_AtenSignOp : Torch_Op<"aten.sign", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index edcc81a2847f..fb0aaa7201b8 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6495,6 +6495,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.special_expm1\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.isfinite\"(%arg0: !torch.list) -> !torch.list {\n" " return %arg0 : !torch.list\n" " }\n" @@ -11589,6 +11593,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" " return %1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.special_expm1\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.isfinite\"(%arg0: !torch.tuple) -> !torch.int {\n" " %int11 = torch.constant.int 11\n" " return %int11 : !torch.int\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 919c4727b1f9..063dca041901 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -11177,6 +11177,19 @@ class DecomposeTorchvisionNmsOp : public OpRewritePattern { }; } // namespace +namespace { +class DecomposeAtenSpecialExpm1Op + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenSpecialExpm1Op op, + PatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf()); + return success(); + } +}; +} // namespace + namespace { class DecomposeComplexOpsPass : public DecomposeComplexOpsBase { @@ -11462,6 +11475,7 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal( patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal< DecomposeAtenFMaxMinOp>(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index f868c4c1800a..25635d2c5c46 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -569,6 +569,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); for (auto &opName : backendLegalOpsSet) { target.addLegalOp( diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 9f832cb9e033..d2c6e6c9a762 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -500,8 +500,6 @@ "AdaptiveMaxPool1dStatic_basic", "CrossEntropyLossModule_basic", "CrossEntropyLossNoReductionModule_basic", - "ElementwiseExpm1IntModule_basic", - "ElementwiseExpm1Module_basic", "IsInfiniteModule_basic", "InterpolateDynamicModule_sizes_nearest", "IouOfModule_basic", @@ -909,8 +907,6 @@ "AtenItemIntOpModule_basic", "CrossEntropyLossModule_basic", "CrossEntropyLossNoReductionModule_basic", - "ElementwiseExpm1IntModule_basic", - "ElementwiseExpm1Module_basic", "InterpolateDynamicModule_sizes_nearest", "IouOfModule_basic", "IscloseStaticModuleTrue_basic", @@ -1209,6 +1205,8 @@ "ElementwiseRsqrtModule_basic", "ElementwiseSigmoidModule_basic", "ElementwiseSinModule_basic", + "ElementwiseSpecialExpm1IntModule_basic", + "ElementwiseSpecialExpm1Module_basic", "ElementwiseSqrtModule_basic", "ElementwiseTanIntModule_basic", "ElementwiseTanModule_basic", @@ -2951,6 +2949,8 @@ "ElementwiseEluNonDefaultModule_basic", "ElementwiseExpm1IntModule_basic", "ElementwiseExpm1Module_basic", + "ElementwiseSpecialExpm1IntModule_basic", + "ElementwiseSpecialExpm1Module_basic", "ElementwiseFmodTensor_Int_basic", "ElementwiseCreateComplexModule_basic", "ElementwiseMulTensorComplexModule_basic", @@ -3662,6 +3662,8 @@ "ElementwiseQuantizePerTensorUIntModule_basic", "ElementwiseSinhIntModule_basic", "ElementwiseSinhModule_basic", + "ElementwiseSpecialExpm1IntModule_basic", + "ElementwiseSpecialExpm1Module_basic", "ElementwiseToDtypeF32ToI64Module_basic", "ElementwiseToDtypeI64ToUI8Module_basic", "ElementwiseWhereScalarOtherStaticModule_basic", @@ -4355,6 +4357,8 @@ "ElementwiseSinIntModule_basic", "ElementwiseSinhIntModule_basic", "ElementwiseSinhModule_basic", + "ElementwiseSpecialExpm1IntModule_basic", + "ElementwiseSpecialExpm1Module_basic", "ElementwiseSqrtIntModule_basic", "ElementwiseSubScalarIntModule_basic", "ElementwiseTanIntModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 331aa476910e..2a980bf534fd 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -222,6 +222,9 @@ def aten〇exp2〡shape(self: List[int]) -> List[int]: def aten〇expm1〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇special_expm1〡shape(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇isfinite〡shape(self: List[int]) -> List[int]: return self @@ -2717,6 +2720,11 @@ def aten〇expm1〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype return _get_dtype_of_floating_point_op(self_dtype) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇special_expm1〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return _get_dtype_of_floating_point_op(self_dtype) + def aten〇isfinite〡dtype(self_rank_dtype: Tuple[int, int]) -> int: return torch.bool diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 8a0417a85189..4c2de094e109 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -452,6 +452,7 @@ def emit_with_mutating_variants(key, **kwargs): emit_with_mutating_variants("aten::ceil : (Tensor) -> (Tensor)", has_folder=True) emit_with_mutating_variants("aten::round : (Tensor) -> (Tensor)", has_folder=True) emit_with_mutating_variants("aten::trunc : (Tensor) -> (Tensor)", has_folder=True) + emit("aten::special_expm1 : (Tensor) -> (Tensor)") emit_with_mutating_variants( "aten::sign : (Tensor) -> (Tensor)", has_canonicalizer=True ) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index 38fccc06b393..b1745fa5b85a 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -5207,7 +5207,7 @@ def __init__(self): ] ) def forward(self, a): - return torch.special.expm1(a) + return torch.expm1(a) @register_test_case(module_factory=lambda: ElementwiseExpm1Module()) @@ -5230,7 +5230,7 @@ def __init__(self): ] ) def forward(self, a): - return torch.special.expm1(a) + return torch.expm1(a) @register_test_case(module_factory=lambda: ElementwiseExpm1IntModule()) @@ -5241,6 +5241,52 @@ def ElementwiseExpm1IntModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseSpecialExpm1Module(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) + def forward(self, a): + return torch.special.expm1(a) + + +@register_test_case(module_factory=lambda: ElementwiseSpecialExpm1Module()) +def ElementwiseSpecialExpm1Module_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) + + +# ============================================================================== + + +class ElementwiseSpecialExpm1IntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.int32, True), + ] + ) + def forward(self, a): + return torch.special.expm1(a) + + +@register_test_case(module_factory=lambda: ElementwiseSpecialExpm1IntModule()) +def ElementwiseSpecialExpm1IntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32)) + + +# ============================================================================== + + class ElementwiseRad2DegModule(torch.nn.Module): def __init__(self): super().__init__()