From 65f4a3d930bd93340f4a22a7956eeecfbecf46d3 Mon Sep 17 00:00:00 2001 From: AmosLewis Date: Wed, 11 Dec 2024 20:54:28 -0800 Subject: [PATCH] Add decomposenonzero e2e tests. Fix oneTensorType size. Fix torch.cat empty tensor. --- build_tools/ci/test_posix.sh | 21 ++- .../Torch/Transforms/DecomposeComplexOps.cpp | 59 ++++--- .../torch_mlir_e2e_test/test_suite/basic.py | 156 ++++++++++++++++++ .../torch_mlir_e2e_test/test_suite/scatter.py | 6 +- 4 files changed, 211 insertions(+), 31 deletions(-) diff --git a/build_tools/ci/test_posix.sh b/build_tools/ci/test_posix.sh index 97fb8a8336f2..9b48e09043d6 100755 --- a/build_tools/ci/test_posix.sh +++ b/build_tools/ci/test_posix.sh @@ -9,7 +9,26 @@ torch_version="${1:-unknown}" export PYTHONPATH="$repo_root/build/tools/torch-mlir/python_packages/torch_mlir:$repo_root/projects/pt1" echo "::group::Run ONNX e2e integration tests" -python -m e2e_testing.main --config=onnx -v --filter AtenNonzero1DModule_one_nonzero +# python -m e2e_testing.main --config=onnx -v --filter AtenNonzero1DModule_one_nonzero + +# python -m e2e_testing.main --config=onnx -v --filter NonzeroDecomposeModule_basic # Failed: 1 +# python -m e2e_testing.main --config=linalg -v --filter NonzeroDecomposeModule_basic # Passed: 1 + +# python -m e2e_testing.main --config=linalg -v --filter NonzeroFlattenDynamicModule # Passed: 1 + +# python -m e2e_testing.main --config=onnx -v --filter ScatterAddDynamicModule_basic # Passed: 1 + +# python -m e2e_testing.main --config=onnx -v --filter NonzeroCatModule # Passed: 1 +# python -m e2e_testing.main --config=linalg -v --filter NonzeroCatModule # Failed: 1 +# tensor with unknown dtype "torch.aten.cat"(%31, %4) : (!torch.list, !torch.int) -> !torch.vtensor<[1],unk> + +# python -m e2e_testing.main --config=linalg -v --filter NonzeroCatModule # Failed: 1 + +# python -m e2e_testing.main --config=linalg -v --filter NonzeroCumsumModule +# python -m e2e_testing.main --config=onnx -v --filter NonzeroCumsumModule # pass +python -m e2e_testing.main --config=onnx -v --filter NonzeroCumsumBoolModule # pass in torch-mlir, failed in iree + +# python -m e2e_testing.main --config=onnx -v --filter NonzeroLongModule echo "::endgroup::" # case $torch_version in diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index be868bf0488e..b3e365568ab0 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -5814,7 +5814,8 @@ class DecomposeAtenNonzeroOp : public OpRewritePattern { // // DEBUG one dimentional result. // auto unsqueezedResultType = ValueTensorType::get( - // rewriter.getContext(), SmallVector{kUnknownSize, 1}, si64Type); + // rewriter.getContext(), SmallVector{kUnknownSize, 1}, + // si64Type); // Value unsqueezedResult = rewriter.create( // loc, unsqueezedResultType, slicedResult, constantZero); // rewriter.replaceOp(op, unsqueezedResult); @@ -5837,38 +5838,42 @@ class DecomposeAtenNonzeroOp : public OpRewritePattern { Value flippedCumulativeProduct = rewriter.create( loc, shapeType, cumulativeProduct, makeOneElementList(constantZero)); - // strides = torch.cat([strides[1:], torch.tensor([1])]) - // strides[1:] - auto slicedStrideType = Torch::ValueTensorType::get( - rewriter.getContext(), SmallVector{inputRank - 1}, // sizes - si64Type); - Value strideSliceEnd = rewriter.create( - loc, rewriter.getI64IntegerAttr(inputRank)); - Value slicedStrides = rewriter.create( - loc, slicedStrideType, /*self*/ flippedCumulativeProduct, - /*dim*/ constantZero, - /*start=*/constantOne, /*end=*/strideSliceEnd, /*step=*/constantOne); + // strides = torch.cat([strides[1:-1], torch.tensor([1])]) // torch.tensor([1]) - auto oneTensorType = ValueTensorType::get( - rewriter.getContext(), SmallVector{1}, si64Type); + auto oneTensorType = ValueTensorType::get(rewriter.getContext(), + SmallVector{}, si64Type); Value oneTensor = rewriter.create( - loc, oneTensorType, constantOne, si64Dtype, noneCst, noneCst, - noneCst); - // torch.cat - auto tensorListElementType = Torch::ValueTensorType::get( - rewriter.getContext(), SmallVector{kUnknownSize}, si64Type); - Value tensorList = rewriter.create( - loc, Torch::ListType::get(tensorListElementType), - SmallVector{slicedStrides, oneTensor}); - Value strides = rewriter.create(loc, shapeType, - tensorList, - constantZero); + loc, oneTensorType, constantOne, si64Dtype, noneCst, noneCst, noneCst); + + Value strides; + if (inputRank > 1) { + // strides[1:-1] + auto slicedStrideType = Torch::ValueTensorType::get( + rewriter.getContext(), SmallVector{inputRank - 1}, // sizes + si64Type); + Value strideSliceEnd = rewriter.create( + loc, rewriter.getI64IntegerAttr(inputRank)); + Value slicedStrides = rewriter.create( + loc, slicedStrideType, /*self*/ flippedCumulativeProduct, + /*dim*/ constantZero, + /*start=*/constantOne, /*end=*/strideSliceEnd, /*step=*/constantOne); + // torch.cat + auto tensorListElementType = Torch::ValueTensorType::get( + rewriter.getContext(), SmallVector{kUnknownSize}, si64Type); + Value tensorList = rewriter.create( + loc, Torch::ListType::get(tensorListElementType), + SmallVector{slicedStrides, oneTensor}); + strides = rewriter.create(loc, shapeType, tensorList, + constantZero); + } else { + // strides[1:-1] is empty + strides = oneTensor; + } // multi_indices = (result_flat.unsqueeze(1) // strides.unsqueeze(0)) % // input_shape_tensor auto unsqueezedResultType = ValueTensorType::get( - rewriter.getContext(), SmallVector{kUnknownSize, 1}, - si64Type); + rewriter.getContext(), SmallVector{kUnknownSize, 1}, si64Type); Value unsqueezedResult = rewriter.create( loc, unsqueezedResultType, slicedResult, constantOne); diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index a64a521fce60..dee420a1d929 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -6430,3 +6430,159 @@ def forward(self, x): @register_test_case(module_factory=lambda: AtenNonzero1DModule()) def AtenNonzero1DModule_one_nonzero(module, tu: TestUtils): module.forward(torch.tensor([0, 0, 1, 1, 0, 0], dtype=torch.int)) + + +class NonzeroDecomposeModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1], torch.int, True), + ] + ) + def forward(self, t): + t_flat = t.flatten() + nonzero_mask = t_flat != 0 + nonzero_mask = nonzero_mask.long() + destination_indices = torch.cumsum(nonzero_mask, 0) - 1 + destination_indices_clamp = torch.clamp(destination_indices, min=0) + iota = torch.arange(t_flat.size(0)) * nonzero_mask + scatter_self = torch.zeros_like(t_flat, dtype=torch.int64) + # compacted = scatter_self.scatter_( + # dim=0, + # index=destination_indices_clamp, + # src=iota, + # reduce='add' + # ) + compacted = torch.scatter_add( + scatter_self, dim=0, index=destination_indices_clamp, src=iota + ) + result_flat = compacted[: torch.sum(nonzero_mask)] + + # multi dim + original_shape = t.shape + input_shape_tensor = torch.tensor(original_shape) + strides = torch.cumprod(torch.flip(input_shape_tensor, [0]), 0).flip(0) + + one = torch.tensor([1]) + if(t.dim() > 1): + slicedStrides = strides[1:-1] + strides = torch.cat([slicedStrides, one]) + # return strides + else: + strides = one + a = result_flat.unsqueeze(1) # tensor([[2], [3]]) torch.Size([2, 1]) + b = strides.unsqueeze(0) # tensor([[1]]) torch.Size([1, 1]) + c = a // b + multi_indices = c % input_shape_tensor + return multi_indices + + +@register_test_case(module_factory=lambda: NonzeroDecomposeModule()) +def NonzeroDecomposeModule_basic(module, tu: TestUtils): + module.forward(torch.tensor([0, 0, 1, 1, 0, 0], dtype=torch.int)) + + +class NonzeroFlattenDynamicModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1], torch.int, True), + ] + ) + def forward(self, x): + return x.flatten() + + +@register_test_case(module_factory=lambda: NonzeroFlattenDynamicModule()) +def NonzeroFlattenDynamicModule_basic(module, tu: TestUtils): + module.forward(torch.tensor([0, 0, 1, 1, 0, 0], dtype=torch.int)) + + +class NonzeroCatModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1], torch.int, True), + ] + ) + def forward(self, a): + a = a[1:1] + b = torch.tensor([1]) + return torch.cat([a, b]) + + +@register_test_case(module_factory=lambda: NonzeroCatModule()) +def NonzeroCatModule_basic(module, tu: TestUtils): + module.forward(torch.tensor([6], dtype=torch.int)) + + +class NonzeroCumsumModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1], torch.int64, True), + ] + ) + def forward(self, x): + return torch.cumsum(x, 0) + + +@register_test_case(module_factory=lambda: NonzeroCumsumModule()) +def NonzeroCumsumModule_basic(module, tu: TestUtils): + module.forward(torch.tensor([0, 0, 1, 1, 0, 0], dtype=torch.int64)) + + +class NonzeroCumsumBoolModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1], torch.bool, True), + ] + ) + def forward(self, x): + return torch.cumsum(x.long(), 0) + + +@register_test_case(module_factory=lambda: NonzeroCumsumBoolModule()) +def NonzeroCumsumBoolModule_basic(module, tu: TestUtils): + module.forward(torch.tensor([0, 0, 1, 1, 0, 0], dtype=torch.bool)) + + +class NonzeroLongModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1], torch.bool, True), + ] + ) + def forward(self, x): + return x.long() + + +@register_test_case(module_factory=lambda: NonzeroLongModule()) +def NonzeroLongModule_basic(module, tu: TestUtils): + module.forward(torch.tensor([0, 0, 1, 1, 0, 0], dtype=torch.bool)) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/scatter.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/scatter.py index 709702ca0054..43645e0a31e0 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/scatter.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/scatter.py @@ -1065,9 +1065,9 @@ def forward(self, input, index, src): @register_test_case(module_factory=lambda: ScatterAddDynamicModule()) def ScatterAddDynamicModule_basic(module, tu: TestUtils): module.forward( - torch.tensor([0, 0, 0, 0, 0, 0]), - torch.tensor([0, 0, 0, 0, 0, 0]), - torch.tensor([0, 0, 0, 3, 0, 0]) + torch.tensor([0, 0, 0, 0, 0, 0]), + torch.tensor([0, 0, 0, 0, 0, 0]), + torch.tensor([0, 0, 0, 3, 0, 0]), )