Skip to main content

vpp_plugin/vlib/main/
sync.rs

1//! Core synchronization primitives for vlib main.
2
3use std::{
4    cell::{Cell, UnsafeCell},
5    fmt,
6    marker::PhantomData,
7    ops::{Deref, DerefMut},
8    ptr::NonNull,
9};
10
11use crate::vlib::{BarrierHeldMainRef, MainRef};
12
13/// A read/write lock using VPP's barrier to provide exclusion between threads
14///
15/// VPP implements a barrier in the main thread which blocks all worker threads from running. The
16/// `BarrierRwLock` is an abstraction around this which allows a writer in the VPP main thread
17/// whilst the barrier is held and readers in either VPP workers or the VPP main thread.
18///
19/// Taking read or write "locks" are guaranteed to never block - blocking instead occurs in the
20/// VPP main and worker threads when the VPP barrier is taken.
21pub struct BarrierRwLock<T: ?Sized> {
22    /// The number of readers
23    ///
24    /// Note that this doesn't use `AtomicU32` because it's only modified on the VPP main thread.
25    readers: UnsafeCell<u32>,
26    /// Whether there is a writer
27    ///
28    /// Note that this doesn't use `AtomicBool` because it's only modified on the VPP main thread.
29    writer: UnsafeCell<bool>,
30    /// The lock-protected data.
31    data: UnsafeCell<T>,
32}
33
34impl<T> BarrierRwLock<T> {
35    /// Create a new barrier-backed read/write lock.
36    #[inline]
37    pub const fn new(t: T) -> Self {
38        Self {
39            data: UnsafeCell::new(t),
40            readers: UnsafeCell::new(0),
41            writer: UnsafeCell::new(false),
42        }
43    }
44}
45
46impl<T: ?Sized> BarrierRwLock<T> {
47    /// Locks this `BarrierRwLock` with shared read access.
48    ///
49    /// Returns an RAII guard which will release this thread's shared access
50    /// once it is dropped.
51    ///
52    /// # Panics
53    ///
54    /// Panics if a write lock has already been taken by this thread and not dropped.
55    #[inline(always)]
56    pub fn read(&self, vm: &MainRef) -> BarrierRwLockReadGuard<'_, T> {
57        let main_thread = vm.thread_index() == 0;
58        // SAFETY: calling `BarrierRwLockReadGuard::new` is valid when we have a reference to the lock
59        // and we are on a known VPP thread. These conditions are satisfied by the public API.
60        unsafe { BarrierRwLockReadGuard::new(self, main_thread) }
61    }
62
63    /// Locks this `BarrierRwLock` with write access.
64    ///
65    /// This is used on the VPP main thread in contexts where the VPP barrier is held.
66    ///
67    /// Returns an RAII guard which will release this thread's access
68    /// once it is dropped.
69    ///
70    /// # Panics
71    ///
72    /// Panics if a read or another write lock has already been taken by this thread and not
73    /// dropped.
74    #[inline(always)]
75    pub fn write(&self, vm: &BarrierHeldMainRef) -> BarrierRwLockWriteGuard<'_, T> {
76        // Make sure we match the check in read()
77        debug_assert_eq!(vm.thread_index(), 0);
78        // SAFETY: `BarrierRwLockWriteGuard::new` is only called on the main thread while the
79        // barrier is held.
80        unsafe { BarrierRwLockWriteGuard::new(self) }
81    }
82
83    /// Get a mutable reference to the contained data without locking.
84    ///
85    /// This call borrows the `BarrierRwLock` mutably (at compile-time) which guarantees that we
86    /// possess the only reference.
87    pub fn get_mut(&mut self) -> &mut T {
88        self.data.get_mut()
89    }
90
91    /// Returns a raw pointer to the underlying data.
92    ///
93    /// The returned pointer is always non-null and properly aligned, but it is
94    /// the user's responsibility to ensure that any reads and writes through it
95    /// are properly synchronized to avoid data races, and that it is not read
96    /// or written through after the lock is dropped.
97    pub const fn data_ptr(&self) -> *mut T {
98        self.data.get()
99    }
100}
101
102impl<T> BarrierRwLock<T> {
103    /// Consume the lock and return the underlying data.
104    pub fn into_inner(self) -> T {
105        self.data.into_inner()
106    }
107}
108
109// SAFETY: `BarrierRwLock<T>` is safe to send to another thread if `T: Send`.
110unsafe impl<T: ?Sized + Send> Send for BarrierRwLock<T> {}
111
112// SAFETY: `BarrierRwLock<T>` is safe to share between threads if `T: Send + Sync`.
113unsafe impl<T: ?Sized + Send + Sync> Sync for BarrierRwLock<T> {}
114
115impl<T: Default> Default for BarrierRwLock<T> {
116    /// Creates a new `BarrierRwLock<T>`, with the `Default` value for T.
117    fn default() -> BarrierRwLock<T> {
118        BarrierRwLock::new(Default::default())
119    }
120}
121
122/// Shared read guard returned by [`BarrierRwLock::read`].
123pub struct BarrierRwLockReadGuard<'rwlock, T: ?Sized + 'rwlock> {
124    /// A pointer to the data protected by the `BarrierRwLock`. Note that we use a pointer here
125    /// instead of `&'rwlock T` to avoid `noalias` violations, because a `BarrierRwLockReadGuard`
126    /// instance only holds immutability until it drops, not for its whole scope.
127    data: NonNull<T>,
128
129    /// A reference to the [`BarrierRwLock`] that we have read-locked.
130    lock: &'rwlock BarrierRwLock<T>,
131
132    /// Whether the lock is on the VPP main thread or not
133    main_thread: bool,
134}
135
136// Note: Send not implemented here as that would prevent the optimisation of not incrementing
137// readers for VPP worker threads, since the guard could then be sent to the VPP main thread
138// and used to access data while there is a write lock taken, which violates `noalias` rules.
139
140// SAFETY: `BarrierRwLockReadGuard` is immutable references to valid data; `Sync` is safe for T: Sync.
141unsafe impl<T: ?Sized + Sync> Sync for BarrierRwLockReadGuard<'_, T> {}
142
143impl<'rwlock, T: ?Sized> BarrierRwLockReadGuard<'rwlock, T> {
144    /// Creates a new instance of `BarrierRwLockReadGuard<T>` from a `BarrierRwLock<T>`.
145    ///
146    /// # Panics
147    ///
148    /// Panics if a write lock has already been taken by this thread and not dropped.
149    ///
150    /// # Safety
151    ///
152    /// This function is safe if and only if called from a thread that VPP barriers know about,
153    /// i.e. either the VPP main thread or a VPP worker thread.
154    #[inline(always)]
155    unsafe fn new(
156        lock: &'rwlock BarrierRwLock<T>,
157        main_thread: bool,
158    ) -> BarrierRwLockReadGuard<'rwlock, T> {
159        // SAFETY: `lock.writer` is valid because `lock` is a valid pointer to a live lock.
160        if main_thread && unsafe { *lock.writer.get() } {
161            panic!("Write lock already taken by this thread");
162        }
163
164        // SAFETY: `lock.data` is valid and aligned, and lock lifetime guarantees it outlives the guard.
165        let data = unsafe { NonNull::new_unchecked(lock.data.get()) };
166        if main_thread {
167            // SAFETY: Only main thread increments/decrements readers so there is no data race.
168            unsafe {
169                *lock.readers.get() += 1;
170            }
171        }
172        Self {
173            data,
174            lock,
175            main_thread,
176        }
177    }
178}
179
180impl<T: ?Sized> Drop for BarrierRwLockReadGuard<'_, T> {
181    #[inline(always)]
182    fn drop(&mut self) {
183        if self.main_thread {
184            // SAFETY: Only main thread mutates `readers` so there is no data race. We are on
185            // the main thread by conditional.
186            unsafe {
187                *self.lock.readers.get() -= 1;
188            }
189        }
190    }
191}
192
193impl<T: ?Sized> Deref for BarrierRwLockReadGuard<'_, T> {
194    type Target = T;
195
196    #[inline(always)]
197    fn deref(&self) -> &T {
198        // SAFETY: the conditions of `BarrierRwLockReadGuard::new` were satisfied when created.
199        unsafe { self.data.as_ref() }
200    }
201}
202
203impl<T: ?Sized + fmt::Debug> fmt::Debug for BarrierRwLockReadGuard<'_, T> {
204    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
205        (**self).fmt(f)
206    }
207}
208
209impl<T: ?Sized + fmt::Display> fmt::Display for BarrierRwLockReadGuard<'_, T> {
210    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
211        (**self).fmt(f)
212    }
213}
214
215/// Exclusive write guard returned by [`BarrierRwLock::write`].
216pub struct BarrierRwLockWriteGuard<'rwlock, T: ?Sized + 'rwlock> {
217    /// A reference to the [`RwLock`] that we have write-locked.
218    lock: &'rwlock BarrierRwLock<T>,
219
220    /// Prevent the type from being Send
221    _phantom: PhantomData<Cell<()>>,
222}
223
224impl<'rwlock, T: ?Sized> BarrierRwLockWriteGuard<'rwlock, T> {
225    /// Creates a new instance of `BarrierRwLockWriteGuard<T>` from a `BarrierRwLock<T>`.
226    ///
227    /// # Panics
228    ///
229    /// Panics if a read or another write lock has already been taken by this thread and not
230    /// dropped.
231    ///
232    /// # Safety
233    ///
234    /// This function is safe if and only if the same thread is holding the VPP barrier prior to
235    /// calling this function and continues to hold it for the lifetime of this object.
236    #[inline(always)]
237    unsafe fn new(lock: &'rwlock BarrierRwLock<T>) -> BarrierRwLockWriteGuard<'rwlock, T> {
238        // SAFETY: this function is only called with barrier held and no concurrent write.
239        unsafe {
240            if *lock.readers.get() != 0 {
241                panic!("Read lock already taken by this thread");
242            }
243            if *lock.writer.get() {
244                panic!("Write lock already taken by this thread");
245            }
246            *lock.writer.get() = true;
247        }
248        BarrierRwLockWriteGuard {
249            lock,
250            _phantom: PhantomData,
251        }
252    }
253}
254
255impl<T: ?Sized> Drop for BarrierRwLockWriteGuard<'_, T> {
256    #[inline(always)]
257    fn drop(&mut self) {
258        // SAFETY: This is the only writer and barrier is held while the guard is alive.
259        unsafe {
260            *self.lock.writer.get() = false;
261        }
262    }
263}
264
265// Note: no Send implementation as it's not safe to modify `self.lock.writer` on Drop and sending
266// the write guard across threads has limited usefulness.
267
268// SAFETY: `BarrierRwLockWriteGuard` ensures exclusive write access to the protected data
269// during its lifetime via VPP's barrier mechanism, which prevents concurrent access from
270// worker threads. For `T: Sync`, the guard can be safely shared across threads because
271// the underlying data is `Sync` and the barrier guarantees no conflicting accesses occur.
272unsafe impl<T: ?Sized + Sync> Sync for BarrierRwLockWriteGuard<'_, T> {}
273
274impl<T: ?Sized> Deref for BarrierRwLockWriteGuard<'_, T> {
275    type Target = T;
276
277    #[inline(always)]
278    fn deref(&self) -> &T {
279        // SAFETY: the conditions of `BarrierRwLockWriteGuard::new` were satisfied when created.
280        unsafe { &*self.lock.data.get() }
281    }
282}
283
284impl<T: ?Sized> DerefMut for BarrierRwLockWriteGuard<'_, T> {
285    #[inline(always)]
286    fn deref_mut(&mut self) -> &mut T {
287        // SAFETY: the conditions of `BarrierRwLockWriteGuard::new` were satisfied when created.
288        unsafe { &mut *self.lock.data.get() }
289    }
290}
291
292impl<T: ?Sized + fmt::Debug> fmt::Debug for BarrierRwLockWriteGuard<'_, T> {
293    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
294        (**self).fmt(f)
295    }
296}
297
298impl<T: ?Sized + fmt::Display> fmt::Display for BarrierRwLockWriteGuard<'_, T> {
299    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
300        (**self).fmt(f)
301    }
302}
303
304#[cfg(test)]
305mod tests {
306    use std::thread;
307
308    use crate::{
309        bindings::vlib_main_t,
310        vlib::{BarrierHeldMainRef, MainRef, main::sync::BarrierRwLock},
311    };
312
313    #[test]
314    fn concurrent_reads() {
315        let lock = BarrierRwLock::new("value".to_string());
316        let ref_lock = &lock;
317        thread::scope(|s| {
318            let thread1 = s.spawn(move || {
319                let mut main = vlib_main_t::default();
320                // SAFETY: main is sufficiently initialised for the test and valid for the duration of the
321                // call.
322                let main_ref = unsafe { MainRef::from_ptr_mut(std::ptr::addr_of_mut!(main)) };
323                for _ in 0..1000 {
324                    assert_eq!(*ref_lock.read(main_ref), "value");
325                }
326            });
327            let thread2 = s.spawn(move || {
328                let mut main = vlib_main_t {
329                    thread_index: 1,
330                    ..vlib_main_t::default()
331                };
332                // SAFETY: main is sufficiently initialised for the test and valid for the duration of the
333                // call.
334                let main_ref = unsafe { MainRef::from_ptr_mut(std::ptr::addr_of_mut!(main)) };
335                for _ in 0..1000 {
336                    assert_eq!(*ref_lock.read(main_ref), "value");
337                }
338            });
339            thread1.join().unwrap();
340            thread2.join().unwrap();
341        });
342    }
343
344    #[test]
345    fn write_guard() {
346        let mut main = vlib_main_t::default();
347        // SAFETY: main is sufficiently initialised for the test and valid for the duration of the
348        // call.
349        let main_ref = unsafe { BarrierHeldMainRef::from_ptr_mut(std::ptr::addr_of_mut!(main)) };
350        let lock = BarrierRwLock::new("value".to_string());
351        *lock.write(main_ref) = "new value".to_string();
352        assert_eq!(*lock.read(main_ref), "new value");
353    }
354
355    #[test]
356    #[should_panic(expected = "Write lock already taken by this thread")]
357    fn read_and_write1() {
358        let mut main = vlib_main_t::default();
359        // SAFETY: main is sufficiently initialised for the test and valid for the duration of the
360        // call.
361        let main_ref = unsafe { BarrierHeldMainRef::from_ptr_mut(std::ptr::addr_of_mut!(main)) };
362        let lock = BarrierRwLock::new("value".to_string());
363        let _guard1 = lock.write(main_ref);
364        let _guard2 = lock.read(main_ref);
365    }
366
367    #[test]
368    #[should_panic(expected = "Read lock already taken by this thread")]
369    fn read_and_write2() {
370        let mut main = vlib_main_t::default();
371        // SAFETY: main is sufficiently initialised for the test and valid for the duration of the
372        // call.
373        let main_ref = unsafe { BarrierHeldMainRef::from_ptr_mut(std::ptr::addr_of_mut!(main)) };
374        let lock = BarrierRwLock::new("value".to_string());
375        let _guard1 = lock.read(main_ref);
376        let _guard2 = lock.write(main_ref);
377    }
378
379    #[test]
380    #[should_panic(expected = "Write lock already taken by this thread")]
381    fn write_write() {
382        let mut main = vlib_main_t::default();
383        // SAFETY: main is sufficiently initialised for the test and valid for the duration of the
384        // call.
385        let main_ref = unsafe { BarrierHeldMainRef::from_ptr_mut(std::ptr::addr_of_mut!(main)) };
386        let lock = BarrierRwLock::new("value".to_string());
387        let _guard1 = lock.write(main_ref);
388        let _guard2 = lock.write(main_ref);
389    }
390
391    /// Test misc small utilities of [`BarrierRwLock`]
392    #[test]
393    fn misc() {
394        let mut main = vlib_main_t::default();
395        // SAFETY: main is sufficiently initialised for the test and valid for the duration of the
396        // call.
397        let main_ref = unsafe { BarrierHeldMainRef::from_ptr_mut(std::ptr::addr_of_mut!(main)) };
398        let mut lock: BarrierRwLock<String> = BarrierRwLock::default();
399
400        assert_eq!(*lock.write(main_ref), "");
401
402        *lock.get_mut() = "value".to_string();
403
404        assert_eq!(lock.write(main_ref).to_string(), "value");
405        assert_eq!(format!("{:?}", lock.write(main_ref)), "\"value\"");
406        assert_eq!(lock.read(main_ref).to_string(), "value");
407        assert_eq!(format!("{:?}", lock.read(main_ref)), "\"value\"");
408
409        // SAFETY: data_ptr() returns a valid pointer and it remains valid throughout its use
410        unsafe {
411            assert_eq!(&*lock.data_ptr(), "value");
412        }
413
414        assert_eq!(lock.into_inner(), "value");
415    }
416}