Skip to content

Commit

Permalink
[GPU] Use affine.linearize_index (and delinearize_index) where possible
Browse files Browse the repository at this point in the history
There have been issues with the composition of affine maps being too
general and loosing important information, like the fact that
affine_map<(s0 + s1 * 32 + ... - (s0 floorDiv 16) * 16)> realy should
be affine_map<(s0 mod 16 + s1 * 32 + ...)>, and other issues with the
ultimate IR that block low-level arithmetic optimizations.

The affine.delinearize_index operation represents the div/mod chains
needed to break a flat index into its component parts. A recently
added affine.linearize_index operation is its inverse - combining
multiple indices into a flat 1D value.

Another advantage to linearize/delinearize is simpler upstream
canonicalizations and lead to more streamlined generated code.

This PR updates the vector distribution code and other GPU-related
code that I could find to

1. Use affine.linearize_index to construct flat thread IDs
2. Use affine.delinearize_index in places where there was a
floorDiv/mod chain.
3. Plumb the subgroup size through the transfer_read and
transfer_write distribution patterns to enable better reasoning about
when you do/don't need to take a mod of the lane ID
  • Loading branch information
krzysz00 committed Nov 12, 2024
1 parent 4aa08f2 commit c3eabac
Show file tree
Hide file tree
Showing 19 changed files with 290 additions and 329 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "mlir/Dialect/GPU/Transforms/Passes.h"
#include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"

namespace mlir::iree_compiler {

Expand Down Expand Up @@ -87,9 +88,12 @@ LogicalResult resolveGPUMappedForallOp(RewriterBase &rewriter,
assert(!(hasThreadMapping && hasWarpMapping));
Value flatId = linearThreadId;
if (hasWarpMapping) {
OpFoldResult subgroupSizeVal = rewriter.getIndexAttr(subgroupSize);
flatId = affine::makeComposedAffineApply(rewriter, loc, d0.floorDiv(d1),
{flatId, subgroupSizeVal});
flatId = rewriter
.create<affine::AffineDelinearizeIndexOp>(
loc, flatId,
ArrayRef<int64_t>{flatWorkgroupSize / subgroupSize,
subgroupSize})
.getResult(0);
}

SmallVector<Value> delinSizes;
Expand Down Expand Up @@ -190,23 +194,18 @@ void GPUDistributeForallPass::runOnOperation() {
return signalPassFailure();
}

AffineExpr x, y, z;
bindSymbols(funcOp.getContext(), x, y, z);
// Compute the linearized thread id.
AffineExpr linearId =
x + workgroupSize[0] * y + workgroupSize[1] * workgroupSize[0] * z;

rewriter.setInsertionPointToStart(&funcOp.getFunctionBody().front());
SmallVector<OpFoldResult> threadGrid = {
rewriter.createOrFold<gpu::ThreadIdOp>(funcOp.getLoc(),
gpu::Dimension::x),
rewriter.createOrFold<gpu::ThreadIdOp>(funcOp.getLoc(),
gpu::Dimension::y),
rewriter.createOrFold<gpu::ThreadIdOp>(funcOp.getLoc(),
gpu::Dimension::z)};

Value linearThreadIdVal = affine::makeComposedAffineApply(
rewriter, funcOp.getLoc(), linearId, threadGrid);
SmallVector<Value> threadGrid = {rewriter.createOrFold<gpu::ThreadIdOp>(
funcOp.getLoc(), gpu::Dimension::z),
rewriter.createOrFold<gpu::ThreadIdOp>(
funcOp.getLoc(), gpu::Dimension::y),
rewriter.createOrFold<gpu::ThreadIdOp>(
funcOp.getLoc(), gpu::Dimension::x)};
SmallVector<int64_t> threadGridBasis = {workgroupSize[2], workgroupSize[1],
workgroupSize[0]};

Value linearThreadIdVal = rewriter.create<affine::AffineLinearizeIndexOp>(
funcOp.getLoc(), threadGrid, threadGridBasis, /*disjoint=*/true);
for (auto forall : forallOps) {
rewriter.setInsertionPoint(forall);
if (failed(resolveGPUMappedForallOp(rewriter, forall, linearThreadIdVal,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,30 +189,29 @@ SmallVector<linalg::ProcInfo> getIds(OpBuilder &b, Location loc,
ArrayRef<Range> parallelLoopRanges,
Value flatThreadId) {
SmallVector<linalg::ProcInfo> infos;
Value id = flatThreadId;
AffineExpr d0 = b.getAffineDimExpr(0);
for (Range r : llvm::reverse(parallelLoopRanges)) {
linalg::ProcInfo info;
SmallVector<int64_t> delinSizes;
for (Range r : parallelLoopRanges) {
auto offset = dyn_cast<Attribute>(r.offset);
auto stride = dyn_cast<Attribute>(r.stride);
auto size = dyn_cast<Attribute>(r.size);
assert(offset && stride && size);
int64_t numThreadsDim = (llvm::cast<IntegerAttr>(size).getInt() -
llvm::cast<IntegerAttr>(offset).getInt()) /
llvm::cast<IntegerAttr>(stride).getInt();
Value dimId = id;
if (infos.size() != parallelLoopRanges.size() - 1)
dimId =
affine::makeComposedAffineApply(b, loc, d0 % numThreadsDim, {dimId});
delinSizes.push_back(numThreadsDim);
}
ValueRange dims =
b.create<affine::AffineDelinearizeIndexOp>(loc, flatThreadId, delinSizes)
.getResults();

for (auto [dimId, numThreadsDim] : llvm::zip_equal(dims, delinSizes)) {
linalg::ProcInfo info;
info.procId = dimId;
info.nprocs = b.create<arith::ConstantIndexOp>(loc, numThreadsDim);
info.distributionMethod =
linalg::DistributionMethod::CyclicNumProcsEqNumIters;
infos.push_back(info);
id = affine::makeComposedAffineApply(b, loc, d0.floorDiv(numThreadsDim),
{id});
}
std::reverse(infos.begin(), infos.end());
return infos;
}

Expand Down Expand Up @@ -288,19 +287,16 @@ static Value createFlatId(mlir::FunctionOpInterface funcOp,
ArrayRef<int64_t> workgroupSize) {
OpBuilder b(funcOp.getFunctionBody());
Type indexType = b.getIndexType();
AffineExpr d0 = getAffineDimExpr(0, b.getContext());
AffineExpr d1 = getAffineDimExpr(1, b.getContext());
AffineExpr d2 = getAffineDimExpr(2, b.getContext());
Value threadX =
b.create<gpu::ThreadIdOp>(funcOp.getLoc(), indexType, gpu::Dimension::x);
Value threadY =
b.create<gpu::ThreadIdOp>(funcOp.getLoc(), indexType, gpu::Dimension::y);
Value threadZ =
b.create<gpu::ThreadIdOp>(funcOp.getLoc(), indexType, gpu::Dimension::z);
Value flatThreadId = affine::makeComposedAffineApply(
b, funcOp.getLoc(),
d0 + workgroupSize[0] * d1 + (workgroupSize[0] * workgroupSize[1]) * d2,
{threadX, threadY, threadZ});
Value flatThreadId = b.create<affine::AffineLinearizeIndexOp>(
funcOp.getLoc(), ValueRange{threadZ, threadY, threadX},
ArrayRef<int64_t>{workgroupSize[2], workgroupSize[1], workgroupSize[0]},
/*disjoint=*/true);
return flatThreadId;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,54 +32,55 @@ namespace {
/// parameterized by the thread grid.
static SmallVector<Value> computeSIMDIndex(const LayoutIterator::State &state,
LayoutAttr layout, Value laneId,
int64_t subgroupSize,
RewriterBase &rewriter) {
MLIRContext *ctx = layout.getContext();
AffineExpr threadX, threadY, threadZ;
bindSymbols(ctx, threadX, threadY, threadZ);
Location loc = laneId.getLoc();

auto [laneDimX, laneDimY, laneDimZ] = layout.getLaneGrid();
int64_t gridsPerSubgroup =
llvm::divideCeil(subgroupSize, laneDimX * laneDimY * laneDimZ);
// Note: we add an extra entry to the delinearization here so that, if the
// vector layout requires fewer lanes than are present in the subgroup.
// Otherwise, we'd, for example, construct delinearizations with the basis (1,
// 1, 16) when there are 32 lanes, which would simplify to no delinearization
// at all. To resolve this, we add an extra term to the grid to capture the
// overflow.
auto reversedLaneGrid = rewriter.create<affine::AffineDelinearizeIndexOp>(
loc, laneId,
ArrayRef<int64_t>{gridsPerSubgroup, laneDimZ, laneDimY, laneDimX});

SmallVector<Value> simdIndex;

// Calculate the index for each dim separately.
for (PerDimLayoutAttr dimLayout : layout.getLayouts()) {
AffineExpr offset = getAffineConstantExpr(0, ctx);
AffineExpr stride = getAffineConstantExpr(1, ctx);
for (auto [label, shape] : llvm::reverse(
llvm::zip(dimLayout.getLabels(), dimLayout.getShapes()))) {
SmallVector<Value> linearizeVals;
for (LayoutDimensionAttr label : dimLayout.getLabels()) {
int64_t position = state.lookup(label.getValue()).getPosition();

// Note: indices are into a reversed lane grid that has an extra leading
// term we must ignore (so the X coordinate is result #3 and the Z
// coordinate is result #1).
switch (label.getValue()) {
case LayoutDimension::LANEX:
offset = offset + stride * threadX;
linearizeVals.push_back(reversedLaneGrid.getResult(3));
break;
case LayoutDimension::LANEY:
offset = offset + stride * threadY;
linearizeVals.push_back(reversedLaneGrid.getResult(2));
break;
case LayoutDimension::LANEZ:
offset = offset + stride * threadZ;
linearizeVals.push_back(reversedLaneGrid.getResult(1));
break;
default:
offset = offset + stride * getAffineConstantExpr(position, ctx);
linearizeVals.push_back(
rewriter.createOrFold<arith::ConstantIndexOp>(loc, position));
break;
}
stride = stride * getAffineConstantExpr(shape, ctx);
}

auto [laneDimX, laneDimY, laneDimZ] = layout.getLaneGrid();
SmallVector<Value> laneGrid = {
rewriter.create<arith::ConstantIndexOp>(laneId.getLoc(), laneDimZ),
rewriter.create<arith::ConstantIndexOp>(laneId.getLoc(), laneDimY),
rewriter.create<arith::ConstantIndexOp>(laneId.getLoc(), laneDimX)};
FailureOr<SmallVector<Value>> maybeReversedLaneGridVals =
affine::delinearizeIndex(rewriter, laneId.getLoc(), laneId, laneGrid);
assert(succeeded(maybeReversedLaneGridVals) &&
"Failed to delinearize lane index");
SmallVector<Value> laneGridVals = {(*maybeReversedLaneGridVals)[2],
(*maybeReversedLaneGridVals)[1],
(*maybeReversedLaneGridVals)[0]};

// Compute the index for the dim.
AffineMap indexMap = AffineMap::get(0, 3, offset);
Value index = rewriter.create<affine::AffineApplyOp>(
rewriter.getUnknownLoc(), indexMap, laneGridVals);
Value index = rewriter.create<affine::AffineLinearizeIndexOp>(
rewriter.getUnknownLoc(), linearizeVals, dimLayout.getShapes(),
/*disjoint=*/true);
simdIndex.push_back(index);
}

Expand Down Expand Up @@ -199,8 +200,9 @@ struct DistributeXferLayoutAttr : OpDistributionPattern<OpTy> {
"expected vector::TransferReadOp or vector::TransferWriteOp");

DistributeXferLayoutAttr(MLIRContext *context, Value laneId,
PatternBenefit benefit = 1)
: OpDistributionPattern<OpTy>(context, benefit), laneId(laneId) {}
int64_t subgroupSize, PatternBenefit benefit = 1)
: OpDistributionPattern<OpTy>(context, benefit), laneId(laneId),
subgroupSize(subgroupSize) {}

VectorValue accessMemory(OpTy xferOp, VectorValue accumulator,
LayoutAttr vectorLayout,
Expand Down Expand Up @@ -237,7 +239,7 @@ struct DistributeXferLayoutAttr : OpDistributionPattern<OpTy> {
llvm::SmallBitVector &projectedDims,
RewriterBase &rewriter) const {
SmallVector<Value> simdIndices =
computeSIMDIndex(state, memoryLayout, laneId, rewriter);
computeSIMDIndex(state, memoryLayout, laneId, subgroupSize, rewriter);
SmallVector<Value> memoryIndices(indices);

// The memory layout has some projected leading dims that indices doesn't.
Expand Down Expand Up @@ -272,6 +274,7 @@ struct DistributeXferLayoutAttr : OpDistributionPattern<OpTy> {
}

Value laneId;
int64_t subgroupSize;
};

struct DistributeTransferReadLayoutAttr final
Expand Down Expand Up @@ -940,26 +943,28 @@ struct DistributeLayoutConflictToSharedMemory final

// Offset and indexing computation such that subgroups can
// write and read to shared memory correctly and without conflicts.
AffineExpr d0, d1, d2, s0;
bindDims(rewriter.getContext(), d0, d1, d2);
bindSymbols(rewriter.getContext(), s0);
auto indexType = rewriter.getIndexType();
Value threadX =
rewriter.create<gpu::ThreadIdOp>(loc, indexType, gpu::Dimension::x);
Value threadY =
rewriter.create<gpu::ThreadIdOp>(loc, indexType, gpu::Dimension::y);
Value threadZ =
rewriter.create<gpu::ThreadIdOp>(loc, indexType, gpu::Dimension::z);
Value flatThreadId = affine::makeComposedAffineApply(
rewriter, loc,
(d0 + workgroupSize.value()[0] * d1 +
(workgroupSize.value()[0] * workgroupSize.value()[1]) * d2),
{threadX, threadY, threadZ});
Value subgroupOffset = affine::makeComposedAffineApply(
rewriter, loc,
s0.floorDiv(subgroupSize.value()) *
resolutionType.getShape()[vectorRank - 1],
{flatThreadId});
Value flatThreadId = rewriter.create<affine::AffineLinearizeIndexOp>(
loc, ValueRange{threadZ, threadY, threadX},
ArrayRef<int64_t>{workgroupSize.value()[2], workgroupSize.value()[1],
workgroupSize.value()[0]},
/*disjoint=*/true);

Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
auto splitBySubgroups = rewriter.create<affine::AffineDelinearizeIndexOp>(
loc, flatThreadId,
ArrayRef<int64_t>{numSubgroups, subgroupSize.value()});
Value subgroupOffset = rewriter.create<affine::AffineLinearizeIndexOp>(
loc, ValueRange{splitBySubgroups.getResult(0), c0},
ArrayRef<int64_t>{numSubgroups,
resolutionType.getShape()[vectorRank - 1]},
/*disjoint=*/true);

// Create shared memory to store the intermediate from src layout.
auto workgroupMemoryAddressSpace = Attribute(gpu::AddressSpaceAttr::get(
Expand All @@ -980,7 +985,6 @@ struct DistributeLayoutConflictToSharedMemory final
shapes, strides);

// Creating write/trip to shared memory using src layout.
Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
SmallVector<Value> indices(resolutionType.getRank(), c0);
SmallVector<bool> inBounds(vectorRank, true);
auto write = rewriter.create<vector::TransferWriteOp>(loc, vector, subview,
Expand Down Expand Up @@ -1117,10 +1121,11 @@ void populateGPUDistributionPatterns(RewritePatternSet &patterns) {
}

void populateGPUDistributionLayoutAttrPatterns(Value laneId,
int64_t subgroupSize,
RewritePatternSet &patterns) {
patterns
.add<DistributeTransferReadLayoutAttr, DistributeTransferWriteLayoutAttr>(
patterns.getContext(), laneId);
patterns.getContext(), laneId, subgroupSize);
patterns.add<DistributeBroadcastLayoutAttr, DistributeTransposeLayoutAttr>(
patterns.getContext());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ void populateDropSharedMemoryDeallocOpPatterns(RewritePatternSet &patterns);
void populateGPUDistributionPatterns(RewritePatternSet &patterns);

void populateGPUDistributionLayoutAttrPatterns(Value laneId,
int64_t subgroupSize,
RewritePatternSet &patterns);

void populateGPUReductionDistributionPatterns(RewritePatternSet &patterns,
Expand Down
Loading

0 comments on commit c3eabac

Please sign in to comment.