Skip to main content

ultralytics_inference/
metadata.rs

1// Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
3//! ONNX model metadata parsing.
4//!
5//! This module handles parsing metadata from Ultralytics YOLO ONNX models.
6//! The metadata is stored as YAML in the ONNX model's custom metadata properties.
7
8use std::collections::HashMap;
9
10use crate::error::{InferenceError, Result};
11use crate::task::Task;
12
13/// Metadata extracted from an Ultralytics YOLO ONNX model.
14///
15/// This struct contains all the configuration information embedded in the model,
16/// including class names, input dimensions, and task type.
17#[derive(Debug, Clone)]
18pub struct ModelMetadata {
19    /// Model description (e.g., "Ultralytics `YOLO11n` model trained on coco.yaml").
20    pub description: String,
21    /// Model author.
22    pub author: String,
23    /// Export date.
24    pub date: String,
25    /// Ultralytics version used for export.
26    pub version: String,
27    /// License information.
28    pub license: String,
29    /// Documentation URL.
30    pub docs: String,
31    /// The task this model performs.
32    pub task: Task,
33    /// Model stride (typically 32 for YOLO).
34    pub stride: u32,
35    /// Batch size the model was exported with.
36    pub batch: usize,
37    /// Input image size as (height, width).
38    pub imgsz: Option<(usize, usize)>,
39    /// Number of input channels (typically 3 for RGB).
40    pub channels: usize,
41    /// Whether the model uses FP16 (half precision).
42    pub half: bool,
43    /// Class ID to class name mapping.
44    pub names: HashMap<usize, String>,
45    /// Whether the model was exported with end-to-end NMS-free output
46    /// (YOLO26-style post-NMS output: `[B, max_det, 6+extra]`).
47    pub end2end: bool,
48    /// Pose keypoint shape as (`num_keypoints`, `dims`), e.g. (17, 3).
49    pub kpt_shape: Option<(usize, usize)>,
50}
51
52impl ModelMetadata {
53    /// Parse metadata from ONNX model custom metadata properties.
54    ///
55    /// # Arguments
56    ///
57    /// * `metadata_map` - The custom metadata from the ONNX model session.
58    ///
59    /// # Returns
60    ///
61    /// * A new `ModelMetadata` instance.
62    ///
63    /// # Errors
64    ///
65    /// Returns an error if the metadata is missing or malformed.
66    pub fn from_onnx_metadata(metadata_map: &HashMap<String, String>) -> Result<Self> {
67        // The metadata is typically stored under a single key containing YAML
68        // Try common key names used by Ultralytics
69        let yaml_str = metadata_map
70            .get("metadata")
71            .or_else(|| metadata_map.get("model_metadata"))
72            .or_else(|| {
73                // If no standard key, check if all metadata is in one value
74                metadata_map.values().find(|v| v.contains("task:"))
75            })
76            .ok_or_else(|| {
77                InferenceError::ModelLoadError(
78                    "No metadata found in ONNX model. Ensure the model was exported with Ultralytics.".to_string()
79                )
80            })?;
81
82        Self::from_yaml_str(yaml_str)
83    }
84
85    /// Parse metadata from a YAML string.
86    ///
87    /// # Arguments
88    ///
89    /// * `yaml_str` - The YAML-formatted metadata string.
90    ///
91    /// # Returns
92    ///
93    /// * A new `ModelMetadata` instance.
94    ///
95    /// # Errors
96    ///
97    /// Returns an error if the YAML is malformed or missing required fields.
98    pub fn from_yaml_str(yaml_str: &str) -> Result<Self> {
99        let mut metadata = Self::default();
100
101        for line in yaml_str.lines() {
102            let line = line.trim();
103            if line.is_empty() || line.starts_with('#') {
104                continue;
105            }
106
107            // Handle key: value pairs
108            if let Some((key, value)) = line.split_once(':') {
109                let key = key.trim();
110                let value = value.trim().trim_matches('\'').trim_matches('"');
111
112                match key {
113                    "description" => metadata.description = value.to_string(),
114                    "author" => metadata.author = value.to_string(),
115                    "date" => metadata.date = value.to_string(),
116                    "version" => metadata.version = value.to_string(),
117                    "license" => metadata.license = value.to_string(),
118                    "docs" => metadata.docs = value.to_string(),
119                    "task" => {
120                        metadata.task = value.parse().map_err(|e| {
121                            InferenceError::ModelLoadError(format!("Invalid task in metadata: {e}"))
122                        })?;
123                    }
124                    "stride" => {
125                        metadata.stride = value.parse().map_err(|_| {
126                            InferenceError::ModelLoadError(format!("Invalid stride value: {value}"))
127                        })?;
128                    }
129                    "batch" => {
130                        metadata.batch = value.parse().map_err(|_| {
131                            InferenceError::ModelLoadError(format!("Invalid batch value: {value}"))
132                        })?;
133                    }
134                    "channels" => {
135                        metadata.channels = value.parse().map_err(|_| {
136                            InferenceError::ModelLoadError(format!(
137                                "Invalid channels value: {value}"
138                            ))
139                        })?;
140                    }
141                    "half" => {
142                        metadata.half = value == "true" || value == "True";
143                    }
144                    "end2end" => {
145                        metadata.end2end = value == "true" || value == "True";
146                    }
147                    "kpt_shape" => {
148                        metadata.kpt_shape = Self::parse_kpt_shape(value);
149                    }
150                    "args" => {
151                        // Parse args dict for half flag: {'half': True, ...}
152                        if value.contains("'half': True")
153                            || value.contains("\"half\": true")
154                            || value.contains("'half':True")
155                        {
156                            metadata.half = true;
157                        }
158                    }
159                    _ => {
160                        // Check for class name entries (numeric keys)
161                        if let Ok(class_id) = key.trim().parse::<usize>() {
162                            metadata.names.insert(class_id, value.to_string());
163                        }
164                    }
165                }
166            }
167        }
168
169        // Parse imgsz which can be a list like [640, 640]
170        if let Some(imgsz_line) = yaml_str.lines().find(|l| l.contains("imgsz:")) {
171            metadata.imgsz = Self::parse_imgsz(yaml_str, imgsz_line);
172        }
173
174        // Parse names block if not already parsed inline
175        if metadata.names.is_empty() {
176            metadata.names = Self::parse_names_block(yaml_str);
177        }
178
179        Ok(metadata)
180    }
181
182    /// Parse a `kpt_shape` value like "[17, 3]" or "(17, 3)" into a tuple.
183    fn parse_kpt_shape(value: &str) -> Option<(usize, usize)> {
184        let inner = value
185            .trim()
186            .trim_matches(|c| matches!(c, '[' | ']' | '(' | ')'));
187        let parts: Vec<usize> = inner
188            .split(',')
189            .filter_map(|s| s.trim().parse().ok())
190            .collect();
191        if parts.len() >= 2 {
192            Some((parts[0], parts[1]))
193        } else {
194            None
195        }
196    }
197
198    /// Parse the imgsz field which can be a YAML list.
199    fn parse_imgsz(yaml_str: &str, imgsz_line: &str) -> Option<(usize, usize)> {
200        // Check if imgsz is on a single line like "imgsz: [640, 640]"
201        if let Some(bracket_start) = imgsz_line.find('[')
202            && let Some(bracket_end) = imgsz_line.find(']')
203        {
204            let values: Vec<usize> = imgsz_line[bracket_start + 1..bracket_end]
205                .split(',')
206                .filter_map(|s| s.trim().parse().ok())
207                .collect();
208            if values.len() >= 2 {
209                return Some((values[0], values[1]));
210            }
211        }
212
213        // Check for multi-line YAML list format
214        let lines: Vec<&str> = yaml_str.lines().collect();
215        let mut imgsz_values = Vec::new();
216
217        for (i, line) in lines.iter().enumerate() {
218            if line.contains("imgsz:") {
219                // Look at following lines for list items
220                for following in lines.iter().skip(i + 1) {
221                    let trimmed = following.trim();
222                    if trimmed.starts_with('-') {
223                        if let Ok(val) = trimmed.trim_start_matches('-').trim().parse::<usize>() {
224                            imgsz_values.push(val);
225                        }
226                    } else if !trimmed.is_empty() && !trimmed.starts_with('#') {
227                        break;
228                    }
229                    if imgsz_values.len() >= 2 {
230                        break;
231                    }
232                }
233                break;
234            }
235        }
236
237        if imgsz_values.len() >= 2 {
238            Some((imgsz_values[0], imgsz_values[1]))
239        } else {
240            None
241        }
242    }
243
244    /// Parse the names block from YAML or Python dict format.
245    fn parse_names_block(yaml_str: &str) -> HashMap<usize, String> {
246        let mut names = HashMap::new();
247
248        // First, try to find `names: {0: 'person', 1: 'bicycle', ...}` Python dict format
249        // This is how Ultralytics stores names in ONNX metadata
250        if let Some(start) = yaml_str.find("names:") {
251            let after_names = &yaml_str[start + 6..];
252            let trimmed = after_names.trim();
253
254            // Check if it's Python dict format (starts with {)
255            if trimmed.starts_with('{')
256                && let Some(end) = trimmed.find('}')
257            {
258                let dict_str = &trimmed[1..end];
259                return Self::parse_python_dict(dict_str);
260            }
261        }
262
263        // Fall back to YAML block format
264        let lines: Vec<&str> = yaml_str.lines().collect();
265        let mut in_names_block = false;
266        let mut names_indent = 0;
267
268        for line in &lines {
269            let trimmed = line.trim();
270
271            if trimmed.starts_with("names:") {
272                in_names_block = true;
273                names_indent = line.len() - line.trim_start().len();
274                continue;
275            }
276
277            if in_names_block {
278                let current_indent = line.len() - line.trim_start().len();
279
280                // Check if we've exited the names block
281                if !trimmed.is_empty()
282                    && !trimmed.starts_with('#')
283                    && current_indent <= names_indent
284                {
285                    // Only exit if this isn't a class entry
286                    if !trimmed.chars().next().is_some_and(|c| c.is_ascii_digit()) {
287                        break;
288                    }
289                }
290
291                // Parse class entries like "0: person" or "  0: person"
292                if let Some((key, value)) = trimmed.split_once(':')
293                    && let Ok(class_id) = key.trim().parse::<usize>()
294                {
295                    let class_name = value.trim().trim_matches('\'').trim_matches('"');
296                    names.insert(class_id, class_name.to_string());
297                }
298            }
299        }
300
301        names
302    }
303
304    /// Parse a Python dict string like `0: 'person', 1: 'bicycle'`.
305    fn parse_python_dict(dict_str: &str) -> HashMap<usize, String> {
306        let mut names = HashMap::new();
307
308        // Split by comma, but be careful with quotes
309        for entry in dict_str.split(',') {
310            let entry = entry.trim();
311            if let Some((key, value)) = entry.split_once(':') {
312                let key = key.trim();
313                let value = value.trim().trim_matches('\'').trim_matches('"');
314                if let Ok(class_id) = key.parse::<usize>() {
315                    names.insert(class_id, value.to_string());
316                }
317            }
318        }
319
320        names
321    }
322
323    /// Get the number of classes in this model.
324    ///
325    /// # Returns
326    ///
327    /// * The count of classes.
328    #[must_use]
329    pub fn num_classes(&self) -> usize {
330        self.names.len()
331    }
332
333    /// Get a class name by ID.
334    ///
335    /// # Arguments
336    ///
337    /// * `class_id` - The numeric identifier for the class.
338    ///
339    /// # Returns
340    ///
341    /// * `Some` class name if found, otherwise `None`.
342    #[must_use]
343    pub fn class_name(&self, class_id: usize) -> Option<&str> {
344        self.names.get(&class_id).map(String::as_str)
345    }
346
347    /// Extract the model name from the description.
348    ///
349    /// E.g. "Ultralytics `YOLO11n` model..." -> "`YOLO11n`"
350    /// Returns `YOLO` if extraction fails.
351    #[must_use]
352    pub fn model_name(&self) -> String {
353        // Description format: "Ultralytics <MODEL> model..."
354        self.description
355            .split_whitespace()
356            .find(|&word| word.to_lowercase().starts_with("yolo"))
357            .unwrap_or("YOLO")
358            .to_string()
359    }
360}
361
362impl Default for ModelMetadata {
363    fn default() -> Self {
364        Self {
365            description: String::new(),
366            author: "Ultralytics".to_string(),
367            date: String::new(),
368            version: String::new(),
369            license: "AGPL-3.0".to_string(),
370            docs: "https://docs.ultralytics.com".to_string(),
371            task: Task::Detect,
372            stride: 32,
373            batch: 1,
374            imgsz: None,
375            channels: 3,
376            half: false,
377            names: HashMap::new(),
378            end2end: false,
379            kpt_shape: None,
380        }
381    }
382}
383
384#[cfg(test)]
385mod tests {
386    use super::*;
387
388    const SAMPLE_METADATA: &str = r"
389description: Ultralytics YOLO11n model trained on /usr/src/ultralytics/ultralytics/cfg/datasets/coco.yaml
390author: Ultralytics
391date: '2025-12-11T20:19:45.464021'
392version: 8.3.236
393license: AGPL-3.0 License (https://ultralytics.com/license)
394docs: https://docs.ultralytics.com
395stride: 32
396task: detect
397batch: 1
398imgsz:
399- 640
400- 640
401names:
402  0: person
403  1: bicycle
404  2: car
405  3: motorcycle
406channels: 3
407";
408
409    #[test]
410    fn test_parse_metadata() {
411        let metadata = ModelMetadata::from_yaml_str(SAMPLE_METADATA).unwrap();
412
413        assert_eq!(metadata.task, Task::Detect);
414        assert_eq!(metadata.stride, 32);
415        assert_eq!(metadata.batch, 1);
416        assert_eq!(metadata.imgsz, Some((640, 640)));
417        assert_eq!(metadata.channels, 3);
418        assert_eq!(metadata.num_classes(), 4);
419        assert_eq!(metadata.class_name(0), Some("person"));
420        assert_eq!(metadata.class_name(1), Some("bicycle"));
421        assert_eq!(metadata.class_name(2), Some("car"));
422        assert_eq!(metadata.class_name(3), Some("motorcycle"));
423    }
424
425    #[test]
426    fn test_parse_inline_imgsz() {
427        let yaml = "task: detect\nimgsz: [640, 640]\nstride: 32";
428        let metadata = ModelMetadata::from_yaml_str(yaml).unwrap();
429        assert_eq!(metadata.imgsz, Some((640, 640)));
430    }
431
432    #[test]
433    fn test_default_metadata() {
434        let metadata = ModelMetadata::default();
435        assert_eq!(metadata.task, Task::Detect);
436        assert_eq!(metadata.stride, 32);
437        assert_eq!(metadata.imgsz, None);
438    }
439}