diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeForall.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeForall.cpp index b9d6e6c09e3c5..b3c6987336761 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeForall.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeForall.cpp @@ -96,43 +96,49 @@ LogicalResult resolveGPUMappedForallOp(RewriterBase &rewriter, .getResult(0); } - SmallVector delinSizes; - OpFoldResult totalLoopTripCount = rewriter.getIndexAttr(1); + SmallVector delinSizes; + OpFoldResult producerCount = rewriter.getIndexAttr(1); for (auto workerCount : forallOp.getMixedUpperBound()) { - delinSizes.push_back( - getValueOrCreateConstantIndexOp(rewriter, loc, workerCount)); - totalLoopTripCount = affine::makeComposedFoldedAffineApply( - rewriter, loc, d0 * d1, {totalLoopTripCount, workerCount}); + delinSizes.push_back(workerCount); + producerCount = affine::makeComposedFoldedAffineApply( + rewriter, loc, d0 * d1, {producerCount, workerCount}); } + // If the total number of producers doesn't evenly divide into int64_t flatTotalNumWorkers = hasWarpMapping ? flatWorkgroupSize / subgroupSize : flatWorkgroupSize; - std::optional staticProducerCount = - getConstantIntValue(totalLoopTripCount); - bool perfectlyDivides = - staticProducerCount && - staticProducerCount.value() % flatTotalNumWorkers == 0; + OpFoldResult newLoopTripCount = affine::makeComposedFoldedAffineApply( + rewriter, loc, d0.floorDiv(flatTotalNumWorkers), producerCount); + OpFoldResult remainingLanes = affine::makeComposedFoldedAffineApply( + rewriter, loc, d0 % flatTotalNumWorkers, {producerCount}); + + // If the loop isn't guaranteed to perfectly tile onto the workers, + // we will run one more iteration of the loop on the workitems where it + // needs to execute. + std::optional remainingLanesCount = + getConstantIntValue(remainingLanes); + bool hasPostLoopTail = + !remainingLanesCount || remainingLanesCount.value() != 0; + OpFoldResult maxIteration = + hasPostLoopTail + ? affine::makeComposedFoldedAffineApply( + rewriter, loc, d0.ceilDiv(flatTotalNumWorkers), {producerCount}) + : newLoopTripCount; // Step 3. Create the `scf.for` loop for the loop. - // If the workgroup count perfectly divides the loop's worker count, then we - // can use a lower bound of 0 and keep the loop bounds static. This helps - // simplify later loop folding patterns without an `affine.linearize_index` op - // to help with inferring int ranges. - Value lb = perfectlyDivides ? rewriter.create(loc, 0) - : flatId; - Value ub = getValueOrCreateConstantIndexOp(rewriter, loc, totalLoopTripCount); - Value step = - rewriter.create(loc, flatTotalNumWorkers); + Value lb = rewriter.create(loc, 0); + Value ub = getValueOrCreateConstantIndexOp(rewriter, loc, newLoopTripCount); + Value step = rewriter.create(loc, 1); auto forLoop = rewriter.create(loc, lb, ub, step, ValueRange{}); Block *loopBody = forLoop.getBody(); // Get the replacement IDs for the forall iterator ids. rewriter.setInsertionPointToStart(loopBody); - Value newFlatProducerId = - perfectlyDivides - ? affine::makeComposedAffineApply(rewriter, loc, d0 + d1, - {forLoop.getInductionVar(), flatId}) - : forLoop.getInductionVar(); + Value newFlatProducerId = rewriter.create( + loc, ValueRange{forLoop.getInductionVar(), flatId}, + ArrayRef{maxIteration, + rewriter.getIndexAttr(flatTotalNumWorkers)}, + /*disjoint=*/true); // We require a descending relative mapping, so we can reuse the upper bound // sizes directly. @@ -147,6 +153,22 @@ LogicalResult resolveGPUMappedForallOp(RewriterBase &rewriter, newBlockArgs); rewriter.eraseOp(forallTerminator); rewriter.eraseOp(forallOp); + + // Step 5. Create the post-loop code that only executes on some workitems. + if (hasPostLoopTail) { + rewriter.setInsertionPointAfter(forLoop); + IRMapping cloneMap; + Value willExecuteTail = rewriter.create( + loc, arith::CmpIPredicate::slt, flatId, + getValueOrCreateConstantIndexOp(rewriter, loc, remainingLanes)); + auto tailIfOp = rewriter.create( + loc, TypeRange{}, willExecuteTail, /*addThenBlock=*/false, + /*addElseBlock=*/false); + cloneMap.map(forLoop.getInductionVar(), ub); + // We're relying on the fact that `scf.for` and `scf.if` share the same + // terminator. + forLoop.getRegion().cloneInto(&tailIfOp.getThenRegion(), cloneMap); + } return success(); } diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_distribute_forall.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_distribute_forall.mlir index 32bda8c90f050..1d96d2902f0dd 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_distribute_forall.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_distribute_forall.mlir @@ -16,8 +16,8 @@ func.func @distribute_thread_forall(%out : memref) // CHECK-DAG: %[[TX:.+]] = gpu.thread_id x // CHECK-DAG: %[[TY:.+]] = gpu.thread_id y // CHECK: %[[TFLAT:.+]] = affine.linearize_index disjoint [%[[TY]], %[[TX]]] by (2, 64) -// CHECK: scf.for %[[I:.+]] = %c0 to %c1024 step %c128 { -// CHECK: %[[LINID:.+]] = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%[[I]])[%[[TFLAT]]] +// CHECK: scf.for %[[I:.+]] = %c0 to %c8 step %c1 { +// CHECK: %[[LINID:.+]] = affine.linearize_index disjoint [%[[I]], %[[TFLAT]]] by (8, 128) // CHECK: memref.store {{.*}}[%[[LINID]]] // ----- @@ -38,8 +38,8 @@ func.func @distribute_warp_forall(%out : memref) // CHECK-DAG: %[[TY:.+]] = gpu.thread_id y // CHECK: %[[TFLAT:.+]] = affine.linearize_index disjoint [%[[TY]], %[[TX]]] by (2, 64) // CHECK: %[[WARPSPLIT:.+]]:2 = affine.delinearize_index %[[TFLAT]] into (4, 32) -// CHECK: scf.for %[[I:.+]] = %c0 to %c32 step %c4 { -// CHECK: %[[LINID:.+]] = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%[[I]])[%[[WARPSPLIT]]#0] +// CHECK: scf.for %[[I:.+]] = %c0 to %c8 step %c1 { +// CHECK: %[[LINID:.+]] = affine.linearize_index disjoint [%[[I]], %[[WARPSPLIT]]#0] by (8, 4) // CHECK: memref.store {{.*}}[%[[LINID]]] // ----- @@ -96,8 +96,10 @@ func.func @distribute_thread_forall_single_thread(%out : memref) // CHECK-DAG: %[[TX:.+]] = gpu.thread_id x // CHECK-DAG: %[[TY:.+]] = gpu.thread_id y // CHECK: %[[TFLAT:.+]] = affine.linearize_index disjoint [%[[TY]], %[[TX]]] by (2, 64) -// CHECK: scf.for %[[I:.+]] = %[[TFLAT]] to %c1 step %c128 { -// CHECK: memref.store {{.*}}[%[[I]]] +// CHECK-NOT: scf.for +// CHECK: %[[TIDGUARD:.+]] = arith.cmpi slt, %[[TFLAT]], %[[C1]] +// CHECK: scf.if %[[TIDGUARD]] { +// CHECK: memref.store {{.*}}[%[[TFLAT]]] // ----- @@ -113,12 +115,18 @@ func.func @distribute_thread_forall_overhang(%out : memref) } // CHECK-LABEL: func @distribute_thread_forall_overhang -// CHECK-DAG: %[[C513:.+]] = arith.constant 513 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index // CHECK-DAG: %[[TX:.+]] = gpu.thread_id x // CHECK-DAG: %[[TY:.+]] = gpu.thread_id y // CHECK: %[[TFLAT:.+]] = affine.linearize_index disjoint [%[[TY]], %[[TX]]] by (2, 64) -// CHECK: scf.for %[[I:.+]] = %[[TFLAT]] to %[[C513]] step %c128 { -// CHECK: memref.store {{.*}}[%[[I]]] +// CHECK: scf.for %[[I:.+]] = %c0 to %[[C4]] step %[[C1]] { +// CHECK: %[[LINID:.+]] = affine.linearize_index disjoint [%[[I]], %[[TFLAT]]] by (5, 128) +// CHECK: memref.store {{.*}}[%[[LINID]]] +// CHECK: %[[TIDGUARD:.+]] = arith.cmpi slt, %[[TFLAT]], %[[C1]] +// CHECK: scf.if %[[TIDGUARD]] { +// CHECK: %[[LINID_IF:.+]] = affine.linearize_index disjoint [%[[C4]], %[[TFLAT]]] +// CHECK: memref.store {{.*}}[%[[LINID_IF]]] // ----- @@ -137,8 +145,8 @@ func.func @distribute_thread_forall_multi_dim(%out : memref) // CHECK-DAG: %[[TX:.+]] = gpu.thread_id x // CHECK-DAG: %[[TY:.+]] = gpu.thread_id y // CHECK: %[[TFLAT:.+]] = affine.linearize_index disjoint [%[[TY]], %[[TX]]] by (2, 64) -// CHECK: scf.for %[[I:.+]] = %c0 to %c512 step %c128 { -// CHECK: %[[LINID:.+]] = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%[[I]])[%[[TFLAT]]] +// CHECK: scf.for %[[I:.+]] = %c0 to %c4 step %c1 { +// CHECK: %[[LINID:.+]] = affine.linearize_index disjoint [%[[I]], %[[TFLAT]]] by (4, 128) // CHECK: %[[DELIN:.+]]:3 = affine.delinearize_index %[[LINID]] into (16, 8, 4) : index // CHECK: memref.store {{.*}}[%[[DELIN]]#0, %[[DELIN]]#1, %[[DELIN]]#2] diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir index 6d8695b372c63..93362198e2f7c 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir @@ -902,7 +902,8 @@ hal.executable public @main { // CHECK-LABEL: func @small_matvec // CHECK-DAG: %[[B2:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(2) -// CHECK: scf.for %{{.*}} = %{{.*}} to %c1 step %c64 +// CHECK: %[[COND:.+]] = arith.cmpi slt, %{{.*}}, %c1 +// CHECK: scf.if %[[COND]] // Verify that the write does not get hoisted out of the single threaded // for loop.