1use yscv_tensor::Tensor;
2
3use super::super::ImgProcError;
4use super::super::shape::hwc_shape;
5use super::brief::{BriefDescriptor, compute_rotated_brief, hamming_distance};
6use super::fast::{Keypoint, fast9_detect_raw, intensity_centroid_angle};
7
8#[derive(Debug, Clone)]
10pub struct OrbFeature {
11 pub keypoint: Keypoint,
12 pub descriptor: BriefDescriptor,
13}
14
15#[derive(Debug, Clone)]
17pub struct OrbConfig {
18 pub num_features: usize,
20 pub scale_factor: f32,
22 pub num_levels: usize,
24 pub fast_threshold: f32,
26}
27
28impl Default for OrbConfig {
29 fn default() -> Self {
30 Self {
31 num_features: 500,
32 scale_factor: 1.2,
33 num_levels: 8,
34 fast_threshold: 20.0 / 255.0,
35 }
36 }
37}
38
39pub fn detect_orb(image: &Tensor, config: &OrbConfig) -> Result<Vec<OrbFeature>, ImgProcError> {
50 let (h, w, c) = hwc_shape(image)?;
51 if c != 1 {
52 return Err(ImgProcError::InvalidChannelCount {
53 expected: 1,
54 got: c,
55 });
56 }
57
58 let mut all_features: Vec<OrbFeature> = Vec::new();
59
60 let mut pyramid: Vec<(Vec<f32>, usize, usize, f32)> = Vec::new(); {
64 let base_data = image.data().to_vec();
65 pyramid.push((base_data, h, w, 1.0));
66 let sf = config.scale_factor;
67 for level in 1..config.num_levels {
68 let (ref prev_data, prev_h, prev_w, prev_scale) = pyramid[level - 1];
69 let nh = (prev_h as f32 / sf).round() as usize;
70 let nw = (prev_w as f32 / sf).round() as usize;
71 if nh < 10 || nw < 10 {
72 break;
73 }
74 let mut dst = vec![0.0f32; nh * nw];
75 let y_ratio = prev_h as f32 / nh as f32;
76 let x_ratio = prev_w as f32 / nw as f32;
77 for dy in 0..nh {
78 let sy = ((dy as f32 + 0.5) * y_ratio) as usize;
79 let sy = sy.min(prev_h - 1);
80 let src_row = sy * prev_w;
81 let dst_row = dy * nw;
82 for dx in 0..nw {
83 let sx = ((dx as f32 + 0.5) * x_ratio) as usize;
84 dst[dst_row + dx] = prev_data[src_row + sx.min(prev_w - 1)];
85 }
86 }
87 pyramid.push((dst, nh, nw, prev_scale * sf));
88 }
89 }
90
91 {
93 use rayon::prelude::*;
94 let fast_threshold = config.fast_threshold;
95 let level_features: Vec<Vec<OrbFeature>> = pyramid
96 .par_iter()
97 .enumerate()
98 .map(|(level, (pdata, ph, pw, scale))| {
99 let ph = *ph;
100 let pw = *pw;
101 let scale = *scale;
102 let mut keypoints = fast9_detect_raw(pdata, ph, pw, fast_threshold, true);
103
104 let centroid_radius = 15i32.min((ph.min(pw) / 2) as i32 - 1);
105 if centroid_radius > 0 {
106 for kp in &mut keypoints {
107 let kx = kp.x.round() as usize;
108 let ky = kp.y.round() as usize;
109 kp.angle = intensity_centroid_angle(pdata, pw, ph, kx, ky, centroid_radius);
110 kp.octave = level;
111 }
112 }
113
114 let descs = compute_rotated_brief(pdata, pw, ph, &keypoints);
115 let mut features = Vec::new();
116 for (kp, desc_opt) in keypoints.into_iter().zip(descs) {
117 if let Some(descriptor) = desc_opt {
118 let mut scaled_kp = kp;
119 scaled_kp.x *= scale;
120 scaled_kp.y *= scale;
121 features.push(OrbFeature {
122 keypoint: scaled_kp,
123 descriptor,
124 });
125 }
126 }
127 features
128 })
129 .collect();
130
131 for features in level_features {
132 all_features.extend(features);
133 }
134 }
135
136 all_features.sort_by(|a, b| {
138 b.keypoint
139 .response
140 .partial_cmp(&a.keypoint.response)
141 .unwrap_or(std::cmp::Ordering::Equal)
142 });
143 all_features.truncate(config.num_features);
144
145 Ok(all_features)
146}
147
148pub fn match_features(
153 features_a: &[OrbFeature],
154 features_b: &[OrbFeature],
155 max_distance: u32,
156) -> Vec<(usize, usize)> {
157 let mut matches = Vec::new();
158 for (i, fa) in features_a.iter().enumerate() {
159 let mut best_j = 0usize;
160 let mut best_dist = u32::MAX;
161 for (j, fb) in features_b.iter().enumerate() {
162 let d = hamming_distance(&fa.descriptor, &fb.descriptor);
163 if d < best_dist {
164 best_dist = d;
165 best_j = j;
166 }
167 }
168 if best_dist <= max_distance {
169 matches.push((i, best_j));
170 }
171 }
172 matches
173}
174
175#[cfg(test)]
176mod tests {
177 use super::*;
178
179 #[test]
180 fn test_orb_detect_returns_features() {
181 let size = 128;
183 let mut data = vec![0.0f32; size * size];
184 let centers = [(50, 50), (50, 80), (80, 50), (80, 80)];
187 for &(cy, cx) in ¢ers {
188 for x in cx..cx + 12 {
190 if x < size {
191 data[cy * size + x] = 1.0;
192 }
193 }
194 for y in cy..cy + 12 {
196 if y < size {
197 data[y * size + cx] = 1.0;
198 }
199 }
200 }
201 let img = Tensor::from_vec(vec![size, size, 1], data).unwrap();
202 let config = OrbConfig {
203 num_features: 100,
204 num_levels: 1,
205 fast_threshold: 0.05,
206 ..OrbConfig::default()
207 };
208 let features = detect_orb(&img, &config).unwrap();
209 assert!(
210 !features.is_empty(),
211 "image with L-shaped corners should produce ORB features"
212 );
213 for f in &features {
215 assert_eq!(f.descriptor.bits.len(), 32);
216 }
217 }
218
219 #[test]
220 fn test_orb_config_defaults() {
221 let cfg = OrbConfig::default();
222 assert_eq!(cfg.num_features, 500);
223 assert!((cfg.scale_factor - 1.2).abs() < 1e-6);
224 assert_eq!(cfg.num_levels, 8);
225 assert!((cfg.fast_threshold - 20.0 / 255.0).abs() < 1e-6);
226 }
227
228 #[test]
229 fn test_match_features_self() {
230 let size = 64;
232 let data: Vec<f32> = (0..size * size)
233 .map(|i| {
234 let x = i % size;
235 let y = i / size;
236 if (x / 4 + y / 4) % 2 == 0 { 0.9 } else { 0.1 }
237 })
238 .collect();
239 let img = Tensor::from_vec(vec![size, size, 1], data).unwrap();
240 let config = OrbConfig {
241 num_features: 50,
242 num_levels: 1,
243 ..OrbConfig::default()
244 };
245 let features = detect_orb(&img, &config).unwrap();
246 if features.is_empty() {
247 return;
249 }
250 let matches = match_features(&features, &features, 0);
251 assert_eq!(matches.len(), features.len());
253 for &(a, b) in &matches {
254 assert_eq!(a, b, "self-matching should produce identity pairs");
255 }
256 }
257}