Skip to main content

yscv_imgproc/ops/
brief.rs

1use yscv_tensor::Tensor;
2
3use super::super::ImgProcError;
4use super::super::shape::hwc_shape;
5use super::fast::Keypoint;
6
7/// BRIEF descriptor: 256-bit binary descriptor stored as 32 bytes.
8#[derive(Debug, Clone, PartialEq, Eq)]
9pub struct BriefDescriptor {
10    pub bits: [u8; 32],
11}
12
13/// Fixed set of 256 point-pair sampling offsets within a 31x31 patch.
14/// Each entry is (ax, ay, bx, by) relative to the keypoint centre.
15/// Generated from a deterministic xorshift PRNG.
16fn brief_pattern() -> Vec<(i32, i32, i32, i32)> {
17    let mut pattern = Vec::with_capacity(256);
18    let mut seed: u32 = 0xDEAD_BEEF;
19    for _ in 0..256 {
20        seed ^= seed << 13;
21        seed ^= seed >> 17;
22        seed ^= seed << 5;
23        let ax = ((seed % 31) as i32) - 15;
24        seed ^= seed << 13;
25        seed ^= seed >> 17;
26        seed ^= seed << 5;
27        let ay = ((seed % 31) as i32) - 15;
28        seed ^= seed << 13;
29        seed ^= seed >> 17;
30        seed ^= seed << 5;
31        let bx = ((seed % 31) as i32) - 15;
32        seed ^= seed << 13;
33        seed ^= seed >> 17;
34        seed ^= seed << 5;
35        let by = ((seed % 31) as i32) - 15;
36        pattern.push((ax, ay, bx, by));
37    }
38    pattern
39}
40
41/// Compute the rotated BRIEF pattern for a given angle (radians).
42/// Returns the pattern with each point rotated around the origin.
43fn rotated_brief_pattern(angle: f32) -> Vec<(i32, i32, i32, i32)> {
44    let base = brief_pattern();
45    let cos_a = angle.cos();
46    let sin_a = angle.sin();
47    base.iter()
48        .map(|&(ax, ay, bx, by)| {
49            let rax = (ax as f32 * cos_a - ay as f32 * sin_a).round() as i32;
50            let ray = (ax as f32 * sin_a + ay as f32 * cos_a).round() as i32;
51            let rbx = (bx as f32 * cos_a - by as f32 * sin_a).round() as i32;
52            let rby = (bx as f32 * sin_a + by as f32 * cos_a).round() as i32;
53            (rax, ray, rbx, rby)
54        })
55        .collect()
56}
57
58/// Compute BRIEF descriptors for keypoints on a grayscale `[H, W, 1]` image.
59///
60/// Uses a fixed set of 256 point-pair comparisons in a 31x31 patch.
61/// Keypoints too close to the image border are skipped (not included in output).
62pub fn compute_brief(
63    image: &Tensor,
64    keypoints: &[Keypoint],
65) -> Result<Vec<BriefDescriptor>, ImgProcError> {
66    let (h, w, c) = hwc_shape(image)?;
67    if c != 1 {
68        return Err(ImgProcError::InvalidChannelCount {
69            expected: 1,
70            got: c,
71        });
72    }
73    let data = image.data();
74    let pattern = brief_pattern();
75    let patch_radius = 15usize; // 31x31 patch
76
77    let mut descriptors = Vec::with_capacity(keypoints.len());
78
79    for kp in keypoints {
80        let kx = kp.x.round() as i32;
81        let ky = kp.y.round() as i32;
82
83        // Skip if too close to border
84        if kx < patch_radius as i32
85            || ky < patch_radius as i32
86            || kx + patch_radius as i32 >= w as i32
87            || ky + patch_radius as i32 >= h as i32
88        {
89            continue;
90        }
91
92        let mut bits = [0u8; 32];
93        for (i, &(ax, ay, bx, by)) in pattern.iter().enumerate() {
94            let pa = data[(ky + ay) as usize * w + (kx + ax) as usize];
95            let pb = data[(ky + by) as usize * w + (kx + bx) as usize];
96            if pa < pb {
97                bits[i / 8] |= 1 << (i % 8);
98            }
99        }
100        descriptors.push(BriefDescriptor { bits });
101    }
102
103    Ok(descriptors)
104}
105
106/// Compute rotated BRIEF descriptors for oriented keypoints.
107///
108/// The sampling pattern is rotated by each keypoint's angle.
109pub(crate) fn compute_rotated_brief(
110    data: &[f32],
111    w: usize,
112    h: usize,
113    keypoints: &[Keypoint],
114) -> Vec<Option<BriefDescriptor>> {
115    keypoints
116        .iter()
117        .map(|kp| {
118            let kx = kp.x.round() as i32;
119            let ky = kp.y.round() as i32;
120
121            // Need extra margin for rotated offsets (worst case ~21 pixels)
122            let margin = 21i32;
123            if kx < margin || ky < margin || kx + margin >= w as i32 || ky + margin >= h as i32 {
124                return None;
125            }
126
127            let pattern = rotated_brief_pattern(kp.angle);
128            let mut bits = [0u8; 32];
129            let mut valid = true;
130            for (i, &(ax, ay, bx, by)) in pattern.iter().enumerate() {
131                let pax = kx + ax;
132                let pay = ky + ay;
133                let pbx = kx + bx;
134                let pby = ky + by;
135                if pax < 0
136                    || pay < 0
137                    || pbx < 0
138                    || pby < 0
139                    || pax >= w as i32
140                    || pay >= h as i32
141                    || pbx >= w as i32
142                    || pby >= h as i32
143                {
144                    valid = false;
145                    break;
146                }
147                let pa = data[pay as usize * w + pax as usize];
148                let pb = data[pby as usize * w + pbx as usize];
149                if pa < pb {
150                    bits[i / 8] |= 1 << (i % 8);
151                }
152            }
153            if valid {
154                Some(BriefDescriptor { bits })
155            } else {
156                None
157            }
158        })
159        .collect()
160}
161
162/// Hamming distance between two BRIEF descriptors.
163pub fn hamming_distance(a: &BriefDescriptor, b: &BriefDescriptor) -> u32 {
164    a.bits
165        .iter()
166        .zip(b.bits.iter())
167        .map(|(&x, &y)| (x ^ y).count_ones())
168        .sum()
169}
170
171#[cfg(test)]
172mod tests {
173    use super::*;
174
175    #[test]
176    fn test_brief_descriptor_length() {
177        // Create a textured image so we get a descriptor
178        let data: Vec<f32> = (0..40 * 40)
179            .map(|i| ((i as f32) * 0.1).sin().abs())
180            .collect();
181        let img = Tensor::from_vec(vec![40, 40, 1], data).unwrap();
182        let kps = vec![Keypoint {
183            x: 20.0,
184            y: 20.0,
185            response: 1.0,
186            angle: 0.0,
187            octave: 0,
188        }];
189        let descs = compute_brief(&img, &kps).unwrap();
190        assert_eq!(descs.len(), 1);
191        assert_eq!(descs[0].bits.len(), 32);
192    }
193
194    #[test]
195    fn test_hamming_distance_identical() {
196        let d = BriefDescriptor { bits: [0xAB; 32] };
197        assert_eq!(hamming_distance(&d, &d), 0);
198    }
199
200    #[test]
201    fn test_hamming_distance_opposite() {
202        let a = BriefDescriptor { bits: [0x00; 32] };
203        let b = BriefDescriptor { bits: [0xFF; 32] };
204        assert_eq!(hamming_distance(&a, &b), 256);
205    }
206}