1use std::{
4 cell::{
5 Cell,
6 UnsafeCell,
7 },
8 future::Future,
9 ops::{
10 Deref,
11 DerefMut,
12 },
13 pin::Pin,
14 task::{
15 Context,
16 Poll,
17 Waker,
18 },
19};
20
21use smallvec::SmallVec;
22
23const UNLOCKED: usize = 0;
24const WRITE_LOCKED: usize = usize::MAX;
25
26#[derive(Debug, Clone, Copy, PartialEq)]
27enum WaiterType {
28 Read,
29 Write,
30}
31
32pub struct RwLock<T: ?Sized> {
39 state: Cell<usize>,
40 next_id: Cell<usize>,
41 waiters: Cell<SmallVec<[(usize, WaiterType, Waker); 8]>>,
42 value: UnsafeCell<T>,
43}
44
45impl<T> RwLock<T> {
46 pub fn new(value: T) -> Self {
48 Self {
49 state: Cell::new(UNLOCKED),
50 next_id: Cell::new(0),
51 waiters: Cell::new(SmallVec::new()),
52 value: UnsafeCell::new(value),
53 }
54 }
55}
56
57impl<T: ?Sized> RwLock<T> {
58 pub fn value_ptr(&self) -> *mut T {
60 self.value.get()
61 }
62
63 pub async fn read(&self) -> RwLockReadGuard<'_, T> {
65 RwLockReadFuture {
66 lock: self, id: None
67 }
68 .await
69 }
70
71 pub async fn write(&self) -> RwLockWriteGuard<'_, T> {
73 RwLockWriteFuture {
74 lock: self, id: None
75 }
76 .await
77 }
78
79 pub fn try_read(&self) -> Option<RwLockReadGuard<'_, T>> {
81 let s = self.state.get();
82 if s != WRITE_LOCKED {
83 self.state.set(s + 1);
84 Some(RwLockReadGuard {
85 lock: self
86 })
87 } else {
88 None
89 }
90 }
91
92 pub fn try_write(&self) -> Option<RwLockWriteGuard<'_, T>> {
94 if self.state.get() == UNLOCKED {
95 self.state.set(WRITE_LOCKED);
96 Some(RwLockWriteGuard {
97 lock: self
98 })
99 } else {
100 None
101 }
102 }
103
104 fn wake_next(&self) {
107 let queue = self.waiters.take();
108
109 match queue.first() {
110 | Some((_, WaiterType::Write, waker)) => {
111 waker.wake_by_ref();
112 },
113 | _ => {
114 for (_, typ, waker) in queue.iter() {
115 if *typ == WaiterType::Read {
116 waker.wake_by_ref();
117 } else {
118 break;
119 }
120 }
121 },
122 }
123
124 self.waiters.set(queue);
125 }
126}
127
128pub struct RwLockReadFuture<'a, T: ?Sized> {
130 lock: &'a RwLock<T>,
131 id: Option<usize>,
132}
133
134impl<'a, T: ?Sized> Future for RwLockReadFuture<'a, T> {
135 type Output = RwLockReadGuard<'a, T>;
136
137 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
138 let s = self.lock.state.get();
139
140 let id = self.id.unwrap_or_else(|| {
141 let new_id = self.lock.next_id.get();
142 self.lock.next_id.set(new_id.wrapping_add(1));
143 self.id = Some(new_id);
144 new_id
145 });
146
147 let mut queue = self.lock.waiters.take();
148
149 let has_writer_ahead = queue
151 .iter()
152 .take_while(|(i, ..)| *i != id)
153 .any(|(_, typ, _)| *typ == WaiterType::Write);
154
155 if s != WRITE_LOCKED && !has_writer_ahead {
156 self.lock.state.set(s + 1);
157 queue.retain(|(i, ..)| *i != id);
158 self.lock.waiters.set(queue);
159 self.id = None;
160 return Poll::Ready(RwLockReadGuard {
161 lock: self.lock
162 });
163 }
164
165 match queue.iter_mut().find(|(i, ..)| *i == id) {
166 | Some(entry) => {
167 if !entry.2.will_wake(cx.waker()) {
168 entry.2 = cx.waker().clone();
169 }
170 },
171 | None => {
172 queue.push((id, WaiterType::Read, cx.waker().clone()));
173 },
174 }
175
176 self.lock.waiters.set(queue);
177 Poll::Pending
178 }
179}
180
181impl<'a, T: ?Sized> Drop for RwLockReadFuture<'a, T> {
182 fn drop(&mut self) {
183 if let Some(id) = self.id {
184 let mut queue = self.lock.waiters.take();
185 let was_first = queue.first().map_or(false, |(i, ..)| *i == id);
186 queue.retain(|(i, ..)| *i != id);
187 self.lock.waiters.set(queue);
188
189 if was_first && self.lock.state.get() == UNLOCKED {
191 self.lock.wake_next();
192 }
193 }
194 }
195}
196
197pub struct RwLockWriteFuture<'a, T: ?Sized> {
199 lock: &'a RwLock<T>,
200 id: Option<usize>,
201}
202
203impl<'a, T: ?Sized> Future for RwLockWriteFuture<'a, T> {
204 type Output = RwLockWriteGuard<'a, T>;
205
206 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
207 let s = self.lock.state.get();
208
209 let id = self.id.unwrap_or_else(|| {
210 let new_id = self.lock.next_id.get();
211 self.lock.next_id.set(new_id.wrapping_add(1));
212 self.id = Some(new_id);
213 new_id
214 });
215
216 let mut queue = self.lock.waiters.take();
217 let is_first = queue.first().map_or(true, |(i, ..)| *i == id);
218
219 if s == UNLOCKED && is_first {
220 self.lock.state.set(WRITE_LOCKED);
221 queue.retain(|(i, ..)| *i != id);
222 self.lock.waiters.set(queue);
223 self.id = None;
224 return Poll::Ready(RwLockWriteGuard {
225 lock: self.lock
226 });
227 }
228
229 match queue.iter_mut().find(|(i, ..)| *i == id) {
230 | Some(entry) => {
231 if !entry.2.will_wake(cx.waker()) {
232 entry.2 = cx.waker().clone();
233 }
234 },
235 | None => {
236 queue.push((id, WaiterType::Write, cx.waker().clone()));
237 },
238 }
239
240 self.lock.waiters.set(queue);
241 Poll::Pending
242 }
243}
244
245impl<'a, T: ?Sized> Drop for RwLockWriteFuture<'a, T> {
246 fn drop(&mut self) {
247 if let Some(id) = self.id {
248 let mut queue = self.lock.waiters.take();
249 let was_first = queue.first().map_or(false, |(i, ..)| *i == id);
250 queue.retain(|(i, ..)| *i != id);
251 self.lock.waiters.set(queue);
252
253 if was_first && self.lock.state.get() == UNLOCKED {
254 self.lock.wake_next();
255 }
256 }
257 }
258}
259
260pub struct RwLockReadGuard<'a, T: ?Sized> {
262 lock: &'a RwLock<T>,
263}
264
265impl<'a, T: ?Sized> Deref for RwLockReadGuard<'a, T> {
266 type Target = T;
267
268 fn deref(&self) -> &Self::Target {
269 unsafe { &*self.lock.value.get() }
270 }
271}
272
273impl<'a, T: ?Sized> Drop for RwLockReadGuard<'a, T> {
274 fn drop(&mut self) {
275 let s = self.lock.state.get();
276 self.lock.state.set(s - 1);
277
278 if self.lock.state.get() == UNLOCKED {
279 self.lock.wake_next();
280 }
281 }
282}
283
284pub struct RwLockWriteGuard<'a, T: ?Sized> {
286 lock: &'a RwLock<T>,
287}
288
289impl<'a, T: ?Sized> Deref for RwLockWriteGuard<'a, T> {
290 type Target = T;
291
292 fn deref(&self) -> &Self::Target {
293 unsafe { &*self.lock.value.get() }
294 }
295}
296
297impl<'a, T: ?Sized> DerefMut for RwLockWriteGuard<'a, T> {
298 fn deref_mut(&mut self) -> &mut Self::Target {
299 unsafe { &mut *self.lock.value.get() }
300 }
301}
302
303impl<'a, T: ?Sized> Drop for RwLockWriteGuard<'a, T> {
304 fn drop(&mut self) {
305 self.lock.state.set(UNLOCKED);
306 self.lock.wake_next();
307 }
308}
309
310#[cfg(test)]
311mod tests {
312 use std::rc::Rc;
313
314 use super::*;
315
316 #[tokio::test]
317 async fn async_rwlock() {
318 let rwlock = Rc::new(RwLock::new(0));
319
320 let r1 = Rc::clone(&rwlock);
321 tokio::task::spawn_local(async move {
322 let mut guard = r1.write().await;
323 *guard = 42;
324 })
325 .await
326 .unwrap();
327
328 let guard = rwlock.read().await;
329 assert_eq!(*guard, 42);
330 }
331}