ultralytics_inference/results.rs
1// Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
3//! Results classes for YOLO inference output.
4//!
5//! This module provides Ultralytics-compatible result classes that match
6//! the Python API for easy migration and consistent usage patterns.
7
8use std::collections::HashMap;
9
10use ndarray::{Array1, Array2, Array3, ArrayView1, ArrayView2, Axis, s};
11
12/// Timing information for inference operations (in milliseconds).
13#[derive(Debug, Clone, Default)]
14pub struct Speed {
15 /// Time spent on preprocessing.
16 pub preprocess: Option<f64>,
17 /// Time spent on model inference.
18 pub inference: Option<f64>,
19 /// Time spent on postprocessing.
20 pub postprocess: Option<f64>,
21}
22
23impl Speed {
24 /// Create a new Speed instance with all timings.
25 ///
26 /// # Arguments
27 ///
28 /// * `preprocess` - Time in milliseconds.
29 /// * `inference` - Time in milliseconds.
30 /// * `postprocess` - Time in milliseconds.
31 ///
32 /// # Returns
33 ///
34 /// * A new `Speed` instance.
35 #[must_use]
36 pub const fn new(preprocess: f64, inference: f64, postprocess: f64) -> Self {
37 Self {
38 preprocess: Some(preprocess),
39 inference: Some(inference),
40 postprocess: Some(postprocess),
41 }
42 }
43
44 /// Get total inference time.
45 ///
46 /// # Returns
47 ///
48 /// * Sum of preprocess, inference, and postprocess times in milliseconds.
49 #[must_use]
50 pub fn total(&self) -> f64 {
51 self.preprocess.unwrap_or(0.0)
52 + self.inference.unwrap_or(0.0)
53 + self.postprocess.unwrap_or(0.0)
54 }
55}
56
57/// Main results container for YOLO inference.
58///
59/// Contains the original image, detection results (boxes, masks, keypoints, etc.), timing information, and metadata.
60#[derive(Debug, Clone)]
61pub struct Results {
62 /// Original image as HWC array (height, width, channels).
63 pub orig_img: Array3<u8>,
64 /// Original image shape (height, width).
65 pub orig_shape: (u32, u32),
66 /// Inference tensor shape (height, width) after letterboxing.
67 pub inference_shape: (u32, u32),
68 /// Detection bounding boxes (if applicable).
69 pub boxes: Option<Boxes>,
70 /// Segmentation masks (if applicable).
71 pub masks: Option<Masks>,
72 /// Pose keypoints (if applicable).
73 pub keypoints: Option<Keypoints>,
74 /// Classification probabilities (if applicable).
75 pub probs: Option<Probs>,
76 /// Oriented bounding boxes (if applicable).
77 pub obb: Option<Obb>,
78 /// Inference timing information.
79 pub speed: Speed,
80 /// Class ID to name mapping.
81 pub names: HashMap<usize, String>,
82 /// Path to the source image/video.
83 pub path: String,
84}
85
86impl Results {
87 /// Create a new Results instance.
88 ///
89 /// # Arguments
90 ///
91 /// * `orig_img` - Original image as HWC array.
92 /// * `path` - Path to the source image/video.
93 /// * `names` - Map of class IDs to class names.
94 /// * `speed` - Timing information.
95 /// * `inference_shape` - Shape of the inference tensor (height, width).
96 ///
97 /// # Returns
98 ///
99 /// * A new `Results` instance.
100 #[must_use]
101 pub fn new(
102 orig_img: Array3<u8>,
103 path: String,
104 names: HashMap<usize, String>,
105 speed: Speed,
106 inference_shape: (u32, u32),
107 ) -> Self {
108 let shape = orig_img.shape();
109 #[allow(clippy::cast_possible_truncation)]
110 let orig_shape = (shape[0] as u32, shape[1] as u32);
111
112 Self {
113 orig_img,
114 orig_shape,
115 inference_shape,
116 boxes: None,
117 masks: None,
118 keypoints: None,
119 probs: None,
120 obb: None,
121 speed,
122 names,
123 path,
124 }
125 }
126
127 /// Get the number of detections.
128 ///
129 /// # Returns
130 ///
131 /// * The count of detected objects, keyspoints, or masks.
132 #[must_use]
133 pub fn len(&self) -> usize {
134 if let Some(ref boxes) = self.boxes {
135 return boxes.len();
136 }
137 if let Some(ref masks) = self.masks {
138 return masks.len();
139 }
140 if let Some(ref keypoints) = self.keypoints {
141 return keypoints.len();
142 }
143 if let Some(ref probs) = self.probs {
144 return usize::from(!probs.data.is_empty());
145 }
146 if let Some(ref obb) = self.obb {
147 return obb.len();
148 }
149 0
150 }
151
152 /// Check if there are no detections.
153 ///
154 /// # Returns
155 ///
156 /// * `true` if no objects were detected.
157 #[must_use]
158 pub fn is_empty(&self) -> bool {
159 self.len() == 0
160 }
161
162 /// Get the original image shape (height, width).
163 ///
164 /// # Returns
165 ///
166 /// * Tuple of (height, width).
167 #[must_use]
168 pub const fn orig_shape(&self) -> (u32, u32) {
169 self.orig_shape
170 }
171
172 /// Get the inference tensor shape (height, width) after letterboxing.
173 ///
174 /// # Returns
175 ///
176 /// * Tuple of (height, width).
177 #[must_use]
178 pub const fn inference_shape(&self) -> (u32, u32) {
179 self.inference_shape
180 }
181
182 /// Generate a verbose log string describing the results.
183 ///
184 /// # Returns
185 ///
186 /// * A string summary of detections (e.g., "2 persons, 1 car, ").
187 #[must_use]
188 pub fn verbose(&self) -> String {
189 if self.is_empty() {
190 if self.probs.is_some() {
191 return String::new();
192 }
193 return "(no detections), ".to_string();
194 }
195
196 if let Some(ref probs) = self.probs {
197 let top5: Vec<String> = probs
198 .top5()
199 .iter()
200 .map(|&i| {
201 let name = self.names.get(&i).cloned().unwrap_or_else(|| i.to_string());
202 format!("{} {:.2}", name, probs.data[i])
203 })
204 .collect();
205 return format!("{}, ", top5.join(", "));
206 }
207
208 if let Some(ref boxes) = self.boxes {
209 let cls = boxes.cls();
210 let mut counts: HashMap<usize, usize> = HashMap::new();
211 for &c in cls {
212 #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
213 let c = c as usize;
214 *counts.entry(c).or_insert(0) += 1;
215 }
216
217 let mut parts = Vec::new();
218 for (class_id, count) in &counts {
219 let name = self
220 .names
221 .get(class_id)
222 .cloned()
223 .unwrap_or_else(|| class_id.to_string());
224 let suffix = if *count > 1 { "s" } else { "" };
225 parts.push(format!("{count} {name}{suffix}"));
226 }
227 return format!("{}, ", parts.join(", "));
228 }
229
230 String::new()
231 }
232
233 /// Convert results to a list of dictionaries (summary format).
234 ///
235 /// # Arguments
236 ///
237 /// * `normalize` - Whether to normalize coordinates to [0, 1] range.
238 ///
239 /// # Returns
240 ///
241 /// * A vector of hashmaps representing the detections.
242 #[must_use]
243 pub fn summary(&self, normalize: bool) -> Vec<HashMap<String, SummaryValue>> {
244 let mut results = Vec::new();
245
246 if let Some(ref probs) = self.probs {
247 let class_id = probs.top1();
248 let mut entry = HashMap::new();
249 entry.insert(
250 "name".to_string(),
251 SummaryValue::String(
252 self.names
253 .get(&class_id)
254 .cloned()
255 .unwrap_or_else(|| class_id.to_string()),
256 ),
257 );
258 entry.insert("class".to_string(), SummaryValue::Int(class_id));
259 entry.insert(
260 "confidence".to_string(),
261 SummaryValue::Float(probs.top1conf()),
262 );
263 results.push(entry);
264 return results;
265 }
266
267 if let Some(ref boxes) = self.boxes {
268 let (h, w) = if normalize {
269 #[allow(clippy::cast_precision_loss)]
270 (self.orig_shape.0 as f32, self.orig_shape.1 as f32)
271 } else {
272 (1.0, 1.0)
273 };
274
275 let xyxy = boxes.xyxy();
276 let conf = boxes.conf();
277 let cls = boxes.cls();
278
279 for i in 0..boxes.len() {
280 #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
281 let class_id = cls[i] as usize;
282 let mut entry = HashMap::new();
283 entry.insert(
284 "name".to_string(),
285 SummaryValue::String(
286 self.names
287 .get(&class_id)
288 .cloned()
289 .unwrap_or_else(|| class_id.to_string()),
290 ),
291 );
292 entry.insert("class".to_string(), SummaryValue::Int(class_id));
293 entry.insert("confidence".to_string(), SummaryValue::Float(conf[i]));
294
295 let mut box_coords = HashMap::new();
296 box_coords.insert("x1".to_string(), SummaryValue::Float(xyxy[[i, 0]] / w));
297 box_coords.insert("y1".to_string(), SummaryValue::Float(xyxy[[i, 1]] / h));
298 box_coords.insert("x2".to_string(), SummaryValue::Float(xyxy[[i, 2]] / w));
299 box_coords.insert("y2".to_string(), SummaryValue::Float(xyxy[[i, 3]] / h));
300 entry.insert("box".to_string(), SummaryValue::Box(box_coords));
301
302 results.push(entry);
303 }
304 }
305
306 results
307 }
308
309 /// Save the annotated result to a file.
310 ///
311 /// # Arguments
312 ///
313 /// * `path` - The path to save the image to.
314 ///
315 /// # Errors
316 ///
317 /// Returns an error if the image cannot be saved or if the format is unsupported.
318 #[cfg(feature = "annotate")]
319 pub fn save<P: AsRef<std::path::Path>>(&self, path: P) -> crate::error::Result<()> {
320 let img = crate::utils::array_to_image(&self.orig_img)?;
321 let annotated = crate::annotate::annotate_image(&img, self, None);
322 annotated
323 .save(path)
324 .map_err(|e| crate::error::InferenceError::ImageError(e.to_string()))
325 }
326}
327
328/// Values that can appear in a summary dictionary.
329#[derive(Debug, Clone)]
330pub enum SummaryValue {
331 /// String value.
332 String(String),
333 /// Integer value.
334 Int(usize),
335 /// Float value.
336 Float(f32),
337 /// Box coordinates.
338 Box(HashMap<String, Self>),
339}
340
341/// Detection bounding boxes.
342///
343/// Stores bounding boxes in xyxy format along with confidence scores and class IDs.
344#[derive(Debug, Clone)]
345pub struct Boxes {
346 /// Raw data array with shape (N, 6) containing [x1, y1, x2, y2, conf, cls].
347 /// Or shape (N, 7) if tracking: [x1, y1, x2, y2, `track_id`, conf, cls].
348 pub data: Array2<f32>,
349 /// Original image shape (height, width) for normalization.
350 pub orig_shape: (u32, u32),
351 /// Whether tracking IDs are present.
352 is_track: bool,
353}
354
355impl Boxes {
356 /// Create a new Boxes instance.
357 ///
358 /// # Arguments
359 ///
360 /// * `data` - Array with shape (N, 6) or (N, 7) containing box data.
361 /// * `orig_shape` - Original image shape (height, width).
362 ///
363 /// # Returns
364 ///
365 /// * A new `Boxes` instance.
366 #[must_use]
367 pub fn new(data: Array2<f32>, orig_shape: (u32, u32)) -> Self {
368 let is_track = data.shape()[1] == 7;
369 Self {
370 data,
371 orig_shape,
372 is_track,
373 }
374 }
375
376 /// Get the number of boxes.
377 ///
378 /// # Returns
379 ///
380 /// * The count of bounding boxes.
381 #[must_use]
382 pub fn len(&self) -> usize {
383 self.data.nrows()
384 }
385
386 /// Check if there are no boxes.
387 ///
388 /// # Returns
389 ///
390 /// * `true` if the boxes array is empty.
391 #[must_use]
392 pub fn is_empty(&self) -> bool {
393 self.data.is_empty()
394 }
395
396 /// Get boxes in xyxy format [x1, y1, x2, y2].
397 ///
398 /// # Returns
399 ///
400 /// * A view of the box coordinates.
401 #[must_use]
402 pub fn xyxy(&self) -> ArrayView2<'_, f32> {
403 self.data.slice(s![.., 0..4])
404 }
405
406 /// Get confidence scores.
407 ///
408 /// # Returns
409 ///
410 /// * A view of confidence scores (0.0 to 1.0).
411 #[must_use]
412 pub fn conf(&self) -> ArrayView1<'_, f32> {
413 self.data.slice(s![.., -2])
414 }
415
416 /// Get class IDs.
417 ///
418 /// # Returns
419 ///
420 /// * A view of class IDs.
421 #[must_use]
422 pub fn cls(&self) -> ArrayView1<'_, f32> {
423 self.data.slice(s![.., -1])
424 }
425
426 /// Get tracking IDs (if available).
427 ///
428 /// # Returns
429 ///
430 /// * `Some` view of track IDs if this is a tracking result, otherwise `None`.
431 #[must_use]
432 pub fn id(&self) -> Option<ArrayView1<'_, f32>> {
433 if self.is_track {
434 Some(self.data.slice(s![.., -3]))
435 } else {
436 None
437 }
438 }
439
440 /// Get boxes in xywh format [`x_center`, `y_center`, width, height].
441 ///
442 /// # Returns
443 ///
444 /// * An owned array of boxes in xywh format.
445 #[must_use]
446 pub fn xywh(&self) -> Array2<f32> {
447 let xyxy = self.xyxy();
448 let n = xyxy.nrows();
449 let mut xywh = Array2::zeros((n, 4));
450
451 for i in 0..n {
452 let x1 = xyxy[[i, 0]];
453 let y1 = xyxy[[i, 1]];
454 let x2 = xyxy[[i, 2]];
455 let y2 = xyxy[[i, 3]];
456
457 xywh[[i, 0]] = f32::midpoint(x1, x2); // x_center
458 xywh[[i, 1]] = f32::midpoint(y1, y2); // y_center
459 xywh[[i, 2]] = x2 - x1; // width
460 xywh[[i, 3]] = y2 - y1; // height
461 }
462
463 xywh
464 }
465
466 /// Get boxes in xyxy format normalized by image size.
467 ///
468 /// # Returns
469 ///
470 /// * An owned array of normalized boxes [0.0-1.0].
471 #[must_use]
472 pub fn xyxyn(&self) -> Array2<f32> {
473 let mut xyxyn = self.xyxy().to_owned();
474 #[allow(clippy::cast_precision_loss)]
475 let (h, w) = (self.orig_shape.0 as f32, self.orig_shape.1 as f32);
476
477 for mut row in xyxyn.rows_mut() {
478 row[0] /= w;
479 row[1] /= h;
480 row[2] /= w;
481 row[3] /= h;
482 }
483
484 xyxyn
485 }
486
487 /// Get boxes in xywh format normalized by image size.
488 ///
489 /// # Returns
490 ///
491 /// * An owned array of normalized boxes [0.0-1.0].
492 #[must_use]
493 pub fn xywhn(&self) -> Array2<f32> {
494 let mut xywhn = self.xywh();
495 #[allow(clippy::cast_precision_loss)]
496 let (h, w) = (self.orig_shape.0 as f32, self.orig_shape.1 as f32);
497
498 for mut row in xywhn.rows_mut() {
499 row[0] /= w;
500 row[1] /= h;
501 row[2] /= w;
502 row[3] /= h;
503 }
504
505 xywhn
506 }
507
508 /// Check if tracking IDs are available.
509 ///
510 /// # Returns
511 ///
512 /// * `true` if the boxes contain tracking information.
513 #[must_use]
514 pub const fn is_track(&self) -> bool {
515 self.is_track
516 }
517}
518
519/// Segmentation masks.
520///
521/// Placeholder for future segmentation support.
522#[derive(Debug, Clone)]
523pub struct Masks {
524 /// Raw mask data with shape (N, H, W).
525 pub data: Array3<f32>,
526 /// Original image shape (height, width).
527 pub orig_shape: (u32, u32),
528}
529
530impl Masks {
531 /// Create a new Masks instance.
532 ///
533 /// # Arguments
534 ///
535 /// * `data` - Raw mask data with shape (N, H, W).
536 /// * `orig_shape` - Original image shape (height, width).
537 ///
538 /// # Returns
539 ///
540 /// * A new `Masks` instance.
541 #[must_use]
542 pub const fn new(data: Array3<f32>, orig_shape: (u32, u32)) -> Self {
543 Self { data, orig_shape }
544 }
545
546 /// Get the number of masks.
547 ///
548 /// # Returns
549 ///
550 /// * The count of masks.
551 #[must_use]
552 pub fn len(&self) -> usize {
553 self.data.shape()[0]
554 }
555
556 /// Check if there are no masks.
557 ///
558 /// # Returns
559 ///
560 /// * `true` if the masks array is empty.
561 #[must_use]
562 pub fn is_empty(&self) -> bool {
563 self.data.is_empty()
564 }
565
566 // TODO: Implement xy and xyn properties for segment coordinates
567}
568
569/// Pose keypoints.
570///
571/// Placeholder for future pose estimation support.
572#[derive(Debug, Clone)]
573pub struct Keypoints {
574 /// Raw keypoint data with shape (N, K, 2) or (N, K, 3) if confidence included.
575 pub data: Array3<f32>,
576 /// Original image shape (height, width).
577 pub orig_shape: (u32, u32),
578 /// Whether confidence values are included.
579 has_visible: bool,
580}
581
582impl Keypoints {
583 /// Create a new Keypoints instance.
584 ///
585 /// # Arguments
586 ///
587 /// * `data` - Raw keypoint data.
588 /// * `orig_shape` - Original image shape.
589 ///
590 /// # Returns
591 ///
592 /// * A new `Keypoints` instance.
593 #[must_use]
594 pub fn new(data: Array3<f32>, orig_shape: (u32, u32)) -> Self {
595 let has_visible = data.shape()[2] == 3;
596 Self {
597 data,
598 orig_shape,
599 has_visible,
600 }
601 }
602
603 /// Get the number of detected objects with keypoints.
604 ///
605 /// # Returns
606 ///
607 /// * The count of poses.
608 #[must_use]
609 pub fn len(&self) -> usize {
610 self.data.shape()[0]
611 }
612
613 /// Check if there are no keypoints.
614 ///
615 /// # Returns
616 ///
617 /// * `true` if no keypoints were detected.
618 #[must_use]
619 pub fn is_empty(&self) -> bool {
620 self.data.is_empty()
621 }
622
623 /// Get xy coordinates.
624 ///
625 /// # Returns
626 ///
627 /// * An owned array of keypoint coordinates.
628 #[must_use]
629 pub fn xy(&self) -> Array3<f32> {
630 self.data.slice(s![.., .., 0..2]).to_owned()
631 }
632
633 /// Get normalized xy coordinates.
634 ///
635 /// # Returns
636 ///
637 /// * An owned array of normalized keypoint coordinates.
638 #[must_use]
639 pub fn xyn(&self) -> Array3<f32> {
640 let mut xyn = self.xy();
641 #[allow(clippy::cast_precision_loss)]
642 let (h, w) = (self.orig_shape.0 as f32, self.orig_shape.1 as f32);
643
644 for mut point in xyn.axis_iter_mut(Axis(2)) {
645 if point.shape()[0] > 0 {
646 point.mapv_inplace(|v| v / w);
647 }
648 if point.shape()[0] > 1 {
649 point.mapv_inplace(|v| v / h);
650 }
651 }
652
653 xyn
654 }
655
656 /// Get confidence values (if available).
657 ///
658 /// # Returns
659 ///
660 /// * `Some` array of confidences if available, otherwise `None`.
661 #[must_use]
662 pub fn conf(&self) -> Option<Array2<f32>> {
663 if self.has_visible {
664 Some(self.data.slice(s![.., .., 2]).to_owned())
665 } else {
666 None
667 }
668 }
669}
670
671/// Classification probabilities.
672///
673/// Stores class probabilities with convenience methods for top predictions.
674#[derive(Debug, Clone)]
675pub struct Probs {
676 /// Raw probability data with shape (`num_classes`,).
677 pub data: Array1<f32>,
678}
679
680impl Probs {
681 /// Create a new Probs instance.
682 ///
683 /// # Arguments
684 ///
685 /// * `data` - Raw probability array.
686 ///
687 /// # Returns
688 ///
689 /// * A new `Probs` instance.
690 #[must_use]
691 pub const fn new(data: Array1<f32>) -> Self {
692 Self { data }
693 }
694
695 /// Get the index of the top-1 class.
696 ///
697 /// # Returns
698 ///
699 /// * The class ID with the highest probability.
700 ///
701 /// # Panics
702 ///
703 /// Panics if valid comparison cannot be made (e.g. NaN) in `max_by`.
704 #[must_use]
705 pub fn top1(&self) -> usize {
706 self.data
707 .iter()
708 .enumerate()
709 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
710 .map_or(0, |(i, _)| i)
711 }
712 /// Get the indices of the top-5 classes.
713 ///
714 /// # Returns
715 ///
716 /// * A vector of the top 5 class IDs sorted by probability.
717 #[must_use]
718 pub fn top5(&self) -> Vec<usize> {
719 self.top_k(5)
720 }
721
722 /// Get the indices of the top-k classes.
723 ///
724 /// # Arguments
725 ///
726 /// * `k` - The number of classes to return.
727 ///
728 /// # Returns
729 ///
730 /// * A vector of the top k class IDs sorted by probability.
731 #[must_use]
732 pub fn top_k(&self, k: usize) -> Vec<usize> {
733 let mut indices: Vec<usize> = (0..self.data.len()).collect();
734 indices.sort_by(|&a, &b| {
735 self.data[b]
736 .partial_cmp(&self.data[a])
737 .unwrap_or(std::cmp::Ordering::Equal)
738 });
739 indices.truncate(k);
740 indices
741 }
742
743 /// Get the confidence of the top-1 class.
744 ///
745 /// # Returns
746 ///
747 /// * The probability of the top class.
748 #[must_use]
749 pub fn top1conf(&self) -> f32 {
750 self.data[self.top1()]
751 }
752
753 /// Get the confidences of the top-5 classes.
754 ///
755 /// # Returns
756 ///
757 /// * A vector of the top 5 probabilities.
758 #[must_use]
759 pub fn top5conf(&self) -> Vec<f32> {
760 self.top5().iter().map(|&i| self.data[i]).collect()
761 }
762}
763
764/// Oriented bounding boxes.
765///
766/// Placeholder for future OBB support.
767#[derive(Debug, Clone)]
768pub struct Obb {
769 /// Raw OBB data with shape (N, 7) containing [x, y, w, h, rotation, conf, cls].
770 /// Or shape (N, 8) if tracking.
771 pub data: Array2<f32>,
772 /// Original image shape (height, width).
773 pub orig_shape: (u32, u32),
774 /// Whether tracking IDs are present.
775 is_track: bool,
776}
777
778impl Obb {
779 /// Create a new Obb instance.
780 ///
781 /// # Arguments
782 ///
783 /// * `data` - Raw OBB data.
784 /// * `orig_shape` - Original image shape.
785 ///
786 /// # Returns
787 ///
788 /// * A new `Obb` instance.
789 #[must_use]
790 pub fn new(data: Array2<f32>, orig_shape: (u32, u32)) -> Self {
791 let is_track = data.shape()[1] == 8;
792 Self {
793 data,
794 orig_shape,
795 is_track,
796 }
797 }
798
799 /// Get the number of OBBs.
800 ///
801 /// # Returns
802 ///
803 /// * The count of oriented bounding boxes.
804 #[must_use]
805 pub fn len(&self) -> usize {
806 self.data.nrows()
807 }
808
809 /// Check if there are no OBBs.
810 ///
811 /// # Returns
812 ///
813 /// * `true` if empty.
814 #[must_use]
815 pub fn is_empty(&self) -> bool {
816 self.data.is_empty()
817 }
818
819 /// Get boxes in xywhr format [`x_center`, `y_center`, width, height, rotation].
820 ///
821 /// # Returns
822 ///
823 /// * A view of the box parameters.
824 #[must_use]
825 pub fn xywhr(&self) -> ArrayView2<'_, f32> {
826 self.data.slice(s![.., 0..5])
827 }
828
829 /// Get confidence scores.
830 ///
831 /// # Returns
832 ///
833 /// * A view of confidence scores.
834 #[must_use]
835 pub fn conf(&self) -> ArrayView1<'_, f32> {
836 self.data.slice(s![.., -2])
837 }
838
839 /// Get class IDs.
840 ///
841 /// # Returns
842 ///
843 /// * A view of class IDs.
844 #[must_use]
845 pub fn cls(&self) -> ArrayView1<'_, f32> {
846 self.data.slice(s![.., -1])
847 }
848
849 /// Get tracking IDs (if available).
850 ///
851 /// # Returns
852 ///
853 /// * `Some` view of track IDs if available, otherwise `None`.
854 #[must_use]
855 pub fn id(&self) -> Option<ArrayView1<'_, f32>> {
856 if self.is_track {
857 Some(self.data.slice(s![.., -3]))
858 } else {
859 None
860 }
861 }
862
863 /// Get corner points for each OBB as (N, 4, 2) array.
864 /// Returns the 4 corner points of each rotated bounding box.
865 ///
866 /// # Returns
867 ///
868 /// * An owned array of shape (N, 4, 2) containing corner coordinates.
869 #[must_use]
870 pub fn xyxyxyxy(&self) -> Array3<f32> {
871 let n = self.len();
872 let mut corners = Array3::zeros((n, 4, 2));
873
874 for i in 0..n {
875 let cx = self.data[[i, 0]];
876 let cy = self.data[[i, 1]];
877 let w = self.data[[i, 2]];
878 let h = self.data[[i, 3]];
879 let angle = self.data[[i, 4]];
880
881 // Calculate corner offsets from center
882 let cos_a = angle.cos();
883 let sin_a = angle.sin();
884
885 // Half dimensions
886 let hw = w / 2.0;
887 let hh = h / 2.0;
888
889 // Corner offsets relative to center (before rotation)
890 let corners_rel = [
891 (-hw, -hh), // top-left
892 (hw, -hh), // top-right
893 (hw, hh), // bottom-right
894 (-hw, hh), // bottom-left
895 ];
896
897 // Apply rotation and translate to absolute coordinates
898 for (j, (dx, dy)) in corners_rel.iter().enumerate() {
899 let rotated_x = dx * cos_a - dy * sin_a;
900 let rotated_y = dx * sin_a + dy * cos_a;
901 corners[[i, j, 0]] = cx + rotated_x;
902 corners[[i, j, 1]] = cy + rotated_y;
903 }
904 }
905
906 corners
907 }
908
909 /// Get axis-aligned bounding box containing each OBB.
910 /// Returns array of shape (N, 4) with [x1, y1, x2, y2] for each OBB.
911 ///
912 /// # Returns
913 ///
914 /// * An owned array of axis-aligned bounding boxes.
915 #[must_use]
916 pub fn xyxy(&self) -> Array2<f32> {
917 let corners = self.xyxyxyxy();
918 let n = self.len();
919 let mut xyxy = Array2::zeros((n, 4));
920
921 for i in 0..n {
922 let mut min_x = f32::INFINITY;
923 let mut min_y = f32::INFINITY;
924 let mut max_x = f32::NEG_INFINITY;
925 let mut max_y = f32::NEG_INFINITY;
926
927 for j in 0..4 {
928 let x = corners[[i, j, 0]];
929 let y = corners[[i, j, 1]];
930 min_x = min_x.min(x);
931 min_y = min_y.min(y);
932 max_x = max_x.max(x);
933 max_y = max_y.max(y);
934 }
935
936 // Clip to image bounds
937 #[allow(clippy::cast_precision_loss)]
938 let (h, w) = (self.orig_shape.0 as f32, self.orig_shape.1 as f32);
939 xyxy[[i, 0]] = min_x.max(0.0).min(w);
940 xyxy[[i, 1]] = min_y.max(0.0).min(h);
941 xyxy[[i, 2]] = max_x.max(0.0).min(w);
942 xyxy[[i, 3]] = max_y.max(0.0).min(h);
943 }
944
945 xyxy
946 }
947}
948
949#[cfg(test)]
950mod tests {
951 use super::*;
952 use ndarray::array;
953
954 #[test]
955 fn test_boxes_xyxy() {
956 let data = array![[10.0, 20.0, 100.0, 200.0, 0.95, 0.0]];
957 let boxes = Boxes::new(data, (480, 640));
958
959 assert_eq!(boxes.len(), 1);
960 assert!((boxes.conf()[0] - 0.95).abs() < 1e-6);
961 assert!((boxes.cls()[0] - 0.0).abs() < 1e-6);
962 }
963
964 #[test]
965 fn test_boxes_xywh() {
966 let data = array![[0.0, 0.0, 100.0, 100.0, 0.9, 1.0]];
967 let boxes = Boxes::new(data, (640, 640));
968 let xywh = boxes.xywh();
969
970 assert!((xywh[[0, 0]] - 50.0).abs() < 1e-6); // x_center
971 assert!((xywh[[0, 1]] - 50.0).abs() < 1e-6); // y_center
972 assert!((xywh[[0, 2]] - 100.0).abs() < 1e-6); // width
973 assert!((xywh[[0, 3]] - 100.0).abs() < 1e-6); // height
974 }
975
976 #[test]
977 fn test_boxes_normalized() {
978 let data = array![[0.0, 0.0, 320.0, 240.0, 0.9, 0.0]];
979 let boxes = Boxes::new(data, (480, 640));
980 let xyxyn = boxes.xyxyn();
981
982 assert!((xyxyn[[0, 0]] - 0.0).abs() < 1e-6);
983 assert!((xyxyn[[0, 1]] - 0.0).abs() < 1e-6);
984 assert!((xyxyn[[0, 2]] - 0.5).abs() < 1e-6); // 320/640
985 assert!((xyxyn[[0, 3]] - 0.5).abs() < 1e-6); // 240/480
986 }
987
988 #[test]
989 fn test_probs() {
990 let data = array![0.1, 0.3, 0.6];
991 let probs = Probs::new(data);
992
993 assert_eq!(probs.top1(), 2);
994 assert_eq!(probs.top5(), vec![2, 1, 0]);
995 assert!((probs.top1conf() - 0.6).abs() < 1e-6);
996 }
997
998 #[test]
999 fn test_speed() {
1000 let speed = Speed::new(10.0, 20.0, 5.0);
1001 assert!((speed.total() - 35.0).abs() < 1e-6);
1002 }
1003 #[test]
1004 fn test_results_verbose() {
1005 let names = HashMap::from([(0, "person".to_string())]);
1006 let speed = Speed::default();
1007 let orig_img = Array3::zeros((100, 100, 3));
1008
1009 // Empty results
1010 let results = Results::new(orig_img, "test.jpg".to_string(), names, speed, (640, 640));
1011 assert!(results.is_empty());
1012 assert_eq!(results.verbose(), "(no detections), ");
1013 }
1014}