vernier_core/accumulate.rs
1//! Per-image evaluation → precision/recall/scores arrays.
2//!
3//! Mirrors `pycocotools.cocoeval.COCOeval.accumulate` (cocoeval.py
4//! lines 315-420). Inputs come from the upstream matching engine
5//! (the [`crate::matching`] engine) packaged as one [`PerImageEval`] per
6//! `(category, areaRange, image)` cell; outputs are the
7//! `(T, R, K, A, M)` precision and `(T, K, A, M)` recall tensors that
8//! the summarizer slices into the final 12 stats.
9//!
10//! ## Quirk dispositions
11//!
12//! - **A1** (`strict`): the merged-stream sort across one `(K, A, M)`
13//! slice is also a stable mergesort on `-score`, mirroring
14//! `np.argsort(kind='mergesort')` on the concatenated stream.
15//! - **C1** (`strict`): recall lookup uses `searchsorted(rc, t,
16//! side='left')` semantics — the leftmost cumulative-recall index
17//! with `rc[i] >= t`.
18//! - **C2** (`strict`): right-to-left running max on the precision
19//! array enforces the monotonic precision envelope before
20//! integration.
21//! - **C3** (`corrected` implementation, `strict` outputs): the
22//! `try/except` around `dtScoresSorted[pi]` becomes an explicit
23//! bounds check (`pi < n_d`); past the curve we leave `q[ri]` and
24//! `ss[ri]` at `0.0`, matching the silent-skip pycocotools does in
25//! the `except: pass` branch.
26//! - **C4** (`strict`): "AR" stored in `recall` is terminal cumulative
27//! recall (the last value of `rc`), not an integral of the
28//! precision/recall curve.
29//! - **C5** (`strict`): `(K, A, M)` cells with no detections or no
30//! non-ignore GTs leave `precision`/`recall`/`scores` at the `-1`
31//! sentinel; the summarizer filters those before averaging.
32//! - **C7** (`strict`): TP and FP cumsums skip DTs whose `dt_ignore`
33//! flag is set — both B6 (matched-to-ignore) and B7 (out-of-area
34//! unmatched) are folded into `dt_ignore` upstream.
35//! - **C8** (`aligned`): precision denominator uses
36//! [`crate::parity::PARITY_EPS`] (= `f64::EPSILON`), bit-equal to
37//! `np.spacing(1)`.
38//! - **L1, L2** (`strict`): `iou_thresholds` and `recall_thresholds`
39//! come from [`crate::parity::iou_thresholds`] / [`crate::parity::recall_thresholds`]
40//! and are linspace-built; the accumulator does not assume their
41//! values, only their lengths.
42//!
43//! Quirks **B7** (out-of-area unmatched DT → `dt_ignore`) and **B6**
44//! (DT matched to ignore-GT → `dt_ignore`) are inputs here, not
45//! responsibilities. The orchestrator that builds [`PerImageEval`]
46//! folds B7 in alongside the matching engine's B6.
47
48use ndarray::{Array2, Array4, Array5, Axis};
49
50use crate::error::EvalError;
51use crate::parity::{argsort_score_desc, ParityMode, PARITY_EPS};
52
53/// Per `(image, category, areaRange)` slice of evaluation data, in the
54/// shape the accumulator consumes.
55///
56/// Built by the orchestrator from a `MatchResult` (private to the
57/// [`crate::matching`] module) plus the
58/// per-DT areas needed to apply quirk **B7**. Field orders mirror the
59/// matching engine's *sorted* internal orders: `dt_*` rows are
60/// score-desc (stable mergesort), `gt_ignore` is ignore-asc.
61#[derive(Debug, Clone)]
62pub struct PerImageEval {
63 /// Detection scores in sorted-DT order. Length `D`.
64 pub dt_scores: Vec<f64>,
65 /// Per-`(T, D)` match indicator. `true` when the DT matched any GT
66 /// at this threshold (regardless of whether the matched GT is an
67 /// ignore-GT — that distinction is carried by `dt_ignore`).
68 pub dt_matched: Array2<bool>,
69 /// Per-`(T, D)` ignore flag. Caller must fold in both B6 (matched
70 /// to ignore-GT) and B7 (out-of-area unmatched) before constructing
71 /// this struct; the accumulator treats it as authoritative.
72 pub dt_ignore: Array2<bool>,
73 /// Per-GT ignore flag in sorted-GT order. Length `G`.
74 pub gt_ignore: Vec<bool>,
75}
76
77/// Inputs to [`accumulate`] that describe the evaluation grid.
78///
79/// `eval_imgs.len()` must equal `n_categories * n_area_ranges *
80/// n_images`, with the layout `eval_imgs[k * A * I + a * I + i]`
81/// matching pycocotools' flat indexing of `evalImgs`.
82#[derive(Debug, Clone, Copy)]
83pub struct AccumulateParams<'p> {
84 /// IoU thresholds, length `T`. Use [`crate::parity::iou_thresholds`] for
85 /// the canonical 10-point COCO ladder.
86 pub iou_thresholds: &'p [f64],
87 /// Recall integration thresholds, length `R` (typically 101). Use
88 /// [`crate::parity::recall_thresholds`].
89 pub recall_thresholds: &'p [f64],
90 /// Per-image maxDet caps, length `M`. Pycocotools defaults to
91 /// `[1, 10, 100]`. The matching engine should be invoked with the
92 /// *largest* of these — the accumulator slices to smaller caps via
93 /// `[..max_det]`.
94 ///
95 /// Must be sorted ascending (quirk **A2** — strict). Pycocotools
96 /// silently overwrites `p.maxDets = sorted(p.maxDets)` at
97 /// `cocoeval.py:137`, so the M-axis is always laid out
98 /// smallest-to-largest. The summarizer's `AR_1 / AR_10 / AR_100`
99 /// slot mapping depends on this ordering — passing `[100, 1, 10]`
100 /// without sorting would silently swap the slot semantics. Callers
101 /// at the FFI boundary use [`sort_max_dets`] to enforce this.
102 pub max_dets: &'p [usize],
103 /// Number of categories `K` (or `1` when `useCats == 0`).
104 pub n_categories: usize,
105 /// Number of area ranges `A` (COCO defaults to 4: all/small/medium/
106 /// large).
107 pub n_area_ranges: usize,
108 /// Number of images `I`.
109 pub n_images: usize,
110}
111
112/// Normalize a `max_dets` ladder to ascending order, in place.
113///
114/// Mirrors `pycocotools.cocoeval.COCOeval.accumulate`'s opening line
115/// (`cocoeval.py:137`):
116///
117/// ```python
118/// p.maxDets = sorted(p.maxDets)
119/// ```
120///
121/// Quirk **A2** (strict). The accumulator's M-axis is laid out in the
122/// order of the ladder it receives, and the summarizer's
123/// `AR_1 / AR_10 / AR_100` slot mapping is positional — sorting at the
124/// param-construction boundary keeps user input order from silently
125/// permuting the final stat vector. Stable sort (`Vec::sort`); the
126/// ladder is `usize`, so stability matches pycocotools' Python `sorted`
127/// (also stable).
128pub fn sort_max_dets(max_dets: &mut [usize]) {
129 max_dets.sort();
130}
131
132/// Output tensors produced by [`accumulate`].
133///
134/// Cells absent from the dataset (no DTs, or no non-ignore GTs) carry
135/// `-1.0` per quirk **C5**. The summarizer filters these before
136/// averaging; downstream code that consumes the tensors directly must
137/// honor the same convention.
138#[derive(Debug, Clone)]
139pub struct Accumulated {
140 /// Shape `(T, R, K, A, M)`. Right-monotonic precision interpolated
141 /// at every recall threshold.
142 pub precision: Array5<f64>,
143 /// Shape `(T, K, A, M)`. Terminal cumulative recall (quirk **C4**).
144 pub recall: Array4<f64>,
145 /// Shape `(T, R, K, A, M)`. Detection score at the recall threshold
146 /// where each precision sample was taken.
147 pub scores: Array5<f64>,
148}
149
150/// Accumulate per-image evaluation results into precision / recall /
151/// scores tensors.
152///
153/// The flat `eval_imgs` slice must be laid out as `[k][a][i]` (K-major,
154/// then A, then I) — `eval_imgs.len() == K * A * I`.
155///
156/// # Errors
157///
158/// Returns [`EvalError::DimensionMismatch`] if `eval_imgs.len()` does
159/// not equal `K * A * I`, or if any per-image array shapes disagree
160/// with the declared `T` (IoU-threshold count).
161pub fn accumulate(
162 eval_imgs: &[Option<Box<PerImageEval>>],
163 p: AccumulateParams<'_>,
164 _parity_mode: ParityMode,
165) -> Result<Accumulated, EvalError> {
166 let n_t = p.iou_thresholds.len();
167 let n_r = p.recall_thresholds.len();
168 let n_k = p.n_categories;
169 let n_a = p.n_area_ranges;
170 let n_m = p.max_dets.len();
171 let n_i = p.n_images;
172
173 let expected = n_k * n_a * n_i;
174 if eval_imgs.len() != expected {
175 return Err(EvalError::DimensionMismatch {
176 detail: format!(
177 "eval_imgs len {} != n_categories({}) * n_area_ranges({}) * n_images({}) = {}",
178 eval_imgs.len(),
179 n_k,
180 n_a,
181 n_i,
182 expected
183 ),
184 });
185 }
186
187 for cell in eval_imgs.iter().flatten() {
188 if cell.dt_matched.shape() != cell.dt_ignore.shape() {
189 return Err(EvalError::DimensionMismatch {
190 detail: format!(
191 "PerImageEval.dt_matched {:?} != dt_ignore {:?}",
192 cell.dt_matched.shape(),
193 cell.dt_ignore.shape()
194 ),
195 });
196 }
197 if cell.dt_matched.nrows() != n_t {
198 return Err(EvalError::DimensionMismatch {
199 detail: format!(
200 "PerImageEval row count {} != iou_thresholds len {}",
201 cell.dt_matched.nrows(),
202 n_t
203 ),
204 });
205 }
206 if cell.dt_matched.ncols() != cell.dt_scores.len() {
207 return Err(EvalError::DimensionMismatch {
208 detail: format!(
209 "PerImageEval.dt_matched cols {} != dt_scores len {}",
210 cell.dt_matched.ncols(),
211 cell.dt_scores.len()
212 ),
213 });
214 }
215 }
216
217 let mut precision = Array5::<f64>::from_elem((n_t, n_r, n_k, n_a, n_m), -1.0);
218 let mut recall = Array4::<f64>::from_elem((n_t, n_k, n_a, n_m), -1.0);
219 let mut scores = Array5::<f64>::from_elem((n_t, n_r, n_k, n_a, n_m), -1.0);
220
221 for k in 0..n_k {
222 let nk = k * n_a * n_i;
223 for a in 0..n_a {
224 let na = a * n_i;
225 let cells: Vec<&PerImageEval> = (0..n_i)
226 .filter_map(|i| eval_imgs[nk + na + i].as_deref())
227 .collect();
228 if cells.is_empty() {
229 continue;
230 }
231 let npig: usize = cells
232 .iter()
233 .map(|e| e.gt_ignore.iter().filter(|&&ig| !ig).count())
234 .sum();
235 if npig == 0 {
236 continue;
237 }
238
239 for (m, &max_det) in p.max_dets.iter().enumerate() {
240 accumulate_cell(
241 &cells,
242 max_det,
243 npig,
244 n_t,
245 p.recall_thresholds,
246 k,
247 a,
248 m,
249 &mut precision,
250 &mut recall,
251 &mut scores,
252 );
253 }
254 }
255 }
256
257 Ok(Accumulated {
258 precision,
259 recall,
260 scores,
261 })
262}
263
264#[allow(clippy::too_many_arguments)]
265fn accumulate_cell(
266 cells: &[&PerImageEval],
267 max_det: usize,
268 npig: usize,
269 n_t: usize,
270 recall_thresholds: &[f64],
271 k: usize,
272 a: usize,
273 m: usize,
274 precision: &mut Array5<f64>,
275 recall: &mut Array4<f64>,
276 scores: &mut Array5<f64>,
277) {
278 let mut takes: Vec<usize> = Vec::with_capacity(cells.len());
279 let mut total = 0usize;
280 for cell in cells {
281 let take = cell.dt_scores.len().min(max_det);
282 takes.push(take);
283 total += take;
284 }
285 let mut all_scores: Vec<f64> = Vec::with_capacity(total);
286 for (cell, &take) in cells.iter().zip(&takes) {
287 all_scores.extend_from_slice(&cell.dt_scores[..take]);
288 }
289
290 let n_d = all_scores.len();
291 if n_d == 0 {
292 // No detections, but npig > 0 — recall collapses to 0; precision
293 // and scores keep the -1 sentinel.
294 for t in 0..n_t {
295 recall[(t, k, a, m)] = 0.0;
296 }
297 return;
298 }
299
300 let perm = argsort_score_desc(&all_scores);
301
302 let npig_f = npig as f64;
303 let mut rc = vec![0.0_f64; n_d];
304 let mut pr = vec![0.0_f64; n_d];
305 let mut dtm = vec![false; n_d];
306 let mut dtg = vec![false; n_d];
307
308 for t in 0..n_t {
309 let mut cursor = 0;
310 for (cell, &take) in cells.iter().zip(&takes) {
311 let m_row = cell.dt_matched.row(t);
312 let g_row = cell.dt_ignore.row(t);
313 for d in 0..take {
314 dtm[cursor] = m_row[d];
315 dtg[cursor] = g_row[d];
316 cursor += 1;
317 }
318 }
319
320 // C7: cumulative TP/FP exclude ignore-tagged DTs.
321 let mut tp = 0.0_f64;
322 let mut fp = 0.0_f64;
323 for (out_idx, &src_idx) in perm.iter().enumerate() {
324 if !dtg[src_idx] {
325 if dtm[src_idx] {
326 tp += 1.0;
327 } else {
328 fp += 1.0;
329 }
330 }
331 rc[out_idx] = tp / npig_f;
332 pr[out_idx] = tp / (tp + fp + PARITY_EPS);
333 }
334
335 // C4: terminal cumulative recall.
336 recall[(t, k, a, m)] = rc[n_d - 1];
337
338 // C2: right-to-left running max on precision (envelope).
339 for j in (1..n_d).rev() {
340 if pr[j] > pr[j - 1] {
341 pr[j - 1] = pr[j];
342 }
343 }
344
345 // C1 + C3: searchsorted-left + bounds-check. Past the curve,
346 // slots are filled with 0.0 — overwriting the -1 sentinel so the
347 // summarizer's `s > -1` filter keeps them.
348 let mut p_lane = precision
349 .index_axis_mut(Axis(0), t)
350 .index_axis_move(Axis(1), k)
351 .index_axis_move(Axis(1), a)
352 .index_axis_move(Axis(1), m);
353 let mut s_lane = scores
354 .index_axis_mut(Axis(0), t)
355 .index_axis_move(Axis(1), k)
356 .index_axis_move(Axis(1), a)
357 .index_axis_move(Axis(1), m);
358 for (ri, &target) in recall_thresholds.iter().enumerate() {
359 let pi = rc.partition_point(|&v| v < target);
360 if pi < n_d {
361 p_lane[ri] = pr[pi];
362 s_lane[ri] = all_scores[perm[pi]];
363 } else {
364 p_lane[ri] = 0.0;
365 s_lane[ri] = 0.0;
366 }
367 }
368 }
369}
370
371#[cfg(test)]
372mod tests {
373 use super::*;
374 use ndarray::array;
375
376 fn one_threshold_eval(
377 scores: Vec<f64>,
378 matched: Vec<bool>,
379 ignore: Vec<bool>,
380 gt_ignore: Vec<bool>,
381 ) -> PerImageEval {
382 let n = scores.len();
383 let dt_matched =
384 Array2::from_shape_vec((1, n), matched).expect("dt_matched shape mismatch");
385 let dt_ignore = Array2::from_shape_vec((1, n), ignore).expect("dt_ignore shape mismatch");
386 PerImageEval {
387 dt_scores: scores,
388 dt_matched,
389 dt_ignore,
390 gt_ignore,
391 }
392 }
393
394 fn params<'p>(
395 iou: &'p [f64],
396 rec: &'p [f64],
397 max_dets: &'p [usize],
398 n_images: usize,
399 ) -> AccumulateParams<'p> {
400 AccumulateParams {
401 iou_thresholds: iou,
402 recall_thresholds: rec,
403 max_dets,
404 n_categories: 1,
405 n_area_ranges: 1,
406 n_images,
407 }
408 }
409
410 #[test]
411 fn empty_grid_returns_all_sentinel() {
412 let p = params(&[0.5], &[0.0, 0.5, 1.0], &[100], 0);
413 let out = accumulate(&[], p, ParityMode::Strict).unwrap();
414 assert!(out.precision.iter().all(|&v| v == -1.0));
415 assert!(out.recall.iter().all(|&v| v == -1.0));
416 }
417
418 #[test]
419 fn no_dt_with_real_gt_yields_zero_recall_and_sentinel_precision() {
420 // C5: precision stays at -1 sentinel; recall is 0 for every t.
421 let cell = PerImageEval {
422 dt_scores: vec![],
423 dt_matched: Array2::<bool>::default((2, 0)),
424 dt_ignore: Array2::<bool>::default((2, 0)),
425 gt_ignore: vec![false],
426 };
427 let p = params(&[0.5, 0.75], &[0.0, 0.5, 1.0], &[100], 1);
428 let out = accumulate(&[Some(Box::new(cell))], p, ParityMode::Strict).unwrap();
429 assert_eq!(out.recall[(0, 0, 0, 0)], 0.0);
430 assert_eq!(out.recall[(1, 0, 0, 0)], 0.0);
431 // No precision write happened — every cell still -1.
432 for ri in 0..3 {
433 assert_eq!(out.precision[(0, ri, 0, 0, 0)], -1.0);
434 assert_eq!(out.precision[(1, ri, 0, 0, 0)], -1.0);
435 }
436 }
437
438 #[test]
439 fn cell_with_only_ignore_gts_skips_entirely() {
440 // npig == 0 short-circuit: outputs stay at -1 (no recall write).
441 let cell = one_threshold_eval(vec![0.9], vec![true], vec![true], vec![true]);
442 let p = params(&[0.5], &[0.0, 0.5, 1.0], &[100], 1);
443 let out = accumulate(&[Some(Box::new(cell))], p, ParityMode::Strict).unwrap();
444 assert_eq!(out.recall[(0, 0, 0, 0)], -1.0);
445 assert_eq!(out.precision[(0, 0, 0, 0, 0)], -1.0);
446 }
447
448 #[test]
449 fn perfect_match_yields_ap_one_and_ar_one() {
450 // Single DT matches the only real GT → both precision and
451 // recall are 1.0 across every recall threshold.
452 let cell = one_threshold_eval(vec![0.9], vec![true], vec![false], vec![false]);
453 let p = params(&[0.5], &[0.0, 0.5, 1.0], &[100], 1);
454 let out = accumulate(&[Some(Box::new(cell))], p, ParityMode::Strict).unwrap();
455
456 assert_eq!(out.recall[(0, 0, 0, 0)], 1.0);
457 for ri in 0..3 {
458 // Precision is `tp / (tp + fp + eps)` — 1 / (1 + 0 + eps) ≈ 1.
459 let pr = out.precision[(0, ri, 0, 0, 0)];
460 assert!((pr - 1.0).abs() < 1e-12, "precision[{ri}] = {pr}");
461 assert_eq!(out.scores[(0, ri, 0, 0, 0)], 0.9);
462 }
463 }
464
465 #[test]
466 fn lone_fp_yields_zero_recall_zero_precision() {
467 // One unmatched detection, one real unmatched GT → recall 0,
468 // precision 0 across all recall thresholds. The score column
469 // gets a value only at recall=0 (where the curve does exist);
470 // recall thresholds past the end of the curve fall through to
471 // pycocotools' silent-skip branch, leaving 0.0.
472 let cell = one_threshold_eval(vec![0.9], vec![false], vec![false], vec![false]);
473 let p = params(&[0.5], &[0.0, 0.5, 1.0], &[100], 1);
474 let out = accumulate(&[Some(Box::new(cell))], p, ParityMode::Strict).unwrap();
475 assert_eq!(out.recall[(0, 0, 0, 0)], 0.0);
476 for ri in 0..3 {
477 // 0 / (0 + 1 + eps) ≈ 0 → envelope keeps it at 0.
478 assert!(out.precision[(0, ri, 0, 0, 0)].abs() < 1e-12);
479 }
480 // recall threshold 0.0 lands on the lone curve point (rc[0] =
481 // 0.0); 0.5 and 1.0 are past the end → score sentinel 0.0.
482 assert_eq!(out.scores[(0, 0, 0, 0, 0)], 0.9);
483 assert_eq!(out.scores[(0, 1, 0, 0, 0)], 0.0);
484 assert_eq!(out.scores[(0, 2, 0, 0, 0)], 0.0);
485 }
486
487 #[test]
488 fn ignored_dt_does_not_count_as_fp() {
489 // C7: an ignore-tagged DT is invisible to both TP and FP cumsums.
490 // Setup: one real GT (matched by DT 0), one DT 1 that misses but
491 // is ignore-tagged (e.g. out-of-area unmatched). FP must not
492 // appear in the curve.
493 let cell = one_threshold_eval(
494 vec![0.9, 0.8],
495 vec![true, false],
496 vec![false, true],
497 vec![false],
498 );
499 let p = params(&[0.5], &[0.0, 0.5, 1.0], &[100], 1);
500 let out = accumulate(&[Some(Box::new(cell))], p, ParityMode::Strict).unwrap();
501
502 // tp=1 fp=0 → precision ≈ 1 everywhere on the curve.
503 for ri in 0..3 {
504 let pr = out.precision[(0, ri, 0, 0, 0)];
505 assert!((pr - 1.0).abs() < 1e-12, "precision[{ri}] = {pr}");
506 }
507 assert_eq!(out.recall[(0, 0, 0, 0)], 1.0);
508 }
509
510 #[test]
511 fn precision_envelope_runs_right_to_left() {
512 // C2: pre-envelope precision dips. Curve: TP, FP, TP → precisions
513 // 1.0, 0.5, 0.667. After right-to-left max: 1.0, 0.667, 0.667.
514 // Recall thresholds 0.0 and 0.5 (rc = [0.5, 0.5, 1.0]) sample
515 // index 0; threshold 1.0 samples index 2.
516 let cell = one_threshold_eval(
517 vec![0.9, 0.8, 0.7],
518 vec![true, false, true],
519 vec![false, false, false],
520 vec![false, false],
521 );
522 let p = params(&[0.5], &[0.0, 0.5, 1.0], &[100], 1);
523 let out = accumulate(&[Some(Box::new(cell))], p, ParityMode::Strict).unwrap();
524
525 // recall thresholds 0.0 and 0.5 both fall on the first rc cell
526 // where rc[0] = 0.5 (TP at j=0 → 1/2). Envelope makes pr[0]=1.0.
527 assert!((out.precision[(0, 0, 0, 0, 0)] - 1.0).abs() < 1e-12);
528 assert!((out.precision[(0, 1, 0, 0, 0)] - 1.0).abs() < 1e-12);
529 // recall threshold 1.0 samples j=2: pr[2] = 2/3.
530 assert!((out.precision[(0, 2, 0, 0, 0)] - 2.0 / 3.0).abs() < 1e-12);
531 }
532
533 #[test]
534 fn partition_point_matches_numpy_searchsorted_left() {
535 // Pinning the stdlib semantics so a future swap (e.g., to a
536 // SIMD search) keeps `np.searchsorted(..., side='left')` parity.
537 let haystack = [0.1, 0.3, 0.3, 0.7];
538 let lookup = |t: f64| haystack.partition_point(|&v| v < t);
539 assert_eq!(lookup(0.0), 0);
540 assert_eq!(lookup(0.3), 1); // leftmost equal
541 assert_eq!(lookup(0.5), 3);
542 assert_eq!(lookup(1.0), 4); // past end
543 }
544
545 #[test]
546 fn merged_sort_breaks_ties_by_input_order() {
547 // A1 over the merged stream: two images with one DT each at
548 // score 0.7. With stable sort, image-0 DT comes first.
549 let img0 = one_threshold_eval(vec![0.7], vec![true], vec![false], vec![false]);
550 let img1 = one_threshold_eval(vec![0.7], vec![false], vec![false], vec![false]);
551 // grid: K=1, A=1, I=2 → eval_imgs[0..2] is the (k=0, a=0) row.
552 let grid = vec![Some(Box::new(img0)), Some(Box::new(img1))];
553 let p = params(&[0.5], &[0.0, 0.5, 1.0], &[100], 2);
554 let out = accumulate(&grid, p, ParityMode::Strict).unwrap();
555
556 // tp=1, fp=1 → final pr = 0.5; rc = [0.5, 0.5]. With envelope
557 // (no monotonicity adjustment needed because pr[1] < pr[0]),
558 // recThr 0.0 and 0.5 both sample index 0 (pr ≈ 1.0), recThr 1.0
559 // is past the end → 0.0.
560 assert!((out.precision[(0, 0, 0, 0, 0)] - 1.0).abs() < 1e-12);
561 assert!((out.precision[(0, 1, 0, 0, 0)] - 1.0).abs() < 1e-12);
562 assert_eq!(out.precision[(0, 2, 0, 0, 0)], 0.0);
563 }
564
565 #[test]
566 fn max_det_truncation_drops_low_score_dts_per_image() {
567 // Per-image max_det=1: only the top-scoring DT survives, even
568 // though more were emitted. With only the FP at score 0.95
569 // surviving, AP must collapse.
570 let cell = one_threshold_eval(
571 vec![0.95, 0.9],
572 vec![false, true], // FP first, TP second
573 vec![false, false],
574 vec![false],
575 );
576 let p = params(&[0.5], &[0.0, 0.5, 1.0], &[1], 1);
577 let out = accumulate(&[Some(Box::new(cell))], p, ParityMode::Strict).unwrap();
578 // Only FP survived → tp=0, fp=1, precision ≈ 0 everywhere.
579 for ri in 0..3 {
580 assert!(out.precision[(0, ri, 0, 0, 0)].abs() < 1e-12);
581 }
582 assert_eq!(out.recall[(0, 0, 0, 0)], 0.0);
583 }
584
585 #[test]
586 fn dimension_mismatch_on_grid_size_is_typed_error() {
587 let p = params(&[0.5], &[0.0], &[100], 5);
588 // Grid claims K*A*I = 1*1*5 = 5 cells; we pass 2 → error.
589 let err = accumulate(&[None, None], p, ParityMode::Strict).unwrap_err();
590 match err {
591 EvalError::DimensionMismatch { detail } => {
592 assert!(detail.contains("eval_imgs"));
593 }
594 other => panic!("expected DimensionMismatch, got {other:?}"),
595 }
596 }
597
598 #[test]
599 fn dimension_mismatch_on_per_image_t_is_typed_error() {
600 // Per-image dt_matched has 2 rows, params declare 3 IoU
601 // thresholds → mismatch reported.
602 let cell = PerImageEval {
603 dt_scores: vec![0.9],
604 dt_matched: array![[true], [true]],
605 dt_ignore: array![[false], [false]],
606 gt_ignore: vec![false],
607 };
608 let p = params(&[0.5, 0.75, 0.9], &[0.0], &[100], 1);
609 let err = accumulate(&[Some(Box::new(cell))], p, ParityMode::Strict).unwrap_err();
610 assert!(matches!(err, EvalError::DimensionMismatch { .. }));
611 }
612
613 #[test]
614 fn reaccumulate_with_different_area_range_count_is_typed_error() {
615 // A3: re-accumulating an `eval_imgs` grid built for one A-axis
616 // size against an `AccumulateParams` with a different
617 // `n_area_ranges` must surface DimensionMismatch — not silently
618 // produce wrong outputs by re-slicing the flat buffer at the new
619 // pitch. Build a 4-area-range grid (the COCO default), then try
620 // to accumulate it as if it were a 3-area-range grid.
621 let n_i = 1;
622 let n_a_built = 4;
623 let n_k = 1;
624 let cell = one_threshold_eval(vec![0.9], vec![true], vec![false], vec![false]);
625 // Only the first (k=0, a=0, i=0) slot carries data; remaining
626 // slots are None as they would be for an image with no GTs/DTs in
627 // those buckets.
628 let mut eval_imgs: Vec<Option<Box<PerImageEval>>> = vec![None; n_k * n_a_built * n_i];
629 eval_imgs[0] = Some(Box::new(cell));
630
631 // Mismatched params: claim the grid has 3 area ranges. Expected
632 // grid size becomes 1*3*1 = 3, but we pass 4 cells → typed error.
633 let mut bad = params(&[0.5], &[0.0, 0.5, 1.0], &[100], n_i);
634 bad.n_area_ranges = 3;
635 let err = accumulate(&eval_imgs, bad, ParityMode::Strict).unwrap_err();
636 match err {
637 EvalError::DimensionMismatch { detail } => {
638 assert!(detail.contains("eval_imgs"), "msg: {detail}");
639 assert!(detail.contains("n_area_ranges(3)"), "msg: {detail}");
640 }
641 other => panic!("expected DimensionMismatch, got {other:?}"),
642 }
643 }
644
645 #[test]
646 fn vectorized_inner_sweep_matches_naive_reference() {
647 // C6: the inner recall-threshold sweep is vectorized via
648 // partition_point + an in-place right-to-left envelope. Pin it
649 // against a naive reference that mirrors pycocotools' Python
650 // `for ri, pi in enumerate(inds): q[ri] = pr[pi]` line by line.
651 //
652 // Three hand-crafted PR curves cover the edge cases:
653 // - monotonic-decreasing precision (no envelope work);
654 // - non-monotonic precision (envelope rewrites multiple cells);
655 // - all-1.0 precision with the recall curve ending at 0.5 so
656 // half the recall thresholds fall past the curve (C3 path).
657 //
658 // Only the precision lane is compared — both implementations
659 // share the same recall-index lookup, so the score lane would
660 // trivially agree.
661 let recall_thresholds: Vec<f64> = (0..=10).map(|i| (i as f64) / 10.0).collect();
662
663 // Naive reference: explicit right-to-left running max + linear
664 // searchsorted-left scan.
665 fn naive_sweep(rc: &[f64], pr: &[f64], rec_thr: &[f64]) -> Vec<f64> {
666 let n = pr.len();
667 let mut env = pr.to_vec();
668 for j in (1..n).rev() {
669 if env[j] > env[j - 1] {
670 env[j - 1] = env[j];
671 }
672 }
673 let mut q = vec![0.0_f64; rec_thr.len()];
674 for (ri, &target) in rec_thr.iter().enumerate() {
675 let mut pi = n;
676 for (j, &r) in rc.iter().enumerate() {
677 if r >= target {
678 pi = j;
679 break;
680 }
681 }
682 if pi < n {
683 q[ri] = env[pi];
684 }
685 }
686 q
687 }
688
689 // Vectorized reference: same shape as `accumulate_cell`'s inner
690 // sweep, callable on hand-crafted curves without rebuilding the
691 // whole `(T, R, K, A, M)` tensor. Drift between this body and
692 // the production sweep is what the test exists to catch.
693 fn vectorized_sweep(rc: &[f64], pr: &[f64], rec_thr: &[f64]) -> Vec<f64> {
694 let n = pr.len();
695 let mut env = pr.to_vec();
696 for j in (1..n).rev() {
697 if env[j] > env[j - 1] {
698 env[j - 1] = env[j];
699 }
700 }
701 let mut q = vec![0.0_f64; rec_thr.len()];
702 for (ri, &target) in rec_thr.iter().enumerate() {
703 let pi = rc.partition_point(|&v| v < target);
704 if pi < n {
705 q[ri] = env[pi];
706 }
707 }
708 q
709 }
710
711 let curves: &[(&[f64], &[f64])] = &[
712 // Monotonic-decreasing precision; recall reaches 1.0.
713 (&[0.1, 0.3, 0.5, 0.7, 1.0], &[1.0, 0.9, 0.7, 0.5, 0.3]),
714 // Non-monotonic precision: envelope rewrites cells 1 and 3.
715 (&[0.2, 0.4, 0.6, 0.8, 1.0], &[1.0, 0.4, 0.6, 0.2, 0.5]),
716 // All-1.0 precision; recall caps at 0.5 → recall thresholds
717 // > 0.5 fall past the curve (C3 silent-skip path → 0.0).
718 (&[0.1, 0.2, 0.3, 0.4, 0.5], &[1.0, 1.0, 1.0, 1.0, 1.0]),
719 ];
720
721 for (i, (rc, pr)) in curves.iter().enumerate() {
722 let q_naive = naive_sweep(rc, pr, &recall_thresholds);
723 let q_vec = vectorized_sweep(rc, pr, &recall_thresholds);
724 assert_eq!(q_naive.len(), q_vec.len(), "curve {i}");
725 for (ri, (a, b)) in q_naive.iter().zip(q_vec.iter()).enumerate() {
726 assert_eq!(
727 a.to_bits(),
728 b.to_bits(),
729 "curve {i}, recall threshold index {ri}: naive={a}, vec={b}"
730 );
731 }
732 }
733 }
734
735 #[test]
736 fn sort_max_dets_normalizes_ascending() {
737 // Quirk A2: pycocotools' `cocoeval.py:137` does
738 // `p.maxDets = sorted(p.maxDets)` — `sort_max_dets` is the
739 // mirror at the param-construction boundary.
740 let mut ladder = vec![100usize, 1, 10];
741 sort_max_dets(&mut ladder);
742 assert_eq!(ladder, vec![1, 10, 100]);
743 }
744
745 #[test]
746 fn sort_max_dets_is_idempotent_on_sorted_input() {
747 let mut ladder = vec![1usize, 10, 100];
748 sort_max_dets(&mut ladder);
749 assert_eq!(ladder, vec![1, 10, 100]);
750 }
751
752 #[test]
753 fn sort_max_dets_handles_duplicates_and_singletons() {
754 let mut singleton = vec![100usize];
755 sort_max_dets(&mut singleton);
756 assert_eq!(singleton, vec![100]);
757
758 let mut empty: Vec<usize> = Vec::new();
759 sort_max_dets(&mut empty);
760 assert!(empty.is_empty());
761
762 let mut dups = vec![10usize, 1, 10, 1, 100];
763 sort_max_dets(&mut dups);
764 assert_eq!(dups, vec![1, 1, 10, 10, 100]);
765 }
766
767 #[test]
768 fn permuted_ladder_after_sort_matches_canonical_order() {
769 // End-to-end: feeding `[100, 1, 10]` after `sort_max_dets`
770 // produces a `(T, R, K, A, M)` accumulator whose M-axis is
771 // identical to the one built from the canonical `[1, 10, 100]`.
772 // Without the sort, the M-axis slots would be swapped and the
773 // summarizer's positional `AR_1 / AR_10 / AR_100` mapping would
774 // bind to the wrong threshold.
775 let cell = one_threshold_eval(
776 vec![0.9, 0.8, 0.7],
777 vec![true, true, false],
778 vec![false, false, false],
779 vec![false, false, false],
780 );
781 let iou = [0.5];
782 let rec = [0.0, 0.5, 1.0];
783
784 let canonical = vec![1usize, 10, 100];
785 let canonical_acc = accumulate(
786 &[Some(Box::new(cell.clone()))],
787 params(&iou, &rec, &canonical, 1),
788 ParityMode::Strict,
789 )
790 .unwrap();
791
792 let mut permuted = vec![100usize, 1, 10];
793 sort_max_dets(&mut permuted);
794 assert_eq!(permuted, canonical);
795 let permuted_acc = accumulate(
796 &[Some(Box::new(cell))],
797 params(&iou, &rec, &permuted, 1),
798 ParityMode::Strict,
799 )
800 .unwrap();
801
802 assert_eq!(canonical_acc.precision, permuted_acc.precision);
803 assert_eq!(canonical_acc.recall, permuted_acc.recall);
804 assert_eq!(canonical_acc.scores, permuted_acc.scores);
805 }
806}