1use std::sync::{Arc, Condvar, Mutex};
4
5#[derive(Clone)]
25pub struct WaitGroup {
26 counter: Arc<(Mutex<i64>, Condvar)>,
27}
28
29impl WaitGroup {
30 pub fn new() -> Self {
31 WaitGroup {
32 counter: Arc::new((Mutex::new(0), Condvar::new())),
33 }
34 }
35 pub fn add(&self, delta: i64) {
39 let (lock, cvar) = &*self.counter;
40 let mut count = lock.lock().unwrap();
41 *count += delta;
42 if *count < 0 {
43 panic!("negative WaitGroup counter");
44 }
45 if *count == 0 {
46 cvar.notify_all();
47 }
48 }
49 pub fn done(&self) {
51 self.add(-1);
52 }
53 pub fn wait(&self) {
55 let (lock, cvar) = &*self.counter;
56 let mut count = lock.lock().unwrap();
57 while *count > 0 {
58 count = cvar.wait(count).unwrap();
59 }
60 }
61}
62
63impl Default for WaitGroup {
64 fn default() -> Self {
65 WaitGroup::new()
66 }
67}
68
69#[cfg(test)]
70mod tests {
71 use super::*;
72 use std::sync::atomic::{AtomicI64, Ordering};
73 use std::thread;
74
75 #[test]
76 fn it_works() {
77 let count = Arc::new(AtomicI64::new(0));
78 let wg = WaitGroup::default();
79 let n = 10;
80 for _ in 0..n {
81 let wg = wg.clone();
82 wg.add(1);
83 let count = count.clone();
85 count.fetch_add(1, Ordering::Relaxed);
86 thread::spawn(move || {
87 count.fetch_sub(1, Ordering::Relaxed);
89 wg.done();
90 });
91 }
92 wg.wait();
93 assert_eq!(count.load(Ordering::Relaxed), 0);
95 }
96}