1use crate::{TrainError, TrainResult};
2use scirs2_core::ndarray::{Array, Array2, ArrayView2};
3use scirs2_core::random::{RngExt, StdRng};
4
5pub trait DataAugmenter {
7 fn augment(&self, data: &ArrayView2<f64>, rng: &mut StdRng) -> TrainResult<Array2<f64>>;
16}
17
18#[derive(Debug, Clone, Default)]
22pub struct NoAugmentation;
23
24impl DataAugmenter for NoAugmentation {
25 fn augment(&self, data: &ArrayView2<f64>, _rng: &mut StdRng) -> TrainResult<Array2<f64>> {
26 Ok(data.to_owned())
27 }
28}
29
30#[derive(Debug, Clone)]
34pub struct NoiseAugmenter {
35 pub std_dev: f64,
37}
38
39impl NoiseAugmenter {
40 pub fn new(std_dev: f64) -> TrainResult<Self> {
45 if std_dev < 0.0 {
46 return Err(TrainError::InvalidParameter(
47 "std_dev must be non-negative".to_string(),
48 ));
49 }
50 Ok(Self { std_dev })
51 }
52}
53
54impl Default for NoiseAugmenter {
55 fn default() -> Self {
56 Self { std_dev: 0.01 }
57 }
58}
59
60impl DataAugmenter for NoiseAugmenter {
61 fn augment(&self, data: &ArrayView2<f64>, rng: &mut StdRng) -> TrainResult<Array2<f64>> {
62 let mut augmented = data.to_owned();
63
64 for value in augmented.iter_mut() {
66 let u1: f64 = rng.random();
67 let u2: f64 = rng.random();
68
69 let z0 = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
71 let noise = z0 * self.std_dev;
72
73 *value += noise;
74 }
75
76 Ok(augmented)
77 }
78}
79
80#[derive(Debug, Clone)]
84pub struct ScaleAugmenter {
85 pub scale_range: f64,
87}
88
89impl ScaleAugmenter {
90 pub fn new(scale_range: f64) -> TrainResult<Self> {
95 if !(0.0..=1.0).contains(&scale_range) {
96 return Err(TrainError::InvalidParameter(
97 "scale_range must be in [0, 1]".to_string(),
98 ));
99 }
100 Ok(Self { scale_range })
101 }
102}
103
104impl Default for ScaleAugmenter {
105 fn default() -> Self {
106 Self { scale_range: 0.1 }
107 }
108}
109
110impl DataAugmenter for ScaleAugmenter {
111 fn augment(&self, data: &ArrayView2<f64>, rng: &mut StdRng) -> TrainResult<Array2<f64>> {
112 let scale = 1.0 + (rng.random::<f64>() * 2.0 - 1.0) * self.scale_range;
114
115 let augmented = data.mapv(|x| x * scale);
116 Ok(augmented)
117 }
118}
119
120#[derive(Debug, Clone)]
125pub struct RotationAugmenter {
126 pub max_angle: f64,
128}
129
130impl RotationAugmenter {
131 pub fn new(max_angle: f64) -> Self {
136 Self { max_angle }
137 }
138}
139
140impl Default for RotationAugmenter {
141 fn default() -> Self {
142 Self {
143 max_angle: std::f64::consts::PI / 18.0, }
145 }
146}
147
148impl DataAugmenter for RotationAugmenter {
149 fn augment(&self, data: &ArrayView2<f64>, rng: &mut StdRng) -> TrainResult<Array2<f64>> {
150 let angle = (rng.random::<f64>() * 2.0 - 1.0) * self.max_angle;
153
154 let cos_a = angle.cos();
156 let sin_a = angle.sin();
157
158 let augmented = data.mapv(|x| x * cos_a + x * sin_a * 0.1);
159 Ok(augmented)
160 }
161}
162
163#[derive(Debug, Clone)]
170pub struct MixupAugmenter {
171 pub alpha: f64,
173}
174
175impl MixupAugmenter {
176 pub fn new(alpha: f64) -> TrainResult<Self> {
181 if alpha <= 0.0 {
182 return Err(TrainError::InvalidParameter(
183 "alpha must be positive".to_string(),
184 ));
185 }
186 Ok(Self { alpha })
187 }
188
189 pub fn augment_batch(
199 &self,
200 data: &ArrayView2<f64>,
201 labels: &ArrayView2<f64>,
202 rng: &mut StdRng,
203 ) -> TrainResult<(Array2<f64>, Array2<f64>)> {
204 if data.nrows() != labels.nrows() {
205 return Err(TrainError::InvalidParameter(
206 "data and labels must have same number of rows".to_string(),
207 ));
208 }
209
210 let n = data.nrows();
211 let mut augmented_data = Array::zeros(data.raw_dim());
212 let mut augmented_labels = Array::zeros(labels.raw_dim());
213
214 let mut indices: Vec<usize> = (0..n).collect();
216 for i in (1..n).rev() {
217 let j = rng.gen_range(0..=i);
218 indices.swap(i, j);
219 }
220
221 for i in 0..n {
222 let j = indices[i];
223
224 let lambda = self.sample_beta(rng);
227
228 for k in 0..data.ncols() {
230 augmented_data[[i, k]] = lambda * data[[i, k]] + (1.0 - lambda) * data[[j, k]];
231 }
232
233 for k in 0..labels.ncols() {
235 augmented_labels[[i, k]] =
236 lambda * labels[[i, k]] + (1.0 - lambda) * labels[[j, k]];
237 }
238 }
239
240 Ok((augmented_data, augmented_labels))
241 }
242
243 fn sample_beta(&self, rng: &mut StdRng) -> f64 {
247 if self.alpha < 0.5 {
248 if rng.random::<f64>() < 0.5 {
250 rng.random::<f64>().powf(2.0)
251 } else {
252 1.0 - rng.random::<f64>().powf(2.0)
253 }
254 } else {
255 rng.random::<f64>()
257 }
258 }
259}
260
261impl Default for MixupAugmenter {
262 fn default() -> Self {
263 Self { alpha: 1.0 }
264 }
265}
266
267impl DataAugmenter for MixupAugmenter {
268 fn augment(&self, data: &ArrayView2<f64>, _rng: &mut StdRng) -> TrainResult<Array2<f64>> {
269 Ok(data.to_owned())
272 }
273}
274
275#[derive(Debug, Clone)]
284pub struct CutMixAugmenter {
285 pub alpha: f64,
287}
288
289impl CutMixAugmenter {
290 pub fn new(alpha: f64) -> TrainResult<Self> {
298 if alpha <= 0.0 {
299 return Err(TrainError::InvalidParameter(
300 "alpha must be positive".to_string(),
301 ));
302 }
303 Ok(Self { alpha })
304 }
305
306 pub fn augment_batch(
319 &self,
320 data: &ArrayView2<f64>,
321 labels: &ArrayView2<f64>,
322 rng: &mut StdRng,
323 ) -> TrainResult<(Array2<f64>, Array2<f64>)> {
324 if data.nrows() != labels.nrows() {
325 return Err(TrainError::InvalidParameter(
326 "data and labels must have same number of rows".to_string(),
327 ));
328 }
329
330 let n = data.nrows();
331 let features = data.ncols();
332 let mut augmented_data = data.to_owned();
333 let mut augmented_labels = labels.to_owned();
334
335 let mut indices: Vec<usize> = (0..n).collect();
337 for i in (1..n).rev() {
338 let j = rng.gen_range(0..=i);
339 indices.swap(i, j);
340 }
341
342 for i in 0..n {
343 let j = indices[i];
344
345 let lambda = self.sample_beta(rng);
347
348 let cut_ratio = (1.0 - lambda).sqrt();
351 let cut_size = (features as f64 * cut_ratio) as usize;
352 let cut_size = cut_size.max(1).min(features - 1);
353
354 let start = if features > cut_size {
356 rng.gen_range(0..=(features - cut_size))
357 } else {
358 0
359 };
360
361 for k in start..(start + cut_size).min(features) {
363 augmented_data[[i, k]] = data[[j, k]];
364 }
365
366 let actual_ratio = cut_size as f64 / features as f64;
368 for k in 0..labels.ncols() {
369 augmented_labels[[i, k]] =
370 (1.0 - actual_ratio) * labels[[i, k]] + actual_ratio * labels[[j, k]];
371 }
372 }
373
374 Ok((augmented_data, augmented_labels))
375 }
376
377 fn sample_beta(&self, rng: &mut StdRng) -> f64 {
381 if self.alpha < 0.5 {
382 if rng.random::<f64>() < 0.5 {
384 rng.random::<f64>().powf(2.0)
385 } else {
386 1.0 - rng.random::<f64>().powf(2.0)
387 }
388 } else {
389 rng.random::<f64>()
391 }
392 }
393}
394
395impl Default for CutMixAugmenter {
396 fn default() -> Self {
397 Self { alpha: 1.0 }
398 }
399}
400
401impl DataAugmenter for CutMixAugmenter {
402 fn augment(&self, data: &ArrayView2<f64>, _rng: &mut StdRng) -> TrainResult<Array2<f64>> {
403 Ok(data.to_owned())
406 }
407}
408
409#[derive(Clone, Default)]
411pub struct CompositeAugmenter {
412 augmenters: Vec<Box<dyn AugmenterClone>>,
413}
414
415impl std::fmt::Debug for CompositeAugmenter {
416 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
417 f.debug_struct("CompositeAugmenter")
418 .field("num_augmenters", &self.augmenters.len())
419 .finish()
420 }
421}
422
423trait AugmenterClone: DataAugmenter {
425 fn clone_box(&self) -> Box<dyn AugmenterClone>;
426}
427
428impl<T: DataAugmenter + Clone + 'static> AugmenterClone for T {
429 fn clone_box(&self) -> Box<dyn AugmenterClone> {
430 Box::new(self.clone())
431 }
432}
433
434impl Clone for Box<dyn AugmenterClone> {
435 fn clone(&self) -> Self {
436 self.clone_box()
437 }
438}
439
440impl DataAugmenter for Box<dyn AugmenterClone> {
441 fn augment(&self, data: &ArrayView2<f64>, rng: &mut StdRng) -> TrainResult<Array2<f64>> {
442 (**self).augment(data, rng)
443 }
444}
445
446impl CompositeAugmenter {
447 pub fn new() -> Self {
449 Self {
450 augmenters: Vec::new(),
451 }
452 }
453
454 pub fn add<A: DataAugmenter + Clone + 'static>(&mut self, augmenter: A) {
456 self.augmenters.push(Box::new(augmenter));
457 }
458
459 pub fn len(&self) -> usize {
461 self.augmenters.len()
462 }
463
464 pub fn is_empty(&self) -> bool {
466 self.augmenters.is_empty()
467 }
468}
469
470impl DataAugmenter for CompositeAugmenter {
471 fn augment(&self, data: &ArrayView2<f64>, rng: &mut StdRng) -> TrainResult<Array2<f64>> {
472 let mut result = data.to_owned();
473
474 for augmenter in &self.augmenters {
475 result = augmenter.augment(&result.view(), rng)?;
476 }
477
478 Ok(result)
479 }
480}
481
482#[derive(Debug, Clone)]
495pub struct RandomErasingAugmenter {
496 pub probability: f64,
498 pub scale_min: f64,
500 pub scale_max: f64,
502 pub ratio_min: f64,
504 pub ratio_max: f64,
506 pub fill_value: f64,
508}
509
510impl RandomErasingAugmenter {
511 pub fn new(
513 probability: f64,
514 scale_min: f64,
515 scale_max: f64,
516 ratio_min: f64,
517 ratio_max: f64,
518 fill_value: f64,
519 ) -> TrainResult<Self> {
520 if !(0.0..=1.0).contains(&probability) {
521 return Err(TrainError::InvalidParameter(
522 "probability must be in [0, 1]".to_string(),
523 ));
524 }
525 if scale_min >= scale_max || scale_min < 0.0 || scale_max > 1.0 {
526 return Err(TrainError::InvalidParameter(
527 "scale range must be valid (0 <= min < max <= 1)".to_string(),
528 ));
529 }
530 if ratio_min <= 0.0 || ratio_min >= ratio_max {
531 return Err(TrainError::InvalidParameter(
532 "ratio range must be valid (0 < min < max)".to_string(),
533 ));
534 }
535
536 Ok(Self {
537 probability,
538 scale_min,
539 scale_max,
540 ratio_min,
541 ratio_max,
542 fill_value,
543 })
544 }
545
546 pub fn with_defaults() -> Self {
548 Self {
549 probability: 0.5,
550 scale_min: 0.02,
551 scale_max: 0.33,
552 ratio_min: 0.3,
553 ratio_max: 3.3,
554 fill_value: 0.0,
555 }
556 }
557}
558
559impl Default for RandomErasingAugmenter {
560 fn default() -> Self {
561 Self::with_defaults()
562 }
563}
564
565impl DataAugmenter for RandomErasingAugmenter {
566 fn augment(&self, data: &ArrayView2<f64>, rng: &mut StdRng) -> TrainResult<Array2<f64>> {
567 let mut augmented = data.to_owned();
568
569 if rng.random::<f64>() > self.probability {
571 return Ok(augmented);
572 }
573
574 let (height, width) = (data.nrows(), data.ncols());
575 let area = (height * width) as f64;
576
577 for _ in 0..10 {
579 let scale = self.scale_min + rng.random::<f64>() * (self.scale_max - self.scale_min);
581 let erase_area = area * scale;
582
583 let aspect_ratio =
585 self.ratio_min + rng.random::<f64>() * (self.ratio_max - self.ratio_min);
586
587 let h = (erase_area * aspect_ratio).sqrt().min(height as f64);
589 let w = (erase_area / aspect_ratio).sqrt().min(width as f64);
590
591 if h >= 1.0 && w >= 1.0 {
592 let erase_h = h as usize;
593 let erase_w = w as usize;
594
595 let i = if erase_h < height {
597 (rng.random::<f64>() * (height - erase_h) as f64) as usize
598 } else {
599 0
600 };
601 let j = if erase_w < width {
602 (rng.random::<f64>() * (width - erase_w) as f64) as usize
603 } else {
604 0
605 };
606
607 if self.fill_value == 1.0 {
609 for row in i..i + erase_h.min(height - i) {
611 for col in j..j + erase_w.min(width - j) {
612 augmented[[row, col]] = rng.random();
613 }
614 }
615 } else {
616 for row in i..i + erase_h.min(height - i) {
618 for col in j..j + erase_w.min(width - j) {
619 augmented[[row, col]] = self.fill_value;
620 }
621 }
622 }
623
624 break;
625 }
626 }
627
628 Ok(augmented)
629 }
630}
631
632#[derive(Debug, Clone)]
644pub struct CutOutAugmenter {
645 pub cutout_size: usize,
647 pub probability: f64,
649 pub fill_value: f64,
651}
652
653impl CutOutAugmenter {
654 pub fn new(cutout_size: usize, probability: f64, fill_value: f64) -> TrainResult<Self> {
656 if cutout_size == 0 {
657 return Err(TrainError::InvalidParameter(
658 "cutout_size must be > 0".to_string(),
659 ));
660 }
661 if !(0.0..=1.0).contains(&probability) {
662 return Err(TrainError::InvalidParameter(
663 "probability must be in [0, 1]".to_string(),
664 ));
665 }
666
667 Ok(Self {
668 cutout_size,
669 probability,
670 fill_value,
671 })
672 }
673
674 pub fn with_size(size: usize) -> TrainResult<Self> {
676 Self::new(size, 1.0, 0.0)
677 }
678}
679
680impl DataAugmenter for CutOutAugmenter {
681 fn augment(&self, data: &ArrayView2<f64>, rng: &mut StdRng) -> TrainResult<Array2<f64>> {
682 let mut augmented = data.to_owned();
683
684 if rng.random::<f64>() > self.probability {
686 return Ok(augmented);
687 }
688
689 let (height, width) = (data.nrows(), data.ncols());
690
691 let center_y = (rng.random::<f64>() * height as f64) as usize;
693 let center_x = (rng.random::<f64>() * width as f64) as usize;
694
695 let half_size = self.cutout_size / 2;
697
698 let y_start = center_y.saturating_sub(half_size);
699 let y_end = (center_y + half_size).min(height);
700
701 let x_start = center_x.saturating_sub(half_size);
702 let x_end = (center_x + half_size).min(width);
703
704 for i in y_start..y_end {
706 for j in x_start..x_end {
707 augmented[[i, j]] = self.fill_value;
708 }
709 }
710
711 Ok(augmented)
712 }
713}
714
715#[cfg(test)]
716mod tests {
717 use super::*;
718 use scirs2_core::ndarray::array;
719 use scirs2_core::random::SeedableRng;
720
721 fn create_test_rng() -> StdRng {
722 StdRng::seed_from_u64(42)
723 }
724
725 #[test]
726 fn test_no_augmentation() {
727 let augmenter = NoAugmentation;
728 let data = array![[1.0, 2.0], [3.0, 4.0]];
729 let mut rng = create_test_rng();
730
731 let augmented = augmenter.augment(&data.view(), &mut rng).expect("unwrap");
732 assert_eq!(augmented, data);
733 }
734
735 #[test]
736 fn test_noise_augmenter() {
737 let augmenter = NoiseAugmenter::new(0.1).expect("unwrap");
738 let data = array![[1.0, 2.0], [3.0, 4.0]];
739 let mut rng = create_test_rng();
740
741 let augmented = augmenter.augment(&data.view(), &mut rng).expect("unwrap");
742
743 assert_eq!(augmented.shape(), data.shape());
745
746 assert_ne!(augmented[[0, 0]], data[[0, 0]]);
748
749 for i in 0..data.nrows() {
751 for j in 0..data.ncols() {
752 let diff = (augmented[[i, j]] - data[[i, j]]).abs();
753 assert!(diff < 1.0); }
755 }
756 }
757
758 #[test]
759 fn test_noise_augmenter_invalid() {
760 let result = NoiseAugmenter::new(-0.1);
761 assert!(result.is_err());
762 }
763
764 #[test]
765 fn test_scale_augmenter() {
766 let augmenter = ScaleAugmenter::new(0.2).expect("unwrap");
767 let data = array![[1.0, 2.0], [3.0, 4.0]];
768 let mut rng = create_test_rng();
769
770 let augmented = augmenter.augment(&data.view(), &mut rng).expect("unwrap");
771
772 assert_eq!(augmented.shape(), data.shape());
774
775 let scale = augmented[[0, 0]] / data[[0, 0]];
777 for i in 0..data.nrows() {
778 for j in 0..data.ncols() {
779 let computed_scale = augmented[[i, j]] / data[[i, j]];
780 assert!((computed_scale - scale).abs() < 1e-10);
781 }
782 }
783
784 assert!((0.8..=1.2).contains(&scale));
786 }
787
788 #[test]
789 fn test_scale_augmenter_invalid() {
790 assert!(ScaleAugmenter::new(-0.1).is_err());
791 assert!(ScaleAugmenter::new(1.5).is_err());
792 }
793
794 #[test]
795 fn test_rotation_augmenter() {
796 let augmenter = RotationAugmenter::default();
797 let data = array![[1.0, 2.0], [3.0, 4.0]];
798 let mut rng = create_test_rng();
799
800 let augmented = augmenter.augment(&data.view(), &mut rng).expect("unwrap");
801
802 assert_eq!(augmented.shape(), data.shape());
804 }
805
806 #[test]
807 fn test_mixup_augmenter_batch() {
808 let augmenter = MixupAugmenter::new(1.0).expect("unwrap");
809 let data = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
810 let labels = array![[1.0, 0.0], [0.0, 1.0], [1.0, 0.0]];
811 let mut rng = create_test_rng();
812
813 let (aug_data, aug_labels) = augmenter
814 .augment_batch(&data.view(), &labels.view(), &mut rng)
815 .expect("unwrap");
816
817 assert_eq!(aug_data.shape(), data.shape());
819 assert_eq!(aug_labels.shape(), labels.shape());
820
821 let data_min = data.iter().cloned().fold(f64::INFINITY, f64::min);
823 let data_max = data.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
824
825 for &val in aug_data.iter() {
826 assert!(val >= data_min && val <= data_max);
827 }
828 }
829
830 #[test]
831 fn test_mixup_invalid_alpha() {
832 let result = MixupAugmenter::new(0.0);
833 assert!(result.is_err());
834
835 let result = MixupAugmenter::new(-1.0);
836 assert!(result.is_err());
837 }
838
839 #[test]
840 fn test_mixup_mismatched_shapes() {
841 let augmenter = MixupAugmenter::default();
842 let data = array![[1.0, 2.0], [3.0, 4.0]];
843 let labels = array![[1.0, 0.0]]; let mut rng = create_test_rng();
845
846 let result = augmenter.augment_batch(&data.view(), &labels.view(), &mut rng);
847 assert!(result.is_err());
848 }
849
850 #[test]
851 fn test_cutmix_augmenter_batch() {
852 let augmenter = CutMixAugmenter::new(1.0).expect("unwrap");
853 let data = array![
854 [1.0, 2.0, 3.0, 4.0],
855 [5.0, 6.0, 7.0, 8.0],
856 [9.0, 10.0, 11.0, 12.0]
857 ];
858 let labels = array![[1.0, 0.0], [0.0, 1.0], [1.0, 0.0]];
859 let mut rng = create_test_rng();
860
861 let (aug_data, aug_labels) = augmenter
862 .augment_batch(&data.view(), &labels.view(), &mut rng)
863 .expect("unwrap");
864
865 assert_eq!(aug_data.shape(), data.shape());
867 assert_eq!(aug_labels.shape(), labels.shape());
868
869 for i in 0..aug_data.nrows() {
872 let mut found_original = false;
873 let mut found_different = false;
874
875 for j in 0..aug_data.ncols() {
876 if (aug_data[[i, j]] - data[[i, j]]).abs() < 1e-10 {
878 found_original = true;
879 } else {
880 found_different = true;
881 }
882 }
883
884 assert!(found_original || found_different);
886 }
887
888 for i in 0..aug_labels.nrows() {
890 let sum: f64 = aug_labels.row(i).iter().sum();
891 assert!((sum - 1.0).abs() < 1e-10, "Labels should sum to 1.0");
892 }
893 }
894
895 #[test]
896 fn test_cutmix_invalid_alpha() {
897 let result = CutMixAugmenter::new(0.0);
898 assert!(result.is_err());
899
900 let result = CutMixAugmenter::new(-1.0);
901 assert!(result.is_err());
902 }
903
904 #[test]
905 fn test_cutmix_mismatched_shapes() {
906 let augmenter = CutMixAugmenter::default();
907 let data = array![[1.0, 2.0], [3.0, 4.0]];
908 let labels = array![[1.0, 0.0]]; let mut rng = create_test_rng();
910
911 let result = augmenter.augment_batch(&data.view(), &labels.view(), &mut rng);
912 assert!(result.is_err());
913 }
914
915 #[test]
916 fn test_cutmix_label_proportions() {
917 let augmenter = CutMixAugmenter::new(1.0).expect("unwrap");
918 let data = array![[10.0, 10.0, 10.0, 10.0], [20.0, 20.0, 20.0, 20.0]];
920 let labels = array![[1.0, 0.0], [0.0, 1.0]];
921 let mut rng = create_test_rng();
922
923 let (aug_data, aug_labels) = augmenter
924 .augment_batch(&data.view(), &labels.view(), &mut rng)
925 .expect("unwrap");
926
927 for i in 0..aug_labels.nrows() {
929 let sum: f64 = aug_labels.row(i).iter().sum();
931 assert!((sum - 1.0).abs() < 1e-10);
932
933 for j in 0..aug_labels.ncols() {
935 assert!(aug_labels[[i, j]] >= 0.0);
936 assert!(aug_labels[[i, j]] <= 1.0);
937 }
938 }
939
940 assert_eq!(aug_data.shape(), data.shape());
942 }
943
944 #[test]
945 fn test_composite_augmenter() {
946 let mut composite = CompositeAugmenter::new();
947 composite.add(NoiseAugmenter::new(0.01).expect("unwrap"));
948 composite.add(ScaleAugmenter::new(0.1).expect("unwrap"));
949
950 let data = array![[1.0, 2.0], [3.0, 4.0]];
951 let mut rng = create_test_rng();
952
953 let augmented = composite.augment(&data.view(), &mut rng).expect("unwrap");
954
955 assert_eq!(augmented.shape(), data.shape());
957
958 assert_ne!(augmented[[0, 0]], data[[0, 0]]);
960 }
961
962 #[test]
963 fn test_composite_empty() {
964 let composite = CompositeAugmenter::new();
965 assert!(composite.is_empty());
966 assert_eq!(composite.len(), 0);
967
968 let data = array![[1.0, 2.0]];
969 let mut rng = create_test_rng();
970
971 let augmented = composite.augment(&data.view(), &mut rng).expect("unwrap");
972 assert_eq!(augmented, data);
973 }
974
975 #[test]
976 fn test_composite_multiple() {
977 let mut composite = CompositeAugmenter::new();
978 composite.add(NoAugmentation);
979 composite.add(ScaleAugmenter::default());
980 composite.add(NoiseAugmenter::default());
981
982 assert_eq!(composite.len(), 3);
983 assert!(!composite.is_empty());
984 }
985
986 #[test]
987 fn test_random_erasing_creation() {
988 let augmenter =
989 RandomErasingAugmenter::new(0.5, 0.02, 0.33, 0.3, 3.3, 0.0).expect("unwrap");
990 assert_eq!(augmenter.probability, 0.5);
991 assert_eq!(augmenter.scale_min, 0.02);
992 assert_eq!(augmenter.scale_max, 0.33);
993 }
994
995 #[test]
996 fn test_random_erasing_invalid_params() {
997 assert!(RandomErasingAugmenter::new(1.5, 0.02, 0.33, 0.3, 3.3, 0.0).is_err());
999
1000 assert!(RandomErasingAugmenter::new(0.5, 0.33, 0.02, 0.3, 3.3, 0.0).is_err());
1002
1003 assert!(RandomErasingAugmenter::new(0.5, 0.02, 0.33, 3.3, 0.3, 0.0).is_err());
1005 }
1006
1007 #[test]
1008 fn test_random_erasing_augment() {
1009 let augmenter = RandomErasingAugmenter::with_defaults();
1010 let data = scirs2_core::ndarray::Array2::ones((10, 10));
1011 let mut rng = create_test_rng();
1012
1013 let augmented = augmenter.augment(&data.view(), &mut rng).expect("unwrap");
1014
1015 assert_eq!(augmented.shape(), data.shape());
1017
1018 }
1020
1021 #[test]
1022 fn test_random_erasing_probability_zero() {
1023 let augmenter =
1024 RandomErasingAugmenter::new(0.0, 0.02, 0.33, 0.3, 3.3, 0.0).expect("unwrap");
1025 let data = scirs2_core::ndarray::Array2::ones((10, 10));
1026 let mut rng = create_test_rng();
1027
1028 let augmented = augmenter.augment(&data.view(), &mut rng).expect("unwrap");
1029
1030 assert_eq!(augmented, data);
1032 }
1033
1034 #[test]
1035 fn test_cutout_creation() {
1036 let augmenter = CutOutAugmenter::new(5, 1.0, 0.0).expect("unwrap");
1037 assert_eq!(augmenter.cutout_size, 5);
1038 assert_eq!(augmenter.probability, 1.0);
1039 assert_eq!(augmenter.fill_value, 0.0);
1040 }
1041
1042 #[test]
1043 fn test_cutout_invalid_params() {
1044 assert!(CutOutAugmenter::new(0, 1.0, 0.0).is_err());
1046
1047 assert!(CutOutAugmenter::new(5, 1.5, 0.0).is_err());
1049 }
1050
1051 #[test]
1052 fn test_cutout_augment() {
1053 let augmenter = CutOutAugmenter::with_size(3).expect("unwrap");
1054 let data = scirs2_core::ndarray::Array2::ones((10, 10));
1055 let mut rng = create_test_rng();
1056
1057 let augmented = augmenter.augment(&data.view(), &mut rng).expect("unwrap");
1058
1059 assert_eq!(augmented.shape(), data.shape());
1061
1062 let zeros_count = augmented.iter().filter(|&&x| x == 0.0).count();
1064 assert!(zeros_count > 0, "Expected some values to be erased");
1065 assert!(zeros_count < 100, "Not all values should be erased");
1066 }
1067
1068 #[test]
1069 fn test_cutout_probability_zero() {
1070 let augmenter = CutOutAugmenter::new(5, 0.0, 0.0).expect("unwrap");
1071 let data = scirs2_core::ndarray::Array2::ones((10, 10));
1072 let mut rng = create_test_rng();
1073
1074 let augmented = augmenter.augment(&data.view(), &mut rng).expect("unwrap");
1075
1076 assert_eq!(augmented, data);
1078 }
1079
1080 #[test]
1081 fn test_cutout_fill_value() {
1082 let augmenter = CutOutAugmenter::new(3, 1.0, 0.5).expect("unwrap");
1083 let data = scirs2_core::ndarray::Array2::ones((10, 10));
1084 let mut rng = create_test_rng();
1085
1086 let augmented = augmenter.augment(&data.view(), &mut rng).expect("unwrap");
1087
1088 let filled_count = augmented.iter().filter(|&&x| x == 0.5).count();
1090 assert!(
1091 filled_count > 0,
1092 "Expected some values to be filled with 0.5"
1093 );
1094 }
1095}