sync_wait_group/
lib.rs

1//! Enables threads to synchronize the beginning or end of some computation.
2//!
3//! # Examples
4//!
5//! ```
6//! use sync_wait_group::WaitGroup;
7//! use std::thread;
8//!
9//! // Create a new wait group.
10//! let wg = WaitGroup::new();
11//!
12//! for _ in 0..4 {
13//!     // Create another reference to the wait group.
14//!     let wg = wg.clone();
15//!
16//!     thread::spawn(move || {
17//!         // Do some work.
18//!
19//!         // Drop the reference to the wait group.
20//!         drop(wg);
21//!     });
22//! }
23//!
24//! // Block until all threads have finished their work.
25//! wg.wait();
26//! ```
27
28use parking_lot::{Condvar, Mutex};
29use std::fmt;
30use std::sync::Arc;
31
32/// Enables threads to synchronize the beginning or end of some computation.
33pub struct WaitGroup {
34    inner: Arc<Inner>,
35}
36
37/// Inner state of a `WaitGroup`.
38struct Inner {
39    cvar: Condvar,
40    count: Mutex<usize>,
41}
42
43impl Default for WaitGroup {
44    #[inline]
45    fn default() -> Self {
46        WaitGroup::new()
47    }
48}
49
50impl WaitGroup {
51    /// Creates a new wait group and returns the single reference to it.
52    #[inline]
53    pub fn new() -> Self {
54        Self {
55            inner: Arc::new(Inner {
56                cvar: Condvar::new(),
57                count: Mutex::new(1),
58            }),
59        }
60    }
61
62    /// Drops this reference and waits until all other references are dropped.
63    #[inline]
64    pub fn wait(self) {
65        if *self.inner.count.lock() == 1 {
66            return;
67        }
68
69        let inner = self.inner.clone();
70        drop(self);
71
72        let mut count = inner.count.lock();
73        while *count > 0 {
74            inner.cvar.wait(&mut count);
75        }
76    }
77}
78
79impl Drop for WaitGroup {
80    #[inline]
81    fn drop(&mut self) {
82        let mut count = self.inner.count.lock();
83        *count -= 1;
84
85        if *count == 0 {
86            self.inner.cvar.notify_all();
87        }
88    }
89}
90
91impl Clone for WaitGroup {
92    #[inline]
93    fn clone(&self) -> WaitGroup {
94        let mut count = self.inner.count.lock();
95        *count += 1;
96
97        WaitGroup {
98            inner: self.inner.clone(),
99        }
100    }
101}
102
103impl fmt::Debug for WaitGroup {
104    #[inline]
105    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
106        let count: &usize = &*self.inner.count.lock();
107        f.debug_struct("WaitGroup").field("count", count).finish()
108    }
109}
110
111#[cfg(test)]
112mod tests {
113    use super::*;
114    use std::thread;
115    use std::time::Duration;
116
117    const THREADS: usize = 10;
118
119    #[test]
120    fn wait() {
121        let wg = WaitGroup::new();
122        let (tx, rx) = std::sync::mpsc::channel();
123
124        for _ in 0..THREADS {
125            let wg = wg.clone();
126            let tx = tx.clone();
127
128            thread::spawn(move || {
129                wg.wait();
130                tx.send(()).unwrap();
131            });
132        }
133
134        thread::sleep(Duration::from_millis(100));
135
136        // At this point, all spawned threads should be blocked, so we shouldn't get anything from the
137        // channel.
138        assert!(rx.try_recv().is_err());
139
140        wg.wait();
141
142        // Now, the wait group is cleared and we should receive messages.
143        for _ in 0..THREADS {
144            rx.recv().unwrap();
145        }
146    }
147
148    #[test]
149    fn wait_and_drop() {
150        let wg = WaitGroup::new();
151        let (tx, rx) = std::sync::mpsc::channel();
152
153        for _ in 0..THREADS {
154            let wg = wg.clone();
155            let tx = tx.clone();
156
157            thread::spawn(move || {
158                thread::sleep(Duration::from_millis(100));
159                tx.send(()).unwrap();
160                drop(wg);
161            });
162        }
163
164        // At this point, all spawned threads should be sleeping, so we shouldn't get anything from the
165        // channel.
166        assert!(rx.try_recv().is_err());
167
168        wg.wait();
169
170        // Now, the wait group is cleared and we should receive messages.
171        for _ in 0..THREADS {
172            rx.try_recv().unwrap();
173        }
174    }
175}