ruspiro_lock/async/
asyncmutex.rs

1/***********************************************************************************************************************
2 * Copyright (c) 2020 by the authors
3 *
4 * Author: André Borrmann <pspwizard@gmx.de>
5 * License: Apache License 2.0 / MIT
6 **********************************************************************************************************************/
7
8//! # Async Mutex
9//!
10
11extern crate alloc;
12use crate::sync::{Mutex, MutexGuard};
13use alloc::{collections::BTreeMap, sync::Arc};
14use core::{
15  future::Future,
16  ops::{Deref, DerefMut},
17  pin::Pin,
18  task::{Context, Poll, Waker},
19};
20
21/// An async mutex lock that can be used in async functions to prevent blocking current execution while waiting for the
22/// lock to become available. So for this to work the `lock()` method does not return a MutexGuard immediately but a
23/// [Future] that will resove into a [AsyncMutexGuard] when `await`ed.
24pub struct AsyncMutex<T> {
25  /// The inner wrapper to the actual [Mutex] requires to be secured with a [Mutex] on it's own
26  /// as we require mutual exclusive access to it. This actually should not harm any concurrent blocking
27  /// as this is a short living lock that will be only aquired to request the actual lock status. So it is
28  /// more then unlikely that this will happen in parallel at the same time
29  inner: Arc<Mutex<AsyncMutexInner>>,
30  /// The actual [Mutex] securing the contained data for mutual exclusive access
31  data: Arc<Mutex<T>>,
32}
33
34impl<T> AsyncMutex<T> {
35  /// Create the [AsyncMutex]
36  pub fn new(value: T) -> Self {
37    Self {
38      inner: Arc::new(Mutex::new(AsyncMutexInner::new())),
39      data: Arc::new(Mutex::new(value)),
40    }
41  }
42
43  /// Locking the data secured by the [AsyncMutex] will yield a `Future` that must be awaited to actually acquire
44  /// the lock.
45  pub async fn lock(&self) -> AsyncMutexGuard<'_, T> {
46    // check if we could immediately get the lock
47    if let Some(guard) = self.data.try_lock() {
48      // lock immediatly acquired, provide the lock guard as result
49      AsyncMutexGuard {
50        guard,
51        inner: Arc::clone(&self.inner),
52      }
53    } else {
54      // to be able to request the lock we require to upate the inner metadata. For this to work we require a
55      // short living exclusive lock to this data.
56      let mut inner = self.inner.lock();
57      let current_id = inner.next_waiter;
58      inner.next_waiter += 1;
59      drop(inner);
60
61      // once we have updated the metadata we can release the lock to it and create the `Future` that will yield
62      // the lock to the data once available
63      AsyncMutexFuture::new(Arc::clone(&self.inner), Arc::clone(&self.data), current_id).await
64    }
65  }
66
67  /// Provide the inner data wrapped by this [AsyncMutex]. This will only provide the contained data if there is only
68  /// one active reference to it. If the data is still shared more than once, eg. because there are active `Future`s
69  /// awaiting a lock this will return the actual `AsyncMutex` in the `Err` variant.
70  pub fn into_inner(self) -> Result<T, Self>
71  where
72    T: Sized,
73  {
74    match Arc::try_unwrap(self.data) {
75      Ok(data) => Ok(data.into_inner()),
76      Err(origin) => Err(Self {
77        inner: self.inner,
78        data: origin,
79      }),
80    }
81  }
82}
83
84pub struct AsyncMutexGuard<'a, T: 'a> {
85  guard: MutexGuard<'a, T>,
86  inner: Arc<Mutex<AsyncMutexInner>>,
87}
88
89impl<'a, T> Deref for AsyncMutexGuard<'a, T> {
90  type Target = MutexGuard<'a, T>;
91
92  fn deref(&self) -> &Self::Target {
93    &self.guard
94  }
95}
96
97impl<'a, T> DerefMut for AsyncMutexGuard<'a, T> {
98  fn deref_mut(&mut self) -> &mut Self::Target {
99    &mut self.guard
100  }
101}
102
103/// If an [AsyncMutexGuard] get's dropped we need to wake the `Future`s that might hav registered themself and
104/// are waiting to aquire the lock.
105impl<T> Drop for AsyncMutexGuard<'_, T> {
106  fn drop(&mut self) {
107    // if the mutex guard is about to be locked we need to check if there has been a waker send
108    // already to get woken
109    let mut inner = self.inner.lock();
110    if let Some(&next_waiter) = inner.waiter.keys().next() {
111      // remove the waker from the waiter list as it will re-register itself when the corresponding
112      // Future is polled and can't acquire the lock
113      let waiter = inner
114        .waiter
115        .remove(&next_waiter)
116        .expect("found key but can't remove it ???");
117      waiter.wake_by_ref();
118    }
119  }
120}
121
122/// The `Future` that represents an `await`able [AsynMutex] and can only be created from the functions of [AsyncMutex].
123struct AsyncMutexFuture<'a, T: 'a> {
124  inner: Arc<Mutex<AsyncMutexInner>>,
125  data: Arc<Mutex<T>>,
126  id: usize,
127  _p: core::marker::PhantomData<&'a T>,
128}
129
130impl<T> AsyncMutexFuture<'_, T> {
131  fn new(inner: Arc<Mutex<AsyncMutexInner>>, data: Arc<Mutex<T>>, id: usize) -> Self {
132    Self {
133      inner,
134      data,
135      id,
136      _p: core::marker::PhantomData,
137    }
138  }
139}
140
141impl<'a, T> Future for AsyncMutexFuture<'a, T> {
142  type Output = AsyncMutexGuard<'a, T>;
143
144  fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
145    // we need to elide the lifetime given by self.get_mut() using unsafe code here
146    // SAFETY: it's actually safe as we either return Poll::Pending without any lifetime or we
147    // handout the `AsyncMutexGuard` with lifetime 'a which bound to the AsyncMutex that created this Future and
148    // will always outlive this future and is therefore ok - I guess...
149    let this = unsafe { &*(self.get_mut() as *const Self) };
150    if let Some(guard) = this.data.try_lock() {
151      // data lock could be acquired
152      // provide the AsyncMutexGuard
153      Poll::Ready(AsyncMutexGuard {
154        guard,
155        inner: Arc::clone(&this.inner),
156      })
157    } else {
158      // data lock could not be acquired this time, so someone else is holding the lock. We need to register
159      // ourself to get woken as soon as the lock gets available
160      let mut inner = this.inner.lock();
161      inner.waiter.insert(this.id, cx.waker().clone());
162      drop(inner);
163
164      Poll::Pending
165    }
166  }
167}
168
169struct AsyncMutexInner {
170  /// If the lock could not be aquired we store the requestor id here to allow the next one
171  /// already waiting for the lock to retrieve it
172  waiter: BTreeMap<usize, Waker>,
173  /// The id of the next waiter that can be woken once the lock is released and someone else is already waiting for
174  /// the lock to be aquired
175  next_waiter: usize,
176}
177
178impl AsyncMutexInner {
179  fn new() -> Self {
180    Self {
181      waiter: BTreeMap::new(),
182      next_waiter: 0,
183    }
184  }
185}
186
187#[cfg(testing)]
188mod tests {
189  use super::*;
190  use async_std::prelude::*;
191  use async_std::task;
192  use core::time::Duration;
193
194  #[async_std::test]
195  async fn wait_on_mutex() {
196    let mutex = Arc::new(AsyncMutex::new(10_u32));
197    let mutex_clone = Arc::clone(&mutex);
198
199    let task1 = task::spawn(async move {
200      let mut guard = mutex_clone.lock().await;
201      **guard = 20;
202      // with the AsyncMutexLock in place wait a second to keep the guard
203      // alive and let the second task relly wait for this one
204      task::yield_now().await;
205      task::sleep(Duration::from_secs(1)).await;
206    });
207
208    let task2 = task::spawn(async move {
209      // if this async is started first wait a bit to really run the
210      // other one first to aquire the AsyncMutexLock
211      task::yield_now().await;
212      task::sleep(Duration::from_millis(100)).await;
213      let guard = mutex.lock().await;
214      let value = **guard;
215      assert_eq!(20, value);
216    });
217
218    // run both tasks concurrently
219    task1.join(task2).await;
220  }
221
222  #[test]
223  fn mutex_to_inner() {
224    let mutex = AsyncMutex::new(10);
225    let inner = mutex.into_inner();
226    match inner {
227      Ok(data) => assert_eq!(data, 10),
228      _ => panic!("unable to get inner data"),
229    }
230  }
231}