Skip to main content

wasm_sql/core/bindings/
transaction.rs

1use sqlx::Acquire;
2use std::sync::Arc;
3
4use tokio::sync::{
5    RwLock,
6    mpsc::{Receiver, Sender},
7    oneshot,
8};
9use wasmtime::component::{Accessor, AccessorTask, JoinHandle};
10
11use crate::{
12    core::bindings::{
13        SqlHostState,
14        executor::{ErasedExecutor, QueryOrRaw},
15        generated::wasm_sql::core::{
16            connection::Connection, transaction::Transaction, util_types::Error,
17        },
18    },
19    execute_with,
20    sqldb::SqlDatabase,
21};
22
23#[allow(dead_code)]
24pub enum TransactionCommand {
25    FetchAll {
26        query: QueryOrRaw,
27        cb: oneshot::Sender<Result<Vec<<SqlDatabase as sqlx::Database>::Row>, Error>>,
28    },
29
30    Execute {
31        query: QueryOrRaw,
32        cb: oneshot::Sender<Result<<SqlDatabase as sqlx::Database>::QueryResult, Error>>,
33    },
34    Commit {
35        cb: oneshot::Sender<Result<(), Error>>,
36    },
37
38    Rollback {
39        cb: oneshot::Sender<Result<(), Error>>,
40    },
41}
42
43#[allow(dead_code)]
44pub struct ConnectionBoundTask {
45    pub(crate) resource: wasmtime::component::Resource<Connection>,
46    pub(crate) receiver: Receiver<TransactionCommand>,
47}
48
49#[allow(dead_code)]
50#[derive(Clone)]
51pub enum TransactionImpl {
52    Tx(Arc<RwLock<Option<sqlx::Transaction<'static, SqlDatabase>>>>),
53    ConnectionBound {
54        handle: Arc<JoinHandle>,
55        sender: Sender<TransactionCommand>,
56    },
57}
58
59#[allow(dead_code)]
60trait TransactionErasedOps {
61    async fn commit(self) -> Result<(), Error>;
62    async fn rollback(self) -> Result<(), Error>;
63}
64
65impl ErasedExecutor<SqlHostState> for TransactionImpl {
66    async fn fetch_all<T>(
67        &self,
68        query: QueryOrRaw,
69        accessor: &Accessor<T, SqlHostState>,
70    ) -> Result<Vec<<SqlDatabase as sqlx::Database>::Row>, Error> {
71        match self {
72            TransactionImpl::Tx(rw_lock) => {
73                let mut guard = rw_lock.write().await;
74                if let Some(tx) = guard.as_mut() {
75                    execute_with!(&mut **tx, accessor, query, fetch_all)
76                } else {
77                    Err(Error::TransactionClosed)
78                }
79            }
80            TransactionImpl::ConnectionBound { handle: _, sender } => {
81                let (s, r) = oneshot::channel();
82                sender
83                    .send(TransactionCommand::FetchAll { query, cb: s })
84                    .await?;
85                r.await?
86            }
87        }
88    }
89
90    async fn execute<T>(
91        &self,
92        query: super::executor::QueryOrRaw,
93        accessor: &Accessor<T, SqlHostState>,
94    ) -> Result<<SqlDatabase as sqlx::Database>::QueryResult, Error> {
95        match self {
96            TransactionImpl::Tx(rw_lock) => {
97                let mut guard = rw_lock.write().await;
98                if let Some(tx) = guard.as_mut() {
99                    execute_with!(&mut **tx, accessor, query, execute)
100                } else {
101                    Err(Error::TransactionClosed)
102                }
103            }
104            TransactionImpl::ConnectionBound { handle: _, sender } => {
105                let (s, r) = oneshot::channel();
106                sender
107                    .send(TransactionCommand::Execute { query, cb: s })
108                    .await?;
109                r.await?
110            }
111        }
112    }
113}
114
115impl TransactionErasedOps for TransactionImpl {
116    async fn commit(self) -> Result<(), Error> {
117        match self {
118            TransactionImpl::Tx(rw_lock) => {
119                let mut guard = rw_lock.write().await;
120                if let Some(tx) = guard.take() {
121                    Ok(tx.commit().await?)
122                } else {
123                    Err(Error::TransactionClosed)
124                }
125            }
126            TransactionImpl::ConnectionBound { handle: _, sender } => {
127                let (s, r) = oneshot::channel();
128                sender.send(TransactionCommand::Commit { cb: s }).await?;
129
130                r.await?
131            }
132        }
133    }
134
135    async fn rollback(self) -> Result<(), Error> {
136        match self {
137            TransactionImpl::Tx(rw_lock) => {
138                let mut guard = rw_lock.write().await;
139                if let Some(tx) = guard.take() {
140                    Ok(tx.rollback().await?)
141                } else {
142                    Err(Error::TransactionClosed)
143                }
144            }
145            TransactionImpl::ConnectionBound { handle: _, sender } => {
146                let (s, r) = oneshot::channel();
147                sender.send(TransactionCommand::Rollback { cb: s }).await?;
148
149                r.await?
150            }
151        }
152    }
153}
154
155impl<T> AccessorTask<T, SqlHostState> for ConnectionBoundTask {
156    async fn run(
157        mut self,
158        accessor: &wasmtime::component::Accessor<T, SqlHostState>,
159    ) -> Result<(), wasmtime::Error> {
160        let conn = accessor.with(|mut access| {
161            let state = access.get();
162
163            state
164                .table
165                .get(&self.resource)
166                .map(|x| x.connection.clone())
167        })?;
168
169        let mut guard = conn.write().await;
170        let mut tx = Some(guard.begin().await?);
171
172        while let Some(cmd) = self.receiver.recv().await {
173            match cmd {
174                TransactionCommand::FetchAll { query, cb } => {
175                    let res = if let Some(ref mut tx) = tx {
176                        execute_with!(tx, accessor, query, fetch_all)
177                    } else {
178                        Err(Error::TransactionClosed)
179                    };
180
181                    let _ = cb.send(res);
182                }
183                TransactionCommand::Execute { query, cb } => {
184                    let res = if let Some(ref mut tx) = tx {
185                        execute_with!(tx, accessor, query, execute)
186                    } else {
187                        Err(Error::TransactionClosed)
188                    };
189
190                    let _ = cb.send(res);
191                }
192
193                TransactionCommand::Commit { cb } => {
194                    let res = if let Some(tx) = tx.take() {
195                        tx.commit().await.map_err(|e| e.into())
196                    } else {
197                        Err(Error::TransactionClosed)
198                    };
199
200                    let _ = cb.send(res);
201                }
202                TransactionCommand::Rollback { cb } => {
203                    let res = if let Some(tx) = tx.take() {
204                        tx.rollback().await.map_err(|e| e.into())
205                    } else {
206                        Err(Error::TransactionClosed)
207                    };
208
209                    let _ = cb.send(res);
210                }
211            }
212        }
213
214        Ok(())
215    }
216}
217
218impl crate::core::bindings::generated::wasm_sql::core::transaction::HostTransaction
219    for SqlHostState
220{
221    async fn drop(
222        &mut self,
223        rep: wasmtime::component::Resource<Transaction>,
224    ) -> wasmtime::Result<()> {
225        self.table.delete(rep)?;
226
227        Ok(())
228    }
229}
230
231impl crate::core::bindings::generated::wasm_sql::core::transaction::Host for SqlHostState {}
232
233impl crate::core::bindings::generated::wasm_sql::core::transaction::HostTransactionWithStore
234    for SqlHostState
235{
236    async fn commit<T>(
237        accessor: &wasmtime::component::Accessor<T, Self>,
238        this: wasmtime::component::Resource<Transaction>,
239    ) -> Result<(), Error> {
240        let tx_impl = accessor.with(|mut access| access.get().table.delete(this))?;
241
242        tx_impl.commit().await
243    }
244
245    async fn rollback<T>(
246        accessor: &wasmtime::component::Accessor<T, Self>,
247        this: wasmtime::component::Resource<Transaction>,
248    ) -> Result<(), Error> {
249        let tx_impl = accessor.with(|mut access| access.get().table.delete(this))?;
250
251        tx_impl.rollback().await
252    }
253}