tensorlogic_train/
augmentation.rs

1//! Data augmentation techniques for training.
2//!
3//! This module provides various data augmentation strategies to improve model generalization:
4//! - Noise augmentation (Gaussian)
5//! - Scale augmentation (random scaling)
6//! - Rotation augmentation (placeholder for future)
7//! - Mixup augmentation (interpolation between samples)
8
9use crate::{TrainError, TrainResult};
10use scirs2_core::ndarray::{Array, Array2, ArrayView2};
11use scirs2_core::random::{Rng, StdRng};
12
13/// Trait for data augmentation strategies.
14pub trait DataAugmenter {
15    /// Augment the given data.
16    ///
17    /// # Arguments
18    /// * `data` - Input data to augment
19    /// * `rng` - Random number generator for stochastic augmentation
20    ///
21    /// # Returns
22    /// Augmented data with the same shape as input
23    fn augment(&self, data: &ArrayView2<f64>, rng: &mut StdRng) -> TrainResult<Array2<f64>>;
24}
25
26/// No augmentation (identity transformation).
27///
28/// Useful for testing or as a placeholder.
29#[derive(Debug, Clone, Default)]
30pub struct NoAugmentation;
31
32impl DataAugmenter for NoAugmentation {
33    fn augment(&self, data: &ArrayView2<f64>, _rng: &mut StdRng) -> TrainResult<Array2<f64>> {
34        Ok(data.to_owned())
35    }
36}
37
38/// Gaussian noise augmentation.
39///
40/// Adds random Gaussian noise to the input data: x' = x + N(0, σ²)
41#[derive(Debug, Clone)]
42pub struct NoiseAugmenter {
43    /// Standard deviation of the Gaussian noise.
44    pub std_dev: f64,
45}
46
47impl NoiseAugmenter {
48    /// Create a new noise augmenter.
49    ///
50    /// # Arguments
51    /// * `std_dev` - Standard deviation of the noise
52    pub fn new(std_dev: f64) -> TrainResult<Self> {
53        if std_dev < 0.0 {
54            return Err(TrainError::InvalidParameter(
55                "std_dev must be non-negative".to_string(),
56            ));
57        }
58        Ok(Self { std_dev })
59    }
60}
61
62impl Default for NoiseAugmenter {
63    fn default() -> Self {
64        Self { std_dev: 0.01 }
65    }
66}
67
68impl DataAugmenter for NoiseAugmenter {
69    fn augment(&self, data: &ArrayView2<f64>, rng: &mut StdRng) -> TrainResult<Array2<f64>> {
70        let mut augmented = data.to_owned();
71
72        // Add Gaussian noise using Box-Muller transform
73        for value in augmented.iter_mut() {
74            let u1: f64 = rng.random();
75            let u2: f64 = rng.random();
76
77            // Box-Muller transform
78            let z0 = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
79            let noise = z0 * self.std_dev;
80
81            *value += noise;
82        }
83
84        Ok(augmented)
85    }
86}
87
88/// Scale augmentation.
89///
90/// Randomly scales the input data by a factor in [1 - scale_range, 1 + scale_range].
91#[derive(Debug, Clone)]
92pub struct ScaleAugmenter {
93    /// Range of scaling factor (e.g., 0.1 means scale in [0.9, 1.1]).
94    pub scale_range: f64,
95}
96
97impl ScaleAugmenter {
98    /// Create a new scale augmenter.
99    ///
100    /// # Arguments
101    /// * `scale_range` - Range of scaling factor
102    pub fn new(scale_range: f64) -> TrainResult<Self> {
103        if !(0.0..=1.0).contains(&scale_range) {
104            return Err(TrainError::InvalidParameter(
105                "scale_range must be in [0, 1]".to_string(),
106            ));
107        }
108        Ok(Self { scale_range })
109    }
110}
111
112impl Default for ScaleAugmenter {
113    fn default() -> Self {
114        Self { scale_range: 0.1 }
115    }
116}
117
118impl DataAugmenter for ScaleAugmenter {
119    fn augment(&self, data: &ArrayView2<f64>, rng: &mut StdRng) -> TrainResult<Array2<f64>> {
120        // Generate random scale factor
121        let scale = 1.0 + (rng.random::<f64>() * 2.0 - 1.0) * self.scale_range;
122
123        let augmented = data.mapv(|x| x * scale);
124        Ok(augmented)
125    }
126}
127
128/// Rotation augmentation (placeholder for future implementation).
129///
130/// For 2D images, this would apply random rotations.
131/// Currently implements a simplified version for tabular data.
132#[derive(Debug, Clone)]
133pub struct RotationAugmenter {
134    /// Maximum rotation angle in radians.
135    pub max_angle: f64,
136}
137
138impl RotationAugmenter {
139    /// Create a new rotation augmenter.
140    ///
141    /// # Arguments
142    /// * `max_angle` - Maximum rotation angle in radians
143    pub fn new(max_angle: f64) -> Self {
144        Self { max_angle }
145    }
146}
147
148impl Default for RotationAugmenter {
149    fn default() -> Self {
150        Self {
151            max_angle: std::f64::consts::PI / 18.0, // 10 degrees
152        }
153    }
154}
155
156impl DataAugmenter for RotationAugmenter {
157    fn augment(&self, data: &ArrayView2<f64>, rng: &mut StdRng) -> TrainResult<Array2<f64>> {
158        // For now, this is a placeholder that returns a simple transformation
159        // Future: implement proper 2D rotation for image data
160        let angle = (rng.random::<f64>() * 2.0 - 1.0) * self.max_angle;
161
162        // Apply a simple rotation-inspired transformation
163        let cos_a = angle.cos();
164        let sin_a = angle.sin();
165
166        let augmented = data.mapv(|x| x * cos_a + x * sin_a * 0.1);
167        Ok(augmented)
168    }
169}
170
171/// Mixup augmentation.
172///
173/// Creates new training samples by linearly interpolating between pairs of samples:
174/// x' = λ * x₁ + (1 - λ) * x₂, y' = λ * y₁ + (1 - λ) * y₂
175///
176/// Reference: Zhang et al., "mixup: Beyond Empirical Risk Minimization", ICLR 2018
177#[derive(Debug, Clone)]
178pub struct MixupAugmenter {
179    /// Alpha parameter for Beta distribution (controls mixing strength).
180    pub alpha: f64,
181}
182
183impl MixupAugmenter {
184    /// Create a new mixup augmenter.
185    ///
186    /// # Arguments
187    /// * `alpha` - Alpha parameter for Beta distribution
188    pub fn new(alpha: f64) -> TrainResult<Self> {
189        if alpha <= 0.0 {
190            return Err(TrainError::InvalidParameter(
191                "alpha must be positive".to_string(),
192            ));
193        }
194        Ok(Self { alpha })
195    }
196
197    /// Apply mixup to a batch of data.
198    ///
199    /// # Arguments
200    /// * `data` - Input data batch [N, features]
201    /// * `labels` - Corresponding labels [N, classes]
202    /// * `rng` - Random number generator
203    ///
204    /// # Returns
205    /// Tuple of (augmented_data, augmented_labels)
206    pub fn augment_batch(
207        &self,
208        data: &ArrayView2<f64>,
209        labels: &ArrayView2<f64>,
210        rng: &mut StdRng,
211    ) -> TrainResult<(Array2<f64>, Array2<f64>)> {
212        if data.nrows() != labels.nrows() {
213            return Err(TrainError::InvalidParameter(
214                "data and labels must have same number of rows".to_string(),
215            ));
216        }
217
218        let n = data.nrows();
219        let mut augmented_data = Array::zeros(data.raw_dim());
220        let mut augmented_labels = Array::zeros(labels.raw_dim());
221
222        // Create random permutation
223        let mut indices: Vec<usize> = (0..n).collect();
224        for i in (1..n).rev() {
225            let j = rng.gen_range(0..=i);
226            indices.swap(i, j);
227        }
228
229        for i in 0..n {
230            let j = indices[i];
231
232            // Sample mixing coefficient from Beta distribution
233            // Simplified: use uniform distribution as approximation
234            let lambda = self.sample_beta(rng);
235
236            // Mix data: x' = λ * x_i + (1 - λ) * x_j
237            for k in 0..data.ncols() {
238                augmented_data[[i, k]] = lambda * data[[i, k]] + (1.0 - lambda) * data[[j, k]];
239            }
240
241            // Mix labels: y' = λ * y_i + (1 - λ) * y_j
242            for k in 0..labels.ncols() {
243                augmented_labels[[i, k]] =
244                    lambda * labels[[i, k]] + (1.0 - lambda) * labels[[j, k]];
245            }
246        }
247
248        Ok((augmented_data, augmented_labels))
249    }
250
251    /// Sample from Beta(alpha, alpha) distribution.
252    ///
253    /// Simplified implementation using uniform distribution when alpha is close to 1.
254    fn sample_beta(&self, rng: &mut StdRng) -> f64 {
255        if self.alpha < 0.5 {
256            // For small alpha, prefer values near 0 or 1
257            if rng.random::<f64>() < 0.5 {
258                rng.random::<f64>().powf(2.0)
259            } else {
260                1.0 - rng.random::<f64>().powf(2.0)
261            }
262        } else {
263            // For alpha >= 0.5, approximate with uniform
264            rng.random::<f64>()
265        }
266    }
267}
268
269impl Default for MixupAugmenter {
270    fn default() -> Self {
271        Self { alpha: 1.0 }
272    }
273}
274
275impl DataAugmenter for MixupAugmenter {
276    fn augment(&self, data: &ArrayView2<f64>, _rng: &mut StdRng) -> TrainResult<Array2<f64>> {
277        // For single-sample augmentation, mix with itself (no-op)
278        // In practice, mixup should be used with augment_batch
279        Ok(data.to_owned())
280    }
281}
282
283/// Composite augmenter that applies multiple augmentations sequentially.
284#[derive(Clone, Default)]
285pub struct CompositeAugmenter {
286    augmenters: Vec<Box<dyn AugmenterClone>>,
287}
288
289impl std::fmt::Debug for CompositeAugmenter {
290    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
291        f.debug_struct("CompositeAugmenter")
292            .field("num_augmenters", &self.augmenters.len())
293            .finish()
294    }
295}
296
297/// Helper trait for cloning boxed augmenters.
298trait AugmenterClone: DataAugmenter {
299    fn clone_box(&self) -> Box<dyn AugmenterClone>;
300}
301
302impl<T: DataAugmenter + Clone + 'static> AugmenterClone for T {
303    fn clone_box(&self) -> Box<dyn AugmenterClone> {
304        Box::new(self.clone())
305    }
306}
307
308impl Clone for Box<dyn AugmenterClone> {
309    fn clone(&self) -> Self {
310        self.clone_box()
311    }
312}
313
314impl DataAugmenter for Box<dyn AugmenterClone> {
315    fn augment(&self, data: &ArrayView2<f64>, rng: &mut StdRng) -> TrainResult<Array2<f64>> {
316        (**self).augment(data, rng)
317    }
318}
319
320impl CompositeAugmenter {
321    /// Create a new composite augmenter.
322    pub fn new() -> Self {
323        Self {
324            augmenters: Vec::new(),
325        }
326    }
327
328    /// Add an augmenter to the pipeline.
329    pub fn add<A: DataAugmenter + Clone + 'static>(&mut self, augmenter: A) {
330        self.augmenters.push(Box::new(augmenter));
331    }
332
333    /// Get the number of augmenters.
334    pub fn len(&self) -> usize {
335        self.augmenters.len()
336    }
337
338    /// Check if the pipeline is empty.
339    pub fn is_empty(&self) -> bool {
340        self.augmenters.is_empty()
341    }
342}
343
344impl DataAugmenter for CompositeAugmenter {
345    fn augment(&self, data: &ArrayView2<f64>, rng: &mut StdRng) -> TrainResult<Array2<f64>> {
346        let mut result = data.to_owned();
347
348        for augmenter in &self.augmenters {
349            result = augmenter.augment(&result.view(), rng)?;
350        }
351
352        Ok(result)
353    }
354}
355
356#[cfg(test)]
357mod tests {
358    use super::*;
359    use scirs2_core::ndarray::array;
360    use scirs2_core::random::SeedableRng;
361
362    fn create_test_rng() -> StdRng {
363        StdRng::seed_from_u64(42)
364    }
365
366    #[test]
367    fn test_no_augmentation() {
368        let augmenter = NoAugmentation;
369        let data = array![[1.0, 2.0], [3.0, 4.0]];
370        let mut rng = create_test_rng();
371
372        let augmented = augmenter.augment(&data.view(), &mut rng).unwrap();
373        assert_eq!(augmented, data);
374    }
375
376    #[test]
377    fn test_noise_augmenter() {
378        let augmenter = NoiseAugmenter::new(0.1).unwrap();
379        let data = array![[1.0, 2.0], [3.0, 4.0]];
380        let mut rng = create_test_rng();
381
382        let augmented = augmenter.augment(&data.view(), &mut rng).unwrap();
383
384        // Shape should be preserved
385        assert_eq!(augmented.shape(), data.shape());
386
387        // Values should be different (with high probability)
388        assert_ne!(augmented[[0, 0]], data[[0, 0]]);
389
390        // But should be close to original values
391        for i in 0..data.nrows() {
392            for j in 0..data.ncols() {
393                let diff = (augmented[[i, j]] - data[[i, j]]).abs();
394                assert!(diff < 1.0); // Within reasonable noise range
395            }
396        }
397    }
398
399    #[test]
400    fn test_noise_augmenter_invalid() {
401        let result = NoiseAugmenter::new(-0.1);
402        assert!(result.is_err());
403    }
404
405    #[test]
406    fn test_scale_augmenter() {
407        let augmenter = ScaleAugmenter::new(0.2).unwrap();
408        let data = array![[1.0, 2.0], [3.0, 4.0]];
409        let mut rng = create_test_rng();
410
411        let augmented = augmenter.augment(&data.view(), &mut rng).unwrap();
412
413        // Shape should be preserved
414        assert_eq!(augmented.shape(), data.shape());
415
416        // All values should be scaled by the same factor
417        let scale = augmented[[0, 0]] / data[[0, 0]];
418        for i in 0..data.nrows() {
419            for j in 0..data.ncols() {
420                let computed_scale = augmented[[i, j]] / data[[i, j]];
421                assert!((computed_scale - scale).abs() < 1e-10);
422            }
423        }
424
425        // Scale should be within range [0.8, 1.2]
426        assert!((0.8..=1.2).contains(&scale));
427    }
428
429    #[test]
430    fn test_scale_augmenter_invalid() {
431        assert!(ScaleAugmenter::new(-0.1).is_err());
432        assert!(ScaleAugmenter::new(1.5).is_err());
433    }
434
435    #[test]
436    fn test_rotation_augmenter() {
437        let augmenter = RotationAugmenter::default();
438        let data = array![[1.0, 2.0], [3.0, 4.0]];
439        let mut rng = create_test_rng();
440
441        let augmented = augmenter.augment(&data.view(), &mut rng).unwrap();
442
443        // Shape should be preserved
444        assert_eq!(augmented.shape(), data.shape());
445    }
446
447    #[test]
448    fn test_mixup_augmenter_batch() {
449        let augmenter = MixupAugmenter::new(1.0).unwrap();
450        let data = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
451        let labels = array![[1.0, 0.0], [0.0, 1.0], [1.0, 0.0]];
452        let mut rng = create_test_rng();
453
454        let (aug_data, aug_labels) = augmenter
455            .augment_batch(&data.view(), &labels.view(), &mut rng)
456            .unwrap();
457
458        // Shapes should be preserved
459        assert_eq!(aug_data.shape(), data.shape());
460        assert_eq!(aug_labels.shape(), labels.shape());
461
462        // Values should be interpolations (between min and max of original)
463        let data_min = data.iter().cloned().fold(f64::INFINITY, f64::min);
464        let data_max = data.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
465
466        for &val in aug_data.iter() {
467            assert!(val >= data_min && val <= data_max);
468        }
469    }
470
471    #[test]
472    fn test_mixup_invalid_alpha() {
473        let result = MixupAugmenter::new(0.0);
474        assert!(result.is_err());
475
476        let result = MixupAugmenter::new(-1.0);
477        assert!(result.is_err());
478    }
479
480    #[test]
481    fn test_mixup_mismatched_shapes() {
482        let augmenter = MixupAugmenter::default();
483        let data = array![[1.0, 2.0], [3.0, 4.0]];
484        let labels = array![[1.0, 0.0]]; // Wrong shape
485        let mut rng = create_test_rng();
486
487        let result = augmenter.augment_batch(&data.view(), &labels.view(), &mut rng);
488        assert!(result.is_err());
489    }
490
491    #[test]
492    fn test_composite_augmenter() {
493        let mut composite = CompositeAugmenter::new();
494        composite.add(NoiseAugmenter::new(0.01).unwrap());
495        composite.add(ScaleAugmenter::new(0.1).unwrap());
496
497        let data = array![[1.0, 2.0], [3.0, 4.0]];
498        let mut rng = create_test_rng();
499
500        let augmented = composite.augment(&data.view(), &mut rng).unwrap();
501
502        // Shape should be preserved
503        assert_eq!(augmented.shape(), data.shape());
504
505        // Values should be different due to augmentation
506        assert_ne!(augmented[[0, 0]], data[[0, 0]]);
507    }
508
509    #[test]
510    fn test_composite_empty() {
511        let composite = CompositeAugmenter::new();
512        assert!(composite.is_empty());
513        assert_eq!(composite.len(), 0);
514
515        let data = array![[1.0, 2.0]];
516        let mut rng = create_test_rng();
517
518        let augmented = composite.augment(&data.view(), &mut rng).unwrap();
519        assert_eq!(augmented, data);
520    }
521
522    #[test]
523    fn test_composite_multiple() {
524        let mut composite = CompositeAugmenter::new();
525        composite.add(NoAugmentation);
526        composite.add(ScaleAugmenter::default());
527        composite.add(NoiseAugmenter::default());
528
529        assert_eq!(composite.len(), 3);
530        assert!(!composite.is_empty());
531    }
532}