forked from chromium/chromium
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Wrap OrtGraph API into OrtModelBuilder
- Loading branch information
Showing
8 changed files
with
366 additions
and
265 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
// Copyright 2024 The Chromium Authors | ||
// Use of this source code is governed by a BSD-style license that can be | ||
// found in the LICENSE file. | ||
|
||
#include "services/webnn/ort/ort_model_builder.h" | ||
|
||
#include "base/notreached.h" | ||
#include "services/webnn/ort/error_ort.h" | ||
#include "services/webnn/ort/utils_ort.h" | ||
|
||
namespace webnn { | ||
|
||
namespace { | ||
|
||
constexpr char kOrtDomainName[] = ""; | ||
constexpr int32_t kOrtOpsetVersion = 21; | ||
|
||
} // namespace | ||
|
||
namespace ort { | ||
|
||
OrtModelBuilder::ModelInfo::ModelInfo() = default; | ||
OrtModelBuilder::ModelInfo::~ModelInfo() = default; | ||
|
||
OrtModelBuilder::OrtModelBuilder(scoped_refptr<AllocatorOrt> allocator) | ||
: allocator_(std::move(allocator)), | ||
model_info_(std::make_unique<ModelInfo>()) { | ||
CHECK_STATUS(GetOrtGraphApi()->CreateGraph(graph_.get_pptr())); | ||
} | ||
OrtModelBuilder::~OrtModelBuilder() = default; | ||
|
||
void OrtModelBuilder::AddInput(std::string_view name, | ||
base::span<const int64_t> shape, | ||
ONNXTensorElementDataType data_type) { | ||
ScopedOrtShape input_shape; | ||
CHECK_STATUS(GetOrtGraphApi()->CreateFixedShape(shape.data(), shape.size(), | ||
input_shape.get_pptr())); | ||
|
||
ScopedOrtValueInfo input_info; | ||
CHECK_STATUS(GetOrtGraphApi()->CreateTensorValueInfo( | ||
name.data(), data_type, input_shape.get_pptr(), input_info.get_pptr())); | ||
CHECK_STATUS( | ||
GetOrtGraphApi()->AddInput(graph_.get_ptr(), input_info.get_pptr())); | ||
} | ||
|
||
void OrtModelBuilder::AddOutput(std::string_view name, | ||
base::span<const int64_t> shape, | ||
ONNXTensorElementDataType data_type) { | ||
ScopedOrtShape output_shape; | ||
CHECK_STATUS(GetOrtGraphApi()->CreateFixedShape(shape.data(), shape.size(), | ||
output_shape.get_pptr())); | ||
|
||
ScopedOrtValueInfo output_info; | ||
CHECK_STATUS(GetOrtGraphApi()->CreateTensorValueInfo( | ||
name.data(), data_type, output_shape.get_pptr(), output_info.get_pptr())); | ||
CHECK_STATUS( | ||
GetOrtGraphApi()->AddOutput(graph_.get_ptr(), output_info.get_pptr())); | ||
} | ||
|
||
void OrtModelBuilder::AddInitializerAsRawData( | ||
std::string_view name, | ||
base::span<const int64_t> shape, | ||
base::span<const uint8_t> data, | ||
ONNXTensorElementDataType data_type) { | ||
ScopedOrtValue initializer; | ||
CHECK_STATUS(GetOrtApi()->CreateTensorAsOrtValue( | ||
allocator_->allocator(), shape.data(), shape.size(), data_type, | ||
initializer.get_pptr())); | ||
|
||
void* ort_tensor_raw_data = nullptr; | ||
CHECK_STATUS(GetOrtApi()->GetTensorMutableData(initializer.get_ptr(), | ||
&ort_tensor_raw_data)); | ||
CHECK(ort_tensor_raw_data); | ||
UNSAFE_BUFFERS( | ||
base::span(static_cast<uint8_t*>(ort_tensor_raw_data), data.size())) | ||
.copy_from(data); | ||
CHECK_STATUS(GetOrtGraphApi()->AddInitializer(graph_.get_ptr(), name.data(), | ||
initializer.get_pptr())); | ||
} | ||
|
||
void OrtModelBuilder::AddInitializerAsExternalData( | ||
std::string_view name, | ||
base::span<const int64_t> shape, | ||
base::span<const uint8_t> data, | ||
ONNXTensorElementDataType data_type) { | ||
auto weight = base::HeapArray<uint8_t>::CopiedFrom(data); | ||
model_info_->external_data.push_back(std::move(weight)); | ||
|
||
ScopedOrtValue initializer; | ||
CHECK_STATUS(GetOrtApi()->CreateTensorWithDataAsOrtValue( | ||
allocator_->memory_info(), model_info_->external_data.back().data(), | ||
model_info_->external_data.back().size(), shape.data(), shape.size(), | ||
data_type, initializer.get_pptr())); | ||
CHECK_STATUS(GetOrtGraphApi()->AddInitializer(graph_.get_ptr(), name.data(), | ||
initializer.get_pptr())); | ||
} | ||
|
||
void OrtModelBuilder::CreateAttribute(ScopedOrtOpAttr& attribute, | ||
std::string_view name, | ||
OrtOpAttrData data) { | ||
if (absl::holds_alternative<int64_t>(data)) { | ||
CHECK_STATUS(GetOrtApi()->CreateOpAttr( | ||
name.data(), &absl::get<int64_t>(data), /*len=*/1, | ||
OrtOpAttrType::ORT_OP_ATTR_INT, attribute.get_pptr())); | ||
} else if (absl::holds_alternative<float>(data)) { | ||
CHECK_STATUS(GetOrtApi()->CreateOpAttr( | ||
name.data(), &absl::get<float>(data), /*len=*/1, | ||
OrtOpAttrType::ORT_OP_ATTR_FLOAT, attribute.get_pptr())); | ||
} else if (absl::holds_alternative<std::string_view>(data)) { | ||
std::string_view string_data = absl::get<std::string_view>(data); | ||
CHECK_STATUS(GetOrtApi()->CreateOpAttr( | ||
name.data(), string_data.data(), string_data.size(), | ||
OrtOpAttrType::ORT_OP_ATTR_STRING, attribute.get_pptr())); | ||
} else if (absl::holds_alternative<base::span<const int64_t>>(data)) { | ||
base::span<const int64_t> ints_data = | ||
absl::get<base::span<const int64_t>>(data); | ||
CHECK_STATUS(GetOrtApi()->CreateOpAttr( | ||
name.data(), ints_data.data(), ints_data.size(), | ||
OrtOpAttrType::ORT_OP_ATTR_INTS, attribute.get_pptr())); | ||
} else if (absl::holds_alternative<base::span<const float>>(data)) { | ||
base::span<const float> floats_data = | ||
absl::get<base::span<const float>>(data); | ||
CHECK_STATUS(GetOrtApi()->CreateOpAttr( | ||
name.data(), floats_data.data(), floats_data.size(), | ||
OrtOpAttrType::ORT_OP_ATTR_FLOATS, attribute.get_pptr())); | ||
} else if (absl::holds_alternative<base::span<const char*>>(data)) { | ||
base::span<const char*> strings_data = | ||
absl::get<base::span<const char*>>(data); | ||
CHECK_STATUS(GetOrtApi()->CreateOpAttr( | ||
name.data(), strings_data.data(), strings_data.size(), | ||
OrtOpAttrType::ORT_OP_ATTR_STRINGS, attribute.get_pptr())); | ||
} | ||
} | ||
|
||
void OrtModelBuilder::AddNode(std::string_view op_type, | ||
std::string_view node_name, | ||
base::span<const char*> input_names, | ||
base::span<const char*> output_names, | ||
base::span<OrtOpAttr**> attributes) { | ||
ScopedOrtNode node; | ||
CHECK_STATUS(GetOrtGraphApi()->CreateNode( | ||
op_type.data(), kOrtDomainName, node_name.data(), input_names.data(), | ||
input_names.size(), output_names.data(), output_names.size(), | ||
attributes.data(), attributes.size(), node.get_pptr())); | ||
CHECK_STATUS(GetOrtGraphApi()->AddNode(graph_.get_ptr(), node.get_pptr())); | ||
} | ||
|
||
std::unique_ptr<OrtModelBuilder::ModelInfo> | ||
OrtModelBuilder::BuildAndTakeModelInfo() { | ||
std::vector<const char*> domain_names = {kOrtDomainName}; | ||
std::vector<int32_t> opset_versions = {kOrtOpsetVersion}; | ||
|
||
CHECK_STATUS(GetOrtGraphApi()->CreateModel( | ||
domain_names.data(), opset_versions.data(), domain_names.size(), | ||
model_info_->model.get_pptr())); | ||
|
||
CHECK_STATUS(GetOrtGraphApi()->AddGraph(model_info_->model.get_ptr(), | ||
graph_.get_pptr())); | ||
|
||
return std::move(model_info_); | ||
} | ||
|
||
} // namespace ort | ||
|
||
} // namespace webnn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
// Copyright 2024 The Chromium Authors | ||
// Use of this source code is governed by a BSD-style license that can be | ||
// found in the LICENSE file. | ||
|
||
#ifndef SERVICES_WEBNN_ORT_ORT_MODEL_BUILDER_H_ | ||
#define SERVICES_WEBNN_ORT_ORT_MODEL_BUILDER_H_ | ||
|
||
#include <memory> | ||
#include <string> | ||
|
||
#include "base/containers/heap_array.h" | ||
#include "base/containers/span.h" | ||
#include "base/memory/stack_allocated.h" | ||
#include "services/webnn/ort/allocator_ort.h" | ||
#include "services/webnn/ort/scoped_ort_types.h" | ||
#include "third_party/abseil-cpp/absl/types/variant.h" | ||
#include "third_party/microsoft_dxheaders/include/onnxruntime_c_api.h" | ||
|
||
namespace webnn { | ||
|
||
namespace ort { | ||
|
||
class OrtModelBuilder final { | ||
STACK_ALLOCATED(); | ||
|
||
public: | ||
struct ModelInfo { | ||
explicit ModelInfo(); | ||
ModelInfo(const ModelInfo&) = delete; | ||
ModelInfo& operator=(const ModelInfo&) = delete; | ||
~ModelInfo(); | ||
|
||
ScopedOrtModel model; | ||
|
||
// TODO: Consider reusing constant operands instead of copying them to | ||
// `external_data`. | ||
// | ||
// Store the external data which should be alive for inference session. | ||
std::vector<base::HeapArray<uint8_t>> external_data; | ||
}; | ||
|
||
explicit OrtModelBuilder(scoped_refptr<AllocatorOrt> allocator); | ||
~OrtModelBuilder(); | ||
OrtModelBuilder(const OrtModelBuilder&) = delete; | ||
OrtModelBuilder& operator=(const OrtModelBuilder&) = delete; | ||
|
||
void AddInput(std::string_view name, | ||
base::span<const int64_t> shape, | ||
ONNXTensorElementDataType data_type); | ||
|
||
void AddOutput(std::string_view name, | ||
base::span<const int64_t> shape, | ||
ONNXTensorElementDataType data_type); | ||
|
||
void AddInitializerAsRawData(std::string_view name, | ||
base::span<const int64_t> shape, | ||
base::span<const uint8_t> data, | ||
ONNXTensorElementDataType data_type); | ||
|
||
void AddInitializerAsExternalData(std::string_view name, | ||
base::span<const int64_t> shape, | ||
base::span<const uint8_t> data, | ||
ONNXTensorElementDataType data_type); | ||
using OrtOpAttrData = absl::variant<int64_t, | ||
float, | ||
std::string_view, | ||
base::span<const int64_t>, | ||
base::span<const float>, | ||
base::span<const char*>>; | ||
void CreateAttribute(ScopedOrtOpAttr& attribute, | ||
std::string_view name, | ||
OrtOpAttrData data); | ||
|
||
void AddNode(std::string_view op_type, | ||
std::string_view node_name, | ||
base::span<const char*> input_names, | ||
base::span<const char*> output_names, | ||
base::span<OrtOpAttr**> attributes = {}); | ||
|
||
std::unique_ptr<ModelInfo> BuildAndTakeModelInfo(); | ||
|
||
private: | ||
scoped_refptr<AllocatorOrt> allocator_; | ||
|
||
ScopedOrtGraph graph_; | ||
|
||
std::unique_ptr<ModelInfo> model_info_; | ||
}; | ||
|
||
} // namespace ort | ||
} // namespace webnn | ||
|
||
#endif // SERVICES_WEBNN_ORT_ORT_MODEL_BUILDER_H_ |
Oops, something went wrong.