Skip to content

Commit

Permalink
Fix scatter op
Browse files Browse the repository at this point in the history
  • Loading branch information
AmosLewis committed Dec 7, 2024
1 parent f4f26db commit 3104e4f
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 30 deletions.
55 changes: 26 additions & 29 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5763,53 +5763,46 @@ class DecomposeAtenNonzeroOp : public OpRewritePattern<AtenNonzeroOp> {
Value boolMask = rewriter.create<AtenNeScalarOp>(
loc, boolMaskType, flattenedInput, constantZero);

// nonzero_mask = nonzero_mask.int()
// nonzero_mask = nonzero_mask.int()
Value falseCst = rewriter.create<ConstantBoolOp>(loc, false);
Value noneCst = rewriter.create<ConstantNoneOp>(loc);
auto intMaskType = flattenedInputType.getWithSizesAndDtype(
flattenedInputType.getOptionalSizes(), si64Type); // ####
flattenedInputType.getOptionalSizes(), si64Type);
Value intMask = rewriter.create<AtenToDtypeOp>(
loc, intMaskType, boolMask, si64Dtype, falseCst, falseCst, noneCst);

// destination_indices = torch.cumsum(nonzero_mask, 0) - 1
auto cumulativeSumType =
dyn_cast<BaseTensorType>(flattenedInputType.getWithSizesAndDtype(
flattenedInputType.getOptionalSizes(), si64Type));
Value cumulativeSum = rewriter.create<AtenCumsumOp>(
loc, cumulativeSumType, intMask, constantZero, noneCst);
Value subtracted =
rewriter.create<AtenSubScalarOp>(loc, cumulativeSumType, cumulativeSum,
constantOne, /*alpha=*/constantOne);
loc, intMaskType, intMask, constantZero, noneCst);
Value subtracted = rewriter.create<AtenSubScalarOp>(
loc, intMaskType, cumulativeSum, constantOne, /*alpha=*/constantOne);

// destination_indices = torch.clamp(destination_indices, min=0)
Value indices = rewriter.create<AtenClampMinOp>(loc, cumulativeSumType,
Value indices = rewriter.create<AtenClampMinOp>(loc, intMaskType,
subtracted, constantZero);

// iota = torch.arange(len(t_flat), device=t.device) * nonzero_mask
// iota = torch.arange(len(t_flat)) * nonzero_mask
Value end = rewriter.create<AtenSizeIntOp>(loc, flattenedInput,
/*dim=*/constantZero);
Value rangeTensor = rewriter.create<AtenArangeStartStepOp>(
loc, cumulativeSumType, constantZero, end, constantOne, noneCst,
noneCst, noneCst, noneCst);
Value multiplied = rewriter.create<AtenMulTensorOp>(loc, cumulativeSumType,
loc, intMaskType, /*start*/ constantZero, /*end*/ end,
/*step*/ constantOne, noneCst, noneCst, noneCst, noneCst);
Value multiplied = rewriter.create<AtenMulTensorOp>(loc, intMaskType,
rangeTensor, intMask);

// scatter_self = torch.zeros_like(t, dtype=torch.int64)
// AtenFullLike doesn't support index type so we have to use si64
auto zerosTensorType = cumulativeSumType.getWithSizesAndDtype(
cumulativeSumType.getOptionalSizes(), si64Type);
Value zerosTensor = rewriter.create<AtenZerosLikeOp>(
loc, zerosTensorType, cumulativeSum, si64Dtype, noneCst, noneCst,
noneCst, noneCst);
loc, intMaskType, flattenedInput, si64Dtype, noneCst, noneCst, noneCst,
noneCst);

// compacted = scatter_self.scatter_(
// dim=0,
// index=destination_indices,
// src=iota, reduce='add')
Value reduceStr = rewriter.create<ConstantStrOp>(loc, "sum");
Value scatteredTensor = rewriter.create<AtenScatterReduceTwoOp>(
loc, cumulativeSumType, zerosTensor, /*axis=*/constantZero,
/*dims=*/indices, /*src=*/multiplied, reduceStr, falseCst);
Value scatteredTensor = rewriter.create<AtenScatterAddOp>(
loc, intMaskType, /*self*/ zerosTensor, /*dim=*/constantZero,
/*index=*/indices, /*src=*/multiplied);

// result_flat = compacted[:torch.sum(nonzero_mask)]
auto scalarType = ValueTensorType::get(rewriter.getContext(),
Expand All @@ -5828,28 +5821,32 @@ class DecomposeAtenNonzeroOp : public OpRewritePattern<AtenNonzeroOp> {
/*end=*/numNonzero,
/*step=*/constantOne);

// Convert flattened indices back to multi-dimensional indices
// strides = torch.cumprod(torch.flip(input_shape_tensor, [0]), 0).flip(0)
Value flippedShape = rewriter.create<AtenFlipOp>(
loc, shapeType, inputShapeTensor, makeOneElementList(constantZero));
Value cumulativeProduct = rewriter.create<AtenCumprodOp>(
loc, shapeType, flippedShape, constantZero, noneCst);
Value flippedCumulativeProduct = rewriter.create<AtenFlipOp>(
loc, shapeType, cumulativeProduct, makeOneElementList(constantZero));
// strides = torch.cat([strides[1:], torch.tensor([1], device=t.device)])
auto oneTensorType = ValueTensorType::get(
rewriter.getContext(), SmallVector<int64_t>{1}, si64Type);
Value oneTensor = rewriter.create<AtenScalarTensorOp>(
loc, oneTensorType, constantOne, si64Dtype, noneCst, noneCst, noneCst);

// strides = torch.cat([strides[1:], torch.tensor([1])])
// strides[1:]
auto slicedStrideType = Torch::ValueTensorType::get(
rewriter.getContext(), SmallVector<int64_t>{inputRank - 1}, // sizes
si64Type);
Value strideSliceEnd = rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(inputRank));
Value slicedStrides = rewriter.create<AtenSliceTensorOp>(
loc, slicedStrideType, flippedCumulativeProduct, /*dim*/ constantZero,
loc, slicedStrideType, /*self*/ flippedCumulativeProduct,
/*dim*/ constantZero,
/*start=*/constantOne, /*end=*/strideSliceEnd, /*step=*/constantOne);

// torch.tensor([1])
auto oneTensorType = ValueTensorType::get(
rewriter.getContext(), SmallVector<int64_t>{1}, si64Type);
Value oneTensor = rewriter.create<AtenScalarTensorOp>(
loc, oneTensorType, constantOne, si64Dtype, noneCst, noneCst, noneCst);
// torch.cat
auto tensorListElementType = Torch::ValueTensorType::get(
rewriter.getContext(), SmallVector<int64_t>{kUnknownSize}, si64Type);
Value tensorList = rewriter.create<Torch::PrimListConstructOp>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6429,4 +6429,4 @@ def forward(self, x):

@register_test_case(module_factory=lambda: AtenNonzero1DModule())
def AtenNonzero1DModule_one_nonzero(module, tu: TestUtils):
module.forward(torch.tensor([0, 0, 5, 0, 0, 0], dtype=torch.int))
module.forward(torch.tensor([0, 0, 0, 1, 0, 0], dtype=torch.int))

0 comments on commit 3104e4f

Please sign in to comment.