Skip to content

Commit

Permalink
Merge pull request #128 from pipeless-ai/custom_op_lib
Browse files Browse the repository at this point in the history
feat(onnxruntime): Support custom op library
  • Loading branch information
miguelaeh authored Jan 29, 2024
2 parents 5abb981 + 643fec2 commit 9c09485
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 19 deletions.
57 changes: 39 additions & 18 deletions pipeless/src/stages/inference/onnx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ pub struct OnnxSessionParams {
execution_mode: Option<String>, // Parallel or sequential exeuction mode or onnx
inter_threads: Option<i16>, // If execution mode is Parallel (and nodes can be run in parallel), this sets the maximum number of threads to use to run them in parallel.
intra_threads: Option<i16>, // Number of threads to parallelize the execution within nodes
custom_op_lib_path: Option<String>, // Path to a custom op library
/*ir_version: Option<u32>,
opset_version: Option<u32>,
image_shape_format: Option<Vec<String>>,
Expand All @@ -20,13 +21,15 @@ impl OnnxSessionParams {
pub fn new(
stage_name: &str,
execution_provider: &str, execution_mode: Option<&str>,
inter_threads: Option<i16>, intra_threads: Option<i16>
inter_threads: Option<i16>, intra_threads: Option<i16>,
custom_op_lib_path: Option<&str>,
) -> Self {
Self {
stage_name: stage_name.to_string(),
execution_provider: execution_provider.to_string(),
execution_mode: execution_mode.map(|m| m.to_string()),
inter_threads, intra_threads,
custom_op_lib_path: custom_op_lib_path.map(|p| p.to_string()),
}
}
}
Expand Down Expand Up @@ -82,25 +85,41 @@ impl OnnxSession {
}
}

if let Some(lib_path) = onnx_params.custom_op_lib_path {
log::info!("Loading custom operations lib from: {}", lib_path);
session_builder = session_builder.with_custom_op_lib(&lib_path).unwrap();
}

let session = session_builder.with_model_from_file(model_file_path).unwrap();

// Run a first test inference that usually takes more time.
// This avoids to add an initial delay to the stream when it arrives, making the session ready
let input0_shape: Vec<usize> = session.inputs[0].dimensions()
.map(std::option::Option::unwrap)
.collect();
// Assuming the conventional input format: batch, channels, height, witdh
let batch_shift = if input0_shape.len() > 3 { 1 } else { 0 };
let width = input0_shape[2 + batch_shift];
let height = input0_shape[1 + batch_shift];
let channels = input0_shape[0 + batch_shift];
let test_image = ndarray::Array3::<u8>::zeros((channels, height, width)).into_dyn();
let cow_array = ndarray::CowArray::from(test_image);
let ort_input_value = ort::Value::from_array(
session.allocator(),
&cow_array
).unwrap();
let _ = session.run(vec![ort_input_value]);
let input0_shape: Vec<Option<usize>> = session.inputs[0].dimensions().map(|x| x).collect();
if input0_shape.len() > 2 {
// Assuming the conventional input format: batch, channels, height, witdh
let batch_shift = if input0_shape.len() > 3 { 1 } else { 0 };
let width = input0_shape[2 + batch_shift];
let height = input0_shape[1 + batch_shift];
let channels = input0_shape[0 + batch_shift];
if let (Some(width), Some(height), Some(channels)) = (width, height, channels) {
let test_image = ndarray::Array3::<u8>::zeros((channels, height, width)).into_dyn();
let cow_array = ndarray::CowArray::from(test_image);
let ort_input_value = ort::Value::from_array(
session.allocator(),
&cow_array
).unwrap();
let _ = session.run(vec![ort_input_value]);
} else {
warn!(
"Could not run an inference test because the model input shape was not properly recognized. Obtained: width: {:?}, height: {:?}, channels: {:?}",
width.map(|num| num.to_string()).unwrap_or_else(|| "None".to_string()), // Print the number on the option or "None"
height.map(|num| num.to_string()).unwrap_or_else(|| "None".to_string()),
channels.map(|num| num.to_string()).unwrap_or_else(|| "None".to_string())
);
}
} else {
warn!("Could not run an inference test because the model input shape does not contain all the image dimensions");
}

Ok(Self { session })
} else {
Expand All @@ -112,14 +131,16 @@ impl OnnxSession {

impl super::session::SessionTrait for OnnxSession {
fn infer(&self, mut frame: pipeless::data::Frame) -> pipeless::data::Frame {
// TODO: automatically resize and traspose the input image to the expected by the model

// FIXME: we are forcing users to provide float32 arrays which will produce the inference to fail if the model expects uint values.

let input_data = frame.get_inference_input().to_owned();
if input_data.len() == 0 {
warn!("No inference input data was provided. Did you forget to add it at your pre-process hook?");
return frame;
}

// TODO: automatically resize and traspose the input image to the expected by the model

let input_vec = input_data.view().insert_axis(ndarray::Axis(0)).into_dyn(); // Batch image with batch size 1
let cow_array = ndarray::CowArray::from(input_vec);
let ort_input_value_result = ort::Value::from_array(
Expand Down
4 changes: 3 additions & 1 deletion pipeless/src/stages/inference/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,14 @@ impl SessionParams {
warn!("'execution_mode' must be set to 'Parallel' for 'inter_threads' to take effect");
}
let intra_threads = data["intra_threads"].as_i64();
let custom_op_lib_path = data["custom_op_lib_path"].as_str();
SessionParams::Onnx(
OnnxSessionParams::new(
stage_name,
execution_provider, execution_mode,
inter_threads.map(|t| t as i16),
intra_threads.map(|t| t as i16)
intra_threads.map(|t| t as i16),
custom_op_lib_path,
))
},
super::runtime::InferenceRuntime::Openvino => unimplemented!(),
Expand Down

0 comments on commit 9c09485

Please sign in to comment.