Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LinalgExt] Switch to new pass generation tablegen definitions. #18216

Merged
merged 1 commit into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ iree_compiler_cc_library(
"DecomposeIm2col.cpp",
"DecomposeWinogradPass.cpp",
"PadContractionToBlockSize.cpp",
"PassDetail.h",
"Passes.cpp",
"SplitReduction.cpp",
"TileAttention.cpp",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ iree_cc_library(
"DecomposeIm2col.cpp"
"DecomposeWinogradPass.cpp"
"PadContractionToBlockSize.cpp"
"PassDetail.h"
"Passes.cpp"
"SplitReduction.cpp"
"TileAttention.cpp"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/PassDetail.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
Expand All @@ -16,6 +15,9 @@

namespace mlir::iree_compiler::IREE::LinalgExt {

#define GEN_PASS_DEF_CONVERTCONV2DTOIM2COLOPPASS
#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h.inc"

static bool hasAllOneValues(DenseIntElementsAttr attr) {
return llvm::all_of(
attr, [](APInt element) { return element.getSExtValue() == 1; });
Expand Down Expand Up @@ -322,8 +324,8 @@ class ConvertConv2DNchwFchw final
ControlFnTy controlFn;
};

struct ConvertConv2DToIm2ColOpPass
: ConvertConv2DToIm2ColOpBase<ConvertConv2DToIm2ColOpPass> {
struct ConvertConv2DToIm2ColOpPass final
: impl::ConvertConv2DToIm2ColOpPassBase<ConvertConv2DToIm2ColOpPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<tensor::TensorDialect, IREELinalgExtDialect>();
}
Expand All @@ -345,9 +347,4 @@ void populateConv2DToIm2colOpPatterns(RewritePatternSet &patterns,
patterns.getContext(), std::move(controlFn));
}

std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createConvertConv2DToIm2ColOpPass() {
return std::make_unique<ConvertConv2DToIm2ColOpPass>();
}

} // namespace mlir::iree_compiler::IREE::LinalgExt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/PassDetail.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/WinogradConstants.h"
Expand All @@ -27,6 +26,9 @@

namespace mlir::iree_compiler::IREE::LinalgExt {

#define GEN_PASS_DEF_CONVERTCONV2DTOWINOGRADPASS
#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h.inc"

static const char kWinogradAttr[] = "__winograd_conv";

static bool hasAllOneValues(DenseIntElementsAttr attr) {
Expand Down Expand Up @@ -403,8 +405,11 @@ class ConvertConvToWinograd final : public OpRewritePattern<ConvOp> {
/// }
/// }
/// ```
struct ConvertConv2DToWinogradPass
: ConvertConv2DToWinogradBase<ConvertConv2DToWinogradPass> {
struct ConvertConv2DToWinogradPass final
: impl::ConvertConv2DToWinogradPassBase<ConvertConv2DToWinogradPass> {
using impl::ConvertConv2DToWinogradPassBase<
ConvertConv2DToWinogradPass>::ConvertConv2DToWinogradPassBase;

void getDependentDialects(DialectRegistry &registry) const override {
registry
.insert<linalg::LinalgDialect, IREE::LinalgExt::IREELinalgExtDialect>();
Expand All @@ -423,10 +428,4 @@ struct ConvertConv2DToWinogradPass
};

} // namespace

std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createConvertConv2DToWinogradPass() {
return std::make_unique<ConvertConv2DToWinogradPass>();
}

} // namespace mlir::iree_compiler::IREE::LinalgExt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/PassDetail.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
Expand All @@ -23,9 +22,10 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

using namespace mlir;
namespace IREE = mlir::iree_compiler::IREE;
using namespace IREE::LinalgExt;
namespace mlir::iree_compiler::IREE::LinalgExt {

#define GEN_PASS_DEF_LINALGEXTTOLOOPSPASS
#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h.inc"

/// Recursive method that lowers one dimension of the `TiledOpInterface` to
/// scalar loops at a time.
Expand Down Expand Up @@ -100,8 +100,8 @@ struct TilingInterfaceLowerToLoopsPattern : public RewritePattern {
//===----------------------------------------------------------------------===//

namespace {
struct LinalgExtToLoopsPass
: public LinalgExtToLoopsBase<LinalgExtToLoopsPass> {
struct LinalgExtToLoopsPass final
: impl::LinalgExtToLoopsPassBase<LinalgExtToLoopsPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry
.insert<linalg::LinalgDialect, mlir::arith::ArithDialect,
Expand All @@ -120,8 +120,4 @@ struct LinalgExtToLoopsPass
}
};
} // namespace

std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
IREE::LinalgExt::createLinalgExtToLoopsPass() {
return std::make_unique<LinalgExtToLoopsPass>();
}
} // namespace mlir::iree_compiler::IREE::LinalgExt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/PassDetail.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
Expand All @@ -16,6 +15,9 @@

namespace mlir::iree_compiler::IREE::LinalgExt {

#define GEN_PASS_DEF_DECOMPOSEATTENTIONPASS
#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h.inc"

namespace {

// Computes a reduction along the rows of a 2d tensor of shape MxN
Expand Down Expand Up @@ -337,20 +339,16 @@ void decomposeTiledAttention(IREE::LinalgExt::AttentionOp tiledAttnOp,
}

namespace {
struct DecomposeAttentionPass
: public DecomposeAttentionBase<DecomposeAttentionPass> {
struct DecomposeAttentionPass final
: impl::DecomposeAttentionPassBase<DecomposeAttentionPass> {
using impl::DecomposeAttentionPassBase<
DecomposeAttentionPass>::DecomposeAttentionPassBase;

void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<
affine::AffineDialect, IREE::LinalgExt::IREELinalgExtDialect,
linalg::LinalgDialect, scf::SCFDialect, tensor::TensorDialect>();
}
DecomposeAttentionPass() = default;
DecomposeAttentionPass(bool onlyTile, uint64_t tileSize) {
this->tileSize = tileSize;
}
DecomposeAttentionPass(const DecomposeAttentionPass &pass) {
tileSize = pass.tileSize;
}
void runOnOperation() override;
};
} // namespace
Expand All @@ -377,9 +375,4 @@ void DecomposeAttentionPass::runOnOperation() {
rewriter.replaceOp(onlineAtt, results.value());
});
}

std::unique_ptr<Pass> createDecomposeAttentionPass() {
return std::make_unique<DecomposeAttentionPass>();
}

} // namespace mlir::iree_compiler::IREE::LinalgExt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/PassDetail.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
Expand All @@ -16,6 +15,10 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace mlir::iree_compiler::IREE::LinalgExt {

#define GEN_PASS_DEF_DECOMPOSEIM2COLPASS
#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h.inc"

namespace {

/// Pattern to decompose the tiled im2col op.
Expand All @@ -37,7 +40,8 @@ struct DecomposeIm2col : public OpRewritePattern<Im2colOp> {
} // namespace

namespace {
struct DecomposeIm2colPass : public DecomposeIm2colBase<DecomposeIm2colPass> {
struct DecomposeIm2colPass final
: impl::DecomposeIm2colPassBase<DecomposeIm2colPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<
affine::AffineDialect, IREE::LinalgExt::IREELinalgExtDialect,
Expand All @@ -58,10 +62,4 @@ void DecomposeIm2colPass::runOnOperation() {
return signalPassFailure();
}
}

std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createDecomposeIm2colPass() {
return std::make_unique<DecomposeIm2colPass>();
}

} // namespace mlir::iree_compiler::IREE::LinalgExt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/PassDetail.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/WinogradConstants.h"
Expand All @@ -24,6 +23,10 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace mlir::iree_compiler::IREE::LinalgExt {

#define GEN_PASS_DEF_DECOMPOSEWINOGRADTRANSFORMPASS
#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h.inc"

namespace {

/// Pattern to remove unit dims from winograd ops after tililng. Tiling is
Expand Down Expand Up @@ -333,8 +336,8 @@ struct DecomposeWinogradOutputTransform
} // namespace

namespace {
struct DecomposeWinogradTransformPass
: public DecomposeWinogradTransformBase<DecomposeWinogradTransformPass> {
struct DecomposeWinogradTransformPass final
: impl::DecomposeWinogradTransformPassBase<DecomposeWinogradTransformPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<
affine::AffineDialect, IREE::LinalgExt::IREELinalgExtDialect,
Expand Down Expand Up @@ -363,9 +366,4 @@ void DecomposeWinogradTransformPass::runOnOperation() {
}
}

std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createDecomposeWinogradTransformPass() {
return std::make_unique<DecomposeWinogradTransformPass>();
}

} // namespace mlir::iree_compiler::IREE::LinalgExt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

#include "iree-dialects/Dialect/Input/InputDialect.h"
#include "iree-dialects/Dialect/Input/InputOps.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/PassDetail.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
Expand All @@ -16,9 +15,10 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

using namespace mlir;
namespace IREE = mlir::iree_compiler::IREE;
using namespace IREE::LinalgExt;
namespace mlir::iree_compiler::IREE::LinalgExt {

#define GEN_PASS_DEF_PADCONTRACTIONTOBLOCKSIZEPASS
#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h.inc"

static Operation *sliceTensor(Location loc, Value expanded, Value original,
OpBuilder &builder) {
Expand Down Expand Up @@ -87,8 +87,11 @@ static bool padTensor(Location loc, OpOperand *operand,

namespace {

struct PadContractionToBlockSizePass
: public PadContractionToBlockSizeBase<PadContractionToBlockSizePass> {
struct PadContractionToBlockSizePass final
: impl::PadContractionToBlockSizePassBase<PadContractionToBlockSizePass> {
using impl::PadContractionToBlockSizePassBase<
PadContractionToBlockSizePass>::PadContractionToBlockSizePassBase;

void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<IREE::Input::IREEInputDialect>();
}
Expand Down Expand Up @@ -126,8 +129,4 @@ struct PadContractionToBlockSizePass
}
};
} // namespace

std::unique_ptr<OperationPass<>>
IREE::LinalgExt::createPadContractionToBlockSizePass() {
return std::make_unique<PadContractionToBlockSizePass>();
}
} // namespace mlir::iree_compiler::IREE::LinalgExt

This file was deleted.

Loading
Loading