tensorlogic_train/augmentation/
functional.rs1use scirs2_core::ndarray::{ArrayD, Dimension, IxDyn};
2
3use super::error::AugmentationError;
4use super::rng::{sample_beta_symmetric, AugRng};
5
6pub fn gaussian_noise(
8 input: &ArrayD<f64>,
9 std: f64,
10 rng: &mut AugRng,
11) -> Result<ArrayD<f64>, AugmentationError> {
12 if std < 0.0 {
13 return Err(AugmentationError::InvalidNoise { std });
14 }
15 if input.is_empty() {
16 return Err(AugmentationError::EmptyInput);
17 }
18 let noisy = input.mapv(|x| x + rng.next_normal() * std);
19 Ok(noisy)
20}
21
22pub fn dropout(
26 input: &ArrayD<f64>,
27 p: f64,
28 training: bool,
29 rng: &mut AugRng,
30) -> Result<ArrayD<f64>, AugmentationError> {
31 if !(0.0..=1.0).contains(&p) {
32 return Err(AugmentationError::InvalidProbability(p));
33 }
34 if !training {
35 return Ok(input.clone());
36 }
37 let scale = if (p - 1.0).abs() < 1e-12 {
38 0.0
39 } else {
40 1.0 / (1.0 - p)
41 };
42 let result = input.mapv(|x| if rng.next_bool(p) { 0.0 } else { x * scale });
43 Ok(result)
44}
45
46pub fn dropout_mask(
50 shape: &[usize],
51 p: f64,
52 rng: &mut AugRng,
53) -> Result<ArrayD<f64>, AugmentationError> {
54 if !(0.0..=1.0).contains(&p) {
55 return Err(AugmentationError::InvalidProbability(p));
56 }
57 let total: usize = shape.iter().product();
58 let data: Vec<f64> = (0..total)
59 .map(|_| if rng.next_bool(p) { 0.0 } else { 1.0 })
60 .collect();
61 ArrayD::from_shape_vec(IxDyn(shape), data).map_err(|_| AugmentationError::EmptyInput)
62}
63
64pub fn mixup(
68 x1: &ArrayD<f64>,
69 x2: &ArrayD<f64>,
70 alpha: f64,
71 rng: &mut AugRng,
72) -> Result<(ArrayD<f64>, f64), AugmentationError> {
73 if alpha <= 0.0 {
74 return Err(AugmentationError::InvalidAlpha(alpha));
75 }
76 if x1.shape() != x2.shape() {
77 return Err(AugmentationError::ShapeMismatch {
78 expected: x1.shape().to_vec(),
79 got: x2.shape().to_vec(),
80 });
81 }
82 if x1.is_empty() {
83 return Err(AugmentationError::EmptyInput);
84 }
85 let lambda = sample_beta_symmetric(alpha, rng);
86 let mixed = x1.mapv(|v| v * lambda) + x2.mapv(|v| v * (1.0 - lambda));
87 Ok((mixed, lambda))
88}
89
90pub fn cutmix(
96 x1: &ArrayD<f64>,
97 x2: &ArrayD<f64>,
98 alpha: f64,
99 rng: &mut AugRng,
100) -> Result<(ArrayD<f64>, f64), AugmentationError> {
101 if alpha <= 0.0 {
102 return Err(AugmentationError::InvalidAlpha(alpha));
103 }
104 if x1.shape() != x2.shape() {
105 return Err(AugmentationError::ShapeMismatch {
106 expected: x1.shape().to_vec(),
107 got: x2.shape().to_vec(),
108 });
109 }
110 if x1.ndim() < 2 {
111 return Err(AugmentationError::ShapeMismatch {
112 expected: vec![2],
113 got: x1.shape().to_vec(),
114 });
115 }
116 if x1.is_empty() {
117 return Err(AugmentationError::EmptyInput);
118 }
119
120 let ndim = x1.ndim();
121 let h = x1.shape()[ndim - 2];
122 let w = x1.shape()[ndim - 1];
123
124 let lambda_raw = sample_beta_symmetric(alpha, rng);
126 let cut_ratio = (1.0 - lambda_raw).sqrt();
128 let cut_h = ((h as f64 * cut_ratio) as usize).max(1).min(h);
129 let cut_w = ((w as f64 * cut_ratio) as usize).max(1).min(w);
130
131 let top = if h > cut_h {
133 rng.next_usize(h - cut_h + 1)
134 } else {
135 0
136 };
137 let left = if w > cut_w {
138 rng.next_usize(w - cut_w + 1)
139 } else {
140 0
141 };
142
143 let actual_lambda = 1.0 - (cut_h * cut_w) as f64 / (h * w) as f64;
144
145 let mut mixed = x1.clone();
146
147 for (idx, val) in mixed.indexed_iter_mut() {
149 let raw = idx.slice();
150 let ih = raw[ndim - 2];
151 let iw = raw[ndim - 1];
152 if ih >= top && ih < top + cut_h && iw >= left && iw < left + cut_w {
153 *val = x2[idx.clone()];
154 }
155 }
156
157 Ok((mixed, actual_lambda))
158}
159
160pub fn random_crop_2d(
164 input: &ArrayD<f64>,
165 crop_h: usize,
166 crop_w: usize,
167 rng: &mut AugRng,
168) -> Result<ArrayD<f64>, AugmentationError> {
169 let ndim = input.ndim();
170 if ndim < 2 {
171 return Err(AugmentationError::InvalidCrop {
172 crop_size: crop_h,
173 input_size: 0,
174 });
175 }
176 let h = input.shape()[ndim - 2];
177 let w = input.shape()[ndim - 1];
178 if crop_h > h {
179 return Err(AugmentationError::InvalidCrop {
180 crop_size: crop_h,
181 input_size: h,
182 });
183 }
184 if crop_w > w {
185 return Err(AugmentationError::InvalidCrop {
186 crop_size: crop_w,
187 input_size: w,
188 });
189 }
190 let top = if h > crop_h {
191 rng.next_usize(h - crop_h + 1)
192 } else {
193 0
194 };
195 let left = if w > crop_w {
196 rng.next_usize(w - crop_w + 1)
197 } else {
198 0
199 };
200
201 crop_2d_impl(input, top, left, crop_h, crop_w)
202}
203
204pub fn center_crop_2d(
206 input: &ArrayD<f64>,
207 crop_h: usize,
208 crop_w: usize,
209) -> Result<ArrayD<f64>, AugmentationError> {
210 let ndim = input.ndim();
211 if ndim < 2 {
212 return Err(AugmentationError::InvalidCrop {
213 crop_size: crop_h,
214 input_size: 0,
215 });
216 }
217 let h = input.shape()[ndim - 2];
218 let w = input.shape()[ndim - 1];
219 if crop_h > h {
220 return Err(AugmentationError::InvalidCrop {
221 crop_size: crop_h,
222 input_size: h,
223 });
224 }
225 if crop_w > w {
226 return Err(AugmentationError::InvalidCrop {
227 crop_size: crop_w,
228 input_size: w,
229 });
230 }
231 let top = (h - crop_h) / 2;
232 let left = (w - crop_w) / 2;
233 crop_2d_impl(input, top, left, crop_h, crop_w)
234}
235
236fn crop_2d_impl(
238 input: &ArrayD<f64>,
239 top: usize,
240 left: usize,
241 crop_h: usize,
242 crop_w: usize,
243) -> Result<ArrayD<f64>, AugmentationError> {
244 let ndim = input.ndim();
245
246 let mut out_shape = input.shape().to_vec();
248 out_shape[ndim - 2] = crop_h;
249 out_shape[ndim - 1] = crop_w;
250
251 let total: usize = out_shape.iter().product();
252 let mut data = Vec::with_capacity(total);
253
254 for flat in 0..total {
257 let mut rem = flat;
258 let mut out_idx = vec![0usize; ndim];
259 for d in (0..ndim).rev() {
260 out_idx[d] = rem % out_shape[d];
261 rem /= out_shape[d];
262 }
263 let mut src_idx = out_idx.clone();
265 src_idx[ndim - 2] += top;
266 src_idx[ndim - 1] += left;
267
268 let v = input[IxDyn(&src_idx)];
269 data.push(v);
270 }
271
272 ArrayD::from_shape_vec(IxDyn(&out_shape), data).map_err(|_| AugmentationError::EmptyInput)
273}
274
275pub fn random_hflip(
277 input: &ArrayD<f64>,
278 p: f64,
279 rng: &mut AugRng,
280) -> Result<ArrayD<f64>, AugmentationError> {
281 if !(0.0..=1.0).contains(&p) {
282 return Err(AugmentationError::InvalidProbability(p));
283 }
284 if !rng.next_bool(p) {
285 return Ok(input.clone());
286 }
287 hflip_impl(input)
288}
289
290pub fn random_vflip(
292 input: &ArrayD<f64>,
293 p: f64,
294 rng: &mut AugRng,
295) -> Result<ArrayD<f64>, AugmentationError> {
296 if !(0.0..=1.0).contains(&p) {
297 return Err(AugmentationError::InvalidProbability(p));
298 }
299 if !rng.next_bool(p) {
300 return Ok(input.clone());
301 }
302 vflip_impl(input)
303}
304
305fn hflip_impl(input: &ArrayD<f64>) -> Result<ArrayD<f64>, AugmentationError> {
307 let ndim = input.ndim();
308 if ndim < 2 {
309 return Err(AugmentationError::InvalidCrop {
310 crop_size: 0,
311 input_size: 0,
312 });
313 }
314 let w = input.shape()[ndim - 1];
315 let shape = input.shape().to_vec();
316 let total: usize = shape.iter().product();
317 let mut data = vec![0.0f64; total];
318
319 for (flat, val) in input.iter().enumerate() {
320 let mut rem = flat;
321 let mut idx = vec![0usize; ndim];
322 for d in (0..ndim).rev() {
323 idx[d] = rem % shape[d];
324 rem /= shape[d];
325 }
326 idx[ndim - 1] = w - 1 - idx[ndim - 1];
328 let mut dst_flat = 0usize;
329 let mut stride = 1usize;
330 for d in (0..ndim).rev() {
331 dst_flat += idx[d] * stride;
332 stride *= shape[d];
333 }
334 data[dst_flat] = *val;
335 }
336
337 ArrayD::from_shape_vec(IxDyn(&shape), data).map_err(|_| AugmentationError::EmptyInput)
338}
339
340fn vflip_impl(input: &ArrayD<f64>) -> Result<ArrayD<f64>, AugmentationError> {
342 let ndim = input.ndim();
343 if ndim < 2 {
344 return Err(AugmentationError::InvalidCrop {
345 crop_size: 0,
346 input_size: 0,
347 });
348 }
349 let h = input.shape()[ndim - 2];
350 let shape = input.shape().to_vec();
351 let total: usize = shape.iter().product();
352 let mut data = vec![0.0f64; total];
353
354 for (flat, val) in input.iter().enumerate() {
355 let mut rem = flat;
356 let mut idx = vec![0usize; ndim];
357 for d in (0..ndim).rev() {
358 idx[d] = rem % shape[d];
359 rem /= shape[d];
360 }
361 idx[ndim - 2] = h - 1 - idx[ndim - 2];
363 let mut dst_flat = 0usize;
364 let mut stride = 1usize;
365 for d in (0..ndim).rev() {
366 dst_flat += idx[d] * stride;
367 stride *= shape[d];
368 }
369 data[dst_flat] = *val;
370 }
371
372 ArrayD::from_shape_vec(IxDyn(&shape), data).map_err(|_| AugmentationError::EmptyInput)
373}
374
375pub fn normalize(
381 input: &ArrayD<f64>,
382 mean: &[f64],
383 std: &[f64],
384) -> Result<ArrayD<f64>, AugmentationError> {
385 if mean.is_empty() || std.is_empty() {
386 return Err(AugmentationError::EmptyInput);
387 }
388 if input.is_empty() {
389 return Err(AugmentationError::EmptyInput);
390 }
391
392 let ndim = input.ndim();
393 let shape = input.shape().to_vec();
394
395 if ndim >= 3 {
397 let num_channels = shape[ndim - 3]; let m: Vec<f64> = broadcast_stats(mean, num_channels)?;
400 let s: Vec<f64> = broadcast_stats(std, num_channels)?;
401
402 let mut result = input.clone();
403 for (idx, val) in result.indexed_iter_mut() {
405 let raw = idx.slice();
406 let c = raw[ndim - 3];
407 *val = (*val - m[c]) / s[c];
408 }
409 Ok(result)
410 } else {
411 let m = mean[0];
413 let s = std[0];
414 Ok(input.mapv(|x| (x - m) / s))
415 }
416}
417
418pub fn denormalize(
420 input: &ArrayD<f64>,
421 mean: &[f64],
422 std: &[f64],
423) -> Result<ArrayD<f64>, AugmentationError> {
424 if mean.is_empty() || std.is_empty() {
425 return Err(AugmentationError::EmptyInput);
426 }
427 if input.is_empty() {
428 return Err(AugmentationError::EmptyInput);
429 }
430
431 let ndim = input.ndim();
432 let shape = input.shape().to_vec();
433
434 if ndim >= 3 {
435 let num_channels = shape[ndim - 3];
436 let m: Vec<f64> = broadcast_stats(mean, num_channels)?;
437 let s: Vec<f64> = broadcast_stats(std, num_channels)?;
438
439 let mut result = input.clone();
440 for (idx, val) in result.indexed_iter_mut() {
441 let raw = idx.slice();
442 let c = raw[ndim - 3];
443 *val = *val * s[c] + m[c];
444 }
445 Ok(result)
446 } else {
447 let m = mean[0];
448 let s = std[0];
449 Ok(input.mapv(|x| x * s + m))
450 }
451}
452
453fn broadcast_stats(stats: &[f64], n: usize) -> Result<Vec<f64>, AugmentationError> {
457 if stats.len() == 1 {
458 Ok(vec![stats[0]; n])
459 } else if stats.len() == n {
460 Ok(stats.to_vec())
461 } else {
462 Err(AugmentationError::ShapeMismatch {
463 expected: vec![n],
464 got: vec![stats.len()],
465 })
466 }
467}
468
469pub fn clip(input: &ArrayD<f64>, min_val: f64, max_val: f64) -> ArrayD<f64> {
471 input.mapv(|x| x.clamp(min_val, max_val))
472}