From fd771f0c99681b4af34a44cfcdb7c262c462bae7 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Thu, 5 Dec 2024 10:40:42 +0000 Subject: [PATCH] [MLIR][TORCH] Add support for 1-d group convolution This commit adds the support for 1-d group convolution by transforming it into a 2-d group convolution which is already supported. This commit also refactors the unsqueeze and squeeze tensor utility. Signed-off-by: Vivek Khandelwal --- include/torch-mlir/Conversion/Utils/Utils.h | 9 ++ lib/Conversion/TorchToLinalg/DataMovement.cpp | 98 ++------------- lib/Conversion/TorchToLinalg/Linear.cpp | 59 +++++++-- lib/Conversion/Utils/Utils.cpp | 114 ++++++++++++++++++ projects/pt1/e2e_testing/xfail_sets.py | 4 + .../torch_mlir_e2e_test/test_suite/conv.py | 27 +++++ 6 files changed, 218 insertions(+), 93 deletions(-) diff --git a/include/torch-mlir/Conversion/Utils/Utils.h b/include/torch-mlir/Conversion/Utils/Utils.h index d21dd5504dcd..264fb4966d39 100644 --- a/include/torch-mlir/Conversion/Utils/Utils.h +++ b/include/torch-mlir/Conversion/Utils/Utils.h @@ -97,6 +97,15 @@ Value toPositiveValidDim(ConversionPatternRewriter &rewriter, Location loc, Value torchOptionalInt, Value builtinInt, Value defaultValue, Value dimSize); +// Helper function to unsqueeze the input tensor at given dim. +// Returns the unsqueezed tensor or failure. +FailureOr unsqueezeTensor(PatternRewriter &rewriter, Operation *op, + Value input, int64_t dim); + +// Helper function to squeeze the input tensor at given dim. +// Returns the squeezed tensor or failure. +FailureOr squeezeTensor(PatternRewriter &rewriter, Operation *op, + Value input, int64_t dim); } // namespace Torch } // namespace torch } // namespace mlir diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index a18c0bae01fc..b8c20bc73f65 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -1642,69 +1642,18 @@ class ConvertAtenSqueezeDimOp : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); - Value input = adaptor.getSelf(); - auto inputType = cast(input.getType()); - int64_t inputRank = inputType.getRank(); - - if (inputRank == 0) { - return rewriter.notifyMatchFailure( - op, "zero input rank should have been handled by the folder"); - } - int64_t dim; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) return rewriter.notifyMatchFailure(op, "dim must be constant"); - dim = toPositiveDim(dim, inputRank); - if (!isValidDim(dim, inputRank)) - return rewriter.notifyMatchFailure(op, "dim is statically invalid"); - - // assert dynamic squeeze dim size == 1 - if (inputType.isDynamicDim(dim)) { - Value cstDim = rewriter.create(op.getLoc(), dim); - Value dimVal = rewriter.create(op.getLoc(), input, cstDim); - Value cstOne = rewriter.create(op.getLoc(), 1); - Value cmp = rewriter.create( - op.getLoc(), arith::CmpIPredicate::eq, dimVal, cstOne); - rewriter.create( - op.getLoc(), cmp, - rewriter.getStringAttr( - "Expected dynamic squeeze dim size to be statically 1")); - } - - const TypeConverter *typeConverter = getTypeConverter(); - auto resultType = - cast(typeConverter->convertType(op.getType())); - int64_t resultRank = resultType.getRank(); - // If the dim(th) dimension of operand tensor type is not statically unit, - // `aten.squeeze` will behave as an identity operation. - if (inputType.getDimSize(dim) != 1 && !inputType.isDynamicDim(dim)) { - rewriter.replaceOpWithNewOp(op, resultType, input); - return success(); + auto squeezeTensorInfo = + squeezeTensor(rewriter, op, adaptor.getSelf(), dim); + if (failed(squeezeTensorInfo)) { + return rewriter.notifyMatchFailure(op, + "cannot generate unsqueeze tensor"); } - SmallVector reassociationMap(resultRank); - bool alreadyCrossedSqueezedDim = false; - for (int i = 0; i != resultRank; i++) { - if (alreadyCrossedSqueezedDim) { - reassociationMap[i].push_back(i + 1); - } else { - reassociationMap[i].push_back(i); - if (dim != 0 && i != dim - 1) - continue; - - alreadyCrossedSqueezedDim = true; - if (dim == 0) - reassociationMap[0].push_back(1); - if (i == dim - 1) - reassociationMap[i].push_back(dim); - } - } - // Note: In case the operand tensor type is of unit rank and is statically - // shaped with unit dimension, the `reassociationMap` will be empty and the - // input will be collapsed to a 0-D tensor. - rewriter.replaceOpWithNewOp(op, resultType, input, - reassociationMap); + rewriter.replaceOp(op, squeezeTensorInfo.value()); return success(); } }; @@ -1722,36 +1671,15 @@ class ConvertAtenUnsqueezeOp : public OpConversionPattern { int64_t dim; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) return rewriter.notifyMatchFailure(op, "dim must be constant"); - auto inputRank = - cast(adaptor.getSelf().getType()).getRank(); - dim = toPositiveDim(dim, inputRank + 1); - if (!isValidDim(dim, inputRank + 1)) - return rewriter.notifyMatchFailure(op, "dim is statically invalid"); - SmallVector reassociationMap(inputRank); - // From the perspective of the reassociation map, the situation of - // unsqueezing before or after the last dimension is symmetrical. - // Normalize it to the "before" case. - // The 0 case is special here, since there is no last dimension to insert - // before -- we simply rely on the loop below iterating 0 times. - if (dim == inputRank && inputRank != 0) - dim = inputRank - 1; - bool alreadyCrossedExpandedDim = false; - for (int i = 0; i != inputRank; i++) { - if (alreadyCrossedExpandedDim) { - reassociationMap[i].push_back(i + 1); - } else { - reassociationMap[i].push_back(i); - if (i == dim) { - reassociationMap[i].push_back(i + 1); - alreadyCrossedExpandedDim = true; - } - } + auto unsqueezeTensorInfo = + unsqueezeTensor(rewriter, op, adaptor.getSelf(), dim); + if (failed(unsqueezeTensorInfo)) { + return rewriter.notifyMatchFailure(op, + "cannot generate unsqueeze tensor"); } - auto resultType = cast( - getTypeConverter()->convertType(op->getResult(0).getType())); - rewriter.replaceOpWithNewOp( - op, resultType, adaptor.getSelf(), reassociationMap); + + rewriter.replaceOp(op, unsqueezeTensorInfo.value()); return success(); } }; diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 9ec7761704ea..02f6761de189 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -850,6 +850,46 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { return rewriter.notifyMatchFailure(op, "only support constant int dilations"); + // Checks for valid group size + int64_t numGroups; + if (!matchPattern(op.getGroups(), m_TorchConstantInt(&numGroups))) + return rewriter.notifyMatchFailure(op, + "only constant group size supported."); + Value groups = castIntToIndex(rewriter, loc, adaptor.getGroups()); + + // Adding support for 1d group convolution by converting the 1d-conv to + // 2d-conv. + bool is1DGroupConv = (numSpatialDims == 1 && numGroups != 1) ? true : false; + if (is1DGroupConv) { + // Unsqueezing the last dim of input and weight. Also extending the + // dilation, stride, padding, and output padding lists. + auto unsqueezeInputInfo = + unsqueezeTensor(rewriter, op, input, /*dim=*/-1); + if (failed(unsqueezeInputInfo)) { + return rewriter.notifyMatchFailure(op, + "cannot generate unsqueeze tensor"); + } + input = unsqueezeInputInfo.value(); + + auto unsqueezeWeightInfo = + unsqueezeTensor(rewriter, op, weight, /*dim=*/-1); + if (failed(unsqueezeWeightInfo)) { + return rewriter.notifyMatchFailure(op, + "cannot generate unsqueeze tensor"); + } + weight = unsqueezeWeightInfo.value(); + + Value cstZero = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + paddingIntValues.push_back(cstZero); + outputPaddingIntValues.push_back(cstZero); + strideInts.push_back(1); + dilationInts.push_back(1); + + inRank += 1; + numSpatialDims += 1; + } + Value inBatch = getDimOp(rewriter, loc, input, 0); Value inChannels = getDimOp(rewriter, loc, input, 1); SmallVector inDims; @@ -861,13 +901,6 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { for (size_t i = 2; i < inRank; i++) weightDims.push_back(getDimOp(rewriter, loc, weight, i)); - // Checks for valid group size - int64_t numGroups; - if (!matchPattern(op.getGroups(), m_TorchConstantInt(&numGroups))) - return rewriter.notifyMatchFailure(op, - "only constant group size supported."); - Value groups = castIntToIndex(rewriter, loc, adaptor.getGroups()); - auto validate = [&](Value toValidate, std::string err) { Value c0 = rewriter.create(loc, rewriter.getIndexAttr(0)); @@ -1286,7 +1319,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { if (numSpatialDims != 2) return rewriter.notifyMatchFailure( - op, "unimplemented: only 2D grouped convolution supported"); + op, "unimplemented: only 1D and 2D grouped convolution supported"); // Grouped case, use the grouped conv linalg op auto expandGroups = [&](Value tensor, size_t dim) { @@ -1371,6 +1404,16 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv, resultElementType); } + + if (is1DGroupConv) { + // Squeezing the last dim of the result of conv. + auto squeezeInputInfo = squeezeTensor(rewriter, op, conv, /*dim=*/-1); + if (failed(squeezeInputInfo)) { + return rewriter.notifyMatchFailure(op, + "cannot generate unsqueeze tensor"); + } + conv = squeezeInputInfo.value(); + } rewriter.replaceOpWithNewOp(op, newResultType, conv); return success(); } diff --git a/lib/Conversion/Utils/Utils.cpp b/lib/Conversion/Utils/Utils.cpp index e3f5b6d0299a..7a9ff0077376 100644 --- a/lib/Conversion/Utils/Utils.cpp +++ b/lib/Conversion/Utils/Utils.cpp @@ -447,6 +447,120 @@ Value toPositiveValidDim(ConversionPatternRewriter &rewriter, Location loc, return castIntToIndex(rewriter, loc, boundedByDimSize); } +// Helper function to unsqueeze the input tensor at given dim. +// Returns the unsqueezed tensor or failure. +FailureOr unsqueezeTensor(PatternRewriter &rewriter, Operation *op, + Value input, int64_t dim) { + auto inputType = cast(input.getType()); + int64_t inputRank = inputType.getRank(); + ArrayRef inputShape = inputType.getShape(); + + // `input` has a reduced rank. Hence add 1. + int64_t unsqueezedRank = inputShape.size() + 1; + dim = toPositiveDim(dim, unsqueezedRank); + if (!isValidDim(dim, unsqueezedRank)) { + return rewriter.notifyMatchFailure(op, "dim is not a valid dim"); + } + + SmallVector unsqueezedShape{inputShape}; + unsqueezedShape.insert(unsqueezedShape.begin() + dim, 1); + Type unsqueezedType = + RankedTensorType::get(unsqueezedShape, inputType.getElementType()); + + SmallVector reassociationMap(inputRank); + // From the perspective of the reassociation map, the situation of + // unsqueezing before or after the last dimension is symmetrical. + // Normalize it to the "before" case. + // The 0 case is special here, since there is no last dimension to insert + // before -- we simply rely on the loop below iterating 0 times. + if (dim == inputRank && inputRank != 0) + dim = inputRank - 1; + bool alreadyCrossedExpandedDim = false; + for (int i = 0; i != inputRank; i++) { + if (alreadyCrossedExpandedDim) { + reassociationMap[i].push_back(i + 1); + } else { + reassociationMap[i].push_back(i); + if (i == dim) { + reassociationMap[i].push_back(i + 1); + alreadyCrossedExpandedDim = true; + } + } + } + Value unsqueezed = rewriter.create( + op->getLoc(), unsqueezedType, input, reassociationMap); + return unsqueezed; +} + +// Helper function to squeeze the input tensor at given dim. +// Returns the squeezed tensor or failure. +FailureOr squeezeTensor(PatternRewriter &rewriter, Operation *op, + Value input, int64_t dim) { + Location loc = op->getLoc(); + auto inputType = cast(input.getType()); + int64_t inputRank = inputType.getRank(); + + if (inputRank == 0) { + return rewriter.notifyMatchFailure( + op, "zero input rank should have been handled by the folder"); + } + + dim = toPositiveDim(dim, inputRank); + if (!isValidDim(dim, inputRank)) + return rewriter.notifyMatchFailure(op, "dim is statically invalid"); + + // assert dynamic squeeze dim size == 1 + if (inputType.isDynamicDim(dim)) { + Value cstDim = rewriter.create(loc, dim); + Value dimVal = rewriter.create(loc, input, cstDim); + Value cstOne = rewriter.create(loc, 1); + Value cmp = rewriter.create(loc, arith::CmpIPredicate::eq, + dimVal, cstOne); + rewriter.create( + loc, cmp, + rewriter.getStringAttr( + "Expected dynamic squeeze dim size to be statically 1")); + } + + ArrayRef inputShape = inputType.getShape(); + SmallVector squeezedShape; + squeezedShape.append(inputShape.begin(), inputShape.begin() + dim); + squeezedShape.append(inputShape.begin() + dim + 1, inputShape.end()); + int64_t squeezedRank = inputRank - 1; + Type squeezedType = + RankedTensorType::get(squeezedShape, inputType.getElementType()); + + // If the dim(th) dimension of operand tensor type is not statically unit, + // squeeze will behave as an identity operation. + if (inputType.getDimSize(dim) != 1 && !inputType.isDynamicDim(dim)) { + return input; + } + + SmallVector reassociationMap(squeezedRank); + bool alreadyCrossedSqueezedDim = false; + for (int i = 0; i != squeezedRank; i++) { + if (alreadyCrossedSqueezedDim) { + reassociationMap[i].push_back(i + 1); + } else { + reassociationMap[i].push_back(i); + if (dim != 0 && i != dim - 1) + continue; + + alreadyCrossedSqueezedDim = true; + if (dim == 0) + reassociationMap[0].push_back(1); + if (i == dim - 1) + reassociationMap[i].push_back(dim); + } + } + // Note: In case the operand tensor type is of unit rank and is statically + // shaped with unit dimension, the `reassociationMap` will be empty and the + // input will be collapsed to a 0-D tensor. + Value squeezed = rewriter.create( + op->getLoc(), squeezedType, input, reassociationMap); + return squeezed; +} + } // namespace Torch } // namespace torch } // namespace mlir diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 7430ad89c2c2..0795e8e708f1 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2744,6 +2744,7 @@ "ElementwiseBitwiseAndScalarInt64Module_basic", "ElementwiseBitwiseAndScalarInt32Module_basic", "ElementwiseBitwiseAndScalarInt8Module_basic", + "Conv1dGroupModule_basic", "Conv2dQInt8Module_basic", "Conv2dQInt8Module_depthwise", "Conv2dQInt8Module_grouped", @@ -2899,6 +2900,7 @@ "Conv1dModule_basic", "Conv1dWithSamePaddingModule_basic", "Conv1dWithValidPaddingModule_basic", + "Conv1dGroupModule_basic", "Conv2dBiasNoPaddingModule_basic", "Conv2dModule_basic", "Conv2dNoPaddingModule_basic", @@ -3604,6 +3606,7 @@ "Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic", "Conv1dWithSamePaddingModule_basic", "Conv1dWithValidPaddingModule_basic", + "Conv1dGroupModule_basic", "Conv2dQInt8Module_basic", "Conv2dQInt8Module_depthwise", "Conv2dQInt8Module_grouped", @@ -4195,6 +4198,7 @@ "Conv1dWithSamePaddingModule_basic", "Conv1dWithValidPaddingModule_basic", "Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic", + "Conv1dGroupModule_basic", "Conv2dBiasNoPaddingModule_basic", "Conv2dModule_basic", "Conv2dNoPaddingModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py index 7a45dd7fc0ce..663c4b6a746b 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py @@ -1199,6 +1199,33 @@ def Conv1dWithValidPaddingModule_basic(module, tu: TestUtils): module.forward(inputVec, weight, bias) +class Conv1dGroupModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ([-1], torch.float32, True), + ] + ) + def forward(self, inputVec, weight, bias): + return torch.ops.aten.conv1d( + inputVec, weight, bias=bias, stride=[1], padding=[0], dilation=[1], groups=2 + ) + + +@register_test_case(module_factory=lambda: Conv1dGroupModule()) +def Conv1dGroupModule_basic(module, tu: TestUtils): + inputVec = tu.rand(2, 4, 6) + weight = torch.randn(8, 2, 3) + bias = torch.randn(8) + module.forward(inputVec, weight, bias) + + class Conv2dModule(torch.nn.Module): def __init__(self): super().__init__()