1use std::fmt;
2use std::sync::Arc;
3
4use yscv_imgproc::{
5 box_blur_3x3, flip_horizontal, flip_vertical, normalize, resize_nearest, rotate90_cw,
6};
7use yscv_tensor::{Tensor, TensorError};
8
9use super::transform::Transform;
10use crate::{ModelError, SupervisedDataset};
11
12pub enum ImageAugmentationOp {
14 HorizontalFlip { probability: f32 },
16 VerticalFlip { probability: f32 },
18 RandomRotate90 { probability: f32 },
23 BrightnessJitter { max_delta: f32 },
25 ContrastJitter { max_scale_delta: f32 },
27 GammaJitter { max_gamma_delta: f32 },
29 GaussianNoise { probability: f32, std_dev: f32 },
31 BoxBlur3x3 { probability: f32 },
33 RandomResizedCrop {
35 probability: f32,
36 min_scale: f32,
37 max_scale: f32,
38 },
39 Cutout {
41 probability: f32,
42 max_height_fraction: f32,
43 max_width_fraction: f32,
44 fill_value: f32,
45 },
46 ChannelNormalize { mean: Vec<f32>, std: Vec<f32> },
48 Custom(Arc<dyn Fn(&Tensor) -> Result<Tensor, ModelError> + Send + Sync>),
50 RandomCrop { height: usize, width: usize },
52 GaussianBlur { kernel_size: usize },
54}
55
56impl fmt::Debug for ImageAugmentationOp {
57 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
58 match self {
59 Self::HorizontalFlip { probability } => f
60 .debug_struct("HorizontalFlip")
61 .field("probability", probability)
62 .finish(),
63 Self::VerticalFlip { probability } => f
64 .debug_struct("VerticalFlip")
65 .field("probability", probability)
66 .finish(),
67 Self::RandomRotate90 { probability } => f
68 .debug_struct("RandomRotate90")
69 .field("probability", probability)
70 .finish(),
71 Self::BrightnessJitter { max_delta } => f
72 .debug_struct("BrightnessJitter")
73 .field("max_delta", max_delta)
74 .finish(),
75 Self::ContrastJitter { max_scale_delta } => f
76 .debug_struct("ContrastJitter")
77 .field("max_scale_delta", max_scale_delta)
78 .finish(),
79 Self::GammaJitter { max_gamma_delta } => f
80 .debug_struct("GammaJitter")
81 .field("max_gamma_delta", max_gamma_delta)
82 .finish(),
83 Self::GaussianNoise {
84 probability,
85 std_dev,
86 } => f
87 .debug_struct("GaussianNoise")
88 .field("probability", probability)
89 .field("std_dev", std_dev)
90 .finish(),
91 Self::BoxBlur3x3 { probability } => f
92 .debug_struct("BoxBlur3x3")
93 .field("probability", probability)
94 .finish(),
95 Self::RandomResizedCrop {
96 probability,
97 min_scale,
98 max_scale,
99 } => f
100 .debug_struct("RandomResizedCrop")
101 .field("probability", probability)
102 .field("min_scale", min_scale)
103 .field("max_scale", max_scale)
104 .finish(),
105 Self::Cutout {
106 probability,
107 max_height_fraction,
108 max_width_fraction,
109 fill_value,
110 } => f
111 .debug_struct("Cutout")
112 .field("probability", probability)
113 .field("max_height_fraction", max_height_fraction)
114 .field("max_width_fraction", max_width_fraction)
115 .field("fill_value", fill_value)
116 .finish(),
117 Self::ChannelNormalize { mean, std } => f
118 .debug_struct("ChannelNormalize")
119 .field("mean", mean)
120 .field("std", std)
121 .finish(),
122 Self::Custom(_) => f.debug_tuple("Custom").field(&"<closure>").finish(),
123 Self::RandomCrop { height, width } => f
124 .debug_struct("RandomCrop")
125 .field("height", height)
126 .field("width", width)
127 .finish(),
128 Self::GaussianBlur { kernel_size } => f
129 .debug_struct("GaussianBlur")
130 .field("kernel_size", kernel_size)
131 .finish(),
132 }
133 }
134}
135
136impl Clone for ImageAugmentationOp {
137 fn clone(&self) -> Self {
138 match self {
139 Self::HorizontalFlip { probability } => Self::HorizontalFlip {
140 probability: *probability,
141 },
142 Self::VerticalFlip { probability } => Self::VerticalFlip {
143 probability: *probability,
144 },
145 Self::RandomRotate90 { probability } => Self::RandomRotate90 {
146 probability: *probability,
147 },
148 Self::BrightnessJitter { max_delta } => Self::BrightnessJitter {
149 max_delta: *max_delta,
150 },
151 Self::ContrastJitter { max_scale_delta } => Self::ContrastJitter {
152 max_scale_delta: *max_scale_delta,
153 },
154 Self::GammaJitter { max_gamma_delta } => Self::GammaJitter {
155 max_gamma_delta: *max_gamma_delta,
156 },
157 Self::GaussianNoise {
158 probability,
159 std_dev,
160 } => Self::GaussianNoise {
161 probability: *probability,
162 std_dev: *std_dev,
163 },
164 Self::BoxBlur3x3 { probability } => Self::BoxBlur3x3 {
165 probability: *probability,
166 },
167 Self::RandomResizedCrop {
168 probability,
169 min_scale,
170 max_scale,
171 } => Self::RandomResizedCrop {
172 probability: *probability,
173 min_scale: *min_scale,
174 max_scale: *max_scale,
175 },
176 Self::Cutout {
177 probability,
178 max_height_fraction,
179 max_width_fraction,
180 fill_value,
181 } => Self::Cutout {
182 probability: *probability,
183 max_height_fraction: *max_height_fraction,
184 max_width_fraction: *max_width_fraction,
185 fill_value: *fill_value,
186 },
187 Self::ChannelNormalize { mean, std } => Self::ChannelNormalize {
188 mean: mean.clone(),
189 std: std.clone(),
190 },
191 Self::Custom(f) => Self::Custom(Arc::clone(f)),
192 Self::RandomCrop { height, width } => Self::RandomCrop {
193 height: *height,
194 width: *width,
195 },
196 Self::GaussianBlur { kernel_size } => Self::GaussianBlur {
197 kernel_size: *kernel_size,
198 },
199 }
200 }
201}
202
203impl PartialEq for ImageAugmentationOp {
204 fn eq(&self, other: &Self) -> bool {
205 match (self, other) {
206 (Self::HorizontalFlip { probability: a }, Self::HorizontalFlip { probability: b }) => {
207 a == b
208 }
209 (Self::VerticalFlip { probability: a }, Self::VerticalFlip { probability: b }) => {
210 a == b
211 }
212 (Self::RandomRotate90 { probability: a }, Self::RandomRotate90 { probability: b }) => {
213 a == b
214 }
215 (Self::BrightnessJitter { max_delta: a }, Self::BrightnessJitter { max_delta: b }) => {
216 a == b
217 }
218 (
219 Self::ContrastJitter { max_scale_delta: a },
220 Self::ContrastJitter { max_scale_delta: b },
221 ) => a == b,
222 (
223 Self::GammaJitter { max_gamma_delta: a },
224 Self::GammaJitter { max_gamma_delta: b },
225 ) => a == b,
226 (
227 Self::GaussianNoise {
228 probability: p1,
229 std_dev: s1,
230 },
231 Self::GaussianNoise {
232 probability: p2,
233 std_dev: s2,
234 },
235 ) => p1 == p2 && s1 == s2,
236 (Self::BoxBlur3x3 { probability: a }, Self::BoxBlur3x3 { probability: b }) => a == b,
237 (
238 Self::RandomResizedCrop {
239 probability: p1,
240 min_scale: mn1,
241 max_scale: mx1,
242 },
243 Self::RandomResizedCrop {
244 probability: p2,
245 min_scale: mn2,
246 max_scale: mx2,
247 },
248 ) => p1 == p2 && mn1 == mn2 && mx1 == mx2,
249 (
250 Self::Cutout {
251 probability: p1,
252 max_height_fraction: h1,
253 max_width_fraction: w1,
254 fill_value: f1,
255 },
256 Self::Cutout {
257 probability: p2,
258 max_height_fraction: h2,
259 max_width_fraction: w2,
260 fill_value: f2,
261 },
262 ) => p1 == p2 && h1 == h2 && w1 == w2 && f1 == f2,
263 (
264 Self::ChannelNormalize { mean: m1, std: s1 },
265 Self::ChannelNormalize { mean: m2, std: s2 },
266 ) => m1 == m2 && s1 == s2,
267 (Self::Custom(_), Self::Custom(_)) => false, (
269 Self::RandomCrop {
270 height: h1,
271 width: w1,
272 },
273 Self::RandomCrop {
274 height: h2,
275 width: w2,
276 },
277 ) => h1 == h2 && w1 == w2,
278 (Self::GaussianBlur { kernel_size: a }, Self::GaussianBlur { kernel_size: b }) => {
279 a == b
280 }
281 _ => false,
282 }
283 }
284}
285
286impl ImageAugmentationOp {
287 fn validate(&self) -> Result<(), ModelError> {
288 match self {
289 Self::HorizontalFlip { probability } => {
290 validate_probability("horizontal_flip", *probability)
291 }
292 Self::VerticalFlip { probability } => {
293 validate_probability("vertical_flip", *probability)
294 }
295 Self::RandomRotate90 { probability } => {
296 validate_probability("random_rotate90", *probability)
297 }
298 Self::BrightnessJitter { max_delta } => {
299 if !max_delta.is_finite() || *max_delta < 0.0 {
300 return Err(ModelError::InvalidAugmentationArgument {
301 operation: "brightness_jitter",
302 message: format!("max_delta must be finite and >= 0, got {max_delta}"),
303 });
304 }
305 Ok(())
306 }
307 Self::ContrastJitter { max_scale_delta } => {
308 if !max_scale_delta.is_finite() || *max_scale_delta < 0.0 {
309 return Err(ModelError::InvalidAugmentationArgument {
310 operation: "contrast_jitter",
311 message: format!(
312 "max_scale_delta must be finite and >= 0, got {max_scale_delta}"
313 ),
314 });
315 }
316 Ok(())
317 }
318 Self::GammaJitter { max_gamma_delta } => {
319 if !max_gamma_delta.is_finite() || *max_gamma_delta < 0.0 {
320 return Err(ModelError::InvalidAugmentationArgument {
321 operation: "gamma_jitter",
322 message: format!(
323 "max_gamma_delta must be finite and >= 0, got {max_gamma_delta}"
324 ),
325 });
326 }
327 Ok(())
328 }
329 Self::GaussianNoise {
330 probability,
331 std_dev,
332 } => {
333 validate_probability("gaussian_noise", *probability)?;
334 if !std_dev.is_finite() || *std_dev < 0.0 {
335 return Err(ModelError::InvalidAugmentationArgument {
336 operation: "gaussian_noise",
337 message: format!("std_dev must be finite and >= 0, got {std_dev}"),
338 });
339 }
340 Ok(())
341 }
342 Self::BoxBlur3x3 { probability } => validate_probability("box_blur_3x3", *probability),
343 Self::RandomResizedCrop {
344 probability,
345 min_scale,
346 max_scale,
347 } => {
348 validate_probability("random_resized_crop", *probability)?;
349 validate_fraction("random_resized_crop", "min_scale", *min_scale)?;
350 validate_fraction("random_resized_crop", "max_scale", *max_scale)?;
351 if min_scale > max_scale {
352 return Err(ModelError::InvalidAugmentationArgument {
353 operation: "random_resized_crop",
354 message: format!(
355 "min_scale must be <= max_scale, got min_scale={min_scale}, max_scale={max_scale}"
356 ),
357 });
358 }
359 Ok(())
360 }
361 Self::Cutout {
362 probability,
363 max_height_fraction,
364 max_width_fraction,
365 fill_value,
366 } => {
367 validate_probability("cutout", *probability)?;
368 validate_fraction("cutout", "max_height_fraction", *max_height_fraction)?;
369 validate_fraction("cutout", "max_width_fraction", *max_width_fraction)?;
370 if !fill_value.is_finite() {
371 return Err(ModelError::InvalidAugmentationArgument {
372 operation: "cutout",
373 message: format!("fill_value must be finite, got {fill_value}"),
374 });
375 }
376 Ok(())
377 }
378 Self::ChannelNormalize { mean, std } => {
379 if mean.is_empty() || std.is_empty() {
380 return Err(ModelError::InvalidAugmentationArgument {
381 operation: "channel_normalize",
382 message: "mean/std must be non-empty".to_string(),
383 });
384 }
385 if mean.len() != std.len() {
386 return Err(ModelError::InvalidAugmentationArgument {
387 operation: "channel_normalize",
388 message: format!(
389 "mean/std length mismatch: mean_len={}, std_len={}",
390 mean.len(),
391 std.len()
392 ),
393 });
394 }
395 for (channel, mean_value) in mean.iter().enumerate() {
396 if !mean_value.is_finite() {
397 return Err(ModelError::InvalidAugmentationArgument {
398 operation: "channel_normalize",
399 message: format!("mean[{channel}] must be finite"),
400 });
401 }
402 }
403 for (channel, std_value) in std.iter().enumerate() {
404 if !std_value.is_finite() || *std_value <= 0.0 {
405 return Err(ModelError::InvalidAugmentationArgument {
406 operation: "channel_normalize",
407 message: format!("std[{channel}] must be finite and > 0"),
408 });
409 }
410 }
411 Ok(())
412 }
413 Self::Custom(_) => Ok(()),
414 Self::RandomCrop { height, width } => {
415 if *height == 0 || *width == 0 {
416 return Err(ModelError::InvalidAugmentationArgument {
417 operation: "random_crop",
418 message: format!(
419 "height and width must be > 0, got height={height}, width={width}"
420 ),
421 });
422 }
423 Ok(())
424 }
425 Self::GaussianBlur { kernel_size } => {
426 if *kernel_size == 0 || *kernel_size % 2 == 0 {
427 return Err(ModelError::InvalidAugmentationArgument {
428 operation: "gaussian_blur",
429 message: format!("kernel_size must be odd and >= 1, got {kernel_size}"),
430 });
431 }
432 Ok(())
433 }
434 }
435 }
436
437 pub fn from_transform<T: Transform + Send + Sync + 'static>(t: T) -> Self {
439 Self::Custom(Arc::new(move |input| t.apply(input)))
440 }
441}
442
443#[derive(Debug, Clone, PartialEq)]
445pub struct ImageAugmentationPipeline {
446 ops: Vec<ImageAugmentationOp>,
447}
448
449impl ImageAugmentationPipeline {
450 pub fn new(ops: Vec<ImageAugmentationOp>) -> Result<Self, ModelError> {
451 for op in &ops {
452 op.validate()?;
453 }
454 Ok(Self { ops })
455 }
456
457 pub fn ops(&self) -> &[ImageAugmentationOp] {
458 &self.ops
459 }
460
461 pub fn apply_nhwc(&self, inputs: &Tensor, seed: u64) -> Result<Tensor, ModelError> {
462 if inputs.rank() != 4 {
463 return Err(ModelError::InvalidAugmentationInputShape {
464 got: inputs.shape().to_vec(),
465 });
466 }
467
468 let shape = inputs.shape();
469 let sample_count = shape[0];
470 let sample_len = shape[1..].iter().try_fold(1usize, |acc, dim| {
471 acc.checked_mul(*dim)
472 .ok_or_else(|| TensorError::SizeOverflow {
473 shape: shape.to_vec(),
474 })
475 })?;
476
477 let mut out = Vec::with_capacity(inputs.data().len());
478 let mut rng = LcgRng::new(seed);
479 let mut out_sample_shape: Option<Vec<usize>> = None;
480
481 for sample_idx in 0..sample_count {
482 let start =
483 sample_idx
484 .checked_mul(sample_len)
485 .ok_or_else(|| TensorError::SizeOverflow {
486 shape: shape.to_vec(),
487 })?;
488 let end = start
489 .checked_add(sample_len)
490 .ok_or_else(|| TensorError::SizeOverflow {
491 shape: shape.to_vec(),
492 })?;
493
494 let mut sample =
495 Tensor::from_vec(shape[1..].to_vec(), inputs.data()[start..end].to_vec())?;
496 for op in &self.ops {
497 sample = apply_op(sample, op, &mut rng)?;
498 }
499 if sample_idx == 0 {
503 out_sample_shape = Some(sample.shape().to_vec());
504 } else if let Some(ref expected) = out_sample_shape
505 && sample.shape() != expected.as_slice()
506 {
507 return Err(ModelError::InvalidAugmentationArgument {
508 operation: "pipeline",
509 message: format!(
510 "augmentation produced inconsistent sample shapes: first={:?}, sample[{sample_idx}]={:?}",
511 expected,
512 sample.shape()
513 ),
514 });
515 }
516 out.extend_from_slice(sample.data());
517 }
518
519 let final_sample_shape = out_sample_shape.unwrap_or_else(|| shape[1..].to_vec());
520 let mut final_shape = vec![sample_count];
521 final_shape.extend_from_slice(&final_sample_shape);
522 Tensor::from_vec(final_shape, out).map_err(Into::into)
523 }
524}
525
526impl SupervisedDataset {
527 pub fn augment_nhwc(
529 &self,
530 pipeline: &ImageAugmentationPipeline,
531 seed: u64,
532 ) -> Result<Self, ModelError> {
533 let augmented_inputs = pipeline.apply_nhwc(self.inputs(), seed)?;
534 Self::new(augmented_inputs, self.targets().clone())
535 }
536}
537
538fn apply_op(
539 input: Tensor,
540 op: &ImageAugmentationOp,
541 rng: &mut LcgRng,
542) -> Result<Tensor, ModelError> {
543 match op {
544 ImageAugmentationOp::HorizontalFlip { probability } => {
545 if should_apply(*probability, rng) {
546 flip_horizontal(&input).map_err(Into::into)
547 } else {
548 Ok(input)
549 }
550 }
551 ImageAugmentationOp::VerticalFlip { probability } => {
552 if should_apply(*probability, rng) {
553 flip_vertical(&input).map_err(Into::into)
554 } else {
555 Ok(input)
556 }
557 }
558 ImageAugmentationOp::RandomRotate90 { probability } => {
559 if should_apply(*probability, rng) {
560 apply_random_rotate90(input, rng)
561 } else {
562 Ok(input)
563 }
564 }
565 ImageAugmentationOp::BrightnessJitter { max_delta } => {
566 if *max_delta == 0.0 {
567 return Ok(input);
568 }
569 let delta = rng.next_signed_unit() * *max_delta;
570 apply_brightness_delta(&input, delta)
571 }
572 ImageAugmentationOp::ContrastJitter { max_scale_delta } => {
573 if *max_scale_delta == 0.0 {
574 return Ok(input);
575 }
576 let scale = (1.0 + rng.next_signed_unit() * *max_scale_delta).max(0.0);
577 apply_contrast_scale(&input, scale)
578 }
579 ImageAugmentationOp::GammaJitter { max_gamma_delta } => {
580 if *max_gamma_delta == 0.0 {
581 return Ok(input);
582 }
583 let gamma = (1.0 + rng.next_signed_unit() * *max_gamma_delta).max(0.01);
584 apply_gamma_correction(&input, gamma)
585 }
586 ImageAugmentationOp::GaussianNoise {
587 probability,
588 std_dev,
589 } => {
590 if !should_apply(*probability, rng) || *std_dev == 0.0 {
591 return Ok(input);
592 }
593 apply_gaussian_noise(&input, *std_dev, rng)
594 }
595 ImageAugmentationOp::BoxBlur3x3 { probability } => {
596 if should_apply(*probability, rng) {
597 box_blur_3x3(&input).map_err(Into::into)
598 } else {
599 Ok(input)
600 }
601 }
602 ImageAugmentationOp::RandomResizedCrop {
603 probability,
604 min_scale,
605 max_scale,
606 } => {
607 if !should_apply(*probability, rng) {
608 return Ok(input);
609 }
610 apply_random_resized_crop(&input, *min_scale, *max_scale, rng)
611 }
612 ImageAugmentationOp::Cutout {
613 probability,
614 max_height_fraction,
615 max_width_fraction,
616 fill_value,
617 } => {
618 if !should_apply(*probability, rng) {
619 return Ok(input);
620 }
621 apply_cutout(
622 &input,
623 *max_height_fraction,
624 *max_width_fraction,
625 *fill_value,
626 rng,
627 )
628 }
629 ImageAugmentationOp::ChannelNormalize { mean, std } => {
630 normalize(&input, mean, std).map_err(Into::into)
631 }
632 ImageAugmentationOp::Custom(f) => f(&input),
633 ImageAugmentationOp::RandomCrop { height, width } => {
634 apply_random_crop(&input, *height, *width, rng)
635 }
636 ImageAugmentationOp::GaussianBlur { kernel_size } => {
637 apply_gaussian_blur(&input, *kernel_size)
638 }
639 }
640}
641
642fn apply_brightness_delta(input: &Tensor, delta: f32) -> Result<Tensor, ModelError> {
643 let mut output = Vec::with_capacity(input.data().len());
644 for value in input.data() {
645 output.push((*value + delta).clamp(0.0, 1.0));
646 }
647 Tensor::from_vec(input.shape().to_vec(), output).map_err(Into::into)
648}
649
650fn apply_random_rotate90(input: Tensor, rng: &mut LcgRng) -> Result<Tensor, ModelError> {
651 if input.rank() != 3 {
652 return Err(ModelError::InvalidAugmentationArgument {
653 operation: "random_rotate90",
654 message: format!("expected rank-3 HWC sample, got shape {:?}", input.shape()),
655 });
656 }
657
658 let height = input.shape()[0];
659 let width = input.shape()[1];
660 if height == 0 || width == 0 {
661 return Ok(input);
662 }
663
664 let rotation_count = if height == width {
665 rng.next_usize(4)
666 } else {
667 rng.next_usize(2) * 2
668 };
669
670 let mut rotated = input;
671 for _ in 0..rotation_count {
672 rotated = rotate90_cw(&rotated).map_err(ModelError::from)?;
673 }
674 Ok(rotated)
675}
676
677fn apply_contrast_scale(input: &Tensor, scale: f32) -> Result<Tensor, ModelError> {
678 let mean = if input.is_empty() {
679 0.0
680 } else {
681 input.data().iter().copied().sum::<f32>() / input.len() as f32
682 };
683
684 let mut output = Vec::with_capacity(input.data().len());
685 for value in input.data() {
686 let scaled = (*value - mean) * scale + mean;
687 output.push(scaled.clamp(0.0, 1.0));
688 }
689 Tensor::from_vec(input.shape().to_vec(), output).map_err(Into::into)
690}
691
692fn apply_gamma_correction(input: &Tensor, gamma: f32) -> Result<Tensor, ModelError> {
693 let mut output = Vec::with_capacity(input.data().len());
694 for value in input.data() {
695 output.push(value.clamp(0.0, 1.0).powf(gamma));
696 }
697 Tensor::from_vec(input.shape().to_vec(), output).map_err(Into::into)
698}
699
700fn apply_gaussian_noise(
701 input: &Tensor,
702 std_dev: f32,
703 rng: &mut LcgRng,
704) -> Result<Tensor, ModelError> {
705 let mut output = Vec::with_capacity(input.data().len());
706 for value in input.data() {
707 let noise = rng.next_gaussian() * std_dev;
708 output.push((*value + noise).clamp(0.0, 1.0));
709 }
710 Tensor::from_vec(input.shape().to_vec(), output).map_err(Into::into)
711}
712
713fn apply_cutout(
714 input: &Tensor,
715 max_height_fraction: f32,
716 max_width_fraction: f32,
717 fill_value: f32,
718 rng: &mut LcgRng,
719) -> Result<Tensor, ModelError> {
720 if input.rank() != 3 {
721 return Err(ModelError::InvalidAugmentationArgument {
722 operation: "cutout",
723 message: format!("expected rank-3 HWC sample, got shape {:?}", input.shape()),
724 });
725 }
726 let height = input.shape()[0];
727 let width = input.shape()[1];
728 let channels = input.shape()[2];
729 if height == 0 || width == 0 || channels == 0 {
730 return Ok(input.clone());
731 }
732
733 let max_cut_height = ((height as f32 * max_height_fraction).floor() as usize)
734 .max(1)
735 .min(height);
736 let max_cut_width = ((width as f32 * max_width_fraction).floor() as usize)
737 .max(1)
738 .min(width);
739
740 let cut_height = rng.next_usize_inclusive(max_cut_height - 1) + 1;
741 let cut_width = rng.next_usize_inclusive(max_cut_width - 1) + 1;
742 let top = rng.next_usize(height - cut_height + 1);
743 let left = rng.next_usize(width - cut_width + 1);
744
745 let mut output = input.data().to_vec();
746 for y in top..(top + cut_height) {
747 for x in left..(left + cut_width) {
748 let base = (y * width + x) * channels;
749 for channel in 0..channels {
750 output[base + channel] = fill_value;
751 }
752 }
753 }
754 Tensor::from_vec(input.shape().to_vec(), output).map_err(Into::into)
755}
756
757fn apply_random_resized_crop(
758 input: &Tensor,
759 min_scale: f32,
760 max_scale: f32,
761 rng: &mut LcgRng,
762) -> Result<Tensor, ModelError> {
763 if input.rank() != 3 {
764 return Err(ModelError::InvalidAugmentationArgument {
765 operation: "random_resized_crop",
766 message: format!("expected rank-3 HWC sample, got shape {:?}", input.shape()),
767 });
768 }
769 let height = input.shape()[0];
770 let width = input.shape()[1];
771 let channels = input.shape()[2];
772 if height == 0 || width == 0 || channels == 0 {
773 return Ok(input.clone());
774 }
775
776 let scale = if (max_scale - min_scale).abs() <= f32::EPSILON {
777 min_scale
778 } else {
779 min_scale + rng.next_unit() * (max_scale - min_scale)
780 };
781
782 let crop_height = ((height as f32 * scale).floor() as usize)
783 .max(1)
784 .min(height);
785 let crop_width = ((width as f32 * scale).floor() as usize).max(1).min(width);
786 let top = rng.next_usize(height - crop_height + 1);
787 let left = rng.next_usize(width - crop_width + 1);
788
789 let mut cropped = vec![0.0f32; crop_height * crop_width * channels];
790 for y in 0..crop_height {
791 for x in 0..crop_width {
792 let src_base = ((top + y) * width + (left + x)) * channels;
793 let dst_base = (y * crop_width + x) * channels;
794 cropped[dst_base..(dst_base + channels)]
795 .copy_from_slice(&input.data()[src_base..(src_base + channels)]);
796 }
797 }
798
799 let cropped_tensor = Tensor::from_vec(vec![crop_height, crop_width, channels], cropped)?;
800 resize_nearest(&cropped_tensor, height, width).map_err(Into::into)
801}
802
803fn apply_random_crop(
804 input: &Tensor,
805 crop_height: usize,
806 crop_width: usize,
807 rng: &mut LcgRng,
808) -> Result<Tensor, ModelError> {
809 if input.rank() != 3 {
810 return Err(ModelError::InvalidAugmentationArgument {
811 operation: "random_crop",
812 message: format!("expected rank-3 HWC sample, got shape {:?}", input.shape()),
813 });
814 }
815 let height = input.shape()[0];
816 let width = input.shape()[1];
817 let channels = input.shape()[2];
818
819 if crop_height > height || crop_width > width {
820 return Err(ModelError::InvalidAugmentationArgument {
821 operation: "random_crop",
822 message: format!(
823 "crop size ({crop_height}x{crop_width}) exceeds input size ({height}x{width})"
824 ),
825 });
826 }
827
828 let top = rng.next_usize(height - crop_height + 1);
829 let left = rng.next_usize(width - crop_width + 1);
830
831 let mut cropped = vec![0.0f32; crop_height * crop_width * channels];
832 for y in 0..crop_height {
833 for x in 0..crop_width {
834 let src_base = ((top + y) * width + (left + x)) * channels;
835 let dst_base = (y * crop_width + x) * channels;
836 cropped[dst_base..(dst_base + channels)]
837 .copy_from_slice(&input.data()[src_base..(src_base + channels)]);
838 }
839 }
840
841 Tensor::from_vec(vec![crop_height, crop_width, channels], cropped).map_err(Into::into)
842}
843
844fn apply_gaussian_blur(input: &Tensor, kernel_size: usize) -> Result<Tensor, ModelError> {
845 if input.rank() != 3 {
846 return Err(ModelError::InvalidAugmentationArgument {
847 operation: "gaussian_blur",
848 message: format!("expected rank-3 HWC sample, got shape {:?}", input.shape()),
849 });
850 }
851 let height = input.shape()[0];
852 let width = input.shape()[1];
853 let channels = input.shape()[2];
854 if height == 0 || width == 0 || channels == 0 {
855 return Ok(input.clone());
856 }
857
858 let sigma = (kernel_size as f32 - 1.0) / 2.0 * 0.5 + 0.8;
860 let half = (kernel_size / 2) as isize;
861 let mut kernel = Vec::with_capacity(kernel_size);
862 let mut sum = 0.0f32;
863 for i in 0..kernel_size {
864 let x = (i as isize - half) as f32;
865 let w = (-x * x / (2.0 * sigma * sigma)).exp();
866 kernel.push(w);
867 sum += w;
868 }
869 for w in &mut kernel {
870 *w /= sum;
871 }
872
873 let mut tmp = vec![0.0f32; height * width * channels];
875 for y in 0..height {
876 for x in 0..width {
877 for c in 0..channels {
878 let mut acc = 0.0f32;
879 for ki in 0..kernel_size {
880 let sx =
881 (x as isize + ki as isize - half).clamp(0, width as isize - 1) as usize;
882 acc += input.data()[(y * width + sx) * channels + c] * kernel[ki];
883 }
884 tmp[(y * width + x) * channels + c] = acc;
885 }
886 }
887 }
888
889 let mut out = vec![0.0f32; height * width * channels];
891 for y in 0..height {
892 for x in 0..width {
893 for c in 0..channels {
894 let mut acc = 0.0f32;
895 for ki in 0..kernel_size {
896 let sy =
897 (y as isize + ki as isize - half).clamp(0, height as isize - 1) as usize;
898 acc += tmp[(sy * width + x) * channels + c] * kernel[ki];
899 }
900 out[(y * width + x) * channels + c] = acc;
901 }
902 }
903 }
904
905 Tensor::from_vec(vec![height, width, channels], out).map_err(Into::into)
906}
907
908fn validate_probability(operation: &'static str, probability: f32) -> Result<(), ModelError> {
909 if !probability.is_finite() || !(0.0..=1.0).contains(&probability) {
910 return Err(ModelError::InvalidAugmentationProbability {
911 operation,
912 value: probability,
913 });
914 }
915 Ok(())
916}
917
918fn validate_fraction(
919 operation: &'static str,
920 parameter: &'static str,
921 value: f32,
922) -> Result<(), ModelError> {
923 if !value.is_finite() || value <= 0.0 || value > 1.0 {
924 return Err(ModelError::InvalidAugmentationArgument {
925 operation,
926 message: format!("{parameter} must be finite in (0, 1], got {value}"),
927 });
928 }
929 Ok(())
930}
931
932fn should_apply(probability: f32, rng: &mut LcgRng) -> bool {
933 if probability <= 0.0 {
934 return false;
935 }
936 if probability >= 1.0 {
937 return true;
938 }
939 rng.next_unit() < probability
940}
941
942#[derive(Debug, Clone, Copy)]
943struct LcgRng {
944 state: u64,
945}
946
947impl LcgRng {
948 const MULTIPLIER: u64 = 6364136223846793005;
949 const INCREMENT: u64 = 1;
950
951 fn new(seed: u64) -> Self {
952 Self { state: seed }
953 }
954
955 fn next_u32(&mut self) -> u32 {
956 self.state = self
957 .state
958 .wrapping_mul(Self::MULTIPLIER)
959 .wrapping_add(Self::INCREMENT);
960 (self.state >> 32) as u32
961 }
962
963 fn next_unit(&mut self) -> f32 {
964 self.next_u32() as f32 / (u32::MAX as f32 + 1.0)
965 }
966
967 fn next_signed_unit(&mut self) -> f32 {
968 self.next_unit() * 2.0 - 1.0
969 }
970
971 fn next_gaussian(&mut self) -> f32 {
972 let u1 = self.next_unit().max(f32::MIN_POSITIVE);
973 let u2 = self.next_unit();
974 let magnitude = (-2.0 * u1.ln()).sqrt();
975 let angle = 2.0 * std::f32::consts::PI * u2;
976 magnitude * angle.cos()
977 }
978
979 fn next_usize_inclusive(&mut self, upper_inclusive: usize) -> usize {
980 self.next_usize(upper_inclusive.saturating_add(1))
981 }
982
983 fn next_usize(&mut self, upper_exclusive: usize) -> usize {
984 if upper_exclusive == 0 {
985 return 0;
986 }
987 (self.next_u32() as usize) % upper_exclusive
988 }
989}