Skip to main content

scirs2_datasets/
augmentation.rs

1//! Data augmentation pipeline with GPU support
2//!
3//! This module provides composable data augmentation transformations for various
4//! data types (images, audio, tabular) with optional GPU acceleration for improved
5//! performance on large datasets.
6
7use crate::error::{DatasetsError, Result};
8use scirs2_core::ndarray::{s, Array1, Array2, Array3, ArrayView2};
9use scirs2_core::rand_distributions::Normal;
10use scirs2_core::random::Random;
11use scirs2_core::{Rng, RngExt};
12use std::sync::Arc;
13
14/// Helper function to create a random number generator with time-based seed
15fn create_rng() -> Random<scirs2_core::rand_prelude::StdRng> {
16    use std::time::{SystemTime, UNIX_EPOCH};
17    let seed = SystemTime::now()
18        .duration_since(UNIX_EPOCH)
19        .map(|d| d.as_secs())
20        .unwrap_or(0);
21    Random::seed(seed)
22}
23
24/// Augmentation transform trait
25pub trait Transform: Send + Sync {
26    /// Apply transformation to 2D array (tabular data)
27    fn transform_2d(&self, data: &Array2<f64>) -> Result<Array2<f64>>;
28
29    /// Apply transformation to 3D array (image data)
30    fn transform_3d(&self, data: &Array3<f64>) -> Result<Array3<f64>> {
31        // Default implementation: process each channel separately
32        let (height, width, channels) = data.dim();
33        let mut result = Array3::zeros((height, width, channels));
34        for c in 0..channels {
35            let channel_2d = data.slice(s![.., .., c]).to_owned();
36            let transformed = self.transform_2d(&channel_2d)?;
37            result.slice_mut(s![.., .., c]).assign(&transformed);
38        }
39        Ok(result)
40    }
41
42    /// Whether this transform uses GPU acceleration
43    fn uses_gpu(&self) -> bool {
44        false
45    }
46
47    /// Name of the transform
48    fn name(&self) -> &str;
49}
50
51/// Pipeline of augmentation transforms
52#[derive(Clone)]
53pub struct AugmentationPipeline {
54    transforms: Vec<Arc<dyn Transform>>,
55    probability: f64,
56    seed: Option<u64>,
57}
58
59impl AugmentationPipeline {
60    /// Create a new augmentation pipeline
61    pub fn new() -> Self {
62        Self {
63            transforms: Vec::new(),
64            probability: 1.0,
65            seed: None,
66        }
67    }
68
69    /// Add a transform to the pipeline
70    pub fn add_transform(mut self, transform: Arc<dyn Transform>) -> Self {
71        self.transforms.push(transform);
72        self
73    }
74
75    /// Set the probability of applying the entire pipeline
76    pub fn with_probability(mut self, prob: f64) -> Self {
77        self.probability = prob.clamp(0.0, 1.0);
78        self
79    }
80
81    /// Set random seed for reproducibility
82    pub fn with_seed(mut self, seed: u64) -> Self {
83        self.seed = Some(seed);
84        self
85    }
86
87    /// Apply pipeline to 2D data
88    pub fn apply_2d(&self, data: &Array2<f64>) -> Result<Array2<f64>> {
89        // Check if we should apply augmentation
90        let mut rng = if let Some(seed) = self.seed {
91            Random::seed(seed)
92        } else {
93            create_rng()
94        };
95
96        if rng.random::<f64>() > self.probability {
97            return Ok(data.clone());
98        }
99
100        // Apply transforms sequentially
101        let mut result = data.clone();
102        for transform in &self.transforms {
103            result = transform.transform_2d(&result)?;
104        }
105        Ok(result)
106    }
107
108    /// Apply pipeline to 3D data
109    pub fn apply_3d(&self, data: &Array3<f64>) -> Result<Array3<f64>> {
110        let mut rng = if let Some(seed) = self.seed {
111            Random::seed(seed)
112        } else {
113            create_rng()
114        };
115
116        if rng.random::<f64>() > self.probability {
117            return Ok(data.clone());
118        }
119
120        let mut result = data.clone();
121        for transform in &self.transforms {
122            result = transform.transform_3d(&result)?;
123        }
124        Ok(result)
125    }
126
127    /// Check if any transform uses GPU
128    pub fn uses_gpu(&self) -> bool {
129        self.transforms.iter().any(|t| t.uses_gpu())
130    }
131}
132
133impl Default for AugmentationPipeline {
134    fn default() -> Self {
135        Self::new()
136    }
137}
138
139// ============================================================================
140// Image Augmentation Transforms
141// ============================================================================
142
143/// Horizontal flip transform
144pub struct HorizontalFlip {
145    probability: f64,
146}
147
148impl HorizontalFlip {
149    /// Create a new horizontal flip transform
150    pub fn new(probability: f64) -> Self {
151        Self {
152            probability: probability.clamp(0.0, 1.0),
153        }
154    }
155}
156
157impl Transform for HorizontalFlip {
158    fn transform_2d(&self, data: &Array2<f64>) -> Result<Array2<f64>> {
159        let mut rng = create_rng();
160        if rng.random::<f64>() < self.probability {
161            // Flip horizontally (reverse columns)
162            let flipped = data.slice(s![.., ..;-1]).to_owned();
163            Ok(flipped)
164        } else {
165            Ok(data.clone())
166        }
167    }
168
169    fn name(&self) -> &str {
170        "HorizontalFlip"
171    }
172}
173
174/// Vertical flip transform
175pub struct VerticalFlip {
176    probability: f64,
177}
178
179impl VerticalFlip {
180    /// Create a new vertical flip transform
181    pub fn new(probability: f64) -> Self {
182        Self {
183            probability: probability.clamp(0.0, 1.0),
184        }
185    }
186}
187
188impl Transform for VerticalFlip {
189    fn transform_2d(&self, data: &Array2<f64>) -> Result<Array2<f64>> {
190        let mut rng = create_rng();
191        if rng.random::<f64>() < self.probability {
192            // Flip vertically (reverse rows)
193            let flipped = data.slice(s![..;-1, ..]).to_owned();
194            Ok(flipped)
195        } else {
196            Ok(data.clone())
197        }
198    }
199
200    fn name(&self) -> &str {
201        "VerticalFlip"
202    }
203}
204
205/// Random rotation transform (90, 180, 270 degrees)
206pub struct RandomRotation90 {
207    probability: f64,
208}
209
210impl RandomRotation90 {
211    /// Create a new random rotation transform
212    pub fn new(probability: f64) -> Self {
213        Self {
214            probability: probability.clamp(0.0, 1.0),
215        }
216    }
217
218    /// Rotate matrix 90 degrees clockwise
219    fn rotate_90(&self, data: &Array2<f64>) -> Array2<f64> {
220        let (rows, cols) = data.dim();
221        let mut result = Array2::zeros((cols, rows));
222        for i in 0..rows {
223            for j in 0..cols {
224                result[[j, rows - 1 - i]] = data[[i, j]];
225            }
226        }
227        result
228    }
229}
230
231impl Transform for RandomRotation90 {
232    fn transform_2d(&self, data: &Array2<f64>) -> Result<Array2<f64>> {
233        let mut rng = create_rng();
234        if rng.random::<f64>() < self.probability {
235            // Randomly choose 90, 180, or 270 degrees
236            let rotations = (rng.random::<f64>() * 3.0).floor() as usize + 1;
237            let mut result = data.clone();
238            for _ in 0..rotations {
239                result = self.rotate_90(&result);
240            }
241            Ok(result)
242        } else {
243            Ok(data.clone())
244        }
245    }
246
247    fn name(&self) -> &str {
248        "RandomRotation90"
249    }
250}
251
252/// Gaussian noise addition
253pub struct GaussianNoise {
254    mean: f64,
255    std: f64,
256    probability: f64,
257}
258
259impl GaussianNoise {
260    /// Create a new Gaussian noise transform
261    pub fn new(mean: f64, std: f64, probability: f64) -> Self {
262        Self {
263            mean,
264            std,
265            probability: probability.clamp(0.0, 1.0),
266        }
267    }
268}
269
270impl Transform for GaussianNoise {
271    fn transform_2d(&self, data: &Array2<f64>) -> Result<Array2<f64>> {
272        let mut rng = create_rng();
273        if rng.random::<f64>() < self.probability {
274            let (rows, cols) = data.dim();
275            let mut result = data.clone();
276            let normal = Normal::new(self.mean, self.std).map_err(|e| {
277                DatasetsError::ComputationError(format!(
278                    "Failed to create normal distribution: {}",
279                    e
280                ))
281            })?;
282            for i in 0..rows {
283                for j in 0..cols {
284                    let noise = rng.sample(normal);
285                    result[[i, j]] += noise;
286                }
287            }
288            Ok(result)
289        } else {
290            Ok(data.clone())
291        }
292    }
293
294    fn name(&self) -> &str {
295        "GaussianNoise"
296    }
297}
298
299/// Brightness adjustment
300pub struct Brightness {
301    delta_range: (f64, f64),
302    probability: f64,
303}
304
305impl Brightness {
306    /// Create a new brightness transform
307    pub fn new(delta_range: (f64, f64), probability: f64) -> Self {
308        Self {
309            delta_range,
310            probability: probability.clamp(0.0, 1.0),
311        }
312    }
313}
314
315impl Transform for Brightness {
316    fn transform_2d(&self, data: &Array2<f64>) -> Result<Array2<f64>> {
317        let mut rng = create_rng();
318        if rng.random::<f64>() < self.probability {
319            let delta = self.delta_range.0
320                + rng.random::<f64>() * (self.delta_range.1 - self.delta_range.0);
321            Ok(data + delta)
322        } else {
323            Ok(data.clone())
324        }
325    }
326
327    fn name(&self) -> &str {
328        "Brightness"
329    }
330}
331
332/// Contrast adjustment
333pub struct Contrast {
334    factor_range: (f64, f64),
335    probability: f64,
336}
337
338impl Contrast {
339    /// Create a new contrast transform
340    pub fn new(factor_range: (f64, f64), probability: f64) -> Self {
341        Self {
342            factor_range,
343            probability: probability.clamp(0.0, 1.0),
344        }
345    }
346}
347
348impl Transform for Contrast {
349    fn transform_2d(&self, data: &Array2<f64>) -> Result<Array2<f64>> {
350        let mut rng = create_rng();
351        if rng.random::<f64>() < self.probability {
352            let factor = self.factor_range.0
353                + rng.random::<f64>() * (self.factor_range.1 - self.factor_range.0);
354            let mean = data.mean().unwrap_or(0.0);
355            Ok((data - mean) * factor + mean)
356        } else {
357            Ok(data.clone())
358        }
359    }
360
361    fn name(&self) -> &str {
362        "Contrast"
363    }
364}
365
366// ============================================================================
367// Tabular Data Augmentation
368// ============================================================================
369
370/// Random feature scaling
371pub struct RandomFeatureScale {
372    scale_range: (f64, f64),
373    feature_probability: f64,
374}
375
376impl RandomFeatureScale {
377    /// Create a new random feature scaling transform
378    pub fn new(scale_range: (f64, f64), feature_probability: f64) -> Self {
379        Self {
380            scale_range,
381            feature_probability: feature_probability.clamp(0.0, 1.0),
382        }
383    }
384}
385
386impl Transform for RandomFeatureScale {
387    fn transform_2d(&self, data: &Array2<f64>) -> Result<Array2<f64>> {
388        let mut rng = create_rng();
389        let (rows, cols) = data.dim();
390        let mut result = data.clone();
391
392        for j in 0..cols {
393            if rng.random::<f64>() < self.feature_probability {
394                let scale = self.scale_range.0
395                    + rng.random::<f64>() * (self.scale_range.1 - self.scale_range.0);
396                for i in 0..rows {
397                    result[[i, j]] *= scale;
398                }
399            }
400        }
401
402        Ok(result)
403    }
404
405    fn name(&self) -> &str {
406        "RandomFeatureScale"
407    }
408}
409
410/// Mixup augmentation (linear interpolation between samples)
411pub struct Mixup {
412    alpha: f64,
413    probability: f64,
414}
415
416impl Mixup {
417    /// Create a new mixup transform
418    pub fn new(alpha: f64, probability: f64) -> Self {
419        Self {
420            alpha,
421            probability: probability.clamp(0.0, 1.0),
422        }
423    }
424}
425
426impl Transform for Mixup {
427    fn transform_2d(&self, data: &Array2<f64>) -> Result<Array2<f64>> {
428        let mut rng = create_rng();
429        if rng.random::<f64>() < self.probability {
430            let (rows, cols) = data.dim();
431            if rows < 2 {
432                return Ok(data.clone());
433            }
434
435            let mut result = data.clone();
436            for i in 0..rows {
437                // Randomly select another sample
438                let j = (rng.random::<f64>() * rows as f64).floor() as usize % rows;
439                if i != j {
440                    // Beta distribution parameter (simplified as uniform for now)
441                    let lambda = rng.random::<f64>();
442                    // Mix the two samples
443                    for k in 0..cols {
444                        result[[i, k]] = lambda * data[[i, k]] + (1.0 - lambda) * data[[j, k]];
445                    }
446                }
447            }
448            Ok(result)
449        } else {
450            Ok(data.clone())
451        }
452    }
453
454    fn name(&self) -> &str {
455        "Mixup"
456    }
457}
458
459// ============================================================================
460// Convenience Functions
461// ============================================================================
462
463/// Create a standard image augmentation pipeline
464pub fn standard_image_augmentation(probability: f64) -> AugmentationPipeline {
465    AugmentationPipeline::new()
466        .add_transform(Arc::new(HorizontalFlip::new(0.5)))
467        .add_transform(Arc::new(RandomRotation90::new(0.3)))
468        .add_transform(Arc::new(Brightness::new((-0.2, 0.2), 0.4)))
469        .add_transform(Arc::new(Contrast::new((0.8, 1.2), 0.4)))
470        .add_transform(Arc::new(GaussianNoise::new(0.0, 0.01, 0.3)))
471        .with_probability(probability)
472}
473
474/// Create a standard tabular augmentation pipeline
475pub fn standard_tabular_augmentation(probability: f64) -> AugmentationPipeline {
476    AugmentationPipeline::new()
477        .add_transform(Arc::new(RandomFeatureScale::new((0.9, 1.1), 0.3)))
478        .add_transform(Arc::new(GaussianNoise::new(0.0, 0.01, 0.2)))
479        .add_transform(Arc::new(Mixup::new(1.0, 0.5)))
480        .with_probability(probability)
481}
482
483#[cfg(test)]
484mod tests {
485    use super::*;
486
487    #[test]
488    fn test_horizontal_flip() -> Result<()> {
489        let data = Array2::from_shape_vec(
490            (3, 4),
491            vec![
492                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
493            ],
494        )
495        .map_err(|e| DatasetsError::InvalidFormat(format!("{}", e)))?;
496
497        let flip = HorizontalFlip::new(1.0); // Always flip
498        let result = flip.transform_2d(&data)?;
499
500        assert_eq!(result[[0, 0]], 4.0);
501        assert_eq!(result[[0, 3]], 1.0);
502        assert_eq!(result.nrows(), 3);
503        assert_eq!(result.ncols(), 4);
504
505        Ok(())
506    }
507
508    #[test]
509    fn test_gaussian_noise() -> Result<()> {
510        let data = Array2::zeros((10, 10));
511        let noise = GaussianNoise::new(0.0, 0.1, 1.0);
512        let result = noise.transform_2d(&data)?;
513
514        // Should have added noise (not all zeros)
515        let sum = result.sum();
516        assert!(sum.abs() > 1e-10);
517        assert_eq!(result.dim(), data.dim());
518
519        Ok(())
520    }
521
522    #[test]
523    fn test_brightness() -> Result<()> {
524        let data = Array2::from_elem((5, 5), 0.5);
525        let brightness = Brightness::new((0.1, 0.1), 1.0); // Fixed delta
526        let result = brightness.transform_2d(&data)?;
527
528        // All values should be increased by ~0.1
529        assert!((result[[0, 0]] - 0.6).abs() < 0.01);
530
531        Ok(())
532    }
533
534    #[test]
535    fn test_augmentation_pipeline() -> Result<()> {
536        let data =
537            Array2::from_shape_vec((3, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0])
538                .map_err(|e| DatasetsError::InvalidFormat(format!("{}", e)))?;
539
540        let pipeline = AugmentationPipeline::new()
541            .add_transform(Arc::new(HorizontalFlip::new(1.0)))
542            .add_transform(Arc::new(Brightness::new((0.1, 0.1), 1.0)))
543            .with_probability(1.0);
544
545        let result = pipeline.apply_2d(&data)?;
546
547        // Should be flipped and brightened
548        assert_eq!(result.dim(), data.dim());
549
550        Ok(())
551    }
552
553    #[test]
554    fn test_standard_pipelines() {
555        let img_pipeline = standard_image_augmentation(0.8);
556        assert!(!img_pipeline.uses_gpu());
557
558        let tab_pipeline = standard_tabular_augmentation(0.8);
559        assert!(!tab_pipeline.uses_gpu());
560    }
561}