Skip to content

Commit

Permalink
Fix scatter op
Browse files Browse the repository at this point in the history
  • Loading branch information
AmosLewis committed Dec 6, 2024
1 parent 1ac257b commit 9c47fd1
Showing 1 changed file with 15 additions and 22 deletions.
37 changes: 15 additions & 22 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5687,53 +5687,46 @@ class DecomposeAtenNonzeroOp : public OpRewritePattern<AtenNonzeroOp> {
Value boolMask = rewriter.create<AtenNeScalarOp>(
loc, boolMaskType, flattenedInput, constantZero);

// nonzero_mask = nonzero_mask.int()
// nonzero_mask = nonzero_mask.int()
Value falseCst = rewriter.create<ConstantBoolOp>(loc, false);
Value noneCst = rewriter.create<ConstantNoneOp>(loc);
auto intMaskType = flattenedInputType.getWithSizesAndDtype(
flattenedInputType.getOptionalSizes(), si64Type); // ####
flattenedInputType.getOptionalSizes(), si64Type);
Value intMask = rewriter.create<AtenToDtypeOp>(
loc, intMaskType, boolMask, si64Dtype, falseCst, falseCst, noneCst);

// destination_indices = torch.cumsum(nonzero_mask, 0) - 1
auto cumulativeSumType =
dyn_cast<BaseTensorType>(flattenedInputType.getWithSizesAndDtype(
flattenedInputType.getOptionalSizes(), si64Type));
Value cumulativeSum = rewriter.create<AtenCumsumOp>(
loc, cumulativeSumType, intMask, constantZero, noneCst);
Value subtracted =
rewriter.create<AtenSubScalarOp>(loc, cumulativeSumType, cumulativeSum,
constantOne, /*alpha=*/constantOne);
loc, intMaskType, intMask, constantZero, noneCst);
Value subtracted = rewriter.create<AtenSubScalarOp>(
loc, intMaskType, cumulativeSum, constantOne, /*alpha=*/constantOne);

// destination_indices = torch.clamp(destination_indices, min=0)
Value indices = rewriter.create<AtenClampMinOp>(loc, cumulativeSumType,
Value indices = rewriter.create<AtenClampMinOp>(loc, intMaskType,
subtracted, constantZero);

// iota = torch.arange(len(t_flat), device=t.device) * nonzero_mask
// iota = torch.arange(len(t_flat)) * nonzero_mask
Value end = rewriter.create<AtenSizeIntOp>(loc, flattenedInput,
/*dim=*/constantZero);
Value rangeTensor = rewriter.create<AtenArangeStartStepOp>(
loc, cumulativeSumType, constantZero, end, constantOne, noneCst,
noneCst, noneCst, noneCst);
Value multiplied = rewriter.create<AtenMulTensorOp>(loc, cumulativeSumType,
loc, intMaskType, /*start*/ constantZero, /*end*/ end,
/*step*/ constantOne, noneCst, noneCst, noneCst, noneCst);
Value multiplied = rewriter.create<AtenMulTensorOp>(loc, intMaskType,
rangeTensor, intMask);

// scatter_self = torch.zeros_like(t, dtype=torch.int64)
// AtenFullLike doesn't support index type so we have to use si64
auto zerosTensorType = cumulativeSumType.getWithSizesAndDtype(
cumulativeSumType.getOptionalSizes(), si64Type);
Value zerosTensor = rewriter.create<AtenZerosLikeOp>(
loc, zerosTensorType, cumulativeSum, si64Dtype, noneCst, noneCst,
noneCst, noneCst);
loc, intMaskType, flattenedInput, si64Dtype, noneCst, noneCst, noneCst,
noneCst);

// compacted = scatter_self.scatter_(
// dim=0,
// index=destination_indices,
// src=iota, reduce='add')
Value reduceStr = rewriter.create<ConstantStrOp>(loc, "sum");
Value scatteredTensor = rewriter.create<AtenScatterReduceTwoOp>(
loc, cumulativeSumType, zerosTensor, /*axis=*/constantZero,
/*dims=*/indices, /*src=*/multiplied, reduceStr, falseCst);
Value scatteredTensor = rewriter.create<AtenScatterAddOp>(
loc, intMaskType, /*self*/ zerosTensor, /*dim=*/constantZero,
/*index=*/indices, /*src=*/multiplied);

// result_flat = compacted[:torch.sum(nonzero_mask)]
auto scalarType = ValueTensorType::get(rewriter.getContext(),
Expand Down

0 comments on commit 9c47fd1

Please sign in to comment.