rust_query/
async_db.rs

1use std::{
2    future,
3    sync::{Arc, Mutex},
4    task::{Poll, Waker},
5};
6
7use crate::{Database, Transaction, migrate::Schema};
8
9/// This is an async wrapper for [Database].
10///
11/// Transaction methods on this type are identical to [Database], but they spawn
12/// an async task and return a future to await the transaction completion.
13/// Because of this difference, the transaction closures must be `'static`.
14///
15/// You can easily implement [DatabaseAsync] manually using `tokio::task::spawn_blocking`,
16/// but this wrapper is a little bit more efficient while also being runtime agnostic.
17pub struct DatabaseAsync<S> {
18    inner: Arc<Database<S>>,
19}
20
21impl<S> Clone for DatabaseAsync<S> {
22    fn clone(&self) -> Self {
23        Self {
24            inner: self.inner.clone(),
25        }
26    }
27}
28
29impl<S: 'static + Send + Sync + Schema> DatabaseAsync<S> {
30    /// Create an async wrapper for the [Database].
31    ///
32    /// The database is wrapped in an [Arc] as it needs to be shared with any thread
33    /// executing a transaction. These threads can live longer than the future that
34    /// started the transaction.
35    ///
36    /// By accepting an [Arc], you can keep your own clone of the [Arc] and use
37    /// the database synchronously and asynchronously at the same time!
38    pub fn new(db: Arc<Database<S>>) -> Self {
39        DatabaseAsync { inner: db }
40    }
41
42    #[doc = include_str!("database/transaction.md")]
43    pub async fn transaction<R: 'static + Send>(
44        &self,
45        f: impl 'static + Send + FnOnce(&'static Transaction<S>) -> R,
46    ) -> R {
47        let db = self.inner.clone();
48        async_run(move || db.transaction_local(f)).await
49    }
50
51    #[doc = include_str!("database/transaction_mut.md")]
52    pub async fn transaction_mut<O: 'static + Send, E: 'static + Send>(
53        &self,
54        f: impl 'static + Send + FnOnce(&'static mut Transaction<S>) -> Result<O, E>,
55    ) -> Result<O, E> {
56        let db = self.inner.clone();
57        async_run(move || db.transaction_mut_local(f)).await
58    }
59
60    #[doc = include_str!("database/transaction_mut_ok.md")]
61    pub async fn transaction_mut_ok<R: 'static + Send>(
62        &self,
63        f: impl 'static + Send + FnOnce(&'static mut Transaction<S>) -> R,
64    ) -> R {
65        self.transaction_mut(|txn| Ok::<R, std::convert::Infallible>(f(txn)))
66            .await
67            .unwrap()
68    }
69}
70
71async fn async_run<R: 'static + Send>(f: impl 'static + Send + FnOnce() -> R) -> R {
72    pub struct WakeOnDrop {
73        waker: Mutex<Waker>,
74    }
75
76    impl Drop for WakeOnDrop {
77        #[cfg_attr(test, mutants::skip)] // mutating this will make the test hang
78        fn drop(&mut self) {
79            self.waker.lock().unwrap().wake_by_ref();
80        }
81    }
82
83    // Initally we use a noop waker, because we will override it anyway.
84    let wake_on_drop = Arc::new(WakeOnDrop {
85        waker: Mutex::new(Waker::noop().clone()),
86    });
87    let weak = Arc::downgrade(&wake_on_drop);
88
89    let handle = std::thread::spawn(move || {
90        // waker will be called when thread finishes, even with panic.
91        let _wake_on_drop = wake_on_drop;
92        f()
93    });
94
95    // asynchonously wait for the thread to finish
96    future::poll_fn(|cx| {
97        if let Some(wake_on_drop) = weak.upgrade() {
98            wake_on_drop.waker.lock().unwrap().clone_from(cx.waker());
99            Poll::Pending
100        } else {
101            Poll::Ready(())
102        }
103    })
104    .await;
105
106    // we know that the thread is finished, so we block on it
107    match handle.join() {
108        Ok(val) => val,
109        Err(err) => std::panic::resume_unwind(err),
110    }
111}