1use crate::{TrainError, TrainResult};
14use scirs2_core::ndarray::{Array, Array2, ArrayView2};
15use scirs2_core::random::{Rng, StdRng};
16
17pub trait DataAugmenter {
19 fn augment(&self, data: &ArrayView2<f64>, rng: &mut StdRng) -> TrainResult<Array2<f64>>;
28}
29
30#[derive(Debug, Clone, Default)]
34pub struct NoAugmentation;
35
36impl DataAugmenter for NoAugmentation {
37 fn augment(&self, data: &ArrayView2<f64>, _rng: &mut StdRng) -> TrainResult<Array2<f64>> {
38 Ok(data.to_owned())
39 }
40}
41
42#[derive(Debug, Clone)]
46pub struct NoiseAugmenter {
47 pub std_dev: f64,
49}
50
51impl NoiseAugmenter {
52 pub fn new(std_dev: f64) -> TrainResult<Self> {
57 if std_dev < 0.0 {
58 return Err(TrainError::InvalidParameter(
59 "std_dev must be non-negative".to_string(),
60 ));
61 }
62 Ok(Self { std_dev })
63 }
64}
65
66impl Default for NoiseAugmenter {
67 fn default() -> Self {
68 Self { std_dev: 0.01 }
69 }
70}
71
72impl DataAugmenter for NoiseAugmenter {
73 fn augment(&self, data: &ArrayView2<f64>, rng: &mut StdRng) -> TrainResult<Array2<f64>> {
74 let mut augmented = data.to_owned();
75
76 for value in augmented.iter_mut() {
78 let u1: f64 = rng.random();
79 let u2: f64 = rng.random();
80
81 let z0 = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
83 let noise = z0 * self.std_dev;
84
85 *value += noise;
86 }
87
88 Ok(augmented)
89 }
90}
91
92#[derive(Debug, Clone)]
96pub struct ScaleAugmenter {
97 pub scale_range: f64,
99}
100
101impl ScaleAugmenter {
102 pub fn new(scale_range: f64) -> TrainResult<Self> {
107 if !(0.0..=1.0).contains(&scale_range) {
108 return Err(TrainError::InvalidParameter(
109 "scale_range must be in [0, 1]".to_string(),
110 ));
111 }
112 Ok(Self { scale_range })
113 }
114}
115
116impl Default for ScaleAugmenter {
117 fn default() -> Self {
118 Self { scale_range: 0.1 }
119 }
120}
121
122impl DataAugmenter for ScaleAugmenter {
123 fn augment(&self, data: &ArrayView2<f64>, rng: &mut StdRng) -> TrainResult<Array2<f64>> {
124 let scale = 1.0 + (rng.random::<f64>() * 2.0 - 1.0) * self.scale_range;
126
127 let augmented = data.mapv(|x| x * scale);
128 Ok(augmented)
129 }
130}
131
132#[derive(Debug, Clone)]
137pub struct RotationAugmenter {
138 pub max_angle: f64,
140}
141
142impl RotationAugmenter {
143 pub fn new(max_angle: f64) -> Self {
148 Self { max_angle }
149 }
150}
151
152impl Default for RotationAugmenter {
153 fn default() -> Self {
154 Self {
155 max_angle: std::f64::consts::PI / 18.0, }
157 }
158}
159
160impl DataAugmenter for RotationAugmenter {
161 fn augment(&self, data: &ArrayView2<f64>, rng: &mut StdRng) -> TrainResult<Array2<f64>> {
162 let angle = (rng.random::<f64>() * 2.0 - 1.0) * self.max_angle;
165
166 let cos_a = angle.cos();
168 let sin_a = angle.sin();
169
170 let augmented = data.mapv(|x| x * cos_a + x * sin_a * 0.1);
171 Ok(augmented)
172 }
173}
174
175#[derive(Debug, Clone)]
182pub struct MixupAugmenter {
183 pub alpha: f64,
185}
186
187impl MixupAugmenter {
188 pub fn new(alpha: f64) -> TrainResult<Self> {
193 if alpha <= 0.0 {
194 return Err(TrainError::InvalidParameter(
195 "alpha must be positive".to_string(),
196 ));
197 }
198 Ok(Self { alpha })
199 }
200
201 pub fn augment_batch(
211 &self,
212 data: &ArrayView2<f64>,
213 labels: &ArrayView2<f64>,
214 rng: &mut StdRng,
215 ) -> TrainResult<(Array2<f64>, Array2<f64>)> {
216 if data.nrows() != labels.nrows() {
217 return Err(TrainError::InvalidParameter(
218 "data and labels must have same number of rows".to_string(),
219 ));
220 }
221
222 let n = data.nrows();
223 let mut augmented_data = Array::zeros(data.raw_dim());
224 let mut augmented_labels = Array::zeros(labels.raw_dim());
225
226 let mut indices: Vec<usize> = (0..n).collect();
228 for i in (1..n).rev() {
229 let j = rng.gen_range(0..=i);
230 indices.swap(i, j);
231 }
232
233 for i in 0..n {
234 let j = indices[i];
235
236 let lambda = self.sample_beta(rng);
239
240 for k in 0..data.ncols() {
242 augmented_data[[i, k]] = lambda * data[[i, k]] + (1.0 - lambda) * data[[j, k]];
243 }
244
245 for k in 0..labels.ncols() {
247 augmented_labels[[i, k]] =
248 lambda * labels[[i, k]] + (1.0 - lambda) * labels[[j, k]];
249 }
250 }
251
252 Ok((augmented_data, augmented_labels))
253 }
254
255 fn sample_beta(&self, rng: &mut StdRng) -> f64 {
259 if self.alpha < 0.5 {
260 if rng.random::<f64>() < 0.5 {
262 rng.random::<f64>().powf(2.0)
263 } else {
264 1.0 - rng.random::<f64>().powf(2.0)
265 }
266 } else {
267 rng.random::<f64>()
269 }
270 }
271}
272
273impl Default for MixupAugmenter {
274 fn default() -> Self {
275 Self { alpha: 1.0 }
276 }
277}
278
279impl DataAugmenter for MixupAugmenter {
280 fn augment(&self, data: &ArrayView2<f64>, _rng: &mut StdRng) -> TrainResult<Array2<f64>> {
281 Ok(data.to_owned())
284 }
285}
286
287#[derive(Debug, Clone)]
296pub struct CutMixAugmenter {
297 pub alpha: f64,
299}
300
301impl CutMixAugmenter {
302 pub fn new(alpha: f64) -> TrainResult<Self> {
310 if alpha <= 0.0 {
311 return Err(TrainError::InvalidParameter(
312 "alpha must be positive".to_string(),
313 ));
314 }
315 Ok(Self { alpha })
316 }
317
318 pub fn augment_batch(
331 &self,
332 data: &ArrayView2<f64>,
333 labels: &ArrayView2<f64>,
334 rng: &mut StdRng,
335 ) -> TrainResult<(Array2<f64>, Array2<f64>)> {
336 if data.nrows() != labels.nrows() {
337 return Err(TrainError::InvalidParameter(
338 "data and labels must have same number of rows".to_string(),
339 ));
340 }
341
342 let n = data.nrows();
343 let features = data.ncols();
344 let mut augmented_data = data.to_owned();
345 let mut augmented_labels = labels.to_owned();
346
347 let mut indices: Vec<usize> = (0..n).collect();
349 for i in (1..n).rev() {
350 let j = rng.gen_range(0..=i);
351 indices.swap(i, j);
352 }
353
354 for i in 0..n {
355 let j = indices[i];
356
357 let lambda = self.sample_beta(rng);
359
360 let cut_ratio = (1.0 - lambda).sqrt();
363 let cut_size = (features as f64 * cut_ratio) as usize;
364 let cut_size = cut_size.max(1).min(features - 1);
365
366 let start = if features > cut_size {
368 rng.gen_range(0..=(features - cut_size))
369 } else {
370 0
371 };
372
373 for k in start..(start + cut_size).min(features) {
375 augmented_data[[i, k]] = data[[j, k]];
376 }
377
378 let actual_ratio = cut_size as f64 / features as f64;
380 for k in 0..labels.ncols() {
381 augmented_labels[[i, k]] =
382 (1.0 - actual_ratio) * labels[[i, k]] + actual_ratio * labels[[j, k]];
383 }
384 }
385
386 Ok((augmented_data, augmented_labels))
387 }
388
389 fn sample_beta(&self, rng: &mut StdRng) -> f64 {
393 if self.alpha < 0.5 {
394 if rng.random::<f64>() < 0.5 {
396 rng.random::<f64>().powf(2.0)
397 } else {
398 1.0 - rng.random::<f64>().powf(2.0)
399 }
400 } else {
401 rng.random::<f64>()
403 }
404 }
405}
406
407impl Default for CutMixAugmenter {
408 fn default() -> Self {
409 Self { alpha: 1.0 }
410 }
411}
412
413impl DataAugmenter for CutMixAugmenter {
414 fn augment(&self, data: &ArrayView2<f64>, _rng: &mut StdRng) -> TrainResult<Array2<f64>> {
415 Ok(data.to_owned())
418 }
419}
420
421#[derive(Clone, Default)]
423pub struct CompositeAugmenter {
424 augmenters: Vec<Box<dyn AugmenterClone>>,
425}
426
427impl std::fmt::Debug for CompositeAugmenter {
428 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
429 f.debug_struct("CompositeAugmenter")
430 .field("num_augmenters", &self.augmenters.len())
431 .finish()
432 }
433}
434
435trait AugmenterClone: DataAugmenter {
437 fn clone_box(&self) -> Box<dyn AugmenterClone>;
438}
439
440impl<T: DataAugmenter + Clone + 'static> AugmenterClone for T {
441 fn clone_box(&self) -> Box<dyn AugmenterClone> {
442 Box::new(self.clone())
443 }
444}
445
446impl Clone for Box<dyn AugmenterClone> {
447 fn clone(&self) -> Self {
448 self.clone_box()
449 }
450}
451
452impl DataAugmenter for Box<dyn AugmenterClone> {
453 fn augment(&self, data: &ArrayView2<f64>, rng: &mut StdRng) -> TrainResult<Array2<f64>> {
454 (**self).augment(data, rng)
455 }
456}
457
458impl CompositeAugmenter {
459 pub fn new() -> Self {
461 Self {
462 augmenters: Vec::new(),
463 }
464 }
465
466 pub fn add<A: DataAugmenter + Clone + 'static>(&mut self, augmenter: A) {
468 self.augmenters.push(Box::new(augmenter));
469 }
470
471 pub fn len(&self) -> usize {
473 self.augmenters.len()
474 }
475
476 pub fn is_empty(&self) -> bool {
478 self.augmenters.is_empty()
479 }
480}
481
482impl DataAugmenter for CompositeAugmenter {
483 fn augment(&self, data: &ArrayView2<f64>, rng: &mut StdRng) -> TrainResult<Array2<f64>> {
484 let mut result = data.to_owned();
485
486 for augmenter in &self.augmenters {
487 result = augmenter.augment(&result.view(), rng)?;
488 }
489
490 Ok(result)
491 }
492}
493
494#[derive(Debug, Clone)]
507pub struct RandomErasingAugmenter {
508 pub probability: f64,
510 pub scale_min: f64,
512 pub scale_max: f64,
514 pub ratio_min: f64,
516 pub ratio_max: f64,
518 pub fill_value: f64,
520}
521
522impl RandomErasingAugmenter {
523 pub fn new(
525 probability: f64,
526 scale_min: f64,
527 scale_max: f64,
528 ratio_min: f64,
529 ratio_max: f64,
530 fill_value: f64,
531 ) -> TrainResult<Self> {
532 if !(0.0..=1.0).contains(&probability) {
533 return Err(TrainError::InvalidParameter(
534 "probability must be in [0, 1]".to_string(),
535 ));
536 }
537 if scale_min >= scale_max || scale_min < 0.0 || scale_max > 1.0 {
538 return Err(TrainError::InvalidParameter(
539 "scale range must be valid (0 <= min < max <= 1)".to_string(),
540 ));
541 }
542 if ratio_min <= 0.0 || ratio_min >= ratio_max {
543 return Err(TrainError::InvalidParameter(
544 "ratio range must be valid (0 < min < max)".to_string(),
545 ));
546 }
547
548 Ok(Self {
549 probability,
550 scale_min,
551 scale_max,
552 ratio_min,
553 ratio_max,
554 fill_value,
555 })
556 }
557
558 pub fn with_defaults() -> Self {
560 Self {
561 probability: 0.5,
562 scale_min: 0.02,
563 scale_max: 0.33,
564 ratio_min: 0.3,
565 ratio_max: 3.3,
566 fill_value: 0.0,
567 }
568 }
569}
570
571impl Default for RandomErasingAugmenter {
572 fn default() -> Self {
573 Self::with_defaults()
574 }
575}
576
577impl DataAugmenter for RandomErasingAugmenter {
578 fn augment(&self, data: &ArrayView2<f64>, rng: &mut StdRng) -> TrainResult<Array2<f64>> {
579 let mut augmented = data.to_owned();
580
581 if rng.random::<f64>() > self.probability {
583 return Ok(augmented);
584 }
585
586 let (height, width) = (data.nrows(), data.ncols());
587 let area = (height * width) as f64;
588
589 for _ in 0..10 {
591 let scale = self.scale_min + rng.random::<f64>() * (self.scale_max - self.scale_min);
593 let erase_area = area * scale;
594
595 let aspect_ratio =
597 self.ratio_min + rng.random::<f64>() * (self.ratio_max - self.ratio_min);
598
599 let h = (erase_area * aspect_ratio).sqrt().min(height as f64);
601 let w = (erase_area / aspect_ratio).sqrt().min(width as f64);
602
603 if h >= 1.0 && w >= 1.0 {
604 let erase_h = h as usize;
605 let erase_w = w as usize;
606
607 let i = if erase_h < height {
609 (rng.random::<f64>() * (height - erase_h) as f64) as usize
610 } else {
611 0
612 };
613 let j = if erase_w < width {
614 (rng.random::<f64>() * (width - erase_w) as f64) as usize
615 } else {
616 0
617 };
618
619 if self.fill_value == 1.0 {
621 for row in i..i + erase_h.min(height - i) {
623 for col in j..j + erase_w.min(width - j) {
624 augmented[[row, col]] = rng.random();
625 }
626 }
627 } else {
628 for row in i..i + erase_h.min(height - i) {
630 for col in j..j + erase_w.min(width - j) {
631 augmented[[row, col]] = self.fill_value;
632 }
633 }
634 }
635
636 break;
637 }
638 }
639
640 Ok(augmented)
641 }
642}
643
644#[derive(Debug, Clone)]
656pub struct CutOutAugmenter {
657 pub cutout_size: usize,
659 pub probability: f64,
661 pub fill_value: f64,
663}
664
665impl CutOutAugmenter {
666 pub fn new(cutout_size: usize, probability: f64, fill_value: f64) -> TrainResult<Self> {
668 if cutout_size == 0 {
669 return Err(TrainError::InvalidParameter(
670 "cutout_size must be > 0".to_string(),
671 ));
672 }
673 if !(0.0..=1.0).contains(&probability) {
674 return Err(TrainError::InvalidParameter(
675 "probability must be in [0, 1]".to_string(),
676 ));
677 }
678
679 Ok(Self {
680 cutout_size,
681 probability,
682 fill_value,
683 })
684 }
685
686 pub fn with_size(size: usize) -> TrainResult<Self> {
688 Self::new(size, 1.0, 0.0)
689 }
690}
691
692impl DataAugmenter for CutOutAugmenter {
693 fn augment(&self, data: &ArrayView2<f64>, rng: &mut StdRng) -> TrainResult<Array2<f64>> {
694 let mut augmented = data.to_owned();
695
696 if rng.random::<f64>() > self.probability {
698 return Ok(augmented);
699 }
700
701 let (height, width) = (data.nrows(), data.ncols());
702
703 let center_y = (rng.random::<f64>() * height as f64) as usize;
705 let center_x = (rng.random::<f64>() * width as f64) as usize;
706
707 let half_size = self.cutout_size / 2;
709
710 let y_start = center_y.saturating_sub(half_size);
711 let y_end = (center_y + half_size).min(height);
712
713 let x_start = center_x.saturating_sub(half_size);
714 let x_end = (center_x + half_size).min(width);
715
716 for i in y_start..y_end {
718 for j in x_start..x_end {
719 augmented[[i, j]] = self.fill_value;
720 }
721 }
722
723 Ok(augmented)
724 }
725}
726
727#[cfg(test)]
728mod tests {
729 use super::*;
730 use scirs2_core::ndarray::array;
731 use scirs2_core::random::SeedableRng;
732
733 fn create_test_rng() -> StdRng {
734 StdRng::seed_from_u64(42)
735 }
736
737 #[test]
738 fn test_no_augmentation() {
739 let augmenter = NoAugmentation;
740 let data = array![[1.0, 2.0], [3.0, 4.0]];
741 let mut rng = create_test_rng();
742
743 let augmented = augmenter.augment(&data.view(), &mut rng).unwrap();
744 assert_eq!(augmented, data);
745 }
746
747 #[test]
748 fn test_noise_augmenter() {
749 let augmenter = NoiseAugmenter::new(0.1).unwrap();
750 let data = array![[1.0, 2.0], [3.0, 4.0]];
751 let mut rng = create_test_rng();
752
753 let augmented = augmenter.augment(&data.view(), &mut rng).unwrap();
754
755 assert_eq!(augmented.shape(), data.shape());
757
758 assert_ne!(augmented[[0, 0]], data[[0, 0]]);
760
761 for i in 0..data.nrows() {
763 for j in 0..data.ncols() {
764 let diff = (augmented[[i, j]] - data[[i, j]]).abs();
765 assert!(diff < 1.0); }
767 }
768 }
769
770 #[test]
771 fn test_noise_augmenter_invalid() {
772 let result = NoiseAugmenter::new(-0.1);
773 assert!(result.is_err());
774 }
775
776 #[test]
777 fn test_scale_augmenter() {
778 let augmenter = ScaleAugmenter::new(0.2).unwrap();
779 let data = array![[1.0, 2.0], [3.0, 4.0]];
780 let mut rng = create_test_rng();
781
782 let augmented = augmenter.augment(&data.view(), &mut rng).unwrap();
783
784 assert_eq!(augmented.shape(), data.shape());
786
787 let scale = augmented[[0, 0]] / data[[0, 0]];
789 for i in 0..data.nrows() {
790 for j in 0..data.ncols() {
791 let computed_scale = augmented[[i, j]] / data[[i, j]];
792 assert!((computed_scale - scale).abs() < 1e-10);
793 }
794 }
795
796 assert!((0.8..=1.2).contains(&scale));
798 }
799
800 #[test]
801 fn test_scale_augmenter_invalid() {
802 assert!(ScaleAugmenter::new(-0.1).is_err());
803 assert!(ScaleAugmenter::new(1.5).is_err());
804 }
805
806 #[test]
807 fn test_rotation_augmenter() {
808 let augmenter = RotationAugmenter::default();
809 let data = array![[1.0, 2.0], [3.0, 4.0]];
810 let mut rng = create_test_rng();
811
812 let augmented = augmenter.augment(&data.view(), &mut rng).unwrap();
813
814 assert_eq!(augmented.shape(), data.shape());
816 }
817
818 #[test]
819 fn test_mixup_augmenter_batch() {
820 let augmenter = MixupAugmenter::new(1.0).unwrap();
821 let data = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
822 let labels = array![[1.0, 0.0], [0.0, 1.0], [1.0, 0.0]];
823 let mut rng = create_test_rng();
824
825 let (aug_data, aug_labels) = augmenter
826 .augment_batch(&data.view(), &labels.view(), &mut rng)
827 .unwrap();
828
829 assert_eq!(aug_data.shape(), data.shape());
831 assert_eq!(aug_labels.shape(), labels.shape());
832
833 let data_min = data.iter().cloned().fold(f64::INFINITY, f64::min);
835 let data_max = data.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
836
837 for &val in aug_data.iter() {
838 assert!(val >= data_min && val <= data_max);
839 }
840 }
841
842 #[test]
843 fn test_mixup_invalid_alpha() {
844 let result = MixupAugmenter::new(0.0);
845 assert!(result.is_err());
846
847 let result = MixupAugmenter::new(-1.0);
848 assert!(result.is_err());
849 }
850
851 #[test]
852 fn test_mixup_mismatched_shapes() {
853 let augmenter = MixupAugmenter::default();
854 let data = array![[1.0, 2.0], [3.0, 4.0]];
855 let labels = array![[1.0, 0.0]]; let mut rng = create_test_rng();
857
858 let result = augmenter.augment_batch(&data.view(), &labels.view(), &mut rng);
859 assert!(result.is_err());
860 }
861
862 #[test]
863 fn test_cutmix_augmenter_batch() {
864 let augmenter = CutMixAugmenter::new(1.0).unwrap();
865 let data = array![
866 [1.0, 2.0, 3.0, 4.0],
867 [5.0, 6.0, 7.0, 8.0],
868 [9.0, 10.0, 11.0, 12.0]
869 ];
870 let labels = array![[1.0, 0.0], [0.0, 1.0], [1.0, 0.0]];
871 let mut rng = create_test_rng();
872
873 let (aug_data, aug_labels) = augmenter
874 .augment_batch(&data.view(), &labels.view(), &mut rng)
875 .unwrap();
876
877 assert_eq!(aug_data.shape(), data.shape());
879 assert_eq!(aug_labels.shape(), labels.shape());
880
881 for i in 0..aug_data.nrows() {
884 let mut found_original = false;
885 let mut found_different = false;
886
887 for j in 0..aug_data.ncols() {
888 if (aug_data[[i, j]] - data[[i, j]]).abs() < 1e-10 {
890 found_original = true;
891 } else {
892 found_different = true;
893 }
894 }
895
896 assert!(found_original || found_different);
898 }
899
900 for i in 0..aug_labels.nrows() {
902 let sum: f64 = aug_labels.row(i).iter().sum();
903 assert!((sum - 1.0).abs() < 1e-10, "Labels should sum to 1.0");
904 }
905 }
906
907 #[test]
908 fn test_cutmix_invalid_alpha() {
909 let result = CutMixAugmenter::new(0.0);
910 assert!(result.is_err());
911
912 let result = CutMixAugmenter::new(-1.0);
913 assert!(result.is_err());
914 }
915
916 #[test]
917 fn test_cutmix_mismatched_shapes() {
918 let augmenter = CutMixAugmenter::default();
919 let data = array![[1.0, 2.0], [3.0, 4.0]];
920 let labels = array![[1.0, 0.0]]; let mut rng = create_test_rng();
922
923 let result = augmenter.augment_batch(&data.view(), &labels.view(), &mut rng);
924 assert!(result.is_err());
925 }
926
927 #[test]
928 fn test_cutmix_label_proportions() {
929 let augmenter = CutMixAugmenter::new(1.0).unwrap();
930 let data = array![[10.0, 10.0, 10.0, 10.0], [20.0, 20.0, 20.0, 20.0]];
932 let labels = array![[1.0, 0.0], [0.0, 1.0]];
933 let mut rng = create_test_rng();
934
935 let (aug_data, aug_labels) = augmenter
936 .augment_batch(&data.view(), &labels.view(), &mut rng)
937 .unwrap();
938
939 for i in 0..aug_labels.nrows() {
941 let sum: f64 = aug_labels.row(i).iter().sum();
943 assert!((sum - 1.0).abs() < 1e-10);
944
945 for j in 0..aug_labels.ncols() {
947 assert!(aug_labels[[i, j]] >= 0.0);
948 assert!(aug_labels[[i, j]] <= 1.0);
949 }
950 }
951
952 assert_eq!(aug_data.shape(), data.shape());
954 }
955
956 #[test]
957 fn test_composite_augmenter() {
958 let mut composite = CompositeAugmenter::new();
959 composite.add(NoiseAugmenter::new(0.01).unwrap());
960 composite.add(ScaleAugmenter::new(0.1).unwrap());
961
962 let data = array![[1.0, 2.0], [3.0, 4.0]];
963 let mut rng = create_test_rng();
964
965 let augmented = composite.augment(&data.view(), &mut rng).unwrap();
966
967 assert_eq!(augmented.shape(), data.shape());
969
970 assert_ne!(augmented[[0, 0]], data[[0, 0]]);
972 }
973
974 #[test]
975 fn test_composite_empty() {
976 let composite = CompositeAugmenter::new();
977 assert!(composite.is_empty());
978 assert_eq!(composite.len(), 0);
979
980 let data = array![[1.0, 2.0]];
981 let mut rng = create_test_rng();
982
983 let augmented = composite.augment(&data.view(), &mut rng).unwrap();
984 assert_eq!(augmented, data);
985 }
986
987 #[test]
988 fn test_composite_multiple() {
989 let mut composite = CompositeAugmenter::new();
990 composite.add(NoAugmentation);
991 composite.add(ScaleAugmenter::default());
992 composite.add(NoiseAugmenter::default());
993
994 assert_eq!(composite.len(), 3);
995 assert!(!composite.is_empty());
996 }
997
998 #[test]
999 fn test_random_erasing_creation() {
1000 let augmenter = RandomErasingAugmenter::new(0.5, 0.02, 0.33, 0.3, 3.3, 0.0).unwrap();
1001 assert_eq!(augmenter.probability, 0.5);
1002 assert_eq!(augmenter.scale_min, 0.02);
1003 assert_eq!(augmenter.scale_max, 0.33);
1004 }
1005
1006 #[test]
1007 fn test_random_erasing_invalid_params() {
1008 assert!(RandomErasingAugmenter::new(1.5, 0.02, 0.33, 0.3, 3.3, 0.0).is_err());
1010
1011 assert!(RandomErasingAugmenter::new(0.5, 0.33, 0.02, 0.3, 3.3, 0.0).is_err());
1013
1014 assert!(RandomErasingAugmenter::new(0.5, 0.02, 0.33, 3.3, 0.3, 0.0).is_err());
1016 }
1017
1018 #[test]
1019 fn test_random_erasing_augment() {
1020 let augmenter = RandomErasingAugmenter::with_defaults();
1021 let data = Array2::ones((10, 10));
1022 let mut rng = create_test_rng();
1023
1024 let augmented = augmenter.augment(&data.view(), &mut rng).unwrap();
1025
1026 assert_eq!(augmented.shape(), data.shape());
1028
1029 }
1031
1032 #[test]
1033 fn test_random_erasing_probability_zero() {
1034 let augmenter = RandomErasingAugmenter::new(0.0, 0.02, 0.33, 0.3, 3.3, 0.0).unwrap();
1035 let data = Array2::ones((10, 10));
1036 let mut rng = create_test_rng();
1037
1038 let augmented = augmenter.augment(&data.view(), &mut rng).unwrap();
1039
1040 assert_eq!(augmented, data);
1042 }
1043
1044 #[test]
1045 fn test_cutout_creation() {
1046 let augmenter = CutOutAugmenter::new(5, 1.0, 0.0).unwrap();
1047 assert_eq!(augmenter.cutout_size, 5);
1048 assert_eq!(augmenter.probability, 1.0);
1049 assert_eq!(augmenter.fill_value, 0.0);
1050 }
1051
1052 #[test]
1053 fn test_cutout_invalid_params() {
1054 assert!(CutOutAugmenter::new(0, 1.0, 0.0).is_err());
1056
1057 assert!(CutOutAugmenter::new(5, 1.5, 0.0).is_err());
1059 }
1060
1061 #[test]
1062 fn test_cutout_augment() {
1063 let augmenter = CutOutAugmenter::with_size(3).unwrap();
1064 let data = Array2::ones((10, 10));
1065 let mut rng = create_test_rng();
1066
1067 let augmented = augmenter.augment(&data.view(), &mut rng).unwrap();
1068
1069 assert_eq!(augmented.shape(), data.shape());
1071
1072 let zeros_count = augmented.iter().filter(|&&x| x == 0.0).count();
1074 assert!(zeros_count > 0, "Expected some values to be erased");
1075 assert!(zeros_count < 100, "Not all values should be erased");
1076 }
1077
1078 #[test]
1079 fn test_cutout_probability_zero() {
1080 let augmenter = CutOutAugmenter::new(5, 0.0, 0.0).unwrap();
1081 let data = Array2::ones((10, 10));
1082 let mut rng = create_test_rng();
1083
1084 let augmented = augmenter.augment(&data.view(), &mut rng).unwrap();
1085
1086 assert_eq!(augmented, data);
1088 }
1089
1090 #[test]
1091 fn test_cutout_fill_value() {
1092 let augmenter = CutOutAugmenter::new(3, 1.0, 0.5).unwrap();
1093 let data = Array2::ones((10, 10));
1094 let mut rng = create_test_rng();
1095
1096 let augmented = augmenter.augment(&data.view(), &mut rng).unwrap();
1097
1098 let filled_count = augmented.iter().filter(|&&x| x == 0.5).count();
1100 assert!(
1101 filled_count > 0,
1102 "Expected some values to be filled with 0.5"
1103 );
1104 }
1105}