Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[runtime] Upgrade onnx runtime version to v1.13.1 #2636

Merged
merged 3 commits into from
Dec 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions runtime/core/cmake/onnx.cmake
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
if(ONNX)
set(ONNX_VERSION "1.12.0")
set(ONNX_VERSION "1.13.1")
if(${CMAKE_SYSTEM_NAME} STREQUAL "Windows")
set(ONNX_URL "https://github.com/microsoft/onnxruntime/releases/download/v${ONNX_VERSION}/onnxruntime-win-x64-${ONNX_VERSION}.zip")
set(URL_HASH "SHA256=8b5d61204989350b7904ac277f5fbccd3e6736ddbb6ec001e412723d71c9c176")
set(URL_HASH "SHA256=cd8318dc30352e0d615f809bd544bfd18b578289ec16621252b5db1994f09e43")
elseif(${CMAKE_SYSTEM_NAME} STREQUAL "Linux")
if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64")
set(ONNX_URL "https://github.com/microsoft/onnxruntime/releases/download/v${ONNX_VERSION}/onnxruntime-linux-aarch64-${ONNX_VERSION}.tgz")
set(URL_HASH "SHA256=5820d9f343df73c63b6b2b174a1ff62575032e171c9564bcf92060f46827d0ac")
set(URL_HASH "SHA256=18e441585de69ef8aab263e2e96f0325729537ebfbd17cdcee78b2eabf0594d2")
else()
set(ONNX_URL "https://github.com/microsoft/onnxruntime/releases/download/v${ONNX_VERSION}/onnxruntime-linux-x64-${ONNX_VERSION}.tgz")
set(URL_HASH "SHA256=5d503ce8540358b59be26c675e42081be14a3e833a5301926f555451046929c5")
set(URL_HASH "SHA256=2c7fdcfa8131b52167b1870747758cb24265952eba975318a67cc840c04ca73e")
endif()
elseif(${CMAKE_SYSTEM_NAME} STREQUAL "Darwin")
if(CMAKE_SYSTEM_PROCESSOR MATCHES "arm64")
set(ONNX_URL "https://github.com/microsoft/onnxruntime/releases/download/v${ONNX_VERSION}/onnxruntime-osx-arm64-${ONNX_VERSION}.tgz")
set(URL_HASH "SHA256=23117b6f5d7324d4a7c51184e5f808dd952aec411a6b99a1b6fd1011de06e300")
set(URL_HASH "SHA256=10ce30925c789715f29424a7658b41c601dfbde5d58fe21cb53ad418cde3c215")
else()
set(ONNX_URL "https://github.com/microsoft/onnxruntime/releases/download/v${ONNX_VERSION}/onnxruntime-osx-x86_64-${ONNX_VERSION}.tgz")
set(URL_HASH "SHA256=09b17f712f8c6f19bb63da35d508815b443cbb473e16c6192abfaa297c02f600")
set(URL_HASH "SHA256=32f3fff17b01db779e9e3cbe32f27adba40460e6202a79dfd1ac76b4f20588ef")
endif()
else()
message(FATAL_ERROR "Unsupported CMake System Name '${CMAKE_SYSTEM_NAME}' (expected 'Windows', 'Linux' or 'Darwin')")
Expand Down
135 changes: 94 additions & 41 deletions runtime/core/decoder/onnx_asr_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,14 @@ void OnnxAsrModel::InitEngineThreads(int num_threads) {

void OnnxAsrModel::GetInputOutputInfo(
const std::shared_ptr<Ort::Session>& session,
std::vector<const char*>* in_names, std::vector<const char*>* out_names) {
std::vector<std::string>* in_names, std::vector<std::string>* out_names) {
Ort::AllocatorWithDefaultOptions allocator;
// Input info
int num_nodes = session->GetInputCount();
in_names->resize(num_nodes);
for (int i = 0; i < num_nodes; ++i) {
char* name = session->GetInputName(i, allocator);
Ort::AllocatedStringPtr in_name_ptr =
session->GetInputNameAllocated(i, allocator);
Ort::TypeInfo type_info = session->GetInputTypeInfo(i);
auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
ONNXTensorElementDataType type = tensor_info.GetElementType();
Expand All @@ -50,15 +51,16 @@ void OnnxAsrModel::GetInputOutputInfo(
shape << j;
shape << " ";
}
LOG(INFO) << "\tInput " << i << " : name=" << name << " type=" << type
<< " dims=" << shape.str();
(*in_names)[i] = name;
LOG(INFO) << "\tInput " << i << " : name=" << in_name_ptr.get()
<< " type=" << type << " dims=" << shape.str();
(*in_names)[i] = std::string(in_name_ptr.get());
}
// Output info
num_nodes = session->GetOutputCount();
out_names->resize(num_nodes);
for (int i = 0; i < num_nodes; ++i) {
char* name = session->GetOutputName(i, allocator);
Ort::AllocatedStringPtr out_name_ptr =
session->GetOutputNameAllocated(i, allocator);
Ort::TypeInfo type_info = session->GetOutputTypeInfo(i);
auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
ONNXTensorElementDataType type = tensor_info.GetElementType();
Expand All @@ -68,9 +70,9 @@ void OnnxAsrModel::GetInputOutputInfo(
shape << j;
shape << " ";
}
LOG(INFO) << "\tOutput " << i << " : name=" << name << " type=" << type
<< " dims=" << shape.str();
(*out_names)[i] = name;
LOG(INFO) << "\tOutput " << i << " : name=" << out_name_ptr.get()
<< " type=" << type << " dims=" << shape.str();
(*out_names)[i] = std::string(out_name_ptr.get());
}
}

Expand Down Expand Up @@ -105,25 +107,43 @@ void OnnxAsrModel::Read(const std::string& model_dir) {
auto model_metadata = encoder_session_->GetModelMetadata();

Ort::AllocatorWithDefaultOptions allocator;
encoder_output_size_ =
atoi(model_metadata.LookupCustomMetadataMap("output_size", allocator));
num_blocks_ =
atoi(model_metadata.LookupCustomMetadataMap("num_blocks", allocator));
head_ = atoi(model_metadata.LookupCustomMetadataMap("head", allocator));
cnn_module_kernel_ = atoi(
model_metadata.LookupCustomMetadataMap("cnn_module_kernel", allocator));
subsampling_rate_ = atoi(
model_metadata.LookupCustomMetadataMap("subsampling_rate", allocator));
encoder_output_size_ = atoi(
model_metadata.LookupCustomMetadataMapAllocated("output_size", allocator)
.get());
num_blocks_ = atoi(
model_metadata.LookupCustomMetadataMapAllocated("num_blocks", allocator)
.get());
head_ = atoi(
model_metadata.LookupCustomMetadataMapAllocated("head", allocator).get());
cnn_module_kernel_ =
atoi(model_metadata
.LookupCustomMetadataMapAllocated("cnn_module_kernel", allocator)
.get());
subsampling_rate_ =
atoi(model_metadata
.LookupCustomMetadataMapAllocated("subsampling_rate", allocator)
.get());
right_context_ =
atoi(model_metadata.LookupCustomMetadataMap("right_context", allocator));
sos_ = atoi(model_metadata.LookupCustomMetadataMap("sos_symbol", allocator));
eos_ = atoi(model_metadata.LookupCustomMetadataMap("eos_symbol", allocator));
is_bidirectional_decoder_ = atoi(model_metadata.LookupCustomMetadataMap(
"is_bidirectional_decoder", allocator));
chunk_size_ =
atoi(model_metadata.LookupCustomMetadataMap("chunk_size", allocator));
num_left_chunks_ =
atoi(model_metadata.LookupCustomMetadataMap("left_chunks", allocator));
atoi(model_metadata
.LookupCustomMetadataMapAllocated("right_context", allocator)
.get());
sos_ = atoi(
model_metadata.LookupCustomMetadataMapAllocated("sos_symbol", allocator)
.get());
eos_ = atoi(
model_metadata.LookupCustomMetadataMapAllocated("eos_symbol", allocator)
.get());
is_bidirectional_decoder_ =
atoi(model_metadata
.LookupCustomMetadataMapAllocated("is_bidirectional_decoder",
allocator)
.get());
chunk_size_ = atoi(
model_metadata.LookupCustomMetadataMapAllocated("chunk_size", allocator)
.get());
num_left_chunks_ = atoi(
model_metadata.LookupCustomMetadataMapAllocated("left_chunks", allocator)
.get());

LOG(INFO) << "Onnx Model Info:";
LOG(INFO) << "\tencoder_output_size " << encoder_output_size_;
Expand Down Expand Up @@ -264,24 +284,35 @@ void OnnxAsrModel::ForwardEncoderFunc(
// 2. Encoder chunk forward
std::vector<Ort::Value> inputs;
for (auto name : encoder_in_names_) {
if (!strcmp(name, "chunk")) {
if (!strcmp(name.c_str(), "chunk")) {
inputs.emplace_back(std::move(feats_ort));
} else if (!strcmp(name, "offset")) {
} else if (!strcmp(name.c_str(), "offset")) {
inputs.emplace_back(std::move(offset_ort));
} else if (!strcmp(name, "required_cache_size")) {
} else if (!strcmp(name.c_str(), "required_cache_size")) {
inputs.emplace_back(std::move(required_cache_size_ort));
} else if (!strcmp(name, "att_cache")) {
} else if (!strcmp(name.c_str(), "att_cache")) {
inputs.emplace_back(std::move(att_cache_ort_));
} else if (!strcmp(name, "cnn_cache")) {
} else if (!strcmp(name.c_str(), "cnn_cache")) {
inputs.emplace_back(std::move(cnn_cache_ort_));
} else if (!strcmp(name, "att_mask")) {
} else if (!strcmp(name.c_str(), "att_mask")) {
inputs.emplace_back(std::move(att_mask_ort));
}
}

// Convert std::vector<std::string> to std::vector<const char*> for using
// C-style strings
std::vector<const char*> encoder_in_names(encoder_in_names_.size());
std::vector<const char*> encoder_out_names(encoder_out_names_.size());
std::transform(encoder_in_names_.begin(), encoder_in_names_.end(),
encoder_in_names.begin(),
[](const std::string& name) { return name.c_str(); });
std::transform(encoder_out_names_.begin(), encoder_out_names_.end(),
encoder_out_names.begin(),
[](const std::string& name) { return name.c_str(); });

std::vector<Ort::Value> ort_outputs = encoder_session_->Run(
Ort::RunOptions{nullptr}, encoder_in_names_.data(), inputs.data(),
inputs.size(), encoder_out_names_.data(), encoder_out_names_.size());
Ort::RunOptions{nullptr}, encoder_in_names.data(), inputs.data(),
inputs.size(), encoder_out_names.data(), encoder_out_names.size());

offset_ += static_cast<int>(
ort_outputs[0].GetTensorTypeAndShapeInfo().GetShape()[1]);
Expand All @@ -291,9 +322,20 @@ void OnnxAsrModel::ForwardEncoderFunc(
std::vector<Ort::Value> ctc_inputs;
ctc_inputs.emplace_back(std::move(ort_outputs[0]));

// Convert std::vector<std::string> to std::vector<const char*> for using
// C-style strings
std::vector<const char*> ctc_in_names(ctc_in_names_.size());
std::vector<const char*> ctc_out_names(ctc_out_names_.size());
std::transform(ctc_in_names_.begin(), ctc_in_names_.end(),
ctc_in_names.begin(),
[](const std::string& name) { return name.c_str(); });
std::transform(ctc_out_names_.begin(), ctc_out_names_.end(),
ctc_out_names.begin(),
[](const std::string& name) { return name.c_str(); });

std::vector<Ort::Value> ctc_ort_outputs = ctc_session_->Run(
Ort::RunOptions{nullptr}, ctc_in_names_.data(), ctc_inputs.data(),
ctc_inputs.size(), ctc_out_names_.data(), ctc_out_names_.size());
Ort::RunOptions{nullptr}, ctc_in_names.data(), ctc_inputs.data(),
ctc_inputs.size(), ctc_out_names.data(), ctc_out_names.size());
encoder_outs_.push_back(std::move(ctc_inputs[0]));

float* logp_data = ctc_ort_outputs[0].GetTensorMutableData<float>();
Expand Down Expand Up @@ -393,10 +435,21 @@ void OnnxAsrModel::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
rescore_inputs.emplace_back(std::move(hyps_lens_tensor_));
rescore_inputs.emplace_back(std::move(decode_input_tensor_));

std::vector<Ort::Value> rescore_outputs = rescore_session_->Run(
Ort::RunOptions{nullptr}, rescore_in_names_.data(), rescore_inputs.data(),
rescore_inputs.size(), rescore_out_names_.data(),
rescore_out_names_.size());
// Convert std::vector<std::string> to std::vector<const char*> for using
// C-style strings
std::vector<const char*> rescore_in_names(rescore_in_names_.size());
std::vector<const char*> rescore_out_names(rescore_out_names_.size());
std::transform(rescore_in_names_.begin(), rescore_in_names_.end(),
rescore_in_names.begin(),
[](const std::string& name) { return name.c_str(); });
std::transform(rescore_out_names_.begin(), rescore_out_names_.end(),
rescore_out_names.begin(),
[](const std::string& name) { return name.c_str(); });

std::vector<Ort::Value> rescore_outputs =
rescore_session_->Run(Ort::RunOptions{nullptr}, rescore_in_names.data(),
rescore_inputs.data(), rescore_inputs.size(),
rescore_out_names.data(), rescore_out_names.size());

float* decoder_outs_data = rescore_outputs[0].GetTensorMutableData<float>();
float* r_decoder_outs_data = rescore_outputs[1].GetTensorMutableData<float>();
Expand Down
10 changes: 5 additions & 5 deletions runtime/core/decoder/onnx_asr_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ class OnnxAsrModel : public AsrModel {
std::vector<float>* rescoring_score) override;
std::shared_ptr<AsrModel> Copy() const override;
void GetInputOutputInfo(const std::shared_ptr<Ort::Session>& session,
std::vector<const char*>* in_names,
std::vector<const char*>* out_names);
std::vector<std::string>* in_names,
std::vector<std::string>* out_names);

protected:
void ForwardEncoderFunc(const std::vector<std::vector<float>>& chunk_feats,
Expand All @@ -70,9 +70,9 @@ class OnnxAsrModel : public AsrModel {
std::shared_ptr<Ort::Session> ctc_session_ = nullptr;

// node names
std::vector<const char*> encoder_in_names_, encoder_out_names_;
std::vector<const char*> ctc_in_names_, ctc_out_names_;
std::vector<const char*> rescore_in_names_, rescore_out_names_;
std::vector<std::string> encoder_in_names_, encoder_out_names_;
std::vector<std::string> ctc_in_names_, ctc_out_names_;
std::vector<std::string> rescore_in_names_, rescore_out_names_;

// caches
Ort::Value att_cache_ort_{nullptr};
Expand Down
Loading