Skip to main content

yscv_imgproc/ops/
threshold.rs

1use rayon::prelude::*;
2use yscv_tensor::{AlignedVec, Tensor};
3
4use super::super::ImgProcError;
5use super::super::shape::hwc_shape;
6
7use super::geometry::sobel_3x3_gradients;
8
9const RAYON_THRESHOLD: usize = 4096;
10
11/// Binary threshold: outputs `max_val` where value > threshold, else `0.0`.
12#[allow(unsafe_code, clippy::uninit_vec)]
13pub fn threshold_binary(
14    input: &Tensor,
15    threshold: f32,
16    max_val: f32,
17) -> Result<Tensor, ImgProcError> {
18    let (h, w, channels) = hwc_shape(input)?;
19    let data = input.data();
20    let len = data.len();
21    // SAFETY: every element is written by threshold_binary_simd_slice + scalar tail below.
22    let mut out = AlignedVec::<f32>::uninitialized(len);
23
24    let row_len = w * channels;
25
26    #[cfg(target_os = "macos")]
27    if len >= RAYON_THRESHOLD && !cfg!(miri) {
28        let src_ptr = data.as_ptr() as usize;
29        let dst_ptr = out.as_mut_ptr() as usize;
30        use super::u8ops::gcd;
31        gcd::parallel_for(h, |y| {
32            let src = unsafe {
33                std::slice::from_raw_parts((src_ptr as *const f32).add(y * row_len), row_len)
34            };
35            let dst = unsafe {
36                std::slice::from_raw_parts_mut((dst_ptr as *mut f32).add(y * row_len), row_len)
37            };
38            threshold_binary_simd_slice(src, dst, threshold, max_val);
39        });
40        return Tensor::from_aligned(vec![h, w, channels], out).map_err(Into::into);
41    }
42
43    if len >= RAYON_THRESHOLD {
44        out.par_chunks_mut(row_len)
45            .enumerate()
46            .for_each(|(y, dst)| {
47                let src = &data[y * row_len..(y + 1) * row_len];
48                threshold_binary_simd_slice(src, dst, threshold, max_val);
49            });
50    } else {
51        threshold_binary_simd_slice(data, &mut out, threshold, max_val);
52    }
53    Tensor::from_aligned(vec![h, w, channels], out).map_err(Into::into)
54}
55
56/// Inverse binary threshold: outputs `0.0` where value > threshold, else `max_val`.
57#[allow(unsafe_code, clippy::uninit_vec)]
58pub fn threshold_binary_inv(
59    input: &Tensor,
60    threshold: f32,
61    max_val: f32,
62) -> Result<Tensor, ImgProcError> {
63    let (h, w, channels) = hwc_shape(input)?;
64    let data = input.data();
65    let len = data.len();
66    // SAFETY: every element is written by threshold_binary_inv_simd_slice + scalar tail below.
67    let mut out = AlignedVec::<f32>::uninitialized(len);
68
69    let row_len = w * channels;
70
71    #[cfg(target_os = "macos")]
72    if len >= RAYON_THRESHOLD && !cfg!(miri) {
73        let src_ptr = data.as_ptr() as usize;
74        let dst_ptr = out.as_mut_ptr() as usize;
75        use super::u8ops::gcd;
76        gcd::parallel_for(h, |y| {
77            let src = unsafe {
78                std::slice::from_raw_parts((src_ptr as *const f32).add(y * row_len), row_len)
79            };
80            let dst = unsafe {
81                std::slice::from_raw_parts_mut((dst_ptr as *mut f32).add(y * row_len), row_len)
82            };
83            threshold_binary_inv_simd_slice(src, dst, threshold, max_val);
84        });
85        return Tensor::from_aligned(vec![h, w, channels], out).map_err(Into::into);
86    }
87
88    if len >= RAYON_THRESHOLD {
89        out.par_chunks_mut(row_len)
90            .enumerate()
91            .for_each(|(y, dst)| {
92                let src = &data[y * row_len..(y + 1) * row_len];
93                threshold_binary_inv_simd_slice(src, dst, threshold, max_val);
94            });
95    } else {
96        threshold_binary_inv_simd_slice(data, &mut out, threshold, max_val);
97    }
98    Tensor::from_aligned(vec![h, w, channels], out).map_err(Into::into)
99}
100
101/// Truncate threshold: caps values above `threshold`.
102#[allow(unsafe_code, clippy::uninit_vec)]
103pub fn threshold_truncate(input: &Tensor, threshold: f32) -> Result<Tensor, ImgProcError> {
104    let (h, w, channels) = hwc_shape(input)?;
105    let data = input.data();
106    let len = data.len();
107    // SAFETY: every element is written by threshold_truncate_simd_slice + scalar tail below.
108    let mut out = AlignedVec::<f32>::uninitialized(len);
109
110    let row_len = w * channels;
111
112    #[cfg(target_os = "macos")]
113    if len >= RAYON_THRESHOLD && !cfg!(miri) {
114        let src_ptr = data.as_ptr() as usize;
115        let dst_ptr = out.as_mut_ptr() as usize;
116        use super::u8ops::gcd;
117        gcd::parallel_for(h, |y| {
118            let src = unsafe {
119                std::slice::from_raw_parts((src_ptr as *const f32).add(y * row_len), row_len)
120            };
121            let dst = unsafe {
122                std::slice::from_raw_parts_mut((dst_ptr as *mut f32).add(y * row_len), row_len)
123            };
124            threshold_truncate_simd_slice(src, dst, threshold);
125        });
126        return Tensor::from_aligned(vec![h, w, channels], out).map_err(Into::into);
127    }
128
129    if len >= RAYON_THRESHOLD {
130        out.par_chunks_mut(row_len)
131            .enumerate()
132            .for_each(|(y, dst)| {
133                let src = &data[y * row_len..(y + 1) * row_len];
134                threshold_truncate_simd_slice(src, dst, threshold);
135            });
136    } else {
137        threshold_truncate_simd_slice(data, &mut out, threshold);
138    }
139    Tensor::from_aligned(vec![h, w, channels], out).map_err(Into::into)
140}
141
142/// SIMD-accelerated binary threshold for an entire slice.
143#[allow(unsafe_code)]
144#[inline(always)]
145fn threshold_binary_simd_slice(src: &[f32], dst: &mut [f32], threshold: f32, max_val: f32) {
146    debug_assert_eq!(src.len(), dst.len());
147    let len = src.len();
148    let mut i = 0usize;
149
150    if !cfg!(miri) {
151        #[cfg(target_arch = "aarch64")]
152        {
153            if std::arch::is_aarch64_feature_detected!("neon") {
154                // SAFETY: feature detected; pointers valid for len elements.
155                i = unsafe {
156                    threshold_binary_neon(src.as_ptr(), dst.as_mut_ptr(), len, threshold, max_val)
157                };
158            }
159        }
160        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
161        {
162            if std::is_x86_feature_detected!("avx") {
163                // SAFETY: feature detected; pointers valid for len elements.
164                i = unsafe {
165                    threshold_binary_avx(src.as_ptr(), dst.as_mut_ptr(), len, threshold, max_val)
166                };
167            } else if std::is_x86_feature_detected!("sse") {
168                // SAFETY: feature detected; pointers valid for len elements.
169                i = unsafe {
170                    threshold_binary_sse(src.as_ptr(), dst.as_mut_ptr(), len, threshold, max_val)
171                };
172            }
173        }
174    }
175
176    // Scalar tail
177    while i < len {
178        dst[i] = if src[i] > threshold { max_val } else { 0.0 };
179        i += 1;
180    }
181}
182
183#[cfg(target_arch = "aarch64")]
184#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
185#[target_feature(enable = "neon")]
186unsafe fn threshold_binary_neon(
187    src: *const f32,
188    dst: *mut f32,
189    len: usize,
190    threshold: f32,
191    max_val: f32,
192) -> usize {
193    use std::arch::aarch64::*;
194    let thresh_v = vdupq_n_f32(threshold);
195    let max_v = vdupq_n_f32(max_val);
196    let zero_v = vdupq_n_f32(0.0);
197    let mut x = 0usize;
198    // Process 32 floats (8×4) per iteration for better throughput
199    while x + 32 <= len {
200        let v0 = vld1q_f32(src.add(x));
201        let v1 = vld1q_f32(src.add(x + 4));
202        let v2 = vld1q_f32(src.add(x + 8));
203        let v3 = vld1q_f32(src.add(x + 12));
204        let v4 = vld1q_f32(src.add(x + 16));
205        let v5 = vld1q_f32(src.add(x + 20));
206        let v6 = vld1q_f32(src.add(x + 24));
207        let v7 = vld1q_f32(src.add(x + 28));
208        vst1q_f32(
209            dst.add(x),
210            vbslq_f32(vcgtq_f32(v0, thresh_v), max_v, zero_v),
211        );
212        vst1q_f32(
213            dst.add(x + 4),
214            vbslq_f32(vcgtq_f32(v1, thresh_v), max_v, zero_v),
215        );
216        vst1q_f32(
217            dst.add(x + 8),
218            vbslq_f32(vcgtq_f32(v2, thresh_v), max_v, zero_v),
219        );
220        vst1q_f32(
221            dst.add(x + 12),
222            vbslq_f32(vcgtq_f32(v3, thresh_v), max_v, zero_v),
223        );
224        vst1q_f32(
225            dst.add(x + 16),
226            vbslq_f32(vcgtq_f32(v4, thresh_v), max_v, zero_v),
227        );
228        vst1q_f32(
229            dst.add(x + 20),
230            vbslq_f32(vcgtq_f32(v5, thresh_v), max_v, zero_v),
231        );
232        vst1q_f32(
233            dst.add(x + 24),
234            vbslq_f32(vcgtq_f32(v6, thresh_v), max_v, zero_v),
235        );
236        vst1q_f32(
237            dst.add(x + 28),
238            vbslq_f32(vcgtq_f32(v7, thresh_v), max_v, zero_v),
239        );
240        x += 32;
241    }
242    while x + 4 <= len {
243        let v = vld1q_f32(src.add(x));
244        let mask = vcgtq_f32(v, thresh_v);
245        let result = vbslq_f32(mask, max_v, zero_v);
246        vst1q_f32(dst.add(x), result);
247        x += 4;
248    }
249    x
250}
251
252#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
253#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
254#[target_feature(enable = "avx")]
255unsafe fn threshold_binary_avx(
256    src: *const f32,
257    dst: *mut f32,
258    len: usize,
259    threshold: f32,
260    max_val: f32,
261) -> usize {
262    #[cfg(target_arch = "x86")]
263    use std::arch::x86::*;
264    #[cfg(target_arch = "x86_64")]
265    use std::arch::x86_64::*;
266    let thresh_v = _mm256_set1_ps(threshold);
267    let max_v = _mm256_set1_ps(max_val);
268    let mut x = 0usize;
269    // 4× unrolled: 32 elements per iteration
270    while x + 32 <= len {
271        let v0 = _mm256_loadu_ps(src.add(x));
272        let v1 = _mm256_loadu_ps(src.add(x + 8));
273        let v2 = _mm256_loadu_ps(src.add(x + 16));
274        let v3 = _mm256_loadu_ps(src.add(x + 24));
275        _mm256_storeu_ps(
276            dst.add(x),
277            _mm256_and_ps(_mm256_cmp_ps::<14>(v0, thresh_v), max_v),
278        );
279        _mm256_storeu_ps(
280            dst.add(x + 8),
281            _mm256_and_ps(_mm256_cmp_ps::<14>(v1, thresh_v), max_v),
282        );
283        _mm256_storeu_ps(
284            dst.add(x + 16),
285            _mm256_and_ps(_mm256_cmp_ps::<14>(v2, thresh_v), max_v),
286        );
287        _mm256_storeu_ps(
288            dst.add(x + 24),
289            _mm256_and_ps(_mm256_cmp_ps::<14>(v3, thresh_v), max_v),
290        );
291        x += 32;
292    }
293    while x + 8 <= len {
294        _mm256_storeu_ps(
295            dst.add(x),
296            _mm256_and_ps(
297                _mm256_cmp_ps::<14>(_mm256_loadu_ps(src.add(x)), thresh_v),
298                max_v,
299            ),
300        );
301        x += 8;
302    }
303    x
304}
305
306#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
307#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
308#[target_feature(enable = "sse")]
309unsafe fn threshold_binary_sse(
310    src: *const f32,
311    dst: *mut f32,
312    len: usize,
313    threshold: f32,
314    max_val: f32,
315) -> usize {
316    #[cfg(target_arch = "x86")]
317    use std::arch::x86::*;
318    #[cfg(target_arch = "x86_64")]
319    use std::arch::x86_64::*;
320    let thresh_v = _mm_set1_ps(threshold);
321    let max_v = _mm_set1_ps(max_val);
322    let mut x = 0usize;
323    while x + 4 <= len {
324        let v = _mm_loadu_ps(src.add(x));
325        let mask = _mm_cmpgt_ps(v, thresh_v);
326        let result = _mm_and_ps(mask, max_v);
327        _mm_storeu_ps(dst.add(x), result);
328        x += 4;
329    }
330    x
331}
332
333/// SIMD-accelerated inverse binary threshold for an entire slice.
334#[allow(unsafe_code)]
335#[inline(always)]
336fn threshold_binary_inv_simd_slice(src: &[f32], dst: &mut [f32], threshold: f32, max_val: f32) {
337    debug_assert_eq!(src.len(), dst.len());
338    let len = src.len();
339    let mut i = 0usize;
340
341    if !cfg!(miri) {
342        #[cfg(target_arch = "aarch64")]
343        {
344            if std::arch::is_aarch64_feature_detected!("neon") {
345                // SAFETY: feature detected; pointers valid for len elements.
346                i = unsafe {
347                    threshold_binary_inv_neon(
348                        src.as_ptr(),
349                        dst.as_mut_ptr(),
350                        len,
351                        threshold,
352                        max_val,
353                    )
354                };
355            }
356        }
357        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
358        {
359            if std::is_x86_feature_detected!("avx") {
360                // SAFETY: feature detected; pointers valid for len elements.
361                i = unsafe {
362                    threshold_binary_inv_avx(
363                        src.as_ptr(),
364                        dst.as_mut_ptr(),
365                        len,
366                        threshold,
367                        max_val,
368                    )
369                };
370            } else if std::is_x86_feature_detected!("sse") {
371                // SAFETY: feature detected; pointers valid for len elements.
372                i = unsafe {
373                    threshold_binary_inv_sse(
374                        src.as_ptr(),
375                        dst.as_mut_ptr(),
376                        len,
377                        threshold,
378                        max_val,
379                    )
380                };
381            }
382        }
383    }
384
385    // Scalar tail
386    while i < len {
387        dst[i] = if src[i] > threshold { 0.0 } else { max_val };
388        i += 1;
389    }
390}
391
392#[cfg(target_arch = "aarch64")]
393#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
394#[target_feature(enable = "neon")]
395unsafe fn threshold_binary_inv_neon(
396    src: *const f32,
397    dst: *mut f32,
398    len: usize,
399    threshold: f32,
400    max_val: f32,
401) -> usize {
402    use std::arch::aarch64::*;
403    let thresh_v = vdupq_n_f32(threshold);
404    let max_v = vdupq_n_f32(max_val);
405    let zero_v = vdupq_n_f32(0.0);
406    let mut x = 0usize;
407    // Process 16 floats (4x4) per iteration for better throughput
408    while x + 16 <= len {
409        let v0 = vld1q_f32(src.add(x));
410        let v1 = vld1q_f32(src.add(x + 4));
411        let v2 = vld1q_f32(src.add(x + 8));
412        let v3 = vld1q_f32(src.add(x + 12));
413        vst1q_f32(
414            dst.add(x),
415            vbslq_f32(vcgtq_f32(v0, thresh_v), zero_v, max_v),
416        );
417        vst1q_f32(
418            dst.add(x + 4),
419            vbslq_f32(vcgtq_f32(v1, thresh_v), zero_v, max_v),
420        );
421        vst1q_f32(
422            dst.add(x + 8),
423            vbslq_f32(vcgtq_f32(v2, thresh_v), zero_v, max_v),
424        );
425        vst1q_f32(
426            dst.add(x + 12),
427            vbslq_f32(vcgtq_f32(v3, thresh_v), zero_v, max_v),
428        );
429        x += 16;
430    }
431    while x + 4 <= len {
432        let v = vld1q_f32(src.add(x));
433        let mask = vcgtq_f32(v, thresh_v);
434        let result = vbslq_f32(mask, zero_v, max_v);
435        vst1q_f32(dst.add(x), result);
436        x += 4;
437    }
438    x
439}
440
441#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
442#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
443#[target_feature(enable = "avx")]
444unsafe fn threshold_binary_inv_avx(
445    src: *const f32,
446    dst: *mut f32,
447    len: usize,
448    threshold: f32,
449    max_val: f32,
450) -> usize {
451    #[cfg(target_arch = "x86")]
452    use std::arch::x86::*;
453    #[cfg(target_arch = "x86_64")]
454    use std::arch::x86_64::*;
455    let thresh_v = _mm256_set1_ps(threshold);
456    let max_v = _mm256_set1_ps(max_val);
457    let mut x = 0usize;
458    while x + 8 <= len {
459        let v = _mm256_loadu_ps(src.add(x));
460        // _CMP_GT_OQ = 14: greater-than, ordered, quiet
461        let mask = _mm256_cmp_ps::<14>(v, thresh_v);
462        let result = _mm256_andnot_ps(mask, max_v);
463        _mm256_storeu_ps(dst.add(x), result);
464        x += 8;
465    }
466    x
467}
468
469#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
470#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
471#[target_feature(enable = "sse")]
472unsafe fn threshold_binary_inv_sse(
473    src: *const f32,
474    dst: *mut f32,
475    len: usize,
476    threshold: f32,
477    max_val: f32,
478) -> usize {
479    #[cfg(target_arch = "x86")]
480    use std::arch::x86::*;
481    #[cfg(target_arch = "x86_64")]
482    use std::arch::x86_64::*;
483    let thresh_v = _mm_set1_ps(threshold);
484    let max_v = _mm_set1_ps(max_val);
485    let mut x = 0usize;
486    while x + 4 <= len {
487        let v = _mm_loadu_ps(src.add(x));
488        let mask = _mm_cmpgt_ps(v, thresh_v);
489        let result = _mm_andnot_ps(mask, max_v);
490        _mm_storeu_ps(dst.add(x), result);
491        x += 4;
492    }
493    x
494}
495
496/// SIMD-accelerated truncate threshold for an entire slice.
497#[allow(unsafe_code)]
498#[inline(always)]
499fn threshold_truncate_simd_slice(src: &[f32], dst: &mut [f32], threshold: f32) {
500    debug_assert_eq!(src.len(), dst.len());
501    let len = src.len();
502    let mut i = 0usize;
503
504    if !cfg!(miri) {
505        #[cfg(target_arch = "aarch64")]
506        {
507            if std::arch::is_aarch64_feature_detected!("neon") {
508                // SAFETY: feature detected; pointers valid for len elements.
509                i = unsafe {
510                    threshold_truncate_neon(src.as_ptr(), dst.as_mut_ptr(), len, threshold)
511                };
512            }
513        }
514        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
515        {
516            if std::is_x86_feature_detected!("avx") {
517                // SAFETY: feature detected; pointers valid for len elements.
518                i = unsafe {
519                    threshold_truncate_avx(src.as_ptr(), dst.as_mut_ptr(), len, threshold)
520                };
521            } else if std::is_x86_feature_detected!("sse") {
522                // SAFETY: feature detected; pointers valid for len elements.
523                i = unsafe {
524                    threshold_truncate_sse(src.as_ptr(), dst.as_mut_ptr(), len, threshold)
525                };
526            }
527        }
528    }
529
530    // Scalar tail
531    while i < len {
532        dst[i] = if src[i] > threshold {
533            threshold
534        } else {
535            src[i]
536        };
537        i += 1;
538    }
539}
540
541#[cfg(target_arch = "aarch64")]
542#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
543#[target_feature(enable = "neon")]
544unsafe fn threshold_truncate_neon(
545    src: *const f32,
546    dst: *mut f32,
547    len: usize,
548    threshold: f32,
549) -> usize {
550    use std::arch::aarch64::*;
551    let thresh_v = vdupq_n_f32(threshold);
552    let mut x = 0usize;
553    // Process 16 floats (4x4) per iteration for better throughput
554    while x + 16 <= len {
555        let v0 = vld1q_f32(src.add(x));
556        let v1 = vld1q_f32(src.add(x + 4));
557        let v2 = vld1q_f32(src.add(x + 8));
558        let v3 = vld1q_f32(src.add(x + 12));
559        vst1q_f32(dst.add(x), vminq_f32(v0, thresh_v));
560        vst1q_f32(dst.add(x + 4), vminq_f32(v1, thresh_v));
561        vst1q_f32(dst.add(x + 8), vminq_f32(v2, thresh_v));
562        vst1q_f32(dst.add(x + 12), vminq_f32(v3, thresh_v));
563        x += 16;
564    }
565    while x + 4 <= len {
566        let v = vld1q_f32(src.add(x));
567        vst1q_f32(dst.add(x), vminq_f32(v, thresh_v));
568        x += 4;
569    }
570    x
571}
572
573#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
574#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
575#[target_feature(enable = "avx")]
576unsafe fn threshold_truncate_avx(
577    src: *const f32,
578    dst: *mut f32,
579    len: usize,
580    threshold: f32,
581) -> usize {
582    #[cfg(target_arch = "x86")]
583    use std::arch::x86::*;
584    #[cfg(target_arch = "x86_64")]
585    use std::arch::x86_64::*;
586    let thresh_v = _mm256_set1_ps(threshold);
587    let mut x = 0usize;
588    while x + 8 <= len {
589        let v = _mm256_loadu_ps(src.add(x));
590        let result = _mm256_min_ps(v, thresh_v);
591        _mm256_storeu_ps(dst.add(x), result);
592        x += 8;
593    }
594    x
595}
596
597#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
598#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
599#[target_feature(enable = "sse")]
600unsafe fn threshold_truncate_sse(
601    src: *const f32,
602    dst: *mut f32,
603    len: usize,
604    threshold: f32,
605) -> usize {
606    #[cfg(target_arch = "x86")]
607    use std::arch::x86::*;
608    #[cfg(target_arch = "x86_64")]
609    use std::arch::x86_64::*;
610    let thresh_v = _mm_set1_ps(threshold);
611    let mut x = 0usize;
612    while x + 4 <= len {
613        let v = _mm_loadu_ps(src.add(x));
614        let result = _mm_min_ps(v, thresh_v);
615        _mm_storeu_ps(dst.add(x), result);
616        x += 4;
617    }
618    x
619}
620
621/// Otsu threshold for single-channel `[H, W, 1]` images.
622/// Returns `(threshold, thresholded_image)`.
623pub fn threshold_otsu(input: &Tensor, max_val: f32) -> Result<(f32, Tensor), ImgProcError> {
624    let (_h, _w, channels) = hwc_shape(input)?;
625    if channels != 1 {
626        return Err(ImgProcError::InvalidChannelCount {
627            expected: 1,
628            got: channels,
629        });
630    }
631
632    // Build 256-bin histogram over [0, 1] range
633    let data = input.data();
634    let total = data.len() as f32;
635    let mut hist = [0u32; 256];
636    for &v in data {
637        let bin = (v.clamp(0.0, 1.0) * 255.0) as usize;
638        hist[bin.min(255)] += 1;
639    }
640
641    let mut sum_total = 0.0f64;
642    for (i, &count) in hist.iter().enumerate() {
643        sum_total += i as f64 * count as f64;
644    }
645
646    let mut sum_bg = 0.0f64;
647    let mut weight_bg = 0.0f64;
648    let mut max_variance = 0.0f64;
649    let mut best_t = 0usize;
650
651    for (t, &count) in hist.iter().enumerate() {
652        weight_bg += count as f64;
653        if weight_bg == 0.0 {
654            continue;
655        }
656        let weight_fg = total as f64 - weight_bg;
657        if weight_fg == 0.0 {
658            break;
659        }
660
661        sum_bg += t as f64 * count as f64;
662        let mean_bg = sum_bg / weight_bg;
663        let mean_fg = (sum_total - sum_bg) / weight_fg;
664        let diff = mean_bg - mean_fg;
665        let variance = weight_bg * weight_fg * diff * diff;
666
667        if variance > max_variance {
668            max_variance = variance;
669            best_t = t;
670        }
671    }
672
673    let threshold = best_t as f32 / 255.0;
674    let thresholded = threshold_binary(input, threshold, max_val)?;
675    Ok((threshold, thresholded))
676}
677
678/// Canny edge detection on a single-channel HWC image.
679///
680/// Steps: Sobel gradients -> non-maximum suppression -> double-threshold hysteresis.
681/// Returns a binary edge map with values `0.0` or `1.0`.
682/// Reusable scratch buffers for [`canny_with_scratch`].
683///
684/// Pre-allocating these avoids per-call allocation overhead in hot loops
685/// (e.g., processing video frames).
686pub struct CannyScratch {
687    magnitude: Vec<f32>,
688    direction: Vec<u8>,
689    nms: Vec<f32>,
690    edges: Vec<u8>,
691    queue: Vec<usize>,
692}
693
694impl CannyScratch {
695    /// Creates empty scratch (buffers grow on first use).
696    pub fn new() -> Self {
697        Self {
698            magnitude: Vec::new(),
699            direction: Vec::new(),
700            nms: Vec::new(),
701            edges: Vec::new(),
702            queue: Vec::new(),
703        }
704    }
705
706    fn ensure_capacity(&mut self, len: usize) {
707        self.magnitude.resize(len, 0.0);
708        self.direction.resize(len, 0);
709        self.nms.resize(len, 0.0);
710        self.edges.resize(len, 0);
711    }
712}
713
714impl Default for CannyScratch {
715    fn default() -> Self {
716        Self::new()
717    }
718}
719
720pub fn canny(input: &Tensor, low_thresh: f32, high_thresh: f32) -> Result<Tensor, ImgProcError> {
721    let mut scratch = CannyScratch::new();
722    canny_with_scratch(input, low_thresh, high_thresh, &mut scratch)
723}
724
725/// Canny edge detection with reusable scratch buffers.
726///
727/// Optimized with:
728/// - Alpha-max-beta-min magnitude approximation (no sqrt)
729/// - Fast gradient direction via sign comparison (no atan2)
730/// - Single-pass BFS hysteresis (no iterative convergence loop)
731pub fn canny_with_scratch(
732    input: &Tensor,
733    low_thresh: f32,
734    high_thresh: f32,
735    scratch: &mut CannyScratch,
736) -> Result<Tensor, ImgProcError> {
737    let (_h, _w, channels) = hwc_shape(input)?;
738    if channels != 1 {
739        return Err(ImgProcError::InvalidChannelCount {
740            expected: 1,
741            got: channels,
742        });
743    }
744
745    let (h, w, _) = hwc_shape(input)?;
746    let (gx, gy) = sobel_3x3_gradients(input)?;
747    let gx_data = gx.data();
748    let gy_data = gy.data();
749    let len = h * w;
750
751    scratch.ensure_capacity(len);
752    let magnitude = &mut scratch.magnitude;
753    let direction = &mut scratch.direction;
754    let nms = &mut scratch.nms;
755    let edges = &mut scratch.edges;
756
757    // Pass 1: Compute magnitude (alpha-max-beta-min) and direction (sign-based)
758    // Alpha-max-beta-min: mag ≈ max(|dx|,|dy|) + 0.414 * min(|dx|,|dy|)
759    // ~4% max error vs sqrt, but avoids expensive per-pixel sqrt
760    for i in 0..len {
761        let dx = gx_data[i];
762        let dy = gy_data[i];
763        let adx = dx.abs();
764        let ady = dy.abs();
765        let (big, small) = if adx > ady { (adx, ady) } else { (ady, adx) };
766        magnitude[i] = big + 0.414 * small;
767
768        // Fast direction quantization using sign comparisons:
769        // 0 = horizontal (|dx| > 2.414*|dy|)
770        // 2 = vertical (|dy| > 2.414*|dx|)
771        // 1 = diagonal /  (dx*dy < 0, roughly 45°)
772        // 3 = diagonal \  (dx*dy > 0, roughly 135°)
773        // Threshold 2.414 = tan(67.5°), using a fast approximation
774        direction[i] = if ady * 5.0 < adx * 2.0 {
775            // |dy|/|dx| < 0.4 → horizontal
776            0
777        } else if adx * 5.0 < ady * 2.0 {
778            // |dx|/|dy| < 0.4 → vertical
779            2
780        } else if (dx > 0.0) == (dy > 0.0) {
781            // Same sign → NE-SW diagonal
782            1
783        } else {
784            // Different sign → NW-SE diagonal
785            3
786        };
787    }
788
789    // Pass 2: Non-maximum suppression
790    for v in nms.iter_mut() {
791        *v = 0.0;
792    }
793    for y in 1..h.saturating_sub(1) {
794        for x in 1..w.saturating_sub(1) {
795            let idx = y * w + x;
796            let mag = magnitude[idx];
797            let (n1, n2) = match direction[idx] {
798                0 => (magnitude[y * w + x - 1], magnitude[y * w + x + 1]),
799                1 => (
800                    magnitude[(y - 1) * w + x + 1],
801                    magnitude[(y + 1) * w + x - 1],
802                ),
803                2 => (magnitude[(y - 1) * w + x], magnitude[(y + 1) * w + x]),
804                _ => (
805                    magnitude[(y - 1) * w + x - 1],
806                    magnitude[(y + 1) * w + x + 1],
807                ),
808            };
809            if mag >= n1 && mag >= n2 {
810                nms[idx] = mag;
811            }
812        }
813    }
814
815    // Pass 3: Double threshold + BFS hysteresis (single pass, no convergence loop)
816    for v in edges.iter_mut() {
817        *v = 0;
818    }
819    scratch.queue.clear();
820
821    // Seed the queue with strong edges
822    for i in 0..len {
823        if nms[i] >= high_thresh {
824            edges[i] = 2;
825            scratch.queue.push(i);
826        } else if nms[i] >= low_thresh {
827            edges[i] = 1;
828        }
829    }
830
831    // BFS: propagate from strong edges to connected weak edges
832    let mut head = 0;
833    while head < scratch.queue.len() {
834        let idx = scratch.queue[head];
835        head += 1;
836        let y = idx / w;
837        let x = idx % w;
838        if y == 0 || y >= h - 1 || x == 0 || x >= w - 1 {
839            continue;
840        }
841        // Check 8 neighbors
842        for dy in [-1isize, 0, 1] {
843            for dx in [-1isize, 0, 1] {
844                if dy == 0 && dx == 0 {
845                    continue;
846                }
847                let ny = (y as isize + dy) as usize;
848                let nx = (x as isize + dx) as usize;
849                let ni = ny * w + nx;
850                if edges[ni] == 1 {
851                    edges[ni] = 2;
852                    scratch.queue.push(ni);
853                }
854            }
855        }
856    }
857
858    let out: Vec<f32> = edges
859        .iter()
860        .map(|&e| if e == 2 { 1.0 } else { 0.0 })
861        .collect();
862    Tensor::from_vec(vec![h, w, 1], out).map_err(Into::into)
863}
864
865/// Adaptive threshold using a local mean window.
866///
867/// For each pixel, threshold = local_mean - constant. Pixels above threshold get `max_val`.
868/// Operates on single-channel HWC input. `block_size` must be odd and > 0.
869pub fn adaptive_threshold_mean(
870    input: &Tensor,
871    max_val: f32,
872    block_size: usize,
873    constant: f32,
874) -> Result<Tensor, ImgProcError> {
875    let (h, w, channels) = hwc_shape(input)?;
876    if channels != 1 {
877        return Err(ImgProcError::InvalidChannelCount {
878            expected: 1,
879            got: channels,
880        });
881    }
882    if block_size == 0 || block_size.is_multiple_of(2) {
883        return Err(ImgProcError::InvalidBlockSize { block_size });
884    }
885
886    let data = input.data();
887    let half = (block_size / 2) as isize;
888    let mut out = vec![0.0f32; h * w];
889
890    for y in 0..h {
891        for x in 0..w {
892            let mut sum = 0.0f32;
893            let mut count = 0u32;
894            for ky in -half..=half {
895                for kx in -half..=half {
896                    let sy = y as isize + ky;
897                    let sx = x as isize + kx;
898                    if sy >= 0 && sy < h as isize && sx >= 0 && sx < w as isize {
899                        sum += data[sy as usize * w + sx as usize];
900                        count += 1;
901                    }
902                }
903            }
904            let local_mean = sum / count as f32;
905            let threshold = local_mean - constant;
906            out[y * w + x] = if data[y * w + x] > threshold {
907                max_val
908            } else {
909                0.0
910            };
911        }
912    }
913
914    Tensor::from_vec(vec![h, w, 1], out).map_err(Into::into)
915}
916
917/// Adaptive threshold using a Gaussian-weighted local window.
918///
919/// Operates on single-channel HWC input. `block_size` must be odd and > 0.
920pub fn adaptive_threshold_gaussian(
921    input: &Tensor,
922    max_val: f32,
923    block_size: usize,
924    constant: f32,
925) -> Result<Tensor, ImgProcError> {
926    let (h, w, channels) = hwc_shape(input)?;
927    if channels != 1 {
928        return Err(ImgProcError::InvalidChannelCount {
929            expected: 1,
930            got: channels,
931        });
932    }
933    if block_size == 0 || block_size.is_multiple_of(2) {
934        return Err(ImgProcError::InvalidBlockSize { block_size });
935    }
936
937    let half = block_size / 2;
938    let sigma = 0.3 * ((block_size as f64 - 1.0) * 0.5 - 1.0) + 0.8;
939    let sigma2 = sigma * sigma;
940
941    let mut kernel = vec![0.0f64; block_size * block_size];
942    let mut ksum = 0.0f64;
943    for ky in 0..block_size {
944        for kx in 0..block_size {
945            let dy = ky as f64 - half as f64;
946            let dx = kx as f64 - half as f64;
947            let val = (-(dy * dy + dx * dx) / (2.0 * sigma2)).exp();
948            kernel[ky * block_size + kx] = val;
949            ksum += val;
950        }
951    }
952    for v in &mut kernel {
953        *v /= ksum;
954    }
955
956    let data = input.data();
957    let half_i = half as isize;
958    let mut out = vec![0.0f32; h * w];
959
960    for y in 0..h {
961        for x in 0..w {
962            let mut wsum = 0.0f64;
963            let mut wnorm = 0.0f64;
964            for ky in -half_i..=half_i {
965                for kx in -half_i..=half_i {
966                    let sy = y as isize + ky;
967                    let sx = x as isize + kx;
968                    if sy >= 0 && sy < h as isize && sx >= 0 && sx < w as isize {
969                        let kw =
970                            kernel[(ky + half_i) as usize * block_size + (kx + half_i) as usize];
971                        wsum += data[sy as usize * w + sx as usize] as f64 * kw;
972                        wnorm += kw;
973                    }
974                }
975            }
976            let local_mean = (wsum / wnorm) as f32;
977            let threshold = local_mean - constant;
978            out[y * w + x] = if data[y * w + x] > threshold {
979                max_val
980            } else {
981                0.0
982            };
983        }
984    }
985
986    Tensor::from_vec(vec![h, w, 1], out).map_err(Into::into)
987}
988
989/// Connected-component labeling on a binary single-channel HWC image (4-connectivity).
990///
991/// Input pixels > 0 are foreground. Returns `(label_map, num_labels)` where
992/// each connected component has a unique positive integer label, background is 0.
993pub fn connected_components_4(input: &Tensor) -> Result<(Tensor, usize), ImgProcError> {
994    let (h, w, channels) = hwc_shape(input)?;
995    if channels != 1 {
996        return Err(ImgProcError::InvalidChannelCount {
997            expected: 1,
998            got: channels,
999        });
1000    }
1001
1002    let data = input.data();
1003    let len = h * w;
1004    let mut labels = vec![0u32; len];
1005    let mut next_label = 1u32;
1006    let mut equivalences: Vec<u32> = vec![0];
1007
1008    for y in 0..h {
1009        for x in 0..w {
1010            let idx = y * w + x;
1011            if data[idx] <= 0.0 {
1012                continue;
1013            }
1014            let left = if x > 0 { labels[y * w + x - 1] } else { 0 };
1015            let above = if y > 0 { labels[(y - 1) * w + x] } else { 0 };
1016
1017            match (left > 0, above > 0) {
1018                (false, false) => {
1019                    labels[idx] = next_label;
1020                    equivalences.push(next_label);
1021                    next_label += 1;
1022                }
1023                (true, false) => labels[idx] = left,
1024                (false, true) => labels[idx] = above,
1025                (true, true) => {
1026                    let rl = find_root(&equivalences, left);
1027                    let ra = find_root(&equivalences, above);
1028                    labels[idx] = rl.min(ra);
1029                    if rl != ra {
1030                        let (lo, hi) = if rl < ra { (rl, ra) } else { (ra, rl) };
1031                        equivalences[hi as usize] = lo;
1032                    }
1033                }
1034            }
1035        }
1036    }
1037
1038    let mut canonical = vec![0u32; next_label as usize];
1039    let mut label_count = 0u32;
1040    #[allow(clippy::needless_range_loop)]
1041    for i in 1..next_label as usize {
1042        let root = find_root(&equivalences, i as u32);
1043        if root == i as u32 {
1044            label_count += 1;
1045            canonical[i] = label_count;
1046        }
1047    }
1048    #[allow(clippy::needless_range_loop)]
1049    for i in 1..next_label as usize {
1050        let root = find_root(&equivalences, i as u32);
1051        canonical[i] = canonical[root as usize];
1052    }
1053
1054    let out: Vec<f32> = labels
1055        .iter()
1056        .map(|&l| {
1057            if l == 0 {
1058                0.0
1059            } else {
1060                canonical[l as usize] as f32
1061            }
1062        })
1063        .collect();
1064
1065    Ok((Tensor::from_vec(vec![h, w, 1], out)?, label_count as usize))
1066}
1067
1068pub(crate) fn find_root(equiv: &[u32], mut label: u32) -> u32 {
1069    while equiv[label as usize] != label {
1070        label = equiv[label as usize];
1071    }
1072    label
1073}