rust_query/async_db.rs
1use std::{
2 future,
3 sync::Arc,
4 task::{Poll, Waker},
5};
6
7use crate::{Database, Transaction, migrate::Schema};
8
9/// This is an async wrapper for [Database].
10///
11/// You can easily achieve the same thing with `tokio::task::spawn_blocking`,
12/// but this wrapper is a little bit more efficient while also being runtime agnostic.
13pub struct DatabaseAsync<S> {
14 inner: Arc<Database<S>>,
15}
16
17impl<S> Clone for DatabaseAsync<S> {
18 fn clone(&self) -> Self {
19 Self {
20 inner: self.inner.clone(),
21 }
22 }
23}
24
25impl<S: 'static + Send + Sync + Schema> DatabaseAsync<S> {
26 /// Create an async wrapper for the [Database].
27 ///
28 /// The database is wrapped in an [Arc] as it needs to be shared with any thread
29 /// executing a transaction. These threads can live longer than the future that
30 /// started the transaction.
31 ///
32 /// By accepting an [Arc], you can keep your own clone of the [Arc] and use
33 /// the database synchronously and asynchronously at the same time!
34 pub fn new(db: Arc<Database<S>>) -> Self {
35 DatabaseAsync { inner: db }
36 }
37
38 /// This is a lot like [Database::transaction], the only difference is that the async function
39 /// does not block the runtime and requires the closure to be `'static`.
40 /// The static requirement is because the future may be canceled, but the transaction can not
41 /// be canceled.
42 pub async fn transaction<R: 'static + Send>(
43 &self,
44 f: impl 'static + Send + FnOnce(&'static Transaction<S>) -> R,
45 ) -> R {
46 let db = self.inner.clone();
47 async_run(move || db.transaction_local(f)).await
48 }
49
50 /// This is a lot like [Database::transaction_mut], the only difference is that the async function
51 /// does not block the runtime and requires the closure to be `'static`.
52 /// The static requirement is because the future may be canceled, but the transaction can not
53 /// be canceled.
54 pub async fn transaction_mut<O: 'static + Send, E: 'static + Send>(
55 &self,
56 f: impl 'static + Send + FnOnce(&'static mut Transaction<S>) -> Result<O, E>,
57 ) -> Result<O, E> {
58 let db = self.inner.clone();
59 async_run(move || db.transaction_mut_local(f)).await
60 }
61
62 /// This is a lot like [Database::transaction_mut_ok], the only difference is that the async function
63 /// does not block the runtime and requires the closure to be `'static`.
64 /// The static requirement is because the future may be canceled, but the transaction can not
65 /// be canceled.
66 pub async fn transaction_mut_ok<R: 'static + Send>(
67 &self,
68 f: impl 'static + Send + FnOnce(&'static mut Transaction<S>) -> R,
69 ) -> R {
70 self.transaction_mut(|txn| Ok::<R, std::convert::Infallible>(f(txn)))
71 .await
72 .unwrap()
73 }
74}
75
76async fn async_run<R: 'static + Send>(f: impl 'static + Send + FnOnce() -> R) -> R {
77 pub struct WakeOnDrop {
78 waker: Option<Waker>,
79 }
80
81 impl Drop for WakeOnDrop {
82 fn drop(&mut self) {
83 self.waker.take().unwrap().wake();
84 }
85 }
86
87 let waker = future::poll_fn(|cx| Poll::Ready(cx.waker().clone())).await;
88 let done = Arc::new(());
89
90 let handle = std::thread::spawn({
91 let done = done.clone();
92 move || {
93 // waker will be called when thread finishes, even with panic.
94 let _wake_on_drop = WakeOnDrop { waker: Some(waker) };
95 // done arc is dropped before waking
96 let _done_on_drop = done;
97 f()
98 }
99 });
100
101 // asynchonously wait for the thread to finish
102 future::poll_fn(|_cx| {
103 // check if the done Arc is dropped
104 if Arc::strong_count(&done) == 1 {
105 Poll::Ready(())
106 } else {
107 Poll::Pending
108 }
109 })
110 .await;
111
112 // we know that the thread is finished, so we block on it
113 match handle.join() {
114 Ok(val) => val,
115 Err(err) => std::panic::resume_unwind(err),
116 }
117}