From f03a5762c3598da39ac44f1edbc7aa4579ef3262 Mon Sep 17 00:00:00 2001 From: Sayan Saha Date: Thu, 12 Dec 2024 04:08:27 -0500 Subject: [PATCH] [TorchToTosa] Refactoring to separate construction of legal/illegal ops and conversion patterns. (#3759) This PR refactors TorchToTosa to separate the construction of legal/illegal ops and conversion patterns in their own functions: 1. populateTorchToTosaConversionLegalOps -- populate any ops that are legal after the conversion pass 2. populateTorchToTosaConversionIllegalOps -- populate any ops that are illegal after the conversion pass 3. populateTorchToTosaConversionPatterns -- populate the ops conversion patterns Currently the (il)legality of the ops that are (il)legal after the conversion pass runs is embedded within the conversion pattern. Our end goal is to write a new pass pipeline that converts `torch` ops to a mix of `tosa`, `linalg`, `tensor`, etc dialect ops. The reason we want to also emit `tosa` ops (instead of using the existing `TorchToLinalg` to emit `linalg`+`tensor`+...) is because some operations like `conv2d` encodes the padding behavior in the op in `tosa` unlike the `linalg` version -- this helps in lowering the `tosa.conv2d` to a custom implementation that does padding on the fly. To implement this new pipeline we need to be able to separate out the illegal `tosa` ops from the conversion pattern itself. Otherwise we will hit an issue for ops like `AtenMaxDimOp` which can be lowered to both `tosa` and `linalg + others` dialects. Not all `AtenMaxDimOp` can be lowered successfully to `tosa` as the implementation uses `tosa.reshape` which cannot handle multiple dynamic dimensions but the `TorchToLinalg` lowering can handle it. In the current behavior the pipeline will stop as soon as the existing `TorchToTosa` conversion runs as `AtenMaxDimOp` will be marked as an illegal op. Essentially we want to be able to control what the legality of the ops should be independent of the conversion pattern. This is also inline with the conversion patterns in the llvm-mlir repo such as https://github.com/llvm/llvm-project/blob/000e790be35b77a01872851646d54432a203542c/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp#L718 "THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY." --- .../Conversion/TorchToTosa/TorchToTosa.h | 15 +- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 456 +++++++++--------- 2 files changed, 249 insertions(+), 222 deletions(-) diff --git a/include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h b/include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h index a6d774a64db1..221745b1c26e 100644 --- a/include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h +++ b/include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h @@ -12,12 +12,25 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + #include namespace mlir { namespace torch { + +/// Collect a set of legal/illegal ops for converting Torch operations to Tosa +/// dialect. +void populateTorchToTosaConversionLegalOps(ConversionTarget &target); + +/// Collect a set of patterns to convert Torch operations to Tosa dialect + +/// return the set of illegalOps +std::set +populateTorchToTosaConversionPatternsAndIllegalOps(TypeConverter &typeConverter, + RewritePatternSet &patterns); + std::unique_ptr> createConvertTorchToTosaPass(); -} +} // namespace torch } // namespace mlir #endif // TORCHMLIR_CONVERSION_TORCHTOTOSA_TORCHTOTOSA_H diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 9572723fdd29..1c05ae49e18b 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -8277,342 +8277,356 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase { ConversionTarget target(*context); target.addLegalDialect(); + target.addIllegalDialect(); TypeConverter typeConverter; typeConverter.addConversion([](Type type) { return type; }); TorchConversion::setupBackendTypeConversion(target, typeConverter); - // The following ops are never the primary reason why lowering fails. - // The backend contract only allows functions to return tensors thus there - // is always another op using them. - // When we have a chain of torch.constant.int followed by a unsupported - // torch op, we want the pass to mention the unsupported torch op - // in the error message. - target.addLegalOp(); - target.addLegalOp(); - target.addLegalOp(); - target.addLegalOp(); - target.addLegalOp(); - target.addLegalOp(); - target.addLegalOp(); - target.addLegalOp(); - target.addIllegalDialect(); + populateTorchToTosaConversionLegalOps(target); RewritePatternSet patterns(context); + auto illegalOps = populateTorchToTosaConversionPatternsAndIllegalOps( + typeConverter, patterns); + + for (auto op : illegalOps) { + target.addIllegalOp(OperationName(op, context)); + } + + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + return signalPassFailure(); + } +}; +} // namespace + +void torch::populateTorchToTosaConversionLegalOps(ConversionTarget &target) { + // The following ops are never the primary reason why lowering fails. + // The backend contract only allows functions to return tensors thus there + // is always another op using them. + // When we have a chain of torch.constant.int followed by a unsupported + // torch op, we want the pass to mention the unsupported torch op + // in the error message. + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); +} + +std::set torch::populateTorchToTosaConversionPatternsAndIllegalOps( + TypeConverter &typeConverter, RewritePatternSet &patterns) { + + MLIRContext *context = patterns.getContext(); + std::set illegalOps; + #define INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenOp, TosaOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, \ context); - INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenLogOp, tosa::LogOp) - INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenExpOp, tosa::ExpOp) + INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenLogOp, tosa::LogOp) + INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenExpOp, tosa::ExpOp) #undef INSERT_UNARY_PROMOTE_TO_FP_PATTERN #define INSERT_UNARY_PATTERN(AtenOp, TosaOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, context); - INSERT_UNARY_PATTERN(AtenNegOp, tosa::NegateOp) - INSERT_UNARY_PATTERN(AtenFloorOp, tosa::FloorOp) - INSERT_UNARY_PATTERN(AtenRsqrtOp, tosa::RsqrtOp) - INSERT_UNARY_PATTERN(AtenBitwiseNotOp, tosa::BitwiseNotOp) - INSERT_UNARY_PATTERN(AtenCeilOp, tosa::CeilOp) - INSERT_UNARY_PATTERN(AtenReciprocalOp, tosa::ReciprocalOp) - INSERT_UNARY_PATTERN(AtenCosOp, tosa::CosOp) - INSERT_UNARY_PATTERN(AtenSinOp, tosa::SinOp) - INSERT_UNARY_PATTERN(AtenLogicalNotOp, tosa::LogicalNotOp) + INSERT_UNARY_PATTERN(AtenNegOp, tosa::NegateOp) + INSERT_UNARY_PATTERN(AtenFloorOp, tosa::FloorOp) + INSERT_UNARY_PATTERN(AtenRsqrtOp, tosa::RsqrtOp) + INSERT_UNARY_PATTERN(AtenBitwiseNotOp, tosa::BitwiseNotOp) + INSERT_UNARY_PATTERN(AtenCeilOp, tosa::CeilOp) + INSERT_UNARY_PATTERN(AtenReciprocalOp, tosa::ReciprocalOp) + INSERT_UNARY_PATTERN(AtenCosOp, tosa::CosOp) + INSERT_UNARY_PATTERN(AtenSinOp, tosa::SinOp) + INSERT_UNARY_PATTERN(AtenLogicalNotOp, tosa::LogicalNotOp) #undef INSERT_UNARY_PATTERN #define INSERT_BINARY_PATTERN(AtenOp, TosaOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, context); - INSERT_BINARY_PATTERN(AtenMaximumOp, tosa::MaximumOp) - INSERT_BINARY_PATTERN(AtenMinimumOp, tosa::MinimumOp) - INSERT_BINARY_PATTERN(AtenLogicalOrOp, tosa::LogicalOrOp) - INSERT_BINARY_PATTERN(AtenLogicalXorOp, tosa::LogicalXorOp) - INSERT_BINARY_PATTERN(AtenLogicalAndOp, tosa::LogicalAndOp) - INSERT_BINARY_PATTERN(AtenBitwiseLeftShiftTensorOp, - tosa::LogicalLeftShiftOp) - INSERT_BINARY_PATTERN(AtenBitwiseRightShiftTensorOp, - tosa::ArithmeticRightShiftOp) + INSERT_BINARY_PATTERN(AtenMaximumOp, tosa::MaximumOp) + INSERT_BINARY_PATTERN(AtenMinimumOp, tosa::MinimumOp) + INSERT_BINARY_PATTERN(AtenLogicalOrOp, tosa::LogicalOrOp) + INSERT_BINARY_PATTERN(AtenLogicalXorOp, tosa::LogicalXorOp) + INSERT_BINARY_PATTERN(AtenLogicalAndOp, tosa::LogicalAndOp) + INSERT_BINARY_PATTERN(AtenBitwiseLeftShiftTensorOp, tosa::LogicalLeftShiftOp) + INSERT_BINARY_PATTERN(AtenBitwiseRightShiftTensorOp, + tosa::ArithmeticRightShiftOp) #undef INSERT_BINARY_PATTERN #define INSERT_BINARY_ADDSUB_PATTERN(AtenOp, TosaOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, context); - INSERT_BINARY_ADDSUB_PATTERN(AtenAddTensorOp, tosa::AddOp) - INSERT_BINARY_ADDSUB_PATTERN(AtenAddScalarOp, tosa::AddOp) - INSERT_BINARY_ADDSUB_PATTERN(AtenSubTensorOp, tosa::SubOp) - INSERT_BINARY_ADDSUB_PATTERN(AtenSubScalarOp, tosa::SubOp) + INSERT_BINARY_ADDSUB_PATTERN(AtenAddTensorOp, tosa::AddOp) + INSERT_BINARY_ADDSUB_PATTERN(AtenAddScalarOp, tosa::AddOp) + INSERT_BINARY_ADDSUB_PATTERN(AtenSubTensorOp, tosa::SubOp) + INSERT_BINARY_ADDSUB_PATTERN(AtenSubScalarOp, tosa::SubOp) #undef INSERT_BINARY_ADDSUB_PATTERN #define INSERT_BINARY_COMPARE_PATTERN(AtenOp, TosaOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, context); - INSERT_BINARY_COMPARE_PATTERN(AtenGtTensorOp, tosa::GreaterOp) - INSERT_BINARY_COMPARE_PATTERN(AtenGeScalarOp, tosa::GreaterEqualOp) - INSERT_BINARY_COMPARE_PATTERN(AtenGeTensorOp, tosa::GreaterEqualOp) - INSERT_BINARY_COMPARE_PATTERN(AtenGtScalarOp, tosa::GreaterOp) - INSERT_BINARY_COMPARE_PATTERN(AtenLtTensorOp, tosa::GreaterOp) - INSERT_BINARY_COMPARE_PATTERN(AtenLtScalarOp, tosa::GreaterOp) - INSERT_BINARY_COMPARE_PATTERN(AtenLeTensorOp, tosa::GreaterEqualOp) - INSERT_BINARY_COMPARE_PATTERN(AtenLeScalarOp, tosa::GreaterEqualOp) - INSERT_BINARY_COMPARE_PATTERN(AtenEqTensorOp, tosa::EqualOp) - INSERT_BINARY_COMPARE_PATTERN(AtenEqScalarOp, tosa::EqualOp) - INSERT_BINARY_COMPARE_PATTERN(AtenNeTensorOp, tosa::EqualOp) - INSERT_BINARY_COMPARE_PATTERN(AtenNeScalarOp, tosa::EqualOp) - INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseAndTensorOp, tosa::BitwiseAndOp) - INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseAndScalarOp, tosa::BitwiseAndOp) - INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseOrTensorOp, tosa::BitwiseOrOp) - INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseXorTensorOp, tosa::BitwiseXorOp) + INSERT_BINARY_COMPARE_PATTERN(AtenGtTensorOp, tosa::GreaterOp) + INSERT_BINARY_COMPARE_PATTERN(AtenGeScalarOp, tosa::GreaterEqualOp) + INSERT_BINARY_COMPARE_PATTERN(AtenGeTensorOp, tosa::GreaterEqualOp) + INSERT_BINARY_COMPARE_PATTERN(AtenGtScalarOp, tosa::GreaterOp) + INSERT_BINARY_COMPARE_PATTERN(AtenLtTensorOp, tosa::GreaterOp) + INSERT_BINARY_COMPARE_PATTERN(AtenLtScalarOp, tosa::GreaterOp) + INSERT_BINARY_COMPARE_PATTERN(AtenLeTensorOp, tosa::GreaterEqualOp) + INSERT_BINARY_COMPARE_PATTERN(AtenLeScalarOp, tosa::GreaterEqualOp) + INSERT_BINARY_COMPARE_PATTERN(AtenEqTensorOp, tosa::EqualOp) + INSERT_BINARY_COMPARE_PATTERN(AtenEqScalarOp, tosa::EqualOp) + INSERT_BINARY_COMPARE_PATTERN(AtenNeTensorOp, tosa::EqualOp) + INSERT_BINARY_COMPARE_PATTERN(AtenNeScalarOp, tosa::EqualOp) + INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseAndTensorOp, tosa::BitwiseAndOp) + INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseAndScalarOp, tosa::BitwiseAndOp) + INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseOrTensorOp, tosa::BitwiseOrOp) + INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseXorTensorOp, tosa::BitwiseXorOp) #undef INSERT_BINARY_COMPARE_PATTERN #define INSERT_BINARY_MUL_PATTERN(AtenOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, context); - INSERT_BINARY_MUL_PATTERN(AtenMulTensorOp); - INSERT_BINARY_MUL_PATTERN(AtenMulScalarOp); + INSERT_BINARY_MUL_PATTERN(AtenMulTensorOp); + INSERT_BINARY_MUL_PATTERN(AtenMulScalarOp); #undef INSERT_BINARY_MUL_PATTERN #define INSERT_BINARY_DIV_PATTERN(AtenOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, context); - INSERT_BINARY_DIV_PATTERN(AtenDivTensorOp); - INSERT_BINARY_DIV_PATTERN(AtenDivScalarOp); - INSERT_BINARY_DIV_PATTERN(AtenDivTensorModeOp); - INSERT_BINARY_DIV_PATTERN(AtenDivScalarModeOp); + INSERT_BINARY_DIV_PATTERN(AtenDivTensorOp); + INSERT_BINARY_DIV_PATTERN(AtenDivScalarOp); + INSERT_BINARY_DIV_PATTERN(AtenDivTensorModeOp); + INSERT_BINARY_DIV_PATTERN(AtenDivScalarModeOp); #undef INSERT_BINARY_DIV_PATTERN #define INSERT_REMAINDER_FMOD_OP_PATTERN(AtenOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, context); - INSERT_REMAINDER_FMOD_OP_PATTERN(AtenRemainderScalarOp); - INSERT_REMAINDER_FMOD_OP_PATTERN(AtenRemainderTensorOp); - INSERT_REMAINDER_FMOD_OP_PATTERN(AtenFmodScalarOp); - INSERT_REMAINDER_FMOD_OP_PATTERN(AtenFmodTensorOp); + INSERT_REMAINDER_FMOD_OP_PATTERN(AtenRemainderScalarOp); + INSERT_REMAINDER_FMOD_OP_PATTERN(AtenRemainderTensorOp); + INSERT_REMAINDER_FMOD_OP_PATTERN(AtenFmodScalarOp); + INSERT_REMAINDER_FMOD_OP_PATTERN(AtenFmodTensorOp); #undef INSERT_REMAINDER_FMOD_OP_PATTERN #define INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>( \ typeConverter, context); - INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenMeanDimOp, - mlir::tosa::convertReduceMeanOp) - INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenSumDimIntListOp, - mlir::tosa::convertReduceSumOp) - INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenLinalgVectorNormOp, - mlir::tosa::convertLinalgVectorNormOp) + INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenMeanDimOp, + mlir::tosa::convertReduceMeanOp) + INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenSumDimIntListOp, + mlir::tosa::convertReduceSumOp) + INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenLinalgVectorNormOp, + mlir::tosa::convertLinalgVectorNormOp) #undef INSERT_NDIMS_REDUCTION_OP_PATTERN #define INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>( \ typeConverter, context); - INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenAnyDimOp, - mlir::tosa::convertReduceAnyOp) - INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenAllDimOp, - mlir::tosa::convertReduceAllOp) - INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenProdDimIntOp, - mlir::tosa::convertReduceProdOp) + INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenAnyDimOp, + mlir::tosa::convertReduceAnyOp) + INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenAllDimOp, + mlir::tosa::convertReduceAllOp) + INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenProdDimIntOp, + mlir::tosa::convertReduceProdOp) #undef INSERT_ONEDIM_REDUCTION_OP_PATTERN #define INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>( \ typeConverter, context); - INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenAllOp, - mlir::tosa::convertReduceAllOp) - INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenAnyOp, - mlir::tosa::convertReduceAnyOp) - INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenSumOp, - mlir::tosa::convertReduceSumOp) - INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenMaxOp, - mlir::tosa::convertReduceMaxOp) - INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenMinOp, - mlir::tosa::convertReduceMinOp) - INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenProdOp, - mlir::tosa::convertReduceProdOp) + INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenAllOp, mlir::tosa::convertReduceAllOp) + INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenAnyOp, mlir::tosa::convertReduceAnyOp) + INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenSumOp, mlir::tosa::convertReduceSumOp) + INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenMaxOp, mlir::tosa::convertReduceMaxOp) + INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenMinOp, mlir::tosa::convertReduceMinOp) + INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenProdOp, + mlir::tosa::convertReduceProdOp) #undef INSERT_ALLDIMS_REDUCTION_OP_PATTERN #define INSERT_INDICES_REDUCTION_OP_PATTERN(AtenOp, TosaOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, context); - INSERT_INDICES_REDUCTION_OP_PATTERN(AtenMaxDimOp, tosa::ReduceMaxOp); - INSERT_INDICES_REDUCTION_OP_PATTERN(AtenMinDimOp, tosa::ReduceMinOp); + INSERT_INDICES_REDUCTION_OP_PATTERN(AtenMaxDimOp, tosa::ReduceMaxOp); + INSERT_INDICES_REDUCTION_OP_PATTERN(AtenMinDimOp, tosa::ReduceMinOp); #undef INSERT_INDICES_REDUCTION_OP_PATTERN #define INSERT_SQUEEZE_OP_PATTERN(AtenOp, TemplateForm) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, context); - INSERT_SQUEEZE_OP_PATTERN(AtenSqueezeOp, ConvertAtenSqueezeAllDimsOp) - INSERT_SQUEEZE_OP_PATTERN(AtenSqueezeDimOp, ConvertAtenSqueezeOneDimOp) + INSERT_SQUEEZE_OP_PATTERN(AtenSqueezeOp, ConvertAtenSqueezeAllDimsOp) + INSERT_SQUEEZE_OP_PATTERN(AtenSqueezeDimOp, ConvertAtenSqueezeOneDimOp) #undef INSERT_SQUEEZE_OP_PATTERN #define INSERT_MATMUL_ATENOP_PATTERN(AtenOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, context); - INSERT_MATMUL_ATENOP_PATTERN(AtenMatmulOp); + INSERT_MATMUL_ATENOP_PATTERN(AtenMatmulOp); #undef INSERT_MATMUL_ATEMOP_PATTERN #define INSERT_MM_ATENOP_PATTERN(AtenOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, context); - INSERT_MM_ATENOP_PATTERN(AtenMmOp); - INSERT_MM_ATENOP_PATTERN(AtenBmmOp); + INSERT_MM_ATENOP_PATTERN(AtenMmOp); + INSERT_MM_ATENOP_PATTERN(AtenBmmOp); #undef INSERT_MM_ATEMOP_PATTERN #define INSERT_LINEAR_ATENOP_PATTERN(AtenOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, context); - INSERT_LINEAR_ATENOP_PATTERN(AtenLinearOp); + INSERT_LINEAR_ATENOP_PATTERN(AtenLinearOp); #undef INSERT_LINEAR_ATEMOP_PATTERN #define INSERT_ADAPTIVE_POOLING_ATENOP_PATTERN(AtenOp, TosaOpT) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, \ context); - INSERT_ADAPTIVE_POOLING_ATENOP_PATTERN(AtenAdaptiveAvgPool2dOp, - tosa::AvgPool2dOp); + INSERT_ADAPTIVE_POOLING_ATENOP_PATTERN(AtenAdaptiveAvgPool2dOp, + tosa::AvgPool2dOp); #undef INSERT_ADAPTIVE_POOLING_ATEMOP_PATTERN - target.addIllegalOp(); - patterns.add(typeConverter, context); + illegalOps.insert(AtenMaxPool2dOp::getOperationName()); + patterns.add(typeConverter, context); - target.addIllegalOp(); - patterns.add(typeConverter, context); + illegalOps.insert(AtenMaxPool1dOp::getOperationName()); + patterns.add(typeConverter, context); - target.addIllegalOp(); - patterns.add(typeConverter, context); + illegalOps.insert(AtenAvgPool2dOp::getOperationName()); + patterns.add(typeConverter, context); - target.addIllegalOp(); - patterns.add(typeConverter, context); + illegalOps.insert(AtenAvgPool1dOp::getOperationName()); + patterns.add(typeConverter, context); #define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, \ context); - INSERT_CONSTANT_FILL_PATTERN(AtenOnesOp, 1); - INSERT_CONSTANT_FILL_PATTERN(AtenZerosOp, 0); - INSERT_CONSTANT_FILL_PATTERN(AtenEmptyMemoryFormatOp, 0); + INSERT_CONSTANT_FILL_PATTERN(AtenOnesOp, 1); + INSERT_CONSTANT_FILL_PATTERN(AtenZerosOp, 0); + INSERT_CONSTANT_FILL_PATTERN(AtenEmptyMemoryFormatOp, 0); #undef INSERT_CONSTANT_FILL_PATTERN #define INSERT_FILL_PATTERN(AtenOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, context); - INSERT_FILL_PATTERN(AtenFill_ScalarOp); - INSERT_FILL_PATTERN(AtenFillScalarOp); - INSERT_FILL_PATTERN(AtenFillTensorOp); + INSERT_FILL_PATTERN(AtenFill_ScalarOp); + INSERT_FILL_PATTERN(AtenFillScalarOp); + INSERT_FILL_PATTERN(AtenFillTensorOp); #undef INSERT_FILL_PATTERN #define INSERT_MASKED_FILL_PATTERN(AtenOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, context); - INSERT_MASKED_FILL_PATTERN(AtenMaskedFillScalarOp); - INSERT_MASKED_FILL_PATTERN(AtenMaskedFillTensorOp); + INSERT_MASKED_FILL_PATTERN(AtenMaskedFillScalarOp); + INSERT_MASKED_FILL_PATTERN(AtenMaskedFillTensorOp); #undef INSERT_MASKED_FILL_PATTERN #define INSERT_POW_OP_PATTERN(AtenOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, context); - INSERT_POW_OP_PATTERN(AtenPowTensorScalarOp); - INSERT_POW_OP_PATTERN(AtenPowTensorTensorOp); - INSERT_POW_OP_PATTERN(AtenPowScalarOp); + INSERT_POW_OP_PATTERN(AtenPowTensorScalarOp); + INSERT_POW_OP_PATTERN(AtenPowTensorTensorOp); + INSERT_POW_OP_PATTERN(AtenPowScalarOp); #undef INSERT_POW_OP_PATTERN +#define INSERT_UPSAMPLE_NEAREST_2D_FORWARD_OP_PATTERN(AtenOp) \ + illegalOps.insert(AtenOp::getOperationName()); \ + patterns.add>(typeConverter, context); + INSERT_UPSAMPLE_NEAREST_2D_FORWARD_OP_PATTERN(AtenUpsampleNearest2dOp); + INSERT_UPSAMPLE_NEAREST_2D_FORWARD_OP_PATTERN(AtenUpsampleNearest2dVecOp); +#undef INSERT_UPSAMPLE_NEAREST_2D_FORWARD_OP_PATTERN + #define INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenOp, TosaOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, \ context); - INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenTanhOp, tosa::TanhOp); - INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenSigmoidOp, tosa::SigmoidOp); - INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenErfOp, tosa::ErfOp); + INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenTanhOp, tosa::TanhOp); + INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenSigmoidOp, tosa::SigmoidOp); + INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenErfOp, tosa::ErfOp); #undef INSERT_ACTIVATION_FUNCITON_OP_PATTERN -#define INSERT_UPSAMPLE_NEAREST_2D_FORWARD_OP_PATTERN(AtenOp) \ - target.addIllegalOp(); \ - patterns.add>(typeConverter, context); - INSERT_UPSAMPLE_NEAREST_2D_FORWARD_OP_PATTERN(AtenUpsampleNearest2dOp); - INSERT_UPSAMPLE_NEAREST_2D_FORWARD_OP_PATTERN(AtenUpsampleNearest2dVecOp); -#undef INSERT_UPSAMPLE_NEAREST_2D_FORWARD_OP_PATTERN - #define INSERT_ATENOP_PATTERN(AtenOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, context); - INSERT_ATENOP_PATTERN(AtenHardtanhBackwardOp); - INSERT_ATENOP_PATTERN(AtenReluOp); - INSERT_ATENOP_PATTERN(AtenLeakyReluOp); - INSERT_ATENOP_PATTERN(AtenArgmaxOp); - INSERT_ATENOP_PATTERN(AtenRsubScalarOp); - INSERT_ATENOP_PATTERN(AtenConvolutionOp); - INSERT_ATENOP_PATTERN(ValueTensorLiteralOp); - INSERT_ATENOP_PATTERN(AtenReshapeOp); - INSERT_ATENOP_PATTERN(AtenBatchNormOp); - INSERT_ATENOP_PATTERN(AtenNativeLayerNormOp); - INSERT_ATENOP_PATTERN(AtenFlattenUsingIntsOp); - INSERT_ATENOP_PATTERN(AtenUnflattenIntOp); - INSERT_ATENOP_PATTERN(AtenPermuteOp); - INSERT_ATENOP_PATTERN(AtenLog2Op); - INSERT_ATENOP_PATTERN(AtenThresholdOp); - INSERT_ATENOP_PATTERN(AtenUnsqueezeOp); - INSERT_ATENOP_PATTERN(AtenContiguousOp); - INSERT_ATENOP_PATTERN(AtenDropoutOp); - INSERT_ATENOP_PATTERN(AtenViewOp); - INSERT_ATENOP_PATTERN(AtenGeluOp); - INSERT_ATENOP_PATTERN(AtenGeluBackwardOp); - INSERT_ATENOP_PATTERN(AtenEmbeddingOp); - INSERT_ATENOP_PATTERN(AtenTransposeIntOp); - INSERT_ATENOP_PATTERN(AtenSliceTensorOp); - INSERT_ATENOP_PATTERN(AtenBroadcastToOp); - INSERT_ATENOP_PATTERN(AtenGatherOp); - INSERT_ATENOP_PATTERN(AtenIndexPutHackedTwinOp); - INSERT_ATENOP_PATTERN(AtenIndexTensorHackedTwinOp); - INSERT_ATENOP_PATTERN(AtenAbsOp); - INSERT_ATENOP_PATTERN(AtenWhereSelfOp); - INSERT_ATENOP_PATTERN(AtenClampOp); - INSERT_ATENOP_PATTERN(AtenArangeStartStepOp); - INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp); - INSERT_ATENOP_PATTERN(AtenCopyOp); - INSERT_ATENOP_PATTERN(AtenToDtypeOp); - INSERT_ATENOP_PATTERN(AtenConstantPadNdOp); - INSERT_ATENOP_PATTERN(AtenCatOp); - INSERT_ATENOP_PATTERN(AtenSqrtOp); - INSERT_ATENOP_PATTERN(AtenIscloseOp); - INSERT_ATENOP_PATTERN(Aten__InterpolateSizeListScaleListOp); - INSERT_ATENOP_PATTERN(AtenTrilOp); - INSERT_ATENOP_PATTERN(AtenDiagonalOp); - INSERT_ATENOP_PATTERN(AtenIndexSelectOp); - INSERT_ATENOP_PATTERN(AtenFlipOp); - INSERT_ATENOP_PATTERN(AtenRoundOp); - INSERT_ATENOP_PATTERN(AtenScatterSrcOp); - INSERT_ATENOP_PATTERN(AtenSliceScatterOp); - INSERT_ATENOP_PATTERN(AtenDiagEmbedOp); - INSERT_ATENOP_PATTERN(AtenUniformOp); - INSERT_ATENOP_PATTERN(AtenThresholdBackwardOp); - INSERT_ATENOP_PATTERN(AtenAsStridedOp); - INSERT_ATENOP_PATTERN(AtenClampTensorOp); - INSERT_ATENOP_PATTERN(PrimsCollapseOp); - INSERT_ATENOP_PATTERN(AtenReflectionPad1dOp); - INSERT_ATENOP_PATTERN(AtenReflectionPad2dOp); - INSERT_ATENOP_PATTERN(AtenReplicationPad2dOp); - INSERT_ATENOP_PATTERN(PrimsSplitDimOp); - INSERT_ATENOP_PATTERN(AtenOuterOp); - INSERT_ATENOP_PATTERN(AtenLogitOp); - INSERT_ATENOP_PATTERN(AtenLog1pOp); - INSERT_ATENOP_PATTERN(AtenLog10Op); - INSERT_ATENOP_PATTERN(AtenTanOp); + INSERT_ATENOP_PATTERN(AtenHardtanhBackwardOp); + INSERT_ATENOP_PATTERN(AtenReluOp); + INSERT_ATENOP_PATTERN(AtenLeakyReluOp); + INSERT_ATENOP_PATTERN(AtenArgmaxOp); + INSERT_ATENOP_PATTERN(AtenRsubScalarOp); + INSERT_ATENOP_PATTERN(AtenConvolutionOp); + INSERT_ATENOP_PATTERN(ValueTensorLiteralOp); + INSERT_ATENOP_PATTERN(AtenReshapeOp); + INSERT_ATENOP_PATTERN(AtenBatchNormOp); + INSERT_ATENOP_PATTERN(AtenNativeLayerNormOp); + INSERT_ATENOP_PATTERN(AtenFlattenUsingIntsOp); + INSERT_ATENOP_PATTERN(AtenUnflattenIntOp); + INSERT_ATENOP_PATTERN(AtenPermuteOp); + INSERT_ATENOP_PATTERN(AtenLog2Op); + INSERT_ATENOP_PATTERN(AtenThresholdOp); + INSERT_ATENOP_PATTERN(AtenUnsqueezeOp); + INSERT_ATENOP_PATTERN(AtenContiguousOp); + INSERT_ATENOP_PATTERN(AtenDropoutOp); + INSERT_ATENOP_PATTERN(AtenViewOp); + INSERT_ATENOP_PATTERN(AtenGeluOp); + INSERT_ATENOP_PATTERN(AtenGeluBackwardOp); + INSERT_ATENOP_PATTERN(AtenEmbeddingOp); + INSERT_ATENOP_PATTERN(AtenTransposeIntOp); + INSERT_ATENOP_PATTERN(AtenSliceTensorOp); + INSERT_ATENOP_PATTERN(AtenBroadcastToOp); + INSERT_ATENOP_PATTERN(AtenGatherOp); + INSERT_ATENOP_PATTERN(AtenIndexPutHackedTwinOp); + INSERT_ATENOP_PATTERN(AtenIndexTensorHackedTwinOp); + INSERT_ATENOP_PATTERN(AtenAbsOp); + INSERT_ATENOP_PATTERN(AtenWhereSelfOp); + INSERT_ATENOP_PATTERN(AtenClampOp); + INSERT_ATENOP_PATTERN(AtenArangeStartStepOp); + INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp); + INSERT_ATENOP_PATTERN(AtenCopyOp); + INSERT_ATENOP_PATTERN(AtenToDtypeOp); + INSERT_ATENOP_PATTERN(AtenConstantPadNdOp); + INSERT_ATENOP_PATTERN(AtenCatOp); + INSERT_ATENOP_PATTERN(AtenSqrtOp); + INSERT_ATENOP_PATTERN(AtenIscloseOp); + INSERT_ATENOP_PATTERN(Aten__InterpolateSizeListScaleListOp); + INSERT_ATENOP_PATTERN(AtenTrilOp); + INSERT_ATENOP_PATTERN(AtenDiagonalOp); + INSERT_ATENOP_PATTERN(AtenIndexSelectOp); + INSERT_ATENOP_PATTERN(AtenFlipOp); + INSERT_ATENOP_PATTERN(AtenRoundOp); + INSERT_ATENOP_PATTERN(AtenScatterSrcOp); + INSERT_ATENOP_PATTERN(AtenSliceScatterOp); + INSERT_ATENOP_PATTERN(AtenDiagEmbedOp); + INSERT_ATENOP_PATTERN(AtenUniformOp); + INSERT_ATENOP_PATTERN(AtenThresholdBackwardOp); + INSERT_ATENOP_PATTERN(AtenAsStridedOp); + INSERT_ATENOP_PATTERN(AtenClampTensorOp); + INSERT_ATENOP_PATTERN(PrimsCollapseOp); + INSERT_ATENOP_PATTERN(AtenReflectionPad1dOp); + INSERT_ATENOP_PATTERN(AtenReflectionPad2dOp); + INSERT_ATENOP_PATTERN(AtenReplicationPad2dOp); + INSERT_ATENOP_PATTERN(PrimsSplitDimOp); + INSERT_ATENOP_PATTERN(AtenOuterOp); + INSERT_ATENOP_PATTERN(AtenLogitOp); + INSERT_ATENOP_PATTERN(AtenLog1pOp); + INSERT_ATENOP_PATTERN(AtenLog10Op); + INSERT_ATENOP_PATTERN(AtenTanOp); #undef INSERT_ATENOP_PATTERN #define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, context); - INSERT_CLONE_ATENOP_PATTERN(AtenCloneOp); + INSERT_CLONE_ATENOP_PATTERN(AtenCloneOp); #undef INSERT_CLONE_ATENOP_PATTERN - if (failed(applyPartialConversion(getOperation(), target, - std::move(patterns)))) - return signalPassFailure(); - } -}; -} // namespace + return illegalOps; +} std::unique_ptr> mlir::torch::createConvertTorchToTosaPass() {