Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

simple dot transpose pattern #1317

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion tao_compiler/mlir/disc/transforms/disc_algebraic_simplifier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,38 @@ struct TrunciSimplifierPattern : public OpRewritePattern<arith::TruncIOp> {
}
};

// Simplifier dot and transpose op pattern, an example as following:
// from: Dot(A^T, B)^T
// to: Dot(B^T, A)
struct SimplifierDotTransposePattern
: public OpRewritePattern<mhlo::TransposeOp> {
using OpRewritePattern<mhlo::TransposeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(mhlo::TransposeOp op,
PatternRewriter& rewriter) const override {
auto loc = op->getLoc();
Value input = op->getOperand(0);
auto dotOp = input.getDefiningOp<mhlo::DotOp>();
if (!dotOp) {
return failure();
}
auto lhs = dotOp.getLhs();
auto rhs = dotOp.getRhs();
auto lhsTranspose = lhs.getDefiningOp<mhlo::TransposeOp>();
if (!lhsTranspose) {
llvm::dbgs() << "hls should be transposed tensor\n";
return failure();
}
auto rhsTranspose = rewriter.create<mhlo::TransposeOp>(
loc, rhs, lhsTranspose.getPermutation());
auto newDotOp = rewriter.create<mhlo::DotOp>(
loc, op.getType(), rhsTranspose, lhsTranspose.getOperand(),
dotOp.getPrecisionConfigAttr());
rewriter.replaceOp(op, newDotOp.getResult());
newDotOp->getParentOp()->dump();
return success();
}
};

// Simplify index_cast pattern. An examples as following:
// %0 = arith.index_cast %arg0 : index to i32
// %1 = arith.index_cast %0 : i32 to index
Expand Down Expand Up @@ -718,7 +750,8 @@ void populateDiscAlgebraicSimplifierPatterns(RewritePatternSet& patterns) {
SimplifierFromElementsPattern,
TrunciSimplifierPattern,
IndexCastSimplifierPattern,
SimplifierGetDimensionSizePattern
SimplifierGetDimensionSizePattern,
SimplifierDotTransposePattern
>(patterns.getContext());
if (isMemIntensiveOptExperimentalEnabled()) {
// Will be enabled by default after a set of robustness testing.
Expand Down
Loading