sklears_compose/cv_pipelines/
image_specification.rs

1//! Image specification and data type management
2//!
3//! This module provides comprehensive image specification handling including
4//! format validation, color space management, normalization parameters,
5//! and data type conversions for computer vision pipelines.
6
7use super::types_config::{ColorSpace, ImageDataType, ImageFormat};
8use scirs2_core::ndarray::{Array1, Array3};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11
12/// Comprehensive image specification for input validation and processing
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ImageSpecification {
15    /// Expected image dimensions (height, width)
16    pub dimensions: Option<(usize, usize)>,
17    /// Number of channels (1=grayscale, 3=RGB, 4=RGBA)
18    pub channels: usize,
19    /// Data type (uint8, float32, etc.)
20    pub dtype: ImageDataType,
21    /// Color space (RGB, BGR, HSV, LAB, etc.)
22    pub color_space: ColorSpace,
23    /// Bit depth per channel
24    pub bit_depth: u8,
25    /// Supported input formats
26    pub supported_formats: Vec<ImageFormat>,
27    /// Validation requirements
28    pub validation: ImageValidationSpec,
29    /// Preprocessing requirements
30    pub preprocessing_requirements: Vec<String>,
31}
32
33impl Default for ImageSpecification {
34    fn default() -> Self {
35        Self {
36            dimensions: Some((224, 224)),
37            channels: 3,
38            dtype: ImageDataType::UInt8,
39            color_space: ColorSpace::RGB,
40            bit_depth: 8,
41            supported_formats: vec![ImageFormat::JPEG, ImageFormat::PNG, ImageFormat::BMP],
42            validation: ImageValidationSpec::default(),
43            preprocessing_requirements: vec![],
44        }
45    }
46}
47
48impl ImageSpecification {
49    /// Create a new image specification with specified dimensions
50    #[must_use]
51    pub fn new(width: usize, height: usize, channels: usize) -> Self {
52        Self {
53            dimensions: Some((height, width)),
54            channels,
55            ..Default::default()
56        }
57    }
58
59    /// Create specification for grayscale images
60    #[must_use]
61    pub fn grayscale(width: usize, height: usize) -> Self {
62        Self {
63            dimensions: Some((height, width)),
64            channels: 1,
65            color_space: ColorSpace::Grayscale,
66            ..Default::default()
67        }
68    }
69
70    /// Create specification for RGB images
71    #[must_use]
72    pub fn rgb(width: usize, height: usize) -> Self {
73        Self {
74            dimensions: Some((height, width)),
75            channels: 3,
76            color_space: ColorSpace::RGB,
77            ..Default::default()
78        }
79    }
80
81    /// Create specification for RGBA images with alpha channel
82    #[must_use]
83    pub fn rgba(width: usize, height: usize) -> Self {
84        Self {
85            dimensions: Some((height, width)),
86            channels: 4,
87            color_space: ColorSpace::RGB,
88            ..Default::default()
89        }
90    }
91
92    /// Create specification for high dynamic range images
93    #[must_use]
94    pub fn hdr(width: usize, height: usize) -> Self {
95        Self {
96            dimensions: Some((height, width)),
97            channels: 3,
98            dtype: ImageDataType::Float32,
99            color_space: ColorSpace::RGB,
100            bit_depth: 32,
101            supported_formats: vec![ImageFormat::HDR, ImageFormat::EXR],
102            ..Default::default()
103        }
104    }
105
106    /// Validate an image against this specification
107    pub fn validate(&self, image_data: &ImageData) -> Result<(), ValidationError> {
108        // Check dimensions
109        if let Some((expected_h, expected_w)) = self.dimensions {
110            if image_data.height != expected_h || image_data.width != expected_w {
111                return Err(ValidationError::DimensionMismatch {
112                    expected: (expected_h, expected_w),
113                    actual: (image_data.height, image_data.width),
114                });
115            }
116        }
117
118        // Check channels
119        if image_data.channels != self.channels {
120            return Err(ValidationError::ChannelMismatch {
121                expected: self.channels,
122                actual: image_data.channels,
123            });
124        }
125
126        // Check data type
127        if image_data.dtype != self.dtype {
128            return Err(ValidationError::DataTypeMismatch {
129                expected: self.dtype,
130                actual: image_data.dtype,
131            });
132        }
133
134        // Check color space
135        if image_data.color_space != self.color_space {
136            return Err(ValidationError::ColorSpaceMismatch {
137                expected: self.color_space,
138                actual: image_data.color_space,
139            });
140        }
141
142        // Apply validation rules
143        self.validation.validate(image_data)?;
144
145        Ok(())
146    }
147
148    /// Check if a format is supported
149    #[must_use]
150    pub fn supports_format(&self, format: &ImageFormat) -> bool {
151        self.supported_formats.contains(format)
152    }
153
154    /// Get memory requirements for an image with this specification
155    #[must_use]
156    pub fn memory_requirements(&self) -> usize {
157        if let Some((height, width)) = self.dimensions {
158            let bytes_per_pixel = self.channels * (self.bit_depth as usize / 8);
159            height * width * bytes_per_pixel
160        } else {
161            0 // Unknown dimensions
162        }
163    }
164
165    /// Create specification for object detection tasks
166    #[must_use]
167    pub fn object_detection(dimensions: (usize, usize)) -> Self {
168        Self {
169            dimensions: Some(dimensions),
170            channels: 3,
171            color_space: ColorSpace::RGB,
172            dtype: ImageDataType::UInt8,
173            bit_depth: 8,
174            supported_formats: vec![ImageFormat::JPEG, ImageFormat::PNG],
175            ..Default::default()
176        }
177    }
178
179    /// Create specification for classification tasks
180    #[must_use]
181    pub fn classification(dimensions: (usize, usize)) -> Self {
182        Self {
183            dimensions: Some(dimensions),
184            channels: 3,
185            color_space: ColorSpace::RGB,
186            dtype: ImageDataType::UInt8,
187            bit_depth: 8,
188            supported_formats: vec![ImageFormat::JPEG, ImageFormat::PNG],
189            ..Default::default()
190        }
191    }
192
193    /// Create specification for segmentation tasks
194    #[must_use]
195    pub fn segmentation(dimensions: (usize, usize)) -> Self {
196        Self {
197            dimensions: Some(dimensions),
198            channels: 3,
199            color_space: ColorSpace::RGB,
200            dtype: ImageDataType::UInt8,
201            bit_depth: 8,
202            supported_formats: vec![ImageFormat::PNG, ImageFormat::TIFF],
203            ..Default::default()
204        }
205    }
206}
207
208/// Image validation specification
209#[derive(Debug, Clone, Serialize, Deserialize)]
210pub struct ImageValidationSpec {
211    /// Minimum allowed dimensions
212    pub min_dimensions: Option<(usize, usize)>,
213    /// Maximum allowed dimensions
214    pub max_dimensions: Option<(usize, usize)>,
215    /// Allowed aspect ratios (width/height)
216    pub allowed_aspect_ratios: Vec<f64>,
217    /// Maximum file size in bytes
218    pub max_file_size: Option<usize>,
219    /// Minimum file size in bytes
220    pub min_file_size: Option<usize>,
221    /// Required metadata fields
222    pub required_metadata: Vec<String>,
223    /// Quality thresholds
224    pub quality_thresholds: QualityThresholds,
225}
226
227impl Default for ImageValidationSpec {
228    fn default() -> Self {
229        Self {
230            min_dimensions: Some((32, 32)),
231            max_dimensions: Some((4096, 4096)),
232            allowed_aspect_ratios: vec![], // Empty means any aspect ratio is allowed
233            max_file_size: Some(10 * 1024 * 1024), // 10MB
234            min_file_size: Some(1024),     // 1KB
235            required_metadata: vec![],
236            quality_thresholds: QualityThresholds::default(),
237        }
238    }
239}
240
241impl ImageValidationSpec {
242    /// Validate image data against these specifications
243    pub fn validate(&self, image_data: &ImageData) -> Result<(), ValidationError> {
244        // Check minimum dimensions
245        if let Some((min_h, min_w)) = self.min_dimensions {
246            if image_data.height < min_h || image_data.width < min_w {
247                return Err(ValidationError::DimensionTooSmall {
248                    minimum: (min_h, min_w),
249                    actual: (image_data.height, image_data.width),
250                });
251            }
252        }
253
254        // Check maximum dimensions
255        if let Some((max_h, max_w)) = self.max_dimensions {
256            if image_data.height > max_h || image_data.width > max_w {
257                return Err(ValidationError::DimensionTooLarge {
258                    maximum: (max_h, max_w),
259                    actual: (image_data.height, image_data.width),
260                });
261            }
262        }
263
264        // Check aspect ratio
265        if !self.allowed_aspect_ratios.is_empty() {
266            let aspect_ratio = image_data.width as f64 / image_data.height as f64;
267            let tolerance = 0.01; // 1% tolerance
268
269            let aspect_ratio_valid = self
270                .allowed_aspect_ratios
271                .iter()
272                .any(|&allowed| (aspect_ratio - allowed).abs() < tolerance);
273
274            if !aspect_ratio_valid {
275                return Err(ValidationError::InvalidAspectRatio {
276                    allowed: self.allowed_aspect_ratios.clone(),
277                    actual: aspect_ratio,
278                });
279            }
280        }
281
282        // Check file size if available
283        if let Some(file_size) = image_data.file_size {
284            if let Some(max_size) = self.max_file_size {
285                if file_size > max_size {
286                    return Err(ValidationError::FileTooLarge {
287                        maximum: max_size,
288                        actual: file_size,
289                    });
290                }
291            }
292
293            if let Some(min_size) = self.min_file_size {
294                if file_size < min_size {
295                    return Err(ValidationError::FileTooSmall {
296                        minimum: min_size,
297                        actual: file_size,
298                    });
299                }
300            }
301        }
302
303        // Validate quality thresholds
304        self.quality_thresholds.validate(image_data)?;
305
306        Ok(())
307    }
308}
309
310/// Quality thresholds for image validation
311#[derive(Debug, Clone, Serialize, Deserialize)]
312pub struct QualityThresholds {
313    /// Minimum brightness level (0.0-1.0)
314    pub min_brightness: Option<f64>,
315    /// Maximum brightness level (0.0-1.0)
316    pub max_brightness: Option<f64>,
317    /// Minimum contrast level (0.0-1.0)
318    pub min_contrast: Option<f64>,
319    /// Maximum blur metric (lower is sharper)
320    pub max_blur: Option<f64>,
321    /// Minimum signal-to-noise ratio
322    pub min_snr: Option<f64>,
323}
324
325impl Default for QualityThresholds {
326    fn default() -> Self {
327        Self {
328            min_brightness: Some(0.1),
329            max_brightness: Some(0.9),
330            min_contrast: Some(0.1),
331            max_blur: Some(10.0),
332            min_snr: Some(20.0),
333        }
334    }
335}
336
337impl QualityThresholds {
338    /// Validate image quality against thresholds
339    pub fn validate(&self, image_data: &ImageData) -> Result<(), ValidationError> {
340        // Note: In a real implementation, these would calculate actual metrics
341        // For now, we'll assume the image data includes quality metrics
342
343        if let Some(brightness) = image_data.quality_metrics.get("brightness") {
344            if let Some(min_brightness) = self.min_brightness {
345                if *brightness < min_brightness {
346                    return Err(ValidationError::QualityTooLow {
347                        metric: "brightness".to_string(),
348                        threshold: min_brightness,
349                        actual: *brightness,
350                    });
351                }
352            }
353
354            if let Some(max_brightness) = self.max_brightness {
355                if *brightness > max_brightness {
356                    return Err(ValidationError::QualityTooHigh {
357                        metric: "brightness".to_string(),
358                        threshold: max_brightness,
359                        actual: *brightness,
360                    });
361                }
362            }
363        }
364
365        if let Some(contrast) = image_data.quality_metrics.get("contrast") {
366            if let Some(min_contrast) = self.min_contrast {
367                if *contrast < min_contrast {
368                    return Err(ValidationError::QualityTooLow {
369                        metric: "contrast".to_string(),
370                        threshold: min_contrast,
371                        actual: *contrast,
372                    });
373                }
374            }
375        }
376
377        Ok(())
378    }
379}
380
381/// Normalization specification for preprocessing
382#[derive(Debug, Clone, Serialize, Deserialize)]
383pub struct NormalizationSpec {
384    /// Mean values per channel
385    pub mean: Array1<f32>,
386    /// Standard deviation per channel
387    pub std: Array1<f32>,
388    /// Value range (min, max) for clipping
389    pub range: (f32, f32),
390    /// Whether to apply per-channel normalization
391    pub per_channel: bool,
392    /// Whether to apply global normalization
393    pub global_normalization: bool,
394}
395
396impl Default for NormalizationSpec {
397    fn default() -> Self {
398        // ImageNet normalization values for RGB
399        Self {
400            mean: Array1::from(vec![0.485, 0.456, 0.406]),
401            std: Array1::from(vec![0.229, 0.224, 0.225]),
402            range: (0.0, 1.0),
403            per_channel: true,
404            global_normalization: false,
405        }
406    }
407}
408
409impl NormalizationSpec {
410    /// Create normalization spec for grayscale images
411    #[must_use]
412    pub fn grayscale() -> Self {
413        Self {
414            mean: Array1::from(vec![0.5]),
415            std: Array1::from(vec![0.5]),
416            range: (0.0, 1.0),
417            per_channel: true,
418            global_normalization: false,
419        }
420    }
421
422    /// Create custom normalization spec
423    #[must_use]
424    pub fn custom(mean: Vec<f32>, std: Vec<f32>, range: (f32, f32)) -> Self {
425        Self {
426            mean: Array1::from(mean),
427            std: Array1::from(std),
428            range,
429            per_channel: true,
430            global_normalization: false,
431        }
432    }
433
434    /// Apply normalization to image data
435    pub fn normalize(&self, image: &mut Array3<f32>) -> Result<(), ValidationError> {
436        let (height, width, channels) = image.dim();
437
438        if self.mean.len() != channels || self.std.len() != channels {
439            return Err(ValidationError::NormalizationError(
440                "Mean and std dimensions don't match image channels".to_string(),
441            ));
442        }
443
444        for c in 0..channels {
445            let mean_val = self.mean[c];
446            let std_val = self.std[c];
447
448            if std_val == 0.0 {
449                return Err(ValidationError::NormalizationError(
450                    "Standard deviation cannot be zero".to_string(),
451                ));
452            }
453
454            for h in 0..height {
455                for w in 0..width {
456                    let pixel_value = (image[[h, w, c]] - mean_val) / std_val;
457                    image[[h, w, c]] = pixel_value.clamp(self.range.0, self.range.1);
458                }
459            }
460        }
461
462        Ok(())
463    }
464}
465
466/// Image data structure for validation and processing
467#[derive(Debug, Clone)]
468pub struct ImageData {
469    /// Image height in pixels
470    pub height: usize,
471    /// Image width in pixels
472    pub width: usize,
473    /// Number of channels
474    pub channels: usize,
475    /// Data type
476    pub dtype: ImageDataType,
477    /// Color space
478    pub color_space: ColorSpace,
479    /// Optional file size in bytes
480    pub file_size: Option<usize>,
481    /// Quality metrics
482    pub quality_metrics: HashMap<String, f64>,
483    /// Image tensor data
484    pub data: Array3<f32>,
485    /// Optional metadata
486    pub metadata: HashMap<String, String>,
487}
488
489impl ImageData {
490    /// Create new image data with specified properties
491    #[must_use]
492    pub fn new(
493        height: usize,
494        width: usize,
495        channels: usize,
496        dtype: ImageDataType,
497        color_space: ColorSpace,
498        data: Array3<f32>,
499    ) -> Self {
500        Self {
501            height,
502            width,
503            channels,
504            dtype,
505            color_space,
506            file_size: None,
507            quality_metrics: HashMap::new(),
508            data,
509            metadata: HashMap::new(),
510        }
511    }
512
513    /// Get aspect ratio (width/height)
514    #[must_use]
515    pub fn aspect_ratio(&self) -> f64 {
516        self.width as f64 / self.height as f64
517    }
518
519    /// Get total number of pixels
520    #[must_use]
521    pub fn pixel_count(&self) -> usize {
522        self.height * self.width
523    }
524
525    /// Calculate memory footprint in bytes
526    #[must_use]
527    pub fn memory_footprint(&self) -> usize {
528        let bytes_per_element = match self.dtype {
529            ImageDataType::UInt8 => 1,
530            ImageDataType::UInt16 => 2,
531            ImageDataType::Float32 => 4,
532            ImageDataType::Float64 => 8,
533        };
534        self.height * self.width * self.channels * bytes_per_element
535    }
536}
537
538/// Validation errors for image specifications
539#[derive(Debug, Clone, PartialEq)]
540pub enum ValidationError {
541    /// Image dimensions don't match expected values
542    DimensionMismatch {
543        expected: (usize, usize),
544        actual: (usize, usize),
545    },
546    /// Number of channels doesn't match
547    ChannelMismatch { expected: usize, actual: usize },
548    /// Data type doesn't match
549    DataTypeMismatch {
550        expected: ImageDataType,
551        actual: ImageDataType,
552    },
553    /// Color space doesn't match
554    ColorSpaceMismatch {
555        expected: ColorSpace,
556        actual: ColorSpace,
557    },
558    /// Image dimensions are too small
559    DimensionTooSmall {
560        minimum: (usize, usize),
561        actual: (usize, usize),
562    },
563    /// Image dimensions are too large
564    DimensionTooLarge {
565        maximum: (usize, usize),
566        actual: (usize, usize),
567    },
568    /// Invalid aspect ratio
569    InvalidAspectRatio { allowed: Vec<f64>, actual: f64 },
570    /// File size too large
571    FileTooLarge { maximum: usize, actual: usize },
572    /// File size too small
573    FileTooSmall { minimum: usize, actual: usize },
574    /// Quality metric below threshold
575    QualityTooLow {
576        metric: String,
577        threshold: f64,
578        actual: f64,
579    },
580    /// Quality metric above threshold
581    QualityTooHigh {
582        metric: String,
583        threshold: f64,
584        actual: f64,
585    },
586    /// Normalization error
587    NormalizationError(String),
588}
589
590impl std::fmt::Display for ValidationError {
591    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
592        match self {
593            Self::DimensionMismatch { expected, actual } => {
594                write!(
595                    f,
596                    "Dimension mismatch: expected {expected:?}, got {actual:?}"
597                )
598            }
599            Self::ChannelMismatch { expected, actual } => {
600                write!(f, "Channel mismatch: expected {expected}, got {actual}")
601            }
602            Self::DataTypeMismatch { expected, actual } => {
603                write!(
604                    f,
605                    "Data type mismatch: expected {expected:?}, got {actual:?}"
606                )
607            }
608            Self::ColorSpaceMismatch { expected, actual } => {
609                write!(
610                    f,
611                    "Color space mismatch: expected {expected:?}, got {actual:?}"
612                )
613            }
614            Self::DimensionTooSmall { minimum, actual } => {
615                write!(
616                    f,
617                    "Dimensions too small: minimum {minimum:?}, got {actual:?}"
618                )
619            }
620            Self::DimensionTooLarge { maximum, actual } => {
621                write!(
622                    f,
623                    "Dimensions too large: maximum {maximum:?}, got {actual:?}"
624                )
625            }
626            Self::InvalidAspectRatio { allowed, actual } => {
627                write!(f, "Invalid aspect ratio: allowed {allowed:?}, got {actual}")
628            }
629            Self::FileTooLarge { maximum, actual } => {
630                write!(
631                    f,
632                    "File too large: maximum {maximum} bytes, got {actual} bytes"
633                )
634            }
635            Self::FileTooSmall { minimum, actual } => {
636                write!(
637                    f,
638                    "File too small: minimum {minimum} bytes, got {actual} bytes"
639                )
640            }
641            Self::QualityTooLow {
642                metric,
643                threshold,
644                actual,
645            } => {
646                write!(
647                    f,
648                    "Quality too low for {metric}: minimum {threshold}, got {actual}"
649                )
650            }
651            Self::QualityTooHigh {
652                metric,
653                threshold,
654                actual,
655            } => {
656                write!(
657                    f,
658                    "Quality too high for {metric}: maximum {threshold}, got {actual}"
659                )
660            }
661            Self::NormalizationError(msg) => {
662                write!(f, "Normalization error: {msg}")
663            }
664        }
665    }
666}
667
668impl std::error::Error for ValidationError {}
669
670#[allow(non_snake_case)]
671#[cfg(test)]
672mod tests {
673    use super::*;
674    use scirs2_core::ndarray::Array3;
675
676    #[test]
677    fn test_image_specification_creation() {
678        let spec = ImageSpecification::rgb(640, 480);
679        assert_eq!(spec.dimensions, Some((480, 640)));
680        assert_eq!(spec.channels, 3);
681        assert_eq!(spec.color_space, ColorSpace::RGB);
682
683        let spec = ImageSpecification::grayscale(224, 224);
684        assert_eq!(spec.channels, 1);
685        assert_eq!(spec.color_space, ColorSpace::Grayscale);
686    }
687
688    #[test]
689    fn test_image_validation() {
690        let spec = ImageSpecification::rgb(224, 224);
691        let data = Array3::<f32>::zeros((224, 224, 3));
692        let image = ImageData::new(224, 224, 3, ImageDataType::UInt8, ColorSpace::RGB, data);
693
694        assert!(spec.validate(&image).is_ok());
695
696        // Test dimension mismatch
697        let data = Array3::<f32>::zeros((256, 256, 3));
698        let image = ImageData::new(256, 256, 3, ImageDataType::UInt8, ColorSpace::RGB, data);
699        assert!(spec.validate(&image).is_err());
700    }
701
702    #[test]
703    fn test_normalization_spec() {
704        let norm_spec = NormalizationSpec::grayscale();
705        assert_eq!(norm_spec.mean.len(), 1);
706        assert_eq!(norm_spec.std.len(), 1);
707
708        let norm_spec = NormalizationSpec::custom(
709            vec![0.485, 0.456, 0.406],
710            vec![0.229, 0.224, 0.225],
711            (0.0, 1.0),
712        );
713        assert_eq!(norm_spec.mean.len(), 3);
714        assert_eq!(norm_spec.std.len(), 3);
715    }
716
717    #[test]
718    fn test_image_data_properties() {
719        let data = Array3::<f32>::zeros((480, 640, 3));
720        let image = ImageData::new(480, 640, 3, ImageDataType::UInt8, ColorSpace::RGB, data);
721
722        assert_eq!(image.aspect_ratio(), 640.0 / 480.0);
723        assert_eq!(image.pixel_count(), 480 * 640);
724        assert_eq!(image.memory_footprint(), 480 * 640 * 3 * 1); // UInt8 = 1 byte
725    }
726
727    #[test]
728    fn test_validation_error_display() {
729        let error = ValidationError::DimensionMismatch {
730            expected: (224, 224),
731            actual: (256, 256),
732        };
733        let error_str = error.to_string();
734        assert!(error_str.contains("Dimension mismatch"));
735        assert!(error_str.contains("224"));
736        assert!(error_str.contains("256"));
737    }
738}