vernier_core/summarize.rs
1//! Twelve-stat detection summary atop [`crate::Accumulated`].
2//!
3//! Mirrors `pycocotools.cocoeval.COCOeval.summarize` (cocoeval.py
4//! lines 422-475), but as a pure structured value — no stdout side
5//! effects (quirks **L5/L6/L7**, dispositioned `corrected`).
6//!
7//! ## Quirk dispositions
8//!
9//! - **C5** (`strict`): cells absent from the dataset carry `-1`;
10//! summarization filters them out via `s > -1` before averaging.
11//! - **L5** (`corrected`): the print/log side-effect from upstream
12//! `_summarize` is gone. Use [`Summary::pretty_lines`] for the
13//! pycocotools-shaped human-readable rendering.
14//! - **L6** (`corrected`): empty-eval `mean(empty)` no longer raises a
15//! numpy RuntimeWarning — the absent case explicitly returns `-1`.
16//! - **L7** (`corrected`): the result is a value (`Summary`), not a
17//! property side-effect on the evaluator.
18
19use std::borrow::Cow;
20use std::collections::HashMap;
21use std::ops::Range;
22
23use ndarray::Axis;
24
25use crate::accumulate::Accumulated;
26use crate::dataset::{CategoryId, Frequency};
27use crate::error::EvalError;
28
29/// Tolerance for matching a user-supplied IoU threshold to a value in
30/// the `iou_thresholds` ladder. Rounds out the ulp-level error from the
31/// `linspace(0.5, 0.95, 10)` build (quirk **L1**).
32pub(crate) const IOU_LOOKUP_TOL: f64 = 1e-12;
33
34/// One bucket on the A-axis of an [`Accumulated`] — an index plus a
35/// label for rendering.
36///
37/// The canonical pycocotools detection layout is exposed as
38/// [`AreaRng::ALL`] / [`SMALL`](Self::SMALL) / [`MEDIUM`](Self::MEDIUM)
39/// / [`LARGE`](Self::LARGE), matching the cocoeval `Params.areaRngLbl`
40/// order. Custom layouts (e.g., robotics-style finer buckets) are
41/// constructed with [`AreaRng::new`] for owned labels or
42/// [`AreaRng::from_static`] for `&'static str` labels.
43///
44/// The *bounds* that turn an annotation's area into a bucket index
45/// live upstream, on the orchestrator that builds [`crate::accumulate::PerImageEval`]
46/// cells; the summarizer only consumes the resulting A-axis index.
47#[derive(Debug, Clone, PartialEq, Eq)]
48pub struct AreaRng {
49 /// Position on the A-axis of [`Accumulated::precision`] /
50 /// [`Accumulated::recall`]. Validated against the actual A-axis
51 /// length at summarize time, not at construction; an out-of-range
52 /// index produces [`EvalError::InvalidConfig`].
53 pub index: usize,
54 /// Label rendered by [`Summary::pretty_lines`].
55 pub label: Cow<'static, str>,
56}
57
58impl AreaRng {
59 /// Construct from any owned- or borrowed-string label.
60 pub fn new(index: usize, label: impl Into<Cow<'static, str>>) -> Self {
61 Self {
62 index,
63 label: label.into(),
64 }
65 }
66
67 /// `const`-friendly constructor for compile-time labels.
68 pub const fn from_static(index: usize, label: &'static str) -> Self {
69 Self {
70 index,
71 label: Cow::Borrowed(label),
72 }
73 }
74
75 /// COCO `all` bucket — pycocotools' `[0, 1e10]`, A-axis index 0.
76 pub const ALL: Self = Self::from_static(0, "all");
77 /// COCO `small` bucket — pycocotools' `[0, 32^2]`, A-axis index 1.
78 pub const SMALL: Self = Self::from_static(1, "small");
79 /// COCO `medium` bucket — pycocotools' `[32^2, 96^2]`, A-axis index 2.
80 pub const MEDIUM: Self = Self::from_static(2, "medium");
81 /// COCO `large` bucket — pycocotools' `[96^2, 1e10]`, A-axis index 3.
82 pub const LARGE: Self = Self::from_static(3, "large");
83}
84
85/// AP / AR selector emitted on every [`StatLine`].
86#[derive(Debug, Clone, Copy, PartialEq, Eq)]
87pub enum Metric {
88 /// Average Precision — slices `Accumulated::precision`.
89 AveragePrecision,
90 /// Average Recall — slices `Accumulated::recall`. Quirk **C4**: AR
91 /// is the terminal cumulative recall, not an integral of the
92 /// precision/recall curve.
93 AverageRecall,
94}
95
96/// Single line of the COCO 12-stat summary table.
97#[derive(Debug, Clone)]
98pub struct StatLine {
99 /// AP or AR.
100 pub metric: Metric,
101 /// `None` means averaged across the whole IoU ladder; `Some(t)`
102 /// pins a specific threshold (e.g., 0.5 for AP@.50).
103 pub iou_threshold: Option<f64>,
104 /// Area-range bucket.
105 pub area: AreaRng,
106 /// Per-image maxDet cap.
107 pub max_dets: usize,
108 /// Mean over the matching slice, ignoring `-1` sentinels. `-1.0`
109 /// when the slice has no non-sentinel entries (quirks **C5/L6**).
110 pub value: f64,
111}
112
113/// Result of evaluating a summary plan over an [`Accumulated`].
114///
115/// `lines.len()` matches the plan length; for the canonical pycocotools
116/// detection summary built by [`summarize_detection`], that's 12 lines
117/// in the order `[AP, AP50, AP75, AP_S, AP_M, AP_L, AR_1, AR_10,
118/// AR_100, AR_S, AR_M, AR_L]`. For custom plans evaluated via
119/// [`summarize_with`], `lines` mirrors the request order.
120#[derive(Debug, Clone)]
121pub struct Summary {
122 /// One entry per request in the evaluated plan, paired with slicing
123 /// metadata.
124 pub lines: Vec<StatLine>,
125}
126
127impl Summary {
128 /// Numeric values in plan order. Equivalent to
129 /// `lines.iter().map(|l| l.value).collect()`.
130 pub fn stats(&self) -> Vec<f64> {
131 self.lines.iter().map(|l| l.value).collect()
132 }
133 /// Render the canonical pycocotools text table (12 lines, each in
134 /// the upstream `Average Precision (AP) @[ IoU=... | area=... |
135 /// maxDets=... ] = 0.xxx` shape). Returned as a `Vec<String>`; the
136 /// caller decides whether to print, log, or test against it.
137 pub fn pretty_lines(&self) -> Vec<String> {
138 self.lines
139 .iter()
140 .map(|line| {
141 let (title, kind) = match line.metric {
142 Metric::AveragePrecision => ("Average Precision", "(AP)"),
143 Metric::AverageRecall => ("Average Recall", "(AR)"),
144 };
145 let iou = match line.iou_threshold {
146 Some(t) => format!("{t:0.2}"),
147 None => "0.50:0.95".to_string(),
148 };
149 format!(
150 " {title:<18} {kind} @[ IoU={iou:<9} | area={:>6} | maxDets={:>3} ] = {:0.3}",
151 line.area.label, line.max_dets, line.value
152 )
153 })
154 .collect()
155 }
156}
157
158/// How a [`StatRequest`] picks an entry on the M-axis of an
159/// [`Accumulated`].
160///
161/// Pycocotools hard-codes `maxDets[0|1|2]` for `AR_{1,10,100}` and
162/// `maxDets[-1]` for everything else; this enum lets a plan express
163/// that intent — "the largest cap available" or "the entry whose value
164/// equals N" — without binding to fixed positional indices.
165#[derive(Debug, Clone, Copy, PartialEq, Eq)]
166pub enum MaxDetSelector {
167 /// Pick the largest cap in the supplied `max_dets` slice. This is
168 /// what every cocoeval AP line and `AR_S` / `AR_M` / `AR_L` use.
169 Largest,
170 /// Pick the M-axis entry whose value equals this. Errors via
171 /// [`EvalError::InvalidConfig`] if the value is absent.
172 Value(usize),
173}
174
175/// K-axis subset selector (ADR-0026 D2). Filters which categories
176/// contribute to a [`StatRequest`]'s mean.
177///
178/// Frequency buckets are *not* a [`crate::breakdown::Breakdown`] axis — they
179/// are a category-subset selector, the K-axis equivalent of an area
180/// bucket. The discriminated form keeps frequency-keyed (LVIS) and
181/// id-keyed (per-supercategory, ablation subsets) intents cleanly
182/// separated; the resolution to a list of K indices happens at
183/// summarize time once the K-axis ordering is known.
184///
185/// Only [`CategoryFilter::All`] is supported by the standard
186/// [`summarize_with`] entry point; the [`Frequency`] and
187/// [`ByIds`](Self::ByIds) variants require the K-axis context (the
188/// list of `CategoryId`s in axis order, plus the per-category
189/// frequency map for [`Frequency`](Self::Frequency)) and route through
190/// [`summarize_with_lvis`].
191#[derive(Debug, Clone, PartialEq, Eq)]
192pub enum CategoryFilter {
193 /// No filter — every category contributes (the COCO default).
194 All,
195 /// Include only categories whose [`Frequency`] tag matches. Quirk
196 /// **AB3**: empty after the `-1`-sentinel drop (handled internally
197 /// by the summarizer) yields
198 /// `-1.0`, not `0.0` or `nan` (lvis-api's `eval.py:441-442`).
199 Frequency(Frequency),
200 /// Explicit subset: include only categories whose id is in the
201 /// list. Sorted ascending for stable membership tests; duplicates
202 /// are ignored.
203 ByIds(Vec<CategoryId>),
204 /// Include only categories that belong to the named group of the
205 /// active [`ClassGroupBreakdown`] (ADR-0041 / ADR-0042). The string
206 /// is the group label; resolution happens at summarize time
207 /// against the configured grouping.
208 ///
209 /// Unlike [`Frequency`](Self::Frequency) / [`ByIds`](Self::ByIds), this variant does *not*
210 /// require LVIS context — it routes through the standard
211 /// summarizer with the breakdown reference passed in via the
212 /// non-LVIS context shim.
213 ///
214 /// [`ClassGroupBreakdown`]: crate::breakdown::ClassGroupBreakdown
215 ByGrouping(Cow<'static, str>),
216}
217
218impl CategoryFilter {
219 /// `true` if this filter requires the K-axis context (frequency
220 /// map or id ordering) — i.e., [`Self::Frequency`] or
221 /// [`Self::ByIds`]. [`Self::All`] and [`Self::ByGrouping`] are
222 /// resolvable without LVIS context (the latter via the
223 /// `ClassGroupBreakdown` reference passed alongside).
224 pub fn needs_lvis_context(&self) -> bool {
225 matches!(self, Self::Frequency(_) | Self::ByIds(_))
226 }
227}
228
229/// One line of a summary plan — describes a single mean to compute.
230#[derive(Debug, Clone)]
231pub struct StatRequest {
232 /// AP or AR.
233 pub metric: Metric,
234 /// `None` averages across the IoU ladder; `Some(t)` pins one row.
235 /// Looked up against `iou_thresholds` within an internal absolute
236 /// tolerance (≈1e-9) at
237 /// summarize time; values not on the ladder produce
238 /// [`EvalError::InvalidConfig`].
239 pub iou_threshold: Option<f64>,
240 /// Area-range bucket on the A-axis.
241 pub area: AreaRng,
242 /// How to pick the M-axis entry.
243 pub max_dets: MaxDetSelector,
244 /// K-axis subset (ADR-0026 D2). Defaults to
245 /// [`CategoryFilter::All`] for COCO-shape plans; LVIS plans use
246 /// [`CategoryFilter::Frequency`] for the AP_r/c/f buckets.
247 pub category_filter: CategoryFilter,
248}
249
250impl StatRequest {
251 /// Convenience constructor. `const`-callable so [`coco_detection_default`]
252 /// and downstream user-defined plans can be assembled in `const`
253 /// contexts. Defaults `category_filter` to [`CategoryFilter::All`].
254 ///
255 /// [`coco_detection_default`]: Self::coco_detection_default
256 pub const fn new(
257 metric: Metric,
258 iou_threshold: Option<f64>,
259 area: AreaRng,
260 max_dets: MaxDetSelector,
261 ) -> Self {
262 Self {
263 metric,
264 iou_threshold,
265 area,
266 max_dets,
267 category_filter: CategoryFilter::All,
268 }
269 }
270
271 /// Construct with a non-default [`CategoryFilter`] in one shot.
272 /// `const`-callable for the [`Frequency`](CategoryFilter::Frequency)
273 /// and [`All`](CategoryFilter::All) variants;
274 /// [`ByIds`](CategoryFilter::ByIds) carries a heap-allocated
275 /// `Vec<CategoryId>` and is constructed at runtime.
276 pub const fn new_with_filter(
277 metric: Metric,
278 iou_threshold: Option<f64>,
279 area: AreaRng,
280 max_dets: MaxDetSelector,
281 category_filter: CategoryFilter,
282 ) -> Self {
283 Self {
284 metric,
285 iou_threshold,
286 area,
287 max_dets,
288 category_filter,
289 }
290 }
291
292 /// The canonical 12-entry pycocotools detection plan, in the
293 /// `[AP, AP50, AP75, AP_S, AP_M, AP_L, AR_1, AR_10, AR_100, AR_S,
294 /// AR_M, AR_L]` order. Bit-exact with cocoeval is by construction:
295 /// [`summarize_detection`] is just `summarize_with(.., this, ..)`.
296 pub const fn coco_detection_default() -> [Self; 12] {
297 use MaxDetSelector::{Largest, Value};
298 use Metric::{AveragePrecision, AverageRecall};
299 [
300 Self::new(AveragePrecision, None, AreaRng::ALL, Largest),
301 Self::new(AveragePrecision, Some(0.5), AreaRng::ALL, Largest),
302 Self::new(AveragePrecision, Some(0.75), AreaRng::ALL, Largest),
303 Self::new(AveragePrecision, None, AreaRng::SMALL, Largest),
304 Self::new(AveragePrecision, None, AreaRng::MEDIUM, Largest),
305 Self::new(AveragePrecision, None, AreaRng::LARGE, Largest),
306 Self::new(AverageRecall, None, AreaRng::ALL, Value(1)),
307 Self::new(AverageRecall, None, AreaRng::ALL, Value(10)),
308 Self::new(AverageRecall, None, AreaRng::ALL, Value(100)),
309 Self::new(AverageRecall, None, AreaRng::SMALL, Largest),
310 Self::new(AverageRecall, None, AreaRng::MEDIUM, Largest),
311 Self::new(AverageRecall, None, AreaRng::LARGE, Largest),
312 ]
313 }
314
315 /// The canonical 13-entry LVIS detection plan (ADR-0026 AF1, AF4),
316 /// in the LVIS `print_results` order:
317 ///
318 /// `[AP, AP50, AP75, APs, APm, APl, APr, APc, APf,
319 /// AR@300, ARs@300, ARm@300, ARl@300]`
320 ///
321 /// Differences from [`Self::coco_detection_default`]:
322 ///
323 /// - **9 AP entries vs 6.** Three additional rows (APr/APc/APf)
324 /// filter the K-axis by [`Frequency`] tag. `lvis-api` reports
325 /// them as separate entries, not `Breakdown` axes (ADR-0016
326 /// `f64`-keyed type doesn't fit categorical tags).
327 /// - **4 AR entries vs 6.** No `AR@1` / `AR@10` / `AR@100` —
328 /// LVIS reports recall at `max_dets=300` only (AF4). The
329 /// `Largest` selector resolves to whatever the user passes;
330 /// pair the plan with `max_dets=[300]` for parity with
331 /// `LVISEval`.
332 ///
333 /// `Frequency`-filtered entries route through
334 /// [`summarize_with_lvis`]; calling [`summarize_with`] on this
335 /// plan returns [`EvalError::InvalidConfig`] (the plain entry
336 /// point has no K-axis context).
337 pub const fn lvis_default() -> [Self; 13] {
338 use CategoryFilter::{All as AllK, Frequency as FreqK};
339 use MaxDetSelector::Largest;
340 use Metric::{AveragePrecision, AverageRecall};
341 [
342 Self::new_with_filter(AveragePrecision, None, AreaRng::ALL, Largest, AllK),
343 Self::new_with_filter(AveragePrecision, Some(0.5), AreaRng::ALL, Largest, AllK),
344 Self::new_with_filter(AveragePrecision, Some(0.75), AreaRng::ALL, Largest, AllK),
345 Self::new_with_filter(AveragePrecision, None, AreaRng::SMALL, Largest, AllK),
346 Self::new_with_filter(AveragePrecision, None, AreaRng::MEDIUM, Largest, AllK),
347 Self::new_with_filter(AveragePrecision, None, AreaRng::LARGE, Largest, AllK),
348 Self::new_with_filter(
349 AveragePrecision,
350 None,
351 AreaRng::ALL,
352 Largest,
353 FreqK(Frequency::Rare),
354 ),
355 Self::new_with_filter(
356 AveragePrecision,
357 None,
358 AreaRng::ALL,
359 Largest,
360 FreqK(Frequency::Common),
361 ),
362 Self::new_with_filter(
363 AveragePrecision,
364 None,
365 AreaRng::ALL,
366 Largest,
367 FreqK(Frequency::Frequent),
368 ),
369 Self::new_with_filter(AverageRecall, None, AreaRng::ALL, Largest, AllK),
370 Self::new_with_filter(AverageRecall, None, AreaRng::SMALL, Largest, AllK),
371 Self::new_with_filter(AverageRecall, None, AreaRng::MEDIUM, Largest, AllK),
372 Self::new_with_filter(AverageRecall, None, AreaRng::LARGE, Largest, AllK),
373 ]
374 }
375
376 /// The canonical 10-entry pycocotools keypoints plan, in the
377 /// `[AP, AP50, AP75, AP_M, AP_L, AR, AR50, AR75, AR_M, AR_L]`
378 /// order (cocoeval.py:478-499 under `iouType="keypoints"`).
379 ///
380 /// Differs from [`Self::coco_detection_default`] in three ways,
381 /// all per ADR-0012:
382 ///
383 /// - 10 entries, not 12 — the small-area row is dropped on both
384 /// AP and AR (quirk **D5**).
385 /// - Every entry uses [`MaxDetSelector::Largest`], which resolves
386 /// to the kp-canonical `(20,)` ladder; there are no `AR_1` /
387 /// `AR_10` / `AR_100` rows because the kp ladder has only one
388 /// rung.
389 /// - The `AreaRng` indices `0/1/2` (all/medium/large) are
390 /// re-indexed for the kp A-axis. Callers must pair this plan
391 /// with [`crate::AreaRange::keypoints_default`] so the A-axis
392 /// indices line up; the const [`AreaRng::ALL`] / `MEDIUM` /
393 /// `LARGE` carry the four-bucket detection-grid indices and
394 /// would index off the end of a three-bucket accumulator.
395 pub const fn coco_keypoints_default() -> [Self; 10] {
396 use MaxDetSelector::Largest;
397 use Metric::{AveragePrecision, AverageRecall};
398 // D5: re-indexed kp A-axis (0=all, 1=medium, 2=large), no small.
399 // `from_static` is `const`, so each call site materializes a
400 // fresh `AreaRng` without an intermediate `clone()` — mirroring
401 // `coco_detection_default`'s use of the const `AreaRng::ALL`
402 // / `MEDIUM` / `LARGE` constants.
403 const ALL: AreaRng = AreaRng::from_static(0, "all");
404 const MEDIUM: AreaRng = AreaRng::from_static(1, "medium");
405 const LARGE: AreaRng = AreaRng::from_static(2, "large");
406 [
407 Self::new(AveragePrecision, None, ALL, Largest),
408 Self::new(AveragePrecision, Some(0.5), ALL, Largest),
409 Self::new(AveragePrecision, Some(0.75), ALL, Largest),
410 Self::new(AveragePrecision, None, MEDIUM, Largest),
411 Self::new(AveragePrecision, None, LARGE, Largest),
412 Self::new(AverageRecall, None, ALL, Largest),
413 Self::new(AverageRecall, Some(0.5), ALL, Largest),
414 Self::new(AverageRecall, Some(0.75), ALL, Largest),
415 Self::new(AverageRecall, None, MEDIUM, Largest),
416 Self::new(AverageRecall, None, LARGE, Largest),
417 ]
418 }
419}
420
421/// Twelve-stat COCO detection summary, bit-exact with cocoeval.
422///
423/// Thin wrapper over [`summarize_with`] that supplies the canonical
424/// 12-entry plan from [`StatRequest::coco_detection_default`].
425/// Downstream callers who need a different shape (keypoint `[20]`
426/// maxDets, custom AP@.30, …) should call `summarize_with` directly
427/// with their own plan; the canonical plan is available via the
428/// constructor for those who want to extend rather than replace it.
429///
430/// # Errors
431///
432/// Same conditions as [`summarize_with`].
433pub fn summarize_detection(
434 accum: &Accumulated,
435 iou_thresholds: &[f64],
436 max_dets: &[usize],
437) -> Result<Summary, EvalError> {
438 summarize_with(
439 accum,
440 &StatRequest::coco_detection_default(),
441 iou_thresholds,
442 max_dets,
443 )
444}
445
446/// Evaluate an arbitrary summary plan over an [`Accumulated`].
447///
448/// `iou_thresholds` and `max_dets` describe the grid the `Accumulated`
449/// was built against; they are needed to resolve [`StatRequest`]
450/// selectors (IoU value → T-axis index, [`MaxDetSelector`] → M-axis
451/// index) and to populate the `max_dets` field on each emitted
452/// [`StatLine`].
453///
454/// # Errors
455///
456/// Returns [`EvalError::DimensionMismatch`] if `iou_thresholds` or
457/// `max_dets` lengths disagree with `accum`'s `T`/`M` axes. Returns
458/// [`EvalError::InvalidConfig`] if any request names an IoU threshold
459/// not present in `iou_thresholds` (within `1e-12`) or a
460/// [`MaxDetSelector::Value`] absent from `max_dets`.
461pub fn summarize_with(
462 accum: &Accumulated,
463 plan: &[StatRequest],
464 iou_thresholds: &[f64],
465 max_dets: &[usize],
466) -> Result<Summary, EvalError> {
467 summarize_dispatch(accum, plan, iou_thresholds, max_dets, None)
468}
469
470/// LVIS variant of [`summarize_with`] that resolves
471/// [`CategoryFilter::Frequency`] / [`CategoryFilter::ByIds`] against
472/// the K-axis context (ADR-0026 D2).
473///
474/// `category_ids` lists the dataset's categories in K-axis order
475/// (id-ascending — the same ordering the orchestrator at
476/// `evaluate.rs:701-707` uses). `category_frequency` is the
477/// per-category tag from the LVIS JSON (quirk **AB1**); pass `None`
478/// to opt out of frequency-based filtering, in which case any
479/// [`CategoryFilter::Frequency`] entry yields `-1.0` (the AP-undefined
480/// sentinel — quirk **AB6**, also the migration-guide's note for
481/// COCO datasets that don't carry frequency tags).
482///
483/// # Errors
484///
485/// Same error surface as [`summarize_with`], plus
486/// [`EvalError::InvalidConfig`] when `category_ids.len()` does not
487/// match the K-axis size of `accum.precision`.
488pub fn summarize_with_lvis(
489 accum: &Accumulated,
490 plan: &[StatRequest],
491 iou_thresholds: &[f64],
492 max_dets: &[usize],
493 category_ids: &[CategoryId],
494 category_frequency: Option<&HashMap<CategoryId, Frequency>>,
495) -> Result<Summary, EvalError> {
496 let n_k = accum.precision.shape()[2];
497 if category_ids.len() != n_k {
498 return Err(EvalError::InvalidConfig {
499 detail: format!(
500 "category_ids len {} != precision K-axis {n_k}",
501 category_ids.len()
502 ),
503 });
504 }
505 let ctx = LvisCtx {
506 category_ids,
507 category_frequency,
508 };
509 summarize_dispatch(accum, plan, iou_thresholds, max_dets, Some(&ctx))
510}
511
512/// Internal context bundle for the LVIS K-axis resolution. Carrying it
513/// as an `Option` lets [`summarize_with`] and [`summarize_with_lvis`]
514/// share the body without exposing a fifth public parameter on the
515/// COCO path.
516struct LvisCtx<'a> {
517 category_ids: &'a [CategoryId],
518 category_frequency: Option<&'a HashMap<CategoryId, Frequency>>,
519}
520
521fn summarize_dispatch(
522 accum: &Accumulated,
523 plan: &[StatRequest],
524 iou_thresholds: &[f64],
525 max_dets: &[usize],
526 lvis: Option<&LvisCtx<'_>>,
527) -> Result<Summary, EvalError> {
528 let p_shape = accum.precision.shape();
529 let r_shape = accum.recall.shape();
530 let n_t = p_shape[0];
531 let n_m = p_shape[4];
532
533 if n_t != iou_thresholds.len() {
534 return Err(EvalError::DimensionMismatch {
535 detail: format!(
536 "precision T-axis {} != iou_thresholds len {}",
537 n_t,
538 iou_thresholds.len()
539 ),
540 });
541 }
542 if n_m != max_dets.len() {
543 return Err(EvalError::DimensionMismatch {
544 detail: format!(
545 "precision M-axis {} != max_dets len {}",
546 n_m,
547 max_dets.len()
548 ),
549 });
550 }
551 if r_shape[0] != n_t || r_shape[3] != n_m {
552 return Err(EvalError::DimensionMismatch {
553 detail: format!("recall {r_shape:?} disagrees with precision {p_shape:?}"),
554 });
555 }
556
557 // Resolve every selector before computing any means: a typo in any
558 // request fails early without wasting evaluation work, and the
559 // compute pass below stays infallible.
560 let n_a = p_shape[3];
561 let n_k = p_shape[2];
562 let m_max = max_dets.len() - 1;
563 let resolved: Vec<(usize, Range<usize>, Option<Vec<bool>>)> = plan
564 .iter()
565 .map(|req| {
566 if req.area.index >= n_a {
567 return Err(EvalError::InvalidConfig {
568 detail: format!(
569 "AreaRng index {} is out of range for A-axis (size {})",
570 req.area.index, n_a
571 ),
572 });
573 }
574 let m_idx = match req.max_dets {
575 MaxDetSelector::Largest => m_max,
576 MaxDetSelector::Value(v) => {
577 max_dets.iter().position(|&d| d == v).ok_or_else(|| {
578 EvalError::InvalidConfig {
579 detail: format!("max_dets does not contain {v}"),
580 }
581 })?
582 }
583 };
584 let t_range = match req.iou_threshold {
585 None => 0..n_t,
586 Some(target) => {
587 let t = iou_thresholds
588 .iter()
589 .position(|&v| (v - target).abs() < IOU_LOOKUP_TOL)
590 .ok_or_else(|| EvalError::InvalidConfig {
591 detail: format!("iou_threshold {target} not in ladder"),
592 })?;
593 t..(t + 1)
594 }
595 };
596 let k_mask = resolve_category_filter(&req.category_filter, n_k, lvis)?;
597 Ok((m_idx, t_range, k_mask))
598 })
599 .collect::<Result<Vec<_>, EvalError>>()?;
600
601 let lines = plan
602 .iter()
603 .zip(resolved)
604 .map(|(req, (m_idx, t_range, k_mask))| {
605 let value = mean_slice(
606 accum,
607 req.metric,
608 t_range,
609 req.area.index,
610 m_idx,
611 k_mask.as_deref(),
612 );
613 StatLine {
614 metric: req.metric,
615 iou_threshold: req.iou_threshold,
616 area: req.area.clone(),
617 max_dets: max_dets[m_idx],
618 value,
619 }
620 })
621 .collect();
622
623 Ok(Summary { lines })
624}
625
626/// Resolve a [`CategoryFilter`] to a K-axis bool mask of length `n_k`.
627/// Returns `None` for [`CategoryFilter::All`] (the no-op — equivalent
628/// to a mask of all `true`s), and `Some(mask)` for the filtered
629/// variants.
630///
631/// Errors when a non-`All` filter is encountered without an LVIS
632/// context (the standard [`summarize_with`] entry point — its error
633/// message points users at [`summarize_with_lvis`]).
634fn resolve_category_filter(
635 filter: &CategoryFilter,
636 n_k: usize,
637 lvis: Option<&LvisCtx<'_>>,
638) -> Result<Option<Vec<bool>>, EvalError> {
639 match filter {
640 CategoryFilter::All => Ok(None),
641 CategoryFilter::Frequency(target) => {
642 let Some(ctx) = lvis else {
643 return Err(EvalError::InvalidConfig {
644 detail: "CategoryFilter::Frequency requires summarize_with_lvis".to_string(),
645 });
646 };
647 let Some(freq_map) = ctx.category_frequency else {
648 // AB6 migration-guide note: dataset has no frequency
649 // tags. AP_r/c/f can't be computed; the request
650 // resolves to "no K passes", which mean_slice maps
651 // to the `-1` sentinel.
652 return Ok(Some(vec![false; n_k]));
653 };
654 Ok(Some(
655 ctx.category_ids
656 .iter()
657 .map(|cid| freq_map.get(cid).is_some_and(|f| f == target))
658 .collect(),
659 ))
660 }
661 CategoryFilter::ByIds(ids) => {
662 let Some(ctx) = lvis else {
663 return Err(EvalError::InvalidConfig {
664 detail: "CategoryFilter::ByIds requires summarize_with_lvis".to_string(),
665 });
666 };
667 let allow: std::collections::HashSet<&CategoryId> = ids.iter().collect();
668 Ok(Some(
669 ctx.category_ids
670 .iter()
671 .map(|cid| allow.contains(cid))
672 .collect(),
673 ))
674 }
675 CategoryFilter::ByGrouping(label) => Err(EvalError::InvalidConfig {
676 detail: format!(
677 "CategoryFilter::ByGrouping({label:?}) must be resolved to ByIds at the \
678 evaluator boundary before reaching the kernel summarizer (ADR-0041 / 0042). \
679 Resolution maps the group label against the active ClassGroupBreakdown."
680 ),
681 }),
682 }
683}
684
685/// Mean of an `Accumulated` slice, filtering out the `-1` sentinel
686/// (quirks **C5/L6**) and optionally masking out K-axis indices that
687/// the request's [`CategoryFilter`] excludes (ADR-0026 D2). Returns
688/// `-1.0` if every surviving cell is the sentinel (mirrors
689/// pycocotools' `if len(s[s>-1])==0: -1`; quirk **AF6**: stays at
690/// `-1`, never collapses to `0` or `nan`).
691///
692/// The sum is computed via numpy-compatible pairwise summation
693/// ([`pairwise_sum`]) so the result is bit-identical to
694/// `np.mean(s[s>-1])` for the same input ordering. The K-axis mask
695/// is applied **before** the sentinel drop and the mean — matching
696/// lvis-api `eval.py:444`'s `s[s>-1]` shape on a frequency-filtered
697/// slice.
698///
699/// `k_mask`: `None` is the COCO no-op (every K passes); `Some(mask)`
700/// includes only K-axis indices where `mask[k] == true`. The mask
701/// length must equal the K-axis size; the caller (resolved in
702/// [`resolve_category_filter`]) guarantees this.
703///
704/// Infallible: callers must validate `t_range`, `area_idx`, and
705/// `m_idx` against the `Accumulated`'s shape upfront (see
706/// [`summarize_with`]).
707fn mean_slice(
708 accum: &Accumulated,
709 metric: Metric,
710 t_range: Range<usize>,
711 area_idx: usize,
712 m_idx: usize,
713 k_mask: Option<&[bool]>,
714) -> f64 {
715 let t_count = t_range.len();
716 let n_k = accum.precision.shape()[2];
717 let cap = match metric {
718 Metric::AveragePrecision => t_count * accum.precision.shape()[1] * n_k,
719 Metric::AverageRecall => t_count * n_k,
720 };
721 let mut filtered: Vec<f64> = Vec::with_capacity(cap);
722 let push_if = |filtered: &mut Vec<f64>, v: f64| {
723 if v > -1.0 {
724 filtered.push(v);
725 }
726 };
727 for t in t_range {
728 match metric {
729 Metric::AveragePrecision => {
730 // Slice T → (R, K, A, M); pick A and M → (R, K). The
731 // intermediate axis-views have to live in `let`
732 // bindings so the final `(R, K)` view borrows them
733 // long enough to index — the chained-call form
734 // dropped the inner views before `plane` was used.
735 let p_t = accum.precision.index_axis(Axis(0), t);
736 let p_ta = p_t.index_axis(Axis(2), area_idx);
737 let plane = p_ta.index_axis(Axis(2), m_idx);
738 // Walking R-major preserves the same sum order numpy
739 // uses on a `(R, K)` slice — the K-mask filter just
740 // skips columns the user opted out of.
741 let n_r = plane.shape()[0];
742 for r in 0..n_r {
743 for k in 0..n_k {
744 if k_mask.is_some_and(|m| !m[k]) {
745 continue;
746 }
747 push_if(&mut filtered, plane[(r, k)]);
748 }
749 }
750 }
751 Metric::AverageRecall => {
752 // recall is (T, K, A, M); slice T → (K, A, M); pick A
753 // and M → (K,). One value per K.
754 let r_t = accum.recall.index_axis(Axis(0), t);
755 let r_ta = r_t.index_axis(Axis(1), area_idx);
756 let plane = r_ta.index_axis(Axis(1), m_idx);
757 for k in 0..n_k {
758 if k_mask.is_some_and(|m| !m[k]) {
759 continue;
760 }
761 push_if(&mut filtered, plane[k]);
762 }
763 }
764 }
765 }
766 if filtered.is_empty() {
767 -1.0
768 } else {
769 pairwise_sum(&filtered) / filtered.len() as f64
770 }
771}
772
773/// Numpy-compatible pairwise summation for `f64` slices.
774///
775/// Matches the algorithm used by `np.add.reduce` on contiguous
776/// double-precision arrays (see numpy's
777/// `numpy/core/src/umath/loops_utils.h.src::pairwise_sum_DOUBLE`):
778///
779/// - `n < 8`: naive forward sum.
780/// - `8 <= n <= PW_BLOCKSIZE` (128): 8 separately accumulated lanes
781/// combined via a balanced tree `((r0+r1)+(r2+r3)) + ((r4+r5)+(r6+r7))`,
782/// followed by a tail loop for the remainder.
783/// - `n > PW_BLOCKSIZE`: split at `n / 2` aligned down to a multiple of
784/// 8 and recurse on both halves.
785///
786/// Reproducing this here is a quirk-**C8**-style alignment: the public
787/// summary stats ride on top of `np.mean(s[s > -1])`, and any other sum
788/// order drifts by ~1 ULP.
789pub(crate) fn pairwise_sum(values: &[f64]) -> f64 {
790 const PW_BLOCKSIZE: usize = 128;
791 let n = values.len();
792
793 if n < 8 {
794 let mut s = 0.0_f64;
795 for &v in values {
796 s += v;
797 }
798 return s;
799 }
800
801 if n <= PW_BLOCKSIZE {
802 let mut r = [
803 values[0], values[1], values[2], values[3], values[4], values[5], values[6], values[7],
804 ];
805 let trunc = n - (n % 8);
806 let mut i = 8;
807 while i < trunc {
808 r[0] += values[i];
809 r[1] += values[i + 1];
810 r[2] += values[i + 2];
811 r[3] += values[i + 3];
812 r[4] += values[i + 4];
813 r[5] += values[i + 5];
814 r[6] += values[i + 6];
815 r[7] += values[i + 7];
816 i += 8;
817 }
818 let mut res = ((r[0] + r[1]) + (r[2] + r[3])) + ((r[4] + r[5]) + (r[6] + r[7]));
819 while i < n {
820 res += values[i];
821 i += 1;
822 }
823 return res;
824 }
825
826 let mut n2 = n / 2;
827 n2 -= n2 % 8;
828 pairwise_sum(&values[..n2]) + pairwise_sum(&values[n2..])
829}
830
831#[cfg(test)]
832mod tests {
833 use super::*;
834 use crate::accumulate::{accumulate, AccumulateParams, PerImageEval};
835 use crate::parity::{iou_thresholds, recall_thresholds, ParityMode};
836 use ndarray::{Array2, Array4, Array5};
837
838 fn perfect_match_eval(t: usize) -> PerImageEval {
839 PerImageEval {
840 dt_scores: vec![0.9],
841 dt_matched: Array2::from_elem((t, 1), true),
842 dt_ignore: Array2::from_elem((t, 1), false),
843 gt_ignore: vec![false],
844 }
845 }
846
847 #[test]
848 fn perfect_match_summarizes_to_ones() {
849 // Single image, single category, all-area only — the simplest
850 // valid run that exercises every line of the 12-stat table.
851 let iou = iou_thresholds();
852 let rec = recall_thresholds();
853 let max_dets = [1usize, 10, 100];
854 let cell = perfect_match_eval(iou.len());
855
856 // K=1, A=4 (all/small/medium/large), I=1; we populate only the
857 // `all` cell. small/medium/large stay None → -1 sentinel.
858 let mut grid: Vec<Option<Box<PerImageEval>>> = vec![None; 4];
859 grid[0] = Some(Box::new(cell));
860
861 let p = AccumulateParams {
862 iou_thresholds: iou,
863 recall_thresholds: rec,
864 max_dets: &max_dets,
865 n_categories: 1,
866 n_area_ranges: 4,
867 n_images: 1,
868 };
869 let accum = accumulate(&grid, p, ParityMode::Strict).unwrap();
870 let summary = summarize_detection(&accum, iou, &max_dets).unwrap();
871
872 let stats = summary.stats();
873 assert_eq!(stats.len(), 12);
874 // AP[all], AP50, AP75, AR_1, AR_10, AR_100 should all be ~1.0.
875 for &i in &[0usize, 1, 2, 6, 7, 8] {
876 let v = stats[i];
877 assert!((v - 1.0).abs() < 1e-9, "stat[{i}] = {v}");
878 }
879 // small / medium / large carry -1 (no data).
880 for &i in &[3usize, 4, 5, 9, 10, 11] {
881 assert_eq!(stats[i], -1.0, "stat[{i}] should be -1 sentinel");
882 }
883 }
884
885 #[test]
886 fn empty_grid_yields_all_neg_one_stats() {
887 let iou = iou_thresholds();
888 let rec = recall_thresholds();
889 let max_dets = [1usize, 10, 100];
890 let p = AccumulateParams {
891 iou_thresholds: iou,
892 recall_thresholds: rec,
893 max_dets: &max_dets,
894 n_categories: 1,
895 n_area_ranges: 4,
896 n_images: 0,
897 };
898 let accum = accumulate(&[], p, ParityMode::Strict).unwrap();
899 let summary = summarize_detection(&accum, iou, &max_dets).unwrap();
900 assert!(summary.stats().iter().all(|&v| v == -1.0));
901 }
902
903 #[test]
904 fn missing_max_det_value_is_typed_error() {
905 // AR_1 line requires max_dets to contain 1; without it,
906 // summarization fails with InvalidConfig.
907 let iou = iou_thresholds();
908 let max_dets = [10usize, 100];
909 let accum = Accumulated {
910 precision: Array5::<f64>::from_elem((iou.len(), 101, 1, 4, 2), -1.0),
911 recall: Array4::<f64>::from_elem((iou.len(), 1, 4, 2), -1.0),
912 scores: Array5::<f64>::from_elem((iou.len(), 101, 1, 4, 2), -1.0),
913 };
914 let err = summarize_detection(&accum, iou, &max_dets).unwrap_err();
915 assert!(matches!(err, EvalError::InvalidConfig { .. }));
916 }
917
918 #[test]
919 fn iou_threshold_dimension_mismatch_is_typed_error() {
920 let max_dets = [100usize];
921 let accum = Accumulated {
922 precision: Array5::<f64>::from_elem((10, 101, 1, 4, 1), -1.0),
923 recall: Array4::<f64>::from_elem((10, 1, 4, 1), -1.0),
924 scores: Array5::<f64>::from_elem((10, 101, 1, 4, 1), -1.0),
925 };
926 // pass only 5 thresholds — accum was built with 10.
927 let err = summarize_detection(&accum, &[0.5, 0.6, 0.7, 0.8, 0.9], &max_dets).unwrap_err();
928 assert!(matches!(err, EvalError::DimensionMismatch { .. }));
929 }
930
931 #[test]
932 fn summarize_with_custom_plan_evaluates_only_requested_lines() {
933 // Demonstrates the extension point: a 2-entry plan asking for
934 // AP@.50 across all areas and AR@.75 (not in the canonical 12)
935 // — both at the largest cap. Order is preserved.
936 let iou = iou_thresholds();
937 let max_dets = [100usize];
938 let accum = Accumulated {
939 precision: Array5::<f64>::from_elem((iou.len(), 101, 1, 4, 1), 0.5),
940 recall: Array4::<f64>::from_elem((iou.len(), 1, 4, 1), 0.7),
941 scores: Array5::<f64>::from_elem((iou.len(), 101, 1, 4, 1), 1.0),
942 };
943 let plan = [
944 StatRequest::new(
945 Metric::AveragePrecision,
946 Some(0.5),
947 AreaRng::ALL,
948 MaxDetSelector::Largest,
949 ),
950 StatRequest::new(
951 Metric::AverageRecall,
952 Some(0.75),
953 AreaRng::ALL,
954 MaxDetSelector::Largest,
955 ),
956 ];
957 let summary = summarize_with(&accum, &plan, iou, &max_dets).unwrap();
958 assert_eq!(summary.lines.len(), 2);
959 assert!((summary.lines[0].value - 0.5).abs() < 1e-12);
960 assert_eq!(summary.lines[0].iou_threshold, Some(0.5));
961 assert!((summary.lines[1].value - 0.7).abs() < 1e-12);
962 assert_eq!(summary.lines[1].metric, Metric::AverageRecall);
963 }
964
965 #[test]
966 fn summarize_detection_matches_canonical_plan_via_summarize_with() {
967 // The thin-wrapper invariant: results are bit-equal whether the
968 // caller invokes summarize_detection or summarize_with with the
969 // canonical plan.
970 let iou = iou_thresholds();
971 let max_dets = [1usize, 10, 100];
972 let accum = Accumulated {
973 precision: Array5::<f64>::from_elem((iou.len(), 101, 1, 4, 3), 0.5),
974 recall: Array4::<f64>::from_elem((iou.len(), 1, 4, 3), 0.7),
975 scores: Array5::<f64>::from_elem((iou.len(), 101, 1, 4, 3), 1.0),
976 };
977 let direct = summarize_detection(&accum, iou, &max_dets).unwrap();
978 let via_plan = summarize_with(
979 &accum,
980 &StatRequest::coco_detection_default(),
981 iou,
982 &max_dets,
983 )
984 .unwrap();
985 assert_eq!(direct.stats(), via_plan.stats());
986 }
987
988 #[test]
989 fn custom_area_bucket_with_owned_label_renders_in_pretty_lines() {
990 // 5-bucket A-axis (e.g. an orchestrator that adds a "tiny"
991 // bucket below "small"). The plan addresses index 4 by name and
992 // the label flows through to pretty_lines.
993 let iou = iou_thresholds();
994 let max_dets = [100usize];
995 let accum = Accumulated {
996 precision: Array5::<f64>::from_elem((iou.len(), 101, 1, 5, 1), 1.0),
997 recall: Array4::<f64>::from_elem((iou.len(), 1, 5, 1), 1.0),
998 scores: Array5::<f64>::from_elem((iou.len(), 101, 1, 5, 1), 1.0),
999 };
1000 let plan = [StatRequest::new(
1001 Metric::AveragePrecision,
1002 None,
1003 AreaRng::new(4, "tiny"),
1004 MaxDetSelector::Largest,
1005 )];
1006 let summary = summarize_with(&accum, &plan, iou, &max_dets).unwrap();
1007 let lines = summary.pretty_lines();
1008 assert_eq!(lines.len(), 1);
1009 assert!(lines[0].contains("tiny"), "unexpected line: {}", lines[0]);
1010 }
1011
1012 #[test]
1013 fn out_of_range_area_index_is_typed_error() {
1014 // Plan addresses A-axis index 4 against a 4-bucket Accumulated.
1015 let iou = iou_thresholds();
1016 let max_dets = [100usize];
1017 let accum = Accumulated {
1018 precision: Array5::<f64>::from_elem((iou.len(), 101, 1, 4, 1), 1.0),
1019 recall: Array4::<f64>::from_elem((iou.len(), 1, 4, 1), 1.0),
1020 scores: Array5::<f64>::from_elem((iou.len(), 101, 1, 4, 1), 1.0),
1021 };
1022 let plan = [StatRequest::new(
1023 Metric::AveragePrecision,
1024 None,
1025 AreaRng::new(4, "tiny"),
1026 MaxDetSelector::Largest,
1027 )];
1028 let err = summarize_with(&accum, &plan, iou, &max_dets).unwrap_err();
1029 assert!(matches!(err, EvalError::InvalidConfig { .. }));
1030 }
1031
1032 #[test]
1033 fn pretty_lines_match_pycocotools_shape() {
1034 let iou = iou_thresholds();
1035 let max_dets = [1usize, 10, 100];
1036 let accum = Accumulated {
1037 precision: Array5::<f64>::from_elem((iou.len(), 101, 1, 4, 3), 1.0),
1038 recall: Array4::<f64>::from_elem((iou.len(), 1, 4, 3), 1.0),
1039 scores: Array5::<f64>::from_elem((iou.len(), 101, 1, 4, 3), 1.0),
1040 };
1041 let summary = summarize_detection(&accum, iou, &max_dets).unwrap();
1042 let lines = summary.pretty_lines();
1043 assert_eq!(lines.len(), 12);
1044 // Spot-check the first AP line and the first AR line for the
1045 // pycocotools-shaped layout.
1046 assert!(lines[0].contains("Average Precision"));
1047 assert!(lines[0].contains("(AP)"));
1048 assert!(lines[0].contains("0.50:0.95"));
1049 assert!(lines[0].contains("maxDets=100"));
1050 assert!(lines[6].contains("Average Recall"));
1051 assert!(lines[6].contains("maxDets= 1"));
1052 }
1053
1054 #[test]
1055 fn pairwise_sum_matches_numpy_add_reduce_bitwise() {
1056 // 1010 alternating elements is large enough to drive both the
1057 // 8-lane unrolled block and the recursive split (n > 128). The
1058 // expected hex below is `np.add.reduce(v).hex()` for the same
1059 // sequence; naive forward summation lands one ULP higher
1060 // (`0x1.f900000002309p+8`).
1061 let v: Vec<f64> = (0..1010)
1062 .map(|i| if i % 2 == 0 { 1.0 } else { 1e-12 })
1063 .collect();
1064 let got = pairwise_sum(&v);
1065 let expected = f64::from_bits(0x407f_9000_0000_22b4);
1066 assert_eq!(
1067 got.to_bits(),
1068 expected.to_bits(),
1069 "pairwise_sum drifts from numpy: got {got:e}, expected {expected:e}",
1070 );
1071 }
1072
1073 #[test]
1074 fn coco_keypoints_default_plan_pins_canonical_order() {
1075 // ADR-0012 / D5: pycocotools' kp summary is exactly these 10
1076 // lines, in this order. Pin metric, threshold, A-axis index,
1077 // and selector so a refactor cannot silently re-order, drop a
1078 // row, or re-introduce the small bucket.
1079 let plan = StatRequest::coco_keypoints_default();
1080 assert_eq!(plan.len(), 10);
1081
1082 // Each entry: (metric, iou_threshold, area_index, selector).
1083 let expected: [(Metric, Option<f64>, usize, MaxDetSelector); 10] = [
1084 (Metric::AveragePrecision, None, 0, MaxDetSelector::Largest), // AP
1085 (
1086 Metric::AveragePrecision,
1087 Some(0.5),
1088 0,
1089 MaxDetSelector::Largest,
1090 ), // AP50
1091 (
1092 Metric::AveragePrecision,
1093 Some(0.75),
1094 0,
1095 MaxDetSelector::Largest,
1096 ), // AP75
1097 (Metric::AveragePrecision, None, 1, MaxDetSelector::Largest), // AP_M
1098 (Metric::AveragePrecision, None, 2, MaxDetSelector::Largest), // AP_L
1099 (Metric::AverageRecall, None, 0, MaxDetSelector::Largest), // AR
1100 (Metric::AverageRecall, Some(0.5), 0, MaxDetSelector::Largest), // AR50
1101 (
1102 Metric::AverageRecall,
1103 Some(0.75),
1104 0,
1105 MaxDetSelector::Largest,
1106 ), // AR75
1107 (Metric::AverageRecall, None, 1, MaxDetSelector::Largest), // AR_M
1108 (Metric::AverageRecall, None, 2, MaxDetSelector::Largest), // AR_L
1109 ];
1110
1111 for (i, (metric, iou, idx, sel)) in expected.into_iter().enumerate() {
1112 assert_eq!(plan[i].metric, metric, "row {i} metric");
1113 assert_eq!(plan[i].iou_threshold, iou, "row {i} iou_threshold");
1114 assert_eq!(plan[i].area.index, idx, "row {i} area index");
1115 assert_eq!(plan[i].max_dets, sel, "row {i} selector");
1116 }
1117
1118 // No row addresses A-axis index 3 (would land off the end of a
1119 // 3-bucket kp accumulator) and no row addresses index 1 of the
1120 // detection-grid (which is "small" — D5 forbids).
1121 assert!(plan.iter().all(|r| r.area.index <= 2));
1122 }
1123
1124 #[test]
1125 fn pairwise_sum_handles_short_inputs_with_naive_fallback() {
1126 // n < 8 uses the simple loop; verify a hand-checked tiny case.
1127 let v = [1.0_f64, 2.0, 3.0, 4.0];
1128 assert_eq!(pairwise_sum(&v), 10.0);
1129 assert_eq!(pairwise_sum(&[]), 0.0);
1130 assert_eq!(pairwise_sum(&[42.0]), 42.0);
1131 }
1132
1133 // -- ADR-0026: lvis_default plan and CategoryFilter dispatch --------------
1134
1135 /// Synthesize a `(T, R, K, A, M)` precision tensor + matching
1136 /// recall and counts for unit-testing the K-axis filter. The
1137 /// shape parameters mirror lvis-api's typical run: T=10, R=101,
1138 /// A=4, M=1.
1139 fn fake_accumulated(n_k: usize, precision_per_k: &[f64], recall_per_k: &[f64]) -> Accumulated {
1140 const N_T: usize = 10;
1141 const N_R: usize = 101;
1142 const N_A: usize = 4;
1143 const N_M: usize = 1;
1144 assert_eq!(precision_per_k.len(), n_k);
1145 assert_eq!(recall_per_k.len(), n_k);
1146 let mut precision = Array5::<f64>::from_elem((N_T, N_R, n_k, N_A, N_M), 0.0);
1147 let mut recall = Array4::<f64>::from_elem((N_T, n_k, N_A, N_M), 0.0);
1148 for k in 0..n_k {
1149 // Fill every cell on the K-axis with the same constant so
1150 // mean(s[s>-1]) trivially equals that constant; sentinels
1151 // (`-1`) on whole-K rows fall through to the AF6 path.
1152 for t in 0..N_T {
1153 for r in 0..N_R {
1154 for a in 0..N_A {
1155 for m in 0..N_M {
1156 precision[(t, r, k, a, m)] = precision_per_k[k];
1157 }
1158 }
1159 }
1160 for a in 0..N_A {
1161 for m in 0..N_M {
1162 recall[(t, k, a, m)] = recall_per_k[k];
1163 }
1164 }
1165 }
1166 }
1167 Accumulated {
1168 precision,
1169 recall,
1170 scores: Array5::<f64>::from_elem((N_T, N_R, n_k, N_A, N_M), 0.0),
1171 }
1172 }
1173
1174 #[test]
1175 fn lvis_default_has_13_entries_in_canonical_order() {
1176 let plan = StatRequest::lvis_default();
1177 assert_eq!(plan.len(), 13, "AF1: 9 AP + 4 AR");
1178 // First 6 are AP across the COCO area buckets (no freq filter).
1179 for (i, req) in plan.iter().take(6).enumerate() {
1180 assert_eq!(req.metric, Metric::AveragePrecision, "row {i}");
1181 assert_eq!(req.category_filter, CategoryFilter::All, "row {i}");
1182 }
1183 // Rows 6/7/8: APr/APc/APf — ALL area, frequency filter set.
1184 for (i, expected) in [Frequency::Rare, Frequency::Common, Frequency::Frequent]
1185 .iter()
1186 .enumerate()
1187 {
1188 let req = &plan[6 + i];
1189 assert_eq!(req.metric, Metric::AveragePrecision);
1190 assert_eq!(req.area.index, AreaRng::ALL.index);
1191 assert_eq!(
1192 req.category_filter,
1193 CategoryFilter::Frequency(*expected),
1194 "row {}: AP{}",
1195 6 + i,
1196 expected_letter(*expected),
1197 );
1198 }
1199 // Rows 9..13: AR@300 across all four area buckets, no freq filter.
1200 for (i, area_idx) in [
1201 AreaRng::ALL.index,
1202 AreaRng::SMALL.index,
1203 AreaRng::MEDIUM.index,
1204 AreaRng::LARGE.index,
1205 ]
1206 .iter()
1207 .enumerate()
1208 {
1209 let req = &plan[9 + i];
1210 assert_eq!(req.metric, Metric::AverageRecall);
1211 assert_eq!(req.area.index, *area_idx);
1212 assert_eq!(req.category_filter, CategoryFilter::All);
1213 assert_eq!(req.max_dets, MaxDetSelector::Largest);
1214 }
1215 }
1216
1217 fn expected_letter(f: Frequency) -> char {
1218 match f {
1219 Frequency::Rare => 'r',
1220 Frequency::Common => 'c',
1221 Frequency::Frequent => 'f',
1222 }
1223 }
1224
1225 #[test]
1226 fn summarize_with_rejects_frequency_filter_on_coco_path() {
1227 // The plain `summarize_with` entry point has no K-axis context;
1228 // a Frequency-filtered plan must surface an InvalidConfig that
1229 // points the caller at summarize_with_lvis.
1230 let accum = fake_accumulated(3, &[1.0, 1.0, 1.0], &[1.0, 1.0, 1.0]);
1231 let plan = StatRequest::lvis_default();
1232 let err = summarize_with(&accum, &plan, iou_thresholds(), &[300]).unwrap_err();
1233 match err {
1234 EvalError::InvalidConfig { detail } => {
1235 assert!(detail.contains("summarize_with_lvis"), "msg: {detail}");
1236 }
1237 other => panic!("expected InvalidConfig, got {other:?}"),
1238 }
1239 }
1240
1241 #[test]
1242 fn summarize_with_lvis_routes_frequency_buckets_correctly() {
1243 // 3 categories: cat 1 = Frequent (precision 0.6), cat 2 =
1244 // Common (0.4), cat 3 = Rare (0.2). AP_r/c/f must equal each
1245 // bucket's per-K precision; AP overall = (0.6+0.4+0.2)/3.
1246 let accum = fake_accumulated(3, &[0.6, 0.4, 0.2], &[0.6, 0.4, 0.2]);
1247 let cat_ids = [CategoryId(1), CategoryId(2), CategoryId(3)];
1248 let mut freq_map = HashMap::new();
1249 freq_map.insert(CategoryId(1), Frequency::Frequent);
1250 freq_map.insert(CategoryId(2), Frequency::Common);
1251 freq_map.insert(CategoryId(3), Frequency::Rare);
1252
1253 let plan = StatRequest::lvis_default();
1254 let summary = summarize_with_lvis(
1255 &accum,
1256 &plan,
1257 iou_thresholds(),
1258 &[300],
1259 &cat_ids,
1260 Some(&freq_map),
1261 )
1262 .unwrap();
1263
1264 // Index 6/7/8 → APr/APc/APf.
1265 let apr = summary.lines[6].value;
1266 let apc = summary.lines[7].value;
1267 let apf = summary.lines[8].value;
1268 assert!((apr - 0.2).abs() < 1e-12, "APr expected 0.2, got {apr}");
1269 assert!((apc - 0.4).abs() < 1e-12, "APc expected 0.4, got {apc}");
1270 assert!((apf - 0.6).abs() < 1e-12, "APf expected 0.6, got {apf}");
1271 // Index 0 → AP overall.
1272 let ap = summary.lines[0].value;
1273 let expected = (0.6 + 0.4 + 0.2) / 3.0;
1274 assert!((ap - expected).abs() < 1e-12, "AP overall: {ap}");
1275 }
1276
1277 #[test]
1278 fn ab3_filters_minus_one_sentinels_before_mean() {
1279 // Mix two categories: one with positive precision, one with
1280 // the `-1` sentinel (a category that produced no eval_imgs
1281 // entries). The AP overall must equal the positive value
1282 // alone — the sentinel falls out of the mean.
1283 let accum = fake_accumulated(2, &[-1.0, 0.5], &[-1.0, 0.5]);
1284 let cat_ids = [CategoryId(1), CategoryId(2)];
1285 let mut freq_map = HashMap::new();
1286 freq_map.insert(CategoryId(1), Frequency::Rare);
1287 freq_map.insert(CategoryId(2), Frequency::Frequent);
1288 let summary = summarize_with_lvis(
1289 &accum,
1290 &StatRequest::lvis_default(),
1291 iou_thresholds(),
1292 &[300],
1293 &cat_ids,
1294 Some(&freq_map),
1295 )
1296 .unwrap();
1297 // AP overall: only cat 2 contributes (cat 1 is `-1`).
1298 assert!((summary.lines[0].value - 0.5).abs() < 1e-12);
1299 // APr: cat 1 is the only Rare category — and it's `-1`, so
1300 // the bucket is empty after filtering. AF6: returns `-1.0`.
1301 assert_eq!(summary.lines[6].value, -1.0, "APr empty bucket → -1");
1302 // APf: cat 2 is the only Frequent category, value 0.5.
1303 assert!((summary.lines[8].value - 0.5).abs() < 1e-12);
1304 }
1305
1306 #[test]
1307 fn af6_empty_frequency_bucket_returns_minus_one_not_zero_or_nan() {
1308 // Three Frequent categories, no Rare or Common. APr and APc
1309 // must both surface the `-1` sentinel, distinct from the
1310 // panoptic ADR-0025 W6 corrected behavior (returns `0.0`)
1311 // and from numpy's `nan` on an unfiltered empty mean.
1312 let accum = fake_accumulated(3, &[0.7, 0.8, 0.9], &[0.7, 0.8, 0.9]);
1313 let cat_ids = [CategoryId(1), CategoryId(2), CategoryId(3)];
1314 let mut freq_map = HashMap::new();
1315 freq_map.insert(CategoryId(1), Frequency::Frequent);
1316 freq_map.insert(CategoryId(2), Frequency::Frequent);
1317 freq_map.insert(CategoryId(3), Frequency::Frequent);
1318 let summary = summarize_with_lvis(
1319 &accum,
1320 &StatRequest::lvis_default(),
1321 iou_thresholds(),
1322 &[300],
1323 &cat_ids,
1324 Some(&freq_map),
1325 )
1326 .unwrap();
1327 // AP overall: mean of 0.7/0.8/0.9 = 0.8.
1328 assert!((summary.lines[0].value - 0.8).abs() < 1e-12);
1329 // APr / APc: empty K filter → -1.0 (not 0.0, not nan).
1330 assert_eq!(summary.lines[6].value, -1.0, "APr");
1331 assert_eq!(summary.lines[7].value, -1.0, "APc");
1332 assert!(!summary.lines[6].value.is_nan(), "AF6: never nan");
1333 assert!(summary.lines[6].value != 0.0, "AF6: never 0.0");
1334 }
1335
1336 #[test]
1337 fn ab6_no_frequency_map_yields_minus_one_for_frequency_filtered_lines() {
1338 // The dataset doesn't carry frequency tags (the COCO loader
1339 // path on an LVIS-shaped JSON, or a programmatically built
1340 // dataset). Frequency-filtered entries gracefully surface
1341 // the `-1` sentinel — quirk **AB6** corrected (no panic).
1342 let accum = fake_accumulated(2, &[0.5, 0.5], &[0.5, 0.5]);
1343 let cat_ids = [CategoryId(1), CategoryId(2)];
1344 let summary = summarize_with_lvis(
1345 &accum,
1346 &StatRequest::lvis_default(),
1347 iou_thresholds(),
1348 &[300],
1349 &cat_ids,
1350 None,
1351 )
1352 .unwrap();
1353 assert!((summary.lines[0].value - 0.5).abs() < 1e-12, "AP overall");
1354 assert_eq!(summary.lines[6].value, -1.0, "APr without freq map");
1355 assert_eq!(summary.lines[7].value, -1.0, "APc without freq map");
1356 assert_eq!(summary.lines[8].value, -1.0, "APf without freq map");
1357 }
1358
1359 #[test]
1360 fn category_filter_by_ids_subsets_correctly() {
1361 // ByIds filter: include only cat 2 → AP equals cat 2's
1362 // per-K precision regardless of the other categories.
1363 let accum = fake_accumulated(3, &[0.1, 0.5, 0.9], &[0.1, 0.5, 0.9]);
1364 let cat_ids = [CategoryId(1), CategoryId(2), CategoryId(3)];
1365 let plan = vec![StatRequest::new_with_filter(
1366 Metric::AveragePrecision,
1367 None,
1368 AreaRng::ALL,
1369 MaxDetSelector::Largest,
1370 CategoryFilter::ByIds(vec![CategoryId(2)]),
1371 )];
1372 let summary =
1373 summarize_with_lvis(&accum, &plan, iou_thresholds(), &[300], &cat_ids, None).unwrap();
1374 assert!((summary.lines[0].value - 0.5).abs() < 1e-12);
1375 }
1376
1377 #[test]
1378 fn category_axis_size_mismatch_is_typed_error() {
1379 let accum = fake_accumulated(2, &[0.5, 0.5], &[0.5, 0.5]);
1380 let cat_ids = [CategoryId(1), CategoryId(2), CategoryId(3)]; // wrong length
1381 let err = summarize_with_lvis(
1382 &accum,
1383 &StatRequest::lvis_default(),
1384 iou_thresholds(),
1385 &[300],
1386 &cat_ids,
1387 None,
1388 )
1389 .unwrap_err();
1390 assert!(matches!(err, EvalError::InvalidConfig { .. }));
1391 }
1392}