1use crate::transforms::Transform;
17use torsh_core::dtype::FloatElement;
18use torsh_core::error::Result;
19use torsh_tensor::Tensor;
20
21#[cfg(not(feature = "std"))]
22use alloc::{boxed::Box, vec::Vec};
23
24#[cfg(feature = "std")]
25use scirs2_core::random::thread_rng;
26
27#[cfg(not(feature = "std"))]
28use scirs2_core::random::thread_rng;
29
30pub struct AugmentationPipeline<T> {
32 transforms: Vec<Box<dyn Transform<T, Output = T> + Send + Sync>>,
33 probability: f32,
34}
35
36impl<T: 'static + Send + Sync> AugmentationPipeline<T> {
37 pub fn new() -> Self {
39 Self {
40 transforms: Vec::new(),
41 probability: 1.0,
42 }
43 }
44
45 pub fn with_probability(mut self, prob: f32) -> Self {
47 assert!(
48 (0.0..=1.0).contains(&prob),
49 "Probability must be between 0 and 1"
50 );
51 self.probability = prob;
52 self
53 }
54
55 pub fn add_transform<F>(mut self, transform: F) -> Self
57 where
58 F: Transform<T, Output = T> + 'static,
59 {
60 self.transforms.push(Box::new(transform));
61 self
62 }
63
64 pub fn add_conditional<F>(self, transform: F, prob: f32) -> Self
66 where
67 F: Transform<T, Output = T> + 'static,
68 {
69 self.add_transform(ConditionalTransform::new(transform, prob))
70 }
71}
72
73impl<T: 'static + Send + Sync> Default for AugmentationPipeline<T> {
74 fn default() -> Self {
75 Self::new()
76 }
77}
78
79impl<T> Transform<T> for AugmentationPipeline<T> {
80 type Output = T;
81
82 fn transform(&self, mut input: T) -> Result<Self::Output> {
83 let mut rng = thread_rng();
84
85 if rng.random::<f32>() > self.probability {
87 return Ok(input);
88 }
89
90 for transform in &self.transforms {
92 input = transform.transform(input)?;
93 }
94
95 Ok(input)
96 }
97}
98
99pub struct ConditionalTransform<T, F> {
101 transform: F,
102 probability: f32,
103 _phantom: core::marker::PhantomData<T>,
104}
105
106impl<T, F> ConditionalTransform<T, F> {
107 pub fn new(transform: F, probability: f32) -> Self {
108 assert!(
109 (0.0..=1.0).contains(&probability),
110 "Probability must be between 0 and 1"
111 );
112 Self {
113 transform,
114 probability,
115 _phantom: core::marker::PhantomData,
116 }
117 }
118}
119
120impl<T, F> Transform<T> for ConditionalTransform<T, F>
121where
122 F: Transform<T, Output = T>,
123 T: Send + Sync,
124{
125 type Output = T;
126
127 fn transform(&self, input: T) -> Result<Self::Output> {
128 let mut rng = thread_rng();
129
130 if rng.random::<f32>() < self.probability {
131 self.transform.transform(input)
132 } else {
133 Ok(input)
134 }
135 }
136}
137
138pub struct RandomBrightness {
140 #[allow(dead_code)]
141 factor_range: (f32, f32),
142}
143
144impl RandomBrightness {
145 pub fn new(factor_range: (f32, f32)) -> Self {
146 assert!(factor_range.0 <= factor_range.1, "Invalid factor range");
147 Self { factor_range }
148 }
149
150 pub fn symmetric(factor: f32) -> Self {
152 Self::new((1.0 - factor, 1.0 + factor))
153 }
154}
155
156impl<T: FloatElement> Transform<Tensor<T>> for RandomBrightness {
157 type Output = Tensor<T>;
158
159 fn transform(&self, input: Tensor<T>) -> Result<Self::Output> {
160 Ok(input)
163 }
164}
165
166pub struct RandomContrast {
168 #[allow(dead_code)]
169 factor_range: (f32, f32),
170}
171
172impl RandomContrast {
173 pub fn new(factor_range: (f32, f32)) -> Self {
174 assert!(factor_range.0 <= factor_range.1, "Invalid factor range");
175 Self { factor_range }
176 }
177
178 pub fn symmetric(factor: f32) -> Self {
180 Self::new((1.0 - factor, 1.0 + factor))
181 }
182}
183
184impl<T: FloatElement> Transform<Tensor<T>> for RandomContrast {
185 type Output = Tensor<T>;
186
187 fn transform(&self, input: Tensor<T>) -> Result<Self::Output> {
188 Ok(input)
191 }
192}
193
194pub struct RandomSaturation {
196 #[allow(dead_code)]
197 factor_range: (f32, f32),
198}
199
200impl RandomSaturation {
201 pub fn new(factor_range: (f32, f32)) -> Self {
202 assert!(factor_range.0 <= factor_range.1, "Invalid factor range");
203 Self { factor_range }
204 }
205
206 pub fn symmetric(factor: f32) -> Self {
208 Self::new((1.0 - factor, 1.0 + factor))
209 }
210}
211
212impl<T: FloatElement> Transform<Tensor<T>> for RandomSaturation {
213 type Output = Tensor<T>;
214
215 fn transform(&self, input: Tensor<T>) -> Result<Self::Output> {
216 Ok(input)
219 }
220}
221
222pub struct RandomHue {
224 #[allow(dead_code)]
225 delta_range: (f32, f32),
226}
227
228impl RandomHue {
229 pub fn new(delta_range: (f32, f32)) -> Self {
230 assert!(delta_range.0 <= delta_range.1, "Invalid delta range");
231 assert!(
232 delta_range.0 >= -1.0 && delta_range.1 <= 1.0,
233 "Hue delta must be in [-1, 1]"
234 );
235 Self { delta_range }
236 }
237
238 pub fn symmetric(delta: f32) -> Self {
240 Self::new((-delta, delta))
241 }
242}
243
244impl<T: FloatElement> Transform<Tensor<T>> for RandomHue {
245 type Output = Tensor<T>;
246
247 fn transform(&self, input: Tensor<T>) -> Result<Self::Output> {
248 Ok(input)
251 }
252}
253
254pub struct RandomVerticalFlip {
256 #[allow(dead_code)]
257 prob: f32,
258}
259
260impl RandomVerticalFlip {
261 pub fn new(prob: f32) -> Self {
262 assert!(
263 (0.0..=1.0).contains(&prob),
264 "Probability must be between 0 and 1"
265 );
266 Self { prob }
267 }
268}
269
270impl<T: FloatElement> Transform<Tensor<T>> for RandomVerticalFlip {
271 type Output = Tensor<T>;
272
273 fn transform(&self, input: Tensor<T>) -> Result<Self::Output> {
274 Ok(input)
276 }
277}
278
279pub struct GaussianNoise {
281 #[allow(dead_code)]
282 mean: f32,
283 #[allow(dead_code)]
284 std: f32,
285}
286
287impl GaussianNoise {
288 pub fn new(mean: f32, std: f32) -> Self {
289 assert!(std >= 0.0, "Standard deviation must be non-negative");
290 Self { mean, std }
291 }
292
293 pub fn with_std(std: f32) -> Self {
295 Self::new(0.0, std)
296 }
297}
298
299impl<T: FloatElement> Transform<Tensor<T>> for GaussianNoise {
300 type Output = Tensor<T>;
301
302 fn transform(&self, input: Tensor<T>) -> Result<Self::Output> {
303 Ok(input)
305 }
306}
307
308pub struct RandomErasing {
310 #[allow(dead_code)]
311 prob: f32,
312 #[allow(dead_code)]
313 scale_range: (f32, f32),
314 #[allow(dead_code)]
315 ratio_range: (f32, f32),
316 #[allow(dead_code)]
317 fill_value: f32,
318}
319
320impl RandomErasing {
321 pub fn new(prob: f32, scale_range: (f32, f32), ratio_range: (f32, f32)) -> Self {
322 assert!(
323 (0.0..=1.0).contains(&prob),
324 "Probability must be between 0 and 1"
325 );
326 assert!(scale_range.0 <= scale_range.1, "Invalid scale range");
327 assert!(ratio_range.0 <= ratio_range.1, "Invalid ratio range");
328
329 Self {
330 prob,
331 scale_range,
332 ratio_range,
333 fill_value: 0.0,
334 }
335 }
336
337 pub fn with_fill_value(mut self, fill_value: f32) -> Self {
338 self.fill_value = fill_value;
339 self
340 }
341}
342
343impl<T: FloatElement> Transform<Tensor<T>> for RandomErasing {
344 type Output = Tensor<T>;
345
346 fn transform(&self, input: Tensor<T>) -> Result<Self::Output> {
347 Ok(input)
349 }
350}
351
352impl AugmentationPipeline<Tensor<f32>> {
354 pub fn light_augmentation() -> Self {
356 Self::new()
357 .add_conditional(
358 crate::tensor_transforms::RandomHorizontalFlip::new(0.5),
359 1.0,
360 )
361 .add_conditional(RandomBrightness::symmetric(0.1), 0.3)
362 .add_conditional(RandomContrast::symmetric(0.1), 0.3)
363 }
364
365 pub fn medium_augmentation() -> Self {
367 Self::new()
368 .add_conditional(
369 crate::tensor_transforms::RandomHorizontalFlip::new(0.5),
370 1.0,
371 )
372 .add_conditional(RandomVerticalFlip::new(0.1), 1.0)
373 .add_conditional(RandomBrightness::symmetric(0.2), 0.5)
374 .add_conditional(RandomContrast::symmetric(0.2), 0.5)
375 .add_conditional(RandomSaturation::symmetric(0.2), 0.3)
376 .add_conditional(GaussianNoise::with_std(0.01), 0.2)
377 }
378
379 pub fn heavy_augmentation() -> Self {
381 Self::new()
382 .add_conditional(
383 crate::tensor_transforms::RandomHorizontalFlip::new(0.5),
384 1.0,
385 )
386 .add_conditional(RandomVerticalFlip::new(0.2), 1.0)
387 .add_conditional(RandomBrightness::symmetric(0.3), 0.7)
388 .add_conditional(RandomContrast::symmetric(0.3), 0.7)
389 .add_conditional(RandomSaturation::symmetric(0.3), 0.5)
390 .add_conditional(RandomHue::symmetric(0.1), 0.3)
391 .add_conditional(GaussianNoise::with_std(0.02), 0.3)
392 .add_conditional(RandomErasing::new(0.5, (0.02, 0.33), (0.3, 3.3)), 1.0)
393 }
394
395 pub fn imagenet_augmentation() -> Self {
397 Self::new()
398 .add_conditional(
399 crate::tensor_transforms::RandomHorizontalFlip::new(0.5),
400 1.0,
401 )
402 .add_conditional(RandomBrightness::symmetric(0.2), 0.4)
403 .add_conditional(RandomContrast::symmetric(0.2), 0.4)
404 .add_conditional(RandomSaturation::symmetric(0.2), 0.4)
405 .add_conditional(RandomHue::symmetric(0.1), 0.1)
406 }
407}
408
409#[cfg(test)]
410mod tests {
411 use super::*;
412 use torsh_core::device::DeviceType;
413 use torsh_tensor::Tensor;
414
415 fn mock_tensor() -> Tensor<f32> {
417 Tensor::from_data(vec![1.0f32, 2.0, 3.0, 4.0], vec![2, 2], DeviceType::Cpu).unwrap()
418 }
419
420 #[test]
421 fn test_augmentation_pipeline_creation() {
422 let pipeline = AugmentationPipeline::<i32>::new();
423 assert_eq!(pipeline.probability, 1.0);
424 assert_eq!(pipeline.transforms.len(), 0);
425 }
426
427 #[test]
428 fn test_augmentation_pipeline_with_probability() {
429 let pipeline = AugmentationPipeline::<i32>::new().with_probability(0.5);
430 assert_eq!(pipeline.probability, 0.5);
431 }
432
433 #[test]
434 #[should_panic(expected = "Probability must be between 0 and 1")]
435 fn test_invalid_probability() {
436 AugmentationPipeline::<i32>::new().with_probability(1.5);
437 }
438
439 #[test]
440 fn test_conditional_transform_creation() {
441 let transform: ConditionalTransform<i32, _> =
442 ConditionalTransform::new(crate::transforms::lambda(|x: i32| Ok(x * 2)), 0.5);
443 assert_eq!(transform.probability, 0.5);
444 }
445
446 #[test]
447 fn test_random_brightness_creation() {
448 let brightness = RandomBrightness::new((0.8, 1.2));
449 assert_eq!(brightness.factor_range, (0.8, 1.2));
450 }
451
452 #[test]
453 fn test_random_brightness_symmetric() {
454 let brightness = RandomBrightness::symmetric(0.2);
455 assert_eq!(brightness.factor_range, (0.8, 1.2));
456 }
457
458 #[test]
459 fn test_gaussian_noise_creation() {
460 let noise = GaussianNoise::new(0.0, 0.1);
461 assert_eq!(noise.mean, 0.0);
462 assert_eq!(noise.std, 0.1);
463 }
464
465 #[test]
466 fn test_gaussian_noise_with_std() {
467 let noise = GaussianNoise::with_std(0.05);
468 assert_eq!(noise.mean, 0.0);
469 assert_eq!(noise.std, 0.05);
470 }
471
472 #[test]
473 fn test_random_erasing_creation() {
474 let erasing = RandomErasing::new(0.5, (0.02, 0.33), (0.3, 3.3));
475 assert_eq!(erasing.prob, 0.5);
476 assert_eq!(erasing.scale_range, (0.02, 0.33));
477 assert_eq!(erasing.ratio_range, (0.3, 3.3));
478 assert_eq!(erasing.fill_value, 0.0);
479 }
480
481 #[test]
482 fn test_light_augmentation_preset() {
483 let pipeline = AugmentationPipeline::light_augmentation();
484 assert_eq!(pipeline.transforms.len(), 3);
485 }
486
487 #[test]
488 fn test_medium_augmentation_preset() {
489 let pipeline = AugmentationPipeline::medium_augmentation();
490 assert_eq!(pipeline.transforms.len(), 6);
491 }
492
493 #[test]
494 fn test_heavy_augmentation_preset() {
495 let pipeline = AugmentationPipeline::heavy_augmentation();
496 assert_eq!(pipeline.transforms.len(), 8);
497 }
498
499 #[test]
500 fn test_augmentation_transform_passthrough() {
501 let tensor = mock_tensor();
502 let brightness = RandomBrightness::symmetric(0.1);
503 let result = brightness.transform(tensor.clone()).unwrap();
504
505 assert_eq!(result.shape(), tensor.shape());
507 }
508}