rust_query/
async_db.rs

1use std::{
2    future,
3    sync::Arc,
4    task::{Poll, Waker},
5};
6
7use crate::{Database, Transaction, migrate::Schema};
8
9/// This is an async wrapper for [Database].
10///
11/// You can easily achieve the same thing with `tokio::task::spawn_blocking`,
12/// but this wrapper is a little bit more efficient while also being runtime agnostic.
13pub struct DatabaseAsync<S> {
14    inner: Arc<Database<S>>,
15}
16
17impl<S> Clone for DatabaseAsync<S> {
18    fn clone(&self) -> Self {
19        Self {
20            inner: self.inner.clone(),
21        }
22    }
23}
24
25impl<S: 'static + Send + Sync + Schema> DatabaseAsync<S> {
26    /// Create an async wrapper for the [Database].
27    ///
28    /// The database is wrapped in an [Arc] as it needs to be shared with any thread
29    /// executing a transaction. These threads can live longer than the future that
30    /// started the transaction.
31    ///
32    /// By accepting an [Arc], you can keep your own clone of the [Arc] and use
33    /// the database synchronously and asynchronously at the same time!
34    pub fn new(db: Arc<Database<S>>) -> Self {
35        DatabaseAsync { inner: db }
36    }
37
38    /// This is a lot like [Database::transaction], the only difference is that the async function
39    /// does not block the runtime and requires the closure to be `'static`.
40    /// The static requirement is because the future may be canceled, but the transaction can not
41    /// be canceled.
42    pub async fn transaction<R: 'static + Send>(
43        &self,
44        f: impl 'static + Send + FnOnce(&'static Transaction<S>) -> R,
45    ) -> R {
46        let db = self.inner.clone();
47        async_run(move || db.transaction_local(f)).await
48    }
49
50    /// This is a lot like [Database::transaction_mut], the only difference is that the async function
51    /// does not block the runtime and requires the closure to be `'static`.
52    /// The static requirement is because the future may be canceled, but the transaction can not
53    /// be canceled.
54    pub async fn transaction_mut<O: 'static + Send, E: 'static + Send>(
55        &self,
56        f: impl 'static + Send + FnOnce(&'static mut Transaction<S>) -> Result<O, E>,
57    ) -> Result<O, E> {
58        let db = self.inner.clone();
59        async_run(move || db.transaction_mut_local(f)).await
60    }
61
62    /// This is a lot like [Database::transaction_mut_ok], the only difference is that the async function
63    /// does not block the runtime and requires the closure to be `'static`.
64    /// The static requirement is because the future may be canceled, but the transaction can not
65    /// be canceled.
66    pub async fn transaction_mut_ok<R: 'static + Send>(
67        &self,
68        f: impl 'static + Send + FnOnce(&'static mut Transaction<S>) -> R,
69    ) -> R {
70        self.transaction_mut(|txn| Ok::<R, std::convert::Infallible>(f(txn)))
71            .await
72            .unwrap()
73    }
74}
75
76async fn async_run<R: 'static + Send>(f: impl 'static + Send + FnOnce() -> R) -> R {
77    pub struct WakeOnDrop {
78        waker: Option<Waker>,
79    }
80
81    impl Drop for WakeOnDrop {
82        fn drop(&mut self) {
83            self.waker.take().unwrap().wake();
84        }
85    }
86
87    let waker = future::poll_fn(|cx| Poll::Ready(cx.waker().clone())).await;
88    let done = Arc::new(());
89
90    let handle = std::thread::spawn({
91        let done = done.clone();
92        move || {
93            // waker will be called when thread finishes, even with panic.
94            let _wake_on_drop = WakeOnDrop { waker: Some(waker) };
95            // done arc is dropped before waking
96            let _done_on_drop = done;
97            f()
98        }
99    });
100
101    // asynchonously wait for the thread to finish
102    future::poll_fn(|_cx| {
103        // check if the done Arc is dropped
104        if Arc::strong_count(&done) == 1 {
105            Poll::Ready(())
106        } else {
107            Poll::Pending
108        }
109    })
110    .await;
111
112    // we know that the thread is finished, so we block on it
113    match handle.join() {
114        Ok(val) => val,
115        Err(err) => std::panic::resume_unwind(err),
116    }
117}