1use scirs2_core::ndarray::ArrayD;
2
3use super::error::AugmentationError;
4use super::functional::{
5 center_crop_2d, clip, dropout, gaussian_noise, normalize, random_crop_2d, random_hflip,
6 random_vflip,
7};
8use super::rng::AugRng;
9
10#[derive(Debug, Clone)]
12pub enum AugmentationStep {
13 GaussianNoise { std: f64 },
14 Dropout { p: f64 },
15 RandomHFlip { p: f64 },
16 RandomVFlip { p: f64 },
17 RandomCrop { crop_h: usize, crop_w: usize },
18 CenterCrop { crop_h: usize, crop_w: usize },
19 Normalize { mean: Vec<f64>, std: Vec<f64> },
20 Clip { min_val: f64, max_val: f64 },
21}
22
23#[derive(Debug, Clone)]
29pub struct AugmentationPipeline {
30 pub steps: Vec<AugmentationStep>,
32 pub rng_seed: u64,
34}
35
36impl AugmentationPipeline {
37 pub fn new(seed: u64) -> Self {
39 Self {
40 steps: Vec::new(),
41 rng_seed: seed,
42 }
43 }
44
45 pub fn add_step(mut self, step: AugmentationStep) -> Self {
47 self.steps.push(step);
48 self
49 }
50
51 pub fn apply(
56 &self,
57 input: &ArrayD<f64>,
58 training: bool,
59 ) -> Result<ArrayD<f64>, AugmentationError> {
60 let mut current = input.clone();
61 for (i, step) in self.steps.iter().enumerate() {
62 let step_seed = self
64 .rng_seed
65 .wrapping_add((i as u64).wrapping_mul(0x9e37_79b9_7f4a_7c15));
66 let mut rng = AugRng::new(step_seed);
67
68 current = match step {
69 AugmentationStep::GaussianNoise { std } => {
70 gaussian_noise(¤t, *std, &mut rng)?
71 }
72 AugmentationStep::Dropout { p } => dropout(¤t, *p, training, &mut rng)?,
73 AugmentationStep::RandomHFlip { p } => random_hflip(¤t, *p, &mut rng)?,
74 AugmentationStep::RandomVFlip { p } => random_vflip(¤t, *p, &mut rng)?,
75 AugmentationStep::RandomCrop { crop_h, crop_w } => {
76 random_crop_2d(¤t, *crop_h, *crop_w, &mut rng)?
77 }
78 AugmentationStep::CenterCrop { crop_h, crop_w } => {
79 center_crop_2d(¤t, *crop_h, *crop_w)?
80 }
81 AugmentationStep::Normalize { mean, std } => normalize(¤t, mean, std)?,
82 AugmentationStep::Clip { min_val, max_val } => clip(¤t, *min_val, *max_val),
83 };
84 }
85 Ok(current)
86 }
87
88 pub fn num_steps(&self) -> usize {
90 self.steps.len()
91 }
92}
93
94#[derive(Debug, Clone)]
96pub struct AugStats {
97 pub original_mean: f64,
99 pub original_std: f64,
101 pub augmented_mean: f64,
103 pub augmented_std: f64,
105 pub element_change_ratio: f64,
107}
108
109impl AugStats {
110 pub fn compute(original: &ArrayD<f64>, augmented: &ArrayD<f64>) -> Self {
112 let orig_flat: Vec<f64> = original.iter().copied().collect();
113 let aug_flat: Vec<f64> = augmented.iter().copied().collect();
114 let n = orig_flat.len().max(1);
115
116 let orig_mean = orig_flat.iter().sum::<f64>() / n as f64;
117 let aug_mean = aug_flat.iter().sum::<f64>() / aug_flat.len().max(1) as f64;
118
119 let orig_var = orig_flat
120 .iter()
121 .map(|&x| (x - orig_mean).powi(2))
122 .sum::<f64>()
123 / n as f64;
124 let aug_var = aug_flat
125 .iter()
126 .map(|&x| (x - aug_mean).powi(2))
127 .sum::<f64>()
128 / aug_flat.len().max(1) as f64;
129
130 let compare_n = orig_flat.len().min(aug_flat.len()).max(1);
131 let changed = orig_flat
132 .iter()
133 .zip(aug_flat.iter())
134 .filter(|(&a, &b)| (a - b).abs() > 1e-12)
135 .count();
136
137 AugStats {
138 original_mean: orig_mean,
139 original_std: orig_var.sqrt(),
140 augmented_mean: aug_mean,
141 augmented_std: aug_var.sqrt(),
142 element_change_ratio: changed as f64 / compare_n as f64,
143 }
144 }
145
146 pub fn summary(&self) -> String {
148 format!(
149 "orig μ={:.4} σ={:.4} | aug μ={:.4} σ={:.4} | changed {:.1}%",
150 self.original_mean,
151 self.original_std,
152 self.augmented_mean,
153 self.augmented_std,
154 self.element_change_ratio * 100.0
155 )
156 }
157}
158
159#[cfg(test)]
160mod aug_tests {
161 use super::*;
162 use scirs2_core::ndarray::ArrayD;
163
164 fn make_rng() -> AugRng {
165 AugRng::new(0xDEAD_BEEF)
166 }
167
168 fn ones(shape: &[usize]) -> ArrayD<f64> {
169 use scirs2_core::ndarray::IxDyn;
170 let n: usize = shape.iter().product();
171 ArrayD::from_shape_vec(IxDyn(shape), vec![1.0f64; n]).expect("shape ok")
172 }
173
174 fn arange(shape: &[usize]) -> ArrayD<f64> {
175 use scirs2_core::ndarray::IxDyn;
176 let n: usize = shape.iter().product();
177 let data: Vec<f64> = (0..n).map(|i| i as f64).collect();
178 ArrayD::from_shape_vec(IxDyn(shape), data).expect("shape ok")
179 }
180
181 #[test]
184 fn test_gaussian_noise_shape_preserved() {
185 let input = ones(&[3, 4]);
186 let mut rng = make_rng();
187 let out = gaussian_noise(&input, 0.1, &mut rng).expect("ok");
188 assert_eq!(out.shape(), input.shape());
189 }
190
191 #[test]
192 fn test_gaussian_noise_mean_near_original() {
193 let input = ones(&[10, 100]);
195 let mut rng = make_rng();
196 let out = gaussian_noise(&input, 0.01, &mut rng).expect("ok");
197 let sum: f64 = out.iter().sum();
198 let mean = sum / 1000.0;
199 assert!((mean - 1.0).abs() < 0.05, "mean {mean} too far from 1.0");
200 }
201
202 #[test]
205 fn test_dropout_training_zeroes_some() {
206 let input = ones(&[100]);
207 let mut rng = make_rng();
208 let out = dropout(&input, 0.5, true, &mut rng).expect("ok");
209 let zero_count = out.iter().filter(|&&x| x == 0.0).count();
210 assert!(zero_count > 0, "expected some zeros");
212 assert!(zero_count < 100, "not all should be zero");
213 }
214
215 #[test]
216 fn test_dropout_inference_unchanged() {
217 let input = arange(&[5, 5]);
218 let mut rng = make_rng();
219 let out = dropout(&input, 0.9, false, &mut rng).expect("ok");
220 assert_eq!(out, input);
221 }
222
223 #[test]
226 fn test_dropout_mask_shape() {
227 use super::super::functional::dropout_mask;
228 let mut rng = make_rng();
229 let mask = dropout_mask(&[4, 4], 0.3, &mut rng).expect("ok");
230 assert_eq!(mask.shape(), &[4, 4]);
231 for &v in mask.iter() {
232 assert!(v == 0.0 || v == 1.0);
233 }
234 }
235
236 #[test]
239 fn test_mixup_shape() {
240 use super::super::functional::mixup;
241 let x1 = ones(&[3, 4]);
242 let x2 = arange(&[3, 4]);
243 let mut rng = make_rng();
244 let (mixed, _lam) = mixup(&x1, &x2, 1.0, &mut rng).expect("ok");
245 assert_eq!(mixed.shape(), x1.shape());
246 }
247
248 #[test]
249 fn test_mixup_lambda_range() {
250 use super::super::functional::mixup;
251 let x1 = ones(&[2, 2]);
252 let x2 = ones(&[2, 2]);
253 let mut rng = make_rng();
254 for _ in 0..50 {
255 let (_mixed, lam) = mixup(&x1, &x2, 1.0, &mut rng).expect("ok");
256 assert!((0.0..=1.0).contains(&lam), "lambda={lam} out of range");
257 }
258 }
259
260 #[test]
263 fn test_cutmix_shape() {
264 use super::super::functional::cutmix;
265 let x1 = ones(&[1, 3, 8, 8]);
266 let x2 = arange(&[1, 3, 8, 8]);
267 let mut rng = make_rng();
268 let (mixed, _lam) = cutmix(&x1, &x2, 1.0, &mut rng).expect("ok");
269 assert_eq!(mixed.shape(), x1.shape());
270 }
271
272 #[test]
273 fn test_cutmix_lambda_range() {
274 use super::super::functional::cutmix;
275 let x1 = ones(&[1, 4, 8, 8]);
276 let x2 = arange(&[1, 4, 8, 8]);
277 let mut rng = make_rng();
278 for _ in 0..20 {
279 let (_mixed, lam) = cutmix(&x1, &x2, 1.0, &mut rng).expect("ok");
280 assert!((0.0..=1.0).contains(&lam), "lambda={lam} out of range");
281 }
282 }
283
284 #[test]
287 fn test_random_crop_2d_shape() {
288 let input = arange(&[3, 16, 16]);
289 let mut rng = make_rng();
290 let out = random_crop_2d(&input, 12, 12, &mut rng).expect("ok");
291 assert_eq!(out.shape(), &[3, 12, 12]);
292 }
293
294 #[test]
295 fn test_random_crop_invalid_size() {
296 let input = ones(&[8, 8]);
297 let mut rng = make_rng();
298 let result = random_crop_2d(&input, 16, 8, &mut rng);
299 assert!(result.is_err(), "crop larger than input should fail");
300 }
301
302 #[test]
305 fn test_center_crop_2d_shape() {
306 let input = arange(&[1, 3, 32, 32]);
307 let out = center_crop_2d(&input, 24, 24).expect("ok");
308 assert_eq!(out.shape(), &[1, 3, 24, 24]);
309 }
310
311 #[test]
314 fn test_random_hflip_probability_zero() {
315 let input = arange(&[2, 4, 4]);
316 let mut rng = make_rng();
317 let out = random_hflip(&input, 0.0, &mut rng).expect("ok");
318 assert_eq!(out, input, "p=0 must leave input unchanged");
319 }
320
321 #[test]
322 fn test_random_hflip_probability_one() {
323 let input = arange(&[1, 4, 4]);
325 let mut rng = make_rng();
326 let flipped = random_hflip(&input, 1.0, &mut rng).expect("ok");
327 assert_ne!(flipped, input, "p=1 must flip");
328 let mut rng2 = make_rng();
329 let double_flipped = random_hflip(&flipped, 1.0, &mut rng2).expect("ok");
330 assert_eq!(double_flipped, input, "double flip = identity");
331 }
332
333 #[test]
336 fn test_normalize_and_denormalize_roundtrip() {
337 use super::super::functional::denormalize;
338 let input = arange(&[2, 3, 4, 4]);
339 let mean = vec![0.485, 0.456, 0.406];
340 let std = vec![0.229, 0.224, 0.225];
341
342 let normed = normalize(&input, &mean, &std).expect("normalize ok");
343 let restored = denormalize(&normed, &mean, &std).expect("denormalize ok");
344
345 for (a, b) in input.iter().zip(restored.iter()) {
346 assert!((a - b).abs() < 1e-9, "roundtrip mismatch: {a} vs {b}");
347 }
348 }
349
350 #[test]
353 fn test_clip_bounds() {
354 let input = arange(&[10]);
355 let clipped = clip(&input, 2.0, 7.0);
356 for &v in clipped.iter() {
357 assert!((2.0..=7.0).contains(&v), "value {v} out of clipped range");
358 }
359 }
360
361 #[test]
364 fn test_pipeline_apply_empty() {
365 let pipeline = AugmentationPipeline::new(42);
366 let input = arange(&[4, 4]);
367 let out = pipeline.apply(&input, true).expect("ok");
368 assert_eq!(out, input, "empty pipeline is identity");
369 }
370
371 #[test]
372 fn test_pipeline_apply_noise_step() {
373 let pipeline = AugmentationPipeline::new(99)
374 .add_step(AugmentationStep::GaussianNoise { std: 0.01 })
375 .add_step(AugmentationStep::Clip {
376 min_val: -10.0,
377 max_val: 100.0,
378 });
379 let input = ones(&[20, 20]);
380 let out = pipeline.apply(&input, true).expect("ok");
381 assert_eq!(out.shape(), input.shape());
382 }
383
384 #[test]
387 fn test_aug_stats_compute() {
388 let orig = ones(&[10]);
389 let aug = arange(&[10]);
390 let stats = AugStats::compute(&orig, &aug);
391 assert!((stats.original_mean - 1.0).abs() < 1e-9);
392 assert!(stats.element_change_ratio > 0.0);
394 }
395
396 #[test]
397 fn test_aug_stats_summary_nonempty() {
398 let orig = ones(&[5]);
399 let aug = arange(&[5]);
400 let stats = AugStats::compute(&orig, &aug);
401 let summary = stats.summary();
402 assert!(!summary.is_empty());
403 assert!(summary.contains("μ"));
404 }
405}