Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GPU] Use affine.linearize_index (and delinearize_index) where possible #19087

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "mlir/Dialect/GPU/Transforms/Passes.h"
#include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"

namespace mlir::iree_compiler {

Expand Down Expand Up @@ -87,48 +88,57 @@ LogicalResult resolveGPUMappedForallOp(RewriterBase &rewriter,
assert(!(hasThreadMapping && hasWarpMapping));
Value flatId = linearThreadId;
if (hasWarpMapping) {
OpFoldResult subgroupSizeVal = rewriter.getIndexAttr(subgroupSize);
flatId = affine::makeComposedAffineApply(rewriter, loc, d0.floorDiv(d1),
{flatId, subgroupSizeVal});
flatId = rewriter
.create<affine::AffineDelinearizeIndexOp>(
loc, flatId,
ArrayRef<int64_t>{flatWorkgroupSize / subgroupSize,
subgroupSize})
.getResult(0);
}

SmallVector<Value> delinSizes;
OpFoldResult totalLoopTripCount = rewriter.getIndexAttr(1);
SmallVector<OpFoldResult> delinSizes;
OpFoldResult producerCount = rewriter.getIndexAttr(1);
for (auto workerCount : forallOp.getMixedUpperBound()) {
delinSizes.push_back(
getValueOrCreateConstantIndexOp(rewriter, loc, workerCount));
totalLoopTripCount = affine::makeComposedFoldedAffineApply(
rewriter, loc, d0 * d1, {totalLoopTripCount, workerCount});
delinSizes.push_back(workerCount);
producerCount = affine::makeComposedFoldedAffineApply(
rewriter, loc, d0 * d1, {producerCount, workerCount});
}

// If the total number of producers doesn't evenly divide into
int64_t flatTotalNumWorkers =
hasWarpMapping ? flatWorkgroupSize / subgroupSize : flatWorkgroupSize;
std::optional<int64_t> staticProducerCount =
getConstantIntValue(totalLoopTripCount);
bool perfectlyDivides =
staticProducerCount &&
staticProducerCount.value() % flatTotalNumWorkers == 0;
OpFoldResult newLoopTripCount = affine::makeComposedFoldedAffineApply(
rewriter, loc, d0.floorDiv(flatTotalNumWorkers), producerCount);
OpFoldResult remainingLanes = affine::makeComposedFoldedAffineApply(
rewriter, loc, d0 % flatTotalNumWorkers, {producerCount});

// If the loop isn't guaranteed to perfectly tile onto the workers,
// we will run one more iteration of the loop on the workitems where it
// needs to execute.
std::optional<int64_t> remainingLanesCount =
getConstantIntValue(remainingLanes);
bool hasPostLoopTail =
!remainingLanesCount || remainingLanesCount.value() != 0;
OpFoldResult maxIteration =
hasPostLoopTail
? affine::makeComposedFoldedAffineApply(
rewriter, loc, d0.ceilDiv(flatTotalNumWorkers), {producerCount})
: newLoopTripCount;

// Step 3. Create the `scf.for` loop for the loop.
// If the workgroup count perfectly divides the loop's worker count, then we
// can use a lower bound of 0 and keep the loop bounds static. This helps
// simplify later loop folding patterns without an `affine.linearize_index` op
// to help with inferring int ranges.
Value lb = perfectlyDivides ? rewriter.create<arith::ConstantIndexOp>(loc, 0)
: flatId;
Value ub = getValueOrCreateConstantIndexOp(rewriter, loc, totalLoopTripCount);
Value step =
rewriter.create<arith::ConstantIndexOp>(loc, flatTotalNumWorkers);
Value lb = rewriter.create<arith::ConstantIndexOp>(loc, 0);
Value ub = getValueOrCreateConstantIndexOp(rewriter, loc, newLoopTripCount);
Value step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
auto forLoop = rewriter.create<scf::ForOp>(loc, lb, ub, step, ValueRange{});
Block *loopBody = forLoop.getBody();

// Get the replacement IDs for the forall iterator ids.
rewriter.setInsertionPointToStart(loopBody);
Value newFlatProducerId =
perfectlyDivides
? affine::makeComposedAffineApply(rewriter, loc, d0 + d1,
{forLoop.getInductionVar(), flatId})
: forLoop.getInductionVar();
Value newFlatProducerId = rewriter.create<affine::AffineLinearizeIndexOp>(
loc, ValueRange{forLoop.getInductionVar(), flatId},
ArrayRef<OpFoldResult>{maxIteration,
rewriter.getIndexAttr(flatTotalNumWorkers)},
/*disjoint=*/true);

// We require a descending relative mapping, so we can reuse the upper bound
// sizes directly.
Expand All @@ -143,6 +153,22 @@ LogicalResult resolveGPUMappedForallOp(RewriterBase &rewriter,
newBlockArgs);
rewriter.eraseOp(forallTerminator);
rewriter.eraseOp(forallOp);

// Step 5. Create the post-loop code that only executes on some workitems.
if (hasPostLoopTail) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we move this into a separate PR?

rewriter.setInsertionPointAfter(forLoop);
IRMapping cloneMap;
Value willExecuteTail = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, flatId,
getValueOrCreateConstantIndexOp(rewriter, loc, remainingLanes));
auto tailIfOp = rewriter.create<scf::IfOp>(
loc, TypeRange{}, willExecuteTail, /*addThenBlock=*/false,
/*addElseBlock=*/false);
cloneMap.map(forLoop.getInductionVar(), ub);
// We're relying on the fact that `scf.for` and `scf.if` share the same
// terminator.
forLoop.getRegion().cloneInto(&tailIfOp.getThenRegion(), cloneMap);
}
return success();
}

Expand Down Expand Up @@ -190,23 +216,18 @@ void GPUDistributeForallPass::runOnOperation() {
return signalPassFailure();
}

AffineExpr x, y, z;
bindSymbols(funcOp.getContext(), x, y, z);
// Compute the linearized thread id.
AffineExpr linearId =
x + workgroupSize[0] * y + workgroupSize[1] * workgroupSize[0] * z;

rewriter.setInsertionPointToStart(&funcOp.getFunctionBody().front());
SmallVector<OpFoldResult> threadGrid = {
rewriter.createOrFold<gpu::ThreadIdOp>(funcOp.getLoc(),
gpu::Dimension::x),
rewriter.createOrFold<gpu::ThreadIdOp>(funcOp.getLoc(),
gpu::Dimension::y),
rewriter.createOrFold<gpu::ThreadIdOp>(funcOp.getLoc(),
gpu::Dimension::z)};

Value linearThreadIdVal = affine::makeComposedAffineApply(
rewriter, funcOp.getLoc(), linearId, threadGrid);
SmallVector<Value> threadGrid = {rewriter.createOrFold<gpu::ThreadIdOp>(
funcOp.getLoc(), gpu::Dimension::z),
rewriter.createOrFold<gpu::ThreadIdOp>(
funcOp.getLoc(), gpu::Dimension::y),
rewriter.createOrFold<gpu::ThreadIdOp>(
funcOp.getLoc(), gpu::Dimension::x)};
SmallVector<int64_t> threadGridBasis = {workgroupSize[2], workgroupSize[1],
workgroupSize[0]};

Value linearThreadIdVal = rewriter.create<affine::AffineLinearizeIndexOp>(
funcOp.getLoc(), threadGrid, threadGridBasis, /*disjoint=*/true);
for (auto forall : forallOps) {
rewriter.setInsertionPoint(forall);
if (failed(resolveGPUMappedForallOp(rewriter, forall, linearThreadIdVal,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,30 +189,29 @@ SmallVector<linalg::ProcInfo> getIds(OpBuilder &b, Location loc,
ArrayRef<Range> parallelLoopRanges,
Value flatThreadId) {
SmallVector<linalg::ProcInfo> infos;
Value id = flatThreadId;
AffineExpr d0 = b.getAffineDimExpr(0);
for (Range r : llvm::reverse(parallelLoopRanges)) {
linalg::ProcInfo info;
SmallVector<int64_t> delinSizes;
for (Range r : parallelLoopRanges) {
auto offset = dyn_cast<Attribute>(r.offset);
auto stride = dyn_cast<Attribute>(r.stride);
auto size = dyn_cast<Attribute>(r.size);
assert(offset && stride && size);
int64_t numThreadsDim = (llvm::cast<IntegerAttr>(size).getInt() -
llvm::cast<IntegerAttr>(offset).getInt()) /
llvm::cast<IntegerAttr>(stride).getInt();
Value dimId = id;
if (infos.size() != parallelLoopRanges.size() - 1)
dimId =
affine::makeComposedAffineApply(b, loc, d0 % numThreadsDim, {dimId});
delinSizes.push_back(numThreadsDim);
}
ValueRange dims =
b.create<affine::AffineDelinearizeIndexOp>(loc, flatThreadId, delinSizes)
.getResults();

for (auto [dimId, numThreadsDim] : llvm::zip_equal(dims, delinSizes)) {
linalg::ProcInfo info;
info.procId = dimId;
info.nprocs = b.create<arith::ConstantIndexOp>(loc, numThreadsDim);
info.distributionMethod =
linalg::DistributionMethod::CyclicNumProcsEqNumIters;
infos.push_back(info);
id = affine::makeComposedAffineApply(b, loc, d0.floorDiv(numThreadsDim),
{id});
}
std::reverse(infos.begin(), infos.end());
return infos;
}

Expand Down Expand Up @@ -288,19 +287,16 @@ static Value createFlatId(mlir::FunctionOpInterface funcOp,
ArrayRef<int64_t> workgroupSize) {
OpBuilder b(funcOp.getFunctionBody());
Type indexType = b.getIndexType();
AffineExpr d0 = getAffineDimExpr(0, b.getContext());
AffineExpr d1 = getAffineDimExpr(1, b.getContext());
AffineExpr d2 = getAffineDimExpr(2, b.getContext());
Value threadX =
b.create<gpu::ThreadIdOp>(funcOp.getLoc(), indexType, gpu::Dimension::x);
Value threadY =
b.create<gpu::ThreadIdOp>(funcOp.getLoc(), indexType, gpu::Dimension::y);
Value threadZ =
b.create<gpu::ThreadIdOp>(funcOp.getLoc(), indexType, gpu::Dimension::z);
Value flatThreadId = affine::makeComposedAffineApply(
b, funcOp.getLoc(),
d0 + workgroupSize[0] * d1 + (workgroupSize[0] * workgroupSize[1]) * d2,
{threadX, threadY, threadZ});
Value flatThreadId = b.create<affine::AffineLinearizeIndexOp>(
funcOp.getLoc(), ValueRange{threadZ, threadY, threadX},
ArrayRef<int64_t>{workgroupSize[2], workgroupSize[1], workgroupSize[0]},
/*disjoint=*/true);
return flatThreadId;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,54 +32,55 @@ namespace {
/// parameterized by the thread grid.
static SmallVector<Value> computeSIMDIndex(const LayoutIterator::State &state,
LayoutAttr layout, Value laneId,
int64_t subgroupSize,
RewriterBase &rewriter) {
MLIRContext *ctx = layout.getContext();
AffineExpr threadX, threadY, threadZ;
bindSymbols(ctx, threadX, threadY, threadZ);
Location loc = laneId.getLoc();

auto [laneDimX, laneDimY, laneDimZ] = layout.getLaneGrid();
int64_t gridsPerSubgroup =
llvm::divideCeil(subgroupSize, laneDimX * laneDimY * laneDimZ);
// Note: we add an extra entry to the delinearization here so that, if the
// vector layout requires fewer lanes than are present in the subgroup.
// Otherwise, we'd, for example, construct delinearizations with the basis (1,
// 1, 16) when there are 32 lanes, which would simplify to no delinearization
// at all. To resolve this, we add an extra term to the grid to capture the
// overflow.
auto reversedLaneGrid = rewriter.create<affine::AffineDelinearizeIndexOp>(
loc, laneId,
ArrayRef<int64_t>{gridsPerSubgroup, laneDimZ, laneDimY, laneDimX});

SmallVector<Value> simdIndex;

// Calculate the index for each dim separately.
for (PerDimLayoutAttr dimLayout : layout.getLayouts()) {
AffineExpr offset = getAffineConstantExpr(0, ctx);
AffineExpr stride = getAffineConstantExpr(1, ctx);
for (auto [label, shape] : llvm::reverse(
llvm::zip(dimLayout.getLabels(), dimLayout.getShapes()))) {
SmallVector<Value> linearizeVals;
for (LayoutDimensionAttr label : dimLayout.getLabels()) {
int64_t position = state.lookup(label.getValue()).getPosition();

// Note: indices are into a reversed lane grid that has an extra leading
// term we must ignore (so the X coordinate is result #3 and the Z
// coordinate is result #1).
switch (label.getValue()) {
case LayoutDimension::LANEX:
offset = offset + stride * threadX;
linearizeVals.push_back(reversedLaneGrid.getResult(3));
break;
case LayoutDimension::LANEY:
offset = offset + stride * threadY;
linearizeVals.push_back(reversedLaneGrid.getResult(2));
break;
case LayoutDimension::LANEZ:
offset = offset + stride * threadZ;
linearizeVals.push_back(reversedLaneGrid.getResult(1));
break;
default:
offset = offset + stride * getAffineConstantExpr(position, ctx);
linearizeVals.push_back(
rewriter.createOrFold<arith::ConstantIndexOp>(loc, position));
break;
}
stride = stride * getAffineConstantExpr(shape, ctx);
}

auto [laneDimX, laneDimY, laneDimZ] = layout.getLaneGrid();
SmallVector<Value> laneGrid = {
rewriter.create<arith::ConstantIndexOp>(laneId.getLoc(), laneDimZ),
rewriter.create<arith::ConstantIndexOp>(laneId.getLoc(), laneDimY),
rewriter.create<arith::ConstantIndexOp>(laneId.getLoc(), laneDimX)};
FailureOr<SmallVector<Value>> maybeReversedLaneGridVals =
affine::delinearizeIndex(rewriter, laneId.getLoc(), laneId, laneGrid);
assert(succeeded(maybeReversedLaneGridVals) &&
"Failed to delinearize lane index");
SmallVector<Value> laneGridVals = {(*maybeReversedLaneGridVals)[2],
(*maybeReversedLaneGridVals)[1],
(*maybeReversedLaneGridVals)[0]};

// Compute the index for the dim.
AffineMap indexMap = AffineMap::get(0, 3, offset);
Value index = rewriter.create<affine::AffineApplyOp>(
rewriter.getUnknownLoc(), indexMap, laneGridVals);
Value index = rewriter.create<affine::AffineLinearizeIndexOp>(
rewriter.getUnknownLoc(), linearizeVals, dimLayout.getShapes(),
/*disjoint=*/true);
simdIndex.push_back(index);
}

Expand Down Expand Up @@ -199,8 +200,9 @@ struct DistributeXferLayoutAttr : OpDistributionPattern<OpTy> {
"expected vector::TransferReadOp or vector::TransferWriteOp");

DistributeXferLayoutAttr(MLIRContext *context, Value laneId,
PatternBenefit benefit = 1)
: OpDistributionPattern<OpTy>(context, benefit), laneId(laneId) {}
int64_t subgroupSize, PatternBenefit benefit = 1)
: OpDistributionPattern<OpTy>(context, benefit), laneId(laneId),
subgroupSize(subgroupSize) {}

VectorValue accessMemory(OpTy xferOp, VectorValue accumulator,
LayoutAttr vectorLayout,
Expand Down Expand Up @@ -237,7 +239,7 @@ struct DistributeXferLayoutAttr : OpDistributionPattern<OpTy> {
llvm::SmallBitVector &projectedDims,
RewriterBase &rewriter) const {
SmallVector<Value> simdIndices =
computeSIMDIndex(state, memoryLayout, laneId, rewriter);
computeSIMDIndex(state, memoryLayout, laneId, subgroupSize, rewriter);
SmallVector<Value> memoryIndices(indices);

// The memory layout has some projected leading dims that indices doesn't.
Expand Down Expand Up @@ -272,6 +274,7 @@ struct DistributeXferLayoutAttr : OpDistributionPattern<OpTy> {
}

Value laneId;
int64_t subgroupSize;
};

struct DistributeTransferReadLayoutAttr final
Expand Down Expand Up @@ -940,26 +943,28 @@ struct DistributeLayoutConflictToSharedMemory final

// Offset and indexing computation such that subgroups can
// write and read to shared memory correctly and without conflicts.
AffineExpr d0, d1, d2, s0;
bindDims(rewriter.getContext(), d0, d1, d2);
bindSymbols(rewriter.getContext(), s0);
auto indexType = rewriter.getIndexType();
Value threadX =
rewriter.create<gpu::ThreadIdOp>(loc, indexType, gpu::Dimension::x);
Value threadY =
rewriter.create<gpu::ThreadIdOp>(loc, indexType, gpu::Dimension::y);
Value threadZ =
rewriter.create<gpu::ThreadIdOp>(loc, indexType, gpu::Dimension::z);
Value flatThreadId = affine::makeComposedAffineApply(
rewriter, loc,
(d0 + workgroupSize.value()[0] * d1 +
(workgroupSize.value()[0] * workgroupSize.value()[1]) * d2),
{threadX, threadY, threadZ});
Value subgroupOffset = affine::makeComposedAffineApply(
rewriter, loc,
s0.floorDiv(subgroupSize.value()) *
resolutionType.getShape()[vectorRank - 1],
{flatThreadId});
Value flatThreadId = rewriter.create<affine::AffineLinearizeIndexOp>(
loc, ValueRange{threadZ, threadY, threadX},
ArrayRef<int64_t>{workgroupSize.value()[2], workgroupSize.value()[1],
workgroupSize.value()[0]},
/*disjoint=*/true);

Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
auto splitBySubgroups = rewriter.create<affine::AffineDelinearizeIndexOp>(
loc, flatThreadId,
ArrayRef<int64_t>{numSubgroups, subgroupSize.value()});
Value subgroupOffset = rewriter.create<affine::AffineLinearizeIndexOp>(
loc, ValueRange{splitBySubgroups.getResult(0), c0},
ArrayRef<int64_t>{numSubgroups,
resolutionType.getShape()[vectorRank - 1]},
/*disjoint=*/true);

// Create shared memory to store the intermediate from src layout.
auto workgroupMemoryAddressSpace = Attribute(gpu::AddressSpaceAttr::get(
Expand All @@ -980,7 +985,6 @@ struct DistributeLayoutConflictToSharedMemory final
shapes, strides);

// Creating write/trip to shared memory using src layout.
Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
SmallVector<Value> indices(resolutionType.getRank(), c0);
SmallVector<bool> inBounds(vectorRank, true);
auto write = rewriter.create<vector::TransferWriteOp>(loc, vector, subview,
Expand Down Expand Up @@ -1117,10 +1121,11 @@ void populateGPUDistributionPatterns(RewritePatternSet &patterns) {
}

void populateGPUDistributionLayoutAttrPatterns(Value laneId,
int64_t subgroupSize,
RewritePatternSet &patterns) {
patterns
.add<DistributeTransferReadLayoutAttr, DistributeTransferWriteLayoutAttr>(
patterns.getContext(), laneId);
patterns.getContext(), laneId, subgroupSize);
patterns.add<DistributeBroadcastLayoutAttr, DistributeTransposeLayoutAttr>(
patterns.getContext());
}
Expand Down
Loading
Loading