diff --git a/compiler/src/iree/compiler/Codegen/Common/test/canonicalize_interface_load_store.mlir b/compiler/src/iree/compiler/Codegen/Common/test/canonicalize_interface_load_store.mlir index 50ca569bc8f1..6f1cc19b452e 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/canonicalize_interface_load_store.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/canonicalize_interface_load_store.mlir @@ -71,26 +71,36 @@ func.func @dont_fold_reshape_with_not_full_load() { // ----- #pipeline_layout = #hal.pipeline.layout, #hal.pipeline.binding ]> -// CHECK-LABEL: func.func @dont_fold_dynamic_reshape() -func.func @dont_fold_dynamic_reshape() { +func.func @fold_dynamic_reshape() { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %dim0 = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index %dim1 = hal.interface.constant.load layout(#pipeline_layout) ordinal(1) : index %dim2 = hal.interface.constant.load layout(#pipeline_layout) ordinal(2) : index %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) : !flow.dispatch.tensor>{%dim0, %dim1} - %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) : !flow.dispatch.tensor>{%dim2} + %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) : !flow.dispatch.tensor>{%dim2} %3 = flow.dispatch.tensor.load %1, offsets=[0, 0, 0], sizes =[%dim0, %dim1, 96], strides=[1, 1, 1] : !flow.dispatch.tensor>{%dim0, %dim1} -> tensor - // CHECK: tensor.collapse_shape - // CHECK: tensor.expand_shape %4 = tensor.collapse_shape %3 [[0, 1], [2]] : tensor into tensor %dyn = tensor.dim %4, %c0 : tensor %5 = tensor.expand_shape %4 [[0], [1, 2]] output_shape [%dyn, 12, 8] : tensor into tensor - flow.dispatch.tensor.store %5, %2, offsets = [%c0, %c0, %c0], sizes = [%c1, 12, 8], strides = [%c1, %c1, %c1] : tensor -> !flow.dispatch.tensor>{%dim2} + flow.dispatch.tensor.store %5, %2, offsets = [0, 0, 0], sizes = [%dim2, 12, 8], strides = [1, 1, 1] : tensor -> !flow.dispatch.tensor>{%dim2} return } +// CHECK: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 * s1)> +// CHECK: func.func @fold_dynamic_reshape() +// CHECK-DAG: %[[CST0:.+]] = hal.interface.constant.load layout(#{{.+}}) ordinal(0) +// CHECK-DAG: %[[CST1:.+]] = hal.interface.constant.load layout(#{{.+}}) ordinal(1) +// CHECK-DAG: %[[CST2:.+]] = hal.interface.constant.load layout(#{{.+}}) ordinal(2) +// CHECK: %[[COLLAPSED:.+]] = affine.apply #[[MAP]]()[%[[CST0]], %[[CST1]]] +// CHECK: %[[IN_BINDING:.+]] = hal.interface.binding.subspan +// CHECK-SAME: binding(0) : !flow.dispatch.tensor>{%[[COLLAPSED]]} +// CHECK: %[[OUT_BINDING:.+]] = hal.interface.binding.subspan +// CHECK-SAME: binding(1) : !flow.dispatch.tensor>{%[[CST2]]} +// CHECK: %[[IN:.+]] = flow.dispatch.tensor.load %[[IN_BINDING]] +// CHECK: flow.dispatch.tensor.store %[[IN]], %[[OUT_BINDING]] // ----- diff --git a/compiler/src/iree/compiler/Codegen/Common/test/propagate_reshapes_by_expansion.mlir b/compiler/src/iree/compiler/Codegen/Common/test/propagate_reshapes_by_expansion.mlir index 7dd745e5a7c3..fc9e85e3a764 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/propagate_reshapes_by_expansion.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/propagate_reshapes_by_expansion.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-propagate-reshapes-by-expansion))" --split-input-file %s | FileCheck %s +// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-propagate-reshapes-by-expansion))" --split-input-file %s --mlir-print-local-scope | FileCheck %s func.func @reshape_and_lowering_config(%src: tensor<3x4xf16>, %dest: tensor<12xf16>, %dest2: tensor<12xf16>) -> tensor<12xf16> { %collapse = tensor.collapse_shape %src [[0, 1]] : tensor<3x4xf16> into tensor<12xf16> @@ -14,3 +14,75 @@ func.func @reshape_and_lowering_config(%src: tensor<3x4xf16>, %dest: tensor<12xf // CHECK: linalg.copy // CHECK-SAME: lowering_config = #iree_gpu.derived_thread_config // CHECK-SAME: ins(%[[COLLAPSE]] + +// ----- + +#pipeline_layout = #hal.pipeline.layout], flags = Indirect> +func.func @fold_collapse_into_loads_dynamic() -> tensor { + %c0 = arith.constant 0 : index + %0 = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index + %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) + flags("ReadOnly|Indirect") : !flow.dispatch.tensor>{%0} + %2 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [2, %0, 32], strides = [1, 1, 1] + : !flow.dispatch.tensor>{%0} -> tensor<2x?x32xf32> + %3 = tensor.collapse_shape %2 [[0, 1], [2]] : tensor<2x?x32xf32> into tensor + return %3 : tensor +} +// CHECK-LABEL: func @fold_collapse_into_loads_dynamic() +// CHECK: %[[CONST:.+]] = hal.interface.constant.load +// CHECK: %[[SHAPE:.+]] = affine.apply affine_map<()[s0] -> (s0 * 2)>()[%[[CONST]]] +// CHECK: %[[SUBSPAN:.+]] = hal.interface.binding.subspan +// CHECK-SAME: !flow.dispatch.tensor>{%[[SHAPE]]} +// CHECK: %[[LOAD:.+]] = flow.dispatch.tensor.load %[[SUBSPAN]] +// CHECK-SAME: offsets = [0, 0], sizes = [%[[SHAPE]], 32], strides = [1, 1] +// CHECK-SAME: !flow.dispatch.tensor>{%[[SHAPE]]} + +// ----- + +#pipeline_layout = #hal.pipeline.layout], flags = Indirect> +func.func @fold_expand_into_loads_dynamic() -> tensor<2x?x16x32xf32> { + %c0 = arith.constant 0 : index + %0 = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index + %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) + flags("ReadOnly|Indirect") : !flow.dispatch.tensor>{%0} + %2 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [2, %0, 32], strides = [1, 1, 1] + : !flow.dispatch.tensor>{%0} -> tensor<2x?x32xf32> + %3 = affine.apply affine_map<()[s0] -> (s0 floordiv 2)>()[%0] + %4 = tensor.expand_shape %2 [[0], [1, 2], [3]] output_shape [2, %3, 16, 32] : tensor<2x?x32xf32> into tensor<2x?x16x32xf32> + return %4 : tensor<2x?x16x32xf32> +} +// CHECK-LABEL: func @fold_expand_into_loads_dynamic() +// CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index +// CHECK-DAG: %[[CONST:.+]] = hal.interface.constant.load +// CHECK: %[[SHAPE:.+]] = arith.divui %[[CONST]], %[[C16]] +// CHECK: %[[SUBSPAN:.+]] = hal.interface.binding.subspan +// CHECK-SAME: !flow.dispatch.tensor>{%[[SHAPE]]} +// CHECK: %[[LOAD:.+]] = flow.dispatch.tensor.load %[[SUBSPAN]] +// CHECK-SAME: offsets = [0, 0, 0, 0], sizes = [2, %[[SHAPE]], 16, 32], strides = [1, 1, 1, 1] +// CHECK-SAME: !flow.dispatch.tensor>{%[[SHAPE]]} + +// ----- + +#pipeline_layout = #hal.pipeline.layout], flags = Indirect> +func.func @fold_collapse_into_stores_dynamic(%arg0 : tensor<2x?x32xf32>) { + %c0 = arith.constant 0 : index + %0 = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index + %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) + flags("ReadOnly|Indirect") : !flow.dispatch.tensor>{%0} + %2 = tensor.collapse_shape %arg0 [[0, 1], [2]] : tensor<2x?x32xf32> into tensor + flow.dispatch.tensor.store %2, %1, offsets = [0, 0], sizes = [%0, 32], strides = [1, 1] + : tensor -> !flow.dispatch.tensor>{%0} + return +} +// CHECK-LABEL: func @fold_collapse_into_stores_dynamic( +// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index +// CHECK: %[[CONST:.+]] = hal.interface.constant.load +// CHECK: %[[SHAPE:.+]] = arith.divui %[[CONST]], %[[C2]] +// CHECK: %[[SUBSPAN:.+]] = hal.interface.binding.subspan +// CHECK-SAME: !flow.dispatch.tensor>{%[[SHAPE]]} +// CHECK: flow.dispatch.tensor.store %{{.+}}, %[[SUBSPAN]] +// CHECK-SAME: offsets = [0, 0, 0], sizes = [2, %[[SHAPE]], 32], strides = [1, 1, 1] +// CHECK-SAME: !flow.dispatch.tensor>{%[[SHAPE]]} diff --git a/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp b/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp index 0b8c49c4be69..6d0d05277b11 100644 --- a/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp +++ b/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp @@ -36,20 +36,29 @@ namespace mlir::iree_compiler { -static bool isAllConstantValue(SmallVector ofrs, int64_t v) { +static bool isAllConstantValue(ArrayRef ofrs, int64_t v) { return llvm::all_of( ofrs, [&](OpFoldResult ofr) { return isConstantIntValue(ofr, v); }); } -static bool isFullSlice(SmallVector mixedOffsets, - SmallVector mixedSizes, - SmallVector mixedStrides, - IREE::Flow::DispatchTensorType tensorType) { - std::optional> constSizes = - getConstantIntValues(mixedSizes); +static bool isFullSlice(ArrayRef mixedOffsets, + ArrayRef mixedSizes, + ArrayRef mixedStrides, + IREE::Flow::DispatchTensorType tensorType, + ValueRange dynamicDims) { + OpBuilder builder(tensorType.getContext()); + SmallVector tensorShape = llvm::to_vector(tensorType.getShape()); + SmallVector mixedTensorShape = + mlir::getMixedValues(tensorShape, dynamicDims, builder); return isAllConstantValue(mixedOffsets, 0) && - isAllConstantValue(mixedStrides, 1) && constSizes && - llvm::equal(tensorType.getShape(), *constSizes); + isAllConstantValue(mixedStrides, 1) && mixedTensorShape == mixedSizes; +} +static bool isFullSlice(OffsetSizeAndStrideOpInterface sliceLoadStoreOp, + IREE::Flow::DispatchTensorType tensorType, + ValueRange dynamicDims) { + return isFullSlice( + sliceLoadStoreOp.getMixedOffsets(), sliceLoadStoreOp.getMixedSizes(), + sliceLoadStoreOp.getMixedStrides(), tensorType, dynamicDims); } static bool sliceFilter(Operation *op, ValueRange nonIndexComputationOperands, @@ -546,14 +555,29 @@ void moveLoopInvariantCodeFromGuaranteedLoops(Operation *target) { namespace { -// TODO(antigainst): enable dynamic shape support once they are needed. -template -static std::optional getStaticReshapeOpSrc(TensorReshapeOp reshapeOp) { - auto reshapeSrcType = llvm::cast(reshapeOp.getSrc().getType()); - auto reshapeDstType = llvm::cast(reshapeOp.getType()); - if (!reshapeSrcType.hasStaticShape() || !reshapeDstType.hasStaticShape()) - return std::nullopt; - return reshapeOp.getSrc(); +static SmallVector +inferCollapsedShape(RewriterBase &rewriter, Location loc, + RankedTensorType expandedType, + ArrayRef reassociations, + ValueRange expandedDynamicDims) { + ArrayRef expandedStaticShape = expandedType.getShape(); + SmallVector expandedMixedShape = + mlir::getMixedValues(expandedStaticShape, expandedDynamicDims, rewriter); + SmallVector collapsedShape; + unsigned expandedShapeDim = 0; + for (auto reassociation : reassociations) { + AffineExpr mulExpr = rewriter.getAffineSymbolExpr(0); + for (auto i : llvm::seq(1, reassociation.size())) { + mulExpr = mulExpr * rewriter.getAffineSymbolExpr(i); + } + auto collapsedDim = affine::makeComposedFoldedAffineApply( + rewriter, loc, mulExpr, + ArrayRef(expandedMixedShape) + .slice(expandedShapeDim, reassociation.size())); + collapsedShape.push_back(collapsedDim); + expandedShapeDim += reassociation.size(); + } + return collapsedShape; } /// Folds tensor.expand/collapse_shape into the source @@ -576,35 +600,38 @@ static std::optional getStaticReshapeOpSrc(TensorReshapeOp reshapeOp) { /// !flow.dispatch.tensor> /// %0 = flow.dispatch.tensor.load %subspan : /// !flow.dispatch.tensor> -> tensor<864xf32> -template -struct FoldReshapeIntoInterfaceTensorLoad : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +struct FoldCollapseShapeIntoInterfaceTensorLoad + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp, + LogicalResult matchAndRewrite(tensor::CollapseShapeOp reshapeOp, PatternRewriter &rewriter) const override { - std::optional reshapeSrc = - getStaticReshapeOpSrc(reshapeOp); - if (!reshapeSrc) - return failure(); - - auto loadOp = - reshapeSrc->template getDefiningOp(); + Value reshapeSrc = reshapeOp.getSrc(); + auto reshapeSrcType = cast(reshapeSrc.getType()); + auto loadOp = reshapeSrc.getDefiningOp(); if (!loadOp) return failure(); // Make sure we are loading the full incoming subspan. Otherwise we cannot // simply adjust the subspan's resultant type later. - if (!isFullSlice(loadOp.getMixedOffsets(), loadOp.getMixedSizes(), - loadOp.getMixedStrides(), loadOp.getSourceType())) { + if (!isFullSlice(loadOp, loadOp.getSourceType(), loadOp.getSourceDims())) { return failure(); } - auto subspanOp = - loadOp.getSource() - .template getDefiningOp(); + auto subspanOp = loadOp.getSource() + .getDefiningOp(); if (!subspanOp) return failure(); - assert(subspanOp.getDynamicDims().empty()); + + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(subspanOp); + SmallVector collapsedShape = inferCollapsedShape( + rewriter, subspanOp.getLoc(), reshapeSrcType, + reshapeOp.getReassociationIndices(), subspanOp.getDynamicDims()); + SmallVector collapsedStaticShape; + SmallVector collapsedDynamicShape; + dispatchIndexOpFoldResults(collapsedShape, collapsedDynamicShape, + collapsedStaticShape); auto tensorAccess = llvm::cast(subspanOp.getType()) @@ -615,12 +642,111 @@ struct FoldReshapeIntoInterfaceTensorLoad : OpRewritePattern { Value newSubspanOp = rewriter.create( subspanOp.getLoc(), newSubspanType, subspanOp.getLayout(), subspanOp.getBinding(), subspanOp.getByteOffset(), - subspanOp.getDynamicDims(), subspanOp.getAlignmentAttr(), + collapsedDynamicShape, subspanOp.getAlignmentAttr(), subspanOp.getDescriptorFlagsAttr()); + rewriter.setInsertionPoint(reshapeOp); + rewriter.replaceOpWithNewOp( + reshapeOp, reshapeOp.getResultType(), newSubspanOp, + collapsedDynamicShape); + + return success(); + } +}; + +/// Folds tensor.expand_shape into the source +/// hal.interface.binding.subspan. +/// +/// For example, this matches the following pattern: +/// +/// %subspan = hal.interface.binding.subspan ... : +/// !flow.dispatch.tensor> +/// %tensor = flow.dispatch.tensor.load %subspan : +/// !flow.dispatch.tensor> -> +/// tensor<3x3x1x96xf32> +/// %0 = linalg.expand_reshape %tensor [ +/// affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +/// ] : tensor<3x3x1x96xf32> into tensor<864xf32> +/// +/// And turns it into: +/// +/// %subspan = hal.interface.binding.subspan ... : +/// !flow.dispatch.tensor> +/// %0 = flow.dispatch.tensor.load %subspan : +/// !flow.dispatch.tensor> -> tensor<864xf32> +struct FoldExpandShapeIntoInterfaceTensorLoad + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp, + PatternRewriter &rewriter) const override { + Value reshapeSrc = reshapeOp.getSrc(); + auto loadOp = reshapeSrc.getDefiningOp(); + if (!loadOp) { + return failure(); + } + + // Make sure we are loading the full incoming subspan. Otherwise we cannot + // simply adjust the subspan's resultant type later. + if (!isFullSlice(loadOp, loadOp.getSourceType(), loadOp.getSourceDims())) { + return failure(); + } + + // In the corner case where the expand_shape is the source of a store, dont + // fold with the load. Instead fold with the store to reduce the + // dimensionality + if (reshapeOp->hasOneUse()) { + if (auto storeOp = dyn_cast( + *reshapeOp->getUsers().begin())) { + if (isFullSlice(storeOp, storeOp.getTargetType(), + storeOp.getTargetDims())) { + return rewriter.notifyMatchFailure(reshapeOp, + "fold with store instead"); + } + } + } + + auto subspanOp = loadOp.getSource() + .getDefiningOp(); + if (!subspanOp) + return failure(); + + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(subspanOp); + + auto currDynamicDims = subspanOp.getDynamicDims(); + auto currStaticDims = loadOp.getType().getShape(); + auto currOfrDynamicDims = + mlir::getMixedValues(currStaticDims, currDynamicDims, rewriter); + std::optional> expandedDims = + mlir::inferExpandShapeOutputShape( + rewriter, subspanOp.getLoc(), reshapeOp.getType(), + reshapeOp.getReassociationIndices(), currOfrDynamicDims); + if (!expandedDims) { + return reshapeOp.emitOpError("failure in expanded shape"); + } + + auto tensorAccess = + llvm::cast(subspanOp.getType()) + .getAccess(); + auto newSubspanType = IREE::Flow::DispatchTensorType::get( + tensorAccess, reshapeOp.getResultType()); + + SmallVector expandedDynamicDims; + SmallVector expandedStaticDims; + dispatchIndexOpFoldResults(expandedDims.value(), expandedDynamicDims, + expandedStaticDims); + + Value newSubspanOp; + newSubspanOp = rewriter.create( + subspanOp.getLoc(), newSubspanType, subspanOp.getLayout(), + subspanOp.getBinding(), subspanOp.getByteOffset(), expandedDynamicDims, + subspanOp.getAlignmentAttr(), subspanOp.getDescriptorFlagsAttr()); + + rewriter.setInsertionPoint(reshapeOp); rewriter.replaceOpWithNewOp( reshapeOp, reshapeOp.getResultType(), newSubspanOp, - loadOp.getSourceDims()); + expandedDynamicDims); return success(); } @@ -652,8 +778,8 @@ struct FoldExpandShapeIntoInterfaceTensorStore PatternRewriter &rewriter) const override { // Make sure we are storing the full incoming subspan. Otherwise we cannot // simply adjust the subspan's resultant type later. - if (!isFullSlice(storeOp.getMixedOffsets(), storeOp.getMixedSizes(), - storeOp.getMixedStrides(), storeOp.getTargetType())) { + if (!isFullSlice(storeOp, storeOp.getTargetType(), + storeOp.getTargetDims())) { return failure(); } @@ -662,38 +788,136 @@ struct FoldExpandShapeIntoInterfaceTensorStore return failure(); } - // Dynamic shapes are currently unsupported. - std::optional reshapeSrc = - getStaticReshapeOpSrc(reshapeOp); - if (!reshapeSrc) - return failure(); + Value reshapeSrc = reshapeOp.getSrc(); + // If the source is a `flow.dispatch.tensor.load`, fold with the load + // instead to reduce dimensionality of the problem + if (auto loadOp = + reshapeSrc.getDefiningOp()) { + if (isFullSlice(loadOp, loadOp.getSourceType(), loadOp.getSourceDims())) { + return rewriter.notifyMatchFailure( + storeOp, "fold expand_shape with load instead"); + } + } - auto subspanOp = - storeOp.getTarget() - .template getDefiningOp(); + auto subspanOp = storeOp.getTarget() + .getDefiningOp(); if (!subspanOp) return failure(); - assert(subspanOp.getDynamicDims().empty()); + + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(subspanOp); + SmallVector collapsedShape = inferCollapsedShape( + rewriter, subspanOp.getLoc(), reshapeOp.getResultType(), + reshapeOp.getReassociationIndices(), subspanOp.getDynamicDims()); + SmallVector collapsedStaticShape; + SmallVector collapsedDynamicShape; + dispatchIndexOpFoldResults(collapsedShape, collapsedDynamicShape, + collapsedStaticShape); auto tensorAccess = llvm::cast(subspanOp.getType()) .getAccess(); - auto newSubspanType = IREE::Flow::DispatchTensorType::get( - tensorAccess, reshapeSrc->getType()); + auto newSubspanType = + IREE::Flow::DispatchTensorType::get(tensorAccess, reshapeSrc.getType()); - Value newSubspanOp; - { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointAfter(subspanOp); - newSubspanOp = rewriter.create( - subspanOp.getLoc(), newSubspanType, subspanOp.getLayout(), - subspanOp.getBinding(), subspanOp.getByteOffset(), - subspanOp.getDynamicDims(), subspanOp.getAlignmentAttr(), - subspanOp.getDescriptorFlagsAttr()); + Value newSubspanOp = rewriter.create( + subspanOp.getLoc(), newSubspanType, subspanOp.getLayout(), + subspanOp.getBinding(), subspanOp.getByteOffset(), + collapsedDynamicShape, subspanOp.getAlignmentAttr(), + subspanOp.getDescriptorFlagsAttr()); + + rewriter.setInsertionPoint(storeOp); + rewriter.replaceOpWithNewOp( + storeOp, reshapeSrc, newSubspanOp, collapsedDynamicShape); + + return success(); + } +}; + +/// Folds tensor.collapse_shape into the source hal.interface.binding.subspan. +/// +/// For example, this matches the following pattern: +/// +/// %subspan = hal.interface.binding.subspan ... : +/// !flow.dispatch.tensor> +/// %0 = tensor.collapse_shape %tensor [[0, 1, 2, 3]] +/// : tensor<3x?x?x96xf32> into tensor +/// %tensor = flow.dispatch.tensor.store %0, %subspan : +/// tensor -> !flow.dispatch.tensor>{%dim} +/// +/// And turns it into: +/// +/// %subspan = hal.interface.binding.subspan ... : +/// !flow.dispatch.tensor> +/// %0 = flow.dispatch.tensor.store %tensor, %subspan : +/// tensor<3x?x?x96xf32> -> +/// !flow.dispatch.tensor>{%d0, %d1} +/// +/// TODO: This handles full slices. The pattern below +/// (`FoldCollapseShapeIntoTensorInsertSlice`) handles cases where the slic is +/// not a full slice, but requires the shapes to be static. This pattern handles +/// dynamic shapes as well. Combine the two (if possible, it isnt clear that it +/// is possible) +struct FoldCollapseShapeIntoInterfaceTensorStoreFullSlice + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(IREE::Flow::DispatchTensorStoreOp storeOp, + PatternRewriter &rewriter) const override { + // Make sure we are storing the full incoming subspan. Otherwise we cannot + // simply adjust the subspan's resultant type later. + if (!isFullSlice(storeOp, storeOp.getTargetType(), + storeOp.getTargetDims())) { + return failure(); } + auto reshapeOp = + storeOp.getValue().getDefiningOp(); + if (!reshapeOp) { + return failure(); + } + auto subspanOp = storeOp.getTarget() + .getDefiningOp(); + if (!subspanOp) + return failure(); + + Value reshapeSrc = reshapeOp.getSrc(); + auto reshapeSrcType = cast(reshapeSrc.getType()); + + // Compute the type and dynamic dims of the interface binding. + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(subspanOp); + auto dynamicDims = subspanOp.getDynamicDims(); + ArrayRef staticShape = reshapeOp.getType().getShape(); + SmallVector mixedShape = + mlir::getMixedValues(staticShape, dynamicDims, rewriter); + std::optional> expandedShape = + mlir::inferExpandShapeOutputShape( + rewriter, subspanOp.getLoc(), + cast(reshapeSrc.getType()), + reshapeOp.getReassociationIndices(), mixedShape); + if (!expandedShape) { + return rewriter.notifyMatchFailure( + storeOp, "failed to compute expand shape for interface binding"); + } + SmallVector expandedStaticShape; + SmallVector expandedDynamicShape; + dispatchIndexOpFoldResults(*expandedShape, expandedDynamicShape, + expandedStaticShape); + + auto tensorAccess = + cast(subspanOp.getType()).getAccess(); + auto newSubspanType = + IREE::Flow::DispatchTensorType::get(tensorAccess, reshapeSrcType); + + auto newSubspanOp = rewriter.create( + subspanOp.getLoc(), newSubspanType, subspanOp.getLayout(), + subspanOp.getBinding(), subspanOp.getByteOffset(), expandedDynamicShape, + subspanOp.getAlignmentAttr(), subspanOp.getDescriptorFlagsAttr()); + + rewriter.setInsertionPoint(storeOp); rewriter.replaceOpWithNewOp( - storeOp, *reshapeSrc, newSubspanOp, storeOp.getTargetDims()); + storeOp, reshapeSrc, newSubspanOp, expandedDynamicShape); return success(); } @@ -840,12 +1064,11 @@ struct FoldCollapseShapeIntoInterfaceTensorStore } // namespace void populateReshapeToInterfaceTensorPatterns(RewritePatternSet &patterns) { - patterns.insert, - FoldReshapeIntoInterfaceTensorLoad>( - patterns.getContext()); - patterns.insert( - patterns.getContext()); - patterns.insert( + patterns.insert( patterns.getContext()); }