diff --git a/services/webnn/BUILD.gn b/services/webnn/BUILD.gn index 8adec9255f9ad0..b0f27b4cd33b00 100644 --- a/services/webnn/BUILD.gn +++ b/services/webnn/BUILD.gn @@ -81,6 +81,8 @@ component("webnn_service") { "ort/graph_builder_ort.h", "ort/graph_impl_ort.cc", "ort/graph_impl_ort.h", + "ort/ort_model_builder.cc", + "ort/ort_model_builder.h", "ort/platform_functions_ort.cc", "ort/platform_functions_ort.h", "ort/scoped_ort_types.cc", diff --git a/services/webnn/ort/graph_builder_ort.cc b/services/webnn/ort/graph_builder_ort.cc index 8cb6d78862596c..c69c0930a92b30 100644 --- a/services/webnn/ort/graph_builder_ort.cc +++ b/services/webnn/ort/graph_builder_ort.cc @@ -20,9 +20,6 @@ namespace webnn { namespace { -constexpr char kOrtDomainName[] = ""; -constexpr int32_t kOrtOpsetVersion = 21; - // Element-wise binary constexpr char kOpTypeAdd[] = "Add"; constexpr char kOpTypeSub[] = "Sub"; @@ -123,22 +120,13 @@ GraphBuilderOrt::OperandInfo::OperandInfo(OperandInfo&&) = default; GraphBuilderOrt::Result::Result() = default; GraphBuilderOrt::Result::~Result() = default; -const ScopedOrtModel& GraphBuilderOrt::Result::GetModel() { - return model; -} - const GraphBuilderOrt::OperandInfo& GraphBuilderOrt::Result::GetOperandInfo( uint64_t operand_id) const { - auto it = operand_infos.find(operand_id); - CHECK(it != operand_infos.end()); + auto it = id_to_operand_info.find(operand_id); + CHECK(it != id_to_operand_info.end()); return it->second; } -const std::map& -GraphBuilderOrt::Result::id_to_operand_info_map() const { - return operand_infos; -} - // static base::expected, mojom::ErrorPtr> GraphBuilderOrt::CreateAndBuild( @@ -160,10 +148,10 @@ GraphBuilderOrt::GraphBuilderOrt( base::flat_map> constant_operands, scoped_refptr allocator) - : allocator_(allocator), - graph_info_(graph_info), + : graph_info_(graph_info), constant_operands_(std::move(constant_operands)), context_properties_(std::move(context_properties)), + model_builder_(OrtModelBuilder(std::move(allocator))), result_(std::make_unique()) { for (const auto& [id, _] : graph_info.id_to_operand_map) { next_operand_id_ = std::max(next_operand_id_, id + 1); @@ -209,23 +197,10 @@ uint64_t GraphBuilderOrt::NewInitializerAsRawData( std::string name = GetInsertedOperandName(next_operand_id_); OperandInfo operand_info{name, data_type, shape}; - ScopedOrtValue initializer; - CHECK_STATUS(GetOrtApi()->CreateTensorAsOrtValue( - allocator_->allocator(), operand_info.int64_shape.data(), - operand_info.int64_shape.size(), operand_info.onnx_data_type, - initializer.get_pptr())); - - void* ort_tensor_raw_data = nullptr; - CHECK_STATUS(GetOrtApi()->GetTensorMutableData(initializer.get_ptr(), - &ort_tensor_raw_data)); - CHECK(ort_tensor_raw_data); - UNSAFE_BUFFERS( - base::span(static_cast(ort_tensor_raw_data), data.size())) - .copy_from(data); - CHECK_STATUS(GetOrtGraphApi()->AddInitializer(graph_.get_ptr(), name.c_str(), - initializer.get_pptr())); - - CHECK(result_->operand_infos + model_builder_.AddInitializerAsRawData(name, operand_info.int64_shape, data, + operand_info.onnx_data_type); + + CHECK(result_->id_to_operand_info .try_emplace(next_operand_id_, std::move(operand_info)) .second); return next_operand_id_++; @@ -237,18 +212,13 @@ void GraphBuilderOrt::AddInput(uint64_t input_id) { OperandInfo operand_info{name, operand.descriptor.data_type(), operand.descriptor.shape()}; - ScopedOrtShape input_shape; - CHECK_STATUS(GetOrtGraphApi()->CreateFixedShape( - operand_info.int64_shape.data(), operand_info.int64_shape.size(), - input_shape.get_pptr())); - ScopedOrtValueInfo input_info; - CHECK_STATUS(GetOrtGraphApi()->CreateTensorValueInfo( - name.c_str(), operand_info.onnx_data_type, input_shape.get_pptr(), - input_info.get_pptr())); - CHECK_STATUS( - GetOrtGraphApi()->AddInput(graph_.get_ptr(), input_info.get_pptr())); - CHECK(result_->operand_infos.try_emplace(input_id, std::move(operand_info)) - .second); + + model_builder_.AddInput(name, operand_info.int64_shape, + operand_info.onnx_data_type); + + CHECK( + result_->id_to_operand_info.try_emplace(input_id, std::move(operand_info)) + .second); } void GraphBuilderOrt::AddOutput(uint64_t output_id) { @@ -258,20 +228,11 @@ void GraphBuilderOrt::AddOutput(uint64_t output_id) { OperandInfo operand_info{name, operand.descriptor.data_type(), operand.descriptor.shape()}; - ScopedOrtShape output_shape; - CHECK_STATUS(GetOrtGraphApi()->CreateFixedShape( - operand_info.int64_shape.data(), operand_info.int64_shape.size(), - output_shape.get_pptr())); + model_builder_.AddOutput(name, operand_info.int64_shape, + operand_info.onnx_data_type); - ScopedOrtValueInfo output_info; - CHECK_STATUS(GetOrtGraphApi()->CreateTensorValueInfo( - name.c_str(), operand_info.onnx_data_type, output_shape.get_pptr(), - output_info.get_pptr())); - - CHECK_STATUS( - GetOrtGraphApi()->AddOutput(graph_.get_ptr(), output_info.get_pptr())); - - CHECK(result_->operand_infos.try_emplace(output_id, std::move(operand_info)) + CHECK(result_->id_to_operand_info + .try_emplace(output_id, std::move(operand_info)) .second); } @@ -282,33 +243,12 @@ void GraphBuilderOrt::AddInitializer(uint64_t constant_id) { OperandInfo operand_info{name, operand.descriptor().data_type(), operand.descriptor().shape()}; - // auto weight = base::HeapArray::CopiedFrom(operand.ByteSpan()); - // result_->weights.push_back(std::move(weight)); - - // ScopedOrtValue initializer; - // CHECK_STATUS(GetOrtApi()->CreateTensorWithDataAsOrtValue( - // allocator_->memory_info(), result_->weights.back().data(), - // result_->weights.back().size(), operand_info.int64_shape.data(), - // operand_info.int64_shape.size(), operand_info.onnx_data_type, - // initializer.get_pptr())); - // CHECK_STATUS(GetOrtGraphApi()->AddInitializer(graph_.get_ptr(), - // name.c_str(), - // initializer.get_pptr())); - ScopedOrtValue initializer; - CHECK_STATUS(GetOrtApi()->CreateTensorAsOrtValue( - allocator_->allocator(), operand_info.int64_shape.data(), - operand_info.int64_shape.size(), operand_info.onnx_data_type, - initializer.get_pptr())); - - void* ort_tensor_raw_data = nullptr; - CHECK_STATUS(GetOrtApi()->GetTensorMutableData(initializer.get_ptr(), - &ort_tensor_raw_data)); - CHECK(ort_tensor_raw_data); - UNSAFE_BUFFERS(base::span(static_cast(ort_tensor_raw_data), - operand.ByteSpan().size())) - .copy_from(operand.ByteSpan()); - CHECK_STATUS(GetOrtGraphApi()->AddInitializer(graph_.get_ptr(), name.c_str(), - initializer.get_pptr())); + model_builder_.AddInitializerAsExternalData(name, operand_info.int64_shape, + operand.ByteSpan(), + operand_info.onnx_data_type); + + CHECK(result_->id_to_operand_info.try_emplace(constant_id, std::move(operand_info)) + .second); } template @@ -322,12 +262,7 @@ void GraphBuilderOrt::AddBinaryOperation(const T& operation, std::array input_names = {lhs_name.c_str(), rhs_name.c_str()}; std::array output_names = {output_name.c_str()}; - ScopedOrtNode node; - CHECK_STATUS(GetOrtGraphApi()->CreateNode( - op_type.data(), kOrtDomainName, node_name.c_str(), input_names.data(), - input_names.size(), output_names.data(), output_names.size(), - /*attributes=*/nullptr, /*attribs_len=*/0, node.get_pptr())); - CHECK_STATUS(GetOrtGraphApi()->AddNode(graph_.get_ptr(), node.get_pptr())); + model_builder_.AddNode(op_type, node_name, input_names, output_names); } void GraphBuilderOrt::AddElementWiseBinaryOperation( @@ -406,12 +341,7 @@ void GraphBuilderOrt::AddUnaryOperation(const T& operation, std::array input_names = {input_name.c_str()}; std::array output_names = {output_name.c_str()}; - ScopedOrtNode node; - CHECK_STATUS(GetOrtGraphApi()->CreateNode( - op_type.data(), kOrtDomainName, node_name.c_str(), input_names.data(), - input_names.size(), output_names.data(), output_names.size(), - /*attributes=*/nullptr, /*attribs_len=*/0, node.get_pptr())); - CHECK_STATUS(GetOrtGraphApi()->AddNode(graph_.get_ptr(), node.get_pptr())); + model_builder_.AddNode(op_type, node_name, input_names, output_names); } void GraphBuilderOrt::AddElementWiseUnaryOperation( @@ -479,20 +409,14 @@ void GraphBuilderOrt::AddCastOperation(const mojom::ElementWiseUnary& cast) { const OperandDataType output_data_type = GetOperand(cast.output_operand_id).descriptor.data_type(); - ScopedOrtOpAttr attr_to; int64_t to_data_type = static_cast( OperandTypeToONNXTensorElementDataType(output_data_type)); - CHECK_STATUS(GetOrtApi()->CreateOpAttr( - /*name=*/"to", &to_data_type, /*len=*/1, OrtOpAttrType::ORT_OP_ATTR_INT, - attr_to.get_pptr())); + ScopedOrtOpAttr attr_to; + model_builder_.CreateAttribute(attr_to, /*name=*/"to", to_data_type); - ScopedOrtNode node; std::array attributes = {attr_to.get_pptr()}; - CHECK_STATUS(GetOrtGraphApi()->CreateNode( - kOpTypeCast, kOrtDomainName, node_name.c_str(), input_names.data(), - input_names.size(), output_names.data(), output_names.size(), - attributes.data(), attributes.size(), node.get_pptr())); - CHECK_STATUS(GetOrtGraphApi()->AddNode(graph_.get_ptr(), node.get_pptr())); + + model_builder_.AddNode(kOpTypeCast, node_name, input_names, output_names, attributes); } void GraphBuilderOrt::AddClampOperation(const mojom::Clamp& clamp) { @@ -542,12 +466,7 @@ void GraphBuilderOrt::AddClampOperation(const mojom::Clamp& clamp) { min_name.c_str(), max_name.c_str()}; std::array output_names = {output_name.c_str()}; - ScopedOrtNode node; - CHECK_STATUS(GetOrtGraphApi()->CreateNode( - kOpTypeClamp, kOrtDomainName, node_name.c_str(), input_names.data(), - input_names.size(), output_names.data(), output_names.size(), - /*attributes=*/nullptr, /*attribs_len=*/0, node.get_pptr())); - CHECK_STATUS(GetOrtGraphApi()->AddNode(graph_.get_ptr(), node.get_pptr())); + model_builder_.AddNode(kOpTypeClamp, node_name, input_names, output_names); } void GraphBuilderOrt::AddConv2dOperation(const mojom::Conv2d& conv2d) { @@ -565,50 +484,38 @@ void GraphBuilderOrt::AddConv2dOperation(const mojom::Conv2d& conv2d) { } std::array output_names = {output_name.c_str()}; - ScopedOrtOpAttr attr_dilations; std::array dilations = { base::checked_cast(conv2d.dilations->height), base::checked_cast(conv2d.dilations->width)}; - CHECK_STATUS(GetOrtApi()->CreateOpAttr( - /*name=*/"dilations", dilations.data(), dilations.size(), - OrtOpAttrType::ORT_OP_ATTR_INTS, attr_dilations.get_pptr())); + ScopedOrtOpAttr attr_dilations; + model_builder_.CreateAttribute(attr_dilations, /*name=*/"dilations", dilations); - ScopedOrtOpAttr attr_group; int64_t group = base::checked_cast(conv2d.groups); - CHECK_STATUS(GetOrtApi()->CreateOpAttr( - /*name=*/"group", &group, /*len=*/1, OrtOpAttrType::ORT_OP_ATTR_INT, - attr_group.get_pptr())); + ScopedOrtOpAttr attr_group; + model_builder_.CreateAttribute(attr_group, /*name=*/"group", group); - ScopedOrtOpAttr attr_pads; std::array pads = { base::checked_cast(conv2d.padding->beginning->height), base::checked_cast(conv2d.padding->beginning->width), base::checked_cast(conv2d.padding->ending->height), base::checked_cast(conv2d.padding->ending->width)}; - CHECK_STATUS(GetOrtApi()->CreateOpAttr( - /*name=*/"pads", pads.data(), pads.size(), - OrtOpAttrType::ORT_OP_ATTR_INTS, attr_pads.get_pptr())); + ScopedOrtOpAttr attr_pads; + model_builder_.CreateAttribute(attr_pads, /*name=*/"pads", pads); - ScopedOrtOpAttr attr_strides; std::array strides = { base::checked_cast(conv2d.strides->height), base::checked_cast(conv2d.strides->width)}; - CHECK_STATUS(GetOrtApi()->CreateOpAttr( - /*name=*/"strides", strides.data(), strides.size(), - OrtOpAttrType::ORT_OP_ATTR_INTS, attr_strides.get_pptr())); + ScopedOrtOpAttr attr_strides; + model_builder_.CreateAttribute(attr_strides, /*name=*/"strides", strides); - ScopedOrtNode node; std::array attributes = { attr_dilations.get_pptr(), attr_group.get_pptr(), attr_pads.get_pptr(), attr_strides.get_pptr(), }; - CHECK_STATUS(GetOrtGraphApi()->CreateNode( - kOpTypeConv2d, kOrtDomainName, node_name.c_str(), input_names.data(), - input_names.size(), output_names.data(), output_names.size(), - attributes.data(), attributes.size(), node.get_pptr())); - CHECK_STATUS(GetOrtGraphApi()->AddNode(graph_.get_ptr(), node.get_pptr())); + model_builder_.AddNode(kOpTypeConv2d, node_name, input_names, output_names, + attributes); } void GraphBuilderOrt::AddGemmOperation(const mojom::Gemm& gemm) { @@ -629,37 +536,24 @@ void GraphBuilderOrt::AddGemmOperation(const mojom::Gemm& gemm) { std::array output_names = {output_name.c_str()}; ScopedOrtOpAttr attr_alpha; - CHECK_STATUS(GetOrtApi()->CreateOpAttr( - /*name=*/"alpha", &gemm.alpha, /*len=*/1, - OrtOpAttrType::ORT_OP_ATTR_FLOAT, attr_alpha.get_pptr())); - + model_builder_.CreateAttribute(attr_alpha, /*name=*/"alpha", gemm.alpha); ScopedOrtOpAttr attr_beta; - CHECK_STATUS(GetOrtApi()->CreateOpAttr( - /*name=*/"beta", &gemm.beta, /*len=*/1, OrtOpAttrType::ORT_OP_ATTR_FLOAT, - attr_beta.get_pptr())); + model_builder_.CreateAttribute(attr_beta, /*name=*/"beta", gemm.beta); - ScopedOrtOpAttr attr_transA; int64_t trans_a = static_cast(gemm.a_transpose); - CHECK_STATUS(GetOrtApi()->CreateOpAttr( - /*name=*/"transA", &trans_a, /*len=*/1, OrtOpAttrType::ORT_OP_ATTR_INT, - attr_transA.get_pptr())); + ScopedOrtOpAttr attr_transA; + model_builder_.CreateAttribute(attr_transA, /*name=*/"transA", trans_a); - ScopedOrtOpAttr attr_transB; int64_t trans_b = static_cast(gemm.b_transpose); - CHECK_STATUS(GetOrtApi()->CreateOpAttr( - /*name=*/"transB", &trans_b, /*len=*/1, OrtOpAttrType::ORT_OP_ATTR_INT, - attr_transB.get_pptr())); + ScopedOrtOpAttr attr_transB; + model_builder_.CreateAttribute(attr_transB, /*name=*/"transB", trans_b); std::array attributes = { attr_alpha.get_pptr(), attr_beta.get_pptr(), attr_transA.get_pptr(), attr_transB.get_pptr()}; - ScopedOrtNode node; - CHECK_STATUS(GetOrtGraphApi()->CreateNode( - kOpTypeGemm, kOrtDomainName, node_name.c_str(), input_names.data(), - input_names.size(), output_names.data(), output_names.size(), - attributes.data(), attributes.size(), node.get_pptr())); - CHECK_STATUS(GetOrtGraphApi()->AddNode(graph_.get_ptr(), node.get_pptr())) + model_builder_.AddNode(kOpTypeGemm, node_name, input_names, output_names, + attributes); } void GraphBuilderOrt::AddLogicalNotOperation( @@ -675,12 +569,7 @@ void GraphBuilderOrt::AddMatmulOperation(const mojom::Matmul& matmul) { input_b_name.c_str()}; std::array output_names = {output_name.c_str()}; - ScopedOrtNode node; - CHECK_STATUS(GetOrtGraphApi()->CreateNode( - kOpTypeMatmul, kOrtDomainName, node_name.c_str(), input_names.data(), - input_names.size(), output_names.data(), output_names.size(), - /*attributes=*/nullptr, /*attribs_len=*/0, node.get_pptr())); - CHECK_STATUS(GetOrtGraphApi()->AddNode(graph_.get_ptr(), node.get_pptr())) + model_builder_.AddNode(kOpTypeMatmul, node_name, input_names, output_names); } void GraphBuilderOrt::AddPool2dOperation(const mojom::Pool2d& pool2d) { @@ -688,25 +577,21 @@ void GraphBuilderOrt::AddPool2dOperation(const mojom::Pool2d& pool2d) { base::checked_cast(pool2d.dilations->height), base::checked_cast(pool2d.dilations->width)}; ScopedOrtOpAttr attr_dilations; - CHECK_STATUS(GetOrtApi()->CreateOpAttr( - /*name=*/"dilations", dilations.data(), /*len=*/2, - OrtOpAttrType::ORT_OP_ATTR_INTS, attr_dilations.get_pptr())); + model_builder_.CreateAttribute(attr_dilations, /*name=*/"dilations", + dilations); std::array strides = { base::checked_cast(pool2d.strides->height), base::checked_cast(pool2d.strides->width)}; ScopedOrtOpAttr attr_strides; - CHECK_STATUS(GetOrtApi()->CreateOpAttr( - /*name=*/"strides", strides.data(), /*len=*/2, - OrtOpAttrType::ORT_OP_ATTR_INTS, attr_strides.get_pptr())); + model_builder_.CreateAttribute(attr_strides, /*name=*/"strides", strides); std::array window_dimensions = { base::checked_cast(pool2d.window_dimensions->height), base::checked_cast(pool2d.window_dimensions->width)}; ScopedOrtOpAttr attr_kernel_shape; - CHECK_STATUS(GetOrtApi()->CreateOpAttr( - /*name=*/"kernel_shape", window_dimensions.data(), /*len=*/2, - OrtOpAttrType::ORT_OP_ATTR_INTS, attr_kernel_shape.get_pptr())); + model_builder_.CreateAttribute(attr_kernel_shape, + /*name=*/"kernel_shape", window_dimensions); // ONNX's pads are [beginning_height, beginning_width, ending_height, // ending_width] @@ -716,9 +601,7 @@ void GraphBuilderOrt::AddPool2dOperation(const mojom::Pool2d& pool2d) { base::checked_cast(pool2d.padding->ending->height), base::checked_cast(pool2d.padding->ending->width)}; ScopedOrtOpAttr attr_pads; - CHECK_STATUS(GetOrtApi()->CreateOpAttr( - /*name=*/"pads", pads.data(), /*len=*/4, OrtOpAttrType::ORT_OP_ATTR_INTS, - attr_pads.get_pptr())); + model_builder_.CreateAttribute(attr_pads, /*name=*/"pads", pads); // Calculate the ceil_mode. const std::vector& input_shape = @@ -744,9 +627,8 @@ void GraphBuilderOrt::AddPool2dOperation(const mojom::Pool2d& pool2d) { int64_t ceil_mode = float_output_height.value() < output_height ? 1 : 0; ScopedOrtOpAttr attr_ceil_mode; - CHECK_STATUS(GetOrtApi()->CreateOpAttr( - /*name=*/"ceil_mode", &ceil_mode, /*len=*/1, - OrtOpAttrType::ORT_OP_ATTR_INT, attr_ceil_mode.get_pptr())); + model_builder_.CreateAttribute(attr_ceil_mode, /*name=*/"ceil_mode", + ceil_mode); // P value of the Lp norm used to pool over the input data. std::optional attr_p; @@ -765,14 +647,11 @@ void GraphBuilderOrt::AddPool2dOperation(const mojom::Pool2d& pool2d) { op_type = kOpTypeLpPool2d; p = 2; attr_p.emplace(); - CHECK_STATUS(GetOrtApi()->CreateOpAttr( - /*name=*/"p", &p.value(), /*len=*/1, OrtOpAttrType::ORT_OP_ATTR_INT, - attr_p.value().get_pptr())); + model_builder_.CreateAttribute(attr_p.value(), /*name=*/"p", p.value()); break; } } - ScopedOrtNode node; std::vector attributes = { attr_dilations.get_pptr(), attr_strides.get_pptr(), attr_kernel_shape.get_pptr(), attr_pads.get_pptr(), @@ -789,11 +668,8 @@ void GraphBuilderOrt::AddPool2dOperation(const mojom::Pool2d& pool2d) { std::array input_names = {input_name.c_str()}; std::array output_names = {output_name.c_str()}; - CHECK_STATUS(GetOrtGraphApi()->CreateNode( - op_type.data(), kOrtDomainName, node_name.c_str(), input_names.data(), - input_names.size(), output_names.data(), output_names.size(), - attributes.data(), attributes.size(), node.get_pptr())); - CHECK_STATUS(GetOrtGraphApi()->AddNode(graph_.get_ptr(), node.get_pptr())); + model_builder_.AddNode(op_type, node_name, input_names, output_names, + attributes); } void GraphBuilderOrt::AddReshapeOperation(const mojom::Reshape& reshape) { @@ -822,12 +698,7 @@ void GraphBuilderOrt::AddReshapeOperation(const mojom::Reshape& reshape) { shape_name.c_str()}; std::array output_names = {output_name.c_str()}; - ScopedOrtNode node; - CHECK_STATUS(GetOrtGraphApi()->CreateNode( - kOpTypeReshape, kOrtDomainName, node_name.c_str(), input_names.data(), - input_names.size(), output_names.data(), output_names.size(), - /*attributes=*/nullptr, /*attribs_len=*/0, node.get_pptr())); - CHECK_STATUS(GetOrtGraphApi()->AddNode(graph_.get_ptr(), node.get_pptr())); + model_builder_.AddNode(kOpTypeReshape, node_name, input_names, output_names); } void GraphBuilderOrt::AddSoftmaxOperation(const mojom::Softmax& softmax) { @@ -838,19 +709,12 @@ void GraphBuilderOrt::AddSoftmaxOperation(const mojom::Softmax& softmax) { std::array input_names = {input_name.c_str()}; std::array output_names = {output_name.c_str()}; - ScopedOrtOpAttr attr_axis; int64_t axis = static_cast(softmax.axis); - CHECK_STATUS(GetOrtApi()->CreateOpAttr( - /*name=*/"axis", &axis, /*len=*/1, OrtOpAttrType::ORT_OP_ATTR_INT, - attr_axis.get_pptr())); + ScopedOrtOpAttr attr_axis; + model_builder_.CreateAttribute(attr_axis, /*name=*/"axis", axis); - ScopedOrtNode node; std::array attributes = {attr_axis.get_pptr()}; - CHECK_STATUS(GetOrtGraphApi()->CreateNode( - kOpTypeSoftmax, kOrtDomainName, node_name.c_str(), input_names.data(), - input_names.size(), output_names.data(), output_names.size(), - attributes.data(), attributes.size(), node.get_pptr())); - CHECK_STATUS(GetOrtGraphApi()->AddNode(graph_.get_ptr(), node.get_pptr())); + model_builder_.AddNode(kOpTypeSoftmax, node_name, input_names, output_names, attributes); } void GraphBuilderOrt::AddTransposeOperation(const mojom::Transpose& transpose) { @@ -861,20 +725,14 @@ void GraphBuilderOrt::AddTransposeOperation(const mojom::Transpose& transpose) { std::array input_names = {input_name.c_str()}; std::array output_names = {output_name.c_str()}; - ScopedOrtOpAttr attr_perm; std::vector permutation(transpose.permutation.begin(), transpose.permutation.end()); - CHECK_STATUS(GetOrtApi()->CreateOpAttr( - /*name=*/"perm", permutation.data(), permutation.size(), - OrtOpAttrType::ORT_OP_ATTR_INTS, attr_perm.get_pptr())); + ScopedOrtOpAttr attr_perm; + model_builder_.CreateAttribute(attr_perm, /*name=*/"perm", permutation); - ScopedOrtNode node; std::array attributes = {attr_perm.get_pptr()}; - CHECK_STATUS(GetOrtGraphApi()->CreateNode( - kOpTypeTranspose, kOrtDomainName, node_name.c_str(), input_names.data(), - input_names.size(), output_names.data(), output_names.size(), - attributes.data(), attributes.size(), node.get_pptr())); - CHECK_STATUS(GetOrtGraphApi()->AddNode(graph_.get_ptr(), node.get_pptr())); + model_builder_.AddNode(kOpTypeTranspose, node_name, input_names, output_names, + attributes); } void GraphBuilderOrt::AddWhereOperation(const mojom::Where& where) { @@ -894,19 +752,11 @@ void GraphBuilderOrt::AddWhereOperation(const mojom::Where& where) { ScopedOrtOpAttr attr_to; int64_t to_data_type = static_cast(ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL); - CHECK_STATUS(GetOrtApi()->CreateOpAttr( - /*name=*/"to", &to_data_type, /*len=*/1, OrtOpAttrType::ORT_OP_ATTR_INT, - attr_to.get_pptr())); + model_builder_.CreateAttribute(attr_to, /*name=*/"to", to_data_type); - ScopedOrtNode cast_node; std::array cast_attributes = {attr_to.get_pptr()}; - CHECK_STATUS(GetOrtGraphApi()->CreateNode( - kOpTypeCast, kOrtDomainName, cast_node_name.c_str(), - cast_input_names.data(), cast_input_names.size(), - cast_output_names.data(), cast_output_names.size(), - cast_attributes.data(), cast_attributes.size(), cast_node.get_pptr())); - CHECK_STATUS( - GetOrtGraphApi()->AddNode(graph_.get_ptr(), cast_node.get_pptr())); + model_builder_.AddNode(kOpTypeCast, cast_node_name, cast_input_names, + cast_output_names, cast_attributes); next_operand_id_++; } @@ -920,26 +770,11 @@ void GraphBuilderOrt::AddWhereOperation(const mojom::Where& where) { false_value_name.c_str()}; std::array output_names = {output_name.c_str()}; - ScopedOrtNode node; - CHECK_STATUS(GetOrtGraphApi()->CreateNode( - kOpTypeWhere, kOrtDomainName, node_name.c_str(), input_names.data(), - input_names.size(), output_names.data(), output_names.size(), - /*attributes=*/nullptr, /*attribs_len=*/0, node.get_pptr())); - CHECK_STATUS(GetOrtGraphApi()->AddNode(graph_.get_ptr(), node.get_pptr())); + model_builder_.AddNode(kOpTypeWhere, node_name, input_names, output_names); } [[nodiscard]] base::expected GraphBuilderOrt::BuildModel() { - ScopedOrtModel& model = result_->model; - - std::vector domain_names = {kOrtDomainName}; - std::vector opset_versions = {kOrtOpsetVersion}; - CHECK_STATUS( - GetOrtGraphApi()->CreateModel(domain_names.data(), opset_versions.data(), - domain_names.size(), model.get_pptr())); - - CHECK_STATUS(GetOrtGraphApi()->CreateGraph(graph_.get_pptr())); - // Add inputs. for (uint64_t input_id : graph_info_->input_operands) { AddInput(input_id); @@ -1047,7 +882,7 @@ GraphBuilderOrt::BuildModel() { AddOutput(output_id); } - CHECK_STATUS(GetOrtGraphApi()->AddGraph(model.get_ptr(), graph_.get_pptr())); + result_->model_info = model_builder_.BuildAndTakeModelInfo(); return base::ok(); } diff --git a/services/webnn/ort/graph_builder_ort.h b/services/webnn/ort/graph_builder_ort.h index 042c3e60a1289f..07c1286c06c8ee 100644 --- a/services/webnn/ort/graph_builder_ort.h +++ b/services/webnn/ort/graph_builder_ort.h @@ -11,7 +11,6 @@ #include #include "base/containers/flat_map.h" -#include "base/containers/heap_array.h" #include "base/containers/span.h" #include "base/files/file_path.h" #include "base/memory/raw_ptr.h" @@ -24,7 +23,7 @@ #include "services/webnn/public/mojom/webnn_context_provider.mojom.h" #include "services/webnn/public/mojom/webnn_error.mojom-forward.h" #include "services/webnn/public/mojom/webnn_graph.mojom.h" -#include "third_party/microsoft_dxheaders/include/onnxruntime_c_api.h" +#include "services/webnn/ort/ort_model_builder.h" namespace webnn { @@ -60,20 +59,11 @@ class GraphBuilderOrt { Result& operator=(const Result&) = delete; ~Result(); - const ScopedOrtModel& GetModel(); - const OperandInfo& GetOperandInfo(uint64_t operand_id) const; - const std::map& id_to_operand_info_map() const; - - ScopedOrtModel model; - std::map operand_infos; - - // TODO: Consider reusing constant operands instead of copying them to - // `weights`. - // - // Store the weights which should be alive for inference session. - std::vector> weights; + std::map id_to_operand_info; + + std::unique_ptr model_info; }; // Factory method that creates a GraphBuilderOrt, builds and serializes the @@ -148,8 +138,6 @@ class GraphBuilderOrt { [[nodiscard]] base::expected BuildModel(); - scoped_refptr allocator_; - // Used for inserting new operands into graph. uint64_t next_operand_id_ = 0; @@ -163,7 +151,7 @@ class GraphBuilderOrt { const ContextProperties context_properties_; - ScopedOrtGraph graph_; + OrtModelBuilder model_builder_; std::unique_ptr result_; }; diff --git a/services/webnn/ort/graph_impl_ort.cc b/services/webnn/ort/graph_impl_ort.cc index c6111ba6ca501b..faad950c208a98 100644 --- a/services/webnn/ort/graph_impl_ort.cc +++ b/services/webnn/ort/graph_impl_ort.cc @@ -74,9 +74,10 @@ void GraphImplOrt::CreateAndBuild( std::move(wrapped_callback)); } -GraphImplOrt::Session::Session(OrtSession* session, - std::vector> weights) - : weights(std::move(weights)), session(session) {} +GraphImplOrt::Session::Session( + OrtSession* session, + std::vector> external_data) + : external_data(std::move(external_data)), session(session) {} GraphImplOrt::Session::~Session() { // TODO: Can we call `ReleaseSession` from Dllmain (because session owns a @@ -132,7 +133,7 @@ GraphImplOrt::CreateAndBuildOnBackgroundThread( OrtSession* session; const OrtEnv* env = allocator->env(); OrtStatus* status = GetOrtGraphApi()->CreateSessionFromModel( - env, result->model.get_ptr(), session_options, &session); + env, result->model_info->model.get_ptr(), session_options, &session); ort_api->ReleaseSessionOptions(session_options); if (status != NULL) { @@ -145,8 +146,8 @@ GraphImplOrt::CreateAndBuildOnBackgroundThread( LOG(ERROR) << "Running on ORT============="; - return base::WrapUnique( - new GraphImplOrt::Session(session, std::move(result->weights))); + return base::WrapUnique(new GraphImplOrt::Session( + session, std::move(result->model_info->external_data))); } // static diff --git a/services/webnn/ort/graph_impl_ort.h b/services/webnn/ort/graph_impl_ort.h index 9aa92b288175b4..7b58acc694bdcb 100644 --- a/services/webnn/ort/graph_impl_ort.h +++ b/services/webnn/ort/graph_impl_ort.h @@ -52,14 +52,15 @@ class GraphImplOrt final : public WebNNGraphImpl { ~GraphImplOrt() override; struct Session { - Session(OrtSession* session, std::vector> weights); + Session(OrtSession* session, + std::vector> external_data); Session(const Session&) = delete; Session& operator=(const Session&) = delete; ~Session(); OrtSession* GetSession() { return session.get(); } - std::vector> weights; + std::vector> external_data; raw_ptr session; }; diff --git a/services/webnn/ort/ort_model_builder.cc b/services/webnn/ort/ort_model_builder.cc new file mode 100644 index 00000000000000..2bd05b2ed57cf4 --- /dev/null +++ b/services/webnn/ort/ort_model_builder.cc @@ -0,0 +1,165 @@ +// Copyright 2024 The Chromium Authors +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "services/webnn/ort/ort_model_builder.h" + +#include "base/notreached.h" +#include "services/webnn/ort/error_ort.h" +#include "services/webnn/ort/utils_ort.h" + +namespace webnn { + +namespace { + +constexpr char kOrtDomainName[] = ""; +constexpr int32_t kOrtOpsetVersion = 21; + +} // namespace + +namespace ort { + +OrtModelBuilder::ModelInfo::ModelInfo() = default; +OrtModelBuilder::ModelInfo::~ModelInfo() = default; + +OrtModelBuilder::OrtModelBuilder(scoped_refptr allocator) + : allocator_(std::move(allocator)), + model_info_(std::make_unique()) { + CHECK_STATUS(GetOrtGraphApi()->CreateGraph(graph_.get_pptr())); +} +OrtModelBuilder::~OrtModelBuilder() = default; + +void OrtModelBuilder::AddInput(std::string_view name, + base::span shape, + ONNXTensorElementDataType data_type) { + ScopedOrtShape input_shape; + CHECK_STATUS(GetOrtGraphApi()->CreateFixedShape(shape.data(), shape.size(), + input_shape.get_pptr())); + + ScopedOrtValueInfo input_info; + CHECK_STATUS(GetOrtGraphApi()->CreateTensorValueInfo( + name.data(), data_type, input_shape.get_pptr(), input_info.get_pptr())); + CHECK_STATUS( + GetOrtGraphApi()->AddInput(graph_.get_ptr(), input_info.get_pptr())); +} + +void OrtModelBuilder::AddOutput(std::string_view name, + base::span shape, + ONNXTensorElementDataType data_type) { + ScopedOrtShape output_shape; + CHECK_STATUS(GetOrtGraphApi()->CreateFixedShape(shape.data(), shape.size(), + output_shape.get_pptr())); + + ScopedOrtValueInfo output_info; + CHECK_STATUS(GetOrtGraphApi()->CreateTensorValueInfo( + name.data(), data_type, output_shape.get_pptr(), output_info.get_pptr())); + CHECK_STATUS( + GetOrtGraphApi()->AddOutput(graph_.get_ptr(), output_info.get_pptr())); +} + +void OrtModelBuilder::AddInitializerAsRawData( + std::string_view name, + base::span shape, + base::span data, + ONNXTensorElementDataType data_type) { + ScopedOrtValue initializer; + CHECK_STATUS(GetOrtApi()->CreateTensorAsOrtValue( + allocator_->allocator(), shape.data(), shape.size(), data_type, + initializer.get_pptr())); + + void* ort_tensor_raw_data = nullptr; + CHECK_STATUS(GetOrtApi()->GetTensorMutableData(initializer.get_ptr(), + &ort_tensor_raw_data)); + CHECK(ort_tensor_raw_data); + UNSAFE_BUFFERS( + base::span(static_cast(ort_tensor_raw_data), data.size())) + .copy_from(data); + CHECK_STATUS(GetOrtGraphApi()->AddInitializer(graph_.get_ptr(), name.data(), + initializer.get_pptr())); +} + +void OrtModelBuilder::AddInitializerAsExternalData( + std::string_view name, + base::span shape, + base::span data, + ONNXTensorElementDataType data_type) { + auto weight = base::HeapArray::CopiedFrom(data); + model_info_->external_data.push_back(std::move(weight)); + + ScopedOrtValue initializer; + CHECK_STATUS(GetOrtApi()->CreateTensorWithDataAsOrtValue( + allocator_->memory_info(), model_info_->external_data.back().data(), + model_info_->external_data.back().size(), shape.data(), shape.size(), + data_type, initializer.get_pptr())); + CHECK_STATUS(GetOrtGraphApi()->AddInitializer(graph_.get_ptr(), name.data(), + initializer.get_pptr())); +} + +void OrtModelBuilder::CreateAttribute(ScopedOrtOpAttr& attribute, + std::string_view name, + OrtOpAttrData data) { + if (absl::holds_alternative(data)) { + CHECK_STATUS(GetOrtApi()->CreateOpAttr( + name.data(), &absl::get(data), /*len=*/1, + OrtOpAttrType::ORT_OP_ATTR_INT, attribute.get_pptr())); + } else if (absl::holds_alternative(data)) { + CHECK_STATUS(GetOrtApi()->CreateOpAttr( + name.data(), &absl::get(data), /*len=*/1, + OrtOpAttrType::ORT_OP_ATTR_FLOAT, attribute.get_pptr())); + } else if (absl::holds_alternative(data)) { + std::string_view string_data = absl::get(data); + CHECK_STATUS(GetOrtApi()->CreateOpAttr( + name.data(), string_data.data(), string_data.size(), + OrtOpAttrType::ORT_OP_ATTR_STRING, attribute.get_pptr())); + } else if (absl::holds_alternative>(data)) { + base::span ints_data = + absl::get>(data); + CHECK_STATUS(GetOrtApi()->CreateOpAttr( + name.data(), ints_data.data(), ints_data.size(), + OrtOpAttrType::ORT_OP_ATTR_INTS, attribute.get_pptr())); + } else if (absl::holds_alternative>(data)) { + base::span floats_data = + absl::get>(data); + CHECK_STATUS(GetOrtApi()->CreateOpAttr( + name.data(), floats_data.data(), floats_data.size(), + OrtOpAttrType::ORT_OP_ATTR_FLOATS, attribute.get_pptr())); + } else if (absl::holds_alternative>(data)) { + base::span strings_data = + absl::get>(data); + CHECK_STATUS(GetOrtApi()->CreateOpAttr( + name.data(), strings_data.data(), strings_data.size(), + OrtOpAttrType::ORT_OP_ATTR_STRINGS, attribute.get_pptr())); + } +} + +void OrtModelBuilder::AddNode(std::string_view op_type, + std::string_view node_name, + base::span input_names, + base::span output_names, + base::span attributes) { + ScopedOrtNode node; + CHECK_STATUS(GetOrtGraphApi()->CreateNode( + op_type.data(), kOrtDomainName, node_name.data(), input_names.data(), + input_names.size(), output_names.data(), output_names.size(), + attributes.data(), attributes.size(), node.get_pptr())); + CHECK_STATUS(GetOrtGraphApi()->AddNode(graph_.get_ptr(), node.get_pptr())); +} + +std::unique_ptr +OrtModelBuilder::BuildAndTakeModelInfo() { + std::vector domain_names = {kOrtDomainName}; + std::vector opset_versions = {kOrtOpsetVersion}; + + CHECK_STATUS(GetOrtGraphApi()->CreateModel( + domain_names.data(), opset_versions.data(), domain_names.size(), + model_info_->model.get_pptr())); + + CHECK_STATUS(GetOrtGraphApi()->AddGraph(model_info_->model.get_ptr(), + graph_.get_pptr())); + + return std::move(model_info_); +} + +} // namespace ort + +} // namespace webnn diff --git a/services/webnn/ort/ort_model_builder.h b/services/webnn/ort/ort_model_builder.h new file mode 100644 index 00000000000000..27e8360e7633a7 --- /dev/null +++ b/services/webnn/ort/ort_model_builder.h @@ -0,0 +1,93 @@ +// Copyright 2024 The Chromium Authors +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef SERVICES_WEBNN_ORT_ORT_MODEL_BUILDER_H_ +#define SERVICES_WEBNN_ORT_ORT_MODEL_BUILDER_H_ + +#include +#include + +#include "base/containers/heap_array.h" +#include "base/containers/span.h" +#include "base/memory/stack_allocated.h" +#include "services/webnn/ort/allocator_ort.h" +#include "services/webnn/ort/scoped_ort_types.h" +#include "third_party/abseil-cpp/absl/types/variant.h" +#include "third_party/microsoft_dxheaders/include/onnxruntime_c_api.h" + +namespace webnn { + +namespace ort { + +class OrtModelBuilder final { + STACK_ALLOCATED(); + + public: + struct ModelInfo { + explicit ModelInfo(); + ModelInfo(const ModelInfo&) = delete; + ModelInfo& operator=(const ModelInfo&) = delete; + ~ModelInfo(); + + ScopedOrtModel model; + + // TODO: Consider reusing constant operands instead of copying them to + // `external_data`. + // + // Store the external data which should be alive for inference session. + std::vector> external_data; + }; + + explicit OrtModelBuilder(scoped_refptr allocator); + ~OrtModelBuilder(); + OrtModelBuilder(const OrtModelBuilder&) = delete; + OrtModelBuilder& operator=(const OrtModelBuilder&) = delete; + + void AddInput(std::string_view name, + base::span shape, + ONNXTensorElementDataType data_type); + + void AddOutput(std::string_view name, + base::span shape, + ONNXTensorElementDataType data_type); + + void AddInitializerAsRawData(std::string_view name, + base::span shape, + base::span data, + ONNXTensorElementDataType data_type); + + void AddInitializerAsExternalData(std::string_view name, + base::span shape, + base::span data, + ONNXTensorElementDataType data_type); + using OrtOpAttrData = absl::variant, + base::span, + base::span>; + void CreateAttribute(ScopedOrtOpAttr& attribute, + std::string_view name, + OrtOpAttrData data); + + void AddNode(std::string_view op_type, + std::string_view node_name, + base::span input_names, + base::span output_names, + base::span attributes = {}); + + std::unique_ptr BuildAndTakeModelInfo(); + + private: + scoped_refptr allocator_; + + ScopedOrtGraph graph_; + + std::unique_ptr model_info_; +}; + +} // namespace ort +} // namespace webnn + +#endif // SERVICES_WEBNN_ORT_ORT_MODEL_BUILDER_H_ diff --git a/services/webnn/ort/scoped_ort_types.h b/services/webnn/ort/scoped_ort_types.h index f9fb7e5ed98fde..f3d1841037a961 100644 --- a/services/webnn/ort/scoped_ort_types.h +++ b/services/webnn/ort/scoped_ort_types.h @@ -14,6 +14,8 @@ namespace webnn::ort { class ScopedOrtValue { public: ScopedOrtValue(); + ScopedOrtValue(const ScopedOrtValue&) = delete; + ScopedOrtValue& operator=(const ScopedOrtValue&) = delete; ~ScopedOrtValue(); OrtValue* get_ptr() { return *pptr_; } @@ -26,6 +28,8 @@ class ScopedOrtValue { class ScopedOrtMemoryInfo { public: ScopedOrtMemoryInfo(); + ScopedOrtMemoryInfo(const ScopedOrtMemoryInfo&) = delete; + ScopedOrtMemoryInfo& operator=(const ScopedOrtMemoryInfo&) = delete; ~ScopedOrtMemoryInfo(); OrtMemoryInfo* get_ptr() { return *pptr_; } @@ -38,6 +42,8 @@ class ScopedOrtMemoryInfo { class ScopedOrtOpAttr { public: ScopedOrtOpAttr(); + ScopedOrtOpAttr(const ScopedOrtOpAttr&) = delete; + ScopedOrtOpAttr& operator=(const ScopedOrtOpAttr&) = delete; ~ScopedOrtOpAttr(); OrtOpAttr* get_ptr() { return *pptr_; } @@ -50,6 +56,8 @@ class ScopedOrtOpAttr { class ScopedOrtGraph { public: ScopedOrtGraph(); + ScopedOrtGraph(const ScopedOrtGraph&) = delete; + ScopedOrtGraph& operator=(const ScopedOrtGraph&) = delete; ~ScopedOrtGraph(); OrtGraph* get_ptr() { return *pptr_; } @@ -62,6 +70,8 @@ class ScopedOrtGraph { class ScopedOrtShape { public: ScopedOrtShape(); + ScopedOrtShape(const ScopedOrtShape&) = delete; + ScopedOrtShape& operator=(const ScopedOrtShape&) = delete; ~ScopedOrtShape(); OrtShape* get_ptr() { return *pptr_; } @@ -74,6 +84,8 @@ class ScopedOrtShape { class ScopedOrtValueInfo { public: ScopedOrtValueInfo(); + ScopedOrtValueInfo(const ScopedOrtValueInfo&) = delete; + ScopedOrtValueInfo& operator=(const ScopedOrtValueInfo&) = delete; ~ScopedOrtValueInfo(); OrtValueInfo* get_ptr() { return *pptr_; } @@ -86,6 +98,8 @@ class ScopedOrtValueInfo { class ScopedOrtNode { public: ScopedOrtNode(); + ScopedOrtNode(const ScopedOrtNode&) = delete; + ScopedOrtNode& operator=(const ScopedOrtNode&) = delete; ~ScopedOrtNode(); OrtNode* get_ptr() { return *pptr_; } @@ -98,6 +112,8 @@ class ScopedOrtNode { class ScopedOrtModel { public: ScopedOrtModel(); + ScopedOrtModel(const ScopedOrtModel&) = delete; + ScopedOrtModel& operator=(const ScopedOrtModel&) = delete; ~ScopedOrtModel(); OrtModel* get_ptr() { return *pptr_; }