From 551cd54462aef333ae1b950f9eeec5a36a632304 Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Fri, 16 Aug 2024 06:40:20 +0200 Subject: [PATCH] [TOSA] Switch to tablegen pass generation (#18227) This switches the pass generation definition to tablegen. The cleanup includes switching passes to follow the `create*Pass` naming convention and introduces anonymous namespaces. --- .../input/TOSA/InputConversion/BUILD.bazel | 1 - .../input/TOSA/InputConversion/CMakeLists.txt | 1 - .../TOSA/InputConversion/Converti48Toi64.cpp | 16 ++++++------ .../input/TOSA/InputConversion/PassDetail.h | 25 ------------------- .../input/TOSA/InputConversion/Passes.cpp | 6 ++--- .../input/TOSA/InputConversion/Passes.h | 20 +++------------ .../input/TOSA/InputConversion/Passes.td | 12 +++------ .../TOSA/InputConversion/StripSignedness.cpp | 12 ++++----- .../TOSA/InputConversion/TosaToLinalgExt.cpp | 17 +++++++------ .../VerifyCompilerTOSAInputLegality.cpp | 17 +++++++------ 10 files changed, 42 insertions(+), 85 deletions(-) delete mode 100644 compiler/plugins/input/TOSA/InputConversion/PassDetail.h diff --git a/compiler/plugins/input/TOSA/InputConversion/BUILD.bazel b/compiler/plugins/input/TOSA/InputConversion/BUILD.bazel index 1f91a8f6ef0f..64f53168e1d1 100644 --- a/compiler/plugins/input/TOSA/InputConversion/BUILD.bazel +++ b/compiler/plugins/input/TOSA/InputConversion/BUILD.bazel @@ -30,7 +30,6 @@ iree_gentbl_cc_library( iree_compiler_cc_library( name = "PassHeaders", hdrs = [ - "PassDetail.h", "Passes.h", "Passes.h.inc", ], diff --git a/compiler/plugins/input/TOSA/InputConversion/CMakeLists.txt b/compiler/plugins/input/TOSA/InputConversion/CMakeLists.txt index b73fad01586d..bb9f3ad2c055 100644 --- a/compiler/plugins/input/TOSA/InputConversion/CMakeLists.txt +++ b/compiler/plugins/input/TOSA/InputConversion/CMakeLists.txt @@ -23,7 +23,6 @@ iree_cc_library( NAME PassHeaders HDRS - "PassDetail.h" "Passes.h" "Passes.h.inc" DEPS diff --git a/compiler/plugins/input/TOSA/InputConversion/Converti48Toi64.cpp b/compiler/plugins/input/TOSA/InputConversion/Converti48Toi64.cpp index 06725dcf082c..1249651a2e65 100644 --- a/compiler/plugins/input/TOSA/InputConversion/Converti48Toi64.cpp +++ b/compiler/plugins/input/TOSA/InputConversion/Converti48Toi64.cpp @@ -4,8 +4,6 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -#include "compiler/plugins/input/TOSA/InputConversion/PassDetail.h" - #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/Transforms/Passes.h" @@ -16,7 +14,13 @@ using namespace mlir; namespace mlir::iree_compiler { -class Converti48Toi64Pass : public Converti48Toi64Base { +#define GEN_PASS_DEF_CONVERTI48TOI64PASS +#include "compiler/plugins/input/TOSA/InputConversion/Passes.h.inc" + +namespace { + +class Converti48Toi64Pass final + : public impl::Converti48Toi64PassBase { public: explicit Converti48Toi64Pass() = default; void runOnOperation() override; @@ -174,9 +178,5 @@ void Converti48Toi64Pass::runOnOperation() { } } -std::unique_ptr> -createConverti48Toi64() { - return std::make_unique(); -} - +} // namespace } // namespace mlir::iree_compiler diff --git a/compiler/plugins/input/TOSA/InputConversion/PassDetail.h b/compiler/plugins/input/TOSA/InputConversion/PassDetail.h deleted file mode 100644 index 6444c75fd196..000000000000 --- a/compiler/plugins/input/TOSA/InputConversion/PassDetail.h +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef IREE_COMPILER_PLUGINS_INPUT_TOSA_INPUTCONVERSION_PASSDETAIL_H_ -#define IREE_COMPILER_PLUGINS_INPUT_TOSA_INPUTCONVERSION_PASSDETAIL_H_ - -#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/Interfaces/FunctionInterfaces.h" -#include "mlir/Pass/Pass.h" - -namespace mlir::iree_compiler { - -#define GEN_PASS_CLASSES -#include "compiler/plugins/input/TOSA/InputConversion/Passes.h.inc" - -} // namespace mlir::iree_compiler - -#endif // IREE_COMPILER_PLUGINS_INPUT_TOSA_INPUTCONVERSION_PASSDETAIL_H_ diff --git a/compiler/plugins/input/TOSA/InputConversion/Passes.cpp b/compiler/plugins/input/TOSA/InputConversion/Passes.cpp index 451383ba76ee..c17b69470819 100644 --- a/compiler/plugins/input/TOSA/InputConversion/Passes.cpp +++ b/compiler/plugins/input/TOSA/InputConversion/Passes.cpp @@ -47,7 +47,7 @@ void buildTOSAInputConversionPassPipeline(OpPassManager &passManager) { passManager.addNestedPass(tosa::createTosaToArith()); passManager.addNestedPass(tosa::createTosaToTensor()); passManager.addNestedPass( - iree_compiler::createTosaToLinalgExt()); + iree_compiler::createTosaToLinalgExtPass()); passManager.addNestedPass(mlir::createCanonicalizerPass()); TosaToLinalgNamedOptions tosaToLinalgNamedOptions; @@ -55,7 +55,7 @@ void buildTOSAInputConversionPassPipeline(OpPassManager &passManager) { tosa::addTosaToLinalgPasses(passManager, TosaToLinalgOptions(), tosaToLinalgNamedOptions); passManager.addNestedPass( - iree_compiler::createConverti48Toi64()); + iree_compiler::createConverti48Toi64Pass()); // Sometimes we generate more TOSA operations during the lowering to linalg. passManager.addNestedPass(tosa::createTosaToArith()); @@ -74,7 +74,7 @@ void buildTOSAInputConversionPassPipeline(OpPassManager &passManager) { //---------------------------------------------------------------------------- // Entry dialect cleanup //---------------------------------------------------------------------------- - passManager.addPass(createVerifyCompilerTOSAInputLegality()); + passManager.addPass(createVerifyCompilerTOSAInputLegalityPass()); } namespace { diff --git a/compiler/plugins/input/TOSA/InputConversion/Passes.h b/compiler/plugins/input/TOSA/InputConversion/Passes.h index f47e4b10b478..f23d11b90ea3 100644 --- a/compiler/plugins/input/TOSA/InputConversion/Passes.h +++ b/compiler/plugins/input/TOSA/InputConversion/Passes.h @@ -29,27 +29,13 @@ void registerTOSAConversionPassPipeline(); // Set of patterns for materializing TOSA operations to linalg_ext. void populateTosaToLinalgExtPatterns(RewritePatternSet *patterns); -// Converts i48 to i64. -std::unique_ptr> -createConverti48Toi64(); - -// Strips the signed/unsigned portion off of tensors. -std::unique_ptr> -createStripSignednessPass(); - -// Converts TOSA operations to linalg_ext. -std::unique_ptr> -createTosaToLinalgExt(); - -// Verifies that a module only contains IR structures that are supported by the -// core compiler. -std::unique_ptr> -createVerifyCompilerTOSAInputLegality(); - //===----------------------------------------------------------------------===// // Register all Passes //===----------------------------------------------------------------------===// +#define GEN_PASS_DECL +#include "compiler/plugins/input/TOSA/InputConversion/Passes.h.inc" // IWYU pragma: export + void registerTOSAConversionPasses(); } // namespace mlir::iree_compiler diff --git a/compiler/plugins/input/TOSA/InputConversion/Passes.td b/compiler/plugins/input/TOSA/InputConversion/Passes.td index 7ea82d1e52ff..53294e14a940 100644 --- a/compiler/plugins/input/TOSA/InputConversion/Passes.td +++ b/compiler/plugins/input/TOSA/InputConversion/Passes.td @@ -9,22 +9,19 @@ include "mlir/Pass/PassBase.td" -def Converti48Toi64 : +def Converti48Toi64Pass : InterfacePass<"iree-tosa-convert-i48-to-i64", "mlir::FunctionOpInterface"> { let summary = "Converts all i48s to i64s"; - let constructor = "mlir::iree_compiler::createConverti48Toi64()"; } -def StripSignedness : +def StripSignednessPass : InterfacePass<"iree-tosa-strip-signedness", "mlir::FunctionOpInterface"> { let summary = "Legalizes ui tensors constants to uis"; - let constructor = "mlir::iree_compiler::createStripSignednessPass()"; } -def TosaToLinalgExt : +def TosaToLinalgExtPass : InterfacePass<"iree-tosa-to-linalg-ext", "mlir::FunctionOpInterface"> { let summary = "Convert TOSA operations to their equivalent linalg-ext operations."; - let constructor = "mlir::iree_compiler::createTosaToLinalgExt()"; let dependentDialects = [ "arith::ArithDialect", "linalg::LinalgDialect", @@ -33,10 +30,9 @@ def TosaToLinalgExt : ]; } -def VerifyCompilerTOSAInputLegality : +def VerifyCompilerTOSAInputLegalityPass : Pass<"iree-tosa-verify-compiler-input-legality", "ModuleOp"> { let summary = "Verifies that only supported IR constructs are passed to the compiler."; - let constructor = "mlir::iree_compiler::createVerifyCompilerTOSAInputLegality()"; } #endif // IREE_COMPILER_PLUGINS_INPUT_TOSA_INPUTCONVERSION_PASSES diff --git a/compiler/plugins/input/TOSA/InputConversion/StripSignedness.cpp b/compiler/plugins/input/TOSA/InputConversion/StripSignedness.cpp index d9097e321c61..0e32a1bfc28b 100644 --- a/compiler/plugins/input/TOSA/InputConversion/StripSignedness.cpp +++ b/compiler/plugins/input/TOSA/InputConversion/StripSignedness.cpp @@ -4,7 +4,6 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -#include "compiler/plugins/input/TOSA/InputConversion/PassDetail.h" #include "compiler/plugins/input/TOSA/InputConversion/Passes.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/FunctionInterfaces.h" @@ -12,9 +11,13 @@ namespace mlir::iree_compiler { +#define GEN_PASS_DEF_STRIPSIGNEDNESSPASS +#include "compiler/plugins/input/TOSA/InputConversion/Passes.h.inc" + namespace { -class StripSignednessPass : public StripSignednessBase { +class StripSignednessPass final + : public impl::StripSignednessPassBase { public: explicit StripSignednessPass() {} void runOnOperation() override; @@ -125,9 +128,4 @@ void StripSignednessPass::runOnOperation() { } // namespace -std::unique_ptr> -createStripSignednessPass() { - return std::make_unique(); -} - } // namespace mlir::iree_compiler diff --git a/compiler/plugins/input/TOSA/InputConversion/TosaToLinalgExt.cpp b/compiler/plugins/input/TOSA/InputConversion/TosaToLinalgExt.cpp index 600b002b9add..72b227389fc5 100644 --- a/compiler/plugins/input/TOSA/InputConversion/TosaToLinalgExt.cpp +++ b/compiler/plugins/input/TOSA/InputConversion/TosaToLinalgExt.cpp @@ -4,7 +4,6 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -#include "compiler/plugins/input/TOSA/InputConversion/PassDetail.h" #include "compiler/plugins/input/TOSA/InputConversion/Passes.h" #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h" #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h" @@ -21,6 +20,11 @@ using namespace mlir::tosa; namespace mlir::iree_compiler { +#define GEN_PASS_DEF_TOSATOLINALGEXTPASS +#include "compiler/plugins/input/TOSA/InputConversion/Passes.h.inc" + +namespace { + // Converts tosa.scatter to the iree_linalg_ext.scatter operation. As the // LinalgExt version is not batched therefore we materialize the batch index // for each update. @@ -145,7 +149,9 @@ class ScatterConversion : public OpRewritePattern { } }; -struct TosaToLinalgExtPass : public TosaToLinalgExtBase { +class TosaToLinalgExtPass final + : public impl::TosaToLinalgExtPassBase { +public: void runOnOperation() override { RewritePatternSet patterns(&getContext()); ConversionTarget target(getContext()); @@ -159,13 +165,10 @@ struct TosaToLinalgExtPass : public TosaToLinalgExtBase { } }; +} // namespace + void populateTosaToLinalgExtPatterns(RewritePatternSet *patterns) { patterns->add(patterns->getContext()); } -std::unique_ptr> -createTosaToLinalgExt() { - return std::make_unique(); -} - } // namespace mlir::iree_compiler diff --git a/compiler/plugins/input/TOSA/InputConversion/VerifyCompilerTOSAInputLegality.cpp b/compiler/plugins/input/TOSA/InputConversion/VerifyCompilerTOSAInputLegality.cpp index dd0f27537e42..36f3cb51fe33 100644 --- a/compiler/plugins/input/TOSA/InputConversion/VerifyCompilerTOSAInputLegality.cpp +++ b/compiler/plugins/input/TOSA/InputConversion/VerifyCompilerTOSAInputLegality.cpp @@ -4,7 +4,6 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -#include "compiler/plugins/input/TOSA/InputConversion/PassDetail.h" #include "compiler/plugins/input/TOSA/InputConversion/Passes.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Pass/Pass.h" @@ -13,9 +12,15 @@ namespace mlir::iree_compiler { -struct VerifyCompilerTOSAInputLegalityPass - : public VerifyCompilerTOSAInputLegalityBase< +#define GEN_PASS_DEF_VERIFYCOMPILERTOSAINPUTLEGALITYPASS +#include "compiler/plugins/input/TOSA/InputConversion/Passes.h.inc" + +namespace { + +class VerifyCompilerTOSAInputLegalityPass final + : public impl::VerifyCompilerTOSAInputLegalityPassBase< VerifyCompilerTOSAInputLegalityPass> { +public: void runOnOperation() override { auto *context = &getContext(); ConversionTarget conversionTarget(*context); @@ -63,9 +68,5 @@ struct VerifyCompilerTOSAInputLegalityPass } }; -std::unique_ptr> -createVerifyCompilerTOSAInputLegality() { - return std::make_unique(); -} - +} // namespace } // namespace mlir::iree_compiler