Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python][tuner] Add bindings for MMAIntrinsic #19095

Merged
merged 5 commits into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 35 additions & 7 deletions compiler/bindings/c/iree/compiler/dialects/iree_gpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,19 @@
extern "C" {
#endif

enum ireeGPUReorderWorkgroupsStrategyEnum {
ireeGPUReorderWorkgroupsStrategyEnumNone = 0,
ireeGPUReorderWorkgroupsStrategyEnumTranspose = 1,
};
// The following C API is **NOT STABLE** and likely to change in the future.
// It mirrors the IREE GPU Dialect which is not stable itself.

MLIR_CAPI_EXPORTED bool
ireeAttributeIsAGPUReorderWorkgroupsStrategyAttr(MlirAttribute attr);

MLIR_CAPI_EXPORTED MlirTypeID
ireeGPUReorderWorkgroupsStrategyAttrGetTypeID(void);

MLIR_CAPI_EXPORTED MlirAttribute ireeGPUReorderWorkgroupsStrategyAttrGet(
MlirContext mlirCtx, ireeGPUReorderWorkgroupsStrategyEnum value);
MLIR_CAPI_EXPORTED MlirAttribute
ireeGPUReorderWorkgroupsStrategyAttrGet(MlirContext mlirCtx, uint32_t value);

MLIR_CAPI_EXPORTED ireeGPUReorderWorkgroupsStrategyEnum
MLIR_CAPI_EXPORTED uint32_t
ireeGPUReorderWorkgroupsStrategyAttrGetValue(MlirAttribute attr);

MLIR_CAPI_EXPORTED
Expand All @@ -54,6 +52,36 @@ ireeGPUPipelineOptionsAttrGetReorderWorkgroupsStrategy(MlirAttribute attr);

MLIR_CAPI_EXPORTED MlirTypeID ireeGPUPipelineOptionsAttrGetTypeID(void);

MLIR_CAPI_EXPORTED bool ireeAttributeIsAGPUMMAIntrinsicAttr(MlirAttribute attr);

MLIR_CAPI_EXPORTED MlirTypeID ireeGPUMMAIntrinsicAttrGetTypeID(void);

MLIR_CAPI_EXPORTED MlirAttribute ireeGPUMMAIntrinsicAttrGet(MlirContext mlirCtx,
uint32_t value);

MLIR_CAPI_EXPORTED uint32_t ireeGPUMMAIntrinsicAttrGetValue(MlirAttribute attr);

MLIR_CAPI_EXPORTED bool ireeAttributeIsAGPUMMAAttr(MlirAttribute attr);

MLIR_CAPI_EXPORTED MlirTypeID ireeGPUMMAAttrGetTypeID(void);

MLIR_CAPI_EXPORTED MlirAttribute ireeGPUMMAAttrGet(MlirContext mlirCtx,
uint32_t value);

struct ireeGPUMMAInfo {
MlirType aElementType;
MlirType bElementType;
MlirType cElementType;
MlirType aVectorType;
MlirType bVectorType;
MlirType cVectorType;
int64_t mElements;
int64_t nElements;
int64_t kElements;
};

MLIR_CAPI_EXPORTED ireeGPUMMAInfo ireeGPUMMAAttrGetInfo(MlirAttribute attr);

#ifdef __cplusplus
}
#endif
Expand Down
96 changes: 66 additions & 30 deletions compiler/bindings/python/IREECompilerDialectsModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,15 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include <cstdint>
#include "iree/compiler/dialects/iree_gpu.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/IR.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"

static const char *kGpuModuleImportPath =
MAKE_MLIR_PYTHON_QUALNAME("dialects.iree_gpu");

namespace py = pybind11;
using namespace mlir::python::adaptors;

Expand All @@ -22,45 +26,23 @@ PYBIND11_MODULE(_ireeCompilerDialects, m) {
// GPUReorderWorkgroupsStrategyAttr
//===-------------------------------------------------------------------===//

auto strategyEnum =
py::enum_<ireeGPUReorderWorkgroupsStrategyEnum>(
iree_gpu_module, "ReorderWorkgroupsStrategy", py::module_local())
.value("None_", ireeGPUReorderWorkgroupsStrategyEnumNone)
.value("Transpose", ireeGPUReorderWorkgroupsStrategyEnumTranspose)
.def(
"__str__",
[](ireeGPUReorderWorkgroupsStrategyEnum &self) {
switch (self) {
case ireeGPUReorderWorkgroupsStrategyEnumNone:
return "None";
case ireeGPUReorderWorkgroupsStrategyEnumTranspose:
return "Transpose";
default:
llvm::report_fatal_error(
"unknown ReorderWorkgroupsStrategy variant");
}
},
// pybind overloads are tried in the order they were registered.
// As a result, enums used the default __str__ method instead of
// the custom one. Adding py::prepend() fixes this issue.
py::prepend());

mlir_attribute_subclass(iree_gpu_module, "ReorderWorkgroupsStrategyAttr",
ireeAttributeIsAGPUReorderWorkgroupsStrategyAttr,
ireeGPUReorderWorkgroupsStrategyAttrGetTypeID)
.def_classmethod(
"get",
[](const py::object &, ireeGPUReorderWorkgroupsStrategyEnum value,
MlirContext ctx) {
[](const py::object &, uint32_t value, MlirContext ctx) {
return ireeGPUReorderWorkgroupsStrategyAttrGet(ctx, value);
},
"cls"_a, "value"_a, "ctx"_a = py::none(),
"Gets a gpu.reorder_workgroups_strategy from parameters.")
.def_property_readonly(
"value",
[](MlirAttribute self) -> ireeGPUReorderWorkgroupsStrategyEnum {
return ireeGPUReorderWorkgroupsStrategyAttrGetValue(self);
});
.def_property_readonly("raw_value",
ireeGPUReorderWorkgroupsStrategyAttrGetValue)
.def_property_readonly("value", [](MlirAttribute self) -> py::object {
uint32_t rawValue = ireeGPUReorderWorkgroupsStrategyAttrGetValue(self);
return py::module_::import(kGpuModuleImportPath)
.attr("ReorderWorkgroupsStrategy")(rawValue);
});

//===-------------------------------------------------------------------===//
// GPUPipelineOptionsAttr
Expand Down Expand Up @@ -129,4 +111,58 @@ PYBIND11_MODULE(_ireeCompilerDialects, m) {
return attr;
return std::nullopt;
});

//===-------------------------------------------------------------------===//
// GPUMMAIntrinsicAttr
//===-------------------------------------------------------------------===//
mlir_attribute_subclass(iree_gpu_module, "MMAIntrinsicAttr",
ireeAttributeIsAGPUMMAIntrinsicAttr,
ireeGPUMMAIntrinsicAttrGetTypeID)
.def_classmethod(
"get",
[](const py::object &, uint32_t value, MlirContext ctx) {
return ireeGPUMMAIntrinsicAttrGet(ctx, value);
},
"cls"_a, "value"_a, "ctx"_a = py::none(),
"Gets a gpu.mma_intrinsic from parameters.")
.def_property_readonly("raw_value", ireeGPUMMAIntrinsicAttrGetValue)
.def_property_readonly("value",
[](MlirAttribute self) -> py::object {
uint32_t rawValue =
ireeGPUMMAIntrinsicAttrGetValue(self);
return py::module_::import(kGpuModuleImportPath)
.attr("MMAIntrinsic")(rawValue);
})
.def_property_readonly("mma", [](MlirAttribute self) -> MlirAttribute {
uint32_t value = ireeGPUMMAIntrinsicAttrGetValue(self);
return ireeGPUMMAAttrGet(mlirAttributeGetContext(self), value);
});

mlir_attribute_subclass(iree_gpu_module, "MMAAttr",
ireeAttributeIsAGPUMMAAttr, ireeGPUMMAAttrGetTypeID)
.def_classmethod(
"get",
[](const py::object &, uint32_t value, MlirContext ctx) {
return ireeGPUMMAAttrGet(ctx, value);
},
"cls"_a, "value"_a, "ctx"_a = py::none(),
"Gets a gpu.mma from parameters.")
.def_property_readonly(
"abc_element_types",
[](MlirAttribute self) -> py::tuple {
ireeGPUMMAInfo info = ireeGPUMMAAttrGetInfo(self);
return py::make_tuple(info.aElementType, info.bElementType,
info.cElementType);
})
.def_property_readonly(
"abc_vector_types",
[](MlirAttribute self) -> py::tuple {
ireeGPUMMAInfo info = ireeGPUMMAAttrGetInfo(self);
return py::make_tuple(info.aVectorType, info.bVectorType,
info.cVectorType);
})
.def_property_readonly("mnk_shape", [](MlirAttribute self) -> py::tuple {
ireeGPUMMAInfo info = ireeGPUMMAAttrGetInfo(self);
return py::make_tuple(info.mElements, info.nElements, info.kElements);
});
}
45 changes: 45 additions & 0 deletions compiler/bindings/python/test/ir/dialects_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ def gpu_pipeline_options_attr():
reorder_attr = iree_gpu.ReorderWorkgroupsStrategyAttr.get(
iree_gpu.ReorderWorkgroupsStrategy.Transpose, ctx
)
assert reorder_attr.value == iree_gpu.ReorderWorkgroupsStrategy.Transpose

gpu_attr = iree_gpu.PipelineOptionsAttr.get(
True,
False,
Expand Down Expand Up @@ -86,3 +88,46 @@ def gpu_pipeline_options_attr():
# unfortunately not `is`
== iree_gpu.ReorderWorkgroupsStrategy.Transpose
)


@lambda _: _()
def mma_intrinsic_attr():
with ir.Context() as ctx, ir.Location.unknown():
module = ir.Module.create()
with ir.InsertionPoint(module.body):
mma_intrinsic_attr = iree_gpu.MMAIntrinsicAttr.get(
iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16, ctx
)
assert mma_intrinsic_attr is not None
assert (
str(mma_intrinsic_attr)
== "#iree_gpu<mma_intrinsic MFMA_F32_32x32x8_F16>"
)

raw_value = mma_intrinsic_attr.raw_value
assert raw_value == iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16
value = mma_intrinsic_attr.value
assert str(value) == "MFMA_F32_32x32x8_F16"
assert int(value) == raw_value

mma_attr = iree_gpu.MMAAttr.get(raw_value, ctx)
assert mma_attr is not None

f16 = ir.F16Type.get()
f32 = ir.F32Type.get()
a_type, b_type, c_type = mma_attr.abc_element_types
assert a_type == f16
assert b_type == f16
assert c_type == f32

vec_4xf16 = ir.VectorType.get((4,), f16)
a_vec_type, b_vec_type, _c_vec_type = mma_attr.abc_vector_types
assert a_vec_type == vec_4xf16
assert b_vec_type == vec_4xf16

M, N, K = mma_attr.mnk_shape
assert M == 32
assert N == 32
assert K == 8

assert mma_intrinsic_attr.mma == mma_attr
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/API/Internal/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ iree_compiler_cc_library(
deps = [
"//compiler/bindings/c:headers",
"//compiler/src/iree/compiler/Codegen/Dialect/GPU/IR:IREEGPUDialect",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:CAPIIR",
"@llvm-project//mlir:CAPIIRHeaders",
],
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/API/Internal/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ iree_cc_library(
"IREEGPUDialectCAPI.cpp"
DEPS
IREELLVMIncludeSetup
LLVMSupport
kuhar marked this conversation as resolved.
Show resolved Hide resolved
MLIRCAPIIR
iree::compiler::Codegen::Dialect::GPU::IR::IREEGPUDialect
iree::compiler::bindings::c::headers
Expand Down
94 changes: 75 additions & 19 deletions compiler/src/iree/compiler/API/Internal/IREEGPUDialectCAPI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include <cstdint>
#include <type_traits>
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.h"
#include "iree/compiler/dialects/iree_gpu.h"
#include "mlir-c/IR.h"
#include "mlir/CAPI/IR.h"
Expand Down Expand Up @@ -84,20 +87,6 @@ MlirTypeID ireeGPUPipelineOptionsAttrGetTypeID() {
mlir::iree_compiler::IREE::GPU::GPUPipelineOptionsAttr::getTypeID());
}

static_assert(
static_cast<uint32_t>(ireeGPUReorderWorkgroupsStrategyEnumNone) ==
static_cast<uint32_t>(mlir::iree_compiler::IREE::GPU::
ReorderWorkgroupsStrategy::None) &&
static_cast<uint32_t>(ireeGPUReorderWorkgroupsStrategyEnumTranspose) ==
static_cast<uint32_t>(mlir::iree_compiler::IREE::GPU::
ReorderWorkgroupsStrategy::Transpose) &&
static_cast<uint32_t>(ireeGPUReorderWorkgroupsStrategyEnumTranspose) ==
mlir::iree_compiler::IREE::GPU::
getMaxEnumValForReorderWorkgroupsStrategy(),
"ireeGPUReorderWorkgroupsStrategyEnum and "
"mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategy definitions "
"have diverged");

bool ireeAttributeIsAGPUReorderWorkgroupsStrategyAttr(MlirAttribute attr) {
return llvm::isa<
mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategyAttr>(
Expand All @@ -109,8 +98,15 @@ MlirTypeID ireeGPUReorderWorkgroupsStrategyAttrGetTypeID() {
getTypeID());
}

MlirAttribute ireeGPUReorderWorkgroupsStrategyAttrGet(
MlirContext mlirCtx, ireeGPUReorderWorkgroupsStrategyEnum value) {
static_assert(
std::is_same_v<
uint32_t,
std::underlying_type_t<
mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategy>>,
"Enum type changed");

MlirAttribute ireeGPUReorderWorkgroupsStrategyAttrGet(MlirContext mlirCtx,
uint32_t value) {
mlir::MLIRContext *ctx = unwrap(mlirCtx);
return wrap(
mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategyAttr::get(
Expand All @@ -119,12 +115,72 @@ MlirAttribute ireeGPUReorderWorkgroupsStrategyAttrGet(
value)));
}

ireeGPUReorderWorkgroupsStrategyEnum
ireeGPUReorderWorkgroupsStrategyAttrGetValue(MlirAttribute attr) {
uint32_t ireeGPUReorderWorkgroupsStrategyAttrGetValue(MlirAttribute attr) {
assert(ireeAttributeIsAGPUReorderWorkgroupsStrategyAttr(attr) &&
"attr is not a GPUReorderWorkgroupsStrategyAttr");
return static_cast<ireeGPUReorderWorkgroupsStrategyEnum>(
return static_cast<uint32_t>(
llvm::cast<mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategyAttr>(
unwrap(attr))
.getValue());
}

bool ireeAttributeIsAGPUMMAIntrinsicAttr(MlirAttribute attr) {
return llvm::isa<mlir::iree_compiler::IREE::GPU::MMAIntrinsicAttr>(
unwrap(attr));
}

MlirTypeID ireeGPUMMAIntrinsicAttrGetTypeID() {
return wrap(mlir::iree_compiler::IREE::GPU::MMAIntrinsicAttr::getTypeID());
}

static_assert(
std::is_same_v<uint32_t, std::underlying_type_t<
mlir::iree_compiler::IREE::GPU::MMAIntrinsic>>,
"Enum type changed");

MlirAttribute ireeGPUMMAIntrinsicAttrGet(MlirContext mlirCtx, uint32_t value) {
mlir::MLIRContext *ctx = unwrap(mlirCtx);
return wrap(mlir::iree_compiler::IREE::GPU::MMAIntrinsicAttr::get(
ctx, static_cast<mlir::iree_compiler::IREE::GPU::MMAIntrinsic>(value)));
}

uint32_t ireeGPUMMAIntrinsicAttrGetValue(MlirAttribute attr) {
assert(ireeAttributeIsAGPUMMAIntrinsicAttr(attr) &&
"attr is not a GPUMMAIntrinsicAttr");
return static_cast<uint32_t>(
llvm::cast<mlir::iree_compiler::IREE::GPU::MMAIntrinsicAttr>(unwrap(attr))
.getValue());
}

bool ireeAttributeIsAGPUMMAAttr(MlirAttribute attr) {
return llvm::isa<mlir::iree_compiler::IREE::GPU::MMAAttr>(unwrap(attr));
}

MlirTypeID ireeGPUMMAAttrGetTypeID() {
return wrap(mlir::iree_compiler::IREE::GPU::MMAAttr::getTypeID());
}

MlirAttribute ireeGPUMMAAttrGet(MlirContext mlirCtx, uint32_t value) {
mlir::MLIRContext *ctx = unwrap(mlirCtx);
return wrap(mlir::iree_compiler::IREE::GPU::MMAAttr::get(
ctx, static_cast<mlir::iree_compiler::IREE::GPU::MMAIntrinsic>(value)));
}

ireeGPUMMAInfo ireeGPUMMAAttrGetInfo(MlirAttribute attr) {
assert(ireeAttributeIsAGPUMMAAttr(attr) && "attr is not a MMAAttr");
auto mma = llvm::cast<mlir::iree_compiler::IREE::GPU::MMAAttr>(unwrap(attr));

ireeGPUMMAInfo info = {};
auto [aType, bType, cType] = mma.getABCElementTypes();
info.aElementType = wrap(aType);
info.bElementType = wrap(bType);
info.cElementType = wrap(cType);

auto [aVecType, bVecType, cVecType] = mma.getABCVectorTypes();
info.aVectorType = wrap(aVecType);
info.bVectorType = wrap(bVecType);
info.cVectorType = wrap(cVecType);

std::tie(info.mElements, info.nElements, info.kElements) = mma.getMNKShape();
return info;
}
Loading
Loading