vernier_core/tide/assignment.rs
1//! Bin assignment for TIDE error decomposition.
2//!
3//! Walks every detection that survives the per-image `max_dets` cap and
4//! assigns it to one of the six TIDE bins (Cls / Loc / Both / Dupe / Bkg
5//! plus the `Tp`/`Ignore` non-FP labels). Walks every non-ignore GT and
6//! flags those without a same-class match at `t_f` as Missed.
7//!
8//! The algorithm mirrors `tests/python/oracle/tide/oracle.py::
9//! _attribute_bins` exactly — that file is the spec per ADR-0021 and
10//! the Rust output is correct iff `|delta_rust − delta_oracle| < 1e-9`
11//! per bin per fixture (see `tests/tide_oracle_parity.rs`).
12//!
13//! ## Inputs
14//!
15//! - `gt` / `dt` are the source dataset and detection list.
16//! - `cross_class` carries the un-class-filtered per-image IoU matrices
17//! from the orchestrator-level side pass (ADR-0023). Rows index DTs in
18//! the same per-image score-desc order [`crate::evaluate::
19//! dt_top_indices_for_cell`] uses; columns index GTs in dataset
20//! insertion order. We read it for `iou_same` / `iou_cross` to
21//! sidestep recomputing the kernel.
22//! - `params` supplies `t_f` / `t_b` / `max_dets_per_image` / `use_cats`.
23//!
24//! ## Output
25//!
26//! [`BinAssignment`] carries a per-`(image_id, dt_input_idx)` label
27//! plus the per-bin `target_gt_local_idx` the rewrite layer needs:
28//! the wrong-class GT for Cls, the same-class GT for Loc. For Dupe /
29//! Bkg / Both / Missed the rewrite layer needs no extra payload (drop
30//! the DT, or flip the GT's ignore flag).
31
32use std::collections::HashMap;
33
34use ndarray::Array2;
35
36use crate::dataset::{
37 CategoryId, CocoAnnotation, CocoDataset, CocoDetection, CocoDetections, EvalDataset, ImageId,
38};
39use crate::error::EvalError;
40use crate::evaluate::dt_top_indices_for_cell;
41use crate::matching::{match_image, MatchResult};
42use crate::parity::ParityMode;
43use crate::tables::CrossClassIous;
44
45use super::params::TideParams;
46
47/// One detection's TIDE label at `t_f`, plus the rewrite-layer target
48/// (when the bin's correction needs one) and the IoU values that drove
49/// the bin pick (the FP-IoU histogram reads these for ADR-0022's
50/// `t_b` ratification).
51///
52/// `target_gt_local_idx` indexes into the **per-image** GT list in
53/// dataset insertion order — the same axis [`CrossClassIous::gt_classes`]
54/// uses as columns. Its meaning depends on `bin`:
55///
56/// - `Cls` — index of the wrong-class GT to relabel onto.
57/// - `Loc` — index of the same-class GT to snap the bbox to.
58/// - any other bin — meaningless (`-1`).
59///
60/// `iou_same` / `iou_cross` are the best same-class and cross-class
61/// IoUs computed during bin assignment. For TP / Ignore labels they're
62/// recorded as zeros (those DTs aren't on the FP path and the
63/// histogram filters them out).
64#[derive(Debug, Clone, Copy, PartialEq)]
65pub struct DtBinLabel {
66 /// The TIDE bin (or non-FP label).
67 pub bin: DtBin,
68 /// Per-image local GT index used by the `Cls` / `Loc` corrections;
69 /// `-1` for bins that need no target.
70 pub target_gt_local_idx: i32,
71 /// Best same-class IoU at the time of bin pick. `0.0` for TP /
72 /// Ignore labels (not used on the FP path).
73 pub iou_same: f64,
74 /// Best cross-class IoU at the time of bin pick. `0.0` for TP /
75 /// Ignore labels.
76 pub iou_cross: f64,
77}
78
79/// Per-detection TIDE label, including the two non-FP labels.
80///
81/// The TP and Ignore labels are not in [`super::TideErrorBin`] (which
82/// only enumerates the six error bins) — they live here because the
83/// bin-assignment loop needs to know that "this DT was a true positive,
84/// no rewrite needed" or "this DT matched only an ignore-GT, not
85/// counted as an FP". See `oracle.py:466-471`.
86#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
87pub enum DtBin {
88 /// True positive: matched a non-ignore GT at `t_f`.
89 Tp,
90 /// Matched only an ignore-GT (e.g. `iscrowd=1`) at `t_f`. Excluded
91 /// from FP accounting.
92 Ignore,
93 /// Cls error — wrong-class GT overlaps at IoU `>= t_f`.
94 Cls,
95 /// Loc error — same-class GT overlaps at IoU `∈ [t_b, t_f)`, and
96 /// the same-class IoU is at least the cross-class IoU.
97 Loc,
98 /// Both error — wrong-class GT overlaps at IoU `∈ [t_b, t_f)`,
99 /// and the same-class IoU did not reach `t_b` (or was lower).
100 Both,
101 /// Dupe error — same-class GT overlaps at IoU `>= t_f` but a
102 /// higher-scoring same-class DT already claimed it.
103 Dupe,
104 /// Bkg error — best IoU against any GT is `< t_b`.
105 Bkg,
106}
107
108/// Bin-assignment output for one TIDE call.
109///
110/// Two flat maps keyed by `(image_id, dt_input_idx)` and `(image_id,
111/// gt_input_idx)` respectively. The `dt_input_idx` is the index into
112/// the input [`CocoDetections`] (the auto-incrementing position the
113/// detection originally occupied in the input list); the
114/// `gt_input_idx` is the index into the input [`CocoDataset`]
115/// annotations (also dataset insertion order).
116///
117/// Both indices are stable across the rewrite layer's per-bin calls:
118/// the rewrite rebuilds detections preserving these positions so the
119/// targets stay valid.
120#[derive(Debug, Default, Clone)]
121pub struct BinAssignment {
122 /// `(image_id, dt_input_idx)` → bin label. DTs evicted by the
123 /// per-image `max_dets` cap are absent (mirrors the oracle's
124 /// `attribution.get(d.dt_idx)` returning `None` for evicted DTs).
125 pub dt_labels: HashMap<(i64, usize), DtBinLabel>,
126 /// `(image_id, gt_input_idx)` for non-ignore GTs unmatched by any
127 /// same-class DT at `t_f`. The Missed correction marks these as
128 /// `ignore=true` in the rewrite step.
129 pub missed_gts: Vec<(i64, usize)>,
130}
131
132/// Walk every image and assign TIDE bins per the oracle's algorithm.
133///
134/// ## Algorithm
135///
136/// For each image:
137///
138/// 1. Apply the per-image score-desc `max_dets_per_image` cap (quirk
139/// **A1**) to the DTs. Evicted DTs are not labelled.
140/// 2. For each category present on the image, run a greedy match
141/// (matching engine, non-cross-class) at `t_f` against that
142/// category's same-class same-image GTs. Track `gt_taken_by` and
143/// per-DT `matched` / `ignore` status at `t_f`.
144/// 3. For each surviving DT, look up its `iou_same` and `iou_cross`
145/// from the cross-class side-pass storage and apply the priority
146/// decision (`oracle.py:496-531`):
147/// - `Tp` if matched at `t_f` to a non-ignore GT.
148/// - `Ignore` if matched only to an ignore-GT.
149/// - else FP, with the priority chain
150/// `Dupe → Cls → Loc → Both → Bkg`.
151/// 4. For each non-ignore GT, mark it Missed iff no same-class DT
152/// matched it at `t_f` (per `oracle.py:533-547`).
153///
154/// # Errors
155///
156/// Propagates [`EvalError`] from the matching engine (only on dimension
157/// mismatch — kernel work is already done by the time we get here).
158pub fn assign_bins(
159 gt: &CocoDataset,
160 dt: &CocoDetections,
161 cross_class: &CrossClassIous,
162 params: &TideParams<'_>,
163) -> Result<BinAssignment, EvalError> {
164 let mut images: Vec<&crate::dataset::ImageMeta> = gt.images().iter().collect();
165 images.sort_unstable_by_key(|im| im.id.0);
166
167 let mut out = BinAssignment::default();
168 let gt_anns = gt.annotations();
169 let dt_anns = dt.detections();
170
171 // Map (image_id, gt_input_idx) → per-image local column index used
172 // by CrossClassIous. Same ordering used by the side pass:
173 // `gt.ann_indices_for_image(image_id)` (dataset insertion order
174 // for that image).
175 for (image_idx, image) in images.iter().enumerate() {
176 assign_bins_for_image(
177 image_idx,
178 image.id,
179 gt,
180 dt,
181 cross_class,
182 params,
183 gt_anns,
184 dt_anns,
185 &mut out,
186 )?;
187 }
188 Ok(out)
189}
190
191#[allow(clippy::too_many_arguments)]
192fn assign_bins_for_image(
193 image_idx: usize,
194 image_id: ImageId,
195 gt: &CocoDataset,
196 dt: &CocoDetections,
197 cross_class: &CrossClassIous,
198 params: &TideParams<'_>,
199 gt_anns: &[CocoAnnotation],
200 dt_anns: &[CocoDetection],
201 out: &mut BinAssignment,
202) -> Result<(), EvalError> {
203 // Per-image GT list in the same column ordering CrossClassIous uses.
204 // `compute_cross_class_ious` calls `gt.ann_indices_for_image(image_id)`,
205 // which returns indices in dataset-insertion order (HashMap insertion
206 // order is irrelevant — the by_image map's Vec preserves the order
207 // the dataset was constructed in).
208 let gt_local_indices: &[usize] = gt.ann_indices_for_image(image_id);
209 // DT list capped + sorted in score-desc order, again matching the
210 // side pass's ordering for row indexing.
211 let dt_local_indices = dt_top_indices_for_cell(dt, image_id, None, params.max_dets_per_image);
212
213 if gt_local_indices.is_empty() && dt_local_indices.is_empty() {
214 return Ok(());
215 }
216
217 let cross = cross_class.get(image_idx);
218
219 // For each DT in the cap-applied list, its row index in the cross
220 // matrix is its position in `dt_local_indices` (the side pass walks
221 // the same `dt_top_indices_for_cell` output).
222 // For each GT in `gt_local_indices`, its column index in the cross
223 // matrix is its position in `gt_local_indices`.
224
225 // 1. Per-class same-class greedy match at t_f. Track per-DT match
226 // status and per-GT-local-column "taken_by" map.
227 // The categories we iterate over are the ids actually present on
228 // the image (matches the oracle's `cats_in_image` set).
229 let mut per_dt_matched: HashMap<usize, bool> = HashMap::new();
230 let mut per_dt_ignore: HashMap<usize, bool> = HashMap::new();
231 // `gt_taken_by[gt_local_col_idx] = dt_input_idx` — same-class match
232 // took this GT at t_f. Used for Missed attribution and not for
233 // Dupe (Dupe is geometric: iou_same >= t_f).
234 let mut gt_taken_by: HashMap<usize, usize> = HashMap::new();
235
236 let cats_in_image: Vec<CategoryId> = if params.use_cats {
237 let mut cats: Vec<CategoryId> = gt_local_indices
238 .iter()
239 .map(|&j| gt_anns[j].category_id)
240 .chain(dt_local_indices.iter().map(|&j| dt_anns[j].category_id))
241 .collect();
242 cats.sort_unstable_by_key(|c| c.0);
243 cats.dedup();
244 cats
245 } else {
246 // L4: collapse — single virtual category.
247 vec![CategoryId(crate::evaluate::COLLAPSED_CATEGORY_SENTINEL)]
248 };
249
250 for cat in cats_in_image {
251 same_class_match_one_category(
252 >_local_indices_with_pos(gt_local_indices, gt_anns, cat, params.use_cats),
253 &dt_local_indices_with_pos(&dt_local_indices, dt_anns, cat, params.use_cats),
254 gt_anns,
255 dt_anns,
256 params,
257 &mut per_dt_matched,
258 &mut per_dt_ignore,
259 &mut gt_taken_by,
260 )?;
261 }
262
263 // 2. Per-DT bin label using the cross-class side-pass IoU.
264 for (row_idx, &dt_input_idx) in dt_local_indices.iter().enumerate() {
265 let dt = &dt_anns[dt_input_idx];
266 let key = (image_id.0, dt_input_idx);
267
268 if per_dt_ignore.get(&dt_input_idx).copied().unwrap_or(false) {
269 out.dt_labels.insert(
270 key,
271 DtBinLabel {
272 bin: DtBin::Ignore,
273 target_gt_local_idx: -1,
274 iou_same: 0.0,
275 iou_cross: 0.0,
276 },
277 );
278 continue;
279 }
280 if per_dt_matched.get(&dt_input_idx).copied().unwrap_or(false) {
281 out.dt_labels.insert(
282 key,
283 DtBinLabel {
284 bin: DtBin::Tp,
285 target_gt_local_idx: -1,
286 iou_same: 0.0,
287 iou_cross: 0.0,
288 },
289 );
290 continue;
291 }
292
293 // FP: compute iou_same / iou_cross from the side pass.
294 let (iou_same, best_same_col, iou_cross, best_cross_col) = best_same_and_cross(
295 row_idx,
296 dt.category_id,
297 cross,
298 gt_local_indices,
299 gt_anns,
300 params.use_cats,
301 );
302
303 let label = pick_bin(
304 iou_same,
305 best_same_col,
306 iou_cross,
307 best_cross_col,
308 params.t_f,
309 params.t_b,
310 );
311 out.dt_labels.insert(key, label);
312 }
313
314 // 3. Missed: non-ignore GTs not in `gt_taken_by`.
315 for (col_idx, >_input_idx) in gt_local_indices.iter().enumerate() {
316 let g = >_anns[gt_input_idx];
317 // Use the same effective_ignore semantics the matching path uses
318 // (D1) — Strict + Corrected both fold iscrowd into ignore.
319 if g.is_crowd || g.ignore_flag.unwrap_or(false) {
320 continue;
321 }
322 if gt_taken_by.contains_key(&col_idx) {
323 continue;
324 }
325 out.missed_gts.push((image_id.0, gt_input_idx));
326 }
327
328 Ok(())
329}
330
331/// Build a `(local_col_idx, gt_input_idx)` list for one category, where
332/// `local_col_idx` matches the cross-class column ordering.
333fn gt_local_indices_with_pos(
334 gt_local_indices: &[usize],
335 gt_anns: &[CocoAnnotation],
336 cat: CategoryId,
337 use_cats: bool,
338) -> Vec<(usize, usize)> {
339 gt_local_indices
340 .iter()
341 .enumerate()
342 .filter(|&(_, &gi)| !use_cats || gt_anns[gi].category_id == cat)
343 .map(|(col, &gi)| (col, gi))
344 .collect()
345}
346
347/// Build a `(row_idx, dt_input_idx)` list for one category. `row_idx`
348/// matches the cross-class row ordering.
349fn dt_local_indices_with_pos(
350 dt_local_indices: &[usize],
351 dt_anns: &[CocoDetection],
352 cat: CategoryId,
353 use_cats: bool,
354) -> Vec<(usize, usize)> {
355 dt_local_indices
356 .iter()
357 .enumerate()
358 .filter(|&(_, &di)| !use_cats || dt_anns[di].category_id == cat)
359 .map(|(row, &di)| (row, di))
360 .collect()
361}
362
363#[allow(clippy::too_many_arguments)]
364fn same_class_match_one_category(
365 gts_in_cat: &[(usize, usize)], // (col_idx in cross matrix, gt_input_idx)
366 dts_in_cat: &[(usize, usize)], // (row_idx in cross matrix, dt_input_idx)
367 gt_anns: &[CocoAnnotation],
368 dt_anns: &[CocoDetection],
369 params: &TideParams<'_>,
370 per_dt_matched: &mut HashMap<usize, bool>,
371 per_dt_ignore: &mut HashMap<usize, bool>,
372 gt_taken_by: &mut HashMap<usize, usize>,
373) -> Result<(), EvalError> {
374 if dts_in_cat.is_empty() {
375 return Ok(());
376 }
377 let n_g = gts_in_cat.len();
378 let n_d = dts_in_cat.len();
379
380 // Build same-class IoU matrix by computing afresh via the bbox
381 // kernel. Rebuilding here (rather than reading from CrossClassIous's
382 // submatrix) keeps the assignment module free of an axis-orientation
383 // mistake — the cross-class storage is `(D, G)` and the matching
384 // engine needs `(G, D)`, so a sub-slice would have to be transposed
385 // anyway. Bbox IoU is cheap and the alternative slicing is trickier
386 // to get right.
387 let mut iou = Array2::<f64>::zeros((n_g, n_d));
388 if n_g > 0 {
389 for (gi_local, &(_, gi)) in gts_in_cat.iter().enumerate() {
390 let g_box = gt_anns[gi].bbox;
391 for (di_local, &(_, di)) in dts_in_cat.iter().enumerate() {
392 let d_box = dt_anns[di].bbox;
393 iou[(gi_local, di_local)] = bbox_iou_pair(g_box, d_box);
394 }
395 }
396 }
397
398 let gt_ignore: Vec<bool> = gts_in_cat
399 .iter()
400 .map(|&(_, gi)| {
401 let g = >_anns[gi];
402 // Mirror the oracle's "iscrowd OR ignore" — see oracle.py:
403 // `gt_ignore_k = np.array([g.iscrowd or g.ignore for g in gts_k])`.
404 // The matching engine reads the same flag.
405 g.is_crowd || g.ignore_flag.unwrap_or(false)
406 })
407 .collect();
408 let gt_iscrowd: Vec<bool> = gts_in_cat
409 .iter()
410 .map(|&(_, gi)| gt_anns[gi].is_crowd)
411 .collect();
412 let dt_scores: Vec<f64> = dts_in_cat
413 .iter()
414 .map(|&(_, di)| dt_anns[di].score)
415 .collect();
416
417 let single_threshold = [params.t_f];
418 let MatchResult {
419 dt_perm,
420 gt_perm,
421 dt_matches: dt_matches_pos,
422 gt_matches: gt_matches_pos,
423 dt_ignore,
424 } = match_image(
425 iou.view(),
426 >_ignore,
427 >_iscrowd,
428 &dt_scores,
429 &single_threshold,
430 ParityMode::Strict,
431 )?;
432
433 // Record per-DT matched / ignore at t_f. Permutations are over the
434 // dts_in_cat slot ordering — map back to the global dt_input_idx.
435 for (sorted_d, &orig_d) in dt_perm.iter().enumerate() {
436 let (_row_idx, dt_input_idx) = dts_in_cat[orig_d];
437 let matched = dt_matches_pos[(0, sorted_d)] >= 0;
438 let is_ignore = dt_ignore[(0, sorted_d)];
439 per_dt_matched.insert(dt_input_idx, matched);
440 per_dt_ignore.insert(dt_input_idx, is_ignore);
441 }
442 // Record per-GT taken_by for Missed attribution. Note: the
443 // matching engine returns gt_matches in the gt_perm order; the
444 // oracle's gt_matched_by uses the original GT order. Map perm →
445 // gts_in_cat[orig_g] → cross-matrix column index.
446 for (sorted_g, &orig_g) in gt_perm.iter().enumerate() {
447 let dt_pos = gt_matches_pos[(0, sorted_g)];
448 if dt_pos < 0 {
449 continue;
450 }
451 // Skip ignore-GTs: pycocotools' matching can mark an ignore-GT
452 // as matched but the oracle's `gt_matched_by` only records
453 // non-ignore matches (see `_greedy_match`, `gt_matched_by`
454 // stays -1 for ignore matches per oracle.py:300-304).
455 if gt_ignore[orig_g] {
456 continue;
457 }
458 let (col_idx, _gt_input_idx) = gts_in_cat[orig_g];
459 let dt_orig = dt_perm[dt_pos as usize];
460 let (_row_idx, dt_input_idx) = dts_in_cat[dt_orig];
461 gt_taken_by.insert(col_idx, dt_input_idx);
462 }
463 Ok(())
464}
465
466/// Pure axis-aligned bbox IoU on COCO `[x, y, w, h]`. Mirrors
467/// `oracle.py::bbox_iou` for one pair.
468fn bbox_iou_pair(g: crate::dataset::Bbox, d: crate::dataset::Bbox) -> f64 {
469 let g_x2 = g.x + g.w;
470 let g_y2 = g.y + g.h;
471 let d_x2 = d.x + d.w;
472 let d_y2 = d.y + d.h;
473 let inter_w = (g_x2.min(d_x2) - g.x.max(d.x)).max(0.0);
474 let inter_h = (g_y2.min(d_y2) - g.y.max(d.y)).max(0.0);
475 let inter = inter_w * inter_h;
476 let union = g.w * g.h + d.w * d.h - inter;
477 if union <= 0.0 {
478 0.0
479 } else {
480 inter / union
481 }
482}
483
484/// Pull `iou_same` / `iou_cross` for one DT row out of the cross-class
485/// IoU matrix. Returns the best IoU and the per-image-local GT column
486/// index for each side, or `(0.0, -1, 0.0, -1)` when the image has no
487/// GTs (or the matrix is absent).
488///
489/// The cross-class side pass already labels each row/column with the
490/// category index, but we read the GT category from `gt_anns` directly
491/// because it costs nothing here and keeps the side-pass parallel
492/// vectors out of the hot path's mental model.
493fn best_same_and_cross(
494 row_idx: usize,
495 dt_cat: CategoryId,
496 cross: Option<ndarray::ArrayView2<'_, f64>>,
497 gt_local_indices: &[usize],
498 gt_anns: &[CocoAnnotation],
499 use_cats: bool,
500) -> (f64, i32, f64, i32) {
501 // No GT data on this image → both sides are zero (Bkg territory).
502 let cross = match cross {
503 Some(m) => m,
504 None => return (0.0, -1, 0.0, -1),
505 };
506 if cross.ncols() == 0 {
507 return (0.0, -1, 0.0, -1);
508 }
509
510 let mut iou_same = 0.0_f64;
511 let mut best_same: i32 = -1;
512 let mut iou_cross = 0.0_f64;
513 let mut best_cross: i32 = -1;
514
515 for (col, >_input_idx) in gt_local_indices.iter().enumerate() {
516 let v = cross[(row_idx, col)];
517 let g_cat = gt_anns[gt_input_idx].category_id;
518 let same_class = !use_cats || g_cat == dt_cat;
519 // Strict `>` mirrors the oracle (`if ious[g_local] > iou_same`)
520 // — the first column wins ties, matching the oracle's iteration.
521 if same_class {
522 if v > iou_same {
523 iou_same = v;
524 best_same = col as i32;
525 }
526 } else if v > iou_cross {
527 iou_cross = v;
528 best_cross = col as i32;
529 }
530 }
531 (iou_same, best_same, iou_cross, best_cross)
532}
533
534/// Apply the priority chain from `oracle.py:496-531`. Returns the bin
535/// label, the rewrite-layer target, and the iou_same / iou_cross values
536/// the histogram extractor reads (ADR-0022 t_b ratification).
537fn pick_bin(
538 iou_same: f64,
539 best_same_col: i32,
540 iou_cross: f64,
541 best_cross_col: i32,
542 t_f: f64,
543 t_b: f64,
544) -> DtBinLabel {
545 let (bin, target) = if iou_same >= t_f {
546 // The rewrite drops Dupe DTs; target unused but recorded
547 // for symmetry with the oracle's _BinAttribution shape.
548 (DtBin::Dupe, best_same_col)
549 } else if iou_cross >= t_f {
550 (DtBin::Cls, best_cross_col)
551 } else if iou_same >= t_b && iou_same >= iou_cross {
552 (DtBin::Loc, best_same_col)
553 } else if iou_cross >= t_b {
554 (DtBin::Both, best_cross_col)
555 } else {
556 (DtBin::Bkg, -1)
557 };
558 DtBinLabel {
559 bin,
560 target_gt_local_idx: target,
561 iou_same,
562 iou_cross,
563 }
564}