Skip to main content

yscv_track/
tracker.rs

1use yscv_detect::{BoundingBox, CLASS_ID_PERSON, Detection, iou};
2
3use crate::motion::{
4    MotionState, apply_motion, bbox_size_similarity, normalized_center_distance,
5    update_motion_state,
6};
7use crate::{Track, TrackError, TrackedDetection, TrackerConfig};
8
9#[derive(Debug, Clone)]
10pub struct Tracker {
11    config: TrackerConfig,
12    next_track_id: u64,
13    tracks: Vec<Track>,
14    motion: Vec<MotionState>,
15    pair_candidates: Vec<(f32, usize, usize)>,
16    track_taken: Vec<bool>,
17    det_taken: Vec<bool>,
18    det_to_track: Vec<Option<u64>>,
19}
20
21impl Tracker {
22    /// Creates a tracker with validated configuration.
23    pub fn new(config: TrackerConfig) -> Result<Self, TrackError> {
24        config.validate()?;
25        Ok(Self {
26            config,
27            next_track_id: 1,
28            tracks: Vec::new(),
29            motion: Vec::new(),
30            pair_candidates: Vec::new(),
31            track_taken: Vec::new(),
32            det_taken: Vec::new(),
33            det_to_track: Vec::new(),
34        })
35    }
36
37    /// Updates tracker state for one frame and returns tracked detections.
38    ///
39    /// For allocation-sensitive runtime loops, prefer [`Tracker::update_into`]
40    /// with a caller-owned output buffer that can be reused across frames.
41    pub fn update(&mut self, detections: &[Detection]) -> Vec<TrackedDetection> {
42        let mut out = Vec::with_capacity(detections.len());
43        self.update_into(detections, &mut out);
44        out
45    }
46
47    /// Updates tracker state for one frame and writes tracked detections into `out`.
48    ///
49    /// This API allows callers to reuse `out` across frames and avoid
50    /// allocating a fresh output vector for each `update` call.
51    pub fn update_into(&mut self, detections: &[Detection], out: &mut Vec<TrackedDetection>) {
52        debug_assert_eq!(self.tracks.len(), self.motion.len());
53
54        self.pair_candidates.clear();
55        for (track_idx, track) in self.tracks.iter().enumerate() {
56            let predicted = self.predict_bbox(track_idx, track);
57            for (det_idx, det) in detections.iter().enumerate() {
58                if track.class_id != det.class_id {
59                    continue;
60                }
61                if let Some(match_score) = self.match_score(
62                    track.missed_frames,
63                    predicted,
64                    det.bbox,
65                    self.config.match_iou_threshold,
66                ) {
67                    self.pair_candidates.push((match_score, track_idx, det_idx));
68                }
69            }
70        }
71        self.pair_candidates
72            .sort_by(|left, right| right.0.total_cmp(&left.0));
73
74        self.track_taken.clear();
75        self.track_taken.resize(self.tracks.len(), false);
76        self.det_taken.clear();
77        self.det_taken.resize(detections.len(), false);
78        self.det_to_track.clear();
79        self.det_to_track.resize(detections.len(), None);
80
81        for (_match_score, track_idx, det_idx) in self.pair_candidates.iter().copied() {
82            if self.track_taken[track_idx] || self.det_taken[det_idx] {
83                continue;
84            }
85            self.track_taken[track_idx] = true;
86            self.det_taken[det_idx] = true;
87            let det = detections[det_idx];
88            let track = &mut self.tracks[track_idx];
89            let previous_bbox = track.bbox;
90            track.bbox = det.bbox;
91            track.score = det.score;
92            track.class_id = det.class_id;
93            track.age += 1;
94            track.hits += 1;
95            track.missed_frames = 0;
96            update_motion_state(&mut self.motion[track_idx], previous_bbox, det.bbox);
97            self.det_to_track[det_idx] = Some(track.id);
98        }
99
100        for (idx, track) in self.tracks.iter_mut().enumerate() {
101            if !self.track_taken[idx] {
102                track.bbox = apply_motion(track.bbox, &self.motion[idx], 1.0);
103                track.age += 1;
104                track.missed_frames += 1;
105            }
106        }
107
108        let mut write = 0usize;
109        for read in 0..self.tracks.len() {
110            if self.tracks[read].missed_frames <= self.config.max_missed_frames {
111                if write != read {
112                    self.tracks[write] = self.tracks[read];
113                    self.motion[write] = self.motion[read];
114                }
115                write += 1;
116            }
117        }
118        self.tracks.truncate(write);
119        self.motion.truncate(write);
120
121        for (det_idx, det) in detections.iter().enumerate() {
122            if self.det_taken[det_idx] {
123                continue;
124            }
125            if self.tracks.len() >= self.config.max_tracks {
126                break;
127            }
128            let track_id = self.alloc_track_id();
129            self.tracks.push(Track {
130                id: track_id,
131                bbox: det.bbox,
132                score: det.score,
133                class_id: det.class_id,
134                age: 1,
135                hits: 1,
136                missed_frames: 0,
137            });
138            self.motion.push(MotionState::default());
139            self.det_to_track[det_idx] = Some(track_id);
140        }
141
142        out.clear();
143        if out.capacity() < detections.len() {
144            out.reserve(detections.len() - out.capacity());
145        }
146        for (det_idx, det) in detections.iter().enumerate() {
147            if let Some(track_id) = self.det_to_track[det_idx] {
148                out.push(TrackedDetection {
149                    track_id,
150                    detection: *det,
151                });
152            }
153        }
154    }
155
156    /// Returns active tracks owned by this tracker.
157    pub fn active_tracks(&self) -> &[Track] {
158        &self.tracks
159    }
160
161    /// Counts active tracks for the provided class id.
162    pub fn count_by_class(&self, class_id: usize) -> usize {
163        self.tracks
164            .iter()
165            .filter(|track| track.class_id == class_id)
166            .count()
167    }
168
169    /// Counts active person-class tracks.
170    pub fn people_count(&self) -> usize {
171        self.count_by_class(CLASS_ID_PERSON)
172    }
173
174    fn alloc_track_id(&mut self) -> u64 {
175        let id = self.next_track_id;
176        self.next_track_id += 1;
177        id
178    }
179
180    fn predict_bbox(&self, track_idx: usize, track: &Track) -> BoundingBox {
181        if track.missed_frames == 0 {
182            return track.bbox;
183        }
184        apply_motion(track.bbox, &self.motion[track_idx], 1.0)
185    }
186
187    fn match_score(
188        &self,
189        missed_frames: u32,
190        predicted: BoundingBox,
191        detection: BoundingBox,
192        iou_threshold: f32,
193    ) -> Option<f32> {
194        let overlap = iou(predicted, detection);
195        if overlap >= iou_threshold {
196            return Some(overlap);
197        }
198
199        let center_distance = normalized_center_distance(predicted, detection);
200        let size_similarity = bbox_size_similarity(predicted, detection);
201        let proximity_score = 1.0 / (1.0 + center_distance);
202        let blended = 0.6 * overlap + 0.4 * proximity_score;
203
204        if missed_frames > 0 && center_distance <= 2.0 && size_similarity >= 0.5 && blended >= 0.35
205        {
206            Some(blended)
207        } else {
208            None
209        }
210    }
211}