Skip to content

Commit

Permalink
Extend hoist collapse out of scf.forall pattern to use same offsets f…
Browse files Browse the repository at this point in the history
…or all users (#19139)

The existing pattern added in
#19044 created different offsets
for each user even though we previously checked that the offsets will be
exactly same. This was preventing recursive application of the pattern
as the comparison of the offsets for the next application of patten
would fail. The change in this PR is tested by removing cse in test file
which was added by #19044 to
workaround this exact issue.

Signed-off-by: Nirvedh Meshram <[email protected]>
  • Loading branch information
nirvedhmeshram authored Nov 13, 2024
1 parent ab35e1b commit 4b15edd
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,12 @@ getExpandedShape(SmallVector<ReassociationIndices> reIndices,
/// account for the expand. If not we bail out. There are two supported users
/// which are extract_slice -> expand_shape with the same exact reassociation
/// map as the collapse op to be hoisted out or the root parallel_insert_slice.
static LogicalResult
verifyAndCollectExpandableUsers(Value insertDest,
SmallVector<ReassociationIndices> reIndices,
tensor::ParallelInsertSliceOp parallelInsertOp,
SmallVector<Operation *> &expandableUsers) {
static LogicalResult verifyAndCollectExpandableUsers(
Value insertDest, SmallVector<ReassociationIndices> reIndices,
tensor::ParallelInsertSliceOp parallelInsertOp,
SmallVector<tensor::ExtractSliceOp> &expandableUsers) {
for (Operation *user : insertDest.getUsers()) {
if (user == parallelInsertOp) {
expandableUsers.push_back(user);
continue;
}
auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
Expand All @@ -98,19 +96,20 @@ verifyAndCollectExpandableUsers(Value insertDest,
expandShapeOp.getReassociationIndices();
if (reIndices != expandReIndices)
return failure();
expandableUsers.push_back(user);
expandableUsers.push_back(extractSliceOp);
}
return success();
}

/// Utility to expand the pre-verified expandable users of the scf.forall
/// output.
static void expandVerifiedUsers(PatternRewriter &rewriter, Location loc,
MLIRContext *ctx,
SmallVector<Operation *> expandableUsers,
SmallVector<int64_t> totalInnerSizes,
SmallVector<ReassociationIndices> reIndices,
scf::ForallOp forallOp) {
static void
expandVerifiedUsers(PatternRewriter &rewriter, Location loc, MLIRContext *ctx,
SmallVector<tensor::ExtractSliceOp> expandableUsers,
SmallVector<int64_t> totalInnerSizes,
SmallVector<ReassociationIndices> reIndices,
scf::ForallOp forallOp,
tensor::ParallelInsertSliceOp parallelInsertOp) {
// compute the offsets,sizes,strides in the expanded dimensions.
auto computeExpandedAccess = [&](ArrayRef<OpFoldResult> mixedOffsets,
ShapedType resultType)
Expand All @@ -124,6 +123,7 @@ static void expandVerifiedUsers(PatternRewriter &rewriter, Location loc,
for (size_t i = 1, e = reIndices[index].size(); i < e; ++i) {
expandedOffsets.push_back(getAsIndexOpFoldResult(ctx, 0));
}
rewriter.setInsertionPointToStart(forallOp.getBody());
// Compute the outer dimension expression.
AffineExpr s0, s1;
bindSymbols(rewriter.getContext(), s0, s1);
Expand All @@ -144,31 +144,20 @@ static void expandVerifiedUsers(PatternRewriter &rewriter, Location loc,
rewriter.getIndexAttr(1));
return {expandedOffsets, expandedSizes, expandedStrides};
};
for (Operation *user : expandableUsers) {
rewriter.setInsertionPointToStart(forallOp.getBody());
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(user)) {
auto expandShapeOp =
dyn_cast<tensor::ExpandShapeOp>(*extractSliceOp->getUsers().begin());
RankedTensorType resultType = expandShapeOp.getResultType();
auto [expandedOffsets, expandedSizes, expandedStrides] =
computeExpandedAccess(extractSliceOp.getMixedOffsets(), resultType);
rewriter.setInsertionPoint(extractSliceOp);
rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
extractSliceOp, resultType, extractSliceOp.getSource(),
expandedOffsets, expandedSizes, expandedStrides);
} else if (auto parallelInsertOp =
dyn_cast<tensor::ParallelInsertSliceOp>(user)) {
auto collapseShapeOp =
parallelInsertOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
RankedTensorType resultType = collapseShapeOp.getSrcType();
auto [expandedOffsets, expandedSizes, expandedStrides] =
computeExpandedAccess(parallelInsertOp.getMixedOffsets(), resultType);
rewriter.setInsertionPoint(parallelInsertOp);
rewriter.replaceOpWithNewOp<tensor::ParallelInsertSliceOp>(
parallelInsertOp, collapseShapeOp.getSrc(),
parallelInsertOp.getDest(), expandedOffsets, expandedSizes,
expandedStrides);
}
auto collapseShapeOp =
parallelInsertOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
RankedTensorType resultType = collapseShapeOp.getSrcType();
auto [expandedOffsets, expandedSizes, expandedStrides] =
computeExpandedAccess(parallelInsertOp.getMixedOffsets(), resultType);
rewriter.setInsertionPoint(parallelInsertOp);
rewriter.replaceOpWithNewOp<tensor::ParallelInsertSliceOp>(
parallelInsertOp, collapseShapeOp.getSrc(), parallelInsertOp.getDest(),
expandedOffsets, expandedSizes, expandedStrides);
for (tensor::ExtractSliceOp extractSliceOp : expandableUsers) {
rewriter.setInsertionPoint(extractSliceOp);
rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
extractSliceOp, resultType, extractSliceOp.getSource(), expandedOffsets,
expandedSizes, expandedStrides);
}
return;
}
Expand Down Expand Up @@ -242,7 +231,7 @@ struct ExpandDestinationForallOp final

// Verify that the users of destination are valid to expand and collect all
// such users.
SmallVector<Operation *> expandableUsers;
SmallVector<tensor::ExtractSliceOp> expandableUsers;
if (failed(verifyAndCollectExpandableUsers(
insertDest, collapseOp.getReassociationIndices(), parallelInsertOp,
expandableUsers))) {
Expand All @@ -252,7 +241,7 @@ struct ExpandDestinationForallOp final
// Expand the users of the destination.
rewriter.setInsertionPointToStart(forallOp.getBody());
expandVerifiedUsers(rewriter, loc, ctx, expandableUsers, totalInnerSizes,
reIndices, forallOp);
reIndices, forallOp, parallelInsertOp);
rewriter.setInsertionPoint(forallOp);

// Create the expand -> new scf.forall -> collapse chain.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-propagate-reshapes-by-expansion), cse)" \
// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-propagate-reshapes-by-expansion))" \
// RUN: --split-input-file %s --mlir-print-local-scope | FileCheck %s

func.func @reshape_and_lowering_config(%src: tensor<3x4xf16>, %dest: tensor<12xf16>, %dest2: tensor<12xf16>) -> tensor<12xf16> {
Expand Down

0 comments on commit 4b15edd

Please sign in to comment.