From 8c92a77c89a68b5d4f2935e8183330c62ee2498c Mon Sep 17 00:00:00 2001 From: Stanley Winata Date: Tue, 12 Nov 2024 10:50:00 -0800 Subject: [PATCH 1/2] [LLVMGPU] Add 32x32x16 F8 MFMA intrinsic To enable faster SDXL on attention we'd need different FP8 MFMA intrinsics. This 32x32x16 FP8 intrinsic (and virtual intrinsic for 2nd matmul) has been especially performant when used on this SDXL attention shape (B0: 2, B1: 10, (M, K2): 4096: K1: 64). Signed-off-by: Stanley Winata --- .../target/ROCM/test/target_device_features.mlir | 4 ++-- .../Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp | 16 ++++++++++++++++ .../Codegen/Dialect/GPU/IR/IREEGPUEnums.td | 8 ++++++++ .../Dialect/GPU/TargetUtils/KnownTargets.cpp | 4 ++++ tests/e2e/matmul/generate_e2e_matmul_tests.py | 5 +++++ 5 files changed, 35 insertions(+), 2 deletions(-) diff --git a/compiler/plugins/target/ROCM/test/target_device_features.mlir b/compiler/plugins/target/ROCM/test/target_device_features.mlir index 71a8272d341b..0809252abc9b 100644 --- a/compiler/plugins/target/ROCM/test/target_device_features.mlir +++ b/compiler/plugins/target/ROCM/test/target_device_features.mlir @@ -15,7 +15,7 @@ // GFX942: target = #iree_gpu.target, , , , , , , , , , , ], +// GFX942-SAME: mma = [, , , , , , , , , , , , , , , ], // GFX942-SAME: subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], // GFX942-SAME: max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, // GFX942-SAME: max_workgroup_counts = [2147483647, 2147483647, 2147483647], @@ -26,7 +26,7 @@ // GFX941-SAME: features = "+sramecc,-xnack" // GFX940: target = #iree_gpu.target, , , , , , , , , , , ], +// GFX940-SAME: mma = [, , , , , , , , , , , , , , , ], // GFX1100: target = #iree_gpu.target, , , , ] diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp index 4f9b1ffab3b6..544acb3a67c6 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp @@ -256,6 +256,18 @@ static std::tuple getABCElementTypes(MLIRContext *context, case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ: { return {f8E5M2FNUZ, f8E4M3FNUZ, f32}; } + case MMAIntrinsic::MFMA_F32_32x32x16_F8E4M3FNUZ: { + return {f8E4M3FNUZ, f8E4M3FNUZ, f32}; + } + case MMAIntrinsic::MFMA_F32_32x32x16_F8E5M2FNUZ: { + return {f8E5M2FNUZ, f8E5M2FNUZ, f32}; + } + case MMAIntrinsic::MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ: { + return {f8E4M3FNUZ, f8E5M2FNUZ, f32}; + } + case MMAIntrinsic::MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ: { + return {f8E5M2FNUZ, f8E4M3FNUZ, f32}; + } case MMAIntrinsic::MFMA_I32_16x16x32_I8: { return {i8, i8, i32}; } @@ -608,6 +620,10 @@ MMASingleSubgroupLayout getSingleSubgroupLayout(MMAIntrinsic intrinsic, return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*tstrides=*/{16, 1}, /*element=*/{4, 1}}; } + case MMAIntrinsic::MFMA_F32_32x32x16_F8E4M3FNUZ: + case MMAIntrinsic::MFMA_F32_32x32x16_F8E5M2FNUZ: + case MMAIntrinsic::MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ: + case MMAIntrinsic::MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ: case MMAIntrinsic::MFMA_I32_32x32x16_I8: switch (fragment) { case MMAFragment::Lhs: diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td index ac03a9fa5fa2..5391c5c3c69b 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td @@ -158,6 +158,10 @@ def MFMA_F32_16x16x32_F8E5M2FNUZ : I32EnumAttrCase<"MFMA_F32_16x16x32_F8E5M2FNUZ def MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ : I32EnumAttrCase<"MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ", 0x1231>; def MFMA_F32_16x16x32_F8E4M3FNUZ : I32EnumAttrCase<"MFMA_F32_16x16x32_F8E4M3FNUZ", 0x1232>; def MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ : I32EnumAttrCase<"MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ", 0x1233>; +def MFMA_F32_32x32x16_F8E5M2FNUZ : I32EnumAttrCase<"MFMA_F32_32x32x16_F8E5M2FNUZ", 0x1234>; +def MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ : I32EnumAttrCase<"MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ", 0x1235>; +def MFMA_F32_32x32x16_F8E4M3FNUZ : I32EnumAttrCase<"MFMA_F32_32x32x16_F8E4M3FNUZ", 0x1236>; +def MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ : I32EnumAttrCase<"MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ", 0x1237>; def MFMA_I32_16x16x32_I8 : I32EnumAttrCase<"MFMA_I32_16x16x32_I8", 0x12C0>; def MFMA_I32_32x32x16_I8 : I32EnumAttrCase<"MFMA_I32_32x32x16_I8", 0x12C1>; @@ -193,6 +197,10 @@ def IREEGPU_MMAIntrinsic : IREEGPU_I32MmaEnumAttr<"MMAIntrinsic", MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ, MFMA_F32_16x16x32_F8E4M3FNUZ, MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ, + MFMA_F32_32x32x16_F8E5M2FNUZ, + MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ, + MFMA_F32_32x32x16_F8E4M3FNUZ, + MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ, MFMA_I32_16x16x32_I8, MFMA_I32_32x32x16_I8, diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp index e198e216ece7..f4a626a70b42 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp @@ -146,6 +146,10 @@ const WgpDetails *getCDNA3WgpDetails() { MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ, MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ, MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ, + MMAIntrinsic::MFMA_F32_32x32x16_F8E5M2FNUZ, + MMAIntrinsic::MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ, + MMAIntrinsic::MFMA_F32_32x32x16_F8E4M3FNUZ, + MMAIntrinsic::MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ, MMAIntrinsic::MFMA_I32_16x16x32_I8, MMAIntrinsic::MFMA_I32_32x32x16_I8, }; diff --git a/tests/e2e/matmul/generate_e2e_matmul_tests.py b/tests/e2e/matmul/generate_e2e_matmul_tests.py index 5782120b0954..a97b5626c069 100644 --- a/tests/e2e/matmul/generate_e2e_matmul_tests.py +++ b/tests/e2e/matmul/generate_e2e_matmul_tests.py @@ -339,6 +339,10 @@ def get_rocm_test_compilation_infos( MMASchedule("MFMA_F32_16x16x32_F8E4M3FNUZ", 2, 2, 1, 1, 2), MMASchedule("MFMA_F32_16x16x32_F8E4M3FNUZ", 4, 1, 4, 1, 1), MMASchedule("MFMA_F32_16x16x32_F8E4M3FNUZ", 4, 2, 4, 2, 1), + MMASchedule("MFMA_F32_32x32x16_F8E4M3FNUZ", 1, 1, 1, 1, 1), + MMASchedule("MFMA_F32_32x32x16_F8E4M3FNUZ", 2, 2, 1, 1, 2), + MMASchedule("MFMA_F32_32x32x16_F8E4M3FNUZ", 4, 1, 1, 2, 2), + MMASchedule("MFMA_F32_32x32x16_F8E4M3FNUZ", 4, 2, 2, 2, 2), MMASchedule("MFMA_I32_16x16x32_I8", 1, 1, 1, 1, 1), MMASchedule("MFMA_I32_16x16x32_I8", 2, 2, 1, 1, 2), MMASchedule("MFMA_I32_16x16x32_I8", 4, 1, 4, 1, 1), @@ -409,6 +413,7 @@ def get_rocm_test_compilation_infos( wg_tile_k = schedule.k_tile_count * 32 elif ( schedule.intrinsic == "VMFMA_F32_32x32x16_F16" + or schedule.intrinsic == "MFMA_F32_32x32x16_F8E4M3FNUZ" or schedule.intrinsic == "MFMA_I32_32x32x16_I8" ): wg_tile_m = schedule.m_count * schedule.m_tile_count * 32 From 854e675d03287455f9975694a5d4138274ac14d8 Mon Sep 17 00:00:00 2001 From: Stanley Winata Date: Tue, 12 Nov 2024 13:23:48 -0800 Subject: [PATCH 2/2] add virtual mfma for 32x32x16xf8e4m3 Signed-off-by: Stanley Winata --- .../Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp | 24 +++- .../Codegen/Dialect/GPU/IR/IREEGPUEnums.td | 2 + .../pipeline_vector_distribute_gfx940.mlir | 122 ++++++++++++++++-- 3 files changed, 134 insertions(+), 14 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp index 544acb3a67c6..ea06da542ce5 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp @@ -691,6 +691,8 @@ SmallVector MMAAttr::getVirtualIntrinsics() const { return {VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16}; case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ: return {VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ}; + case MMAIntrinsic::MFMA_F32_32x32x16_F8E4M3FNUZ: + return {VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F8E4M3FNUZ}; default: return {}; } @@ -1234,6 +1236,9 @@ static OpaqueMmaLayout getOpaqueVMMALayout(MLIRContext *context, case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ: { return OpaqueMmaLayout{16, 16, 32, f8E4M3FNUZ, f8E4M3FNUZ, f32}; } + case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F8E4M3FNUZ: { + return OpaqueMmaLayout{32, 32, 16, f8E4M3FNUZ, f8E4M3FNUZ, f32}; + } // V(Virtual)MFMA instructions which have 2 mfma instructions interleaved // along the k dimension. case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F16: { @@ -1268,6 +1273,7 @@ VirtualMMAAttr::getABCVectorTypes() const { auto cType = VectorType::get({4}, C); return {aType, bType, cType}; } + case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F8E4M3FNUZ: case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16: { auto aType = VectorType::get({8}, A); auto bType = VectorType::get({8}, B); @@ -1290,6 +1296,7 @@ int64_t VirtualMMAAttr::getSubgroupSize() const { switch (getIntrinsic().getValue()) { case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ: case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F16: + case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F8E4M3FNUZ: case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16: { return 64; } @@ -1344,7 +1351,8 @@ int64_t VirtualMMAAttr::getUnrollK() const { case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16: { return 2; } - case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ: { + case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ: + case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F8E4M3FNUZ: { return 1; } } @@ -1372,6 +1380,7 @@ FailureOr VirtualMMAAttr::buildMmaOperation(OpBuilder &builder, switch (getIntrinsic().getValue()) { case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ: case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F16: + case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F8E4M3FNUZ: case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16: { // Generate mfma's for K with unrolled kernels. const int64_t unrollKFactor = getUnrollK(); @@ -1410,6 +1419,7 @@ int64_t VirtualMMAAttr::getBlockSize() const { switch (getIntrinsic().getValue()) { case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ: case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F16: + case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F8E4M3FNUZ: case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16: { return 1; } @@ -1458,6 +1468,18 @@ MMASingleSubgroupLayout getSingleSubgroupLayout(VirtualMMAIntrinsic intrinsic, return {/*outer=*/{4, 1}, /*thread=*/{2, 32}, /*tstrides=*/{32, 1}, /*element=*/{4, 1}}; } + case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F8E4M3FNUZ: + switch (fragment) { + case MMAFragment::Lhs: + return {/*outer=*/{1, 2}, /*thread=*/{32, 2}, /*tstrides=*/{1, 32}, + /*element=*/{1, 4}}; + case MMAFragment::Rhs: + return {/*outer=*/{2, 1}, /*thread=*/{2, 32}, /*tstrides=*/{32, 1}, + /*element=*/{4, 1}}; + case MMAFragment::Acc: + return {/*outer=*/{4, 1}, /*thread=*/{2, 32}, /*tstrides=*/{32, 1}, + /*element=*/{4, 1}}; + } } assert(false && "unhandled virtual mma layout type."); return {}; diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td index 5391c5c3c69b..b174ceaac916 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td @@ -219,12 +219,14 @@ def IREEGPU_MMAIntrinsic : IREEGPU_I32MmaEnumAttr<"MMAIntrinsic", def VMFMA_F32_16x16x32_F16 : I32EnumAttrCase<"VMFMA_F32_16x16x32_F16", 0>; def VMFMA_F32_32x32x16_F16 : I32EnumAttrCase<"VMFMA_F32_32x32x16_F16", 1>; def VMFMA_F32_16x16x32_F8E4M3FNUZ : I32EnumAttrCase<"VMFMA_F32_16x16x32_F8E4M3FNUZ", 2>; +def VMFMA_F32_32x32x16_F8E4M3FNUZ : I32EnumAttrCase<"VMFMA_F32_32x32x16_F8E4M3FNUZ", 3>; def IREEGPU_VirtualMMAIntrinsic : IREEGPU_I32MmaEnumAttr<"VirtualMMAIntrinsic", "Descriptor for different Virtual MMA intrinsics", [ VMFMA_F32_16x16x32_F16, VMFMA_F32_32x32x16_F16, VMFMA_F32_16x16x32_F8E4M3FNUZ, + VMFMA_F32_32x32x16_F8E4M3FNUZ, ]>; def MMA_LHS : I32EnumAttrCase<"Lhs", 0>; diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir index 0eeeafe775ae..5e5096dfa637 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir @@ -215,7 +215,7 @@ hal.executable @matmul_multiple_k { // ----- -// Basic f8, f8 -> f32 matmul. +// Basic f8, f8 -> f32 matmul. (intrinsic with shape, m = 16, n = 16, k = 32) #config = #iree_gpu.lowering_config<{workgroup = [64, 64, 0], reduction = [0, 0, 256], promote_operands = [0, 1], mma_kind = #iree_gpu.mma_layout, subgroup_m_count = 2, subgroup_n_count = 2}> #translation = #iree_codegen.translation_info}> @@ -225,15 +225,15 @@ hal.executable @matmul_multiple_k { #hal.pipeline.binding, #hal.pipeline.binding ]> -hal.executable @matmul_256x256x256_f8_f32 { +hal.executable @matmul_256x256x256_16x16x32_f8_f32 { hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) { - hal.executable.export @matmul_256x256x256_f8_f32 layout(#pipeline_layout) { + hal.executable.export @matmul_256x256x256_16x16x32_f8_f32 layout(#pipeline_layout) { ^bb0(%arg0: !hal.device, %arg1: index, %arg2 : index): %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2 hal.return %x, %y, %z : index, index, index } builtin.module { - func.func @matmul_256x256x256_f8_f32() attributes {translation_info = #translation} { + func.func @matmul_256x256x256_16x16x32_f8_f32() attributes {translation_info = #translation} { %cst = arith.constant 0.000000e+00 : f32 %c0 = arith.constant 0 : index %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> @@ -253,7 +253,7 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) { // Make sure it generates the mfma instructions we expect for f8 inputs. -// CHECK-LABEL: func.func @matmul_256x256x256_f8_f32() +// CHECK-LABEL: func.func @matmul_256x256x256_16x16x32_f8_f32() // Each subgroup handles 2 * 2 tiles, and for each tile we accumulate 8 times // along the K dimension. So in total 32 mfma ops. // CHECK-COUNT-32: amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32} blgp = none : vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32> @@ -307,6 +307,52 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) { // ----- +// Basic f8, f8 -> f32 matmul. (intrinsic with shape, m = 32, n = 32, k = 16) + +#config = #iree_gpu.lowering_config<{workgroup = [64, 64, 0], reduction = [0, 0, 256], promote_operands = [0, 1], mma_kind = #iree_gpu.mma_layout, subgroup_m_count = 2, subgroup_n_count = 2}> +#translation = #iree_codegen.translation_info}> + +#pipeline_layout = #hal.pipeline.layout, + #hal.pipeline.binding, + #hal.pipeline.binding +]> +hal.executable @matmul_256x256x256_32x32x16_f8_f32 { +hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) { + hal.executable.export @matmul_256x256x256_32x32x16_f8_f32 layout(#pipeline_layout) { + ^bb0(%arg0: !hal.device, %arg1: index, %arg2 : index): + %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2 + hal.return %x, %y, %z : index, index, index + } + builtin.module { + func.func @matmul_256x256x256_32x32x16_f8_f32() attributes {translation_info = #translation} { + %cst = arith.constant 0.000000e+00 : f32 + %c0 = arith.constant 0 : index + %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<256x256xf8E4M3FNUZ> + %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<256x256xf8E4M3FNUZ> + %5 = tensor.empty() : tensor<256x256xf32> + %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<256x256xf32>) -> tensor<256x256xf32> + %7 = linalg.matmul {lowering_config = #config} ins(%3, %4 : tensor<256x256xf8E4M3FNUZ>, tensor<256x256xf8E4M3FNUZ>) outs(%6 : tensor<256x256xf32>) -> tensor<256x256xf32> + flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : tensor<256x256xf32> -> !flow.dispatch.tensor> + return + } + } +} +} + +// Make sure it generates the mfma instructions we expect for f8 inputs. + +// CHECK-LABEL: func.func @matmul_256x256x256_32x32x16_f8_f32() +// Each subgroup handles 1 * 1 tiles, and for each tile we accumulate (256/16) = 16 times +// along the K dimension. So in total 16 mfma ops. +// CHECK-COUNT-16: amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 32 : i32, n = 32 : i32} blgp = none : vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<16xf32> +// CHECK-COUNT-4: vector.transfer_write {{.+}} {in_bounds = [true, true]} : vector<4x1xf32>, memref<256x256xf32, #hal.descriptor_type> + +// ----- + // Basic i8, i8 -> i32 matmul_transpose_b. #config = #iree_gpu.lowering_config<{workgroup = [64, 64, 0], reduction = [0, 0, 256], promote_operands = [0, 1], mma_kind = #iree_gpu.mma_layout, subgroup_m_count = 2, subgroup_n_count = 2}> @@ -656,6 +702,7 @@ hal.executable public @contract_schedule_considering_read_layout { // ----- // This test ensures that we can generate and decompose the right instructions from V(Virtual) MFMAs. +// (intrinsic with shape, m = 32, n = 32, k = 16) #config = #iree_gpu.lowering_config<{workgroup = [64, 64, 0], reduction = [0, 0, 128], promote_operands = [0, 1], mma_kind = #iree_gpu.virtual_mma_layout, subgroup_m_count = 2, subgroup_n_count = 2}> #translation = #iree_codegen.translation_info}> @@ -665,15 +712,15 @@ hal.executable public @contract_schedule_considering_read_layout { #hal.pipeline.binding, #hal.pipeline.binding ]> -hal.executable @virtual_intrinsic_256x256x256_f16_f32 { +hal.executable @virtual_intrinsic_256x256x256_16x16x32xf8E4M3FNUZ_f32 { hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) { - hal.executable.export @virtual_intrinsic_256x256x256_f16_f32 layout(#pipeline_layout) { + hal.executable.export @virtual_intrinsic_256x256x256_16x16x32xf8E4M3FNUZ_f32 layout(#pipeline_layout) { ^bb0(%arg0: !hal.device, %arg1: index, %arg2 : index): %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2 hal.return %x, %y, %z : index, index, index } builtin.module { - func.func @virtual_intrinsic_256x256x256_f16_f32() attributes {translation_info = #translation} { + func.func @virtual_intrinsic_256x256x256_16x16x32xf8E4M3FNUZ_f32() attributes {translation_info = #translation} { %cst = arith.constant 0.000000e+00 : f32 %c0 = arith.constant 0 : index %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> @@ -691,7 +738,7 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) { } } -// CHECK-LABEL: func @virtual_intrinsic_256x256x256_f16_f32 +// CHECK-LABEL: func @virtual_intrinsic_256x256x256_16x16x32xf8E4M3FNUZ_f32 // CHECK: scf.for {{.*}} = %c0 to %c256 step %c128 iter_args(%[[ARG:.+]] = {{.*}}) -> (vector<1x1x4x1x4x1xf32>) // Validate that VMFMA is decomposed into coalesced read and 2 MFMAs: @@ -718,6 +765,55 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) { // ----- +// This test ensures we can generate correct instructions from V(Virtual) MFMAs that has only different read layouts. +// (intrinsic with shape m = 16, n = 16, k = 32) + +#config = #iree_gpu.lowering_config<{workgroup = [32, 32, 0], reduction = [0, 0, 128], promote_operands = [0, 1], mma_kind = #iree_gpu.virtual_mma_layout, subgroup_m_count = 1, subgroup_n_count = 1}> +#translation = #iree_codegen.translation_info}> + +#pipeline_layout = #hal.pipeline.layout, + #hal.pipeline.binding, + #hal.pipeline.binding +]> +hal.executable @virtual_intrinsic_256x256x256_32x32x16_f8E4M3FNUZ_f32 { +hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) { + hal.executable.export @virtual_intrinsic_256x256x256_32x32x16_f8E4M3FNUZ_f32 layout(#pipeline_layout) { + ^bb0(%arg0: !hal.device, %arg1: index, %arg2 : index): + %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2 + hal.return %x, %y, %z : index, index, index + } + builtin.module { + func.func @virtual_intrinsic_256x256x256_32x32x16_f8E4M3FNUZ_f32() attributes {translation_info = #translation} { + %cst = arith.constant 0.000000e+00 : f32 + %c0 = arith.constant 0 : index + %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<256x256xf8E4M3FNUZ> + %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<256x256xf8E4M3FNUZ> + %5 = tensor.empty() : tensor<256x256xf32> + %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<256x256xf32>) -> tensor<256x256xf32> + %7 = linalg.matmul {lowering_config = #config} ins(%3, %4 : tensor<256x256xf8E4M3FNUZ>, tensor<256x256xf8E4M3FNUZ>) outs(%6 : tensor<256x256xf32>) -> tensor<256x256xf32> + flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : tensor<256x256xf32> -> !flow.dispatch.tensor> + return + } + } +} +} + +// Basic pipeline test to make sure it generates the instructions we expect. + +// CHECK-LABEL: func.func @virtual_intrinsic_256x256x256_32x32x16_f8E4M3FNUZ_f32() +// CHECK: scf.for {{.*}} = %c0 to %c256 step %c128 iter_args({{.*}}) -> (vector<1x1x4x1x4x1xf32>) +// Each subgroup handles 1 * 1 tiles, and for each tile we accumulate (128 / 16) = 8 times +// along the K dimension. So in total 8 mfma ops. +// CHECK-COUNT-8: amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 32 : i32, n = 32 : i32} blgp = none : vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<16xf32> +// CHECK: scf.yield %{{.+}} : vector<1x1x4x1x4x1xf32> +// CHECK-COUNT-4: vector.transfer_write {{.+}} {in_bounds = [true, true]} : vector<4x1xf32>, memref<256x256xf32, #hal.descriptor_type> + +// ----- + // This test ensures we can generate correct instructions from V(Virtual) MFMAs that has only different read layouts. #config = #iree_gpu.lowering_config<{workgroup = [32, 32, 0], reduction = [0, 0, 128], promote_operands = [0, 1], mma_kind = #iree_gpu.virtual_mma_layout, subgroup_m_count = 2, subgroup_n_count = 2}> @@ -728,15 +824,15 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) { #hal.pipeline.binding, #hal.pipeline.binding ]> -hal.executable @virtual_intrinsic_256x256x256_f8E4M3FNUZ_f32 { +hal.executable @virtual_intrinsic_256x256x256_16x16x32xf8E4M3FNUZ_f32 { hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) { - hal.executable.export @virtual_intrinsic_256x256x256_f8E4M3FNUZ_f32 layout(#pipeline_layout) { + hal.executable.export @virtual_intrinsic_256x256x256_16x16x32xf8E4M3FNUZ_f32 layout(#pipeline_layout) { ^bb0(%arg0: !hal.device, %arg1: index, %arg2 : index): %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2 hal.return %x, %y, %z : index, index, index } builtin.module { - func.func @virtual_intrinsic_256x256x256_f8E4M3FNUZ_f32() attributes {translation_info = #translation} { + func.func @virtual_intrinsic_256x256x256_16x16x32xf8E4M3FNUZ_f32() attributes {translation_info = #translation} { %cst = arith.constant 0.000000e+00 : f32 %c0 = arith.constant 0 : index %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> @@ -754,7 +850,7 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) { } } -// CHECK-LABEL: func @virtual_intrinsic_256x256x256_f8E4M3FNUZ_f32 +// CHECK-LABEL: func @virtual_intrinsic_256x256x256_16x16x32xf8E4M3FNUZ_f32 // CHECK-DAG: %[[ALLOC_LHS:.+]] = memref.alloc() : memref<32x136xf8E4M3FNUZ, #gpu.address_space> // CHECK-DAG: %[[ALLOC_RHS:.+]] = memref.alloc() : memref<128x40xf8E4M3FNUZ, #gpu.address_space> // CHECK: scf.for {{.*}} = %c0 to %c256 step %c128 iter_args(%[[ARG:.+]] = {{.*}}) -> (vector<1x1x1x1x4x1xf32>)