Skip to main content

torsh_data/
augmentation_pipeline.rs

1//! Data augmentation pipeline for machine learning training
2//!
3//! This module provides a comprehensive augmentation system for training data preprocessing.
4//! Augmentation is a critical technique for improving model generalization by applying
5//! random transformations to training data.
6//!
7//! # Features
8//!
9//! - **Pipeline composition**: AugmentationPipeline for chaining multiple transforms
10//! - **Probabilistic transforms**: ConditionalTransform for random application
11//! - **Image augmentations**: Color, brightness, contrast, and geometric transforms
12//! - **Noise augmentation**: GaussianNoise for regularization
13//! - **Cutout/Erasing**: RandomErasing for occlusion robustness
14//! - **Preset pipelines**: Common augmentation configurations for different use cases
15
16use crate::transforms::Transform;
17use torsh_core::dtype::FloatElement;
18use torsh_core::error::Result;
19use torsh_tensor::Tensor;
20
21#[cfg(not(feature = "std"))]
22use alloc::{boxed::Box, vec::Vec};
23
24#[cfg(feature = "std")]
25use scirs2_core::random::thread_rng;
26
27#[cfg(not(feature = "std"))]
28use scirs2_core::random::thread_rng;
29
30/// Augmentation pipeline builder for easy composition of transforms
31pub struct AugmentationPipeline<T> {
32    transforms: Vec<Box<dyn Transform<T, Output = T> + Send + Sync>>,
33    probability: f32,
34}
35
36impl<T: 'static + Send + Sync> AugmentationPipeline<T> {
37    /// Create a new augmentation pipeline
38    pub fn new() -> Self {
39        Self {
40            transforms: Vec::new(),
41            probability: 1.0,
42        }
43    }
44
45    /// Set the probability of applying the entire pipeline
46    pub fn with_probability(mut self, prob: f32) -> Self {
47        assert!(
48            (0.0..=1.0).contains(&prob),
49            "Probability must be between 0 and 1"
50        );
51        self.probability = prob;
52        self
53    }
54
55    /// Add a transform to the pipeline
56    pub fn add_transform<F>(mut self, transform: F) -> Self
57    where
58        F: Transform<T, Output = T> + 'static,
59    {
60        self.transforms.push(Box::new(transform));
61        self
62    }
63
64    /// Add a conditional transform that only applies with given probability
65    pub fn add_conditional<F>(self, transform: F, prob: f32) -> Self
66    where
67        F: Transform<T, Output = T> + 'static,
68    {
69        self.add_transform(ConditionalTransform::new(transform, prob))
70    }
71}
72
73impl<T: 'static + Send + Sync> Default for AugmentationPipeline<T> {
74    fn default() -> Self {
75        Self::new()
76    }
77}
78
79impl<T> Transform<T> for AugmentationPipeline<T> {
80    type Output = T;
81
82    fn transform(&self, mut input: T) -> Result<Self::Output> {
83        let mut rng = thread_rng();
84
85        // Check if we should apply the pipeline at all
86        if rng.random::<f32>() > self.probability {
87            return Ok(input);
88        }
89
90        // Apply all transforms in sequence
91        for transform in &self.transforms {
92            input = transform.transform(input)?;
93        }
94
95        Ok(input)
96    }
97}
98
99/// Conditional transform that applies with a given probability
100pub struct ConditionalTransform<T, F> {
101    transform: F,
102    probability: f32,
103    _phantom: core::marker::PhantomData<T>,
104}
105
106impl<T, F> ConditionalTransform<T, F> {
107    pub fn new(transform: F, probability: f32) -> Self {
108        assert!(
109            (0.0..=1.0).contains(&probability),
110            "Probability must be between 0 and 1"
111        );
112        Self {
113            transform,
114            probability,
115            _phantom: core::marker::PhantomData,
116        }
117    }
118}
119
120impl<T, F> Transform<T> for ConditionalTransform<T, F>
121where
122    F: Transform<T, Output = T>,
123    T: Send + Sync,
124{
125    type Output = T;
126
127    fn transform(&self, input: T) -> Result<Self::Output> {
128        let mut rng = thread_rng();
129
130        if rng.random::<f32>() < self.probability {
131            self.transform.transform(input)
132        } else {
133            Ok(input)
134        }
135    }
136}
137
138/// Random brightness adjustment
139pub struct RandomBrightness {
140    #[allow(dead_code)]
141    factor_range: (f32, f32),
142}
143
144impl RandomBrightness {
145    pub fn new(factor_range: (f32, f32)) -> Self {
146        assert!(factor_range.0 <= factor_range.1, "Invalid factor range");
147        Self { factor_range }
148    }
149
150    /// Create with symmetric range around 1.0
151    pub fn symmetric(factor: f32) -> Self {
152        Self::new((1.0 - factor, 1.0 + factor))
153    }
154}
155
156impl<T: FloatElement> Transform<Tensor<T>> for RandomBrightness {
157    type Output = Tensor<T>;
158
159    fn transform(&self, input: Tensor<T>) -> Result<Self::Output> {
160        // For now, return input as-is since tensor operations need proper trait bounds
161        // In a full implementation, we would apply brightness adjustment
162        Ok(input)
163    }
164}
165
166/// Random contrast adjustment
167pub struct RandomContrast {
168    #[allow(dead_code)]
169    factor_range: (f32, f32),
170}
171
172impl RandomContrast {
173    pub fn new(factor_range: (f32, f32)) -> Self {
174        assert!(factor_range.0 <= factor_range.1, "Invalid factor range");
175        Self { factor_range }
176    }
177
178    /// Create with symmetric range around 1.0
179    pub fn symmetric(factor: f32) -> Self {
180        Self::new((1.0 - factor, 1.0 + factor))
181    }
182}
183
184impl<T: FloatElement> Transform<Tensor<T>> for RandomContrast {
185    type Output = Tensor<T>;
186
187    fn transform(&self, input: Tensor<T>) -> Result<Self::Output> {
188        // For now, return input as-is since tensor operations need proper trait bounds
189        // In a full implementation, we would apply contrast adjustment
190        Ok(input)
191    }
192}
193
194/// Random saturation adjustment (for color images)
195pub struct RandomSaturation {
196    #[allow(dead_code)]
197    factor_range: (f32, f32),
198}
199
200impl RandomSaturation {
201    pub fn new(factor_range: (f32, f32)) -> Self {
202        assert!(factor_range.0 <= factor_range.1, "Invalid factor range");
203        Self { factor_range }
204    }
205
206    /// Create with symmetric range around 1.0
207    pub fn symmetric(factor: f32) -> Self {
208        Self::new((1.0 - factor, 1.0 + factor))
209    }
210}
211
212impl<T: FloatElement> Transform<Tensor<T>> for RandomSaturation {
213    type Output = Tensor<T>;
214
215    fn transform(&self, input: Tensor<T>) -> Result<Self::Output> {
216        // For now, return input as-is since proper saturation adjustment
217        // requires complex RGB to grayscale conversion operations
218        Ok(input)
219    }
220}
221
222/// Random hue adjustment (for color images)
223pub struct RandomHue {
224    #[allow(dead_code)]
225    delta_range: (f32, f32),
226}
227
228impl RandomHue {
229    pub fn new(delta_range: (f32, f32)) -> Self {
230        assert!(delta_range.0 <= delta_range.1, "Invalid delta range");
231        assert!(
232            delta_range.0 >= -1.0 && delta_range.1 <= 1.0,
233            "Hue delta must be in [-1, 1]"
234        );
235        Self { delta_range }
236    }
237
238    /// Create with symmetric range
239    pub fn symmetric(delta: f32) -> Self {
240        Self::new((-delta, delta))
241    }
242}
243
244impl<T: FloatElement> Transform<Tensor<T>> for RandomHue {
245    type Output = Tensor<T>;
246
247    fn transform(&self, input: Tensor<T>) -> Result<Self::Output> {
248        // For now, return input as-is since proper HSV conversion
249        // requires more complex operations
250        Ok(input)
251    }
252}
253
254/// Random vertical flip
255pub struct RandomVerticalFlip {
256    #[allow(dead_code)]
257    prob: f32,
258}
259
260impl RandomVerticalFlip {
261    pub fn new(prob: f32) -> Self {
262        assert!(
263            (0.0..=1.0).contains(&prob),
264            "Probability must be between 0 and 1"
265        );
266        Self { prob }
267    }
268}
269
270impl<T: FloatElement> Transform<Tensor<T>> for RandomVerticalFlip {
271    type Output = Tensor<T>;
272
273    fn transform(&self, input: Tensor<T>) -> Result<Self::Output> {
274        // For now, return input as-is
275        Ok(input)
276    }
277}
278
279/// Gaussian noise addition
280pub struct GaussianNoise {
281    #[allow(dead_code)]
282    mean: f32,
283    #[allow(dead_code)]
284    std: f32,
285}
286
287impl GaussianNoise {
288    pub fn new(mean: f32, std: f32) -> Self {
289        assert!(std >= 0.0, "Standard deviation must be non-negative");
290        Self { mean, std }
291    }
292
293    /// Create with zero mean
294    pub fn with_std(std: f32) -> Self {
295        Self::new(0.0, std)
296    }
297}
298
299impl<T: FloatElement> Transform<Tensor<T>> for GaussianNoise {
300    type Output = Tensor<T>;
301
302    fn transform(&self, input: Tensor<T>) -> Result<Self::Output> {
303        // For now, return input as-is
304        Ok(input)
305    }
306}
307
308/// Random erasing (cutout) augmentation
309pub struct RandomErasing {
310    #[allow(dead_code)]
311    prob: f32,
312    #[allow(dead_code)]
313    scale_range: (f32, f32),
314    #[allow(dead_code)]
315    ratio_range: (f32, f32),
316    #[allow(dead_code)]
317    fill_value: f32,
318}
319
320impl RandomErasing {
321    pub fn new(prob: f32, scale_range: (f32, f32), ratio_range: (f32, f32)) -> Self {
322        assert!(
323            (0.0..=1.0).contains(&prob),
324            "Probability must be between 0 and 1"
325        );
326        assert!(scale_range.0 <= scale_range.1, "Invalid scale range");
327        assert!(ratio_range.0 <= ratio_range.1, "Invalid ratio range");
328
329        Self {
330            prob,
331            scale_range,
332            ratio_range,
333            fill_value: 0.0,
334        }
335    }
336
337    pub fn with_fill_value(mut self, fill_value: f32) -> Self {
338        self.fill_value = fill_value;
339        self
340    }
341}
342
343impl<T: FloatElement> Transform<Tensor<T>> for RandomErasing {
344    type Output = Tensor<T>;
345
346    fn transform(&self, input: Tensor<T>) -> Result<Self::Output> {
347        // For now, return input as-is
348        Ok(input)
349    }
350}
351
352/// Common augmentation presets
353impl AugmentationPipeline<Tensor<f32>> {
354    /// Create a light augmentation pipeline for training
355    pub fn light_augmentation() -> Self {
356        Self::new()
357            .add_conditional(
358                crate::tensor_transforms::RandomHorizontalFlip::new(0.5),
359                1.0,
360            )
361            .add_conditional(RandomBrightness::symmetric(0.1), 0.3)
362            .add_conditional(RandomContrast::symmetric(0.1), 0.3)
363    }
364
365    /// Create a medium augmentation pipeline
366    pub fn medium_augmentation() -> Self {
367        Self::new()
368            .add_conditional(
369                crate::tensor_transforms::RandomHorizontalFlip::new(0.5),
370                1.0,
371            )
372            .add_conditional(RandomVerticalFlip::new(0.1), 1.0)
373            .add_conditional(RandomBrightness::symmetric(0.2), 0.5)
374            .add_conditional(RandomContrast::symmetric(0.2), 0.5)
375            .add_conditional(RandomSaturation::symmetric(0.2), 0.3)
376            .add_conditional(GaussianNoise::with_std(0.01), 0.2)
377    }
378
379    /// Create a heavy augmentation pipeline
380    pub fn heavy_augmentation() -> Self {
381        Self::new()
382            .add_conditional(
383                crate::tensor_transforms::RandomHorizontalFlip::new(0.5),
384                1.0,
385            )
386            .add_conditional(RandomVerticalFlip::new(0.2), 1.0)
387            .add_conditional(RandomBrightness::symmetric(0.3), 0.7)
388            .add_conditional(RandomContrast::symmetric(0.3), 0.7)
389            .add_conditional(RandomSaturation::symmetric(0.3), 0.5)
390            .add_conditional(RandomHue::symmetric(0.1), 0.3)
391            .add_conditional(GaussianNoise::with_std(0.02), 0.3)
392            .add_conditional(RandomErasing::new(0.5, (0.02, 0.33), (0.3, 3.3)), 1.0)
393    }
394
395    /// Create an augmentation pipeline for ImageNet-style training
396    pub fn imagenet_augmentation() -> Self {
397        Self::new()
398            .add_conditional(
399                crate::tensor_transforms::RandomHorizontalFlip::new(0.5),
400                1.0,
401            )
402            .add_conditional(RandomBrightness::symmetric(0.2), 0.4)
403            .add_conditional(RandomContrast::symmetric(0.2), 0.4)
404            .add_conditional(RandomSaturation::symmetric(0.2), 0.4)
405            .add_conditional(RandomHue::symmetric(0.1), 0.1)
406    }
407}
408
409#[cfg(test)]
410mod tests {
411    use super::*;
412    use torsh_core::device::DeviceType;
413    use torsh_tensor::Tensor;
414
415    // Mock tensor for testing
416    fn mock_tensor() -> Tensor<f32> {
417        Tensor::from_data(vec![1.0f32, 2.0, 3.0, 4.0], vec![2, 2], DeviceType::Cpu).unwrap()
418    }
419
420    #[test]
421    fn test_augmentation_pipeline_creation() {
422        let pipeline = AugmentationPipeline::<i32>::new();
423        assert_eq!(pipeline.probability, 1.0);
424        assert_eq!(pipeline.transforms.len(), 0);
425    }
426
427    #[test]
428    fn test_augmentation_pipeline_with_probability() {
429        let pipeline = AugmentationPipeline::<i32>::new().with_probability(0.5);
430        assert_eq!(pipeline.probability, 0.5);
431    }
432
433    #[test]
434    #[should_panic(expected = "Probability must be between 0 and 1")]
435    fn test_invalid_probability() {
436        AugmentationPipeline::<i32>::new().with_probability(1.5);
437    }
438
439    #[test]
440    fn test_conditional_transform_creation() {
441        let transform: ConditionalTransform<i32, _> =
442            ConditionalTransform::new(crate::transforms::lambda(|x: i32| Ok(x * 2)), 0.5);
443        assert_eq!(transform.probability, 0.5);
444    }
445
446    #[test]
447    fn test_random_brightness_creation() {
448        let brightness = RandomBrightness::new((0.8, 1.2));
449        assert_eq!(brightness.factor_range, (0.8, 1.2));
450    }
451
452    #[test]
453    fn test_random_brightness_symmetric() {
454        let brightness = RandomBrightness::symmetric(0.2);
455        assert_eq!(brightness.factor_range, (0.8, 1.2));
456    }
457
458    #[test]
459    fn test_gaussian_noise_creation() {
460        let noise = GaussianNoise::new(0.0, 0.1);
461        assert_eq!(noise.mean, 0.0);
462        assert_eq!(noise.std, 0.1);
463    }
464
465    #[test]
466    fn test_gaussian_noise_with_std() {
467        let noise = GaussianNoise::with_std(0.05);
468        assert_eq!(noise.mean, 0.0);
469        assert_eq!(noise.std, 0.05);
470    }
471
472    #[test]
473    fn test_random_erasing_creation() {
474        let erasing = RandomErasing::new(0.5, (0.02, 0.33), (0.3, 3.3));
475        assert_eq!(erasing.prob, 0.5);
476        assert_eq!(erasing.scale_range, (0.02, 0.33));
477        assert_eq!(erasing.ratio_range, (0.3, 3.3));
478        assert_eq!(erasing.fill_value, 0.0);
479    }
480
481    #[test]
482    fn test_light_augmentation_preset() {
483        let pipeline = AugmentationPipeline::light_augmentation();
484        assert_eq!(pipeline.transforms.len(), 3);
485    }
486
487    #[test]
488    fn test_medium_augmentation_preset() {
489        let pipeline = AugmentationPipeline::medium_augmentation();
490        assert_eq!(pipeline.transforms.len(), 6);
491    }
492
493    #[test]
494    fn test_heavy_augmentation_preset() {
495        let pipeline = AugmentationPipeline::heavy_augmentation();
496        assert_eq!(pipeline.transforms.len(), 8);
497    }
498
499    #[test]
500    fn test_augmentation_transform_passthrough() {
501        let tensor = mock_tensor();
502        let brightness = RandomBrightness::symmetric(0.1);
503        let result = brightness.transform(tensor.clone()).unwrap();
504
505        // For now, transforms are passthrough, so tensor should be unchanged
506        assert_eq!(result.shape(), tensor.shape());
507    }
508}