Skip to main content

torsh_data/dataloader/
prefetch.rs

1//! Prefetch functionality for performance optimization
2//!
3//! This module provides prefetching capabilities to improve DataLoader performance
4//! by loading data in a separate thread while the main thread processes batches.
5
6use crossbeam::channel;
7use std::thread;
8
9/// Prefetch iterator for performance optimization
10///
11/// This iterator wraps another iterator and prefetches items in a background thread,
12/// allowing for overlapping computation and data loading to improve overall throughput.
13///
14/// # Type Parameters
15///
16/// * `T` - The item type that the iterator yields
17///
18/// # Examples
19///
20/// ```rust,ignore
21/// use torsh_data::dataloader::prefetch::{PrefetchIterator, PrefetchExt};
22///
23/// let data = vec![1, 2, 3, 4, 5];
24/// let iter = data.into_iter();
25/// let prefetch_iter = iter.prefetch(2); // Buffer size of 2
26///
27/// for item in prefetch_iter {
28///     // Process item while next items are being prefetched
29/// }
30/// ```
31pub struct PrefetchIterator<T> {
32    receiver: channel::Receiver<Option<T>>,
33    _handle: thread::JoinHandle<()>,
34}
35
36impl<T> PrefetchIterator<T>
37where
38    T: Send + 'static,
39{
40    /// Create a new prefetch iterator
41    ///
42    /// # Arguments
43    ///
44    /// * `inner` - The iterator to wrap with prefetching
45    /// * `buffer_size` - The size of the prefetch buffer
46    ///
47    /// # Returns
48    ///
49    /// A new PrefetchIterator that will prefetch items from the inner iterator
50    ///
51    /// # Examples
52    ///
53    /// ```rust,ignore
54    /// use torsh_data::dataloader::prefetch::PrefetchIterator;
55    ///
56    /// let data = vec![1, 2, 3, 4, 5];
57    /// let iter = data.into_iter();
58    /// let prefetch_iter = PrefetchIterator::new(iter, 3);
59    /// ```
60    pub fn new<I>(inner: I, buffer_size: usize) -> Self
61    where
62        I: Iterator<Item = T> + Send + 'static,
63    {
64        let (sender, receiver) = channel::bounded(buffer_size);
65
66        let handle = thread::spawn(move || {
67            for item in inner {
68                if sender.send(Some(item)).is_err() {
69                    // Receiver has been dropped, stop producing
70                    break;
71                }
72            }
73            // Send None to signal end of iteration
74            let _ = sender.send(None);
75        });
76
77        Self {
78            receiver,
79            _handle: handle,
80        }
81    }
82
83    /// Create a new prefetch iterator with unbounded buffer
84    ///
85    /// This creates a prefetch iterator with an unbounded channel, which can be useful
86    /// when memory usage is not a concern and maximum throughput is desired.
87    ///
88    /// # Arguments
89    ///
90    /// * `inner` - The iterator to wrap with prefetching
91    ///
92    /// # Returns
93    ///
94    /// A new PrefetchIterator with unbounded buffering
95    ///
96    /// # Warning
97    ///
98    /// Using an unbounded buffer can lead to excessive memory usage if the consumer
99    /// is significantly slower than the producer.
100    pub fn new_unbounded<I>(inner: I) -> Self
101    where
102        I: Iterator<Item = T> + Send + 'static,
103    {
104        let (sender, receiver) = channel::unbounded();
105
106        let handle = thread::spawn(move || {
107            for item in inner {
108                if sender.send(Some(item)).is_err() {
109                    // Receiver has been dropped, stop producing
110                    break;
111                }
112            }
113            // Send None to signal end of iteration
114            let _ = sender.send(None);
115        });
116
117        Self {
118            receiver,
119            _handle: handle,
120        }
121    }
122
123    /// Get the number of items currently in the prefetch buffer
124    ///
125    /// This can be useful for monitoring the prefetch buffer utilization.
126    ///
127    /// # Returns
128    ///
129    /// The number of items currently buffered
130    pub fn buffer_len(&self) -> usize {
131        self.receiver.len()
132    }
133
134    /// Check if the prefetch buffer is empty
135    ///
136    /// # Returns
137    ///
138    /// True if the buffer is empty, false otherwise
139    pub fn buffer_is_empty(&self) -> bool {
140        self.receiver.is_empty()
141    }
142
143    /// Try to get the next item without blocking
144    ///
145    /// This is useful when you want to check if an item is available without
146    /// blocking the current thread.
147    ///
148    /// # Returns
149    ///
150    /// Some(item) if an item is available, None if no item is ready or the iterator is exhausted
151    pub fn try_next(&mut self) -> Option<T> {
152        match self.receiver.try_recv() {
153            Ok(Some(item)) => Some(item),
154            Ok(None) | Err(_) => None,
155        }
156    }
157}
158
159impl<T> Iterator for PrefetchIterator<T>
160where
161    T: Send + 'static,
162{
163    type Item = T;
164
165    fn next(&mut self) -> Option<Self::Item> {
166        match self.receiver.recv() {
167            Ok(Some(item)) => Some(item),
168            Ok(None) | Err(_) => None,
169        }
170    }
171}
172
173/// Extension trait for adding prefetching to iterators
174///
175/// This trait provides convenient methods for adding prefetching capabilities
176/// to any iterator that meets the requirements (Send + 'static items).
177///
178/// # Examples
179///
180/// ```rust,ignore
181/// use torsh_data::dataloader::prefetch::PrefetchExt;
182///
183/// let data = vec![1, 2, 3, 4, 5];
184/// let prefetch_iter = data.into_iter().prefetch(2);
185///
186/// for item in prefetch_iter {
187///     // Process item while next items are being prefetched
188/// }
189/// ```
190pub trait PrefetchExt<T>: Iterator<Item = T> + Sized + Send + 'static
191where
192    T: Send + 'static,
193{
194    /// Add prefetching to the iterator
195    ///
196    /// # Arguments
197    ///
198    /// * `buffer_size` - The size of the prefetch buffer
199    ///
200    /// # Returns
201    ///
202    /// A PrefetchIterator that will prefetch items from this iterator
203    fn prefetch(self, buffer_size: usize) -> PrefetchIterator<T> {
204        PrefetchIterator::new(self, buffer_size)
205    }
206
207    /// Add unbounded prefetching to the iterator
208    ///
209    /// # Returns
210    ///
211    /// A PrefetchIterator with unbounded buffering
212    ///
213    /// # Warning
214    ///
215    /// This can lead to excessive memory usage if the consumer is slower than the producer
216    fn prefetch_unbounded(self) -> PrefetchIterator<T> {
217        PrefetchIterator::new_unbounded(self)
218    }
219}
220
221/// Blanket implementation of PrefetchExt for all compatible iterators
222impl<I, T> PrefetchExt<T> for I
223where
224    I: Iterator<Item = T> + Send + 'static,
225    T: Send + 'static,
226{
227}
228
229/// Configuration for prefetch operations
230#[derive(Debug, Clone)]
231pub struct PrefetchConfig {
232    /// Size of the prefetch buffer
233    pub buffer_size: usize,
234    /// Whether to use unbounded buffering
235    pub unbounded: bool,
236}
237
238impl Default for PrefetchConfig {
239    fn default() -> Self {
240        Self {
241            buffer_size: 2,
242            unbounded: false,
243        }
244    }
245}
246
247impl PrefetchConfig {
248    /// Create a new prefetch configuration
249    pub fn new() -> Self {
250        Self::default()
251    }
252
253    /// Set the buffer size
254    pub fn buffer_size(mut self, size: usize) -> Self {
255        self.buffer_size = size;
256        self.unbounded = false;
257        self
258    }
259
260    /// Enable unbounded buffering
261    pub fn unbounded(mut self) -> Self {
262        self.unbounded = true;
263        self
264    }
265
266    /// Apply this configuration to an iterator
267    pub fn apply<I, T>(self, iter: I) -> PrefetchIterator<T>
268    where
269        I: Iterator<Item = T> + Send + 'static,
270        T: Send + 'static,
271    {
272        if self.unbounded {
273            PrefetchIterator::new_unbounded(iter)
274        } else {
275            PrefetchIterator::new(iter, self.buffer_size)
276        }
277    }
278}
279
280/// Utility functions for prefetch operations
281pub mod utils {
282    use super::*;
283
284    /// Create a prefetch iterator with optimal buffer size
285    ///
286    /// This function automatically determines an appropriate buffer size based on
287    /// heuristics and the expected workload characteristics.
288    ///
289    /// # Arguments
290    ///
291    /// * `iter` - The iterator to wrap
292    /// * `expected_item_processing_time` - Expected time to process each item in milliseconds
293    ///
294    /// # Returns
295    ///
296    /// A PrefetchIterator with an optimized buffer size
297    pub fn optimal_prefetch<I, T>(
298        iter: I,
299        expected_item_processing_time: u64,
300    ) -> PrefetchIterator<T>
301    where
302        I: Iterator<Item = T> + Send + 'static,
303        T: Send + 'static,
304    {
305        // Simple heuristic: buffer size inversely related to processing time
306        let buffer_size = if expected_item_processing_time > 100 {
307            2 // Small buffer for expensive operations
308        } else if expected_item_processing_time > 10 {
309            4 // Medium buffer for moderate operations
310        } else {
311            8 // Larger buffer for fast operations
312        };
313
314        PrefetchIterator::new(iter, buffer_size)
315    }
316
317    /// Create a prefetch iterator optimized for CPU-bound tasks
318    ///
319    /// # Arguments
320    ///
321    /// * `iter` - The iterator to wrap
322    ///
323    /// # Returns
324    ///
325    /// A PrefetchIterator configured for CPU-bound workloads
326    pub fn cpu_bound_prefetch<I, T>(iter: I) -> PrefetchIterator<T>
327    where
328        I: Iterator<Item = T> + Send + 'static,
329        T: Send + 'static,
330    {
331        // For CPU-bound tasks, use a smaller buffer to avoid excessive memory usage
332        PrefetchIterator::new(iter, 2)
333    }
334
335    /// Create a prefetch iterator optimized for I/O-bound tasks
336    ///
337    /// # Arguments
338    ///
339    /// * `iter` - The iterator to wrap
340    ///
341    /// # Returns
342    ///
343    /// A PrefetchIterator configured for I/O-bound workloads
344    pub fn io_bound_prefetch<I, T>(iter: I) -> PrefetchIterator<T>
345    where
346        I: Iterator<Item = T> + Send + 'static,
347        T: Send + 'static,
348    {
349        // For I/O-bound tasks, use a larger buffer to hide I/O latency
350        PrefetchIterator::new(iter, 8)
351    }
352}
353
354#[cfg(test)]
355mod tests {
356    use super::*;
357    use std::time::{Duration, Instant};
358
359    #[test]
360    fn test_prefetch_iterator_basic() {
361        let data = vec![1, 2, 3, 4, 5];
362        let iter = data.into_iter();
363        let mut prefetch_iter = PrefetchIterator::new(iter, 2);
364
365        assert_eq!(prefetch_iter.next(), Some(1));
366        assert_eq!(prefetch_iter.next(), Some(2));
367        assert_eq!(prefetch_iter.next(), Some(3));
368        assert_eq!(prefetch_iter.next(), Some(4));
369        assert_eq!(prefetch_iter.next(), Some(5));
370        assert_eq!(prefetch_iter.next(), None);
371    }
372
373    #[test]
374    fn test_prefetch_ext_trait() {
375        let data = vec![1, 2, 3, 4, 5];
376        let mut prefetch_iter = data.into_iter().prefetch(2);
377
378        assert_eq!(prefetch_iter.next(), Some(1));
379        assert_eq!(prefetch_iter.next(), Some(2));
380        assert_eq!(prefetch_iter.next(), Some(3));
381        assert_eq!(prefetch_iter.next(), Some(4));
382        assert_eq!(prefetch_iter.next(), Some(5));
383        assert_eq!(prefetch_iter.next(), None);
384    }
385
386    #[test]
387    fn test_prefetch_unbounded() {
388        let data = vec![1, 2, 3, 4, 5];
389        let mut prefetch_iter = data.into_iter().prefetch_unbounded();
390
391        assert_eq!(prefetch_iter.next(), Some(1));
392        assert_eq!(prefetch_iter.next(), Some(2));
393        assert_eq!(prefetch_iter.next(), Some(3));
394        assert_eq!(prefetch_iter.next(), Some(4));
395        assert_eq!(prefetch_iter.next(), Some(5));
396        assert_eq!(prefetch_iter.next(), None);
397    }
398
399    #[test]
400    fn test_prefetch_config() {
401        let config = PrefetchConfig::new().buffer_size(4);
402        assert_eq!(config.buffer_size, 4);
403        assert!(!config.unbounded);
404
405        let config = PrefetchConfig::new().unbounded();
406        assert!(config.unbounded);
407    }
408
409    #[test]
410    fn test_prefetch_config_apply() {
411        let data = vec![1, 2, 3, 4, 5];
412        let config = PrefetchConfig::new().buffer_size(3);
413        let mut prefetch_iter = config.apply(data.into_iter());
414
415        assert_eq!(prefetch_iter.next(), Some(1));
416        assert_eq!(prefetch_iter.next(), Some(2));
417        assert_eq!(prefetch_iter.next(), Some(3));
418        assert_eq!(prefetch_iter.next(), Some(4));
419        assert_eq!(prefetch_iter.next(), Some(5));
420        assert_eq!(prefetch_iter.next(), None);
421    }
422
423    #[test]
424    fn test_try_next() {
425        let data = vec![1, 2, 3];
426        let mut prefetch_iter = PrefetchIterator::new(data.into_iter(), 2);
427
428        // Give the prefetch thread a moment to work
429        std::thread::sleep(Duration::from_millis(10));
430
431        // Should have at least one item prefetched
432        assert!(prefetch_iter.try_next().is_some());
433    }
434
435    #[test]
436    fn test_buffer_status() {
437        let data = vec![1, 2, 3, 4, 5];
438        let prefetch_iter = PrefetchIterator::new(data.into_iter(), 3);
439
440        // Poll until the prefetch thread has had a chance to fill the buffer.
441        // A fixed sleep is unreliable under heavy test-suite load; retry with
442        // a generous timeout so the test is not flaky on slow CI runners.
443        let deadline = std::time::Instant::now() + Duration::from_millis(500);
444        while prefetch_iter.buffer_is_empty() && std::time::Instant::now() < deadline {
445            std::thread::sleep(Duration::from_millis(5));
446        }
447
448        // Buffer should not be empty after prefetching starts
449        assert!(!prefetch_iter.buffer_is_empty());
450        assert!(prefetch_iter.buffer_len() > 0);
451    }
452
453    #[test]
454    fn test_utils_optimal_prefetch() {
455        let data = vec![1, 2, 3, 4, 5];
456        let mut prefetch_iter = utils::optimal_prefetch(data.into_iter(), 50);
457
458        assert_eq!(prefetch_iter.next(), Some(1));
459        assert_eq!(prefetch_iter.next(), Some(2));
460        assert_eq!(prefetch_iter.next(), Some(3));
461        assert_eq!(prefetch_iter.next(), Some(4));
462        assert_eq!(prefetch_iter.next(), Some(5));
463        assert_eq!(prefetch_iter.next(), None);
464    }
465
466    #[test]
467    fn test_utils_cpu_bound_prefetch() {
468        let data = vec![1, 2, 3, 4, 5];
469        let mut prefetch_iter = utils::cpu_bound_prefetch(data.into_iter());
470
471        assert_eq!(prefetch_iter.next(), Some(1));
472        assert_eq!(prefetch_iter.next(), Some(2));
473        assert_eq!(prefetch_iter.next(), Some(3));
474        assert_eq!(prefetch_iter.next(), Some(4));
475        assert_eq!(prefetch_iter.next(), Some(5));
476        assert_eq!(prefetch_iter.next(), None);
477    }
478
479    #[test]
480    fn test_utils_io_bound_prefetch() {
481        let data = vec![1, 2, 3, 4, 5];
482        let mut prefetch_iter = utils::io_bound_prefetch(data.into_iter());
483
484        assert_eq!(prefetch_iter.next(), Some(1));
485        assert_eq!(prefetch_iter.next(), Some(2));
486        assert_eq!(prefetch_iter.next(), Some(3));
487        assert_eq!(prefetch_iter.next(), Some(4));
488        assert_eq!(prefetch_iter.next(), Some(5));
489        assert_eq!(prefetch_iter.next(), None);
490    }
491
492    #[test]
493    fn test_empty_iterator() {
494        let data: Vec<i32> = vec![];
495        let mut prefetch_iter = data.into_iter().prefetch(2);
496
497        assert_eq!(prefetch_iter.next(), None);
498    }
499
500    #[test]
501    fn test_prefetch_performance() {
502        // Create a slow iterator that simulates expensive computation
503        let slow_iter = (0..10).map(|x| {
504            std::thread::sleep(Duration::from_millis(10));
505            x
506        });
507
508        let start = Instant::now();
509        let mut prefetch_iter = slow_iter.prefetch(3);
510
511        // Consume the first few items
512        assert_eq!(prefetch_iter.next(), Some(0));
513        assert_eq!(prefetch_iter.next(), Some(1));
514        assert_eq!(prefetch_iter.next(), Some(2));
515
516        let elapsed = start.elapsed();
517
518        // With prefetching, we should be able to get the first few items faster
519        // than if we had to wait for all the computation sequentially
520        assert!(elapsed < Duration::from_millis(100));
521    }
522}