Skip to main content

yscv_imgproc/ops/
orb.rs

1use yscv_tensor::Tensor;
2
3use super::super::ImgProcError;
4use super::super::shape::hwc_shape;
5use super::brief::{BriefDescriptor, compute_rotated_brief, hamming_distance};
6use super::fast::{Keypoint, fast9_detect_raw, intensity_centroid_angle};
7
8/// ORB feature: oriented FAST keypoint paired with a rotated BRIEF descriptor.
9#[derive(Debug, Clone)]
10pub struct OrbFeature {
11    pub keypoint: Keypoint,
12    pub descriptor: BriefDescriptor,
13}
14
15/// Configuration for the ORB feature detector.
16#[derive(Debug, Clone)]
17pub struct OrbConfig {
18    /// Maximum number of features to retain (default 500).
19    pub num_features: usize,
20    /// Scale factor between pyramid levels (default 1.2).
21    pub scale_factor: f32,
22    /// Number of pyramid levels (default 8).
23    pub num_levels: usize,
24    /// FAST threshold (default 20.0/255.0).
25    pub fast_threshold: f32,
26}
27
28impl Default for OrbConfig {
29    fn default() -> Self {
30        Self {
31            num_features: 500,
32            scale_factor: 1.2,
33            num_levels: 8,
34            fast_threshold: 20.0 / 255.0,
35        }
36    }
37}
38
39/// Detect ORB features: oriented FAST keypoints + rotated BRIEF descriptors.
40///
41/// Input must be a single-channel `[H, W, 1]` image.
42///
43/// 1. Build an image pyramid.
44/// 2. At each level, run FAST-9 detection.
45/// 3. Compute orientation via the intensity centroid method.
46/// 4. Compute rotated BRIEF descriptors.
47/// 5. Keep top `num_features` by response across all levels.
48/// 6. Scale coordinates back to the original image.
49pub fn detect_orb(image: &Tensor, config: &OrbConfig) -> Result<Vec<OrbFeature>, ImgProcError> {
50    let (h, w, c) = hwc_shape(image)?;
51    if c != 1 {
52        return Err(ImgProcError::InvalidChannelCount {
53            expected: 1,
54            got: c,
55        });
56    }
57
58    let mut all_features: Vec<OrbFeature> = Vec::new();
59
60    // Build pyramid levels — hierarchical: level N from level N-1.
61    // Nearest-neighbor from previous level (fast, each level reads smaller image).
62    let mut pyramid: Vec<(Vec<f32>, usize, usize, f32)> = Vec::new(); // (data, h, w, scale)
63    {
64        let base_data = image.data().to_vec();
65        pyramid.push((base_data, h, w, 1.0));
66        let sf = config.scale_factor;
67        for level in 1..config.num_levels {
68            let (ref prev_data, prev_h, prev_w, prev_scale) = pyramid[level - 1];
69            let nh = (prev_h as f32 / sf).round() as usize;
70            let nw = (prev_w as f32 / sf).round() as usize;
71            if nh < 10 || nw < 10 {
72                break;
73            }
74            let mut dst = vec![0.0f32; nh * nw];
75            let y_ratio = prev_h as f32 / nh as f32;
76            let x_ratio = prev_w as f32 / nw as f32;
77            for dy in 0..nh {
78                let sy = ((dy as f32 + 0.5) * y_ratio) as usize;
79                let sy = sy.min(prev_h - 1);
80                let src_row = sy * prev_w;
81                let dst_row = dy * nw;
82                for dx in 0..nw {
83                    let sx = ((dx as f32 + 0.5) * x_ratio) as usize;
84                    dst[dst_row + dx] = prev_data[src_row + sx.min(prev_w - 1)];
85                }
86            }
87            pyramid.push((dst, nh, nw, prev_scale * sf));
88        }
89    }
90
91    // Process all pyramid levels in parallel (FAST + orientation + BRIEF per level).
92    {
93        use rayon::prelude::*;
94        let fast_threshold = config.fast_threshold;
95        let level_features: Vec<Vec<OrbFeature>> = pyramid
96            .par_iter()
97            .enumerate()
98            .map(|(level, (pdata, ph, pw, scale))| {
99                let ph = *ph;
100                let pw = *pw;
101                let scale = *scale;
102                let mut keypoints = fast9_detect_raw(pdata, ph, pw, fast_threshold, true);
103
104                let centroid_radius = 15i32.min((ph.min(pw) / 2) as i32 - 1);
105                if centroid_radius > 0 {
106                    for kp in &mut keypoints {
107                        let kx = kp.x.round() as usize;
108                        let ky = kp.y.round() as usize;
109                        kp.angle = intensity_centroid_angle(pdata, pw, ph, kx, ky, centroid_radius);
110                        kp.octave = level;
111                    }
112                }
113
114                let descs = compute_rotated_brief(pdata, pw, ph, &keypoints);
115                let mut features = Vec::new();
116                for (kp, desc_opt) in keypoints.into_iter().zip(descs) {
117                    if let Some(descriptor) = desc_opt {
118                        let mut scaled_kp = kp;
119                        scaled_kp.x *= scale;
120                        scaled_kp.y *= scale;
121                        features.push(OrbFeature {
122                            keypoint: scaled_kp,
123                            descriptor,
124                        });
125                    }
126                }
127                features
128            })
129            .collect();
130
131        for features in level_features {
132            all_features.extend(features);
133        }
134    }
135
136    // Sort by response (descending) and keep top N
137    all_features.sort_by(|a, b| {
138        b.keypoint
139            .response
140            .partial_cmp(&a.keypoint.response)
141            .unwrap_or(std::cmp::Ordering::Equal)
142    });
143    all_features.truncate(config.num_features);
144
145    Ok(all_features)
146}
147
148/// Match ORB features between two sets using brute-force Hamming distance.
149///
150/// Returns pairs `(idx_a, idx_b)` for the best match of each feature in `a`
151/// against `b`, provided the distance is at most `max_distance`.
152pub fn match_features(
153    features_a: &[OrbFeature],
154    features_b: &[OrbFeature],
155    max_distance: u32,
156) -> Vec<(usize, usize)> {
157    let mut matches = Vec::new();
158    for (i, fa) in features_a.iter().enumerate() {
159        let mut best_j = 0usize;
160        let mut best_dist = u32::MAX;
161        for (j, fb) in features_b.iter().enumerate() {
162            let d = hamming_distance(&fa.descriptor, &fb.descriptor);
163            if d < best_dist {
164                best_dist = d;
165                best_j = j;
166            }
167        }
168        if best_dist <= max_distance {
169            matches.push((i, best_j));
170        }
171    }
172    matches
173}
174
175#[cfg(test)]
176mod tests {
177    use super::*;
178
179    #[test]
180    fn test_orb_detect_returns_features() {
181        // 128x128 with multiple L-shaped corners well inside the BRIEF margin (21px)
182        let size = 128;
183        let mut data = vec![0.0f32; size * size];
184        // Place several L-shaped bright features centered around the image
185        // Each L: horizontal bar + vertical bar meeting at a corner
186        let centers = [(50, 50), (50, 80), (80, 50), (80, 80)];
187        for &(cy, cx) in &centers {
188            // Horizontal bright bar
189            for x in cx..cx + 12 {
190                if x < size {
191                    data[cy * size + x] = 1.0;
192                }
193            }
194            // Vertical bright bar
195            for y in cy..cy + 12 {
196                if y < size {
197                    data[y * size + cx] = 1.0;
198                }
199            }
200        }
201        let img = Tensor::from_vec(vec![size, size, 1], data).unwrap();
202        let config = OrbConfig {
203            num_features: 100,
204            num_levels: 1,
205            fast_threshold: 0.05,
206            ..OrbConfig::default()
207        };
208        let features = detect_orb(&img, &config).unwrap();
209        assert!(
210            !features.is_empty(),
211            "image with L-shaped corners should produce ORB features"
212        );
213        // Each feature should have a valid descriptor
214        for f in &features {
215            assert_eq!(f.descriptor.bits.len(), 32);
216        }
217    }
218
219    #[test]
220    fn test_orb_config_defaults() {
221        let cfg = OrbConfig::default();
222        assert_eq!(cfg.num_features, 500);
223        assert!((cfg.scale_factor - 1.2).abs() < 1e-6);
224        assert_eq!(cfg.num_levels, 8);
225        assert!((cfg.fast_threshold - 20.0 / 255.0).abs() < 1e-6);
226    }
227
228    #[test]
229    fn test_match_features_self() {
230        // Create textured image and detect features, then match against self
231        let size = 64;
232        let data: Vec<f32> = (0..size * size)
233            .map(|i| {
234                let x = i % size;
235                let y = i / size;
236                if (x / 4 + y / 4) % 2 == 0 { 0.9 } else { 0.1 }
237            })
238            .collect();
239        let img = Tensor::from_vec(vec![size, size, 1], data).unwrap();
240        let config = OrbConfig {
241            num_features: 50,
242            num_levels: 1,
243            ..OrbConfig::default()
244        };
245        let features = detect_orb(&img, &config).unwrap();
246        if features.is_empty() {
247            // Can't test matching without features; skip gracefully
248            return;
249        }
250        let matches = match_features(&features, &features, 0);
251        // Every feature should match itself at distance 0
252        assert_eq!(matches.len(), features.len());
253        for &(a, b) in &matches {
254            assert_eq!(a, b, "self-matching should produce identity pairs");
255        }
256    }
257}