Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Torch] Add decomposition for 1d torch.nonzero #3876

Merged
merged 9 commits into from
Dec 19, 2024
238 changes: 238 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5705,6 +5705,243 @@ class DecomposeAtenConvolutionBackwardOp
};
} // namespace

/**
* # one dim input
* t = torch.tensor([0, 0, 1, 1, 0, 0]
* # t_flat:[0, 0, 1, 1, 0, 0]
* t_flat = t.flatten(0, 0)
* nonzero_mask = t_flat != 0
* # nonzero_mask:[0, 0, 1, 1, 0, 0]
* nonzero_mask = nonzero_mask.long()
* # destination_indices:[-1, -1, 0, 1, 1, 1]
* destination_indices = torch.cumsum(nonzero_mask, 0) - 1
* # destination_indices_clamp:[0, 0, 0, 1, 1, 1]
* destination_indices_clamp = torch.clamp(destination_indices, min=0)
* # iota:[0, 0, 2, 3, 0, 0]
* iota = torch.arange(t_flat.size(0)) * nonzero_mask
* # scatter_self:[0, 0, 0, 0, 0, 0]
* scatter_self = torch.zeros_like(t_flat, dtype=torch.int64)
* # compacted:[2, 3, 0, 0, 0, 0]
* compacted = torch.scatter_add(
* scatter_self, dim=0, index=destination_indices_clamp, src=iota
* )
* # result_flat:[2, 3]
* result_flat = compacted[: torch.sum(nonzero_mask)]
*
* # multi dim support
AmosLewis marked this conversation as resolved.
Show resolved Hide resolved
* original_shape = t.shape
* # input_shape_tensor:[6]
* input_shape_tensor = torch.tensor(original_shape)
* strides = torch.cumprod(torch.flip(input_shape_tensor, [0]), 0).flip(0)
*
* one = torch.tensor([1])
* if(t.dim() > 1):
* slicedStrides = strides[1:-1]
* strides = torch.cat([slicedStrides, one])
* else:
* strides = one
* # a: tensor([[2], [3]]) torch.Size([2, 1])
* a = result_flat.unsqueeze(1) # tensor([[2], [3]]) torch.Size([2, 1])
* # b: tensor([[1]]) torch.Size([1, 1])
* b = strides.unsqueeze(0)
* # c: tensor([[2], [3]]) torch.Size([2, 1])
* c = a // b
* # result: tensor([[2], [3]]) torch.Size([2, 1])
* result = c % input_shape_tensor
*/
class DecomposeAtenNonzeroOp : public OpRewritePattern<AtenNonzeroOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenNonzeroOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto resultType = cast<BaseTensorType>(op.getType());
auto intType = resultType.getDtype();
Value intTypeValue = getDtypeIntValueForType(rewriter, loc, intType);
auto constantZero =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
auto constantOne =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
std::function<Value(Value)> makeOneElementList = [&](Value element) {
auto listType = Torch::ListType::get(element.getType());
return rewriter.create<PrimListConstructOp>(loc, listType,
ArrayRef<Value>{element});
};

Value input = op.getSelf();
auto inputType = dyn_cast<BaseTensorType>(input.getType());
int64_t inputRank = inputType.getSizes().size();

// t_flat = t.flatten() # torch.flatten(t, 0, 0)
int64_t flattenedSize = 1;
if (inputType.hasSizes()) {
for (auto size : inputType.getSizes()) {
flattenedSize *= size;
}
} else {
flattenedSize = kUnknownSize;
}

auto flattendInputShape = SmallVector<int64_t>{flattenedSize};
auto flattenedInputType = rewriter.getType<Torch::ValueTensorType>(
flattendInputShape, inputType.getOptionalDtype());

// %1 = torch.aten.flatten.using_ints %arg0, %int0, %int0_0 :
auto inputDimsEnd = rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(inputRank - 1));
Value flattenedInput = rewriter.create<AtenFlattenUsingIntsOp>(
loc, flattenedInputType, input, constantZero /*inputDimsStart*/,
inputDimsEnd /*inputDimsEnd*/);

// nonzero_mask = (t_flat != 0)
auto boolMaskType = inputType.getWithSizesAndDtype(
flattenedInputType.getOptionalSizes(), rewriter.getI1Type());
Value boolMask = rewriter.create<AtenNeScalarOp>(
loc, boolMaskType, flattenedInput, constantZero);

// nonzero_mask = nonzero_mask.int()
Value falseCst = rewriter.create<ConstantBoolOp>(loc, false);
Value noneCst = rewriter.create<ConstantNoneOp>(loc);
auto intMaskType = flattenedInputType.getWithSizesAndDtype(
flattenedInputType.getOptionalSizes(), intType);
Value intMask = rewriter.create<AtenToDtypeOp>(
loc, intMaskType, boolMask, intTypeValue, falseCst, falseCst, noneCst);

// destination_indices = torch.cumsum(nonzero_mask, 0) - 1
Value cumulativeSum = rewriter.create<AtenCumsumOp>(
loc, intMaskType, intMask, constantZero, noneCst);
Value subtracted = rewriter.create<AtenSubScalarOp>(
loc, intMaskType, cumulativeSum, constantOne, /*alpha=*/constantOne);

// destination_indices = torch.clamp(destination_indices, min=0)
Value indices = rewriter.create<AtenClampMinOp>(loc, intMaskType,
subtracted, constantZero);

// iota = torch.arange(len(t_flat)) * nonzero_mask
Value end = rewriter.create<AtenSizeIntOp>(loc, flattenedInput,
/*dim=*/constantZero);
Value rangeTensor = rewriter.create<AtenArangeStartStepOp>(
loc, intMaskType, /*start*/ constantZero, /*end*/ end,
/*step*/ constantOne, noneCst, noneCst, noneCst, noneCst);
Value multiplied = rewriter.create<AtenMulTensorOp>(loc, intMaskType,
rangeTensor, intMask);

// scatter_self = torch.zeros_like(t, dtype=torch.int64)
// AtenFullLike doesn't support index type so we have to use int.
Value zerosTensor = rewriter.create<AtenZerosLikeOp>(
loc, intMaskType, flattenedInput, intTypeValue, noneCst, noneCst,
noneCst, noneCst);

// compacted = torch.scatter_add(
// scatter_self, dim=0, index=destination_indices_clamp, src=iota)
Value scatteredTensor = rewriter.create<AtenScatterAddOp>(
loc, intMaskType, /*self*/ zerosTensor, /*dim=*/constantZero,
/*index=*/indices, /*src=*/multiplied);

// result_flat = compacted[:torch.sum(nonzero_mask)]
auto scalarType = ValueTensorType::get(rewriter.getContext(),
ArrayRef<int64_t>{}, intType);
Value sumMask =
rewriter.create<AtenSumOp>(loc, scalarType, intMask, noneCst);
Value numNonzero = rewriter.create<AtenIntTensorOp>(loc, sumMask);

auto slicedResultType = Torch::ValueTensorType::get(
rewriter.getContext(), SmallVector<int64_t>{kUnknownSize}, intType);
Value slicedResult =
rewriter.create<AtenSliceTensorOp>(loc, slicedResultType,
/*self=*/scatteredTensor,
/*dim=*/constantZero,
/*start=*/noneCst,
/*end=*/numNonzero,
/*step=*/constantOne);

// Convert flattened indices back to multi-dimensional indices
// original_shape = t.shape
// input_shape_tensor = torch.tensor(original_shape)
auto shapeType = Torch::ValueTensorType::get(
rewriter.getContext(), SmallVector<int64_t>{inputRank}, intType);
// Value inputShapeTensor =
// rewriter.create<Torch::Aten_ShapeAsTensorOp>(loc, shapeType, input);
AmosLewis marked this conversation as resolved.
Show resolved Hide resolved
SmallVector<Value> shapeValues;
for (int i = 0; i < inputRank; i++) {
auto constantI =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(i));
Value shape = rewriter.create<AtenSizeIntOp>(loc, input,
/*dim=*/constantI);
shapeValues.push_back(shape);
}
// Value shape0 = rewriter.create<AtenSizeIntOp>(loc, input,
// /*dim=*/constantZero);
Value shapeTensorList = rewriter.create<Torch::PrimListConstructOp>(
loc, Torch::ListType::get(shapeValues[0].getType()), shapeValues);
Value inputShapeTensor = rewriter.create<Torch::AtenTensorOp>(
loc, shapeType, shapeTensorList, noneCst, noneCst, falseCst);

// strides = torch.cumprod(torch.flip(input_shape_tensor,[0]),0).flip(0)
Value flippedShape = rewriter.create<AtenFlipOp>(
loc, shapeType, inputShapeTensor, makeOneElementList(constantZero));
Value cumulativeProduct = rewriter.create<AtenCumprodOp>(
loc, shapeType, flippedShape, constantZero, noneCst);
Value flippedCumulativeProduct = rewriter.create<AtenFlipOp>(
loc, shapeType, cumulativeProduct, makeOneElementList(constantZero));

// strides = torch.cat([strides[1:-1], torch.tensor([1])])
auto oneTensorType = ValueTensorType::get(rewriter.getContext(),
SmallVector<int64_t>{}, intType);
Value oneTensor = rewriter.create<AtenScalarTensorOp>(
loc, oneTensorType, constantOne, intTypeValue, noneCst, noneCst,
noneCst);

Value strides;
if (inputRank > 1) {
// strides[1:-1]
auto slicedStrideType = Torch::ValueTensorType::get(
rewriter.getContext(), SmallVector<int64_t>{inputRank - 1}, // sizes
intType);
Value strideSliceEnd = rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(inputRank));
Value slicedStrides = rewriter.create<AtenSliceTensorOp>(
loc, slicedStrideType, /*self*/ flippedCumulativeProduct,
/*dim*/ constantZero,
/*start=*/constantOne, /*end=*/strideSliceEnd, /*step=*/constantOne);
// torch.cat
auto tensorListElementType = Torch::ValueTensorType::get(
rewriter.getContext(), SmallVector<int64_t>{kUnknownSize}, intType);
Value tensorList = rewriter.create<Torch::PrimListConstructOp>(
loc, Torch::ListType::get(tensorListElementType),
SmallVector<Value>{slicedStrides, oneTensor});
strides = rewriter.create<Torch::AtenCatOp>(loc, shapeType, tensorList,
constantZero);
} else {
// strides[1:-1] is empty
strides = oneTensor;
}

// multi_indices = (result_flat.unsqueeze(1) // strides.unsqueeze(0)) %
// input_shape_tensor
auto unsqueezedResultType = ValueTensorType::get(
rewriter.getContext(), SmallVector<int64_t>{kUnknownSize, 1}, intType);
Value unsqueezedResult = rewriter.create<AtenUnsqueezeOp>(
loc, unsqueezedResultType, slicedResult, constantOne);

auto unsqueezedStridesType = ValueTensorType::get(
rewriter.getContext(), SmallVector<int64_t>{1, inputRank}, intType);
Value unsqueezedStrides = rewriter.create<AtenUnsqueezeOp>(
loc, unsqueezedStridesType, strides, constantZero);

auto dividedBroadcastType = ValueTensorType::get(
rewriter.getContext(), SmallVector<int64_t>{kUnknownSize, inputRank},
intType);
Value divided = rewriter.create<AtenFloorDivideOp>(
loc, dividedBroadcastType, unsqueezedResult, unsqueezedStrides);

Value modded = rewriter.create<AtenRemainderTensorOp>(
loc, resultType, divided, inputShapeTensor);

rewriter.replaceOp(op, modded);
return success();
}
};

// Decompose aten.addmm into aten.mm and aten.add.Tensor op.
namespace {
class DecomposeAtenAddmmOp : public OpRewritePattern<AtenAddmmOp> {
Expand Down Expand Up @@ -11263,6 +11500,7 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAten_SoftmaxBackwardDataOp>(
patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenTanhBackwardOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNonzeroOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenAddmmOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenMeanOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenMeanDimOp>(patterns);
Expand Down
2 changes: 2 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,7 @@
"AtenIntBoolOpModule_basic",
"AtenIntMM_basic",
"AtenItemFpOpModule_basic",
"AtenNonzero1DDynamicModule_basic", # no lowering for torch.aten.sym_constrain_range_for_size
"Aten_TrilinearModuleVaryingRanks_basic",
"Aten_TrilinearModuleZerodDimBug_basic",
"QuantizedReluInt32_basic",
Expand Down Expand Up @@ -628,6 +629,7 @@
"AtenMmQMixedSigni8_basic",
"AtenMmQint8_basic",
"AtenMmQuint8_basic",
"AtenNonzero1DDynamicModule_basic",
"AtenRealView128Module_basic",
"AtenRealView64Module_basic",
"AtenTopKModule_basic",
Expand Down
23 changes: 23 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6407,3 +6407,26 @@ def AtenPolarDoubleModule_basic(module, tu: TestUtils):
module.forward(
tu.rand(2, 5, 3, 4).to(torch.float64), tu.rand(2, 5, 3, 4).to(torch.float64)
)


# ==============================================================================


class AtenNonzero1DDynamicModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1], torch.bool, True),
]
)
def forward(self, x):
return torch.ops.aten.nonzero(x)


@register_test_case(module_factory=lambda: AtenNonzero1DDynamicModule())
def AtenNonzero1DDynamicModule_basic(module, tu: TestUtils):
module.forward(torch.tensor([0, 0, 1, 1, 0, 0], dtype=torch.bool))
Loading