Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LLVMGPU] Refactor VirtualMMAIntrinsic to it's own attribute and enum. #19055

Merged
merged 5 commits into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -50,23 +50,22 @@ isSubgroupLayoutCompatible(IREE::GPU::MMASingleSubgroupLayout subgroupLayout,
return success();
}

static LogicalResult isIntrinsicLayoutCompatible(VectorContractOpInfo &opInfo,
IREE::GPU::MMAAttr intrinsic,
NestedLayoutAttr lhsLayout,
NestedLayoutAttr rhsLayout,
NestedLayoutAttr accLayout) {
static LogicalResult isIntrinsicLayoutCompatible(
VectorContractOpInfo &opInfo, IREE::GPU::MmaInterfaceAttr intrinsic,
NestedLayoutAttr lhsLayout, NestedLayoutAttr rhsLayout,
NestedLayoutAttr accLayout) {
auto [lhsM, rhsN] = opInfo.getOperandMNIndex();
auto [lhsK, rhsK] = opInfo.getOperandKIndex();
auto [accM, accN] = opInfo.getResultMNIndex();
if (failed(isSubgroupLayoutCompatible(intrinsic.getASingleSubgroupLayout(),
if (failed(isSubgroupLayoutCompatible(getASingleSubgroupLayout(intrinsic),
lhsLayout, lhsM, lhsK))) {
return failure();
}
if (failed(isSubgroupLayoutCompatible(intrinsic.getBSingleSubgroupLayout(),
if (failed(isSubgroupLayoutCompatible(getBSingleSubgroupLayout(intrinsic),
rhsLayout, rhsK, rhsN))) {
return failure();
}
if (failed(isSubgroupLayoutCompatible(intrinsic.getCSingleSubgroupLayout(),
if (failed(isSubgroupLayoutCompatible(getCSingleSubgroupLayout(intrinsic),
accLayout, accM, accN))) {
return failure();
}
Expand Down Expand Up @@ -124,16 +123,16 @@ struct DistributeContract final : OpDistributionPattern<vector::ContractionOp> {

// We assume there is an decision made before regarding which mfma intrinsic
// to use and it is attached as an attribute to this contract op.
auto mmaAttr =
contractOp->getAttrOfType<IREE::GPU::MMAAttr>("iree.amdgpu.mma");
if (!mmaAttr) {
auto mmaKind = contractOp->getAttrOfType<IREE::GPU::MmaInterfaceAttr>(
"iree.amdgpu.mma");
if (!mmaKind) {
return rewriter.notifyMatchFailure(
contractOp, "missing iree.amdgpu.mma intrinsic attribute");
}

// Check if the given intrinsic can be distributed with the given
// layouts.
if (failed(isIntrinsicLayoutCompatible(opDetail, mmaAttr, lhsLayout,
if (failed(isIntrinsicLayoutCompatible(opDetail, mmaKind, lhsLayout,
rhsLayout, accLayout))) {
return rewriter.notifyMatchFailure(
contractOp, "the intrinsic does not match the expected layouts");
Expand Down Expand Up @@ -230,7 +229,7 @@ struct DistributeContract final : OpDistributionPattern<vector::ContractionOp> {
Value rhsSlice =
rewriter.create<vector::ExtractOp>(loc, rhs, rhsBatchOffsets);
accSlice =
computeMMA(rewriter, loc, mmaAttr, lhsSlice, rhsSlice, accSlice);
computeMMA(rewriter, loc, mmaKind, lhsSlice, rhsSlice, accSlice);
}
finalTile = rewriter.create<vector::InsertOp>(loc, accSlice, finalTile,
resultBatchOffsets);
Expand Down Expand Up @@ -285,17 +284,18 @@ struct DistributeContract final : OpDistributionPattern<vector::ContractionOp> {

// Generates amdgpu.mfma operation on the given inputs for the given MFMA
// |intrinsic|.
Value computeMMA(OpBuilder &builder, Location loc, IREE::GPU::MMAAttr mmaAttr,
Value a, Value b, Value c) const {
Value computeMMA(OpBuilder &builder, Location loc,
IREE::GPU::MmaInterfaceAttr mmaKind, Value a, Value b,
Value c) const {
// Get the storage vector types that each thread is in charge of.
auto [aVectorType, bVectorType, cVectorType] = mmaAttr.getABCVectorTypes();
auto [aVectorType, bVectorType, cVectorType] = mmaKind.getABCVectorTypes();
Value aCast =
builder.create<vector::ShapeCastOp>(a.getLoc(), aVectorType, a);
Value bCast =
builder.create<vector::ShapeCastOp>(b.getLoc(), bVectorType, b);
Value cCast =
builder.create<vector::ShapeCastOp>(c.getLoc(), cVectorType, c);
FailureOr<Value> mmaOp = mmaAttr.buildMmaOperation(
FailureOr<Value> mmaOp = mmaKind.buildMmaOperation(
builder, loc, cVectorType, aCast, bCast, cCast);
assert(succeeded(mmaOp) && "Failed to construct mma op");
return builder.create<vector::ShapeCastOp>(c.getLoc(), c.getType(), *mmaOp);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -593,8 +593,8 @@ struct DistributeContract final : OpDistributionPattern<vector::ContractionOp> {
}
// If mmaAttr exists, defer the lowering to use MMA.
// Notify failure if the "iree.amdgpu.mma" intrinsic attribute is present.
auto mmaAttr =
contractOp->getAttrOfType<IREE::GPU::MMAAttr>("iree.amdgpu.mma");
auto mmaAttr = contractOp->getAttrOfType<IREE::GPU::MmaInterfaceAttr>(
"iree.amdgpu.mma");
if (mmaAttr) {
return rewriter.notifyMatchFailure(
contractOp, "iree.amdgpu.mma intrinsic attribute exists");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,7 @@ func.func @contract_to_vmfma_32x32x16_mm(%a : vector<32x16xf16>, %b : vector<16x
indexing_maps = [#map1, #map2, #map3],
iterator_types = ["parallel", "parallel", "reduction"],
kind = #vector.kind<add>,
iree.amdgpu.mma = #iree_gpu.mma_layout<VMFMA_F32_32x32x16_F16>
iree.amdgpu.mma = #iree_gpu.virtual_mma_layout<intrinsic = VMFMA_F32_32x32x16_F16>
} %A, %B, %C : vector<32x16xf16>, vector<16x32xf16> into vector<32x32xf32>

%O = iree_vector_ext.to_layout %output to layout(#layout_c) : vector<32x32xf32>
Expand Down
Loading
Loading