Skip to main content

yscv_eval/
tracking.rs

1use std::collections::HashMap;
2
3use yscv_detect::{BoundingBox, iou};
4use yscv_track::TrackedDetection;
5
6use crate::EvalError;
7use crate::util::{harmonic_mean, safe_ratio, validate_iou_threshold};
8
9#[derive(Debug, Clone, Copy, PartialEq)]
10pub struct GroundTruthTrack {
11    pub object_id: u64,
12    pub bbox: BoundingBox,
13    pub class_id: usize,
14}
15
16#[derive(Debug, Clone, Copy, PartialEq)]
17pub struct TrackingFrame<'a> {
18    pub ground_truth: &'a [GroundTruthTrack],
19    pub predictions: &'a [TrackedDetection],
20}
21
22#[derive(Debug, Clone, Copy, PartialEq)]
23pub struct TrackingEvalConfig {
24    pub iou_threshold: f32,
25}
26
27impl Default for TrackingEvalConfig {
28    fn default() -> Self {
29        Self { iou_threshold: 0.5 }
30    }
31}
32
33impl TrackingEvalConfig {
34    pub fn validate(&self) -> Result<(), EvalError> {
35        validate_iou_threshold(self.iou_threshold)
36    }
37}
38
39#[derive(Debug, Clone, Copy, PartialEq)]
40pub struct TrackingMetrics {
41    pub total_ground_truth: u64,
42    pub matches: u64,
43    pub false_positives: u64,
44    pub false_negatives: u64,
45    pub id_switches: u64,
46    pub precision: f32,
47    pub recall: f32,
48    pub f1: f32,
49    pub mota: f32,
50    pub motp: f32,
51}
52
53#[derive(Debug, Clone, PartialEq)]
54pub struct TrackingDatasetFrame {
55    pub ground_truth: Vec<GroundTruthTrack>,
56    pub predictions: Vec<TrackedDetection>,
57}
58
59impl TrackingDatasetFrame {
60    pub fn as_view(&self) -> TrackingFrame<'_> {
61        TrackingFrame {
62            ground_truth: &self.ground_truth,
63            predictions: &self.predictions,
64        }
65    }
66}
67
68pub fn tracking_frames_as_view(frames: &[TrackingDatasetFrame]) -> Vec<TrackingFrame<'_>> {
69    frames.iter().map(TrackingDatasetFrame::as_view).collect()
70}
71
72pub fn evaluate_tracking_from_dataset(
73    frames: &[TrackingDatasetFrame],
74    config: TrackingEvalConfig,
75) -> Result<TrackingMetrics, EvalError> {
76    let borrowed = tracking_frames_as_view(frames);
77    evaluate_tracking(&borrowed, config)
78}
79
80/// Greedy IoU matching for a single frame. Returns vec of (gt_index, pred_index, iou).
81fn greedy_iou_match(frame: &TrackingFrame<'_>, iou_threshold: f32) -> Vec<(usize, usize, f32)> {
82    let mut candidates = Vec::new();
83    for (gt_idx, gt) in frame.ground_truth.iter().enumerate() {
84        for (pred_idx, prediction) in frame.predictions.iter().enumerate() {
85            if gt.class_id != prediction.detection.class_id {
86                continue;
87            }
88            let overlap = iou(gt.bbox, prediction.detection.bbox);
89            if overlap >= iou_threshold {
90                candidates.push((overlap, gt_idx, pred_idx));
91            }
92        }
93    }
94    candidates.sort_by(|a, b| b.0.total_cmp(&a.0));
95
96    let mut gt_taken = vec![false; frame.ground_truth.len()];
97    let mut pred_taken = vec![false; frame.predictions.len()];
98    let mut matches = Vec::new();
99
100    for (overlap, gt_idx, pred_idx) in candidates {
101        if gt_taken[gt_idx] || pred_taken[pred_idx] {
102            continue;
103        }
104        gt_taken[gt_idx] = true;
105        pred_taken[pred_idx] = true;
106        matches.push((gt_idx, pred_idx, overlap));
107    }
108
109    matches
110}
111
112/// Identity F1 score: measures how well predicted IDs match GT IDs across frames.
113pub fn idf1(frames: &[TrackingFrame<'_>], config: TrackingEvalConfig) -> Result<f32, EvalError> {
114    config.validate()?;
115
116    // Count co-occurrences of (gt_object_id, pred_track_id) pairs
117    let mut cooccurrence: HashMap<(u64, u64), u64> = HashMap::new();
118    let mut gt_appearances: HashMap<u64, u64> = HashMap::new();
119    let mut pred_appearances: HashMap<u64, u64> = HashMap::new();
120
121    for frame in frames {
122        for gt in frame.ground_truth {
123            *gt_appearances.entry(gt.object_id).or_insert(0) += 1;
124        }
125        for pred in frame.predictions {
126            *pred_appearances.entry(pred.track_id).or_insert(0) += 1;
127        }
128
129        let matches = greedy_iou_match(frame, config.iou_threshold);
130        for (gt_idx, pred_idx, _) in matches {
131            let gt_id = frame.ground_truth[gt_idx].object_id;
132            let pred_id = frame.predictions[pred_idx].track_id;
133            *cooccurrence.entry((gt_id, pred_id)).or_insert(0) += 1;
134        }
135    }
136
137    let total_gt: u64 = gt_appearances.values().sum();
138    let total_pred: u64 = pred_appearances.values().sum();
139
140    if total_gt == 0 && total_pred == 0 {
141        return Ok(0.0);
142    }
143
144    // For each GT object, find the best-matching predicted track (most co-occurrences)
145    let mut best_for_gt: HashMap<u64, (u64, u64)> = HashMap::new(); // gt_id -> (pred_id, count)
146    for (&(gt_id, pred_id), &count) in &cooccurrence {
147        let entry = best_for_gt.entry(gt_id).or_insert((pred_id, 0));
148        if count > entry.1 {
149            *entry = (pred_id, count);
150        }
151    }
152
153    let idtp: u64 = best_for_gt.values().map(|(_, count)| count).sum();
154    let idfn = total_gt - idtp;
155    let idfp = total_pred - idtp;
156
157    let denom = 2 * idtp + idfp + idfn;
158    if denom == 0 {
159        return Ok(0.0);
160    }
161
162    Ok((2 * idtp) as f32 / denom as f32)
163}
164
165/// Higher Order Tracking Accuracy.
166pub fn hota(frames: &[TrackingFrame<'_>], config: TrackingEvalConfig) -> Result<f32, EvalError> {
167    config.validate()?;
168
169    // Collect all per-frame TP matches and detection counts
170    let mut all_matches: Vec<(u64, u64)> = Vec::new(); // (gt_object_id, pred_track_id) per TP
171    let mut total_tp = 0u64;
172    let mut total_fp = 0u64;
173    let mut total_fn = 0u64;
174
175    // Track which pred_id maps to which gt_id per frame (for association computation)
176    let mut pred_to_gt_per_frame: Vec<HashMap<u64, u64>> = Vec::new();
177    let mut gt_to_pred_per_frame: Vec<HashMap<u64, u64>> = Vec::new();
178
179    for frame in frames {
180        let matches = greedy_iou_match(frame, config.iou_threshold);
181        let tp = matches.len() as u64;
182        let fn_count = frame.ground_truth.len() as u64 - tp;
183        let fp_count = frame.predictions.len() as u64 - tp;
184
185        total_tp += tp;
186        total_fp += fp_count;
187        total_fn += fn_count;
188
189        let mut pred_to_gt = HashMap::new();
190        let mut gt_to_pred = HashMap::new();
191        for &(gt_idx, pred_idx, _) in &matches {
192            let gt_id = frame.ground_truth[gt_idx].object_id;
193            let pred_id = frame.predictions[pred_idx].track_id;
194            all_matches.push((gt_id, pred_id));
195            pred_to_gt.insert(pred_id, gt_id);
196            gt_to_pred.insert(gt_id, pred_id);
197        }
198        pred_to_gt_per_frame.push(pred_to_gt);
199        gt_to_pred_per_frame.push(gt_to_pred);
200    }
201
202    if total_tp == 0 {
203        return Ok(0.0);
204    }
205
206    let det_a = total_tp as f32 / (total_tp + total_fp + total_fn) as f32;
207
208    // Compute association accuracy for each TP match
209    let mut ass_a_sum = 0.0f32;
210    let num_frames = frames.len();
211
212    for &(gt_id, pred_id) in &all_matches {
213        let mut tpa = 0u64;
214        let mut fpa = 0u64;
215        let mut fna = 0u64;
216
217        for f in 0..num_frames {
218            let pred_matched_gt = pred_to_gt_per_frame[f].get(&pred_id);
219            let gt_matched_pred = gt_to_pred_per_frame[f].get(&gt_id);
220
221            match (pred_matched_gt, gt_matched_pred) {
222                (Some(&matched_gt), Some(&matched_pred))
223                    if matched_gt == gt_id && matched_pred == pred_id =>
224                {
225                    tpa += 1;
226                }
227                _ => {
228                    // FPA: pred_id matched a different gt (or any gt)
229                    if let Some(&matched_gt) = pred_matched_gt
230                        && matched_gt != gt_id
231                    {
232                        fpa += 1;
233                    }
234                    // FNA: gt_id matched a different pred
235                    if let Some(&matched_pred) = gt_matched_pred
236                        && matched_pred != pred_id
237                    {
238                        fna += 1;
239                    }
240                }
241            }
242        }
243
244        let denom = tpa + fpa + fna;
245        if denom > 0 {
246            ass_a_sum += tpa as f32 / denom as f32;
247        }
248    }
249
250    let ass_a = ass_a_sum / all_matches.len() as f32;
251    Ok((det_a * ass_a).sqrt())
252}
253
254pub fn evaluate_tracking(
255    frames: &[TrackingFrame<'_>],
256    config: TrackingEvalConfig,
257) -> Result<TrackingMetrics, EvalError> {
258    config.validate()?;
259
260    let mut total_ground_truth = 0u64;
261    let mut matches = 0u64;
262    let mut false_positives = 0u64;
263    let mut false_negatives = 0u64;
264    let mut id_switches = 0u64;
265    let mut iou_sum = 0.0f32;
266    let mut last_assignment: HashMap<u64, u64> = HashMap::new();
267
268    for frame in frames {
269        total_ground_truth += frame.ground_truth.len() as u64;
270
271        let mut candidates = Vec::new();
272        for (gt_idx, gt) in frame.ground_truth.iter().enumerate() {
273            for (pred_idx, prediction) in frame.predictions.iter().enumerate() {
274                if gt.class_id != prediction.detection.class_id {
275                    continue;
276                }
277                let overlap = iou(gt.bbox, prediction.detection.bbox);
278                if overlap >= config.iou_threshold {
279                    candidates.push((overlap, gt_idx, pred_idx));
280                }
281            }
282        }
283        candidates.sort_by(|a, b| b.0.total_cmp(&a.0));
284
285        let mut gt_taken = vec![false; frame.ground_truth.len()];
286        let mut pred_taken = vec![false; frame.predictions.len()];
287
288        for (overlap, gt_idx, pred_idx) in candidates {
289            if gt_taken[gt_idx] || pred_taken[pred_idx] {
290                continue;
291            }
292
293            gt_taken[gt_idx] = true;
294            pred_taken[pred_idx] = true;
295            matches += 1;
296            iou_sum += overlap;
297
298            let gt_id = frame.ground_truth[gt_idx].object_id;
299            let pred_id = frame.predictions[pred_idx].track_id;
300            if let Some(previous_pred_id) = last_assignment.get(&gt_id)
301                && *previous_pred_id != pred_id
302            {
303                id_switches += 1;
304            }
305            last_assignment.insert(gt_id, pred_id);
306        }
307
308        false_negatives += gt_taken.iter().filter(|matched| !**matched).count() as u64;
309        false_positives += pred_taken.iter().filter(|matched| !**matched).count() as u64;
310    }
311
312    let precision = safe_ratio(matches, matches + false_positives);
313    let recall = safe_ratio(matches, matches + false_negatives);
314    let f1 = harmonic_mean(precision, recall);
315    let motp = if matches == 0 {
316        0.0
317    } else {
318        iou_sum / matches as f32
319    };
320    let mota = if total_ground_truth == 0 {
321        0.0
322    } else {
323        1.0 - ((false_negatives + false_positives + id_switches) as f32 / total_ground_truth as f32)
324    };
325
326    Ok(TrackingMetrics {
327        total_ground_truth,
328        matches,
329        false_positives,
330        false_negatives,
331        id_switches,
332        precision,
333        recall,
334        f1,
335        mota,
336        motp,
337    })
338}