Skip to main content

ultralytics_inference/
task.rs

1// Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
3//! Task definitions for YOLO models.
4//!
5//! This module defines the different tasks that YOLO models can perform,
6//! along with their associated capabilities and string representations.
7
8use std::fmt;
9use std::str::FromStr;
10
11/// YOLO model task types.
12///
13/// Each task type corresponds to a different computer vision problem
14/// that YOLO models can solve. The task type determines the expected
15/// model outputs and post-processing steps.
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
17pub enum Task {
18    /// Object detection.
19    /// Predicts bounding boxes and class labels for objects in an image.
20    #[default]
21    Detect,
22    /// Instance segmentation.
23    /// Predicts bounding boxes, class labels, and pixel-level masks for objects.
24    Segment,
25    /// Pose estimation.
26    /// Predicts bounding boxes and skeletal keypoints for objects (e.g., humans).
27    Pose,
28    /// Image classification.
29    /// Predicts class probabilities for the entire image (no localization).
30    Classify,
31    /// Oriented bounding box detection (OBB).
32    /// Predicts rotated bounding boxes for objects, useful for aerial imagery etc.
33    Obb,
34}
35
36impl Task {
37    /// Get the string representation used in ONNX model metadata
38    /// (e.g. `"detect"`, `"segment"`).
39    #[must_use]
40    pub const fn as_str(&self) -> &'static str {
41        match self {
42            Self::Detect => "detect",
43            Self::Segment => "segment",
44            Self::Pose => "pose",
45            Self::Classify => "classify",
46            Self::Obb => "obb",
47        }
48    }
49
50    /// ONNX filename suffix for this task, used to construct `yolo26n{suffix}.onnx`.
51    ///
52    /// ```
53    /// use ultralytics_inference::Task;
54    /// assert_eq!(Task::Detect.model_suffix(), "");
55    /// assert_eq!(Task::Segment.model_suffix(), "-seg");
56    /// ```
57    #[must_use]
58    pub const fn model_suffix(&self) -> &'static str {
59        match self {
60            Self::Detect => "",
61            Self::Segment => "-seg",
62            Self::Pose => "-pose",
63            Self::Classify => "-cls",
64            Self::Obb => "-obb",
65        }
66    }
67
68    /// Default nano YOLO26 model filename for this task.
69    ///
70    /// Used by the CLI to auto-pick a model when `--model` is omitted but `--task` is set.
71    ///
72    /// ```
73    /// use ultralytics_inference::Task;
74    /// assert_eq!(Task::Detect.default_model(), "yolo26n.onnx");
75    /// assert_eq!(Task::Segment.default_model(), "yolo26n-seg.onnx");
76    /// ```
77    #[must_use]
78    pub fn default_model(&self) -> String {
79        format!("yolo26n{}.onnx", self.model_suffix())
80    }
81
82    /// Returns `true` when the task outputs bounding boxes — namely Detect, Segment, Pose, and Obb.
83    #[must_use]
84    pub const fn has_boxes(&self) -> bool {
85        matches!(self, Self::Detect | Self::Segment | Self::Pose | Self::Obb)
86    }
87
88    /// Returns `true` only for the Segment task, which outputs per-instance segmentation masks.
89    #[must_use]
90    pub const fn has_masks(&self) -> bool {
91        matches!(self, Self::Segment)
92    }
93
94    /// Returns `true` only for the Pose task, which outputs skeletal keypoints.
95    #[must_use]
96    pub const fn has_keypoints(&self) -> bool {
97        matches!(self, Self::Pose)
98    }
99
100    /// Returns `true` only for the Classify task, which outputs global class probabilities.
101    #[must_use]
102    pub const fn has_probs(&self) -> bool {
103        matches!(self, Self::Classify)
104    }
105
106    /// Returns `true` only for the Obb task, which outputs oriented (rotated) bounding boxes.
107    #[must_use]
108    pub const fn has_obb(&self) -> bool {
109        matches!(self, Self::Obb)
110    }
111}
112
113impl fmt::Display for Task {
114    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
115        f.write_str(self.as_str())
116    }
117}
118
119impl FromStr for Task {
120    type Err = TaskParseError;
121
122    fn from_str(s: &str) -> Result<Self, Self::Err> {
123        match s.to_lowercase().as_str() {
124            "detect" | "detection" => Ok(Self::Detect),
125            "segment" | "segmentation" => Ok(Self::Segment),
126            "pose" | "keypoint" | "keypoints" => Ok(Self::Pose),
127            "classify" | "classification" | "cls" => Ok(Self::Classify),
128            "obb" | "oriented" => Ok(Self::Obb),
129            _ => Err(TaskParseError(s.to_string())),
130        }
131    }
132}
133
134/// Error returned when parsing an invalid task string.
135#[derive(Debug, Clone)]
136pub struct TaskParseError(String);
137
138impl fmt::Display for TaskParseError {
139    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
140        write!(
141            f,
142            "invalid task '{}', expected one of: detect, segment, pose, classify, obb",
143            self.0
144        )
145    }
146}
147
148impl std::error::Error for TaskParseError {}
149
150#[cfg(test)]
151mod tests {
152    use super::*;
153
154    #[test]
155    fn test_task_from_str() {
156        assert_eq!("detect".parse::<Task>().unwrap(), Task::Detect);
157        assert_eq!("segment".parse::<Task>().unwrap(), Task::Segment);
158        assert_eq!("pose".parse::<Task>().unwrap(), Task::Pose);
159        assert_eq!("classify".parse::<Task>().unwrap(), Task::Classify);
160        assert_eq!("obb".parse::<Task>().unwrap(), Task::Obb);
161
162        // Alternative names
163        assert_eq!("detection".parse::<Task>().unwrap(), Task::Detect);
164        assert_eq!("segmentation".parse::<Task>().unwrap(), Task::Segment);
165        assert_eq!("keypoints".parse::<Task>().unwrap(), Task::Pose);
166        assert_eq!("cls".parse::<Task>().unwrap(), Task::Classify);
167    }
168
169    #[test]
170    fn test_task_display() {
171        assert_eq!(Task::Detect.to_string(), "detect");
172        assert_eq!(Task::Segment.to_string(), "segment");
173    }
174
175    #[test]
176    fn test_task_capabilities() {
177        assert!(Task::Detect.has_boxes());
178        assert!(!Task::Detect.has_masks());
179        assert!(Task::Segment.has_masks());
180        assert!(Task::Pose.has_keypoints());
181        assert!(Task::Classify.has_probs());
182        assert!(Task::Obb.has_obb());
183    }
184}