Skip to main content

tensorlogic_train/
augmentation.rs

1//! Data augmentation techniques for training.
2//!
3//! This module provides various data augmentation strategies to improve model generalization:
4//! - Noise augmentation (Gaussian)
5//! - Scale augmentation (random scaling)
6//! - Rotation augmentation (placeholder for future)
7//! - Mixup augmentation (interpolation between samples)
8//! - CutMix augmentation (cutting and mixing patches)
9//! - Random Erasing (randomly erase rectangular regions)
10//! - CutOut (fixed-size random erasing)
11//!
12
13use crate::{TrainError, TrainResult};
14use scirs2_core::ndarray::{Array, Array2, ArrayView2};
15use scirs2_core::random::{Rng, StdRng};
16
17/// Trait for data augmentation strategies.
18pub trait DataAugmenter {
19    /// Augment the given data.
20    ///
21    /// # Arguments
22    /// * `data` - Input data to augment
23    /// * `rng` - Random number generator for stochastic augmentation
24    ///
25    /// # Returns
26    /// Augmented data with the same shape as input
27    fn augment(&self, data: &ArrayView2<f64>, rng: &mut StdRng) -> TrainResult<Array2<f64>>;
28}
29
30/// No augmentation (identity transformation).
31///
32/// Useful for testing or as a placeholder.
33#[derive(Debug, Clone, Default)]
34pub struct NoAugmentation;
35
36impl DataAugmenter for NoAugmentation {
37    fn augment(&self, data: &ArrayView2<f64>, _rng: &mut StdRng) -> TrainResult<Array2<f64>> {
38        Ok(data.to_owned())
39    }
40}
41
42/// Gaussian noise augmentation.
43///
44/// Adds random Gaussian noise to the input data: x' = x + N(0, σ²)
45#[derive(Debug, Clone)]
46pub struct NoiseAugmenter {
47    /// Standard deviation of the Gaussian noise.
48    pub std_dev: f64,
49}
50
51impl NoiseAugmenter {
52    /// Create a new noise augmenter.
53    ///
54    /// # Arguments
55    /// * `std_dev` - Standard deviation of the noise
56    pub fn new(std_dev: f64) -> TrainResult<Self> {
57        if std_dev < 0.0 {
58            return Err(TrainError::InvalidParameter(
59                "std_dev must be non-negative".to_string(),
60            ));
61        }
62        Ok(Self { std_dev })
63    }
64}
65
66impl Default for NoiseAugmenter {
67    fn default() -> Self {
68        Self { std_dev: 0.01 }
69    }
70}
71
72impl DataAugmenter for NoiseAugmenter {
73    fn augment(&self, data: &ArrayView2<f64>, rng: &mut StdRng) -> TrainResult<Array2<f64>> {
74        let mut augmented = data.to_owned();
75
76        // Add Gaussian noise using Box-Muller transform
77        for value in augmented.iter_mut() {
78            let u1: f64 = rng.random();
79            let u2: f64 = rng.random();
80
81            // Box-Muller transform
82            let z0 = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
83            let noise = z0 * self.std_dev;
84
85            *value += noise;
86        }
87
88        Ok(augmented)
89    }
90}
91
92/// Scale augmentation.
93///
94/// Randomly scales the input data by a factor in [1 - scale_range, 1 + scale_range].
95#[derive(Debug, Clone)]
96pub struct ScaleAugmenter {
97    /// Range of scaling factor (e.g., 0.1 means scale in [0.9, 1.1]).
98    pub scale_range: f64,
99}
100
101impl ScaleAugmenter {
102    /// Create a new scale augmenter.
103    ///
104    /// # Arguments
105    /// * `scale_range` - Range of scaling factor
106    pub fn new(scale_range: f64) -> TrainResult<Self> {
107        if !(0.0..=1.0).contains(&scale_range) {
108            return Err(TrainError::InvalidParameter(
109                "scale_range must be in [0, 1]".to_string(),
110            ));
111        }
112        Ok(Self { scale_range })
113    }
114}
115
116impl Default for ScaleAugmenter {
117    fn default() -> Self {
118        Self { scale_range: 0.1 }
119    }
120}
121
122impl DataAugmenter for ScaleAugmenter {
123    fn augment(&self, data: &ArrayView2<f64>, rng: &mut StdRng) -> TrainResult<Array2<f64>> {
124        // Generate random scale factor
125        let scale = 1.0 + (rng.random::<f64>() * 2.0 - 1.0) * self.scale_range;
126
127        let augmented = data.mapv(|x| x * scale);
128        Ok(augmented)
129    }
130}
131
132/// Rotation augmentation (placeholder for future implementation).
133///
134/// For 2D images, this would apply random rotations.
135/// Currently implements a simplified version for tabular data.
136#[derive(Debug, Clone)]
137pub struct RotationAugmenter {
138    /// Maximum rotation angle in radians.
139    pub max_angle: f64,
140}
141
142impl RotationAugmenter {
143    /// Create a new rotation augmenter.
144    ///
145    /// # Arguments
146    /// * `max_angle` - Maximum rotation angle in radians
147    pub fn new(max_angle: f64) -> Self {
148        Self { max_angle }
149    }
150}
151
152impl Default for RotationAugmenter {
153    fn default() -> Self {
154        Self {
155            max_angle: std::f64::consts::PI / 18.0, // 10 degrees
156        }
157    }
158}
159
160impl DataAugmenter for RotationAugmenter {
161    fn augment(&self, data: &ArrayView2<f64>, rng: &mut StdRng) -> TrainResult<Array2<f64>> {
162        // For now, this is a placeholder that returns a simple transformation
163        // Future: implement proper 2D rotation for image data
164        let angle = (rng.random::<f64>() * 2.0 - 1.0) * self.max_angle;
165
166        // Apply a simple rotation-inspired transformation
167        let cos_a = angle.cos();
168        let sin_a = angle.sin();
169
170        let augmented = data.mapv(|x| x * cos_a + x * sin_a * 0.1);
171        Ok(augmented)
172    }
173}
174
175/// Mixup augmentation.
176///
177/// Creates new training samples by linearly interpolating between pairs of samples:
178/// x' = λ * x₁ + (1 - λ) * x₂, y' = λ * y₁ + (1 - λ) * y₂
179///
180/// Reference: Zhang et al., "mixup: Beyond Empirical Risk Minimization", ICLR 2018
181#[derive(Debug, Clone)]
182pub struct MixupAugmenter {
183    /// Alpha parameter for Beta distribution (controls mixing strength).
184    pub alpha: f64,
185}
186
187impl MixupAugmenter {
188    /// Create a new mixup augmenter.
189    ///
190    /// # Arguments
191    /// * `alpha` - Alpha parameter for Beta distribution
192    pub fn new(alpha: f64) -> TrainResult<Self> {
193        if alpha <= 0.0 {
194            return Err(TrainError::InvalidParameter(
195                "alpha must be positive".to_string(),
196            ));
197        }
198        Ok(Self { alpha })
199    }
200
201    /// Apply mixup to a batch of data.
202    ///
203    /// # Arguments
204    /// * `data` - Input data batch [N, features]
205    /// * `labels` - Corresponding labels [N, classes]
206    /// * `rng` - Random number generator
207    ///
208    /// # Returns
209    /// Tuple of (augmented_data, augmented_labels)
210    pub fn augment_batch(
211        &self,
212        data: &ArrayView2<f64>,
213        labels: &ArrayView2<f64>,
214        rng: &mut StdRng,
215    ) -> TrainResult<(Array2<f64>, Array2<f64>)> {
216        if data.nrows() != labels.nrows() {
217            return Err(TrainError::InvalidParameter(
218                "data and labels must have same number of rows".to_string(),
219            ));
220        }
221
222        let n = data.nrows();
223        let mut augmented_data = Array::zeros(data.raw_dim());
224        let mut augmented_labels = Array::zeros(labels.raw_dim());
225
226        // Create random permutation
227        let mut indices: Vec<usize> = (0..n).collect();
228        for i in (1..n).rev() {
229            let j = rng.gen_range(0..=i);
230            indices.swap(i, j);
231        }
232
233        for i in 0..n {
234            let j = indices[i];
235
236            // Sample mixing coefficient from Beta distribution
237            // Simplified: use uniform distribution as approximation
238            let lambda = self.sample_beta(rng);
239
240            // Mix data: x' = λ * x_i + (1 - λ) * x_j
241            for k in 0..data.ncols() {
242                augmented_data[[i, k]] = lambda * data[[i, k]] + (1.0 - lambda) * data[[j, k]];
243            }
244
245            // Mix labels: y' = λ * y_i + (1 - λ) * y_j
246            for k in 0..labels.ncols() {
247                augmented_labels[[i, k]] =
248                    lambda * labels[[i, k]] + (1.0 - lambda) * labels[[j, k]];
249            }
250        }
251
252        Ok((augmented_data, augmented_labels))
253    }
254
255    /// Sample from Beta(alpha, alpha) distribution.
256    ///
257    /// Simplified implementation using uniform distribution when alpha is close to 1.
258    fn sample_beta(&self, rng: &mut StdRng) -> f64 {
259        if self.alpha < 0.5 {
260            // For small alpha, prefer values near 0 or 1
261            if rng.random::<f64>() < 0.5 {
262                rng.random::<f64>().powf(2.0)
263            } else {
264                1.0 - rng.random::<f64>().powf(2.0)
265            }
266        } else {
267            // For alpha >= 0.5, approximate with uniform
268            rng.random::<f64>()
269        }
270    }
271}
272
273impl Default for MixupAugmenter {
274    fn default() -> Self {
275        Self { alpha: 1.0 }
276    }
277}
278
279impl DataAugmenter for MixupAugmenter {
280    fn augment(&self, data: &ArrayView2<f64>, _rng: &mut StdRng) -> TrainResult<Array2<f64>> {
281        // For single-sample augmentation, mix with itself (no-op)
282        // In practice, mixup should be used with augment_batch
283        Ok(data.to_owned())
284    }
285}
286
287/// CutMix augmentation (ICCV 2019).
288///
289/// Instead of mixing pixels uniformly like Mixup, CutMix cuts a rectangular region
290/// from one image and pastes it to another. Labels are mixed proportionally to the
291/// area of the patch.
292///
293/// Reference: Yun et al. "CutMix: Regularization Strategy to Train Strong Classifiers
294/// with Localizable Features" (ICCV 2019)
295#[derive(Debug, Clone)]
296pub struct CutMixAugmenter {
297    /// Beta distribution parameter for sampling mixing ratio.
298    pub alpha: f64,
299}
300
301impl CutMixAugmenter {
302    /// Create a new CutMix augmenter.
303    ///
304    /// # Arguments
305    /// * `alpha` - Beta distribution parameter (typically 1.0)
306    ///
307    /// # Returns
308    /// New CutMix augmenter
309    pub fn new(alpha: f64) -> TrainResult<Self> {
310        if alpha <= 0.0 {
311            return Err(TrainError::InvalidParameter(
312                "alpha must be positive".to_string(),
313            ));
314        }
315        Ok(Self { alpha })
316    }
317
318    /// Apply CutMix augmentation to a batch of data.
319    ///
320    /// For 2D feature arrays, we treat the second dimension as a "spatial" dimension
321    /// and cut rectangular regions along it.
322    ///
323    /// # Arguments
324    /// * `data` - Input data batch [N, features]
325    /// * `labels` - Corresponding labels [N, classes]
326    /// * `rng` - Random number generator
327    ///
328    /// # Returns
329    /// Tuple of (augmented_data, augmented_labels)
330    pub fn augment_batch(
331        &self,
332        data: &ArrayView2<f64>,
333        labels: &ArrayView2<f64>,
334        rng: &mut StdRng,
335    ) -> TrainResult<(Array2<f64>, Array2<f64>)> {
336        if data.nrows() != labels.nrows() {
337            return Err(TrainError::InvalidParameter(
338                "data and labels must have same number of rows".to_string(),
339            ));
340        }
341
342        let n = data.nrows();
343        let features = data.ncols();
344        let mut augmented_data = data.to_owned();
345        let mut augmented_labels = labels.to_owned();
346
347        // Create random permutation for pairing samples
348        let mut indices: Vec<usize> = (0..n).collect();
349        for i in (1..n).rev() {
350            let j = rng.gen_range(0..=i);
351            indices.swap(i, j);
352        }
353
354        for i in 0..n {
355            let j = indices[i];
356
357            // Sample mixing ratio from Beta distribution
358            let lambda = self.sample_beta(rng);
359
360            // Generate random bounding box
361            // For 1D feature vectors, we cut along the feature dimension
362            let cut_ratio = (1.0 - lambda).sqrt();
363            let cut_size = (features as f64 * cut_ratio) as usize;
364            let cut_size = cut_size.max(1).min(features - 1);
365
366            // Random starting position
367            let start = if features > cut_size {
368                rng.gen_range(0..=(features - cut_size))
369            } else {
370                0
371            };
372
373            // Apply CutMix: replace region with data from paired sample
374            for k in start..(start + cut_size).min(features) {
375                augmented_data[[i, k]] = data[[j, k]];
376            }
377
378            // Mix labels proportionally to the area of the cut region
379            let actual_ratio = cut_size as f64 / features as f64;
380            for k in 0..labels.ncols() {
381                augmented_labels[[i, k]] =
382                    (1.0 - actual_ratio) * labels[[i, k]] + actual_ratio * labels[[j, k]];
383            }
384        }
385
386        Ok((augmented_data, augmented_labels))
387    }
388
389    /// Sample from Beta(alpha, alpha) distribution.
390    ///
391    /// Simplified implementation using uniform distribution when alpha is close to 1.
392    fn sample_beta(&self, rng: &mut StdRng) -> f64 {
393        if self.alpha < 0.5 {
394            // For small alpha, prefer values near 0 or 1
395            if rng.random::<f64>() < 0.5 {
396                rng.random::<f64>().powf(2.0)
397            } else {
398                1.0 - rng.random::<f64>().powf(2.0)
399            }
400        } else {
401            // For alpha >= 0.5, approximate with uniform
402            rng.random::<f64>()
403        }
404    }
405}
406
407impl Default for CutMixAugmenter {
408    fn default() -> Self {
409        Self { alpha: 1.0 }
410    }
411}
412
413impl DataAugmenter for CutMixAugmenter {
414    fn augment(&self, data: &ArrayView2<f64>, _rng: &mut StdRng) -> TrainResult<Array2<f64>> {
415        // For single-sample augmentation, no operation
416        // In practice, CutMix should be used with augment_batch
417        Ok(data.to_owned())
418    }
419}
420
421/// Composite augmenter that applies multiple augmentations sequentially.
422#[derive(Clone, Default)]
423pub struct CompositeAugmenter {
424    augmenters: Vec<Box<dyn AugmenterClone>>,
425}
426
427impl std::fmt::Debug for CompositeAugmenter {
428    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
429        f.debug_struct("CompositeAugmenter")
430            .field("num_augmenters", &self.augmenters.len())
431            .finish()
432    }
433}
434
435/// Helper trait for cloning boxed augmenters.
436trait AugmenterClone: DataAugmenter {
437    fn clone_box(&self) -> Box<dyn AugmenterClone>;
438}
439
440impl<T: DataAugmenter + Clone + 'static> AugmenterClone for T {
441    fn clone_box(&self) -> Box<dyn AugmenterClone> {
442        Box::new(self.clone())
443    }
444}
445
446impl Clone for Box<dyn AugmenterClone> {
447    fn clone(&self) -> Self {
448        self.clone_box()
449    }
450}
451
452impl DataAugmenter for Box<dyn AugmenterClone> {
453    fn augment(&self, data: &ArrayView2<f64>, rng: &mut StdRng) -> TrainResult<Array2<f64>> {
454        (**self).augment(data, rng)
455    }
456}
457
458impl CompositeAugmenter {
459    /// Create a new composite augmenter.
460    pub fn new() -> Self {
461        Self {
462            augmenters: Vec::new(),
463        }
464    }
465
466    /// Add an augmenter to the pipeline.
467    pub fn add<A: DataAugmenter + Clone + 'static>(&mut self, augmenter: A) {
468        self.augmenters.push(Box::new(augmenter));
469    }
470
471    /// Get the number of augmenters.
472    pub fn len(&self) -> usize {
473        self.augmenters.len()
474    }
475
476    /// Check if the pipeline is empty.
477    pub fn is_empty(&self) -> bool {
478        self.augmenters.is_empty()
479    }
480}
481
482impl DataAugmenter for CompositeAugmenter {
483    fn augment(&self, data: &ArrayView2<f64>, rng: &mut StdRng) -> TrainResult<Array2<f64>> {
484        let mut result = data.to_owned();
485
486        for augmenter in &self.augmenters {
487            result = augmenter.augment(&result.view(), rng)?;
488        }
489
490        Ok(result)
491    }
492}
493
494/// Random Erasing augmentation.
495///
496/// Randomly erases rectangular regions in the input data with random values.
497/// This technique prevents overfitting and improves generalization, especially for image data.
498///
499/// Reference: Zhong et al., "Random Erasing Data Augmentation" (AAAI 2020)
500///
501/// # Parameters
502/// - `p`: Probability of applying erasing (default: 0.5)
503/// - `scale`: Range of proportion of erased area (default: [0.02, 0.33])
504/// - `ratio`: Range of aspect ratio of erased area (default: [0.3, 3.3])
505/// - `value`: Value to fill erased region (0.0 = zero, 1.0 = random, -1.0 = pixel mean)
506#[derive(Debug, Clone)]
507pub struct RandomErasingAugmenter {
508    /// Probability of applying erasing.
509    pub probability: f64,
510    /// Minimum proportion of erased area.
511    pub scale_min: f64,
512    /// Maximum proportion of erased area.
513    pub scale_max: f64,
514    /// Minimum aspect ratio.
515    pub ratio_min: f64,
516    /// Maximum aspect ratio.
517    pub ratio_max: f64,
518    /// Fill value (0.0 = zero, 1.0 = random).
519    pub fill_value: f64,
520}
521
522impl RandomErasingAugmenter {
523    /// Create a new Random Erasing augmenter with custom parameters.
524    pub fn new(
525        probability: f64,
526        scale_min: f64,
527        scale_max: f64,
528        ratio_min: f64,
529        ratio_max: f64,
530        fill_value: f64,
531    ) -> TrainResult<Self> {
532        if !(0.0..=1.0).contains(&probability) {
533            return Err(TrainError::InvalidParameter(
534                "probability must be in [0, 1]".to_string(),
535            ));
536        }
537        if scale_min >= scale_max || scale_min < 0.0 || scale_max > 1.0 {
538            return Err(TrainError::InvalidParameter(
539                "scale range must be valid (0 <= min < max <= 1)".to_string(),
540            ));
541        }
542        if ratio_min <= 0.0 || ratio_min >= ratio_max {
543            return Err(TrainError::InvalidParameter(
544                "ratio range must be valid (0 < min < max)".to_string(),
545            ));
546        }
547
548        Ok(Self {
549            probability,
550            scale_min,
551            scale_max,
552            ratio_min,
553            ratio_max,
554            fill_value,
555        })
556    }
557
558    /// Create with default parameters (as in the paper).
559    pub fn with_defaults() -> Self {
560        Self {
561            probability: 0.5,
562            scale_min: 0.02,
563            scale_max: 0.33,
564            ratio_min: 0.3,
565            ratio_max: 3.3,
566            fill_value: 0.0,
567        }
568    }
569}
570
571impl Default for RandomErasingAugmenter {
572    fn default() -> Self {
573        Self::with_defaults()
574    }
575}
576
577impl DataAugmenter for RandomErasingAugmenter {
578    fn augment(&self, data: &ArrayView2<f64>, rng: &mut StdRng) -> TrainResult<Array2<f64>> {
579        let mut augmented = data.to_owned();
580
581        // Apply with probability p
582        if rng.random::<f64>() > self.probability {
583            return Ok(augmented);
584        }
585
586        let (height, width) = (data.nrows(), data.ncols());
587        let area = (height * width) as f64;
588
589        // Try multiple times to find a valid erasing region
590        for _ in 0..10 {
591            // Random scale (proportion of total area)
592            let scale = self.scale_min + rng.random::<f64>() * (self.scale_max - self.scale_min);
593            let erase_area = area * scale;
594
595            // Random aspect ratio
596            let aspect_ratio =
597                self.ratio_min + rng.random::<f64>() * (self.ratio_max - self.ratio_min);
598
599            // Compute erase region dimensions
600            let h = (erase_area * aspect_ratio).sqrt().min(height as f64);
601            let w = (erase_area / aspect_ratio).sqrt().min(width as f64);
602
603            if h >= 1.0 && w >= 1.0 {
604                let erase_h = h as usize;
605                let erase_w = w as usize;
606
607                // Random position
608                let i = if erase_h < height {
609                    (rng.random::<f64>() * (height - erase_h) as f64) as usize
610                } else {
611                    0
612                };
613                let j = if erase_w < width {
614                    (rng.random::<f64>() * (width - erase_w) as f64) as usize
615                } else {
616                    0
617                };
618
619                // Fill with specified value
620                if self.fill_value == 1.0 {
621                    // Random values
622                    for row in i..i + erase_h.min(height - i) {
623                        for col in j..j + erase_w.min(width - j) {
624                            augmented[[row, col]] = rng.random();
625                        }
626                    }
627                } else {
628                    // Fixed value (0.0 or specified)
629                    for row in i..i + erase_h.min(height - i) {
630                        for col in j..j + erase_w.min(width - j) {
631                            augmented[[row, col]] = self.fill_value;
632                        }
633                    }
634                }
635
636                break;
637            }
638        }
639
640        Ok(augmented)
641    }
642}
643
644/// CutOut augmentation.
645///
646/// Randomly erases a fixed-size square region in the input data.
647/// Simpler variant of Random Erasing with fixed square regions.
648///
649/// Reference: DeVries & Taylor, "Improved Regularization of Convolutional Neural Networks with Cutout" (2017)
650///
651/// # Parameters
652/// - `size`: Size of the square region to erase
653/// - `p`: Probability of applying cutout (default: 1.0)
654/// - `fill_value`: Value to fill the erased region (default: 0.0)
655#[derive(Debug, Clone)]
656pub struct CutOutAugmenter {
657    /// Size of the square cutout region.
658    pub cutout_size: usize,
659    /// Probability of applying cutout.
660    pub probability: f64,
661    /// Fill value for erased region.
662    pub fill_value: f64,
663}
664
665impl CutOutAugmenter {
666    /// Create a new CutOut augmenter.
667    pub fn new(cutout_size: usize, probability: f64, fill_value: f64) -> TrainResult<Self> {
668        if cutout_size == 0 {
669            return Err(TrainError::InvalidParameter(
670                "cutout_size must be > 0".to_string(),
671            ));
672        }
673        if !(0.0..=1.0).contains(&probability) {
674            return Err(TrainError::InvalidParameter(
675                "probability must be in [0, 1]".to_string(),
676            ));
677        }
678
679        Ok(Self {
680            cutout_size,
681            probability,
682            fill_value,
683        })
684    }
685
686    /// Create with default parameters.
687    pub fn with_size(size: usize) -> TrainResult<Self> {
688        Self::new(size, 1.0, 0.0)
689    }
690}
691
692impl DataAugmenter for CutOutAugmenter {
693    fn augment(&self, data: &ArrayView2<f64>, rng: &mut StdRng) -> TrainResult<Array2<f64>> {
694        let mut augmented = data.to_owned();
695
696        // Apply with probability p
697        if rng.random::<f64>() > self.probability {
698            return Ok(augmented);
699        }
700
701        let (height, width) = (data.nrows(), data.ncols());
702
703        // Random center position
704        let center_y = (rng.random::<f64>() * height as f64) as usize;
705        let center_x = (rng.random::<f64>() * width as f64) as usize;
706
707        // Compute cutout region bounds (allow partial cutout at boundaries)
708        let half_size = self.cutout_size / 2;
709
710        let y_start = center_y.saturating_sub(half_size);
711        let y_end = (center_y + half_size).min(height);
712
713        let x_start = center_x.saturating_sub(half_size);
714        let x_end = (center_x + half_size).min(width);
715
716        // Erase the region
717        for i in y_start..y_end {
718            for j in x_start..x_end {
719                augmented[[i, j]] = self.fill_value;
720            }
721        }
722
723        Ok(augmented)
724    }
725}
726
727#[cfg(test)]
728mod tests {
729    use super::*;
730    use scirs2_core::ndarray::array;
731    use scirs2_core::random::SeedableRng;
732
733    fn create_test_rng() -> StdRng {
734        StdRng::seed_from_u64(42)
735    }
736
737    #[test]
738    fn test_no_augmentation() {
739        let augmenter = NoAugmentation;
740        let data = array![[1.0, 2.0], [3.0, 4.0]];
741        let mut rng = create_test_rng();
742
743        let augmented = augmenter.augment(&data.view(), &mut rng).unwrap();
744        assert_eq!(augmented, data);
745    }
746
747    #[test]
748    fn test_noise_augmenter() {
749        let augmenter = NoiseAugmenter::new(0.1).unwrap();
750        let data = array![[1.0, 2.0], [3.0, 4.0]];
751        let mut rng = create_test_rng();
752
753        let augmented = augmenter.augment(&data.view(), &mut rng).unwrap();
754
755        // Shape should be preserved
756        assert_eq!(augmented.shape(), data.shape());
757
758        // Values should be different (with high probability)
759        assert_ne!(augmented[[0, 0]], data[[0, 0]]);
760
761        // But should be close to original values
762        for i in 0..data.nrows() {
763            for j in 0..data.ncols() {
764                let diff = (augmented[[i, j]] - data[[i, j]]).abs();
765                assert!(diff < 1.0); // Within reasonable noise range
766            }
767        }
768    }
769
770    #[test]
771    fn test_noise_augmenter_invalid() {
772        let result = NoiseAugmenter::new(-0.1);
773        assert!(result.is_err());
774    }
775
776    #[test]
777    fn test_scale_augmenter() {
778        let augmenter = ScaleAugmenter::new(0.2).unwrap();
779        let data = array![[1.0, 2.0], [3.0, 4.0]];
780        let mut rng = create_test_rng();
781
782        let augmented = augmenter.augment(&data.view(), &mut rng).unwrap();
783
784        // Shape should be preserved
785        assert_eq!(augmented.shape(), data.shape());
786
787        // All values should be scaled by the same factor
788        let scale = augmented[[0, 0]] / data[[0, 0]];
789        for i in 0..data.nrows() {
790            for j in 0..data.ncols() {
791                let computed_scale = augmented[[i, j]] / data[[i, j]];
792                assert!((computed_scale - scale).abs() < 1e-10);
793            }
794        }
795
796        // Scale should be within range [0.8, 1.2]
797        assert!((0.8..=1.2).contains(&scale));
798    }
799
800    #[test]
801    fn test_scale_augmenter_invalid() {
802        assert!(ScaleAugmenter::new(-0.1).is_err());
803        assert!(ScaleAugmenter::new(1.5).is_err());
804    }
805
806    #[test]
807    fn test_rotation_augmenter() {
808        let augmenter = RotationAugmenter::default();
809        let data = array![[1.0, 2.0], [3.0, 4.0]];
810        let mut rng = create_test_rng();
811
812        let augmented = augmenter.augment(&data.view(), &mut rng).unwrap();
813
814        // Shape should be preserved
815        assert_eq!(augmented.shape(), data.shape());
816    }
817
818    #[test]
819    fn test_mixup_augmenter_batch() {
820        let augmenter = MixupAugmenter::new(1.0).unwrap();
821        let data = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
822        let labels = array![[1.0, 0.0], [0.0, 1.0], [1.0, 0.0]];
823        let mut rng = create_test_rng();
824
825        let (aug_data, aug_labels) = augmenter
826            .augment_batch(&data.view(), &labels.view(), &mut rng)
827            .unwrap();
828
829        // Shapes should be preserved
830        assert_eq!(aug_data.shape(), data.shape());
831        assert_eq!(aug_labels.shape(), labels.shape());
832
833        // Values should be interpolations (between min and max of original)
834        let data_min = data.iter().cloned().fold(f64::INFINITY, f64::min);
835        let data_max = data.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
836
837        for &val in aug_data.iter() {
838            assert!(val >= data_min && val <= data_max);
839        }
840    }
841
842    #[test]
843    fn test_mixup_invalid_alpha() {
844        let result = MixupAugmenter::new(0.0);
845        assert!(result.is_err());
846
847        let result = MixupAugmenter::new(-1.0);
848        assert!(result.is_err());
849    }
850
851    #[test]
852    fn test_mixup_mismatched_shapes() {
853        let augmenter = MixupAugmenter::default();
854        let data = array![[1.0, 2.0], [3.0, 4.0]];
855        let labels = array![[1.0, 0.0]]; // Wrong shape
856        let mut rng = create_test_rng();
857
858        let result = augmenter.augment_batch(&data.view(), &labels.view(), &mut rng);
859        assert!(result.is_err());
860    }
861
862    #[test]
863    fn test_cutmix_augmenter_batch() {
864        let augmenter = CutMixAugmenter::new(1.0).unwrap();
865        let data = array![
866            [1.0, 2.0, 3.0, 4.0],
867            [5.0, 6.0, 7.0, 8.0],
868            [9.0, 10.0, 11.0, 12.0]
869        ];
870        let labels = array![[1.0, 0.0], [0.0, 1.0], [1.0, 0.0]];
871        let mut rng = create_test_rng();
872
873        let (aug_data, aug_labels) = augmenter
874            .augment_batch(&data.view(), &labels.view(), &mut rng)
875            .unwrap();
876
877        // Shapes should be preserved
878        assert_eq!(aug_data.shape(), data.shape());
879        assert_eq!(aug_labels.shape(), labels.shape());
880
881        // Each row should contain a mix of original values
882        // (some regions from original, some from paired sample)
883        for i in 0..aug_data.nrows() {
884            let mut found_original = false;
885            let mut found_different = false;
886
887            for j in 0..aug_data.ncols() {
888                // Check if value matches original row
889                if (aug_data[[i, j]] - data[[i, j]]).abs() < 1e-10 {
890                    found_original = true;
891                } else {
892                    found_different = true;
893                }
894            }
895
896            // Should have both original and swapped regions (unless randomly paired with self)
897            assert!(found_original || found_different);
898        }
899
900        // Label mixing: each element should be in [0, 1] and sum across classes
901        for i in 0..aug_labels.nrows() {
902            let sum: f64 = aug_labels.row(i).iter().sum();
903            assert!((sum - 1.0).abs() < 1e-10, "Labels should sum to 1.0");
904        }
905    }
906
907    #[test]
908    fn test_cutmix_invalid_alpha() {
909        let result = CutMixAugmenter::new(0.0);
910        assert!(result.is_err());
911
912        let result = CutMixAugmenter::new(-1.0);
913        assert!(result.is_err());
914    }
915
916    #[test]
917    fn test_cutmix_mismatched_shapes() {
918        let augmenter = CutMixAugmenter::default();
919        let data = array![[1.0, 2.0], [3.0, 4.0]];
920        let labels = array![[1.0, 0.0]]; // Wrong shape
921        let mut rng = create_test_rng();
922
923        let result = augmenter.augment_batch(&data.view(), &labels.view(), &mut rng);
924        assert!(result.is_err());
925    }
926
927    #[test]
928    fn test_cutmix_label_proportions() {
929        let augmenter = CutMixAugmenter::new(1.0).unwrap();
930        // Use distinctive patterns
931        let data = array![[10.0, 10.0, 10.0, 10.0], [20.0, 20.0, 20.0, 20.0]];
932        let labels = array![[1.0, 0.0], [0.0, 1.0]];
933        let mut rng = create_test_rng();
934
935        let (aug_data, aug_labels) = augmenter
936            .augment_batch(&data.view(), &labels.view(), &mut rng)
937            .unwrap();
938
939        // Verify that labels are mixed proportionally
940        for i in 0..aug_labels.nrows() {
941            // Each sample should have labels that sum to 1
942            let sum: f64 = aug_labels.row(i).iter().sum();
943            assert!((sum - 1.0).abs() < 1e-10);
944
945            // Labels should be between original values
946            for j in 0..aug_labels.ncols() {
947                assert!(aug_labels[[i, j]] >= 0.0);
948                assert!(aug_labels[[i, j]] <= 1.0);
949            }
950        }
951
952        // Verify data has been cut and mixed
953        assert_eq!(aug_data.shape(), data.shape());
954    }
955
956    #[test]
957    fn test_composite_augmenter() {
958        let mut composite = CompositeAugmenter::new();
959        composite.add(NoiseAugmenter::new(0.01).unwrap());
960        composite.add(ScaleAugmenter::new(0.1).unwrap());
961
962        let data = array![[1.0, 2.0], [3.0, 4.0]];
963        let mut rng = create_test_rng();
964
965        let augmented = composite.augment(&data.view(), &mut rng).unwrap();
966
967        // Shape should be preserved
968        assert_eq!(augmented.shape(), data.shape());
969
970        // Values should be different due to augmentation
971        assert_ne!(augmented[[0, 0]], data[[0, 0]]);
972    }
973
974    #[test]
975    fn test_composite_empty() {
976        let composite = CompositeAugmenter::new();
977        assert!(composite.is_empty());
978        assert_eq!(composite.len(), 0);
979
980        let data = array![[1.0, 2.0]];
981        let mut rng = create_test_rng();
982
983        let augmented = composite.augment(&data.view(), &mut rng).unwrap();
984        assert_eq!(augmented, data);
985    }
986
987    #[test]
988    fn test_composite_multiple() {
989        let mut composite = CompositeAugmenter::new();
990        composite.add(NoAugmentation);
991        composite.add(ScaleAugmenter::default());
992        composite.add(NoiseAugmenter::default());
993
994        assert_eq!(composite.len(), 3);
995        assert!(!composite.is_empty());
996    }
997
998    #[test]
999    fn test_random_erasing_creation() {
1000        let augmenter = RandomErasingAugmenter::new(0.5, 0.02, 0.33, 0.3, 3.3, 0.0).unwrap();
1001        assert_eq!(augmenter.probability, 0.5);
1002        assert_eq!(augmenter.scale_min, 0.02);
1003        assert_eq!(augmenter.scale_max, 0.33);
1004    }
1005
1006    #[test]
1007    fn test_random_erasing_invalid_params() {
1008        // Invalid probability
1009        assert!(RandomErasingAugmenter::new(1.5, 0.02, 0.33, 0.3, 3.3, 0.0).is_err());
1010
1011        // Invalid scale range
1012        assert!(RandomErasingAugmenter::new(0.5, 0.33, 0.02, 0.3, 3.3, 0.0).is_err());
1013
1014        // Invalid ratio range
1015        assert!(RandomErasingAugmenter::new(0.5, 0.02, 0.33, 3.3, 0.3, 0.0).is_err());
1016    }
1017
1018    #[test]
1019    fn test_random_erasing_augment() {
1020        let augmenter = RandomErasingAugmenter::with_defaults();
1021        let data = Array2::ones((10, 10));
1022        let mut rng = create_test_rng();
1023
1024        let augmented = augmenter.augment(&data.view(), &mut rng).unwrap();
1025
1026        // Shape should be preserved
1027        assert_eq!(augmented.shape(), data.shape());
1028
1029        // Some values may be erased (but not guaranteed due to probability)
1030    }
1031
1032    #[test]
1033    fn test_random_erasing_probability_zero() {
1034        let augmenter = RandomErasingAugmenter::new(0.0, 0.02, 0.33, 0.3, 3.3, 0.0).unwrap();
1035        let data = Array2::ones((10, 10));
1036        let mut rng = create_test_rng();
1037
1038        let augmented = augmenter.augment(&data.view(), &mut rng).unwrap();
1039
1040        // With probability 0, data should be unchanged
1041        assert_eq!(augmented, data);
1042    }
1043
1044    #[test]
1045    fn test_cutout_creation() {
1046        let augmenter = CutOutAugmenter::new(5, 1.0, 0.0).unwrap();
1047        assert_eq!(augmenter.cutout_size, 5);
1048        assert_eq!(augmenter.probability, 1.0);
1049        assert_eq!(augmenter.fill_value, 0.0);
1050    }
1051
1052    #[test]
1053    fn test_cutout_invalid_params() {
1054        // Zero size
1055        assert!(CutOutAugmenter::new(0, 1.0, 0.0).is_err());
1056
1057        // Invalid probability
1058        assert!(CutOutAugmenter::new(5, 1.5, 0.0).is_err());
1059    }
1060
1061    #[test]
1062    fn test_cutout_augment() {
1063        let augmenter = CutOutAugmenter::with_size(3).unwrap();
1064        let data = Array2::ones((10, 10));
1065        let mut rng = create_test_rng();
1066
1067        let augmented = augmenter.augment(&data.view(), &mut rng).unwrap();
1068
1069        // Shape should be preserved
1070        assert_eq!(augmented.shape(), data.shape());
1071
1072        // Some values should be zero (erased)
1073        let zeros_count = augmented.iter().filter(|&&x| x == 0.0).count();
1074        assert!(zeros_count > 0, "Expected some values to be erased");
1075        assert!(zeros_count < 100, "Not all values should be erased");
1076    }
1077
1078    #[test]
1079    fn test_cutout_probability_zero() {
1080        let augmenter = CutOutAugmenter::new(5, 0.0, 0.0).unwrap();
1081        let data = Array2::ones((10, 10));
1082        let mut rng = create_test_rng();
1083
1084        let augmented = augmenter.augment(&data.view(), &mut rng).unwrap();
1085
1086        // With probability 0, data should be unchanged
1087        assert_eq!(augmented, data);
1088    }
1089
1090    #[test]
1091    fn test_cutout_fill_value() {
1092        let augmenter = CutOutAugmenter::new(3, 1.0, 0.5).unwrap();
1093        let data = Array2::ones((10, 10));
1094        let mut rng = create_test_rng();
1095
1096        let augmented = augmenter.augment(&data.view(), &mut rng).unwrap();
1097
1098        // Some values should be 0.5 (fill value)
1099        let filled_count = augmented.iter().filter(|&&x| x == 0.5).count();
1100        assert!(
1101            filled_count > 0,
1102            "Expected some values to be filled with 0.5"
1103        );
1104    }
1105}