Skip to content

Commit

Permalink
Plumb through MMAAttr
Browse files Browse the repository at this point in the history
  • Loading branch information
kuhar committed Nov 11, 2024
1 parent 66db517 commit 3f88d20
Show file tree
Hide file tree
Showing 8 changed files with 77 additions and 25 deletions.
12 changes: 9 additions & 3 deletions compiler/bindings/c/iree/compiler/dialects/iree_gpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,14 @@ ireeGPUMMAIntrinsicAttrGet(MlirContext mlirCtx, ireeGPUMMAIntrinsicEnum value);
MLIR_CAPI_EXPORTED ireeGPUMMAIntrinsicEnum
ireeGPUMMAIntrinsicAttrGetValue(MlirAttribute attr);

struct ireeGPUMMAIntrinsicInfo {
MLIR_CAPI_EXPORTED bool ireeAttributeIsAGPUMMAAttr(MlirAttribute attr);

MLIR_CAPI_EXPORTED MlirTypeID ireeGPUMMAAttrGetTypeID(void);

MLIR_CAPI_EXPORTED MlirAttribute
ireeGPUMMAAttrGet(MlirContext mlirCtx, ireeGPUMMAIntrinsicEnum value);

struct ireeGPUMMAInfo {
MlirType aElementType;
MlirType bElementType;
MlirType cElementType;
Expand All @@ -111,8 +118,7 @@ struct ireeGPUMMAIntrinsicInfo {
int64_t kElements;
};

MLIR_CAPI_EXPORTED ireeGPUMMAIntrinsicInfo
ireeGPUMMAIntrinsicAttrGetInfo(MlirAttribute attr);
MLIR_CAPI_EXPORTED ireeGPUMMAInfo ireeGPUMMAAttrGetInfo(MlirAttribute attr);

#ifdef __cplusplus
}
Expand Down
19 changes: 16 additions & 3 deletions compiler/bindings/python/IREECompilerDialectsModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,22 +170,35 @@ PYBIND11_MODULE(_ireeCompilerDialects, m) {
"cls"_a, "value"_a, "ctx"_a = py::none(),
"Gets a gpu.mma_intrinsic from parameters.")
.def_property_readonly("value", ireeGPUMMAIntrinsicAttrGetValue)
.def_property_readonly("mma", [](MlirAttribute self) -> MlirAttribute {
ireeGPUMMAIntrinsicEnum value = ireeGPUMMAIntrinsicAttrGetValue(self);
return ireeGPUMMAAttrGet(mlirAttributeGetContext(self), value);
});

mlir_attribute_subclass(iree_gpu_module, "MMAAttr",
ireeAttributeIsAGPUMMAAttr, ireeGPUMMAAttrGetTypeID)
.def_classmethod(
"get",
[](const py::object &, ireeGPUMMAIntrinsicEnum value,
MlirContext ctx) { return ireeGPUMMAAttrGet(ctx, value); },
"cls"_a, "value"_a, "ctx"_a = py::none(),
"Gets a gpu.mma from parameters.")
.def_property_readonly(
"abc_element_types",
[](MlirAttribute self) -> py::tuple {
ireeGPUMMAIntrinsicInfo info = ireeGPUMMAIntrinsicAttrGetInfo(self);
ireeGPUMMAInfo info = ireeGPUMMAAttrGetInfo(self);
return py::make_tuple(info.aElementType, info.bElementType,
info.cElementType);
})
.def_property_readonly(
"abc_vector_types",
[](MlirAttribute self) -> py::tuple {
ireeGPUMMAIntrinsicInfo info = ireeGPUMMAIntrinsicAttrGetInfo(self);
ireeGPUMMAInfo info = ireeGPUMMAAttrGetInfo(self);
return py::make_tuple(info.aVectorType, info.bVectorType,
info.cVectorType);
})
.def_property_readonly("mnk_shape", [](MlirAttribute self) -> py::tuple {
ireeGPUMMAIntrinsicInfo info = ireeGPUMMAIntrinsicAttrGetInfo(self);
ireeGPUMMAInfo info = ireeGPUMMAAttrGetInfo(self);
return py::make_tuple(info.mElements, info.nElements, info.kElements);
});
}
18 changes: 14 additions & 4 deletions compiler/bindings/python/test/ir/dialects_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,21 @@ def mma_intrinsic_attr():
with ir.Context() as ctx, ir.Location.unknown():
module = ir.Module.create()
with ir.InsertionPoint(module.body):
mma_attr = iree_gpu.MMAIntrinsicAttr.get(
mma_intrinsic_attr = iree_gpu.MMAIntrinsicAttr.get(
iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16, ctx
)
assert mma_intrinsic_attr is not None
assert (
mma_intrinsic_attr.value == iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16
)
assert (
str(mma_intrinsic_attr)
== "#iree_gpu<mma_intrinsic MFMA_F32_32x32x8_F16>"
)
assert str(mma_intrinsic_attr.value) == "MFMA_F32_32x32x8_F16"

mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic_attr.value, ctx)
assert mma_attr is not None
assert mma_attr.value == iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16
assert str(mma_attr) == "#iree_gpu<mma_intrinsic MFMA_F32_32x32x8_F16>"
assert str(mma_attr.value) == "MFMA_F32_32x32x8_F16"

f16 = ir.F16Type.get()
f32 = ir.F32Type.get()
Expand All @@ -117,3 +125,5 @@ def mma_intrinsic_attr():
assert M == 32
assert N == 32
assert K == 8

assert mma_intrinsic_attr.mma == mma_attr
28 changes: 18 additions & 10 deletions compiler/src/iree/compiler/API/Internal/IREEGPUDialectCAPI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,17 +167,26 @@ ireeGPUMMAIntrinsicEnum ireeGPUMMAIntrinsicAttrGetValue(MlirAttribute attr) {
.getValue());
}

ireeGPUMMAIntrinsicInfo ireeGPUMMAIntrinsicAttrGetInfo(MlirAttribute attr) {
assert(ireeAttributeIsAGPUMMAIntrinsicAttr(attr) &&
"attr is not a GPUMMAIntrinsicAttr");
auto intrinsicAttr =
llvm::cast<mlir::iree_compiler::IREE::GPU::MMAIntrinsicAttr>(
unwrap(attr));
auto mma = mlir::iree_compiler::IREE::GPU::MMAAttr::get(
intrinsicAttr.getContext(), intrinsicAttr.getValue());
bool ireeAttributeIsAGPUMMAAttr(MlirAttribute attr) {
return llvm::isa<mlir::iree_compiler::IREE::GPU::MMAAttr>(unwrap(attr));
}

ireeGPUMMAIntrinsicInfo info = {};
MlirTypeID ireeGPUMMAAttrGetTypeID() {
return wrap(mlir::iree_compiler::IREE::GPU::MMAAttr::getTypeID());
}

MlirAttribute ireeGPUMMAAttrGet(MlirContext mlirCtx,
ireeGPUMMAIntrinsicEnum value) {
mlir::MLIRContext *ctx = unwrap(mlirCtx);
return wrap(mlir::iree_compiler::IREE::GPU::MMAAttr::get(
ctx, static_cast<mlir::iree_compiler::IREE::GPU::MMAIntrinsic>(value)));
}

ireeGPUMMAInfo ireeGPUMMAAttrGetInfo(MlirAttribute attr) {
assert(ireeAttributeIsAGPUMMAAttr(attr) && "attr is not a MMAAttr");
auto mma = llvm::cast<mlir::iree_compiler::IREE::GPU::MMAAttr>(unwrap(attr));

ireeGPUMMAInfo info = {};
auto [aType, bType, cType] = mma.getABCElementTypes();
info.aElementType = wrap(aType);
info.bElementType = wrap(bType);
Expand All @@ -189,7 +198,6 @@ ireeGPUMMAIntrinsicInfo ireeGPUMMAIntrinsicAttrGetInfo(MlirAttribute attr) {
info.cVectorType = wrap(cVecType);

std::tie(info.mElements, info.nElements, info.kElements) = mma.getMNKShape();

return info;
}

Expand Down
10 changes: 8 additions & 2 deletions compiler/src/iree/compiler/API/api_exports.c
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include <stdint.h>

extern void ireeAttributeIsAGPUMMAAttr();
extern void ireeAttributeIsAGPUMMAIntrinsicAttr();
extern void ireeAttributeIsAGPUPipelineOptionsAttr();
extern void ireeAttributeIsAGPUReorderWorkgroupsStrategyAttr();
Expand Down Expand Up @@ -64,8 +65,10 @@ extern void ireeCompilerSourceDestroy();
extern void ireeCompilerSourceOpenFile();
extern void ireeCompilerSourceSplit();
extern void ireeCompilerSourceWrapBuffer();
extern void ireeGPUMMAAttrGet();
extern void ireeGPUMMAAttrGetInfo();
extern void ireeGPUMMAAttrGetTypeID();
extern void ireeGPUMMAIntrinsicAttrGet();
extern void ireeGPUMMAIntrinsicAttrGetInfo();
extern void ireeGPUMMAIntrinsicAttrGetTypeID();
extern void ireeGPUMMAIntrinsicAttrGetValue();
extern void ireeGPUPipelineOptionsAttrGet();
Expand Down Expand Up @@ -853,6 +856,7 @@ extern void mlirVectorTypeIsScalable();

uintptr_t __iree_compiler_hidden_force_extern() {
uintptr_t x = 0;
x += (uintptr_t)&ireeAttributeIsAGPUMMAAttr;
x += (uintptr_t)&ireeAttributeIsAGPUMMAIntrinsicAttr;
x += (uintptr_t)&ireeAttributeIsAGPUPipelineOptionsAttr;
x += (uintptr_t)&ireeAttributeIsAGPUReorderWorkgroupsStrategyAttr;
Expand Down Expand Up @@ -907,8 +911,10 @@ uintptr_t __iree_compiler_hidden_force_extern() {
x += (uintptr_t)&ireeCompilerSourceOpenFile;
x += (uintptr_t)&ireeCompilerSourceSplit;
x += (uintptr_t)&ireeCompilerSourceWrapBuffer;
x += (uintptr_t)&ireeGPUMMAAttrGet;
x += (uintptr_t)&ireeGPUMMAAttrGetInfo;
x += (uintptr_t)&ireeGPUMMAAttrGetTypeID;
x += (uintptr_t)&ireeGPUMMAIntrinsicAttrGet;
x += (uintptr_t)&ireeGPUMMAIntrinsicAttrGetInfo;
x += (uintptr_t)&ireeGPUMMAIntrinsicAttrGetTypeID;
x += (uintptr_t)&ireeGPUMMAIntrinsicAttrGetValue;
x += (uintptr_t)&ireeGPUPipelineOptionsAttrGet;
Expand Down
5 changes: 4 additions & 1 deletion compiler/src/iree/compiler/API/api_exports.def
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
; Generated by generate_exports.py: Do not edit.
EXPORTS
ireeAttributeIsAGPUMMAAttr
ireeAttributeIsAGPUMMAIntrinsicAttr
ireeAttributeIsAGPUPipelineOptionsAttr
ireeAttributeIsAGPUReorderWorkgroupsStrategyAttr
Expand Down Expand Up @@ -54,8 +55,10 @@ EXPORTS
ireeCompilerSourceOpenFile
ireeCompilerSourceSplit
ireeCompilerSourceWrapBuffer
ireeGPUMMAAttrGet
ireeGPUMMAAttrGetInfo
ireeGPUMMAAttrGetTypeID
ireeGPUMMAIntrinsicAttrGet
ireeGPUMMAIntrinsicAttrGetInfo
ireeGPUMMAIntrinsicAttrGetTypeID
ireeGPUMMAIntrinsicAttrGetValue
ireeGPUPipelineOptionsAttrGet
Expand Down
5 changes: 4 additions & 1 deletion compiler/src/iree/compiler/API/api_exports.ld
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Generated by generate_exports.py: Do not edit.
VER_0 {
global:
ireeAttributeIsAGPUMMAAttr;
ireeAttributeIsAGPUMMAIntrinsicAttr;
ireeAttributeIsAGPUPipelineOptionsAttr;
ireeAttributeIsAGPUReorderWorkgroupsStrategyAttr;
Expand Down Expand Up @@ -55,8 +56,10 @@ VER_0 {
ireeCompilerSourceOpenFile;
ireeCompilerSourceSplit;
ireeCompilerSourceWrapBuffer;
ireeGPUMMAAttrGet;
ireeGPUMMAAttrGetInfo;
ireeGPUMMAAttrGetTypeID;
ireeGPUMMAIntrinsicAttrGet;
ireeGPUMMAIntrinsicAttrGetInfo;
ireeGPUMMAIntrinsicAttrGetTypeID;
ireeGPUMMAIntrinsicAttrGetValue;
ireeGPUPipelineOptionsAttrGet;
Expand Down
5 changes: 4 additions & 1 deletion compiler/src/iree/compiler/API/api_exports.macos.lst
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Generated by generate_exports.py: Do not edit.
_ireeAttributeIsAGPUMMAAttr
_ireeAttributeIsAGPUMMAIntrinsicAttr
_ireeAttributeIsAGPUPipelineOptionsAttr
_ireeAttributeIsAGPUReorderWorkgroupsStrategyAttr
Expand Down Expand Up @@ -53,8 +54,10 @@ _ireeCompilerSourceDestroy
_ireeCompilerSourceOpenFile
_ireeCompilerSourceSplit
_ireeCompilerSourceWrapBuffer
_ireeGPUMMAAttrGet
_ireeGPUMMAAttrGetInfo
_ireeGPUMMAAttrGetTypeID
_ireeGPUMMAIntrinsicAttrGet
_ireeGPUMMAIntrinsicAttrGetInfo
_ireeGPUMMAIntrinsicAttrGetTypeID
_ireeGPUMMAIntrinsicAttrGetValue
_ireeGPUPipelineOptionsAttrGet
Expand Down

0 comments on commit 3f88d20

Please sign in to comment.