Skip to content

Commit

Permalink
[LLVMGPU] Refactor VirtualMMAIntrinsic to it's own class
Browse files Browse the repository at this point in the history
Signed-off-by: Stanley Winata <[email protected]>
  • Loading branch information
raikonenfnu committed Nov 7, 2024
1 parent 8920c28 commit 9561d57
Show file tree
Hide file tree
Showing 10 changed files with 429 additions and 136 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,10 @@ isSubgroupLayoutCompatible(IREE::GPU::MMASingleSubgroupLayout subgroupLayout,
return success();
}

static LogicalResult isIntrinsicLayoutCompatible(VectorContractOpInfo &opInfo,
IREE::GPU::MMAAttr intrinsic,
NestedLayoutAttr lhsLayout,
NestedLayoutAttr rhsLayout,
NestedLayoutAttr accLayout) {
static LogicalResult isIntrinsicLayoutCompatible(
VectorContractOpInfo &opInfo, IREE::GPU::MmaInterfaceAttr intrinsic,
NestedLayoutAttr lhsLayout, NestedLayoutAttr rhsLayout,
NestedLayoutAttr accLayout) {
auto [lhsM, rhsN] = opInfo.getOperandMNIndex();
auto [lhsK, rhsK] = opInfo.getOperandKIndex();
auto [accM, accN] = opInfo.getResultMNIndex();
Expand Down Expand Up @@ -124,16 +123,16 @@ struct DistributeContract final : OpDistributionPattern<vector::ContractionOp> {

// We assume there is an decision made before regarding which mfma intrinsic
// to use and it is attached as an attribute to this contract op.
auto mmaAttr =
contractOp->getAttrOfType<IREE::GPU::MMAAttr>("iree.amdgpu.mma");
if (!mmaAttr) {
auto mmaKind = contractOp->getAttrOfType<IREE::GPU::MmaInterfaceAttr>(
"iree.amdgpu.mma");
if (!mmaKind) {
return rewriter.notifyMatchFailure(
contractOp, "missing iree.amdgpu.mma intrinsic attribute");
}

// Check if the given intrinsic can be distributed with the given
// layouts.
if (failed(isIntrinsicLayoutCompatible(opDetail, mmaAttr, lhsLayout,
if (failed(isIntrinsicLayoutCompatible(opDetail, mmaKind, lhsLayout,
rhsLayout, accLayout))) {
return rewriter.notifyMatchFailure(
contractOp, "the intrinsic does not match the expected layouts");
Expand Down Expand Up @@ -230,7 +229,7 @@ struct DistributeContract final : OpDistributionPattern<vector::ContractionOp> {
Value rhsSlice =
rewriter.create<vector::ExtractOp>(loc, rhs, rhsBatchOffsets);
accSlice =
computeMMA(rewriter, loc, mmaAttr, lhsSlice, rhsSlice, accSlice);
computeMMA(rewriter, loc, mmaKind, lhsSlice, rhsSlice, accSlice);
}
finalTile = rewriter.create<vector::InsertOp>(loc, accSlice, finalTile,
resultBatchOffsets);
Expand Down Expand Up @@ -285,17 +284,18 @@ struct DistributeContract final : OpDistributionPattern<vector::ContractionOp> {

// Generates amdgpu.mfma operation on the given inputs for the given MFMA
// |intrinsic|.
Value computeMMA(OpBuilder &builder, Location loc, IREE::GPU::MMAAttr mmaAttr,
Value a, Value b, Value c) const {
Value computeMMA(OpBuilder &builder, Location loc,
IREE::GPU::MmaInterfaceAttr mmaKind, Value a, Value b,
Value c) const {
// Get the storage vector types that each thread is in charge of.
auto [aVectorType, bVectorType, cVectorType] = mmaAttr.getABCVectorTypes();
auto [aVectorType, bVectorType, cVectorType] = mmaKind.getABCVectorTypes();
Value aCast =
builder.create<vector::ShapeCastOp>(a.getLoc(), aVectorType, a);
Value bCast =
builder.create<vector::ShapeCastOp>(b.getLoc(), bVectorType, b);
Value cCast =
builder.create<vector::ShapeCastOp>(c.getLoc(), cVectorType, c);
FailureOr<Value> mmaOp = mmaAttr.buildMmaOperation(
FailureOr<Value> mmaOp = mmaKind.buildMmaOperation(
builder, loc, cVectorType, aCast, bCast, cCast);
assert(succeeded(mmaOp) && "Failed to construct mma op");
return builder.create<vector::ShapeCastOp>(c.getLoc(), c.getType(), *mmaOp);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,7 @@ func.func @contract_to_vmfma_32x32x16_mm(%a : vector<32x16xf16>, %b : vector<16x
indexing_maps = [#map1, #map2, #map3],
iterator_types = ["parallel", "parallel", "reduction"],
kind = #vector.kind<add>,
iree.amdgpu.mma = #iree_gpu.mma_layout<VMFMA_F32_32x32x16_F16>
iree.amdgpu.mma = #iree_gpu.virtual_mma_layout<intrinsic = VMFMA_F32_32x32x16_F16>
} %A, %B, %C : vector<32x16xf16>, vector<16x32xf16> into vector<32x32xf32>

%O = iree_vector_ext.to_layout %output to layout(#layout_c) : vector<32x32xf32>
Expand Down
Loading

0 comments on commit 9561d57

Please sign in to comment.