Skip to content

Commit

Permalink
Wrap OrtGraph API into OrtModelBuilder
Browse files Browse the repository at this point in the history
  • Loading branch information
shiyi9801 authored Dec 31, 2024
1 parent dc6ae60 commit 1eb9dfc
Show file tree
Hide file tree
Showing 8 changed files with 366 additions and 265 deletions.
2 changes: 2 additions & 0 deletions services/webnn/BUILD.gn
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ component("webnn_service") {
"ort/graph_builder_ort.h",
"ort/graph_impl_ort.cc",
"ort/graph_impl_ort.h",
"ort/ort_model_builder.cc",
"ort/ort_model_builder.h",
"ort/platform_functions_ort.cc",
"ort/platform_functions_ort.h",
"ort/scoped_ort_types.cc",
Expand Down
315 changes: 75 additions & 240 deletions services/webnn/ort/graph_builder_ort.cc

Large diffs are not rendered by default.

22 changes: 5 additions & 17 deletions services/webnn/ort/graph_builder_ort.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
#include <string_view>

#include "base/containers/flat_map.h"
#include "base/containers/heap_array.h"
#include "base/containers/span.h"
#include "base/files/file_path.h"
#include "base/memory/raw_ptr.h"
Expand All @@ -24,7 +23,7 @@
#include "services/webnn/public/mojom/webnn_context_provider.mojom.h"
#include "services/webnn/public/mojom/webnn_error.mojom-forward.h"
#include "services/webnn/public/mojom/webnn_graph.mojom.h"
#include "third_party/microsoft_dxheaders/include/onnxruntime_c_api.h"
#include "services/webnn/ort/ort_model_builder.h"

namespace webnn {

Expand Down Expand Up @@ -60,20 +59,11 @@ class GraphBuilderOrt {
Result& operator=(const Result&) = delete;
~Result();

const ScopedOrtModel& GetModel();

const OperandInfo& GetOperandInfo(uint64_t operand_id) const;

const std::map<uint64_t, OperandInfo>& id_to_operand_info_map() const;

ScopedOrtModel model;
std::map<uint64_t, OperandInfo> operand_infos;

// TODO: Consider reusing constant operands instead of copying them to
// `weights`.
//
// Store the weights which should be alive for inference session.
std::vector<base::HeapArray<uint8_t>> weights;
std::map<uint64_t, OperandInfo> id_to_operand_info;

std::unique_ptr<OrtModelBuilder::ModelInfo> model_info;
};

// Factory method that creates a GraphBuilderOrt, builds and serializes the
Expand Down Expand Up @@ -148,8 +138,6 @@ class GraphBuilderOrt {

[[nodiscard]] base::expected<void, mojom::ErrorPtr> BuildModel();

scoped_refptr<AllocatorOrt> allocator_;

// Used for inserting new operands into graph.
uint64_t next_operand_id_ = 0;

Expand All @@ -163,7 +151,7 @@ class GraphBuilderOrt {

const ContextProperties context_properties_;

ScopedOrtGraph graph_;
OrtModelBuilder model_builder_;

std::unique_ptr<Result> result_;
};
Expand Down
13 changes: 7 additions & 6 deletions services/webnn/ort/graph_impl_ort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,10 @@ void GraphImplOrt::CreateAndBuild(
std::move(wrapped_callback));
}

GraphImplOrt::Session::Session(OrtSession* session,
std::vector<base::HeapArray<uint8_t>> weights)
: weights(std::move(weights)), session(session) {}
GraphImplOrt::Session::Session(
OrtSession* session,
std::vector<base::HeapArray<uint8_t>> external_data)
: external_data(std::move(external_data)), session(session) {}

GraphImplOrt::Session::~Session() {
// TODO: Can we call `ReleaseSession` from Dllmain (because session owns a
Expand Down Expand Up @@ -132,7 +133,7 @@ GraphImplOrt::CreateAndBuildOnBackgroundThread(
OrtSession* session;
const OrtEnv* env = allocator->env();
OrtStatus* status = GetOrtGraphApi()->CreateSessionFromModel(
env, result->model.get_ptr(), session_options, &session);
env, result->model_info->model.get_ptr(), session_options, &session);
ort_api->ReleaseSessionOptions(session_options);

if (status != NULL) {
Expand All @@ -145,8 +146,8 @@ GraphImplOrt::CreateAndBuildOnBackgroundThread(

LOG(ERROR) << "Running on ORT=============";

return base::WrapUnique(
new GraphImplOrt::Session(session, std::move(result->weights)));
return base::WrapUnique(new GraphImplOrt::Session(
session, std::move(result->model_info->external_data)));
}

// static
Expand Down
5 changes: 3 additions & 2 deletions services/webnn/ort/graph_impl_ort.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,15 @@ class GraphImplOrt final : public WebNNGraphImpl {
~GraphImplOrt() override;

struct Session {
Session(OrtSession* session, std::vector<base::HeapArray<uint8_t>> weights);
Session(OrtSession* session,
std::vector<base::HeapArray<uint8_t>> external_data);
Session(const Session&) = delete;
Session& operator=(const Session&) = delete;
~Session();

OrtSession* GetSession() { return session.get(); }

std::vector<base::HeapArray<uint8_t>> weights;
std::vector<base::HeapArray<uint8_t>> external_data;
raw_ptr<OrtSession> session;
};

Expand Down
165 changes: 165 additions & 0 deletions services/webnn/ort/ort_model_builder.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "services/webnn/ort/ort_model_builder.h"

#include "base/notreached.h"
#include "services/webnn/ort/error_ort.h"
#include "services/webnn/ort/utils_ort.h"

namespace webnn {

namespace {

constexpr char kOrtDomainName[] = "";
constexpr int32_t kOrtOpsetVersion = 21;

} // namespace

namespace ort {

OrtModelBuilder::ModelInfo::ModelInfo() = default;
OrtModelBuilder::ModelInfo::~ModelInfo() = default;

OrtModelBuilder::OrtModelBuilder(scoped_refptr<AllocatorOrt> allocator)
: allocator_(std::move(allocator)),
model_info_(std::make_unique<ModelInfo>()) {
CHECK_STATUS(GetOrtGraphApi()->CreateGraph(graph_.get_pptr()));
}
OrtModelBuilder::~OrtModelBuilder() = default;

void OrtModelBuilder::AddInput(std::string_view name,
base::span<const int64_t> shape,
ONNXTensorElementDataType data_type) {
ScopedOrtShape input_shape;
CHECK_STATUS(GetOrtGraphApi()->CreateFixedShape(shape.data(), shape.size(),
input_shape.get_pptr()));

ScopedOrtValueInfo input_info;
CHECK_STATUS(GetOrtGraphApi()->CreateTensorValueInfo(
name.data(), data_type, input_shape.get_pptr(), input_info.get_pptr()));
CHECK_STATUS(
GetOrtGraphApi()->AddInput(graph_.get_ptr(), input_info.get_pptr()));
}

void OrtModelBuilder::AddOutput(std::string_view name,
base::span<const int64_t> shape,
ONNXTensorElementDataType data_type) {
ScopedOrtShape output_shape;
CHECK_STATUS(GetOrtGraphApi()->CreateFixedShape(shape.data(), shape.size(),
output_shape.get_pptr()));

ScopedOrtValueInfo output_info;
CHECK_STATUS(GetOrtGraphApi()->CreateTensorValueInfo(
name.data(), data_type, output_shape.get_pptr(), output_info.get_pptr()));
CHECK_STATUS(
GetOrtGraphApi()->AddOutput(graph_.get_ptr(), output_info.get_pptr()));
}

void OrtModelBuilder::AddInitializerAsRawData(
std::string_view name,
base::span<const int64_t> shape,
base::span<const uint8_t> data,
ONNXTensorElementDataType data_type) {
ScopedOrtValue initializer;
CHECK_STATUS(GetOrtApi()->CreateTensorAsOrtValue(
allocator_->allocator(), shape.data(), shape.size(), data_type,
initializer.get_pptr()));

void* ort_tensor_raw_data = nullptr;
CHECK_STATUS(GetOrtApi()->GetTensorMutableData(initializer.get_ptr(),
&ort_tensor_raw_data));
CHECK(ort_tensor_raw_data);
UNSAFE_BUFFERS(
base::span(static_cast<uint8_t*>(ort_tensor_raw_data), data.size()))
.copy_from(data);
CHECK_STATUS(GetOrtGraphApi()->AddInitializer(graph_.get_ptr(), name.data(),
initializer.get_pptr()));
}

void OrtModelBuilder::AddInitializerAsExternalData(
std::string_view name,
base::span<const int64_t> shape,
base::span<const uint8_t> data,
ONNXTensorElementDataType data_type) {
auto weight = base::HeapArray<uint8_t>::CopiedFrom(data);
model_info_->external_data.push_back(std::move(weight));

ScopedOrtValue initializer;
CHECK_STATUS(GetOrtApi()->CreateTensorWithDataAsOrtValue(
allocator_->memory_info(), model_info_->external_data.back().data(),
model_info_->external_data.back().size(), shape.data(), shape.size(),
data_type, initializer.get_pptr()));
CHECK_STATUS(GetOrtGraphApi()->AddInitializer(graph_.get_ptr(), name.data(),
initializer.get_pptr()));
}

void OrtModelBuilder::CreateAttribute(ScopedOrtOpAttr& attribute,
std::string_view name,
OrtOpAttrData data) {
if (absl::holds_alternative<int64_t>(data)) {
CHECK_STATUS(GetOrtApi()->CreateOpAttr(
name.data(), &absl::get<int64_t>(data), /*len=*/1,
OrtOpAttrType::ORT_OP_ATTR_INT, attribute.get_pptr()));
} else if (absl::holds_alternative<float>(data)) {
CHECK_STATUS(GetOrtApi()->CreateOpAttr(
name.data(), &absl::get<float>(data), /*len=*/1,
OrtOpAttrType::ORT_OP_ATTR_FLOAT, attribute.get_pptr()));
} else if (absl::holds_alternative<std::string_view>(data)) {
std::string_view string_data = absl::get<std::string_view>(data);
CHECK_STATUS(GetOrtApi()->CreateOpAttr(
name.data(), string_data.data(), string_data.size(),
OrtOpAttrType::ORT_OP_ATTR_STRING, attribute.get_pptr()));
} else if (absl::holds_alternative<base::span<const int64_t>>(data)) {
base::span<const int64_t> ints_data =
absl::get<base::span<const int64_t>>(data);
CHECK_STATUS(GetOrtApi()->CreateOpAttr(
name.data(), ints_data.data(), ints_data.size(),
OrtOpAttrType::ORT_OP_ATTR_INTS, attribute.get_pptr()));
} else if (absl::holds_alternative<base::span<const float>>(data)) {
base::span<const float> floats_data =
absl::get<base::span<const float>>(data);
CHECK_STATUS(GetOrtApi()->CreateOpAttr(
name.data(), floats_data.data(), floats_data.size(),
OrtOpAttrType::ORT_OP_ATTR_FLOATS, attribute.get_pptr()));
} else if (absl::holds_alternative<base::span<const char*>>(data)) {
base::span<const char*> strings_data =
absl::get<base::span<const char*>>(data);
CHECK_STATUS(GetOrtApi()->CreateOpAttr(
name.data(), strings_data.data(), strings_data.size(),
OrtOpAttrType::ORT_OP_ATTR_STRINGS, attribute.get_pptr()));
}
}

void OrtModelBuilder::AddNode(std::string_view op_type,
std::string_view node_name,
base::span<const char*> input_names,
base::span<const char*> output_names,
base::span<OrtOpAttr**> attributes) {
ScopedOrtNode node;
CHECK_STATUS(GetOrtGraphApi()->CreateNode(
op_type.data(), kOrtDomainName, node_name.data(), input_names.data(),
input_names.size(), output_names.data(), output_names.size(),
attributes.data(), attributes.size(), node.get_pptr()));
CHECK_STATUS(GetOrtGraphApi()->AddNode(graph_.get_ptr(), node.get_pptr()));
}

std::unique_ptr<OrtModelBuilder::ModelInfo>
OrtModelBuilder::BuildAndTakeModelInfo() {
std::vector<const char*> domain_names = {kOrtDomainName};
std::vector<int32_t> opset_versions = {kOrtOpsetVersion};

CHECK_STATUS(GetOrtGraphApi()->CreateModel(
domain_names.data(), opset_versions.data(), domain_names.size(),
model_info_->model.get_pptr()));

CHECK_STATUS(GetOrtGraphApi()->AddGraph(model_info_->model.get_ptr(),
graph_.get_pptr()));

return std::move(model_info_);
}

} // namespace ort

} // namespace webnn
93 changes: 93 additions & 0 deletions services/webnn/ort/ort_model_builder.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#ifndef SERVICES_WEBNN_ORT_ORT_MODEL_BUILDER_H_
#define SERVICES_WEBNN_ORT_ORT_MODEL_BUILDER_H_

#include <memory>
#include <string>

#include "base/containers/heap_array.h"
#include "base/containers/span.h"
#include "base/memory/stack_allocated.h"
#include "services/webnn/ort/allocator_ort.h"
#include "services/webnn/ort/scoped_ort_types.h"
#include "third_party/abseil-cpp/absl/types/variant.h"
#include "third_party/microsoft_dxheaders/include/onnxruntime_c_api.h"

namespace webnn {

namespace ort {

class OrtModelBuilder final {
STACK_ALLOCATED();

public:
struct ModelInfo {
explicit ModelInfo();
ModelInfo(const ModelInfo&) = delete;
ModelInfo& operator=(const ModelInfo&) = delete;
~ModelInfo();

ScopedOrtModel model;

// TODO: Consider reusing constant operands instead of copying them to
// `external_data`.
//
// Store the external data which should be alive for inference session.
std::vector<base::HeapArray<uint8_t>> external_data;
};

explicit OrtModelBuilder(scoped_refptr<AllocatorOrt> allocator);
~OrtModelBuilder();
OrtModelBuilder(const OrtModelBuilder&) = delete;
OrtModelBuilder& operator=(const OrtModelBuilder&) = delete;

void AddInput(std::string_view name,
base::span<const int64_t> shape,
ONNXTensorElementDataType data_type);

void AddOutput(std::string_view name,
base::span<const int64_t> shape,
ONNXTensorElementDataType data_type);

void AddInitializerAsRawData(std::string_view name,
base::span<const int64_t> shape,
base::span<const uint8_t> data,
ONNXTensorElementDataType data_type);

void AddInitializerAsExternalData(std::string_view name,
base::span<const int64_t> shape,
base::span<const uint8_t> data,
ONNXTensorElementDataType data_type);
using OrtOpAttrData = absl::variant<int64_t,
float,
std::string_view,
base::span<const int64_t>,
base::span<const float>,
base::span<const char*>>;
void CreateAttribute(ScopedOrtOpAttr& attribute,
std::string_view name,
OrtOpAttrData data);

void AddNode(std::string_view op_type,
std::string_view node_name,
base::span<const char*> input_names,
base::span<const char*> output_names,
base::span<OrtOpAttr**> attributes = {});

std::unique_ptr<ModelInfo> BuildAndTakeModelInfo();

private:
scoped_refptr<AllocatorOrt> allocator_;

ScopedOrtGraph graph_;

std::unique_ptr<ModelInfo> model_info_;
};

} // namespace ort
} // namespace webnn

#endif // SERVICES_WEBNN_ORT_ORT_MODEL_BUILDER_H_
Loading

0 comments on commit 1eb9dfc

Please sign in to comment.