Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][TORCH] Add support for 1-d group convolution #3904

Merged
merged 4 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions include/torch-mlir/Conversion/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,15 @@ Value toPositiveValidDim(ConversionPatternRewriter &rewriter, Location loc,
Value torchOptionalInt, Value builtinInt,
Value defaultValue, Value dimSize);

// Helper function to unsqueeze the input tensor at given dim.
// Returns the unsqueezed tensor or failure.
FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
Value input, int64_t dim);

// Helper function to squeeze the input tensor at given dim.
// Returns the squeezed tensor or failure.
FailureOr<Value> squeezeTensor(PatternRewriter &rewriter, Operation *op,
Value input, int64_t dim);
} // namespace Torch
} // namespace torch
} // namespace mlir
Expand Down
98 changes: 13 additions & 85 deletions lib/Conversion/TorchToLinalg/DataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1642,69 +1642,18 @@ class ConvertAtenSqueezeDimOp : public OpConversionPattern<AtenSqueezeDimOp> {
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Value input = adaptor.getSelf();
auto inputType = cast<RankedTensorType>(input.getType());
int64_t inputRank = inputType.getRank();

if (inputRank == 0) {
return rewriter.notifyMatchFailure(
op, "zero input rank should have been handled by the folder");
}

int64_t dim;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
return rewriter.notifyMatchFailure(op, "dim must be constant");
dim = toPositiveDim(dim, inputRank);
if (!isValidDim(dim, inputRank))
return rewriter.notifyMatchFailure(op, "dim is statically invalid");

// assert dynamic squeeze dim size == 1
if (inputType.isDynamicDim(dim)) {
Value cstDim = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), dim);
Value dimVal = rewriter.create<tensor::DimOp>(op.getLoc(), input, cstDim);
Value cstOne = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 1);
Value cmp = rewriter.create<arith::CmpIOp>(
op.getLoc(), arith::CmpIPredicate::eq, dimVal, cstOne);
rewriter.create<cf::AssertOp>(
op.getLoc(), cmp,
rewriter.getStringAttr(
"Expected dynamic squeeze dim size to be statically 1"));
}

const TypeConverter *typeConverter = getTypeConverter();
auto resultType =
cast<RankedTensorType>(typeConverter->convertType(op.getType()));
int64_t resultRank = resultType.getRank();

// If the dim(th) dimension of operand tensor type is not statically unit,
// `aten.squeeze` will behave as an identity operation.
if (inputType.getDimSize(dim) != 1 && !inputType.isDynamicDim(dim)) {
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, input);
return success();
auto squeezeTensorInfo =
squeezeTensor(rewriter, op, adaptor.getSelf(), dim);
if (failed(squeezeTensorInfo)) {
return rewriter.notifyMatchFailure(op,
"cannot generate unsqueeze tensor");
}

SmallVector<ReassociationIndices> reassociationMap(resultRank);
bool alreadyCrossedSqueezedDim = false;
for (int i = 0; i != resultRank; i++) {
if (alreadyCrossedSqueezedDim) {
reassociationMap[i].push_back(i + 1);
} else {
reassociationMap[i].push_back(i);
if (dim != 0 && i != dim - 1)
continue;

alreadyCrossedSqueezedDim = true;
if (dim == 0)
reassociationMap[0].push_back(1);
if (i == dim - 1)
reassociationMap[i].push_back(dim);
}
}
// Note: In case the operand tensor type is of unit rank and is statically
// shaped with unit dimension, the `reassociationMap` will be empty and the
// input will be collapsed to a 0-D tensor.
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(op, resultType, input,
reassociationMap);
rewriter.replaceOp(op, squeezeTensorInfo.value());
return success();
}
};
Expand All @@ -1722,36 +1671,15 @@ class ConvertAtenUnsqueezeOp : public OpConversionPattern<AtenUnsqueezeOp> {
int64_t dim;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
return rewriter.notifyMatchFailure(op, "dim must be constant");
auto inputRank =
cast<RankedTensorType>(adaptor.getSelf().getType()).getRank();
dim = toPositiveDim(dim, inputRank + 1);
if (!isValidDim(dim, inputRank + 1))
return rewriter.notifyMatchFailure(op, "dim is statically invalid");

SmallVector<ReassociationIndices> reassociationMap(inputRank);
// From the perspective of the reassociation map, the situation of
// unsqueezing before or after the last dimension is symmetrical.
// Normalize it to the "before" case.
// The 0 case is special here, since there is no last dimension to insert
// before -- we simply rely on the loop below iterating 0 times.
if (dim == inputRank && inputRank != 0)
dim = inputRank - 1;
bool alreadyCrossedExpandedDim = false;
for (int i = 0; i != inputRank; i++) {
if (alreadyCrossedExpandedDim) {
reassociationMap[i].push_back(i + 1);
} else {
reassociationMap[i].push_back(i);
if (i == dim) {
reassociationMap[i].push_back(i + 1);
alreadyCrossedExpandedDim = true;
}
}
auto unsqueezeTensorInfo =
unsqueezeTensor(rewriter, op, adaptor.getSelf(), dim);
if (failed(unsqueezeTensorInfo)) {
return rewriter.notifyMatchFailure(op,
"cannot generate unsqueeze tensor");
}
auto resultType = cast<RankedTensorType>(
getTypeConverter()->convertType(op->getResult(0).getType()));
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
op, resultType, adaptor.getSelf(), reassociationMap);

rewriter.replaceOp(op, unsqueezeTensorInfo.value());
return success();
}
};
Expand Down
72 changes: 64 additions & 8 deletions lib/Conversion/TorchToLinalg/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -850,6 +850,48 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
return rewriter.notifyMatchFailure(op,
"only support constant int dilations");

// Checks for valid group size
int64_t numGroups;
if (!matchPattern(op.getGroups(), m_TorchConstantInt(&numGroups)))
return rewriter.notifyMatchFailure(op,
"only constant group size supported.");
Value groups = castIntToIndex(rewriter, loc, adaptor.getGroups());

// Adding support for 1d group convolution by converting the 1d-conv to
// 2d-conv.
vivekkhandelwal1 marked this conversation as resolved.
Show resolved Hide resolved
// TODO: Replace this logic with the appropriate linalg op for 1-d group
// convolution once that support is added.
bool is1DGroupConv = (numSpatialDims == 1 && numGroups != 1);
if (is1DGroupConv) {
// Unsqueezing the last dim of input and weight. Also extending the
// dilation, stride, padding, and output padding lists.
auto unsqueezeInputInfo =
unsqueezeTensor(rewriter, op, input, /*dim=*/-1);
if (failed(unsqueezeInputInfo)) {
return rewriter.notifyMatchFailure(op,
"cannot generate unsqueeze tensor");
}
input = unsqueezeInputInfo.value();

auto unsqueezeWeightInfo =
unsqueezeTensor(rewriter, op, weight, /*dim=*/-1);
if (failed(unsqueezeWeightInfo)) {
return rewriter.notifyMatchFailure(op,
"cannot generate unsqueeze tensor");
}
weight = unsqueezeWeightInfo.value();

Value cstZero = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI64IntegerAttr(0));
paddingIntValues.push_back(cstZero);
outputPaddingIntValues.push_back(cstZero);
strideInts.push_back(1);
dilationInts.push_back(1);

inRank++;
numSpatialDims++;
}

Value inBatch = getDimOp(rewriter, loc, input, 0);
Value inChannels = getDimOp(rewriter, loc, input, 1);
SmallVector<Value> inDims;
Expand All @@ -861,13 +903,6 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
for (size_t i = 2; i < inRank; i++)
weightDims.push_back(getDimOp(rewriter, loc, weight, i));

// Checks for valid group size
int64_t numGroups;
if (!matchPattern(op.getGroups(), m_TorchConstantInt(&numGroups)))
return rewriter.notifyMatchFailure(op,
"only constant group size supported.");
Value groups = castIntToIndex(rewriter, loc, adaptor.getGroups());

auto validate = [&](Value toValidate, std::string err) {
Value c0 =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
Expand Down Expand Up @@ -1280,13 +1315,24 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv,
resultElementType);
}

if (is1DGroupConv) {
// Squeezing the last dim of the result of conv.
auto squeezeOutputInfo = squeezeTensor(rewriter, op, conv, /*dim=*/-1);
if (failed(squeezeOutputInfo)) {
return rewriter.notifyMatchFailure(op,
"cannot generate squeeze tensor");
}
conv = squeezeOutputInfo.value();
}

rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, conv);
return success();
}

if (numSpatialDims != 2)
return rewriter.notifyMatchFailure(
op, "unimplemented: only 2D grouped convolution supported");
op, "unimplemented: only 1D and 2D grouped convolution supported");

// Grouped case, use the grouped conv linalg op
auto expandGroups = [&](Value tensor, size_t dim) {
Expand Down Expand Up @@ -1371,6 +1417,16 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv,
resultElementType);
}

if (is1DGroupConv) {
// Squeezing the last dim of the result of conv.
auto squeezeOutputInfo = squeezeTensor(rewriter, op, conv, /*dim=*/-1);
if (failed(squeezeOutputInfo)) {
return rewriter.notifyMatchFailure(op,
"cannot generate squeeze tensor");
}
conv = squeezeOutputInfo.value();
}
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, conv);
return success();
}
Expand Down
113 changes: 113 additions & 0 deletions lib/Conversion/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,119 @@ Value toPositiveValidDim(ConversionPatternRewriter &rewriter, Location loc,
return castIntToIndex(rewriter, loc, boundedByDimSize);
}

// Helper function to unsqueeze the input tensor at given dim.
// Returns the unsqueezed tensor or failure.
FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
Value input, int64_t dim) {
auto inputType = cast<RankedTensorType>(input.getType());
int64_t inputRank = inputType.getRank();
ArrayRef<int64_t> inputShape = inputType.getShape();

// `input` has a reduced rank. Hence add 1.
int64_t unsqueezedRank = inputShape.size() + 1;
dim = toPositiveDim(dim, unsqueezedRank);
if (!isValidDim(dim, unsqueezedRank)) {
return rewriter.notifyMatchFailure(op, "dim is not a valid dim");
}

SmallVector<int64_t> unsqueezedShape{inputShape};
unsqueezedShape.insert(unsqueezedShape.begin() + dim, 1);
Type unsqueezedType =
RankedTensorType::get(unsqueezedShape, inputType.getElementType());

SmallVector<ReassociationIndices> reassociationMap(inputRank);
// From the perspective of the reassociation map, the situation of
// unsqueezing before or after the last dimension is symmetrical.
// Normalize it to the "before" case.
// The 0 case is special here, since there is no last dimension to insert
// before -- we simply rely on the loop below iterating 0 times.
if (dim == inputRank && inputRank != 0)
dim = inputRank - 1;
bool alreadyCrossedExpandedDim = false;
for (int i = 0; i != inputRank; i++) {
if (alreadyCrossedExpandedDim) {
reassociationMap[i].push_back(i + 1);
} else {
reassociationMap[i].push_back(i);
if (i == dim) {
reassociationMap[i].push_back(i + 1);
alreadyCrossedExpandedDim = true;
}
}
}
Value unsqueezed = rewriter.create<tensor::ExpandShapeOp>(
op->getLoc(), unsqueezedType, input, reassociationMap);
return unsqueezed;
}

// Helper function to squeeze the input tensor at given dim.
// Returns the squeezed tensor or failure.
FailureOr<Value> squeezeTensor(PatternRewriter &rewriter, Operation *op,
Value input, int64_t dim) {
Location loc = op->getLoc();
auto inputType = cast<RankedTensorType>(input.getType());
int64_t inputRank = inputType.getRank();

// No scope for squeezing the input.
if (inputRank == 0)
return input;

dim = toPositiveDim(dim, inputRank);
if (!isValidDim(dim, inputRank))
return rewriter.notifyMatchFailure(op, "dim is statically invalid");

// assert dynamic squeeze dim size == 1
if (inputType.isDynamicDim(dim)) {
Value cstDim = rewriter.create<arith::ConstantIndexOp>(loc, dim);
Value dimVal = rewriter.create<tensor::DimOp>(loc, input, cstDim);
Value cstOne = rewriter.create<arith::ConstantIndexOp>(loc, 1);
Value cmp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
dimVal, cstOne);
rewriter.create<cf::AssertOp>(
loc, cmp,
rewriter.getStringAttr(
"Expected dynamic squeeze dim size to be statically 1"));
}

ArrayRef<int64_t> inputShape = inputType.getShape();
SmallVector<int64_t> squeezedShape;
squeezedShape.append(inputShape.begin(), inputShape.begin() + dim);
squeezedShape.append(inputShape.begin() + dim + 1, inputShape.end());
int64_t squeezedRank = inputRank - 1;
Type squeezedType =
RankedTensorType::get(squeezedShape, inputType.getElementType());

// If the dim(th) dimension of operand tensor type is not statically unit,
// squeeze will behave as an identity operation.
if (inputType.getDimSize(dim) != 1 && !inputType.isDynamicDim(dim)) {
return input;
}

SmallVector<ReassociationIndices> reassociationMap(squeezedRank);
bool alreadyCrossedSqueezedDim = false;
for (int i = 0; i != squeezedRank; i++) {
if (alreadyCrossedSqueezedDim) {
reassociationMap[i].push_back(i + 1);
} else {
reassociationMap[i].push_back(i);
if (dim != 0 && i != dim - 1)
continue;

alreadyCrossedSqueezedDim = true;
if (dim == 0)
reassociationMap[0].push_back(1);
if (i == dim - 1)
reassociationMap[i].push_back(dim);
}
}
// Note: In case the operand tensor type is of unit rank and is statically
// shaped with unit dimension, the `reassociationMap` will be empty and the
// input will be collapsed to a 0-D tensor.
Value squeezed = rewriter.create<tensor::CollapseShapeOp>(
op->getLoc(), squeezedType, input, reassociationMap);
return squeezed;
}

} // namespace Torch
} // namespace torch
} // namespace mlir
4 changes: 4 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2731,6 +2731,7 @@
"ElementwiseBitwiseAndScalarInt64Module_basic",
"ElementwiseBitwiseAndScalarInt32Module_basic",
"ElementwiseBitwiseAndScalarInt8Module_basic",
"Conv1dGroupModule_basic",
"Conv2dQInt8Module_basic",
"Conv2dQInt8Module_depthwise",
"Conv2dQInt8Module_grouped",
Expand Down Expand Up @@ -2886,6 +2887,7 @@
"Conv1dModule_basic",
"Conv1dWithSamePaddingModule_basic",
"Conv1dWithValidPaddingModule_basic",
"Conv1dGroupModule_basic",
"Conv2dBiasNoPaddingModule_basic",
"Conv2dModule_basic",
"Conv2dNoPaddingModule_basic",
Expand Down Expand Up @@ -3593,6 +3595,7 @@
"Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic",
"Conv1dWithSamePaddingModule_basic",
"Conv1dWithValidPaddingModule_basic",
"Conv1dGroupModule_basic",
"Conv2dQInt8Module_basic",
"Conv2dQInt8Module_depthwise",
"Conv2dQInt8Module_grouped",
Expand Down Expand Up @@ -4186,6 +4189,7 @@
"Conv1dWithSamePaddingModule_basic",
"Conv1dWithValidPaddingModule_basic",
"Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic",
"Conv1dGroupModule_basic",
"Conv2dBiasNoPaddingModule_basic",
"Conv2dModule_basic",
"Conv2dNoPaddingModule_basic",
Expand Down
Loading
Loading