diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 2b8ac02b761f..62a8bcdd7ba1 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -5763,53 +5763,46 @@ class DecomposeAtenNonzeroOp : public OpRewritePattern { Value boolMask = rewriter.create( loc, boolMaskType, flattenedInput, constantZero); - // nonzero_mask = nonzero_mask.int() + // nonzero_mask = nonzero_mask.int() Value falseCst = rewriter.create(loc, false); Value noneCst = rewriter.create(loc); auto intMaskType = flattenedInputType.getWithSizesAndDtype( - flattenedInputType.getOptionalSizes(), si64Type); // #### + flattenedInputType.getOptionalSizes(), si64Type); Value intMask = rewriter.create( loc, intMaskType, boolMask, si64Dtype, falseCst, falseCst, noneCst); // destination_indices = torch.cumsum(nonzero_mask, 0) - 1 - auto cumulativeSumType = - dyn_cast(flattenedInputType.getWithSizesAndDtype( - flattenedInputType.getOptionalSizes(), si64Type)); Value cumulativeSum = rewriter.create( - loc, cumulativeSumType, intMask, constantZero, noneCst); - Value subtracted = - rewriter.create(loc, cumulativeSumType, cumulativeSum, - constantOne, /*alpha=*/constantOne); + loc, intMaskType, intMask, constantZero, noneCst); + Value subtracted = rewriter.create( + loc, intMaskType, cumulativeSum, constantOne, /*alpha=*/constantOne); // destination_indices = torch.clamp(destination_indices, min=0) - Value indices = rewriter.create(loc, cumulativeSumType, + Value indices = rewriter.create(loc, intMaskType, subtracted, constantZero); - // iota = torch.arange(len(t_flat), device=t.device) * nonzero_mask + // iota = torch.arange(len(t_flat)) * nonzero_mask Value end = rewriter.create(loc, flattenedInput, /*dim=*/constantZero); Value rangeTensor = rewriter.create( - loc, cumulativeSumType, constantZero, end, constantOne, noneCst, - noneCst, noneCst, noneCst); - Value multiplied = rewriter.create(loc, cumulativeSumType, + loc, intMaskType, /*start*/ constantZero, /*end*/ end, + /*step*/ constantOne, noneCst, noneCst, noneCst, noneCst); + Value multiplied = rewriter.create(loc, intMaskType, rangeTensor, intMask); // scatter_self = torch.zeros_like(t, dtype=torch.int64) // AtenFullLike doesn't support index type so we have to use si64 - auto zerosTensorType = cumulativeSumType.getWithSizesAndDtype( - cumulativeSumType.getOptionalSizes(), si64Type); Value zerosTensor = rewriter.create( - loc, zerosTensorType, cumulativeSum, si64Dtype, noneCst, noneCst, - noneCst, noneCst); + loc, intMaskType, flattenedInput, si64Dtype, noneCst, noneCst, noneCst, + noneCst); // compacted = scatter_self.scatter_( // dim=0, // index=destination_indices, // src=iota, reduce='add') - Value reduceStr = rewriter.create(loc, "sum"); - Value scatteredTensor = rewriter.create( - loc, cumulativeSumType, zerosTensor, /*axis=*/constantZero, - /*dims=*/indices, /*src=*/multiplied, reduceStr, falseCst); + Value scatteredTensor = rewriter.create( + loc, intMaskType, /*self*/ zerosTensor, /*dim=*/constantZero, + /*index=*/indices, /*src=*/multiplied); // result_flat = compacted[:torch.sum(nonzero_mask)] auto scalarType = ValueTensorType::get(rewriter.getContext(), @@ -5828,6 +5821,7 @@ class DecomposeAtenNonzeroOp : public OpRewritePattern { /*end=*/numNonzero, /*step=*/constantOne); + // Convert flattened indices back to multi-dimensional indices // strides = torch.cumprod(torch.flip(input_shape_tensor, [0]), 0).flip(0) Value flippedShape = rewriter.create( loc, shapeType, inputShapeTensor, makeOneElementList(constantZero)); @@ -5835,21 +5829,24 @@ class DecomposeAtenNonzeroOp : public OpRewritePattern { loc, shapeType, flippedShape, constantZero, noneCst); Value flippedCumulativeProduct = rewriter.create( loc, shapeType, cumulativeProduct, makeOneElementList(constantZero)); - // strides = torch.cat([strides[1:], torch.tensor([1], device=t.device)]) - auto oneTensorType = ValueTensorType::get( - rewriter.getContext(), SmallVector{1}, si64Type); - Value oneTensor = rewriter.create( - loc, oneTensorType, constantOne, si64Dtype, noneCst, noneCst, noneCst); + // strides = torch.cat([strides[1:], torch.tensor([1])]) + // strides[1:] auto slicedStrideType = Torch::ValueTensorType::get( rewriter.getContext(), SmallVector{inputRank - 1}, // sizes si64Type); Value strideSliceEnd = rewriter.create( loc, rewriter.getI64IntegerAttr(inputRank)); Value slicedStrides = rewriter.create( - loc, slicedStrideType, flippedCumulativeProduct, /*dim*/ constantZero, + loc, slicedStrideType, /*self*/ flippedCumulativeProduct, + /*dim*/ constantZero, /*start=*/constantOne, /*end=*/strideSliceEnd, /*step=*/constantOne); - + // torch.tensor([1]) + auto oneTensorType = ValueTensorType::get( + rewriter.getContext(), SmallVector{1}, si64Type); + Value oneTensor = rewriter.create( + loc, oneTensorType, constantOne, si64Dtype, noneCst, noneCst, noneCst); + // torch.cat auto tensorListElementType = Torch::ValueTensorType::get( rewriter.getContext(), SmallVector{kUnknownSize}, si64Type); Value tensorList = rewriter.create( diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index 72bd658b9439..3c6b4930f55c 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -6429,4 +6429,4 @@ def forward(self, x): @register_test_case(module_factory=lambda: AtenNonzero1DModule()) def AtenNonzero1DModule_one_nonzero(module, tu: TestUtils): - module.forward(torch.tensor([0, 0, 5, 0, 0, 0], dtype=torch.int)) + module.forward(torch.tensor([0, 0, 0, 1, 0, 0], dtype=torch.int))