diff --git a/compiler/src/iree/compiler/Codegen/Common/PropagateReshapesByExpansion.cpp b/compiler/src/iree/compiler/Codegen/Common/PropagateReshapesByExpansion.cpp index 67d2358d7f895..0f16f67130be1 100644 --- a/compiler/src/iree/compiler/Codegen/Common/PropagateReshapesByExpansion.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/PropagateReshapesByExpansion.cpp @@ -7,6 +7,9 @@ #include "iree/compiler/Codegen/Common/Passes.h" #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" #include "iree/compiler/Codegen/Transforms/Transforms.h" +#include "iree/compiler/Codegen/Utils/GPUUtils.h" +#include "iree/compiler/Codegen/Utils/Utils.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -17,6 +20,282 @@ namespace mlir::iree_compiler { namespace { +/// Calculate the expanded shape of `dest` if it can be expanded with the inner +/// expanded sizes of `sliceStaticSizes`. Returns failure if such expansion is +/// not possible. +static LogicalResult +getExpandedShape(SmallVector reIndices, + ArrayRef sliceStaticSizes, Value dest, + SmallVectorImpl &expandedShape, + SmallVectorImpl &totalInnerSizes) { + auto destType = dyn_cast(dest.getType()); + if (!destType) + return failure(); + // TODO (nirvedhmeshram): Support rank reducing parallel_insert_slice. + if (reIndices.size() != destType.getShape().size()) + return failure(); + // Iterator to insert outer sizes. + auto outerShapeIter = expandedShape.begin(); + for (auto [reassociations, destSize] : + llvm::zip_equal(reIndices, destType.getShape())) { + // Dynamic destination dims that are not getting expanded are allowed. + if (ShapedType::isDynamic(destSize) && reassociations.size() == 1) { + expandedShape.insert(outerShapeIter++, destSize); + totalInnerSizes.push_back(1); + continue; + } + // Dynamic destination dims that are expanded are currently unsupported but + // this support can be added if needed. + if (ShapedType::isDynamic(destSize)) { + return failure(); + } + int64_t totalInnerSize = 1; + for (int64_t reasociation : llvm::drop_begin(reassociations)) { + int64_t expandedInnerSize = sliceStaticSizes[reasociation]; + // It is not safe to do this pattern if inner dimensions are dynamic. + if (ShapedType::isDynamic(expandedInnerSize)) + return failure(); + expandedShape.push_back(expandedInnerSize); + totalInnerSize *= expandedInnerSize; + } + if (destSize % totalInnerSize != 0) + return failure(); + totalInnerSizes.push_back(totalInnerSize); + // insert the outer size in front of any inner sizes. + expandedShape.insert(outerShapeIter, destSize / totalInnerSize); + // set up the iterator for the next uncollapsed dimension. + outerShapeIter = expandedShape.end(); + } + return success(); +} + +/// Check if the users of the expanded scf.forall destination can be updated to +/// account for the expand. If not we bail out. There are two supported users +/// which are extract_slice -> expand_shape with the same exact reassociation +/// map as the collapse op to be hoisted out or the root parallel_insert_slice. +static LogicalResult +verifyAndCollectExpandableUsers(Value insertDest, + SmallVector reIndices, + tensor::ParallelInsertSliceOp parallelInsertOp, + SmallVector &expandableUsers) { + for (Operation *user : insertDest.getUsers()) { + if (user == parallelInsertOp) { + expandableUsers.push_back(user); + continue; + } + auto extractSliceOp = dyn_cast(user); + if (!extractSliceOp) + return failure(); + if (extractSliceOp.getMixedSizes() != parallelInsertOp.getMixedSizes()) + return failure(); + if (extractSliceOp.getMixedOffsets() != parallelInsertOp.getMixedOffsets()) + return failure(); + auto expandShapeOp = + dyn_cast(*extractSliceOp->getUsers().begin()); + if (!expandShapeOp) + return failure(); + SmallVector expandReIndices = + expandShapeOp.getReassociationIndices(); + if (reIndices != expandReIndices) + return failure(); + expandableUsers.push_back(user); + } + return success(); +} + +/// Utility to expand the pre-verified expandable users of the scf.forall +/// output. +static void expandVerifiedUsers(PatternRewriter &rewriter, Location loc, + MLIRContext *ctx, + SmallVector expandableUsers, + SmallVector totalInnerSizes, + SmallVector reIndices, + scf::ForallOp forallOp) { + // compute the offsets,sizes,strides in the expanded dimensions. + auto computeExpandedAccess = [&](ArrayRef mixedOffsets, + ShapedType resultType) + -> std::tuple, SmallVector, + SmallVector> { + SmallVector expandedOffsets; + auto expandedOffsetsIter = expandedOffsets.begin(); + + for (auto [index, offset] : llvm::enumerate(mixedOffsets)) { + // Add zero offsets for the extra dimensions from reIndices. + for (size_t i = 1, e = reIndices[index].size(); i < e; ++i) { + expandedOffsets.push_back(getAsIndexOpFoldResult(ctx, 0)); + } + // Compute the outer dimension expression. + AffineExpr s0, s1; + bindSymbols(rewriter.getContext(), s0, s1); + AffineExpr outerDimExpr = (s0).floorDiv(s1); + // Insert computed offset using affine expression. + expandedOffsets.insert( + expandedOffsetsIter, + affine::makeComposedFoldedAffineApply( + rewriter, loc, outerDimExpr, + {getValueOrCreateConstantIndexOp(rewriter, loc, offset), + rewriter.getIndexAttr(totalInnerSizes[index])})); + + expandedOffsetsIter = expandedOffsets.end(); + } + SmallVector expandedSizes = + getAsIndexOpFoldResult(ctx, resultType.getShape()); + SmallVector expandedStrides(resultType.getRank(), + rewriter.getIndexAttr(1)); + return {expandedOffsets, expandedSizes, expandedStrides}; + }; + for (Operation *user : expandableUsers) { + rewriter.setInsertionPointToStart(forallOp.getBody()); + if (auto extractSliceOp = dyn_cast(user)) { + auto expandShapeOp = + dyn_cast(*extractSliceOp->getUsers().begin()); + RankedTensorType resultType = expandShapeOp.getResultType(); + auto [expandedOffsets, expandedSizes, expandedStrides] = + computeExpandedAccess(extractSliceOp.getMixedOffsets(), resultType); + rewriter.setInsertionPoint(extractSliceOp); + rewriter.replaceOpWithNewOp( + extractSliceOp, resultType, extractSliceOp.getSource(), + expandedOffsets, expandedSizes, expandedStrides); + } else if (auto parallelInsertOp = + dyn_cast(user)) { + auto collapseShapeOp = + parallelInsertOp.getSource().getDefiningOp(); + RankedTensorType resultType = collapseShapeOp.getSrcType(); + auto [expandedOffsets, expandedSizes, expandedStrides] = + computeExpandedAccess(parallelInsertOp.getMixedOffsets(), resultType); + rewriter.setInsertionPoint(parallelInsertOp); + rewriter.replaceOpWithNewOp( + parallelInsertOp, collapseShapeOp.getSrc(), + parallelInsertOp.getDest(), expandedOffsets, expandedSizes, + expandedStrides); + } + } + return; +} + +/// This pattern expands destination of workgroup mapped scf.foralls by +/// hoisting out collapse_shape op consumed by its parallel.insert_slice op. +struct ExpandDestinationForallOp final + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(tensor::ParallelInsertSliceOp parallelInsertOp, + PatternRewriter &rewriter) const override { + Location loc = parallelInsertOp.getLoc(); + MLIRContext *ctx = getContext(); + auto collapseOp = + parallelInsertOp.getSource().getDefiningOp(); + // No collapse op to hoist out. + if (!collapseOp) + return failure(); + + // Ignore trivially foldable collapse ops. + if (collapseOp.getSrcType().getRank() == + collapseOp.getResultType().getRank()) { + return failure(); + } + + // Get the destination to expand. + Value insertDest = parallelInsertOp.getDest(); + + // Get the enclosing scf.forall op. + OpResult tiedResult = parallelInsertOp.getTiedOpResult(); + int64_t tiedResultIdx = tiedResult.getResultNumber(); + + auto forallOp = dyn_cast(tiedResult.getOwner()); + if (!forallOp) + return failure(); + + // We only want this pattern if the forall op result is being written to a + // full slice. Otherwise the hoisted collapse op is not foldable. + for (Operation *foralluser : tiedResult.getUsers()) { + auto storeOp = dyn_cast(foralluser); + if (!storeOp) + return failure(); + if (!isFullSlice(storeOp, storeOp.getTargetType(), + storeOp.getTargetDims())) { + return failure(); + } + } + + // This allows us to assume that the extract/inserts in the loop are + // disjoint and makes the application of this pattern safe. + if (!forallOpHasMappingType( + forallOp)) { + return failure(); + } + // This pattern only supports forall ops with single + // output. + SmallVector forallOutputs(forallOp.getOutputs()); + + SmallVector reIndices = + collapseOp.getReassociationIndices(); + SmallVector expandedDestShape; + SmallVector totalInnerSizes; + // Get the shape of the outer expand which will be the new destination + // of the scf.forall and the total size of inner dimensions per uncollapsed + // dimension. + if (failed(getExpandedShape(reIndices, collapseOp.getSrcType().getShape(), + insertDest, expandedDestShape, + totalInnerSizes))) { + return failure(); + } + + // Verify that the users of destination are valid to expand and collect all + // such users. + SmallVector expandableUsers; + if (failed(verifyAndCollectExpandableUsers( + insertDest, collapseOp.getReassociationIndices(), parallelInsertOp, + expandableUsers))) { + return failure(); + } + + // Expand the users of the destination. + rewriter.setInsertionPointToStart(forallOp.getBody()); + expandVerifiedUsers(rewriter, loc, ctx, expandableUsers, totalInnerSizes, + reIndices, forallOp); + rewriter.setInsertionPoint(forallOp); + + // Create the expand -> new scf.forall -> collapse chain. + auto expandedDestType = + cast(forallOutputs[tiedResultIdx].getType()) + .clone(expandedDestShape); + auto expandedDest = rewriter.create( + loc, expandedDestType, forallOutputs[tiedResultIdx], reIndices); + + forallOutputs[tiedResultIdx] = expandedDest; + + scf::ForallOp newForallOp = rewriter.create( + loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(), + forallOp.getMixedStep(), forallOutputs, forallOp.getMappingAttr()); + + auto collapsedResultOp = rewriter.create( + loc, cast(forallOp->getResult(tiedResultIdx).getType()), + newForallOp->getResult(tiedResultIdx), reIndices); + + // Merge the old scf.forall block which has the expanded users into the new + // scf.forall which has the expanded destination. + SmallVector argReplacements(newForallOp.getInductionVars()); + argReplacements.append(newForallOp.getRegionIterArgs().begin(), + newForallOp.getRegionIterArgs().end()); + scf::InParallelOp parallelTerminator = newForallOp.getTerminator(); + parallelTerminator->erase(); + rewriter.mergeBlocks(forallOp.getBody(), newForallOp.getBody(), + argReplacements); + + // Replaces the uses of the old scf.forall with the new scf.forall + for (int idx = 0; idx < forallOp->getNumResults(); ++idx) { + if (idx == tiedResultIdx) { + forallOp->getResult(idx).replaceAllUsesWith( + collapsedResultOp->getResult(0)); + } else { + forallOp->getResult(idx).replaceAllUsesWith( + newForallOp->getResult(idx)); + } + } + return success(); + } +}; + struct PropagateReshapesByExpansionPass final : impl::PropagateReshapesByExpansionPassBase< PropagateReshapesByExpansionPass> { @@ -65,6 +344,7 @@ void PropagateReshapesByExpansionPass::runOnOperation() { tensor::ExpandShapeOp::getCanonicalizationPatterns(bubbleExpandShapePatterns, context); populateReshapeToInterfaceTensorPatterns(bubbleExpandShapePatterns); + bubbleExpandShapePatterns.add(context); if (failed(applyPatternsAndFoldGreedily( getOperation(), std::move(bubbleExpandShapePatterns)))) { diff --git a/compiler/src/iree/compiler/Codegen/Common/test/propagate_reshapes_by_expansion.mlir b/compiler/src/iree/compiler/Codegen/Common/test/propagate_reshapes_by_expansion.mlir index fc9e85e3a764e..faeb828097b43 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/propagate_reshapes_by_expansion.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/propagate_reshapes_by_expansion.mlir @@ -1,4 +1,5 @@ -// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-propagate-reshapes-by-expansion))" --split-input-file %s --mlir-print-local-scope | FileCheck %s +// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-propagate-reshapes-by-expansion), cse)" \ +// RUN: --split-input-file %s --mlir-print-local-scope | FileCheck %s func.func @reshape_and_lowering_config(%src: tensor<3x4xf16>, %dest: tensor<12xf16>, %dest2: tensor<12xf16>) -> tensor<12xf16> { %collapse = tensor.collapse_shape %src [[0, 1]] : tensor<3x4xf16> into tensor<12xf16> @@ -86,3 +87,253 @@ func.func @fold_collapse_into_stores_dynamic(%arg0 : tensor<2x?x32xf32>) { // CHECK: flow.dispatch.tensor.store %{{.+}}, %[[SUBSPAN]] // CHECK-SAME: offsets = [0, 0, 0], sizes = [2, %[[SHAPE]], 32], strides = [1, 1, 1] // CHECK-SAME: !flow.dispatch.tensor>{%[[SHAPE]]} + +// ----- + +#pipeline_layout = #hal.pipeline.layout], flags = Indirect> +func.func @expand_dest_forall() { + %cst = arith.constant 0.000000e+00 : f16 + %c0 = arith.constant 0 : index + %index = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index + %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) + flags(Indirect) : !flow.dispatch.tensor>{%index} + %1 = tensor.empty(%index) : tensor + %extra = tensor.empty() : tensor<32x32xf32> + %2 = scf.forall (%arg0, %arg1) = (0, 0) to (64, 32) step (16, 16) + shared_outs(%arg2 = %1) -> (tensor) { + %extracted_slice = tensor.extract_slice %arg2[%c0, %arg0, %arg1] [1, 16, 16] [1, 1, 1] + : tensor to tensor<1x16x16xf32> + %expanded = tensor.expand_shape %extracted_slice [[0], [1], [2, 3, 4]] + output_shape [1, 16, 2, 4, 2] : tensor<1x16x16xf32> into tensor<1x16x2x4x2xf32> + %expanded_barrier = util.optimization_barrier %expanded : tensor<1x16x2x4x2xf32> + %collapsed = tensor.collapse_shape %expanded_barrier [[0], [1], [2, 3, 4]] : tensor<1x16x2x4x2xf32> into tensor<1x16x16xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %collapsed into %arg2[%c0, %arg0, %arg1] [1, 16, 16] [1, 1, 1] + : tensor<1x16x16xf32> into tensor + } + } {mapping = [#iree_codegen.workgroup_mapping, #iree_codegen.workgroup_mapping]} + flow.dispatch.tensor.store %2, %0, offsets = [0, 0, 0], sizes = [%index, 64, 32], strides = [1, 1, 1] + : tensor -> !flow.dispatch.tensor>{%index} + return +} + +// CHECK-LABEL: func @expand_dest_forall( +// CHECK: %[[LOAD_CONST:.+]] = hal.interface.constant.load +// CHECK: %[[SUBSPAN:.+]] = hal.interface.binding.subspan +// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[LOAD_CONST]]) : tensor +// CHECK: %[[SCFFORALL:.+]] = scf.forall (%[[ARG0:.+]], %[[ARG1:.+]]) = (0, 0) +// CHECK-SAME: shared_outs(%[[ARG2:.+]] = %[[EMPTY]]) -> (tensor) { +// CHECK-DAG: %[[OFFSET:.+]] = affine.apply affine_map<()[s0] -> (s0 floordiv 8)>()[%[[ARG1]]] +// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[ARG2]] +// CHECK-SAME: [0, %[[ARG0]], %[[OFFSET]], 0, 0] [1, 16, 2, 4, 2] [1, 1, 1, 1, 1] +// CHECK-SAME: tensor to tensor<1x16x2x4x2xf32> +// CHECK: %[[BARRIER:.+]] = util.optimization_barrier %[[EXTRACT]] : tensor<1x16x2x4x2xf32> +// CHECK: tensor.parallel_insert_slice %[[BARRIER]] into %[[ARG2]] +// CHECK-SAME: [0, %[[ARG0]], %[[OFFSET]], 0, 0] [1, 16, 2, 4, 2] [1, 1, 1, 1, 1] +// CHECK-SAME: tensor<1x16x2x4x2xf32> into tensor +// CHECK: flow.dispatch.tensor.store %[[SCFFORALL]], %[[SUBSPAN]] +// CHECK-SAME: offsets = [0, 0, 0, 0, 0], sizes = [%[[LOAD_CONST]], 64, 4, 4, 2], strides = [1, 1, 1, 1, 1] +// CHECK-SAME: !flow.dispatch.tensor>{%[[LOAD_CONST]]} + +// ----- +#pipeline_layout = #hal.pipeline.layout], flags = Indirect> +func.func @expand_dest_forall_multiresult() { + %cst = arith.constant 0.000000e+00 : f16 + %c0 = arith.constant 0 : index + %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) + flags(Indirect) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) + offset(%c0) flags(Indirect) : !flow.dispatch.tensor> + %2 = tensor.empty() : tensor<32xf32> + %3 = tensor.empty() : tensor<32x32xf32> + %4:2 = scf.forall (%arg0) = (0) to (32) step (16) + shared_outs(%arg1 = %3, %arg2 = %2) -> (tensor<32x32xf32>, tensor<32xf32>) { + %extracted_slice = tensor.extract_slice %arg2[%arg0] [16] [1] : tensor<32xf32> to tensor<16xf32> + %expanded = tensor.expand_shape %extracted_slice [[0, 1]] output_shape [2, 8] + : tensor<16xf32> into tensor<2x8xf32> + %5 = util.optimization_barrier %expanded : tensor<2x8xf32> + %collapsed = tensor.collapse_shape %5 [[0, 1]] : tensor<2x8xf32> into tensor<16xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %arg1 into %arg1[%c0, %c0] [32, 32] [1, 1] + : tensor<32x32xf32> into tensor<32x32xf32> + tensor.parallel_insert_slice %collapsed into %arg2[%arg0] [16] [1] + : tensor<16xf32> into tensor<32xf32> + } + } {mapping = [#iree_codegen.workgroup_mapping]} + flow.dispatch.tensor.store %4#1, %0, offsets = [0], sizes = [32], strides = [1] + : tensor<32xf32> -> !flow.dispatch.tensor> + flow.dispatch.tensor.store %4#0, %1, offsets = [0, 0], sizes = [32, 32], strides = [1, 1] + : tensor<32x32xf32> -> !flow.dispatch.tensor> + return +} + +// CHECK-LABEL: func @expand_dest_forall_multiresult( +// CHECK: %[[SUBSPAN0:.+]] = hal.interface.binding.subspan +// CHECK: %[[SUBSPAN1:.+]] = hal.interface.binding.subspan +// CHECK: %[[EMPTY0:.+]] = tensor.empty() : tensor<32x32xf32> +// CHECK: %[[EMPTY1:.+]] = tensor.empty() : tensor<4x8xf32> +// CHECK: %[[SCFFORALL:.+]]:2 = scf.forall (%[[ARG0:.+]]) = (0) to (32) step (16) +// CHECK-SAME: shared_outs(%[[ARG1:.+]] = %[[EMPTY0]], %[[ARG2:.+]] = %[[EMPTY1]]) +// CHECK-SAME: -> (tensor<32x32xf32>, tensor<4x8xf32>) { +// CHECK-DAG: %[[OFFSET:.+]] = affine.apply affine_map<()[s0] -> (s0 floordiv 8)>()[%[[ARG0]]] +// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[ARG2]] +// CHECK-SAME: [%[[OFFSET]], 0] [2, 8] [1, 1] +// CHECK-SAME: tensor<4x8xf32> to tensor<2x8xf32> +// CHECK: %[[BARRIER:.+]] = util.optimization_barrier %[[EXTRACT]] : tensor<2x8xf32> +// CHECK: tensor.parallel_insert_slice %[[ARG1]] into %[[ARG1]] +// CHECK-SAME: tensor<32x32xf32> into tensor<32x32xf32> +// CHECK: tensor.parallel_insert_slice %[[BARRIER]] into %[[ARG2]] +// CHECK-SAME: [%[[OFFSET]], 0] [2, 8] [1, 1] +// CHECK-SAME: tensor<2x8xf32> into tensor<4x8xf32> +// CHECK: flow.dispatch.tensor.store %[[SCFFORALL]]#1, %[[SUBSPAN0]] +// CHECK-SAME: offsets = [0, 0], sizes = [4, 8], strides = [1, 1] +// CHECK-SAME: !flow.dispatch.tensor> +// CHECK: flow.dispatch.tensor.store %[[SCFFORALL]]#0, %[[SUBSPAN1]] +// CHECK-SAME: offsets = [0, 0], sizes = [32, 32], strides = [1, 1] +// CHECK-SAME: !flow.dispatch.tensor> + + +// ----- + +#pipeline_layout = #hal.pipeline.layout], flags = Indirect> +func.func @noexpand_dest_forall_dynamicpacked() { + %index1 = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index + %index2 = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index + %index3 = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index + %cst = arith.constant 0.000000e+00 : f16 + %c0 = arith.constant 0 : index + %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) + flags(Indirect) : !flow.dispatch.tensor> + %2 = tensor.empty() : tensor<32xf32> + %4 = scf.forall (%arg0) = (0) to (32) step (16) + shared_outs(%arg2 = %2) -> (tensor<32xf32>) { + %extracted_slice = tensor.extract_slice %arg2[%arg0] [%index1] [1] : tensor<32xf32> to tensor + %expanded = tensor.expand_shape %extracted_slice [[0, 1]] output_shape [%index2, %index3] + : tensor into tensor + %5 = util.optimization_barrier %expanded : tensor + %collapsed = tensor.collapse_shape %5 [[0, 1]] : tensor into tensor + scf.forall.in_parallel { + tensor.parallel_insert_slice %collapsed into %arg2[%arg0] [%index1] [1] + : tensor into tensor<32xf32> + } + } {mapping = [#iree_codegen.workgroup_mapping]} + flow.dispatch.tensor.store %4, %0, offsets = [0], sizes = [32], strides = [1] : tensor<32xf32> + -> !flow.dispatch.tensor> + return +} + +// CHECK-LABEL: func @noexpand_dest_forall_dynamicpacked( +// CHECK: %[[SUBSPAN:.+]] = hal.interface.binding.subspan +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<32xf32> +// CHECK: %[[SCFFORALL:.+]] = scf.forall (%[[ARG0:.+]]) = (0) to (32) step (16) +// CHECK-SAME: shared_outs(%[[ARG2:.+]] = %[[EMPTY]]) -> (tensor<32xf32>) { +// CHECK: flow.dispatch.tensor.store %[[SCFFORALL]], %[[SUBSPAN]] +// CHECK-SAME: offsets = [0], sizes = [32], strides = [1] : tensor<32xf32> +// CHECK-SAME: !flow.dispatch.tensor> + +// ----- +#pipeline_layout = #hal.pipeline.layout], flags = Indirect> +func.func @expand_dest_forall_unsupporteduse() { + %c0 = arith.constant 0 : index + %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) + flags(Indirect) : !flow.dispatch.tensor> + %2 = tensor.empty() : tensor<32xf32> + %4 = scf.forall (%arg0) = (0) to (32) step (16) + shared_outs(%arg2 = %2) -> (tensor<32xf32>) { + %extracted_slice = tensor.extract_slice %arg2[%arg0] [16] [1] : tensor<32xf32> to tensor<16xf32> + %arith_op = arith.negf %extracted_slice : tensor<16xf32> + %expanded = tensor.expand_shape %arith_op [[0, 1]] output_shape [2, 8] + : tensor<16xf32> into tensor<2x8xf32> + %5 = util.optimization_barrier %expanded : tensor<2x8xf32> + %collapsed = tensor.collapse_shape %5 [[0, 1]] : tensor<2x8xf32> into tensor<16xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %collapsed into %arg2[%arg0] [16] [1] + : tensor<16xf32> into tensor<32xf32> + } + } {mapping = [#iree_codegen.workgroup_mapping]} + flow.dispatch.tensor.store %4, %0, offsets = [0], sizes = [32], strides = [1] : tensor<32xf32> -> !flow.dispatch.tensor> + return +} + +// CHECK-LABEL: func @expand_dest_forall_unsupporteduse( +// CHECK: %[[SUBSPAN:.+]] = hal.interface.binding.subspan +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<32xf32> +// CHECK: %[[SCFFORALL:.+]] = scf.forall (%[[ARG0:.+]]) = (0) to (32) step (16) +// CHECK-SAME: shared_outs(%[[ARG2:.+]] = %[[EMPTY]]) -> (tensor<32xf32>) { +// CHECK: flow.dispatch.tensor.store %[[SCFFORALL]], %[[SUBSPAN]] +// CHECK-SAME: offsets = [0], sizes = [32], strides = [1] : tensor<32xf32> +// CHECK-SAME: !flow.dispatch.tensor> + + +// ----- + +#pipeline_layout = #hal.pipeline.layout], flags = Indirect> +func.func @noexpand_dest_forall_nomapping() { + %c0 = arith.constant 0 : index + %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) + flags(Indirect) : !flow.dispatch.tensor> + %2 = tensor.empty() : tensor<32xf32> + %4 = scf.forall (%arg0) = (0) to (32) step (16) + shared_outs(%arg2 = %2) -> (tensor<32xf32>) { + %extracted_slice = tensor.extract_slice %arg2[%arg0] [16] [1] : tensor<32xf32> to tensor<16xf32> + %expanded = tensor.expand_shape %extracted_slice [[0, 1]] output_shape [2, 8] + : tensor<16xf32> into tensor<2x8xf32> + %5 = util.optimization_barrier %expanded : tensor<2x8xf32> + %collapsed = tensor.collapse_shape %5 [[0, 1]] : tensor<2x8xf32> into tensor<16xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %collapsed into %arg2[%arg0] [16] [1] + : tensor<16xf32> into tensor<32xf32> + } + } + flow.dispatch.tensor.store %4, %0, offsets = [0], sizes = [32], strides = [1] : tensor<32xf32> -> !flow.dispatch.tensor> + return +} + +// CHECK-LABEL: func @noexpand_dest_forall_nomapping( +// CHECK: %[[SUBSPAN:.+]] = hal.interface.binding.subspan +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<32xf32> +// CHECK: %[[SCFFORALL:.+]] = scf.forall (%[[ARG0:.+]]) = (0) to (32) step (16) +// CHECK-SAME: shared_outs(%[[ARG2:.+]] = %[[EMPTY]]) -> (tensor<32xf32>) { +// CHECK: flow.dispatch.tensor.store %[[SCFFORALL]], %[[SUBSPAN]] +// CHECK-SAME: offsets = [0], sizes = [32], strides = [1] : tensor<32xf32> +// CHECK-SAME: !flow.dispatch.tensor> + + +// ----- + +#pipeline_layout = #hal.pipeline.layout], flags = Indirect> +func.func @noexpand_dest_forall_notfullslicestore() { + %c0 = arith.constant 0 : index + %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) + flags(Indirect) : !flow.dispatch.tensor> + %2 = tensor.empty() : tensor<32xf32> + %4 = scf.forall (%arg0) = (0) to (32) step (16) + shared_outs(%arg2 = %2) -> (tensor<32xf32>) { + %extracted_slice = tensor.extract_slice %arg2[%arg0] [16] [1] : tensor<32xf32> to tensor<16xf32> + %expanded = tensor.expand_shape %extracted_slice [[0, 1]] output_shape [2, 8] + : tensor<16xf32> into tensor<2x8xf32> + %5 = util.optimization_barrier %expanded : tensor<2x8xf32> + %collapsed = tensor.collapse_shape %5 [[0, 1]] : tensor<2x8xf32> into tensor<16xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %collapsed into %arg2[%arg0] [16] [1] + : tensor<16xf32> into tensor<32xf32> + } + } {mapping = [#iree_codegen.workgroup_mapping]} + flow.dispatch.tensor.store %4, %0, offsets = [1], sizes = [32], strides = [1] : tensor<32xf32> -> !flow.dispatch.tensor> + return +} + +// CHECK-LABEL: func @noexpand_dest_forall_notfullslicestore( +// CHECK: %[[SUBSPAN:.+]] = hal.interface.binding.subspan +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<32xf32> +// CHECK: %[[SCFFORALL:.+]] = scf.forall (%[[ARG0:.+]]) = (0) to (32) step (16) +// CHECK-SAME: shared_outs(%[[ARG2:.+]] = %[[EMPTY]]) -> (tensor<32xf32>) { +// CHECK: flow.dispatch.tensor.store %[[SCFFORALL]], %[[SUBSPAN]] +// CHECK-SAME: offsets = [1], sizes = [32], strides = [1] : tensor<32xf32> +// CHECK-SAME: !flow.dispatch.tensor> diff --git a/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp b/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp index 5fe45e9cd8625..980925ccb11ad 100644 --- a/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp +++ b/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp @@ -13,6 +13,7 @@ #include "iree/compiler/Codegen/Transforms/Transforms.h" #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h" +#include "iree/compiler/Codegen/Utils/Utils.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Debug.h" #include "mlir/Analysis/Liveness.h" @@ -36,31 +37,6 @@ namespace mlir::iree_compiler { -static bool isAllConstantValue(ArrayRef ofrs, int64_t v) { - return llvm::all_of( - ofrs, [&](OpFoldResult ofr) { return isConstantIntValue(ofr, v); }); -} - -static bool isFullSlice(ArrayRef mixedOffsets, - ArrayRef mixedSizes, - ArrayRef mixedStrides, - IREE::Flow::DispatchTensorType tensorType, - ValueRange dynamicDims) { - OpBuilder builder(tensorType.getContext()); - SmallVector tensorShape = llvm::to_vector(tensorType.getShape()); - SmallVector mixedTensorShape = - mlir::getMixedValues(tensorShape, dynamicDims, builder); - return isAllConstantValue(mixedOffsets, 0) && - isAllConstantValue(mixedStrides, 1) && mixedTensorShape == mixedSizes; -} -static bool isFullSlice(OffsetSizeAndStrideOpInterface sliceLoadStoreOp, - IREE::Flow::DispatchTensorType tensorType, - ValueRange dynamicDims) { - return isFullSlice( - sliceLoadStoreOp.getMixedOffsets(), sliceLoadStoreOp.getMixedSizes(), - sliceLoadStoreOp.getMixedStrides(), tensorType, dynamicDims); -} - static bool sliceFilter(Operation *op, ValueRange nonIndexComputationOperands, Operation *baseOp) { for (auto val : nonIndexComputationOperands) { diff --git a/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp b/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp index 544e0558ead22..f86f447c49dc1 100644 --- a/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp +++ b/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp @@ -1511,4 +1511,26 @@ computeDimUpperBound(Value shapedValue, unsigned dimNum, return roundedDimBound.getSize(); } +static bool isFullSlice(ArrayRef mixedOffsets, + ArrayRef mixedSizes, + ArrayRef mixedStrides, + IREE::Flow::DispatchTensorType tensorType, + ValueRange dynamicDims) { + OpBuilder builder(tensorType.getContext()); + SmallVector tensorShape = llvm::to_vector(tensorType.getShape()); + SmallVector mixedTensorShape = + mlir::getMixedValues(tensorShape, dynamicDims, builder); + return areAllConstantIntValue(mixedOffsets, 0) && + areAllConstantIntValue(mixedStrides, 1) && + mixedTensorShape == mixedSizes; +} + +bool isFullSlice(OffsetSizeAndStrideOpInterface sliceLoadStoreOp, + IREE::Flow::DispatchTensorType tensorType, + ValueRange dynamicDims) { + return isFullSlice( + sliceLoadStoreOp.getMixedOffsets(), sliceLoadStoreOp.getMixedSizes(), + sliceLoadStoreOp.getMixedStrides(), tensorType, dynamicDims); +} + } // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/Utils/Utils.h b/compiler/src/iree/compiler/Codegen/Utils/Utils.h index 7337549d5ec21..4603429afd96c 100644 --- a/compiler/src/iree/compiler/Codegen/Utils/Utils.h +++ b/compiler/src/iree/compiler/Codegen/Utils/Utils.h @@ -7,6 +7,7 @@ #ifndef IREE_COMPILER_CODEGEN_UTILS_UTILS_H_ #define IREE_COMPILER_CODEGEN_UTILS_UTILS_H_ +#include "iree/compiler/Dialect/Flow/IR/FlowOps.h" #include "iree/compiler/Dialect/HAL/IR/HALOps.h" #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h" #include "llvm/TargetParser/Triple.h" @@ -251,6 +252,12 @@ computeDimUpperBound(Value shapedValue, unsigned dimNum, std::optional vscaleRange, RoundUpVscaleMultiple = RoundUpVscaleMultiple::No); +// Utility to make sure we are storing the full incoming subspan. Otherwise we +// cannot simply adjust the subspan's resultant type later. +bool isFullSlice(OffsetSizeAndStrideOpInterface sliceLoadStoreOp, + IREE::Flow::DispatchTensorType tensorType, + ValueRange dynamicDims); + } // namespace mlir::iree_compiler #endif // IREE_COMPILER_CODEGEN_UTILS_UTILS_H_