1use crate::error::{Result, VisionError};
41use std::collections::HashMap;
42
43#[derive(Clone, Debug, PartialEq)]
52pub struct DetectionBox {
53 pub x1: f64,
55 pub y1: f64,
57 pub x2: f64,
59 pub y2: f64,
61 pub confidence: f64,
63 pub class_id: usize,
65 pub class_name: Option<String>,
67}
68
69impl DetectionBox {
70 pub fn new(x1: f64, y1: f64, x2: f64, y2: f64) -> Self {
74 Self {
75 x1: x1.min(x2),
76 y1: y1.min(y2),
77 x2: x1.max(x2),
78 y2: y1.max(y2),
79 confidence: 0.0,
80 class_id: 0,
81 class_name: None,
82 }
83 }
84
85 pub fn from_center(cx: f64, cy: f64, w: f64, h: f64) -> Self {
93 let half_w = w.abs() / 2.0;
94 let half_h = h.abs() / 2.0;
95 Self::new(cx - half_w, cy - half_h, cx + half_w, cy + half_h)
96 }
97
98 #[must_use]
100 pub fn with_confidence(mut self, confidence: f64) -> Self {
101 self.confidence = confidence;
102 self
103 }
104
105 #[must_use]
107 pub fn with_class(mut self, class_id: usize, class_name: Option<String>) -> Self {
108 self.class_id = class_id;
109 self.class_name = class_name;
110 self
111 }
112
113 #[inline]
115 pub fn area(&self) -> f64 {
116 (self.x2 - self.x1).max(0.0) * (self.y2 - self.y1).max(0.0)
117 }
118
119 #[inline]
121 pub fn center(&self) -> (f64, f64) {
122 ((self.x1 + self.x2) / 2.0, (self.y1 + self.y2) / 2.0)
123 }
124
125 #[inline]
127 pub fn width(&self) -> f64 {
128 (self.x2 - self.x1).max(0.0)
129 }
130
131 #[inline]
133 pub fn height(&self) -> f64 {
134 (self.y2 - self.y1).max(0.0)
135 }
136
137 #[inline]
139 pub fn aspect_ratio(&self) -> f64 {
140 let h = self.height();
141 if h == 0.0 {
142 0.0
143 } else {
144 self.width() / h
145 }
146 }
147
148 pub fn intersection_area(&self, other: &DetectionBox) -> f64 {
152 let ix1 = self.x1.max(other.x1);
153 let iy1 = self.y1.max(other.y1);
154 let ix2 = self.x2.min(other.x2);
155 let iy2 = self.y2.min(other.y2);
156 (ix2 - ix1).max(0.0) * (iy2 - iy1).max(0.0)
157 }
158
159 pub fn union_area(&self, other: &DetectionBox) -> f64 {
161 self.area() + other.area() - self.intersection_area(other)
162 }
163
164 pub fn iou(&self, other: &DetectionBox) -> f64 {
168 let union = self.union_area(other);
169 if union <= 0.0 {
170 return 0.0;
171 }
172 self.intersection_area(other) / union
173 }
174
175 pub fn giou(&self, other: &DetectionBox) -> f64 {
181 let inter = self.intersection_area(other);
182 let union = self.union_area(other);
183 if union <= 0.0 {
184 return 0.0;
185 }
186
187 let enc_x1 = self.x1.min(other.x1);
189 let enc_y1 = self.y1.min(other.y1);
190 let enc_x2 = self.x2.max(other.x2);
191 let enc_y2 = self.y2.max(other.y2);
192 let enc_area = (enc_x2 - enc_x1).max(0.0) * (enc_y2 - enc_y1).max(0.0);
193
194 let iou_val = inter / union;
195 if enc_area <= 0.0 {
196 return iou_val;
197 }
198 iou_val - (enc_area - union) / enc_area
199 }
200
201 pub fn diou(&self, other: &DetectionBox) -> f64 {
207 let union = self.union_area(other);
208 if union <= 0.0 {
209 return 0.0;
210 }
211 let iou_val = self.intersection_area(other) / union;
212
213 let (cx1, cy1) = self.center();
214 let (cx2, cy2) = other.center();
215 let d_sq = (cx1 - cx2).powi(2) + (cy1 - cy2).powi(2);
216
217 let enc_x1 = self.x1.min(other.x1);
218 let enc_y1 = self.y1.min(other.y1);
219 let enc_x2 = self.x2.max(other.x2);
220 let enc_y2 = self.y2.max(other.y2);
221 let c_sq = (enc_x2 - enc_x1).powi(2) + (enc_y2 - enc_y1).powi(2);
222
223 if c_sq <= 0.0 {
224 return iou_val;
225 }
226 iou_val - d_sq / c_sq
227 }
228
229 pub fn ciou(&self, other: &DetectionBox) -> f64 {
234 let union = self.union_area(other);
235 if union <= 0.0 {
236 return 0.0;
237 }
238 let iou_val = self.intersection_area(other) / union;
239
240 let (cx1, cy1) = self.center();
241 let (cx2, cy2) = other.center();
242 let d_sq = (cx1 - cx2).powi(2) + (cy1 - cy2).powi(2);
243
244 let enc_x1 = self.x1.min(other.x1);
245 let enc_y1 = self.y1.min(other.y1);
246 let enc_x2 = self.x2.max(other.x2);
247 let enc_y2 = self.y2.max(other.y2);
248 let c_sq = (enc_x2 - enc_x1).powi(2) + (enc_y2 - enc_y1).powi(2);
249
250 let pi = std::f64::consts::PI;
252 let v = {
253 let atan_self = (self.width() / self.height().max(1e-12)).atan();
254 let atan_other = (other.width() / other.height().max(1e-12)).atan();
255 (4.0 / (pi * pi)) * (atan_self - atan_other).powi(2)
256 };
257
258 let alpha = if (1.0 - iou_val + v).abs() < 1e-12 {
259 0.0
260 } else {
261 v / (1.0 - iou_val + v)
262 };
263
264 let distance_term = if c_sq > 0.0 { d_sq / c_sq } else { 0.0 };
265 iou_val - distance_term - alpha * v
266 }
267
268 pub fn contains_point(&self, px: f64, py: f64) -> bool {
270 px >= self.x1 && px <= self.x2 && py >= self.y1 && py <= self.y2
271 }
272}
273
274pub fn nms(boxes: &[DetectionBox], iou_threshold: f64) -> Vec<usize> {
287 if boxes.is_empty() {
288 return Vec::new();
289 }
290
291 let mut indices: Vec<usize> = (0..boxes.len()).collect();
293 indices.sort_by(|&a, &b| {
294 boxes[b]
295 .confidence
296 .partial_cmp(&boxes[a].confidence)
297 .unwrap_or(std::cmp::Ordering::Equal)
298 });
299
300 let mut keep = Vec::new();
301 let mut suppressed = vec![false; boxes.len()];
302
303 for &idx in &indices {
304 if suppressed[idx] {
305 continue;
306 }
307 keep.push(idx);
308 for &other in &indices {
309 if other != idx && !suppressed[other] && boxes[idx].iou(&boxes[other]) > iou_threshold {
310 suppressed[other] = true;
311 }
312 }
313 }
314 keep
315}
316
317pub fn soft_nms(boxes: &mut [DetectionBox], sigma: f64, score_threshold: f64) -> Vec<usize> {
326 if boxes.is_empty() {
327 return Vec::new();
328 }
329 let n = boxes.len();
330 let mut active: Vec<bool> = vec![true; n];
331 let mut keep = Vec::new();
332
333 for _ in 0..n {
334 let mut best_idx: Option<usize> = None;
336 let mut best_score = f64::NEG_INFINITY;
337 for (i, &is_active) in active.iter().enumerate() {
338 if is_active && boxes[i].confidence > best_score {
339 best_score = boxes[i].confidence;
340 best_idx = Some(i);
341 }
342 }
343 let best = match best_idx {
344 Some(i) => i,
345 None => break,
346 };
347
348 if boxes[best].confidence < score_threshold {
349 break;
350 }
351
352 keep.push(best);
353 active[best] = false;
354
355 for j in 0..n {
357 if active[j] {
358 let iou_val = boxes[best].iou(&boxes[j]);
359 if sigma > 0.0 {
360 boxes[j].confidence *= (-iou_val * iou_val / sigma).exp();
361 }
362 if boxes[j].confidence < score_threshold {
363 active[j] = false;
364 }
365 }
366 }
367 }
368 keep
369}
370
371pub fn batched_nms(boxes: &[DetectionBox], iou_threshold: f64) -> Vec<usize> {
375 if boxes.is_empty() {
376 return Vec::new();
377 }
378
379 let mut class_map: HashMap<usize, Vec<usize>> = HashMap::new();
381 for (i, b) in boxes.iter().enumerate() {
382 class_map.entry(b.class_id).or_default().push(i);
383 }
384
385 let mut keep = Vec::new();
386 for group_indices in class_map.values() {
387 let group_boxes: Vec<DetectionBox> =
388 group_indices.iter().map(|&i| boxes[i].clone()).collect();
389 let class_keep = nms(&group_boxes, iou_threshold);
390 for local_idx in class_keep {
391 keep.push(group_indices[local_idx]);
392 }
393 }
394 keep.sort_by(|&a, &b| {
396 boxes[b]
397 .confidence
398 .partial_cmp(&boxes[a].confidence)
399 .unwrap_or(std::cmp::Ordering::Equal)
400 });
401 keep
402}
403
404pub fn weighted_nms(boxes: &[DetectionBox], iou_threshold: f64) -> Vec<DetectionBox> {
409 if boxes.is_empty() {
410 return Vec::new();
411 }
412
413 let mut indices: Vec<usize> = (0..boxes.len()).collect();
414 indices.sort_by(|&a, &b| {
415 boxes[b]
416 .confidence
417 .partial_cmp(&boxes[a].confidence)
418 .unwrap_or(std::cmp::Ordering::Equal)
419 });
420
421 let mut used = vec![false; boxes.len()];
422 let mut result = Vec::new();
423
424 for &idx in &indices {
425 if used[idx] {
426 continue;
427 }
428 let mut cluster: Vec<usize> = vec![idx];
430 for &other in &indices {
431 if other != idx && !used[other] && boxes[idx].iou(&boxes[other]) > iou_threshold {
432 cluster.push(other);
433 }
434 }
435 for &c in &cluster {
437 used[c] = true;
438 }
439
440 let total_conf: f64 = cluster.iter().map(|&c| boxes[c].confidence).sum();
442 if total_conf <= 0.0 {
443 result.push(boxes[idx].clone());
444 continue;
445 }
446 let mut wx1 = 0.0;
447 let mut wy1 = 0.0;
448 let mut wx2 = 0.0;
449 let mut wy2 = 0.0;
450 for &c in &cluster {
451 let w = boxes[c].confidence;
452 wx1 += boxes[c].x1 * w;
453 wy1 += boxes[c].y1 * w;
454 wx2 += boxes[c].x2 * w;
455 wy2 += boxes[c].y2 * w;
456 }
457 let merged = DetectionBox {
458 x1: wx1 / total_conf,
459 y1: wy1 / total_conf,
460 x2: wx2 / total_conf,
461 y2: wy2 / total_conf,
462 confidence: boxes[idx].confidence, class_id: boxes[idx].class_id,
464 class_name: boxes[idx].class_name.clone(),
465 };
466 result.push(merged);
467 }
468 result
469}
470
471#[derive(Clone, Debug)]
477pub struct AnchorConfig {
478 pub feature_map_sizes: Vec<(usize, usize)>,
480 pub aspect_ratios: Vec<f64>,
482 pub scales: Vec<f64>,
484 pub image_size: (usize, usize),
486}
487
488pub fn generate_anchors(config: &AnchorConfig) -> Result<Vec<DetectionBox>> {
493 if config.feature_map_sizes.is_empty() {
494 return Err(VisionError::InvalidParameter(
495 "feature_map_sizes must not be empty".into(),
496 ));
497 }
498 if config.aspect_ratios.is_empty() {
499 return Err(VisionError::InvalidParameter(
500 "aspect_ratios must not be empty".into(),
501 ));
502 }
503 if config.scales.len() != config.feature_map_sizes.len() {
504 return Err(VisionError::InvalidParameter(
505 "scales length must match feature_map_sizes length".into(),
506 ));
507 }
508
509 let (img_w, img_h) = config.image_size;
510 let img_w = img_w as f64;
511 let img_h = img_h as f64;
512
513 let mut anchors = Vec::new();
514
515 for (level, &(fm_w, fm_h)) in config.feature_map_sizes.iter().enumerate() {
516 if fm_w == 0 || fm_h == 0 {
517 continue;
518 }
519 let step_x = img_w / fm_w as f64;
520 let step_y = img_h / fm_h as f64;
521 let scale = config.scales[level];
522
523 for row in 0..fm_h {
524 for col in 0..fm_w {
525 let cx = (col as f64 + 0.5) * step_x;
526 let cy = (row as f64 + 0.5) * step_y;
527
528 for &ar in &config.aspect_ratios {
529 let w = scale * ar.sqrt();
530 let h = scale / ar.sqrt();
531 anchors.push(DetectionBox::from_center(cx, cy, w, h));
532 }
533 }
534 }
535 }
536 Ok(anchors)
537}
538
539pub fn generate_ssd_anchors(
544 image_size: (usize, usize),
545 feature_maps: &[(usize, usize)],
546) -> Result<Vec<DetectionBox>> {
547 if feature_maps.is_empty() {
548 return Err(VisionError::InvalidParameter(
549 "feature_maps must not be empty".into(),
550 ));
551 }
552 let aspect_ratios = vec![1.0, 2.0, 0.5, 3.0, 1.0 / 3.0];
553 let min_scale = 0.2;
554 let max_scale = 0.9;
555 let num_levels = feature_maps.len();
556 let scales: Vec<f64> = (0..num_levels)
557 .map(|k| {
558 let s = if num_levels > 1 {
559 min_scale + (max_scale - min_scale) * (k as f64) / ((num_levels - 1) as f64)
560 } else {
561 (min_scale + max_scale) / 2.0
562 };
563 s * image_size.0.min(image_size.1) as f64
564 })
565 .collect();
566
567 generate_anchors(&AnchorConfig {
568 feature_map_sizes: feature_maps.to_vec(),
569 aspect_ratios,
570 scales,
571 image_size,
572 })
573}
574
575pub fn generate_yolo_anchors(
581 image_size: (usize, usize),
582 feature_map: (usize, usize),
583 anchor_wh: &[(f64, f64)],
584) -> Result<Vec<DetectionBox>> {
585 if anchor_wh.is_empty() {
586 return Err(VisionError::InvalidParameter(
587 "anchor_wh must not be empty".into(),
588 ));
589 }
590 let (fm_w, fm_h) = feature_map;
591 if fm_w == 0 || fm_h == 0 {
592 return Err(VisionError::InvalidParameter(
593 "feature map dimensions must be > 0".into(),
594 ));
595 }
596 let step_x = image_size.0 as f64 / fm_w as f64;
597 let step_y = image_size.1 as f64 / fm_h as f64;
598
599 let mut anchors = Vec::new();
600 for row in 0..fm_h {
601 for col in 0..fm_w {
602 let cx = (col as f64 + 0.5) * step_x;
603 let cy = (row as f64 + 0.5) * step_y;
604 for &(aw, ah) in anchor_wh {
605 anchors.push(DetectionBox::from_center(cx, cy, aw, ah));
606 }
607 }
608 }
609 Ok(anchors)
610}
611
612pub fn compute_ap(
622 predictions: &[DetectionBox],
623 ground_truth: &[DetectionBox],
624 iou_threshold: f64,
625) -> f64 {
626 if ground_truth.is_empty() {
627 return 0.0;
628 }
629 if predictions.is_empty() {
630 return 0.0;
631 }
632
633 let (precisions, recalls) = precision_recall_curve(predictions, ground_truth, iou_threshold);
634 if precisions.is_empty() {
635 return 0.0;
636 }
637
638 ap_from_pr(&precisions, &recalls)
640}
641
642pub fn compute_map(
648 predictions: &[Vec<DetectionBox>],
649 ground_truth: &[Vec<DetectionBox>],
650 iou_thresholds: &[f64],
651) -> f64 {
652 if predictions.is_empty() || ground_truth.is_empty() || iou_thresholds.is_empty() {
653 return 0.0;
654 }
655 let n = predictions.len().min(ground_truth.len());
656 let mut total_ap = 0.0;
657 let mut count = 0usize;
658
659 for threshold in iou_thresholds {
660 for i in 0..n {
661 total_ap += compute_ap(&predictions[i], &ground_truth[i], *threshold);
662 count += 1;
663 }
664 }
665 if count == 0 {
666 0.0
667 } else {
668 total_ap / count as f64
669 }
670}
671
672pub fn precision_recall_curve(
676 predictions: &[DetectionBox],
677 ground_truth: &[DetectionBox],
678 iou_threshold: f64,
679) -> (Vec<f64>, Vec<f64>) {
680 if ground_truth.is_empty() || predictions.is_empty() {
681 return (Vec::new(), Vec::new());
682 }
683
684 let mut sorted_preds: Vec<(usize, f64)> = predictions
686 .iter()
687 .enumerate()
688 .map(|(i, p)| (i, p.confidence))
689 .collect();
690 sorted_preds.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
691
692 let total_gt = ground_truth.len() as f64;
693 let mut matched_gt = vec![false; ground_truth.len()];
694 let mut tp = 0.0_f64;
695 let mut fp = 0.0_f64;
696 let mut precisions = Vec::with_capacity(sorted_preds.len());
697 let mut recalls = Vec::with_capacity(sorted_preds.len());
698
699 for &(pred_idx, _conf) in &sorted_preds {
700 let pred = &predictions[pred_idx];
701
702 let mut best_iou = 0.0;
704 let mut best_gt: Option<usize> = None;
705 for (gt_idx, gt) in ground_truth.iter().enumerate() {
706 if matched_gt[gt_idx] {
707 continue;
708 }
709 let iou_val = pred.iou(gt);
710 if iou_val > best_iou {
711 best_iou = iou_val;
712 best_gt = Some(gt_idx);
713 }
714 }
715
716 if best_iou >= iou_threshold {
717 if let Some(gt_idx) = best_gt {
718 tp += 1.0;
719 matched_gt[gt_idx] = true;
720 } else {
721 fp += 1.0;
722 }
723 } else {
724 fp += 1.0;
725 }
726
727 let precision = if (tp + fp) > 0.0 { tp / (tp + fp) } else { 0.0 };
728 let recall = tp / total_gt;
729 precisions.push(precision);
730 recalls.push(recall);
731 }
732
733 (precisions, recalls)
734}
735
736fn ap_from_pr(precisions: &[f64], recalls: &[f64]) -> f64 {
738 if precisions.is_empty() || recalls.is_empty() {
739 return 0.0;
740 }
741
742 let n = precisions.len();
744 let mut mrec = Vec::with_capacity(n + 2);
745 let mut mprec = Vec::with_capacity(n + 2);
746 mrec.push(0.0);
747 mprec.push(0.0);
748 for i in 0..n {
749 mrec.push(recalls[i]);
750 mprec.push(precisions[i]);
751 }
752 mrec.push(1.0);
753 mprec.push(0.0);
754
755 for i in (0..mprec.len() - 1).rev() {
757 if mprec[i + 1] > mprec[i] {
758 mprec[i] = mprec[i + 1];
759 }
760 }
761
762 let mut ap = 0.0;
764 for i in 1..mrec.len() {
765 if (mrec[i] - mrec[i - 1]).abs() > 1e-15 {
766 ap += (mrec[i] - mrec[i - 1]) * mprec[i];
767 }
768 }
769 ap
770}
771
772pub fn confusion_matrix(
779 predictions: &[DetectionBox],
780 ground_truth: &[DetectionBox],
781 iou_threshold: f64,
782) -> HashMap<usize, HashMap<usize, usize>> {
783 let mut matrix: HashMap<usize, HashMap<usize, usize>> = HashMap::new();
784 let mut matched_gt = vec![false; ground_truth.len()];
785
786 let mut sorted_indices: Vec<usize> = (0..predictions.len()).collect();
788 sorted_indices.sort_by(|&a, &b| {
789 predictions[b]
790 .confidence
791 .partial_cmp(&predictions[a].confidence)
792 .unwrap_or(std::cmp::Ordering::Equal)
793 });
794
795 for &pred_idx in &sorted_indices {
796 let pred = &predictions[pred_idx];
797 let mut best_iou = 0.0_f64;
798 let mut best_gt: Option<usize> = None;
799
800 for (gt_idx, gt) in ground_truth.iter().enumerate() {
801 if matched_gt[gt_idx] {
802 continue;
803 }
804 let iou_val = pred.iou(gt);
805 if iou_val > best_iou {
806 best_iou = iou_val;
807 best_gt = Some(gt_idx);
808 }
809 }
810
811 if best_iou >= iou_threshold {
812 if let Some(gt_idx) = best_gt {
813 matched_gt[gt_idx] = true;
814 let gt_class = ground_truth[gt_idx].class_id;
815 let pred_class = pred.class_id;
816 *matrix
817 .entry(gt_class)
818 .or_default()
819 .entry(pred_class)
820 .or_insert(0) += 1;
821 }
822 } else {
823 *matrix
825 .entry(usize::MAX)
826 .or_default()
827 .entry(pred.class_id)
828 .or_insert(0) += 1;
829 }
830 }
831
832 for (gt_idx, gt) in ground_truth.iter().enumerate() {
834 if !matched_gt[gt_idx] {
835 *matrix
836 .entry(gt.class_id)
837 .or_default()
838 .entry(usize::MAX)
839 .or_insert(0) += 1;
840 }
841 }
842
843 matrix
844}
845
846pub fn random_horizontal_flip(
855 boxes: &mut [DetectionBox],
856 image_width: f64,
857 probability: f64,
858 seed: u64,
859) {
860 let mut state = if seed == 0 { 0x5DEECE66D } else { seed };
862 let should_flip = {
863 state ^= state << 13;
864 state ^= state >> 7;
865 state ^= state << 17;
866 let rand_val = (state as f64) / (u64::MAX as f64);
867 rand_val.abs() < probability
868 };
869
870 if should_flip {
871 for b in boxes.iter_mut() {
872 let new_x1 = image_width - b.x2;
873 let new_x2 = image_width - b.x1;
874 b.x1 = new_x1;
875 b.x2 = new_x2;
876 }
877 }
878}
879
880pub fn random_crop_with_boxes(
886 boxes: &[DetectionBox],
887 crop_box: &DetectionBox,
888) -> Vec<DetectionBox> {
889 let mut result = Vec::new();
890 for b in boxes {
891 let clipped_x1 = b.x1.max(crop_box.x1);
893 let clipped_y1 = b.y1.max(crop_box.y1);
894 let clipped_x2 = b.x2.min(crop_box.x2);
895 let clipped_y2 = b.y2.min(crop_box.y2);
896
897 if clipped_x1 >= clipped_x2 || clipped_y1 >= clipped_y2 {
899 continue;
900 }
901
902 result.push(DetectionBox {
904 x1: clipped_x1 - crop_box.x1,
905 y1: clipped_y1 - crop_box.y1,
906 x2: clipped_x2 - crop_box.x1,
907 y2: clipped_y2 - crop_box.y1,
908 confidence: b.confidence,
909 class_id: b.class_id,
910 class_name: b.class_name.clone(),
911 });
912 }
913 result
914}
915
916pub fn scale_boxes(boxes: &mut [DetectionBox], sx: f64, sy: f64) {
918 for b in boxes.iter_mut() {
919 b.x1 *= sx;
920 b.y1 *= sy;
921 b.x2 *= sx;
922 b.y2 *= sy;
923 }
924}
925
926pub fn translate_boxes(boxes: &mut [DetectionBox], tx: f64, ty: f64) {
928 for b in boxes.iter_mut() {
929 b.x1 += tx;
930 b.y1 += ty;
931 b.x2 += tx;
932 b.y2 += ty;
933 }
934}
935
936pub fn clip_boxes(boxes: &mut Vec<DetectionBox>, image_width: f64, image_height: f64) {
940 for b in boxes.iter_mut() {
941 b.x1 = b.x1.max(0.0).min(image_width);
942 b.y1 = b.y1.max(0.0).min(image_height);
943 b.x2 = b.x2.max(0.0).min(image_width);
944 b.y2 = b.y2.max(0.0).min(image_height);
945 }
946 boxes.retain(|b| b.width() > 0.0 && b.height() > 0.0);
947}
948
949pub fn filter_by_confidence(boxes: &[DetectionBox], threshold: f64) -> Vec<DetectionBox> {
951 boxes
952 .iter()
953 .filter(|b| b.confidence >= threshold)
954 .cloned()
955 .collect()
956}
957
958pub fn boxes_to_tuples(boxes: &[DetectionBox]) -> Vec<(f64, f64, f64, f64, f64, usize)> {
960 boxes
961 .iter()
962 .map(|b| (b.x1, b.y1, b.x2, b.y2, b.confidence, b.class_id))
963 .collect()
964}
965
966#[cfg(test)]
971mod tests {
972 use super::*;
973
974 fn approx_eq(a: f64, b: f64, eps: f64) -> bool {
975 (a - b).abs() < eps
976 }
977
978 #[test]
983 fn test_new_normalises_coordinates() {
984 let b = DetectionBox::new(50.0, 40.0, 10.0, 20.0);
985 assert!(b.x1 <= b.x2);
986 assert!(b.y1 <= b.y2);
987 assert_eq!(b.x1, 10.0);
988 assert_eq!(b.y1, 20.0);
989 assert_eq!(b.x2, 50.0);
990 assert_eq!(b.y2, 40.0);
991 }
992
993 #[test]
994 fn test_from_center() {
995 let b = DetectionBox::from_center(100.0, 100.0, 40.0, 20.0);
996 assert!(approx_eq(b.x1, 80.0, 1e-10));
997 assert!(approx_eq(b.y1, 90.0, 1e-10));
998 assert!(approx_eq(b.x2, 120.0, 1e-10));
999 assert!(approx_eq(b.y2, 110.0, 1e-10));
1000 }
1001
1002 #[test]
1003 fn test_area() {
1004 let b = DetectionBox::new(0.0, 0.0, 10.0, 20.0);
1005 assert!(approx_eq(b.area(), 200.0, 1e-10));
1006 }
1007
1008 #[test]
1009 fn test_center() {
1010 let b = DetectionBox::new(10.0, 20.0, 30.0, 40.0);
1011 let (cx, cy) = b.center();
1012 assert!(approx_eq(cx, 20.0, 1e-10));
1013 assert!(approx_eq(cy, 30.0, 1e-10));
1014 }
1015
1016 #[test]
1017 fn test_width_height() {
1018 let b = DetectionBox::new(5.0, 10.0, 25.0, 40.0);
1019 assert!(approx_eq(b.width(), 20.0, 1e-10));
1020 assert!(approx_eq(b.height(), 30.0, 1e-10));
1021 }
1022
1023 #[test]
1024 fn test_aspect_ratio() {
1025 let b = DetectionBox::new(0.0, 0.0, 20.0, 10.0);
1026 assert!(approx_eq(b.aspect_ratio(), 2.0, 1e-10));
1027 }
1028
1029 #[test]
1030 fn test_contains_point() {
1031 let b = DetectionBox::new(10.0, 10.0, 50.0, 50.0);
1032 assert!(b.contains_point(30.0, 30.0));
1033 assert!(b.contains_point(10.0, 10.0));
1034 assert!(!b.contains_point(5.0, 5.0));
1035 assert!(!b.contains_point(55.0, 55.0));
1036 }
1037
1038 #[test]
1039 fn test_builder_methods() {
1040 let b = DetectionBox::new(0.0, 0.0, 10.0, 10.0)
1041 .with_confidence(0.95)
1042 .with_class(3, Some("person".to_string()));
1043 assert!(approx_eq(b.confidence, 0.95, 1e-10));
1044 assert_eq!(b.class_id, 3);
1045 assert_eq!(b.class_name.as_deref(), Some("person"));
1046 }
1047
1048 #[test]
1053 fn test_iou_identical() {
1054 let b = DetectionBox::new(10.0, 10.0, 50.0, 50.0);
1055 assert!(approx_eq(b.iou(&b), 1.0, 1e-10));
1056 }
1057
1058 #[test]
1059 fn test_iou_no_overlap() {
1060 let a = DetectionBox::new(0.0, 0.0, 10.0, 10.0);
1061 let b = DetectionBox::new(20.0, 20.0, 30.0, 30.0);
1062 assert!(approx_eq(a.iou(&b), 0.0, 1e-10));
1063 }
1064
1065 #[test]
1066 fn test_iou_partial_overlap() {
1067 let a = DetectionBox::new(0.0, 0.0, 10.0, 10.0);
1068 let b = DetectionBox::new(5.0, 5.0, 15.0, 15.0);
1069 let expected = 25.0 / 175.0;
1071 assert!(approx_eq(a.iou(&b), expected, 1e-10));
1072 }
1073
1074 #[test]
1075 fn test_giou_identical() {
1076 let b = DetectionBox::new(10.0, 10.0, 50.0, 50.0);
1077 assert!(approx_eq(b.giou(&b), 1.0, 1e-10));
1078 }
1079
1080 #[test]
1081 fn test_giou_no_overlap() {
1082 let a = DetectionBox::new(0.0, 0.0, 10.0, 10.0);
1083 let b = DetectionBox::new(20.0, 20.0, 30.0, 30.0);
1084 let val = a.giou(&b);
1086 assert!(val < 0.0);
1087 }
1088
1089 #[test]
1090 fn test_giou_range() {
1091 let a = DetectionBox::new(0.0, 0.0, 1.0, 1.0);
1093 let b = DetectionBox::new(100.0, 100.0, 101.0, 101.0);
1094 let val = a.giou(&b);
1095 assert!((-1.0..=1.0).contains(&val));
1096 }
1097
1098 #[test]
1099 fn test_diou_identical() {
1100 let b = DetectionBox::new(10.0, 10.0, 50.0, 50.0);
1101 assert!(approx_eq(b.diou(&b), 1.0, 1e-10));
1102 }
1103
1104 #[test]
1105 fn test_ciou_identical() {
1106 let b = DetectionBox::new(10.0, 10.0, 50.0, 50.0);
1107 assert!(approx_eq(b.ciou(&b), 1.0, 1e-10));
1108 }
1109
1110 #[test]
1111 fn test_ciou_different_aspect_ratios() {
1112 let a = DetectionBox::new(0.0, 0.0, 100.0, 50.0); let b = DetectionBox::new(0.0, 0.0, 50.0, 100.0); let ciou_val = a.ciou(&b);
1115 let iou_val = a.iou(&b);
1117 assert!(ciou_val <= iou_val);
1118 }
1119
1120 #[test]
1125 fn test_nms_basic() {
1126 let boxes = vec![
1127 DetectionBox::new(0.0, 0.0, 10.0, 10.0).with_confidence(0.9),
1128 DetectionBox::new(1.0, 1.0, 11.0, 11.0).with_confidence(0.7),
1129 DetectionBox::new(200.0, 200.0, 210.0, 210.0).with_confidence(0.8),
1130 ];
1131 let kept = nms(&boxes, 0.5);
1132 assert_eq!(kept.len(), 2);
1133 assert!(kept.contains(&0)); assert!(kept.contains(&2)); }
1136
1137 #[test]
1138 fn test_nms_empty() {
1139 let kept = nms(&[], 0.5);
1140 assert!(kept.is_empty());
1141 }
1142
1143 #[test]
1144 fn test_nms_no_suppression() {
1145 let boxes = vec![
1146 DetectionBox::new(0.0, 0.0, 10.0, 10.0).with_confidence(0.9),
1147 DetectionBox::new(100.0, 100.0, 110.0, 110.0).with_confidence(0.8),
1148 ];
1149 let kept = nms(&boxes, 0.5);
1150 assert_eq!(kept.len(), 2);
1151 }
1152
1153 #[test]
1154 fn test_soft_nms_reduces_scores() {
1155 let mut boxes = vec![
1156 DetectionBox::new(0.0, 0.0, 10.0, 10.0).with_confidence(0.9),
1157 DetectionBox::new(1.0, 1.0, 11.0, 11.0).with_confidence(0.8),
1158 DetectionBox::new(200.0, 200.0, 210.0, 210.0).with_confidence(0.85),
1159 ];
1160 let original_conf = boxes[1].confidence;
1161 let kept = soft_nms(&mut boxes, 0.5, 0.01);
1162 assert!(boxes[1].confidence < original_conf);
1164 assert!(!kept.is_empty());
1165 }
1166
1167 #[test]
1168 fn test_batched_nms_separates_classes() {
1169 let boxes = vec![
1170 DetectionBox::new(0.0, 0.0, 10.0, 10.0)
1171 .with_confidence(0.9)
1172 .with_class(0, None),
1173 DetectionBox::new(1.0, 1.0, 11.0, 11.0)
1174 .with_confidence(0.8)
1175 .with_class(1, None), ];
1177 let kept = batched_nms(&boxes, 0.5);
1178 assert_eq!(kept.len(), 2); }
1180
1181 #[test]
1182 fn test_batched_nms_suppresses_same_class() {
1183 let boxes = vec![
1184 DetectionBox::new(0.0, 0.0, 10.0, 10.0)
1185 .with_confidence(0.9)
1186 .with_class(0, None),
1187 DetectionBox::new(1.0, 1.0, 11.0, 11.0)
1188 .with_confidence(0.7)
1189 .with_class(0, None),
1190 ];
1191 let kept = batched_nms(&boxes, 0.5);
1192 assert_eq!(kept.len(), 1);
1193 }
1194
1195 #[test]
1196 fn test_weighted_nms_merges() {
1197 let boxes = vec![
1198 DetectionBox::new(0.0, 0.0, 10.0, 10.0).with_confidence(0.9),
1199 DetectionBox::new(1.0, 1.0, 11.0, 11.0).with_confidence(0.8),
1200 ];
1201 let merged = weighted_nms(&boxes, 0.3);
1202 assert_eq!(merged.len(), 1);
1203 assert!(merged[0].x1 > 0.0 && merged[0].x1 < 1.0);
1205 }
1206
1207 #[test]
1212 fn test_generate_anchors_basic() {
1213 let config = AnchorConfig {
1214 feature_map_sizes: vec![(4, 4)],
1215 aspect_ratios: vec![1.0, 2.0],
1216 scales: vec![32.0],
1217 image_size: (128, 128),
1218 };
1219 let anchors = generate_anchors(&config);
1220 assert!(anchors.is_ok());
1221 let anchors = anchors.expect("should succeed");
1222 assert_eq!(anchors.len(), 32);
1224 }
1225
1226 #[test]
1227 fn test_generate_anchors_empty_feature_maps() {
1228 let config = AnchorConfig {
1229 feature_map_sizes: vec![],
1230 aspect_ratios: vec![1.0],
1231 scales: vec![],
1232 image_size: (100, 100),
1233 };
1234 let result = generate_anchors(&config);
1235 assert!(result.is_err());
1236 }
1237
1238 #[test]
1239 fn test_generate_ssd_anchors() {
1240 let anchors = generate_ssd_anchors((300, 300), &[(38, 38), (19, 19), (10, 10)]);
1241 assert!(anchors.is_ok());
1242 let anchors = anchors.expect("should succeed");
1243 assert_eq!(anchors.len(), 9525);
1245 }
1246
1247 #[test]
1248 fn test_generate_yolo_anchors() {
1249 let anchor_wh = vec![(10.0, 13.0), (16.0, 30.0), (33.0, 23.0)];
1250 let anchors = generate_yolo_anchors((416, 416), (13, 13), &anchor_wh);
1251 assert!(anchors.is_ok());
1252 let anchors = anchors.expect("should succeed");
1253 assert_eq!(anchors.len(), 13 * 13 * 3);
1254 }
1255
1256 #[test]
1257 fn test_generate_yolo_anchors_empty() {
1258 let result = generate_yolo_anchors((416, 416), (13, 13), &[]);
1259 assert!(result.is_err());
1260 }
1261
1262 #[test]
1267 fn test_compute_ap_perfect() {
1268 let preds = vec![DetectionBox::new(0.0, 0.0, 10.0, 10.0).with_confidence(0.9)];
1269 let gt = vec![DetectionBox::new(0.0, 0.0, 10.0, 10.0)];
1270 let ap = compute_ap(&preds, >, 0.5);
1271 assert!(approx_eq(ap, 1.0, 1e-10));
1272 }
1273
1274 #[test]
1275 fn test_compute_ap_no_predictions() {
1276 let gt = vec![DetectionBox::new(0.0, 0.0, 10.0, 10.0)];
1277 let ap = compute_ap(&[], >, 0.5);
1278 assert!(approx_eq(ap, 0.0, 1e-10));
1279 }
1280
1281 #[test]
1282 fn test_compute_ap_no_ground_truth() {
1283 let preds = vec![DetectionBox::new(0.0, 0.0, 10.0, 10.0).with_confidence(0.9)];
1284 let ap = compute_ap(&preds, &[], 0.5);
1285 assert!(approx_eq(ap, 0.0, 1e-10));
1286 }
1287
1288 #[test]
1289 fn test_compute_map() {
1290 let preds = vec![vec![
1291 DetectionBox::new(0.0, 0.0, 10.0, 10.0).with_confidence(0.9)
1292 ]];
1293 let gt = vec![vec![DetectionBox::new(0.0, 0.0, 10.0, 10.0)]];
1294 let thresholds = vec![0.5, 0.75];
1295 let map_val = compute_map(&preds, >, &thresholds);
1296 assert!(approx_eq(map_val, 1.0, 1e-10));
1297 }
1298
1299 #[test]
1300 fn test_precision_recall_curve_basic() {
1301 let preds = vec![
1302 DetectionBox::new(0.0, 0.0, 10.0, 10.0).with_confidence(0.9),
1303 DetectionBox::new(100.0, 100.0, 110.0, 110.0).with_confidence(0.5),
1304 ];
1305 let gt = vec![DetectionBox::new(0.0, 0.0, 10.0, 10.0)];
1306 let (prec, rec) = precision_recall_curve(&preds, >, 0.5);
1307 assert_eq!(prec.len(), 2);
1308 assert_eq!(rec.len(), 2);
1309 assert!(approx_eq(prec[0], 1.0, 1e-10));
1311 assert!(approx_eq(rec[0], 1.0, 1e-10));
1312 }
1313
1314 #[test]
1315 fn test_confusion_matrix_basic() {
1316 let preds = vec![
1317 DetectionBox::new(0.0, 0.0, 10.0, 10.0)
1318 .with_confidence(0.9)
1319 .with_class(0, None),
1320 DetectionBox::new(200.0, 200.0, 210.0, 210.0)
1321 .with_confidence(0.8)
1322 .with_class(1, None),
1323 ];
1324 let gt = vec![DetectionBox::new(0.0, 0.0, 10.0, 10.0).with_class(0, None)];
1325 let cm = confusion_matrix(&preds, >, 0.5);
1326 assert_eq!(*cm.get(&0).and_then(|m| m.get(&0)).unwrap_or(&0), 1);
1328 assert_eq!(
1330 *cm.get(&usize::MAX).and_then(|m| m.get(&1)).unwrap_or(&0),
1331 1
1332 );
1333 }
1334
1335 #[test]
1340 fn test_scale_boxes() {
1341 let mut boxes = vec![DetectionBox::new(10.0, 20.0, 30.0, 40.0)];
1342 scale_boxes(&mut boxes, 2.0, 0.5);
1343 assert!(approx_eq(boxes[0].x1, 20.0, 1e-10));
1344 assert!(approx_eq(boxes[0].y1, 10.0, 1e-10));
1345 assert!(approx_eq(boxes[0].x2, 60.0, 1e-10));
1346 assert!(approx_eq(boxes[0].y2, 20.0, 1e-10));
1347 }
1348
1349 #[test]
1350 fn test_translate_boxes() {
1351 let mut boxes = vec![DetectionBox::new(10.0, 20.0, 30.0, 40.0)];
1352 translate_boxes(&mut boxes, 5.0, -5.0);
1353 assert!(approx_eq(boxes[0].x1, 15.0, 1e-10));
1354 assert!(approx_eq(boxes[0].y1, 15.0, 1e-10));
1355 assert!(approx_eq(boxes[0].x2, 35.0, 1e-10));
1356 assert!(approx_eq(boxes[0].y2, 35.0, 1e-10));
1357 }
1358
1359 #[test]
1360 fn test_clip_boxes() {
1361 let mut boxes = vec![
1362 DetectionBox::new(-5.0, -5.0, 50.0, 50.0),
1363 DetectionBox::new(-10.0, -10.0, -1.0, -1.0), ];
1365 clip_boxes(&mut boxes, 100.0, 100.0);
1366 assert_eq!(boxes.len(), 1); assert!(approx_eq(boxes[0].x1, 0.0, 1e-10));
1368 assert!(approx_eq(boxes[0].y1, 0.0, 1e-10));
1369 }
1370
1371 #[test]
1372 fn test_random_crop_with_boxes() {
1373 let boxes = vec![
1374 DetectionBox::new(10.0, 10.0, 50.0, 50.0).with_confidence(0.9),
1375 DetectionBox::new(200.0, 200.0, 250.0, 250.0).with_confidence(0.8), ];
1377 let crop = DetectionBox::new(0.0, 0.0, 100.0, 100.0);
1378 let result = random_crop_with_boxes(&boxes, &crop);
1379 assert_eq!(result.len(), 1);
1380 assert!(approx_eq(result[0].x1, 10.0, 1e-10));
1381 assert!(approx_eq(result[0].confidence, 0.9, 1e-10));
1382 }
1383
1384 #[test]
1385 fn test_random_horizontal_flip_deterministic() {
1386 let mut boxes = vec![DetectionBox::new(10.0, 0.0, 30.0, 20.0)];
1388 random_horizontal_flip(&mut boxes, 100.0, 1.0, 42);
1389 assert!(approx_eq(boxes[0].x1, 70.0, 1e-10));
1391 assert!(approx_eq(boxes[0].x2, 90.0, 1e-10));
1392 }
1393
1394 #[test]
1395 fn test_filter_by_confidence() {
1396 let boxes = vec![
1397 DetectionBox::new(0.0, 0.0, 10.0, 10.0).with_confidence(0.9),
1398 DetectionBox::new(0.0, 0.0, 10.0, 10.0).with_confidence(0.3),
1399 DetectionBox::new(0.0, 0.0, 10.0, 10.0).with_confidence(0.6),
1400 ];
1401 let filtered = filter_by_confidence(&boxes, 0.5);
1402 assert_eq!(filtered.len(), 2);
1403 }
1404
1405 #[test]
1406 fn test_boxes_to_tuples() {
1407 let boxes = vec![DetectionBox::new(1.0, 2.0, 3.0, 4.0)
1408 .with_confidence(0.5)
1409 .with_class(7, None)];
1410 let tuples = boxes_to_tuples(&boxes);
1411 assert_eq!(tuples.len(), 1);
1412 let (x1, y1, x2, y2, c, cls) = tuples[0];
1413 assert!(approx_eq(x1, 1.0, 1e-10));
1414 assert!(approx_eq(y1, 2.0, 1e-10));
1415 assert!(approx_eq(x2, 3.0, 1e-10));
1416 assert!(approx_eq(y2, 4.0, 1e-10));
1417 assert!(approx_eq(c, 0.5, 1e-10));
1418 assert_eq!(cls, 7);
1419 }
1420
1421 #[test]
1422 fn test_intersection_area() {
1423 let a = DetectionBox::new(0.0, 0.0, 10.0, 10.0);
1424 let b = DetectionBox::new(5.0, 5.0, 15.0, 15.0);
1425 assert!(approx_eq(a.intersection_area(&b), 25.0, 1e-10));
1426 }
1427
1428 #[test]
1429 fn test_union_area() {
1430 let a = DetectionBox::new(0.0, 0.0, 10.0, 10.0);
1431 let b = DetectionBox::new(5.0, 5.0, 15.0, 15.0);
1432 assert!(approx_eq(a.union_area(&b), 175.0, 1e-10));
1433 }
1434
1435 #[test]
1436 fn test_zero_area_box() {
1437 let b = DetectionBox::new(5.0, 5.0, 5.0, 5.0);
1438 assert!(approx_eq(b.area(), 0.0, 1e-10));
1439 assert!(approx_eq(b.width(), 0.0, 1e-10));
1440 assert!(approx_eq(b.height(), 0.0, 1e-10));
1441 }
1442
1443 #[test]
1444 fn test_iou_symmetry() {
1445 let a = DetectionBox::new(0.0, 0.0, 10.0, 10.0);
1446 let b = DetectionBox::new(5.0, 0.0, 15.0, 10.0);
1447 assert!(approx_eq(a.iou(&b), b.iou(&a), 1e-12));
1448 }
1449
1450 #[test]
1451 fn test_nms_all_identical() {
1452 let boxes = vec![
1453 DetectionBox::new(0.0, 0.0, 10.0, 10.0).with_confidence(0.9),
1454 DetectionBox::new(0.0, 0.0, 10.0, 10.0).with_confidence(0.8),
1455 DetectionBox::new(0.0, 0.0, 10.0, 10.0).with_confidence(0.7),
1456 ];
1457 let kept = nms(&boxes, 0.5);
1458 assert_eq!(kept.len(), 1);
1460 assert_eq!(kept[0], 0);
1461 }
1462
1463 #[test]
1464 fn test_compute_ap_mixed() {
1465 let preds = vec![
1467 DetectionBox::new(0.0, 0.0, 10.0, 10.0).with_confidence(0.9), DetectionBox::new(500.0, 500.0, 510.0, 510.0).with_confidence(0.8), DetectionBox::new(100.0, 100.0, 110.0, 110.0).with_confidence(0.7), ];
1471 let gt = vec![
1472 DetectionBox::new(0.0, 0.0, 10.0, 10.0),
1473 DetectionBox::new(100.0, 100.0, 110.0, 110.0),
1474 ];
1475 let ap = compute_ap(&preds, >, 0.5);
1476 assert!(ap > 0.5);
1478 }
1479
1480 #[test]
1481 fn test_anchor_centres_are_within_image() {
1482 let config = AnchorConfig {
1483 feature_map_sizes: vec![(8, 8)],
1484 aspect_ratios: vec![1.0],
1485 scales: vec![16.0],
1486 image_size: (256, 256),
1487 };
1488 let anchors = generate_anchors(&config).expect("should succeed");
1489 for a in &anchors {
1490 let (cx, cy) = a.center();
1491 assert!(cx > 0.0 && cx < 256.0, "cx={cx} out of range");
1492 assert!(cy > 0.0 && cy < 256.0, "cy={cy} out of range");
1493 }
1494 }
1495}