diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index da49e2d77049..62a497f51074 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -2585,7 +2585,36 @@ class DecomposeAtenLerpScalarOp : public OpRewritePattern { auto weightedDelta = rewriter.create(loc, inputType, delta, op.getWeight()); - auto lerp = rewriter.create(loc, inputType, start, + auto lerp = rewriter.create(loc, resType, start, + weightedDelta, cstOne); + rewriter.replaceOp(op, lerp); + return success(); + } +}; +} // namespace + +namespace { +class DecomposeAtenLerpTensorOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenLerpTensorOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto resType = cast(op.getType()); + if (!resType.hasDtype()) { + return rewriter.notifyMatchFailure(op, "result should have dtype"); + } + Value cstOne = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + auto start = op.getSelf(); + auto inputType = cast(start.getType()); + + auto delta = rewriter.create(loc, inputType, op.getEnd(), + start, cstOne); + + auto weightedDelta = + rewriter.create(loc, inputType, delta, op.getWeight()); + auto lerp = rewriter.create(loc, resType, start, weightedDelta, cstOne); rewriter.replaceOp(op, lerp); return success(); @@ -8114,6 +8143,7 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index ffc45a1be859..dc18761e3127 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -507,6 +507,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 9bbeb9befef9..7ac40a365659 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1020,6 +1020,7 @@ "ElementwiseSqrtModule_basic", "ElementwiseTanIntModule_basic", "ElementwiseTanModule_basic", + "ElementwiseTernaryStaticShapeModule_basic", "ElementwiseToDtypeF32ToI64Module_basic", "ElementwiseToDtypeI64ToI8Module_basic", "ElementwiseToDtypeIdentityModule_basic", @@ -1475,6 +1476,7 @@ "AtenDotModule_basic", "ElementwiseFloatTensorGtIntScalarModule_basic", "ElementwiseLogSigmoidModule_basic", + "ElementwiseTernaryStaticShapeModule_basic", "ElementwiseTruncModule_basic", "ElementwiseTruncIntModule_basic", "ElementwiseSgnModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index a7f27df555ba..67c2c1b6f3e8 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -414,6 +414,31 @@ def ElementwiseTernaryModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseTernaryStaticShapeModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([5, 4, 3], torch.float32, True), + ([4, 3], torch.float32, True), + ([3], torch.float32, True), + ] + ) + def forward(self, a, b, c): + return torch.lerp(a, b, c) + + +@register_test_case(module_factory=lambda: ElementwiseTernaryStaticShapeModule()) +def ElementwiseTernaryStaticShapeModule_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 4, 3), tu.rand(4, 3), tu.rand(3)) + + +# ============================================================================== + + class ElementwiseAtenWhereSelfModule(torch.nn.Module): def __init__(self): super().__init__()