usync/barrier.rs
1use crate::shared::{fence_acquire, invalid_mut, StrictProvenance, Waiter};
2use std::{
3 fmt,
4 ptr::{self, NonNull},
5 sync::atomic::{AtomicPtr, Ordering},
6};
7
8const QUEUED: usize = 1;
9const QUEUE_LOCKED: usize = 2;
10const COMPLETED: usize = 0;
11const COUNT_SHIFT: u32 = QUEUED.trailing_zeros();
12
13/// A barrier enables multiple threads to synchronize the beginning
14/// of some computation.
15///
16/// # Examples
17///
18/// ```
19/// use usync::Barrier;
20/// use std::sync::Arc;
21/// use std::thread;
22///
23/// let mut handles = Vec::with_capacity(10);
24/// let barrier = Arc::new(Barrier::new(10));
25/// for _ in 0..10 {
26/// let c = Arc::clone(&barrier);
27/// // The same messages will be printed together.
28/// // You will NOT see any interleaving.
29/// handles.push(thread::spawn(move|| {
30/// println!("before wait");
31/// c.wait();
32/// println!("after wait");
33/// }));
34/// }
35/// // Wait for other threads to finish.
36/// for handle in handles {
37/// handle.join().unwrap();
38/// }
39/// ```
40#[derive(Default)]
41pub struct Barrier {
42 /// This atomic integer holds the current state of the Barrier instance.
43 /// The QUEUED bit switches between counting the barrier value and recording the waiters.
44 ///
45 /// # State table:
46 ///
47 /// QUEUED | QUEUE_LOCKED | Remaining | Description
48 /// 0 | 0 | 0 | The barrier was completed and wait()s will return without blocking.
49 /// -------+--------------+-----------+----------------------------------------------------------------------
50 /// 0 | barrier count | The barrier was initialized with a barrier count and has no waiting threads.
51 /// -------+--------------+-----------+----------------------------------------------------------------------
52 /// 1 | 0 | *Waiter | The barrier has waiting threads where the head of the queue is in Remaining bits.
53 /// | | | The barrier count was moved to the tail of the waiting-threads queue.
54 /// -------+--------------+-----------+----------------------------------------------------------------------
55 /// 1 | 1 | *Waiter | The barrier has waiting threads where the head of the queue is in Remaining bits.
56 /// | | | There is also a thread updating the waiting-threads queue.
57 /// | | | Said thread is counting how many threads are queued and may possibly
58 /// | | | complete the barrier if the amount waiting matches or goes above the barrier count.
59 /// -------+--------------+-----------+----------------------------------------------------------------------
60 state: AtomicPtr<Waiter>,
61}
62
63unsafe impl Send for Barrier {}
64unsafe impl Sync for Barrier {}
65
66impl fmt::Debug for Barrier {
67 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
68 f.debug_struct("Barrier").finish_non_exhaustive()
69 }
70}
71
72impl Barrier {
73 /// Creates a new barrier that can block a given number of threads.
74 ///
75 /// A barrier will block `n`-1 threads which call [`wait()`] and then wake
76 /// up all threads at once when the `n`th thread calls [`wait()`].
77 ///
78 /// [`wait()`]: Barrier::wait
79 ///
80 /// # Examples
81 ///
82 /// ```
83 /// use usync::Barrier;
84 ///
85 /// let barrier = Barrier::new(10);
86 /// ```
87 #[must_use]
88 pub const fn new(n: usize) -> Self {
89 let state = invalid_mut(n << COUNT_SHIFT);
90 Self {
91 state: AtomicPtr::new(state),
92 }
93 }
94
95 /// Blocks the current thread until all threads have rendezvoused here.
96 ///
97 /// Barriers are re-usable after all threads have rendezvoused once, and can
98 /// be used continuously.
99 ///
100 /// A single (arbitrary) thread will receive a [`BarrierWaitResult`] that
101 /// returns `true` from [`BarrierWaitResult::is_leader()`] when returning
102 /// from this function, and all other threads will receive a result that
103 /// will return `false` from [`BarrierWaitResult::is_leader()`].
104 ///
105 /// # Examples
106 ///
107 /// ```
108 /// use usync::Barrier;
109 /// use std::sync::Arc;
110 /// use std::thread;
111 ///
112 /// let mut handles = Vec::with_capacity(10);
113 /// let barrier = Arc::new(Barrier::new(10));
114 /// for _ in 0..10 {
115 /// let c = Arc::clone(&barrier);
116 /// // The same messages will be printed together.
117 /// // You will NOT see any interleaving.
118 /// handles.push(thread::spawn(move|| {
119 /// println!("before wait");
120 /// c.wait();
121 /// println!("after wait");
122 /// }));
123 /// }
124 /// // Wait for other threads to finish.
125 /// for handle in handles {
126 /// handle.join().unwrap();
127 /// }
128 /// ```
129 #[inline]
130 pub fn wait(&self) -> BarrierWaitResult {
131 let mut is_leader = false;
132
133 // Quick check if the Barrier was already completed.
134 // Acquire barrier to ensure Barrier completions happens before we return.
135 let state = self.state.load(Ordering::Acquire);
136 if state.address() != COMPLETED {
137 is_leader = self.wait_slow(state);
138 }
139
140 BarrierWaitResult(is_leader)
141 }
142
143 #[cold]
144 fn wait_slow(&self, mut state: *mut Waiter) -> bool {
145 Waiter::with(|waiter| {
146 waiter.waiting_on.set(None);
147 waiter.prev.set(None);
148
149 loop {
150 // If the queue became completed, return that we are not the leader.
151 // Acqire barrier to ensure the queue completion happens before we return.
152 if state.address() == COMPLETED {
153 fence_acquire(&self.state);
154 return false;
155 }
156
157 // Special case to complete the queue if there's only an n=1.
158 // This avoids going throught the queue + QUEUE_LOCKED case below.
159 // On success, returns true for being the leader as we completed the Barrier.
160 // Release barrier ensures the Barrier completions happens before waiting threads return.
161 if state.address() == (1 << COUNT_SHIFT) {
162 match self.state.compare_exchange_weak(
163 state,
164 state.with_address(COMPLETED),
165 Ordering::Release,
166 Ordering::Relaxed,
167 ) {
168 Ok(_) => return true,
169 Err(e) => state = e,
170 }
171 continue;
172 }
173
174 // Prepare the waiter to be queued onto the state.
175 // NOTE: Don't keep the non Waiter::MASK bits!
176 // The first queued waiter will have the counter in those bits.
177 let waiter_ptr = NonNull::from(&*waiter).as_ptr();
178 let mut new_state = waiter_ptr.map_address(|addr| addr | QUEUED);
179
180 if state.address() & QUEUED == 0 {
181 // If we're the first waiter, we move the counter to our node.
182 // We also subtract one from the counter to *account* (pun) for our waiting thread.
183 let counter = (state.address() >> COUNT_SHIFT)
184 .checked_sub(1)
185 .expect("Barrier counter with zero value when waiting");
186
187 // The first waiter also sets the tail to itself
188 // so that Waiter::get_and_link_queue() can find the queue tail.
189 waiter.counter.store(counter, Ordering::Relaxed);
190 waiter.next.set(None);
191 waiter.tail.set(Some(NonNull::from(&*waiter)));
192 } else {
193 // Other waiters push to the queue in a stack-like manner.
194 // They also try to grab the QUEUE_LOCKED bit in order to fix/link the queue
195 // and possibly complete the Barrier in the process (depending on how many waiters there are).
196 let head = NonNull::new(state.map_address(|addr| addr & Waiter::MASK));
197 new_state = new_state.map_address(|addr| addr | QUEUE_LOCKED);
198 waiter.next.set(head);
199 waiter.tail.set(None);
200 }
201
202 // Release barrier synchronizes with Acquire barrier by the QUEUE_LOCKED bit holder
203 // doing Waiter::get_and_link_queue() to ensure that it sees the waiter writes we did
204 // above when observing the state.
205 if let Err(e) = self.state.compare_exchange_weak(
206 state,
207 new_state,
208 Ordering::Release,
209 Ordering::Relaxed,
210 ) {
211 state = e;
212 continue;
213 }
214
215 // If we acquired the QUEUE_LOCKED bit, try to link the queue or complete the Barrier.
216 // NOTE: The bits must be checked separately!
217 // When the counter is still there, it could pose as a QUEUE_LOCKED bit.
218 if (state.address() & QUEUED != 0) && (state.address() & QUEUE_LOCKED == 0) {
219 // If we manage to complete the Barrier, return is_leader=true here.
220 // SAFETY: we hold the QUEUE_LOCKED bit now.
221 if unsafe { self.link_queue_or_complete(new_state) } {
222 return true;
223 }
224 }
225
226 // Wait until we're woken up with the barrier completed.
227 assert!(waiter.parker.park(None));
228
229 // Ensure that once we're woken up, the barrier was completed.
230 // Acqire barrier to ensure the queue completion happens before we return.
231 state = self.state.load(Ordering::Acquire);
232 assert_eq!(state.address(), COMPLETED);
233 return false;
234 }
235 })
236 }
237
238 #[cold]
239 unsafe fn link_queue_or_complete(&self, mut state: *mut Waiter) -> bool {
240 loop {
241 assert_ne!(state.address() & QUEUED, 0);
242 assert_ne!(state.address() & QUEUE_LOCKED, 0);
243
244 // Fix the prev links in the waiter queue now that we hold the QUEUE_LOCKED bit.
245 // Also track how many waiters we discovered while trying to fix the waiter links.
246 // Acquire barrier to ensure writes to waiters pushed to the queue happen before we start fixing it.
247 fence_acquire(&self.state);
248 let mut discovered = 0;
249 let (_, tail) = Waiter::get_and_link_queue(state, |_| discovered += 1);
250
251 // The barrier count is stored at the tail.
252 // Subtract the amount of newly discovered nodes from the count.
253 // Use saturating_sub() as technically more threads than the count could try to wait().
254 let mut counter = tail.as_ref().counter.load(Ordering::Relaxed);
255 counter = counter.saturating_sub(discovered);
256
257 // When the count hits zero, complete the barrier.
258 tail.as_ref().counter.store(counter, Ordering::Relaxed);
259 if counter == 0 {
260 return self.complete();
261 }
262
263 // The barrier count isnt zero yet.
264 // Unset the QUEUE_LOCKED bit since we've updated the queue links for the next wait()'er to grab it.
265 // Release barrier ensures the waiter writes to head/tail we did above happen before the next QUEUE_LOCKED bit holder.
266 match self.state.compare_exchange_weak(
267 state,
268 state.map_address(|addr| addr & !QUEUE_LOCKED),
269 Ordering::Release,
270 Ordering::Relaxed,
271 ) {
272 Ok(_) => return false,
273 Err(e) => state = e,
274 }
275 }
276 }
277
278 #[cold]
279 unsafe fn complete(&self) -> bool {
280 // Complete the barrier while also dequeueing all the waiters.
281 // AcqRel as Acquire barrier to ensure the writes to the pushed waiters happens before we iterate & wake them below.
282 // AcqRel as Release barrier to ensure that the barrier completion happens before the wait() calls return.
283 let completed = ptr::null_mut::<Waiter>().with_address(COMPLETED);
284 let state = self.state.swap(completed, Ordering::AcqRel);
285
286 assert_ne!(state.address() & QUEUED, 0);
287 assert_ne!(state.address() & QUEUE_LOCKED, 0);
288
289 let mut waiters = NonNull::new(state.map_address(|addr| addr & Waiter::MASK));
290 while let Some(waiter) = waiters {
291 waiters = waiter.as_ref().next.get();
292 waiter.as_ref().parker.unpark();
293 }
294
295 // Since we completed the barrier, we are the leader.
296 true
297 }
298}
299
300/// A `BarrierWaitResult` is returned by [`Barrier::wait()`] when all threads
301/// in the [`Barrier`] have rendezvoused.
302///
303/// # Examples
304///
305/// ```
306/// use usync::Barrier;
307///
308/// let barrier = Barrier::new(1);
309/// let barrier_wait_result = barrier.wait();
310/// ```
311pub struct BarrierWaitResult(bool);
312
313impl fmt::Debug for BarrierWaitResult {
314 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
315 f.debug_struct("BarrierWaitResult")
316 .field("is_leader", &self.is_leader())
317 .finish()
318 }
319}
320
321impl BarrierWaitResult {
322 /// Returns `true` if this thread is the "leader thread" for the call to
323 /// [`Barrier::wait()`].
324 ///
325 /// Only one thread will have `true` returned from their result, all other
326 /// threads will have `false` returned.
327 ///
328 /// # Examples
329 ///
330 /// ```
331 /// use usync::Barrier;
332 ///
333 /// let barrier = Barrier::new(1);
334 /// let barrier_wait_result = barrier.wait();
335 /// println!("{:?}", barrier_wait_result.is_leader());
336 /// ```
337 #[must_use]
338 pub fn is_leader(&self) -> bool {
339 self.0
340 }
341}