Skip to main content

ultralytics_inference/
inference.rs

1// Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
3//! Inference configuration and common types.
4//!
5//! This module defines the [`InferenceConfig`] struct, which controls various parameters
6//! for YOLO model inference, such as confidence thresholds, Non-Maximum Suppression (NMS),
7//! input image sizing, and hardware execution options.
8
9/// Configuration for YOLO inference.
10///
11/// This struct is used to customize the behavior of the inference engine.
12/// It uses a builder pattern for convenient construction.
13///
14/// # Examples
15///
16/// Basic configuration:
17/// ```rust
18/// use ultralytics_inference::InferenceConfig;
19///
20/// let config = InferenceConfig::new()
21///     .with_confidence(0.5)
22///     .with_iou(0.45)
23///     .with_max_det(300)
24///     .with_imgsz(640, 640);
25/// ```
26///
27/// With specific hardware device:
28/// ```rust
29/// use ultralytics_inference::{InferenceConfig, Device};
30///
31/// let config = InferenceConfig::new()
32///     .with_confidence(0.5)
33///     .with_device(Device::Cuda(0));
34/// ```
35#[derive(Debug, Clone)]
36#[allow(clippy::struct_excessive_bools)]
37pub struct InferenceConfig {
38    /// Confidence threshold for detections (0.0 to 1.0).
39    /// Detections with confidence scores lower than this value will be discarded.
40    pub confidence_threshold: f32,
41    /// Intersection over Union (`IoU`) threshold for Non-Maximum Suppression (NMS) (0.0 to 1.0).
42    /// Used to merge overlapping boxes. Lower values filter more duplicates.
43    pub iou_threshold: f32,
44    /// Maximum number of detections to return per image.
45    /// The top-k detections sorted by confidence will be returned.
46    pub max_det: usize,
47    /// Explicit input image size (height, width).
48    /// If `None`, the model's metadata will be used to determine input size.
49    pub imgsz: Option<(usize, usize)>,
50    /// Batch size for inference when using [`BatchProcessor`](crate::batch::BatchProcessor).
51    /// If `None`, defaults to 1 (single-image inference).
52    pub batch: Option<usize>,
53    /// Number of intra-op threads for ONNX Runtime.
54    /// Setting this to `0` allows ONNX Runtime to choose the optimal number.
55    pub num_threads: usize,
56    /// Whether to use FP16 (half-precision) inference.
57    /// This can improve performance on compatible hardware (e.g., GPUs) but may
58    /// result in slight precision loss.
59    pub half: bool,
60    /// Hardware device to use for inference.
61    /// If `None`, the best available device will be automatically selected.
62    pub device: Option<crate::Device>,
63    /// Whether to save annotated results.
64    /// Defaults to `true`.
65    pub save: bool,
66    /// Whether to save individual frames instead of a video file when input is video.
67    /// Defaults to `false` (save as video).
68    pub save_frames: bool,
69    /// Whether to use minimal padding (rectangular inference). Defaults to `true`.
70    pub rect: bool,
71    /// Class IDs to filter predictions. If `None`, all classes are returned.
72    /// Useful for focusing on specific objects in multi-class detection tasks.
73    pub classes: Option<Vec<usize>>,
74}
75
76impl Default for InferenceConfig {
77    fn default() -> Self {
78        Self {
79            confidence_threshold: Self::DEFAULT_CONF,
80            iou_threshold: Self::DEFAULT_IOU,
81            max_det: Self::DEFAULT_MAX_DET,
82            imgsz: None,
83            batch: None,
84            num_threads: 0, // 0 = let ONNX Runtime decide (typically uses all cores efficiently)
85            half: Self::DEFAULT_HALF,
86            device: None,
87            save: Self::DEFAULT_SAVE,
88            save_frames: Self::DEFAULT_SAVE_FRAMES,
89            rect: Self::DEFAULT_RECT,
90            classes: None,
91        }
92    }
93}
94
95impl InferenceConfig {
96    /// Default confidence threshold (0.0 to 1.0).
97    pub const DEFAULT_CONF: f32 = 0.25;
98    /// Default `IoU` threshold for NMS (0.0 to 1.0).
99    pub const DEFAULT_IOU: f32 = 0.7;
100    /// Default maximum number of detections per image.
101    pub const DEFAULT_MAX_DET: usize = 300;
102    /// Default for FP16 half-precision inference.
103    pub const DEFAULT_HALF: bool = false;
104    /// Default for saving annotated results.
105    pub const DEFAULT_SAVE: bool = true;
106    /// Default for saving individual frames (vs video).
107    pub const DEFAULT_SAVE_FRAMES: bool = false;
108    /// Default for rectangular (minimal padding) inference.
109    pub const DEFAULT_RECT: bool = true;
110
111    /// Create a new configuration with default values.
112    ///
113    /// # Returns
114    ///
115    /// * A new `InferenceConfig` instance with default settings.
116    #[must_use]
117    pub fn new() -> Self {
118        Self::default()
119    }
120
121    /// Set the batch size.
122    ///
123    /// # Arguments
124    ///
125    /// * `batch` - The batch size.
126    ///
127    /// # Returns
128    ///
129    /// * The modified `InferenceConfig`.
130    #[must_use]
131    pub const fn with_batch(mut self, batch: usize) -> Self {
132        self.batch = Some(batch);
133        self
134    }
135
136    /// Set the confidence threshold.
137    ///
138    /// Detections with a confidence score below this threshold will be filtered out.
139    ///
140    /// # Arguments
141    ///
142    /// * `threshold` - The minimum confidence score (0.0 to 1.0).
143    ///
144    /// # Returns
145    ///
146    /// * The modified `InferenceConfig`.
147    #[must_use]
148    pub const fn with_confidence(mut self, threshold: f32) -> Self {
149        self.confidence_threshold = threshold;
150        self
151    }
152
153    /// Set the `IoU` threshold for Non-Maximum Suppression (NMS).
154    ///
155    /// NMS suppresses overlapping bounding boxes. This threshold determines how much overlap
156    /// is allowed before boxes are considered duplicates.
157    ///
158    /// # Arguments
159    ///
160    /// * `threshold` - The `IoU` threshold (0.0 to 1.0).
161    ///
162    /// # Returns
163    ///
164    /// * The modified `InferenceConfig`.
165    #[must_use]
166    pub const fn with_iou(mut self, threshold: f32) -> Self {
167        self.iou_threshold = threshold;
168        self
169    }
170
171    /// Set the maximum number of detections to return.
172    ///
173    /// Only the top `max` detections (sorted by confidence) will be kept after NMS.
174    ///
175    /// # Arguments
176    ///
177    /// * `max` - The maximum number of detections.
178    ///
179    /// # Returns
180    ///
181    /// * The modified `InferenceConfig`.
182    #[must_use]
183    pub const fn with_max_det(mut self, max: usize) -> Self {
184        self.max_det = max;
185        self
186    }
187
188    /// Set the input image size.
189    ///
190    /// This explicitly sets the size to resize images to before inference.
191    /// If not set, the model's internal metadata size will be used.
192    ///
193    /// # Arguments
194    ///
195    /// * `height` - The target image height.
196    /// * `width` - The target image width.
197    ///
198    /// # Returns
199    ///
200    /// * The modified `InferenceConfig`.
201    #[must_use]
202    pub const fn with_imgsz(mut self, height: usize, width: usize) -> Self {
203        self.imgsz = Some((height, width));
204        self
205    }
206
207    /// Set the number of threads for inference.
208    ///
209    /// # Arguments
210    ///
211    /// * `threads` - The number of intra-op threads. Set to `0` for auto-configuration.
212    ///
213    /// # Returns
214    ///
215    /// * The modified `InferenceConfig`.
216    #[must_use]
217    pub const fn with_threads(mut self, threads: usize) -> Self {
218        self.num_threads = threads;
219        self
220    }
221
222    /// Enable or disable FP16 (half-precision) inference.
223    ///
224    /// Using FP16 can significantly speed up inference on GPUs and some CPUS,
225    /// at the cost of potential minor precision loss.
226    ///
227    /// # Arguments
228    ///
229    /// * `half` - `true` to enable FP16, `false` for FP32.
230    ///
231    /// # Returns
232    ///
233    /// * The modified `InferenceConfig`.
234    #[must_use]
235    pub const fn with_half(mut self, half: bool) -> Self {
236        self.half = half;
237        self
238    }
239
240    /// Set the hardware device for inference.
241    ///
242    /// # Arguments
243    ///
244    /// * `device` - The device to use (e.g., CPU, CUDA, MPS).
245    ///
246    /// # Example
247    ///
248    /// ```rust
249    /// use ultralytics_inference::{InferenceConfig, Device};
250    ///
251    /// let config = InferenceConfig::new()
252    ///     .with_device(Device::Mps); // Use Apple Metal Performance Shaders
253    /// ```
254    ///
255    /// # Returns
256    ///
257    /// * The modified `InferenceConfig`.
258    #[must_use]
259    pub const fn with_device(mut self, device: crate::Device) -> Self {
260        self.device = Some(device);
261        self
262    }
263
264    /// Set whether to save annotated results.
265    ///
266    /// # Arguments
267    ///
268    /// * `save` - `true` to save results, `false` to skip saving.
269    ///
270    /// # Returns
271    ///
272    /// * The modified `InferenceConfig`.
273    #[must_use]
274    pub const fn with_save(mut self, save: bool) -> Self {
275        self.save = save;
276        self
277    }
278
279    /// Set whether to save individual frames for video inputs.
280    ///
281    /// # Arguments
282    ///
283    /// * `save_frames` - `true` to save frames, `false` to save as video.
284    ///
285    /// # Returns
286    ///
287    /// * The modified `InferenceConfig`.
288    #[must_use]
289    pub const fn with_save_frames(mut self, save_frames: bool) -> Self {
290        self.save_frames = save_frames;
291        self
292    }
293
294    /// Set whether to use minimal padding (rectangular inference).
295    ///
296    /// # Arguments
297    ///
298    /// * `rect` - `true` to enable, `false` to disable.
299    ///
300    /// # Returns
301    ///
302    /// * The modified `InferenceConfig`.
303    #[must_use]
304    pub const fn with_rect(mut self, rect: bool) -> Self {
305        self.rect = rect;
306        self
307    }
308
309    /// Set the class IDs to filter predictions.
310    ///
311    /// Only detections belonging to the specified classes will be returned.
312    ///
313    /// # Arguments
314    ///
315    /// * `classes` - A vector of class IDs to keep.
316    ///
317    /// # Example
318    ///
319    /// ```rust
320    /// use ultralytics_inference::InferenceConfig;
321    ///
322    /// // Only detect persons (class 0) and cars (class 2)
323    /// let config = InferenceConfig::new()
324    ///     .with_classes(vec![0, 2]);
325    /// ```
326    ///
327    /// # Returns
328    ///
329    /// * The modified `InferenceConfig`.
330    #[must_use]
331    pub fn with_classes(mut self, classes: Vec<usize>) -> Self {
332        self.classes = Some(classes);
333        self
334    }
335    /// Check if a class should be included in the results.
336    ///
337    /// # Arguments
338    ///
339    /// * `class_id` - The class index to check.
340    ///
341    /// # Returns
342    ///
343    /// * `true` if the class should be kept.
344    /// * `false` if the class should be filtered out.
345    #[must_use]
346    pub fn keep_class(&self, class_id: usize) -> bool {
347        self.classes.as_ref().is_none_or(|c| c.contains(&class_id))
348    }
349}
350
351#[cfg(test)]
352mod tests {
353    use super::*;
354
355    #[test]
356    fn test_config_default() {
357        let config = InferenceConfig::default();
358        assert!((config.confidence_threshold - InferenceConfig::DEFAULT_CONF).abs() < f32::EPSILON);
359        assert!((config.iou_threshold - InferenceConfig::DEFAULT_IOU).abs() < f32::EPSILON);
360        assert_eq!(config.max_det, 300);
361    }
362
363    #[test]
364    fn test_config_builder() {
365        let config = InferenceConfig::new()
366            .with_confidence(0.5)
367            .with_iou(0.6)
368            .with_max_det(300)
369            .with_imgsz(640, 640)
370            .with_threads(8);
371
372        assert!((config.confidence_threshold - 0.5).abs() < f32::EPSILON);
373        assert!((config.iou_threshold - 0.6).abs() < f32::EPSILON);
374        assert_eq!(config.max_det, 300);
375        assert_eq!(config.imgsz, Some((640, 640)));
376        assert_eq!(config.num_threads, 8);
377    }
378
379    #[test]
380    fn test_keep_class() {
381        let config = InferenceConfig::default();
382        assert!(config.keep_class(0));
383        assert!(config.keep_class(100));
384
385        let config_filtered = InferenceConfig::new().with_classes(vec![1, 3]);
386        assert!(config_filtered.keep_class(1));
387        assert!(config_filtered.keep_class(3));
388        assert!(!config_filtered.keep_class(0));
389        assert!(!config_filtered.keep_class(2));
390    }
391}