Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][ONNX] Add OnnxToTorch support for ReduceSumSquare Op #3188

Merged
merged 21 commits into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -914,6 +914,26 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
/*storeValue=*/data, keepDims,
noop_with_empty_axes, false);
});
patterns.onOp("ReduceSumSquare", 1,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value data;
int64_t keepDims, noop_with_empty_axes;
if (binder.tensorOperandAtIndex(data, 0) ||
binder.tensorResultType(resultType) ||
binder.s64IntegerAttr(keepDims, "keepdims", 1) ||
binder.s64IntegerAttr(noop_with_empty_axes,
"noop_with_empty_axes", 0))
return failure();

Value dataSquare = rewriter.create<Torch::AtenMulTensorOp>(
binder.getLoc(), data.getType(), data, data);

return reducedSumImpl(binder, rewriter, dataSquare,
resultType,
/*storeValue=*/dataSquare, keepDims,
noop_with_empty_axes, false);
});
patterns.onOp(
"ReduceMean", 1,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Expand Down
94 changes: 93 additions & 1 deletion test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -866,7 +866,7 @@ func.func @test_reduce_max_attr(%arg0: !torch.vtensor<[4,2],i1>) -> !torch.vtens
// -----

// CHECK-LABEL: func.func @test_reduce_l1_default_axes_keepdims_example
func.func @test_reduce_l1_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
func.func @test_reduce_l1_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_torch.aten.mul.Tensormeta.producer_version = ""} {
archana-ramalingam marked this conversation as resolved.
Show resolved Hide resolved
// CHECK: %[[ABS:.+]] = torch.aten.abs %arg0 : !torch.vtensor<[3,2,2],f32> -> !torch.vtensor<[3,2,2],f32>
// CHECK: %[[INT0:.+]] = torch.constant.int 0
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list<int>
Expand Down Expand Up @@ -998,6 +998,98 @@ func.func @test_reduce_sum_negative_axes_keepdims_example(%arg0: !torch.vtensor<

// -----

// CHECK-LABEL: func.func @test_reduce_sum_square_default_axes_keepdims_example
func.func @test_reduce_sum_square_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[MULT:.+]] = torch.aten.mul.Tensor %arg0, %arg0 : !torch.vtensor<[3,2,2],f32>, !torch.vtensor<[3,2,2],f32> -> !torch.vtensor<[3,2,2],f32>
// CHECK: %[[INT0:.+]] = torch.constant.int 0
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list<int>
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
// CHECK: %[[NONE:.+]] = torch.constant.none
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[MULT]], %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32>
// CHECK: return %[[SUM]] : !torch.vtensor<[1,1,1],f32>
%0 = torch.operator "onnx.ReduceSumSquare"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32>
return %0 : !torch.vtensor<[1,1,1],f32>
}

// -----

// CHECK-LABEL: func.func @test_reduce_sum_square_do_not_keepdims_example
func.func @test_reduce_sum_square_do_not_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[MULT:.+]] = torch.aten.mul.Tensor %arg0, %arg0 : !torch.vtensor<[3,2,2],f32>, !torch.vtensor<[3,2,2],f32> -> !torch.vtensor<[3,2,2],f32>
// CHECK: %[[INT0:.+]] = torch.constant.int 0
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
// CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
// CHECK: %[[NONE:.+]] = torch.constant.none
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[MULT]], %[[DIMS]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,2],f32>
// CHECK: return %[[SUM]] : !torch.vtensor<[3,2],f32>
%0 = torch.operator "onnx.ReduceSumSquare"(%arg0, %arg1) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32>
return %0 : !torch.vtensor<[3,2],f32>
}

// -----

// CHECK-LABEL: func.func @test_reduce_sum_square_empty_axes_input_noop_example
func.func @test_reduce_sum_square_empty_axes_input_noop_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[3,2,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
%0 = torch.operator "onnx.ReduceSumSquare"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64, torch.onnx.noop_with_empty_axes = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[0],si64>) -> !torch.vtensor<[3,2,2],f32>
return %0 : !torch.vtensor<[3,2,2],f32>
}
archana-ramalingam marked this conversation as resolved.
Show resolved Hide resolved

// -----

// CHECK-LABEL: func.func @test_reduce_sum_square_empty_set_non_reduced_axis_zero
func.func @test_reduce_sum_square_empty_set_non_reduced_axis_zero(%arg0: !torch.vtensor<[2,0,4],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,0,1],f32> attributes {torch.onnx_meta.ir_version = 8: si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[MULT:.+]] = torch.aten.mul.Tensor %arg0, %arg0 : !torch.vtensor<[2,0,4],f32>, !torch.vtensor<[2,0,4],f32> -> !torch.vtensor<[2,0,4],f32>
// CHECK: %[[INT2:.+]] = torch.constant.int 2
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[INT2]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
// CHECK: %[[NONE:.+]] = torch.constant.none
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[MULT]], %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[2,0,4],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[2,0,1],f32>
// CHECK: return %[[SUM]] : !torch.vtensor<[2,0,1],f32>
%0 = torch.operator "onnx.ReduceSumSquare"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[2,0,4],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,0,1],f32>
return %0 : !torch.vtensor<[2,0,1],f32>
}

// -----

// CHECK-LABEL: func.func @test_reduce_sum_square_keepdims_example
func.func @test_reduce_sum_square_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[MULT:.+]] = torch.aten.mul.Tensor %arg0, %arg0 : !torch.vtensor<[3,2,2],f32>, !torch.vtensor<[3,2,2],f32> -> !torch.vtensor<[3,2,2],f32>
// CHECK: %[[INT0:.+]] = torch.constant.int 0
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
// CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
// CHECK: %[[NONE:.+]] = torch.constant.none
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[MULT]], %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,1,2],f32>
// CHECK: return %[[SUM]] : !torch.vtensor<[3,1,2],f32>
%0 = torch.operator "onnx.ReduceSumSquare"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32>
return %0 : !torch.vtensor<[3,1,2],f32>
}

// -----

// CHECK-LABEL: func.func @test_reduce_sum_square_keepdims_int_example
func.func @test_reduce_sum_square_keepdims_int_example(%arg0: !torch.vtensor<[3,2,2],si64>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[MULT:.+]] = torch.aten.mul.Tensor %arg0, %arg0 : !torch.vtensor<[3,2,2],si64>, !torch.vtensor<[3,2,2],si64> -> !torch.vtensor<[3,2,2],si64>
// CHECK: %[[INT0:.+]] = torch.constant.int 0
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
// CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
// CHECK: %[[NONE:.+]] = torch.constant.none
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[MULT]], %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],si64>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,1,2],f32>
// CHECK: return %[[SUM]] : !torch.vtensor<[3,1,2],f32>
%0 = torch.operator "onnx.ReduceSumSquare"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32>
return %0 : !torch.vtensor<[3,1,2],f32>
}
vivekkhandelwal1 marked this conversation as resolved.
Show resolved Hide resolved

// -----

// CHECK-LABEL: @test_reduce_mean_negative_axes_keepdims_example
func.func @test_reduce_mean_negative_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} {
// CHECK: %[[TENSOR:.+]] = torch.vtensor.literal(dense<-2> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
Expand Down
Loading