1use crossbeam::channel;
7use std::thread;
8
9pub struct PrefetchIterator<T> {
32 receiver: channel::Receiver<Option<T>>,
33 _handle: thread::JoinHandle<()>,
34}
35
36impl<T> PrefetchIterator<T>
37where
38 T: Send + 'static,
39{
40 pub fn new<I>(inner: I, buffer_size: usize) -> Self
61 where
62 I: Iterator<Item = T> + Send + 'static,
63 {
64 let (sender, receiver) = channel::bounded(buffer_size);
65
66 let handle = thread::spawn(move || {
67 for item in inner {
68 if sender.send(Some(item)).is_err() {
69 break;
71 }
72 }
73 let _ = sender.send(None);
75 });
76
77 Self {
78 receiver,
79 _handle: handle,
80 }
81 }
82
83 pub fn new_unbounded<I>(inner: I) -> Self
101 where
102 I: Iterator<Item = T> + Send + 'static,
103 {
104 let (sender, receiver) = channel::unbounded();
105
106 let handle = thread::spawn(move || {
107 for item in inner {
108 if sender.send(Some(item)).is_err() {
109 break;
111 }
112 }
113 let _ = sender.send(None);
115 });
116
117 Self {
118 receiver,
119 _handle: handle,
120 }
121 }
122
123 pub fn buffer_len(&self) -> usize {
131 self.receiver.len()
132 }
133
134 pub fn buffer_is_empty(&self) -> bool {
140 self.receiver.is_empty()
141 }
142
143 pub fn try_next(&mut self) -> Option<T> {
152 match self.receiver.try_recv() {
153 Ok(Some(item)) => Some(item),
154 Ok(None) | Err(_) => None,
155 }
156 }
157}
158
159impl<T> Iterator for PrefetchIterator<T>
160where
161 T: Send + 'static,
162{
163 type Item = T;
164
165 fn next(&mut self) -> Option<Self::Item> {
166 match self.receiver.recv() {
167 Ok(Some(item)) => Some(item),
168 Ok(None) | Err(_) => None,
169 }
170 }
171}
172
173pub trait PrefetchExt<T>: Iterator<Item = T> + Sized + Send + 'static
191where
192 T: Send + 'static,
193{
194 fn prefetch(self, buffer_size: usize) -> PrefetchIterator<T> {
204 PrefetchIterator::new(self, buffer_size)
205 }
206
207 fn prefetch_unbounded(self) -> PrefetchIterator<T> {
217 PrefetchIterator::new_unbounded(self)
218 }
219}
220
221impl<I, T> PrefetchExt<T> for I
223where
224 I: Iterator<Item = T> + Send + 'static,
225 T: Send + 'static,
226{
227}
228
229#[derive(Debug, Clone)]
231pub struct PrefetchConfig {
232 pub buffer_size: usize,
234 pub unbounded: bool,
236}
237
238impl Default for PrefetchConfig {
239 fn default() -> Self {
240 Self {
241 buffer_size: 2,
242 unbounded: false,
243 }
244 }
245}
246
247impl PrefetchConfig {
248 pub fn new() -> Self {
250 Self::default()
251 }
252
253 pub fn buffer_size(mut self, size: usize) -> Self {
255 self.buffer_size = size;
256 self.unbounded = false;
257 self
258 }
259
260 pub fn unbounded(mut self) -> Self {
262 self.unbounded = true;
263 self
264 }
265
266 pub fn apply<I, T>(self, iter: I) -> PrefetchIterator<T>
268 where
269 I: Iterator<Item = T> + Send + 'static,
270 T: Send + 'static,
271 {
272 if self.unbounded {
273 PrefetchIterator::new_unbounded(iter)
274 } else {
275 PrefetchIterator::new(iter, self.buffer_size)
276 }
277 }
278}
279
280pub mod utils {
282 use super::*;
283
284 pub fn optimal_prefetch<I, T>(
298 iter: I,
299 expected_item_processing_time: u64,
300 ) -> PrefetchIterator<T>
301 where
302 I: Iterator<Item = T> + Send + 'static,
303 T: Send + 'static,
304 {
305 let buffer_size = if expected_item_processing_time > 100 {
307 2 } else if expected_item_processing_time > 10 {
309 4 } else {
311 8 };
313
314 PrefetchIterator::new(iter, buffer_size)
315 }
316
317 pub fn cpu_bound_prefetch<I, T>(iter: I) -> PrefetchIterator<T>
327 where
328 I: Iterator<Item = T> + Send + 'static,
329 T: Send + 'static,
330 {
331 PrefetchIterator::new(iter, 2)
333 }
334
335 pub fn io_bound_prefetch<I, T>(iter: I) -> PrefetchIterator<T>
345 where
346 I: Iterator<Item = T> + Send + 'static,
347 T: Send + 'static,
348 {
349 PrefetchIterator::new(iter, 8)
351 }
352}
353
354#[cfg(test)]
355mod tests {
356 use super::*;
357 use std::time::{Duration, Instant};
358
359 #[test]
360 fn test_prefetch_iterator_basic() {
361 let data = vec![1, 2, 3, 4, 5];
362 let iter = data.into_iter();
363 let mut prefetch_iter = PrefetchIterator::new(iter, 2);
364
365 assert_eq!(prefetch_iter.next(), Some(1));
366 assert_eq!(prefetch_iter.next(), Some(2));
367 assert_eq!(prefetch_iter.next(), Some(3));
368 assert_eq!(prefetch_iter.next(), Some(4));
369 assert_eq!(prefetch_iter.next(), Some(5));
370 assert_eq!(prefetch_iter.next(), None);
371 }
372
373 #[test]
374 fn test_prefetch_ext_trait() {
375 let data = vec![1, 2, 3, 4, 5];
376 let mut prefetch_iter = data.into_iter().prefetch(2);
377
378 assert_eq!(prefetch_iter.next(), Some(1));
379 assert_eq!(prefetch_iter.next(), Some(2));
380 assert_eq!(prefetch_iter.next(), Some(3));
381 assert_eq!(prefetch_iter.next(), Some(4));
382 assert_eq!(prefetch_iter.next(), Some(5));
383 assert_eq!(prefetch_iter.next(), None);
384 }
385
386 #[test]
387 fn test_prefetch_unbounded() {
388 let data = vec![1, 2, 3, 4, 5];
389 let mut prefetch_iter = data.into_iter().prefetch_unbounded();
390
391 assert_eq!(prefetch_iter.next(), Some(1));
392 assert_eq!(prefetch_iter.next(), Some(2));
393 assert_eq!(prefetch_iter.next(), Some(3));
394 assert_eq!(prefetch_iter.next(), Some(4));
395 assert_eq!(prefetch_iter.next(), Some(5));
396 assert_eq!(prefetch_iter.next(), None);
397 }
398
399 #[test]
400 fn test_prefetch_config() {
401 let config = PrefetchConfig::new().buffer_size(4);
402 assert_eq!(config.buffer_size, 4);
403 assert!(!config.unbounded);
404
405 let config = PrefetchConfig::new().unbounded();
406 assert!(config.unbounded);
407 }
408
409 #[test]
410 fn test_prefetch_config_apply() {
411 let data = vec![1, 2, 3, 4, 5];
412 let config = PrefetchConfig::new().buffer_size(3);
413 let mut prefetch_iter = config.apply(data.into_iter());
414
415 assert_eq!(prefetch_iter.next(), Some(1));
416 assert_eq!(prefetch_iter.next(), Some(2));
417 assert_eq!(prefetch_iter.next(), Some(3));
418 assert_eq!(prefetch_iter.next(), Some(4));
419 assert_eq!(prefetch_iter.next(), Some(5));
420 assert_eq!(prefetch_iter.next(), None);
421 }
422
423 #[test]
424 fn test_try_next() {
425 let data = vec![1, 2, 3];
426 let mut prefetch_iter = PrefetchIterator::new(data.into_iter(), 2);
427
428 std::thread::sleep(Duration::from_millis(10));
430
431 assert!(prefetch_iter.try_next().is_some());
433 }
434
435 #[test]
436 fn test_buffer_status() {
437 let data = vec![1, 2, 3, 4, 5];
438 let prefetch_iter = PrefetchIterator::new(data.into_iter(), 3);
439
440 let deadline = std::time::Instant::now() + Duration::from_millis(500);
444 while prefetch_iter.buffer_is_empty() && std::time::Instant::now() < deadline {
445 std::thread::sleep(Duration::from_millis(5));
446 }
447
448 assert!(!prefetch_iter.buffer_is_empty());
450 assert!(prefetch_iter.buffer_len() > 0);
451 }
452
453 #[test]
454 fn test_utils_optimal_prefetch() {
455 let data = vec![1, 2, 3, 4, 5];
456 let mut prefetch_iter = utils::optimal_prefetch(data.into_iter(), 50);
457
458 assert_eq!(prefetch_iter.next(), Some(1));
459 assert_eq!(prefetch_iter.next(), Some(2));
460 assert_eq!(prefetch_iter.next(), Some(3));
461 assert_eq!(prefetch_iter.next(), Some(4));
462 assert_eq!(prefetch_iter.next(), Some(5));
463 assert_eq!(prefetch_iter.next(), None);
464 }
465
466 #[test]
467 fn test_utils_cpu_bound_prefetch() {
468 let data = vec![1, 2, 3, 4, 5];
469 let mut prefetch_iter = utils::cpu_bound_prefetch(data.into_iter());
470
471 assert_eq!(prefetch_iter.next(), Some(1));
472 assert_eq!(prefetch_iter.next(), Some(2));
473 assert_eq!(prefetch_iter.next(), Some(3));
474 assert_eq!(prefetch_iter.next(), Some(4));
475 assert_eq!(prefetch_iter.next(), Some(5));
476 assert_eq!(prefetch_iter.next(), None);
477 }
478
479 #[test]
480 fn test_utils_io_bound_prefetch() {
481 let data = vec![1, 2, 3, 4, 5];
482 let mut prefetch_iter = utils::io_bound_prefetch(data.into_iter());
483
484 assert_eq!(prefetch_iter.next(), Some(1));
485 assert_eq!(prefetch_iter.next(), Some(2));
486 assert_eq!(prefetch_iter.next(), Some(3));
487 assert_eq!(prefetch_iter.next(), Some(4));
488 assert_eq!(prefetch_iter.next(), Some(5));
489 assert_eq!(prefetch_iter.next(), None);
490 }
491
492 #[test]
493 fn test_empty_iterator() {
494 let data: Vec<i32> = vec![];
495 let mut prefetch_iter = data.into_iter().prefetch(2);
496
497 assert_eq!(prefetch_iter.next(), None);
498 }
499
500 #[test]
501 fn test_prefetch_performance() {
502 let slow_iter = (0..10).map(|x| {
504 std::thread::sleep(Duration::from_millis(10));
505 x
506 });
507
508 let start = Instant::now();
509 let mut prefetch_iter = slow_iter.prefetch(3);
510
511 assert_eq!(prefetch_iter.next(), Some(0));
513 assert_eq!(prefetch_iter.next(), Some(1));
514 assert_eq!(prefetch_iter.next(), Some(2));
515
516 let elapsed = start.elapsed();
517
518 assert!(elapsed < Duration::from_millis(100));
521 }
522}