From 612ccc32e196cd8820132f9da8e411a1a67ffa85 Mon Sep 17 00:00:00 2001 From: rahul shrivastava Date: Thu, 2 Jan 2025 19:55:33 -0800 Subject: [PATCH] Revert "Add Scalarization Patterns for `AtenToDtypeOp`, `AtenNegOp`, `AtenRemainderTensorOp` (#3861)" This reverts commit cd38ecf6c223b94edf05a02dd10781264d762e76. --- lib/Conversion/TorchToArith/TorchToArith.cpp | 26 +- lib/Dialect/Torch/IR/TorchOps.cpp | 4 - .../Torch/Transforms/DecomposeComplexOps.cpp | 5 - .../Torch/Transforms/ScalarizeShapes.cpp | 228 +----------------- test/Dialect/Torch/decompose-complex-ops.mlir | 6 +- test/Dialect/Torch/scalarize-shapes.mlir | 63 +---- 6 files changed, 30 insertions(+), 302 deletions(-) diff --git a/lib/Conversion/TorchToArith/TorchToArith.cpp b/lib/Conversion/TorchToArith/TorchToArith.cpp index 458ea31852ec..143b46694030 100644 --- a/lib/Conversion/TorchToArith/TorchToArith.cpp +++ b/lib/Conversion/TorchToArith/TorchToArith.cpp @@ -82,25 +82,6 @@ class ConvertAtenBinaryOp : public OpConversionPattern { }; } // namespace -namespace { -class ConvertAtenNegIntOp : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(AtenNegIntOp op, - typename OpConversionPattern::OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value a = adaptor.getA(); - rewriter.replaceOpWithNewOp( - op, - rewriter.create(op.getLoc(), /*value=*/0, - /*bitwidth=*/64), - a); - return success(); - } -}; -} // namespace - namespace { template class ConvertAtenUnaryOpToFloatMathOp : public OpConversionPattern { @@ -484,14 +465,11 @@ class ConvertTorchToArith target.addIllegalOp(); patterns.add(typeConverter, context); - target.addIllegalOp(); - patterns.add(typeConverter, context); + target.addIllegalOp(); + AtenMulIntOp>(); patterns.add>( typeConverter, context); - patterns.add>( - typeConverter, context); patterns.add>( typeConverter, context); patterns.add>( diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 868c5ef67a46..3774e65f0859 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -4068,10 +4068,6 @@ OpFoldResult AtenMulIntOp::fold(FoldAdaptor adaptor) { int64_t lhs, rhs; bool lConstant = matchPattern(getOperand(0), m_TorchConstantInt(&lhs)); bool rConstant = matchPattern(getOperand(1), m_TorchConstantInt(&rhs)); - if (lConstant && lhs == 1) - return getOperand(1); - if (rConstant && rhs == 1) - return getOperand(0); if ((lConstant && lhs == 0) || (rConstant && rhs == 0)) return getI64IntegerAttr(getContext(), 0); if (lConstant && rConstant) diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 9db8a6949063..aa15e3735dae 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -4587,11 +4587,6 @@ class DecomposeAtenUnflattenIntOp if (!isValidDim(dimInt, inputRank)) return rewriter.notifyMatchFailure(op, "dim is not a valid dim"); - if (inputShape[dimInt] == Torch::kUnknownSize && - llvm::count(sizesInts, -1) > 0) - return rewriter.notifyMatchFailure( - op, "Unimplemented: dynamic unflatten dim with an inferred size."); - SmallVector sizesTorchInt; if (!getListConstructElements(op.getSizes(), sizesTorchInt)) return rewriter.notifyMatchFailure( diff --git a/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp b/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp index 989057501957..3d1a54de29f9 100644 --- a/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp +++ b/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp @@ -714,7 +714,7 @@ class PropagateAtenItemPattern : public OpRewritePattern { ImplicitLocOpBuilder b(op.getLoc(), rewriter); // Rank 0 item op prop - if (selfTy.getSizes().empty()) { + if (selfTy.getSizes().size() == 0) { auto numToTensor = self.getDefiningOp(); auto squeezeDim = self.getDefiningOp(); if (!squeezeDim && !numToTensor) @@ -746,109 +746,6 @@ class PropagateAtenItemPattern : public OpRewritePattern { }; } // namespace -namespace { - -LogicalResult convertOpFoldResults(ImplicitLocOpBuilder &b, - SmallVector &converted, - SmallVector &elements, - Type inputDtype, Type resultDtype) { - auto inputIsInt = dyn_cast(inputDtype); - auto resultIsInt = dyn_cast(resultDtype); - if (!inputIsInt && !isa(inputDtype)) - return failure(); - if (!resultIsInt && !isa(resultDtype)) - return failure(); - - // if dtypes are both int or both float, no conversion needed - if (static_cast(inputIsInt) == static_cast(resultIsInt)) { - converted = elements; - return success(); - } - - if (resultIsInt) { - for (auto &e : elements) { - auto eValue = dyn_cast(e); - if (eValue) { - converted.push_back(b.createOrFold(eValue)); - continue; - } - auto eAttr = dyn_cast(e); - auto eFloatAttr = dyn_cast_or_null(eAttr); - if (!eFloatAttr) - return failure(); - - converted.push_back(IntegerAttr::get( - resultDtype, static_cast(eFloatAttr.getValueAsDouble()))); - } - return success(); - } - - // result is float - for (auto &e : elements) { - auto eValue = dyn_cast(e); - if (eValue) { - converted.push_back(b.createOrFold(eValue)); - continue; - } - auto eAttr = dyn_cast(e); - auto eIntAttr = dyn_cast(eAttr); - if (!eIntAttr) - return failure(); - - auto eInt = (inputIsInt.isSigned()) ? eIntAttr.getValue().getSExtValue() - : eIntAttr.getValue().getZExtValue(); - converted.push_back(FloatAttr::get(resultDtype, static_cast(eInt))); - } - return success(); -} - -class PropagateAtenToDtypePattern : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(AtenToDtypeOp op, - PatternRewriter &rewriter) const override { - bool nonBlocking, copyArg; - // The non_blocking arg must be `False`. - if (!matchPattern(op.getNonBlocking(), m_TorchConstantBool(&nonBlocking)) || - nonBlocking) - return failure(); - // The copy arg must be `False`. - if (!matchPattern(op.getCopy(), m_TorchConstantBool(©Arg)) || copyArg) - return failure(); - // The memory_format arg must be `none`. - if (!isa(op.getMemoryFormat().getType())) - return failure(); - - auto inputType = dyn_cast(op.getSelf().getType()); - auto resultType = dyn_cast(op.getType()); - if (!inputType || !resultType || !inputType.hasDtype() || - !resultType.hasDtype()) - return failure(); - auto inputDtype = inputType.getDtype(); - auto resultDtype = resultType.getDtype(); - - SmallVector elements; - if (failed(getListFromTensor(op.getSelf(), elements))) - return failure(); - - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - SmallVector converted; - if (failed(convertOpFoldResults(b, converted, elements, inputDtype, - resultDtype))) - return rewriter.notifyMatchFailure( - op, "Unhandled attribute type encountered."); - - SmallVector vals; - if (failed(materializeFolds(b, converted, vals))) - return failure(); - - Value result = constructAtenTensorOpFromList(b, op.getType(), vals); - rewriter.replaceOp(op, result); - return success(); - } -}; -} // namespace - namespace { template class PropagateAtenViewLikePattern : public OpRewritePattern { @@ -931,49 +828,6 @@ class PropagateAtenArithmeticPattern : public OpRewritePattern { if (failed(materializeFolds(b, resultFolds, resultVals))) return failure(); - if (resultTy.getSizes().empty()) { - rewriter.replaceOpWithNewOp( - op, resultTy, resultVals.front()); - return success(); - } - - Value result = constructAtenTensorOpFromList(b, resultTy, resultVals); - rewriter.replaceOp(op, result); - return success(); - } -}; -} // namespace - -namespace { -template -class PropagateAtenUnaryPattern : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(OpTy op, - PatternRewriter &rewriter) const override { - // Check type - auto resultTy = cast(op.getType()); - if (resultTy.getSizes().size() > 1) - return rewriter.notifyMatchFailure(op, "unsupported: rank > 1"); - if (!resultTy.hasDtype() || !isa(resultTy.getDtype())) - return rewriter.notifyMatchFailure(op, "not an int type"); - - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - SmallVector selfFold; - if (failed(getListFromTensor(op.getSelf(), selfFold))) - return failure(); - SmallVector selfVals; - if (failed(materializeFolds(b, selfFold, selfVals))) - return failure(); - SmallVector resultFolds; - for (uint64_t i = 0; i < selfVals.size(); i++) { - resultFolds.push_back( - b.createOrFold(selfVals[i].getType(), selfVals[i])); - } - SmallVector resultVals; - if (failed(materializeFolds(b, resultFolds, resultVals))) - return failure(); - if (resultTy.getSizes().size() == 0) { rewriter.replaceOpWithNewOp( op, resultTy, resultVals.front()); @@ -986,6 +840,7 @@ class PropagateAtenUnaryPattern : public OpRewritePattern { } }; } // namespace + /// ------ Fold Patterns ------ /// // These are shape-specific folding patterns @@ -1060,11 +915,6 @@ class FoldAtenTensorSplatPattern : public OpRewritePattern { auto resultTy = cast(op.getType()); if (!resultTy.hasSizes() || !resultTy.areAllSizesKnown()) return rewriter.notifyMatchFailure(op, "dynamic output shape"); - if (resultTy.getSizes().size() == 0) { - rewriter.replaceOpWithNewOp( - op, op.getType(), elements.front()); - return success(); - } auto loc = op.getLoc(); SmallVector sizes; @@ -1072,10 +922,12 @@ class FoldAtenTensorSplatPattern : public OpRewritePattern { sizes.push_back(rewriter.create( loc, rewriter.getI64IntegerAttr(size))); + Value one = rewriter.create( + loc, rewriter.getType(), 1); Value sizeList = rewriter.create( loc, rewriter.getType(rewriter.getType()), - sizes); + one); Value none = rewriter.create(loc); Value cstFalse = rewriter.create(loc, false); @@ -1179,24 +1031,6 @@ class FoldAtenWhereSelf : public OpRewritePattern { }; } // namespace -namespace { -// fold ridiculous patterns like size.int -> float.scalar -> int.scalar -class FoldAtenIntScalarPattern : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(AtenIntScalarOp op, - PatternRewriter &rewriter) const override { - auto floatScalarOp = op.getA().getDefiningOp(); - if (!floatScalarOp) - return failure(); - auto sizeOp = floatScalarOp.getA().getDefiningOp(); - if (!sizeOp) - return failure(); - rewriter.replaceOp(op, floatScalarOp.getA()); - return success(); - } -}; -} // namespace namespace { class FoldAtenUnsqueezePattern : public OpRewritePattern { public: @@ -1348,29 +1182,8 @@ class CanonicalizeAtenViewPattern : public OpRewritePattern { if (inputUnmatched == 1 && outputUnmatched > 1) { Value dimVal = rewriter.create(op.getLoc(), leftMatchEnd); - SmallVector unflattenSizes(viewSizes.begin() + leftMatchEnd, - viewSizes.end() - rightMatchEnd); - // try to convert a single dynamic size input to -1 - int64_t dynCount = 0; - int64_t dynIdx = 0; - for (auto [i, v] : llvm::enumerate(unflattenSizes)) { - int64_t szeInt; - if (!matchPattern(v, m_TorchConstantInt(&szeInt))) { - dynCount++; - dynIdx = i; - continue; - } - // if we have a -1 already, make dynCount invalid and break - if (szeInt == -1) { - dynCount = -1; - break; - } - } - // if only one size is dynamic, make it -1 - if (dynCount == 1) - unflattenSizes[dynIdx] = - rewriter.create(op.getLoc(), -1); - + ArrayRef unflattenSizes(viewSizes.begin() + leftMatchEnd, + viewSizes.end() - rightMatchEnd); Value unflattenList = rewriter.create( op.getLoc(), op.getSize().getType(), unflattenSizes); rewriter.replaceOpWithNewOp( @@ -1414,18 +1227,6 @@ template class RemoveUnusedPattern : public OpRewritePattern { namespace { -bool isItemForSliceOp(Operation *op) { - auto itemOp = dyn_cast_or_null(op); - if (!itemOp) - return false; - for (OpOperand &use : op->getUses()) { - Operation *userOp = use.getOwner(); - if (isa(userOp)) - return true; - } - return false; -} - bool isSourceOpForShapeScalarization(Operation *op) { return llvm::isa(op); @@ -1443,7 +1244,7 @@ bool isPrimListOfInts(Operation *op) { bool isAnchorOp(Operation *op) { return isa(op) || isa(op) || - isPrimListOfInts(op) || isItemForSliceOp(op); + isPrimListOfInts(op); } // The argument to this function, op, is the use of some source op, srcOp. If @@ -1477,9 +1278,9 @@ bool isInvalidValidViewConsumer(Operation *op, void populateScalarizationFoldPatterns(RewritePatternSet &patterns) { patterns.insert, FoldAtenSqueezePattern, - FoldAtenIntScalarPattern, FoldAtenUnsqueezePattern, - FoldAtenWhereSelf, FoldAtenTensorSplatPattern, - FoldAtenEqIntPattern>(patterns.getContext()); + FoldAtenUnsqueezePattern, FoldAtenWhereSelf, + FoldAtenTensorSplatPattern, FoldAtenEqIntPattern>( + patterns.getContext()); } void populateScalarizationCanonicalizePatterns(RewritePatternSet &patterns) { @@ -1502,12 +1303,10 @@ void populateScalarizationPropagationPatterns(RewritePatternSet &patterns) { PropagateAtenItemPattern, PropagateAtenShapeToTensorPattern, PropagateAtenSliceTensorPattern, PropagateAtenEqTensorPattern, PropagateAtenWhereSelfPattern, PropagateAtenBroadcastToPattern, - PropagateAtenTransposeIntPattern, PropagateAtenToDtypePattern, - PropagateAtenUnaryPattern, + PropagateAtenTransposeIntPattern, PropagateAtenArithmeticPattern, PropagateAtenArithmeticPattern, PropagateAtenArithmeticPattern, - PropagateAtenArithmeticPattern, PropagateAtenArithmeticPattern>( patterns.getContext()); } @@ -1515,7 +1314,6 @@ void populateScalarizationPropagationPatterns(RewritePatternSet &patterns) { void populateScalarizationRemovePatterns(RewritePatternSet &patterns) { patterns.insert, RemoveUnusedPattern, - RemoveUnusedPattern, RemoveUnusedPattern, RemoveUnusedPattern, RemoveUnusedPattern, @@ -1523,8 +1321,6 @@ void populateScalarizationRemovePatterns(RewritePatternSet &patterns) { RemoveUnusedPattern, RemoveUnusedPattern, RemoveUnusedPattern, - RemoveUnusedPattern, - RemoveUnusedPattern, RemoveUnusedPattern>( patterns.getContext()); } diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index 4da482af03f3..f938a2637835 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -105,9 +105,9 @@ func.func @torch.aten.fake_quantize_per_channel_affine_cachemask(%arg0: !torch.v // CHECK-LABEL: test_einsum_inner_prod func.func @test_einsum_inner_prod(%arg0: !torch.vtensor<[5],f64>, %arg1: !torch.vtensor<[5],f64>) -> !torch.vtensor<[],f64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 17 : si64} { - // CHECK-DAG: %[[INT5:.+]] = torch.constant.int 5 - // CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1 - // CHECK-DAG: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT5:.+]] = torch.constant.int 5 + // CHECK: %[[INT1:.+]] = torch.constant.int 1 + // CHECK: %[[INT0:.+]] = torch.constant.int 0 // CHECK: %[[LHS_LIST:.+]] = torch.prim.ListConstruct %[[INT0]] // CHECK: %[[LHS_PERM:.+]] = torch.aten.permute %arg0, %[[LHS_LIST]] // CHECK: %[[RHS_LIST:.+]] = torch.prim.ListConstruct %[[INT0]] diff --git a/test/Dialect/Torch/scalarize-shapes.mlir b/test/Dialect/Torch/scalarize-shapes.mlir index c7fc2c280a2b..5ea715735c70 100644 --- a/test/Dialect/Torch/scalarize-shapes.mlir +++ b/test/Dialect/Torch/scalarize-shapes.mlir @@ -27,8 +27,12 @@ func.func @shape_as_tensor(%arg0 : !torch.vtensor<[5,?,?],f32>) -> !torch.vtenso // CHECK-LABEL: @shape_as_tensor_dim func.func @shape_as_tensor_dim(%arg0 : !torch.vtensor<[5,?,?],f32>) -> !torch.vtensor<[],si32> { // CHECK: %[[INT1:.+]] = torch.constant.int 1 - // CHECK-DAG: %[[SZ:.+]] = torch.aten.size.int %arg0, %[[INT1]] - // CHECK: %[[TENSOR:.+]] = torch.prim.NumToTensor.Scalar %[[SZ]] : !torch.int -> !torch.vtensor<[],si32> + // CHECK: %[[SZ:.+]] = torch.aten.size.int %arg0, %[[INT1]] + // CHECK: %[[INT1_0:.+]] = torch.constant.int 1 + // CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false + // CHECK-DAG: %[[NONE:.+]] = torch.constant.none + // CHECK-DAG: %[[LIST:.+]] = torch.prim.ListConstruct %[[INT1_0]] + // CHECK: %[[TENSOR:.+]] = torch.aten.full %[[LIST]], %[[SZ]], %[[NONE]], %[[NONE]], %[[NONE]], %[[FALSE]] // CHECK: return %[[TENSOR]] : !torch.vtensor<[],si32> %shape = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[5,?,?],f32> -> !torch.vtensor<[3],si32> %dim = torch.constant.int 0 @@ -39,49 +43,6 @@ func.func @shape_as_tensor_dim(%arg0 : !torch.vtensor<[5,?,?],f32>) -> !torch.vt return %select : !torch.vtensor<[],si32> } -// ----- - -// CHECK-LABEL: @cast_int_int -func.func @cast_int_int(%arg0 : !torch.vtensor<[5,?,?],f32>) -> !torch.vtensor<[],si64> { - // CHECK: %[[I1:.*]] = torch.constant.int 1 - // CHECK: %[[SZE:.*]] = torch.aten.size.int %arg0, %[[I1]] : !torch.vtensor<[5,?,?],f32>, !torch.int -> !torch.int - // CHECK: %[[TENSOR:.*]] = torch.prim.NumToTensor.Scalar %[[SZE]] : !torch.int -> !torch.vtensor<[],si64> - // CHECK: return %[[TENSOR]] : !torch.vtensor<[],si64> - %int4 = torch.constant.int 4 - %false = torch.constant.bool false - %none = torch.constant.none - %shape = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[5,?,?],f32> -> !torch.vtensor<[3],si32> - %cast_shape = torch.aten.to.dtype %shape, %int4, %false, %false, %none : !torch.vtensor<[3],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3],si64> - %dim = torch.constant.int 0 - %idx = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si32> - %select = torch.aten.index_select %cast_shape, %dim, %idx : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[],si32> -> !torch.vtensor<[],si64> - %item = torch.aten.item %select : !torch.vtensor<[],si64> -> !torch.int - %list = torch.prim.ListConstruct %item : (!torch.int) -> !torch.list - return %select : !torch.vtensor<[],si64> -} - -// ----- - -// CHECK-LABEL: @cast_int_float -func.func @cast_int_float(%arg0 : !torch.vtensor<[5,?,?],f32>) -> !torch.vtensor<[],f32> { - // CHECK: %[[I1:.*]] = torch.constant.int 1 - // CHECK: %[[SZE:.*]] = torch.aten.size.int %arg0, %[[I1]] : !torch.vtensor<[5,?,?],f32>, !torch.int -> !torch.int - // CHECK: %[[FLOAT:.*]] = torch.aten.Float.Scalar %[[SZE]] : !torch.int -> !torch.float - // CHECK: %[[TENSOR:.*]] = torch.prim.NumToTensor.Scalar %[[FLOAT]] : !torch.float -> !torch.vtensor<[],f32> - // CHECK: return %[[TENSOR]] : !torch.vtensor<[],f32> - %int6 = torch.constant.int 6 - %false = torch.constant.bool false - %none = torch.constant.none - %shape = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[5,?,?],f32> -> !torch.vtensor<[3],si32> - %cast_shape = torch.aten.to.dtype %shape, %int6, %false, %false, %none : !torch.vtensor<[3],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3],f32> - %dim = torch.constant.int 0 - %idx = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si32> - %select = torch.aten.index_select %cast_shape, %dim, %idx : !torch.vtensor<[3],f32>, !torch.int, !torch.vtensor<[],si32> -> !torch.vtensor<[],f32> - %item = torch.aten.item %select : !torch.vtensor<[],f32> -> !torch.float - %item_int = torch.aten.Int.Scalar %item : !torch.float -> !torch.int - %list = torch.prim.ListConstruct %item_int : (!torch.int) -> !torch.list - return %select : !torch.vtensor<[],f32> -} // ----- @@ -128,12 +89,14 @@ func.func @arith_prop(%arg0 : !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?] // CHECK: %[[x2:.*]] = torch.aten.floordiv.int %[[x0]], %[[int12]] : !torch.int, !torch.int -> !torch.int // CHECK: %[[x3:.*]] = torch.aten.floordiv.int %[[x1]], %[[int1_0]] : !torch.int, !torch.int -> !torch.int // CHECK: %[[int12_1:.*]] = torch.constant.int 12 + // CHECK: %[[int1_2:.*]] = torch.constant.int 1 // CHECK: %[[x4:.*]] = torch.aten.mul.int %[[x2]], %[[int12_1]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[x5:.*]] = torch.aten.sub.int %[[x0]], %[[x4]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[x6:.*]] = torch.aten.sub.int %[[x1]], %[[x3]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[x7:.*]] = torch.prim.ListConstruct %[[x6]], %[[x5]] : (!torch.int, !torch.int) -> !torch.list - // CHECK: %[[x8:.*]] = torch.aten.constant_pad_nd %arg0, %[[x7]], %[[float0]] : !torch.vtensor<[?,?],f32>, !torch.list, !torch.float -> !torch.vtensor<[?,?],f32> - // CHECK: return %[[x8]] : !torch.vtensor<[?,?],f32> + // CHECK: %[[x5:.*]] = torch.aten.mul.int %[[x3]], %[[int1_2]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[x6:.*]] = torch.aten.sub.int %[[x0]], %[[x4]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[x7:.*]] = torch.aten.sub.int %[[x1]], %[[x5]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[x8:.*]] = torch.prim.ListConstruct %[[x7]], %[[x6]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[x9:.*]] = torch.aten.constant_pad_nd %arg0, %[[x8]], %[[float0]] : !torch.vtensor<[?,?],f32>, !torch.list, !torch.float -> !torch.vtensor<[?,?],f32> + // CHECK: return %[[x9]] : !torch.vtensor<[?,?],f32> %0 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> %1 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> %float0.000000e00 = torch.constant.float 0.000000e+00