rust_waitgroup/
lib.rs

1//! Golang like WaitGroup
2
3use std::sync::{Arc, Condvar, Mutex};
4
5/// A WaitGroup waits for a collection of coroutines to finish.
6///
7/// # Examples
8/// ```
9/// use rust_waitgroup::WaitGroup;
10/// use std::thread;
11///
12/// let wg = WaitGroup::default();
13/// let n = 10;
14/// for _ in 0..n {
15///    let wg = wg.clone();
16///    wg.add(1);
17///    thread::spawn(move || {
18///         // do some work
19///         wg.done();
20///    });
21/// }
22/// wg.wait();
23/// ```
24#[derive(Clone)]
25pub struct WaitGroup {
26    counter: Arc<(Mutex<i64>, Condvar)>,
27}
28
29impl WaitGroup {
30    pub fn new() -> Self {
31        WaitGroup {
32            counter: Arc::new((Mutex::new(0), Condvar::new())),
33        }
34    }
35    /// add adds delta, which may be negative, to the WaitGroup counter.
36    /// if the counter becomes zero, all coroutines blocked on Wait are released.
37    /// if the counter goes negative, add panics
38    pub fn add(&self, delta: i64) {
39        let (lock, cvar) = &*self.counter;
40        let mut count = lock.lock().unwrap();
41        *count += delta;
42        if *count < 0 {
43            panic!("negative WaitGroup counter");
44        }
45        if *count == 0 {
46            cvar.notify_all();
47        }
48    }
49    /// done decrements the WaitGroup counter by one.
50    pub fn done(&self) {
51        self.add(-1);
52    }
53    /// wait blocks until the WaitGroup counter is zero.
54    pub fn wait(&self) {
55        let (lock, cvar) = &*self.counter;
56        let mut count = lock.lock().unwrap();
57        while *count > 0 {
58            count = cvar.wait(count).unwrap();
59        }
60    }
61}
62
63impl Default for WaitGroup {
64    fn default() -> Self {
65        WaitGroup::new()
66    }
67}
68
69#[cfg(test)]
70mod tests {
71    use super::*;
72    use std::sync::atomic::{AtomicI64, Ordering};
73    use std::thread;
74
75    #[test]
76    fn it_works() {
77        let count = Arc::new(AtomicI64::new(0));
78        let wg = WaitGroup::default();
79        let n = 10;
80        for _ in 0..n {
81            let wg = wg.clone();
82            wg.add(1);
83            // count += 1
84            let count = count.clone();
85            count.fetch_add(1, Ordering::Relaxed);
86            thread::spawn(move || {
87                // count -= 1
88                count.fetch_sub(1, Ordering::Relaxed);
89                wg.done();
90            });
91        }
92        wg.wait();
93        // assert count == 0
94        assert_eq!(count.load(Ordering::Relaxed), 0);
95    }
96}