Skip to main content

yscv_track/
byte_track.rs

1//! ByteTrack multi-object tracker with two-stage association.
2//!
3//! ByteTrack splits detections into high- and low-confidence groups, matching
4//! high-confidence detections to tracks first, then attempting to recover
5//! unmatched tracks with low-confidence detections.
6
7use yscv_detect::{BoundingBox, Detection, iou};
8
9use crate::types::TrackedDetection;
10
11/// Internal track representation for ByteTrack.
12#[derive(Debug, Clone)]
13struct ByteTrack {
14    id: usize,
15    bbox: BoundingBox,
16    score: f32,
17    class_id: usize,
18    age: usize,
19    hits: usize,
20}
21
22/// ByteTrack multi-object tracker.
23///
24/// Uses two-stage association: first matching high-confidence detections to
25/// existing tracks, then trying to match low-confidence detections to
26/// remaining unmatched tracks.
27pub struct ByteTracker {
28    next_id: usize,
29    tracks: Vec<ByteTrack>,
30    high_threshold: f32,
31    low_threshold: f32,
32    iou_threshold: f32,
33    max_age: usize,
34}
35
36impl ByteTracker {
37    /// Create a new `ByteTracker`.
38    ///
39    /// - `high_threshold`: minimum score to be considered high-confidence (e.g. 0.5).
40    /// - `low_threshold`: minimum score to be considered at all (e.g. 0.1).
41    /// - `iou_threshold`: minimum IoU for a detection-track match (e.g. 0.3).
42    /// - `max_age`: frames a track survives without a match before deletion.
43    pub fn new(
44        high_threshold: f32,
45        low_threshold: f32,
46        iou_threshold: f32,
47        max_age: usize,
48    ) -> Self {
49        Self {
50            next_id: 1,
51            tracks: Vec::new(),
52            high_threshold,
53            low_threshold,
54            iou_threshold,
55            max_age,
56        }
57    }
58
59    /// Update the tracker with a new frame of detections.
60    ///
61    /// Returns the list of currently tracked detections with their track IDs.
62    pub fn update(&mut self, detections: &[Detection]) -> Vec<TrackedDetection> {
63        // 1. Split detections into high and low confidence.
64        let mut high: Vec<usize> = Vec::new();
65        let mut low: Vec<usize> = Vec::new();
66        for (i, det) in detections.iter().enumerate() {
67            if det.score >= self.high_threshold {
68                high.push(i);
69            } else if det.score >= self.low_threshold {
70                low.push(i);
71            }
72        }
73
74        let all_track_indices: Vec<usize> = (0..self.tracks.len()).collect();
75        let mut matched_tracks: Vec<bool> = vec![false; self.tracks.len()];
76        let mut matched_dets: Vec<bool> = vec![false; detections.len()];
77
78        // 2. Match high-confidence detections to existing tracks (greedy by best IoU).
79        let assignments1 = greedy_match(
80            &self.tracks,
81            detections,
82            &all_track_indices,
83            &high,
84            self.iou_threshold,
85        );
86
87        for &(ti, di) in &assignments1 {
88            matched_tracks[ti] = true;
89            matched_dets[di] = true;
90            self.tracks[ti].bbox = detections[di].bbox;
91            self.tracks[ti].score = detections[di].score;
92            self.tracks[ti].class_id = detections[di].class_id;
93            self.tracks[ti].age = 0;
94            self.tracks[ti].hits += 1;
95        }
96
97        // 3. Match low-confidence detections to remaining unmatched tracks.
98        let unmatched_track_indices: Vec<usize> = (0..self.tracks.len())
99            .filter(|&i| !matched_tracks[i])
100            .collect();
101
102        let assignments2 = greedy_match(
103            &self.tracks,
104            detections,
105            &unmatched_track_indices,
106            &low,
107            self.iou_threshold,
108        );
109
110        for &(ti, di) in &assignments2 {
111            matched_tracks[ti] = true;
112            matched_dets[di] = true;
113            self.tracks[ti].bbox = detections[di].bbox;
114            self.tracks[ti].score = detections[di].score;
115            self.tracks[ti].class_id = detections[di].class_id;
116            self.tracks[ti].age = 0;
117            self.tracks[ti].hits += 1;
118        }
119
120        // 4. Create new tracks for unmatched high-confidence detections.
121        for &di in &high {
122            if !matched_dets[di] {
123                let id = self.next_id;
124                self.next_id += 1;
125                self.tracks.push(ByteTrack {
126                    id,
127                    bbox: detections[di].bbox,
128                    score: detections[di].score,
129                    class_id: detections[di].class_id,
130                    age: 0,
131                    hits: 1,
132                });
133            }
134        }
135
136        // 5. Age unmatched tracks, remove those exceeding max_age.
137        for (i, track) in self.tracks.iter_mut().enumerate() {
138            if i < matched_tracks.len() && !matched_tracks[i] {
139                track.age += 1;
140            }
141        }
142        self.tracks.retain(|t| t.age <= self.max_age);
143
144        // 6. Return all active tracks as TrackedDetection.
145        self.tracks
146            .iter()
147            .map(|t| TrackedDetection {
148                track_id: t.id as u64,
149                detection: Detection {
150                    bbox: t.bbox,
151                    score: t.score,
152                    class_id: t.class_id,
153                },
154            })
155            .collect()
156    }
157
158    /// Return the number of currently active tracks.
159    pub fn active_track_count(&self) -> usize {
160        self.tracks.len()
161    }
162}
163
164/// Greedy IoU matching: for each track, find the detection with the highest IoU
165/// at or above `iou_threshold`. Returns a list of (track_index, detection_index) pairs.
166fn greedy_match(
167    tracks: &[ByteTrack],
168    detections: &[Detection],
169    track_indices: &[usize],
170    det_indices: &[usize],
171    iou_threshold: f32,
172) -> Vec<(usize, usize)> {
173    let mut used_dets = vec![false; detections.len()];
174    let mut assignments = Vec::new();
175
176    for &ti in track_indices {
177        let mut best_iou = iou_threshold;
178        let mut best_di: Option<usize> = None;
179        for &di in det_indices {
180            if used_dets[di] {
181                continue;
182            }
183            let iou_val = iou(tracks[ti].bbox, detections[di].bbox);
184            if iou_val >= best_iou {
185                best_iou = iou_val;
186                best_di = Some(di);
187            }
188        }
189        if let Some(di) = best_di {
190            used_dets[di] = true;
191            assignments.push((ti, di));
192        }
193    }
194
195    assignments
196}
197
198#[cfg(test)]
199mod tests {
200    use super::*;
201
202    fn det(x1: f32, y1: f32, x2: f32, y2: f32, score: f32) -> Detection {
203        Detection {
204            bbox: BoundingBox { x1, y1, x2, y2 },
205            score,
206            class_id: 0,
207        }
208    }
209
210    #[test]
211    fn byte_track_creates_tracks() {
212        let mut tracker = ByteTracker::new(0.5, 0.1, 0.3, 3);
213        let dets = [
214            det(10.0, 10.0, 50.0, 50.0, 0.9),
215            det(100.0, 100.0, 150.0, 150.0, 0.8),
216        ];
217        let tracked = tracker.update(&dets);
218        assert_eq!(tracked.len(), 2);
219        assert_eq!(tracker.active_track_count(), 2);
220    }
221
222    #[test]
223    fn byte_track_maintains_ids() {
224        let mut tracker = ByteTracker::new(0.5, 0.1, 0.3, 3);
225        let dets1 = [
226            det(10.0, 10.0, 50.0, 50.0, 0.9),
227            det(100.0, 100.0, 150.0, 150.0, 0.8),
228        ];
229        let tracked1 = tracker.update(&dets1);
230        let id0 = tracked1[0].track_id;
231        let id1 = tracked1[1].track_id;
232
233        // Same objects, slightly moved.
234        let dets2 = [
235            det(12.0, 12.0, 52.0, 52.0, 0.9),
236            det(102.0, 102.0, 152.0, 152.0, 0.85),
237        ];
238        let tracked2 = tracker.update(&dets2);
239        assert_eq!(tracked2.len(), 2);
240
241        let ids: Vec<u64> = tracked2.iter().map(|t| t.track_id).collect();
242        assert!(ids.contains(&id0));
243        assert!(ids.contains(&id1));
244    }
245
246    #[test]
247    fn byte_track_removes_old_tracks() {
248        let max_age = 2;
249        let mut tracker = ByteTracker::new(0.5, 0.1, 0.3, max_age);
250
251        // Create a track.
252        let dets = [det(10.0, 10.0, 50.0, 50.0, 0.9)];
253        tracker.update(&dets);
254        assert_eq!(tracker.active_track_count(), 1);
255
256        // No detections for max_age + 1 frames -> track should be removed.
257        for _ in 0..=max_age {
258            tracker.update(&[]);
259        }
260        assert_eq!(tracker.active_track_count(), 0);
261    }
262
263    #[test]
264    fn byte_track_low_confidence_association() {
265        let mut tracker = ByteTracker::new(0.5, 0.1, 0.3, 3);
266
267        // Frame 1: high-confidence detection creates a track.
268        let dets1 = [det(10.0, 10.0, 50.0, 50.0, 0.9)];
269        let tracked1 = tracker.update(&dets1);
270        let id = tracked1[0].track_id;
271
272        // Frame 2: same object appears with low confidence (below high, above low).
273        let dets2 = [det(12.0, 12.0, 52.0, 52.0, 0.3)];
274        let tracked2 = tracker.update(&dets2);
275
276        // Should match the existing track, not create a new one.
277        assert_eq!(tracked2.len(), 1);
278        assert_eq!(tracked2[0].track_id, id);
279    }
280
281    #[test]
282    fn byte_track_new_track_for_new_object() {
283        let mut tracker = ByteTracker::new(0.5, 0.1, 0.3, 3);
284
285        // Frame 1: one object.
286        let dets1 = [det(10.0, 10.0, 50.0, 50.0, 0.9)];
287        let tracked1 = tracker.update(&dets1);
288        assert_eq!(tracked1.len(), 1);
289        let id1 = tracked1[0].track_id;
290
291        // Frame 2: original object + a new far-away object.
292        let dets2 = [
293            det(12.0, 12.0, 52.0, 52.0, 0.9),
294            det(200.0, 200.0, 250.0, 250.0, 0.8),
295        ];
296        let tracked2 = tracker.update(&dets2);
297        assert_eq!(tracked2.len(), 2);
298
299        let ids: Vec<u64> = tracked2.iter().map(|t| t.track_id).collect();
300        assert!(ids.contains(&id1));
301        // The new object should have a different track_id.
302        let new_id = ids.iter().find(|&&id| id != id1).unwrap();
303        assert_ne!(*new_id, id1);
304    }
305
306    #[test]
307    fn byte_track_three_objects_simultaneously() {
308        let mut tracker = ByteTracker::new(0.5, 0.1, 0.3, 5);
309        let dets = [
310            det(10.0, 10.0, 50.0, 50.0, 0.9),
311            det(100.0, 100.0, 140.0, 140.0, 0.8),
312            det(200.0, 200.0, 240.0, 240.0, 0.7),
313        ];
314        let tracked = tracker.update(&dets);
315        assert_eq!(tracked.len(), 3);
316        assert_eq!(tracker.active_track_count(), 3);
317
318        // All IDs unique.
319        let ids: Vec<u64> = tracked.iter().map(|t| t.track_id).collect();
320        let mut unique = ids.clone();
321        unique.sort();
322        unique.dedup();
323        assert_eq!(unique.len(), 3);
324    }
325
326    #[test]
327    fn byte_track_empty_detections_ages_tracks() {
328        let mut tracker = ByteTracker::new(0.5, 0.1, 0.3, 5);
329        tracker.update(&[det(10.0, 10.0, 50.0, 50.0, 0.9)]);
330        assert_eq!(tracker.active_track_count(), 1);
331
332        // Empty frame: track should still exist but age.
333        let result = tracker.update(&[]);
334        assert_eq!(tracker.active_track_count(), 1);
335        assert_eq!(result.len(), 1); // track still reported
336    }
337
338    #[test]
339    fn byte_track_single_detection_stable_id() {
340        let mut tracker = ByteTracker::new(0.5, 0.1, 0.3, 5);
341        let d = det(20.0, 20.0, 60.0, 60.0, 0.9);
342        let first = tracker.update(&[d]);
343        let id = first[0].track_id;
344
345        for _ in 0..10 {
346            let tracked = tracker.update(&[d]);
347            assert_eq!(tracked.len(), 1);
348            assert_eq!(tracked[0].track_id, id);
349        }
350    }
351
352    #[test]
353    fn byte_track_id_stability_smooth_motion() {
354        let mut tracker = ByteTracker::new(0.5, 0.1, 0.3, 5);
355        let first = tracker.update(&[det(10.0, 10.0, 50.0, 50.0, 0.9)]);
356        let id = first[0].track_id;
357
358        // Smooth small moves.
359        let positions = [
360            (12.0, 12.0, 52.0, 52.0),
361            (14.0, 14.0, 54.0, 54.0),
362            (16.0, 16.0, 56.0, 56.0),
363            (18.0, 18.0, 58.0, 58.0),
364        ];
365        for (x1, y1, x2, y2) in positions {
366            let tracked = tracker.update(&[det(x1, y1, x2, y2, 0.9)]);
367            assert_eq!(tracked.len(), 1);
368            assert_eq!(tracked[0].track_id, id);
369        }
370    }
371
372    #[test]
373    fn byte_track_iou_matching_overlapping_bboxes() {
374        let mut tracker = ByteTracker::new(0.5, 0.1, 0.3, 5);
375
376        // Two overlapping detections.
377        let dets1 = [
378            det(10.0, 10.0, 50.0, 50.0, 0.9),
379            det(30.0, 30.0, 70.0, 70.0, 0.8),
380        ];
381        let tracked1 = tracker.update(&dets1);
382        assert_eq!(tracked1.len(), 2);
383        let id_a = tracked1[0].track_id;
384        let id_b = tracked1[1].track_id;
385        assert_ne!(id_a, id_b);
386
387        // Slightly move each, should still match correctly.
388        let dets2 = [
389            det(11.0, 11.0, 51.0, 51.0, 0.9),
390            det(31.0, 31.0, 71.0, 71.0, 0.8),
391        ];
392        let tracked2 = tracker.update(&dets2);
393        let ids2: Vec<u64> = tracked2.iter().map(|t| t.track_id).collect();
394        assert!(ids2.contains(&id_a));
395        assert!(ids2.contains(&id_b));
396    }
397
398    #[test]
399    fn byte_track_low_vs_high_confidence_precedence() {
400        // High-confidence detection should match first, low-confidence second.
401        let mut tracker = ByteTracker::new(0.5, 0.1, 0.3, 5);
402
403        // Create two tracks with high-confidence detections.
404        let dets1 = [
405            det(10.0, 10.0, 50.0, 50.0, 0.9),
406            det(100.0, 100.0, 140.0, 140.0, 0.8),
407        ];
408        let tracked1 = tracker.update(&dets1);
409        let id_a = tracked1[0].track_id;
410        let id_b = tracked1[1].track_id;
411
412        // Frame 2: track A has high-confidence, track B only has low-confidence.
413        let dets2 = [
414            det(12.0, 12.0, 52.0, 52.0, 0.9),     // high — should match track A
415            det(102.0, 102.0, 142.0, 142.0, 0.2), // low — should match track B
416        ];
417        let tracked2 = tracker.update(&dets2);
418        assert_eq!(tracked2.len(), 2);
419        let ids2: Vec<u64> = tracked2.iter().map(|t| t.track_id).collect();
420        assert!(ids2.contains(&id_a));
421        assert!(ids2.contains(&id_b));
422    }
423
424    #[test]
425    fn byte_track_low_confidence_no_new_track() {
426        // Low-confidence detections should NOT create new tracks.
427        let mut tracker = ByteTracker::new(0.5, 0.1, 0.3, 5);
428
429        // Only low-confidence detection.
430        let dets = [det(10.0, 10.0, 50.0, 50.0, 0.2)];
431        let tracked = tracker.update(&dets);
432        assert_eq!(tracked.len(), 0);
433        assert_eq!(tracker.active_track_count(), 0);
434    }
435
436    #[test]
437    fn byte_track_below_low_threshold_ignored() {
438        // Detection below low_threshold should be completely ignored.
439        let mut tracker = ByteTracker::new(0.5, 0.1, 0.3, 5);
440
441        let dets = [det(10.0, 10.0, 50.0, 50.0, 0.05)]; // below low_threshold=0.1
442        let tracked = tracker.update(&dets);
443        assert_eq!(tracked.len(), 0);
444        assert_eq!(tracker.active_track_count(), 0);
445    }
446
447    #[test]
448    fn byte_track_config_different_max_age() {
449        // max_age=0 means tracks die immediately when unmatched.
450        let mut tracker = ByteTracker::new(0.5, 0.1, 0.3, 0);
451        tracker.update(&[det(10.0, 10.0, 50.0, 50.0, 0.9)]);
452        assert_eq!(tracker.active_track_count(), 1);
453
454        // One empty frame → gone immediately.
455        tracker.update(&[]);
456        assert_eq!(tracker.active_track_count(), 0);
457    }
458
459    #[test]
460    fn byte_track_config_different_iou_threshold() {
461        // Very high IoU threshold: only exact overlaps match.
462        let mut tracker = ByteTracker::new(0.5, 0.1, 0.95, 5);
463        tracker.update(&[det(10.0, 10.0, 50.0, 50.0, 0.9)]);
464
465        // Slightly moved → IoU below 0.95 → new track created, old one unmatched.
466        tracker.update(&[det(15.0, 15.0, 55.0, 55.0, 0.9)]);
467        // Should have the old track (unmatched) plus a new one.
468        assert_eq!(tracker.active_track_count(), 2);
469    }
470
471    #[test]
472    fn byte_track_more_tracks_than_detections() {
473        let mut tracker = ByteTracker::new(0.5, 0.1, 0.3, 5);
474
475        // Create 3 tracks.
476        let dets = [
477            det(10.0, 10.0, 50.0, 50.0, 0.9),
478            det(100.0, 100.0, 140.0, 140.0, 0.8),
479            det(200.0, 200.0, 240.0, 240.0, 0.7),
480        ];
481        tracker.update(&dets);
482        assert_eq!(tracker.active_track_count(), 3);
483
484        // Only 1 detection: 2 tracks unmatched, should age but still alive.
485        tracker.update(&[det(12.0, 12.0, 52.0, 52.0, 0.9)]);
486        // All 3 tracks still exist since max_age=5 and only 1 frame missed.
487        assert_eq!(tracker.active_track_count(), 3);
488    }
489
490    #[test]
491    fn byte_track_more_detections_than_tracks() {
492        let mut tracker = ByteTracker::new(0.5, 0.1, 0.3, 5);
493
494        // Create 1 track.
495        tracker.update(&[det(10.0, 10.0, 50.0, 50.0, 0.9)]);
496        assert_eq!(tracker.active_track_count(), 1);
497
498        // 3 detections: 1 match + 2 new tracks.
499        let dets = [
500            det(12.0, 12.0, 52.0, 52.0, 0.9),
501            det(100.0, 100.0, 140.0, 140.0, 0.8),
502            det(200.0, 200.0, 240.0, 240.0, 0.7),
503        ];
504        tracker.update(&dets);
505        assert_eq!(tracker.active_track_count(), 3);
506    }
507}