Skip to content

Commit

Permalink
[torch] Switch to tablegen pass generation (#18226)
Browse files Browse the repository at this point in the history
This switches the pass generation defintion to tablegen. The cleanup
includes switching passes to follow the `create*Pass` naming convention.
  • Loading branch information
marbre authored Aug 16, 2024
1 parent 12e2eb4 commit 878a99b
Show file tree
Hide file tree
Showing 8 changed files with 30 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "compiler/plugins/input/Torch/InputConversion/PassDetail.h"

#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "mlir/Transforms/DialectConversion.h"
Expand All @@ -17,6 +15,9 @@

namespace mlir::iree_compiler::TorchInput {

#define GEN_PASS_DEF_BITCASTQUANTTENSORPASS
#include "compiler/plugins/input/Torch/InputConversion/Passes.h.inc"

namespace {

class BitCastQuantizedMatmul
Expand Down Expand Up @@ -105,8 +106,8 @@ class BitCastQuantizedMatmul
} // namespace

namespace {
class BitCastQuantTensorPass
: public BitCastQuantTensorPassBase<BitCastQuantTensorPass> {
class BitCastQuantTensorPass final
: public impl::BitCastQuantTensorPassBase<BitCastQuantTensorPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<IREE::Flow::FlowDialect>();
registry.insert<torch::Torch::TorchDialect>();
Expand All @@ -124,9 +125,4 @@ class BitCastQuantTensorPass
};
} // namespace

std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createBitCastQuantTensorPass() {
return std::make_unique<BitCastQuantTensorPass>();
}

} // namespace mlir::iree_compiler::TorchInput
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ iree_cc_library(
NAME
PassHeaders
HDRS
"PassDetail.h"
"Passes.h"
"Passes.h.inc"
DEPS
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#include <cstdint>
#include <numeric>

#include "compiler/plugins/input/Torch/InputConversion/PassDetail.h"
#include "compiler/plugins/input/Torch/InputConversion/Passes.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
Expand All @@ -20,6 +19,9 @@

namespace mlir::iree_compiler::TorchInput {

#define GEN_PASS_DEF_CONVERTTMTENSORTOLINALGEXTPASS
#include "compiler/plugins/input/Torch/InputConversion/Passes.h.inc"

namespace {

template <typename SrcOpTy, typename TargetOpTy>
Expand Down Expand Up @@ -156,8 +158,10 @@ namespace {
// Pass
//===----------------------------------------------------------------------===//

struct ConvertTMTensorToLinalgExtPass
: public ConvertTMTensorToLinalgExtBase<ConvertTMTensorToLinalgExtPass> {
class ConvertTMTensorToLinalgExtPass final
: public impl::ConvertTMTensorToLinalgExtPassBase<
ConvertTMTensorToLinalgExtPass> {
public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<IREE::LinalgExt::IREELinalgExtDialect>();
registry.insert<tensor::TensorDialect>();
Expand Down Expand Up @@ -190,9 +194,4 @@ struct ConvertTMTensorToLinalgExtPass
};
} // namespace

std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createConvertTMTensorToLinalgExtPass() {
return std::make_unique<ConvertTMTensorToLinalgExtPass>();
}

} // namespace mlir::iree_compiler::TorchInput
12 changes: 6 additions & 6 deletions compiler/plugins/input/Torch/InputConversion/FuncConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "compiler/plugins/input/Torch/InputConversion/PassDetail.h"
#include "compiler/plugins/input/Torch/InputConversion/Passes.h"
#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
Expand All @@ -25,6 +24,9 @@ namespace TorchConversion = mlir::torch::TorchConversion;

namespace mlir::iree_compiler::TorchInput {

#define GEN_PASS_DEF_FUNCCONVERSIONPASS
#include "compiler/plugins/input/Torch/InputConversion/Passes.h.inc"

namespace {

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -545,7 +547,9 @@ void createCoarseFencesSyncWrapper(StringRef syncFunctionName,

} // namespace

struct FuncConversionPass : public FuncConversionBase<FuncConversionPass> {
class FuncConversionPass final
: public impl::FuncConversionPassBase<FuncConversionPass> {
public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<mlir::tensor::TensorDialect>();
registry.insert<IREE::HAL::HALDialect>();
Expand Down Expand Up @@ -758,8 +762,4 @@ struct FuncConversionPass : public FuncConversionBase<FuncConversionPass> {
}
};

std::unique_ptr<OperationPass<ModuleOp>> createFuncConversionPass() {
return std::make_unique<FuncConversionPass>();
}

} // namespace mlir::iree_compiler::TorchInput
21 changes: 0 additions & 21 deletions compiler/plugins/input/Torch/InputConversion/PassDetail.h

This file was deleted.

14 changes: 3 additions & 11 deletions compiler/plugins/input/Torch/InputConversion/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,6 @@ struct TorchToIREELoweringPipelineOptions
llvm::cl::desc("Use strict symbolic shapes."), llvm::cl::init(true)};
};

std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createBitCastQuantTensorPass();

std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createConvertTMTensorToLinalgExtPass();

std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createSetStrictSymbolicShapesPass();

std::unique_ptr<OperationPass<ModuleOp>> createFuncConversionPass();

// Creates a pipeline that lowers from the torch backend contract to IREE.
// This is based on the torch-backend-to-linalg-on-tensors-backend-pipeline
// pipeline in torch-mlir but includes IREE specific lowerings.
Expand All @@ -41,6 +30,9 @@ void createTorchToIREEPipeline(
// Register all Passes
//===----------------------------------------------------------------------===//

#define GEN_PASS_DECL
#include "compiler/plugins/input/Torch/InputConversion/Passes.h.inc" // IWYU pragma: keep

void registerTMTensorConversionPasses();

} // namespace mlir::iree_compiler::TorchInput
Expand Down
8 changes: 2 additions & 6 deletions compiler/plugins/input/Torch/InputConversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,21 @@ include "mlir/Pass/PassBase.td"
def BitCastQuantTensorPass :
InterfacePass<"torch-iree-bitcast-quant-tensor", "mlir::FunctionOpInterface"> {
let summary = "Bitcasts i8 packed tensors of sub-byte types to the actual bit width";
let constructor = "mlir::iree_compiler::TorchInput::createBitCastQuantTensorPass()";
}

def ConvertTMTensorToLinalgExt :
def ConvertTMTensorToLinalgExtPass :
InterfacePass<"torch-iree-tm-tensor-to-linalg-ext", "mlir::FunctionOpInterface"> {
let summary = "Convert from TMTensor ops to LinalgExt ops on tensors";
let constructor = "mlir::iree_compiler::TorchInput::createConvertTMTensorToLinalgExtPass()";
}

def SetStrictSymbolicShapesPass :
InterfacePass<"torch-iree-set-strict-symbolic-shapes", "mlir::FunctionOpInterface"> {
let summary = "Adds the attribute indicating strict symbolic shapes in Torch IR";
let constructor = "mlir::iree_compiler::TorchInput::createSetStrictSymbolicShapesPass()";
}

def FuncConversion :
def FuncConversionPass :
Pass<"torch-iree-func-conversion", "ModuleOp"> {
let summary = "Finalizes conversion from torch to IREE";
let constructor = "mlir::iree_compiler::TorchInput::createFuncConversionPass()";
let description = [{
Conversion pass for finalizing functions and ABI. Replaces the generic
torch-func-backend-type-conversion pass.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
//
//===----------------------------------------------------------------------===//

#include "compiler/plugins/input/Torch/InputConversion/PassDetail.h"
#include "compiler/plugins/input/Torch/InputConversion/Passes.h"
#include "llvm/ADT/StringRef.h"

Expand All @@ -21,19 +20,18 @@ static const llvm::StringLiteral kStrictSymbolsMarker =

namespace mlir::iree_compiler::TorchInput {

namespace {
struct SetStrictSymbolicShapesPass
: public SetStrictSymbolicShapesPassBase<SetStrictSymbolicShapesPass> {
#define GEN_PASS_DEF_SETSTRICTSYMBOLICSHAPESPASS
#include "compiler/plugins/input/Torch/InputConversion/Passes.h.inc"

namespace {
class SetStrictSymbolicShapesPass final
: public impl::SetStrictSymbolicShapesPassBase<
SetStrictSymbolicShapesPass> {
public:
void runOnOperation() override {
getOperation()->setAttr(kStrictSymbolsMarker, UnitAttr::get(&getContext()));
}
};
} // namespace

std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createSetStrictSymbolicShapesPass() {
return std::make_unique<SetStrictSymbolicShapesPass>();
}

} // namespace mlir::iree_compiler::TorchInput

0 comments on commit 878a99b

Please sign in to comment.