Skip to main content

tensorlogic_train/augmentation/
trait_api.rs

1use crate::{TrainError, TrainResult};
2use scirs2_core::ndarray::{Array, Array2, ArrayView2};
3use scirs2_core::random::{RngExt, StdRng};
4
5/// Trait for data augmentation strategies.
6pub trait DataAugmenter {
7    /// Augment the given data.
8    ///
9    /// # Arguments
10    /// * `data` - Input data to augment
11    /// * `rng` - Random number generator for stochastic augmentation
12    ///
13    /// # Returns
14    /// Augmented data with the same shape as input
15    fn augment(&self, data: &ArrayView2<f64>, rng: &mut StdRng) -> TrainResult<Array2<f64>>;
16}
17
18/// No augmentation (identity transformation).
19///
20/// Useful for testing or as a placeholder.
21#[derive(Debug, Clone, Default)]
22pub struct NoAugmentation;
23
24impl DataAugmenter for NoAugmentation {
25    fn augment(&self, data: &ArrayView2<f64>, _rng: &mut StdRng) -> TrainResult<Array2<f64>> {
26        Ok(data.to_owned())
27    }
28}
29
30/// Gaussian noise augmentation.
31///
32/// Adds random Gaussian noise to the input data: x' = x + N(0, σ²)
33#[derive(Debug, Clone)]
34pub struct NoiseAugmenter {
35    /// Standard deviation of the Gaussian noise.
36    pub std_dev: f64,
37}
38
39impl NoiseAugmenter {
40    /// Create a new noise augmenter.
41    ///
42    /// # Arguments
43    /// * `std_dev` - Standard deviation of the noise
44    pub fn new(std_dev: f64) -> TrainResult<Self> {
45        if std_dev < 0.0 {
46            return Err(TrainError::InvalidParameter(
47                "std_dev must be non-negative".to_string(),
48            ));
49        }
50        Ok(Self { std_dev })
51    }
52}
53
54impl Default for NoiseAugmenter {
55    fn default() -> Self {
56        Self { std_dev: 0.01 }
57    }
58}
59
60impl DataAugmenter for NoiseAugmenter {
61    fn augment(&self, data: &ArrayView2<f64>, rng: &mut StdRng) -> TrainResult<Array2<f64>> {
62        let mut augmented = data.to_owned();
63
64        // Add Gaussian noise using Box-Muller transform
65        for value in augmented.iter_mut() {
66            let u1: f64 = rng.random();
67            let u2: f64 = rng.random();
68
69            // Box-Muller transform
70            let z0 = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
71            let noise = z0 * self.std_dev;
72
73            *value += noise;
74        }
75
76        Ok(augmented)
77    }
78}
79
80/// Scale augmentation.
81///
82/// Randomly scales the input data by a factor in [1 - scale_range, 1 + scale_range].
83#[derive(Debug, Clone)]
84pub struct ScaleAugmenter {
85    /// Range of scaling factor (e.g., 0.1 means scale in [0.9, 1.1]).
86    pub scale_range: f64,
87}
88
89impl ScaleAugmenter {
90    /// Create a new scale augmenter.
91    ///
92    /// # Arguments
93    /// * `scale_range` - Range of scaling factor
94    pub fn new(scale_range: f64) -> TrainResult<Self> {
95        if !(0.0..=1.0).contains(&scale_range) {
96            return Err(TrainError::InvalidParameter(
97                "scale_range must be in [0, 1]".to_string(),
98            ));
99        }
100        Ok(Self { scale_range })
101    }
102}
103
104impl Default for ScaleAugmenter {
105    fn default() -> Self {
106        Self { scale_range: 0.1 }
107    }
108}
109
110impl DataAugmenter for ScaleAugmenter {
111    fn augment(&self, data: &ArrayView2<f64>, rng: &mut StdRng) -> TrainResult<Array2<f64>> {
112        // Generate random scale factor
113        let scale = 1.0 + (rng.random::<f64>() * 2.0 - 1.0) * self.scale_range;
114
115        let augmented = data.mapv(|x| x * scale);
116        Ok(augmented)
117    }
118}
119
120/// Rotation augmentation (placeholder for future implementation).
121///
122/// For 2D images, this would apply random rotations.
123/// Currently implements a simplified version for tabular data.
124#[derive(Debug, Clone)]
125pub struct RotationAugmenter {
126    /// Maximum rotation angle in radians.
127    pub max_angle: f64,
128}
129
130impl RotationAugmenter {
131    /// Create a new rotation augmenter.
132    ///
133    /// # Arguments
134    /// * `max_angle` - Maximum rotation angle in radians
135    pub fn new(max_angle: f64) -> Self {
136        Self { max_angle }
137    }
138}
139
140impl Default for RotationAugmenter {
141    fn default() -> Self {
142        Self {
143            max_angle: std::f64::consts::PI / 18.0, // 10 degrees
144        }
145    }
146}
147
148impl DataAugmenter for RotationAugmenter {
149    fn augment(&self, data: &ArrayView2<f64>, rng: &mut StdRng) -> TrainResult<Array2<f64>> {
150        // For now, this is a placeholder that returns a simple transformation
151        // Future: implement proper 2D rotation for image data
152        let angle = (rng.random::<f64>() * 2.0 - 1.0) * self.max_angle;
153
154        // Apply a simple rotation-inspired transformation
155        let cos_a = angle.cos();
156        let sin_a = angle.sin();
157
158        let augmented = data.mapv(|x| x * cos_a + x * sin_a * 0.1);
159        Ok(augmented)
160    }
161}
162
163/// Mixup augmentation.
164///
165/// Creates new training samples by linearly interpolating between pairs of samples:
166/// x' = λ * x₁ + (1 - λ) * x₂, y' = λ * y₁ + (1 - λ) * y₂
167///
168/// Reference: Zhang et al., "mixup: Beyond Empirical Risk Minimization", ICLR 2018
169#[derive(Debug, Clone)]
170pub struct MixupAugmenter {
171    /// Alpha parameter for Beta distribution (controls mixing strength).
172    pub alpha: f64,
173}
174
175impl MixupAugmenter {
176    /// Create a new mixup augmenter.
177    ///
178    /// # Arguments
179    /// * `alpha` - Alpha parameter for Beta distribution
180    pub fn new(alpha: f64) -> TrainResult<Self> {
181        if alpha <= 0.0 {
182            return Err(TrainError::InvalidParameter(
183                "alpha must be positive".to_string(),
184            ));
185        }
186        Ok(Self { alpha })
187    }
188
189    /// Apply mixup to a batch of data.
190    ///
191    /// # Arguments
192    /// * `data` - Input data batch [N, features]
193    /// * `labels` - Corresponding labels [N, classes]
194    /// * `rng` - Random number generator
195    ///
196    /// # Returns
197    /// Tuple of (augmented_data, augmented_labels)
198    pub fn augment_batch(
199        &self,
200        data: &ArrayView2<f64>,
201        labels: &ArrayView2<f64>,
202        rng: &mut StdRng,
203    ) -> TrainResult<(Array2<f64>, Array2<f64>)> {
204        if data.nrows() != labels.nrows() {
205            return Err(TrainError::InvalidParameter(
206                "data and labels must have same number of rows".to_string(),
207            ));
208        }
209
210        let n = data.nrows();
211        let mut augmented_data = Array::zeros(data.raw_dim());
212        let mut augmented_labels = Array::zeros(labels.raw_dim());
213
214        // Create random permutation
215        let mut indices: Vec<usize> = (0..n).collect();
216        for i in (1..n).rev() {
217            let j = rng.gen_range(0..=i);
218            indices.swap(i, j);
219        }
220
221        for i in 0..n {
222            let j = indices[i];
223
224            // Sample mixing coefficient from Beta distribution
225            // Simplified: use uniform distribution as approximation
226            let lambda = self.sample_beta(rng);
227
228            // Mix data: x' = λ * x_i + (1 - λ) * x_j
229            for k in 0..data.ncols() {
230                augmented_data[[i, k]] = lambda * data[[i, k]] + (1.0 - lambda) * data[[j, k]];
231            }
232
233            // Mix labels: y' = λ * y_i + (1 - λ) * y_j
234            for k in 0..labels.ncols() {
235                augmented_labels[[i, k]] =
236                    lambda * labels[[i, k]] + (1.0 - lambda) * labels[[j, k]];
237            }
238        }
239
240        Ok((augmented_data, augmented_labels))
241    }
242
243    /// Sample from Beta(alpha, alpha) distribution.
244    ///
245    /// Simplified implementation using uniform distribution when alpha is close to 1.
246    fn sample_beta(&self, rng: &mut StdRng) -> f64 {
247        if self.alpha < 0.5 {
248            // For small alpha, prefer values near 0 or 1
249            if rng.random::<f64>() < 0.5 {
250                rng.random::<f64>().powf(2.0)
251            } else {
252                1.0 - rng.random::<f64>().powf(2.0)
253            }
254        } else {
255            // For alpha >= 0.5, approximate with uniform
256            rng.random::<f64>()
257        }
258    }
259}
260
261impl Default for MixupAugmenter {
262    fn default() -> Self {
263        Self { alpha: 1.0 }
264    }
265}
266
267impl DataAugmenter for MixupAugmenter {
268    fn augment(&self, data: &ArrayView2<f64>, _rng: &mut StdRng) -> TrainResult<Array2<f64>> {
269        // For single-sample augmentation, mix with itself (no-op)
270        // In practice, mixup should be used with augment_batch
271        Ok(data.to_owned())
272    }
273}
274
275/// CutMix augmentation (ICCV 2019).
276///
277/// Instead of mixing pixels uniformly like Mixup, CutMix cuts a rectangular region
278/// from one image and pastes it to another. Labels are mixed proportionally to the
279/// area of the patch.
280///
281/// Reference: Yun et al. "CutMix: Regularization Strategy to Train Strong Classifiers
282/// with Localizable Features" (ICCV 2019)
283#[derive(Debug, Clone)]
284pub struct CutMixAugmenter {
285    /// Beta distribution parameter for sampling mixing ratio.
286    pub alpha: f64,
287}
288
289impl CutMixAugmenter {
290    /// Create a new CutMix augmenter.
291    ///
292    /// # Arguments
293    /// * `alpha` - Beta distribution parameter (typically 1.0)
294    ///
295    /// # Returns
296    /// New CutMix augmenter
297    pub fn new(alpha: f64) -> TrainResult<Self> {
298        if alpha <= 0.0 {
299            return Err(TrainError::InvalidParameter(
300                "alpha must be positive".to_string(),
301            ));
302        }
303        Ok(Self { alpha })
304    }
305
306    /// Apply CutMix augmentation to a batch of data.
307    ///
308    /// For 2D feature arrays, we treat the second dimension as a "spatial" dimension
309    /// and cut rectangular regions along it.
310    ///
311    /// # Arguments
312    /// * `data` - Input data batch [N, features]
313    /// * `labels` - Corresponding labels [N, classes]
314    /// * `rng` - Random number generator
315    ///
316    /// # Returns
317    /// Tuple of (augmented_data, augmented_labels)
318    pub fn augment_batch(
319        &self,
320        data: &ArrayView2<f64>,
321        labels: &ArrayView2<f64>,
322        rng: &mut StdRng,
323    ) -> TrainResult<(Array2<f64>, Array2<f64>)> {
324        if data.nrows() != labels.nrows() {
325            return Err(TrainError::InvalidParameter(
326                "data and labels must have same number of rows".to_string(),
327            ));
328        }
329
330        let n = data.nrows();
331        let features = data.ncols();
332        let mut augmented_data = data.to_owned();
333        let mut augmented_labels = labels.to_owned();
334
335        // Create random permutation for pairing samples
336        let mut indices: Vec<usize> = (0..n).collect();
337        for i in (1..n).rev() {
338            let j = rng.gen_range(0..=i);
339            indices.swap(i, j);
340        }
341
342        for i in 0..n {
343            let j = indices[i];
344
345            // Sample mixing ratio from Beta distribution
346            let lambda = self.sample_beta(rng);
347
348            // Generate random bounding box
349            // For 1D feature vectors, we cut along the feature dimension
350            let cut_ratio = (1.0 - lambda).sqrt();
351            let cut_size = (features as f64 * cut_ratio) as usize;
352            let cut_size = cut_size.max(1).min(features - 1);
353
354            // Random starting position
355            let start = if features > cut_size {
356                rng.gen_range(0..=(features - cut_size))
357            } else {
358                0
359            };
360
361            // Apply CutMix: replace region with data from paired sample
362            for k in start..(start + cut_size).min(features) {
363                augmented_data[[i, k]] = data[[j, k]];
364            }
365
366            // Mix labels proportionally to the area of the cut region
367            let actual_ratio = cut_size as f64 / features as f64;
368            for k in 0..labels.ncols() {
369                augmented_labels[[i, k]] =
370                    (1.0 - actual_ratio) * labels[[i, k]] + actual_ratio * labels[[j, k]];
371            }
372        }
373
374        Ok((augmented_data, augmented_labels))
375    }
376
377    /// Sample from Beta(alpha, alpha) distribution.
378    ///
379    /// Simplified implementation using uniform distribution when alpha is close to 1.
380    fn sample_beta(&self, rng: &mut StdRng) -> f64 {
381        if self.alpha < 0.5 {
382            // For small alpha, prefer values near 0 or 1
383            if rng.random::<f64>() < 0.5 {
384                rng.random::<f64>().powf(2.0)
385            } else {
386                1.0 - rng.random::<f64>().powf(2.0)
387            }
388        } else {
389            // For alpha >= 0.5, approximate with uniform
390            rng.random::<f64>()
391        }
392    }
393}
394
395impl Default for CutMixAugmenter {
396    fn default() -> Self {
397        Self { alpha: 1.0 }
398    }
399}
400
401impl DataAugmenter for CutMixAugmenter {
402    fn augment(&self, data: &ArrayView2<f64>, _rng: &mut StdRng) -> TrainResult<Array2<f64>> {
403        // For single-sample augmentation, no operation
404        // In practice, CutMix should be used with augment_batch
405        Ok(data.to_owned())
406    }
407}
408
409/// Composite augmenter that applies multiple augmentations sequentially.
410#[derive(Clone, Default)]
411pub struct CompositeAugmenter {
412    augmenters: Vec<Box<dyn AugmenterClone>>,
413}
414
415impl std::fmt::Debug for CompositeAugmenter {
416    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
417        f.debug_struct("CompositeAugmenter")
418            .field("num_augmenters", &self.augmenters.len())
419            .finish()
420    }
421}
422
423/// Helper trait for cloning boxed augmenters.
424trait AugmenterClone: DataAugmenter {
425    fn clone_box(&self) -> Box<dyn AugmenterClone>;
426}
427
428impl<T: DataAugmenter + Clone + 'static> AugmenterClone for T {
429    fn clone_box(&self) -> Box<dyn AugmenterClone> {
430        Box::new(self.clone())
431    }
432}
433
434impl Clone for Box<dyn AugmenterClone> {
435    fn clone(&self) -> Self {
436        self.clone_box()
437    }
438}
439
440impl DataAugmenter for Box<dyn AugmenterClone> {
441    fn augment(&self, data: &ArrayView2<f64>, rng: &mut StdRng) -> TrainResult<Array2<f64>> {
442        (**self).augment(data, rng)
443    }
444}
445
446impl CompositeAugmenter {
447    /// Create a new composite augmenter.
448    pub fn new() -> Self {
449        Self {
450            augmenters: Vec::new(),
451        }
452    }
453
454    /// Add an augmenter to the pipeline.
455    pub fn add<A: DataAugmenter + Clone + 'static>(&mut self, augmenter: A) {
456        self.augmenters.push(Box::new(augmenter));
457    }
458
459    /// Get the number of augmenters.
460    pub fn len(&self) -> usize {
461        self.augmenters.len()
462    }
463
464    /// Check if the pipeline is empty.
465    pub fn is_empty(&self) -> bool {
466        self.augmenters.is_empty()
467    }
468}
469
470impl DataAugmenter for CompositeAugmenter {
471    fn augment(&self, data: &ArrayView2<f64>, rng: &mut StdRng) -> TrainResult<Array2<f64>> {
472        let mut result = data.to_owned();
473
474        for augmenter in &self.augmenters {
475            result = augmenter.augment(&result.view(), rng)?;
476        }
477
478        Ok(result)
479    }
480}
481
482/// Random Erasing augmentation.
483///
484/// Randomly erases rectangular regions in the input data with random values.
485/// This technique prevents overfitting and improves generalization, especially for image data.
486///
487/// Reference: Zhong et al., "Random Erasing Data Augmentation" (AAAI 2020)
488///
489/// # Parameters
490/// - `p`: Probability of applying erasing (default: 0.5)
491/// - `scale`: Range of proportion of erased area (default: [0.02, 0.33])
492/// - `ratio`: Range of aspect ratio of erased area (default: [0.3, 3.3])
493/// - `value`: Value to fill erased region (0.0 = zero, 1.0 = random, -1.0 = pixel mean)
494#[derive(Debug, Clone)]
495pub struct RandomErasingAugmenter {
496    /// Probability of applying erasing.
497    pub probability: f64,
498    /// Minimum proportion of erased area.
499    pub scale_min: f64,
500    /// Maximum proportion of erased area.
501    pub scale_max: f64,
502    /// Minimum aspect ratio.
503    pub ratio_min: f64,
504    /// Maximum aspect ratio.
505    pub ratio_max: f64,
506    /// Fill value (0.0 = zero, 1.0 = random).
507    pub fill_value: f64,
508}
509
510impl RandomErasingAugmenter {
511    /// Create a new Random Erasing augmenter with custom parameters.
512    pub fn new(
513        probability: f64,
514        scale_min: f64,
515        scale_max: f64,
516        ratio_min: f64,
517        ratio_max: f64,
518        fill_value: f64,
519    ) -> TrainResult<Self> {
520        if !(0.0..=1.0).contains(&probability) {
521            return Err(TrainError::InvalidParameter(
522                "probability must be in [0, 1]".to_string(),
523            ));
524        }
525        if scale_min >= scale_max || scale_min < 0.0 || scale_max > 1.0 {
526            return Err(TrainError::InvalidParameter(
527                "scale range must be valid (0 <= min < max <= 1)".to_string(),
528            ));
529        }
530        if ratio_min <= 0.0 || ratio_min >= ratio_max {
531            return Err(TrainError::InvalidParameter(
532                "ratio range must be valid (0 < min < max)".to_string(),
533            ));
534        }
535
536        Ok(Self {
537            probability,
538            scale_min,
539            scale_max,
540            ratio_min,
541            ratio_max,
542            fill_value,
543        })
544    }
545
546    /// Create with default parameters (as in the paper).
547    pub fn with_defaults() -> Self {
548        Self {
549            probability: 0.5,
550            scale_min: 0.02,
551            scale_max: 0.33,
552            ratio_min: 0.3,
553            ratio_max: 3.3,
554            fill_value: 0.0,
555        }
556    }
557}
558
559impl Default for RandomErasingAugmenter {
560    fn default() -> Self {
561        Self::with_defaults()
562    }
563}
564
565impl DataAugmenter for RandomErasingAugmenter {
566    fn augment(&self, data: &ArrayView2<f64>, rng: &mut StdRng) -> TrainResult<Array2<f64>> {
567        let mut augmented = data.to_owned();
568
569        // Apply with probability p
570        if rng.random::<f64>() > self.probability {
571            return Ok(augmented);
572        }
573
574        let (height, width) = (data.nrows(), data.ncols());
575        let area = (height * width) as f64;
576
577        // Try multiple times to find a valid erasing region
578        for _ in 0..10 {
579            // Random scale (proportion of total area)
580            let scale = self.scale_min + rng.random::<f64>() * (self.scale_max - self.scale_min);
581            let erase_area = area * scale;
582
583            // Random aspect ratio
584            let aspect_ratio =
585                self.ratio_min + rng.random::<f64>() * (self.ratio_max - self.ratio_min);
586
587            // Compute erase region dimensions
588            let h = (erase_area * aspect_ratio).sqrt().min(height as f64);
589            let w = (erase_area / aspect_ratio).sqrt().min(width as f64);
590
591            if h >= 1.0 && w >= 1.0 {
592                let erase_h = h as usize;
593                let erase_w = w as usize;
594
595                // Random position
596                let i = if erase_h < height {
597                    (rng.random::<f64>() * (height - erase_h) as f64) as usize
598                } else {
599                    0
600                };
601                let j = if erase_w < width {
602                    (rng.random::<f64>() * (width - erase_w) as f64) as usize
603                } else {
604                    0
605                };
606
607                // Fill with specified value
608                if self.fill_value == 1.0 {
609                    // Random values
610                    for row in i..i + erase_h.min(height - i) {
611                        for col in j..j + erase_w.min(width - j) {
612                            augmented[[row, col]] = rng.random();
613                        }
614                    }
615                } else {
616                    // Fixed value (0.0 or specified)
617                    for row in i..i + erase_h.min(height - i) {
618                        for col in j..j + erase_w.min(width - j) {
619                            augmented[[row, col]] = self.fill_value;
620                        }
621                    }
622                }
623
624                break;
625            }
626        }
627
628        Ok(augmented)
629    }
630}
631
632/// CutOut augmentation.
633///
634/// Randomly erases a fixed-size square region in the input data.
635/// Simpler variant of Random Erasing with fixed square regions.
636///
637/// Reference: DeVries & Taylor, "Improved Regularization of Convolutional Neural Networks with Cutout" (2017)
638///
639/// # Parameters
640/// - `size`: Size of the square region to erase
641/// - `p`: Probability of applying cutout (default: 1.0)
642/// - `fill_value`: Value to fill the erased region (default: 0.0)
643#[derive(Debug, Clone)]
644pub struct CutOutAugmenter {
645    /// Size of the square cutout region.
646    pub cutout_size: usize,
647    /// Probability of applying cutout.
648    pub probability: f64,
649    /// Fill value for erased region.
650    pub fill_value: f64,
651}
652
653impl CutOutAugmenter {
654    /// Create a new CutOut augmenter.
655    pub fn new(cutout_size: usize, probability: f64, fill_value: f64) -> TrainResult<Self> {
656        if cutout_size == 0 {
657            return Err(TrainError::InvalidParameter(
658                "cutout_size must be > 0".to_string(),
659            ));
660        }
661        if !(0.0..=1.0).contains(&probability) {
662            return Err(TrainError::InvalidParameter(
663                "probability must be in [0, 1]".to_string(),
664            ));
665        }
666
667        Ok(Self {
668            cutout_size,
669            probability,
670            fill_value,
671        })
672    }
673
674    /// Create with default parameters.
675    pub fn with_size(size: usize) -> TrainResult<Self> {
676        Self::new(size, 1.0, 0.0)
677    }
678}
679
680impl DataAugmenter for CutOutAugmenter {
681    fn augment(&self, data: &ArrayView2<f64>, rng: &mut StdRng) -> TrainResult<Array2<f64>> {
682        let mut augmented = data.to_owned();
683
684        // Apply with probability p
685        if rng.random::<f64>() > self.probability {
686            return Ok(augmented);
687        }
688
689        let (height, width) = (data.nrows(), data.ncols());
690
691        // Random center position
692        let center_y = (rng.random::<f64>() * height as f64) as usize;
693        let center_x = (rng.random::<f64>() * width as f64) as usize;
694
695        // Compute cutout region bounds (allow partial cutout at boundaries)
696        let half_size = self.cutout_size / 2;
697
698        let y_start = center_y.saturating_sub(half_size);
699        let y_end = (center_y + half_size).min(height);
700
701        let x_start = center_x.saturating_sub(half_size);
702        let x_end = (center_x + half_size).min(width);
703
704        // Erase the region
705        for i in y_start..y_end {
706            for j in x_start..x_end {
707                augmented[[i, j]] = self.fill_value;
708            }
709        }
710
711        Ok(augmented)
712    }
713}
714
715#[cfg(test)]
716mod tests {
717    use super::*;
718    use scirs2_core::ndarray::array;
719    use scirs2_core::random::SeedableRng;
720
721    fn create_test_rng() -> StdRng {
722        StdRng::seed_from_u64(42)
723    }
724
725    #[test]
726    fn test_no_augmentation() {
727        let augmenter = NoAugmentation;
728        let data = array![[1.0, 2.0], [3.0, 4.0]];
729        let mut rng = create_test_rng();
730
731        let augmented = augmenter.augment(&data.view(), &mut rng).expect("unwrap");
732        assert_eq!(augmented, data);
733    }
734
735    #[test]
736    fn test_noise_augmenter() {
737        let augmenter = NoiseAugmenter::new(0.1).expect("unwrap");
738        let data = array![[1.0, 2.0], [3.0, 4.0]];
739        let mut rng = create_test_rng();
740
741        let augmented = augmenter.augment(&data.view(), &mut rng).expect("unwrap");
742
743        // Shape should be preserved
744        assert_eq!(augmented.shape(), data.shape());
745
746        // Values should be different (with high probability)
747        assert_ne!(augmented[[0, 0]], data[[0, 0]]);
748
749        // But should be close to original values
750        for i in 0..data.nrows() {
751            for j in 0..data.ncols() {
752                let diff = (augmented[[i, j]] - data[[i, j]]).abs();
753                assert!(diff < 1.0); // Within reasonable noise range
754            }
755        }
756    }
757
758    #[test]
759    fn test_noise_augmenter_invalid() {
760        let result = NoiseAugmenter::new(-0.1);
761        assert!(result.is_err());
762    }
763
764    #[test]
765    fn test_scale_augmenter() {
766        let augmenter = ScaleAugmenter::new(0.2).expect("unwrap");
767        let data = array![[1.0, 2.0], [3.0, 4.0]];
768        let mut rng = create_test_rng();
769
770        let augmented = augmenter.augment(&data.view(), &mut rng).expect("unwrap");
771
772        // Shape should be preserved
773        assert_eq!(augmented.shape(), data.shape());
774
775        // All values should be scaled by the same factor
776        let scale = augmented[[0, 0]] / data[[0, 0]];
777        for i in 0..data.nrows() {
778            for j in 0..data.ncols() {
779                let computed_scale = augmented[[i, j]] / data[[i, j]];
780                assert!((computed_scale - scale).abs() < 1e-10);
781            }
782        }
783
784        // Scale should be within range [0.8, 1.2]
785        assert!((0.8..=1.2).contains(&scale));
786    }
787
788    #[test]
789    fn test_scale_augmenter_invalid() {
790        assert!(ScaleAugmenter::new(-0.1).is_err());
791        assert!(ScaleAugmenter::new(1.5).is_err());
792    }
793
794    #[test]
795    fn test_rotation_augmenter() {
796        let augmenter = RotationAugmenter::default();
797        let data = array![[1.0, 2.0], [3.0, 4.0]];
798        let mut rng = create_test_rng();
799
800        let augmented = augmenter.augment(&data.view(), &mut rng).expect("unwrap");
801
802        // Shape should be preserved
803        assert_eq!(augmented.shape(), data.shape());
804    }
805
806    #[test]
807    fn test_mixup_augmenter_batch() {
808        let augmenter = MixupAugmenter::new(1.0).expect("unwrap");
809        let data = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
810        let labels = array![[1.0, 0.0], [0.0, 1.0], [1.0, 0.0]];
811        let mut rng = create_test_rng();
812
813        let (aug_data, aug_labels) = augmenter
814            .augment_batch(&data.view(), &labels.view(), &mut rng)
815            .expect("unwrap");
816
817        // Shapes should be preserved
818        assert_eq!(aug_data.shape(), data.shape());
819        assert_eq!(aug_labels.shape(), labels.shape());
820
821        // Values should be interpolations (between min and max of original)
822        let data_min = data.iter().cloned().fold(f64::INFINITY, f64::min);
823        let data_max = data.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
824
825        for &val in aug_data.iter() {
826            assert!(val >= data_min && val <= data_max);
827        }
828    }
829
830    #[test]
831    fn test_mixup_invalid_alpha() {
832        let result = MixupAugmenter::new(0.0);
833        assert!(result.is_err());
834
835        let result = MixupAugmenter::new(-1.0);
836        assert!(result.is_err());
837    }
838
839    #[test]
840    fn test_mixup_mismatched_shapes() {
841        let augmenter = MixupAugmenter::default();
842        let data = array![[1.0, 2.0], [3.0, 4.0]];
843        let labels = array![[1.0, 0.0]]; // Wrong shape
844        let mut rng = create_test_rng();
845
846        let result = augmenter.augment_batch(&data.view(), &labels.view(), &mut rng);
847        assert!(result.is_err());
848    }
849
850    #[test]
851    fn test_cutmix_augmenter_batch() {
852        let augmenter = CutMixAugmenter::new(1.0).expect("unwrap");
853        let data = array![
854            [1.0, 2.0, 3.0, 4.0],
855            [5.0, 6.0, 7.0, 8.0],
856            [9.0, 10.0, 11.0, 12.0]
857        ];
858        let labels = array![[1.0, 0.0], [0.0, 1.0], [1.0, 0.0]];
859        let mut rng = create_test_rng();
860
861        let (aug_data, aug_labels) = augmenter
862            .augment_batch(&data.view(), &labels.view(), &mut rng)
863            .expect("unwrap");
864
865        // Shapes should be preserved
866        assert_eq!(aug_data.shape(), data.shape());
867        assert_eq!(aug_labels.shape(), labels.shape());
868
869        // Each row should contain a mix of original values
870        // (some regions from original, some from paired sample)
871        for i in 0..aug_data.nrows() {
872            let mut found_original = false;
873            let mut found_different = false;
874
875            for j in 0..aug_data.ncols() {
876                // Check if value matches original row
877                if (aug_data[[i, j]] - data[[i, j]]).abs() < 1e-10 {
878                    found_original = true;
879                } else {
880                    found_different = true;
881                }
882            }
883
884            // Should have both original and swapped regions (unless randomly paired with self)
885            assert!(found_original || found_different);
886        }
887
888        // Label mixing: each element should be in [0, 1] and sum across classes
889        for i in 0..aug_labels.nrows() {
890            let sum: f64 = aug_labels.row(i).iter().sum();
891            assert!((sum - 1.0).abs() < 1e-10, "Labels should sum to 1.0");
892        }
893    }
894
895    #[test]
896    fn test_cutmix_invalid_alpha() {
897        let result = CutMixAugmenter::new(0.0);
898        assert!(result.is_err());
899
900        let result = CutMixAugmenter::new(-1.0);
901        assert!(result.is_err());
902    }
903
904    #[test]
905    fn test_cutmix_mismatched_shapes() {
906        let augmenter = CutMixAugmenter::default();
907        let data = array![[1.0, 2.0], [3.0, 4.0]];
908        let labels = array![[1.0, 0.0]]; // Wrong shape
909        let mut rng = create_test_rng();
910
911        let result = augmenter.augment_batch(&data.view(), &labels.view(), &mut rng);
912        assert!(result.is_err());
913    }
914
915    #[test]
916    fn test_cutmix_label_proportions() {
917        let augmenter = CutMixAugmenter::new(1.0).expect("unwrap");
918        // Use distinctive patterns
919        let data = array![[10.0, 10.0, 10.0, 10.0], [20.0, 20.0, 20.0, 20.0]];
920        let labels = array![[1.0, 0.0], [0.0, 1.0]];
921        let mut rng = create_test_rng();
922
923        let (aug_data, aug_labels) = augmenter
924            .augment_batch(&data.view(), &labels.view(), &mut rng)
925            .expect("unwrap");
926
927        // Verify that labels are mixed proportionally
928        for i in 0..aug_labels.nrows() {
929            // Each sample should have labels that sum to 1
930            let sum: f64 = aug_labels.row(i).iter().sum();
931            assert!((sum - 1.0).abs() < 1e-10);
932
933            // Labels should be between original values
934            for j in 0..aug_labels.ncols() {
935                assert!(aug_labels[[i, j]] >= 0.0);
936                assert!(aug_labels[[i, j]] <= 1.0);
937            }
938        }
939
940        // Verify data has been cut and mixed
941        assert_eq!(aug_data.shape(), data.shape());
942    }
943
944    #[test]
945    fn test_composite_augmenter() {
946        let mut composite = CompositeAugmenter::new();
947        composite.add(NoiseAugmenter::new(0.01).expect("unwrap"));
948        composite.add(ScaleAugmenter::new(0.1).expect("unwrap"));
949
950        let data = array![[1.0, 2.0], [3.0, 4.0]];
951        let mut rng = create_test_rng();
952
953        let augmented = composite.augment(&data.view(), &mut rng).expect("unwrap");
954
955        // Shape should be preserved
956        assert_eq!(augmented.shape(), data.shape());
957
958        // Values should be different due to augmentation
959        assert_ne!(augmented[[0, 0]], data[[0, 0]]);
960    }
961
962    #[test]
963    fn test_composite_empty() {
964        let composite = CompositeAugmenter::new();
965        assert!(composite.is_empty());
966        assert_eq!(composite.len(), 0);
967
968        let data = array![[1.0, 2.0]];
969        let mut rng = create_test_rng();
970
971        let augmented = composite.augment(&data.view(), &mut rng).expect("unwrap");
972        assert_eq!(augmented, data);
973    }
974
975    #[test]
976    fn test_composite_multiple() {
977        let mut composite = CompositeAugmenter::new();
978        composite.add(NoAugmentation);
979        composite.add(ScaleAugmenter::default());
980        composite.add(NoiseAugmenter::default());
981
982        assert_eq!(composite.len(), 3);
983        assert!(!composite.is_empty());
984    }
985
986    #[test]
987    fn test_random_erasing_creation() {
988        let augmenter =
989            RandomErasingAugmenter::new(0.5, 0.02, 0.33, 0.3, 3.3, 0.0).expect("unwrap");
990        assert_eq!(augmenter.probability, 0.5);
991        assert_eq!(augmenter.scale_min, 0.02);
992        assert_eq!(augmenter.scale_max, 0.33);
993    }
994
995    #[test]
996    fn test_random_erasing_invalid_params() {
997        // Invalid probability
998        assert!(RandomErasingAugmenter::new(1.5, 0.02, 0.33, 0.3, 3.3, 0.0).is_err());
999
1000        // Invalid scale range
1001        assert!(RandomErasingAugmenter::new(0.5, 0.33, 0.02, 0.3, 3.3, 0.0).is_err());
1002
1003        // Invalid ratio range
1004        assert!(RandomErasingAugmenter::new(0.5, 0.02, 0.33, 3.3, 0.3, 0.0).is_err());
1005    }
1006
1007    #[test]
1008    fn test_random_erasing_augment() {
1009        let augmenter = RandomErasingAugmenter::with_defaults();
1010        let data = scirs2_core::ndarray::Array2::ones((10, 10));
1011        let mut rng = create_test_rng();
1012
1013        let augmented = augmenter.augment(&data.view(), &mut rng).expect("unwrap");
1014
1015        // Shape should be preserved
1016        assert_eq!(augmented.shape(), data.shape());
1017
1018        // Some values may be erased (but not guaranteed due to probability)
1019    }
1020
1021    #[test]
1022    fn test_random_erasing_probability_zero() {
1023        let augmenter =
1024            RandomErasingAugmenter::new(0.0, 0.02, 0.33, 0.3, 3.3, 0.0).expect("unwrap");
1025        let data = scirs2_core::ndarray::Array2::ones((10, 10));
1026        let mut rng = create_test_rng();
1027
1028        let augmented = augmenter.augment(&data.view(), &mut rng).expect("unwrap");
1029
1030        // With probability 0, data should be unchanged
1031        assert_eq!(augmented, data);
1032    }
1033
1034    #[test]
1035    fn test_cutout_creation() {
1036        let augmenter = CutOutAugmenter::new(5, 1.0, 0.0).expect("unwrap");
1037        assert_eq!(augmenter.cutout_size, 5);
1038        assert_eq!(augmenter.probability, 1.0);
1039        assert_eq!(augmenter.fill_value, 0.0);
1040    }
1041
1042    #[test]
1043    fn test_cutout_invalid_params() {
1044        // Zero size
1045        assert!(CutOutAugmenter::new(0, 1.0, 0.0).is_err());
1046
1047        // Invalid probability
1048        assert!(CutOutAugmenter::new(5, 1.5, 0.0).is_err());
1049    }
1050
1051    #[test]
1052    fn test_cutout_augment() {
1053        let augmenter = CutOutAugmenter::with_size(3).expect("unwrap");
1054        let data = scirs2_core::ndarray::Array2::ones((10, 10));
1055        let mut rng = create_test_rng();
1056
1057        let augmented = augmenter.augment(&data.view(), &mut rng).expect("unwrap");
1058
1059        // Shape should be preserved
1060        assert_eq!(augmented.shape(), data.shape());
1061
1062        // Some values should be zero (erased)
1063        let zeros_count = augmented.iter().filter(|&&x| x == 0.0).count();
1064        assert!(zeros_count > 0, "Expected some values to be erased");
1065        assert!(zeros_count < 100, "Not all values should be erased");
1066    }
1067
1068    #[test]
1069    fn test_cutout_probability_zero() {
1070        let augmenter = CutOutAugmenter::new(5, 0.0, 0.0).expect("unwrap");
1071        let data = scirs2_core::ndarray::Array2::ones((10, 10));
1072        let mut rng = create_test_rng();
1073
1074        let augmented = augmenter.augment(&data.view(), &mut rng).expect("unwrap");
1075
1076        // With probability 0, data should be unchanged
1077        assert_eq!(augmented, data);
1078    }
1079
1080    #[test]
1081    fn test_cutout_fill_value() {
1082        let augmenter = CutOutAugmenter::new(3, 1.0, 0.5).expect("unwrap");
1083        let data = scirs2_core::ndarray::Array2::ones((10, 10));
1084        let mut rng = create_test_rng();
1085
1086        let augmented = augmenter.augment(&data.view(), &mut rng).expect("unwrap");
1087
1088        // Some values should be 0.5 (fill value)
1089        let filled_count = augmented.iter().filter(|&&x| x == 0.5).count();
1090        assert!(
1091            filled_count > 0,
1092            "Expected some values to be filled with 0.5"
1093        );
1094    }
1095}