1use rayon::prelude::*;
2use yscv_tensor::Tensor;
3
4use super::super::ImgProcError;
5use super::super::shape::hwc_shape;
6
7#[allow(unsafe_code)]
9fn box_blur_simd_row_c1(
10 row0: &[f32],
11 row1: &[f32],
12 row2: &[f32],
13 out: &mut [f32],
14 w: usize,
15) -> usize {
16 if w < 6 {
17 return 1;
18 }
19
20 #[cfg(target_arch = "aarch64")]
21 {
22 if std::arch::is_aarch64_feature_detected!("neon") {
23 return unsafe { box_blur_neon_row_c1(row0, row1, row2, out, w) };
24 }
25 }
26 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
27 if std::is_x86_feature_detected!("avx") {
28 return unsafe { box_blur_avx_row_c1(row0, row1, row2, out, w) };
29 }
30 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
31 {
32 if std::is_x86_feature_detected!("sse") {
33 return unsafe { box_blur_sse_row_c1(row0, row1, row2, out, w) };
34 }
35 }
36 1
37}
38
39#[cfg(target_arch = "aarch64")]
40#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
41#[target_feature(enable = "neon")]
42unsafe fn box_blur_neon_row_c1(
43 row0: &[f32],
44 row1: &[f32],
45 row2: &[f32],
46 out: &mut [f32],
47 w: usize,
48) -> usize {
49 use std::arch::aarch64::*;
50 let inv9 = vdupq_n_f32(1.0 / 9.0);
51 let mut x = 1usize;
52 while x + 5 <= w {
53 let r0l = vld1q_f32(row0.as_ptr().add(x - 1));
54 let r0m = vld1q_f32(row0.as_ptr().add(x));
55 let r0r = vld1q_f32(row0.as_ptr().add(x + 1));
56 let r1l = vld1q_f32(row1.as_ptr().add(x - 1));
57 let r1m = vld1q_f32(row1.as_ptr().add(x));
58 let r1r = vld1q_f32(row1.as_ptr().add(x + 1));
59 let r2l = vld1q_f32(row2.as_ptr().add(x - 1));
60 let r2m = vld1q_f32(row2.as_ptr().add(x));
61 let r2r = vld1q_f32(row2.as_ptr().add(x + 1));
62
63 let sum = vaddq_f32(
64 vaddq_f32(vaddq_f32(r0l, r0m), vaddq_f32(r0r, r1l)),
65 vaddq_f32(vaddq_f32(r1m, r1r), vaddq_f32(r2l, vaddq_f32(r2m, r2r))),
66 );
67 vst1q_f32(out.as_mut_ptr().add(x), vmulq_f32(sum, inv9));
68 x += 4;
69 }
70 x
71}
72
73#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
74#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
75#[target_feature(enable = "avx")]
76unsafe fn box_blur_avx_row_c1(
77 row0: &[f32],
78 row1: &[f32],
79 row2: &[f32],
80 out: &mut [f32],
81 w: usize,
82) -> usize {
83 #[cfg(target_arch = "x86")]
84 use std::arch::x86::*;
85 #[cfg(target_arch = "x86_64")]
86 use std::arch::x86_64::*;
87
88 let inv9 = _mm256_set1_ps(1.0 / 9.0);
89 let mut x = 1usize;
90 while x + 9 <= w {
91 let r0l = _mm256_loadu_ps(row0.as_ptr().add(x - 1));
92 let r0m = _mm256_loadu_ps(row0.as_ptr().add(x));
93 let r0r = _mm256_loadu_ps(row0.as_ptr().add(x + 1));
94 let r1l = _mm256_loadu_ps(row1.as_ptr().add(x - 1));
95 let r1m = _mm256_loadu_ps(row1.as_ptr().add(x));
96 let r1r = _mm256_loadu_ps(row1.as_ptr().add(x + 1));
97 let r2l = _mm256_loadu_ps(row2.as_ptr().add(x - 1));
98 let r2m = _mm256_loadu_ps(row2.as_ptr().add(x));
99 let r2r = _mm256_loadu_ps(row2.as_ptr().add(x + 1));
100
101 let sum = _mm256_add_ps(
102 _mm256_add_ps(_mm256_add_ps(r0l, r0m), _mm256_add_ps(r0r, r1l)),
103 _mm256_add_ps(
104 _mm256_add_ps(r1m, r1r),
105 _mm256_add_ps(r2l, _mm256_add_ps(r2m, r2r)),
106 ),
107 );
108 _mm256_storeu_ps(out.as_mut_ptr().add(x), _mm256_mul_ps(sum, inv9));
109 x += 8;
110 }
111 x
112}
113
114#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
115#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
116#[target_feature(enable = "sse")]
117unsafe fn box_blur_sse_row_c1(
118 row0: &[f32],
119 row1: &[f32],
120 row2: &[f32],
121 out: &mut [f32],
122 w: usize,
123) -> usize {
124 #[cfg(target_arch = "x86")]
125 use std::arch::x86::*;
126 #[cfg(target_arch = "x86_64")]
127 use std::arch::x86_64::*;
128
129 let inv9 = _mm_set1_ps(1.0 / 9.0);
130 let mut x = 1usize;
131 while x + 5 <= w {
132 let r0l = _mm_loadu_ps(row0.as_ptr().add(x - 1));
133 let r0m = _mm_loadu_ps(row0.as_ptr().add(x));
134 let r0r = _mm_loadu_ps(row0.as_ptr().add(x + 1));
135 let r1l = _mm_loadu_ps(row1.as_ptr().add(x - 1));
136 let r1m = _mm_loadu_ps(row1.as_ptr().add(x));
137 let r1r = _mm_loadu_ps(row1.as_ptr().add(x + 1));
138 let r2l = _mm_loadu_ps(row2.as_ptr().add(x - 1));
139 let r2m = _mm_loadu_ps(row2.as_ptr().add(x));
140 let r2r = _mm_loadu_ps(row2.as_ptr().add(x + 1));
141
142 let sum = _mm_add_ps(
143 _mm_add_ps(_mm_add_ps(r0l, r0m), _mm_add_ps(r0r, r1l)),
144 _mm_add_ps(_mm_add_ps(r1m, r1r), _mm_add_ps(r2l, _mm_add_ps(r2m, r2r))),
145 );
146 _mm_storeu_ps(out.as_mut_ptr().add(x), _mm_mul_ps(sum, inv9));
147 x += 4;
148 }
149 x
150}
151
152#[allow(unsafe_code, clippy::uninit_vec)]
154pub fn box_blur_3x3(input: &Tensor) -> Result<Tensor, ImgProcError> {
155 let (h, w, channels) = hwc_shape(input)?;
156 let data = input.data();
157 let row_len = w * channels;
158 let total = h * w * channels;
159 let mut out = Vec::with_capacity(total);
161 unsafe {
162 out.set_len(total);
163 }
164
165 let compute_row = |y: usize, row: &mut [f32]| {
166 if channels == 1 && y > 0 && y < h - 1 && !cfg!(miri) {
168 let row0 = &data[(y - 1) * w..y * w];
169 let row1 = &data[y * w..(y + 1) * w];
170 let row2 = &data[(y + 1) * w..(y + 2) * w];
171 let done = box_blur_simd_row_c1(row0, row1, row2, row, w);
172 for x in done..w.saturating_sub(1) {
174 if x == 0 {
175 continue;
176 }
177 let sum = row0[x - 1]
178 + row0[x]
179 + row0[x + 1]
180 + row1[x - 1]
181 + row1[x]
182 + row1[x + 1]
183 + row2[x - 1]
184 + row2[x]
185 + row2[x + 1];
186 row[x] = sum / 9.0;
187 }
188 {
191 let mut acc = 0.0f32;
192 let mut count = 0.0f32;
193 for ky in -1isize..=1 {
194 let sy = y as isize + ky;
195 if sy < 0 || sy >= h as isize {
196 continue;
197 }
198 for kx in 0isize..=1 {
199 acc += data[(sy as usize) * w + kx as usize];
200 count += 1.0;
201 }
202 }
203 row[0] = acc / count;
204 }
205 if w > 1 {
207 let mut acc = 0.0f32;
208 let mut count = 0.0f32;
209 for ky in -1isize..=1 {
210 let sy = y as isize + ky;
211 if sy < 0 || sy >= h as isize {
212 continue;
213 }
214 for kx in (w as isize - 2)..=(w as isize - 1) {
215 if kx >= 0 {
216 acc += data[(sy as usize) * w + kx as usize];
217 count += 1.0;
218 }
219 }
220 }
221 row[w - 1] = acc / count;
222 }
223 return;
224 }
225
226 for x in 0..w {
227 for c in 0..channels {
228 let mut acc = 0.0f32;
229 let mut count = 0.0f32;
230 for ky in -1isize..=1 {
231 for kx in -1isize..=1 {
232 let sy = y as isize + ky;
233 let sx = x as isize + kx;
234 if sy < 0 || sx < 0 || sy >= h as isize || sx >= w as isize {
235 continue;
236 }
237 let src = ((sy as usize) * w + sx as usize) * channels + c;
238 acc += data[src];
239 count += 1.0;
240 }
241 }
242 row[x * channels + c] = acc / count;
243 }
244 }
245 };
246
247 let pixels = h * w;
248
249 #[cfg(target_os = "macos")]
250 if pixels > 4096 && !cfg!(miri) {
251 let out_ptr = out.as_mut_ptr() as usize;
252 use super::u8ops::gcd;
253 gcd::parallel_for(h, |y| {
254 let row = unsafe {
256 std::slice::from_raw_parts_mut((out_ptr as *mut f32).add(y * row_len), row_len)
257 };
258 compute_row(y, row);
259 });
260 return Tensor::from_vec(vec![h, w, channels], out).map_err(Into::into);
261 }
262
263 if pixels > 4096 {
264 out.par_chunks_mut(row_len)
265 .enumerate()
266 .for_each(|(y, row)| compute_row(y, row));
267 } else {
268 out.chunks_mut(row_len)
269 .enumerate()
270 .for_each(|(y, row)| compute_row(y, row));
271 }
272
273 Tensor::from_vec(vec![h, w, channels], out).map_err(Into::into)
274}
275
276#[allow(unsafe_code)]
280fn gauss_h_simd_row_c1(src: &[f32], out: &mut [f32], w: usize) -> usize {
281 if w < 6 {
282 return 1;
283 }
284
285 #[cfg(target_arch = "aarch64")]
286 {
287 if std::arch::is_aarch64_feature_detected!("neon") {
288 return unsafe { gauss_h_neon_row_c1(src, out, w) };
289 }
290 }
291 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
292 if std::is_x86_feature_detected!("avx") {
293 return unsafe { gauss_h_avx_row_c1(src, out, w) };
294 }
295 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
296 {
297 if std::is_x86_feature_detected!("sse") {
298 return unsafe { gauss_h_sse_row_c1(src, out, w) };
299 }
300 }
301 1
302}
303
304#[cfg(target_arch = "aarch64")]
305#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
306#[target_feature(enable = "neon")]
307unsafe fn gauss_h_neon_row_c1(src: &[f32], out: &mut [f32], w: usize) -> usize {
308 use std::arch::aarch64::*;
309 let two = vdupq_n_f32(2.0);
310 let quarter = vdupq_n_f32(0.25);
311 let mut x = 1usize;
312 while x + 5 <= w {
313 let left = vld1q_f32(src.as_ptr().add(x - 1));
314 let center = vld1q_f32(src.as_ptr().add(x));
315 let right = vld1q_f32(src.as_ptr().add(x + 1));
316 let sum = vaddq_f32(vaddq_f32(left, right), vmulq_f32(center, two));
317 vst1q_f32(out.as_mut_ptr().add(x), vmulq_f32(sum, quarter));
318 x += 4;
319 }
320 x
321}
322
323#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
324#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
325#[target_feature(enable = "avx")]
326unsafe fn gauss_h_avx_row_c1(src: &[f32], out: &mut [f32], w: usize) -> usize {
327 #[cfg(target_arch = "x86")]
328 use std::arch::x86::*;
329 #[cfg(target_arch = "x86_64")]
330 use std::arch::x86_64::*;
331
332 let two = _mm256_set1_ps(2.0);
333 let quarter = _mm256_set1_ps(0.25);
334 let mut x = 1usize;
335 while x + 9 <= w {
336 let left = _mm256_loadu_ps(src.as_ptr().add(x - 1));
337 let center = _mm256_loadu_ps(src.as_ptr().add(x));
338 let right = _mm256_loadu_ps(src.as_ptr().add(x + 1));
339 let sum = _mm256_add_ps(_mm256_add_ps(left, right), _mm256_mul_ps(center, two));
340 _mm256_storeu_ps(out.as_mut_ptr().add(x), _mm256_mul_ps(sum, quarter));
341 x += 8;
342 }
343 x
344}
345
346#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
347#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
348#[target_feature(enable = "sse")]
349unsafe fn gauss_h_sse_row_c1(src: &[f32], out: &mut [f32], w: usize) -> usize {
350 #[cfg(target_arch = "x86")]
351 use std::arch::x86::*;
352 #[cfg(target_arch = "x86_64")]
353 use std::arch::x86_64::*;
354
355 let two = _mm_set1_ps(2.0);
356 let quarter = _mm_set1_ps(0.25);
357 let mut x = 1usize;
358 while x + 5 <= w {
359 let left = _mm_loadu_ps(src.as_ptr().add(x - 1));
360 let center = _mm_loadu_ps(src.as_ptr().add(x));
361 let right = _mm_loadu_ps(src.as_ptr().add(x + 1));
362 let sum = _mm_add_ps(_mm_add_ps(left, right), _mm_mul_ps(center, two));
363 _mm_storeu_ps(out.as_mut_ptr().add(x), _mm_mul_ps(sum, quarter));
364 x += 4;
365 }
366 x
367}
368
369#[allow(unsafe_code, clippy::uninit_vec)]
374pub fn gaussian_blur_3x3(input: &Tensor) -> Result<Tensor, ImgProcError> {
375 let (h, w, channels) = hwc_shape(input)?;
376 let data = input.data();
377 let total = h * w * channels;
378 let mut tmp = Vec::with_capacity(total);
381 unsafe {
382 tmp.set_len(total);
383 }
384
385 for y in 0..h {
386 if channels == 1 && !cfg!(miri) {
388 let src = &data[y * w..(y + 1) * w];
389 let dst = &mut tmp[y * w..(y + 1) * w];
390 {
392 let center = src[0];
393 let right = src[1.min(w - 1)];
394 dst[0] = (center * 2.0 + right) / 4.0;
395 }
396 let done = gauss_h_simd_row_c1(src, dst, w);
397 for x in done..w.saturating_sub(1) {
399 if x == 0 {
400 continue;
401 }
402 dst[x] = (src[x - 1] + src[x] * 2.0 + src[x + 1]) * 0.25;
403 }
404 if w > 1 {
406 dst[w - 1] = (src[w - 2] + src[w - 1] * 2.0) / 4.0;
407 }
408 continue;
409 }
410
411 for c in 0..channels {
412 {
414 let center = data[(y * w) * channels + c];
415 let right = data[(y * w + 1.min(w - 1)) * channels + c];
416 tmp[(y * w) * channels + c] = (center * 2.0 + right) / 4.0;
417 }
418 for x in 1..w.saturating_sub(1) {
420 let base = y * w;
421 let left = data[(base + x - 1) * channels + c];
422 let center = data[(base + x) * channels + c];
423 let right = data[(base + x + 1) * channels + c];
424 tmp[(base + x) * channels + c] = (left + center * 2.0 + right) * 0.25;
425 }
426 if w > 1 {
428 let base = y * w;
429 let left = data[(base + w - 2) * channels + c];
430 let center = data[(base + w - 1) * channels + c];
431 tmp[(base + w - 1) * channels + c] = (left + center * 2.0) / 4.0;
432 }
433 }
434 }
435 let mut out = Vec::with_capacity(total);
438 unsafe {
439 out.set_len(total);
440 }
441 let row_len = w * channels;
442
443 let compute_row = |y: usize, row: &mut [f32]| {
444 if channels == 1 && y > 0 && y < h - 1 && !cfg!(miri) {
446 let above = &tmp[(y - 1) * w..y * w];
447 let center = &tmp[y * w..(y + 1) * w];
448 let below = &tmp[(y + 1) * w..(y + 2) * w];
449 let done = gauss_v_simd_row_c1(above, center, below, row, w);
450 for x in done..w {
451 row[x] = (above[x] + center[x] * 2.0 + below[x]) * 0.25;
452 }
453 return;
454 }
455
456 for x in 0..w {
457 for c in 0..channels {
458 let val = if y == 0 {
459 let center = tmp[x * channels + c];
460 let below = tmp[(1.min(h - 1) * w + x) * channels + c];
461 (center * 2.0 + below) / 4.0
462 } else if y == h - 1 && h > 1 {
463 let above = tmp[((h - 2) * w + x) * channels + c];
464 let center = tmp[((h - 1) * w + x) * channels + c];
465 (above + center * 2.0) / 4.0
466 } else {
467 let above = tmp[((y - 1) * w + x) * channels + c];
468 let center = tmp[(y * w + x) * channels + c];
469 let below = tmp[((y + 1) * w + x) * channels + c];
470 (above + center * 2.0 + below) * 0.25
471 };
472 row[x * channels + c] = val;
473 }
474 }
475 };
476
477 let pixels = h * w;
478
479 #[cfg(target_os = "macos")]
480 if pixels > 4096 && !cfg!(miri) {
481 let out_ptr = out.as_mut_ptr() as usize;
482 use super::u8ops::gcd;
483 gcd::parallel_for(h, |y| {
484 let row = unsafe {
486 std::slice::from_raw_parts_mut((out_ptr as *mut f32).add(y * row_len), row_len)
487 };
488 compute_row(y, row);
489 });
490 return Tensor::from_vec(vec![h, w, channels], out).map_err(Into::into);
491 }
492
493 if pixels > 4096 {
494 out.par_chunks_mut(row_len)
495 .enumerate()
496 .for_each(|(y, row)| compute_row(y, row));
497 } else {
498 out.chunks_mut(row_len)
499 .enumerate()
500 .for_each(|(y, row)| compute_row(y, row));
501 }
502
503 Tensor::from_vec(vec![h, w, channels], out).map_err(Into::into)
504}
505
506#[allow(unsafe_code)]
508fn gauss_v_simd_row_c1(
509 above: &[f32],
510 center: &[f32],
511 below: &[f32],
512 out: &mut [f32],
513 w: usize,
514) -> usize {
515 if w < 4 {
516 return 0;
517 }
518
519 #[cfg(target_arch = "aarch64")]
520 {
521 if std::arch::is_aarch64_feature_detected!("neon") {
522 return unsafe { gauss_v_neon_c1(above, center, below, out, w) };
523 }
524 }
525 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
526 if std::is_x86_feature_detected!("avx") {
527 return unsafe { gauss_v_avx_c1(above, center, below, out, w) };
528 }
529 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
530 {
531 if std::is_x86_feature_detected!("sse") {
532 return unsafe { gauss_v_sse_c1(above, center, below, out, w) };
533 }
534 }
535 0
536}
537
538#[cfg(target_arch = "aarch64")]
539#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
540#[target_feature(enable = "neon")]
541unsafe fn gauss_v_neon_c1(
542 above: &[f32],
543 center: &[f32],
544 below: &[f32],
545 out: &mut [f32],
546 w: usize,
547) -> usize {
548 use std::arch::aarch64::*;
549 let two = vdupq_n_f32(2.0);
550 let quarter = vdupq_n_f32(0.25);
551 let mut x = 0usize;
552 while x + 4 <= w {
553 let a = vld1q_f32(above.as_ptr().add(x));
554 let c = vld1q_f32(center.as_ptr().add(x));
555 let b = vld1q_f32(below.as_ptr().add(x));
556 let sum = vaddq_f32(vaddq_f32(a, b), vmulq_f32(c, two));
557 vst1q_f32(out.as_mut_ptr().add(x), vmulq_f32(sum, quarter));
558 x += 4;
559 }
560 x
561}
562
563#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
564#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
565#[target_feature(enable = "avx")]
566unsafe fn gauss_v_avx_c1(
567 above: &[f32],
568 center: &[f32],
569 below: &[f32],
570 out: &mut [f32],
571 w: usize,
572) -> usize {
573 #[cfg(target_arch = "x86")]
574 use std::arch::x86::*;
575 #[cfg(target_arch = "x86_64")]
576 use std::arch::x86_64::*;
577
578 let two = _mm256_set1_ps(2.0);
579 let quarter = _mm256_set1_ps(0.25);
580 let mut x = 0usize;
581 while x + 8 <= w {
582 let a = _mm256_loadu_ps(above.as_ptr().add(x));
583 let c = _mm256_loadu_ps(center.as_ptr().add(x));
584 let b = _mm256_loadu_ps(below.as_ptr().add(x));
585 let sum = _mm256_add_ps(_mm256_add_ps(a, b), _mm256_mul_ps(c, two));
586 _mm256_storeu_ps(out.as_mut_ptr().add(x), _mm256_mul_ps(sum, quarter));
587 x += 8;
588 }
589 x
590}
591
592#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
593#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
594#[target_feature(enable = "sse")]
595unsafe fn gauss_v_sse_c1(
596 above: &[f32],
597 center: &[f32],
598 below: &[f32],
599 out: &mut [f32],
600 w: usize,
601) -> usize {
602 #[cfg(target_arch = "x86")]
603 use std::arch::x86::*;
604 #[cfg(target_arch = "x86_64")]
605 use std::arch::x86_64::*;
606
607 let two = _mm_set1_ps(2.0);
608 let quarter = _mm_set1_ps(0.25);
609 let mut x = 0usize;
610 while x + 4 <= w {
611 let a = _mm_loadu_ps(above.as_ptr().add(x));
612 let c = _mm_loadu_ps(center.as_ptr().add(x));
613 let b = _mm_loadu_ps(below.as_ptr().add(x));
614 let sum = _mm_add_ps(_mm_add_ps(a, b), _mm_mul_ps(c, two));
615 _mm_storeu_ps(out.as_mut_ptr().add(x), _mm_mul_ps(sum, quarter));
616 x += 4;
617 }
618 x
619}
620
621pub fn gaussian_blur_5x5(input: &Tensor) -> Result<Tensor, ImgProcError> {
625 let (h, w, channels) = hwc_shape(input)?;
626 let data = input.data();
627 let k: [f32; 5] = [1.0 / 16.0, 4.0 / 16.0, 6.0 / 16.0, 4.0 / 16.0, 1.0 / 16.0];
628
629 let mut tmp = vec![0.0f32; h * w * channels];
631 for y in 0..h {
632 for x in 0..w {
633 for c in 0..channels {
634 let base = y * w;
635 let mut acc = 0.0f32;
636 for i in 0..5 {
637 let sx = (x as isize + i as isize - 2).clamp(0, w as isize - 1) as usize;
638 acc += data[(base + sx) * channels + c] * k[i];
639 }
640 tmp[(base + x) * channels + c] = acc;
641 }
642 }
643 }
644 let mut out = vec![0.0f32; h * w * channels];
646 let row_len = w * channels;
647
648 let compute_row = |y: usize, row: &mut [f32]| {
649 for x in 0..w {
650 for c in 0..channels {
651 let mut acc = 0.0f32;
652 for i in 0..5 {
653 let sy = (y as isize + i as isize - 2).clamp(0, h as isize - 1) as usize;
654 acc += tmp[(sy * w + x) * channels + c] * k[i];
655 }
656 row[x * channels + c] = acc;
657 }
658 }
659 };
660
661 if h * w > 4096 {
662 out.par_chunks_mut(row_len)
663 .enumerate()
664 .for_each(|(y, row)| compute_row(y, row));
665 } else {
666 out.chunks_mut(row_len)
667 .enumerate()
668 .for_each(|(y, row)| compute_row(y, row));
669 }
670
671 Tensor::from_vec(vec![h, w, channels], out).map_err(Into::into)
672}
673
674pub(crate) fn apply_kernel_3x3(
677 input: &Tensor,
678 kernel: &[[f32; 3]; 3],
679) -> Result<Tensor, ImgProcError> {
680 let (h, w, channels) = hwc_shape(input)?;
681 let data = input.data();
682 let mut out = vec![0.0f32; h * w * channels];
683
684 let row_len = w * channels;
686 let interior_h = h.saturating_sub(2); let compute_interior_row = |y: usize, row: &mut [f32]| {
689 for x in 1..w.saturating_sub(1) {
690 for c in 0..channels {
691 let mut acc = 0.0f32;
692 let r0 = ((y - 1) * w + x - 1) * channels + c;
693 let r1 = (y * w + x - 1) * channels + c;
694 let r2 = ((y + 1) * w + x - 1) * channels + c;
695 acc += data[r0] * kernel[0][0];
696 acc += data[r0 + channels] * kernel[0][1];
697 acc += data[r0 + 2 * channels] * kernel[0][2];
698 acc += data[r1] * kernel[1][0];
699 acc += data[r1 + channels] * kernel[1][1];
700 acc += data[r1 + 2 * channels] * kernel[1][2];
701 acc += data[r2] * kernel[2][0];
702 acc += data[r2 + channels] * kernel[2][1];
703 acc += data[r2 + 2 * channels] * kernel[2][2];
704 row[x * channels + c] = acc;
705 }
706 }
707 };
708
709 if interior_h > 0 {
710 let interior_out = &mut out[row_len..row_len + interior_h * row_len];
712 if h * w > 4096 {
713 interior_out
714 .par_chunks_mut(row_len)
715 .enumerate()
716 .for_each(|(i, row)| compute_interior_row(i + 1, row));
717 } else {
718 interior_out
719 .chunks_mut(row_len)
720 .enumerate()
721 .for_each(|(i, row)| compute_interior_row(i + 1, row));
722 }
723 }
724
725 let border_pixels = border_coords_3x3(h, w);
727 for (y, x) in border_pixels {
728 for c in 0..channels {
729 let mut acc = 0.0f32;
730 for ky in -1isize..=1 {
731 for kx in -1isize..=1 {
732 let sy = y as isize + ky;
733 let sx = x as isize + kx;
734 if sy < 0 || sx < 0 || sy >= h as isize || sx >= w as isize {
735 continue;
736 }
737 let src = ((sy as usize) * w + sx as usize) * channels + c;
738 acc += data[src] * kernel[(ky + 1) as usize][(kx + 1) as usize];
739 }
740 }
741 out[(y * w + x) * channels + c] = acc;
742 }
743 }
744
745 Tensor::from_vec(vec![h, w, channels], out).map_err(Into::into)
746}
747
748pub(crate) fn border_coords_3x3(h: usize, w: usize) -> Vec<(usize, usize)> {
750 let mut coords = Vec::with_capacity(2 * w + 2 * h);
751 for x in 0..w {
753 coords.push((0, x));
754 }
755 if h > 1 {
757 for x in 0..w {
758 coords.push((h - 1, x));
759 }
760 }
761 for y in 1..h.saturating_sub(1) {
763 coords.push((y, 0));
764 if w > 1 {
765 coords.push((y, w - 1));
766 }
767 }
768 coords
769}
770
771pub fn laplacian_3x3(input: &Tensor) -> Result<Tensor, ImgProcError> {
773 const KERNEL: [[f32; 3]; 3] = [[0.0, 1.0, 0.0], [1.0, -4.0, 1.0], [0.0, 1.0, 0.0]];
774 apply_kernel_3x3(input, &KERNEL)
775}
776
777#[inline(always)]
781fn cswap(v: &mut [f32; 9], a: usize, b: usize) {
782 if v[a] > v[b] {
783 v.swap(a, b);
784 }
785}
786
787#[inline(always)]
790fn median9(v: &mut [f32; 9]) -> f32 {
791 cswap(v, 0, 1);
793 cswap(v, 3, 4);
794 cswap(v, 6, 7);
795 cswap(v, 1, 2);
796 cswap(v, 4, 5);
797 cswap(v, 7, 8);
798 cswap(v, 0, 1);
799 cswap(v, 3, 4);
800 cswap(v, 6, 7);
801 cswap(v, 0, 3);
802 cswap(v, 3, 6);
803 cswap(v, 0, 3);
804 cswap(v, 1, 4);
805 cswap(v, 4, 7);
806 cswap(v, 1, 4);
807 cswap(v, 2, 5);
808 cswap(v, 5, 8);
809 cswap(v, 2, 5);
810 cswap(v, 1, 3);
811 cswap(v, 5, 7);
812 cswap(v, 2, 6);
813 cswap(v, 4, 6);
814 cswap(v, 2, 4);
815 cswap(v, 2, 3);
816 cswap(v, 5, 6);
817 v[4]
818}
819
820pub fn median_blur_3x3(input: &Tensor) -> Result<Tensor, ImgProcError> {
822 let (h, w, channels) = hwc_shape(input)?;
823 let data = input.data();
824 let mut out = vec![0.0f32; h * w * channels];
825 let mut neighborhood = [0.0f32; 9];
826
827 for y in 1..h.saturating_sub(1) {
829 for x in 1..w.saturating_sub(1) {
830 for c in 0..channels {
831 let r0 = ((y - 1) * w + x - 1) * channels + c;
832 let r1 = (y * w + x - 1) * channels + c;
833 let r2 = ((y + 1) * w + x - 1) * channels + c;
834 neighborhood[0] = data[r0];
835 neighborhood[1] = data[r0 + channels];
836 neighborhood[2] = data[r0 + 2 * channels];
837 neighborhood[3] = data[r1];
838 neighborhood[4] = data[r1 + channels];
839 neighborhood[5] = data[r1 + 2 * channels];
840 neighborhood[6] = data[r2];
841 neighborhood[7] = data[r2 + channels];
842 neighborhood[8] = data[r2 + 2 * channels];
843 out[(y * w + x) * channels + c] = median9(&mut neighborhood);
844 }
845 }
846 }
847
848 let border = border_coords_3x3(h, w);
850 for (y, x) in border {
851 for c in 0..channels {
852 let mut count = 0usize;
853 for ky in -1isize..=1 {
854 for kx in -1isize..=1 {
855 let sy = y as isize + ky;
856 let sx = x as isize + kx;
857 if sy < 0 || sx < 0 || sy >= h as isize || sx >= w as isize {
858 continue;
859 }
860 let src = ((sy as usize) * w + sx as usize) * channels + c;
861 neighborhood[count] = data[src];
862 count += 1;
863 }
864 }
865 let slice = &mut neighborhood[..count];
866 slice.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
867 out[(y * w + x) * channels + c] = slice[count / 2];
868 }
869 }
870
871 Tensor::from_vec(vec![h, w, channels], out).map_err(Into::into)
872}
873
874pub fn median_filter(input: &Tensor, kernel_size: usize) -> Result<Tensor, ImgProcError> {
879 if kernel_size == 0 || kernel_size.is_multiple_of(2) {
880 return Err(ImgProcError::InvalidBlockSize {
881 block_size: kernel_size,
882 });
883 }
884 let (h, w, c) = hwc_shape(input)?;
885 if c != 1 {
886 return Err(ImgProcError::InvalidChannelCount {
887 expected: 1,
888 got: c,
889 });
890 }
891 let data = input.data();
892 let radius = (kernel_size / 2) as isize;
893 let mut out = vec![0.0f32; h * w];
894 let mut neighborhood = vec![0.0f32; kernel_size * kernel_size];
895
896 for y in 0..h {
897 for x in 0..w {
898 let mut count = 0usize;
899 for ky in -radius..=radius {
900 for kx in -radius..=radius {
901 let sy = (y as isize + ky).clamp(0, h as isize - 1) as usize;
902 let sx = (x as isize + kx).clamp(0, w as isize - 1) as usize;
903 neighborhood[count] = data[sy * w + sx];
904 count += 1;
905 }
906 }
907 let slice = &mut neighborhood[..count];
908 slice.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
909 out[y * w + x] = slice[count / 2];
910 }
911 }
912
913 Tensor::from_vec(vec![h, w, 1], out).map_err(Into::into)
914}
915
916#[cfg(target_arch = "aarch64")]
920#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
921#[target_feature(enable = "neon")]
922unsafe fn bilateral_neon_row(
923 data: &[f32],
924 w: usize,
925 y: usize,
926 x_start: usize,
927 x_end: usize,
928 radius: i32,
929 diameter: usize,
930 spatial_lut: &[f32],
931 color_lut: &[f32; 256],
932 row_out: &mut [f32],
933) {
934 use std::arch::aarch64::*;
935
936 let scale_255 = vdupq_n_f32(255.0);
937 let max255 = vdupq_n_u32(255);
938 let clut = color_lut.as_ptr();
939
940 for x in x_start..x_end {
941 let center = *data.get_unchecked(y * w + x);
942 let center_v = vdupq_n_f32(center);
943 let mut sum_v = vdupq_n_f32(0.0);
944 let mut wsum_v = vdupq_n_f32(0.0);
945 let mut sum_s = 0.0f32;
946 let mut wsum_s = 0.0f32;
947
948 for dy in -radius..=radius {
949 let ny = (y as i32 + dy) as usize;
950 let row_ptr = data.as_ptr().add(ny * w + x - (radius as usize));
951 let sp_ptr = spatial_lut
952 .as_ptr()
953 .add(((dy + radius) as usize) * diameter);
954
955 let mut dx = 0usize;
956
957 while dx + 8 <= diameter {
959 let n1 = vld1q_f32(row_ptr.add(dx));
961 let sp1 = vld1q_f32(sp_ptr.add(dx));
962 let diff1 = vabsq_f32(vsubq_f32(n1, center_v));
963 let idx1 = vminq_u32(vcvtq_u32_f32(vmulq_f32(diff1, scale_255)), max255);
964 let mut ia1 = [0u32; 4];
965 vst1q_u32(ia1.as_mut_ptr(), idx1);
966
967 let n2 = vld1q_f32(row_ptr.add(dx + 4));
969 let sp2 = vld1q_f32(sp_ptr.add(dx + 4));
970 let diff2 = vabsq_f32(vsubq_f32(n2, center_v));
971 let idx2 = vminq_u32(vcvtq_u32_f32(vmulq_f32(diff2, scale_255)), max255);
972 let mut ia2 = [0u32; 4];
973 vst1q_u32(ia2.as_mut_ptr(), idx2);
974
975 let cw1_arr = [
977 *clut.add(ia1[0] as usize),
978 *clut.add(ia1[1] as usize),
979 *clut.add(ia1[2] as usize),
980 *clut.add(ia1[3] as usize),
981 ];
982 let cw2_arr = [
983 *clut.add(ia2[0] as usize),
984 *clut.add(ia2[1] as usize),
985 *clut.add(ia2[2] as usize),
986 *clut.add(ia2[3] as usize),
987 ];
988
989 let wt1 = vmulq_f32(sp1, vld1q_f32(cw1_arr.as_ptr()));
990 let wt2 = vmulq_f32(sp2, vld1q_f32(cw2_arr.as_ptr()));
991 sum_v = vfmaq_f32(sum_v, n1, wt1);
992 sum_v = vfmaq_f32(sum_v, n2, wt2);
993 wsum_v = vaddq_f32(wsum_v, vaddq_f32(wt1, wt2));
994
995 dx += 8;
996 }
997
998 while dx + 4 <= diameter {
1000 let neighbors = vld1q_f32(row_ptr.add(dx));
1001 let spatial_w = vld1q_f32(sp_ptr.add(dx));
1002 let diff = vabsq_f32(vsubq_f32(neighbors, center_v));
1003 let idx_u32 = vminq_u32(vcvtq_u32_f32(vmulq_f32(diff, scale_255)), max255);
1004 let mut idx_arr = [0u32; 4];
1005 vst1q_u32(idx_arr.as_mut_ptr(), idx_u32);
1006
1007 let cw_arr = [
1008 *clut.add(idx_arr[0] as usize),
1009 *clut.add(idx_arr[1] as usize),
1010 *clut.add(idx_arr[2] as usize),
1011 *clut.add(idx_arr[3] as usize),
1012 ];
1013 let wt = vmulq_f32(spatial_w, vld1q_f32(cw_arr.as_ptr()));
1014 sum_v = vfmaq_f32(sum_v, neighbors, wt);
1015 wsum_v = vaddq_f32(wsum_v, wt);
1016 dx += 4;
1017 }
1018
1019 while dx < diameter {
1021 let neighbor = *row_ptr.add(dx);
1022 let color_diff = (neighbor - center).abs();
1023 let color_idx = ((color_diff * 255.0) as usize).min(255);
1024 let wt = *sp_ptr.add(dx) * *clut.add(color_idx);
1025 sum_s += neighbor * wt;
1026 wsum_s += wt;
1027 dx += 1;
1028 }
1029 }
1030
1031 let total_sum = vaddvq_f32(sum_v) + sum_s;
1032 let total_wsum = vaddvq_f32(wsum_v) + wsum_s;
1033 *row_out.get_unchecked_mut(x) = if total_wsum > 0.0 {
1034 total_sum / total_wsum
1035 } else {
1036 center
1037 };
1038 }
1039}
1040
1041#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1046#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
1047#[target_feature(enable = "sse2")]
1048unsafe fn bilateral_sse_row(
1049 data: &[f32],
1050 w: usize,
1051 y: usize,
1052 x_start: usize,
1053 x_end: usize,
1054 radius: i32,
1055 diameter: usize,
1056 spatial_lut: &[f32],
1057 color_lut: &[f32; 256],
1058 row_out: &mut [f32],
1059) {
1060 #[cfg(target_arch = "x86")]
1061 use std::arch::x86::*;
1062 #[cfg(target_arch = "x86_64")]
1063 use std::arch::x86_64::*;
1064
1065 let scale_255 = _mm_set1_ps(255.0);
1066 let max_idx = 255i32;
1067 let clut = color_lut.as_ptr();
1068
1069 for x in x_start..x_end {
1070 let center = *data.get_unchecked(y * w + x);
1071 let center_v = _mm_set1_ps(center);
1072 let mut sum_v = _mm_setzero_ps();
1073 let mut wsum_v = _mm_setzero_ps();
1074 let mut sum_s = 0.0f32;
1075 let mut wsum_s = 0.0f32;
1076
1077 for dy in -radius..=radius {
1078 let ny = (y as i32 + dy) as usize;
1079 let row_ptr = data.as_ptr().add(ny * w + x - (radius as usize));
1080 let sp_ptr = spatial_lut
1081 .as_ptr()
1082 .add(((dy + radius) as usize) * diameter);
1083
1084 let mut dx = 0usize;
1085
1086 while dx + 4 <= diameter {
1088 let neighbors = _mm_loadu_ps(row_ptr.add(dx));
1089 let spatial_w = _mm_loadu_ps(sp_ptr.add(dx));
1090
1091 let diff = _mm_sub_ps(neighbors, center_v);
1093 let neg_diff = _mm_sub_ps(_mm_setzero_ps(), diff);
1095 let abs_diff = _mm_max_ps(diff, neg_diff);
1096
1097 let scaled = _mm_mul_ps(abs_diff, scale_255);
1099 let idx_i32 = _mm_cvttps_epi32(scaled);
1101
1102 let mut idx_arr = [0i32; 4];
1105 _mm_storeu_si128(idx_arr.as_mut_ptr() as *mut __m128i, idx_i32);
1106 idx_arr[0] = idx_arr[0].min(max_idx).max(0);
1107 idx_arr[1] = idx_arr[1].min(max_idx).max(0);
1108 idx_arr[2] = idx_arr[2].min(max_idx).max(0);
1109 idx_arr[3] = idx_arr[3].min(max_idx).max(0);
1110
1111 let cw_arr = [
1113 *clut.add(idx_arr[0] as usize),
1114 *clut.add(idx_arr[1] as usize),
1115 *clut.add(idx_arr[2] as usize),
1116 *clut.add(idx_arr[3] as usize),
1117 ];
1118
1119 let color_w = _mm_loadu_ps(cw_arr.as_ptr());
1120 let wt = _mm_mul_ps(spatial_w, color_w);
1121
1122 sum_v = _mm_add_ps(sum_v, _mm_mul_ps(neighbors, wt));
1124 wsum_v = _mm_add_ps(wsum_v, wt);
1125
1126 dx += 4;
1127 }
1128
1129 while dx < diameter {
1131 let neighbor = *row_ptr.add(dx);
1132 let color_diff = (neighbor - center).abs();
1133 let color_idx = ((color_diff * 255.0) as usize).min(255);
1134 let wt = *sp_ptr.add(dx) * *clut.add(color_idx);
1135 sum_s += neighbor * wt;
1136 wsum_s += wt;
1137 dx += 1;
1138 }
1139 }
1140
1141 let hi = _mm_movehl_ps(sum_v, sum_v); let sum_lo = _mm_add_ps(sum_v, hi); let sum_shuf = _mm_shuffle_ps(sum_lo, sum_lo, 1); let total_sum_v = _mm_add_ss(sum_lo, sum_shuf);
1147
1148 let hi_w = _mm_movehl_ps(wsum_v, wsum_v);
1149 let wsum_lo = _mm_add_ps(wsum_v, hi_w);
1150 let wsum_shuf = _mm_shuffle_ps(wsum_lo, wsum_lo, 1);
1151 let total_wsum_v = _mm_add_ss(wsum_lo, wsum_shuf);
1152
1153 let total_sum = _mm_cvtss_f32(total_sum_v) + sum_s;
1154 let total_wsum = _mm_cvtss_f32(total_wsum_v) + wsum_s;
1155
1156 *row_out.get_unchecked_mut(x) = if total_wsum > 0.0 {
1157 total_sum / total_wsum
1158 } else {
1159 center
1160 };
1161 }
1162}
1163
1164#[allow(unsafe_code)]
1171pub fn bilateral_filter(
1172 input: &Tensor,
1173 d: usize,
1174 sigma_color: f32,
1175 sigma_space: f32,
1176) -> Result<Tensor, ImgProcError> {
1177 let (h, w, c) = hwc_shape(input)?;
1178 if c != 1 {
1179 return Err(ImgProcError::InvalidChannelCount {
1180 expected: 1,
1181 got: c,
1182 });
1183 }
1184 let data = input.data();
1185 let mut out = vec![0.0f32; h * w];
1186 let radius = d as i32;
1187 let color_coeff = -0.5 / (sigma_color * sigma_color);
1188 let space_coeff = -0.5 / (sigma_space * sigma_space);
1189
1190 let diameter = (2 * radius + 1) as usize;
1192 let mut spatial_lut = vec![0.0f32; diameter * diameter];
1193 for dy in -radius..=radius {
1194 for dx in -radius..=radius {
1195 let spatial_dist_sq = (dy * dy + dx * dx) as f32;
1196 let idx = ((dy + radius) as usize) * diameter + (dx + radius) as usize;
1197 spatial_lut[idx] = (space_coeff * spatial_dist_sq).exp();
1198 }
1199 }
1200
1201 let mut color_lut = [0.0f32; 256];
1203 for i in 0..256 {
1204 let diff = i as f32 / 255.0;
1205 color_lut[i] = (color_coeff * diff * diff).exp();
1206 }
1207
1208 let radius_u = d;
1209
1210 let process_pixel_scalar = |y: usize, x: usize| -> f32 {
1212 let center = data[y * w + x];
1213 let mut sum = 0.0f32;
1214 let mut weight_sum = 0.0f32;
1215 for dy in -radius..=radius {
1216 let ny = y as i32 + dy;
1217 if ny < 0 || ny >= h as i32 {
1218 continue;
1219 }
1220 let ny = ny as usize;
1221 let spatial_row_off = ((dy + radius) as usize) * diameter;
1222 for dx in -radius..=radius {
1223 let nx = x as i32 + dx;
1224 if nx < 0 || nx >= w as i32 {
1225 continue;
1226 }
1227 let neighbor = data[ny * w + nx as usize];
1228 let color_diff = (neighbor - center).abs();
1229 let color_idx = ((color_diff * 255.0) as usize).min(255);
1230 let spatial_idx = spatial_row_off + (dx + radius) as usize;
1231 let wt = spatial_lut[spatial_idx] * color_lut[color_idx];
1232 sum += neighbor * wt;
1233 weight_sum += wt;
1234 }
1235 }
1236 if weight_sum > 0.0 {
1237 sum / weight_sum
1238 } else {
1239 center
1240 }
1241 };
1242
1243 #[cfg(target_arch = "aarch64")]
1245 let use_neon = !cfg!(miri) && std::arch::is_aarch64_feature_detected!("neon");
1246 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1247 let use_sse2 = !cfg!(miri) && std::is_x86_feature_detected!("sse2");
1248
1249 let x_start = radius_u;
1251 let x_end = w.saturating_sub(radius_u);
1252
1253 let compute_row = |y: usize, row_out: &mut [f32]| {
1255 let is_interior_y = y >= radius_u && y + radius_u < h;
1256
1257 if is_interior_y {
1258 for x in 0..x_start {
1260 row_out[x] = process_pixel_scalar(y, x);
1261 }
1262 #[cfg(target_arch = "aarch64")]
1264 if use_neon {
1265 unsafe {
1266 bilateral_neon_row(
1267 data,
1268 w,
1269 y,
1270 x_start,
1271 x_end,
1272 radius,
1273 diameter,
1274 &spatial_lut,
1275 &color_lut,
1276 row_out,
1277 );
1278 }
1279 } else {
1280 for x in x_start..x_end {
1281 row_out[x] = process_pixel_scalar(y, x);
1282 }
1283 }
1284 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1285 if use_sse2 {
1286 unsafe {
1287 bilateral_sse_row(
1288 data,
1289 w,
1290 y,
1291 x_start,
1292 x_end,
1293 radius,
1294 diameter,
1295 &spatial_lut,
1296 &color_lut,
1297 row_out,
1298 );
1299 }
1300 } else {
1301 for x in x_start..x_end {
1302 row_out[x] = process_pixel_scalar(y, x);
1303 }
1304 }
1305 #[cfg(not(any(target_arch = "aarch64", target_arch = "x86", target_arch = "x86_64")))]
1306 {
1307 for x in x_start..x_end {
1308 row_out[x] = process_pixel_scalar(y, x);
1309 }
1310 }
1311 for x in x_end..w {
1313 row_out[x] = process_pixel_scalar(y, x);
1314 }
1315 } else {
1316 for x in 0..w {
1318 row_out[x] = process_pixel_scalar(y, x);
1319 }
1320 }
1321 };
1322
1323 let pixels = h * w;
1324
1325 #[cfg(target_os = "macos")]
1326 if pixels > 4096 && !cfg!(miri) {
1327 let out_ptr = out.as_mut_ptr() as usize;
1328 use super::u8ops::gcd;
1329 gcd::parallel_for(h, |y| {
1330 let row =
1331 unsafe { std::slice::from_raw_parts_mut((out_ptr as *mut f32).add(y * w), w) };
1332 compute_row(y, row);
1333 });
1334 return Tensor::from_vec(vec![h, w, 1], out).map_err(Into::into);
1335 }
1336
1337 if pixels > 4096 {
1338 out.par_chunks_mut(w)
1339 .enumerate()
1340 .for_each(|(y, row)| compute_row(y, row));
1341 } else {
1342 out.chunks_mut(w)
1343 .enumerate()
1344 .for_each(|(y, row)| compute_row(y, row));
1345 }
1346
1347 Tensor::from_vec(vec![h, w, 1], out).map_err(Into::into)
1348}
1349
1350pub fn filter2d(input: &Tensor, kernel: &Tensor) -> Result<Tensor, ImgProcError> {
1354 let (h, w, c) = hwc_shape(input)?;
1355 if c != 1 {
1356 return Err(ImgProcError::InvalidChannelCount {
1357 expected: 1,
1358 got: c,
1359 });
1360 }
1361 let (kh, kw, kc) = hwc_shape(kernel)?;
1362 if kc != 1 {
1363 return Err(ImgProcError::InvalidChannelCount {
1364 expected: 1,
1365 got: kc,
1366 });
1367 }
1368 let data = input.data();
1369 let kern = kernel.data();
1370 let rh = kh / 2;
1371 let rw = kw / 2;
1372 let mut out = vec![0.0f32; h * w];
1373
1374 for y in 0..h {
1375 for x in 0..w {
1376 let mut sum = 0.0f32;
1377 for ky in 0..kh {
1378 for kx in 0..kw {
1379 let ny = y as i32 + ky as i32 - rh as i32;
1380 let nx = x as i32 + kx as i32 - rw as i32;
1381 if ny >= 0 && ny < h as i32 && nx >= 0 && nx < w as i32 {
1382 sum += data[ny as usize * w + nx as usize] * kern[ky * kw + kx];
1383 }
1384 }
1385 }
1386 out[y * w + x] = sum;
1387 }
1388 }
1389
1390 Tensor::from_vec(vec![h, w, 1], out).map_err(Into::into)
1391}