diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 1c05ae49e18b..be51712a35de 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -8256,6 +8256,198 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +// Legalization for aten.unfold +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenUnfoldOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // Approach: Use GatherOp to retrieve target elements from target dim and then + // reshape the output into slices according to the output shape + // + // Lowering steps: + // 1. Create PyTorch-style indices tensor corresponding to target elements and + // reshape them to (d_0, d_1, ..., nWindows * size, ..., d_(rank - 1)) + // with d_x being the dimension size of the input at dim x. + // The indices vector will be calculated using the following formula: + // for i in range(d_0 * d_1 * ... * d_(target_dim - 1)): + // for window in range(nWindows): + // for elementIndex in range(size): + // for j in range(d_(target_dim + 1) * ... * d_(rank-1)): + // indices_vec.push_back(elementIndex + window * step) + // 2. Convert PyTorch-style indices and target dim to TensorFlow-style indices + // 3. Apply TensorFlow GatherNdOp with TensorFlow-style indices to retrieve + // target elements + // 4. Reshape result from above to correct output shape + auto self = adaptor.getSelf(); + + auto selfType = dyn_cast(self.getType()); + if (!selfType) + return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); + + auto selfShape = selfType.getShape(); + auto selfRank = selfType.getRank(); + auto selfElemTy = selfType.getElementType(); + + auto resultType = + dyn_cast(typeConverter->convertType(op.getType())); + auto resultElemTy = resultType.getElementType(); + + int64_t dim; + if (!matchPattern(op.getDimension(), m_TorchConstantInt(&dim))) + return rewriter.notifyMatchFailure(op, + "Only constant int dims are supported"); + + int64_t size; + if (!matchPattern(op.getSize(), m_TorchConstantInt(&size))) + return rewriter.notifyMatchFailure(op, + "Only constant int sizes are supported"); + + int64_t step; + if (!matchPattern(op.getStep(), m_TorchConstantInt(&step))) + return rewriter.notifyMatchFailure(op, + "Only constant int steps are supported"); + + if (step <= 0) + return rewriter.notifyMatchFailure(op, "Step value must be greater than 0"); + + // Handle rank zero + if (selfRank == 0) { + if (dim != 0) + return rewriter.notifyMatchFailure( + op, "Unsupported dim value for rank zero input"); + + if (size != 1) + return rewriter.notifyMatchFailure( + op, "Unsupported size value for rank zero input"); + + auto result = rewriter.create( + op->getLoc(), RankedTensorType::get({1}, selfElemTy), self, + rewriter.getDenseI64ArrayAttr({1})); + + rewriter.replaceOp(op, {result.getResult()}); + return success(); + } + + dim = toPositiveDim(dim, selfRank); + if (!isValidDim(dim, selfRank)) + return rewriter.notifyMatchFailure(op, "Dim value is invalid"); + + // Size of dimension 'dim' in the returned tensor (or number of windows within + // the dimension that got sliced) + int64_t nWindows = (selfShape[dim] - size) / step + 1; + + // Find number of times that each base index value gets repeated for target + // dim based on dim values before and after target dim i.e. preDimAccumulate = + // d_0 * d_1 * ... * d_(target_dim - 1) + // postDimAccumulate = d_(target_dim + 1) * ... * d_(rank - 1) + int64_t preDimAccumulate = + std::accumulate(selfShape.begin(), selfShape.begin() + dim, 1, + std::multiplies()); + int64_t postDimAccumulate = + std::accumulate(selfShape.begin() + dim + 1, selfShape.end(), 1, + std::multiplies()); + + // Calculate PyTorch-style gather indices vector + // Example: shape = (2, 4, 3), dim = 1, size = 3, step = 1 + // -> preDimAccumulate = 2, postDimAccummulate = 3, nWindows = 2 + // pyTorchIndicesBaseVec = [0, 0, 0, 1, 1, 1, 2, 2, 2, + // 1, 1, 1, 2, 2, 2, 3, 3, 3] + // pyTorchIndicesVec = [0, 0, 0, 1, 1, 1, 2, 2, 2, + // 1, 1, 1, 2, 2, 2, 3, 3, 3, + // 0, 0, 0, 1, 1, 1, 2, 2, 2, + // 1, 1, 1, 2, 2, 2, 3, 3, 3] + SmallVector pyTorchIndicesBaseVec; + SmallVector pyTorchIndicesVec; + + for (int64_t window = 0; window < nWindows; window++) { + for (int64_t elementIndex = 0; elementIndex < size; elementIndex++) { + int32_t baseIndex = static_cast(elementIndex + window * step); + for (int64_t i = 0; i < postDimAccumulate; i++) + pyTorchIndicesBaseVec.push_back(baseIndex); + } + } + + for (int64_t i = 0; i < preDimAccumulate; i++) + pyTorchIndicesVec.insert(pyTorchIndicesVec.end(), + pyTorchIndicesBaseVec.begin(), + pyTorchIndicesBaseVec.end()); + + // Create the PyTorch-style indices tensor + // Continuing with the previous example: + // pyTorchIndicesShape = (2, nWindows * size, 3) = (2, 6, 3) + // pyTorchIndices = tensor([[[0, 0, 0], + // [1, 1, 1], + // [2, 2, 2], + // [1, 1, 1], + // [2, 2, 2], + // [3, 3, 3]], + // [[0, 0, 0], + // [1, 1, 1], + // [2, 2, 2], + // [1, 1, 1], + // [2, 2, 2], + // [3, 3, 3]]]) + SmallVector pyTorchIndicesShape(selfShape); + pyTorchIndicesShape[dim] = nWindows * size; + auto pyTorchIndices = + tosa::getConstTensor(rewriter, op, pyTorchIndicesVec, + pyTorchIndicesShape) + .value(); + + // Convert PyTorch-style indices to TensorFlow-style indices + auto tfIndices = tosa::convertTorchIndexToTfIndices(rewriter, op, self, + pyTorchIndices, dim); + if (!tfIndices) + return rewriter.notifyMatchFailure(op, + "Convert PyTorch-style indices and dim " + "to TensorFlow-style indices failed"); + + // Apply TensorFlow GatherNdOp with TensorFlow-style indices to retrieve + // target elements + auto gatherNdOp = tosa::convertGatherNdOp( + rewriter, op, RankedTensorType::get(pyTorchIndicesShape, resultElemTy), + self, tfIndices.value()); + if (!gatherNdOp) + return rewriter.notifyMatchFailure(op, "Convert GatherNdOp failed"); + + // Reshape to an intermediary shape where the gathered elements in dimension + // 'dim' are split back into 2 dimensions of sizes 'nWindows' and 'size' + SmallVector intermediaryShape; + for (int64_t currentDim = 0; currentDim < selfRank; currentDim++) { + if (currentDim == dim) { + intermediaryShape.push_back(nWindows); + intermediaryShape.push_back(size); + } else { + intermediaryShape.push_back(pyTorchIndicesShape[currentDim]); + } + } + + auto reshapeOp = rewriter.create( + op->getLoc(), RankedTensorType::get(intermediaryShape, resultElemTy), + gatherNdOp.value(), rewriter.getDenseI64ArrayAttr(intermediaryShape)); + + // Permute dims to the correct result order + SmallVector permutedDims; + for (int64_t currentDim = 0; currentDim < selfRank + 1; currentDim++) { + if (currentDim != dim + 1) + permutedDims.push_back(static_cast(currentDim)); + } + permutedDims.push_back(static_cast(dim + 1)); + + auto permutedDimsConst = tosa::getConstTensor( + rewriter, op, + /*vec=*/permutedDims, + /*shape=*/{static_cast(selfRank + 1)}) + .value(); + + auto result = rewriter.create( + op->getLoc(), resultType, reshapeOp.getResult(), permutedDimsConst); + + rewriter.replaceOp(op, {result.getResult()}); + + return success(); +} + } // namespace // ----------------------------------------------------------------------------- @@ -8617,6 +8809,7 @@ std::set torch::populateTorchToTosaConversionPatternsAndIllegalOps( INSERT_ATENOP_PATTERN(AtenLog1pOp); INSERT_ATENOP_PATTERN(AtenLog10Op); INSERT_ATENOP_PATTERN(AtenTanOp); + INSERT_ATENOP_PATTERN(AtenUnfoldOp); #undef INSERT_ATENOP_PATTERN #define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \ diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index fe3aa3c5dd41..88c57461cd1b 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1696,6 +1696,8 @@ "Aten_TrilinearModuleSumAllDims_basic", "Aten_TrilinearModuleSumdims_basic", "Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic", + "CrossEntropyLossModule_basic", + "CrossEntropyLossNoReductionModule_basic", "ScatterSrcModule_basic", "ScatterSrcStaticModule_basic", "HBC_basic", @@ -1704,6 +1706,9 @@ # Write the TOSA set as a "passing" set as it is very early in development # and very few tests work yet. TOSA_PASS_SET = { + "Unfold_Module_Rank_4", + "Unfold_Module_Rank_Zero_basic", + "Unfold_Module_basic", "ElementwiseErfIntModule_basic", "ElementwiseIntTensorLtFloatScalarModule_basic", "ElementwiseSigmoidIntModule_basic", @@ -3440,6 +3445,8 @@ } FX_IMPORTER_TOSA_XFAIL_SET = { + "UniformModule_basic", + "UniformStaticShapeModule_basic", "AtenFftRfft2DLastDim_basic", "AtenFftRfft2DMiddleDim_basic", "IsInfiniteModule_basic", @@ -3459,11 +3466,7 @@ "MaxPool3dModule_basic", "MaxPool3dStaticModule_basic", "ViewDtypeStaticModule_basic", - "Unfold_Module_Dynamic_basic", - "Unfold_Module_Rank_4", "Unfold_Module_Rank_Zero_Size_Zero_basic", - "Unfold_Module_Rank_Zero_basic", - "Unfold_Module_basic", "ArangeZeroElementOutputModule_basic", "NumpyTRank0Module_basic", "Permute0RankModule_basic", @@ -3887,17 +3890,10 @@ "AdaptiveAvgPool2dDynamic_basic", "CrossEntropyLossModule_basic", "CrossEntropyLossNoReductionModule_basic", - "ElementwiseRreluTrainModule_basic", - "ElementwiseRreluTrainStaticModule_basic", - "IndexPutImpl1DFloatNonAccumulateModule_basic", - "IndexPutImpl1DIntNonAccumulateModule_basic", - "IndexPutImpl2DFloatNonAccumulateModule_basic", - "IndexPutImpl3DFloatNonAccumulateModule_basic", "IouOfModule_basic", "MeshgridIndexingIJ_basic", "MeshgridIndexingXY_basic", "Meshgrid_basic", - "OneHotModule_basic", "ReduceFrobeniusNormKeepDimModule_basic", "ReduceFrobeniusNormModule_basic", "ScaledDotProductAttentionBoolMaskModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py index a8820f59c373..d1ddc42b39b1 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py @@ -1752,7 +1752,7 @@ def forward(self, x): return x.unfold(0, 0, 1) -@register_test_case(module_factory=lambda: Unfold_Module_Rank_Zero()) +@register_test_case(module_factory=lambda: Unfold_Module_Rank_Zero_Size_Zero()) def Unfold_Module_Rank_Zero_Size_Zero_basic(module, tu: TestUtils): module.forward(tu.rand()) diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index 9e504c082a8c..a3d52166385a 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -2943,3 +2943,53 @@ func.func @torch.aten.pow.Tensor_Tensor$intfloat(%arg0: !torch.vtensor<[3,4,5],s } // ----- + +// CHECK-LABEL: func.func @torch.aten.unfold$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[6,4],f32>) -> !torch.vtensor<[3,4,2],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[6,4],f32> -> tensor<6x4xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_3:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<{{\[\[}}0, 0, 0, 0], [1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3], [4, 4, 4, 4], [5, 5, 5, 5]]> : tensor<6x4xi32>}> : () -> tensor<6x4xi32> +// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_4]] {new_shape = array} : (tensor<6x4xi32>) -> tensor<6x4x1xi32> +// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<{{\[\[}}[0], [1], [2], [3]], {{\[\[}}0], [1], [2], [3]], {{\[\[}}0], [1], [2], [3]], {{\[\[}}0], [1], [2], [3]], {{\[\[}}0], [1], [2], [3]], {{\[\[}}0], [1], [2], [3]]]> : tensor<6x4x1xi32>}> : () -> tensor<6x4x1xi32> +// CHECK: %[[VAL_7:.*]] = tosa.concat %[[VAL_5]], %[[VAL_6]] {axis = 2 : i32} : (tensor<6x4x1xi32>, tensor<6x4x1xi32>) -> tensor<6x4x2xi32> +// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array} : (tensor<6x4xf32>) -> tensor<1x24x1xf32> +// CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array} : (tensor<6x4x2xi32>) -> tensor<24x2xi32> +// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<[4, 1]> : tensor<2xi32>}> : () -> tensor<2xi32> +// CHECK: %[[VAL_11:.*]] = tosa.mul %[[VAL_9]], %[[VAL_10]] {shift = 0 : i8} : (tensor<24x2xi32>, tensor<2xi32>) -> tensor<24x2xi32> +// CHECK: %[[VAL_12:.*]] = tosa.reduce_sum %[[VAL_11]] {axis = 1 : i32} : (tensor<24x2xi32>) -> tensor<24x1xi32> +// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_12]] {new_shape = array} : (tensor<24x1xi32>) -> tensor<1x24xi32> +// CHECK: %[[VAL_14:.*]] = tosa.gather %[[VAL_8]], %[[VAL_13]] : (tensor<1x24x1xf32>, tensor<1x24xi32>) -> tensor<1x24x1xf32> +// CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_14]] {new_shape = array} : (tensor<1x24x1xf32>) -> tensor<6x4xf32> +// CHECK: %[[VAL_16:.*]] = tosa.reshape %[[VAL_15]] {new_shape = array} : (tensor<6x4xf32>) -> tensor<3x2x4xf32> +// CHECK: %[[VAL_17:.*]] = "tosa.const"() <{value = dense<[0, 2, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> +// CHECK: %[[VAL_18:.*]] = tosa.transpose %[[VAL_16]], %[[VAL_17]] : (tensor<3x2x4xf32>, tensor<3xi32>) -> tensor<3x4x2xf32> +// CHECK: %[[VAL_19:.*]] = torch_c.from_builtin_tensor %[[VAL_18]] : tensor<3x4x2xf32> -> !torch.vtensor<[3,4,2],f32> +// CHECK: return %[[VAL_19]] : !torch.vtensor<[3,4,2],f32> +// CHECK: } +func.func @torch.aten.unfold$basic(%arg0: !torch.vtensor<[6,4],f32>) -> !torch.vtensor<[3,4,2],f32> { + %int0 = torch.constant.int 0 + %int2 = torch.constant.int 2 + %0 = torch.aten.unfold %arg0, %int0, %int2, %int2 : !torch.vtensor<[6,4],f32>, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3,4,2],f32> + return %0 : !torch.vtensor<[3,4,2],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.unfold$rank_zero( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[],f32> -> tensor +// CHECK: %[[VAL_2:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_3:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array} : (tensor) -> tensor<1xf32> +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<1xf32> -> !torch.vtensor<[1],f32> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[1],f32> +// CHECK: } +func.func @torch.aten.unfold$rank_zero(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> { + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %0 = torch.aten.unfold %arg0, %int0, %int1, %int1 : !torch.vtensor<[],f32>, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],f32> + return %0 : !torch.vtensor<[1],f32> +} + +// -----