varpulis_runtime/
scoring.rs1#[derive(Debug, Clone, Default)]
12pub enum GpuProvider {
13 #[default]
15 Cpu,
16 Cuda { device_id: i32 },
18 TensorRT { device_id: i32 },
20}
21
22#[derive(Debug, Clone)]
24pub struct GpuConfig {
25 pub provider: GpuProvider,
26 pub batch_size: usize,
27}
28
29impl Default for GpuConfig {
30 fn default() -> Self {
31 Self {
32 provider: GpuProvider::Cpu,
33 batch_size: 1,
34 }
35 }
36}
37
38#[cfg(feature = "onnx")]
39mod inner {
40 use std::sync::Arc;
41
42 use ort::session::Session;
43 use ort::value::Tensor;
44
45 use super::{GpuConfig, GpuProvider};
46 use crate::event::Event;
47
48 pub struct OnnxModel {
50 session: Arc<std::sync::Mutex<Session>>,
51 input_name: String,
52 pub input_fields: Vec<String>,
53 pub output_fields: Vec<String>,
54 pub batch_size: usize,
55 }
56
57 impl std::fmt::Debug for OnnxModel {
58 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59 f.debug_struct("OnnxModel")
60 .field("input_name", &self.input_name)
61 .field("input_fields", &self.input_fields)
62 .field("output_fields", &self.output_fields)
63 .field("batch_size", &self.batch_size)
64 .finish_non_exhaustive()
65 }
66 }
67
68 impl OnnxModel {
69 pub fn load(
71 path: &str,
72 inputs: Vec<String>,
73 outputs: Vec<String>,
74 gpu_config: Option<GpuConfig>,
75 ) -> Result<Self, String> {
76 let mut builder =
77 Session::builder().map_err(|e| format!("ONNX session builder error: {}", e))?;
78
79 if let Some(ref config) = gpu_config {
81 match &config.provider {
82 GpuProvider::Cpu => {
83 }
85 GpuProvider::Cuda { device_id } => {
86 builder = builder
87 .with_execution_providers([
88 ort::execution_providers::CUDAExecutionProvider::default()
89 .with_device_id(*device_id)
90 .build(),
91 ])
92 .map_err(|e| format!("CUDA provider error: {}", e))?;
93 }
94 GpuProvider::TensorRT { device_id } => {
95 builder = builder
96 .with_execution_providers([
97 ort::execution_providers::TensorRTExecutionProvider::default()
98 .with_device_id(*device_id)
99 .build(),
100 ])
101 .map_err(|e| format!("TensorRT provider error: {}", e))?;
102 }
103 }
104 }
105
106 let session = builder
107 .commit_from_file(path)
108 .map_err(|e| format!("Failed to load ONNX model '{}': {}", path, e))?;
109
110 let input_name = session
112 .inputs()
113 .first()
114 .map(|i| i.name().to_string())
115 .unwrap_or_else(|| "input".to_string());
116
117 let batch_size = gpu_config.as_ref().map(|c| c.batch_size).unwrap_or(1);
118
119 Ok(OnnxModel {
120 session: Arc::new(std::sync::Mutex::new(session)),
121 input_name,
122 input_fields: inputs,
123 output_fields: outputs,
124 batch_size,
125 })
126 }
127
128 pub fn infer(&self, event: &Event) -> Result<Vec<(String, f64)>, String> {
130 let n = self.input_fields.len();
131 let mut input_data = Vec::with_capacity(n);
132
133 for field in &self.input_fields {
134 let val = if let Some(f) = event.get_float(field) {
135 f as f32
136 } else if let Some(i) = event.get_int(field) {
137 i as f32
138 } else {
139 return Err(format!("Missing input field '{}' in event", field));
140 };
141 input_data.push(val);
142 }
143
144 let input_tensor = Tensor::from_array((vec![1_i64, n as i64], input_data))
146 .map_err(|e| format!("Tensor creation error: {}", e))?;
147
148 let mut session = self
149 .session
150 .lock()
151 .map_err(|e| format!("Session lock error: {}", e))?;
152
153 let outputs = session
154 .run(ort::inputs![self.input_name.as_str() => input_tensor])
155 .map_err(|e| format!("ONNX inference error: {}", e))?;
156
157 let output_value = &outputs[0];
159
160 let (_, raw_data) = output_value
161 .try_extract_tensor::<f32>()
162 .map_err(|e| format!("Output tensor extract error: {}", e))?;
163
164 let mut results = Vec::with_capacity(self.output_fields.len());
165 for (i, field) in self.output_fields.iter().enumerate() {
166 let val = raw_data.get(i).copied().unwrap_or(0.0f32);
167 results.push((field.clone(), val as f64));
168 }
169
170 Ok(results)
171 }
172
173 pub fn infer_batch(&self, events: &[&Event]) -> Result<Vec<Vec<(String, f64)>>, String> {
178 if events.is_empty() {
179 return Ok(Vec::new());
180 }
181
182 let n_features = self.input_fields.len();
183 let batch = events.len();
184 let mut input_data = Vec::with_capacity(batch * n_features);
185
186 for event in events {
187 for field in &self.input_fields {
188 let val = if let Some(f) = event.get_float(field) {
189 f as f32
190 } else if let Some(i) = event.get_int(field) {
191 i as f32
192 } else {
193 return Err(format!("Missing input field '{}' in event", field));
194 };
195 input_data.push(val);
196 }
197 }
198
199 let input_tensor =
200 Tensor::from_array((vec![batch as i64, n_features as i64], input_data))
201 .map_err(|e| format!("Batch tensor creation error: {}", e))?;
202
203 let mut session = self
204 .session
205 .lock()
206 .map_err(|e| format!("Session lock error: {}", e))?;
207
208 let outputs = session
209 .run(ort::inputs![self.input_name.as_str() => input_tensor])
210 .map_err(|e| format!("ONNX batch inference error: {}", e))?;
211
212 let output_value = &outputs[0];
213 let (_, raw_data) = output_value
214 .try_extract_tensor::<f32>()
215 .map_err(|e| format!("Batch output tensor extract error: {}", e))?;
216
217 let n_outputs = self.output_fields.len();
218 let mut batch_results = Vec::with_capacity(batch);
219
220 for i in 0..batch {
221 let mut results = Vec::with_capacity(n_outputs);
222 for (j, field) in self.output_fields.iter().enumerate() {
223 let idx = i * n_outputs + j;
224 let val = raw_data.get(idx).copied().unwrap_or(0.0f32);
225 results.push((field.clone(), val as f64));
226 }
227 batch_results.push(results);
228 }
229
230 Ok(batch_results)
231 }
232 }
233}
234
235#[cfg(not(feature = "onnx"))]
236mod inner {
237 use super::GpuConfig;
238 use crate::event::Event;
239
240 #[derive(Debug)]
242 pub struct OnnxModel {
243 pub input_fields: Vec<String>,
244 pub output_fields: Vec<String>,
245 pub batch_size: usize,
246 }
247
248 impl OnnxModel {
249 pub fn load(
250 _path: &str,
251 _inputs: Vec<String>,
252 _outputs: Vec<String>,
253 _gpu_config: Option<GpuConfig>,
254 ) -> Result<Self, String> {
255 Err(
256 ".score() requires the 'onnx' feature — rebuild with: cargo build --features onnx"
257 .to_string(),
258 )
259 }
260
261 pub fn infer(&self, _event: &Event) -> Result<Vec<(String, f64)>, String> {
262 Err(".score() requires the 'onnx' feature".to_string())
263 }
264
265 pub fn infer_batch(&self, _events: &[&Event]) -> Result<Vec<Vec<(String, f64)>>, String> {
266 Err(".score() requires the 'onnx' feature".to_string())
267 }
268 }
269}
270
271pub use inner::OnnxModel;