1use crate::{TrainError, TrainResult};
10use scirs2_core::ndarray::{Array, Array2, ArrayView2};
11use scirs2_core::random::{Rng, StdRng};
12
13pub trait DataAugmenter {
15 fn augment(&self, data: &ArrayView2<f64>, rng: &mut StdRng) -> TrainResult<Array2<f64>>;
24}
25
26#[derive(Debug, Clone, Default)]
30pub struct NoAugmentation;
31
32impl DataAugmenter for NoAugmentation {
33 fn augment(&self, data: &ArrayView2<f64>, _rng: &mut StdRng) -> TrainResult<Array2<f64>> {
34 Ok(data.to_owned())
35 }
36}
37
38#[derive(Debug, Clone)]
42pub struct NoiseAugmenter {
43 pub std_dev: f64,
45}
46
47impl NoiseAugmenter {
48 pub fn new(std_dev: f64) -> TrainResult<Self> {
53 if std_dev < 0.0 {
54 return Err(TrainError::InvalidParameter(
55 "std_dev must be non-negative".to_string(),
56 ));
57 }
58 Ok(Self { std_dev })
59 }
60}
61
62impl Default for NoiseAugmenter {
63 fn default() -> Self {
64 Self { std_dev: 0.01 }
65 }
66}
67
68impl DataAugmenter for NoiseAugmenter {
69 fn augment(&self, data: &ArrayView2<f64>, rng: &mut StdRng) -> TrainResult<Array2<f64>> {
70 let mut augmented = data.to_owned();
71
72 for value in augmented.iter_mut() {
74 let u1: f64 = rng.random();
75 let u2: f64 = rng.random();
76
77 let z0 = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
79 let noise = z0 * self.std_dev;
80
81 *value += noise;
82 }
83
84 Ok(augmented)
85 }
86}
87
88#[derive(Debug, Clone)]
92pub struct ScaleAugmenter {
93 pub scale_range: f64,
95}
96
97impl ScaleAugmenter {
98 pub fn new(scale_range: f64) -> TrainResult<Self> {
103 if !(0.0..=1.0).contains(&scale_range) {
104 return Err(TrainError::InvalidParameter(
105 "scale_range must be in [0, 1]".to_string(),
106 ));
107 }
108 Ok(Self { scale_range })
109 }
110}
111
112impl Default for ScaleAugmenter {
113 fn default() -> Self {
114 Self { scale_range: 0.1 }
115 }
116}
117
118impl DataAugmenter for ScaleAugmenter {
119 fn augment(&self, data: &ArrayView2<f64>, rng: &mut StdRng) -> TrainResult<Array2<f64>> {
120 let scale = 1.0 + (rng.random::<f64>() * 2.0 - 1.0) * self.scale_range;
122
123 let augmented = data.mapv(|x| x * scale);
124 Ok(augmented)
125 }
126}
127
128#[derive(Debug, Clone)]
133pub struct RotationAugmenter {
134 pub max_angle: f64,
136}
137
138impl RotationAugmenter {
139 pub fn new(max_angle: f64) -> Self {
144 Self { max_angle }
145 }
146}
147
148impl Default for RotationAugmenter {
149 fn default() -> Self {
150 Self {
151 max_angle: std::f64::consts::PI / 18.0, }
153 }
154}
155
156impl DataAugmenter for RotationAugmenter {
157 fn augment(&self, data: &ArrayView2<f64>, rng: &mut StdRng) -> TrainResult<Array2<f64>> {
158 let angle = (rng.random::<f64>() * 2.0 - 1.0) * self.max_angle;
161
162 let cos_a = angle.cos();
164 let sin_a = angle.sin();
165
166 let augmented = data.mapv(|x| x * cos_a + x * sin_a * 0.1);
167 Ok(augmented)
168 }
169}
170
171#[derive(Debug, Clone)]
178pub struct MixupAugmenter {
179 pub alpha: f64,
181}
182
183impl MixupAugmenter {
184 pub fn new(alpha: f64) -> TrainResult<Self> {
189 if alpha <= 0.0 {
190 return Err(TrainError::InvalidParameter(
191 "alpha must be positive".to_string(),
192 ));
193 }
194 Ok(Self { alpha })
195 }
196
197 pub fn augment_batch(
207 &self,
208 data: &ArrayView2<f64>,
209 labels: &ArrayView2<f64>,
210 rng: &mut StdRng,
211 ) -> TrainResult<(Array2<f64>, Array2<f64>)> {
212 if data.nrows() != labels.nrows() {
213 return Err(TrainError::InvalidParameter(
214 "data and labels must have same number of rows".to_string(),
215 ));
216 }
217
218 let n = data.nrows();
219 let mut augmented_data = Array::zeros(data.raw_dim());
220 let mut augmented_labels = Array::zeros(labels.raw_dim());
221
222 let mut indices: Vec<usize> = (0..n).collect();
224 for i in (1..n).rev() {
225 let j = rng.gen_range(0..=i);
226 indices.swap(i, j);
227 }
228
229 for i in 0..n {
230 let j = indices[i];
231
232 let lambda = self.sample_beta(rng);
235
236 for k in 0..data.ncols() {
238 augmented_data[[i, k]] = lambda * data[[i, k]] + (1.0 - lambda) * data[[j, k]];
239 }
240
241 for k in 0..labels.ncols() {
243 augmented_labels[[i, k]] =
244 lambda * labels[[i, k]] + (1.0 - lambda) * labels[[j, k]];
245 }
246 }
247
248 Ok((augmented_data, augmented_labels))
249 }
250
251 fn sample_beta(&self, rng: &mut StdRng) -> f64 {
255 if self.alpha < 0.5 {
256 if rng.random::<f64>() < 0.5 {
258 rng.random::<f64>().powf(2.0)
259 } else {
260 1.0 - rng.random::<f64>().powf(2.0)
261 }
262 } else {
263 rng.random::<f64>()
265 }
266 }
267}
268
269impl Default for MixupAugmenter {
270 fn default() -> Self {
271 Self { alpha: 1.0 }
272 }
273}
274
275impl DataAugmenter for MixupAugmenter {
276 fn augment(&self, data: &ArrayView2<f64>, _rng: &mut StdRng) -> TrainResult<Array2<f64>> {
277 Ok(data.to_owned())
280 }
281}
282
283#[derive(Clone, Default)]
285pub struct CompositeAugmenter {
286 augmenters: Vec<Box<dyn AugmenterClone>>,
287}
288
289impl std::fmt::Debug for CompositeAugmenter {
290 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
291 f.debug_struct("CompositeAugmenter")
292 .field("num_augmenters", &self.augmenters.len())
293 .finish()
294 }
295}
296
297trait AugmenterClone: DataAugmenter {
299 fn clone_box(&self) -> Box<dyn AugmenterClone>;
300}
301
302impl<T: DataAugmenter + Clone + 'static> AugmenterClone for T {
303 fn clone_box(&self) -> Box<dyn AugmenterClone> {
304 Box::new(self.clone())
305 }
306}
307
308impl Clone for Box<dyn AugmenterClone> {
309 fn clone(&self) -> Self {
310 self.clone_box()
311 }
312}
313
314impl DataAugmenter for Box<dyn AugmenterClone> {
315 fn augment(&self, data: &ArrayView2<f64>, rng: &mut StdRng) -> TrainResult<Array2<f64>> {
316 (**self).augment(data, rng)
317 }
318}
319
320impl CompositeAugmenter {
321 pub fn new() -> Self {
323 Self {
324 augmenters: Vec::new(),
325 }
326 }
327
328 pub fn add<A: DataAugmenter + Clone + 'static>(&mut self, augmenter: A) {
330 self.augmenters.push(Box::new(augmenter));
331 }
332
333 pub fn len(&self) -> usize {
335 self.augmenters.len()
336 }
337
338 pub fn is_empty(&self) -> bool {
340 self.augmenters.is_empty()
341 }
342}
343
344impl DataAugmenter for CompositeAugmenter {
345 fn augment(&self, data: &ArrayView2<f64>, rng: &mut StdRng) -> TrainResult<Array2<f64>> {
346 let mut result = data.to_owned();
347
348 for augmenter in &self.augmenters {
349 result = augmenter.augment(&result.view(), rng)?;
350 }
351
352 Ok(result)
353 }
354}
355
356#[cfg(test)]
357mod tests {
358 use super::*;
359 use scirs2_core::ndarray::array;
360 use scirs2_core::random::SeedableRng;
361
362 fn create_test_rng() -> StdRng {
363 StdRng::seed_from_u64(42)
364 }
365
366 #[test]
367 fn test_no_augmentation() {
368 let augmenter = NoAugmentation;
369 let data = array![[1.0, 2.0], [3.0, 4.0]];
370 let mut rng = create_test_rng();
371
372 let augmented = augmenter.augment(&data.view(), &mut rng).unwrap();
373 assert_eq!(augmented, data);
374 }
375
376 #[test]
377 fn test_noise_augmenter() {
378 let augmenter = NoiseAugmenter::new(0.1).unwrap();
379 let data = array![[1.0, 2.0], [3.0, 4.0]];
380 let mut rng = create_test_rng();
381
382 let augmented = augmenter.augment(&data.view(), &mut rng).unwrap();
383
384 assert_eq!(augmented.shape(), data.shape());
386
387 assert_ne!(augmented[[0, 0]], data[[0, 0]]);
389
390 for i in 0..data.nrows() {
392 for j in 0..data.ncols() {
393 let diff = (augmented[[i, j]] - data[[i, j]]).abs();
394 assert!(diff < 1.0); }
396 }
397 }
398
399 #[test]
400 fn test_noise_augmenter_invalid() {
401 let result = NoiseAugmenter::new(-0.1);
402 assert!(result.is_err());
403 }
404
405 #[test]
406 fn test_scale_augmenter() {
407 let augmenter = ScaleAugmenter::new(0.2).unwrap();
408 let data = array![[1.0, 2.0], [3.0, 4.0]];
409 let mut rng = create_test_rng();
410
411 let augmented = augmenter.augment(&data.view(), &mut rng).unwrap();
412
413 assert_eq!(augmented.shape(), data.shape());
415
416 let scale = augmented[[0, 0]] / data[[0, 0]];
418 for i in 0..data.nrows() {
419 for j in 0..data.ncols() {
420 let computed_scale = augmented[[i, j]] / data[[i, j]];
421 assert!((computed_scale - scale).abs() < 1e-10);
422 }
423 }
424
425 assert!((0.8..=1.2).contains(&scale));
427 }
428
429 #[test]
430 fn test_scale_augmenter_invalid() {
431 assert!(ScaleAugmenter::new(-0.1).is_err());
432 assert!(ScaleAugmenter::new(1.5).is_err());
433 }
434
435 #[test]
436 fn test_rotation_augmenter() {
437 let augmenter = RotationAugmenter::default();
438 let data = array![[1.0, 2.0], [3.0, 4.0]];
439 let mut rng = create_test_rng();
440
441 let augmented = augmenter.augment(&data.view(), &mut rng).unwrap();
442
443 assert_eq!(augmented.shape(), data.shape());
445 }
446
447 #[test]
448 fn test_mixup_augmenter_batch() {
449 let augmenter = MixupAugmenter::new(1.0).unwrap();
450 let data = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
451 let labels = array![[1.0, 0.0], [0.0, 1.0], [1.0, 0.0]];
452 let mut rng = create_test_rng();
453
454 let (aug_data, aug_labels) = augmenter
455 .augment_batch(&data.view(), &labels.view(), &mut rng)
456 .unwrap();
457
458 assert_eq!(aug_data.shape(), data.shape());
460 assert_eq!(aug_labels.shape(), labels.shape());
461
462 let data_min = data.iter().cloned().fold(f64::INFINITY, f64::min);
464 let data_max = data.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
465
466 for &val in aug_data.iter() {
467 assert!(val >= data_min && val <= data_max);
468 }
469 }
470
471 #[test]
472 fn test_mixup_invalid_alpha() {
473 let result = MixupAugmenter::new(0.0);
474 assert!(result.is_err());
475
476 let result = MixupAugmenter::new(-1.0);
477 assert!(result.is_err());
478 }
479
480 #[test]
481 fn test_mixup_mismatched_shapes() {
482 let augmenter = MixupAugmenter::default();
483 let data = array![[1.0, 2.0], [3.0, 4.0]];
484 let labels = array![[1.0, 0.0]]; let mut rng = create_test_rng();
486
487 let result = augmenter.augment_batch(&data.view(), &labels.view(), &mut rng);
488 assert!(result.is_err());
489 }
490
491 #[test]
492 fn test_composite_augmenter() {
493 let mut composite = CompositeAugmenter::new();
494 composite.add(NoiseAugmenter::new(0.01).unwrap());
495 composite.add(ScaleAugmenter::new(0.1).unwrap());
496
497 let data = array![[1.0, 2.0], [3.0, 4.0]];
498 let mut rng = create_test_rng();
499
500 let augmented = composite.augment(&data.view(), &mut rng).unwrap();
501
502 assert_eq!(augmented.shape(), data.shape());
504
505 assert_ne!(augmented[[0, 0]], data[[0, 0]]);
507 }
508
509 #[test]
510 fn test_composite_empty() {
511 let composite = CompositeAugmenter::new();
512 assert!(composite.is_empty());
513 assert_eq!(composite.len(), 0);
514
515 let data = array![[1.0, 2.0]];
516 let mut rng = create_test_rng();
517
518 let augmented = composite.augment(&data.view(), &mut rng).unwrap();
519 assert_eq!(augmented, data);
520 }
521
522 #[test]
523 fn test_composite_multiple() {
524 let mut composite = CompositeAugmenter::new();
525 composite.add(NoAugmentation);
526 composite.add(ScaleAugmenter::default());
527 composite.add(NoiseAugmenter::default());
528
529 assert_eq!(composite.len(), 3);
530 assert!(!composite.is_empty());
531 }
532}