1use std::collections::HashMap;
2
3use yscv_detect::{BoundingBox, iou};
4use yscv_track::TrackedDetection;
5
6use crate::EvalError;
7use crate::util::{harmonic_mean, safe_ratio, validate_iou_threshold};
8
9#[derive(Debug, Clone, Copy, PartialEq)]
10pub struct GroundTruthTrack {
11 pub object_id: u64,
12 pub bbox: BoundingBox,
13 pub class_id: usize,
14}
15
16#[derive(Debug, Clone, Copy, PartialEq)]
17pub struct TrackingFrame<'a> {
18 pub ground_truth: &'a [GroundTruthTrack],
19 pub predictions: &'a [TrackedDetection],
20}
21
22#[derive(Debug, Clone, Copy, PartialEq)]
23pub struct TrackingEvalConfig {
24 pub iou_threshold: f32,
25}
26
27impl Default for TrackingEvalConfig {
28 fn default() -> Self {
29 Self { iou_threshold: 0.5 }
30 }
31}
32
33impl TrackingEvalConfig {
34 pub fn validate(&self) -> Result<(), EvalError> {
35 validate_iou_threshold(self.iou_threshold)
36 }
37}
38
39#[derive(Debug, Clone, Copy, PartialEq)]
40pub struct TrackingMetrics {
41 pub total_ground_truth: u64,
42 pub matches: u64,
43 pub false_positives: u64,
44 pub false_negatives: u64,
45 pub id_switches: u64,
46 pub precision: f32,
47 pub recall: f32,
48 pub f1: f32,
49 pub mota: f32,
50 pub motp: f32,
51}
52
53#[derive(Debug, Clone, PartialEq)]
54pub struct TrackingDatasetFrame {
55 pub ground_truth: Vec<GroundTruthTrack>,
56 pub predictions: Vec<TrackedDetection>,
57}
58
59impl TrackingDatasetFrame {
60 pub fn as_view(&self) -> TrackingFrame<'_> {
61 TrackingFrame {
62 ground_truth: &self.ground_truth,
63 predictions: &self.predictions,
64 }
65 }
66}
67
68pub fn tracking_frames_as_view(frames: &[TrackingDatasetFrame]) -> Vec<TrackingFrame<'_>> {
69 frames.iter().map(TrackingDatasetFrame::as_view).collect()
70}
71
72pub fn evaluate_tracking_from_dataset(
73 frames: &[TrackingDatasetFrame],
74 config: TrackingEvalConfig,
75) -> Result<TrackingMetrics, EvalError> {
76 let borrowed = tracking_frames_as_view(frames);
77 evaluate_tracking(&borrowed, config)
78}
79
80fn greedy_iou_match(frame: &TrackingFrame<'_>, iou_threshold: f32) -> Vec<(usize, usize, f32)> {
82 let mut candidates = Vec::new();
83 for (gt_idx, gt) in frame.ground_truth.iter().enumerate() {
84 for (pred_idx, prediction) in frame.predictions.iter().enumerate() {
85 if gt.class_id != prediction.detection.class_id {
86 continue;
87 }
88 let overlap = iou(gt.bbox, prediction.detection.bbox);
89 if overlap >= iou_threshold {
90 candidates.push((overlap, gt_idx, pred_idx));
91 }
92 }
93 }
94 candidates.sort_by(|a, b| b.0.total_cmp(&a.0));
95
96 let mut gt_taken = vec![false; frame.ground_truth.len()];
97 let mut pred_taken = vec![false; frame.predictions.len()];
98 let mut matches = Vec::new();
99
100 for (overlap, gt_idx, pred_idx) in candidates {
101 if gt_taken[gt_idx] || pred_taken[pred_idx] {
102 continue;
103 }
104 gt_taken[gt_idx] = true;
105 pred_taken[pred_idx] = true;
106 matches.push((gt_idx, pred_idx, overlap));
107 }
108
109 matches
110}
111
112pub fn idf1(frames: &[TrackingFrame<'_>], config: TrackingEvalConfig) -> Result<f32, EvalError> {
114 config.validate()?;
115
116 let mut cooccurrence: HashMap<(u64, u64), u64> = HashMap::new();
118 let mut gt_appearances: HashMap<u64, u64> = HashMap::new();
119 let mut pred_appearances: HashMap<u64, u64> = HashMap::new();
120
121 for frame in frames {
122 for gt in frame.ground_truth {
123 *gt_appearances.entry(gt.object_id).or_insert(0) += 1;
124 }
125 for pred in frame.predictions {
126 *pred_appearances.entry(pred.track_id).or_insert(0) += 1;
127 }
128
129 let matches = greedy_iou_match(frame, config.iou_threshold);
130 for (gt_idx, pred_idx, _) in matches {
131 let gt_id = frame.ground_truth[gt_idx].object_id;
132 let pred_id = frame.predictions[pred_idx].track_id;
133 *cooccurrence.entry((gt_id, pred_id)).or_insert(0) += 1;
134 }
135 }
136
137 let total_gt: u64 = gt_appearances.values().sum();
138 let total_pred: u64 = pred_appearances.values().sum();
139
140 if total_gt == 0 && total_pred == 0 {
141 return Ok(0.0);
142 }
143
144 let mut best_for_gt: HashMap<u64, (u64, u64)> = HashMap::new(); for (&(gt_id, pred_id), &count) in &cooccurrence {
147 let entry = best_for_gt.entry(gt_id).or_insert((pred_id, 0));
148 if count > entry.1 {
149 *entry = (pred_id, count);
150 }
151 }
152
153 let idtp: u64 = best_for_gt.values().map(|(_, count)| count).sum();
154 let idfn = total_gt - idtp;
155 let idfp = total_pred - idtp;
156
157 let denom = 2 * idtp + idfp + idfn;
158 if denom == 0 {
159 return Ok(0.0);
160 }
161
162 Ok((2 * idtp) as f32 / denom as f32)
163}
164
165pub fn hota(frames: &[TrackingFrame<'_>], config: TrackingEvalConfig) -> Result<f32, EvalError> {
167 config.validate()?;
168
169 let mut all_matches: Vec<(u64, u64)> = Vec::new(); let mut total_tp = 0u64;
172 let mut total_fp = 0u64;
173 let mut total_fn = 0u64;
174
175 let mut pred_to_gt_per_frame: Vec<HashMap<u64, u64>> = Vec::new();
177 let mut gt_to_pred_per_frame: Vec<HashMap<u64, u64>> = Vec::new();
178
179 for frame in frames {
180 let matches = greedy_iou_match(frame, config.iou_threshold);
181 let tp = matches.len() as u64;
182 let fn_count = frame.ground_truth.len() as u64 - tp;
183 let fp_count = frame.predictions.len() as u64 - tp;
184
185 total_tp += tp;
186 total_fp += fp_count;
187 total_fn += fn_count;
188
189 let mut pred_to_gt = HashMap::new();
190 let mut gt_to_pred = HashMap::new();
191 for &(gt_idx, pred_idx, _) in &matches {
192 let gt_id = frame.ground_truth[gt_idx].object_id;
193 let pred_id = frame.predictions[pred_idx].track_id;
194 all_matches.push((gt_id, pred_id));
195 pred_to_gt.insert(pred_id, gt_id);
196 gt_to_pred.insert(gt_id, pred_id);
197 }
198 pred_to_gt_per_frame.push(pred_to_gt);
199 gt_to_pred_per_frame.push(gt_to_pred);
200 }
201
202 if total_tp == 0 {
203 return Ok(0.0);
204 }
205
206 let det_a = total_tp as f32 / (total_tp + total_fp + total_fn) as f32;
207
208 let mut ass_a_sum = 0.0f32;
210 let num_frames = frames.len();
211
212 for &(gt_id, pred_id) in &all_matches {
213 let mut tpa = 0u64;
214 let mut fpa = 0u64;
215 let mut fna = 0u64;
216
217 for f in 0..num_frames {
218 let pred_matched_gt = pred_to_gt_per_frame[f].get(&pred_id);
219 let gt_matched_pred = gt_to_pred_per_frame[f].get(>_id);
220
221 match (pred_matched_gt, gt_matched_pred) {
222 (Some(&matched_gt), Some(&matched_pred))
223 if matched_gt == gt_id && matched_pred == pred_id =>
224 {
225 tpa += 1;
226 }
227 _ => {
228 if let Some(&matched_gt) = pred_matched_gt
230 && matched_gt != gt_id
231 {
232 fpa += 1;
233 }
234 if let Some(&matched_pred) = gt_matched_pred
236 && matched_pred != pred_id
237 {
238 fna += 1;
239 }
240 }
241 }
242 }
243
244 let denom = tpa + fpa + fna;
245 if denom > 0 {
246 ass_a_sum += tpa as f32 / denom as f32;
247 }
248 }
249
250 let ass_a = ass_a_sum / all_matches.len() as f32;
251 Ok((det_a * ass_a).sqrt())
252}
253
254pub fn evaluate_tracking(
255 frames: &[TrackingFrame<'_>],
256 config: TrackingEvalConfig,
257) -> Result<TrackingMetrics, EvalError> {
258 config.validate()?;
259
260 let mut total_ground_truth = 0u64;
261 let mut matches = 0u64;
262 let mut false_positives = 0u64;
263 let mut false_negatives = 0u64;
264 let mut id_switches = 0u64;
265 let mut iou_sum = 0.0f32;
266 let mut last_assignment: HashMap<u64, u64> = HashMap::new();
267
268 for frame in frames {
269 total_ground_truth += frame.ground_truth.len() as u64;
270
271 let mut candidates = Vec::new();
272 for (gt_idx, gt) in frame.ground_truth.iter().enumerate() {
273 for (pred_idx, prediction) in frame.predictions.iter().enumerate() {
274 if gt.class_id != prediction.detection.class_id {
275 continue;
276 }
277 let overlap = iou(gt.bbox, prediction.detection.bbox);
278 if overlap >= config.iou_threshold {
279 candidates.push((overlap, gt_idx, pred_idx));
280 }
281 }
282 }
283 candidates.sort_by(|a, b| b.0.total_cmp(&a.0));
284
285 let mut gt_taken = vec![false; frame.ground_truth.len()];
286 let mut pred_taken = vec![false; frame.predictions.len()];
287
288 for (overlap, gt_idx, pred_idx) in candidates {
289 if gt_taken[gt_idx] || pred_taken[pred_idx] {
290 continue;
291 }
292
293 gt_taken[gt_idx] = true;
294 pred_taken[pred_idx] = true;
295 matches += 1;
296 iou_sum += overlap;
297
298 let gt_id = frame.ground_truth[gt_idx].object_id;
299 let pred_id = frame.predictions[pred_idx].track_id;
300 if let Some(previous_pred_id) = last_assignment.get(>_id)
301 && *previous_pred_id != pred_id
302 {
303 id_switches += 1;
304 }
305 last_assignment.insert(gt_id, pred_id);
306 }
307
308 false_negatives += gt_taken.iter().filter(|matched| !**matched).count() as u64;
309 false_positives += pred_taken.iter().filter(|matched| !**matched).count() as u64;
310 }
311
312 let precision = safe_ratio(matches, matches + false_positives);
313 let recall = safe_ratio(matches, matches + false_negatives);
314 let f1 = harmonic_mean(precision, recall);
315 let motp = if matches == 0 {
316 0.0
317 } else {
318 iou_sum / matches as f32
319 };
320 let mota = if total_ground_truth == 0 {
321 0.0
322 } else {
323 1.0 - ((false_negatives + false_positives + id_switches) as f32 / total_ground_truth as f32)
324 };
325
326 Ok(TrackingMetrics {
327 total_ground_truth,
328 matches,
329 false_positives,
330 false_negatives,
331 id_switches,
332 precision,
333 recall,
334 f1,
335 mota,
336 motp,
337 })
338}