Skip to main content

ultralytics_inference/
batch.rs

1// Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
3//! Batch processing module for YOLO inference.
4//!
5//! This module provides the [`BatchProcessor`] struct, which abstracts the logic for
6//! buffering images and running batch inference. It handles:
7//!
8//! - **Buffering**: Collects images until the batch size is reached
9//! - **Batch inference**: Runs inference on the full batch
10//! - **Automatic fallback**: Falls back to single-image inference if batch fails
11//! - **Callback invocation**: Invokes a user-provided callback with results
12//!
13//! # Usage
14//!
15//! ```no_run
16//! use ultralytics_inference::{YOLOModel, batch::BatchProcessor};
17//!
18//! let mut model = YOLOModel::load("yolo26n.onnx")?;
19//! let mut processor = BatchProcessor::new(&mut model, 4, |results, images, paths, metas| {
20//!     for (idx, result_vec) in results.iter().enumerate() {
21//!         println!("Image {}: {} detections", paths[idx], result_vec.len());
22//!     }
23//! });
24//!
25//! // Add images as they become available
26//! // processor.add(image, path, meta);
27//!
28//! // Don't forget to flush remaining images
29//! processor.flush();
30//! # Ok::<(), Box<dyn std::error::Error>>(())
31//! ```
32
33use crate::{Results, YOLOModel, source::SourceMeta};
34use image::DynamicImage;
35
36/// A processor for handling batch inference.
37///
38/// This struct manages collecting images into batches, running inference (with fallback),
39/// and invoking a callback with the results.
40///
41/// # Example
42///
43/// ```no_run
44/// use ultralytics_inference::{YOLOModel, batch::BatchProcessor};
45///
46/// fn main() -> Result<(), Box<dyn std::error::Error>> {
47///     let mut model = YOLOModel::load("yolo26n.onnx")?;
48///     let batch_size = 4;
49///     
50///     let mut processor = BatchProcessor::new(&mut model, batch_size, |results, images, paths, metas| {
51///         println!("Processed batch of {} images", results.len());
52///     });
53///     
54///     // Add images...
55///     // processor.add(image, path, meta);
56///     
57///     processor.flush();
58///     Ok(())
59/// }
60/// ```
61pub struct BatchProcessor<'a, F>
62where
63    F: FnMut(Vec<Vec<Results>>, &[DynamicImage], &[String], &[SourceMeta]),
64{
65    model: &'a mut YOLOModel,
66    batch_size: usize,
67    images: Vec<DynamicImage>,
68    paths: Vec<String>,
69    metas: Vec<SourceMeta>,
70    callback: F,
71}
72
73impl<'a, F> BatchProcessor<'a, F>
74where
75    F: FnMut(Vec<Vec<Results>>, &[DynamicImage], &[String], &[SourceMeta]),
76{
77    /// Create a new `BatchProcessor`.
78    ///
79    /// # Arguments
80    ///
81    /// * `model` - Mutable reference to the [`YOLOModel`] for inference.
82    /// * `batch_size` - Maximum number of images to collect before processing.
83    /// * `callback` - Closure invoked with batch results. Receives:
84    ///   - `Vec<Vec<Results>>` - Results for each image in the batch
85    ///   - `&[DynamicImage]` - The batch images
86    ///   - `&[String]` - Paths for each image
87    ///   - `&[SourceMeta]` - Metadata for each image
88    ///
89    /// # Returns
90    ///
91    /// A new `BatchProcessor` instance.
92    pub fn new(model: &'a mut YOLOModel, batch_size: usize, callback: F) -> Self {
93        Self {
94            model,
95            batch_size,
96            images: Vec::with_capacity(batch_size),
97            paths: Vec::with_capacity(batch_size),
98            metas: Vec::with_capacity(batch_size),
99            callback,
100        }
101    }
102
103    /// Add an image to the batch.
104    ///
105    /// If the batch becomes full (reaches `batch_size`), it is automatically processed
106    /// and the callback is invoked.
107    ///
108    /// # Arguments
109    ///
110    /// * `image` - The image to add.
111    /// * `path` - Path or identifier for this image.
112    /// * `meta` - Source metadata for this image.
113    pub fn add(&mut self, image: DynamicImage, path: String, meta: SourceMeta) {
114        self.images.push(image);
115        self.paths.push(path);
116        self.metas.push(meta);
117
118        if self.images.len() >= self.batch_size {
119            self.process();
120        }
121    }
122
123    /// Process any remaining images in the batch.
124    ///
125    /// This should be called after all images have been added to ensure
126    /// the last partial batch is processed. Has no effect if the batch is empty.
127    pub fn flush(&mut self) {
128        self.process();
129    }
130
131    fn process(&mut self) {
132        if self.images.is_empty() {
133            return;
134        }
135
136        let batch_results = self.run_inference();
137        (self.callback)(batch_results, &self.images, &self.paths, &self.metas);
138
139        self.images.clear();
140        self.paths.clear();
141        self.metas.clear();
142    }
143
144    fn run_inference(&mut self) -> Vec<Vec<Results>> {
145        if let Ok(batch_results) = self.model.predict_batch(&self.images, &self.paths) {
146            return batch_results;
147        }
148
149        eprintln!("WARNING ⚠️ Batch inference failed. Falling back to single-image inference...");
150
151        let mut fallback_results = Vec::with_capacity(self.images.len());
152        for (idx, img) in self.images.iter().enumerate() {
153            let path = &self.paths[idx];
154            match self.model.predict_image(img, path.clone()) {
155                Ok(results) => fallback_results.push(results),
156                Err(e) => {
157                    eprintln!("Error processing {path}: {e}");
158                    fallback_results.push(Vec::new());
159                }
160            }
161        }
162        fallback_results
163    }
164}
165
166#[cfg(test)]
167mod tests {
168    use super::*;
169    use serial_test::serial;
170    use std::cell::RefCell;
171    use std::rc::Rc;
172
173    /// Helper to load a test image from assets.
174    fn load_test_image() -> DynamicImage {
175        // Use bus.jpg which should exist in assets/
176        image::open("assets/bus.jpg")
177            .or_else(|_| image::open("assets/zidane.jpg"))
178            .unwrap_or_else(|_| DynamicImage::new_rgb8(640, 640))
179    }
180
181    /// Test that `BatchProcessor` correctly buffers images and invokes callback.
182    ///
183    /// Uses `batch_size=1` since the default yolo26n.onnx model only supports batch=1.
184    /// The model is auto-downloaded if not present.
185    #[test]
186    #[serial]
187    fn test_batch_processor_with_model() {
188        let mut model = YOLOModel::load("yolo26n.onnx").expect("Model should load");
189
190        let callback_count = Rc::new(RefCell::new(0));
191        let callback_count_clone = Rc::clone(&callback_count);
192
193        // Use batch_size=1 since default model only supports batch=1
194        let mut processor =
195            BatchProcessor::new(&mut model, 1, move |_results, _images, _paths, _metas| {
196                *callback_count_clone.borrow_mut() += 1;
197            });
198
199        // Load real test images
200        let img1 = load_test_image();
201        let img2 = load_test_image();
202
203        let meta = SourceMeta {
204            path: "test.jpg".to_string(),
205            frame_idx: 0,
206            total_frames: Some(1),
207            fps: None,
208        };
209
210        // Add first image - should trigger callback immediately (batch_size=1)
211        processor.add(img1, "img1.jpg".to_string(), meta.clone());
212        assert_eq!(*callback_count.borrow(), 1);
213
214        // Add second image - should trigger another callback
215        processor.add(img2, "img2.jpg".to_string(), meta);
216        assert_eq!(*callback_count.borrow(), 2);
217
218        // Flush should not trigger callback (batch is empty)
219        processor.flush();
220        assert_eq!(*callback_count.borrow(), 2);
221    }
222
223    /// Test that flush on empty processor does nothing.
224    #[test]
225    #[serial]
226    fn test_batch_processor_empty_flush() {
227        let mut model = YOLOModel::load("yolo26n.onnx").expect("Model should load");
228
229        let callback_count = Rc::new(RefCell::new(0));
230        let callback_count_clone = Rc::clone(&callback_count);
231
232        let mut processor =
233            BatchProcessor::new(&mut model, 1, move |_results, _images, _paths, _metas| {
234                *callback_count_clone.borrow_mut() += 1;
235            });
236
237        // Flush without adding anything should not call callback
238        processor.flush();
239        assert_eq!(*callback_count.borrow(), 0);
240    }
241
242    /// Test that callback is invoked correct number of times with results.
243    #[test]
244    #[serial]
245    fn test_batch_processor_callback_count() {
246        let mut model = YOLOModel::load("yolo26n.onnx").expect("Model should load");
247
248        let count = Rc::new(RefCell::new(0));
249        let count_clone = Rc::clone(&count);
250
251        // Use `batch_size=1` to work with default model (which only supports batch=1)
252        let mut processor =
253            BatchProcessor::new(&mut model, 1, move |_results, _images, _paths, _metas| {
254                *count_clone.borrow_mut() += 1;
255            });
256
257        let meta = SourceMeta {
258            path: "test.jpg".to_string(),
259            frame_idx: 0,
260            total_frames: Some(1),
261            fps: None,
262        };
263
264        // Add 3 images with batch_size=1
265        for i in 0..3 {
266            let img = load_test_image();
267            processor.add(img, format!("img{i}.jpg"), meta.clone());
268        }
269        processor.flush();
270
271        // Should have 3 callbacks (one per image since batch_size=1)
272        assert_eq!(*count.borrow(), 3);
273    }
274}