screeps_async/sync/
mutex.rs1use std::cell::{Cell, UnsafeCell};
2use std::future::Future;
3use std::ops::{Deref, DerefMut};
4use std::pin::Pin;
5use std::task::{Context, Poll, Waker};
6
7pub struct Mutex<T> {
23 state: Cell<bool>,
28 data: UnsafeCell<T>,
30 wakers: UnsafeCell<Vec<Waker>>,
32}
33
34impl<T> Mutex<T> {
35 pub fn new(val: T) -> Self {
37 Self {
38 state: Cell::new(false),
39 data: UnsafeCell::new(val),
40 wakers: UnsafeCell::new(Vec::new()),
41 }
42 }
43
44 pub fn lock(&self) -> MutexLockFuture<'_, T> {
48 MutexLockFuture::new(self)
49 }
50
51 pub fn try_lock(&self) -> Option<MutexGuard<'_, T>> {
56 (!self.state.replace(true)).then(|| MutexGuard::new(self))
57 }
58
59 pub fn into_inner(self) -> T {
61 self.data.into_inner()
62 }
63
64 fn unlock(&self) {
65 self.state.set(false);
66 let wakers = unsafe { &mut *self.wakers.get() };
67 wakers.drain(..).for_each(Waker::wake);
68 }
69}
70
71pub struct MutexGuard<'a, T> {
73 lock: &'a Mutex<T>,
74}
75
76impl<'a, T> MutexGuard<'a, T> {
77 fn new(lock: &'a Mutex<T>) -> Self {
78 Self { lock }
79 }
80
81 pub fn unlock(self) {
85 drop(self);
86 }
87
88 pub async fn unlock_fair(self) {
92 self.unlock();
93 crate::time::yield_now().await;
94 }
95}
96
97impl<T> Deref for MutexGuard<'_, T> {
98 type Target = T;
99
100 fn deref(&self) -> &Self::Target {
101 unsafe { &*self.lock.data.get() }
102 }
103}
104
105impl<T> DerefMut for MutexGuard<'_, T> {
106 fn deref_mut(&mut self) -> &mut Self::Target {
107 unsafe { &mut *self.lock.data.get() }
108 }
109}
110
111impl<T> Drop for MutexGuard<'_, T> {
112 fn drop(&mut self) {
113 self.lock.unlock();
114 }
115}
116
117pub struct MutexLockFuture<'a, T> {
119 mutex: &'a Mutex<T>,
120}
121
122impl<'a, T> MutexLockFuture<'a, T> {
123 fn new(mutex: &'a Mutex<T>) -> Self {
124 Self { mutex }
125 }
126}
127
128impl<'a, T> Future for MutexLockFuture<'a, T> {
129 type Output = MutexGuard<'a, T>;
130
131 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
132 if let Some(val) = self.mutex.try_lock() {
133 return Poll::Ready(val);
134 }
135
136 unsafe {
137 (*self.mutex.wakers.get()).push(cx.waker().clone());
138 }
139
140 Poll::Pending
141 }
142}
143
144#[cfg(test)]
145mod test {
146 use super::*;
147 use crate::time::delay_ticks;
148 use std::rc::Rc;
149
150 #[test]
151 fn single_lock() {
152 crate::tests::init_test();
153
154 let mutex = Rc::new(Mutex::new(vec![]));
155 {
156 let mutex = mutex.clone();
157 crate::spawn(async move {
158 let mut vec = mutex.lock().await;
159 vec.push(0);
160 })
161 .detach();
162 }
163
164 crate::run().unwrap();
165
166 let expected = vec![0];
167 let actual = Rc::into_inner(mutex).unwrap().into_inner();
168 assert_eq!(expected, actual);
169 }
170
171 #[test]
172 fn cannot_lock_twice() {
173 let mutex = Mutex::new(());
174 let _guard = mutex.try_lock().unwrap();
175
176 assert!(mutex.try_lock().is_none());
177 }
178
179 #[test]
180 fn await_multiple_locks() {
181 crate::tests::init_test();
182
183 let mutex = Rc::new(Mutex::new(vec![]));
184 const N: u32 = 10;
185 for i in 0..N {
186 let mutex = mutex.clone();
187 crate::spawn(async move {
188 let mut vec = mutex.lock().await;
189 delay_ticks(1).await;
191 vec.push(i);
192 })
193 .detach();
194 }
195
196 for _ in 0..=N {
197 crate::tests::tick().unwrap();
198 }
199
200 let expected = (0..10).collect::<Vec<_>>();
201 let actual = Rc::into_inner(mutex).unwrap().into_inner();
202 assert_eq!(expected, actual);
203 }
204
205 #[test]
206 fn handles_dropped_futures() {
207 crate::tests::init_test();
208
209 let mutex = Rc::new(Mutex::new(vec![]));
210 {
211 let mutex = mutex.clone();
212 crate::spawn(async move {
213 let mut _guard = mutex.lock().await;
214 delay_ticks(1).await;
215 _guard.push(0);
216 })
217 .detach();
218 }
219 let to_drop = {
220 let mutex = mutex.clone();
221 crate::spawn(async move {
222 let mut _guard = mutex.lock().await;
223 _guard.push(1);
224 })
225 };
226 {
227 let mutex = mutex.clone();
228 crate::spawn(async move {
229 let mut _guard = mutex.lock().await;
230 _guard.push(2);
231 })
232 .detach();
233 }
234
235 crate::tests::tick().unwrap();
236 drop(to_drop);
237 crate::tests::tick().unwrap();
238
239 let expected = vec![0, 2];
240 let actual = Rc::into_inner(mutex).unwrap().into_inner();
241
242 assert_eq!(expected, actual);
243 }
244}