Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python][tuner] Add bindings for lowering config #19096

Merged
merged 4 commits into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions compiler/bindings/c/iree/compiler/dialects/iree_gpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,17 @@ struct ireeGPUMMAInfo {

MLIR_CAPI_EXPORTED ireeGPUMMAInfo ireeGPUMMAAttrGetInfo(MlirAttribute attr);

MLIR_CAPI_EXPORTED bool
ireeAttributeIsAGPULoweringConfigAttr(MlirAttribute attr);

MLIR_CAPI_EXPORTED MlirTypeID ireeGPULoweringConfigAttrGetTypeID(void);

MLIR_CAPI_EXPORTED MlirAttribute ireeGPULoweringConfigAttrGet(
MlirContext mlirCtx, MlirAttribute attributesDictionary);

MLIR_CAPI_EXPORTED MlirAttribute
ireeGPULoweringConfigAttrGetAttributes(MlirAttribute attr);

#ifdef __cplusplus
}
#endif
Expand Down
23 changes: 23 additions & 0 deletions compiler/bindings/python/IREECompilerDialectsModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ PYBIND11_MODULE(_ireeCompilerDialects, m) {
//===-------------------------------------------------------------------===//
// GPUMMAIntrinsicAttr
//===-------------------------------------------------------------------===//

mlir_attribute_subclass(iree_gpu_module, "MMAIntrinsicAttr",
ireeAttributeIsAGPUMMAIntrinsicAttr,
ireeGPUMMAIntrinsicAttrGetTypeID)
Expand All @@ -138,6 +139,10 @@ PYBIND11_MODULE(_ireeCompilerDialects, m) {
return ireeGPUMMAAttrGet(mlirAttributeGetContext(self), value);
});

//===-------------------------------------------------------------------===//
// GPUMMAAttr
//===-------------------------------------------------------------------===//

mlir_attribute_subclass(iree_gpu_module, "MMAAttr",
ireeAttributeIsAGPUMMAAttr, ireeGPUMMAAttrGetTypeID)
.def_classmethod(
Expand Down Expand Up @@ -165,4 +170,22 @@ PYBIND11_MODULE(_ireeCompilerDialects, m) {
ireeGPUMMAInfo info = ireeGPUMMAAttrGetInfo(self);
return py::make_tuple(info.mElements, info.nElements, info.kElements);
});

//===-------------------------------------------------------------------===//
// GPULoweringConfigAttr
//===-------------------------------------------------------------------===//

mlir_attribute_subclass(iree_gpu_module, "LoweringConfigAttr",
ireeAttributeIsAGPULoweringConfigAttr,
ireeGPULoweringConfigAttrGetTypeID)
.def_classmethod(
"get",
[](const py::object &, MlirAttribute attributeDictionary,
MlirContext ctx) {
return ireeGPULoweringConfigAttrGet(ctx, attributeDictionary);
},
"cls"_a, "value"_a, "ctx"_a = py::none(),
"Gets a gpu.lowering_config from parameters.")
.def_property_readonly("attributes",
ireeGPULoweringConfigAttrGetAttributes);
}
12 changes: 12 additions & 0 deletions compiler/bindings/python/test/ir/dialects_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,15 @@ def mma_intrinsic_attr():
assert K == 8

assert mma_intrinsic_attr.mma == mma_attr


@lambda _: _()
def lowering_config_attr():
with ir.Context() as ctx, ir.Location.unknown():
module = ir.Module.create()
with ir.InsertionPoint(module.body):
attributes = ir.DictAttr.get({"reduction": ir.ArrayAttr.get([])}, ctx)
kuhar marked this conversation as resolved.
Show resolved Hide resolved
lowering_config = iree_gpu.LoweringConfigAttr.get(attributes, ctx)
assert lowering_config is not None

assert lowering_config.attributes == attributes
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/API/Internal/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -139,5 +139,6 @@ iree_compiler_cc_library(
"//compiler/src/iree/compiler/Codegen/Dialect/GPU/IR:IREEGPUDialect",
"@llvm-project//mlir:CAPIIR",
"@llvm-project//mlir:CAPIIRHeaders",
"@llvm-project//mlir:IR",
],
)
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/API/Internal/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ iree_cc_library(
DEPS
IREELLVMIncludeSetup
MLIRCAPIIR
MLIRIR
iree::compiler::Codegen::Dialect::GPU::IR::IREEGPUDialect
iree::compiler::bindings::c::headers
PUBLIC
Expand Down
29 changes: 29 additions & 0 deletions compiler/src/iree/compiler/API/Internal/IREEGPUDialectCAPI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,17 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include <cassert>
#include <cstdint>
#include <type_traits>
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.h"
#include "iree/compiler/dialects/iree_gpu.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/IR.h"
#include "mlir/CAPI/IR.h"
#include "mlir/CAPI/Support.h"
#include "mlir/IR/BuiltinAttributes.h"

bool ireeAttributeIsAGPUPipelineOptionsAttr(MlirAttribute attr) {
return llvm::isa<mlir::iree_compiler::IREE::GPU::GPUPipelineOptionsAttr>(
Expand Down Expand Up @@ -184,3 +187,29 @@ ireeGPUMMAInfo ireeGPUMMAAttrGetInfo(MlirAttribute attr) {
std::tie(info.mElements, info.nElements, info.kElements) = mma.getMNKShape();
return info;
}

bool ireeAttributeIsAGPULoweringConfigAttr(MlirAttribute attr) {
return llvm::isa<mlir::iree_compiler::IREE::GPU::LoweringConfigAttr>(
unwrap(attr));
}

MlirTypeID ireeGPULoweringConfigAttrGetTypeID() {
return wrap(mlir::iree_compiler::IREE::GPU::LoweringConfigAttr::getTypeID());
}

MlirAttribute ireeGPULoweringConfigAttrGet(MlirContext mlirCtx,
MlirAttribute attributesDictionary) {
assert(mlirAttributeIsADictionary(attributesDictionary));
auto attributes =
llvm::cast<mlir::DictionaryAttr>(unwrap(attributesDictionary));
mlir::MLIRContext *ctx = unwrap(mlirCtx);
return wrap(
mlir::iree_compiler::IREE::GPU::LoweringConfigAttr::get(ctx, attributes));
}

MlirAttribute ireeGPULoweringConfigAttrGetAttributes(MlirAttribute attr) {
assert(ireeAttributeIsAGPULoweringConfigAttr(attr));
return wrap(llvm::cast<mlir::iree_compiler::IREE::GPU::LoweringConfigAttr>(
unwrap(attr))
.getAttributes());
}
Loading