Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use ORT Model Builder API #24

Merged
merged 5 commits into from
Jan 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions services/webnn/BUILD.gn
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ component("webnn_service") {
deps += [
"//third_party/fp16",
"//third_party/microsoft_dxheaders:dxguids",
"//third_party/onnxruntime_headers",
"//ui/gl",
"//ui/gl/init",
]
Expand Down
2 changes: 1 addition & 1 deletion services/webnn/ort/allocator_ort.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

#include "base/component_export.h"
#include "base/memory/ref_counted.h"
#include "third_party/microsoft_dxheaders/include/onnxruntime_c_api.h"
#include "third_party/onnxruntime_headers/src/include/onnxruntime/core/session/onnxruntime_c_api.h"

namespace webnn::ort {

Expand Down
2 changes: 1 addition & 1 deletion services/webnn/ort/context_impl_ort.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
#include "services/webnn/webnn_constant_operand.h"
#include "services/webnn/webnn_context_impl.h"
#include "services/webnn/webnn_graph_impl.h"
#include "third_party/microsoft_dxheaders/include/onnxruntime_c_api.h"
#include "third_party/onnxruntime_headers/src/include/onnxruntime/core/session/onnxruntime_c_api.h"

namespace webnn::ort {

Expand Down
80 changes: 41 additions & 39 deletions services/webnn/ort/graph_builder_ort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,8 @@ void GraphBuilderOrt::AddInitializer(uint64_t constant_id) {
operand.ByteSpan(),
operand_info.onnx_data_type);

CHECK(result_->id_to_operand_info.try_emplace(constant_id, std::move(operand_info))
CHECK(result_->id_to_operand_info
.try_emplace(constant_id, std::move(operand_info))
.second);
}

Expand Down Expand Up @@ -411,12 +412,13 @@ void GraphBuilderOrt::AddCastOperation(const mojom::ElementWiseUnary& cast) {

int64_t to_data_type = static_cast<int64_t>(
OperandTypeToONNXTensorElementDataType(output_data_type));
ScopedOrtOpAttr attr_to;
ScopedOrtOpAttrPtr attr_to;
model_builder_.CreateAttribute(attr_to, /*name=*/"to", to_data_type);

std::array<OrtOpAttr**, 1> attributes = {attr_to.get_pptr()};
std::array<OrtOpAttr*, 1> attributes = {attr_to};

model_builder_.AddNode(kOpTypeCast, node_name, input_names, output_names, attributes);
model_builder_.AddNode(kOpTypeCast, node_name, input_names, output_names,
attributes);
}

void GraphBuilderOrt::AddClampOperation(const mojom::Clamp& clamp) {
Expand Down Expand Up @@ -487,32 +489,33 @@ void GraphBuilderOrt::AddConv2dOperation(const mojom::Conv2d& conv2d) {
std::array<int64_t, 2> dilations = {
base::checked_cast<int64_t>(conv2d.dilations->height),
base::checked_cast<int64_t>(conv2d.dilations->width)};
ScopedOrtOpAttr attr_dilations;
model_builder_.CreateAttribute(attr_dilations, /*name=*/"dilations", dilations);
ScopedOrtOpAttrPtr attr_dilations;
model_builder_.CreateAttribute(attr_dilations, /*name=*/"dilations",
dilations);

int64_t group = base::checked_cast<int64_t>(conv2d.groups);
ScopedOrtOpAttr attr_group;
ScopedOrtOpAttrPtr attr_group;
model_builder_.CreateAttribute(attr_group, /*name=*/"group", group);

std::array<int64_t, 4> pads = {
base::checked_cast<int64_t>(conv2d.padding->beginning->height),
base::checked_cast<int64_t>(conv2d.padding->beginning->width),
base::checked_cast<int64_t>(conv2d.padding->ending->height),
base::checked_cast<int64_t>(conv2d.padding->ending->width)};
ScopedOrtOpAttr attr_pads;
ScopedOrtOpAttrPtr attr_pads;
model_builder_.CreateAttribute(attr_pads, /*name=*/"pads", pads);

std::array<int64_t, 2> strides = {
base::checked_cast<int64_t>(conv2d.strides->height),
base::checked_cast<int64_t>(conv2d.strides->width)};
ScopedOrtOpAttr attr_strides;
ScopedOrtOpAttrPtr attr_strides;
model_builder_.CreateAttribute(attr_strides, /*name=*/"strides", strides);

std::array<OrtOpAttr**, 4> attributes = {
attr_dilations.get_pptr(),
attr_group.get_pptr(),
attr_pads.get_pptr(),
attr_strides.get_pptr(),
std::array<OrtOpAttr*, 4> attributes = {
attr_dilations,
attr_group,
attr_pads,
attr_strides,
};
model_builder_.AddNode(kOpTypeConv2d, node_name, input_names, output_names,
attributes);
Expand All @@ -535,22 +538,21 @@ void GraphBuilderOrt::AddGemmOperation(const mojom::Gemm& gemm) {
}
std::array<const char*, 1> output_names = {output_name.c_str()};

ScopedOrtOpAttr attr_alpha;
ScopedOrtOpAttrPtr attr_alpha;
model_builder_.CreateAttribute(attr_alpha, /*name=*/"alpha", gemm.alpha);
ScopedOrtOpAttr attr_beta;
ScopedOrtOpAttrPtr attr_beta;
model_builder_.CreateAttribute(attr_beta, /*name=*/"beta", gemm.beta);

int64_t trans_a = static_cast<int64_t>(gemm.a_transpose);
ScopedOrtOpAttr attr_transA;
ScopedOrtOpAttrPtr attr_transA;
model_builder_.CreateAttribute(attr_transA, /*name=*/"transA", trans_a);

int64_t trans_b = static_cast<int64_t>(gemm.b_transpose);
ScopedOrtOpAttr attr_transB;
ScopedOrtOpAttrPtr attr_transB;
model_builder_.CreateAttribute(attr_transB, /*name=*/"transB", trans_b);

std::array<OrtOpAttr**, 4> attributes = {
attr_alpha.get_pptr(), attr_beta.get_pptr(), attr_transA.get_pptr(),
attr_transB.get_pptr()};
std::array<OrtOpAttr*, 4> attributes = {attr_alpha, attr_beta, attr_transA,
attr_transB};

model_builder_.AddNode(kOpTypeGemm, node_name, input_names, output_names,
attributes);
Expand All @@ -576,20 +578,20 @@ void GraphBuilderOrt::AddPool2dOperation(const mojom::Pool2d& pool2d) {
std::array<int64_t, 2> dilations = {
base::checked_cast<int64_t>(pool2d.dilations->height),
base::checked_cast<int64_t>(pool2d.dilations->width)};
ScopedOrtOpAttr attr_dilations;
ScopedOrtOpAttrPtr attr_dilations;
model_builder_.CreateAttribute(attr_dilations, /*name=*/"dilations",
dilations);

std::array<int64_t, 2> strides = {
base::checked_cast<int64_t>(pool2d.strides->height),
base::checked_cast<int64_t>(pool2d.strides->width)};
ScopedOrtOpAttr attr_strides;
ScopedOrtOpAttrPtr attr_strides;
model_builder_.CreateAttribute(attr_strides, /*name=*/"strides", strides);

std::array<int64_t, 2> window_dimensions = {
base::checked_cast<int64_t>(pool2d.window_dimensions->height),
base::checked_cast<int64_t>(pool2d.window_dimensions->width)};
ScopedOrtOpAttr attr_kernel_shape;
ScopedOrtOpAttrPtr attr_kernel_shape;
model_builder_.CreateAttribute(attr_kernel_shape,
/*name=*/"kernel_shape", window_dimensions);

Expand All @@ -600,7 +602,7 @@ void GraphBuilderOrt::AddPool2dOperation(const mojom::Pool2d& pool2d) {
base::checked_cast<int64_t>(pool2d.padding->beginning->width),
base::checked_cast<int64_t>(pool2d.padding->ending->height),
base::checked_cast<int64_t>(pool2d.padding->ending->width)};
ScopedOrtOpAttr attr_pads;
ScopedOrtOpAttrPtr attr_pads;
model_builder_.CreateAttribute(attr_pads, /*name=*/"pads", pads);

// Calculate the ceil_mode.
Expand All @@ -626,12 +628,12 @@ void GraphBuilderOrt::AddPool2dOperation(const mojom::Pool2d& pool2d) {
CHECK(float_output_height.has_value());

int64_t ceil_mode = float_output_height.value() < output_height ? 1 : 0;
ScopedOrtOpAttr attr_ceil_mode;
ScopedOrtOpAttrPtr attr_ceil_mode;
model_builder_.CreateAttribute(attr_ceil_mode, /*name=*/"ceil_mode",
ceil_mode);

// P value of the Lp norm used to pool over the input data.
std::optional<ScopedOrtOpAttr> attr_p;
std::optional<ScopedOrtOpAttrPtr> attr_p;
std::optional<int64_t> p;
std::string op_type;
switch (pool2d.kind) {
Expand All @@ -652,14 +654,13 @@ void GraphBuilderOrt::AddPool2dOperation(const mojom::Pool2d& pool2d) {
}
}

std::vector<OrtOpAttr**> attributes = {
attr_dilations.get_pptr(), attr_strides.get_pptr(),
attr_kernel_shape.get_pptr(), attr_pads.get_pptr(),
attr_ceil_mode.get_pptr()};
std::vector<OrtOpAttr*> attributes = {attr_dilations, attr_strides,
attr_kernel_shape, attr_pads,
attr_ceil_mode};
if (op_type == kOpTypeLpPool2d) {
CHECK(attr_p.has_value());
CHECK(p.has_value());
attributes.push_back(attr_p.value().get_pptr());
attributes.push_back(attr_p.value());
}

const std::string node_name = GetNodeName(pool2d.label);
Expand Down Expand Up @@ -710,11 +711,12 @@ void GraphBuilderOrt::AddSoftmaxOperation(const mojom::Softmax& softmax) {
std::array<const char*, 1> output_names = {output_name.c_str()};

int64_t axis = static_cast<int64_t>(softmax.axis);
ScopedOrtOpAttr attr_axis;
ScopedOrtOpAttrPtr attr_axis;
model_builder_.CreateAttribute(attr_axis, /*name=*/"axis", axis);

std::array<OrtOpAttr**, 1> attributes = {attr_axis.get_pptr()};
model_builder_.AddNode(kOpTypeSoftmax, node_name, input_names, output_names, attributes);
std::array<OrtOpAttr*, 1> attributes = {attr_axis};
model_builder_.AddNode(kOpTypeSoftmax, node_name, input_names, output_names,
attributes);
}

void GraphBuilderOrt::AddTransposeOperation(const mojom::Transpose& transpose) {
Expand All @@ -727,10 +729,10 @@ void GraphBuilderOrt::AddTransposeOperation(const mojom::Transpose& transpose) {

std::vector<int64_t> permutation(transpose.permutation.begin(),
transpose.permutation.end());
ScopedOrtOpAttr attr_perm;
ScopedOrtOpAttrPtr attr_perm;
model_builder_.CreateAttribute(attr_perm, /*name=*/"perm", permutation);

std::array<OrtOpAttr**, 1> attributes = {attr_perm.get_pptr()};
std::array<OrtOpAttr*, 1> attributes = {attr_perm};
model_builder_.AddNode(kOpTypeTranspose, node_name, input_names, output_names,
attributes);
}
Expand All @@ -749,12 +751,12 @@ void GraphBuilderOrt::AddWhereOperation(const mojom::Where& where) {
std::array<const char*, 1> cast_output_names = {
cast_node_output_name.c_str()};

ScopedOrtOpAttr attr_to;
ScopedOrtOpAttrPtr attr_to;
int64_t to_data_type =
static_cast<int64_t>(ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL);
model_builder_.CreateAttribute(attr_to, /*name=*/"to", to_data_type);

std::array<OrtOpAttr**, 1> cast_attributes = {attr_to.get_pptr()};
std::array<OrtOpAttr*, 1> cast_attributes = {attr_to};
model_builder_.AddNode(kOpTypeCast, cast_node_name, cast_input_names,
cast_output_names, cast_attributes);
next_operand_id_++;
Expand Down
4 changes: 2 additions & 2 deletions services/webnn/ort/graph_builder_ort.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
#include "base/memory/stack_allocated.h"
#include "base/types/expected.h"
#include "services/webnn/ort/allocator_ort.h"
#include "services/webnn/ort/ort_model_builder.h"
#include "services/webnn/ort/scoped_ort_types.h"
#include "services/webnn/public/cpp/context_properties.h"
#include "services/webnn/public/cpp/operand_descriptor.h"
#include "services/webnn/public/mojom/webnn_context_provider.mojom.h"
#include "services/webnn/public/mojom/webnn_error.mojom-forward.h"
#include "services/webnn/public/mojom/webnn_graph.mojom.h"
#include "services/webnn/ort/ort_model_builder.h"

namespace webnn {

Expand Down Expand Up @@ -62,7 +62,7 @@ class GraphBuilderOrt {
const OperandInfo& GetOperandInfo(uint64_t operand_id) const;

std::map<uint64_t, OperandInfo> id_to_operand_info;

std::unique_ptr<OrtModelBuilder::ModelInfo> model_info;
};

Expand Down
5 changes: 3 additions & 2 deletions services/webnn/ort/graph_impl_ort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "services/webnn/resource_task.h"
#include "services/webnn/webnn_constant_operand.h"
#include "services/webnn/webnn_graph_impl.h"
#include "third_party/onnxruntime_headers/src/include/onnxruntime/core/providers/dml/dml_provider_factory.h"

namespace webnn::ort {

Expand Down Expand Up @@ -132,8 +133,8 @@ GraphImplOrt::CreateAndBuildOnBackgroundThread(

OrtSession* session;
const OrtEnv* env = allocator->env();
OrtStatus* status = GetOrtGraphApi()->CreateSessionFromModel(
env, result->model_info->model.get_ptr(), session_options, &session);
OrtStatus* status = GetOrtModelBuilderApi()->CreateSessionFromModel(
env, result->model_info->model, session_options, &session);
ort_api->ReleaseSessionOptions(session_options);

if (status != NULL) {
Expand Down
2 changes: 1 addition & 1 deletion services/webnn/ort/graph_impl_ort.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
#include "services/webnn/queueable_resource_state.h"
#include "services/webnn/webnn_context_impl.h"
#include "services/webnn/webnn_graph_impl.h"
#include "third_party/microsoft_dxheaders/include/onnxruntime_c_api.h"
#include "third_party/onnxruntime_headers/src/include/onnxruntime/core/session/onnxruntime_c_api.h"

namespace webnn {

Expand Down
Loading