Skip to content

Commit

Permalink
Use ORT Model Builder API
Browse files Browse the repository at this point in the history
* Pass ORT_API_VERSION to `OrtApiBase::GetApi()`

Also removes the inclusion of onnx.pb.h header.

* Add third_party/onnxruntime_headers

Import https://github.com/microsoft/onnxruntime/tree/main/include

Commit is based on microsoft/onnxruntime#23223

* Use ORT Model Builder API

* Refactor scoped ORT type ptr

1. Rename to ScopedOrtTypePtr
2. Use macros
3. Introduce `operator T*()`
4. Introduce `Release()` method
5. Rename `get_ptr()` to `Get()`
6. Rename `get_pptr()` to `GetAddressOf()`

* Remove ONNX Runtime headers from third_party/microsoft_dxheaders
  • Loading branch information
huningxin authored Jan 2, 2025
1 parent 1eb9dfc commit 2bc8f32
Show file tree
Hide file tree
Showing 134 changed files with 25,097 additions and 385 deletions.
1 change: 1 addition & 0 deletions services/webnn/BUILD.gn
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ component("webnn_service") {
deps += [
"//third_party/fp16",
"//third_party/microsoft_dxheaders:dxguids",
"//third_party/onnxruntime_headers",
"//ui/gl",
"//ui/gl/init",
]
Expand Down
2 changes: 1 addition & 1 deletion services/webnn/ort/allocator_ort.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

#include "base/component_export.h"
#include "base/memory/ref_counted.h"
#include "third_party/microsoft_dxheaders/include/onnxruntime_c_api.h"
#include "third_party/onnxruntime_headers/src/include/onnxruntime/core/session/onnxruntime_c_api.h"

namespace webnn::ort {

Expand Down
2 changes: 1 addition & 1 deletion services/webnn/ort/context_impl_ort.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
#include "services/webnn/webnn_constant_operand.h"
#include "services/webnn/webnn_context_impl.h"
#include "services/webnn/webnn_graph_impl.h"
#include "third_party/microsoft_dxheaders/include/onnxruntime_c_api.h"
#include "third_party/onnxruntime_headers/src/include/onnxruntime/core/session/onnxruntime_c_api.h"

namespace webnn::ort {

Expand Down
80 changes: 41 additions & 39 deletions services/webnn/ort/graph_builder_ort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,8 @@ void GraphBuilderOrt::AddInitializer(uint64_t constant_id) {
operand.ByteSpan(),
operand_info.onnx_data_type);

CHECK(result_->id_to_operand_info.try_emplace(constant_id, std::move(operand_info))
CHECK(result_->id_to_operand_info
.try_emplace(constant_id, std::move(operand_info))
.second);
}

Expand Down Expand Up @@ -411,12 +412,13 @@ void GraphBuilderOrt::AddCastOperation(const mojom::ElementWiseUnary& cast) {

int64_t to_data_type = static_cast<int64_t>(
OperandTypeToONNXTensorElementDataType(output_data_type));
ScopedOrtOpAttr attr_to;
ScopedOrtOpAttrPtr attr_to;
model_builder_.CreateAttribute(attr_to, /*name=*/"to", to_data_type);

std::array<OrtOpAttr**, 1> attributes = {attr_to.get_pptr()};
std::array<OrtOpAttr*, 1> attributes = {attr_to};

model_builder_.AddNode(kOpTypeCast, node_name, input_names, output_names, attributes);
model_builder_.AddNode(kOpTypeCast, node_name, input_names, output_names,
attributes);
}

void GraphBuilderOrt::AddClampOperation(const mojom::Clamp& clamp) {
Expand Down Expand Up @@ -487,32 +489,33 @@ void GraphBuilderOrt::AddConv2dOperation(const mojom::Conv2d& conv2d) {
std::array<int64_t, 2> dilations = {
base::checked_cast<int64_t>(conv2d.dilations->height),
base::checked_cast<int64_t>(conv2d.dilations->width)};
ScopedOrtOpAttr attr_dilations;
model_builder_.CreateAttribute(attr_dilations, /*name=*/"dilations", dilations);
ScopedOrtOpAttrPtr attr_dilations;
model_builder_.CreateAttribute(attr_dilations, /*name=*/"dilations",
dilations);

int64_t group = base::checked_cast<int64_t>(conv2d.groups);
ScopedOrtOpAttr attr_group;
ScopedOrtOpAttrPtr attr_group;
model_builder_.CreateAttribute(attr_group, /*name=*/"group", group);

std::array<int64_t, 4> pads = {
base::checked_cast<int64_t>(conv2d.padding->beginning->height),
base::checked_cast<int64_t>(conv2d.padding->beginning->width),
base::checked_cast<int64_t>(conv2d.padding->ending->height),
base::checked_cast<int64_t>(conv2d.padding->ending->width)};
ScopedOrtOpAttr attr_pads;
ScopedOrtOpAttrPtr attr_pads;
model_builder_.CreateAttribute(attr_pads, /*name=*/"pads", pads);

std::array<int64_t, 2> strides = {
base::checked_cast<int64_t>(conv2d.strides->height),
base::checked_cast<int64_t>(conv2d.strides->width)};
ScopedOrtOpAttr attr_strides;
ScopedOrtOpAttrPtr attr_strides;
model_builder_.CreateAttribute(attr_strides, /*name=*/"strides", strides);

std::array<OrtOpAttr**, 4> attributes = {
attr_dilations.get_pptr(),
attr_group.get_pptr(),
attr_pads.get_pptr(),
attr_strides.get_pptr(),
std::array<OrtOpAttr*, 4> attributes = {
attr_dilations,
attr_group,
attr_pads,
attr_strides,
};
model_builder_.AddNode(kOpTypeConv2d, node_name, input_names, output_names,
attributes);
Expand All @@ -535,22 +538,21 @@ void GraphBuilderOrt::AddGemmOperation(const mojom::Gemm& gemm) {
}
std::array<const char*, 1> output_names = {output_name.c_str()};

ScopedOrtOpAttr attr_alpha;
ScopedOrtOpAttrPtr attr_alpha;
model_builder_.CreateAttribute(attr_alpha, /*name=*/"alpha", gemm.alpha);
ScopedOrtOpAttr attr_beta;
ScopedOrtOpAttrPtr attr_beta;
model_builder_.CreateAttribute(attr_beta, /*name=*/"beta", gemm.beta);

int64_t trans_a = static_cast<int64_t>(gemm.a_transpose);
ScopedOrtOpAttr attr_transA;
ScopedOrtOpAttrPtr attr_transA;
model_builder_.CreateAttribute(attr_transA, /*name=*/"transA", trans_a);

int64_t trans_b = static_cast<int64_t>(gemm.b_transpose);
ScopedOrtOpAttr attr_transB;
ScopedOrtOpAttrPtr attr_transB;
model_builder_.CreateAttribute(attr_transB, /*name=*/"transB", trans_b);

std::array<OrtOpAttr**, 4> attributes = {
attr_alpha.get_pptr(), attr_beta.get_pptr(), attr_transA.get_pptr(),
attr_transB.get_pptr()};
std::array<OrtOpAttr*, 4> attributes = {attr_alpha, attr_beta, attr_transA,
attr_transB};

model_builder_.AddNode(kOpTypeGemm, node_name, input_names, output_names,
attributes);
Expand All @@ -576,20 +578,20 @@ void GraphBuilderOrt::AddPool2dOperation(const mojom::Pool2d& pool2d) {
std::array<int64_t, 2> dilations = {
base::checked_cast<int64_t>(pool2d.dilations->height),
base::checked_cast<int64_t>(pool2d.dilations->width)};
ScopedOrtOpAttr attr_dilations;
ScopedOrtOpAttrPtr attr_dilations;
model_builder_.CreateAttribute(attr_dilations, /*name=*/"dilations",
dilations);

std::array<int64_t, 2> strides = {
base::checked_cast<int64_t>(pool2d.strides->height),
base::checked_cast<int64_t>(pool2d.strides->width)};
ScopedOrtOpAttr attr_strides;
ScopedOrtOpAttrPtr attr_strides;
model_builder_.CreateAttribute(attr_strides, /*name=*/"strides", strides);

std::array<int64_t, 2> window_dimensions = {
base::checked_cast<int64_t>(pool2d.window_dimensions->height),
base::checked_cast<int64_t>(pool2d.window_dimensions->width)};
ScopedOrtOpAttr attr_kernel_shape;
ScopedOrtOpAttrPtr attr_kernel_shape;
model_builder_.CreateAttribute(attr_kernel_shape,
/*name=*/"kernel_shape", window_dimensions);

Expand All @@ -600,7 +602,7 @@ void GraphBuilderOrt::AddPool2dOperation(const mojom::Pool2d& pool2d) {
base::checked_cast<int64_t>(pool2d.padding->beginning->width),
base::checked_cast<int64_t>(pool2d.padding->ending->height),
base::checked_cast<int64_t>(pool2d.padding->ending->width)};
ScopedOrtOpAttr attr_pads;
ScopedOrtOpAttrPtr attr_pads;
model_builder_.CreateAttribute(attr_pads, /*name=*/"pads", pads);

// Calculate the ceil_mode.
Expand All @@ -626,12 +628,12 @@ void GraphBuilderOrt::AddPool2dOperation(const mojom::Pool2d& pool2d) {
CHECK(float_output_height.has_value());

int64_t ceil_mode = float_output_height.value() < output_height ? 1 : 0;
ScopedOrtOpAttr attr_ceil_mode;
ScopedOrtOpAttrPtr attr_ceil_mode;
model_builder_.CreateAttribute(attr_ceil_mode, /*name=*/"ceil_mode",
ceil_mode);

// P value of the Lp norm used to pool over the input data.
std::optional<ScopedOrtOpAttr> attr_p;
std::optional<ScopedOrtOpAttrPtr> attr_p;
std::optional<int64_t> p;
std::string op_type;
switch (pool2d.kind) {
Expand All @@ -652,14 +654,13 @@ void GraphBuilderOrt::AddPool2dOperation(const mojom::Pool2d& pool2d) {
}
}

std::vector<OrtOpAttr**> attributes = {
attr_dilations.get_pptr(), attr_strides.get_pptr(),
attr_kernel_shape.get_pptr(), attr_pads.get_pptr(),
attr_ceil_mode.get_pptr()};
std::vector<OrtOpAttr*> attributes = {attr_dilations, attr_strides,
attr_kernel_shape, attr_pads,
attr_ceil_mode};
if (op_type == kOpTypeLpPool2d) {
CHECK(attr_p.has_value());
CHECK(p.has_value());
attributes.push_back(attr_p.value().get_pptr());
attributes.push_back(attr_p.value());
}

const std::string node_name = GetNodeName(pool2d.label);
Expand Down Expand Up @@ -710,11 +711,12 @@ void GraphBuilderOrt::AddSoftmaxOperation(const mojom::Softmax& softmax) {
std::array<const char*, 1> output_names = {output_name.c_str()};

int64_t axis = static_cast<int64_t>(softmax.axis);
ScopedOrtOpAttr attr_axis;
ScopedOrtOpAttrPtr attr_axis;
model_builder_.CreateAttribute(attr_axis, /*name=*/"axis", axis);

std::array<OrtOpAttr**, 1> attributes = {attr_axis.get_pptr()};
model_builder_.AddNode(kOpTypeSoftmax, node_name, input_names, output_names, attributes);
std::array<OrtOpAttr*, 1> attributes = {attr_axis};
model_builder_.AddNode(kOpTypeSoftmax, node_name, input_names, output_names,
attributes);
}

void GraphBuilderOrt::AddTransposeOperation(const mojom::Transpose& transpose) {
Expand All @@ -727,10 +729,10 @@ void GraphBuilderOrt::AddTransposeOperation(const mojom::Transpose& transpose) {

std::vector<int64_t> permutation(transpose.permutation.begin(),
transpose.permutation.end());
ScopedOrtOpAttr attr_perm;
ScopedOrtOpAttrPtr attr_perm;
model_builder_.CreateAttribute(attr_perm, /*name=*/"perm", permutation);

std::array<OrtOpAttr**, 1> attributes = {attr_perm.get_pptr()};
std::array<OrtOpAttr*, 1> attributes = {attr_perm};
model_builder_.AddNode(kOpTypeTranspose, node_name, input_names, output_names,
attributes);
}
Expand All @@ -749,12 +751,12 @@ void GraphBuilderOrt::AddWhereOperation(const mojom::Where& where) {
std::array<const char*, 1> cast_output_names = {
cast_node_output_name.c_str()};

ScopedOrtOpAttr attr_to;
ScopedOrtOpAttrPtr attr_to;
int64_t to_data_type =
static_cast<int64_t>(ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL);
model_builder_.CreateAttribute(attr_to, /*name=*/"to", to_data_type);

std::array<OrtOpAttr**, 1> cast_attributes = {attr_to.get_pptr()};
std::array<OrtOpAttr*, 1> cast_attributes = {attr_to};
model_builder_.AddNode(kOpTypeCast, cast_node_name, cast_input_names,
cast_output_names, cast_attributes);
next_operand_id_++;
Expand Down
4 changes: 2 additions & 2 deletions services/webnn/ort/graph_builder_ort.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
#include "base/memory/stack_allocated.h"
#include "base/types/expected.h"
#include "services/webnn/ort/allocator_ort.h"
#include "services/webnn/ort/ort_model_builder.h"
#include "services/webnn/ort/scoped_ort_types.h"
#include "services/webnn/public/cpp/context_properties.h"
#include "services/webnn/public/cpp/operand_descriptor.h"
#include "services/webnn/public/mojom/webnn_context_provider.mojom.h"
#include "services/webnn/public/mojom/webnn_error.mojom-forward.h"
#include "services/webnn/public/mojom/webnn_graph.mojom.h"
#include "services/webnn/ort/ort_model_builder.h"

namespace webnn {

Expand Down Expand Up @@ -62,7 +62,7 @@ class GraphBuilderOrt {
const OperandInfo& GetOperandInfo(uint64_t operand_id) const;

std::map<uint64_t, OperandInfo> id_to_operand_info;

std::unique_ptr<OrtModelBuilder::ModelInfo> model_info;
};

Expand Down
5 changes: 3 additions & 2 deletions services/webnn/ort/graph_impl_ort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "services/webnn/resource_task.h"
#include "services/webnn/webnn_constant_operand.h"
#include "services/webnn/webnn_graph_impl.h"
#include "third_party/onnxruntime_headers/src/include/onnxruntime/core/providers/dml/dml_provider_factory.h"

namespace webnn::ort {

Expand Down Expand Up @@ -132,8 +133,8 @@ GraphImplOrt::CreateAndBuildOnBackgroundThread(

OrtSession* session;
const OrtEnv* env = allocator->env();
OrtStatus* status = GetOrtGraphApi()->CreateSessionFromModel(
env, result->model_info->model.get_ptr(), session_options, &session);
OrtStatus* status = GetOrtModelBuilderApi()->CreateSessionFromModel(
env, result->model_info->model, session_options, &session);
ort_api->ReleaseSessionOptions(session_options);

if (status != NULL) {
Expand Down
2 changes: 1 addition & 1 deletion services/webnn/ort/graph_impl_ort.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
#include "services/webnn/queueable_resource_state.h"
#include "services/webnn/webnn_context_impl.h"
#include "services/webnn/webnn_graph_impl.h"
#include "third_party/microsoft_dxheaders/include/onnxruntime_c_api.h"
#include "third_party/onnxruntime_headers/src/include/onnxruntime/core/session/onnxruntime_c_api.h"

namespace webnn {

Expand Down
Loading

0 comments on commit 2bc8f32

Please sign in to comment.