diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 91d6b5eb17fc..0d1eea1e3bea 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -8034,105 +8034,17 @@ class DecomposeAtenToDeviceOp : public OpRewritePattern { } // namespace namespace { +// Decompose `aten.adaptive_avg_pool1d` op into `aten.avg_pool1d`. // Decompose `aten.adaptive_max_pool1d` op into `aten.max_pool1d_with_indices` -// op. -class DecomposeAtenAdaptiveMaxPool1dOp - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(AtenAdaptiveMaxPool1dOp op, - PatternRewriter &rewriter) const override { - Location loc = op->getLoc(); - MLIRContext *context = op.getContext(); - - Value input = op.getSelf(); - std::optional maybeRank = getTensorRank(input); - if (!maybeRank) { - return rewriter.notifyMatchFailure(op, "expected input to have a rank"); - } - unsigned rank = *maybeRank; - Value sizeDim = rewriter.create( - loc, rewriter.getI64IntegerAttr(rank - 1)); - Value inputSize = rewriter.create(loc, input, sizeDim); - - Value outputShape = op.getOutputSize(); - SmallVector outputShapeSizesTorchInt; - getListConstructElements(outputShape, outputShapeSizesTorchInt); - Value outputSize = outputShapeSizesTorchInt[0]; - - Value constantOne = rewriter.create( - loc, rewriter.getI64IntegerAttr(1)); - Value constantZero = rewriter.create( - loc, rewriter.getI64IntegerAttr(0)); - Value constantFalse = rewriter.create(loc, false); - - int64_t outputSizeInt; - if (!matchPattern(outputSize, m_TorchConstantInt(&outputSizeInt))) { - return rewriter.notifyMatchFailure( - op, "the output size of adaptive_max_pool1d must be a constant int"); - } - - SmallVector kernelSize; - if (outputSizeInt == 1) { - BaseTensorType inputTensorType = cast(input.getType()); - ArrayRef inputShape = inputTensorType.getSizes(); - kernelSize.push_back( - inputShape[rank - 1] == kUnknownSize - ? inputSize - : rewriter.create( - loc, rewriter.getI64IntegerAttr(inputShape[rank - 1]))); - } else { - if (!isAssumingStrictSymbolicShapes(rewriter)) { - Value cond = rewriter.create(loc, inputSize, outputSize); - rewriter.create( - loc, cond, - "unimplemented: only support cases where input and output size are " - "equal for non-unit output size"); - } - kernelSize.push_back(constantOne); - } - - Value kernelSizeList = rewriter.create( - loc, Torch::ListType::get(Torch::IntType::get(context)), kernelSize); - Value strideList = rewriter.create( - loc, Torch::ListType::get(Torch::IntType::get(context)), - ValueRange{constantOne}); - Value paddingSizeList = rewriter.create( - loc, Torch::ListType::get(Torch::IntType::get(context)), - ValueRange{constantZero}); - Value dialationList = rewriter.create( - loc, Torch::ListType::get(Torch::IntType::get(context)), - ValueRange{constantOne}); - - if (op.getResult(1).use_empty()) { - auto maxPool = rewriter.create( - loc, op.getType(0), input, kernelSizeList, strideList, - paddingSizeList, dialationList, - /*ceil_mode=*/constantFalse); - rewriter.replaceOp(op, {maxPool.getResult(), Value()}); - } else { - auto maxPool = rewriter.create( - loc, op.getType(0), op.getType(1), input, kernelSizeList, strideList, - paddingSizeList, dialationList, - /*ceil_mode=*/constantFalse); - rewriter.replaceOp(op, maxPool.getResults()); - } - return success(); - } -}; -} // namespace - -namespace { -// Decompose `aten.adaptive_avg_pool1d` op into `aten.avg_pool1d` op. - -// The logic of this decomposition is totally same with -// the DecomposeAtenAdaptiveAvgPool2dOp, that means currently only following two -// cases are supported: +// or `aten.max_pool1d`. +// +// Only following two cases are supported: // 1. inputSize = outputSize // 2. outputSize = 1 -class DecomposeAtenAdaptiveAvgPool1dOp - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(AtenAdaptiveAvgPool1dOp op, +template +class DecomposeAtenAdaptivePool1dOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenOpT op, PatternRewriter &rewriter) const override { Location loc = op->getLoc(); MLIRContext *context = op.getContext(); @@ -8145,11 +8057,10 @@ class DecomposeAtenAdaptiveAvgPool1dOp unsigned rank = *maybeRank; Value sizeDim = rewriter.create( loc, rewriter.getI64IntegerAttr(rank - 1)); - Value inputSize = rewriter.create(loc, input, sizeDim); + Value inputSize = rewriter.createOrFold(loc, input, sizeDim); - Value outputShape = op.getOutputSize(); SmallVector outputShapeSizesTorchInt; - getListConstructElements(outputShape, outputShapeSizesTorchInt); + getListConstructElements(op.getOutputSize(), outputShapeSizesTorchInt); Value outputSize = outputShapeSizesTorchInt[0]; Value constantOne = rewriter.create( @@ -8162,18 +8073,12 @@ class DecomposeAtenAdaptiveAvgPool1dOp int64_t outputSizeInt; if (!matchPattern(outputSize, m_TorchConstantInt(&outputSizeInt))) { return rewriter.notifyMatchFailure( - op, "the output size of adaptive_pool_1d must be a constant int"); + op, "the output size of adaptive pool1d must be a constant int"); } SmallVector kernelSize; if (outputSizeInt == 1) { - BaseTensorType inputTensorType = cast(input.getType()); - ArrayRef inputShape = inputTensorType.getSizes(); - kernelSize.push_back( - inputShape[rank - 1] == kUnknownSize - ? inputSize - : rewriter.create( - loc, rewriter.getI64IntegerAttr(inputShape[rank - 1]))); + kernelSize.push_back(inputSize); } else { if (!isAssumingStrictSymbolicShapes(rewriter)) { Value cond = rewriter.create(loc, inputSize, outputSize); @@ -8194,16 +8099,40 @@ class DecomposeAtenAdaptiveAvgPool1dOp loc, Torch::ListType::get(Torch::IntType::get(context)), ValueRange{constantZero}); - rewriter.replaceOpWithNewOp( - op, op.getType(), input, kernelSizeList, strideList, paddingSizeList, - /*ceil_mode=*/constantFalse, /*count_include_pad=*/constantTrue); - return success(); + if constexpr (std::is_same_v) { + rewriter.replaceOpWithNewOp( + op, op.getType(), input, kernelSizeList, strideList, paddingSizeList, + /*ceil_mode=*/constantFalse, /*count_include_pad=*/constantTrue); + return success(); + } else if constexpr (std::is_same_v) { + Value dilationList = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(context)), + ValueRange{constantOne}); + if (op.getResult(1).use_empty()) { + auto maxPool = rewriter.create( + loc, op.getType(0), input, kernelSizeList, strideList, + paddingSizeList, dilationList, + /*ceil_mode=*/constantFalse); + rewriter.replaceOp(op, {maxPool.getResult(), Value()}); + } else { + auto maxPool = rewriter.create( + loc, op.getType(0), op.getType(1), input, kernelSizeList, + strideList, paddingSizeList, dilationList, + /*ceil_mode=*/constantFalse); + rewriter.replaceOp(op, maxPool.getResults()); + } + return success(); + } + return rewriter.notifyMatchFailure( + op, "unimplemented: unsupported template op"); } }; } // namespace namespace { -// Decompose `aten.adaptiveAvgPool2d` op into `aten.avgPool2d` op. +// Decompose `aten.adaptive_avg_pool2d` op into `aten.avg_pool2d` op. +// Decompose `aten.adaptive_max_pool2d` op into `aten.max_pool2d` or +// `aten.max_pool2d_with_indices` op. // // For AdaptiveAvgPool2d op, when the input size is an integer multiple of // output size the kernelSize, stride and padding is calculated as follows: @@ -8213,10 +8142,10 @@ namespace { // kernelW = inW - [(outW - 1) * strideW] = strideW // paddingH = 0, paddingW = 0 // -class DecomposeAtenAdaptiveAvgPool2dOp - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(AtenAdaptiveAvgPool2dOp op, +template +class DecomposeAtenAdaptivePool2dOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenOpT op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); @@ -8232,15 +8161,14 @@ class DecomposeAtenAdaptiveAvgPool2dOp Value dimH = rewriter.create( loc, rewriter.getI64IntegerAttr(rank - 2)); inputHW.push_back( - /*inH=*/rewriter.create(loc, input, dimH)); + /*inH=*/rewriter.createOrFold(loc, input, dimH)); Value dimW = rewriter.create( loc, rewriter.getI64IntegerAttr(rank - 1)); inputHW.push_back( - /*inW=*/rewriter.create(loc, input, dimW)); + /*inW=*/rewriter.createOrFold(loc, input, dimW)); - Value outputShape = op.getOutputSize(); SmallVector outputShapeSizesTorchInt; - getListConstructElements(outputShape, outputShapeSizesTorchInt); + getListConstructElements(op.getOutputSize(), outputShapeSizesTorchInt); // TODO: Add support for cases other than: // inH % outH != 0 or inW % outW != 0 where @@ -8321,11 +8249,32 @@ class DecomposeAtenAdaptiveAvgPool2dOp loc, Torch::ListType::get(Torch::IntType::get(context)), ValueRange{constantZero, constantZero}); - rewriter.replaceOpWithNewOp( - op, op.getType(), input, kernelSizeList, strideList, paddingSizeList, - /*ceilMode=*/constantFalse, /*countIncludePad=*/constantTrue, - /*divisorOverride=*/constantNone); - return success(); + if constexpr (std::is_same_v) { + rewriter.replaceOpWithNewOp( + op, op.getType(), input, kernelSizeList, strideList, paddingSizeList, + /*ceilMode=*/constantFalse, /*countIncludePad=*/constantTrue, + /*divisorOverride=*/constantNone); + return success(); + } else if constexpr (std::is_same_v) { + Value dilationList = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(context)), + ValueRange{constantOne, constantOne}); + if (op.getResult(1).use_empty()) { + auto maxPool = rewriter.create( + loc, op.getType(0), input, kernelSizeList, strideList, + paddingSizeList, dilationList, /*ceil_mode=*/constantFalse); + rewriter.replaceOp(op, {maxPool.getResult(), Value()}); + } else { + auto maxPool = rewriter.create( + loc, op.getType(0), op.getType(1), input, kernelSizeList, + strideList, paddingSizeList, dilationList, + /*ceil_mode=*/constantFalse); + rewriter.replaceOp(op, maxPool.getResults()); + } + return success(); + } + return rewriter.notifyMatchFailure( + op, "unimplemented: unsupported template op"); } }; } // namespace @@ -11640,9 +11589,14 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); - addPatternIfTargetOpIsIllegal(patterns); - addPatternIfTargetOpIsIllegal(patterns); - addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal< + DecomposeAtenAdaptivePool1dOp>(patterns); + addPatternIfTargetOpIsIllegal< + DecomposeAtenAdaptivePool1dOp>(patterns); + addPatternIfTargetOpIsIllegal< + DecomposeAtenAdaptivePool2dOp>(patterns); + addPatternIfTargetOpIsIllegal< + DecomposeAtenAdaptivePool2dOp>(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index f15911e2b5ba..a28a61c613e5 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -509,6 +509,8 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 38eb1f573362..f7978cdf954d 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2802,6 +2802,7 @@ "AdaptiveMaxPool2dDynamicNoBatch_basic", "AdaptiveMaxPool2dDynamicWithIndices_basic", "AdaptiveMaxPool2dDynamic_basic", + "AdaptiveMaxPool2dFixedKernelStrideSizeStaticModule_basic", "AdaptiveMaxPool2dStaticWithIndices_basic", "AdaptiveMaxPool2dStatic_basic", "AdaptiveMaxPool3dDynamicNoBatch_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py index e2eaa4cfd0fe..a42ef43cd448 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py @@ -1921,6 +1921,52 @@ def AdaptiveMaxPool2dStaticWithIndices_basic(module, tu: TestUtils): module.forward(tu.rand(1, 512, 10, 16)) +class AdaptiveMaxPool2dFixedKernelStrideSizeStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.amp2d = torch.nn.AdaptiveMaxPool2d((2, 2)) + + @export + @annotate_args( + [ + None, + ([1, 3, 7, 7], torch.float32, True), + ] + ) + def forward(self, x): + return self.amp2d(x) + + +@register_test_case( + module_factory=lambda: AdaptiveMaxPool2dFixedKernelStrideSizeStaticModule() +) +def AdaptiveMaxPool2dFixedKernelStrideSizeStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 3, 7, 7)) + + +class AdaptiveMaxPool2dUnitOutputSizeStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.amp2d = torch.nn.AdaptiveMaxPool2d((1, 1)) + + @export + @annotate_args( + [ + None, + ([1, 512, 7, 7], torch.float32, True), + ] + ) + def forward(self, x): + return self.amp2d(x) + + +@register_test_case( + module_factory=lambda: AdaptiveMaxPool2dUnitOutputSizeStaticModule() +) +def AdaptiveMaxPool2dUnitOutputSizeStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 512, 7, 7)) + + # AdaptiveMaxPool3d