1#[cfg(feature = "augmentation")]
20use crate::error::{CnnError, CnnResult};
21#[cfg(feature = "augmentation")]
22use image::{DynamicImage, GenericImageView, ImageBuffer, Rgb, RgbImage};
23#[cfg(feature = "augmentation")]
24use rand::Rng;
25use serde::{Deserialize, Serialize};
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct AugmentationConfig {
30 pub crop_scale_min: f64,
32 pub crop_scale_max: f64,
34 pub aspect_ratio_min: f64,
36 pub aspect_ratio_max: f64,
38 pub horizontal_flip_prob: f64,
40 pub brightness: f64,
42 pub contrast: f64,
44 pub saturation: f64,
46 pub hue: f64,
48 pub color_jitter_prob: f64,
50 pub grayscale_prob: f64,
52 pub blur_kernel_size: u32,
54 pub blur_prob: f64,
56 pub blur_sigma_range: (f64, f64),
58 pub output_size: (u32, u32),
60}
61
62impl Default for AugmentationConfig {
63 fn default() -> Self {
64 Self {
65 crop_scale_min: 0.08,
66 crop_scale_max: 1.0,
67 aspect_ratio_min: 0.75,
68 aspect_ratio_max: 4.0 / 3.0,
69 horizontal_flip_prob: 0.5,
70 brightness: 0.4,
71 contrast: 0.4,
72 saturation: 0.4,
73 hue: 0.1,
74 color_jitter_prob: 0.8,
75 grayscale_prob: 0.2,
76 blur_kernel_size: 0,
77 blur_prob: 0.5,
78 blur_sigma_range: (0.1, 2.0),
79 output_size: (224, 224),
80 }
81 }
82}
83
84#[derive(Debug, Clone)]
86pub struct ContrastiveAugmentationBuilder {
87 config: AugmentationConfig,
88 seed: Option<u64>,
89}
90
91impl ContrastiveAugmentationBuilder {
92 pub fn new() -> Self {
94 Self {
95 config: AugmentationConfig::default(),
96 seed: None,
97 }
98 }
99
100 pub fn crop_scale(mut self, min: f64, max: f64) -> Self {
102 self.config.crop_scale_min = min;
103 self.config.crop_scale_max = max;
104 self
105 }
106
107 pub fn aspect_ratio(mut self, min: f64, max: f64) -> Self {
109 self.config.aspect_ratio_min = min;
110 self.config.aspect_ratio_max = max;
111 self
112 }
113
114 pub fn horizontal_flip_prob(mut self, prob: f64) -> Self {
116 self.config.horizontal_flip_prob = prob;
117 self
118 }
119
120 pub fn color_jitter(mut self, brightness: f64, contrast: f64, saturation: f64, hue: f64) -> Self {
122 self.config.brightness = brightness;
123 self.config.contrast = contrast;
124 self.config.saturation = saturation;
125 self.config.hue = hue;
126 self
127 }
128
129 pub fn color_jitter_prob(mut self, prob: f64) -> Self {
131 self.config.color_jitter_prob = prob;
132 self
133 }
134
135 pub fn grayscale_prob(mut self, prob: f64) -> Self {
137 self.config.grayscale_prob = prob;
138 self
139 }
140
141 pub fn gaussian_blur(mut self, kernel_size: u32, sigma_range: (f64, f64)) -> Self {
143 self.config.blur_kernel_size = kernel_size;
144 self.config.blur_sigma_range = sigma_range;
145 self
146 }
147
148 pub fn blur_prob(mut self, prob: f64) -> Self {
150 self.config.blur_prob = prob;
151 self
152 }
153
154 pub fn output_size(mut self, width: u32, height: u32) -> Self {
156 self.config.output_size = (width, height);
157 self
158 }
159
160 pub fn seed(mut self, seed: u64) -> Self {
162 self.seed = Some(seed);
163 self
164 }
165
166 pub fn build(self) -> ContrastiveAugmentation {
168 let rng = if let Some(seed) = self.seed {
169 rand::SeedableRng::seed_from_u64(seed)
170 } else {
171 rand::SeedableRng::from_entropy()
172 };
173 ContrastiveAugmentation {
174 config: self.config,
175 rng,
176 }
177 }
178}
179
180impl Default for ContrastiveAugmentationBuilder {
181 fn default() -> Self {
182 Self::new()
183 }
184}
185
186#[derive(Debug, Clone)]
204pub struct ContrastiveAugmentation {
205 config: AugmentationConfig,
206 #[allow(dead_code)]
208 rng: rand::rngs::StdRng,
209}
210
211impl ContrastiveAugmentation {
212 pub fn builder() -> ContrastiveAugmentationBuilder {
214 ContrastiveAugmentationBuilder::new()
215 }
216
217 pub fn config(&self) -> &AugmentationConfig {
219 &self.config
220 }
221
222 #[cfg(feature = "augmentation")]
234 pub fn generate_pair(&mut self, image: &DynamicImage) -> CnnResult<(RgbImage, RgbImage)> {
235 let view1 = self.augment(image)?;
236 let view2 = self.augment(image)?;
237 Ok((view1, view2))
238 }
239
240 #[cfg(feature = "augmentation")]
242 pub fn augment(&mut self, image: &DynamicImage) -> CnnResult<RgbImage> {
243 let mut img = image.to_rgb8();
244
245 img = self.random_resized_crop(&img)?;
247
248 if self.rng.gen::<f64>() < self.config.horizontal_flip_prob {
250 img = self.horizontal_flip(&img);
251 }
252
253 if self.rng.gen::<f64>() < self.config.color_jitter_prob {
255 img = self.color_jitter(&img)?;
256 }
257
258 if self.rng.gen::<f64>() < self.config.grayscale_prob {
260 img = self.to_grayscale(&img);
261 }
262
263 if self.config.blur_kernel_size > 0 && self.rng.gen::<f64>() < self.config.blur_prob {
265 img = self.gaussian_blur(&img)?;
266 }
267
268 Ok(img)
269 }
270
271 #[cfg(feature = "augmentation")]
273 pub fn random_resized_crop(&mut self, image: &RgbImage) -> CnnResult<RgbImage> {
274 let (orig_w, orig_h) = image.dimensions();
275 let orig_area = (orig_w * orig_h) as f64;
276
277 for _ in 0..10 {
279 let scale = self.rng.gen_range(self.config.crop_scale_min..=self.config.crop_scale_max);
281 let aspect = self.rng.gen_range(
282 self.config.aspect_ratio_min.ln()..=self.config.aspect_ratio_max.ln(),
283 ).exp();
284
285 let crop_area = orig_area * scale;
287 let crop_w = (crop_area * aspect).sqrt() as u32;
288 let crop_h = (crop_area / aspect).sqrt() as u32;
289
290 if crop_w <= orig_w && crop_h <= orig_h && crop_w > 0 && crop_h > 0 {
291 let x = self.rng.gen_range(0..=(orig_w - crop_w));
293 let y = self.rng.gen_range(0..=(orig_h - crop_h));
294
295 let cropped = image::imageops::crop_imm(image, x, y, crop_w, crop_h).to_image();
297
298 let (target_w, target_h) = self.config.output_size;
300 let resized = image::imageops::resize(
301 &cropped,
302 target_w,
303 target_h,
304 image::imageops::FilterType::Lanczos3,
305 );
306
307 return Ok(resized);
308 }
309 }
310
311 let (target_w, target_h) = self.config.output_size;
313 let target_ratio = target_w as f64 / target_h as f64;
314 let orig_ratio = orig_w as f64 / orig_h as f64;
315
316 let (crop_w, crop_h) = if orig_ratio > target_ratio {
317 let h = orig_h;
319 let w = (h as f64 * target_ratio) as u32;
320 (w, h)
321 } else {
322 let w = orig_w;
324 let h = (w as f64 / target_ratio) as u32;
325 (w, h)
326 };
327
328 let x = (orig_w - crop_w) / 2;
329 let y = (orig_h - crop_h) / 2;
330
331 let cropped = image::imageops::crop_imm(image, x, y, crop_w, crop_h).to_image();
332 let resized = image::imageops::resize(
333 &cropped,
334 target_w,
335 target_h,
336 image::imageops::FilterType::Lanczos3,
337 );
338
339 Ok(resized)
340 }
341
342 #[cfg(feature = "augmentation")]
344 pub fn horizontal_flip(&self, image: &RgbImage) -> RgbImage {
345 image::imageops::flip_horizontal(image)
346 }
347
348 #[cfg(feature = "augmentation")]
350 pub fn color_jitter(&mut self, image: &RgbImage) -> CnnResult<RgbImage> {
351 let (width, height) = image.dimensions();
352 let mut result = image.clone();
353
354 let brightness_factor = 1.0 + self.rng.gen_range(-self.config.brightness..=self.config.brightness);
356 let contrast_factor = 1.0 + self.rng.gen_range(-self.config.contrast..=self.config.contrast);
357 let saturation_factor = 1.0 + self.rng.gen_range(-self.config.saturation..=self.config.saturation);
358 let hue_shift = self.rng.gen_range(-self.config.hue..=self.config.hue);
359
360 let mean = self.compute_mean(image);
362
363 for y in 0..height {
364 for x in 0..width {
365 let pixel = image.get_pixel(x, y);
366 let mut rgb = [pixel[0] as f64 / 255.0, pixel[1] as f64 / 255.0, pixel[2] as f64 / 255.0];
367
368 for c in rgb.iter_mut() {
370 *c *= brightness_factor;
371 }
372
373 for (i, c) in rgb.iter_mut().enumerate() {
375 *c = (*c - mean[i]) * contrast_factor + mean[i];
376 }
377
378 let (h, s, v) = rgb_to_hsv(rgb[0], rgb[1], rgb[2]);
380 let new_s = (s * saturation_factor).clamp(0.0, 1.0);
381 let new_h = (h + hue_shift * 360.0).rem_euclid(360.0);
382 let (r, g, b) = hsv_to_rgb(new_h, new_s, v);
383
384 rgb = [r, g, b];
385
386 let out_pixel = Rgb([
388 (rgb[0] * 255.0).clamp(0.0, 255.0) as u8,
389 (rgb[1] * 255.0).clamp(0.0, 255.0) as u8,
390 (rgb[2] * 255.0).clamp(0.0, 255.0) as u8,
391 ]);
392 result.put_pixel(x, y, out_pixel);
393 }
394 }
395
396 Ok(result)
397 }
398
399 #[cfg(feature = "augmentation")]
401 pub fn to_grayscale(&self, image: &RgbImage) -> RgbImage {
402 let (width, height) = image.dimensions();
403 let mut result = ImageBuffer::new(width, height);
404
405 for y in 0..height {
406 for x in 0..width {
407 let pixel = image.get_pixel(x, y);
408 let gray = (0.299 * pixel[0] as f64
410 + 0.587 * pixel[1] as f64
411 + 0.114 * pixel[2] as f64) as u8;
412 result.put_pixel(x, y, Rgb([gray, gray, gray]));
413 }
414 }
415
416 result
417 }
418
419 #[cfg(feature = "augmentation")]
421 pub fn gaussian_blur(&mut self, image: &RgbImage) -> CnnResult<RgbImage> {
422 let sigma = self.rng.gen_range(self.config.blur_sigma_range.0..=self.config.blur_sigma_range.1);
423
424 let kernel_size = if self.config.blur_kernel_size > 0 {
426 self.config.blur_kernel_size
427 } else {
428 let k = (sigma * 6.0).ceil() as u32;
429 if k % 2 == 0 { k + 1 } else { k }
430 };
431
432 let kernel = self.generate_gaussian_kernel(kernel_size, sigma);
434
435 let blurred = self.convolve_separable(image, &kernel)?;
437
438 Ok(blurred)
439 }
440
441 #[cfg(feature = "augmentation")]
443 fn generate_gaussian_kernel(&self, size: u32, sigma: f64) -> Vec<f64> {
444 let size = size as i32;
445 let center = size / 2;
446 let mut kernel = Vec::with_capacity(size as usize);
447 let mut sum = 0.0;
448
449 let sigma_sq_2 = 2.0 * sigma * sigma;
450
451 for i in 0..size {
452 let x = (i - center) as f64;
453 let value = (-x * x / sigma_sq_2).exp();
454 kernel.push(value);
455 sum += value;
456 }
457
458 for k in kernel.iter_mut() {
460 *k /= sum;
461 }
462
463 kernel
464 }
465
466 #[cfg(feature = "augmentation")]
468 fn convolve_separable(&self, image: &RgbImage, kernel: &[f64]) -> CnnResult<RgbImage> {
469 let (width, height) = image.dimensions();
470 let radius = kernel.len() / 2;
471
472 let mut temp = ImageBuffer::<Rgb<u8>, _>::new(width, height);
474 for y in 0..height {
475 for x in 0..width {
476 let mut sum = [0.0, 0.0, 0.0];
477 for (i, &k) in kernel.iter().enumerate() {
478 let sx = (x as i32 + i as i32 - radius as i32).clamp(0, width as i32 - 1) as u32;
479 let pixel = image.get_pixel(sx, y);
480 sum[0] += pixel[0] as f64 * k;
481 sum[1] += pixel[1] as f64 * k;
482 sum[2] += pixel[2] as f64 * k;
483 }
484 temp.put_pixel(x, y, Rgb([
485 sum[0].clamp(0.0, 255.0) as u8,
486 sum[1].clamp(0.0, 255.0) as u8,
487 sum[2].clamp(0.0, 255.0) as u8,
488 ]));
489 }
490 }
491
492 let mut result = ImageBuffer::<Rgb<u8>, _>::new(width, height);
494 for y in 0..height {
495 for x in 0..width {
496 let mut sum = [0.0, 0.0, 0.0];
497 for (i, &k) in kernel.iter().enumerate() {
498 let sy = (y as i32 + i as i32 - radius as i32).clamp(0, height as i32 - 1) as u32;
499 let pixel = temp.get_pixel(x, sy);
500 sum[0] += pixel[0] as f64 * k;
501 sum[1] += pixel[1] as f64 * k;
502 sum[2] += pixel[2] as f64 * k;
503 }
504 result.put_pixel(x, y, Rgb([
505 sum[0].clamp(0.0, 255.0) as u8,
506 sum[1].clamp(0.0, 255.0) as u8,
507 sum[2].clamp(0.0, 255.0) as u8,
508 ]));
509 }
510 }
511
512 Ok(result)
513 }
514
515 #[cfg(feature = "augmentation")]
517 fn compute_mean(&self, image: &RgbImage) -> [f64; 3] {
518 let (width, height) = image.dimensions();
519 let n = (width * height) as f64;
520 let mut sum = [0.0, 0.0, 0.0];
521
522 for pixel in image.pixels() {
523 sum[0] += pixel[0] as f64 / 255.0;
524 sum[1] += pixel[1] as f64 / 255.0;
525 sum[2] += pixel[2] as f64 / 255.0;
526 }
527
528 [sum[0] / n, sum[1] / n, sum[2] / n]
529 }
530}
531
532impl Default for ContrastiveAugmentation {
533 fn default() -> Self {
534 Self::builder().build()
535 }
536}
537
538#[cfg(feature = "augmentation")]
540fn rgb_to_hsv(r: f64, g: f64, b: f64) -> (f64, f64, f64) {
541 let max = r.max(g).max(b);
542 let min = r.min(g).min(b);
543 let delta = max - min;
544
545 let v = max;
546
547 let s = if max > 1e-8 { delta / max } else { 0.0 };
548
549 let h = if delta < 1e-8 {
550 0.0
551 } else if (max - r).abs() < 1e-8 {
552 60.0 * (((g - b) / delta) % 6.0)
553 } else if (max - g).abs() < 1e-8 {
554 60.0 * ((b - r) / delta + 2.0)
555 } else {
556 60.0 * ((r - g) / delta + 4.0)
557 };
558
559 let h = if h < 0.0 { h + 360.0 } else { h };
560
561 (h, s, v)
562}
563
564#[cfg(feature = "augmentation")]
566fn hsv_to_rgb(h: f64, s: f64, v: f64) -> (f64, f64, f64) {
567 let c = v * s;
568 let h_prime = h / 60.0;
569 let x = c * (1.0 - (h_prime % 2.0 - 1.0).abs());
570
571 let (r1, g1, b1) = if h_prime < 1.0 {
572 (c, x, 0.0)
573 } else if h_prime < 2.0 {
574 (x, c, 0.0)
575 } else if h_prime < 3.0 {
576 (0.0, c, x)
577 } else if h_prime < 4.0 {
578 (0.0, x, c)
579 } else if h_prime < 5.0 {
580 (x, 0.0, c)
581 } else {
582 (c, 0.0, x)
583 };
584
585 let m = v - c;
586 (r1 + m, g1 + m, b1 + m)
587}
588
589#[cfg(all(test, feature = "augmentation"))]
590mod tests {
591 use super::*;
592
593 fn create_test_image(width: u32, height: u32) -> RgbImage {
594 let mut img = ImageBuffer::new(width, height);
595 for y in 0..height {
596 for x in 0..width {
597 let r = ((x * 255) / width) as u8;
598 let g = ((y * 255) / height) as u8;
599 let b = 128;
600 img.put_pixel(x, y, Rgb([r, g, b]));
601 }
602 }
603 img
604 }
605
606 #[test]
607 fn test_augmentation_builder() {
608 let aug = ContrastiveAugmentation::builder()
609 .crop_scale(0.5, 1.0)
610 .horizontal_flip_prob(0.3)
611 .output_size(128, 128)
612 .seed(42)
613 .build();
614
615 assert_eq!(aug.config.crop_scale_min, 0.5);
616 assert_eq!(aug.config.horizontal_flip_prob, 0.3);
617 assert_eq!(aug.config.output_size, (128, 128));
618 }
619
620 #[test]
621 fn test_random_resized_crop() {
622 let mut aug = ContrastiveAugmentation::builder()
623 .output_size(64, 64)
624 .seed(42)
625 .build();
626
627 let img = create_test_image(256, 256);
628 let cropped = aug.random_resized_crop(&img).unwrap();
629
630 assert_eq!(cropped.dimensions(), (64, 64));
631 }
632
633 #[test]
634 fn test_horizontal_flip() {
635 let aug = ContrastiveAugmentation::default();
636 let img = create_test_image(4, 4);
637 let flipped = aug.horizontal_flip(&img);
638
639 assert_eq!(flipped.get_pixel(3, 0), img.get_pixel(0, 0));
641 assert_eq!(flipped.get_pixel(0, 0), img.get_pixel(3, 0));
642 }
643
644 #[test]
645 fn test_color_jitter() {
646 let mut aug = ContrastiveAugmentation::builder()
647 .color_jitter(0.2, 0.2, 0.2, 0.05)
648 .seed(42)
649 .build();
650
651 let img = create_test_image(64, 64);
652 let jittered = aug.color_jitter(&img).unwrap();
653
654 assert_eq!(jittered.dimensions(), img.dimensions());
656
657 let diff: u32 = img
659 .pixels()
660 .zip(jittered.pixels())
661 .map(|(p1, p2)| {
662 (p1[0] as i32 - p2[0] as i32).unsigned_abs()
663 + (p1[1] as i32 - p2[1] as i32).unsigned_abs()
664 + (p1[2] as i32 - p2[2] as i32).unsigned_abs()
665 })
666 .sum();
667 assert!(diff > 0);
668 }
669
670 #[test]
671 fn test_grayscale() {
672 let aug = ContrastiveAugmentation::default();
673 let img = create_test_image(64, 64);
674 let gray = aug.to_grayscale(&img);
675
676 for pixel in gray.pixels() {
678 assert_eq!(pixel[0], pixel[1]);
679 assert_eq!(pixel[1], pixel[2]);
680 }
681 }
682
683 #[test]
684 fn test_gaussian_blur() {
685 let mut aug = ContrastiveAugmentation::builder()
686 .gaussian_blur(5, (1.0, 1.0))
687 .seed(42)
688 .build();
689
690 let img = create_test_image(64, 64);
691 let blurred = aug.gaussian_blur(&img).unwrap();
692
693 assert_eq!(blurred.dimensions(), img.dimensions());
694 }
695
696 #[test]
697 fn test_generate_pair() {
698 let mut aug = ContrastiveAugmentation::builder()
699 .output_size(32, 32)
700 .seed(42)
701 .build();
702
703 let img = DynamicImage::ImageRgb8(create_test_image(128, 128));
704 let (view1, view2) = aug.generate_pair(&img).unwrap();
705
706 assert_eq!(view1.dimensions(), (32, 32));
708 assert_eq!(view2.dimensions(), (32, 32));
709
710 let diff: u32 = view1
712 .pixels()
713 .zip(view2.pixels())
714 .map(|(p1, p2)| {
715 (p1[0] as i32 - p2[0] as i32).unsigned_abs()
716 + (p1[1] as i32 - p2[1] as i32).unsigned_abs()
717 + (p1[2] as i32 - p2[2] as i32).unsigned_abs()
718 })
719 .sum();
720 assert!(diff > 0, "Two augmented views should differ");
721 }
722
723 #[test]
724 fn test_full_pipeline() {
725 let mut aug = ContrastiveAugmentation::builder()
726 .crop_scale(0.5, 1.0)
727 .horizontal_flip_prob(1.0) .color_jitter(0.3, 0.3, 0.3, 0.1)
729 .grayscale_prob(0.0) .output_size(48, 48)
731 .seed(12345)
732 .build();
733
734 let img = DynamicImage::ImageRgb8(create_test_image(200, 200));
735 let result = aug.augment(&img).unwrap();
736
737 assert_eq!(result.dimensions(), (48, 48));
738 }
739
740 #[test]
741 fn test_rgb_hsv_roundtrip() {
742 let test_values = [
743 (1.0, 0.0, 0.0), (0.0, 1.0, 0.0), (0.0, 0.0, 1.0), (0.5, 0.5, 0.5), (1.0, 1.0, 1.0), (0.0, 0.0, 0.0), ];
750
751 for (r, g, b) in test_values {
752 let (h, s, v) = rgb_to_hsv(r, g, b);
753 let (r2, g2, b2) = hsv_to_rgb(h, s, v);
754
755 assert!((r - r2).abs() < 1e-6, "R mismatch for ({}, {}, {})", r, g, b);
756 assert!((g - g2).abs() < 1e-6, "G mismatch for ({}, {}, {})", r, g, b);
757 assert!((b - b2).abs() < 1e-6, "B mismatch for ({}, {}, {})", r, g, b);
758 }
759 }
760
761 #[test]
762 fn test_default_config() {
763 let config = AugmentationConfig::default();
764
765 assert!((config.crop_scale_min - 0.08).abs() < 1e-6);
766 assert!((config.crop_scale_max - 1.0).abs() < 1e-6);
767 assert!((config.horizontal_flip_prob - 0.5).abs() < 1e-6);
768 assert_eq!(config.output_size, (224, 224));
769 }
770}
771
772#[cfg(test)]
773mod tests_no_feature {
774 use super::*;
775
776 #[test]
777 fn test_builder_without_image_feature() {
778 let aug = ContrastiveAugmentation::builder()
780 .crop_scale(0.5, 1.0)
781 .horizontal_flip_prob(0.3)
782 .output_size(128, 128)
783 .seed(42)
784 .build();
785
786 assert_eq!(aug.config().crop_scale_min, 0.5);
787 assert_eq!(aug.config().horizontal_flip_prob, 0.3);
788 }
789
790 #[test]
791 fn test_default_config() {
792 let config = AugmentationConfig::default();
793 assert!((config.crop_scale_min - 0.08).abs() < 1e-6);
794 assert_eq!(config.output_size, (224, 224));
795 }
796}