From 68340c2441cd11d15c2032c3acacb95949a37935 Mon Sep 17 00:00:00 2001 From: Stanley Winata <68087699+raikonenfnu@users.noreply.github.com> Date: Thu, 7 Nov 2024 10:29:54 -0800 Subject: [PATCH] [LLVMGPU] Refactor VirtualMMAIntrinsic to it's own attribute and enum. (#19055) VirtualMMA are variants of MMAOps that has layouts modified through interleaved of reads and/or unrolled-K. These are especially useful for coalescing reads from shared memory to register and for FA to align layouts between the 1st and 2nd matmul. However, the current implementation of virtual MMA lives in MMAAttr and MMAIntrinsic which is a violation of abstractions. Especially since data_tiled_mma_layout is dependent on it. Additionally, having VirtualMMA in MMAAttr also limits the flexibility of adding more stuff and specification to it in the future. In order to resolve this conflict of abstractions, we are migrating VirtualMMAAttr to it's own enums and attr/class. Things that we have changed in this PR: 1. Factor our VMFMA from `MMAAttr` and `MMAIntrinsic` into it's own attr and enums `VirtualMMAAttr ` and `VirtualMMAIntrinsic`. 2. Update places that may use `MMAAttr` or `VirtualMMAAttr` to simply use `MmaInterfaceAttr` such as `KernelConfig`, `LLVMGPUConfigureTensorLayout`, and `AMDGPUDistributeContract`, `GPUNestedDistribution` 3. Move `get[A,B,C]SingleSubgroupLayout` as global method that works on `MmaInterfaceAttr`. --------- Signed-off-by: Stanley Winata Signed-off-by: Giacomo Serafini <179146510+giacs-epic@users.noreply.github.com> --- .../Common/GPU/AMDGPUDistributeContract.cpp | 34 +- .../GPUNestedLayoutDistributionPatterns.cpp | 4 +- .../gpu_nested_layout_contract_amdgpu.mlir | 2 +- .../Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp | 380 ++++++++++++++---- .../Codegen/Dialect/GPU/IR/IREEGPUAttrs.h | 17 +- .../Codegen/Dialect/GPU/IR/IREEGPUAttrs.td | 63 ++- .../Codegen/Dialect/GPU/IR/IREEGPUEnums.td | 10 +- .../compiler/Codegen/LLVMGPU/KernelConfig.cpp | 59 +-- .../LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp | 12 +- .../pipeline_vector_distribute_gfx940.mlir | 4 +- 10 files changed, 449 insertions(+), 136 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPUDistributeContract.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPUDistributeContract.cpp index 3e2cb427ed989..8f5752ae2a1a1 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPUDistributeContract.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPUDistributeContract.cpp @@ -50,23 +50,22 @@ isSubgroupLayoutCompatible(IREE::GPU::MMASingleSubgroupLayout subgroupLayout, return success(); } -static LogicalResult isIntrinsicLayoutCompatible(VectorContractOpInfo &opInfo, - IREE::GPU::MMAAttr intrinsic, - NestedLayoutAttr lhsLayout, - NestedLayoutAttr rhsLayout, - NestedLayoutAttr accLayout) { +static LogicalResult isIntrinsicLayoutCompatible( + VectorContractOpInfo &opInfo, IREE::GPU::MmaInterfaceAttr intrinsic, + NestedLayoutAttr lhsLayout, NestedLayoutAttr rhsLayout, + NestedLayoutAttr accLayout) { auto [lhsM, rhsN] = opInfo.getOperandMNIndex(); auto [lhsK, rhsK] = opInfo.getOperandKIndex(); auto [accM, accN] = opInfo.getResultMNIndex(); - if (failed(isSubgroupLayoutCompatible(intrinsic.getASingleSubgroupLayout(), + if (failed(isSubgroupLayoutCompatible(getASingleSubgroupLayout(intrinsic), lhsLayout, lhsM, lhsK))) { return failure(); } - if (failed(isSubgroupLayoutCompatible(intrinsic.getBSingleSubgroupLayout(), + if (failed(isSubgroupLayoutCompatible(getBSingleSubgroupLayout(intrinsic), rhsLayout, rhsK, rhsN))) { return failure(); } - if (failed(isSubgroupLayoutCompatible(intrinsic.getCSingleSubgroupLayout(), + if (failed(isSubgroupLayoutCompatible(getCSingleSubgroupLayout(intrinsic), accLayout, accM, accN))) { return failure(); } @@ -124,16 +123,16 @@ struct DistributeContract final : OpDistributionPattern { // We assume there is an decision made before regarding which mfma intrinsic // to use and it is attached as an attribute to this contract op. - auto mmaAttr = - contractOp->getAttrOfType("iree.amdgpu.mma"); - if (!mmaAttr) { + auto mmaKind = contractOp->getAttrOfType( + "iree.amdgpu.mma"); + if (!mmaKind) { return rewriter.notifyMatchFailure( contractOp, "missing iree.amdgpu.mma intrinsic attribute"); } // Check if the given intrinsic can be distributed with the given // layouts. - if (failed(isIntrinsicLayoutCompatible(opDetail, mmaAttr, lhsLayout, + if (failed(isIntrinsicLayoutCompatible(opDetail, mmaKind, lhsLayout, rhsLayout, accLayout))) { return rewriter.notifyMatchFailure( contractOp, "the intrinsic does not match the expected layouts"); @@ -230,7 +229,7 @@ struct DistributeContract final : OpDistributionPattern { Value rhsSlice = rewriter.create(loc, rhs, rhsBatchOffsets); accSlice = - computeMMA(rewriter, loc, mmaAttr, lhsSlice, rhsSlice, accSlice); + computeMMA(rewriter, loc, mmaKind, lhsSlice, rhsSlice, accSlice); } finalTile = rewriter.create(loc, accSlice, finalTile, resultBatchOffsets); @@ -285,17 +284,18 @@ struct DistributeContract final : OpDistributionPattern { // Generates amdgpu.mfma operation on the given inputs for the given MFMA // |intrinsic|. - Value computeMMA(OpBuilder &builder, Location loc, IREE::GPU::MMAAttr mmaAttr, - Value a, Value b, Value c) const { + Value computeMMA(OpBuilder &builder, Location loc, + IREE::GPU::MmaInterfaceAttr mmaKind, Value a, Value b, + Value c) const { // Get the storage vector types that each thread is in charge of. - auto [aVectorType, bVectorType, cVectorType] = mmaAttr.getABCVectorTypes(); + auto [aVectorType, bVectorType, cVectorType] = mmaKind.getABCVectorTypes(); Value aCast = builder.create(a.getLoc(), aVectorType, a); Value bCast = builder.create(b.getLoc(), bVectorType, b); Value cCast = builder.create(c.getLoc(), cVectorType, c); - FailureOr mmaOp = mmaAttr.buildMmaOperation( + FailureOr mmaOp = mmaKind.buildMmaOperation( builder, loc, cVectorType, aCast, bCast, cCast); assert(succeeded(mmaOp) && "Failed to construct mma op"); return builder.create(c.getLoc(), c.getType(), *mmaOp); diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp index c97d8913b3cb6..4675d1dab9ce8 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp @@ -593,8 +593,8 @@ struct DistributeContract final : OpDistributionPattern { } // If mmaAttr exists, defer the lowering to use MMA. // Notify failure if the "iree.amdgpu.mma" intrinsic attribute is present. - auto mmaAttr = - contractOp->getAttrOfType("iree.amdgpu.mma"); + auto mmaAttr = contractOp->getAttrOfType( + "iree.amdgpu.mma"); if (mmaAttr) { return rewriter.notifyMatchFailure( contractOp, "iree.amdgpu.mma intrinsic attribute exists"); diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_contract_amdgpu.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_contract_amdgpu.mlir index db39c0b15742d..eecd2f0653ce7 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_contract_amdgpu.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_contract_amdgpu.mlir @@ -645,7 +645,7 @@ func.func @contract_to_vmfma_32x32x16_mm(%a : vector<32x16xf16>, %b : vector<16x indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind, - iree.amdgpu.mma = #iree_gpu.mma_layout + iree.amdgpu.mma = #iree_gpu.virtual_mma_layout } %A, %B, %C : vector<32x16xf16>, vector<16x32xf16> into vector<32x32xf32> %O = iree_vector_ext.to_layout %output to layout(#layout_c) : vector<32x32xf32> diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp index 0be2db3428adc..a09ae277819f9 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp @@ -236,8 +236,7 @@ static OpaqueMmaLayout getOpaqueMFMALayout(MLIRContext *context, case MMAIntrinsic::MFMA_F32_32x32x8_BF16: { return OpaqueMmaLayout{32, 32, 8, bf16, bf16, f32}; } - case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ: - case MMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ: { + case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ: { return OpaqueMmaLayout{16, 16, 32, f8E4M3FNUZ, f8E4M3FNUZ, f32}; } case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ: { @@ -264,14 +263,6 @@ static OpaqueMmaLayout getOpaqueMFMALayout(MLIRContext *context, case MMAIntrinsic::WMMA_I32_16x16x16_I8: { return OpaqueMmaLayout{16, 16, 16, i8, i8, i32}; } - // V(Virtual)MFMA instructions which have 2 mfma instructions interleaved - // along the k dimension. - case MMAIntrinsic::VMFMA_F32_16x16x32_F16: { - return OpaqueMmaLayout{16, 16, 32, f16, f16, f32}; - } - case MMAIntrinsic::VMFMA_F32_32x32x16_F16: { - return OpaqueMmaLayout{32, 32, 16, f16, f16, f32}; - } } llvm_unreachable("unhandled mfma layout type"); return OpaqueMmaLayout{}; @@ -345,6 +336,43 @@ static ConcreteMmaLayout getConcreteMFMALayout(MLIRContext *context, return concreteLayout; } +//===----------------------------------------------------------------------===// +// MmaInterface Attribute Helper Functions +//===----------------------------------------------------------------------===// + +MMASingleSubgroupLayout getASingleSubgroupLayout(MmaInterfaceAttr mmaKind) { + if (auto mmaAttr = dyn_cast(mmaKind)) { + return mmaAttr.getASingleSubgroupLayout(); + } else if (auto vmmaAttr = dyn_cast(mmaKind)) { + return vmmaAttr.getASingleSubgroupLayout(); + } else { + assert(false && "unhandled MMA Interface type."); + return {}; + } +} + +MMASingleSubgroupLayout getBSingleSubgroupLayout(MmaInterfaceAttr mmaKind) { + if (auto mmaAttr = dyn_cast(mmaKind)) { + return mmaAttr.getBSingleSubgroupLayout(); + } else if (auto vmmaAttr = dyn_cast(mmaKind)) { + return vmmaAttr.getBSingleSubgroupLayout(); + } else { + assert(false && "unhandled MMA Interface type."); + return {}; + } +} + +MMASingleSubgroupLayout getCSingleSubgroupLayout(MmaInterfaceAttr mmaKind) { + if (auto mmaAttr = dyn_cast(mmaKind)) { + return mmaAttr.getCSingleSubgroupLayout(); + } else if (auto vmmaAttr = dyn_cast(mmaKind)) { + return vmmaAttr.getCSingleSubgroupLayout(); + } else { + assert(false && "unhandled MMA Interface type."); + return {}; + } +} + //===----------------------------------------------------------------------===// // MFMA Attributes //===----------------------------------------------------------------------===// @@ -421,15 +449,12 @@ MMAAttr::getABCVectorTypes() const { } case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ: case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ: - case MMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ: - case MMAIntrinsic::VMFMA_F32_16x16x32_F16: case MMAIntrinsic::MFMA_I32_16x16x32_I8: { auto aType = VectorType::get({8}, getAType()); auto bType = VectorType::get({8}, getBType()); auto cType = VectorType::get({4}, getCType()); return std::make_tuple(aType, bType, cType); } - case MMAIntrinsic::VMFMA_F32_32x32x16_F16: case MMAIntrinsic::MFMA_I32_32x32x16_I8: { auto aType = VectorType::get({8}, getAType()); auto bType = VectorType::get({8}, getBType()); @@ -473,10 +498,7 @@ int64_t MMAAttr::getBlockSize() const { case MMAIntrinsic::MFMA_I32_32x32x8_I8: case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ: case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ: - case MMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ: - case MMAIntrinsic::VMFMA_F32_16x16x32_F16: case MMAIntrinsic::MFMA_I32_16x16x32_I8: - case MMAIntrinsic::VMFMA_F32_32x32x16_F16: case MMAIntrinsic::MFMA_I32_32x32x16_I8: case MMAIntrinsic::WMMA_F16_16x16x16_F16: case MMAIntrinsic::WMMA_F32_16x16x16_F16: @@ -499,10 +521,7 @@ static int64_t getIntrinsicSubgroupSize(MMAIntrinsic intrinsic) { case MMAIntrinsic::MFMA_I32_32x32x8_I8: case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ: case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ: - case MMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ: - case MMAIntrinsic::VMFMA_F32_16x16x32_F16: case MMAIntrinsic::MFMA_I32_16x16x32_I8: - case MMAIntrinsic::VMFMA_F32_32x32x16_F16: case MMAIntrinsic::MFMA_I32_32x32x16_I8: { return 64; } @@ -567,7 +586,6 @@ MMASingleSubgroupLayout getSingleSubgroupLayout(MMAIntrinsic intrinsic, return {/*outer=*/{4, 1}, /*thread=*/{2, 32}, /*tstrides=*/{32, 1}, /*element=*/{4, 1}}; } - case MMAIntrinsic::VMFMA_F32_16x16x32_F16: case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ: case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ: case MMAIntrinsic::MFMA_I32_16x16x32_I8: @@ -582,19 +600,6 @@ MMASingleSubgroupLayout getSingleSubgroupLayout(MMAIntrinsic intrinsic, return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*tstrides=*/{16, 1}, /*element=*/{4, 1}}; } - case MMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ: - switch (fragment) { - case MMAFragment::Lhs: - return {/*outer=*/{1, 2}, /*thread=*/{16, 4}, /*tstrides=*/{1, 16}, - /*element=*/{1, 4}}; - case MMAFragment::Rhs: - return {/*outer=*/{2, 1}, /*thread=*/{4, 16}, /*tstrides=*/{16, 1}, - /*element=*/{4, 1}}; - case MMAFragment::Acc: - return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*tstrides=*/{16, 1}, - /*element=*/{4, 1}}; - } - case MMAIntrinsic::VMFMA_F32_32x32x16_F16: case MMAIntrinsic::MFMA_I32_32x32x16_I8: switch (fragment) { case MMAFragment::Lhs: @@ -649,14 +654,14 @@ MMASingleSubgroupLayout MMAAttr::getCSingleSubgroupLayout() const { } // Get virtual intrinsics that is composed/based on queried op. -SmallVector MMAAttr::getVirtualIntrinsics() const { +SmallVector MMAAttr::getVirtualIntrinsics() const { switch (getIntrinsic().getValue()) { case MMAIntrinsic::MFMA_F32_16x16x16_F16: - return {MMAIntrinsic::VMFMA_F32_16x16x32_F16}; + return {VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F16}; case MMAIntrinsic::MFMA_F32_32x32x8_F16: - return {MMAIntrinsic::VMFMA_F32_32x32x16_F16}; + return {VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16}; case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ: - return {MMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ}; + return {VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ}; default: return {}; } @@ -690,37 +695,6 @@ FailureOr MMAAttr::buildMmaOperation(OpBuilder &builder, Location loc, rhs, acc) .getResult(); } - case MMAIntrinsic::VMFMA_F32_16x16x32_F16: - case MMAIntrinsic::VMFMA_F32_32x32x16_F16: { - // Generate mfma's for K with unrolled kernels. - const int64_t unrollKFactor = 2; - auto [m, n, k] = getMNKShape(); - // Compute actual/native intrinsic's K size. - int64_t nativeKSize = k / unrollKFactor; - - auto [aType, bType, cType] = getABCVectorTypes(); - if (aType.getShape()[0] != bType.getShape()[0]) { - // Currently only support case where lhs and rhs - // has same vectorWidth. - return failure(); - } - int64_t vectorWidth = aType.getShape()[0] / unrollKFactor; - for (int i = 0; i < unrollKFactor; i++) { - int64_t offset = vectorWidth * i; - Value sliced_lhs = builder.create( - loc, lhs, ArrayRef{offset}, ArrayRef{vectorWidth}, - ArrayRef{1}); - Value sliced_rhs = builder.create( - loc, rhs, ArrayRef{offset}, ArrayRef{vectorWidth}, - ArrayRef{1}); - acc = builder - .create(loc, resultType, m, n, nativeKSize, - getBlockSize(), sliced_lhs, sliced_rhs, - acc) - .getResult(); - } - return acc; - } case MMAIntrinsic::MFMA_I32_16x16x16_I8: case MMAIntrinsic::MFMA_F32_16x16x16_F16: case MMAIntrinsic::MFMA_F32_16x16x16_BF16: @@ -729,7 +703,6 @@ FailureOr MMAAttr::buildMmaOperation(OpBuilder &builder, Location loc, case MMAIntrinsic::MFMA_F32_32x32x8_BF16: case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ: case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ: - case MMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ: case MMAIntrinsic::MFMA_I32_16x16x32_I8: case MMAIntrinsic::MFMA_I32_32x32x16_I8: { auto [m, n, k] = getMNKShape(); @@ -1179,6 +1152,267 @@ FailureOr DataTiledMMAAttr::buildMmaOperation(OpBuilder &builder, return acc; } +//===----------------------------------------------------------------------===// +// VirtualMMA Attributes +//===----------------------------------------------------------------------===// + +VirtualMMAAttr VirtualMMAAttr::get(MLIRContext *context, + VirtualMMAIntrinsic type) { + auto intrinsicAttr = VirtualMMAIntrinsicAttr::get(context, type); + return VirtualMMAAttr::get(context, intrinsicAttr); +} + +static OpaqueMmaLayout getOpaqueVMMALayout(MLIRContext *context, + VirtualMMAIntrinsic type) { + Type f8E4M3FNUZ = Float8E4M3FNUZType::get(context); + Type f16 = Float16Type::get(context); + Type f32 = Float32Type::get(context); + + switch (type) { + case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ: { + return OpaqueMmaLayout{16, 16, 32, f8E4M3FNUZ, f8E4M3FNUZ, f32}; + } + // V(Virtual)MFMA instructions which have 2 mfma instructions interleaved + // along the k dimension. + case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F16: { + return OpaqueMmaLayout{16, 16, 32, f16, f16, f32}; + } + case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16: { + return OpaqueMmaLayout{32, 32, 16, f16, f16, f32}; + } + } + assert(false && "unhandled virtual mma layout type."); + return OpaqueMmaLayout{}; +} + +std::tuple VirtualMMAAttr::getABCElementTypes() const { + MLIRContext *ctx = getContext(); + auto opaqueLayout = getOpaqueVMMALayout(ctx, getIntrinsic().getValue()); + return {opaqueLayout.aType, opaqueLayout.bType, opaqueLayout.cType}; +} + +std::tuple +VirtualMMAAttr::getABCVectorTypes() const { + // Check https://github.com/ROCm/amd_matrix_instruction_calculator for + // instruction details. Note here we are returning the number elements, while + // amd_matrix_instruction_calculator tells us about the number of 32-bit + // registers. So need to adjust accordingly. All vectors should be 1-D. + auto [A, B, C] = getABCElementTypes(); + switch (getIntrinsic().getValue()) { + case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ: + case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F16: { + auto aType = VectorType::get({8}, A); + auto bType = VectorType::get({8}, B); + auto cType = VectorType::get({4}, C); + return {aType, bType, cType}; + } + case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16: { + auto aType = VectorType::get({8}, A); + auto bType = VectorType::get({8}, B); + auto cType = VectorType::get({16}, C); + return {aType, bType, cType}; + } + } + // This should not happen but just to make GCC happy. + assert(false && "unhandled virtual mma layout type."); + return {VectorType{}, VectorType{}, VectorType{}}; +} + +std::tuple VirtualMMAAttr::getMNKShape() const { + MLIRContext *ctx = getContext(); + auto opaqueLayout = getOpaqueVMMALayout(ctx, getIntrinsic().getValue()); + return {opaqueLayout.mSize, opaqueLayout.nSize, opaqueLayout.kSize}; +} + +int64_t VirtualMMAAttr::getSubgroupSize() const { + switch (getIntrinsic().getValue()) { + case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ: + case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F16: + case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16: { + return 64; + } + } + // This should not happen but just to make GCC happy. + assert(false && "unhandled virtual mma layout type."); + return 0; +} + +FailureOr VirtualMMAAttr::getMmaScope() const { + return IREE::GPU::MMAScope::Subgroup; +} + +LogicalResult VirtualMMAAttr::populateOperandOffsetsSizesStrides( + OpBuilder &builder, Location loc, IREE::GPU::MMAFragment fragment, + Value laneId, ArrayRef permutation, + SmallVector &offsets, SmallVector &sizes, + SmallVector &strides) const { + + MMASingleSubgroupLayout subgroupLayout; + switch (fragment) { + case IREE::GPU::MMAFragment::Lhs: { + subgroupLayout = getASingleSubgroupLayout(); + break; + } + case IREE::GPU::MMAFragment::Rhs: { + subgroupLayout = getBSingleSubgroupLayout(); + break; + } + case IREE::GPU::MMAFragment::Acc: { + subgroupLayout = getCSingleSubgroupLayout(); + break; + } + } + + SmallVector canonicalOffsets; + SmallVector canonicalSizes; + if (failed(populateCanonicalOffsetsSizesAndStrides( + builder, loc, laneId, permutation, subgroupLayout, canonicalOffsets, + canonicalSizes, strides))) { + return failure(); + } + offsets.append(canonicalOffsets); + sizes.append(canonicalSizes); + + return success(); +} + +int64_t VirtualMMAAttr::getUnrollK() const { + switch (getIntrinsic().getValue()) { + case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F16: + case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16: { + return 2; + } + case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ: { + return 1; + } + } + // This should not happen but just to make GCC happy. + assert(false && "unhandled virtual mma layout type."); + return 0; +} + +// Generates amdgpu.mfma/wmma operation on the given inputs for this attribute +// type. +FailureOr VirtualMMAAttr::buildMmaOperation(OpBuilder &builder, + Location loc, + Type resultType, Value lhs, + Value rhs, Value acc) const { + auto [aType, bType, cType] = getABCVectorTypes(); + if (aType != lhs.getType() || bType != rhs.getType() || + cType != acc.getType()) { + return failure(); + } + // Fail if the result type does not match with the expected return type of + // the intrinsic. We expect the caller to handle type conversions externally. + if (cType != resultType) { + return failure(); + } + switch (getIntrinsic().getValue()) { + case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ: + case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F16: + case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16: { + // Generate mfma's for K with unrolled kernels. + const int64_t unrollKFactor = getUnrollK(); + auto [m, n, k] = getMNKShape(); + // Compute actual/native intrinsic's K size. + int64_t nativeKSize = k / unrollKFactor; + + auto [aType, bType, cType] = getABCVectorTypes(); + if (aType.getShape()[0] != bType.getShape()[0]) { + // Currently only support case where lhs and rhs + // has same vectorWidth. + return failure(); + } + int64_t vectorWidth = aType.getShape()[0] / unrollKFactor; + for (int i = 0; i < unrollKFactor; i++) { + int64_t offset = vectorWidth * i; + Value sliced_lhs = builder.create( + loc, lhs, ArrayRef{offset}, ArrayRef{vectorWidth}, + ArrayRef{1}); + Value sliced_rhs = builder.create( + loc, rhs, ArrayRef{offset}, ArrayRef{vectorWidth}, + ArrayRef{1}); + acc = builder + .create(loc, resultType, m, n, nativeKSize, + getBlockSize(), sliced_lhs, sliced_rhs, + acc) + .getResult(); + } + return acc; + } + } + return failure(); +} + +int64_t VirtualMMAAttr::getBlockSize() const { + switch (getIntrinsic().getValue()) { + case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ: + case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F16: + case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16: { + return 1; + } + } + // This should not happen but just to make GCC happy. + assert(false && "unhandled virtual mma layout type."); + return 0; +} + +MMASingleSubgroupLayout getSingleSubgroupLayout(VirtualMMAIntrinsic intrinsic, + MMAFragment fragment) { + switch (intrinsic) { + case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F16: + switch (fragment) { + case MMAFragment::Lhs: + return {/*outer=*/{1, 1}, /*thread=*/{16, 4}, /*tstrides=*/{1, 16}, + /*element=*/{1, 8}}; + case MMAFragment::Rhs: + return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*tstrides=*/{16, 1}, + /*element=*/{8, 1}}; + case MMAFragment::Acc: + return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*tstrides=*/{16, 1}, + /*element=*/{4, 1}}; + } + case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ: + switch (fragment) { + case MMAFragment::Lhs: + return {/*outer=*/{1, 2}, /*thread=*/{16, 4}, /*tstrides=*/{1, 16}, + /*element=*/{1, 4}}; + case MMAFragment::Rhs: + return {/*outer=*/{2, 1}, /*thread=*/{4, 16}, /*tstrides=*/{16, 1}, + /*element=*/{4, 1}}; + case MMAFragment::Acc: + return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*tstrides=*/{16, 1}, + /*element=*/{4, 1}}; + } + case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16: + switch (fragment) { + case MMAFragment::Lhs: + return {/*outer=*/{1, 1}, /*thread=*/{32, 2}, /*tstrides=*/{1, 32}, + /*element=*/{1, 8}}; + case MMAFragment::Rhs: + return {/*outer=*/{1, 1}, /*thread=*/{2, 32}, /*tstrides=*/{32, 1}, + /*element=*/{8, 1}}; + case MMAFragment::Acc: + return {/*outer=*/{4, 1}, /*thread=*/{2, 32}, /*tstrides=*/{32, 1}, + /*element=*/{4, 1}}; + } + } + assert(false && "unhandled virtual mma layout type."); + return {}; +} + +MMASingleSubgroupLayout VirtualMMAAttr::getASingleSubgroupLayout() const { + return getSingleSubgroupLayout(getIntrinsic().getValue(), MMAFragment::Lhs); +} + +MMASingleSubgroupLayout VirtualMMAAttr::getBSingleSubgroupLayout() const { + return getSingleSubgroupLayout(getIntrinsic().getValue(), MMAFragment::Rhs); +} + +MMASingleSubgroupLayout VirtualMMAAttr::getCSingleSubgroupLayout() const { + return getSingleSubgroupLayout(getIntrinsic().getValue(), MMAFragment::Acc); +} + //===----------------------------------------------------------------------===// // MMA Schedule Attributes //===----------------------------------------------------------------------===// @@ -1313,7 +1547,7 @@ MMAScheduleAttr::getContractionLayout(VectorContractOpInfo &opInfo, }); int64_t rank = contractOp.getIteratorTypesArray().size(); - auto mmaAttr = llvm::cast(getIntrinsic()); + auto mmaAttr = llvm::cast(getIntrinsic()); MLIRContext *context = getContext(); SmallVector bounds = contractOp.getStaticLoopRanges(); @@ -1460,7 +1694,7 @@ MMAScheduleAttr::getContractionLayout(VectorContractOpInfo &opInfo, /*subgroupCount=*/cSubgroupSizes, /*subgroupStrides=*/cSubgroupStrides, /*batchCount=*/cBatchSizes, - mmaAttr.getCSingleSubgroupLayout()); + getCSingleSubgroupLayout(mmaAttr)); LLVM_DEBUG({ llvm::errs() << "C layout: " << cLayout << "\n"; }); // A matrix layout @@ -1488,7 +1722,7 @@ MMAScheduleAttr::getContractionLayout(VectorContractOpInfo &opInfo, /*subgroupCount=*/aSubgroupSizes, /*subgroupStrides=*/aSubgroupStrides, /*batchCount=*/aBatchSizes, - mmaAttr.getASingleSubgroupLayout()); + getASingleSubgroupLayout(mmaAttr)); LLVM_DEBUG({ llvm::errs() << "A layout: " << aLayout << "\n"; }); int64_t bRank = opInfo.getBRank(); @@ -1512,7 +1746,7 @@ MMAScheduleAttr::getContractionLayout(VectorContractOpInfo &opInfo, /*subgroupCount=*/bSubgroupSizes, /*subgroupStrides=*/bSubgroupStrides, /*batchCount=*/bBatchSizes, - mmaAttr.getBSingleSubgroupLayout()); + getBSingleSubgroupLayout(mmaAttr)); LLVM_DEBUG({ llvm::errs() << "B layout: " << bLayout << "\n"; }); std::tuple outer; - llvm::SmallVector thread; - llvm::SmallVector tstrides; - llvm::SmallVector element; + SmallVector outer; + SmallVector thread; + SmallVector tstrides; + SmallVector element; }; MMASingleSubgroupLayout getSingleSubgroupLayout(MMAIntrinsic intrinsic, MMAFragment fragment); +MMASingleSubgroupLayout getSingleSubgroupLayout(VirtualMMAIntrinsic intrinsic, + MMAFragment fragment); + +MMASingleSubgroupLayout getASingleSubgroupLayout(MmaInterfaceAttr mmaKind); + +MMASingleSubgroupLayout getBSingleSubgroupLayout(MmaInterfaceAttr mmaKind); + +MMASingleSubgroupLayout getCSingleSubgroupLayout(MmaInterfaceAttr mmaKind); + } // namespace mlir::iree_compiler::IREE::GPU // clang-format off diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td index ee4cb932ded47..d7686562f7cda 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td @@ -152,6 +152,9 @@ class IREEGPU_MmaVectorLayoutAttr : "getMNKShape", "getSubgroupSize", "getMmaScope", + "getASingleSubgroupLayout", + "getBSingleSubgroupLayout", + "getCSingleSubgroupLayout", "buildMmaOperation", "populateOperandOffsetsSizesStrides", "materializeOperandConcreteShape", @@ -232,7 +235,7 @@ def IREEGPU_MMAAttr : IREEGPU_MmaVectorLayoutAttr<"MMA", "MMAIntrinsicAttr"> { MMASingleSubgroupLayout getBSingleSubgroupLayout() const; MMASingleSubgroupLayout getCSingleSubgroupLayout() const; - SmallVector getVirtualIntrinsics() const; + SmallVector getVirtualIntrinsics() const; }]; } @@ -283,6 +286,64 @@ def IREEGPU_DataTiledMMAAttr : ); } +def IREEGPU_VirtualMMAIntrinsicAttr + : IREEGPU_MmaEnumAttr; + +def IREEGPU_VirtualMMAAttr : + AttrDef +]> { + let mnemonic = "virtual_mma_layout"; + let cppNamespace = "::mlir::iree_compiler::IREE::GPU"; + + let description = [{ + This mma variant represents "virtual" MMA ops that has modification to + its native layouts by unrollK and/or interleave reads. The |intrinsic| + field represents different kinds of "Virtual" MMA Ops we found helpful. + + These interleaving and/or unrolling changes in the layout is especially + useful to coalesce reads from shared memory to register or align layouts + in a chained-matmul operation. + }]; + + let assemblyFormat = "`<` struct(params) `>`"; + let builders = [ + AttrBuilder<(ins "VirtualMMAIntrinsic":$intrinsic)> + ]; + + let parameters = (ins + "::mlir::iree_compiler::IREE::GPU::VirtualMMAIntrinsicAttr":$intrinsic + ); + + let extraClassDeclaration = [{ + int64_t getBlockSize() const; + + // Returns the A/B/C matrix's partial nested layout shape inside a single + // subgroup. Shape at each outer/thread/element level is a 2-D value, + // following canonical matmul order--(M, K) for A, (K, N) for B, and + // (M, N) for C. + MMASingleSubgroupLayout getASingleSubgroupLayout() const; + MMASingleSubgroupLayout getBSingleSubgroupLayout() const; + MMASingleSubgroupLayout getCSingleSubgroupLayout() const; + + // Factor to unroll K from native MMA/intrinsic size to virtual size. + // e.g MFMA_F32_16x16x16 has K of 16, while VMFMA_F32_16x16x32 has K of 32 + // in this example, unrollK = 32/16 = 2. + int64_t getUnrollK() const; + }]; +} + def IREEGPU_MMAOpsArrayAttr : ArrayOfAttr< IREEGPU_Dialect, "MMAOpsArray", "mma_ops", "MMAAttr"> { let cppNamespace = "::mlir::iree_compiler::IREE::GPU"; diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td index 49f210e7f06dd..22a510db049cf 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td @@ -156,13 +156,10 @@ def IREEGPU_MMAIntrinsic : IREEGPU_I32MmaEnumAttr<"MMAIntrinsic", MFMA_F32_16x16x4_F32, MFMA_F32_16x16x16_F16, MFMA_F32_32x32x8_F16, - VMFMA_F32_16x16x32_F16, - VMFMA_F32_32x32x16_F16, MFMA_F32_16x16x16_BF16, MFMA_F32_32x32x8_BF16, MFMA_F32_16x16x32_F8E4M3FNUZ, MFMA_F32_16x16x32_F8E5M2FNUZ, - VMFMA_F32_16x16x32_F8E4M3FNUZ, MFMA_I32_16x16x32_I8, MFMA_I32_32x32x16_I8, MFMA_I32_16x16x16_I8, @@ -172,6 +169,13 @@ def IREEGPU_MMAIntrinsic : IREEGPU_I32MmaEnumAttr<"MMAIntrinsic", WMMA_I32_16x16x16_I8 ]>; +def IREEGPU_VirtualMMAIntrinsic : IREEGPU_I32MmaEnumAttr<"VirtualMMAIntrinsic", + "Descriptor for different Virtual MMA intrinsics", [ + VMFMA_F32_16x16x32_F16, + VMFMA_F32_32x32x16_F16, + VMFMA_F32_16x16x32_F8E4M3FNUZ, + ]>; + def MMA_LHS : I32EnumAttrCase<"Lhs", 0>; def MMA_RHS : I32EnumAttrCase<"Rhs", 1>; def MMA_ACC : I32EnumAttrCase<"Acc", 2>; diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp index 96a7a9d8ddb71..f269e2be0a9be 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp @@ -323,28 +323,29 @@ setConvolutionVectorDistributionConfig(IREE::GPU::TargetAttr target, lhsElemType, rhsElemType, initElemType}; // Helper fn to store mma information. - auto storeMmaInfo = [](IREE::GPU::MMAAttr mma, + auto storeMmaInfo = [](IREE::GPU::MmaInterfaceAttr mma, SmallVector &intrinsics, - SmallVector &mmaAttrs) { + SmallVector &mmaKinds) { auto [mSize, nSize, kSize] = mma.getMNKShape(); auto [aType, bType, cType] = mma.getABCElementTypes(); intrinsics.emplace_back(mSize, nSize, kSize, aType, bType, cType); - mmaAttrs.emplace_back(mma); + mmaKinds.emplace_back(mma); }; SmallVector intrinsics; intrinsics.reserve(target.getWgp().getMma().size()); - SmallVector mmaAttrs; + SmallVector mmaKinds; MLIRContext *context = op.getContext(); for (IREE::GPU::MMAAttr mma : target.getWgp().getMma()) { if (mma.getSubgroupSize() != targetSubgroupSize) continue; - storeMmaInfo(mma, intrinsics, mmaAttrs); + storeMmaInfo(mma, intrinsics, mmaKinds); // Store info on virtual intrinsics based on current mma if any - for (IREE::GPU::MMAIntrinsic virtualIntrinsic : + for (IREE::GPU::VirtualMMAIntrinsic virtualIntrinsic : mma.getVirtualIntrinsics()) { - auto virtualMma = IREE::GPU::MMAAttr::get(context, virtualIntrinsic); - storeMmaInfo(virtualMma, intrinsics, mmaAttrs); + auto virtualMma = + IREE::GPU::VirtualMMAAttr::get(context, virtualIntrinsic); + storeMmaInfo(virtualMma, intrinsics, mmaKinds); } } @@ -417,7 +418,7 @@ setConvolutionVectorDistributionConfig(IREE::GPU::TargetAttr target, b.getI64ArrayAttr(reductionTileSizes)); IREE::GPU::LoweringConfigAttr::setPromotedOperandList(context, attrs, {0, 1}); IREE::GPU::LoweringConfigAttr::setMmaKind(context, attrs, - mmaAttrs[schedule->index]); + mmaKinds[schedule->index]); IREE::GPU::LoweringConfigAttr::setSubgroupMCount( context, attrs, schedule->mSubgroupCounts[0]); IREE::GPU::LoweringConfigAttr::setSubgroupNCount( @@ -536,28 +537,29 @@ setMatmulVectorDistributionConfig(IREE::GPU::TargetAttr target, lhsElemType, rhsElemType, initElemType}; // Helper fn to store mma information. - auto storeMmaInfo = [](IREE::GPU::MMAAttr mma, + auto storeMmaInfo = [](IREE::GPU::MmaInterfaceAttr mma, SmallVector &intrinsics, - SmallVector &mmaAttrs) { + SmallVector &mmaKinds) { auto [mSize, nSize, kSize] = mma.getMNKShape(); auto [aType, bType, cType] = mma.getABCElementTypes(); intrinsics.emplace_back(mSize, nSize, kSize, aType, bType, cType); - mmaAttrs.emplace_back(mma); + mmaKinds.emplace_back(mma); }; SmallVector intrinsics; intrinsics.reserve(target.getWgp().getMma().size()); - SmallVector mmaAttrs; + SmallVector mmaKinds; MLIRContext *context = op.getContext(); for (IREE::GPU::MMAAttr mma : target.getWgp().getMma()) { if (mma.getSubgroupSize() != targetSubgroupSize) continue; - storeMmaInfo(mma, intrinsics, mmaAttrs); + storeMmaInfo(mma, intrinsics, mmaKinds); // Store info on virtual intrinsics based on current mma if any - for (IREE::GPU::MMAIntrinsic virtualIntrinsic : + for (IREE::GPU::VirtualMMAIntrinsic virtualIntrinsic : mma.getVirtualIntrinsics()) { - auto virtualMma = IREE::GPU::MMAAttr::get(context, virtualIntrinsic); - storeMmaInfo(virtualMma, intrinsics, mmaAttrs); + auto virtualMma = + IREE::GPU::VirtualMMAAttr::get(context, virtualIntrinsic); + storeMmaInfo(virtualMma, intrinsics, mmaKinds); } } @@ -682,7 +684,7 @@ setMatmulVectorDistributionConfig(IREE::GPU::TargetAttr target, b.getI64ArrayAttr(reductionTileSizes)); IREE::GPU::LoweringConfigAttr::setPromotedOperandList(context, attrs, {0, 1}); IREE::GPU::LoweringConfigAttr::setMmaKind(context, attrs, - mmaAttrs[schedule->index]); + mmaKinds[schedule->index]); IREE::GPU::LoweringConfigAttr::setSubgroupMCount( context, attrs, schedule->mSubgroupCounts[0]); IREE::GPU::LoweringConfigAttr::setSubgroupNCount( @@ -759,28 +761,29 @@ setAttentionVectorDistributionConfig(IREE::GPU::TargetAttr target, Value vMatrix = op.getValue(); // Helper fn to store mma information. - auto storeMmaInfo = [](IREE::GPU::MMAAttr mma, + auto storeMmaInfo = [](IREE::GPU::MmaInterfaceAttr mma, SmallVector &intrinsics, - SmallVector &mmaAttrs) { + SmallVector &mmaKinds) { auto [mSize, nSize, kSize] = mma.getMNKShape(); auto [aType, bType, cType] = mma.getABCElementTypes(); intrinsics.emplace_back(mSize, nSize, kSize, aType, bType, cType); - mmaAttrs.emplace_back(mma); + mmaKinds.emplace_back(mma); }; SmallVector intrinsics; intrinsics.reserve(target.getWgp().getMma().size()); - SmallVector mmaAttrs; + SmallVector mmaKinds; MLIRContext *context = op.getContext(); for (IREE::GPU::MMAAttr mma : target.getWgp().getMma()) { if (mma.getSubgroupSize() != targetSubgroupSize) continue; - storeMmaInfo(mma, intrinsics, mmaAttrs); + storeMmaInfo(mma, intrinsics, mmaKinds); // Store info on virtual intrinsics based on current mma if any - for (IREE::GPU::MMAIntrinsic virtualIntrinsic : + for (IREE::GPU::VirtualMMAIntrinsic virtualIntrinsic : mma.getVirtualIntrinsics()) { - auto virtualMma = IREE::GPU::MMAAttr::get(context, virtualIntrinsic); - storeMmaInfo(virtualMma, intrinsics, mmaAttrs); + auto virtualMma = + IREE::GPU::VirtualMMAAttr::get(context, virtualIntrinsic); + storeMmaInfo(virtualMma, intrinsics, mmaKinds); } } @@ -916,7 +919,7 @@ setAttentionVectorDistributionConfig(IREE::GPU::TargetAttr target, IREE::GPU::LoweringConfigAttr::setPromotedOperandList(context, qkConfig, {0, 1}); IREE::GPU::LoweringConfigAttr::setMmaKind(context, qkConfig, - mmaAttrs[schedule->index]); + mmaKinds[schedule->index]); IREE::GPU::LoweringConfigAttr::setSubgroupMCount( context, qkConfig, schedule->mSubgroupCounts[0]); IREE::GPU::LoweringConfigAttr::setSubgroupNCount(context, qkConfig, 1); @@ -924,7 +927,7 @@ setAttentionVectorDistributionConfig(IREE::GPU::TargetAttr target, // Configuring for pv matmul. IREE::GPU::LoweringConfigAttr::setPromotedOperandList(context, pvConfig, {1}); IREE::GPU::LoweringConfigAttr::setMmaKind(context, pvConfig, - mmaAttrs[schedule->index]); + mmaKinds[schedule->index]); IREE::GPU::LoweringConfigAttr::setSubgroupMCount( context, pvConfig, schedule->mSubgroupCounts[0]); IREE::GPU::LoweringConfigAttr::setSubgroupNCount( diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp index 7008f3e1376e5..f9555ef10c5e1 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp @@ -308,14 +308,16 @@ static LogicalResult setAttentionMatmulAnchor(RewriterBase &rewriter, bool reuseIntrinsicOutput = false; bool transposeIntrinsic = false; - auto qkIntrinsic = cast(qkSchedule.getIntrinsic()); - auto pvIntrinsic = cast(pvSchedule.getIntrinsic()); + auto qkIntrinsic = + cast(qkSchedule.getIntrinsic()); + auto pvIntrinsic = + cast(pvSchedule.getIntrinsic()); IREE::GPU::MMASingleSubgroupLayout lhsLayout = - pvIntrinsic.getASingleSubgroupLayout(); + getASingleSubgroupLayout(pvIntrinsic); IREE::GPU::MMASingleSubgroupLayout rhsLayout = - pvIntrinsic.getBSingleSubgroupLayout(); + getBSingleSubgroupLayout(pvIntrinsic); IREE::GPU::MMASingleSubgroupLayout outLayout = - qkIntrinsic.getCSingleSubgroupLayout(); + getCSingleSubgroupLayout(qkIntrinsic); auto matchLayout = [](IREE::GPU::MMASingleSubgroupLayout layoutA, IREE::GPU::MMASingleSubgroupLayout layoutB) -> bool { diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir index a38ae5cee7ec4..960fe0b9938c6 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir @@ -657,7 +657,7 @@ hal.executable public @contract_schedule_considering_read_layout { // This test ensures that we can generate and decompose the right instructions from V(Virtual) MFMAs. -#config = #iree_gpu.lowering_config<{workgroup = [64, 64, 0], reduction = [0, 0, 128], promote_operands = [0, 1], mma_kind = #iree_gpu.mma_layout, subgroup_m_count = 2, subgroup_n_count = 2}> +#config = #iree_gpu.lowering_config<{workgroup = [64, 64, 0], reduction = [0, 0, 128], promote_operands = [0, 1], mma_kind = #iree_gpu.virtual_mma_layout, subgroup_m_count = 2, subgroup_n_count = 2}> #translation = #iree_codegen.translation_info}> #pipeline_layout = #hal.pipeline.layout) { // This test ensures we can generate correct instructions from V(Virtual) MFMAs that has only different read layouts. -#config = #iree_gpu.lowering_config<{workgroup = [32, 32, 0], reduction = [0, 0, 128], promote_operands = [0, 1], mma_kind = #iree_gpu.mma_layout, subgroup_m_count = 2, subgroup_n_count = 2}> +#config = #iree_gpu.lowering_config<{workgroup = [32, 32, 0], reduction = [0, 0, 128], promote_operands = [0, 1], mma_kind = #iree_gpu.virtual_mma_layout, subgroup_m_count = 2, subgroup_n_count = 2}> #translation = #iree_codegen.translation_info}> #pipeline_layout = #hal.pipeline.layout