Skip to main content

yscv_track/
deep_sort.rs

1//! DeepSORT-style multi-object tracker with appearance features and cascade matching.
2
3use yscv_detect::{Detection, iou};
4
5use crate::KalmanFilter;
6use crate::hungarian::hungarian_assignment;
7
8/// Configuration for the DeepSORT tracker.
9#[derive(Debug, Clone)]
10pub struct DeepSortConfig {
11    /// Maximum cosine distance for appearance matching.
12    pub max_cosine_distance: f32,
13    /// Maximum IoU distance for fallback matching.
14    pub max_iou_distance: f32,
15    /// Number of frames to keep a track alive without detection.
16    pub max_age: usize,
17    /// Number of consecutive hits to confirm a track.
18    pub n_init: usize,
19}
20
21impl Default for DeepSortConfig {
22    fn default() -> Self {
23        Self {
24            max_cosine_distance: 0.3,
25            max_iou_distance: 0.7,
26            max_age: 30,
27            n_init: 3,
28        }
29    }
30}
31
32/// Track state in the DeepSORT lifecycle.
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
34pub enum TrackState {
35    /// Track has been created but not yet confirmed.
36    Tentative,
37    /// Track has been confirmed (enough consecutive hits).
38    Confirmed,
39    /// Track has been marked for deletion.
40    Deleted,
41}
42
43/// A tracked object with appearance features.
44#[derive(Debug, Clone)]
45pub struct DeepSortTrack {
46    /// Unique track identifier.
47    pub id: usize,
48    /// Current track state.
49    pub state: TrackState,
50    /// Kalman filter for motion prediction.
51    pub kalman: KalmanFilter,
52    /// Feature history for appearance matching (last N features).
53    pub features: Vec<Vec<f32>>,
54    /// Total number of frames this track has been matched.
55    pub hits: usize,
56    /// Total number of frames since track creation.
57    pub age: usize,
58    /// Number of consecutive frames without a matching detection.
59    pub time_since_update: usize,
60}
61
62/// DeepSORT multi-object tracker.
63pub struct DeepSortTracker {
64    config: DeepSortConfig,
65    tracks: Vec<DeepSortTrack>,
66    next_id: usize,
67}
68
69impl DeepSortTracker {
70    /// Create a new DeepSORT tracker with the given configuration.
71    pub fn new(config: DeepSortConfig) -> Self {
72        Self {
73            config,
74            tracks: Vec::new(),
75            next_id: 1,
76        }
77    }
78
79    /// Predict the next state for all tracks using their Kalman filters.
80    pub fn predict(&mut self) {
81        for track in &mut self.tracks {
82            track.kalman.predict();
83            track.age += 1;
84            track.time_since_update += 1;
85        }
86    }
87
88    /// Main update step: match detections to tracks and update state.
89    ///
90    /// `detections` — the current frame's detections.
91    /// `features` — optional appearance feature vectors, one per detection.
92    pub fn update(&mut self, detections: &[Detection], features: Option<&[Vec<f32>]>) {
93        // Split track indices into confirmed and unconfirmed.
94        let mut confirmed_indices: Vec<usize> = Vec::new();
95        let mut unconfirmed_indices: Vec<usize> = Vec::new();
96        for (i, track) in self.tracks.iter().enumerate() {
97            match track.state {
98                TrackState::Confirmed => confirmed_indices.push(i),
99                TrackState::Tentative => unconfirmed_indices.push(i),
100                TrackState::Deleted => {}
101            }
102        }
103
104        let n_dets = detections.len();
105        let mut matched_tracks: Vec<bool> = vec![false; self.tracks.len()];
106        let mut matched_dets: Vec<bool> = vec![false; n_dets];
107
108        // ── Stage 1: Cascade matching (appearance) on confirmed tracks ──
109        if let Some(feats) = features
110            && !confirmed_indices.is_empty()
111            && !detections.is_empty()
112        {
113            let n_tracks = confirmed_indices.len();
114            let mut cost_matrix = vec![vec![0.0_f32; n_dets]; n_tracks];
115            for (ti, &track_idx) in confirmed_indices.iter().enumerate() {
116                let track = &self.tracks[track_idx];
117                for dj in 0..n_dets {
118                    if track.features.is_empty() {
119                        // No features yet — use a high cost so IoU matching picks it up.
120                        cost_matrix[ti][dj] = self.config.max_cosine_distance + 1.0;
121                    } else {
122                        cost_matrix[ti][dj] = min_cosine_distance(&feats[dj], &track.features);
123                    }
124                }
125            }
126
127            let assignments = hungarian_assignment(&cost_matrix);
128            for (ti, dj) in assignments {
129                if cost_matrix[ti][dj] <= self.config.max_cosine_distance {
130                    let track_idx = confirmed_indices[ti];
131                    matched_tracks[track_idx] = true;
132                    matched_dets[dj] = true;
133                    self.update_track(track_idx, &detections[dj], Some(&feats[dj]));
134                }
135            }
136        }
137
138        // ── Stage 2: IoU matching on remaining tracks vs remaining detections ──
139        // Collect unmatched track indices (confirmed that weren't matched + all unconfirmed).
140        let mut iou_track_indices: Vec<usize> = Vec::new();
141        for &ti in &confirmed_indices {
142            if !matched_tracks[ti] {
143                iou_track_indices.push(ti);
144            }
145        }
146        iou_track_indices.extend_from_slice(&unconfirmed_indices);
147
148        let unmatched_det_indices: Vec<usize> = (0..n_dets).filter(|&d| !matched_dets[d]).collect();
149
150        if !iou_track_indices.is_empty() && !unmatched_det_indices.is_empty() {
151            let n_t = iou_track_indices.len();
152            let n_d = unmatched_det_indices.len();
153            let mut cost_matrix = vec![vec![0.0_f32; n_d]; n_t];
154            for (ti, &track_idx) in iou_track_indices.iter().enumerate() {
155                let predicted = self.tracks[track_idx].kalman.bbox();
156                for (dj, &det_idx) in unmatched_det_indices.iter().enumerate() {
157                    let iou_val = iou(predicted, detections[det_idx].bbox);
158                    cost_matrix[ti][dj] = 1.0 - iou_val; // IoU distance
159                }
160            }
161
162            let assignments = hungarian_assignment(&cost_matrix);
163            for (ti, dj) in assignments {
164                if cost_matrix[ti][dj] <= self.config.max_iou_distance {
165                    let track_idx = iou_track_indices[ti];
166                    let det_idx = unmatched_det_indices[dj];
167                    matched_tracks[track_idx] = true;
168                    matched_dets[det_idx] = true;
169                    let feat = features.map(|f| &f[det_idx] as &[f32]);
170                    self.update_track(track_idx, &detections[det_idx], feat);
171                }
172            }
173        }
174
175        // ── Create new tracks for unmatched detections ──
176        for det_idx in 0..n_dets {
177            if matched_dets[det_idx] {
178                continue;
179            }
180            let feat = features.map(|f| f[det_idx].clone());
181            self.create_track(&detections[det_idx], feat);
182        }
183
184        // ── Mark unmatched tracks ──
185        for (i, track) in self.tracks.iter_mut().enumerate() {
186            if matched_tracks.get(i).copied().unwrap_or(false) {
187                continue;
188            }
189            if track.state == TrackState::Deleted {
190                continue;
191            }
192            // For newly created tracks (not in matched_tracks vec), skip.
193            if i >= matched_tracks.len() {
194                continue;
195            }
196            if track.state == TrackState::Tentative && track.time_since_update > 0 {
197                track.state = TrackState::Deleted;
198            } else if track.time_since_update > self.config.max_age {
199                track.state = TrackState::Deleted;
200            }
201        }
202
203        // Remove deleted tracks.
204        self.tracks.retain(|t| t.state != TrackState::Deleted);
205    }
206
207    /// Get all active (non-deleted) tracks.
208    pub fn tracks(&self) -> &[DeepSortTrack] {
209        &self.tracks
210    }
211
212    /// Get only confirmed tracks.
213    pub fn confirmed_tracks(&self) -> Vec<&DeepSortTrack> {
214        self.tracks
215            .iter()
216            .filter(|t| t.state == TrackState::Confirmed)
217            .collect()
218    }
219
220    fn update_track(&mut self, track_idx: usize, detection: &Detection, feature: Option<&[f32]>) {
221        let track = &mut self.tracks[track_idx];
222        let bbox = detection.bbox;
223        let cx = (bbox.x1 + bbox.x2) * 0.5;
224        let cy = (bbox.y1 + bbox.y2) * 0.5;
225        let w = bbox.width();
226        let h = bbox.height();
227        track.kalman.update([cx, cy, w, h]);
228        track.hits += 1;
229        track.time_since_update = 0;
230        if let Some(feat) = feature {
231            track.features.push(feat.to_vec());
232            // Keep only last 100 features.
233            if track.features.len() > 100 {
234                track.features.remove(0);
235            }
236        }
237        if track.state == TrackState::Tentative && track.hits >= self.config.n_init {
238            track.state = TrackState::Confirmed;
239        }
240    }
241
242    fn create_track(&mut self, detection: &Detection, feature: Option<Vec<f32>>) {
243        let id = self.next_id;
244        self.next_id += 1;
245        let kalman = KalmanFilter::new(detection.bbox);
246        let mut features = Vec::new();
247        if let Some(f) = feature {
248            features.push(f);
249        }
250        self.tracks.push(DeepSortTrack {
251            id,
252            state: TrackState::Tentative,
253            kalman,
254            features,
255            hits: 1,
256            age: 1,
257            time_since_update: 0,
258        });
259    }
260}
261
262/// Cosine distance between two feature vectors.
263fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
264    let mut dot = 0.0_f32;
265    let mut norm_a = 0.0_f32;
266    let mut norm_b = 0.0_f32;
267    for (&ai, &bi) in a.iter().zip(b.iter()) {
268        dot += ai * bi;
269        norm_a += ai * ai;
270        norm_b += bi * bi;
271    }
272    let denom = norm_a.sqrt() * norm_b.sqrt();
273    if denom < 1e-12 {
274        return 1.0; // Maximum distance if either vector is zero.
275    }
276    1.0 - (dot / denom)
277}
278
279/// Minimum cosine distance between a feature and a gallery of features.
280fn min_cosine_distance(feature: &[f32], gallery: &[Vec<f32>]) -> f32 {
281    gallery
282        .iter()
283        .map(|g| cosine_distance(feature, g))
284        .fold(f32::INFINITY, f32::min)
285}
286
287#[cfg(test)]
288mod tests {
289    use super::*;
290    use yscv_detect::BoundingBox;
291
292    fn make_detection(x1: f32, y1: f32, x2: f32, y2: f32) -> Detection {
293        Detection {
294            bbox: BoundingBox { x1, y1, x2, y2 },
295            score: 0.9,
296            class_id: 0,
297        }
298    }
299
300    #[test]
301    fn test_deep_sort_creation() {
302        let tracker = DeepSortTracker::new(DeepSortConfig::default());
303        assert!(tracker.tracks().is_empty());
304        assert!(tracker.confirmed_tracks().is_empty());
305    }
306
307    #[test]
308    fn test_deep_sort_single_detection() {
309        let mut tracker = DeepSortTracker::new(DeepSortConfig::default());
310        let dets = [make_detection(10.0, 10.0, 50.0, 50.0)];
311        tracker.predict();
312        tracker.update(&dets, None);
313        assert_eq!(tracker.tracks().len(), 1);
314        assert_eq!(tracker.tracks()[0].state, TrackState::Tentative);
315    }
316
317    #[test]
318    fn test_deep_sort_track_confirmation() {
319        let config = DeepSortConfig {
320            n_init: 3,
321            ..DeepSortConfig::default()
322        };
323        let mut tracker = DeepSortTracker::new(config);
324        let det = make_detection(10.0, 10.0, 50.0, 50.0);
325
326        // First detection creates tentative track.
327        tracker.predict();
328        tracker.update(&[det], None);
329        assert_eq!(tracker.tracks()[0].state, TrackState::Tentative);
330
331        // Second hit.
332        tracker.predict();
333        tracker.update(&[det], None);
334        assert_eq!(tracker.tracks()[0].state, TrackState::Tentative);
335
336        // Third hit → confirmed.
337        tracker.predict();
338        tracker.update(&[det], None);
339        assert_eq!(tracker.tracks()[0].state, TrackState::Confirmed);
340    }
341
342    #[test]
343    fn test_deep_sort_track_deletion() {
344        let config = DeepSortConfig {
345            max_age: 2,
346            n_init: 1,
347            ..DeepSortConfig::default()
348        };
349        let mut tracker = DeepSortTracker::new(config);
350        let det = make_detection(10.0, 10.0, 50.0, 50.0);
351
352        // Create and confirm a track.
353        tracker.predict();
354        tracker.update(&[det], None);
355        assert_eq!(tracker.tracks().len(), 1);
356
357        // No detections for max_age+1 frames → deleted.
358        for _ in 0..4 {
359            tracker.predict();
360            tracker.update(&[], None);
361        }
362        assert!(tracker.tracks().is_empty());
363    }
364
365    #[test]
366    fn test_deep_sort_iou_matching() {
367        let config = DeepSortConfig {
368            n_init: 1,
369            max_age: 5,
370            ..DeepSortConfig::default()
371        };
372        let mut tracker = DeepSortTracker::new(config);
373
374        // Frame 1: two detections.
375        let dets1 = [
376            make_detection(10.0, 10.0, 50.0, 50.0),
377            make_detection(100.0, 100.0, 150.0, 150.0),
378        ];
379        tracker.predict();
380        tracker.update(&dets1, None);
381        assert_eq!(tracker.tracks().len(), 2);
382        let id0 = tracker.tracks()[0].id;
383        let id1 = tracker.tracks()[1].id;
384
385        // Frame 2: same detections, slightly moved.
386        let dets2 = [
387            make_detection(12.0, 12.0, 52.0, 52.0),
388            make_detection(102.0, 102.0, 152.0, 152.0),
389        ];
390        tracker.predict();
391        tracker.update(&dets2, None);
392        assert_eq!(tracker.tracks().len(), 2);
393
394        // Track IDs should be preserved (same objects matched).
395        let ids: Vec<usize> = tracker.tracks().iter().map(|t| t.id).collect();
396        assert!(ids.contains(&id0));
397        assert!(ids.contains(&id1));
398    }
399
400    #[test]
401    fn test_cosine_distance() {
402        // Identical vectors → distance 0.
403        let a = vec![1.0, 0.0, 0.0];
404        let b = vec![1.0, 0.0, 0.0];
405        assert!((cosine_distance(&a, &b) - 0.0).abs() < 1e-6);
406
407        // Orthogonal vectors → distance 1.
408        let c = vec![0.0, 1.0, 0.0];
409        assert!((cosine_distance(&a, &c) - 1.0).abs() < 1e-6);
410
411        // Opposite vectors → distance 2.
412        let d = vec![-1.0, 0.0, 0.0];
413        assert!((cosine_distance(&a, &d) - 2.0).abs() < 1e-6);
414    }
415
416    #[test]
417    fn test_deep_sort_multiple_objects_tracked() {
418        let config = DeepSortConfig {
419            n_init: 1,
420            max_age: 5,
421            ..DeepSortConfig::default()
422        };
423        let mut tracker = DeepSortTracker::new(config);
424
425        // Three detections far apart.
426        let dets = [
427            make_detection(10.0, 10.0, 50.0, 50.0),
428            make_detection(100.0, 100.0, 140.0, 140.0),
429            make_detection(200.0, 200.0, 240.0, 240.0),
430        ];
431        tracker.predict();
432        tracker.update(&dets, None);
433        assert_eq!(tracker.tracks().len(), 3);
434
435        let ids: Vec<usize> = tracker.tracks().iter().map(|t| t.id).collect();
436        // All IDs should be unique.
437        let mut unique = ids.clone();
438        unique.sort();
439        unique.dedup();
440        assert_eq!(unique.len(), 3);
441    }
442
443    #[test]
444    fn test_deep_sort_occlusion_and_reappearance() {
445        // Confirmed tracks survive occlusion up to max_age.
446        let config = DeepSortConfig {
447            n_init: 2,
448            max_age: 5,
449            ..DeepSortConfig::default()
450        };
451        let mut tracker = DeepSortTracker::new(config);
452
453        let det = make_detection(10.0, 10.0, 50.0, 50.0);
454        // First frame: creates tentative track.
455        tracker.predict();
456        tracker.update(&[det], None);
457        assert_eq!(tracker.tracks()[0].state, TrackState::Tentative);
458        let original_id = tracker.tracks()[0].id;
459
460        // Second frame: confirms the track (hits=2 >= n_init=2).
461        tracker.predict();
462        tracker.update(&[det], None);
463        assert_eq!(tracker.tracks()[0].state, TrackState::Confirmed);
464
465        // Object disappears for 3 frames (within max_age=5).
466        for _ in 0..3 {
467            tracker.predict();
468            tracker.update(&[], None);
469        }
470        // Confirmed track should still exist (time_since_update=3 <= max_age=5).
471        assert!(!tracker.tracks().is_empty());
472
473        // Object reappears at same position.
474        tracker.predict();
475        tracker.update(&[det], None);
476        let ids: Vec<usize> = tracker.tracks().iter().map(|t| t.id).collect();
477        assert!(ids.contains(&original_id));
478    }
479
480    #[test]
481    fn test_deep_sort_id_stability_smooth_motion() {
482        let config = DeepSortConfig {
483            n_init: 1,
484            max_age: 5,
485            ..DeepSortConfig::default()
486        };
487        let mut tracker = DeepSortTracker::new(config);
488
489        // Object moves smoothly across 5 frames.
490        let positions = [
491            (10.0, 10.0, 50.0, 50.0),
492            (12.0, 12.0, 52.0, 52.0),
493            (14.0, 14.0, 54.0, 54.0),
494            (16.0, 16.0, 56.0, 56.0),
495            (18.0, 18.0, 58.0, 58.0),
496        ];
497
498        tracker.predict();
499        tracker.update(
500            &[make_detection(
501                positions[0].0,
502                positions[0].1,
503                positions[0].2,
504                positions[0].3,
505            )],
506            None,
507        );
508        let original_id = tracker.tracks()[0].id;
509
510        for &(x1, y1, x2, y2) in &positions[1..] {
511            tracker.predict();
512            tracker.update(&[make_detection(x1, y1, x2, y2)], None);
513            assert_eq!(tracker.tracks().len(), 1);
514            assert_eq!(tracker.tracks()[0].id, original_id);
515        }
516    }
517
518    #[test]
519    fn test_deep_sort_deletion_after_max_age() {
520        let config = DeepSortConfig {
521            max_age: 3,
522            n_init: 1,
523            ..DeepSortConfig::default()
524        };
525        let mut tracker = DeepSortTracker::new(config);
526
527        tracker.predict();
528        tracker.update(&[make_detection(10.0, 10.0, 50.0, 50.0)], None);
529        assert_eq!(tracker.tracks().len(), 1);
530
531        // Exactly max_age frames without detection.
532        for _ in 0..3 {
533            tracker.predict();
534            tracker.update(&[], None);
535        }
536        // Should still be alive (time_since_update==3, max_age==3, deletion is >).
537        // One more frame to exceed max_age.
538        tracker.predict();
539        tracker.update(&[], None);
540        assert!(tracker.tracks().is_empty());
541    }
542
543    #[test]
544    fn test_deep_sort_new_track_far_apart() {
545        let config = DeepSortConfig {
546            n_init: 1,
547            max_age: 5,
548            ..DeepSortConfig::default()
549        };
550        let mut tracker = DeepSortTracker::new(config);
551
552        // Create and confirm a track with two hits.
553        let det1 = make_detection(10.0, 10.0, 50.0, 50.0);
554        tracker.predict();
555        tracker.update(&[det1], None);
556        let id1 = tracker.tracks()[0].id;
557        tracker.predict();
558        tracker.update(&[det1], None);
559
560        // Detection very far away: original track unmatched, new track created.
561        tracker.predict();
562        tracker.update(&[make_detection(500.0, 500.0, 540.0, 540.0)], None);
563
564        let ids: Vec<usize> = tracker.tracks().iter().map(|t| t.id).collect();
565        assert!(ids.len() >= 2);
566        assert!(
567            ids.iter().any(|&id| id != id1),
568            "Should have created a new track"
569        );
570    }
571
572    #[test]
573    fn test_deep_sort_empty_detections_ages_tracks() {
574        let config = DeepSortConfig {
575            n_init: 1,
576            max_age: 10,
577            ..DeepSortConfig::default()
578        };
579        let mut tracker = DeepSortTracker::new(config);
580
581        let det = make_detection(10.0, 10.0, 50.0, 50.0);
582        // Create and confirm track with two hits so it survives a miss.
583        tracker.predict();
584        tracker.update(&[det], None);
585        tracker.predict();
586        tracker.update(&[det], None);
587        assert_eq!(tracker.tracks()[0].time_since_update, 0);
588        assert_eq!(tracker.tracks()[0].state, TrackState::Confirmed);
589
590        // Empty frame should increment time_since_update for confirmed tracks.
591        tracker.predict();
592        tracker.update(&[], None);
593        assert_eq!(tracker.tracks().len(), 1);
594        assert!(tracker.tracks()[0].time_since_update > 0);
595    }
596
597    #[test]
598    fn test_deep_sort_single_detection_stable_id() {
599        let config = DeepSortConfig {
600            n_init: 1,
601            max_age: 5,
602            ..DeepSortConfig::default()
603        };
604        let mut tracker = DeepSortTracker::new(config);
605
606        let det = make_detection(20.0, 20.0, 60.0, 60.0);
607
608        tracker.predict();
609        tracker.update(&[det], None);
610        let id = tracker.tracks()[0].id;
611
612        // Repeat same detection for 10 frames.
613        for _ in 0..10 {
614            tracker.predict();
615            tracker.update(&[det], None);
616            assert_eq!(tracker.tracks().len(), 1);
617            assert_eq!(tracker.tracks()[0].id, id);
618        }
619    }
620
621    #[test]
622    fn test_deep_sort_overlapping_detections() {
623        let config = DeepSortConfig {
624            n_init: 1,
625            max_age: 5,
626            ..DeepSortConfig::default()
627        };
628        let mut tracker = DeepSortTracker::new(config);
629
630        // Multiple heavily overlapping detections.
631        let dets = [
632            make_detection(10.0, 10.0, 50.0, 50.0),
633            make_detection(12.0, 12.0, 52.0, 52.0),
634            make_detection(14.0, 14.0, 54.0, 54.0),
635        ];
636        tracker.predict();
637        tracker.update(&dets, None);
638        // Each detection should create a track (all are unmatched initially).
639        assert_eq!(tracker.tracks().len(), 3);
640    }
641
642    #[test]
643    fn test_deep_sort_config_variations_max_age() {
644        // Very short max_age.
645        let config = DeepSortConfig {
646            max_age: 1,
647            n_init: 1,
648            ..DeepSortConfig::default()
649        };
650        let mut tracker = DeepSortTracker::new(config);
651
652        tracker.predict();
653        tracker.update(&[make_detection(10.0, 10.0, 50.0, 50.0)], None);
654
655        // Two empty frames should delete with max_age=1.
656        tracker.predict();
657        tracker.update(&[], None);
658        tracker.predict();
659        tracker.update(&[], None);
660        assert!(tracker.tracks().is_empty());
661    }
662
663    #[test]
664    fn test_deep_sort_config_variations_n_init() {
665        // Need 5 hits to confirm.
666        let config = DeepSortConfig {
667            n_init: 5,
668            max_age: 30,
669            ..DeepSortConfig::default()
670        };
671        let mut tracker = DeepSortTracker::new(config);
672        let det = make_detection(10.0, 10.0, 50.0, 50.0);
673
674        for i in 0..5 {
675            tracker.predict();
676            tracker.update(&[det], None);
677            if i < 4 {
678                assert_eq!(tracker.tracks()[0].state, TrackState::Tentative);
679            }
680        }
681        assert_eq!(tracker.tracks()[0].state, TrackState::Confirmed);
682    }
683
684    #[test]
685    fn test_deep_sort_appearance_matching_with_features() {
686        let config = DeepSortConfig {
687            n_init: 1,
688            max_age: 5,
689            max_cosine_distance: 0.5,
690            ..DeepSortConfig::default()
691        };
692        let mut tracker = DeepSortTracker::new(config);
693
694        let det = make_detection(10.0, 10.0, 50.0, 50.0);
695        let feat = vec![vec![1.0, 0.0, 0.0]];
696        tracker.predict();
697        tracker.update(&[det], Some(&feat));
698        let id = tracker.tracks()[0].id;
699        assert!(!tracker.tracks()[0].features.is_empty());
700
701        // Same feature, slightly moved detection.
702        let det2 = make_detection(12.0, 12.0, 52.0, 52.0);
703        let feat2 = vec![vec![0.99, 0.1, 0.0]]; // similar feature
704        tracker.predict();
705        tracker.update(&[det2], Some(&feat2));
706        assert_eq!(tracker.tracks()[0].id, id);
707    }
708
709    #[test]
710    fn test_deep_sort_confirmed_tracks_filter() {
711        let config = DeepSortConfig {
712            n_init: 3,
713            max_age: 10,
714            ..DeepSortConfig::default()
715        };
716        let mut tracker = DeepSortTracker::new(config);
717        let det = make_detection(10.0, 10.0, 50.0, 50.0);
718
719        // After 1 hit: tentative.
720        tracker.predict();
721        tracker.update(&[det], None);
722        assert!(tracker.confirmed_tracks().is_empty());
723
724        // After 3 hits: confirmed.
725        tracker.predict();
726        tracker.update(&[det], None);
727        tracker.predict();
728        tracker.update(&[det], None);
729        assert_eq!(tracker.confirmed_tracks().len(), 1);
730    }
731
732    #[test]
733    fn test_min_cosine_distance_gallery() {
734        let gallery = vec![vec![1.0, 0.0, 0.0], vec![0.0, 1.0, 0.0]];
735        let query = vec![0.9, 0.1, 0.0]; // closer to gallery[0]
736        let dist = min_cosine_distance(&query, &gallery);
737        // Should be close to 0 (matching gallery[0]).
738        assert!(dist < 0.2);
739    }
740}