Skip to main content

ultralytics_inference/
results.rs

1// Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
3//! Results classes for YOLO inference output.
4//!
5//! This module provides Ultralytics-compatible result classes that match
6//! the Python API for easy migration and consistent usage patterns.
7
8use std::collections::HashMap;
9
10use ndarray::{Array1, Array2, Array3, ArrayView1, ArrayView2, Axis, s};
11
12/// Timing information for inference operations (in milliseconds).
13#[derive(Debug, Clone, Default)]
14pub struct Speed {
15    /// Time spent on preprocessing.
16    pub preprocess: Option<f64>,
17    /// Time spent on model inference.
18    pub inference: Option<f64>,
19    /// Time spent on postprocessing.
20    pub postprocess: Option<f64>,
21}
22
23impl Speed {
24    /// Create a new Speed instance with all timings.
25    ///
26    /// # Arguments
27    ///
28    /// * `preprocess` - Time in milliseconds.
29    /// * `inference` - Time in milliseconds.
30    /// * `postprocess` - Time in milliseconds.
31    ///
32    /// # Returns
33    ///
34    /// * A new `Speed` instance.
35    #[must_use]
36    pub const fn new(preprocess: f64, inference: f64, postprocess: f64) -> Self {
37        Self {
38            preprocess: Some(preprocess),
39            inference: Some(inference),
40            postprocess: Some(postprocess),
41        }
42    }
43
44    /// Get total inference time.
45    ///
46    /// # Returns
47    ///
48    /// * Sum of preprocess, inference, and postprocess times in milliseconds.
49    #[must_use]
50    pub fn total(&self) -> f64 {
51        self.preprocess.unwrap_or(0.0)
52            + self.inference.unwrap_or(0.0)
53            + self.postprocess.unwrap_or(0.0)
54    }
55}
56
57/// Main results container for YOLO inference.
58///
59/// Contains the original image, detection results (boxes, masks, keypoints, etc.), timing information, and metadata.
60#[derive(Debug, Clone)]
61pub struct Results {
62    /// Original image as HWC array (height, width, channels).
63    pub orig_img: Array3<u8>,
64    /// Original image shape (height, width).
65    pub orig_shape: (u32, u32),
66    /// Inference tensor shape (height, width) after letterboxing.
67    pub inference_shape: (u32, u32),
68    /// Detection bounding boxes (if applicable).
69    pub boxes: Option<Boxes>,
70    /// Segmentation masks (if applicable).
71    pub masks: Option<Masks>,
72    /// Pose keypoints (if applicable).
73    pub keypoints: Option<Keypoints>,
74    /// Classification probabilities (if applicable).
75    pub probs: Option<Probs>,
76    /// Oriented bounding boxes (if applicable).
77    pub obb: Option<Obb>,
78    /// Inference timing information.
79    pub speed: Speed,
80    /// Class ID to name mapping.
81    pub names: HashMap<usize, String>,
82    /// Path to the source image/video.
83    pub path: String,
84}
85
86impl Results {
87    /// Create a new Results instance.
88    ///
89    /// # Arguments
90    ///
91    /// * `orig_img` - Original image as HWC array.
92    /// * `path` - Path to the source image/video.
93    /// * `names` - Map of class IDs to class names.
94    /// * `speed` - Timing information.
95    /// * `inference_shape` - Shape of the inference tensor (height, width).
96    ///
97    /// # Returns
98    ///
99    /// * A new `Results` instance.
100    #[must_use]
101    pub fn new(
102        orig_img: Array3<u8>,
103        path: String,
104        names: HashMap<usize, String>,
105        speed: Speed,
106        inference_shape: (u32, u32),
107    ) -> Self {
108        let shape = orig_img.shape();
109        #[allow(clippy::cast_possible_truncation)]
110        let orig_shape = (shape[0] as u32, shape[1] as u32);
111
112        Self {
113            orig_img,
114            orig_shape,
115            inference_shape,
116            boxes: None,
117            masks: None,
118            keypoints: None,
119            probs: None,
120            obb: None,
121            speed,
122            names,
123            path,
124        }
125    }
126
127    /// Get the number of detections.
128    ///
129    /// # Returns
130    ///
131    /// * The count of detected objects, keyspoints, or masks.
132    #[must_use]
133    pub fn len(&self) -> usize {
134        if let Some(ref boxes) = self.boxes {
135            return boxes.len();
136        }
137        if let Some(ref masks) = self.masks {
138            return masks.len();
139        }
140        if let Some(ref keypoints) = self.keypoints {
141            return keypoints.len();
142        }
143        if let Some(ref probs) = self.probs {
144            return usize::from(!probs.data.is_empty());
145        }
146        if let Some(ref obb) = self.obb {
147            return obb.len();
148        }
149        0
150    }
151
152    /// Check if there are no detections.
153    ///
154    /// # Returns
155    ///
156    /// * `true` if no objects were detected.
157    #[must_use]
158    pub fn is_empty(&self) -> bool {
159        self.len() == 0
160    }
161
162    /// Get the original image shape (height, width).
163    ///
164    /// # Returns
165    ///
166    /// * Tuple of (height, width).
167    #[must_use]
168    pub const fn orig_shape(&self) -> (u32, u32) {
169        self.orig_shape
170    }
171
172    /// Get the inference tensor shape (height, width) after letterboxing.
173    ///
174    /// # Returns
175    ///
176    /// * Tuple of (height, width).
177    #[must_use]
178    pub const fn inference_shape(&self) -> (u32, u32) {
179        self.inference_shape
180    }
181
182    /// Generate a verbose log string describing the results.
183    ///
184    /// # Returns
185    ///
186    /// * A string summary of detections (e.g., "2 persons, 1 car, ").
187    #[must_use]
188    pub fn verbose(&self) -> String {
189        if self.is_empty() {
190            if self.probs.is_some() {
191                return String::new();
192            }
193            return "(no detections), ".to_string();
194        }
195
196        if let Some(ref probs) = self.probs {
197            let top5: Vec<String> = probs
198                .top5()
199                .iter()
200                .map(|&i| {
201                    let name = self.names.get(&i).cloned().unwrap_or_else(|| i.to_string());
202                    format!("{} {:.2}", name, probs.data[i])
203                })
204                .collect();
205            return format!("{}, ", top5.join(", "));
206        }
207
208        if let Some(ref boxes) = self.boxes {
209            let cls = boxes.cls();
210            let mut counts: HashMap<usize, usize> = HashMap::new();
211            for &c in cls {
212                #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
213                let c = c as usize;
214                *counts.entry(c).or_insert(0) += 1;
215            }
216
217            let mut parts = Vec::new();
218            for (class_id, count) in &counts {
219                let name = self
220                    .names
221                    .get(class_id)
222                    .cloned()
223                    .unwrap_or_else(|| class_id.to_string());
224                let suffix = if *count > 1 { "s" } else { "" };
225                parts.push(format!("{count} {name}{suffix}"));
226            }
227            return format!("{}, ", parts.join(", "));
228        }
229
230        String::new()
231    }
232
233    /// Convert results to a list of dictionaries (summary format).
234    ///
235    /// # Arguments
236    ///
237    /// * `normalize` - Whether to normalize coordinates to [0, 1] range.
238    ///
239    /// # Returns
240    ///
241    /// * A vector of hashmaps representing the detections.
242    #[must_use]
243    pub fn summary(&self, normalize: bool) -> Vec<HashMap<String, SummaryValue>> {
244        let mut results = Vec::new();
245
246        if let Some(ref probs) = self.probs {
247            let class_id = probs.top1();
248            let mut entry = HashMap::new();
249            entry.insert(
250                "name".to_string(),
251                SummaryValue::String(
252                    self.names
253                        .get(&class_id)
254                        .cloned()
255                        .unwrap_or_else(|| class_id.to_string()),
256                ),
257            );
258            entry.insert("class".to_string(), SummaryValue::Int(class_id));
259            entry.insert(
260                "confidence".to_string(),
261                SummaryValue::Float(probs.top1conf()),
262            );
263            results.push(entry);
264            return results;
265        }
266
267        if let Some(ref boxes) = self.boxes {
268            let (h, w) = if normalize {
269                #[allow(clippy::cast_precision_loss)]
270                (self.orig_shape.0 as f32, self.orig_shape.1 as f32)
271            } else {
272                (1.0, 1.0)
273            };
274
275            let xyxy = boxes.xyxy();
276            let conf = boxes.conf();
277            let cls = boxes.cls();
278
279            for i in 0..boxes.len() {
280                #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
281                let class_id = cls[i] as usize;
282                let mut entry = HashMap::new();
283                entry.insert(
284                    "name".to_string(),
285                    SummaryValue::String(
286                        self.names
287                            .get(&class_id)
288                            .cloned()
289                            .unwrap_or_else(|| class_id.to_string()),
290                    ),
291                );
292                entry.insert("class".to_string(), SummaryValue::Int(class_id));
293                entry.insert("confidence".to_string(), SummaryValue::Float(conf[i]));
294
295                let mut box_coords = HashMap::new();
296                box_coords.insert("x1".to_string(), SummaryValue::Float(xyxy[[i, 0]] / w));
297                box_coords.insert("y1".to_string(), SummaryValue::Float(xyxy[[i, 1]] / h));
298                box_coords.insert("x2".to_string(), SummaryValue::Float(xyxy[[i, 2]] / w));
299                box_coords.insert("y2".to_string(), SummaryValue::Float(xyxy[[i, 3]] / h));
300                entry.insert("box".to_string(), SummaryValue::Box(box_coords));
301
302                results.push(entry);
303            }
304        }
305
306        results
307    }
308
309    /// Save the annotated result to a file.
310    ///
311    /// # Arguments
312    ///
313    /// * `path` - The path to save the image to.
314    ///
315    /// # Errors
316    ///
317    /// Returns an error if the image cannot be saved or if the format is unsupported.
318    #[cfg(feature = "annotate")]
319    pub fn save<P: AsRef<std::path::Path>>(&self, path: P) -> crate::error::Result<()> {
320        let img = crate::utils::array_to_image(&self.orig_img)?;
321        let annotated = crate::annotate::annotate_image(&img, self, None);
322        annotated
323            .save(path)
324            .map_err(|e| crate::error::InferenceError::ImageError(e.to_string()))
325    }
326}
327
328/// Values that can appear in a summary dictionary.
329#[derive(Debug, Clone)]
330pub enum SummaryValue {
331    /// String value.
332    String(String),
333    /// Integer value.
334    Int(usize),
335    /// Float value.
336    Float(f32),
337    /// Box coordinates.
338    Box(HashMap<String, Self>),
339}
340
341/// Detection bounding boxes.
342///
343/// Stores bounding boxes in xyxy format along with confidence scores and class IDs.
344#[derive(Debug, Clone)]
345pub struct Boxes {
346    /// Raw data array with shape (N, 6) containing [x1, y1, x2, y2, conf, cls].
347    /// Or shape (N, 7) if tracking: [x1, y1, x2, y2, `track_id`, conf, cls].
348    pub data: Array2<f32>,
349    /// Original image shape (height, width) for normalization.
350    pub orig_shape: (u32, u32),
351    /// Whether tracking IDs are present.
352    is_track: bool,
353}
354
355impl Boxes {
356    /// Create a new Boxes instance.
357    ///
358    /// # Arguments
359    ///
360    /// * `data` - Array with shape (N, 6) or (N, 7) containing box data.
361    /// * `orig_shape` - Original image shape (height, width).
362    ///
363    /// # Returns
364    ///
365    /// * A new `Boxes` instance.
366    #[must_use]
367    pub fn new(data: Array2<f32>, orig_shape: (u32, u32)) -> Self {
368        let is_track = data.shape()[1] == 7;
369        Self {
370            data,
371            orig_shape,
372            is_track,
373        }
374    }
375
376    /// Get the number of boxes.
377    ///
378    /// # Returns
379    ///
380    /// * The count of bounding boxes.
381    #[must_use]
382    pub fn len(&self) -> usize {
383        self.data.nrows()
384    }
385
386    /// Check if there are no boxes.
387    ///
388    /// # Returns
389    ///
390    /// * `true` if the boxes array is empty.
391    #[must_use]
392    pub fn is_empty(&self) -> bool {
393        self.data.is_empty()
394    }
395
396    /// Get boxes in xyxy format [x1, y1, x2, y2].
397    ///
398    /// # Returns
399    ///
400    /// * A view of the box coordinates.
401    #[must_use]
402    pub fn xyxy(&self) -> ArrayView2<'_, f32> {
403        self.data.slice(s![.., 0..4])
404    }
405
406    /// Get confidence scores.
407    ///
408    /// # Returns
409    ///
410    /// * A view of confidence scores (0.0 to 1.0).
411    #[must_use]
412    pub fn conf(&self) -> ArrayView1<'_, f32> {
413        self.data.slice(s![.., -2])
414    }
415
416    /// Get class IDs.
417    ///
418    /// # Returns
419    ///
420    /// * A view of class IDs.
421    #[must_use]
422    pub fn cls(&self) -> ArrayView1<'_, f32> {
423        self.data.slice(s![.., -1])
424    }
425
426    /// Get tracking IDs (if available).
427    ///
428    /// # Returns
429    ///
430    /// * `Some` view of track IDs if this is a tracking result, otherwise `None`.
431    #[must_use]
432    pub fn id(&self) -> Option<ArrayView1<'_, f32>> {
433        if self.is_track {
434            Some(self.data.slice(s![.., -3]))
435        } else {
436            None
437        }
438    }
439
440    /// Get boxes in xywh format [`x_center`, `y_center`, width, height].
441    ///
442    /// # Returns
443    ///
444    /// * An owned array of boxes in xywh format.
445    #[must_use]
446    pub fn xywh(&self) -> Array2<f32> {
447        let xyxy = self.xyxy();
448        let n = xyxy.nrows();
449        let mut xywh = Array2::zeros((n, 4));
450
451        for i in 0..n {
452            let x1 = xyxy[[i, 0]];
453            let y1 = xyxy[[i, 1]];
454            let x2 = xyxy[[i, 2]];
455            let y2 = xyxy[[i, 3]];
456
457            xywh[[i, 0]] = f32::midpoint(x1, x2); // x_center
458            xywh[[i, 1]] = f32::midpoint(y1, y2); // y_center
459            xywh[[i, 2]] = x2 - x1; // width
460            xywh[[i, 3]] = y2 - y1; // height
461        }
462
463        xywh
464    }
465
466    /// Get boxes in xyxy format normalized by image size.
467    ///
468    /// # Returns
469    ///
470    /// * An owned array of normalized boxes [0.0-1.0].
471    #[must_use]
472    pub fn xyxyn(&self) -> Array2<f32> {
473        let mut xyxyn = self.xyxy().to_owned();
474        #[allow(clippy::cast_precision_loss)]
475        let (h, w) = (self.orig_shape.0 as f32, self.orig_shape.1 as f32);
476
477        for mut row in xyxyn.rows_mut() {
478            row[0] /= w;
479            row[1] /= h;
480            row[2] /= w;
481            row[3] /= h;
482        }
483
484        xyxyn
485    }
486
487    /// Get boxes in xywh format normalized by image size.
488    ///
489    /// # Returns
490    ///
491    /// * An owned array of normalized boxes [0.0-1.0].
492    #[must_use]
493    pub fn xywhn(&self) -> Array2<f32> {
494        let mut xywhn = self.xywh();
495        #[allow(clippy::cast_precision_loss)]
496        let (h, w) = (self.orig_shape.0 as f32, self.orig_shape.1 as f32);
497
498        for mut row in xywhn.rows_mut() {
499            row[0] /= w;
500            row[1] /= h;
501            row[2] /= w;
502            row[3] /= h;
503        }
504
505        xywhn
506    }
507
508    /// Check if tracking IDs are available.
509    ///
510    /// # Returns
511    ///
512    /// * `true` if the boxes contain tracking information.
513    #[must_use]
514    pub const fn is_track(&self) -> bool {
515        self.is_track
516    }
517}
518
519/// Segmentation masks.
520///
521/// Placeholder for future segmentation support.
522#[derive(Debug, Clone)]
523pub struct Masks {
524    /// Raw mask data with shape (N, H, W).
525    pub data: Array3<f32>,
526    /// Original image shape (height, width).
527    pub orig_shape: (u32, u32),
528}
529
530impl Masks {
531    /// Create a new Masks instance.
532    ///
533    /// # Arguments
534    ///
535    /// * `data` - Raw mask data with shape (N, H, W).
536    /// * `orig_shape` - Original image shape (height, width).
537    ///
538    /// # Returns
539    ///
540    /// * A new `Masks` instance.
541    #[must_use]
542    pub const fn new(data: Array3<f32>, orig_shape: (u32, u32)) -> Self {
543        Self { data, orig_shape }
544    }
545
546    /// Get the number of masks.
547    ///
548    /// # Returns
549    ///
550    /// * The count of masks.
551    #[must_use]
552    pub fn len(&self) -> usize {
553        self.data.shape()[0]
554    }
555
556    /// Check if there are no masks.
557    ///
558    /// # Returns
559    ///
560    /// * `true` if the masks array is empty.
561    #[must_use]
562    pub fn is_empty(&self) -> bool {
563        self.data.is_empty()
564    }
565
566    // TODO: Implement xy and xyn properties for segment coordinates
567}
568
569/// Pose keypoints.
570///
571/// Placeholder for future pose estimation support.
572#[derive(Debug, Clone)]
573pub struct Keypoints {
574    /// Raw keypoint data with shape (N, K, 2) or (N, K, 3) if confidence included.
575    pub data: Array3<f32>,
576    /// Original image shape (height, width).
577    pub orig_shape: (u32, u32),
578    /// Whether confidence values are included.
579    has_visible: bool,
580}
581
582impl Keypoints {
583    /// Create a new Keypoints instance.
584    ///
585    /// # Arguments
586    ///
587    /// * `data` - Raw keypoint data.
588    /// * `orig_shape` - Original image shape.
589    ///
590    /// # Returns
591    ///
592    /// * A new `Keypoints` instance.
593    #[must_use]
594    pub fn new(data: Array3<f32>, orig_shape: (u32, u32)) -> Self {
595        let has_visible = data.shape()[2] == 3;
596        Self {
597            data,
598            orig_shape,
599            has_visible,
600        }
601    }
602
603    /// Get the number of detected objects with keypoints.
604    ///
605    /// # Returns
606    ///
607    /// * The count of poses.
608    #[must_use]
609    pub fn len(&self) -> usize {
610        self.data.shape()[0]
611    }
612
613    /// Check if there are no keypoints.
614    ///
615    /// # Returns
616    ///
617    /// * `true` if no keypoints were detected.
618    #[must_use]
619    pub fn is_empty(&self) -> bool {
620        self.data.is_empty()
621    }
622
623    /// Get xy coordinates.
624    ///
625    /// # Returns
626    ///
627    /// * An owned array of keypoint coordinates.
628    #[must_use]
629    pub fn xy(&self) -> Array3<f32> {
630        self.data.slice(s![.., .., 0..2]).to_owned()
631    }
632
633    /// Get normalized xy coordinates.
634    ///
635    /// # Returns
636    ///
637    /// * An owned array of normalized keypoint coordinates.
638    #[must_use]
639    pub fn xyn(&self) -> Array3<f32> {
640        let mut xyn = self.xy();
641        #[allow(clippy::cast_precision_loss)]
642        let (h, w) = (self.orig_shape.0 as f32, self.orig_shape.1 as f32);
643
644        for mut point in xyn.axis_iter_mut(Axis(2)) {
645            if point.shape()[0] > 0 {
646                point.mapv_inplace(|v| v / w);
647            }
648            if point.shape()[0] > 1 {
649                point.mapv_inplace(|v| v / h);
650            }
651        }
652
653        xyn
654    }
655
656    /// Get confidence values (if available).
657    ///
658    /// # Returns
659    ///
660    /// * `Some` array of confidences if available, otherwise `None`.
661    #[must_use]
662    pub fn conf(&self) -> Option<Array2<f32>> {
663        if self.has_visible {
664            Some(self.data.slice(s![.., .., 2]).to_owned())
665        } else {
666            None
667        }
668    }
669}
670
671/// Classification probabilities.
672///
673/// Stores class probabilities with convenience methods for top predictions.
674#[derive(Debug, Clone)]
675pub struct Probs {
676    /// Raw probability data with shape (`num_classes`,).
677    pub data: Array1<f32>,
678}
679
680impl Probs {
681    /// Create a new Probs instance.
682    ///
683    /// # Arguments
684    ///
685    /// * `data` - Raw probability array.
686    ///
687    /// # Returns
688    ///
689    /// * A new `Probs` instance.
690    #[must_use]
691    pub const fn new(data: Array1<f32>) -> Self {
692        Self { data }
693    }
694
695    /// Get the index of the top-1 class.
696    ///
697    /// # Returns
698    ///
699    /// * The class ID with the highest probability.
700    ///
701    /// # Panics
702    ///
703    /// Panics if valid comparison cannot be made (e.g. NaN) in `max_by`.
704    #[must_use]
705    pub fn top1(&self) -> usize {
706        self.data
707            .iter()
708            .enumerate()
709            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
710            .map_or(0, |(i, _)| i)
711    }
712    /// Get the indices of the top-5 classes.
713    ///
714    /// # Returns
715    ///
716    /// * A vector of the top 5 class IDs sorted by probability.
717    #[must_use]
718    pub fn top5(&self) -> Vec<usize> {
719        self.top_k(5)
720    }
721
722    /// Get the indices of the top-k classes.
723    ///
724    /// # Arguments
725    ///
726    /// * `k` - The number of classes to return.
727    ///
728    /// # Returns
729    ///
730    /// * A vector of the top k class IDs sorted by probability.
731    #[must_use]
732    pub fn top_k(&self, k: usize) -> Vec<usize> {
733        let mut indices: Vec<usize> = (0..self.data.len()).collect();
734        indices.sort_by(|&a, &b| {
735            self.data[b]
736                .partial_cmp(&self.data[a])
737                .unwrap_or(std::cmp::Ordering::Equal)
738        });
739        indices.truncate(k);
740        indices
741    }
742
743    /// Get the confidence of the top-1 class.
744    ///
745    /// # Returns
746    ///
747    /// * The probability of the top class.
748    #[must_use]
749    pub fn top1conf(&self) -> f32 {
750        self.data[self.top1()]
751    }
752
753    /// Get the confidences of the top-5 classes.
754    ///
755    /// # Returns
756    ///
757    /// * A vector of the top 5 probabilities.
758    #[must_use]
759    pub fn top5conf(&self) -> Vec<f32> {
760        self.top5().iter().map(|&i| self.data[i]).collect()
761    }
762}
763
764/// Oriented bounding boxes.
765///
766/// Placeholder for future OBB support.
767#[derive(Debug, Clone)]
768pub struct Obb {
769    /// Raw OBB data with shape (N, 7) containing [x, y, w, h, rotation, conf, cls].
770    /// Or shape (N, 8) if tracking.
771    pub data: Array2<f32>,
772    /// Original image shape (height, width).
773    pub orig_shape: (u32, u32),
774    /// Whether tracking IDs are present.
775    is_track: bool,
776}
777
778impl Obb {
779    /// Create a new Obb instance.
780    ///
781    /// # Arguments
782    ///
783    /// * `data` - Raw OBB data.
784    /// * `orig_shape` - Original image shape.
785    ///
786    /// # Returns
787    ///
788    /// * A new `Obb` instance.
789    #[must_use]
790    pub fn new(data: Array2<f32>, orig_shape: (u32, u32)) -> Self {
791        let is_track = data.shape()[1] == 8;
792        Self {
793            data,
794            orig_shape,
795            is_track,
796        }
797    }
798
799    /// Get the number of OBBs.
800    ///
801    /// # Returns
802    ///
803    /// * The count of oriented bounding boxes.
804    #[must_use]
805    pub fn len(&self) -> usize {
806        self.data.nrows()
807    }
808
809    /// Check if there are no OBBs.
810    ///
811    /// # Returns
812    ///
813    /// * `true` if empty.
814    #[must_use]
815    pub fn is_empty(&self) -> bool {
816        self.data.is_empty()
817    }
818
819    /// Get boxes in xywhr format [`x_center`, `y_center`, width, height, rotation].
820    ///
821    /// # Returns
822    ///
823    /// * A view of the box parameters.
824    #[must_use]
825    pub fn xywhr(&self) -> ArrayView2<'_, f32> {
826        self.data.slice(s![.., 0..5])
827    }
828
829    /// Get confidence scores.
830    ///
831    /// # Returns
832    ///
833    /// * A view of confidence scores.
834    #[must_use]
835    pub fn conf(&self) -> ArrayView1<'_, f32> {
836        self.data.slice(s![.., -2])
837    }
838
839    /// Get class IDs.
840    ///
841    /// # Returns
842    ///
843    /// * A view of class IDs.
844    #[must_use]
845    pub fn cls(&self) -> ArrayView1<'_, f32> {
846        self.data.slice(s![.., -1])
847    }
848
849    /// Get tracking IDs (if available).
850    ///
851    /// # Returns
852    ///
853    /// * `Some` view of track IDs if available, otherwise `None`.
854    #[must_use]
855    pub fn id(&self) -> Option<ArrayView1<'_, f32>> {
856        if self.is_track {
857            Some(self.data.slice(s![.., -3]))
858        } else {
859            None
860        }
861    }
862
863    /// Get corner points for each OBB as (N, 4, 2) array.
864    /// Returns the 4 corner points of each rotated bounding box.
865    ///
866    /// # Returns
867    ///
868    /// * An owned array of shape (N, 4, 2) containing corner coordinates.
869    #[must_use]
870    pub fn xyxyxyxy(&self) -> Array3<f32> {
871        let n = self.len();
872        let mut corners = Array3::zeros((n, 4, 2));
873
874        for i in 0..n {
875            let cx = self.data[[i, 0]];
876            let cy = self.data[[i, 1]];
877            let w = self.data[[i, 2]];
878            let h = self.data[[i, 3]];
879            let angle = self.data[[i, 4]];
880
881            // Calculate corner offsets from center
882            let cos_a = angle.cos();
883            let sin_a = angle.sin();
884
885            // Half dimensions
886            let hw = w / 2.0;
887            let hh = h / 2.0;
888
889            // Corner offsets relative to center (before rotation)
890            let corners_rel = [
891                (-hw, -hh), // top-left
892                (hw, -hh),  // top-right
893                (hw, hh),   // bottom-right
894                (-hw, hh),  // bottom-left
895            ];
896
897            // Apply rotation and translate to absolute coordinates
898            for (j, (dx, dy)) in corners_rel.iter().enumerate() {
899                let rotated_x = dx * cos_a - dy * sin_a;
900                let rotated_y = dx * sin_a + dy * cos_a;
901                corners[[i, j, 0]] = cx + rotated_x;
902                corners[[i, j, 1]] = cy + rotated_y;
903            }
904        }
905
906        corners
907    }
908
909    /// Get axis-aligned bounding box containing each OBB.
910    /// Returns array of shape (N, 4) with [x1, y1, x2, y2] for each OBB.
911    ///
912    /// # Returns
913    ///
914    /// * An owned array of axis-aligned bounding boxes.
915    #[must_use]
916    pub fn xyxy(&self) -> Array2<f32> {
917        let corners = self.xyxyxyxy();
918        let n = self.len();
919        let mut xyxy = Array2::zeros((n, 4));
920
921        for i in 0..n {
922            let mut min_x = f32::INFINITY;
923            let mut min_y = f32::INFINITY;
924            let mut max_x = f32::NEG_INFINITY;
925            let mut max_y = f32::NEG_INFINITY;
926
927            for j in 0..4 {
928                let x = corners[[i, j, 0]];
929                let y = corners[[i, j, 1]];
930                min_x = min_x.min(x);
931                min_y = min_y.min(y);
932                max_x = max_x.max(x);
933                max_y = max_y.max(y);
934            }
935
936            // Clip to image bounds
937            #[allow(clippy::cast_precision_loss)]
938            let (h, w) = (self.orig_shape.0 as f32, self.orig_shape.1 as f32);
939            xyxy[[i, 0]] = min_x.max(0.0).min(w);
940            xyxy[[i, 1]] = min_y.max(0.0).min(h);
941            xyxy[[i, 2]] = max_x.max(0.0).min(w);
942            xyxy[[i, 3]] = max_y.max(0.0).min(h);
943        }
944
945        xyxy
946    }
947}
948
949#[cfg(test)]
950mod tests {
951    use super::*;
952    use ndarray::array;
953
954    #[test]
955    fn test_boxes_xyxy() {
956        let data = array![[10.0, 20.0, 100.0, 200.0, 0.95, 0.0]];
957        let boxes = Boxes::new(data, (480, 640));
958
959        assert_eq!(boxes.len(), 1);
960        assert!((boxes.conf()[0] - 0.95).abs() < 1e-6);
961        assert!((boxes.cls()[0] - 0.0).abs() < 1e-6);
962    }
963
964    #[test]
965    fn test_boxes_xywh() {
966        let data = array![[0.0, 0.0, 100.0, 100.0, 0.9, 1.0]];
967        let boxes = Boxes::new(data, (640, 640));
968        let xywh = boxes.xywh();
969
970        assert!((xywh[[0, 0]] - 50.0).abs() < 1e-6); // x_center
971        assert!((xywh[[0, 1]] - 50.0).abs() < 1e-6); // y_center
972        assert!((xywh[[0, 2]] - 100.0).abs() < 1e-6); // width
973        assert!((xywh[[0, 3]] - 100.0).abs() < 1e-6); // height
974    }
975
976    #[test]
977    fn test_boxes_normalized() {
978        let data = array![[0.0, 0.0, 320.0, 240.0, 0.9, 0.0]];
979        let boxes = Boxes::new(data, (480, 640));
980        let xyxyn = boxes.xyxyn();
981
982        assert!((xyxyn[[0, 0]] - 0.0).abs() < 1e-6);
983        assert!((xyxyn[[0, 1]] - 0.0).abs() < 1e-6);
984        assert!((xyxyn[[0, 2]] - 0.5).abs() < 1e-6); // 320/640
985        assert!((xyxyn[[0, 3]] - 0.5).abs() < 1e-6); // 240/480
986    }
987
988    #[test]
989    fn test_probs() {
990        let data = array![0.1, 0.3, 0.6];
991        let probs = Probs::new(data);
992
993        assert_eq!(probs.top1(), 2);
994        assert_eq!(probs.top5(), vec![2, 1, 0]);
995        assert!((probs.top1conf() - 0.6).abs() < 1e-6);
996    }
997
998    #[test]
999    fn test_speed() {
1000        let speed = Speed::new(10.0, 20.0, 5.0);
1001        assert!((speed.total() - 35.0).abs() < 1e-6);
1002    }
1003    #[test]
1004    fn test_results_verbose() {
1005        let names = HashMap::from([(0, "person".to_string())]);
1006        let speed = Speed::default();
1007        let orig_img = Array3::zeros((100, 100, 3));
1008
1009        // Empty results
1010        let results = Results::new(orig_img, "test.jpg".to_string(), names, speed, (640, 640));
1011        assert!(results.is_empty());
1012        assert_eq!(results.verbose(), "(no detections), ");
1013    }
1014}