Skip to content

Commit

Permalink
[GPU] Use scf.if for forall overhangs
Browse files Browse the repository at this point in the history
In cases where we can't determine if the number of workitems per
workgroup evenly divides the set of items that's required for an
scf.forall, the current code uses
`scf.for %i = %id to %upperBound step %numWorkitems`
in order to make the last loop iteration only run on the expected
fraction of wworkitems.

This commit enables using linearize (and a step-1 loop) in the main
body of the for loop the forall is being lowered to by switching to a
post-loop if statement instead.
  • Loading branch information
krzysz00 committed Nov 26, 2024
1 parent 5708d42 commit 9368746
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -100,43 +100,49 @@ LogicalResult resolveGPUMappedForallOp(RewriterBase &rewriter,
.getResult(0);
}

SmallVector<Value> delinSizes;
OpFoldResult totalLoopTripCount = rewriter.getIndexAttr(1);
SmallVector<OpFoldResult> delinSizes;
OpFoldResult producerCount = rewriter.getIndexAttr(1);
for (auto workerCount : forallOp.getMixedUpperBound()) {
delinSizes.push_back(
getValueOrCreateConstantIndexOp(rewriter, loc, workerCount));
totalLoopTripCount = affine::makeComposedFoldedAffineApply(
rewriter, loc, d0 * d1, {totalLoopTripCount, workerCount});
delinSizes.push_back(workerCount);
producerCount = affine::makeComposedFoldedAffineApply(
rewriter, loc, d0 * d1, {producerCount, workerCount});
}

// If the total number of producers doesn't evenly divide into
int64_t flatTotalNumWorkers =
hasWarpMapping ? flatWorkgroupSize / subgroupSize : flatWorkgroupSize;
std::optional<int64_t> staticProducerCount =
getConstantIntValue(totalLoopTripCount);
bool perfectlyDivides =
staticProducerCount &&
staticProducerCount.value() % flatTotalNumWorkers == 0;
OpFoldResult newLoopTripCount = affine::makeComposedFoldedAffineApply(
rewriter, loc, d0.floorDiv(flatTotalNumWorkers), producerCount);
OpFoldResult remainingLanes = affine::makeComposedFoldedAffineApply(
rewriter, loc, d0 % flatTotalNumWorkers, {producerCount});

// If the loop isn't guaranteed to perfectly tile onto the workers,
// we will run one more iteration of the loop on the workitems where it
// needs to execute.
std::optional<int64_t> remainingLanesCount =
getConstantIntValue(remainingLanes);
bool hasPostLoopTail =
!remainingLanesCount || remainingLanesCount.value() != 0;
OpFoldResult maxIteration =
hasPostLoopTail
? affine::makeComposedFoldedAffineApply(
rewriter, loc, d0.ceilDiv(flatTotalNumWorkers), {producerCount})
: newLoopTripCount;

// Step 3. Create the `scf.for` loop for the loop.
// If the workgroup count perfectly divides the loop's worker count, then we
// can use a lower bound of 0 and keep the loop bounds static. This helps
// simplify later loop folding patterns without an `affine.linearize_index` op
// to help with inferring int ranges.
Value lb = perfectlyDivides ? rewriter.create<arith::ConstantIndexOp>(loc, 0)
: flatId;
Value ub = getValueOrCreateConstantIndexOp(rewriter, loc, totalLoopTripCount);
Value step =
rewriter.create<arith::ConstantIndexOp>(loc, flatTotalNumWorkers);
Value lb = rewriter.create<arith::ConstantIndexOp>(loc, 0);
Value ub = getValueOrCreateConstantIndexOp(rewriter, loc, newLoopTripCount);
Value step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
auto forLoop = rewriter.create<scf::ForOp>(loc, lb, ub, step, ValueRange{});
Block *loopBody = forLoop.getBody();

// Get the replacement IDs for the forall iterator ids.
rewriter.setInsertionPointToStart(loopBody);
Value newFlatProducerId =
perfectlyDivides
? affine::makeComposedAffineApply(rewriter, loc, d0 + d1,
{forLoop.getInductionVar(), flatId})
: forLoop.getInductionVar();
Value newFlatProducerId = rewriter.create<affine::AffineLinearizeIndexOp>(
loc, ValueRange{forLoop.getInductionVar(), flatId},
ArrayRef<OpFoldResult>{maxIteration,
rewriter.getIndexAttr(flatTotalNumWorkers)},
/*disjoint=*/true);

// We require a descending relative mapping, so we can reuse the upper bound
// sizes directly.
Expand All @@ -151,6 +157,22 @@ LogicalResult resolveGPUMappedForallOp(RewriterBase &rewriter,
newBlockArgs);
rewriter.eraseOp(forallTerminator);
rewriter.eraseOp(forallOp);

// Step 5. Create the post-loop code that only executes on some workitems.
if (hasPostLoopTail) {
rewriter.setInsertionPointAfter(forLoop);
IRMapping cloneMap;
Value willExecuteTail = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, flatId,
getValueOrCreateConstantIndexOp(rewriter, loc, remainingLanes));
auto tailIfOp = rewriter.create<scf::IfOp>(
loc, TypeRange{}, willExecuteTail, /*addThenBlock=*/false,
/*addElseBlock=*/false);
cloneMap.map(forLoop.getInductionVar(), ub);
// We're relying on the fact that `scf.for` and `scf.if` share the same
// terminator.
forLoop.getRegion().cloneInto(&tailIfOp.getThenRegion(), cloneMap);
}
return success();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ func.func @distribute_thread_forall(%out : memref<?xi32>)
// CHECK-DAG: %[[TX:.+]] = gpu.thread_id x
// CHECK-DAG: %[[TY:.+]] = gpu.thread_id y
// CHECK: %[[TFLAT:.+]] = affine.linearize_index disjoint [%[[TY]], %[[TX]]] by (2, 64)
// CHECK: scf.for %[[I:.+]] = %c0 to %c1024 step %c128 {
// CHECK: %[[LINID:.+]] = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%[[I]])[%[[TFLAT]]]
// CHECK: scf.for %[[I:.+]] = %c0 to %c8 step %c1 {
// CHECK: %[[LINID:.+]] = affine.linearize_index disjoint [%[[I]], %[[TFLAT]]] by (8, 128)
// CHECK: memref.store {{.*}}[%[[LINID]]]

// -----
Expand All @@ -38,8 +38,8 @@ func.func @distribute_warp_forall(%out : memref<?xi32>)
// CHECK-DAG: %[[TY:.+]] = gpu.thread_id y
// CHECK: %[[TFLAT:.+]] = affine.linearize_index disjoint [%[[TY]], %[[TX]]] by (2, 64)
// CHECK: %[[WARPSPLIT:.+]]:2 = affine.delinearize_index %[[TFLAT]] into (4, 32)
// CHECK: scf.for %[[I:.+]] = %c0 to %c32 step %c4 {
// CHECK: %[[LINID:.+]] = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%[[I]])[%[[WARPSPLIT]]#0]
// CHECK: scf.for %[[I:.+]] = %c0 to %c8 step %c1 {
// CHECK: %[[LINID:.+]] = affine.linearize_index disjoint [%[[I]], %[[WARPSPLIT]]#0] by (8, 4)
// CHECK: memref.store {{.*}}[%[[LINID]]]

// -----
Expand Down Expand Up @@ -96,8 +96,10 @@ func.func @distribute_thread_forall_single_thread(%out : memref<?xi32>)
// CHECK-DAG: %[[TX:.+]] = gpu.thread_id x
// CHECK-DAG: %[[TY:.+]] = gpu.thread_id y
// CHECK: %[[TFLAT:.+]] = affine.linearize_index disjoint [%[[TY]], %[[TX]]] by (2, 64)
// CHECK: scf.for %[[I:.+]] = %[[TFLAT]] to %c1 step %c128 {
// CHECK: memref.store {{.*}}[%[[I]]]
// CHECK-NOT: scf.for
// CHECK: %[[TIDGUARD:.+]] = arith.cmpi slt, %[[TFLAT]], %[[C1]]
// CHECK: scf.if %[[TIDGUARD]] {
// CHECK: memref.store {{.*}}[%[[TFLAT]]]

// -----

Expand All @@ -113,12 +115,18 @@ func.func @distribute_thread_forall_overhang(%out : memref<?xi32>)
}

// CHECK-LABEL: func @distribute_thread_forall_overhang
// CHECK-DAG: %[[C513:.+]] = arith.constant 513 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index
// CHECK-DAG: %[[TX:.+]] = gpu.thread_id x
// CHECK-DAG: %[[TY:.+]] = gpu.thread_id y
// CHECK: %[[TFLAT:.+]] = affine.linearize_index disjoint [%[[TY]], %[[TX]]] by (2, 64)
// CHECK: scf.for %[[I:.+]] = %[[TFLAT]] to %[[C513]] step %c128 {
// CHECK: memref.store {{.*}}[%[[I]]]
// CHECK: scf.for %[[I:.+]] = %c0 to %[[C4]] step %[[C1]] {
// CHECK: %[[LINID:.+]] = affine.linearize_index disjoint [%[[I]], %[[TFLAT]]] by (5, 128)
// CHECK: memref.store {{.*}}[%[[LINID]]]
// CHECK: %[[TIDGUARD:.+]] = arith.cmpi slt, %[[TFLAT]], %[[C1]]
// CHECK: scf.if %[[TIDGUARD]] {
// CHECK: %[[LINID_IF:.+]] = affine.linearize_index disjoint [%[[C4]], %[[TFLAT]]]
// CHECK: memref.store {{.*}}[%[[LINID_IF]]]

// -----

Expand All @@ -137,8 +145,8 @@ func.func @distribute_thread_forall_multi_dim(%out : memref<?x?x?xi32>)
// CHECK-DAG: %[[TX:.+]] = gpu.thread_id x
// CHECK-DAG: %[[TY:.+]] = gpu.thread_id y
// CHECK: %[[TFLAT:.+]] = affine.linearize_index disjoint [%[[TY]], %[[TX]]] by (2, 64)
// CHECK: scf.for %[[I:.+]] = %c0 to %c512 step %c128 {
// CHECK: %[[LINID:.+]] = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%[[I]])[%[[TFLAT]]]
// CHECK: scf.for %[[I:.+]] = %c0 to %c4 step %c1 {
// CHECK: %[[LINID:.+]] = affine.linearize_index disjoint [%[[I]], %[[TFLAT]]] by (4, 128)
// CHECK: %[[DELIN:.+]]:3 = affine.delinearize_index %[[LINID]] into (16, 8, 4) : index
// CHECK: memref.store {{.*}}[%[[DELIN]]#0, %[[DELIN]]#1, %[[DELIN]]#2]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -902,7 +902,8 @@ hal.executable public @main {

// CHECK-LABEL: func @small_matvec
// CHECK-DAG: %[[B2:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(2)
// CHECK: scf.for %{{.*}} = %{{.*}} to %c1 step %c64
// CHECK: %[[COND:.+]] = arith.cmpi slt, %{{.*}}, %c1
// CHECK: scf.if %[[COND]]

// Verify that the write does not get hoisted out of the single threaded
// for loop.
Expand Down

0 comments on commit 9368746

Please sign in to comment.