Skip to main content

yscv_model/
augmentation.rs

1use std::fmt;
2use std::sync::Arc;
3
4use yscv_imgproc::{
5    box_blur_3x3, flip_horizontal, flip_vertical, normalize, resize_nearest, rotate90_cw,
6};
7use yscv_tensor::{Tensor, TensorError};
8
9use super::transform::Transform;
10use crate::{ModelError, SupervisedDataset};
11
12/// Per-sample image augmentations for rank-4 NHWC training tensors.
13pub enum ImageAugmentationOp {
14    /// Flip image across width axis with configured probability.
15    HorizontalFlip { probability: f32 },
16    /// Flip image across height axis with configured probability.
17    VerticalFlip { probability: f32 },
18    /// Rotate sample by random multiples of 90 degrees with configured probability.
19    ///
20    /// For square samples, rotation is sampled from {0, 90, 180, 270} degrees.
21    /// For non-square samples, rotation is sampled from {0, 180} degrees to preserve shape.
22    RandomRotate90 { probability: f32 },
23    /// Add random uniform brightness delta in `[-max_delta, +max_delta]` and clamp to `[0, 1]`.
24    BrightnessJitter { max_delta: f32 },
25    /// Scale contrast around per-sample mean by factor in `[1-max_scale_delta, 1+max_scale_delta]`.
26    ContrastJitter { max_scale_delta: f32 },
27    /// Apply gamma correction with gamma sampled in `[1-max_gamma_delta, 1+max_gamma_delta]`.
28    GammaJitter { max_gamma_delta: f32 },
29    /// Add per-value Gaussian noise with configured standard deviation and clamp to `[0, 1]`.
30    GaussianNoise { probability: f32, std_dev: f32 },
31    /// Apply 3x3 box blur with configured probability.
32    BoxBlur3x3 { probability: f32 },
33    /// Crop a random window and resize it back to original sample size.
34    RandomResizedCrop {
35        probability: f32,
36        min_scale: f32,
37        max_scale: f32,
38    },
39    /// Apply random rectangular erasing with configured max size fractions and fill value.
40    Cutout {
41        probability: f32,
42        max_height_fraction: f32,
43        max_width_fraction: f32,
44        fill_value: f32,
45    },
46    /// Per-channel normalization in HWC layout: `(x - mean[c]) / std[c]`.
47    ChannelNormalize { mean: Vec<f32>, std: Vec<f32> },
48    /// User-provided closure for custom augmentation logic.
49    Custom(Arc<dyn Fn(&Tensor) -> Result<Tensor, ModelError> + Send + Sync>),
50    /// Random crop from larger image (does not resize back; changes spatial dims).
51    RandomCrop { height: usize, width: usize },
52    /// Apply gaussian blur with the given square kernel size (must be odd and >= 1).
53    GaussianBlur { kernel_size: usize },
54}
55
56impl fmt::Debug for ImageAugmentationOp {
57    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
58        match self {
59            Self::HorizontalFlip { probability } => f
60                .debug_struct("HorizontalFlip")
61                .field("probability", probability)
62                .finish(),
63            Self::VerticalFlip { probability } => f
64                .debug_struct("VerticalFlip")
65                .field("probability", probability)
66                .finish(),
67            Self::RandomRotate90 { probability } => f
68                .debug_struct("RandomRotate90")
69                .field("probability", probability)
70                .finish(),
71            Self::BrightnessJitter { max_delta } => f
72                .debug_struct("BrightnessJitter")
73                .field("max_delta", max_delta)
74                .finish(),
75            Self::ContrastJitter { max_scale_delta } => f
76                .debug_struct("ContrastJitter")
77                .field("max_scale_delta", max_scale_delta)
78                .finish(),
79            Self::GammaJitter { max_gamma_delta } => f
80                .debug_struct("GammaJitter")
81                .field("max_gamma_delta", max_gamma_delta)
82                .finish(),
83            Self::GaussianNoise {
84                probability,
85                std_dev,
86            } => f
87                .debug_struct("GaussianNoise")
88                .field("probability", probability)
89                .field("std_dev", std_dev)
90                .finish(),
91            Self::BoxBlur3x3 { probability } => f
92                .debug_struct("BoxBlur3x3")
93                .field("probability", probability)
94                .finish(),
95            Self::RandomResizedCrop {
96                probability,
97                min_scale,
98                max_scale,
99            } => f
100                .debug_struct("RandomResizedCrop")
101                .field("probability", probability)
102                .field("min_scale", min_scale)
103                .field("max_scale", max_scale)
104                .finish(),
105            Self::Cutout {
106                probability,
107                max_height_fraction,
108                max_width_fraction,
109                fill_value,
110            } => f
111                .debug_struct("Cutout")
112                .field("probability", probability)
113                .field("max_height_fraction", max_height_fraction)
114                .field("max_width_fraction", max_width_fraction)
115                .field("fill_value", fill_value)
116                .finish(),
117            Self::ChannelNormalize { mean, std } => f
118                .debug_struct("ChannelNormalize")
119                .field("mean", mean)
120                .field("std", std)
121                .finish(),
122            Self::Custom(_) => f.debug_tuple("Custom").field(&"<closure>").finish(),
123            Self::RandomCrop { height, width } => f
124                .debug_struct("RandomCrop")
125                .field("height", height)
126                .field("width", width)
127                .finish(),
128            Self::GaussianBlur { kernel_size } => f
129                .debug_struct("GaussianBlur")
130                .field("kernel_size", kernel_size)
131                .finish(),
132        }
133    }
134}
135
136impl Clone for ImageAugmentationOp {
137    fn clone(&self) -> Self {
138        match self {
139            Self::HorizontalFlip { probability } => Self::HorizontalFlip {
140                probability: *probability,
141            },
142            Self::VerticalFlip { probability } => Self::VerticalFlip {
143                probability: *probability,
144            },
145            Self::RandomRotate90 { probability } => Self::RandomRotate90 {
146                probability: *probability,
147            },
148            Self::BrightnessJitter { max_delta } => Self::BrightnessJitter {
149                max_delta: *max_delta,
150            },
151            Self::ContrastJitter { max_scale_delta } => Self::ContrastJitter {
152                max_scale_delta: *max_scale_delta,
153            },
154            Self::GammaJitter { max_gamma_delta } => Self::GammaJitter {
155                max_gamma_delta: *max_gamma_delta,
156            },
157            Self::GaussianNoise {
158                probability,
159                std_dev,
160            } => Self::GaussianNoise {
161                probability: *probability,
162                std_dev: *std_dev,
163            },
164            Self::BoxBlur3x3 { probability } => Self::BoxBlur3x3 {
165                probability: *probability,
166            },
167            Self::RandomResizedCrop {
168                probability,
169                min_scale,
170                max_scale,
171            } => Self::RandomResizedCrop {
172                probability: *probability,
173                min_scale: *min_scale,
174                max_scale: *max_scale,
175            },
176            Self::Cutout {
177                probability,
178                max_height_fraction,
179                max_width_fraction,
180                fill_value,
181            } => Self::Cutout {
182                probability: *probability,
183                max_height_fraction: *max_height_fraction,
184                max_width_fraction: *max_width_fraction,
185                fill_value: *fill_value,
186            },
187            Self::ChannelNormalize { mean, std } => Self::ChannelNormalize {
188                mean: mean.clone(),
189                std: std.clone(),
190            },
191            Self::Custom(f) => Self::Custom(Arc::clone(f)),
192            Self::RandomCrop { height, width } => Self::RandomCrop {
193                height: *height,
194                width: *width,
195            },
196            Self::GaussianBlur { kernel_size } => Self::GaussianBlur {
197                kernel_size: *kernel_size,
198            },
199        }
200    }
201}
202
203impl PartialEq for ImageAugmentationOp {
204    fn eq(&self, other: &Self) -> bool {
205        match (self, other) {
206            (Self::HorizontalFlip { probability: a }, Self::HorizontalFlip { probability: b }) => {
207                a == b
208            }
209            (Self::VerticalFlip { probability: a }, Self::VerticalFlip { probability: b }) => {
210                a == b
211            }
212            (Self::RandomRotate90 { probability: a }, Self::RandomRotate90 { probability: b }) => {
213                a == b
214            }
215            (Self::BrightnessJitter { max_delta: a }, Self::BrightnessJitter { max_delta: b }) => {
216                a == b
217            }
218            (
219                Self::ContrastJitter { max_scale_delta: a },
220                Self::ContrastJitter { max_scale_delta: b },
221            ) => a == b,
222            (
223                Self::GammaJitter { max_gamma_delta: a },
224                Self::GammaJitter { max_gamma_delta: b },
225            ) => a == b,
226            (
227                Self::GaussianNoise {
228                    probability: p1,
229                    std_dev: s1,
230                },
231                Self::GaussianNoise {
232                    probability: p2,
233                    std_dev: s2,
234                },
235            ) => p1 == p2 && s1 == s2,
236            (Self::BoxBlur3x3 { probability: a }, Self::BoxBlur3x3 { probability: b }) => a == b,
237            (
238                Self::RandomResizedCrop {
239                    probability: p1,
240                    min_scale: mn1,
241                    max_scale: mx1,
242                },
243                Self::RandomResizedCrop {
244                    probability: p2,
245                    min_scale: mn2,
246                    max_scale: mx2,
247                },
248            ) => p1 == p2 && mn1 == mn2 && mx1 == mx2,
249            (
250                Self::Cutout {
251                    probability: p1,
252                    max_height_fraction: h1,
253                    max_width_fraction: w1,
254                    fill_value: f1,
255                },
256                Self::Cutout {
257                    probability: p2,
258                    max_height_fraction: h2,
259                    max_width_fraction: w2,
260                    fill_value: f2,
261                },
262            ) => p1 == p2 && h1 == h2 && w1 == w2 && f1 == f2,
263            (
264                Self::ChannelNormalize { mean: m1, std: s1 },
265                Self::ChannelNormalize { mean: m2, std: s2 },
266            ) => m1 == m2 && s1 == s2,
267            (Self::Custom(_), Self::Custom(_)) => false, // closures cannot be compared
268            (
269                Self::RandomCrop {
270                    height: h1,
271                    width: w1,
272                },
273                Self::RandomCrop {
274                    height: h2,
275                    width: w2,
276                },
277            ) => h1 == h2 && w1 == w2,
278            (Self::GaussianBlur { kernel_size: a }, Self::GaussianBlur { kernel_size: b }) => {
279                a == b
280            }
281            _ => false,
282        }
283    }
284}
285
286impl ImageAugmentationOp {
287    fn validate(&self) -> Result<(), ModelError> {
288        match self {
289            Self::HorizontalFlip { probability } => {
290                validate_probability("horizontal_flip", *probability)
291            }
292            Self::VerticalFlip { probability } => {
293                validate_probability("vertical_flip", *probability)
294            }
295            Self::RandomRotate90 { probability } => {
296                validate_probability("random_rotate90", *probability)
297            }
298            Self::BrightnessJitter { max_delta } => {
299                if !max_delta.is_finite() || *max_delta < 0.0 {
300                    return Err(ModelError::InvalidAugmentationArgument {
301                        operation: "brightness_jitter",
302                        message: format!("max_delta must be finite and >= 0, got {max_delta}"),
303                    });
304                }
305                Ok(())
306            }
307            Self::ContrastJitter { max_scale_delta } => {
308                if !max_scale_delta.is_finite() || *max_scale_delta < 0.0 {
309                    return Err(ModelError::InvalidAugmentationArgument {
310                        operation: "contrast_jitter",
311                        message: format!(
312                            "max_scale_delta must be finite and >= 0, got {max_scale_delta}"
313                        ),
314                    });
315                }
316                Ok(())
317            }
318            Self::GammaJitter { max_gamma_delta } => {
319                if !max_gamma_delta.is_finite() || *max_gamma_delta < 0.0 {
320                    return Err(ModelError::InvalidAugmentationArgument {
321                        operation: "gamma_jitter",
322                        message: format!(
323                            "max_gamma_delta must be finite and >= 0, got {max_gamma_delta}"
324                        ),
325                    });
326                }
327                Ok(())
328            }
329            Self::GaussianNoise {
330                probability,
331                std_dev,
332            } => {
333                validate_probability("gaussian_noise", *probability)?;
334                if !std_dev.is_finite() || *std_dev < 0.0 {
335                    return Err(ModelError::InvalidAugmentationArgument {
336                        operation: "gaussian_noise",
337                        message: format!("std_dev must be finite and >= 0, got {std_dev}"),
338                    });
339                }
340                Ok(())
341            }
342            Self::BoxBlur3x3 { probability } => validate_probability("box_blur_3x3", *probability),
343            Self::RandomResizedCrop {
344                probability,
345                min_scale,
346                max_scale,
347            } => {
348                validate_probability("random_resized_crop", *probability)?;
349                validate_fraction("random_resized_crop", "min_scale", *min_scale)?;
350                validate_fraction("random_resized_crop", "max_scale", *max_scale)?;
351                if min_scale > max_scale {
352                    return Err(ModelError::InvalidAugmentationArgument {
353                        operation: "random_resized_crop",
354                        message: format!(
355                            "min_scale must be <= max_scale, got min_scale={min_scale}, max_scale={max_scale}"
356                        ),
357                    });
358                }
359                Ok(())
360            }
361            Self::Cutout {
362                probability,
363                max_height_fraction,
364                max_width_fraction,
365                fill_value,
366            } => {
367                validate_probability("cutout", *probability)?;
368                validate_fraction("cutout", "max_height_fraction", *max_height_fraction)?;
369                validate_fraction("cutout", "max_width_fraction", *max_width_fraction)?;
370                if !fill_value.is_finite() {
371                    return Err(ModelError::InvalidAugmentationArgument {
372                        operation: "cutout",
373                        message: format!("fill_value must be finite, got {fill_value}"),
374                    });
375                }
376                Ok(())
377            }
378            Self::ChannelNormalize { mean, std } => {
379                if mean.is_empty() || std.is_empty() {
380                    return Err(ModelError::InvalidAugmentationArgument {
381                        operation: "channel_normalize",
382                        message: "mean/std must be non-empty".to_string(),
383                    });
384                }
385                if mean.len() != std.len() {
386                    return Err(ModelError::InvalidAugmentationArgument {
387                        operation: "channel_normalize",
388                        message: format!(
389                            "mean/std length mismatch: mean_len={}, std_len={}",
390                            mean.len(),
391                            std.len()
392                        ),
393                    });
394                }
395                for (channel, mean_value) in mean.iter().enumerate() {
396                    if !mean_value.is_finite() {
397                        return Err(ModelError::InvalidAugmentationArgument {
398                            operation: "channel_normalize",
399                            message: format!("mean[{channel}] must be finite"),
400                        });
401                    }
402                }
403                for (channel, std_value) in std.iter().enumerate() {
404                    if !std_value.is_finite() || *std_value <= 0.0 {
405                        return Err(ModelError::InvalidAugmentationArgument {
406                            operation: "channel_normalize",
407                            message: format!("std[{channel}] must be finite and > 0"),
408                        });
409                    }
410                }
411                Ok(())
412            }
413            Self::Custom(_) => Ok(()),
414            Self::RandomCrop { height, width } => {
415                if *height == 0 || *width == 0 {
416                    return Err(ModelError::InvalidAugmentationArgument {
417                        operation: "random_crop",
418                        message: format!(
419                            "height and width must be > 0, got height={height}, width={width}"
420                        ),
421                    });
422                }
423                Ok(())
424            }
425            Self::GaussianBlur { kernel_size } => {
426                if *kernel_size == 0 || *kernel_size % 2 == 0 {
427                    return Err(ModelError::InvalidAugmentationArgument {
428                        operation: "gaussian_blur",
429                        message: format!("kernel_size must be odd and >= 1, got {kernel_size}"),
430                    });
431                }
432                Ok(())
433            }
434        }
435    }
436
437    /// Create an augmentation op from any Transform implementation.
438    pub fn from_transform<T: Transform + Send + Sync + 'static>(t: T) -> Self {
439        Self::Custom(Arc::new(move |input| t.apply(input)))
440    }
441}
442
443/// Ordered per-sample augmentation pipeline for NHWC mini-batch data.
444#[derive(Debug, Clone, PartialEq)]
445pub struct ImageAugmentationPipeline {
446    ops: Vec<ImageAugmentationOp>,
447}
448
449impl ImageAugmentationPipeline {
450    pub fn new(ops: Vec<ImageAugmentationOp>) -> Result<Self, ModelError> {
451        for op in &ops {
452            op.validate()?;
453        }
454        Ok(Self { ops })
455    }
456
457    pub fn ops(&self) -> &[ImageAugmentationOp] {
458        &self.ops
459    }
460
461    pub fn apply_nhwc(&self, inputs: &Tensor, seed: u64) -> Result<Tensor, ModelError> {
462        if inputs.rank() != 4 {
463            return Err(ModelError::InvalidAugmentationInputShape {
464                got: inputs.shape().to_vec(),
465            });
466        }
467
468        let shape = inputs.shape();
469        let sample_count = shape[0];
470        let sample_len = shape[1..].iter().try_fold(1usize, |acc, dim| {
471            acc.checked_mul(*dim)
472                .ok_or_else(|| TensorError::SizeOverflow {
473                    shape: shape.to_vec(),
474                })
475        })?;
476
477        let mut out = Vec::with_capacity(inputs.data().len());
478        let mut rng = LcgRng::new(seed);
479        let mut out_sample_shape: Option<Vec<usize>> = None;
480
481        for sample_idx in 0..sample_count {
482            let start =
483                sample_idx
484                    .checked_mul(sample_len)
485                    .ok_or_else(|| TensorError::SizeOverflow {
486                        shape: shape.to_vec(),
487                    })?;
488            let end = start
489                .checked_add(sample_len)
490                .ok_or_else(|| TensorError::SizeOverflow {
491                    shape: shape.to_vec(),
492                })?;
493
494            let mut sample =
495                Tensor::from_vec(shape[1..].to_vec(), inputs.data()[start..end].to_vec())?;
496            for op in &self.ops {
497                sample = apply_op(sample, op, &mut rng)?;
498            }
499            // Record the expected sample shape from the first sample and verify
500            // all subsequent samples match (they may differ from the original
501            // input shape when shape-changing ops like RandomCrop are used).
502            if sample_idx == 0 {
503                out_sample_shape = Some(sample.shape().to_vec());
504            } else if let Some(ref expected) = out_sample_shape
505                && sample.shape() != expected.as_slice()
506            {
507                return Err(ModelError::InvalidAugmentationArgument {
508                    operation: "pipeline",
509                    message: format!(
510                        "augmentation produced inconsistent sample shapes: first={:?}, sample[{sample_idx}]={:?}",
511                        expected,
512                        sample.shape()
513                    ),
514                });
515            }
516            out.extend_from_slice(sample.data());
517        }
518
519        let final_sample_shape = out_sample_shape.unwrap_or_else(|| shape[1..].to_vec());
520        let mut final_shape = vec![sample_count];
521        final_shape.extend_from_slice(&final_sample_shape);
522        Tensor::from_vec(final_shape, out).map_err(Into::into)
523    }
524}
525
526impl SupervisedDataset {
527    /// Returns a new dataset with NHWC image augmentations applied to `inputs`.
528    pub fn augment_nhwc(
529        &self,
530        pipeline: &ImageAugmentationPipeline,
531        seed: u64,
532    ) -> Result<Self, ModelError> {
533        let augmented_inputs = pipeline.apply_nhwc(self.inputs(), seed)?;
534        Self::new(augmented_inputs, self.targets().clone())
535    }
536}
537
538fn apply_op(
539    input: Tensor,
540    op: &ImageAugmentationOp,
541    rng: &mut LcgRng,
542) -> Result<Tensor, ModelError> {
543    match op {
544        ImageAugmentationOp::HorizontalFlip { probability } => {
545            if should_apply(*probability, rng) {
546                flip_horizontal(&input).map_err(Into::into)
547            } else {
548                Ok(input)
549            }
550        }
551        ImageAugmentationOp::VerticalFlip { probability } => {
552            if should_apply(*probability, rng) {
553                flip_vertical(&input).map_err(Into::into)
554            } else {
555                Ok(input)
556            }
557        }
558        ImageAugmentationOp::RandomRotate90 { probability } => {
559            if should_apply(*probability, rng) {
560                apply_random_rotate90(input, rng)
561            } else {
562                Ok(input)
563            }
564        }
565        ImageAugmentationOp::BrightnessJitter { max_delta } => {
566            if *max_delta == 0.0 {
567                return Ok(input);
568            }
569            let delta = rng.next_signed_unit() * *max_delta;
570            apply_brightness_delta(&input, delta)
571        }
572        ImageAugmentationOp::ContrastJitter { max_scale_delta } => {
573            if *max_scale_delta == 0.0 {
574                return Ok(input);
575            }
576            let scale = (1.0 + rng.next_signed_unit() * *max_scale_delta).max(0.0);
577            apply_contrast_scale(&input, scale)
578        }
579        ImageAugmentationOp::GammaJitter { max_gamma_delta } => {
580            if *max_gamma_delta == 0.0 {
581                return Ok(input);
582            }
583            let gamma = (1.0 + rng.next_signed_unit() * *max_gamma_delta).max(0.01);
584            apply_gamma_correction(&input, gamma)
585        }
586        ImageAugmentationOp::GaussianNoise {
587            probability,
588            std_dev,
589        } => {
590            if !should_apply(*probability, rng) || *std_dev == 0.0 {
591                return Ok(input);
592            }
593            apply_gaussian_noise(&input, *std_dev, rng)
594        }
595        ImageAugmentationOp::BoxBlur3x3 { probability } => {
596            if should_apply(*probability, rng) {
597                box_blur_3x3(&input).map_err(Into::into)
598            } else {
599                Ok(input)
600            }
601        }
602        ImageAugmentationOp::RandomResizedCrop {
603            probability,
604            min_scale,
605            max_scale,
606        } => {
607            if !should_apply(*probability, rng) {
608                return Ok(input);
609            }
610            apply_random_resized_crop(&input, *min_scale, *max_scale, rng)
611        }
612        ImageAugmentationOp::Cutout {
613            probability,
614            max_height_fraction,
615            max_width_fraction,
616            fill_value,
617        } => {
618            if !should_apply(*probability, rng) {
619                return Ok(input);
620            }
621            apply_cutout(
622                &input,
623                *max_height_fraction,
624                *max_width_fraction,
625                *fill_value,
626                rng,
627            )
628        }
629        ImageAugmentationOp::ChannelNormalize { mean, std } => {
630            normalize(&input, mean, std).map_err(Into::into)
631        }
632        ImageAugmentationOp::Custom(f) => f(&input),
633        ImageAugmentationOp::RandomCrop { height, width } => {
634            apply_random_crop(&input, *height, *width, rng)
635        }
636        ImageAugmentationOp::GaussianBlur { kernel_size } => {
637            apply_gaussian_blur(&input, *kernel_size)
638        }
639    }
640}
641
642fn apply_brightness_delta(input: &Tensor, delta: f32) -> Result<Tensor, ModelError> {
643    let mut output = Vec::with_capacity(input.data().len());
644    for value in input.data() {
645        output.push((*value + delta).clamp(0.0, 1.0));
646    }
647    Tensor::from_vec(input.shape().to_vec(), output).map_err(Into::into)
648}
649
650fn apply_random_rotate90(input: Tensor, rng: &mut LcgRng) -> Result<Tensor, ModelError> {
651    if input.rank() != 3 {
652        return Err(ModelError::InvalidAugmentationArgument {
653            operation: "random_rotate90",
654            message: format!("expected rank-3 HWC sample, got shape {:?}", input.shape()),
655        });
656    }
657
658    let height = input.shape()[0];
659    let width = input.shape()[1];
660    if height == 0 || width == 0 {
661        return Ok(input);
662    }
663
664    let rotation_count = if height == width {
665        rng.next_usize(4)
666    } else {
667        rng.next_usize(2) * 2
668    };
669
670    let mut rotated = input;
671    for _ in 0..rotation_count {
672        rotated = rotate90_cw(&rotated).map_err(ModelError::from)?;
673    }
674    Ok(rotated)
675}
676
677fn apply_contrast_scale(input: &Tensor, scale: f32) -> Result<Tensor, ModelError> {
678    let mean = if input.is_empty() {
679        0.0
680    } else {
681        input.data().iter().copied().sum::<f32>() / input.len() as f32
682    };
683
684    let mut output = Vec::with_capacity(input.data().len());
685    for value in input.data() {
686        let scaled = (*value - mean) * scale + mean;
687        output.push(scaled.clamp(0.0, 1.0));
688    }
689    Tensor::from_vec(input.shape().to_vec(), output).map_err(Into::into)
690}
691
692fn apply_gamma_correction(input: &Tensor, gamma: f32) -> Result<Tensor, ModelError> {
693    let mut output = Vec::with_capacity(input.data().len());
694    for value in input.data() {
695        output.push(value.clamp(0.0, 1.0).powf(gamma));
696    }
697    Tensor::from_vec(input.shape().to_vec(), output).map_err(Into::into)
698}
699
700fn apply_gaussian_noise(
701    input: &Tensor,
702    std_dev: f32,
703    rng: &mut LcgRng,
704) -> Result<Tensor, ModelError> {
705    let mut output = Vec::with_capacity(input.data().len());
706    for value in input.data() {
707        let noise = rng.next_gaussian() * std_dev;
708        output.push((*value + noise).clamp(0.0, 1.0));
709    }
710    Tensor::from_vec(input.shape().to_vec(), output).map_err(Into::into)
711}
712
713fn apply_cutout(
714    input: &Tensor,
715    max_height_fraction: f32,
716    max_width_fraction: f32,
717    fill_value: f32,
718    rng: &mut LcgRng,
719) -> Result<Tensor, ModelError> {
720    if input.rank() != 3 {
721        return Err(ModelError::InvalidAugmentationArgument {
722            operation: "cutout",
723            message: format!("expected rank-3 HWC sample, got shape {:?}", input.shape()),
724        });
725    }
726    let height = input.shape()[0];
727    let width = input.shape()[1];
728    let channels = input.shape()[2];
729    if height == 0 || width == 0 || channels == 0 {
730        return Ok(input.clone());
731    }
732
733    let max_cut_height = ((height as f32 * max_height_fraction).floor() as usize)
734        .max(1)
735        .min(height);
736    let max_cut_width = ((width as f32 * max_width_fraction).floor() as usize)
737        .max(1)
738        .min(width);
739
740    let cut_height = rng.next_usize_inclusive(max_cut_height - 1) + 1;
741    let cut_width = rng.next_usize_inclusive(max_cut_width - 1) + 1;
742    let top = rng.next_usize(height - cut_height + 1);
743    let left = rng.next_usize(width - cut_width + 1);
744
745    let mut output = input.data().to_vec();
746    for y in top..(top + cut_height) {
747        for x in left..(left + cut_width) {
748            let base = (y * width + x) * channels;
749            for channel in 0..channels {
750                output[base + channel] = fill_value;
751            }
752        }
753    }
754    Tensor::from_vec(input.shape().to_vec(), output).map_err(Into::into)
755}
756
757fn apply_random_resized_crop(
758    input: &Tensor,
759    min_scale: f32,
760    max_scale: f32,
761    rng: &mut LcgRng,
762) -> Result<Tensor, ModelError> {
763    if input.rank() != 3 {
764        return Err(ModelError::InvalidAugmentationArgument {
765            operation: "random_resized_crop",
766            message: format!("expected rank-3 HWC sample, got shape {:?}", input.shape()),
767        });
768    }
769    let height = input.shape()[0];
770    let width = input.shape()[1];
771    let channels = input.shape()[2];
772    if height == 0 || width == 0 || channels == 0 {
773        return Ok(input.clone());
774    }
775
776    let scale = if (max_scale - min_scale).abs() <= f32::EPSILON {
777        min_scale
778    } else {
779        min_scale + rng.next_unit() * (max_scale - min_scale)
780    };
781
782    let crop_height = ((height as f32 * scale).floor() as usize)
783        .max(1)
784        .min(height);
785    let crop_width = ((width as f32 * scale).floor() as usize).max(1).min(width);
786    let top = rng.next_usize(height - crop_height + 1);
787    let left = rng.next_usize(width - crop_width + 1);
788
789    let mut cropped = vec![0.0f32; crop_height * crop_width * channels];
790    for y in 0..crop_height {
791        for x in 0..crop_width {
792            let src_base = ((top + y) * width + (left + x)) * channels;
793            let dst_base = (y * crop_width + x) * channels;
794            cropped[dst_base..(dst_base + channels)]
795                .copy_from_slice(&input.data()[src_base..(src_base + channels)]);
796        }
797    }
798
799    let cropped_tensor = Tensor::from_vec(vec![crop_height, crop_width, channels], cropped)?;
800    resize_nearest(&cropped_tensor, height, width).map_err(Into::into)
801}
802
803fn apply_random_crop(
804    input: &Tensor,
805    crop_height: usize,
806    crop_width: usize,
807    rng: &mut LcgRng,
808) -> Result<Tensor, ModelError> {
809    if input.rank() != 3 {
810        return Err(ModelError::InvalidAugmentationArgument {
811            operation: "random_crop",
812            message: format!("expected rank-3 HWC sample, got shape {:?}", input.shape()),
813        });
814    }
815    let height = input.shape()[0];
816    let width = input.shape()[1];
817    let channels = input.shape()[2];
818
819    if crop_height > height || crop_width > width {
820        return Err(ModelError::InvalidAugmentationArgument {
821            operation: "random_crop",
822            message: format!(
823                "crop size ({crop_height}x{crop_width}) exceeds input size ({height}x{width})"
824            ),
825        });
826    }
827
828    let top = rng.next_usize(height - crop_height + 1);
829    let left = rng.next_usize(width - crop_width + 1);
830
831    let mut cropped = vec![0.0f32; crop_height * crop_width * channels];
832    for y in 0..crop_height {
833        for x in 0..crop_width {
834            let src_base = ((top + y) * width + (left + x)) * channels;
835            let dst_base = (y * crop_width + x) * channels;
836            cropped[dst_base..(dst_base + channels)]
837                .copy_from_slice(&input.data()[src_base..(src_base + channels)]);
838        }
839    }
840
841    Tensor::from_vec(vec![crop_height, crop_width, channels], cropped).map_err(Into::into)
842}
843
844fn apply_gaussian_blur(input: &Tensor, kernel_size: usize) -> Result<Tensor, ModelError> {
845    if input.rank() != 3 {
846        return Err(ModelError::InvalidAugmentationArgument {
847            operation: "gaussian_blur",
848            message: format!("expected rank-3 HWC sample, got shape {:?}", input.shape()),
849        });
850    }
851    let height = input.shape()[0];
852    let width = input.shape()[1];
853    let channels = input.shape()[2];
854    if height == 0 || width == 0 || channels == 0 {
855        return Ok(input.clone());
856    }
857
858    // Build 1-D Gaussian kernel weights.
859    let sigma = (kernel_size as f32 - 1.0) / 2.0 * 0.5 + 0.8;
860    let half = (kernel_size / 2) as isize;
861    let mut kernel = Vec::with_capacity(kernel_size);
862    let mut sum = 0.0f32;
863    for i in 0..kernel_size {
864        let x = (i as isize - half) as f32;
865        let w = (-x * x / (2.0 * sigma * sigma)).exp();
866        kernel.push(w);
867        sum += w;
868    }
869    for w in &mut kernel {
870        *w /= sum;
871    }
872
873    // Horizontal pass.
874    let mut tmp = vec![0.0f32; height * width * channels];
875    for y in 0..height {
876        for x in 0..width {
877            for c in 0..channels {
878                let mut acc = 0.0f32;
879                for ki in 0..kernel_size {
880                    let sx =
881                        (x as isize + ki as isize - half).clamp(0, width as isize - 1) as usize;
882                    acc += input.data()[(y * width + sx) * channels + c] * kernel[ki];
883                }
884                tmp[(y * width + x) * channels + c] = acc;
885            }
886        }
887    }
888
889    // Vertical pass.
890    let mut out = vec![0.0f32; height * width * channels];
891    for y in 0..height {
892        for x in 0..width {
893            for c in 0..channels {
894                let mut acc = 0.0f32;
895                for ki in 0..kernel_size {
896                    let sy =
897                        (y as isize + ki as isize - half).clamp(0, height as isize - 1) as usize;
898                    acc += tmp[(sy * width + x) * channels + c] * kernel[ki];
899                }
900                out[(y * width + x) * channels + c] = acc;
901            }
902        }
903    }
904
905    Tensor::from_vec(vec![height, width, channels], out).map_err(Into::into)
906}
907
908fn validate_probability(operation: &'static str, probability: f32) -> Result<(), ModelError> {
909    if !probability.is_finite() || !(0.0..=1.0).contains(&probability) {
910        return Err(ModelError::InvalidAugmentationProbability {
911            operation,
912            value: probability,
913        });
914    }
915    Ok(())
916}
917
918fn validate_fraction(
919    operation: &'static str,
920    parameter: &'static str,
921    value: f32,
922) -> Result<(), ModelError> {
923    if !value.is_finite() || value <= 0.0 || value > 1.0 {
924        return Err(ModelError::InvalidAugmentationArgument {
925            operation,
926            message: format!("{parameter} must be finite in (0, 1], got {value}"),
927        });
928    }
929    Ok(())
930}
931
932fn should_apply(probability: f32, rng: &mut LcgRng) -> bool {
933    if probability <= 0.0 {
934        return false;
935    }
936    if probability >= 1.0 {
937        return true;
938    }
939    rng.next_unit() < probability
940}
941
942#[derive(Debug, Clone, Copy)]
943struct LcgRng {
944    state: u64,
945}
946
947impl LcgRng {
948    const MULTIPLIER: u64 = 6364136223846793005;
949    const INCREMENT: u64 = 1;
950
951    fn new(seed: u64) -> Self {
952        Self { state: seed }
953    }
954
955    fn next_u32(&mut self) -> u32 {
956        self.state = self
957            .state
958            .wrapping_mul(Self::MULTIPLIER)
959            .wrapping_add(Self::INCREMENT);
960        (self.state >> 32) as u32
961    }
962
963    fn next_unit(&mut self) -> f32 {
964        self.next_u32() as f32 / (u32::MAX as f32 + 1.0)
965    }
966
967    fn next_signed_unit(&mut self) -> f32 {
968        self.next_unit() * 2.0 - 1.0
969    }
970
971    fn next_gaussian(&mut self) -> f32 {
972        let u1 = self.next_unit().max(f32::MIN_POSITIVE);
973        let u2 = self.next_unit();
974        let magnitude = (-2.0 * u1.ln()).sqrt();
975        let angle = 2.0 * std::f32::consts::PI * u2;
976        magnitude * angle.cos()
977    }
978
979    fn next_usize_inclusive(&mut self, upper_inclusive: usize) -> usize {
980        self.next_usize(upper_inclusive.saturating_add(1))
981    }
982
983    fn next_usize(&mut self, upper_exclusive: usize) -> usize {
984        if upper_exclusive == 0 {
985            return 0;
986        }
987        (self.next_u32() as usize) % upper_exclusive
988    }
989}