Skip to main content

vernier_core/
summarize.rs

1//! Twelve-stat detection summary atop [`crate::Accumulated`].
2//!
3//! Mirrors `pycocotools.cocoeval.COCOeval.summarize` (cocoeval.py
4//! lines 422-475), but as a pure structured value — no stdout side
5//! effects (quirks **L5/L6/L7**, dispositioned `corrected`).
6//!
7//! ## Quirk dispositions
8//!
9//! - **C5** (`strict`): cells absent from the dataset carry `-1`;
10//!   summarization filters them out via `s > -1` before averaging.
11//! - **L5** (`corrected`): the print/log side-effect from upstream
12//!   `_summarize` is gone. Use [`Summary::pretty_lines`] for the
13//!   pycocotools-shaped human-readable rendering.
14//! - **L6** (`corrected`): empty-eval `mean(empty)` no longer raises a
15//!   numpy RuntimeWarning — the absent case explicitly returns `-1`.
16//! - **L7** (`corrected`): the result is a value (`Summary`), not a
17//!   property side-effect on the evaluator.
18
19use std::borrow::Cow;
20use std::collections::HashMap;
21use std::ops::Range;
22
23use ndarray::Axis;
24
25use crate::accumulate::Accumulated;
26use crate::dataset::{CategoryId, Frequency};
27use crate::error::EvalError;
28
29/// Tolerance for matching a user-supplied IoU threshold to a value in
30/// the `iou_thresholds` ladder. Rounds out the ulp-level error from the
31/// `linspace(0.5, 0.95, 10)` build (quirk **L1**).
32pub(crate) const IOU_LOOKUP_TOL: f64 = 1e-12;
33
34/// One bucket on the A-axis of an [`Accumulated`] — an index plus a
35/// label for rendering.
36///
37/// The canonical pycocotools detection layout is exposed as
38/// [`AreaRng::ALL`] / [`SMALL`](Self::SMALL) / [`MEDIUM`](Self::MEDIUM)
39/// / [`LARGE`](Self::LARGE), matching the cocoeval `Params.areaRngLbl`
40/// order. Custom layouts (e.g., robotics-style finer buckets) are
41/// constructed with [`AreaRng::new`] for owned labels or
42/// [`AreaRng::from_static`] for `&'static str` labels.
43///
44/// The *bounds* that turn an annotation's area into a bucket index
45/// live upstream, on the orchestrator that builds [`crate::accumulate::PerImageEval`]
46/// cells; the summarizer only consumes the resulting A-axis index.
47#[derive(Debug, Clone, PartialEq, Eq)]
48pub struct AreaRng {
49    /// Position on the A-axis of [`Accumulated::precision`] /
50    /// [`Accumulated::recall`]. Validated against the actual A-axis
51    /// length at summarize time, not at construction; an out-of-range
52    /// index produces [`EvalError::InvalidConfig`].
53    pub index: usize,
54    /// Label rendered by [`Summary::pretty_lines`].
55    pub label: Cow<'static, str>,
56}
57
58impl AreaRng {
59    /// Construct from any owned- or borrowed-string label.
60    pub fn new(index: usize, label: impl Into<Cow<'static, str>>) -> Self {
61        Self {
62            index,
63            label: label.into(),
64        }
65    }
66
67    /// `const`-friendly constructor for compile-time labels.
68    pub const fn from_static(index: usize, label: &'static str) -> Self {
69        Self {
70            index,
71            label: Cow::Borrowed(label),
72        }
73    }
74
75    /// COCO `all` bucket — pycocotools' `[0, 1e10]`, A-axis index 0.
76    pub const ALL: Self = Self::from_static(0, "all");
77    /// COCO `small` bucket — pycocotools' `[0, 32^2]`, A-axis index 1.
78    pub const SMALL: Self = Self::from_static(1, "small");
79    /// COCO `medium` bucket — pycocotools' `[32^2, 96^2]`, A-axis index 2.
80    pub const MEDIUM: Self = Self::from_static(2, "medium");
81    /// COCO `large` bucket — pycocotools' `[96^2, 1e10]`, A-axis index 3.
82    pub const LARGE: Self = Self::from_static(3, "large");
83}
84
85/// AP / AR selector emitted on every [`StatLine`].
86#[derive(Debug, Clone, Copy, PartialEq, Eq)]
87pub enum Metric {
88    /// Average Precision — slices `Accumulated::precision`.
89    AveragePrecision,
90    /// Average Recall — slices `Accumulated::recall`. Quirk **C4**: AR
91    /// is the terminal cumulative recall, not an integral of the
92    /// precision/recall curve.
93    AverageRecall,
94}
95
96/// Single line of the COCO 12-stat summary table.
97#[derive(Debug, Clone)]
98pub struct StatLine {
99    /// AP or AR.
100    pub metric: Metric,
101    /// `None` means averaged across the whole IoU ladder; `Some(t)`
102    /// pins a specific threshold (e.g., 0.5 for AP@.50).
103    pub iou_threshold: Option<f64>,
104    /// Area-range bucket.
105    pub area: AreaRng,
106    /// Per-image maxDet cap.
107    pub max_dets: usize,
108    /// Mean over the matching slice, ignoring `-1` sentinels. `-1.0`
109    /// when the slice has no non-sentinel entries (quirks **C5/L6**).
110    pub value: f64,
111}
112
113/// Result of evaluating a summary plan over an [`Accumulated`].
114///
115/// `lines.len()` matches the plan length; for the canonical pycocotools
116/// detection summary built by [`summarize_detection`], that's 12 lines
117/// in the order `[AP, AP50, AP75, AP_S, AP_M, AP_L, AR_1, AR_10,
118/// AR_100, AR_S, AR_M, AR_L]`. For custom plans evaluated via
119/// [`summarize_with`], `lines` mirrors the request order.
120#[derive(Debug, Clone)]
121pub struct Summary {
122    /// One entry per request in the evaluated plan, paired with slicing
123    /// metadata.
124    pub lines: Vec<StatLine>,
125}
126
127impl Summary {
128    /// Numeric values in plan order. Equivalent to
129    /// `lines.iter().map(|l| l.value).collect()`.
130    pub fn stats(&self) -> Vec<f64> {
131        self.lines.iter().map(|l| l.value).collect()
132    }
133    /// Render the canonical pycocotools text table (12 lines, each in
134    /// the upstream `Average Precision (AP) @[ IoU=... | area=... |
135    /// maxDets=... ] = 0.xxx` shape). Returned as a `Vec<String>`; the
136    /// caller decides whether to print, log, or test against it.
137    pub fn pretty_lines(&self) -> Vec<String> {
138        self.lines
139            .iter()
140            .map(|line| {
141                let (title, kind) = match line.metric {
142                    Metric::AveragePrecision => ("Average Precision", "(AP)"),
143                    Metric::AverageRecall => ("Average Recall", "(AR)"),
144                };
145                let iou = match line.iou_threshold {
146                    Some(t) => format!("{t:0.2}"),
147                    None => "0.50:0.95".to_string(),
148                };
149                format!(
150                    " {title:<18} {kind} @[ IoU={iou:<9} | area={:>6} | maxDets={:>3} ] = {:0.3}",
151                    line.area.label, line.max_dets, line.value
152                )
153            })
154            .collect()
155    }
156}
157
158/// How a [`StatRequest`] picks an entry on the M-axis of an
159/// [`Accumulated`].
160///
161/// Pycocotools hard-codes `maxDets[0|1|2]` for `AR_{1,10,100}` and
162/// `maxDets[-1]` for everything else; this enum lets a plan express
163/// that intent — "the largest cap available" or "the entry whose value
164/// equals N" — without binding to fixed positional indices.
165#[derive(Debug, Clone, Copy, PartialEq, Eq)]
166pub enum MaxDetSelector {
167    /// Pick the largest cap in the supplied `max_dets` slice. This is
168    /// what every cocoeval AP line and `AR_S` / `AR_M` / `AR_L` use.
169    Largest,
170    /// Pick the M-axis entry whose value equals this. Errors via
171    /// [`EvalError::InvalidConfig`] if the value is absent.
172    Value(usize),
173}
174
175/// K-axis subset selector (ADR-0026 D2). Filters which categories
176/// contribute to a [`StatRequest`]'s mean.
177///
178/// Frequency buckets are *not* a [`crate::breakdown::Breakdown`] axis — they
179/// are a category-subset selector, the K-axis equivalent of an area
180/// bucket. The discriminated form keeps frequency-keyed (LVIS) and
181/// id-keyed (per-supercategory, ablation subsets) intents cleanly
182/// separated; the resolution to a list of K indices happens at
183/// summarize time once the K-axis ordering is known.
184///
185/// Only [`CategoryFilter::All`] is supported by the standard
186/// [`summarize_with`] entry point; the [`Frequency`] and
187/// [`ByIds`](Self::ByIds) variants require the K-axis context (the
188/// list of `CategoryId`s in axis order, plus the per-category
189/// frequency map for [`Frequency`](Self::Frequency)) and route through
190/// [`summarize_with_lvis`].
191#[derive(Debug, Clone, PartialEq, Eq)]
192pub enum CategoryFilter {
193    /// No filter — every category contributes (the COCO default).
194    All,
195    /// Include only categories whose [`Frequency`] tag matches. Quirk
196    /// **AB3**: empty after the `-1`-sentinel drop (handled internally
197    /// by the summarizer) yields
198    /// `-1.0`, not `0.0` or `nan` (lvis-api's `eval.py:441-442`).
199    Frequency(Frequency),
200    /// Explicit subset: include only categories whose id is in the
201    /// list. Sorted ascending for stable membership tests; duplicates
202    /// are ignored.
203    ByIds(Vec<CategoryId>),
204    /// Include only categories that belong to the named group of the
205    /// active [`ClassGroupBreakdown`] (ADR-0041 / ADR-0042). The string
206    /// is the group label; resolution happens at summarize time
207    /// against the configured grouping.
208    ///
209    /// Unlike [`Frequency`](Self::Frequency) / [`ByIds`](Self::ByIds), this variant does *not*
210    /// require LVIS context — it routes through the standard
211    /// summarizer with the breakdown reference passed in via the
212    /// non-LVIS context shim.
213    ///
214    /// [`ClassGroupBreakdown`]: crate::breakdown::ClassGroupBreakdown
215    ByGrouping(Cow<'static, str>),
216}
217
218impl CategoryFilter {
219    /// `true` if this filter requires the K-axis context (frequency
220    /// map or id ordering) — i.e., [`Self::Frequency`] or
221    /// [`Self::ByIds`]. [`Self::All`] and [`Self::ByGrouping`] are
222    /// resolvable without LVIS context (the latter via the
223    /// `ClassGroupBreakdown` reference passed alongside).
224    pub fn needs_lvis_context(&self) -> bool {
225        matches!(self, Self::Frequency(_) | Self::ByIds(_))
226    }
227}
228
229/// One line of a summary plan — describes a single mean to compute.
230#[derive(Debug, Clone)]
231pub struct StatRequest {
232    /// AP or AR.
233    pub metric: Metric,
234    /// `None` averages across the IoU ladder; `Some(t)` pins one row.
235    /// Looked up against `iou_thresholds` within an internal absolute
236    /// tolerance (≈1e-9) at
237    /// summarize time; values not on the ladder produce
238    /// [`EvalError::InvalidConfig`].
239    pub iou_threshold: Option<f64>,
240    /// Area-range bucket on the A-axis.
241    pub area: AreaRng,
242    /// How to pick the M-axis entry.
243    pub max_dets: MaxDetSelector,
244    /// K-axis subset (ADR-0026 D2). Defaults to
245    /// [`CategoryFilter::All`] for COCO-shape plans; LVIS plans use
246    /// [`CategoryFilter::Frequency`] for the AP_r/c/f buckets.
247    pub category_filter: CategoryFilter,
248}
249
250impl StatRequest {
251    /// Convenience constructor. `const`-callable so [`coco_detection_default`]
252    /// and downstream user-defined plans can be assembled in `const`
253    /// contexts. Defaults `category_filter` to [`CategoryFilter::All`].
254    ///
255    /// [`coco_detection_default`]: Self::coco_detection_default
256    pub const fn new(
257        metric: Metric,
258        iou_threshold: Option<f64>,
259        area: AreaRng,
260        max_dets: MaxDetSelector,
261    ) -> Self {
262        Self {
263            metric,
264            iou_threshold,
265            area,
266            max_dets,
267            category_filter: CategoryFilter::All,
268        }
269    }
270
271    /// Construct with a non-default [`CategoryFilter`] in one shot.
272    /// `const`-callable for the [`Frequency`](CategoryFilter::Frequency)
273    /// and [`All`](CategoryFilter::All) variants;
274    /// [`ByIds`](CategoryFilter::ByIds) carries a heap-allocated
275    /// `Vec<CategoryId>` and is constructed at runtime.
276    pub const fn new_with_filter(
277        metric: Metric,
278        iou_threshold: Option<f64>,
279        area: AreaRng,
280        max_dets: MaxDetSelector,
281        category_filter: CategoryFilter,
282    ) -> Self {
283        Self {
284            metric,
285            iou_threshold,
286            area,
287            max_dets,
288            category_filter,
289        }
290    }
291
292    /// The canonical 12-entry pycocotools detection plan, in the
293    /// `[AP, AP50, AP75, AP_S, AP_M, AP_L, AR_1, AR_10, AR_100, AR_S,
294    /// AR_M, AR_L]` order. Bit-exact with cocoeval is by construction:
295    /// [`summarize_detection`] is just `summarize_with(.., this, ..)`.
296    pub const fn coco_detection_default() -> [Self; 12] {
297        use MaxDetSelector::{Largest, Value};
298        use Metric::{AveragePrecision, AverageRecall};
299        [
300            Self::new(AveragePrecision, None, AreaRng::ALL, Largest),
301            Self::new(AveragePrecision, Some(0.5), AreaRng::ALL, Largest),
302            Self::new(AveragePrecision, Some(0.75), AreaRng::ALL, Largest),
303            Self::new(AveragePrecision, None, AreaRng::SMALL, Largest),
304            Self::new(AveragePrecision, None, AreaRng::MEDIUM, Largest),
305            Self::new(AveragePrecision, None, AreaRng::LARGE, Largest),
306            Self::new(AverageRecall, None, AreaRng::ALL, Value(1)),
307            Self::new(AverageRecall, None, AreaRng::ALL, Value(10)),
308            Self::new(AverageRecall, None, AreaRng::ALL, Value(100)),
309            Self::new(AverageRecall, None, AreaRng::SMALL, Largest),
310            Self::new(AverageRecall, None, AreaRng::MEDIUM, Largest),
311            Self::new(AverageRecall, None, AreaRng::LARGE, Largest),
312        ]
313    }
314
315    /// The canonical 13-entry LVIS detection plan (ADR-0026 AF1, AF4),
316    /// in the LVIS `print_results` order:
317    ///
318    /// `[AP, AP50, AP75, APs, APm, APl, APr, APc, APf,
319    ///   AR@300, ARs@300, ARm@300, ARl@300]`
320    ///
321    /// Differences from [`Self::coco_detection_default`]:
322    ///
323    /// - **9 AP entries vs 6.** Three additional rows (APr/APc/APf)
324    ///   filter the K-axis by [`Frequency`] tag. `lvis-api` reports
325    ///   them as separate entries, not `Breakdown` axes (ADR-0016
326    ///   `f64`-keyed type doesn't fit categorical tags).
327    /// - **4 AR entries vs 6.** No `AR@1` / `AR@10` / `AR@100` —
328    ///   LVIS reports recall at `max_dets=300` only (AF4). The
329    ///   `Largest` selector resolves to whatever the user passes;
330    ///   pair the plan with `max_dets=[300]` for parity with
331    ///   `LVISEval`.
332    ///
333    /// `Frequency`-filtered entries route through
334    /// [`summarize_with_lvis`]; calling [`summarize_with`] on this
335    /// plan returns [`EvalError::InvalidConfig`] (the plain entry
336    /// point has no K-axis context).
337    pub const fn lvis_default() -> [Self; 13] {
338        use CategoryFilter::{All as AllK, Frequency as FreqK};
339        use MaxDetSelector::Largest;
340        use Metric::{AveragePrecision, AverageRecall};
341        [
342            Self::new_with_filter(AveragePrecision, None, AreaRng::ALL, Largest, AllK),
343            Self::new_with_filter(AveragePrecision, Some(0.5), AreaRng::ALL, Largest, AllK),
344            Self::new_with_filter(AveragePrecision, Some(0.75), AreaRng::ALL, Largest, AllK),
345            Self::new_with_filter(AveragePrecision, None, AreaRng::SMALL, Largest, AllK),
346            Self::new_with_filter(AveragePrecision, None, AreaRng::MEDIUM, Largest, AllK),
347            Self::new_with_filter(AveragePrecision, None, AreaRng::LARGE, Largest, AllK),
348            Self::new_with_filter(
349                AveragePrecision,
350                None,
351                AreaRng::ALL,
352                Largest,
353                FreqK(Frequency::Rare),
354            ),
355            Self::new_with_filter(
356                AveragePrecision,
357                None,
358                AreaRng::ALL,
359                Largest,
360                FreqK(Frequency::Common),
361            ),
362            Self::new_with_filter(
363                AveragePrecision,
364                None,
365                AreaRng::ALL,
366                Largest,
367                FreqK(Frequency::Frequent),
368            ),
369            Self::new_with_filter(AverageRecall, None, AreaRng::ALL, Largest, AllK),
370            Self::new_with_filter(AverageRecall, None, AreaRng::SMALL, Largest, AllK),
371            Self::new_with_filter(AverageRecall, None, AreaRng::MEDIUM, Largest, AllK),
372            Self::new_with_filter(AverageRecall, None, AreaRng::LARGE, Largest, AllK),
373        ]
374    }
375
376    /// The canonical 10-entry pycocotools keypoints plan, in the
377    /// `[AP, AP50, AP75, AP_M, AP_L, AR, AR50, AR75, AR_M, AR_L]`
378    /// order (cocoeval.py:478-499 under `iouType="keypoints"`).
379    ///
380    /// Differs from [`Self::coco_detection_default`] in three ways,
381    /// all per ADR-0012:
382    ///
383    /// - 10 entries, not 12 — the small-area row is dropped on both
384    ///   AP and AR (quirk **D5**).
385    /// - Every entry uses [`MaxDetSelector::Largest`], which resolves
386    ///   to the kp-canonical `(20,)` ladder; there are no `AR_1` /
387    ///   `AR_10` / `AR_100` rows because the kp ladder has only one
388    ///   rung.
389    /// - The `AreaRng` indices `0/1/2` (all/medium/large) are
390    ///   re-indexed for the kp A-axis. Callers must pair this plan
391    ///   with [`crate::AreaRange::keypoints_default`] so the A-axis
392    ///   indices line up; the const [`AreaRng::ALL`] / `MEDIUM` /
393    ///   `LARGE` carry the four-bucket detection-grid indices and
394    ///   would index off the end of a three-bucket accumulator.
395    pub const fn coco_keypoints_default() -> [Self; 10] {
396        use MaxDetSelector::Largest;
397        use Metric::{AveragePrecision, AverageRecall};
398        // D5: re-indexed kp A-axis (0=all, 1=medium, 2=large), no small.
399        // `from_static` is `const`, so each call site materializes a
400        // fresh `AreaRng` without an intermediate `clone()` — mirroring
401        // `coco_detection_default`'s use of the const `AreaRng::ALL`
402        // / `MEDIUM` / `LARGE` constants.
403        const ALL: AreaRng = AreaRng::from_static(0, "all");
404        const MEDIUM: AreaRng = AreaRng::from_static(1, "medium");
405        const LARGE: AreaRng = AreaRng::from_static(2, "large");
406        [
407            Self::new(AveragePrecision, None, ALL, Largest),
408            Self::new(AveragePrecision, Some(0.5), ALL, Largest),
409            Self::new(AveragePrecision, Some(0.75), ALL, Largest),
410            Self::new(AveragePrecision, None, MEDIUM, Largest),
411            Self::new(AveragePrecision, None, LARGE, Largest),
412            Self::new(AverageRecall, None, ALL, Largest),
413            Self::new(AverageRecall, Some(0.5), ALL, Largest),
414            Self::new(AverageRecall, Some(0.75), ALL, Largest),
415            Self::new(AverageRecall, None, MEDIUM, Largest),
416            Self::new(AverageRecall, None, LARGE, Largest),
417        ]
418    }
419}
420
421/// Twelve-stat COCO detection summary, bit-exact with cocoeval.
422///
423/// Thin wrapper over [`summarize_with`] that supplies the canonical
424/// 12-entry plan from [`StatRequest::coco_detection_default`].
425/// Downstream callers who need a different shape (keypoint `[20]`
426/// maxDets, custom AP@.30, …) should call `summarize_with` directly
427/// with their own plan; the canonical plan is available via the
428/// constructor for those who want to extend rather than replace it.
429///
430/// # Errors
431///
432/// Same conditions as [`summarize_with`].
433pub fn summarize_detection(
434    accum: &Accumulated,
435    iou_thresholds: &[f64],
436    max_dets: &[usize],
437) -> Result<Summary, EvalError> {
438    summarize_with(
439        accum,
440        &StatRequest::coco_detection_default(),
441        iou_thresholds,
442        max_dets,
443    )
444}
445
446/// Evaluate an arbitrary summary plan over an [`Accumulated`].
447///
448/// `iou_thresholds` and `max_dets` describe the grid the `Accumulated`
449/// was built against; they are needed to resolve [`StatRequest`]
450/// selectors (IoU value → T-axis index, [`MaxDetSelector`] → M-axis
451/// index) and to populate the `max_dets` field on each emitted
452/// [`StatLine`].
453///
454/// # Errors
455///
456/// Returns [`EvalError::DimensionMismatch`] if `iou_thresholds` or
457/// `max_dets` lengths disagree with `accum`'s `T`/`M` axes. Returns
458/// [`EvalError::InvalidConfig`] if any request names an IoU threshold
459/// not present in `iou_thresholds` (within `1e-12`) or a
460/// [`MaxDetSelector::Value`] absent from `max_dets`.
461pub fn summarize_with(
462    accum: &Accumulated,
463    plan: &[StatRequest],
464    iou_thresholds: &[f64],
465    max_dets: &[usize],
466) -> Result<Summary, EvalError> {
467    summarize_dispatch(accum, plan, iou_thresholds, max_dets, None)
468}
469
470/// LVIS variant of [`summarize_with`] that resolves
471/// [`CategoryFilter::Frequency`] / [`CategoryFilter::ByIds`] against
472/// the K-axis context (ADR-0026 D2).
473///
474/// `category_ids` lists the dataset's categories in K-axis order
475/// (id-ascending — the same ordering the orchestrator at
476/// `evaluate.rs:701-707` uses). `category_frequency` is the
477/// per-category tag from the LVIS JSON (quirk **AB1**); pass `None`
478/// to opt out of frequency-based filtering, in which case any
479/// [`CategoryFilter::Frequency`] entry yields `-1.0` (the AP-undefined
480/// sentinel — quirk **AB6**, also the migration-guide's note for
481/// COCO datasets that don't carry frequency tags).
482///
483/// # Errors
484///
485/// Same error surface as [`summarize_with`], plus
486/// [`EvalError::InvalidConfig`] when `category_ids.len()` does not
487/// match the K-axis size of `accum.precision`.
488pub fn summarize_with_lvis(
489    accum: &Accumulated,
490    plan: &[StatRequest],
491    iou_thresholds: &[f64],
492    max_dets: &[usize],
493    category_ids: &[CategoryId],
494    category_frequency: Option<&HashMap<CategoryId, Frequency>>,
495) -> Result<Summary, EvalError> {
496    let n_k = accum.precision.shape()[2];
497    if category_ids.len() != n_k {
498        return Err(EvalError::InvalidConfig {
499            detail: format!(
500                "category_ids len {} != precision K-axis {n_k}",
501                category_ids.len()
502            ),
503        });
504    }
505    let ctx = LvisCtx {
506        category_ids,
507        category_frequency,
508    };
509    summarize_dispatch(accum, plan, iou_thresholds, max_dets, Some(&ctx))
510}
511
512/// Internal context bundle for the LVIS K-axis resolution. Carrying it
513/// as an `Option` lets [`summarize_with`] and [`summarize_with_lvis`]
514/// share the body without exposing a fifth public parameter on the
515/// COCO path.
516struct LvisCtx<'a> {
517    category_ids: &'a [CategoryId],
518    category_frequency: Option<&'a HashMap<CategoryId, Frequency>>,
519}
520
521fn summarize_dispatch(
522    accum: &Accumulated,
523    plan: &[StatRequest],
524    iou_thresholds: &[f64],
525    max_dets: &[usize],
526    lvis: Option<&LvisCtx<'_>>,
527) -> Result<Summary, EvalError> {
528    let p_shape = accum.precision.shape();
529    let r_shape = accum.recall.shape();
530    let n_t = p_shape[0];
531    let n_m = p_shape[4];
532
533    if n_t != iou_thresholds.len() {
534        return Err(EvalError::DimensionMismatch {
535            detail: format!(
536                "precision T-axis {} != iou_thresholds len {}",
537                n_t,
538                iou_thresholds.len()
539            ),
540        });
541    }
542    if n_m != max_dets.len() {
543        return Err(EvalError::DimensionMismatch {
544            detail: format!(
545                "precision M-axis {} != max_dets len {}",
546                n_m,
547                max_dets.len()
548            ),
549        });
550    }
551    if r_shape[0] != n_t || r_shape[3] != n_m {
552        return Err(EvalError::DimensionMismatch {
553            detail: format!("recall {r_shape:?} disagrees with precision {p_shape:?}"),
554        });
555    }
556
557    // Resolve every selector before computing any means: a typo in any
558    // request fails early without wasting evaluation work, and the
559    // compute pass below stays infallible.
560    let n_a = p_shape[3];
561    let n_k = p_shape[2];
562    let m_max = max_dets.len() - 1;
563    let resolved: Vec<(usize, Range<usize>, Option<Vec<bool>>)> = plan
564        .iter()
565        .map(|req| {
566            if req.area.index >= n_a {
567                return Err(EvalError::InvalidConfig {
568                    detail: format!(
569                        "AreaRng index {} is out of range for A-axis (size {})",
570                        req.area.index, n_a
571                    ),
572                });
573            }
574            let m_idx = match req.max_dets {
575                MaxDetSelector::Largest => m_max,
576                MaxDetSelector::Value(v) => {
577                    max_dets.iter().position(|&d| d == v).ok_or_else(|| {
578                        EvalError::InvalidConfig {
579                            detail: format!("max_dets does not contain {v}"),
580                        }
581                    })?
582                }
583            };
584            let t_range = match req.iou_threshold {
585                None => 0..n_t,
586                Some(target) => {
587                    let t = iou_thresholds
588                        .iter()
589                        .position(|&v| (v - target).abs() < IOU_LOOKUP_TOL)
590                        .ok_or_else(|| EvalError::InvalidConfig {
591                            detail: format!("iou_threshold {target} not in ladder"),
592                        })?;
593                    t..(t + 1)
594                }
595            };
596            let k_mask = resolve_category_filter(&req.category_filter, n_k, lvis)?;
597            Ok((m_idx, t_range, k_mask))
598        })
599        .collect::<Result<Vec<_>, EvalError>>()?;
600
601    let lines = plan
602        .iter()
603        .zip(resolved)
604        .map(|(req, (m_idx, t_range, k_mask))| {
605            let value = mean_slice(
606                accum,
607                req.metric,
608                t_range,
609                req.area.index,
610                m_idx,
611                k_mask.as_deref(),
612            );
613            StatLine {
614                metric: req.metric,
615                iou_threshold: req.iou_threshold,
616                area: req.area.clone(),
617                max_dets: max_dets[m_idx],
618                value,
619            }
620        })
621        .collect();
622
623    Ok(Summary { lines })
624}
625
626/// Resolve a [`CategoryFilter`] to a K-axis bool mask of length `n_k`.
627/// Returns `None` for [`CategoryFilter::All`] (the no-op — equivalent
628/// to a mask of all `true`s), and `Some(mask)` for the filtered
629/// variants.
630///
631/// Errors when a non-`All` filter is encountered without an LVIS
632/// context (the standard [`summarize_with`] entry point — its error
633/// message points users at [`summarize_with_lvis`]).
634fn resolve_category_filter(
635    filter: &CategoryFilter,
636    n_k: usize,
637    lvis: Option<&LvisCtx<'_>>,
638) -> Result<Option<Vec<bool>>, EvalError> {
639    match filter {
640        CategoryFilter::All => Ok(None),
641        CategoryFilter::Frequency(target) => {
642            let Some(ctx) = lvis else {
643                return Err(EvalError::InvalidConfig {
644                    detail: "CategoryFilter::Frequency requires summarize_with_lvis".to_string(),
645                });
646            };
647            let Some(freq_map) = ctx.category_frequency else {
648                // AB6 migration-guide note: dataset has no frequency
649                // tags. AP_r/c/f can't be computed; the request
650                // resolves to "no K passes", which mean_slice maps
651                // to the `-1` sentinel.
652                return Ok(Some(vec![false; n_k]));
653            };
654            Ok(Some(
655                ctx.category_ids
656                    .iter()
657                    .map(|cid| freq_map.get(cid).is_some_and(|f| f == target))
658                    .collect(),
659            ))
660        }
661        CategoryFilter::ByIds(ids) => {
662            let Some(ctx) = lvis else {
663                return Err(EvalError::InvalidConfig {
664                    detail: "CategoryFilter::ByIds requires summarize_with_lvis".to_string(),
665                });
666            };
667            let allow: std::collections::HashSet<&CategoryId> = ids.iter().collect();
668            Ok(Some(
669                ctx.category_ids
670                    .iter()
671                    .map(|cid| allow.contains(cid))
672                    .collect(),
673            ))
674        }
675        CategoryFilter::ByGrouping(label) => Err(EvalError::InvalidConfig {
676            detail: format!(
677                "CategoryFilter::ByGrouping({label:?}) must be resolved to ByIds at the \
678                 evaluator boundary before reaching the kernel summarizer (ADR-0041 / 0042). \
679                 Resolution maps the group label against the active ClassGroupBreakdown."
680            ),
681        }),
682    }
683}
684
685/// Mean of an `Accumulated` slice, filtering out the `-1` sentinel
686/// (quirks **C5/L6**) and optionally masking out K-axis indices that
687/// the request's [`CategoryFilter`] excludes (ADR-0026 D2). Returns
688/// `-1.0` if every surviving cell is the sentinel (mirrors
689/// pycocotools' `if len(s[s>-1])==0: -1`; quirk **AF6**: stays at
690/// `-1`, never collapses to `0` or `nan`).
691///
692/// The sum is computed via numpy-compatible pairwise summation
693/// ([`pairwise_sum`]) so the result is bit-identical to
694/// `np.mean(s[s>-1])` for the same input ordering. The K-axis mask
695/// is applied **before** the sentinel drop and the mean — matching
696/// lvis-api `eval.py:444`'s `s[s>-1]` shape on a frequency-filtered
697/// slice.
698///
699/// `k_mask`: `None` is the COCO no-op (every K passes); `Some(mask)`
700/// includes only K-axis indices where `mask[k] == true`. The mask
701/// length must equal the K-axis size; the caller (resolved in
702/// [`resolve_category_filter`]) guarantees this.
703///
704/// Infallible: callers must validate `t_range`, `area_idx`, and
705/// `m_idx` against the `Accumulated`'s shape upfront (see
706/// [`summarize_with`]).
707fn mean_slice(
708    accum: &Accumulated,
709    metric: Metric,
710    t_range: Range<usize>,
711    area_idx: usize,
712    m_idx: usize,
713    k_mask: Option<&[bool]>,
714) -> f64 {
715    let t_count = t_range.len();
716    let n_k = accum.precision.shape()[2];
717    let cap = match metric {
718        Metric::AveragePrecision => t_count * accum.precision.shape()[1] * n_k,
719        Metric::AverageRecall => t_count * n_k,
720    };
721    let mut filtered: Vec<f64> = Vec::with_capacity(cap);
722    let push_if = |filtered: &mut Vec<f64>, v: f64| {
723        if v > -1.0 {
724            filtered.push(v);
725        }
726    };
727    for t in t_range {
728        match metric {
729            Metric::AveragePrecision => {
730                // Slice T → (R, K, A, M); pick A and M → (R, K). The
731                // intermediate axis-views have to live in `let`
732                // bindings so the final `(R, K)` view borrows them
733                // long enough to index — the chained-call form
734                // dropped the inner views before `plane` was used.
735                let p_t = accum.precision.index_axis(Axis(0), t);
736                let p_ta = p_t.index_axis(Axis(2), area_idx);
737                let plane = p_ta.index_axis(Axis(2), m_idx);
738                // Walking R-major preserves the same sum order numpy
739                // uses on a `(R, K)` slice — the K-mask filter just
740                // skips columns the user opted out of.
741                let n_r = plane.shape()[0];
742                for r in 0..n_r {
743                    for k in 0..n_k {
744                        if k_mask.is_some_and(|m| !m[k]) {
745                            continue;
746                        }
747                        push_if(&mut filtered, plane[(r, k)]);
748                    }
749                }
750            }
751            Metric::AverageRecall => {
752                // recall is (T, K, A, M); slice T → (K, A, M); pick A
753                // and M → (K,). One value per K.
754                let r_t = accum.recall.index_axis(Axis(0), t);
755                let r_ta = r_t.index_axis(Axis(1), area_idx);
756                let plane = r_ta.index_axis(Axis(1), m_idx);
757                for k in 0..n_k {
758                    if k_mask.is_some_and(|m| !m[k]) {
759                        continue;
760                    }
761                    push_if(&mut filtered, plane[k]);
762                }
763            }
764        }
765    }
766    if filtered.is_empty() {
767        -1.0
768    } else {
769        pairwise_sum(&filtered) / filtered.len() as f64
770    }
771}
772
773/// Numpy-compatible pairwise summation for `f64` slices.
774///
775/// Matches the algorithm used by `np.add.reduce` on contiguous
776/// double-precision arrays (see numpy's
777/// `numpy/core/src/umath/loops_utils.h.src::pairwise_sum_DOUBLE`):
778///
779/// - `n < 8`: naive forward sum.
780/// - `8 <= n <= PW_BLOCKSIZE` (128): 8 separately accumulated lanes
781///   combined via a balanced tree `((r0+r1)+(r2+r3)) + ((r4+r5)+(r6+r7))`,
782///   followed by a tail loop for the remainder.
783/// - `n > PW_BLOCKSIZE`: split at `n / 2` aligned down to a multiple of
784///   8 and recurse on both halves.
785///
786/// Reproducing this here is a quirk-**C8**-style alignment: the public
787/// summary stats ride on top of `np.mean(s[s > -1])`, and any other sum
788/// order drifts by ~1 ULP.
789pub(crate) fn pairwise_sum(values: &[f64]) -> f64 {
790    const PW_BLOCKSIZE: usize = 128;
791    let n = values.len();
792
793    if n < 8 {
794        let mut s = 0.0_f64;
795        for &v in values {
796            s += v;
797        }
798        return s;
799    }
800
801    if n <= PW_BLOCKSIZE {
802        let mut r = [
803            values[0], values[1], values[2], values[3], values[4], values[5], values[6], values[7],
804        ];
805        let trunc = n - (n % 8);
806        let mut i = 8;
807        while i < trunc {
808            r[0] += values[i];
809            r[1] += values[i + 1];
810            r[2] += values[i + 2];
811            r[3] += values[i + 3];
812            r[4] += values[i + 4];
813            r[5] += values[i + 5];
814            r[6] += values[i + 6];
815            r[7] += values[i + 7];
816            i += 8;
817        }
818        let mut res = ((r[0] + r[1]) + (r[2] + r[3])) + ((r[4] + r[5]) + (r[6] + r[7]));
819        while i < n {
820            res += values[i];
821            i += 1;
822        }
823        return res;
824    }
825
826    let mut n2 = n / 2;
827    n2 -= n2 % 8;
828    pairwise_sum(&values[..n2]) + pairwise_sum(&values[n2..])
829}
830
831#[cfg(test)]
832mod tests {
833    use super::*;
834    use crate::accumulate::{accumulate, AccumulateParams, PerImageEval};
835    use crate::parity::{iou_thresholds, recall_thresholds, ParityMode};
836    use ndarray::{Array2, Array4, Array5};
837
838    fn perfect_match_eval(t: usize) -> PerImageEval {
839        PerImageEval {
840            dt_scores: vec![0.9],
841            dt_matched: Array2::from_elem((t, 1), true),
842            dt_ignore: Array2::from_elem((t, 1), false),
843            gt_ignore: vec![false],
844        }
845    }
846
847    #[test]
848    fn perfect_match_summarizes_to_ones() {
849        // Single image, single category, all-area only — the simplest
850        // valid run that exercises every line of the 12-stat table.
851        let iou = iou_thresholds();
852        let rec = recall_thresholds();
853        let max_dets = [1usize, 10, 100];
854        let cell = perfect_match_eval(iou.len());
855
856        // K=1, A=4 (all/small/medium/large), I=1; we populate only the
857        // `all` cell. small/medium/large stay None → -1 sentinel.
858        let mut grid: Vec<Option<Box<PerImageEval>>> = vec![None; 4];
859        grid[0] = Some(Box::new(cell));
860
861        let p = AccumulateParams {
862            iou_thresholds: iou,
863            recall_thresholds: rec,
864            max_dets: &max_dets,
865            n_categories: 1,
866            n_area_ranges: 4,
867            n_images: 1,
868        };
869        let accum = accumulate(&grid, p, ParityMode::Strict).unwrap();
870        let summary = summarize_detection(&accum, iou, &max_dets).unwrap();
871
872        let stats = summary.stats();
873        assert_eq!(stats.len(), 12);
874        // AP[all], AP50, AP75, AR_1, AR_10, AR_100 should all be ~1.0.
875        for &i in &[0usize, 1, 2, 6, 7, 8] {
876            let v = stats[i];
877            assert!((v - 1.0).abs() < 1e-9, "stat[{i}] = {v}");
878        }
879        // small / medium / large carry -1 (no data).
880        for &i in &[3usize, 4, 5, 9, 10, 11] {
881            assert_eq!(stats[i], -1.0, "stat[{i}] should be -1 sentinel");
882        }
883    }
884
885    #[test]
886    fn empty_grid_yields_all_neg_one_stats() {
887        let iou = iou_thresholds();
888        let rec = recall_thresholds();
889        let max_dets = [1usize, 10, 100];
890        let p = AccumulateParams {
891            iou_thresholds: iou,
892            recall_thresholds: rec,
893            max_dets: &max_dets,
894            n_categories: 1,
895            n_area_ranges: 4,
896            n_images: 0,
897        };
898        let accum = accumulate(&[], p, ParityMode::Strict).unwrap();
899        let summary = summarize_detection(&accum, iou, &max_dets).unwrap();
900        assert!(summary.stats().iter().all(|&v| v == -1.0));
901    }
902
903    #[test]
904    fn missing_max_det_value_is_typed_error() {
905        // AR_1 line requires max_dets to contain 1; without it,
906        // summarization fails with InvalidConfig.
907        let iou = iou_thresholds();
908        let max_dets = [10usize, 100];
909        let accum = Accumulated {
910            precision: Array5::<f64>::from_elem((iou.len(), 101, 1, 4, 2), -1.0),
911            recall: Array4::<f64>::from_elem((iou.len(), 1, 4, 2), -1.0),
912            scores: Array5::<f64>::from_elem((iou.len(), 101, 1, 4, 2), -1.0),
913        };
914        let err = summarize_detection(&accum, iou, &max_dets).unwrap_err();
915        assert!(matches!(err, EvalError::InvalidConfig { .. }));
916    }
917
918    #[test]
919    fn iou_threshold_dimension_mismatch_is_typed_error() {
920        let max_dets = [100usize];
921        let accum = Accumulated {
922            precision: Array5::<f64>::from_elem((10, 101, 1, 4, 1), -1.0),
923            recall: Array4::<f64>::from_elem((10, 1, 4, 1), -1.0),
924            scores: Array5::<f64>::from_elem((10, 101, 1, 4, 1), -1.0),
925        };
926        // pass only 5 thresholds — accum was built with 10.
927        let err = summarize_detection(&accum, &[0.5, 0.6, 0.7, 0.8, 0.9], &max_dets).unwrap_err();
928        assert!(matches!(err, EvalError::DimensionMismatch { .. }));
929    }
930
931    #[test]
932    fn summarize_with_custom_plan_evaluates_only_requested_lines() {
933        // Demonstrates the extension point: a 2-entry plan asking for
934        // AP@.50 across all areas and AR@.75 (not in the canonical 12)
935        // — both at the largest cap. Order is preserved.
936        let iou = iou_thresholds();
937        let max_dets = [100usize];
938        let accum = Accumulated {
939            precision: Array5::<f64>::from_elem((iou.len(), 101, 1, 4, 1), 0.5),
940            recall: Array4::<f64>::from_elem((iou.len(), 1, 4, 1), 0.7),
941            scores: Array5::<f64>::from_elem((iou.len(), 101, 1, 4, 1), 1.0),
942        };
943        let plan = [
944            StatRequest::new(
945                Metric::AveragePrecision,
946                Some(0.5),
947                AreaRng::ALL,
948                MaxDetSelector::Largest,
949            ),
950            StatRequest::new(
951                Metric::AverageRecall,
952                Some(0.75),
953                AreaRng::ALL,
954                MaxDetSelector::Largest,
955            ),
956        ];
957        let summary = summarize_with(&accum, &plan, iou, &max_dets).unwrap();
958        assert_eq!(summary.lines.len(), 2);
959        assert!((summary.lines[0].value - 0.5).abs() < 1e-12);
960        assert_eq!(summary.lines[0].iou_threshold, Some(0.5));
961        assert!((summary.lines[1].value - 0.7).abs() < 1e-12);
962        assert_eq!(summary.lines[1].metric, Metric::AverageRecall);
963    }
964
965    #[test]
966    fn summarize_detection_matches_canonical_plan_via_summarize_with() {
967        // The thin-wrapper invariant: results are bit-equal whether the
968        // caller invokes summarize_detection or summarize_with with the
969        // canonical plan.
970        let iou = iou_thresholds();
971        let max_dets = [1usize, 10, 100];
972        let accum = Accumulated {
973            precision: Array5::<f64>::from_elem((iou.len(), 101, 1, 4, 3), 0.5),
974            recall: Array4::<f64>::from_elem((iou.len(), 1, 4, 3), 0.7),
975            scores: Array5::<f64>::from_elem((iou.len(), 101, 1, 4, 3), 1.0),
976        };
977        let direct = summarize_detection(&accum, iou, &max_dets).unwrap();
978        let via_plan = summarize_with(
979            &accum,
980            &StatRequest::coco_detection_default(),
981            iou,
982            &max_dets,
983        )
984        .unwrap();
985        assert_eq!(direct.stats(), via_plan.stats());
986    }
987
988    #[test]
989    fn custom_area_bucket_with_owned_label_renders_in_pretty_lines() {
990        // 5-bucket A-axis (e.g. an orchestrator that adds a "tiny"
991        // bucket below "small"). The plan addresses index 4 by name and
992        // the label flows through to pretty_lines.
993        let iou = iou_thresholds();
994        let max_dets = [100usize];
995        let accum = Accumulated {
996            precision: Array5::<f64>::from_elem((iou.len(), 101, 1, 5, 1), 1.0),
997            recall: Array4::<f64>::from_elem((iou.len(), 1, 5, 1), 1.0),
998            scores: Array5::<f64>::from_elem((iou.len(), 101, 1, 5, 1), 1.0),
999        };
1000        let plan = [StatRequest::new(
1001            Metric::AveragePrecision,
1002            None,
1003            AreaRng::new(4, "tiny"),
1004            MaxDetSelector::Largest,
1005        )];
1006        let summary = summarize_with(&accum, &plan, iou, &max_dets).unwrap();
1007        let lines = summary.pretty_lines();
1008        assert_eq!(lines.len(), 1);
1009        assert!(lines[0].contains("tiny"), "unexpected line: {}", lines[0]);
1010    }
1011
1012    #[test]
1013    fn out_of_range_area_index_is_typed_error() {
1014        // Plan addresses A-axis index 4 against a 4-bucket Accumulated.
1015        let iou = iou_thresholds();
1016        let max_dets = [100usize];
1017        let accum = Accumulated {
1018            precision: Array5::<f64>::from_elem((iou.len(), 101, 1, 4, 1), 1.0),
1019            recall: Array4::<f64>::from_elem((iou.len(), 1, 4, 1), 1.0),
1020            scores: Array5::<f64>::from_elem((iou.len(), 101, 1, 4, 1), 1.0),
1021        };
1022        let plan = [StatRequest::new(
1023            Metric::AveragePrecision,
1024            None,
1025            AreaRng::new(4, "tiny"),
1026            MaxDetSelector::Largest,
1027        )];
1028        let err = summarize_with(&accum, &plan, iou, &max_dets).unwrap_err();
1029        assert!(matches!(err, EvalError::InvalidConfig { .. }));
1030    }
1031
1032    #[test]
1033    fn pretty_lines_match_pycocotools_shape() {
1034        let iou = iou_thresholds();
1035        let max_dets = [1usize, 10, 100];
1036        let accum = Accumulated {
1037            precision: Array5::<f64>::from_elem((iou.len(), 101, 1, 4, 3), 1.0),
1038            recall: Array4::<f64>::from_elem((iou.len(), 1, 4, 3), 1.0),
1039            scores: Array5::<f64>::from_elem((iou.len(), 101, 1, 4, 3), 1.0),
1040        };
1041        let summary = summarize_detection(&accum, iou, &max_dets).unwrap();
1042        let lines = summary.pretty_lines();
1043        assert_eq!(lines.len(), 12);
1044        // Spot-check the first AP line and the first AR line for the
1045        // pycocotools-shaped layout.
1046        assert!(lines[0].contains("Average Precision"));
1047        assert!(lines[0].contains("(AP)"));
1048        assert!(lines[0].contains("0.50:0.95"));
1049        assert!(lines[0].contains("maxDets=100"));
1050        assert!(lines[6].contains("Average Recall"));
1051        assert!(lines[6].contains("maxDets=  1"));
1052    }
1053
1054    #[test]
1055    fn pairwise_sum_matches_numpy_add_reduce_bitwise() {
1056        // 1010 alternating elements is large enough to drive both the
1057        // 8-lane unrolled block and the recursive split (n > 128). The
1058        // expected hex below is `np.add.reduce(v).hex()` for the same
1059        // sequence; naive forward summation lands one ULP higher
1060        // (`0x1.f900000002309p+8`).
1061        let v: Vec<f64> = (0..1010)
1062            .map(|i| if i % 2 == 0 { 1.0 } else { 1e-12 })
1063            .collect();
1064        let got = pairwise_sum(&v);
1065        let expected = f64::from_bits(0x407f_9000_0000_22b4);
1066        assert_eq!(
1067            got.to_bits(),
1068            expected.to_bits(),
1069            "pairwise_sum drifts from numpy: got {got:e}, expected {expected:e}",
1070        );
1071    }
1072
1073    #[test]
1074    fn coco_keypoints_default_plan_pins_canonical_order() {
1075        // ADR-0012 / D5: pycocotools' kp summary is exactly these 10
1076        // lines, in this order. Pin metric, threshold, A-axis index,
1077        // and selector so a refactor cannot silently re-order, drop a
1078        // row, or re-introduce the small bucket.
1079        let plan = StatRequest::coco_keypoints_default();
1080        assert_eq!(plan.len(), 10);
1081
1082        // Each entry: (metric, iou_threshold, area_index, selector).
1083        let expected: [(Metric, Option<f64>, usize, MaxDetSelector); 10] = [
1084            (Metric::AveragePrecision, None, 0, MaxDetSelector::Largest), // AP
1085            (
1086                Metric::AveragePrecision,
1087                Some(0.5),
1088                0,
1089                MaxDetSelector::Largest,
1090            ), // AP50
1091            (
1092                Metric::AveragePrecision,
1093                Some(0.75),
1094                0,
1095                MaxDetSelector::Largest,
1096            ), // AP75
1097            (Metric::AveragePrecision, None, 1, MaxDetSelector::Largest), // AP_M
1098            (Metric::AveragePrecision, None, 2, MaxDetSelector::Largest), // AP_L
1099            (Metric::AverageRecall, None, 0, MaxDetSelector::Largest),    // AR
1100            (Metric::AverageRecall, Some(0.5), 0, MaxDetSelector::Largest), // AR50
1101            (
1102                Metric::AverageRecall,
1103                Some(0.75),
1104                0,
1105                MaxDetSelector::Largest,
1106            ), // AR75
1107            (Metric::AverageRecall, None, 1, MaxDetSelector::Largest),    // AR_M
1108            (Metric::AverageRecall, None, 2, MaxDetSelector::Largest),    // AR_L
1109        ];
1110
1111        for (i, (metric, iou, idx, sel)) in expected.into_iter().enumerate() {
1112            assert_eq!(plan[i].metric, metric, "row {i} metric");
1113            assert_eq!(plan[i].iou_threshold, iou, "row {i} iou_threshold");
1114            assert_eq!(plan[i].area.index, idx, "row {i} area index");
1115            assert_eq!(plan[i].max_dets, sel, "row {i} selector");
1116        }
1117
1118        // No row addresses A-axis index 3 (would land off the end of a
1119        // 3-bucket kp accumulator) and no row addresses index 1 of the
1120        // detection-grid (which is "small" — D5 forbids).
1121        assert!(plan.iter().all(|r| r.area.index <= 2));
1122    }
1123
1124    #[test]
1125    fn pairwise_sum_handles_short_inputs_with_naive_fallback() {
1126        // n < 8 uses the simple loop; verify a hand-checked tiny case.
1127        let v = [1.0_f64, 2.0, 3.0, 4.0];
1128        assert_eq!(pairwise_sum(&v), 10.0);
1129        assert_eq!(pairwise_sum(&[]), 0.0);
1130        assert_eq!(pairwise_sum(&[42.0]), 42.0);
1131    }
1132
1133    // -- ADR-0026: lvis_default plan and CategoryFilter dispatch --------------
1134
1135    /// Synthesize a `(T, R, K, A, M)` precision tensor + matching
1136    /// recall and counts for unit-testing the K-axis filter. The
1137    /// shape parameters mirror lvis-api's typical run: T=10, R=101,
1138    /// A=4, M=1.
1139    fn fake_accumulated(n_k: usize, precision_per_k: &[f64], recall_per_k: &[f64]) -> Accumulated {
1140        const N_T: usize = 10;
1141        const N_R: usize = 101;
1142        const N_A: usize = 4;
1143        const N_M: usize = 1;
1144        assert_eq!(precision_per_k.len(), n_k);
1145        assert_eq!(recall_per_k.len(), n_k);
1146        let mut precision = Array5::<f64>::from_elem((N_T, N_R, n_k, N_A, N_M), 0.0);
1147        let mut recall = Array4::<f64>::from_elem((N_T, n_k, N_A, N_M), 0.0);
1148        for k in 0..n_k {
1149            // Fill every cell on the K-axis with the same constant so
1150            // mean(s[s>-1]) trivially equals that constant; sentinels
1151            // (`-1`) on whole-K rows fall through to the AF6 path.
1152            for t in 0..N_T {
1153                for r in 0..N_R {
1154                    for a in 0..N_A {
1155                        for m in 0..N_M {
1156                            precision[(t, r, k, a, m)] = precision_per_k[k];
1157                        }
1158                    }
1159                }
1160                for a in 0..N_A {
1161                    for m in 0..N_M {
1162                        recall[(t, k, a, m)] = recall_per_k[k];
1163                    }
1164                }
1165            }
1166        }
1167        Accumulated {
1168            precision,
1169            recall,
1170            scores: Array5::<f64>::from_elem((N_T, N_R, n_k, N_A, N_M), 0.0),
1171        }
1172    }
1173
1174    #[test]
1175    fn lvis_default_has_13_entries_in_canonical_order() {
1176        let plan = StatRequest::lvis_default();
1177        assert_eq!(plan.len(), 13, "AF1: 9 AP + 4 AR");
1178        // First 6 are AP across the COCO area buckets (no freq filter).
1179        for (i, req) in plan.iter().take(6).enumerate() {
1180            assert_eq!(req.metric, Metric::AveragePrecision, "row {i}");
1181            assert_eq!(req.category_filter, CategoryFilter::All, "row {i}");
1182        }
1183        // Rows 6/7/8: APr/APc/APf — ALL area, frequency filter set.
1184        for (i, expected) in [Frequency::Rare, Frequency::Common, Frequency::Frequent]
1185            .iter()
1186            .enumerate()
1187        {
1188            let req = &plan[6 + i];
1189            assert_eq!(req.metric, Metric::AveragePrecision);
1190            assert_eq!(req.area.index, AreaRng::ALL.index);
1191            assert_eq!(
1192                req.category_filter,
1193                CategoryFilter::Frequency(*expected),
1194                "row {}: AP{}",
1195                6 + i,
1196                expected_letter(*expected),
1197            );
1198        }
1199        // Rows 9..13: AR@300 across all four area buckets, no freq filter.
1200        for (i, area_idx) in [
1201            AreaRng::ALL.index,
1202            AreaRng::SMALL.index,
1203            AreaRng::MEDIUM.index,
1204            AreaRng::LARGE.index,
1205        ]
1206        .iter()
1207        .enumerate()
1208        {
1209            let req = &plan[9 + i];
1210            assert_eq!(req.metric, Metric::AverageRecall);
1211            assert_eq!(req.area.index, *area_idx);
1212            assert_eq!(req.category_filter, CategoryFilter::All);
1213            assert_eq!(req.max_dets, MaxDetSelector::Largest);
1214        }
1215    }
1216
1217    fn expected_letter(f: Frequency) -> char {
1218        match f {
1219            Frequency::Rare => 'r',
1220            Frequency::Common => 'c',
1221            Frequency::Frequent => 'f',
1222        }
1223    }
1224
1225    #[test]
1226    fn summarize_with_rejects_frequency_filter_on_coco_path() {
1227        // The plain `summarize_with` entry point has no K-axis context;
1228        // a Frequency-filtered plan must surface an InvalidConfig that
1229        // points the caller at summarize_with_lvis.
1230        let accum = fake_accumulated(3, &[1.0, 1.0, 1.0], &[1.0, 1.0, 1.0]);
1231        let plan = StatRequest::lvis_default();
1232        let err = summarize_with(&accum, &plan, iou_thresholds(), &[300]).unwrap_err();
1233        match err {
1234            EvalError::InvalidConfig { detail } => {
1235                assert!(detail.contains("summarize_with_lvis"), "msg: {detail}");
1236            }
1237            other => panic!("expected InvalidConfig, got {other:?}"),
1238        }
1239    }
1240
1241    #[test]
1242    fn summarize_with_lvis_routes_frequency_buckets_correctly() {
1243        // 3 categories: cat 1 = Frequent (precision 0.6), cat 2 =
1244        // Common (0.4), cat 3 = Rare (0.2). AP_r/c/f must equal each
1245        // bucket's per-K precision; AP overall = (0.6+0.4+0.2)/3.
1246        let accum = fake_accumulated(3, &[0.6, 0.4, 0.2], &[0.6, 0.4, 0.2]);
1247        let cat_ids = [CategoryId(1), CategoryId(2), CategoryId(3)];
1248        let mut freq_map = HashMap::new();
1249        freq_map.insert(CategoryId(1), Frequency::Frequent);
1250        freq_map.insert(CategoryId(2), Frequency::Common);
1251        freq_map.insert(CategoryId(3), Frequency::Rare);
1252
1253        let plan = StatRequest::lvis_default();
1254        let summary = summarize_with_lvis(
1255            &accum,
1256            &plan,
1257            iou_thresholds(),
1258            &[300],
1259            &cat_ids,
1260            Some(&freq_map),
1261        )
1262        .unwrap();
1263
1264        // Index 6/7/8 → APr/APc/APf.
1265        let apr = summary.lines[6].value;
1266        let apc = summary.lines[7].value;
1267        let apf = summary.lines[8].value;
1268        assert!((apr - 0.2).abs() < 1e-12, "APr expected 0.2, got {apr}");
1269        assert!((apc - 0.4).abs() < 1e-12, "APc expected 0.4, got {apc}");
1270        assert!((apf - 0.6).abs() < 1e-12, "APf expected 0.6, got {apf}");
1271        // Index 0 → AP overall.
1272        let ap = summary.lines[0].value;
1273        let expected = (0.6 + 0.4 + 0.2) / 3.0;
1274        assert!((ap - expected).abs() < 1e-12, "AP overall: {ap}");
1275    }
1276
1277    #[test]
1278    fn ab3_filters_minus_one_sentinels_before_mean() {
1279        // Mix two categories: one with positive precision, one with
1280        // the `-1` sentinel (a category that produced no eval_imgs
1281        // entries). The AP overall must equal the positive value
1282        // alone — the sentinel falls out of the mean.
1283        let accum = fake_accumulated(2, &[-1.0, 0.5], &[-1.0, 0.5]);
1284        let cat_ids = [CategoryId(1), CategoryId(2)];
1285        let mut freq_map = HashMap::new();
1286        freq_map.insert(CategoryId(1), Frequency::Rare);
1287        freq_map.insert(CategoryId(2), Frequency::Frequent);
1288        let summary = summarize_with_lvis(
1289            &accum,
1290            &StatRequest::lvis_default(),
1291            iou_thresholds(),
1292            &[300],
1293            &cat_ids,
1294            Some(&freq_map),
1295        )
1296        .unwrap();
1297        // AP overall: only cat 2 contributes (cat 1 is `-1`).
1298        assert!((summary.lines[0].value - 0.5).abs() < 1e-12);
1299        // APr: cat 1 is the only Rare category — and it's `-1`, so
1300        // the bucket is empty after filtering. AF6: returns `-1.0`.
1301        assert_eq!(summary.lines[6].value, -1.0, "APr empty bucket → -1");
1302        // APf: cat 2 is the only Frequent category, value 0.5.
1303        assert!((summary.lines[8].value - 0.5).abs() < 1e-12);
1304    }
1305
1306    #[test]
1307    fn af6_empty_frequency_bucket_returns_minus_one_not_zero_or_nan() {
1308        // Three Frequent categories, no Rare or Common. APr and APc
1309        // must both surface the `-1` sentinel, distinct from the
1310        // panoptic ADR-0025 W6 corrected behavior (returns `0.0`)
1311        // and from numpy's `nan` on an unfiltered empty mean.
1312        let accum = fake_accumulated(3, &[0.7, 0.8, 0.9], &[0.7, 0.8, 0.9]);
1313        let cat_ids = [CategoryId(1), CategoryId(2), CategoryId(3)];
1314        let mut freq_map = HashMap::new();
1315        freq_map.insert(CategoryId(1), Frequency::Frequent);
1316        freq_map.insert(CategoryId(2), Frequency::Frequent);
1317        freq_map.insert(CategoryId(3), Frequency::Frequent);
1318        let summary = summarize_with_lvis(
1319            &accum,
1320            &StatRequest::lvis_default(),
1321            iou_thresholds(),
1322            &[300],
1323            &cat_ids,
1324            Some(&freq_map),
1325        )
1326        .unwrap();
1327        // AP overall: mean of 0.7/0.8/0.9 = 0.8.
1328        assert!((summary.lines[0].value - 0.8).abs() < 1e-12);
1329        // APr / APc: empty K filter → -1.0 (not 0.0, not nan).
1330        assert_eq!(summary.lines[6].value, -1.0, "APr");
1331        assert_eq!(summary.lines[7].value, -1.0, "APc");
1332        assert!(!summary.lines[6].value.is_nan(), "AF6: never nan");
1333        assert!(summary.lines[6].value != 0.0, "AF6: never 0.0");
1334    }
1335
1336    #[test]
1337    fn ab6_no_frequency_map_yields_minus_one_for_frequency_filtered_lines() {
1338        // The dataset doesn't carry frequency tags (the COCO loader
1339        // path on an LVIS-shaped JSON, or a programmatically built
1340        // dataset). Frequency-filtered entries gracefully surface
1341        // the `-1` sentinel — quirk **AB6** corrected (no panic).
1342        let accum = fake_accumulated(2, &[0.5, 0.5], &[0.5, 0.5]);
1343        let cat_ids = [CategoryId(1), CategoryId(2)];
1344        let summary = summarize_with_lvis(
1345            &accum,
1346            &StatRequest::lvis_default(),
1347            iou_thresholds(),
1348            &[300],
1349            &cat_ids,
1350            None,
1351        )
1352        .unwrap();
1353        assert!((summary.lines[0].value - 0.5).abs() < 1e-12, "AP overall");
1354        assert_eq!(summary.lines[6].value, -1.0, "APr without freq map");
1355        assert_eq!(summary.lines[7].value, -1.0, "APc without freq map");
1356        assert_eq!(summary.lines[8].value, -1.0, "APf without freq map");
1357    }
1358
1359    #[test]
1360    fn category_filter_by_ids_subsets_correctly() {
1361        // ByIds filter: include only cat 2 → AP equals cat 2's
1362        // per-K precision regardless of the other categories.
1363        let accum = fake_accumulated(3, &[0.1, 0.5, 0.9], &[0.1, 0.5, 0.9]);
1364        let cat_ids = [CategoryId(1), CategoryId(2), CategoryId(3)];
1365        let plan = vec![StatRequest::new_with_filter(
1366            Metric::AveragePrecision,
1367            None,
1368            AreaRng::ALL,
1369            MaxDetSelector::Largest,
1370            CategoryFilter::ByIds(vec![CategoryId(2)]),
1371        )];
1372        let summary =
1373            summarize_with_lvis(&accum, &plan, iou_thresholds(), &[300], &cat_ids, None).unwrap();
1374        assert!((summary.lines[0].value - 0.5).abs() < 1e-12);
1375    }
1376
1377    #[test]
1378    fn category_axis_size_mismatch_is_typed_error() {
1379        let accum = fake_accumulated(2, &[0.5, 0.5], &[0.5, 0.5]);
1380        let cat_ids = [CategoryId(1), CategoryId(2), CategoryId(3)]; // wrong length
1381        let err = summarize_with_lvis(
1382            &accum,
1383            &StatRequest::lvis_default(),
1384            iou_thresholds(),
1385            &[300],
1386            &cat_ids,
1387            None,
1388        )
1389        .unwrap_err();
1390        assert!(matches!(err, EvalError::InvalidConfig { .. }));
1391    }
1392}