ultralytics_inference/
metadata.rs1use std::collections::HashMap;
9
10use crate::error::{InferenceError, Result};
11use crate::task::Task;
12
13#[derive(Debug, Clone)]
18pub struct ModelMetadata {
19 pub description: String,
21 pub author: String,
23 pub date: String,
25 pub version: String,
27 pub license: String,
29 pub docs: String,
31 pub task: Task,
33 pub stride: u32,
35 pub batch: usize,
37 pub imgsz: Option<(usize, usize)>,
39 pub channels: usize,
41 pub half: bool,
43 pub names: HashMap<usize, String>,
45 pub end2end: bool,
48 pub kpt_shape: Option<(usize, usize)>,
50}
51
52impl ModelMetadata {
53 pub fn from_onnx_metadata(metadata_map: &HashMap<String, String>) -> Result<Self> {
67 let yaml_str = metadata_map
70 .get("metadata")
71 .or_else(|| metadata_map.get("model_metadata"))
72 .or_else(|| {
73 metadata_map.values().find(|v| v.contains("task:"))
75 })
76 .ok_or_else(|| {
77 InferenceError::ModelLoadError(
78 "No metadata found in ONNX model. Ensure the model was exported with Ultralytics.".to_string()
79 )
80 })?;
81
82 Self::from_yaml_str(yaml_str)
83 }
84
85 pub fn from_yaml_str(yaml_str: &str) -> Result<Self> {
99 let mut metadata = Self::default();
100
101 for line in yaml_str.lines() {
102 let line = line.trim();
103 if line.is_empty() || line.starts_with('#') {
104 continue;
105 }
106
107 if let Some((key, value)) = line.split_once(':') {
109 let key = key.trim();
110 let value = value.trim().trim_matches('\'').trim_matches('"');
111
112 match key {
113 "description" => metadata.description = value.to_string(),
114 "author" => metadata.author = value.to_string(),
115 "date" => metadata.date = value.to_string(),
116 "version" => metadata.version = value.to_string(),
117 "license" => metadata.license = value.to_string(),
118 "docs" => metadata.docs = value.to_string(),
119 "task" => {
120 metadata.task = value.parse().map_err(|e| {
121 InferenceError::ModelLoadError(format!("Invalid task in metadata: {e}"))
122 })?;
123 }
124 "stride" => {
125 metadata.stride = value.parse().map_err(|_| {
126 InferenceError::ModelLoadError(format!("Invalid stride value: {value}"))
127 })?;
128 }
129 "batch" => {
130 metadata.batch = value.parse().map_err(|_| {
131 InferenceError::ModelLoadError(format!("Invalid batch value: {value}"))
132 })?;
133 }
134 "channels" => {
135 metadata.channels = value.parse().map_err(|_| {
136 InferenceError::ModelLoadError(format!(
137 "Invalid channels value: {value}"
138 ))
139 })?;
140 }
141 "half" => {
142 metadata.half = value == "true" || value == "True";
143 }
144 "end2end" => {
145 metadata.end2end = value == "true" || value == "True";
146 }
147 "kpt_shape" => {
148 metadata.kpt_shape = Self::parse_kpt_shape(value);
149 }
150 "args" => {
151 if value.contains("'half': True")
153 || value.contains("\"half\": true")
154 || value.contains("'half':True")
155 {
156 metadata.half = true;
157 }
158 }
159 _ => {
160 if let Ok(class_id) = key.trim().parse::<usize>() {
162 metadata.names.insert(class_id, value.to_string());
163 }
164 }
165 }
166 }
167 }
168
169 if let Some(imgsz_line) = yaml_str.lines().find(|l| l.contains("imgsz:")) {
171 metadata.imgsz = Self::parse_imgsz(yaml_str, imgsz_line);
172 }
173
174 if metadata.names.is_empty() {
176 metadata.names = Self::parse_names_block(yaml_str);
177 }
178
179 Ok(metadata)
180 }
181
182 fn parse_kpt_shape(value: &str) -> Option<(usize, usize)> {
184 let inner = value
185 .trim()
186 .trim_matches(|c| matches!(c, '[' | ']' | '(' | ')'));
187 let parts: Vec<usize> = inner
188 .split(',')
189 .filter_map(|s| s.trim().parse().ok())
190 .collect();
191 if parts.len() >= 2 {
192 Some((parts[0], parts[1]))
193 } else {
194 None
195 }
196 }
197
198 fn parse_imgsz(yaml_str: &str, imgsz_line: &str) -> Option<(usize, usize)> {
200 if let Some(bracket_start) = imgsz_line.find('[')
202 && let Some(bracket_end) = imgsz_line.find(']')
203 {
204 let values: Vec<usize> = imgsz_line[bracket_start + 1..bracket_end]
205 .split(',')
206 .filter_map(|s| s.trim().parse().ok())
207 .collect();
208 if values.len() >= 2 {
209 return Some((values[0], values[1]));
210 }
211 }
212
213 let lines: Vec<&str> = yaml_str.lines().collect();
215 let mut imgsz_values = Vec::new();
216
217 for (i, line) in lines.iter().enumerate() {
218 if line.contains("imgsz:") {
219 for following in lines.iter().skip(i + 1) {
221 let trimmed = following.trim();
222 if trimmed.starts_with('-') {
223 if let Ok(val) = trimmed.trim_start_matches('-').trim().parse::<usize>() {
224 imgsz_values.push(val);
225 }
226 } else if !trimmed.is_empty() && !trimmed.starts_with('#') {
227 break;
228 }
229 if imgsz_values.len() >= 2 {
230 break;
231 }
232 }
233 break;
234 }
235 }
236
237 if imgsz_values.len() >= 2 {
238 Some((imgsz_values[0], imgsz_values[1]))
239 } else {
240 None
241 }
242 }
243
244 fn parse_names_block(yaml_str: &str) -> HashMap<usize, String> {
246 let mut names = HashMap::new();
247
248 if let Some(start) = yaml_str.find("names:") {
251 let after_names = &yaml_str[start + 6..];
252 let trimmed = after_names.trim();
253
254 if trimmed.starts_with('{')
256 && let Some(end) = trimmed.find('}')
257 {
258 let dict_str = &trimmed[1..end];
259 return Self::parse_python_dict(dict_str);
260 }
261 }
262
263 let lines: Vec<&str> = yaml_str.lines().collect();
265 let mut in_names_block = false;
266 let mut names_indent = 0;
267
268 for line in &lines {
269 let trimmed = line.trim();
270
271 if trimmed.starts_with("names:") {
272 in_names_block = true;
273 names_indent = line.len() - line.trim_start().len();
274 continue;
275 }
276
277 if in_names_block {
278 let current_indent = line.len() - line.trim_start().len();
279
280 if !trimmed.is_empty()
282 && !trimmed.starts_with('#')
283 && current_indent <= names_indent
284 {
285 if !trimmed.chars().next().is_some_and(|c| c.is_ascii_digit()) {
287 break;
288 }
289 }
290
291 if let Some((key, value)) = trimmed.split_once(':')
293 && let Ok(class_id) = key.trim().parse::<usize>()
294 {
295 let class_name = value.trim().trim_matches('\'').trim_matches('"');
296 names.insert(class_id, class_name.to_string());
297 }
298 }
299 }
300
301 names
302 }
303
304 fn parse_python_dict(dict_str: &str) -> HashMap<usize, String> {
306 let mut names = HashMap::new();
307
308 for entry in dict_str.split(',') {
310 let entry = entry.trim();
311 if let Some((key, value)) = entry.split_once(':') {
312 let key = key.trim();
313 let value = value.trim().trim_matches('\'').trim_matches('"');
314 if let Ok(class_id) = key.parse::<usize>() {
315 names.insert(class_id, value.to_string());
316 }
317 }
318 }
319
320 names
321 }
322
323 #[must_use]
329 pub fn num_classes(&self) -> usize {
330 self.names.len()
331 }
332
333 #[must_use]
343 pub fn class_name(&self, class_id: usize) -> Option<&str> {
344 self.names.get(&class_id).map(String::as_str)
345 }
346
347 #[must_use]
352 pub fn model_name(&self) -> String {
353 self.description
355 .split_whitespace()
356 .find(|&word| word.to_lowercase().starts_with("yolo"))
357 .unwrap_or("YOLO")
358 .to_string()
359 }
360}
361
362impl Default for ModelMetadata {
363 fn default() -> Self {
364 Self {
365 description: String::new(),
366 author: "Ultralytics".to_string(),
367 date: String::new(),
368 version: String::new(),
369 license: "AGPL-3.0".to_string(),
370 docs: "https://docs.ultralytics.com".to_string(),
371 task: Task::Detect,
372 stride: 32,
373 batch: 1,
374 imgsz: None,
375 channels: 3,
376 half: false,
377 names: HashMap::new(),
378 end2end: false,
379 kpt_shape: None,
380 }
381 }
382}
383
384#[cfg(test)]
385mod tests {
386 use super::*;
387
388 const SAMPLE_METADATA: &str = r"
389description: Ultralytics YOLO11n model trained on /usr/src/ultralytics/ultralytics/cfg/datasets/coco.yaml
390author: Ultralytics
391date: '2025-12-11T20:19:45.464021'
392version: 8.3.236
393license: AGPL-3.0 License (https://ultralytics.com/license)
394docs: https://docs.ultralytics.com
395stride: 32
396task: detect
397batch: 1
398imgsz:
399- 640
400- 640
401names:
402 0: person
403 1: bicycle
404 2: car
405 3: motorcycle
406channels: 3
407";
408
409 #[test]
410 fn test_parse_metadata() {
411 let metadata = ModelMetadata::from_yaml_str(SAMPLE_METADATA).unwrap();
412
413 assert_eq!(metadata.task, Task::Detect);
414 assert_eq!(metadata.stride, 32);
415 assert_eq!(metadata.batch, 1);
416 assert_eq!(metadata.imgsz, Some((640, 640)));
417 assert_eq!(metadata.channels, 3);
418 assert_eq!(metadata.num_classes(), 4);
419 assert_eq!(metadata.class_name(0), Some("person"));
420 assert_eq!(metadata.class_name(1), Some("bicycle"));
421 assert_eq!(metadata.class_name(2), Some("car"));
422 assert_eq!(metadata.class_name(3), Some("motorcycle"));
423 }
424
425 #[test]
426 fn test_parse_inline_imgsz() {
427 let yaml = "task: detect\nimgsz: [640, 640]\nstride: 32";
428 let metadata = ModelMetadata::from_yaml_str(yaml).unwrap();
429 assert_eq!(metadata.imgsz, Some((640, 640)));
430 }
431
432 #[test]
433 fn test_default_metadata() {
434 let metadata = ModelMetadata::default();
435 assert_eq!(metadata.task, Task::Detect);
436 assert_eq!(metadata.stride, 32);
437 assert_eq!(metadata.imgsz, None);
438 }
439}