Skip to content

Commit

Permalink
Revert "[ONNX] Fix nonzero output type difference between onnx and to…
Browse files Browse the repository at this point in the history
…rch (#3916)"

This reverts commit 2c72a82.
  • Loading branch information
rahuls-cerebras committed Jan 3, 2025
1 parent f118b62 commit c6c429f
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 37 deletions.
41 changes: 12 additions & 29 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1093,35 +1093,18 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
rewriter.replaceOp(binder.op, nllLoss);
return success();
});
patterns.onOp(
"NonZero", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value operand;
if (binder.tensorOperand(operand) ||
binder.tensorResultType(resultType)) {
return failure();
}
Value zero = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
Value one = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 1));
auto rawSize = resultType.getSizes();
SmallVector<int64_t> torchResultSize(rawSize.rbegin(), rawSize.rend());
auto torchResultType = rewriter.getType<Torch::ValueTensorType>(
torchResultSize, resultType.getDtype());
auto nonZero = rewriter.create<Torch::AtenNonzeroOp>(
binder.getLoc(), torchResultType, operand);
// The output tensor has a shape of ((n, z)), where (n) is the
// number of dimensions in the input tensor and (z) is the
// number of non-zero elements2. This is different from
// PyTorch's default behavior, where the dimensions are
// reversed.
rewriter.replaceOpWithNewOp<Torch::AtenTransposeIntOp>(
binder.op, resultType, nonZero, zero, one);
return success();
});
patterns.onOp("NonZero", 13,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value operand;
if (binder.tensorOperand(operand) ||
binder.tensorResultType(resultType)) {
return failure();
}
rewriter.replaceOpWithNewOp<Torch::AtenNonzeroOp>(
binder.op, resultType, operand);
return success();
});
patterns.onOp(
"MaxPool", 12, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
std::string autoPad;
Expand Down
14 changes: 6 additions & 8 deletions test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1580,14 +1580,12 @@ func.func @test_nllloss_iii_reduction_none_ignore_negative(%arg0: !torch.vtensor

// -----

func.func @test_nonzero(%arg0: !torch.vtensor<[?],f32>) -> !torch.vtensor<[1,?],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[ZERO:.*]] = torch.constant.int 0
// CHECK: %[[ONE:.*]] = torch.constant.int 1
// CHECK: %[[NONZERO:.*]] = torch.aten.nonzero %arg0 : !torch.vtensor<[?],f32> -> !torch.vtensor<[?,1],si64>
// CHECK: %[[TRANSPOSE:.*]] = torch.aten.transpose.int %[[NONZERO]], %[[ZERO]], %[[ONE]] : !torch.vtensor<[?,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1,?],si64>
%0 = torch.operator "onnx.NonZero"(%arg0) : (!torch.vtensor<[?],f32>) -> !torch.vtensor<[1,?],si64>
return %0 : !torch.vtensor<[1,?],si64>
}
// CHECK-LABEL: func.func @test_nonzero
func.func @test_nonzero(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: torch.aten.nonzero %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],si64>
%0 = torch.operator "onnx.NonZero"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],si64>
return %0 : !torch.vtensor<[3,4,5],si64>
}

// -----

Expand Down

0 comments on commit c6c429f

Please sign in to comment.