xiaoyong_value/unsync/
async_mutex.rs1use std::{
4 cell::{
5 Cell,
6 UnsafeCell,
7 },
8 ops::{
9 Deref,
10 DerefMut,
11 },
12 pin::Pin,
13 task::{
14 Context,
15 Poll,
16 Waker,
17 },
18};
19
20use smallvec::SmallVec;
21
22pub struct Mutex<T: ?Sized> {
28 is_locked: Cell<bool>,
29 next_id: Cell<usize>,
30 waiters: Cell<SmallVec<[(usize, Waker); 8]>>,
31 value: UnsafeCell<T>,
32}
33
34impl<T> Mutex<T> {
35 pub fn new(value: T) -> Self {
37 Self {
38 is_locked: Cell::new(false),
39 next_id: Cell::new(0),
40 waiters: Cell::new(SmallVec::new()),
41 value: UnsafeCell::new(value),
42 }
43 }
44}
45
46impl<T: ?Sized> Mutex<T> {
47 pub fn value_ptr(&self) -> *mut T {
49 self.value.get()
50 }
51
52 pub async fn lock(&self) -> MutexGuard<'_, T> {
54 LockFuture {
55 mutex: self,
56 id: None,
57 }
58 .await
59 }
60
61 pub fn try_lock(&self) -> Option<MutexGuard<'_, T>> {
63 if !self.is_locked.get() {
64 self.is_locked.set(true);
65 Some(MutexGuard {
66 mutex: self
67 })
68 } else {
69 None
70 }
71 }
72}
73
74pub struct LockFuture<'a, T: ?Sized> {
76 mutex: &'a Mutex<T>,
77 id: Option<usize>,
78}
79
80impl<'a, T: ?Sized> Future for LockFuture<'a, T> {
81 type Output = MutexGuard<'a, T>;
82
83 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
84 if !self.mutex.is_locked.get() {
85 self.mutex.is_locked.set(true);
86
87 if let Some(id) = self.id {
89 let mut queue = self.mutex.waiters.take();
90 queue.retain(|(w_id, _)| *w_id != id);
91 self.mutex.waiters.set(queue);
92
93 self.id = None;
95 }
96
97 Poll::Ready(MutexGuard {
98 mutex: self.mutex
99 })
100 } else {
101 let id = self.id.unwrap_or_else(|| {
103 let new_id = self.mutex.next_id.get();
104 self.mutex.next_id.set(new_id.wrapping_add(1));
105 self.id = Some(new_id);
106 new_id
107 });
108
109 let mut queue = self.mutex.waiters.take();
110
111 match queue.iter_mut().find(|(i, _)| *i == id) {
113 | Some(entry) => {
114 if !entry.1.will_wake(cx.waker()) {
115 entry.1 = cx.waker().clone();
116 }
117 },
118 | None => {
119 queue.push((id, cx.waker().clone()));
120 },
121 }
122
123 self.mutex.waiters.set(queue);
124 Poll::Pending
125 }
126 }
127}
128
129impl<'a, T: ?Sized> Drop for LockFuture<'a, T> {
130 fn drop(&mut self) {
131 if let Some(id) = self.id {
132 let mut queue = self.mutex.waiters.take();
133
134 queue.retain(|(w_id, _)| *w_id != id);
136
137 if !self.mutex.is_locked.get() {
139 if let Some((_, next_waker)) = queue.first() {
140 next_waker.wake_by_ref();
141 }
142 }
143
144 self.mutex.waiters.set(queue);
145 }
146 }
147}
148
149pub struct MutexGuard<'a, T: ?Sized> {
151 mutex: &'a Mutex<T>,
152}
153
154impl<'a, T: ?Sized> Deref for MutexGuard<'a, T> {
155 type Target = T;
156
157 fn deref(&self) -> &Self::Target {
158 unsafe { &*self.mutex.value.get() }
159 }
160}
161
162impl<'a, T: ?Sized> DerefMut for MutexGuard<'a, T> {
163 fn deref_mut(&mut self) -> &mut Self::Target {
164 unsafe { &mut *self.mutex.value.get() }
165 }
166}
167
168impl<'a, T: ?Sized> Drop for MutexGuard<'a, T> {
169 fn drop(&mut self) {
170 self.mutex.is_locked.set(false);
171
172 let queue = self.mutex.waiters.take();
173 let next_waker = queue.first().map(|(_, waker)| waker.clone());
174 self.mutex.waiters.set(queue);
175
176 if let Some(waker) = next_waker {
177 waker.wake();
178 }
179 }
180}
181
182#[cfg(test)]
183mod tests {
184 use std::rc::Rc;
185
186 use tokio::task;
187
188 use super::*;
189
190 #[tokio::test]
191 async fn async_mutex() {
192 let local = task::LocalSet::new();
193 local
194 .run_until(async move {
195 let mutex = Rc::new(Mutex::new(0));
196
197 let m1 = Rc::clone(&mutex);
198 task::spawn_local(async move {
199 let mut guard = m1.lock().await;
200 *guard += 1;
201 ()
202 })
203 .await
204 .unwrap();
205
206 let guard = mutex.lock().await;
207 assert_eq!(*guard, 1);
208 })
209 .await;
210 }
211}