Skip to main content

yscv_imgproc/ops/
morphology.rs

1use rayon::prelude::*;
2use yscv_tensor::Tensor;
3
4use super::super::ImgProcError;
5use super::super::shape::hwc_shape;
6use super::filter::border_coords_3x3;
7
8/// SIMD 3x3 max (dilate) for single-channel row. Returns first x NOT processed.
9#[allow(unsafe_code)]
10fn dilate_simd_row_c1(
11    row0: &[f32],
12    row1: &[f32],
13    row2: &[f32],
14    out: &mut [f32],
15    w: usize,
16) -> usize {
17    if w < 6 {
18        return 1;
19    }
20
21    #[cfg(target_arch = "aarch64")]
22    {
23        if std::arch::is_aarch64_feature_detected!("neon") {
24            return unsafe { dilate_neon_row_c1(row0, row1, row2, out, w) };
25        }
26    }
27    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
28    {
29        if std::is_x86_feature_detected!("avx") {
30            return unsafe { dilate_avx_row_c1(row0, row1, row2, out, w) };
31        }
32    }
33    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
34    {
35        if std::is_x86_feature_detected!("sse") {
36            return unsafe { dilate_sse_row_c1(row0, row1, row2, out, w) };
37        }
38    }
39    1
40}
41
42/// SIMD 3x3 max (dilate) for multi-channel (HWC) row.
43/// Operates on flat row data of length `row_len = w * channels`.
44/// Horizontal neighbors are at offsets `-channels, 0, +channels` in the flat array.
45/// Returns the flat index up to which processing was done (interior pixels only,
46/// i.e., from `channels` to `row_len - channels`).
47#[allow(unsafe_code)]
48fn dilate_simd_row_mc(
49    row0: &[f32],
50    row1: &[f32],
51    row2: &[f32],
52    out: &mut [f32],
53    row_len: usize,
54    channels: usize,
55) -> usize {
56    if row_len < channels * 3 + 4 {
57        return channels;
58    }
59
60    #[cfg(target_arch = "aarch64")]
61    {
62        if std::arch::is_aarch64_feature_detected!("neon") {
63            return unsafe { dilate_neon_row_mc(row0, row1, row2, out, row_len, channels) };
64        }
65    }
66    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
67    {
68        if std::is_x86_feature_detected!("avx") {
69            return unsafe { dilate_avx_row_mc(row0, row1, row2, out, row_len, channels) };
70        }
71    }
72    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
73    {
74        if std::is_x86_feature_detected!("sse") {
75            return unsafe { dilate_sse_row_mc(row0, row1, row2, out, row_len, channels) };
76        }
77    }
78    channels
79}
80
81#[cfg(target_arch = "aarch64")]
82#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
83#[target_feature(enable = "neon")]
84unsafe fn dilate_neon_row_mc(
85    row0: &[f32],
86    row1: &[f32],
87    row2: &[f32],
88    out: &mut [f32],
89    row_len: usize,
90    channels: usize,
91) -> usize {
92    use std::arch::aarch64::*;
93    let mut i = channels;
94    let end = row_len - channels;
95    while i + 4 <= end {
96        let r0l = vld1q_f32(row0.as_ptr().add(i - channels));
97        let r0m = vld1q_f32(row0.as_ptr().add(i));
98        let r0r = vld1q_f32(row0.as_ptr().add(i + channels));
99        let r1l = vld1q_f32(row1.as_ptr().add(i - channels));
100        let r1m = vld1q_f32(row1.as_ptr().add(i));
101        let r1r = vld1q_f32(row1.as_ptr().add(i + channels));
102        let r2l = vld1q_f32(row2.as_ptr().add(i - channels));
103        let r2m = vld1q_f32(row2.as_ptr().add(i));
104        let r2r = vld1q_f32(row2.as_ptr().add(i + channels));
105
106        let m0 = vmaxq_f32(vmaxq_f32(r0l, r0m), r0r);
107        let m1 = vmaxq_f32(vmaxq_f32(r1l, r1m), r1r);
108        let m2 = vmaxq_f32(vmaxq_f32(r2l, r2m), r2r);
109        vst1q_f32(out.as_mut_ptr().add(i), vmaxq_f32(vmaxq_f32(m0, m1), m2));
110        i += 4;
111    }
112    i
113}
114
115#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
116#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
117#[target_feature(enable = "sse")]
118unsafe fn dilate_sse_row_mc(
119    row0: &[f32],
120    row1: &[f32],
121    row2: &[f32],
122    out: &mut [f32],
123    row_len: usize,
124    channels: usize,
125) -> usize {
126    #[cfg(target_arch = "x86")]
127    use std::arch::x86::*;
128    #[cfg(target_arch = "x86_64")]
129    use std::arch::x86_64::*;
130
131    let mut i = channels;
132    let end = row_len - channels;
133    while i + 4 <= end {
134        let r0l = _mm_loadu_ps(row0.as_ptr().add(i - channels));
135        let r0m = _mm_loadu_ps(row0.as_ptr().add(i));
136        let r0r = _mm_loadu_ps(row0.as_ptr().add(i + channels));
137        let r1l = _mm_loadu_ps(row1.as_ptr().add(i - channels));
138        let r1m = _mm_loadu_ps(row1.as_ptr().add(i));
139        let r1r = _mm_loadu_ps(row1.as_ptr().add(i + channels));
140        let r2l = _mm_loadu_ps(row2.as_ptr().add(i - channels));
141        let r2m = _mm_loadu_ps(row2.as_ptr().add(i));
142        let r2r = _mm_loadu_ps(row2.as_ptr().add(i + channels));
143
144        let m0 = _mm_max_ps(_mm_max_ps(r0l, r0m), r0r);
145        let m1 = _mm_max_ps(_mm_max_ps(r1l, r1m), r1r);
146        let m2 = _mm_max_ps(_mm_max_ps(r2l, r2m), r2r);
147        _mm_storeu_ps(out.as_mut_ptr().add(i), _mm_max_ps(_mm_max_ps(m0, m1), m2));
148        i += 4;
149    }
150    i
151}
152
153#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
154#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
155#[target_feature(enable = "avx")]
156unsafe fn dilate_avx_row_mc(
157    row0: &[f32],
158    row1: &[f32],
159    row2: &[f32],
160    out: &mut [f32],
161    row_len: usize,
162    channels: usize,
163) -> usize {
164    #[cfg(target_arch = "x86")]
165    use std::arch::x86::*;
166    #[cfg(target_arch = "x86_64")]
167    use std::arch::x86_64::*;
168
169    let mut i = channels;
170    let end = row_len - channels;
171    while i + 8 <= end {
172        let r0l = _mm256_loadu_ps(row0.as_ptr().add(i - channels));
173        let r0m = _mm256_loadu_ps(row0.as_ptr().add(i));
174        let r0r = _mm256_loadu_ps(row0.as_ptr().add(i + channels));
175        let r1l = _mm256_loadu_ps(row1.as_ptr().add(i - channels));
176        let r1m = _mm256_loadu_ps(row1.as_ptr().add(i));
177        let r1r = _mm256_loadu_ps(row1.as_ptr().add(i + channels));
178        let r2l = _mm256_loadu_ps(row2.as_ptr().add(i - channels));
179        let r2m = _mm256_loadu_ps(row2.as_ptr().add(i));
180        let r2r = _mm256_loadu_ps(row2.as_ptr().add(i + channels));
181
182        let m0 = _mm256_max_ps(_mm256_max_ps(r0l, r0m), r0r);
183        let m1 = _mm256_max_ps(_mm256_max_ps(r1l, r1m), r1r);
184        let m2 = _mm256_max_ps(_mm256_max_ps(r2l, r2m), r2r);
185        _mm256_storeu_ps(
186            out.as_mut_ptr().add(i),
187            _mm256_max_ps(_mm256_max_ps(m0, m1), m2),
188        );
189        i += 8;
190    }
191    i
192}
193
194/// SIMD 3x3 min (erode) for multi-channel (HWC) row.
195#[allow(unsafe_code)]
196fn erode_simd_row_mc(
197    row0: &[f32],
198    row1: &[f32],
199    row2: &[f32],
200    out: &mut [f32],
201    row_len: usize,
202    channels: usize,
203) -> usize {
204    if row_len < channels * 3 + 4 {
205        return channels;
206    }
207
208    #[cfg(target_arch = "aarch64")]
209    {
210        if std::arch::is_aarch64_feature_detected!("neon") {
211            return unsafe { erode_neon_row_mc(row0, row1, row2, out, row_len, channels) };
212        }
213    }
214    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
215    {
216        if std::is_x86_feature_detected!("avx") {
217            return unsafe { erode_avx_row_mc(row0, row1, row2, out, row_len, channels) };
218        }
219    }
220    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
221    {
222        if std::is_x86_feature_detected!("sse") {
223            return unsafe { erode_sse_row_mc(row0, row1, row2, out, row_len, channels) };
224        }
225    }
226    channels
227}
228
229#[cfg(target_arch = "aarch64")]
230#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
231#[target_feature(enable = "neon")]
232unsafe fn erode_neon_row_mc(
233    row0: &[f32],
234    row1: &[f32],
235    row2: &[f32],
236    out: &mut [f32],
237    row_len: usize,
238    channels: usize,
239) -> usize {
240    use std::arch::aarch64::*;
241    let mut i = channels;
242    let end = row_len - channels;
243    while i + 4 <= end {
244        let r0l = vld1q_f32(row0.as_ptr().add(i - channels));
245        let r0m = vld1q_f32(row0.as_ptr().add(i));
246        let r0r = vld1q_f32(row0.as_ptr().add(i + channels));
247        let r1l = vld1q_f32(row1.as_ptr().add(i - channels));
248        let r1m = vld1q_f32(row1.as_ptr().add(i));
249        let r1r = vld1q_f32(row1.as_ptr().add(i + channels));
250        let r2l = vld1q_f32(row2.as_ptr().add(i - channels));
251        let r2m = vld1q_f32(row2.as_ptr().add(i));
252        let r2r = vld1q_f32(row2.as_ptr().add(i + channels));
253
254        let m0 = vminq_f32(vminq_f32(r0l, r0m), r0r);
255        let m1 = vminq_f32(vminq_f32(r1l, r1m), r1r);
256        let m2 = vminq_f32(vminq_f32(r2l, r2m), r2r);
257        vst1q_f32(out.as_mut_ptr().add(i), vminq_f32(vminq_f32(m0, m1), m2));
258        i += 4;
259    }
260    i
261}
262
263#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
264#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
265#[target_feature(enable = "sse")]
266unsafe fn erode_sse_row_mc(
267    row0: &[f32],
268    row1: &[f32],
269    row2: &[f32],
270    out: &mut [f32],
271    row_len: usize,
272    channels: usize,
273) -> usize {
274    #[cfg(target_arch = "x86")]
275    use std::arch::x86::*;
276    #[cfg(target_arch = "x86_64")]
277    use std::arch::x86_64::*;
278
279    let mut i = channels;
280    let end = row_len - channels;
281    while i + 4 <= end {
282        let r0l = _mm_loadu_ps(row0.as_ptr().add(i - channels));
283        let r0m = _mm_loadu_ps(row0.as_ptr().add(i));
284        let r0r = _mm_loadu_ps(row0.as_ptr().add(i + channels));
285        let r1l = _mm_loadu_ps(row1.as_ptr().add(i - channels));
286        let r1m = _mm_loadu_ps(row1.as_ptr().add(i));
287        let r1r = _mm_loadu_ps(row1.as_ptr().add(i + channels));
288        let r2l = _mm_loadu_ps(row2.as_ptr().add(i - channels));
289        let r2m = _mm_loadu_ps(row2.as_ptr().add(i));
290        let r2r = _mm_loadu_ps(row2.as_ptr().add(i + channels));
291
292        let m0 = _mm_min_ps(_mm_min_ps(r0l, r0m), r0r);
293        let m1 = _mm_min_ps(_mm_min_ps(r1l, r1m), r1r);
294        let m2 = _mm_min_ps(_mm_min_ps(r2l, r2m), r2r);
295        _mm_storeu_ps(out.as_mut_ptr().add(i), _mm_min_ps(_mm_min_ps(m0, m1), m2));
296        i += 4;
297    }
298    i
299}
300
301#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
302#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
303#[target_feature(enable = "avx")]
304unsafe fn erode_avx_row_mc(
305    row0: &[f32],
306    row1: &[f32],
307    row2: &[f32],
308    out: &mut [f32],
309    row_len: usize,
310    channels: usize,
311) -> usize {
312    #[cfg(target_arch = "x86")]
313    use std::arch::x86::*;
314    #[cfg(target_arch = "x86_64")]
315    use std::arch::x86_64::*;
316
317    let mut i = channels;
318    let end = row_len - channels;
319    while i + 8 <= end {
320        let r0l = _mm256_loadu_ps(row0.as_ptr().add(i - channels));
321        let r0m = _mm256_loadu_ps(row0.as_ptr().add(i));
322        let r0r = _mm256_loadu_ps(row0.as_ptr().add(i + channels));
323        let r1l = _mm256_loadu_ps(row1.as_ptr().add(i - channels));
324        let r1m = _mm256_loadu_ps(row1.as_ptr().add(i));
325        let r1r = _mm256_loadu_ps(row1.as_ptr().add(i + channels));
326        let r2l = _mm256_loadu_ps(row2.as_ptr().add(i - channels));
327        let r2m = _mm256_loadu_ps(row2.as_ptr().add(i));
328        let r2r = _mm256_loadu_ps(row2.as_ptr().add(i + channels));
329
330        let m0 = _mm256_min_ps(_mm256_min_ps(r0l, r0m), r0r);
331        let m1 = _mm256_min_ps(_mm256_min_ps(r1l, r1m), r1r);
332        let m2 = _mm256_min_ps(_mm256_min_ps(r2l, r2m), r2r);
333        _mm256_storeu_ps(
334            out.as_mut_ptr().add(i),
335            _mm256_min_ps(_mm256_min_ps(m0, m1), m2),
336        );
337        i += 8;
338    }
339    i
340}
341
342#[cfg(target_arch = "aarch64")]
343#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
344#[target_feature(enable = "neon")]
345unsafe fn dilate_neon_row_c1(
346    row0: &[f32],
347    row1: &[f32],
348    row2: &[f32],
349    out: &mut [f32],
350    w: usize,
351) -> usize {
352    use std::arch::aarch64::*;
353    let mut x = 1usize;
354    while x + 5 <= w {
355        let r0l = vld1q_f32(row0.as_ptr().add(x - 1));
356        let r0m = vld1q_f32(row0.as_ptr().add(x));
357        let r0r = vld1q_f32(row0.as_ptr().add(x + 1));
358        let r1l = vld1q_f32(row1.as_ptr().add(x - 1));
359        let r1m = vld1q_f32(row1.as_ptr().add(x));
360        let r1r = vld1q_f32(row1.as_ptr().add(x + 1));
361        let r2l = vld1q_f32(row2.as_ptr().add(x - 1));
362        let r2m = vld1q_f32(row2.as_ptr().add(x));
363        let r2r = vld1q_f32(row2.as_ptr().add(x + 1));
364
365        let m0 = vmaxq_f32(vmaxq_f32(r0l, r0m), r0r);
366        let m1 = vmaxq_f32(vmaxq_f32(r1l, r1m), r1r);
367        let m2 = vmaxq_f32(vmaxq_f32(r2l, r2m), r2r);
368        let result = vmaxq_f32(vmaxq_f32(m0, m1), m2);
369
370        vst1q_f32(out.as_mut_ptr().add(x), result);
371        x += 4;
372    }
373    x
374}
375
376#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
377#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
378#[target_feature(enable = "sse")]
379unsafe fn dilate_sse_row_c1(
380    row0: &[f32],
381    row1: &[f32],
382    row2: &[f32],
383    out: &mut [f32],
384    w: usize,
385) -> usize {
386    #[cfg(target_arch = "x86")]
387    use std::arch::x86::*;
388    #[cfg(target_arch = "x86_64")]
389    use std::arch::x86_64::*;
390
391    let mut x = 1usize;
392    while x + 5 <= w {
393        let r0l = _mm_loadu_ps(row0.as_ptr().add(x - 1));
394        let r0m = _mm_loadu_ps(row0.as_ptr().add(x));
395        let r0r = _mm_loadu_ps(row0.as_ptr().add(x + 1));
396        let r1l = _mm_loadu_ps(row1.as_ptr().add(x - 1));
397        let r1m = _mm_loadu_ps(row1.as_ptr().add(x));
398        let r1r = _mm_loadu_ps(row1.as_ptr().add(x + 1));
399        let r2l = _mm_loadu_ps(row2.as_ptr().add(x - 1));
400        let r2m = _mm_loadu_ps(row2.as_ptr().add(x));
401        let r2r = _mm_loadu_ps(row2.as_ptr().add(x + 1));
402
403        let m0 = _mm_max_ps(_mm_max_ps(r0l, r0m), r0r);
404        let m1 = _mm_max_ps(_mm_max_ps(r1l, r1m), r1r);
405        let m2 = _mm_max_ps(_mm_max_ps(r2l, r2m), r2r);
406        let result = _mm_max_ps(_mm_max_ps(m0, m1), m2);
407
408        _mm_storeu_ps(out.as_mut_ptr().add(x), result);
409        x += 4;
410    }
411    x
412}
413
414#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
415#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
416#[target_feature(enable = "avx")]
417unsafe fn dilate_avx_row_c1(
418    row0: &[f32],
419    row1: &[f32],
420    row2: &[f32],
421    out: &mut [f32],
422    w: usize,
423) -> usize {
424    #[cfg(target_arch = "x86")]
425    use std::arch::x86::*;
426    #[cfg(target_arch = "x86_64")]
427    use std::arch::x86_64::*;
428
429    let mut x = 1usize;
430    while x + 9 <= w {
431        let r0l = _mm256_loadu_ps(row0.as_ptr().add(x - 1));
432        let r0m = _mm256_loadu_ps(row0.as_ptr().add(x));
433        let r0r = _mm256_loadu_ps(row0.as_ptr().add(x + 1));
434        let r1l = _mm256_loadu_ps(row1.as_ptr().add(x - 1));
435        let r1m = _mm256_loadu_ps(row1.as_ptr().add(x));
436        let r1r = _mm256_loadu_ps(row1.as_ptr().add(x + 1));
437        let r2l = _mm256_loadu_ps(row2.as_ptr().add(x - 1));
438        let r2m = _mm256_loadu_ps(row2.as_ptr().add(x));
439        let r2r = _mm256_loadu_ps(row2.as_ptr().add(x + 1));
440
441        let m0 = _mm256_max_ps(_mm256_max_ps(r0l, r0m), r0r);
442        let m1 = _mm256_max_ps(_mm256_max_ps(r1l, r1m), r1r);
443        let m2 = _mm256_max_ps(_mm256_max_ps(r2l, r2m), r2r);
444        let result = _mm256_max_ps(_mm256_max_ps(m0, m1), m2);
445
446        _mm256_storeu_ps(out.as_mut_ptr().add(x), result);
447        x += 8;
448    }
449    x
450}
451
452/// SIMD 3x3 min (erode) for single-channel row. Returns first x NOT processed.
453#[allow(unsafe_code)]
454fn erode_simd_row_c1(row0: &[f32], row1: &[f32], row2: &[f32], out: &mut [f32], w: usize) -> usize {
455    if w < 6 {
456        return 1;
457    }
458
459    #[cfg(target_arch = "aarch64")]
460    {
461        if std::arch::is_aarch64_feature_detected!("neon") {
462            return unsafe { erode_neon_row_c1(row0, row1, row2, out, w) };
463        }
464    }
465    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
466    {
467        if std::is_x86_feature_detected!("avx") {
468            return unsafe { erode_avx_row_c1(row0, row1, row2, out, w) };
469        }
470    }
471    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
472    {
473        if std::is_x86_feature_detected!("sse") {
474            return unsafe { erode_sse_row_c1(row0, row1, row2, out, w) };
475        }
476    }
477    1
478}
479
480#[cfg(target_arch = "aarch64")]
481#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
482#[target_feature(enable = "neon")]
483unsafe fn erode_neon_row_c1(
484    row0: &[f32],
485    row1: &[f32],
486    row2: &[f32],
487    out: &mut [f32],
488    w: usize,
489) -> usize {
490    use std::arch::aarch64::*;
491    let mut x = 1usize;
492    while x + 5 <= w {
493        let r0l = vld1q_f32(row0.as_ptr().add(x - 1));
494        let r0m = vld1q_f32(row0.as_ptr().add(x));
495        let r0r = vld1q_f32(row0.as_ptr().add(x + 1));
496        let r1l = vld1q_f32(row1.as_ptr().add(x - 1));
497        let r1m = vld1q_f32(row1.as_ptr().add(x));
498        let r1r = vld1q_f32(row1.as_ptr().add(x + 1));
499        let r2l = vld1q_f32(row2.as_ptr().add(x - 1));
500        let r2m = vld1q_f32(row2.as_ptr().add(x));
501        let r2r = vld1q_f32(row2.as_ptr().add(x + 1));
502
503        let m0 = vminq_f32(vminq_f32(r0l, r0m), r0r);
504        let m1 = vminq_f32(vminq_f32(r1l, r1m), r1r);
505        let m2 = vminq_f32(vminq_f32(r2l, r2m), r2r);
506        let result = vminq_f32(vminq_f32(m0, m1), m2);
507
508        vst1q_f32(out.as_mut_ptr().add(x), result);
509        x += 4;
510    }
511    x
512}
513
514#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
515#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
516#[target_feature(enable = "sse")]
517unsafe fn erode_sse_row_c1(
518    row0: &[f32],
519    row1: &[f32],
520    row2: &[f32],
521    out: &mut [f32],
522    w: usize,
523) -> usize {
524    #[cfg(target_arch = "x86")]
525    use std::arch::x86::*;
526    #[cfg(target_arch = "x86_64")]
527    use std::arch::x86_64::*;
528
529    let mut x = 1usize;
530    while x + 5 <= w {
531        let r0l = _mm_loadu_ps(row0.as_ptr().add(x - 1));
532        let r0m = _mm_loadu_ps(row0.as_ptr().add(x));
533        let r0r = _mm_loadu_ps(row0.as_ptr().add(x + 1));
534        let r1l = _mm_loadu_ps(row1.as_ptr().add(x - 1));
535        let r1m = _mm_loadu_ps(row1.as_ptr().add(x));
536        let r1r = _mm_loadu_ps(row1.as_ptr().add(x + 1));
537        let r2l = _mm_loadu_ps(row2.as_ptr().add(x - 1));
538        let r2m = _mm_loadu_ps(row2.as_ptr().add(x));
539        let r2r = _mm_loadu_ps(row2.as_ptr().add(x + 1));
540
541        let m0 = _mm_min_ps(_mm_min_ps(r0l, r0m), r0r);
542        let m1 = _mm_min_ps(_mm_min_ps(r1l, r1m), r1r);
543        let m2 = _mm_min_ps(_mm_min_ps(r2l, r2m), r2r);
544        let result = _mm_min_ps(_mm_min_ps(m0, m1), m2);
545
546        _mm_storeu_ps(out.as_mut_ptr().add(x), result);
547        x += 4;
548    }
549    x
550}
551
552#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
553#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
554#[target_feature(enable = "avx")]
555unsafe fn erode_avx_row_c1(
556    row0: &[f32],
557    row1: &[f32],
558    row2: &[f32],
559    out: &mut [f32],
560    w: usize,
561) -> usize {
562    #[cfg(target_arch = "x86")]
563    use std::arch::x86::*;
564    #[cfg(target_arch = "x86_64")]
565    use std::arch::x86_64::*;
566
567    let mut x = 1usize;
568    while x + 9 <= w {
569        let r0l = _mm256_loadu_ps(row0.as_ptr().add(x - 1));
570        let r0m = _mm256_loadu_ps(row0.as_ptr().add(x));
571        let r0r = _mm256_loadu_ps(row0.as_ptr().add(x + 1));
572        let r1l = _mm256_loadu_ps(row1.as_ptr().add(x - 1));
573        let r1m = _mm256_loadu_ps(row1.as_ptr().add(x));
574        let r1r = _mm256_loadu_ps(row1.as_ptr().add(x + 1));
575        let r2l = _mm256_loadu_ps(row2.as_ptr().add(x - 1));
576        let r2m = _mm256_loadu_ps(row2.as_ptr().add(x));
577        let r2r = _mm256_loadu_ps(row2.as_ptr().add(x + 1));
578
579        let m0 = _mm256_min_ps(_mm256_min_ps(r0l, r0m), r0r);
580        let m1 = _mm256_min_ps(_mm256_min_ps(r1l, r1m), r1r);
581        let m2 = _mm256_min_ps(_mm256_min_ps(r2l, r2m), r2r);
582        let result = _mm256_min_ps(_mm256_min_ps(m0, m1), m2);
583
584        _mm256_storeu_ps(out.as_mut_ptr().add(x), result);
585        x += 8;
586    }
587    x
588}
589
590/// Applies a 3x3 grayscale/RGB dilation (local maximum per channel).
591///
592/// Border handling uses only in-bounds neighbors.
593#[allow(unsafe_code, clippy::uninit_vec)]
594pub fn dilate_3x3(input: &Tensor) -> Result<Tensor, ImgProcError> {
595    let (h, w, channels) = hwc_shape(input)?;
596    let total = h * w * channels;
597    // SAFETY: every element is written by interior + border paths below.
598    let mut out = Vec::with_capacity(total);
599    unsafe {
600        out.set_len(total);
601    }
602    let data = input.data();
603    let row_len = w * channels;
604    let interior_h = h.saturating_sub(2);
605
606    // Interior pixels (y=1..h-1, x=1..w-1): no bounds checks
607    let compute_interior_row = |y: usize, row: &mut [f32]| {
608        if !cfg!(miri) {
609            // SIMD fast path for any channel count.
610            // For HWC layout, flat row data is [R0 G0 B0 R1 G1 B1 ...].
611            // Horizontal neighbors for pixel x, channel c are at flat offsets
612            // (x-1)*C+c, x*C+c, (x+1)*C+c — which is offset -C, 0, +C in flat space.
613            // SIMD loads at these offsets naturally compare same channels.
614            let row0 = &data[(y - 1) * row_len..y * row_len];
615            let row1 = &data[y * row_len..(y + 1) * row_len];
616            let row2 = &data[(y + 1) * row_len..(y + 2) * row_len];
617
618            let done = if channels == 1 {
619                dilate_simd_row_c1(row0, row1, row2, row, w)
620            } else {
621                dilate_simd_row_mc(row0, row1, row2, row, row_len, channels)
622            };
623
624            // Scalar fallback for remaining interior elements
625            if channels == 1 {
626                for x in done..w.saturating_sub(1) {
627                    if x == 0 {
628                        continue;
629                    }
630                    let mut m = row0[x - 1];
631                    m = m.max(row0[x]).max(row0[x + 1]);
632                    m = m.max(row1[x - 1]).max(row1[x]).max(row1[x + 1]);
633                    m = m.max(row2[x - 1]).max(row2[x]).max(row2[x + 1]);
634                    row[x] = m;
635                }
636            } else {
637                // Convert flat index back to pixel x, handle remaining pixels
638                let start_x = done / channels;
639                for x in start_x..w.saturating_sub(1) {
640                    if x == 0 {
641                        continue;
642                    }
643                    for c in 0..channels {
644                        let i = x * channels + c;
645                        let mut m = row0[i - channels];
646                        m = m.max(row0[i]).max(row0[i + channels]);
647                        m = m
648                            .max(row1[i - channels])
649                            .max(row1[i])
650                            .max(row1[i + channels]);
651                        m = m
652                            .max(row2[i - channels])
653                            .max(row2[i])
654                            .max(row2[i + channels]);
655                        row[i] = m;
656                    }
657                }
658            }
659            return;
660        }
661
662        if channels == 1 {
663            for x in 1..w.saturating_sub(1) {
664                let row0 = &data[(y - 1) * w..y * w];
665                let row1 = &data[y * w..(y + 1) * w];
666                let row2 = &data[(y + 1) * w..(y + 2) * w];
667                let mut m = row0[x - 1];
668                m = m.max(row0[x]).max(row0[x + 1]);
669                m = m.max(row1[x - 1]).max(row1[x]).max(row1[x + 1]);
670                m = m.max(row2[x - 1]).max(row2[x]).max(row2[x + 1]);
671                row[x] = m;
672            }
673        } else {
674            for x in 1..w.saturating_sub(1) {
675                for c in 0..channels {
676                    let r0 = ((y - 1) * w + x - 1) * channels + c;
677                    let r1 = (y * w + x - 1) * channels + c;
678                    let r2 = ((y + 1) * w + x - 1) * channels + c;
679                    let mut max_value = data[r0];
680                    max_value = max_value.max(data[r0 + channels]);
681                    max_value = max_value.max(data[r0 + 2 * channels]);
682                    max_value = max_value.max(data[r1]);
683                    max_value = max_value.max(data[r1 + channels]);
684                    max_value = max_value.max(data[r1 + 2 * channels]);
685                    max_value = max_value.max(data[r2]);
686                    max_value = max_value.max(data[r2 + channels]);
687                    max_value = max_value.max(data[r2 + 2 * channels]);
688                    row[x * channels + c] = max_value;
689                }
690            }
691        }
692    };
693
694    if interior_h > 0 {
695        let pixels = h * w;
696
697        #[cfg(target_os = "macos")]
698        let use_gcd = pixels > 4096 && !cfg!(miri);
699        #[cfg(not(target_os = "macos"))]
700        let use_gcd = false;
701
702        if use_gcd {
703            #[cfg(target_os = "macos")]
704            {
705                let out_ptr = out.as_mut_ptr() as usize;
706                use super::u8ops::gcd;
707                gcd::parallel_for(interior_h, |i| {
708                    let y = i + 1;
709                    // SAFETY: each row writes to a disjoint slice of out.
710                    let row = unsafe {
711                        std::slice::from_raw_parts_mut(
712                            (out_ptr as *mut f32).add(y * row_len),
713                            row_len,
714                        )
715                    };
716                    compute_interior_row(y, row);
717                });
718            }
719        } else if pixels > 4096 {
720            let interior_out = &mut out[row_len..row_len + interior_h * row_len];
721            interior_out
722                .par_chunks_mut(row_len)
723                .enumerate()
724                .for_each(|(i, row)| compute_interior_row(i + 1, row));
725        } else {
726            let interior_out = &mut out[row_len..row_len + interior_h * row_len];
727            interior_out
728                .chunks_mut(row_len)
729                .enumerate()
730                .for_each(|(i, row)| compute_interior_row(i + 1, row));
731        }
732    }
733
734    // Border pixels: bounds-checked path
735    let border = border_coords_3x3(h, w);
736    for (y, x) in border {
737        for c in 0..channels {
738            let mut max_value = f32::NEG_INFINITY;
739            for ky in -1isize..=1 {
740                for kx in -1isize..=1 {
741                    let sy = y as isize + ky;
742                    let sx = x as isize + kx;
743                    if sy < 0 || sx < 0 || sy >= h as isize || sx >= w as isize {
744                        continue;
745                    }
746                    let src = ((sy as usize) * w + sx as usize) * channels + c;
747                    max_value = max_value.max(data[src]);
748                }
749            }
750            let dst = (y * w + x) * channels + c;
751            out[dst] = max_value;
752        }
753    }
754
755    Tensor::from_vec(vec![h, w, channels], out).map_err(Into::into)
756}
757
758/// Applies a 3x3 grayscale/RGB erosion (local minimum per channel).
759///
760/// Border handling uses only in-bounds neighbors.
761#[allow(unsafe_code, clippy::uninit_vec)]
762pub fn erode_3x3(input: &Tensor) -> Result<Tensor, ImgProcError> {
763    let (h, w, channels) = hwc_shape(input)?;
764    let total = h * w * channels;
765    // SAFETY: every element is written by interior + border paths below.
766    let mut out = Vec::with_capacity(total);
767    unsafe {
768        out.set_len(total);
769    }
770    let data = input.data();
771    let row_len = w * channels;
772    let interior_h = h.saturating_sub(2);
773
774    // Interior pixels (y=1..h-1, x=1..w-1): no bounds checks
775    let compute_interior_row = |y: usize, row: &mut [f32]| {
776        if !cfg!(miri) {
777            let row0 = &data[(y - 1) * row_len..y * row_len];
778            let row1 = &data[y * row_len..(y + 1) * row_len];
779            let row2 = &data[(y + 1) * row_len..(y + 2) * row_len];
780
781            let done = if channels == 1 {
782                erode_simd_row_c1(row0, row1, row2, row, w)
783            } else {
784                erode_simd_row_mc(row0, row1, row2, row, row_len, channels)
785            };
786
787            if channels == 1 {
788                for x in done..w.saturating_sub(1) {
789                    if x == 0 {
790                        continue;
791                    }
792                    let mut m = row0[x - 1];
793                    m = m.min(row0[x]).min(row0[x + 1]);
794                    m = m.min(row1[x - 1]).min(row1[x]).min(row1[x + 1]);
795                    m = m.min(row2[x - 1]).min(row2[x]).min(row2[x + 1]);
796                    row[x] = m;
797                }
798            } else {
799                let start_x = done / channels;
800                for x in start_x..w.saturating_sub(1) {
801                    if x == 0 {
802                        continue;
803                    }
804                    for c in 0..channels {
805                        let i = x * channels + c;
806                        let mut m = row0[i - channels];
807                        m = m.min(row0[i]).min(row0[i + channels]);
808                        m = m
809                            .min(row1[i - channels])
810                            .min(row1[i])
811                            .min(row1[i + channels]);
812                        m = m
813                            .min(row2[i - channels])
814                            .min(row2[i])
815                            .min(row2[i + channels]);
816                        row[i] = m;
817                    }
818                }
819            }
820            return;
821        }
822
823        if channels == 1 {
824            for x in 1..w.saturating_sub(1) {
825                let row0 = &data[(y - 1) * w..y * w];
826                let row1 = &data[y * w..(y + 1) * w];
827                let row2 = &data[(y + 1) * w..(y + 2) * w];
828                let mut m = row0[x - 1];
829                m = m.min(row0[x]).min(row0[x + 1]);
830                m = m.min(row1[x - 1]).min(row1[x]).min(row1[x + 1]);
831                m = m.min(row2[x - 1]).min(row2[x]).min(row2[x + 1]);
832                row[x] = m;
833            }
834        } else {
835            for x in 1..w.saturating_sub(1) {
836                for c in 0..channels {
837                    let r0 = ((y - 1) * w + x - 1) * channels + c;
838                    let r1 = (y * w + x - 1) * channels + c;
839                    let r2 = ((y + 1) * w + x - 1) * channels + c;
840                    let mut min_value = data[r0];
841                    min_value = min_value.min(data[r0 + channels]);
842                    min_value = min_value.min(data[r0 + 2 * channels]);
843                    min_value = min_value.min(data[r1]);
844                    min_value = min_value.min(data[r1 + channels]);
845                    min_value = min_value.min(data[r1 + 2 * channels]);
846                    min_value = min_value.min(data[r2]);
847                    min_value = min_value.min(data[r2 + channels]);
848                    min_value = min_value.min(data[r2 + 2 * channels]);
849                    row[x * channels + c] = min_value;
850                }
851            }
852        }
853    };
854
855    if interior_h > 0 {
856        let pixels = h * w;
857
858        #[cfg(target_os = "macos")]
859        let use_gcd = pixels > 4096 && !cfg!(miri);
860        #[cfg(not(target_os = "macos"))]
861        let use_gcd = false;
862
863        if use_gcd {
864            #[cfg(target_os = "macos")]
865            {
866                let out_ptr = out.as_mut_ptr() as usize;
867                use super::u8ops::gcd;
868                gcd::parallel_for(interior_h, |i| {
869                    let y = i + 1;
870                    // SAFETY: each row writes to a disjoint slice of out.
871                    let row = unsafe {
872                        std::slice::from_raw_parts_mut(
873                            (out_ptr as *mut f32).add(y * row_len),
874                            row_len,
875                        )
876                    };
877                    compute_interior_row(y, row);
878                });
879            }
880        } else if pixels > 4096 {
881            let interior_out = &mut out[row_len..row_len + interior_h * row_len];
882            interior_out
883                .par_chunks_mut(row_len)
884                .enumerate()
885                .for_each(|(i, row)| compute_interior_row(i + 1, row));
886        } else {
887            let interior_out = &mut out[row_len..row_len + interior_h * row_len];
888            interior_out
889                .chunks_mut(row_len)
890                .enumerate()
891                .for_each(|(i, row)| compute_interior_row(i + 1, row));
892        }
893    }
894
895    // Border pixels: bounds-checked path
896    let border = border_coords_3x3(h, w);
897    for (y, x) in border {
898        for c in 0..channels {
899            let mut min_value = f32::INFINITY;
900            for ky in -1isize..=1 {
901                for kx in -1isize..=1 {
902                    let sy = y as isize + ky;
903                    let sx = x as isize + kx;
904                    if sy < 0 || sx < 0 || sy >= h as isize || sx >= w as isize {
905                        continue;
906                    }
907                    let src = ((sy as usize) * w + sx as usize) * channels + c;
908                    min_value = min_value.min(data[src]);
909                }
910            }
911            let dst = (y * w + x) * channels + c;
912            out[dst] = min_value;
913        }
914    }
915
916    Tensor::from_vec(vec![h, w, channels], out).map_err(Into::into)
917}
918
919/// Applies a 3x3 opening (`erode` followed by `dilate`) per channel.
920pub fn opening_3x3(input: &Tensor) -> Result<Tensor, ImgProcError> {
921    let eroded = erode_3x3(input)?;
922    dilate_3x3(&eroded)
923}
924
925/// Applies a 3x3 closing (`dilate` followed by `erode`) per channel.
926pub fn closing_3x3(input: &Tensor) -> Result<Tensor, ImgProcError> {
927    let dilated = dilate_3x3(input)?;
928    erode_3x3(&dilated)
929}
930
931/// Applies a 3x3 morphological gradient (`dilate - erode`) per channel.
932pub fn morph_gradient_3x3(input: &Tensor) -> Result<Tensor, ImgProcError> {
933    let dilated = dilate_3x3(input)?;
934    let eroded = erode_3x3(input)?;
935
936    let mut out = vec![0.0f32; input.len()];
937    for (idx, value) in out.iter_mut().enumerate() {
938        *value = dilated.data()[idx] - eroded.data()[idx];
939    }
940
941    let (h, w, channels) = hwc_shape(input)?;
942    Tensor::from_vec(vec![h, w, channels], out).map_err(Into::into)
943}
944
945/// Dilate a single-channel `[H, W, 1]` image with an arbitrary structuring element.
946///
947/// `kernel` is `[kh, kw, 1]` where nonzero values indicate active elements.
948pub fn dilate(input: &Tensor, kernel: &Tensor) -> Result<Tensor, ImgProcError> {
949    let (h, w, c) = hwc_shape(input)?;
950    if c != 1 {
951        return Err(ImgProcError::InvalidChannelCount {
952            expected: 1,
953            got: c,
954        });
955    }
956    let (kh, kw, kc) = hwc_shape(kernel)?;
957    if kc != 1 {
958        return Err(ImgProcError::InvalidChannelCount {
959            expected: 1,
960            got: kc,
961        });
962    }
963    let data = input.data();
964    let kern = kernel.data();
965    let rh = kh / 2;
966    let rw = kw / 2;
967    let mut out = vec![0.0f32; h * w];
968
969    for y in 0..h {
970        for x in 0..w {
971            let mut max_val = f32::NEG_INFINITY;
972            for ky in 0..kh {
973                for kx in 0..kw {
974                    if kern[ky * kw + kx] <= 0.0 {
975                        continue;
976                    }
977                    let ny = y as i32 + ky as i32 - rh as i32;
978                    let nx = x as i32 + kx as i32 - rw as i32;
979                    if ny >= 0 && ny < h as i32 && nx >= 0 && nx < w as i32 {
980                        max_val = max_val.max(data[ny as usize * w + nx as usize]);
981                    }
982                }
983            }
984            out[y * w + x] = if max_val.is_finite() { max_val } else { 0.0 };
985        }
986    }
987
988    Tensor::from_vec(vec![h, w, 1], out).map_err(Into::into)
989}
990
991/// Top-hat transform: `input - opening(input)`.
992///
993/// Extracts bright features smaller than the 3x3 structuring element.
994pub fn morph_tophat(input: &Tensor) -> Result<Tensor, ImgProcError> {
995    let opened = opening_3x3(input)?;
996    let mut out = vec![0.0f32; input.len()];
997    for (i, v) in out.iter_mut().enumerate() {
998        *v = input.data()[i] - opened.data()[i];
999    }
1000    let (h, w, c) = hwc_shape(input)?;
1001    Tensor::from_vec(vec![h, w, c], out).map_err(Into::into)
1002}
1003
1004/// Black-hat transform: `closing(input) - input`.
1005///
1006/// Extracts dark features smaller than the 3x3 structuring element.
1007pub fn morph_blackhat(input: &Tensor) -> Result<Tensor, ImgProcError> {
1008    let closed = closing_3x3(input)?;
1009    let mut out = vec![0.0f32; input.len()];
1010    for (i, v) in out.iter_mut().enumerate() {
1011        *v = closed.data()[i] - input.data()[i];
1012    }
1013    let (h, w, c) = hwc_shape(input)?;
1014    Tensor::from_vec(vec![h, w, c], out).map_err(Into::into)
1015}
1016
1017/// Zhang-Suen thinning algorithm on a binary single-channel `[H, W, 1]` image.
1018///
1019/// Pixels > 0.5 are foreground. Iteratively removes boundary pixels to produce
1020/// a one-pixel-wide skeleton. Returns a binary image with 1.0 for skeleton pixels.
1021pub fn skeletonize(input: &Tensor) -> Result<Tensor, ImgProcError> {
1022    let (h, w, c) = hwc_shape(input)?;
1023    if c != 1 {
1024        return Err(ImgProcError::InvalidChannelCount {
1025            expected: 1,
1026            got: c,
1027        });
1028    }
1029    let data = input.data();
1030    let mut img: Vec<u8> = data
1031        .iter()
1032        .map(|&v| if v > 0.5 { 1u8 } else { 0u8 })
1033        .collect();
1034
1035    loop {
1036        let mut changed = false;
1037
1038        // Sub-iteration 1
1039        let mut markers = vec![false; h * w];
1040        for y in 1..h.saturating_sub(1) {
1041            for x in 1..w.saturating_sub(1) {
1042                if img[y * w + x] == 0 {
1043                    continue;
1044                }
1045                let p = zhang_suen_neighbors(&img, w, x, y);
1046                let b = p.iter().map(|&v| v as u32).sum::<u32>();
1047                let a = zhang_suen_transitions(&p);
1048                if (2..=6).contains(&b)
1049                    && a == 1
1050                    && p[0] * p[2] * p[4] == 0
1051                    && p[2] * p[4] * p[6] == 0
1052                {
1053                    markers[y * w + x] = true;
1054                }
1055            }
1056        }
1057        for i in 0..h * w {
1058            if markers[i] {
1059                img[i] = 0;
1060                changed = true;
1061            }
1062        }
1063
1064        // Sub-iteration 2
1065        markers.fill(false);
1066        for y in 1..h.saturating_sub(1) {
1067            for x in 1..w.saturating_sub(1) {
1068                if img[y * w + x] == 0 {
1069                    continue;
1070                }
1071                let p = zhang_suen_neighbors(&img, w, x, y);
1072                let b = p.iter().map(|&v| v as u32).sum::<u32>();
1073                let a = zhang_suen_transitions(&p);
1074                if (2..=6).contains(&b)
1075                    && a == 1
1076                    && p[0] * p[2] * p[6] == 0
1077                    && p[0] * p[4] * p[6] == 0
1078                {
1079                    markers[y * w + x] = true;
1080                }
1081            }
1082        }
1083        for i in 0..h * w {
1084            if markers[i] {
1085                img[i] = 0;
1086                changed = true;
1087            }
1088        }
1089
1090        if !changed {
1091            break;
1092        }
1093    }
1094
1095    let out: Vec<f32> = img.iter().map(|&v| v as f32).collect();
1096    Tensor::from_vec(vec![h, w, 1], out).map_err(Into::into)
1097}
1098
1099/// Returns the 8 neighbors P2..P9 in Zhang-Suen order.
1100/// P2=N, P3=NE, P4=E, P5=SE, P6=S, P7=SW, P8=W, P9=NW
1101fn zhang_suen_neighbors(img: &[u8], w: usize, x: usize, y: usize) -> [u8; 8] {
1102    [
1103        img[(y - 1) * w + x],     // P2 (N)
1104        img[(y - 1) * w + x + 1], // P3 (NE)
1105        img[y * w + x + 1],       // P4 (E)
1106        img[(y + 1) * w + x + 1], // P5 (SE)
1107        img[(y + 1) * w + x],     // P6 (S)
1108        img[(y + 1) * w + x - 1], // P7 (SW)
1109        img[y * w + x - 1],       // P8 (W)
1110        img[(y - 1) * w + x - 1], // P9 (NW)
1111    ]
1112}
1113
1114/// Number of 0->1 transitions in the circular sequence P2..P9..P2.
1115fn zhang_suen_transitions(p: &[u8; 8]) -> u32 {
1116    let mut count = 0u32;
1117    for i in 0..8 {
1118        if p[i] == 0 && p[(i + 1) % 8] == 1 {
1119            count += 1;
1120        }
1121    }
1122    count
1123}
1124
1125/// Removes connected components with area less than `min_size`.
1126///
1127/// Takes a single-channel `[H, W, 1]` binary image. Pixels > 0.5 are foreground.
1128/// Connected components (4-connected) with fewer than `min_size` foreground pixels
1129/// are set to 0.0.
1130pub fn remove_small_objects(input: &Tensor, min_size: usize) -> Result<Tensor, ImgProcError> {
1131    let (h, w, c) = hwc_shape(input)?;
1132    if c != 1 {
1133        return Err(ImgProcError::InvalidChannelCount {
1134            expected: 1,
1135            got: c,
1136        });
1137    }
1138    let data = input.data();
1139    let mut labels = vec![0u32; h * w];
1140    let mut label_id = 0u32;
1141    let mut label_sizes: Vec<usize> = Vec::new();
1142
1143    // BFS-based connected components (4-connected)
1144    for y in 0..h {
1145        for x in 0..w {
1146            let idx = y * w + x;
1147            if data[idx] <= 0.5 || labels[idx] != 0 {
1148                continue;
1149            }
1150            label_id += 1;
1151            let mut queue = vec![(x, y)];
1152            labels[idx] = label_id;
1153            let mut size = 0usize;
1154            while let Some((cx, cy)) = queue.pop() {
1155                size += 1;
1156                for &(dx, dy) in &[(0isize, -1isize), (0, 1), (-1, 0), (1, 0)] {
1157                    let nx = cx as isize + dx;
1158                    let ny = cy as isize + dy;
1159                    if nx >= 0 && nx < w as isize && ny >= 0 && ny < h as isize {
1160                        let nidx = ny as usize * w + nx as usize;
1161                        if data[nidx] > 0.5 && labels[nidx] == 0 {
1162                            labels[nidx] = label_id;
1163                            queue.push((nx as usize, ny as usize));
1164                        }
1165                    }
1166                }
1167            }
1168            label_sizes.push(size);
1169        }
1170    }
1171
1172    let mut out: Vec<f32> = data.to_vec();
1173    for i in 0..h * w {
1174        if labels[i] > 0 && label_sizes[(labels[i] - 1) as usize] < min_size {
1175            out[i] = 0.0;
1176        }
1177    }
1178
1179    Tensor::from_vec(vec![h, w, 1], out).map_err(Into::into)
1180}
1181
1182/// Erode a single-channel `[H, W, 1]` image with an arbitrary structuring element.
1183pub fn erode(input: &Tensor, kernel: &Tensor) -> Result<Tensor, ImgProcError> {
1184    let (h, w, c) = hwc_shape(input)?;
1185    if c != 1 {
1186        return Err(ImgProcError::InvalidChannelCount {
1187            expected: 1,
1188            got: c,
1189        });
1190    }
1191    let (kh, kw, kc) = hwc_shape(kernel)?;
1192    if kc != 1 {
1193        return Err(ImgProcError::InvalidChannelCount {
1194            expected: 1,
1195            got: kc,
1196        });
1197    }
1198    let data = input.data();
1199    let kern = kernel.data();
1200    let rh = kh / 2;
1201    let rw = kw / 2;
1202    let mut out = vec![0.0f32; h * w];
1203
1204    for y in 0..h {
1205        for x in 0..w {
1206            let mut min_val = f32::INFINITY;
1207            for ky in 0..kh {
1208                for kx in 0..kw {
1209                    if kern[ky * kw + kx] <= 0.0 {
1210                        continue;
1211                    }
1212                    let ny = y as i32 + ky as i32 - rh as i32;
1213                    let nx = x as i32 + kx as i32 - rw as i32;
1214                    if ny >= 0 && ny < h as i32 && nx >= 0 && nx < w as i32 {
1215                        min_val = min_val.min(data[ny as usize * w + nx as usize]);
1216                    }
1217                }
1218            }
1219            out[y * w + x] = if min_val.is_finite() { min_val } else { 0.0 };
1220        }
1221    }
1222
1223    Tensor::from_vec(vec![h, w, 1], out).map_err(Into::into)
1224}