From 878a99b21ecefd7e259c05dc01bf570ece043f23 Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Fri, 16 Aug 2024 06:39:55 +0200 Subject: [PATCH] [torch] Switch to tablegen pass generation (#18226) This switches the pass generation defintion to tablegen. The cleanup includes switching passes to follow the `create*Pass` naming convention. --- .../InputConversion/BitCastQuantTensor.cpp | 14 +++++-------- .../Torch/InputConversion/CMakeLists.txt | 1 - .../ConvertTMTensorToLinalgExt.cpp | 15 +++++++------ .../Torch/InputConversion/FuncConversion.cpp | 12 +++++------ .../input/Torch/InputConversion/PassDetail.h | 21 ------------------- .../input/Torch/InputConversion/Passes.h | 14 +++---------- .../input/Torch/InputConversion/Passes.td | 8 ++----- .../SetStrictSymbolicShapes.cpp | 16 +++++++------- 8 files changed, 30 insertions(+), 71 deletions(-) delete mode 100644 compiler/plugins/input/Torch/InputConversion/PassDetail.h diff --git a/compiler/plugins/input/Torch/InputConversion/BitCastQuantTensor.cpp b/compiler/plugins/input/Torch/InputConversion/BitCastQuantTensor.cpp index 0dd8c68f73dc..6a3dccdb97f9 100644 --- a/compiler/plugins/input/Torch/InputConversion/BitCastQuantTensor.cpp +++ b/compiler/plugins/input/Torch/InputConversion/BitCastQuantTensor.cpp @@ -4,8 +4,6 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -#include "compiler/plugins/input/Torch/InputConversion/PassDetail.h" - #include "iree/compiler/Dialect/Flow/IR/FlowDialect.h" #include "iree/compiler/Dialect/Flow/IR/FlowOps.h" #include "mlir/Transforms/DialectConversion.h" @@ -17,6 +15,9 @@ namespace mlir::iree_compiler::TorchInput { +#define GEN_PASS_DEF_BITCASTQUANTTENSORPASS +#include "compiler/plugins/input/Torch/InputConversion/Passes.h.inc" + namespace { class BitCastQuantizedMatmul @@ -105,8 +106,8 @@ class BitCastQuantizedMatmul } // namespace namespace { -class BitCastQuantTensorPass - : public BitCastQuantTensorPassBase { +class BitCastQuantTensorPass final + : public impl::BitCastQuantTensorPassBase { void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); registry.insert(); @@ -124,9 +125,4 @@ class BitCastQuantTensorPass }; } // namespace -std::unique_ptr> -createBitCastQuantTensorPass() { - return std::make_unique(); -} - } // namespace mlir::iree_compiler::TorchInput diff --git a/compiler/plugins/input/Torch/InputConversion/CMakeLists.txt b/compiler/plugins/input/Torch/InputConversion/CMakeLists.txt index b1fdbb816dc8..1db408527651 100644 --- a/compiler/plugins/input/Torch/InputConversion/CMakeLists.txt +++ b/compiler/plugins/input/Torch/InputConversion/CMakeLists.txt @@ -19,7 +19,6 @@ iree_cc_library( NAME PassHeaders HDRS - "PassDetail.h" "Passes.h" "Passes.h.inc" DEPS diff --git a/compiler/plugins/input/Torch/InputConversion/ConvertTMTensorToLinalgExt.cpp b/compiler/plugins/input/Torch/InputConversion/ConvertTMTensorToLinalgExt.cpp index f88e4daf433b..ec5a0a7ec675 100644 --- a/compiler/plugins/input/Torch/InputConversion/ConvertTMTensorToLinalgExt.cpp +++ b/compiler/plugins/input/Torch/InputConversion/ConvertTMTensorToLinalgExt.cpp @@ -7,7 +7,6 @@ #include #include -#include "compiler/plugins/input/Torch/InputConversion/PassDetail.h" #include "compiler/plugins/input/Torch/InputConversion/Passes.h" #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h" #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h" @@ -20,6 +19,9 @@ namespace mlir::iree_compiler::TorchInput { +#define GEN_PASS_DEF_CONVERTTMTENSORTOLINALGEXTPASS +#include "compiler/plugins/input/Torch/InputConversion/Passes.h.inc" + namespace { template @@ -156,8 +158,10 @@ namespace { // Pass //===----------------------------------------------------------------------===// -struct ConvertTMTensorToLinalgExtPass - : public ConvertTMTensorToLinalgExtBase { +class ConvertTMTensorToLinalgExtPass final + : public impl::ConvertTMTensorToLinalgExtPassBase< + ConvertTMTensorToLinalgExtPass> { +public: void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); registry.insert(); @@ -190,9 +194,4 @@ struct ConvertTMTensorToLinalgExtPass }; } // namespace -std::unique_ptr> -createConvertTMTensorToLinalgExtPass() { - return std::make_unique(); -} - } // namespace mlir::iree_compiler::TorchInput diff --git a/compiler/plugins/input/Torch/InputConversion/FuncConversion.cpp b/compiler/plugins/input/Torch/InputConversion/FuncConversion.cpp index 553dbdeecf5d..01c7a6e7f9da 100644 --- a/compiler/plugins/input/Torch/InputConversion/FuncConversion.cpp +++ b/compiler/plugins/input/Torch/InputConversion/FuncConversion.cpp @@ -4,7 +4,6 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -#include "compiler/plugins/input/Torch/InputConversion/PassDetail.h" #include "compiler/plugins/input/Torch/InputConversion/Passes.h" #include "iree/compiler/Dialect/HAL/IR/HALDialect.h" #include "iree/compiler/Dialect/HAL/IR/HALOps.h" @@ -25,6 +24,9 @@ namespace TorchConversion = mlir::torch::TorchConversion; namespace mlir::iree_compiler::TorchInput { +#define GEN_PASS_DEF_FUNCCONVERSIONPASS +#include "compiler/plugins/input/Torch/InputConversion/Passes.h.inc" + namespace { //===----------------------------------------------------------------------===// @@ -545,7 +547,9 @@ void createCoarseFencesSyncWrapper(StringRef syncFunctionName, } // namespace -struct FuncConversionPass : public FuncConversionBase { +class FuncConversionPass final + : public impl::FuncConversionPassBase { +public: void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); registry.insert(); @@ -758,8 +762,4 @@ struct FuncConversionPass : public FuncConversionBase { } }; -std::unique_ptr> createFuncConversionPass() { - return std::make_unique(); -} - } // namespace mlir::iree_compiler::TorchInput diff --git a/compiler/plugins/input/Torch/InputConversion/PassDetail.h b/compiler/plugins/input/Torch/InputConversion/PassDetail.h deleted file mode 100644 index e3e68b190846..000000000000 --- a/compiler/plugins/input/Torch/InputConversion/PassDetail.h +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright 2022 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef IREE_COMPILER_PLUGINS_INPUT_TORCH_INPUTCONVERSION_PASSDETAIL_H_ -#define IREE_COMPILER_PLUGINS_INPUT_TORCH_INPUTCONVERSION_PASSDETAIL_H_ - -#include "mlir/IR/BuiltinOps.h" -#include "mlir/Interfaces/FunctionInterfaces.h" -#include "mlir/Pass/Pass.h" - -namespace mlir::iree_compiler::TorchInput { - -#define GEN_PASS_CLASSES -#include "compiler/plugins/input/Torch/InputConversion/Passes.h.inc" - -} // namespace mlir::iree_compiler::TorchInput - -#endif // IREE_COMPILER_PLUGINS_INPUT_TORCH_INPUTCONVERSION_PASSDETAIL_H_ diff --git a/compiler/plugins/input/Torch/InputConversion/Passes.h b/compiler/plugins/input/Torch/InputConversion/Passes.h index 1abfab2c2ecb..d0e4b4f7f0d8 100644 --- a/compiler/plugins/input/Torch/InputConversion/Passes.h +++ b/compiler/plugins/input/Torch/InputConversion/Passes.h @@ -20,17 +20,6 @@ struct TorchToIREELoweringPipelineOptions llvm::cl::desc("Use strict symbolic shapes."), llvm::cl::init(true)}; }; -std::unique_ptr> -createBitCastQuantTensorPass(); - -std::unique_ptr> -createConvertTMTensorToLinalgExtPass(); - -std::unique_ptr> -createSetStrictSymbolicShapesPass(); - -std::unique_ptr> createFuncConversionPass(); - // Creates a pipeline that lowers from the torch backend contract to IREE. // This is based on the torch-backend-to-linalg-on-tensors-backend-pipeline // pipeline in torch-mlir but includes IREE specific lowerings. @@ -41,6 +30,9 @@ void createTorchToIREEPipeline( // Register all Passes //===----------------------------------------------------------------------===// +#define GEN_PASS_DECL +#include "compiler/plugins/input/Torch/InputConversion/Passes.h.inc" // IWYU pragma: keep + void registerTMTensorConversionPasses(); } // namespace mlir::iree_compiler::TorchInput diff --git a/compiler/plugins/input/Torch/InputConversion/Passes.td b/compiler/plugins/input/Torch/InputConversion/Passes.td index dd527a11a456..91b7792d08b3 100644 --- a/compiler/plugins/input/Torch/InputConversion/Passes.td +++ b/compiler/plugins/input/Torch/InputConversion/Passes.td @@ -12,25 +12,21 @@ include "mlir/Pass/PassBase.td" def BitCastQuantTensorPass : InterfacePass<"torch-iree-bitcast-quant-tensor", "mlir::FunctionOpInterface"> { let summary = "Bitcasts i8 packed tensors of sub-byte types to the actual bit width"; - let constructor = "mlir::iree_compiler::TorchInput::createBitCastQuantTensorPass()"; } -def ConvertTMTensorToLinalgExt : +def ConvertTMTensorToLinalgExtPass : InterfacePass<"torch-iree-tm-tensor-to-linalg-ext", "mlir::FunctionOpInterface"> { let summary = "Convert from TMTensor ops to LinalgExt ops on tensors"; - let constructor = "mlir::iree_compiler::TorchInput::createConvertTMTensorToLinalgExtPass()"; } def SetStrictSymbolicShapesPass : InterfacePass<"torch-iree-set-strict-symbolic-shapes", "mlir::FunctionOpInterface"> { let summary = "Adds the attribute indicating strict symbolic shapes in Torch IR"; - let constructor = "mlir::iree_compiler::TorchInput::createSetStrictSymbolicShapesPass()"; } -def FuncConversion : +def FuncConversionPass : Pass<"torch-iree-func-conversion", "ModuleOp"> { let summary = "Finalizes conversion from torch to IREE"; - let constructor = "mlir::iree_compiler::TorchInput::createFuncConversionPass()"; let description = [{ Conversion pass for finalizing functions and ABI. Replaces the generic torch-func-backend-type-conversion pass. diff --git a/compiler/plugins/input/Torch/InputConversion/SetStrictSymbolicShapes.cpp b/compiler/plugins/input/Torch/InputConversion/SetStrictSymbolicShapes.cpp index 51dbc47ffd2d..40c6cfea63b8 100644 --- a/compiler/plugins/input/Torch/InputConversion/SetStrictSymbolicShapes.cpp +++ b/compiler/plugins/input/Torch/InputConversion/SetStrictSymbolicShapes.cpp @@ -12,7 +12,6 @@ // //===----------------------------------------------------------------------===// -#include "compiler/plugins/input/Torch/InputConversion/PassDetail.h" #include "compiler/plugins/input/Torch/InputConversion/Passes.h" #include "llvm/ADT/StringRef.h" @@ -21,19 +20,18 @@ static const llvm::StringLiteral kStrictSymbolsMarker = namespace mlir::iree_compiler::TorchInput { -namespace { -struct SetStrictSymbolicShapesPass - : public SetStrictSymbolicShapesPassBase { +#define GEN_PASS_DEF_SETSTRICTSYMBOLICSHAPESPASS +#include "compiler/plugins/input/Torch/InputConversion/Passes.h.inc" +namespace { +class SetStrictSymbolicShapesPass final + : public impl::SetStrictSymbolicShapesPassBase< + SetStrictSymbolicShapesPass> { +public: void runOnOperation() override { getOperation()->setAttr(kStrictSymbolsMarker, UnitAttr::get(&getContext())); } }; } // namespace -std::unique_ptr> -createSetStrictSymbolicShapesPass() { - return std::make_unique(); -} - } // namespace mlir::iree_compiler::TorchInput