From 93687466750c4b95d3e6bfb57bcaecf4ae14bd5f Mon Sep 17 00:00:00 2001 From: Krzysztof Drewniak Date: Fri, 8 Nov 2024 20:46:48 +0000 Subject: [PATCH] [GPU] Use scf.if for forall overhangs In cases where we can't determine if the number of workitems per workgroup evenly divides the set of items that's required for an scf.forall, the current code uses `scf.for %i = %id to %upperBound step %numWorkitems` in order to make the last loop iteration only run on the expected fraction of wworkitems. This commit enables using linearize (and a step-1 loop) in the main body of the for loop the forall is being lowered to by switching to a post-loop if statement instead. --- .../Common/GPU/GPUDistributeForall.cpp | 72 ++++++++++++------- .../GPU/test/gpu_distribute_forall.mlir | 30 +++++--- .../test/ROCDL/pipeline_tile_and_fuse.mlir | 3 +- 3 files changed, 68 insertions(+), 37 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeForall.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeForall.cpp index 334427cfffb9..3cdc4e6bcc7b 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeForall.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeForall.cpp @@ -100,43 +100,49 @@ LogicalResult resolveGPUMappedForallOp(RewriterBase &rewriter, .getResult(0); } - SmallVector delinSizes; - OpFoldResult totalLoopTripCount = rewriter.getIndexAttr(1); + SmallVector delinSizes; + OpFoldResult producerCount = rewriter.getIndexAttr(1); for (auto workerCount : forallOp.getMixedUpperBound()) { - delinSizes.push_back( - getValueOrCreateConstantIndexOp(rewriter, loc, workerCount)); - totalLoopTripCount = affine::makeComposedFoldedAffineApply( - rewriter, loc, d0 * d1, {totalLoopTripCount, workerCount}); + delinSizes.push_back(workerCount); + producerCount = affine::makeComposedFoldedAffineApply( + rewriter, loc, d0 * d1, {producerCount, workerCount}); } + // If the total number of producers doesn't evenly divide into int64_t flatTotalNumWorkers = hasWarpMapping ? flatWorkgroupSize / subgroupSize : flatWorkgroupSize; - std::optional staticProducerCount = - getConstantIntValue(totalLoopTripCount); - bool perfectlyDivides = - staticProducerCount && - staticProducerCount.value() % flatTotalNumWorkers == 0; + OpFoldResult newLoopTripCount = affine::makeComposedFoldedAffineApply( + rewriter, loc, d0.floorDiv(flatTotalNumWorkers), producerCount); + OpFoldResult remainingLanes = affine::makeComposedFoldedAffineApply( + rewriter, loc, d0 % flatTotalNumWorkers, {producerCount}); + + // If the loop isn't guaranteed to perfectly tile onto the workers, + // we will run one more iteration of the loop on the workitems where it + // needs to execute. + std::optional remainingLanesCount = + getConstantIntValue(remainingLanes); + bool hasPostLoopTail = + !remainingLanesCount || remainingLanesCount.value() != 0; + OpFoldResult maxIteration = + hasPostLoopTail + ? affine::makeComposedFoldedAffineApply( + rewriter, loc, d0.ceilDiv(flatTotalNumWorkers), {producerCount}) + : newLoopTripCount; // Step 3. Create the `scf.for` loop for the loop. - // If the workgroup count perfectly divides the loop's worker count, then we - // can use a lower bound of 0 and keep the loop bounds static. This helps - // simplify later loop folding patterns without an `affine.linearize_index` op - // to help with inferring int ranges. - Value lb = perfectlyDivides ? rewriter.create(loc, 0) - : flatId; - Value ub = getValueOrCreateConstantIndexOp(rewriter, loc, totalLoopTripCount); - Value step = - rewriter.create(loc, flatTotalNumWorkers); + Value lb = rewriter.create(loc, 0); + Value ub = getValueOrCreateConstantIndexOp(rewriter, loc, newLoopTripCount); + Value step = rewriter.create(loc, 1); auto forLoop = rewriter.create(loc, lb, ub, step, ValueRange{}); Block *loopBody = forLoop.getBody(); // Get the replacement IDs for the forall iterator ids. rewriter.setInsertionPointToStart(loopBody); - Value newFlatProducerId = - perfectlyDivides - ? affine::makeComposedAffineApply(rewriter, loc, d0 + d1, - {forLoop.getInductionVar(), flatId}) - : forLoop.getInductionVar(); + Value newFlatProducerId = rewriter.create( + loc, ValueRange{forLoop.getInductionVar(), flatId}, + ArrayRef{maxIteration, + rewriter.getIndexAttr(flatTotalNumWorkers)}, + /*disjoint=*/true); // We require a descending relative mapping, so we can reuse the upper bound // sizes directly. @@ -151,6 +157,22 @@ LogicalResult resolveGPUMappedForallOp(RewriterBase &rewriter, newBlockArgs); rewriter.eraseOp(forallTerminator); rewriter.eraseOp(forallOp); + + // Step 5. Create the post-loop code that only executes on some workitems. + if (hasPostLoopTail) { + rewriter.setInsertionPointAfter(forLoop); + IRMapping cloneMap; + Value willExecuteTail = rewriter.create( + loc, arith::CmpIPredicate::slt, flatId, + getValueOrCreateConstantIndexOp(rewriter, loc, remainingLanes)); + auto tailIfOp = rewriter.create( + loc, TypeRange{}, willExecuteTail, /*addThenBlock=*/false, + /*addElseBlock=*/false); + cloneMap.map(forLoop.getInductionVar(), ub); + // We're relying on the fact that `scf.for` and `scf.if` share the same + // terminator. + forLoop.getRegion().cloneInto(&tailIfOp.getThenRegion(), cloneMap); + } return success(); } diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_distribute_forall.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_distribute_forall.mlir index 32bda8c90f05..1d96d2902f0d 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_distribute_forall.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_distribute_forall.mlir @@ -16,8 +16,8 @@ func.func @distribute_thread_forall(%out : memref) // CHECK-DAG: %[[TX:.+]] = gpu.thread_id x // CHECK-DAG: %[[TY:.+]] = gpu.thread_id y // CHECK: %[[TFLAT:.+]] = affine.linearize_index disjoint [%[[TY]], %[[TX]]] by (2, 64) -// CHECK: scf.for %[[I:.+]] = %c0 to %c1024 step %c128 { -// CHECK: %[[LINID:.+]] = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%[[I]])[%[[TFLAT]]] +// CHECK: scf.for %[[I:.+]] = %c0 to %c8 step %c1 { +// CHECK: %[[LINID:.+]] = affine.linearize_index disjoint [%[[I]], %[[TFLAT]]] by (8, 128) // CHECK: memref.store {{.*}}[%[[LINID]]] // ----- @@ -38,8 +38,8 @@ func.func @distribute_warp_forall(%out : memref) // CHECK-DAG: %[[TY:.+]] = gpu.thread_id y // CHECK: %[[TFLAT:.+]] = affine.linearize_index disjoint [%[[TY]], %[[TX]]] by (2, 64) // CHECK: %[[WARPSPLIT:.+]]:2 = affine.delinearize_index %[[TFLAT]] into (4, 32) -// CHECK: scf.for %[[I:.+]] = %c0 to %c32 step %c4 { -// CHECK: %[[LINID:.+]] = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%[[I]])[%[[WARPSPLIT]]#0] +// CHECK: scf.for %[[I:.+]] = %c0 to %c8 step %c1 { +// CHECK: %[[LINID:.+]] = affine.linearize_index disjoint [%[[I]], %[[WARPSPLIT]]#0] by (8, 4) // CHECK: memref.store {{.*}}[%[[LINID]]] // ----- @@ -96,8 +96,10 @@ func.func @distribute_thread_forall_single_thread(%out : memref) // CHECK-DAG: %[[TX:.+]] = gpu.thread_id x // CHECK-DAG: %[[TY:.+]] = gpu.thread_id y // CHECK: %[[TFLAT:.+]] = affine.linearize_index disjoint [%[[TY]], %[[TX]]] by (2, 64) -// CHECK: scf.for %[[I:.+]] = %[[TFLAT]] to %c1 step %c128 { -// CHECK: memref.store {{.*}}[%[[I]]] +// CHECK-NOT: scf.for +// CHECK: %[[TIDGUARD:.+]] = arith.cmpi slt, %[[TFLAT]], %[[C1]] +// CHECK: scf.if %[[TIDGUARD]] { +// CHECK: memref.store {{.*}}[%[[TFLAT]]] // ----- @@ -113,12 +115,18 @@ func.func @distribute_thread_forall_overhang(%out : memref) } // CHECK-LABEL: func @distribute_thread_forall_overhang -// CHECK-DAG: %[[C513:.+]] = arith.constant 513 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index // CHECK-DAG: %[[TX:.+]] = gpu.thread_id x // CHECK-DAG: %[[TY:.+]] = gpu.thread_id y // CHECK: %[[TFLAT:.+]] = affine.linearize_index disjoint [%[[TY]], %[[TX]]] by (2, 64) -// CHECK: scf.for %[[I:.+]] = %[[TFLAT]] to %[[C513]] step %c128 { -// CHECK: memref.store {{.*}}[%[[I]]] +// CHECK: scf.for %[[I:.+]] = %c0 to %[[C4]] step %[[C1]] { +// CHECK: %[[LINID:.+]] = affine.linearize_index disjoint [%[[I]], %[[TFLAT]]] by (5, 128) +// CHECK: memref.store {{.*}}[%[[LINID]]] +// CHECK: %[[TIDGUARD:.+]] = arith.cmpi slt, %[[TFLAT]], %[[C1]] +// CHECK: scf.if %[[TIDGUARD]] { +// CHECK: %[[LINID_IF:.+]] = affine.linearize_index disjoint [%[[C4]], %[[TFLAT]]] +// CHECK: memref.store {{.*}}[%[[LINID_IF]]] // ----- @@ -137,8 +145,8 @@ func.func @distribute_thread_forall_multi_dim(%out : memref) // CHECK-DAG: %[[TX:.+]] = gpu.thread_id x // CHECK-DAG: %[[TY:.+]] = gpu.thread_id y // CHECK: %[[TFLAT:.+]] = affine.linearize_index disjoint [%[[TY]], %[[TX]]] by (2, 64) -// CHECK: scf.for %[[I:.+]] = %c0 to %c512 step %c128 { -// CHECK: %[[LINID:.+]] = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%[[I]])[%[[TFLAT]]] +// CHECK: scf.for %[[I:.+]] = %c0 to %c4 step %c1 { +// CHECK: %[[LINID:.+]] = affine.linearize_index disjoint [%[[I]], %[[TFLAT]]] by (4, 128) // CHECK: %[[DELIN:.+]]:3 = affine.delinearize_index %[[LINID]] into (16, 8, 4) : index // CHECK: memref.store {{.*}}[%[[DELIN]]#0, %[[DELIN]]#1, %[[DELIN]]#2] diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir index 3f5b280b6342..20776ce870a9 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir @@ -902,7 +902,8 @@ hal.executable public @main { // CHECK-LABEL: func @small_matvec // CHECK-DAG: %[[B2:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(2) -// CHECK: scf.for %{{.*}} = %{{.*}} to %c1 step %c64 +// CHECK: %[[COND:.+]] = arith.cmpi slt, %{{.*}}, %c1 +// CHECK: scf.if %[[COND]] // Verify that the write does not get hoisted out of the single threaded // for loop.