1use crate::{Error, Result};
7use candle_core::{Device, Tensor};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::path::PathBuf;
11use std::sync::{Arc, RwLock};
12
13#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
15pub enum MLFramework {
16 Candle,
18 OnnxRuntime,
20 TensorFlowLite,
22 PyTorch,
24 Custom,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct MLFrameworkConfig {
31 pub primary_framework: MLFramework,
33 pub fallback_frameworks: Vec<MLFramework>,
35 pub device_preference: DevicePreference,
37 pub optimization: ModelOptimization,
39 pub memory_config: MemoryConfig,
41 pub performance_config: PerformanceConfig,
43 pub framework_settings: HashMap<MLFramework, FrameworkSettings>,
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
49pub enum DevicePreference {
50 Cpu,
52 Gpu {
54 device_index: Option<usize>,
56 memory_limit_mb: Option<usize>,
58 },
59 Auto,
61 Custom {
63 device_id: String,
65 capabilities: HashMap<String, String>,
67 },
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct ModelOptimization {
73 pub quantization_enabled: bool,
75 pub quantization_precision: QuantizationPrecision,
77 pub pruning_enabled: bool,
79 pub pruning_ratio: f32,
81 pub distillation_enabled: bool,
83 pub operator_fusion: bool,
85 pub constant_folding: bool,
87}
88
89#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
91pub enum QuantizationPrecision {
92 Int8,
94 Int16,
96 Float16,
98 Dynamic,
100}
101
102#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct MemoryConfig {
105 pub max_memory_mb: usize,
107 pub memory_pool_size_mb: usize,
109 pub memory_optimization_enabled: bool,
111 pub gc_frequency: usize,
113 pub memory_mapping_enabled: bool,
115}
116
117#[derive(Debug, Clone, Serialize, Deserialize)]
119pub struct PerformanceConfig {
120 pub cpu_threads: Option<usize>,
122 pub batch_size: usize,
124 pub async_execution: bool,
126 pub prefetch_buffer_size: usize,
128 pub pipeline_parallelism: bool,
130 pub model_caching: bool,
132}
133
134#[derive(Debug, Clone, Serialize, Deserialize)]
136pub struct FrameworkSettings {
137 pub library_path: Option<PathBuf>,
139 pub init_params: HashMap<String, String>,
141 pub provider_options: HashMap<String, String>,
143 pub session_config: HashMap<String, String>,
145}
146
147#[derive(Debug, Clone, Serialize, Deserialize)]
149pub struct MLModelMetadata {
150 pub name: String,
152 pub version: String,
154 pub framework: MLFramework,
156 pub input_specs: Vec<TensorSpec>,
158 pub output_specs: Vec<TensorSpec>,
160 pub model_path: PathBuf,
162 pub model_size_bytes: u64,
164 pub supported_sample_rates: Vec<u32>,
166 pub capabilities: ModelCapabilities,
168}
169
170#[derive(Debug, Clone, Serialize, Deserialize)]
172pub struct TensorSpec {
173 pub name: String,
175 pub shape: Vec<i64>,
177 pub data_type: TensorDataType,
179 pub description: Option<String>,
181}
182
183#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
185pub enum TensorDataType {
186 Float32,
188 Float64,
190 Int32,
192 Int64,
194 UInt8,
196 Int8,
198 Float16,
200}
201
202#[derive(Debug, Clone, Serialize, Deserialize)]
204pub struct ModelCapabilities {
205 pub realtime_capable: bool,
207 pub batch_capable: bool,
209 pub streaming_capable: bool,
211 pub gpu_accelerated: bool,
213 pub quantization_support: bool,
215 pub max_input_length: Option<usize>,
217}
218
219pub struct MLInferenceSession {
221 framework: MLFramework,
223 model_metadata: MLModelMetadata,
225 candle_session: Option<CandleSession>,
227 config: MLFrameworkConfig,
229 metrics: Arc<RwLock<InferenceMetrics>>,
231}
232
233pub struct CandleSession {
235 device: Device,
237 model_weights: HashMap<String, Tensor>,
239 model_architecture: ModelArchitecture,
241}
242
243#[derive(Debug, Clone)]
245pub enum ModelArchitecture {
246 Transformer {
248 num_layers: usize,
250 hidden_dim: usize,
252 num_heads: usize,
254 },
255 Cnn {
257 conv_layers: Vec<ConvLayerConfig>,
259 fc_layers: Vec<usize>,
261 },
262 Rnn {
264 rnn_type: RnnType,
266 hidden_size: usize,
268 num_layers: usize,
270 bidirectional: bool,
272 },
273 Custom {
275 description: String,
277 layers: Vec<LayerSpec>,
279 },
280}
281
282#[derive(Debug, Clone, Serialize, Deserialize)]
284pub struct ConvLayerConfig {
285 pub in_channels: usize,
287 pub out_channels: usize,
289 pub kernel_size: usize,
291 pub stride: usize,
293 pub padding: usize,
295 pub activation: ActivationFunction,
297}
298
299#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
301pub enum RnnType {
302 Lstm,
304 Gru,
306 Vanilla,
308}
309
310#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
312pub enum ActivationFunction {
313 ReLU,
315 LeakyReLU,
317 Tanh,
319 Sigmoid,
321 Swish,
323 GELU,
325 Mish,
327}
328
329#[derive(Debug, Clone, Serialize, Deserialize)]
331pub struct LayerSpec {
332 pub layer_type: String,
334 pub parameters: HashMap<String, f32>,
336 pub input_shape: Vec<usize>,
338 pub output_shape: Vec<usize>,
340}
341
342#[derive(Debug, Clone, Serialize, Deserialize)]
344pub struct InferenceMetrics {
345 pub inference_count: u64,
347 pub total_inference_time_ms: u64,
349 pub avg_inference_time_ms: f32,
351 pub min_inference_time_ms: f32,
353 pub max_inference_time_ms: f32,
355 pub memory_usage: MemoryUsageStats,
357 pub error_count: u64,
359 pub last_update: std::time::SystemTime,
361}
362
363impl Default for InferenceMetrics {
364 fn default() -> Self {
365 Self {
366 inference_count: 0,
367 total_inference_time_ms: 0,
368 avg_inference_time_ms: 0.0,
369 min_inference_time_ms: 0.0,
370 max_inference_time_ms: 0.0,
371 memory_usage: MemoryUsageStats::default(),
372 error_count: 0,
373 last_update: std::time::SystemTime::now(),
374 }
375 }
376}
377
378#[derive(Debug, Default, Clone, Serialize, Deserialize)]
380pub struct MemoryUsageStats {
381 pub peak_usage_bytes: u64,
383 pub current_usage_bytes: u64,
385 pub avg_usage_bytes: u64,
387 pub allocation_count: u64,
389}
390
391pub struct MLFrameworkManager {
393 frameworks: HashMap<MLFramework, FrameworkInfo>,
395 active_sessions: Arc<RwLock<HashMap<String, MLInferenceSession>>>,
397 config: MLFrameworkConfig,
399 model_registry: Arc<RwLock<HashMap<String, MLModelMetadata>>>,
401}
402
403#[derive(Debug, Clone)]
405pub struct FrameworkInfo {
406 version: String,
408 providers: Vec<String>,
410 initialized: bool,
412 capabilities: FrameworkCapabilities,
414}
415
416#[derive(Debug, Clone)]
418pub struct FrameworkCapabilities {
419 gpu_support: bool,
421 quantization_support: bool,
423 dynamic_shapes: bool,
425 streaming_support: bool,
427}
428
429impl Default for MLFrameworkConfig {
430 fn default() -> Self {
431 Self {
432 primary_framework: MLFramework::Candle,
433 fallback_frameworks: vec![MLFramework::OnnxRuntime],
434 device_preference: DevicePreference::Auto,
435 optimization: ModelOptimization::default(),
436 memory_config: MemoryConfig::default(),
437 performance_config: PerformanceConfig::default(),
438 framework_settings: HashMap::new(),
439 }
440 }
441}
442
443impl Default for ModelOptimization {
444 fn default() -> Self {
445 Self {
446 quantization_enabled: true,
447 quantization_precision: QuantizationPrecision::Int8,
448 pruning_enabled: false,
449 pruning_ratio: 0.1,
450 distillation_enabled: false,
451 operator_fusion: true,
452 constant_folding: true,
453 }
454 }
455}
456
457impl Default for MemoryConfig {
458 fn default() -> Self {
459 Self {
460 max_memory_mb: 4096, memory_pool_size_mb: 512, memory_optimization_enabled: true,
463 gc_frequency: 100,
464 memory_mapping_enabled: true,
465 }
466 }
467}
468
469impl Default for PerformanceConfig {
470 fn default() -> Self {
471 Self {
472 cpu_threads: None, batch_size: 1,
474 async_execution: true,
475 prefetch_buffer_size: 4,
476 pipeline_parallelism: false,
477 model_caching: true,
478 }
479 }
480}
481
482impl MLFrameworkManager {
483 pub fn new(config: MLFrameworkConfig) -> Result<Self> {
485 let mut frameworks = HashMap::new();
486
487 frameworks.insert(
489 MLFramework::Candle,
490 FrameworkInfo {
491 version: env!("CARGO_PKG_VERSION").to_string(),
492 providers: vec!["CPU".to_string(), "CUDA".to_string(), "Metal".to_string()],
493 initialized: true,
494 capabilities: FrameworkCapabilities {
495 gpu_support: true,
496 quantization_support: true,
497 dynamic_shapes: true,
498 streaming_support: true,
499 },
500 },
501 );
502
503 frameworks.insert(
505 MLFramework::OnnxRuntime,
506 FrameworkInfo {
507 version: "1.16.0".to_string(),
508 providers: vec!["CPU".to_string(), "CUDA".to_string()],
509 initialized: false, capabilities: FrameworkCapabilities {
511 gpu_support: true,
512 quantization_support: true,
513 dynamic_shapes: true,
514 streaming_support: false,
515 },
516 },
517 );
518
519 Ok(Self {
520 frameworks,
521 active_sessions: Arc::new(RwLock::new(HashMap::new())),
522 config,
523 model_registry: Arc::new(RwLock::new(HashMap::new())),
524 })
525 }
526
527 pub fn register_model(&self, metadata: MLModelMetadata) -> Result<()> {
529 let mut registry = self.model_registry.write().map_err(|_| {
530 Error::runtime("Failed to acquire write lock on model registry".to_string())
531 })?;
532
533 registry.insert(metadata.name.clone(), metadata);
534 Ok(())
535 }
536
537 pub fn create_session(&self, model_name: &str, session_id: String) -> Result<()> {
539 let model_metadata = {
540 let registry = self.model_registry.read().map_err(|_| {
541 Error::runtime("Failed to acquire read lock on model registry".to_string())
542 })?;
543
544 registry.get(model_name).cloned().ok_or_else(|| {
545 Error::model(format!("Model '{model_name}' not found in registry"))
546 })?
547 };
548
549 let framework = self.select_framework(&model_metadata)?;
551
552 let session = match framework {
553 MLFramework::Candle => self.create_candle_session(&model_metadata)?,
554 MLFramework::OnnxRuntime => self.create_onnx_session(&model_metadata)?,
555 MLFramework::TensorFlowLite => self.create_tflite_session(&model_metadata)?,
556 MLFramework::PyTorch => self.create_pytorch_session(&model_metadata)?,
557 MLFramework::Custom => self.create_custom_session(&model_metadata)?,
558 };
559
560 let mut sessions = self.active_sessions.write().map_err(|_| {
561 Error::runtime("Failed to acquire write lock on active sessions".to_string())
562 })?;
563
564 sessions.insert(session_id, session);
565 Ok(())
566 }
567
568 pub fn run_inference(&self, session_id: &str, inputs: &[Tensor]) -> Result<Vec<Tensor>> {
570 let mut sessions = self.active_sessions.write().map_err(|_| {
571 Error::runtime("Failed to acquire write lock on active sessions".to_string())
572 })?;
573
574 let session = sessions
575 .get_mut(session_id)
576 .ok_or_else(|| Error::runtime(format!("Session '{session_id}' not found")))?;
577
578 let start_time = std::time::Instant::now();
579
580 let outputs = match session.framework {
581 MLFramework::Candle => self.run_candle_inference(session, inputs)?,
582 MLFramework::OnnxRuntime => self.run_onnx_inference(session, inputs)?,
583 MLFramework::TensorFlowLite => self.run_tflite_inference(session, inputs)?,
584 MLFramework::PyTorch => self.run_pytorch_inference(session, inputs)?,
585 MLFramework::Custom => self.run_custom_inference(session, inputs)?,
586 };
587
588 let inference_time = start_time.elapsed();
589
590 {
592 let mut metrics = session.metrics.write().map_err(|_| {
593 Error::runtime("Failed to acquire write lock on metrics".to_string())
594 })?;
595
596 metrics.inference_count += 1;
597 let inference_time_ms = inference_time.as_millis() as u64;
598 metrics.total_inference_time_ms += inference_time_ms;
599 metrics.avg_inference_time_ms =
600 metrics.total_inference_time_ms as f32 / metrics.inference_count as f32;
601
602 if metrics.inference_count == 1 {
603 metrics.min_inference_time_ms = inference_time_ms as f32;
604 metrics.max_inference_time_ms = inference_time_ms as f32;
605 } else {
606 metrics.min_inference_time_ms =
607 metrics.min_inference_time_ms.min(inference_time_ms as f32);
608 metrics.max_inference_time_ms =
609 metrics.max_inference_time_ms.max(inference_time_ms as f32);
610 }
611
612 metrics.last_update = std::time::SystemTime::now();
613 }
614
615 Ok(outputs)
616 }
617
618 fn select_framework(&self, model_metadata: &MLModelMetadata) -> Result<MLFramework> {
620 if self.framework_supports_model(self.config.primary_framework, model_metadata)? {
622 return Ok(self.config.primary_framework);
623 }
624
625 for &framework in &self.config.fallback_frameworks {
627 if self.framework_supports_model(framework, model_metadata)? {
628 return Ok(framework);
629 }
630 }
631
632 Err(Error::model(format!(
633 "No compatible framework found for model '{}'",
634 model_metadata.name
635 )))
636 }
637
638 fn framework_supports_model(
640 &self,
641 framework: MLFramework,
642 model_metadata: &MLModelMetadata,
643 ) -> Result<bool> {
644 let framework_info = self
645 .frameworks
646 .get(&framework)
647 .ok_or_else(|| Error::model(format!("Framework {framework:?} not available")))?;
648
649 if !framework_info.initialized {
650 return Ok(false);
651 }
652
653 match (framework, model_metadata.framework) {
655 (MLFramework::Candle, _) => Ok(true), (a, b) if a == b => Ok(true), (MLFramework::OnnxRuntime, MLFramework::PyTorch) => Ok(true), (MLFramework::OnnxRuntime, MLFramework::TensorFlowLite) => Ok(true), _ => Ok(false),
660 }
661 }
662
663 fn create_candle_session(
665 &self,
666 model_metadata: &MLModelMetadata,
667 ) -> Result<MLInferenceSession> {
668 let device = match &self.config.device_preference {
669 DevicePreference::Cpu => Device::Cpu,
670 DevicePreference::Gpu { device_index, .. } => match device_index {
671 Some(idx) => std::panic::catch_unwind(move || Device::cuda_if_available(*idx))
672 .ok()
673 .and_then(|r| r.ok())
674 .ok_or_else(|| Error::model(format!("Failed to create CUDA device {idx}")))?,
675 None => std::panic::catch_unwind(|| Device::cuda_if_available(0))
676 .ok()
677 .and_then(|r| r.ok())
678 .ok_or_else(|| Error::model("Failed to create CUDA device".to_string()))?,
679 },
680 DevicePreference::Auto => std::panic::catch_unwind(|| Device::cuda_if_available(0))
681 .ok()
682 .and_then(|r| r.ok())
683 .unwrap_or(Device::Cpu),
684 DevicePreference::Custom { .. } => Device::Cpu, };
686
687 let model_weights = HashMap::new();
689
690 let model_architecture = ModelArchitecture::Transformer {
692 num_layers: 12,
693 hidden_dim: 768,
694 num_heads: 12,
695 };
696
697 let candle_session = CandleSession {
698 device,
699 model_weights,
700 model_architecture,
701 };
702
703 Ok(MLInferenceSession {
704 framework: MLFramework::Candle,
705 model_metadata: model_metadata.clone(),
706 candle_session: Some(candle_session),
707 config: self.config.clone(),
708 metrics: Arc::new(RwLock::new(InferenceMetrics::default())),
709 })
710 }
711
712 fn create_onnx_session(&self, model_metadata: &MLModelMetadata) -> Result<MLInferenceSession> {
714 Ok(MLInferenceSession {
716 framework: MLFramework::OnnxRuntime,
717 model_metadata: model_metadata.clone(),
718 candle_session: None,
719 config: self.config.clone(),
720 metrics: Arc::new(RwLock::new(InferenceMetrics::default())),
721 })
722 }
723
724 fn create_tflite_session(
726 &self,
727 model_metadata: &MLModelMetadata,
728 ) -> Result<MLInferenceSession> {
729 Ok(MLInferenceSession {
731 framework: MLFramework::TensorFlowLite,
732 model_metadata: model_metadata.clone(),
733 candle_session: None,
734 config: self.config.clone(),
735 metrics: Arc::new(RwLock::new(InferenceMetrics::default())),
736 })
737 }
738
739 fn create_pytorch_session(
741 &self,
742 model_metadata: &MLModelMetadata,
743 ) -> Result<MLInferenceSession> {
744 Ok(MLInferenceSession {
746 framework: MLFramework::PyTorch,
747 model_metadata: model_metadata.clone(),
748 candle_session: None,
749 config: self.config.clone(),
750 metrics: Arc::new(RwLock::new(InferenceMetrics::default())),
751 })
752 }
753
754 fn create_custom_session(
756 &self,
757 model_metadata: &MLModelMetadata,
758 ) -> Result<MLInferenceSession> {
759 Ok(MLInferenceSession {
761 framework: MLFramework::Custom,
762 model_metadata: model_metadata.clone(),
763 candle_session: None,
764 config: self.config.clone(),
765 metrics: Arc::new(RwLock::new(InferenceMetrics::default())),
766 })
767 }
768
769 fn run_candle_inference(
771 &self,
772 session: &MLInferenceSession,
773 inputs: &[Tensor],
774 ) -> Result<Vec<Tensor>> {
775 let candle_session = session
777 .candle_session
778 .as_ref()
779 .ok_or_else(|| Error::model("Candle session not initialized".to_string()))?;
780
781 Ok(inputs.to_vec())
783 }
784
785 fn run_onnx_inference(
787 &self,
788 _session: &MLInferenceSession,
789 inputs: &[Tensor],
790 ) -> Result<Vec<Tensor>> {
791 Ok(inputs.to_vec())
793 }
794
795 fn run_tflite_inference(
797 &self,
798 _session: &MLInferenceSession,
799 inputs: &[Tensor],
800 ) -> Result<Vec<Tensor>> {
801 Ok(inputs.to_vec())
803 }
804
805 fn run_pytorch_inference(
807 &self,
808 _session: &MLInferenceSession,
809 inputs: &[Tensor],
810 ) -> Result<Vec<Tensor>> {
811 Ok(inputs.to_vec())
813 }
814
815 fn run_custom_inference(
817 &self,
818 _session: &MLInferenceSession,
819 inputs: &[Tensor],
820 ) -> Result<Vec<Tensor>> {
821 Ok(inputs.to_vec())
823 }
824
825 pub fn get_metrics(&self, session_id: &str) -> Result<InferenceMetrics> {
827 let sessions = self.active_sessions.read().map_err(|_| {
828 Error::runtime("Failed to acquire read lock on active sessions".to_string())
829 })?;
830
831 let session = sessions
832 .get(session_id)
833 .ok_or_else(|| Error::runtime(format!("Session '{session_id}' not found")))?;
834
835 let metrics = session
836 .metrics
837 .read()
838 .map_err(|_| Error::runtime("Failed to acquire read lock on metrics".to_string()))?;
839
840 Ok(metrics.clone())
841 }
842
843 pub fn close_session(&self, session_id: &str) -> Result<()> {
845 let mut sessions = self.active_sessions.write().map_err(|_| {
846 Error::runtime("Failed to acquire write lock on active sessions".to_string())
847 })?;
848
849 sessions
850 .remove(session_id)
851 .ok_or_else(|| Error::runtime(format!("Session '{session_id}' not found")))?;
852
853 Ok(())
854 }
855
856 pub fn list_frameworks(&self) -> Vec<(MLFramework, &FrameworkInfo)> {
858 self.frameworks
859 .iter()
860 .map(|(&framework, info)| (framework, info))
861 .collect()
862 }
863
864 pub fn get_framework_capabilities(
866 &self,
867 framework: MLFramework,
868 ) -> Option<&FrameworkCapabilities> {
869 self.frameworks
870 .get(&framework)
871 .map(|info| &info.capabilities)
872 }
873}
874
875#[cfg(test)]
876mod tests {
877 use super::*;
878 use std::path::PathBuf;
879
880 #[test]
881 fn test_ml_framework_config_default() {
882 let config = MLFrameworkConfig::default();
883 assert_eq!(config.primary_framework, MLFramework::Candle);
884 assert!(config.optimization.quantization_enabled);
885 assert_eq!(config.performance_config.batch_size, 1);
886 }
887
888 #[test]
889 fn test_ml_framework_manager_creation() {
890 let config = MLFrameworkConfig::default();
891 let manager = MLFrameworkManager::new(config).unwrap();
892
893 let frameworks = manager.list_frameworks();
894 assert!(!frameworks.is_empty());
895
896 assert!(frameworks
898 .iter()
899 .any(|(framework, _)| *framework == MLFramework::Candle));
900 }
901
902 #[test]
903 fn test_model_registration() {
904 let config = MLFrameworkConfig::default();
905 let manager = MLFrameworkManager::new(config).unwrap();
906
907 let model_metadata = MLModelMetadata {
908 name: "test-model".to_string(),
909 version: "1.0.0".to_string(),
910 framework: MLFramework::Candle,
911 input_specs: vec![TensorSpec {
912 name: "input".to_string(),
913 shape: vec![1, -1, 80],
914 data_type: TensorDataType::Float32,
915 description: Some("Audio features".to_string()),
916 }],
917 output_specs: vec![TensorSpec {
918 name: "output".to_string(),
919 shape: vec![1, -1, 80],
920 data_type: TensorDataType::Float32,
921 description: Some("Converted features".to_string()),
922 }],
923 model_path: PathBuf::from("test_model.safetensors"),
924 model_size_bytes: 1024 * 1024, supported_sample_rates: vec![22050, 44100],
926 capabilities: ModelCapabilities {
927 realtime_capable: true,
928 batch_capable: true,
929 streaming_capable: true,
930 gpu_accelerated: true,
931 quantization_support: true,
932 max_input_length: Some(1000),
933 },
934 };
935
936 manager.register_model(model_metadata).unwrap();
937 }
938
939 #[test]
940 fn test_quantization_precision() {
941 let mut config = MLFrameworkConfig::default();
942 config.optimization.quantization_precision = QuantizationPrecision::Float16;
943
944 assert_eq!(
945 config.optimization.quantization_precision,
946 QuantizationPrecision::Float16
947 );
948 }
949
950 #[test]
951 fn test_device_preference() {
952 let cpu_preference = DevicePreference::Cpu;
953 let gpu_preference = DevicePreference::Gpu {
954 device_index: Some(0),
955 memory_limit_mb: Some(4096),
956 };
957
958 match cpu_preference {
959 DevicePreference::Cpu => {}
960 _ => panic!("Expected CPU preference"),
961 }
962
963 match gpu_preference {
964 DevicePreference::Gpu {
965 device_index: Some(0),
966 memory_limit_mb: Some(4096),
967 } => {}
968 _ => panic!("Expected GPU preference with specific settings"),
969 }
970 }
971
972 #[test]
973 fn test_inference_metrics() {
974 let mut metrics = InferenceMetrics::default();
975
976 metrics.inference_count = 10;
978 metrics.total_inference_time_ms = 1000;
979 metrics.avg_inference_time_ms = 100.0;
980 metrics.min_inference_time_ms = 50.0;
981 metrics.max_inference_time_ms = 200.0;
982
983 assert_eq!(metrics.inference_count, 10);
984 assert_eq!(metrics.avg_inference_time_ms, 100.0);
985 }
986}