Skip to content

Commit

Permalink
Revert "Fix onnx asinh lowering (llvm#3263)"
Browse files Browse the repository at this point in the history
This reverts commit bf04b53.
  • Loading branch information
rsuderman committed May 1, 2024
1 parent 8c48135 commit 1996fbd
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 35 deletions.
34 changes: 11 additions & 23 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,29 +198,17 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
binder.op, resultType, operand);
return success();
});
patterns.onOp(
"Asinh", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value operand;
if (binder.tensorOperand(operand) ||
binder.tensorResultType(resultType))
return failure();

// log(x + sqrt(x**2 + 1))
Value square = rewriter.create<Torch::AtenSquareOp>(
binder.getLoc(), resultType, operand);
Value cstOne = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(1));
Value add0 = rewriter.create<Torch::AtenAddScalarOp>(
binder.getLoc(), resultType, square, cstOne, cstOne);
Value sqrt = rewriter.create<Torch::AtenSqrtOp>(binder.getLoc(),
resultType, add0);
Value add1 = rewriter.create<Torch::AtenAddTensorOp>(
binder.getLoc(), resultType, operand, sqrt, cstOne);
rewriter.replaceOpWithNewOp<Torch::AtenLogOp>(binder.op, resultType,
add1);
return success();
});
patterns.onOp("Asinh", 9,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value operand;
if (binder.tensorOperand(operand) ||
binder.tensorResultType(resultType))
return failure();
rewriter.replaceOpWithNewOp<Torch::AtenAsinhOp>(
binder.op, resultType, operand);
return success();
});
patterns.onOp("Atan", 7,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Expand Down
14 changes: 2 additions & 12 deletions test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -748,12 +748,7 @@ func.func @test_asin(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,

// CHECK-LABEL: @test_asinh_example
func.func @test_asinh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[SQUARE:.+]] = torch.aten.square %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32>
// CHECK: %[[C1:.+]] = torch.constant.int 1
// CHECK: %[[ADD:.+]] = torch.aten.add.Scalar %[[SQUARE]], %[[C1]], %[[C1]] : !torch.vtensor<[3],f32>, !torch.int, !torch.int -> !torch.vtensor<[3],f32>
// CHECK: %[[SQRT:.+]] = torch.aten.sqrt %[[ADD]] : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32>
// CHECK: %[[ADD_0:.+]] = torch.aten.add.Tensor %arg0, %[[SQRT]], %[[C1]] : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32>
// CHECK: torch.aten.log %[[ADD_0]] : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32>
// CHECK: torch.aten.asinh %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32>
%0 = torch.operator "onnx.Asinh"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32>
return %0 : !torch.vtensor<[3],f32>
}
Expand All @@ -762,12 +757,7 @@ func.func @test_asinh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<

// CHECK-LABEL: @test_asinh
func.func @test_asinh(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[SQUARE:.+]] = torch.aten.square %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>
// CHECK: %[[C1:.+]] = torch.constant.int 1
// CHECK: %[[ADD:.+]] = torch.aten.add.Scalar %[[SQUARE]], %[[C1]], %[[C1]] : !torch.vtensor<[3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[3,4,5],f32>
// CHECK: %[[SQRT:.+]] = torch.aten.sqrt %[[ADD]] : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>
// CHECK: %[[ADD_0:.+]] = torch.aten.add.Tensor %arg0, %[[SQRT]], %[[C1]] : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor<[3,4,5],f32>
// CHECK: torch.aten.log %[[ADD_0]] : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>
// CHECK: torch.aten.asinh %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>
%0 = torch.operator "onnx.Asinh"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32>
return %0 : !torch.vtensor<[3,4,5],f32>
}
Expand Down

0 comments on commit 1996fbd

Please sign in to comment.