1use yscv_tensor::Tensor;
4
5use super::super::ImgProcError;
6use super::super::shape::hwc_shape;
7
8struct Rng(u64);
10
11impl Rng {
12 fn new(seed: u64) -> Self {
13 Self(if seed == 0 { 0xDEAD_BEEF } else { seed })
14 }
15
16 fn next_u64(&mut self) -> u64 {
17 let mut x = self.0;
18 x ^= x << 13;
19 x ^= x >> 7;
20 x ^= x << 17;
21 self.0 = x;
22 x
23 }
24
25 fn uniform(&mut self) -> f32 {
26 (self.next_u64() >> 40) as f32 / (1u64 << 24) as f32
27 }
28
29 fn uniform_range(&mut self, lo: f32, hi: f32) -> f32 {
30 lo + (hi - lo) * self.uniform()
31 }
32
33 #[allow(dead_code)]
34 fn normal(&mut self) -> f32 {
35 let u1 = self.uniform().max(1e-10);
36 let u2 = self.uniform();
37 (-2.0 * u1.ln()).sqrt() * (2.0 * std::f32::consts::PI * u2).cos()
38 }
39}
40
41pub fn random_crop(
43 image: &Tensor,
44 out_h: usize,
45 out_w: usize,
46 seed: u64,
47) -> Result<Tensor, ImgProcError> {
48 let (h, w, c) = hwc_shape(image)?;
49 if out_h > h || out_w > w {
50 return Err(ImgProcError::InvalidSize {
51 height: out_h,
52 width: out_w,
53 });
54 }
55 let mut rng = Rng::new(seed);
56 let y0 = (rng.uniform() * (h - out_h + 1) as f32) as usize;
57 let x0 = (rng.uniform() * (w - out_w + 1) as f32) as usize;
58
59 let data = image.data();
60 let mut out = vec![0.0f32; out_h * out_w * c];
61 for y in 0..out_h {
62 let src_start = ((y0 + y) * w + x0) * c;
63 let dst_start = (y * out_w) * c;
64 out[dst_start..dst_start + out_w * c]
65 .copy_from_slice(&data[src_start..src_start + out_w * c]);
66 }
67 Ok(Tensor::from_vec(vec![out_h, out_w, c], out)?)
68}
69
70pub fn random_horizontal_flip(image: &Tensor, p: f32, seed: u64) -> Result<Tensor, ImgProcError> {
72 let mut rng = Rng::new(seed);
73 if rng.uniform() >= p {
74 return Ok(image.clone());
75 }
76 let (h, w, c) = hwc_shape(image)?;
77 let data = image.data();
78 let mut out = vec![0.0f32; h * w * c];
79 for y in 0..h {
80 for x in 0..w {
81 let src = (y * w + (w - 1 - x)) * c;
82 let dst = (y * w + x) * c;
83 out[dst..dst + c].copy_from_slice(&data[src..src + c]);
84 }
85 }
86 Ok(Tensor::from_vec(vec![h, w, c], out)?)
87}
88
89pub fn random_vertical_flip(image: &Tensor, p: f32, seed: u64) -> Result<Tensor, ImgProcError> {
91 let mut rng = Rng::new(seed);
92 if rng.uniform() >= p {
93 return Ok(image.clone());
94 }
95 let (h, w, c) = hwc_shape(image)?;
96 let data = image.data();
97 let mut out = vec![0.0f32; h * w * c];
98 for y in 0..h {
99 let src_row = (h - 1 - y) * w * c;
100 let dst_row = y * w * c;
101 out[dst_row..dst_row + w * c].copy_from_slice(&data[src_row..src_row + w * c]);
102 }
103 Ok(Tensor::from_vec(vec![h, w, c], out)?)
104}
105
106pub fn random_rotation(
110 image: &Tensor,
111 max_degrees: f32,
112 seed: u64,
113) -> Result<Tensor, ImgProcError> {
114 let (h, w, c) = hwc_shape(image)?;
115 let mut rng = Rng::new(seed);
116 let angle_deg = rng.uniform_range(-max_degrees, max_degrees);
117 let angle = angle_deg * std::f32::consts::PI / 180.0;
118 let cos_a = angle.cos();
119 let sin_a = angle.sin();
120 let cx = w as f32 / 2.0;
121 let cy = h as f32 / 2.0;
122
123 let data = image.data();
124 let mut out = vec![0.0f32; h * w * c];
125
126 for y in 0..h {
127 for x in 0..w {
128 let dx = x as f32 - cx;
130 let dy = y as f32 - cy;
131 let src_x = cos_a * dx + sin_a * dy + cx;
132 let src_y = -sin_a * dx + cos_a * dy + cy;
133
134 if src_x >= 0.0 && src_x < (w - 1) as f32 && src_y >= 0.0 && src_y < (h - 1) as f32 {
135 let x0 = src_x.floor() as usize;
136 let y0 = src_y.floor() as usize;
137 let x1 = x0 + 1;
138 let y1 = y0 + 1;
139 let fx = src_x - x0 as f32;
140 let fy = src_y - y0 as f32;
141
142 for ch in 0..c {
143 let v00 = data[(y0 * w + x0) * c + ch];
144 let v10 = data[(y0 * w + x1) * c + ch];
145 let v01 = data[(y1 * w + x0) * c + ch];
146 let v11 = data[(y1 * w + x1) * c + ch];
147 out[(y * w + x) * c + ch] = v00 * (1.0 - fx) * (1.0 - fy)
148 + v10 * fx * (1.0 - fy)
149 + v01 * (1.0 - fx) * fy
150 + v11 * fx * fy;
151 }
152 }
153 }
154 }
155 Ok(Tensor::from_vec(vec![h, w, c], out)?)
156}
157
158pub fn random_erasing(
162 image: &Tensor,
163 p: f32,
164 scale_min: f32,
165 scale_max: f32,
166 ratio_min: f32,
167 ratio_max: f32,
168 fill_value: f32,
169 seed: u64,
170) -> Result<Tensor, ImgProcError> {
171 let (h, w, c) = hwc_shape(image)?;
172 let mut rng = Rng::new(seed);
173 if rng.uniform() >= p {
174 return Ok(image.clone());
175 }
176
177 let area = (h * w) as f32;
178 let target_area = area * rng.uniform_range(scale_min, scale_max);
179 let ratio = rng.uniform_range(ratio_min, ratio_max);
180 let eh = (target_area * ratio).sqrt() as usize;
181 let ew = (target_area / ratio).sqrt() as usize;
182 let eh = eh.min(h);
183 let ew = ew.min(w);
184
185 let y0 = (rng.uniform() * (h - eh + 1) as f32) as usize;
186 let x0 = (rng.uniform() * (w - ew + 1) as f32) as usize;
187
188 let mut out = image.data().to_vec();
189 for y in y0..y0 + eh {
190 for x in x0..x0 + ew {
191 let base = (y * w + x) * c;
192 for ch in 0..c {
193 out[base + ch] = fill_value;
194 }
195 }
196 }
197 Ok(Tensor::from_vec(vec![h, w, c], out)?)
198}
199
200pub fn color_jitter(
204 image: &Tensor,
205 brightness: f32,
206 contrast: f32,
207 saturation: f32,
208 hue: f32,
209 seed: u64,
210) -> Result<Tensor, ImgProcError> {
211 let (h, w, c) = hwc_shape(image)?;
212 if c != 3 {
213 return Err(ImgProcError::InvalidChannelCount {
214 expected: 3,
215 got: c,
216 });
217 }
218 let mut rng = Rng::new(seed);
219 let data = image.data();
220 let mut out = data.to_vec();
221
222 if brightness > 0.0 {
224 let factor = rng
225 .uniform_range(1.0 - brightness, 1.0 + brightness)
226 .max(0.0);
227 for v in out.iter_mut() {
228 *v *= factor;
229 }
230 }
231
232 if contrast > 0.0 {
234 let factor = rng.uniform_range(1.0 - contrast, 1.0 + contrast).max(0.0);
235 let mean: f32 = out.iter().sum::<f32>() / out.len() as f32;
236 for v in out.iter_mut() {
237 *v = (*v - mean) * factor + mean;
238 }
239 }
240
241 if saturation > 0.0 {
243 let factor = rng
244 .uniform_range(1.0 - saturation, 1.0 + saturation)
245 .max(0.0);
246 for i in 0..(h * w) {
247 let base = i * 3;
248 let gray = 0.299 * out[base] + 0.587 * out[base + 1] + 0.114 * out[base + 2];
249 for ch in 0..3 {
250 out[base + ch] = gray + (out[base + ch] - gray) * factor;
251 }
252 }
253 }
254
255 if hue > 0.0 {
257 let angle = rng.uniform_range(-hue, hue) * std::f32::consts::PI;
258 let cos_h = angle.cos();
259 let sin_h = angle.sin();
260 let sqrt3 = 3.0f32.sqrt();
261 for i in 0..(h * w) {
262 let base = i * 3;
263 let r = out[base];
264 let g = out[base + 1];
265 let b = out[base + 2];
266 out[base] =
267 (r + g + b) / 3.0 + (2.0 * r - g - b) / 3.0 * cos_h + (g - b) / sqrt3 * sin_h;
268 out[base + 1] = (r + g + b) / 3.0 - (2.0 * r - g - b) / 6.0 * cos_h
269 + (2.0 * b - 2.0 * g + 2.0 * r - g - b) / (2.0 * sqrt3) * sin_h;
270 out[base + 2] = r + g + b - out[base] - out[base + 1];
272 }
273 }
274
275 for v in out.iter_mut() {
277 *v = v.clamp(0.0, 1.0);
278 }
279
280 Ok(Tensor::from_vec(vec![h, w, c], out)?)
281}
282
283pub fn elastic_transform(
288 image: &Tensor,
289 alpha: f32,
290 sigma: f32,
291 seed: u64,
292) -> Result<Tensor, ImgProcError> {
293 let (h, w, c) = hwc_shape(image)?;
294 let mut rng = Rng::new(seed);
295
296 let n = h * w;
298 let mut dx: Vec<f32> = (0..n).map(|_| rng.uniform_range(-1.0, 1.0)).collect();
299 let mut dy: Vec<f32> = (0..n).map(|_| rng.uniform_range(-1.0, 1.0)).collect();
300
301 let kernel_size = (sigma * 3.0) as usize | 1; let half_k = kernel_size / 2;
304 for _ in 0..2 {
305 let dx_copy = dx.clone();
307 let dy_copy = dy.clone();
308 for y in 0..h {
309 for x in 0..w {
310 let mut sx = 0.0f32;
311 let mut sy = 0.0f32;
312 let mut count = 0.0f32;
313 for ky in y.saturating_sub(half_k)..=(y + half_k).min(h - 1) {
314 for kx in x.saturating_sub(half_k)..=(x + half_k).min(w - 1) {
315 sx += dx_copy[ky * w + kx];
316 sy += dy_copy[ky * w + kx];
317 count += 1.0;
318 }
319 }
320 dx[y * w + x] = sx / count;
321 dy[y * w + x] = sy / count;
322 }
323 }
324 }
325
326 for v in dx.iter_mut() {
328 *v *= alpha;
329 }
330 for v in dy.iter_mut() {
331 *v *= alpha;
332 }
333
334 let data = image.data();
336 let mut out = vec![0.0f32; h * w * c];
337 for y in 0..h {
338 for x in 0..w {
339 let src_x = x as f32 + dx[y * w + x];
340 let src_y = y as f32 + dy[y * w + x];
341
342 if src_x >= 0.0 && src_x < (w - 1) as f32 && src_y >= 0.0 && src_y < (h - 1) as f32 {
343 let x0 = src_x.floor() as usize;
344 let y0 = src_y.floor() as usize;
345 let x1 = x0 + 1;
346 let y1 = y0 + 1;
347 let fx = src_x - x0 as f32;
348 let fy = src_y - y0 as f32;
349 for ch in 0..c {
350 let v00 = data[(y0 * w + x0) * c + ch];
351 let v10 = data[(y0 * w + x1) * c + ch];
352 let v01 = data[(y1 * w + x0) * c + ch];
353 let v11 = data[(y1 * w + x1) * c + ch];
354 out[(y * w + x) * c + ch] = v00 * (1.0 - fx) * (1.0 - fy)
355 + v10 * fx * (1.0 - fy)
356 + v01 * (1.0 - fx) * fy
357 + v11 * fx * fy;
358 }
359 }
360 }
361 }
362 Ok(Tensor::from_vec(vec![h, w, c], out)?)
363}
364
365#[cfg(test)]
366mod tests {
367 use super::*;
368
369 fn test_image(h: usize, w: usize) -> Tensor {
370 let data: Vec<f32> = (0..h * w * 3)
371 .map(|i| (i as f32) / (h * w * 3) as f32)
372 .collect();
373 Tensor::from_vec(vec![h, w, 3], data).unwrap()
374 }
375
376 #[test]
377 fn test_random_crop_shape() {
378 let img = test_image(32, 32);
379 let cropped = random_crop(&img, 16, 16, 42).unwrap();
380 assert_eq!(cropped.shape(), &[16, 16, 3]);
381 }
382
383 #[test]
384 fn test_random_horizontal_flip() {
385 let img = Tensor::from_vec(vec![1, 3, 1], vec![1.0, 2.0, 3.0]).unwrap();
386 let flipped = random_horizontal_flip(&img, 1.0, 42).unwrap();
387 assert_eq!(flipped.data(), &[3.0, 2.0, 1.0]);
388 }
389
390 #[test]
391 fn test_random_vertical_flip() {
392 let img = Tensor::from_vec(vec![3, 1, 1], vec![1.0, 2.0, 3.0]).unwrap();
393 let flipped = random_vertical_flip(&img, 1.0, 42).unwrap();
394 assert_eq!(flipped.data(), &[3.0, 2.0, 1.0]);
395 }
396
397 #[test]
398 fn test_random_rotation_preserves_shape() {
399 let img = test_image(20, 20);
400 let rotated = random_rotation(&img, 30.0, 42).unwrap();
401 assert_eq!(rotated.shape(), &[20, 20, 3]);
402 }
403
404 #[test]
405 fn test_random_erasing_modifies_pixels() {
406 let img = Tensor::from_vec(vec![10, 10, 3], vec![1.0f32; 300]).unwrap();
407 let erased = random_erasing(&img, 1.0, 0.1, 0.3, 0.5, 2.0, 0.0, 42).unwrap();
408 let zeros = erased.data().iter().filter(|&&v| v == 0.0).count();
409 assert!(zeros > 0, "expected some erased pixels");
410 }
411
412 #[test]
413 fn test_color_jitter_preserves_shape() {
414 let img = test_image(8, 8);
415 let jittered = color_jitter(&img, 0.2, 0.2, 0.2, 0.0, 42).unwrap();
416 assert_eq!(jittered.shape(), &[8, 8, 3]);
417 for &v in jittered.data() {
419 assert!((0.0..=1.0).contains(&v));
420 }
421 }
422
423 #[test]
424 fn test_elastic_transform_preserves_shape() {
425 let img = test_image(16, 16);
426 let out = elastic_transform(&img, 10.0, 3.0, 42).unwrap();
427 assert_eq!(out.shape(), &[16, 16, 3]);
428 }
429}