1use super::types_config::{OutputType, ProcessorType, TransformParameter};
8use scirs2_core::ndarray::Array1;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::time::SystemTime;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct ModelInputSpec {
16 pub image_size: (usize, usize),
18 pub channels: usize,
20 pub normalization: NormalizationSpec,
22 pub preprocessing: Vec<String>,
24 pub tensor_format: TensorFormat,
26 pub data_type: ModelDataType,
28 pub value_range: (f32, f32),
30}
31
32impl Default for ModelInputSpec {
33 fn default() -> Self {
34 Self {
35 image_size: (224, 224),
36 channels: 3,
37 normalization: NormalizationSpec::imagenet(),
38 preprocessing: vec!["resize".to_string(), "normalize".to_string()],
39 tensor_format: TensorFormat::NCHW,
40 data_type: ModelDataType::Float32,
41 value_range: (0.0, 1.0),
42 }
43 }
44}
45
46impl ModelInputSpec {
47 #[must_use]
49 pub fn classification(image_size: (usize, usize)) -> Self {
50 Self {
51 image_size,
52 channels: 3,
53 normalization: NormalizationSpec::imagenet(),
54 preprocessing: vec![
55 "resize".to_string(),
56 "center_crop".to_string(),
57 "normalize".to_string(),
58 ],
59 tensor_format: TensorFormat::NCHW,
60 data_type: ModelDataType::Float32,
61 value_range: (0.0, 1.0),
62 }
63 }
64
65 #[must_use]
67 pub fn object_detection(image_size: (usize, usize)) -> Self {
68 Self {
69 image_size,
70 channels: 3,
71 normalization: NormalizationSpec::coco(),
72 preprocessing: vec![
73 "resize".to_string(),
74 "letterbox".to_string(),
75 "normalize".to_string(),
76 ],
77 tensor_format: TensorFormat::NCHW,
78 data_type: ModelDataType::Float32,
79 value_range: (0.0, 1.0),
80 }
81 }
82
83 #[must_use]
85 pub fn segmentation(image_size: (usize, usize)) -> Self {
86 Self {
87 image_size,
88 channels: 3,
89 normalization: NormalizationSpec::cityscapes(),
90 preprocessing: vec!["resize".to_string(), "normalize".to_string()],
91 tensor_format: TensorFormat::NCHW,
92 data_type: ModelDataType::Float32,
93 value_range: (0.0, 1.0),
94 }
95 }
96
97 pub fn validate(&self) -> Result<(), ModelError> {
99 if self.image_size.0 == 0 || self.image_size.1 == 0 {
100 return Err(ModelError::InvalidInputSpec(
101 "Image size must be greater than zero".to_string(),
102 ));
103 }
104
105 if self.channels == 0 {
106 return Err(ModelError::InvalidInputSpec(
107 "Number of channels must be greater than zero".to_string(),
108 ));
109 }
110
111 if self.value_range.0 >= self.value_range.1 {
112 return Err(ModelError::InvalidInputSpec(
113 "Value range minimum must be less than maximum".to_string(),
114 ));
115 }
116
117 Ok(())
118 }
119
120 #[must_use]
122 pub fn memory_requirements(&self, batch_size: usize) -> usize {
123 let element_size = match self.data_type {
124 ModelDataType::Float32 => 4,
125 ModelDataType::Float16 => 2,
126 ModelDataType::Int8 => 1,
127 ModelDataType::UInt8 => 1,
128 };
129
130 batch_size * self.channels * self.image_size.0 * self.image_size.1 * element_size
131 }
132}
133
134#[derive(Debug, Clone, Serialize, Deserialize)]
136pub struct NormalizationSpec {
137 pub mean: Array1<f32>,
139 pub std: Array1<f32>,
141 pub range: (f32, f32),
143 pub norm_type: NormalizationType,
145}
146
147impl Default for NormalizationSpec {
148 fn default() -> Self {
149 Self::imagenet()
150 }
151}
152
153impl NormalizationSpec {
154 #[must_use]
156 pub fn imagenet() -> Self {
157 Self {
158 mean: Array1::from(vec![0.485, 0.456, 0.406]),
159 std: Array1::from(vec![0.229, 0.224, 0.225]),
160 range: (0.0, 1.0),
161 norm_type: NormalizationType::StandardScore,
162 }
163 }
164
165 #[must_use]
167 pub fn coco() -> Self {
168 Self {
169 mean: Array1::from(vec![0.485, 0.456, 0.406]),
170 std: Array1::from(vec![0.229, 0.224, 0.225]),
171 range: (0.0, 1.0),
172 norm_type: NormalizationType::StandardScore,
173 }
174 }
175
176 #[must_use]
178 pub fn cityscapes() -> Self {
179 Self {
180 mean: Array1::from(vec![0.485, 0.456, 0.406]),
181 std: Array1::from(vec![0.229, 0.224, 0.225]),
182 range: (0.0, 1.0),
183 norm_type: NormalizationType::StandardScore,
184 }
185 }
186
187 #[must_use]
189 pub fn custom(mean: Vec<f32>, std: Vec<f32>, range: (f32, f32)) -> Self {
190 Self {
191 mean: Array1::from(mean),
192 std: Array1::from(std),
193 range,
194 norm_type: NormalizationType::StandardScore,
195 }
196 }
197
198 #[must_use]
200 pub fn min_max() -> Self {
201 Self {
202 mean: Array1::from(vec![0.0, 0.0, 0.0]),
203 std: Array1::from(vec![255.0, 255.0, 255.0]),
204 range: (0.0, 1.0),
205 norm_type: NormalizationType::MinMax,
206 }
207 }
208}
209
210#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
212pub enum NormalizationType {
213 StandardScore,
215 MinMax,
217 L2,
219 None,
221}
222
223#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
225pub enum TensorFormat {
226 NCHW,
228 NHWC,
230 CHW,
232 HWC,
234}
235
236impl Default for TensorFormat {
237 fn default() -> Self {
238 Self::NCHW
239 }
240}
241
242#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
244pub enum ModelDataType {
245 Float32,
247 Float16,
249 Int8,
251 UInt8,
253}
254
255impl Default for ModelDataType {
256 fn default() -> Self {
257 Self::Float32
258 }
259}
260
261#[derive(Debug, Clone, Serialize, Deserialize)]
263pub struct ModelOutputSpec {
264 pub output_shapes: Vec<Vec<usize>>,
266 pub output_types: Vec<OutputType>,
268 pub postprocessing: Vec<String>,
270 pub interpretation: OutputInterpretation,
272 pub confidence_thresholds: HashMap<String, f64>,
274}
275
276impl Default for ModelOutputSpec {
277 fn default() -> Self {
278 Self {
279 output_shapes: vec![vec![1000]], output_types: vec![OutputType::Classification],
281 postprocessing: vec!["softmax".to_string()],
282 interpretation: OutputInterpretation::default(),
283 confidence_thresholds: HashMap::new(),
284 }
285 }
286}
287
288impl ModelOutputSpec {
289 #[must_use]
291 pub fn classification(num_classes: usize) -> Self {
292 let mut thresholds = HashMap::new();
293 thresholds.insert("min_confidence".to_string(), 0.5);
294
295 Self {
296 output_shapes: vec![vec![num_classes]],
297 output_types: vec![OutputType::Classification],
298 postprocessing: vec!["softmax".to_string(), "argmax".to_string()],
299 interpretation: OutputInterpretation::classification(),
300 confidence_thresholds: thresholds,
301 }
302 }
303
304 #[must_use]
306 pub fn object_detection() -> Self {
307 let mut thresholds = HashMap::new();
308 thresholds.insert("detection_threshold".to_string(), 0.5);
309 thresholds.insert("nms_threshold".to_string(), 0.4);
310
311 Self {
312 output_shapes: vec![
313 vec![1, 25200, 85], ],
315 output_types: vec![OutputType::Detection],
316 postprocessing: vec![
317 "decode_boxes".to_string(),
318 "nms".to_string(),
319 "filter_confidence".to_string(),
320 ],
321 interpretation: OutputInterpretation::object_detection(),
322 confidence_thresholds: thresholds,
323 }
324 }
325
326 #[must_use]
328 pub fn segmentation(num_classes: usize, height: usize, width: usize) -> Self {
329 Self {
330 output_shapes: vec![vec![num_classes, height, width]],
331 output_types: vec![OutputType::Segmentation],
332 postprocessing: vec!["argmax".to_string(), "colormap".to_string()],
333 interpretation: OutputInterpretation::segmentation(),
334 confidence_thresholds: HashMap::new(),
335 }
336 }
337}
338
339#[derive(Debug, Clone, Serialize, Deserialize)]
341pub struct OutputInterpretation {
342 pub class_labels: Vec<String>,
344 pub label_mapping: HashMap<usize, String>,
346 pub format_description: String,
348 pub units: Option<String>,
350}
351
352impl Default for OutputInterpretation {
353 fn default() -> Self {
354 Self {
355 class_labels: vec![],
356 label_mapping: HashMap::new(),
357 format_description: "Raw model output".to_string(),
358 units: None,
359 }
360 }
361}
362
363impl OutputInterpretation {
364 #[must_use]
366 pub fn classification() -> Self {
367 Self {
368 class_labels: vec![],
369 label_mapping: HashMap::new(),
370 format_description: "Class probabilities".to_string(),
371 units: Some("probability".to_string()),
372 }
373 }
374
375 #[must_use]
377 pub fn object_detection() -> Self {
378 Self {
379 class_labels: vec![],
380 label_mapping: HashMap::new(),
381 format_description: "Bounding boxes with class and confidence".to_string(),
382 units: Some("normalized_coordinates".to_string()),
383 }
384 }
385
386 #[must_use]
388 pub fn segmentation() -> Self {
389 Self {
390 class_labels: vec![],
391 label_mapping: HashMap::new(),
392 format_description: "Per-pixel class predictions".to_string(),
393 units: Some("class_index".to_string()),
394 }
395 }
396}
397
398#[derive(Debug, Clone, Serialize, Deserialize)]
400pub struct ModelPerformance {
401 pub inference_time: f64,
403 pub memory_usage: u64,
405 pub flops: u64,
407 pub accuracy: AccuracyMetrics,
409 pub throughput: ThroughputMetrics,
411 pub resource_utilization: ModelResourceUtilization,
413}
414
415impl Default for ModelPerformance {
416 fn default() -> Self {
417 Self {
418 inference_time: 100.0, memory_usage: 100 * 1024 * 1024, flops: 1_000_000_000, accuracy: AccuracyMetrics::default(),
422 throughput: ThroughputMetrics::default(),
423 resource_utilization: ModelResourceUtilization::default(),
424 }
425 }
426}
427
428impl ModelPerformance {
429 #[must_use]
431 pub fn mobile_optimized() -> Self {
432 Self {
433 inference_time: 50.0, memory_usage: 20 * 1024 * 1024, flops: 100_000_000, accuracy: AccuracyMetrics::mobile(),
437 throughput: ThroughputMetrics::mobile(),
438 resource_utilization: ModelResourceUtilization::low(),
439 }
440 }
441
442 #[must_use]
444 pub fn server_optimized() -> Self {
445 Self {
446 inference_time: 200.0, memory_usage: 500 * 1024 * 1024, flops: 10_000_000_000, accuracy: AccuracyMetrics::high_accuracy(),
450 throughput: ThroughputMetrics::server(),
451 resource_utilization: ModelResourceUtilization::high(),
452 }
453 }
454
455 #[must_use]
457 pub fn edge_optimized() -> Self {
458 Self {
459 inference_time: 75.0, memory_usage: 50 * 1024 * 1024, flops: 500_000_000, accuracy: AccuracyMetrics::balanced(),
463 throughput: ThroughputMetrics::edge(),
464 resource_utilization: ModelResourceUtilization::moderate(),
465 }
466 }
467}
468
469#[derive(Debug, Clone, Serialize, Deserialize)]
471pub struct AccuracyMetrics {
472 pub top1_accuracy: Option<f64>,
474 pub top5_accuracy: Option<f64>,
476 pub map: Option<f64>,
478 pub iou: Option<f64>,
480 pub f1_score: Option<f64>,
482 pub precision: Option<f64>,
484 pub recall: Option<f64>,
486 pub custom_metrics: HashMap<String, f64>,
488}
489
490impl Default for AccuracyMetrics {
491 fn default() -> Self {
492 Self {
493 top1_accuracy: Some(0.75),
494 top5_accuracy: Some(0.92),
495 map: None,
496 iou: None,
497 f1_score: Some(0.75),
498 precision: Some(0.75),
499 recall: Some(0.75),
500 custom_metrics: HashMap::new(),
501 }
502 }
503}
504
505impl AccuracyMetrics {
506 #[must_use]
508 pub fn mobile() -> Self {
509 Self {
510 top1_accuracy: Some(0.65),
511 top5_accuracy: Some(0.85),
512 map: None,
513 iou: None,
514 f1_score: Some(0.65),
515 precision: Some(0.70),
516 recall: Some(0.60),
517 custom_metrics: HashMap::new(),
518 }
519 }
520
521 #[must_use]
523 pub fn high_accuracy() -> Self {
524 Self {
525 top1_accuracy: Some(0.85),
526 top5_accuracy: Some(0.97),
527 map: Some(0.75),
528 iou: Some(0.80),
529 f1_score: Some(0.85),
530 precision: Some(0.88),
531 recall: Some(0.82),
532 custom_metrics: HashMap::new(),
533 }
534 }
535
536 #[must_use]
538 pub fn balanced() -> Self {
539 Self::default()
540 }
541}
542
543#[derive(Debug, Clone, Serialize, Deserialize)]
545pub struct ThroughputMetrics {
546 pub images_per_second: f64,
548 pub batch_throughput: HashMap<usize, f64>, pub peak_throughput: f64,
552 pub sustained_throughput: f64,
554}
555
556impl Default for ThroughputMetrics {
557 fn default() -> Self {
558 let mut batch_throughput = HashMap::new();
559 batch_throughput.insert(1, 10.0);
560 batch_throughput.insert(8, 40.0);
561 batch_throughput.insert(16, 60.0);
562
563 Self {
564 images_per_second: 10.0,
565 batch_throughput,
566 peak_throughput: 80.0,
567 sustained_throughput: 10.0,
568 }
569 }
570}
571
572impl ThroughputMetrics {
573 #[must_use]
575 pub fn mobile() -> Self {
576 let mut batch_throughput = HashMap::new();
577 batch_throughput.insert(1, 20.0);
578 batch_throughput.insert(4, 30.0);
579
580 Self {
581 images_per_second: 20.0,
582 batch_throughput,
583 peak_throughput: 35.0,
584 sustained_throughput: 18.0,
585 }
586 }
587
588 #[must_use]
590 pub fn server() -> Self {
591 let mut batch_throughput = HashMap::new();
592 batch_throughput.insert(1, 5.0);
593 batch_throughput.insert(8, 30.0);
594 batch_throughput.insert(16, 50.0);
595 batch_throughput.insert(32, 80.0);
596 batch_throughput.insert(64, 100.0);
597
598 Self {
599 images_per_second: 5.0,
600 batch_throughput,
601 peak_throughput: 120.0,
602 sustained_throughput: 4.0,
603 }
604 }
605
606 #[must_use]
608 pub fn edge() -> Self {
609 let mut batch_throughput = HashMap::new();
610 batch_throughput.insert(1, 13.0);
611 batch_throughput.insert(4, 20.0);
612 batch_throughput.insert(8, 25.0);
613
614 Self {
615 images_per_second: 13.0,
616 batch_throughput,
617 peak_throughput: 28.0,
618 sustained_throughput: 12.0,
619 }
620 }
621}
622
623#[derive(Debug, Clone, Serialize, Deserialize)]
625pub struct ModelResourceUtilization {
626 pub cpu_utilization: f64,
628 pub gpu_utilization: Option<f64>,
630 pub memory_utilization: f64,
632 pub power_consumption: Option<f64>,
634 pub thermal_profile: ThermalProfile,
636}
637
638impl Default for ModelResourceUtilization {
639 fn default() -> Self {
640 Self {
641 cpu_utilization: 50.0,
642 gpu_utilization: Some(60.0),
643 memory_utilization: 40.0,
644 power_consumption: Some(15.0),
645 thermal_profile: ThermalProfile::default(),
646 }
647 }
648}
649
650impl ModelResourceUtilization {
651 #[must_use]
653 pub fn low() -> Self {
654 Self {
655 cpu_utilization: 20.0,
656 gpu_utilization: Some(25.0),
657 memory_utilization: 15.0,
658 power_consumption: Some(5.0),
659 thermal_profile: ThermalProfile::cool(),
660 }
661 }
662
663 #[must_use]
665 pub fn moderate() -> Self {
666 Self::default()
667 }
668
669 #[must_use]
671 pub fn high() -> Self {
672 Self {
673 cpu_utilization: 80.0,
674 gpu_utilization: Some(90.0),
675 memory_utilization: 70.0,
676 power_consumption: Some(50.0),
677 thermal_profile: ThermalProfile::hot(),
678 }
679 }
680}
681
682#[derive(Debug, Clone, Serialize, Deserialize)]
684pub struct ThermalProfile {
685 pub temperature_range: (f32, f32),
687 pub tdp: Option<f32>,
689 pub cooling_requirements: CoolingRequirements,
691}
692
693impl Default for ThermalProfile {
694 fn default() -> Self {
695 Self {
696 temperature_range: (20.0, 70.0),
697 tdp: Some(15.0),
698 cooling_requirements: CoolingRequirements::Passive,
699 }
700 }
701}
702
703impl ThermalProfile {
704 #[must_use]
706 pub fn cool() -> Self {
707 Self {
708 temperature_range: (15.0, 50.0),
709 tdp: Some(5.0),
710 cooling_requirements: CoolingRequirements::None,
711 }
712 }
713
714 #[must_use]
716 pub fn hot() -> Self {
717 Self {
718 temperature_range: (25.0, 85.0),
719 tdp: Some(50.0),
720 cooling_requirements: CoolingRequirements::Active,
721 }
722 }
723}
724
725#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
727pub enum CoolingRequirements {
728 None,
730 Passive,
732 Active,
734 Liquid,
736}
737
738#[derive(Debug, Clone, Serialize, Deserialize)]
740pub struct ModelMetadata {
741 pub name: String,
743 pub version: String,
745 pub author: String,
747 pub description: String,
749 pub training_dataset: Option<String>,
751 pub license: Option<String>,
753 pub model_size: u64,
755 pub platforms: Vec<String>,
757 pub created_at: SystemTime,
759 pub modified_at: SystemTime,
761 pub tags: Vec<String>,
763 pub deployment_requirements: DeploymentRequirements,
765}
766
767impl Default for ModelMetadata {
768 fn default() -> Self {
769 Self {
770 name: "Unknown Model".to_string(),
771 version: "1.0.0".to_string(),
772 author: "Unknown".to_string(),
773 description: "Computer vision model".to_string(),
774 training_dataset: None,
775 license: None,
776 model_size: 0,
777 platforms: vec!["cpu".to_string()],
778 created_at: SystemTime::now(),
779 modified_at: SystemTime::now(),
780 tags: vec![],
781 deployment_requirements: DeploymentRequirements::default(),
782 }
783 }
784}
785
786impl ModelMetadata {
787 #[must_use]
789 pub fn classification(name: &str, version: &str) -> Self {
790 Self {
791 name: name.to_string(),
792 version: version.to_string(),
793 author: "Unknown".to_string(),
794 description: "Image classification model".to_string(),
795 training_dataset: Some("ImageNet".to_string()),
796 license: Some("MIT".to_string()),
797 model_size: 100 * 1024 * 1024, platforms: vec!["cpu".to_string(), "gpu".to_string()],
799 created_at: SystemTime::now(),
800 modified_at: SystemTime::now(),
801 tags: vec!["classification".to_string(), "vision".to_string()],
802 deployment_requirements: DeploymentRequirements::standard(),
803 }
804 }
805
806 pub fn touch(&mut self) {
808 self.modified_at = SystemTime::now();
809 }
810
811 pub fn add_tag(&mut self, tag: String) {
813 if !self.tags.contains(&tag) {
814 self.tags.push(tag);
815 }
816 }
817
818 #[must_use]
820 pub fn has_tag(&self, tag: &str) -> bool {
821 self.tags.contains(&tag.to_string())
822 }
823}
824
825#[derive(Debug, Clone, Serialize, Deserialize)]
827pub struct DeploymentRequirements {
828 pub min_runtime_version: String,
830 pub dependencies: Vec<String>,
832 pub hardware_requirements: HardwareRequirements,
834 pub environment_variables: HashMap<String, String>,
836 pub config_files: Vec<String>,
838}
839
840impl Default for DeploymentRequirements {
841 fn default() -> Self {
842 Self {
843 min_runtime_version: "1.0.0".to_string(),
844 dependencies: vec![],
845 hardware_requirements: HardwareRequirements::default(),
846 environment_variables: HashMap::new(),
847 config_files: vec![],
848 }
849 }
850}
851
852impl DeploymentRequirements {
853 #[must_use]
855 pub fn standard() -> Self {
856 Self {
857 min_runtime_version: "1.0.0".to_string(),
858 dependencies: vec!["opencv".to_string(), "numpy".to_string()],
859 hardware_requirements: HardwareRequirements::standard(),
860 environment_variables: HashMap::new(),
861 config_files: vec!["model_config.json".to_string()],
862 }
863 }
864
865 #[must_use]
867 pub fn minimal() -> Self {
868 Self {
869 min_runtime_version: "1.0.0".to_string(),
870 dependencies: vec![],
871 hardware_requirements: HardwareRequirements::minimal(),
872 environment_variables: HashMap::new(),
873 config_files: vec![],
874 }
875 }
876}
877
878#[derive(Debug, Clone, Serialize, Deserialize)]
880pub struct HardwareRequirements {
881 pub min_ram: u64,
883 pub min_storage: u64,
885 pub cpu_features: Vec<String>,
887 pub gpu_requirements: Option<GpuRequirements>,
889 pub architectures: Vec<String>,
891}
892
893impl Default for HardwareRequirements {
894 fn default() -> Self {
895 Self {
896 min_ram: 1024 * 1024 * 1024, min_storage: 500 * 1024 * 1024, cpu_features: vec!["sse4.1".to_string()],
899 gpu_requirements: None,
900 architectures: vec!["x86_64".to_string(), "arm64".to_string()],
901 }
902 }
903}
904
905impl HardwareRequirements {
906 #[must_use]
908 pub fn standard() -> Self {
909 Self::default()
910 }
911
912 #[must_use]
914 pub fn minimal() -> Self {
915 Self {
916 min_ram: 256 * 1024 * 1024, min_storage: 100 * 1024 * 1024, cpu_features: vec![],
919 gpu_requirements: None,
920 architectures: vec!["x86_64".to_string(), "arm64".to_string()],
921 }
922 }
923}
924
925#[derive(Debug, Clone, Serialize, Deserialize)]
927pub struct GpuRequirements {
928 pub min_gpu_memory: u64,
930 pub compute_capability: Option<String>,
932 pub vendors: Vec<String>,
934 pub min_driver_version: Option<String>,
936}
937
938impl Default for GpuRequirements {
939 fn default() -> Self {
940 Self {
941 min_gpu_memory: 2 * 1024 * 1024 * 1024, compute_capability: Some("6.0".to_string()),
943 vendors: vec!["NVIDIA".to_string(), "AMD".to_string()],
944 min_driver_version: Some("450.0".to_string()),
945 }
946 }
947}
948
949#[derive(Debug, Clone, Serialize, Deserialize)]
951pub struct ProcessorConfig {
952 pub processor_type: ProcessorType,
954 pub parameters: HashMap<String, TransformParameter>,
956 pub quality_enhancement: QualityEnhancementConfig,
958 pub priority: i32,
960 pub enabled: bool,
962}
963
964impl Default for ProcessorConfig {
965 fn default() -> Self {
966 Self {
967 processor_type: ProcessorType::Filtering,
968 parameters: HashMap::new(),
969 quality_enhancement: QualityEnhancementConfig::default(),
970 priority: 0,
971 enabled: true,
972 }
973 }
974}
975
976impl ProcessorConfig {
977 #[must_use]
979 pub fn nms(confidence_threshold: f64, iou_threshold: f64) -> Self {
980 let mut parameters = HashMap::new();
981 parameters.insert(
982 "confidence_threshold".to_string(),
983 TransformParameter::Float(confidence_threshold),
984 );
985 parameters.insert(
986 "iou_threshold".to_string(),
987 TransformParameter::Float(iou_threshold),
988 );
989
990 Self {
991 processor_type: ProcessorType::NonMaximumSuppression,
992 parameters,
993 quality_enhancement: QualityEnhancementConfig::default(),
994 priority: 10,
995 enabled: true,
996 }
997 }
998
999 #[must_use]
1001 pub fn filtering(min_confidence: f64) -> Self {
1002 let mut parameters = HashMap::new();
1003 parameters.insert(
1004 "min_confidence".to_string(),
1005 TransformParameter::Float(min_confidence),
1006 );
1007
1008 Self {
1009 processor_type: ProcessorType::Filtering,
1010 parameters,
1011 quality_enhancement: QualityEnhancementConfig::default(),
1012 priority: 5,
1013 enabled: true,
1014 }
1015 }
1016
1017 #[must_use]
1019 pub fn tracking() -> Self {
1020 let mut parameters = HashMap::new();
1021 parameters.insert("max_disappeared".to_string(), TransformParameter::Int(30));
1022 parameters.insert("max_distance".to_string(), TransformParameter::Float(50.0));
1023
1024 Self {
1025 processor_type: ProcessorType::Tracking,
1026 parameters,
1027 quality_enhancement: QualityEnhancementConfig::default(),
1028 priority: 20,
1029 enabled: true,
1030 }
1031 }
1032}
1033
1034#[derive(Debug, Clone, Serialize, Deserialize)]
1036pub struct QualityEnhancementConfig {
1037 pub confidence_adjustment: f64,
1039 pub noise_filtering: bool,
1041 pub outlier_removal: bool,
1043 pub temporal_consistency: bool,
1045 pub smoothing: SmoothingConfig,
1047}
1048
1049impl Default for QualityEnhancementConfig {
1050 fn default() -> Self {
1051 Self {
1052 confidence_adjustment: 0.0,
1053 noise_filtering: false,
1054 outlier_removal: false,
1055 temporal_consistency: false,
1056 smoothing: SmoothingConfig::default(),
1057 }
1058 }
1059}
1060
1061#[derive(Debug, Clone, Serialize, Deserialize)]
1063pub struct SmoothingConfig {
1064 pub enabled: bool,
1066 pub window_size: usize,
1068 pub algorithm: SmoothingAlgorithm,
1070 pub strength: f64,
1072}
1073
1074impl Default for SmoothingConfig {
1075 fn default() -> Self {
1076 Self {
1077 enabled: false,
1078 window_size: 5,
1079 algorithm: SmoothingAlgorithm::MovingAverage,
1080 strength: 0.5,
1081 }
1082 }
1083}
1084
1085#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
1087pub enum SmoothingAlgorithm {
1088 MovingAverage,
1090 Exponential,
1092 Gaussian,
1094 Median,
1096}
1097
1098#[derive(Debug, Clone, PartialEq)]
1100pub enum ModelError {
1101 InvalidInputSpec(String),
1103 InvalidOutputSpec(String),
1105 LoadingError(String),
1107 InferenceError(String),
1109 PostProcessingError(String),
1111 ConfigurationError(String),
1113 HardwareRequirementNotMet(String),
1115 UnsupportedPlatform(String),
1117}
1118
1119impl std::fmt::Display for ModelError {
1120 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1121 match self {
1122 Self::InvalidInputSpec(msg) => write!(f, "Invalid input specification: {msg}"),
1123 Self::InvalidOutputSpec(msg) => write!(f, "Invalid output specification: {msg}"),
1124 Self::LoadingError(msg) => write!(f, "Model loading error: {msg}"),
1125 Self::InferenceError(msg) => write!(f, "Inference error: {msg}"),
1126 Self::PostProcessingError(msg) => write!(f, "Post-processing error: {msg}"),
1127 Self::ConfigurationError(msg) => write!(f, "Configuration error: {msg}"),
1128 Self::HardwareRequirementNotMet(msg) => {
1129 write!(f, "Hardware requirement not met: {msg}")
1130 }
1131 Self::UnsupportedPlatform(platform) => write!(f, "Unsupported platform: {platform}"),
1132 }
1133 }
1134}
1135
1136impl std::error::Error for ModelError {}
1137
1138#[allow(non_snake_case)]
1139#[cfg(test)]
1140mod tests {
1141 use super::*;
1142
1143 #[test]
1144 fn test_model_input_spec() {
1145 let spec = ModelInputSpec::classification((224, 224));
1146 assert_eq!(spec.image_size, (224, 224));
1147 assert_eq!(spec.channels, 3);
1148 assert!(spec.validate().is_ok());
1149
1150 let memory = spec.memory_requirements(1);
1151 assert_eq!(memory, 1 * 3 * 224 * 224 * 4); }
1153
1154 #[test]
1155 fn test_normalization_spec() {
1156 let imagenet = NormalizationSpec::imagenet();
1157 assert_eq!(imagenet.mean.len(), 3);
1158 assert_eq!(imagenet.std.len(), 3);
1159
1160 let custom =
1161 NormalizationSpec::custom(vec![0.5, 0.5, 0.5], vec![0.5, 0.5, 0.5], (0.0, 1.0));
1162 assert_eq!(custom.mean.len(), 3);
1163 assert_eq!(custom.norm_type, NormalizationType::StandardScore);
1164 }
1165
1166 #[test]
1167 fn test_model_output_spec() {
1168 let classification = ModelOutputSpec::classification(1000);
1169 assert_eq!(classification.output_shapes[0], vec![1000]);
1170 assert_eq!(classification.output_types[0], OutputType::Classification);
1171
1172 let detection = ModelOutputSpec::object_detection();
1173 assert_eq!(detection.output_types[0], OutputType::Detection);
1174 assert!(detection
1175 .confidence_thresholds
1176 .contains_key("detection_threshold"));
1177 }
1178
1179 #[test]
1180 fn test_model_performance() {
1181 let mobile = ModelPerformance::mobile_optimized();
1182 assert_eq!(mobile.inference_time, 50.0);
1183 assert!(mobile.memory_usage < 50 * 1024 * 1024);
1184
1185 let server = ModelPerformance::server_optimized();
1186 assert!(server.inference_time > mobile.inference_time);
1187 assert!(server.memory_usage > mobile.memory_usage);
1188 }
1189
1190 #[test]
1191 fn test_accuracy_metrics() {
1192 let mobile = AccuracyMetrics::mobile();
1193 let high_acc = AccuracyMetrics::high_accuracy();
1194
1195 assert!(high_acc.top1_accuracy.unwrap() > mobile.top1_accuracy.unwrap());
1196 assert!(high_acc.f1_score.unwrap() > mobile.f1_score.unwrap());
1197 }
1198
1199 #[test]
1200 fn test_model_metadata() {
1201 let mut metadata = ModelMetadata::classification("ResNet50", "1.0.0");
1202 assert_eq!(metadata.name, "ResNet50");
1203 assert_eq!(metadata.version, "1.0.0");
1204
1205 metadata.add_tag("pretrained".to_string());
1206 assert!(metadata.has_tag("pretrained"));
1207 assert!(metadata.has_tag("classification"));
1208
1209 let before = metadata.modified_at;
1210 metadata.touch();
1211 assert!(metadata.modified_at > before);
1212 }
1213
1214 #[test]
1215 fn test_processor_config() {
1216 let nms = ProcessorConfig::nms(0.5, 0.4);
1217 assert_eq!(nms.processor_type, ProcessorType::NonMaximumSuppression);
1218 assert_eq!(nms.priority, 10);
1219 assert!(nms.enabled);
1220
1221 let filtering = ProcessorConfig::filtering(0.3);
1222 assert_eq!(filtering.processor_type, ProcessorType::Filtering);
1223 assert_eq!(filtering.priority, 5);
1224 }
1225
1226 #[test]
1227 fn test_deployment_requirements() {
1228 let standard = DeploymentRequirements::standard();
1229 assert!(!standard.dependencies.is_empty());
1230 assert!(standard.hardware_requirements.min_ram > 0);
1231
1232 let minimal = DeploymentRequirements::minimal();
1233 assert!(minimal.dependencies.is_empty());
1234 assert!(minimal.hardware_requirements.min_ram < standard.hardware_requirements.min_ram);
1235 }
1236
1237 #[test]
1238 fn test_model_error_display() {
1239 let error = ModelError::InvalidInputSpec("Image size must be positive".to_string());
1240 let error_str = error.to_string();
1241 assert!(error_str.contains("Invalid input specification"));
1242 assert!(error_str.contains("Image size must be positive"));
1243
1244 let error = ModelError::UnsupportedPlatform("windows-arm".to_string());
1245 let error_str = error.to_string();
1246 assert!(error_str.contains("Unsupported platform"));
1247 assert!(error_str.contains("windows-arm"));
1248 }
1249}