async_std/sync/
barrier.rs

1use broadcaster::BroadcastChannel;
2
3use crate::sync::Mutex;
4
5/// A barrier enables multiple tasks to synchronize the beginning
6/// of some computation.
7///
8/// # Examples
9///
10/// ```
11/// # async_std::task::block_on(async {
12/// #
13/// use async_std::sync::{Arc, Barrier};
14/// use async_std::task;
15///
16/// let mut handles = Vec::with_capacity(10);
17/// let barrier = Arc::new(Barrier::new(10));
18/// for _ in 0..10 {
19///     let c = barrier.clone();
20///     // The same messages will be printed together.
21///     // You will NOT see any interleaving.
22///     handles.push(task::spawn(async move {
23///         println!("before wait");
24///         c.wait().await;
25///         println!("after wait");
26///     }));
27/// }
28/// // Wait for the other futures to finish.
29/// for handle in handles {
30///     handle.await;
31/// }
32/// # });
33/// ```
34#[cfg(feature = "unstable")]
35#[cfg_attr(feature = "docs", doc(cfg(unstable)))]
36#[derive(Debug)]
37pub struct Barrier {
38    state: Mutex<BarrierState>,
39    wait: BroadcastChannel<(usize, usize)>,
40    n: usize,
41}
42
43// The inner state of a double barrier
44#[derive(Debug)]
45struct BarrierState {
46    waker: BroadcastChannel<(usize, usize)>,
47    count: usize,
48    generation_id: usize,
49}
50
51/// A `BarrierWaitResult` is returned by `wait` when all threads in the `Barrier` have rendezvoused.
52///
53/// [`wait`]: struct.Barrier.html#method.wait
54/// [`Barrier`]: struct.Barrier.html
55///
56/// # Examples
57///
58/// ```
59/// use async_std::sync::Barrier;
60///
61/// let barrier = Barrier::new(1);
62/// let barrier_wait_result = barrier.wait();
63/// ```
64#[cfg(feature = "unstable")]
65#[cfg_attr(feature = "docs", doc(cfg(unstable)))]
66#[derive(Debug, Clone)]
67pub struct BarrierWaitResult(bool);
68
69impl Barrier {
70    /// Creates a new barrier that can block a given number of tasks.
71    ///
72    /// A barrier will block `n`-1 tasks which call [`wait`] and then wake up
73    /// all tasks at once when the `n`th task calls [`wait`].
74    ///
75    /// [`wait`]: #method.wait
76    ///
77    /// # Examples
78    ///
79    /// ```
80    /// use std::sync::Barrier;
81    ///
82    /// let barrier = Barrier::new(10);
83    /// ```
84    pub fn new(mut n: usize) -> Barrier {
85        let waker = BroadcastChannel::new();
86        let wait = waker.clone();
87
88        if n == 0 {
89            // if n is 0, it's not clear what behavior the user wants.
90            // in std::sync::Barrier, an n of 0 exhibits the same behavior as n == 1, where every
91            // .wait() immediately unblocks, so we adopt that here as well.
92            n = 1;
93        }
94
95        Barrier {
96            state: Mutex::new(BarrierState {
97                waker,
98                count: 0,
99                generation_id: 1,
100            }),
101            n,
102            wait,
103        }
104    }
105
106    /// Blocks the current task until all tasks have rendezvoused here.
107    ///
108    /// Barriers are re-usable after all tasks have rendezvoused once, and can
109    /// be used continuously.
110    ///
111    /// A single (arbitrary) task will receive a [`BarrierWaitResult`] that
112    /// returns `true` from [`is_leader`] when returning from this function, and
113    /// all other tasks will receive a result that will return `false` from
114    /// [`is_leader`].
115    ///
116    /// [`BarrierWaitResult`]: struct.BarrierWaitResult.html
117    /// [`is_leader`]: struct.BarrierWaitResult.html#method.is_leader
118    ///
119    /// # Examples
120    ///
121    /// ```
122    /// # async_std::task::block_on(async {
123    /// #
124    /// use async_std::sync::{Arc, Barrier};
125    /// use async_std::task;
126    ///
127    /// let mut handles = Vec::with_capacity(10);
128    /// let barrier = Arc::new(Barrier::new(10));
129    /// for _ in 0..10 {
130    ///     let c = barrier.clone();
131    ///     // The same messages will be printed together.
132    ///     // You will NOT see any interleaving.
133    ///     handles.push(task::spawn(async move {
134    ///         println!("before wait");
135    ///         c.wait().await;
136    ///         println!("after wait");
137    ///     }));
138    /// }
139    /// // Wait for the other futures to finish.
140    /// for handle in handles {
141    ///     handle.await;
142    /// }
143    /// # });
144    /// ```
145    pub async fn wait(&self) -> BarrierWaitResult {
146        let mut lock = self.state.lock().await;
147        let local_gen = lock.generation_id;
148
149        lock.count += 1;
150
151        if lock.count < self.n {
152            let mut wait = self.wait.clone();
153
154            let mut generation_id = lock.generation_id;
155            let mut count = lock.count;
156
157            drop(lock);
158
159            while local_gen == generation_id && count < self.n {
160                let (g, c) = wait.recv().await.expect("sender has not been closed");
161                generation_id = g;
162                count = c;
163            }
164
165            BarrierWaitResult(false)
166        } else {
167            lock.count = 0;
168            lock.generation_id = lock.generation_id.wrapping_add(1);
169
170            lock.waker
171                .send(&(lock.generation_id, lock.count))
172                .await
173                .expect("there should be at least one receiver");
174
175            BarrierWaitResult(true)
176        }
177    }
178}
179
180impl BarrierWaitResult {
181    /// Returns `true` if this task from [`wait`] is the "leader task".
182    ///
183    /// Only one task will have `true` returned from their result, all other
184    /// tasks will have `false` returned.
185    ///
186    /// [`wait`]: struct.Barrier.html#method.wait
187    ///
188    /// # Examples
189    ///
190    /// ```
191    /// # async_std::task::block_on(async {
192    /// #
193    /// use async_std::sync::Barrier;
194    ///
195    /// let barrier = Barrier::new(1);
196    /// let barrier_wait_result = barrier.wait().await;
197    /// println!("{:?}", barrier_wait_result.is_leader());
198    /// # });
199    /// ```
200    pub fn is_leader(&self) -> bool {
201        self.0
202    }
203}
204
205#[cfg(test)]
206mod test {
207    use futures::channel::mpsc::unbounded;
208    use futures::sink::SinkExt;
209    use futures::stream::StreamExt;
210
211    use crate::sync::{Arc, Barrier};
212    use crate::task;
213
214    #[tokio::test]
215    async fn test_barrier() {
216        // NOTE(dignifiedquire): Based on the test in std, I was seeing some
217        // race conditions, so running it in a loop to make sure things are
218        // solid.
219
220        for _ in 0..1_000 {
221            const N: usize = 10;
222
223            let barrier = Arc::new(Barrier::new(N));
224            let (tx, mut rx) = unbounded();
225
226            for _ in 0..N - 1 {
227                let c = barrier.clone();
228                let mut tx = tx.clone();
229                task::spawn(async move {
230                    let res = c.wait().await;
231
232                    tx.send(res.is_leader()).await.unwrap();
233                });
234            }
235
236            // At this point, all spawned threads should be blocked,
237            // so we shouldn't get anything from the port
238            let res = rx.try_next();
239            assert!(match res {
240                Err(_err) => true,
241                _ => false,
242            });
243
244            let mut leader_found = barrier.wait().await.is_leader();
245
246            // Now, the barrier is cleared and we should get data.
247            for _ in 0..N - 1 {
248                if rx.next().await.unwrap() {
249                    assert!(!leader_found);
250                    leader_found = true;
251                }
252            }
253            assert!(leader_found);
254        }
255    }
256}