async_std/sync/
rwlock.rs

1use std::cell::UnsafeCell;
2use std::fmt;
3use std::isize;
4use std::ops::{Deref, DerefMut};
5use std::pin::Pin;
6use std::process;
7use std::future::Future;
8use std::sync::atomic::{AtomicUsize, Ordering};
9
10use crate::sync::WakerSet;
11use crate::task::{Context, Poll};
12
13/// Set if a write lock is held.
14#[allow(clippy::identity_op)]
15const WRITE_LOCK: usize = 1 << 0;
16
17/// The value of a single blocked read contributing to the read count.
18const ONE_READ: usize = 1 << 1;
19
20/// The bits in which the read count is stored.
21const READ_COUNT_MASK: usize = !(ONE_READ - 1);
22
23/// A reader-writer lock for protecting shared data.
24///
25/// This type is an async version of [`std::sync::RwLock`].
26///
27/// [`std::sync::RwLock`]: https://doc.rust-lang.org/std/sync/struct.RwLock.html
28///
29/// # Examples
30///
31/// ```
32/// # async_std::task::block_on(async {
33/// #
34/// use async_std::sync::RwLock;
35///
36/// let lock = RwLock::new(5);
37///
38/// // Multiple read locks can be held at a time.
39/// let r1 = lock.read().await;
40/// let r2 = lock.read().await;
41/// assert_eq!(*r1, 5);
42/// assert_eq!(*r2, 5);
43/// drop((r1, r2));
44///
45/// // Only one write locks can be held at a time.
46/// let mut w = lock.write().await;
47/// *w += 1;
48/// assert_eq!(*w, 6);
49/// #
50/// # })
51/// ```
52pub struct RwLock<T: ?Sized> {
53    state: AtomicUsize,
54    read_wakers: WakerSet,
55    write_wakers: WakerSet,
56    value: UnsafeCell<T>,
57}
58
59unsafe impl<T: ?Sized + Send> Send for RwLock<T> {}
60unsafe impl<T: ?Sized + Send + Sync> Sync for RwLock<T> {}
61
62impl<T> RwLock<T> {
63    /// Creates a new reader-writer lock.
64    ///
65    /// # Examples
66    ///
67    /// ```
68    /// use async_std::sync::RwLock;
69    ///
70    /// let lock = RwLock::new(0);
71    /// ```
72    pub fn new(t: T) -> RwLock<T> {
73        RwLock {
74            state: AtomicUsize::new(0),
75            read_wakers: WakerSet::new(),
76            write_wakers: WakerSet::new(),
77            value: UnsafeCell::new(t),
78        }
79    }
80}
81
82impl<T: ?Sized> RwLock<T> {
83    /// Acquires a read lock.
84    ///
85    /// Returns a guard that releases the lock when dropped.
86    ///
87    /// # Examples
88    ///
89    /// ```
90    /// # async_std::task::block_on(async {
91    /// #
92    /// use async_std::sync::RwLock;
93    ///
94    /// let lock = RwLock::new(1);
95    ///
96    /// let n = lock.read().await;
97    /// assert_eq!(*n, 1);
98    ///
99    /// assert!(lock.try_read().is_some());
100    /// #
101    /// # })
102    /// ```
103    pub async fn read(&self) -> RwLockReadGuard<'_, T> {
104        pub struct ReadFuture<'a, T: ?Sized> {
105            lock: &'a RwLock<T>,
106            opt_key: Option<usize>,
107        }
108
109        impl<'a, T: ?Sized> Future for ReadFuture<'a, T> {
110            type Output = RwLockReadGuard<'a, T>;
111
112            fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
113                loop {
114                    // If the current task is in the set, remove it.
115                    if let Some(key) = self.opt_key.take() {
116                        self.lock.read_wakers.remove(key);
117                    }
118
119                    // Try acquiring a read lock.
120                    match self.lock.try_read() {
121                        Some(guard) => return Poll::Ready(guard),
122                        None => {
123                            // Insert this lock operation.
124                            self.opt_key = Some(self.lock.read_wakers.insert(cx));
125
126                            // If the lock is still acquired for writing, return.
127                            if self.lock.state.load(Ordering::SeqCst) & WRITE_LOCK != 0 {
128                                return Poll::Pending;
129                            }
130                        }
131                    }
132                }
133            }
134        }
135
136        impl<T: ?Sized> Drop for ReadFuture<'_, T> {
137            fn drop(&mut self) {
138                // If the current task is still in the set, that means it is being cancelled now.
139                if let Some(key) = self.opt_key {
140                    self.lock.read_wakers.cancel(key);
141
142                    // If there are no active readers, notify a blocked writer if none were
143                    // notified already.
144                    if self.lock.state.load(Ordering::SeqCst) & READ_COUNT_MASK == 0 {
145                        self.lock.write_wakers.notify_any();
146                    }
147                }
148            }
149        }
150
151        ReadFuture {
152            lock: self,
153            opt_key: None,
154        }
155        .await
156    }
157
158    /// Attempts to acquire a read lock.
159    ///
160    /// If a read lock could not be acquired at this time, then [`None`] is returned. Otherwise, a
161    /// guard is returned that releases the lock when dropped.
162    ///
163    /// [`None`]: https://doc.rust-lang.org/std/option/enum.Option.html#variant.None
164    ///
165    /// # Examples
166    ///
167    /// ```
168    /// # async_std::task::block_on(async {
169    /// #
170    /// use async_std::sync::RwLock;
171    ///
172    /// let lock = RwLock::new(1);
173    ///
174    /// let n = lock.read().await;
175    /// assert_eq!(*n, 1);
176    ///
177    /// assert!(lock.try_read().is_some());
178    /// #
179    /// # })
180    /// ```
181    pub fn try_read(&self) -> Option<RwLockReadGuard<'_, T>> {
182        let mut state = self.state.load(Ordering::SeqCst);
183
184        loop {
185            // If a write lock is currently held, then a read lock cannot be acquired.
186            if state & WRITE_LOCK != 0 {
187                return None;
188            }
189
190            // Make sure the number of readers doesn't overflow.
191            if state > isize::MAX as usize {
192                process::abort();
193            }
194
195            // Increment the number of active reads.
196            match self.state.compare_exchange_weak(
197                state,
198                state + ONE_READ,
199                Ordering::SeqCst,
200                Ordering::SeqCst,
201            ) {
202                Ok(_) => return Some(RwLockReadGuard(self)),
203                Err(s) => state = s,
204            }
205        }
206    }
207
208    /// Acquires a write lock.
209    ///
210    /// Returns a guard that releases the lock when dropped.
211    ///
212    /// # Examples
213    ///
214    /// ```
215    /// # async_std::task::block_on(async {
216    /// #
217    /// use async_std::sync::RwLock;
218    ///
219    /// let lock = RwLock::new(1);
220    ///
221    /// let mut n = lock.write().await;
222    /// *n = 2;
223    ///
224    /// assert!(lock.try_read().is_none());
225    /// #
226    /// # })
227    /// ```
228    pub async fn write(&self) -> RwLockWriteGuard<'_, T> {
229        pub struct WriteFuture<'a, T: ?Sized> {
230            lock: &'a RwLock<T>,
231            opt_key: Option<usize>,
232        }
233
234        impl<'a, T: ?Sized> Future for WriteFuture<'a, T> {
235            type Output = RwLockWriteGuard<'a, T>;
236
237            fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
238                loop {
239                    // If the current task is in the set, remove it.
240                    if let Some(key) = self.opt_key.take() {
241                        self.lock.write_wakers.remove(key);
242                    }
243
244                    // Try acquiring a write lock.
245                    match self.lock.try_write() {
246                        Some(guard) => return Poll::Ready(guard),
247                        None => {
248                            // Insert this lock operation.
249                            self.opt_key = Some(self.lock.write_wakers.insert(cx));
250
251                            // If the lock is still acquired for reading or writing, return.
252                            if self.lock.state.load(Ordering::SeqCst) != 0 {
253                                return Poll::Pending;
254                            }
255                        }
256                    }
257                }
258            }
259        }
260
261        impl<T: ?Sized> Drop for WriteFuture<'_, T> {
262            fn drop(&mut self) {
263                // If the current task is still in the set, that means it is being cancelled now.
264                if let Some(key) = self.opt_key {
265                    if !self.lock.write_wakers.cancel(key) {
266                        // If no other blocked reader was notified, notify all readers.
267                        self.lock.read_wakers.notify_all();
268                    }
269                }
270            }
271        }
272
273        WriteFuture {
274            lock: self,
275            opt_key: None,
276        }
277        .await
278    }
279
280    /// Attempts to acquire a write lock.
281    ///
282    /// If a write lock could not be acquired at this time, then [`None`] is returned. Otherwise, a
283    /// guard is returned that releases the lock when dropped.
284    ///
285    /// [`None`]: https://doc.rust-lang.org/std/option/enum.Option.html#variant.None
286    ///
287    /// # Examples
288    ///
289    /// ```
290    /// # async_std::task::block_on(async {
291    /// #
292    /// use async_std::sync::RwLock;
293    ///
294    /// let lock = RwLock::new(1);
295    ///
296    /// let n = lock.read().await;
297    /// assert_eq!(*n, 1);
298    ///
299    /// assert!(lock.try_write().is_none());
300    /// #
301    /// # })
302    /// ```
303    pub fn try_write(&self) -> Option<RwLockWriteGuard<'_, T>> {
304        if self.state.compare_and_swap(0, WRITE_LOCK, Ordering::SeqCst) == 0 {
305            Some(RwLockWriteGuard(self))
306        } else {
307            None
308        }
309    }
310
311    /// Consumes the lock, returning the underlying data.
312    ///
313    /// # Examples
314    ///
315    /// ```
316    /// use async_std::sync::RwLock;
317    ///
318    /// let lock = RwLock::new(10);
319    /// assert_eq!(lock.into_inner(), 10);
320    /// ```
321    pub fn into_inner(self) -> T where T: Sized {
322        self.value.into_inner()
323    }
324
325    /// Returns a mutable reference to the underlying data.
326    ///
327    /// Since this call borrows the lock mutably, no actual locking takes place -- the mutable
328    /// borrow statically guarantees no locks exist.
329    ///
330    /// # Examples
331    ///
332    /// ```
333    /// # async_std::task::block_on(async {
334    /// #
335    /// use async_std::sync::RwLock;
336    ///
337    /// let mut lock = RwLock::new(0);
338    /// *lock.get_mut() = 10;
339    /// assert_eq!(*lock.write().await, 10);
340    /// #
341    /// # })
342    /// ```
343    pub fn get_mut(&mut self) -> &mut T {
344        unsafe { &mut *self.value.get() }
345    }
346}
347
348impl<T: ?Sized + fmt::Debug> fmt::Debug for RwLock<T> {
349    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
350        struct Locked;
351        impl fmt::Debug for Locked {
352            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
353                f.write_str("<locked>")
354            }
355        }
356
357        match self.try_read() {
358            None => f.debug_struct("RwLock").field("data", &Locked).finish(),
359            Some(guard) => f.debug_struct("RwLock").field("data", &&*guard).finish(),
360        }
361    }
362}
363
364impl<T> From<T> for RwLock<T> {
365    fn from(val: T) -> RwLock<T> {
366        RwLock::new(val)
367    }
368}
369
370impl<T: ?Sized + Default> Default for RwLock<T> {
371    fn default() -> RwLock<T> {
372        RwLock::new(Default::default())
373    }
374}
375
376/// A guard that releases the read lock when dropped.
377pub struct RwLockReadGuard<'a, T: ?Sized>(&'a RwLock<T>);
378
379unsafe impl<T: ?Sized + Send> Send for RwLockReadGuard<'_, T> {}
380unsafe impl<T: ?Sized + Sync> Sync for RwLockReadGuard<'_, T> {}
381
382impl<T: ?Sized> Drop for RwLockReadGuard<'_, T> {
383    fn drop(&mut self) {
384        let state = self.0.state.fetch_sub(ONE_READ, Ordering::SeqCst);
385
386        // If this was the last reader, notify a blocked writer if none were notified already.
387        if state & READ_COUNT_MASK == ONE_READ {
388            self.0.write_wakers.notify_any();
389        }
390    }
391}
392
393impl<T: ?Sized + fmt::Debug> fmt::Debug for RwLockReadGuard<'_, T> {
394    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
395        fmt::Debug::fmt(&**self, f)
396    }
397}
398
399impl<T: ?Sized + fmt::Display> fmt::Display for RwLockReadGuard<'_, T> {
400    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
401        (**self).fmt(f)
402    }
403}
404
405impl<T: ?Sized> Deref for RwLockReadGuard<'_, T> {
406    type Target = T;
407
408    fn deref(&self) -> &T {
409        unsafe { &*self.0.value.get() }
410    }
411}
412
413/// A guard that releases the write lock when dropped.
414pub struct RwLockWriteGuard<'a, T: ?Sized>(&'a RwLock<T>);
415
416unsafe impl<T: ?Sized + Send> Send for RwLockWriteGuard<'_, T> {}
417unsafe impl<T: ?Sized + Sync> Sync for RwLockWriteGuard<'_, T> {}
418
419impl<T: ?Sized> Drop for RwLockWriteGuard<'_, T> {
420    fn drop(&mut self) {
421        self.0.state.store(0, Ordering::SeqCst);
422
423        // Notify all blocked readers.
424        if !self.0.read_wakers.notify_all() {
425            // If there were no blocked readers, notify a blocked writer if none were notified
426            // already.
427            self.0.write_wakers.notify_any();
428        }
429    }
430}
431
432impl<T: ?Sized + fmt::Debug> fmt::Debug for RwLockWriteGuard<'_, T> {
433    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
434        fmt::Debug::fmt(&**self, f)
435    }
436}
437
438impl<T: ?Sized + fmt::Display> fmt::Display for RwLockWriteGuard<'_, T> {
439    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
440        (**self).fmt(f)
441    }
442}
443
444impl<T: ?Sized> Deref for RwLockWriteGuard<'_, T> {
445    type Target = T;
446
447    fn deref(&self) -> &T {
448        unsafe { &*self.0.value.get() }
449    }
450}
451
452impl<T: ?Sized> DerefMut for RwLockWriteGuard<'_, T> {
453    fn deref_mut(&mut self) -> &mut T {
454        unsafe { &mut *self.0.value.get() }
455    }
456}