Skip to main content

tensorlogic_train/augmentation/
functional.rs

1use scirs2_core::ndarray::{ArrayD, Dimension, IxDyn};
2
3use super::error::AugmentationError;
4use super::rng::{sample_beta_symmetric, AugRng};
5
6/// Add element-wise Gaussian noise: output = input + N(0, std²).
7pub fn gaussian_noise(
8    input: &ArrayD<f64>,
9    std: f64,
10    rng: &mut AugRng,
11) -> Result<ArrayD<f64>, AugmentationError> {
12    if std < 0.0 {
13        return Err(AugmentationError::InvalidNoise { std });
14    }
15    if input.is_empty() {
16        return Err(AugmentationError::EmptyInput);
17    }
18    let noisy = input.mapv(|x| x + rng.next_normal() * std);
19    Ok(noisy)
20}
21
22/// Apply inverted dropout: zero each element with probability `p`; scale survivors by 1/(1−p).
23///
24/// When `training` is `false` the input is returned unchanged (inference mode).
25pub fn dropout(
26    input: &ArrayD<f64>,
27    p: f64,
28    training: bool,
29    rng: &mut AugRng,
30) -> Result<ArrayD<f64>, AugmentationError> {
31    if !(0.0..=1.0).contains(&p) {
32        return Err(AugmentationError::InvalidProbability(p));
33    }
34    if !training {
35        return Ok(input.clone());
36    }
37    let scale = if (p - 1.0).abs() < 1e-12 {
38        0.0
39    } else {
40        1.0 / (1.0 - p)
41    };
42    let result = input.mapv(|x| if rng.next_bool(p) { 0.0 } else { x * scale });
43    Ok(result)
44}
45
46/// Generate a binary dropout mask of the given shape.
47///
48/// Each element is 1.0 with probability (1 − p) and 0.0 with probability p.
49pub fn dropout_mask(
50    shape: &[usize],
51    p: f64,
52    rng: &mut AugRng,
53) -> Result<ArrayD<f64>, AugmentationError> {
54    if !(0.0..=1.0).contains(&p) {
55        return Err(AugmentationError::InvalidProbability(p));
56    }
57    let total: usize = shape.iter().product();
58    let data: Vec<f64> = (0..total)
59        .map(|_| if rng.next_bool(p) { 0.0 } else { 1.0 })
60        .collect();
61    ArrayD::from_shape_vec(IxDyn(shape), data).map_err(|_| AugmentationError::EmptyInput)
62}
63
64/// Mixup: λ·x1 + (1−λ)·x2 where λ ~ Beta(alpha, alpha).
65///
66/// Returns `(mixed, lambda)`.
67pub fn mixup(
68    x1: &ArrayD<f64>,
69    x2: &ArrayD<f64>,
70    alpha: f64,
71    rng: &mut AugRng,
72) -> Result<(ArrayD<f64>, f64), AugmentationError> {
73    if alpha <= 0.0 {
74        return Err(AugmentationError::InvalidAlpha(alpha));
75    }
76    if x1.shape() != x2.shape() {
77        return Err(AugmentationError::ShapeMismatch {
78            expected: x1.shape().to_vec(),
79            got: x2.shape().to_vec(),
80        });
81    }
82    if x1.is_empty() {
83        return Err(AugmentationError::EmptyInput);
84    }
85    let lambda = sample_beta_symmetric(alpha, rng);
86    let mixed = x1.mapv(|v| v * lambda) + x2.mapv(|v| v * (1.0 - lambda));
87    Ok((mixed, lambda))
88}
89
90/// CutMix: paste a random rectangular region from x2 into x1.
91///
92/// The patch covers fraction (1 − lambda) of the spatial area.
93/// Input must have at least 2 dimensions; last two are treated as (H, W).
94/// Returns `(mixed, lambda)` where lambda = fraction of x1 retained.
95pub fn cutmix(
96    x1: &ArrayD<f64>,
97    x2: &ArrayD<f64>,
98    alpha: f64,
99    rng: &mut AugRng,
100) -> Result<(ArrayD<f64>, f64), AugmentationError> {
101    if alpha <= 0.0 {
102        return Err(AugmentationError::InvalidAlpha(alpha));
103    }
104    if x1.shape() != x2.shape() {
105        return Err(AugmentationError::ShapeMismatch {
106            expected: x1.shape().to_vec(),
107            got: x2.shape().to_vec(),
108        });
109    }
110    if x1.ndim() < 2 {
111        return Err(AugmentationError::ShapeMismatch {
112            expected: vec![2],
113            got: x1.shape().to_vec(),
114        });
115    }
116    if x1.is_empty() {
117        return Err(AugmentationError::EmptyInput);
118    }
119
120    let ndim = x1.ndim();
121    let h = x1.shape()[ndim - 2];
122    let w = x1.shape()[ndim - 1];
123
124    // Sample mixing ratio
125    let lambda_raw = sample_beta_symmetric(alpha, rng);
126    // Patch area fraction = 1 - lambda_raw
127    let cut_ratio = (1.0 - lambda_raw).sqrt();
128    let cut_h = ((h as f64 * cut_ratio) as usize).max(1).min(h);
129    let cut_w = ((w as f64 * cut_ratio) as usize).max(1).min(w);
130
131    // Random top-left corner
132    let top = if h > cut_h {
133        rng.next_usize(h - cut_h + 1)
134    } else {
135        0
136    };
137    let left = if w > cut_w {
138        rng.next_usize(w - cut_w + 1)
139    } else {
140        0
141    };
142
143    let actual_lambda = 1.0 - (cut_h * cut_w) as f64 / (h * w) as f64;
144
145    let mut mixed = x1.clone();
146
147    // Iterate over all indices; replace elements inside the bounding box with x2.
148    for (idx, val) in mixed.indexed_iter_mut() {
149        let raw = idx.slice();
150        let ih = raw[ndim - 2];
151        let iw = raw[ndim - 1];
152        if ih >= top && ih < top + cut_h && iw >= left && iw < left + cut_w {
153            *val = x2[idx.clone()];
154        }
155    }
156
157    Ok((mixed, actual_lambda))
158}
159
160/// Random 2-D crop: extract a sub-array of size [.., crop_h, crop_w] at a random position.
161///
162/// Input must have at least 2 dimensions. All leading batch/channel dimensions are preserved.
163pub fn random_crop_2d(
164    input: &ArrayD<f64>,
165    crop_h: usize,
166    crop_w: usize,
167    rng: &mut AugRng,
168) -> Result<ArrayD<f64>, AugmentationError> {
169    let ndim = input.ndim();
170    if ndim < 2 {
171        return Err(AugmentationError::InvalidCrop {
172            crop_size: crop_h,
173            input_size: 0,
174        });
175    }
176    let h = input.shape()[ndim - 2];
177    let w = input.shape()[ndim - 1];
178    if crop_h > h {
179        return Err(AugmentationError::InvalidCrop {
180            crop_size: crop_h,
181            input_size: h,
182        });
183    }
184    if crop_w > w {
185        return Err(AugmentationError::InvalidCrop {
186            crop_size: crop_w,
187            input_size: w,
188        });
189    }
190    let top = if h > crop_h {
191        rng.next_usize(h - crop_h + 1)
192    } else {
193        0
194    };
195    let left = if w > crop_w {
196        rng.next_usize(w - crop_w + 1)
197    } else {
198        0
199    };
200
201    crop_2d_impl(input, top, left, crop_h, crop_w)
202}
203
204/// Center crop: crop `[crop_h, crop_w]` from the center of the last two spatial dims.
205pub fn center_crop_2d(
206    input: &ArrayD<f64>,
207    crop_h: usize,
208    crop_w: usize,
209) -> Result<ArrayD<f64>, AugmentationError> {
210    let ndim = input.ndim();
211    if ndim < 2 {
212        return Err(AugmentationError::InvalidCrop {
213            crop_size: crop_h,
214            input_size: 0,
215        });
216    }
217    let h = input.shape()[ndim - 2];
218    let w = input.shape()[ndim - 1];
219    if crop_h > h {
220        return Err(AugmentationError::InvalidCrop {
221            crop_size: crop_h,
222            input_size: h,
223        });
224    }
225    if crop_w > w {
226        return Err(AugmentationError::InvalidCrop {
227            crop_size: crop_w,
228            input_size: w,
229        });
230    }
231    let top = (h - crop_h) / 2;
232    let left = (w - crop_w) / 2;
233    crop_2d_impl(input, top, left, crop_h, crop_w)
234}
235
236/// Internal helper: extract sub-array given top-left corner and crop dimensions.
237fn crop_2d_impl(
238    input: &ArrayD<f64>,
239    top: usize,
240    left: usize,
241    crop_h: usize,
242    crop_w: usize,
243) -> Result<ArrayD<f64>, AugmentationError> {
244    let ndim = input.ndim();
245
246    // Build output shape: leading dims unchanged, last two = crop_h, crop_w.
247    let mut out_shape = input.shape().to_vec();
248    out_shape[ndim - 2] = crop_h;
249    out_shape[ndim - 1] = crop_w;
250
251    let total: usize = out_shape.iter().product();
252    let mut data = Vec::with_capacity(total);
253
254    // Iterate over all output linear indices, reconstruct multi-dim coords,
255    // offset the spatial dims, then index the input.
256    for flat in 0..total {
257        let mut rem = flat;
258        let mut out_idx = vec![0usize; ndim];
259        for d in (0..ndim).rev() {
260            out_idx[d] = rem % out_shape[d];
261            rem /= out_shape[d];
262        }
263        // Map output spatial coords to input spatial coords.
264        let mut src_idx = out_idx.clone();
265        src_idx[ndim - 2] += top;
266        src_idx[ndim - 1] += left;
267
268        let v = input[IxDyn(&src_idx)];
269        data.push(v);
270    }
271
272    ArrayD::from_shape_vec(IxDyn(&out_shape), data).map_err(|_| AugmentationError::EmptyInput)
273}
274
275/// Random horizontal flip of the last two spatial dimensions with probability `p`.
276pub fn random_hflip(
277    input: &ArrayD<f64>,
278    p: f64,
279    rng: &mut AugRng,
280) -> Result<ArrayD<f64>, AugmentationError> {
281    if !(0.0..=1.0).contains(&p) {
282        return Err(AugmentationError::InvalidProbability(p));
283    }
284    if !rng.next_bool(p) {
285        return Ok(input.clone());
286    }
287    hflip_impl(input)
288}
289
290/// Random vertical flip of the last two spatial dimensions with probability `p`.
291pub fn random_vflip(
292    input: &ArrayD<f64>,
293    p: f64,
294    rng: &mut AugRng,
295) -> Result<ArrayD<f64>, AugmentationError> {
296    if !(0.0..=1.0).contains(&p) {
297        return Err(AugmentationError::InvalidProbability(p));
298    }
299    if !rng.next_bool(p) {
300        return Ok(input.clone());
301    }
302    vflip_impl(input)
303}
304
305/// Internal horizontal flip (flip along last dim = width).
306fn hflip_impl(input: &ArrayD<f64>) -> Result<ArrayD<f64>, AugmentationError> {
307    let ndim = input.ndim();
308    if ndim < 2 {
309        return Err(AugmentationError::InvalidCrop {
310            crop_size: 0,
311            input_size: 0,
312        });
313    }
314    let w = input.shape()[ndim - 1];
315    let shape = input.shape().to_vec();
316    let total: usize = shape.iter().product();
317    let mut data = vec![0.0f64; total];
318
319    for (flat, val) in input.iter().enumerate() {
320        let mut rem = flat;
321        let mut idx = vec![0usize; ndim];
322        for d in (0..ndim).rev() {
323            idx[d] = rem % shape[d];
324            rem /= shape[d];
325        }
326        // Flip the last (width) dimension.
327        idx[ndim - 1] = w - 1 - idx[ndim - 1];
328        let mut dst_flat = 0usize;
329        let mut stride = 1usize;
330        for d in (0..ndim).rev() {
331            dst_flat += idx[d] * stride;
332            stride *= shape[d];
333        }
334        data[dst_flat] = *val;
335    }
336
337    ArrayD::from_shape_vec(IxDyn(&shape), data).map_err(|_| AugmentationError::EmptyInput)
338}
339
340/// Internal vertical flip (flip along second-to-last dim = height).
341fn vflip_impl(input: &ArrayD<f64>) -> Result<ArrayD<f64>, AugmentationError> {
342    let ndim = input.ndim();
343    if ndim < 2 {
344        return Err(AugmentationError::InvalidCrop {
345            crop_size: 0,
346            input_size: 0,
347        });
348    }
349    let h = input.shape()[ndim - 2];
350    let shape = input.shape().to_vec();
351    let total: usize = shape.iter().product();
352    let mut data = vec![0.0f64; total];
353
354    for (flat, val) in input.iter().enumerate() {
355        let mut rem = flat;
356        let mut idx = vec![0usize; ndim];
357        for d in (0..ndim).rev() {
358            idx[d] = rem % shape[d];
359            rem /= shape[d];
360        }
361        // Flip the second-to-last (height) dimension.
362        idx[ndim - 2] = h - 1 - idx[ndim - 2];
363        let mut dst_flat = 0usize;
364        let mut stride = 1usize;
365        for d in (0..ndim).rev() {
366            dst_flat += idx[d] * stride;
367            stride *= shape[d];
368        }
369        data[dst_flat] = *val;
370    }
371
372    ArrayD::from_shape_vec(IxDyn(&shape), data).map_err(|_| AugmentationError::EmptyInput)
373}
374
375/// Normalize input: `(x − mean[c]) / std[c]`.
376///
377/// For a `[B, C, H, W]` or `[C, H, W]` tensor the normalization is per-channel.
378/// If `mean` and `std` have length 1 the same value is applied to all channels.
379/// For 1-D or 2-D tensors element-wise normalization uses the first element of `mean`/`std`.
380pub fn normalize(
381    input: &ArrayD<f64>,
382    mean: &[f64],
383    std: &[f64],
384) -> Result<ArrayD<f64>, AugmentationError> {
385    if mean.is_empty() || std.is_empty() {
386        return Err(AugmentationError::EmptyInput);
387    }
388    if input.is_empty() {
389        return Err(AugmentationError::EmptyInput);
390    }
391
392    let ndim = input.ndim();
393    let shape = input.shape().to_vec();
394
395    // Determine which axis is the channel axis: axis 1 for ndim >= 3, else element-wise.
396    if ndim >= 3 {
397        let num_channels = shape[ndim - 3]; // C dim: ..., C, H, W
398                                            // Broadcast mean/std to num_channels length.
399        let m: Vec<f64> = broadcast_stats(mean, num_channels)?;
400        let s: Vec<f64> = broadcast_stats(std, num_channels)?;
401
402        let mut result = input.clone();
403        // Iterate over every element and apply per-channel normalization.
404        for (idx, val) in result.indexed_iter_mut() {
405            let raw = idx.slice();
406            let c = raw[ndim - 3];
407            *val = (*val - m[c]) / s[c];
408        }
409        Ok(result)
410    } else {
411        // 1-D or 2-D: scalar normalization.
412        let m = mean[0];
413        let s = std[0];
414        Ok(input.mapv(|x| (x - m) / s))
415    }
416}
417
418/// Denormalize input: `x * std[c] + mean[c]` (inverse of `normalize`).
419pub fn denormalize(
420    input: &ArrayD<f64>,
421    mean: &[f64],
422    std: &[f64],
423) -> Result<ArrayD<f64>, AugmentationError> {
424    if mean.is_empty() || std.is_empty() {
425        return Err(AugmentationError::EmptyInput);
426    }
427    if input.is_empty() {
428        return Err(AugmentationError::EmptyInput);
429    }
430
431    let ndim = input.ndim();
432    let shape = input.shape().to_vec();
433
434    if ndim >= 3 {
435        let num_channels = shape[ndim - 3];
436        let m: Vec<f64> = broadcast_stats(mean, num_channels)?;
437        let s: Vec<f64> = broadcast_stats(std, num_channels)?;
438
439        let mut result = input.clone();
440        for (idx, val) in result.indexed_iter_mut() {
441            let raw = idx.slice();
442            let c = raw[ndim - 3];
443            *val = *val * s[c] + m[c];
444        }
445        Ok(result)
446    } else {
447        let m = mean[0];
448        let s = std[0];
449        Ok(input.mapv(|x| x * s + m))
450    }
451}
452
453/// Broadcast a stats slice to `n` channels.
454///
455/// If len == 1, replicate; if len == n, use as-is; otherwise error.
456fn broadcast_stats(stats: &[f64], n: usize) -> Result<Vec<f64>, AugmentationError> {
457    if stats.len() == 1 {
458        Ok(vec![stats[0]; n])
459    } else if stats.len() == n {
460        Ok(stats.to_vec())
461    } else {
462        Err(AugmentationError::ShapeMismatch {
463            expected: vec![n],
464            got: vec![stats.len()],
465        })
466    }
467}
468
469/// Clamp all elements to `[min_val, max_val]`.
470pub fn clip(input: &ArrayD<f64>, min_val: f64, max_val: f64) -> ArrayD<f64> {
471    input.mapv(|x| x.clamp(min_val, max_val))
472}