Skip to content

Commit

Permalink
[GPU] Hoist collapse shape out of scf.forall when possible
Browse files Browse the repository at this point in the history
Signed-off-by: Nirvedh <[email protected]>
  • Loading branch information
nirvedhmeshram committed Nov 6, 2024
1 parent 842bcbc commit e44793b
Show file tree
Hide file tree
Showing 2 changed files with 370 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "iree/compiler/Codegen/Transforms/Transforms.h"
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

Expand All @@ -17,6 +19,271 @@ namespace mlir::iree_compiler {

namespace {

/// Calculate the expanded shape of `dest` if it can be expanded with the inner
/// expanded sizes of `sliceStaticSizes`. Returns failure if such expansion is
/// not possible.
static LogicalResult
getExpandedShape(SmallVector<ReassociationIndices> reIndices,
ArrayRef<int64_t> sliceStaticSizes, Value dest,
SmallVector<int64_t> &expandedShape,
SmallVector<int64_t> &totalInnerSizes) {
auto destType = dyn_cast<ShapedType>(dest.getType());
if (!destType)
return failure();
// Iterator to insert outer sizes.
auto outerShapeIter = expandedShape.begin();
for (auto [reassociations, destSize] :
llvm::zip_equal(reIndices, destType.getShape())) {
int64_t totalInnerSize = 1;
for (int i = 1; i < reassociations.size(); i++) {
int64_t expandedInnerSize = sliceStaticSizes[reassociations[i]];
if (ShapedType::isDynamic(expandedInnerSize)) {
return failure();
}
expandedShape.push_back(expandedInnerSize);
totalInnerSize *= expandedInnerSize;
}
if (destSize % totalInnerSize != 0) {
return failure();
}
totalInnerSizes.push_back(totalInnerSize);
// insert the outer size in front of any inner sizes.
expandedShape.insert(outerShapeIter, destSize / totalInnerSize);
// set up the iterator for the next uncollapsed dimension.
outerShapeIter = expandedShape.end();
}
return success();
}

/// Check if the users of the expanded scf.forall destination can be updated to
/// account for the expand. If not we bail out. There are two supported users
/// which are extract_slice -> expand_shape with the same exact reassociation
/// map as the collapse op to be hoisted out or a parallel_insert_slice.
static LogicalResult
verifyandCollectExpandableUsers(Value insertDest,
SmallVector<ReassociationIndices> reIndices,
SmallVector<Operation *> &expandableUsers) {
for (auto user : insertDest.getUsers()) {
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(user)) {
auto expandShapeOp =
dyn_cast<tensor::ExpandShapeOp>(*extractSliceOp->getUsers().begin());
if (!expandShapeOp)
return failure();
auto expandReIndices = expandShapeOp.getReassociationIndices();
if (reIndices != expandReIndices) {
return failure();
}
expandableUsers.push_back(user);
} else if (auto parallelInsertOp =
dyn_cast<tensor::ParallelInsertSliceOp>(user)) {
expandableUsers.push_back(user);
} else
return failure();
}
return success();
}

/// Utility to expand the pre-verified expandable users of the scf.forall
/// output.
static void expandVerifiedUsers(PatternRewriter &rewriter, Location loc,
MLIRContext *ctx,
SmallVector<Operation *> expandableUsers,
SmallVector<int64_t> totalInnerSizes,
SmallVector<ReassociationIndices> reIndices,
scf::ForallOp forallOp) {
// The user expands and producer collapses need to be
// unflattened e.g %collapsed = tensor.collapse_shape %transposed [[0, 1], [2,
// 3]] : tensor<8x16x8x16xf32> into tensor<128x128xf32> can be unflattened to
// %collapsed = tensor.collapse_shape %transposed [[0], [1], [2], [3]] :
// tensor<8x16x8x16xf32> into tensor<8x16x8x16xf32> and then is consumed by
// the expanded parallel_insert_slice_op.
SmallVector<ReassociationIndices> unFlattenReassociations;
for (auto inds : reIndices) {
for (auto i : inds) {
unFlattenReassociations.push_back({i});
}
}
// compute the offsets,sizes,strides in the expanded dimensions.
auto computeExpandedAccess = [&](ArrayRef<OpFoldResult> mixedOffsets,
ShapedType resultType) {
SmallVector<OpFoldResult> expandedOffsets;
auto expandedOffsetsIter = expandedOffsets.begin();

for (auto [index, offset] : llvm::enumerate(mixedOffsets)) {
// Add zero offsets for the extra dimensions from reIndices.
for (int i = 1; i < reIndices[index].size(); i++) {
expandedOffsets.push_back(getAsIndexOpFoldResult(ctx, 0));
}
// Compute the outer dimension expression.
AffineExpr s0, s1;
bindSymbols(rewriter.getContext(), s0, s1);
AffineExpr outerDimExpr = (s0).floorDiv(s1);
// Insert computed offset using affine expression.
expandedOffsets.insert(
expandedOffsetsIter,
affine::makeComposedFoldedAffineApply(
rewriter, loc, outerDimExpr,
{getValueOrCreateConstantIndexOp(rewriter, loc, offset),
rewriter.getIndexAttr(totalInnerSizes[index])}));

expandedOffsetsIter = expandedOffsets.end();
}
ArrayRef<int64_t> expandedShape = resultType.getShape();
SmallVector<OpFoldResult> expandedSizes;
for (auto size : expandedShape) {
expandedSizes.push_back(getAsIndexOpFoldResult(ctx, size));
}
SmallVector<OpFoldResult> expandedStrides(resultType.getRank(),
rewriter.getIndexAttr(1));
return std::make_tuple(expandedOffsets, expandedSizes, expandedStrides);
};
for (auto user : expandableUsers) {
rewriter.setInsertionPointToStart(forallOp.getBody());
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(user)) {
auto expandShapeOp =
dyn_cast<tensor::ExpandShapeOp>(*extractSliceOp->getUsers().begin());
auto resultType = expandShapeOp.getResultType();
auto [expandedOffsets, expandedSizes, expandedStrides] =
computeExpandedAccess(extractSliceOp.getMixedOffsets(), resultType);
rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
extractSliceOp, resultType, extractSliceOp.getSource(),
expandedOffsets, expandedSizes, expandedStrides);
rewriter.setInsertionPoint(expandShapeOp);
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
expandShapeOp, resultType, expandShapeOp.getSrc(),
unFlattenReassociations);
} else if (auto parallelInsertOp =
dyn_cast<tensor::ParallelInsertSliceOp>(user)) {
auto collapseShapeOp =
parallelInsertOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
auto resultType = collapseShapeOp.getSrcType();
auto [expandedOffsets, expandedSizes, expandedStrides] =
computeExpandedAccess(parallelInsertOp.getMixedOffsets(), resultType);

rewriter.setInsertionPoint(collapseShapeOp);
auto newCollapseOp = rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
collapseShapeOp, collapseShapeOp.getSrcType(),
collapseShapeOp.getSrc(), unFlattenReassociations);
rewriter.setInsertionPoint(parallelInsertOp);
rewriter.replaceOpWithNewOp<tensor::ParallelInsertSliceOp>(
parallelInsertOp, newCollapseOp.getResult(),
parallelInsertOp.getDest(), expandedOffsets, expandedSizes,
expandedStrides);
}
}
return;
}

/// This pattern expands destination of workgroup mapped scf.foralls by
/// hoisting out collapse_shape op consumed by its parallel.insert_slice op.
struct ExpandDestinationForallOp final
: OpRewritePattern<tensor::ParallelInsertSliceOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::ParallelInsertSliceOp parallelInsertOp,
PatternRewriter &rewriter) const override {
Location loc = parallelInsertOp.getLoc();
MLIRContext *ctx = getContext();
auto collapseOp =
parallelInsertOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
// No collapse op to hoist out.
if (!collapseOp) {
return failure();
}

// Ignore trivially foldable collapse ops.
if (collapseOp.getSrcType().getRank() ==
collapseOp.getResultType().getRank()) {
return failure();
}

// Get the destination to expand.
Value insertDest = parallelInsertOp.getDest();

// Get the enclosing scf.forall op.
OpResult tiedResult = parallelInsertOp.getTiedOpResult();
auto forallOp = dyn_cast<scf::ForallOp>(tiedResult.getOwner());
if (!forallOp) {
return failure();
}
// This allows us to assume that the extract/inserts in the loop are
// disjoint and makes the application of this pattern safe.
if (!forallOpHasMappingType<IREE::Codegen::WorkgroupMappingAttr>(
forallOp)) {
return failure();
}
// This pattern only supports forall ops with single
// output.
SmallVector<Value> forallOutputs(forallOp.getOutputs());
if (forallOutputs.size() != 1) {
return failure();
}

SmallVector<ReassociationIndices> reIndices =
collapseOp.getReassociationIndices();
SmallVector<int64_t> expandedDestShape;
SmallVector<int64_t> totalInnerSizes;
// Get the shape of the outer expand which will be the new destination
// of the scf.forall and the total size of inner dimensions per uncollapsed
// dimension.
if (failed(getExpandedShape(reIndices, collapseOp.getSrcType().getShape(),
insertDest, expandedDestShape,
totalInnerSizes))) {
return failure();
}

// Verify that the users of destination are valid to expand and collect all
// such users.
SmallVector<Operation *> expandableUsers;
if (failed(verifyandCollectExpandableUsers(
insertDest, collapseOp.getReassociationIndices(),
expandableUsers))) {
return failure();
}

// Expand the users of the destination.
rewriter.setInsertionPointToStart(forallOp.getBody());
expandVerifiedUsers(rewriter, loc, ctx, expandableUsers, totalInnerSizes,
reIndices, forallOp);
rewriter.setInsertionPoint(forallOp);

auto outOp = forallOutputs[0].getDefiningOp();
if (!outOp) {
return failure();
}

// Create the expand -> new scf.forall -> collapse chain.
Type expandedDestType = RankedTensorType::get(
expandedDestShape,
cast<ShapedType>(outOp->getResult(0).getType()).getElementType());
auto expandedDest = rewriter.create<tensor::ExpandShapeOp>(
loc, expandedDestType, outOp->getResult(0), reIndices);

scf::ForallOp newForallOp = rewriter.create<scf::ForallOp>(
loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
forallOp.getMixedStep(), ValueRange{expandedDest},
forallOp.getMappingAttr());

auto collapsedResultOp = rewriter.create<tensor::CollapseShapeOp>(
loc, cast<ShapedType>(forallOp->getResult(0).getType()),
newForallOp->getResult(0), reIndices);

// Merge the old scf.forall block which has the expanded users into the new
// scf.forall which has the expanded destination.
SmallVector<Value> argReplacements(newForallOp.getInductionVars());
for (auto forallIterArg : newForallOp.getRegionIterArgs()) {
argReplacements.push_back(forallIterArg);
}
scf::InParallelOp parallelTerminator = newForallOp.getTerminator();
parallelTerminator->erase();
rewriter.mergeBlocks(forallOp.getBody(), newForallOp.getBody(),
argReplacements);

// Replaces the uses of the old scf.forall with the new scf.forall
forallOp->getResult(0).replaceAllUsesWith(collapsedResultOp->getResult(0));
return success();
}
};

struct PropagateReshapesByExpansionPass final
: impl::PropagateReshapesByExpansionPassBase<
PropagateReshapesByExpansionPass> {
Expand Down Expand Up @@ -65,6 +332,7 @@ void PropagateReshapesByExpansionPass::runOnOperation() {
tensor::ExpandShapeOp::getCanonicalizationPatterns(bubbleExpandShapePatterns,
context);
populateReshapeToInterfaceTensorPatterns(bubbleExpandShapePatterns);
bubbleExpandShapePatterns.add<ExpandDestinationForallOp>(context);

if (failed(applyPatternsAndFoldGreedily(
getOperation(), std::move(bubbleExpandShapePatterns)))) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-propagate-reshapes-by-expansion))" --split-input-file %s --mlir-print-local-scope | FileCheck %s
// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-propagate-reshapes-by-expansion), cse)" \
// RUN: --split-input-file %s --mlir-print-local-scope | FileCheck %s

func.func @reshape_and_lowering_config(%src: tensor<3x4xf16>, %dest: tensor<12xf16>, %dest2: tensor<12xf16>) -> tensor<12xf16> {
%collapse = tensor.collapse_shape %src [[0, 1]] : tensor<3x4xf16> into tensor<12xf16>
Expand Down Expand Up @@ -86,3 +87,103 @@ func.func @fold_collapse_into_stores_dynamic(%arg0 : tensor<2x?x32xf32>) {
// CHECK: flow.dispatch.tensor.store %{{.+}}, %[[SUBSPAN]]
// CHECK-SAME: offsets = [0, 0, 0], sizes = [2, %[[SHAPE]], 32], strides = [1, 1, 1]
// CHECK-SAME: !flow.dispatch.tensor<writeonly:tensor<2x?x32xf32>>{%[[SHAPE]]}

// -----

#pipeline_layout = #hal.pipeline.layout<constants = 1, bindings = [
#hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>
func.func @expand_dest_forall_workgroup_mapped() {
%cst = arith.constant 0.000000e+00 : f16
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0)
flags(Indirect) : !flow.dispatch.tensor<writeonly:tensor<2048x10240xf32>>
%1 = tensor.empty() : tensor<2048x10240xf32>
%2 = scf.forall (%arg0, %arg1) = (0, 0) to (2048, 10240) step (128, 128)
shared_outs(%arg2 = %1) -> (tensor<2048x10240xf32>) {
%extracted_slice = tensor.extract_slice %arg2[%arg0, %arg1] [128, 128] [1, 1]
: tensor<2048x10240xf32> to tensor<128x128xf32>
%3 = tensor.empty() : tensor<8x8x16x8x2xf32>
%4 = linalg.fill ins(%cst : f16) outs(%3 : tensor<8x8x16x8x2xf32>) -> tensor<8x8x16x8x2xf32>
%5 = tensor.empty() : tensor<8x16x8x8x2xf32>
%transposed = linalg.transpose ins(%4 : tensor<8x8x16x8x2xf32>)
outs(%5 : tensor<8x16x8x8x2xf32>) permutation = [0, 2, 1, 3, 4]
%expanded = tensor.expand_shape %extracted_slice [[0, 1], [2, 3, 4]]
output_shape [8, 16, 8, 8, 2] : tensor<128x128xf32> into tensor<8x16x8x8x2xf32>
%6 = linalg.copy ins(%transposed : tensor<8x16x8x8x2xf32>)
outs(%expanded : tensor<8x16x8x8x2xf32>) -> tensor<8x16x8x8x2xf32>
%collapsed = tensor.collapse_shape %6 [[0, 1], [2, 3, 4]] : tensor<8x16x8x8x2xf32> into tensor<128x128xf32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %collapsed into %arg2[%arg0, %arg1] [128, 128] [1, 1]
: tensor<128x128xf32> into tensor<2048x10240xf32>
}
} {mapping = [#iree_codegen.workgroup_mapping<y>, #iree_codegen.workgroup_mapping<x>]}
flow.dispatch.tensor.store %2, %0, offsets = [0, 0], sizes = [2048, 10240], strides = [1, 1]
: tensor<2048x10240xf32> -> !flow.dispatch.tensor<writeonly:tensor<2048x10240xf32>>
return
}

// CHECK-LABEL: func @expand_dest_forall_workgroup_mapped(
// CHECK: %[[SUBSPAN:.+]] = hal.interface.binding.subspan
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<128x16x640x8x2xf32>
// CHECK: %[[SCFFORALL:.+]] = scf.forall (%[[ARG0:.+]], %[[ARG1:.+]]) = (0, 0)
// CHECK-SAME: shared_outs(%[[ARG2:.+]] = %[[EMPTY]]) -> (tensor<128x16x640x8x2xf32>) {
// CHECK-DAG: %[[OFFSET1:.+]] = affine.apply affine_map<()[s0] -> (s0 floordiv 16)>()[%[[ARG0]]]
// CHECK-DAG: %[[OFFSET2:.+]] = affine.apply affine_map<()[s0] -> (s0 floordiv 16)>()[%[[ARG1]]]
// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[ARG2]]
// CHECK-SAME: [%[[OFFSET1]], 0, %[[OFFSET2]], 0, 0] [8, 16, 8, 8, 2] [1, 1, 1, 1, 1]
// CHECK-SAME: tensor<128x16x640x8x2xf32> to tensor<8x16x8x8x2xf32>
// CHECK: tensor.parallel_insert_slice %{{.+}} into %[[ARG2]]
// CHECK-SAME: [%[[OFFSET1]], 0, %[[OFFSET2]], 0, 0] [8, 16, 8, 8, 2] [1, 1, 1, 1, 1]
// CHECK-SAME: tensor<8x16x8x8x2xf32> into tensor<128x16x640x8x2xf32>
// CHECK: flow.dispatch.tensor.store %[[SCFFORALL]], %[[SUBSPAN]]
// CHECK-SAME: offsets = [0, 0, 0, 0, 0], sizes = [128, 16, 640, 8, 2], strides = [1, 1, 1, 1, 1]
// CHECK-SAME: !flow.dispatch.tensor<writeonly:tensor<128x16x640x8x2xf32>>

// -----

#pipeline_layout = #hal.pipeline.layout<constants = 1, bindings = [
#hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>
func.func @no_expand_dest_forall_not_workgroup_mapped() {
%cst = arith.constant 0.000000e+00 : f16
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0)
flags(Indirect) : !flow.dispatch.tensor<writeonly:tensor<2048x10240xf32>>
%1 = tensor.empty() : tensor<2048x10240xf32>
%2 = scf.forall (%arg0, %arg1) = (0, 0) to (2048, 10240) step (128, 128)
shared_outs(%arg2 = %1) -> (tensor<2048x10240xf32>) {
%extracted_slice = tensor.extract_slice %arg2[%arg0, %arg1] [128, 128] [1, 1]
: tensor<2048x10240xf32> to tensor<128x128xf32>
%3 = tensor.empty() : tensor<8x8x16x8x2xf32>
%4 = linalg.fill ins(%cst : f16) outs(%3 : tensor<8x8x16x8x2xf32>) -> tensor<8x8x16x8x2xf32>
%5 = tensor.empty() : tensor<8x16x8x8x2xf32>
%transposed = linalg.transpose ins(%4 : tensor<8x8x16x8x2xf32>)
outs(%5 : tensor<8x16x8x8x2xf32>) permutation = [0, 2, 1, 3, 4]
%expanded = tensor.expand_shape %extracted_slice [[0, 1], [2, 3, 4]]
output_shape [8, 16, 8, 8, 2] : tensor<128x128xf32> into tensor<8x16x8x8x2xf32>
%6 = linalg.copy ins(%transposed : tensor<8x16x8x8x2xf32>)
outs(%expanded : tensor<8x16x8x8x2xf32>) -> tensor<8x16x8x8x2xf32>
%collapsed = tensor.collapse_shape %6 [[0, 1], [2, 3, 4]] : tensor<8x16x8x8x2xf32> into tensor<128x128xf32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %collapsed into %arg2[%arg0, %arg1] [128, 128] [1, 1]
: tensor<128x128xf32> into tensor<2048x10240xf32>
}
}
flow.dispatch.tensor.store %2, %0, offsets = [0, 0], sizes = [2048, 10240], strides = [1, 1]
: tensor<2048x10240xf32> -> !flow.dispatch.tensor<writeonly:tensor<2048x10240xf32>>
return
}

// CHECK-LABEL: func @no_expand_dest_forall_not_workgroup_mapped(
// CHECK: %[[SUBSPAN:.+]] = hal.interface.binding.subspan
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<2048x10240xf32>
// CHECK: %[[SCFFORALL:.+]] = scf.forall (%[[ARG0:.+]], %[[ARG1:.+]]) = (0, 0)
// CHECK-SAME: shared_outs(%[[ARG2:.+]] = %[[EMPTY]]) -> (tensor<2048x10240xf32>) {
// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[ARG2]]
// CHECK-SAME: [%[[ARG0]], %[[ARG1]]] [128, 128] [1, 1]
// CHECK-SAME: tensor<2048x10240xf32> to tensor<128x128xf32>
// CHECK: tensor.parallel_insert_slice %{{.+}} into %[[ARG2]]
// CHECK-SAME: [%[[ARG0]], %[[ARG1]]] [128, 128] [1, 1]
// CHECK-SAME: tensor<128x128xf32> into tensor<2048x10240xf32>
// CHECK: flow.dispatch.tensor.store %[[SCFFORALL]], %[[SUBSPAN]]
// CHECK-SAME: offsets = [0, 0], sizes = [2048, 10240], strides = [1, 1]
// CHECK-SAME: !flow.dispatch.tensor<writeonly:tensor<2048x10240xf32>>

0 comments on commit e44793b

Please sign in to comment.