Skip to content

Commit

Permalink
Add decomposenonzero e2e tests. Fix oneTensorType size. Fix torch.cat…
Browse files Browse the repository at this point in the history
… empty tensor.
  • Loading branch information
AmosLewis committed Dec 13, 2024
1 parent 5a87158 commit 65f4a3d
Show file tree
Hide file tree
Showing 4 changed files with 211 additions and 31 deletions.
21 changes: 20 additions & 1 deletion build_tools/ci/test_posix.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,26 @@ torch_version="${1:-unknown}"
export PYTHONPATH="$repo_root/build/tools/torch-mlir/python_packages/torch_mlir:$repo_root/projects/pt1"

echo "::group::Run ONNX e2e integration tests"
python -m e2e_testing.main --config=onnx -v --filter AtenNonzero1DModule_one_nonzero
# python -m e2e_testing.main --config=onnx -v --filter AtenNonzero1DModule_one_nonzero

# python -m e2e_testing.main --config=onnx -v --filter NonzeroDecomposeModule_basic # Failed: 1
# python -m e2e_testing.main --config=linalg -v --filter NonzeroDecomposeModule_basic # Passed: 1

# python -m e2e_testing.main --config=linalg -v --filter NonzeroFlattenDynamicModule # Passed: 1

# python -m e2e_testing.main --config=onnx -v --filter ScatterAddDynamicModule_basic # Passed: 1

# python -m e2e_testing.main --config=onnx -v --filter NonzeroCatModule # Passed: 1
# python -m e2e_testing.main --config=linalg -v --filter NonzeroCatModule # Failed: 1
# tensor with unknown dtype "torch.aten.cat"(%31, %4) : (!torch.list<vtensor>, !torch.int) -> !torch.vtensor<[1],unk>

# python -m e2e_testing.main --config=linalg -v --filter NonzeroCatModule # Failed: 1

# python -m e2e_testing.main --config=linalg -v --filter NonzeroCumsumModule
# python -m e2e_testing.main --config=onnx -v --filter NonzeroCumsumModule # pass
python -m e2e_testing.main --config=onnx -v --filter NonzeroCumsumBoolModule # pass in torch-mlir, failed in iree

# python -m e2e_testing.main --config=onnx -v --filter NonzeroLongModule
echo "::endgroup::"

# case $torch_version in
Expand Down
59 changes: 32 additions & 27 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5814,7 +5814,8 @@ class DecomposeAtenNonzeroOp : public OpRewritePattern<AtenNonzeroOp> {

// // DEBUG one dimentional result.
// auto unsqueezedResultType = ValueTensorType::get(
// rewriter.getContext(), SmallVector<int64_t>{kUnknownSize, 1}, si64Type);
// rewriter.getContext(), SmallVector<int64_t>{kUnknownSize, 1},
// si64Type);
// Value unsqueezedResult = rewriter.create<AtenUnsqueezeOp>(
// loc, unsqueezedResultType, slicedResult, constantZero);
// rewriter.replaceOp(op, unsqueezedResult);
Expand All @@ -5837,38 +5838,42 @@ class DecomposeAtenNonzeroOp : public OpRewritePattern<AtenNonzeroOp> {
Value flippedCumulativeProduct = rewriter.create<AtenFlipOp>(
loc, shapeType, cumulativeProduct, makeOneElementList(constantZero));

// strides = torch.cat([strides[1:], torch.tensor([1])])
// strides[1:]
auto slicedStrideType = Torch::ValueTensorType::get(
rewriter.getContext(), SmallVector<int64_t>{inputRank - 1}, // sizes
si64Type);
Value strideSliceEnd = rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(inputRank));
Value slicedStrides = rewriter.create<AtenSliceTensorOp>(
loc, slicedStrideType, /*self*/ flippedCumulativeProduct,
/*dim*/ constantZero,
/*start=*/constantOne, /*end=*/strideSliceEnd, /*step=*/constantOne);
// strides = torch.cat([strides[1:-1], torch.tensor([1])])
// torch.tensor([1])
auto oneTensorType = ValueTensorType::get(
rewriter.getContext(), SmallVector<int64_t>{1}, si64Type);
auto oneTensorType = ValueTensorType::get(rewriter.getContext(),
SmallVector<int64_t>{}, si64Type);
Value oneTensor = rewriter.create<AtenScalarTensorOp>(
loc, oneTensorType, constantOne, si64Dtype, noneCst, noneCst,
noneCst);
// torch.cat
auto tensorListElementType = Torch::ValueTensorType::get(
rewriter.getContext(), SmallVector<int64_t>{kUnknownSize}, si64Type);
Value tensorList = rewriter.create<Torch::PrimListConstructOp>(
loc, Torch::ListType::get(tensorListElementType),
SmallVector<Value>{slicedStrides, oneTensor});
Value strides = rewriter.create<Torch::AtenCatOp>(loc, shapeType,
tensorList,
constantZero);
loc, oneTensorType, constantOne, si64Dtype, noneCst, noneCst, noneCst);

Value strides;
if (inputRank > 1) {
// strides[1:-1]
auto slicedStrideType = Torch::ValueTensorType::get(
rewriter.getContext(), SmallVector<int64_t>{inputRank - 1}, // sizes
si64Type);
Value strideSliceEnd = rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(inputRank));
Value slicedStrides = rewriter.create<AtenSliceTensorOp>(
loc, slicedStrideType, /*self*/ flippedCumulativeProduct,
/*dim*/ constantZero,
/*start=*/constantOne, /*end=*/strideSliceEnd, /*step=*/constantOne);
// torch.cat
auto tensorListElementType = Torch::ValueTensorType::get(
rewriter.getContext(), SmallVector<int64_t>{kUnknownSize}, si64Type);
Value tensorList = rewriter.create<Torch::PrimListConstructOp>(
loc, Torch::ListType::get(tensorListElementType),
SmallVector<Value>{slicedStrides, oneTensor});
strides = rewriter.create<Torch::AtenCatOp>(loc, shapeType, tensorList,
constantZero);
} else {
// strides[1:-1] is empty
strides = oneTensor;
}

// multi_indices = (result_flat.unsqueeze(1) // strides.unsqueeze(0)) %
// input_shape_tensor
auto unsqueezedResultType = ValueTensorType::get(
rewriter.getContext(), SmallVector<int64_t>{kUnknownSize, 1},
si64Type);
rewriter.getContext(), SmallVector<int64_t>{kUnknownSize, 1}, si64Type);
Value unsqueezedResult = rewriter.create<AtenUnsqueezeOp>(
loc, unsqueezedResultType, slicedResult, constantOne);

Expand Down
156 changes: 156 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6430,3 +6430,159 @@ def forward(self, x):
@register_test_case(module_factory=lambda: AtenNonzero1DModule())
def AtenNonzero1DModule_one_nonzero(module, tu: TestUtils):
module.forward(torch.tensor([0, 0, 1, 1, 0, 0], dtype=torch.int))


class NonzeroDecomposeModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1], torch.int, True),
]
)
def forward(self, t):
t_flat = t.flatten()
nonzero_mask = t_flat != 0
nonzero_mask = nonzero_mask.long()
destination_indices = torch.cumsum(nonzero_mask, 0) - 1
destination_indices_clamp = torch.clamp(destination_indices, min=0)
iota = torch.arange(t_flat.size(0)) * nonzero_mask
scatter_self = torch.zeros_like(t_flat, dtype=torch.int64)
# compacted = scatter_self.scatter_(
# dim=0,
# index=destination_indices_clamp,
# src=iota,
# reduce='add'
# )
compacted = torch.scatter_add(
scatter_self, dim=0, index=destination_indices_clamp, src=iota
)
result_flat = compacted[: torch.sum(nonzero_mask)]

# multi dim
original_shape = t.shape
input_shape_tensor = torch.tensor(original_shape)
strides = torch.cumprod(torch.flip(input_shape_tensor, [0]), 0).flip(0)

one = torch.tensor([1])
if(t.dim() > 1):
slicedStrides = strides[1:-1]
strides = torch.cat([slicedStrides, one])
# return strides
else:
strides = one
a = result_flat.unsqueeze(1) # tensor([[2], [3]]) torch.Size([2, 1])
b = strides.unsqueeze(0) # tensor([[1]]) torch.Size([1, 1])
c = a // b
multi_indices = c % input_shape_tensor
return multi_indices


@register_test_case(module_factory=lambda: NonzeroDecomposeModule())
def NonzeroDecomposeModule_basic(module, tu: TestUtils):
module.forward(torch.tensor([0, 0, 1, 1, 0, 0], dtype=torch.int))


class NonzeroFlattenDynamicModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1], torch.int, True),
]
)
def forward(self, x):
return x.flatten()


@register_test_case(module_factory=lambda: NonzeroFlattenDynamicModule())
def NonzeroFlattenDynamicModule_basic(module, tu: TestUtils):
module.forward(torch.tensor([0, 0, 1, 1, 0, 0], dtype=torch.int))


class NonzeroCatModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1], torch.int, True),
]
)
def forward(self, a):
a = a[1:1]
b = torch.tensor([1])
return torch.cat([a, b])


@register_test_case(module_factory=lambda: NonzeroCatModule())
def NonzeroCatModule_basic(module, tu: TestUtils):
module.forward(torch.tensor([6], dtype=torch.int))


class NonzeroCumsumModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1], torch.int64, True),
]
)
def forward(self, x):
return torch.cumsum(x, 0)


@register_test_case(module_factory=lambda: NonzeroCumsumModule())
def NonzeroCumsumModule_basic(module, tu: TestUtils):
module.forward(torch.tensor([0, 0, 1, 1, 0, 0], dtype=torch.int64))


class NonzeroCumsumBoolModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1], torch.bool, True),
]
)
def forward(self, x):
return torch.cumsum(x.long(), 0)


@register_test_case(module_factory=lambda: NonzeroCumsumBoolModule())
def NonzeroCumsumBoolModule_basic(module, tu: TestUtils):
module.forward(torch.tensor([0, 0, 1, 1, 0, 0], dtype=torch.bool))


class NonzeroLongModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1], torch.bool, True),
]
)
def forward(self, x):
return x.long()


@register_test_case(module_factory=lambda: NonzeroLongModule())
def NonzeroLongModule_basic(module, tu: TestUtils):
module.forward(torch.tensor([0, 0, 1, 1, 0, 0], dtype=torch.bool))
6 changes: 3 additions & 3 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1065,9 +1065,9 @@ def forward(self, input, index, src):
@register_test_case(module_factory=lambda: ScatterAddDynamicModule())
def ScatterAddDynamicModule_basic(module, tu: TestUtils):
module.forward(
torch.tensor([0, 0, 0, 0, 0, 0]),
torch.tensor([0, 0, 0, 0, 0, 0]),
torch.tensor([0, 0, 0, 3, 0, 0])
torch.tensor([0, 0, 0, 0, 0, 0]),
torch.tensor([0, 0, 0, 0, 0, 0]),
torch.tensor([0, 0, 0, 3, 0, 0]),
)


Expand Down

0 comments on commit 65f4a3d

Please sign in to comment.