diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPUDistributeContract.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPUDistributeContract.cpp index 3e2cb427ed989..8f5752ae2a1a1 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPUDistributeContract.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPUDistributeContract.cpp @@ -50,23 +50,22 @@ isSubgroupLayoutCompatible(IREE::GPU::MMASingleSubgroupLayout subgroupLayout, return success(); } -static LogicalResult isIntrinsicLayoutCompatible(VectorContractOpInfo &opInfo, - IREE::GPU::MMAAttr intrinsic, - NestedLayoutAttr lhsLayout, - NestedLayoutAttr rhsLayout, - NestedLayoutAttr accLayout) { +static LogicalResult isIntrinsicLayoutCompatible( + VectorContractOpInfo &opInfo, IREE::GPU::MmaInterfaceAttr intrinsic, + NestedLayoutAttr lhsLayout, NestedLayoutAttr rhsLayout, + NestedLayoutAttr accLayout) { auto [lhsM, rhsN] = opInfo.getOperandMNIndex(); auto [lhsK, rhsK] = opInfo.getOperandKIndex(); auto [accM, accN] = opInfo.getResultMNIndex(); - if (failed(isSubgroupLayoutCompatible(intrinsic.getASingleSubgroupLayout(), + if (failed(isSubgroupLayoutCompatible(getASingleSubgroupLayout(intrinsic), lhsLayout, lhsM, lhsK))) { return failure(); } - if (failed(isSubgroupLayoutCompatible(intrinsic.getBSingleSubgroupLayout(), + if (failed(isSubgroupLayoutCompatible(getBSingleSubgroupLayout(intrinsic), rhsLayout, rhsK, rhsN))) { return failure(); } - if (failed(isSubgroupLayoutCompatible(intrinsic.getCSingleSubgroupLayout(), + if (failed(isSubgroupLayoutCompatible(getCSingleSubgroupLayout(intrinsic), accLayout, accM, accN))) { return failure(); } @@ -124,16 +123,16 @@ struct DistributeContract final : OpDistributionPattern { // We assume there is an decision made before regarding which mfma intrinsic // to use and it is attached as an attribute to this contract op. - auto mmaAttr = - contractOp->getAttrOfType("iree.amdgpu.mma"); - if (!mmaAttr) { + auto mmaKind = contractOp->getAttrOfType( + "iree.amdgpu.mma"); + if (!mmaKind) { return rewriter.notifyMatchFailure( contractOp, "missing iree.amdgpu.mma intrinsic attribute"); } // Check if the given intrinsic can be distributed with the given // layouts. - if (failed(isIntrinsicLayoutCompatible(opDetail, mmaAttr, lhsLayout, + if (failed(isIntrinsicLayoutCompatible(opDetail, mmaKind, lhsLayout, rhsLayout, accLayout))) { return rewriter.notifyMatchFailure( contractOp, "the intrinsic does not match the expected layouts"); @@ -230,7 +229,7 @@ struct DistributeContract final : OpDistributionPattern { Value rhsSlice = rewriter.create(loc, rhs, rhsBatchOffsets); accSlice = - computeMMA(rewriter, loc, mmaAttr, lhsSlice, rhsSlice, accSlice); + computeMMA(rewriter, loc, mmaKind, lhsSlice, rhsSlice, accSlice); } finalTile = rewriter.create(loc, accSlice, finalTile, resultBatchOffsets); @@ -285,17 +284,18 @@ struct DistributeContract final : OpDistributionPattern { // Generates amdgpu.mfma operation on the given inputs for the given MFMA // |intrinsic|. - Value computeMMA(OpBuilder &builder, Location loc, IREE::GPU::MMAAttr mmaAttr, - Value a, Value b, Value c) const { + Value computeMMA(OpBuilder &builder, Location loc, + IREE::GPU::MmaInterfaceAttr mmaKind, Value a, Value b, + Value c) const { // Get the storage vector types that each thread is in charge of. - auto [aVectorType, bVectorType, cVectorType] = mmaAttr.getABCVectorTypes(); + auto [aVectorType, bVectorType, cVectorType] = mmaKind.getABCVectorTypes(); Value aCast = builder.create(a.getLoc(), aVectorType, a); Value bCast = builder.create(b.getLoc(), bVectorType, b); Value cCast = builder.create(c.getLoc(), cVectorType, c); - FailureOr mmaOp = mmaAttr.buildMmaOperation( + FailureOr mmaOp = mmaKind.buildMmaOperation( builder, loc, cVectorType, aCast, bCast, cCast); assert(succeeded(mmaOp) && "Failed to construct mma op"); return builder.create(c.getLoc(), c.getType(), *mmaOp); diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp index c97d8913b3cb6..4675d1dab9ce8 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp @@ -593,8 +593,8 @@ struct DistributeContract final : OpDistributionPattern { } // If mmaAttr exists, defer the lowering to use MMA. // Notify failure if the "iree.amdgpu.mma" intrinsic attribute is present. - auto mmaAttr = - contractOp->getAttrOfType("iree.amdgpu.mma"); + auto mmaAttr = contractOp->getAttrOfType( + "iree.amdgpu.mma"); if (mmaAttr) { return rewriter.notifyMatchFailure( contractOp, "iree.amdgpu.mma intrinsic attribute exists"); diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_contract_amdgpu.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_contract_amdgpu.mlir index db39c0b15742d..eecd2f0653ce7 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_contract_amdgpu.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_contract_amdgpu.mlir @@ -645,7 +645,7 @@ func.func @contract_to_vmfma_32x32x16_mm(%a : vector<32x16xf16>, %b : vector<16x indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind, - iree.amdgpu.mma = #iree_gpu.mma_layout + iree.amdgpu.mma = #iree_gpu.virtual_mma_layout } %A, %B, %C : vector<32x16xf16>, vector<16x32xf16> into vector<32x32xf32> %O = iree_vector_ext.to_layout %output to layout(#layout_c) : vector<32x32xf32> diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp index 0be2db3428adc..a09ae277819f9 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp @@ -236,8 +236,7 @@ static OpaqueMmaLayout getOpaqueMFMALayout(MLIRContext *context, case MMAIntrinsic::MFMA_F32_32x32x8_BF16: { return OpaqueMmaLayout{32, 32, 8, bf16, bf16, f32}; } - case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ: - case MMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ: { + case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ: { return OpaqueMmaLayout{16, 16, 32, f8E4M3FNUZ, f8E4M3FNUZ, f32}; } case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ: { @@ -264,14 +263,6 @@ static OpaqueMmaLayout getOpaqueMFMALayout(MLIRContext *context, case MMAIntrinsic::WMMA_I32_16x16x16_I8: { return OpaqueMmaLayout{16, 16, 16, i8, i8, i32}; } - // V(Virtual)MFMA instructions which have 2 mfma instructions interleaved - // along the k dimension. - case MMAIntrinsic::VMFMA_F32_16x16x32_F16: { - return OpaqueMmaLayout{16, 16, 32, f16, f16, f32}; - } - case MMAIntrinsic::VMFMA_F32_32x32x16_F16: { - return OpaqueMmaLayout{32, 32, 16, f16, f16, f32}; - } } llvm_unreachable("unhandled mfma layout type"); return OpaqueMmaLayout{}; @@ -345,6 +336,43 @@ static ConcreteMmaLayout getConcreteMFMALayout(MLIRContext *context, return concreteLayout; } +//===----------------------------------------------------------------------===// +// MmaInterface Attribute Helper Functions +//===----------------------------------------------------------------------===// + +MMASingleSubgroupLayout getASingleSubgroupLayout(MmaInterfaceAttr mmaKind) { + if (auto mmaAttr = dyn_cast(mmaKind)) { + return mmaAttr.getASingleSubgroupLayout(); + } else if (auto vmmaAttr = dyn_cast(mmaKind)) { + return vmmaAttr.getASingleSubgroupLayout(); + } else { + assert(false && "unhandled MMA Interface type."); + return {}; + } +} + +MMASingleSubgroupLayout getBSingleSubgroupLayout(MmaInterfaceAttr mmaKind) { + if (auto mmaAttr = dyn_cast(mmaKind)) { + return mmaAttr.getBSingleSubgroupLayout(); + } else if (auto vmmaAttr = dyn_cast(mmaKind)) { + return vmmaAttr.getBSingleSubgroupLayout(); + } else { + assert(false && "unhandled MMA Interface type."); + return {}; + } +} + +MMASingleSubgroupLayout getCSingleSubgroupLayout(MmaInterfaceAttr mmaKind) { + if (auto mmaAttr = dyn_cast(mmaKind)) { + return mmaAttr.getCSingleSubgroupLayout(); + } else if (auto vmmaAttr = dyn_cast(mmaKind)) { + return vmmaAttr.getCSingleSubgroupLayout(); + } else { + assert(false && "unhandled MMA Interface type."); + return {}; + } +} + //===----------------------------------------------------------------------===// // MFMA Attributes //===----------------------------------------------------------------------===// @@ -421,15 +449,12 @@ MMAAttr::getABCVectorTypes() const { } case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ: case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ: - case MMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ: - case MMAIntrinsic::VMFMA_F32_16x16x32_F16: case MMAIntrinsic::MFMA_I32_16x16x32_I8: { auto aType = VectorType::get({8}, getAType()); auto bType = VectorType::get({8}, getBType()); auto cType = VectorType::get({4}, getCType()); return std::make_tuple(aType, bType, cType); } - case MMAIntrinsic::VMFMA_F32_32x32x16_F16: case MMAIntrinsic::MFMA_I32_32x32x16_I8: { auto aType = VectorType::get({8}, getAType()); auto bType = VectorType::get({8}, getBType()); @@ -473,10 +498,7 @@ int64_t MMAAttr::getBlockSize() const { case MMAIntrinsic::MFMA_I32_32x32x8_I8: case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ: case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ: - case MMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ: - case MMAIntrinsic::VMFMA_F32_16x16x32_F16: case MMAIntrinsic::MFMA_I32_16x16x32_I8: - case MMAIntrinsic::VMFMA_F32_32x32x16_F16: case MMAIntrinsic::MFMA_I32_32x32x16_I8: case MMAIntrinsic::WMMA_F16_16x16x16_F16: case MMAIntrinsic::WMMA_F32_16x16x16_F16: @@ -499,10 +521,7 @@ static int64_t getIntrinsicSubgroupSize(MMAIntrinsic intrinsic) { case MMAIntrinsic::MFMA_I32_32x32x8_I8: case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ: case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ: - case MMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ: - case MMAIntrinsic::VMFMA_F32_16x16x32_F16: case MMAIntrinsic::MFMA_I32_16x16x32_I8: - case MMAIntrinsic::VMFMA_F32_32x32x16_F16: case MMAIntrinsic::MFMA_I32_32x32x16_I8: { return 64; } @@ -567,7 +586,6 @@ MMASingleSubgroupLayout getSingleSubgroupLayout(MMAIntrinsic intrinsic, return {/*outer=*/{4, 1}, /*thread=*/{2, 32}, /*tstrides=*/{32, 1}, /*element=*/{4, 1}}; } - case MMAIntrinsic::VMFMA_F32_16x16x32_F16: case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ: case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ: case MMAIntrinsic::MFMA_I32_16x16x32_I8: @@ -582,19 +600,6 @@ MMASingleSubgroupLayout getSingleSubgroupLayout(MMAIntrinsic intrinsic, return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*tstrides=*/{16, 1}, /*element=*/{4, 1}}; } - case MMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ: - switch (fragment) { - case MMAFragment::Lhs: - return {/*outer=*/{1, 2}, /*thread=*/{16, 4}, /*tstrides=*/{1, 16}, - /*element=*/{1, 4}}; - case MMAFragment::Rhs: - return {/*outer=*/{2, 1}, /*thread=*/{4, 16}, /*tstrides=*/{16, 1}, - /*element=*/{4, 1}}; - case MMAFragment::Acc: - return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*tstrides=*/{16, 1}, - /*element=*/{4, 1}}; - } - case MMAIntrinsic::VMFMA_F32_32x32x16_F16: case MMAIntrinsic::MFMA_I32_32x32x16_I8: switch (fragment) { case MMAFragment::Lhs: @@ -649,14 +654,14 @@ MMASingleSubgroupLayout MMAAttr::getCSingleSubgroupLayout() const { } // Get virtual intrinsics that is composed/based on queried op. -SmallVector MMAAttr::getVirtualIntrinsics() const { +SmallVector MMAAttr::getVirtualIntrinsics() const { switch (getIntrinsic().getValue()) { case MMAIntrinsic::MFMA_F32_16x16x16_F16: - return {MMAIntrinsic::VMFMA_F32_16x16x32_F16}; + return {VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F16}; case MMAIntrinsic::MFMA_F32_32x32x8_F16: - return {MMAIntrinsic::VMFMA_F32_32x32x16_F16}; + return {VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16}; case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ: - return {MMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ}; + return {VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ}; default: return {}; } @@ -690,37 +695,6 @@ FailureOr MMAAttr::buildMmaOperation(OpBuilder &builder, Location loc, rhs, acc) .getResult(); } - case MMAIntrinsic::VMFMA_F32_16x16x32_F16: - case MMAIntrinsic::VMFMA_F32_32x32x16_F16: { - // Generate mfma's for K with unrolled kernels. - const int64_t unrollKFactor = 2; - auto [m, n, k] = getMNKShape(); - // Compute actual/native intrinsic's K size. - int64_t nativeKSize = k / unrollKFactor; - - auto [aType, bType, cType] = getABCVectorTypes(); - if (aType.getShape()[0] != bType.getShape()[0]) { - // Currently only support case where lhs and rhs - // has same vectorWidth. - return failure(); - } - int64_t vectorWidth = aType.getShape()[0] / unrollKFactor; - for (int i = 0; i < unrollKFactor; i++) { - int64_t offset = vectorWidth * i; - Value sliced_lhs = builder.create( - loc, lhs, ArrayRef{offset}, ArrayRef{vectorWidth}, - ArrayRef{1}); - Value sliced_rhs = builder.create( - loc, rhs, ArrayRef{offset}, ArrayRef{vectorWidth}, - ArrayRef{1}); - acc = builder - .create(loc, resultType, m, n, nativeKSize, - getBlockSize(), sliced_lhs, sliced_rhs, - acc) - .getResult(); - } - return acc; - } case MMAIntrinsic::MFMA_I32_16x16x16_I8: case MMAIntrinsic::MFMA_F32_16x16x16_F16: case MMAIntrinsic::MFMA_F32_16x16x16_BF16: @@ -729,7 +703,6 @@ FailureOr MMAAttr::buildMmaOperation(OpBuilder &builder, Location loc, case MMAIntrinsic::MFMA_F32_32x32x8_BF16: case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ: case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ: - case MMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ: case MMAIntrinsic::MFMA_I32_16x16x32_I8: case MMAIntrinsic::MFMA_I32_32x32x16_I8: { auto [m, n, k] = getMNKShape(); @@ -1179,6 +1152,267 @@ FailureOr DataTiledMMAAttr::buildMmaOperation(OpBuilder &builder, return acc; } +//===----------------------------------------------------------------------===// +// VirtualMMA Attributes +//===----------------------------------------------------------------------===// + +VirtualMMAAttr VirtualMMAAttr::get(MLIRContext *context, + VirtualMMAIntrinsic type) { + auto intrinsicAttr = VirtualMMAIntrinsicAttr::get(context, type); + return VirtualMMAAttr::get(context, intrinsicAttr); +} + +static OpaqueMmaLayout getOpaqueVMMALayout(MLIRContext *context, + VirtualMMAIntrinsic type) { + Type f8E4M3FNUZ = Float8E4M3FNUZType::get(context); + Type f16 = Float16Type::get(context); + Type f32 = Float32Type::get(context); + + switch (type) { + case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ: { + return OpaqueMmaLayout{16, 16, 32, f8E4M3FNUZ, f8E4M3FNUZ, f32}; + } + // V(Virtual)MFMA instructions which have 2 mfma instructions interleaved + // along the k dimension. + case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F16: { + return OpaqueMmaLayout{16, 16, 32, f16, f16, f32}; + } + case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16: { + return OpaqueMmaLayout{32, 32, 16, f16, f16, f32}; + } + } + assert(false && "unhandled virtual mma layout type."); + return OpaqueMmaLayout{}; +} + +std::tuple VirtualMMAAttr::getABCElementTypes() const { + MLIRContext *ctx = getContext(); + auto opaqueLayout = getOpaqueVMMALayout(ctx, getIntrinsic().getValue()); + return {opaqueLayout.aType, opaqueLayout.bType, opaqueLayout.cType}; +} + +std::tuple +VirtualMMAAttr::getABCVectorTypes() const { + // Check https://github.com/ROCm/amd_matrix_instruction_calculator for + // instruction details. Note here we are returning the number elements, while + // amd_matrix_instruction_calculator tells us about the number of 32-bit + // registers. So need to adjust accordingly. All vectors should be 1-D. + auto [A, B, C] = getABCElementTypes(); + switch (getIntrinsic().getValue()) { + case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ: + case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F16: { + auto aType = VectorType::get({8}, A); + auto bType = VectorType::get({8}, B); + auto cType = VectorType::get({4}, C); + return {aType, bType, cType}; + } + case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16: { + auto aType = VectorType::get({8}, A); + auto bType = VectorType::get({8}, B); + auto cType = VectorType::get({16}, C); + return {aType, bType, cType}; + } + } + // This should not happen but just to make GCC happy. + assert(false && "unhandled virtual mma layout type."); + return {VectorType{}, VectorType{}, VectorType{}}; +} + +std::tuple VirtualMMAAttr::getMNKShape() const { + MLIRContext *ctx = getContext(); + auto opaqueLayout = getOpaqueVMMALayout(ctx, getIntrinsic().getValue()); + return {opaqueLayout.mSize, opaqueLayout.nSize, opaqueLayout.kSize}; +} + +int64_t VirtualMMAAttr::getSubgroupSize() const { + switch (getIntrinsic().getValue()) { + case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ: + case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F16: + case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16: { + return 64; + } + } + // This should not happen but just to make GCC happy. + assert(false && "unhandled virtual mma layout type."); + return 0; +} + +FailureOr VirtualMMAAttr::getMmaScope() const { + return IREE::GPU::MMAScope::Subgroup; +} + +LogicalResult VirtualMMAAttr::populateOperandOffsetsSizesStrides( + OpBuilder &builder, Location loc, IREE::GPU::MMAFragment fragment, + Value laneId, ArrayRef permutation, + SmallVector &offsets, SmallVector &sizes, + SmallVector &strides) const { + + MMASingleSubgroupLayout subgroupLayout; + switch (fragment) { + case IREE::GPU::MMAFragment::Lhs: { + subgroupLayout = getASingleSubgroupLayout(); + break; + } + case IREE::GPU::MMAFragment::Rhs: { + subgroupLayout = getBSingleSubgroupLayout(); + break; + } + case IREE::GPU::MMAFragment::Acc: { + subgroupLayout = getCSingleSubgroupLayout(); + break; + } + } + + SmallVector canonicalOffsets; + SmallVector canonicalSizes; + if (failed(populateCanonicalOffsetsSizesAndStrides( + builder, loc, laneId, permutation, subgroupLayout, canonicalOffsets, + canonicalSizes, strides))) { + return failure(); + } + offsets.append(canonicalOffsets); + sizes.append(canonicalSizes); + + return success(); +} + +int64_t VirtualMMAAttr::getUnrollK() const { + switch (getIntrinsic().getValue()) { + case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F16: + case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16: { + return 2; + } + case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ: { + return 1; + } + } + // This should not happen but just to make GCC happy. + assert(false && "unhandled virtual mma layout type."); + return 0; +} + +// Generates amdgpu.mfma/wmma operation on the given inputs for this attribute +// type. +FailureOr VirtualMMAAttr::buildMmaOperation(OpBuilder &builder, + Location loc, + Type resultType, Value lhs, + Value rhs, Value acc) const { + auto [aType, bType, cType] = getABCVectorTypes(); + if (aType != lhs.getType() || bType != rhs.getType() || + cType != acc.getType()) { + return failure(); + } + // Fail if the result type does not match with the expected return type of + // the intrinsic. We expect the caller to handle type conversions externally. + if (cType != resultType) { + return failure(); + } + switch (getIntrinsic().getValue()) { + case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ: + case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F16: + case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16: { + // Generate mfma's for K with unrolled kernels. + const int64_t unrollKFactor = getUnrollK(); + auto [m, n, k] = getMNKShape(); + // Compute actual/native intrinsic's K size. + int64_t nativeKSize = k / unrollKFactor; + + auto [aType, bType, cType] = getABCVectorTypes(); + if (aType.getShape()[0] != bType.getShape()[0]) { + // Currently only support case where lhs and rhs + // has same vectorWidth. + return failure(); + } + int64_t vectorWidth = aType.getShape()[0] / unrollKFactor; + for (int i = 0; i < unrollKFactor; i++) { + int64_t offset = vectorWidth * i; + Value sliced_lhs = builder.create( + loc, lhs, ArrayRef{offset}, ArrayRef{vectorWidth}, + ArrayRef{1}); + Value sliced_rhs = builder.create( + loc, rhs, ArrayRef{offset}, ArrayRef{vectorWidth}, + ArrayRef{1}); + acc = builder + .create(loc, resultType, m, n, nativeKSize, + getBlockSize(), sliced_lhs, sliced_rhs, + acc) + .getResult(); + } + return acc; + } + } + return failure(); +} + +int64_t VirtualMMAAttr::getBlockSize() const { + switch (getIntrinsic().getValue()) { + case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ: + case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F16: + case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16: { + return 1; + } + } + // This should not happen but just to make GCC happy. + assert(false && "unhandled virtual mma layout type."); + return 0; +} + +MMASingleSubgroupLayout getSingleSubgroupLayout(VirtualMMAIntrinsic intrinsic, + MMAFragment fragment) { + switch (intrinsic) { + case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F16: + switch (fragment) { + case MMAFragment::Lhs: + return {/*outer=*/{1, 1}, /*thread=*/{16, 4}, /*tstrides=*/{1, 16}, + /*element=*/{1, 8}}; + case MMAFragment::Rhs: + return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*tstrides=*/{16, 1}, + /*element=*/{8, 1}}; + case MMAFragment::Acc: + return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*tstrides=*/{16, 1}, + /*element=*/{4, 1}}; + } + case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ: + switch (fragment) { + case MMAFragment::Lhs: + return {/*outer=*/{1, 2}, /*thread=*/{16, 4}, /*tstrides=*/{1, 16}, + /*element=*/{1, 4}}; + case MMAFragment::Rhs: + return {/*outer=*/{2, 1}, /*thread=*/{4, 16}, /*tstrides=*/{16, 1}, + /*element=*/{4, 1}}; + case MMAFragment::Acc: + return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*tstrides=*/{16, 1}, + /*element=*/{4, 1}}; + } + case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16: + switch (fragment) { + case MMAFragment::Lhs: + return {/*outer=*/{1, 1}, /*thread=*/{32, 2}, /*tstrides=*/{1, 32}, + /*element=*/{1, 8}}; + case MMAFragment::Rhs: + return {/*outer=*/{1, 1}, /*thread=*/{2, 32}, /*tstrides=*/{32, 1}, + /*element=*/{8, 1}}; + case MMAFragment::Acc: + return {/*outer=*/{4, 1}, /*thread=*/{2, 32}, /*tstrides=*/{32, 1}, + /*element=*/{4, 1}}; + } + } + assert(false && "unhandled virtual mma layout type."); + return {}; +} + +MMASingleSubgroupLayout VirtualMMAAttr::getASingleSubgroupLayout() const { + return getSingleSubgroupLayout(getIntrinsic().getValue(), MMAFragment::Lhs); +} + +MMASingleSubgroupLayout VirtualMMAAttr::getBSingleSubgroupLayout() const { + return getSingleSubgroupLayout(getIntrinsic().getValue(), MMAFragment::Rhs); +} + +MMASingleSubgroupLayout VirtualMMAAttr::getCSingleSubgroupLayout() const { + return getSingleSubgroupLayout(getIntrinsic().getValue(), MMAFragment::Acc); +} + //===----------------------------------------------------------------------===// // MMA Schedule Attributes //===----------------------------------------------------------------------===// @@ -1313,7 +1547,7 @@ MMAScheduleAttr::getContractionLayout(VectorContractOpInfo &opInfo, }); int64_t rank = contractOp.getIteratorTypesArray().size(); - auto mmaAttr = llvm::cast(getIntrinsic()); + auto mmaAttr = llvm::cast(getIntrinsic()); MLIRContext *context = getContext(); SmallVector bounds = contractOp.getStaticLoopRanges(); @@ -1460,7 +1694,7 @@ MMAScheduleAttr::getContractionLayout(VectorContractOpInfo &opInfo, /*subgroupCount=*/cSubgroupSizes, /*subgroupStrides=*/cSubgroupStrides, /*batchCount=*/cBatchSizes, - mmaAttr.getCSingleSubgroupLayout()); + getCSingleSubgroupLayout(mmaAttr)); LLVM_DEBUG({ llvm::errs() << "C layout: " << cLayout << "\n"; }); // A matrix layout @@ -1488,7 +1722,7 @@ MMAScheduleAttr::getContractionLayout(VectorContractOpInfo &opInfo, /*subgroupCount=*/aSubgroupSizes, /*subgroupStrides=*/aSubgroupStrides, /*batchCount=*/aBatchSizes, - mmaAttr.getASingleSubgroupLayout()); + getASingleSubgroupLayout(mmaAttr)); LLVM_DEBUG({ llvm::errs() << "A layout: " << aLayout << "\n"; }); int64_t bRank = opInfo.getBRank(); @@ -1512,7 +1746,7 @@ MMAScheduleAttr::getContractionLayout(VectorContractOpInfo &opInfo, /*subgroupCount=*/bSubgroupSizes, /*subgroupStrides=*/bSubgroupStrides, /*batchCount=*/bBatchSizes, - mmaAttr.getBSingleSubgroupLayout()); + getBSingleSubgroupLayout(mmaAttr)); LLVM_DEBUG({ llvm::errs() << "B layout: " << bLayout << "\n"; }); std::tuple outer; - llvm::SmallVector thread; - llvm::SmallVector tstrides; - llvm::SmallVector element; + SmallVector outer; + SmallVector thread; + SmallVector tstrides; + SmallVector element; }; MMASingleSubgroupLayout getSingleSubgroupLayout(MMAIntrinsic intrinsic, MMAFragment fragment); +MMASingleSubgroupLayout getSingleSubgroupLayout(VirtualMMAIntrinsic intrinsic, + MMAFragment fragment); + +MMASingleSubgroupLayout getASingleSubgroupLayout(MmaInterfaceAttr mmaKind); + +MMASingleSubgroupLayout getBSingleSubgroupLayout(MmaInterfaceAttr mmaKind); + +MMASingleSubgroupLayout getCSingleSubgroupLayout(MmaInterfaceAttr mmaKind); + } // namespace mlir::iree_compiler::IREE::GPU // clang-format off diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td index ee4cb932ded47..d7686562f7cda 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td @@ -152,6 +152,9 @@ class IREEGPU_MmaVectorLayoutAttr : "getMNKShape", "getSubgroupSize", "getMmaScope", + "getASingleSubgroupLayout", + "getBSingleSubgroupLayout", + "getCSingleSubgroupLayout", "buildMmaOperation", "populateOperandOffsetsSizesStrides", "materializeOperandConcreteShape", @@ -232,7 +235,7 @@ def IREEGPU_MMAAttr : IREEGPU_MmaVectorLayoutAttr<"MMA", "MMAIntrinsicAttr"> { MMASingleSubgroupLayout getBSingleSubgroupLayout() const; MMASingleSubgroupLayout getCSingleSubgroupLayout() const; - SmallVector getVirtualIntrinsics() const; + SmallVector getVirtualIntrinsics() const; }]; } @@ -283,6 +286,64 @@ def IREEGPU_DataTiledMMAAttr : ); } +def IREEGPU_VirtualMMAIntrinsicAttr + : IREEGPU_MmaEnumAttr; + +def IREEGPU_VirtualMMAAttr : + AttrDef +]> { + let mnemonic = "virtual_mma_layout"; + let cppNamespace = "::mlir::iree_compiler::IREE::GPU"; + + let description = [{ + This mma variant represents "virtual" MMA ops that has modification to + its native layouts by unrollK and/or interleave reads. The |intrinsic| + field represents different kinds of "Virtual" MMA Ops we found helpful. + + These interleaving and/or unrolling changes in the layout is especially + useful to coalesce reads from shared memory to register or align layouts + in a chained-matmul operation. + }]; + + let assemblyFormat = "`<` struct(params) `>`"; + let builders = [ + AttrBuilder<(ins "VirtualMMAIntrinsic":$intrinsic)> + ]; + + let parameters = (ins + "::mlir::iree_compiler::IREE::GPU::VirtualMMAIntrinsicAttr":$intrinsic + ); + + let extraClassDeclaration = [{ + int64_t getBlockSize() const; + + // Returns the A/B/C matrix's partial nested layout shape inside a single + // subgroup. Shape at each outer/thread/element level is a 2-D value, + // following canonical matmul order--(M, K) for A, (K, N) for B, and + // (M, N) for C. + MMASingleSubgroupLayout getASingleSubgroupLayout() const; + MMASingleSubgroupLayout getBSingleSubgroupLayout() const; + MMASingleSubgroupLayout getCSingleSubgroupLayout() const; + + // Factor to unroll K from native MMA/intrinsic size to virtual size. + // e.g MFMA_F32_16x16x16 has K of 16, while VMFMA_F32_16x16x32 has K of 32 + // in this example, unrollK = 32/16 = 2. + int64_t getUnrollK() const; + }]; +} + def IREEGPU_MMAOpsArrayAttr : ArrayOfAttr< IREEGPU_Dialect, "MMAOpsArray", "mma_ops", "MMAAttr"> { let cppNamespace = "::mlir::iree_compiler::IREE::GPU"; diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td index 49f210e7f06dd..22a510db049cf 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td @@ -156,13 +156,10 @@ def IREEGPU_MMAIntrinsic : IREEGPU_I32MmaEnumAttr<"MMAIntrinsic", MFMA_F32_16x16x4_F32, MFMA_F32_16x16x16_F16, MFMA_F32_32x32x8_F16, - VMFMA_F32_16x16x32_F16, - VMFMA_F32_32x32x16_F16, MFMA_F32_16x16x16_BF16, MFMA_F32_32x32x8_BF16, MFMA_F32_16x16x32_F8E4M3FNUZ, MFMA_F32_16x16x32_F8E5M2FNUZ, - VMFMA_F32_16x16x32_F8E4M3FNUZ, MFMA_I32_16x16x32_I8, MFMA_I32_32x32x16_I8, MFMA_I32_16x16x16_I8, @@ -172,6 +169,13 @@ def IREEGPU_MMAIntrinsic : IREEGPU_I32MmaEnumAttr<"MMAIntrinsic", WMMA_I32_16x16x16_I8 ]>; +def IREEGPU_VirtualMMAIntrinsic : IREEGPU_I32MmaEnumAttr<"VirtualMMAIntrinsic", + "Descriptor for different Virtual MMA intrinsics", [ + VMFMA_F32_16x16x32_F16, + VMFMA_F32_32x32x16_F16, + VMFMA_F32_16x16x32_F8E4M3FNUZ, + ]>; + def MMA_LHS : I32EnumAttrCase<"Lhs", 0>; def MMA_RHS : I32EnumAttrCase<"Rhs", 1>; def MMA_ACC : I32EnumAttrCase<"Acc", 2>; diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp index 96a7a9d8ddb71..f269e2be0a9be 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp @@ -323,28 +323,29 @@ setConvolutionVectorDistributionConfig(IREE::GPU::TargetAttr target, lhsElemType, rhsElemType, initElemType}; // Helper fn to store mma information. - auto storeMmaInfo = [](IREE::GPU::MMAAttr mma, + auto storeMmaInfo = [](IREE::GPU::MmaInterfaceAttr mma, SmallVector &intrinsics, - SmallVector &mmaAttrs) { + SmallVector &mmaKinds) { auto [mSize, nSize, kSize] = mma.getMNKShape(); auto [aType, bType, cType] = mma.getABCElementTypes(); intrinsics.emplace_back(mSize, nSize, kSize, aType, bType, cType); - mmaAttrs.emplace_back(mma); + mmaKinds.emplace_back(mma); }; SmallVector intrinsics; intrinsics.reserve(target.getWgp().getMma().size()); - SmallVector mmaAttrs; + SmallVector mmaKinds; MLIRContext *context = op.getContext(); for (IREE::GPU::MMAAttr mma : target.getWgp().getMma()) { if (mma.getSubgroupSize() != targetSubgroupSize) continue; - storeMmaInfo(mma, intrinsics, mmaAttrs); + storeMmaInfo(mma, intrinsics, mmaKinds); // Store info on virtual intrinsics based on current mma if any - for (IREE::GPU::MMAIntrinsic virtualIntrinsic : + for (IREE::GPU::VirtualMMAIntrinsic virtualIntrinsic : mma.getVirtualIntrinsics()) { - auto virtualMma = IREE::GPU::MMAAttr::get(context, virtualIntrinsic); - storeMmaInfo(virtualMma, intrinsics, mmaAttrs); + auto virtualMma = + IREE::GPU::VirtualMMAAttr::get(context, virtualIntrinsic); + storeMmaInfo(virtualMma, intrinsics, mmaKinds); } } @@ -417,7 +418,7 @@ setConvolutionVectorDistributionConfig(IREE::GPU::TargetAttr target, b.getI64ArrayAttr(reductionTileSizes)); IREE::GPU::LoweringConfigAttr::setPromotedOperandList(context, attrs, {0, 1}); IREE::GPU::LoweringConfigAttr::setMmaKind(context, attrs, - mmaAttrs[schedule->index]); + mmaKinds[schedule->index]); IREE::GPU::LoweringConfigAttr::setSubgroupMCount( context, attrs, schedule->mSubgroupCounts[0]); IREE::GPU::LoweringConfigAttr::setSubgroupNCount( @@ -536,28 +537,29 @@ setMatmulVectorDistributionConfig(IREE::GPU::TargetAttr target, lhsElemType, rhsElemType, initElemType}; // Helper fn to store mma information. - auto storeMmaInfo = [](IREE::GPU::MMAAttr mma, + auto storeMmaInfo = [](IREE::GPU::MmaInterfaceAttr mma, SmallVector &intrinsics, - SmallVector &mmaAttrs) { + SmallVector &mmaKinds) { auto [mSize, nSize, kSize] = mma.getMNKShape(); auto [aType, bType, cType] = mma.getABCElementTypes(); intrinsics.emplace_back(mSize, nSize, kSize, aType, bType, cType); - mmaAttrs.emplace_back(mma); + mmaKinds.emplace_back(mma); }; SmallVector intrinsics; intrinsics.reserve(target.getWgp().getMma().size()); - SmallVector mmaAttrs; + SmallVector mmaKinds; MLIRContext *context = op.getContext(); for (IREE::GPU::MMAAttr mma : target.getWgp().getMma()) { if (mma.getSubgroupSize() != targetSubgroupSize) continue; - storeMmaInfo(mma, intrinsics, mmaAttrs); + storeMmaInfo(mma, intrinsics, mmaKinds); // Store info on virtual intrinsics based on current mma if any - for (IREE::GPU::MMAIntrinsic virtualIntrinsic : + for (IREE::GPU::VirtualMMAIntrinsic virtualIntrinsic : mma.getVirtualIntrinsics()) { - auto virtualMma = IREE::GPU::MMAAttr::get(context, virtualIntrinsic); - storeMmaInfo(virtualMma, intrinsics, mmaAttrs); + auto virtualMma = + IREE::GPU::VirtualMMAAttr::get(context, virtualIntrinsic); + storeMmaInfo(virtualMma, intrinsics, mmaKinds); } } @@ -682,7 +684,7 @@ setMatmulVectorDistributionConfig(IREE::GPU::TargetAttr target, b.getI64ArrayAttr(reductionTileSizes)); IREE::GPU::LoweringConfigAttr::setPromotedOperandList(context, attrs, {0, 1}); IREE::GPU::LoweringConfigAttr::setMmaKind(context, attrs, - mmaAttrs[schedule->index]); + mmaKinds[schedule->index]); IREE::GPU::LoweringConfigAttr::setSubgroupMCount( context, attrs, schedule->mSubgroupCounts[0]); IREE::GPU::LoweringConfigAttr::setSubgroupNCount( @@ -759,28 +761,29 @@ setAttentionVectorDistributionConfig(IREE::GPU::TargetAttr target, Value vMatrix = op.getValue(); // Helper fn to store mma information. - auto storeMmaInfo = [](IREE::GPU::MMAAttr mma, + auto storeMmaInfo = [](IREE::GPU::MmaInterfaceAttr mma, SmallVector &intrinsics, - SmallVector &mmaAttrs) { + SmallVector &mmaKinds) { auto [mSize, nSize, kSize] = mma.getMNKShape(); auto [aType, bType, cType] = mma.getABCElementTypes(); intrinsics.emplace_back(mSize, nSize, kSize, aType, bType, cType); - mmaAttrs.emplace_back(mma); + mmaKinds.emplace_back(mma); }; SmallVector intrinsics; intrinsics.reserve(target.getWgp().getMma().size()); - SmallVector mmaAttrs; + SmallVector mmaKinds; MLIRContext *context = op.getContext(); for (IREE::GPU::MMAAttr mma : target.getWgp().getMma()) { if (mma.getSubgroupSize() != targetSubgroupSize) continue; - storeMmaInfo(mma, intrinsics, mmaAttrs); + storeMmaInfo(mma, intrinsics, mmaKinds); // Store info on virtual intrinsics based on current mma if any - for (IREE::GPU::MMAIntrinsic virtualIntrinsic : + for (IREE::GPU::VirtualMMAIntrinsic virtualIntrinsic : mma.getVirtualIntrinsics()) { - auto virtualMma = IREE::GPU::MMAAttr::get(context, virtualIntrinsic); - storeMmaInfo(virtualMma, intrinsics, mmaAttrs); + auto virtualMma = + IREE::GPU::VirtualMMAAttr::get(context, virtualIntrinsic); + storeMmaInfo(virtualMma, intrinsics, mmaKinds); } } @@ -916,7 +919,7 @@ setAttentionVectorDistributionConfig(IREE::GPU::TargetAttr target, IREE::GPU::LoweringConfigAttr::setPromotedOperandList(context, qkConfig, {0, 1}); IREE::GPU::LoweringConfigAttr::setMmaKind(context, qkConfig, - mmaAttrs[schedule->index]); + mmaKinds[schedule->index]); IREE::GPU::LoweringConfigAttr::setSubgroupMCount( context, qkConfig, schedule->mSubgroupCounts[0]); IREE::GPU::LoweringConfigAttr::setSubgroupNCount(context, qkConfig, 1); @@ -924,7 +927,7 @@ setAttentionVectorDistributionConfig(IREE::GPU::TargetAttr target, // Configuring for pv matmul. IREE::GPU::LoweringConfigAttr::setPromotedOperandList(context, pvConfig, {1}); IREE::GPU::LoweringConfigAttr::setMmaKind(context, pvConfig, - mmaAttrs[schedule->index]); + mmaKinds[schedule->index]); IREE::GPU::LoweringConfigAttr::setSubgroupMCount( context, pvConfig, schedule->mSubgroupCounts[0]); IREE::GPU::LoweringConfigAttr::setSubgroupNCount( diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp index 7008f3e1376e5..f9555ef10c5e1 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp @@ -308,14 +308,16 @@ static LogicalResult setAttentionMatmulAnchor(RewriterBase &rewriter, bool reuseIntrinsicOutput = false; bool transposeIntrinsic = false; - auto qkIntrinsic = cast(qkSchedule.getIntrinsic()); - auto pvIntrinsic = cast(pvSchedule.getIntrinsic()); + auto qkIntrinsic = + cast(qkSchedule.getIntrinsic()); + auto pvIntrinsic = + cast(pvSchedule.getIntrinsic()); IREE::GPU::MMASingleSubgroupLayout lhsLayout = - pvIntrinsic.getASingleSubgroupLayout(); + getASingleSubgroupLayout(pvIntrinsic); IREE::GPU::MMASingleSubgroupLayout rhsLayout = - pvIntrinsic.getBSingleSubgroupLayout(); + getBSingleSubgroupLayout(pvIntrinsic); IREE::GPU::MMASingleSubgroupLayout outLayout = - qkIntrinsic.getCSingleSubgroupLayout(); + getCSingleSubgroupLayout(qkIntrinsic); auto matchLayout = [](IREE::GPU::MMASingleSubgroupLayout layoutA, IREE::GPU::MMASingleSubgroupLayout layoutB) -> bool { diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir index a38ae5cee7ec4..960fe0b9938c6 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir @@ -657,7 +657,7 @@ hal.executable public @contract_schedule_considering_read_layout { // This test ensures that we can generate and decompose the right instructions from V(Virtual) MFMAs. -#config = #iree_gpu.lowering_config<{workgroup = [64, 64, 0], reduction = [0, 0, 128], promote_operands = [0, 1], mma_kind = #iree_gpu.mma_layout, subgroup_m_count = 2, subgroup_n_count = 2}> +#config = #iree_gpu.lowering_config<{workgroup = [64, 64, 0], reduction = [0, 0, 128], promote_operands = [0, 1], mma_kind = #iree_gpu.virtual_mma_layout, subgroup_m_count = 2, subgroup_n_count = 2}> #translation = #iree_codegen.translation_info}> #pipeline_layout = #hal.pipeline.layout) { // This test ensures we can generate correct instructions from V(Virtual) MFMAs that has only different read layouts. -#config = #iree_gpu.lowering_config<{workgroup = [32, 32, 0], reduction = [0, 0, 128], promote_operands = [0, 1], mma_kind = #iree_gpu.mma_layout, subgroup_m_count = 2, subgroup_n_count = 2}> +#config = #iree_gpu.lowering_config<{workgroup = [32, 32, 0], reduction = [0, 0, 128], promote_operands = [0, 1], mma_kind = #iree_gpu.virtual_mma_layout, subgroup_m_count = 2, subgroup_n_count = 2}> #translation = #iree_codegen.translation_info}> #pipeline_layout = #hal.pipeline.layout