From 508be9f8401a6171131c5eb8f525608cd706d528 Mon Sep 17 00:00:00 2001 From: AmosLewis Date: Thu, 5 Dec 2024 23:05:24 -0800 Subject: [PATCH] [ONNX]Fix onnx to torch nonzero result type and delete multi dim for debug --- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 23 ++- .../Torch/Transforms/DecomposeComplexOps.cpp | 138 ++++++++++-------- .../torch_mlir_e2e_test/test_suite/basic.py | 2 +- .../torch_mlir_e2e_test/test_suite/scatter.py | 29 ++++ 4 files changed, 125 insertions(+), 67 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 7446b7faaa08..0d1b02fde1ab 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -1101,8 +1101,27 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.tensorResultType(resultType)) { return failure(); } - rewriter.replaceOpWithNewOp( - binder.op, resultType, operand); + Value zero = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + Value one = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 1)); + auto rawSize = resultType.getSizes(); + SmallVector torchResultSize(rawSize.rbegin(), + rawSize.rend()); + auto torchResultSizeType = Torch::ValueTensorType::get( + rewriter.getContext(), torchResultSize, + rewriter.getIntegerType(64, /*signed=*/true)); + auto nonZero = rewriter.create( + binder.getLoc(), torchResultSizeType, operand); + // The output tensor has a shape of ((n, z)), where (n) is the + // number of dimensions in the input tensor and (z) is the + // number of non-zero elements2. This is different from + // PyTorch's default behavior, where the dimensions are + // reversed. + rewriter.replaceOpWithNewOp( + binder.op, resultType, nonZero, zero, one); return success(); }); patterns.onOp( diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 62a8bcdd7ba1..f0b7bf2cdc79 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -5726,15 +5726,6 @@ class DecomposeAtenNonzeroOp : public OpRewritePattern { auto inputType = dyn_cast(input.getType()); int64_t inputRank = inputType.getSizes().size(); - // original_shape = t.shape - // input_shape_tensor = torch.tensor(original_shape, device=t.device) - auto shapeType = Torch::ValueTensorType::get( - rewriter.getContext(), SmallVector{inputRank}, si64Type); - // %2 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[?],i1> -> - // !torch.vtensor<[1],si64> - Value inputShapeTensor = - rewriter.create(loc, shapeType, input); - // t_flat = t.flatten() # torch.flatten(t, 0, 0) int64_t flattenedSize = 1; if (inputType.hasSizes()) { @@ -5821,64 +5812,83 @@ class DecomposeAtenNonzeroOp : public OpRewritePattern { /*end=*/numNonzero, /*step=*/constantOne); - // Convert flattened indices back to multi-dimensional indices - // strides = torch.cumprod(torch.flip(input_shape_tensor, [0]), 0).flip(0) - Value flippedShape = rewriter.create( - loc, shapeType, inputShapeTensor, makeOneElementList(constantZero)); - Value cumulativeProduct = rewriter.create( - loc, shapeType, flippedShape, constantZero, noneCst); - Value flippedCumulativeProduct = rewriter.create( - loc, shapeType, cumulativeProduct, makeOneElementList(constantZero)); - - // strides = torch.cat([strides[1:], torch.tensor([1])]) - // strides[1:] - auto slicedStrideType = Torch::ValueTensorType::get( - rewriter.getContext(), SmallVector{inputRank - 1}, // sizes - si64Type); - Value strideSliceEnd = rewriter.create( - loc, rewriter.getI64IntegerAttr(inputRank)); - Value slicedStrides = rewriter.create( - loc, slicedStrideType, /*self*/ flippedCumulativeProduct, - /*dim*/ constantZero, - /*start=*/constantOne, /*end=*/strideSliceEnd, /*step=*/constantOne); - // torch.tensor([1]) - auto oneTensorType = ValueTensorType::get( - rewriter.getContext(), SmallVector{1}, si64Type); - Value oneTensor = rewriter.create( - loc, oneTensorType, constantOne, si64Dtype, noneCst, noneCst, noneCst); - // torch.cat - auto tensorListElementType = Torch::ValueTensorType::get( - rewriter.getContext(), SmallVector{kUnknownSize}, si64Type); - Value tensorList = rewriter.create( - loc, Torch::ListType::get(tensorListElementType), - SmallVector{slicedStrides, oneTensor}); - Value strides = rewriter.create(loc, shapeType, - tensorList, constantZero); - - // multi_indices = (result_flat.unsqueeze(1) // strides.unsqueeze(0)) % - // input_shape_tensor + // DEBUG one dimentional result. auto unsqueezedResultType = ValueTensorType::get( rewriter.getContext(), SmallVector{kUnknownSize, 1}, si64Type); Value unsqueezedResult = rewriter.create( - loc, unsqueezedResultType, slicedResult, constantOne); - - auto unsqueezedStridesType = ValueTensorType::get( - rewriter.getContext(), SmallVector{1, inputRank}, si64Type); - Value unsqueezedStrides = rewriter.create( - loc, unsqueezedStridesType, strides, constantZero); - - auto dividedBroadcastType = ValueTensorType::get( - rewriter.getContext(), SmallVector{kUnknownSize, inputRank}, - si64Type); - Value divided = rewriter.create( - loc, dividedBroadcastType, unsqueezedResult, unsqueezedStrides); - - auto resultType = cast(op.getType()); - Value modded = rewriter.create( - loc, resultType, divided, inputShapeTensor); - - rewriter.replaceOp(op, modded); - return success(); + loc, unsqueezedResultType, slicedResult, constantZero); + rewriter.replaceOp(op, unsqueezedResult); + return success(); + + // // Convert flattened indices back to multi-dimensional indices + // // original_shape = t.shape + // // input_shape_tensor = torch.tensor(original_shape) + // auto shapeType = Torch::ValueTensorType::get( + // rewriter.getContext(), SmallVector{inputRank}, si64Type); + // // %2 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[?],i1> -> + // // !torch.vtensor<[1],si64> + // Value inputShapeTensor = + // rewriter.create(loc, shapeType, input); + // // strides = torch.cumprod(torch.flip(input_shape_tensor,[0]),0).flip(0) + // Value flippedShape = rewriter.create( + // loc, shapeType, inputShapeTensor, makeOneElementList(constantZero)); + // Value cumulativeProduct = rewriter.create( + // loc, shapeType, flippedShape, constantZero, noneCst); + // Value flippedCumulativeProduct = rewriter.create( + // loc, shapeType, cumulativeProduct, makeOneElementList(constantZero)); + + // // strides = torch.cat([strides[1:], torch.tensor([1])]) + // // strides[1:] + // auto slicedStrideType = Torch::ValueTensorType::get( + // rewriter.getContext(), SmallVector{inputRank - 1}, // sizes + // si64Type); + // Value strideSliceEnd = rewriter.create( + // loc, rewriter.getI64IntegerAttr(inputRank)); + // Value slicedStrides = rewriter.create( + // loc, slicedStrideType, /*self*/ flippedCumulativeProduct, + // /*dim*/ constantZero, + // /*start=*/constantOne, /*end=*/strideSliceEnd, /*step=*/constantOne); + // // torch.tensor([1]) + // auto oneTensorType = ValueTensorType::get( + // rewriter.getContext(), SmallVector{1}, si64Type); + // Value oneTensor = rewriter.create( + // loc, oneTensorType, constantOne, si64Dtype, noneCst, noneCst, + // noneCst); + // // torch.cat + // auto tensorListElementType = Torch::ValueTensorType::get( + // rewriter.getContext(), SmallVector{kUnknownSize}, si64Type); + // Value tensorList = rewriter.create( + // loc, Torch::ListType::get(tensorListElementType), + // SmallVector{slicedStrides, oneTensor}); + // Value strides = rewriter.create(loc, shapeType, + // tensorList, + // constantZero); + + // // multi_indices = (result_flat.unsqueeze(1) // strides.unsqueeze(0)) % + // // input_shape_tensor + // auto unsqueezedResultType = ValueTensorType::get( + // rewriter.getContext(), SmallVector{kUnknownSize, 1}, + // si64Type); + // Value unsqueezedResult = rewriter.create( + // loc, unsqueezedResultType, slicedResult, constantOne); + + // auto unsqueezedStridesType = ValueTensorType::get( + // rewriter.getContext(), SmallVector{1, inputRank}, si64Type); + // Value unsqueezedStrides = rewriter.create( + // loc, unsqueezedStridesType, strides, constantZero); + + // auto dividedBroadcastType = ValueTensorType::get( + // rewriter.getContext(), SmallVector{kUnknownSize, inputRank}, + // si64Type); + // Value divided = rewriter.create( + // loc, dividedBroadcastType, unsqueezedResult, unsqueezedStrides); + + // auto resultType = cast(op.getType()); + // Value modded = rewriter.create( + // loc, resultType, divided, inputShapeTensor); + + // rewriter.replaceOp(op, modded); + // return success(); } }; diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index 3c6b4930f55c..a64a521fce60 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -6429,4 +6429,4 @@ def forward(self, x): @register_test_case(module_factory=lambda: AtenNonzero1DModule()) def AtenNonzero1DModule_one_nonzero(module, tu: TestUtils): - module.forward(torch.tensor([0, 0, 0, 1, 0, 0], dtype=torch.int)) + module.forward(torch.tensor([0, 0, 1, 1, 0, 0], dtype=torch.int)) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/scatter.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/scatter.py index ee85855e4aa8..709702ca0054 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/scatter.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/scatter.py @@ -1045,6 +1045,35 @@ def ScatterAddStaticModule_basic(module, tu: TestUtils): # ============================================================================== +class ScatterAddDynamicModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1], torch.int64, True), + ([-1], torch.int64, True), + ([-1], torch.int64, True), + ] + ) + def forward(self, input, index, src): + return torch.ops.aten.scatter_add(input, 0, index, src) + + +@register_test_case(module_factory=lambda: ScatterAddDynamicModule()) +def ScatterAddDynamicModule_basic(module, tu: TestUtils): + module.forward( + torch.tensor([0, 0, 0, 0, 0, 0]), + torch.tensor([0, 0, 0, 0, 0, 0]), + torch.tensor([0, 0, 0, 3, 0, 0]) + ) + + +# ============================================================================== + + class ScatterReduceFloatModule(torch.nn.Module): include_self: bool reduce_type: str