Skip to content

Commit

Permalink
Refactor existing MMA intrinsics (#19098)
Browse files Browse the repository at this point in the history
Some code simplifications in `IREEGPUAttrs.cpp`, and some shuffling of
the MMAIntrinsic enum, making it easier to add more MMAIntrinsics in the
future (see #19099).

---------

Signed-off-by: Benoit Jacob <[email protected]>
  • Loading branch information
bjacob authored Nov 12, 2024
1 parent 2bfc639 commit bbb87aa
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 202 deletions.
237 changes: 85 additions & 152 deletions compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ namespace mlir::iree_compiler::IREE::GPU {
namespace {
// Struct containing abstract MMA shape and type information.
struct OpaqueMmaLayout {
int64_t mSize;
int64_t nSize;
int64_t kSize;
int64_t mSize = 0;
int64_t nSize = 0;
int64_t kSize = 0;
Type aType;
Type bType;
Type cType;
Expand Down Expand Up @@ -209,63 +209,77 @@ getContractionLayout(vector::ContractionOp contract, ConcreteMmaLayout layout) {
// Layout Attribute Building Helpers
//===----------------------------------------------------------------------===//

static OpaqueMmaLayout getOpaqueMFMALayout(MLIRContext *context,
MMAIntrinsic type) {
static std::tuple<Type, Type, Type> getABCElementTypes(MLIRContext *context,
MMAIntrinsic intrinsic) {
Type f8E4M3FNUZ = Float8E4M3FNUZType::get(context);
Type f8E5M2FNUZ = Float8E5M2FNUZType::get(context);
Type f16 = Float16Type::get(context);
Type bf16 = BFloat16Type::get(context);
Type f32 = Float32Type::get(context);

Type i8 = IntegerType::get(context, 8);
Type i32 = IntegerType::get(context, 32);

switch (type) {
switch (intrinsic) {
case MMAIntrinsic::MFMA_F32_16x16x4_F32: {
return OpaqueMmaLayout{16, 16, 4, f32, f32, f32};
return {f32, f32, f32};
}
case MMAIntrinsic::MFMA_F32_16x16x16_F16: {
return OpaqueMmaLayout{16, 16, 16, f16, f16, f32};
return {f16, f16, f32};
}
case MMAIntrinsic::MFMA_F32_32x32x8_F16: {
return OpaqueMmaLayout{32, 32, 8, f16, f16, f32};
return {f16, f16, f32};
}
case MMAIntrinsic::MFMA_F32_16x16x16_BF16: {
return OpaqueMmaLayout{16, 16, 16, bf16, bf16, f32};
return {bf16, bf16, f32};
}
case MMAIntrinsic::MFMA_F32_32x32x8_BF16: {
return OpaqueMmaLayout{32, 32, 8, bf16, bf16, f32};
return {bf16, bf16, f32};
}
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ: {
return OpaqueMmaLayout{16, 16, 32, f8E4M3FNUZ, f8E4M3FNUZ, f32};
return {f8E4M3FNUZ, f8E4M3FNUZ, f32};
}
case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ: {
return OpaqueMmaLayout{16, 16, 32, f8E5M2FNUZ, f8E5M2FNUZ, f32};
return {f8E5M2FNUZ, f8E5M2FNUZ, f32};
}
case MMAIntrinsic::MFMA_I32_16x16x32_I8: {
return OpaqueMmaLayout{16, 16, 32, i8, i8, i32};
return {i8, i8, i32};
}
case MMAIntrinsic::MFMA_I32_32x32x16_I8: {
return OpaqueMmaLayout{32, 32, 16, i8, i8, i32};
return {i8, i8, i32};
}
case MMAIntrinsic::MFMA_I32_32x32x8_I8: {
return OpaqueMmaLayout{32, 32, 8, i8, i8, i32};
return {i8, i8, i32};
}
case MMAIntrinsic::MFMA_I32_16x16x16_I8: {
return OpaqueMmaLayout{16, 16, 16, i8, i8, i32};
return {i8, i8, i32};
}
case MMAIntrinsic::WMMA_F32_16x16x16_F16: {
return OpaqueMmaLayout{16, 16, 16, f16, f16, f32};
return {f16, f16, f32};
}
case MMAIntrinsic::WMMA_F16_16x16x16_F16: {
return OpaqueMmaLayout{16, 16, 16, f16, f16, f16};
return {f16, f16, f16};
}
case MMAIntrinsic::WMMA_I32_16x16x16_I8: {
return OpaqueMmaLayout{16, 16, 16, i8, i8, i32};
return {i8, i8, i32};
}
case MMAIntrinsic::NV_WMMA_F16_16x16x16_F16: {
return {f16, f16, f16};
}
llvm_unreachable("unhandled mfma layout type");
return OpaqueMmaLayout{};
case MMAIntrinsic::NV_WMMA_F32_16x16x16_F16: {
return {f16, f16, f32};
}
}
}

static OpaqueMmaLayout getOpaqueMFMALayout(MLIRContext *context,
MMAIntrinsic intrinsic) {
OpaqueMmaLayout o;
std::tie(o.aType, o.bType, o.cType) = getABCElementTypes(context, intrinsic);
auto lhs = getSingleSubgroupLayout(intrinsic, MMAFragment::Lhs);
auto rhs = getSingleSubgroupLayout(intrinsic, MMAFragment::Rhs);
o.mSize = lhs.outer[0] * lhs.thread[0] * lhs.element[0];
o.kSize = lhs.outer[1] * lhs.thread[1] * lhs.element[1];
o.nSize = rhs.outer[1] * rhs.thread[1] * rhs.element[1];
return o;
}

static std::tuple<PerDimLayoutAttr, PerDimLayoutAttr>
Expand Down Expand Up @@ -416,67 +430,25 @@ std::tuple<int64_t, int64_t, int64_t> MMAAttr::getMNKShape() const {
return {getMSize(), getNSize(), getKSize()};
}

// NOTE: For layout specifications of the WMMA intrinsics
// below we are assuming subgroupsize of 32.
static VectorType getVectorType(MLIRContext *context, MMAIntrinsic intrinsic,
MMAFragment fragment) {
auto o = getOpaqueMFMALayout(context, intrinsic);
auto s = getSingleSubgroupLayout(intrinsic, fragment);
Type elemType = (fragment == MMAFragment::Lhs) ? o.aType
: (fragment == MMAFragment::Rhs) ? o.bType
: o.cType;
return VectorType::get(
{s.element[0] * s.element[1] * s.outer[0] * s.outer[1]}, elemType);
}

std::tuple<VectorType, VectorType, VectorType>
MMAAttr::getABCVectorTypes() const {
// Check https://github.com/ROCm/amd_matrix_instruction_calculator for
// instruction details. Note here we are returning the number elements, while
// amd_matrix_instruction_calculator tells us about the number of 32-bit
// registers. So need to adjust accordingly. All vectors should be 1-D.
switch (getIntrinsic().getValue()) {
case MMAIntrinsic::MFMA_F32_16x16x4_F32: {
auto aType = VectorType::get({1}, getAType());
auto bType = VectorType::get({1}, getBType());
auto cType = VectorType::get({4}, getCType());
return std::make_tuple(aType, bType, cType);
}
case MMAIntrinsic::MFMA_I32_16x16x16_I8:
case MMAIntrinsic::MFMA_F32_16x16x16_F16:
case MMAIntrinsic::MFMA_F32_16x16x16_BF16: {
auto aType = VectorType::get({4}, getAType());
auto bType = VectorType::get({4}, getBType());
auto cType = VectorType::get({4}, getCType());
return std::make_tuple(aType, bType, cType);
}
case MMAIntrinsic::MFMA_I32_32x32x8_I8:
case MMAIntrinsic::MFMA_F32_32x32x8_F16:
case MMAIntrinsic::MFMA_F32_32x32x8_BF16: {
auto aType = VectorType::get({4}, getAType());
auto bType = VectorType::get({4}, getBType());
auto cType = VectorType::get({16}, getCType());
return std::make_tuple(aType, bType, cType);
}
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ:
case MMAIntrinsic::MFMA_I32_16x16x32_I8: {
auto aType = VectorType::get({8}, getAType());
auto bType = VectorType::get({8}, getBType());
auto cType = VectorType::get({4}, getCType());
return std::make_tuple(aType, bType, cType);
}
case MMAIntrinsic::MFMA_I32_32x32x16_I8: {
auto aType = VectorType::get({8}, getAType());
auto bType = VectorType::get({8}, getBType());
auto cType = VectorType::get({16}, getCType());
return std::make_tuple(aType, bType, cType);
}
case MMAIntrinsic::WMMA_F32_16x16x16_F16:
case MMAIntrinsic::WMMA_I32_16x16x16_I8: {
auto aType = VectorType::get({16}, getAType());
auto bType = VectorType::get({16}, getBType());
auto cType = VectorType::get({8}, getCType());
return std::make_tuple(aType, bType, cType);
}
case MMAIntrinsic::WMMA_F16_16x16x16_F16: {
auto aType = VectorType::get({16}, getAType());
auto bType = VectorType::get({16}, getBType());
auto cType = VectorType::get({16}, getCType());
return std::make_tuple(aType, bType, cType);
}
}
// This should not happen but just to make GCC happy.
return std::make_tuple(VectorType{}, VectorType{}, VectorType{});
MLIRContext *context = getContext();
MMAIntrinsic intrinsic = getIntrinsic().getValue();
VectorType aVecType = getVectorType(context, intrinsic, MMAFragment::Lhs);
VectorType bVecType = getVectorType(context, intrinsic, MMAFragment::Rhs);
VectorType cVecType = getVectorType(context, intrinsic, MMAFragment::Acc);
return {aVecType, bVecType, cVecType};
}

FailureOr<std::tuple<VectorLayoutInterface, VectorLayoutInterface,
Expand All @@ -488,51 +460,26 @@ MMAAttr::getContractionLayout(vector::ContractionOp contract) const {
}

int64_t MMAAttr::getBlockSize() const {
switch (getIntrinsic().getValue()) {
case MMAIntrinsic::MFMA_F32_16x16x4_F32:
case MMAIntrinsic::MFMA_F32_16x16x16_F16:
case MMAIntrinsic::MFMA_F32_16x16x16_BF16:
case MMAIntrinsic::MFMA_I32_16x16x16_I8:
case MMAIntrinsic::MFMA_F32_32x32x8_F16:
case MMAIntrinsic::MFMA_F32_32x32x8_BF16:
case MMAIntrinsic::MFMA_I32_32x32x8_I8:
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ:
case MMAIntrinsic::MFMA_I32_16x16x32_I8:
case MMAIntrinsic::MFMA_I32_32x32x16_I8:
case MMAIntrinsic::WMMA_F16_16x16x16_F16:
case MMAIntrinsic::WMMA_F32_16x16x16_F16:
case MMAIntrinsic::WMMA_I32_16x16x16_I8: {
return 1;
}
}
// This should not happen but just to make GCC happy.
return 0;
// Not supporting any block size other than 1 at the moment.
return 1;
}

static uint32_t getArchID(MMAIntrinsic intrinsic) {
return static_cast<int>(intrinsic) & 0xFF00;
}

static bool is_AMD_MFMA(MMAIntrinsic intrinsic) {
return getArchID(intrinsic) >= 0x1000 && getArchID(intrinsic) <= 0x17FF;
}

static bool is_AMD_WMMA(MMAIntrinsic intrinsic) {
return getArchID(intrinsic) >= 0x1800 && getArchID(intrinsic) <= 0x1FFF;
}

static int64_t getIntrinsicSubgroupSize(MMAIntrinsic intrinsic) {
switch (intrinsic) {
case MMAIntrinsic::MFMA_F32_16x16x4_F32:
case MMAIntrinsic::MFMA_F32_16x16x16_F16:
case MMAIntrinsic::MFMA_F32_16x16x16_BF16:
case MMAIntrinsic::MFMA_I32_16x16x16_I8:
case MMAIntrinsic::MFMA_F32_32x32x8_F16:
case MMAIntrinsic::MFMA_F32_32x32x8_BF16:
case MMAIntrinsic::MFMA_I32_32x32x8_I8:
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ:
case MMAIntrinsic::MFMA_I32_16x16x32_I8:
case MMAIntrinsic::MFMA_I32_32x32x16_I8: {
return 64;
}
case MMAIntrinsic::WMMA_F32_16x16x16_F16:
case MMAIntrinsic::WMMA_F16_16x16x16_F16:
case MMAIntrinsic::WMMA_I32_16x16x16_I8: {
return 32;
}
}
// This should not happen but just to make GCC happy.
return 0;
// Not using Wave64 at all at the moment, so the only place where the
// subgroup size is CDNA* architectures.
return is_AMD_MFMA(intrinsic) ? 64 : 32;
}

int64_t MMAAttr::getSubgroupSize() const {
Expand Down Expand Up @@ -637,6 +584,9 @@ MMASingleSubgroupLayout getSingleSubgroupLayout(MMAIntrinsic intrinsic,
return {/*outer=*/{16, 1}, /*thread=*/{1, 16}, /*tstrides=*/{0, 1},
/*element=*/{1, 1}};
}
case MMAIntrinsic::NV_WMMA_F32_16x16x16_F16:
case MMAIntrinsic::NV_WMMA_F16_16x16x16_F16:
return {};
}
return {};
}
Expand Down Expand Up @@ -683,41 +633,24 @@ FailureOr<Value> MMAAttr::buildMmaOperation(OpBuilder &builder, Location loc,
if (cType != resultType) {
return failure();
}
switch (getIntrinsic().getValue()) {
case MMAIntrinsic::MFMA_F32_16x16x4_F32: {
// Update the lhs and rhs to extract the first element since vector<1xT> is
// not supoorted by amgpu.mfma op.
lhs = builder.create<vector::ExtractOp>(loc, lhs, ArrayRef{int64_t{0}});
rhs = builder.create<vector::ExtractOp>(loc, rhs, ArrayRef{int64_t{0}});
auto [m, n, k] = getMNKShape();
return builder
.create<amdgpu::MFMAOp>(loc, resultType, m, n, k, getBlockSize(), lhs,
rhs, acc)
.getResult();
}
case MMAIntrinsic::MFMA_I32_16x16x16_I8:
case MMAIntrinsic::MFMA_F32_16x16x16_F16:
case MMAIntrinsic::MFMA_F32_16x16x16_BF16:
case MMAIntrinsic::MFMA_I32_32x32x8_I8:
case MMAIntrinsic::MFMA_F32_32x32x8_F16:
case MMAIntrinsic::MFMA_F32_32x32x8_BF16:
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ:
case MMAIntrinsic::MFMA_I32_16x16x32_I8:
case MMAIntrinsic::MFMA_I32_32x32x16_I8: {
auto getVecOrSingleElem = [&](Value vec) -> Value {
bool one = llvm::cast<VectorType>(vec.getType()).getNumElements() == 1;
return one ? builder.create<vector::ExtractOp>(loc, vec, 0) : vec;
};
MMAIntrinsic intrinsic = getIntrinsic().getValue();
if (is_AMD_MFMA(intrinsic)) {
// MFMA intrinsics want single-element operands of element type, not vector.
lhs = getVecOrSingleElem(lhs);
rhs = getVecOrSingleElem(rhs);
auto [m, n, k] = getMNKShape();
return builder
.create<amdgpu::MFMAOp>(loc, resultType, m, n, k, getBlockSize(), lhs,
rhs, acc)
.getResult();
}
case MMAIntrinsic::WMMA_F32_16x16x16_F16:
case MMAIntrinsic::WMMA_F16_16x16x16_F16:
case MMAIntrinsic::WMMA_I32_16x16x16_I8: {
} else if (is_AMD_WMMA(intrinsic)) {
return builder.create<amdgpu::WMMAOp>(loc, resultType, lhs, rhs, acc)
.getResult();
}
}
return failure();
}

Expand Down
Loading

0 comments on commit bbb87aa

Please sign in to comment.