Skip to main content

yscv_imgproc/ops/
features.rs

1use rayon::prelude::*;
2use yscv_tensor::Tensor;
3
4use super::super::ImgProcError;
5use super::super::shape::hwc_shape;
6
7/// Detected corner/interest point.
8#[derive(Debug, Clone, Copy, PartialEq)]
9pub struct HarrisKeypoint {
10    pub x: usize,
11    pub y: usize,
12    pub response: f32,
13}
14
15/// Harris corner detector on a single-channel `[H, W, 1]` image.
16///
17/// Returns corners whose Harris response exceeds `threshold`.
18/// `k` is the Harris sensitivity parameter (typical 0.04-0.06).
19#[allow(unsafe_code)]
20pub fn harris_corners(
21    input: &Tensor,
22    block_size: usize,
23    k: f32,
24    threshold: f32,
25) -> Result<Vec<HarrisKeypoint>, ImgProcError> {
26    let (h, w, c) = hwc_shape(input)?;
27    if c != 1 {
28        return Err(ImgProcError::InvalidChannelCount {
29            expected: 1,
30            got: c,
31        });
32    }
33    let data = input.data();
34
35    // Interleaved structure tensor products: [sxx, sxy, syy] per pixel
36    let n = h * w;
37    let mut prods = vec![0.0f32; n * 3]; // interleaved [sxx0,sxy0,syy0, sxx1,sxy1,syy1, ...]
38
39    #[cfg(target_arch = "aarch64")]
40    if !cfg!(miri) && std::arch::is_aarch64_feature_detected!("neon") {
41        unsafe {
42            sobel_products_interleaved_neon(data, &mut prods, h, w);
43        }
44    } else {
45        sobel_products_interleaved_scalar(data, &mut prods, h, w);
46    }
47    #[cfg(not(target_arch = "aarch64"))]
48    {
49        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
50        if !cfg!(miri) && std::is_x86_feature_detected!("sse") {
51            unsafe {
52                sobel_products_interleaved_sse(data, &mut prods, h, w);
53            }
54        } else {
55            sobel_products_interleaved_scalar(data, &mut prods, h, w);
56        }
57        #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
58        sobel_products_interleaved_scalar(data, &mut prods, h, w);
59    }
60
61    let half = (block_size / 2) as i32;
62    let r = half as usize;
63
64    // Separable box filter on interleaved buffer.
65    // Horizontal pass: sliding window on 3 channels at once
66    let mut h_prods = vec![0.0f32; n * 3];
67    for y in 0..h {
68        let row3 = y * w * 3;
69        let mut r0 = 0.0f32;
70        let mut r1 = 0.0f32;
71        let mut r2 = 0.0f32;
72        for x in 0..r.min(w) {
73            let i = row3 + x * 3;
74            r0 += prods[i];
75            r1 += prods[i + 1];
76            r2 += prods[i + 2];
77        }
78        for x in 0..w {
79            if x + r < w {
80                let i = row3 + (x + r) * 3;
81                r0 += prods[i];
82                r1 += prods[i + 1];
83                r2 += prods[i + 2];
84            }
85            let o = row3 + x * 3;
86            h_prods[o] = r0;
87            h_prods[o + 1] = r1;
88            h_prods[o + 2] = r2;
89            if x >= r {
90                let i = row3 + (x - r) * 3;
91                r0 -= prods[i];
92                r1 -= prods[i + 1];
93                r2 -= prods[i + 2];
94            }
95        }
96    }
97
98    // Vertical pass: interleaved column accumulators (w*3 elements)
99    let stride3 = w * 3;
100    let mut col = vec![0.0f32; stride3];
101
102    for y in 0..r.min(h) {
103        vec_add_row(&mut col, &h_prods[y * stride3..(y + 1) * stride3]);
104    }
105
106    // Reuse prods for the final box-filtered result
107    for y in 0..h {
108        if y + r < h {
109            let bot = (y + r) * stride3;
110            vec_add_row(&mut col, &h_prods[bot..bot + stride3]);
111        }
112        let row = y * stride3;
113        prods[row..row + stride3].copy_from_slice(&col);
114        if y >= r {
115            let top = (y - r) * stride3;
116            vec_sub_row(&mut col, &h_prods[top..top + stride3]);
117        }
118    }
119
120    // prods now contains interleaved box-filtered [sxx,sxy,syy] per pixel.
121    // Compute Harris response in parallel using rayon.
122    let margin = r.max(1);
123    let row_range: Vec<usize> = (margin..h.saturating_sub(margin)).collect();
124    let corners: Vec<HarrisKeypoint> = row_range
125        .par_iter()
126        .flat_map(|&y| {
127            let mut row_corners = Vec::new();
128            for x in margin..w.saturating_sub(margin) {
129                let i = (y * w + x) * 3;
130                let a = prods[i]; // sxx
131                let b = prods[i + 1]; // sxy
132                let c = prods[i + 2]; // syy
133                let det = a * c - b * b;
134                let trace = a + c;
135                let response = det - k * trace * trace;
136                if response > threshold {
137                    row_corners.push(HarrisKeypoint { x, y, response });
138                }
139            }
140            row_corners
141        })
142        .collect();
143
144    // Non-maximum suppression within 3x3 neighborhood
145    let mut response_map = vec![0.0f32; h * w];
146    for corner in &corners {
147        response_map[corner.y * w + corner.x] = corner.response;
148    }
149    let corners: Vec<HarrisKeypoint> = corners
150        .into_iter()
151        .filter(|corner| {
152            for dy in -1i32..=1 {
153                for dx in -1i32..=1 {
154                    if dy == 0 && dx == 0 {
155                        continue;
156                    }
157                    let ny = (corner.y as i32 + dy) as usize;
158                    let nx = (corner.x as i32 + dx) as usize;
159                    if ny < h && nx < w && response_map[ny * w + nx] > corner.response {
160                        return false;
161                    }
162                }
163            }
164            true
165        })
166        .collect();
167
168    Ok(corners)
169}
170
171// ── SIMD row-vector helpers ────────────────────────────────────────
172
173/// acc[i] += src[i] for all i, NEON-accelerated on aarch64.
174#[allow(unsafe_code)]
175#[inline]
176fn vec_add_row(acc: &mut [f32], src: &[f32]) {
177    let n = acc.len().min(src.len());
178    let mut i = 0;
179
180    #[cfg(target_arch = "aarch64")]
181    if !cfg!(miri) {
182        unsafe {
183            use std::arch::aarch64::*;
184            let ap = acc.as_mut_ptr();
185            let sp = src.as_ptr();
186            while i + 4 <= n {
187                let a = vld1q_f32(ap.add(i));
188                let b = vld1q_f32(sp.add(i));
189                vst1q_f32(ap.add(i), vaddq_f32(a, b));
190                i += 4;
191            }
192        }
193    }
194
195    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
196    if !cfg!(miri) {
197        unsafe {
198            #[cfg(target_arch = "x86")]
199            use std::arch::x86::*;
200            #[cfg(target_arch = "x86_64")]
201            use std::arch::x86_64::*;
202            if std::is_x86_feature_detected!("sse") {
203                let ap = acc.as_mut_ptr();
204                let sp = src.as_ptr();
205                while i + 4 <= n {
206                    let a = _mm_loadu_ps(ap.add(i));
207                    let b = _mm_loadu_ps(sp.add(i));
208                    _mm_storeu_ps(ap.add(i), _mm_add_ps(a, b));
209                    i += 4;
210                }
211            }
212        }
213    }
214
215    while i < n {
216        acc[i] += src[i];
217        i += 1;
218    }
219}
220
221/// acc[i] -= src[i] for all i, NEON-accelerated on aarch64.
222#[allow(unsafe_code)]
223#[inline]
224fn vec_sub_row(acc: &mut [f32], src: &[f32]) {
225    let n = acc.len().min(src.len());
226    let mut i = 0;
227
228    #[cfg(target_arch = "aarch64")]
229    if !cfg!(miri) {
230        unsafe {
231            use std::arch::aarch64::*;
232            let ap = acc.as_mut_ptr();
233            let sp = src.as_ptr();
234            while i + 4 <= n {
235                let a = vld1q_f32(ap.add(i));
236                let b = vld1q_f32(sp.add(i));
237                vst1q_f32(ap.add(i), vsubq_f32(a, b));
238                i += 4;
239            }
240        }
241    }
242
243    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
244    if !cfg!(miri) {
245        unsafe {
246            #[cfg(target_arch = "x86")]
247            use std::arch::x86::*;
248            #[cfg(target_arch = "x86_64")]
249            use std::arch::x86_64::*;
250            if std::is_x86_feature_detected!("sse") {
251                let ap = acc.as_mut_ptr();
252                let sp = src.as_ptr();
253                while i + 4 <= n {
254                    let a = _mm_loadu_ps(ap.add(i));
255                    let b = _mm_loadu_ps(sp.add(i));
256                    _mm_storeu_ps(ap.add(i), _mm_sub_ps(a, b));
257                    i += 4;
258                }
259            }
260        }
261    }
262
263    while i < n {
264        acc[i] -= src[i];
265        i += 1;
266    }
267}
268
269// ── Sobel gradient helpers for Harris ──────────────────────────────
270
271/// Scalar fused Sobel gradient + interleaved product computation.
272/// Stores [Ix², Ix*Iy, Iy²] interleaved at each pixel position.
273#[inline]
274fn sobel_products_interleaved_scalar(data: &[f32], prods: &mut [f32], h: usize, w: usize) {
275    for y in 1..h - 1 {
276        for x in 1..w - 1 {
277            let p =
278                |dy: i32, dx: i32| data[(y as i32 + dy) as usize * w + (x as i32 + dx) as usize];
279            let gx = -p(-1, -1) + p(-1, 1) - 2.0 * p(0, -1) + 2.0 * p(0, 1) - p(1, -1) + p(1, 1);
280            let gy = -p(-1, -1) - 2.0 * p(-1, 0) - p(-1, 1) + p(1, -1) + 2.0 * p(1, 0) + p(1, 1);
281            let idx = (y * w + x) * 3;
282            prods[idx] = gx * gx;
283            prods[idx + 1] = gx * gy;
284            prods[idx + 2] = gy * gy;
285        }
286    }
287}
288
289/// NEON SIMD fused Sobel gradient + interleaved product computation.
290#[cfg(target_arch = "aarch64")]
291#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
292#[target_feature(enable = "neon")]
293unsafe fn sobel_products_interleaved_neon(data: &[f32], prods: &mut [f32], h: usize, w: usize) {
294    use std::arch::aarch64::*;
295
296    let two = vdupq_n_f32(2.0);
297
298    for y in 1..h - 1 {
299        let row_above = (y - 1) * w;
300        let row_curr = y * w;
301        let row_below = (y + 1) * w;
302
303        let dp = data.as_ptr();
304        let pp = prods.as_mut_ptr();
305
306        let mut x = 1usize;
307
308        while x + 5 <= w {
309            let r0l = vld1q_f32(dp.add(row_above + x - 1));
310            let r0c = vld1q_f32(dp.add(row_above + x));
311            let r0r = vld1q_f32(dp.add(row_above + x + 1));
312            let r1l = vld1q_f32(dp.add(row_curr + x - 1));
313            let r1r = vld1q_f32(dp.add(row_curr + x + 1));
314            let r2l = vld1q_f32(dp.add(row_below + x - 1));
315            let r2c = vld1q_f32(dp.add(row_below + x));
316            let r2r = vld1q_f32(dp.add(row_below + x + 1));
317
318            let gx = vaddq_f32(
319                vaddq_f32(vsubq_f32(r0r, r0l), vsubq_f32(r2r, r2l)),
320                vmulq_f32(vsubq_f32(r1r, r1l), two),
321            );
322            let gy = vaddq_f32(
323                vaddq_f32(vsubq_f32(r2l, r0l), vsubq_f32(r2r, r0r)),
324                vmulq_f32(vsubq_f32(r2c, r0c), two),
325            );
326
327            // Interleaved store: [sxx0,sxy0,syy0, sxx1,sxy1,syy1, sxx2,sxy2,syy2, sxx3,sxy3,syy3]
328            let gxx = vmulq_f32(gx, gx);
329            let gxy = vmulq_f32(gx, gy);
330            let gyy = vmulq_f32(gy, gy);
331            vst3q_f32(pp.add((row_curr + x) * 3), float32x4x3_t(gxx, gxy, gyy));
332
333            x += 4;
334        }
335
336        // Scalar tail
337        while x < w - 1 {
338            let p = |dy: i32, dx: i32| {
339                *dp.add(((y as i32 + dy) as usize) * w + ((x as i32 + dx) as usize))
340            };
341            let gx = -p(-1, -1) + p(-1, 1) - 2.0 * p(0, -1) + 2.0 * p(0, 1) - p(1, -1) + p(1, 1);
342            let gy = -p(-1, -1) - 2.0 * p(-1, 0) - p(-1, 1) + p(1, -1) + 2.0 * p(1, 0) + p(1, 1);
343            let idx = (row_curr + x) * 3;
344            *pp.add(idx) = gx * gx;
345            *pp.add(idx + 1) = gx * gy;
346            *pp.add(idx + 2) = gy * gy;
347            x += 1;
348        }
349    }
350}
351
352/// SSE SIMD fused Sobel gradient + interleaved product computation.
353#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
354#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
355#[target_feature(enable = "sse")]
356unsafe fn sobel_products_interleaved_sse(data: &[f32], prods: &mut [f32], h: usize, w: usize) {
357    #[cfg(target_arch = "x86")]
358    use std::arch::x86::*;
359    #[cfg(target_arch = "x86_64")]
360    use std::arch::x86_64::*;
361
362    let two = _mm_set1_ps(2.0);
363
364    for y in 1..h - 1 {
365        let row_above = (y - 1) * w;
366        let row_curr = y * w;
367        let row_below = (y + 1) * w;
368
369        let dp = data.as_ptr();
370        let pp = prods.as_mut_ptr();
371
372        let mut x = 1usize;
373
374        while x + 5 <= w {
375            let r0l = _mm_loadu_ps(dp.add(row_above + x - 1));
376            let r0c = _mm_loadu_ps(dp.add(row_above + x));
377            let r0r = _mm_loadu_ps(dp.add(row_above + x + 1));
378            let r1l = _mm_loadu_ps(dp.add(row_curr + x - 1));
379            let r1r = _mm_loadu_ps(dp.add(row_curr + x + 1));
380            let r2l = _mm_loadu_ps(dp.add(row_below + x - 1));
381            let r2c = _mm_loadu_ps(dp.add(row_below + x));
382            let r2r = _mm_loadu_ps(dp.add(row_below + x + 1));
383
384            // gx = (r0r - r0l) + (r2r - r2l) + 2*(r1r - r1l)
385            let gx = _mm_add_ps(
386                _mm_add_ps(_mm_sub_ps(r0r, r0l), _mm_sub_ps(r2r, r2l)),
387                _mm_mul_ps(_mm_sub_ps(r1r, r1l), two),
388            );
389            // gy = (r2l - r0l) + (r2r - r0r) + 2*(r2c - r0c)
390            let gy = _mm_add_ps(
391                _mm_add_ps(_mm_sub_ps(r2l, r0l), _mm_sub_ps(r2r, r0r)),
392                _mm_mul_ps(_mm_sub_ps(r2c, r0c), two),
393            );
394
395            let gxx = _mm_mul_ps(gx, gx);
396            let gxy = _mm_mul_ps(gx, gy);
397            let gyy = _mm_mul_ps(gy, gy);
398
399            // SSE has no vst3q equivalent, so store interleaved via scalar extracts
400            // Extract 4 lanes and store interleaved [sxx, sxy, syy] per pixel
401            let mut gxx_arr = [0.0f32; 4];
402            let mut gxy_arr = [0.0f32; 4];
403            let mut gyy_arr = [0.0f32; 4];
404            _mm_storeu_ps(gxx_arr.as_mut_ptr(), gxx);
405            _mm_storeu_ps(gxy_arr.as_mut_ptr(), gxy);
406            _mm_storeu_ps(gyy_arr.as_mut_ptr(), gyy);
407
408            for k in 0..4 {
409                let idx = (row_curr + x + k) * 3;
410                *pp.add(idx) = gxx_arr[k];
411                *pp.add(idx + 1) = gxy_arr[k];
412                *pp.add(idx + 2) = gyy_arr[k];
413            }
414
415            x += 4;
416        }
417
418        // Scalar tail
419        while x < w - 1 {
420            let p = |dy: i32, dx: i32| {
421                *dp.add(((y as i32 + dy) as usize) * w + ((x as i32 + dx) as usize))
422            };
423            let gx = -p(-1, -1) + p(-1, 1) - 2.0 * p(0, -1) + 2.0 * p(0, 1) - p(1, -1) + p(1, 1);
424            let gy = -p(-1, -1) - 2.0 * p(-1, 0) - p(-1, 1) + p(1, -1) + 2.0 * p(1, 0) + p(1, 1);
425            let idx = (row_curr + x) * 3;
426            *pp.add(idx) = gx * gx;
427            *pp.add(idx + 1) = gx * gy;
428            *pp.add(idx + 2) = gy * gy;
429            x += 1;
430        }
431    }
432}
433
434// ── FAST corner detection ──────────────────────────────────────────
435
436/// FAST-9 corner detector on a single-channel `[H, W, 1]` image.
437///
438/// Tests 16 pixels on a Bresenham circle of radius 3. A pixel is a corner
439/// if at least `min_consecutive` contiguous circle pixels are all brighter
440/// (or all darker) than the center by at least `threshold`.
441///
442/// Uses precomputed offsets and bitmask-based contiguity checking for speed.
443#[allow(unsafe_code)]
444pub fn fast_corners(
445    input: &Tensor,
446    threshold: f32,
447    min_consecutive: usize,
448) -> Result<Vec<HarrisKeypoint>, ImgProcError> {
449    let (h, w, c) = hwc_shape(input)?;
450    if c != 1 {
451        return Err(ImgProcError::InvalidChannelCount {
452            expected: 1,
453            got: c,
454        });
455    }
456    let data = input.data();
457
458    const CIRCLE: [(i32, i32); 16] = [
459        (0, -3),
460        (1, -3),
461        (2, -2),
462        (3, -1),
463        (3, 0),
464        (3, 1),
465        (2, 2),
466        (1, 3),
467        (0, 3),
468        (-1, 3),
469        (-2, 2),
470        (-3, 1),
471        (-3, 0),
472        (-3, -1),
473        (-2, -2),
474        (-1, -3),
475    ];
476
477    // Precompute flat offsets from center pixel
478    let ws = w as isize;
479    let mut offsets = [0isize; 16];
480    for (i, &(dx, dy)) in CIRCLE.iter().enumerate() {
481        offsets[i] = dy as isize * ws + dx as isize;
482    }
483
484    // Cardinal offsets for early rejection: N(0), E(4), S(8), W(12)
485    let card = [offsets[0], offsets[4], offsets[8], offsets[12]];
486
487    let n = min_consecutive.min(16);
488    let mut corners = Vec::new();
489
490    for y in 3..h.saturating_sub(3) {
491        let row_base = y * w;
492        for x in 3..w.saturating_sub(3) {
493            let idx = row_base + x;
494            let center = unsafe { *data.get_unchecked(idx) };
495            let bright_thresh = center + threshold;
496            let dark_thresh = center - threshold;
497
498            // Quick reject: check cardinal pixels N, E, S, W
499            let mut bc = 0u32;
500            let mut dc = 0u32;
501            for &co in &card {
502                let v = unsafe { *data.get_unchecked((idx as isize + co) as usize) };
503                bc += (v > bright_thresh) as u32;
504                dc += (v < dark_thresh) as u32;
505            }
506            // Need at least 3 of 4 cardinals passing for any 9-run to exist
507            let min_card = if n >= 9 { 3 } else { n.min(4) };
508            if (bc as usize) < min_card && (dc as usize) < min_card {
509                continue;
510            }
511
512            // Build bitmasks
513            let mut bright_mask = 0u32;
514            let mut dark_mask = 0u32;
515            for i in 0..16 {
516                let v = unsafe { *data.get_unchecked((idx as isize + offsets[i]) as usize) };
517                if v > bright_thresh {
518                    bright_mask |= 1 << i;
519                }
520                if v < dark_thresh {
521                    dark_mask |= 1 << i;
522                }
523            }
524
525            let is_corner =
526                has_consecutive_mask(bright_mask, n) || has_consecutive_mask(dark_mask, n);
527            if is_corner {
528                // Compute score: sum of absolute differences
529                let mut score = 0.0f32;
530                for i in 0..16 {
531                    let v = unsafe { *data.get_unchecked((idx as isize + offsets[i]) as usize) };
532                    score += (v - center).abs();
533                }
534                corners.push(HarrisKeypoint {
535                    x,
536                    y,
537                    response: score,
538                });
539            }
540        }
541    }
542
543    Ok(corners)
544}
545
546/// Check if a 16-bit circular bitmask has `n` contiguous set bits.
547pub(crate) fn has_consecutive_mask(mask: u32, n: usize) -> bool {
548    if n == 0 {
549        return true;
550    }
551    if mask == 0 {
552        return false;
553    }
554    let doubled = mask | (mask << 16);
555    let mut run = 0u32;
556    for i in 0..32 {
557        if (doubled >> i) & 1 != 0 {
558            run += 1;
559            if run >= n as u32 {
560                return true;
561            }
562        } else {
563            run = 0;
564        }
565    }
566    false
567}
568
569#[allow(dead_code)]
570pub(crate) fn has_consecutive(flags: &[bool; 16], n: usize) -> bool {
571    if n == 0 {
572        return true;
573    }
574    let mut count = 0usize;
575    for i in 0..32 {
576        if flags[i % 16] {
577            count += 1;
578            if count >= n {
579                return true;
580            }
581        } else {
582            count = 0;
583        }
584    }
585    false
586}
587
588// ── Hough line detection ───────────────────────────────────────────
589
590/// Detected line in Hough parameter space.
591#[derive(Debug, Clone, Copy, PartialEq)]
592pub struct HoughLine {
593    pub rho: f32,
594    pub theta: f32,
595    pub votes: u32,
596}
597
598/// Standard Hough line transform on a binary/edge single-channel `[H, W, 1]` image.
599///
600/// `rho_resolution` is the distance resolution in pixels (typically 1.0).
601/// `theta_resolution` is the angular resolution in radians (typically pi/180).
602/// `vote_threshold` is the minimum accumulator count for a line.
603pub fn hough_lines(
604    input: &Tensor,
605    rho_resolution: f32,
606    theta_resolution: f32,
607    vote_threshold: u32,
608) -> Result<Vec<HoughLine>, ImgProcError> {
609    let (h, w, c) = hwc_shape(input)?;
610    if c != 1 {
611        return Err(ImgProcError::InvalidChannelCount {
612            expected: 1,
613            got: c,
614        });
615    }
616    let data = input.data();
617    let diag = ((h * h + w * w) as f32).sqrt();
618    let max_rho = diag;
619    let num_rho = (2.0 * max_rho / rho_resolution) as usize + 1;
620    let num_theta = (std::f32::consts::PI / theta_resolution) as usize;
621
622    let mut accumulator = vec![0u32; num_rho * num_theta];
623
624    let sins: Vec<f32> = (0..num_theta)
625        .map(|t| (t as f32 * theta_resolution).sin())
626        .collect();
627    let coss: Vec<f32> = (0..num_theta)
628        .map(|t| (t as f32 * theta_resolution).cos())
629        .collect();
630
631    for y in 0..h {
632        for x in 0..w {
633            if data[y * w + x] > 0.5 {
634                for t in 0..num_theta {
635                    let rho = x as f32 * coss[t] + y as f32 * sins[t];
636                    let rho_idx = ((rho + max_rho) / rho_resolution) as usize;
637                    if rho_idx < num_rho {
638                        accumulator[rho_idx * num_theta + t] += 1;
639                    }
640                }
641            }
642        }
643    }
644
645    let mut lines = Vec::new();
646    for rho_idx in 0..num_rho {
647        for t in 0..num_theta {
648            let votes = accumulator[rho_idx * num_theta + t];
649            if votes >= vote_threshold {
650                let rho = rho_idx as f32 * rho_resolution - max_rho;
651                let theta = t as f32 * theta_resolution;
652                lines.push(HoughLine { rho, theta, votes });
653            }
654        }
655    }
656
657    lines.sort_by(|a, b| b.votes.cmp(&a.votes));
658    Ok(lines)
659}
660
661// ── Image pyramid ──────────────────────────────────────────────────
662
663/// Builds a Gaussian image pyramid by repeated 2x downsampling.
664///
665/// Input is `[H, W, C]`. Returns a vector of progressively smaller images.
666/// Level 0 is the original image.
667pub fn gaussian_pyramid(input: &Tensor, levels: usize) -> Result<Vec<Tensor>, ImgProcError> {
668    let mut pyramid = Vec::with_capacity(levels + 1);
669    pyramid.push(input.clone());
670
671    for _ in 0..levels {
672        // SAFETY: pyramid always has at least one element (input.clone() pushed above).
673        let prev = pyramid.last().expect("pyramid must be non-empty");
674        let (ph, pw, pc) = hwc_shape(prev)?;
675        if ph < 2 || pw < 2 {
676            break;
677        }
678        let nh = ph / 2;
679        let nw = pw / 2;
680        let prev_data = prev.data();
681        let mut down = vec![0.0f32; nh * nw * pc];
682        for y in 0..nh {
683            for x in 0..nw {
684                let sy = y * 2;
685                let sx = x * 2;
686                for c in 0..pc {
687                    let v00 = prev_data[(sy * pw + sx) * pc + c];
688                    let v01 = prev_data[(sy * pw + (sx + 1).min(pw - 1)) * pc + c];
689                    let v10 = prev_data[((sy + 1).min(ph - 1) * pw + sx) * pc + c];
690                    let v11 =
691                        prev_data[((sy + 1).min(ph - 1) * pw + (sx + 1).min(pw - 1)) * pc + c];
692                    down[(y * nw + x) * pc + c] = (v00 + v01 + v10 + v11) * 0.25;
693                }
694            }
695        }
696        let t = Tensor::from_vec(vec![nh, nw, pc], down)?;
697        pyramid.push(t);
698    }
699
700    Ok(pyramid)
701}
702
703// ── Distance transform ─────────────────────────────────────────────
704
705/// L1 distance transform on a binary single-channel `[H, W, 1]` image.
706///
707/// Input pixels > 0.5 are foreground. Output is the distance from each
708/// foreground pixel to the nearest background pixel (two-pass L1).
709///
710/// Uses SIMD (NEON/SSE) to accelerate the vertical propagation step,
711/// which processes 4 f32 values at a time.
712#[allow(unsafe_code)]
713/// L1 Manhattan distance transform. Two interleaved passes (like OpenCV):
714/// each pass handles both vertical (SIMD) and horizontal (scalar) in one sweep.
715/// Conditional stores skip writes when value unchanged (huge win on sparse inputs).
716pub fn distance_transform(input: &Tensor) -> Result<Tensor, ImgProcError> {
717    let (h, w, c) = hwc_shape(input)?;
718    if c != 1 {
719        return Err(ImgProcError::InvalidChannelCount {
720            expected: 1,
721            got: c,
722        });
723    }
724    let data = input.data();
725    let inf = (h + w) as f32;
726
727    // Fast path: if no foreground pixels, distance is 0 everywhere.
728    let has_foreground = data.iter().any(|&v| v > 0.5);
729    if !has_foreground {
730        return Tensor::from_vec(vec![h, w, 1], vec![0.0f32; h * w]).map_err(Into::into);
731    }
732
733    let mut dist = vec![0.0f32; h * w];
734    for i in 0..h * w {
735        if data[i] > 0.5 {
736            unsafe {
737                *dist.as_mut_ptr().add(i) = inf;
738            }
739        }
740    }
741
742    // === Forward pass: top→bottom, vertical SIMD + horizontal register-scan ===
743    // Key optimization: keep running min in a REGISTER, not memory.
744    // Eliminates store→load forwarding latency (4→1 cycle per element).
745    {
746        let p = dist.as_mut_ptr();
747        unsafe {
748            let mut run = *p; // register
749            for x in 1..w {
750                run += 1.0;
751                let cur = *p.add(x);
752                if run < cur {
753                    *p.add(x) = run;
754                } else {
755                    run = cur;
756                }
757            }
758        }
759    }
760    for y in 1..h {
761        dt_vertical_min_forward(&mut dist, (y - 1) * w, y * w, w);
762        let p = unsafe { dist.as_mut_ptr().add(y * w) };
763        unsafe {
764            let mut run = *p;
765            for x in 1..w {
766                run += 1.0;
767                let cur = *p.add(x);
768                if run < cur {
769                    *p.add(x) = run;
770                } else {
771                    run = cur;
772                }
773            }
774        }
775    }
776
777    // === Backward pass: bottom→top, register-scan R→L ===
778    {
779        let p = unsafe { dist.as_mut_ptr().add((h - 1) * w) };
780        unsafe {
781            let mut run = *p.add(w - 1);
782            for x in (0..w.saturating_sub(1)).rev() {
783                run += 1.0;
784                let cur = *p.add(x);
785                if run < cur {
786                    *p.add(x) = run;
787                } else {
788                    run = cur;
789                }
790            }
791        }
792    }
793    for y in (0..h.saturating_sub(1)).rev() {
794        dt_vertical_min_forward(&mut dist, (y + 1) * w, y * w, w);
795        let p = unsafe { dist.as_mut_ptr().add(y * w) };
796        unsafe {
797            let mut run = *p.add(w - 1);
798            for x in (0..w.saturating_sub(1)).rev() {
799                run += 1.0;
800                let cur = *p.add(x);
801                if run < cur {
802                    *p.add(x) = run;
803                } else {
804                    run = cur;
805                }
806            }
807        }
808    }
809
810    Tensor::from_vec(vec![h, w, 1], dist).map_err(Into::into)
811}
812
813/// SIMD-accelerated vertical min step: cur[x] = min(cur[x], src[x] + 1) for all x.
814#[allow(unsafe_code)]
815fn dt_vertical_min_forward(dist: &mut [f32], src_start: usize, cur_start: usize, w: usize) {
816    let mut x = 0usize;
817
818    #[cfg(target_arch = "aarch64")]
819    {
820        if std::arch::is_aarch64_feature_detected!("neon") {
821            x = unsafe { dt_vertical_neon(dist, src_start, cur_start, w) };
822        }
823    }
824
825    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
826    {
827        if std::is_x86_feature_detected!("sse") {
828            x = unsafe { dt_vertical_sse(dist, src_start, cur_start, w) };
829        }
830    }
831
832    // Scalar tail
833    while x < w {
834        let src_val = dist[src_start + x] + 1.0;
835        if src_val < dist[cur_start + x] {
836            dist[cur_start + x] = src_val;
837        }
838        x += 1;
839    }
840}
841
842#[cfg(target_arch = "aarch64")]
843#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
844#[target_feature(enable = "neon")]
845unsafe fn dt_vertical_neon(
846    dist: &mut [f32],
847    src_start: usize,
848    cur_start: usize,
849    w: usize,
850) -> usize {
851    use std::arch::aarch64::*;
852    let one = vdupq_n_f32(1.0);
853    let sp = dist.as_ptr().add(src_start);
854    let cp = dist.as_mut_ptr().add(cur_start);
855    let mut x = 0usize;
856    // 4× unrolled: 16 elements per iteration
857    while x + 16 <= w {
858        let s0 = vaddq_f32(vld1q_f32(sp.add(x)), one);
859        let s1 = vaddq_f32(vld1q_f32(sp.add(x + 4)), one);
860        let s2 = vaddq_f32(vld1q_f32(sp.add(x + 8)), one);
861        let s3 = vaddq_f32(vld1q_f32(sp.add(x + 12)), one);
862        vst1q_f32(cp.add(x), vminq_f32(vld1q_f32(cp.add(x)), s0));
863        vst1q_f32(cp.add(x + 4), vminq_f32(vld1q_f32(cp.add(x + 4)), s1));
864        vst1q_f32(cp.add(x + 8), vminq_f32(vld1q_f32(cp.add(x + 8)), s2));
865        vst1q_f32(cp.add(x + 12), vminq_f32(vld1q_f32(cp.add(x + 12)), s3));
866        x += 16;
867    }
868    while x + 4 <= w {
869        let s = vaddq_f32(vld1q_f32(sp.add(x)), one);
870        vst1q_f32(cp.add(x), vminq_f32(vld1q_f32(cp.add(x)), s));
871        x += 4;
872    }
873    x
874}
875
876#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
877#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
878#[target_feature(enable = "sse")]
879unsafe fn dt_vertical_sse(dist: &mut [f32], src_start: usize, cur_start: usize, w: usize) -> usize {
880    #[cfg(target_arch = "x86")]
881    use std::arch::x86::*;
882    #[cfg(target_arch = "x86_64")]
883    use std::arch::x86_64::*;
884
885    let one = _mm_set1_ps(1.0);
886    let ptr = dist.as_mut_ptr();
887    let mut x = 0usize;
888    while x + 4 <= w {
889        let src = _mm_loadu_ps(ptr.add(src_start + x));
890        let cur = _mm_loadu_ps(ptr.add(cur_start + x));
891        let src_plus_one = _mm_add_ps(src, one);
892        let result = _mm_min_ps(cur, src_plus_one);
893        _mm_storeu_ps(ptr.add(cur_start + x), result);
894        x += 4;
895    }
896    x
897}
898
899// ── ORB feature descriptors ────────────────────────────────────────
900
901/// ORB descriptor: 256-bit binary descriptor stored as 32 bytes.
902#[derive(Debug, Clone, PartialEq, Eq)]
903pub struct OrbDescriptor {
904    pub keypoint: (usize, usize),
905    pub bits: [u8; 32],
906}
907
908const ORB_PATTERN_LEN: usize = 256;
909
910fn orb_pattern() -> Vec<(i32, i32, i32, i32)> {
911    let mut pattern = Vec::with_capacity(ORB_PATTERN_LEN);
912    let mut seed: u32 = 0x1234_5678;
913    for _ in 0..ORB_PATTERN_LEN {
914        seed ^= seed << 13;
915        seed ^= seed >> 17;
916        seed ^= seed << 5;
917        let a_x = ((seed % 31) as i32) - 15;
918        seed ^= seed << 13;
919        seed ^= seed >> 17;
920        seed ^= seed << 5;
921        let a_y = ((seed % 31) as i32) - 15;
922        seed ^= seed << 13;
923        seed ^= seed >> 17;
924        seed ^= seed << 5;
925        let b_x = ((seed % 31) as i32) - 15;
926        seed ^= seed << 13;
927        seed ^= seed >> 17;
928        seed ^= seed << 5;
929        let b_y = ((seed % 31) as i32) - 15;
930        pattern.push((a_x, a_y, b_x, b_y));
931    }
932    pattern
933}
934
935/// Computes ORB descriptors for keypoints on a grayscale `[H,W,1]` image.
936pub fn orb_descriptors(
937    image: &Tensor,
938    keypoints: &[(usize, usize)],
939    patch_radius: usize,
940) -> Result<Vec<OrbDescriptor>, ImgProcError> {
941    let (h, w, c) = hwc_shape(image)?;
942    if c != 1 {
943        return Err(ImgProcError::InvalidChannelCount {
944            expected: 1,
945            got: c,
946        });
947    }
948    let data = image.data();
949    let pattern = orb_pattern();
950    let mut descriptors = Vec::new();
951    for &(kx, ky) in keypoints {
952        if kx < patch_radius
953            || ky < patch_radius
954            || kx + patch_radius >= w
955            || ky + patch_radius >= h
956        {
957            continue;
958        }
959        let mut bits = [0u8; 32];
960        for (i, &(ax, ay, bx, by)) in pattern.iter().enumerate() {
961            let pa = data[(ky as i32 + ay) as usize * w + (kx as i32 + ax) as usize];
962            let pb = data[(ky as i32 + by) as usize * w + (kx as i32 + bx) as usize];
963            if pa < pb {
964                bits[i / 8] |= 1 << (i % 8);
965            }
966        }
967        descriptors.push(OrbDescriptor {
968            keypoint: (kx, ky),
969            bits,
970        });
971    }
972    Ok(descriptors)
973}
974
975/// Hamming distance between two ORB descriptors.
976pub fn orb_hamming_distance(a: &OrbDescriptor, b: &OrbDescriptor) -> u32 {
977    a.bits
978        .iter()
979        .zip(b.bits.iter())
980        .map(|(&x, &y)| (x ^ y).count_ones())
981        .sum()
982}
983
984/// Brute-force ORB descriptor matching. Returns `(idx_a, idx_b, distance)`.
985pub fn orb_match(
986    desc_a: &[OrbDescriptor],
987    desc_b: &[OrbDescriptor],
988    max_distance: u32,
989) -> Vec<(usize, usize, u32)> {
990    let mut matches = Vec::new();
991    for (i, da) in desc_a.iter().enumerate() {
992        let mut best_j = 0;
993        let mut best_dist = u32::MAX;
994        for (j, db) in desc_b.iter().enumerate() {
995            let d = orb_hamming_distance(da, db);
996            if d < best_dist {
997                best_dist = d;
998                best_j = j;
999            }
1000        }
1001        if best_dist <= max_distance {
1002            matches.push((i, best_j, best_dist));
1003        }
1004    }
1005    matches
1006}
1007
1008// ── Shi-Tomasi good features to track ──────────────────────────────
1009
1010/// Shi-Tomasi corner detection (minimum eigenvalue approach).
1011///
1012/// Detects up to `max_corners` strong corners on a grayscale `[H, W, 1]` image.
1013/// `quality_level` (0..1) sets the threshold relative to the strongest corner response.
1014/// `min_distance` enforces minimum Euclidean distance between returned corners via
1015/// greedy non-maximum suppression.
1016///
1017/// Returns a vector of `(row, col)` coordinates.
1018pub fn good_features_to_track(
1019    img: &Tensor,
1020    max_corners: usize,
1021    quality_level: f32,
1022    min_distance: f32,
1023) -> Result<Vec<(usize, usize)>, ImgProcError> {
1024    let (h, w, c) = hwc_shape(img)?;
1025    if c != 1 {
1026        return Err(ImgProcError::InvalidChannelCount {
1027            expected: 1,
1028            got: c,
1029        });
1030    }
1031    let data = img.data();
1032
1033    // Compute gradients using [-1, 0, 1] kernel
1034    let mut ix = vec![0.0f32; h * w];
1035    let mut iy = vec![0.0f32; h * w];
1036    for y in 1..h - 1 {
1037        for x in 1..w - 1 {
1038            ix[y * w + x] = data[y * w + x + 1] - data[y * w + x - 1];
1039            iy[y * w + x] = data[(y + 1) * w + x] - data[(y - 1) * w + x];
1040        }
1041    }
1042
1043    // Compute structure tensor elements with 3x3 window sum, then min eigenvalue
1044    let mut min_eig = vec![0.0f32; h * w];
1045    let mut max_eig_val: f32 = 0.0;
1046
1047    for y in 1..h - 1 {
1048        for x in 1..w - 1 {
1049            let (mut sxx, mut sxy, mut syy) = (0.0f32, 0.0f32, 0.0f32);
1050            for dy in -1i32..=1 {
1051                for dx in -1i32..=1 {
1052                    let py = (y as i32 + dy) as usize;
1053                    let px = (x as i32 + dx) as usize;
1054                    let gx = ix[py * w + px];
1055                    let gy = iy[py * w + px];
1056                    sxx += gx * gx;
1057                    sxy += gx * gy;
1058                    syy += gy * gy;
1059                }
1060            }
1061            // Min eigenvalue of 2x2 matrix [[sxx, sxy], [sxy, syy]]
1062            let trace = sxx + syy;
1063            let det = sxx * syy - sxy * sxy;
1064            let disc = (trace * trace - 4.0 * det).max(0.0).sqrt();
1065            let lambda_min = (trace - disc) * 0.5;
1066            min_eig[y * w + x] = lambda_min;
1067            if lambda_min > max_eig_val {
1068                max_eig_val = lambda_min;
1069            }
1070        }
1071    }
1072
1073    // Threshold
1074    let thresh = quality_level * max_eig_val;
1075
1076    // Collect candidates above threshold
1077    let mut candidates: Vec<(usize, usize, f32)> = Vec::new();
1078    for y in 1..h - 1 {
1079        for x in 1..w - 1 {
1080            let e = min_eig[y * w + x];
1081            if e > thresh {
1082                candidates.push((y, x, e));
1083            }
1084        }
1085    }
1086
1087    // Sort by strength (strongest first)
1088    candidates.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
1089
1090    // Greedy non-maximum suppression by min_distance
1091    let min_dist_sq = min_distance * min_distance;
1092    let mut selected: Vec<(usize, usize)> = Vec::new();
1093    for (r, c_col, _) in &candidates {
1094        if selected.len() >= max_corners {
1095            break;
1096        }
1097        let too_close = selected.iter().any(|&(sr, sc)| {
1098            let dr = *r as f32 - sr as f32;
1099            let dc = *c_col as f32 - sc as f32;
1100            dr * dr + dc * dc < min_dist_sq
1101        });
1102        if !too_close {
1103            selected.push((*r, *c_col));
1104        }
1105    }
1106
1107    Ok(selected)
1108}
1109
1110// ── Sub-pixel corner refinement ────────────────────────────────────
1111
1112/// Refines corner locations to sub-pixel accuracy using gradient-based method.
1113///
1114/// For each corner in `corners` (given as `(row, col)`), a window of `±win_size`
1115/// is used to solve a 2x2 linear system for the sub-pixel offset.
1116///
1117/// Returns refined `(row, col)` as `f32` coordinates.
1118pub fn corner_sub_pix(
1119    img: &Tensor,
1120    corners: &[(usize, usize)],
1121    win_size: usize,
1122) -> Result<Vec<(f32, f32)>, ImgProcError> {
1123    let (h, w, c) = hwc_shape(img)?;
1124    if c != 1 {
1125        return Err(ImgProcError::InvalidChannelCount {
1126            expected: 1,
1127            got: c,
1128        });
1129    }
1130    let data = img.data();
1131
1132    let mut refined = Vec::with_capacity(corners.len());
1133
1134    for &(row, col) in corners {
1135        // Check if window fits inside the image (need 1 extra for gradient)
1136        if row < win_size + 1
1137            || row + win_size + 1 >= h
1138            || col < win_size + 1
1139            || col + win_size + 1 >= w
1140        {
1141            // Can't refine, return original position
1142            refined.push((row as f32, col as f32));
1143            continue;
1144        }
1145
1146        // Build the normal equations for sub-pixel shift:
1147        // For each pixel in the window, compute gradient g = (gx, gy)
1148        // and the dot product g . (p - corner). We solve:
1149        //   A * delta = b
1150        // where A = sum(g * g^T), b = sum(g * g^T * (p - corner))
1151        // This is equivalent to minimizing the squared gradient projection error.
1152        let mut a00 = 0.0f32; // sum gx*gx
1153        let mut a01 = 0.0f32; // sum gx*gy
1154        let mut a11 = 0.0f32; // sum gy*gy
1155        let mut b0 = 0.0f32; // sum gx*gx*(px-col) + gx*gy*(py-row)
1156        let mut b1 = 0.0f32; // sum gx*gy*(px-col) + gy*gy*(py-row)
1157
1158        let ws = win_size as i32;
1159        for dy in -ws..=ws {
1160            for dx in -ws..=ws {
1161                let py = (row as i32 + dy) as usize;
1162                let px = (col as i32 + dx) as usize;
1163
1164                let gx = data[py * w + px + 1] - data[py * w + px - 1];
1165                let gy = data[(py + 1) * w + px] - data[(py - 1) * w + px];
1166
1167                a00 += gx * gx;
1168                a01 += gx * gy;
1169                a11 += gy * gy;
1170
1171                let dxf = dx as f32;
1172                let dyf = dy as f32;
1173                b0 += gx * gx * dxf + gx * gy * dyf;
1174                b1 += gx * gy * dxf + gy * gy * dyf;
1175            }
1176        }
1177
1178        // Solve 2x2 system: [[a00, a01], [a01, a11]] * [dcol, drow] = [b0, b1]
1179        let det = a00 * a11 - a01 * a01;
1180        if det.abs() < 1e-10 {
1181            // Singular, keep original
1182            refined.push((row as f32, col as f32));
1183        } else {
1184            let inv_det = 1.0 / det;
1185            let dcol = (a11 * b0 - a01 * b1) * inv_det;
1186            let drow = (-a01 * b0 + a00 * b1) * inv_det;
1187
1188            // Clamp offset to ±win_size to avoid wild jumps
1189            let dcol = dcol.clamp(-(win_size as f32), win_size as f32);
1190            let drow = drow.clamp(-(win_size as f32), win_size as f32);
1191
1192            refined.push((row as f32 + drow, col as f32 + dcol));
1193        }
1194    }
1195
1196    Ok(refined)
1197}
1198
1199// ── Gradient orientation ───────────────────────────────────────────
1200
1201/// Computes image gradients and orientation for descriptor construction.
1202/// Returns a gradient magnitude + orientation pair per pixel.
1203pub fn compute_gradient_orientation(image: &Tensor) -> Result<(Tensor, Tensor), ImgProcError> {
1204    let (h, w, c) = hwc_shape(image)?;
1205    if c != 1 {
1206        return Err(ImgProcError::InvalidChannelCount {
1207            expected: 1,
1208            got: c,
1209        });
1210    }
1211    let data = image.data();
1212    let mut mag = vec![0.0f32; h * w];
1213    let mut ori = vec![0.0f32; h * w];
1214
1215    for y in 1..h - 1 {
1216        for x in 1..w - 1 {
1217            let gx = data[y * w + x + 1] - data[y * w + x - 1];
1218            let gy = data[(y + 1) * w + x] - data[(y - 1) * w + x];
1219            mag[y * w + x] = (gx * gx + gy * gy).sqrt();
1220            ori[y * w + x] = gy.atan2(gx);
1221        }
1222    }
1223
1224    let mag_t = Tensor::from_vec(vec![h, w, 1], mag)?;
1225    let ori_t = Tensor::from_vec(vec![h, w, 1], ori)?;
1226    Ok((mag_t, ori_t))
1227}
1228
1229/// Histogram of Oriented Gradients (HOG) descriptor for a single cell.
1230///
1231/// `cell` is a grayscale patch `[cell_h, cell_w, 1]`, returns 9-bin orientation histogram.
1232pub fn hog_cell_descriptor(cell: &Tensor) -> Result<Vec<f32>, ImgProcError> {
1233    let (ch, cw, cc) = hwc_shape(cell)?;
1234    if cc != 1 {
1235        return Err(ImgProcError::InvalidChannelCount {
1236            expected: 1,
1237            got: cc,
1238        });
1239    }
1240    let data = cell.data();
1241    let mut bins = [0.0f32; 9];
1242    let bin_width = std::f32::consts::PI / 9.0;
1243
1244    for y in 1..ch.saturating_sub(1) {
1245        for x in 1..cw.saturating_sub(1) {
1246            let gx = data[y * cw + x + 1] - data[y * cw + x - 1];
1247            let gy = data[(y + 1) * cw + x] - data[(y - 1) * cw + x];
1248            let mag = (gx * gx + gy * gy).sqrt();
1249            let angle = gy.atan2(gx).rem_euclid(std::f32::consts::PI);
1250            let bin = ((angle / bin_width) as usize).min(8);
1251            bins[bin] += mag;
1252        }
1253    }
1254
1255    // L2-normalize
1256    let norm = bins.iter().map(|v| v * v).sum::<f32>().sqrt().max(1e-6);
1257    Ok(bins.iter().map(|v| v / norm).collect())
1258}
1259
1260// ── SIFT-like descriptor ───────────────────────────────────────────
1261
1262/// SIFT-like 128-element descriptor for a keypoint on a grayscale `[H,W,1]` image.
1263///
1264/// Computes a 4x4 grid of 8-bin orientation histograms in a 16x16 patch around each keypoint.
1265/// Returns 128-float L2-normalized descriptor per keypoint.
1266/// Keypoints too close to border are skipped (returned as all zeros).
1267pub fn sift_descriptor(
1268    image: &Tensor,
1269    keypoints: &[(usize, usize)],
1270) -> Result<Vec<[f32; 128]>, ImgProcError> {
1271    let (h, w, c) = hwc_shape(image)?;
1272    if c != 1 {
1273        return Err(ImgProcError::InvalidChannelCount {
1274            expected: 1,
1275            got: c,
1276        });
1277    }
1278    let data = image.data();
1279    let half_patch = 8; // 16x16 patch, radius 8
1280
1281    let mut descriptors = Vec::with_capacity(keypoints.len());
1282
1283    for &(ky, kx) in keypoints {
1284        let mut desc = [0.0f32; 128];
1285
1286        if ky < half_patch || ky + half_patch >= h || kx < half_patch || kx + half_patch >= w {
1287            descriptors.push(desc);
1288            continue;
1289        }
1290
1291        // 4x4 grid of 4x4 cells, each producing 8-bin histogram
1292        for gy in 0..4 {
1293            for gx in 0..4 {
1294                let cell_y = (ky - half_patch) + gy * 4;
1295                let cell_x = (kx - half_patch) + gx * 4;
1296                let bin_offset = (gy * 4 + gx) * 8;
1297
1298                for cy in 0..4 {
1299                    for cx in 0..4 {
1300                        let py = cell_y + cy;
1301                        let px = cell_x + cx;
1302                        if py < 1 || py + 1 >= h || px < 1 || px + 1 >= w {
1303                            continue;
1304                        }
1305                        let gx_val = data[py * w + px + 1] - data[py * w + px - 1];
1306                        let gy_val = data[(py + 1) * w + px] - data[(py - 1) * w + px];
1307                        let mag = (gx_val * gx_val + gy_val * gy_val).sqrt();
1308                        let angle = gy_val.atan2(gx_val).rem_euclid(2.0 * std::f32::consts::PI);
1309                        let bin = ((angle / (2.0 * std::f32::consts::PI) * 8.0) as usize).min(7);
1310                        desc[bin_offset + bin] += mag;
1311                    }
1312                }
1313            }
1314        }
1315
1316        // L2-normalize, clamp, re-normalize
1317        let norm1 = desc.iter().map(|v| v * v).sum::<f32>().sqrt().max(1e-7);
1318        for v in &mut desc {
1319            *v /= norm1;
1320        }
1321        for v in &mut desc {
1322            *v = v.min(0.2);
1323        }
1324        let norm2 = desc.iter().map(|v| v * v).sum::<f32>().sqrt().max(1e-7);
1325        for v in &mut desc {
1326            *v /= norm2;
1327        }
1328
1329        descriptors.push(desc);
1330    }
1331
1332    Ok(descriptors)
1333}
1334
1335/// Matches SIFT descriptors using Euclidean distance with Lowe's ratio test.
1336///
1337/// Returns `(idx_a, idx_b, distance)` pairs where the best match ratio is below `ratio_threshold`.
1338pub fn sift_match(
1339    desc_a: &[[f32; 128]],
1340    desc_b: &[[f32; 128]],
1341    ratio_threshold: f32,
1342) -> Vec<(usize, usize, f32)> {
1343    let mut matches = Vec::new();
1344
1345    for (ia, da) in desc_a.iter().enumerate() {
1346        let mut best_dist = f32::MAX;
1347        let mut second_dist = f32::MAX;
1348        let mut best_idx = 0;
1349
1350        for (ib, db) in desc_b.iter().enumerate() {
1351            let dist: f32 = da
1352                .iter()
1353                .zip(db.iter())
1354                .map(|(a, b)| (a - b) * (a - b))
1355                .sum::<f32>()
1356                .sqrt();
1357            if dist < best_dist {
1358                second_dist = best_dist;
1359                best_dist = dist;
1360                best_idx = ib;
1361            } else if dist < second_dist {
1362                second_dist = dist;
1363            }
1364        }
1365
1366        if second_dist > 0.0 && best_dist / second_dist < ratio_threshold {
1367            matches.push((ia, best_idx, best_dist));
1368        }
1369    }
1370
1371    matches
1372}
1373
1374// ── Blob detection (Laplacian of Gaussian) ─────────────────────────
1375
1376/// Laplacian-of-Gaussian blob detection on a grayscale `[H, W, 1]` image.
1377///
1378/// Builds a scale space by applying Gaussian blur at `num_sigma` scales
1379/// linearly spaced between `min_sigma` and `max_sigma`, computing the
1380/// Laplacian at each scale (approximated with second differences), and
1381/// scale-normalising by `sigma²`.
1382///
1383/// Local maxima in the 3×3×3 scale-space neighbourhood that exceed
1384/// `threshold` are returned as `(row, col, sigma)`.
1385pub fn blob_log(
1386    img: &Tensor,
1387    min_sigma: f32,
1388    max_sigma: f32,
1389    num_sigma: usize,
1390    threshold: f32,
1391) -> Result<Vec<(usize, usize, f32)>, ImgProcError> {
1392    let (h, w, c) = hwc_shape(img)?;
1393    if c != 1 {
1394        return Err(ImgProcError::InvalidChannelCount {
1395            expected: 1,
1396            got: c,
1397        });
1398    }
1399    if num_sigma == 0 {
1400        return Ok(Vec::new());
1401    }
1402
1403    // Build sigma values
1404    let sigmas: Vec<f32> = if num_sigma == 1 {
1405        vec![min_sigma]
1406    } else {
1407        (0..num_sigma)
1408            .map(|i| min_sigma + (max_sigma - min_sigma) * i as f32 / (num_sigma - 1) as f32)
1409            .collect()
1410    };
1411
1412    // Build scale-space: for each sigma, blur then compute Laplacian, then scale-normalise
1413    let data = img.data();
1414    let n = h * w;
1415    let mut scale_space: Vec<Vec<f32>> = Vec::with_capacity(num_sigma);
1416
1417    for &sigma in &sigmas {
1418        // Gaussian blur with given sigma using separable convolution
1419        let blurred = gaussian_blur_sigma(data, h, w, sigma);
1420
1421        // Approximate Laplacian with second differences: d²/dx² + d²/dy²
1422        let mut lap = vec![0.0f32; n];
1423        for y in 1..h.saturating_sub(1) {
1424            for x in 1..w.saturating_sub(1) {
1425                let idx = y * w + x;
1426                let dxx = blurred[idx + 1] - 2.0 * blurred[idx] + blurred[idx - 1];
1427                let dyy = blurred[idx + w] - 2.0 * blurred[idx] + blurred[idx - w];
1428                // Scale-normalise by sigma²
1429                lap[idx] = (dxx + dyy).abs() * sigma * sigma;
1430            }
1431        }
1432        scale_space.push(lap);
1433    }
1434
1435    // Find local maxima in 3×3×3 neighbourhood (exclude border region proportional to sigma)
1436    let mut blobs = Vec::new();
1437    for s in 0..num_sigma {
1438        let border = (sigmas[s] * 3.0).ceil() as usize + 1;
1439        for y in border..h.saturating_sub(border) {
1440            for x in border..w.saturating_sub(border) {
1441                let val = scale_space[s][y * w + x];
1442                if val < threshold {
1443                    continue;
1444                }
1445                let mut is_max = true;
1446                'outer: for ds in -1i32..=1 {
1447                    let si = s as i32 + ds;
1448                    if si < 0 || si >= num_sigma as i32 {
1449                        continue;
1450                    }
1451                    let si = si as usize;
1452                    for dy in -1i32..=1 {
1453                        for dx in -1i32..=1 {
1454                            if ds == 0 && dy == 0 && dx == 0 {
1455                                continue;
1456                            }
1457                            let ny = (y as i32 + dy) as usize;
1458                            let nx = (x as i32 + dx) as usize;
1459                            if ny < h && nx < w && scale_space[si][ny * w + nx] >= val {
1460                                is_max = false;
1461                                break 'outer;
1462                            }
1463                        }
1464                    }
1465                }
1466                if is_max {
1467                    blobs.push((y, x, sigmas[s]));
1468                }
1469            }
1470        }
1471    }
1472
1473    Ok(blobs)
1474}
1475
1476/// Separable Gaussian blur with arbitrary sigma on flat single-channel data.
1477fn gaussian_blur_sigma(data: &[f32], h: usize, w: usize, sigma: f32) -> Vec<f32> {
1478    let radius = (sigma * 3.0).ceil() as usize;
1479    let size = 2 * radius + 1;
1480
1481    // Build 1D Gaussian kernel
1482    let mut kernel = vec![0.0f32; size];
1483    let denom = 2.0 * sigma * sigma;
1484    let mut sum = 0.0f32;
1485    for i in 0..size {
1486        let x = i as f32 - radius as f32;
1487        kernel[i] = (-x * x / denom).exp();
1488        sum += kernel[i];
1489    }
1490    for v in &mut kernel {
1491        *v /= sum;
1492    }
1493
1494    let n = h * w;
1495    // Horizontal pass
1496    let mut tmp = vec![0.0f32; n];
1497    for y in 0..h {
1498        for x in 0..w {
1499            let mut acc = 0.0f32;
1500            for k in 0..size {
1501                let sx = x as i32 + k as i32 - radius as i32;
1502                if sx >= 0 && sx < w as i32 {
1503                    acc += data[y * w + sx as usize] * kernel[k];
1504                }
1505            }
1506            tmp[y * w + x] = acc;
1507        }
1508    }
1509    // Vertical pass
1510    let mut out = vec![0.0f32; n];
1511    for y in 0..h {
1512        for x in 0..w {
1513            let mut acc = 0.0f32;
1514            for k in 0..size {
1515                let sy = y as i32 + k as i32 - radius as i32;
1516                if sy >= 0 && sy < h as i32 {
1517                    acc += tmp[sy as usize * w + x] * kernel[k];
1518                }
1519            }
1520            out[y * w + x] = acc;
1521        }
1522    }
1523    out
1524}
1525
1526// ── Hough circle transform ─────────────────────────────────────────
1527
1528/// Hough circle transform on a binary/edge single-channel `[H, W, 1]` image.
1529///
1530/// For each edge pixel (value > 0.5), votes are cast into an accumulator for
1531/// circles of every radius in `[min_radius, max_radius]`. Returns
1532/// `(center_row, center_col, radius)` for accumulator peaks above `threshold`.
1533pub fn hough_circles(
1534    img: &Tensor,
1535    min_radius: usize,
1536    max_radius: usize,
1537    threshold: usize,
1538) -> Result<Vec<(usize, usize, usize)>, ImgProcError> {
1539    let (h, w, c) = hwc_shape(img)?;
1540    if c != 1 {
1541        return Err(ImgProcError::InvalidChannelCount {
1542            expected: 1,
1543            got: c,
1544        });
1545    }
1546    if min_radius > max_radius || max_radius == 0 {
1547        return Ok(Vec::new());
1548    }
1549    let data = img.data();
1550    let num_radii = max_radius - min_radius + 1;
1551
1552    // Accumulator: [radius_idx][y * w + x]
1553    let mut acc = vec![vec![0u32; h * w]; num_radii];
1554
1555    // Pre-compute circle offsets for each radius
1556    for ri in 0..num_radii {
1557        let r = min_radius + ri;
1558        // Generate circle points using angular discretisation
1559        let circumference = (2.0 * std::f32::consts::PI * r as f32).ceil() as usize;
1560        let num_steps = circumference.max(36);
1561        let mut visited = std::collections::HashSet::new();
1562
1563        for step in 0..num_steps {
1564            let angle = 2.0 * std::f32::consts::PI * step as f32 / num_steps as f32;
1565            let dx = (r as f32 * angle.cos()).round() as i32;
1566            let dy = (r as f32 * angle.sin()).round() as i32;
1567            visited.insert((dy, dx));
1568        }
1569
1570        // Vote
1571        for y in 0..h {
1572            for x in 0..w {
1573                if data[y * w + x] <= 0.5 {
1574                    continue;
1575                }
1576                for &(dy, dx) in &visited {
1577                    let cy = y as i32 - dy;
1578                    let cx = x as i32 - dx;
1579                    if cy >= 0 && cy < h as i32 && cx >= 0 && cx < w as i32 {
1580                        acc[ri][cy as usize * w + cx as usize] += 1;
1581                    }
1582                }
1583            }
1584        }
1585    }
1586
1587    // Extract peaks above threshold with 3×3×3 non-maximum suppression
1588    let mut circles = Vec::new();
1589    for ri in 0..num_radii {
1590        for y in 0..h {
1591            for x in 0..w {
1592                let votes = acc[ri][y * w + x];
1593                if votes < threshold as u32 {
1594                    continue;
1595                }
1596                let mut is_max = true;
1597                'peak: for dri in -1i32..=1 {
1598                    let nri = ri as i32 + dri;
1599                    if nri < 0 || nri >= num_radii as i32 {
1600                        continue;
1601                    }
1602                    for dy in -1i32..=1 {
1603                        for dx in -1i32..=1 {
1604                            if dri == 0 && dy == 0 && dx == 0 {
1605                                continue;
1606                            }
1607                            let ny = y as i32 + dy;
1608                            let nx = x as i32 + dx;
1609                            if ny >= 0
1610                                && ny < h as i32
1611                                && nx >= 0
1612                                && nx < w as i32
1613                                && acc[nri as usize][ny as usize * w + nx as usize] >= votes
1614                            {
1615                                is_max = false;
1616                                break 'peak;
1617                            }
1618                        }
1619                    }
1620                }
1621                if is_max {
1622                    circles.push((y, x, min_radius + ri));
1623                }
1624            }
1625        }
1626    }
1627
1628    // Sort by votes descending
1629    circles.sort_by(|a, b| {
1630        let va = acc[a.2 - min_radius][a.0 * w + a.1];
1631        let vb = acc[b.2 - min_radius][b.0 * w + b.1];
1632        vb.cmp(&va)
1633    });
1634
1635    Ok(circles)
1636}