1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
use std::collections::{HashMap, HashSet};
use std::future::Future;
use std::hash::Hash;
use std::pin::Pin;
use std::sync::Arc;

use crossbeam_queue::SegQueue;
use tokio::task::{JoinError, JoinHandle};

pub async fn work<T, W, F, C>(
    context: C,
    starting_items: SegQueue<W>,
    worker: F,
) -> Result<HashMap<W, T>, JoinError>
where
    W: std::fmt::Debug + Clone + Eq + Hash + Send + Sync + 'static,
    F: Fn(W, Arc<SegQueue<W>>, Arc<C>) -> Pin<Box<dyn Future<Output = T> + Send + Sync + 'static>>
        + Send
        + Sync
        + 'static,
    T: Send + Sync + 'static,
{
    let context = Arc::new(context);
    let work_queue = Arc::new(starting_items);
    let fut_queue = Arc::new(SegQueue::<(W, JoinHandle<T>)>::new());

    let mut is_processed = HashSet::<W>::new();
    let mut completed = HashMap::<W, T>::new();

    loop {
        if fut_queue.is_empty() && work_queue.is_empty() {
            log::trace!("All queues empty, breaking.");
            break;
        }

        if let Ok(item) = fut_queue.pop() {
            let (work, future) = item;
            log::trace!("Got item from future queue for {:?}.", &work);
            completed.insert(work, future.await?);
        }

        while let Ok(work) = work_queue.pop() {
            if is_processed.contains(&work) {
                log::trace!("Item {:?} already processed.", &work);
                continue;
            }

            is_processed.insert(work.clone());

            log::trace!("Processing {:?}", &work);
            fut_queue.push((
                work.clone(),
                tokio::spawn(worker(work, Arc::clone(&work_queue), Arc::clone(&context))),
            ));
        }
    }

    Ok(completed)
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn smoke() {
        // env_logger::init();

        use tokio::runtime::Runtime;

        let mut rt = Runtime::new().unwrap();

        let q = SegQueue::new();
        q.push(1usize);
        q.push(2usize);
        q.push(3usize);

        let result = rt
            .block_on(work(20usize, q, |work, queue, context| {
                Box::pin(async move {
                    // println!("{:?}", work);
                    if work <= *context {
                        queue.push(work + 1);
                    }
                    work * 2
                })
            }))
            .unwrap();

        println!("{:?}", result);
    }
}