1extern crate alloc;
12use crate::sync::{Mutex, RWLock, ReadLockGuard, WriteLockGuard};
13use alloc::{collections::BTreeMap, sync::Arc};
14use core::{
15 arch::asm,
16 future::Future,
17 ops::{Deref, DerefMut},
18 pin::Pin,
19 task::{Context, Poll, Waker},
20};
21
22pub struct AsyncRWLock<T> {
27 inner: Arc<Mutex<AsyncRWLockInner>>,
32 data: Arc<RWLock<T>>,
34}
35
36impl<T> AsyncRWLock<T> {
37 pub fn new(value: T) -> Self {
39 Self {
40 inner: Arc::new(Mutex::new(AsyncRWLockInner::new())),
41 data: Arc::new(RWLock::new(value)),
42 }
43 }
44
45 pub async fn write(&self) -> AsyncWriteLockGuard<'_, T> {
48 if let Some(guard) = self.data.try_write() {
50 AsyncWriteLockGuard {
52 guard,
53 inner: Arc::clone(&self.inner),
54 }
55 } else {
56 let mut inner = self.inner.lock();
59 let current_id = inner.next_waiter;
60 inner.next_waiter += 1;
61 drop(inner);
62
63 AsyncWriteLockFuture::new(Arc::clone(&self.inner), Arc::clone(&self.data), current_id).await
66 }
67 }
68
69 pub fn write_blocking(&self) -> WriteLockGuard<'_, T> {
70 loop {
71 if let Some(write_guard) = self.data.try_write() {
72 return write_guard;
73 }
74 #[cfg(any(target_arch = "arm", target_arch = "aarch64"))]
77 unsafe {
78 asm!("wfe");
79 }
80 }
81 }
82
83 pub async fn read(&self) -> AsyncReadLockGuard<'_, T> {
86 if let Some(guard) = self.data.try_read() {
88 AsyncReadLockGuard {
90 guard,
91 inner: Arc::clone(&self.inner),
92 }
93 } else {
94 let mut inner = self.inner.lock();
97 let current_id = inner.next_waiter;
98 inner.next_waiter += 1;
99 drop(inner);
100
101 AsyncReadLockFuture::new(Arc::clone(&self.inner), Arc::clone(&self.data), current_id).await
104 }
105 }
106
107 pub fn into_inner(self) -> Result<T, Self> {
111 match Arc::try_unwrap(self.data) {
112 Ok(data) => Ok(data.into_inner()),
113 Err(origin) => Err(Self {
114 inner: self.inner,
115 data: origin,
116 }),
117 }
118 }
119}
120
121pub struct AsyncWriteLockGuard<'a, T: 'a> {
122 guard: WriteLockGuard<'a, T>,
123 inner: Arc<Mutex<AsyncRWLockInner>>,
124}
125
126impl<'a, T> Deref for AsyncWriteLockGuard<'a, T> {
127 type Target = WriteLockGuard<'a, T>;
128
129 fn deref(&self) -> &Self::Target {
130 &self.guard
131 }
132}
133
134impl<'a, T> DerefMut for AsyncWriteLockGuard<'a, T> {
135 fn deref_mut(&mut self) -> &mut Self::Target {
136 &mut self.guard
137 }
138}
139
140impl<T> Drop for AsyncWriteLockGuard<'_, T> {
143 fn drop(&mut self) {
144 let mut inner = self.inner.lock();
147 if let Some(&next_waiter) = inner.waiter.keys().next() {
148 let waiter = inner
151 .waiter
152 .remove(&next_waiter)
153 .expect("found key but can't remove it ???");
154 waiter.wake();
155 }
156 }
157}
158
159pub struct AsyncReadLockGuard<'a, T: 'a> {
160 guard: ReadLockGuard<'a, T>,
161 inner: Arc<Mutex<AsyncRWLockInner>>,
162}
163
164impl<'a, T> Deref for AsyncReadLockGuard<'a, T> {
165 type Target = ReadLockGuard<'a, T>;
166
167 fn deref(&self) -> &Self::Target {
168 &self.guard
169 }
170}
171
172impl<T> Drop for AsyncReadLockGuard<'_, T> {
175 fn drop(&mut self) {
176 let mut inner = self.inner.lock();
179 if let Some(&next_waiter) = inner.waiter.keys().next() {
180 let waiter = inner
183 .waiter
184 .remove(&next_waiter)
185 .expect("found key but can't remove it ???");
186 waiter.wake();
187 }
188 }
189}
190struct AsyncWriteLockFuture<'a, T: ?Sized> {
193 inner: Arc<Mutex<AsyncRWLockInner>>,
194 data: Arc<RWLock<T>>,
195 id: usize,
196 _p: core::marker::PhantomData<&'a T>,
197}
198
199impl<T> AsyncWriteLockFuture<'_, T> {
200 fn new(inner: Arc<Mutex<AsyncRWLockInner>>, data: Arc<RWLock<T>>, id: usize) -> Self {
201 Self {
202 inner,
203 data,
204 id,
205 _p: core::marker::PhantomData,
206 }
207 }
208}
209
210impl<'a, T> Future for AsyncWriteLockFuture<'a, T> {
211 type Output = AsyncWriteLockGuard<'a, T>;
212
213 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
214 let this = unsafe { &*(self.get_mut() as *const Self) };
219 if let Some(guard) = this.data.try_write() {
220 Poll::Ready(AsyncWriteLockGuard {
223 guard,
224 inner: Arc::clone(&this.inner),
225 })
226 } else {
227 let mut inner = this.inner.lock();
230 inner.waiter.insert(this.id, cx.waker().clone());
231 drop(inner);
232
233 Poll::Pending
234 }
235 }
236}
237
238struct AsyncReadLockFuture<'a, T> {
241 inner: Arc<Mutex<AsyncRWLockInner>>,
242 data: Arc<RWLock<T>>,
243 id: usize,
244 _p: core::marker::PhantomData<&'a T>,
245}
246
247impl<T> AsyncReadLockFuture<'_, T> {
248 fn new(inner: Arc<Mutex<AsyncRWLockInner>>, data: Arc<RWLock<T>>, id: usize) -> Self {
249 Self {
250 inner,
251 data,
252 id,
253 _p: core::marker::PhantomData,
254 }
255 }
256}
257
258impl<'a, T> Future for AsyncReadLockFuture<'a, T> {
259 type Output = AsyncReadLockGuard<'a, T>;
260
261 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
262 let this = unsafe { &*(self.get_mut() as *const Self) };
267 if let Some(guard) = this.data.try_read() {
268 Poll::Ready(AsyncReadLockGuard {
271 guard,
272 inner: Arc::clone(&this.inner),
273 })
274 } else {
275 let mut inner = this.inner.lock();
278 inner.waiter.insert(this.id, cx.waker().clone());
279 drop(inner);
280
281 Poll::Pending
282 }
283 }
284}
285struct AsyncRWLockInner {
286 waiter: BTreeMap<usize, Waker>,
289 next_waiter: usize,
292}
293
294impl AsyncRWLockInner {
295 fn new() -> Self {
296 Self {
297 waiter: BTreeMap::new(),
298 next_waiter: 0,
299 }
300 }
301}
302
303#[cfg(testing)]
304mod tests {
305 use super::*;
306 use async_std::prelude::*;
307 use async_std::task;
308 use core::time::Duration;
309
310 #[async_std::test]
311 #[ignore = "test leads sometimes to deadlock on travis-ci for an unknown reason"]
312 async fn wait_on_rwlock_write() {
313 let rwlock = Arc::new(AsyncRWLock::new(10_u32));
314 let rwlock_clone = Arc::clone(&rwlock);
315
316 let task1 = task::spawn(async move {
317 let mut guard = rwlock_clone.lock().await;
318 **guard = 20;
319 task::yield_now().await;
322 task::sleep(Duration::from_secs(1)).await;
323 });
324
325 let task2 = task::spawn(async move {
326 task::yield_now().await;
329 task::sleep(Duration::from_secs(1)).await;
330 let guard = rwlock.lock().await;
331 let value = **guard;
332 assert_eq!(20, value);
333 });
334
335 task1.join(task2).await;
337 }
338
339 #[async_std::test]
340 #[ignore = "test leads sometimes to deadlock on travis-ci for an unknown reason"]
341 async fn wait_on_rwlock_read() {
342 let rwlock = Arc::new(AsyncRWLock::new(10_u32));
343 let rwlock_clone = Arc::clone(&rwlock);
344
345 let task1 = task::spawn(async move {
346 let mut guard = rwlock_clone.lock().await;
347 **guard = 20;
348 task::yield_now().await;
351 task::sleep(Duration::from_secs(1)).await;
352 });
353
354 let task2 = task::spawn(async move {
355 task::yield_now().await;
358 task::sleep(Duration::from_secs(1)).await;
359 let guard = rwlock.read().await;
360 let value = **guard;
361 assert_eq!(20, value);
362 });
363
364 task1.join(task2).await;
366 }
367
368 #[async_std::test]
369 #[ignore = "test leads sometimes to deadlock on travis-ci for an unknown reason"]
370 async fn wait_on_rwlock_write_after_read() {
371 let rwlock = Arc::new(AsyncRWLock::new(10_u32));
372 let rwlock_clone = Arc::clone(&rwlock);
373 let rwlock_clone2 = Arc::clone(&rwlock);
374
375 let task1 = task::spawn(async move {
376 let guard = rwlock_clone.read().await;
377 task::sleep(Duration::from_secs(10)).await;
380 println!("{}", **guard);
381 });
382
383 let task2 = task::spawn(async move {
384 task::sleep(Duration::from_secs(5)).await;
387 let mut guard = rwlock.lock().await;
388 **guard = 20;
389 });
390
391 task1.join(task2).await;
393
394 let guard = rwlock_clone2.read().await;
395 assert_eq!(20, **guard);
396 }
397
398 #[test]
399 fn rwlock_to_inner() {
400 let rwlock = AsyncRWLock::new(10);
401 let inner = rwlock.into_inner();
402 match inner {
403 Ok(data) => assert_eq!(data, 10),
404 _ => panic!("unable to get inner data"),
405 }
406 }
407}