Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GPU] Use scf.if for forall overhangs #19125

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -100,43 +100,49 @@ LogicalResult resolveGPUMappedForallOp(RewriterBase &rewriter,
.getResult(0);
}

SmallVector<Value> delinSizes;
OpFoldResult totalLoopTripCount = rewriter.getIndexAttr(1);
SmallVector<OpFoldResult> delinSizes;
OpFoldResult producerCount = rewriter.getIndexAttr(1);
for (auto workerCount : forallOp.getMixedUpperBound()) {
delinSizes.push_back(
getValueOrCreateConstantIndexOp(rewriter, loc, workerCount));
totalLoopTripCount = affine::makeComposedFoldedAffineApply(
rewriter, loc, d0 * d1, {totalLoopTripCount, workerCount});
delinSizes.push_back(workerCount);
producerCount = affine::makeComposedFoldedAffineApply(
rewriter, loc, d0 * d1, {producerCount, workerCount});
}

// If the total number of producers doesn't evenly divide into
int64_t flatTotalNumWorkers =
hasWarpMapping ? flatWorkgroupSize / subgroupSize : flatWorkgroupSize;
std::optional<int64_t> staticProducerCount =
getConstantIntValue(totalLoopTripCount);
bool perfectlyDivides =
staticProducerCount &&
staticProducerCount.value() % flatTotalNumWorkers == 0;
OpFoldResult newLoopTripCount = affine::makeComposedFoldedAffineApply(
rewriter, loc, d0.floorDiv(flatTotalNumWorkers), producerCount);
OpFoldResult remainingLanes = affine::makeComposedFoldedAffineApply(
rewriter, loc, d0 % flatTotalNumWorkers, {producerCount});

// If the loop isn't guaranteed to perfectly tile onto the workers,
// we will run one more iteration of the loop on the workitems where it
// needs to execute.
std::optional<int64_t> remainingLanesCount =
getConstantIntValue(remainingLanes);
bool hasPostLoopTail =
!remainingLanesCount || remainingLanesCount.value() != 0;
OpFoldResult maxIteration =
hasPostLoopTail
? affine::makeComposedFoldedAffineApply(
rewriter, loc, d0.ceilDiv(flatTotalNumWorkers), {producerCount})
: newLoopTripCount;

// Step 3. Create the `scf.for` loop for the loop.
// If the workgroup count perfectly divides the loop's worker count, then we
// can use a lower bound of 0 and keep the loop bounds static. This helps
// simplify later loop folding patterns without an `affine.linearize_index` op
// to help with inferring int ranges.
Value lb = perfectlyDivides ? rewriter.create<arith::ConstantIndexOp>(loc, 0)
: flatId;
Value ub = getValueOrCreateConstantIndexOp(rewriter, loc, totalLoopTripCount);
Value step =
rewriter.create<arith::ConstantIndexOp>(loc, flatTotalNumWorkers);
Value lb = rewriter.create<arith::ConstantIndexOp>(loc, 0);
Value ub = getValueOrCreateConstantIndexOp(rewriter, loc, newLoopTripCount);
Value step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
auto forLoop = rewriter.create<scf::ForOp>(loc, lb, ub, step, ValueRange{});
Block *loopBody = forLoop.getBody();

// Get the replacement IDs for the forall iterator ids.
rewriter.setInsertionPointToStart(loopBody);
Value newFlatProducerId =
perfectlyDivides
? affine::makeComposedAffineApply(rewriter, loc, d0 + d1,
{forLoop.getInductionVar(), flatId})
: forLoop.getInductionVar();
Value newFlatProducerId = rewriter.create<affine::AffineLinearizeIndexOp>(
loc, ValueRange{forLoop.getInductionVar(), flatId},
ArrayRef<OpFoldResult>{maxIteration,
rewriter.getIndexAttr(flatTotalNumWorkers)},
/*disjoint=*/true);

// We require a descending relative mapping, so we can reuse the upper bound
// sizes directly.
Expand All @@ -151,6 +157,22 @@ LogicalResult resolveGPUMappedForallOp(RewriterBase &rewriter,
newBlockArgs);
rewriter.eraseOp(forallTerminator);
rewriter.eraseOp(forallOp);

// Step 5. Create the post-loop code that only executes on some workitems.
if (hasPostLoopTail) {
rewriter.setInsertionPointAfter(forLoop);
IRMapping cloneMap;
Value willExecuteTail = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, flatId,
getValueOrCreateConstantIndexOp(rewriter, loc, remainingLanes));
auto tailIfOp = rewriter.create<scf::IfOp>(
loc, TypeRange{}, willExecuteTail, /*addThenBlock=*/false,
/*addElseBlock=*/false);
cloneMap.map(forLoop.getInductionVar(), ub);
// We're relying on the fact that `scf.for` and `scf.if` share the same
// terminator.
forLoop.getRegion().cloneInto(&tailIfOp.getThenRegion(), cloneMap);
}
return success();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ func.func @distribute_thread_forall(%out : memref<?xi32>)
// CHECK-DAG: %[[TX:.+]] = gpu.thread_id x
// CHECK-DAG: %[[TY:.+]] = gpu.thread_id y
// CHECK: %[[TFLAT:.+]] = affine.linearize_index disjoint [%[[TY]], %[[TX]]] by (2, 64)
// CHECK: scf.for %[[I:.+]] = %c0 to %c1024 step %c128 {
// CHECK: %[[LINID:.+]] = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%[[I]])[%[[TFLAT]]]
// CHECK: scf.for %[[I:.+]] = %c0 to %c8 step %c1 {
// CHECK: %[[LINID:.+]] = affine.linearize_index disjoint [%[[I]], %[[TFLAT]]] by (8, 128)
// CHECK: memref.store {{.*}}[%[[LINID]]]

// -----
Expand All @@ -38,8 +38,8 @@ func.func @distribute_warp_forall(%out : memref<?xi32>)
// CHECK-DAG: %[[TY:.+]] = gpu.thread_id y
// CHECK: %[[TFLAT:.+]] = affine.linearize_index disjoint [%[[TY]], %[[TX]]] by (2, 64)
// CHECK: %[[WARPSPLIT:.+]]:2 = affine.delinearize_index %[[TFLAT]] into (4, 32)
// CHECK: scf.for %[[I:.+]] = %c0 to %c32 step %c4 {
// CHECK: %[[LINID:.+]] = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%[[I]])[%[[WARPSPLIT]]#0]
// CHECK: scf.for %[[I:.+]] = %c0 to %c8 step %c1 {
// CHECK: %[[LINID:.+]] = affine.linearize_index disjoint [%[[I]], %[[WARPSPLIT]]#0] by (8, 4)
// CHECK: memref.store {{.*}}[%[[LINID]]]

// -----
Expand Down Expand Up @@ -96,8 +96,10 @@ func.func @distribute_thread_forall_single_thread(%out : memref<?xi32>)
// CHECK-DAG: %[[TX:.+]] = gpu.thread_id x
// CHECK-DAG: %[[TY:.+]] = gpu.thread_id y
// CHECK: %[[TFLAT:.+]] = affine.linearize_index disjoint [%[[TY]], %[[TX]]] by (2, 64)
// CHECK: scf.for %[[I:.+]] = %[[TFLAT]] to %c1 step %c128 {
// CHECK: memref.store {{.*}}[%[[I]]]
// CHECK-NOT: scf.for
// CHECK: %[[TIDGUARD:.+]] = arith.cmpi slt, %[[TFLAT]], %[[C1]]
// CHECK: scf.if %[[TIDGUARD]] {
// CHECK: memref.store {{.*}}[%[[TFLAT]]]

// -----

Expand All @@ -113,12 +115,18 @@ func.func @distribute_thread_forall_overhang(%out : memref<?xi32>)
}

// CHECK-LABEL: func @distribute_thread_forall_overhang
// CHECK-DAG: %[[C513:.+]] = arith.constant 513 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index
// CHECK-DAG: %[[TX:.+]] = gpu.thread_id x
// CHECK-DAG: %[[TY:.+]] = gpu.thread_id y
// CHECK: %[[TFLAT:.+]] = affine.linearize_index disjoint [%[[TY]], %[[TX]]] by (2, 64)
// CHECK: scf.for %[[I:.+]] = %[[TFLAT]] to %[[C513]] step %c128 {
// CHECK: memref.store {{.*}}[%[[I]]]
// CHECK: scf.for %[[I:.+]] = %c0 to %[[C4]] step %[[C1]] {
// CHECK: %[[LINID:.+]] = affine.linearize_index disjoint [%[[I]], %[[TFLAT]]] by (5, 128)
// CHECK: memref.store {{.*}}[%[[LINID]]]
// CHECK: %[[TIDGUARD:.+]] = arith.cmpi slt, %[[TFLAT]], %[[C1]]
// CHECK: scf.if %[[TIDGUARD]] {
// CHECK: %[[LINID_IF:.+]] = affine.linearize_index disjoint [%[[C4]], %[[TFLAT]]]
// CHECK: memref.store {{.*}}[%[[LINID_IF]]]

// -----

Expand All @@ -137,8 +145,8 @@ func.func @distribute_thread_forall_multi_dim(%out : memref<?x?x?xi32>)
// CHECK-DAG: %[[TX:.+]] = gpu.thread_id x
// CHECK-DAG: %[[TY:.+]] = gpu.thread_id y
// CHECK: %[[TFLAT:.+]] = affine.linearize_index disjoint [%[[TY]], %[[TX]]] by (2, 64)
// CHECK: scf.for %[[I:.+]] = %c0 to %c512 step %c128 {
// CHECK: %[[LINID:.+]] = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%[[I]])[%[[TFLAT]]]
// CHECK: scf.for %[[I:.+]] = %c0 to %c4 step %c1 {
// CHECK: %[[LINID:.+]] = affine.linearize_index disjoint [%[[I]], %[[TFLAT]]] by (4, 128)
// CHECK: %[[DELIN:.+]]:3 = affine.delinearize_index %[[LINID]] into (16, 8, 4) : index
// CHECK: memref.store {{.*}}[%[[DELIN]]#0, %[[DELIN]]#1, %[[DELIN]]#2]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -902,7 +902,8 @@ hal.executable public @main {

// CHECK-LABEL: func @small_matvec
// CHECK-DAG: %[[B2:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(2)
// CHECK: scf.for %{{.*}} = %{{.*}} to %c1 step %c64
// CHECK: %[[COND:.+]] = arith.cmpi slt, %{{.*}}, %c1
// CHECK: scf.if %[[COND]]

// Verify that the write does not get hoisted out of the single threaded
// for loop.
Expand Down
Loading