screeps_async/sync/
mutex.rs

1use std::cell::{Cell, UnsafeCell};
2use std::future::Future;
3use std::ops::{Deref, DerefMut};
4use std::pin::Pin;
5use std::task::{Context, Poll, Waker};
6
7/// An async mutex
8///
9/// Locks will be acquired in the order they are requested
10///
11/// # Examples
12/// ```
13/// # use std::rc::Rc;
14/// # use screeps_async::sync::Mutex;
15/// # screeps_async::initialize();
16/// let mutex = Rc::new(Mutex::new(0));
17/// screeps_async::spawn(async move {
18///     let mut val = mutex.lock().await;
19///     *val = 1;
20/// }).detach();
21/// ```
22pub struct Mutex<T> {
23    /// Whether the mutex is currently locked.
24    ///
25    /// Use [`Cell<bool>`] instead of [AtomicBool] since we don't really need atomics
26    /// and [Cell] is more general
27    state: Cell<bool>,
28    /// Wrapped value
29    data: UnsafeCell<T>,
30    /// Queue of futures to wake when a lock is released
31    wakers: UnsafeCell<Vec<Waker>>,
32}
33
34impl<T> Mutex<T> {
35    /// Construct a new [Mutex] in the unlocked state wrapping the given value
36    pub fn new(val: T) -> Self {
37        Self {
38            state: Cell::new(false),
39            data: UnsafeCell::new(val),
40            wakers: UnsafeCell::new(Vec::new()),
41        }
42    }
43
44    /// Acquire the mutex.
45    ///
46    /// Returns a guard that release the mutex when dropped
47    pub fn lock(&self) -> MutexLockFuture<'_, T> {
48        MutexLockFuture::new(self)
49    }
50
51    /// Try to acquire the mutex.
52    ///
53    /// If the mutex could not be acquired at this time return [`None`], otherwise
54    /// returns a guard that will release the mutex when dropped.
55    pub fn try_lock(&self) -> Option<MutexGuard<'_, T>> {
56        (!self.state.replace(true)).then(|| MutexGuard::new(self))
57    }
58
59    /// Consumes the mutex, returning the underlying data
60    pub fn into_inner(self) -> T {
61        self.data.into_inner()
62    }
63
64    fn unlock(&self) {
65        self.state.set(false);
66        let wakers = unsafe { &mut *self.wakers.get() };
67        wakers.drain(..).for_each(Waker::wake);
68    }
69}
70
71/// An RAII guard that releases the mutex when dropped
72pub struct MutexGuard<'a, T> {
73    lock: &'a Mutex<T>,
74}
75
76impl<'a, T> MutexGuard<'a, T> {
77    fn new(lock: &'a Mutex<T>) -> Self {
78        Self { lock }
79    }
80
81    /// Immediately drops the guard, and consequently unlocks the mutex.
82    ///
83    /// This function is equivalent to calling [`drop`] on the guard but is more self-documenting.
84    pub fn unlock(self) {
85        drop(self);
86    }
87
88    /// Release the lock and immediately yield control back to the async runtime
89    ///
90    /// This essentially just calls [Self::unlock] then [yield_now()](crate::time::yield_now)
91    pub async fn unlock_fair(self) {
92        self.unlock();
93        crate::time::yield_now().await;
94    }
95}
96
97impl<T> Deref for MutexGuard<'_, T> {
98    type Target = T;
99
100    fn deref(&self) -> &Self::Target {
101        unsafe { &*self.lock.data.get() }
102    }
103}
104
105impl<T> DerefMut for MutexGuard<'_, T> {
106    fn deref_mut(&mut self) -> &mut Self::Target {
107        unsafe { &mut *self.lock.data.get() }
108    }
109}
110
111impl<T> Drop for MutexGuard<'_, T> {
112    fn drop(&mut self) {
113        self.lock.unlock();
114    }
115}
116
117/// A [Future] that blocks until the [Mutex] can be locked, then returns the [MutexGuard]
118pub struct MutexLockFuture<'a, T> {
119    mutex: &'a Mutex<T>,
120}
121
122impl<'a, T> MutexLockFuture<'a, T> {
123    fn new(mutex: &'a Mutex<T>) -> Self {
124        Self { mutex }
125    }
126}
127
128impl<'a, T> Future for MutexLockFuture<'a, T> {
129    type Output = MutexGuard<'a, T>;
130
131    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
132        if let Some(val) = self.mutex.try_lock() {
133            return Poll::Ready(val);
134        }
135
136        unsafe {
137            (*self.mutex.wakers.get()).push(cx.waker().clone());
138        }
139
140        Poll::Pending
141    }
142}
143
144#[cfg(test)]
145mod test {
146    use super::*;
147    use crate::time::delay_ticks;
148    use std::rc::Rc;
149
150    #[test]
151    fn single_lock() {
152        crate::tests::init_test();
153
154        let mutex = Rc::new(Mutex::new(vec![]));
155        {
156            let mutex = mutex.clone();
157            crate::spawn(async move {
158                let mut vec = mutex.lock().await;
159                vec.push(0);
160            })
161            .detach();
162        }
163
164        crate::run().unwrap();
165
166        let expected = vec![0];
167        let actual = Rc::into_inner(mutex).unwrap().into_inner();
168        assert_eq!(expected, actual);
169    }
170
171    #[test]
172    fn cannot_lock_twice() {
173        let mutex = Mutex::new(());
174        let _guard = mutex.try_lock().unwrap();
175
176        assert!(mutex.try_lock().is_none());
177    }
178
179    #[test]
180    fn await_multiple_locks() {
181        crate::tests::init_test();
182
183        let mutex = Rc::new(Mutex::new(vec![]));
184        const N: u32 = 10;
185        for i in 0..N {
186            let mutex = mutex.clone();
187            crate::spawn(async move {
188                let mut vec = mutex.lock().await;
189                // Release the lock next tick to guarantee blocked tasks
190                delay_ticks(1).await;
191                vec.push(i);
192            })
193            .detach();
194        }
195
196        for _ in 0..=N {
197            crate::tests::tick().unwrap();
198        }
199
200        let expected = (0..10).collect::<Vec<_>>();
201        let actual = Rc::into_inner(mutex).unwrap().into_inner();
202        assert_eq!(expected, actual);
203    }
204
205    #[test]
206    fn handles_dropped_futures() {
207        crate::tests::init_test();
208
209        let mutex = Rc::new(Mutex::new(vec![]));
210        {
211            let mutex = mutex.clone();
212            crate::spawn(async move {
213                let mut _guard = mutex.lock().await;
214                delay_ticks(1).await;
215                _guard.push(0);
216            })
217            .detach();
218        }
219        let to_drop = {
220            let mutex = mutex.clone();
221            crate::spawn(async move {
222                let mut _guard = mutex.lock().await;
223                _guard.push(1);
224            })
225        };
226        {
227            let mutex = mutex.clone();
228            crate::spawn(async move {
229                let mut _guard = mutex.lock().await;
230                _guard.push(2);
231            })
232            .detach();
233        }
234
235        crate::tests::tick().unwrap();
236        drop(to_drop);
237        crate::tests::tick().unwrap();
238
239        let expected = vec![0, 2];
240        let actual = Rc::into_inner(mutex).unwrap().into_inner();
241
242        assert_eq!(expected, actual);
243    }
244}