Skip to content

Commit

Permalink
Revert "[TorchToLinalg] Implement lowering of torch.aten.rrelu_with_n…
Browse files Browse the repository at this point in the history
…oise and torch.aten.rrelu_with_noise_backward ops (fix) (#3748)"

This reverts commit 54d9e24.
  • Loading branch information
rahuls-cerebras committed Jan 2, 2025
1 parent c1892de commit 521b58d
Show file tree
Hide file tree
Showing 9 changed files with 0 additions and 568 deletions.
84 changes: 0 additions & 84 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -309,61 +309,6 @@ def Torch_AtenRrelu_Op : Torch_Op<"aten.rrelu_", [
}];
}

def Torch_AtenRreluWithNoiseOp : Torch_Op<"aten.rrelu_with_noise", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::rrelu_with_noise : (Tensor, Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$noise,
AnyTorchScalarType:$lower,
AnyTorchScalarType:$upper,
Torch_BoolType:$training,
AnyTorchOptionalGeneratorType:$generator
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenRreluWithNoiseOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 6, 1);
}
void AtenRreluWithNoiseOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 6, 1);
}
}];
}

def Torch_AtenRreluWithNoise_Op : Torch_Op<"aten.rrelu_with_noise_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::rrelu_with_noise_ : (Tensor, Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor)`";
let arguments = (ins
Torch_NonValueTensorType:$self,
Torch_NonValueTensorType:$noise,
AnyTorchScalarType:$lower,
AnyTorchScalarType:$upper,
Torch_BoolType:$training,
AnyTorchOptionalGeneratorType:$generator
);
let results = (outs
AnyTorchOptionalNonValueTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenRreluWithNoise_Op::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 6, 1);
}
void AtenRreluWithNoise_Op::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 6, 1);
}
}];
}

def Torch_AtenCeluOp : Torch_Op<"aten.celu", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down Expand Up @@ -17467,35 +17412,6 @@ def Torch_AtenLeakyReluBackwardOp : Torch_Op<"aten.leaky_relu_backward", [
}];
}

def Torch_AtenRreluWithNoiseBackwardOp : Torch_Op<"aten.rrelu_with_noise_backward", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::rrelu_with_noise_backward : (Tensor, Tensor, Tensor, Scalar, Scalar, bool, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$grad_output,
AnyTorchTensorType:$self,
AnyTorchTensorType:$noise,
AnyTorchScalarType:$lower,
AnyTorchScalarType:$upper,
Torch_BoolType:$training,
Torch_BoolType:$self_is_result
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenRreluWithNoiseBackwardOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 7, 1);
}
void AtenRreluWithNoiseBackwardOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 7, 1);
}
}];
}

def Torch_AtenQuantizePerChannelOp : Torch_Op<"aten.quantize_per_channel", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
57 changes: 0 additions & 57 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6690,10 +6690,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.rrelu_with_noise_backward\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.float, %arg4: !torch.float, %arg5: !torch.bool, %arg6: !torch.bool) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.hardtanh_backward\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.float, %arg3: !torch.float) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
Expand Down Expand Up @@ -7296,10 +7292,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.rrelu_with_noise\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.float, %arg3: !torch.float, %arg4: !torch.bool, %arg5: !torch.any) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.selu\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
Expand Down Expand Up @@ -12409,14 +12401,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.rrelu_with_noise_backward\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.number, %arg4: !torch.number, %arg5: !torch.bool, %arg6: !torch.bool) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list<optional<int>>\n"
" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.lift_fresh_copy\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
Expand Down Expand Up @@ -12609,47 +12593,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.rrelu_with_noise\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.number, %arg3: !torch.number, %arg4: !torch.bool, %arg5: !torch.any) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %true = torch.constant.bool true\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %2 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
" %3 = torch.prim.If %2 -> (!torch.bool) {\n"
" torch.prim.If.yield %true : !torch.bool\n"
" } else {\n"
" %7 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
" torch.prim.If.yield %7 : !torch.bool\n"
" }\n"
" torch.prim.If %3 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
" %5 = torch.prim.If %4 -> (!torch.bool) {\n"
" torch.prim.If.yield %true : !torch.bool\n"
" } else {\n"
" %7 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
" torch.prim.If.yield %7 : !torch.bool\n"
" }\n"
" torch.prim.If %5 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %6 = torch.aten.eq.int %0#0, %1#0 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %6 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.relu6\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
Expand Down
132 changes: 0 additions & 132 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3675,59 +3675,6 @@ class DecomposeAtenLeakyReluBackwardOp
};
} // namespace

namespace {
class DecomposeAtenRreluWithNoiseBackwardOp
: public OpRewritePattern<AtenRreluWithNoiseBackwardOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenRreluWithNoiseBackwardOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value gradOutput = op.getGradOutput();
Value self = op.getSelf();
Value noise = op.getNoise();
auto resType = cast<BaseTensorType>(op.getType());
if (!resType.hasDtype()) {
return rewriter.notifyMatchFailure(op, "result should have dtype");
}

bool training;
if (!matchPattern(op.getTraining(), m_TorchConstantBool(&training))) {
return rewriter.notifyMatchFailure(op,
"training should be a bool constant");
}

bool selfIsResult = false;
if (!matchPattern(op.getSelfIsResult(),
m_TorchConstantBool(&selfIsResult)) ||
selfIsResult)
return rewriter.notifyMatchFailure(
op, "unimplemented: self_is_result should be false");

double lower, upper;
if (!matchPattern(op.getLower(), m_TorchConstantFloat(&lower)) ||
!matchPattern(op.getUpper(), m_TorchConstantFloat(&upper))) {
return rewriter.notifyMatchFailure(
op, "lower and upper should be float constants");
}

if (training && (upper - lower > 0.000001)) {
Value rreluWithNoiseBackwardOutput =
rewriter.create<AtenMulTensorOp>(loc, resType, gradOutput, noise);
rewriter.replaceOp(op, rreluWithNoiseBackwardOutput);
} else {
double negative_slope = (upper + lower) / 2;
Value cstNegativeSlope = rewriter.create<ConstantFloatOp>(
loc, rewriter.getF64FloatAttr(negative_slope));
rewriter.replaceOpWithNewOp<AtenLeakyReluBackwardOp>(
op, resType, gradOutput, self, cstNegativeSlope,
op.getSelfIsResult());
}
return success();
}
};
} // namespace

namespace {
class DecomposeAtenPreluOp : public OpRewritePattern<AtenPreluOp> {
public:
Expand Down Expand Up @@ -3827,82 +3774,6 @@ class DecomposeAtenRreluOp : public OpRewritePattern<AtenRreluOp> {
};
} // namespace

namespace {
class DecomposeAtenRreluWithNoiseOp
: public OpRewritePattern<AtenRreluWithNoiseOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenRreluWithNoiseOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value self = op.getSelf();
Value noise = op.getNoise();
Value lower = op.getLower();
Value upper = op.getUpper();
auto resType = cast<BaseTensorType>(op.getType());
if (!resType.hasDtype()) {
return rewriter.notifyMatchFailure(op, "result should have dtype");
}

bool training;
if (!matchPattern(op.getTraining(), m_TorchConstantBool(&training))) {
return rewriter.notifyMatchFailure(op, "training should be a constant");
}

Value constantZeroFloat =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(0.0));
Value constantOneFloat =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
Value constantTwoFloat =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(2.0));

Value alpha;
if (training) {
Value none = rewriter.create<ConstantNoneOp>(loc);
Value emptyTensor = rewriter.create<AtenFullLikeOp>(
loc, resType, self, constantZeroFloat, /*dtype=*/none,
/*layout=*/none,
/*device=*/none, /*pin_memoty=*/none, /*memory_format=*/none);
alpha = rewriter.create<AtenUniformOp>(loc, resType, emptyTensor,
/*from=*/lower, /*to=*/upper,
/*generator=*/none);
} else {
Value half = rewriter.create<AtenAddOp>(loc, constantTwoFloat.getType(),
lower, upper);
alpha = rewriter.create<AtenDivOp>(loc, constantTwoFloat.getType(), half,
constantTwoFloat);
}

Value zeroTensor =
createRank0Tensor(rewriter, loc, resType, constantZeroFloat);
Value positiveOutput =
rewriter.create<AtenMaximumOp>(loc, resType, zeroTensor, self);

Value scaledSelf;
if (training) {
scaledSelf = rewriter.create<AtenMulTensorOp>(loc, resType, self, alpha);
auto boolResType = resType.getWithSizesAndDtype(resType.getSizes(),
rewriter.getI1Type());
Value oneTensor =
createRank0Tensor(rewriter, loc, resType, constantOneFloat);
Value not_positive = rewriter.create<AtenLtScalarOp>(
loc, boolResType, self, constantZeroFloat);
noise = rewriter.create<AtenWhereSelfOp>(loc, resType, not_positive,
alpha, oneTensor);
} else {
scaledSelf = rewriter.create<AtenMulScalarOp>(loc, resType, self, alpha);
}

Value negativeOutput =
rewriter.create<AtenMinimumOp>(loc, resType, zeroTensor, scaledSelf);
Value rreluOutput = rewriter.create<AtenAddTensorOp>(
loc, resType, positiveOutput, negativeOutput, constantOneFloat);
rewriter.replaceOp(op, rreluOutput);
return success();
}
};
} // namespace

// CELU(x)=max(0,x)+min(0,alpha∗(exp(x/alpha)−1))
namespace {
class DecomposeAtenCeluOp : public OpRewritePattern<AtenCeluOp> {
Expand Down Expand Up @@ -11319,9 +11190,6 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAtenRelu6Op>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenPreluOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRreluOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRreluWithNoiseOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRreluWithNoiseBackwardOp>(
patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenCeluOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenAtleast1dOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenAtleast2dOp>(patterns);
Expand Down
2 changes: 0 additions & 2 deletions lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -500,8 +500,6 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenPadOp>();
target.addIllegalOp<AtenPreluOp>();
target.addIllegalOp<AtenRreluOp>();
target.addIllegalOp<AtenRreluWithNoiseOp>();
target.addIllegalOp<AtenRreluWithNoiseBackwardOp>();
target.addIllegalOp<AtenCeluOp>();
target.addIllegalOp<AtenToDtypeLayoutOp>();
target.addIllegalOp<AtenToDeviceOp>();
Expand Down
Loading

0 comments on commit 521b58d

Please sign in to comment.