Skip to content

Commit

Permalink
[python] Add bindings for MMAIntrinsic
Browse files Browse the repository at this point in the history
Signed-off-by: Jakub Kuderski <[email protected]>
  • Loading branch information
kuhar committed Nov 11, 2024
1 parent c0dff68 commit e3347ca
Show file tree
Hide file tree
Showing 11 changed files with 316 additions and 21 deletions.
70 changes: 68 additions & 2 deletions compiler/bindings/c/iree/compiler/dialects/iree_gpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,20 @@
extern "C" {
#endif

// The following C API is **NOT STABLE** and likely to change in the future.
// It mirrors the IREE GPU Dialect which is not stable itself.

#define IREE_GPU_FOR_ALL_REORDER_WORKGROUP_VALUES \
X_DO(None, 0) \
X_DO(Transpose, 1)

enum ireeGPUReorderWorkgroupsStrategyEnum {
ireeGPUReorderWorkgroupsStrategyEnumNone = 0,
ireeGPUReorderWorkgroupsStrategyEnumTranspose = 1,
#define X_DO(EnumName, EnumValue) \
ireeGPUReorderWorkgroupsStrategyEnum##EnumName = EnumValue,

IREE_GPU_FOR_ALL_REORDER_WORKGROUP_VALUES

#undef X_DO
};

MLIR_CAPI_EXPORTED bool
Expand Down Expand Up @@ -54,6 +65,61 @@ ireeGPUPipelineOptionsAttrGetReorderWorkgroupsStrategy(MlirAttribute attr);

MLIR_CAPI_EXPORTED MlirTypeID ireeGPUPipelineOptionsAttrGetTypeID(void);

#define IREE_GPU_FOR_ALL_MMA_INTRINSIC_VALUES \
X_DO(MFMA_F32_16x16x4_F32, 0x0900) \
X_DO(MFMA_F32_16x16x16_F16, 0x0910) \
X_DO(MFMA_F32_32x32x8_F16, 0x0911) \
X_DO(MFMA_F32_16x16x16_BF16, 0x0920) \
X_DO(MFMA_F32_32x32x8_BF16, 0x0921) \
X_DO(MFMA_F32_16x16x32_F8E4M3FNUZ, 0x0940) \
X_DO(MFMA_F32_16x16x32_F8E5M2FNUZ, 0x0930) \
X_DO(MFMA_I32_16x16x32_I8, 0x0980) \
X_DO(MFMA_I32_32x32x16_I8, 0x0981) \
X_DO(MFMA_I32_16x16x16_I8, 0x0880) \
X_DO(MFMA_I32_32x32x8_I8, 0x0881) \
X_DO(WMMA_F32_16x16x16_F16, 0x0010) \
X_DO(WMMA_F16_16x16x16_F16, 0x0011) \
X_DO(WMMA_I32_16x16x16_I8, 0x0080)

enum ireeGPUMMAIntrinsicEnum {
#define X_DO(EnumName, EnumValue) ireeGPUMMAIntrinsicEnum##EnumName = EnumValue,

IREE_GPU_FOR_ALL_MMA_INTRINSIC_VALUES

#undef X_DO
};

MLIR_CAPI_EXPORTED bool ireeAttributeIsAGPUMMAIntrinsicAttr(MlirAttribute attr);

MLIR_CAPI_EXPORTED MlirTypeID ireeGPUMMAIntrinsicAttrGetTypeID(void);

MLIR_CAPI_EXPORTED MlirAttribute
ireeGPUMMAIntrinsicAttrGet(MlirContext mlirCtx, ireeGPUMMAIntrinsicEnum value);

MLIR_CAPI_EXPORTED ireeGPUMMAIntrinsicEnum
ireeGPUMMAIntrinsicAttrGetValue(MlirAttribute attr);

MLIR_CAPI_EXPORTED bool ireeAttributeIsAGPUMMAAttr(MlirAttribute attr);

MLIR_CAPI_EXPORTED MlirTypeID ireeGPUMMAAttrGetTypeID(void);

MLIR_CAPI_EXPORTED MlirAttribute
ireeGPUMMAAttrGet(MlirContext mlirCtx, ireeGPUMMAIntrinsicEnum value);

struct ireeGPUMMAInfo {
MlirType aElementType;
MlirType bElementType;
MlirType cElementType;
MlirType aVectorType;
MlirType bVectorType;
MlirType cVectorType;
int64_t mElements;
int64_t nElements;
int64_t kElements;
};

MLIR_CAPI_EXPORTED ireeGPUMMAInfo ireeGPUMMAAttrGetInfo(MlirAttribute attr);

#ifdef __cplusplus
}
#endif
Expand Down
82 changes: 77 additions & 5 deletions compiler/bindings/python/IREECompilerDialectsModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,8 @@ PYBIND11_MODULE(_ireeCompilerDialects, m) {
},
"cls"_a, "value"_a, "ctx"_a = py::none(),
"Gets a gpu.reorder_workgroups_strategy from parameters.")
.def_property_readonly(
"value",
[](MlirAttribute self) -> ireeGPUReorderWorkgroupsStrategyEnum {
return ireeGPUReorderWorkgroupsStrategyAttrGetValue(self);
});
.def_property_readonly("value",
ireeGPUReorderWorkgroupsStrategyAttrGetValue);

//===-------------------------------------------------------------------===//
// GPUPipelineOptionsAttr
Expand Down Expand Up @@ -129,4 +126,79 @@ PYBIND11_MODULE(_ireeCompilerDialects, m) {
return attr;
return std::nullopt;
});

//===-------------------------------------------------------------------===//
// GPUMMAIntrinsicAttr
//===-------------------------------------------------------------------===//
auto mmaIntrinsicEnum =
py::enum_<ireeGPUMMAIntrinsicEnum>(iree_gpu_module, "MMAIntrinsic",
py::module_local())

#define X_DO(EnumName, EnumValue) \
.value(#EnumName, ireeGPUMMAIntrinsicEnum##EnumName)

IREE_GPU_FOR_ALL_MMA_INTRINSIC_VALUES

#undef X_DO

.def(
"__str__",
[](ireeGPUMMAIntrinsicEnum &self) {
switch (self) {
#define X_DO(EnumName, EnumValue) \
case ireeGPUMMAIntrinsicEnum##EnumName: \
return #EnumName;

IREE_GPU_FOR_ALL_MMA_INTRINSIC_VALUES

#undef X_DO
default:
llvm::report_fatal_error("unknown MMAIntrinsic variant");
}
},
py::prepend());

mlir_attribute_subclass(iree_gpu_module, "MMAIntrinsicAttr",
ireeAttributeIsAGPUMMAIntrinsicAttr,
ireeGPUMMAIntrinsicAttrGetTypeID)
.def_classmethod(
"get",
[](const py::object &, ireeGPUMMAIntrinsicEnum value,
MlirContext ctx) {
return ireeGPUMMAIntrinsicAttrGet(ctx, value);
},
"cls"_a, "value"_a, "ctx"_a = py::none(),
"Gets a gpu.mma_intrinsic from parameters.")
.def_property_readonly("value", ireeGPUMMAIntrinsicAttrGetValue)
.def_property_readonly("mma", [](MlirAttribute self) -> MlirAttribute {
ireeGPUMMAIntrinsicEnum value = ireeGPUMMAIntrinsicAttrGetValue(self);
return ireeGPUMMAAttrGet(mlirAttributeGetContext(self), value);
});

mlir_attribute_subclass(iree_gpu_module, "MMAAttr",
ireeAttributeIsAGPUMMAAttr, ireeGPUMMAAttrGetTypeID)
.def_classmethod(
"get",
[](const py::object &, ireeGPUMMAIntrinsicEnum value,
MlirContext ctx) { return ireeGPUMMAAttrGet(ctx, value); },
"cls"_a, "value"_a, "ctx"_a = py::none(),
"Gets a gpu.mma from parameters.")
.def_property_readonly(
"abc_element_types",
[](MlirAttribute self) -> py::tuple {
ireeGPUMMAInfo info = ireeGPUMMAAttrGetInfo(self);
return py::make_tuple(info.aElementType, info.bElementType,
info.cElementType);
})
.def_property_readonly(
"abc_vector_types",
[](MlirAttribute self) -> py::tuple {
ireeGPUMMAInfo info = ireeGPUMMAAttrGetInfo(self);
return py::make_tuple(info.aVectorType, info.bVectorType,
info.cVectorType);
})
.def_property_readonly("mnk_shape", [](MlirAttribute self) -> py::tuple {
ireeGPUMMAInfo info = ireeGPUMMAAttrGetInfo(self);
return py::make_tuple(info.mElements, info.nElements, info.kElements);
});
}
41 changes: 41 additions & 0 deletions compiler/bindings/python/test/ir/dialects_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,44 @@ def gpu_pipeline_options_attr():
# unfortunately not `is`
== iree_gpu.ReorderWorkgroupsStrategy.Transpose
)


@lambda _: _()
def mma_intrinsic_attr():
with ir.Context() as ctx, ir.Location.unknown():
module = ir.Module.create()
with ir.InsertionPoint(module.body):
mma_intrinsic_attr = iree_gpu.MMAIntrinsicAttr.get(
iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16, ctx
)
assert mma_intrinsic_attr is not None
assert (
mma_intrinsic_attr.value == iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16
)
assert (
str(mma_intrinsic_attr)
== "#iree_gpu<mma_intrinsic MFMA_F32_32x32x8_F16>"
)
assert str(mma_intrinsic_attr.value) == "MFMA_F32_32x32x8_F16"

mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic_attr.value, ctx)
assert mma_attr is not None

f16 = ir.F16Type.get()
f32 = ir.F32Type.get()
a_type, b_type, c_type = mma_attr.abc_element_types
assert a_type == f16
assert b_type == f16
assert c_type == f32

vec_4xf16 = ir.VectorType.get((4,), f16)
a_vec_type, b_vec_type, _c_vec_type = mma_attr.abc_vector_types
assert a_vec_type == vec_4xf16
assert b_vec_type == vec_4xf16

M, N, K = mma_attr.mnk_shape
assert M == 32
assert N == 32
assert K == 8

assert mma_intrinsic_attr.mma == mma_attr
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/API/Internal/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ iree_compiler_cc_library(
deps = [
"//compiler/bindings/c:headers",
"//compiler/src/iree/compiler/Codegen/Dialect/GPU/IR:IREEGPUDialect",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:CAPIIR",
"@llvm-project//mlir:CAPIIRHeaders",
],
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/API/Internal/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ iree_cc_library(
"IREEGPUDialectCAPI.cpp"
DEPS
IREELLVMIncludeSetup
LLVMSupport
MLIRCAPIIR
iree::compiler::Codegen::Dialect::GPU::IR::IREEGPUDialect
iree::compiler::bindings::c::headers
Expand Down
100 changes: 87 additions & 13 deletions compiler/src/iree/compiler/API/Internal/IREEGPUDialectCAPI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,20 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUInterfaces.h"
#include "iree/compiler/dialects/iree_gpu.h"
#include "llvm/ADT/STLExtras.h"
#include "mlir-c/IR.h"
#include "mlir/CAPI/IR.h"
#include "mlir/CAPI/Support.h"

// Macro to check that enum values match across C++ and C API.
#define ASSERT_ENUM_VALS_MATCH(CppEnumName, CAPIEnumName, ValueName) \
static_assert(llvm::to_underlying(CppEnumName ::ValueName) == \
llvm::to_underlying(CAPIEnumName##ValueName), \
#CppEnumName " and " #CAPIEnumName " have diverged")

bool ireeAttributeIsAGPUPipelineOptionsAttr(MlirAttribute attr) {
return llvm::isa<mlir::iree_compiler::IREE::GPU::GPUPipelineOptionsAttr>(
unwrap(attr));
Expand Down Expand Up @@ -84,19 +93,16 @@ MlirTypeID ireeGPUPipelineOptionsAttrGetTypeID() {
mlir::iree_compiler::IREE::GPU::GPUPipelineOptionsAttr::getTypeID());
}

static_assert(
static_cast<uint32_t>(ireeGPUReorderWorkgroupsStrategyEnumNone) ==
static_cast<uint32_t>(mlir::iree_compiler::IREE::GPU::
ReorderWorkgroupsStrategy::None) &&
static_cast<uint32_t>(ireeGPUReorderWorkgroupsStrategyEnumTranspose) ==
static_cast<uint32_t>(mlir::iree_compiler::IREE::GPU::
ReorderWorkgroupsStrategy::Transpose) &&
static_cast<uint32_t>(ireeGPUReorderWorkgroupsStrategyEnumTranspose) ==
mlir::iree_compiler::IREE::GPU::
getMaxEnumValForReorderWorkgroupsStrategy(),
"ireeGPUReorderWorkgroupsStrategyEnum and "
"mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategy definitions "
"have diverged");
#define X_DO(EnumName, EnumValue) \
ASSERT_ENUM_VALS_MATCH( \
mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategy, \
ireeGPUReorderWorkgroupsStrategyEnum, EnumName);

IREE_GPU_FOR_ALL_REORDER_WORKGROUP_VALUES

#undef X_DO

#undef FOR_ALL_REORDER_WORKGROUP_VALUES

bool ireeAttributeIsAGPUReorderWorkgroupsStrategyAttr(MlirAttribute attr) {
return llvm::isa<
Expand Down Expand Up @@ -128,3 +134,71 @@ ireeGPUReorderWorkgroupsStrategyAttrGetValue(MlirAttribute attr) {
unwrap(attr))
.getValue());
}

#define X_DO(EnumName, EnumVal) \
ASSERT_ENUM_VALS_MATCH(mlir::iree_compiler::IREE::GPU::MMAIntrinsic, \
ireeGPUMMAIntrinsicEnum, EnumName);

IREE_GPU_FOR_ALL_MMA_INTRINSIC_VALUES

#undef X_DO

bool ireeAttributeIsAGPUMMAIntrinsicAttr(MlirAttribute attr) {
return llvm::isa<mlir::iree_compiler::IREE::GPU::MMAIntrinsicAttr>(
unwrap(attr));
}

MlirTypeID ireeGPUMMAIntrinsicAttrGetTypeID() {
return wrap(mlir::iree_compiler::IREE::GPU::MMAIntrinsicAttr::getTypeID());
}

MlirAttribute ireeGPUMMAIntrinsicAttrGet(MlirContext mlirCtx,
ireeGPUMMAIntrinsicEnum value) {
mlir::MLIRContext *ctx = unwrap(mlirCtx);
return wrap(mlir::iree_compiler::IREE::GPU::MMAIntrinsicAttr::get(
ctx, static_cast<mlir::iree_compiler::IREE::GPU::MMAIntrinsic>(value)));
}

ireeGPUMMAIntrinsicEnum ireeGPUMMAIntrinsicAttrGetValue(MlirAttribute attr) {
assert(ireeAttributeIsAGPUMMAIntrinsicAttr(attr) &&
"attr is not a GPUMMAIntrinsicAttr");
return static_cast<ireeGPUMMAIntrinsicEnum>(
llvm::cast<mlir::iree_compiler::IREE::GPU::MMAIntrinsicAttr>(unwrap(attr))
.getValue());
}

bool ireeAttributeIsAGPUMMAAttr(MlirAttribute attr) {
return llvm::isa<mlir::iree_compiler::IREE::GPU::MMAAttr>(unwrap(attr));
}

MlirTypeID ireeGPUMMAAttrGetTypeID() {
return wrap(mlir::iree_compiler::IREE::GPU::MMAAttr::getTypeID());
}

MlirAttribute ireeGPUMMAAttrGet(MlirContext mlirCtx,
ireeGPUMMAIntrinsicEnum value) {
mlir::MLIRContext *ctx = unwrap(mlirCtx);
return wrap(mlir::iree_compiler::IREE::GPU::MMAAttr::get(
ctx, static_cast<mlir::iree_compiler::IREE::GPU::MMAIntrinsic>(value)));
}

ireeGPUMMAInfo ireeGPUMMAAttrGetInfo(MlirAttribute attr) {
assert(ireeAttributeIsAGPUMMAAttr(attr) && "attr is not a MMAAttr");
auto mma = llvm::cast<mlir::iree_compiler::IREE::GPU::MMAAttr>(unwrap(attr));

ireeGPUMMAInfo info = {};
auto [aType, bType, cType] = mma.getABCElementTypes();
info.aElementType = wrap(aType);
info.bElementType = wrap(bType);
info.cElementType = wrap(cType);

auto [aVecType, bVecType, cVecType] = mma.getABCVectorTypes();
info.aVectorType = wrap(aVecType);
info.bVectorType = wrap(bVecType);
info.cVectorType = wrap(cVecType);

std::tie(info.mElements, info.nElements, info.kElements) = mma.getMNKShape();
return info;
}

#undef ASSERT_ENUM_VALS_MATCH
Loading

0 comments on commit e3347ca

Please sign in to comment.