Skip to main content

tensorlogic_train/augmentation/
pipeline.rs

1use scirs2_core::ndarray::ArrayD;
2
3use super::error::AugmentationError;
4use super::functional::{
5    center_crop_2d, clip, dropout, gaussian_noise, normalize, random_crop_2d, random_hflip,
6    random_vflip,
7};
8use super::rng::AugRng;
9
10/// A single step in an augmentation pipeline.
11#[derive(Debug, Clone)]
12pub enum AugmentationStep {
13    GaussianNoise { std: f64 },
14    Dropout { p: f64 },
15    RandomHFlip { p: f64 },
16    RandomVFlip { p: f64 },
17    RandomCrop { crop_h: usize, crop_w: usize },
18    CenterCrop { crop_h: usize, crop_w: usize },
19    Normalize { mean: Vec<f64>, std: Vec<f64> },
20    Clip { min_val: f64, max_val: f64 },
21}
22
23/// A composable, ordered sequence of augmentation steps.
24///
25/// Steps are applied left-to-right. Each step receives its own `AugRng`
26/// derived from the pipeline seed advanced by the step index so results
27/// are deterministic given the same seed.
28#[derive(Debug, Clone)]
29pub struct AugmentationPipeline {
30    /// The ordered list of augmentation steps.
31    pub steps: Vec<AugmentationStep>,
32    /// Seed used to derive per-step RNG states.
33    pub rng_seed: u64,
34}
35
36impl AugmentationPipeline {
37    /// Create an empty pipeline with the given RNG seed.
38    pub fn new(seed: u64) -> Self {
39        Self {
40            steps: Vec::new(),
41            rng_seed: seed,
42        }
43    }
44
45    /// Append a step and return `self` (builder pattern).
46    pub fn add_step(mut self, step: AugmentationStep) -> Self {
47        self.steps.push(step);
48        self
49    }
50
51    /// Apply all steps to `input` sequentially.
52    ///
53    /// A fresh `AugRng` is derived for each step from `rng_seed ^ (step_index * prime)`,
54    /// guaranteeing reproducibility while avoiding correlation between steps.
55    pub fn apply(
56        &self,
57        input: &ArrayD<f64>,
58        training: bool,
59    ) -> Result<ArrayD<f64>, AugmentationError> {
60        let mut current = input.clone();
61        for (i, step) in self.steps.iter().enumerate() {
62            // Derive a per-step seed.
63            let step_seed = self
64                .rng_seed
65                .wrapping_add((i as u64).wrapping_mul(0x9e37_79b9_7f4a_7c15));
66            let mut rng = AugRng::new(step_seed);
67
68            current = match step {
69                AugmentationStep::GaussianNoise { std } => {
70                    gaussian_noise(&current, *std, &mut rng)?
71                }
72                AugmentationStep::Dropout { p } => dropout(&current, *p, training, &mut rng)?,
73                AugmentationStep::RandomHFlip { p } => random_hflip(&current, *p, &mut rng)?,
74                AugmentationStep::RandomVFlip { p } => random_vflip(&current, *p, &mut rng)?,
75                AugmentationStep::RandomCrop { crop_h, crop_w } => {
76                    random_crop_2d(&current, *crop_h, *crop_w, &mut rng)?
77                }
78                AugmentationStep::CenterCrop { crop_h, crop_w } => {
79                    center_crop_2d(&current, *crop_h, *crop_w)?
80                }
81                AugmentationStep::Normalize { mean, std } => normalize(&current, mean, std)?,
82                AugmentationStep::Clip { min_val, max_val } => clip(&current, *min_val, *max_val),
83            };
84        }
85        Ok(current)
86    }
87
88    /// Number of steps in the pipeline.
89    pub fn num_steps(&self) -> usize {
90        self.steps.len()
91    }
92}
93
94/// Statistics comparing original and augmented data.
95#[derive(Debug, Clone)]
96pub struct AugStats {
97    /// Mean of the original array.
98    pub original_mean: f64,
99    /// Standard deviation of the original array.
100    pub original_std: f64,
101    /// Mean of the augmented array.
102    pub augmented_mean: f64,
103    /// Standard deviation of the augmented array.
104    pub augmented_std: f64,
105    /// Fraction of elements whose value changed (|orig − aug| > ε).
106    pub element_change_ratio: f64,
107}
108
109impl AugStats {
110    /// Compute statistics comparing `original` and `augmented`.
111    pub fn compute(original: &ArrayD<f64>, augmented: &ArrayD<f64>) -> Self {
112        let orig_flat: Vec<f64> = original.iter().copied().collect();
113        let aug_flat: Vec<f64> = augmented.iter().copied().collect();
114        let n = orig_flat.len().max(1);
115
116        let orig_mean = orig_flat.iter().sum::<f64>() / n as f64;
117        let aug_mean = aug_flat.iter().sum::<f64>() / aug_flat.len().max(1) as f64;
118
119        let orig_var = orig_flat
120            .iter()
121            .map(|&x| (x - orig_mean).powi(2))
122            .sum::<f64>()
123            / n as f64;
124        let aug_var = aug_flat
125            .iter()
126            .map(|&x| (x - aug_mean).powi(2))
127            .sum::<f64>()
128            / aug_flat.len().max(1) as f64;
129
130        let compare_n = orig_flat.len().min(aug_flat.len()).max(1);
131        let changed = orig_flat
132            .iter()
133            .zip(aug_flat.iter())
134            .filter(|(&a, &b)| (a - b).abs() > 1e-12)
135            .count();
136
137        AugStats {
138            original_mean: orig_mean,
139            original_std: orig_var.sqrt(),
140            augmented_mean: aug_mean,
141            augmented_std: aug_var.sqrt(),
142            element_change_ratio: changed as f64 / compare_n as f64,
143        }
144    }
145
146    /// Human-readable one-line summary.
147    pub fn summary(&self) -> String {
148        format!(
149            "orig μ={:.4} σ={:.4} | aug μ={:.4} σ={:.4} | changed {:.1}%",
150            self.original_mean,
151            self.original_std,
152            self.augmented_mean,
153            self.augmented_std,
154            self.element_change_ratio * 100.0
155        )
156    }
157}
158
159#[cfg(test)]
160mod aug_tests {
161    use super::*;
162    use scirs2_core::ndarray::ArrayD;
163
164    fn make_rng() -> AugRng {
165        AugRng::new(0xDEAD_BEEF)
166    }
167
168    fn ones(shape: &[usize]) -> ArrayD<f64> {
169        use scirs2_core::ndarray::IxDyn;
170        let n: usize = shape.iter().product();
171        ArrayD::from_shape_vec(IxDyn(shape), vec![1.0f64; n]).expect("shape ok")
172    }
173
174    fn arange(shape: &[usize]) -> ArrayD<f64> {
175        use scirs2_core::ndarray::IxDyn;
176        let n: usize = shape.iter().product();
177        let data: Vec<f64> = (0..n).map(|i| i as f64).collect();
178        ArrayD::from_shape_vec(IxDyn(shape), data).expect("shape ok")
179    }
180
181    // ---- gaussian_noise ----
182
183    #[test]
184    fn test_gaussian_noise_shape_preserved() {
185        let input = ones(&[3, 4]);
186        let mut rng = make_rng();
187        let out = gaussian_noise(&input, 0.1, &mut rng).expect("ok");
188        assert_eq!(out.shape(), input.shape());
189    }
190
191    #[test]
192    fn test_gaussian_noise_mean_near_original() {
193        // With std=0.01 and 1000 elements the mean should stay close to 1.0.
194        let input = ones(&[10, 100]);
195        let mut rng = make_rng();
196        let out = gaussian_noise(&input, 0.01, &mut rng).expect("ok");
197        let sum: f64 = out.iter().sum();
198        let mean = sum / 1000.0;
199        assert!((mean - 1.0).abs() < 0.05, "mean {mean} too far from 1.0");
200    }
201
202    // ---- dropout ----
203
204    #[test]
205    fn test_dropout_training_zeroes_some() {
206        let input = ones(&[100]);
207        let mut rng = make_rng();
208        let out = dropout(&input, 0.5, true, &mut rng).expect("ok");
209        let zero_count = out.iter().filter(|&&x| x == 0.0).count();
210        // With p=0.5 and 100 elements, expect some zeros.
211        assert!(zero_count > 0, "expected some zeros");
212        assert!(zero_count < 100, "not all should be zero");
213    }
214
215    #[test]
216    fn test_dropout_inference_unchanged() {
217        let input = arange(&[5, 5]);
218        let mut rng = make_rng();
219        let out = dropout(&input, 0.9, false, &mut rng).expect("ok");
220        assert_eq!(out, input);
221    }
222
223    // ---- dropout_mask ----
224
225    #[test]
226    fn test_dropout_mask_shape() {
227        use super::super::functional::dropout_mask;
228        let mut rng = make_rng();
229        let mask = dropout_mask(&[4, 4], 0.3, &mut rng).expect("ok");
230        assert_eq!(mask.shape(), &[4, 4]);
231        for &v in mask.iter() {
232            assert!(v == 0.0 || v == 1.0);
233        }
234    }
235
236    // ---- mixup ----
237
238    #[test]
239    fn test_mixup_shape() {
240        use super::super::functional::mixup;
241        let x1 = ones(&[3, 4]);
242        let x2 = arange(&[3, 4]);
243        let mut rng = make_rng();
244        let (mixed, _lam) = mixup(&x1, &x2, 1.0, &mut rng).expect("ok");
245        assert_eq!(mixed.shape(), x1.shape());
246    }
247
248    #[test]
249    fn test_mixup_lambda_range() {
250        use super::super::functional::mixup;
251        let x1 = ones(&[2, 2]);
252        let x2 = ones(&[2, 2]);
253        let mut rng = make_rng();
254        for _ in 0..50 {
255            let (_mixed, lam) = mixup(&x1, &x2, 1.0, &mut rng).expect("ok");
256            assert!((0.0..=1.0).contains(&lam), "lambda={lam} out of range");
257        }
258    }
259
260    // ---- cutmix ----
261
262    #[test]
263    fn test_cutmix_shape() {
264        use super::super::functional::cutmix;
265        let x1 = ones(&[1, 3, 8, 8]);
266        let x2 = arange(&[1, 3, 8, 8]);
267        let mut rng = make_rng();
268        let (mixed, _lam) = cutmix(&x1, &x2, 1.0, &mut rng).expect("ok");
269        assert_eq!(mixed.shape(), x1.shape());
270    }
271
272    #[test]
273    fn test_cutmix_lambda_range() {
274        use super::super::functional::cutmix;
275        let x1 = ones(&[1, 4, 8, 8]);
276        let x2 = arange(&[1, 4, 8, 8]);
277        let mut rng = make_rng();
278        for _ in 0..20 {
279            let (_mixed, lam) = cutmix(&x1, &x2, 1.0, &mut rng).expect("ok");
280            assert!((0.0..=1.0).contains(&lam), "lambda={lam} out of range");
281        }
282    }
283
284    // ---- random_crop_2d ----
285
286    #[test]
287    fn test_random_crop_2d_shape() {
288        let input = arange(&[3, 16, 16]);
289        let mut rng = make_rng();
290        let out = random_crop_2d(&input, 12, 12, &mut rng).expect("ok");
291        assert_eq!(out.shape(), &[3, 12, 12]);
292    }
293
294    #[test]
295    fn test_random_crop_invalid_size() {
296        let input = ones(&[8, 8]);
297        let mut rng = make_rng();
298        let result = random_crop_2d(&input, 16, 8, &mut rng);
299        assert!(result.is_err(), "crop larger than input should fail");
300    }
301
302    // ---- center_crop_2d ----
303
304    #[test]
305    fn test_center_crop_2d_shape() {
306        let input = arange(&[1, 3, 32, 32]);
307        let out = center_crop_2d(&input, 24, 24).expect("ok");
308        assert_eq!(out.shape(), &[1, 3, 24, 24]);
309    }
310
311    // ---- random_hflip ----
312
313    #[test]
314    fn test_random_hflip_probability_zero() {
315        let input = arange(&[2, 4, 4]);
316        let mut rng = make_rng();
317        let out = random_hflip(&input, 0.0, &mut rng).expect("ok");
318        assert_eq!(out, input, "p=0 must leave input unchanged");
319    }
320
321    #[test]
322    fn test_random_hflip_probability_one() {
323        // Flip of flip should give back original.
324        let input = arange(&[1, 4, 4]);
325        let mut rng = make_rng();
326        let flipped = random_hflip(&input, 1.0, &mut rng).expect("ok");
327        assert_ne!(flipped, input, "p=1 must flip");
328        let mut rng2 = make_rng();
329        let double_flipped = random_hflip(&flipped, 1.0, &mut rng2).expect("ok");
330        assert_eq!(double_flipped, input, "double flip = identity");
331    }
332
333    // ---- normalize / denormalize ----
334
335    #[test]
336    fn test_normalize_and_denormalize_roundtrip() {
337        use super::super::functional::denormalize;
338        let input = arange(&[2, 3, 4, 4]);
339        let mean = vec![0.485, 0.456, 0.406];
340        let std = vec![0.229, 0.224, 0.225];
341
342        let normed = normalize(&input, &mean, &std).expect("normalize ok");
343        let restored = denormalize(&normed, &mean, &std).expect("denormalize ok");
344
345        for (a, b) in input.iter().zip(restored.iter()) {
346            assert!((a - b).abs() < 1e-9, "roundtrip mismatch: {a} vs {b}");
347        }
348    }
349
350    // ---- clip ----
351
352    #[test]
353    fn test_clip_bounds() {
354        let input = arange(&[10]);
355        let clipped = clip(&input, 2.0, 7.0);
356        for &v in clipped.iter() {
357            assert!((2.0..=7.0).contains(&v), "value {v} out of clipped range");
358        }
359    }
360
361    // ---- pipeline ----
362
363    #[test]
364    fn test_pipeline_apply_empty() {
365        let pipeline = AugmentationPipeline::new(42);
366        let input = arange(&[4, 4]);
367        let out = pipeline.apply(&input, true).expect("ok");
368        assert_eq!(out, input, "empty pipeline is identity");
369    }
370
371    #[test]
372    fn test_pipeline_apply_noise_step() {
373        let pipeline = AugmentationPipeline::new(99)
374            .add_step(AugmentationStep::GaussianNoise { std: 0.01 })
375            .add_step(AugmentationStep::Clip {
376                min_val: -10.0,
377                max_val: 100.0,
378            });
379        let input = ones(&[20, 20]);
380        let out = pipeline.apply(&input, true).expect("ok");
381        assert_eq!(out.shape(), input.shape());
382    }
383
384    // ---- AugStats ----
385
386    #[test]
387    fn test_aug_stats_compute() {
388        let orig = ones(&[10]);
389        let aug = arange(&[10]);
390        let stats = AugStats::compute(&orig, &aug);
391        assert!((stats.original_mean - 1.0).abs() < 1e-9);
392        // At least some elements changed.
393        assert!(stats.element_change_ratio > 0.0);
394    }
395
396    #[test]
397    fn test_aug_stats_summary_nonempty() {
398        let orig = ones(&[5]);
399        let aug = arange(&[5]);
400        let stats = AugStats::compute(&orig, &aug);
401        let summary = stats.summary();
402        assert!(!summary.is_empty());
403        assert!(summary.contains("μ"));
404    }
405}