1use yscv_tensor::Tensor;
2
3use super::super::ImgProcError;
4use super::super::shape::hwc_shape;
5
6#[derive(Debug, Clone)]
8pub struct Keypoint {
9 pub x: f32,
10 pub y: f32,
11 pub response: f32,
12 pub angle: f32,
13 pub octave: usize,
14}
15
16const CIRCLE: [(i32, i32); 16] = [
18 (0, -3),
19 (1, -3),
20 (2, -2),
21 (3, -1),
22 (3, 0),
23 (3, 1),
24 (2, 2),
25 (1, 3),
26 (0, 3),
27 (-1, 3),
28 (-2, 2),
29 (-3, 1),
30 (-3, 0),
31 (-3, -1),
32 (-2, -2),
33 (-1, -3),
34];
35
36#[inline]
38fn circle_offsets(w: usize) -> [isize; 16] {
39 let ws = w as isize;
40 let mut offsets = [0isize; 16];
41 for (i, &(dx, dy)) in CIRCLE.iter().enumerate() {
42 offsets[i] = dy as isize * ws + dx as isize;
43 }
44 offsets
45}
46
47#[inline]
50fn contiguous_run_from_mask(mask: u32) -> usize {
51 if mask == 0 {
52 return 0;
53 }
54 let doubled = mask | (mask << 16);
56 let mut best = 0u32;
57 let mut run = 0u32;
58 for i in 0..32 {
60 if (doubled >> i) & 1 != 0 {
61 run += 1;
62 if run > best {
63 best = run;
64 }
65 } else {
66 run = 0;
67 }
68 }
69 best.min(16) as usize
70}
71
72#[allow(unsafe_code)]
83pub fn fast9_detect(
84 image: &Tensor,
85 threshold: f32,
86 non_max: bool,
87) -> Result<Vec<Keypoint>, ImgProcError> {
88 let (h, w, c) = hwc_shape(image)?;
89 if c != 1 {
90 return Err(ImgProcError::InvalidChannelCount {
91 expected: 1,
92 got: c,
93 });
94 }
95 Ok(fast9_detect_raw(image.data(), h, w, threshold, non_max))
96}
97
98pub fn fast9_detect_raw(
100 data: &[f32],
101 h: usize,
102 w: usize,
103 threshold: f32,
104 non_max: bool,
105) -> Vec<Keypoint> {
106 let offsets = circle_offsets(w);
107
108 let card = [offsets[0], offsets[4], offsets[8], offsets[12]];
110
111 let y_start = 3;
112 let y_end = h.saturating_sub(3);
113 let x_start = 3;
114 let x_end = w.saturating_sub(3);
115
116 let n_rows = y_end.saturating_sub(y_start);
118 let row_corners: Vec<Vec<Keypoint>> = {
119 use std::sync::Mutex;
120 let results: Vec<Mutex<Vec<Keypoint>>> =
121 (0..n_rows).map(|_| Mutex::new(Vec::new())).collect();
122
123 use super::u8ops::gcd;
124 gcd::parallel_for(n_rows, |row_idx| {
125 let y = y_start + row_idx;
126 let mut row_kps = Vec::new();
127
128 let row_base = y * w;
129 let mut x = x_start;
130
131 #[cfg(target_arch = "aarch64")]
132 if std::arch::is_aarch64_feature_detected!("neon") {
133 while x + 4 <= x_end {
134 let pass_mask =
135 unsafe { fast9_cardinal_check_neon(data, row_base + x, &card, threshold) };
136 if pass_mask == 0 {
137 x += 4;
138 continue;
139 }
140 for i in 0..4 {
141 if (pass_mask >> i) & 1 != 0 {
142 let cx = x + i;
143 let idx = row_base + cx;
144 let max_run =
145 unsafe { fast9_full_check(data, idx, &offsets, threshold) };
146 if max_run >= 9 {
147 row_kps.push(Keypoint {
148 x: cx as f32,
149 y: y as f32,
150 response: max_run as f32,
151 angle: 0.0,
152 octave: 0,
153 });
154 }
155 }
156 }
157 x += 4;
158 }
159 }
160
161 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
162 if std::is_x86_feature_detected!("sse") {
163 while x + 4 <= x_end {
164 let pass_mask =
165 unsafe { fast9_cardinal_check_sse(data, row_base + x, &card, threshold) };
166 if pass_mask == 0 {
167 x += 4;
168 continue;
169 }
170 for i in 0..4 {
171 if (pass_mask >> i) & 1 != 0 {
172 let cx = x + i;
173 let idx = row_base + cx;
174 let max_run =
175 unsafe { fast9_full_check(data, idx, &offsets, threshold) };
176 if max_run >= 9 {
177 row_kps.push(Keypoint {
178 x: cx as f32,
179 y: y as f32,
180 response: max_run as f32,
181 angle: 0.0,
182 octave: 0,
183 });
184 }
185 }
186 }
187 x += 4;
188 }
189 }
190
191 while x < x_end {
192 let idx = row_base + x;
193 let center = unsafe { *data.get_unchecked(idx) };
194 let bright_thresh = center + threshold;
195 let dark_thresh = center - threshold;
196 let mut bright_count = 0u32;
197 let mut dark_count = 0u32;
198 for &co in &card {
199 let v = unsafe { *data.get_unchecked((idx as isize + co) as usize) };
200 bright_count += (v > bright_thresh) as u32;
201 dark_count += (v < dark_thresh) as u32;
202 }
203 if bright_count < 3 && dark_count < 3 {
204 x += 1;
205 continue;
206 }
207 let max_run = unsafe { fast9_full_check(data, idx, &offsets, threshold) };
208 if max_run >= 9 {
209 row_kps.push(Keypoint {
210 x: x as f32,
211 y: y as f32,
212 response: max_run as f32,
213 angle: 0.0,
214 octave: 0,
215 });
216 }
217 x += 1;
218 }
219
220 *results[row_idx].lock().expect("mutex poisoned") = row_kps;
221 });
222
223 results
224 .into_iter()
225 .map(|m| m.into_inner().expect("mutex poisoned"))
226 .collect()
227 };
228
229 let mut corners: Vec<Keypoint> = row_corners.into_iter().flatten().collect();
230
231 #[allow(unreachable_code)]
233 if false {
234 let y = y_start;
235 for _y in y_start..y_end {
236 let row_base = y * w;
237 let mut x = x_start;
238
239 #[cfg(target_arch = "aarch64")]
242 if std::arch::is_aarch64_feature_detected!("neon") {
243 while x + 4 <= x_end {
244 let pass_mask =
245 unsafe { fast9_cardinal_check_neon(data, row_base + x, &card, threshold) };
246 if pass_mask == 0 {
248 x += 4;
249 continue;
250 }
251 for i in 0..4 {
253 if (pass_mask >> i) & 1 != 0 {
254 let cx = x + i;
255 let idx = row_base + cx;
256 let max_run =
257 unsafe { fast9_full_check(data, idx, &offsets, threshold) };
258 if max_run >= 9 {
259 corners.push(Keypoint {
260 x: cx as f32,
261 y: y as f32,
262 response: max_run as f32,
263 angle: 0.0,
264 octave: 0,
265 });
266 }
267 }
268 }
269 x += 4;
270 }
271 }
272
273 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
274 if std::is_x86_feature_detected!("sse") {
275 while x + 4 <= x_end {
276 let pass_mask =
277 unsafe { fast9_cardinal_check_sse(data, row_base + x, &card, threshold) };
278 if pass_mask == 0 {
279 x += 4;
280 continue;
281 }
282 for i in 0..4 {
283 if (pass_mask >> i) & 1 != 0 {
284 let cx = x + i;
285 let idx = row_base + cx;
286 let max_run =
287 unsafe { fast9_full_check(data, idx, &offsets, threshold) };
288 if max_run >= 9 {
289 corners.push(Keypoint {
290 x: cx as f32,
291 y: y as f32,
292 response: max_run as f32,
293 angle: 0.0,
294 octave: 0,
295 });
296 }
297 }
298 }
299 x += 4;
300 }
301 }
302
303 while x < x_end {
305 let idx = row_base + x;
306 let center = unsafe { *data.get_unchecked(idx) };
307 let bright_thresh = center + threshold;
308 let dark_thresh = center - threshold;
309
310 let mut bright_count = 0u32;
311 let mut dark_count = 0u32;
312 for &co in &card {
313 let v = unsafe { *data.get_unchecked((idx as isize + co) as usize) };
314 bright_count += (v > bright_thresh) as u32;
315 dark_count += (v < dark_thresh) as u32;
316 }
317 if bright_count < 3 && dark_count < 3 {
318 x += 1;
319 continue;
320 }
321
322 let max_run = unsafe { fast9_full_check(data, idx, &offsets, threshold) };
323 if max_run >= 9 {
324 corners.push(Keypoint {
325 x: x as f32,
326 y: y as f32,
327 response: max_run as f32,
328 angle: 0.0,
329 octave: 0,
330 });
331 }
332 x += 1;
333 }
334 }
335 } if non_max {
338 let mut response_map = vec![0.0f32; h * w];
339 for kp in &corners {
340 let ix = kp.x as usize;
341 let iy = kp.y as usize;
342 response_map[iy * w + ix] = kp.response;
343 }
344 corners.retain(|kp| {
345 let ix = kp.x as usize;
346 let iy = kp.y as usize;
347 for dy in -1i32..=1 {
348 for dx in -1i32..=1 {
349 if dy == 0 && dx == 0 {
350 continue;
351 }
352 let ny = (iy as i32 + dy) as usize;
353 let nx = (ix as i32 + dx) as usize;
354 if ny < h && nx < w && response_map[ny * w + nx] > kp.response {
355 return false;
356 }
357 }
358 }
359 true
360 });
361 }
362
363 corners
364}
365
366#[cfg(target_arch = "aarch64")]
369#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
370#[target_feature(enable = "neon")]
371unsafe fn fast9_cardinal_check_neon(
372 data: &[f32],
373 base_idx: usize,
374 card: &[isize; 4],
375 threshold: f32,
376) -> u32 {
377 use std::arch::aarch64::*;
378
379 let ptr = data.as_ptr();
380 let thresh = vdupq_n_f32(threshold);
381 let neg_thresh = vdupq_n_f32(-threshold);
382 let three = vdupq_n_u32(3);
383
384 let centers = vld1q_f32(ptr.add(base_idx));
386 let bright_thresh = vaddq_f32(centers, thresh);
387 let dark_thresh = vaddq_f32(centers, neg_thresh);
388
389 let mut bright_cnt = vdupq_n_u32(0);
392 let mut dark_cnt = vdupq_n_u32(0);
393
394 for &co in card.iter() {
395 let circle_px = vld1q_f32(ptr.add((base_idx as isize + co) as usize));
396 let b = vcgtq_f32(circle_px, bright_thresh);
398 bright_cnt = vsubq_u32(bright_cnt, vreinterpretq_u32_f32(vreinterpretq_f32_u32(b)));
399 let d = vcltq_f32(circle_px, dark_thresh);
401 dark_cnt = vsubq_u32(dark_cnt, vreinterpretq_u32_f32(vreinterpretq_f32_u32(d)));
402 }
403
404 let bright_pass = vcgeq_u32(bright_cnt, three);
406 let dark_pass = vcgeq_u32(dark_cnt, three);
407 let pass = vorrq_u32(bright_pass, dark_pass);
408
409 let mut mask = 0u32;
411 if vgetq_lane_u32(pass, 0) != 0 {
412 mask |= 1;
413 }
414 if vgetq_lane_u32(pass, 1) != 0 {
415 mask |= 2;
416 }
417 if vgetq_lane_u32(pass, 2) != 0 {
418 mask |= 4;
419 }
420 if vgetq_lane_u32(pass, 3) != 0 {
421 mask |= 8;
422 }
423 mask
424}
425
426#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
428#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
429#[target_feature(enable = "sse")]
430unsafe fn fast9_cardinal_check_sse(
431 data: &[f32],
432 base_idx: usize,
433 card: &[isize; 4],
434 threshold: f32,
435) -> u32 {
436 #[cfg(target_arch = "x86")]
437 use std::arch::x86::*;
438 #[cfg(target_arch = "x86_64")]
439 use std::arch::x86_64::*;
440
441 let ptr = data.as_ptr();
442 let thresh = _mm_set1_ps(threshold);
443 let neg_thresh = _mm_set1_ps(-threshold);
444 let _zero = _mm_setzero_ps();
445
446 let centers = _mm_loadu_ps(ptr.add(base_idx));
447 let bright_thresh = _mm_add_ps(centers, thresh);
448 let dark_thresh = _mm_add_ps(centers, neg_thresh);
449
450 let mut bright_cnt = _mm_setzero_ps();
452 let mut dark_cnt = _mm_setzero_ps();
453 let one_bits = _mm_set1_ps(1.0);
454
455 for &co in card.iter() {
456 let circle_px = _mm_loadu_ps(ptr.add((base_idx as isize + co) as usize));
457 let b = _mm_and_ps(_mm_cmpgt_ps(circle_px, bright_thresh), one_bits);
459 bright_cnt = _mm_add_ps(bright_cnt, b);
460 let d = _mm_and_ps(_mm_cmplt_ps(circle_px, dark_thresh), one_bits);
461 dark_cnt = _mm_add_ps(dark_cnt, d);
462 }
463
464 let three = _mm_set1_ps(3.0);
465 let bright_pass = _mm_cmpge_ps(bright_cnt, three);
466 let dark_pass = _mm_cmpge_ps(dark_cnt, three);
467 let pass = _mm_or_ps(bright_pass, dark_pass);
468
469 _mm_movemask_ps(pass) as u32
470}
471
472#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
474#[inline]
475unsafe fn fast9_full_check(
476 data: &[f32],
477 idx: usize,
478 offsets: &[isize; 16],
479 threshold: f32,
480) -> usize {
481 let center = *data.get_unchecked(idx);
482 let bright_thresh = center + threshold;
483 let dark_thresh = center - threshold;
484
485 let mut bright_mask = 0u32;
486 let mut dark_mask = 0u32;
487 for i in 0..16 {
488 let v = *data.get_unchecked((idx as isize + offsets[i]) as usize);
489 if v > bright_thresh {
490 bright_mask |= 1 << i;
491 }
492 if v < dark_thresh {
493 dark_mask |= 1 << i;
494 }
495 }
496
497 let bright_run = contiguous_run_from_mask(bright_mask);
498 let dark_run = contiguous_run_from_mask(dark_mask);
499 bright_run.max(dark_run)
500}
501
502pub(crate) fn intensity_centroid_angle(
505 data: &[f32],
506 w: usize,
507 h: usize,
508 kx: usize,
509 ky: usize,
510 radius: i32,
511) -> f32 {
512 let mut m01: f32 = 0.0;
513 let mut m10: f32 = 0.0;
514 for dy in -radius..=radius {
515 let max_dx = ((radius * radius - dy * dy) as f32).sqrt() as i32;
516 for dx in -max_dx..=max_dx {
517 let py = ky as i32 + dy;
518 let px = kx as i32 + dx;
519 if py >= 0 && py < h as i32 && px >= 0 && px < w as i32 {
520 let v = data[py as usize * w + px as usize];
521 m10 += dx as f32 * v;
522 m01 += dy as f32 * v;
523 }
524 }
525 }
526 m01.atan2(m10)
527}
528
529#[allow(dead_code)]
531fn max_consecutive(flags: &[bool; 16]) -> usize {
532 let mut best = 0usize;
533 let mut count = 0usize;
534 for i in 0..32 {
536 if flags[i % 16] {
537 count += 1;
538 if count > best {
539 best = count;
540 }
541 } else {
542 count = 0;
543 }
544 }
545 best.min(16)
546}
547
548#[cfg(test)]
549mod tests {
550 use super::*;
551
552 #[test]
553 fn test_fast9_detects_corner() {
554 let mut data = vec![0.0f32; 30 * 30];
556 for x in 10..20 {
558 data[15 * 30 + x] = 1.0;
559 }
560 for y in 10..20 {
562 data[y * 30 + 10] = 1.0;
563 }
564 let img = Tensor::from_vec(vec![30, 30, 1], data).unwrap();
565 let kps = fast9_detect(&img, 0.3, false).unwrap();
566 assert!(!kps.is_empty(), "should detect corners near the L-shape");
567 }
568
569 #[test]
570 fn test_fast9_no_corners_on_flat() {
571 let img = Tensor::from_vec(vec![20, 20, 1], vec![0.5; 400]).unwrap();
572 let kps = fast9_detect(&img, 0.1, false).unwrap();
573 assert!(kps.is_empty(), "flat image should produce no corners");
574 }
575
576 #[test]
577 fn test_fast9_threshold() {
578 let mut data = vec![0.0f32; 30 * 30];
580 data[15 * 30 + 15] = 1.0;
581 for &(dx, dy) in &CIRCLE {
582 let px = (15 + dx) as usize;
583 let py = (15 + dy) as usize;
584 data[py * 30 + px] = 0.6;
585 }
586 let img = Tensor::from_vec(vec![30, 30, 1], data.clone()).unwrap();
587 let low = fast9_detect(&img, 0.1, false).unwrap();
588 let high = fast9_detect(&img, 0.8, false).unwrap();
589 assert!(
590 high.len() <= low.len(),
591 "higher threshold should produce fewer or equal corners: low={} high={}",
592 low.len(),
593 high.len()
594 );
595 }
596}