diff --git a/pipeless/src/stages/inference/onnx.rs b/pipeless/src/stages/inference/onnx.rs index 0b780cb..18e3169 100644 --- a/pipeless/src/stages/inference/onnx.rs +++ b/pipeless/src/stages/inference/onnx.rs @@ -9,6 +9,7 @@ pub struct OnnxSessionParams { execution_mode: Option, // Parallel or sequential exeuction mode or onnx inter_threads: Option, // If execution mode is Parallel (and nodes can be run in parallel), this sets the maximum number of threads to use to run them in parallel. intra_threads: Option, // Number of threads to parallelize the execution within nodes + custom_op_lib_path: Option, // Path to a custom op library /*ir_version: Option, opset_version: Option, image_shape_format: Option>, @@ -20,13 +21,15 @@ impl OnnxSessionParams { pub fn new( stage_name: &str, execution_provider: &str, execution_mode: Option<&str>, - inter_threads: Option, intra_threads: Option + inter_threads: Option, intra_threads: Option, + custom_op_lib_path: Option<&str>, ) -> Self { Self { stage_name: stage_name.to_string(), execution_provider: execution_provider.to_string(), execution_mode: execution_mode.map(|m| m.to_string()), inter_threads, intra_threads, + custom_op_lib_path: custom_op_lib_path.map(|p| p.to_string()), } } } @@ -82,25 +85,41 @@ impl OnnxSession { } } + if let Some(lib_path) = onnx_params.custom_op_lib_path { + log::info!("Loading custom operations lib from: {}", lib_path); + session_builder = session_builder.with_custom_op_lib(&lib_path).unwrap(); + } + let session = session_builder.with_model_from_file(model_file_path).unwrap(); // Run a first test inference that usually takes more time. // This avoids to add an initial delay to the stream when it arrives, making the session ready - let input0_shape: Vec = session.inputs[0].dimensions() - .map(std::option::Option::unwrap) - .collect(); - // Assuming the conventional input format: batch, channels, height, witdh - let batch_shift = if input0_shape.len() > 3 { 1 } else { 0 }; - let width = input0_shape[2 + batch_shift]; - let height = input0_shape[1 + batch_shift]; - let channels = input0_shape[0 + batch_shift]; - let test_image = ndarray::Array3::::zeros((channels, height, width)).into_dyn(); - let cow_array = ndarray::CowArray::from(test_image); - let ort_input_value = ort::Value::from_array( - session.allocator(), - &cow_array - ).unwrap(); - let _ = session.run(vec![ort_input_value]); + let input0_shape: Vec> = session.inputs[0].dimensions().map(|x| x).collect(); + if input0_shape.len() > 2 { + // Assuming the conventional input format: batch, channels, height, witdh + let batch_shift = if input0_shape.len() > 3 { 1 } else { 0 }; + let width = input0_shape[2 + batch_shift]; + let height = input0_shape[1 + batch_shift]; + let channels = input0_shape[0 + batch_shift]; + if let (Some(width), Some(height), Some(channels)) = (width, height, channels) { + let test_image = ndarray::Array3::::zeros((channels, height, width)).into_dyn(); + let cow_array = ndarray::CowArray::from(test_image); + let ort_input_value = ort::Value::from_array( + session.allocator(), + &cow_array + ).unwrap(); + let _ = session.run(vec![ort_input_value]); + } else { + warn!( + "Could not run an inference test because the model input shape was not properly recognized. Obtained: width: {:?}, height: {:?}, channels: {:?}", + width.map(|num| num.to_string()).unwrap_or_else(|| "None".to_string()), // Print the number on the option or "None" + height.map(|num| num.to_string()).unwrap_or_else(|| "None".to_string()), + channels.map(|num| num.to_string()).unwrap_or_else(|| "None".to_string()) + ); + } + } else { + warn!("Could not run an inference test because the model input shape does not contain all the image dimensions"); + } Ok(Self { session }) } else { @@ -112,14 +131,16 @@ impl OnnxSession { impl super::session::SessionTrait for OnnxSession { fn infer(&self, mut frame: pipeless::data::Frame) -> pipeless::data::Frame { + // TODO: automatically resize and traspose the input image to the expected by the model + + // FIXME: we are forcing users to provide float32 arrays which will produce the inference to fail if the model expects uint values. + let input_data = frame.get_inference_input().to_owned(); if input_data.len() == 0 { warn!("No inference input data was provided. Did you forget to add it at your pre-process hook?"); return frame; } - // TODO: automatically resize and traspose the input image to the expected by the model - let input_vec = input_data.view().insert_axis(ndarray::Axis(0)).into_dyn(); // Batch image with batch size 1 let cow_array = ndarray::CowArray::from(input_vec); let ort_input_value_result = ort::Value::from_array( diff --git a/pipeless/src/stages/inference/session.rs b/pipeless/src/stages/inference/session.rs index 54ff160..e584cbc 100644 --- a/pipeless/src/stages/inference/session.rs +++ b/pipeless/src/stages/inference/session.rs @@ -42,12 +42,14 @@ impl SessionParams { warn!("'execution_mode' must be set to 'Parallel' for 'inter_threads' to take effect"); } let intra_threads = data["intra_threads"].as_i64(); + let custom_op_lib_path = data["custom_op_lib_path"].as_str(); SessionParams::Onnx( OnnxSessionParams::new( stage_name, execution_provider, execution_mode, inter_threads.map(|t| t as i16), - intra_threads.map(|t| t as i16) + intra_threads.map(|t| t as i16), + custom_op_lib_path, )) }, super::runtime::InferenceRuntime::Openvino => unimplemented!(),