workqueue/
lib.rs

1use std::collections::{HashMap, HashSet};
2use std::future::Future;
3use std::hash::Hash;
4use std::pin::Pin;
5use std::sync::Arc;
6
7use crossbeam_queue::SegQueue;
8use tokio::task::{JoinError, JoinHandle};
9
10pub async fn work<T, W, F, C>(
11    context: C,
12    starting_items: SegQueue<W>,
13    worker: F,
14) -> Result<HashMap<W, T>, JoinError>
15where
16    W: std::fmt::Debug + Clone + Eq + Hash + Send + Sync + 'static,
17    F: Fn(W, Arc<SegQueue<W>>, Arc<C>) -> Pin<Box<dyn Future<Output = T> + Send + Sync + 'static>>
18        + Send
19        + Sync
20        + 'static,
21    T: Send + Sync + 'static,
22{
23    let context = Arc::new(context);
24    let work_queue = Arc::new(starting_items);
25    let fut_queue = Arc::new(SegQueue::<(W, JoinHandle<T>)>::new());
26
27    let mut is_processed = HashSet::<W>::new();
28    let mut completed = HashMap::<W, T>::new();
29
30    loop {
31        if fut_queue.is_empty() && work_queue.is_empty() {
32            log::trace!("All queues empty, breaking.");
33            break;
34        }
35
36        if let Ok(item) = fut_queue.pop() {
37            let (work, future) = item;
38            log::trace!("Got item from future queue for {:?}.", &work);
39            completed.insert(work, future.await?);
40        }
41
42        while let Ok(work) = work_queue.pop() {
43            if is_processed.contains(&work) {
44                log::trace!("Item {:?} already processed.", &work);
45                continue;
46            }
47
48            is_processed.insert(work.clone());
49
50            log::trace!("Processing {:?}", &work);
51            fut_queue.push((
52                work.clone(),
53                tokio::spawn(worker(work, Arc::clone(&work_queue), Arc::clone(&context))),
54            ));
55        }
56    }
57
58    Ok(completed)
59}
60
61#[cfg(test)]
62mod tests {
63    use super::*;
64
65    #[test]
66    fn smoke() {
67        // env_logger::init();
68
69        use tokio::runtime::Runtime;
70
71        let mut rt = Runtime::new().unwrap();
72
73        let q = SegQueue::new();
74        q.push(1usize);
75        q.push(2usize);
76        q.push(3usize);
77
78        let result = rt
79            .block_on(work(20usize, q, |work, queue, context| {
80                Box::pin(async move {
81                    // println!("{:?}", work);
82                    if work <= *context {
83                        queue.push(work + 1);
84                    }
85                    work * 2
86                })
87            }))
88            .unwrap();
89
90        println!("{:?}", result);
91    }
92}