ultralytics_inference/
task.rs1use std::fmt;
9use std::str::FromStr;
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
17pub enum Task {
18 #[default]
21 Detect,
22 Segment,
25 Pose,
28 Classify,
31 Obb,
34}
35
36impl Task {
37 #[must_use]
40 pub const fn as_str(&self) -> &'static str {
41 match self {
42 Self::Detect => "detect",
43 Self::Segment => "segment",
44 Self::Pose => "pose",
45 Self::Classify => "classify",
46 Self::Obb => "obb",
47 }
48 }
49
50 #[must_use]
58 pub const fn model_suffix(&self) -> &'static str {
59 match self {
60 Self::Detect => "",
61 Self::Segment => "-seg",
62 Self::Pose => "-pose",
63 Self::Classify => "-cls",
64 Self::Obb => "-obb",
65 }
66 }
67
68 #[must_use]
78 pub fn default_model(&self) -> String {
79 format!("yolo26n{}.onnx", self.model_suffix())
80 }
81
82 #[must_use]
84 pub const fn has_boxes(&self) -> bool {
85 matches!(self, Self::Detect | Self::Segment | Self::Pose | Self::Obb)
86 }
87
88 #[must_use]
90 pub const fn has_masks(&self) -> bool {
91 matches!(self, Self::Segment)
92 }
93
94 #[must_use]
96 pub const fn has_keypoints(&self) -> bool {
97 matches!(self, Self::Pose)
98 }
99
100 #[must_use]
102 pub const fn has_probs(&self) -> bool {
103 matches!(self, Self::Classify)
104 }
105
106 #[must_use]
108 pub const fn has_obb(&self) -> bool {
109 matches!(self, Self::Obb)
110 }
111}
112
113impl fmt::Display for Task {
114 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
115 f.write_str(self.as_str())
116 }
117}
118
119impl FromStr for Task {
120 type Err = TaskParseError;
121
122 fn from_str(s: &str) -> Result<Self, Self::Err> {
123 match s.to_lowercase().as_str() {
124 "detect" | "detection" => Ok(Self::Detect),
125 "segment" | "segmentation" => Ok(Self::Segment),
126 "pose" | "keypoint" | "keypoints" => Ok(Self::Pose),
127 "classify" | "classification" | "cls" => Ok(Self::Classify),
128 "obb" | "oriented" => Ok(Self::Obb),
129 _ => Err(TaskParseError(s.to_string())),
130 }
131 }
132}
133
134#[derive(Debug, Clone)]
136pub struct TaskParseError(String);
137
138impl fmt::Display for TaskParseError {
139 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
140 write!(
141 f,
142 "invalid task '{}', expected one of: detect, segment, pose, classify, obb",
143 self.0
144 )
145 }
146}
147
148impl std::error::Error for TaskParseError {}
149
150#[cfg(test)]
151mod tests {
152 use super::*;
153
154 #[test]
155 fn test_task_from_str() {
156 assert_eq!("detect".parse::<Task>().unwrap(), Task::Detect);
157 assert_eq!("segment".parse::<Task>().unwrap(), Task::Segment);
158 assert_eq!("pose".parse::<Task>().unwrap(), Task::Pose);
159 assert_eq!("classify".parse::<Task>().unwrap(), Task::Classify);
160 assert_eq!("obb".parse::<Task>().unwrap(), Task::Obb);
161
162 assert_eq!("detection".parse::<Task>().unwrap(), Task::Detect);
164 assert_eq!("segmentation".parse::<Task>().unwrap(), Task::Segment);
165 assert_eq!("keypoints".parse::<Task>().unwrap(), Task::Pose);
166 assert_eq!("cls".parse::<Task>().unwrap(), Task::Classify);
167 }
168
169 #[test]
170 fn test_task_display() {
171 assert_eq!(Task::Detect.to_string(), "detect");
172 assert_eq!(Task::Segment.to_string(), "segment");
173 }
174
175 #[test]
176 fn test_task_capabilities() {
177 assert!(Task::Detect.has_boxes());
178 assert!(!Task::Detect.has_masks());
179 assert!(Task::Segment.has_masks());
180 assert!(Task::Pose.has_keypoints());
181 assert!(Task::Classify.has_probs());
182 assert!(Task::Obb.has_obb());
183 }
184}