Skip to content

Commit

Permalink
[ONNX]Fix onnx to torch nonzero result type and delete multi dim for …
Browse files Browse the repository at this point in the history
…debug
  • Loading branch information
AmosLewis committed Dec 9, 2024
1 parent 3104e4f commit 508be9f
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 67 deletions.
23 changes: 21 additions & 2 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1101,8 +1101,27 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
binder.tensorResultType(resultType)) {
return failure();
}
rewriter.replaceOpWithNewOp<Torch::AtenNonzeroOp>(
binder.op, resultType, operand);
Value zero = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
Value one = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 1));
auto rawSize = resultType.getSizes();
SmallVector<int64_t> torchResultSize(rawSize.rbegin(),
rawSize.rend());
auto torchResultSizeType = Torch::ValueTensorType::get(
rewriter.getContext(), torchResultSize,
rewriter.getIntegerType(64, /*signed=*/true));
auto nonZero = rewriter.create<Torch::AtenNonzeroOp>(
binder.getLoc(), torchResultSizeType, operand);
// The output tensor has a shape of ((n, z)), where (n) is the
// number of dimensions in the input tensor and (z) is the
// number of non-zero elements2. This is different from
// PyTorch's default behavior, where the dimensions are
// reversed.
rewriter.replaceOpWithNewOp<Torch::AtenTransposeIntOp>(
binder.op, resultType, nonZero, zero, one);
return success();
});
patterns.onOp(
Expand Down
138 changes: 74 additions & 64 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5726,15 +5726,6 @@ class DecomposeAtenNonzeroOp : public OpRewritePattern<AtenNonzeroOp> {
auto inputType = dyn_cast<BaseTensorType>(input.getType());
int64_t inputRank = inputType.getSizes().size();

// original_shape = t.shape
// input_shape_tensor = torch.tensor(original_shape, device=t.device)
auto shapeType = Torch::ValueTensorType::get(
rewriter.getContext(), SmallVector<int64_t>{inputRank}, si64Type);
// %2 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[?],i1> ->
// !torch.vtensor<[1],si64>
Value inputShapeTensor =
rewriter.create<Torch::Aten_ShapeAsTensorOp>(loc, shapeType, input);

// t_flat = t.flatten() # torch.flatten(t, 0, 0)
int64_t flattenedSize = 1;
if (inputType.hasSizes()) {
Expand Down Expand Up @@ -5821,64 +5812,83 @@ class DecomposeAtenNonzeroOp : public OpRewritePattern<AtenNonzeroOp> {
/*end=*/numNonzero,
/*step=*/constantOne);

// Convert flattened indices back to multi-dimensional indices
// strides = torch.cumprod(torch.flip(input_shape_tensor, [0]), 0).flip(0)
Value flippedShape = rewriter.create<AtenFlipOp>(
loc, shapeType, inputShapeTensor, makeOneElementList(constantZero));
Value cumulativeProduct = rewriter.create<AtenCumprodOp>(
loc, shapeType, flippedShape, constantZero, noneCst);
Value flippedCumulativeProduct = rewriter.create<AtenFlipOp>(
loc, shapeType, cumulativeProduct, makeOneElementList(constantZero));

// strides = torch.cat([strides[1:], torch.tensor([1])])
// strides[1:]
auto slicedStrideType = Torch::ValueTensorType::get(
rewriter.getContext(), SmallVector<int64_t>{inputRank - 1}, // sizes
si64Type);
Value strideSliceEnd = rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(inputRank));
Value slicedStrides = rewriter.create<AtenSliceTensorOp>(
loc, slicedStrideType, /*self*/ flippedCumulativeProduct,
/*dim*/ constantZero,
/*start=*/constantOne, /*end=*/strideSliceEnd, /*step=*/constantOne);
// torch.tensor([1])
auto oneTensorType = ValueTensorType::get(
rewriter.getContext(), SmallVector<int64_t>{1}, si64Type);
Value oneTensor = rewriter.create<AtenScalarTensorOp>(
loc, oneTensorType, constantOne, si64Dtype, noneCst, noneCst, noneCst);
// torch.cat
auto tensorListElementType = Torch::ValueTensorType::get(
rewriter.getContext(), SmallVector<int64_t>{kUnknownSize}, si64Type);
Value tensorList = rewriter.create<Torch::PrimListConstructOp>(
loc, Torch::ListType::get(tensorListElementType),
SmallVector<Value>{slicedStrides, oneTensor});
Value strides = rewriter.create<Torch::AtenCatOp>(loc, shapeType,
tensorList, constantZero);

// multi_indices = (result_flat.unsqueeze(1) // strides.unsqueeze(0)) %
// input_shape_tensor
// DEBUG one dimentional result.
auto unsqueezedResultType = ValueTensorType::get(
rewriter.getContext(), SmallVector<int64_t>{kUnknownSize, 1}, si64Type);
Value unsqueezedResult = rewriter.create<AtenUnsqueezeOp>(
loc, unsqueezedResultType, slicedResult, constantOne);

auto unsqueezedStridesType = ValueTensorType::get(
rewriter.getContext(), SmallVector<int64_t>{1, inputRank}, si64Type);
Value unsqueezedStrides = rewriter.create<AtenUnsqueezeOp>(
loc, unsqueezedStridesType, strides, constantZero);

auto dividedBroadcastType = ValueTensorType::get(
rewriter.getContext(), SmallVector<int64_t>{kUnknownSize, inputRank},
si64Type);
Value divided = rewriter.create<AtenFloorDivideOp>(
loc, dividedBroadcastType, unsqueezedResult, unsqueezedStrides);

auto resultType = cast<BaseTensorType>(op.getType());
Value modded = rewriter.create<AtenRemainderTensorOp>(
loc, resultType, divided, inputShapeTensor);

rewriter.replaceOp(op, modded);
return success();
loc, unsqueezedResultType, slicedResult, constantZero);
rewriter.replaceOp(op, unsqueezedResult);
return success();

// // Convert flattened indices back to multi-dimensional indices
// // original_shape = t.shape
// // input_shape_tensor = torch.tensor(original_shape)
// auto shapeType = Torch::ValueTensorType::get(
// rewriter.getContext(), SmallVector<int64_t>{inputRank}, si64Type);
// // %2 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[?],i1> ->
// // !torch.vtensor<[1],si64>
// Value inputShapeTensor =
// rewriter.create<Torch::Aten_ShapeAsTensorOp>(loc, shapeType, input);
// // strides = torch.cumprod(torch.flip(input_shape_tensor,[0]),0).flip(0)
// Value flippedShape = rewriter.create<AtenFlipOp>(
// loc, shapeType, inputShapeTensor, makeOneElementList(constantZero));
// Value cumulativeProduct = rewriter.create<AtenCumprodOp>(
// loc, shapeType, flippedShape, constantZero, noneCst);
// Value flippedCumulativeProduct = rewriter.create<AtenFlipOp>(
// loc, shapeType, cumulativeProduct, makeOneElementList(constantZero));

// // strides = torch.cat([strides[1:], torch.tensor([1])])
// // strides[1:]
// auto slicedStrideType = Torch::ValueTensorType::get(
// rewriter.getContext(), SmallVector<int64_t>{inputRank - 1}, // sizes
// si64Type);
// Value strideSliceEnd = rewriter.create<ConstantIntOp>(
// loc, rewriter.getI64IntegerAttr(inputRank));
// Value slicedStrides = rewriter.create<AtenSliceTensorOp>(
// loc, slicedStrideType, /*self*/ flippedCumulativeProduct,
// /*dim*/ constantZero,
// /*start=*/constantOne, /*end=*/strideSliceEnd, /*step=*/constantOne);
// // torch.tensor([1])
// auto oneTensorType = ValueTensorType::get(
// rewriter.getContext(), SmallVector<int64_t>{1}, si64Type);
// Value oneTensor = rewriter.create<AtenScalarTensorOp>(
// loc, oneTensorType, constantOne, si64Dtype, noneCst, noneCst,
// noneCst);
// // torch.cat
// auto tensorListElementType = Torch::ValueTensorType::get(
// rewriter.getContext(), SmallVector<int64_t>{kUnknownSize}, si64Type);
// Value tensorList = rewriter.create<Torch::PrimListConstructOp>(
// loc, Torch::ListType::get(tensorListElementType),
// SmallVector<Value>{slicedStrides, oneTensor});
// Value strides = rewriter.create<Torch::AtenCatOp>(loc, shapeType,
// tensorList,
// constantZero);

// // multi_indices = (result_flat.unsqueeze(1) // strides.unsqueeze(0)) %
// // input_shape_tensor
// auto unsqueezedResultType = ValueTensorType::get(
// rewriter.getContext(), SmallVector<int64_t>{kUnknownSize, 1},
// si64Type);
// Value unsqueezedResult = rewriter.create<AtenUnsqueezeOp>(
// loc, unsqueezedResultType, slicedResult, constantOne);

// auto unsqueezedStridesType = ValueTensorType::get(
// rewriter.getContext(), SmallVector<int64_t>{1, inputRank}, si64Type);
// Value unsqueezedStrides = rewriter.create<AtenUnsqueezeOp>(
// loc, unsqueezedStridesType, strides, constantZero);

// auto dividedBroadcastType = ValueTensorType::get(
// rewriter.getContext(), SmallVector<int64_t>{kUnknownSize, inputRank},
// si64Type);
// Value divided = rewriter.create<AtenFloorDivideOp>(
// loc, dividedBroadcastType, unsqueezedResult, unsqueezedStrides);

// auto resultType = cast<BaseTensorType>(op.getType());
// Value modded = rewriter.create<AtenRemainderTensorOp>(
// loc, resultType, divided, inputShapeTensor);

// rewriter.replaceOp(op, modded);
// return success();
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6429,4 +6429,4 @@ def forward(self, x):

@register_test_case(module_factory=lambda: AtenNonzero1DModule())
def AtenNonzero1DModule_one_nonzero(module, tu: TestUtils):
module.forward(torch.tensor([0, 0, 0, 1, 0, 0], dtype=torch.int))
module.forward(torch.tensor([0, 0, 1, 1, 0, 0], dtype=torch.int))
29 changes: 29 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,6 +1045,35 @@ def ScatterAddStaticModule_basic(module, tu: TestUtils):
# ==============================================================================


class ScatterAddDynamicModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1], torch.int64, True),
([-1], torch.int64, True),
([-1], torch.int64, True),
]
)
def forward(self, input, index, src):
return torch.ops.aten.scatter_add(input, 0, index, src)


@register_test_case(module_factory=lambda: ScatterAddDynamicModule())
def ScatterAddDynamicModule_basic(module, tu: TestUtils):
module.forward(
torch.tensor([0, 0, 0, 0, 0, 0]),
torch.tensor([0, 0, 0, 0, 0, 0]),
torch.tensor([0, 0, 0, 3, 0, 0])
)


# ==============================================================================


class ScatterReduceFloatModule(torch.nn.Module):
include_self: bool
reduce_type: str
Expand Down

0 comments on commit 508be9f

Please sign in to comment.