Skip to main content

vernier_core/
accumulate.rs

1//! Per-image evaluation → precision/recall/scores arrays.
2//!
3//! Mirrors `pycocotools.cocoeval.COCOeval.accumulate` (cocoeval.py
4//! lines 315-420). Inputs come from the upstream matching engine
5//! (the [`crate::matching`] engine) packaged as one [`PerImageEval`] per
6//! `(category, areaRange, image)` cell; outputs are the
7//! `(T, R, K, A, M)` precision and `(T, K, A, M)` recall tensors that
8//! the summarizer slices into the final 12 stats.
9//!
10//! ## Quirk dispositions
11//!
12//! - **A1** (`strict`): the merged-stream sort across one `(K, A, M)`
13//!   slice is also a stable mergesort on `-score`, mirroring
14//!   `np.argsort(kind='mergesort')` on the concatenated stream.
15//! - **C1** (`strict`): recall lookup uses `searchsorted(rc, t,
16//!   side='left')` semantics — the leftmost cumulative-recall index
17//!   with `rc[i] >= t`.
18//! - **C2** (`strict`): right-to-left running max on the precision
19//!   array enforces the monotonic precision envelope before
20//!   integration.
21//! - **C3** (`corrected` implementation, `strict` outputs): the
22//!   `try/except` around `dtScoresSorted[pi]` becomes an explicit
23//!   bounds check (`pi < n_d`); past the curve we leave `q[ri]` and
24//!   `ss[ri]` at `0.0`, matching the silent-skip pycocotools does in
25//!   the `except: pass` branch.
26//! - **C4** (`strict`): "AR" stored in `recall` is terminal cumulative
27//!   recall (the last value of `rc`), not an integral of the
28//!   precision/recall curve.
29//! - **C5** (`strict`): `(K, A, M)` cells with no detections or no
30//!   non-ignore GTs leave `precision`/`recall`/`scores` at the `-1`
31//!   sentinel; the summarizer filters those before averaging.
32//! - **C7** (`strict`): TP and FP cumsums skip DTs whose `dt_ignore`
33//!   flag is set — both B6 (matched-to-ignore) and B7 (out-of-area
34//!   unmatched) are folded into `dt_ignore` upstream.
35//! - **C8** (`aligned`): precision denominator uses
36//!   [`crate::parity::PARITY_EPS`] (= `f64::EPSILON`), bit-equal to
37//!   `np.spacing(1)`.
38//! - **L1, L2** (`strict`): `iou_thresholds` and `recall_thresholds`
39//!   come from [`crate::parity::iou_thresholds`] / [`crate::parity::recall_thresholds`]
40//!   and are linspace-built; the accumulator does not assume their
41//!   values, only their lengths.
42//!
43//! Quirks **B7** (out-of-area unmatched DT → `dt_ignore`) and **B6**
44//! (DT matched to ignore-GT → `dt_ignore`) are inputs here, not
45//! responsibilities. The orchestrator that builds [`PerImageEval`]
46//! folds B7 in alongside the matching engine's B6.
47
48use ndarray::{Array2, Array4, Array5, Axis};
49
50use crate::error::EvalError;
51use crate::parity::{argsort_score_desc, ParityMode, PARITY_EPS};
52
53/// Per `(image, category, areaRange)` slice of evaluation data, in the
54/// shape the accumulator consumes.
55///
56/// Built by the orchestrator from a `MatchResult` (private to the
57/// [`crate::matching`] module) plus the
58/// per-DT areas needed to apply quirk **B7**. Field orders mirror the
59/// matching engine's *sorted* internal orders: `dt_*` rows are
60/// score-desc (stable mergesort), `gt_ignore` is ignore-asc.
61#[derive(Debug, Clone)]
62pub struct PerImageEval {
63    /// Detection scores in sorted-DT order. Length `D`.
64    pub dt_scores: Vec<f64>,
65    /// Per-`(T, D)` match indicator. `true` when the DT matched any GT
66    /// at this threshold (regardless of whether the matched GT is an
67    /// ignore-GT — that distinction is carried by `dt_ignore`).
68    pub dt_matched: Array2<bool>,
69    /// Per-`(T, D)` ignore flag. Caller must fold in both B6 (matched
70    /// to ignore-GT) and B7 (out-of-area unmatched) before constructing
71    /// this struct; the accumulator treats it as authoritative.
72    pub dt_ignore: Array2<bool>,
73    /// Per-GT ignore flag in sorted-GT order. Length `G`.
74    pub gt_ignore: Vec<bool>,
75}
76
77/// Inputs to [`accumulate`] that describe the evaluation grid.
78///
79/// `eval_imgs.len()` must equal `n_categories * n_area_ranges *
80/// n_images`, with the layout `eval_imgs[k * A * I + a * I + i]`
81/// matching pycocotools' flat indexing of `evalImgs`.
82#[derive(Debug, Clone, Copy)]
83pub struct AccumulateParams<'p> {
84    /// IoU thresholds, length `T`. Use [`crate::parity::iou_thresholds`] for
85    /// the canonical 10-point COCO ladder.
86    pub iou_thresholds: &'p [f64],
87    /// Recall integration thresholds, length `R` (typically 101). Use
88    /// [`crate::parity::recall_thresholds`].
89    pub recall_thresholds: &'p [f64],
90    /// Per-image maxDet caps, length `M`. Pycocotools defaults to
91    /// `[1, 10, 100]`. The matching engine should be invoked with the
92    /// *largest* of these — the accumulator slices to smaller caps via
93    /// `[..max_det]`.
94    ///
95    /// Must be sorted ascending (quirk **A2** — strict). Pycocotools
96    /// silently overwrites `p.maxDets = sorted(p.maxDets)` at
97    /// `cocoeval.py:137`, so the M-axis is always laid out
98    /// smallest-to-largest. The summarizer's `AR_1 / AR_10 / AR_100`
99    /// slot mapping depends on this ordering — passing `[100, 1, 10]`
100    /// without sorting would silently swap the slot semantics. Callers
101    /// at the FFI boundary use [`sort_max_dets`] to enforce this.
102    pub max_dets: &'p [usize],
103    /// Number of categories `K` (or `1` when `useCats == 0`).
104    pub n_categories: usize,
105    /// Number of area ranges `A` (COCO defaults to 4: all/small/medium/
106    /// large).
107    pub n_area_ranges: usize,
108    /// Number of images `I`.
109    pub n_images: usize,
110}
111
112/// Normalize a `max_dets` ladder to ascending order, in place.
113///
114/// Mirrors `pycocotools.cocoeval.COCOeval.accumulate`'s opening line
115/// (`cocoeval.py:137`):
116///
117/// ```python
118/// p.maxDets = sorted(p.maxDets)
119/// ```
120///
121/// Quirk **A2** (strict). The accumulator's M-axis is laid out in the
122/// order of the ladder it receives, and the summarizer's
123/// `AR_1 / AR_10 / AR_100` slot mapping is positional — sorting at the
124/// param-construction boundary keeps user input order from silently
125/// permuting the final stat vector. Stable sort (`Vec::sort`); the
126/// ladder is `usize`, so stability matches pycocotools' Python `sorted`
127/// (also stable).
128pub fn sort_max_dets(max_dets: &mut [usize]) {
129    max_dets.sort();
130}
131
132/// Output tensors produced by [`accumulate`].
133///
134/// Cells absent from the dataset (no DTs, or no non-ignore GTs) carry
135/// `-1.0` per quirk **C5**. The summarizer filters these before
136/// averaging; downstream code that consumes the tensors directly must
137/// honor the same convention.
138#[derive(Debug, Clone)]
139pub struct Accumulated {
140    /// Shape `(T, R, K, A, M)`. Right-monotonic precision interpolated
141    /// at every recall threshold.
142    pub precision: Array5<f64>,
143    /// Shape `(T, K, A, M)`. Terminal cumulative recall (quirk **C4**).
144    pub recall: Array4<f64>,
145    /// Shape `(T, R, K, A, M)`. Detection score at the recall threshold
146    /// where each precision sample was taken.
147    pub scores: Array5<f64>,
148}
149
150/// Accumulate per-image evaluation results into precision / recall /
151/// scores tensors.
152///
153/// The flat `eval_imgs` slice must be laid out as `[k][a][i]` (K-major,
154/// then A, then I) — `eval_imgs.len() == K * A * I`.
155///
156/// # Errors
157///
158/// Returns [`EvalError::DimensionMismatch`] if `eval_imgs.len()` does
159/// not equal `K * A * I`, or if any per-image array shapes disagree
160/// with the declared `T` (IoU-threshold count).
161pub fn accumulate(
162    eval_imgs: &[Option<Box<PerImageEval>>],
163    p: AccumulateParams<'_>,
164    _parity_mode: ParityMode,
165) -> Result<Accumulated, EvalError> {
166    let n_t = p.iou_thresholds.len();
167    let n_r = p.recall_thresholds.len();
168    let n_k = p.n_categories;
169    let n_a = p.n_area_ranges;
170    let n_m = p.max_dets.len();
171    let n_i = p.n_images;
172
173    let expected = n_k * n_a * n_i;
174    if eval_imgs.len() != expected {
175        return Err(EvalError::DimensionMismatch {
176            detail: format!(
177                "eval_imgs len {} != n_categories({}) * n_area_ranges({}) * n_images({}) = {}",
178                eval_imgs.len(),
179                n_k,
180                n_a,
181                n_i,
182                expected
183            ),
184        });
185    }
186
187    for cell in eval_imgs.iter().flatten() {
188        if cell.dt_matched.shape() != cell.dt_ignore.shape() {
189            return Err(EvalError::DimensionMismatch {
190                detail: format!(
191                    "PerImageEval.dt_matched {:?} != dt_ignore {:?}",
192                    cell.dt_matched.shape(),
193                    cell.dt_ignore.shape()
194                ),
195            });
196        }
197        if cell.dt_matched.nrows() != n_t {
198            return Err(EvalError::DimensionMismatch {
199                detail: format!(
200                    "PerImageEval row count {} != iou_thresholds len {}",
201                    cell.dt_matched.nrows(),
202                    n_t
203                ),
204            });
205        }
206        if cell.dt_matched.ncols() != cell.dt_scores.len() {
207            return Err(EvalError::DimensionMismatch {
208                detail: format!(
209                    "PerImageEval.dt_matched cols {} != dt_scores len {}",
210                    cell.dt_matched.ncols(),
211                    cell.dt_scores.len()
212                ),
213            });
214        }
215    }
216
217    let mut precision = Array5::<f64>::from_elem((n_t, n_r, n_k, n_a, n_m), -1.0);
218    let mut recall = Array4::<f64>::from_elem((n_t, n_k, n_a, n_m), -1.0);
219    let mut scores = Array5::<f64>::from_elem((n_t, n_r, n_k, n_a, n_m), -1.0);
220
221    for k in 0..n_k {
222        let nk = k * n_a * n_i;
223        for a in 0..n_a {
224            let na = a * n_i;
225            let cells: Vec<&PerImageEval> = (0..n_i)
226                .filter_map(|i| eval_imgs[nk + na + i].as_deref())
227                .collect();
228            if cells.is_empty() {
229                continue;
230            }
231            let npig: usize = cells
232                .iter()
233                .map(|e| e.gt_ignore.iter().filter(|&&ig| !ig).count())
234                .sum();
235            if npig == 0 {
236                continue;
237            }
238
239            for (m, &max_det) in p.max_dets.iter().enumerate() {
240                accumulate_cell(
241                    &cells,
242                    max_det,
243                    npig,
244                    n_t,
245                    p.recall_thresholds,
246                    k,
247                    a,
248                    m,
249                    &mut precision,
250                    &mut recall,
251                    &mut scores,
252                );
253            }
254        }
255    }
256
257    Ok(Accumulated {
258        precision,
259        recall,
260        scores,
261    })
262}
263
264#[allow(clippy::too_many_arguments)]
265fn accumulate_cell(
266    cells: &[&PerImageEval],
267    max_det: usize,
268    npig: usize,
269    n_t: usize,
270    recall_thresholds: &[f64],
271    k: usize,
272    a: usize,
273    m: usize,
274    precision: &mut Array5<f64>,
275    recall: &mut Array4<f64>,
276    scores: &mut Array5<f64>,
277) {
278    let mut takes: Vec<usize> = Vec::with_capacity(cells.len());
279    let mut total = 0usize;
280    for cell in cells {
281        let take = cell.dt_scores.len().min(max_det);
282        takes.push(take);
283        total += take;
284    }
285    let mut all_scores: Vec<f64> = Vec::with_capacity(total);
286    for (cell, &take) in cells.iter().zip(&takes) {
287        all_scores.extend_from_slice(&cell.dt_scores[..take]);
288    }
289
290    let n_d = all_scores.len();
291    if n_d == 0 {
292        // No detections, but npig > 0 — recall collapses to 0; precision
293        // and scores keep the -1 sentinel.
294        for t in 0..n_t {
295            recall[(t, k, a, m)] = 0.0;
296        }
297        return;
298    }
299
300    let perm = argsort_score_desc(&all_scores);
301
302    let npig_f = npig as f64;
303    let mut rc = vec![0.0_f64; n_d];
304    let mut pr = vec![0.0_f64; n_d];
305    let mut dtm = vec![false; n_d];
306    let mut dtg = vec![false; n_d];
307
308    for t in 0..n_t {
309        let mut cursor = 0;
310        for (cell, &take) in cells.iter().zip(&takes) {
311            let m_row = cell.dt_matched.row(t);
312            let g_row = cell.dt_ignore.row(t);
313            for d in 0..take {
314                dtm[cursor] = m_row[d];
315                dtg[cursor] = g_row[d];
316                cursor += 1;
317            }
318        }
319
320        // C7: cumulative TP/FP exclude ignore-tagged DTs.
321        let mut tp = 0.0_f64;
322        let mut fp = 0.0_f64;
323        for (out_idx, &src_idx) in perm.iter().enumerate() {
324            if !dtg[src_idx] {
325                if dtm[src_idx] {
326                    tp += 1.0;
327                } else {
328                    fp += 1.0;
329                }
330            }
331            rc[out_idx] = tp / npig_f;
332            pr[out_idx] = tp / (tp + fp + PARITY_EPS);
333        }
334
335        // C4: terminal cumulative recall.
336        recall[(t, k, a, m)] = rc[n_d - 1];
337
338        // C2: right-to-left running max on precision (envelope).
339        for j in (1..n_d).rev() {
340            if pr[j] > pr[j - 1] {
341                pr[j - 1] = pr[j];
342            }
343        }
344
345        // C1 + C3: searchsorted-left + bounds-check. Past the curve,
346        // slots are filled with 0.0 — overwriting the -1 sentinel so the
347        // summarizer's `s > -1` filter keeps them.
348        let mut p_lane = precision
349            .index_axis_mut(Axis(0), t)
350            .index_axis_move(Axis(1), k)
351            .index_axis_move(Axis(1), a)
352            .index_axis_move(Axis(1), m);
353        let mut s_lane = scores
354            .index_axis_mut(Axis(0), t)
355            .index_axis_move(Axis(1), k)
356            .index_axis_move(Axis(1), a)
357            .index_axis_move(Axis(1), m);
358        for (ri, &target) in recall_thresholds.iter().enumerate() {
359            let pi = rc.partition_point(|&v| v < target);
360            if pi < n_d {
361                p_lane[ri] = pr[pi];
362                s_lane[ri] = all_scores[perm[pi]];
363            } else {
364                p_lane[ri] = 0.0;
365                s_lane[ri] = 0.0;
366            }
367        }
368    }
369}
370
371#[cfg(test)]
372mod tests {
373    use super::*;
374    use ndarray::array;
375
376    fn one_threshold_eval(
377        scores: Vec<f64>,
378        matched: Vec<bool>,
379        ignore: Vec<bool>,
380        gt_ignore: Vec<bool>,
381    ) -> PerImageEval {
382        let n = scores.len();
383        let dt_matched =
384            Array2::from_shape_vec((1, n), matched).expect("dt_matched shape mismatch");
385        let dt_ignore = Array2::from_shape_vec((1, n), ignore).expect("dt_ignore shape mismatch");
386        PerImageEval {
387            dt_scores: scores,
388            dt_matched,
389            dt_ignore,
390            gt_ignore,
391        }
392    }
393
394    fn params<'p>(
395        iou: &'p [f64],
396        rec: &'p [f64],
397        max_dets: &'p [usize],
398        n_images: usize,
399    ) -> AccumulateParams<'p> {
400        AccumulateParams {
401            iou_thresholds: iou,
402            recall_thresholds: rec,
403            max_dets,
404            n_categories: 1,
405            n_area_ranges: 1,
406            n_images,
407        }
408    }
409
410    #[test]
411    fn empty_grid_returns_all_sentinel() {
412        let p = params(&[0.5], &[0.0, 0.5, 1.0], &[100], 0);
413        let out = accumulate(&[], p, ParityMode::Strict).unwrap();
414        assert!(out.precision.iter().all(|&v| v == -1.0));
415        assert!(out.recall.iter().all(|&v| v == -1.0));
416    }
417
418    #[test]
419    fn no_dt_with_real_gt_yields_zero_recall_and_sentinel_precision() {
420        // C5: precision stays at -1 sentinel; recall is 0 for every t.
421        let cell = PerImageEval {
422            dt_scores: vec![],
423            dt_matched: Array2::<bool>::default((2, 0)),
424            dt_ignore: Array2::<bool>::default((2, 0)),
425            gt_ignore: vec![false],
426        };
427        let p = params(&[0.5, 0.75], &[0.0, 0.5, 1.0], &[100], 1);
428        let out = accumulate(&[Some(Box::new(cell))], p, ParityMode::Strict).unwrap();
429        assert_eq!(out.recall[(0, 0, 0, 0)], 0.0);
430        assert_eq!(out.recall[(1, 0, 0, 0)], 0.0);
431        // No precision write happened — every cell still -1.
432        for ri in 0..3 {
433            assert_eq!(out.precision[(0, ri, 0, 0, 0)], -1.0);
434            assert_eq!(out.precision[(1, ri, 0, 0, 0)], -1.0);
435        }
436    }
437
438    #[test]
439    fn cell_with_only_ignore_gts_skips_entirely() {
440        // npig == 0 short-circuit: outputs stay at -1 (no recall write).
441        let cell = one_threshold_eval(vec![0.9], vec![true], vec![true], vec![true]);
442        let p = params(&[0.5], &[0.0, 0.5, 1.0], &[100], 1);
443        let out = accumulate(&[Some(Box::new(cell))], p, ParityMode::Strict).unwrap();
444        assert_eq!(out.recall[(0, 0, 0, 0)], -1.0);
445        assert_eq!(out.precision[(0, 0, 0, 0, 0)], -1.0);
446    }
447
448    #[test]
449    fn perfect_match_yields_ap_one_and_ar_one() {
450        // Single DT matches the only real GT → both precision and
451        // recall are 1.0 across every recall threshold.
452        let cell = one_threshold_eval(vec![0.9], vec![true], vec![false], vec![false]);
453        let p = params(&[0.5], &[0.0, 0.5, 1.0], &[100], 1);
454        let out = accumulate(&[Some(Box::new(cell))], p, ParityMode::Strict).unwrap();
455
456        assert_eq!(out.recall[(0, 0, 0, 0)], 1.0);
457        for ri in 0..3 {
458            // Precision is `tp / (tp + fp + eps)` — 1 / (1 + 0 + eps) ≈ 1.
459            let pr = out.precision[(0, ri, 0, 0, 0)];
460            assert!((pr - 1.0).abs() < 1e-12, "precision[{ri}] = {pr}");
461            assert_eq!(out.scores[(0, ri, 0, 0, 0)], 0.9);
462        }
463    }
464
465    #[test]
466    fn lone_fp_yields_zero_recall_zero_precision() {
467        // One unmatched detection, one real unmatched GT → recall 0,
468        // precision 0 across all recall thresholds. The score column
469        // gets a value only at recall=0 (where the curve does exist);
470        // recall thresholds past the end of the curve fall through to
471        // pycocotools' silent-skip branch, leaving 0.0.
472        let cell = one_threshold_eval(vec![0.9], vec![false], vec![false], vec![false]);
473        let p = params(&[0.5], &[0.0, 0.5, 1.0], &[100], 1);
474        let out = accumulate(&[Some(Box::new(cell))], p, ParityMode::Strict).unwrap();
475        assert_eq!(out.recall[(0, 0, 0, 0)], 0.0);
476        for ri in 0..3 {
477            // 0 / (0 + 1 + eps) ≈ 0 → envelope keeps it at 0.
478            assert!(out.precision[(0, ri, 0, 0, 0)].abs() < 1e-12);
479        }
480        // recall threshold 0.0 lands on the lone curve point (rc[0] =
481        // 0.0); 0.5 and 1.0 are past the end → score sentinel 0.0.
482        assert_eq!(out.scores[(0, 0, 0, 0, 0)], 0.9);
483        assert_eq!(out.scores[(0, 1, 0, 0, 0)], 0.0);
484        assert_eq!(out.scores[(0, 2, 0, 0, 0)], 0.0);
485    }
486
487    #[test]
488    fn ignored_dt_does_not_count_as_fp() {
489        // C7: an ignore-tagged DT is invisible to both TP and FP cumsums.
490        // Setup: one real GT (matched by DT 0), one DT 1 that misses but
491        // is ignore-tagged (e.g. out-of-area unmatched). FP must not
492        // appear in the curve.
493        let cell = one_threshold_eval(
494            vec![0.9, 0.8],
495            vec![true, false],
496            vec![false, true],
497            vec![false],
498        );
499        let p = params(&[0.5], &[0.0, 0.5, 1.0], &[100], 1);
500        let out = accumulate(&[Some(Box::new(cell))], p, ParityMode::Strict).unwrap();
501
502        // tp=1 fp=0 → precision ≈ 1 everywhere on the curve.
503        for ri in 0..3 {
504            let pr = out.precision[(0, ri, 0, 0, 0)];
505            assert!((pr - 1.0).abs() < 1e-12, "precision[{ri}] = {pr}");
506        }
507        assert_eq!(out.recall[(0, 0, 0, 0)], 1.0);
508    }
509
510    #[test]
511    fn precision_envelope_runs_right_to_left() {
512        // C2: pre-envelope precision dips. Curve: TP, FP, TP → precisions
513        // 1.0, 0.5, 0.667. After right-to-left max: 1.0, 0.667, 0.667.
514        // Recall thresholds 0.0 and 0.5 (rc = [0.5, 0.5, 1.0]) sample
515        // index 0; threshold 1.0 samples index 2.
516        let cell = one_threshold_eval(
517            vec![0.9, 0.8, 0.7],
518            vec![true, false, true],
519            vec![false, false, false],
520            vec![false, false],
521        );
522        let p = params(&[0.5], &[0.0, 0.5, 1.0], &[100], 1);
523        let out = accumulate(&[Some(Box::new(cell))], p, ParityMode::Strict).unwrap();
524
525        // recall thresholds 0.0 and 0.5 both fall on the first rc cell
526        // where rc[0] = 0.5 (TP at j=0 → 1/2). Envelope makes pr[0]=1.0.
527        assert!((out.precision[(0, 0, 0, 0, 0)] - 1.0).abs() < 1e-12);
528        assert!((out.precision[(0, 1, 0, 0, 0)] - 1.0).abs() < 1e-12);
529        // recall threshold 1.0 samples j=2: pr[2] = 2/3.
530        assert!((out.precision[(0, 2, 0, 0, 0)] - 2.0 / 3.0).abs() < 1e-12);
531    }
532
533    #[test]
534    fn partition_point_matches_numpy_searchsorted_left() {
535        // Pinning the stdlib semantics so a future swap (e.g., to a
536        // SIMD search) keeps `np.searchsorted(..., side='left')` parity.
537        let haystack = [0.1, 0.3, 0.3, 0.7];
538        let lookup = |t: f64| haystack.partition_point(|&v| v < t);
539        assert_eq!(lookup(0.0), 0);
540        assert_eq!(lookup(0.3), 1); // leftmost equal
541        assert_eq!(lookup(0.5), 3);
542        assert_eq!(lookup(1.0), 4); // past end
543    }
544
545    #[test]
546    fn merged_sort_breaks_ties_by_input_order() {
547        // A1 over the merged stream: two images with one DT each at
548        // score 0.7. With stable sort, image-0 DT comes first.
549        let img0 = one_threshold_eval(vec![0.7], vec![true], vec![false], vec![false]);
550        let img1 = one_threshold_eval(vec![0.7], vec![false], vec![false], vec![false]);
551        // grid: K=1, A=1, I=2 → eval_imgs[0..2] is the (k=0, a=0) row.
552        let grid = vec![Some(Box::new(img0)), Some(Box::new(img1))];
553        let p = params(&[0.5], &[0.0, 0.5, 1.0], &[100], 2);
554        let out = accumulate(&grid, p, ParityMode::Strict).unwrap();
555
556        // tp=1, fp=1 → final pr = 0.5; rc = [0.5, 0.5]. With envelope
557        // (no monotonicity adjustment needed because pr[1] < pr[0]),
558        // recThr 0.0 and 0.5 both sample index 0 (pr ≈ 1.0), recThr 1.0
559        // is past the end → 0.0.
560        assert!((out.precision[(0, 0, 0, 0, 0)] - 1.0).abs() < 1e-12);
561        assert!((out.precision[(0, 1, 0, 0, 0)] - 1.0).abs() < 1e-12);
562        assert_eq!(out.precision[(0, 2, 0, 0, 0)], 0.0);
563    }
564
565    #[test]
566    fn max_det_truncation_drops_low_score_dts_per_image() {
567        // Per-image max_det=1: only the top-scoring DT survives, even
568        // though more were emitted. With only the FP at score 0.95
569        // surviving, AP must collapse.
570        let cell = one_threshold_eval(
571            vec![0.95, 0.9],
572            vec![false, true], // FP first, TP second
573            vec![false, false],
574            vec![false],
575        );
576        let p = params(&[0.5], &[0.0, 0.5, 1.0], &[1], 1);
577        let out = accumulate(&[Some(Box::new(cell))], p, ParityMode::Strict).unwrap();
578        // Only FP survived → tp=0, fp=1, precision ≈ 0 everywhere.
579        for ri in 0..3 {
580            assert!(out.precision[(0, ri, 0, 0, 0)].abs() < 1e-12);
581        }
582        assert_eq!(out.recall[(0, 0, 0, 0)], 0.0);
583    }
584
585    #[test]
586    fn dimension_mismatch_on_grid_size_is_typed_error() {
587        let p = params(&[0.5], &[0.0], &[100], 5);
588        // Grid claims K*A*I = 1*1*5 = 5 cells; we pass 2 → error.
589        let err = accumulate(&[None, None], p, ParityMode::Strict).unwrap_err();
590        match err {
591            EvalError::DimensionMismatch { detail } => {
592                assert!(detail.contains("eval_imgs"));
593            }
594            other => panic!("expected DimensionMismatch, got {other:?}"),
595        }
596    }
597
598    #[test]
599    fn dimension_mismatch_on_per_image_t_is_typed_error() {
600        // Per-image dt_matched has 2 rows, params declare 3 IoU
601        // thresholds → mismatch reported.
602        let cell = PerImageEval {
603            dt_scores: vec![0.9],
604            dt_matched: array![[true], [true]],
605            dt_ignore: array![[false], [false]],
606            gt_ignore: vec![false],
607        };
608        let p = params(&[0.5, 0.75, 0.9], &[0.0], &[100], 1);
609        let err = accumulate(&[Some(Box::new(cell))], p, ParityMode::Strict).unwrap_err();
610        assert!(matches!(err, EvalError::DimensionMismatch { .. }));
611    }
612
613    #[test]
614    fn reaccumulate_with_different_area_range_count_is_typed_error() {
615        // A3: re-accumulating an `eval_imgs` grid built for one A-axis
616        // size against an `AccumulateParams` with a different
617        // `n_area_ranges` must surface DimensionMismatch — not silently
618        // produce wrong outputs by re-slicing the flat buffer at the new
619        // pitch. Build a 4-area-range grid (the COCO default), then try
620        // to accumulate it as if it were a 3-area-range grid.
621        let n_i = 1;
622        let n_a_built = 4;
623        let n_k = 1;
624        let cell = one_threshold_eval(vec![0.9], vec![true], vec![false], vec![false]);
625        // Only the first (k=0, a=0, i=0) slot carries data; remaining
626        // slots are None as they would be for an image with no GTs/DTs in
627        // those buckets.
628        let mut eval_imgs: Vec<Option<Box<PerImageEval>>> = vec![None; n_k * n_a_built * n_i];
629        eval_imgs[0] = Some(Box::new(cell));
630
631        // Mismatched params: claim the grid has 3 area ranges. Expected
632        // grid size becomes 1*3*1 = 3, but we pass 4 cells → typed error.
633        let mut bad = params(&[0.5], &[0.0, 0.5, 1.0], &[100], n_i);
634        bad.n_area_ranges = 3;
635        let err = accumulate(&eval_imgs, bad, ParityMode::Strict).unwrap_err();
636        match err {
637            EvalError::DimensionMismatch { detail } => {
638                assert!(detail.contains("eval_imgs"), "msg: {detail}");
639                assert!(detail.contains("n_area_ranges(3)"), "msg: {detail}");
640            }
641            other => panic!("expected DimensionMismatch, got {other:?}"),
642        }
643    }
644
645    #[test]
646    fn vectorized_inner_sweep_matches_naive_reference() {
647        // C6: the inner recall-threshold sweep is vectorized via
648        // partition_point + an in-place right-to-left envelope. Pin it
649        // against a naive reference that mirrors pycocotools' Python
650        // `for ri, pi in enumerate(inds): q[ri] = pr[pi]` line by line.
651        //
652        // Three hand-crafted PR curves cover the edge cases:
653        //  - monotonic-decreasing precision (no envelope work);
654        //  - non-monotonic precision (envelope rewrites multiple cells);
655        //  - all-1.0 precision with the recall curve ending at 0.5 so
656        //    half the recall thresholds fall past the curve (C3 path).
657        //
658        // Only the precision lane is compared — both implementations
659        // share the same recall-index lookup, so the score lane would
660        // trivially agree.
661        let recall_thresholds: Vec<f64> = (0..=10).map(|i| (i as f64) / 10.0).collect();
662
663        // Naive reference: explicit right-to-left running max + linear
664        // searchsorted-left scan.
665        fn naive_sweep(rc: &[f64], pr: &[f64], rec_thr: &[f64]) -> Vec<f64> {
666            let n = pr.len();
667            let mut env = pr.to_vec();
668            for j in (1..n).rev() {
669                if env[j] > env[j - 1] {
670                    env[j - 1] = env[j];
671                }
672            }
673            let mut q = vec![0.0_f64; rec_thr.len()];
674            for (ri, &target) in rec_thr.iter().enumerate() {
675                let mut pi = n;
676                for (j, &r) in rc.iter().enumerate() {
677                    if r >= target {
678                        pi = j;
679                        break;
680                    }
681                }
682                if pi < n {
683                    q[ri] = env[pi];
684                }
685            }
686            q
687        }
688
689        // Vectorized reference: same shape as `accumulate_cell`'s inner
690        // sweep, callable on hand-crafted curves without rebuilding the
691        // whole `(T, R, K, A, M)` tensor. Drift between this body and
692        // the production sweep is what the test exists to catch.
693        fn vectorized_sweep(rc: &[f64], pr: &[f64], rec_thr: &[f64]) -> Vec<f64> {
694            let n = pr.len();
695            let mut env = pr.to_vec();
696            for j in (1..n).rev() {
697                if env[j] > env[j - 1] {
698                    env[j - 1] = env[j];
699                }
700            }
701            let mut q = vec![0.0_f64; rec_thr.len()];
702            for (ri, &target) in rec_thr.iter().enumerate() {
703                let pi = rc.partition_point(|&v| v < target);
704                if pi < n {
705                    q[ri] = env[pi];
706                }
707            }
708            q
709        }
710
711        let curves: &[(&[f64], &[f64])] = &[
712            // Monotonic-decreasing precision; recall reaches 1.0.
713            (&[0.1, 0.3, 0.5, 0.7, 1.0], &[1.0, 0.9, 0.7, 0.5, 0.3]),
714            // Non-monotonic precision: envelope rewrites cells 1 and 3.
715            (&[0.2, 0.4, 0.6, 0.8, 1.0], &[1.0, 0.4, 0.6, 0.2, 0.5]),
716            // All-1.0 precision; recall caps at 0.5 → recall thresholds
717            // > 0.5 fall past the curve (C3 silent-skip path → 0.0).
718            (&[0.1, 0.2, 0.3, 0.4, 0.5], &[1.0, 1.0, 1.0, 1.0, 1.0]),
719        ];
720
721        for (i, (rc, pr)) in curves.iter().enumerate() {
722            let q_naive = naive_sweep(rc, pr, &recall_thresholds);
723            let q_vec = vectorized_sweep(rc, pr, &recall_thresholds);
724            assert_eq!(q_naive.len(), q_vec.len(), "curve {i}");
725            for (ri, (a, b)) in q_naive.iter().zip(q_vec.iter()).enumerate() {
726                assert_eq!(
727                    a.to_bits(),
728                    b.to_bits(),
729                    "curve {i}, recall threshold index {ri}: naive={a}, vec={b}"
730                );
731            }
732        }
733    }
734
735    #[test]
736    fn sort_max_dets_normalizes_ascending() {
737        // Quirk A2: pycocotools' `cocoeval.py:137` does
738        // `p.maxDets = sorted(p.maxDets)` — `sort_max_dets` is the
739        // mirror at the param-construction boundary.
740        let mut ladder = vec![100usize, 1, 10];
741        sort_max_dets(&mut ladder);
742        assert_eq!(ladder, vec![1, 10, 100]);
743    }
744
745    #[test]
746    fn sort_max_dets_is_idempotent_on_sorted_input() {
747        let mut ladder = vec![1usize, 10, 100];
748        sort_max_dets(&mut ladder);
749        assert_eq!(ladder, vec![1, 10, 100]);
750    }
751
752    #[test]
753    fn sort_max_dets_handles_duplicates_and_singletons() {
754        let mut singleton = vec![100usize];
755        sort_max_dets(&mut singleton);
756        assert_eq!(singleton, vec![100]);
757
758        let mut empty: Vec<usize> = Vec::new();
759        sort_max_dets(&mut empty);
760        assert!(empty.is_empty());
761
762        let mut dups = vec![10usize, 1, 10, 1, 100];
763        sort_max_dets(&mut dups);
764        assert_eq!(dups, vec![1, 1, 10, 10, 100]);
765    }
766
767    #[test]
768    fn permuted_ladder_after_sort_matches_canonical_order() {
769        // End-to-end: feeding `[100, 1, 10]` after `sort_max_dets`
770        // produces a `(T, R, K, A, M)` accumulator whose M-axis is
771        // identical to the one built from the canonical `[1, 10, 100]`.
772        // Without the sort, the M-axis slots would be swapped and the
773        // summarizer's positional `AR_1 / AR_10 / AR_100` mapping would
774        // bind to the wrong threshold.
775        let cell = one_threshold_eval(
776            vec![0.9, 0.8, 0.7],
777            vec![true, true, false],
778            vec![false, false, false],
779            vec![false, false, false],
780        );
781        let iou = [0.5];
782        let rec = [0.0, 0.5, 1.0];
783
784        let canonical = vec![1usize, 10, 100];
785        let canonical_acc = accumulate(
786            &[Some(Box::new(cell.clone()))],
787            params(&iou, &rec, &canonical, 1),
788            ParityMode::Strict,
789        )
790        .unwrap();
791
792        let mut permuted = vec![100usize, 1, 10];
793        sort_max_dets(&mut permuted);
794        assert_eq!(permuted, canonical);
795        let permuted_acc = accumulate(
796            &[Some(Box::new(cell))],
797            params(&iou, &rec, &permuted, 1),
798            ParityMode::Strict,
799        )
800        .unwrap();
801
802        assert_eq!(canonical_acc.precision, permuted_acc.precision);
803        assert_eq!(canonical_acc.recall, permuted_acc.recall);
804        assert_eq!(canonical_acc.scores, permuted_acc.scores);
805    }
806}