Skip to content

Commit

Permalink
Revert "Add Scalarization Patterns for AtenToDtypeOp, AtenNegOp, …
Browse files Browse the repository at this point in the history
…`AtenRemainderTensorOp` (#3861)"

This reverts commit cd38ecf.
  • Loading branch information
rahuls-cerebras committed Jan 3, 2025
1 parent c1c0524 commit 612ccc3
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 302 deletions.
26 changes: 2 additions & 24 deletions lib/Conversion/TorchToArith/TorchToArith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,25 +82,6 @@ class ConvertAtenBinaryOp : public OpConversionPattern<AtenOp> {
};
} // namespace

namespace {
class ConvertAtenNegIntOp : public OpConversionPattern<AtenNegIntOp> {
public:
using OpConversionPattern<AtenNegIntOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenNegIntOp op,
typename OpConversionPattern<AtenNegIntOp>::OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value a = adaptor.getA();
rewriter.replaceOpWithNewOp<arith::SubIOp>(
op,
rewriter.create<arith::ConstantIntOp>(op.getLoc(), /*value=*/0,
/*bitwidth=*/64),
a);
return success();
}
};
} // namespace

namespace {
template <typename AtenOp, typename UnaryOp>
class ConvertAtenUnaryOpToFloatMathOp : public OpConversionPattern<AtenOp> {
Expand Down Expand Up @@ -484,14 +465,11 @@ class ConvertTorchToArith

target.addIllegalOp<AtenAddOp>();
patterns.add<ConvertAtenAddOp>(typeConverter, context);
target.addIllegalOp<AtenNegIntOp>();
patterns.add<ConvertAtenNegIntOp>(typeConverter, context);

target.addIllegalOp<AtenAddIntOp, AtenAddFloatIntOp, AtenSubIntOp,
AtenMulIntOp, AtenRemainderIntOp>();
AtenMulIntOp>();
patterns.add<ConvertAtenBinaryOp<AtenAddIntOp, arith::AddIOp>>(
typeConverter, context);
patterns.add<ConvertAtenBinaryOp<AtenRemainderIntOp, arith::RemSIOp>>(
typeConverter, context);
patterns.add<ConvertAtenBinaryOp<AtenAddFloatIntOp, arith::AddFOp>>(
typeConverter, context);
patterns.add<ConvertAtenBinaryOp<AtenSubIntOp, arith::SubIOp>>(
Expand Down
4 changes: 0 additions & 4 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4068,10 +4068,6 @@ OpFoldResult AtenMulIntOp::fold(FoldAdaptor adaptor) {
int64_t lhs, rhs;
bool lConstant = matchPattern(getOperand(0), m_TorchConstantInt(&lhs));
bool rConstant = matchPattern(getOperand(1), m_TorchConstantInt(&rhs));
if (lConstant && lhs == 1)
return getOperand(1);
if (rConstant && rhs == 1)
return getOperand(0);
if ((lConstant && lhs == 0) || (rConstant && rhs == 0))
return getI64IntegerAttr(getContext(), 0);
if (lConstant && rConstant)
Expand Down
5 changes: 0 additions & 5 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4587,11 +4587,6 @@ class DecomposeAtenUnflattenIntOp
if (!isValidDim(dimInt, inputRank))
return rewriter.notifyMatchFailure(op, "dim is not a valid dim");

if (inputShape[dimInt] == Torch::kUnknownSize &&
llvm::count(sizesInts, -1) > 0)
return rewriter.notifyMatchFailure(
op, "Unimplemented: dynamic unflatten dim with an inferred size.");

SmallVector<Value> sizesTorchInt;
if (!getListConstructElements(op.getSizes(), sizesTorchInt))
return rewriter.notifyMatchFailure(
Expand Down
228 changes: 12 additions & 216 deletions lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -714,7 +714,7 @@ class PropagateAtenItemPattern : public OpRewritePattern<AtenItemOp> {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);

// Rank 0 item op prop
if (selfTy.getSizes().empty()) {
if (selfTy.getSizes().size() == 0) {
auto numToTensor = self.getDefiningOp<Torch::PrimNumToTensorScalarOp>();
auto squeezeDim = self.getDefiningOp<AtenSqueezeDimOp>();
if (!squeezeDim && !numToTensor)
Expand Down Expand Up @@ -746,109 +746,6 @@ class PropagateAtenItemPattern : public OpRewritePattern<AtenItemOp> {
};
} // namespace

namespace {

LogicalResult convertOpFoldResults(ImplicitLocOpBuilder &b,
SmallVector<OpFoldResult> &converted,
SmallVector<OpFoldResult> &elements,
Type inputDtype, Type resultDtype) {
auto inputIsInt = dyn_cast<mlir::IntegerType>(inputDtype);
auto resultIsInt = dyn_cast<mlir::IntegerType>(resultDtype);
if (!inputIsInt && !isa<mlir::FloatType>(inputDtype))
return failure();
if (!resultIsInt && !isa<mlir::FloatType>(resultDtype))
return failure();

// if dtypes are both int or both float, no conversion needed
if (static_cast<bool>(inputIsInt) == static_cast<bool>(resultIsInt)) {
converted = elements;
return success();
}

if (resultIsInt) {
for (auto &e : elements) {
auto eValue = dyn_cast<Value>(e);
if (eValue) {
converted.push_back(b.createOrFold<AtenIntScalarOp>(eValue));
continue;
}
auto eAttr = dyn_cast<Attribute>(e);
auto eFloatAttr = dyn_cast_or_null<FloatAttr>(eAttr);
if (!eFloatAttr)
return failure();

converted.push_back(IntegerAttr::get(
resultDtype, static_cast<int64_t>(eFloatAttr.getValueAsDouble())));
}
return success();
}

// result is float
for (auto &e : elements) {
auto eValue = dyn_cast<Value>(e);
if (eValue) {
converted.push_back(b.createOrFold<AtenFloatScalarOp>(eValue));
continue;
}
auto eAttr = dyn_cast<Attribute>(e);
auto eIntAttr = dyn_cast<IntegerAttr>(eAttr);
if (!eIntAttr)
return failure();

auto eInt = (inputIsInt.isSigned()) ? eIntAttr.getValue().getSExtValue()
: eIntAttr.getValue().getZExtValue();
converted.push_back(FloatAttr::get(resultDtype, static_cast<double>(eInt)));
}
return success();
}

class PropagateAtenToDtypePattern : public OpRewritePattern<AtenToDtypeOp> {
public:
using OpRewritePattern<AtenToDtypeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenToDtypeOp op,
PatternRewriter &rewriter) const override {
bool nonBlocking, copyArg;
// The non_blocking arg must be `False`.
if (!matchPattern(op.getNonBlocking(), m_TorchConstantBool(&nonBlocking)) ||
nonBlocking)
return failure();
// The copy arg must be `False`.
if (!matchPattern(op.getCopy(), m_TorchConstantBool(&copyArg)) || copyArg)
return failure();
// The memory_format arg must be `none`.
if (!isa<Torch::NoneType>(op.getMemoryFormat().getType()))
return failure();

auto inputType = dyn_cast<ValueTensorType>(op.getSelf().getType());
auto resultType = dyn_cast<ValueTensorType>(op.getType());
if (!inputType || !resultType || !inputType.hasDtype() ||
!resultType.hasDtype())
return failure();
auto inputDtype = inputType.getDtype();
auto resultDtype = resultType.getDtype();

SmallVector<OpFoldResult> elements;
if (failed(getListFromTensor(op.getSelf(), elements)))
return failure();

ImplicitLocOpBuilder b(op.getLoc(), rewriter);
SmallVector<OpFoldResult> converted;
if (failed(convertOpFoldResults(b, converted, elements, inputDtype,
resultDtype)))
return rewriter.notifyMatchFailure(
op, "Unhandled attribute type encountered.");

SmallVector<Value> vals;
if (failed(materializeFolds(b, converted, vals)))
return failure();

Value result = constructAtenTensorOpFromList(b, op.getType(), vals);
rewriter.replaceOp(op, result);
return success();
}
};
} // namespace

namespace {
template <typename AtenViewLikeOp>
class PropagateAtenViewLikePattern : public OpRewritePattern<AtenViewLikeOp> {
Expand Down Expand Up @@ -931,49 +828,6 @@ class PropagateAtenArithmeticPattern : public OpRewritePattern<OpTy> {
if (failed(materializeFolds(b, resultFolds, resultVals)))
return failure();

if (resultTy.getSizes().empty()) {
rewriter.replaceOpWithNewOp<Torch::PrimNumToTensorScalarOp>(
op, resultTy, resultVals.front());
return success();
}

Value result = constructAtenTensorOpFromList(b, resultTy, resultVals);
rewriter.replaceOp(op, result);
return success();
}
};
} // namespace

namespace {
template <typename OpTy, typename ScalarOpTy>
class PropagateAtenUnaryPattern : public OpRewritePattern<OpTy> {
public:
using OpRewritePattern<OpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
// Check type
auto resultTy = cast<ValueTensorType>(op.getType());
if (resultTy.getSizes().size() > 1)
return rewriter.notifyMatchFailure(op, "unsupported: rank > 1");
if (!resultTy.hasDtype() || !isa<mlir::IntegerType>(resultTy.getDtype()))
return rewriter.notifyMatchFailure(op, "not an int type");

ImplicitLocOpBuilder b(op.getLoc(), rewriter);
SmallVector<OpFoldResult> selfFold;
if (failed(getListFromTensor(op.getSelf(), selfFold)))
return failure();
SmallVector<Value> selfVals;
if (failed(materializeFolds(b, selfFold, selfVals)))
return failure();
SmallVector<OpFoldResult> resultFolds;
for (uint64_t i = 0; i < selfVals.size(); i++) {
resultFolds.push_back(
b.createOrFold<ScalarOpTy>(selfVals[i].getType(), selfVals[i]));
}
SmallVector<Value> resultVals;
if (failed(materializeFolds(b, resultFolds, resultVals)))
return failure();

if (resultTy.getSizes().size() == 0) {
rewriter.replaceOpWithNewOp<Torch::PrimNumToTensorScalarOp>(
op, resultTy, resultVals.front());
Expand All @@ -986,6 +840,7 @@ class PropagateAtenUnaryPattern : public OpRewritePattern<OpTy> {
}
};
} // namespace

/// ------ Fold Patterns ------ ///
// These are shape-specific folding patterns

Expand Down Expand Up @@ -1060,22 +915,19 @@ class FoldAtenTensorSplatPattern : public OpRewritePattern<AtenTensorOp> {
auto resultTy = cast<BaseTensorType>(op.getType());
if (!resultTy.hasSizes() || !resultTy.areAllSizesKnown())
return rewriter.notifyMatchFailure(op, "dynamic output shape");
if (resultTy.getSizes().size() == 0) {
rewriter.replaceOpWithNewOp<Torch::PrimNumToTensorScalarOp>(
op, op.getType(), elements.front());
return success();
}

auto loc = op.getLoc();
SmallVector<Value> sizes;
for (auto size : resultTy.getSizes())
sizes.push_back(rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(size)));

Value one = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getType<Torch::IntType>(), 1);
Value sizeList = rewriter.create<Torch::PrimListConstructOp>(
loc,
rewriter.getType<Torch::ListType>(rewriter.getType<Torch::IntType>()),
sizes);
one);

Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
Expand Down Expand Up @@ -1179,24 +1031,6 @@ class FoldAtenWhereSelf : public OpRewritePattern<AtenWhereSelfOp> {
};
} // namespace

namespace {
// fold ridiculous patterns like size.int -> float.scalar -> int.scalar
class FoldAtenIntScalarPattern : public OpRewritePattern<AtenIntScalarOp> {
public:
using OpRewritePattern<AtenIntScalarOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenIntScalarOp op,
PatternRewriter &rewriter) const override {
auto floatScalarOp = op.getA().getDefiningOp<AtenFloatScalarOp>();
if (!floatScalarOp)
return failure();
auto sizeOp = floatScalarOp.getA().getDefiningOp<AtenSizeIntOp>();
if (!sizeOp)
return failure();
rewriter.replaceOp(op, floatScalarOp.getA());
return success();
}
};
} // namespace
namespace {
class FoldAtenUnsqueezePattern : public OpRewritePattern<AtenUnsqueezeOp> {
public:
Expand Down Expand Up @@ -1348,29 +1182,8 @@ class CanonicalizeAtenViewPattern : public OpRewritePattern<AtenViewOp> {
if (inputUnmatched == 1 && outputUnmatched > 1) {
Value dimVal =
rewriter.create<Torch::ConstantIntOp>(op.getLoc(), leftMatchEnd);
SmallVector<Value> unflattenSizes(viewSizes.begin() + leftMatchEnd,
viewSizes.end() - rightMatchEnd);
// try to convert a single dynamic size input to -1
int64_t dynCount = 0;
int64_t dynIdx = 0;
for (auto [i, v] : llvm::enumerate(unflattenSizes)) {
int64_t szeInt;
if (!matchPattern(v, m_TorchConstantInt(&szeInt))) {
dynCount++;
dynIdx = i;
continue;
}
// if we have a -1 already, make dynCount invalid and break
if (szeInt == -1) {
dynCount = -1;
break;
}
}
// if only one size is dynamic, make it -1
if (dynCount == 1)
unflattenSizes[dynIdx] =
rewriter.create<Torch::ConstantIntOp>(op.getLoc(), -1);

ArrayRef<Value> unflattenSizes(viewSizes.begin() + leftMatchEnd,
viewSizes.end() - rightMatchEnd);
Value unflattenList = rewriter.create<Torch::PrimListConstructOp>(
op.getLoc(), op.getSize().getType(), unflattenSizes);
rewriter.replaceOpWithNewOp<AtenUnflattenIntOp>(
Expand Down Expand Up @@ -1414,18 +1227,6 @@ template <typename T> class RemoveUnusedPattern : public OpRewritePattern<T> {

namespace {

bool isItemForSliceOp(Operation *op) {
auto itemOp = dyn_cast_or_null<AtenItemOp>(op);
if (!itemOp)
return false;
for (OpOperand &use : op->getUses()) {
Operation *userOp = use.getOwner();
if (isa<AtenSliceTensorOp>(userOp))
return true;
}
return false;
}

bool isSourceOpForShapeScalarization(Operation *op) {
return llvm::isa<AtenSizeIntOp, Torch::ConstantIntOp, Torch::ConstantBoolOp,
Aten_ShapeAsTensorOp, Torch::ValueTensorLiteralOp>(op);
Expand All @@ -1443,7 +1244,7 @@ bool isPrimListOfInts(Operation *op) {

bool isAnchorOp(Operation *op) {
return isa<Torch::RuntimeAssertOp>(op) || isa<AtenArangeStartStepOp>(op) ||
isPrimListOfInts(op) || isItemForSliceOp(op);
isPrimListOfInts(op);
}

// The argument to this function, op, is the use of some source op, srcOp. If
Expand Down Expand Up @@ -1477,9 +1278,9 @@ bool isInvalidValidViewConsumer(Operation *op,
void populateScalarizationFoldPatterns(RewritePatternSet &patterns) {
patterns.insert<FoldAtenSqueezePattern<AtenSqueezeOp>,
FoldAtenSqueezePattern<AtenSqueezeDimOp>,
FoldAtenIntScalarPattern, FoldAtenUnsqueezePattern,
FoldAtenWhereSelf, FoldAtenTensorSplatPattern,
FoldAtenEqIntPattern>(patterns.getContext());
FoldAtenUnsqueezePattern, FoldAtenWhereSelf,
FoldAtenTensorSplatPattern, FoldAtenEqIntPattern>(
patterns.getContext());
}

void populateScalarizationCanonicalizePatterns(RewritePatternSet &patterns) {
Expand All @@ -1502,29 +1303,24 @@ void populateScalarizationPropagationPatterns(RewritePatternSet &patterns) {
PropagateAtenItemPattern, PropagateAtenShapeToTensorPattern,
PropagateAtenSliceTensorPattern, PropagateAtenEqTensorPattern,
PropagateAtenWhereSelfPattern, PropagateAtenBroadcastToPattern,
PropagateAtenTransposeIntPattern, PropagateAtenToDtypePattern,
PropagateAtenUnaryPattern<AtenNegOp, AtenNegIntOp>,
PropagateAtenTransposeIntPattern,
PropagateAtenArithmeticPattern<AtenAddTensorOp, AtenAddIntOp>,
PropagateAtenArithmeticPattern<AtenSubTensorOp, AtenSubIntOp>,
PropagateAtenArithmeticPattern<AtenMulTensorOp, AtenMulIntOp>,
PropagateAtenArithmeticPattern<AtenRemainderTensorOp, AtenRemainderIntOp>,
PropagateAtenArithmeticPattern<AtenDivTensorOp, AtenFloordivIntOp>>(
patterns.getContext());
}

void populateScalarizationRemovePatterns(RewritePatternSet &patterns) {
patterns.insert<RemoveUnusedPattern<Torch::AtenIntBoolOp>,
RemoveUnusedPattern<Torch::AtenEqIntOp>,
RemoveUnusedPattern<Torch::AtenToDtypeOp>,
RemoveUnusedPattern<Torch::PrimNumToTensorScalarOp>,
RemoveUnusedPattern<Torch::AtenFullOp>,
RemoveUnusedPattern<Torch::AtenUnsqueezeOp>,
RemoveUnusedPattern<Torch::AtenSqueezeDimOp>,
RemoveUnusedPattern<Torch::AtenSizeIntOp>,
RemoveUnusedPattern<Torch::AtenSliceTensorOp>,
RemoveUnusedPattern<Torch::AtenTensorOp>,
RemoveUnusedPattern<Torch::AtenFloatScalarOp>,
RemoveUnusedPattern<Torch::AtenIntScalarOp>,
RemoveUnusedPattern<Torch::PrimListConstructOp>>(
patterns.getContext());
}
Expand Down
Loading

0 comments on commit 612ccc3

Please sign in to comment.