Skip to main content

vernier_core/tide/
assignment.rs

1//! Bin assignment for TIDE error decomposition.
2//!
3//! Walks every detection that survives the per-image `max_dets` cap and
4//! assigns it to one of the six TIDE bins (Cls / Loc / Both / Dupe / Bkg
5//! plus the `Tp`/`Ignore` non-FP labels). Walks every non-ignore GT and
6//! flags those without a same-class match at `t_f` as Missed.
7//!
8//! The algorithm mirrors `tests/python/oracle/tide/oracle.py::
9//! _attribute_bins` exactly — that file is the spec per ADR-0021 and
10//! the Rust output is correct iff `|delta_rust − delta_oracle| < 1e-9`
11//! per bin per fixture (see `tests/tide_oracle_parity.rs`).
12//!
13//! ## Inputs
14//!
15//! - `gt` / `dt` are the source dataset and detection list.
16//! - `cross_class` carries the un-class-filtered per-image IoU matrices
17//!   from the orchestrator-level side pass (ADR-0023). Rows index DTs in
18//!   the same per-image score-desc order [`crate::evaluate::
19//!   dt_top_indices_for_cell`] uses; columns index GTs in dataset
20//!   insertion order. We read it for `iou_same` / `iou_cross` to
21//!   sidestep recomputing the kernel.
22//! - `params` supplies `t_f` / `t_b` / `max_dets_per_image` / `use_cats`.
23//!
24//! ## Output
25//!
26//! [`BinAssignment`] carries a per-`(image_id, dt_input_idx)` label
27//! plus the per-bin `target_gt_local_idx` the rewrite layer needs:
28//! the wrong-class GT for Cls, the same-class GT for Loc. For Dupe /
29//! Bkg / Both / Missed the rewrite layer needs no extra payload (drop
30//! the DT, or flip the GT's ignore flag).
31
32use std::collections::HashMap;
33
34use ndarray::Array2;
35
36use crate::dataset::{
37    CategoryId, CocoAnnotation, CocoDataset, CocoDetection, CocoDetections, EvalDataset, ImageId,
38};
39use crate::error::EvalError;
40use crate::evaluate::dt_top_indices_for_cell;
41use crate::matching::{match_image, MatchResult};
42use crate::parity::ParityMode;
43use crate::tables::CrossClassIous;
44
45use super::params::TideParams;
46
47/// One detection's TIDE label at `t_f`, plus the rewrite-layer target
48/// (when the bin's correction needs one) and the IoU values that drove
49/// the bin pick (the FP-IoU histogram reads these for ADR-0022's
50/// `t_b` ratification).
51///
52/// `target_gt_local_idx` indexes into the **per-image** GT list in
53/// dataset insertion order — the same axis [`CrossClassIous::gt_classes`]
54/// uses as columns. Its meaning depends on `bin`:
55///
56/// - `Cls`  — index of the wrong-class GT to relabel onto.
57/// - `Loc`  — index of the same-class GT to snap the bbox to.
58/// - any other bin — meaningless (`-1`).
59///
60/// `iou_same` / `iou_cross` are the best same-class and cross-class
61/// IoUs computed during bin assignment. For TP / Ignore labels they're
62/// recorded as zeros (those DTs aren't on the FP path and the
63/// histogram filters them out).
64#[derive(Debug, Clone, Copy, PartialEq)]
65pub struct DtBinLabel {
66    /// The TIDE bin (or non-FP label).
67    pub bin: DtBin,
68    /// Per-image local GT index used by the `Cls` / `Loc` corrections;
69    /// `-1` for bins that need no target.
70    pub target_gt_local_idx: i32,
71    /// Best same-class IoU at the time of bin pick. `0.0` for TP /
72    /// Ignore labels (not used on the FP path).
73    pub iou_same: f64,
74    /// Best cross-class IoU at the time of bin pick. `0.0` for TP /
75    /// Ignore labels.
76    pub iou_cross: f64,
77}
78
79/// Per-detection TIDE label, including the two non-FP labels.
80///
81/// The TP and Ignore labels are not in [`super::TideErrorBin`] (which
82/// only enumerates the six error bins) — they live here because the
83/// bin-assignment loop needs to know that "this DT was a true positive,
84/// no rewrite needed" or "this DT matched only an ignore-GT, not
85/// counted as an FP". See `oracle.py:466-471`.
86#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
87pub enum DtBin {
88    /// True positive: matched a non-ignore GT at `t_f`.
89    Tp,
90    /// Matched only an ignore-GT (e.g. `iscrowd=1`) at `t_f`. Excluded
91    /// from FP accounting.
92    Ignore,
93    /// Cls error — wrong-class GT overlaps at IoU `>= t_f`.
94    Cls,
95    /// Loc error — same-class GT overlaps at IoU `∈ [t_b, t_f)`, and
96    /// the same-class IoU is at least the cross-class IoU.
97    Loc,
98    /// Both error — wrong-class GT overlaps at IoU `∈ [t_b, t_f)`,
99    /// and the same-class IoU did not reach `t_b` (or was lower).
100    Both,
101    /// Dupe error — same-class GT overlaps at IoU `>= t_f` but a
102    /// higher-scoring same-class DT already claimed it.
103    Dupe,
104    /// Bkg error — best IoU against any GT is `< t_b`.
105    Bkg,
106}
107
108/// Bin-assignment output for one TIDE call.
109///
110/// Two flat maps keyed by `(image_id, dt_input_idx)` and `(image_id,
111/// gt_input_idx)` respectively. The `dt_input_idx` is the index into
112/// the input [`CocoDetections`] (the auto-incrementing position the
113/// detection originally occupied in the input list); the
114/// `gt_input_idx` is the index into the input [`CocoDataset`]
115/// annotations (also dataset insertion order).
116///
117/// Both indices are stable across the rewrite layer's per-bin calls:
118/// the rewrite rebuilds detections preserving these positions so the
119/// targets stay valid.
120#[derive(Debug, Default, Clone)]
121pub struct BinAssignment {
122    /// `(image_id, dt_input_idx)` → bin label. DTs evicted by the
123    /// per-image `max_dets` cap are absent (mirrors the oracle's
124    /// `attribution.get(d.dt_idx)` returning `None` for evicted DTs).
125    pub dt_labels: HashMap<(i64, usize), DtBinLabel>,
126    /// `(image_id, gt_input_idx)` for non-ignore GTs unmatched by any
127    /// same-class DT at `t_f`. The Missed correction marks these as
128    /// `ignore=true` in the rewrite step.
129    pub missed_gts: Vec<(i64, usize)>,
130}
131
132/// Walk every image and assign TIDE bins per the oracle's algorithm.
133///
134/// ## Algorithm
135///
136/// For each image:
137///
138/// 1. Apply the per-image score-desc `max_dets_per_image` cap (quirk
139///    **A1**) to the DTs. Evicted DTs are not labelled.
140/// 2. For each category present on the image, run a greedy match
141///    (matching engine, non-cross-class) at `t_f` against that
142///    category's same-class same-image GTs. Track `gt_taken_by` and
143///    per-DT `matched` / `ignore` status at `t_f`.
144/// 3. For each surviving DT, look up its `iou_same` and `iou_cross`
145///    from the cross-class side-pass storage and apply the priority
146///    decision (`oracle.py:496-531`):
147///    - `Tp` if matched at `t_f` to a non-ignore GT.
148///    - `Ignore` if matched only to an ignore-GT.
149///    - else FP, with the priority chain
150///      `Dupe → Cls → Loc → Both → Bkg`.
151/// 4. For each non-ignore GT, mark it Missed iff no same-class DT
152///    matched it at `t_f` (per `oracle.py:533-547`).
153///
154/// # Errors
155///
156/// Propagates [`EvalError`] from the matching engine (only on dimension
157/// mismatch — kernel work is already done by the time we get here).
158pub fn assign_bins(
159    gt: &CocoDataset,
160    dt: &CocoDetections,
161    cross_class: &CrossClassIous,
162    params: &TideParams<'_>,
163) -> Result<BinAssignment, EvalError> {
164    let mut images: Vec<&crate::dataset::ImageMeta> = gt.images().iter().collect();
165    images.sort_unstable_by_key(|im| im.id.0);
166
167    let mut out = BinAssignment::default();
168    let gt_anns = gt.annotations();
169    let dt_anns = dt.detections();
170
171    // Map (image_id, gt_input_idx) → per-image local column index used
172    // by CrossClassIous. Same ordering used by the side pass:
173    // `gt.ann_indices_for_image(image_id)` (dataset insertion order
174    // for that image).
175    for (image_idx, image) in images.iter().enumerate() {
176        assign_bins_for_image(
177            image_idx,
178            image.id,
179            gt,
180            dt,
181            cross_class,
182            params,
183            gt_anns,
184            dt_anns,
185            &mut out,
186        )?;
187    }
188    Ok(out)
189}
190
191#[allow(clippy::too_many_arguments)]
192fn assign_bins_for_image(
193    image_idx: usize,
194    image_id: ImageId,
195    gt: &CocoDataset,
196    dt: &CocoDetections,
197    cross_class: &CrossClassIous,
198    params: &TideParams<'_>,
199    gt_anns: &[CocoAnnotation],
200    dt_anns: &[CocoDetection],
201    out: &mut BinAssignment,
202) -> Result<(), EvalError> {
203    // Per-image GT list in the same column ordering CrossClassIous uses.
204    // `compute_cross_class_ious` calls `gt.ann_indices_for_image(image_id)`,
205    // which returns indices in dataset-insertion order (HashMap insertion
206    // order is irrelevant — the by_image map's Vec preserves the order
207    // the dataset was constructed in).
208    let gt_local_indices: &[usize] = gt.ann_indices_for_image(image_id);
209    // DT list capped + sorted in score-desc order, again matching the
210    // side pass's ordering for row indexing.
211    let dt_local_indices = dt_top_indices_for_cell(dt, image_id, None, params.max_dets_per_image);
212
213    if gt_local_indices.is_empty() && dt_local_indices.is_empty() {
214        return Ok(());
215    }
216
217    let cross = cross_class.get(image_idx);
218
219    // For each DT in the cap-applied list, its row index in the cross
220    // matrix is its position in `dt_local_indices` (the side pass walks
221    // the same `dt_top_indices_for_cell` output).
222    // For each GT in `gt_local_indices`, its column index in the cross
223    // matrix is its position in `gt_local_indices`.
224
225    // 1. Per-class same-class greedy match at t_f. Track per-DT match
226    //    status and per-GT-local-column "taken_by" map.
227    //    The categories we iterate over are the ids actually present on
228    //    the image (matches the oracle's `cats_in_image` set).
229    let mut per_dt_matched: HashMap<usize, bool> = HashMap::new();
230    let mut per_dt_ignore: HashMap<usize, bool> = HashMap::new();
231    // `gt_taken_by[gt_local_col_idx] = dt_input_idx` — same-class match
232    // took this GT at t_f. Used for Missed attribution and not for
233    // Dupe (Dupe is geometric: iou_same >= t_f).
234    let mut gt_taken_by: HashMap<usize, usize> = HashMap::new();
235
236    let cats_in_image: Vec<CategoryId> = if params.use_cats {
237        let mut cats: Vec<CategoryId> = gt_local_indices
238            .iter()
239            .map(|&j| gt_anns[j].category_id)
240            .chain(dt_local_indices.iter().map(|&j| dt_anns[j].category_id))
241            .collect();
242        cats.sort_unstable_by_key(|c| c.0);
243        cats.dedup();
244        cats
245    } else {
246        // L4: collapse — single virtual category.
247        vec![CategoryId(crate::evaluate::COLLAPSED_CATEGORY_SENTINEL)]
248    };
249
250    for cat in cats_in_image {
251        same_class_match_one_category(
252            &gt_local_indices_with_pos(gt_local_indices, gt_anns, cat, params.use_cats),
253            &dt_local_indices_with_pos(&dt_local_indices, dt_anns, cat, params.use_cats),
254            gt_anns,
255            dt_anns,
256            params,
257            &mut per_dt_matched,
258            &mut per_dt_ignore,
259            &mut gt_taken_by,
260        )?;
261    }
262
263    // 2. Per-DT bin label using the cross-class side-pass IoU.
264    for (row_idx, &dt_input_idx) in dt_local_indices.iter().enumerate() {
265        let dt = &dt_anns[dt_input_idx];
266        let key = (image_id.0, dt_input_idx);
267
268        if per_dt_ignore.get(&dt_input_idx).copied().unwrap_or(false) {
269            out.dt_labels.insert(
270                key,
271                DtBinLabel {
272                    bin: DtBin::Ignore,
273                    target_gt_local_idx: -1,
274                    iou_same: 0.0,
275                    iou_cross: 0.0,
276                },
277            );
278            continue;
279        }
280        if per_dt_matched.get(&dt_input_idx).copied().unwrap_or(false) {
281            out.dt_labels.insert(
282                key,
283                DtBinLabel {
284                    bin: DtBin::Tp,
285                    target_gt_local_idx: -1,
286                    iou_same: 0.0,
287                    iou_cross: 0.0,
288                },
289            );
290            continue;
291        }
292
293        // FP: compute iou_same / iou_cross from the side pass.
294        let (iou_same, best_same_col, iou_cross, best_cross_col) = best_same_and_cross(
295            row_idx,
296            dt.category_id,
297            cross,
298            gt_local_indices,
299            gt_anns,
300            params.use_cats,
301        );
302
303        let label = pick_bin(
304            iou_same,
305            best_same_col,
306            iou_cross,
307            best_cross_col,
308            params.t_f,
309            params.t_b,
310        );
311        out.dt_labels.insert(key, label);
312    }
313
314    // 3. Missed: non-ignore GTs not in `gt_taken_by`.
315    for (col_idx, &gt_input_idx) in gt_local_indices.iter().enumerate() {
316        let g = &gt_anns[gt_input_idx];
317        // Use the same effective_ignore semantics the matching path uses
318        // (D1) — Strict + Corrected both fold iscrowd into ignore.
319        if g.is_crowd || g.ignore_flag.unwrap_or(false) {
320            continue;
321        }
322        if gt_taken_by.contains_key(&col_idx) {
323            continue;
324        }
325        out.missed_gts.push((image_id.0, gt_input_idx));
326    }
327
328    Ok(())
329}
330
331/// Build a `(local_col_idx, gt_input_idx)` list for one category, where
332/// `local_col_idx` matches the cross-class column ordering.
333fn gt_local_indices_with_pos(
334    gt_local_indices: &[usize],
335    gt_anns: &[CocoAnnotation],
336    cat: CategoryId,
337    use_cats: bool,
338) -> Vec<(usize, usize)> {
339    gt_local_indices
340        .iter()
341        .enumerate()
342        .filter(|&(_, &gi)| !use_cats || gt_anns[gi].category_id == cat)
343        .map(|(col, &gi)| (col, gi))
344        .collect()
345}
346
347/// Build a `(row_idx, dt_input_idx)` list for one category. `row_idx`
348/// matches the cross-class row ordering.
349fn dt_local_indices_with_pos(
350    dt_local_indices: &[usize],
351    dt_anns: &[CocoDetection],
352    cat: CategoryId,
353    use_cats: bool,
354) -> Vec<(usize, usize)> {
355    dt_local_indices
356        .iter()
357        .enumerate()
358        .filter(|&(_, &di)| !use_cats || dt_anns[di].category_id == cat)
359        .map(|(row, &di)| (row, di))
360        .collect()
361}
362
363#[allow(clippy::too_many_arguments)]
364fn same_class_match_one_category(
365    gts_in_cat: &[(usize, usize)], // (col_idx in cross matrix, gt_input_idx)
366    dts_in_cat: &[(usize, usize)], // (row_idx in cross matrix, dt_input_idx)
367    gt_anns: &[CocoAnnotation],
368    dt_anns: &[CocoDetection],
369    params: &TideParams<'_>,
370    per_dt_matched: &mut HashMap<usize, bool>,
371    per_dt_ignore: &mut HashMap<usize, bool>,
372    gt_taken_by: &mut HashMap<usize, usize>,
373) -> Result<(), EvalError> {
374    if dts_in_cat.is_empty() {
375        return Ok(());
376    }
377    let n_g = gts_in_cat.len();
378    let n_d = dts_in_cat.len();
379
380    // Build same-class IoU matrix by computing afresh via the bbox
381    // kernel. Rebuilding here (rather than reading from CrossClassIous's
382    // submatrix) keeps the assignment module free of an axis-orientation
383    // mistake — the cross-class storage is `(D, G)` and the matching
384    // engine needs `(G, D)`, so a sub-slice would have to be transposed
385    // anyway. Bbox IoU is cheap and the alternative slicing is trickier
386    // to get right.
387    let mut iou = Array2::<f64>::zeros((n_g, n_d));
388    if n_g > 0 {
389        for (gi_local, &(_, gi)) in gts_in_cat.iter().enumerate() {
390            let g_box = gt_anns[gi].bbox;
391            for (di_local, &(_, di)) in dts_in_cat.iter().enumerate() {
392                let d_box = dt_anns[di].bbox;
393                iou[(gi_local, di_local)] = bbox_iou_pair(g_box, d_box);
394            }
395        }
396    }
397
398    let gt_ignore: Vec<bool> = gts_in_cat
399        .iter()
400        .map(|&(_, gi)| {
401            let g = &gt_anns[gi];
402            // Mirror the oracle's "iscrowd OR ignore" — see oracle.py:
403            // `gt_ignore_k = np.array([g.iscrowd or g.ignore for g in gts_k])`.
404            // The matching engine reads the same flag.
405            g.is_crowd || g.ignore_flag.unwrap_or(false)
406        })
407        .collect();
408    let gt_iscrowd: Vec<bool> = gts_in_cat
409        .iter()
410        .map(|&(_, gi)| gt_anns[gi].is_crowd)
411        .collect();
412    let dt_scores: Vec<f64> = dts_in_cat
413        .iter()
414        .map(|&(_, di)| dt_anns[di].score)
415        .collect();
416
417    let single_threshold = [params.t_f];
418    let MatchResult {
419        dt_perm,
420        gt_perm,
421        dt_matches: dt_matches_pos,
422        gt_matches: gt_matches_pos,
423        dt_ignore,
424    } = match_image(
425        iou.view(),
426        &gt_ignore,
427        &gt_iscrowd,
428        &dt_scores,
429        &single_threshold,
430        ParityMode::Strict,
431    )?;
432
433    // Record per-DT matched / ignore at t_f. Permutations are over the
434    // dts_in_cat slot ordering — map back to the global dt_input_idx.
435    for (sorted_d, &orig_d) in dt_perm.iter().enumerate() {
436        let (_row_idx, dt_input_idx) = dts_in_cat[orig_d];
437        let matched = dt_matches_pos[(0, sorted_d)] >= 0;
438        let is_ignore = dt_ignore[(0, sorted_d)];
439        per_dt_matched.insert(dt_input_idx, matched);
440        per_dt_ignore.insert(dt_input_idx, is_ignore);
441    }
442    // Record per-GT taken_by for Missed attribution. Note: the
443    // matching engine returns gt_matches in the gt_perm order; the
444    // oracle's gt_matched_by uses the original GT order. Map perm →
445    // gts_in_cat[orig_g] → cross-matrix column index.
446    for (sorted_g, &orig_g) in gt_perm.iter().enumerate() {
447        let dt_pos = gt_matches_pos[(0, sorted_g)];
448        if dt_pos < 0 {
449            continue;
450        }
451        // Skip ignore-GTs: pycocotools' matching can mark an ignore-GT
452        // as matched but the oracle's `gt_matched_by` only records
453        // non-ignore matches (see `_greedy_match`, `gt_matched_by`
454        // stays -1 for ignore matches per oracle.py:300-304).
455        if gt_ignore[orig_g] {
456            continue;
457        }
458        let (col_idx, _gt_input_idx) = gts_in_cat[orig_g];
459        let dt_orig = dt_perm[dt_pos as usize];
460        let (_row_idx, dt_input_idx) = dts_in_cat[dt_orig];
461        gt_taken_by.insert(col_idx, dt_input_idx);
462    }
463    Ok(())
464}
465
466/// Pure axis-aligned bbox IoU on COCO `[x, y, w, h]`. Mirrors
467/// `oracle.py::bbox_iou` for one pair.
468fn bbox_iou_pair(g: crate::dataset::Bbox, d: crate::dataset::Bbox) -> f64 {
469    let g_x2 = g.x + g.w;
470    let g_y2 = g.y + g.h;
471    let d_x2 = d.x + d.w;
472    let d_y2 = d.y + d.h;
473    let inter_w = (g_x2.min(d_x2) - g.x.max(d.x)).max(0.0);
474    let inter_h = (g_y2.min(d_y2) - g.y.max(d.y)).max(0.0);
475    let inter = inter_w * inter_h;
476    let union = g.w * g.h + d.w * d.h - inter;
477    if union <= 0.0 {
478        0.0
479    } else {
480        inter / union
481    }
482}
483
484/// Pull `iou_same` / `iou_cross` for one DT row out of the cross-class
485/// IoU matrix. Returns the best IoU and the per-image-local GT column
486/// index for each side, or `(0.0, -1, 0.0, -1)` when the image has no
487/// GTs (or the matrix is absent).
488///
489/// The cross-class side pass already labels each row/column with the
490/// category index, but we read the GT category from `gt_anns` directly
491/// because it costs nothing here and keeps the side-pass parallel
492/// vectors out of the hot path's mental model.
493fn best_same_and_cross(
494    row_idx: usize,
495    dt_cat: CategoryId,
496    cross: Option<ndarray::ArrayView2<'_, f64>>,
497    gt_local_indices: &[usize],
498    gt_anns: &[CocoAnnotation],
499    use_cats: bool,
500) -> (f64, i32, f64, i32) {
501    // No GT data on this image → both sides are zero (Bkg territory).
502    let cross = match cross {
503        Some(m) => m,
504        None => return (0.0, -1, 0.0, -1),
505    };
506    if cross.ncols() == 0 {
507        return (0.0, -1, 0.0, -1);
508    }
509
510    let mut iou_same = 0.0_f64;
511    let mut best_same: i32 = -1;
512    let mut iou_cross = 0.0_f64;
513    let mut best_cross: i32 = -1;
514
515    for (col, &gt_input_idx) in gt_local_indices.iter().enumerate() {
516        let v = cross[(row_idx, col)];
517        let g_cat = gt_anns[gt_input_idx].category_id;
518        let same_class = !use_cats || g_cat == dt_cat;
519        // Strict `>` mirrors the oracle (`if ious[g_local] > iou_same`)
520        // — the first column wins ties, matching the oracle's iteration.
521        if same_class {
522            if v > iou_same {
523                iou_same = v;
524                best_same = col as i32;
525            }
526        } else if v > iou_cross {
527            iou_cross = v;
528            best_cross = col as i32;
529        }
530    }
531    (iou_same, best_same, iou_cross, best_cross)
532}
533
534/// Apply the priority chain from `oracle.py:496-531`. Returns the bin
535/// label, the rewrite-layer target, and the iou_same / iou_cross values
536/// the histogram extractor reads (ADR-0022 t_b ratification).
537fn pick_bin(
538    iou_same: f64,
539    best_same_col: i32,
540    iou_cross: f64,
541    best_cross_col: i32,
542    t_f: f64,
543    t_b: f64,
544) -> DtBinLabel {
545    let (bin, target) = if iou_same >= t_f {
546        // The rewrite drops Dupe DTs; target unused but recorded
547        // for symmetry with the oracle's _BinAttribution shape.
548        (DtBin::Dupe, best_same_col)
549    } else if iou_cross >= t_f {
550        (DtBin::Cls, best_cross_col)
551    } else if iou_same >= t_b && iou_same >= iou_cross {
552        (DtBin::Loc, best_same_col)
553    } else if iou_cross >= t_b {
554        (DtBin::Both, best_cross_col)
555    } else {
556        (DtBin::Bkg, -1)
557    };
558    DtBinLabel {
559        bin,
560        target_gt_local_idx: target,
561        iou_same,
562        iou_cross,
563    }
564}