Skip to content

Commit

Permalink
Merge branch 'llvm:main' into fix-conv
Browse files Browse the repository at this point in the history
  • Loading branch information
vivekkhandelwal1 authored Dec 10, 2024
2 parents 5da4d93 + 5077090 commit fd98ffe
Show file tree
Hide file tree
Showing 5 changed files with 221 additions and 42 deletions.
2 changes: 1 addition & 1 deletion build_tools/ci/build_posix.sh
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ cmake -S "$repo_root/externals/llvm-project/llvm" -B "$build_dir" \
-DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR="$repo_root" \
-DLLVM_TARGETS_TO_BUILD=host \
-DMLIR_ENABLE_BINDINGS_PYTHON=ON \
-DTORCH_MLIR_ENABLE_LTC=ON \
-DTORCH_MLIR_ENABLE_LTC=OFF \
-DTORCH_MLIR_ENABLE_PYTORCH_EXTENSIONS=ON
echo "::endgroup::"

Expand Down
107 changes: 83 additions & 24 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ class ConvertAtenCompareOp : public OpConversionPattern<AtenOpT> {
Value rhsAsTensor;
if (!rhsTy) {
if (failed(torchScalarToTosaTensor(rewriter, op, op.getOther(),
rhsAsTensor, lhsElemTy, {})))
rhsAsTensor, rhs.getType(), {})))
return rewriter.notifyMatchFailure(
op, "Currently only scalar constants are supported for "
"conversion in TOSA operation");
Expand All @@ -414,11 +414,26 @@ class ConvertAtenCompareOp : public OpConversionPattern<AtenOpT> {
auto rhsTensorTy = dyn_cast<TensorType>(rhsTensor.getType());
auto rhsElemTy = rhsTensorTy.getElementType();

// There is no Lesser operator in TOSA.
constexpr auto swapLhsRhs = (std::is_same<AtenOpT, AtenLtTensorOp>() ||
std::is_same<AtenOpT, AtenLtScalarOp>() ||
std::is_same<AtenOpT, AtenLeTensorOp>() ||
std::is_same<AtenOpT, AtenLeScalarOp>());

// Promote lhs and rhs dtypes for bitwise operators.
TensorType resultTy = cast<TensorType>(
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()));
if (isBitwiseOp) {
lhs = tosa::promoteType(rewriter, lhs, resultTy);
rhsTensor = tosa::promoteType(rewriter, rhsTensor, resultTy);
}

// Support different types comparisons
auto isLhsElemFloat = isa<mlir::FloatType>(lhsElemTy);
auto isRhsElemFloat = isa<mlir::FloatType>(rhsElemTy);

// Support different types comparisons
if (lhsElemTy != rhsElemTy) {
if (lhsElemTy != rhsElemTy && !isBitwiseOp) {
if (isLhsElemFloat && !isRhsElemFloat) {
rhsTensor = tosa::promoteType(rewriter, rhsTensor, lhsTy);
} else if (!isLhsElemFloat && isRhsElemFloat) {
Expand All @@ -441,20 +456,6 @@ class ConvertAtenCompareOp : public OpConversionPattern<AtenOpT> {
}
}
}
// There is no Lesser operator in TOSA.
constexpr auto swapLhsRhs = (std::is_same<AtenOpT, AtenLtTensorOp>() ||
std::is_same<AtenOpT, AtenLtScalarOp>() ||
std::is_same<AtenOpT, AtenLeTensorOp>() ||
std::is_same<AtenOpT, AtenLeScalarOp>());

// Promote lhs and rhs dtypes for bitwise operators.
TensorType resultTy = cast<TensorType>(
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()));
if (isBitwiseOp) {
lhs = tosa::promoteType(rewriter, lhs, resultTy);
rhsTensor = tosa::promoteType(rewriter, rhsTensor, resultTy);
}

auto resultOp = rewriter.create<TosaOpT>(op.getLoc(), resultTy,
(swapLhsRhs ? rhsTensor : lhs),
Expand Down Expand Up @@ -770,17 +771,24 @@ class ConvertAtenActivationFunctionOp : public OpConversionPattern<AtenOpT> {
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value self = adaptor.getSelf();
auto selfTy = cast<TensorType>(self.getType());
auto selfTy = dyn_cast<TensorType>(self.getType());

if (!selfTy)
return rewriter.notifyMatchFailure(op, "Only Tensor types supported");

if (!isa<mlir::FloatType>(selfTy.getElementType()))
auto resultTy = dyn_cast<TensorType>(
this->getTypeConverter()->convertType(op.getType()));

if (!isa<mlir::FloatType>(resultTy.getElementType()))
return rewriter.notifyMatchFailure(
op, "Only floating-point datatype legalization currently supported");
op, "Only floating-point datatype result types are supported");

rewriter.replaceOpWithNewOp<TosaOpT>(
op, this->getTypeConverter()->convertType(op.getType()), self);
// Non floating point inputs are not supported for activation functions
// (erf, sigmoid, tanh) in TOSA so we cast the input to result type
if (!isa<mlir::FloatType>(selfTy.getElementType()))
self = tosa::promoteType(rewriter, self, resultTy);

rewriter.replaceOpWithNewOp<TosaOpT>(op, resultTy, self);

return success();
}
Expand Down Expand Up @@ -1283,6 +1291,10 @@ class ConvertAtenPowOp : public OpConversionPattern<AtenOpT> {
auto outType =
cast<TensorType>(this->getTypeConverter()->convertType(op.getType()));

if (!isa<mlir::FloatType>(outType.getElementType()))
return rewriter.notifyMatchFailure(
op, "Only floating-point datatype result types are supported");

Value selfTensor;
if constexpr (std::is_same<AtenOpT, AtenPowScalarOp>()) {
Value selfScalar = op.getSelf();
Expand All @@ -1299,9 +1311,10 @@ class ConvertAtenPowOp : public OpConversionPattern<AtenOpT> {
return rewriter.notifyMatchFailure(
op, "Only ranked tensor types supported in TOSA Pow");

// Non floating point inputs are not supported for tosa.pow so we cast the
// input to result type
if (!isa<mlir::FloatType>(selfTy.getElementType()))
return rewriter.notifyMatchFailure(
op, "Only floating-point datatype legalization supported");
selfTensor = tosa::promoteType(rewriter, selfTensor, outType);
}

Value expTensor;
Expand All @@ -1319,6 +1332,11 @@ class ConvertAtenPowOp : public OpConversionPattern<AtenOpT> {
if (!expTy)
return rewriter.notifyMatchFailure(
op, "Only ranked tensor types supported in TOSA Pow");

// Non floating point exponents are not supported for tosa.pow so we cast
// the exponent to result type
if (!isa<mlir::FloatType>(expTy.getElementType()))
expTensor = tosa::promoteType(rewriter, expTensor, outType);
}

auto powOp = tosa::createBinaryOpAndCast<tosa::PowOp>(
Expand Down Expand Up @@ -8198,6 +8216,46 @@ LogicalResult ConvertAtenOp<AtenLog10Op>::matchAndRewrite(
return success();
}

// Legalization for aten.tan
template <>
LogicalResult ConvertAtenOp<AtenTanOp>::matchAndRewrite(
AtenTanOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// tan = sin / cos
auto self = adaptor.getSelf();

auto selfType = dyn_cast<TensorType>(self.getType());
if (!selfType)
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");

auto resultType =
dyn_cast<TensorType>(typeConverter->convertType(op.getType()));

if (!isa<mlir::FloatType>(resultType.getElementType()))
return rewriter.notifyMatchFailure(
op, "Only floating-point datatype result types are supported");

// Non floating point inputs are not supported in TOSA so we cast the input
// to result type
if (!isa<mlir::FloatType>(selfType.getElementType()))
self = tosa::promoteType(rewriter, self, resultType);

auto sinOp = rewriter.create<tosa::SinOp>(op->getLoc(), resultType, self);

auto cosOp = rewriter.create<tosa::CosOp>(op->getLoc(), resultType, self);

auto reciprocalOp =
rewriter.create<tosa::ReciprocalOp>(op->getLoc(), resultType, cosOp);

auto result = rewriter.create<tosa::MulOp>(
op->getLoc(), resultType, sinOp.getResult(), reciprocalOp.getResult(),
/*shift=*/0);

rewriter.replaceOp(op, {result.getResult()});

return success();
}

} // namespace

// -----------------------------------------------------------------------------
Expand Down Expand Up @@ -8540,6 +8598,7 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
INSERT_ATENOP_PATTERN(AtenLogitOp);
INSERT_ATENOP_PATTERN(AtenLog1pOp);
INSERT_ATENOP_PATTERN(AtenLog10Op);
INSERT_ATENOP_PATTERN(AtenTanOp);
#undef INSERT_ATENOP_PATTERN

#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \
Expand Down
15 changes: 7 additions & 8 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1717,6 +1717,13 @@
# Write the TOSA set as a "passing" set as it is very early in development
# and very few tests work yet.
TOSA_PASS_SET = {
"ElementwiseErfIntModule_basic",
"ElementwiseIntTensorLtFloatScalarModule_basic",
"ElementwiseSigmoidIntModule_basic",
"ElementwiseTanIntModule_basic",
"ElementwiseTanModule_basic",
"ElementwiseUnaryIntModule_basic",
"PowIntFloatModule_basic",
"Deg2radModule_basic",
"ElementwiseIntTensorLtFloatTensorModule_basic",
"L1LossMeanReductionModule_basic",
Expand Down Expand Up @@ -3661,22 +3668,16 @@
"ElementwiseCoshModule_basic",
"ElementwiseDequantizePerChannelModule_basic",
"ElementwiseDequantizePerTensorModule_basic",
"ElementwiseErfIntModule_basic",
"ElementwiseExpm1IntModule_basic",
"ElementwiseExpm1Module_basic",
"ElementwiseIntTensorLtFloatScalarModule_basic",
"ElementwiseMulTensorComplexDiffModule_basic",
"ElementwiseMulTensorComplexModule_basic",
"ElementwiseQuantizePerTensorModule_basic",
"ElementwiseQuantizePerTensorUIntModule_basic",
"ElementwiseSigmoidIntModule_basic",
"ElementwiseSinhIntModule_basic",
"ElementwiseSinhModule_basic",
"ElementwiseTanIntModule_basic",
"ElementwiseTanModule_basic",
"ElementwiseToDtypeF32ToI64Module_basic",
"ElementwiseToDtypeI64ToUI8Module_basic",
"ElementwiseUnaryIntModule_basic",
"ElementwiseWhereScalarOtherStaticModule_basic",
"EqIntModule_basic",
"FloatImplicitModule_basic",
Expand Down Expand Up @@ -3783,7 +3784,6 @@
"NumelZeroRankModule_basic",
"OnesLikeModule_falsePinMemory",
"PowIntIntModule_basic",
"PowIntFloatModule_basic",
"PrimMaxIntModule_basic",
"PrimMinIntDynamicModule_basic",
"PrimMinIntModule_basic",
Expand Down Expand Up @@ -4373,7 +4373,6 @@
"ElementwiseSqrtIntModule_basic",
"ElementwiseSubScalarIntModule_basic",
"ElementwiseTanIntModule_basic",
"ElementwiseTanModule_basic",
"ElementwiseTernaryModule_basic",
"ElementwiseToDtypeF32ToI64Module_basic",
"ElementwiseToDtypeI64ToI8Module_basic",
Expand Down
5 changes: 4 additions & 1 deletion python/torch_mlir/extras/onnx_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,7 +739,10 @@ def type_proto_to_type(self, tp: onnx.TypeProto) -> IrType:
def _sanitize_name(self, name):
if not name.isidentifier():
name = "_" + name
return re.sub("[:/]", "_", name)

# Remove characters that are invalid in MLIR identifier names.
# https://mlir.llvm.org/docs/LangRef/#identifiers-and-keywords
return re.sub("[:/-]", "_", name)

def tensor_proto_to_attr(self, tp: onnx.TensorProto) -> Attribute:
tensor_type = self.tensor_proto_to_builtin_type(tp)
Expand Down
Loading

0 comments on commit fd98ffe

Please sign in to comment.