Skip to content

Commit

Permalink
[LLVMGPU] Make linalg_ext.online_attention op the root op
Browse files Browse the repository at this point in the history
This commit changes LLVMGPUSelectLoweringConfig to look
for linalg_ext.online_attention op. In order to make that
work, the lowering from linalg_ext.attention is hoisted
just above the pass.

As a consequence linalg_ext.online_attention op now
implements LinalgFusionInterface and generateResultTileValue.

Signed-off-by: Manupa Karunaratne <[email protected]>
  • Loading branch information
manupak committed Nov 12, 2024
1 parent 87e6e09 commit ba5c3fc
Show file tree
Hide file tree
Showing 7 changed files with 89 additions and 27 deletions.
12 changes: 6 additions & 6 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -722,10 +722,9 @@ setMatmulVectorDistributionConfig(IREE::GPU::TargetAttr target,
targetSubgroupSize, pipelineConfig);
}

static LogicalResult
setAttentionVectorDistributionConfig(IREE::GPU::TargetAttr target,
mlir::FunctionOpInterface entryPoint,
IREE::LinalgExt::AttentionOp op) {
static LogicalResult setOnlineAttentionVectorDistributionConfig(
IREE::GPU::TargetAttr target, mlir::FunctionOpInterface entryPoint,
IREE::LinalgExt::OnlineAttentionOp op) {
if (target.getWgp().getMma().empty())
return failure();

Expand Down Expand Up @@ -1014,9 +1013,10 @@ setVectorDistributionConfig(IREE::GPU::TargetAttr target,
}
}

if (auto attnOp = dyn_cast<IREE::LinalgExt::AttentionOp>(computeOp)) {
if (auto attnOp = dyn_cast<IREE::LinalgExt::OnlineAttentionOp>(computeOp)) {
LDBG("VectorDistribution: trying to find a suitable attention config");
return setAttentionVectorDistributionConfig(target, entryPoint, attnOp);
return setOnlineAttentionVectorDistributionConfig(target, entryPoint,
attnOp);
}

LDBG("VectorDistribution: failed to find a suitable config");
Expand Down
6 changes: 2 additions & 4 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -840,10 +840,6 @@ void addGPUVectorDistributePassPipeline(OpPassManager &funcPassManager,
if (usePadToModelSharedMemcpy) {
funcPassManager.addPass(createLLVMGPUPromoteMatmulToFitMMAPass());
}

funcPassManager.addPass(
IREE::LinalgExt::createConvertAttentionToOnlineAttentionPass());

funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createCSEPass());
funcPassManager.addPass(createGPUPromoteMatmulOperandsPass());
Expand Down Expand Up @@ -1167,6 +1163,8 @@ static void buildLLVMGPUCodegenConfigurationPassPipelineImpl(
funcPassManager.addPass(createBlockDynamicDimensionsPass);
funcPassManager.addPass(createCanonicalizerPass);
funcPassManager.addPass(createCSEPass);
funcPassManager.addPass(
IREE::LinalgExt::createConvertAttentionToOnlineAttentionPass);
}
modulePassManager.addPass(createMaterializeUserConfigsPass());
modulePassManager.addPass(createLLVMGPUSelectLoweringStrategyPass());
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx940 --iree-codegen-llvmgpu-use-vector-distribution \
// RUN: --iree-codegen-llvmgpu-use-unaligned-gemm-vector-distribution --iree-codegen-llvmgpu-use-igemm=false \
// RUN: --pass-pipeline="builtin.module(iree-llvmgpu-select-lowering-strategy)" %s | FileCheck %s
// RUN: --pass-pipeline="builtin.module(func.func(iree-linalg-ext-convert-attention-to-online-attention), iree-llvmgpu-select-lowering-strategy)" %s | FileCheck %s

// TODO: This test is still using the legacy LLVMGPU kernel config. This needs
// to be migrated to the rocdl heuristics, but for now is just physically
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx940 \
// RUN: --iree-codegen-llvmgpu-use-vector-distribution --iree-llvmgpu-enable-prefetch=true \
// RUN: --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-llvmgpu-lower-executable-target)))))" \
// RUN: --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-codegen-llvmgpu-configuration-pipeline, func.func(iree-llvmgpu-lower-executable-target)))))" \
// RUN: %s | FileCheck %s

// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx940 \
// RUN: --iree-codegen-llvmgpu-use-vector-distribution --iree-llvmgpu-enable-prefetch=true \
// RUN: --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-llvmgpu-lower-executable-target)))))" \
// RUN: --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-codegen-llvmgpu-configuration-pipeline, func.func(iree-llvmgpu-lower-executable-target)))))" \
// RUN: %s | FileCheck %s --check-prefix=MEMORY

#config = #iree_gpu.lowering_config<{workgroup = [64, 64, 0], reduction = [0, 0, 128], promote_operands = [0, 1], mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, subgroup_m_count = 2, subgroup_n_count = 2}>
Expand Down
55 changes: 42 additions & 13 deletions compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1340,17 +1340,11 @@ SmallVector<AffineMap> AttentionOp::getIndexingMapsArray() {
getIndexingMaps().getAsValueRange<AffineMapAttr>());
}

FailureOr<SmallVector<int64_t>> AttentionOp::getStaticLoopRanges() {
SmallVector<int64_t> bounds(getIterationDomainRank());
SmallVector<bool> dimsFound(getIterationDomainRank(), false);

// batch(s), m, k1
ArrayRef<int64_t> queryShape = getQuery().getType().getShape();
ArrayRef<AffineExpr> queryDims = getQueryMap().getResults();
// batch(s), k2, n
ArrayRef<int64_t> valueShape = getValue().getType().getShape();
ArrayRef<AffineExpr> valueDims = getValueMap().getResults();

static SmallVector<int64_t> getStaticLoopRangesForAttnLikeOps(
ArrayRef<int64_t> qShape, ArrayRef<AffineExpr> qDims,
ArrayRef<int64_t> vShape, ArrayRef<AffineExpr> vDims, int64_t iterDomRank) {
SmallVector<int64_t> bounds(iterDomRank);
SmallVector<bool> dimsFound(iterDomRank, false);
auto fillSizes = [&](ArrayRef<int64_t> sizes, ArrayRef<AffineExpr> dims) {
for (auto [size, dim] : llvm::zip_equal(sizes, dims)) {
int pos = cast<AffineDimExpr>(dim).getPosition();
Expand All @@ -1361,11 +1355,23 @@ FailureOr<SmallVector<int64_t>> AttentionOp::getStaticLoopRanges() {
dimsFound[pos] = true;
}
};
fillSizes(queryShape, queryDims);
fillSizes(valueShape, valueDims);
fillSizes(qShape, qDims);
fillSizes(vShape, vDims);
return bounds;
}

FailureOr<SmallVector<int64_t>> AttentionOp::getStaticLoopRanges() {
// batch(s), m, k1
ArrayRef<int64_t> queryShape = getQuery().getType().getShape();
ArrayRef<AffineExpr> queryDims = getQueryMap().getResults();
// batch(s), k2, n
ArrayRef<int64_t> valueShape = getValue().getType().getShape();
ArrayRef<AffineExpr> valueDims = getValueMap().getResults();
int64_t iterDomRank = getIterationDomainRank();
return getStaticLoopRangesForAttnLikeOps(queryShape, queryDims, valueShape,
valueDims, iterDomRank);
}

SmallVector<AffineMap> AttentionOp::getIndexingMapsForOperands() {
auto maps = getIndexingMapsArray();
maps.resize(getNumDpsInputs());
Expand Down Expand Up @@ -1514,6 +1520,29 @@ SmallVector<AffineMap> OnlineAttentionOp::getIndexingMapsArray() {
getIndexingMaps().getAsValueRange<AffineMapAttr>());
}

FailureOr<SmallVector<int64_t>> OnlineAttentionOp::getStaticLoopRanges() {
// batch(s), m, k1
ArrayRef<int64_t> queryShape = getQuery().getType().getShape();
ArrayRef<AffineExpr> queryDims = getQueryMap().getResults();
// batch(s), k2, n
ArrayRef<int64_t> valueShape = getValue().getType().getShape();
ArrayRef<AffineExpr> valueDims = getValueMap().getResults();
int64_t iterDomRank = getIterationDomainRank();
return getStaticLoopRangesForAttnLikeOps(queryShape, queryDims, valueShape,
valueDims, iterDomRank);
}

SmallVector<AffineMap> OnlineAttentionOp::getIndexingMapsForOperands() {
auto maps = getIndexingMapsArray();
maps.resize(getNumDpsInputs());
return maps;
}

SmallVector<AffineMap> OnlineAttentionOp::getIndexingMapsForResults() {
auto maps = getIndexingMapsArray();
return SmallVector<AffineMap>(maps.begin() + getNumDpsInputs(), maps.end());
}

//===----------------------------------------------------------------------===//
// Im2colOp
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -577,13 +577,17 @@ def IREELinalgExt_OnlineAttentionOp : IREELinalgExt_PureOp<"online_attention",
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
SingleBlockImplicitTerminator<"::mlir::iree_compiler::IREE::LinalgExt::YieldOp">,
DestinationStyleOpInterface, LinalgExtInterface,
DeclareOpInterfaceMethods<LinalgFusionInterface,
["getIndexingMapsForResults", "getIndexingMapsForOperands",
"getStaticLoopRanges"]>,
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
DeclareOpInterfaceMethods<AggregatedOpInterface, ["decomposeOperation"]>,
DeclareOpInterfaceMethods<TilingInterface,
["getIterationDomain",
"getLoopIteratorTypes",
"getResultTilePosition",
"getTiledImplementation"]>]> {
"getTiledImplementation",
"generateResultTileValue"]>]> {
let summary = "Online Attention operator";
let description = [{
Traditional scaled dot product attention computes:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2136,6 +2136,37 @@ LogicalResult OnlineAttentionOp::getResultTilePosition(
return success();
}

FailureOr<TilingResult> OnlineAttentionOp::generateResultTileValue(
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) {
// Input offsets and sizes here are from the POV of the outputMap. We need to
// normalize these offsets and size for it to be useful.

// Initialize normalized offsets with 0s and normalized sizes with original
// size.
SmallVector<Range> iterationDomain(getIterationDomain(builder));
SmallVector<OpFoldResult> normalizedSizes =
llvm::map_to_vector(iterationDomain, [](Range x) { return x.size; });
SmallVector<OpFoldResult> normalizedOffsets(getIterationDomainRank(),
builder.getIndexAttr(0));
AffineMap outMap = getIndexingMapsArray()[getNumDpsInputs() + resultNumber];
ArrayRef<AffineExpr> outputDims = outMap.getResults();
for (auto [idx, dimExpr] : llvm::enumerate(outputDims)) {
int dim = cast<AffineDimExpr>(dimExpr).getPosition();
normalizedOffsets[dim] = offsets[idx];
normalizedSizes[dim] = sizes[idx];
}
FailureOr<TilingResult> tilingResult =
getTiledImplementation(builder, normalizedOffsets, normalizedSizes);
if (failed(tilingResult)) {
return failure();
}
Value requestedResultValue = tilingResult->tiledValues[resultNumber];
Operation *requestedResultSlice = tilingResult->generatedSlices[resultNumber];
return TilingResult{
tilingResult->tiledOps, {requestedResultValue}, {requestedResultSlice}};
}

//===---------------------------------------------------------------------===//
// CustomOp
//===---------------------------------------------------------------------===//
Expand Down

0 comments on commit ba5c3fc

Please sign in to comment.