Skip to main content

xiaoyong_value/unsync/
async_rwlock.rs

1//! A single-threaded, asynchronous reader-writer lock.
2
3use std::{
4    cell::{
5        Cell,
6        UnsafeCell,
7    },
8    future::Future,
9    ops::{
10        Deref,
11        DerefMut,
12    },
13    pin::Pin,
14    task::{
15        Context,
16        Poll,
17        Waker,
18    },
19};
20
21use smallvec::SmallVec;
22
23const UNLOCKED: usize = 0;
24const WRITE_LOCKED: usize = usize::MAX;
25
26#[derive(Debug, Clone, Copy, PartialEq)]
27enum WaiterType {
28    Read,
29    Write,
30}
31
32/// Asynchronous, single-threaded Reader-Writer Lock.
33///
34/// **Thread Safety:** This lock is built on Cell and UnsafeCell and does not
35/// use Rc. It automatically implements Send if T: Send, so it can be moved
36/// across threads. However, its guards (RwLockReadGuard and RwLockWriteGuard)
37/// are explicitly !Send.
38pub struct RwLock<T: ?Sized> {
39    state:   Cell<usize>,
40    next_id: Cell<usize>,
41    waiters: Cell<SmallVec<[(usize, WaiterType, Waker); 8]>>,
42    value:   UnsafeCell<T>,
43}
44
45impl<T> RwLock<T> {
46    /// Creates a new instance.
47    pub fn new(value: T) -> Self {
48        Self {
49            state:   Cell::new(UNLOCKED),
50            next_id: Cell::new(0),
51            waiters: Cell::new(SmallVec::new()),
52            value:   UnsafeCell::new(value),
53        }
54    }
55}
56
57impl<T: ?Sized> RwLock<T> {
58    /// Get a raw pointer to the underlying data.
59    pub fn value_ptr(&self) -> *mut T {
60        self.value.get()
61    }
62    
63    /// Acquires the lock for reading asynchronously.
64    pub async fn read(&self) -> RwLockReadGuard<'_, T> {
65        RwLockReadFuture {
66            lock: self, id: None
67        }
68        .await
69    }
70
71    /// Acquires the lock for writing asynchronously.
72    pub async fn write(&self) -> RwLockWriteGuard<'_, T> {
73        RwLockWriteFuture {
74            lock: self, id: None
75        }
76        .await
77    }
78
79    /// Attempts to acquire the lock for reading without blocking.
80    pub fn try_read(&self) -> Option<RwLockReadGuard<'_, T>> {
81        let s = self.state.get();
82        if s != WRITE_LOCKED {
83            self.state.set(s + 1);
84            Some(RwLockReadGuard {
85                lock: self
86            })
87        } else {
88            None
89        }
90    }
91
92    /// Attempts to acquire the lock for writing without blocking.
93    pub fn try_write(&self) -> Option<RwLockWriteGuard<'_, T>> {
94        if self.state.get() == UNLOCKED {
95            self.state.set(WRITE_LOCKED);
96            Some(RwLockWriteGuard {
97                lock: self
98            })
99        } else {
100            None
101        }
102    }
103
104    /// Wakes the next eligible tasks. If a writer is first, wakes it.
105    /// If a reader is first, wakes ALL contiguous readers at the front.
106    fn wake_next(&self) {
107        let queue = self.waiters.take();
108
109        match queue.first() {
110            | Some((_, WaiterType::Write, waker)) => {
111                waker.wake_by_ref();
112            },
113            | _ => {
114                for (_, typ, waker) in queue.iter() {
115                    if *typ == WaiterType::Read {
116                        waker.wake_by_ref();
117                    } else {
118                        break;
119                    }
120                }
121            },
122        }
123
124        self.waiters.set(queue);
125    }
126}
127
128/// Future that resolves to a read guard when the lock is acquired for reading.
129pub struct RwLockReadFuture<'a, T: ?Sized> {
130    lock: &'a RwLock<T>,
131    id:   Option<usize>,
132}
133
134impl<'a, T: ?Sized> Future for RwLockReadFuture<'a, T> {
135    type Output = RwLockReadGuard<'a, T>;
136
137    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
138        let s = self.lock.state.get();
139
140        let id = self.id.unwrap_or_else(|| {
141            let new_id = self.lock.next_id.get();
142            self.lock.next_id.set(new_id.wrapping_add(1));
143            self.id = Some(new_id);
144            new_id
145        });
146
147        let mut queue = self.lock.waiters.take();
148
149        // Ensure no writers are ahead of us to prevent writer starvation
150        let has_writer_ahead = queue
151            .iter()
152            .take_while(|(i, ..)| *i != id)
153            .any(|(_, typ, _)| *typ == WaiterType::Write);
154
155        if s != WRITE_LOCKED && !has_writer_ahead {
156            self.lock.state.set(s + 1);
157            queue.retain(|(i, ..)| *i != id);
158            self.lock.waiters.set(queue);
159            self.id = None;
160            return Poll::Ready(RwLockReadGuard {
161                lock: self.lock
162            });
163        }
164
165        match queue.iter_mut().find(|(i, ..)| *i == id) {
166            | Some(entry) => {
167                if !entry.2.will_wake(cx.waker()) {
168                    entry.2 = cx.waker().clone();
169                }
170            },
171            | None => {
172                queue.push((id, WaiterType::Read, cx.waker().clone()));
173            },
174        }
175
176        self.lock.waiters.set(queue);
177        Poll::Pending
178    }
179}
180
181impl<'a, T: ?Sized> Drop for RwLockReadFuture<'a, T> {
182    fn drop(&mut self) {
183        if let Some(id) = self.id {
184            let mut queue = self.lock.waiters.take();
185            let was_first = queue.first().map_or(false, |(i, ..)| *i == id);
186            queue.retain(|(i, ..)| *i != id);
187            self.lock.waiters.set(queue);
188
189            // Pass the baton if we were blocking a wakeup chain
190            if was_first && self.lock.state.get() == UNLOCKED {
191                self.lock.wake_next();
192            }
193        }
194    }
195}
196
197/// Future that resolves to a write guard when the lock is acquired for writing.
198pub struct RwLockWriteFuture<'a, T: ?Sized> {
199    lock: &'a RwLock<T>,
200    id:   Option<usize>,
201}
202
203impl<'a, T: ?Sized> Future for RwLockWriteFuture<'a, T> {
204    type Output = RwLockWriteGuard<'a, T>;
205
206    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
207        let s = self.lock.state.get();
208
209        let id = self.id.unwrap_or_else(|| {
210            let new_id = self.lock.next_id.get();
211            self.lock.next_id.set(new_id.wrapping_add(1));
212            self.id = Some(new_id);
213            new_id
214        });
215
216        let mut queue = self.lock.waiters.take();
217        let is_first = queue.first().map_or(true, |(i, ..)| *i == id);
218
219        if s == UNLOCKED && is_first {
220            self.lock.state.set(WRITE_LOCKED);
221            queue.retain(|(i, ..)| *i != id);
222            self.lock.waiters.set(queue);
223            self.id = None;
224            return Poll::Ready(RwLockWriteGuard {
225                lock: self.lock
226            });
227        }
228
229        match queue.iter_mut().find(|(i, ..)| *i == id) {
230            | Some(entry) => {
231                if !entry.2.will_wake(cx.waker()) {
232                    entry.2 = cx.waker().clone();
233                }
234            },
235            | None => {
236                queue.push((id, WaiterType::Write, cx.waker().clone()));
237            },
238        }
239
240        self.lock.waiters.set(queue);
241        Poll::Pending
242    }
243}
244
245impl<'a, T: ?Sized> Drop for RwLockWriteFuture<'a, T> {
246    fn drop(&mut self) {
247        if let Some(id) = self.id {
248            let mut queue = self.lock.waiters.take();
249            let was_first = queue.first().map_or(false, |(i, ..)| *i == id);
250            queue.retain(|(i, ..)| *i != id);
251            self.lock.waiters.set(queue);
252
253            if was_first && self.lock.state.get() == UNLOCKED {
254                self.lock.wake_next();
255            }
256        }
257    }
258}
259
260/// An RAII guard that provides shared read access to the protected data.
261pub struct RwLockReadGuard<'a, T: ?Sized> {
262    lock: &'a RwLock<T>,
263}
264
265impl<'a, T: ?Sized> Deref for RwLockReadGuard<'a, T> {
266    type Target = T;
267
268    fn deref(&self) -> &Self::Target {
269        unsafe { &*self.lock.value.get() }
270    }
271}
272
273impl<'a, T: ?Sized> Drop for RwLockReadGuard<'a, T> {
274    fn drop(&mut self) {
275        let s = self.lock.state.get();
276        self.lock.state.set(s - 1);
277
278        if self.lock.state.get() == UNLOCKED {
279            self.lock.wake_next();
280        }
281    }
282}
283
284/// An RAII guard that provides exclusive write access to the protected data.
285pub struct RwLockWriteGuard<'a, T: ?Sized> {
286    lock: &'a RwLock<T>,
287}
288
289impl<'a, T: ?Sized> Deref for RwLockWriteGuard<'a, T> {
290    type Target = T;
291
292    fn deref(&self) -> &Self::Target {
293        unsafe { &*self.lock.value.get() }
294    }
295}
296
297impl<'a, T: ?Sized> DerefMut for RwLockWriteGuard<'a, T> {
298    fn deref_mut(&mut self) -> &mut Self::Target {
299        unsafe { &mut *self.lock.value.get() }
300    }
301}
302
303impl<'a, T: ?Sized> Drop for RwLockWriteGuard<'a, T> {
304    fn drop(&mut self) {
305        self.lock.state.set(UNLOCKED);
306        self.lock.wake_next();
307    }
308}
309
310#[cfg(test)]
311mod tests {
312    use std::rc::Rc;
313
314    use super::*;
315
316    #[tokio::test]
317    async fn async_rwlock() {
318        let rwlock = Rc::new(RwLock::new(0));
319
320        let r1 = Rc::clone(&rwlock);
321        tokio::task::spawn_local(async move {
322            let mut guard = r1.write().await;
323            *guard = 42;
324        })
325        .await
326        .unwrap();
327
328        let guard = rwlock.read().await;
329        assert_eq!(*guard, 42);
330    }
331}