Skip to content

Commit

Permalink
[NFC][LLVMGPU] Cleanup layout configuration (#19059)
Browse files Browse the repository at this point in the history
Contractions and Convolutions were using some adhoc logic before, which
isn't needed anymore as the codegen relies on lowering_config to set
layouts.
  • Loading branch information
Groverkss authored Nov 11, 2024
1 parent 48f6dee commit 300e0c3
Showing 1 changed file with 61 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,35 @@ static LogicalResult setDerivedThreadConfigLayout(
return success();
}

static LogicalResult setIntrinsicLoweringConfigLayout(
IREE::GPU::LoweringConfigAttr config, linalg::LinalgOp candidate,
ArrayRef<int64_t> workgroupSize, RewriterBase &rewriter) {

SmallVector<bool> promotedOperands = getPromotedOperands(candidate);
auto schedule = rewriter.getAttr<IREE::GPU::MMAScheduleAttr>(
getIntrinsic(candidate), getSubgroupMCount(candidate),
getSubgroupNCount(candidate));

if (linalg::isaContractionOpInterface(candidate)) {
if (succeeded(setContractionAnchor(schedule, promotedOperands, rewriter,
candidate))) {
return success();
}
}

if (succeeded(linalg::inferConvolutionDims(candidate))) {
if (succeeded(setConvolutionAnchor(schedule, promotedOperands, rewriter,
candidate))) {
return success();
}
}

candidate->emitError() << "Unable to set intrinsic layouts on operation "
"based on given lowering config: "
<< config;
return failure();
}

static Operation *getOpWithAttr(Operation *root, StringRef attr) {
Operation *result = nullptr;
WalkResult walkResult = root->walk([&](Operation *op) {
Expand Down Expand Up @@ -505,16 +534,11 @@ struct LLVMGPUConfigureTensorLayoutsPass final
return signalPassFailure();
}

if (failed(setDerivedConfigLayouts(func, maybeWorkgroupSize.value(),
rewriter))) {
if (failed(setLayoutsFromLoweringConfig(func, maybeWorkgroupSize.value(),
rewriter))) {
return signalPassFailure();
}

// Vector layout option setter aimed at contractions and convolutions. For
// now, layout setting for other problems like reductions is TODO.
SmallVector<linalg::LinalgOp> contracts;
SmallVector<linalg::LinalgOp> convs;

auto attentionQKMatmul = dyn_cast_or_null<linalg::LinalgOp>(
getOpWithAttr(func, "attention_qk_matmul"));
auto attentionPVMatmul = dyn_cast_or_null<linalg::LinalgOp>(
Expand All @@ -530,40 +554,6 @@ struct LLVMGPUConfigureTensorLayoutsPass final
return signalPassFailure();
}

func->walk([&](linalg::LinalgOp linalgOp) {
if (linalgOp == attentionQKMatmul || linalgOp == attentionPVMatmul) {
return WalkResult::advance();
}

if (linalg::isaContractionOpInterface(linalgOp)) {
contracts.push_back(linalgOp);
} else if (succeeded(linalg::inferConvolutionDims(linalgOp))) {
convs.push_back(linalgOp);
}
return WalkResult::advance();
});

for (linalg::LinalgOp contract : contracts) {
SmallVector<bool> promotedOperands = getPromotedOperands(contract);
auto contractionSchedule = rewriter.getAttr<IREE::GPU::MMAScheduleAttr>(
getIntrinsic(contract), getSubgroupMCount(contract),
getSubgroupNCount(contract));
if (failed(setContractionAnchor(contractionSchedule, promotedOperands,
rewriter, contract))) {
return signalPassFailure();
}
}

for (linalg::LinalgOp conv : convs) {
SmallVector<bool> promotedOperands = getPromotedOperands(conv);
auto convSchedule = rewriter.getAttr<IREE::GPU::MMAScheduleAttr>(
getIntrinsic(conv), getSubgroupMCount(conv), getSubgroupNCount(conv));
if (failed(setConvolutionAnchor(convSchedule, promotedOperands, rewriter,
conv))) {
return signalPassFailure();
}
}

if (attentionQKMatmul && attentionPVMatmul) {
if (failed(setAttentionMatmulAnchor(rewriter, attentionQKMatmul,
attentionPVMatmul))) {
Expand All @@ -572,24 +562,43 @@ struct LLVMGPUConfigureTensorLayoutsPass final
}
}

LogicalResult setDerivedConfigLayouts(FunctionOpInterface funcOp,
ArrayRef<int64_t> workgroupSize,
RewriterBase &rewriter) {
LogicalResult setLayoutsFromLoweringConfig(FunctionOpInterface funcOp,
ArrayRef<int64_t> workgroupSize,
RewriterBase &rewriter) {
SmallVector<linalg::LinalgOp> candidates;
funcOp->walk([&](linalg::LinalgOp op) {
auto config = dyn_cast_or_null<IREE::GPU::DerivedThreadConfigAttr>(
getLoweringConfig(op));
if (config) {
if (getLoweringConfig(op)) {
candidates.push_back(op);
}
});

for (linalg::LinalgOp candidate : candidates) {
auto config = dyn_cast_or_null<IREE::GPU::DerivedThreadConfigAttr>(
getLoweringConfig(candidate));
assert(config);
if (failed(setDerivedThreadConfigLayout(config, candidate, workgroupSize,
rewriter))) {
// Skip attention candidates.
if (candidate->hasAttr("attention_qk_matmul") ||
candidate->hasAttr("attention_pv_matmul")) {
continue;
}

auto result =
TypeSwitch<IREE::Codegen::LoweringConfigAttrInterface, LogicalResult>(
getLoweringConfig(candidate))
.Case([&](IREE::GPU::DerivedThreadConfigAttr config) {
return setDerivedThreadConfigLayout(config, candidate,
workgroupSize, rewriter);
})
.Case([&](IREE::GPU::LoweringConfigAttr config) {
if (config.getMmaKind()) {
return setIntrinsicLoweringConfigLayout(
config, candidate, workgroupSize, rewriter);
}
candidate->emitError() << "Unable to set layouts on operation "
"based on given lowering config: "
<< config;
return failure();
})
.Default([](Attribute) -> LogicalResult { return failure(); });

if (failed(result)) {
return failure();
}
}
Expand Down

0 comments on commit 300e0c3

Please sign in to comment.