1use yscv_detect::{BoundingBox, Detection, iou};
8
9use crate::types::TrackedDetection;
10
11#[derive(Debug, Clone)]
13struct ByteTrack {
14 id: usize,
15 bbox: BoundingBox,
16 score: f32,
17 class_id: usize,
18 age: usize,
19 hits: usize,
20}
21
22pub struct ByteTracker {
28 next_id: usize,
29 tracks: Vec<ByteTrack>,
30 high_threshold: f32,
31 low_threshold: f32,
32 iou_threshold: f32,
33 max_age: usize,
34}
35
36impl ByteTracker {
37 pub fn new(
44 high_threshold: f32,
45 low_threshold: f32,
46 iou_threshold: f32,
47 max_age: usize,
48 ) -> Self {
49 Self {
50 next_id: 1,
51 tracks: Vec::new(),
52 high_threshold,
53 low_threshold,
54 iou_threshold,
55 max_age,
56 }
57 }
58
59 pub fn update(&mut self, detections: &[Detection]) -> Vec<TrackedDetection> {
63 let mut high: Vec<usize> = Vec::new();
65 let mut low: Vec<usize> = Vec::new();
66 for (i, det) in detections.iter().enumerate() {
67 if det.score >= self.high_threshold {
68 high.push(i);
69 } else if det.score >= self.low_threshold {
70 low.push(i);
71 }
72 }
73
74 let all_track_indices: Vec<usize> = (0..self.tracks.len()).collect();
75 let mut matched_tracks: Vec<bool> = vec![false; self.tracks.len()];
76 let mut matched_dets: Vec<bool> = vec![false; detections.len()];
77
78 let assignments1 = greedy_match(
80 &self.tracks,
81 detections,
82 &all_track_indices,
83 &high,
84 self.iou_threshold,
85 );
86
87 for &(ti, di) in &assignments1 {
88 matched_tracks[ti] = true;
89 matched_dets[di] = true;
90 self.tracks[ti].bbox = detections[di].bbox;
91 self.tracks[ti].score = detections[di].score;
92 self.tracks[ti].class_id = detections[di].class_id;
93 self.tracks[ti].age = 0;
94 self.tracks[ti].hits += 1;
95 }
96
97 let unmatched_track_indices: Vec<usize> = (0..self.tracks.len())
99 .filter(|&i| !matched_tracks[i])
100 .collect();
101
102 let assignments2 = greedy_match(
103 &self.tracks,
104 detections,
105 &unmatched_track_indices,
106 &low,
107 self.iou_threshold,
108 );
109
110 for &(ti, di) in &assignments2 {
111 matched_tracks[ti] = true;
112 matched_dets[di] = true;
113 self.tracks[ti].bbox = detections[di].bbox;
114 self.tracks[ti].score = detections[di].score;
115 self.tracks[ti].class_id = detections[di].class_id;
116 self.tracks[ti].age = 0;
117 self.tracks[ti].hits += 1;
118 }
119
120 for &di in &high {
122 if !matched_dets[di] {
123 let id = self.next_id;
124 self.next_id += 1;
125 self.tracks.push(ByteTrack {
126 id,
127 bbox: detections[di].bbox,
128 score: detections[di].score,
129 class_id: detections[di].class_id,
130 age: 0,
131 hits: 1,
132 });
133 }
134 }
135
136 for (i, track) in self.tracks.iter_mut().enumerate() {
138 if i < matched_tracks.len() && !matched_tracks[i] {
139 track.age += 1;
140 }
141 }
142 self.tracks.retain(|t| t.age <= self.max_age);
143
144 self.tracks
146 .iter()
147 .map(|t| TrackedDetection {
148 track_id: t.id as u64,
149 detection: Detection {
150 bbox: t.bbox,
151 score: t.score,
152 class_id: t.class_id,
153 },
154 })
155 .collect()
156 }
157
158 pub fn active_track_count(&self) -> usize {
160 self.tracks.len()
161 }
162}
163
164fn greedy_match(
167 tracks: &[ByteTrack],
168 detections: &[Detection],
169 track_indices: &[usize],
170 det_indices: &[usize],
171 iou_threshold: f32,
172) -> Vec<(usize, usize)> {
173 let mut used_dets = vec![false; detections.len()];
174 let mut assignments = Vec::new();
175
176 for &ti in track_indices {
177 let mut best_iou = iou_threshold;
178 let mut best_di: Option<usize> = None;
179 for &di in det_indices {
180 if used_dets[di] {
181 continue;
182 }
183 let iou_val = iou(tracks[ti].bbox, detections[di].bbox);
184 if iou_val >= best_iou {
185 best_iou = iou_val;
186 best_di = Some(di);
187 }
188 }
189 if let Some(di) = best_di {
190 used_dets[di] = true;
191 assignments.push((ti, di));
192 }
193 }
194
195 assignments
196}
197
198#[cfg(test)]
199mod tests {
200 use super::*;
201
202 fn det(x1: f32, y1: f32, x2: f32, y2: f32, score: f32) -> Detection {
203 Detection {
204 bbox: BoundingBox { x1, y1, x2, y2 },
205 score,
206 class_id: 0,
207 }
208 }
209
210 #[test]
211 fn byte_track_creates_tracks() {
212 let mut tracker = ByteTracker::new(0.5, 0.1, 0.3, 3);
213 let dets = [
214 det(10.0, 10.0, 50.0, 50.0, 0.9),
215 det(100.0, 100.0, 150.0, 150.0, 0.8),
216 ];
217 let tracked = tracker.update(&dets);
218 assert_eq!(tracked.len(), 2);
219 assert_eq!(tracker.active_track_count(), 2);
220 }
221
222 #[test]
223 fn byte_track_maintains_ids() {
224 let mut tracker = ByteTracker::new(0.5, 0.1, 0.3, 3);
225 let dets1 = [
226 det(10.0, 10.0, 50.0, 50.0, 0.9),
227 det(100.0, 100.0, 150.0, 150.0, 0.8),
228 ];
229 let tracked1 = tracker.update(&dets1);
230 let id0 = tracked1[0].track_id;
231 let id1 = tracked1[1].track_id;
232
233 let dets2 = [
235 det(12.0, 12.0, 52.0, 52.0, 0.9),
236 det(102.0, 102.0, 152.0, 152.0, 0.85),
237 ];
238 let tracked2 = tracker.update(&dets2);
239 assert_eq!(tracked2.len(), 2);
240
241 let ids: Vec<u64> = tracked2.iter().map(|t| t.track_id).collect();
242 assert!(ids.contains(&id0));
243 assert!(ids.contains(&id1));
244 }
245
246 #[test]
247 fn byte_track_removes_old_tracks() {
248 let max_age = 2;
249 let mut tracker = ByteTracker::new(0.5, 0.1, 0.3, max_age);
250
251 let dets = [det(10.0, 10.0, 50.0, 50.0, 0.9)];
253 tracker.update(&dets);
254 assert_eq!(tracker.active_track_count(), 1);
255
256 for _ in 0..=max_age {
258 tracker.update(&[]);
259 }
260 assert_eq!(tracker.active_track_count(), 0);
261 }
262
263 #[test]
264 fn byte_track_low_confidence_association() {
265 let mut tracker = ByteTracker::new(0.5, 0.1, 0.3, 3);
266
267 let dets1 = [det(10.0, 10.0, 50.0, 50.0, 0.9)];
269 let tracked1 = tracker.update(&dets1);
270 let id = tracked1[0].track_id;
271
272 let dets2 = [det(12.0, 12.0, 52.0, 52.0, 0.3)];
274 let tracked2 = tracker.update(&dets2);
275
276 assert_eq!(tracked2.len(), 1);
278 assert_eq!(tracked2[0].track_id, id);
279 }
280
281 #[test]
282 fn byte_track_new_track_for_new_object() {
283 let mut tracker = ByteTracker::new(0.5, 0.1, 0.3, 3);
284
285 let dets1 = [det(10.0, 10.0, 50.0, 50.0, 0.9)];
287 let tracked1 = tracker.update(&dets1);
288 assert_eq!(tracked1.len(), 1);
289 let id1 = tracked1[0].track_id;
290
291 let dets2 = [
293 det(12.0, 12.0, 52.0, 52.0, 0.9),
294 det(200.0, 200.0, 250.0, 250.0, 0.8),
295 ];
296 let tracked2 = tracker.update(&dets2);
297 assert_eq!(tracked2.len(), 2);
298
299 let ids: Vec<u64> = tracked2.iter().map(|t| t.track_id).collect();
300 assert!(ids.contains(&id1));
301 let new_id = ids.iter().find(|&&id| id != id1).unwrap();
303 assert_ne!(*new_id, id1);
304 }
305
306 #[test]
307 fn byte_track_three_objects_simultaneously() {
308 let mut tracker = ByteTracker::new(0.5, 0.1, 0.3, 5);
309 let dets = [
310 det(10.0, 10.0, 50.0, 50.0, 0.9),
311 det(100.0, 100.0, 140.0, 140.0, 0.8),
312 det(200.0, 200.0, 240.0, 240.0, 0.7),
313 ];
314 let tracked = tracker.update(&dets);
315 assert_eq!(tracked.len(), 3);
316 assert_eq!(tracker.active_track_count(), 3);
317
318 let ids: Vec<u64> = tracked.iter().map(|t| t.track_id).collect();
320 let mut unique = ids.clone();
321 unique.sort();
322 unique.dedup();
323 assert_eq!(unique.len(), 3);
324 }
325
326 #[test]
327 fn byte_track_empty_detections_ages_tracks() {
328 let mut tracker = ByteTracker::new(0.5, 0.1, 0.3, 5);
329 tracker.update(&[det(10.0, 10.0, 50.0, 50.0, 0.9)]);
330 assert_eq!(tracker.active_track_count(), 1);
331
332 let result = tracker.update(&[]);
334 assert_eq!(tracker.active_track_count(), 1);
335 assert_eq!(result.len(), 1); }
337
338 #[test]
339 fn byte_track_single_detection_stable_id() {
340 let mut tracker = ByteTracker::new(0.5, 0.1, 0.3, 5);
341 let d = det(20.0, 20.0, 60.0, 60.0, 0.9);
342 let first = tracker.update(&[d]);
343 let id = first[0].track_id;
344
345 for _ in 0..10 {
346 let tracked = tracker.update(&[d]);
347 assert_eq!(tracked.len(), 1);
348 assert_eq!(tracked[0].track_id, id);
349 }
350 }
351
352 #[test]
353 fn byte_track_id_stability_smooth_motion() {
354 let mut tracker = ByteTracker::new(0.5, 0.1, 0.3, 5);
355 let first = tracker.update(&[det(10.0, 10.0, 50.0, 50.0, 0.9)]);
356 let id = first[0].track_id;
357
358 let positions = [
360 (12.0, 12.0, 52.0, 52.0),
361 (14.0, 14.0, 54.0, 54.0),
362 (16.0, 16.0, 56.0, 56.0),
363 (18.0, 18.0, 58.0, 58.0),
364 ];
365 for (x1, y1, x2, y2) in positions {
366 let tracked = tracker.update(&[det(x1, y1, x2, y2, 0.9)]);
367 assert_eq!(tracked.len(), 1);
368 assert_eq!(tracked[0].track_id, id);
369 }
370 }
371
372 #[test]
373 fn byte_track_iou_matching_overlapping_bboxes() {
374 let mut tracker = ByteTracker::new(0.5, 0.1, 0.3, 5);
375
376 let dets1 = [
378 det(10.0, 10.0, 50.0, 50.0, 0.9),
379 det(30.0, 30.0, 70.0, 70.0, 0.8),
380 ];
381 let tracked1 = tracker.update(&dets1);
382 assert_eq!(tracked1.len(), 2);
383 let id_a = tracked1[0].track_id;
384 let id_b = tracked1[1].track_id;
385 assert_ne!(id_a, id_b);
386
387 let dets2 = [
389 det(11.0, 11.0, 51.0, 51.0, 0.9),
390 det(31.0, 31.0, 71.0, 71.0, 0.8),
391 ];
392 let tracked2 = tracker.update(&dets2);
393 let ids2: Vec<u64> = tracked2.iter().map(|t| t.track_id).collect();
394 assert!(ids2.contains(&id_a));
395 assert!(ids2.contains(&id_b));
396 }
397
398 #[test]
399 fn byte_track_low_vs_high_confidence_precedence() {
400 let mut tracker = ByteTracker::new(0.5, 0.1, 0.3, 5);
402
403 let dets1 = [
405 det(10.0, 10.0, 50.0, 50.0, 0.9),
406 det(100.0, 100.0, 140.0, 140.0, 0.8),
407 ];
408 let tracked1 = tracker.update(&dets1);
409 let id_a = tracked1[0].track_id;
410 let id_b = tracked1[1].track_id;
411
412 let dets2 = [
414 det(12.0, 12.0, 52.0, 52.0, 0.9), det(102.0, 102.0, 142.0, 142.0, 0.2), ];
417 let tracked2 = tracker.update(&dets2);
418 assert_eq!(tracked2.len(), 2);
419 let ids2: Vec<u64> = tracked2.iter().map(|t| t.track_id).collect();
420 assert!(ids2.contains(&id_a));
421 assert!(ids2.contains(&id_b));
422 }
423
424 #[test]
425 fn byte_track_low_confidence_no_new_track() {
426 let mut tracker = ByteTracker::new(0.5, 0.1, 0.3, 5);
428
429 let dets = [det(10.0, 10.0, 50.0, 50.0, 0.2)];
431 let tracked = tracker.update(&dets);
432 assert_eq!(tracked.len(), 0);
433 assert_eq!(tracker.active_track_count(), 0);
434 }
435
436 #[test]
437 fn byte_track_below_low_threshold_ignored() {
438 let mut tracker = ByteTracker::new(0.5, 0.1, 0.3, 5);
440
441 let dets = [det(10.0, 10.0, 50.0, 50.0, 0.05)]; let tracked = tracker.update(&dets);
443 assert_eq!(tracked.len(), 0);
444 assert_eq!(tracker.active_track_count(), 0);
445 }
446
447 #[test]
448 fn byte_track_config_different_max_age() {
449 let mut tracker = ByteTracker::new(0.5, 0.1, 0.3, 0);
451 tracker.update(&[det(10.0, 10.0, 50.0, 50.0, 0.9)]);
452 assert_eq!(tracker.active_track_count(), 1);
453
454 tracker.update(&[]);
456 assert_eq!(tracker.active_track_count(), 0);
457 }
458
459 #[test]
460 fn byte_track_config_different_iou_threshold() {
461 let mut tracker = ByteTracker::new(0.5, 0.1, 0.95, 5);
463 tracker.update(&[det(10.0, 10.0, 50.0, 50.0, 0.9)]);
464
465 tracker.update(&[det(15.0, 15.0, 55.0, 55.0, 0.9)]);
467 assert_eq!(tracker.active_track_count(), 2);
469 }
470
471 #[test]
472 fn byte_track_more_tracks_than_detections() {
473 let mut tracker = ByteTracker::new(0.5, 0.1, 0.3, 5);
474
475 let dets = [
477 det(10.0, 10.0, 50.0, 50.0, 0.9),
478 det(100.0, 100.0, 140.0, 140.0, 0.8),
479 det(200.0, 200.0, 240.0, 240.0, 0.7),
480 ];
481 tracker.update(&dets);
482 assert_eq!(tracker.active_track_count(), 3);
483
484 tracker.update(&[det(12.0, 12.0, 52.0, 52.0, 0.9)]);
486 assert_eq!(tracker.active_track_count(), 3);
488 }
489
490 #[test]
491 fn byte_track_more_detections_than_tracks() {
492 let mut tracker = ByteTracker::new(0.5, 0.1, 0.3, 5);
493
494 tracker.update(&[det(10.0, 10.0, 50.0, 50.0, 0.9)]);
496 assert_eq!(tracker.active_track_count(), 1);
497
498 let dets = [
500 det(12.0, 12.0, 52.0, 52.0, 0.9),
501 det(100.0, 100.0, 140.0, 140.0, 0.8),
502 det(200.0, 200.0, 240.0, 240.0, 0.7),
503 ];
504 tracker.update(&dets);
505 assert_eq!(tracker.active_track_count(), 3);
506 }
507}