screeps_async/sync/
rwlock.rs1use std::cell::{Ref, RefCell, RefMut, UnsafeCell};
2use std::future::Future;
3use std::ops::{Deref, DerefMut};
4use std::pin::Pin;
5use std::rc::Rc;
6use std::task::{Context, Poll, Waker};
7
8pub struct RwLock<T> {
13 inner: RefCell<T>,
15 read_wakers: UnsafeCell<Vec<Waker>>,
17 write_wakers: UnsafeCell<Vec<Waker>>,
19}
20
21impl<T> RwLock<T> {
22 pub fn new(val: T) -> Self {
24 Self {
25 inner: RefCell::new(val),
26 read_wakers: UnsafeCell::new(Vec::new()),
27 write_wakers: UnsafeCell::new(Vec::new()),
28 }
29 }
30
31 pub fn read(&self) -> RwLockFuture<'_, T, RwLockReadGuard<'_, T>> {
33 RwLockFuture {
34 lock: self,
35 borrow: Self::try_read,
36 is_writer: false,
37 }
38 }
39
40 pub fn try_read(&self) -> Option<RwLockReadGuard<'_, T>> {
45 unsafe { RwLockReadGuard::new(self) }
46 }
47
48 pub fn write(&self) -> RwLockFuture<'_, T, RwLockWriteGuard<'_, T>> {
50 RwLockFuture {
51 lock: self,
52 borrow: Self::try_write,
53 is_writer: true,
54 }
55 }
56
57 pub fn try_write(&self) -> Option<RwLockWriteGuard<'_, T>> {
61 RwLockWriteGuard::new(self)
62 }
63
64 pub fn into_inner(self) -> T {
66 self.inner.into_inner()
67 }
68
69 pub fn into_inner_rc(self: Rc<Self>) -> T {
74 Rc::into_inner(self).unwrap().into_inner()
75 }
76}
77
78impl<T> RwLock<T> {
79 unsafe fn unlock(&self) {
80 let wakers = &mut *self.write_wakers.get();
81 wakers.drain(..).for_each(Waker::wake);
82
83 let wakers = &mut *self.read_wakers.get();
84 wakers.drain(..).for_each(Waker::wake);
85 }
86}
87
88pub struct RwLockReadGuard<'a, T> {
90 inner: &'a RwLock<T>,
91 data: Ref<'a, T>,
92}
93
94impl<'a, T> RwLockReadGuard<'a, T> {
95 unsafe fn new(lock: &'a RwLock<T>) -> Option<Self> {
96 if !(*lock.write_wakers.get()).is_empty() {
97 return None; }
99
100 let data = lock.inner.try_borrow().ok()?;
101
102 Some(RwLockReadGuard { data, inner: lock })
103 }
104}
105
106impl<T> Drop for RwLockReadGuard<'_, T> {
107 fn drop(&mut self) {
108 unsafe { self.inner.unlock() }
109 }
110}
111
112impl<T> Deref for RwLockReadGuard<'_, T> {
113 type Target = T;
114
115 fn deref(&self) -> &Self::Target {
116 &self.data
117 }
118}
119
120pub struct RwLockWriteGuard<'a, T> {
122 inner: &'a RwLock<T>,
123 data: RefMut<'a, T>,
124}
125
126impl<'a, T> RwLockWriteGuard<'a, T> {
127 fn new(lock: &'a RwLock<T>) -> Option<Self> {
128 let data = lock.inner.try_borrow_mut().ok()?;
129
130 Some(Self { inner: lock, data })
131 }
132
133 pub fn unlock(self) {
137 drop(self);
138 }
139
140 pub async fn unlock_fair(self) {
144 self.unlock();
145 crate::time::yield_now().await;
146 }
147}
148
149impl<T> Drop for RwLockWriteGuard<'_, T> {
150 fn drop(&mut self) {
151 unsafe { self.inner.unlock() }
152 }
153}
154
155impl<T> Deref for RwLockWriteGuard<'_, T> {
156 type Target = T;
157
158 fn deref(&self) -> &Self::Target {
159 &self.data
160 }
161}
162
163impl<T> DerefMut for RwLockWriteGuard<'_, T> {
164 fn deref_mut(&mut self) -> &mut Self::Target {
165 &mut self.data
166 }
167}
168
169pub struct RwLockFuture<'a, T, G> {
171 lock: &'a RwLock<T>,
172 borrow: fn(&'a RwLock<T>) -> Option<G>,
173 is_writer: bool,
174}
175
176impl<T, G> Future for RwLockFuture<'_, T, G> {
177 type Output = G;
178
179 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
180 if let Some(guard) = (self.borrow)(self.lock) {
181 return Poll::Ready(guard);
182 }
183
184 let wakers = if self.is_writer {
185 self.lock.write_wakers.get()
186 } else {
187 self.lock.read_wakers.get()
188 };
189 let wakers = unsafe { &mut *wakers };
190
191 wakers.push(cx.waker().clone());
192
193 Poll::Pending
194 }
195}
196
197#[cfg(test)]
198mod test {
199 use super::*;
200 use crate::time::delay_ticks;
201
202 #[test]
203 fn can_read_multiple_times() {
204 crate::tests::init_test();
205
206 let lock = Rc::new(RwLock::new(()));
207 const N: usize = 10;
208 for _ in 0..N {
209 let lock = lock.clone();
210 crate::spawn(async move {
211 let _guard = lock.read().await;
212 assert_eq!(0, crate::tests::game_time());
214 delay_ticks(1).await;
216 })
217 .detach();
218 }
219
220 for _ in 0..=N {
221 crate::tests::tick().unwrap();
222 }
223 }
224
225 #[test]
226 fn cannot_write_multiple_times() {
227 crate::tests::init_test();
228
229 let lock = Rc::new(RwLock::new(0));
230 {
231 let lock = lock.clone();
232 crate::spawn(async move {
233 let mut guard = lock.write().await;
234 assert_eq!(0, crate::tests::game_time());
235 delay_ticks(1).await;
236 *guard += 1;
237 })
238 .detach();
239 }
240 {
241 let lock = lock.clone();
242 crate::spawn(async move {
243 let mut guard = lock.write().await;
244 assert_eq!(1, crate::tests::game_time());
245 delay_ticks(1).await;
246 *guard += 1;
247 })
248 .detach();
249 }
250
251 crate::tests::tick().unwrap();
252 crate::tests::tick().unwrap();
253 crate::tests::tick().unwrap();
254
255 assert_eq!(2, lock.into_inner_rc());
256 }
257
258 #[test]
259 fn cannot_read_while_writer_waiting() {
260 crate::tests::init_test();
261
262 let lock = Rc::new(RwLock::new(0));
263 {
264 let lock = lock.clone();
265 crate::spawn(async move {
266 let mut guard = lock.write().await;
267 println!("write 1 acquired");
268 assert_eq!(0, crate::tests::game_time());
269 delay_ticks(1).await;
270 *guard += 1;
271 })
272 .detach();
273 }
274 {
275 let lock = lock.clone();
276 crate::spawn(async move {
277 let guard = lock.read().await;
278 println!("read 1 acquired");
279 assert_eq!(2, crate::tests::game_time());
281 delay_ticks(1).await;
282 assert_eq!(2, *guard);
283 })
284 .detach();
285 }
286 {
287 let lock = lock.clone();
288 crate::spawn(async move {
289 let mut guard = lock.write().await;
290 println!("write 2 acquired");
291 assert_eq!(1, crate::tests::game_time());
292 delay_ticks(1).await;
293 *guard += 1;
294 })
295 .detach();
296 }
297 {
298 let lock = lock.clone();
299 crate::spawn(async move {
300 let guard = lock.read().await;
301 println!("read 2 acquired");
302 assert_eq!(2, crate::tests::game_time());
303 assert_eq!(2, *guard);
304 })
305 .detach();
306 }
307
308 crate::tests::tick().unwrap();
309 crate::tests::tick().unwrap();
310 crate::tests::tick().unwrap();
311 crate::tests::tick().unwrap();
312
313 assert_eq!(2, lock.into_inner_rc());
314 }
315}