tiny_std/sync/
rwlock.rs

1//! Rw-lock implementation essentially copied from std.
2//! Thus the license for it is this:
3//! ---
4//!
5//! Permission is hereby granted, free of charge, to any
6//! person obtaining a copy of this software and associated
7//! documentation files (the "Software"), to deal in the
8//! Software without restriction, including without
9//! limitation the rights to use, copy, modify, merge,
10//! publish, distribute, sublicense, and/or sell copies of
11//! the Software, and to permit persons to whom the Software
12//! is furnished to do so, subject to the following
13//! conditions:
14//!
15//! The above copyright notice and this permission notice
16//! shall be included in all copies or substantial portions
17//! of the Software.
18//!
19//! THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
20//! ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
21//! TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
22//! PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
23//! SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
24//! CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
25//! OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
26//! IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
27//! DEALINGS IN THE SOFTWARE.
28//!
29//! ---
30use crate::sync::{futex_wait_fast, NotSend};
31use core::cell::UnsafeCell;
32use core::fmt;
33use core::ops::{Deref, DerefMut};
34use core::ptr::NonNull;
35use core::sync::atomic::AtomicU32;
36use core::sync::atomic::Ordering::{Acquire, Relaxed, Release};
37use rusl::futex::futex_wake;
38
39pub struct RwLock<T: ?Sized> {
40    inner: InnerLock,
41    data: UnsafeCell<T>,
42}
43
44unsafe impl<T: ?Sized + Send> Send for RwLock<T> {}
45unsafe impl<T: ?Sized + Send + Sync> Sync for RwLock<T> {}
46/// RAII structure used to release the shared read access of a lock when
47/// dropped.
48///
49/// This structure is created by the [`read`] and [`try_read`] methods on
50/// [`RwLock`].
51///
52/// [`read`]: RwLock::read
53/// [`try_read`]: RwLock::try_read
54#[must_use = "if unused the RwLock will immediately unlock"]
55#[clippy::has_significant_drop]
56pub struct RwLockReadGuard<'a, T: ?Sized + 'a> {
57    // NB: we use a pointer instead of `&'a T` to avoid `noalias` violations, because a
58    // `Ref` argument doesn't hold immutability for its whole scope, only until it drops.
59    // `NonNull` is also covariant over `T`, just like we would have with `&T`. `NonNull`
60    // is preferable over `const* T` to allow for niche optimization.
61    data: NonNull<T>,
62    inner_lock: &'a InnerLock,
63    _not_send: NotSend,
64}
65
66unsafe impl<T: ?Sized + Sync> Sync for RwLockReadGuard<'_, T> {}
67
68impl<'rwlock, T: ?Sized> RwLockReadGuard<'rwlock, T> {
69    /// Create a new instance of `RwLockReadGuard<T>` from a `RwLock<T>`.
70    // SAFETY: if and only if `lock.inner.read()` (or `lock.inner.try_read()`) has been
71    // successfully called from the same thread before instantiating this object.
72    unsafe fn new(lock: &'rwlock RwLock<T>) -> RwLockReadGuard<'rwlock, T> {
73        RwLockReadGuard {
74            data: NonNull::new_unchecked(lock.data.get()),
75            inner_lock: &lock.inner,
76            _not_send: NotSend::new(),
77        }
78    }
79}
80#[must_use = "if unused the RwLock will immediately unlock"]
81#[clippy::has_significant_drop]
82pub struct RwLockWriteGuard<'a, T: ?Sized + 'a> {
83    lock: &'a RwLock<T>,
84    _not_send: NotSend,
85}
86
87unsafe impl<T: ?Sized + Sync> Sync for RwLockWriteGuard<'_, T> {}
88
89impl<'rwlock, T: ?Sized> RwLockWriteGuard<'rwlock, T> {
90    /// Create a new instance of `RwLockWriteGuard<T>` from a `RwLock<T>`.
91    // SAFETY: if and only if `lock.inner.write()` (or `lock.inner.try_write()`) has been
92    // successfully called from the same thread before instantiating this object.
93    unsafe fn new(lock: &'rwlock RwLock<T>) -> RwLockWriteGuard<'rwlock, T> {
94        RwLockWriteGuard {
95            lock,
96            _not_send: NotSend::new(),
97        }
98    }
99}
100impl<T> RwLock<T> {
101    #[inline]
102    pub const fn new(t: T) -> RwLock<T> {
103        RwLock {
104            inner: InnerLock::new(),
105            data: UnsafeCell::new(t),
106        }
107    }
108}
109
110impl<T: ?Sized> RwLock<T> {
111    #[inline]
112    pub fn read(&self) -> RwLockReadGuard<'_, T> {
113        unsafe {
114            self.inner.read();
115            RwLockReadGuard::new(self)
116        }
117    }
118
119    #[inline]
120    pub fn try_read(&self) -> Option<RwLockReadGuard<'_, T>> {
121        unsafe { self.inner.try_read().then(|| RwLockReadGuard::new(self)) }
122    }
123
124    #[inline]
125    pub fn write(&self) -> RwLockWriteGuard<'_, T> {
126        unsafe {
127            self.inner.write();
128            RwLockWriteGuard::new(self)
129        }
130    }
131
132    #[inline]
133    pub fn try_write(&self) -> Option<RwLockWriteGuard<'_, T>> {
134        unsafe { self.inner.try_write().then(|| RwLockWriteGuard::new(self)) }
135    }
136
137    pub fn into_inner(self) -> T
138    where
139        T: Sized,
140    {
141        self.data.into_inner()
142    }
143
144    pub fn get_mut(&mut self) -> &mut T {
145        self.data.get_mut()
146    }
147}
148
149struct InnerLock {
150    state: AtomicU32,
151    writer_notify: AtomicU32,
152}
153
154const READ_LOCKED: u32 = 1;
155const MASK: u32 = (1 << 30) - 1;
156const WRITE_LOCKED: u32 = MASK;
157const MAX_READERS: u32 = MASK - 1;
158const READERS_WAITING: u32 = 1 << 30;
159const WRITERS_WAITING: u32 = 1 << 31;
160
161#[inline]
162fn is_unlocked(state: u32) -> bool {
163    state & MASK == 0
164}
165
166#[inline]
167fn is_write_locked(state: u32) -> bool {
168    state & MASK == WRITE_LOCKED
169}
170
171#[inline]
172fn has_readers_waiting(state: u32) -> bool {
173    state & READERS_WAITING != 0
174}
175
176#[inline]
177fn has_writers_waiting(state: u32) -> bool {
178    state & WRITERS_WAITING != 0
179}
180
181#[inline]
182fn is_read_lockable(state: u32) -> bool {
183    // This also returns false if the counter could overflow if we tried to read lock it.
184    //
185    // We don't allow read-locking if there's readers waiting, even if the lock is unlocked
186    // and there's no writers waiting. The only situation when this happens is after unlocking,
187    // at which point the unlocking thread might be waking up writers, which have priority over readers.
188    // The unlocking thread will clear the readers waiting bit and wake up readers, if necessary.
189    state & MASK < MAX_READERS && !has_readers_waiting(state) && !has_writers_waiting(state)
190}
191
192#[inline]
193fn has_reached_max_readers(state: u32) -> bool {
194    state & MASK == MAX_READERS
195}
196
197impl InnerLock {
198    #[inline]
199    pub const fn new() -> Self {
200        Self {
201            state: AtomicU32::new(0),
202            writer_notify: AtomicU32::new(0),
203        }
204    }
205
206    #[inline]
207    pub fn try_read(&self) -> bool {
208        self.state
209            .fetch_update(Acquire, Relaxed, |s| {
210                is_read_lockable(s).then_some(s + READ_LOCKED)
211            })
212            .is_ok()
213    }
214
215    #[inline]
216    pub fn read(&self) {
217        let state = self.state.load(Relaxed);
218        if !is_read_lockable(state)
219            || self
220                .state
221                .compare_exchange_weak(state, state + READ_LOCKED, Acquire, Relaxed)
222                .is_err()
223        {
224            self.read_contended();
225        }
226    }
227
228    #[inline]
229    pub unsafe fn read_unlock(&self) {
230        let state = self.state.fetch_sub(READ_LOCKED, Release) - READ_LOCKED;
231
232        // It's impossible for a reader to be waiting on a read-locked RwLock,
233        // except if there is also a writer waiting.
234        debug_assert!(!has_readers_waiting(state) || has_writers_waiting(state));
235
236        // Wake up a writer if we were the last reader and there's a writer waiting.
237        if is_unlocked(state) && has_writers_waiting(state) {
238            self.wake_writer_or_readers(state);
239        }
240    }
241
242    #[cold]
243    fn read_contended(&self) {
244        let mut state = self.spin_read();
245
246        loop {
247            // If we can lock it, lock it.
248            if is_read_lockable(state) {
249                match self
250                    .state
251                    .compare_exchange_weak(state, state + READ_LOCKED, Acquire, Relaxed)
252                {
253                    Ok(_) => return, // Locked!
254                    Err(s) => {
255                        state = s;
256                        continue;
257                    }
258                }
259            }
260
261            // Check for overflow.
262            assert!(
263                !has_reached_max_readers(state),
264                "too many active read locks on RwLock"
265            );
266
267            // Make sure the readers waiting bit is set before we go to sleep.
268            if !has_readers_waiting(state) {
269                if let Err(s) =
270                    self.state
271                        .compare_exchange(state, state | READERS_WAITING, Relaxed, Relaxed)
272                {
273                    state = s;
274                    continue;
275                }
276            }
277
278            // Wait for the state to change.
279            futex_wait_fast(&self.state, state | READERS_WAITING);
280
281            // Spin again after waking up.
282            state = self.spin_read();
283        }
284    }
285
286    #[inline]
287    pub fn try_write(&self) -> bool {
288        self.state
289            .fetch_update(Acquire, Relaxed, |s| {
290                is_unlocked(s).then_some(s + WRITE_LOCKED)
291            })
292            .is_ok()
293    }
294
295    #[inline]
296    pub fn write(&self) {
297        if self
298            .state
299            .compare_exchange_weak(0, WRITE_LOCKED, Acquire, Relaxed)
300            .is_err()
301        {
302            self.write_contended();
303        }
304    }
305
306    #[inline]
307    pub unsafe fn write_unlock(&self) {
308        let state = self.state.fetch_sub(WRITE_LOCKED, Release) - WRITE_LOCKED;
309
310        debug_assert!(is_unlocked(state));
311
312        if has_writers_waiting(state) || has_readers_waiting(state) {
313            self.wake_writer_or_readers(state);
314        }
315    }
316
317    #[cold]
318    fn write_contended(&self) {
319        let mut state = self.spin_write();
320
321        let mut other_writers_waiting = 0;
322
323        loop {
324            // If it's unlocked, we try to lock it.
325            if is_unlocked(state) {
326                match self.state.compare_exchange_weak(
327                    state,
328                    state | WRITE_LOCKED | other_writers_waiting,
329                    Acquire,
330                    Relaxed,
331                ) {
332                    Ok(_) => return, // Locked!
333                    Err(s) => {
334                        state = s;
335                        continue;
336                    }
337                }
338            }
339
340            // Set the waiting bit indicating that we're waiting on it.
341            if !has_writers_waiting(state) {
342                if let Err(s) =
343                    self.state
344                        .compare_exchange(state, state | WRITERS_WAITING, Relaxed, Relaxed)
345                {
346                    state = s;
347                    continue;
348                }
349            }
350
351            // Other writers might be waiting now too, so we should make sure
352            // we keep that bit on once we manage lock it.
353            other_writers_waiting = WRITERS_WAITING;
354
355            // Examine the notification counter before we check if `state` has changed,
356            // to make sure we don't miss any notifications.
357            let seq = self.writer_notify.load(Acquire);
358
359            // Don't go to sleep if the lock has become available,
360            // or if the writers waiting bit is no longer set.
361            state = self.state.load(Relaxed);
362            if is_unlocked(state) || !has_writers_waiting(state) {
363                continue;
364            }
365
366            // Wait for the state to change.
367            futex_wait_fast(&self.writer_notify, seq);
368
369            // Spin again after waking up.
370            state = self.spin_write();
371        }
372    }
373
374    /// Wake up waiting threads after unlocking.
375    ///
376    /// If both are waiting, this will wake up only one writer, but will fall
377    /// back to waking up readers if there was no writer to wake up.
378    #[cold]
379    fn wake_writer_or_readers(&self, mut state: u32) {
380        assert!(is_unlocked(state));
381
382        // The readers waiting bit might be turned on at any point now,
383        // since readers will block when there's anything waiting.
384        // Writers will just lock the lock though, regardless of the waiting bits,
385        // so we don't have to worry about the writer waiting bit.
386        //
387        // If the lock gets locked in the meantime, we don't have to do
388        // anything, because then the thread that locked the lock will take
389        // care of waking up waiters when it unlocks.
390
391        // If only writers are waiting, wake one of them up.
392        if state == WRITERS_WAITING {
393            match self.state.compare_exchange(state, 0, Relaxed, Relaxed) {
394                Ok(_) => {
395                    self.wake_writer();
396                    return;
397                }
398                Err(s) => {
399                    // Maybe some readers are now waiting too. So, continue to the next `if`.
400                    state = s;
401                }
402            }
403        }
404
405        // If both writers and readers are waiting, leave the readers waiting
406        // and only wake up one writer.
407        if state == READERS_WAITING + WRITERS_WAITING {
408            if self
409                .state
410                .compare_exchange(state, READERS_WAITING, Relaxed, Relaxed)
411                .is_err()
412            {
413                // The lock got locked. Not our problem anymore.
414                return;
415            }
416            if self.wake_writer() {
417                return;
418            }
419            // No writers were actually blocked on futex_wait, so we continue
420            // to wake up readers instead, since we can't be sure if we notified a writer.
421            state = READERS_WAITING;
422        }
423
424        // If readers are waiting, wake them all up.
425        if state == READERS_WAITING
426            && self
427                .state
428                .compare_exchange(state, 0, Relaxed, Relaxed)
429                .is_ok()
430        {
431            let _ = futex_wake(&self.state, i32::MAX);
432        }
433    }
434
435    fn wake_writer(&self) -> bool {
436        self.writer_notify.fetch_add(1, Release);
437        futex_wake(&self.writer_notify, 1).unwrap() != 0
438        // Note that FreeBSD and DragonFlyBSD don't tell us whether they woke
439        // up any threads or not, and always return `false` here. That still
440        // results in correct behaviour: it just means readers get woken up as
441        // well in case both readers and writers were waiting.
442    }
443
444    #[inline]
445    fn spin_until(&self, f: impl Fn(u32) -> bool) -> u32 {
446        let mut spin = 100; // Chosen by fair dice roll.
447        loop {
448            let state = self.state.load(Relaxed);
449            if f(state) || spin == 0 {
450                return state;
451            }
452            core::hint::spin_loop();
453            spin -= 1;
454        }
455    }
456
457    #[inline]
458    fn spin_write(&self) -> u32 {
459        // Stop spinning when it's unlocked or when there's waiting writers, to keep things somewhat fair.
460        self.spin_until(|state| is_unlocked(state) || has_writers_waiting(state))
461    }
462
463    #[inline]
464    fn spin_read(&self) -> u32 {
465        // Stop spinning when it's unlocked or read locked, or when there's waiting threads.
466        self.spin_until(|state| {
467            !is_write_locked(state) || has_readers_waiting(state) || has_writers_waiting(state)
468        })
469    }
470}
471
472impl<T: fmt::Debug> fmt::Debug for RwLockReadGuard<'_, T> {
473    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
474        (**self).fmt(f)
475    }
476}
477
478impl<T: ?Sized + fmt::Display> fmt::Display for RwLockReadGuard<'_, T> {
479    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
480        (**self).fmt(f)
481    }
482}
483
484impl<T: fmt::Debug> fmt::Debug for RwLockWriteGuard<'_, T> {
485    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
486        (**self).fmt(f)
487    }
488}
489
490impl<T: ?Sized + fmt::Display> fmt::Display for RwLockWriteGuard<'_, T> {
491    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
492        (**self).fmt(f)
493    }
494}
495
496impl<T: ?Sized> Deref for RwLockReadGuard<'_, T> {
497    type Target = T;
498
499    fn deref(&self) -> &T {
500        // SAFETY: the conditions of `RwLockGuard::new` were satisfied when created.
501        unsafe { self.data.as_ref() }
502    }
503}
504
505impl<T: ?Sized> Deref for RwLockWriteGuard<'_, T> {
506    type Target = T;
507
508    fn deref(&self) -> &T {
509        // SAFETY: the conditions of `RwLockWriteGuard::new` were satisfied when created.
510        unsafe { &*self.lock.data.get() }
511    }
512}
513
514impl<T: ?Sized> DerefMut for RwLockWriteGuard<'_, T> {
515    fn deref_mut(&mut self) -> &mut T {
516        // SAFETY: the conditions of `RwLockWriteGuard::new` were satisfied when created.
517        unsafe { &mut *self.lock.data.get() }
518    }
519}
520
521impl<T: ?Sized> Drop for RwLockReadGuard<'_, T> {
522    fn drop(&mut self) {
523        // SAFETY: the conditions of `RwLockReadGuard::new` were satisfied when created.
524        unsafe {
525            self.inner_lock.read_unlock();
526        }
527    }
528}
529
530impl<T: ?Sized> Drop for RwLockWriteGuard<'_, T> {
531    fn drop(&mut self) {
532        // SAFETY: the conditions of `RwLockWriteGuard::new` were satisfied when created.
533        unsafe {
534            self.lock.inner.write_unlock();
535        }
536    }
537}
538
539#[cfg(test)]
540mod tests {
541    use crate::sync::RwLock;
542    use core::time::Duration;
543
544    #[test]
545    fn can_lock() {
546        let rw = std::sync::Arc::new(super::RwLock::new(0));
547        let rw_c = rw.clone();
548        let mut guard = rw.write();
549        let res = std::thread::spawn(move || *rw_c.read());
550        *guard = 15;
551        drop(guard);
552        let thread_res = res.join().unwrap();
553        assert_eq!(15, thread_res);
554    }
555
556    #[test]
557    fn can_mutex_contended() {
558        const NUM_THREADS: usize = 32;
559        let count = std::sync::Arc::new(RwLock::new(0));
560        let mut handles = std::vec::Vec::new();
561        for _i in 0..NUM_THREADS {
562            let count_c = count.clone();
563            let handle = std::thread::spawn(move || {
564                // Try to create some contention
565                let mut w_guard = count_c.write();
566                let orig = *w_guard;
567                std::thread::sleep(Duration::from_millis(1));
568                *w_guard += 1;
569                drop(w_guard);
570                std::thread::sleep(Duration::from_millis(1));
571                let r_guard = count_c.read();
572                std::thread::sleep(Duration::from_millis(1));
573                // We incremented this
574                assert!(*r_guard > orig);
575            });
576            handles.push(handle);
577        }
578        for h in handles {
579            h.join().unwrap();
580        }
581        assert_eq!(NUM_THREADS, *count.read());
582    }
583
584    #[test]
585    fn can_try_rw_single_thread_contended() {
586        let rw = std::sync::Arc::new(super::RwLock::new(0));
587        let rw_c = rw.clone();
588        assert_eq!(0, *rw_c.try_read().unwrap());
589        let r_guard = rw.read();
590        assert_eq!(0, *rw_c.try_read().unwrap());
591        assert!(rw_c.try_write().is_none());
592        drop(r_guard);
593        assert_eq!(0, *rw_c.try_write().unwrap());
594        let _w_guard = rw.write();
595        assert!(rw_c.try_read().is_none());
596    }
597}