Skip to main content

shrew_data/
augment.rs

1// Image Augmentation — random transforms for data augmentation
2//
3// All augmentations operate on `Sample::features` treating them as images
4// in [C, H, W] layout (channel-first, row-major).
5
6use rand::thread_rng;
7use rand::Rng;
8
9use crate::dataset::Sample;
10use crate::transform::Transform;
11
12// RandomHorizontalFlip
13
14/// Randomly flip an image horizontally with probability `p`.
15///
16/// Expects `feature_shape = [C, H, W]`.
17#[derive(Debug, Clone)]
18pub struct RandomHorizontalFlip {
19    pub p: f64,
20}
21
22impl RandomHorizontalFlip {
23    pub fn new(p: f64) -> Self {
24        Self { p }
25    }
26}
27
28impl Transform for RandomHorizontalFlip {
29    fn apply(&self, mut sample: Sample) -> Sample {
30        let mut rng = thread_rng();
31        if rng.gen::<f64>() >= self.p {
32            return sample;
33        }
34        let shape = &sample.feature_shape;
35        if shape.len() != 3 {
36            return sample;
37        }
38        let (c, h, w) = (shape[0], shape[1], shape[2]);
39        let mut flipped = vec![0.0; c * h * w];
40        for ch in 0..c {
41            for row in 0..h {
42                for col in 0..w {
43                    let src = ch * h * w + row * w + col;
44                    let dst = ch * h * w + row * w + (w - 1 - col);
45                    flipped[dst] = sample.features[src];
46                }
47            }
48        }
49        sample.features = flipped;
50        sample
51    }
52}
53
54// RandomVerticalFlip
55
56/// Randomly flip an image vertically with probability `p`.
57///
58/// Expects `feature_shape = [C, H, W]`.
59#[derive(Debug, Clone)]
60pub struct RandomVerticalFlip {
61    pub p: f64,
62}
63
64impl RandomVerticalFlip {
65    pub fn new(p: f64) -> Self {
66        Self { p }
67    }
68}
69
70impl Transform for RandomVerticalFlip {
71    fn apply(&self, mut sample: Sample) -> Sample {
72        let mut rng = thread_rng();
73        if rng.gen::<f64>() >= self.p {
74            return sample;
75        }
76        let shape = &sample.feature_shape;
77        if shape.len() != 3 {
78            return sample;
79        }
80        let (c, h, w) = (shape[0], shape[1], shape[2]);
81        let mut flipped = vec![0.0; c * h * w];
82        for ch in 0..c {
83            for row in 0..h {
84                for col in 0..w {
85                    let src = ch * h * w + row * w + col;
86                    let dst = ch * h * w + (h - 1 - row) * w + col;
87                    flipped[dst] = sample.features[src];
88                }
89            }
90        }
91        sample.features = flipped;
92        sample
93    }
94}
95
96// RandomCrop
97
98/// Randomly crop an image to `[crop_h, crop_w]`, optionally with zero-padding.
99///
100/// Expects `feature_shape = [C, H, W]`.  If `padding > 0`, the image is first
101/// padded with zeros on all sides by `padding` pixels.
102#[derive(Debug, Clone)]
103pub struct RandomCrop {
104    pub crop_h: usize,
105    pub crop_w: usize,
106    pub padding: usize,
107}
108
109impl RandomCrop {
110    pub fn new(crop_h: usize, crop_w: usize, padding: usize) -> Self {
111        Self {
112            crop_h,
113            crop_w,
114            padding,
115        }
116    }
117}
118
119impl Transform for RandomCrop {
120    fn apply(&self, mut sample: Sample) -> Sample {
121        let shape = &sample.feature_shape;
122        if shape.len() != 3 {
123            return sample;
124        }
125        let (c, h, w) = (shape[0], shape[1], shape[2]);
126        let pad = self.padding;
127        let padded_h = h + 2 * pad;
128        let padded_w = w + 2 * pad;
129
130        // Build padded image
131        let mut padded = vec![0.0; c * padded_h * padded_w];
132        for ch in 0..c {
133            for row in 0..h {
134                for col in 0..w {
135                    let src = ch * h * w + row * w + col;
136                    let dst = ch * padded_h * padded_w + (row + pad) * padded_w + (col + pad);
137                    padded[dst] = sample.features[src];
138                }
139            }
140        }
141
142        // Random crop position
143        let mut rng = thread_rng();
144        let max_y = padded_h.saturating_sub(self.crop_h);
145        let max_x = padded_w.saturating_sub(self.crop_w);
146        let y0 = if max_y > 0 {
147            rng.gen_range(0..=max_y)
148        } else {
149            0
150        };
151        let x0 = if max_x > 0 {
152            rng.gen_range(0..=max_x)
153        } else {
154            0
155        };
156
157        let mut cropped = vec![0.0; c * self.crop_h * self.crop_w];
158        for ch in 0..c {
159            for row in 0..self.crop_h {
160                for col in 0..self.crop_w {
161                    let src = ch * padded_h * padded_w + (y0 + row) * padded_w + (x0 + col);
162                    let dst = ch * self.crop_h * self.crop_w + row * self.crop_w + col;
163                    cropped[dst] = padded[src];
164                }
165            }
166        }
167
168        sample.features = cropped;
169        sample.feature_shape = vec![c, self.crop_h, self.crop_w];
170        sample
171    }
172}
173
174// RandomNoise — additive Gaussian noise
175
176/// Add Gaussian noise to features: `x' = x + N(0, std)`.
177#[derive(Debug, Clone)]
178pub struct RandomNoise {
179    pub std_dev: f64,
180}
181
182impl RandomNoise {
183    pub fn new(std_dev: f64) -> Self {
184        Self { std_dev }
185    }
186}
187
188impl Transform for RandomNoise {
189    fn apply(&self, mut sample: Sample) -> Sample {
190        use rand_distr::{Distribution, Normal};
191        let normal = Normal::new(0.0, self.std_dev).unwrap();
192        let mut rng = thread_rng();
193        for v in &mut sample.features {
194            *v += normal.sample(&mut rng);
195        }
196        sample
197    }
198}
199
200// RandomErasing — randomly erase a rectangular region (cutout)
201
202/// Erase a random rectangular region, replacing with a constant value.
203///
204/// Uses a random rectangle whose area is `[min_area_ratio, max_area_ratio]`
205/// of the total, with aspect ratio in `[min_aspect, max_aspect]`.
206/// Expects `feature_shape = [C, H, W]`.
207#[derive(Debug, Clone)]
208pub struct RandomErasing {
209    pub p: f64,
210    pub fill_value: f64,
211    pub min_area_ratio: f64,
212    pub max_area_ratio: f64,
213}
214
215impl RandomErasing {
216    pub fn new(p: f64) -> Self {
217        Self {
218            p,
219            fill_value: 0.0,
220            min_area_ratio: 0.02,
221            max_area_ratio: 0.33,
222        }
223    }
224}
225
226impl Transform for RandomErasing {
227    fn apply(&self, mut sample: Sample) -> Sample {
228        let mut rng = thread_rng();
229        if rng.gen::<f64>() >= self.p {
230            return sample;
231        }
232        let shape = &sample.feature_shape;
233        if shape.len() != 3 {
234            return sample;
235        }
236        let (c, h, w) = (shape[0], shape[1], shape[2]);
237        let area = (h * w) as f64;
238
239        // Pick random area and aspect ratio
240        let target_area = area * rng.gen_range(self.min_area_ratio..self.max_area_ratio);
241        let aspect = rng.gen_range(0.3f64..3.3f64);
242        let erase_h = (target_area * aspect).sqrt().round() as usize;
243        let erase_w = (target_area / aspect).sqrt().round() as usize;
244
245        if erase_h >= h || erase_w >= w {
246            return sample;
247        }
248
249        let y0 = rng.gen_range(0..h - erase_h);
250        let x0 = rng.gen_range(0..w - erase_w);
251
252        for ch in 0..c {
253            for row in y0..y0 + erase_h {
254                for col in x0..x0 + erase_w {
255                    sample.features[ch * h * w + row * w + col] = self.fill_value;
256                }
257            }
258        }
259        sample
260    }
261}
262
263// ColorJitter — random brightness/contrast for images normalised to [0,1]
264
265/// Randomly adjust brightness and contrast.
266///
267/// brightness: `x' = x + uniform(-brightness, +brightness)`
268/// contrast:   `x' = mean + (x - mean) * factor` where factor ∈ `[1 - contrast, 1 + contrast]`
269#[derive(Debug, Clone)]
270pub struct ColorJitter {
271    pub brightness: f64,
272    pub contrast: f64,
273}
274
275impl ColorJitter {
276    pub fn new(brightness: f64, contrast: f64) -> Self {
277        Self {
278            brightness,
279            contrast,
280        }
281    }
282}
283
284impl Transform for ColorJitter {
285    fn apply(&self, mut sample: Sample) -> Sample {
286        let mut rng = thread_rng();
287
288        // Brightness
289        if self.brightness > 0.0 {
290            let delta = rng.gen_range(-self.brightness..self.brightness);
291            for v in &mut sample.features {
292                *v += delta;
293            }
294        }
295
296        // Contrast
297        if self.contrast > 0.0 {
298            let factor = rng.gen_range(1.0 - self.contrast..1.0 + self.contrast);
299            let mean: f64 = sample.features.iter().sum::<f64>() / sample.features.len() as f64;
300            for v in &mut sample.features {
301                *v = mean + (*v - mean) * factor;
302            }
303        }
304
305        sample
306    }
307}
308
309// Tests
310
311#[cfg(test)]
312mod tests {
313    use super::*;
314
315    fn make_image_sample(c: usize, h: usize, w: usize) -> Sample {
316        let n = c * h * w;
317        Sample {
318            features: (0..n).map(|i| i as f64).collect(),
319            feature_shape: vec![c, h, w],
320            target: vec![0.0],
321            target_shape: vec![1],
322        }
323    }
324
325    #[test]
326    fn horizontal_flip_deterministic() {
327        // p=1.0 always flips
328        let flip = RandomHorizontalFlip::new(1.0);
329        let sample = make_image_sample(1, 2, 3);
330        // Original: [0,1,2, 3,4,5]
331        let result = flip.apply(sample);
332        // Flipped:  [2,1,0, 5,4,3]
333        assert_eq!(result.features, vec![2.0, 1.0, 0.0, 5.0, 4.0, 3.0]);
334    }
335
336    #[test]
337    fn vertical_flip_deterministic() {
338        let flip = RandomVerticalFlip::new(1.0);
339        let sample = make_image_sample(1, 2, 3);
340        // Original: [0,1,2, 3,4,5]
341        let result = flip.apply(sample);
342        // Flipped:  [3,4,5, 0,1,2]
343        assert_eq!(result.features, vec![3.0, 4.0, 5.0, 0.0, 1.0, 2.0]);
344    }
345
346    #[test]
347    fn random_crop_no_padding_same_size() {
348        let crop = RandomCrop::new(2, 3, 0);
349        let sample = make_image_sample(1, 2, 3);
350        let result = crop.apply(sample);
351        assert_eq!(result.feature_shape, vec![1, 2, 3]);
352        assert_eq!(result.features.len(), 6);
353    }
354
355    #[test]
356    fn random_crop_with_padding() {
357        let crop = RandomCrop::new(4, 4, 1);
358        let sample = make_image_sample(1, 4, 4);
359        let result = crop.apply(sample);
360        assert_eq!(result.feature_shape, vec![1, 4, 4]);
361        assert_eq!(result.features.len(), 16);
362    }
363
364    #[test]
365    fn random_noise_changes_values() {
366        let noise = RandomNoise::new(1.0);
367        let sample = make_image_sample(1, 2, 2);
368        let result = noise.apply(sample.clone());
369        // Values should be different (with astronomical probability)
370        let changed = result
371            .features
372            .iter()
373            .zip(sample.features.iter())
374            .any(|(a, b)| (a - b).abs() > 1e-10);
375        assert!(changed);
376    }
377
378    #[test]
379    fn random_erasing_p1() {
380        let erasing = RandomErasing::new(1.0);
381        let sample = make_image_sample(1, 8, 8);
382        let result = erasing.apply(sample);
383        // At least some zeros should be introduced
384        let num_zeros = result.features.iter().filter(|&&v| v == 0.0).count();
385        // Index 0 is naturally 0.0, so at least one more zero from erasing
386        assert!(num_zeros >= 2, "Expected erased zeros, got {}", num_zeros);
387    }
388
389    #[test]
390    fn color_jitter() {
391        let jitter = ColorJitter::new(0.1, 0.1);
392        let sample = make_image_sample(1, 2, 2);
393        let result = jitter.apply(sample.clone());
394        let changed = result
395            .features
396            .iter()
397            .zip(sample.features.iter())
398            .any(|(a, b)| (a - b).abs() > 1e-10);
399        assert!(changed);
400    }
401}