sklears_compose/cv_pipelines/
model_management.rs

1//! Model management and configuration
2//!
3//! This module provides comprehensive model management capabilities including
4//! model specifications, metadata, performance characteristics, processor
5//! configurations, and quality enhancement for computer vision pipelines.
6
7use super::types_config::{OutputType, ProcessorType, TransformParameter};
8use scirs2_core::ndarray::Array1;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::time::SystemTime;
12
13/// Model input specification for computer vision models
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct ModelInputSpec {
16    /// Expected input image size (width, height)
17    pub image_size: (usize, usize),
18    /// Number of input channels
19    pub channels: usize,
20    /// Normalization parameters
21    pub normalization: NormalizationSpec,
22    /// Required preprocessing steps
23    pub preprocessing: Vec<String>,
24    /// Input tensor format (NCHW, NHWC, etc.)
25    pub tensor_format: TensorFormat,
26    /// Data type requirements
27    pub data_type: ModelDataType,
28    /// Input value range
29    pub value_range: (f32, f32),
30}
31
32impl Default for ModelInputSpec {
33    fn default() -> Self {
34        Self {
35            image_size: (224, 224),
36            channels: 3,
37            normalization: NormalizationSpec::imagenet(),
38            preprocessing: vec!["resize".to_string(), "normalize".to_string()],
39            tensor_format: TensorFormat::NCHW,
40            data_type: ModelDataType::Float32,
41            value_range: (0.0, 1.0),
42        }
43    }
44}
45
46impl ModelInputSpec {
47    /// Create specification for classification models
48    #[must_use]
49    pub fn classification(image_size: (usize, usize)) -> Self {
50        Self {
51            image_size,
52            channels: 3,
53            normalization: NormalizationSpec::imagenet(),
54            preprocessing: vec![
55                "resize".to_string(),
56                "center_crop".to_string(),
57                "normalize".to_string(),
58            ],
59            tensor_format: TensorFormat::NCHW,
60            data_type: ModelDataType::Float32,
61            value_range: (0.0, 1.0),
62        }
63    }
64
65    /// Create specification for object detection models
66    #[must_use]
67    pub fn object_detection(image_size: (usize, usize)) -> Self {
68        Self {
69            image_size,
70            channels: 3,
71            normalization: NormalizationSpec::coco(),
72            preprocessing: vec![
73                "resize".to_string(),
74                "letterbox".to_string(),
75                "normalize".to_string(),
76            ],
77            tensor_format: TensorFormat::NCHW,
78            data_type: ModelDataType::Float32,
79            value_range: (0.0, 1.0),
80        }
81    }
82
83    /// Create specification for segmentation models
84    #[must_use]
85    pub fn segmentation(image_size: (usize, usize)) -> Self {
86        Self {
87            image_size,
88            channels: 3,
89            normalization: NormalizationSpec::cityscapes(),
90            preprocessing: vec!["resize".to_string(), "normalize".to_string()],
91            tensor_format: TensorFormat::NCHW,
92            data_type: ModelDataType::Float32,
93            value_range: (0.0, 1.0),
94        }
95    }
96
97    /// Validate input specification
98    pub fn validate(&self) -> Result<(), ModelError> {
99        if self.image_size.0 == 0 || self.image_size.1 == 0 {
100            return Err(ModelError::InvalidInputSpec(
101                "Image size must be greater than zero".to_string(),
102            ));
103        }
104
105        if self.channels == 0 {
106            return Err(ModelError::InvalidInputSpec(
107                "Number of channels must be greater than zero".to_string(),
108            ));
109        }
110
111        if self.value_range.0 >= self.value_range.1 {
112            return Err(ModelError::InvalidInputSpec(
113                "Value range minimum must be less than maximum".to_string(),
114            ));
115        }
116
117        Ok(())
118    }
119
120    /// Calculate memory requirements for input
121    #[must_use]
122    pub fn memory_requirements(&self, batch_size: usize) -> usize {
123        let element_size = match self.data_type {
124            ModelDataType::Float32 => 4,
125            ModelDataType::Float16 => 2,
126            ModelDataType::Int8 => 1,
127            ModelDataType::UInt8 => 1,
128        };
129
130        batch_size * self.channels * self.image_size.0 * self.image_size.1 * element_size
131    }
132}
133
134/// Normalization specification for model inputs
135#[derive(Debug, Clone, Serialize, Deserialize)]
136pub struct NormalizationSpec {
137    /// Mean values per channel
138    pub mean: Array1<f32>,
139    /// Standard deviation per channel
140    pub std: Array1<f32>,
141    /// Value range (min, max) for clipping
142    pub range: (f32, f32),
143    /// Normalization type
144    pub norm_type: NormalizationType,
145}
146
147impl Default for NormalizationSpec {
148    fn default() -> Self {
149        Self::imagenet()
150    }
151}
152
153impl NormalizationSpec {
154    /// `ImageNet` normalization parameters
155    #[must_use]
156    pub fn imagenet() -> Self {
157        Self {
158            mean: Array1::from(vec![0.485, 0.456, 0.406]),
159            std: Array1::from(vec![0.229, 0.224, 0.225]),
160            range: (0.0, 1.0),
161            norm_type: NormalizationType::StandardScore,
162        }
163    }
164
165    /// COCO dataset normalization parameters
166    #[must_use]
167    pub fn coco() -> Self {
168        Self {
169            mean: Array1::from(vec![0.485, 0.456, 0.406]),
170            std: Array1::from(vec![0.229, 0.224, 0.225]),
171            range: (0.0, 1.0),
172            norm_type: NormalizationType::StandardScore,
173        }
174    }
175
176    /// Cityscapes dataset normalization parameters
177    #[must_use]
178    pub fn cityscapes() -> Self {
179        Self {
180            mean: Array1::from(vec![0.485, 0.456, 0.406]),
181            std: Array1::from(vec![0.229, 0.224, 0.225]),
182            range: (0.0, 1.0),
183            norm_type: NormalizationType::StandardScore,
184        }
185    }
186
187    /// Custom normalization parameters
188    #[must_use]
189    pub fn custom(mean: Vec<f32>, std: Vec<f32>, range: (f32, f32)) -> Self {
190        Self {
191            mean: Array1::from(mean),
192            std: Array1::from(std),
193            range,
194            norm_type: NormalizationType::StandardScore,
195        }
196    }
197
198    /// Min-max normalization (0-1)
199    #[must_use]
200    pub fn min_max() -> Self {
201        Self {
202            mean: Array1::from(vec![0.0, 0.0, 0.0]),
203            std: Array1::from(vec![255.0, 255.0, 255.0]),
204            range: (0.0, 1.0),
205            norm_type: NormalizationType::MinMax,
206        }
207    }
208}
209
210/// Normalization types
211#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
212pub enum NormalizationType {
213    /// Standard score normalization (z-score)
214    StandardScore,
215    /// Min-max normalization
216    MinMax,
217    /// L2 normalization
218    L2,
219    /// No normalization
220    None,
221}
222
223/// Tensor format specifications
224#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
225pub enum TensorFormat {
226    /// Batch, Channels, Height, Width
227    NCHW,
228    /// Batch, Height, Width, Channels
229    NHWC,
230    /// Channels, Height, Width (single image)
231    CHW,
232    /// Height, Width, Channels (single image)
233    HWC,
234}
235
236impl Default for TensorFormat {
237    fn default() -> Self {
238        Self::NCHW
239    }
240}
241
242/// Model data types
243#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
244pub enum ModelDataType {
245    /// 32-bit floating point
246    Float32,
247    /// 16-bit floating point
248    Float16,
249    /// 8-bit signed integer
250    Int8,
251    /// 8-bit unsigned integer
252    UInt8,
253}
254
255impl Default for ModelDataType {
256    fn default() -> Self {
257        Self::Float32
258    }
259}
260
261/// Model output specification
262#[derive(Debug, Clone, Serialize, Deserialize)]
263pub struct ModelOutputSpec {
264    /// Output tensor shapes
265    pub output_shapes: Vec<Vec<usize>>,
266    /// Output types for each tensor
267    pub output_types: Vec<OutputType>,
268    /// Post-processing requirements
269    pub postprocessing: Vec<String>,
270    /// Output interpretation
271    pub interpretation: OutputInterpretation,
272    /// Confidence thresholds
273    pub confidence_thresholds: HashMap<String, f64>,
274}
275
276impl Default for ModelOutputSpec {
277    fn default() -> Self {
278        Self {
279            output_shapes: vec![vec![1000]], // ImageNet classes
280            output_types: vec![OutputType::Classification],
281            postprocessing: vec!["softmax".to_string()],
282            interpretation: OutputInterpretation::default(),
283            confidence_thresholds: HashMap::new(),
284        }
285    }
286}
287
288impl ModelOutputSpec {
289    /// Create specification for classification outputs
290    #[must_use]
291    pub fn classification(num_classes: usize) -> Self {
292        let mut thresholds = HashMap::new();
293        thresholds.insert("min_confidence".to_string(), 0.5);
294
295        Self {
296            output_shapes: vec![vec![num_classes]],
297            output_types: vec![OutputType::Classification],
298            postprocessing: vec!["softmax".to_string(), "argmax".to_string()],
299            interpretation: OutputInterpretation::classification(),
300            confidence_thresholds: thresholds,
301        }
302    }
303
304    /// Create specification for object detection outputs
305    #[must_use]
306    pub fn object_detection() -> Self {
307        let mut thresholds = HashMap::new();
308        thresholds.insert("detection_threshold".to_string(), 0.5);
309        thresholds.insert("nms_threshold".to_string(), 0.4);
310
311        Self {
312            output_shapes: vec![
313                vec![1, 25200, 85], // YOLO-style output
314            ],
315            output_types: vec![OutputType::Detection],
316            postprocessing: vec![
317                "decode_boxes".to_string(),
318                "nms".to_string(),
319                "filter_confidence".to_string(),
320            ],
321            interpretation: OutputInterpretation::object_detection(),
322            confidence_thresholds: thresholds,
323        }
324    }
325
326    /// Create specification for segmentation outputs
327    #[must_use]
328    pub fn segmentation(num_classes: usize, height: usize, width: usize) -> Self {
329        Self {
330            output_shapes: vec![vec![num_classes, height, width]],
331            output_types: vec![OutputType::Segmentation],
332            postprocessing: vec!["argmax".to_string(), "colormap".to_string()],
333            interpretation: OutputInterpretation::segmentation(),
334            confidence_thresholds: HashMap::new(),
335        }
336    }
337}
338
339/// Output interpretation configuration
340#[derive(Debug, Clone, Serialize, Deserialize)]
341pub struct OutputInterpretation {
342    /// Class labels (for classification/detection)
343    pub class_labels: Vec<String>,
344    /// Label mapping
345    pub label_mapping: HashMap<usize, String>,
346    /// Output format description
347    pub format_description: String,
348    /// Units for regression outputs
349    pub units: Option<String>,
350}
351
352impl Default for OutputInterpretation {
353    fn default() -> Self {
354        Self {
355            class_labels: vec![],
356            label_mapping: HashMap::new(),
357            format_description: "Raw model output".to_string(),
358            units: None,
359        }
360    }
361}
362
363impl OutputInterpretation {
364    /// Create interpretation for classification
365    #[must_use]
366    pub fn classification() -> Self {
367        Self {
368            class_labels: vec![],
369            label_mapping: HashMap::new(),
370            format_description: "Class probabilities".to_string(),
371            units: Some("probability".to_string()),
372        }
373    }
374
375    /// Create interpretation for object detection
376    #[must_use]
377    pub fn object_detection() -> Self {
378        Self {
379            class_labels: vec![],
380            label_mapping: HashMap::new(),
381            format_description: "Bounding boxes with class and confidence".to_string(),
382            units: Some("normalized_coordinates".to_string()),
383        }
384    }
385
386    /// Create interpretation for segmentation
387    #[must_use]
388    pub fn segmentation() -> Self {
389        Self {
390            class_labels: vec![],
391            label_mapping: HashMap::new(),
392            format_description: "Per-pixel class predictions".to_string(),
393            units: Some("class_index".to_string()),
394        }
395    }
396}
397
398/// Model performance characteristics
399#[derive(Debug, Clone, Serialize, Deserialize)]
400pub struct ModelPerformance {
401    /// Average inference time in milliseconds
402    pub inference_time: f64,
403    /// Memory usage in bytes
404    pub memory_usage: u64,
405    /// FLOPs (floating point operations) required
406    pub flops: u64,
407    /// Model accuracy metrics
408    pub accuracy: AccuracyMetrics,
409    /// Throughput metrics
410    pub throughput: ThroughputMetrics,
411    /// Resource utilization
412    pub resource_utilization: ModelResourceUtilization,
413}
414
415impl Default for ModelPerformance {
416    fn default() -> Self {
417        Self {
418            inference_time: 100.0,           // 100ms
419            memory_usage: 100 * 1024 * 1024, // 100MB
420            flops: 1_000_000_000,            // 1 GFLOP
421            accuracy: AccuracyMetrics::default(),
422            throughput: ThroughputMetrics::default(),
423            resource_utilization: ModelResourceUtilization::default(),
424        }
425    }
426}
427
428impl ModelPerformance {
429    /// Create performance profile for lightweight mobile models
430    #[must_use]
431    pub fn mobile_optimized() -> Self {
432        Self {
433            inference_time: 50.0,           // 50ms
434            memory_usage: 20 * 1024 * 1024, // 20MB
435            flops: 100_000_000,             // 100 MFLOP
436            accuracy: AccuracyMetrics::mobile(),
437            throughput: ThroughputMetrics::mobile(),
438            resource_utilization: ModelResourceUtilization::low(),
439        }
440    }
441
442    /// Create performance profile for server-side models
443    #[must_use]
444    pub fn server_optimized() -> Self {
445        Self {
446            inference_time: 200.0,           // 200ms
447            memory_usage: 500 * 1024 * 1024, // 500MB
448            flops: 10_000_000_000,           // 10 GFLOP
449            accuracy: AccuracyMetrics::high_accuracy(),
450            throughput: ThroughputMetrics::server(),
451            resource_utilization: ModelResourceUtilization::high(),
452        }
453    }
454
455    /// Create performance profile for edge devices
456    #[must_use]
457    pub fn edge_optimized() -> Self {
458        Self {
459            inference_time: 75.0,           // 75ms
460            memory_usage: 50 * 1024 * 1024, // 50MB
461            flops: 500_000_000,             // 500 MFLOP
462            accuracy: AccuracyMetrics::balanced(),
463            throughput: ThroughputMetrics::edge(),
464            resource_utilization: ModelResourceUtilization::moderate(),
465        }
466    }
467}
468
469/// Accuracy metrics for model evaluation
470#[derive(Debug, Clone, Serialize, Deserialize)]
471pub struct AccuracyMetrics {
472    /// Top-1 accuracy (for classification)
473    pub top1_accuracy: Option<f64>,
474    /// Top-5 accuracy (for classification)
475    pub top5_accuracy: Option<f64>,
476    /// Mean Average Precision (for detection/segmentation)
477    pub map: Option<f64>,
478    /// Intersection over Union (for detection/segmentation)
479    pub iou: Option<f64>,
480    /// F1 score
481    pub f1_score: Option<f64>,
482    /// Precision
483    pub precision: Option<f64>,
484    /// Recall
485    pub recall: Option<f64>,
486    /// Custom metrics
487    pub custom_metrics: HashMap<String, f64>,
488}
489
490impl Default for AccuracyMetrics {
491    fn default() -> Self {
492        Self {
493            top1_accuracy: Some(0.75),
494            top5_accuracy: Some(0.92),
495            map: None,
496            iou: None,
497            f1_score: Some(0.75),
498            precision: Some(0.75),
499            recall: Some(0.75),
500            custom_metrics: HashMap::new(),
501        }
502    }
503}
504
505impl AccuracyMetrics {
506    /// Create metrics profile for mobile models
507    #[must_use]
508    pub fn mobile() -> Self {
509        Self {
510            top1_accuracy: Some(0.65),
511            top5_accuracy: Some(0.85),
512            map: None,
513            iou: None,
514            f1_score: Some(0.65),
515            precision: Some(0.70),
516            recall: Some(0.60),
517            custom_metrics: HashMap::new(),
518        }
519    }
520
521    /// Create metrics profile for high-accuracy models
522    #[must_use]
523    pub fn high_accuracy() -> Self {
524        Self {
525            top1_accuracy: Some(0.85),
526            top5_accuracy: Some(0.97),
527            map: Some(0.75),
528            iou: Some(0.80),
529            f1_score: Some(0.85),
530            precision: Some(0.88),
531            recall: Some(0.82),
532            custom_metrics: HashMap::new(),
533        }
534    }
535
536    /// Create balanced accuracy metrics
537    #[must_use]
538    pub fn balanced() -> Self {
539        Self::default()
540    }
541}
542
543/// Throughput metrics for model performance
544#[derive(Debug, Clone, Serialize, Deserialize)]
545pub struct ThroughputMetrics {
546    /// Images per second
547    pub images_per_second: f64,
548    /// Batch processing throughput
549    pub batch_throughput: HashMap<usize, f64>, // batch_size -> throughput
550    /// Peak throughput
551    pub peak_throughput: f64,
552    /// Sustained throughput
553    pub sustained_throughput: f64,
554}
555
556impl Default for ThroughputMetrics {
557    fn default() -> Self {
558        let mut batch_throughput = HashMap::new();
559        batch_throughput.insert(1, 10.0);
560        batch_throughput.insert(8, 40.0);
561        batch_throughput.insert(16, 60.0);
562
563        Self {
564            images_per_second: 10.0,
565            batch_throughput,
566            peak_throughput: 80.0,
567            sustained_throughput: 10.0,
568        }
569    }
570}
571
572impl ThroughputMetrics {
573    /// Create throughput metrics for mobile devices
574    #[must_use]
575    pub fn mobile() -> Self {
576        let mut batch_throughput = HashMap::new();
577        batch_throughput.insert(1, 20.0);
578        batch_throughput.insert(4, 30.0);
579
580        Self {
581            images_per_second: 20.0,
582            batch_throughput,
583            peak_throughput: 35.0,
584            sustained_throughput: 18.0,
585        }
586    }
587
588    /// Create throughput metrics for server deployments
589    #[must_use]
590    pub fn server() -> Self {
591        let mut batch_throughput = HashMap::new();
592        batch_throughput.insert(1, 5.0);
593        batch_throughput.insert(8, 30.0);
594        batch_throughput.insert(16, 50.0);
595        batch_throughput.insert(32, 80.0);
596        batch_throughput.insert(64, 100.0);
597
598        Self {
599            images_per_second: 5.0,
600            batch_throughput,
601            peak_throughput: 120.0,
602            sustained_throughput: 4.0,
603        }
604    }
605
606    /// Create throughput metrics for edge devices
607    #[must_use]
608    pub fn edge() -> Self {
609        let mut batch_throughput = HashMap::new();
610        batch_throughput.insert(1, 13.0);
611        batch_throughput.insert(4, 20.0);
612        batch_throughput.insert(8, 25.0);
613
614        Self {
615            images_per_second: 13.0,
616            batch_throughput,
617            peak_throughput: 28.0,
618            sustained_throughput: 12.0,
619        }
620    }
621}
622
623/// Resource utilization for model execution
624#[derive(Debug, Clone, Serialize, Deserialize)]
625pub struct ModelResourceUtilization {
626    /// CPU utilization percentage
627    pub cpu_utilization: f64,
628    /// GPU utilization percentage (if applicable)
629    pub gpu_utilization: Option<f64>,
630    /// Memory utilization percentage
631    pub memory_utilization: f64,
632    /// Power consumption in watts
633    pub power_consumption: Option<f64>,
634    /// Thermal characteristics
635    pub thermal_profile: ThermalProfile,
636}
637
638impl Default for ModelResourceUtilization {
639    fn default() -> Self {
640        Self {
641            cpu_utilization: 50.0,
642            gpu_utilization: Some(60.0),
643            memory_utilization: 40.0,
644            power_consumption: Some(15.0),
645            thermal_profile: ThermalProfile::default(),
646        }
647    }
648}
649
650impl ModelResourceUtilization {
651    /// Create low resource utilization profile
652    #[must_use]
653    pub fn low() -> Self {
654        Self {
655            cpu_utilization: 20.0,
656            gpu_utilization: Some(25.0),
657            memory_utilization: 15.0,
658            power_consumption: Some(5.0),
659            thermal_profile: ThermalProfile::cool(),
660        }
661    }
662
663    /// Create moderate resource utilization profile
664    #[must_use]
665    pub fn moderate() -> Self {
666        Self::default()
667    }
668
669    /// Create high resource utilization profile
670    #[must_use]
671    pub fn high() -> Self {
672        Self {
673            cpu_utilization: 80.0,
674            gpu_utilization: Some(90.0),
675            memory_utilization: 70.0,
676            power_consumption: Some(50.0),
677            thermal_profile: ThermalProfile::hot(),
678        }
679    }
680}
681
682/// Thermal profile for model execution
683#[derive(Debug, Clone, Serialize, Deserialize)]
684pub struct ThermalProfile {
685    /// Operating temperature range (min, max) in Celsius
686    pub temperature_range: (f32, f32),
687    /// Thermal design power in watts
688    pub tdp: Option<f32>,
689    /// Cooling requirements
690    pub cooling_requirements: CoolingRequirements,
691}
692
693impl Default for ThermalProfile {
694    fn default() -> Self {
695        Self {
696            temperature_range: (20.0, 70.0),
697            tdp: Some(15.0),
698            cooling_requirements: CoolingRequirements::Passive,
699        }
700    }
701}
702
703impl ThermalProfile {
704    /// Create cool thermal profile for low-power models
705    #[must_use]
706    pub fn cool() -> Self {
707        Self {
708            temperature_range: (15.0, 50.0),
709            tdp: Some(5.0),
710            cooling_requirements: CoolingRequirements::None,
711        }
712    }
713
714    /// Create hot thermal profile for high-performance models
715    #[must_use]
716    pub fn hot() -> Self {
717        Self {
718            temperature_range: (25.0, 85.0),
719            tdp: Some(50.0),
720            cooling_requirements: CoolingRequirements::Active,
721        }
722    }
723}
724
725/// Cooling requirements for model execution
726#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
727pub enum CoolingRequirements {
728    /// No cooling required
729    None,
730    /// Passive cooling (heat sinks)
731    Passive,
732    /// Active cooling (fans)
733    Active,
734    /// Liquid cooling
735    Liquid,
736}
737
738/// Model metadata for management and deployment
739#[derive(Debug, Clone, Serialize, Deserialize)]
740pub struct ModelMetadata {
741    /// Model name
742    pub name: String,
743    /// Model version
744    pub version: String,
745    /// Author/organization
746    pub author: String,
747    /// Model description
748    pub description: String,
749    /// Training dataset information
750    pub training_dataset: Option<String>,
751    /// License information
752    pub license: Option<String>,
753    /// Model file size in bytes
754    pub model_size: u64,
755    /// Supported platforms/architectures
756    pub platforms: Vec<String>,
757    /// Creation timestamp
758    pub created_at: SystemTime,
759    /// Last modified timestamp
760    pub modified_at: SystemTime,
761    /// Model tags for categorization
762    pub tags: Vec<String>,
763    /// Deployment requirements
764    pub deployment_requirements: DeploymentRequirements,
765}
766
767impl Default for ModelMetadata {
768    fn default() -> Self {
769        Self {
770            name: "Unknown Model".to_string(),
771            version: "1.0.0".to_string(),
772            author: "Unknown".to_string(),
773            description: "Computer vision model".to_string(),
774            training_dataset: None,
775            license: None,
776            model_size: 0,
777            platforms: vec!["cpu".to_string()],
778            created_at: SystemTime::now(),
779            modified_at: SystemTime::now(),
780            tags: vec![],
781            deployment_requirements: DeploymentRequirements::default(),
782        }
783    }
784}
785
786impl ModelMetadata {
787    /// Create metadata for a classification model
788    #[must_use]
789    pub fn classification(name: &str, version: &str) -> Self {
790        Self {
791            name: name.to_string(),
792            version: version.to_string(),
793            author: "Unknown".to_string(),
794            description: "Image classification model".to_string(),
795            training_dataset: Some("ImageNet".to_string()),
796            license: Some("MIT".to_string()),
797            model_size: 100 * 1024 * 1024, // 100MB
798            platforms: vec!["cpu".to_string(), "gpu".to_string()],
799            created_at: SystemTime::now(),
800            modified_at: SystemTime::now(),
801            tags: vec!["classification".to_string(), "vision".to_string()],
802            deployment_requirements: DeploymentRequirements::standard(),
803        }
804    }
805
806    /// Update modification timestamp
807    pub fn touch(&mut self) {
808        self.modified_at = SystemTime::now();
809    }
810
811    /// Add a tag if not already present
812    pub fn add_tag(&mut self, tag: String) {
813        if !self.tags.contains(&tag) {
814            self.tags.push(tag);
815        }
816    }
817
818    /// Check if model has a specific tag
819    #[must_use]
820    pub fn has_tag(&self, tag: &str) -> bool {
821        self.tags.contains(&tag.to_string())
822    }
823}
824
825/// Deployment requirements for models
826#[derive(Debug, Clone, Serialize, Deserialize)]
827pub struct DeploymentRequirements {
828    /// Minimum runtime version required
829    pub min_runtime_version: String,
830    /// Required dependencies
831    pub dependencies: Vec<String>,
832    /// Hardware requirements
833    pub hardware_requirements: HardwareRequirements,
834    /// Environment variables
835    pub environment_variables: HashMap<String, String>,
836    /// Configuration files needed
837    pub config_files: Vec<String>,
838}
839
840impl Default for DeploymentRequirements {
841    fn default() -> Self {
842        Self {
843            min_runtime_version: "1.0.0".to_string(),
844            dependencies: vec![],
845            hardware_requirements: HardwareRequirements::default(),
846            environment_variables: HashMap::new(),
847            config_files: vec![],
848        }
849    }
850}
851
852impl DeploymentRequirements {
853    /// Create standard deployment requirements
854    #[must_use]
855    pub fn standard() -> Self {
856        Self {
857            min_runtime_version: "1.0.0".to_string(),
858            dependencies: vec!["opencv".to_string(), "numpy".to_string()],
859            hardware_requirements: HardwareRequirements::standard(),
860            environment_variables: HashMap::new(),
861            config_files: vec!["model_config.json".to_string()],
862        }
863    }
864
865    /// Create minimal deployment requirements
866    #[must_use]
867    pub fn minimal() -> Self {
868        Self {
869            min_runtime_version: "1.0.0".to_string(),
870            dependencies: vec![],
871            hardware_requirements: HardwareRequirements::minimal(),
872            environment_variables: HashMap::new(),
873            config_files: vec![],
874        }
875    }
876}
877
878/// Hardware requirements for model deployment
879#[derive(Debug, Clone, Serialize, Deserialize)]
880pub struct HardwareRequirements {
881    /// Minimum RAM in bytes
882    pub min_ram: u64,
883    /// Minimum storage in bytes
884    pub min_storage: u64,
885    /// Required CPU features
886    pub cpu_features: Vec<String>,
887    /// GPU requirements (if applicable)
888    pub gpu_requirements: Option<GpuRequirements>,
889    /// Architecture requirements
890    pub architectures: Vec<String>,
891}
892
893impl Default for HardwareRequirements {
894    fn default() -> Self {
895        Self {
896            min_ram: 1024 * 1024 * 1024,    // 1GB
897            min_storage: 500 * 1024 * 1024, // 500MB
898            cpu_features: vec!["sse4.1".to_string()],
899            gpu_requirements: None,
900            architectures: vec!["x86_64".to_string(), "arm64".to_string()],
901        }
902    }
903}
904
905impl HardwareRequirements {
906    /// Create standard hardware requirements
907    #[must_use]
908    pub fn standard() -> Self {
909        Self::default()
910    }
911
912    /// Create minimal hardware requirements
913    #[must_use]
914    pub fn minimal() -> Self {
915        Self {
916            min_ram: 256 * 1024 * 1024,     // 256MB
917            min_storage: 100 * 1024 * 1024, // 100MB
918            cpu_features: vec![],
919            gpu_requirements: None,
920            architectures: vec!["x86_64".to_string(), "arm64".to_string()],
921        }
922    }
923}
924
925/// GPU requirements for models
926#[derive(Debug, Clone, Serialize, Deserialize)]
927pub struct GpuRequirements {
928    /// Minimum GPU memory in bytes
929    pub min_gpu_memory: u64,
930    /// Required compute capability
931    pub compute_capability: Option<String>,
932    /// Supported GPU vendors
933    pub vendors: Vec<String>,
934    /// Minimum GPU driver version
935    pub min_driver_version: Option<String>,
936}
937
938impl Default for GpuRequirements {
939    fn default() -> Self {
940        Self {
941            min_gpu_memory: 2 * 1024 * 1024 * 1024, // 2GB
942            compute_capability: Some("6.0".to_string()),
943            vendors: vec!["NVIDIA".to_string(), "AMD".to_string()],
944            min_driver_version: Some("450.0".to_string()),
945        }
946    }
947}
948
949/// Post-processor configuration for model outputs
950#[derive(Debug, Clone, Serialize, Deserialize)]
951pub struct ProcessorConfig {
952    /// Processor type
953    pub processor_type: ProcessorType,
954    /// Processing parameters
955    pub parameters: HashMap<String, TransformParameter>,
956    /// Quality enhancement settings
957    pub quality_enhancement: QualityEnhancementConfig,
958    /// Processing order/priority
959    pub priority: i32,
960    /// Enable/disable flag
961    pub enabled: bool,
962}
963
964impl Default for ProcessorConfig {
965    fn default() -> Self {
966        Self {
967            processor_type: ProcessorType::Filtering,
968            parameters: HashMap::new(),
969            quality_enhancement: QualityEnhancementConfig::default(),
970            priority: 0,
971            enabled: true,
972        }
973    }
974}
975
976impl ProcessorConfig {
977    /// Create NMS processor configuration
978    #[must_use]
979    pub fn nms(confidence_threshold: f64, iou_threshold: f64) -> Self {
980        let mut parameters = HashMap::new();
981        parameters.insert(
982            "confidence_threshold".to_string(),
983            TransformParameter::Float(confidence_threshold),
984        );
985        parameters.insert(
986            "iou_threshold".to_string(),
987            TransformParameter::Float(iou_threshold),
988        );
989
990        Self {
991            processor_type: ProcessorType::NonMaximumSuppression,
992            parameters,
993            quality_enhancement: QualityEnhancementConfig::default(),
994            priority: 10,
995            enabled: true,
996        }
997    }
998
999    /// Create filtering processor configuration
1000    #[must_use]
1001    pub fn filtering(min_confidence: f64) -> Self {
1002        let mut parameters = HashMap::new();
1003        parameters.insert(
1004            "min_confidence".to_string(),
1005            TransformParameter::Float(min_confidence),
1006        );
1007
1008        Self {
1009            processor_type: ProcessorType::Filtering,
1010            parameters,
1011            quality_enhancement: QualityEnhancementConfig::default(),
1012            priority: 5,
1013            enabled: true,
1014        }
1015    }
1016
1017    /// Create tracking processor configuration
1018    #[must_use]
1019    pub fn tracking() -> Self {
1020        let mut parameters = HashMap::new();
1021        parameters.insert("max_disappeared".to_string(), TransformParameter::Int(30));
1022        parameters.insert("max_distance".to_string(), TransformParameter::Float(50.0));
1023
1024        Self {
1025            processor_type: ProcessorType::Tracking,
1026            parameters,
1027            quality_enhancement: QualityEnhancementConfig::default(),
1028            priority: 20,
1029            enabled: true,
1030        }
1031    }
1032}
1033
1034/// Quality enhancement configuration for post-processing
1035#[derive(Debug, Clone, Serialize, Deserialize)]
1036pub struct QualityEnhancementConfig {
1037    /// Confidence threshold adjustment
1038    pub confidence_adjustment: f64,
1039    /// Enable noise filtering
1040    pub noise_filtering: bool,
1041    /// Enable outlier removal
1042    pub outlier_removal: bool,
1043    /// Enable temporal consistency (for video)
1044    pub temporal_consistency: bool,
1045    /// Smoothing parameters
1046    pub smoothing: SmoothingConfig,
1047}
1048
1049impl Default for QualityEnhancementConfig {
1050    fn default() -> Self {
1051        Self {
1052            confidence_adjustment: 0.0,
1053            noise_filtering: false,
1054            outlier_removal: false,
1055            temporal_consistency: false,
1056            smoothing: SmoothingConfig::default(),
1057        }
1058    }
1059}
1060
1061/// Smoothing configuration for quality enhancement
1062#[derive(Debug, Clone, Serialize, Deserialize)]
1063pub struct SmoothingConfig {
1064    /// Enable smoothing
1065    pub enabled: bool,
1066    /// Smoothing window size
1067    pub window_size: usize,
1068    /// Smoothing algorithm
1069    pub algorithm: SmoothingAlgorithm,
1070    /// Smoothing strength (0.0-1.0)
1071    pub strength: f64,
1072}
1073
1074impl Default for SmoothingConfig {
1075    fn default() -> Self {
1076        Self {
1077            enabled: false,
1078            window_size: 5,
1079            algorithm: SmoothingAlgorithm::MovingAverage,
1080            strength: 0.5,
1081        }
1082    }
1083}
1084
1085/// Smoothing algorithms for post-processing
1086#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
1087pub enum SmoothingAlgorithm {
1088    /// Moving average smoothing
1089    MovingAverage,
1090    /// Exponential smoothing
1091    Exponential,
1092    /// Gaussian smoothing
1093    Gaussian,
1094    /// Median filtering
1095    Median,
1096}
1097
1098/// Model management errors
1099#[derive(Debug, Clone, PartialEq)]
1100pub enum ModelError {
1101    /// Invalid input specification
1102    InvalidInputSpec(String),
1103    /// Invalid output specification
1104    InvalidOutputSpec(String),
1105    /// Model loading error
1106    LoadingError(String),
1107    /// Inference error
1108    InferenceError(String),
1109    /// Post-processing error
1110    PostProcessingError(String),
1111    /// Configuration error
1112    ConfigurationError(String),
1113    /// Hardware requirement not met
1114    HardwareRequirementNotMet(String),
1115    /// Unsupported platform
1116    UnsupportedPlatform(String),
1117}
1118
1119impl std::fmt::Display for ModelError {
1120    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1121        match self {
1122            Self::InvalidInputSpec(msg) => write!(f, "Invalid input specification: {msg}"),
1123            Self::InvalidOutputSpec(msg) => write!(f, "Invalid output specification: {msg}"),
1124            Self::LoadingError(msg) => write!(f, "Model loading error: {msg}"),
1125            Self::InferenceError(msg) => write!(f, "Inference error: {msg}"),
1126            Self::PostProcessingError(msg) => write!(f, "Post-processing error: {msg}"),
1127            Self::ConfigurationError(msg) => write!(f, "Configuration error: {msg}"),
1128            Self::HardwareRequirementNotMet(msg) => {
1129                write!(f, "Hardware requirement not met: {msg}")
1130            }
1131            Self::UnsupportedPlatform(platform) => write!(f, "Unsupported platform: {platform}"),
1132        }
1133    }
1134}
1135
1136impl std::error::Error for ModelError {}
1137
1138#[allow(non_snake_case)]
1139#[cfg(test)]
1140mod tests {
1141    use super::*;
1142
1143    #[test]
1144    fn test_model_input_spec() {
1145        let spec = ModelInputSpec::classification((224, 224));
1146        assert_eq!(spec.image_size, (224, 224));
1147        assert_eq!(spec.channels, 3);
1148        assert!(spec.validate().is_ok());
1149
1150        let memory = spec.memory_requirements(1);
1151        assert_eq!(memory, 1 * 3 * 224 * 224 * 4); // Float32 = 4 bytes
1152    }
1153
1154    #[test]
1155    fn test_normalization_spec() {
1156        let imagenet = NormalizationSpec::imagenet();
1157        assert_eq!(imagenet.mean.len(), 3);
1158        assert_eq!(imagenet.std.len(), 3);
1159
1160        let custom =
1161            NormalizationSpec::custom(vec![0.5, 0.5, 0.5], vec![0.5, 0.5, 0.5], (0.0, 1.0));
1162        assert_eq!(custom.mean.len(), 3);
1163        assert_eq!(custom.norm_type, NormalizationType::StandardScore);
1164    }
1165
1166    #[test]
1167    fn test_model_output_spec() {
1168        let classification = ModelOutputSpec::classification(1000);
1169        assert_eq!(classification.output_shapes[0], vec![1000]);
1170        assert_eq!(classification.output_types[0], OutputType::Classification);
1171
1172        let detection = ModelOutputSpec::object_detection();
1173        assert_eq!(detection.output_types[0], OutputType::Detection);
1174        assert!(detection
1175            .confidence_thresholds
1176            .contains_key("detection_threshold"));
1177    }
1178
1179    #[test]
1180    fn test_model_performance() {
1181        let mobile = ModelPerformance::mobile_optimized();
1182        assert_eq!(mobile.inference_time, 50.0);
1183        assert!(mobile.memory_usage < 50 * 1024 * 1024);
1184
1185        let server = ModelPerformance::server_optimized();
1186        assert!(server.inference_time > mobile.inference_time);
1187        assert!(server.memory_usage > mobile.memory_usage);
1188    }
1189
1190    #[test]
1191    fn test_accuracy_metrics() {
1192        let mobile = AccuracyMetrics::mobile();
1193        let high_acc = AccuracyMetrics::high_accuracy();
1194
1195        assert!(high_acc.top1_accuracy.unwrap() > mobile.top1_accuracy.unwrap());
1196        assert!(high_acc.f1_score.unwrap() > mobile.f1_score.unwrap());
1197    }
1198
1199    #[test]
1200    fn test_model_metadata() {
1201        let mut metadata = ModelMetadata::classification("ResNet50", "1.0.0");
1202        assert_eq!(metadata.name, "ResNet50");
1203        assert_eq!(metadata.version, "1.0.0");
1204
1205        metadata.add_tag("pretrained".to_string());
1206        assert!(metadata.has_tag("pretrained"));
1207        assert!(metadata.has_tag("classification"));
1208
1209        let before = metadata.modified_at;
1210        metadata.touch();
1211        assert!(metadata.modified_at > before);
1212    }
1213
1214    #[test]
1215    fn test_processor_config() {
1216        let nms = ProcessorConfig::nms(0.5, 0.4);
1217        assert_eq!(nms.processor_type, ProcessorType::NonMaximumSuppression);
1218        assert_eq!(nms.priority, 10);
1219        assert!(nms.enabled);
1220
1221        let filtering = ProcessorConfig::filtering(0.3);
1222        assert_eq!(filtering.processor_type, ProcessorType::Filtering);
1223        assert_eq!(filtering.priority, 5);
1224    }
1225
1226    #[test]
1227    fn test_deployment_requirements() {
1228        let standard = DeploymentRequirements::standard();
1229        assert!(!standard.dependencies.is_empty());
1230        assert!(standard.hardware_requirements.min_ram > 0);
1231
1232        let minimal = DeploymentRequirements::minimal();
1233        assert!(minimal.dependencies.is_empty());
1234        assert!(minimal.hardware_requirements.min_ram < standard.hardware_requirements.min_ram);
1235    }
1236
1237    #[test]
1238    fn test_model_error_display() {
1239        let error = ModelError::InvalidInputSpec("Image size must be positive".to_string());
1240        let error_str = error.to_string();
1241        assert!(error_str.contains("Invalid input specification"));
1242        assert!(error_str.contains("Image size must be positive"));
1243
1244        let error = ModelError::UnsupportedPlatform("windows-arm".to_string());
1245        let error_str = error.to_string();
1246        assert!(error_str.contains("Unsupported platform"));
1247        assert!(error_str.contains("windows-arm"));
1248    }
1249}