1use yscv_detect::{BoundingBox, CLASS_ID_PERSON, Detection, iou};
2
3use crate::motion::{
4 MotionState, apply_motion, bbox_size_similarity, normalized_center_distance,
5 update_motion_state,
6};
7use crate::{Track, TrackError, TrackedDetection, TrackerConfig};
8
9#[derive(Debug, Clone)]
10pub struct Tracker {
11 config: TrackerConfig,
12 next_track_id: u64,
13 tracks: Vec<Track>,
14 motion: Vec<MotionState>,
15 pair_candidates: Vec<(f32, usize, usize)>,
16 track_taken: Vec<bool>,
17 det_taken: Vec<bool>,
18 det_to_track: Vec<Option<u64>>,
19}
20
21impl Tracker {
22 pub fn new(config: TrackerConfig) -> Result<Self, TrackError> {
24 config.validate()?;
25 Ok(Self {
26 config,
27 next_track_id: 1,
28 tracks: Vec::new(),
29 motion: Vec::new(),
30 pair_candidates: Vec::new(),
31 track_taken: Vec::new(),
32 det_taken: Vec::new(),
33 det_to_track: Vec::new(),
34 })
35 }
36
37 pub fn update(&mut self, detections: &[Detection]) -> Vec<TrackedDetection> {
42 let mut out = Vec::with_capacity(detections.len());
43 self.update_into(detections, &mut out);
44 out
45 }
46
47 pub fn update_into(&mut self, detections: &[Detection], out: &mut Vec<TrackedDetection>) {
52 debug_assert_eq!(self.tracks.len(), self.motion.len());
53
54 self.pair_candidates.clear();
55 for (track_idx, track) in self.tracks.iter().enumerate() {
56 let predicted = self.predict_bbox(track_idx, track);
57 for (det_idx, det) in detections.iter().enumerate() {
58 if track.class_id != det.class_id {
59 continue;
60 }
61 if let Some(match_score) = self.match_score(
62 track.missed_frames,
63 predicted,
64 det.bbox,
65 self.config.match_iou_threshold,
66 ) {
67 self.pair_candidates.push((match_score, track_idx, det_idx));
68 }
69 }
70 }
71 self.pair_candidates
72 .sort_by(|left, right| right.0.total_cmp(&left.0));
73
74 self.track_taken.clear();
75 self.track_taken.resize(self.tracks.len(), false);
76 self.det_taken.clear();
77 self.det_taken.resize(detections.len(), false);
78 self.det_to_track.clear();
79 self.det_to_track.resize(detections.len(), None);
80
81 for (_match_score, track_idx, det_idx) in self.pair_candidates.iter().copied() {
82 if self.track_taken[track_idx] || self.det_taken[det_idx] {
83 continue;
84 }
85 self.track_taken[track_idx] = true;
86 self.det_taken[det_idx] = true;
87 let det = detections[det_idx];
88 let track = &mut self.tracks[track_idx];
89 let previous_bbox = track.bbox;
90 track.bbox = det.bbox;
91 track.score = det.score;
92 track.class_id = det.class_id;
93 track.age += 1;
94 track.hits += 1;
95 track.missed_frames = 0;
96 update_motion_state(&mut self.motion[track_idx], previous_bbox, det.bbox);
97 self.det_to_track[det_idx] = Some(track.id);
98 }
99
100 for (idx, track) in self.tracks.iter_mut().enumerate() {
101 if !self.track_taken[idx] {
102 track.bbox = apply_motion(track.bbox, &self.motion[idx], 1.0);
103 track.age += 1;
104 track.missed_frames += 1;
105 }
106 }
107
108 let mut write = 0usize;
109 for read in 0..self.tracks.len() {
110 if self.tracks[read].missed_frames <= self.config.max_missed_frames {
111 if write != read {
112 self.tracks[write] = self.tracks[read];
113 self.motion[write] = self.motion[read];
114 }
115 write += 1;
116 }
117 }
118 self.tracks.truncate(write);
119 self.motion.truncate(write);
120
121 for (det_idx, det) in detections.iter().enumerate() {
122 if self.det_taken[det_idx] {
123 continue;
124 }
125 if self.tracks.len() >= self.config.max_tracks {
126 break;
127 }
128 let track_id = self.alloc_track_id();
129 self.tracks.push(Track {
130 id: track_id,
131 bbox: det.bbox,
132 score: det.score,
133 class_id: det.class_id,
134 age: 1,
135 hits: 1,
136 missed_frames: 0,
137 });
138 self.motion.push(MotionState::default());
139 self.det_to_track[det_idx] = Some(track_id);
140 }
141
142 out.clear();
143 if out.capacity() < detections.len() {
144 out.reserve(detections.len() - out.capacity());
145 }
146 for (det_idx, det) in detections.iter().enumerate() {
147 if let Some(track_id) = self.det_to_track[det_idx] {
148 out.push(TrackedDetection {
149 track_id,
150 detection: *det,
151 });
152 }
153 }
154 }
155
156 pub fn active_tracks(&self) -> &[Track] {
158 &self.tracks
159 }
160
161 pub fn count_by_class(&self, class_id: usize) -> usize {
163 self.tracks
164 .iter()
165 .filter(|track| track.class_id == class_id)
166 .count()
167 }
168
169 pub fn people_count(&self) -> usize {
171 self.count_by_class(CLASS_ID_PERSON)
172 }
173
174 fn alloc_track_id(&mut self) -> u64 {
175 let id = self.next_track_id;
176 self.next_track_id += 1;
177 id
178 }
179
180 fn predict_bbox(&self, track_idx: usize, track: &Track) -> BoundingBox {
181 if track.missed_frames == 0 {
182 return track.bbox;
183 }
184 apply_motion(track.bbox, &self.motion[track_idx], 1.0)
185 }
186
187 fn match_score(
188 &self,
189 missed_frames: u32,
190 predicted: BoundingBox,
191 detection: BoundingBox,
192 iou_threshold: f32,
193 ) -> Option<f32> {
194 let overlap = iou(predicted, detection);
195 if overlap >= iou_threshold {
196 return Some(overlap);
197 }
198
199 let center_distance = normalized_center_distance(predicted, detection);
200 let size_similarity = bbox_size_similarity(predicted, detection);
201 let proximity_score = 1.0 / (1.0 + center_distance);
202 let blended = 0.6 * overlap + 0.4 * proximity_score;
203
204 if missed_frames > 0 && center_distance <= 2.0 && size_similarity >= 0.5 && blended >= 0.35
205 {
206 Some(blended)
207 } else {
208 None
209 }
210 }
211}