Skip to main content

yscv_eval/
detection.rs

1use yscv_detect::{BoundingBox, Detection, iou};
2
3use crate::EvalError;
4use crate::util::{harmonic_mean, safe_ratio, validate_iou_threshold, validate_score_threshold};
5
6#[derive(Debug, Clone, Copy, PartialEq)]
7pub struct LabeledBox {
8    pub bbox: BoundingBox,
9    pub class_id: usize,
10}
11
12#[derive(Debug, Clone, Copy, PartialEq)]
13pub struct DetectionFrame<'a> {
14    pub ground_truth: &'a [LabeledBox],
15    pub predictions: &'a [Detection],
16}
17
18#[derive(Debug, Clone, Copy, PartialEq)]
19pub struct DetectionEvalConfig {
20    pub iou_threshold: f32,
21    pub score_threshold: f32,
22}
23
24impl Default for DetectionEvalConfig {
25    fn default() -> Self {
26        Self {
27            iou_threshold: 0.5,
28            score_threshold: 0.0,
29        }
30    }
31}
32
33impl DetectionEvalConfig {
34    pub fn validate(&self) -> Result<(), EvalError> {
35        validate_iou_threshold(self.iou_threshold)?;
36        validate_score_threshold(self.score_threshold)?;
37        Ok(())
38    }
39}
40
41#[derive(Debug, Clone, Copy, PartialEq)]
42pub struct DetectionMetrics {
43    pub true_positives: u64,
44    pub false_positives: u64,
45    pub false_negatives: u64,
46    pub precision: f32,
47    pub recall: f32,
48    pub f1: f32,
49    pub average_precision: f32,
50}
51
52#[derive(Debug, Clone, PartialEq)]
53pub struct DetectionDatasetFrame {
54    pub ground_truth: Vec<LabeledBox>,
55    pub predictions: Vec<Detection>,
56}
57
58impl DetectionDatasetFrame {
59    pub fn as_view(&self) -> DetectionFrame<'_> {
60        DetectionFrame {
61            ground_truth: &self.ground_truth,
62            predictions: &self.predictions,
63        }
64    }
65}
66
67pub fn detection_frames_as_view(frames: &[DetectionDatasetFrame]) -> Vec<DetectionFrame<'_>> {
68    frames.iter().map(DetectionDatasetFrame::as_view).collect()
69}
70
71pub fn evaluate_detections_from_dataset(
72    frames: &[DetectionDatasetFrame],
73    config: DetectionEvalConfig,
74) -> Result<DetectionMetrics, EvalError> {
75    let borrowed = detection_frames_as_view(frames);
76    evaluate_detections(&borrowed, config)
77}
78
79pub fn evaluate_detections(
80    frames: &[DetectionFrame<'_>],
81    config: DetectionEvalConfig,
82) -> Result<DetectionMetrics, EvalError> {
83    config.validate()?;
84
85    let mut true_positives = 0u64;
86    let mut false_positives = 0u64;
87    let mut false_negatives = 0u64;
88
89    for frame in frames {
90        let mut predictions: Vec<Detection> = frame
91            .predictions
92            .iter()
93            .copied()
94            .filter(|prediction| prediction.score >= config.score_threshold)
95            .collect();
96        predictions.sort_by(|a, b| b.score.total_cmp(&a.score));
97
98        let mut gt_taken = vec![false; frame.ground_truth.len()];
99        for prediction in predictions {
100            if let Some(best_gt_idx) = best_gt_match(
101                prediction,
102                frame.ground_truth,
103                &gt_taken,
104                config.iou_threshold,
105            ) {
106                gt_taken[best_gt_idx] = true;
107                true_positives += 1;
108            } else {
109                false_positives += 1;
110            }
111        }
112
113        false_negatives += gt_taken.iter().filter(|matched| !**matched).count() as u64;
114    }
115
116    let precision = safe_ratio(true_positives, true_positives + false_positives);
117    let recall = safe_ratio(true_positives, true_positives + false_negatives);
118    let f1 = harmonic_mean(precision, recall);
119    let average_precision = average_precision(frames, config);
120
121    Ok(DetectionMetrics {
122        true_positives,
123        false_positives,
124        false_negatives,
125        precision,
126        recall,
127        f1,
128        average_precision,
129    })
130}
131
132fn best_gt_match(
133    prediction: Detection,
134    ground_truth: &[LabeledBox],
135    gt_taken: &[bool],
136    iou_threshold: f32,
137) -> Option<usize> {
138    let mut best_iou = iou_threshold;
139    let mut best_idx = None;
140
141    for (idx, gt) in ground_truth.iter().enumerate() {
142        if gt_taken[idx] || gt.class_id != prediction.class_id {
143            continue;
144        }
145        let overlap = iou(gt.bbox, prediction.bbox);
146        if overlap >= best_iou {
147            best_iou = overlap;
148            best_idx = Some(idx);
149        }
150    }
151    best_idx
152}
153
154fn average_precision(frames: &[DetectionFrame<'_>], config: DetectionEvalConfig) -> f32 {
155    let total_ground_truth = frames
156        .iter()
157        .map(|frame| frame.ground_truth.len() as u64)
158        .sum::<u64>();
159    if total_ground_truth == 0 {
160        return 0.0;
161    }
162
163    let mut ranked_predictions = Vec::new();
164    for (frame_idx, frame) in frames.iter().enumerate() {
165        for prediction in frame.predictions {
166            if prediction.score >= config.score_threshold {
167                ranked_predictions.push((frame_idx, *prediction));
168            }
169        }
170    }
171    ranked_predictions.sort_by(|a, b| b.1.score.total_cmp(&a.1.score));
172
173    if ranked_predictions.is_empty() {
174        return 0.0;
175    }
176
177    let mut gt_taken: Vec<Vec<bool>> = frames
178        .iter()
179        .map(|frame| vec![false; frame.ground_truth.len()])
180        .collect();
181    let mut precisions = Vec::with_capacity(ranked_predictions.len());
182    let mut recalls = Vec::with_capacity(ranked_predictions.len());
183
184    let mut true_positives = 0u64;
185    let mut false_positives = 0u64;
186
187    for (frame_idx, prediction) in ranked_predictions {
188        if let Some(best_gt_idx) = best_gt_match(
189            prediction,
190            frames[frame_idx].ground_truth,
191            &gt_taken[frame_idx],
192            config.iou_threshold,
193        ) {
194            gt_taken[frame_idx][best_gt_idx] = true;
195            true_positives += 1;
196        } else {
197            false_positives += 1;
198        }
199
200        precisions.push(safe_ratio(true_positives, true_positives + false_positives));
201        recalls.push(safe_ratio(true_positives, total_ground_truth));
202    }
203
204    let mut monotonic_precisions = Vec::with_capacity(precisions.len() + 2);
205    let mut padded_recalls = Vec::with_capacity(recalls.len() + 2);
206
207    padded_recalls.push(0.0);
208    padded_recalls.extend(recalls.iter().copied());
209    padded_recalls.push(1.0);
210
211    monotonic_precisions.push(0.0);
212    monotonic_precisions.extend(precisions.iter().copied());
213    monotonic_precisions.push(0.0);
214
215    for idx in (0..monotonic_precisions.len() - 1).rev() {
216        monotonic_precisions[idx] = monotonic_precisions[idx].max(monotonic_precisions[idx + 1]);
217    }
218
219    let mut ap = 0.0f32;
220    for idx in 0..padded_recalls.len() - 1 {
221        let recall_delta = padded_recalls[idx + 1] - padded_recalls[idx];
222        if recall_delta > 0.0 {
223            ap += recall_delta * monotonic_precisions[idx + 1];
224        }
225    }
226    ap.clamp(0.0, 1.0)
227}
228
229// ---------------------------------------------------------------------------
230// COCO-style multi-threshold evaluation
231// ---------------------------------------------------------------------------
232
233const COCO_IOU_THRESHOLDS: [f32; 10] = [0.50, 0.55, 0.60, 0.65, 0.70, 0.75, 0.80, 0.85, 0.90, 0.95];
234
235const SMALL_AREA_MAX: f32 = 32.0 * 32.0;
236const MEDIUM_AREA_MAX: f32 = 96.0 * 96.0;
237
238#[derive(Debug, Clone, Copy, PartialEq)]
239pub struct CocoMetrics {
240    /// AP averaged over IoU thresholds 0.50..=0.95 (step 0.05).
241    pub ap: f32,
242    /// AP at IoU = 0.50.
243    pub ap50: f32,
244    /// AP at IoU = 0.75.
245    pub ap75: f32,
246    /// AP for small objects (area < 32²).
247    pub ap_small: f32,
248    /// AP for medium objects (32² <= area < 96²).
249    pub ap_medium: f32,
250    /// AP for large objects (area >= 96²).
251    pub ap_large: f32,
252    /// Average Recall: mean of max recall across IoU thresholds.
253    pub ar: f32,
254}
255
256fn box_area(b: &BoundingBox) -> f32 {
257    (b.x2 - b.x1) * (b.y2 - b.y1)
258}
259
260/// Filter frames so that only ground-truth boxes satisfying `pred` are kept.
261/// Predictions are kept unchanged (they are matched against the filtered GT).
262fn filter_gt_by<F>(frames: &[DetectionFrame<'_>], pred: F) -> Vec<DetectionDatasetFrame>
263where
264    F: Fn(&LabeledBox) -> bool,
265{
266    frames
267        .iter()
268        .map(|frame| {
269            let ground_truth: Vec<LabeledBox> = frame
270                .ground_truth
271                .iter()
272                .filter(|lb| pred(lb))
273                .copied()
274                .collect();
275            DetectionDatasetFrame {
276                ground_truth,
277                predictions: frame.predictions.to_vec(),
278            }
279        })
280        .collect()
281}
282
283/// Evaluate detections using COCO-style multi-threshold metrics.
284pub fn evaluate_detections_coco(
285    frames: &[DetectionFrame<'_>],
286    score_threshold: f32,
287) -> Result<CocoMetrics, EvalError> {
288    validate_score_threshold(score_threshold)?;
289
290    // Compute per-threshold AP and recall.
291    let mut aps = [0.0f32; 10];
292    let mut recalls = [0.0f32; 10];
293
294    for (i, &iou_thresh) in COCO_IOU_THRESHOLDS.iter().enumerate() {
295        let config = DetectionEvalConfig {
296            iou_threshold: iou_thresh,
297            score_threshold,
298        };
299        let m = evaluate_detections(frames, config)?;
300        aps[i] = m.average_precision;
301        recalls[i] = m.recall;
302    }
303
304    let ap = aps.iter().sum::<f32>() / aps.len() as f32;
305    let ap50 = aps[0]; // IoU 0.50
306    let ap75 = aps[5]; // IoU 0.75
307    let ar = recalls.iter().sum::<f32>() / recalls.len() as f32;
308
309    // Size-based AP (computed at all 10 thresholds, then averaged).
310    let ap_small = size_ap(frames, score_threshold, |a| a < SMALL_AREA_MAX)?;
311    let ap_medium = size_ap(frames, score_threshold, |a| {
312        (SMALL_AREA_MAX..MEDIUM_AREA_MAX).contains(&a)
313    })?;
314    let ap_large = size_ap(frames, score_threshold, |a| a >= MEDIUM_AREA_MAX)?;
315
316    Ok(CocoMetrics {
317        ap,
318        ap50,
319        ap75,
320        ap_small,
321        ap_medium,
322        ap_large,
323        ar,
324    })
325}
326
327fn size_ap<F>(
328    frames: &[DetectionFrame<'_>],
329    score_threshold: f32,
330    area_filter: F,
331) -> Result<f32, EvalError>
332where
333    F: Fn(f32) -> bool,
334{
335    let owned = filter_gt_by(frames, |lb| area_filter(box_area(&lb.bbox)));
336    let views = detection_frames_as_view(&owned);
337
338    let mut sum = 0.0f32;
339    for &iou_thresh in &COCO_IOU_THRESHOLDS {
340        let config = DetectionEvalConfig {
341            iou_threshold: iou_thresh,
342            score_threshold,
343        };
344        let m = evaluate_detections(&views, config)?;
345        sum += m.average_precision;
346    }
347    Ok(sum / COCO_IOU_THRESHOLDS.len() as f32)
348}