Skip to main content

varpulis_runtime/
scoring.rs

1//! ONNX model scoring for `.score()` operator
2//!
3//! When the `onnx` feature is enabled, provides `OnnxModel` which loads an ONNX model
4//! and runs inference against event fields. Without the feature, a stub is provided
5//! that returns an error at compile time (engine compilation, not Rust compilation).
6//!
7//! GPU acceleration is available when both the `gpu` and `onnx` features are enabled.
8//! The `gpu` feature implies `onnx`, so enabling `gpu` is sufficient.
9
10/// GPU execution provider configuration
11#[derive(Debug, Clone, Default)]
12pub enum GpuProvider {
13    /// CPU execution (default)
14    #[default]
15    Cpu,
16    /// CUDA GPU execution
17    Cuda { device_id: i32 },
18    /// TensorRT GPU execution (optimized inference)
19    TensorRT { device_id: i32 },
20}
21
22/// GPU and batching configuration for ONNX inference
23#[derive(Debug, Clone)]
24pub struct GpuConfig {
25    pub provider: GpuProvider,
26    pub batch_size: usize,
27}
28
29impl Default for GpuConfig {
30    fn default() -> Self {
31        Self {
32            provider: GpuProvider::Cpu,
33            batch_size: 1,
34        }
35    }
36}
37
38#[cfg(feature = "onnx")]
39mod inner {
40    use std::sync::Arc;
41
42    use ort::session::Session;
43    use ort::value::Tensor;
44
45    use super::{GpuConfig, GpuProvider};
46    use crate::event::Event;
47
48    /// An ONNX model loaded into ONNX Runtime for per-event or batch inference.
49    pub struct OnnxModel {
50        session: Arc<std::sync::Mutex<Session>>,
51        input_name: String,
52        pub input_fields: Vec<String>,
53        pub output_fields: Vec<String>,
54        pub batch_size: usize,
55    }
56
57    impl std::fmt::Debug for OnnxModel {
58        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59            f.debug_struct("OnnxModel")
60                .field("input_name", &self.input_name)
61                .field("input_fields", &self.input_fields)
62                .field("output_fields", &self.output_fields)
63                .field("batch_size", &self.batch_size)
64                .finish_non_exhaustive()
65        }
66    }
67
68    impl OnnxModel {
69        /// Load an ONNX model from a file path with optional GPU configuration.
70        pub fn load(
71            path: &str,
72            inputs: Vec<String>,
73            outputs: Vec<String>,
74            gpu_config: Option<GpuConfig>,
75        ) -> Result<Self, String> {
76            let mut builder =
77                Session::builder().map_err(|e| format!("ONNX session builder error: {}", e))?;
78
79            // Configure execution provider based on GPU settings
80            if let Some(ref config) = gpu_config {
81                match &config.provider {
82                    GpuProvider::Cpu => {
83                        // Default CPU provider, no extra config needed
84                    }
85                    GpuProvider::Cuda { device_id } => {
86                        builder = builder
87                            .with_execution_providers([
88                                ort::execution_providers::CUDAExecutionProvider::default()
89                                    .with_device_id(*device_id)
90                                    .build(),
91                            ])
92                            .map_err(|e| format!("CUDA provider error: {}", e))?;
93                    }
94                    GpuProvider::TensorRT { device_id } => {
95                        builder = builder
96                            .with_execution_providers([
97                                ort::execution_providers::TensorRTExecutionProvider::default()
98                                    .with_device_id(*device_id)
99                                    .build(),
100                            ])
101                            .map_err(|e| format!("TensorRT provider error: {}", e))?;
102                    }
103                }
104            }
105
106            let session = builder
107                .commit_from_file(path)
108                .map_err(|e| format!("Failed to load ONNX model '{}': {}", path, e))?;
109
110            // Read the first input tensor name from the model
111            let input_name = session
112                .inputs()
113                .first()
114                .map(|i| i.name().to_string())
115                .unwrap_or_else(|| "input".to_string());
116
117            let batch_size = gpu_config.as_ref().map(|c| c.batch_size).unwrap_or(1);
118
119            Ok(OnnxModel {
120                session: Arc::new(std::sync::Mutex::new(session)),
121                input_name,
122                input_fields: inputs,
123                output_fields: outputs,
124                batch_size,
125            })
126        }
127
128        /// Run inference on a single event, returning `(field_name, value)` pairs.
129        pub fn infer(&self, event: &Event) -> Result<Vec<(String, f64)>, String> {
130            let n = self.input_fields.len();
131            let mut input_data = Vec::with_capacity(n);
132
133            for field in &self.input_fields {
134                let val = if let Some(f) = event.get_float(field) {
135                    f as f32
136                } else if let Some(i) = event.get_int(field) {
137                    i as f32
138                } else {
139                    return Err(format!("Missing input field '{}' in event", field));
140                };
141                input_data.push(val);
142            }
143
144            // Use (shape, data) tuple form — compatible with ort's OwnedTensorArrayData
145            let input_tensor = Tensor::from_array((vec![1_i64, n as i64], input_data))
146                .map_err(|e| format!("Tensor creation error: {}", e))?;
147
148            let mut session = self
149                .session
150                .lock()
151                .map_err(|e| format!("Session lock error: {}", e))?;
152
153            let outputs = session
154                .run(ort::inputs![self.input_name.as_str() => input_tensor])
155                .map_err(|e| format!("ONNX inference error: {}", e))?;
156
157            // Index the first output tensor by position
158            let output_value = &outputs[0];
159
160            let (_, raw_data) = output_value
161                .try_extract_tensor::<f32>()
162                .map_err(|e| format!("Output tensor extract error: {}", e))?;
163
164            let mut results = Vec::with_capacity(self.output_fields.len());
165            for (i, field) in self.output_fields.iter().enumerate() {
166                let val = raw_data.get(i).copied().unwrap_or(0.0f32);
167                results.push((field.clone(), val as f64));
168            }
169
170            Ok(results)
171        }
172
173        /// Run batch inference on multiple events, returning per-event results.
174        ///
175        /// Constructs a [batch_size, features] tensor and runs a single inference call.
176        /// More efficient than calling `infer()` per event when GPU is enabled.
177        pub fn infer_batch(&self, events: &[&Event]) -> Result<Vec<Vec<(String, f64)>>, String> {
178            if events.is_empty() {
179                return Ok(Vec::new());
180            }
181
182            let n_features = self.input_fields.len();
183            let batch = events.len();
184            let mut input_data = Vec::with_capacity(batch * n_features);
185
186            for event in events {
187                for field in &self.input_fields {
188                    let val = if let Some(f) = event.get_float(field) {
189                        f as f32
190                    } else if let Some(i) = event.get_int(field) {
191                        i as f32
192                    } else {
193                        return Err(format!("Missing input field '{}' in event", field));
194                    };
195                    input_data.push(val);
196                }
197            }
198
199            let input_tensor =
200                Tensor::from_array((vec![batch as i64, n_features as i64], input_data))
201                    .map_err(|e| format!("Batch tensor creation error: {}", e))?;
202
203            let mut session = self
204                .session
205                .lock()
206                .map_err(|e| format!("Session lock error: {}", e))?;
207
208            let outputs = session
209                .run(ort::inputs![self.input_name.as_str() => input_tensor])
210                .map_err(|e| format!("ONNX batch inference error: {}", e))?;
211
212            let output_value = &outputs[0];
213            let (_, raw_data) = output_value
214                .try_extract_tensor::<f32>()
215                .map_err(|e| format!("Batch output tensor extract error: {}", e))?;
216
217            let n_outputs = self.output_fields.len();
218            let mut batch_results = Vec::with_capacity(batch);
219
220            for i in 0..batch {
221                let mut results = Vec::with_capacity(n_outputs);
222                for (j, field) in self.output_fields.iter().enumerate() {
223                    let idx = i * n_outputs + j;
224                    let val = raw_data.get(idx).copied().unwrap_or(0.0f32);
225                    results.push((field.clone(), val as f64));
226                }
227                batch_results.push(results);
228            }
229
230            Ok(batch_results)
231        }
232    }
233}
234
235#[cfg(not(feature = "onnx"))]
236mod inner {
237    use super::GpuConfig;
238    use crate::event::Event;
239
240    /// Stub `OnnxModel` when the `onnx` feature is not enabled.
241    #[derive(Debug)]
242    pub struct OnnxModel {
243        pub input_fields: Vec<String>,
244        pub output_fields: Vec<String>,
245        pub batch_size: usize,
246    }
247
248    impl OnnxModel {
249        pub fn load(
250            _path: &str,
251            _inputs: Vec<String>,
252            _outputs: Vec<String>,
253            _gpu_config: Option<GpuConfig>,
254        ) -> Result<Self, String> {
255            Err(
256                ".score() requires the 'onnx' feature — rebuild with: cargo build --features onnx"
257                    .to_string(),
258            )
259        }
260
261        pub fn infer(&self, _event: &Event) -> Result<Vec<(String, f64)>, String> {
262            Err(".score() requires the 'onnx' feature".to_string())
263        }
264
265        pub fn infer_batch(&self, _events: &[&Event]) -> Result<Vec<Vec<(String, f64)>>, String> {
266            Err(".score() requires the 'onnx' feature".to_string())
267        }
268    }
269}
270
271pub use inner::OnnxModel;