Skip to main content

yscv_imgproc/ops/
augment.rs

1//! Image augmentation transforms (random crop, flip, rotation, color jitter, etc.).
2
3use yscv_tensor::Tensor;
4
5use super::super::ImgProcError;
6use super::super::shape::hwc_shape;
7
8/// Simple deterministic PRNG (xorshift64).
9struct Rng(u64);
10
11impl Rng {
12    fn new(seed: u64) -> Self {
13        Self(if seed == 0 { 0xDEAD_BEEF } else { seed })
14    }
15
16    fn next_u64(&mut self) -> u64 {
17        let mut x = self.0;
18        x ^= x << 13;
19        x ^= x >> 7;
20        x ^= x << 17;
21        self.0 = x;
22        x
23    }
24
25    fn uniform(&mut self) -> f32 {
26        (self.next_u64() >> 40) as f32 / (1u64 << 24) as f32
27    }
28
29    fn uniform_range(&mut self, lo: f32, hi: f32) -> f32 {
30        lo + (hi - lo) * self.uniform()
31    }
32
33    #[allow(dead_code)]
34    fn normal(&mut self) -> f32 {
35        let u1 = self.uniform().max(1e-10);
36        let u2 = self.uniform();
37        (-2.0 * u1.ln()).sqrt() * (2.0 * std::f32::consts::PI * u2).cos()
38    }
39}
40
41/// Randomly crop a `[H, W, C]` image to `(out_h, out_w)`.
42pub fn random_crop(
43    image: &Tensor,
44    out_h: usize,
45    out_w: usize,
46    seed: u64,
47) -> Result<Tensor, ImgProcError> {
48    let (h, w, c) = hwc_shape(image)?;
49    if out_h > h || out_w > w {
50        return Err(ImgProcError::InvalidSize {
51            height: out_h,
52            width: out_w,
53        });
54    }
55    let mut rng = Rng::new(seed);
56    let y0 = (rng.uniform() * (h - out_h + 1) as f32) as usize;
57    let x0 = (rng.uniform() * (w - out_w + 1) as f32) as usize;
58
59    let data = image.data();
60    let mut out = vec![0.0f32; out_h * out_w * c];
61    for y in 0..out_h {
62        let src_start = ((y0 + y) * w + x0) * c;
63        let dst_start = (y * out_w) * c;
64        out[dst_start..dst_start + out_w * c]
65            .copy_from_slice(&data[src_start..src_start + out_w * c]);
66    }
67    Ok(Tensor::from_vec(vec![out_h, out_w, c], out)?)
68}
69
70/// Randomly flip horizontally with probability `p`.
71pub fn random_horizontal_flip(image: &Tensor, p: f32, seed: u64) -> Result<Tensor, ImgProcError> {
72    let mut rng = Rng::new(seed);
73    if rng.uniform() >= p {
74        return Ok(image.clone());
75    }
76    let (h, w, c) = hwc_shape(image)?;
77    let data = image.data();
78    let mut out = vec![0.0f32; h * w * c];
79    for y in 0..h {
80        for x in 0..w {
81            let src = (y * w + (w - 1 - x)) * c;
82            let dst = (y * w + x) * c;
83            out[dst..dst + c].copy_from_slice(&data[src..src + c]);
84        }
85    }
86    Ok(Tensor::from_vec(vec![h, w, c], out)?)
87}
88
89/// Randomly flip vertically with probability `p`.
90pub fn random_vertical_flip(image: &Tensor, p: f32, seed: u64) -> Result<Tensor, ImgProcError> {
91    let mut rng = Rng::new(seed);
92    if rng.uniform() >= p {
93        return Ok(image.clone());
94    }
95    let (h, w, c) = hwc_shape(image)?;
96    let data = image.data();
97    let mut out = vec![0.0f32; h * w * c];
98    for y in 0..h {
99        let src_row = (h - 1 - y) * w * c;
100        let dst_row = y * w * c;
101        out[dst_row..dst_row + w * c].copy_from_slice(&data[src_row..src_row + w * c]);
102    }
103    Ok(Tensor::from_vec(vec![h, w, c], out)?)
104}
105
106/// Rotate image by a random angle in `[-max_degrees, max_degrees]`.
107///
108/// Uses bilinear interpolation. Pixels outside the image are filled with 0.
109pub fn random_rotation(
110    image: &Tensor,
111    max_degrees: f32,
112    seed: u64,
113) -> Result<Tensor, ImgProcError> {
114    let (h, w, c) = hwc_shape(image)?;
115    let mut rng = Rng::new(seed);
116    let angle_deg = rng.uniform_range(-max_degrees, max_degrees);
117    let angle = angle_deg * std::f32::consts::PI / 180.0;
118    let cos_a = angle.cos();
119    let sin_a = angle.sin();
120    let cx = w as f32 / 2.0;
121    let cy = h as f32 / 2.0;
122
123    let data = image.data();
124    let mut out = vec![0.0f32; h * w * c];
125
126    for y in 0..h {
127        for x in 0..w {
128            // Inverse rotation to find source pixel
129            let dx = x as f32 - cx;
130            let dy = y as f32 - cy;
131            let src_x = cos_a * dx + sin_a * dy + cx;
132            let src_y = -sin_a * dx + cos_a * dy + cy;
133
134            if src_x >= 0.0 && src_x < (w - 1) as f32 && src_y >= 0.0 && src_y < (h - 1) as f32 {
135                let x0 = src_x.floor() as usize;
136                let y0 = src_y.floor() as usize;
137                let x1 = x0 + 1;
138                let y1 = y0 + 1;
139                let fx = src_x - x0 as f32;
140                let fy = src_y - y0 as f32;
141
142                for ch in 0..c {
143                    let v00 = data[(y0 * w + x0) * c + ch];
144                    let v10 = data[(y0 * w + x1) * c + ch];
145                    let v01 = data[(y1 * w + x0) * c + ch];
146                    let v11 = data[(y1 * w + x1) * c + ch];
147                    out[(y * w + x) * c + ch] = v00 * (1.0 - fx) * (1.0 - fy)
148                        + v10 * fx * (1.0 - fy)
149                        + v01 * (1.0 - fx) * fy
150                        + v11 * fx * fy;
151                }
152            }
153        }
154    }
155    Ok(Tensor::from_vec(vec![h, w, c], out)?)
156}
157
158/// Random erasing (cutout): randomly erase a rectangular region with probability `p`.
159///
160/// The erased region is filled with `fill_value` (typically 0.0).
161pub fn random_erasing(
162    image: &Tensor,
163    p: f32,
164    scale_min: f32,
165    scale_max: f32,
166    ratio_min: f32,
167    ratio_max: f32,
168    fill_value: f32,
169    seed: u64,
170) -> Result<Tensor, ImgProcError> {
171    let (h, w, c) = hwc_shape(image)?;
172    let mut rng = Rng::new(seed);
173    if rng.uniform() >= p {
174        return Ok(image.clone());
175    }
176
177    let area = (h * w) as f32;
178    let target_area = area * rng.uniform_range(scale_min, scale_max);
179    let ratio = rng.uniform_range(ratio_min, ratio_max);
180    let eh = (target_area * ratio).sqrt() as usize;
181    let ew = (target_area / ratio).sqrt() as usize;
182    let eh = eh.min(h);
183    let ew = ew.min(w);
184
185    let y0 = (rng.uniform() * (h - eh + 1) as f32) as usize;
186    let x0 = (rng.uniform() * (w - ew + 1) as f32) as usize;
187
188    let mut out = image.data().to_vec();
189    for y in y0..y0 + eh {
190        for x in x0..x0 + ew {
191            let base = (y * w + x) * c;
192            for ch in 0..c {
193                out[base + ch] = fill_value;
194            }
195        }
196    }
197    Ok(Tensor::from_vec(vec![h, w, c], out)?)
198}
199
200/// Apply random color jitter: brightness, contrast, saturation, hue adjustments.
201///
202/// Each factor is randomized in `[1-amount, 1+amount]`.
203pub fn color_jitter(
204    image: &Tensor,
205    brightness: f32,
206    contrast: f32,
207    saturation: f32,
208    hue: f32,
209    seed: u64,
210) -> Result<Tensor, ImgProcError> {
211    let (h, w, c) = hwc_shape(image)?;
212    if c != 3 {
213        return Err(ImgProcError::InvalidChannelCount {
214            expected: 3,
215            got: c,
216        });
217    }
218    let mut rng = Rng::new(seed);
219    let data = image.data();
220    let mut out = data.to_vec();
221
222    // Brightness
223    if brightness > 0.0 {
224        let factor = rng
225            .uniform_range(1.0 - brightness, 1.0 + brightness)
226            .max(0.0);
227        for v in out.iter_mut() {
228            *v *= factor;
229        }
230    }
231
232    // Contrast
233    if contrast > 0.0 {
234        let factor = rng.uniform_range(1.0 - contrast, 1.0 + contrast).max(0.0);
235        let mean: f32 = out.iter().sum::<f32>() / out.len() as f32;
236        for v in out.iter_mut() {
237            *v = (*v - mean) * factor + mean;
238        }
239    }
240
241    // Saturation
242    if saturation > 0.0 {
243        let factor = rng
244            .uniform_range(1.0 - saturation, 1.0 + saturation)
245            .max(0.0);
246        for i in 0..(h * w) {
247            let base = i * 3;
248            let gray = 0.299 * out[base] + 0.587 * out[base + 1] + 0.114 * out[base + 2];
249            for ch in 0..3 {
250                out[base + ch] = gray + (out[base + ch] - gray) * factor;
251            }
252        }
253    }
254
255    // Hue (simple rotation in RGB space approximation)
256    if hue > 0.0 {
257        let angle = rng.uniform_range(-hue, hue) * std::f32::consts::PI;
258        let cos_h = angle.cos();
259        let sin_h = angle.sin();
260        let sqrt3 = 3.0f32.sqrt();
261        for i in 0..(h * w) {
262            let base = i * 3;
263            let r = out[base];
264            let g = out[base + 1];
265            let b = out[base + 2];
266            out[base] =
267                (r + g + b) / 3.0 + (2.0 * r - g - b) / 3.0 * cos_h + (g - b) / sqrt3 * sin_h;
268            out[base + 1] = (r + g + b) / 3.0 - (2.0 * r - g - b) / 6.0 * cos_h
269                + (2.0 * b - 2.0 * g + 2.0 * r - g - b) / (2.0 * sqrt3) * sin_h;
270            // simplified: recompute b from the constraint r+g+b is preserved
271            out[base + 2] = r + g + b - out[base] - out[base + 1];
272        }
273    }
274
275    // Clamp to [0, 1]
276    for v in out.iter_mut() {
277        *v = v.clamp(0.0, 1.0);
278    }
279
280    Ok(Tensor::from_vec(vec![h, w, c], out)?)
281}
282
283/// Elastic distortion transform.
284///
285/// Generates random displacement fields and applies Gaussian smoothing,
286/// then remaps pixels with bilinear interpolation.
287pub fn elastic_transform(
288    image: &Tensor,
289    alpha: f32,
290    sigma: f32,
291    seed: u64,
292) -> Result<Tensor, ImgProcError> {
293    let (h, w, c) = hwc_shape(image)?;
294    let mut rng = Rng::new(seed);
295
296    // Generate random displacement fields [-1, 1]
297    let n = h * w;
298    let mut dx: Vec<f32> = (0..n).map(|_| rng.uniform_range(-1.0, 1.0)).collect();
299    let mut dy: Vec<f32> = (0..n).map(|_| rng.uniform_range(-1.0, 1.0)).collect();
300
301    // Gaussian blur the displacement fields (simple box blur approximation)
302    let kernel_size = (sigma * 3.0) as usize | 1; // ensure odd
303    let half_k = kernel_size / 2;
304    for _ in 0..2 {
305        // 2 passes of box blur ≈ gaussian
306        let dx_copy = dx.clone();
307        let dy_copy = dy.clone();
308        for y in 0..h {
309            for x in 0..w {
310                let mut sx = 0.0f32;
311                let mut sy = 0.0f32;
312                let mut count = 0.0f32;
313                for ky in y.saturating_sub(half_k)..=(y + half_k).min(h - 1) {
314                    for kx in x.saturating_sub(half_k)..=(x + half_k).min(w - 1) {
315                        sx += dx_copy[ky * w + kx];
316                        sy += dy_copy[ky * w + kx];
317                        count += 1.0;
318                    }
319                }
320                dx[y * w + x] = sx / count;
321                dy[y * w + x] = sy / count;
322            }
323        }
324    }
325
326    // Scale by alpha
327    for v in dx.iter_mut() {
328        *v *= alpha;
329    }
330    for v in dy.iter_mut() {
331        *v *= alpha;
332    }
333
334    // Remap with bilinear interpolation
335    let data = image.data();
336    let mut out = vec![0.0f32; h * w * c];
337    for y in 0..h {
338        for x in 0..w {
339            let src_x = x as f32 + dx[y * w + x];
340            let src_y = y as f32 + dy[y * w + x];
341
342            if src_x >= 0.0 && src_x < (w - 1) as f32 && src_y >= 0.0 && src_y < (h - 1) as f32 {
343                let x0 = src_x.floor() as usize;
344                let y0 = src_y.floor() as usize;
345                let x1 = x0 + 1;
346                let y1 = y0 + 1;
347                let fx = src_x - x0 as f32;
348                let fy = src_y - y0 as f32;
349                for ch in 0..c {
350                    let v00 = data[(y0 * w + x0) * c + ch];
351                    let v10 = data[(y0 * w + x1) * c + ch];
352                    let v01 = data[(y1 * w + x0) * c + ch];
353                    let v11 = data[(y1 * w + x1) * c + ch];
354                    out[(y * w + x) * c + ch] = v00 * (1.0 - fx) * (1.0 - fy)
355                        + v10 * fx * (1.0 - fy)
356                        + v01 * (1.0 - fx) * fy
357                        + v11 * fx * fy;
358                }
359            }
360        }
361    }
362    Ok(Tensor::from_vec(vec![h, w, c], out)?)
363}
364
365#[cfg(test)]
366mod tests {
367    use super::*;
368
369    fn test_image(h: usize, w: usize) -> Tensor {
370        let data: Vec<f32> = (0..h * w * 3)
371            .map(|i| (i as f32) / (h * w * 3) as f32)
372            .collect();
373        Tensor::from_vec(vec![h, w, 3], data).unwrap()
374    }
375
376    #[test]
377    fn test_random_crop_shape() {
378        let img = test_image(32, 32);
379        let cropped = random_crop(&img, 16, 16, 42).unwrap();
380        assert_eq!(cropped.shape(), &[16, 16, 3]);
381    }
382
383    #[test]
384    fn test_random_horizontal_flip() {
385        let img = Tensor::from_vec(vec![1, 3, 1], vec![1.0, 2.0, 3.0]).unwrap();
386        let flipped = random_horizontal_flip(&img, 1.0, 42).unwrap();
387        assert_eq!(flipped.data(), &[3.0, 2.0, 1.0]);
388    }
389
390    #[test]
391    fn test_random_vertical_flip() {
392        let img = Tensor::from_vec(vec![3, 1, 1], vec![1.0, 2.0, 3.0]).unwrap();
393        let flipped = random_vertical_flip(&img, 1.0, 42).unwrap();
394        assert_eq!(flipped.data(), &[3.0, 2.0, 1.0]);
395    }
396
397    #[test]
398    fn test_random_rotation_preserves_shape() {
399        let img = test_image(20, 20);
400        let rotated = random_rotation(&img, 30.0, 42).unwrap();
401        assert_eq!(rotated.shape(), &[20, 20, 3]);
402    }
403
404    #[test]
405    fn test_random_erasing_modifies_pixels() {
406        let img = Tensor::from_vec(vec![10, 10, 3], vec![1.0f32; 300]).unwrap();
407        let erased = random_erasing(&img, 1.0, 0.1, 0.3, 0.5, 2.0, 0.0, 42).unwrap();
408        let zeros = erased.data().iter().filter(|&&v| v == 0.0).count();
409        assert!(zeros > 0, "expected some erased pixels");
410    }
411
412    #[test]
413    fn test_color_jitter_preserves_shape() {
414        let img = test_image(8, 8);
415        let jittered = color_jitter(&img, 0.2, 0.2, 0.2, 0.0, 42).unwrap();
416        assert_eq!(jittered.shape(), &[8, 8, 3]);
417        // Values should be in [0, 1]
418        for &v in jittered.data() {
419            assert!((0.0..=1.0).contains(&v));
420        }
421    }
422
423    #[test]
424    fn test_elastic_transform_preserves_shape() {
425        let img = test_image(16, 16);
426        let out = elastic_transform(&img, 10.0, 3.0, 42).unwrap();
427        assert_eq!(out.shape(), &[16, 16, 3]);
428    }
429}