sync_linux_no_libc/sync/
barrier.rs

1use core::fmt;
2use crate::sync::{Condvar, Mutex};
3
4/// A barrier enables multiple threads to synchronize the beginning
5/// of some computation.
6///
7/// # Examples
8///
9/// ```
10/// use sync_linux_no_libc::sync::Barrier;
11/// use std::thread;
12///
13/// let n = 10;
14/// let barrier = Barrier::new(n);
15/// thread::scope(|s| {
16///     for _ in 0..n {
17///         // The same messages will be printed together.
18///         // You will NOT see any interleaving.
19///         s.spawn(|| {
20///             println!("before wait");
21///             barrier.wait();
22///             println!("after wait");
23///         });
24///     }
25/// });
26/// ```
27pub struct Barrier {
28    lock: Mutex<BarrierState>,
29    cvar: Condvar,
30    num_threads: usize,
31}
32
33// The inner state of a double barrier
34struct BarrierState {
35    count: usize,
36    generation_id: usize,
37}
38
39/// A `BarrierWaitResult` is returned by [`Barrier::wait()`] when all threads
40/// in the [`Barrier`] have rendezvoused.
41///
42/// # Examples
43///
44/// ```
45/// use sync_linux_no_libc::sync::Barrier;
46///
47/// let barrier = Barrier::new(1);
48/// let barrier_wait_result = barrier.wait();
49/// ```
50pub struct BarrierWaitResult(bool);
51
52impl fmt::Debug for Barrier {
53    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
54        f.debug_struct("Barrier").finish_non_exhaustive()
55    }
56}
57
58impl Barrier {
59    /// Creates a new barrier that can block a given number of threads.
60    ///
61    /// A barrier will block `n`-1 threads which call [`wait()`] and then wake
62    /// up all threads at once when the `n`th thread calls [`wait()`].
63    ///
64    /// [`wait()`]: Barrier::wait
65    ///
66    /// # Examples
67    ///
68    /// ```
69    /// use sync_linux_no_libc::sync::Barrier;
70    ///
71    /// let barrier = Barrier::new(10);
72    /// ```
73    #[must_use]
74    #[inline]
75    pub const fn new(n: usize) -> Barrier {
76        Barrier {
77            lock: Mutex::new(BarrierState { count: 0, generation_id: 0 }),
78            cvar: Condvar::new(),
79            num_threads: n,
80        }
81    }
82
83    /// Blocks the current thread until all threads have rendezvoused here.
84    ///
85    /// Barriers are re-usable after all threads have rendezvoused once, and can
86    /// be used continuously.
87    ///
88    /// A single (arbitrary) thread will receive a [`BarrierWaitResult`] that
89    /// returns `true` from [`BarrierWaitResult::is_leader()`] when returning
90    /// from this function, and all other threads will receive a result that
91    /// will return `false` from [`BarrierWaitResult::is_leader()`].
92    ///
93    /// # Examples
94    ///
95    /// ```
96    /// use sync_linux_no_libc::sync::Barrier;
97    /// use std::thread;
98    ///
99    /// let n = 10;
100    /// let barrier = Barrier::new(n);
101    /// thread::scope(|s| {
102    ///     for _ in 0..n {
103    ///         // The same messages will be printed together.
104    ///         // You will NOT see any interleaving.
105    ///         s.spawn(|| {
106    ///             println!("before wait");
107    ///             barrier.wait();
108    ///             println!("after wait");
109    ///         });
110    ///     }
111    /// });
112    /// ```
113    pub fn wait(&self) -> BarrierWaitResult {
114        let mut lock = self.lock.lock();
115        let local_gen = lock.generation_id;
116        lock.count += 1;
117        if lock.count < self.num_threads {
118            let _guard =
119                self.cvar.wait_while(lock, |state| local_gen == state.generation_id);
120            BarrierWaitResult(false)
121        } else {
122            lock.count = 0;
123            lock.generation_id = lock.generation_id.wrapping_add(1);
124            self.cvar.notify_all();
125            BarrierWaitResult(true)
126        }
127    }
128}
129
130impl fmt::Debug for BarrierWaitResult {
131    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
132        f.debug_struct("BarrierWaitResult").field("is_leader", &self.is_leader()).finish()
133    }
134}
135
136impl BarrierWaitResult {
137    /// Returns `true` if this thread is the "leader thread" for the call to
138    /// [`Barrier::wait()`].
139    ///
140    /// Only one thread will have `true` returned from their result, all other
141    /// threads will have `false` returned.
142    ///
143    /// # Examples
144    ///
145    /// ```
146    /// use sync_linux_no_libc::sync::Barrier;
147    ///
148    /// let barrier = Barrier::new(1);
149    /// let barrier_wait_result = barrier.wait();
150    /// println!("{:?}", barrier_wait_result.is_leader());
151    /// ```
152    #[must_use]
153    pub fn is_leader(&self) -> bool {
154        self.0
155    }
156}