1use std::collections::{HashMap, HashSet};
2use std::future::Future;
3use std::hash::Hash;
4use std::pin::Pin;
5use std::sync::Arc;
6
7use crossbeam_queue::SegQueue;
8use tokio::task::{JoinError, JoinHandle};
9
10pub async fn work<T, W, F, C>(
11 context: C,
12 starting_items: SegQueue<W>,
13 worker: F,
14) -> Result<HashMap<W, T>, JoinError>
15where
16 W: std::fmt::Debug + Clone + Eq + Hash + Send + Sync + 'static,
17 F: Fn(W, Arc<SegQueue<W>>, Arc<C>) -> Pin<Box<dyn Future<Output = T> + Send + Sync + 'static>>
18 + Send
19 + Sync
20 + 'static,
21 T: Send + Sync + 'static,
22{
23 let context = Arc::new(context);
24 let work_queue = Arc::new(starting_items);
25 let fut_queue = Arc::new(SegQueue::<(W, JoinHandle<T>)>::new());
26
27 let mut is_processed = HashSet::<W>::new();
28 let mut completed = HashMap::<W, T>::new();
29
30 loop {
31 if fut_queue.is_empty() && work_queue.is_empty() {
32 log::trace!("All queues empty, breaking.");
33 break;
34 }
35
36 if let Ok(item) = fut_queue.pop() {
37 let (work, future) = item;
38 log::trace!("Got item from future queue for {:?}.", &work);
39 completed.insert(work, future.await?);
40 }
41
42 while let Ok(work) = work_queue.pop() {
43 if is_processed.contains(&work) {
44 log::trace!("Item {:?} already processed.", &work);
45 continue;
46 }
47
48 is_processed.insert(work.clone());
49
50 log::trace!("Processing {:?}", &work);
51 fut_queue.push((
52 work.clone(),
53 tokio::spawn(worker(work, Arc::clone(&work_queue), Arc::clone(&context))),
54 ));
55 }
56 }
57
58 Ok(completed)
59}
60
61#[cfg(test)]
62mod tests {
63 use super::*;
64
65 #[test]
66 fn smoke() {
67 use tokio::runtime::Runtime;
70
71 let mut rt = Runtime::new().unwrap();
72
73 let q = SegQueue::new();
74 q.push(1usize);
75 q.push(2usize);
76 q.push(3usize);
77
78 let result = rt
79 .block_on(work(20usize, q, |work, queue, context| {
80 Box::pin(async move {
81 if work <= *context {
83 queue.push(work + 1);
84 }
85 work * 2
86 })
87 }))
88 .unwrap();
89
90 println!("{:?}", result);
91 }
92}