Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOSA] Add legalization for torch.aten.unfold #3922

Merged
merged 1 commit into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 193 additions & 0 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8256,6 +8256,198 @@ LogicalResult ConvertAtenOp<AtenTanOp>::matchAndRewrite(
return success();
}

// Legalization for aten.unfold
template <>
LogicalResult ConvertAtenOp<AtenUnfoldOp>::matchAndRewrite(
AtenUnfoldOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// Approach: Use GatherOp to retrieve target elements from target dim and then
// reshape the output into slices according to the output shape
//
// Lowering steps:
// 1. Create PyTorch-style indices tensor corresponding to target elements and
// reshape them to (d_0, d_1, ..., nWindows * size, ..., d_(rank - 1))
// with d_x being the dimension size of the input at dim x.
// The indices vector will be calculated using the following formula:
// for i in range(d_0 * d_1 * ... * d_(target_dim - 1)):
// for window in range(nWindows):
// for elementIndex in range(size):
// for j in range(d_(target_dim + 1) * ... * d_(rank-1)):
// indices_vec.push_back(elementIndex + window * step)
// 2. Convert PyTorch-style indices and target dim to TensorFlow-style indices
// 3. Apply TensorFlow GatherNdOp with TensorFlow-style indices to retrieve
// target elements
// 4. Reshape result from above to correct output shape
auto self = adaptor.getSelf();

auto selfType = dyn_cast<TensorType>(self.getType());
if (!selfType)
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");

auto selfShape = selfType.getShape();
auto selfRank = selfType.getRank();
auto selfElemTy = selfType.getElementType();

auto resultType =
dyn_cast<TensorType>(typeConverter->convertType(op.getType()));
auto resultElemTy = resultType.getElementType();

int64_t dim;
if (!matchPattern(op.getDimension(), m_TorchConstantInt(&dim)))
return rewriter.notifyMatchFailure(op,
"Only constant int dims are supported");

int64_t size;
if (!matchPattern(op.getSize(), m_TorchConstantInt(&size)))
return rewriter.notifyMatchFailure(op,
"Only constant int sizes are supported");

int64_t step;
if (!matchPattern(op.getStep(), m_TorchConstantInt(&step)))
return rewriter.notifyMatchFailure(op,
"Only constant int steps are supported");

if (step <= 0)
return rewriter.notifyMatchFailure(op, "Step value must be greater than 0");

// Handle rank zero
if (selfRank == 0) {
if (dim != 0)
return rewriter.notifyMatchFailure(
op, "Unsupported dim value for rank zero input");

if (size != 1)
return rewriter.notifyMatchFailure(
op, "Unsupported size value for rank zero input");

auto result = rewriter.create<tosa::ReshapeOp>(
op->getLoc(), RankedTensorType::get({1}, selfElemTy), self,
rewriter.getDenseI64ArrayAttr({1}));

rewriter.replaceOp(op, {result.getResult()});
return success();
}

dim = toPositiveDim(dim, selfRank);
if (!isValidDim(dim, selfRank))
return rewriter.notifyMatchFailure(op, "Dim value is invalid");

// Size of dimension 'dim' in the returned tensor (or number of windows within
// the dimension that got sliced)
int64_t nWindows = (selfShape[dim] - size) / step + 1;

// Find number of times that each base index value gets repeated for target
// dim based on dim values before and after target dim i.e. preDimAccumulate =
// d_0 * d_1 * ... * d_(target_dim - 1)
// postDimAccumulate = d_(target_dim + 1) * ... * d_(rank - 1)
int64_t preDimAccumulate =
std::accumulate(selfShape.begin(), selfShape.begin() + dim, 1,
std::multiplies<int64_t>());
int64_t postDimAccumulate =
std::accumulate(selfShape.begin() + dim + 1, selfShape.end(), 1,
std::multiplies<int64_t>());

// Calculate PyTorch-style gather indices vector
// Example: shape = (2, 4, 3), dim = 1, size = 3, step = 1
// -> preDimAccumulate = 2, postDimAccummulate = 3, nWindows = 2
// pyTorchIndicesBaseVec = [0, 0, 0, 1, 1, 1, 2, 2, 2,
// 1, 1, 1, 2, 2, 2, 3, 3, 3]
// pyTorchIndicesVec = [0, 0, 0, 1, 1, 1, 2, 2, 2,
// 1, 1, 1, 2, 2, 2, 3, 3, 3,
// 0, 0, 0, 1, 1, 1, 2, 2, 2,
// 1, 1, 1, 2, 2, 2, 3, 3, 3]
SmallVector<int32_t> pyTorchIndicesBaseVec;
SmallVector<int32_t> pyTorchIndicesVec;

for (int64_t window = 0; window < nWindows; window++) {
for (int64_t elementIndex = 0; elementIndex < size; elementIndex++) {
int32_t baseIndex = static_cast<int32_t>(elementIndex + window * step);
for (int64_t i = 0; i < postDimAccumulate; i++)
pyTorchIndicesBaseVec.push_back(baseIndex);
}
}

for (int64_t i = 0; i < preDimAccumulate; i++)
pyTorchIndicesVec.insert(pyTorchIndicesVec.end(),
pyTorchIndicesBaseVec.begin(),
pyTorchIndicesBaseVec.end());

// Create the PyTorch-style indices tensor
// Continuing with the previous example:
// pyTorchIndicesShape = (2, nWindows * size, 3) = (2, 6, 3)
// pyTorchIndices = tensor([[[0, 0, 0],
// [1, 1, 1],
// [2, 2, 2],
// [1, 1, 1],
// [2, 2, 2],
// [3, 3, 3]],
// [[0, 0, 0],
// [1, 1, 1],
// [2, 2, 2],
// [1, 1, 1],
// [2, 2, 2],
// [3, 3, 3]]])
SmallVector<int64_t> pyTorchIndicesShape(selfShape);
pyTorchIndicesShape[dim] = nWindows * size;
auto pyTorchIndices =
tosa::getConstTensor<int32_t>(rewriter, op, pyTorchIndicesVec,
pyTorchIndicesShape)
.value();

// Convert PyTorch-style indices to TensorFlow-style indices
auto tfIndices = tosa::convertTorchIndexToTfIndices(rewriter, op, self,
pyTorchIndices, dim);
if (!tfIndices)
return rewriter.notifyMatchFailure(op,
"Convert PyTorch-style indices and dim "
"to TensorFlow-style indices failed");

// Apply TensorFlow GatherNdOp with TensorFlow-style indices to retrieve
// target elements
auto gatherNdOp = tosa::convertGatherNdOp(
rewriter, op, RankedTensorType::get(pyTorchIndicesShape, resultElemTy),
self, tfIndices.value());
if (!gatherNdOp)
return rewriter.notifyMatchFailure(op, "Convert GatherNdOp failed");

// Reshape to an intermediary shape where the gathered elements in dimension
// 'dim' are split back into 2 dimensions of sizes 'nWindows' and 'size'
SmallVector<int64_t> intermediaryShape;
for (int64_t currentDim = 0; currentDim < selfRank; currentDim++) {
if (currentDim == dim) {
intermediaryShape.push_back(nWindows);
intermediaryShape.push_back(size);
} else {
intermediaryShape.push_back(pyTorchIndicesShape[currentDim]);
}
}

auto reshapeOp = rewriter.create<tosa::ReshapeOp>(
op->getLoc(), RankedTensorType::get(intermediaryShape, resultElemTy),
gatherNdOp.value(), rewriter.getDenseI64ArrayAttr(intermediaryShape));

// Permute dims to the correct result order
SmallVector<int32_t> permutedDims;
for (int64_t currentDim = 0; currentDim < selfRank + 1; currentDim++) {
if (currentDim != dim + 1)
permutedDims.push_back(static_cast<int32_t>(currentDim));
}
permutedDims.push_back(static_cast<int32_t>(dim + 1));

auto permutedDimsConst = tosa::getConstTensor<int32_t>(
rewriter, op,
/*vec=*/permutedDims,
/*shape=*/{static_cast<int32_t>(selfRank + 1)})
.value();

auto result = rewriter.create<tosa::TransposeOp>(
op->getLoc(), resultType, reshapeOp.getResult(), permutedDimsConst);

rewriter.replaceOp(op, {result.getResult()});

return success();
}

} // namespace

// -----------------------------------------------------------------------------
Expand Down Expand Up @@ -8617,6 +8809,7 @@ std::set<StringRef> torch::populateTorchToTosaConversionPatternsAndIllegalOps(
INSERT_ATENOP_PATTERN(AtenLog1pOp);
INSERT_ATENOP_PATTERN(AtenLog10Op);
INSERT_ATENOP_PATTERN(AtenTanOp);
INSERT_ATENOP_PATTERN(AtenUnfoldOp);
#undef INSERT_ATENOP_PATTERN

#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \
Expand Down
18 changes: 7 additions & 11 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1696,6 +1696,8 @@
"Aten_TrilinearModuleSumAllDims_basic",
"Aten_TrilinearModuleSumdims_basic",
"Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic",
"CrossEntropyLossModule_basic",
"CrossEntropyLossNoReductionModule_basic",
"ScatterSrcModule_basic",
"ScatterSrcStaticModule_basic",
"HBC_basic",
Expand All @@ -1704,6 +1706,9 @@
# Write the TOSA set as a "passing" set as it is very early in development
# and very few tests work yet.
TOSA_PASS_SET = {
"Unfold_Module_Rank_4",
"Unfold_Module_Rank_Zero_basic",
"Unfold_Module_basic",
"ElementwiseErfIntModule_basic",
"ElementwiseIntTensorLtFloatScalarModule_basic",
"ElementwiseSigmoidIntModule_basic",
Expand Down Expand Up @@ -3440,6 +3445,8 @@
}

FX_IMPORTER_TOSA_XFAIL_SET = {
"UniformModule_basic",
"UniformStaticShapeModule_basic",
"AtenFftRfft2DLastDim_basic",
"AtenFftRfft2DMiddleDim_basic",
"IsInfiniteModule_basic",
Expand All @@ -3459,11 +3466,7 @@
"MaxPool3dModule_basic",
"MaxPool3dStaticModule_basic",
"ViewDtypeStaticModule_basic",
"Unfold_Module_Dynamic_basic",
"Unfold_Module_Rank_4",
"Unfold_Module_Rank_Zero_Size_Zero_basic",
"Unfold_Module_Rank_Zero_basic",
"Unfold_Module_basic",
"ArangeZeroElementOutputModule_basic",
"NumpyTRank0Module_basic",
"Permute0RankModule_basic",
Expand Down Expand Up @@ -3887,17 +3890,10 @@
"AdaptiveAvgPool2dDynamic_basic",
"CrossEntropyLossModule_basic",
"CrossEntropyLossNoReductionModule_basic",
"ElementwiseRreluTrainModule_basic",
"ElementwiseRreluTrainStaticModule_basic",
"IndexPutImpl1DFloatNonAccumulateModule_basic",
"IndexPutImpl1DIntNonAccumulateModule_basic",
"IndexPutImpl2DFloatNonAccumulateModule_basic",
"IndexPutImpl3DFloatNonAccumulateModule_basic",
"IouOfModule_basic",
"MeshgridIndexingIJ_basic",
"MeshgridIndexingXY_basic",
"Meshgrid_basic",
"OneHotModule_basic",
"ReduceFrobeniusNormKeepDimModule_basic",
"ReduceFrobeniusNormModule_basic",
"ScaledDotProductAttentionBoolMaskModule_basic",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1752,7 +1752,7 @@ def forward(self, x):
return x.unfold(0, 0, 1)


@register_test_case(module_factory=lambda: Unfold_Module_Rank_Zero())
@register_test_case(module_factory=lambda: Unfold_Module_Rank_Zero_Size_Zero())
def Unfold_Module_Rank_Zero_Size_Zero_basic(module, tu: TestUtils):
module.forward(tu.rand())

Expand Down
50 changes: 50 additions & 0 deletions test/Conversion/TorchToTosa/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2943,3 +2943,53 @@ func.func @torch.aten.pow.Tensor_Tensor$intfloat(%arg0: !torch.vtensor<[3,4,5],s
}

// -----

// CHECK-LABEL: func.func @torch.aten.unfold$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[6,4],f32>) -> !torch.vtensor<[3,4,2],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[6,4],f32> -> tensor<6x4xf32>
// CHECK: %[[VAL_2:.*]] = torch.constant.int 0
// CHECK: %[[VAL_3:.*]] = torch.constant.int 2
// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<{{\[\[}}0, 0, 0, 0], [1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3], [4, 4, 4, 4], [5, 5, 5, 5]]> : tensor<6x4xi32>}> : () -> tensor<6x4xi32>
// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_4]] {new_shape = array<i64: 6, 4, 1>} : (tensor<6x4xi32>) -> tensor<6x4x1xi32>
// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<{{\[\[}}[0], [1], [2], [3]], {{\[\[}}0], [1], [2], [3]], {{\[\[}}0], [1], [2], [3]], {{\[\[}}0], [1], [2], [3]], {{\[\[}}0], [1], [2], [3]], {{\[\[}}0], [1], [2], [3]]]> : tensor<6x4x1xi32>}> : () -> tensor<6x4x1xi32>
// CHECK: %[[VAL_7:.*]] = tosa.concat %[[VAL_5]], %[[VAL_6]] {axis = 2 : i32} : (tensor<6x4x1xi32>, tensor<6x4x1xi32>) -> tensor<6x4x2xi32>
// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array<i64: 1, 24, 1>} : (tensor<6x4xf32>) -> tensor<1x24x1xf32>
// CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array<i64: 24, 2>} : (tensor<6x4x2xi32>) -> tensor<24x2xi32>
// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<[4, 1]> : tensor<2xi32>}> : () -> tensor<2xi32>
// CHECK: %[[VAL_11:.*]] = tosa.mul %[[VAL_9]], %[[VAL_10]] {shift = 0 : i8} : (tensor<24x2xi32>, tensor<2xi32>) -> tensor<24x2xi32>
// CHECK: %[[VAL_12:.*]] = tosa.reduce_sum %[[VAL_11]] {axis = 1 : i32} : (tensor<24x2xi32>) -> tensor<24x1xi32>
// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_12]] {new_shape = array<i64: 1, 24>} : (tensor<24x1xi32>) -> tensor<1x24xi32>
// CHECK: %[[VAL_14:.*]] = tosa.gather %[[VAL_8]], %[[VAL_13]] : (tensor<1x24x1xf32>, tensor<1x24xi32>) -> tensor<1x24x1xf32>
// CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_14]] {new_shape = array<i64: 6, 4>} : (tensor<1x24x1xf32>) -> tensor<6x4xf32>
// CHECK: %[[VAL_16:.*]] = tosa.reshape %[[VAL_15]] {new_shape = array<i64: 3, 2, 4>} : (tensor<6x4xf32>) -> tensor<3x2x4xf32>
// CHECK: %[[VAL_17:.*]] = "tosa.const"() <{value = dense<[0, 2, 1]> : tensor<3xi32>}> : () -> tensor<3xi32>
// CHECK: %[[VAL_18:.*]] = tosa.transpose %[[VAL_16]], %[[VAL_17]] : (tensor<3x2x4xf32>, tensor<3xi32>) -> tensor<3x4x2xf32>
// CHECK: %[[VAL_19:.*]] = torch_c.from_builtin_tensor %[[VAL_18]] : tensor<3x4x2xf32> -> !torch.vtensor<[3,4,2],f32>
// CHECK: return %[[VAL_19]] : !torch.vtensor<[3,4,2],f32>
// CHECK: }
func.func @torch.aten.unfold$basic(%arg0: !torch.vtensor<[6,4],f32>) -> !torch.vtensor<[3,4,2],f32> {
%int0 = torch.constant.int 0
%int2 = torch.constant.int 2
%0 = torch.aten.unfold %arg0, %int0, %int2, %int2 : !torch.vtensor<[6,4],f32>, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3,4,2],f32>
return %0 : !torch.vtensor<[3,4,2],f32>
}

// -----

// CHECK-LABEL: func.func @torch.aten.unfold$rank_zero(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[],f32> -> tensor<f32>
// CHECK: %[[VAL_2:.*]] = torch.constant.int 0
// CHECK: %[[VAL_3:.*]] = torch.constant.int 1
// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array<i64: 1>} : (tensor<f32>) -> tensor<1xf32>
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<1xf32> -> !torch.vtensor<[1],f32>
// CHECK: return %[[VAL_5]] : !torch.vtensor<[1],f32>
// CHECK: }
func.func @torch.aten.unfold$rank_zero(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> {
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%0 = torch.aten.unfold %arg0, %int0, %int1, %int1 : !torch.vtensor<[],f32>, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],f32>
return %0 : !torch.vtensor<[1],f32>
}

// -----
Loading