1use crate::error::{DatasetsError, Result};
8use scirs2_core::ndarray::{s, Array1, Array2, Array3, ArrayView2};
9use scirs2_core::rand_distributions::Normal;
10use scirs2_core::random::Random;
11use scirs2_core::{Rng, RngExt};
12use std::sync::Arc;
13
14fn create_rng() -> Random<scirs2_core::rand_prelude::StdRng> {
16 use std::time::{SystemTime, UNIX_EPOCH};
17 let seed = SystemTime::now()
18 .duration_since(UNIX_EPOCH)
19 .map(|d| d.as_secs())
20 .unwrap_or(0);
21 Random::seed(seed)
22}
23
24pub trait Transform: Send + Sync {
26 fn transform_2d(&self, data: &Array2<f64>) -> Result<Array2<f64>>;
28
29 fn transform_3d(&self, data: &Array3<f64>) -> Result<Array3<f64>> {
31 let (height, width, channels) = data.dim();
33 let mut result = Array3::zeros((height, width, channels));
34 for c in 0..channels {
35 let channel_2d = data.slice(s![.., .., c]).to_owned();
36 let transformed = self.transform_2d(&channel_2d)?;
37 result.slice_mut(s![.., .., c]).assign(&transformed);
38 }
39 Ok(result)
40 }
41
42 fn uses_gpu(&self) -> bool {
44 false
45 }
46
47 fn name(&self) -> &str;
49}
50
51#[derive(Clone)]
53pub struct AugmentationPipeline {
54 transforms: Vec<Arc<dyn Transform>>,
55 probability: f64,
56 seed: Option<u64>,
57}
58
59impl AugmentationPipeline {
60 pub fn new() -> Self {
62 Self {
63 transforms: Vec::new(),
64 probability: 1.0,
65 seed: None,
66 }
67 }
68
69 pub fn add_transform(mut self, transform: Arc<dyn Transform>) -> Self {
71 self.transforms.push(transform);
72 self
73 }
74
75 pub fn with_probability(mut self, prob: f64) -> Self {
77 self.probability = prob.clamp(0.0, 1.0);
78 self
79 }
80
81 pub fn with_seed(mut self, seed: u64) -> Self {
83 self.seed = Some(seed);
84 self
85 }
86
87 pub fn apply_2d(&self, data: &Array2<f64>) -> Result<Array2<f64>> {
89 let mut rng = if let Some(seed) = self.seed {
91 Random::seed(seed)
92 } else {
93 create_rng()
94 };
95
96 if rng.random::<f64>() > self.probability {
97 return Ok(data.clone());
98 }
99
100 let mut result = data.clone();
102 for transform in &self.transforms {
103 result = transform.transform_2d(&result)?;
104 }
105 Ok(result)
106 }
107
108 pub fn apply_3d(&self, data: &Array3<f64>) -> Result<Array3<f64>> {
110 let mut rng = if let Some(seed) = self.seed {
111 Random::seed(seed)
112 } else {
113 create_rng()
114 };
115
116 if rng.random::<f64>() > self.probability {
117 return Ok(data.clone());
118 }
119
120 let mut result = data.clone();
121 for transform in &self.transforms {
122 result = transform.transform_3d(&result)?;
123 }
124 Ok(result)
125 }
126
127 pub fn uses_gpu(&self) -> bool {
129 self.transforms.iter().any(|t| t.uses_gpu())
130 }
131}
132
133impl Default for AugmentationPipeline {
134 fn default() -> Self {
135 Self::new()
136 }
137}
138
139pub struct HorizontalFlip {
145 probability: f64,
146}
147
148impl HorizontalFlip {
149 pub fn new(probability: f64) -> Self {
151 Self {
152 probability: probability.clamp(0.0, 1.0),
153 }
154 }
155}
156
157impl Transform for HorizontalFlip {
158 fn transform_2d(&self, data: &Array2<f64>) -> Result<Array2<f64>> {
159 let mut rng = create_rng();
160 if rng.random::<f64>() < self.probability {
161 let flipped = data.slice(s![.., ..;-1]).to_owned();
163 Ok(flipped)
164 } else {
165 Ok(data.clone())
166 }
167 }
168
169 fn name(&self) -> &str {
170 "HorizontalFlip"
171 }
172}
173
174pub struct VerticalFlip {
176 probability: f64,
177}
178
179impl VerticalFlip {
180 pub fn new(probability: f64) -> Self {
182 Self {
183 probability: probability.clamp(0.0, 1.0),
184 }
185 }
186}
187
188impl Transform for VerticalFlip {
189 fn transform_2d(&self, data: &Array2<f64>) -> Result<Array2<f64>> {
190 let mut rng = create_rng();
191 if rng.random::<f64>() < self.probability {
192 let flipped = data.slice(s![..;-1, ..]).to_owned();
194 Ok(flipped)
195 } else {
196 Ok(data.clone())
197 }
198 }
199
200 fn name(&self) -> &str {
201 "VerticalFlip"
202 }
203}
204
205pub struct RandomRotation90 {
207 probability: f64,
208}
209
210impl RandomRotation90 {
211 pub fn new(probability: f64) -> Self {
213 Self {
214 probability: probability.clamp(0.0, 1.0),
215 }
216 }
217
218 fn rotate_90(&self, data: &Array2<f64>) -> Array2<f64> {
220 let (rows, cols) = data.dim();
221 let mut result = Array2::zeros((cols, rows));
222 for i in 0..rows {
223 for j in 0..cols {
224 result[[j, rows - 1 - i]] = data[[i, j]];
225 }
226 }
227 result
228 }
229}
230
231impl Transform for RandomRotation90 {
232 fn transform_2d(&self, data: &Array2<f64>) -> Result<Array2<f64>> {
233 let mut rng = create_rng();
234 if rng.random::<f64>() < self.probability {
235 let rotations = (rng.random::<f64>() * 3.0).floor() as usize + 1;
237 let mut result = data.clone();
238 for _ in 0..rotations {
239 result = self.rotate_90(&result);
240 }
241 Ok(result)
242 } else {
243 Ok(data.clone())
244 }
245 }
246
247 fn name(&self) -> &str {
248 "RandomRotation90"
249 }
250}
251
252pub struct GaussianNoise {
254 mean: f64,
255 std: f64,
256 probability: f64,
257}
258
259impl GaussianNoise {
260 pub fn new(mean: f64, std: f64, probability: f64) -> Self {
262 Self {
263 mean,
264 std,
265 probability: probability.clamp(0.0, 1.0),
266 }
267 }
268}
269
270impl Transform for GaussianNoise {
271 fn transform_2d(&self, data: &Array2<f64>) -> Result<Array2<f64>> {
272 let mut rng = create_rng();
273 if rng.random::<f64>() < self.probability {
274 let (rows, cols) = data.dim();
275 let mut result = data.clone();
276 let normal = Normal::new(self.mean, self.std).map_err(|e| {
277 DatasetsError::ComputationError(format!(
278 "Failed to create normal distribution: {}",
279 e
280 ))
281 })?;
282 for i in 0..rows {
283 for j in 0..cols {
284 let noise = rng.sample(normal);
285 result[[i, j]] += noise;
286 }
287 }
288 Ok(result)
289 } else {
290 Ok(data.clone())
291 }
292 }
293
294 fn name(&self) -> &str {
295 "GaussianNoise"
296 }
297}
298
299pub struct Brightness {
301 delta_range: (f64, f64),
302 probability: f64,
303}
304
305impl Brightness {
306 pub fn new(delta_range: (f64, f64), probability: f64) -> Self {
308 Self {
309 delta_range,
310 probability: probability.clamp(0.0, 1.0),
311 }
312 }
313}
314
315impl Transform for Brightness {
316 fn transform_2d(&self, data: &Array2<f64>) -> Result<Array2<f64>> {
317 let mut rng = create_rng();
318 if rng.random::<f64>() < self.probability {
319 let delta = self.delta_range.0
320 + rng.random::<f64>() * (self.delta_range.1 - self.delta_range.0);
321 Ok(data + delta)
322 } else {
323 Ok(data.clone())
324 }
325 }
326
327 fn name(&self) -> &str {
328 "Brightness"
329 }
330}
331
332pub struct Contrast {
334 factor_range: (f64, f64),
335 probability: f64,
336}
337
338impl Contrast {
339 pub fn new(factor_range: (f64, f64), probability: f64) -> Self {
341 Self {
342 factor_range,
343 probability: probability.clamp(0.0, 1.0),
344 }
345 }
346}
347
348impl Transform for Contrast {
349 fn transform_2d(&self, data: &Array2<f64>) -> Result<Array2<f64>> {
350 let mut rng = create_rng();
351 if rng.random::<f64>() < self.probability {
352 let factor = self.factor_range.0
353 + rng.random::<f64>() * (self.factor_range.1 - self.factor_range.0);
354 let mean = data.mean().unwrap_or(0.0);
355 Ok((data - mean) * factor + mean)
356 } else {
357 Ok(data.clone())
358 }
359 }
360
361 fn name(&self) -> &str {
362 "Contrast"
363 }
364}
365
366pub struct RandomFeatureScale {
372 scale_range: (f64, f64),
373 feature_probability: f64,
374}
375
376impl RandomFeatureScale {
377 pub fn new(scale_range: (f64, f64), feature_probability: f64) -> Self {
379 Self {
380 scale_range,
381 feature_probability: feature_probability.clamp(0.0, 1.0),
382 }
383 }
384}
385
386impl Transform for RandomFeatureScale {
387 fn transform_2d(&self, data: &Array2<f64>) -> Result<Array2<f64>> {
388 let mut rng = create_rng();
389 let (rows, cols) = data.dim();
390 let mut result = data.clone();
391
392 for j in 0..cols {
393 if rng.random::<f64>() < self.feature_probability {
394 let scale = self.scale_range.0
395 + rng.random::<f64>() * (self.scale_range.1 - self.scale_range.0);
396 for i in 0..rows {
397 result[[i, j]] *= scale;
398 }
399 }
400 }
401
402 Ok(result)
403 }
404
405 fn name(&self) -> &str {
406 "RandomFeatureScale"
407 }
408}
409
410pub struct Mixup {
412 alpha: f64,
413 probability: f64,
414}
415
416impl Mixup {
417 pub fn new(alpha: f64, probability: f64) -> Self {
419 Self {
420 alpha,
421 probability: probability.clamp(0.0, 1.0),
422 }
423 }
424}
425
426impl Transform for Mixup {
427 fn transform_2d(&self, data: &Array2<f64>) -> Result<Array2<f64>> {
428 let mut rng = create_rng();
429 if rng.random::<f64>() < self.probability {
430 let (rows, cols) = data.dim();
431 if rows < 2 {
432 return Ok(data.clone());
433 }
434
435 let mut result = data.clone();
436 for i in 0..rows {
437 let j = (rng.random::<f64>() * rows as f64).floor() as usize % rows;
439 if i != j {
440 let lambda = rng.random::<f64>();
442 for k in 0..cols {
444 result[[i, k]] = lambda * data[[i, k]] + (1.0 - lambda) * data[[j, k]];
445 }
446 }
447 }
448 Ok(result)
449 } else {
450 Ok(data.clone())
451 }
452 }
453
454 fn name(&self) -> &str {
455 "Mixup"
456 }
457}
458
459pub fn standard_image_augmentation(probability: f64) -> AugmentationPipeline {
465 AugmentationPipeline::new()
466 .add_transform(Arc::new(HorizontalFlip::new(0.5)))
467 .add_transform(Arc::new(RandomRotation90::new(0.3)))
468 .add_transform(Arc::new(Brightness::new((-0.2, 0.2), 0.4)))
469 .add_transform(Arc::new(Contrast::new((0.8, 1.2), 0.4)))
470 .add_transform(Arc::new(GaussianNoise::new(0.0, 0.01, 0.3)))
471 .with_probability(probability)
472}
473
474pub fn standard_tabular_augmentation(probability: f64) -> AugmentationPipeline {
476 AugmentationPipeline::new()
477 .add_transform(Arc::new(RandomFeatureScale::new((0.9, 1.1), 0.3)))
478 .add_transform(Arc::new(GaussianNoise::new(0.0, 0.01, 0.2)))
479 .add_transform(Arc::new(Mixup::new(1.0, 0.5)))
480 .with_probability(probability)
481}
482
483#[cfg(test)]
484mod tests {
485 use super::*;
486
487 #[test]
488 fn test_horizontal_flip() -> Result<()> {
489 let data = Array2::from_shape_vec(
490 (3, 4),
491 vec![
492 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
493 ],
494 )
495 .map_err(|e| DatasetsError::InvalidFormat(format!("{}", e)))?;
496
497 let flip = HorizontalFlip::new(1.0); let result = flip.transform_2d(&data)?;
499
500 assert_eq!(result[[0, 0]], 4.0);
501 assert_eq!(result[[0, 3]], 1.0);
502 assert_eq!(result.nrows(), 3);
503 assert_eq!(result.ncols(), 4);
504
505 Ok(())
506 }
507
508 #[test]
509 fn test_gaussian_noise() -> Result<()> {
510 let data = Array2::zeros((10, 10));
511 let noise = GaussianNoise::new(0.0, 0.1, 1.0);
512 let result = noise.transform_2d(&data)?;
513
514 let sum = result.sum();
516 assert!(sum.abs() > 1e-10);
517 assert_eq!(result.dim(), data.dim());
518
519 Ok(())
520 }
521
522 #[test]
523 fn test_brightness() -> Result<()> {
524 let data = Array2::from_elem((5, 5), 0.5);
525 let brightness = Brightness::new((0.1, 0.1), 1.0); let result = brightness.transform_2d(&data)?;
527
528 assert!((result[[0, 0]] - 0.6).abs() < 0.01);
530
531 Ok(())
532 }
533
534 #[test]
535 fn test_augmentation_pipeline() -> Result<()> {
536 let data =
537 Array2::from_shape_vec((3, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0])
538 .map_err(|e| DatasetsError::InvalidFormat(format!("{}", e)))?;
539
540 let pipeline = AugmentationPipeline::new()
541 .add_transform(Arc::new(HorizontalFlip::new(1.0)))
542 .add_transform(Arc::new(Brightness::new((0.1, 0.1), 1.0)))
543 .with_probability(1.0);
544
545 let result = pipeline.apply_2d(&data)?;
546
547 assert_eq!(result.dim(), data.dim());
549
550 Ok(())
551 }
552
553 #[test]
554 fn test_standard_pipelines() {
555 let img_pipeline = standard_image_augmentation(0.8);
556 assert!(!img_pipeline.uses_gpu());
557
558 let tab_pipeline = standard_tabular_augmentation(0.8);
559 assert!(!tab_pipeline.uses_gpu());
560 }
561}