From 4b15edd9a77fd54fa1a2354359521f51ab4268a3 Mon Sep 17 00:00:00 2001 From: Nirvedh Meshram <96096277+nirvedhmeshram@users.noreply.github.com> Date: Wed, 13 Nov 2024 14:55:05 -0600 Subject: [PATCH] Extend hoist collapse out of scf.forall pattern to use same offsets for all users (#19139) The existing pattern added in https://github.com/iree-org/iree/pull/19044 created different offsets for each user even though we previously checked that the offsets will be exactly same. This was preventing recursive application of the pattern as the comparison of the offsets for the next application of patten would fail. The change in this PR is tested by removing cse in test file which was added by https://github.com/iree-org/iree/pull/19044 to workaround this exact issue. Signed-off-by: Nirvedh Meshram --- .../Common/PropagateReshapesByExpansion.cpp | 69 ++++++++----------- .../test/propagate_reshapes_by_expansion.mlir | 2 +- 2 files changed, 30 insertions(+), 41 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Common/PropagateReshapesByExpansion.cpp b/compiler/src/iree/compiler/Codegen/Common/PropagateReshapesByExpansion.cpp index 0f16f67130be..aae1a7cc5f80 100644 --- a/compiler/src/iree/compiler/Codegen/Common/PropagateReshapesByExpansion.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/PropagateReshapesByExpansion.cpp @@ -73,14 +73,12 @@ getExpandedShape(SmallVector reIndices, /// account for the expand. If not we bail out. There are two supported users /// which are extract_slice -> expand_shape with the same exact reassociation /// map as the collapse op to be hoisted out or the root parallel_insert_slice. -static LogicalResult -verifyAndCollectExpandableUsers(Value insertDest, - SmallVector reIndices, - tensor::ParallelInsertSliceOp parallelInsertOp, - SmallVector &expandableUsers) { +static LogicalResult verifyAndCollectExpandableUsers( + Value insertDest, SmallVector reIndices, + tensor::ParallelInsertSliceOp parallelInsertOp, + SmallVector &expandableUsers) { for (Operation *user : insertDest.getUsers()) { if (user == parallelInsertOp) { - expandableUsers.push_back(user); continue; } auto extractSliceOp = dyn_cast(user); @@ -98,19 +96,20 @@ verifyAndCollectExpandableUsers(Value insertDest, expandShapeOp.getReassociationIndices(); if (reIndices != expandReIndices) return failure(); - expandableUsers.push_back(user); + expandableUsers.push_back(extractSliceOp); } return success(); } /// Utility to expand the pre-verified expandable users of the scf.forall /// output. -static void expandVerifiedUsers(PatternRewriter &rewriter, Location loc, - MLIRContext *ctx, - SmallVector expandableUsers, - SmallVector totalInnerSizes, - SmallVector reIndices, - scf::ForallOp forallOp) { +static void +expandVerifiedUsers(PatternRewriter &rewriter, Location loc, MLIRContext *ctx, + SmallVector expandableUsers, + SmallVector totalInnerSizes, + SmallVector reIndices, + scf::ForallOp forallOp, + tensor::ParallelInsertSliceOp parallelInsertOp) { // compute the offsets,sizes,strides in the expanded dimensions. auto computeExpandedAccess = [&](ArrayRef mixedOffsets, ShapedType resultType) @@ -124,6 +123,7 @@ static void expandVerifiedUsers(PatternRewriter &rewriter, Location loc, for (size_t i = 1, e = reIndices[index].size(); i < e; ++i) { expandedOffsets.push_back(getAsIndexOpFoldResult(ctx, 0)); } + rewriter.setInsertionPointToStart(forallOp.getBody()); // Compute the outer dimension expression. AffineExpr s0, s1; bindSymbols(rewriter.getContext(), s0, s1); @@ -144,31 +144,20 @@ static void expandVerifiedUsers(PatternRewriter &rewriter, Location loc, rewriter.getIndexAttr(1)); return {expandedOffsets, expandedSizes, expandedStrides}; }; - for (Operation *user : expandableUsers) { - rewriter.setInsertionPointToStart(forallOp.getBody()); - if (auto extractSliceOp = dyn_cast(user)) { - auto expandShapeOp = - dyn_cast(*extractSliceOp->getUsers().begin()); - RankedTensorType resultType = expandShapeOp.getResultType(); - auto [expandedOffsets, expandedSizes, expandedStrides] = - computeExpandedAccess(extractSliceOp.getMixedOffsets(), resultType); - rewriter.setInsertionPoint(extractSliceOp); - rewriter.replaceOpWithNewOp( - extractSliceOp, resultType, extractSliceOp.getSource(), - expandedOffsets, expandedSizes, expandedStrides); - } else if (auto parallelInsertOp = - dyn_cast(user)) { - auto collapseShapeOp = - parallelInsertOp.getSource().getDefiningOp(); - RankedTensorType resultType = collapseShapeOp.getSrcType(); - auto [expandedOffsets, expandedSizes, expandedStrides] = - computeExpandedAccess(parallelInsertOp.getMixedOffsets(), resultType); - rewriter.setInsertionPoint(parallelInsertOp); - rewriter.replaceOpWithNewOp( - parallelInsertOp, collapseShapeOp.getSrc(), - parallelInsertOp.getDest(), expandedOffsets, expandedSizes, - expandedStrides); - } + auto collapseShapeOp = + parallelInsertOp.getSource().getDefiningOp(); + RankedTensorType resultType = collapseShapeOp.getSrcType(); + auto [expandedOffsets, expandedSizes, expandedStrides] = + computeExpandedAccess(parallelInsertOp.getMixedOffsets(), resultType); + rewriter.setInsertionPoint(parallelInsertOp); + rewriter.replaceOpWithNewOp( + parallelInsertOp, collapseShapeOp.getSrc(), parallelInsertOp.getDest(), + expandedOffsets, expandedSizes, expandedStrides); + for (tensor::ExtractSliceOp extractSliceOp : expandableUsers) { + rewriter.setInsertionPoint(extractSliceOp); + rewriter.replaceOpWithNewOp( + extractSliceOp, resultType, extractSliceOp.getSource(), expandedOffsets, + expandedSizes, expandedStrides); } return; } @@ -242,7 +231,7 @@ struct ExpandDestinationForallOp final // Verify that the users of destination are valid to expand and collect all // such users. - SmallVector expandableUsers; + SmallVector expandableUsers; if (failed(verifyAndCollectExpandableUsers( insertDest, collapseOp.getReassociationIndices(), parallelInsertOp, expandableUsers))) { @@ -252,7 +241,7 @@ struct ExpandDestinationForallOp final // Expand the users of the destination. rewriter.setInsertionPointToStart(forallOp.getBody()); expandVerifiedUsers(rewriter, loc, ctx, expandableUsers, totalInnerSizes, - reIndices, forallOp); + reIndices, forallOp, parallelInsertOp); rewriter.setInsertionPoint(forallOp); // Create the expand -> new scf.forall -> collapse chain. diff --git a/compiler/src/iree/compiler/Codegen/Common/test/propagate_reshapes_by_expansion.mlir b/compiler/src/iree/compiler/Codegen/Common/test/propagate_reshapes_by_expansion.mlir index faeb828097b4..88abd0cac880 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/propagate_reshapes_by_expansion.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/propagate_reshapes_by_expansion.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-propagate-reshapes-by-expansion), cse)" \ +// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-propagate-reshapes-by-expansion))" \ // RUN: --split-input-file %s --mlir-print-local-scope | FileCheck %s func.func @reshape_and_lowering_config(%src: tensor<3x4xf16>, %dest: tensor<12xf16>, %dest2: tensor<12xf16>) -> tensor<12xf16> {