1use yscv_detect::{Detection, iou};
4
5use crate::KalmanFilter;
6use crate::hungarian::hungarian_assignment;
7
8#[derive(Debug, Clone)]
10pub struct DeepSortConfig {
11 pub max_cosine_distance: f32,
13 pub max_iou_distance: f32,
15 pub max_age: usize,
17 pub n_init: usize,
19}
20
21impl Default for DeepSortConfig {
22 fn default() -> Self {
23 Self {
24 max_cosine_distance: 0.3,
25 max_iou_distance: 0.7,
26 max_age: 30,
27 n_init: 3,
28 }
29 }
30}
31
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
34pub enum TrackState {
35 Tentative,
37 Confirmed,
39 Deleted,
41}
42
43#[derive(Debug, Clone)]
45pub struct DeepSortTrack {
46 pub id: usize,
48 pub state: TrackState,
50 pub kalman: KalmanFilter,
52 pub features: Vec<Vec<f32>>,
54 pub hits: usize,
56 pub age: usize,
58 pub time_since_update: usize,
60}
61
62pub struct DeepSortTracker {
64 config: DeepSortConfig,
65 tracks: Vec<DeepSortTrack>,
66 next_id: usize,
67}
68
69impl DeepSortTracker {
70 pub fn new(config: DeepSortConfig) -> Self {
72 Self {
73 config,
74 tracks: Vec::new(),
75 next_id: 1,
76 }
77 }
78
79 pub fn predict(&mut self) {
81 for track in &mut self.tracks {
82 track.kalman.predict();
83 track.age += 1;
84 track.time_since_update += 1;
85 }
86 }
87
88 pub fn update(&mut self, detections: &[Detection], features: Option<&[Vec<f32>]>) {
93 let mut confirmed_indices: Vec<usize> = Vec::new();
95 let mut unconfirmed_indices: Vec<usize> = Vec::new();
96 for (i, track) in self.tracks.iter().enumerate() {
97 match track.state {
98 TrackState::Confirmed => confirmed_indices.push(i),
99 TrackState::Tentative => unconfirmed_indices.push(i),
100 TrackState::Deleted => {}
101 }
102 }
103
104 let n_dets = detections.len();
105 let mut matched_tracks: Vec<bool> = vec![false; self.tracks.len()];
106 let mut matched_dets: Vec<bool> = vec![false; n_dets];
107
108 if let Some(feats) = features
110 && !confirmed_indices.is_empty()
111 && !detections.is_empty()
112 {
113 let n_tracks = confirmed_indices.len();
114 let mut cost_matrix = vec![vec![0.0_f32; n_dets]; n_tracks];
115 for (ti, &track_idx) in confirmed_indices.iter().enumerate() {
116 let track = &self.tracks[track_idx];
117 for dj in 0..n_dets {
118 if track.features.is_empty() {
119 cost_matrix[ti][dj] = self.config.max_cosine_distance + 1.0;
121 } else {
122 cost_matrix[ti][dj] = min_cosine_distance(&feats[dj], &track.features);
123 }
124 }
125 }
126
127 let assignments = hungarian_assignment(&cost_matrix);
128 for (ti, dj) in assignments {
129 if cost_matrix[ti][dj] <= self.config.max_cosine_distance {
130 let track_idx = confirmed_indices[ti];
131 matched_tracks[track_idx] = true;
132 matched_dets[dj] = true;
133 self.update_track(track_idx, &detections[dj], Some(&feats[dj]));
134 }
135 }
136 }
137
138 let mut iou_track_indices: Vec<usize> = Vec::new();
141 for &ti in &confirmed_indices {
142 if !matched_tracks[ti] {
143 iou_track_indices.push(ti);
144 }
145 }
146 iou_track_indices.extend_from_slice(&unconfirmed_indices);
147
148 let unmatched_det_indices: Vec<usize> = (0..n_dets).filter(|&d| !matched_dets[d]).collect();
149
150 if !iou_track_indices.is_empty() && !unmatched_det_indices.is_empty() {
151 let n_t = iou_track_indices.len();
152 let n_d = unmatched_det_indices.len();
153 let mut cost_matrix = vec![vec![0.0_f32; n_d]; n_t];
154 for (ti, &track_idx) in iou_track_indices.iter().enumerate() {
155 let predicted = self.tracks[track_idx].kalman.bbox();
156 for (dj, &det_idx) in unmatched_det_indices.iter().enumerate() {
157 let iou_val = iou(predicted, detections[det_idx].bbox);
158 cost_matrix[ti][dj] = 1.0 - iou_val; }
160 }
161
162 let assignments = hungarian_assignment(&cost_matrix);
163 for (ti, dj) in assignments {
164 if cost_matrix[ti][dj] <= self.config.max_iou_distance {
165 let track_idx = iou_track_indices[ti];
166 let det_idx = unmatched_det_indices[dj];
167 matched_tracks[track_idx] = true;
168 matched_dets[det_idx] = true;
169 let feat = features.map(|f| &f[det_idx] as &[f32]);
170 self.update_track(track_idx, &detections[det_idx], feat);
171 }
172 }
173 }
174
175 for det_idx in 0..n_dets {
177 if matched_dets[det_idx] {
178 continue;
179 }
180 let feat = features.map(|f| f[det_idx].clone());
181 self.create_track(&detections[det_idx], feat);
182 }
183
184 for (i, track) in self.tracks.iter_mut().enumerate() {
186 if matched_tracks.get(i).copied().unwrap_or(false) {
187 continue;
188 }
189 if track.state == TrackState::Deleted {
190 continue;
191 }
192 if i >= matched_tracks.len() {
194 continue;
195 }
196 if track.state == TrackState::Tentative && track.time_since_update > 0 {
197 track.state = TrackState::Deleted;
198 } else if track.time_since_update > self.config.max_age {
199 track.state = TrackState::Deleted;
200 }
201 }
202
203 self.tracks.retain(|t| t.state != TrackState::Deleted);
205 }
206
207 pub fn tracks(&self) -> &[DeepSortTrack] {
209 &self.tracks
210 }
211
212 pub fn confirmed_tracks(&self) -> Vec<&DeepSortTrack> {
214 self.tracks
215 .iter()
216 .filter(|t| t.state == TrackState::Confirmed)
217 .collect()
218 }
219
220 fn update_track(&mut self, track_idx: usize, detection: &Detection, feature: Option<&[f32]>) {
221 let track = &mut self.tracks[track_idx];
222 let bbox = detection.bbox;
223 let cx = (bbox.x1 + bbox.x2) * 0.5;
224 let cy = (bbox.y1 + bbox.y2) * 0.5;
225 let w = bbox.width();
226 let h = bbox.height();
227 track.kalman.update([cx, cy, w, h]);
228 track.hits += 1;
229 track.time_since_update = 0;
230 if let Some(feat) = feature {
231 track.features.push(feat.to_vec());
232 if track.features.len() > 100 {
234 track.features.remove(0);
235 }
236 }
237 if track.state == TrackState::Tentative && track.hits >= self.config.n_init {
238 track.state = TrackState::Confirmed;
239 }
240 }
241
242 fn create_track(&mut self, detection: &Detection, feature: Option<Vec<f32>>) {
243 let id = self.next_id;
244 self.next_id += 1;
245 let kalman = KalmanFilter::new(detection.bbox);
246 let mut features = Vec::new();
247 if let Some(f) = feature {
248 features.push(f);
249 }
250 self.tracks.push(DeepSortTrack {
251 id,
252 state: TrackState::Tentative,
253 kalman,
254 features,
255 hits: 1,
256 age: 1,
257 time_since_update: 0,
258 });
259 }
260}
261
262fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
264 let mut dot = 0.0_f32;
265 let mut norm_a = 0.0_f32;
266 let mut norm_b = 0.0_f32;
267 for (&ai, &bi) in a.iter().zip(b.iter()) {
268 dot += ai * bi;
269 norm_a += ai * ai;
270 norm_b += bi * bi;
271 }
272 let denom = norm_a.sqrt() * norm_b.sqrt();
273 if denom < 1e-12 {
274 return 1.0; }
276 1.0 - (dot / denom)
277}
278
279fn min_cosine_distance(feature: &[f32], gallery: &[Vec<f32>]) -> f32 {
281 gallery
282 .iter()
283 .map(|g| cosine_distance(feature, g))
284 .fold(f32::INFINITY, f32::min)
285}
286
287#[cfg(test)]
288mod tests {
289 use super::*;
290 use yscv_detect::BoundingBox;
291
292 fn make_detection(x1: f32, y1: f32, x2: f32, y2: f32) -> Detection {
293 Detection {
294 bbox: BoundingBox { x1, y1, x2, y2 },
295 score: 0.9,
296 class_id: 0,
297 }
298 }
299
300 #[test]
301 fn test_deep_sort_creation() {
302 let tracker = DeepSortTracker::new(DeepSortConfig::default());
303 assert!(tracker.tracks().is_empty());
304 assert!(tracker.confirmed_tracks().is_empty());
305 }
306
307 #[test]
308 fn test_deep_sort_single_detection() {
309 let mut tracker = DeepSortTracker::new(DeepSortConfig::default());
310 let dets = [make_detection(10.0, 10.0, 50.0, 50.0)];
311 tracker.predict();
312 tracker.update(&dets, None);
313 assert_eq!(tracker.tracks().len(), 1);
314 assert_eq!(tracker.tracks()[0].state, TrackState::Tentative);
315 }
316
317 #[test]
318 fn test_deep_sort_track_confirmation() {
319 let config = DeepSortConfig {
320 n_init: 3,
321 ..DeepSortConfig::default()
322 };
323 let mut tracker = DeepSortTracker::new(config);
324 let det = make_detection(10.0, 10.0, 50.0, 50.0);
325
326 tracker.predict();
328 tracker.update(&[det], None);
329 assert_eq!(tracker.tracks()[0].state, TrackState::Tentative);
330
331 tracker.predict();
333 tracker.update(&[det], None);
334 assert_eq!(tracker.tracks()[0].state, TrackState::Tentative);
335
336 tracker.predict();
338 tracker.update(&[det], None);
339 assert_eq!(tracker.tracks()[0].state, TrackState::Confirmed);
340 }
341
342 #[test]
343 fn test_deep_sort_track_deletion() {
344 let config = DeepSortConfig {
345 max_age: 2,
346 n_init: 1,
347 ..DeepSortConfig::default()
348 };
349 let mut tracker = DeepSortTracker::new(config);
350 let det = make_detection(10.0, 10.0, 50.0, 50.0);
351
352 tracker.predict();
354 tracker.update(&[det], None);
355 assert_eq!(tracker.tracks().len(), 1);
356
357 for _ in 0..4 {
359 tracker.predict();
360 tracker.update(&[], None);
361 }
362 assert!(tracker.tracks().is_empty());
363 }
364
365 #[test]
366 fn test_deep_sort_iou_matching() {
367 let config = DeepSortConfig {
368 n_init: 1,
369 max_age: 5,
370 ..DeepSortConfig::default()
371 };
372 let mut tracker = DeepSortTracker::new(config);
373
374 let dets1 = [
376 make_detection(10.0, 10.0, 50.0, 50.0),
377 make_detection(100.0, 100.0, 150.0, 150.0),
378 ];
379 tracker.predict();
380 tracker.update(&dets1, None);
381 assert_eq!(tracker.tracks().len(), 2);
382 let id0 = tracker.tracks()[0].id;
383 let id1 = tracker.tracks()[1].id;
384
385 let dets2 = [
387 make_detection(12.0, 12.0, 52.0, 52.0),
388 make_detection(102.0, 102.0, 152.0, 152.0),
389 ];
390 tracker.predict();
391 tracker.update(&dets2, None);
392 assert_eq!(tracker.tracks().len(), 2);
393
394 let ids: Vec<usize> = tracker.tracks().iter().map(|t| t.id).collect();
396 assert!(ids.contains(&id0));
397 assert!(ids.contains(&id1));
398 }
399
400 #[test]
401 fn test_cosine_distance() {
402 let a = vec![1.0, 0.0, 0.0];
404 let b = vec![1.0, 0.0, 0.0];
405 assert!((cosine_distance(&a, &b) - 0.0).abs() < 1e-6);
406
407 let c = vec![0.0, 1.0, 0.0];
409 assert!((cosine_distance(&a, &c) - 1.0).abs() < 1e-6);
410
411 let d = vec![-1.0, 0.0, 0.0];
413 assert!((cosine_distance(&a, &d) - 2.0).abs() < 1e-6);
414 }
415
416 #[test]
417 fn test_deep_sort_multiple_objects_tracked() {
418 let config = DeepSortConfig {
419 n_init: 1,
420 max_age: 5,
421 ..DeepSortConfig::default()
422 };
423 let mut tracker = DeepSortTracker::new(config);
424
425 let dets = [
427 make_detection(10.0, 10.0, 50.0, 50.0),
428 make_detection(100.0, 100.0, 140.0, 140.0),
429 make_detection(200.0, 200.0, 240.0, 240.0),
430 ];
431 tracker.predict();
432 tracker.update(&dets, None);
433 assert_eq!(tracker.tracks().len(), 3);
434
435 let ids: Vec<usize> = tracker.tracks().iter().map(|t| t.id).collect();
436 let mut unique = ids.clone();
438 unique.sort();
439 unique.dedup();
440 assert_eq!(unique.len(), 3);
441 }
442
443 #[test]
444 fn test_deep_sort_occlusion_and_reappearance() {
445 let config = DeepSortConfig {
447 n_init: 2,
448 max_age: 5,
449 ..DeepSortConfig::default()
450 };
451 let mut tracker = DeepSortTracker::new(config);
452
453 let det = make_detection(10.0, 10.0, 50.0, 50.0);
454 tracker.predict();
456 tracker.update(&[det], None);
457 assert_eq!(tracker.tracks()[0].state, TrackState::Tentative);
458 let original_id = tracker.tracks()[0].id;
459
460 tracker.predict();
462 tracker.update(&[det], None);
463 assert_eq!(tracker.tracks()[0].state, TrackState::Confirmed);
464
465 for _ in 0..3 {
467 tracker.predict();
468 tracker.update(&[], None);
469 }
470 assert!(!tracker.tracks().is_empty());
472
473 tracker.predict();
475 tracker.update(&[det], None);
476 let ids: Vec<usize> = tracker.tracks().iter().map(|t| t.id).collect();
477 assert!(ids.contains(&original_id));
478 }
479
480 #[test]
481 fn test_deep_sort_id_stability_smooth_motion() {
482 let config = DeepSortConfig {
483 n_init: 1,
484 max_age: 5,
485 ..DeepSortConfig::default()
486 };
487 let mut tracker = DeepSortTracker::new(config);
488
489 let positions = [
491 (10.0, 10.0, 50.0, 50.0),
492 (12.0, 12.0, 52.0, 52.0),
493 (14.0, 14.0, 54.0, 54.0),
494 (16.0, 16.0, 56.0, 56.0),
495 (18.0, 18.0, 58.0, 58.0),
496 ];
497
498 tracker.predict();
499 tracker.update(
500 &[make_detection(
501 positions[0].0,
502 positions[0].1,
503 positions[0].2,
504 positions[0].3,
505 )],
506 None,
507 );
508 let original_id = tracker.tracks()[0].id;
509
510 for &(x1, y1, x2, y2) in &positions[1..] {
511 tracker.predict();
512 tracker.update(&[make_detection(x1, y1, x2, y2)], None);
513 assert_eq!(tracker.tracks().len(), 1);
514 assert_eq!(tracker.tracks()[0].id, original_id);
515 }
516 }
517
518 #[test]
519 fn test_deep_sort_deletion_after_max_age() {
520 let config = DeepSortConfig {
521 max_age: 3,
522 n_init: 1,
523 ..DeepSortConfig::default()
524 };
525 let mut tracker = DeepSortTracker::new(config);
526
527 tracker.predict();
528 tracker.update(&[make_detection(10.0, 10.0, 50.0, 50.0)], None);
529 assert_eq!(tracker.tracks().len(), 1);
530
531 for _ in 0..3 {
533 tracker.predict();
534 tracker.update(&[], None);
535 }
536 tracker.predict();
539 tracker.update(&[], None);
540 assert!(tracker.tracks().is_empty());
541 }
542
543 #[test]
544 fn test_deep_sort_new_track_far_apart() {
545 let config = DeepSortConfig {
546 n_init: 1,
547 max_age: 5,
548 ..DeepSortConfig::default()
549 };
550 let mut tracker = DeepSortTracker::new(config);
551
552 let det1 = make_detection(10.0, 10.0, 50.0, 50.0);
554 tracker.predict();
555 tracker.update(&[det1], None);
556 let id1 = tracker.tracks()[0].id;
557 tracker.predict();
558 tracker.update(&[det1], None);
559
560 tracker.predict();
562 tracker.update(&[make_detection(500.0, 500.0, 540.0, 540.0)], None);
563
564 let ids: Vec<usize> = tracker.tracks().iter().map(|t| t.id).collect();
565 assert!(ids.len() >= 2);
566 assert!(
567 ids.iter().any(|&id| id != id1),
568 "Should have created a new track"
569 );
570 }
571
572 #[test]
573 fn test_deep_sort_empty_detections_ages_tracks() {
574 let config = DeepSortConfig {
575 n_init: 1,
576 max_age: 10,
577 ..DeepSortConfig::default()
578 };
579 let mut tracker = DeepSortTracker::new(config);
580
581 let det = make_detection(10.0, 10.0, 50.0, 50.0);
582 tracker.predict();
584 tracker.update(&[det], None);
585 tracker.predict();
586 tracker.update(&[det], None);
587 assert_eq!(tracker.tracks()[0].time_since_update, 0);
588 assert_eq!(tracker.tracks()[0].state, TrackState::Confirmed);
589
590 tracker.predict();
592 tracker.update(&[], None);
593 assert_eq!(tracker.tracks().len(), 1);
594 assert!(tracker.tracks()[0].time_since_update > 0);
595 }
596
597 #[test]
598 fn test_deep_sort_single_detection_stable_id() {
599 let config = DeepSortConfig {
600 n_init: 1,
601 max_age: 5,
602 ..DeepSortConfig::default()
603 };
604 let mut tracker = DeepSortTracker::new(config);
605
606 let det = make_detection(20.0, 20.0, 60.0, 60.0);
607
608 tracker.predict();
609 tracker.update(&[det], None);
610 let id = tracker.tracks()[0].id;
611
612 for _ in 0..10 {
614 tracker.predict();
615 tracker.update(&[det], None);
616 assert_eq!(tracker.tracks().len(), 1);
617 assert_eq!(tracker.tracks()[0].id, id);
618 }
619 }
620
621 #[test]
622 fn test_deep_sort_overlapping_detections() {
623 let config = DeepSortConfig {
624 n_init: 1,
625 max_age: 5,
626 ..DeepSortConfig::default()
627 };
628 let mut tracker = DeepSortTracker::new(config);
629
630 let dets = [
632 make_detection(10.0, 10.0, 50.0, 50.0),
633 make_detection(12.0, 12.0, 52.0, 52.0),
634 make_detection(14.0, 14.0, 54.0, 54.0),
635 ];
636 tracker.predict();
637 tracker.update(&dets, None);
638 assert_eq!(tracker.tracks().len(), 3);
640 }
641
642 #[test]
643 fn test_deep_sort_config_variations_max_age() {
644 let config = DeepSortConfig {
646 max_age: 1,
647 n_init: 1,
648 ..DeepSortConfig::default()
649 };
650 let mut tracker = DeepSortTracker::new(config);
651
652 tracker.predict();
653 tracker.update(&[make_detection(10.0, 10.0, 50.0, 50.0)], None);
654
655 tracker.predict();
657 tracker.update(&[], None);
658 tracker.predict();
659 tracker.update(&[], None);
660 assert!(tracker.tracks().is_empty());
661 }
662
663 #[test]
664 fn test_deep_sort_config_variations_n_init() {
665 let config = DeepSortConfig {
667 n_init: 5,
668 max_age: 30,
669 ..DeepSortConfig::default()
670 };
671 let mut tracker = DeepSortTracker::new(config);
672 let det = make_detection(10.0, 10.0, 50.0, 50.0);
673
674 for i in 0..5 {
675 tracker.predict();
676 tracker.update(&[det], None);
677 if i < 4 {
678 assert_eq!(tracker.tracks()[0].state, TrackState::Tentative);
679 }
680 }
681 assert_eq!(tracker.tracks()[0].state, TrackState::Confirmed);
682 }
683
684 #[test]
685 fn test_deep_sort_appearance_matching_with_features() {
686 let config = DeepSortConfig {
687 n_init: 1,
688 max_age: 5,
689 max_cosine_distance: 0.5,
690 ..DeepSortConfig::default()
691 };
692 let mut tracker = DeepSortTracker::new(config);
693
694 let det = make_detection(10.0, 10.0, 50.0, 50.0);
695 let feat = vec![vec![1.0, 0.0, 0.0]];
696 tracker.predict();
697 tracker.update(&[det], Some(&feat));
698 let id = tracker.tracks()[0].id;
699 assert!(!tracker.tracks()[0].features.is_empty());
700
701 let det2 = make_detection(12.0, 12.0, 52.0, 52.0);
703 let feat2 = vec![vec![0.99, 0.1, 0.0]]; tracker.predict();
705 tracker.update(&[det2], Some(&feat2));
706 assert_eq!(tracker.tracks()[0].id, id);
707 }
708
709 #[test]
710 fn test_deep_sort_confirmed_tracks_filter() {
711 let config = DeepSortConfig {
712 n_init: 3,
713 max_age: 10,
714 ..DeepSortConfig::default()
715 };
716 let mut tracker = DeepSortTracker::new(config);
717 let det = make_detection(10.0, 10.0, 50.0, 50.0);
718
719 tracker.predict();
721 tracker.update(&[det], None);
722 assert!(tracker.confirmed_tracks().is_empty());
723
724 tracker.predict();
726 tracker.update(&[det], None);
727 tracker.predict();
728 tracker.update(&[det], None);
729 assert_eq!(tracker.confirmed_tracks().len(), 1);
730 }
731
732 #[test]
733 fn test_min_cosine_distance_gallery() {
734 let gallery = vec![vec![1.0, 0.0, 0.0], vec![0.0, 1.0, 0.0]];
735 let query = vec![0.9, 0.1, 0.0]; let dist = min_cosine_distance(&query, &gallery);
737 assert!(dist < 0.2);
739 }
740}