Skip to content

Commit

Permalink
[python][tuner] Add bindings for lowering config (#19096)
Browse files Browse the repository at this point in the history
Keeping it simple for now -- no attribute methods are exposed beyond
access to the attribute dictionary.
  • Loading branch information
kuhar authored Nov 11, 2024
1 parent 55b998a commit 55f2fce
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 0 deletions.
11 changes: 11 additions & 0 deletions compiler/bindings/c/iree/compiler/dialects/iree_gpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,17 @@ struct ireeGPUMMAInfo {

MLIR_CAPI_EXPORTED ireeGPUMMAInfo ireeGPUMMAAttrGetInfo(MlirAttribute attr);

MLIR_CAPI_EXPORTED bool
ireeAttributeIsAGPULoweringConfigAttr(MlirAttribute attr);

MLIR_CAPI_EXPORTED MlirTypeID ireeGPULoweringConfigAttrGetTypeID(void);

MLIR_CAPI_EXPORTED MlirAttribute ireeGPULoweringConfigAttrGet(
MlirContext mlirCtx, MlirAttribute attributesDictionary);

MLIR_CAPI_EXPORTED MlirAttribute
ireeGPULoweringConfigAttrGetAttributes(MlirAttribute attr);

#ifdef __cplusplus
}
#endif
Expand Down
23 changes: 23 additions & 0 deletions compiler/bindings/python/IREECompilerDialectsModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ PYBIND11_MODULE(_ireeCompilerDialects, m) {
//===-------------------------------------------------------------------===//
// GPUMMAIntrinsicAttr
//===-------------------------------------------------------------------===//

mlir_attribute_subclass(iree_gpu_module, "MMAIntrinsicAttr",
ireeAttributeIsAGPUMMAIntrinsicAttr,
ireeGPUMMAIntrinsicAttrGetTypeID)
Expand All @@ -138,6 +139,10 @@ PYBIND11_MODULE(_ireeCompilerDialects, m) {
return ireeGPUMMAAttrGet(mlirAttributeGetContext(self), value);
});

//===-------------------------------------------------------------------===//
// GPUMMAAttr
//===-------------------------------------------------------------------===//

mlir_attribute_subclass(iree_gpu_module, "MMAAttr",
ireeAttributeIsAGPUMMAAttr, ireeGPUMMAAttrGetTypeID)
.def_classmethod(
Expand Down Expand Up @@ -165,4 +170,22 @@ PYBIND11_MODULE(_ireeCompilerDialects, m) {
ireeGPUMMAInfo info = ireeGPUMMAAttrGetInfo(self);
return py::make_tuple(info.mElements, info.nElements, info.kElements);
});

//===-------------------------------------------------------------------===//
// GPULoweringConfigAttr
//===-------------------------------------------------------------------===//

mlir_attribute_subclass(iree_gpu_module, "LoweringConfigAttr",
ireeAttributeIsAGPULoweringConfigAttr,
ireeGPULoweringConfigAttrGetTypeID)
.def_classmethod(
"get",
[](const py::object &, MlirAttribute attributeDictionary,
MlirContext ctx) {
return ireeGPULoweringConfigAttrGet(ctx, attributeDictionary);
},
"cls"_a, "value"_a, "ctx"_a = py::none(),
"Gets a gpu.lowering_config from parameters.")
.def_property_readonly("attributes",
ireeGPULoweringConfigAttrGetAttributes);
}
12 changes: 12 additions & 0 deletions compiler/bindings/python/test/ir/dialects_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,15 @@ def mma_intrinsic_attr():
assert K == 8

assert mma_intrinsic_attr.mma == mma_attr


@lambda _: _()
def lowering_config_attr():
with ir.Context() as ctx, ir.Location.unknown():
module = ir.Module.create()
with ir.InsertionPoint(module.body):
attributes = ir.DictAttr.get({"reduction": ir.ArrayAttr.get([])}, ctx)
lowering_config = iree_gpu.LoweringConfigAttr.get(attributes, ctx)
assert lowering_config is not None

assert lowering_config.attributes == attributes
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/API/Internal/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -139,5 +139,6 @@ iree_compiler_cc_library(
"//compiler/src/iree/compiler/Codegen/Dialect/GPU/IR:IREEGPUDialect",
"@llvm-project//mlir:CAPIIR",
"@llvm-project//mlir:CAPIIRHeaders",
"@llvm-project//mlir:IR",
],
)
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/API/Internal/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ iree_cc_library(
DEPS
IREELLVMIncludeSetup
MLIRCAPIIR
MLIRIR
iree::compiler::Codegen::Dialect::GPU::IR::IREEGPUDialect
iree::compiler::bindings::c::headers
PUBLIC
Expand Down
29 changes: 29 additions & 0 deletions compiler/src/iree/compiler/API/Internal/IREEGPUDialectCAPI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,17 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include <cassert>
#include <cstdint>
#include <type_traits>
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.h"
#include "iree/compiler/dialects/iree_gpu.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/IR.h"
#include "mlir/CAPI/IR.h"
#include "mlir/CAPI/Support.h"
#include "mlir/IR/BuiltinAttributes.h"

bool ireeAttributeIsAGPUPipelineOptionsAttr(MlirAttribute attr) {
return llvm::isa<mlir::iree_compiler::IREE::GPU::GPUPipelineOptionsAttr>(
Expand Down Expand Up @@ -184,3 +187,29 @@ ireeGPUMMAInfo ireeGPUMMAAttrGetInfo(MlirAttribute attr) {
std::tie(info.mElements, info.nElements, info.kElements) = mma.getMNKShape();
return info;
}

bool ireeAttributeIsAGPULoweringConfigAttr(MlirAttribute attr) {
return llvm::isa<mlir::iree_compiler::IREE::GPU::LoweringConfigAttr>(
unwrap(attr));
}

MlirTypeID ireeGPULoweringConfigAttrGetTypeID() {
return wrap(mlir::iree_compiler::IREE::GPU::LoweringConfigAttr::getTypeID());
}

MlirAttribute ireeGPULoweringConfigAttrGet(MlirContext mlirCtx,
MlirAttribute attributesDictionary) {
assert(mlirAttributeIsADictionary(attributesDictionary));
auto attributes =
llvm::cast<mlir::DictionaryAttr>(unwrap(attributesDictionary));
mlir::MLIRContext *ctx = unwrap(mlirCtx);
return wrap(
mlir::iree_compiler::IREE::GPU::LoweringConfigAttr::get(ctx, attributes));
}

MlirAttribute ireeGPULoweringConfigAttrGetAttributes(MlirAttribute attr) {
assert(ireeAttributeIsAGPULoweringConfigAttr(attr));
return wrap(llvm::cast<mlir::iree_compiler::IREE::GPU::LoweringConfigAttr>(
unwrap(attr))
.getAttributes());
}

0 comments on commit 55f2fce

Please sign in to comment.