screeps_async/sync/
rwlock.rs

1use std::cell::{Ref, RefCell, RefMut, UnsafeCell};
2use std::future::Future;
3use std::ops::{Deref, DerefMut};
4use std::pin::Pin;
5use std::rc::Rc;
6use std::task::{Context, Poll, Waker};
7
8/// An async RwLock
9///
10/// Locks will be acquired in the order they are requested. When any task is waiting
11/// on a [write](RwLock::write) lock, no new [read](RwLock::read) locks can be acquired
12pub struct RwLock<T> {
13    /// Inner RwLock
14    inner: RefCell<T>,
15    /// Queue of futures to wake when a write lock is released
16    read_wakers: UnsafeCell<Vec<Waker>>,
17    /// Queue of futures to wake when a read lock is released
18    write_wakers: UnsafeCell<Vec<Waker>>,
19}
20
21impl<T> RwLock<T> {
22    /// Construct a new [RwLock] wrapping `val`
23    pub fn new(val: T) -> Self {
24        Self {
25            inner: RefCell::new(val),
26            read_wakers: UnsafeCell::new(Vec::new()),
27            write_wakers: UnsafeCell::new(Vec::new()),
28        }
29    }
30
31    /// Block until the wrapped value can be immutably borrowed
32    pub fn read(&self) -> RwLockFuture<'_, T, RwLockReadGuard<'_, T>> {
33        RwLockFuture {
34            lock: self,
35            borrow: Self::try_read,
36            is_writer: false,
37        }
38    }
39
40    /// Attempt to immutably borrow the wrapped value.
41    ///
42    /// Returns [None] if the value is currently mutably borrowed or
43    /// a task is waiting on a mutable reference.
44    pub fn try_read(&self) -> Option<RwLockReadGuard<'_, T>> {
45        unsafe { RwLockReadGuard::new(self) }
46    }
47
48    /// Block until the wrapped value can be mutably borrowed
49    pub fn write(&self) -> RwLockFuture<'_, T, RwLockWriteGuard<'_, T>> {
50        RwLockFuture {
51            lock: self,
52            borrow: Self::try_write,
53            is_writer: true,
54        }
55    }
56
57    /// Attempt to mutably borrow the wrapped value.
58    ///
59    /// Returns [None] if the value is already borrowed (mutably or immutably)
60    pub fn try_write(&self) -> Option<RwLockWriteGuard<'_, T>> {
61        RwLockWriteGuard::new(self)
62    }
63
64    /// Consumes this [RwLock] and returns ownership of the wrapped value
65    pub fn into_inner(self) -> T {
66        self.inner.into_inner()
67    }
68
69    /// Convenience method to consume [`Rc<RwLock<T>>`] and return the wrapped value
70    ///
71    /// # Panics
72    /// This method panics if the Rc has more than one strong reference
73    pub fn into_inner_rc(self: Rc<Self>) -> T {
74        Rc::into_inner(self).unwrap().into_inner()
75    }
76}
77
78impl<T> RwLock<T> {
79    unsafe fn unlock(&self) {
80        let wakers = &mut *self.write_wakers.get();
81        wakers.drain(..).for_each(Waker::wake);
82
83        let wakers = &mut *self.read_wakers.get();
84        wakers.drain(..).for_each(Waker::wake);
85    }
86}
87
88/// An RAII guard that releases a read lock when dropped
89pub struct RwLockReadGuard<'a, T> {
90    inner: &'a RwLock<T>,
91    data: Ref<'a, T>,
92}
93
94impl<'a, T> RwLockReadGuard<'a, T> {
95    unsafe fn new(lock: &'a RwLock<T>) -> Option<Self> {
96        if !(*lock.write_wakers.get()).is_empty() {
97            return None; // Cannot take new reads if a writer is waiting
98        }
99
100        let data = lock.inner.try_borrow().ok()?;
101
102        Some(RwLockReadGuard { data, inner: lock })
103    }
104}
105
106impl<T> Drop for RwLockReadGuard<'_, T> {
107    fn drop(&mut self) {
108        unsafe { self.inner.unlock() }
109    }
110}
111
112impl<T> Deref for RwLockReadGuard<'_, T> {
113    type Target = T;
114
115    fn deref(&self) -> &Self::Target {
116        &self.data
117    }
118}
119
120/// An RAII guard that releases the write lock when dropped
121pub struct RwLockWriteGuard<'a, T> {
122    inner: &'a RwLock<T>,
123    data: RefMut<'a, T>,
124}
125
126impl<'a, T> RwLockWriteGuard<'a, T> {
127    fn new(lock: &'a RwLock<T>) -> Option<Self> {
128        let data = lock.inner.try_borrow_mut().ok()?;
129
130        Some(Self { inner: lock, data })
131    }
132
133    /// Immediately drop the guard and release the write lock
134    ///
135    /// Equivalent to [drop(self)], but is more self-documenting
136    pub fn unlock(self) {
137        drop(self);
138    }
139
140    /// Release the write lock and immediately yield control back to the async runtime
141    ///
142    /// This essentially just calls [Self::unlock] then [yield_now()](crate::time::yield_now)
143    pub async fn unlock_fair(self) {
144        self.unlock();
145        crate::time::yield_now().await;
146    }
147}
148
149impl<T> Drop for RwLockWriteGuard<'_, T> {
150    fn drop(&mut self) {
151        unsafe { self.inner.unlock() }
152    }
153}
154
155impl<T> Deref for RwLockWriteGuard<'_, T> {
156    type Target = T;
157
158    fn deref(&self) -> &Self::Target {
159        &self.data
160    }
161}
162
163impl<T> DerefMut for RwLockWriteGuard<'_, T> {
164    fn deref_mut(&mut self) -> &mut Self::Target {
165        &mut self.data
166    }
167}
168
169/// A [Future] that blocks until the [RwLock] can be acquired.
170pub struct RwLockFuture<'a, T, G> {
171    lock: &'a RwLock<T>,
172    borrow: fn(&'a RwLock<T>) -> Option<G>,
173    is_writer: bool,
174}
175
176impl<T, G> Future for RwLockFuture<'_, T, G> {
177    type Output = G;
178
179    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
180        if let Some(guard) = (self.borrow)(self.lock) {
181            return Poll::Ready(guard);
182        }
183
184        let wakers = if self.is_writer {
185            self.lock.write_wakers.get()
186        } else {
187            self.lock.read_wakers.get()
188        };
189        let wakers = unsafe { &mut *wakers };
190
191        wakers.push(cx.waker().clone());
192
193        Poll::Pending
194    }
195}
196
197#[cfg(test)]
198mod test {
199    use super::*;
200    use crate::time::delay_ticks;
201
202    #[test]
203    fn can_read_multiple_times() {
204        crate::tests::init_test();
205
206        let lock = Rc::new(RwLock::new(()));
207        const N: usize = 10;
208        for _ in 0..N {
209            let lock = lock.clone();
210            crate::spawn(async move {
211                let _guard = lock.read().await;
212                // Lock should acquire first tick
213                assert_eq!(0, crate::tests::game_time());
214                // don't release till next tick to check if we can hold multiple read locks at once
215                delay_ticks(1).await;
216            })
217            .detach();
218        }
219
220        for _ in 0..=N {
221            crate::tests::tick().unwrap();
222        }
223    }
224
225    #[test]
226    fn cannot_write_multiple_times() {
227        crate::tests::init_test();
228
229        let lock = Rc::new(RwLock::new(0));
230        {
231            let lock = lock.clone();
232            crate::spawn(async move {
233                let mut guard = lock.write().await;
234                assert_eq!(0, crate::tests::game_time());
235                delay_ticks(1).await;
236                *guard += 1;
237            })
238            .detach();
239        }
240        {
241            let lock = lock.clone();
242            crate::spawn(async move {
243                let mut guard = lock.write().await;
244                assert_eq!(1, crate::tests::game_time());
245                delay_ticks(1).await;
246                *guard += 1;
247            })
248            .detach();
249        }
250
251        crate::tests::tick().unwrap();
252        crate::tests::tick().unwrap();
253        crate::tests::tick().unwrap();
254
255        assert_eq!(2, lock.into_inner_rc());
256    }
257
258    #[test]
259    fn cannot_read_while_writer_waiting() {
260        crate::tests::init_test();
261
262        let lock = Rc::new(RwLock::new(0));
263        {
264            let lock = lock.clone();
265            crate::spawn(async move {
266                let mut guard = lock.write().await;
267                println!("write 1 acquired");
268                assert_eq!(0, crate::tests::game_time());
269                delay_ticks(1).await;
270                *guard += 1;
271            })
272            .detach();
273        }
274        {
275            let lock = lock.clone();
276            crate::spawn(async move {
277                let guard = lock.read().await;
278                println!("read 1 acquired");
279                // this should happen after second write
280                assert_eq!(2, crate::tests::game_time());
281                delay_ticks(1).await;
282                assert_eq!(2, *guard);
283            })
284            .detach();
285        }
286        {
287            let lock = lock.clone();
288            crate::spawn(async move {
289                let mut guard = lock.write().await;
290                println!("write 2 acquired");
291                assert_eq!(1, crate::tests::game_time());
292                delay_ticks(1).await;
293                *guard += 1;
294            })
295            .detach();
296        }
297        {
298            let lock = lock.clone();
299            crate::spawn(async move {
300                let guard = lock.read().await;
301                println!("read 2 acquired");
302                assert_eq!(2, crate::tests::game_time());
303                assert_eq!(2, *guard);
304            })
305            .detach();
306        }
307
308        crate::tests::tick().unwrap();
309        crate::tests::tick().unwrap();
310        crate::tests::tick().unwrap();
311        crate::tests::tick().unwrap();
312
313        assert_eq!(2, lock.into_inner_rc());
314    }
315}