From 521b58dc1facb7569c565ef20042606ccf0ec715 Mon Sep 17 00:00:00 2001 From: rahul shrivastava Date: Thu, 2 Jan 2025 06:23:24 -0800 Subject: [PATCH] Revert "[TorchToLinalg] Implement lowering of torch.aten.rrelu_with_noise and torch.aten.rrelu_with_noise_backward ops (fix) (#3748)" This reverts commit 54d9e2401376e7eb2c6c219e3b3555f45f8b2635. --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 84 --------- .../Transforms/AbstractInterpLibrary.cpp | 57 ------- .../Torch/Transforms/DecomposeComplexOps.cpp | 132 -------------- .../Transforms/LowerToBackendContract.cpp | 2 - projects/pt1/e2e_testing/xfail_sets.py | 22 --- .../build_tools/abstract_interp_lib_gen.py | 24 --- .../build_tools/torch_ods_gen.py | 4 - .../test_suite/backprop.py | 161 ------------------ .../test_suite/elementwise.py | 82 --------- 9 files changed, 568 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index f951de9af795..c33a9d717eac 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -309,61 +309,6 @@ def Torch_AtenRrelu_Op : Torch_Op<"aten.rrelu_", [ }]; } -def Torch_AtenRreluWithNoiseOp : Torch_Op<"aten.rrelu_with_noise", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::rrelu_with_noise : (Tensor, Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchTensorType:$noise, - AnyTorchScalarType:$lower, - AnyTorchScalarType:$upper, - Torch_BoolType:$training, - AnyTorchOptionalGeneratorType:$generator - ); - let results = (outs - AnyTorchOptionalTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenRreluWithNoiseOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 6, 1); - } - void AtenRreluWithNoiseOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 6, 1); - } - }]; -} - -def Torch_AtenRreluWithNoise_Op : Torch_Op<"aten.rrelu_with_noise_", [ - IsTrailingUnderscoreInplaceVariant, - AllowsTypeRefinement - ]> { - let summary = "Generated op for `aten::rrelu_with_noise_ : (Tensor, Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor)`"; - let arguments = (ins - Torch_NonValueTensorType:$self, - Torch_NonValueTensorType:$noise, - AnyTorchScalarType:$lower, - AnyTorchScalarType:$upper, - Torch_BoolType:$training, - AnyTorchOptionalGeneratorType:$generator - ); - let results = (outs - AnyTorchOptionalNonValueTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenRreluWithNoise_Op::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 6, 1); - } - void AtenRreluWithNoise_Op::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 6, 1); - } - }]; -} - def Torch_AtenCeluOp : Torch_Op<"aten.celu", [ AllowsTypeRefinement, HasValueSemantics, @@ -17467,35 +17412,6 @@ def Torch_AtenLeakyReluBackwardOp : Torch_Op<"aten.leaky_relu_backward", [ }]; } -def Torch_AtenRreluWithNoiseBackwardOp : Torch_Op<"aten.rrelu_with_noise_backward", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::rrelu_with_noise_backward : (Tensor, Tensor, Tensor, Scalar, Scalar, bool, bool) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$grad_output, - AnyTorchTensorType:$self, - AnyTorchTensorType:$noise, - AnyTorchScalarType:$lower, - AnyTorchScalarType:$upper, - Torch_BoolType:$training, - Torch_BoolType:$self_is_result - ); - let results = (outs - AnyTorchOptionalTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenRreluWithNoiseBackwardOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 7, 1); - } - void AtenRreluWithNoiseBackwardOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 7, 1); - } - }]; -} - def Torch_AtenQuantizePerChannelOp : Torch_Op<"aten.quantize_per_channel", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index edcc81a2847f..933543c18aaf 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6690,10 +6690,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_shape_fn.aten.rrelu_with_noise_backward\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.float, %arg4: !torch.float, %arg5: !torch.bool, %arg6: !torch.bool) -> !torch.list {\n" -" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" -" return %0 : !torch.list\n" -" }\n" " func.func @\"__torch_mlir_shape_fn.aten.hardtanh_backward\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.float, %arg3: !torch.float) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -7296,10 +7292,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_shape_fn.aten.rrelu_with_noise\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.float, %arg3: !torch.float, %arg4: !torch.bool, %arg5: !torch.any) -> !torch.list {\n" -" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" -" return %0 : !torch.list\n" -" }\n" " func.func @\"__torch_mlir_shape_fn.aten.selu\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -12409,14 +12401,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.rrelu_with_noise_backward\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.number, %arg4: !torch.number, %arg5: !torch.bool, %arg6: !torch.bool) -> !torch.int {\n" -" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" -" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" -" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" -" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" -" return %4 : !torch.int\n" -" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.lift_fresh_copy\"(%arg0: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" @@ -12609,47 +12593,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %0#1 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.rrelu_with_noise\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.number, %arg3: !torch.number, %arg4: !torch.bool, %arg5: !torch.any) -> !torch.int {\n" -" %none = torch.constant.none\n" -" %str = torch.constant.str \"AssertionError: \"\n" -" %true = torch.constant.bool true\n" -" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" -" %2 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%0#1) : (!torch.int) -> !torch.bool\n" -" %3 = torch.prim.If %2 -> (!torch.bool) {\n" -" torch.prim.If.yield %true : !torch.bool\n" -" } else {\n" -" %7 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" -" torch.prim.If.yield %7 : !torch.bool\n" -" }\n" -" torch.prim.If %3 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%1#1) : (!torch.int) -> !torch.bool\n" -" %5 = torch.prim.If %4 -> (!torch.bool) {\n" -" torch.prim.If.yield %true : !torch.bool\n" -" } else {\n" -" %7 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" -" torch.prim.If.yield %7 : !torch.bool\n" -" }\n" -" torch.prim.If %5 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %6 = torch.aten.eq.int %0#0, %1#0 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %6 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" return %0#1 : !torch.int\n" -" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.relu6\"(%arg0: !torch.tuple) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 919c4727b1f9..1caac461fe8b 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -3675,59 +3675,6 @@ class DecomposeAtenLeakyReluBackwardOp }; } // namespace -namespace { -class DecomposeAtenRreluWithNoiseBackwardOp - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(AtenRreluWithNoiseBackwardOp op, - PatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - Value gradOutput = op.getGradOutput(); - Value self = op.getSelf(); - Value noise = op.getNoise(); - auto resType = cast(op.getType()); - if (!resType.hasDtype()) { - return rewriter.notifyMatchFailure(op, "result should have dtype"); - } - - bool training; - if (!matchPattern(op.getTraining(), m_TorchConstantBool(&training))) { - return rewriter.notifyMatchFailure(op, - "training should be a bool constant"); - } - - bool selfIsResult = false; - if (!matchPattern(op.getSelfIsResult(), - m_TorchConstantBool(&selfIsResult)) || - selfIsResult) - return rewriter.notifyMatchFailure( - op, "unimplemented: self_is_result should be false"); - - double lower, upper; - if (!matchPattern(op.getLower(), m_TorchConstantFloat(&lower)) || - !matchPattern(op.getUpper(), m_TorchConstantFloat(&upper))) { - return rewriter.notifyMatchFailure( - op, "lower and upper should be float constants"); - } - - if (training && (upper - lower > 0.000001)) { - Value rreluWithNoiseBackwardOutput = - rewriter.create(loc, resType, gradOutput, noise); - rewriter.replaceOp(op, rreluWithNoiseBackwardOutput); - } else { - double negative_slope = (upper + lower) / 2; - Value cstNegativeSlope = rewriter.create( - loc, rewriter.getF64FloatAttr(negative_slope)); - rewriter.replaceOpWithNewOp( - op, resType, gradOutput, self, cstNegativeSlope, - op.getSelfIsResult()); - } - return success(); - } -}; -} // namespace - namespace { class DecomposeAtenPreluOp : public OpRewritePattern { public: @@ -3827,82 +3774,6 @@ class DecomposeAtenRreluOp : public OpRewritePattern { }; } // namespace -namespace { -class DecomposeAtenRreluWithNoiseOp - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(AtenRreluWithNoiseOp op, - PatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - Value self = op.getSelf(); - Value noise = op.getNoise(); - Value lower = op.getLower(); - Value upper = op.getUpper(); - auto resType = cast(op.getType()); - if (!resType.hasDtype()) { - return rewriter.notifyMatchFailure(op, "result should have dtype"); - } - - bool training; - if (!matchPattern(op.getTraining(), m_TorchConstantBool(&training))) { - return rewriter.notifyMatchFailure(op, "training should be a constant"); - } - - Value constantZeroFloat = - rewriter.create(loc, rewriter.getF64FloatAttr(0.0)); - Value constantOneFloat = - rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); - Value constantTwoFloat = - rewriter.create(loc, rewriter.getF64FloatAttr(2.0)); - - Value alpha; - if (training) { - Value none = rewriter.create(loc); - Value emptyTensor = rewriter.create( - loc, resType, self, constantZeroFloat, /*dtype=*/none, - /*layout=*/none, - /*device=*/none, /*pin_memoty=*/none, /*memory_format=*/none); - alpha = rewriter.create(loc, resType, emptyTensor, - /*from=*/lower, /*to=*/upper, - /*generator=*/none); - } else { - Value half = rewriter.create(loc, constantTwoFloat.getType(), - lower, upper); - alpha = rewriter.create(loc, constantTwoFloat.getType(), half, - constantTwoFloat); - } - - Value zeroTensor = - createRank0Tensor(rewriter, loc, resType, constantZeroFloat); - Value positiveOutput = - rewriter.create(loc, resType, zeroTensor, self); - - Value scaledSelf; - if (training) { - scaledSelf = rewriter.create(loc, resType, self, alpha); - auto boolResType = resType.getWithSizesAndDtype(resType.getSizes(), - rewriter.getI1Type()); - Value oneTensor = - createRank0Tensor(rewriter, loc, resType, constantOneFloat); - Value not_positive = rewriter.create( - loc, boolResType, self, constantZeroFloat); - noise = rewriter.create(loc, resType, not_positive, - alpha, oneTensor); - } else { - scaledSelf = rewriter.create(loc, resType, self, alpha); - } - - Value negativeOutput = - rewriter.create(loc, resType, zeroTensor, scaledSelf); - Value rreluOutput = rewriter.create( - loc, resType, positiveOutput, negativeOutput, constantOneFloat); - rewriter.replaceOp(op, rreluOutput); - return success(); - } -}; -} // namespace - // CELU(x)=max(0,x)+min(0,alpha∗(exp(x/alpha)−1)) namespace { class DecomposeAtenCeluOp : public OpRewritePattern { @@ -11319,9 +11190,6 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); - addPatternIfTargetOpIsIllegal(patterns); - addPatternIfTargetOpIsIllegal( - patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index f868c4c1800a..ce675b2f9301 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -500,8 +500,6 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 2237ca1446ea..0ce81594794f 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1207,10 +1207,6 @@ "ElementwisePreluStaticModule_basic", "ElementwiseReciprocalModule_basic", "ElementwiseReluModule_basic", - "ElementwiseRreluWithNoiseEvalStaticModule_basic", - "ElementwiseRreluWithNoiseTrainStaticModule_basic", - "RreluWithNoiseBackwardEvalStaticModule_basic", - "RreluWithNoiseBackwardTrainStaticModule_basic", "ElementwiseRemainderTensorModule_Float_basic", "ElementwiseRemainderTensorModule_Float_NegativeDividend_basic", "ElementwiseRemainderTensorModule_Int_Float_basic", @@ -2200,7 +2196,6 @@ "ElementwiseReciprocalModule_basic", "ElementwiseRelu6Module_basic", "ElementwiseReluModule_basic", - "ElementwiseRreluWithNoiseEvalStaticModule_basic", "ElementwiseRemainderScalarModule_Float_NegativeDividend_basic", "ElementwiseRemainderScalarModule_Float_NegativeDivisor_basic", "ElementwiseRemainderScalarModule_Int_Float_NegativeDividend_basic", @@ -2336,10 +2331,6 @@ "ReduceSumFloatModule_basic", "ReduceSumSignedIntModule_basic", "ReduceSumUnsignedIntModule_basic", - "RreluWithNoiseBackwardEvalModule_basic", - "RreluWithNoiseBackwardEvalStaticModule_basic", - "RreluWithNoiseBackwardTrainModule_basic", - "RreluWithNoiseBackwardTrainStaticModule_basic", "RepeatModule_basic", "RepeatInterleaveSelfIntNoDimModule_basic", "ReshapeAliasCollapseModule_basic", @@ -2536,10 +2527,6 @@ "ViewSizeFromOtherTensor_basic", "RenormModuleFloat32NegativeDim_basic", "RenormModuleFloat32_basic", - "RreluWithNoiseBackwardEvalModule_basic", - "RreluWithNoiseBackwardEvalStaticModule_basic", - "RreluWithNoiseBackwardTrainModule_basic", - "RreluWithNoiseBackwardTrainStaticModule_basic", } ) - { ### Test failing in make_fx_tosa but not in tosa @@ -2968,10 +2955,6 @@ "ElementwiseRemainderTensorModule_Int_basic", "ElementwiseRemainderTensorModule_Int_NegativeDividend_basic", "ElementwiseRemainderTensorModule_Int_NegativeDivisor_basic", - "ElementwiseRreluWithNoiseEvalModule_basic", - "ElementwiseRreluWithNoiseEvalStaticModule_basic", - "ElementwiseRreluWithNoiseTrainModule_basic", - "ElementwiseRreluWithNoiseTrainStaticModule_basic", "ElementwiseSgnModule_basic", "EmptyStridedModule_basic", "EmptyStridedSizeIntStrideModule_basic", @@ -3127,11 +3110,6 @@ "ReduceL1NormComplexModule_basic", "ReduceL2NormComplexModule_basic", "ReduceL3NormKeepDimComplexModule_basic", - "RreluWithNoiseBackwardEvalModule_basic", - "RreluWithNoiseBackwardEvalStaticModule_basic", - "RreluWithNoiseBackwardTrainModule_basic", - "RreluWithNoiseBackwardTrainStaticModule_basic", - "RreluWithNoiseForwardBackwardModule_basic", "ReshapeAliasCollapseModule_basic", "ReshapeAliasExpandModule_basic", "ReshapeExpandModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 331aa476910e..012833b64c8a 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -304,9 +304,6 @@ def aten〇gelu_backward〡shape(grad_output: List[int], self: List[int], approx def aten〇leaky_relu_backward〡shape(grad_output: List[int], self: List[int], negative_slope: float, self_is_result: bool) -> List[int]: return upstream_shape_functions.unary(grad_output) -def aten〇rrelu_with_noise_backward〡shape(grad_output: List[int], self: List[int], noise: List[int], lower: float, upper: float, training: bool, self_is_result: bool) -> List[int]: - return upstream_shape_functions.unary(grad_output) - def aten〇hardtanh_backward〡shape(grad_output: List[int], self: List[int], min_val: float, max_val: float) -> List[int]: return upstream_shape_functions.unary(grad_output) @@ -643,9 +640,6 @@ def aten〇celu〡shape(self: List[int], alpha: float = 1.) -> List[int]: def aten〇rrelu〡shape(self: List[int], lower: float = 0.125, upper: float = 0.33333333333333331, training: bool = False, generator: Any = None) -> List[int]: return upstream_shape_functions.unary(self) -def aten〇rrelu_with_noise〡shape(self: List[int], noise: List[int], lower: float = 0.125, upper: float = 0.33333333333333331, training: bool = False, generator: Any = None) -> List[int]: - return upstream_shape_functions.unary(self) - def aten〇selu〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) @@ -3294,15 +3288,6 @@ def aten〇leaky_relu_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], promoted_dtype = promote_dtypes(ranks, dtypes) return promoted_dtype -@check_dtype_function([Invocation(TensorOfShape(3, 3, dtype=dtype), TensorOfShape(3, 3, dtype=dtype), TensorOfShape(3, 3, dtype=dtype), 0.1, 0.9, False, False) for dtype in _SORTED_TORCH_TYPES]) -def aten〇rrelu_with_noise_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], noise_rank_dtype: Tuple[int, int], lower: Union[int, float, complex], upper: Union[int, float, complex], training: bool, self_is_result: bool) -> int: - grad_output_rank, grad_output_dtype = grad_output_rank_dtype - self_rank, self_dtype = self_rank_dtype - ranks: List[Optional[int]] = [grad_output_rank, self_rank] - dtypes = [grad_output_dtype, self_dtype] - promoted_dtype = promote_dtypes(ranks, dtypes) - return promoted_dtype - @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇lift_fresh_copy〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype @@ -3470,15 +3455,6 @@ def aten〇rrelu〡dtype(self_rank_dtype: Tuple[int, int], lower: Union[int, flo assert is_float_dtype(self_dtype) or is_complex_dtype(self_dtype) return self_dtype -@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=2, error_types={torch.bool, *all_integer_dtypes()})) -def aten〇rrelu_with_noise〡dtype(self_rank_dtype: Tuple[int, int], noise_rank_dtype: Tuple[int, int], lower: Union[int, float, complex] = 0.125, upper: Union[int, float, complex] = 0.33333333333333331, training: bool = False, generator: Any = None) -> int: - self_rank, self_dtype = self_rank_dtype - noise_rank, noise_dtype = noise_rank_dtype - assert is_float_dtype(self_dtype) or is_complex_dtype(self_dtype) - assert is_float_dtype(noise_dtype) or is_complex_dtype(noise_dtype) - assert self_rank == noise_rank - return self_dtype - @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool})) def aten〇relu6〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 8a0417a85189..698fec575749 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -302,7 +302,6 @@ def emit_with_mutating_variants(key, **kwargs): "aten::relu6 : (Tensor) -> (Tensor)", "aten::leaky_relu : (Tensor, Scalar) -> (Tensor)", "aten::rrelu : (Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor)", - "aten::rrelu_with_noise : (Tensor, Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor)", "aten::celu : (Tensor, Scalar) -> (Tensor)", "aten::selu : (Tensor) -> (Tensor)", "aten::sigmoid : (Tensor) -> (Tensor)", @@ -1208,9 +1207,6 @@ def emit_with_mutating_variants(key, **kwargs): "aten::elu_backward : (Tensor, Scalar, Scalar, Scalar, bool, Tensor) -> (Tensor)" ) emit("aten::leaky_relu_backward : (Tensor, Tensor, Scalar, bool) -> (Tensor)") - emit( - "aten::rrelu_with_noise_backward : (Tensor, Tensor, Tensor, Scalar, Scalar, bool, bool) -> (Tensor)" - ) # quantized ops emit("aten::quantize_per_channel : (Tensor, Tensor, Tensor, int, int) -> (Tensor)") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/backprop.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/backprop.py index 5e6e093902c4..e209d15b2b0b 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/backprop.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/backprop.py @@ -322,164 +322,3 @@ def forward(self, grad, input): @register_test_case(module_factory=lambda: LeakyReluBackwardStaticModule()) def LeakyReluBackwardStaticModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5), tu.rand(3, 4, 5)) - - -# ============================================================================== - - -class RreluWithNoiseBackwardTrainModule(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args( - [ - None, - ([-1, -1, -1], torch.float32, True), - ([-1, -1, -1], torch.float32, True), - ([-1, -1, -1], torch.float32, True), - ] - ) - def forward(self, grad, input, noise): - return torch.ops.aten.rrelu_with_noise_backward( - grad, - input, - noise, - lower=0.1, - upper=0.9, - training=True, - self_is_result=False, - ) - - -@register_test_case(module_factory=lambda: RreluWithNoiseBackwardTrainModule()) -def RreluWithNoiseBackwardTrainModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 4, 5), tu.rand(3, 4, 5), tu.rand(3, 4, 5)) - - -class RreluWithNoiseBackwardTrainStaticModule(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args( - [ - None, - ([3, 4, 5], torch.float32, True), - ([3, 4, 5], torch.float32, True), - ([3, 4, 5], torch.float32, True), - ] - ) - def forward(self, grad, input, noise): - return torch.ops.aten.rrelu_with_noise_backward( - grad, - input, - noise, - lower=0.1, - upper=0.9, - training=True, - self_is_result=False, - ) - - -@register_test_case(module_factory=lambda: RreluWithNoiseBackwardTrainStaticModule()) -def RreluWithNoiseBackwardTrainStaticModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 4, 5), tu.rand(3, 4, 5), tu.rand(3, 4, 5)) - - -# ============================================================================== - - -class RreluWithNoiseBackwardEvalModule(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args( - [ - None, - ([-1, -1, -1], torch.float32, True), - ([-1, -1, -1], torch.float32, True), - ([-1, -1, -1], torch.float32, True), - ] - ) - def forward(self, grad, input, noise): - return torch.ops.aten.rrelu_with_noise_backward( - grad, - input, - noise, - lower=0.1, - upper=0.9, - training=False, - self_is_result=False, - ) - - -@register_test_case(module_factory=lambda: RreluWithNoiseBackwardEvalModule()) -def RreluWithNoiseBackwardEvalModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 4, 5), tu.rand(3, 4, 5), tu.rand(3, 4, 5)) - - -class RreluWithNoiseBackwardEvalStaticModule(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args( - [ - None, - ([3, 4, 5], torch.float32, True), - ([3, 4, 5], torch.float32, True), - ([3, 4, 5], torch.float32, True), - ] - ) - def forward(self, grad, input, noise): - return torch.ops.aten.rrelu_with_noise_backward( - grad, - input, - noise, - lower=0.1, - upper=0.9, - training=False, - self_is_result=False, - ) - - -@register_test_case(module_factory=lambda: RreluWithNoiseBackwardEvalStaticModule()) -def RreluWithNoiseBackwardEvalStaticModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 4, 5), tu.rand(3, 4, 5), tu.rand(3, 4, 5)) - - -class RreluWithNoiseForwardBackwardModule(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args( - [ - None, - ([-1, -1], torch.float32, True), - ([-1, -1], torch.float32, True), - ([-1, -1], torch.float32, True), - ] - ) - def forward(self, grad, input, noise): - res = torch.ops.aten.rrelu_with_noise_backward( - grad, - input, - noise, - lower=0.4, - upper=0.6, - training=True, - self_is_result=False, - ) - return torch.mean(res), torch.std(res) - - -@register_test_case(module_factory=lambda: RreluWithNoiseForwardBackwardModule()) -def RreluWithNoiseForwardBackwardModule_basic(module, tu: TestUtils): - grad = tu.rand(256, 244) - input = tu.rand(256, 244, low=-1.0, high=1.0) - noise = tu.rand(256, 244) - torch.ops.aten.rrelu_with_noise(input, noise, lower=0.4, upper=0.6, training=True) - module.forward(grad, input, noise) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index 38fccc06b393..2e59db727341 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -1231,88 +1231,6 @@ def ElementwiseRreluEvalStaticModule_basic(module, tu: TestUtils): # ============================================================================== -class ElementwiseRreluWithNoiseTrainModule(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args( - [None, ([-1, -1], torch.float32, True), ([-1, -1], torch.float32, True)] - ) - def forward(self, x, noise): - res = torch.ops.aten.rrelu_with_noise(x, noise, 0.2, 0.5, True) - return torch.mean(res), torch.std(res) - - -@register_test_case(module_factory=lambda: ElementwiseRreluWithNoiseTrainModule()) -def ElementwiseRreluWithNoiseTrainModule_basic(module, tu: TestUtils): - module.forward(tu.rand(128, 128, low=-1, high=1), tu.rand(128, 128)) - - -# ============================================================================== - - -class ElementwiseRreluWithNoiseTrainStaticModule(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args( - [None, ([128, 128], torch.float32, True), ([128, 128], torch.float32, True)] - ) - def forward(self, x, noise): - res = torch.ops.aten.rrelu_with_noise(x, noise, 0.4, 0.6, True) - return torch.mean(res), torch.std(res) - - -@register_test_case(module_factory=lambda: ElementwiseRreluWithNoiseTrainStaticModule()) -def ElementwiseRreluWithNoiseTrainStaticModule_basic(module, tu: TestUtils): - module.forward(tu.rand(128, 128, low=-1, high=1), tu.rand(128, 128)) - - -# ============================================================================== - - -class ElementwiseRreluWithNoiseEvalModule(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args( - [None, ([-1, -1], torch.float32, True), ([-1, -1], torch.float32, True)] - ) - def forward(self, x, noise): - res = torch.ops.aten.rrelu_with_noise(x, noise, 0.4, 0.6, False) - return torch.mean(res), torch.std(res) - - -@register_test_case(module_factory=lambda: ElementwiseRreluWithNoiseEvalModule()) -def ElementwiseRreluWithNoiseEvalModule_basic(module, tu: TestUtils): - module.forward(tu.rand(5, 3, low=-1, high=1), tu.rand(5, 3)) - - -# ============================================================================== - - -class ElementwiseRreluWithNoiseEvalStaticModule(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args([None, ([5, 3], torch.float32, True), ([5, 3], torch.float32, True)]) - def forward(self, x, noise): - res = torch.ops.aten.rrelu_with_noise(x, noise, 0.4, 0.6, False) - return torch.mean(res), torch.std(res) - - -@register_test_case(module_factory=lambda: ElementwiseRreluWithNoiseEvalStaticModule()) -def ElementwiseRreluWithNoiseEvalStaticModule_basic(module, tu: TestUtils): - module.forward(tu.rand(5, 3, low=-1, high=1), tu.rand(5, 3)) - - -# ============================================================================== - - class ElementwiseCeluStaticModule(torch.nn.Module): def __init__(self): super().__init__()