Skip to main content

ruvector_cnn/contrastive/
augmentation.rs

1//! # Contrastive Augmentation
2//!
3//! SimCLR-style data augmentation for contrastive learning.
4//!
5//! ## Augmentation Pipeline
6//!
7//! The default SimCLR augmentation pipeline includes:
8//! 1. Random resized crop (scale 0.08-1.0, ratio 3/4-4/3)
9//! 2. Random horizontal flip (p=0.5)
10//! 3. Color jitter (brightness, contrast, saturation, hue)
11//! 4. Random grayscale (p=0.2)
12//! 5. Gaussian blur (optional)
13//!
14//! ## References
15//!
16//! - SimCLR: "A Simple Framework for Contrastive Learning of Visual Representations"
17//! - MoCo: "Momentum Contrast for Unsupervised Visual Representation Learning"
18
19#[cfg(feature = "augmentation")]
20use crate::error::{CnnError, CnnResult};
21#[cfg(feature = "augmentation")]
22use image::{DynamicImage, GenericImageView, ImageBuffer, Rgb, RgbImage};
23#[cfg(feature = "augmentation")]
24use rand::Rng;
25use serde::{Deserialize, Serialize};
26
27/// Configuration for contrastive augmentation.
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct AugmentationConfig {
30    /// Minimum crop scale (default: 0.08)
31    pub crop_scale_min: f64,
32    /// Maximum crop scale (default: 1.0)
33    pub crop_scale_max: f64,
34    /// Minimum aspect ratio (default: 0.75)
35    pub aspect_ratio_min: f64,
36    /// Maximum aspect ratio (default: 1.333)
37    pub aspect_ratio_max: f64,
38    /// Probability of horizontal flip (default: 0.5)
39    pub horizontal_flip_prob: f64,
40    /// Brightness jitter factor (default: 0.4)
41    pub brightness: f64,
42    /// Contrast jitter factor (default: 0.4)
43    pub contrast: f64,
44    /// Saturation jitter factor (default: 0.4)
45    pub saturation: f64,
46    /// Hue jitter factor (default: 0.1)
47    pub hue: f64,
48    /// Probability of color jitter (default: 0.8)
49    pub color_jitter_prob: f64,
50    /// Probability of grayscale conversion (default: 0.2)
51    pub grayscale_prob: f64,
52    /// Gaussian blur kernel size (0 to disable)
53    pub blur_kernel_size: u32,
54    /// Probability of Gaussian blur (default: 0.5)
55    pub blur_prob: f64,
56    /// Gaussian blur sigma range
57    pub blur_sigma_range: (f64, f64),
58    /// Target output size (width, height)
59    pub output_size: (u32, u32),
60}
61
62impl Default for AugmentationConfig {
63    fn default() -> Self {
64        Self {
65            crop_scale_min: 0.08,
66            crop_scale_max: 1.0,
67            aspect_ratio_min: 0.75,
68            aspect_ratio_max: 4.0 / 3.0,
69            horizontal_flip_prob: 0.5,
70            brightness: 0.4,
71            contrast: 0.4,
72            saturation: 0.4,
73            hue: 0.1,
74            color_jitter_prob: 0.8,
75            grayscale_prob: 0.2,
76            blur_kernel_size: 0,
77            blur_prob: 0.5,
78            blur_sigma_range: (0.1, 2.0),
79            output_size: (224, 224),
80        }
81    }
82}
83
84/// Builder for ContrastiveAugmentation.
85#[derive(Debug, Clone)]
86pub struct ContrastiveAugmentationBuilder {
87    config: AugmentationConfig,
88    seed: Option<u64>,
89}
90
91impl ContrastiveAugmentationBuilder {
92    /// Create a new builder with default config.
93    pub fn new() -> Self {
94        Self {
95            config: AugmentationConfig::default(),
96            seed: None,
97        }
98    }
99
100    /// Set the crop scale range.
101    pub fn crop_scale(mut self, min: f64, max: f64) -> Self {
102        self.config.crop_scale_min = min;
103        self.config.crop_scale_max = max;
104        self
105    }
106
107    /// Set the aspect ratio range.
108    pub fn aspect_ratio(mut self, min: f64, max: f64) -> Self {
109        self.config.aspect_ratio_min = min;
110        self.config.aspect_ratio_max = max;
111        self
112    }
113
114    /// Set the horizontal flip probability.
115    pub fn horizontal_flip_prob(mut self, prob: f64) -> Self {
116        self.config.horizontal_flip_prob = prob;
117        self
118    }
119
120    /// Set the color jitter parameters.
121    pub fn color_jitter(mut self, brightness: f64, contrast: f64, saturation: f64, hue: f64) -> Self {
122        self.config.brightness = brightness;
123        self.config.contrast = contrast;
124        self.config.saturation = saturation;
125        self.config.hue = hue;
126        self
127    }
128
129    /// Set the color jitter probability.
130    pub fn color_jitter_prob(mut self, prob: f64) -> Self {
131        self.config.color_jitter_prob = prob;
132        self
133    }
134
135    /// Set the grayscale probability.
136    pub fn grayscale_prob(mut self, prob: f64) -> Self {
137        self.config.grayscale_prob = prob;
138        self
139    }
140
141    /// Enable Gaussian blur with the specified kernel size.
142    pub fn gaussian_blur(mut self, kernel_size: u32, sigma_range: (f64, f64)) -> Self {
143        self.config.blur_kernel_size = kernel_size;
144        self.config.blur_sigma_range = sigma_range;
145        self
146    }
147
148    /// Set the blur probability.
149    pub fn blur_prob(mut self, prob: f64) -> Self {
150        self.config.blur_prob = prob;
151        self
152    }
153
154    /// Set the output size.
155    pub fn output_size(mut self, width: u32, height: u32) -> Self {
156        self.config.output_size = (width, height);
157        self
158    }
159
160    /// Set a fixed random seed for reproducibility.
161    pub fn seed(mut self, seed: u64) -> Self {
162        self.seed = Some(seed);
163        self
164    }
165
166    /// Build the ContrastiveAugmentation instance.
167    pub fn build(self) -> ContrastiveAugmentation {
168        let rng = if let Some(seed) = self.seed {
169            rand::SeedableRng::seed_from_u64(seed)
170        } else {
171            rand::SeedableRng::from_entropy()
172        };
173        ContrastiveAugmentation {
174            config: self.config,
175            rng,
176        }
177    }
178}
179
180impl Default for ContrastiveAugmentationBuilder {
181    fn default() -> Self {
182        Self::new()
183    }
184}
185
186/// SimCLR-style contrastive augmentation pipeline.
187///
188/// # Example
189///
190/// ```rust,no_run
191/// use ruvector_cnn::contrastive::ContrastiveAugmentation;
192///
193/// let aug = ContrastiveAugmentation::builder()
194///     .crop_scale(0.08, 1.0)
195///     .horizontal_flip_prob(0.5)
196///     .color_jitter(0.4, 0.4, 0.4, 0.1)
197///     .output_size(224, 224)
198///     .build();
199///
200/// // Generate two augmented views of an image (requires augmentation feature)
201/// // let (view1, view2) = aug.generate_pair(&image)?;
202/// ```
203#[derive(Debug, Clone)]
204pub struct ContrastiveAugmentation {
205    config: AugmentationConfig,
206    /// Random number generator for stochastic augmentations
207    #[allow(dead_code)]
208    rng: rand::rngs::StdRng,
209}
210
211impl ContrastiveAugmentation {
212    /// Create a builder for ContrastiveAugmentation.
213    pub fn builder() -> ContrastiveAugmentationBuilder {
214        ContrastiveAugmentationBuilder::new()
215    }
216
217    /// Get the current configuration.
218    pub fn config(&self) -> &AugmentationConfig {
219        &self.config
220    }
221
222    /// Generate two augmented views of an image.
223    ///
224    /// This is the core operation for SimCLR-style contrastive learning.
225    ///
226    /// # Arguments
227    ///
228    /// * `image` - The input image
229    ///
230    /// # Returns
231    ///
232    /// A tuple of two independently augmented views.
233    #[cfg(feature = "augmentation")]
234    pub fn generate_pair(&mut self, image: &DynamicImage) -> CnnResult<(RgbImage, RgbImage)> {
235        let view1 = self.augment(image)?;
236        let view2 = self.augment(image)?;
237        Ok((view1, view2))
238    }
239
240    /// Apply the full augmentation pipeline to an image.
241    #[cfg(feature = "augmentation")]
242    pub fn augment(&mut self, image: &DynamicImage) -> CnnResult<RgbImage> {
243        let mut img = image.to_rgb8();
244
245        // 1. Random resized crop
246        img = self.random_resized_crop(&img)?;
247
248        // 2. Random horizontal flip
249        if self.rng.gen::<f64>() < self.config.horizontal_flip_prob {
250            img = self.horizontal_flip(&img);
251        }
252
253        // 3. Color jitter (with probability)
254        if self.rng.gen::<f64>() < self.config.color_jitter_prob {
255            img = self.color_jitter(&img)?;
256        }
257
258        // 4. Random grayscale
259        if self.rng.gen::<f64>() < self.config.grayscale_prob {
260            img = self.to_grayscale(&img);
261        }
262
263        // 5. Gaussian blur (optional)
264        if self.config.blur_kernel_size > 0 && self.rng.gen::<f64>() < self.config.blur_prob {
265            img = self.gaussian_blur(&img)?;
266        }
267
268        Ok(img)
269    }
270
271    /// Random resized crop with configurable scale and aspect ratio.
272    #[cfg(feature = "augmentation")]
273    pub fn random_resized_crop(&mut self, image: &RgbImage) -> CnnResult<RgbImage> {
274        let (orig_w, orig_h) = image.dimensions();
275        let orig_area = (orig_w * orig_h) as f64;
276
277        // Try up to 10 times to find a valid crop
278        for _ in 0..10 {
279            // Sample scale and aspect ratio
280            let scale = self.rng.gen_range(self.config.crop_scale_min..=self.config.crop_scale_max);
281            let aspect = self.rng.gen_range(
282                self.config.aspect_ratio_min.ln()..=self.config.aspect_ratio_max.ln(),
283            ).exp();
284
285            // Compute crop dimensions
286            let crop_area = orig_area * scale;
287            let crop_w = (crop_area * aspect).sqrt() as u32;
288            let crop_h = (crop_area / aspect).sqrt() as u32;
289
290            if crop_w <= orig_w && crop_h <= orig_h && crop_w > 0 && crop_h > 0 {
291                // Random position
292                let x = self.rng.gen_range(0..=(orig_w - crop_w));
293                let y = self.rng.gen_range(0..=(orig_h - crop_h));
294
295                // Crop
296                let cropped = image::imageops::crop_imm(image, x, y, crop_w, crop_h).to_image();
297
298                // Resize to output size
299                let (target_w, target_h) = self.config.output_size;
300                let resized = image::imageops::resize(
301                    &cropped,
302                    target_w,
303                    target_h,
304                    image::imageops::FilterType::Lanczos3,
305                );
306
307                return Ok(resized);
308            }
309        }
310
311        // Fallback: center crop to maintain aspect ratio, then resize
312        let (target_w, target_h) = self.config.output_size;
313        let target_ratio = target_w as f64 / target_h as f64;
314        let orig_ratio = orig_w as f64 / orig_h as f64;
315
316        let (crop_w, crop_h) = if orig_ratio > target_ratio {
317            // Original is wider - crop width
318            let h = orig_h;
319            let w = (h as f64 * target_ratio) as u32;
320            (w, h)
321        } else {
322            // Original is taller - crop height
323            let w = orig_w;
324            let h = (w as f64 / target_ratio) as u32;
325            (w, h)
326        };
327
328        let x = (orig_w - crop_w) / 2;
329        let y = (orig_h - crop_h) / 2;
330
331        let cropped = image::imageops::crop_imm(image, x, y, crop_w, crop_h).to_image();
332        let resized = image::imageops::resize(
333            &cropped,
334            target_w,
335            target_h,
336            image::imageops::FilterType::Lanczos3,
337        );
338
339        Ok(resized)
340    }
341
342    /// Horizontal flip.
343    #[cfg(feature = "augmentation")]
344    pub fn horizontal_flip(&self, image: &RgbImage) -> RgbImage {
345        image::imageops::flip_horizontal(image)
346    }
347
348    /// Color jitter: randomly adjust brightness, contrast, saturation, and hue.
349    #[cfg(feature = "augmentation")]
350    pub fn color_jitter(&mut self, image: &RgbImage) -> CnnResult<RgbImage> {
351        let (width, height) = image.dimensions();
352        let mut result = image.clone();
353
354        // Sample jitter factors
355        let brightness_factor = 1.0 + self.rng.gen_range(-self.config.brightness..=self.config.brightness);
356        let contrast_factor = 1.0 + self.rng.gen_range(-self.config.contrast..=self.config.contrast);
357        let saturation_factor = 1.0 + self.rng.gen_range(-self.config.saturation..=self.config.saturation);
358        let hue_shift = self.rng.gen_range(-self.config.hue..=self.config.hue);
359
360        // Compute image mean for contrast adjustment
361        let mean = self.compute_mean(image);
362
363        for y in 0..height {
364            for x in 0..width {
365                let pixel = image.get_pixel(x, y);
366                let mut rgb = [pixel[0] as f64 / 255.0, pixel[1] as f64 / 255.0, pixel[2] as f64 / 255.0];
367
368                // Apply brightness
369                for c in rgb.iter_mut() {
370                    *c *= brightness_factor;
371                }
372
373                // Apply contrast
374                for (i, c) in rgb.iter_mut().enumerate() {
375                    *c = (*c - mean[i]) * contrast_factor + mean[i];
376                }
377
378                // Apply saturation and hue in HSV space
379                let (h, s, v) = rgb_to_hsv(rgb[0], rgb[1], rgb[2]);
380                let new_s = (s * saturation_factor).clamp(0.0, 1.0);
381                let new_h = (h + hue_shift * 360.0).rem_euclid(360.0);
382                let (r, g, b) = hsv_to_rgb(new_h, new_s, v);
383
384                rgb = [r, g, b];
385
386                // Clamp and convert back to u8
387                let out_pixel = Rgb([
388                    (rgb[0] * 255.0).clamp(0.0, 255.0) as u8,
389                    (rgb[1] * 255.0).clamp(0.0, 255.0) as u8,
390                    (rgb[2] * 255.0).clamp(0.0, 255.0) as u8,
391                ]);
392                result.put_pixel(x, y, out_pixel);
393            }
394        }
395
396        Ok(result)
397    }
398
399    /// Convert to grayscale (but keep 3 channels).
400    #[cfg(feature = "augmentation")]
401    pub fn to_grayscale(&self, image: &RgbImage) -> RgbImage {
402        let (width, height) = image.dimensions();
403        let mut result = ImageBuffer::new(width, height);
404
405        for y in 0..height {
406            for x in 0..width {
407                let pixel = image.get_pixel(x, y);
408                // Luminance formula: 0.299*R + 0.587*G + 0.114*B
409                let gray = (0.299 * pixel[0] as f64
410                    + 0.587 * pixel[1] as f64
411                    + 0.114 * pixel[2] as f64) as u8;
412                result.put_pixel(x, y, Rgb([gray, gray, gray]));
413            }
414        }
415
416        result
417    }
418
419    /// Gaussian blur (simplified box blur implementation).
420    #[cfg(feature = "augmentation")]
421    pub fn gaussian_blur(&mut self, image: &RgbImage) -> CnnResult<RgbImage> {
422        let sigma = self.rng.gen_range(self.config.blur_sigma_range.0..=self.config.blur_sigma_range.1);
423
424        // Use kernel size from config, or compute from sigma
425        let kernel_size = if self.config.blur_kernel_size > 0 {
426            self.config.blur_kernel_size
427        } else {
428            let k = (sigma * 6.0).ceil() as u32;
429            if k % 2 == 0 { k + 1 } else { k }
430        };
431
432        // Generate Gaussian kernel
433        let kernel = self.generate_gaussian_kernel(kernel_size, sigma);
434
435        // Apply separable convolution
436        let blurred = self.convolve_separable(image, &kernel)?;
437
438        Ok(blurred)
439    }
440
441    /// Generate 1D Gaussian kernel.
442    #[cfg(feature = "augmentation")]
443    fn generate_gaussian_kernel(&self, size: u32, sigma: f64) -> Vec<f64> {
444        let size = size as i32;
445        let center = size / 2;
446        let mut kernel = Vec::with_capacity(size as usize);
447        let mut sum = 0.0;
448
449        let sigma_sq_2 = 2.0 * sigma * sigma;
450
451        for i in 0..size {
452            let x = (i - center) as f64;
453            let value = (-x * x / sigma_sq_2).exp();
454            kernel.push(value);
455            sum += value;
456        }
457
458        // Normalize
459        for k in kernel.iter_mut() {
460            *k /= sum;
461        }
462
463        kernel
464    }
465
466    /// Apply separable convolution (horizontal then vertical pass).
467    #[cfg(feature = "augmentation")]
468    fn convolve_separable(&self, image: &RgbImage, kernel: &[f64]) -> CnnResult<RgbImage> {
469        let (width, height) = image.dimensions();
470        let radius = kernel.len() / 2;
471
472        // Horizontal pass
473        let mut temp = ImageBuffer::<Rgb<u8>, _>::new(width, height);
474        for y in 0..height {
475            for x in 0..width {
476                let mut sum = [0.0, 0.0, 0.0];
477                for (i, &k) in kernel.iter().enumerate() {
478                    let sx = (x as i32 + i as i32 - radius as i32).clamp(0, width as i32 - 1) as u32;
479                    let pixel = image.get_pixel(sx, y);
480                    sum[0] += pixel[0] as f64 * k;
481                    sum[1] += pixel[1] as f64 * k;
482                    sum[2] += pixel[2] as f64 * k;
483                }
484                temp.put_pixel(x, y, Rgb([
485                    sum[0].clamp(0.0, 255.0) as u8,
486                    sum[1].clamp(0.0, 255.0) as u8,
487                    sum[2].clamp(0.0, 255.0) as u8,
488                ]));
489            }
490        }
491
492        // Vertical pass
493        let mut result = ImageBuffer::<Rgb<u8>, _>::new(width, height);
494        for y in 0..height {
495            for x in 0..width {
496                let mut sum = [0.0, 0.0, 0.0];
497                for (i, &k) in kernel.iter().enumerate() {
498                    let sy = (y as i32 + i as i32 - radius as i32).clamp(0, height as i32 - 1) as u32;
499                    let pixel = temp.get_pixel(x, sy);
500                    sum[0] += pixel[0] as f64 * k;
501                    sum[1] += pixel[1] as f64 * k;
502                    sum[2] += pixel[2] as f64 * k;
503                }
504                result.put_pixel(x, y, Rgb([
505                    sum[0].clamp(0.0, 255.0) as u8,
506                    sum[1].clamp(0.0, 255.0) as u8,
507                    sum[2].clamp(0.0, 255.0) as u8,
508                ]));
509            }
510        }
511
512        Ok(result)
513    }
514
515    /// Compute mean pixel value per channel.
516    #[cfg(feature = "augmentation")]
517    fn compute_mean(&self, image: &RgbImage) -> [f64; 3] {
518        let (width, height) = image.dimensions();
519        let n = (width * height) as f64;
520        let mut sum = [0.0, 0.0, 0.0];
521
522        for pixel in image.pixels() {
523            sum[0] += pixel[0] as f64 / 255.0;
524            sum[1] += pixel[1] as f64 / 255.0;
525            sum[2] += pixel[2] as f64 / 255.0;
526        }
527
528        [sum[0] / n, sum[1] / n, sum[2] / n]
529    }
530}
531
532impl Default for ContrastiveAugmentation {
533    fn default() -> Self {
534        Self::builder().build()
535    }
536}
537
538/// Convert RGB to HSV.
539#[cfg(feature = "augmentation")]
540fn rgb_to_hsv(r: f64, g: f64, b: f64) -> (f64, f64, f64) {
541    let max = r.max(g).max(b);
542    let min = r.min(g).min(b);
543    let delta = max - min;
544
545    let v = max;
546
547    let s = if max > 1e-8 { delta / max } else { 0.0 };
548
549    let h = if delta < 1e-8 {
550        0.0
551    } else if (max - r).abs() < 1e-8 {
552        60.0 * (((g - b) / delta) % 6.0)
553    } else if (max - g).abs() < 1e-8 {
554        60.0 * ((b - r) / delta + 2.0)
555    } else {
556        60.0 * ((r - g) / delta + 4.0)
557    };
558
559    let h = if h < 0.0 { h + 360.0 } else { h };
560
561    (h, s, v)
562}
563
564/// Convert HSV to RGB.
565#[cfg(feature = "augmentation")]
566fn hsv_to_rgb(h: f64, s: f64, v: f64) -> (f64, f64, f64) {
567    let c = v * s;
568    let h_prime = h / 60.0;
569    let x = c * (1.0 - (h_prime % 2.0 - 1.0).abs());
570
571    let (r1, g1, b1) = if h_prime < 1.0 {
572        (c, x, 0.0)
573    } else if h_prime < 2.0 {
574        (x, c, 0.0)
575    } else if h_prime < 3.0 {
576        (0.0, c, x)
577    } else if h_prime < 4.0 {
578        (0.0, x, c)
579    } else if h_prime < 5.0 {
580        (x, 0.0, c)
581    } else {
582        (c, 0.0, x)
583    };
584
585    let m = v - c;
586    (r1 + m, g1 + m, b1 + m)
587}
588
589#[cfg(all(test, feature = "augmentation"))]
590mod tests {
591    use super::*;
592
593    fn create_test_image(width: u32, height: u32) -> RgbImage {
594        let mut img = ImageBuffer::new(width, height);
595        for y in 0..height {
596            for x in 0..width {
597                let r = ((x * 255) / width) as u8;
598                let g = ((y * 255) / height) as u8;
599                let b = 128;
600                img.put_pixel(x, y, Rgb([r, g, b]));
601            }
602        }
603        img
604    }
605
606    #[test]
607    fn test_augmentation_builder() {
608        let aug = ContrastiveAugmentation::builder()
609            .crop_scale(0.5, 1.0)
610            .horizontal_flip_prob(0.3)
611            .output_size(128, 128)
612            .seed(42)
613            .build();
614
615        assert_eq!(aug.config.crop_scale_min, 0.5);
616        assert_eq!(aug.config.horizontal_flip_prob, 0.3);
617        assert_eq!(aug.config.output_size, (128, 128));
618    }
619
620    #[test]
621    fn test_random_resized_crop() {
622        let mut aug = ContrastiveAugmentation::builder()
623            .output_size(64, 64)
624            .seed(42)
625            .build();
626
627        let img = create_test_image(256, 256);
628        let cropped = aug.random_resized_crop(&img).unwrap();
629
630        assert_eq!(cropped.dimensions(), (64, 64));
631    }
632
633    #[test]
634    fn test_horizontal_flip() {
635        let aug = ContrastiveAugmentation::default();
636        let img = create_test_image(4, 4);
637        let flipped = aug.horizontal_flip(&img);
638
639        // Check that leftmost pixel is now rightmost
640        assert_eq!(flipped.get_pixel(3, 0), img.get_pixel(0, 0));
641        assert_eq!(flipped.get_pixel(0, 0), img.get_pixel(3, 0));
642    }
643
644    #[test]
645    fn test_color_jitter() {
646        let mut aug = ContrastiveAugmentation::builder()
647            .color_jitter(0.2, 0.2, 0.2, 0.05)
648            .seed(42)
649            .build();
650
651        let img = create_test_image(64, 64);
652        let jittered = aug.color_jitter(&img).unwrap();
653
654        // Should have same dimensions
655        assert_eq!(jittered.dimensions(), img.dimensions());
656
657        // Should be different from original (with high probability)
658        let diff: u32 = img
659            .pixels()
660            .zip(jittered.pixels())
661            .map(|(p1, p2)| {
662                (p1[0] as i32 - p2[0] as i32).unsigned_abs()
663                    + (p1[1] as i32 - p2[1] as i32).unsigned_abs()
664                    + (p1[2] as i32 - p2[2] as i32).unsigned_abs()
665            })
666            .sum();
667        assert!(diff > 0);
668    }
669
670    #[test]
671    fn test_grayscale() {
672        let aug = ContrastiveAugmentation::default();
673        let img = create_test_image(64, 64);
674        let gray = aug.to_grayscale(&img);
675
676        // Check that all channels are equal
677        for pixel in gray.pixels() {
678            assert_eq!(pixel[0], pixel[1]);
679            assert_eq!(pixel[1], pixel[2]);
680        }
681    }
682
683    #[test]
684    fn test_gaussian_blur() {
685        let mut aug = ContrastiveAugmentation::builder()
686            .gaussian_blur(5, (1.0, 1.0))
687            .seed(42)
688            .build();
689
690        let img = create_test_image(64, 64);
691        let blurred = aug.gaussian_blur(&img).unwrap();
692
693        assert_eq!(blurred.dimensions(), img.dimensions());
694    }
695
696    #[test]
697    fn test_generate_pair() {
698        let mut aug = ContrastiveAugmentation::builder()
699            .output_size(32, 32)
700            .seed(42)
701            .build();
702
703        let img = DynamicImage::ImageRgb8(create_test_image(128, 128));
704        let (view1, view2) = aug.generate_pair(&img).unwrap();
705
706        // Both views should have target size
707        assert_eq!(view1.dimensions(), (32, 32));
708        assert_eq!(view2.dimensions(), (32, 32));
709
710        // Views should be different
711        let diff: u32 = view1
712            .pixels()
713            .zip(view2.pixels())
714            .map(|(p1, p2)| {
715                (p1[0] as i32 - p2[0] as i32).unsigned_abs()
716                    + (p1[1] as i32 - p2[1] as i32).unsigned_abs()
717                    + (p1[2] as i32 - p2[2] as i32).unsigned_abs()
718            })
719            .sum();
720        assert!(diff > 0, "Two augmented views should differ");
721    }
722
723    #[test]
724    fn test_full_pipeline() {
725        let mut aug = ContrastiveAugmentation::builder()
726            .crop_scale(0.5, 1.0)
727            .horizontal_flip_prob(1.0) // Always flip for testing
728            .color_jitter(0.3, 0.3, 0.3, 0.1)
729            .grayscale_prob(0.0) // Never grayscale for consistent test
730            .output_size(48, 48)
731            .seed(12345)
732            .build();
733
734        let img = DynamicImage::ImageRgb8(create_test_image(200, 200));
735        let result = aug.augment(&img).unwrap();
736
737        assert_eq!(result.dimensions(), (48, 48));
738    }
739
740    #[test]
741    fn test_rgb_hsv_roundtrip() {
742        let test_values = [
743            (1.0, 0.0, 0.0), // Red
744            (0.0, 1.0, 0.0), // Green
745            (0.0, 0.0, 1.0), // Blue
746            (0.5, 0.5, 0.5), // Gray
747            (1.0, 1.0, 1.0), // White
748            (0.0, 0.0, 0.0), // Black
749        ];
750
751        for (r, g, b) in test_values {
752            let (h, s, v) = rgb_to_hsv(r, g, b);
753            let (r2, g2, b2) = hsv_to_rgb(h, s, v);
754
755            assert!((r - r2).abs() < 1e-6, "R mismatch for ({}, {}, {})", r, g, b);
756            assert!((g - g2).abs() < 1e-6, "G mismatch for ({}, {}, {})", r, g, b);
757            assert!((b - b2).abs() < 1e-6, "B mismatch for ({}, {}, {})", r, g, b);
758        }
759    }
760
761    #[test]
762    fn test_default_config() {
763        let config = AugmentationConfig::default();
764
765        assert!((config.crop_scale_min - 0.08).abs() < 1e-6);
766        assert!((config.crop_scale_max - 1.0).abs() < 1e-6);
767        assert!((config.horizontal_flip_prob - 0.5).abs() < 1e-6);
768        assert_eq!(config.output_size, (224, 224));
769    }
770}
771
772#[cfg(test)]
773mod tests_no_feature {
774    use super::*;
775
776    #[test]
777    fn test_builder_without_image_feature() {
778        // This test should work even without the augmentation feature
779        let aug = ContrastiveAugmentation::builder()
780            .crop_scale(0.5, 1.0)
781            .horizontal_flip_prob(0.3)
782            .output_size(128, 128)
783            .seed(42)
784            .build();
785
786        assert_eq!(aug.config().crop_scale_min, 0.5);
787        assert_eq!(aug.config().horizontal_flip_prob, 0.3);
788    }
789
790    #[test]
791    fn test_default_config() {
792        let config = AugmentationConfig::default();
793        assert!((config.crop_scale_min - 0.08).abs() < 1e-6);
794        assert_eq!(config.output_size, (224, 224));
795    }
796}