Skip to content

Commit

Permalink
Add new MMAIntrinsic's
Browse files Browse the repository at this point in the history
Signed-off-by: Benoit Jacob <[email protected]>
  • Loading branch information
bjacob committed Nov 12, 2024
1 parent 5b9c4d9 commit ab2aa4c
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 8 deletions.
6 changes: 3 additions & 3 deletions compiler/plugins/target/ROCM/test/target_device_features.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
// GFX942: target = #iree_gpu.target<arch = "gfx942",
// GFX942-SAME: wgp = <compute = fp64|fp32|fp16|int64|int32|int16|int8, storage = b64|b32|b16|b8,
// GFX942-SAME: subgroup = shuffle|arithmetic, dot = dp4xi8toi32,
// GFX942-SAME: mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>, <MFMA_F32_16x16x16_BF16>, <MFMA_F32_32x32x8_BF16>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E5M2FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>],
// GFX942-SAME: mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>, <MFMA_F64_16x16x4_F64>, <MFMA_F32_16x16x16_BF16>, <MFMA_F32_32x32x8_BF16>, <MFMA_F32_16x16x32_F8E5M2FNUZ>, <MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>],
// GFX942-SAME: subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024],
// GFX942-SAME: max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536,
// GFX942-SAME: max_workgroup_counts = [2147483647, 2147483647, 2147483647],
Expand All @@ -26,10 +26,10 @@
// GFX941-SAME: features = "+sramecc,-xnack"

// GFX940: target = #iree_gpu.target<arch = "gfx940",
// GFX940-SAME: mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>, <MFMA_F32_16x16x16_BF16>, <MFMA_F32_32x32x8_BF16>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E5M2FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>],
// GFX940-SAME: mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>, <MFMA_F64_16x16x4_F64>, <MFMA_F32_16x16x16_BF16>, <MFMA_F32_32x32x8_BF16>, <MFMA_F32_16x16x32_F8E5M2FNUZ>, <MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>],

// GFX1100: target = #iree_gpu.target<arch = "gfx1100",
// GFX1100-SAME: mma = [<WMMA_F32_16x16x16_F16>, <WMMA_F16_16x16x16_F16>, <WMMA_I32_16x16x16_I8>]
// GFX1100-SAME: mma = [<WMMA_F32_16x16x16_F16>, <WMMA_F16_16x16x16_F16>, <WMMA_I32_16x16x16_I8>, <WMMA_I32_16x16x16_I8>, <WMMA_I32_16x16x16_I8>]
// GFX1100-SAME: subgroup_size_choices = [32, 64]

stream.executable public @reduce_dispatch {
Expand Down
63 changes: 63 additions & 0 deletions compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,9 +216,13 @@ static std::tuple<Type, Type, Type> getABCElementTypes(MLIRContext *context,
Type f16 = Float16Type::get(context);
Type bf16 = BFloat16Type::get(context);
Type f32 = Float32Type::get(context);
Type f64 = Float64Type::get(context);
Type i8 = IntegerType::get(context, 8);
Type i32 = IntegerType::get(context, 32);
switch (intrinsic) {
case MMAIntrinsic::MFMA_F64_16x16x4_F64: {
return {f64, f64, f64};
}
case MMAIntrinsic::MFMA_F32_16x16x4_F32: {
return {f32, f32, f32};
}
Expand All @@ -228,6 +232,12 @@ static std::tuple<Type, Type, Type> getABCElementTypes(MLIRContext *context,
case MMAIntrinsic::MFMA_F32_32x32x8_F16: {
return {f16, f16, f32};
}
case MMAIntrinsic::MFMA_F32_16x16x8_BF16: {
return {bf16, bf16, f32};
}
case MMAIntrinsic::MFMA_F32_32x32x4_BF16: {
return {bf16, bf16, f32};
}
case MMAIntrinsic::MFMA_F32_16x16x16_BF16: {
return {bf16, bf16, f32};
}
Expand All @@ -240,6 +250,12 @@ static std::tuple<Type, Type, Type> getABCElementTypes(MLIRContext *context,
case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ: {
return {f8E5M2FNUZ, f8E5M2FNUZ, f32};
}
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ: {
return {f8E4M3FNUZ, f8E5M2FNUZ, f32};
}
case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ: {
return {f8E5M2FNUZ, f8E4M3FNUZ, f32};
}
case MMAIntrinsic::MFMA_I32_16x16x32_I8: {
return {i8, i8, i32};
}
Expand All @@ -258,6 +274,12 @@ static std::tuple<Type, Type, Type> getABCElementTypes(MLIRContext *context,
case MMAIntrinsic::WMMA_F16_16x16x16_F16: {
return {f16, f16, f16};
}
case MMAIntrinsic::WMMA_F32_16x16x16_BF16: {
return {bf16, bf16, f32};
}
case MMAIntrinsic::WMMA_BF16_16x16x16_BF16: {
return {bf16, bf16, bf16};
}
case MMAIntrinsic::WMMA_I32_16x16x16_I8: {
return {i8, i8, i32};
}
Expand Down Expand Up @@ -505,6 +527,43 @@ MMASingleSubgroupLayout getSingleSubgroupLayout(MMAIntrinsic intrinsic,
return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*tstrides=*/{16, 1},
/*element=*/{4, 1}};
}
case MMAIntrinsic::MFMA_F64_16x16x4_F64:
switch (fragment) {
case MMAFragment::Lhs:
return {/*outer=*/{1, 1}, /*thread=*/{16, 4}, /*tstrides=*/{1, 16},
/*element=*/{1, 1}};
case MMAFragment::Rhs:
return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*tstrides=*/{16, 1},
/*element=*/{1, 1}};
case MMAFragment::Acc:
return {/*outer=*/{4, 1}, /*thread=*/{4, 16}, /*tstrides=*/{16, 1},
/*element=*/{1, 1}};
}
case MMAIntrinsic::MFMA_F32_16x16x8_BF16: {
switch (fragment) {
case MMAFragment::Lhs:
return {/*outer=*/{1, 1}, /*thread=*/{16, 4}, /*tstrides=*/{1, 16},
/*element=*/{1, 2}};
case MMAFragment::Rhs:
return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*tstrides=*/{16, 1},
/*element=*/{2, 1}};
case MMAFragment::Acc:
return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*tstrides=*/{16, 1},
/*element=*/{4, 1}};
}
}
case MMAIntrinsic::MFMA_F32_32x32x4_BF16:
switch (fragment) {
case MMAFragment::Lhs:
return {/*outer=*/{1, 1}, /*thread=*/{32, 2}, /*tstrides=*/{1, 32},
/*element=*/{1, 2}};
case MMAFragment::Rhs:
return {/*outer=*/{1, 1}, /*thread=*/{2, 32}, /*tstrides=*/{32, 1},
/*element=*/{2, 1}};
case MMAFragment::Acc:
return {/*outer=*/{4, 1}, /*thread=*/{2, 32}, /*tstrides=*/{32, 1},
/*element=*/{4, 1}};
}
case MMAIntrinsic::MFMA_I32_16x16x16_I8:
case MMAIntrinsic::MFMA_F32_16x16x16_F16:
case MMAIntrinsic::MFMA_F32_16x16x16_BF16:
Expand Down Expand Up @@ -535,6 +594,8 @@ MMASingleSubgroupLayout getSingleSubgroupLayout(MMAIntrinsic intrinsic,
}
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ:
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ:
case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ:
case MMAIntrinsic::MFMA_I32_16x16x32_I8:
switch (fragment) {
case MMAFragment::Lhs:
Expand All @@ -560,6 +621,7 @@ MMASingleSubgroupLayout getSingleSubgroupLayout(MMAIntrinsic intrinsic,
/*element=*/{4, 1}};
}
case MMAIntrinsic::WMMA_F32_16x16x16_F16:
case MMAIntrinsic::WMMA_F32_16x16x16_BF16:
case MMAIntrinsic::WMMA_I32_16x16x16_I8:
switch (fragment) {
case MMAFragment::Lhs:
Expand All @@ -573,6 +635,7 @@ MMASingleSubgroupLayout getSingleSubgroupLayout(MMAIntrinsic intrinsic,
/*element=*/{1, 1}};
}
case MMAIntrinsic::WMMA_F16_16x16x16_F16:
case MMAIntrinsic::WMMA_BF16_16x16x16_BF16:
switch (fragment) {
case MMAFragment::Lhs:
return {/*outer=*/{1, 1}, /*thread=*/{16, 1}, /*strides=*/{1, 0},
Expand Down
18 changes: 18 additions & 0 deletions compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td
Original file line number Diff line number Diff line change
Expand Up @@ -146,17 +146,26 @@ def MFMA_F32_32x32x8_F16 : I32EnumAttrCase<"MFMA_F32_32x32x8_F16", 0x1021>;
def MFMA_I32_16x16x16_I8 : I32EnumAttrCase<"MFMA_I32_16x16x16_I8", 0x10C0>;
def MFMA_I32_32x32x8_I8 : I32EnumAttrCase<"MFMA_I32_32x32x8_I8", 0x10C1>;

// Introduced in CDNA2
def MFMA_F32_16x16x8_BF16 : I32EnumAttrCase<"MFMA_F32_16x16x8_BF16", 0x1120>;
def MFMA_F32_32x32x4_BF16 : I32EnumAttrCase<"MFMA_F32_32x32x4_BF16", 0x1121>;
def MFMA_F64_16x16x4_F64 : I32EnumAttrCase<"MFMA_F64_16x16x4_F64", 0x1100>;

// Introduced in CDNA3
def MFMA_F32_16x16x16_BF16 : I32EnumAttrCase<"MFMA_F32_16x16x16_BF16", 0x1220>;
def MFMA_F32_32x32x8_BF16 : I32EnumAttrCase<"MFMA_F32_32x32x8_BF16", 0x1221>;
def MFMA_F32_16x16x32_F8E5M2FNUZ : I32EnumAttrCase<"MFMA_F32_16x16x32_F8E5M2FNUZ", 0x1230>;
def MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ : I32EnumAttrCase<"MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ", 0x1231>;
def MFMA_F32_16x16x32_F8E4M3FNUZ : I32EnumAttrCase<"MFMA_F32_16x16x32_F8E4M3FNUZ", 0x1232>;
def MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ : I32EnumAttrCase<"MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ", 0x1233>;
def MFMA_I32_16x16x32_I8 : I32EnumAttrCase<"MFMA_I32_16x16x32_I8", 0x12C0>;
def MFMA_I32_32x32x16_I8 : I32EnumAttrCase<"MFMA_I32_32x32x16_I8", 0x12C1>;

// Introduced in RDNA3
def WMMA_F32_16x16x16_F16 : I32EnumAttrCase<"WMMA_F32_16x16x16_F16", 0x1820>;
def WMMA_F16_16x16x16_F16 : I32EnumAttrCase<"WMMA_F16_16x16x16_F16", 0x1821>;
def WMMA_F32_16x16x16_BF16 : I32EnumAttrCase<"WMMA_F32_16x16x16_BF16", 0x1822>;
def WMMA_BF16_16x16x16_BF16 : I32EnumAttrCase<"WMMA_BF16_16x16x16_BF16", 0x1823>;
def WMMA_I32_16x16x16_I8 : I32EnumAttrCase<"WMMA_I32_16x16x16_I8", 0x18C0>;

// NV intrinsics
Expand All @@ -172,17 +181,26 @@ def IREEGPU_MMAIntrinsic : IREEGPU_I32MmaEnumAttr<"MMAIntrinsic",
MFMA_I32_16x16x16_I8,
MFMA_I32_32x32x8_I8,

// Introduced in CDNA2
MFMA_F32_16x16x8_BF16,
MFMA_F32_32x32x4_BF16,
MFMA_F64_16x16x4_F64,

// Introduced in CDNA3
MFMA_F32_16x16x16_BF16,
MFMA_F32_32x32x8_BF16,
MFMA_F32_16x16x32_F8E5M2FNUZ,
MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ,
MFMA_F32_16x16x32_F8E4M3FNUZ,
MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ,
MFMA_I32_16x16x32_I8,
MFMA_I32_32x32x16_I8,

// RDNA3 intrinsics
WMMA_F32_16x16x16_F16,
WMMA_F16_16x16x16_F16,
WMMA_F32_16x16x16_BF16,
WMMA_BF16_16x16x16_BF16,
WMMA_I32_16x16x16_I8,

// NV intrinsics
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,13 +133,19 @@ TargetAttr createTargetAttr(const TargetDetails &details, StringRef arch,

const WgpDetails *getCDNA3WgpDetails() {
static const MMAIntrinsic cdna3MMAOps[] = {
// Introduced in CDNA1, still present in CDNA3
MMAIntrinsic::MFMA_F32_16x16x4_F32,
MMAIntrinsic::MFMA_F32_16x16x16_F16,
MMAIntrinsic::MFMA_F32_32x32x8_F16,
// Introduced in CDNA2, still present in CDNA3
MMAIntrinsic::MFMA_F64_16x16x4_F64,
// Introduced in CDNA3
MMAIntrinsic::MFMA_F32_16x16x16_BF16,
MMAIntrinsic::MFMA_F32_32x32x8_BF16,
MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ,
MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ,
MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ,
MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ,
MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ,
MMAIntrinsic::MFMA_I32_16x16x32_I8,
MMAIntrinsic::MFMA_I32_32x32x16_I8,
};
Expand All @@ -162,10 +168,16 @@ const WgpDetails *getCDNA3WgpDetails() {

const WgpDetails *getCDNA2WgpDetails() {
static const MMAIntrinsic cdna2MMAOps[] = {
// Introduced in CDNA1
MMAIntrinsic::MFMA_F32_16x16x4_F32,
MMAIntrinsic::MFMA_F32_16x16x16_F16,
MMAIntrinsic::MFMA_F32_32x32x8_F16,
MMAIntrinsic::MFMA_I32_16x16x16_I8,
MMAIntrinsic::MFMA_I32_32x32x8_I8,
// Introduced in CDNA2
MMAIntrinsic::MFMA_F32_16x16x8_BF16,
MMAIntrinsic::MFMA_F32_32x32x4_BF16,
MMAIntrinsic::MFMA_F64_16x16x4_F64,
};
static const WgpDetails cdna2Wgp = {allComputeBits,
allStorageBits,
Expand All @@ -183,8 +195,9 @@ const WgpDetails *getCDNA2WgpDetails() {

const WgpDetails *getCDNA1WgpDetails() {
static const MMAIntrinsic cdna1MMAOps[] = {
MMAIntrinsic::MFMA_F32_16x16x16_F16,
MMAIntrinsic::MFMA_F32_32x32x8_F16,
MMAIntrinsic::MFMA_F32_16x16x4_F32, MMAIntrinsic::MFMA_F32_16x16x16_F16,
MMAIntrinsic::MFMA_F32_32x32x8_F16, MMAIntrinsic::MFMA_I32_16x16x16_I8,
MMAIntrinsic::MFMA_I32_32x32x8_I8,
};
static const WgpDetails cdna1Wgp = {allComputeBits,
allStorageBits,
Expand All @@ -202,9 +215,10 @@ const WgpDetails *getCDNA1WgpDetails() {

const WgpDetails *getRDNA3WgpDetails() {
static const MMAIntrinsic rdna3MMAOps[] = {
MMAIntrinsic::WMMA_F32_16x16x16_F16,
MMAIntrinsic::WMMA_F16_16x16x16_F16,
MMAIntrinsic::WMMA_F32_16x16x16_F16, MMAIntrinsic::WMMA_F16_16x16x16_F16,
MMAIntrinsic::WMMA_I32_16x16x16_I8, MMAIntrinsic::WMMA_I32_16x16x16_I8,
MMAIntrinsic::WMMA_I32_16x16x16_I8,

};
static const WgpDetails rdna3Wgp = {allComputeBits,
allStorageBits,
Expand Down
30 changes: 30 additions & 0 deletions tests/e2e/matmul/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1741,6 +1741,36 @@ iree_generated_e2e_runner_test(
"requires-gpu-cdna3"
)

iree_generated_e2e_runner_test(
NAME
e2e_matmul_cdna3_dt_f64
TEST_TYPE
matmul
GENERATOR
"generate_e2e_matmul_tests.py"
GENERATOR_ARGS
"--lhs_rhs_type=f64"
"--acc_type=f64"
TEST_RUNNER
iree_tools_testing_e2e_iree-e2e-matmul-test
TARGET_BACKENDS
"rocm"
DRIVERS
"hip"
COMPILER_FLAGS
${IREE_HIP_TEST_COMPILER_FLAGS}
"--iree-opt-data-tiling"
"--iree-global-opt-experimental-rocm-data-tiling"
"--iree-global-opt-enable-early-materialization=true"
"--iree-input-demote-f64-to-f32=false"
LABELS
"noasan"
"nomsan"
"notsan"
"noubsan"
"requires-gpu-cdna3"
)

endif()

elseif(IREE_HIP_TEST_TARGET_CHIP MATCHES "^gfx11")
Expand Down

0 comments on commit ab2aa4c

Please sign in to comment.