diff --git a/compiler/plugins/target/ROCM/test/target_device_features.mlir b/compiler/plugins/target/ROCM/test/target_device_features.mlir index 71a8272d341b..0809252abc9b 100644 --- a/compiler/plugins/target/ROCM/test/target_device_features.mlir +++ b/compiler/plugins/target/ROCM/test/target_device_features.mlir @@ -15,7 +15,7 @@ // GFX942: target = #iree_gpu.target, , , , , , , , , , , ], +// GFX942-SAME: mma = [, , , , , , , , , , , , , , , ], // GFX942-SAME: subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], // GFX942-SAME: max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, // GFX942-SAME: max_workgroup_counts = [2147483647, 2147483647, 2147483647], @@ -26,7 +26,7 @@ // GFX941-SAME: features = "+sramecc,-xnack" // GFX940: target = #iree_gpu.target, , , , , , , , , , , ], +// GFX940-SAME: mma = [, , , , , , , , , , , , , , , ], // GFX1100: target = #iree_gpu.target, , , , ] diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp index 4f9b1ffab3b6..ea06da542ce5 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp @@ -256,6 +256,18 @@ static std::tuple getABCElementTypes(MLIRContext *context, case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ: { return {f8E5M2FNUZ, f8E4M3FNUZ, f32}; } + case MMAIntrinsic::MFMA_F32_32x32x16_F8E4M3FNUZ: { + return {f8E4M3FNUZ, f8E4M3FNUZ, f32}; + } + case MMAIntrinsic::MFMA_F32_32x32x16_F8E5M2FNUZ: { + return {f8E5M2FNUZ, f8E5M2FNUZ, f32}; + } + case MMAIntrinsic::MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ: { + return {f8E4M3FNUZ, f8E5M2FNUZ, f32}; + } + case MMAIntrinsic::MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ: { + return {f8E5M2FNUZ, f8E4M3FNUZ, f32}; + } case MMAIntrinsic::MFMA_I32_16x16x32_I8: { return {i8, i8, i32}; } @@ -608,6 +620,10 @@ MMASingleSubgroupLayout getSingleSubgroupLayout(MMAIntrinsic intrinsic, return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*tstrides=*/{16, 1}, /*element=*/{4, 1}}; } + case MMAIntrinsic::MFMA_F32_32x32x16_F8E4M3FNUZ: + case MMAIntrinsic::MFMA_F32_32x32x16_F8E5M2FNUZ: + case MMAIntrinsic::MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ: + case MMAIntrinsic::MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ: case MMAIntrinsic::MFMA_I32_32x32x16_I8: switch (fragment) { case MMAFragment::Lhs: @@ -675,6 +691,8 @@ SmallVector MMAAttr::getVirtualIntrinsics() const { return {VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16}; case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ: return {VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ}; + case MMAIntrinsic::MFMA_F32_32x32x16_F8E4M3FNUZ: + return {VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F8E4M3FNUZ}; default: return {}; } @@ -1218,6 +1236,9 @@ static OpaqueMmaLayout getOpaqueVMMALayout(MLIRContext *context, case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ: { return OpaqueMmaLayout{16, 16, 32, f8E4M3FNUZ, f8E4M3FNUZ, f32}; } + case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F8E4M3FNUZ: { + return OpaqueMmaLayout{32, 32, 16, f8E4M3FNUZ, f8E4M3FNUZ, f32}; + } // V(Virtual)MFMA instructions which have 2 mfma instructions interleaved // along the k dimension. case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F16: { @@ -1252,6 +1273,7 @@ VirtualMMAAttr::getABCVectorTypes() const { auto cType = VectorType::get({4}, C); return {aType, bType, cType}; } + case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F8E4M3FNUZ: case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16: { auto aType = VectorType::get({8}, A); auto bType = VectorType::get({8}, B); @@ -1274,6 +1296,7 @@ int64_t VirtualMMAAttr::getSubgroupSize() const { switch (getIntrinsic().getValue()) { case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ: case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F16: + case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F8E4M3FNUZ: case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16: { return 64; } @@ -1328,7 +1351,8 @@ int64_t VirtualMMAAttr::getUnrollK() const { case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16: { return 2; } - case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ: { + case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ: + case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F8E4M3FNUZ: { return 1; } } @@ -1356,6 +1380,7 @@ FailureOr VirtualMMAAttr::buildMmaOperation(OpBuilder &builder, switch (getIntrinsic().getValue()) { case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ: case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F16: + case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F8E4M3FNUZ: case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16: { // Generate mfma's for K with unrolled kernels. const int64_t unrollKFactor = getUnrollK(); @@ -1394,6 +1419,7 @@ int64_t VirtualMMAAttr::getBlockSize() const { switch (getIntrinsic().getValue()) { case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ: case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F16: + case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F8E4M3FNUZ: case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16: { return 1; } @@ -1442,6 +1468,18 @@ MMASingleSubgroupLayout getSingleSubgroupLayout(VirtualMMAIntrinsic intrinsic, return {/*outer=*/{4, 1}, /*thread=*/{2, 32}, /*tstrides=*/{32, 1}, /*element=*/{4, 1}}; } + case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F8E4M3FNUZ: + switch (fragment) { + case MMAFragment::Lhs: + return {/*outer=*/{1, 2}, /*thread=*/{32, 2}, /*tstrides=*/{1, 32}, + /*element=*/{1, 4}}; + case MMAFragment::Rhs: + return {/*outer=*/{2, 1}, /*thread=*/{2, 32}, /*tstrides=*/{32, 1}, + /*element=*/{4, 1}}; + case MMAFragment::Acc: + return {/*outer=*/{4, 1}, /*thread=*/{2, 32}, /*tstrides=*/{32, 1}, + /*element=*/{4, 1}}; + } } assert(false && "unhandled virtual mma layout type."); return {}; diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td index ac03a9fa5fa2..b174ceaac916 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td @@ -158,6 +158,10 @@ def MFMA_F32_16x16x32_F8E5M2FNUZ : I32EnumAttrCase<"MFMA_F32_16x16x32_F8E5M2FNUZ def MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ : I32EnumAttrCase<"MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ", 0x1231>; def MFMA_F32_16x16x32_F8E4M3FNUZ : I32EnumAttrCase<"MFMA_F32_16x16x32_F8E4M3FNUZ", 0x1232>; def MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ : I32EnumAttrCase<"MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ", 0x1233>; +def MFMA_F32_32x32x16_F8E5M2FNUZ : I32EnumAttrCase<"MFMA_F32_32x32x16_F8E5M2FNUZ", 0x1234>; +def MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ : I32EnumAttrCase<"MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ", 0x1235>; +def MFMA_F32_32x32x16_F8E4M3FNUZ : I32EnumAttrCase<"MFMA_F32_32x32x16_F8E4M3FNUZ", 0x1236>; +def MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ : I32EnumAttrCase<"MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ", 0x1237>; def MFMA_I32_16x16x32_I8 : I32EnumAttrCase<"MFMA_I32_16x16x32_I8", 0x12C0>; def MFMA_I32_32x32x16_I8 : I32EnumAttrCase<"MFMA_I32_32x32x16_I8", 0x12C1>; @@ -193,6 +197,10 @@ def IREEGPU_MMAIntrinsic : IREEGPU_I32MmaEnumAttr<"MMAIntrinsic", MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ, MFMA_F32_16x16x32_F8E4M3FNUZ, MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ, + MFMA_F32_32x32x16_F8E5M2FNUZ, + MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ, + MFMA_F32_32x32x16_F8E4M3FNUZ, + MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ, MFMA_I32_16x16x32_I8, MFMA_I32_32x32x16_I8, @@ -211,12 +219,14 @@ def IREEGPU_MMAIntrinsic : IREEGPU_I32MmaEnumAttr<"MMAIntrinsic", def VMFMA_F32_16x16x32_F16 : I32EnumAttrCase<"VMFMA_F32_16x16x32_F16", 0>; def VMFMA_F32_32x32x16_F16 : I32EnumAttrCase<"VMFMA_F32_32x32x16_F16", 1>; def VMFMA_F32_16x16x32_F8E4M3FNUZ : I32EnumAttrCase<"VMFMA_F32_16x16x32_F8E4M3FNUZ", 2>; +def VMFMA_F32_32x32x16_F8E4M3FNUZ : I32EnumAttrCase<"VMFMA_F32_32x32x16_F8E4M3FNUZ", 3>; def IREEGPU_VirtualMMAIntrinsic : IREEGPU_I32MmaEnumAttr<"VirtualMMAIntrinsic", "Descriptor for different Virtual MMA intrinsics", [ VMFMA_F32_16x16x32_F16, VMFMA_F32_32x32x16_F16, VMFMA_F32_16x16x32_F8E4M3FNUZ, + VMFMA_F32_32x32x16_F8E4M3FNUZ, ]>; def MMA_LHS : I32EnumAttrCase<"Lhs", 0>; diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp index e198e216ece7..f4a626a70b42 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp @@ -146,6 +146,10 @@ const WgpDetails *getCDNA3WgpDetails() { MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ, MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ, MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ, + MMAIntrinsic::MFMA_F32_32x32x16_F8E5M2FNUZ, + MMAIntrinsic::MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ, + MMAIntrinsic::MFMA_F32_32x32x16_F8E4M3FNUZ, + MMAIntrinsic::MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ, MMAIntrinsic::MFMA_I32_16x16x32_I8, MMAIntrinsic::MFMA_I32_32x32x16_I8, }; diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir index 0eeeafe775ae..5e5096dfa637 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir @@ -215,7 +215,7 @@ hal.executable @matmul_multiple_k { // ----- -// Basic f8, f8 -> f32 matmul. +// Basic f8, f8 -> f32 matmul. (intrinsic with shape, m = 16, n = 16, k = 32) #config = #iree_gpu.lowering_config<{workgroup = [64, 64, 0], reduction = [0, 0, 256], promote_operands = [0, 1], mma_kind = #iree_gpu.mma_layout, subgroup_m_count = 2, subgroup_n_count = 2}> #translation = #iree_codegen.translation_info}> @@ -225,15 +225,15 @@ hal.executable @matmul_multiple_k { #hal.pipeline.binding, #hal.pipeline.binding ]> -hal.executable @matmul_256x256x256_f8_f32 { +hal.executable @matmul_256x256x256_16x16x32_f8_f32 { hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) { - hal.executable.export @matmul_256x256x256_f8_f32 layout(#pipeline_layout) { + hal.executable.export @matmul_256x256x256_16x16x32_f8_f32 layout(#pipeline_layout) { ^bb0(%arg0: !hal.device, %arg1: index, %arg2 : index): %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2 hal.return %x, %y, %z : index, index, index } builtin.module { - func.func @matmul_256x256x256_f8_f32() attributes {translation_info = #translation} { + func.func @matmul_256x256x256_16x16x32_f8_f32() attributes {translation_info = #translation} { %cst = arith.constant 0.000000e+00 : f32 %c0 = arith.constant 0 : index %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> @@ -253,7 +253,7 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) { // Make sure it generates the mfma instructions we expect for f8 inputs. -// CHECK-LABEL: func.func @matmul_256x256x256_f8_f32() +// CHECK-LABEL: func.func @matmul_256x256x256_16x16x32_f8_f32() // Each subgroup handles 2 * 2 tiles, and for each tile we accumulate 8 times // along the K dimension. So in total 32 mfma ops. // CHECK-COUNT-32: amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32} blgp = none : vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32> @@ -307,6 +307,52 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) { // ----- +// Basic f8, f8 -> f32 matmul. (intrinsic with shape, m = 32, n = 32, k = 16) + +#config = #iree_gpu.lowering_config<{workgroup = [64, 64, 0], reduction = [0, 0, 256], promote_operands = [0, 1], mma_kind = #iree_gpu.mma_layout, subgroup_m_count = 2, subgroup_n_count = 2}> +#translation = #iree_codegen.translation_info}> + +#pipeline_layout = #hal.pipeline.layout, + #hal.pipeline.binding, + #hal.pipeline.binding +]> +hal.executable @matmul_256x256x256_32x32x16_f8_f32 { +hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) { + hal.executable.export @matmul_256x256x256_32x32x16_f8_f32 layout(#pipeline_layout) { + ^bb0(%arg0: !hal.device, %arg1: index, %arg2 : index): + %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2 + hal.return %x, %y, %z : index, index, index + } + builtin.module { + func.func @matmul_256x256x256_32x32x16_f8_f32() attributes {translation_info = #translation} { + %cst = arith.constant 0.000000e+00 : f32 + %c0 = arith.constant 0 : index + %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<256x256xf8E4M3FNUZ> + %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<256x256xf8E4M3FNUZ> + %5 = tensor.empty() : tensor<256x256xf32> + %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<256x256xf32>) -> tensor<256x256xf32> + %7 = linalg.matmul {lowering_config = #config} ins(%3, %4 : tensor<256x256xf8E4M3FNUZ>, tensor<256x256xf8E4M3FNUZ>) outs(%6 : tensor<256x256xf32>) -> tensor<256x256xf32> + flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : tensor<256x256xf32> -> !flow.dispatch.tensor> + return + } + } +} +} + +// Make sure it generates the mfma instructions we expect for f8 inputs. + +// CHECK-LABEL: func.func @matmul_256x256x256_32x32x16_f8_f32() +// Each subgroup handles 1 * 1 tiles, and for each tile we accumulate (256/16) = 16 times +// along the K dimension. So in total 16 mfma ops. +// CHECK-COUNT-16: amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 32 : i32, n = 32 : i32} blgp = none : vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<16xf32> +// CHECK-COUNT-4: vector.transfer_write {{.+}} {in_bounds = [true, true]} : vector<4x1xf32>, memref<256x256xf32, #hal.descriptor_type> + +// ----- + // Basic i8, i8 -> i32 matmul_transpose_b. #config = #iree_gpu.lowering_config<{workgroup = [64, 64, 0], reduction = [0, 0, 256], promote_operands = [0, 1], mma_kind = #iree_gpu.mma_layout, subgroup_m_count = 2, subgroup_n_count = 2}> @@ -656,6 +702,7 @@ hal.executable public @contract_schedule_considering_read_layout { // ----- // This test ensures that we can generate and decompose the right instructions from V(Virtual) MFMAs. +// (intrinsic with shape, m = 32, n = 32, k = 16) #config = #iree_gpu.lowering_config<{workgroup = [64, 64, 0], reduction = [0, 0, 128], promote_operands = [0, 1], mma_kind = #iree_gpu.virtual_mma_layout, subgroup_m_count = 2, subgroup_n_count = 2}> #translation = #iree_codegen.translation_info}> @@ -665,15 +712,15 @@ hal.executable public @contract_schedule_considering_read_layout { #hal.pipeline.binding, #hal.pipeline.binding ]> -hal.executable @virtual_intrinsic_256x256x256_f16_f32 { +hal.executable @virtual_intrinsic_256x256x256_16x16x32xf8E4M3FNUZ_f32 { hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) { - hal.executable.export @virtual_intrinsic_256x256x256_f16_f32 layout(#pipeline_layout) { + hal.executable.export @virtual_intrinsic_256x256x256_16x16x32xf8E4M3FNUZ_f32 layout(#pipeline_layout) { ^bb0(%arg0: !hal.device, %arg1: index, %arg2 : index): %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2 hal.return %x, %y, %z : index, index, index } builtin.module { - func.func @virtual_intrinsic_256x256x256_f16_f32() attributes {translation_info = #translation} { + func.func @virtual_intrinsic_256x256x256_16x16x32xf8E4M3FNUZ_f32() attributes {translation_info = #translation} { %cst = arith.constant 0.000000e+00 : f32 %c0 = arith.constant 0 : index %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> @@ -691,7 +738,7 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) { } } -// CHECK-LABEL: func @virtual_intrinsic_256x256x256_f16_f32 +// CHECK-LABEL: func @virtual_intrinsic_256x256x256_16x16x32xf8E4M3FNUZ_f32 // CHECK: scf.for {{.*}} = %c0 to %c256 step %c128 iter_args(%[[ARG:.+]] = {{.*}}) -> (vector<1x1x4x1x4x1xf32>) // Validate that VMFMA is decomposed into coalesced read and 2 MFMAs: @@ -718,6 +765,55 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) { // ----- +// This test ensures we can generate correct instructions from V(Virtual) MFMAs that has only different read layouts. +// (intrinsic with shape m = 16, n = 16, k = 32) + +#config = #iree_gpu.lowering_config<{workgroup = [32, 32, 0], reduction = [0, 0, 128], promote_operands = [0, 1], mma_kind = #iree_gpu.virtual_mma_layout, subgroup_m_count = 1, subgroup_n_count = 1}> +#translation = #iree_codegen.translation_info}> + +#pipeline_layout = #hal.pipeline.layout, + #hal.pipeline.binding, + #hal.pipeline.binding +]> +hal.executable @virtual_intrinsic_256x256x256_32x32x16_f8E4M3FNUZ_f32 { +hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) { + hal.executable.export @virtual_intrinsic_256x256x256_32x32x16_f8E4M3FNUZ_f32 layout(#pipeline_layout) { + ^bb0(%arg0: !hal.device, %arg1: index, %arg2 : index): + %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2 + hal.return %x, %y, %z : index, index, index + } + builtin.module { + func.func @virtual_intrinsic_256x256x256_32x32x16_f8E4M3FNUZ_f32() attributes {translation_info = #translation} { + %cst = arith.constant 0.000000e+00 : f32 + %c0 = arith.constant 0 : index + %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<256x256xf8E4M3FNUZ> + %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<256x256xf8E4M3FNUZ> + %5 = tensor.empty() : tensor<256x256xf32> + %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<256x256xf32>) -> tensor<256x256xf32> + %7 = linalg.matmul {lowering_config = #config} ins(%3, %4 : tensor<256x256xf8E4M3FNUZ>, tensor<256x256xf8E4M3FNUZ>) outs(%6 : tensor<256x256xf32>) -> tensor<256x256xf32> + flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : tensor<256x256xf32> -> !flow.dispatch.tensor> + return + } + } +} +} + +// Basic pipeline test to make sure it generates the instructions we expect. + +// CHECK-LABEL: func.func @virtual_intrinsic_256x256x256_32x32x16_f8E4M3FNUZ_f32() +// CHECK: scf.for {{.*}} = %c0 to %c256 step %c128 iter_args({{.*}}) -> (vector<1x1x4x1x4x1xf32>) +// Each subgroup handles 1 * 1 tiles, and for each tile we accumulate (128 / 16) = 8 times +// along the K dimension. So in total 8 mfma ops. +// CHECK-COUNT-8: amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 32 : i32, n = 32 : i32} blgp = none : vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<16xf32> +// CHECK: scf.yield %{{.+}} : vector<1x1x4x1x4x1xf32> +// CHECK-COUNT-4: vector.transfer_write {{.+}} {in_bounds = [true, true]} : vector<4x1xf32>, memref<256x256xf32, #hal.descriptor_type> + +// ----- + // This test ensures we can generate correct instructions from V(Virtual) MFMAs that has only different read layouts. #config = #iree_gpu.lowering_config<{workgroup = [32, 32, 0], reduction = [0, 0, 128], promote_operands = [0, 1], mma_kind = #iree_gpu.virtual_mma_layout, subgroup_m_count = 2, subgroup_n_count = 2}> @@ -728,15 +824,15 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) { #hal.pipeline.binding, #hal.pipeline.binding ]> -hal.executable @virtual_intrinsic_256x256x256_f8E4M3FNUZ_f32 { +hal.executable @virtual_intrinsic_256x256x256_16x16x32xf8E4M3FNUZ_f32 { hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) { - hal.executable.export @virtual_intrinsic_256x256x256_f8E4M3FNUZ_f32 layout(#pipeline_layout) { + hal.executable.export @virtual_intrinsic_256x256x256_16x16x32xf8E4M3FNUZ_f32 layout(#pipeline_layout) { ^bb0(%arg0: !hal.device, %arg1: index, %arg2 : index): %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2 hal.return %x, %y, %z : index, index, index } builtin.module { - func.func @virtual_intrinsic_256x256x256_f8E4M3FNUZ_f32() attributes {translation_info = #translation} { + func.func @virtual_intrinsic_256x256x256_16x16x32xf8E4M3FNUZ_f32() attributes {translation_info = #translation} { %cst = arith.constant 0.000000e+00 : f32 %c0 = arith.constant 0 : index %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> @@ -754,7 +850,7 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) { } } -// CHECK-LABEL: func @virtual_intrinsic_256x256x256_f8E4M3FNUZ_f32 +// CHECK-LABEL: func @virtual_intrinsic_256x256x256_16x16x32xf8E4M3FNUZ_f32 // CHECK-DAG: %[[ALLOC_LHS:.+]] = memref.alloc() : memref<32x136xf8E4M3FNUZ, #gpu.address_space> // CHECK-DAG: %[[ALLOC_RHS:.+]] = memref.alloc() : memref<128x40xf8E4M3FNUZ, #gpu.address_space> // CHECK: scf.for {{.*}} = %c0 to %c256 step %c128 iter_args(%[[ARG:.+]] = {{.*}}) -> (vector<1x1x1x1x4x1xf32>) diff --git a/tests/e2e/matmul/generate_e2e_matmul_tests.py b/tests/e2e/matmul/generate_e2e_matmul_tests.py index 5782120b0954..a97b5626c069 100644 --- a/tests/e2e/matmul/generate_e2e_matmul_tests.py +++ b/tests/e2e/matmul/generate_e2e_matmul_tests.py @@ -339,6 +339,10 @@ def get_rocm_test_compilation_infos( MMASchedule("MFMA_F32_16x16x32_F8E4M3FNUZ", 2, 2, 1, 1, 2), MMASchedule("MFMA_F32_16x16x32_F8E4M3FNUZ", 4, 1, 4, 1, 1), MMASchedule("MFMA_F32_16x16x32_F8E4M3FNUZ", 4, 2, 4, 2, 1), + MMASchedule("MFMA_F32_32x32x16_F8E4M3FNUZ", 1, 1, 1, 1, 1), + MMASchedule("MFMA_F32_32x32x16_F8E4M3FNUZ", 2, 2, 1, 1, 2), + MMASchedule("MFMA_F32_32x32x16_F8E4M3FNUZ", 4, 1, 1, 2, 2), + MMASchedule("MFMA_F32_32x32x16_F8E4M3FNUZ", 4, 2, 2, 2, 2), MMASchedule("MFMA_I32_16x16x32_I8", 1, 1, 1, 1, 1), MMASchedule("MFMA_I32_16x16x32_I8", 2, 2, 1, 1, 2), MMASchedule("MFMA_I32_16x16x32_I8", 4, 1, 4, 1, 1), @@ -409,6 +413,7 @@ def get_rocm_test_compilation_infos( wg_tile_k = schedule.k_tile_count * 32 elif ( schedule.intrinsic == "VMFMA_F32_32x32x16_F16" + or schedule.intrinsic == "MFMA_F32_32x32x16_F8E4M3FNUZ" or schedule.intrinsic == "MFMA_I32_32x32x16_I8" ): wg_tile_m = schedule.m_count * schedule.m_tile_count * 32