Skip to content

Commit

Permalink
[LLVMGPU] Refactor VirtualMMAIntrinsic to it's own attribute and enum. (
Browse files Browse the repository at this point in the history
iree-org#19055)

VirtualMMA are variants of MMAOps that has layouts modified through
interleaved of reads and/or unrolled-K. These are especially useful for
coalescing reads from shared memory to register and for FA to align
layouts between the 1st and 2nd matmul.

However, the current implementation of virtual MMA lives in MMAAttr and
MMAIntrinsic which is a violation of abstractions. Especially since
data_tiled_mma_layout is dependent on it. Additionally, having
VirtualMMA in MMAAttr also limits the flexibility of adding more stuff
and specification to it in the future.

In order to resolve this conflict of abstractions, we are migrating
VirtualMMAAttr to it's own enums and attr/class.

Things that we have changed in this PR:

1. Factor our VMFMA from `MMAAttr` and `MMAIntrinsic` into it's own attr
and enums `VirtualMMAAttr ` and `VirtualMMAIntrinsic`.
2. Update places that may use `MMAAttr` or `VirtualMMAAttr` to simply
use `MmaInterfaceAttr` such as `KernelConfig`,
`LLVMGPUConfigureTensorLayout`, and `AMDGPUDistributeContract`,
`GPUNestedDistribution`
3. Move `get[A,B,C]SingleSubgroupLayout` as global method 
that works on `MmaInterfaceAttr`.

---------

Signed-off-by: Stanley Winata <[email protected]>
  • Loading branch information
raikonenfnu authored and Groverkss committed Nov 29, 2024
1 parent 7498572 commit 271c748
Show file tree
Hide file tree
Showing 10 changed files with 449 additions and 136 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,23 +50,22 @@ isSubgroupLayoutCompatible(IREE::GPU::MMASingleSubgroupLayout subgroupLayout,
return success();
}

static LogicalResult isIntrinsicLayoutCompatible(VectorContractOpInfo &opInfo,
IREE::GPU::MMAAttr intrinsic,
NestedLayoutAttr lhsLayout,
NestedLayoutAttr rhsLayout,
NestedLayoutAttr accLayout) {
static LogicalResult isIntrinsicLayoutCompatible(
VectorContractOpInfo &opInfo, IREE::GPU::MmaInterfaceAttr intrinsic,
NestedLayoutAttr lhsLayout, NestedLayoutAttr rhsLayout,
NestedLayoutAttr accLayout) {
auto [lhsM, rhsN] = opInfo.getOperandMNIndex();
auto [lhsK, rhsK] = opInfo.getOperandKIndex();
auto [accM, accN] = opInfo.getResultMNIndex();
if (failed(isSubgroupLayoutCompatible(intrinsic.getASingleSubgroupLayout(),
if (failed(isSubgroupLayoutCompatible(getASingleSubgroupLayout(intrinsic),
lhsLayout, lhsM, lhsK))) {
return failure();
}
if (failed(isSubgroupLayoutCompatible(intrinsic.getBSingleSubgroupLayout(),
if (failed(isSubgroupLayoutCompatible(getBSingleSubgroupLayout(intrinsic),
rhsLayout, rhsK, rhsN))) {
return failure();
}
if (failed(isSubgroupLayoutCompatible(intrinsic.getCSingleSubgroupLayout(),
if (failed(isSubgroupLayoutCompatible(getCSingleSubgroupLayout(intrinsic),
accLayout, accM, accN))) {
return failure();
}
Expand Down Expand Up @@ -124,16 +123,16 @@ struct DistributeContract final : OpDistributionPattern<vector::ContractionOp> {

// We assume there is an decision made before regarding which mfma intrinsic
// to use and it is attached as an attribute to this contract op.
auto mmaAttr =
contractOp->getAttrOfType<IREE::GPU::MMAAttr>("iree.amdgpu.mma");
if (!mmaAttr) {
auto mmaKind = contractOp->getAttrOfType<IREE::GPU::MmaInterfaceAttr>(
"iree.amdgpu.mma");
if (!mmaKind) {
return rewriter.notifyMatchFailure(
contractOp, "missing iree.amdgpu.mma intrinsic attribute");
}

// Check if the given intrinsic can be distributed with the given
// layouts.
if (failed(isIntrinsicLayoutCompatible(opDetail, mmaAttr, lhsLayout,
if (failed(isIntrinsicLayoutCompatible(opDetail, mmaKind, lhsLayout,
rhsLayout, accLayout))) {
return rewriter.notifyMatchFailure(
contractOp, "the intrinsic does not match the expected layouts");
Expand Down Expand Up @@ -230,7 +229,7 @@ struct DistributeContract final : OpDistributionPattern<vector::ContractionOp> {
Value rhsSlice =
rewriter.create<vector::ExtractOp>(loc, rhs, rhsBatchOffsets);
accSlice =
computeMMA(rewriter, loc, mmaAttr, lhsSlice, rhsSlice, accSlice);
computeMMA(rewriter, loc, mmaKind, lhsSlice, rhsSlice, accSlice);
}
finalTile = rewriter.create<vector::InsertOp>(loc, accSlice, finalTile,
resultBatchOffsets);
Expand Down Expand Up @@ -285,17 +284,18 @@ struct DistributeContract final : OpDistributionPattern<vector::ContractionOp> {

// Generates amdgpu.mfma operation on the given inputs for the given MFMA
// |intrinsic|.
Value computeMMA(OpBuilder &builder, Location loc, IREE::GPU::MMAAttr mmaAttr,
Value a, Value b, Value c) const {
Value computeMMA(OpBuilder &builder, Location loc,
IREE::GPU::MmaInterfaceAttr mmaKind, Value a, Value b,
Value c) const {
// Get the storage vector types that each thread is in charge of.
auto [aVectorType, bVectorType, cVectorType] = mmaAttr.getABCVectorTypes();
auto [aVectorType, bVectorType, cVectorType] = mmaKind.getABCVectorTypes();
Value aCast =
builder.create<vector::ShapeCastOp>(a.getLoc(), aVectorType, a);
Value bCast =
builder.create<vector::ShapeCastOp>(b.getLoc(), bVectorType, b);
Value cCast =
builder.create<vector::ShapeCastOp>(c.getLoc(), cVectorType, c);
FailureOr<Value> mmaOp = mmaAttr.buildMmaOperation(
FailureOr<Value> mmaOp = mmaKind.buildMmaOperation(
builder, loc, cVectorType, aCast, bCast, cCast);
assert(succeeded(mmaOp) && "Failed to construct mma op");
return builder.create<vector::ShapeCastOp>(c.getLoc(), c.getType(), *mmaOp);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -593,8 +593,8 @@ struct DistributeContract final : OpDistributionPattern<vector::ContractionOp> {
}
// If mmaAttr exists, defer the lowering to use MMA.
// Notify failure if the "iree.amdgpu.mma" intrinsic attribute is present.
auto mmaAttr =
contractOp->getAttrOfType<IREE::GPU::MMAAttr>("iree.amdgpu.mma");
auto mmaAttr = contractOp->getAttrOfType<IREE::GPU::MmaInterfaceAttr>(
"iree.amdgpu.mma");
if (mmaAttr) {
return rewriter.notifyMatchFailure(
contractOp, "iree.amdgpu.mma intrinsic attribute exists");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,7 @@ func.func @contract_to_vmfma_32x32x16_mm(%a : vector<32x16xf16>, %b : vector<16x
indexing_maps = [#map1, #map2, #map3],
iterator_types = ["parallel", "parallel", "reduction"],
kind = #vector.kind<add>,
iree.amdgpu.mma = #iree_gpu.mma_layout<VMFMA_F32_32x32x16_F16>
iree.amdgpu.mma = #iree_gpu.virtual_mma_layout<intrinsic = VMFMA_F32_32x32x16_F16>
} %A, %B, %C : vector<32x16xf16>, vector<16x32xf16> into vector<32x32xf32>

%O = iree_vector_ext.to_layout %output to layout(#layout_c) : vector<32x32xf32>
Expand Down
Loading

0 comments on commit 271c748

Please sign in to comment.