Skip to main content

scirs2_vision/
detection.rs

1//! # Object Detection Utilities
2//!
3//! Comprehensive object detection infrastructure providing bounding box operations,
4//! non-maximum suppression algorithms, anchor generation, detection metrics,
5//! and image augmentation utilities for detection pipelines.
6//!
7//! ## Features
8//!
9//! - **DetectionBox**: Full bounding box representation with IoU, GIoU, DIoU, CIoU metrics
10//! - **NMS**: Standard, soft, batched (per-class), and weighted NMS algorithms
11//! - **Anchor Generation**: SSD, YOLO, and configurable anchor generators
12//! - **Detection Metrics**: Average Precision (AP), mAP, precision-recall curves
13//! - **Augmentation**: Bounding box-aware horizontal flip, crop, scale, translate, clip
14//!
15//! ## Example
16//!
17//! ```rust
18//! use scirs2_vision::detection::{DetectionBox, nms, compute_ap};
19//!
20//! // Create bounding boxes
21//! let b1 = DetectionBox::new(10.0, 10.0, 50.0, 50.0)
22//!     .with_confidence(0.9)
23//!     .with_class(1, Some("cat".to_string()));
24//! let b2 = DetectionBox::new(12.0, 12.0, 52.0, 52.0)
25//!     .with_confidence(0.7)
26//!     .with_class(1, Some("cat".to_string()));
27//! let b3 = DetectionBox::new(200.0, 200.0, 260.0, 260.0)
28//!     .with_confidence(0.85)
29//!     .with_class(2, Some("dog".to_string()));
30//!
31//! // Non-maximum suppression
32//! let kept = nms(&[b1.clone(), b2, b3], 0.5);
33//! assert_eq!(kept.len(), 2); // b1 and b3 survive
34//!
35//! // IoU computation
36//! let iou = b1.iou(&DetectionBox::new(10.0, 10.0, 50.0, 50.0));
37//! assert!((iou - 1.0).abs() < 1e-10);
38//! ```
39
40use crate::error::{Result, VisionError};
41use std::collections::HashMap;
42
43// ---------------------------------------------------------------------------
44// DetectionBox
45// ---------------------------------------------------------------------------
46
47/// A bounding box for object detection, stored as (x1, y1, x2, y2) corner format.
48///
49/// All coordinates use `f64` precision. The box additionally carries optional
50/// confidence, class id, and class name fields used throughout detection pipelines.
51#[derive(Clone, Debug, PartialEq)]
52pub struct DetectionBox {
53    /// Top-left x coordinate
54    pub x1: f64,
55    /// Top-left y coordinate
56    pub y1: f64,
57    /// Bottom-right x coordinate
58    pub x2: f64,
59    /// Bottom-right y coordinate
60    pub y2: f64,
61    /// Detection confidence score in [0, 1]
62    pub confidence: f64,
63    /// Class identifier (0-indexed)
64    pub class_id: usize,
65    /// Optional human-readable class name
66    pub class_name: Option<String>,
67}
68
69impl DetectionBox {
70    /// Create a new detection box from corner coordinates.
71    ///
72    /// Coordinates are normalised so that `x1 <= x2` and `y1 <= y2`.
73    pub fn new(x1: f64, y1: f64, x2: f64, y2: f64) -> Self {
74        Self {
75            x1: x1.min(x2),
76            y1: y1.min(y2),
77            x2: x1.max(x2),
78            y2: y1.max(y2),
79            confidence: 0.0,
80            class_id: 0,
81            class_name: None,
82        }
83    }
84
85    /// Create a detection box from centre coordinates and dimensions.
86    ///
87    /// # Arguments
88    /// * `cx` - Centre x
89    /// * `cy` - Centre y
90    /// * `w`  - Width  (must be non-negative)
91    /// * `h`  - Height (must be non-negative)
92    pub fn from_center(cx: f64, cy: f64, w: f64, h: f64) -> Self {
93        let half_w = w.abs() / 2.0;
94        let half_h = h.abs() / 2.0;
95        Self::new(cx - half_w, cy - half_h, cx + half_w, cy + half_h)
96    }
97
98    /// Builder: set confidence score.
99    #[must_use]
100    pub fn with_confidence(mut self, confidence: f64) -> Self {
101        self.confidence = confidence;
102        self
103    }
104
105    /// Builder: set class id and optional name.
106    #[must_use]
107    pub fn with_class(mut self, class_id: usize, class_name: Option<String>) -> Self {
108        self.class_id = class_id;
109        self.class_name = class_name;
110        self
111    }
112
113    /// Box area.
114    #[inline]
115    pub fn area(&self) -> f64 {
116        (self.x2 - self.x1).max(0.0) * (self.y2 - self.y1).max(0.0)
117    }
118
119    /// Centre of the box as `(cx, cy)`.
120    #[inline]
121    pub fn center(&self) -> (f64, f64) {
122        ((self.x1 + self.x2) / 2.0, (self.y1 + self.y2) / 2.0)
123    }
124
125    /// Width of the box.
126    #[inline]
127    pub fn width(&self) -> f64 {
128        (self.x2 - self.x1).max(0.0)
129    }
130
131    /// Height of the box.
132    #[inline]
133    pub fn height(&self) -> f64 {
134        (self.y2 - self.y1).max(0.0)
135    }
136
137    /// Aspect ratio (width / height). Returns 0.0 if height is zero.
138    #[inline]
139    pub fn aspect_ratio(&self) -> f64 {
140        let h = self.height();
141        if h == 0.0 {
142            0.0
143        } else {
144            self.width() / h
145        }
146    }
147
148    // -- overlap metrics ----------------------------------------------------
149
150    /// Intersection area with another box.
151    pub fn intersection_area(&self, other: &DetectionBox) -> f64 {
152        let ix1 = self.x1.max(other.x1);
153        let iy1 = self.y1.max(other.y1);
154        let ix2 = self.x2.min(other.x2);
155        let iy2 = self.y2.min(other.y2);
156        (ix2 - ix1).max(0.0) * (iy2 - iy1).max(0.0)
157    }
158
159    /// Union area with another box.
160    pub fn union_area(&self, other: &DetectionBox) -> f64 {
161        self.area() + other.area() - self.intersection_area(other)
162    }
163
164    /// Intersection over Union (IoU).
165    ///
166    /// Returns 0.0 when the union area is zero.
167    pub fn iou(&self, other: &DetectionBox) -> f64 {
168        let union = self.union_area(other);
169        if union <= 0.0 {
170            return 0.0;
171        }
172        self.intersection_area(other) / union
173    }
174
175    /// Generalized Intersection over Union (GIoU).
176    ///
177    /// GIoU = IoU - |C \ (A union B)| / |C|
178    /// where C is the smallest enclosing box.
179    /// Range: [-1, 1]
180    pub fn giou(&self, other: &DetectionBox) -> f64 {
181        let inter = self.intersection_area(other);
182        let union = self.union_area(other);
183        if union <= 0.0 {
184            return 0.0;
185        }
186
187        // Enclosing box area
188        let enc_x1 = self.x1.min(other.x1);
189        let enc_y1 = self.y1.min(other.y1);
190        let enc_x2 = self.x2.max(other.x2);
191        let enc_y2 = self.y2.max(other.y2);
192        let enc_area = (enc_x2 - enc_x1).max(0.0) * (enc_y2 - enc_y1).max(0.0);
193
194        let iou_val = inter / union;
195        if enc_area <= 0.0 {
196            return iou_val;
197        }
198        iou_val - (enc_area - union) / enc_area
199    }
200
201    /// Distance-IoU (DIoU).
202    ///
203    /// DIoU = IoU - d^2 / c^2
204    /// where d is the Euclidean distance between centres and c is the diagonal
205    /// length of the smallest enclosing box.
206    pub fn diou(&self, other: &DetectionBox) -> f64 {
207        let union = self.union_area(other);
208        if union <= 0.0 {
209            return 0.0;
210        }
211        let iou_val = self.intersection_area(other) / union;
212
213        let (cx1, cy1) = self.center();
214        let (cx2, cy2) = other.center();
215        let d_sq = (cx1 - cx2).powi(2) + (cy1 - cy2).powi(2);
216
217        let enc_x1 = self.x1.min(other.x1);
218        let enc_y1 = self.y1.min(other.y1);
219        let enc_x2 = self.x2.max(other.x2);
220        let enc_y2 = self.y2.max(other.y2);
221        let c_sq = (enc_x2 - enc_x1).powi(2) + (enc_y2 - enc_y1).powi(2);
222
223        if c_sq <= 0.0 {
224            return iou_val;
225        }
226        iou_val - d_sq / c_sq
227    }
228
229    /// Complete-IoU (CIoU).
230    ///
231    /// CIoU = IoU - d^2/c^2 - alpha * v
232    /// where v measures aspect-ratio consistency and alpha is a trade-off parameter.
233    pub fn ciou(&self, other: &DetectionBox) -> f64 {
234        let union = self.union_area(other);
235        if union <= 0.0 {
236            return 0.0;
237        }
238        let iou_val = self.intersection_area(other) / union;
239
240        let (cx1, cy1) = self.center();
241        let (cx2, cy2) = other.center();
242        let d_sq = (cx1 - cx2).powi(2) + (cy1 - cy2).powi(2);
243
244        let enc_x1 = self.x1.min(other.x1);
245        let enc_y1 = self.y1.min(other.y1);
246        let enc_x2 = self.x2.max(other.x2);
247        let enc_y2 = self.y2.max(other.y2);
248        let c_sq = (enc_x2 - enc_x1).powi(2) + (enc_y2 - enc_y1).powi(2);
249
250        // Aspect-ratio consistency term
251        let pi = std::f64::consts::PI;
252        let v = {
253            let atan_self = (self.width() / self.height().max(1e-12)).atan();
254            let atan_other = (other.width() / other.height().max(1e-12)).atan();
255            (4.0 / (pi * pi)) * (atan_self - atan_other).powi(2)
256        };
257
258        let alpha = if (1.0 - iou_val + v).abs() < 1e-12 {
259            0.0
260        } else {
261            v / (1.0 - iou_val + v)
262        };
263
264        let distance_term = if c_sq > 0.0 { d_sq / c_sq } else { 0.0 };
265        iou_val - distance_term - alpha * v
266    }
267
268    /// Check whether a point `(px, py)` lies inside this box (inclusive).
269    pub fn contains_point(&self, px: f64, py: f64) -> bool {
270        px >= self.x1 && px <= self.x2 && py >= self.y1 && py <= self.y2
271    }
272}
273
274// ---------------------------------------------------------------------------
275// Non-Maximum Suppression
276// ---------------------------------------------------------------------------
277
278/// Standard greedy Non-Maximum Suppression (NMS).
279///
280/// Returns the indices (into the input slice) of the boxes that survive
281/// suppression. Boxes are processed in descending order of confidence.
282///
283/// # Arguments
284/// * `boxes`         - Detection boxes.
285/// * `iou_threshold` - Boxes with IoU > threshold relative to a kept box are removed.
286pub fn nms(boxes: &[DetectionBox], iou_threshold: f64) -> Vec<usize> {
287    if boxes.is_empty() {
288        return Vec::new();
289    }
290
291    // Sort indices by descending confidence
292    let mut indices: Vec<usize> = (0..boxes.len()).collect();
293    indices.sort_by(|&a, &b| {
294        boxes[b]
295            .confidence
296            .partial_cmp(&boxes[a].confidence)
297            .unwrap_or(std::cmp::Ordering::Equal)
298    });
299
300    let mut keep = Vec::new();
301    let mut suppressed = vec![false; boxes.len()];
302
303    for &idx in &indices {
304        if suppressed[idx] {
305            continue;
306        }
307        keep.push(idx);
308        for &other in &indices {
309            if other != idx && !suppressed[other] && boxes[idx].iou(&boxes[other]) > iou_threshold {
310                suppressed[other] = true;
311            }
312        }
313    }
314    keep
315}
316
317/// Soft-NMS with Gaussian score decay.
318///
319/// Instead of hard suppression, overlapping boxes have their confidence
320/// reduced by `exp(-iou^2 / sigma)`. Boxes whose confidence drops below
321/// `score_threshold` are discarded.
322///
323/// The function modifies the confidence values in place and returns the
324/// indices of surviving boxes.
325pub fn soft_nms(boxes: &mut [DetectionBox], sigma: f64, score_threshold: f64) -> Vec<usize> {
326    if boxes.is_empty() {
327        return Vec::new();
328    }
329    let n = boxes.len();
330    let mut active: Vec<bool> = vec![true; n];
331    let mut keep = Vec::new();
332
333    for _ in 0..n {
334        // Find the active box with the highest confidence
335        let mut best_idx: Option<usize> = None;
336        let mut best_score = f64::NEG_INFINITY;
337        for (i, &is_active) in active.iter().enumerate() {
338            if is_active && boxes[i].confidence > best_score {
339                best_score = boxes[i].confidence;
340                best_idx = Some(i);
341            }
342        }
343        let best = match best_idx {
344            Some(i) => i,
345            None => break,
346        };
347
348        if boxes[best].confidence < score_threshold {
349            break;
350        }
351
352        keep.push(best);
353        active[best] = false;
354
355        // Decay overlapping boxes
356        for j in 0..n {
357            if active[j] {
358                let iou_val = boxes[best].iou(&boxes[j]);
359                if sigma > 0.0 {
360                    boxes[j].confidence *= (-iou_val * iou_val / sigma).exp();
361                }
362                if boxes[j].confidence < score_threshold {
363                    active[j] = false;
364                }
365            }
366        }
367    }
368    keep
369}
370
371/// Batched (per-class) NMS.
372///
373/// Applies standard NMS independently within each class and merges the results.
374pub fn batched_nms(boxes: &[DetectionBox], iou_threshold: f64) -> Vec<usize> {
375    if boxes.is_empty() {
376        return Vec::new();
377    }
378
379    // Group indices by class_id
380    let mut class_map: HashMap<usize, Vec<usize>> = HashMap::new();
381    for (i, b) in boxes.iter().enumerate() {
382        class_map.entry(b.class_id).or_default().push(i);
383    }
384
385    let mut keep = Vec::new();
386    for group_indices in class_map.values() {
387        let group_boxes: Vec<DetectionBox> =
388            group_indices.iter().map(|&i| boxes[i].clone()).collect();
389        let class_keep = nms(&group_boxes, iou_threshold);
390        for local_idx in class_keep {
391            keep.push(group_indices[local_idx]);
392        }
393    }
394    // Sort by confidence descending for deterministic output
395    keep.sort_by(|&a, &b| {
396        boxes[b]
397            .confidence
398            .partial_cmp(&boxes[a].confidence)
399            .unwrap_or(std::cmp::Ordering::Equal)
400    });
401    keep
402}
403
404/// Weighted NMS: merges overlapping boxes by confidence-weighted averaging.
405///
406/// For each cluster of overlapping boxes (IoU > threshold), produces a single
407/// box whose coordinates are the confidence-weighted mean of the cluster.
408pub fn weighted_nms(boxes: &[DetectionBox], iou_threshold: f64) -> Vec<DetectionBox> {
409    if boxes.is_empty() {
410        return Vec::new();
411    }
412
413    let mut indices: Vec<usize> = (0..boxes.len()).collect();
414    indices.sort_by(|&a, &b| {
415        boxes[b]
416            .confidence
417            .partial_cmp(&boxes[a].confidence)
418            .unwrap_or(std::cmp::Ordering::Equal)
419    });
420
421    let mut used = vec![false; boxes.len()];
422    let mut result = Vec::new();
423
424    for &idx in &indices {
425        if used[idx] {
426            continue;
427        }
428        // Collect cluster
429        let mut cluster: Vec<usize> = vec![idx];
430        for &other in &indices {
431            if other != idx && !used[other] && boxes[idx].iou(&boxes[other]) > iou_threshold {
432                cluster.push(other);
433            }
434        }
435        // Mark all as used
436        for &c in &cluster {
437            used[c] = true;
438        }
439
440        // Weighted merge
441        let total_conf: f64 = cluster.iter().map(|&c| boxes[c].confidence).sum();
442        if total_conf <= 0.0 {
443            result.push(boxes[idx].clone());
444            continue;
445        }
446        let mut wx1 = 0.0;
447        let mut wy1 = 0.0;
448        let mut wx2 = 0.0;
449        let mut wy2 = 0.0;
450        for &c in &cluster {
451            let w = boxes[c].confidence;
452            wx1 += boxes[c].x1 * w;
453            wy1 += boxes[c].y1 * w;
454            wx2 += boxes[c].x2 * w;
455            wy2 += boxes[c].y2 * w;
456        }
457        let merged = DetectionBox {
458            x1: wx1 / total_conf,
459            y1: wy1 / total_conf,
460            x2: wx2 / total_conf,
461            y2: wy2 / total_conf,
462            confidence: boxes[idx].confidence, // keep max confidence
463            class_id: boxes[idx].class_id,
464            class_name: boxes[idx].class_name.clone(),
465        };
466        result.push(merged);
467    }
468    result
469}
470
471// ---------------------------------------------------------------------------
472// Anchor Generation
473// ---------------------------------------------------------------------------
474
475/// Configuration for anchor generation across feature map levels.
476#[derive(Clone, Debug)]
477pub struct AnchorConfig {
478    /// Feature map spatial sizes at each level, e.g. `[(38,38), (19,19), (10,10)]`.
479    pub feature_map_sizes: Vec<(usize, usize)>,
480    /// Aspect ratios to generate, e.g. `[0.5, 1.0, 2.0]`.
481    pub aspect_ratios: Vec<f64>,
482    /// Scale multipliers at each level (same length as `feature_map_sizes`).
483    pub scales: Vec<f64>,
484    /// Original image size `(width, height)`.
485    pub image_size: (usize, usize),
486}
487
488/// Generate anchors for all feature map levels according to `config`.
489///
490/// For each feature map cell and each (scale, aspect_ratio) combination,
491/// an anchor centred on that cell is produced with the given scale and aspect ratio.
492pub fn generate_anchors(config: &AnchorConfig) -> Result<Vec<DetectionBox>> {
493    if config.feature_map_sizes.is_empty() {
494        return Err(VisionError::InvalidParameter(
495            "feature_map_sizes must not be empty".into(),
496        ));
497    }
498    if config.aspect_ratios.is_empty() {
499        return Err(VisionError::InvalidParameter(
500            "aspect_ratios must not be empty".into(),
501        ));
502    }
503    if config.scales.len() != config.feature_map_sizes.len() {
504        return Err(VisionError::InvalidParameter(
505            "scales length must match feature_map_sizes length".into(),
506        ));
507    }
508
509    let (img_w, img_h) = config.image_size;
510    let img_w = img_w as f64;
511    let img_h = img_h as f64;
512
513    let mut anchors = Vec::new();
514
515    for (level, &(fm_w, fm_h)) in config.feature_map_sizes.iter().enumerate() {
516        if fm_w == 0 || fm_h == 0 {
517            continue;
518        }
519        let step_x = img_w / fm_w as f64;
520        let step_y = img_h / fm_h as f64;
521        let scale = config.scales[level];
522
523        for row in 0..fm_h {
524            for col in 0..fm_w {
525                let cx = (col as f64 + 0.5) * step_x;
526                let cy = (row as f64 + 0.5) * step_y;
527
528                for &ar in &config.aspect_ratios {
529                    let w = scale * ar.sqrt();
530                    let h = scale / ar.sqrt();
531                    anchors.push(DetectionBox::from_center(cx, cy, w, h));
532                }
533            }
534        }
535    }
536    Ok(anchors)
537}
538
539/// Generate SSD-style anchors.
540///
541/// Uses default aspect ratios `[1.0, 2.0, 0.5]` and derives scales from the
542/// ratio between the image size and each feature map size.
543pub fn generate_ssd_anchors(
544    image_size: (usize, usize),
545    feature_maps: &[(usize, usize)],
546) -> Result<Vec<DetectionBox>> {
547    if feature_maps.is_empty() {
548        return Err(VisionError::InvalidParameter(
549            "feature_maps must not be empty".into(),
550        ));
551    }
552    let aspect_ratios = vec![1.0, 2.0, 0.5, 3.0, 1.0 / 3.0];
553    let min_scale = 0.2;
554    let max_scale = 0.9;
555    let num_levels = feature_maps.len();
556    let scales: Vec<f64> = (0..num_levels)
557        .map(|k| {
558            let s = if num_levels > 1 {
559                min_scale + (max_scale - min_scale) * (k as f64) / ((num_levels - 1) as f64)
560            } else {
561                (min_scale + max_scale) / 2.0
562            };
563            s * image_size.0.min(image_size.1) as f64
564        })
565        .collect();
566
567    generate_anchors(&AnchorConfig {
568        feature_map_sizes: feature_maps.to_vec(),
569        aspect_ratios,
570        scales,
571        image_size,
572    })
573}
574
575/// Generate YOLO-style anchors from pre-defined anchor dimensions.
576///
577/// Each entry in `anchor_wh` is `(width, height)` in pixel space. The anchors
578/// are placed at every grid cell of the specified feature map. Returns one
579/// `DetectionBox` per (cell, anchor) pair.
580pub fn generate_yolo_anchors(
581    image_size: (usize, usize),
582    feature_map: (usize, usize),
583    anchor_wh: &[(f64, f64)],
584) -> Result<Vec<DetectionBox>> {
585    if anchor_wh.is_empty() {
586        return Err(VisionError::InvalidParameter(
587            "anchor_wh must not be empty".into(),
588        ));
589    }
590    let (fm_w, fm_h) = feature_map;
591    if fm_w == 0 || fm_h == 0 {
592        return Err(VisionError::InvalidParameter(
593            "feature map dimensions must be > 0".into(),
594        ));
595    }
596    let step_x = image_size.0 as f64 / fm_w as f64;
597    let step_y = image_size.1 as f64 / fm_h as f64;
598
599    let mut anchors = Vec::new();
600    for row in 0..fm_h {
601        for col in 0..fm_w {
602            let cx = (col as f64 + 0.5) * step_x;
603            let cy = (row as f64 + 0.5) * step_y;
604            for &(aw, ah) in anchor_wh {
605                anchors.push(DetectionBox::from_center(cx, cy, aw, ah));
606            }
607        }
608    }
609    Ok(anchors)
610}
611
612// ---------------------------------------------------------------------------
613// Detection Metrics
614// ---------------------------------------------------------------------------
615
616/// Compute Average Precision (AP) for a single class using the all-points interpolation method.
617///
618/// Both `predictions` and `ground_truth` should belong to the same class.
619/// A prediction is considered a true positive if it has IoU > `iou_threshold`
620/// with a ground truth box that has not already been matched.
621pub fn compute_ap(
622    predictions: &[DetectionBox],
623    ground_truth: &[DetectionBox],
624    iou_threshold: f64,
625) -> f64 {
626    if ground_truth.is_empty() {
627        return 0.0;
628    }
629    if predictions.is_empty() {
630        return 0.0;
631    }
632
633    let (precisions, recalls) = precision_recall_curve(predictions, ground_truth, iou_threshold);
634    if precisions.is_empty() {
635        return 0.0;
636    }
637
638    // All-points interpolation (PASCAL VOC 2010+ / COCO style)
639    ap_from_pr(&precisions, &recalls)
640}
641
642/// Compute mean Average Precision (mAP) across multiple images/classes and IoU thresholds.
643///
644/// `predictions[i]` and `ground_truth[i]` correspond to the same image/class.
645/// The function computes AP for each (image, threshold) pair and returns
646/// the mean over all.
647pub fn compute_map(
648    predictions: &[Vec<DetectionBox>],
649    ground_truth: &[Vec<DetectionBox>],
650    iou_thresholds: &[f64],
651) -> f64 {
652    if predictions.is_empty() || ground_truth.is_empty() || iou_thresholds.is_empty() {
653        return 0.0;
654    }
655    let n = predictions.len().min(ground_truth.len());
656    let mut total_ap = 0.0;
657    let mut count = 0usize;
658
659    for threshold in iou_thresholds {
660        for i in 0..n {
661            total_ap += compute_ap(&predictions[i], &ground_truth[i], *threshold);
662            count += 1;
663        }
664    }
665    if count == 0 {
666        0.0
667    } else {
668        total_ap / count as f64
669    }
670}
671
672/// Compute the precision-recall curve for a single class.
673///
674/// Returns `(precisions, recalls)` sorted by decreasing confidence.
675pub fn precision_recall_curve(
676    predictions: &[DetectionBox],
677    ground_truth: &[DetectionBox],
678    iou_threshold: f64,
679) -> (Vec<f64>, Vec<f64>) {
680    if ground_truth.is_empty() || predictions.is_empty() {
681        return (Vec::new(), Vec::new());
682    }
683
684    // Sort predictions by descending confidence
685    let mut sorted_preds: Vec<(usize, f64)> = predictions
686        .iter()
687        .enumerate()
688        .map(|(i, p)| (i, p.confidence))
689        .collect();
690    sorted_preds.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
691
692    let total_gt = ground_truth.len() as f64;
693    let mut matched_gt = vec![false; ground_truth.len()];
694    let mut tp = 0.0_f64;
695    let mut fp = 0.0_f64;
696    let mut precisions = Vec::with_capacity(sorted_preds.len());
697    let mut recalls = Vec::with_capacity(sorted_preds.len());
698
699    for &(pred_idx, _conf) in &sorted_preds {
700        let pred = &predictions[pred_idx];
701
702        // Find the best matching GT box
703        let mut best_iou = 0.0;
704        let mut best_gt: Option<usize> = None;
705        for (gt_idx, gt) in ground_truth.iter().enumerate() {
706            if matched_gt[gt_idx] {
707                continue;
708            }
709            let iou_val = pred.iou(gt);
710            if iou_val > best_iou {
711                best_iou = iou_val;
712                best_gt = Some(gt_idx);
713            }
714        }
715
716        if best_iou >= iou_threshold {
717            if let Some(gt_idx) = best_gt {
718                tp += 1.0;
719                matched_gt[gt_idx] = true;
720            } else {
721                fp += 1.0;
722            }
723        } else {
724            fp += 1.0;
725        }
726
727        let precision = if (tp + fp) > 0.0 { tp / (tp + fp) } else { 0.0 };
728        let recall = tp / total_gt;
729        precisions.push(precision);
730        recalls.push(recall);
731    }
732
733    (precisions, recalls)
734}
735
736/// Compute AP from precision-recall arrays using the all-points interpolation method.
737fn ap_from_pr(precisions: &[f64], recalls: &[f64]) -> f64 {
738    if precisions.is_empty() || recalls.is_empty() {
739        return 0.0;
740    }
741
742    // Prepend (recall=0, precision=1) and append (recall=1, precision=0) sentinel values
743    let n = precisions.len();
744    let mut mrec = Vec::with_capacity(n + 2);
745    let mut mprec = Vec::with_capacity(n + 2);
746    mrec.push(0.0);
747    mprec.push(0.0);
748    for i in 0..n {
749        mrec.push(recalls[i]);
750        mprec.push(precisions[i]);
751    }
752    mrec.push(1.0);
753    mprec.push(0.0);
754
755    // Make precision monotonically decreasing (right to left envelope)
756    for i in (0..mprec.len() - 1).rev() {
757        if mprec[i + 1] > mprec[i] {
758            mprec[i] = mprec[i + 1];
759        }
760    }
761
762    // Sum up the rectangular areas where recall changes
763    let mut ap = 0.0;
764    for i in 1..mrec.len() {
765        if (mrec[i] - mrec[i - 1]).abs() > 1e-15 {
766            ap += (mrec[i] - mrec[i - 1]) * mprec[i];
767        }
768    }
769    ap
770}
771
772/// Compute a simple confusion matrix for detection results.
773///
774/// Returns a 2D map: `class_id -> class_id -> count`, where the first key
775/// is the ground truth class and the second key is the predicted class.
776/// Unmatched ground truths appear under predicted class `usize::MAX` (missed).
777/// False positive predictions appear under ground truth class `usize::MAX`.
778pub fn confusion_matrix(
779    predictions: &[DetectionBox],
780    ground_truth: &[DetectionBox],
781    iou_threshold: f64,
782) -> HashMap<usize, HashMap<usize, usize>> {
783    let mut matrix: HashMap<usize, HashMap<usize, usize>> = HashMap::new();
784    let mut matched_gt = vec![false; ground_truth.len()];
785
786    // Sort predictions by descending confidence
787    let mut sorted_indices: Vec<usize> = (0..predictions.len()).collect();
788    sorted_indices.sort_by(|&a, &b| {
789        predictions[b]
790            .confidence
791            .partial_cmp(&predictions[a].confidence)
792            .unwrap_or(std::cmp::Ordering::Equal)
793    });
794
795    for &pred_idx in &sorted_indices {
796        let pred = &predictions[pred_idx];
797        let mut best_iou = 0.0_f64;
798        let mut best_gt: Option<usize> = None;
799
800        for (gt_idx, gt) in ground_truth.iter().enumerate() {
801            if matched_gt[gt_idx] {
802                continue;
803            }
804            let iou_val = pred.iou(gt);
805            if iou_val > best_iou {
806                best_iou = iou_val;
807                best_gt = Some(gt_idx);
808            }
809        }
810
811        if best_iou >= iou_threshold {
812            if let Some(gt_idx) = best_gt {
813                matched_gt[gt_idx] = true;
814                let gt_class = ground_truth[gt_idx].class_id;
815                let pred_class = pred.class_id;
816                *matrix
817                    .entry(gt_class)
818                    .or_default()
819                    .entry(pred_class)
820                    .or_insert(0) += 1;
821            }
822        } else {
823            // False positive
824            *matrix
825                .entry(usize::MAX)
826                .or_default()
827                .entry(pred.class_id)
828                .or_insert(0) += 1;
829        }
830    }
831
832    // Record unmatched ground truths (missed detections)
833    for (gt_idx, gt) in ground_truth.iter().enumerate() {
834        if !matched_gt[gt_idx] {
835            *matrix
836                .entry(gt.class_id)
837                .or_default()
838                .entry(usize::MAX)
839                .or_insert(0) += 1;
840        }
841    }
842
843    matrix
844}
845
846// ---------------------------------------------------------------------------
847// Augmentation helpers
848// ---------------------------------------------------------------------------
849
850/// Flip bounding boxes horizontally within an image of width `image_width`.
851///
852/// Uses a simple xorshift PRNG seeded with `seed` and flips with `probability` in `[0,1]`.
853/// Modified in place.
854pub fn random_horizontal_flip(
855    boxes: &mut [DetectionBox],
856    image_width: f64,
857    probability: f64,
858    seed: u64,
859) {
860    // Simple xorshift64 PRNG
861    let mut state = if seed == 0 { 0x5DEECE66D } else { seed };
862    let should_flip = {
863        state ^= state << 13;
864        state ^= state >> 7;
865        state ^= state << 17;
866        let rand_val = (state as f64) / (u64::MAX as f64);
867        rand_val.abs() < probability
868    };
869
870    if should_flip {
871        for b in boxes.iter_mut() {
872            let new_x1 = image_width - b.x2;
873            let new_x2 = image_width - b.x1;
874            b.x1 = new_x1;
875            b.x2 = new_x2;
876        }
877    }
878}
879
880/// Return boxes that are fully or partially inside `crop_box`, with coordinates
881/// adjusted relative to the crop region.
882///
883/// Boxes that fall entirely outside the crop region are discarded.
884/// Partially visible boxes are clipped to the crop boundary.
885pub fn random_crop_with_boxes(
886    boxes: &[DetectionBox],
887    crop_box: &DetectionBox,
888) -> Vec<DetectionBox> {
889    let mut result = Vec::new();
890    for b in boxes {
891        // Clip to crop region
892        let clipped_x1 = b.x1.max(crop_box.x1);
893        let clipped_y1 = b.y1.max(crop_box.y1);
894        let clipped_x2 = b.x2.min(crop_box.x2);
895        let clipped_y2 = b.y2.min(crop_box.y2);
896
897        // Discard if no intersection
898        if clipped_x1 >= clipped_x2 || clipped_y1 >= clipped_y2 {
899            continue;
900        }
901
902        // Translate to crop-relative coordinates
903        result.push(DetectionBox {
904            x1: clipped_x1 - crop_box.x1,
905            y1: clipped_y1 - crop_box.y1,
906            x2: clipped_x2 - crop_box.x1,
907            y2: clipped_y2 - crop_box.y1,
908            confidence: b.confidence,
909            class_id: b.class_id,
910            class_name: b.class_name.clone(),
911        });
912    }
913    result
914}
915
916/// Scale bounding box coordinates by `(sx, sy)`.
917pub fn scale_boxes(boxes: &mut [DetectionBox], sx: f64, sy: f64) {
918    for b in boxes.iter_mut() {
919        b.x1 *= sx;
920        b.y1 *= sy;
921        b.x2 *= sx;
922        b.y2 *= sy;
923    }
924}
925
926/// Translate bounding box coordinates by `(tx, ty)`.
927pub fn translate_boxes(boxes: &mut [DetectionBox], tx: f64, ty: f64) {
928    for b in boxes.iter_mut() {
929        b.x1 += tx;
930        b.y1 += ty;
931        b.x2 += tx;
932        b.y2 += ty;
933    }
934}
935
936/// Clip bounding boxes to image boundaries `[0, image_width] x [0, image_height]`.
937///
938/// Boxes that end up with zero or negative area after clipping are removed.
939pub fn clip_boxes(boxes: &mut Vec<DetectionBox>, image_width: f64, image_height: f64) {
940    for b in boxes.iter_mut() {
941        b.x1 = b.x1.max(0.0).min(image_width);
942        b.y1 = b.y1.max(0.0).min(image_height);
943        b.x2 = b.x2.max(0.0).min(image_width);
944        b.y2 = b.y2.max(0.0).min(image_height);
945    }
946    boxes.retain(|b| b.width() > 0.0 && b.height() > 0.0);
947}
948
949/// Filter detection boxes by confidence threshold, returning only those above.
950pub fn filter_by_confidence(boxes: &[DetectionBox], threshold: f64) -> Vec<DetectionBox> {
951    boxes
952        .iter()
953        .filter(|b| b.confidence >= threshold)
954        .cloned()
955        .collect()
956}
957
958/// Convert a slice of `DetectionBox` to `(x1, y1, x2, y2, confidence, class_id)` tuples.
959pub fn boxes_to_tuples(boxes: &[DetectionBox]) -> Vec<(f64, f64, f64, f64, f64, usize)> {
960    boxes
961        .iter()
962        .map(|b| (b.x1, b.y1, b.x2, b.y2, b.confidence, b.class_id))
963        .collect()
964}
965
966// ---------------------------------------------------------------------------
967// Tests
968// ---------------------------------------------------------------------------
969
970#[cfg(test)]
971mod tests {
972    use super::*;
973
974    fn approx_eq(a: f64, b: f64, eps: f64) -> bool {
975        (a - b).abs() < eps
976    }
977
978    // -----------------------------------------------------------------------
979    // DetectionBox basics
980    // -----------------------------------------------------------------------
981
982    #[test]
983    fn test_new_normalises_coordinates() {
984        let b = DetectionBox::new(50.0, 40.0, 10.0, 20.0);
985        assert!(b.x1 <= b.x2);
986        assert!(b.y1 <= b.y2);
987        assert_eq!(b.x1, 10.0);
988        assert_eq!(b.y1, 20.0);
989        assert_eq!(b.x2, 50.0);
990        assert_eq!(b.y2, 40.0);
991    }
992
993    #[test]
994    fn test_from_center() {
995        let b = DetectionBox::from_center(100.0, 100.0, 40.0, 20.0);
996        assert!(approx_eq(b.x1, 80.0, 1e-10));
997        assert!(approx_eq(b.y1, 90.0, 1e-10));
998        assert!(approx_eq(b.x2, 120.0, 1e-10));
999        assert!(approx_eq(b.y2, 110.0, 1e-10));
1000    }
1001
1002    #[test]
1003    fn test_area() {
1004        let b = DetectionBox::new(0.0, 0.0, 10.0, 20.0);
1005        assert!(approx_eq(b.area(), 200.0, 1e-10));
1006    }
1007
1008    #[test]
1009    fn test_center() {
1010        let b = DetectionBox::new(10.0, 20.0, 30.0, 40.0);
1011        let (cx, cy) = b.center();
1012        assert!(approx_eq(cx, 20.0, 1e-10));
1013        assert!(approx_eq(cy, 30.0, 1e-10));
1014    }
1015
1016    #[test]
1017    fn test_width_height() {
1018        let b = DetectionBox::new(5.0, 10.0, 25.0, 40.0);
1019        assert!(approx_eq(b.width(), 20.0, 1e-10));
1020        assert!(approx_eq(b.height(), 30.0, 1e-10));
1021    }
1022
1023    #[test]
1024    fn test_aspect_ratio() {
1025        let b = DetectionBox::new(0.0, 0.0, 20.0, 10.0);
1026        assert!(approx_eq(b.aspect_ratio(), 2.0, 1e-10));
1027    }
1028
1029    #[test]
1030    fn test_contains_point() {
1031        let b = DetectionBox::new(10.0, 10.0, 50.0, 50.0);
1032        assert!(b.contains_point(30.0, 30.0));
1033        assert!(b.contains_point(10.0, 10.0));
1034        assert!(!b.contains_point(5.0, 5.0));
1035        assert!(!b.contains_point(55.0, 55.0));
1036    }
1037
1038    #[test]
1039    fn test_builder_methods() {
1040        let b = DetectionBox::new(0.0, 0.0, 10.0, 10.0)
1041            .with_confidence(0.95)
1042            .with_class(3, Some("person".to_string()));
1043        assert!(approx_eq(b.confidence, 0.95, 1e-10));
1044        assert_eq!(b.class_id, 3);
1045        assert_eq!(b.class_name.as_deref(), Some("person"));
1046    }
1047
1048    // -----------------------------------------------------------------------
1049    // IoU variants
1050    // -----------------------------------------------------------------------
1051
1052    #[test]
1053    fn test_iou_identical() {
1054        let b = DetectionBox::new(10.0, 10.0, 50.0, 50.0);
1055        assert!(approx_eq(b.iou(&b), 1.0, 1e-10));
1056    }
1057
1058    #[test]
1059    fn test_iou_no_overlap() {
1060        let a = DetectionBox::new(0.0, 0.0, 10.0, 10.0);
1061        let b = DetectionBox::new(20.0, 20.0, 30.0, 30.0);
1062        assert!(approx_eq(a.iou(&b), 0.0, 1e-10));
1063    }
1064
1065    #[test]
1066    fn test_iou_partial_overlap() {
1067        let a = DetectionBox::new(0.0, 0.0, 10.0, 10.0);
1068        let b = DetectionBox::new(5.0, 5.0, 15.0, 15.0);
1069        // Intersection: 5x5 = 25. Union: 100 + 100 - 25 = 175. IoU = 25/175
1070        let expected = 25.0 / 175.0;
1071        assert!(approx_eq(a.iou(&b), expected, 1e-10));
1072    }
1073
1074    #[test]
1075    fn test_giou_identical() {
1076        let b = DetectionBox::new(10.0, 10.0, 50.0, 50.0);
1077        assert!(approx_eq(b.giou(&b), 1.0, 1e-10));
1078    }
1079
1080    #[test]
1081    fn test_giou_no_overlap() {
1082        let a = DetectionBox::new(0.0, 0.0, 10.0, 10.0);
1083        let b = DetectionBox::new(20.0, 20.0, 30.0, 30.0);
1084        // GIoU should be negative when boxes are far apart
1085        let val = a.giou(&b);
1086        assert!(val < 0.0);
1087    }
1088
1089    #[test]
1090    fn test_giou_range() {
1091        // GIoU should be in [-1, 1]
1092        let a = DetectionBox::new(0.0, 0.0, 1.0, 1.0);
1093        let b = DetectionBox::new(100.0, 100.0, 101.0, 101.0);
1094        let val = a.giou(&b);
1095        assert!((-1.0..=1.0).contains(&val));
1096    }
1097
1098    #[test]
1099    fn test_diou_identical() {
1100        let b = DetectionBox::new(10.0, 10.0, 50.0, 50.0);
1101        assert!(approx_eq(b.diou(&b), 1.0, 1e-10));
1102    }
1103
1104    #[test]
1105    fn test_ciou_identical() {
1106        let b = DetectionBox::new(10.0, 10.0, 50.0, 50.0);
1107        assert!(approx_eq(b.ciou(&b), 1.0, 1e-10));
1108    }
1109
1110    #[test]
1111    fn test_ciou_different_aspect_ratios() {
1112        let a = DetectionBox::new(0.0, 0.0, 100.0, 50.0); // 2:1
1113        let b = DetectionBox::new(0.0, 0.0, 50.0, 100.0); // 1:2
1114        let ciou_val = a.ciou(&b);
1115        // CIoU penalises aspect-ratio difference, so it should be lower than IoU
1116        let iou_val = a.iou(&b);
1117        assert!(ciou_val <= iou_val);
1118    }
1119
1120    // -----------------------------------------------------------------------
1121    // NMS
1122    // -----------------------------------------------------------------------
1123
1124    #[test]
1125    fn test_nms_basic() {
1126        let boxes = vec![
1127            DetectionBox::new(0.0, 0.0, 10.0, 10.0).with_confidence(0.9),
1128            DetectionBox::new(1.0, 1.0, 11.0, 11.0).with_confidence(0.7),
1129            DetectionBox::new(200.0, 200.0, 210.0, 210.0).with_confidence(0.8),
1130        ];
1131        let kept = nms(&boxes, 0.5);
1132        assert_eq!(kept.len(), 2);
1133        assert!(kept.contains(&0)); // highest confidence overlapping
1134        assert!(kept.contains(&2)); // non-overlapping
1135    }
1136
1137    #[test]
1138    fn test_nms_empty() {
1139        let kept = nms(&[], 0.5);
1140        assert!(kept.is_empty());
1141    }
1142
1143    #[test]
1144    fn test_nms_no_suppression() {
1145        let boxes = vec![
1146            DetectionBox::new(0.0, 0.0, 10.0, 10.0).with_confidence(0.9),
1147            DetectionBox::new(100.0, 100.0, 110.0, 110.0).with_confidence(0.8),
1148        ];
1149        let kept = nms(&boxes, 0.5);
1150        assert_eq!(kept.len(), 2);
1151    }
1152
1153    #[test]
1154    fn test_soft_nms_reduces_scores() {
1155        let mut boxes = vec![
1156            DetectionBox::new(0.0, 0.0, 10.0, 10.0).with_confidence(0.9),
1157            DetectionBox::new(1.0, 1.0, 11.0, 11.0).with_confidence(0.8),
1158            DetectionBox::new(200.0, 200.0, 210.0, 210.0).with_confidence(0.85),
1159        ];
1160        let original_conf = boxes[1].confidence;
1161        let kept = soft_nms(&mut boxes, 0.5, 0.01);
1162        // Overlapping box should have reduced confidence
1163        assert!(boxes[1].confidence < original_conf);
1164        assert!(!kept.is_empty());
1165    }
1166
1167    #[test]
1168    fn test_batched_nms_separates_classes() {
1169        let boxes = vec![
1170            DetectionBox::new(0.0, 0.0, 10.0, 10.0)
1171                .with_confidence(0.9)
1172                .with_class(0, None),
1173            DetectionBox::new(1.0, 1.0, 11.0, 11.0)
1174                .with_confidence(0.8)
1175                .with_class(1, None), // different class - should not suppress
1176        ];
1177        let kept = batched_nms(&boxes, 0.5);
1178        assert_eq!(kept.len(), 2); // both kept since different classes
1179    }
1180
1181    #[test]
1182    fn test_batched_nms_suppresses_same_class() {
1183        let boxes = vec![
1184            DetectionBox::new(0.0, 0.0, 10.0, 10.0)
1185                .with_confidence(0.9)
1186                .with_class(0, None),
1187            DetectionBox::new(1.0, 1.0, 11.0, 11.0)
1188                .with_confidence(0.7)
1189                .with_class(0, None),
1190        ];
1191        let kept = batched_nms(&boxes, 0.5);
1192        assert_eq!(kept.len(), 1);
1193    }
1194
1195    #[test]
1196    fn test_weighted_nms_merges() {
1197        let boxes = vec![
1198            DetectionBox::new(0.0, 0.0, 10.0, 10.0).with_confidence(0.9),
1199            DetectionBox::new(1.0, 1.0, 11.0, 11.0).with_confidence(0.8),
1200        ];
1201        let merged = weighted_nms(&boxes, 0.3);
1202        assert_eq!(merged.len(), 1);
1203        // Merged box should be a weighted average
1204        assert!(merged[0].x1 > 0.0 && merged[0].x1 < 1.0);
1205    }
1206
1207    // -----------------------------------------------------------------------
1208    // Anchor Generation
1209    // -----------------------------------------------------------------------
1210
1211    #[test]
1212    fn test_generate_anchors_basic() {
1213        let config = AnchorConfig {
1214            feature_map_sizes: vec![(4, 4)],
1215            aspect_ratios: vec![1.0, 2.0],
1216            scales: vec![32.0],
1217            image_size: (128, 128),
1218        };
1219        let anchors = generate_anchors(&config);
1220        assert!(anchors.is_ok());
1221        let anchors = anchors.expect("should succeed");
1222        // 4*4 cells * 2 aspect ratios = 32 anchors
1223        assert_eq!(anchors.len(), 32);
1224    }
1225
1226    #[test]
1227    fn test_generate_anchors_empty_feature_maps() {
1228        let config = AnchorConfig {
1229            feature_map_sizes: vec![],
1230            aspect_ratios: vec![1.0],
1231            scales: vec![],
1232            image_size: (100, 100),
1233        };
1234        let result = generate_anchors(&config);
1235        assert!(result.is_err());
1236    }
1237
1238    #[test]
1239    fn test_generate_ssd_anchors() {
1240        let anchors = generate_ssd_anchors((300, 300), &[(38, 38), (19, 19), (10, 10)]);
1241        assert!(anchors.is_ok());
1242        let anchors = anchors.expect("should succeed");
1243        // 38*38*5 + 19*19*5 + 10*10*5 = 7220 + 1805 + 500 = 9525
1244        assert_eq!(anchors.len(), 9525);
1245    }
1246
1247    #[test]
1248    fn test_generate_yolo_anchors() {
1249        let anchor_wh = vec![(10.0, 13.0), (16.0, 30.0), (33.0, 23.0)];
1250        let anchors = generate_yolo_anchors((416, 416), (13, 13), &anchor_wh);
1251        assert!(anchors.is_ok());
1252        let anchors = anchors.expect("should succeed");
1253        assert_eq!(anchors.len(), 13 * 13 * 3);
1254    }
1255
1256    #[test]
1257    fn test_generate_yolo_anchors_empty() {
1258        let result = generate_yolo_anchors((416, 416), (13, 13), &[]);
1259        assert!(result.is_err());
1260    }
1261
1262    // -----------------------------------------------------------------------
1263    // Detection Metrics
1264    // -----------------------------------------------------------------------
1265
1266    #[test]
1267    fn test_compute_ap_perfect() {
1268        let preds = vec![DetectionBox::new(0.0, 0.0, 10.0, 10.0).with_confidence(0.9)];
1269        let gt = vec![DetectionBox::new(0.0, 0.0, 10.0, 10.0)];
1270        let ap = compute_ap(&preds, &gt, 0.5);
1271        assert!(approx_eq(ap, 1.0, 1e-10));
1272    }
1273
1274    #[test]
1275    fn test_compute_ap_no_predictions() {
1276        let gt = vec![DetectionBox::new(0.0, 0.0, 10.0, 10.0)];
1277        let ap = compute_ap(&[], &gt, 0.5);
1278        assert!(approx_eq(ap, 0.0, 1e-10));
1279    }
1280
1281    #[test]
1282    fn test_compute_ap_no_ground_truth() {
1283        let preds = vec![DetectionBox::new(0.0, 0.0, 10.0, 10.0).with_confidence(0.9)];
1284        let ap = compute_ap(&preds, &[], 0.5);
1285        assert!(approx_eq(ap, 0.0, 1e-10));
1286    }
1287
1288    #[test]
1289    fn test_compute_map() {
1290        let preds = vec![vec![
1291            DetectionBox::new(0.0, 0.0, 10.0, 10.0).with_confidence(0.9)
1292        ]];
1293        let gt = vec![vec![DetectionBox::new(0.0, 0.0, 10.0, 10.0)]];
1294        let thresholds = vec![0.5, 0.75];
1295        let map_val = compute_map(&preds, &gt, &thresholds);
1296        assert!(approx_eq(map_val, 1.0, 1e-10));
1297    }
1298
1299    #[test]
1300    fn test_precision_recall_curve_basic() {
1301        let preds = vec![
1302            DetectionBox::new(0.0, 0.0, 10.0, 10.0).with_confidence(0.9),
1303            DetectionBox::new(100.0, 100.0, 110.0, 110.0).with_confidence(0.5),
1304        ];
1305        let gt = vec![DetectionBox::new(0.0, 0.0, 10.0, 10.0)];
1306        let (prec, rec) = precision_recall_curve(&preds, &gt, 0.5);
1307        assert_eq!(prec.len(), 2);
1308        assert_eq!(rec.len(), 2);
1309        // First prediction matches: precision=1.0, recall=1.0
1310        assert!(approx_eq(prec[0], 1.0, 1e-10));
1311        assert!(approx_eq(rec[0], 1.0, 1e-10));
1312    }
1313
1314    #[test]
1315    fn test_confusion_matrix_basic() {
1316        let preds = vec![
1317            DetectionBox::new(0.0, 0.0, 10.0, 10.0)
1318                .with_confidence(0.9)
1319                .with_class(0, None),
1320            DetectionBox::new(200.0, 200.0, 210.0, 210.0)
1321                .with_confidence(0.8)
1322                .with_class(1, None),
1323        ];
1324        let gt = vec![DetectionBox::new(0.0, 0.0, 10.0, 10.0).with_class(0, None)];
1325        let cm = confusion_matrix(&preds, &gt, 0.5);
1326        // GT class 0 matched by pred class 0
1327        assert_eq!(*cm.get(&0).and_then(|m| m.get(&0)).unwrap_or(&0), 1);
1328        // Pred class 1 is a false positive (gt=usize::MAX)
1329        assert_eq!(
1330            *cm.get(&usize::MAX).and_then(|m| m.get(&1)).unwrap_or(&0),
1331            1
1332        );
1333    }
1334
1335    // -----------------------------------------------------------------------
1336    // Augmentation
1337    // -----------------------------------------------------------------------
1338
1339    #[test]
1340    fn test_scale_boxes() {
1341        let mut boxes = vec![DetectionBox::new(10.0, 20.0, 30.0, 40.0)];
1342        scale_boxes(&mut boxes, 2.0, 0.5);
1343        assert!(approx_eq(boxes[0].x1, 20.0, 1e-10));
1344        assert!(approx_eq(boxes[0].y1, 10.0, 1e-10));
1345        assert!(approx_eq(boxes[0].x2, 60.0, 1e-10));
1346        assert!(approx_eq(boxes[0].y2, 20.0, 1e-10));
1347    }
1348
1349    #[test]
1350    fn test_translate_boxes() {
1351        let mut boxes = vec![DetectionBox::new(10.0, 20.0, 30.0, 40.0)];
1352        translate_boxes(&mut boxes, 5.0, -5.0);
1353        assert!(approx_eq(boxes[0].x1, 15.0, 1e-10));
1354        assert!(approx_eq(boxes[0].y1, 15.0, 1e-10));
1355        assert!(approx_eq(boxes[0].x2, 35.0, 1e-10));
1356        assert!(approx_eq(boxes[0].y2, 35.0, 1e-10));
1357    }
1358
1359    #[test]
1360    fn test_clip_boxes() {
1361        let mut boxes = vec![
1362            DetectionBox::new(-5.0, -5.0, 50.0, 50.0),
1363            DetectionBox::new(-10.0, -10.0, -1.0, -1.0), // fully outside
1364        ];
1365        clip_boxes(&mut boxes, 100.0, 100.0);
1366        assert_eq!(boxes.len(), 1); // second box removed
1367        assert!(approx_eq(boxes[0].x1, 0.0, 1e-10));
1368        assert!(approx_eq(boxes[0].y1, 0.0, 1e-10));
1369    }
1370
1371    #[test]
1372    fn test_random_crop_with_boxes() {
1373        let boxes = vec![
1374            DetectionBox::new(10.0, 10.0, 50.0, 50.0).with_confidence(0.9),
1375            DetectionBox::new(200.0, 200.0, 250.0, 250.0).with_confidence(0.8), // outside crop
1376        ];
1377        let crop = DetectionBox::new(0.0, 0.0, 100.0, 100.0);
1378        let result = random_crop_with_boxes(&boxes, &crop);
1379        assert_eq!(result.len(), 1);
1380        assert!(approx_eq(result[0].x1, 10.0, 1e-10));
1381        assert!(approx_eq(result[0].confidence, 0.9, 1e-10));
1382    }
1383
1384    #[test]
1385    fn test_random_horizontal_flip_deterministic() {
1386        // With probability 1.0, flip should always happen
1387        let mut boxes = vec![DetectionBox::new(10.0, 0.0, 30.0, 20.0)];
1388        random_horizontal_flip(&mut boxes, 100.0, 1.0, 42);
1389        // After flip: new_x1 = 100 - 30 = 70, new_x2 = 100 - 10 = 90
1390        assert!(approx_eq(boxes[0].x1, 70.0, 1e-10));
1391        assert!(approx_eq(boxes[0].x2, 90.0, 1e-10));
1392    }
1393
1394    #[test]
1395    fn test_filter_by_confidence() {
1396        let boxes = vec![
1397            DetectionBox::new(0.0, 0.0, 10.0, 10.0).with_confidence(0.9),
1398            DetectionBox::new(0.0, 0.0, 10.0, 10.0).with_confidence(0.3),
1399            DetectionBox::new(0.0, 0.0, 10.0, 10.0).with_confidence(0.6),
1400        ];
1401        let filtered = filter_by_confidence(&boxes, 0.5);
1402        assert_eq!(filtered.len(), 2);
1403    }
1404
1405    #[test]
1406    fn test_boxes_to_tuples() {
1407        let boxes = vec![DetectionBox::new(1.0, 2.0, 3.0, 4.0)
1408            .with_confidence(0.5)
1409            .with_class(7, None)];
1410        let tuples = boxes_to_tuples(&boxes);
1411        assert_eq!(tuples.len(), 1);
1412        let (x1, y1, x2, y2, c, cls) = tuples[0];
1413        assert!(approx_eq(x1, 1.0, 1e-10));
1414        assert!(approx_eq(y1, 2.0, 1e-10));
1415        assert!(approx_eq(x2, 3.0, 1e-10));
1416        assert!(approx_eq(y2, 4.0, 1e-10));
1417        assert!(approx_eq(c, 0.5, 1e-10));
1418        assert_eq!(cls, 7);
1419    }
1420
1421    #[test]
1422    fn test_intersection_area() {
1423        let a = DetectionBox::new(0.0, 0.0, 10.0, 10.0);
1424        let b = DetectionBox::new(5.0, 5.0, 15.0, 15.0);
1425        assert!(approx_eq(a.intersection_area(&b), 25.0, 1e-10));
1426    }
1427
1428    #[test]
1429    fn test_union_area() {
1430        let a = DetectionBox::new(0.0, 0.0, 10.0, 10.0);
1431        let b = DetectionBox::new(5.0, 5.0, 15.0, 15.0);
1432        assert!(approx_eq(a.union_area(&b), 175.0, 1e-10));
1433    }
1434
1435    #[test]
1436    fn test_zero_area_box() {
1437        let b = DetectionBox::new(5.0, 5.0, 5.0, 5.0);
1438        assert!(approx_eq(b.area(), 0.0, 1e-10));
1439        assert!(approx_eq(b.width(), 0.0, 1e-10));
1440        assert!(approx_eq(b.height(), 0.0, 1e-10));
1441    }
1442
1443    #[test]
1444    fn test_iou_symmetry() {
1445        let a = DetectionBox::new(0.0, 0.0, 10.0, 10.0);
1446        let b = DetectionBox::new(5.0, 0.0, 15.0, 10.0);
1447        assert!(approx_eq(a.iou(&b), b.iou(&a), 1e-12));
1448    }
1449
1450    #[test]
1451    fn test_nms_all_identical() {
1452        let boxes = vec![
1453            DetectionBox::new(0.0, 0.0, 10.0, 10.0).with_confidence(0.9),
1454            DetectionBox::new(0.0, 0.0, 10.0, 10.0).with_confidence(0.8),
1455            DetectionBox::new(0.0, 0.0, 10.0, 10.0).with_confidence(0.7),
1456        ];
1457        let kept = nms(&boxes, 0.5);
1458        // Only the highest-confidence box survives
1459        assert_eq!(kept.len(), 1);
1460        assert_eq!(kept[0], 0);
1461    }
1462
1463    #[test]
1464    fn test_compute_ap_mixed() {
1465        // Two GT boxes, three predictions: one TP, one FP, one TP
1466        let preds = vec![
1467            DetectionBox::new(0.0, 0.0, 10.0, 10.0).with_confidence(0.9), // TP
1468            DetectionBox::new(500.0, 500.0, 510.0, 510.0).with_confidence(0.8), // FP
1469            DetectionBox::new(100.0, 100.0, 110.0, 110.0).with_confidence(0.7), // TP
1470        ];
1471        let gt = vec![
1472            DetectionBox::new(0.0, 0.0, 10.0, 10.0),
1473            DetectionBox::new(100.0, 100.0, 110.0, 110.0),
1474        ];
1475        let ap = compute_ap(&preds, &gt, 0.5);
1476        // AP should be reasonably high (2 out of 3 correct, ordered well)
1477        assert!(ap > 0.5);
1478    }
1479
1480    #[test]
1481    fn test_anchor_centres_are_within_image() {
1482        let config = AnchorConfig {
1483            feature_map_sizes: vec![(8, 8)],
1484            aspect_ratios: vec![1.0],
1485            scales: vec![16.0],
1486            image_size: (256, 256),
1487        };
1488        let anchors = generate_anchors(&config).expect("should succeed");
1489        for a in &anchors {
1490            let (cx, cy) = a.center();
1491            assert!(cx > 0.0 && cx < 256.0, "cx={cx} out of range");
1492            assert!(cy > 0.0 && cy < 256.0, "cy={cy} out of range");
1493        }
1494    }
1495}