Skip to main content

yscv_track/
reid.rs

1//! Re-identification feature extraction for DeepSORT appearance matching.
2
3use std::collections::HashMap;
4
5use yscv_tensor::Tensor;
6
7use crate::TrackError;
8
9/// Re-identification feature extractor trait.
10/// Implementations extract appearance embeddings from image crops.
11pub trait ReIdExtractor: Send + Sync {
12    /// Extract an embedding vector from an image crop [H, W, C].
13    fn extract(&self, crop: &Tensor) -> Result<Vec<f32>, TrackError>;
14    /// Embedding dimension.
15    fn dim(&self) -> usize;
16}
17
18/// Simple re-id using average color histogram as embedding.
19/// Useful as a baseline / fallback when no learned model is available.
20pub struct ColorHistogramReId {
21    bins: usize,
22}
23
24impl ColorHistogramReId {
25    pub fn new(bins: usize) -> Self {
26        Self { bins }
27    }
28}
29
30impl ReIdExtractor for ColorHistogramReId {
31    fn extract(&self, crop: &Tensor) -> Result<Vec<f32>, TrackError> {
32        // Compute color histogram per channel, normalize to unit vector
33        let shape = crop.shape();
34        let c = *shape.last().unwrap_or(&3);
35        let data = crop.data();
36        let pixels = data.len() / c;
37        let mut hist = vec![0.0f32; self.bins * c];
38        for px in 0..pixels {
39            for ch in 0..c {
40                let val = data[px * c + ch].clamp(0.0, 1.0);
41                let bin = ((val * self.bins as f32) as usize).min(self.bins - 1);
42                hist[ch * self.bins + bin] += 1.0;
43            }
44        }
45        // L2 normalize
46        let norm: f32 = hist.iter().map(|v| v * v).sum::<f32>().sqrt();
47        if norm > 0.0 {
48            hist.iter_mut().for_each(|v| *v /= norm);
49        }
50        Ok(hist)
51    }
52
53    fn dim(&self) -> usize {
54        self.bins * 3
55    }
56}
57
58/// Re-id feature gallery that stores per-track appearance features
59/// and computes cosine distance for matching.
60pub struct ReIdGallery {
61    features: HashMap<u64, Vec<Vec<f32>>>,
62    max_features: usize,
63}
64
65impl ReIdGallery {
66    /// Create a new gallery that keeps at most `max_features` per track.
67    pub fn new(max_features: usize) -> Self {
68        Self {
69            features: HashMap::new(),
70            max_features,
71        }
72    }
73
74    /// Add a feature vector for the given track, evicting the oldest if
75    /// the gallery for that track is full.
76    pub fn update(&mut self, track_id: u64, feature: Vec<f32>) {
77        let entry = self.features.entry(track_id).or_default();
78        entry.push(feature);
79        if entry.len() > self.max_features {
80            entry.remove(0);
81        }
82    }
83
84    /// Remove all stored features for a track.
85    pub fn remove(&mut self, track_id: u64) {
86        self.features.remove(&track_id);
87    }
88
89    /// Minimum cosine distance between `feature` and all stored features
90    /// for `track_id`. Returns 1.0 (maximum distance) if the track has no
91    /// stored features.
92    pub fn min_cosine_distance(&self, track_id: u64, feature: &[f32]) -> f32 {
93        match self.features.get(&track_id) {
94            Some(gallery) if !gallery.is_empty() => gallery
95                .iter()
96                .map(|g| cosine_distance(feature, g))
97                .fold(f32::INFINITY, f32::min),
98            _ => 1.0,
99        }
100    }
101
102    /// Build a cost matrix of shape `[track_ids.len(), features.len()]`
103    /// where each entry is the minimum cosine distance between the track's
104    /// gallery and the candidate feature.
105    pub fn cost_matrix(&self, track_ids: &[u64], features: &[Vec<f32>]) -> Vec<Vec<f32>> {
106        track_ids
107            .iter()
108            .map(|&tid| {
109                features
110                    .iter()
111                    .map(|f| self.min_cosine_distance(tid, f))
112                    .collect()
113            })
114            .collect()
115    }
116}
117
118/// Cosine distance between two feature vectors: `1 - cos(a, b)`.
119fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
120    let mut dot = 0.0_f32;
121    let mut norm_a = 0.0_f32;
122    let mut norm_b = 0.0_f32;
123    for (&ai, &bi) in a.iter().zip(b.iter()) {
124        dot += ai * bi;
125        norm_a += ai * ai;
126        norm_b += bi * bi;
127    }
128    let denom = norm_a.sqrt() * norm_b.sqrt();
129    if denom < 1e-12 {
130        return 1.0;
131    }
132    1.0 - (dot / denom)
133}
134
135#[cfg(test)]
136mod tests {
137    use super::*;
138
139    #[test]
140    fn color_histogram_reid_basic() {
141        let extractor = ColorHistogramReId::new(8);
142        assert_eq!(extractor.dim(), 24); // 8 bins * 3 channels
143
144        // Create a small 2x2 RGB image with known values
145        let data = vec![
146            0.1, 0.2, 0.3, // pixel (0,0)
147            0.4, 0.5, 0.6, // pixel (0,1)
148            0.7, 0.8, 0.9, // pixel (1,0)
149            0.0, 0.1, 0.2, // pixel (1,1)
150        ];
151        let crop = Tensor::from_vec(vec![2, 2, 3], data).unwrap();
152        let embedding = extractor.extract(&crop).unwrap();
153
154        // Verify dimension matches
155        assert_eq!(embedding.len(), extractor.dim());
156
157        // Verify L2 normalization: ||embedding|| should be ~1.0
158        let norm: f32 = embedding.iter().map(|v| v * v).sum::<f32>().sqrt();
159        assert!((norm - 1.0).abs() < 1e-5, "Expected unit norm, got {norm}");
160    }
161
162    #[test]
163    fn reid_gallery_update_and_distance() {
164        let mut gallery = ReIdGallery::new(10);
165
166        // Track with no features should have max distance
167        assert!((gallery.min_cosine_distance(1, &[1.0, 0.0, 0.0]) - 1.0).abs() < 1e-6);
168
169        // Add a feature for track 1
170        gallery.update(1, vec![1.0, 0.0, 0.0]);
171
172        // Same direction -> distance ~0
173        let dist = gallery.min_cosine_distance(1, &[1.0, 0.0, 0.0]);
174        assert!(
175            dist < 1e-5,
176            "Expected ~0 distance for identical vectors, got {dist}"
177        );
178
179        // Orthogonal -> distance ~1
180        let dist = gallery.min_cosine_distance(1, &[0.0, 1.0, 0.0]);
181        assert!(
182            (dist - 1.0).abs() < 1e-5,
183            "Expected ~1 distance for orthogonal vectors, got {dist}"
184        );
185
186        // Add a second feature; min distance should pick the closer one
187        gallery.update(1, vec![0.0, 1.0, 0.0]);
188        let dist = gallery.min_cosine_distance(1, &[0.0, 1.0, 0.0]);
189        assert!(
190            dist < 1e-5,
191            "Expected ~0 after adding matching feature, got {dist}"
192        );
193
194        // Remove track and verify distance returns to max
195        gallery.remove(1);
196        assert!((gallery.min_cosine_distance(1, &[1.0, 0.0, 0.0]) - 1.0).abs() < 1e-6);
197    }
198
199    #[test]
200    fn reid_gallery_cost_matrix() {
201        let mut gallery = ReIdGallery::new(10);
202
203        gallery.update(10, vec![1.0, 0.0, 0.0]);
204        gallery.update(20, vec![0.0, 1.0, 0.0]);
205
206        let track_ids = vec![10, 20];
207        let features = vec![
208            vec![1.0, 0.0, 0.0], // should be close to track 10, far from 20
209            vec![0.0, 1.0, 0.0], // should be far from track 10, close to 20
210            vec![0.0, 0.0, 1.0], // orthogonal to both
211        ];
212
213        let matrix = gallery.cost_matrix(&track_ids, &features);
214
215        // Shape: 2 tracks x 3 features
216        assert_eq!(matrix.len(), 2);
217        assert_eq!(matrix[0].len(), 3);
218        assert_eq!(matrix[1].len(), 3);
219
220        // track 10 vs feature 0 (identical) -> ~0
221        assert!(matrix[0][0] < 1e-5);
222        // track 10 vs feature 1 (orthogonal) -> ~1
223        assert!((matrix[0][1] - 1.0).abs() < 1e-5);
224        // track 20 vs feature 0 (orthogonal) -> ~1
225        assert!((matrix[1][0] - 1.0).abs() < 1e-5);
226        // track 20 vs feature 1 (identical) -> ~0
227        assert!(matrix[1][1] < 1e-5);
228    }
229}