async_std/sync/barrier.rs
1use broadcaster::BroadcastChannel;
2
3use crate::sync::Mutex;
4
5/// A barrier enables multiple tasks to synchronize the beginning
6/// of some computation.
7///
8/// # Examples
9///
10/// ```
11/// # async_std::task::block_on(async {
12/// #
13/// use async_std::sync::{Arc, Barrier};
14/// use async_std::task;
15///
16/// let mut handles = Vec::with_capacity(10);
17/// let barrier = Arc::new(Barrier::new(10));
18/// for _ in 0..10 {
19/// let c = barrier.clone();
20/// // The same messages will be printed together.
21/// // You will NOT see any interleaving.
22/// handles.push(task::spawn(async move {
23/// println!("before wait");
24/// c.wait().await;
25/// println!("after wait");
26/// }));
27/// }
28/// // Wait for the other futures to finish.
29/// for handle in handles {
30/// handle.await;
31/// }
32/// # });
33/// ```
34#[cfg(feature = "unstable")]
35#[cfg_attr(feature = "docs", doc(cfg(unstable)))]
36#[derive(Debug)]
37pub struct Barrier {
38 state: Mutex<BarrierState>,
39 wait: BroadcastChannel<(usize, usize)>,
40 n: usize,
41}
42
43// The inner state of a double barrier
44#[derive(Debug)]
45struct BarrierState {
46 waker: BroadcastChannel<(usize, usize)>,
47 count: usize,
48 generation_id: usize,
49}
50
51/// A `BarrierWaitResult` is returned by `wait` when all threads in the `Barrier` have rendezvoused.
52///
53/// [`wait`]: struct.Barrier.html#method.wait
54/// [`Barrier`]: struct.Barrier.html
55///
56/// # Examples
57///
58/// ```
59/// use async_std::sync::Barrier;
60///
61/// let barrier = Barrier::new(1);
62/// let barrier_wait_result = barrier.wait();
63/// ```
64#[cfg(feature = "unstable")]
65#[cfg_attr(feature = "docs", doc(cfg(unstable)))]
66#[derive(Debug, Clone)]
67pub struct BarrierWaitResult(bool);
68
69impl Barrier {
70 /// Creates a new barrier that can block a given number of tasks.
71 ///
72 /// A barrier will block `n`-1 tasks which call [`wait`] and then wake up
73 /// all tasks at once when the `n`th task calls [`wait`].
74 ///
75 /// [`wait`]: #method.wait
76 ///
77 /// # Examples
78 ///
79 /// ```
80 /// use std::sync::Barrier;
81 ///
82 /// let barrier = Barrier::new(10);
83 /// ```
84 pub fn new(mut n: usize) -> Barrier {
85 let waker = BroadcastChannel::new();
86 let wait = waker.clone();
87
88 if n == 0 {
89 // if n is 0, it's not clear what behavior the user wants.
90 // in std::sync::Barrier, an n of 0 exhibits the same behavior as n == 1, where every
91 // .wait() immediately unblocks, so we adopt that here as well.
92 n = 1;
93 }
94
95 Barrier {
96 state: Mutex::new(BarrierState {
97 waker,
98 count: 0,
99 generation_id: 1,
100 }),
101 n,
102 wait,
103 }
104 }
105
106 /// Blocks the current task until all tasks have rendezvoused here.
107 ///
108 /// Barriers are re-usable after all tasks have rendezvoused once, and can
109 /// be used continuously.
110 ///
111 /// A single (arbitrary) task will receive a [`BarrierWaitResult`] that
112 /// returns `true` from [`is_leader`] when returning from this function, and
113 /// all other tasks will receive a result that will return `false` from
114 /// [`is_leader`].
115 ///
116 /// [`BarrierWaitResult`]: struct.BarrierWaitResult.html
117 /// [`is_leader`]: struct.BarrierWaitResult.html#method.is_leader
118 ///
119 /// # Examples
120 ///
121 /// ```
122 /// # async_std::task::block_on(async {
123 /// #
124 /// use async_std::sync::{Arc, Barrier};
125 /// use async_std::task;
126 ///
127 /// let mut handles = Vec::with_capacity(10);
128 /// let barrier = Arc::new(Barrier::new(10));
129 /// for _ in 0..10 {
130 /// let c = barrier.clone();
131 /// // The same messages will be printed together.
132 /// // You will NOT see any interleaving.
133 /// handles.push(task::spawn(async move {
134 /// println!("before wait");
135 /// c.wait().await;
136 /// println!("after wait");
137 /// }));
138 /// }
139 /// // Wait for the other futures to finish.
140 /// for handle in handles {
141 /// handle.await;
142 /// }
143 /// # });
144 /// ```
145 pub async fn wait(&self) -> BarrierWaitResult {
146 let mut lock = self.state.lock().await;
147 let local_gen = lock.generation_id;
148
149 lock.count += 1;
150
151 if lock.count < self.n {
152 let mut wait = self.wait.clone();
153
154 let mut generation_id = lock.generation_id;
155 let mut count = lock.count;
156
157 drop(lock);
158
159 while local_gen == generation_id && count < self.n {
160 let (g, c) = wait.recv().await.expect("sender has not been closed");
161 generation_id = g;
162 count = c;
163 }
164
165 BarrierWaitResult(false)
166 } else {
167 lock.count = 0;
168 lock.generation_id = lock.generation_id.wrapping_add(1);
169
170 lock.waker
171 .send(&(lock.generation_id, lock.count))
172 .await
173 .expect("there should be at least one receiver");
174
175 BarrierWaitResult(true)
176 }
177 }
178}
179
180impl BarrierWaitResult {
181 /// Returns `true` if this task from [`wait`] is the "leader task".
182 ///
183 /// Only one task will have `true` returned from their result, all other
184 /// tasks will have `false` returned.
185 ///
186 /// [`wait`]: struct.Barrier.html#method.wait
187 ///
188 /// # Examples
189 ///
190 /// ```
191 /// # async_std::task::block_on(async {
192 /// #
193 /// use async_std::sync::Barrier;
194 ///
195 /// let barrier = Barrier::new(1);
196 /// let barrier_wait_result = barrier.wait().await;
197 /// println!("{:?}", barrier_wait_result.is_leader());
198 /// # });
199 /// ```
200 pub fn is_leader(&self) -> bool {
201 self.0
202 }
203}
204
205#[cfg(test)]
206mod test {
207 use futures::channel::mpsc::unbounded;
208 use futures::sink::SinkExt;
209 use futures::stream::StreamExt;
210
211 use crate::sync::{Arc, Barrier};
212 use crate::task;
213
214 #[tokio::test]
215 async fn test_barrier() {
216 // NOTE(dignifiedquire): Based on the test in std, I was seeing some
217 // race conditions, so running it in a loop to make sure things are
218 // solid.
219
220 for _ in 0..1_000 {
221 const N: usize = 10;
222
223 let barrier = Arc::new(Barrier::new(N));
224 let (tx, mut rx) = unbounded();
225
226 for _ in 0..N - 1 {
227 let c = barrier.clone();
228 let mut tx = tx.clone();
229 task::spawn(async move {
230 let res = c.wait().await;
231
232 tx.send(res.is_leader()).await.unwrap();
233 });
234 }
235
236 // At this point, all spawned threads should be blocked,
237 // so we shouldn't get anything from the port
238 let res = rx.try_next();
239 assert!(match res {
240 Err(_err) => true,
241 _ => false,
242 });
243
244 let mut leader_found = barrier.wait().await.is_leader();
245
246 // Now, the barrier is cleared and we should get data.
247 for _ in 0..N - 1 {
248 if rx.next().await.unwrap() {
249 assert!(!leader_found);
250 leader_found = true;
251 }
252 }
253 assert!(leader_found);
254 }
255 }
256}