1use rayon::prelude::*;
2use yscv_tensor::Tensor;
3
4use super::super::ImgProcError;
5use super::super::shape::hwc_shape;
6use super::filter::border_coords_3x3;
7
8#[allow(unsafe_code)]
12pub fn sobel_3x3_gradients(input: &Tensor) -> Result<(Tensor, Tensor), ImgProcError> {
13 let (h, w, channels) = hwc_shape(input)?;
14 let len = h * w * channels;
15 let mut out_gx = vec![0.0f32; len];
16 let mut out_gy = vec![0.0f32; len];
17
18 const SOBEL_X: [[f32; 3]; 3] = [[-1.0, 0.0, 1.0], [-2.0, 0.0, 2.0], [-1.0, 0.0, 1.0]];
19 const SOBEL_Y: [[f32; 3]; 3] = [[-1.0, -2.0, -1.0], [0.0, 0.0, 0.0], [1.0, 2.0, 1.0]];
20
21 let data = input.data();
22 let row_len = w * channels;
23 let interior_h = h.saturating_sub(2);
24
25 let compute_interior_row = |y: usize, gx_row: &mut [f32], gy_row: &mut [f32]| {
27 if channels == 1 && !cfg!(miri) {
29 let row0 = &data[(y - 1) * w..y * w];
30 let row1 = &data[y * w..(y + 1) * w];
31 let row2 = &data[(y + 1) * w..(y + 2) * w];
32 let done = sobel_simd_row_c1(row0, row1, row2, gx_row, gy_row, w);
33 for x in done..w.saturating_sub(1) {
35 if x == 0 {
36 continue;
37 }
38 let gx =
39 row0[x + 1] - row0[x - 1] + 2.0 * (row1[x + 1] - row1[x - 1]) + row2[x + 1]
40 - row2[x - 1];
41 let gy = row2[x - 1] + 2.0 * row2[x] + row2[x + 1]
42 - row0[x - 1]
43 - 2.0 * row0[x]
44 - row0[x + 1];
45 gx_row[x] = gx;
46 gy_row[x] = gy;
47 }
48 return;
49 }
50
51 for x in 1..w.saturating_sub(1) {
52 for c in 0..channels {
53 let r0 = ((y - 1) * w + x - 1) * channels + c;
54 let r1 = (y * w + x - 1) * channels + c;
55 let r2 = ((y + 1) * w + x - 1) * channels + c;
56 let mut gx = 0.0f32;
57 let mut gy = 0.0f32;
58 gx += data[r0] * SOBEL_X[0][0];
59 gx += data[r0 + 2 * channels] * SOBEL_X[0][2];
60 gx += data[r1] * SOBEL_X[1][0];
61 gx += data[r1 + 2 * channels] * SOBEL_X[1][2];
62 gx += data[r2] * SOBEL_X[2][0];
63 gx += data[r2 + 2 * channels] * SOBEL_X[2][2];
64
65 gy += data[r0] * SOBEL_Y[0][0];
66 gy += data[r0 + channels] * SOBEL_Y[0][1];
67 gy += data[r0 + 2 * channels] * SOBEL_Y[0][2];
68 gy += data[r2] * SOBEL_Y[2][0];
69 gy += data[r2 + channels] * SOBEL_Y[2][1];
70 gy += data[r2 + 2 * channels] * SOBEL_Y[2][2];
71
72 gx_row[x * channels + c] = gx;
73 gy_row[x * channels + c] = gy;
74 }
75 }
76 };
77
78 if interior_h > 0 {
79 let pixels = h * w;
80 let gx_interior = &mut out_gx[row_len..row_len + interior_h * row_len];
81 let gy_interior = &mut out_gy[row_len..row_len + interior_h * row_len];
82 #[cfg(target_os = "macos")]
83 let use_gcd = pixels > 4096 && !cfg!(miri);
84 #[cfg(not(target_os = "macos"))]
85 let use_gcd = false;
86
87 if use_gcd {
88 #[cfg(target_os = "macos")]
89 {
90 let gx_ptr = gx_interior.as_mut_ptr() as usize;
91 let gy_ptr = gy_interior.as_mut_ptr() as usize;
92 use super::u8ops::gcd;
93 gcd::parallel_for(interior_h, |i| {
94 let y = i + 1;
95 let gx_row = unsafe {
97 std::slice::from_raw_parts_mut(
98 (gx_ptr as *mut f32).add(i * row_len),
99 row_len,
100 )
101 };
102 let gy_row = unsafe {
103 std::slice::from_raw_parts_mut(
104 (gy_ptr as *mut f32).add(i * row_len),
105 row_len,
106 )
107 };
108 compute_interior_row(y, gx_row, gy_row);
109 });
110 }
111 } else if pixels > 4096 {
112 gx_interior
113 .par_chunks_mut(row_len)
114 .zip(gy_interior.par_chunks_mut(row_len))
115 .enumerate()
116 .for_each(|(i, (gx_row, gy_row))| {
117 compute_interior_row(i + 1, gx_row, gy_row);
118 });
119 } else {
120 gx_interior
121 .chunks_mut(row_len)
122 .zip(gy_interior.chunks_mut(row_len))
123 .enumerate()
124 .for_each(|(i, (gx_row, gy_row))| {
125 compute_interior_row(i + 1, gx_row, gy_row);
126 });
127 }
128 }
129
130 let border = border_coords_3x3(h, w);
132 for (y, x) in border {
133 for c in 0..channels {
134 let mut gx = 0.0f32;
135 let mut gy = 0.0f32;
136 for ky in -1isize..=1 {
137 for kx in -1isize..=1 {
138 let sy = y as isize + ky;
139 let sx = x as isize + kx;
140 if sy < 0 || sx < 0 || sy >= h as isize || sx >= w as isize {
141 continue;
142 }
143 let src = ((sy as usize) * w + sx as usize) * channels + c;
144 let kernel_y = (ky + 1) as usize;
145 let kernel_x = (kx + 1) as usize;
146 let value = data[src];
147 gx += value * SOBEL_X[kernel_y][kernel_x];
148 gy += value * SOBEL_Y[kernel_y][kernel_x];
149 }
150 }
151 let dst = (y * w + x) * channels + c;
152 out_gx[dst] = gx;
153 out_gy[dst] = gy;
154 }
155 }
156
157 Ok((
158 Tensor::from_vec(vec![h, w, channels], out_gx)?,
159 Tensor::from_vec(vec![h, w, channels], out_gy)?,
160 ))
161}
162
163#[allow(unsafe_code)]
165fn sobel_simd_row_c1(
166 row0: &[f32],
167 row1: &[f32],
168 row2: &[f32],
169 gx_out: &mut [f32],
170 gy_out: &mut [f32],
171 w: usize,
172) -> usize {
173 if w < 6 {
174 return 1; }
176
177 #[cfg(target_arch = "aarch64")]
178 {
179 if std::arch::is_aarch64_feature_detected!("neon") {
180 return unsafe { sobel_neon_row_c1(row0, row1, row2, gx_out, gy_out, w) };
181 }
182 }
183
184 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
185 {
186 if std::is_x86_feature_detected!("avx") {
187 return unsafe { sobel_avx_row_c1(row0, row1, row2, gx_out, gy_out, w) };
188 }
189 if std::is_x86_feature_detected!("sse") {
190 return unsafe { sobel_sse_row_c1(row0, row1, row2, gx_out, gy_out, w) };
191 }
192 }
193
194 1
195}
196
197#[cfg(target_arch = "aarch64")]
198#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
199#[target_feature(enable = "neon")]
200unsafe fn sobel_neon_row_c1(
201 row0: &[f32],
202 row1: &[f32],
203 row2: &[f32],
204 gx_out: &mut [f32],
205 gy_out: &mut [f32],
206 w: usize,
207) -> usize {
208 use std::arch::aarch64::*;
209 let two = vdupq_n_f32(2.0);
210 let mut x = 1usize;
211 while x + 5 <= w {
213 let r0l = vld1q_f32(row0.as_ptr().add(x - 1)); let r0m = vld1q_f32(row0.as_ptr().add(x)); let r0r = vld1q_f32(row0.as_ptr().add(x + 1)); let r1l = vld1q_f32(row1.as_ptr().add(x - 1));
218 let r1r = vld1q_f32(row1.as_ptr().add(x + 1));
219 let r2l = vld1q_f32(row2.as_ptr().add(x - 1));
220 let r2m = vld1q_f32(row2.as_ptr().add(x));
221 let r2r = vld1q_f32(row2.as_ptr().add(x + 1));
222
223 let dx0 = vsubq_f32(r0r, r0l);
225 let dx1 = vsubq_f32(r1r, r1l);
226 let dx2 = vsubq_f32(r2r, r2l);
227 let gx = vaddq_f32(vaddq_f32(dx0, dx2), vmulq_f32(dx1, two));
228
229 let sy0 = vaddq_f32(vaddq_f32(r0l, r0r), vmulq_f32(r0m, two));
231 let sy2 = vaddq_f32(vaddq_f32(r2l, r2r), vmulq_f32(r2m, two));
232 let gy = vsubq_f32(sy2, sy0);
233
234 vst1q_f32(gx_out.as_mut_ptr().add(x), gx);
235 vst1q_f32(gy_out.as_mut_ptr().add(x), gy);
236 x += 4;
237 }
238 x
239}
240
241#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
242#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
243#[target_feature(enable = "avx")]
244unsafe fn sobel_avx_row_c1(
245 row0: &[f32],
246 row1: &[f32],
247 row2: &[f32],
248 gx_out: &mut [f32],
249 gy_out: &mut [f32],
250 w: usize,
251) -> usize {
252 #[cfg(target_arch = "x86")]
253 use std::arch::x86::*;
254 #[cfg(target_arch = "x86_64")]
255 use std::arch::x86_64::*;
256
257 let two = _mm256_set1_ps(2.0);
258 let mut x = 1usize;
259 while x + 9 <= w {
261 let r0l = _mm256_loadu_ps(row0.as_ptr().add(x - 1));
262 let r0m = _mm256_loadu_ps(row0.as_ptr().add(x));
263 let r0r = _mm256_loadu_ps(row0.as_ptr().add(x + 1));
264 let r1l = _mm256_loadu_ps(row1.as_ptr().add(x - 1));
265 let r1r = _mm256_loadu_ps(row1.as_ptr().add(x + 1));
266 let r2l = _mm256_loadu_ps(row2.as_ptr().add(x - 1));
267 let r2m = _mm256_loadu_ps(row2.as_ptr().add(x));
268 let r2r = _mm256_loadu_ps(row2.as_ptr().add(x + 1));
269
270 let dx0 = _mm256_sub_ps(r0r, r0l);
271 let dx1 = _mm256_sub_ps(r1r, r1l);
272 let dx2 = _mm256_sub_ps(r2r, r2l);
273 let gx = _mm256_add_ps(_mm256_add_ps(dx0, dx2), _mm256_mul_ps(dx1, two));
274
275 let sy0 = _mm256_add_ps(_mm256_add_ps(r0l, r0r), _mm256_mul_ps(r0m, two));
276 let sy2 = _mm256_add_ps(_mm256_add_ps(r2l, r2r), _mm256_mul_ps(r2m, two));
277 let gy = _mm256_sub_ps(sy2, sy0);
278
279 _mm256_storeu_ps(gx_out.as_mut_ptr().add(x), gx);
280 _mm256_storeu_ps(gy_out.as_mut_ptr().add(x), gy);
281 x += 8;
282 }
283 x
284}
285
286#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
287#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
288#[target_feature(enable = "sse")]
289unsafe fn sobel_sse_row_c1(
290 row0: &[f32],
291 row1: &[f32],
292 row2: &[f32],
293 gx_out: &mut [f32],
294 gy_out: &mut [f32],
295 w: usize,
296) -> usize {
297 #[cfg(target_arch = "x86")]
298 use std::arch::x86::*;
299 #[cfg(target_arch = "x86_64")]
300 use std::arch::x86_64::*;
301
302 let two = _mm_set1_ps(2.0);
303 let mut x = 1usize;
304 while x + 5 <= w {
305 let r0l = _mm_loadu_ps(row0.as_ptr().add(x - 1));
306 let r0m = _mm_loadu_ps(row0.as_ptr().add(x));
307 let r0r = _mm_loadu_ps(row0.as_ptr().add(x + 1));
308 let r1l = _mm_loadu_ps(row1.as_ptr().add(x - 1));
309 let r1r = _mm_loadu_ps(row1.as_ptr().add(x + 1));
310 let r2l = _mm_loadu_ps(row2.as_ptr().add(x - 1));
311 let r2m = _mm_loadu_ps(row2.as_ptr().add(x));
312 let r2r = _mm_loadu_ps(row2.as_ptr().add(x + 1));
313
314 let dx0 = _mm_sub_ps(r0r, r0l);
315 let dx1 = _mm_sub_ps(r1r, r1l);
316 let dx2 = _mm_sub_ps(r2r, r2l);
317 let gx = _mm_add_ps(_mm_add_ps(dx0, dx2), _mm_mul_ps(dx1, two));
318
319 let sy0 = _mm_add_ps(_mm_add_ps(r0l, r0r), _mm_mul_ps(r0m, two));
320 let sy2 = _mm_add_ps(_mm_add_ps(r2l, r2r), _mm_mul_ps(r2m, two));
321 let gy = _mm_sub_ps(sy2, sy0);
322
323 _mm_storeu_ps(gx_out.as_mut_ptr().add(x), gx);
324 _mm_storeu_ps(gy_out.as_mut_ptr().add(x), gy);
325 x += 4;
326 }
327 x
328}
329
330#[allow(unsafe_code, clippy::uninit_vec)]
332pub fn sobel_3x3_magnitude(input: &Tensor) -> Result<Tensor, ImgProcError> {
333 let (gx, gy) = sobel_3x3_gradients(input)?;
334 let gx_data = gx.data();
335 let gy_data = gy.data();
336 let total = gx.len();
337 let mut out = Vec::with_capacity(total);
339 unsafe {
340 out.set_len(total);
341 }
342
343 let compute_chunk = |chunk: &mut [f32], start: usize| {
344 let end = start + chunk.len();
345 let mut i = start;
346
347 if !cfg!(miri) {
348 i = start + magnitude_simd(&gx_data[start..end], &gy_data[start..end], chunk);
349 }
350
351 while i < end {
352 let x = gx_data[i];
353 let y = gy_data[i];
354 chunk[i - start] = (x * x + y * y).sqrt();
355 i += 1;
356 }
357 };
358
359 #[cfg(target_os = "macos")]
360 if total > 4096 && !cfg!(miri) {
361 let shape = gx.shape().to_vec();
362 let (h, w, _channels) = (shape[0], shape[1], shape[2]);
363 let row_len = w * shape[2];
364 let gx_ptr = gx_data.as_ptr() as usize;
365 let gy_ptr = gy_data.as_ptr() as usize;
366 let out_ptr = out.as_mut_ptr() as usize;
367 use super::u8ops::gcd;
368 gcd::parallel_for(h, |y| {
369 let start = y * row_len;
370 let gx_slice =
372 unsafe { std::slice::from_raw_parts((gx_ptr as *const f32).add(start), row_len) };
373 let gy_slice =
374 unsafe { std::slice::from_raw_parts((gy_ptr as *const f32).add(start), row_len) };
375 let dst = unsafe {
376 std::slice::from_raw_parts_mut((out_ptr as *mut f32).add(start), row_len)
377 };
378 let mut i = magnitude_simd(gx_slice, gy_slice, dst);
379 while i < row_len {
380 let x = gx_slice[i];
381 let yv = gy_slice[i];
382 dst[i] = (x * x + yv * yv).sqrt();
383 i += 1;
384 }
385 });
386 return Tensor::from_vec(shape, out).map_err(Into::into);
387 }
388
389 if total > 4096 {
390 out.par_chunks_mut(1024)
391 .enumerate()
392 .for_each(|(chunk_idx, chunk)| {
393 compute_chunk(chunk, chunk_idx * 1024);
394 });
395 } else {
396 compute_chunk(&mut out, 0);
397 }
398
399 Tensor::from_vec(gx.shape().to_vec(), out).map_err(Into::into)
400}
401
402#[allow(unsafe_code)]
404fn magnitude_simd(gx: &[f32], gy: &[f32], out: &mut [f32]) -> usize {
405 let len = gx.len().min(gy.len()).min(out.len());
406 if len < 4 {
407 return 0;
408 }
409
410 #[cfg(target_arch = "aarch64")]
411 {
412 if std::arch::is_aarch64_feature_detected!("neon") {
413 return unsafe { magnitude_neon(gx, gy, out, len) };
414 }
415 }
416
417 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
418 {
419 if std::is_x86_feature_detected!("avx") {
420 return unsafe { magnitude_avx(gx, gy, out, len) };
421 }
422 if std::is_x86_feature_detected!("sse") {
423 return unsafe { magnitude_sse(gx, gy, out, len) };
424 }
425 }
426
427 0
428}
429
430#[cfg(target_arch = "aarch64")]
431#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
432#[target_feature(enable = "neon")]
433unsafe fn magnitude_neon(gx: &[f32], gy: &[f32], out: &mut [f32], len: usize) -> usize {
434 use std::arch::aarch64::*;
435 let gxp = gx.as_ptr();
436 let gyp = gy.as_ptr();
437 let op = out.as_mut_ptr();
438 let mut i = 0usize;
439 while i + 4 <= len {
440 let x = vld1q_f32(gxp.add(i));
441 let y = vld1q_f32(gyp.add(i));
442 let sq = vaddq_f32(vmulq_f32(x, x), vmulq_f32(y, y));
443 vst1q_f32(op.add(i), vsqrtq_f32(sq));
444 i += 4;
445 }
446 i
447}
448
449#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
450#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
451#[target_feature(enable = "avx")]
452unsafe fn magnitude_avx(gx: &[f32], gy: &[f32], out: &mut [f32], len: usize) -> usize {
453 #[cfg(target_arch = "x86")]
454 use std::arch::x86::*;
455 #[cfg(target_arch = "x86_64")]
456 use std::arch::x86_64::*;
457
458 let gxp = gx.as_ptr();
459 let gyp = gy.as_ptr();
460 let op = out.as_mut_ptr();
461 let mut i = 0usize;
462 while i + 8 <= len {
463 let x = _mm256_loadu_ps(gxp.add(i));
464 let y = _mm256_loadu_ps(gyp.add(i));
465 let sq = _mm256_add_ps(_mm256_mul_ps(x, x), _mm256_mul_ps(y, y));
466 _mm256_storeu_ps(op.add(i), _mm256_sqrt_ps(sq));
467 i += 8;
468 }
469 i
470}
471
472#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
473#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
474#[target_feature(enable = "sse")]
475unsafe fn magnitude_sse(gx: &[f32], gy: &[f32], out: &mut [f32], len: usize) -> usize {
476 #[cfg(target_arch = "x86")]
477 use std::arch::x86::*;
478 #[cfg(target_arch = "x86_64")]
479 use std::arch::x86_64::*;
480
481 let gxp = gx.as_ptr();
482 let gyp = gy.as_ptr();
483 let op = out.as_mut_ptr();
484 let mut i = 0usize;
485 while i + 4 <= len {
486 let x = _mm_loadu_ps(gxp.add(i));
487 let y = _mm_loadu_ps(gyp.add(i));
488 let sq = _mm_add_ps(_mm_mul_ps(x, x), _mm_mul_ps(y, y));
489 _mm_storeu_ps(op.add(i), _mm_sqrt_ps(sq));
490 i += 4;
491 }
492 i
493}
494
495pub fn scharr_3x3_gradients(input: &Tensor) -> Result<(Tensor, Tensor), ImgProcError> {
501 let (h, w, channels) = hwc_shape(input)?;
502 let len = h * w * channels;
503 let mut out_gx = vec![0.0f32; len];
504 let mut out_gy = vec![0.0f32; len];
505
506 const SCHARR_X: [[f32; 3]; 3] = [[-3.0, 0.0, 3.0], [-10.0, 0.0, 10.0], [-3.0, 0.0, 3.0]];
507 const SCHARR_Y: [[f32; 3]; 3] = [[-3.0, -10.0, -3.0], [0.0, 0.0, 0.0], [3.0, 10.0, 3.0]];
508
509 let data = input.data();
510 let row_len = w * channels;
511 let interior_h = h.saturating_sub(2);
512
513 let compute_interior_row = |y: usize, gx_row: &mut [f32], gy_row: &mut [f32]| {
515 for x in 1..w.saturating_sub(1) {
516 for c in 0..channels {
517 let r0 = ((y - 1) * w + x - 1) * channels + c;
518 let r1 = (y * w + x - 1) * channels + c;
519 let r2 = ((y + 1) * w + x - 1) * channels + c;
520 let mut gx = 0.0f32;
521 let mut gy = 0.0f32;
522 gx += data[r0] * SCHARR_X[0][0];
523 gx += data[r0 + 2 * channels] * SCHARR_X[0][2];
524 gx += data[r1] * SCHARR_X[1][0];
525 gx += data[r1 + 2 * channels] * SCHARR_X[1][2];
526 gx += data[r2] * SCHARR_X[2][0];
527 gx += data[r2 + 2 * channels] * SCHARR_X[2][2];
528
529 gy += data[r0] * SCHARR_Y[0][0];
530 gy += data[r0 + channels] * SCHARR_Y[0][1];
531 gy += data[r0 + 2 * channels] * SCHARR_Y[0][2];
532 gy += data[r2] * SCHARR_Y[2][0];
533 gy += data[r2 + channels] * SCHARR_Y[2][1];
534 gy += data[r2 + 2 * channels] * SCHARR_Y[2][2];
535
536 gx_row[x * channels + c] = gx;
537 gy_row[x * channels + c] = gy;
538 }
539 }
540 };
541
542 if interior_h > 0 {
543 let gx_interior = &mut out_gx[row_len..row_len + interior_h * row_len];
544 let gy_interior = &mut out_gy[row_len..row_len + interior_h * row_len];
545 if h * w > 4096 {
546 gx_interior
547 .par_chunks_mut(row_len)
548 .zip(gy_interior.par_chunks_mut(row_len))
549 .enumerate()
550 .for_each(|(i, (gx_row, gy_row))| {
551 compute_interior_row(i + 1, gx_row, gy_row);
552 });
553 } else {
554 gx_interior
555 .chunks_mut(row_len)
556 .zip(gy_interior.chunks_mut(row_len))
557 .enumerate()
558 .for_each(|(i, (gx_row, gy_row))| {
559 compute_interior_row(i + 1, gx_row, gy_row);
560 });
561 }
562 }
563
564 let border = border_coords_3x3(h, w);
566 for (y, x) in border {
567 for c in 0..channels {
568 let mut gx = 0.0f32;
569 let mut gy = 0.0f32;
570 for ky in -1isize..=1 {
571 for kx in -1isize..=1 {
572 let sy = y as isize + ky;
573 let sx = x as isize + kx;
574 if sy < 0 || sx < 0 || sy >= h as isize || sx >= w as isize {
575 continue;
576 }
577 let src = ((sy as usize) * w + sx as usize) * channels + c;
578 let kernel_y = (ky + 1) as usize;
579 let kernel_x = (kx + 1) as usize;
580 let value = data[src];
581 gx += value * SCHARR_X[kernel_y][kernel_x];
582 gy += value * SCHARR_Y[kernel_y][kernel_x];
583 }
584 }
585 let dst = (y * w + x) * channels + c;
586 out_gx[dst] = gx;
587 out_gy[dst] = gy;
588 }
589 }
590
591 let shape = vec![h, w, channels];
592 let gx = Tensor::from_vec(shape.clone(), out_gx)?;
593 let gy = Tensor::from_vec(shape, out_gy)?;
594 Ok((gx, gy))
595}
596
597pub fn scharr_3x3_magnitude(input: &Tensor) -> Result<Tensor, ImgProcError> {
599 let (gx, gy) = scharr_3x3_gradients(input)?;
600 let gx_data = gx.data();
601 let gy_data = gy.data();
602 let mut out = vec![0.0f32; gx.len()];
603
604 if out.len() > 4096 {
605 out.par_chunks_mut(1024)
606 .enumerate()
607 .for_each(|(chunk_idx, chunk)| {
608 let start = chunk_idx * 1024;
609 for (j, v) in chunk.iter_mut().enumerate() {
610 let i = start + j;
611 let x = gx_data[i];
612 let y = gy_data[i];
613 *v = (x * x + y * y).sqrt();
614 }
615 });
616 } else {
617 for (idx, value) in out.iter_mut().enumerate() {
618 let x = gx_data[idx];
619 let y = gy_data[idx];
620 *value = (x * x + y * y).sqrt();
621 }
622 }
623
624 Tensor::from_vec(gx.shape().to_vec(), out).map_err(Into::into)
625}
626
627pub fn flip_horizontal(input: &Tensor) -> Result<Tensor, ImgProcError> {
629 let (h, w, channels) = hwc_shape(input)?;
630 let mut out = vec![0.0f32; input.len()];
631
632 for y in 0..h {
633 for x in 0..w {
634 let src_x = w - 1 - x;
635 let dst_base = (y * w + x) * channels;
636 let src_base = (y * w + src_x) * channels;
637 out[dst_base..(dst_base + channels)]
638 .copy_from_slice(&input.data()[src_base..(src_base + channels)]);
639 }
640 }
641
642 Tensor::from_vec(vec![h, w, channels], out).map_err(Into::into)
643}
644
645pub fn flip_vertical(input: &Tensor) -> Result<Tensor, ImgProcError> {
647 let (h, w, channels) = hwc_shape(input)?;
648 let mut out = vec![0.0f32; input.len()];
649
650 for y in 0..h {
651 let src_y = h - 1 - y;
652 for x in 0..w {
653 let dst_base = (y * w + x) * channels;
654 let src_base = (src_y * w + x) * channels;
655 out[dst_base..(dst_base + channels)]
656 .copy_from_slice(&input.data()[src_base..(src_base + channels)]);
657 }
658 }
659
660 Tensor::from_vec(vec![h, w, channels], out).map_err(Into::into)
661}
662
663pub fn rotate90_cw(input: &Tensor) -> Result<Tensor, ImgProcError> {
665 let (in_h, in_w, channels) = hwc_shape(input)?;
666 let out_h = in_w;
667 let out_w = in_h;
668 let mut out = vec![0.0f32; input.len()];
669
670 for y in 0..in_h {
671 for x in 0..in_w {
672 let dst_y = x;
673 let dst_x = in_h - 1 - y;
674 let src_base = (y * in_w + x) * channels;
675 let dst_base = (dst_y * out_w + dst_x) * channels;
676 out[dst_base..(dst_base + channels)]
677 .copy_from_slice(&input.data()[src_base..(src_base + channels)]);
678 }
679 }
680
681 Tensor::from_vec(vec![out_h, out_w, channels], out).map_err(Into::into)
682}
683
684pub fn pad_constant(
686 input: &Tensor,
687 pad_top: usize,
688 pad_bottom: usize,
689 pad_left: usize,
690 pad_right: usize,
691 value: f32,
692) -> Result<Tensor, ImgProcError> {
693 let (in_h, in_w, channels) = hwc_shape(input)?;
694 let out_h = in_h + pad_top + pad_bottom;
695 let out_w = in_w + pad_left + pad_right;
696 let mut out = vec![value; out_h * out_w * channels];
697
698 let data = input.data();
699 for y in 0..in_h {
700 for x in 0..in_w {
701 let src_base = (y * in_w + x) * channels;
702 let dst_base = ((y + pad_top) * out_w + x + pad_left) * channels;
703 out[dst_base..dst_base + channels]
704 .copy_from_slice(&data[src_base..src_base + channels]);
705 }
706 }
707
708 Tensor::from_vec(vec![out_h, out_w, channels], out).map_err(Into::into)
709}
710
711pub fn crop(
713 input: &Tensor,
714 top: usize,
715 left: usize,
716 crop_h: usize,
717 crop_w: usize,
718) -> Result<Tensor, ImgProcError> {
719 let (in_h, in_w, channels) = hwc_shape(input)?;
720 if top + crop_h > in_h || left + crop_w > in_w || crop_h == 0 || crop_w == 0 {
721 return Err(ImgProcError::InvalidSize {
722 height: crop_h,
723 width: crop_w,
724 });
725 }
726
727 let data = input.data();
728 let mut out = vec![0.0f32; crop_h * crop_w * channels];
729 for y in 0..crop_h {
730 for x in 0..crop_w {
731 let src_base = ((y + top) * in_w + x + left) * channels;
732 let dst_base = (y * crop_w + x) * channels;
733 out[dst_base..dst_base + channels]
734 .copy_from_slice(&data[src_base..src_base + channels]);
735 }
736 }
737
738 Tensor::from_vec(vec![crop_h, crop_w, channels], out).map_err(Into::into)
739}
740
741pub fn warp_affine(
747 input: &Tensor,
748 out_h: usize,
749 out_w: usize,
750 matrix: &[f32; 6],
751 border_value: f32,
752) -> Result<Tensor, ImgProcError> {
753 let (h, w, channels) = hwc_shape(input)?;
754 if out_h == 0 || out_w == 0 {
755 return Err(ImgProcError::InvalidOutputDimensions { out_h, out_w });
756 }
757
758 let data = input.data();
759 let mut out = vec![border_value; out_h * out_w * channels];
760 let [a00, a01, tx, a10, a11, ty] = *matrix;
761
762 for dy in 0..out_h {
763 for dx in 0..out_w {
764 let src_xf = a00 * dx as f32 + a01 * dy as f32 + tx;
765 let src_yf = a10 * dx as f32 + a11 * dy as f32 + ty;
766
767 let x0 = src_xf.floor() as isize;
768 let y0 = src_yf.floor() as isize;
769 let x1 = x0 + 1;
770 let y1 = y0 + 1;
771
772 if x0 < 0 || y0 < 0 || x1 >= w as isize || y1 >= h as isize {
773 continue;
774 }
775
776 let fx = src_xf - x0 as f32;
777 let fy = src_yf - y0 as f32;
778
779 let x0u = x0 as usize;
780 let y0u = y0 as usize;
781 let x1u = x1 as usize;
782 let y1u = y1 as usize;
783
784 for c in 0..channels {
785 let v00 = data[(y0u * w + x0u) * channels + c];
786 let v01 = data[(y0u * w + x1u) * channels + c];
787 let v10 = data[(y1u * w + x0u) * channels + c];
788 let v11 = data[(y1u * w + x1u) * channels + c];
789
790 let val = v00 * (1.0 - fx) * (1.0 - fy)
791 + v01 * fx * (1.0 - fy)
792 + v10 * (1.0 - fx) * fy
793 + v11 * fx * fy;
794
795 out[(dy * out_w + dx) * channels + c] = val;
796 }
797 }
798 }
799
800 Tensor::from_vec(vec![out_h, out_w, channels], out).map_err(Into::into)
801}
802
803#[allow(clippy::too_many_arguments, unsafe_code)]
810pub fn warp_perspective(
811 input: &Tensor,
812 transform: &[f32; 9],
813 out_h: usize,
814 out_w: usize,
815 border_value: f32,
816) -> Result<Tensor, ImgProcError> {
817 let (ih, iw, channels) = hwc_shape(input)?;
818 let in_data = input.data();
819
820 let inv = invert_3x3(transform)
821 .ok_or(ImgProcError::InvalidOutputDimensions { out_h: 0, out_w: 0 })?;
822
823 let mut out = vec![border_value; out_h * out_w * channels];
824
825 if channels == 1 {
827 warp_perspective_c1(in_data, ih, iw, &inv, &mut out, out_h, out_w, border_value);
828 } else {
829 warp_perspective_scalar(in_data, ih, iw, channels, &inv, &mut out, out_h, out_w);
830 }
831
832 Tensor::from_vec(vec![out_h, out_w, channels], out).map_err(Into::into)
833}
834
835#[allow(unsafe_code)]
837fn warp_perspective_scalar(
838 in_data: &[f32],
839 ih: usize,
840 iw: usize,
841 channels: usize,
842 inv: &[f32; 9],
843 out: &mut [f32],
844 out_h: usize,
845 out_w: usize,
846) {
847 use super::u8ops::gcd;
848 let out_ptr = out.as_mut_ptr() as usize;
849 let in_ptr = in_data.as_ptr() as usize;
850 let in_len = in_data.len();
851 let row_stride = out_w * channels;
852 let inv = *inv; gcd::parallel_for(out_h, |dy| {
855 let out_row = unsafe {
857 std::slice::from_raw_parts_mut((out_ptr as *mut f32).add(dy * row_stride), row_stride)
858 };
859 let in_data = unsafe { std::slice::from_raw_parts(in_ptr as *const f32, in_len) };
860
861 let yf = dy as f32 + 0.5;
862 let base_num_x = inv[1] * yf + inv[2];
863 let base_num_y = inv[4] * yf + inv[5];
864 let base_den = inv[7] * yf + inv[8];
865
866 for dx in 0..out_w {
867 let xf = dx as f32 + 0.5;
868 let denom = inv[6] * xf + base_den;
869 if denom.abs() < 1e-10 {
870 continue;
871 }
872 let inv_denom = 1.0 / denom;
873 let sx = (inv[0] * xf + base_num_x) * inv_denom - 0.5;
874 let sy = (inv[3] * xf + base_num_y) * inv_denom - 0.5;
875
876 if sx < 0.0 || sy < 0.0 || sx >= (iw - 1) as f32 || sy >= (ih - 1) as f32 {
877 continue;
878 }
879 let x0 = sx.floor() as usize;
880 let y0 = sy.floor() as usize;
881 let x1 = (x0 + 1).min(iw - 1);
882 let y1 = (y0 + 1).min(ih - 1);
883 let fx = sx - x0 as f32;
884 let fy = sy - y0 as f32;
885
886 for c in 0..channels {
887 let v00 = in_data[(y0 * iw + x0) * channels + c];
888 let v10 = in_data[(y0 * iw + x1) * channels + c];
889 let v01 = in_data[(y1 * iw + x0) * channels + c];
890 let v11 = in_data[(y1 * iw + x1) * channels + c];
891 let val = v00 * (1.0 - fx) * (1.0 - fy)
892 + v10 * fx * (1.0 - fy)
893 + v01 * (1.0 - fx) * fy
894 + v11 * fx * fy;
895 out_row[dx * channels + c] = val;
896 }
897 }
898 });
899}
900
901#[allow(unsafe_code)]
905fn warp_perspective_c1(
906 in_data: &[f32],
907 ih: usize,
908 iw: usize,
909 inv: &[f32; 9],
910 out: &mut [f32],
911 out_h: usize,
912 out_w: usize,
913 border_value: f32,
914) {
915 use super::u8ops::gcd;
916 let iw_f = (iw - 1) as f32;
917 let ih_f = (ih - 1) as f32;
918 let out_ptr = out.as_mut_ptr() as usize;
919 let in_ptr = in_data.as_ptr() as usize;
920 let in_len = in_data.len();
921 let inv = *inv;
922
923 gcd::parallel_for(out_h, |dy| {
924 let out_row =
925 unsafe { std::slice::from_raw_parts_mut((out_ptr as *mut f32).add(dy * out_w), out_w) };
926 let in_data = unsafe { std::slice::from_raw_parts(in_ptr as *const f32, in_len) };
927
928 let yf = dy as f32 + 0.5;
929 let base_num_x = inv[1] * yf + inv[2];
930 let base_num_y = inv[4] * yf + inv[5];
931 let base_den = inv[7] * yf + inv[8];
932
933 let mut dx = 0usize;
934
935 #[cfg(target_arch = "aarch64")]
936 {
937 if std::arch::is_aarch64_feature_detected!("neon") {
938 dx = unsafe {
939 warp_perspective_c1_neon_row(
940 in_data,
941 ih,
942 iw,
943 &inv,
944 out_row,
945 out_w,
946 0,
947 yf,
948 base_num_x,
949 base_num_y,
950 base_den,
951 iw_f,
952 ih_f,
953 border_value,
954 )
955 };
956 }
957 }
958 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
959 {
960 if std::is_x86_feature_detected!("sse") {
961 dx = unsafe {
962 warp_perspective_c1_sse_row(
963 in_data,
964 ih,
965 iw,
966 &inv,
967 out_row,
968 out_w,
969 0,
970 yf,
971 base_num_x,
972 base_num_y,
973 base_den,
974 iw_f,
975 ih_f,
976 border_value,
977 )
978 };
979 }
980 }
981
982 while dx < out_w {
984 let xf = dx as f32 + 0.5;
985 let denom = inv[6] * xf + base_den;
986 if denom.abs() < 1e-10 {
987 dx += 1;
988 continue;
989 }
990 let inv_denom = 1.0 / denom;
991 let sx = (inv[0] * xf + base_num_x) * inv_denom - 0.5;
992 let sy = (inv[3] * xf + base_num_y) * inv_denom - 0.5;
993
994 if sx >= 0.0 && sy >= 0.0 && sx < iw_f && sy < ih_f {
995 let x0 = sx.floor() as usize;
996 let y0 = sy.floor() as usize;
997 let x1 = x0 + 1;
998 let y1 = y0 + 1;
999 let fx = sx - x0 as f32;
1000 let fy = sy - y0 as f32;
1001 let v00 = in_data[y0 * iw + x0];
1002 let v10 = in_data[y0 * iw + x1];
1003 let v01 = in_data[y1 * iw + x0];
1004 let v11 = in_data[y1 * iw + x1];
1005 out_row[dx] = v00 * (1.0 - fx) * (1.0 - fy)
1006 + v10 * fx * (1.0 - fy)
1007 + v01 * (1.0 - fx) * fy
1008 + v11 * fx * fy;
1009 }
1010 dx += 1;
1011 }
1012 });
1013}
1014
1015#[cfg(target_arch = "aarch64")]
1016#[allow(unsafe_code, unsafe_op_in_unsafe_fn, clippy::too_many_arguments)]
1017#[target_feature(enable = "neon")]
1018unsafe fn warp_perspective_c1_neon_row(
1019 in_data: &[f32],
1020 _ih: usize,
1021 iw: usize,
1022 inv: &[f32; 9],
1023 out: &mut [f32],
1024 out_w: usize,
1025 dy: usize,
1026 _yf: f32,
1027 base_num_x: f32,
1028 base_num_y: f32,
1029 base_den: f32,
1030 iw_f: f32,
1031 ih_f: f32,
1032 border_value: f32,
1033) -> usize {
1034 use std::arch::aarch64::*;
1035
1036 let inv0 = vdupq_n_f32(inv[0]);
1037 let inv3 = vdupq_n_f32(inv[3]);
1038 let inv6 = vdupq_n_f32(inv[6]);
1039 let base_nx = vdupq_n_f32(base_num_x);
1040 let base_ny = vdupq_n_f32(base_num_y);
1041 let base_d = vdupq_n_f32(base_den);
1042 let half = vdupq_n_f32(0.5);
1043 let zero = vdupq_n_f32(0.0);
1044 let iw_max = vdupq_n_f32(iw_f);
1045 let ih_max = vdupq_n_f32(ih_f);
1046 let iw_s = vdupq_n_f32(iw as f32);
1047 let border = vdupq_n_f32(border_value);
1048
1049 let row_start = dy * out_w;
1050 let out_ptr = out.as_mut_ptr().add(row_start);
1051 let in_ptr = in_data.as_ptr();
1052
1053 let xf_init = [0.5f32, 1.5, 2.5, 3.5];
1055 let mut xf = vaddq_f32(vdupq_n_f32(0.0), vld1q_f32(xf_init.as_ptr()));
1056 let xf_step = vdupq_n_f32(4.0);
1057
1058 let mut dx = 0usize;
1059 while dx + 4 <= out_w {
1060 let denom = vmlaq_f32(base_d, inv6, xf);
1062 let recip_est = vrecpeq_f32(denom);
1064 let inv_denom = vmulq_f32(recip_est, vrecpsq_f32(denom, recip_est));
1065
1066 let num_x = vmlaq_f32(base_nx, inv0, xf);
1068 let sx = vsubq_f32(vmulq_f32(num_x, inv_denom), half);
1069
1070 let num_y = vmlaq_f32(base_ny, inv3, xf);
1072 let sy = vsubq_f32(vmulq_f32(num_y, inv_denom), half);
1073
1074 let in_bounds = vandq_u32(
1076 vandq_u32(vcgeq_f32(sx, zero), vcgeq_f32(sy, zero)),
1077 vandq_u32(vcltq_f32(sx, iw_max), vcltq_f32(sy, ih_max)),
1078 );
1079
1080 let mask_bits = vgetq_lane_u32(in_bounds, 0)
1082 | vgetq_lane_u32(in_bounds, 1)
1083 | vgetq_lane_u32(in_bounds, 2)
1084 | vgetq_lane_u32(in_bounds, 3);
1085 if mask_bits == 0 {
1086 xf = vaddq_f32(xf, xf_step);
1087 dx += 4;
1088 continue;
1089 }
1090
1091 let sx_floor = vrndmq_f32(sx);
1093 let sy_floor = vrndmq_f32(sy);
1094 let fx = vsubq_f32(sx, sx_floor);
1095 let fy = vsubq_f32(sy, sy_floor);
1096
1097 let idx_base = vmlaq_f32(sx_floor, sy_floor, iw_s);
1099
1100 let mut result = border;
1102 let fx_arr: [f32; 4] = std::mem::transmute(fx);
1103 let fy_arr: [f32; 4] = std::mem::transmute(fy);
1104 let idx_arr: [f32; 4] = std::mem::transmute(idx_base);
1105 let mask_arr: [u32; 4] = std::mem::transmute(in_bounds);
1106 let mut res_arr: [f32; 4] = std::mem::transmute(result);
1107
1108 for i in 0..4 {
1109 if mask_arr[i] != 0 {
1110 let base = idx_arr[i] as usize;
1111 let v00 = *in_ptr.add(base);
1112 let v10 = *in_ptr.add(base + 1);
1113 let v01 = *in_ptr.add(base + iw);
1114 let v11 = *in_ptr.add(base + iw + 1);
1115 let fxi = fx_arr[i];
1116 let fyi = fy_arr[i];
1117 res_arr[i] = v00 * (1.0 - fxi) * (1.0 - fyi)
1118 + v10 * fxi * (1.0 - fyi)
1119 + v01 * (1.0 - fxi) * fyi
1120 + v11 * fxi * fyi;
1121 }
1122 }
1123 result = vld1q_f32(res_arr.as_ptr());
1124 vst1q_f32(out_ptr.add(dx), result);
1125
1126 xf = vaddq_f32(xf, xf_step);
1127 dx += 4;
1128 }
1129 dx
1130}
1131
1132#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1133#[allow(unsafe_code, unsafe_op_in_unsafe_fn, clippy::too_many_arguments)]
1134#[target_feature(enable = "sse")]
1135unsafe fn warp_perspective_c1_sse_row(
1136 in_data: &[f32],
1137 _ih: usize,
1138 iw: usize,
1139 inv: &[f32; 9],
1140 out: &mut [f32],
1141 out_w: usize,
1142 dy: usize,
1143 _yf: f32,
1144 base_num_x: f32,
1145 base_num_y: f32,
1146 base_den: f32,
1147 iw_f: f32,
1148 ih_f: f32,
1149 border_value: f32,
1150) -> usize {
1151 #[cfg(target_arch = "x86")]
1152 use std::arch::x86::*;
1153 #[cfg(target_arch = "x86_64")]
1154 use std::arch::x86_64::*;
1155
1156 let inv0 = _mm_set1_ps(inv[0]);
1157 let inv3 = _mm_set1_ps(inv[3]);
1158 let inv6 = _mm_set1_ps(inv[6]);
1159 let base_nx = _mm_set1_ps(base_num_x);
1160 let base_ny = _mm_set1_ps(base_num_y);
1161 let base_d = _mm_set1_ps(base_den);
1162 let half = _mm_set1_ps(0.5);
1163 let zero = _mm_setzero_ps();
1164 let iw_max = _mm_set1_ps(iw_f);
1165 let ih_max = _mm_set1_ps(ih_f);
1166 let one = _mm_set1_ps(1.0);
1167 let iw_s = _mm_set1_ps(iw as f32);
1168 let _border = _mm_set1_ps(border_value);
1169
1170 let row_start = dy * out_w;
1171 let out_ptr = out.as_mut_ptr().add(row_start);
1172 let in_ptr = in_data.as_ptr();
1173
1174 let mut dx = 0usize;
1175 while dx + 4 <= out_w {
1176 let xf = _mm_set_ps(
1177 dx as f32 + 3.5,
1178 dx as f32 + 2.5,
1179 dx as f32 + 1.5,
1180 dx as f32 + 0.5,
1181 );
1182
1183 let denom = _mm_add_ps(_mm_mul_ps(inv6, xf), base_d);
1185 let inv_denom = _mm_div_ps(one, denom);
1186
1187 let num_x = _mm_add_ps(_mm_mul_ps(inv0, xf), base_nx);
1189 let sx = _mm_sub_ps(_mm_mul_ps(num_x, inv_denom), half);
1190
1191 let num_y = _mm_add_ps(_mm_mul_ps(inv3, xf), base_ny);
1193 let sy = _mm_sub_ps(_mm_mul_ps(num_y, inv_denom), half);
1194
1195 let in_bounds = _mm_and_ps(
1197 _mm_and_ps(_mm_cmpge_ps(sx, zero), _mm_cmpge_ps(sy, zero)),
1198 _mm_and_ps(_mm_cmplt_ps(sx, iw_max), _mm_cmplt_ps(sy, ih_max)),
1199 );
1200
1201 let mask_bits = _mm_movemask_ps(in_bounds);
1202 if mask_bits == 0 {
1203 dx += 4;
1204 continue;
1205 }
1206
1207 let sx_floor = _mm_cvtepi32_ps(_mm_cvttps_epi32(sx));
1209 let sy_floor = _mm_cvtepi32_ps(_mm_cvttps_epi32(sy));
1210 let fx = _mm_sub_ps(sx, sx_floor);
1211 let fy = _mm_sub_ps(sy, sy_floor);
1212
1213 let idx_base = _mm_add_ps(_mm_mul_ps(sy_floor, iw_s), sx_floor);
1215
1216 let mut res_arr = [border_value; 4];
1218 let fx_arr: [f32; 4] = std::mem::transmute(fx);
1219 let fy_arr: [f32; 4] = std::mem::transmute(fy);
1220 let idx_arr: [f32; 4] = std::mem::transmute(idx_base);
1221
1222 for i in 0..4 {
1223 if (mask_bits >> i) & 1 != 0 {
1224 let base = idx_arr[i] as usize;
1225 let v00 = *in_ptr.add(base);
1226 let v10 = *in_ptr.add(base + 1);
1227 let v01 = *in_ptr.add(base + iw);
1228 let v11 = *in_ptr.add(base + iw + 1);
1229 let fxi = fx_arr[i];
1230 let fyi = fy_arr[i];
1231 res_arr[i] = v00 * (1.0 - fxi) * (1.0 - fyi)
1232 + v10 * fxi * (1.0 - fyi)
1233 + v01 * (1.0 - fxi) * fyi
1234 + v11 * fxi * fyi;
1235 }
1236 }
1237 _mm_storeu_ps(out_ptr.add(dx), _mm_loadu_ps(res_arr.as_ptr()));
1238
1239 dx += 4;
1240 }
1241 dx
1242}
1243
1244pub(crate) fn invert_3x3(m: &[f32; 9]) -> Option<[f32; 9]> {
1245 let det = m[0] * (m[4] * m[8] - m[5] * m[7]) - m[1] * (m[3] * m[8] - m[5] * m[6])
1246 + m[2] * (m[3] * m[7] - m[4] * m[6]);
1247 if det.abs() < 1e-10 {
1248 return None;
1249 }
1250 let inv_det = 1.0 / det;
1251 Some([
1252 (m[4] * m[8] - m[5] * m[7]) * inv_det,
1253 (m[2] * m[7] - m[1] * m[8]) * inv_det,
1254 (m[1] * m[5] - m[2] * m[4]) * inv_det,
1255 (m[5] * m[6] - m[3] * m[8]) * inv_det,
1256 (m[0] * m[8] - m[2] * m[6]) * inv_det,
1257 (m[2] * m[3] - m[0] * m[5]) * inv_det,
1258 (m[3] * m[7] - m[4] * m[6]) * inv_det,
1259 (m[1] * m[6] - m[0] * m[7]) * inv_det,
1260 (m[0] * m[4] - m[1] * m[3]) * inv_det,
1261 ])
1262}