Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Torch] Emit ChannelshuffleOp and decompose it. #3268

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -7478,6 +7478,30 @@ def Torch_AtenTransposeIntOp : Torch_Op<"aten.transpose.int", [
}];
}

def Torch_AtenChannelShuffleOp : Torch_Op<"aten.channel_shuffle", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::channel_shuffle : (Tensor, int) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_IntType:$groups
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenChannelShuffleOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenChannelShuffleOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}

def Torch_AtenPixelShuffleOp : Torch_Op<"aten.pixel_shuffle", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
30 changes: 30 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7202,6 +7202,32 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %3 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %1, %arg2, %2) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
" return %3 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.channel_shuffle\"(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {\n"
" %str = torch.constant.str \"AssertionError: number of input channels must be divisible by groups in channel_shuffle\"\n"
" %int-3 = torch.constant.int -3\n"
" %none = torch.constant.none\n"
" %str_0 = torch.constant.str \"AssertionError: input must be at least rank-3 in channel_shuffle\"\n"
" %int3 = torch.constant.int 3\n"
" %int0 = torch.constant.int 0\n"
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %1 = torch.aten.ge.int %0, %int3 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %1 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %2 = torch.aten.__getitem__.t %arg0, %int-3 : !torch.list<int>, !torch.int -> !torch.int\n"
" %3 = torch.aten.remainder.int %2, %arg1 : !torch.int, !torch.int -> !torch.int\n"
" %4 = torch.aten.eq.int %3, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %4 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.pixel_shuffle\"(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {\n"
" %int-1 = torch.constant.int -1\n"
" %int-2 = torch.constant.int -2\n"
Expand Down Expand Up @@ -9917,6 +9943,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.channel_shuffle\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.pixel_shuffle\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
Expand Down
137 changes: 137 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2245,6 +2245,142 @@ class DecomposeAtenPixelShuffleOp
};
} // namespace

// Decompose aten.channel_shuffle into: prim.split_dim, permute, reshape
// operations.

// If input is a tensor of shape
// (*leading_dims, g * n, H, W),

// where leading_dims is of size N, then
// X = channel_shuffle(input, g)
//
// gets replaced with
// X = input.split_dim(...) # shape (*leading_dims, g, n, H, W)
// X = X.permute(0, ..., N+1, N, N+2, N+3)
// # shape (*leading_dims, n, g, H, W)
// X = X.reshape(...) # shape (*leading_dims, g * n, H, W)
namespace {
class DecomposeAtenChannelShuffleOp
: public OpRewritePattern<AtenChannelShuffleOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenChannelShuffleOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value input = op.getSelf();
Value groups = op.getGroups();

auto inType = cast<BaseTensorType>(input.getType());
auto maybeSizes = inType.getOptionalSizes();
if (!maybeSizes) {
return rewriter.notifyMatchFailure(
op, "Expected input tensor to have known rank.");
}
auto inShape = maybeSizes.value();
auto inRank = inShape.size();

if (inRank < 3)
return rewriter.notifyMatchFailure(
op, "Expected input tensor to have rank greater than 2.");

const auto inOptionalDType = inType.getOptionalDtype();

auto getTypeFromShape = [inOptionalDType](auto &&vals) {
// Get a vector of integers from a vector of Values.
auto getIntShape = [](auto &&vals) {
SmallVector<int64_t> shape;
shape.reserve(vals.size());
for (auto v : vals) {
int64_t cst_val;
if (matchPattern(v, m_TorchConstantInt(&cst_val))) {
shape.push_back(cst_val);
} else {
shape.push_back(kUnknownSize);
}
}
return shape;
};

const auto intShape = getIntShape(vals);
return ValueTensorType::get(vals[0].getContext(),
llvm::ArrayRef(intShape), inOptionalDType);
};

auto nLeadingDims = inRank - 3;
// Get the size of the dimension 'i'. Note the use of 'createOrFold' instead
// of 'create': if the dimension size is known, then the AtenSizeIntOp is
// folded to a ConstantOp.
auto getDimSize = [&](uint64_t i) -> Value {
Value dim =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(i));
return rewriter.createOrFold<AtenSizeIntOp>(loc, input, dim);
};

auto inC = getDimSize(inRank - 3);
auto inH = getDimSize(inRank - 2);
auto inW = getDimSize(inRank - 1);

auto groupC = rewriter.createOrFold<AtenFloordivIntOp>(loc, inC, groups);

auto outputH = inH;
auto outputW = inW;

SmallVector<Value> dimensionConstants;
dimensionConstants.reserve(inRank + 1);

for (unsigned i = 0; i < inRank + 1; ++i) {
dimensionConstants.push_back(
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(i)));
}

SmallVector<Value> leadingDims;
leadingDims.reserve(nLeadingDims);

for (unsigned i = 0; i < nLeadingDims; ++i) {
Value leadingDimSize = rewriter.createOrFold<AtenSizeIntOp>(
loc, input, dimensionConstants[i]);
leadingDims.push_back(leadingDimSize);
}

SmallVector<Value> expandShape = leadingDims;
expandShape.append({groups, groupC, outputH, outputW});

auto fullyExpanded = rewriter.create<PrimsSplitDimOp>(
loc, getTypeFromShape(expandShape), input,
dimensionConstants[nLeadingDims], groups);

SmallVector<Value> permutation{dimensionConstants.begin(),
dimensionConstants.begin() + nLeadingDims};

SmallVector<uint64_t> permutationTail{1, 0, 2, 3};

for (uint64_t d : permutationTail) {
permutation.push_back(dimensionConstants[nLeadingDims + d]);
}

Value permuteDimsOrder = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(op->getContext())),
permutation);

SmallVector<Value> postPermuteShape = leadingDims;
postPermuteShape.append({groupC, groups, outputH, outputW});

auto permuted =
rewriter.create<AtenPermuteOp>(loc, getTypeFromShape(postPermuteShape),
fullyExpanded, permuteDimsOrder);

// auto collapsedShape = leadingDims;
// collapsedShape.append({outputH, outputW});

rewriter.replaceOpWithNewOp<PrimsCollapseOp>(
op, op.getType(), permuted, dimensionConstants[nLeadingDims + 0],
dimensionConstants[nLeadingDims + 1]);

return success();
}
};
} // namespace

// ReLU6(x) = min(max(0, x), 6) = min(Relu(x), 6)
static Value getRelu6Results(PatternRewriter &rewriter, Location loc,
Value input) {
Expand Down Expand Up @@ -7696,6 +7832,7 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAtenMatmulOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenMvOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenLinalgCrossOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenChannelShuffleOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenPixelShuffleOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenTOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAten_LogSoftmaxBackwardDataOp>(
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenSelectIntOp>();
target.addIllegalOp<AtenMvOp>();
target.addIllegalOp<AtenLinalgCrossOp>();
target.addIllegalOp<AtenChannelShuffleOp>();
target.addIllegalOp<AtenPixelShuffleOp>();
target.addIllegalOp<AtenTOp>();
target.addIllegalOp<Aten_LogSoftmaxBackwardDataOp>();
Expand Down
1 change: 1 addition & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,6 +818,7 @@
}

STABLEHLO_PASS_SET = {
"ChannelShuffleModuleStatic_basic",
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,15 @@ def aten〇sum〇dim_IntList〡shape(self: List[int], dim: Optional[List[int]],
def aten〇prod〇dim_int〡shape(self: List[int], dim: int, keepdim: bool = False, dtype: Optional[int] = None) -> List[int]:
return upstream_shape_functions.sum_mean_dim(self, [dim], keepdim, dtype)

def aten〇channel_shuffle〡shape(self: List[int], groups: int) -> List[int]:

assert len(self) >= 3, "input must be at least rank-3 in channel_shuffle"

num_channels = self[-3]
assert num_channels % groups == 0, "number of input channels must be divisible by groups in channel_shuffle"
return self


def aten〇pixel_shuffle〡shape(self: List[int], upscale_factor: int) -> List[int]:

assert len(self) >= 3, "input must be at least rank-3 in pixel_shuffle"
Expand Down Expand Up @@ -2150,6 +2159,11 @@ def aten〇adaptive_avg_pool1d〡dtype(self_rank_dtype: Tuple[int, int], output_
self_rank, self_dtype = self_rank_dtype
return self_dtype

@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 6, 3, 4)], groups=2))
def aten〇channel_shuffle〡dtype(self_rank_dtype: Tuple[int, int], groups: int) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype

@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(4, 1, 1)], upscale_factor = 2))
def aten〇pixel_shuffle〡dtype(self_rank_dtype: Tuple[int, int], upscale_factor: int) -> int:
self_rank, self_dtype = self_rank_dtype
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::adaptive_max_pool3d : (Tensor, int[]) -> (Tensor, Tensor)")
emit("aten::topk : (Tensor, int, int, bool, bool) -> (Tensor, Tensor)")
emit("aten::transpose.int : (Tensor, int, int) -> (Tensor)")
emit("aten::channel_shuffle : (Tensor, int) -> (Tensor)")
emit("aten::pixel_shuffle : (Tensor, int) -> (Tensor)")
emit("aten::permute : (Tensor, int[]) -> (Tensor)", has_verifier=True)
emit("aten::movedim.int : (Tensor, int, int) -> (Tensor)")
Expand Down
18 changes: 18 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,6 +986,24 @@ def PixelShuffleModuleSpatiallyStatic_basic(module, tu: TestUtils):
# ==============================================================================


class ChannelShuffleModuleStatic(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([None, ([2, 6, 3, 4], torch.float32, True)])
def forward(self, x):
return torch.ops.aten.channel_shuffle(x, 2)


@register_test_case(module_factory=lambda: ChannelShuffleModuleStatic())
def ChannelShuffleModuleStatic_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 6, 3, 4))


# ==============================================================================


class TensorsConcatModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
Loading