From b3f7d63fc7488ab8e6bcd843901bede5683955d1 Mon Sep 17 00:00:00 2001 From: yangxinyu Date: Tue, 23 Apr 2024 17:51:34 +0000 Subject: [PATCH 01/10] add pixel_unshuffle definition --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 24 +++++++++ .../Transforms/AbstractInterpLibrary.cpp | 54 +++++++++++++++++++ .../build_tools/abstract_interp_lib_gen.py | 16 ++++++ .../build_tools/torch_ods_gen.py | 1 + 4 files changed, 95 insertions(+) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index c38d0dbbd389..a5761d740376 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -7431,6 +7431,30 @@ def Torch_AtenPixelShuffleOp : Torch_Op<"aten.pixel_shuffle", [ }]; } +def Torch_AtenPixelUnshuffleOp : Torch_Op<"aten.pixel_unshuffle", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::pixel_unshuffle : (Tensor, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$downscale_factor + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenPixelUnshuffleOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenPixelUnshuffleOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenPermuteOp : Torch_Op<"aten.permute", [ AllowsTypeRefinement, ReadOnly diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index f4415a480a7c..8c4abc7a6302 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -7234,6 +7234,56 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %15 = torch.aten.append.t %6, %14 : !torch.list, !torch.int -> !torch.list\n" " return %6 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.pixel_unshuffle\"(%arg0: !torch.list, %arg1: !torch.int) -> !torch.list {\n" +" %int1 = torch.constant.int 1\n" +" %int-3 = torch.constant.int -3\n" +" %str = torch.constant.str \"AssertionError: number of input height must be divisible by downscale_factor in pixel_unshuffle\"\n" +" %int-2 = torch.constant.int -2\n" +" %str_0 = torch.constant.str \"AssertionError: number of input width must be divisible by downscale_factor in pixel_unshuffle\"\n" +" %int-1 = torch.constant.int -1\n" +" %none = torch.constant.none\n" +" %str_1 = torch.constant.str \"AssertionError: input must be at least rank-3 in pixel_unshuffle\"\n" +" %int3 = torch.constant.int 3\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.ge.int %0, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = torch.aten.__getitem__.t %arg0, %int-1 : !torch.list, !torch.int -> !torch.int\n" +" %3 = torch.aten.remainder.int %2, %arg1 : !torch.int, !torch.int -> !torch.int\n" +" %4 = torch.aten.eq.int %3, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %5 = torch.aten.__getitem__.t %arg0, %int-2 : !torch.list, !torch.int -> !torch.int\n" +" %6 = torch.aten.remainder.int %5, %arg1 : !torch.int, !torch.int -> !torch.int\n" +" %7 = torch.aten.eq.int %6, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %7 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %8 = torch.aten.slice.t %arg0, %int0, %int-3, %int1 : !torch.list, !torch.int, !torch.int, !torch.int -> !torch.list\n" +" %9 = torch.aten.__getitem__.t %arg0, %int-3 : !torch.list, !torch.int -> !torch.int\n" +" %10 = torch.aten.mul.int %9, %arg1 : !torch.int, !torch.int -> !torch.int\n" +" %11 = torch.aten.mul.int %10, %arg1 : !torch.int, !torch.int -> !torch.int\n" +" %12 = torch.aten.append.t %8, %11 : !torch.list, !torch.int -> !torch.list\n" +" %13 = torch.aten.__getitem__.t %arg0, %int-2 : !torch.list, !torch.int -> !torch.int\n" +" %14 = torch.aten.floordiv.int %13, %arg1 : !torch.int, !torch.int -> !torch.int\n" +" %15 = torch.aten.append.t %8, %14 : !torch.list, !torch.int -> !torch.list\n" +" %16 = torch.aten.__getitem__.t %arg0, %int-1 : !torch.list, !torch.int -> !torch.int\n" +" %17 = torch.aten.floordiv.int %16, %arg1 : !torch.int, !torch.int -> !torch.int\n" +" %18 = torch.aten.append.t %8, %17 : !torch.list, !torch.int -> !torch.list\n" +" return %8 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.permute\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.permute(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -9898,6 +9948,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.pixel_unshuffle\"(%arg0: !torch.tuple, %arg1: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.avg_pool1d\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.bool, %arg5: !torch.bool) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 06962010fea8..2d732c581983 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -667,6 +667,17 @@ def aten〇pixel_shuffle〡shape(self: List[int], upscale_factor: int) -> List[i out.append(self[-1] * upscale_factor) return out +def aten〇pixel_unshuffle〡shape(self: List[int], downscale_factor: int) -> List[int]: + + assert len(self) >= 3, "input must be at least rank-3 in pixel_unshuffle" + assert self[-1] % downscale_factor == 0, "number of input width must be divisible by downscale_factor in pixel_unshuffle" + assert self[-2] % downscale_factor == 0, "number of input height must be divisible by downscale_factor in pixel_unshuffle" + + out = self[0:-3] + out.append(self[-3] * downscale_factor * downscale_factor) + out.append(self[-2] // downscale_factor) + out.append(self[-1] // downscale_factor) + return out def aten〇permute〡shape(self: List[int], dims: List[int]) -> List[int]: @@ -2143,6 +2154,11 @@ def aten〇pixel_shuffle〡dtype(self_rank_dtype: Tuple[int, int], upscale_facto self_rank, self_dtype = self_rank_dtype return self_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 12, 12)], downscale_factor = 4)) +def aten〇pixel_unshuffle〡dtype(self_rank_dtype: Tuple[int, int], downscale_factor: int) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 7)], kernel_size=[2])) def aten〇avg_pool1d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0,), ceil_mode: bool = False, count_include_pad: bool = True) -> int: self_rank, self_dtype = self_rank_dtype diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index e5b219e55e9c..3a9c9835080a 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -530,6 +530,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::topk : (Tensor, int, int, bool, bool) -> (Tensor, Tensor)") emit("aten::transpose.int : (Tensor, int, int) -> (Tensor)") emit("aten::pixel_shuffle : (Tensor, int) -> (Tensor)") + emit("aten::pixel_unshuffle : (Tensor, int) -> (Tensor)") emit("aten::permute : (Tensor, int[]) -> (Tensor)", has_verifier=True) emit("aten::movedim.int : (Tensor, int, int) -> (Tensor)") emit("aten::bmm : (Tensor, Tensor) -> (Tensor)") From e68a3aeac41f2b8e189a61b96b6242691ac30c16 Mon Sep 17 00:00:00 2001 From: yangxinyu Date: Wed, 24 Apr 2024 03:01:06 +0000 Subject: [PATCH 02/10] unfinish --- .../Torch/Transforms/DecomposeComplexOps.cpp | 126 ++++++++++++++++++ 1 file changed, 126 insertions(+) diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 49dd5319514b..0a8eb82e8e25 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -2231,6 +2231,132 @@ class DecomposeAtenPixelShuffleOp }; } // namespace +// Decompose aten.pixel_unshuffle into: prims.split_dim, aten.permute, and +// prims.collapse operations. +// +// If input is a tensor of shape +// (*leading_dims, C, H*r, W*r), +// +// where leading_dims is of size N, then +// X = pixel_unshuffle(input, downscale_factor) +// +// gets replaced with +// X = input.split_dim(...) # shape (*leading_dims, C, H*r, W, r) +// X = X.split_dim(...) # shape (*leading_dims, C, H, r, W, r) +// X = X.permute(0, ..., N, N+2, N+4, N+1, N+3) +// # shape (*leading_dims, C, r, r, H, W) +// X = X.collapse(...) # shape (*leading_dims, C, r*H, r*W) +namespace { +class DecomposeAtenPixelUnshuffleOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenPixelUnshuffleOp op, + PatternRewriter &rewriter) const override { + + Location loc = op.getLoc(); + Value inValue = op.getSelf(); + auto inType = inValue.getType().cast(); + auto maybeSizes = inType.getOptionalSizes(); + if (!maybeSizes) { + return rewriter.notifyMatchFailure( + op, "Expected input tensor to have known rank."); + } + auto inShape = maybeSizes.value(); + auto inRank = inShape.size(); + + const auto inOptionalDType = inType.getOptionalDtype(); + + auto getTypeFromShape = [inOptionalDType](auto &&vals) { + // Get a vector of integers from a vector of Values. + auto getIntShape = [](auto &&vals) { + SmallVector shape; + shape.reserve(vals.size()); + for (auto v : vals) { + int64_t cst_val; + if (matchPattern(v, m_TorchConstantInt(&cst_val))) { + shape.push_back(cst_val); + } else { + shape.push_back(kUnknownSize); + } + } + return shape; + }; + + const auto intShape = getIntShape(vals); + return ValueTensorType::get(vals[0].getContext(), + llvm::ArrayRef(intShape), inOptionalDType); + }; + + auto nLeadingDims = inRank - 3; + + // Get the size of the dimension 'i'. Note the use of 'createOrFold' instead + // of 'create': if the dimension size is known, then the AtenSizeIntOp is + // folded to a ConstantOp. + auto getDimSize = [&](uint64_t i) -> Value { + Value dim = + rewriter.create(loc, rewriter.getI64IntegerAttr(i)); + return rewriter.createOrFold(loc, inValue, dim); + }; + + auto inC = getDimSize(inRank - 3); + auto inH = getDimSize(inRank - 2); + auto inW = getDimSize(inRank - 1); + + auto factor = op.getDownscaleFactor(); + + Value factorSquared = + rewriter.createOrFold(loc, factor, factor); + + Value outC = rewriter.createOrFold(loc, inC, factorSquared); + + Value outH = rewriter.createOrFold(loc, inH, factor); + Value outW = rewriter.createOrFold(loc, inW, factor); + + SmallVector dimensionConstants; + dimensionConstants.reserve(inRank + 2); + for (unsigned i = 0; i < inRank + 2; ++i) { + dimensionConstants.push_back( + rewriter.create(loc, rewriter.getI64IntegerAttr(i))); + } + + SmallVector leadingDims; + leadingDims.reserve(nLeadingDims); + for (unsigned i = 0; i < nLeadingDims; ++i) { + Value leadingDimSize = rewriter.createOrFold( + loc, inValue, dimensionConstants[i]); + leadingDims.push_back(leadingDimSize); + } + + SmallVector partiallyExpandedShape = leadingDims; + partiallyExpandedShape.append({outC, inH, factor, inW, factor}); + + SmallVector prePermuteShape = leadingDims; + prePermuteShape.append({outC, inH, factor, inW, factor}); + + SmallVector postPermuteShape = leadingDims; + postPermuteShape.append({outC, factor, factor, inH, inW}); + + SmallVector partiallyCollapsedShape = leadingDims; + partiallyCollapsedShape.append({outC, outH, outW}); + + SmallVector outShape = leadingDims; + outShape.append({outC, outH, outW}); + + SmallVector permutation{dimensionConstants.begin(), + dimensionConstants.begin() + nLeadingDims}; + SmallVector permutationTail{0, 2, 4, 1, 3}; + for (uint64_t d : permutationTail) { + permutation.push_back(dimensionConstants[nLeadingDims + d]); + } + + Value permuteDimsOrder = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(op->getContext())), + permutation); + } +}; +} // namespace + // ReLU6(x) = min(max(0, x), 6) = min(Relu(x), 6) static Value getRelu6Results(PatternRewriter &rewriter, Location loc, Value input) { From 17d90e3d8420fd1576cf8f2acf050fff7a78db26 Mon Sep 17 00:00:00 2001 From: yangxinyu Date: Sun, 28 Apr 2024 18:18:19 +0000 Subject: [PATCH 03/10] add pixelunshuffle --- .../Torch/Transforms/DecomposeComplexOps.cpp | 33 ++++++++++++++----- .../Transforms/LowerToBackendContract.cpp | 1 + .../torch_mlir_e2e_test/test_suite/basic.py | 18 ++++++++++ 3 files changed, 44 insertions(+), 8 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 2a98b0f65792..58e1aaf261e1 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -2342,14 +2342,14 @@ class DecomposeAtenPixelUnshuffleOp leadingDims.push_back(leadingDimSize); } - SmallVector partiallyExpandedShape = leadingDims; - partiallyExpandedShape.append({outC, inH, factor, inW, factor}); + SmallVector partiallyExpandedShape0 = leadingDims; + partiallyExpandedShape0.append({inC, inH, outW, factor}); - SmallVector prePermuteShape = leadingDims; - prePermuteShape.append({outC, inH, factor, inW, factor}); + SmallVector partiallyExpandedShape1 = leadingDims; + partiallyExpandedShape1.append({inC, outH, factor, outW, factor}); SmallVector postPermuteShape = leadingDims; - postPermuteShape.append({outC, factor, factor, inH, inW}); + postPermuteShape.append({inC, factor, factor, outH, outW}); SmallVector partiallyCollapsedShape = leadingDims; partiallyCollapsedShape.append({outC, outH, outW}); @@ -2364,9 +2364,25 @@ class DecomposeAtenPixelUnshuffleOp permutation.push_back(dimensionConstants[nLeadingDims + d]); } - Value permuteDimsOrder = rewriter.create( - loc, Torch::ListType::get(Torch::IntType::get(op->getContext())), - permutation); + Value partiallyExpanded = rewriter.create( + loc, getTypeFromShape(partiallyExpandedShape0), inValue, + dimensionConstants[nLeadingDims + 2], outW); + + Value fullyExpanded = rewriter.create( + loc, getTypeFromShape(partiallyExpandedShape1), partiallyExpanded, + dimensionConstants[nLeadingDims + 1], outH); + + Value permuted = rewriter.create( + loc, getTypeFromShape(postPermuteShape), fullyExpanded, + rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(op->getContext())), + permutation)); + + rewriter.replaceOpWithNewOp( + op, op.getType(), permuted, dimensionConstants[nLeadingDims + 0], + dimensionConstants[nLeadingDims + 2]); + + return success(); } }; } // namespace @@ -7779,6 +7795,7 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal( patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index c5855a1fa092..307436b8450a 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -398,6 +398,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addDynamicallyLegalOp([](AtenMatmulOp op) { diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index b483f9d3c689..0dda298e574a 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -986,6 +986,24 @@ def PixelShuffleModuleSpatiallyStatic_basic(module, tu: TestUtils): # ============================================================================== +class PixelUnshuffleModuleStaticRank4Float32(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([3, 6, 4, 4], torch.float32, True)]) + def forward(self, x): + return torch.ops.aten.pixel_unshuffle(x, 2) + + +@register_test_case(module_factory=lambda: PixelUnshuffleModuleStaticRank4Float32()) +def PixelUnshuffleModuleStaticRank4Float32_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 6, 4, 4)) + + +# ============================================================================== + + class TensorsConcatModule(torch.nn.Module): def __init__(self): super().__init__() From 67a0b263ce0f5e511e9ab52ac7c68d5ca7bfa5e6 Mon Sep 17 00:00:00 2001 From: yangxinyu Date: Mon, 29 Apr 2024 17:17:59 +0000 Subject: [PATCH 04/10] fix --- .../Torch/Transforms/DecomposeComplexOps.cpp | 135 +++++++----------- 1 file changed, 50 insertions(+), 85 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index d1465656d122..7525a89661f5 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -2074,6 +2074,36 @@ class DecomposeAtenLinalgCrossOp : public OpRewritePattern { }; } // namespace +namespace { +static ValueTensorType getTypeFromShape(Type dtype, ArrayRef vals) { + // Get a vector of integers from a vector of Values. + auto getIntShape = [](auto &&vals) { + SmallVector shape; + shape.reserve(vals.size()); + for (auto v : vals) { + int64_t cst_val; + if (matchPattern(v, m_TorchConstantInt(&cst_val))) { + shape.push_back(cst_val); + } else { + shape.push_back(kUnknownSize); + } + } + return shape; + }; + auto intShape = getIntShape(vals); + return ValueTensorType::get(vals[0].getContext(), intShape, dtype); +} + +// Get the size of the dimension 'i'. Note the use of 'createOrFold' instead +// of 'create': if the dimension size is known, then the AtenSizeIntOp is +// folded to a ConstantOp. +static Value getDimSize(Value tensor, int64_t dim) { + auto loc = tensor.getLoc(); + auto dimVal = + rewriter.create(loc, rewriter.getI64IntegerAttr(dim)); + return rewriter.create(loc, tensor, dimVal); +} + // Decompose aten.pixel_shuffle into: prims.split_dim, aten.permute, and // prims.collapse operations. // @@ -2092,7 +2122,6 @@ class DecomposeAtenLinalgCrossOp : public OpRewritePattern { // X = X.collapse(...) # shape (*leading_dims, C, r*H, r*W) // // 'r' above is referred to as the 'upscale factor' or just 'factor' below. -namespace { class DecomposeAtenPixelShuffleOp : public OpRewritePattern { public: @@ -2122,41 +2151,10 @@ class DecomposeAtenPixelShuffleOp const auto inOptionalDType = inType.getOptionalDtype(); - auto getTypeFromShape = [inOptionalDType](auto &&vals) { - // Get a vector of integers from a vector of Values. - auto getIntShape = [](auto &&vals) { - SmallVector shape; - shape.reserve(vals.size()); - for (auto v : vals) { - int64_t cst_val; - if (matchPattern(v, m_TorchConstantInt(&cst_val))) { - shape.push_back(cst_val); - } else { - shape.push_back(kUnknownSize); - } - } - return shape; - }; - - const auto intShape = getIntShape(vals); - return ValueTensorType::get(vals[0].getContext(), - llvm::ArrayRef(intShape), inOptionalDType); - }; - auto nLeadingDims = inRank - 3; - - // Get the size of the dimension 'i'. Note the use of 'createOrFold' instead - // of 'create': if the dimension size is known, then the AtenSizeIntOp is - // folded to a ConstantOp. - auto getDimSize = [&](uint64_t i) -> Value { - Value dim = - rewriter.create(loc, rewriter.getI64IntegerAttr(i)); - return rewriter.createOrFold(loc, inValue, dim); - }; - - auto inC = getDimSize(inRank - 3); - auto inH = getDimSize(inRank - 2); - auto inW = getDimSize(inRank - 1); + auto inC = getDimSize(inValue, inRank - 3); + auto inH = getDimSize(inValue, inRank - 2); + auto inW = getDimSize(inValue, inRank - 1); auto factor = op.getUpscaleFactor(); @@ -2214,24 +2212,24 @@ class DecomposeAtenPixelShuffleOp auto partiallyExpanded = rewriter .create( - loc, getTypeFromShape(partiallyExpandedShape), inValue, - dimensionConstants[nLeadingDims], outC) + loc, getTypeFromShape(inOptionalDType, partiallyExpandedShape), + inValue, dimensionConstants[nLeadingDims], outC) .getResult(); // Split new dimension factorSquared -> (factor, factor) auto fullyExpanded = rewriter.create( - loc, getTypeFromShape(prePermuteShape), partiallyExpanded, - dimensionConstants[nLeadingDims + 1], factor); + loc, getTypeFromShape(inOptionalDType, prePermuteShape), + partiallyExpanded, dimensionConstants[nLeadingDims + 1], factor); // Perform the permutation - auto permuted = - rewriter.create(loc, getTypeFromShape(postPermuteShape), - fullyExpanded, permuteDimsOrder); + auto permuted = rewriter.create( + loc, getTypeFromShape(inOptionalDType, postPermuteShape), fullyExpanded, + permuteDimsOrder); // Collapse final 2 dimension auto partiallyCollapsed = rewriter.create( - loc, getTypeFromShape(partiallyCollapsedShape), permuted, - dimensionConstants[nLeadingDims + 3], + loc, getTypeFromShape(inOptionalDType, partiallyCollapsedShape), + permuted, dimensionConstants[nLeadingDims + 3], dimensionConstants[nLeadingDims + 4]); // Collapse back to original rank @@ -2243,7 +2241,6 @@ class DecomposeAtenPixelShuffleOp return success(); } }; -} // namespace // Decompose aten.pixel_unshuffle into: prims.split_dim, aten.permute, and // prims.collapse operations. @@ -2260,7 +2257,6 @@ class DecomposeAtenPixelShuffleOp // X = X.permute(0, ..., N, N+2, N+4, N+1, N+3) // # shape (*leading_dims, C, r, r, H, W) // X = X.collapse(...) # shape (*leading_dims, C, r*H, r*W) -namespace { class DecomposeAtenPixelUnshuffleOp : public OpRewritePattern { public: @@ -2281,41 +2277,10 @@ class DecomposeAtenPixelUnshuffleOp const auto inOptionalDType = inType.getOptionalDtype(); - auto getTypeFromShape = [inOptionalDType](auto &&vals) { - // Get a vector of integers from a vector of Values. - auto getIntShape = [](auto &&vals) { - SmallVector shape; - shape.reserve(vals.size()); - for (auto v : vals) { - int64_t cst_val; - if (matchPattern(v, m_TorchConstantInt(&cst_val))) { - shape.push_back(cst_val); - } else { - shape.push_back(kUnknownSize); - } - } - return shape; - }; - - const auto intShape = getIntShape(vals); - return ValueTensorType::get(vals[0].getContext(), - llvm::ArrayRef(intShape), inOptionalDType); - }; - auto nLeadingDims = inRank - 3; - - // Get the size of the dimension 'i'. Note the use of 'createOrFold' instead - // of 'create': if the dimension size is known, then the AtenSizeIntOp is - // folded to a ConstantOp. - auto getDimSize = [&](uint64_t i) -> Value { - Value dim = - rewriter.create(loc, rewriter.getI64IntegerAttr(i)); - return rewriter.createOrFold(loc, inValue, dim); - }; - - auto inC = getDimSize(inRank - 3); - auto inH = getDimSize(inRank - 2); - auto inW = getDimSize(inRank - 1); + auto inC = getDimSize(inValue, inRank - 3); + auto inH = getDimSize(inValue, inRank - 2); + auto inW = getDimSize(inValue, inRank - 1); auto factor = op.getDownscaleFactor(); @@ -2365,15 +2330,15 @@ class DecomposeAtenPixelUnshuffleOp } Value partiallyExpanded = rewriter.create( - loc, getTypeFromShape(partiallyExpandedShape0), inValue, - dimensionConstants[nLeadingDims + 2], outW); + loc, getTypeFromShape(inOptionalDType, partiallyExpandedShape0), + inValue, dimensionConstants[nLeadingDims + 2], outW); Value fullyExpanded = rewriter.create( - loc, getTypeFromShape(partiallyExpandedShape1), partiallyExpanded, - dimensionConstants[nLeadingDims + 1], outH); + loc, getTypeFromShape(inOptionalDType, partiallyExpandedShape1), + partiallyExpanded, dimensionConstants[nLeadingDims + 1], outH); Value permuted = rewriter.create( - loc, getTypeFromShape(postPermuteShape), fullyExpanded, + loc, getTypeFromShape(inOptionalDType, postPermuteShape), fullyExpanded, rewriter.create( loc, Torch::ListType::get(Torch::IntType::get(op->getContext())), permutation)); From 2ef8e1e252ae29b3de979a4f24a539c9154beb24 Mon Sep 17 00:00:00 2001 From: yangxinyu Date: Mon, 29 Apr 2024 17:19:40 +0000 Subject: [PATCH 05/10] fix --- .../Torch/Transforms/DecomposeComplexOps.cpp | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 7525a89661f5..ae4d3a3dec4b 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -2097,7 +2097,7 @@ static ValueTensorType getTypeFromShape(Type dtype, ArrayRef vals) { // Get the size of the dimension 'i'. Note the use of 'createOrFold' instead // of 'create': if the dimension size is known, then the AtenSizeIntOp is // folded to a ConstantOp. -static Value getDimSize(Value tensor, int64_t dim) { +static Value getDimSize(PatternRewriter &rewriter, Value tensor, int64_t dim) { auto loc = tensor.getLoc(); auto dimVal = rewriter.create(loc, rewriter.getI64IntegerAttr(dim)); @@ -2152,9 +2152,9 @@ class DecomposeAtenPixelShuffleOp const auto inOptionalDType = inType.getOptionalDtype(); auto nLeadingDims = inRank - 3; - auto inC = getDimSize(inValue, inRank - 3); - auto inH = getDimSize(inValue, inRank - 2); - auto inW = getDimSize(inValue, inRank - 1); + auto inC = getDimSize(rewriter, inValue, inRank - 3); + auto inH = getDimSize(rewriter, inValue, inRank - 2); + auto inW = getDimSize(rewriter, inValue, inRank - 1); auto factor = op.getUpscaleFactor(); @@ -2278,9 +2278,9 @@ class DecomposeAtenPixelUnshuffleOp const auto inOptionalDType = inType.getOptionalDtype(); auto nLeadingDims = inRank - 3; - auto inC = getDimSize(inValue, inRank - 3); - auto inH = getDimSize(inValue, inRank - 2); - auto inW = getDimSize(inValue, inRank - 1); + auto inC = getDimSize(rewriter, inValue, inRank - 3); + auto inH = getDimSize(rewriter, inValue, inRank - 2); + auto inW = getDimSize(rewriter, inValue, inRank - 1); auto factor = op.getDownscaleFactor(); From 76c0f5b386a56f514a3134b9c6e0dd6ba24d60d5 Mon Sep 17 00:00:00 2001 From: yangxinyu Date: Mon, 29 Apr 2024 17:22:08 +0000 Subject: [PATCH 06/10] format --- lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index ae4d3a3dec4b..7aefb2bee5d5 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -2266,7 +2266,7 @@ class DecomposeAtenPixelUnshuffleOp Location loc = op.getLoc(); Value inValue = op.getSelf(); - auto inType = inValue.getType().cast(); + auto inType = cast(inValue.getType()); auto maybeSizes = inType.getOptionalSizes(); if (!maybeSizes) { return rewriter.notifyMatchFailure( From 228f25e1792c00cf6c03a137558780997e1793f6 Mon Sep 17 00:00:00 2001 From: yangxinyu Date: Mon, 29 Apr 2024 17:34:15 +0000 Subject: [PATCH 07/10] fix nug --- lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 7aefb2bee5d5..5ab3f29f02b1 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -2101,7 +2101,7 @@ static Value getDimSize(PatternRewriter &rewriter, Value tensor, int64_t dim) { auto loc = tensor.getLoc(); auto dimVal = rewriter.create(loc, rewriter.getI64IntegerAttr(dim)); - return rewriter.create(loc, tensor, dimVal); + return rewriter.createOrFold(loc, tensor, dimVal); } // Decompose aten.pixel_shuffle into: prims.split_dim, aten.permute, and From 69c7202d80647a3ab9e26185c1b68495639d26ec Mon Sep 17 00:00:00 2001 From: yangxinyu Date: Mon, 29 Apr 2024 17:48:01 +0000 Subject: [PATCH 08/10] add test --- projects/pt1/e2e_testing/xfail_sets.py | 2 + .../torch_mlir_e2e_test/test_suite/basic.py | 54 +++++++++++++++++++ 2 files changed, 56 insertions(+) diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 10c24b657128..f99f13994836 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1154,6 +1154,8 @@ "PermuteNegativeIndexModule_basic", "PixelShuffleModuleStaticRank3Int64_basic", "PixelShuffleModuleStaticRank4Float32_basic", + "PixelUnshuffleModuleStaticRank3Int64_basic", + "PixelUnshuffleModuleStaticRank4Float32_basic", "PowIntFloatModule_basic", "PrimListUnpackNumMismatchModule_basic", "PrimMaxIntModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index 0dda298e574a..dc0ffcae38c4 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -986,6 +986,24 @@ def PixelShuffleModuleSpatiallyStatic_basic(module, tu: TestUtils): # ============================================================================== +class PixelUnshuffleModuleStaticRank3int64(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([12, 2, 2], torch.int64, True)]) + def forward(self, x): + return torch.ops.aten.pixel_unshuffle(x, 2) + + +@register_test_case(module_factory=lambda: PixelUnshuffleModuleStaticRank3int64()) +def PixelUnshuffleModuleStaticRank3int64_basic(module, tu: TestUtils): + module.forward(tu.randint(12, 2, 2, low=0, high=100)) + + +# ============================================================================== + + class PixelUnshuffleModuleStaticRank4Float32(torch.nn.Module): def __init__(self): super().__init__() @@ -1004,6 +1022,42 @@ def PixelUnshuffleModuleStaticRank4Float32_basic(module, tu: TestUtils): # ============================================================================== +class PixelUnshuffleModuleFullDynamic(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([-1, -1, -1, -1], torch.int64, True)]) + def forward(self, x): + return torch.ops.aten.pixel_unshuffle(x, 2) + + +@register_test_case(module_factory=lambda: PixelUnshuffleModuleFullDynamic()) +def PixelUnshuffleModuleFullDynamic_basic(module, tu: TestUtils): + module.forward(tu.randint(1, 8, 4, 4, low=0, high=100)) + + +# ============================================================================== + + +class PixelUnshuffleModuleSpatiallyDynamic(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([2, 1, -1, -1], torch.int64, True)]) + def forward(self, x): + return torch.ops.aten.pixel_unshuffle(x, 2) + + +@register_test_case(module_factory=lambda: PixelUnshuffleModuleSpatiallyDynamic()) +def PixelUnshuffleModuleSpatiallyDynamic_basic(module, tu: TestUtils): + module.forward(tu.randint(2, 1, 8, 8, low=0, high=100)) + + +# ============================================================================== + + class TensorsConcatModule(torch.nn.Module): def __init__(self): super().__init__() From 4754815132685d3ef5100b8fb9d4d3fc8f858c83 Mon Sep 17 00:00:00 2001 From: yangxinyu Date: Mon, 29 Apr 2024 17:55:54 +0000 Subject: [PATCH 09/10] fool bug --- projects/pt1/e2e_testing/xfail_sets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index f99f13994836..681e5be1d555 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1154,7 +1154,7 @@ "PermuteNegativeIndexModule_basic", "PixelShuffleModuleStaticRank3Int64_basic", "PixelShuffleModuleStaticRank4Float32_basic", - "PixelUnshuffleModuleStaticRank3Int64_basic", + "PixelUnshuffleModuleStaticRank3int64_basic", "PixelUnshuffleModuleStaticRank4Float32_basic", "PowIntFloatModule_basic", "PrimListUnpackNumMismatchModule_basic", From a27e990d91858aa5944363282d6574e7e77b6240 Mon Sep 17 00:00:00 2001 From: yangxinyu Date: Tue, 30 Apr 2024 06:37:01 +0000 Subject: [PATCH 10/10] fail onnx --- projects/pt1/e2e_testing/xfail_sets.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 681e5be1d555..5adb393d2cc1 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2653,6 +2653,9 @@ "MaskedFillTensorFloatValueModule_basic", "NativeDropoutTrainModule_basic", "NativeDropoutTrainStaticShapeModule_basic", + "PixelUnshuffleModuleFullDynamic_basic", + "PixelUnshuffleModuleSpatiallyDynamic_basic", + "PixelUnshuffleModuleStaticRank3int64_basic", "ReduceAnyFloatModule_basic", "ReduceMaxAlongDimUnsignedInt_basic", "ReduceMinAlongDimUnsignedInt_basic",