Skip to main content

yscv_imgproc/ops/
filter.rs

1use rayon::prelude::*;
2use yscv_tensor::Tensor;
3
4use super::super::ImgProcError;
5use super::super::shape::hwc_shape;
6
7/// SIMD 3x3 box blur for single-channel interior row. Returns first x NOT processed.
8#[allow(unsafe_code)]
9fn box_blur_simd_row_c1(
10    row0: &[f32],
11    row1: &[f32],
12    row2: &[f32],
13    out: &mut [f32],
14    w: usize,
15) -> usize {
16    if w < 6 {
17        return 1;
18    }
19
20    #[cfg(target_arch = "aarch64")]
21    {
22        if std::arch::is_aarch64_feature_detected!("neon") {
23            return unsafe { box_blur_neon_row_c1(row0, row1, row2, out, w) };
24        }
25    }
26    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
27    if std::is_x86_feature_detected!("avx") {
28        return unsafe { box_blur_avx_row_c1(row0, row1, row2, out, w) };
29    }
30    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
31    {
32        if std::is_x86_feature_detected!("sse") {
33            return unsafe { box_blur_sse_row_c1(row0, row1, row2, out, w) };
34        }
35    }
36    1
37}
38
39#[cfg(target_arch = "aarch64")]
40#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
41#[target_feature(enable = "neon")]
42unsafe fn box_blur_neon_row_c1(
43    row0: &[f32],
44    row1: &[f32],
45    row2: &[f32],
46    out: &mut [f32],
47    w: usize,
48) -> usize {
49    use std::arch::aarch64::*;
50    let inv9 = vdupq_n_f32(1.0 / 9.0);
51    let mut x = 1usize;
52    while x + 5 <= w {
53        let r0l = vld1q_f32(row0.as_ptr().add(x - 1));
54        let r0m = vld1q_f32(row0.as_ptr().add(x));
55        let r0r = vld1q_f32(row0.as_ptr().add(x + 1));
56        let r1l = vld1q_f32(row1.as_ptr().add(x - 1));
57        let r1m = vld1q_f32(row1.as_ptr().add(x));
58        let r1r = vld1q_f32(row1.as_ptr().add(x + 1));
59        let r2l = vld1q_f32(row2.as_ptr().add(x - 1));
60        let r2m = vld1q_f32(row2.as_ptr().add(x));
61        let r2r = vld1q_f32(row2.as_ptr().add(x + 1));
62
63        let sum = vaddq_f32(
64            vaddq_f32(vaddq_f32(r0l, r0m), vaddq_f32(r0r, r1l)),
65            vaddq_f32(vaddq_f32(r1m, r1r), vaddq_f32(r2l, vaddq_f32(r2m, r2r))),
66        );
67        vst1q_f32(out.as_mut_ptr().add(x), vmulq_f32(sum, inv9));
68        x += 4;
69    }
70    x
71}
72
73#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
74#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
75#[target_feature(enable = "avx")]
76unsafe fn box_blur_avx_row_c1(
77    row0: &[f32],
78    row1: &[f32],
79    row2: &[f32],
80    out: &mut [f32],
81    w: usize,
82) -> usize {
83    #[cfg(target_arch = "x86")]
84    use std::arch::x86::*;
85    #[cfg(target_arch = "x86_64")]
86    use std::arch::x86_64::*;
87
88    let inv9 = _mm256_set1_ps(1.0 / 9.0);
89    let mut x = 1usize;
90    while x + 9 <= w {
91        let r0l = _mm256_loadu_ps(row0.as_ptr().add(x - 1));
92        let r0m = _mm256_loadu_ps(row0.as_ptr().add(x));
93        let r0r = _mm256_loadu_ps(row0.as_ptr().add(x + 1));
94        let r1l = _mm256_loadu_ps(row1.as_ptr().add(x - 1));
95        let r1m = _mm256_loadu_ps(row1.as_ptr().add(x));
96        let r1r = _mm256_loadu_ps(row1.as_ptr().add(x + 1));
97        let r2l = _mm256_loadu_ps(row2.as_ptr().add(x - 1));
98        let r2m = _mm256_loadu_ps(row2.as_ptr().add(x));
99        let r2r = _mm256_loadu_ps(row2.as_ptr().add(x + 1));
100
101        let sum = _mm256_add_ps(
102            _mm256_add_ps(_mm256_add_ps(r0l, r0m), _mm256_add_ps(r0r, r1l)),
103            _mm256_add_ps(
104                _mm256_add_ps(r1m, r1r),
105                _mm256_add_ps(r2l, _mm256_add_ps(r2m, r2r)),
106            ),
107        );
108        _mm256_storeu_ps(out.as_mut_ptr().add(x), _mm256_mul_ps(sum, inv9));
109        x += 8;
110    }
111    x
112}
113
114#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
115#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
116#[target_feature(enable = "sse")]
117unsafe fn box_blur_sse_row_c1(
118    row0: &[f32],
119    row1: &[f32],
120    row2: &[f32],
121    out: &mut [f32],
122    w: usize,
123) -> usize {
124    #[cfg(target_arch = "x86")]
125    use std::arch::x86::*;
126    #[cfg(target_arch = "x86_64")]
127    use std::arch::x86_64::*;
128
129    let inv9 = _mm_set1_ps(1.0 / 9.0);
130    let mut x = 1usize;
131    while x + 5 <= w {
132        let r0l = _mm_loadu_ps(row0.as_ptr().add(x - 1));
133        let r0m = _mm_loadu_ps(row0.as_ptr().add(x));
134        let r0r = _mm_loadu_ps(row0.as_ptr().add(x + 1));
135        let r1l = _mm_loadu_ps(row1.as_ptr().add(x - 1));
136        let r1m = _mm_loadu_ps(row1.as_ptr().add(x));
137        let r1r = _mm_loadu_ps(row1.as_ptr().add(x + 1));
138        let r2l = _mm_loadu_ps(row2.as_ptr().add(x - 1));
139        let r2m = _mm_loadu_ps(row2.as_ptr().add(x));
140        let r2r = _mm_loadu_ps(row2.as_ptr().add(x + 1));
141
142        let sum = _mm_add_ps(
143            _mm_add_ps(_mm_add_ps(r0l, r0m), _mm_add_ps(r0r, r1l)),
144            _mm_add_ps(_mm_add_ps(r1m, r1r), _mm_add_ps(r2l, _mm_add_ps(r2m, r2r))),
145        );
146        _mm_storeu_ps(out.as_mut_ptr().add(x), _mm_mul_ps(sum, inv9));
147        x += 4;
148    }
149    x
150}
151
152/// Applies zero-padded 3x3 box blur over each channel.
153#[allow(unsafe_code, clippy::uninit_vec)]
154pub fn box_blur_3x3(input: &Tensor) -> Result<Tensor, ImgProcError> {
155    let (h, w, channels) = hwc_shape(input)?;
156    let data = input.data();
157    let row_len = w * channels;
158    let total = h * w * channels;
159    // SAFETY: every element is written by compute_row (SIMD + scalar) below.
160    let mut out = Vec::with_capacity(total);
161    unsafe {
162        out.set_len(total);
163    }
164
165    let compute_row = |y: usize, row: &mut [f32]| {
166        // SIMD fast path for single-channel interior rows
167        if channels == 1 && y > 0 && y < h - 1 && !cfg!(miri) {
168            let row0 = &data[(y - 1) * w..y * w];
169            let row1 = &data[y * w..(y + 1) * w];
170            let row2 = &data[(y + 1) * w..(y + 2) * w];
171            let done = box_blur_simd_row_c1(row0, row1, row2, row, w);
172            // Scalar tail for interior
173            for x in done..w.saturating_sub(1) {
174                if x == 0 {
175                    continue;
176                }
177                let sum = row0[x - 1]
178                    + row0[x]
179                    + row0[x + 1]
180                    + row1[x - 1]
181                    + row1[x]
182                    + row1[x + 1]
183                    + row2[x - 1]
184                    + row2[x]
185                    + row2[x + 1];
186                row[x] = sum / 9.0;
187            }
188            // Border pixels x=0 and x=w-1 still need bounds-checked path
189            // x=0
190            {
191                let mut acc = 0.0f32;
192                let mut count = 0.0f32;
193                for ky in -1isize..=1 {
194                    let sy = y as isize + ky;
195                    if sy < 0 || sy >= h as isize {
196                        continue;
197                    }
198                    for kx in 0isize..=1 {
199                        acc += data[(sy as usize) * w + kx as usize];
200                        count += 1.0;
201                    }
202                }
203                row[0] = acc / count;
204            }
205            // x=w-1
206            if w > 1 {
207                let mut acc = 0.0f32;
208                let mut count = 0.0f32;
209                for ky in -1isize..=1 {
210                    let sy = y as isize + ky;
211                    if sy < 0 || sy >= h as isize {
212                        continue;
213                    }
214                    for kx in (w as isize - 2)..=(w as isize - 1) {
215                        if kx >= 0 {
216                            acc += data[(sy as usize) * w + kx as usize];
217                            count += 1.0;
218                        }
219                    }
220                }
221                row[w - 1] = acc / count;
222            }
223            return;
224        }
225
226        for x in 0..w {
227            for c in 0..channels {
228                let mut acc = 0.0f32;
229                let mut count = 0.0f32;
230                for ky in -1isize..=1 {
231                    for kx in -1isize..=1 {
232                        let sy = y as isize + ky;
233                        let sx = x as isize + kx;
234                        if sy < 0 || sx < 0 || sy >= h as isize || sx >= w as isize {
235                            continue;
236                        }
237                        let src = ((sy as usize) * w + sx as usize) * channels + c;
238                        acc += data[src];
239                        count += 1.0;
240                    }
241                }
242                row[x * channels + c] = acc / count;
243            }
244        }
245    };
246
247    let pixels = h * w;
248
249    #[cfg(target_os = "macos")]
250    if pixels > 4096 && !cfg!(miri) {
251        let out_ptr = out.as_mut_ptr() as usize;
252        use super::u8ops::gcd;
253        gcd::parallel_for(h, |y| {
254            // SAFETY: each row writes to a disjoint slice of out.
255            let row = unsafe {
256                std::slice::from_raw_parts_mut((out_ptr as *mut f32).add(y * row_len), row_len)
257            };
258            compute_row(y, row);
259        });
260        return Tensor::from_vec(vec![h, w, channels], out).map_err(Into::into);
261    }
262
263    if pixels > 4096 {
264        out.par_chunks_mut(row_len)
265            .enumerate()
266            .for_each(|(y, row)| compute_row(y, row));
267    } else {
268        out.chunks_mut(row_len)
269            .enumerate()
270            .for_each(|(y, row)| compute_row(y, row));
271    }
272
273    Tensor::from_vec(vec![h, w, channels], out).map_err(Into::into)
274}
275
276// ── Gaussian blur (separable) ─────────────────────────────────────
277
278/// SIMD `[1,2,1]` horizontal pass for c=1: `out[x] = (src[x-1] + 2*src[x] + src[x+1]) * 0.25`
279#[allow(unsafe_code)]
280fn gauss_h_simd_row_c1(src: &[f32], out: &mut [f32], w: usize) -> usize {
281    if w < 6 {
282        return 1;
283    }
284
285    #[cfg(target_arch = "aarch64")]
286    {
287        if std::arch::is_aarch64_feature_detected!("neon") {
288            return unsafe { gauss_h_neon_row_c1(src, out, w) };
289        }
290    }
291    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
292    if std::is_x86_feature_detected!("avx") {
293        return unsafe { gauss_h_avx_row_c1(src, out, w) };
294    }
295    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
296    {
297        if std::is_x86_feature_detected!("sse") {
298            return unsafe { gauss_h_sse_row_c1(src, out, w) };
299        }
300    }
301    1
302}
303
304#[cfg(target_arch = "aarch64")]
305#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
306#[target_feature(enable = "neon")]
307unsafe fn gauss_h_neon_row_c1(src: &[f32], out: &mut [f32], w: usize) -> usize {
308    use std::arch::aarch64::*;
309    let two = vdupq_n_f32(2.0);
310    let quarter = vdupq_n_f32(0.25);
311    let mut x = 1usize;
312    while x + 5 <= w {
313        let left = vld1q_f32(src.as_ptr().add(x - 1));
314        let center = vld1q_f32(src.as_ptr().add(x));
315        let right = vld1q_f32(src.as_ptr().add(x + 1));
316        let sum = vaddq_f32(vaddq_f32(left, right), vmulq_f32(center, two));
317        vst1q_f32(out.as_mut_ptr().add(x), vmulq_f32(sum, quarter));
318        x += 4;
319    }
320    x
321}
322
323#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
324#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
325#[target_feature(enable = "avx")]
326unsafe fn gauss_h_avx_row_c1(src: &[f32], out: &mut [f32], w: usize) -> usize {
327    #[cfg(target_arch = "x86")]
328    use std::arch::x86::*;
329    #[cfg(target_arch = "x86_64")]
330    use std::arch::x86_64::*;
331
332    let two = _mm256_set1_ps(2.0);
333    let quarter = _mm256_set1_ps(0.25);
334    let mut x = 1usize;
335    while x + 9 <= w {
336        let left = _mm256_loadu_ps(src.as_ptr().add(x - 1));
337        let center = _mm256_loadu_ps(src.as_ptr().add(x));
338        let right = _mm256_loadu_ps(src.as_ptr().add(x + 1));
339        let sum = _mm256_add_ps(_mm256_add_ps(left, right), _mm256_mul_ps(center, two));
340        _mm256_storeu_ps(out.as_mut_ptr().add(x), _mm256_mul_ps(sum, quarter));
341        x += 8;
342    }
343    x
344}
345
346#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
347#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
348#[target_feature(enable = "sse")]
349unsafe fn gauss_h_sse_row_c1(src: &[f32], out: &mut [f32], w: usize) -> usize {
350    #[cfg(target_arch = "x86")]
351    use std::arch::x86::*;
352    #[cfg(target_arch = "x86_64")]
353    use std::arch::x86_64::*;
354
355    let two = _mm_set1_ps(2.0);
356    let quarter = _mm_set1_ps(0.25);
357    let mut x = 1usize;
358    while x + 5 <= w {
359        let left = _mm_loadu_ps(src.as_ptr().add(x - 1));
360        let center = _mm_loadu_ps(src.as_ptr().add(x));
361        let right = _mm_loadu_ps(src.as_ptr().add(x + 1));
362        let sum = _mm_add_ps(_mm_add_ps(left, right), _mm_mul_ps(center, two));
363        _mm_storeu_ps(out.as_mut_ptr().add(x), _mm_mul_ps(sum, quarter));
364        x += 4;
365    }
366    x
367}
368
369/// Applies zero-padded 3x3 Gaussian blur per channel.
370/// Kernel: `[[1,2,1],[2,4,2],[1,2,1]] / 16`.
371///
372/// Uses separable decomposition: horizontal `[1,2,1]`/4 then vertical `[1,2,1]`/4.
373#[allow(unsafe_code, clippy::uninit_vec)]
374pub fn gaussian_blur_3x3(input: &Tensor) -> Result<Tensor, ImgProcError> {
375    let (h, w, channels) = hwc_shape(input)?;
376    let data = input.data();
377    let total = h * w * channels;
378    // Horizontal pass: [1, 2, 1] / 4
379    // SAFETY: every element is written by the horizontal pass below.
380    let mut tmp = Vec::with_capacity(total);
381    unsafe {
382        tmp.set_len(total);
383    }
384
385    for y in 0..h {
386        // SIMD fast path for c=1
387        if channels == 1 && !cfg!(miri) {
388            let src = &data[y * w..(y + 1) * w];
389            let dst = &mut tmp[y * w..(y + 1) * w];
390            // Left border
391            {
392                let center = src[0];
393                let right = src[1.min(w - 1)];
394                dst[0] = (center * 2.0 + right) / 4.0;
395            }
396            let done = gauss_h_simd_row_c1(src, dst, w);
397            // Scalar tail
398            for x in done..w.saturating_sub(1) {
399                if x == 0 {
400                    continue;
401                }
402                dst[x] = (src[x - 1] + src[x] * 2.0 + src[x + 1]) * 0.25;
403            }
404            // Right border
405            if w > 1 {
406                dst[w - 1] = (src[w - 2] + src[w - 1] * 2.0) / 4.0;
407            }
408            continue;
409        }
410
411        for c in 0..channels {
412            // Left border (x=0)
413            {
414                let center = data[(y * w) * channels + c];
415                let right = data[(y * w + 1.min(w - 1)) * channels + c];
416                tmp[(y * w) * channels + c] = (center * 2.0 + right) / 4.0;
417            }
418            // Interior (no bounds checks needed)
419            for x in 1..w.saturating_sub(1) {
420                let base = y * w;
421                let left = data[(base + x - 1) * channels + c];
422                let center = data[(base + x) * channels + c];
423                let right = data[(base + x + 1) * channels + c];
424                tmp[(base + x) * channels + c] = (left + center * 2.0 + right) * 0.25;
425            }
426            // Right border (x=w-1)
427            if w > 1 {
428                let base = y * w;
429                let left = data[(base + w - 2) * channels + c];
430                let center = data[(base + w - 1) * channels + c];
431                tmp[(base + w - 1) * channels + c] = (left + center * 2.0) / 4.0;
432            }
433        }
434    }
435    // Vertical pass: [1, 2, 1] / 4
436    // SAFETY: every element is written by compute_row below.
437    let mut out = Vec::with_capacity(total);
438    unsafe {
439        out.set_len(total);
440    }
441    let row_len = w * channels;
442
443    let compute_row = |y: usize, row: &mut [f32]| {
444        // SIMD fast path for c=1 interior rows
445        if channels == 1 && y > 0 && y < h - 1 && !cfg!(miri) {
446            let above = &tmp[(y - 1) * w..y * w];
447            let center = &tmp[y * w..(y + 1) * w];
448            let below = &tmp[(y + 1) * w..(y + 2) * w];
449            let done = gauss_v_simd_row_c1(above, center, below, row, w);
450            for x in done..w {
451                row[x] = (above[x] + center[x] * 2.0 + below[x]) * 0.25;
452            }
453            return;
454        }
455
456        for x in 0..w {
457            for c in 0..channels {
458                let val = if y == 0 {
459                    let center = tmp[x * channels + c];
460                    let below = tmp[(1.min(h - 1) * w + x) * channels + c];
461                    (center * 2.0 + below) / 4.0
462                } else if y == h - 1 && h > 1 {
463                    let above = tmp[((h - 2) * w + x) * channels + c];
464                    let center = tmp[((h - 1) * w + x) * channels + c];
465                    (above + center * 2.0) / 4.0
466                } else {
467                    let above = tmp[((y - 1) * w + x) * channels + c];
468                    let center = tmp[(y * w + x) * channels + c];
469                    let below = tmp[((y + 1) * w + x) * channels + c];
470                    (above + center * 2.0 + below) * 0.25
471                };
472                row[x * channels + c] = val;
473            }
474        }
475    };
476
477    let pixels = h * w;
478
479    #[cfg(target_os = "macos")]
480    if pixels > 4096 && !cfg!(miri) {
481        let out_ptr = out.as_mut_ptr() as usize;
482        use super::u8ops::gcd;
483        gcd::parallel_for(h, |y| {
484            // SAFETY: each row writes to a disjoint slice of out.
485            let row = unsafe {
486                std::slice::from_raw_parts_mut((out_ptr as *mut f32).add(y * row_len), row_len)
487            };
488            compute_row(y, row);
489        });
490        return Tensor::from_vec(vec![h, w, channels], out).map_err(Into::into);
491    }
492
493    if pixels > 4096 {
494        out.par_chunks_mut(row_len)
495            .enumerate()
496            .for_each(|(y, row)| compute_row(y, row));
497    } else {
498        out.chunks_mut(row_len)
499            .enumerate()
500            .for_each(|(y, row)| compute_row(y, row));
501    }
502
503    Tensor::from_vec(vec![h, w, channels], out).map_err(Into::into)
504}
505
506/// SIMD vertical `[1,2,1]`/4 pass for c=1
507#[allow(unsafe_code)]
508fn gauss_v_simd_row_c1(
509    above: &[f32],
510    center: &[f32],
511    below: &[f32],
512    out: &mut [f32],
513    w: usize,
514) -> usize {
515    if w < 4 {
516        return 0;
517    }
518
519    #[cfg(target_arch = "aarch64")]
520    {
521        if std::arch::is_aarch64_feature_detected!("neon") {
522            return unsafe { gauss_v_neon_c1(above, center, below, out, w) };
523        }
524    }
525    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
526    if std::is_x86_feature_detected!("avx") {
527        return unsafe { gauss_v_avx_c1(above, center, below, out, w) };
528    }
529    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
530    {
531        if std::is_x86_feature_detected!("sse") {
532            return unsafe { gauss_v_sse_c1(above, center, below, out, w) };
533        }
534    }
535    0
536}
537
538#[cfg(target_arch = "aarch64")]
539#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
540#[target_feature(enable = "neon")]
541unsafe fn gauss_v_neon_c1(
542    above: &[f32],
543    center: &[f32],
544    below: &[f32],
545    out: &mut [f32],
546    w: usize,
547) -> usize {
548    use std::arch::aarch64::*;
549    let two = vdupq_n_f32(2.0);
550    let quarter = vdupq_n_f32(0.25);
551    let mut x = 0usize;
552    while x + 4 <= w {
553        let a = vld1q_f32(above.as_ptr().add(x));
554        let c = vld1q_f32(center.as_ptr().add(x));
555        let b = vld1q_f32(below.as_ptr().add(x));
556        let sum = vaddq_f32(vaddq_f32(a, b), vmulq_f32(c, two));
557        vst1q_f32(out.as_mut_ptr().add(x), vmulq_f32(sum, quarter));
558        x += 4;
559    }
560    x
561}
562
563#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
564#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
565#[target_feature(enable = "avx")]
566unsafe fn gauss_v_avx_c1(
567    above: &[f32],
568    center: &[f32],
569    below: &[f32],
570    out: &mut [f32],
571    w: usize,
572) -> usize {
573    #[cfg(target_arch = "x86")]
574    use std::arch::x86::*;
575    #[cfg(target_arch = "x86_64")]
576    use std::arch::x86_64::*;
577
578    let two = _mm256_set1_ps(2.0);
579    let quarter = _mm256_set1_ps(0.25);
580    let mut x = 0usize;
581    while x + 8 <= w {
582        let a = _mm256_loadu_ps(above.as_ptr().add(x));
583        let c = _mm256_loadu_ps(center.as_ptr().add(x));
584        let b = _mm256_loadu_ps(below.as_ptr().add(x));
585        let sum = _mm256_add_ps(_mm256_add_ps(a, b), _mm256_mul_ps(c, two));
586        _mm256_storeu_ps(out.as_mut_ptr().add(x), _mm256_mul_ps(sum, quarter));
587        x += 8;
588    }
589    x
590}
591
592#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
593#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
594#[target_feature(enable = "sse")]
595unsafe fn gauss_v_sse_c1(
596    above: &[f32],
597    center: &[f32],
598    below: &[f32],
599    out: &mut [f32],
600    w: usize,
601) -> usize {
602    #[cfg(target_arch = "x86")]
603    use std::arch::x86::*;
604    #[cfg(target_arch = "x86_64")]
605    use std::arch::x86_64::*;
606
607    let two = _mm_set1_ps(2.0);
608    let quarter = _mm_set1_ps(0.25);
609    let mut x = 0usize;
610    while x + 4 <= w {
611        let a = _mm_loadu_ps(above.as_ptr().add(x));
612        let c = _mm_loadu_ps(center.as_ptr().add(x));
613        let b = _mm_loadu_ps(below.as_ptr().add(x));
614        let sum = _mm_add_ps(_mm_add_ps(a, b), _mm_mul_ps(c, two));
615        _mm_storeu_ps(out.as_mut_ptr().add(x), _mm_mul_ps(sum, quarter));
616        x += 4;
617    }
618    x
619}
620
621/// Applies zero-padded 5x5 Gaussian blur per channel.
622///
623/// Uses separable decomposition: horizontal `[1,4,6,4,1]`/16 then vertical `[1,4,6,4,1]`/16.
624pub fn gaussian_blur_5x5(input: &Tensor) -> Result<Tensor, ImgProcError> {
625    let (h, w, channels) = hwc_shape(input)?;
626    let data = input.data();
627    let k: [f32; 5] = [1.0 / 16.0, 4.0 / 16.0, 6.0 / 16.0, 4.0 / 16.0, 1.0 / 16.0];
628
629    // Horizontal pass
630    let mut tmp = vec![0.0f32; h * w * channels];
631    for y in 0..h {
632        for x in 0..w {
633            for c in 0..channels {
634                let base = y * w;
635                let mut acc = 0.0f32;
636                for i in 0..5 {
637                    let sx = (x as isize + i as isize - 2).clamp(0, w as isize - 1) as usize;
638                    acc += data[(base + sx) * channels + c] * k[i];
639                }
640                tmp[(base + x) * channels + c] = acc;
641            }
642        }
643    }
644    // Vertical pass
645    let mut out = vec![0.0f32; h * w * channels];
646    let row_len = w * channels;
647
648    let compute_row = |y: usize, row: &mut [f32]| {
649        for x in 0..w {
650            for c in 0..channels {
651                let mut acc = 0.0f32;
652                for i in 0..5 {
653                    let sy = (y as isize + i as isize - 2).clamp(0, h as isize - 1) as usize;
654                    acc += tmp[(sy * w + x) * channels + c] * k[i];
655                }
656                row[x * channels + c] = acc;
657            }
658        }
659    };
660
661    if h * w > 4096 {
662        out.par_chunks_mut(row_len)
663            .enumerate()
664            .for_each(|(y, row)| compute_row(y, row));
665    } else {
666        out.chunks_mut(row_len)
667            .enumerate()
668            .for_each(|(y, row)| compute_row(y, row));
669    }
670
671    Tensor::from_vec(vec![h, w, channels], out).map_err(Into::into)
672}
673
674// ── Generic 3x3 kernel (with interior/border split) ───────────────
675
676pub(crate) fn apply_kernel_3x3(
677    input: &Tensor,
678    kernel: &[[f32; 3]; 3],
679) -> Result<Tensor, ImgProcError> {
680    let (h, w, channels) = hwc_shape(input)?;
681    let data = input.data();
682    let mut out = vec![0.0f32; h * w * channels];
683
684    // Interior pixels: no bounds checks needed (y=1..h-1, x=1..w-1)
685    let row_len = w * channels;
686    let interior_h = h.saturating_sub(2); // number of interior rows
687
688    let compute_interior_row = |y: usize, row: &mut [f32]| {
689        for x in 1..w.saturating_sub(1) {
690            for c in 0..channels {
691                let mut acc = 0.0f32;
692                let r0 = ((y - 1) * w + x - 1) * channels + c;
693                let r1 = (y * w + x - 1) * channels + c;
694                let r2 = ((y + 1) * w + x - 1) * channels + c;
695                acc += data[r0] * kernel[0][0];
696                acc += data[r0 + channels] * kernel[0][1];
697                acc += data[r0 + 2 * channels] * kernel[0][2];
698                acc += data[r1] * kernel[1][0];
699                acc += data[r1 + channels] * kernel[1][1];
700                acc += data[r1 + 2 * channels] * kernel[1][2];
701                acc += data[r2] * kernel[2][0];
702                acc += data[r2 + channels] * kernel[2][1];
703                acc += data[r2 + 2 * channels] * kernel[2][2];
704                row[x * channels + c] = acc;
705            }
706        }
707    };
708
709    if interior_h > 0 {
710        // Slice out rows 1..h-1
711        let interior_out = &mut out[row_len..row_len + interior_h * row_len];
712        if h * w > 4096 {
713            interior_out
714                .par_chunks_mut(row_len)
715                .enumerate()
716                .for_each(|(i, row)| compute_interior_row(i + 1, row));
717        } else {
718            interior_out
719                .chunks_mut(row_len)
720                .enumerate()
721                .for_each(|(i, row)| compute_interior_row(i + 1, row));
722        }
723    }
724
725    // Border pixels: use bounds-checked path (top/bottom rows, left/right columns)
726    let border_pixels = border_coords_3x3(h, w);
727    for (y, x) in border_pixels {
728        for c in 0..channels {
729            let mut acc = 0.0f32;
730            for ky in -1isize..=1 {
731                for kx in -1isize..=1 {
732                    let sy = y as isize + ky;
733                    let sx = x as isize + kx;
734                    if sy < 0 || sx < 0 || sy >= h as isize || sx >= w as isize {
735                        continue;
736                    }
737                    let src = ((sy as usize) * w + sx as usize) * channels + c;
738                    acc += data[src] * kernel[(ky + 1) as usize][(kx + 1) as usize];
739                }
740            }
741            out[(y * w + x) * channels + c] = acc;
742        }
743    }
744
745    Tensor::from_vec(vec![h, w, channels], out).map_err(Into::into)
746}
747
748/// Returns an iterator of (y, x) border pixel coordinates for a 3x3 kernel.
749pub(crate) fn border_coords_3x3(h: usize, w: usize) -> Vec<(usize, usize)> {
750    let mut coords = Vec::with_capacity(2 * w + 2 * h);
751    // Top row
752    for x in 0..w {
753        coords.push((0, x));
754    }
755    // Bottom row (if h > 1)
756    if h > 1 {
757        for x in 0..w {
758            coords.push((h - 1, x));
759        }
760    }
761    // Left and right columns (excluding corners already added)
762    for y in 1..h.saturating_sub(1) {
763        coords.push((y, 0));
764        if w > 1 {
765            coords.push((y, w - 1));
766        }
767    }
768    coords
769}
770
771/// Applies 3x3 Laplacian edge detection per channel.
772pub fn laplacian_3x3(input: &Tensor) -> Result<Tensor, ImgProcError> {
773    const KERNEL: [[f32; 3]; 3] = [[0.0, 1.0, 0.0], [1.0, -4.0, 1.0], [0.0, 1.0, 0.0]];
774    apply_kernel_3x3(input, &KERNEL)
775}
776
777// ── Median blur (sorting network for N=9) ─────────────────────────
778
779/// Conditional swap at indices `a` and `b` in the array.
780#[inline(always)]
781fn cswap(v: &mut [f32; 9], a: usize, b: usize) {
782    if v[a] > v[b] {
783        v.swap(a, b);
784    }
785}
786
787/// Finds the median of exactly 9 f32 values using an optimal sorting network.
788/// Uses 25 compare-and-swap operations (optimal for N=9).
789#[inline(always)]
790fn median9(v: &mut [f32; 9]) -> f32 {
791    // Optimal 9-element sorting network (25 comparisons)
792    cswap(v, 0, 1);
793    cswap(v, 3, 4);
794    cswap(v, 6, 7);
795    cswap(v, 1, 2);
796    cswap(v, 4, 5);
797    cswap(v, 7, 8);
798    cswap(v, 0, 1);
799    cswap(v, 3, 4);
800    cswap(v, 6, 7);
801    cswap(v, 0, 3);
802    cswap(v, 3, 6);
803    cswap(v, 0, 3);
804    cswap(v, 1, 4);
805    cswap(v, 4, 7);
806    cswap(v, 1, 4);
807    cswap(v, 2, 5);
808    cswap(v, 5, 8);
809    cswap(v, 2, 5);
810    cswap(v, 1, 3);
811    cswap(v, 5, 7);
812    cswap(v, 2, 6);
813    cswap(v, 4, 6);
814    cswap(v, 2, 4);
815    cswap(v, 2, 3);
816    cswap(v, 5, 6);
817    v[4]
818}
819
820/// Applies 3x3 median filter per channel.
821pub fn median_blur_3x3(input: &Tensor) -> Result<Tensor, ImgProcError> {
822    let (h, w, channels) = hwc_shape(input)?;
823    let data = input.data();
824    let mut out = vec![0.0f32; h * w * channels];
825    let mut neighborhood = [0.0f32; 9];
826
827    // Interior pixels: full 9-element neighborhood, use sorting network
828    for y in 1..h.saturating_sub(1) {
829        for x in 1..w.saturating_sub(1) {
830            for c in 0..channels {
831                let r0 = ((y - 1) * w + x - 1) * channels + c;
832                let r1 = (y * w + x - 1) * channels + c;
833                let r2 = ((y + 1) * w + x - 1) * channels + c;
834                neighborhood[0] = data[r0];
835                neighborhood[1] = data[r0 + channels];
836                neighborhood[2] = data[r0 + 2 * channels];
837                neighborhood[3] = data[r1];
838                neighborhood[4] = data[r1 + channels];
839                neighborhood[5] = data[r1 + 2 * channels];
840                neighborhood[6] = data[r2];
841                neighborhood[7] = data[r2 + channels];
842                neighborhood[8] = data[r2 + 2 * channels];
843                out[(y * w + x) * channels + c] = median9(&mut neighborhood);
844            }
845        }
846    }
847
848    // Border pixels: variable neighborhood size, use sort
849    let border = border_coords_3x3(h, w);
850    for (y, x) in border {
851        for c in 0..channels {
852            let mut count = 0usize;
853            for ky in -1isize..=1 {
854                for kx in -1isize..=1 {
855                    let sy = y as isize + ky;
856                    let sx = x as isize + kx;
857                    if sy < 0 || sx < 0 || sy >= h as isize || sx >= w as isize {
858                        continue;
859                    }
860                    let src = ((sy as usize) * w + sx as usize) * channels + c;
861                    neighborhood[count] = data[src];
862                    count += 1;
863                }
864            }
865            let slice = &mut neighborhood[..count];
866            slice.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
867            out[(y * w + x) * channels + c] = slice[count / 2];
868        }
869    }
870
871    Tensor::from_vec(vec![h, w, channels], out).map_err(Into::into)
872}
873
874/// Applies NxN median filter on a single-channel `[H, W, 1]` image.
875///
876/// `kernel_size` must be odd and >= 1. Border pixels use replicate (clamp)
877/// padding so every pixel gets a full NxN neighborhood.
878pub fn median_filter(input: &Tensor, kernel_size: usize) -> Result<Tensor, ImgProcError> {
879    if kernel_size == 0 || kernel_size.is_multiple_of(2) {
880        return Err(ImgProcError::InvalidBlockSize {
881            block_size: kernel_size,
882        });
883    }
884    let (h, w, c) = hwc_shape(input)?;
885    if c != 1 {
886        return Err(ImgProcError::InvalidChannelCount {
887            expected: 1,
888            got: c,
889        });
890    }
891    let data = input.data();
892    let radius = (kernel_size / 2) as isize;
893    let mut out = vec![0.0f32; h * w];
894    let mut neighborhood = vec![0.0f32; kernel_size * kernel_size];
895
896    for y in 0..h {
897        for x in 0..w {
898            let mut count = 0usize;
899            for ky in -radius..=radius {
900                for kx in -radius..=radius {
901                    let sy = (y as isize + ky).clamp(0, h as isize - 1) as usize;
902                    let sx = (x as isize + kx).clamp(0, w as isize - 1) as usize;
903                    neighborhood[count] = data[sy * w + sx];
904                    count += 1;
905                }
906            }
907            let slice = &mut neighborhood[..count];
908            slice.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
909            out[y * w + x] = slice[count / 2];
910        }
911    }
912
913    Tensor::from_vec(vec![h, w, 1], out).map_err(Into::into)
914}
915
916/// NEON-accelerated bilateral filter for an entire interior row.
917/// Processes pixels x in [x_start, x_end) where all neighbors are in bounds.
918/// Uses double-batch (8 neighbors at a time) with interleaved LUT lookups.
919#[cfg(target_arch = "aarch64")]
920#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
921#[target_feature(enable = "neon")]
922unsafe fn bilateral_neon_row(
923    data: &[f32],
924    w: usize,
925    y: usize,
926    x_start: usize,
927    x_end: usize,
928    radius: i32,
929    diameter: usize,
930    spatial_lut: &[f32],
931    color_lut: &[f32; 256],
932    row_out: &mut [f32],
933) {
934    use std::arch::aarch64::*;
935
936    let scale_255 = vdupq_n_f32(255.0);
937    let max255 = vdupq_n_u32(255);
938    let clut = color_lut.as_ptr();
939
940    for x in x_start..x_end {
941        let center = *data.get_unchecked(y * w + x);
942        let center_v = vdupq_n_f32(center);
943        let mut sum_v = vdupq_n_f32(0.0);
944        let mut wsum_v = vdupq_n_f32(0.0);
945        let mut sum_s = 0.0f32;
946        let mut wsum_s = 0.0f32;
947
948        for dy in -radius..=radius {
949            let ny = (y as i32 + dy) as usize;
950            let row_ptr = data.as_ptr().add(ny * w + x - (radius as usize));
951            let sp_ptr = spatial_lut
952                .as_ptr()
953                .add(((dy + radius) as usize) * diameter);
954
955            let mut dx = 0usize;
956
957            // Process 8 neighbors at a time (2 NEON batches interleaved)
958            while dx + 8 <= diameter {
959                // Batch 1
960                let n1 = vld1q_f32(row_ptr.add(dx));
961                let sp1 = vld1q_f32(sp_ptr.add(dx));
962                let diff1 = vabsq_f32(vsubq_f32(n1, center_v));
963                let idx1 = vminq_u32(vcvtq_u32_f32(vmulq_f32(diff1, scale_255)), max255);
964                let mut ia1 = [0u32; 4];
965                vst1q_u32(ia1.as_mut_ptr(), idx1);
966
967                // Batch 2
968                let n2 = vld1q_f32(row_ptr.add(dx + 4));
969                let sp2 = vld1q_f32(sp_ptr.add(dx + 4));
970                let diff2 = vabsq_f32(vsubq_f32(n2, center_v));
971                let idx2 = vminq_u32(vcvtq_u32_f32(vmulq_f32(diff2, scale_255)), max255);
972                let mut ia2 = [0u32; 4];
973                vst1q_u32(ia2.as_mut_ptr(), idx2);
974
975                // Interleaved LUT lookups (while CPU works on store above)
976                let cw1_arr = [
977                    *clut.add(ia1[0] as usize),
978                    *clut.add(ia1[1] as usize),
979                    *clut.add(ia1[2] as usize),
980                    *clut.add(ia1[3] as usize),
981                ];
982                let cw2_arr = [
983                    *clut.add(ia2[0] as usize),
984                    *clut.add(ia2[1] as usize),
985                    *clut.add(ia2[2] as usize),
986                    *clut.add(ia2[3] as usize),
987                ];
988
989                let wt1 = vmulq_f32(sp1, vld1q_f32(cw1_arr.as_ptr()));
990                let wt2 = vmulq_f32(sp2, vld1q_f32(cw2_arr.as_ptr()));
991                sum_v = vfmaq_f32(sum_v, n1, wt1);
992                sum_v = vfmaq_f32(sum_v, n2, wt2);
993                wsum_v = vaddq_f32(wsum_v, vaddq_f32(wt1, wt2));
994
995                dx += 8;
996            }
997
998            // Remaining 4-element batch
999            while dx + 4 <= diameter {
1000                let neighbors = vld1q_f32(row_ptr.add(dx));
1001                let spatial_w = vld1q_f32(sp_ptr.add(dx));
1002                let diff = vabsq_f32(vsubq_f32(neighbors, center_v));
1003                let idx_u32 = vminq_u32(vcvtq_u32_f32(vmulq_f32(diff, scale_255)), max255);
1004                let mut idx_arr = [0u32; 4];
1005                vst1q_u32(idx_arr.as_mut_ptr(), idx_u32);
1006
1007                let cw_arr = [
1008                    *clut.add(idx_arr[0] as usize),
1009                    *clut.add(idx_arr[1] as usize),
1010                    *clut.add(idx_arr[2] as usize),
1011                    *clut.add(idx_arr[3] as usize),
1012                ];
1013                let wt = vmulq_f32(spatial_w, vld1q_f32(cw_arr.as_ptr()));
1014                sum_v = vfmaq_f32(sum_v, neighbors, wt);
1015                wsum_v = vaddq_f32(wsum_v, wt);
1016                dx += 4;
1017            }
1018
1019            // Scalar tail
1020            while dx < diameter {
1021                let neighbor = *row_ptr.add(dx);
1022                let color_diff = (neighbor - center).abs();
1023                let color_idx = ((color_diff * 255.0) as usize).min(255);
1024                let wt = *sp_ptr.add(dx) * *clut.add(color_idx);
1025                sum_s += neighbor * wt;
1026                wsum_s += wt;
1027                dx += 1;
1028            }
1029        }
1030
1031        let total_sum = vaddvq_f32(sum_v) + sum_s;
1032        let total_wsum = vaddvq_f32(wsum_v) + wsum_s;
1033        *row_out.get_unchecked_mut(x) = if total_wsum > 0.0 {
1034            total_sum / total_wsum
1035        } else {
1036            center
1037        };
1038    }
1039}
1040
1041/// SSE2-accelerated bilateral filter for an entire interior row.
1042/// Mirrors `bilateral_neon_row`: processes pixels x in [x_start, x_end) where
1043/// all neighbors are in bounds. Uses a 256-entry color-weight LUT with
1044/// Schraudolph-free table lookup (same as NEON path).
1045#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1046#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
1047#[target_feature(enable = "sse2")]
1048unsafe fn bilateral_sse_row(
1049    data: &[f32],
1050    w: usize,
1051    y: usize,
1052    x_start: usize,
1053    x_end: usize,
1054    radius: i32,
1055    diameter: usize,
1056    spatial_lut: &[f32],
1057    color_lut: &[f32; 256],
1058    row_out: &mut [f32],
1059) {
1060    #[cfg(target_arch = "x86")]
1061    use std::arch::x86::*;
1062    #[cfg(target_arch = "x86_64")]
1063    use std::arch::x86_64::*;
1064
1065    let scale_255 = _mm_set1_ps(255.0);
1066    let max_idx = 255i32;
1067    let clut = color_lut.as_ptr();
1068
1069    for x in x_start..x_end {
1070        let center = *data.get_unchecked(y * w + x);
1071        let center_v = _mm_set1_ps(center);
1072        let mut sum_v = _mm_setzero_ps();
1073        let mut wsum_v = _mm_setzero_ps();
1074        let mut sum_s = 0.0f32;
1075        let mut wsum_s = 0.0f32;
1076
1077        for dy in -radius..=radius {
1078            let ny = (y as i32 + dy) as usize;
1079            let row_ptr = data.as_ptr().add(ny * w + x - (radius as usize));
1080            let sp_ptr = spatial_lut
1081                .as_ptr()
1082                .add(((dy + radius) as usize) * diameter);
1083
1084            let mut dx = 0usize;
1085
1086            // Process 4 neighbors at a time
1087            while dx + 4 <= diameter {
1088                let neighbors = _mm_loadu_ps(row_ptr.add(dx));
1089                let spatial_w = _mm_loadu_ps(sp_ptr.add(dx));
1090
1091                // |neighbor - center|
1092                let diff = _mm_sub_ps(neighbors, center_v);
1093                // SSE doesn't have vabsq, use max(diff, -diff)
1094                let neg_diff = _mm_sub_ps(_mm_setzero_ps(), diff);
1095                let abs_diff = _mm_max_ps(diff, neg_diff);
1096
1097                // Convert to LUT index: clamp(abs_diff * 255, 0, 255)
1098                let scaled = _mm_mul_ps(abs_diff, scale_255);
1099                // Convert to int (truncate)
1100                let idx_i32 = _mm_cvttps_epi32(scaled);
1101
1102                // Extract indices and clamp to [0, 255]
1103                // SSE2 doesn't have _mm_extract_epi32, use a union or store
1104                let mut idx_arr = [0i32; 4];
1105                _mm_storeu_si128(idx_arr.as_mut_ptr() as *mut __m128i, idx_i32);
1106                idx_arr[0] = idx_arr[0].min(max_idx).max(0);
1107                idx_arr[1] = idx_arr[1].min(max_idx).max(0);
1108                idx_arr[2] = idx_arr[2].min(max_idx).max(0);
1109                idx_arr[3] = idx_arr[3].min(max_idx).max(0);
1110
1111                // Gather color weights from LUT
1112                let cw_arr = [
1113                    *clut.add(idx_arr[0] as usize),
1114                    *clut.add(idx_arr[1] as usize),
1115                    *clut.add(idx_arr[2] as usize),
1116                    *clut.add(idx_arr[3] as usize),
1117                ];
1118
1119                let color_w = _mm_loadu_ps(cw_arr.as_ptr());
1120                let wt = _mm_mul_ps(spatial_w, color_w);
1121
1122                // sum += neighbor * wt;  wsum += wt
1123                sum_v = _mm_add_ps(sum_v, _mm_mul_ps(neighbors, wt));
1124                wsum_v = _mm_add_ps(wsum_v, wt);
1125
1126                dx += 4;
1127            }
1128
1129            // Scalar tail
1130            while dx < diameter {
1131                let neighbor = *row_ptr.add(dx);
1132                let color_diff = (neighbor - center).abs();
1133                let color_idx = ((color_diff * 255.0) as usize).min(255);
1134                let wt = *sp_ptr.add(dx) * *clut.add(color_idx);
1135                sum_s += neighbor * wt;
1136                wsum_s += wt;
1137                dx += 1;
1138            }
1139        }
1140
1141        // Horizontal sum of SSE vectors
1142        // sum_v = [a, b, c, d] -> a+b+c+d
1143        let hi = _mm_movehl_ps(sum_v, sum_v); // [c, d, c, d]
1144        let sum_lo = _mm_add_ps(sum_v, hi); // [a+c, b+d, ...]
1145        let sum_shuf = _mm_shuffle_ps(sum_lo, sum_lo, 1); // [b+d, ...]
1146        let total_sum_v = _mm_add_ss(sum_lo, sum_shuf);
1147
1148        let hi_w = _mm_movehl_ps(wsum_v, wsum_v);
1149        let wsum_lo = _mm_add_ps(wsum_v, hi_w);
1150        let wsum_shuf = _mm_shuffle_ps(wsum_lo, wsum_lo, 1);
1151        let total_wsum_v = _mm_add_ss(wsum_lo, wsum_shuf);
1152
1153        let total_sum = _mm_cvtss_f32(total_sum_v) + sum_s;
1154        let total_wsum = _mm_cvtss_f32(total_wsum_v) + wsum_s;
1155
1156        *row_out.get_unchecked_mut(x) = if total_wsum > 0.0 {
1157            total_sum / total_wsum
1158        } else {
1159            center
1160        };
1161    }
1162}
1163
1164/// Bilateral filter on a single-channel `[H, W, 1]` image.
1165///
1166/// Preserves edges while smoothing. `d` is the spatial kernel radius,
1167/// `sigma_color` controls color similarity range, `sigma_space` controls spatial decay.
1168///
1169/// Uses rayon parallelism + NEON SIMD on aarch64 / SSE2 on x86 for high performance.
1170#[allow(unsafe_code)]
1171pub fn bilateral_filter(
1172    input: &Tensor,
1173    d: usize,
1174    sigma_color: f32,
1175    sigma_space: f32,
1176) -> Result<Tensor, ImgProcError> {
1177    let (h, w, c) = hwc_shape(input)?;
1178    if c != 1 {
1179        return Err(ImgProcError::InvalidChannelCount {
1180            expected: 1,
1181            got: c,
1182        });
1183    }
1184    let data = input.data();
1185    let mut out = vec![0.0f32; h * w];
1186    let radius = d as i32;
1187    let color_coeff = -0.5 / (sigma_color * sigma_color);
1188    let space_coeff = -0.5 / (sigma_space * sigma_space);
1189
1190    // Precompute spatial weight LUT: spatial_lut[(dy+radius)*diameter + (dx+radius)]
1191    let diameter = (2 * radius + 1) as usize;
1192    let mut spatial_lut = vec![0.0f32; diameter * diameter];
1193    for dy in -radius..=radius {
1194        for dx in -radius..=radius {
1195            let spatial_dist_sq = (dy * dy + dx * dx) as f32;
1196            let idx = ((dy + radius) as usize) * diameter + (dx + radius) as usize;
1197            spatial_lut[idx] = (space_coeff * spatial_dist_sq).exp();
1198        }
1199    }
1200
1201    // Precompute color weight LUT: color_lut[i] = exp(color_coeff * (i/255)^2) for i in 0..256
1202    let mut color_lut = [0.0f32; 256];
1203    for i in 0..256 {
1204        let diff = i as f32 / 255.0;
1205        color_lut[i] = (color_coeff * diff * diff).exp();
1206    }
1207
1208    let radius_u = d;
1209
1210    // Process pixel at (y, x) — scalar fallback (used for borders and non-NEON platforms)
1211    let process_pixel_scalar = |y: usize, x: usize| -> f32 {
1212        let center = data[y * w + x];
1213        let mut sum = 0.0f32;
1214        let mut weight_sum = 0.0f32;
1215        for dy in -radius..=radius {
1216            let ny = y as i32 + dy;
1217            if ny < 0 || ny >= h as i32 {
1218                continue;
1219            }
1220            let ny = ny as usize;
1221            let spatial_row_off = ((dy + radius) as usize) * diameter;
1222            for dx in -radius..=radius {
1223                let nx = x as i32 + dx;
1224                if nx < 0 || nx >= w as i32 {
1225                    continue;
1226                }
1227                let neighbor = data[ny * w + nx as usize];
1228                let color_diff = (neighbor - center).abs();
1229                let color_idx = ((color_diff * 255.0) as usize).min(255);
1230                let spatial_idx = spatial_row_off + (dx + radius) as usize;
1231                let wt = spatial_lut[spatial_idx] * color_lut[color_idx];
1232                sum += neighbor * wt;
1233                weight_sum += wt;
1234            }
1235        }
1236        if weight_sum > 0.0 {
1237            sum / weight_sum
1238        } else {
1239            center
1240        }
1241    };
1242
1243    // Check SIMD availability once, outside the hot loop
1244    #[cfg(target_arch = "aarch64")]
1245    let use_neon = !cfg!(miri) && std::arch::is_aarch64_feature_detected!("neon");
1246    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1247    let use_sse2 = !cfg!(miri) && std::is_x86_feature_detected!("sse2");
1248
1249    // Interior x range: [radius_u, w - radius_u)
1250    let x_start = radius_u;
1251    let x_end = w.saturating_sub(radius_u);
1252
1253    // Parallel processing: each row is independent
1254    let compute_row = |y: usize, row_out: &mut [f32]| {
1255        let is_interior_y = y >= radius_u && y + radius_u < h;
1256
1257        if is_interior_y {
1258            // Border pixels on left
1259            for x in 0..x_start {
1260                row_out[x] = process_pixel_scalar(y, x);
1261            }
1262            // Interior pixels — SIMD fast path
1263            #[cfg(target_arch = "aarch64")]
1264            if use_neon {
1265                unsafe {
1266                    bilateral_neon_row(
1267                        data,
1268                        w,
1269                        y,
1270                        x_start,
1271                        x_end,
1272                        radius,
1273                        diameter,
1274                        &spatial_lut,
1275                        &color_lut,
1276                        row_out,
1277                    );
1278                }
1279            } else {
1280                for x in x_start..x_end {
1281                    row_out[x] = process_pixel_scalar(y, x);
1282                }
1283            }
1284            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1285            if use_sse2 {
1286                unsafe {
1287                    bilateral_sse_row(
1288                        data,
1289                        w,
1290                        y,
1291                        x_start,
1292                        x_end,
1293                        radius,
1294                        diameter,
1295                        &spatial_lut,
1296                        &color_lut,
1297                        row_out,
1298                    );
1299                }
1300            } else {
1301                for x in x_start..x_end {
1302                    row_out[x] = process_pixel_scalar(y, x);
1303                }
1304            }
1305            #[cfg(not(any(target_arch = "aarch64", target_arch = "x86", target_arch = "x86_64")))]
1306            {
1307                for x in x_start..x_end {
1308                    row_out[x] = process_pixel_scalar(y, x);
1309                }
1310            }
1311            // Border pixels on right
1312            for x in x_end..w {
1313                row_out[x] = process_pixel_scalar(y, x);
1314            }
1315        } else {
1316            // Entire row is border
1317            for x in 0..w {
1318                row_out[x] = process_pixel_scalar(y, x);
1319            }
1320        }
1321    };
1322
1323    let pixels = h * w;
1324
1325    #[cfg(target_os = "macos")]
1326    if pixels > 4096 && !cfg!(miri) {
1327        let out_ptr = out.as_mut_ptr() as usize;
1328        use super::u8ops::gcd;
1329        gcd::parallel_for(h, |y| {
1330            let row =
1331                unsafe { std::slice::from_raw_parts_mut((out_ptr as *mut f32).add(y * w), w) };
1332            compute_row(y, row);
1333        });
1334        return Tensor::from_vec(vec![h, w, 1], out).map_err(Into::into);
1335    }
1336
1337    if pixels > 4096 {
1338        out.par_chunks_mut(w)
1339            .enumerate()
1340            .for_each(|(y, row)| compute_row(y, row));
1341    } else {
1342        out.chunks_mut(w)
1343            .enumerate()
1344            .for_each(|(y, row)| compute_row(y, row));
1345    }
1346
1347    Tensor::from_vec(vec![h, w, 1], out).map_err(Into::into)
1348}
1349
1350/// 2D convolution of a single-channel `[H, W, 1]` image with an arbitrary kernel `[KH, KW, 1]`.
1351///
1352/// Zero-pads the borders. Output has the same shape as input.
1353pub fn filter2d(input: &Tensor, kernel: &Tensor) -> Result<Tensor, ImgProcError> {
1354    let (h, w, c) = hwc_shape(input)?;
1355    if c != 1 {
1356        return Err(ImgProcError::InvalidChannelCount {
1357            expected: 1,
1358            got: c,
1359        });
1360    }
1361    let (kh, kw, kc) = hwc_shape(kernel)?;
1362    if kc != 1 {
1363        return Err(ImgProcError::InvalidChannelCount {
1364            expected: 1,
1365            got: kc,
1366        });
1367    }
1368    let data = input.data();
1369    let kern = kernel.data();
1370    let rh = kh / 2;
1371    let rw = kw / 2;
1372    let mut out = vec![0.0f32; h * w];
1373
1374    for y in 0..h {
1375        for x in 0..w {
1376            let mut sum = 0.0f32;
1377            for ky in 0..kh {
1378                for kx in 0..kw {
1379                    let ny = y as i32 + ky as i32 - rh as i32;
1380                    let nx = x as i32 + kx as i32 - rw as i32;
1381                    if ny >= 0 && ny < h as i32 && nx >= 0 && nx < w as i32 {
1382                        sum += data[ny as usize * w + nx as usize] * kern[ky * kw + kx];
1383                    }
1384                }
1385            }
1386            out[y * w + x] = sum;
1387        }
1388    }
1389
1390    Tensor::from_vec(vec![h, w, 1], out).map_err(Into::into)
1391}