usync/
barrier.rs

1use crate::shared::{fence_acquire, invalid_mut, StrictProvenance, Waiter};
2use std::{
3    fmt,
4    ptr::{self, NonNull},
5    sync::atomic::{AtomicPtr, Ordering},
6};
7
8const QUEUED: usize = 1;
9const QUEUE_LOCKED: usize = 2;
10const COMPLETED: usize = 0;
11const COUNT_SHIFT: u32 = QUEUED.trailing_zeros();
12
13/// A barrier enables multiple threads to synchronize the beginning
14/// of some computation.
15///
16/// # Examples
17///
18/// ```
19/// use usync::Barrier;
20/// use std::sync::Arc;
21/// use std::thread;
22///
23/// let mut handles = Vec::with_capacity(10);
24/// let barrier = Arc::new(Barrier::new(10));
25/// for _ in 0..10 {
26///     let c = Arc::clone(&barrier);
27///     // The same messages will be printed together.
28///     // You will NOT see any interleaving.
29///     handles.push(thread::spawn(move|| {
30///         println!("before wait");
31///         c.wait();
32///         println!("after wait");
33///     }));
34/// }
35/// // Wait for other threads to finish.
36/// for handle in handles {
37///     handle.join().unwrap();
38/// }
39/// ```
40#[derive(Default)]
41pub struct Barrier {
42    /// This atomic integer holds the current state of the Barrier instance.
43    /// The QUEUED bit switches between counting the barrier value and recording the waiters.
44    ///
45    /// # State table:
46    ///
47    /// QUEUED | QUEUE_LOCKED | Remaining | Description
48    ///    0   |      0       |     0     | The barrier was completed and wait()s will return without blocking.
49    /// -------+--------------+-----------+----------------------------------------------------------------------
50    ///    0   |      barrier count       | The barrier was initialized with a barrier count and has no waiting threads.
51    /// -------+--------------+-----------+----------------------------------------------------------------------
52    ///    1   |      0       |  *Waiter  | The barrier has waiting threads where the head of the queue is in Remaining bits.
53    ///        |              |           | The barrier count was moved to the tail of the waiting-threads queue.
54    /// -------+--------------+-----------+----------------------------------------------------------------------
55    ///    1   |      1       |  *Waiter  | The barrier has waiting threads where the head of the queue is in Remaining bits.
56    ///        |              |           | There is also a thread updating the waiting-threads queue.
57    ///        |              |           | Said thread is counting how many threads are queued and may possibly
58    ///        |              |           | complete the barrier if the amount waiting matches or goes above the barrier count.
59    /// -------+--------------+-----------+----------------------------------------------------------------------
60    state: AtomicPtr<Waiter>,
61}
62
63unsafe impl Send for Barrier {}
64unsafe impl Sync for Barrier {}
65
66impl fmt::Debug for Barrier {
67    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
68        f.debug_struct("Barrier").finish_non_exhaustive()
69    }
70}
71
72impl Barrier {
73    /// Creates a new barrier that can block a given number of threads.
74    ///
75    /// A barrier will block `n`-1 threads which call [`wait()`] and then wake
76    /// up all threads at once when the `n`th thread calls [`wait()`].
77    ///
78    /// [`wait()`]: Barrier::wait
79    ///
80    /// # Examples
81    ///
82    /// ```
83    /// use usync::Barrier;
84    ///
85    /// let barrier = Barrier::new(10);
86    /// ```
87    #[must_use]
88    pub const fn new(n: usize) -> Self {
89        let state = invalid_mut(n << COUNT_SHIFT);
90        Self {
91            state: AtomicPtr::new(state),
92        }
93    }
94
95    /// Blocks the current thread until all threads have rendezvoused here.
96    ///
97    /// Barriers are re-usable after all threads have rendezvoused once, and can
98    /// be used continuously.
99    ///
100    /// A single (arbitrary) thread will receive a [`BarrierWaitResult`] that
101    /// returns `true` from [`BarrierWaitResult::is_leader()`] when returning
102    /// from this function, and all other threads will receive a result that
103    /// will return `false` from [`BarrierWaitResult::is_leader()`].
104    ///
105    /// # Examples
106    ///
107    /// ```
108    /// use usync::Barrier;
109    /// use std::sync::Arc;
110    /// use std::thread;
111    ///
112    /// let mut handles = Vec::with_capacity(10);
113    /// let barrier = Arc::new(Barrier::new(10));
114    /// for _ in 0..10 {
115    ///     let c = Arc::clone(&barrier);
116    ///     // The same messages will be printed together.
117    ///     // You will NOT see any interleaving.
118    ///     handles.push(thread::spawn(move|| {
119    ///         println!("before wait");
120    ///         c.wait();
121    ///         println!("after wait");
122    ///     }));
123    /// }
124    /// // Wait for other threads to finish.
125    /// for handle in handles {
126    ///     handle.join().unwrap();
127    /// }
128    /// ```
129    #[inline]
130    pub fn wait(&self) -> BarrierWaitResult {
131        let mut is_leader = false;
132
133        // Quick check if the Barrier was already completed.
134        // Acquire barrier to ensure Barrier completions happens before we return.
135        let state = self.state.load(Ordering::Acquire);
136        if state.address() != COMPLETED {
137            is_leader = self.wait_slow(state);
138        }
139
140        BarrierWaitResult(is_leader)
141    }
142
143    #[cold]
144    fn wait_slow(&self, mut state: *mut Waiter) -> bool {
145        Waiter::with(|waiter| {
146            waiter.waiting_on.set(None);
147            waiter.prev.set(None);
148
149            loop {
150                // If the queue became completed, return that we are not the leader.
151                // Acqire barrier to ensure the queue completion happens before we return.
152                if state.address() == COMPLETED {
153                    fence_acquire(&self.state);
154                    return false;
155                }
156
157                // Special case to complete the queue if there's only an n=1.
158                // This avoids going throught the queue + QUEUE_LOCKED case below.
159                // On success, returns true for being the leader as we completed the Barrier.
160                // Release barrier ensures the Barrier completions happens before waiting threads return.
161                if state.address() == (1 << COUNT_SHIFT) {
162                    match self.state.compare_exchange_weak(
163                        state,
164                        state.with_address(COMPLETED),
165                        Ordering::Release,
166                        Ordering::Relaxed,
167                    ) {
168                        Ok(_) => return true,
169                        Err(e) => state = e,
170                    }
171                    continue;
172                }
173
174                // Prepare the waiter to be queued onto the state.
175                // NOTE: Don't keep the non Waiter::MASK bits!
176                //       The first queued waiter will have the counter in those bits.
177                let waiter_ptr = NonNull::from(&*waiter).as_ptr();
178                let mut new_state = waiter_ptr.map_address(|addr| addr | QUEUED);
179
180                if state.address() & QUEUED == 0 {
181                    // If we're the first waiter, we move the counter to our node.
182                    // We also subtract one from the counter to *account* (pun) for our waiting thread.
183                    let counter = (state.address() >> COUNT_SHIFT)
184                        .checked_sub(1)
185                        .expect("Barrier counter with zero value when waiting");
186
187                    // The first waiter also sets the tail to itself
188                    // so that Waiter::get_and_link_queue() can find the queue tail.
189                    waiter.counter.store(counter, Ordering::Relaxed);
190                    waiter.next.set(None);
191                    waiter.tail.set(Some(NonNull::from(&*waiter)));
192                } else {
193                    // Other waiters push to the queue in a stack-like manner.
194                    // They also try to grab the QUEUE_LOCKED bit in order to fix/link the queue
195                    // and possibly complete the Barrier in the process (depending on how many waiters there are).
196                    let head = NonNull::new(state.map_address(|addr| addr & Waiter::MASK));
197                    new_state = new_state.map_address(|addr| addr | QUEUE_LOCKED);
198                    waiter.next.set(head);
199                    waiter.tail.set(None);
200                }
201
202                // Release barrier synchronizes with Acquire barrier by the QUEUE_LOCKED bit holder
203                // doing Waiter::get_and_link_queue() to ensure that it sees the waiter writes we did
204                // above when observing the state.
205                if let Err(e) = self.state.compare_exchange_weak(
206                    state,
207                    new_state,
208                    Ordering::Release,
209                    Ordering::Relaxed,
210                ) {
211                    state = e;
212                    continue;
213                }
214
215                // If we acquired the QUEUE_LOCKED bit, try to link the queue or complete the Barrier.
216                // NOTE: The bits must be checked separately!
217                //       When the counter is still there, it could pose as a QUEUE_LOCKED bit.
218                if (state.address() & QUEUED != 0) && (state.address() & QUEUE_LOCKED == 0) {
219                    // If we manage to complete the Barrier, return is_leader=true here.
220                    // SAFETY: we hold the QUEUE_LOCKED bit now.
221                    if unsafe { self.link_queue_or_complete(new_state) } {
222                        return true;
223                    }
224                }
225
226                // Wait until we're woken up with the barrier completed.
227                assert!(waiter.parker.park(None));
228
229                // Ensure that once we're woken up, the barrier was completed.
230                // Acqire barrier to ensure the queue completion happens before we return.
231                state = self.state.load(Ordering::Acquire);
232                assert_eq!(state.address(), COMPLETED);
233                return false;
234            }
235        })
236    }
237
238    #[cold]
239    unsafe fn link_queue_or_complete(&self, mut state: *mut Waiter) -> bool {
240        loop {
241            assert_ne!(state.address() & QUEUED, 0);
242            assert_ne!(state.address() & QUEUE_LOCKED, 0);
243
244            // Fix the prev links in the waiter queue now that we hold the QUEUE_LOCKED bit.
245            // Also track how many waiters we discovered while trying to fix the waiter links.
246            // Acquire barrier to ensure writes to waiters pushed to the queue happen before we start fixing it.
247            fence_acquire(&self.state);
248            let mut discovered = 0;
249            let (_, tail) = Waiter::get_and_link_queue(state, |_| discovered += 1);
250
251            // The barrier count is stored at the tail.
252            // Subtract the amount of newly discovered nodes from the count.
253            // Use saturating_sub() as technically more threads than the count could try to wait().
254            let mut counter = tail.as_ref().counter.load(Ordering::Relaxed);
255            counter = counter.saturating_sub(discovered);
256
257            // When the count hits zero, complete the barrier.
258            tail.as_ref().counter.store(counter, Ordering::Relaxed);
259            if counter == 0 {
260                return self.complete();
261            }
262
263            // The barrier count isnt zero yet.
264            // Unset the QUEUE_LOCKED bit since we've updated the queue links for the next wait()'er to grab it.
265            // Release barrier ensures the waiter writes to head/tail we did above happen before the next QUEUE_LOCKED bit holder.
266            match self.state.compare_exchange_weak(
267                state,
268                state.map_address(|addr| addr & !QUEUE_LOCKED),
269                Ordering::Release,
270                Ordering::Relaxed,
271            ) {
272                Ok(_) => return false,
273                Err(e) => state = e,
274            }
275        }
276    }
277
278    #[cold]
279    unsafe fn complete(&self) -> bool {
280        // Complete the barrier while also dequeueing all the waiters.
281        // AcqRel as Acquire barrier to ensure the writes to the pushed waiters happens before we iterate & wake them below.
282        // AcqRel as Release barrier to ensure that the barrier completion happens before the wait() calls return.
283        let completed = ptr::null_mut::<Waiter>().with_address(COMPLETED);
284        let state = self.state.swap(completed, Ordering::AcqRel);
285
286        assert_ne!(state.address() & QUEUED, 0);
287        assert_ne!(state.address() & QUEUE_LOCKED, 0);
288
289        let mut waiters = NonNull::new(state.map_address(|addr| addr & Waiter::MASK));
290        while let Some(waiter) = waiters {
291            waiters = waiter.as_ref().next.get();
292            waiter.as_ref().parker.unpark();
293        }
294
295        // Since we completed the barrier, we are the leader.
296        true
297    }
298}
299
300/// A `BarrierWaitResult` is returned by [`Barrier::wait()`] when all threads
301/// in the [`Barrier`] have rendezvoused.
302///
303/// # Examples
304///
305/// ```
306/// use usync::Barrier;
307///
308/// let barrier = Barrier::new(1);
309/// let barrier_wait_result = barrier.wait();
310/// ```
311pub struct BarrierWaitResult(bool);
312
313impl fmt::Debug for BarrierWaitResult {
314    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
315        f.debug_struct("BarrierWaitResult")
316            .field("is_leader", &self.is_leader())
317            .finish()
318    }
319}
320
321impl BarrierWaitResult {
322    /// Returns `true` if this thread is the "leader thread" for the call to
323    /// [`Barrier::wait()`].
324    ///
325    /// Only one thread will have `true` returned from their result, all other
326    /// threads will have `false` returned.
327    ///
328    /// # Examples
329    ///
330    /// ```
331    /// use usync::Barrier;
332    ///
333    /// let barrier = Barrier::new(1);
334    /// let barrier_wait_result = barrier.wait();
335    /// println!("{:?}", barrier_wait_result.is_leader());
336    /// ```
337    #[must_use]
338    pub fn is_leader(&self) -> bool {
339        self.0
340    }
341}