diff --git a/services/webnn/BUILD.gn b/services/webnn/BUILD.gn index b0f27b4cd33b00..e936abe6f78e82 100644 --- a/services/webnn/BUILD.gn +++ b/services/webnn/BUILD.gn @@ -95,6 +95,7 @@ component("webnn_service") { deps += [ "//third_party/fp16", "//third_party/microsoft_dxheaders:dxguids", + "//third_party/onnxruntime_headers", "//ui/gl", "//ui/gl/init", ] diff --git a/services/webnn/ort/allocator_ort.h b/services/webnn/ort/allocator_ort.h index 8b48692b810875..d5a966df9ff2c5 100644 --- a/services/webnn/ort/allocator_ort.h +++ b/services/webnn/ort/allocator_ort.h @@ -7,7 +7,7 @@ #include "base/component_export.h" #include "base/memory/ref_counted.h" -#include "third_party/microsoft_dxheaders/include/onnxruntime_c_api.h" +#include "third_party/onnxruntime_headers/src/include/onnxruntime/core/session/onnxruntime_c_api.h" namespace webnn::ort { diff --git a/services/webnn/ort/context_impl_ort.h b/services/webnn/ort/context_impl_ort.h index f77f44888cab49..e6559e630f49b7 100644 --- a/services/webnn/ort/context_impl_ort.h +++ b/services/webnn/ort/context_impl_ort.h @@ -12,7 +12,7 @@ #include "services/webnn/webnn_constant_operand.h" #include "services/webnn/webnn_context_impl.h" #include "services/webnn/webnn_graph_impl.h" -#include "third_party/microsoft_dxheaders/include/onnxruntime_c_api.h" +#include "third_party/onnxruntime_headers/src/include/onnxruntime/core/session/onnxruntime_c_api.h" namespace webnn::ort { diff --git a/services/webnn/ort/graph_builder_ort.cc b/services/webnn/ort/graph_builder_ort.cc index c69c0930a92b30..d1addc0ffb4919 100644 --- a/services/webnn/ort/graph_builder_ort.cc +++ b/services/webnn/ort/graph_builder_ort.cc @@ -247,7 +247,8 @@ void GraphBuilderOrt::AddInitializer(uint64_t constant_id) { operand.ByteSpan(), operand_info.onnx_data_type); - CHECK(result_->id_to_operand_info.try_emplace(constant_id, std::move(operand_info)) + CHECK(result_->id_to_operand_info + .try_emplace(constant_id, std::move(operand_info)) .second); } @@ -411,12 +412,13 @@ void GraphBuilderOrt::AddCastOperation(const mojom::ElementWiseUnary& cast) { int64_t to_data_type = static_cast( OperandTypeToONNXTensorElementDataType(output_data_type)); - ScopedOrtOpAttr attr_to; + ScopedOrtOpAttrPtr attr_to; model_builder_.CreateAttribute(attr_to, /*name=*/"to", to_data_type); - std::array attributes = {attr_to.get_pptr()}; + std::array attributes = {attr_to}; - model_builder_.AddNode(kOpTypeCast, node_name, input_names, output_names, attributes); + model_builder_.AddNode(kOpTypeCast, node_name, input_names, output_names, + attributes); } void GraphBuilderOrt::AddClampOperation(const mojom::Clamp& clamp) { @@ -487,11 +489,12 @@ void GraphBuilderOrt::AddConv2dOperation(const mojom::Conv2d& conv2d) { std::array dilations = { base::checked_cast(conv2d.dilations->height), base::checked_cast(conv2d.dilations->width)}; - ScopedOrtOpAttr attr_dilations; - model_builder_.CreateAttribute(attr_dilations, /*name=*/"dilations", dilations); + ScopedOrtOpAttrPtr attr_dilations; + model_builder_.CreateAttribute(attr_dilations, /*name=*/"dilations", + dilations); int64_t group = base::checked_cast(conv2d.groups); - ScopedOrtOpAttr attr_group; + ScopedOrtOpAttrPtr attr_group; model_builder_.CreateAttribute(attr_group, /*name=*/"group", group); std::array pads = { @@ -499,20 +502,20 @@ void GraphBuilderOrt::AddConv2dOperation(const mojom::Conv2d& conv2d) { base::checked_cast(conv2d.padding->beginning->width), base::checked_cast(conv2d.padding->ending->height), base::checked_cast(conv2d.padding->ending->width)}; - ScopedOrtOpAttr attr_pads; + ScopedOrtOpAttrPtr attr_pads; model_builder_.CreateAttribute(attr_pads, /*name=*/"pads", pads); std::array strides = { base::checked_cast(conv2d.strides->height), base::checked_cast(conv2d.strides->width)}; - ScopedOrtOpAttr attr_strides; + ScopedOrtOpAttrPtr attr_strides; model_builder_.CreateAttribute(attr_strides, /*name=*/"strides", strides); - std::array attributes = { - attr_dilations.get_pptr(), - attr_group.get_pptr(), - attr_pads.get_pptr(), - attr_strides.get_pptr(), + std::array attributes = { + attr_dilations, + attr_group, + attr_pads, + attr_strides, }; model_builder_.AddNode(kOpTypeConv2d, node_name, input_names, output_names, attributes); @@ -535,22 +538,21 @@ void GraphBuilderOrt::AddGemmOperation(const mojom::Gemm& gemm) { } std::array output_names = {output_name.c_str()}; - ScopedOrtOpAttr attr_alpha; + ScopedOrtOpAttrPtr attr_alpha; model_builder_.CreateAttribute(attr_alpha, /*name=*/"alpha", gemm.alpha); - ScopedOrtOpAttr attr_beta; + ScopedOrtOpAttrPtr attr_beta; model_builder_.CreateAttribute(attr_beta, /*name=*/"beta", gemm.beta); int64_t trans_a = static_cast(gemm.a_transpose); - ScopedOrtOpAttr attr_transA; + ScopedOrtOpAttrPtr attr_transA; model_builder_.CreateAttribute(attr_transA, /*name=*/"transA", trans_a); int64_t trans_b = static_cast(gemm.b_transpose); - ScopedOrtOpAttr attr_transB; + ScopedOrtOpAttrPtr attr_transB; model_builder_.CreateAttribute(attr_transB, /*name=*/"transB", trans_b); - std::array attributes = { - attr_alpha.get_pptr(), attr_beta.get_pptr(), attr_transA.get_pptr(), - attr_transB.get_pptr()}; + std::array attributes = {attr_alpha, attr_beta, attr_transA, + attr_transB}; model_builder_.AddNode(kOpTypeGemm, node_name, input_names, output_names, attributes); @@ -576,20 +578,20 @@ void GraphBuilderOrt::AddPool2dOperation(const mojom::Pool2d& pool2d) { std::array dilations = { base::checked_cast(pool2d.dilations->height), base::checked_cast(pool2d.dilations->width)}; - ScopedOrtOpAttr attr_dilations; + ScopedOrtOpAttrPtr attr_dilations; model_builder_.CreateAttribute(attr_dilations, /*name=*/"dilations", dilations); std::array strides = { base::checked_cast(pool2d.strides->height), base::checked_cast(pool2d.strides->width)}; - ScopedOrtOpAttr attr_strides; + ScopedOrtOpAttrPtr attr_strides; model_builder_.CreateAttribute(attr_strides, /*name=*/"strides", strides); std::array window_dimensions = { base::checked_cast(pool2d.window_dimensions->height), base::checked_cast(pool2d.window_dimensions->width)}; - ScopedOrtOpAttr attr_kernel_shape; + ScopedOrtOpAttrPtr attr_kernel_shape; model_builder_.CreateAttribute(attr_kernel_shape, /*name=*/"kernel_shape", window_dimensions); @@ -600,7 +602,7 @@ void GraphBuilderOrt::AddPool2dOperation(const mojom::Pool2d& pool2d) { base::checked_cast(pool2d.padding->beginning->width), base::checked_cast(pool2d.padding->ending->height), base::checked_cast(pool2d.padding->ending->width)}; - ScopedOrtOpAttr attr_pads; + ScopedOrtOpAttrPtr attr_pads; model_builder_.CreateAttribute(attr_pads, /*name=*/"pads", pads); // Calculate the ceil_mode. @@ -626,12 +628,12 @@ void GraphBuilderOrt::AddPool2dOperation(const mojom::Pool2d& pool2d) { CHECK(float_output_height.has_value()); int64_t ceil_mode = float_output_height.value() < output_height ? 1 : 0; - ScopedOrtOpAttr attr_ceil_mode; + ScopedOrtOpAttrPtr attr_ceil_mode; model_builder_.CreateAttribute(attr_ceil_mode, /*name=*/"ceil_mode", ceil_mode); // P value of the Lp norm used to pool over the input data. - std::optional attr_p; + std::optional attr_p; std::optional p; std::string op_type; switch (pool2d.kind) { @@ -652,14 +654,13 @@ void GraphBuilderOrt::AddPool2dOperation(const mojom::Pool2d& pool2d) { } } - std::vector attributes = { - attr_dilations.get_pptr(), attr_strides.get_pptr(), - attr_kernel_shape.get_pptr(), attr_pads.get_pptr(), - attr_ceil_mode.get_pptr()}; + std::vector attributes = {attr_dilations, attr_strides, + attr_kernel_shape, attr_pads, + attr_ceil_mode}; if (op_type == kOpTypeLpPool2d) { CHECK(attr_p.has_value()); CHECK(p.has_value()); - attributes.push_back(attr_p.value().get_pptr()); + attributes.push_back(attr_p.value()); } const std::string node_name = GetNodeName(pool2d.label); @@ -710,11 +711,12 @@ void GraphBuilderOrt::AddSoftmaxOperation(const mojom::Softmax& softmax) { std::array output_names = {output_name.c_str()}; int64_t axis = static_cast(softmax.axis); - ScopedOrtOpAttr attr_axis; + ScopedOrtOpAttrPtr attr_axis; model_builder_.CreateAttribute(attr_axis, /*name=*/"axis", axis); - std::array attributes = {attr_axis.get_pptr()}; - model_builder_.AddNode(kOpTypeSoftmax, node_name, input_names, output_names, attributes); + std::array attributes = {attr_axis}; + model_builder_.AddNode(kOpTypeSoftmax, node_name, input_names, output_names, + attributes); } void GraphBuilderOrt::AddTransposeOperation(const mojom::Transpose& transpose) { @@ -727,10 +729,10 @@ void GraphBuilderOrt::AddTransposeOperation(const mojom::Transpose& transpose) { std::vector permutation(transpose.permutation.begin(), transpose.permutation.end()); - ScopedOrtOpAttr attr_perm; + ScopedOrtOpAttrPtr attr_perm; model_builder_.CreateAttribute(attr_perm, /*name=*/"perm", permutation); - std::array attributes = {attr_perm.get_pptr()}; + std::array attributes = {attr_perm}; model_builder_.AddNode(kOpTypeTranspose, node_name, input_names, output_names, attributes); } @@ -749,12 +751,12 @@ void GraphBuilderOrt::AddWhereOperation(const mojom::Where& where) { std::array cast_output_names = { cast_node_output_name.c_str()}; - ScopedOrtOpAttr attr_to; + ScopedOrtOpAttrPtr attr_to; int64_t to_data_type = static_cast(ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL); model_builder_.CreateAttribute(attr_to, /*name=*/"to", to_data_type); - std::array cast_attributes = {attr_to.get_pptr()}; + std::array cast_attributes = {attr_to}; model_builder_.AddNode(kOpTypeCast, cast_node_name, cast_input_names, cast_output_names, cast_attributes); next_operand_id_++; diff --git a/services/webnn/ort/graph_builder_ort.h b/services/webnn/ort/graph_builder_ort.h index 07c1286c06c8ee..adb2c130eb0b78 100644 --- a/services/webnn/ort/graph_builder_ort.h +++ b/services/webnn/ort/graph_builder_ort.h @@ -17,13 +17,13 @@ #include "base/memory/stack_allocated.h" #include "base/types/expected.h" #include "services/webnn/ort/allocator_ort.h" +#include "services/webnn/ort/ort_model_builder.h" #include "services/webnn/ort/scoped_ort_types.h" #include "services/webnn/public/cpp/context_properties.h" #include "services/webnn/public/cpp/operand_descriptor.h" #include "services/webnn/public/mojom/webnn_context_provider.mojom.h" #include "services/webnn/public/mojom/webnn_error.mojom-forward.h" #include "services/webnn/public/mojom/webnn_graph.mojom.h" -#include "services/webnn/ort/ort_model_builder.h" namespace webnn { @@ -62,7 +62,7 @@ class GraphBuilderOrt { const OperandInfo& GetOperandInfo(uint64_t operand_id) const; std::map id_to_operand_info; - + std::unique_ptr model_info; }; diff --git a/services/webnn/ort/graph_impl_ort.cc b/services/webnn/ort/graph_impl_ort.cc index faad950c208a98..ab0e1b8d373cee 100644 --- a/services/webnn/ort/graph_impl_ort.cc +++ b/services/webnn/ort/graph_impl_ort.cc @@ -22,6 +22,7 @@ #include "services/webnn/resource_task.h" #include "services/webnn/webnn_constant_operand.h" #include "services/webnn/webnn_graph_impl.h" +#include "third_party/onnxruntime_headers/src/include/onnxruntime/core/providers/dml/dml_provider_factory.h" namespace webnn::ort { @@ -132,8 +133,8 @@ GraphImplOrt::CreateAndBuildOnBackgroundThread( OrtSession* session; const OrtEnv* env = allocator->env(); - OrtStatus* status = GetOrtGraphApi()->CreateSessionFromModel( - env, result->model_info->model.get_ptr(), session_options, &session); + OrtStatus* status = GetOrtModelBuilderApi()->CreateSessionFromModel( + env, result->model_info->model, session_options, &session); ort_api->ReleaseSessionOptions(session_options); if (status != NULL) { diff --git a/services/webnn/ort/graph_impl_ort.h b/services/webnn/ort/graph_impl_ort.h index 7b58acc694bdcb..12dd70dce8e016 100644 --- a/services/webnn/ort/graph_impl_ort.h +++ b/services/webnn/ort/graph_impl_ort.h @@ -22,7 +22,7 @@ #include "services/webnn/queueable_resource_state.h" #include "services/webnn/webnn_context_impl.h" #include "services/webnn/webnn_graph_impl.h" -#include "third_party/microsoft_dxheaders/include/onnxruntime_c_api.h" +#include "third_party/onnxruntime_headers/src/include/onnxruntime/core/session/onnxruntime_c_api.h" namespace webnn { diff --git a/services/webnn/ort/ort_model_builder.cc b/services/webnn/ort/ort_model_builder.cc index 2bd05b2ed57cf4..876a63899fd523 100644 --- a/services/webnn/ort/ort_model_builder.cc +++ b/services/webnn/ort/ort_model_builder.cc @@ -22,39 +22,45 @@ namespace ort { OrtModelBuilder::ModelInfo::ModelInfo() = default; OrtModelBuilder::ModelInfo::~ModelInfo() = default; +ScopedOrtValueInfoPtr CreateOrtValueInfo(std::string_view name, + base::span shape, + ONNXTensorElementDataType data_type) { + ScopedOrtTensorTypeAndShapeInfoPtr tensor_type_and_shape_info; + CHECK_STATUS(GetOrtApi()->CreateTensorTypeAndShapeInfo( + tensor_type_and_shape_info.GetAddressOf())); + CHECK_STATUS( + GetOrtApi()->SetTensorElementType(tensor_type_and_shape_info, data_type)); + CHECK_STATUS(GetOrtApi()->SetDimensions(tensor_type_and_shape_info, + shape.data(), shape.size())); + + ScopedOrtTypeInfoPtr type_info; + CHECK_STATUS(GetOrtApi()->CreateTensorTypeInfo(tensor_type_and_shape_info, + type_info.GetAddressOf())); + + ScopedOrtValueInfoPtr value_info; + CHECK_STATUS(GetOrtModelBuilderApi()->CreateValueInfo( + name.data(), type_info, value_info.GetAddressOf())); + return value_info; +} + OrtModelBuilder::OrtModelBuilder(scoped_refptr allocator) : allocator_(std::move(allocator)), model_info_(std::make_unique()) { - CHECK_STATUS(GetOrtGraphApi()->CreateGraph(graph_.get_pptr())); + CHECK_STATUS(GetOrtModelBuilderApi()->CreateGraph(graph_.GetAddressOf())); } + OrtModelBuilder::~OrtModelBuilder() = default; void OrtModelBuilder::AddInput(std::string_view name, base::span shape, ONNXTensorElementDataType data_type) { - ScopedOrtShape input_shape; - CHECK_STATUS(GetOrtGraphApi()->CreateFixedShape(shape.data(), shape.size(), - input_shape.get_pptr())); - - ScopedOrtValueInfo input_info; - CHECK_STATUS(GetOrtGraphApi()->CreateTensorValueInfo( - name.data(), data_type, input_shape.get_pptr(), input_info.get_pptr())); - CHECK_STATUS( - GetOrtGraphApi()->AddInput(graph_.get_ptr(), input_info.get_pptr())); + inputs_.push_back(CreateOrtValueInfo(name, shape, data_type)); } void OrtModelBuilder::AddOutput(std::string_view name, base::span shape, ONNXTensorElementDataType data_type) { - ScopedOrtShape output_shape; - CHECK_STATUS(GetOrtGraphApi()->CreateFixedShape(shape.data(), shape.size(), - output_shape.get_pptr())); - - ScopedOrtValueInfo output_info; - CHECK_STATUS(GetOrtGraphApi()->CreateTensorValueInfo( - name.data(), data_type, output_shape.get_pptr(), output_info.get_pptr())); - CHECK_STATUS( - GetOrtGraphApi()->AddOutput(graph_.get_ptr(), output_info.get_pptr())); + outputs_.push_back(CreateOrtValueInfo(name, shape, data_type)); } void OrtModelBuilder::AddInitializerAsRawData( @@ -62,20 +68,22 @@ void OrtModelBuilder::AddInitializerAsRawData( base::span shape, base::span data, ONNXTensorElementDataType data_type) { - ScopedOrtValue initializer; + ScopedOrtValuePtr initializer; CHECK_STATUS(GetOrtApi()->CreateTensorAsOrtValue( allocator_->allocator(), shape.data(), shape.size(), data_type, - initializer.get_pptr())); + initializer.GetAddressOf())); void* ort_tensor_raw_data = nullptr; - CHECK_STATUS(GetOrtApi()->GetTensorMutableData(initializer.get_ptr(), - &ort_tensor_raw_data)); + CHECK_STATUS( + GetOrtApi()->GetTensorMutableData(initializer, &ort_tensor_raw_data)); CHECK(ort_tensor_raw_data); UNSAFE_BUFFERS( base::span(static_cast(ort_tensor_raw_data), data.size())) .copy_from(data); - CHECK_STATUS(GetOrtGraphApi()->AddInitializer(graph_.get_ptr(), name.data(), - initializer.get_pptr())); + // Graph will own the initializer. + CHECK_STATUS(GetOrtModelBuilderApi()->AddInitializerToGraph( + graph_, name.data(), initializer.Release(), + /*data_is_external=*/false)); } void OrtModelBuilder::AddInitializerAsExternalData( @@ -86,49 +94,52 @@ void OrtModelBuilder::AddInitializerAsExternalData( auto weight = base::HeapArray::CopiedFrom(data); model_info_->external_data.push_back(std::move(weight)); - ScopedOrtValue initializer; + ScopedOrtValuePtr initializer; + // TODO: Use `CreateTensorWithDataAndDeleterAsOrtValue()`. CHECK_STATUS(GetOrtApi()->CreateTensorWithDataAsOrtValue( allocator_->memory_info(), model_info_->external_data.back().data(), model_info_->external_data.back().size(), shape.data(), shape.size(), - data_type, initializer.get_pptr())); - CHECK_STATUS(GetOrtGraphApi()->AddInitializer(graph_.get_ptr(), name.data(), - initializer.get_pptr())); + data_type, initializer.GetAddressOf())); + // Graph will own the initializer. + CHECK_STATUS(GetOrtModelBuilderApi()->AddInitializerToGraph( + graph_, name.data(), initializer.Release(), + /*data_is_external=*/true)); } -void OrtModelBuilder::CreateAttribute(ScopedOrtOpAttr& attribute, +void OrtModelBuilder::CreateAttribute(ScopedOrtOpAttrPtr& attribute, std::string_view name, OrtOpAttrData data) { if (absl::holds_alternative(data)) { CHECK_STATUS(GetOrtApi()->CreateOpAttr( name.data(), &absl::get(data), /*len=*/1, - OrtOpAttrType::ORT_OP_ATTR_INT, attribute.get_pptr())); + OrtOpAttrType::ORT_OP_ATTR_INT, attribute.GetAddressOf())); } else if (absl::holds_alternative(data)) { CHECK_STATUS(GetOrtApi()->CreateOpAttr( name.data(), &absl::get(data), /*len=*/1, - OrtOpAttrType::ORT_OP_ATTR_FLOAT, attribute.get_pptr())); + OrtOpAttrType::ORT_OP_ATTR_FLOAT, attribute.GetAddressOf())); } else if (absl::holds_alternative(data)) { std::string_view string_data = absl::get(data); CHECK_STATUS(GetOrtApi()->CreateOpAttr( name.data(), string_data.data(), string_data.size(), - OrtOpAttrType::ORT_OP_ATTR_STRING, attribute.get_pptr())); + OrtOpAttrType::ORT_OP_ATTR_STRING, attribute.GetAddressOf())); } else if (absl::holds_alternative>(data)) { base::span ints_data = absl::get>(data); CHECK_STATUS(GetOrtApi()->CreateOpAttr( name.data(), ints_data.data(), ints_data.size(), - OrtOpAttrType::ORT_OP_ATTR_INTS, attribute.get_pptr())); + OrtOpAttrType::ORT_OP_ATTR_INTS, attribute.GetAddressOf())); } else if (absl::holds_alternative>(data)) { base::span floats_data = absl::get>(data); CHECK_STATUS(GetOrtApi()->CreateOpAttr( name.data(), floats_data.data(), floats_data.size(), - OrtOpAttrType::ORT_OP_ATTR_FLOATS, attribute.get_pptr())); + OrtOpAttrType::ORT_OP_ATTR_FLOATS, attribute.GetAddressOf())); } else if (absl::holds_alternative>(data)) { base::span strings_data = absl::get>(data); CHECK_STATUS(GetOrtApi()->CreateOpAttr( name.data(), strings_data.data(), strings_data.size(), - OrtOpAttrType::ORT_OP_ATTR_STRINGS, attribute.get_pptr())); + OrtOpAttrType::ORT_OP_ATTR_STRINGS, attribute.GetAddressOf())); } } @@ -136,26 +147,45 @@ void OrtModelBuilder::AddNode(std::string_view op_type, std::string_view node_name, base::span input_names, base::span output_names, - base::span attributes) { - ScopedOrtNode node; - CHECK_STATUS(GetOrtGraphApi()->CreateNode( + base::span attributes) { + ScopedOrtNodePtr node; + CHECK_STATUS(GetOrtModelBuilderApi()->CreateNode( op_type.data(), kOrtDomainName, node_name.data(), input_names.data(), input_names.size(), output_names.data(), output_names.size(), - attributes.data(), attributes.size(), node.get_pptr())); - CHECK_STATUS(GetOrtGraphApi()->AddNode(graph_.get_ptr(), node.get_pptr())); + attributes.data(), attributes.size(), node.GetAddressOf())); + // Graph will own the node. + CHECK_STATUS(GetOrtModelBuilderApi()->AddNodeToGraph(graph_, node.Release())); } std::unique_ptr OrtModelBuilder::BuildAndTakeModelInfo() { + // Graph will own the input/output `OrtValueInfo`s. + std::vector graph_inputs; + graph_inputs.reserve(inputs_.size()); + for (auto& input : inputs_) { + graph_inputs.push_back(input.Release()); + } + CHECK_STATUS(GetOrtModelBuilderApi()->SetGraphInputs( + graph_, graph_inputs.data(), graph_inputs.size())); + + std::vector graph_outputs; + graph_outputs.reserve(outputs_.size()); + for (auto& output : outputs_) { + graph_outputs.push_back(output.Release()); + } + CHECK_STATUS(GetOrtModelBuilderApi()->SetGraphOutputs( + graph_, graph_outputs.data(), graph_outputs.size())); + std::vector domain_names = {kOrtDomainName}; std::vector opset_versions = {kOrtOpsetVersion}; - CHECK_STATUS(GetOrtGraphApi()->CreateModel( + CHECK_STATUS(GetOrtModelBuilderApi()->CreateModel( domain_names.data(), opset_versions.data(), domain_names.size(), - model_info_->model.get_pptr())); + model_info_->model.GetAddressOf())); - CHECK_STATUS(GetOrtGraphApi()->AddGraph(model_info_->model.get_ptr(), - graph_.get_pptr())); + // Model will own the graph. + CHECK_STATUS(GetOrtModelBuilderApi()->AddGraphToModel(model_info_->model, + graph_.Release())); return std::move(model_info_); } diff --git a/services/webnn/ort/ort_model_builder.h b/services/webnn/ort/ort_model_builder.h index 27e8360e7633a7..6cf1c6d56d559f 100644 --- a/services/webnn/ort/ort_model_builder.h +++ b/services/webnn/ort/ort_model_builder.h @@ -14,7 +14,7 @@ #include "services/webnn/ort/allocator_ort.h" #include "services/webnn/ort/scoped_ort_types.h" #include "third_party/abseil-cpp/absl/types/variant.h" -#include "third_party/microsoft_dxheaders/include/onnxruntime_c_api.h" +#include "third_party/onnxruntime_headers/src/include/onnxruntime/core/session/onnxruntime_c_api.h" namespace webnn { @@ -30,7 +30,7 @@ class OrtModelBuilder final { ModelInfo& operator=(const ModelInfo&) = delete; ~ModelInfo(); - ScopedOrtModel model; + ScopedOrtModelPtr model; // TODO: Consider reusing constant operands instead of copying them to // `external_data`. @@ -61,13 +61,14 @@ class OrtModelBuilder final { base::span shape, base::span data, ONNXTensorElementDataType data_type); + using OrtOpAttrData = absl::variant, base::span, base::span>; - void CreateAttribute(ScopedOrtOpAttr& attribute, + void CreateAttribute(ScopedOrtOpAttrPtr& attribute, std::string_view name, OrtOpAttrData data); @@ -75,14 +76,17 @@ class OrtModelBuilder final { std::string_view node_name, base::span input_names, base::span output_names, - base::span attributes = {}); + base::span attributes = {}); std::unique_ptr BuildAndTakeModelInfo(); private: scoped_refptr allocator_; - ScopedOrtGraph graph_; + std::vector inputs_; + std::vector outputs_; + + ScopedOrtGraphPtr graph_; std::unique_ptr model_info_; }; diff --git a/services/webnn/ort/platform_functions_ort.cc b/services/webnn/ort/platform_functions_ort.cc index d5f5dd0da831d3..8e46cc58c65002 100644 --- a/services/webnn/ort/platform_functions_ort.cc +++ b/services/webnn/ort/platform_functions_ort.cc @@ -8,7 +8,6 @@ #include "base/logging.h" #include "base/native_library.h" #include "base/path_service.h" -#include "third_party/onnx/proto/onnx.pb.h" namespace webnn::ort { @@ -39,16 +38,18 @@ PlatformFunctions::PlatformFunctions() { return; } - const OrtApi* ort_api = - ort_get_api_base_proc()->GetApi(onnx::Version::IR_VERSION); + // ORT_API_VERSION is defined in onnxruntime_c_api.h and must be passed to + // `OrtApiBase::OrtApi()`. + const OrtApi* ort_api = ort_get_api_base_proc()->GetApi(ORT_API_VERSION); if (!ort_api) { LOG(ERROR) << "[WebNN] Failed to get OrtApi."; return; } - const OrtGraphApi* ort_graph_api = ort_api->GetGraphApi(); - if (!ort_graph_api) { - LOG(ERROR) << "[WebNN] Failed to get OrtGraphApi."; + const OrtModelBuilderApi* ort_model_builder_api = + ort_api->GetModelBuilderApi(); + if (!ort_model_builder_api) { + LOG(ERROR) << "[WebNN] Failed to get OrtModelBuilderApi."; return; } @@ -56,7 +57,7 @@ PlatformFunctions::PlatformFunctions() { ort_library_ = std::move(ort_library); ort_get_api_base_proc_ = std::move(ort_get_api_base_proc); ort_api_ = ort_api; - ort_graph_api_ = ort_graph_api; + ort_model_builder_api_ = ort_model_builder_api; } PlatformFunctions::~PlatformFunctions() = default; @@ -71,7 +72,7 @@ PlatformFunctions* PlatformFunctions::GetInstance() { } bool PlatformFunctions::AllFunctionsLoaded() { - return ort_get_api_base_proc_ && ort_api_ && ort_graph_api_; + return ort_get_api_base_proc_ && ort_api_ && ort_model_builder_api_; } } // namespace webnn::ort diff --git a/services/webnn/ort/platform_functions_ort.h b/services/webnn/ort/platform_functions_ort.h index 9902a9ee704441..07549f1150b050 100644 --- a/services/webnn/ort/platform_functions_ort.h +++ b/services/webnn/ort/platform_functions_ort.h @@ -10,8 +10,7 @@ #include "base/component_export.h" #include "base/no_destructor.h" #include "base/scoped_native_library.h" -#include "third_party/microsoft_dxheaders/include/dml_provider_factory.h" -#include "third_party/microsoft_dxheaders/include/onnxruntime_c_api.h" +#include "third_party/onnxruntime_headers/src/include/onnxruntime/core/session/onnxruntime_c_api.h" namespace webnn::ort { @@ -27,7 +26,9 @@ class COMPONENT_EXPORT(WEBNN_SERVICE) PlatformFunctions { return ort_get_api_base_proc_; } const OrtApi* ort_api() const { return ort_api_.get(); } - const OrtGraphApi* ort_graph_api() const { return ort_graph_api_.get(); } + const OrtModelBuilderApi* ort_model_builder_api() const { + return ort_model_builder_api_.get(); + } private: friend class base::NoDestructor; @@ -40,7 +41,7 @@ class COMPONENT_EXPORT(WEBNN_SERVICE) PlatformFunctions { base::ScopedNativeLibrary ort_library_; OrtGetApiBaseProc ort_get_api_base_proc_ = nullptr; raw_ptr ort_api_ = nullptr; - raw_ptr ort_graph_api_ = nullptr; + raw_ptr ort_model_builder_api_ = nullptr; }; } // namespace webnn::ort diff --git a/services/webnn/ort/scoped_ort_types.cc b/services/webnn/ort/scoped_ort_types.cc index 83fece2eabae0c..d5c519fa617617 100644 --- a/services/webnn/ort/scoped_ort_types.cc +++ b/services/webnn/ort/scoped_ort_types.cc @@ -4,67 +4,30 @@ #include "services/webnn/ort/scoped_ort_types.h" -#include - -#include "services/webnn/ort/utils_ort.h" - namespace webnn::ort { -ScopedOrtValue::ScopedOrtValue() { - pptr_ = std::make_unique(nullptr); -} -ScopedOrtValue::~ScopedOrtValue() { - // TODO: use deleter instead. - GetOrtApi()->ReleaseValue(*pptr_); -} - -ScopedOrtMemoryInfo::ScopedOrtMemoryInfo() { - pptr_ = std::make_unique(nullptr); -} -ScopedOrtMemoryInfo::~ScopedOrtMemoryInfo() { - GetOrtApi()->ReleaseMemoryInfo(*pptr_); -} - -ScopedOrtOpAttr::ScopedOrtOpAttr() { - pptr_ = std::make_unique(nullptr); -} -ScopedOrtOpAttr::~ScopedOrtOpAttr() { - GetOrtApi()->ReleaseOpAttr(*pptr_); -} - -ScopedOrtGraph::ScopedOrtGraph() { - pptr_ = std::make_unique(nullptr); -} -ScopedOrtGraph::~ScopedOrtGraph() { - GetOrtGraphApi()->ReleaseGraph(*pptr_); -} - -ScopedOrtShape::ScopedOrtShape() { - pptr_ = std::make_unique(nullptr); -} -ScopedOrtShape::~ScopedOrtShape() { - GetOrtGraphApi()->ReleaseShape(*pptr_); -} - -ScopedOrtValueInfo::ScopedOrtValueInfo() { - pptr_ = std::make_unique(nullptr); -} -ScopedOrtValueInfo::~ScopedOrtValueInfo() { - GetOrtGraphApi()->ReleaseValueInfo(*pptr_); -} - -ScopedOrtNode::ScopedOrtNode() { - pptr_ = std::make_unique(nullptr); -} -ScopedOrtNode::~ScopedOrtNode() { - GetOrtGraphApi()->ReleaseNode(*pptr_); -} - -ScopedOrtModel::ScopedOrtModel() { - pptr_ = std::make_unique(nullptr); -} -ScopedOrtModel::~ScopedOrtModel() { - GetOrtGraphApi()->ReleaseModel(*pptr_); -} +#define SCOPED_ORT_TYPE_PTR_DEFINITION(ort_type, ort_api) \ + ScopedOrt##ort_type##Ptr::ScopedOrt##ort_type##Ptr() { \ + pptr_ = std::make_unique(nullptr); \ + } \ + ScopedOrt##ort_type##Ptr::~ScopedOrt##ort_type##Ptr() { \ + if (pptr_) { \ + Get##ort_api()->Release##ort_type(*pptr_); \ + } \ + } \ + ScopedOrt##ort_type##Ptr::ScopedOrt##ort_type##Ptr( \ + ScopedOrt##ort_type##Ptr&&) = default; \ + ScopedOrt##ort_type##Ptr& ScopedOrt##ort_type##Ptr::operator=( \ + ScopedOrt##ort_type##Ptr&&) = default; + +SCOPED_ORT_TYPE_PTR_DEFINITION(Value, OrtApi) +SCOPED_ORT_TYPE_PTR_DEFINITION(MemoryInfo, OrtApi) +SCOPED_ORT_TYPE_PTR_DEFINITION(OpAttr, OrtApi) +SCOPED_ORT_TYPE_PTR_DEFINITION(TypeInfo, OrtApi) +SCOPED_ORT_TYPE_PTR_DEFINITION(TensorTypeAndShapeInfo, OrtApi) +SCOPED_ORT_TYPE_PTR_DEFINITION(ValueInfo, OrtModelBuilderApi) +SCOPED_ORT_TYPE_PTR_DEFINITION(Node, OrtModelBuilderApi) +SCOPED_ORT_TYPE_PTR_DEFINITION(Graph, OrtModelBuilderApi) +SCOPED_ORT_TYPE_PTR_DEFINITION(Model, OrtModelBuilderApi) } // namespace webnn::ort diff --git a/services/webnn/ort/scoped_ort_types.h b/services/webnn/ort/scoped_ort_types.h index f3d1841037a961..0075a33aaad7ff 100644 --- a/services/webnn/ort/scoped_ort_types.h +++ b/services/webnn/ort/scoped_ort_types.h @@ -7,121 +7,47 @@ #include -#include "third_party/microsoft_dxheaders/include/onnxruntime_c_api.h" +#include "services/webnn/ort/utils_ort.h" +#include "third_party/onnxruntime_headers/src/include/onnxruntime/core/session/onnxruntime_c_api.h" namespace webnn::ort { -class ScopedOrtValue { - public: - ScopedOrtValue(); - ScopedOrtValue(const ScopedOrtValue&) = delete; - ScopedOrtValue& operator=(const ScopedOrtValue&) = delete; - ~ScopedOrtValue(); - - OrtValue* get_ptr() { return *pptr_; } - OrtValue** get_pptr() { return pptr_.get(); } - - private: - std::unique_ptr pptr_; -}; - -class ScopedOrtMemoryInfo { - public: - ScopedOrtMemoryInfo(); - ScopedOrtMemoryInfo(const ScopedOrtMemoryInfo&) = delete; - ScopedOrtMemoryInfo& operator=(const ScopedOrtMemoryInfo&) = delete; - ~ScopedOrtMemoryInfo(); - - OrtMemoryInfo* get_ptr() { return *pptr_; } - OrtMemoryInfo** get_pptr() { return pptr_.get(); } - - private: - std::unique_ptr pptr_; -}; - -class ScopedOrtOpAttr { - public: - ScopedOrtOpAttr(); - ScopedOrtOpAttr(const ScopedOrtOpAttr&) = delete; - ScopedOrtOpAttr& operator=(const ScopedOrtOpAttr&) = delete; - ~ScopedOrtOpAttr(); - - OrtOpAttr* get_ptr() { return *pptr_; } - OrtOpAttr** get_pptr() { return pptr_.get(); } - - private: - std::unique_ptr pptr_; -}; - -class ScopedOrtGraph { - public: - ScopedOrtGraph(); - ScopedOrtGraph(const ScopedOrtGraph&) = delete; - ScopedOrtGraph& operator=(const ScopedOrtGraph&) = delete; - ~ScopedOrtGraph(); - - OrtGraph* get_ptr() { return *pptr_; } - OrtGraph** get_pptr() { return pptr_.get(); } - - private: - std::unique_ptr pptr_; -}; - -class ScopedOrtShape { - public: - ScopedOrtShape(); - ScopedOrtShape(const ScopedOrtShape&) = delete; - ScopedOrtShape& operator=(const ScopedOrtShape&) = delete; - ~ScopedOrtShape(); - - OrtShape* get_ptr() { return *pptr_; } - OrtShape** get_pptr() { return pptr_.get(); } - - private: - std::unique_ptr pptr_; -}; - -class ScopedOrtValueInfo { - public: - ScopedOrtValueInfo(); - ScopedOrtValueInfo(const ScopedOrtValueInfo&) = delete; - ScopedOrtValueInfo& operator=(const ScopedOrtValueInfo&) = delete; - ~ScopedOrtValueInfo(); - - OrtValueInfo* get_ptr() { return *pptr_; } - OrtValueInfo** get_pptr() { return pptr_.get(); } - - private: - std::unique_ptr pptr_; -}; - -class ScopedOrtNode { - public: - ScopedOrtNode(); - ScopedOrtNode(const ScopedOrtNode&) = delete; - ScopedOrtNode& operator=(const ScopedOrtNode&) = delete; - ~ScopedOrtNode(); - - OrtNode* get_ptr() { return *pptr_; } - OrtNode** get_pptr() { return pptr_.get(); } - - private: - std::unique_ptr pptr_; -}; - -class ScopedOrtModel { - public: - ScopedOrtModel(); - ScopedOrtModel(const ScopedOrtModel&) = delete; - ScopedOrtModel& operator=(const ScopedOrtModel&) = delete; - ~ScopedOrtModel(); - - OrtModel* get_ptr() { return *pptr_; } - OrtModel** get_pptr() { return pptr_.get(); } - - private: - std::unique_ptr pptr_; -}; +#define SCOPED_ORT_TYPE_PTR_DECLARATION(ort_type) \ + class ScopedOrt##ort_type##Ptr { \ + public: \ + ScopedOrt##ort_type##Ptr(); \ + ~ScopedOrt##ort_type##Ptr(); \ + ScopedOrt##ort_type##Ptr(const ScopedOrt##ort_type##Ptr&) = delete; \ + ScopedOrt##ort_type##Ptr& operator=(const ScopedOrt##ort_type##Ptr&) = \ + delete; \ + ScopedOrt##ort_type##Ptr(ScopedOrt##ort_type##Ptr&&); \ + ScopedOrt##ort_type##Ptr& operator=(ScopedOrt##ort_type##Ptr&&); \ + operator Ort##ort_type *() const { \ + return *pptr_; \ + } \ + Ort##ort_type* Get() const { \ + return *pptr_; \ + } \ + Ort##ort_type** GetAddressOf() const { \ + return pptr_.get(); \ + } \ + Ort##ort_type* Release() { \ + return *pptr_.release(); \ + } \ + \ + private: \ + std::unique_ptr pptr_; \ + }; + +SCOPED_ORT_TYPE_PTR_DECLARATION(Value) +SCOPED_ORT_TYPE_PTR_DECLARATION(MemoryInfo) +SCOPED_ORT_TYPE_PTR_DECLARATION(OpAttr) +SCOPED_ORT_TYPE_PTR_DECLARATION(TypeInfo) +SCOPED_ORT_TYPE_PTR_DECLARATION(TensorTypeAndShapeInfo) +SCOPED_ORT_TYPE_PTR_DECLARATION(ValueInfo) +SCOPED_ORT_TYPE_PTR_DECLARATION(Node) +SCOPED_ORT_TYPE_PTR_DECLARATION(Graph) +SCOPED_ORT_TYPE_PTR_DECLARATION(Model) } // namespace webnn::ort diff --git a/services/webnn/ort/tensor_impl_ort.h b/services/webnn/ort/tensor_impl_ort.h index f9931c2853915e..64c6df0c3d44a3 100644 --- a/services/webnn/ort/tensor_impl_ort.h +++ b/services/webnn/ort/tensor_impl_ort.h @@ -11,7 +11,7 @@ #include "services/webnn/public/mojom/webnn_tensor.mojom-forward.h" #include "services/webnn/queueable_resource_state.h" #include "services/webnn/webnn_tensor_impl.h" -#include "third_party/microsoft_dxheaders/include/onnxruntime_c_api.h" +#include "third_party/onnxruntime_headers/src/include/onnxruntime/core/session/onnxruntime_c_api.h" namespace webnn::ort { diff --git a/services/webnn/ort/utils_ort.cc b/services/webnn/ort/utils_ort.cc index 6e17478919c307..bb2965296f27f5 100644 --- a/services/webnn/ort/utils_ort.cc +++ b/services/webnn/ort/utils_ort.cc @@ -48,10 +48,10 @@ const OrtApi* GetOrtApi() { return platform_functions->ort_api(); } -const OrtGraphApi* GetOrtGraphApi() { +const OrtModelBuilderApi* GetOrtModelBuilderApi() { PlatformFunctions* platform_functions = PlatformFunctions::GetInstance(); CHECK(platform_functions); - return platform_functions->ort_graph_api(); + return platform_functions->ort_model_builder_api(); } mojom::ErrorPtr CreateError(mojom::Error::Code error_code, diff --git a/services/webnn/ort/utils_ort.h b/services/webnn/ort/utils_ort.h index f76d7f3804a1b1..8b502c5a56ea96 100644 --- a/services/webnn/ort/utils_ort.h +++ b/services/webnn/ort/utils_ort.h @@ -7,7 +7,7 @@ #include "services/webnn/public/cpp/operand_descriptor.h" #include "services/webnn/public/mojom/webnn_error.mojom.h" -#include "third_party/microsoft_dxheaders/include/onnxruntime_c_api.h" +#include "third_party/onnxruntime_headers/src/include/onnxruntime/core/session/onnxruntime_c_api.h" namespace webnn::ort { @@ -16,7 +16,7 @@ ONNXTensorElementDataType OperandTypeToONNXTensorElementDataType( const OrtApi* GetOrtApi(); -const OrtGraphApi* GetOrtGraphApi(); +const OrtModelBuilderApi* GetOrtModelBuilderApi(); mojom::ErrorPtr CreateError(mojom::Error::Code error_code, const std::string& error_message, diff --git a/third_party/onnxruntime_headers/BUILD.gn b/third_party/onnxruntime_headers/BUILD.gn new file mode 100644 index 00000000000000..6e56ff1b2aa9a7 --- /dev/null +++ b/third_party/onnxruntime_headers/BUILD.gn @@ -0,0 +1,11 @@ +# Copyright 2024 The Chromium Authors +# Use of this source code is governed by a BSD-style license that can be +# found in the LICENSE file. + +config("onnxruntime_headers_config") { + include_dirs = [ "src/include/onnxruntime/core/session" ] +} + +source_set("onnxruntime_headers") { + public_configs = [ ":onnxruntime_headers_config" ] +} diff --git a/third_party/onnxruntime_headers/README.chromium b/third_party/onnxruntime_headers/README.chromium new file mode 100644 index 00000000000000..a266bd00391ebe --- /dev/null +++ b/third_party/onnxruntime_headers/README.chromium @@ -0,0 +1,6 @@ +Name: onnxruntime_headers +URL: https://github.com/microsoft/onnxruntime/tree/main/include +Revision: 6e76179a4e1e76761bfd7be2ad6d12c3f99ec938 + +Description: +This directory contains a copy of the ONNX Runtime headers. diff --git a/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/basic_types.h b/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/basic_types.h new file mode 100644 index 00000000000000..3eb4869377d406 --- /dev/null +++ b/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/basic_types.h @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +namespace onnxruntime { + +/** A computed hash value. */ +using HashValue = uint64_t; + +/** The type of an argument (input or output).*/ +enum class ArgType : uint8_t { + kInput, + kOutput, +}; + +} // namespace onnxruntime diff --git a/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/code_location.h b/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/code_location.h new file mode 100644 index 00000000000000..dbff69099ba78f --- /dev/null +++ b/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/code_location.h @@ -0,0 +1,58 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +namespace onnxruntime { +/** + CodeLocation captures information on where in the source code a message came from. +*/ +struct CodeLocation { + /** + @param file_path Usually the value of __FILE__ + @param line Usually the value of __LINE__ + @param func Usually the value of __PRETTY_FUNCTION__ or __FUNCTION__ + */ + CodeLocation(const char* file_path, const int line, const char* func) + : file_and_path{file_path}, line_num{line}, function{func} { + } + + /** + @param file_path Usually the value of __FILE__ + @param line Usually the value of __LINE__ + @param func Usually the value of __PRETTY_FUNCTION__ or __FUNCTION__ + @param stacktrace Stacktrace from source of message. + */ + CodeLocation(const char* file_path, const int line, const char* func, const std::vector& stacktrace) + : file_and_path{file_path}, line_num{line}, function{func}, stacktrace(stacktrace) { + } + + std::string FileNoPath() const { + // assuming we always have work to do, so not trying to avoid creating a new string if + // no path was removed. + return file_and_path.substr(file_and_path.find_last_of("/\\") + 1); + } + + enum Format { + kFilename, + kFilenameAndPath + }; + + std::string ToString(Format format = Format::kFilename) const { + std::ostringstream out; + out << (format == Format::kFilename ? FileNoPath() : file_and_path) << ":" << line_num << " " << function; + return out.str(); + } + // utf-8. Because on Windows we compile our code with "/utf-8". And we assume the other platforms only use utf-8. + const std::string file_and_path; + const int line_num; + // utf-8 + const std::string function; + const std::vector stacktrace; +}; + +} // namespace onnxruntime diff --git a/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/common.h b/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/common.h new file mode 100644 index 00000000000000..0822eba950f500 --- /dev/null +++ b/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/common.h @@ -0,0 +1,286 @@ +/** + * Copyright (c) 2016-present, Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// Portions Copyright (c) Microsoft Corporation + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "core/common/code_location.h" +#include "core/common/exceptions.h" +#include "core/common/make_string.h" +#include "core/common/status.h" + +namespace onnxruntime { + +using TimePoint = std::chrono::high_resolution_clock::time_point; + +#ifdef _WIN32 +#define ORT_UNUSED_PARAMETER(x) (x) +#else +#define ORT_UNUSED_PARAMETER(x) (void)(x) +#endif + +#ifndef ORT_HAVE_ATTRIBUTE +#ifdef __has_attribute +#define ORT_HAVE_ATTRIBUTE(x) __has_attribute(x) +#else +#define ORT_HAVE_ATTRIBUTE(x) 0 +#endif +#endif + +// ORT_ATTRIBUTE_UNUSED +// +// Prevents the compiler from complaining about or optimizing away variables +// that appear unused on Linux +#if ORT_HAVE_ATTRIBUTE(unused) || (defined(__GNUC__) && !defined(__clang__)) +#undef ORT_ATTRIBUTE_UNUSED +#define ORT_ATTRIBUTE_UNUSED __attribute__((__unused__)) +#else +#define ORT_ATTRIBUTE_UNUSED +#endif + +#ifdef ORT_NO_EXCEPTIONS +// Print the given final message, the message must be a null terminated char* +// ORT will abort after printing the message. +// For Android, will print to Android system log +// For other platforms, will print to stderr +void PrintFinalMessage(const char* msg); +#endif + +// macro to explicitly ignore the return value from a function call so Code Analysis doesn't complain +#define ORT_IGNORE_RETURN_VALUE(fn) \ + static_cast(fn) + +std::vector GetStackTrace(); +// these is a helper function that gets defined by platform/Telemetry +void LogRuntimeError(uint32_t session_id, const common::Status& status, const char* file, + const char* function, uint32_t line); + +// __PRETTY_FUNCTION__ isn't a macro on gcc, so use a check for _MSC_VER +// so we only define it as one for MSVC +#if (_MSC_VER && !defined(__PRETTY_FUNCTION__)) +#define __PRETTY_FUNCTION__ __FUNCTION__ +#endif + +// Capture where a message is coming from. Use __FUNCTION__ rather than the much longer __PRETTY_FUNCTION__ +#define ORT_WHERE ::onnxruntime::CodeLocation(__FILE__, __LINE__, static_cast(__FUNCTION__)) + +#define ORT_WHERE_WITH_STACK \ + ::onnxruntime::CodeLocation(__FILE__, __LINE__, static_cast(__PRETTY_FUNCTION__), ::onnxruntime::GetStackTrace()) + +#ifdef ORT_NO_EXCEPTIONS + +#define ORT_TRY if (true) +#define ORT_CATCH(x) else if (false) +#define ORT_RETHROW + +// In order to ignore the catch statement when a specific exception (not ... ) is caught and referred +// in the body of the catch statements, it is necessary to wrap the body of the catch statement into +// a lambda function. otherwise the exception referred will be undefined and cause build break +#define ORT_HANDLE_EXCEPTION(func) + +// Throw an exception with optional message. +// NOTE: The arguments get streamed into a string via ostringstream::operator<< +// DO NOT use a printf format string, as that will not work as you expect. +#define ORT_THROW(...) \ + do { \ + ::onnxruntime::PrintFinalMessage( \ + ::onnxruntime::OnnxRuntimeException( \ + ORT_WHERE_WITH_STACK, ::onnxruntime::MakeString(__VA_ARGS__)) \ + .what()); \ + abort(); \ + } while (false) + +// Just in order to mark things as not implemented. Do not use in final code. +#define ORT_NOT_IMPLEMENTED(...) \ + do { \ + ::onnxruntime::PrintFinalMessage( \ + ::onnxruntime::NotImplementedException(::onnxruntime::MakeString(__VA_ARGS__)) \ + .what()); \ + abort(); \ + } while (false) + +// Check condition. +// NOTE: The arguments get streamed into a string via ostringstream::operator<< +// DO NOT use a printf format string, as that will not work as you expect. +#define ORT_ENFORCE(condition, ...) \ + do { \ + if (!(condition)) { \ + ::onnxruntime::PrintFinalMessage( \ + ::onnxruntime::OnnxRuntimeException(ORT_WHERE_WITH_STACK, #condition, \ + ::onnxruntime::MakeString(__VA_ARGS__)) \ + .what()); \ + abort(); \ + } \ + } while (false) + +#define ORT_THROW_EX(ex, ...) \ + do { \ + ::onnxruntime::PrintFinalMessage( \ + ::onnxruntime::MakeString(#ex, "(", ::onnxruntime::MakeString(__VA_ARGS__), ")").c_str()); \ + abort(); \ + } while (false) + +#else + +#define ORT_TRY try +#define ORT_CATCH(x) catch (x) +#define ORT_RETHROW throw; + +#define ORT_HANDLE_EXCEPTION(func) func() + +// Throw an exception with optional message. +// NOTE: The arguments get streamed into a string via ostringstream::operator<< +// DO NOT use a printf format string, as that will not work as you expect. +#define ORT_THROW(...) \ + throw ::onnxruntime::OnnxRuntimeException(ORT_WHERE_WITH_STACK, ::onnxruntime::MakeString(__VA_ARGS__)) + +// Just in order to mark things as not implemented. Do not use in final code. +#define ORT_NOT_IMPLEMENTED(...) \ + throw ::onnxruntime::NotImplementedException(::onnxruntime::MakeString(__VA_ARGS__)) + +// Check condition. +// NOTE: The arguments get streamed into a string via ostringstream::operator<< +// DO NOT use a printf format string, as that will not work as you expect. +#define ORT_ENFORCE(condition, ...) \ + do { \ + if (!(condition)) { \ + throw ::onnxruntime::OnnxRuntimeException(ORT_WHERE_WITH_STACK, #condition, \ + ::onnxruntime::MakeString(__VA_ARGS__)); \ + } \ + } while (false) + +#define ORT_THROW_EX(ex, ...) \ + throw ex(__VA_ARGS__) + +#endif + +#define ORT_MAKE_STATUS(category, code, ...) \ + ::onnxruntime::common::Status(::onnxruntime::common::category, \ + ::onnxruntime::common::code, \ + ::onnxruntime::MakeString(__VA_ARGS__)) + +// Check condition. if met, return status. +#define ORT_RETURN_IF(condition, ...) \ + do { \ + if (condition) { \ + return ::onnxruntime::common::Status(::onnxruntime::common::ONNXRUNTIME, \ + ::onnxruntime::common::FAIL, \ + ::onnxruntime::MakeString(ORT_WHERE.ToString(), " ", __VA_ARGS__)); \ + } \ + } while (false) + +// Check condition. if not met, return status. +#define ORT_RETURN_IF_NOT(condition, ...) \ + ORT_RETURN_IF(!(condition), __VA_ARGS__) + +// Macros to disable the copy and/or move ctor and assignment methods +// These are usually placed in the private: declarations for a class. + +#define ORT_DISALLOW_COPY(TypeName) TypeName(const TypeName&) = delete + +#define ORT_DISALLOW_ASSIGNMENT(TypeName) TypeName& operator=(const TypeName&) = delete + +#define ORT_DISALLOW_COPY_AND_ASSIGNMENT(TypeName) \ + ORT_DISALLOW_COPY(TypeName); \ + ORT_DISALLOW_ASSIGNMENT(TypeName) + +#define ORT_DISALLOW_MOVE(TypeName) \ + TypeName(TypeName&&) = delete; \ + TypeName& operator=(TypeName&&) = delete + +#define ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TypeName) \ + ORT_DISALLOW_COPY_AND_ASSIGNMENT(TypeName); \ + ORT_DISALLOW_MOVE(TypeName) + +#define ORT_RETURN_IF_ERROR_SESSIONID(expr, session_id) \ + do { \ + auto _status = (expr); \ + if ((!_status.IsOK())) { \ + ::onnxruntime::LogRuntimeError(session_id, _status, __FILE__, static_cast(__FUNCTION__), __LINE__); \ + return _status; \ + } \ + } while (0) + +#define ORT_RETURN_IF_ERROR_SESSIONID_(expr) ORT_RETURN_IF_ERROR_SESSIONID(expr, session_id_) +#define ORT_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR_SESSIONID(expr, 0) + +#define ORT_THROW_IF_ERROR(expr) \ + do { \ + auto _status = (expr); \ + if ((!_status.IsOK())) { \ + ::onnxruntime::LogRuntimeError(0, _status, __FILE__, static_cast(__FUNCTION__), __LINE__); \ + ORT_THROW(_status); \ + } \ + } while (0) + +// use this macro when cannot early return +#define ORT_CHECK_AND_SET_RETVAL(expr) \ + do { \ + if (retval.IsOK()) { \ + retval = (expr); \ + } \ + } while (0) + +inline long long TimeDiffMicroSeconds(TimePoint start_time) { + auto end_time = std::chrono::high_resolution_clock::now(); + return std::chrono::duration_cast(end_time - start_time).count(); +} + +inline long long TimeDiffMicroSeconds(TimePoint start_time, TimePoint end_time) { + return std::chrono::duration_cast(end_time - start_time).count(); +} + +struct null_type {}; +inline std::string ToUTF8String(const std::string& s) { return s; } +#ifdef _WIN32 +/** + * Convert a wide character string to a UTF-8 string + */ +std::string ToUTF8String(const std::wstring& s); + +std::wstring ToWideString(const std::string& s); +inline std::wstring ToWideString(const std::wstring& s) { return s; } +#else +inline std::string ToWideString(const std::string& s) { return s; } +#endif + +constexpr size_t kMaxStrLen = 2048; + +// Returns whether `key` is in `container`. +// Like C++20's map/set contains() member function. +template typename AssociativeContainer, + typename LookupKey> +inline bool Contains(const AssociativeContainer& container, LookupKey&& key) { + return container.find(std::forward(key)) != container.end(); +} + +} // namespace onnxruntime diff --git a/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/const_pointer_container.h b/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/const_pointer_container.h new file mode 100644 index 00000000000000..1d821ba6092050 --- /dev/null +++ b/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/const_pointer_container.h @@ -0,0 +1,85 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +namespace onnxruntime { +/** + Container has T* entries. e.g. std::vector, and this class provides const access to those + via iterators and direct access, as the standard behavior only makes the pointer constant, + and not what is pointed too. i.e. you get a const pointer to T not a pointer to const T without this wrapper. + See https://stackoverflow.com/questions/8017036/understanding-const-iterator-with-pointers +*/ +template +class ConstPointerContainer { + public: + using T = typename std::remove_pointer::type; + + class ConstIterator { + public: + using const_iterator = typename Container::const_iterator; + using iterator_category = std::input_iterator_tag; + using value_type = T*; + using difference_type = std::ptrdiff_t; + using pointer = T**; + using reference = T*&; + + /** Construct iterator for container that will return const T* entries.*/ + explicit ConstIterator(const_iterator position) noexcept : current_{position}, item_{nullptr} {} + ConstIterator(const ConstIterator& other) = default; + ConstIterator& operator=(const ConstIterator& other) = default; + + bool operator==(const ConstIterator& other) const noexcept { return current_ == other.current_; } + bool operator!=(const ConstIterator& other) const noexcept { return current_ != other.current_; } + + ConstIterator& operator++() { + ++current_; + return *this; + } + + ConstIterator operator++(int) { + ConstIterator tmp{*this}; + ++(*this); + return tmp; + } + + const T*& operator*() const { + item_ = *current_; + return item_; + } + + const T** operator->() const { return &(operator*()); }; + + private: + const_iterator current_; + mutable const T* item_; + }; + + /** + Construct wrapper class that will provide const access to the pointers in a container of non-const pointers. + @param data Container with non-const pointers. e.g. std::vector + */ + explicit ConstPointerContainer(const Container& data) noexcept : data_(data) {} + + size_t size() const noexcept { return data_.size(); } + bool empty() const noexcept { return data_.empty(); } + + ConstIterator cbegin() const noexcept { return ConstIterator(data_.cbegin()); } + ConstIterator cend() const noexcept { return ConstIterator(data_.cend()); } + + ConstIterator begin() const noexcept { return ConstIterator(data_.cbegin()); } + ConstIterator end() const noexcept { return ConstIterator(data_.cend()); } + + const T* operator[](size_t index) const { return data_[index]; } + + const T* at(size_t index) const { + ORT_ENFORCE(index < data_.size()); + return data_[index]; + } + + private: + const Container& data_; +}; +} // namespace onnxruntime diff --git a/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/denormal.h b/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/denormal.h new file mode 100644 index 00000000000000..ca944117813116 --- /dev/null +++ b/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/denormal.h @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +namespace onnxruntime { + +// Set or unset flush-to-zero and denormal=as-zero if SSE3 instructions are supported. +// Return true if SSE3 instruction is supported, otherwise return false. +bool SetDenormalAsZero(bool on); + +} // namespace onnxruntime diff --git a/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/eigen_common_wrapper.h b/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/eigen_common_wrapper.h new file mode 100644 index 00000000000000..19efa7bcff107d --- /dev/null +++ b/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/eigen_common_wrapper.h @@ -0,0 +1,76 @@ +//----------------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +// +//----------------------------------------------------------------------------- +#pragma once +#include "onnxruntime_config.h" +// build/external/eigen/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h:162:71: +// error: ignoring attributes on template argument "Eigen::PacketType::type {aka __vector(4) float}" [-Werror=ignored-attributes] +#if defined(__GNUC__) +#pragma GCC diagnostic push +#if __GNUC__ >= 6 +#pragma GCC diagnostic ignored "-Wignored-attributes" +#endif +#pragma GCC diagnostic ignored "-Wunused-parameter" +#pragma GCC diagnostic ignored "-Wunused-result" +#ifdef HAS_DEPRECATED_COPY +#pragma GCC diagnostic ignored "-Wdeprecated-copy" +#endif +// cmake/external/eigen/unsupported/Eigen/CXX11/../../../Eigen/src/Core/arch/NEON/PacketMath.h:1633:9: +// error: ‘void* memcpy(void*, const void*, size_t)’ copying an object of non-trivial type ‘Eigen::internal::Packet4c’ +// {aka ‘struct Eigen::internal::eigen_packet_wrapper’} from an array of ‘const int8_t’ +// {aka ‘const signed char’} [-Werror=class-memaccess] +#ifdef HAS_CLASS_MEMACCESS +#pragma GCC diagnostic ignored "-Wclass-memaccess" +#endif + +// cmake/external/eigen\Eigen/src/Core/util/Meta.h:454:25: +// error: 'result_of (const unsigned long long &, const unsigned long long &)>' +// is deprecated [-Werror,-Wdeprecated-declarations] +// typedef typename std::result_of::type type1; +#ifdef HAS_DEPRECATED_DECLARATIONS +#pragma GCC diagnostic ignored "-Wdeprecated-declarations" +#endif + +// cmake/external/eigen\Eigen/CXX11/src/Tensor/TensorTrace.h:130:9: +// error: variable 'num_distinct_reduce_dims' set but not used [-Werror,-Wunused-but-set-variable] +// int num_distinct_reduce_dims = 0; +#ifdef HAS_UNUSED_BUT_SET_VARIABLE +#pragma GCC diagnostic ignored "-Wunused-but-set-variable" +#endif + +// eigen-src/unsupported/Eigen/CXX11/src/ThreadPool/EventCount.h:231:56: error: implicit conversion loses integer +// precision: 'uint64_t' (aka 'unsigned long long') to 'size_t' (aka 'unsigned long') [-Werror,-Wshorten-64-to-32] +// next = wnext == kStackMask ? nullptr : &waiters_[wnext]; +// ~~~~~~~~ ^~~~~ +#ifdef HAS_SHORTEN_64_TO_32 +#pragma GCC diagnostic ignored "-Wshorten-64-to-32" +#endif + +// eigen-src/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h:215:9: +// error: implicit capture of 'this' with a capture default of '=' is deprecated [-Werror,-Wdeprecated-this-capture] +#ifdef HAS_DEPRECATED_THIS_CAPTURE +#pragma GCC diagnostic ignored "-Wdeprecated-this-capture" +#endif + +#elif defined(_MSC_VER) +// build\windows\debug\external\eigen3\unsupported\eigen\cxx11\src/Tensor/Tensor.h(76): +// warning C4554: '&': check operator precedence for possible error; use parentheses to clarify precedence + +// unsupported\eigen\cxx11\src\Tensor\TensorUInt128.h(150,0): Warning C4245: 'initializing': conversion from '__int64' +// to 'uint64_t', signed/unsigned mismatch +#pragma warning(push) +#pragma warning(disable : 4554) +#pragma warning(disable : 4245) +#pragma warning(disable : 4127) +#endif + +#include "unsupported/Eigen/CXX11/Tensor" + +#if defined(__GNUC__) +#pragma GCC diagnostic pop +#elif defined(_MSC_VER) +#pragma warning(pop) +#endif diff --git a/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/exceptions.h b/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/exceptions.h new file mode 100644 index 00000000000000..494a770b8db985 --- /dev/null +++ b/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/exceptions.h @@ -0,0 +1,71 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "core/common/common.h" +#include "core/common/code_location.h" + +namespace onnxruntime { + +class NotImplementedException : public std::logic_error { + public: + explicit NotImplementedException(const char* _Message = "Function not yet implemented") noexcept : std::logic_error(_Message) {}; + explicit NotImplementedException(const std::string& _Message = "Function not yet implemented") noexcept : std::logic_error(_Message) {}; +}; + +class TypeMismatchException : public std::logic_error { + public: + TypeMismatchException() noexcept : logic_error("Type mismatch") {}; +}; + +class OnnxRuntimeException : public std::exception { + public: + OnnxRuntimeException(const CodeLocation& location, const std::string& msg) noexcept + : OnnxRuntimeException(location, nullptr, msg) { + } + + /** + Create a new exception that captures the location it was thrown from. + @param location Location in the source code the exception is being thrown from + @param failed_condition Optional string containing the condition that failed. + e.g. "tensor.Size() == input.Size()". May be nullptr. + @param msg Message containing additional information about the exception cause. + */ + OnnxRuntimeException(const CodeLocation& location, const char* failed_condition, const std::string& msg) + : location_{location} { + std::ostringstream ss; + + ss << location.ToString(CodeLocation::kFilenameAndPath); // output full path in case just the filename is ambiguous + if (failed_condition != nullptr) { + ss << " " << failed_condition << " was false."; + } + + ss << " " << msg << "\n"; + if (!location.stacktrace.empty()) { + ss << "Stacktrace:\n"; + // skip the first entry in the stacktrace as we have that information from location.ToString() + std::copy(std::next(location.stacktrace.begin()), location.stacktrace.end(), std::ostream_iterator(ss, "\n")); + } + + what_ = ss.str(); + } + + const char* what() const noexcept override { + return what_.c_str(); + } + + private: + const CodeLocation location_; + const std::vector stacktrace_; + std::string what_; +}; + +} // namespace onnxruntime diff --git a/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/gpu_profiler_common.h b/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/gpu_profiler_common.h new file mode 100644 index 00000000000000..00d5033ef2df40 --- /dev/null +++ b/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/gpu_profiler_common.h @@ -0,0 +1,472 @@ +#pragma once + +#include "core/common/profiler_common.h" +#include "core/common/inlined_containers.h" + +#include +#include +#include +#include +#include +#include +#include + +namespace onnxruntime { +namespace profiling { + +// The classes in this header are implemented as template/inline classes +// to avoid having to export symbols from the main onnxruntime shared library +// to ExecutionProvider (EP) shared libraries. +// More context: The main onnxruntime shared library is optimized for size +// using --gc-sections during link time to ensure that any unreferenced code +// is not retained. This poses a problem in using a design pattern where the +// (abstract) base class is implemented in the main onnxruntime shared library, +// but (concrete) subclasses are implemented in EP shared libraries. Now, because +// EP shared libraries are loaded at runtime (as of 11/2022), there will be no +// references to the base class symbols when the main onnxruntime shared library +// is compiled. Thus, the base class symbols will not be included in the +// main onnxruntime shared library. This manifests in being unable to load +// EP shared libs (because the base class symbols referenced by derived +// classes are missing). +// We solve this by implementing base classes that are common to all GPU profilers +// inline in this header. + +class ProfilerActivityBuffer { + public: + ProfilerActivityBuffer() noexcept + : data_(nullptr), size_(0) {} + + ProfilerActivityBuffer(const char* data, size_t size) noexcept + : data_(std::make_unique(size)), size_(size) { + memcpy(data_.get(), data, size_); + } + + ProfilerActivityBuffer(const ProfilerActivityBuffer& other) noexcept + : ProfilerActivityBuffer(other.GetData(), other.GetSize()) {} + + ProfilerActivityBuffer(ProfilerActivityBuffer&& other) noexcept + : ProfilerActivityBuffer() { + std::swap(data_, other.data_); + std::swap(size_, other.size_); + } + + ProfilerActivityBuffer& operator=(const ProfilerActivityBuffer& other) noexcept { + if (&other == this) { + return *this; + } + + new (this) ProfilerActivityBuffer{other}; + return *this; + } + + ProfilerActivityBuffer& operator=(ProfilerActivityBuffer&& other) noexcept { + if (&other == this) { + return *this; + } + + new (this) ProfilerActivityBuffer{std::move(other)}; + return *this; + } + + static ProfilerActivityBuffer CreateFromPreallocatedBuffer(std::unique_ptr&& buffer_ptr, size_t size) { + ProfilerActivityBuffer res{}; + res.data_ = std::move(buffer_ptr); + res.size_ = size; + return res; + } + + // accessors + char* GetData() { return data_.get(); } + const char* GetData() const { return data_.get(); } + size_t GetSize() const { return size_; } + + private: + std::unique_ptr data_; + size_t size_; +}; /* end class ProfilerActivityBuffer */ + +template +class GPUTracerManager { + public: + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(GPUTracerManager); + virtual ~GPUTracerManager() {} + + uint64_t RegisterClient() { + std::lock_guard lock(manager_instance_mutex_); + auto res = next_client_id_++; + per_client_events_by_ext_correlation_.insert({res, {}}); + ++num_active_clients_; + return res; + } + + void DeregisterClient(uint64_t client_handle) { + std::lock_guard lock(manager_instance_mutex_); + auto it = per_client_events_by_ext_correlation_.find(client_handle); + if (it == per_client_events_by_ext_correlation_.end()) { + return; + } + per_client_events_by_ext_correlation_.erase(it); + --num_active_clients_; + if (num_active_clients_ == 0 && tracing_enabled_) { + StopLogging(); + } + } + + void StartLogging() { + std::lock_guard lock(manager_instance_mutex_); + if (tracing_enabled_) { + return; + } + + auto this_as_derived = static_cast(this); + tracing_enabled_ = this_as_derived->OnStartLogging(); + } + + void Consume(uint64_t client_handle, const TimePoint& start_time, std::map& events) { + auto this_as_derived = static_cast(this); + events.clear(); + { + // Flush any pending activity records before starting + // to process the accumulated activity records. + std::lock_guard lock_manager(manager_instance_mutex_); + if (!tracing_enabled_) { + return; + } + + this_as_derived->FlushActivities(); + } + + std::vector activity_buffers; + { + std::lock_guard lock(unprocessed_activity_buffers_mutex_); + std::swap(unprocessed_activity_buffers_, activity_buffers); + unprocessed_activity_buffers_.clear(); + } + + { + // Ensure that at most one thread is working through the activity buffers at any time. + std::lock_guard lock_two(activity_buffer_processor_mutex_); + this_as_derived->ProcessActivityBuffers(activity_buffers, start_time); + auto it = per_client_events_by_ext_correlation_.find(client_handle); + if (it == per_client_events_by_ext_correlation_.end()) { + return; + } + std::swap(events, it->second); + } + } + + void PushCorrelation(uint64_t client_handle, + uint64_t external_correlation_id, + TimePoint profiling_start_time) { + auto this_as_derived = static_cast(this); + std::lock_guard lock(manager_instance_mutex_); + if (!tracing_enabled_) { + return; + } + + auto it = per_client_events_by_ext_correlation_.find(client_handle); + if (it == per_client_events_by_ext_correlation_.end()) { + // not a registered client, do nothing + return; + } + + // external_correlation_id is simply the timestamp of this event, + // relative to profiling_start_time. i.e., it was computed as: + // external_correlation_id = + // std::chrono::duration_cast(event_start_time - profiling_start_time).count() + // + // Because of the relative nature of the external_correlation_id, the same + // external_correlation_id can be reused across different clients, which then makes it + // impossible to recover the client from the external_correlation_id, which in turn + // makes it impossible to map events (which are tagged with external_correlation_id) to clients. + // + // To address these difficulties, we construct a new correlation_id (let's call it unique_cid) + // as follows: + // unique_cid = + // external_correlation_id + + // std::chrono::duration_cast(profiling_start_time.time_since_epoch()).count() + // now, unique_cid is monotonically increasing with time, so it can be used to reliably map events to clients. + // + // Of course, clients expect lists of events to be returned (on a call to Consume()), that are + // still keyed on the external_correlation_id that they've specified here, so we need to remember the + // offset to be subtracted + uint64_t offset = std::chrono::duration_cast(profiling_start_time.time_since_epoch()).count(); + auto unique_cid = external_correlation_id + offset; + unique_correlation_id_to_client_offset_[unique_cid] = std::make_pair(client_handle, offset); + this_as_derived->PushUniqueCorrelation(unique_cid); + } + + void PopCorrelation(uint64_t& popped_external_correlation_id) { + auto this_as_derived = static_cast(this); + std::lock_guard lock(manager_instance_mutex_); + if (!tracing_enabled_) { + return; + } + uint64_t unique_cid; + this_as_derived->PopUniqueCorrelation(unique_cid); + // lookup the offset and subtract it before returning popped_external_correlation_id to the client + auto client_it = unique_correlation_id_to_client_offset_.find(unique_cid); + if (client_it == unique_correlation_id_to_client_offset_.end()) { + popped_external_correlation_id = 0; + return; + } + popped_external_correlation_id = unique_cid - client_it->second.second; + } + + void PopCorrelation() { + uint64_t unused; + PopCorrelation(unused); + } + + protected: + GPUTracerManager() { + auto this_as_derived = static_cast(this); + uint64_t gpu_ts1, gpu_ts2, cpu_ts; + + // Get the CPU and GPU timestamps to warm up + gpu_ts1 = this_as_derived->GetGPUTimestampInNanoseconds(); + cpu_ts = this->GetCPUTimestampInNanoseconds(); + + // Estimate the skew/offset between the CPU and GPU timestamps. + gpu_ts1 = this_as_derived->GetGPUTimestampInNanoseconds(); + cpu_ts = this->GetCPUTimestampInNanoseconds(); + gpu_ts2 = this_as_derived->GetGPUTimestampInNanoseconds(); + + auto gpu_ts = (gpu_ts1 + gpu_ts2) / 2; + offset_to_add_to_gpu_timestamps_ = cpu_ts - gpu_ts; + } + +#if 0 + // Functional API to be implemented by subclasses + // Included here only for documentation purposes +protected: + bool OnStartLogging(); + void OnStopLogging(); + void ProcessActivityBuffers(const std::vector& buffers, + const TimePoint& start_time); + bool PushUniqueCorrelation(uint64_t unique_cid); + void PopUniqueCorrelation(uint64_t& popped_unique_cid); + void FlushActivities(); + uint64_t GetGPUTimestampInNanoseconds(); +#endif + + void EnqueueActivityBuffer(ProfilerActivityBuffer&& buffer) { + std::lock_guard lock(unprocessed_activity_buffers_mutex_); + unprocessed_activity_buffers_.emplace_back(std::move(buffer)); + } + + // To be called by subclasses only from ProcessActivityBuffers + void MapEventToClient(uint64_t tracer_correlation_id, EventRecord&& event) { + auto it = tracer_correlation_to_unique_correlation_.find(tracer_correlation_id); + if (it == tracer_correlation_to_unique_correlation_.end()) { + // We're yet to receive a mapping to unique_correlation_id for this tracer_correlation_id + DeferEventMapping(std::move(event), tracer_correlation_id); + return; + } + auto unique_correlation_id = it->second; + auto p_event_list = GetEventListForUniqueCorrelationId(unique_correlation_id); + if (p_event_list != nullptr) { + p_event_list->emplace_back(std::move(event)); + } + } + + // To be called by subclasses only from ProcessActivityBuffers + void NotifyNewCorrelation(uint64_t tracer_correlation_id, uint64_t unique_correlation_id) { + tracer_correlation_to_unique_correlation_[tracer_correlation_id] = unique_correlation_id; + auto pending_it = events_pending_client_mapping_.find(tracer_correlation_id); + if (pending_it == events_pending_client_mapping_.end()) { + return; + } + // Map the pending events to the right client + MapEventsToClient(unique_correlation_id, std::move(pending_it->second)); + events_pending_client_mapping_.erase(pending_it); + } + + uint64_t NormalizeGPUTimestampToCPUEpoch(uint64_t gpu_timestamp_in_nanoseconds) { + return gpu_timestamp_in_nanoseconds + this->offset_to_add_to_gpu_timestamps_; + } + + private: + // Requires: manager_instance_mutex_ should be held + void StopLogging() { + auto this_as_derived = static_cast(this); + if (!tracing_enabled_) { + return; + } + this_as_derived->OnStopLogging(); + tracing_enabled_ = false; + Clear(); + } + + // Requires: manager_instance_mutex_ should be held + void Clear() { + unprocessed_activity_buffers_.clear(); + unique_correlation_id_to_client_offset_.clear(); + per_client_events_by_ext_correlation_.clear(); + tracer_correlation_to_unique_correlation_.clear(); + events_pending_client_mapping_.clear(); + } + + Events* GetEventListForUniqueCorrelationId(uint64_t unique_correlation_id) { + auto client_it = unique_correlation_id_to_client_offset_.find(unique_correlation_id); + if (client_it == unique_correlation_id_to_client_offset_.end()) { + return nullptr; + } + + // See the comments on the GetUniqueCorrelationId method for an explanation of + // of this offset computation and why it's required. + auto const& client_handle_offset = client_it->second; + auto external_correlation = unique_correlation_id - client_handle_offset.second; + auto& event_list = per_client_events_by_ext_correlation_[client_handle_offset.first][external_correlation]; + return &event_list; + } + + void MapEventsToClient(uint64_t unique_correlation_id, std::vector&& events) { + auto p_event_list = GetEventListForUniqueCorrelationId(unique_correlation_id); + if (p_event_list != nullptr) { + p_event_list->insert(p_event_list->end(), + std::make_move_iterator(events.begin()), + std::make_move_iterator(events.end())); + } + } + + void DeferEventMapping(EventRecord&& event, uint64_t tracer_correlation_id) { + events_pending_client_mapping_[tracer_correlation_id].emplace_back(std::move(event)); + } + + uint64_t GetCPUTimestampInNanoseconds() { + return std::chrono::duration_cast( + std::chrono::high_resolution_clock::now().time_since_epoch()) + .count(); + } + + std::mutex manager_instance_mutex_; + uint64_t next_client_id_ = 1; + uint64_t num_active_clients_ = 0; + bool tracing_enabled_ = false; + std::mutex unprocessed_activity_buffers_mutex_; + std::mutex activity_buffer_processor_mutex_; + + // Unprocessed activity buffers + std::vector unprocessed_activity_buffers_; + + // Keyed on unique_correlation_id -> (client_id/client_handle, offset) + // unique_correlation_id - offset == external_correlation_id + InlinedHashMap> unique_correlation_id_to_client_offset_; + + // Keyed on tracer_correlation_id -> unique_correlation_id + InlinedHashMap tracer_correlation_to_unique_correlation_; + + // client_id/client_handle -> external_correlation_id -> events + InlinedHashMap> per_client_events_by_ext_correlation_; + + // Keyed on tracer correlation_id, keeps track of activity records + // for which we haven't established the external_correlation_id yet. + InlinedHashMap> events_pending_client_mapping_; + + // An offset to add to (the possibly skewed) GPU timestamps + // to normalize GPU timestamps with CPU timestamps + int64_t offset_to_add_to_gpu_timestamps_; +}; /* class GPUTracerManager */ + +// Base class for a GPU profiler +template +class GPUProfilerBase : public EpProfiler { + protected: + GPUProfilerBase() = default; + virtual ~GPUProfilerBase() {} + + void MergeEvents(std::map& events_to_merge, Events& events) { + Events merged_events; + + auto event_iter = std::make_move_iterator(events.begin()); + auto event_end = std::make_move_iterator(events.end()); + for (auto& map_iter : events_to_merge) { + if (map_iter.second.empty()) { + continue; + } + + auto ts = static_cast(map_iter.first); + + // find the last occurrence of a matching timestamp, + // if one exists + while (event_iter != event_end && + (event_iter->ts < ts || + (event_iter->ts == ts && + (event_iter + 1) != event_end && + (event_iter + 1)->ts == ts))) { + merged_events.emplace_back(*event_iter); + ++event_iter; + } + + bool copy_op_names = false; + std::string op_name; + std::string parent_name; + + if (event_iter != event_end && event_iter->ts == ts) { + // We've located a parent event, copy the op_name and set + // this event's parent_name property to the name of the parent. + copy_op_names = true; + op_name = event_iter->args["op_name"]; + parent_name = event_iter->name; + merged_events.emplace_back(*event_iter); + ++event_iter; + } + + for (auto& evt : map_iter.second) { + if (copy_op_names) { + // If we have found a matching parent event, + // then inherit some names from the parent. + evt.args["op_name"] = op_name; + evt.args["parent_name"] = parent_name; + } + } + + merged_events.insert(merged_events.end(), + std::make_move_iterator(map_iter.second.begin()), + std::make_move_iterator(map_iter.second.end())); + } + + // move any remaining events + merged_events.insert(merged_events.end(), event_iter, event_end); + std::swap(events, merged_events); + } + + uint64_t client_handle_; + TimePoint profiling_start_time_; + + public: + virtual bool StartProfiling(TimePoint profiling_start_time) override { + auto& manager = TManager::GetInstance(); + manager.StartLogging(); + profiling_start_time_ = profiling_start_time; + return true; + } + + virtual void EndProfiling(TimePoint start_time, Events& events) override { + auto& manager = TManager::GetInstance(); + std::map event_map; + manager.Consume(client_handle_, start_time, event_map); + MergeEvents(event_map, events); + } + + virtual void Start(uint64_t id) override { + auto& manager = TManager::GetInstance(); + manager.PushCorrelation(client_handle_, id, profiling_start_time_); + } + + virtual void Stop(uint64_t) override { + auto& manager = TManager::GetInstance(); + manager.PopCorrelation(); + } +}; /* class GPUProfilerBase */ + +// Convert a pointer to a hex string +static inline std::string PointerToHexString(const void* ptr) { + std::ostringstream sstr; + sstr << std::hex << ptr; + return sstr.str(); +} + +} /* end namespace profiling */ +} /* end namespace onnxruntime */ diff --git a/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/hash_combine.h b/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/hash_combine.h new file mode 100644 index 00000000000000..5662a329ea77f3 --- /dev/null +++ b/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/hash_combine.h @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +namespace onnxruntime { + +// Combine hash value `seed` with hash value `h`, updating `seed` in place. +// TODO(edgchen1) find a better implementation? e.g., see a more recent version of boost::hash_combine() +inline void HashCombineWithHashValue(size_t h, size_t& seed) { + seed ^= h + 0x9e3779b9 + (seed << 6) + (seed >> 2); +} + +// Combine hash value `seed` with the hash value of `value`, updating `seed` in place. +// The hash value computation is specified by the `Hash` template parameter. +template > +inline void HashCombine(const T& value, size_t& seed) { + HashCombineWithHashValue(Hash{}(value), seed); +} + +} // namespace onnxruntime diff --git a/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/inlined_containers.h b/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/inlined_containers.h new file mode 100644 index 00000000000000..bd61e691a5d5dd --- /dev/null +++ b/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/inlined_containers.h @@ -0,0 +1,175 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +#include "core/common/inlined_containers_fwd.h" + +#ifndef DISABLE_ABSEIL + +#ifdef _MSC_VER +#pragma warning(push) +// C4127: conditional expression is constant +#pragma warning(disable : 4127) +// C4324: structure was padded due to alignment specifier +// Usage of alignas causes some internal padding in places. +#pragma warning(disable : 4324) +#endif // _MSC_VER + +#include +#include + +#include +#include + +#ifdef _MSC_VER +#pragma warning(pop) +#endif // _MSC_VER + +#else // DISABLE_ABSEIL + +#include +#include +#include +#include + +#endif // DISABLE_ABSEIL + +namespace onnxruntime { + +#ifndef DISABLE_ABSEIL +// InlinedHashSet and InlinedHashMap are preferred +// hash based containers. They store their values in the +// buckets array that is allocated in one shot. It eliminates +// per-node new/delete calls. Always call reserve() on any hash set/map +// when the number of items is known in advance. +// This does not allocate a dummy 'end' node on default construction. +template +class InlinedHashSet : public absl::flat_hash_set, + absl::container_internal::hash_default_eq, + Allocator> { + using Base = absl::flat_hash_set, + absl::container_internal::hash_default_eq, + Allocator>; + + public: + using Base::Base; +}; + +template +class InlinedHashMap : public absl::flat_hash_map, + absl::container_internal::hash_default_eq, + Allocator> { + using Base = absl::flat_hash_map, + absl::container_internal::hash_default_eq, + Allocator>; + + public: + using Base::Base; +}; + +// Use this hash set/map where pointer stability is required, otherwise use +// InlinedHashSet and InlinedHashMap +// This does not allocate a dummy 'end' node on default construction. +// Use reserve() when the number of elements is known. +template +class NodeHashSet : public absl::node_hash_set, + absl::container_internal::hash_default_eq, + Allocator> { + using Base = absl::node_hash_set, + absl::container_internal::hash_default_eq, + Allocator>; + + public: + using Base::Base; +}; + +template +class NodeHashMap : public absl::node_hash_map, + absl::container_internal::hash_default_eq, + Allocator> { + using Base = absl::node_hash_map, + absl::container_internal::hash_default_eq, + Allocator>; + + public: + using Base::Base; +}; + +#else // DISABLE_ABSEIL + +template +class InlinedHashSet : public std::unordered_set, + std::equal_to, + Allocator> { + using Base = std::unordered_set, + std::equal_to, + Allocator>; + + public: + using Base::Base; +}; + +template +class InlinedHashMap : public std::unordered_map, + std::equal_to, + Allocator> { + using Base = std::unordered_map, + std::equal_to, + Allocator>; + + public: + using Base::Base; +}; + +// Use this hash set/map where pointer stability is required, otherwise use +// InlinedHashSet and InlinedHashMap +// This does not allocate a dummy 'end' node on default construction. +// Use reserve() when the number of elements is known. +template +class NodeHashSet : public std::unordered_set, + std::equal_to, + Allocator> { + using Base = std::unordered_set, + std::equal_to, + Allocator>; + + public: + using Base::Base; +}; + +template +class NodeHashMap : public std::unordered_map, + std::equal_to, + Allocator> { + using Base = std::unordered_map, + std::equal_to, + Allocator>; + + public: + using Base::Base; +}; + +#endif // DISABLE_ABSEIL + +} // namespace onnxruntime diff --git a/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/inlined_containers_fwd.h b/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/inlined_containers_fwd.h new file mode 100644 index 00000000000000..21a55f9b315bc4 --- /dev/null +++ b/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/inlined_containers_fwd.h @@ -0,0 +1,151 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include + +#ifndef DISABLE_ABSEIL +#ifdef _MSC_VER +#pragma warning(push) +// C4127: conditional expression is constant +#pragma warning(disable : 4127) +// C4324: structure was padded due to alignment specifier +// Usage of alignas causes some internal padding in places. +#pragma warning(disable : 4324) +#else +// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=102329#c2 +#if !defined(__clang__) && defined(__GNUC__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wmaybe-uninitialized" +#endif +#endif // _MSC_VER + +#include + +#ifdef _MSC_VER +#pragma warning(pop) +#else +#if !defined(__clang__) && defined(__GNUC__) +#pragma GCC diagnostic pop +#endif +#endif // _MSC_VER + +#else + +#include + +#endif // DISABLE_ABSEIL + +// Forward declarations for contexts where abseil can not be compiled and +// not really needed but we want to have it in the headers that are included +// e.g. CUDA 10 and .CU files +// InlinedVector seems to be fine with old CUDA + +//===- llvm/ADT/SmallVector.h - 'Normally small' vectors --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// This file contains code and comments derived from llvm/ADT/SmallVector.h +// +// Specifically CalculateInlinedVectorDefaultInlinedElements() template is derived from +// CalculateSmallVectorDefaultInlinedElements() and its comments. + +namespace onnxruntime { +#ifndef DISABLE_ABSEIL +/// Inspired by LLVM SmallVector with ONNX Runtime adjustments for abseil. +/// https://github.com/llvm/llvm-project/blob/a85b37d0ca819776c6034c2dbda2b21e54e3393a/llvm/include/llvm/ADT/SmallVector.h#L1128-L1179 +/// +/// Helper class for calculating the default number of inline elements for +/// `InlinedVector`. +/// This produces the following on MSVC x64 +/// int8_t -> 41 +// int16_t -> 21 +// int32_t -> 11 +// int64_t -> 6 +// std::string 40 -> 1 +template +struct CalculateInlinedVectorDefaultInlinedElements { + // Parameter controlling the default number of inlined elements + // for `InlinedVector`. + // + // The default number of inlined elements ensures that + // 1. There is at least one inlined element. + // 2. `sizeof(InlinedVector) <= kPreferredInlinedVectorSizeof` unless + // it contradicts 1. + static constexpr size_t kPreferredInlinedVectorSizeof = 64; + + // Largest allowed element size for default element count calculation. + static constexpr size_t kElementSizeCutoff = 256; + + // static_assert that sizeof(T) is not "too big". + // + // Because the InlinedVector must have at least one inlined element, it is possible + // for an arbitrarily large inlined element to allocate an arbitrarily large + // amount of inline storage. So we want to call attention to these cases and + // make sure that users are making an intentional decision if they request a lot of inline storage. + // + // We want this assertion to trigger in pathological cases, but otherwise + // not be too easy to hit. To accomplish that, the cutoff is actually somewhat + // larger than kPreferredInlinedVectorSizeof (otherwise, + // `InlinedVector>` would be one easy way to trip it, and that + // pattern seems useful in practice). + // + // One wrinkle is that this assertion is in theory non-portable, since + // sizeof(absl::InlinedVector) is in general platform-dependent. However, we don't expect this + // to be much of an issue, because most LLVM development happens on 64-bit + // hosts, and therefore sizeof(T) is expected to *decrease* when compiled for + // 32-bit hosts, dodging the issue. The reverse situation, where development + // happens on a 32-bit host and then fails due to sizeof(T) *increasing* on a + // 64-bit host, is expected to be very rare. + static_assert( + sizeof(T) <= kElementSizeCutoff, + "You are trying to use a default number of inlined elements for " + "`InlinedVector` but `sizeof(T)` is really big! Please use an " + "explicit number of inlined elements with `InlinedVector` to make " + "sure you really want that much inline storage."); + + // Discount the size of the header itself when calculating the maximum inline + // bytes. + static constexpr size_t InlinedVectorHeaderSize = sizeof(absl::InlinedVector) - sizeof(T); + static constexpr size_t PreferredInlineBytes = kPreferredInlinedVectorSizeof - InlinedVectorHeaderSize; + static constexpr size_t NumElementsThatFit = PreferredInlineBytes / sizeof(T); + static constexpr size_t value = + NumElementsThatFit == 0 ? 1 : NumElementsThatFit; +}; + +// Use InlinedVector for small arrays that can fit on a stack with a default +// value pre-calculated. +// Use TensorShapeVector for shapes. +template ::value, + typename Allocator = std::allocator> +using InlinedVector = absl::InlinedVector; + +#else + +template > +using InlinedVector = std::vector; + +#endif // DISABLE_ABSEIL + +template > +class InlinedHashSet; + +template >> +class InlinedHashMap; + +template > +class NodeHashSet; + +template >> +class NodeHashMap; +} // namespace onnxruntime diff --git a/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/logging/capture.h b/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/logging/capture.h new file mode 100644 index 00000000000000..13d3a3ad17aff5 --- /dev/null +++ b/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/logging/capture.h @@ -0,0 +1,115 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include "core/common/common.h" +#include "core/common/code_location.h" +#include "core/common/logging/severity.h" + +namespace onnxruntime { +namespace logging { + +class Logger; +enum class DataType; + +/** + Class to capture the details of a log message. +*/ +class Capture { + public: + /** + Initializes a new instance of the Capture class. + @param logger The logger. + @param severity The severity. + @param category The category. + @param dataType Type of the data. + @param location The file location the log message is coming from. + */ + Capture(const Logger& logger, logging::Severity severity, const char* category, + logging::DataType dataType, const CodeLocation& location) + : logger_{&logger}, severity_{severity}, category_{category}, data_type_{dataType}, location_{location} { + } + + /** + The stream that can capture the message via operator<<. + @returns Output stream. + */ + std::ostream& Stream() noexcept { + return stream_; + } + +#ifdef _MSC_VER +// add SAL annotation for printf format string. requires Code Analysis to run to validate usage. +#define msvc_printf_check _Printf_format_string_ +#define __attribute__(x) // Disable for MSVC. Supported by GCC and CLang. +#else +#define msvc_printf_check +#endif + + /** + Captures a printf style log message. + @param name="format">The printf format. + @param name="">Arguments to the printf format if needed. + @remarks + A maximum of 2K of output will be captured currently. + Non-static method, so 'this' is implicit first arg, and we use format(printf(2,3) + */ + void CapturePrintf(msvc_printf_check const char* format, ...) __attribute__((format(printf, 2, 3))); + + /** + Process a printf style log message. + @param format The printf format. + @param ... Arguments to the printf format if needed. + @remarks + A maximum of 2K of output will be captured currently. + Note: As va_list is 'char *', we have to disambiguate this from CapturePrintf + so that something like "One string: %s", "the string" does not consider "the string" + to be the va_list. + */ + void ProcessPrintf(msvc_printf_check const char* format, va_list args); + + logging::Severity Severity() const noexcept { + return severity_; + } + + char SeverityPrefix() const noexcept { + // Carefully setup so severity_ is a valid index + GSL_SUPPRESS(bounds.2) { + return logging::SEVERITY_PREFIX[static_cast(severity_)]; + } + } + + const char* Category() const noexcept { + return category_; + } + + logging::DataType DataType() const noexcept { + return data_type_; + } + + const CodeLocation& Location() const noexcept { + return location_; + } + + std::string Message() const noexcept { + return stream_.str(); + } + + ~Capture(); + + private: + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Capture); + + const Logger* logger_; + const logging::Severity severity_; + const char* category_; + const logging::DataType data_type_; + const CodeLocation location_; + + std::ostringstream stream_; +}; +} // namespace logging +} // namespace onnxruntime diff --git a/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/logging/isink.h b/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/logging/isink.h new file mode 100644 index 00000000000000..fd011e71611fc8 --- /dev/null +++ b/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/logging/isink.h @@ -0,0 +1,46 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +#include "core/common/logging/logging.h" +#include "core/common/logging/sink_types.h" + +namespace onnxruntime { +namespace logging { +class ISink { + public: + explicit ISink(SinkType type = SinkType::BaseSink) : type_(type) {} + + SinkType GetType() const { return type_; } + + /** + Sends the message to the sink. + @param timestamp The timestamp. + @param logger_id The logger identifier. + @param message The captured message. + */ + void Send(const Timestamp& timestamp, const std::string& logger_id, const Capture& message) { + SendImpl(timestamp, logger_id, message); + } + + /** + Sends a Profiling Event Record to the sink. + @param Profiling Event Record + */ + virtual void SendProfileEvent(profiling::EventRecord&) const {}; + + virtual ~ISink() = default; + + private: + SinkType type_; + + // Make Code Analysis happy by disabling all for now. Enable as needed. + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ISink); + + virtual void SendImpl(const Timestamp& timestamp, const std::string& logger_id, const Capture& message) = 0; +}; +} // namespace logging +} // namespace onnxruntime diff --git a/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/logging/logging.h b/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/logging/logging.h new file mode 100644 index 00000000000000..ab2c476f2975a8 --- /dev/null +++ b/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/logging/logging.h @@ -0,0 +1,421 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "core/common/common.h" +#include "core/common/profiler_common.h" +#include "core/common/logging/capture.h" +#include "core/common/logging/macros.h" +#include "core/common/logging/severity.h" +#include "core/common/logging/sink_types.h" +#include "date/date.h" + +/* + + Logging overview and expected usage: + + At program startup: + * Create one or more ISink instances. If multiple, combine using composite_sink. + * Create a LoggingManager instance with the sink/s with is_default_instance set to true + * Only one instance should be created in this way, and it should remain valid for + until the program no longer needs to produce log output. + + You can either use the static default Logger which LoggingManager will create when constructed + via LoggingManager::DefaultLogger(), or separate Logger instances each with different log ids + via LoggingManager::CreateLogger. + + The log id is passed to the ISink instance with the sink determining how the log id is used + in the output. + + LoggingManager + * creates the Logger instances used by the application + * provides a static default logger instance + * owns the log sink instance + * applies checks on severity and output of user data + + The log macros create a Capture instance to capture the information to log. + If the severity and/or user filtering settings would prevent logging, no evaluation + of the log arguments will occur, so no performance cost beyond the severity and user + filtering check. + + A sink can do further filter as needed. + +*/ + +namespace onnxruntime { + +namespace logging { + +using Timestamp = std::chrono::time_point; + +// C++20 has operator<< in std::chrono for Timestamp type but mac builds need additional checks +// to ensure usage is valid. +// TODO: As we enable C++20 on other platforms we may need similar checks. +// define a temporary value to determine whether to use the std::chrono or date implementation. +#define ORT_USE_CXX20_STD_CHRONO __cplusplus >= 202002L + +// Apply constraints for mac builds +#if __APPLE__ +#include + +// Catalyst check must be first as it has both TARGET_OS_MACCATALYST and TARGET_OS_MAC set +#if TARGET_OS_MACCATALYST +// maccatalyst requires version 16.3 +#if (defined(__IPHONE_OS_VERSION_MIN_REQUIRED) && __IPHONE_OS_VERSION_MIN_REQUIRED < 160300) +#undef ORT_USE_CXX20_STD_CHRONO +#endif + +#elif TARGET_OS_MAC +// Xcode added support for C++20's std::chrono::operator<< in SDK version 14.4, +// but the target macOS version must also be >= 13.3 for it to be used. +#if (defined(__MAC_OS_X_VERSION_MAX_ALLOWED) && __MAC_OS_X_VERSION_MAX_ALLOWED < 140400) || \ + (defined(__MAC_OS_X_VERSION_MIN_REQUIRED) && __MAC_OS_X_VERSION_MIN_REQUIRED < 130300) +#undef ORT_USE_CXX20_STD_CHRONO +#endif + +#endif +#endif // __APPLE__ + +#if ORT_USE_CXX20_STD_CHRONO +namespace timestamp_ns = std::chrono; +#else +namespace timestamp_ns = ::date; +#endif + +#undef ORT_USE_CXX20_STD_CHRONO + +#ifndef NDEBUG +ORT_ATTRIBUTE_UNUSED static bool vlog_enabled = true; // Set directly based on your needs. +#else +constexpr bool vlog_enabled = false; // no VLOG output +#endif + +enum class DataType { + SYSTEM = 0, ///< System data. + USER = 1 ///< Contains potentially sensitive user data. +}; + +// Internal log categories. +// Logging interface takes const char* so arbitrary values can also be used. +struct Category { + static const char* onnxruntime; ///< General output + static const char* System; ///< Log output regarding interactions with the host system + // TODO: What other high level categories are meaningful? Model? Optimizer? Execution? +}; + +/// +/// ORT TraceLogging keywords for categories of dynamic logging enablement +/// +enum class ORTTraceLoggingKeyword : uint64_t { + Session = 0x1, // ORT Session TraceLoggingWrite + Logs = 0x2, // LOGS() Macro ORT logs. Pair with an appropriate level depending on detail required + Reserved1 = 0x4, // Reserved if we want to add some specific sub-categories instead of just LOGS() or other uses + Reserved2 = 0x8, + Reserved3 = 0x10, + Reserved4 = 0x20, + Reserved5 = 0x40, + Reserved6 = 0x80, + Profiling = 0x100 // Enables profiling. At higher levels >5 can impact inference performance +}; + +class ISink; +class Logger; +class Capture; + +/// +/// The logging manager. +/// Owns the log sink and potentially provides a default Logger instance. +/// Provides filtering based on a minimum LogSeverity level, and of messages with DataType::User if enabled. +/// +class LoggingManager final { + public: + enum InstanceType { + Default, ///< Default instance of LoggingManager that should exist for the lifetime of the program + Temporal ///< Temporal instance. CreateLogger(...) should be used, however DefaultLogger() will NOT be provided via this instance. + }; + + /** + Initializes a new instance of the LoggingManager class. + @param sink The sink to write to. Use CompositeSink if you need to write to multiple places. + @param default_min_severity The default minimum severity. Messages with lower severity will be ignored unless + overridden in CreateLogger. + @param default_filter_user_data If set to true ignore messages with DataType::USER unless overridden in CreateLogger. + @param instance_type If InstanceType::Default, this is the default instance of the LoggingManager + and is expected to exist for the lifetime of the program. + It creates and owns the default logger that calls to the static DefaultLogger method return. + @param default_logger_id Logger Id to use for the default logger. nullptr/ignored if instance_type == Temporal. + @param default_max_vlog_level Default maximum level for VLOG messages to be created unless overridden in CreateLogger. + Requires a severity of kVERBOSE for VLOG messages to be logged. + */ + LoggingManager(std::unique_ptr sink, Severity default_min_severity, bool default_filter_user_data, + InstanceType instance_type, + const std::string* default_logger_id = nullptr, + int default_max_vlog_level = -1); + + /** + Creates a new logger instance which will use the provided logger_id and default severity and vlog levels. + @param logger_id The log identifier. + @returns A new Logger instance that the caller owns. + */ + std::unique_ptr CreateLogger(const std::string& logger_id); + + /** + Creates a new logger instance which will use the provided logger_id, severity and vlog levels. + @param logger_id The log identifier. + @param min_severity The minimum severity. Requests to create messages with lower severity will be ignored. + @param filter_user_data If set to true ignore messages with DataType::USER. + @param max_vlog_level Maximum level for VLOG messages to be created. + @returns A new Logger instance that the caller owns. + */ + std::unique_ptr CreateLogger(const std::string& logger_id, + Severity min_severity, bool filter_user_data, int max_vlog_level = -1); + + /** + Gets the default logger instance if set. Throws if no default logger is currently registered. + @remarks + Creating a LoggingManager instance with is_default_instance == true registers a default logger. + Note that the default logger is only valid until the LoggerManager that registered it is destroyed. + @returns The default logger if available. + */ + static const Logger& DefaultLogger(); + + /** + Return a boolean indicating if the default logger has been initialized + */ + static bool HasDefaultLogger() { return nullptr != s_default_logger_; } + + /** + Gets the default instance of the LoggingManager. + */ + static LoggingManager* GetDefaultInstance(); + + /** + Removes a Sink if one is present + */ + void RemoveSink(SinkType sinkType); + + /** + Adds a Sink to the current sink creating a CompositeSink if necessary + Sinks types must be unique + @param severity The severity level for the new Sink + */ + bool AddSinkOfType(SinkType sinkType, std::function()> sinkFactory, logging::Severity severity); + + /** + Change the minimum severity level for log messages to be output by the default logger. + @param severity The severity. + */ + static void SetDefaultLoggerSeverity(Severity severity); + + /** + Change the maximum verbosity level for log messages to be output by the default logger. + @remarks + To activate the verbose log, the logger severity must also be set to kVERBOSE. + @param vlog_level The verbosity level. + */ + static void SetDefaultLoggerVerbosity(int vlog_level); + + /** + Logs a FATAL level message and creates an exception that can be thrown with error information. + @param category The log category. + @param location The location the log message was generated. + @param format_str The printf format string. + @param ... The printf arguments. + @returns A new Logger instance that the caller owns. + */ + static std::exception LogFatalAndCreateException(const char* category, + const CodeLocation& location, + const char* format_str, ...); + + /** + Logs the message using the provided logger id. + @param logger_id The log identifier. + @param message The log message. + */ + void Log(const std::string& logger_id, const Capture& message) const; + + /** + Sends a Profiling Event Record to the sink. + @param Profiling Event Record + */ + void SendProfileEvent(profiling::EventRecord& eventRecord) const; + ~LoggingManager(); + + private: + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(LoggingManager); + + Timestamp GetTimestamp() const noexcept; + void CreateDefaultLogger(const std::string& logger_id); + + std::unique_ptr sink_; +#ifdef _WIN32 + mutable std::mutex sink_mutex_; +#endif + Severity default_min_severity_; + const bool default_filter_user_data_; + const int default_max_vlog_level_; + bool owns_default_logger_; + + static Logger* s_default_logger_; + + struct Epochs { + const std::chrono::time_point high_res; + const std::chrono::time_point system; + const std::chrono::minutes localtime_offset_from_utc; + }; + + static const Epochs& GetEpochs() noexcept; +}; + +/** + Logger provides a per-instance log id. Everything else is passed back up to the LoggingManager +*/ +class Logger { + public: + /** + Initializes a new instance of the Logger class. + @param loggingManager The logging manager. + @param id The identifier for messages coming from this Logger. + @param severity Minimum severity for messages to be created and logged. + @param filter_user_data Should USER data be filtered from output. + @param vlog_level Minimum level for VLOG messages to be created. Note that a severity of kVERBOSE must be provided + for VLOG messages to be logged. + */ + Logger(const LoggingManager& loggingManager, std::string id, + Severity severity, bool filter_user_data, int vlog_level) + : logging_manager_{&loggingManager}, + id_{id}, + min_severity_{severity}, + filter_user_data_{filter_user_data}, + max_vlog_level_{vlog_level} { + } + + /** + Get the minimum severity level for log messages to be output. + @returns The severity. + */ + Severity GetSeverity() const noexcept { return min_severity_; } + + /** + Change the minimum severity level for log messages to be output. + @param severity The severity. + */ + void SetSeverity(Severity severity) noexcept { min_severity_ = severity; } + + /** + Change the maximum verbosity level for log messages to be output. + @remarks + To activate the verbose log, the logger severity must also be set to kVERBOSE. + @param vlog_level The verbosity. + */ + void SetVerbosity(int vlog_level) noexcept { max_vlog_level_ = vlog_level; } + + /** + Check if output is enabled for the provided LogSeverity and DataType values. + @param severity The severity. + @param data_type Type of the data. + @returns True if a message with these values will be logged. + */ + bool OutputIsEnabled(Severity severity, DataType data_type) const noexcept { + return (severity >= min_severity_ && (data_type != DataType::USER || !filter_user_data_)); + } + + /** + Return the maximum VLOG level allowed. Disabled unless logging VLOG messages + */ + int VLOGMaxLevel() const noexcept { + return min_severity_ > Severity::kVERBOSE ? -1 : max_vlog_level_; + } + + /** + Logs the captured message. + @param message The log message. + */ + void Log(const Capture& message) const { + logging_manager_->Log(id_, message); + } + + /** + Sends a Profiling Event Record to the sink. + @param Profiling Event Record + */ + void SendProfileEvent(profiling::EventRecord& eventRecord) const { + logging_manager_->SendProfileEvent(eventRecord); + } + + private: + const LoggingManager* logging_manager_; + const std::string id_; + Severity min_severity_; + const bool filter_user_data_; + int max_vlog_level_; +}; + +inline const Logger& LoggingManager::DefaultLogger() { + if (s_default_logger_ == nullptr) { + // fail early for attempted misuse. don't use logging macros as we have no logger. + ORT_THROW("Attempt to use DefaultLogger but none has been registered."); + } + + return *s_default_logger_; +} + +inline void LoggingManager::SetDefaultLoggerSeverity(Severity severity) { + if (s_default_logger_ == nullptr) { + // fail early for attempted misuse. don't use logging macros as we have no logger. + ORT_THROW("Attempt to use DefaultLogger but none has been registered."); + } + + s_default_logger_->SetSeverity(severity); +} + +inline void LoggingManager::SetDefaultLoggerVerbosity(int vlog_level) { + if (s_default_logger_ == nullptr) { + // fail early for attempted misuse. don't use logging macros as we have no logger. + ORT_THROW("Attempt to use DefaultLogger but none has been registered."); + } + + s_default_logger_->SetVerbosity(vlog_level); +} + +inline Timestamp LoggingManager::GetTimestamp() const noexcept { + static const Epochs& epochs = GetEpochs(); + + const auto high_res_now = std::chrono::high_resolution_clock::now(); + return std::chrono::time_point_cast( + epochs.system + (high_res_now - epochs.high_res) + epochs.localtime_offset_from_utc); +} + +/** + Return the current thread id. +*/ +unsigned int GetThreadId(); + +/** + Return the current process id. +*/ +unsigned int GetProcessId(); + +/** + If the ONNXRuntimeTraceLoggingProvider ETW Provider is enabled, then adds to the existing logger. +*/ +std::unique_ptr EnhanceSinkWithEtw(std::unique_ptr existingSink, logging::Severity originalSeverity, + logging::Severity etwSeverity); + +/** + If the ONNXRuntimeTraceLoggingProvider ETW Provider is enabled, then can override the logging level. + But this overrided level only applies to the ETW sink. The original logger(s) retain their original logging level +*/ +Severity OverrideLevelWithEtw(Severity originalSeverity); + +} // namespace logging +} // namespace onnxruntime diff --git a/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/logging/macros.h b/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/logging/macros.h new file mode 100644 index 00000000000000..18764460cba76e --- /dev/null +++ b/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/logging/macros.h @@ -0,0 +1,280 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +// NOTE: Don't include this file directly. Include logging.h + +#define CREATE_MESSAGE(logger, severity, category, datatype) \ + ::onnxruntime::logging::Capture(logger, ::onnxruntime::logging::Severity::k##severity, category, datatype, ORT_WHERE) + +/* + Both printf and stream style logging are supported. + Not that printf currently has a 2K limit to the message size. + + LOGS_* macros are for stream style + LOGF_* macros are for printf style + + The Message class captures the log input, and pushes it through the logger in its destructor. + + Use the *FATAL* macros if you want a Severity::kFatal message to also throw. + + There are a few variants to minimize the length of the macro name required in the calling code. + They are optimized so the shortest names are for the (expected) most common usage. This can be + tweaked if needed. + + Explicit logger vs LoggingManager::DefaulLogger() + Default is for a logger instance to be explicitly passed in. + The logger instance provides an identifier so that log messages from different runs can be separated. + + Variants with DEFAULT in the macro name use the default logger provided by logging manager. This is + static so accessible from any code, provided a LoggingManager instance created with InstanceType::Default + exists somewhere. See logging.h for further explanation of the expected setup. + + DataType + Default uses DataType::SYSTEM. + + Variants with USER in the macro name use DataType::USER. This is data that could be PII, and may need to + be filtered from output. LoggingManager applies this filtering. + + Category + Default category is ::onnxruntime::Logging::Category::onnxruntime. + + If you wish to provide a different category, use variants with CATEGORY in the macro name + +*/ + +/** + * Note: + * The stream style logging macros (something like `LOGS() << message`) are designed to be appended to. + * Normally, we can isolate macro code in a separate scope (e.g., `do {...} while(0)`), but here we need the macro code + * to interact with subsequent code (i.e., the values to log). + * + * When an unisolated conditional is involved, extra care needs to be taken to avoid unexpected parsing behavior. + * For example: + * + * if (enabled) + * Capture().Stream() + * + * is more direct, but + * + * if (!enabled) { + * } else Capture().Stream() + * + * ensures that the `if` does not unintentionally associate with a subsequent `else`. + */ + +// Logging with explicit category + +// iostream style logging. Capture log info in Message, and push to the logger in ~Message. +#define LOGS_CATEGORY(logger, severity, category) \ + if (!(logger).OutputIsEnabled(::onnxruntime::logging::Severity::k##severity, \ + ::onnxruntime::logging::DataType::SYSTEM)) { \ + /* do nothing */ \ + } else \ + CREATE_MESSAGE(logger, severity, category, ::onnxruntime::logging::DataType::SYSTEM).Stream() + +#define LOGS_USER_CATEGORY(logger, severity, category) \ + if (!(logger).OutputIsEnabled(::onnxruntime::logging::Severity::k##severity, \ + ::onnxruntime::logging::DataType::USER)) { \ + /* do nothing */ \ + } else \ + CREATE_MESSAGE(logger, severity, category, ::onnxruntime::logging::DataType::USER).Stream() + +// printf style logging. Capture log info in Message, and push to the logger in ~Message. +#define LOGF_CATEGORY(logger, severity, category, format_str, ...) \ + do { \ + if ((logger).OutputIsEnabled(::onnxruntime::logging::Severity::k##severity, \ + ::onnxruntime::logging::DataType::SYSTEM)) \ + CREATE_MESSAGE(logger, severity, category, ::onnxruntime::logging::DataType::SYSTEM) \ + .CapturePrintf(format_str, ##__VA_ARGS__); \ + } while (0) + +#define LOGF_USER_CATEGORY(logger, severity, category, format_str, ...) \ + do { \ + if ((logger).OutputIsEnabled(::onnxruntime::logging::Severity::k##severity, \ + ::onnxruntime::logging::DataType::USER)) \ + CREATE_MESSAGE(logger, severity, category, ::onnxruntime::logging::DataType::USER) \ + .CapturePrintf(format_str, ##__VA_ARGS__); \ + } while (0) + +// Logging with category of "onnxruntime" + +#define LOGS(logger, severity) \ + LOGS_CATEGORY(logger, severity, ::onnxruntime::logging::Category::onnxruntime) + +#define LOGS_USER(logger, severity) \ + LOGS_USER_CATEGORY(logger, severity, ::onnxruntime::logging::Category::onnxruntime) + +// printf style logging. Capture log info in Message, and push to the logger in ~Message. +#define LOGF(logger, severity, format_str, ...) \ + LOGF_CATEGORY(logger, severity, ::onnxruntime::logging::Category::onnxruntime, format_str, ##__VA_ARGS__) + +#define LOGF_USER(logger, severity, format_str, ...) \ + LOGF_USER_CATEGORY(logger, severity, ::onnxruntime::logging::Category::onnxruntime, format_str, ##__VA_ARGS__) + +/* + Macros that use the default logger. + A LoggingManager instance must be currently valid for the default logger to be available. +*/ + +// Logging with explicit category + +#define LOGS_DEFAULT_CATEGORY(severity, category) \ + LOGS_CATEGORY(::onnxruntime::logging::LoggingManager::DefaultLogger(), severity, category) + +#define LOGS_USER_DEFAULT_CATEGORY(severity, category) \ + LOGS_USER_CATEGORY(::onnxruntime::logging::LoggingManager::DefaultLogger(), severity, category) + +#define LOGF_DEFAULT_CATEGORY(severity, category, format_str, ...) \ + LOGF_CATEGORY(::onnxruntime::logging::LoggingManager::DefaultLogger(), severity, category, format_str, ##__VA_ARGS__) + +#define LOGF_USER_DEFAULT_CATEGORY(severity, category, format_str, ...) \ + LOGF_USER_CATEGORY(::onnxruntime::logging::LoggingManager::DefaultLogger(), severity, category, format_str, ##__VA_ARGS__) + +// Logging with category of "onnxruntime" + +#define LOGS_DEFAULT(severity) \ + LOGS_DEFAULT_CATEGORY(severity, ::onnxruntime::logging::Category::onnxruntime) + +#define LOGS_USER_DEFAULT(severity) \ + LOGS_USER_DEFAULT_CATEGORY(severity, ::onnxruntime::logging::Category::onnxruntime) + +#define LOGF_DEFAULT(severity, format_str, ...) \ + LOGF_DEFAULT_CATEGORY(severity, ::onnxruntime::logging::Category::onnxruntime, format_str, ##__VA_ARGS__) + +#define LOGF_USER_DEFAULT(severity, format_str, ...) \ + LOGF_USER_DEFAULT_CATEGORY(severity, ::onnxruntime::logging::Category::onnxruntime, format_str, ##__VA_ARGS__) + +/* + Conditional logging +*/ + +// Logging with explicit category + +#define LOGS_CATEGORY_IF(boolean_expression, logger, severity, category) \ + if (!((boolean_expression) == true)) { \ + /* do nothing */ \ + } else \ + LOGS_CATEGORY(logger, severity, category) + +#define LOGS_DEFAULT_CATEGORY_IF(boolean_expression, severity, category) \ + if (!((boolean_expression) == true)) { \ + /* do nothing */ \ + } else \ + LOGS_DEFAULT_CATEGORY(severity, category) + +#define LOGS_USER_CATEGORY_IF(boolean_expression, logger, severity, category) \ + if (!((boolean_expression) == true)) { \ + /* do nothing */ \ + } else \ + LOGS_USER_CATEGORY(logger, severity, category) + +#define LOGS_USER_DEFAULT_CATEGORY_IF(boolean_expression, severity, category) \ + if (!((boolean_expression) == true)) { \ + /* do nothing */ \ + } else \ + LOGS_USER_DEFAULT_CATEGORY(severity, category) + +#define LOGF_CATEGORY_IF(boolean_expression, logger, severity, category, format_str, ...) \ + do { \ + if ((boolean_expression) == true) LOGF_CATEGORY(logger, severity, category, format_str, ##__VA_ARGS__); \ + } while (0) + +#define LOGF_DEFAULT_CATEGORY_IF(boolean_expression, severity, category, format_str, ...) \ + do { \ + if ((boolean_expression) == true) LOGF_DEFAULT_CATEGORY(severity, category, format_str, ##__VA_ARGS__); \ + } while (0) + +#define LOGF_USER_CATEGORY_IF(boolean_expression, logger, severity, category, format_str, ...) \ + do { \ + if ((boolean_expression) == true) LOGF_USER_CATEGORY(logger, severity, category, format_str, ##__VA_ARGS__); \ + } while (0) + +#define LOGF_USER_DEFAULT_CATEGORY_IF(boolean_expression, severity, category, format_str, ...) \ + do { \ + if ((boolean_expression) == true) LOGF_USER_DEFAULT_CATEGORY(severity, category, format_str, ##__VA_ARGS__); \ + } while (0) + +// Logging with category of "onnxruntime" + +#define LOGS_IF(boolean_expression, logger, severity) \ + LOGS_CATEGORY_IF(boolean_expression, logger, severity, ::onnxruntime::logging::Category::onnxruntime) + +#define LOGS_DEFAULT_IF(boolean_expression, severity) \ + LOGS_DEFAULT_CATEGORY_IF(boolean_expression, severity, ::onnxruntime::logging::Category::onnxruntime) + +#define LOGS_USER_IF(boolean_expression, logger, severity) \ + LOGS_USER_CATEGORY_IF(boolean_expression, logger, severity, ::onnxruntime::logging::Category::onnxruntime) + +#define LOGS_USER_DEFAULT_IF(boolean_expression, severity) \ + LOGS_USER_DEFAULT_CATEGORY_IF(boolean_expression, severity, ::onnxruntime::logging::Category::onnxruntime) + +#define LOGF_IF(boolean_expression, logger, severity, format_str, ...) \ + LOGF_CATEGORY_IF(boolean_expression, logger, severity, ::onnxruntime::logging::Category::onnxruntime, format_str, ##__VA_ARGS__) + +#define LOGF_DEFAULT_IF(boolean_expression, severity, format_str, ...) \ + LOGF_DEFAULT_CATEGORY_IF(boolean_expression, severity, ::onnxruntime::logging::Category::onnxruntime, format_str, ##__VA_ARGS__) + +#define LOGF_USER_IF(boolean_expression, logger, severity, format_str, ...) \ + LOGF_USER_CATEGORY_IF(boolean_expression, logger, severity, ::onnxruntime::logging::Category::onnxruntime, \ + format_str, ##__VA_ARGS__) + +#define LOGF_USER_DEFAULT_IF(boolean_expression, severity, format_str, ...) \ + LOGF_USER_DEFAULT_CATEGORY_IF(boolean_expression, severity, ::onnxruntime::logging::Category::onnxruntime, \ + format_str, ##__VA_ARGS__) + +/* + Debug verbose logging of caller provided level. + Disabled in Release builds. + Use the _USER variants for VLOG statements involving user data that may need to be filtered. +*/ +#ifndef NDEBUG +#define VLOGS(logger, level) \ + if (!(::onnxruntime::logging::vlog_enabled && level <= (logger).VLOGMaxLevel())) { \ + /* do nothing */ \ + } else \ + LOGS_CATEGORY(logger, VERBOSE, "VLOG" #level) + +#define VLOGS_USER(logger, level) \ + if (!(::onnxruntime::logging::vlog_enabled && level <= (logger).VLOGMaxLevel())) { \ + /* do nothing */ \ + } else \ + LOGS_USER_CATEGORY(logger, VERBOSE, "VLOG" #level) + +#define VLOGF(logger, level, format_str, ...) \ + do { \ + if (::onnxruntime::logging::vlog_enabled && level <= (logger).VLOGMaxLevel()) \ + LOGF_CATEGORY(logger, VERBOSE, "VLOG" #level, format_str, ##__VA_ARGS__); \ + } while (0) + +#define VLOGF_USER(logger, level, format_str, ...) \ + do { \ + if (::onnxruntime::logging::vlog_enabled && level <= (logger).VLOGMaxLevel()) \ + LOGF_USER_CATEGORY(logger, VERBOSE, "VLOG" #level, format_str, ##__VA_ARGS__); \ + } while (0) +#else +// Disabled in Release builds. +#define VLOGS(logger, level) \ + if constexpr (true) { \ + } else \ + LOGS_CATEGORY(logger, VERBOSE, "VLOG" #level) +#define VLOGS_USER(logger, level) \ + if constexpr (true) { \ + } else \ + LOGS_USER_CATEGORY(logger, VERBOSE, "VLOG" #level) +#define VLOGF(logger, level, format_str, ...) +#define VLOGF_USER(logger, level, format_str, ...) +#endif + +// Default logger variants +#define VLOGS_DEFAULT(level) \ + VLOGS(::onnxruntime::logging::LoggingManager::DefaultLogger(), level) + +#define VLOGS_USER_DEFAULT(level) \ + VLOGS_USER(::onnxruntime::logging::LoggingManager::DefaultLogger(), level) + +#define VLOGF_DEFAULT(level, format_str, ...) \ + VLOGF(::onnxruntime::logging::LoggingManager::DefaultLogger(), level, format_str, ##__VA_ARGS__) + +#define VLOGF_USER_DEFAULT(level, format_str, ...) \ + VLOGF_USER(::onnxruntime::logging::LoggingManager::DefaultLogger(), level, format_str, ##__VA_ARGS__) diff --git a/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/logging/severity.h b/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/logging/severity.h new file mode 100644 index 00000000000000..e43f192eb1807e --- /dev/null +++ b/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/logging/severity.h @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +namespace onnxruntime { +namespace logging { +// mild violation of naming convention. the 'k' lets us use token concatenation in the macro +// ::onnxruntime::Logging::Severity::k##severity. It's not legal to have ::onnxruntime::Logging::Severity::##severity +// the uppercase makes the LOG macro usage look as expected for passing an enum value as it will be LOGS(logger, ERROR) +enum class Severity { + kVERBOSE = 0, + kINFO = 1, + kWARNING = 2, + kERROR = 3, + kFATAL = 4 +}; + +constexpr const char* SEVERITY_PREFIX = "VIWEF"; + +} // namespace logging +} // namespace onnxruntime diff --git a/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/logging/sink_types.h b/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/logging/sink_types.h new file mode 100644 index 00000000000000..a99b0fca58d9d1 --- /dev/null +++ b/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/logging/sink_types.h @@ -0,0 +1,11 @@ +#pragma once + +namespace onnxruntime { +namespace logging { +enum class SinkType { + BaseSink, + CompositeSink, + EtwSink +}; +} // namespace logging +} // namespace onnxruntime diff --git a/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/make_string.h b/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/make_string.h new file mode 100644 index 00000000000000..6148ef63e7264e --- /dev/null +++ b/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/make_string.h @@ -0,0 +1,132 @@ +/** + * Copyright (c) 2016-present, Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// Portions Copyright (c) Microsoft Corporation + +#pragma once + +#include +#include +#include +#include + +namespace onnxruntime { + +namespace detail { + +inline void MakeStringImpl(std::ostringstream& /*ss*/) noexcept { +} + +template +inline void MakeStringImpl(std::ostringstream& ss, const T& t) noexcept { + ss << t; +} + +template +inline void MakeStringImpl(std::ostringstream& ss, const T& t, const Args&... args) noexcept { + MakeStringImpl(ss, t); + MakeStringImpl(ss, args...); +} + +// see MakeString comments for explanation of why this is necessary +template +inline std::string MakeStringImpl(const Args&... args) noexcept { + std::ostringstream ss; + MakeStringImpl(ss, args...); + return ss.str(); +} + +template +inline std::string MakeStringWithClassicLocaleImpl(const Args&... args) noexcept { + std::ostringstream ss; + ss.imbue(std::locale::classic()); + MakeStringImpl(ss, args...); + return ss.str(); +} + +// +// Infrastructure to convert char[n] to char* to reduce binary size +// + +// default is to leave the type as is +template +struct if_char_array_make_ptr { + using type = T; +}; + +// specialization that matches an array reference, which is what the char array from a string literal +// used in a call to MakeString will be. +// if the type is a char[n] array we 'decay' it to a char* so that the usages can be folded. +template +struct if_char_array_make_ptr { + // remove a single extent (T[x] -> T, but T[x][y] -> T[y]) so we only match char[x], + // and get the type name without the 'const' so both 'const char (&)[n]' and 'char (&)[n]' are matched. + using element_type = typename std::remove_const::type>::type; + using type = typename std::conditional::value, T*, T (&)[N]>::type; +}; + +// helper to make usage simpler in MakeString +template +using if_char_array_make_ptr_t = typename if_char_array_make_ptr::type; +} // namespace detail + +/** + * Makes a string by concatenating string representations of the arguments. + * This version uses the current locale. + */ +template +inline std::string MakeString(const Args&... args) { + // We need to update the types from the MakeString template instantiation to decay any char[n] to char*. + // e.g. MakeString("in", "out") goes from MakeString to MakeStringImpl + // so that MakeString("out", "in") will also match MakeStringImpl instead of requiring + // MakeStringImpl. + // + // We have to do the type processing before any actual work, so this function purely implements the type processing. + // If we do not do it this way we do not get the full binary size reduction. + // + // See https://stackoverflow.com/a/29418212/684911 for overall details of the approach, but note it does not cover + // the need to do the type processing as a separate step. + + return detail::MakeStringImpl(detail::if_char_array_make_ptr_t(args)...); +} + +/** + * Makes a string by concatenating string representations of the arguments. + * This version uses std::locale::classic(). + */ +template +inline std::string MakeStringWithClassicLocale(const Args&... args) { + return detail::MakeStringWithClassicLocaleImpl(detail::if_char_array_make_ptr_t(args)...); +} + +// MakeString versions for already-a-string types. + +inline std::string MakeString(const std::string& str) { + return str; +} + +inline std::string MakeString(const char* cstr) { + return cstr; +} + +inline std::string MakeStringWithClassicLocale(const std::string& str) { + return str; +} + +inline std::string MakeStringWithClassicLocale(const char* cstr) { + return cstr; +} + +} // namespace onnxruntime diff --git a/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/narrow.h b/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/narrow.h new file mode 100644 index 00000000000000..49dfbf3c459537 --- /dev/null +++ b/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/narrow.h @@ -0,0 +1,69 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +// onnxruntime::narrow() is like gsl::narrow() but it is also available when exceptions are disabled. + +#if !defined(ORT_NO_EXCEPTIONS) + +#include "gsl/narrow" + +namespace onnxruntime { +using gsl::narrow; +} // namespace onnxruntime + +#else // ^^ !defined(ORT_NO_EXCEPTIONS) ^^ / vv defined(ORT_NO_EXCEPTIONS) vv + +#include // std::fprintf +#include // std::terminate +#include + +#include "gsl/util" // gsl::narrow_cast + +namespace onnxruntime { + +namespace detail { +[[noreturn]] inline void OnNarrowingError() noexcept { + std::fprintf(stderr, "%s", "narrowing error\n"); + std::terminate(); +} +} // namespace detail + +// This implementation of onnxruntime::narrow was copied and adapted from: +// https://github.com/microsoft/GSL/blob/a3534567187d2edc428efd3f13466ff75fe5805c/include/gsl/narrow + +// narrow() : a checked version of narrow_cast() that terminates if the cast changed the value +template ::value>::type* = nullptr> +GSL_SUPPRESS(type.1) constexpr T narrow(U u) noexcept { + constexpr const bool is_different_signedness = + (std::is_signed::value != std::is_signed::value); + + GSL_SUPPRESS(es.103) // don't overflow + GSL_SUPPRESS(es.104) // don't underflow + GSL_SUPPRESS(p.2) // don't rely on undefined behavior + const T t = gsl::narrow_cast(u); // While this is technically undefined behavior in some cases (i.e., if the source value is of floating-point type + // and cannot fit into the destination integral type), the resultant behavior is benign on the platforms + // that we target (i.e., no hardware trap representations are hit). + + if (static_cast(t) != u || (is_different_signedness && ((t < T{}) != (u < U{})))) { + detail::OnNarrowingError(); + } + + return t; +} + +template ::value>::type* = nullptr> +GSL_SUPPRESS(type.1) constexpr T narrow(U u) noexcept { + const T t = gsl::narrow_cast(u); + + if (static_cast(t) != u) { + detail::OnNarrowingError(); + } + + return t; +} + +} // namespace onnxruntime + +#endif // defined(ORT_NO_EXCEPTIONS) diff --git a/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/optional.h b/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/optional.h new file mode 100644 index 00000000000000..f7106a3bbfb1ed --- /dev/null +++ b/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/optional.h @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include + +namespace onnxruntime { + +using std::optional; + +#ifndef ORT_NO_EXCEPTIONS +using std::bad_optional_access; +#endif + +using std::nullopt; +using std::nullopt_t; + +using std::in_place; +using std::in_place_t; + +using std::make_optional; + +} // namespace onnxruntime diff --git a/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/parse_string.h b/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/parse_string.h new file mode 100644 index 00000000000000..941e3f3377ecc7 --- /dev/null +++ b/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/parse_string.h @@ -0,0 +1,85 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include + +#include "core/common/common.h" + +namespace onnxruntime { + +/** + * Tries to parse a value from an entire string. + */ +template +bool TryParseStringWithClassicLocale(std::string_view str, T& value) { + if constexpr (std::is_integral::value && std::is_unsigned::value) { + // if T is unsigned integral type, reject negative values which will wrap + if (!str.empty() && str[0] == '-') { + return false; + } + } + + // don't allow leading whitespace + if (!str.empty() && std::isspace(str[0], std::locale::classic())) { + return false; + } + + std::istringstream is{std::string{str}}; + is.imbue(std::locale::classic()); + T parsed_value{}; + + const bool parse_successful = + is >> parsed_value && + is.get() == std::istringstream::traits_type::eof(); // don't allow trailing characters + if (!parse_successful) { + return false; + } + + value = std::move(parsed_value); + return true; +} + +inline bool TryParseStringWithClassicLocale(std::string_view str, std::string& value) { + value = str; + return true; +} + +inline bool TryParseStringWithClassicLocale(std::string_view str, bool& value) { + if (str == "0" || str == "False" || str == "false") { + value = false; + return true; + } + + if (str == "1" || str == "True" || str == "true") { + value = true; + return true; + } + + return false; +} + +/** + * Parses a value from an entire string. + */ +template +Status ParseStringWithClassicLocale(std::string_view s, T& value) { + ORT_RETURN_IF_NOT(TryParseStringWithClassicLocale(s, value), "Failed to parse value: \"", value, "\""); + return Status::OK(); +} + +/** + * Parses a value from an entire string. + */ +template +T ParseStringWithClassicLocale(std::string_view s) { + T value{}; + ORT_THROW_IF_ERROR(ParseStringWithClassicLocale(s, value)); + return value; +} + +} // namespace onnxruntime diff --git a/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/profiler_common.h b/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/profiler_common.h new file mode 100644 index 00000000000000..0074d5e74a461d --- /dev/null +++ b/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/profiler_common.h @@ -0,0 +1,93 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" + +#include +#include + +namespace onnxruntime { +namespace profiling { + +enum EventCategory { + SESSION_EVENT = 0, + NODE_EVENT, + KERNEL_EVENT, + API_EVENT, + EVENT_CATEGORY_MAX +}; + +// Event descriptions for the above session events. +static constexpr const char* event_category_names_[EVENT_CATEGORY_MAX] = { + "Session", + "Node", + "Kernel", + "Api"}; + +// Timing record for all events. +struct EventRecord { + EventRecord() = default; + EventRecord(EventCategory category, + int process_id, + int thread_id, + std::string&& event_name, + long long time_stamp, + long long duration, + std::unordered_map&& event_args) + : cat(category), + pid(process_id), + tid(thread_id), + name(std::move(event_name)), + ts(time_stamp), + dur(duration), + args(std::move(event_args)) {} + + EventRecord(EventCategory category, + int process_id, + int thread_id, + const std::string& event_name, + long long time_stamp, + long long duration, + const std::unordered_map& event_args) + : cat(category), + pid(process_id), + tid(thread_id), + name(event_name), + ts(time_stamp), + dur(duration), + args(event_args) {} + + EventRecord(const EventRecord& other) = default; + EventRecord(EventRecord&& other) noexcept = default; + EventRecord& operator=(const EventRecord& other) = default; + EventRecord& operator=(EventRecord&& other) = default; + + EventCategory cat = EventCategory::API_EVENT; + int pid = -1; + int tid = -1; + std::string name{}; + long long ts = 0; + long long dur = 0; + std::unordered_map args{}; +}; + +using Events = std::vector; + +// Execution Provider Profiler +class EpProfiler { + public: + virtual ~EpProfiler() = default; + virtual bool StartProfiling(TimePoint profiling_start_time) = 0; // called when profiling starts + virtual void EndProfiling(TimePoint start_time, Events& events) = 0; // called when profiling ends, save all captures numbers to "events" + virtual void Start(uint64_t){}; // called before op start, accept an id as argument to identify the op + virtual void Stop(uint64_t){}; // called after op stop, accept an id as argument to identify the op +}; + +// Demangle C++ symbols +std::string demangle(const char* name); +std::string demangle(const std::string& name); + +} // namespace profiling +} // namespace onnxruntime diff --git a/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/span_utils.h b/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/span_utils.h new file mode 100644 index 00000000000000..9f7454625fcd18 --- /dev/null +++ b/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/span_utils.h @@ -0,0 +1,93 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include + +#include + +namespace onnxruntime { + +// AsSpan inspired by Fekir's Blog https://fekir.info/post/span-the-missing-constructor/ +// Used under MIT license + +// Use AsSpan for less typing on any container including initializer list to create a span +// (unnamed, untyped initializer list does not automatically convert to gsl::span). +// {1, 2, 3} as such does not have a type +// (see https://scottmeyers.blogspot.com/2014/03/if-braced-initializers-have-no-type-why.html) +// +// Example: AsSpan({1, 2, 3}) results in gsl::span +// +// The above would deduce to std::initializer_list and the result is gsl::span +// +// AsSpan({1, 2, 3}) produces gsl::span +// +// We can also do std::array{1, 2, 3} that can be automatically converted to span +// without memory allocation. +// +// If type conversion is not required, then for C++17 std::array template parameters are +// auto-deduced. Example: std::array{1, 2, 3}. +// We are aiming at not allocating memory dynamically. + +namespace details { +template +constexpr auto AsSpanImpl(P* p, size_t s) { + return gsl::span

(p, s); +} +} // namespace details + +template +constexpr auto AsSpan(C& c) { + return details::AsSpanImpl(c.data(), c.size()); +} + +template +constexpr auto AsSpan(const C& c) { + return details::AsSpanImpl(c.data(), c.size()); +} + +template +constexpr auto AsSpan(C&& c) { + return details::AsSpanImpl(c.data(), c.size()); +} + +template +constexpr auto AsSpan(std::initializer_list c) { + return details::AsSpanImpl(c.begin(), c.size()); +} + +template +constexpr auto AsSpan(T (&arr)[N]) { + return details::AsSpanImpl(arr, N); +} + +template +constexpr auto AsSpan(const T (&arr)[N]) { + return details::AsSpanImpl(arr, N); +} + +template +inline gsl::span EmptySpan() { return gsl::span(); } + +template +[[nodiscard]] inline gsl::span ReinterpretAsSpan(gsl::span src) { + // adapted from gsl-lite span::as_span(): + // https://github.com/gsl-lite/gsl-lite/blob/4720a2980a30da085b4ddb4a0ea2a71af7351a48/include/gsl/gsl-lite.hpp#L4102-L4108 + Expects(src.size_bytes() % sizeof(U) == 0); + return gsl::span(reinterpret_cast(src.data()), src.size_bytes() / sizeof(U)); +} + +[[nodiscard]] inline gsl::span AsByteSpan(const void* data, size_t length) { + return gsl::span(reinterpret_cast(data), length); +} + +template +[[nodiscard]] inline bool SpanEq(gsl::span a, gsl::span b) { + static_assert(std::is_same_v, std::remove_const_t>, + "T1 and T2 should be the same type except for const qualification"); + return std::equal(a.begin(), a.end(), b.begin(), b.end()); +} + +} // namespace onnxruntime diff --git a/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/spin_pause.h b/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/spin_pause.h new file mode 100644 index 00000000000000..49b71e5567d3e2 --- /dev/null +++ b/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/spin_pause.h @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#if defined(_M_AMD64) +#include +#endif + +#if defined(__x86_64__) +#include +#endif + +namespace onnxruntime { + +namespace concurrency { + +// Intrinsic to use in spin-loops + +inline void SpinPause() { +#if defined(_M_AMD64) || defined(__x86_64__) + _mm_pause(); +#endif +} + +} // namespace concurrency + +} // namespace onnxruntime diff --git a/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/status.h b/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/status.h new file mode 100644 index 00000000000000..8f171daabbb1ea --- /dev/null +++ b/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/status.h @@ -0,0 +1,192 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Modifications Copyright (c) Microsoft. + +#pragma once + +#include +#include +#include +#ifdef _WIN32 +#include +#endif +namespace onnxruntime { +namespace common { + +enum StatusCategory { + NONE = 0, + SYSTEM = 1, + ONNXRUNTIME = 2, +}; + +/** + Error code for ONNXRuntime. +*/ +enum StatusCode { + OK = 0, + FAIL = 1, + INVALID_ARGUMENT = 2, + NO_SUCHFILE = 3, + NO_MODEL = 4, + ENGINE_ERROR = 5, + RUNTIME_EXCEPTION = 6, + INVALID_PROTOBUF = 7, + MODEL_LOADED = 8, + NOT_IMPLEMENTED = 9, + INVALID_GRAPH = 10, + EP_FAIL = 11 +}; + +constexpr const char* StatusCodeToString(StatusCode status) noexcept { + switch (status) { + case StatusCode::OK: + return "SUCCESS"; + case StatusCode::FAIL: + return "FAIL"; + case StatusCode::INVALID_ARGUMENT: + return "INVALID_ARGUMENT"; + case StatusCode::NO_SUCHFILE: + return "NO_SUCHFILE"; + case StatusCode::NO_MODEL: + return "NO_MODEL"; + case StatusCode::ENGINE_ERROR: + return "ENGINE_ERROR"; + case StatusCode::RUNTIME_EXCEPTION: + return "RUNTIME_EXCEPTION"; + case StatusCode::INVALID_PROTOBUF: + return "INVALID_PROTOBUF"; + case StatusCode::MODEL_LOADED: + return "MODEL_LOADED"; + case StatusCode::NOT_IMPLEMENTED: + return "NOT_IMPLEMENTED"; + case StatusCode::INVALID_GRAPH: + return "INVALID_GRAPH"; + case StatusCode::EP_FAIL: + return "EP_FAIL"; + default: + return "GENERAL ERROR"; + } +} + +#ifdef _WIN32 +constexpr HRESULT StatusCodeToHRESULT(StatusCode status) noexcept { + switch (status) { + case StatusCode::OK: + return S_OK; + case StatusCode::FAIL: + return E_FAIL; + case StatusCode::INVALID_ARGUMENT: + return E_INVALIDARG; + case StatusCode::NO_SUCHFILE: + return HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND); + case StatusCode::NO_MODEL: + return HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND); + case StatusCode::ENGINE_ERROR: + return E_FAIL; + case StatusCode::RUNTIME_EXCEPTION: + return E_FAIL; + case StatusCode::INVALID_PROTOBUF: + return HRESULT_FROM_WIN32(ERROR_FILE_CORRUPT); + case StatusCode::MODEL_LOADED: + return HRESULT_FROM_WIN32(ERROR_INTERNAL_ERROR); + case StatusCode::NOT_IMPLEMENTED: + return E_NOTIMPL; + case StatusCode::INVALID_GRAPH: + return HRESULT_FROM_WIN32(ERROR_FILE_CORRUPT); + case StatusCode::EP_FAIL: + return HRESULT_FROM_WIN32(ERROR_INTERNAL_ERROR); + default: + return E_FAIL; + } +} +#endif + +class [[nodiscard]] Status { + public: + Status() noexcept = default; + + Status(StatusCategory category, int code, const std::string& msg); + + Status(StatusCategory category, int code, const char* msg); + + Status(StatusCategory category, int code); + + Status(const Status& other) + : state_((other.state_ == nullptr) ? nullptr : new State(*other.state_)) {} + Status& operator=(const Status& other) { + if (state_ != other.state_) { + if (other.state_ == nullptr) { + state_.reset(); + } else { + state_.reset(new State(*other.state_)); + } + } + return *this; + } + + Status(Status&&) = default; + Status& operator=(Status&&) = default; + ~Status() = default; + + bool IsOK() const { + return (state_ == nullptr); + } + + int Code() const noexcept; + + StatusCategory Category() const noexcept; + + const std::string& ErrorMessage() const noexcept; + + std::string ToString() const; + + bool operator==(const Status& other) const { + return (this->state_ == other.state_) || (ToString() == other.ToString()); + } + + bool operator!=(const Status& other) const { + return !(*this == other); + } + + static Status OK() { + return Status(); + } + + private: + static const std::string& EmptyString() noexcept; + + struct State { + State(StatusCategory cat0, int code0, const std::string& msg0) + : category(cat0), code(code0), msg(msg0) {} + + State(StatusCategory cat0, int code0, const char* msg0) + : category(cat0), code(code0), msg(msg0) {} + + const StatusCategory category; + const int code; + const std::string msg; + }; + + // As long as Code() is OK, state_ == nullptr. + std::unique_ptr state_; +}; + +inline std::ostream& operator<<(std::ostream& out, const Status& status) { + return out << status.ToString(); +} + +} // namespace common + +// make Status directly available in the onnxruntime namespace as it is widely used +using common::Status; + +} // namespace onnxruntime diff --git a/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/string_helper.h b/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/string_helper.h new file mode 100644 index 00000000000000..1304303132d5a4 --- /dev/null +++ b/third_party/onnxruntime_headers/src/include/onnxruntime/core/common/string_helper.h @@ -0,0 +1,11 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include + +// forward declaration +struct OrtAllocator; +namespace onnxruntime { +char* StrDup(const std::string& str, OrtAllocator* allocator); +} // namespace onnxruntime diff --git a/third_party/onnxruntime_headers/src/include/onnxruntime/core/eager/ort_kernel_invoker.h b/third_party/onnxruntime_headers/src/include/onnxruntime/core/eager/ort_kernel_invoker.h new file mode 100644 index 00000000000000..fcf92de2ee39a9 --- /dev/null +++ b/third_party/onnxruntime_headers/src/include/onnxruntime/core/eager/ort_kernel_invoker.h @@ -0,0 +1,59 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include + +#include "core/common/common.h" +#include "core/framework/allocator.h" +#include "core/framework/tensor.h" +#include "core/framework/execution_provider.h" +#include "core/graph/constants.h" +#include "core/session/environment.h" +#include "core/graph/basic_types.h" +#include "core/graph/model.h" + +namespace onnxruntime { +#ifdef __GNUC__ +#pragma GCC diagnostic push +#endif + +class ORTInvoker { + public: + ORTInvoker(std::shared_ptr execution_provider, + const logging::Logger& logger, + const IOnnxRuntimeOpSchemaRegistryList& custom_op_registries) + : execution_provider_(std::move(execution_provider)), + logger_(logger), + custom_op_registries_(custom_op_registries) { + if (!execution_provider_) { + ORT_THROW("Execution provider is nullptr"); + } + } + + IExecutionProvider& GetCurrentExecutionProvider() { + return *execution_provider_; + } + + common::Status Invoke(const std::string& op_name, + // optional inputs / outputs? + const std::vector& inputs, + std::vector& outputs, + const NodeAttributes* attributes, + const std::string& domain = kOnnxDomain, + const int version = -1); + + private: + std::shared_ptr execution_provider_; + const logging::Logger& logger_; + // custom ops for current execution provider + // we need the op schema to resolve the output type during invoke + const IOnnxRuntimeOpSchemaRegistryList& custom_op_registries_; +}; + +#ifdef __GNUC__ +#pragma GCC diagnostic pop +#endif +} // namespace onnxruntime diff --git a/third_party/onnxruntime_headers/src/include/onnxruntime/core/framework/alloc_kind.h b/third_party/onnxruntime_headers/src/include/onnxruntime/core/framework/alloc_kind.h new file mode 100644 index 00000000000000..c7a953a44b872e --- /dev/null +++ b/third_party/onnxruntime_headers/src/include/onnxruntime/core/framework/alloc_kind.h @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include + +namespace onnxruntime { +// The ml-Values fall into the following categories with respect to their +// memory management: +// - inference inputs: owned (allocated and freed) by caller, and is by +// default read-only by the runtime. +// - inference outputs: allocated by runtime, ownership transferred to +// caller. TODO: Make sure this semantics is clear in InferenceSession API. +// - weights (constant tensors): can be allocated once (statically), and +// reused by all inference calls within an InferenceSession. +// - tensor values: The lifetimes of these tensor-values are statically +// determined, which is used for memory reuse/sharing optimizations. The +// runtime allocates/frees these values at the right time (as determined +// by the static allocation plan). Note that this is simplified since we +// do not try to optimize for "slice" like ops, where we may be able to +// conditionally reuse memory/data in some cases but not others. +// Generalizing this is future work. + +enum class AllocKind { + kNotSet = -1, + kAllocate = 0, + kReuse = 1, + kPreExisting = 2, + kAllocateStatically = 3, + kAllocateOutput = 4, + kShare = 5, + kAllocatedExternally = 6 +}; + +std::ostream& operator<<(std::ostream& out, AllocKind alloc_kind); +} // namespace onnxruntime diff --git a/third_party/onnxruntime_headers/src/include/onnxruntime/core/framework/allocator.h b/third_party/onnxruntime_headers/src/include/onnxruntime/core/framework/allocator.h new file mode 100644 index 00000000000000..57b332ce65b93b --- /dev/null +++ b/third_party/onnxruntime_headers/src/include/onnxruntime/core/framework/allocator.h @@ -0,0 +1,268 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +#include "core/common/common.h" +#include "core/framework/allocator_stats.h" +// some enums are defined in session/onnxruntime_c_api.h but used in ortdevice.h/ortmemory.h +#include "core/session/onnxruntime_c_api.h" +#include "core/framework/ortdevice.h" +#include "core/framework/ortmemoryinfo.h" + +// This configures the arena based allocator used by ORT +// See docs/C_API.md for details on what these mean and how to choose these values +struct OrtArenaCfg { + OrtArenaCfg() : max_mem(0), + arena_extend_strategy(-1), + initial_chunk_size_bytes(-1), + max_dead_bytes_per_chunk(-1), + initial_growth_chunk_size_bytes(-1), + max_power_of_two_extend_bytes(-1) {} + OrtArenaCfg(size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes, + int max_dead_bytes_per_chunk, int initial_growth_chunk_size_bytes, + int64_t max_power_of_two_extend_bytes) + : max_mem(max_mem), + arena_extend_strategy(arena_extend_strategy), + initial_chunk_size_bytes(initial_chunk_size_bytes), + max_dead_bytes_per_chunk(max_dead_bytes_per_chunk), + initial_growth_chunk_size_bytes(initial_growth_chunk_size_bytes), + max_power_of_two_extend_bytes(max_power_of_two_extend_bytes) {} + + size_t max_mem; // use 0 to allow ORT to choose the default + int arena_extend_strategy; // use -1 to allow ORT to choose the default, 0 = kNextPowerOfTwo, 1 = kSameAsRequested + int initial_chunk_size_bytes; // use -1 to allow ORT to choose the default + int max_dead_bytes_per_chunk; // use -1 to allow ORT to choose the default + int initial_growth_chunk_size_bytes; // use -1 to allow ORT to choose the default + int64_t max_power_of_two_extend_bytes; // use -1 to allow ORT to choose the default +}; + +namespace onnxruntime { +constexpr const char* CPU = "Cpu"; +constexpr const char* CUDA = "Cuda"; +constexpr const char* CUDA_PINNED = "CudaPinned"; +constexpr const char* CANN = "Cann"; +constexpr const char* CANN_PINNED = "CannPinned"; +constexpr const char* DML = "DML"; +constexpr const char* HIP = "Hip"; +constexpr const char* HIP_PINNED = "HipPinned"; +constexpr const char* OpenVINO_CPU = "OpenVINO_CPU"; +constexpr const char* OpenVINO_GPU = "OpenVINO_GPU"; +constexpr const char* OpenVINO_RT = "OpenVINO_RT"; +constexpr const char* OpenVINO_RT_NPU = "OpenVINO_RT_NPU"; +constexpr const char* WEBGPU_BUFFER = "WebGPU_Buffer"; +constexpr const char* WEBNN_TENSOR = "WebNN_Tensor"; + +constexpr size_t kAllocAlignment = 256; + +class IAllocator; +class Stream; +namespace synchronize { +class Notification; +} +using WaitNotificationFn = std::function; +void* AllocateBufferWithOptions(IAllocator& allocator, size_t size, bool use_reserve, Stream* stream, WaitNotificationFn wait_fn); + +template +using IAllocatorUniquePtr = std::unique_ptr>; + +class IAllocator { + public: + IAllocator(const OrtMemoryInfo& info) : memory_info_(info) {} + virtual ~IAllocator() = default; + /** + * Allocate memory of the specified size. + * If size is 0, nullptr is returned. + * If allocation fails, an exception is thrown. + * + * @remarks Use SafeInt when calculating the size of memory to allocate using Alloc. + */ + virtual void* Alloc(size_t size) = 0; + + virtual void Free(void* p) = 0; + + // Reserve() is an interface exposed for an implementation of IAllocator + // to optionally implement some allocation logic that by-passes any arena-based + // logic that may be housed in the Alloc() implementation. + // There are SessionOptions config(s) that allow users to allocate some memory + // by-passing arena-based logic. + // By default, the base implementation just calls Alloc(). + virtual void* Reserve(size_t size) { return Alloc(size); } + + const OrtMemoryInfo& Info() const { return memory_info_; }; + + // Each implementation of IAllocator can override and provide their own implementation + virtual void GetStats(AllocatorStats* /*stats*/) { return; } + + static bool CalcMemSizeForArray(size_t nmemb, size_t size, size_t* out) noexcept { + return CalcMemSizeForArrayWithAlignment(nmemb, size, 0, out); + } + + /** + * Calculate the memory size for an array. The size is bounds checked using SafeInt. + * \tparam alignment must be power of 2 + * \param nmemb Number of members or elements in the array + * \param size Size of each element + * \param out Total size required after any alignment is applied + * \return true, successful. false, overflow + */ + [[nodiscard]] static bool CalcMemSizeForArrayWithAlignment(size_t nmemb, size_t size, size_t alignment, + size_t* out) noexcept; + + /** + * https://cwe.mitre.org/data/definitions/190.html + * \param alignment must be power of 2 + * \param nmemb Number of members or elements in the array + * \param size Size of each element + * \param out Total size required after any alignment is applied + * \return true, successful. false, overflow + * \remarks This was the original API and was implemented in the header. Replaced with the above version + * implemented in the .cc file so that the SafeInt dependency is internal. + */ + template + [[nodiscard]] static bool CalcMemSizeForArrayWithAlignment(size_t nmemb, size_t size, size_t* out) noexcept; + + /** + * allocate memory for an array which has nmemb items of data, each size bytes long + */ + void* AllocArray(size_t nmemb, size_t size) { + size_t len; + if (!CalcMemSizeForArray(nmemb, size, &len)) { + ORT_THROW("Invalid size requested for allocation: ", nmemb, " * ", size); + } + + return Alloc(len); + } + + /** + * allocate memory for an array which has nmemb items of data, each size bytes long + */ + template + void* AllocArrayWithAlignment(size_t nmemb, size_t size) { + size_t len; + if (!CalcMemSizeForArrayWithAlignment(nmemb, size, alignment, &len)) { + ORT_THROW("Invalid size requested for allocation: ", nmemb, " * ", size, " with alignment ", alignment); + } + + return Alloc(len); + } + + /** + Create a std::unique_ptr that is allocated and freed by the provided IAllocator. + @param allocator The allocator. + @param count_or_bytes The exact bytes to allocate if T is void, otherwise the number of elements to allocate. + @param use_reserve If true, call Reserve() instead of Alloc() to allocate memory. + @param stream Which stream instance allocated chunk will be used with. + @param wait_fn If the allocator want to dynamic reuse a chunk from another stream, use this wait_fn to sync on + the target stream to make the reuse safe. + @returns std::unique_ptr with allocated memory and deleter. Throws if it cannot allocate memory. + */ + template + static IAllocatorUniquePtr MakeUniquePtr(std::shared_ptr allocator, size_t count_or_bytes, + bool use_reserve = false, + Stream* stream = nullptr, WaitNotificationFn wait_fn = nullptr) { + ValidateAllocator(allocator); + + // for now limit to fundamental types. we could support others, but to do so either we or the caller + // needs to call the dtor for the objects, for buffers allocated on device we don't have destructor + // static_assert(std::is_fundamental::value, "Fundamental type required as no destructors are called."); + + size_t alloc_size = count_or_bytes; + + // if T is not void, 'count_or_bytes' == number of items so allow for that + if constexpr (!std::is_void::value) { + // sizeof(void) isn't valid, but the compiler isn't smart enough to ignore that this line isn't + // reachable if T is void. use std::conditional to 'use' void* in the sizeof call + constexpr auto size = sizeof(typename std::conditional::value, void*, T>::type); + alloc_size = ValidatedCalcMemSizeForArray(count_or_bytes, size); + } + + // allocate + T* p = static_cast(AllocateBufferWithOptions(*allocator, alloc_size, use_reserve, stream, std::move(wait_fn))); + ValidateAllocation(p, alloc_size); + + return IAllocatorUniquePtr{p, + [allocator = std::move(allocator)](T* p) { + allocator->Free(p); + }}; + } + + /** + Create a std::unique_ptr that is allocated and freed by the provided OrtAllocator. + @param ort_allocator The allocator. + @param count_or_bytes The exact bytes to allocate if T is void, otherwise the number of elements to allocate. + @returns std::unique_ptr with allocated memory and deleter. Throws if it cannot allocate memory. + */ + template + static IAllocatorUniquePtr MakeUniquePtrFromOrtAllocator(OrtAllocator* ort_allocator, size_t count_or_bytes) { + ValidateAllocator(ort_allocator); + + size_t alloc_size = count_or_bytes; + // if T is not void, 'count_or_bytes' == number of items so allow for that + if constexpr (!std::is_void::value) { + // sizeof(void) isn't valid, but the compiler isn't smart enough to ignore that this line isn't + // reachable if T is void. use std::conditional to 'use' void* in the sizeof call + constexpr auto size = sizeof(typename std::conditional::value, void*, T>::type); + alloc_size = ValidatedCalcMemSizeForArray(count_or_bytes, size); + } + + T* p = static_cast(ort_allocator->Alloc(ort_allocator, alloc_size)); + ValidateAllocation(p, alloc_size); + + return IAllocatorUniquePtr{p, + [ort_allocator](T* p) { + ort_allocator->Free(ort_allocator, p); + }}; + } + + private: + // + // validation functions. split out from methods that are templatized on the data type to minimize binary size. + // + + template + static void ValidateAllocator(const T& allocator) { + ORT_ENFORCE(allocator != nullptr); + } + + static size_t ValidatedCalcMemSizeForArray(size_t count, size_t size) { + size_t alloc_size = 0; + if (!CalcMemSizeForArray(count, size, &alloc_size)) { + ORT_THROW("Invalid size requested for allocation: ", count, " * ", size); + } + + return alloc_size; + } + + static void ValidateAllocation(void* p, size_t size) { + // allocator should throw directly but in case it didn't ensure we do here so that calling code doesn't + // need to check for nullptr when an actual allocation was expected. + ORT_ENFORCE(p != nullptr || size == 0, "Memory allocation failed. Size=", size); + }; + + OrtMemoryInfo memory_info_; +}; + +template +bool IAllocator::CalcMemSizeForArrayWithAlignment(size_t nmemb, size_t size, size_t* out) noexcept { + return CalcMemSizeForArrayWithAlignment(nmemb, size, alignment, out); +} + +class CPUAllocator : public IAllocator { + public: + explicit CPUAllocator(const OrtMemoryInfo& memory_info) : IAllocator(memory_info) {} + + CPUAllocator() : IAllocator(OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator)) {} + + void* Alloc(size_t size) override; + void Free(void* p) override; +}; + +using AllocatorPtr = std::shared_ptr; +using AllocatorMap = std::map; + +void* AllocatorDefaultAlloc(size_t size); +void AllocatorDefaultFree(void* p); +} // namespace onnxruntime diff --git a/third_party/onnxruntime_headers/src/include/onnxruntime/core/framework/buffer_deleter.h b/third_party/onnxruntime_headers/src/include/onnxruntime/core/framework/buffer_deleter.h new file mode 100644 index 00000000000000..961eb443ee1c7a --- /dev/null +++ b/third_party/onnxruntime_headers/src/include/onnxruntime/core/framework/buffer_deleter.h @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/framework/allocator.h" + +namespace onnxruntime { + +// TODO: Do we need this class or is IAllocator::MakeUniquePtr sufficient/better +class BufferDeleter { + public: + BufferDeleter() = default; + explicit BufferDeleter(AllocatorPtr alloc) + : alloc_(std::move(alloc)) {} + + void operator()(void* p) const { + if (alloc_) + alloc_->Free(p); + } + + private: + // TODO: we may need consider the lifetime of alloc carefully + // The alloc_ here is the allocator that used to allocate the buffer + // And need go with the unique_ptr together. If it is using our internal + // allocator, it is ok as our allocators are global managed. But if it + // is provide by user, user need to be very careful about it. + // A weak_ptr may be a choice to reduce the impact, but that require to + // change our current allocator mgr to use shared_ptr. Will revisit it + // later. + AllocatorPtr alloc_{nullptr}; +}; + +using BufferUniquePtr = std::unique_ptr; +using BufferNakedPtr = void*; +} // namespace onnxruntime diff --git a/third_party/onnxruntime_headers/src/include/onnxruntime/core/framework/customregistry.h b/third_party/onnxruntime_headers/src/include/onnxruntime/core/framework/customregistry.h new file mode 100644 index 00000000000000..52f6169e2e8294 --- /dev/null +++ b/third_party/onnxruntime_headers/src/include/onnxruntime/core/framework/customregistry.h @@ -0,0 +1,60 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/status.h" +#include "core/common/logging/logging.h" +#include "core/framework/op_kernel.h" +#include "core/framework/kernel_def_builder.h" +#include "core/framework/kernel_registry.h" + +#if !defined(ORT_MINIMAL_BUILD) +#include "core/graph/schema_registry.h" +#endif + +namespace onnxruntime { + +/** + Represents a registry that contains both custom kernels and custom schemas. +*/ +class CustomRegistry final { + public: + CustomRegistry() + : kernel_registry_(std::make_shared()) +#if !defined(ORT_MINIMAL_BUILD) + , + opschema_registry_(std::make_shared()) +#endif + { + } + + /** + * Register a kernel definition together with kernel factory method to this session. + * If any conflict happened between registered kernel def and built-in kernel def, + * registered kernel will have higher priority. + * Call this before invoking Initialize(). + * @return OK if success. + */ + common::Status RegisterCustomKernel(KernelDefBuilder& kernel_def_builder, const KernelCreateFn& kernel_creator); + + common::Status RegisterCustomKernel(KernelCreateInfo&); + + const std::shared_ptr& GetKernelRegistry(); + +#if !defined(ORT_MINIMAL_BUILD) + common::Status RegisterOpSet(std::vector& schemas, const std::string& domain, + int baseline_opset_version, int opset_version); + + const std::shared_ptr& GetOpschemaRegistry(); +#endif + + private: + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(CustomRegistry); + std::shared_ptr kernel_registry_; +#if !defined(ORT_MINIMAL_BUILD) + std::shared_ptr opschema_registry_; +#endif +}; + +} // namespace onnxruntime diff --git a/third_party/onnxruntime_headers/src/include/onnxruntime/core/framework/data_types.h b/third_party/onnxruntime_headers/src/include/onnxruntime/core/framework/data_types.h new file mode 100644 index 00000000000000..87feefa10ca4a0 --- /dev/null +++ b/third_party/onnxruntime_headers/src/include/onnxruntime/core/framework/data_types.h @@ -0,0 +1,1125 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include "core/common/common.h" +#include "core/common/exceptions.h" +#include "core/framework/endian.h" +#include "core/framework/float8.h" +#include "core/framework/float16.h" +#include "core/framework/int4.h" +#include "core/graph/onnx_protobuf.h" +#include "core/framework/to_tensor_proto_element_type.h" + +struct OrtValue; + +namespace ONNX_NAMESPACE { +class TypeProto; +} // namespace ONNX_NAMESPACE + +namespace onnxruntime { +/// Predefined registered types + +#if !defined(DISABLE_ML_OPS) + +// maps (only used by ML ops) +using MapStringToString = std::map; +using MapStringToInt64 = std::map; +using MapStringToFloat = std::map; +using MapStringToDouble = std::map; +using MapInt64ToString = std::map; +using MapInt64ToInt64 = std::map; +using MapInt64ToFloat = std::map; +using MapInt64ToDouble = std::map; + +// vectors/sequences +using VectorMapStringToFloat = std::vector; +using VectorMapInt64ToFloat = std::vector; + +#endif + +using VectorString = std::vector; +using VectorInt64 = std::vector; + +// Forward declarations +class DataTypeImpl; +class TensorTypeBase; +#if !defined(DISABLE_SPARSE_TENSORS) +class SparseTensorTypeBase; +#endif +class SequenceTensorTypeBase; +class NonTensorTypeBase; +#if !defined(DISABLE_OPTIONAL_TYPE) +class OptionalTypeBase; +#endif +class PrimitiveDataTypeBase; +class Tensor; +class TensorSeq; + +// DataTypeImpl pointer as unique DataTypeImpl identifier. +using MLDataType = const DataTypeImpl*; +// be used with class MLValue +using DeleteFunc = void (*)(void*); +using CreateFunc = void* (*)(); + +/** + * \brief Base class for MLDataType + * + */ +class DataTypeImpl { + public: + enum class GeneralType { + kInvalid = 0, + kNonTensor = 1, + kTensor = 2, + kTensorSequence = 3, + kSparseTensor = 4, + kOptional = 5, + kPrimitive = 6, + }; + + const GeneralType type_; + const size_t size_; + + protected: + DataTypeImpl(GeneralType type, size_t size) : type_{type}, size_{size} {} + + public: + virtual ~DataTypeImpl() = default; + + /** + * \brief this API will be used to check type compatibility at runtime + * + * \param type_proto a TypeProto instance that is constructed for a specific type + * will be checked against a TypeProto instance contained within a corresponding + * MLDataType instance. + */ + virtual bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const = 0; + + size_t Size() const { return size_; } + + virtual DeleteFunc GetDeleteFunc() const = 0; + + /** + * \brief Retrieves an instance of TypeProto for + * a given MLDataType + * \returns optional TypeProto. Only ONNX types + has type proto, non-ONNX types will return nullptr. + */ + virtual const ONNX_NAMESPACE::TypeProto* GetTypeProto() const = 0; + + bool IsTensorType() const { + return type_ == GeneralType::kTensor; + } + + bool IsTensorSequenceType() const { + return type_ == GeneralType::kTensorSequence; + } + + bool IsSparseTensorType() const { + return type_ == GeneralType::kSparseTensor; + } + + bool IsOptionalType() const { + return type_ == GeneralType::kOptional; + } + + bool IsNonTensorType() const { + return type_ == GeneralType::kNonTensor; + } + + bool IsPrimitiveDataType() const { + return type_ == GeneralType::kPrimitive; + } + + // Returns this if this is of tensor-type and null otherwise + const TensorTypeBase* AsTensorType() const; + + const SequenceTensorTypeBase* AsSequenceTensorType() const; + +#if !defined(DISABLE_SPARSE_TENSORS) + // Returns this if this is of sparse-tensor-type and null otherwise + const SparseTensorTypeBase* AsSparseTensorType() const; +#endif + +#if !defined(DISABLE_OPTIONAL_TYPE) + const OptionalTypeBase* AsOptionalType() const; +#endif + + const NonTensorTypeBase* AsNonTensorType() const; + + // Returns this if this is one of the primitive data types (specialization of PrimitiveDataTypeBase) + // and null otherwise + const PrimitiveDataTypeBase* AsPrimitiveDataType() const; + + // Return the type meta that we are using in the runtime. + template + static MLDataType GetType(); + + // Return the types for a concrete tensor type, like Tensor_Float + template + static MLDataType GetTensorType(); + + template + static MLDataType GetSequenceTensorType(); + +#if !defined(DISABLE_SPARSE_TENSORS) + // Return the MLDataType for a concrete sparse tensor type. + template + static MLDataType GetSparseTensorType(); +#endif + + template + static MLDataType GetOptionalType(); + + /** + * Convert an ONNX TypeProto to onnxruntime DataTypeImpl. + * However, this conversion is lossy. Don't try to use 'this->GetTypeProto()' converting it back. + * Even though GetTypeProto() will not have the original information, it will still have enough to correctly + * map to MLDataType. + * \param proto + */ + static MLDataType TypeFromProto(const ONNX_NAMESPACE::TypeProto& proto); + + static const TensorTypeBase* TensorTypeFromONNXEnum(int type); + static const SequenceTensorTypeBase* SequenceTensorTypeFromONNXEnum(int type); +#if !defined(DISABLE_SPARSE_TENSORS) + static const SparseTensorTypeBase* SparseTensorTypeFromONNXEnum(int type); +#endif + + static const char* ToString(MLDataType type); + static std::vector ToString(const std::vector& types); + // Registers ONNX_NAMESPACE::DataType (internalized string) with + // MLDataType. DataType is produced by internalizing an instance of + // TypeProto contained within MLDataType + static void RegisterDataType(MLDataType); + static MLDataType GetDataType(const std::string&); + + // IR4: includes all float types, includes float16, bfloat16 + // IR9: includes float 8 types as well + static const std::vector& AllTensorTypes(); // up to IR4 (no float 8), deprecated + static const std::vector& AllTensorTypesIRv4(); + static const std::vector& AllTensorTypesIRv9(); + + static const std::vector& AllFixedSizeTensorTypes(); // up to IR4 (no float 8), deprecated + static const std::vector& AllFixedSizeTensorTypesIRv4(); + static const std::vector& AllFixedSizeTensorTypesIRv9(); + + static const std::vector& AllSequenceTensorTypes(); // up to IR4 (no float 8), deprecated + static const std::vector& AllSequenceTensorTypesIRv4(); + static const std::vector& AllSequenceTensorTypesIRv9(); + + static const std::vector& AllFixedSizeSequenceTensorTypes(); // up to IR4 (no float 8), deprecated + static const std::vector& AllFixedSizeSequenceTensorTypesIRv4(); + static const std::vector& AllFixedSizeSequenceTensorTypesIRv9(); + + static const std::vector& AllNumericTensorTypes(); // up to IR4 (no float 8), deprecated + static const std::vector& AllNumericTensorTypesIRv4(); + static const std::vector& AllNumericTensorTypesIRv9(); + + static const std::vector& AllIEEEFloatTensorTypes(); // float16, float, double + + static const std::vector& AllTensorAndSequenceTensorTypes(); // up to IR4 (no float 8), deprecated + static const std::vector& AllTensorAndSequenceTensorTypesIRv4(); + static const std::vector& AllTensorAndSequenceTensorTypesIRv9(); + + static const std::vector& AllOptionalAndTensorAndSequenceTensorTypes(); // up to IR4 (no float 8), deprecated + static const std::vector& AllOptionalAndTensorAndSequenceTensorTypesIRv4(); + static const std::vector& AllOptionalAndTensorAndSequenceTensorTypesIRv9(); + + static const std::vector& AllFixedSizeTensorAndSequenceTensorTypes(); // up to IR4 (no float 8), deprecated + static const std::vector& AllFixedSizeTensorAndSequenceTensorTypesIRv4(); + static const std::vector& AllFixedSizeTensorAndSequenceTensorTypesIRv9(); + + static const std::vector& AllOptionalTypes(); // up to IR4 (no float 8), deprecated + static const std::vector& AllOptionalTypesIRv4(); + static const std::vector& AllOptionalTypesIRv9(); + + static const std::vector& AllTensorAndSequenceTensorAndOptionalTypes(); // up to IR4 (no float 8), deprecated + static const std::vector& AllTensorAndSequenceTensorAndOptionalTypesIRv4(); + static const std::vector& AllTensorAndSequenceTensorAndOptionalTypesIRv9(); +}; + +std::ostream& operator<<(std::ostream& out, MLDataType data_type); + +/* + * Type registration helpers + */ +namespace data_types_internal { +/// TensorType helpers +/// + +/// Is a given type on the list of types? +/// Accepts a list of types and the first argument is the type +/// We are checking if it is listed among those that follow +template +struct IsAnyOf; + +/// Two types remaining, end of the list +template +struct IsAnyOf : public std::is_same { +}; + +template +struct IsAnyOf { + static constexpr bool value = (std::is_same::value || + IsAnyOf::value); +}; + +/// Tells if the specified type is one of fundamental types +/// that can be contained within a tensor. +/// We do not have raw fundamental types, rather a subset +/// of fundamental types is contained within tensors. +template +struct IsTensorContainedType : public IsAnyOf { +}; + +#if !defined(DISABLE_SPARSE_TENSORS) +/// Use "IsSparseTensorContainedType::value" to test if a type T +/// is permitted as the element-type of a sparse-tensor. + +template +struct IsSparseTensorContainedType : public IsAnyOf { +}; +#endif + +#if !defined(DISABLE_OPTIONAL_TYPE) +/// Tells if the specified type is one of ORT types +/// that can be contained within an optional struct. +template +struct IsOptionalOrtType : public IsAnyOf { +}; +#endif + +/// This template's Get() returns a corresponding MLDataType +/// It dispatches the call to either GetTensorType<>() or +/// GetType<>() +template +struct GetMLDataType; + +template +struct GetMLDataType { + static MLDataType Get() { + return DataTypeImpl::GetTensorType(); + } +}; + +template +struct GetMLDataType { + static MLDataType Get() { + return DataTypeImpl::GetType(); + } +}; + +struct TensorTypeHelper { + static void Set(ONNX_NAMESPACE::TensorProto_DataType element_type, + ONNX_NAMESPACE::TypeProto& proto) { + proto.mutable_tensor_type()->set_elem_type(element_type); + } +}; + +#if !defined(DISABLE_SPARSE_TENSORS) +struct SparseTensorTypeHelper { + static void Set(ONNX_NAMESPACE::TensorProto_DataType element_type, + ONNX_NAMESPACE::TypeProto& proto) { + proto.mutable_sparse_tensor_type()->set_elem_type(element_type); + } +}; +#endif // !defined(DISABLE_SPARSE_TENSORS) + +#if !defined(DISABLE_ML_OPS) +/// Map helpers + +void CopyMutableMapValue(const ONNX_NAMESPACE::TypeProto&, + ONNX_NAMESPACE::TypeProto&); + +struct MapTypeHelper { + // V can be either a primitive type (in which case it is a tensor) + // or other preregistered types + template + static MLDataType GetValueType() { + return GetMLDataType::value>::Get(); + } + + static void Set(ONNX_NAMESPACE::TensorProto_DataType key_type, const ONNX_NAMESPACE::TypeProto* value_proto, + ONNX_NAMESPACE::TypeProto& proto) { + ORT_ENFORCE(value_proto != nullptr, "expected a registered ONNX type"); + proto.mutable_map_type()->set_key_type(key_type); + CopyMutableMapValue(*value_proto, proto); + } +}; +#endif + +/// Sequence helpers + +// Element type is a primitive type so we set it to a tensor +void CopyMutableSeqElement(const ONNX_NAMESPACE::TypeProto&, + ONNX_NAMESPACE::TypeProto&); + +// helper to create TypeProto with minimal binary size impact +struct SequenceTypeHelper { + template + static MLDataType GetElemType() { + return GetMLDataType::value>::Get(); + } + + static void Set(const ONNX_NAMESPACE::TypeProto* elem_proto, + ONNX_NAMESPACE::TypeProto& proto) { + ORT_ENFORCE(elem_proto != nullptr, "expected a registered ONNX type"); + CopyMutableSeqElement(*elem_proto, proto); + } +}; + +/// Optional helpers + +void CopyMutableOptionalElement(const ONNX_NAMESPACE::TypeProto&, + ONNX_NAMESPACE::TypeProto&); + +// helper to create TypeProto with minimal binary size impact +struct OptionalTypeHelper { + template + static MLDataType GetElemType() { + if constexpr (std::is_same::value) { + return DataTypeImpl::GetTensorType(); + } else { + static_assert(std::is_same::value, "Unsupported element type for optional type"); + return DataTypeImpl::GetSequenceTensorType(); + } + } + + static void Set(const onnx::TypeProto* elem_proto, ONNX_NAMESPACE::TypeProto& proto) { + ORT_ENFORCE(elem_proto != nullptr, "expected a registered ONNX type"); + CopyMutableOptionalElement(*elem_proto, proto); + } +}; + +/// OpaqueTypes helpers + +void AssignOpaqueDomainName(const char* domain, const char* name, + ONNX_NAMESPACE::TypeProto& proto); + +} // namespace data_types_internal + +// The suppressed warning is: "The type with a virtual function needs either public virtual or protected nonvirtual destructor." +// However, we do not allocate this type on heap. +#if defined(_MSC_VER) && !defined(__clang__) +#pragma warning(push) +#pragma warning(disable : 26436) +#endif +/// All tensors base +class TensorTypeBase : public DataTypeImpl { + public: + static MLDataType Type(); + + /// We first compare type_proto pointers and then + /// if they do not match try to account for the case + /// where TypeProto was created ad-hoc and not queried from MLDataType + bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const override; + + DeleteFunc GetDeleteFunc() const override; + + const ONNX_NAMESPACE::TypeProto* GetTypeProto() const override; + + virtual MLDataType GetElementType() const { + // should never reach here. + ORT_NOT_IMPLEMENTED(__FUNCTION__, " is not implemented"); + } + + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TensorTypeBase); + + protected: + ONNX_NAMESPACE::TypeProto& MutableTypeProto(); + + TensorTypeBase(); + ~TensorTypeBase() override; + + private: + struct Impl; + Impl* impl_; +}; + +/** + * \brief Tensor type. This type does not have a C++ type associated with + * it at registration time except the element type. One of the types mentioned + * above at IsTensorContainedType<> list is acceptable. + * + * \details + * Usage: + * ORT_REGISTER_TENSOR(ELEMENT_TYPE) + * Currently all of the Tensors irrespective of the dimensions are mapped to Tensor + * type. IsCompatible() currently ignores shape. + */ + +template +class TensorType : public TensorTypeBase { + public: + static_assert(data_types_internal::IsTensorContainedType::value, + "Requires one of the tensor fundamental types"); + + static MLDataType Type(); + + /// Tensors only can contain basic data types + /// that have been previously registered with ONNXRuntime + MLDataType GetElementType() const override { + return DataTypeImpl::GetType(); + } + + private: + TensorType() { + using namespace data_types_internal; + TensorTypeHelper::Set(utils::ToTensorProtoElementType(), MutableTypeProto()); + } +}; + +#if defined(DISABLE_OPTIONAL_TYPE) + +// TODO is this still needed after removing kernel def hashes? +/// Common base-class for all disabled types. We need DataTypeImpl::ToString to work in a minimal build +/// with disabled types to keep the ORT format model kernel hashes stable. +class DisabledTypeBase : public DataTypeImpl { + public: + static MLDataType Type(); + + bool IsCompatible(const ONNX_NAMESPACE::TypeProto&) const override { + // We always want to return false for the IsCompatible() for a disabled type + // because this will ensure that no kernel supporting the disabled type will + // be matched to a model node requiring that type and the model load will + // result in failure. + return false; + } + + DeleteFunc GetDeleteFunc() const override { + ORT_THROW("Type is disabled in this build."); + } + + // This must work + const ONNX_NAMESPACE::TypeProto* GetTypeProto() const override; + + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(DisabledTypeBase); + + protected: + // This must work + ONNX_NAMESPACE::TypeProto& MutableTypeProto(); + + DisabledTypeBase(DataTypeImpl::GeneralType type, size_t size); + ~DisabledTypeBase() override; + + private: + struct Impl; + Impl* impl_; +}; + +#endif + +#if !defined(DISABLE_SPARSE_TENSORS) +/// Common base-class for all sparse-tensors (with different element types). +class SparseTensorTypeBase : public DataTypeImpl { + public: + static MLDataType Type(); + + bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const override; + + DeleteFunc GetDeleteFunc() const override; + + const ONNX_NAMESPACE::TypeProto* GetTypeProto() const override; + + virtual MLDataType GetElementType() const { + // should never reach here. + ORT_NOT_IMPLEMENTED(__FUNCTION__, " is not implemented"); + } + + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(SparseTensorTypeBase); + + protected: + ONNX_NAMESPACE::TypeProto& MutableTypeProto(); + + SparseTensorTypeBase(); + ~SparseTensorTypeBase() override; + + private: + struct Impl; + Impl* impl_; +}; + +template +class SparseTensorType : public SparseTensorTypeBase { + public: + static_assert(data_types_internal::IsSparseTensorContainedType::value, + "Requires one of the sparse-tensor fundamental types"); + + static MLDataType Type(); + + /// Return a MLDataType representing the element-type + MLDataType GetElementType() const override { + return DataTypeImpl::GetType(); + } + + private: + SparseTensorType() { + using namespace data_types_internal; + SparseTensorTypeHelper::Set(utils::ToTensorProtoElementType(), MutableTypeProto()); + } +}; + +#endif // !defined(DISABLE_SPARSE_TENSORS) + +/// Common base-class for all optional types. + +#if !defined(DISABLE_OPTIONAL_TYPE) +class OptionalTypeBase : public DataTypeImpl { + public: + static MLDataType Type(); + + bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const override; + + DeleteFunc GetDeleteFunc() const override { + // should never reach here. + ORT_NOT_IMPLEMENTED(__FUNCTION__, " is not implemented"); + } + + const ONNX_NAMESPACE::TypeProto* GetTypeProto() const override; + + virtual MLDataType GetElementType() const { + // should never reach here. + ORT_NOT_IMPLEMENTED(__FUNCTION__, " is not implemented"); + } + + OptionalTypeBase(const OptionalTypeBase&) = delete; + OptionalTypeBase& operator=(const OptionalTypeBase&) = delete; + + protected: + ONNX_NAMESPACE::TypeProto& MutableTypeProto(); + + OptionalTypeBase(); + ~OptionalTypeBase() override; + + private: + struct Impl; + Impl* impl_; +}; +#endif + +// Derive from OptionalTypeBase if the Optional type support is enabled, +// else derive from DisabledTypeBase +template +class OptionalType : +#if !defined(DISABLE_OPTIONAL_TYPE) + public OptionalTypeBase +#else + public DisabledTypeBase +#endif +{ + public: + static MLDataType Type(); + +#if !defined(DISABLE_OPTIONAL_TYPE) + static_assert(data_types_internal::IsOptionalOrtType::value, + "Requires one of the supported types: Tensor or TensorSeq"); + + static_assert(data_types_internal::IsTensorContainedType::value, + "Requires one of the tensor fundamental types"); + + MLDataType GetElementType() const override { + return data_types_internal::OptionalTypeHelper::GetElemType(); + } +#endif + + private: +#if !defined(DISABLE_OPTIONAL_TYPE) + OptionalType() +#else + OptionalType() : DisabledTypeBase{DataTypeImpl::GeneralType::kOptional, 0} +#endif + { + using namespace data_types_internal; + OptionalTypeHelper::Set(OptionalTypeHelper::GetElemType()->GetTypeProto(), MutableTypeProto()); + } +}; // namespace onnxruntime + +/** + * \brief Provide a specialization for your C++ Non-tensor type + * so your implementation FromDataTypeContainer/ToDataTypeContainer + * functions correctly. Otherwise you get a default implementation + * which may not be what you need/want. + * + * This class is used to create OrtValue, fetch data from OrtValue via + * C/C++ APIs + */ +template +struct NonTensorTypeConverter { + static void FromContainer(MLDataType /*dtype*/, const void* /*data*/, size_t /*data_size*/, OrtValue& /*output*/) { + ORT_THROW("Not implemented"); + } + static void ToContainer(const OrtValue& /*input*/, size_t /*data_size*/, void* /*data*/) { + ORT_THROW("Not implemented"); + } +}; + +/** + * \brief Base type for all non-tensors, maps, sequences and opaques + */ +class NonTensorTypeBase : public DataTypeImpl { + public: + DeleteFunc GetDeleteFunc() const override = 0; + + virtual CreateFunc GetCreateFunc() const = 0; + + const ONNX_NAMESPACE::TypeProto* GetTypeProto() const override; + + // \brief Override for Non-tensor types to initialize non-tensor CPP + // data representation from data. The caller of the interface + // should have a shared definition of the data which is used to initialize + // CPP data representation. This is used from C API. + // + // \param data - pointer to a data container structure non_tensor type specific + // \param data_size - size of the data container structure, used for rudimentary checks + // \param output - reference to a default constructed non-tensor type + // \returns OrtValue + // \throw if there is an error + virtual void FromDataContainer(const void* data, size_t data_size, OrtValue& output) const; + + // \brief Override for Non-tensor types to fetch data from the internal CPP data representation + // The caller of the interface should have a shared definition of the data which is used to initialize + // CPP data representation. This is used from C API. + // + // \param input - OrtValue containing data + // \param data_size - size of the structure that is being passed for receiving data, used for + // validation + // \param data - pointer to receiving data structure + virtual void ToDataContainer(const OrtValue& input, size_t data_size, void* data) const; + + NonTensorTypeBase(const NonTensorTypeBase&) = delete; + NonTensorTypeBase& operator=(const NonTensorTypeBase&) = delete; + + protected: + NonTensorTypeBase(size_t size); + ~NonTensorTypeBase() override; + + ONNX_NAMESPACE::TypeProto& MutableTypeProto(); + + bool IsMapCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const; + + bool IsSequenceCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const; + + bool IsOpaqueCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const; + + private: + struct Impl; + Impl* impl_; +}; + +// This is where T is the actual CPPRuntimeType +template +class NonTensorType : public NonTensorTypeBase { + private: + static void Delete(void* p) { + delete static_cast(p); + } + + public: + DeleteFunc GetDeleteFunc() const override { + return &Delete; + } + + CreateFunc GetCreateFunc() const override { + return []() -> void* { return new T(); }; + } + + protected: + NonTensorType() : NonTensorTypeBase(sizeof(T)) {} +}; + +#if !defined(DISABLE_ML_OPS) +/** + * \brief MapType. Use this type to register + * mapping types. + * + * \param T - cpp type that you wish to register as runtime MapType + * + * \details Usage: ORT_REGISTER_MAP(C++Type) + * The type is required to have mapped_type and + * key_type defined + */ +template +class MapType : public NonTensorType { + public: + static_assert(data_types_internal::IsTensorContainedType::value, + "Requires one of the tensor fundamental types as key"); + + static MLDataType Type(); + + bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const override { + return this->IsMapCompatible(type_proto); + } + + private: + MapType() { + using namespace data_types_internal; + MapTypeHelper::Set(utils::ToTensorProtoElementType(), + MapTypeHelper::GetValueType()->GetTypeProto(), + this->MutableTypeProto()); + } +}; +#endif + +/** + * \brief SequenceType. Use to register sequence for non-tensor types. + * + * \param T - CPP type that you wish to register as Sequence + * runtime type. + * + * \details Usage: ORT_REGISTER_SEQ(C++Type) + * The type is required to have value_type defined + */ +template +class SequenceType : public NonTensorType { + public: + static MLDataType Type(); + + bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const override { + return this->IsSequenceCompatible(type_proto); + } + + private: + SequenceType() { + using namespace data_types_internal; + SequenceTypeHelper::Set(SequenceTypeHelper::GetElemType()->GetTypeProto(), + this->MutableTypeProto()); + } +}; + +/** + * \brief SequenceTensorTypeBase serves as a base type class for + * Tensor sequences. Akin to TensorTypeBase. + * Runtime representation is always TensorSeq. + */ +class SequenceTensorTypeBase : public DataTypeImpl { + public: + static MLDataType Type(); + + bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const override; + + virtual MLDataType GetElementType() const { + // should never reach here. + ORT_NOT_IMPLEMENTED(__FUNCTION__, " is not implemented"); + } + + DeleteFunc GetDeleteFunc() const override; + + const ONNX_NAMESPACE::TypeProto* GetTypeProto() const override; + + SequenceTensorTypeBase(const SequenceTensorTypeBase&) = delete; + SequenceTensorTypeBase& operator=(const SequenceTensorTypeBase&) = delete; + + protected: + SequenceTensorTypeBase(); + ~SequenceTensorTypeBase(); + + ONNX_NAMESPACE::TypeProto& MutableTypeProto(); + + private: + struct Impl; + Impl* impl_; +}; +#if defined(_MSC_VER) && !defined(__clang__) +#pragma warning(pop) +#endif +/** + * \brief SequenceTensorType. Use to register sequence for non-tensor types. + * + * \param CPPRuntime - We always use TensorSeq + * + * \param TensorElemType - one of the primitive types + * + * \details Usage: ORT_REGISTER_SEQ_TENSOR_TYPE() + * The type is required to have value_type defined + */ +template +class SequenceTensorType : public SequenceTensorTypeBase { + public: + static_assert(data_types_internal::IsTensorContainedType::value, + "Requires one of the tensor fundamental types"); + + static MLDataType Type(); + + /// Return a MLDataType representing the element-type + MLDataType GetElementType() const override { + return DataTypeImpl::GetType(); + } + + private: + SequenceTensorType() { + using namespace data_types_internal; + SequenceTypeHelper::Set(SequenceTypeHelper::GetElemType()->GetTypeProto(), + MutableTypeProto()); + } +}; + +/** + * \brief OpaqueType + * + * \tparam T - cpp runtume that implements the Opaque type + * + * \tparam const char D[] - domain must be extern to be unique + * + * \tparam const char N[] - name must be extern to be unique + * + * \details Only one CPP type can be associated with a particular + * OpaqueType registration + * + */ +template +class OpaqueType : public NonTensorType { + public: + static MLDataType Type(); + + bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const override { + return this->IsOpaqueCompatible(type_proto); + } + + void FromDataContainer(const void* data, size_t data_size, OrtValue& output) const override { + NonTensorTypeConverter::FromContainer(this, data, data_size, output); + } + + void ToDataContainer(const OrtValue& input, size_t data_size, void* data) const override { + NonTensorTypeConverter::ToContainer(input, data_size, data); + } + + private: + OpaqueType() { + data_types_internal::AssignOpaqueDomainName(D, N, this->MutableTypeProto()); + } +}; + +/** + * \brief PrimitiveDataTypeBase + * Base class for primitive Tensor contained types + * + * \details This class contains an integer constant that can be + * used for input data type dispatching. This class also stores the number of subelements per size units. + * Example: For int4, the size unit is 1 byte and the number of subelements is 2. + * + */ +class PrimitiveDataTypeBase : public DataTypeImpl { + public: + bool IsCompatible(const ONNX_NAMESPACE::TypeProto&) const override { + return false; + } + + const ONNX_NAMESPACE::TypeProto* GetTypeProto() const final { + return nullptr; + } + + int32_t GetDataType() const { + return data_type_; + } + + int32_t GetNumSubElems() const { + return num_sub_elems_; + } + + bool HasSubElems() const { + return num_sub_elems_ > 1; + } + + protected: + PrimitiveDataTypeBase(size_t size, int32_t data_type, int32_t num_sub_elems) + : DataTypeImpl{GeneralType::kPrimitive, size}, data_type_{data_type}, num_sub_elems_{num_sub_elems} {} + + private: + const int32_t data_type_; + const int32_t num_sub_elems_; // > 1 for subbyte primitives, 1 for normal primitives. +}; + +/** + * \brief PrimitiveDataType + * Typed specialization for primitive types. + * Concrete instances of this class are used by Tensor. + * + * \param T - primitive data type + * + */ +template +class PrimitiveDataType : public PrimitiveDataTypeBase { + private: + static void Delete(void* p) { + delete static_cast(p); + } + + public: + static MLDataType Type(); + + DeleteFunc GetDeleteFunc() const override { + return &Delete; + } + + private: + explicit PrimitiveDataType(int32_t num_sub_elems) + : PrimitiveDataTypeBase{sizeof(T), + utils::ToTensorProtoElementType(), num_sub_elems} { + } +}; + +inline const TensorTypeBase* DataTypeImpl::AsTensorType() const { + return IsTensorType() ? static_cast(this) : nullptr; +} + +inline const SequenceTensorTypeBase* DataTypeImpl::AsSequenceTensorType() const { + return IsTensorSequenceType() ? static_cast(this) : nullptr; +} + +#if !defined(DISABLE_SPARSE_TENSORS) +inline const SparseTensorTypeBase* DataTypeImpl::AsSparseTensorType() const { + return IsSparseTensorType() ? static_cast(this) : nullptr; +} +#endif + +#if !defined(DISABLE_OPTIONAL_TYPE) +inline const OptionalTypeBase* DataTypeImpl::AsOptionalType() const { + return IsOptionalType() ? static_cast(this) : nullptr; +} +#endif + +inline const NonTensorTypeBase* DataTypeImpl::AsNonTensorType() const { + return IsNonTensorType() ? static_cast(this) : nullptr; +} + +inline const PrimitiveDataTypeBase* DataTypeImpl::AsPrimitiveDataType() const { + return IsPrimitiveDataType() ? static_cast(this) : nullptr; +} + +// Explicit specialization of base class template function +// is only possible within the enclosing namespace scope, +// thus a simple way to pre-instantiate a given template +// at a registration time does not currently work and the macro +// is needed. +#define ORT_REGISTER_TENSOR_TYPE(ELEM_TYPE) \ + template <> \ + MLDataType TensorType::Type() { \ + static TensorType tensor_type; \ + return &tensor_type; \ + } \ + template <> \ + MLDataType DataTypeImpl::GetTensorType() { \ + return TensorType::Type(); \ + } + +#if !defined(DISABLE_SPARSE_TENSORS) +#define ORT_REGISTER_SPARSE_TENSOR_TYPE(ELEM_TYPE) \ + template <> \ + MLDataType SparseTensorType::Type() { \ + static SparseTensorType tensor_type; \ + return &tensor_type; \ + } \ + template <> \ + MLDataType DataTypeImpl::GetSparseTensorType() { \ + return SparseTensorType::Type(); \ + } +#endif + +#define ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, TYPE) \ + template <> \ + MLDataType OptionalType::Type() { \ + static OptionalType optional_type; \ + return &optional_type; \ + } \ + template <> \ + MLDataType DataTypeImpl::GetOptionalType() { \ + return OptionalType::Type(); \ + } + +#if !defined(DISABLE_ML_OPS) +#define ORT_REGISTER_MAP(TYPE) \ + template <> \ + MLDataType MapType::Type() { \ + static MapType map_type; \ + return &map_type; \ + } \ + template <> \ + MLDataType DataTypeImpl::GetType() { \ + return MapType::Type(); \ + } +#endif + +#define ORT_REGISTER_SEQ(TYPE) \ + template <> \ + MLDataType SequenceType::Type() { \ + static SequenceType sequence_type; \ + return &sequence_type; \ + } \ + template <> \ + MLDataType DataTypeImpl::GetType() { \ + return SequenceType::Type(); \ + } + +#define ORT_REGISTER_SEQ_TENSOR_TYPE(ELEM_TYPE) \ + template <> \ + MLDataType SequenceTensorType::Type() { \ + static SequenceTensorType sequence_tensor_type; \ + return &sequence_tensor_type; \ + } \ + template <> \ + MLDataType DataTypeImpl::GetSequenceTensorType() { \ + return SequenceTensorType::Type(); \ + } + +#define ORT_REGISTER_PRIM_TYPE(TYPE) \ + template <> \ + MLDataType PrimitiveDataType::Type() { \ + static PrimitiveDataType prim_data_type(1); \ + return &prim_data_type; \ + } \ + template <> \ + MLDataType DataTypeImpl::GetType() { \ + return PrimitiveDataType::Type(); \ + } + +// Registers a subbyte primitive. +// Examples: +// - Int4x2 stores 2 packed 4-bit elements in 1 byte: ORT_*_SUBBYTE_TYPE(Int4x2, 2) +// - [not supported] Int3x8 could store 8 packed 3-bit elements in 3 bytes: ORT_*_SUBBYTE_TYPE(Int3x8, 8) +#define ORT_REGISTER_PRIM_SUBBYTE_TYPE(TYPE, NUM_SUB_ELEMS) \ + template <> \ + MLDataType PrimitiveDataType::Type() { \ + static PrimitiveDataType prim_data_type(NUM_SUB_ELEMS); \ + return &prim_data_type; \ + } \ + template <> \ + MLDataType DataTypeImpl::GetType() { \ + return PrimitiveDataType::Type(); \ + } + +#define ORT_REGISTER_OPAQUE_TYPE(CPPType, Domain, Name) \ + template <> \ + MLDataType OpaqueType::Type() { \ + static OpaqueType opaque_type; \ + return &opaque_type; \ + } \ + template <> \ + MLDataType DataTypeImpl::GetType() { \ + return OpaqueType::Type(); \ + } +} // namespace onnxruntime diff --git a/third_party/onnxruntime_headers/src/include/onnxruntime/core/framework/data_types_internal.h b/third_party/onnxruntime_headers/src/include/onnxruntime/core/framework/data_types_internal.h new file mode 100644 index 00000000000000..05f4c10995ef2a --- /dev/null +++ b/third_party/onnxruntime_headers/src/include/onnxruntime/core/framework/data_types_internal.h @@ -0,0 +1,712 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "boost/mp11.hpp" + +#include "core/common/common.h" +#include "core/framework/to_tensor_proto_element_type.h" +#ifndef SHARED_PROVIDER +#include "core/common/type_list.h" +#include "core/framework/data_types.h" +#include "core/graph/onnx_protobuf.h" +#endif + +namespace onnxruntime { +namespace utils { + +// The following primitives are strongly recommended for switching on tensor input datatypes for +// kernel implementations. +// +// 1) If you need to handle all of the primitive tensor contained datatypes, the best choice would be macros +// DispatchOnTensorType or DispatchOnTensorTypeWithReturn. Use inline wrappers so your function can be invoked as function(). +// 2) if you have a few types, use Tensor.IsDataType()/IsDataTypeString() or use utils::IsPrimitiveDataType() +// if you have a standalone MLDatatType with a sequence of if/else statements. +// 3) For something in between, we suggest to use CallDispatcher pattern. +// +// Invoking DataTypeImpl::GetType() for switching on input types is discouraged and should be avoided. +// Every primitive type carries with it an integer constant that can be used for quick switching on types. + +#if !defined(DISABLE_FLOAT8_TYPES) + +#define DispatchOnTensorType(tensor_type, function, ...) \ + switch (tensor_type->AsPrimitiveDataType()->GetDataType()) { \ + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_BOOL: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_STRING: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT8: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT8: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT16: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT16: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT32: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT32: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT64: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT64: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FNUZ: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT4: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT4: \ + function(__VA_ARGS__); \ + break; \ + default: \ + ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \ + } + +#define DispatchOnTensorTypeWithReturn(tensor_type, retval, function, ...) \ + switch (tensor_type->AsPrimitiveDataType()->GetDataType()) { \ + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_BOOL: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_STRING: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT8: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT8: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT16: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT16: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT32: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT32: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT64: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT64: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FNUZ: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT4: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT4: \ + retval = function(__VA_ARGS__); \ + break; \ + default: \ + ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \ + } + +#else + +#define DispatchOnTensorType(tensor_type, function, ...) \ + switch (tensor_type->AsPrimitiveDataType()->GetDataType()) { \ + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_BOOL: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_STRING: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT8: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT8: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT16: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT16: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT32: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT32: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT64: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT64: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT4: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT4: \ + function(__VA_ARGS__); \ + break; \ + default: \ + ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \ + } + +#define DispatchOnTensorTypeWithReturn(tensor_type, retval, function, ...) \ + switch (tensor_type->AsPrimitiveDataType()->GetDataType()) { \ + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_BOOL: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_STRING: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT8: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT8: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT16: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT16: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT32: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT32: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT64: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT64: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT4: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT4: \ + retval = function(__VA_ARGS__); \ + break; \ + default: \ + ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \ + } + +#endif + +//////////////////////////////////////////////////////////////////////////////// +/// Use the following primitives if you have a few types to switch on so you +// can write a short sequence of if/else statements. + +// This is a frequently used check so we make a separate utility function. +inline bool IsDataTypeString(MLDataType dt_type) { + auto prim_type = dt_type->AsPrimitiveDataType(); + return (prim_type != nullptr && prim_type->GetDataType() == ONNX_NAMESPACE::TensorProto_DataType_STRING); +} + +// Test if MLDataType is a concrete type of PrimitiveDataTypeBase +// and it is T +template +inline bool IsPrimitiveDataType(MLDataType dt_type) { + auto prim_type = dt_type->AsPrimitiveDataType(); + return (prim_type != nullptr && prim_type->GetDataType() == ToTensorProtoElementType()); +} + +// Use after AsPrimitiveDataType() is successful +// Check if PrimitiveDataTypeBase is of type T +template +inline bool IsPrimitiveDataType(const PrimitiveDataTypeBase* prim_type) { + assert(prim_type != nullptr); + return prim_type->GetDataType() == ToTensorProtoElementType(); +} + +// This implementation contains a workaround for GCC bug https://gcc.gnu.org/bugzilla/show_bug.cgi?id=47226 +// GCC until very recently does not support template parameter pack expansion within lambda context. +namespace mltype_dispatcher_internal { + +// T - type handled by this helper +class CallableDispatchableHelper { + int32_t dt_type_; // Type currently dispatched + size_t called_; + + public: + explicit CallableDispatchableHelper(int32_t dt_type) noexcept : dt_type_(dt_type), called_(0) {} + + // Must return integer to be in a expandable context + template + int Invoke(Fn&& fn, Args&&... args) { + if (utils::ToTensorProtoElementType() == dt_type_) { + std::forward(fn)(std::forward(args)...); + ++called_; + } + return 0; + } + + void CheckCalledOnce() const { + ORT_ENFORCE(called_ == 1, "Unsupported data type: ", dt_type_); + } +}; + +// Default policy is to throw an exception. +// Other policies may set the second result argument accordingly. +template +struct UnsupportedTypeDefaultPolicy { + void operator()(int32_t dt_type, Ret& /*result*/) const { + ORT_THROW("Unsupported data type: ", dt_type); + } +}; + +// Helper with the result type +template +class CallableDispatchableRetHelper { + int32_t dt_type_; // Type currently dispatched + size_t called_; + Ret result_; + + public: + explicit CallableDispatchableRetHelper(int32_t dt_type) noexcept : dt_type_(dt_type), called_(0), result_() {} + + Ret Get() { + // No type was invoked + if (called_ == 0) { + UnsupportedPolicy()(dt_type_, result_); + } + return result_; + } + + // Must return integer to be in a expandable context + template + int Invoke(Fn&& fn, Args&&... args) { + if (utils::ToTensorProtoElementType() == dt_type_) { + result_ = std::forward(fn)(std::forward(args)...); + ++called_; + } + return 0; + } +}; + +template +using TensorProtoElementTypeConstant = + std::integral_constant()>; + +using UndefinedTensorProtoElementTypeConstant = + std::integral_constant; + +} // namespace mltype_dispatcher_internal + +/** + * This class helps to efficiently dispatch calls to implementation function + * objects with a tensor element type template argument. + * + * The constructor accepts a value corresponding to a tensor element type. + * For example, it can be obtained from: + * input_tensor->GetElementType() + * + * The Invoke member functions will instantiate and invoke the provided + * function object template, Fn. Fn must be default constructible. Fn must also + * have a tensor element type template argument. This type template argument + * will be the type that corresponds to the value given in the constructor. + * These functions accept and forward arbitrary function arguments. They ensure + * that Fn is called once with the type specified in the constructor. + * + * @tparam Types The types supported by the implementation. This should be a + * set of ONNX tensor element types that are supported by ORT. + */ +template +class MLTypeCallDispatcher { + using SupportedTypeList = TypeList; + using SupportedTensorProtoElementTypeList = + boost::mp11::mp_transform< + mltype_dispatcher_internal::TensorProtoElementTypeConstant, SupportedTypeList>; + + static_assert( + boost::mp11::mp_and< + boost::mp11::mp_is_set, + boost::mp11::mp_not< + boost::mp11::mp_set_contains< + SupportedTensorProtoElementTypeList, + mltype_dispatcher_internal::UndefinedTensorProtoElementTypeConstant>>>::value, + "Types must map to a unique set of ONNX tensor element data types supported by ORT."); + + int32_t dt_type_; + + public: + /** + * Constructor. + * @param dt_type The value corresponding to the tensor element type to be + * dispatched to. This can be obtained from + * input_tensor->GetElementType() or + * utils::ToTensorProtoElementType(). + */ + explicit MLTypeCallDispatcher(int32_t dt_type) noexcept : dt_type_(dt_type) {} + + /** + * Invokes Fn with the specified arguments. + * + * @tparam Fn The function object template. + * @tparam Args The argument types. + */ + template