From 6f104b1457d7f3669e8633db876442394060c328 Mon Sep 17 00:00:00 2001 From: yangxinyu Date: Sun, 28 Apr 2024 17:34:29 +0000 Subject: [PATCH 1/3] decompose AtenLerpTensorOp --- .../Torch/Transforms/DecomposeComplexOps.cpp | 30 +++++++++++++++++++ .../Transforms/LowerToBackendContract.cpp | 1 + 2 files changed, 31 insertions(+) diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 6cb02297d497..d8ef033a880e 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -2444,6 +2444,35 @@ class DecomposeAtenLerpScalarOp : public OpRewritePattern { }; } // namespace +namespace { +class DecomposeAtenLerpTensorOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenLerpTensorOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto resType = cast(op.getType()); + if (!resType.hasDtype()) { + return rewriter.notifyMatchFailure(op, "result should have dtype"); + } + Value cstOne = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + auto start = op.getSelf(); + auto inputType = cast(start.getType()); + + auto delta = rewriter.create(loc, inputType, op.getEnd(), + start, cstOne); + + auto weightedDelta = + rewriter.create(loc, inputType, delta, op.getWeight()); + auto lerp = rewriter.create(loc, inputType, start, + weightedDelta, cstOne); + rewriter.replaceOp(op, lerp); + return success(); + } +}; +} // namespace + // Elu = scale * max(0,x) + alpha * scale * (exp(min(0,x) * input_scale) - 1) namespace { class DecomposeAtenEluOp : public OpRewritePattern { @@ -7780,6 +7809,7 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index c5855a1fa092..37168b95ee34 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -497,6 +497,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); From 452f45fa549af75ba7487cfd08bda1b43aee5c55 Mon Sep 17 00:00:00 2001 From: yangxinyu Date: Mon, 29 Apr 2024 16:31:38 +0000 Subject: [PATCH 2/3] fix xfails.py --- projects/pt1/e2e_testing/xfail_sets.py | 2 ++ .../test_suite/elementwise.py | 25 +++++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 10c24b657128..f016a6d629a7 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1005,6 +1005,7 @@ "ElementwiseSqrtModule_basic", "ElementwiseTanIntModule_basic", "ElementwiseTanModule_basic", + "ElementwiseTernaryStaticShapeModule_basic", "ElementwiseToDtypeF32ToI64Module_basic", "ElementwiseToDtypeI64ToI8Module_basic", "ElementwiseToDtypeIdentityModule_basic", @@ -1441,6 +1442,7 @@ # and very few tests work yet. TOSA_PASS_SET = { "ElementwiseLogSigmoidModule_basic", + "ElementwiseTernaryStaticShapeModule_basic", "ElementwiseTruncModule_basic", "ElementwiseTruncIntModule_basic", "ElementwiseSgnModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index 8e287584295b..b17df4d3d875 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -414,6 +414,31 @@ def ElementwiseTernaryModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseTernaryStaticShapeModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([5, 4, 3], torch.float32, True), + ([4, 3], torch.float32, True), + ([3], torch.float32, True), + ] + ) + def forward(self, a, b, c): + return torch.lerp(a, b, c) + + +@register_test_case(module_factory=lambda: ElementwiseTernaryStaticShapeModule()) +def ElementwiseTernaryStaticShapeModule_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 4, 3), tu.rand(4, 3), tu.rand(3)) + + +# ============================================================================== + + class ElementwiseAtenWhereSelfModule(torch.nn.Module): def __init__(self): super().__init__() From 06b592f3d6476716388361be19a89a730c9b02a2 Mon Sep 17 00:00:00 2001 From: yangxinyu Date: Mon, 3 Jun 2024 06:50:36 +0000 Subject: [PATCH 3/3] fix result dtype --- lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 74861e75bc9e..91aae578e0ea 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -2585,7 +2585,7 @@ class DecomposeAtenLerpScalarOp : public OpRewritePattern { auto weightedDelta = rewriter.create(loc, inputType, delta, op.getWeight()); - auto lerp = rewriter.create(loc, inputType, start, + auto lerp = rewriter.create(loc, resType, start, weightedDelta, cstOne); rewriter.replaceOp(op, lerp); return success(); @@ -2614,7 +2614,7 @@ class DecomposeAtenLerpTensorOp : public OpRewritePattern { auto weightedDelta = rewriter.create(loc, inputType, delta, op.getWeight()); - auto lerp = rewriter.create(loc, inputType, start, + auto lerp = rewriter.create(loc, resType, start, weightedDelta, cstOne); rewriter.replaceOp(op, lerp); return success();