1use std::{
2 future,
3 sync::{Arc, Mutex},
4 task::{Poll, Waker},
5};
6
7use crate::{Database, Transaction, migrate::Schema};
8
9pub struct DatabaseAsync<S> {
18 inner: Arc<Database<S>>,
19}
20
21impl<S> Clone for DatabaseAsync<S> {
22 fn clone(&self) -> Self {
23 Self {
24 inner: self.inner.clone(),
25 }
26 }
27}
28
29impl<S: 'static + Send + Sync + Schema> DatabaseAsync<S> {
30 pub fn new(db: Arc<Database<S>>) -> Self {
39 DatabaseAsync { inner: db }
40 }
41
42 #[doc = include_str!("database/transaction.md")]
43 pub async fn transaction<R: 'static + Send>(
44 &self,
45 f: impl 'static + Send + FnOnce(&'static Transaction<S>) -> R,
46 ) -> R {
47 let db = self.inner.clone();
48 async_run(move || db.transaction_local(f)).await
49 }
50
51 #[doc = include_str!("database/transaction_mut.md")]
52 pub async fn transaction_mut<O: 'static + Send, E: 'static + Send>(
53 &self,
54 f: impl 'static + Send + FnOnce(&'static mut Transaction<S>) -> Result<O, E>,
55 ) -> Result<O, E> {
56 let db = self.inner.clone();
57 async_run(move || db.transaction_mut_local(f)).await
58 }
59
60 #[doc = include_str!("database/transaction_mut_ok.md")]
61 pub async fn transaction_mut_ok<R: 'static + Send>(
62 &self,
63 f: impl 'static + Send + FnOnce(&'static mut Transaction<S>) -> R,
64 ) -> R {
65 self.transaction_mut(|txn| Ok::<R, std::convert::Infallible>(f(txn)))
66 .await
67 .unwrap()
68 }
69}
70
71async fn async_run<R: 'static + Send>(f: impl 'static + Send + FnOnce() -> R) -> R {
72 pub struct WakeOnDrop {
73 waker: Mutex<Waker>,
74 }
75
76 impl Drop for WakeOnDrop {
77 #[cfg_attr(test, mutants::skip)] fn drop(&mut self) {
79 self.waker.lock().unwrap().wake_by_ref();
80 }
81 }
82
83 let wake_on_drop = Arc::new(WakeOnDrop {
85 waker: Mutex::new(Waker::noop().clone()),
86 });
87 let weak = Arc::downgrade(&wake_on_drop);
88
89 let handle = std::thread::spawn(move || {
90 let _wake_on_drop = wake_on_drop;
92 f()
93 });
94
95 future::poll_fn(|cx| {
97 if let Some(wake_on_drop) = weak.upgrade() {
98 wake_on_drop.waker.lock().unwrap().clone_from(cx.waker());
99 Poll::Pending
100 } else {
101 Poll::Ready(())
102 }
103 })
104 .await;
105
106 match handle.join() {
108 Ok(val) => val,
109 Err(err) => std::panic::resume_unwind(err),
110 }
111}