From c3eabac94c1cf50d9c7dc9b97c74b7fce3c7bda8 Mon Sep 17 00:00:00 2001 From: Krzysztof Drewniak Date: Tue, 12 Nov 2024 22:44:53 +0000 Subject: [PATCH] [GPU] Use affine.linearize_index (and delinearize_index) where possible There have been issues with the composition of affine maps being too general and loosing important information, like the fact that affine_map<(s0 + s1 * 32 + ... - (s0 floorDiv 16) * 16)> realy should be affine_map<(s0 mod 16 + s1 * 32 + ...)>, and other issues with the ultimate IR that block low-level arithmetic optimizations. The affine.delinearize_index operation represents the div/mod chains needed to break a flat index into its component parts. A recently added affine.linearize_index operation is its inverse - combining multiple indices into a flat 1D value. Another advantage to linearize/delinearize is simpler upstream canonicalizations and lead to more streamlined generated code. This PR updates the vector distribution code and other GPU-related code that I could find to 1. Use affine.linearize_index to construct flat thread IDs 2. Use affine.delinearize_index in places where there was a floorDiv/mod chain. 3. Plumb the subgroup size through the transfer_read and transfer_write distribution patterns to enable better reasoning about when you do/don't need to take a mod of the lane ID --- .../Common/GPU/GPUDistributeForall.cpp | 37 ++-- .../GPU/GPUDistributeSharedMemoryCopy.cpp | 32 ++-- .../Common/GPU/GPUDistributionPatterns.cpp | 97 +++++----- .../compiler/Codegen/Common/GPU/GPUPatterns.h | 1 + .../GPU/test/gpu_distribute_forall.mlir | 63 +++--- .../test/gpu_distribute_shared_memory.mlir | 31 ++- .../GPU/test/gpu_vector_distribution.mlir | 33 ++-- ...ransform_gpu_distribute_shared_memory.mlir | 17 +- .../TransformExtensions/CommonExtensions.cpp | 11 +- .../CommonExtensionsOps.td | 3 +- .../Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp | 18 +- .../Dialect/GPU/Transforms/Transforms.cpp | 5 +- .../test/distribute_mma_to_lanes.mlir | 43 ++--- .../LLVMGPU/LLVMGPUVectorDistribute.cpp | 28 ++- .../TransformExtensions/LLVMGPUExtensions.cpp | 6 +- .../LLVMGPUExtensionsOps.td | 3 +- .../test/ROCDL/pipeline_tile_and_fuse.mlir | 8 +- .../LLVMGPU/test/transpose_pipeline_test.mlir | 181 +++++++++--------- third_party/llvm-project | 2 +- 19 files changed, 290 insertions(+), 329 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeForall.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeForall.cpp index 64623462a5269..b9d6e6c09e3c5 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeForall.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeForall.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/GPU/Transforms/Passes.h" #include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" namespace mlir::iree_compiler { @@ -87,9 +88,12 @@ LogicalResult resolveGPUMappedForallOp(RewriterBase &rewriter, assert(!(hasThreadMapping && hasWarpMapping)); Value flatId = linearThreadId; if (hasWarpMapping) { - OpFoldResult subgroupSizeVal = rewriter.getIndexAttr(subgroupSize); - flatId = affine::makeComposedAffineApply(rewriter, loc, d0.floorDiv(d1), - {flatId, subgroupSizeVal}); + flatId = rewriter + .create( + loc, flatId, + ArrayRef{flatWorkgroupSize / subgroupSize, + subgroupSize}) + .getResult(0); } SmallVector delinSizes; @@ -190,23 +194,18 @@ void GPUDistributeForallPass::runOnOperation() { return signalPassFailure(); } - AffineExpr x, y, z; - bindSymbols(funcOp.getContext(), x, y, z); - // Compute the linearized thread id. - AffineExpr linearId = - x + workgroupSize[0] * y + workgroupSize[1] * workgroupSize[0] * z; - rewriter.setInsertionPointToStart(&funcOp.getFunctionBody().front()); - SmallVector threadGrid = { - rewriter.createOrFold(funcOp.getLoc(), - gpu::Dimension::x), - rewriter.createOrFold(funcOp.getLoc(), - gpu::Dimension::y), - rewriter.createOrFold(funcOp.getLoc(), - gpu::Dimension::z)}; - - Value linearThreadIdVal = affine::makeComposedAffineApply( - rewriter, funcOp.getLoc(), linearId, threadGrid); + SmallVector threadGrid = {rewriter.createOrFold( + funcOp.getLoc(), gpu::Dimension::z), + rewriter.createOrFold( + funcOp.getLoc(), gpu::Dimension::y), + rewriter.createOrFold( + funcOp.getLoc(), gpu::Dimension::x)}; + SmallVector threadGridBasis = {workgroupSize[2], workgroupSize[1], + workgroupSize[0]}; + + Value linearThreadIdVal = rewriter.create( + funcOp.getLoc(), threadGrid, threadGridBasis, /*disjoint=*/true); for (auto forall : forallOps) { rewriter.setInsertionPoint(forall); if (failed(resolveGPUMappedForallOp(rewriter, forall, linearThreadIdVal, diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeSharedMemoryCopy.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeSharedMemoryCopy.cpp index 4610c545e553e..47329c84f1897 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeSharedMemoryCopy.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeSharedMemoryCopy.cpp @@ -189,10 +189,8 @@ SmallVector getIds(OpBuilder &b, Location loc, ArrayRef parallelLoopRanges, Value flatThreadId) { SmallVector infos; - Value id = flatThreadId; - AffineExpr d0 = b.getAffineDimExpr(0); - for (Range r : llvm::reverse(parallelLoopRanges)) { - linalg::ProcInfo info; + SmallVector delinSizes; + for (Range r : parallelLoopRanges) { auto offset = dyn_cast(r.offset); auto stride = dyn_cast(r.stride); auto size = dyn_cast(r.size); @@ -200,19 +198,20 @@ SmallVector getIds(OpBuilder &b, Location loc, int64_t numThreadsDim = (llvm::cast(size).getInt() - llvm::cast(offset).getInt()) / llvm::cast(stride).getInt(); - Value dimId = id; - if (infos.size() != parallelLoopRanges.size() - 1) - dimId = - affine::makeComposedAffineApply(b, loc, d0 % numThreadsDim, {dimId}); + delinSizes.push_back(numThreadsDim); + } + ValueRange dims = + b.create(loc, flatThreadId, delinSizes) + .getResults(); + + for (auto [dimId, numThreadsDim] : llvm::zip_equal(dims, delinSizes)) { + linalg::ProcInfo info; info.procId = dimId; info.nprocs = b.create(loc, numThreadsDim); info.distributionMethod = linalg::DistributionMethod::CyclicNumProcsEqNumIters; infos.push_back(info); - id = affine::makeComposedAffineApply(b, loc, d0.floorDiv(numThreadsDim), - {id}); } - std::reverse(infos.begin(), infos.end()); return infos; } @@ -288,19 +287,16 @@ static Value createFlatId(mlir::FunctionOpInterface funcOp, ArrayRef workgroupSize) { OpBuilder b(funcOp.getFunctionBody()); Type indexType = b.getIndexType(); - AffineExpr d0 = getAffineDimExpr(0, b.getContext()); - AffineExpr d1 = getAffineDimExpr(1, b.getContext()); - AffineExpr d2 = getAffineDimExpr(2, b.getContext()); Value threadX = b.create(funcOp.getLoc(), indexType, gpu::Dimension::x); Value threadY = b.create(funcOp.getLoc(), indexType, gpu::Dimension::y); Value threadZ = b.create(funcOp.getLoc(), indexType, gpu::Dimension::z); - Value flatThreadId = affine::makeComposedAffineApply( - b, funcOp.getLoc(), - d0 + workgroupSize[0] * d1 + (workgroupSize[0] * workgroupSize[1]) * d2, - {threadX, threadY, threadZ}); + Value flatThreadId = b.create( + funcOp.getLoc(), ValueRange{threadZ, threadY, threadX}, + ArrayRef{workgroupSize[2], workgroupSize[1], workgroupSize[0]}, + /*disjoint=*/true); return flatThreadId; } diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp index 0ef6e64d2c263..5582a63ac5581 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp @@ -32,54 +32,55 @@ namespace { /// parameterized by the thread grid. static SmallVector computeSIMDIndex(const LayoutIterator::State &state, LayoutAttr layout, Value laneId, + int64_t subgroupSize, RewriterBase &rewriter) { - MLIRContext *ctx = layout.getContext(); - AffineExpr threadX, threadY, threadZ; - bindSymbols(ctx, threadX, threadY, threadZ); + Location loc = laneId.getLoc(); + + auto [laneDimX, laneDimY, laneDimZ] = layout.getLaneGrid(); + int64_t gridsPerSubgroup = + llvm::divideCeil(subgroupSize, laneDimX * laneDimY * laneDimZ); + // Note: we add an extra entry to the delinearization here so that, if the + // vector layout requires fewer lanes than are present in the subgroup. + // Otherwise, we'd, for example, construct delinearizations with the basis (1, + // 1, 16) when there are 32 lanes, which would simplify to no delinearization + // at all. To resolve this, we add an extra term to the grid to capture the + // overflow. + auto reversedLaneGrid = rewriter.create( + loc, laneId, + ArrayRef{gridsPerSubgroup, laneDimZ, laneDimY, laneDimX}); SmallVector simdIndex; + // Calculate the index for each dim separately. for (PerDimLayoutAttr dimLayout : layout.getLayouts()) { - AffineExpr offset = getAffineConstantExpr(0, ctx); - AffineExpr stride = getAffineConstantExpr(1, ctx); - for (auto [label, shape] : llvm::reverse( - llvm::zip(dimLayout.getLabels(), dimLayout.getShapes()))) { + SmallVector linearizeVals; + for (LayoutDimensionAttr label : dimLayout.getLabels()) { int64_t position = state.lookup(label.getValue()).getPosition(); + // Note: indices are into a reversed lane grid that has an extra leading + // term we must ignore (so the X coordinate is result #3 and the Z + // coordinate is result #1). switch (label.getValue()) { case LayoutDimension::LANEX: - offset = offset + stride * threadX; + linearizeVals.push_back(reversedLaneGrid.getResult(3)); break; case LayoutDimension::LANEY: - offset = offset + stride * threadY; + linearizeVals.push_back(reversedLaneGrid.getResult(2)); break; case LayoutDimension::LANEZ: - offset = offset + stride * threadZ; + linearizeVals.push_back(reversedLaneGrid.getResult(1)); break; default: - offset = offset + stride * getAffineConstantExpr(position, ctx); + linearizeVals.push_back( + rewriter.createOrFold(loc, position)); break; } - stride = stride * getAffineConstantExpr(shape, ctx); } - auto [laneDimX, laneDimY, laneDimZ] = layout.getLaneGrid(); - SmallVector laneGrid = { - rewriter.create(laneId.getLoc(), laneDimZ), - rewriter.create(laneId.getLoc(), laneDimY), - rewriter.create(laneId.getLoc(), laneDimX)}; - FailureOr> maybeReversedLaneGridVals = - affine::delinearizeIndex(rewriter, laneId.getLoc(), laneId, laneGrid); - assert(succeeded(maybeReversedLaneGridVals) && - "Failed to delinearize lane index"); - SmallVector laneGridVals = {(*maybeReversedLaneGridVals)[2], - (*maybeReversedLaneGridVals)[1], - (*maybeReversedLaneGridVals)[0]}; - // Compute the index for the dim. - AffineMap indexMap = AffineMap::get(0, 3, offset); - Value index = rewriter.create( - rewriter.getUnknownLoc(), indexMap, laneGridVals); + Value index = rewriter.create( + rewriter.getUnknownLoc(), linearizeVals, dimLayout.getShapes(), + /*disjoint=*/true); simdIndex.push_back(index); } @@ -199,8 +200,9 @@ struct DistributeXferLayoutAttr : OpDistributionPattern { "expected vector::TransferReadOp or vector::TransferWriteOp"); DistributeXferLayoutAttr(MLIRContext *context, Value laneId, - PatternBenefit benefit = 1) - : OpDistributionPattern(context, benefit), laneId(laneId) {} + int64_t subgroupSize, PatternBenefit benefit = 1) + : OpDistributionPattern(context, benefit), laneId(laneId), + subgroupSize(subgroupSize) {} VectorValue accessMemory(OpTy xferOp, VectorValue accumulator, LayoutAttr vectorLayout, @@ -237,7 +239,7 @@ struct DistributeXferLayoutAttr : OpDistributionPattern { llvm::SmallBitVector &projectedDims, RewriterBase &rewriter) const { SmallVector simdIndices = - computeSIMDIndex(state, memoryLayout, laneId, rewriter); + computeSIMDIndex(state, memoryLayout, laneId, subgroupSize, rewriter); SmallVector memoryIndices(indices); // The memory layout has some projected leading dims that indices doesn't. @@ -272,6 +274,7 @@ struct DistributeXferLayoutAttr : OpDistributionPattern { } Value laneId; + int64_t subgroupSize; }; struct DistributeTransferReadLayoutAttr final @@ -940,9 +943,6 @@ struct DistributeLayoutConflictToSharedMemory final // Offset and indexing computation such that subgroups can // write and read to shared memory correctly and without conflicts. - AffineExpr d0, d1, d2, s0; - bindDims(rewriter.getContext(), d0, d1, d2); - bindSymbols(rewriter.getContext(), s0); auto indexType = rewriter.getIndexType(); Value threadX = rewriter.create(loc, indexType, gpu::Dimension::x); @@ -950,16 +950,21 @@ struct DistributeLayoutConflictToSharedMemory final rewriter.create(loc, indexType, gpu::Dimension::y); Value threadZ = rewriter.create(loc, indexType, gpu::Dimension::z); - Value flatThreadId = affine::makeComposedAffineApply( - rewriter, loc, - (d0 + workgroupSize.value()[0] * d1 + - (workgroupSize.value()[0] * workgroupSize.value()[1]) * d2), - {threadX, threadY, threadZ}); - Value subgroupOffset = affine::makeComposedAffineApply( - rewriter, loc, - s0.floorDiv(subgroupSize.value()) * - resolutionType.getShape()[vectorRank - 1], - {flatThreadId}); + Value flatThreadId = rewriter.create( + loc, ValueRange{threadZ, threadY, threadX}, + ArrayRef{workgroupSize.value()[2], workgroupSize.value()[1], + workgroupSize.value()[0]}, + /*disjoint=*/true); + + Value c0 = rewriter.create(loc, 0); + auto splitBySubgroups = rewriter.create( + loc, flatThreadId, + ArrayRef{numSubgroups, subgroupSize.value()}); + Value subgroupOffset = rewriter.create( + loc, ValueRange{splitBySubgroups.getResult(0), c0}, + ArrayRef{numSubgroups, + resolutionType.getShape()[vectorRank - 1]}, + /*disjoint=*/true); // Create shared memory to store the intermediate from src layout. auto workgroupMemoryAddressSpace = Attribute(gpu::AddressSpaceAttr::get( @@ -980,7 +985,6 @@ struct DistributeLayoutConflictToSharedMemory final shapes, strides); // Creating write/trip to shared memory using src layout. - Value c0 = rewriter.create(loc, 0); SmallVector indices(resolutionType.getRank(), c0); SmallVector inBounds(vectorRank, true); auto write = rewriter.create(loc, vector, subview, @@ -1117,10 +1121,11 @@ void populateGPUDistributionPatterns(RewritePatternSet &patterns) { } void populateGPUDistributionLayoutAttrPatterns(Value laneId, + int64_t subgroupSize, RewritePatternSet &patterns) { patterns .add( - patterns.getContext(), laneId); + patterns.getContext(), laneId, subgroupSize); patterns.add( patterns.getContext()); } diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPatterns.h b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPatterns.h index 87303844853ff..4044f7a46efaa 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPatterns.h +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPatterns.h @@ -32,6 +32,7 @@ void populateDropSharedMemoryDeallocOpPatterns(RewritePatternSet &patterns); void populateGPUDistributionPatterns(RewritePatternSet &patterns); void populateGPUDistributionLayoutAttrPatterns(Value laneId, + int64_t subgroupSize, RewritePatternSet &patterns); void populateGPUReductionDistributionPatterns(RewritePatternSet &patterns, diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_distribute_forall.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_distribute_forall.mlir index 214337437b763..32bda8c90f050 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_distribute_forall.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_distribute_forall.mlir @@ -15,11 +15,9 @@ func.func @distribute_thread_forall(%out : memref) // CHECK-LABEL: func @distribute_thread_forall // CHECK-DAG: %[[TX:.+]] = gpu.thread_id x // CHECK-DAG: %[[TY:.+]] = gpu.thread_id y -// CHECK-DAG: %[[TZ:.+]] = gpu.thread_id z +// CHECK: %[[TFLAT:.+]] = affine.linearize_index disjoint [%[[TY]], %[[TX]]] by (2, 64) // CHECK: scf.for %[[I:.+]] = %c0 to %c1024 step %c128 { -// CHECK: %[[LINID:.+]] = affine.apply -// CHECK-SAME: affine_map<(d0)[s0, s1, s2] -> (d0 + s0 + s1 * 64 + s2 * 128)>(%[[I]]) -// CHECK-SAME: [%[[TX]], %[[TY]], %[[TZ]]] +// CHECK: %[[LINID:.+]] = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%[[I]])[%[[TFLAT]]] // CHECK: memref.store {{.*}}[%[[LINID]]] // ----- @@ -38,11 +36,10 @@ func.func @distribute_warp_forall(%out : memref) // CHECK-LABEL: func @distribute_warp_forall // CHECK-DAG: %[[TX:.+]] = gpu.thread_id x // CHECK-DAG: %[[TY:.+]] = gpu.thread_id y -// CHECK-DAG: %[[TZ:.+]] = gpu.thread_id z +// CHECK: %[[TFLAT:.+]] = affine.linearize_index disjoint [%[[TY]], %[[TX]]] by (2, 64) +// CHECK: %[[WARPSPLIT:.+]]:2 = affine.delinearize_index %[[TFLAT]] into (4, 32) // CHECK: scf.for %[[I:.+]] = %c0 to %c32 step %c4 { -// CHECK: %[[LINID:.+]] = affine.apply -// CHECK-SAME: affine_map<(d0)[s0, s1, s2] -> (d0 + s1 * 2 + s2 * 4 + s0 floordiv 32)>(%[[I]]) -// CHECK-SAME: [%[[TX]], %[[TY]], %[[TZ]]] +// CHECK: %[[LINID:.+]] = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%[[I]])[%[[WARPSPLIT]]#0] // CHECK: memref.store {{.*}}[%[[LINID]]] // ----- @@ -78,11 +75,7 @@ func.func @distribute_thread_forall_drop_for_loop(%out : memref) // CHECK-LABEL: func @distribute_thread_forall_drop_for_loop // CHECK-DAG: %[[TX:.+]] = gpu.thread_id x // CHECK-DAG: %[[TY:.+]] = gpu.thread_id y -// CHECK-DAG: %[[TZ:.+]] = gpu.thread_id z -// CHECK-NOT: scf.for -// CHECK: %[[LINID:.+]] = affine.apply -// CHECK-SAME: affine_map<()[s0, s1, s2] -> (s0 + s1 * 64 + s2 * 128)> -// CHECK-SAME: [%[[TX]], %[[TY]], %[[TZ]]] +// CHECK: %[[LINID:.+]] = affine.linearize_index disjoint [%[[TY]], %[[TX]]] by (2, 64) // CHECK: memref.store {{.*}}[%[[LINID]]] // ----- @@ -99,13 +92,32 @@ func.func @distribute_thread_forall_single_thread(%out : memref) } // CHECK-LABEL: func @distribute_thread_forall_single_thread +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index // CHECK-DAG: %[[TX:.+]] = gpu.thread_id x // CHECK-DAG: %[[TY:.+]] = gpu.thread_id y -// CHECK-DAG: %[[TZ:.+]] = gpu.thread_id z -// CHECK: %[[LINID:.+]] = affine.apply -// CHECK-SAME: affine_map<()[s0, s1, s2] -> (s0 + s1 * 64 + s2 * 128)> -// CHECK-SAME: [%[[TX]], %[[TY]], %[[TZ]]] -// CHECK: scf.for %[[I:.+]] = %[[LINID]] to %c1 step %c128 { +// CHECK: %[[TFLAT:.+]] = affine.linearize_index disjoint [%[[TY]], %[[TX]]] by (2, 64) +// CHECK: scf.for %[[I:.+]] = %[[TFLAT]] to %c1 step %c128 { +// CHECK: memref.store {{.*}}[%[[I]]] + +// ----- + +#translation_info = #iree_codegen.translation_info + +func.func @distribute_thread_forall_overhang(%out : memref) + attributes {translation_info = #translation_info} { + %c0 = arith.constant 0 : i32 + scf.forall (%arg0) in (513) { + memref.store %c0, %out[%arg0] : memref + } {mapping = [#gpu.thread]} + return +} + +// CHECK-LABEL: func @distribute_thread_forall_overhang +// CHECK-DAG: %[[C513:.+]] = arith.constant 513 : index +// CHECK-DAG: %[[TX:.+]] = gpu.thread_id x +// CHECK-DAG: %[[TY:.+]] = gpu.thread_id y +// CHECK: %[[TFLAT:.+]] = affine.linearize_index disjoint [%[[TY]], %[[TX]]] by (2, 64) +// CHECK: scf.for %[[I:.+]] = %[[TFLAT]] to %[[C513]] step %c128 { // CHECK: memref.store {{.*}}[%[[I]]] // ----- @@ -124,11 +136,9 @@ func.func @distribute_thread_forall_multi_dim(%out : memref) // CHECK-LABEL: func @distribute_thread_forall_multi_dim // CHECK-DAG: %[[TX:.+]] = gpu.thread_id x // CHECK-DAG: %[[TY:.+]] = gpu.thread_id y -// CHECK-DAG: %[[TZ:.+]] = gpu.thread_id z +// CHECK: %[[TFLAT:.+]] = affine.linearize_index disjoint [%[[TY]], %[[TX]]] by (2, 64) // CHECK: scf.for %[[I:.+]] = %c0 to %c512 step %c128 { -// CHECK: %[[LINID:.+]] = affine.apply -// CHECK-SAME: affine_map<(d0)[s0, s1, s2] -> (d0 + s0 + s1 * 64 + s2 * 128)>(%[[I]]) -// CHECK-SAME: [%[[TX]], %[[TY]], %[[TZ]]] +// CHECK: %[[LINID:.+]] = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%[[I]])[%[[TFLAT]]] // CHECK: %[[DELIN:.+]]:3 = affine.delinearize_index %[[LINID]] into (16, 8, 4) : index // CHECK: memref.store {{.*}}[%[[DELIN]]#0, %[[DELIN]]#1, %[[DELIN]]#2] @@ -147,10 +157,5 @@ func.func @distribute_thread_forall_small_workgroup(%out : memref) } // CHECK-LABEL: func @distribute_thread_forall_small_workgroup -// CHECK-DAG: %[[TX:.+]] = gpu.thread_id x -// CHECK-DAG: %[[TY:.+]] = gpu.thread_id y -// CHECK-DAG: %[[TZ:.+]] = gpu.thread_id z -// CHECK: %[[LINID:.+]] = affine.apply -// CHECK-SAME: affine_map<()[s0, s1, s2] -> (s0 + s1 * 7 + s2 * 7)> -// CHECK-SAME: [%[[TX]], %[[TY]], %[[TZ]]] -// CHECK: memref.store {{.*}}[%[[LINID]]] +// CHECK: %[[TX:.+]] = gpu.thread_id x +// CHECK: memref.store {{.*}}[%[[TX]]] diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_distribute_shared_memory.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_distribute_shared_memory.mlir index 636add66dd0d6..8f526bd4dd91d 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_distribute_shared_memory.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_distribute_shared_memory.mlir @@ -49,12 +49,9 @@ module { } } -// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1, s2] -> (s1 * 8 + s2 * 32 + s0 floordiv 4)> -// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16)> -// CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0, s1, s2] -> (s1 * 8 + s2 * 32 + s0 floordiv 4 + 32)> -// CHECK-DAG: #[[$MAP3:.*]] = affine_map<()[s0, s1, s2] -> (s0 + s1 * 32 + s2 * 128)> -// CHECK-DAG: #[[$MAP4:.*]] = affine_map<()[s0, s1, s2] -> (s0 + s1 * 32 + s2 * 128 + 128)> -// CHECK-DAG: #[[$MAP5:.*]] = affine_map<()[s0, s1, s2] -> (s0 * 4 + s1 * 128 + s2 * 512)> +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0] -> (s0 * 4)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 + 32)> +// CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0] -> (s0 + 128)> // CHECK-LABEL: @shared_mem_cpy( // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index @@ -62,24 +59,22 @@ module { // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[TX:.*]] = gpu.thread_id x // CHECK-DAG: %[[TY:.*]] = gpu.thread_id y -// CHECK-DAG: %[[TZ:.*]] = gpu.thread_id z - -// CHECK-DAG: %[[Y0:.*]] = affine.apply #[[$MAP0]]()[%[[TX]], %[[TY]], %[[TZ]]] -// CHECK-DAG: %[[X0:.*]] = affine.apply #[[$MAP1]]()[%[[TX]]] -// CHECK: %[[R0:.*]] = vector.transfer_read %{{.*}}[%[[Y0]], %[[X0]]], %{{.*}} {in_bounds = [true, true]} : memref<64x16xf32>, vector<1x4xf32> -// CHECK: vector.transfer_write %[[R0]], %{{.*}}[%[[Y0]], %[[X0]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<64x16xf32, 3> -// CHECK-DAG: %[[Y1:.*]] = affine.apply #[[$MAP2]]()[%[[TX]], %[[TY]], %[[TZ]]] +// CHECK: %[[TFLAT:.*]] = affine.linearize_index disjoint [%[[TY]], %[[TX]]] by (4, 32) +// CHECK: %[[YX:.*]]:2 = affine.delinearize_index %[[TFLAT]] into (32, 4) +// CHECK: %[[X0:.*]] = affine.apply #[[$MAP0]]()[%[[YX]]#1] +// CHECK: %[[R0:.*]] = vector.transfer_read %{{.*}}[%[[YX]]#0, %[[X0]]], %{{.*}} {in_bounds = [true, true]} : memref<64x16xf32>, vector<1x4xf32> +// CHECK: vector.transfer_write %[[R0]], %{{.*}}[%[[YX]]#0, %[[X0]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<64x16xf32, 3> +// CHECK-DAG: %[[Y1:.*]] = affine.apply #[[$MAP1]]()[%[[YX]]#0] // CHECK: %[[R1:.*]] = vector.transfer_read %{{.*}}[%[[Y1]], %[[X0]]], %{{.*}} {in_bounds = [true, true]} : memref<64x16xf32>, vector<1x4xf32> // CHECK: vector.transfer_write %[[R1]], %{{.*}}[%[[Y1]], %[[X0]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<64x16xf32, 3> -// CHECK: %[[Y1:.*]] = affine.apply #[[$MAP3]]()[%[[TX]], %[[TY]], %[[TZ]]] -// CHECK: %[[R2:.*]] = vector.transfer_read %{{.*}}[%[[Y1]], %[[C0]]], %{{.*}} {in_bounds = [true, true]} : memref<256x4xf32>, vector<1x4xf32> -// CHECK: vector.transfer_write %[[R2]], %{{.*}}[%[[Y1]], %[[C0]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<256x4xf32, 3> -// CHECK: %[[Y2:.*]] = affine.apply #[[$MAP4]]()[%[[TX]], %[[TY]], %[[TZ]]] +// CHECK: %[[R2:.*]] = vector.transfer_read %{{.*}}[%[[TFLAT]], %[[C0]]], %{{.*}} {in_bounds = [true, true]} : memref<256x4xf32>, vector<1x4xf32> +// CHECK: vector.transfer_write %[[R2]], %{{.*}}[%[[TFLAT]], %[[C0]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<256x4xf32, 3> +// CHECK: %[[Y2:.*]] = affine.apply #[[$MAP2]]()[%[[TFLAT]]] // CHECK: %[[R3:.*]] = vector.transfer_read %{{.*}}[%[[Y2]], %[[C0]]], %{{.*}} {in_bounds = [true, true]} : memref<256x4xf32>, vector<1x4xf32> // CHECK: vector.transfer_write %[[R3]], %{{.*}}[%[[Y2]], %[[C0]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<256x4xf32, 3> -// CHECK: %[[X1:.*]] = affine.apply #[[$MAP5]]()[%[[TX]], %[[TY]], %[[TZ]]] +// CHECK: %[[X1:.*]] = affine.apply #[[$MAP0]]()[%[[TFLAT]]] // CHECK: %[[R4:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[X1]]], %{{.*}} {in_bounds = [true, true]} : memref<3x512xf32>, vector<1x4xf32> // CHECK: vector.transfer_write %[[R4]], %{{.*}}[%[[C0]], %[[X1]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<3x512xf32, 3> // CHECK: %[[R5:.*]] = vector.transfer_read %{{.*}}[%[[C1]], %[[X1]]], %{{.*}} {in_bounds = [true, true]} : memref<3x512xf32>, vector<1x4xf32> diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir index c392eb7835810..ad6cd9d48c671 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir @@ -197,9 +197,6 @@ builtin.module attributes { transform.with_named_sequence } { #layout_row_major = #iree_vector_ext.layout<<[BATCHX, LANEY], [2, 8]>, <[BATCHY, LANEX, VECTORX], [2, 1, 8]>> #layout_col_major = #iree_vector_ext.layout<<[BATCHX, LANEY, VECTORX], [1, 4, 4]>, <[BATCHY, LANEX], [2, 8]>> -// TODO: Use affine min tricks based on the grid size to elide the mod. -// Note that this IR is invalid if subgroup size != 8. - func.func @distribute_transfer_write_row_major(%root: vector<16x16xf16>, %alloc: memref<64x64xf16>) { %c0 = arith.constant 0 : index %rootl = iree_vector_ext.to_layout %root to layout(#layout_row_major) : vector<16x16xf16> @@ -208,24 +205,23 @@ func.func @distribute_transfer_write_row_major(%root: vector<16x16xf16>, %alloc: : vector<16x16xf16>, memref<64x64xf16> func.return } -// CHECK-DAG: #[[$MAP0:.+]] = affine_map<()[s0] -> (s0 mod 8)> -// CHECK-DAG: #[[$MAP1:.+]] = affine_map<()[s0] -> (s0 mod 8 + 8)> // CHECK-LABEL: @distribute_transfer_write_row_major // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index // CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index // CHECK-DAG: %[[LANEID:.+]] = gpu.thread_id x -// CHECK: %[[VEC_LANE_Y:.+]] = affine.apply #[[$MAP0]]()[%[[LANEID]]] +// CHECK: %[[SPLIT_ID:.+]]:2 = affine.delinearize_index %[[LANEID]] into (8, 8) // CHECK: %[[DIST_SRC_VEC:.+]] = iree_vector_ext.to_simt %{{.*}} : vector<16x16xf16> -> vector<2x2x8xf16> // CHECK: %[[BATCH_0_0:.+]] = vector.extract %[[DIST_SRC_VEC]][0, 0] : vector<8xf16> from vector<2x2x8xf16> -// CHECK: vector.store %[[BATCH_0_0]], %{{.*}}[%[[VEC_LANE_Y]], %[[C0]]] : memref<64x64xf16>, vector<8xf16> +// CHECK: vector.store %[[BATCH_0_0]], %{{.*}}[%[[SPLIT_ID]]#1, %[[C0]]] : memref<64x64xf16>, vector<8xf16> -// CHECK: %[[NEXT_VEC_LANE_Y:.+]] = affine.apply #[[$MAP1]]()[%[[LANEID]]] +// CHECK: %[[NEXT_VEC_LANE_Y:.+]] = affine.linearize_index disjoint [%[[C1]], %[[SPLIT_ID]]#1] by (2, 8) : index // CHECK: %[[BATCH_1_0:.+]] = vector.extract %[[DIST_SRC_VEC]][1, 0] : vector<8xf16> from vector<2x2x8xf16> // CHECK: vector.store %[[BATCH_1_0]], %{{.*}}[%[[NEXT_VEC_LANE_Y]], %[[C0]]] : memref<64x64xf16>, vector<8xf16> // CHECK: %[[BATCH_0_1:.+]] = vector.extract %[[DIST_SRC_VEC]][0, 1] : vector<8xf16> from vector<2x2x8xf16> -// CHECK: vector.store %[[BATCH_0_1]], %{{.*}}[%[[VEC_LANE_Y]], %[[C8]]] : memref<64x64xf16>, vector<8xf16> +// CHECK: vector.store %[[BATCH_0_1]], %{{.*}}[%[[SPLIT_ID]]#1, %[[C8]]] : memref<64x64xf16>, vector<8xf16> // CHECK: %[[BATCH_1_1:.+]] = vector.extract %[[DIST_SRC_VEC]][1, 1] : vector<8xf16> from vector<2x2x8xf16> // CHECK: vector.store %[[BATCH_1_1]], %{{.*}}[%[[NEXT_VEC_LANE_Y]], %[[C8]]] : memref<64x64xf16>, vector<8xf16> @@ -560,8 +556,6 @@ builtin.module attributes { transform.with_named_sequence } { #layoutB2 = #iree_vector_ext.layout<<[ BATCHX, LANEY, VECTORX], [1, 1, 16]>, <[ BATCHY, LANEX], [1, 16]>> #layoutC2 = #iree_vector_ext.layout<<[ BATCHX, VECTORY, LANEY, VECTORX], [1, 8, 2, 1]>, <[ BATCHY, LANEX], [1, 16]>> -// CHECK-DAG: #[[$MAP0:.+]] = affine_map<()[s0, s1, s2] -> (s1 * 16 + s2 * 32 + (s0 floordiv 32) * 16)> -// CHECK-DAG: #[[$MAP1:.+]] = affine_map<()[s0] -> (s0 mod 16)> // CHECK-LABEL: func.func @resolve_wmma_layout_conflict_with_shared_memory func.func @resolve_wmma_layout_conflict_with_shared_memory(%15 : vector<16x16xf16>, %14 : vector<16x16xf16>, @@ -607,19 +601,18 @@ func.func @resolve_wmma_layout_conflict_with_shared_memory(%15 : vector<16x16xf1 // CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index // CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index // CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index +// CHECK-DAG: %[[VEC_INIT:.+]] = arith.constant dense<0.000000e+00> : vector<1x1x16xf16 -// CHECK: %[[VEC_INIT:.+]] = arith.constant dense<0.000000e+00> : vector<1x1x16xf16 -// CHECK: %[[TID_X:.+]] = gpu.thread_id x -// CHECK: %[[TID_Y:.+]] = gpu.thread_id y -// CHECK: %[[TID_Z:.+]] = gpu.thread_id z -// CHECK: %[[SUBGROUP_OFFSET:.+]] = affine.apply #[[$MAP0]]()[%[[TID_X]], %[[TID_Y]], %[[TID_Z]]] +// CHECK: %[[TIDX:.+]] = gpu.thread_id x +// CHECK: %[[TIDY:.+]] = gpu.thread_id y +// CHECK: %[[SUBGROUP_OFFSET:.+]] = affine.linearize_index disjoint [%[[TIDY]], %[[C0]]] by (2, 16) // CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<16x32xf16, #gpu.address_space> // CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ALLOC]][0, %[[SUBGROUP_OFFSET]]] [16, 16] [1, 1] -// CHECK: %[[HALF_LANE_ID:.+]] = affine.apply #[[$MAP1]]()[%[[TID_X]]] -// CHECK-COUNT-8: vector.store %{{.+}}, %[[SUBVIEW]][%{{.+}}, %[[HALF_LANE_ID]]] +// CHECK: %[[SPLIT_LANE_ID:.+]]:2 = affine.delinearize_index %[[TIDX]] into (2, 16) +// CHECK-COUNT-8: vector.store %{{.+}}, %[[SUBVIEW]][%{{.+}}, %[[SPLIT_LANE_ID]]#1] // CHECK-AFTER: gpu.barrier -// CHECK: %[[LANE_OFFSET:.+]] = arith.addi %[[SUBGROUP_OFFSET]], %[[HALF_LANE_ID]] +// CHECK: %[[LANE_OFFSET:.+]] = arith.addi %[[SUBGROUP_OFFSET]], %[[SPLIT_LANE_ID]]#1 // CHECK: %[[LOAD0:.+]] = vector.load %[[ALLOC]][%[[C0]], %[[LANE_OFFSET]]] // CHECK: %[[INSERT0:.+]] = vector.insert_strided_slice %[[LOAD0]], %[[VEC_INIT]] {offsets = [0, 0, 0], strides = [1]} : vector<1xf16> into vector<1x1x16xf16> // CHECK: %[[LOAD1:.+]] = vector.load %[[ALLOC]][%[[C1]], %[[LANE_OFFSET]]] @@ -636,7 +629,7 @@ func.func @resolve_wmma_layout_conflict_with_shared_memory(%15 : vector<16x16xf1 builtin.module attributes { transform.with_named_sequence } { transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) { %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op - transform.iree.test_gpu_vector_distribution %top_level_func {experimental = true} : !transform.any_op + transform.iree.test_gpu_vector_distribution %top_level_func {experimental = true, subgroup_size = 32 : i64} : !transform.any_op transform.yield } } diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/transform_gpu_distribute_shared_memory.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/transform_gpu_distribute_shared_memory.mlir index ec765a1d5aa66..907070a35c5ce 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/transform_gpu_distribute_shared_memory.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/transform_gpu_distribute_shared_memory.mlir @@ -46,20 +46,19 @@ module attributes {transform.with_named_sequence} { transform.yield } } -// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1, s2] -> (s1 * 8 + s2 * 32 + s0 floordiv 4)> -// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16)> -// CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0, s1, s2] -> (s1 * 8 + s2 * 32 + s0 floordiv 4 + 32)> +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0] -> (s0 * 4)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 + 32)> // CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: @shared_mem_cpy( // CHECK-DAG: %[[TX:.*]] = gpu.thread_id x // CHECK-DAG: %[[TY:.*]] = gpu.thread_id y -// CHECK-DAG: %[[TZ:.*]] = gpu.thread_id z -// CHECK-DAG: %[[Y0:.*]] = affine.apply #[[$MAP0]]()[%[[TX]], %[[TY]], %[[TZ]]] -// CHECK-DAG: %[[X0:.*]] = affine.apply #[[$MAP1]]()[%[[TX]]] -// CHECK: %[[R0:.*]] = vector.transfer_read %{{.*}}[%[[Y0]], %[[X0]]], %{{.*}} {in_bounds = [true, true]} : memref<64x16xf32, #hal.descriptor_type>, vector<1x4xf32> -// CHECK: vector.transfer_write %[[R0]], %{{.*}}[%[[Y0]], %[[X0]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<64x16xf32, #gpu.address_space> -// CHECK-DAG: %[[Y1:.*]] = affine.apply #[[$MAP2]]()[%[[TX]], %[[TY]], %[[TZ]]] +// CHECK-DAG: %[[TFLAT:.*]] = affine.linearize_index disjoint [%[[TY]], %[[TX]]] by (4, 32) +// CHECK-DAG: %[[YX:.*]]:2 = affine.delinearize_index %[[TFLAT]] into (32, 4) +// CHECK-DAG: %[[X0:.*]] = affine.apply #[[$MAP0]]()[%[[YX]]#1] +// CHECK: %[[R0:.*]] = vector.transfer_read %{{.*}}[%[[YX]]#0, %[[X0]]], %{{.*}} {in_bounds = [true, true]} : memref<64x16xf32, #hal.descriptor_type>, vector<1x4xf32> +// CHECK: vector.transfer_write %[[R0]], %{{.*}}[%[[YX]]#0, %[[X0]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<64x16xf32, #gpu.address_space> +// CHECK-DAG: %[[Y1:.*]] = affine.apply #[[$MAP1]]()[%[[YX]]#0] // CHECK: %[[R1:.*]] = vector.transfer_read %{{.*}}[%[[Y1]], %[[X0]]], %{{.*}} {in_bounds = [true, true]} : memref<64x16xf32, #hal.descriptor_type>, vector<1x4xf32> // CHECK: vector.transfer_write %[[R1]], %{{.*}}[%[[Y1]], %[[X0]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<64x16xf32, #gpu.address_space> // CHECK: linalg.generic diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp index ac8ae7386f55f..d8c5f1c8081f3 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp @@ -1113,18 +1113,17 @@ transform_dialect::TestGpuVectorDistribution::applyToOne( rewriter.setInsertionPointToStart(&target.getFunctionBody().front()); // This is a test op so we unsafely use thread_id x as the lane ID. In // general this should linearize the thread IDs based on the workgroup size - // and divide by the subgroup size. i.e. + // and take the modulo by the subgroup size. i.e. // - // lane_id = (tid_x + tid_y * dim_x + tid_z * dim_y * dim_x) / subgroup_size; + // lane_id = (tid_x + tid_y * dim_x + tid_z * dim_y * dim_x) % subgroup_size; Value laneId = rewriter.create(target.getLoc(), gpu::Dimension::x); + int64_t subgroupSize = getSubgroupSize(); populateGPUDistributionPatterns(patterns); - populateGPUDistributionLayoutAttrPatterns(laneId, patterns); + populateGPUDistributionLayoutAttrPatterns(laneId, subgroupSize, patterns); populateGPUReductionDistributionPatterns(patterns); - // For testing we use subgroup size = 64. - populateGPUDistributeNestedLayoutAttrPatterns(patterns, laneId, - /*subgroupSize=*/64); + populateGPUDistributeNestedLayoutAttrPatterns(patterns, laneId, subgroupSize); populateGPUDistributeNestedLayoutContractAMDGPUPatterns(patterns); if (getExperimental()) populateGPULayoutResolutionDistributionPatterns(patterns); diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td index 5219b4a2da9c9..0c05178043c88 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td +++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td @@ -631,7 +631,8 @@ def TestGpuVectorDistribution : }]; let arguments = (ins TransformHandleTypeInterface:$target, - DefaultValuedOptionalAttr:$experimental); + DefaultValuedOptionalAttr:$experimental, + DefaultValuedOptionalAttr:$subgroup_size); let results = (outs); let assemblyFormat = [{ $target attr-dict `:` type($target)}]; diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp index 4f9b1ffab3b6d..bbb370283d648 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp @@ -1008,21 +1008,19 @@ LogicalResult DataTiledMMAAttr::populateOperandOffsetsSizesStrides( getSubgroupSize() / intrinsicLayoutThreadBound); } - // AffineDelinearizeIndexOp requires an in-bounds input index, so we bound it. - OpFoldResult threadIdBound = - builder.getIndexAttr(ShapedType::getNumElements(distributionThreadSizes)); - AffineExpr d0 = builder.getAffineDimExpr(0), d1 = builder.getAffineDimExpr(1); - OpFoldResult boundedThreadId = affine::makeComposedFoldedAffineApply( - builder, loc, {d0 % d1}, {threadId, threadIdBound}); + // Add a `0` at the front of the distribution sizes so that + // `affine.delinearize_index` clamp its output (we'll throw away the first + // result). + distributionThreadSizes.insert(distributionThreadSizes.begin(), 0); // Obtain the offsets from delinearization along the distributionThreadSizes. SmallVector tileOffsets = builder .create( - loc, - getValueOrCreateConstantIndexOp(builder, loc, boundedThreadId), - getAsIndexOpFoldResult(ctx, distributionThreadSizes)) - ->getResults(); + loc, getValueOrCreateConstantIndexOp(builder, loc, threadId), + distributionThreadSizes) + ->getResults() + .drop_front(); if (hasDistributionOnlyDim) { // Erase the delinearized index that corresponds to the extra distribution diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp index 73774ba09c3c1..67ba2b8252477 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp @@ -209,11 +209,10 @@ LogicalResult fuseForallIntoConsumer(RewriterBase &rewriter, // Compute the total producer loop worker count (P0 * ... * Pn). Value linearConsumerIdVal = getValueOrCreateConstantIndexOp(rewriter, loc, linearId); - SmallVector producerRanges; + SmallVector producerRanges; OpFoldResult producerWorkerCount = rewriter.getIndexAttr(1); for (auto workerCount : producer.getMixedUpperBound()) { - producerRanges.push_back( - getValueOrCreateConstantIndexOp(rewriter, loc, workerCount)); + producerRanges.push_back(workerCount); producerWorkerCount = affine::makeComposedFoldedAffineApply( rewriter, loc, d0 * d1, {producerWorkerCount, workerCount}); } diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/distribute_mma_to_lanes.mlir b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/distribute_mma_to_lanes.mlir index 27215dd770902..8e9e124e90f6e 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/distribute_mma_to_lanes.mlir +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/distribute_mma_to_lanes.mlir @@ -387,24 +387,21 @@ func.func @data_tiled_1x1x1_tensor_multi_mma(%lhs: tensor<1x1x4x16xf32>, %rhs: t return %0 : tensor<1x1x4x16x4xf32> } -// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0) -> (d0 mod 64)> - // CHECK-LABEL: func @data_tiled_1x1x1_tensor_multi_mma // CHECK-SAME: %[[LHS:[A-Za-z0-9]+]] // CHECK-SAME: %[[RHS:[A-Za-z0-9]+]] // CHECK-SAME: %[[ACC:[A-Za-z0-9]+]] // CHECK: scf.forall (%[[THREAD_ID:.+]]) in (64) shared_outs(%[[ACC_ARG:.+]] = %[[ACC]]) -> (tensor<1x1x4x16x4xf32>) -// CHECK: %[[ID_CLAMPED:.+]] = affine.apply #[[$MAP]](%[[THREAD_ID]]) -// CHECK-DAG: %[[IN_IDS:.+]]:2 = affine.delinearize_index %[[ID_CLAMPED]] into (4, 16) -// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][0, 0, %[[IN_IDS]]#0, %[[IN_IDS]]#1] [1, 1, 1, 1] [1, 1, 1, 1] -// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][0, 0, %[[IN_IDS]]#0, %[[IN_IDS]]#1] [1, 1, 1, 1] [1, 1, 1, 1] +// CHECK-DAG: %[[IN_IDS:.+]]:3 = affine.delinearize_index %[[THREAD_ID]] into (0, 4, 16) +// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][0, 0, %[[IN_IDS]]#1, %[[IN_IDS]]#2] [1, 1, 1, 1] [1, 1, 1, 1] +// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][0, 0, %[[IN_IDS]]#1, %[[IN_IDS]]#2] [1, 1, 1, 1] [1, 1, 1, 1] // CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC_ARG]] -// CHECK-SAME: [0, 0, %[[IN_IDS]]#0, %[[IN_IDS]]#1, 0] [1, 1, 1, 1, 4] [1, 1, 1, 1, 1] +// CHECK-SAME: [0, 0, %[[IN_IDS]]#1, %[[IN_IDS]]#2, 0] [1, 1, 1, 1, 4] [1, 1, 1, 1, 1] // CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS_SLICE]], %[[RHS_SLICE]], %[[ACC_SLICE]] // CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout // CHECK-SAME: : tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32> into tensor<1x1x1x1x4xf32> // CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC_ARG]] -// CHECK-SAME: [0, 0, %[[IN_IDS]]#0, %[[IN_IDS]]#1, 0] [1, 1, 1, 1, 4] [1, 1, 1, 1, 1] +// CHECK-SAME: [0, 0, %[[IN_IDS]]#1, %[[IN_IDS]]#2, 0] [1, 1, 1, 1, 4] [1, 1, 1, 1, 1] // CHECK: mapping = [#gpu.thread] // ----- @@ -424,26 +421,23 @@ func.func @data_tiled_2x2x4_tensor_multi_mma_unrolled(%lhs: tensor<1x1x2x4x16x4x return %0 : tensor<1x1x2x2x4x16x4xf32> } -// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0) -> (d0 mod 64)> - // CHECK-LABEL: func @data_tiled_2x2x4_tensor_multi_mma_unrolled // CHECK-SAME: %[[LHS:[A-Za-z0-9]+]] // CHECK-SAME: %[[RHS:[A-Za-z0-9]+]] // CHECK-SAME: %[[ACC:[A-Za-z0-9]+]] // CHECK: scf.forall (%[[THREAD_ID:.+]]) in (64) shared_outs(%[[ACC_ARG:.+]] = %[[ACC]]) -> (tensor<1x1x2x2x4x16x4xf32>) -// CHECK: %[[ID_CLAMPED:.+]] = affine.apply #[[$MAP]](%[[THREAD_ID]]) -// CHECK-DAG: %[[IN_IDS:.+]]:2 = affine.delinearize_index %[[ID_CLAMPED]] into (4, 16) +// CHECK-DAG: %[[IN_IDS:.+]]:3 = affine.delinearize_index %[[THREAD_ID]] into (0, 4, 16) // CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]] -// CHECK-SAME: [0, 0, 0, %[[IN_IDS]]#0, %[[IN_IDS]]#1, 0] [1, 1, 2, 1, 1, 4] [1, 1, 1, 1, 1, 1] +// CHECK-SAME: [0, 0, 0, %[[IN_IDS]]#1, %[[IN_IDS]]#2, 0] [1, 1, 2, 1, 1, 4] [1, 1, 1, 1, 1, 1] // CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]] -// CHECK-SAME: [0, 0, 0, %[[IN_IDS]]#0, %[[IN_IDS]]#1, 0] [1, 1, 2, 1, 1, 4] [1, 1, 1, 1, 1, 1] +// CHECK-SAME: [0, 0, 0, %[[IN_IDS]]#1, %[[IN_IDS]]#2, 0] [1, 1, 2, 1, 1, 4] [1, 1, 1, 1, 1, 1] // CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC_ARG]] -// CHECK-SAME: [0, 0, 0, 0, %[[IN_IDS]]#0, %[[IN_IDS]]#1, 0] [1, 1, 2, 2, 1, 1, 4] [1, 1, 1, 1, 1, 1, 1] +// CHECK-SAME: [0, 0, 0, 0, %[[IN_IDS]]#1, %[[IN_IDS]]#2, 0] [1, 1, 2, 2, 1, 1, 4] [1, 1, 1, 1, 1, 1, 1] // CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS_SLICE]], %[[RHS_SLICE]], %[[ACC_SLICE]] // CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout // CHECK-SAME: : tensor<1x1x2x1x1x4xf32>, tensor<1x1x2x1x1x4xf32> into tensor<1x1x2x2x1x1x4xf32> // CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC_ARG]] -// CHECK-SAME: [0, 0, 0, 0, %[[IN_IDS]]#0, %[[IN_IDS]]#1, 0] [1, 1, 2, 2, 1, 1, 4] [1, 1, 1, 1, 1, 1, 1] +// CHECK-SAME: [0, 0, 0, 0, %[[IN_IDS]]#1, %[[IN_IDS]]#2, 0] [1, 1, 2, 2, 1, 1, 4] [1, 1, 1, 1, 1, 1, 1] // CHECK: mapping = [#gpu.thread] // ----- @@ -463,27 +457,22 @@ func.func @data_tiled_2x2x4_tensor_multi_mma_unrolled_to_subgroups(%lhs: tensor< return %0 : tensor<1x1x2x2x4x16x4xf32> } -// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0) -> (d0 mod 128)> -// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0) -> (d0 mod 256)> - // CHECK-LABEL: func @data_tiled_2x2x4_tensor_multi_mma_unrolled_to_subgroups // CHECK-SAME: %[[LHS:[A-Za-z0-9]+]] // CHECK-SAME: %[[RHS:[A-Za-z0-9]+]] // CHECK-SAME: %[[ACC:[A-Za-z0-9]+]] // CHECK: scf.forall (%[[THREAD_ID:.+]]) in (256) shared_outs(%[[ACC_ARG:.+]] = %[[ACC]]) -> (tensor<1x1x2x2x4x16x4xf32>) -// CHECK: %[[ID_CLAMPED_128:.+]] = affine.apply #[[$MAP]](%[[THREAD_ID]]) -// CHECK-DAG: %[[IN_IDS:.+]]:3 = affine.delinearize_index %[[ID_CLAMPED_128]] into (2, 4, 16) +// CHECK-DAG: %[[IN_IDS:.+]]:4 = affine.delinearize_index %[[THREAD_ID]] into (0, 2, 4, 16) // CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]] -// CHECK-SAME: [0, 0, %[[IN_IDS]]#0, %[[IN_IDS]]#1, %[[IN_IDS]]#2, 0] [1, 1, 1, 1, 1, 4] [1, 1, 1, 1, 1, 1] +// CHECK-SAME: [0, 0, %[[IN_IDS]]#1, %[[IN_IDS]]#2, %[[IN_IDS]]#3, 0] [1, 1, 1, 1, 1, 4] [1, 1, 1, 1, 1, 1] // CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]] -// CHECK-SAME: [0, 0, %[[IN_IDS]]#0, %[[IN_IDS]]#1, %[[IN_IDS]]#2, 0] [1, 1, 1, 1, 1, 4] [1, 1, 1, 1, 1, 1] -// CHECK: %[[ID_CLAMPED_256:.+]] = affine.apply #[[$MAP1]](%[[THREAD_ID]]) -// CHECK-DAG: %[[ACC_IDS:.+]]:4 = affine.delinearize_index %[[ID_CLAMPED_256]] into (2, 2, 4, 16) +// CHECK-SAME: [0, 0, %[[IN_IDS]]#1, %[[IN_IDS]]#2, %[[IN_IDS]]#3, 0] [1, 1, 1, 1, 1, 4] [1, 1, 1, 1, 1, 1] +// CHECK-DAG: %[[ACC_IDS:.+]]:5 = affine.delinearize_index %[[THREAD_ID]] into (0, 2, 2, 4, 16) // CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC_ARG]] -// CHECK-SAME: [0, 0, %[[ACC_IDS]]#0, %[[ACC_IDS]]#1, %[[ACC_IDS]]#2, %[[ACC_IDS]]#3, 0] [1, 1, 1, 1, 1, 1, 4] [1, 1, 1, 1, 1, 1, 1] +// CHECK-SAME: [0, 0, %[[ACC_IDS]]#1, %[[ACC_IDS]]#2, %[[ACC_IDS]]#3, %[[ACC_IDS]]#4, 0] [1, 1, 1, 1, 1, 1, 4] [1, 1, 1, 1, 1, 1, 1] // CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS_SLICE]], %[[RHS_SLICE]], %[[ACC_SLICE]] // CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout} // CHECK-SAME: : tensor<1x1x1x1x1x4xf32>, tensor<1x1x1x1x1x4xf32> into tensor<1x1x1x1x1x1x4xf32> // CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC_ARG]] -// CHECK-SAME: [0, 0, %[[ACC_IDS]]#0, %[[ACC_IDS]]#1, %[[ACC_IDS]]#2, %[[ACC_IDS]]#3, 0] [1, 1, 1, 1, 1, 1, 4] [1, 1, 1, 1, 1, 1, 1] +// CHECK-SAME: [0, 0, %[[ACC_IDS]]#1, %[[ACC_IDS]]#2, %[[ACC_IDS]]#3, %[[ACC_IDS]]#4, 0] [1, 1, 1, 1, 1, 1, 4] [1, 1, 1, 1, 1, 1, 1] // CHECK: mapping = [#gpu.thread] diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorDistribute.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorDistribute.cpp index 466d7bd1bf805..54b5612b5d1e9 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorDistribute.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorDistribute.cpp @@ -34,7 +34,7 @@ class ContractionVectorLayoutOptions : public VectorLayoutOptions { int64_t subgroupSize) : VectorLayoutOptions(root), patterns(root->getContext()) { populateGPUDistributionPatterns(patterns); - populateGPUDistributionLayoutAttrPatterns(laneId, patterns); + populateGPUDistributionLayoutAttrPatterns(laneId, subgroupSize, patterns); populateGPUDistributeNestedLayoutAttrPatterns(patterns, laneId, subgroupSize); populateGPUDistributeNestedLayoutContractAMDGPUPatterns(patterns); @@ -81,24 +81,18 @@ struct LLVMGPUVectorDistributePass final } } - AffineExpr x, y, z; - bindSymbols(func.getContext(), x, y, z); - // Construct the expression for linearizing the thread indices. - AffineExpr linearId = - x + workgroupSize[0] * y + workgroupSize[1] * workgroupSize[0] * z; - IRRewriter rewriter(func); rewriter.setInsertionPointToStart(&func.getFunctionBody().front()); - SmallVector threadGrid = { - rewriter.createOrFold(func.getLoc(), - gpu::Dimension::x), - rewriter.createOrFold(func.getLoc(), - gpu::Dimension::y), - rewriter.createOrFold(func.getLoc(), - gpu::Dimension::z)}; - - Value linearThreadIdVal = affine::makeComposedAffineApply( - rewriter, func.getLoc(), linearId, threadGrid); + SmallVector threadGrid = {rewriter.createOrFold( + func.getLoc(), gpu::Dimension::z), + rewriter.createOrFold( + func.getLoc(), gpu::Dimension::y), + rewriter.createOrFold( + func.getLoc(), gpu::Dimension::x)}; + std::reverse(workgroupSize.begin(), workgroupSize.end()); + + Value linearThreadIdVal = rewriter.create( + func.getLoc(), threadGrid, workgroupSize, /*disjoint=*/true); std::optional subgroupSize = getSubgroupSize(func); if (!subgroupSize) { diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp index 3dd0c128008e8..beca786675a48 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp @@ -1473,13 +1473,13 @@ transform_dialect::AMDGPUDistributeVectorsOp::applyToOne( rewriter.setInsertionPointToStart(&target.getFunctionBody().front()); Value laneId = rewriter.create(target.getLoc(), gpu::Dimension::x); + int64_t subgroupSize = getSubgroupSize(); populateGPUDistributionPatterns(patterns); - populateGPUDistributionLayoutAttrPatterns(laneId, patterns); + populateGPUDistributionLayoutAttrPatterns(laneId, subgroupSize, patterns); populateGPUReductionDistributionPatterns(patterns); // For testing we use subgroup size = 64. - populateGPUDistributeNestedLayoutAttrPatterns(patterns, laneId, - /*subgroupSize=*/64); + populateGPUDistributeNestedLayoutAttrPatterns(patterns, laneId, subgroupSize); populateAMDGPUDistributionPatterns(patterns); populateGPULayoutResolutionDistributionPatterns(patterns); if (failed(distributeVectorOps(target, patterns, options))) { diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensionsOps.td b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensionsOps.td index ac3e7eef75136..4f1a8c8f163e2 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensionsOps.td +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensionsOps.td @@ -699,7 +699,8 @@ def AMDGPUDistributeVectorsOp : }]; let arguments = (ins TransformHandleTypeInterface:$target, - UnitAttr:$test_conversion); + UnitAttr:$test_conversion, + DefaultValuedOptionalAttr:$subgroup_size); let results = (outs TransformHandleTypeInterface:$result); let assemblyFormat = [{ diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir index f05bb7c32991c..6d8695b372c63 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir @@ -544,17 +544,13 @@ hal.executable public @main { } } -// CHECK: #[[$MAP:.+]] = affine_map<()[s0, s1, s2] -> (s0 + s1 * 8 + s2 * 32)> -// CHECK: #[[$MAP1:.+]] = affine_map<()[s0, s1] -> (s0 * 8 + s1)> +// CHECK: #[[$MAP0:.+]] = affine_map<()[s0, s1] -> (s0 * 8 + s1)> // CHECK-LABEL: func @skinny_matmul_config // CHECK-DAG: %[[IDX:.+]] = gpu.thread_id x // CHECK-DAG: %[[IDY:.+]] = gpu.thread_id y -// CHECK-DAG: %[[IDZ:.+]] = gpu.thread_id z -// CHECK: %[[LINID0:.+]] = affine.apply #[[$MAP]]()[%[[IDX]], %[[IDY]], %[[IDZ]]] -// CHECK: %[[IDS:.+]]:2 = affine.delinearize_index %[[LINID0:.+]] into (4, 8) : index, index -// CHECK: %[[LINID1:.+]] = affine.apply #[[$MAP1]]()[%[[IDS]]#0, %[[IDS]]#1] +// CHECK: %[[LINID1:.+]] = affine.apply #[[$MAP0]]()[%[[IDY]], %[[IDX]]] // CHECK: scf.forall ({{.*}}) in (32, 98) { // CHECK: scf.for %{{.*}} = %c0 to %c256 step %c4 {{.*}} -> (vector<1x4xf32>) // CHECK: scf.for %{{.*}} = %[[LINID1]] to %c4 step %c32 diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transpose_pipeline_test.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transpose_pipeline_test.mlir index 8aa87740b0578..5e0288fbfe711 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transpose_pipeline_test.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transpose_pipeline_test.mlir @@ -34,28 +34,28 @@ hal.executable @transpose_dispatch_0 { // CHECK-LABEL: hal.executable public @transpose_dispatch_0 // CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[D0:.*]] = gpu.thread_id x -// CHECK-DAG: %[[D1:.*]] = gpu.thread_id y -// CHECK-DAG: %[[D2:.*]] = gpu.thread_id z -// CHECK-DAG: %[[D3:.*]] = memref.alloc() : memref<32x33xf32, #gpu.address_space> -// CHECK: %[[D4:.*]] = hal.interface.binding.subspan layout({{.+}}) binding(0) alignment(64) offset(%[[C0]]) : memref<4096x4096xf32, #hal.descriptor_type> -// CHECK: memref.assume_alignment %[[D4]], 64 : memref<4096x4096xf32, #hal.descriptor_type> -// CHECK: %[[D5:.*]] = hal.interface.binding.subspan layout({{.+}}) binding(1) alignment(64) offset(%[[C0]]) : memref<4096x4096xf32, #hal.descriptor_type> -// CHECK: memref.assume_alignment %[[D5]], 64 : memref<4096x4096xf32, #hal.descriptor_type> +// CHECK-DAG: %[[TX:.*]] = gpu.thread_id x +// CHECK-DAG: %[[TY:.*]] = gpu.thread_id y +// CHECK-DAG: %[[ALLOC:.*]] = memref.alloc() : memref<32x33xf32, #gpu.address_space> +// CHECK: %[[D0:.*]] = hal.interface.binding.subspan layout({{.+}}) binding(0) alignment(64) offset(%[[C0]]) : memref<4096x4096xf32, #hal.descriptor_type> +// CHECK: memref.assume_alignment %[[D0]], 64 : memref<4096x4096xf32, #hal.descriptor_type> +// CHECK: %[[D1:.*]] = hal.interface.binding.subspan layout({{.+}}) binding(1) alignment(64) offset(%[[C0]]) : memref<4096x4096xf32, #hal.descriptor_type> +// CHECK: memref.assume_alignment %[[D1]], 64 : memref<4096x4096xf32, #hal.descriptor_type> +// CHECK-DAG: %[[WG1:.*]] = hal.interface.workgroup.id[1] : index +// CHECK-DAG: %[[WG0:.*]] = hal.interface.workgroup.id[0] : index + %workgroup_id_x = hal.interface.workgroup.id[0] : index // CHECK: gpu.barrier -// CHECK: %[[D6:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[D0]], %[[D1]], %[[D2]]] -// CHECK: %[[D7:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[D0]]] -// CHECK: %[[D8:.*]] = vector.transfer_read %[[D4]][%[[D6]], %[[D7]]], %[[CST]] {in_bounds = [true, true]} : memref<4096x4096xf32, #hal.descriptor_type>, vector<1x4xf32> -// CHECK: %[[D9:.*]] = affine.apply #{{.*}}()[%[[D0]], %[[D1]], %[[D2]]] -// CHECK: %[[D10:.*]] = affine.apply #{{.*}}()[%[[D0]]] -// CHECK: vector.transfer_write %[[D8]], %[[D3]][%[[D9]], %[[D10]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<32x33xf32, #gpu.address_space> +// CHECK: %[[D2:.*]] = affine.apply #{{.*}}()[%[[TY]], %[[WG0]]] +// CHECK: %[[D3:.*]] = affine.apply #{{.*}}()[%[[WG1]], %[[TX]]] +// CHECK: %[[D4:.*]] = vector.transfer_read %[[D0]][%[[D2]], %[[D3]]], %[[CST]] {in_bounds = [true, true]} : memref<4096x4096xf32, #hal.descriptor_type>, vector<1x4xf32> +// CHECK: %[[D5:.*]] = affine.apply #{{.*}}()[%[[TX]]] +// CHECK: vector.transfer_write %[[D4]], %[[ALLOC]][%[[TY]], %[[D5]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<32x33xf32, #gpu.address_space> // CHECK: gpu.barrier -// CHECK: %[[D11:.*]] = affine.apply #{{.*}}()[%[[D0]]] -// CHECK: %[[D12:.*]] = vector.transfer_read %[[D3]][%[[D11]], %[[D1]]], %[[CST]] {in_bounds = [true, true]} : memref<32x33xf32, #gpu.address_space>, vector<4x1xf32> -// CHECK: %[[D13:.*]] = vector.shape_cast %[[D12]] : vector<4x1xf32> to vector<4xf32> -// CHECK: %[[D15:.*]] = affine.apply #{{.*}}()[%[[D1]], %{{.*}}] -// CHECK: %[[D16:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[D0]]] -// CHECK: vector.transfer_write %[[D13]], %[[D5]][%[[D15]], %[[D16]]] {in_bounds = [true]} : vector<4xf32>, memref<4096x4096xf32, #hal.descriptor_type> +// CHECK: %[[D6:.*]] = vector.transfer_read %[[ALLOC]][%[[D5]], %[[TY]]], %[[CST]] {in_bounds = [true, true]} : memref<32x33xf32, #gpu.address_space>, vector<4x1xf32> +// CHECK: %[[D7:.*]] = vector.shape_cast %[[D6]] : vector<4x1xf32> to vector<4xf32> +// CHECK: %[[D8:.*]] = affine.apply #{{.*}}()[%[[TY]], %[[WG1]]] +// CHECK: %[[D9:.*]] = affine.apply #{{.*}}()[%[[WG0]], %[[TX]]] +// CHECK: vector.transfer_write %[[D7]], %[[D1]][%[[D8]], %[[D9]]] {in_bounds = [true]} : vector<4xf32>, memref<4096x4096xf32, #hal.descriptor_type> // ----- @@ -96,32 +96,29 @@ hal.executable @transpose_single_operand_dispatch_0_generic_768x2048 { // CHECK-LABEL: hal.executable public @transpose_single_operand_dispatch_0_generic_768x2048 // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 // CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[D0:.*]] = gpu.thread_id x -// CHECK: %[[D1:.*]] = gpu.thread_id y -// CHECK: %[[D2:.*]] = gpu.thread_id z -// CHECK: %[[D3:.*]] = memref.alloc() : memref<32x33xf32, #gpu.address_space> -// CHECK: %[[D4:.*]] = hal.interface.binding.subspan layout({{.+}}) binding(0) alignment(64) offset(%[[C0]]) : memref<2048x768xf32, #hal.descriptor_type> -// CHECK: memref.assume_alignment %[[D4]], 64 : memref<2048x768xf32, #hal.descriptor_type> -// CHECK: %[[D5:.*]] = hal.interface.binding.subspan layout({{.+}}) binding(1) alignment(64) offset(%[[C0]]) : memref<768x2048xf32, #hal.descriptor_type> -// CHECK: memref.assume_alignment %[[D5]], 64 : memref<768x2048xf32, #hal.descriptor_type> -// CHECK: %[[D6:.*]] = hal.interface.binding.subspan layout({{.+}}) binding(2) alignment(64) offset(%[[C0]]) : memref<768x2048xf32, #hal.descriptor_type> -// CHECK: memref.assume_alignment %[[D6]], 64 : memref<768x2048xf32, #hal.descriptor_type> +// CHECK: %[[TX:.*]] = gpu.thread_id x +// CHECK: %[[TY:.*]] = gpu.thread_id y +// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<32x33xf32, #gpu.address_space> +// CHECK: %[[D0:.*]] = hal.interface.binding.subspan layout({{.+}}) binding(0) alignment(64) offset(%[[C0]]) : memref<2048x768xf32, #hal.descriptor_type> +// CHECK: memref.assume_alignment %[[D0]], 64 : memref<2048x768xf32, #hal.descriptor_type> +// CHECK: %[[D1:.*]] = hal.interface.binding.subspan layout({{.+}}) binding(1) alignment(64) offset(%[[C0]]) : memref<768x2048xf32, #hal.descriptor_type> +// CHECK: memref.assume_alignment %[[D1]], 64 : memref<768x2048xf32, #hal.descriptor_type> +// CHECK: %[[D2:.*]] = hal.interface.binding.subspan layout({{.+}}) binding(2) alignment(64) offset(%[[C0]]) : memref<768x2048xf32, #hal.descriptor_type> +// CHECK: memref.assume_alignment %[[D2]], 64 : memref<768x2048xf32, #hal.descriptor_type> // CHECK: gpu.barrier -// CHECK: %[[D7:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[D0]], %[[D1]], %[[D2]]] -// CHECK: %[[D8:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[D0]]] -// CHECK: %[[D9:.*]] = vector.transfer_read %[[D4]][%[[D7]], %[[D8]]], %[[CST]] {in_bounds = [true, true]} : memref<2048x768xf32, #hal.descriptor_type>, vector<1x4xf32> -// CHECK: %[[D10:.*]] = affine.apply #{{.*}}()[%[[D0]], %[[D1]], %[[D2]]] -// CHECK: %[[D11:.*]] = affine.apply #{{.*}}()[%[[D0]]] -// CHECK: vector.transfer_write %[[D9]], %[[D3]][%[[D10]], %[[D11]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<32x33xf32, #gpu.address_space> +// CHECK: %[[D3:.*]] = affine.apply #{{.*}}()[%[[TY]], %{{.*}}] +// CHECK: %[[D4:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[TX]]] +// CHECK: %[[D5:.*]] = vector.transfer_read %[[D0]][%[[D3]], %[[D4]]], %[[CST]] {in_bounds = [true, true]} : memref<2048x768xf32, #hal.descriptor_type>, vector<1x4xf32> +// CHECK: %[[D6:.*]] = affine.apply #{{.*}}()[%[[TX]]] +// CHECK: vector.transfer_write %[[D5]], %[[ALLOC]][%[[TY]], %[[D6]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<32x33xf32, #gpu.address_space> // CHECK: gpu.barrier -// CHECK: %[[D12:.*]] = affine.apply #{{.*}}()[%[[D0]]] -// CHECK: %[[D13:.*]] = vector.transfer_read %[[D3]][%[[D12]], %[[D1]]], %[[CST]] {in_bounds = [true, true]} : memref<32x33xf32, #gpu.address_space>, vector<4x1xf32> -// CHECK: %[[D15:.*]] = affine.apply #{{.*}}()[%[[D1]], %{{.*}}] -// CHECK: %[[D16:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[D0]]] -// CHECK: %[[D17:.*]] = vector.transfer_read %[[D5]][%[[D15]], %[[D16]]], %[[CST]] {in_bounds = [true]} : memref<768x2048xf32, #hal.descriptor_type>, vector<4xf32> -// CHECK: %[[D14:.*]] = vector.shape_cast %[[D13]] : vector<4x1xf32> to vector<4xf32> -// CHECK: %[[D19:.*]] = arith.addf %[[D14]], %[[D17]] : vector<4xf32> -// CHECK: vector.transfer_write %[[D19]], %[[D6]][%[[D15]], %[[D16]]] {in_bounds = [true]} : vector<4xf32>, memref<768x2048xf32, #hal.descriptor_type> +// CHECK: %[[D7:.*]] = vector.transfer_read %[[ALLOC]][%[[D6]], %[[TY]]], %[[CST]] {in_bounds = [true, true]} : memref<32x33xf32, #gpu.address_space>, vector<4x1xf32> +// CHECK: %[[D8:.*]] = affine.apply #{{.*}}()[%[[TY]], %{{.*}}] +// CHECK: %[[D9:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[TX]]] +// CHECK: %[[D10:.*]] = vector.transfer_read %[[D1]][%[[D8]], %[[D9]]], %[[CST]] {in_bounds = [true]} : memref<768x2048xf32, #hal.descriptor_type>, vector<4xf32> +// CHECK: %[[D11:.*]] = vector.shape_cast %[[D7]] : vector<4x1xf32> to vector<4xf32> +// CHECK: %[[D12:.*]] = arith.addf %[[D11]], %[[D10]] : vector<4xf32> +// CHECK: vector.transfer_write %[[D12]], %[[D2]][%[[D8]], %[[D9]]] {in_bounds = [true]} : vector<4xf32>, memref<768x2048xf32, #hal.descriptor_type> // ----- @@ -203,32 +200,29 @@ hal.executable @transpose_3d_yes_dispatch_0_generic_10x768x2048 { // CHECK-LABEL: hal.executable public @transpose_3d_yes_dispatch_0_generic_10x768x2048 { // CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[D0:.*]] = gpu.thread_id x -// CHECK: %[[D1:.*]] = gpu.thread_id y -// CHECK: %[[D2:.*]] = gpu.thread_id z -// CHECK: %[[D3:.*]] = memref.alloc() : memref<1x32x33xf32, #gpu.address_space> -// CHECK: %[[D4:.*]] = hal.interface.binding.subspan layout({{.+}}) binding(0) alignment(64) offset(%[[C0]]) : memref<10x2048x768xf32, #hal.descriptor_type> -// CHECK: memref.assume_alignment %[[D4]], 64 : memref<10x2048x768xf32, #hal.descriptor_type> -// CHECK: %[[D5:.*]] = hal.interface.binding.subspan layout({{.+}}) binding(1) alignment(64) offset(%[[C0]]) : memref<10x768x2048xf32, #hal.descriptor_type> -// CHECK: memref.assume_alignment %[[D5]], 64 : memref<10x768x2048xf32, #hal.descriptor_type> -// CHECK: %[[D6:.*]] = hal.interface.binding.subspan layout({{.+}}) binding(2) alignment(64) offset(%[[C0]]) : memref<10x768x2048xf32, #hal.descriptor_type> -// CHECK: memref.assume_alignment %[[D6]], 64 : memref<10x768x2048xf32, #hal.descriptor_type> +// CHECK: %[[TX:.*]] = gpu.thread_id x +// CHECK: %[[TY:.*]] = gpu.thread_id y +// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<1x32x33xf32, #gpu.address_space> +// CHECK: %[[D0:.*]] = hal.interface.binding.subspan layout({{.+}}) binding(0) alignment(64) offset(%[[C0]]) : memref<10x2048x768xf32, #hal.descriptor_type> +// CHECK: memref.assume_alignment %[[D0]], 64 : memref<10x2048x768xf32, #hal.descriptor_type> +// CHECK: %[[D1:.*]] = hal.interface.binding.subspan layout({{.+}}) binding(1) alignment(64) offset(%[[C0]]) : memref<10x768x2048xf32, #hal.descriptor_type> +// CHECK: memref.assume_alignment %[[D1]], 64 : memref<10x768x2048xf32, #hal.descriptor_type> +// CHECK: %[[D2:.*]] = hal.interface.binding.subspan layout({{.+}}) binding(2) alignment(64) offset(%[[C0]]) : memref<10x768x2048xf32, #hal.descriptor_type> +// CHECK: memref.assume_alignment %[[D2]], 64 : memref<10x768x2048xf32, #hal.descriptor_type> // CHECK: gpu.barrier -// CHECK: %[[D7:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[D0]], %[[D1]], %[[D2]]] -// CHECK: %[[D8:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[D0]]] -// CHECK: %[[D9:.*]] = vector.transfer_read %[[D4]][%{{.*}}, %[[D7]], %[[D8]]], %[[CST]] {in_bounds = [true, true, true]} : memref<10x2048x768xf32, #hal.descriptor_type>, vector<1x1x4xf32> -// CHECK: %[[D10:.*]] = affine.apply #{{.*}}()[%[[D0]], %[[D1]], %[[D2]]] -// CHECK: %[[D11:.*]] = affine.apply #{{.*}}()[%[[D0]]] -// CHECK: vector.transfer_write %[[D9]], %[[D3]][%[[C0]], %[[D10]], %[[D11]]] {in_bounds = [true, true, true]} : vector<1x1x4xf32>, memref<1x32x33xf32, #gpu.address_space> +// CHECK: %[[D3:.*]] = affine.apply #{{.*}}()[%[[TY]], %{{.*}}] +// CHECK: %[[D4:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[TX]]] +// CHECK: %[[D5:.*]] = vector.transfer_read %[[D0]][%{{.*}}, %[[D3]], %[[D4]]], %[[CST]] {in_bounds = [true, true, true]} : memref<10x2048x768xf32, #hal.descriptor_type>, vector<1x1x4xf32> +// CHECK: %[[D6:.*]] = affine.apply #{{.*}}()[%[[TX]]] +// CHECK: vector.transfer_write %[[D5]], %[[ALLOC]][%[[C0]], %[[TY]], %[[D6]]] {in_bounds = [true, true, true]} : vector<1x1x4xf32>, memref<1x32x33xf32, #gpu.address_space> // CHECK: gpu.barrier -// CHECK: %[[D12:.*]] = affine.apply #{{.*}}()[%[[D0]]] -// CHECK: %[[D13:.*]] = vector.transfer_read %[[D3]][%[[C0]], %[[D12]], %[[D1]]], %[[CST]] {in_bounds = [true, true]} : memref<1x32x33xf32, #gpu.address_space>, vector<4x1xf32> -// CHECK: %[[D16:.*]] = affine.apply #{{.*}}()[%[[D1]], %{{.*}}] -// CHECK: %[[D17:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[D0]]] -// CHECK: %[[D18:.*]] = vector.transfer_read %[[D5]][%{{.*}}, %[[D16]], %[[D17]]], %[[CST]] {in_bounds = [true]} : memref<10x768x2048xf32, #hal.descriptor_type>, vector<4xf32> -// CHECK: %[[D15:.*]] = vector.shape_cast %[[D13]] : vector<4x1xf32> to vector<4xf32> -// CHECK: %[[D20:.*]] = arith.addf %[[D15]], %[[D18]] : vector<4xf32> -// CHECK: vector.transfer_write %[[D20]], %[[D6]][%{{.*}}, %[[D16]], %[[D17]]] {in_bounds = [true]} : vector<4xf32>, memref<10x768x2048xf32, #hal.descriptor_type> +// CHECK: %[[D7:.*]] = vector.transfer_read %[[ALLOC]][%[[C0]], %[[D6]], %[[TY]]], %[[CST]] {in_bounds = [true, true]} : memref<1x32x33xf32, #gpu.address_space>, vector<4x1xf32> +// CHECK: %[[D8:.*]] = affine.apply #{{.*}}()[%[[TY]], %{{.*}}] +// CHECK: %[[D9:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[TX]]] +// CHECK: %[[D10:.*]] = vector.transfer_read %[[D1]][%{{.*}}, %[[D8]], %[[D9]]], %[[CST]] {in_bounds = [true]} : memref<10x768x2048xf32, #hal.descriptor_type>, vector<4xf32> +// CHECK: %[[D11:.*]] = vector.shape_cast %[[D7]] : vector<4x1xf32> to vector<4xf32> +// CHECK: %[[D12:.*]] = arith.addf %[[D11]], %[[D10]] : vector<4xf32> +// CHECK: vector.transfer_write %[[D12]], %[[D2]][%{{.*}}, %[[D8]], %[[D9]]] {in_bounds = [true]} : vector<4xf32>, memref<10x768x2048xf32, #hal.descriptor_type> // ----- @@ -269,35 +263,32 @@ hal.executable @transpose_3d_trans_out_dispatch_0_generic_10x2048x768 { // CHECK-LABEL: hal.executable public @transpose_3d_trans_out_dispatch_0_generic_10x2048x768 { // CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[D0:.*]] = gpu.thread_id x -// CHECK: %[[D1:.*]] = gpu.thread_id y -// CHECK: %[[D2:.*]] = gpu.thread_id z -// CHECK: %[[D3:.*]] = memref.alloc() : memref<1x32x33xf32, #gpu.address_space> -// CHECK: %[[D4:.*]] = memref.alloc() : memref<1x32x33xf32, #gpu.address_space> -// CHECK: %[[D5:.*]] = hal.interface.binding.subspan layout({{.+}}) binding(0) alignment(64) offset(%[[C0]]) : memref<10x768x2048xf32, #hal.descriptor_type> -// CHECK: memref.assume_alignment %[[D5]], 64 : memref<10x768x2048xf32, #hal.descriptor_type> -// CHECK: %[[D6:.*]] = hal.interface.binding.subspan layout({{.+}}) binding(1) alignment(64) offset(%[[C0]]) : memref<10x768x2048xf32, #hal.descriptor_type> -// CHECK: memref.assume_alignment %[[D6]], 64 : memref<10x768x2048xf32, #hal.descriptor_type> -// CHECK: %[[D7:.*]] = hal.interface.binding.subspan layout({{.+}}) binding(2) alignment(64) offset(%[[C0]]) : memref<10x2048x768xf32, #hal.descriptor_type> -// CHECK: memref.assume_alignment %[[D7]], 64 : memref<10x2048x768xf32, #hal.descriptor_type> +// CHECK: %[[TX:.*]] = gpu.thread_id x +// CHECK: %[[TY:.*]] = gpu.thread_id y +// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<1x32x33xf32, #gpu.address_space> +// CHECK: %[[ALLOC1:.*]] = memref.alloc() : memref<1x32x33xf32, #gpu.address_space> +// CHECK: %[[D0:.*]] = hal.interface.binding.subspan layout({{.+}}) binding(0) alignment(64) offset(%[[C0]]) : memref<10x768x2048xf32, #hal.descriptor_type> +// CHECK: memref.assume_alignment %[[D0]], 64 : memref<10x768x2048xf32, #hal.descriptor_type> +// CHECK: %[[D1:.*]] = hal.interface.binding.subspan layout({{.+}}) binding(1) alignment(64) offset(%[[C0]]) : memref<10x768x2048xf32, #hal.descriptor_type> +// CHECK: memref.assume_alignment %[[D1]], 64 : memref<10x768x2048xf32, #hal.descriptor_type> +// CHECK: %[[D2:.*]] = hal.interface.binding.subspan layout({{.+}}) binding(2) alignment(64) offset(%[[C0]]) : memref<10x2048x768xf32, #hal.descriptor_type> +// CHECK: memref.assume_alignment %[[D2]], 64 : memref<10x2048x768xf32, #hal.descriptor_type> // CHECK: gpu.barrier -// CHECK: %[[D8:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[D0]], %[[D1]], %[[D2]]] -// CHECK: %[[D9:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[D0]]] -// CHECK: %[[D10:.*]] = vector.transfer_read %[[D5]][%{{.*}}, %[[D8]], %[[D9]]], %[[CST]] {in_bounds = [true, true, true]} : memref<10x768x2048xf32, #hal.descriptor_type>, vector<1x1x4xf32> -// CHECK: %[[D11:.*]] = affine.apply #{{.*}}()[%[[D0]], %[[D1]], %[[D2]]] -// CHECK: %[[D12:.*]] = affine.apply #{{.*}}()[%[[D0]]] -// CHECK: vector.transfer_write %[[D10]], %[[D4]][%[[C0]], %[[D11]], %[[D12]]] {in_bounds = [true, true, true]} : vector<1x1x4xf32>, memref<1x32x33xf32, #gpu.address_space> -// CHECK: %[[D13:.*]] = vector.transfer_read %[[D6]][%{{.*}}, %[[D8]], %[[D9]]], %[[CST]] {in_bounds = [true, true, true]} : memref<10x768x2048xf32, #hal.descriptor_type>, vector<1x1x4xf32> -// CHECK: vector.transfer_write %[[D13]], %[[D3]][%[[C0]], %[[D11]], %[[D12]]] {in_bounds = [true, true, true]} : vector<1x1x4xf32>, memref<1x32x33xf32, #gpu.address_space> +// CHECK: %[[D3:.*]] = affine.apply #{{.*}}()[%[[TY]], %{{.*}}] +// CHECK: %[[D4:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[TX]]] +// CHECK: %[[D5:.*]] = vector.transfer_read %[[D0]][%{{.*}}, %[[D3]], %[[D4]]], %[[CST]] {in_bounds = [true, true, true]} : memref<10x768x2048xf32, #hal.descriptor_type>, vector<1x1x4xf32> +// CHECK: %[[D6:.*]] = affine.apply #{{.*}}()[%[[TX]]] +// CHECK: vector.transfer_write %[[D5]], %[[ALLOC1]][%[[C0]], %[[TY]], %[[D6]]] {in_bounds = [true, true, true]} : vector<1x1x4xf32>, memref<1x32x33xf32, #gpu.address_space> +// CHECK: %[[D7:.*]] = vector.transfer_read %[[D1]][%{{.*}}, %[[D3]], %[[D4]]], %[[CST]] {in_bounds = [true, true, true]} : memref<10x768x2048xf32, #hal.descriptor_type>, vector<1x1x4xf32> +// CHECK: vector.transfer_write %[[D7]], %[[ALLOC]][%[[C0]], %[[TY]], %[[D6]]] {in_bounds = [true, true, true]} : vector<1x1x4xf32>, memref<1x32x33xf32, #gpu.address_space> // CHECK: gpu.barrier -// CHECK: %[[D14:.*]] = affine.apply #{{.*}}()[%[[D0]]] -// CHECK: %[[D15:.*]] = vector.transfer_read %[[D4]][%[[C0]], %[[D14]], %[[D1]]], %[[CST]] {in_bounds = [true, true]} : memref<1x32x33xf32, #gpu.address_space>, vector<4x1xf32> -// CHECK: %[[D16:.*]] = vector.transfer_read %[[D3]][%[[C0]], %[[D14]], %[[D1]]], %[[CST]] {in_bounds = [true, true]} : memref<1x32x33xf32, #gpu.address_space>, vector<4x1xf32> -// CHECK: %[[D17:.*]] = arith.addf %[[D15]], %[[D16]] : vector<4x1xf32> -// CHECK: %[[D19:.*]] = vector.shape_cast %[[D17]] : vector<4x1xf32> to vector<4xf32> -// CHECK: %[[D21:.*]] = affine.apply #{{.*}}()[%[[D1]], %{{.*}}] -// CHECK: %[[D22:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[D0]]] -// CHECK: vector.transfer_write %[[D19]], %[[D7]][%{{.*}}, %[[D21]], %[[D22]]] {in_bounds = [true]} : vector<4xf32>, memref<10x2048x768xf32, #hal.descriptor_type> +// CHECK: %[[D8:.*]] = vector.transfer_read %[[ALLOC1]][%[[C0]], %[[D6]], %[[TY]]], %[[CST]] {in_bounds = [true, true]} : memref<1x32x33xf32, #gpu.address_space>, vector<4x1xf32> +// CHECK: %[[D9:.*]] = vector.transfer_read %[[ALLOC]][%[[C0]], %[[D6]], %[[TY]]], %[[CST]] {in_bounds = [true, true]} : memref<1x32x33xf32, #gpu.address_space>, vector<4x1xf32> +// CHECK: %[[D10:.*]] = arith.addf %[[D8]], %[[D9]] : vector<4x1xf32> +// CHECK: %[[D11:.*]] = vector.shape_cast %[[D10]] : vector<4x1xf32> to vector<4xf32> +// CHECK: %[[D12:.*]] = affine.apply #{{.*}}()[%[[TY]], %{{.*}}] +// CHECK: %[[D13:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[TX]]] +// CHECK: vector.transfer_write %[[D11]], %[[D2]][%{{.*}}, %[[D12]], %[[D13]]] {in_bounds = [true]} : vector<4xf32>, memref<10x2048x768xf32, #hal.descriptor_type> // ----- diff --git a/third_party/llvm-project b/third_party/llvm-project index 889525fa99b25..0f52f88049c73 160000 --- a/third_party/llvm-project +++ b/third_party/llvm-project @@ -1 +1 @@ -Subproject commit 889525fa99b251dc962edb516e0108088ba7e44d +Subproject commit 0f52f88049c73d7c16a30cff8dfa0b40b63e5634