Skip to main content

xiaoyong_value/unsync/
async_mutex.rs

1//! Single-threaded, asynchronous mutual exclusion lock.
2
3use std::{
4    cell::{
5        Cell,
6        UnsafeCell,
7    },
8    ops::{
9        Deref,
10        DerefMut,
11    },
12    pin::Pin,
13    task::{
14        Context,
15        Poll,
16        Waker,
17    },
18};
19
20use smallvec::SmallVec;
21
22/// Asynchronous, single-threaded Mutex.
23///
24/// **Thread Safety:** This Mutex is built on Cell and UnsafeCell and does not
25/// use Rc. It automatically implements Send if T: Send, so it can be moved
26/// across threads. However, its guard (MutexGuard) is explicitly !Send.
27pub struct Mutex<T: ?Sized> {
28    is_locked: Cell<bool>,
29    next_id:   Cell<usize>,
30    waiters:   Cell<SmallVec<[(usize, Waker); 8]>>,
31    value:     UnsafeCell<T>,
32}
33
34impl<T> Mutex<T> {
35    /// Creates a new instance.
36    pub fn new(value: T) -> Self {
37        Self {
38            is_locked: Cell::new(false),
39            next_id:   Cell::new(0),
40            waiters:   Cell::new(SmallVec::new()),
41            value:     UnsafeCell::new(value),
42        }
43    }
44}
45
46impl<T: ?Sized> Mutex<T> {
47    /// Get a raw pointer to the underlying data.
48    pub fn value_ptr(&self) -> *mut T {
49        self.value.get()
50    }
51    
52    /// Acquire the lock.
53    pub async fn lock(&self) -> MutexGuard<'_, T> {
54        LockFuture {
55            mutex: self,
56            id:    None,
57        }
58        .await
59    }
60
61    /// Try to acquire the lock without blocking.
62    pub fn try_lock(&self) -> Option<MutexGuard<'_, T>> {
63        if !self.is_locked.get() {
64            self.is_locked.set(true);
65            Some(MutexGuard {
66                mutex: self
67            })
68        } else {
69            None
70        }
71    }
72}
73
74/// Future that resolves to a mutex guard when the lock is acquired.
75pub struct LockFuture<'a, T: ?Sized> {
76    mutex: &'a Mutex<T>,
77    id:    Option<usize>,
78}
79
80impl<'a, T: ?Sized> Future for LockFuture<'a, T> {
81    type Output = MutexGuard<'a, T>;
82
83    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
84        if !self.mutex.is_locked.get() {
85            self.mutex.is_locked.set(true);
86
87            // Success: Clean up our queue entry if we were previously pending
88            if let Some(id) = self.id {
89                let mut queue = self.mutex.waiters.take();
90                queue.retain(|(w_id, _)| *w_id != id);
91                self.mutex.waiters.set(queue);
92
93                // Disable the Drop handler since we successfully acquired the lock
94                self.id = None;
95            }
96
97            Poll::Ready(MutexGuard {
98                mutex: self.mutex
99            })
100        } else {
101            // Assign a unique ID on the first Pending poll
102            let id = self.id.unwrap_or_else(|| {
103                let new_id = self.mutex.next_id.get();
104                self.mutex.next_id.set(new_id.wrapping_add(1));
105                self.id = Some(new_id);
106                new_id
107            });
108
109            let mut queue = self.mutex.waiters.take();
110
111            // Update the waker if we're already in the queue, else push
112            match queue.iter_mut().find(|(i, _)| *i == id) {
113                | Some(entry) => {
114                    if !entry.1.will_wake(cx.waker()) {
115                        entry.1 = cx.waker().clone();
116                    }
117                },
118                | None => {
119                    queue.push((id, cx.waker().clone()));
120                },
121            }
122
123            self.mutex.waiters.set(queue);
124            Poll::Pending
125        }
126    }
127}
128
129impl<'a, T: ?Sized> Drop for LockFuture<'a, T> {
130    fn drop(&mut self) {
131        if let Some(id) = self.id {
132            let mut queue = self.mutex.waiters.take();
133
134            // Remove this specific future from the wait queue
135            queue.retain(|(w_id, _)| *w_id != id);
136
137            // Prevent Lost Wakeups
138            if !self.mutex.is_locked.get() {
139                if let Some((_, next_waker)) = queue.first() {
140                    next_waker.wake_by_ref();
141                }
142            }
143
144            self.mutex.waiters.set(queue);
145        }
146    }
147}
148
149/// An RAII guard that provides mutable access to the protected data.
150pub struct MutexGuard<'a, T: ?Sized> {
151    mutex: &'a Mutex<T>,
152}
153
154impl<'a, T: ?Sized> Deref for MutexGuard<'a, T> {
155    type Target = T;
156
157    fn deref(&self) -> &Self::Target {
158        unsafe { &*self.mutex.value.get() }
159    }
160}
161
162impl<'a, T: ?Sized> DerefMut for MutexGuard<'a, T> {
163    fn deref_mut(&mut self) -> &mut Self::Target {
164        unsafe { &mut *self.mutex.value.get() }
165    }
166}
167
168impl<'a, T: ?Sized> Drop for MutexGuard<'a, T> {
169    fn drop(&mut self) {
170        self.mutex.is_locked.set(false);
171
172        let queue = self.mutex.waiters.take();
173        let next_waker = queue.first().map(|(_, waker)| waker.clone());
174        self.mutex.waiters.set(queue);
175
176        if let Some(waker) = next_waker {
177            waker.wake();
178        }
179    }
180}
181
182#[cfg(test)]
183mod tests {
184    use std::rc::Rc;
185
186    use tokio::task;
187
188    use super::*;
189
190    #[tokio::test]
191    async fn async_mutex() {
192        let local = task::LocalSet::new();
193        local
194            .run_until(async move {
195                let mutex = Rc::new(Mutex::new(0));
196
197                let m1 = Rc::clone(&mutex);
198                task::spawn_local(async move {
199                    let mut guard = m1.lock().await;
200                    *guard += 1;
201                    ()
202                })
203                .await
204                .unwrap();
205
206                let guard = mutex.lock().await;
207                assert_eq!(*guard, 1);
208            })
209            .await;
210    }
211}