From 160aedf6c6b758d390240eddfd73cc2474203fe7 Mon Sep 17 00:00:00 2001 From: Justin Ngo Date: Fri, 6 Dec 2024 19:25:34 +0000 Subject: [PATCH] [TOSA] Add some more mixed dtype handling * Add int input handling for activation functions like erf, sigmoid, and tanh * Fix mixed dtype handling for scalar comparison ops * Add mixed dtype handling for pow tensor op (with only floating point result type support for now) * Add Torch to TOSA lowering for torch.aten.tan Signed-off-by: Justin Ngo Change-Id: I3a8aa1e6febbc0e39ebdb5734f87ae171b03cd73 --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 107 ++++++++++++---- projects/pt1/e2e_testing/xfail_sets.py | 15 ++- test/Conversion/TorchToTosa/basic.mlir | 134 +++++++++++++++++++-- 3 files changed, 216 insertions(+), 40 deletions(-) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index cd23717f04eb..9572723fdd29 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -405,7 +405,7 @@ class ConvertAtenCompareOp : public OpConversionPattern { Value rhsAsTensor; if (!rhsTy) { if (failed(torchScalarToTosaTensor(rewriter, op, op.getOther(), - rhsAsTensor, lhsElemTy, {}))) + rhsAsTensor, rhs.getType(), {}))) return rewriter.notifyMatchFailure( op, "Currently only scalar constants are supported for " "conversion in TOSA operation"); @@ -414,11 +414,26 @@ class ConvertAtenCompareOp : public OpConversionPattern { auto rhsTensorTy = dyn_cast(rhsTensor.getType()); auto rhsElemTy = rhsTensorTy.getElementType(); + // There is no Lesser operator in TOSA. + constexpr auto swapLhsRhs = (std::is_same() || + std::is_same() || + std::is_same() || + std::is_same()); + + // Promote lhs and rhs dtypes for bitwise operators. + TensorType resultTy = cast( + OpConversionPattern::getTypeConverter()->convertType( + op.getType())); + if (isBitwiseOp) { + lhs = tosa::promoteType(rewriter, lhs, resultTy); + rhsTensor = tosa::promoteType(rewriter, rhsTensor, resultTy); + } + + // Support different types comparisons auto isLhsElemFloat = isa(lhsElemTy); auto isRhsElemFloat = isa(rhsElemTy); - // Support different types comparisons - if (lhsElemTy != rhsElemTy) { + if (lhsElemTy != rhsElemTy && !isBitwiseOp) { if (isLhsElemFloat && !isRhsElemFloat) { rhsTensor = tosa::promoteType(rewriter, rhsTensor, lhsTy); } else if (!isLhsElemFloat && isRhsElemFloat) { @@ -441,20 +456,6 @@ class ConvertAtenCompareOp : public OpConversionPattern { } } } - // There is no Lesser operator in TOSA. - constexpr auto swapLhsRhs = (std::is_same() || - std::is_same() || - std::is_same() || - std::is_same()); - - // Promote lhs and rhs dtypes for bitwise operators. - TensorType resultTy = cast( - OpConversionPattern::getTypeConverter()->convertType( - op.getType())); - if (isBitwiseOp) { - lhs = tosa::promoteType(rewriter, lhs, resultTy); - rhsTensor = tosa::promoteType(rewriter, rhsTensor, resultTy); - } auto resultOp = rewriter.create(op.getLoc(), resultTy, (swapLhsRhs ? rhsTensor : lhs), @@ -770,17 +771,24 @@ class ConvertAtenActivationFunctionOp : public OpConversionPattern { matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value self = adaptor.getSelf(); - auto selfTy = cast(self.getType()); + auto selfTy = dyn_cast(self.getType()); if (!selfTy) return rewriter.notifyMatchFailure(op, "Only Tensor types supported"); - if (!isa(selfTy.getElementType())) + auto resultTy = dyn_cast( + this->getTypeConverter()->convertType(op.getType())); + + if (!isa(resultTy.getElementType())) return rewriter.notifyMatchFailure( - op, "Only floating-point datatype legalization currently supported"); + op, "Only floating-point datatype result types are supported"); - rewriter.replaceOpWithNewOp( - op, this->getTypeConverter()->convertType(op.getType()), self); + // Non floating point inputs are not supported for activation functions + // (erf, sigmoid, tanh) in TOSA so we cast the input to result type + if (!isa(selfTy.getElementType())) + self = tosa::promoteType(rewriter, self, resultTy); + + rewriter.replaceOpWithNewOp(op, resultTy, self); return success(); } @@ -1283,6 +1291,10 @@ class ConvertAtenPowOp : public OpConversionPattern { auto outType = cast(this->getTypeConverter()->convertType(op.getType())); + if (!isa(outType.getElementType())) + return rewriter.notifyMatchFailure( + op, "Only floating-point datatype result types are supported"); + Value selfTensor; if constexpr (std::is_same()) { Value selfScalar = op.getSelf(); @@ -1299,9 +1311,10 @@ class ConvertAtenPowOp : public OpConversionPattern { return rewriter.notifyMatchFailure( op, "Only ranked tensor types supported in TOSA Pow"); + // Non floating point inputs are not supported for tosa.pow so we cast the + // input to result type if (!isa(selfTy.getElementType())) - return rewriter.notifyMatchFailure( - op, "Only floating-point datatype legalization supported"); + selfTensor = tosa::promoteType(rewriter, selfTensor, outType); } Value expTensor; @@ -1319,6 +1332,11 @@ class ConvertAtenPowOp : public OpConversionPattern { if (!expTy) return rewriter.notifyMatchFailure( op, "Only ranked tensor types supported in TOSA Pow"); + + // Non floating point exponents are not supported for tosa.pow so we cast + // the exponent to result type + if (!isa(expTy.getElementType())) + expTensor = tosa::promoteType(rewriter, expTensor, outType); } auto powOp = tosa::createBinaryOpAndCast( @@ -8198,6 +8216,46 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +// Legalization for aten.tan +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenTanOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // tan = sin / cos + auto self = adaptor.getSelf(); + + auto selfType = dyn_cast(self.getType()); + if (!selfType) + return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); + + auto resultType = + dyn_cast(typeConverter->convertType(op.getType())); + + if (!isa(resultType.getElementType())) + return rewriter.notifyMatchFailure( + op, "Only floating-point datatype result types are supported"); + + // Non floating point inputs are not supported in TOSA so we cast the input + // to result type + if (!isa(selfType.getElementType())) + self = tosa::promoteType(rewriter, self, resultType); + + auto sinOp = rewriter.create(op->getLoc(), resultType, self); + + auto cosOp = rewriter.create(op->getLoc(), resultType, self); + + auto reciprocalOp = + rewriter.create(op->getLoc(), resultType, cosOp); + + auto result = rewriter.create( + op->getLoc(), resultType, sinOp.getResult(), reciprocalOp.getResult(), + /*shift=*/0); + + rewriter.replaceOp(op, {result.getResult()}); + + return success(); +} + } // namespace // ----------------------------------------------------------------------------- @@ -8540,6 +8598,7 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase { INSERT_ATENOP_PATTERN(AtenLogitOp); INSERT_ATENOP_PATTERN(AtenLog1pOp); INSERT_ATENOP_PATTERN(AtenLog10Op); + INSERT_ATENOP_PATTERN(AtenTanOp); #undef INSERT_ATENOP_PATTERN #define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \ diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 2237ca1446ea..7430ad89c2c2 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1717,6 +1717,13 @@ # Write the TOSA set as a "passing" set as it is very early in development # and very few tests work yet. TOSA_PASS_SET = { + "ElementwiseErfIntModule_basic", + "ElementwiseIntTensorLtFloatScalarModule_basic", + "ElementwiseSigmoidIntModule_basic", + "ElementwiseTanIntModule_basic", + "ElementwiseTanModule_basic", + "ElementwiseUnaryIntModule_basic", + "PowIntFloatModule_basic", "Deg2radModule_basic", "ElementwiseIntTensorLtFloatTensorModule_basic", "L1LossMeanReductionModule_basic", @@ -3658,22 +3665,16 @@ "ElementwiseCoshModule_basic", "ElementwiseDequantizePerChannelModule_basic", "ElementwiseDequantizePerTensorModule_basic", - "ElementwiseErfIntModule_basic", "ElementwiseExpm1IntModule_basic", "ElementwiseExpm1Module_basic", - "ElementwiseIntTensorLtFloatScalarModule_basic", "ElementwiseMulTensorComplexDiffModule_basic", "ElementwiseMulTensorComplexModule_basic", "ElementwiseQuantizePerTensorModule_basic", "ElementwiseQuantizePerTensorUIntModule_basic", - "ElementwiseSigmoidIntModule_basic", "ElementwiseSinhIntModule_basic", "ElementwiseSinhModule_basic", - "ElementwiseTanIntModule_basic", - "ElementwiseTanModule_basic", "ElementwiseToDtypeF32ToI64Module_basic", "ElementwiseToDtypeI64ToUI8Module_basic", - "ElementwiseUnaryIntModule_basic", "ElementwiseWhereScalarOtherStaticModule_basic", "EqIntModule_basic", "FloatImplicitModule_basic", @@ -3780,7 +3781,6 @@ "NumelZeroRankModule_basic", "OnesLikeModule_falsePinMemory", "PowIntIntModule_basic", - "PowIntFloatModule_basic", "PrimMaxIntModule_basic", "PrimMinIntDynamicModule_basic", "PrimMinIntModule_basic", @@ -4369,7 +4369,6 @@ "ElementwiseSqrtIntModule_basic", "ElementwiseSubScalarIntModule_basic", "ElementwiseTanIntModule_basic", - "ElementwiseTanModule_basic", "ElementwiseTernaryModule_basic", "ElementwiseToDtypeF32ToI64Module_basic", "ElementwiseToDtypeI64ToI8Module_basic", diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index 02bb2338910f..9e504c082a8c 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -1766,10 +1766,11 @@ func.func @torch.aten.erf$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vten // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],si32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],si32> -> tensor // CHECK: %[[VAL_2:.*]] = torch.constant.int 2 -// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<2> : tensor}> : () -> tensor -// CHECK: %[[VAL_4:.*]] = tosa.bitwise_and %[[VAL_1]], %[[VAL_3]] : (tensor, tensor) -> tensor -// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],si32> -// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],si32> +// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<2> : tensor}> : () -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.cast %[[VAL_3]] : (tensor) -> tensor +// CHECK: %[[VAL_5:.*]] = tosa.bitwise_and %[[VAL_1]], %[[VAL_4]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor -> !torch.vtensor<[?,?],si32> +// CHECK: return %[[VAL_6]] : !torch.vtensor<[?,?],si32> // CHECK: } func.func @torch.aten.bitwise_and.Scalar$basic(%arg0: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],si32> { %int2 = torch.constant.int 2 @@ -1799,10 +1800,11 @@ func.func @torch.aten.le.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: ! // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_2:.*]] = torch.constant.int 2 -// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_4:.*]] = tosa.greater_equal %[[VAL_3]], %[[VAL_1]] : (tensor, tensor) -> tensor -// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],i1> -// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1> +// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<2> : tensor}> : () -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.cast %[[VAL_3]] : (tensor) -> tensor +// CHECK: %[[VAL_5:.*]] = tosa.greater_equal %[[VAL_4]], %[[VAL_1]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor -> !torch.vtensor<[?,?],i1> +// CHECK: return %[[VAL_6]] : !torch.vtensor<[?,?],i1> // CHECK: } func.func @torch.aten.le.Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> { %int2 = torch.constant.int 2 @@ -2825,3 +2827,119 @@ func.func @torch.aten.log2$int(%arg0: !torch.vtensor<[3,4],si32>) -> !torch.vten } // ----- + +// CHECK-LABEL: func.func @torch.aten.erf$int( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],si32> -> tensor<3x4xi32> +// CHECK: %[[VAL_2:.*]] = tosa.cast %[[VAL_1]] : (tensor<3x4xi32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_3:.*]] = tosa.erf %[[VAL_2]] : (tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> +// CHECK: return %[[VAL_4]] : !torch.vtensor<[3,4],f32> +// CHECK: } +func.func @torch.aten.erf$int(%arg0: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],f32> { + %0 = torch.aten.erf %arg0 : !torch.vtensor<[3,4],si32> -> !torch.vtensor<[3,4],f32> + return %0 : !torch.vtensor<[3,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.lt.Scalar$intfloat( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4],si64>) -> !torch.vtensor<[4],i1> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4],si64> -> tensor<4xi64> +// CHECK: %[[VAL_2:.*]] = torch.constant.float 1.100000e+00 +// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<1.100000e+00> : tensor}> : () -> tensor +// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<1.1000000238418579> : tensor}> : () -> tensor +// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_1]] : (tensor<4xi64>) -> tensor<4xf64> +// CHECK: %[[VAL_6:.*]] = tosa.greater %[[VAL_4]], %[[VAL_5]] : (tensor, tensor<4xf64>) -> tensor<4xi1> +// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<4xi1> -> !torch.vtensor<[4],i1> +// CHECK: return %[[VAL_7]] : !torch.vtensor<[4],i1> +// CHECK: } +func.func @torch.aten.lt.Scalar$intfloat(%arg0: !torch.vtensor<[4],si64>) -> !torch.vtensor<[4],i1> { + %float1.100000e00 = torch.constant.float 1.100000e+00 + %0 = torch.aten.lt.Scalar %arg0, %float1.100000e00 : !torch.vtensor<[4],si64>, !torch.float -> !torch.vtensor<[4],i1> + return %0 : !torch.vtensor<[4],i1> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.sigmoid$int( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,5],si32>) -> !torch.vtensor<[3,5],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,5],si32> -> tensor<3x5xi32> +// CHECK: %[[VAL_2:.*]] = tosa.cast %[[VAL_1]] : (tensor<3x5xi32>) -> tensor<3x5xf32> +// CHECK: %[[VAL_3:.*]] = tosa.sigmoid %[[VAL_2]] : (tensor<3x5xf32>) -> tensor<3x5xf32> +// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<3x5xf32> -> !torch.vtensor<[3,5],f32> +// CHECK: return %[[VAL_4]] : !torch.vtensor<[3,5],f32> +// CHECK: } +func.func @torch.aten.sigmoid$int(%arg0: !torch.vtensor<[3,5],si32>) -> !torch.vtensor<[3,5],f32> { + %0 = torch.aten.sigmoid %arg0 : !torch.vtensor<[3,5],si32> -> !torch.vtensor<[3,5],f32> + return %0 : !torch.vtensor<[3,5],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.tan$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],f32> -> tensor<3x4xf32> +// CHECK: %[[VAL_2:.*]] = tosa.sin %[[VAL_1]] : (tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_3:.*]] = tosa.cos %[[VAL_1]] : (tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_4:.*]] = tosa.reciprocal %[[VAL_3]] : (tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_2]], %[[VAL_4]] {shift = 0 : i8} : (tensor<3x4xf32>, tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> +// CHECK: return %[[VAL_6]] : !torch.vtensor<[3,4],f32> +// CHECK: } +func.func @torch.aten.tan$basic(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> { + %0 = torch.aten.tan %arg0 : !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32> + return %0 : !torch.vtensor<[3,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.tan$int( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],si32> -> tensor<3x4xi32> +// CHECK: %[[VAL_2:.*]] = tosa.cast %[[VAL_1]] : (tensor<3x4xi32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_3:.*]] = tosa.sin %[[VAL_2]] : (tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_4:.*]] = tosa.cos %[[VAL_2]] : (tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_5:.*]] = tosa.reciprocal %[[VAL_4]] : (tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_3]], %[[VAL_5]] {shift = 0 : i8} : (tensor<3x4xf32>, tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> +// CHECK: return %[[VAL_7]] : !torch.vtensor<[3,4],f32> +// CHECK: } +func.func @torch.aten.tan$int(%arg0: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],f32> { + %0 = torch.aten.tan %arg0 : !torch.vtensor<[3,4],si32> -> !torch.vtensor<[3,4],f32> + return %0 : !torch.vtensor<[3,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.tanh$int( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],si32> -> tensor<3x4xi32> +// CHECK: %[[VAL_2:.*]] = tosa.cast %[[VAL_1]] : (tensor<3x4xi32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_3:.*]] = tosa.tanh %[[VAL_2]] : (tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> +// CHECK: return %[[VAL_4]] : !torch.vtensor<[3,4],f32> +// CHECK: } +func.func @torch.aten.tanh$int(%arg0: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],f32> { + %0 = torch.aten.tanh %arg0 : !torch.vtensor<[3,4],si32> -> !torch.vtensor<[3,4],f32> + return %0 : !torch.vtensor<[3,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.pow.Tensor_Tensor$intfloat( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4,5],si32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[3,4,5],f32> -> tensor<3x4x5xf32> +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4,5],si32> -> tensor<3x4x5xi32> +// CHECK: %[[VAL_4:.*]] = tosa.cast %[[VAL_3]] : (tensor<3x4x5xi32>) -> tensor<3x4x5xf32> +// CHECK: %[[VAL_5:.*]] = tosa.pow %[[VAL_4]], %[[VAL_2]] : (tensor<3x4x5xf32>, tensor<3x4x5xf32>) -> tensor<3x4x5xf32> +// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x4x5xf32> -> !torch.vtensor<[3,4,5],f32> +// CHECK: return %[[VAL_6]] : !torch.vtensor<[3,4,5],f32> +// CHECK: } +func.func @torch.aten.pow.Tensor_Tensor$intfloat(%arg0: !torch.vtensor<[3,4,5],si32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> { + %0 = torch.aten.pow.Tensor_Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],si32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// -----