1use sqlx::Acquire;
2use std::sync::Arc;
3
4use tokio::sync::{
5 RwLock,
6 mpsc::{Receiver, Sender},
7 oneshot,
8};
9use wasmtime::component::{Accessor, AccessorTask, JoinHandle};
10
11use crate::{
12 core::bindings::{
13 SqlHostState,
14 executor::{ErasedExecutor, QueryOrRaw},
15 generated::wasm_sql::core::{
16 connection::Connection, transaction::Transaction, util_types::Error,
17 },
18 },
19 execute_with,
20 sqldb::SqlDatabase,
21};
22
23#[allow(dead_code)]
24pub enum TransactionCommand {
25 FetchAll {
26 query: QueryOrRaw,
27 cb: oneshot::Sender<Result<Vec<<SqlDatabase as sqlx::Database>::Row>, Error>>,
28 },
29
30 Execute {
31 query: QueryOrRaw,
32 cb: oneshot::Sender<Result<<SqlDatabase as sqlx::Database>::QueryResult, Error>>,
33 },
34 Commit {
35 cb: oneshot::Sender<Result<(), Error>>,
36 },
37
38 Rollback {
39 cb: oneshot::Sender<Result<(), Error>>,
40 },
41}
42
43#[allow(dead_code)]
44pub struct ConnectionBoundTask {
45 pub(crate) resource: wasmtime::component::Resource<Connection>,
46 pub(crate) receiver: Receiver<TransactionCommand>,
47}
48
49#[allow(dead_code)]
50#[derive(Clone)]
51pub enum TransactionImpl {
52 Tx(Arc<RwLock<Option<sqlx::Transaction<'static, SqlDatabase>>>>),
53 ConnectionBound {
54 handle: Arc<JoinHandle>,
55 sender: Sender<TransactionCommand>,
56 },
57}
58
59#[allow(dead_code)]
60trait TransactionErasedOps {
61 async fn commit(self) -> Result<(), Error>;
62 async fn rollback(self) -> Result<(), Error>;
63}
64
65impl ErasedExecutor<SqlHostState> for TransactionImpl {
66 async fn fetch_all<T>(
67 &self,
68 query: QueryOrRaw,
69 accessor: &Accessor<T, SqlHostState>,
70 ) -> Result<Vec<<SqlDatabase as sqlx::Database>::Row>, Error> {
71 match self {
72 TransactionImpl::Tx(rw_lock) => {
73 let mut guard = rw_lock.write().await;
74 if let Some(tx) = guard.as_mut() {
75 execute_with!(&mut **tx, accessor, query, fetch_all)
76 } else {
77 Err(Error::TransactionClosed)
78 }
79 }
80 TransactionImpl::ConnectionBound { handle: _, sender } => {
81 let (s, r) = oneshot::channel();
82 sender
83 .send(TransactionCommand::FetchAll { query, cb: s })
84 .await?;
85 r.await?
86 }
87 }
88 }
89
90 async fn execute<T>(
91 &self,
92 query: super::executor::QueryOrRaw,
93 accessor: &Accessor<T, SqlHostState>,
94 ) -> Result<<SqlDatabase as sqlx::Database>::QueryResult, Error> {
95 match self {
96 TransactionImpl::Tx(rw_lock) => {
97 let mut guard = rw_lock.write().await;
98 if let Some(tx) = guard.as_mut() {
99 execute_with!(&mut **tx, accessor, query, execute)
100 } else {
101 Err(Error::TransactionClosed)
102 }
103 }
104 TransactionImpl::ConnectionBound { handle: _, sender } => {
105 let (s, r) = oneshot::channel();
106 sender
107 .send(TransactionCommand::Execute { query, cb: s })
108 .await?;
109 r.await?
110 }
111 }
112 }
113}
114
115impl TransactionErasedOps for TransactionImpl {
116 async fn commit(self) -> Result<(), Error> {
117 match self {
118 TransactionImpl::Tx(rw_lock) => {
119 let mut guard = rw_lock.write().await;
120 if let Some(tx) = guard.take() {
121 Ok(tx.commit().await?)
122 } else {
123 Err(Error::TransactionClosed)
124 }
125 }
126 TransactionImpl::ConnectionBound { handle: _, sender } => {
127 let (s, r) = oneshot::channel();
128 sender.send(TransactionCommand::Commit { cb: s }).await?;
129
130 r.await?
131 }
132 }
133 }
134
135 async fn rollback(self) -> Result<(), Error> {
136 match self {
137 TransactionImpl::Tx(rw_lock) => {
138 let mut guard = rw_lock.write().await;
139 if let Some(tx) = guard.take() {
140 Ok(tx.rollback().await?)
141 } else {
142 Err(Error::TransactionClosed)
143 }
144 }
145 TransactionImpl::ConnectionBound { handle: _, sender } => {
146 let (s, r) = oneshot::channel();
147 sender.send(TransactionCommand::Rollback { cb: s }).await?;
148
149 r.await?
150 }
151 }
152 }
153}
154
155impl<T> AccessorTask<T, SqlHostState> for ConnectionBoundTask {
156 async fn run(
157 mut self,
158 accessor: &wasmtime::component::Accessor<T, SqlHostState>,
159 ) -> Result<(), wasmtime::Error> {
160 let conn = accessor.with(|mut access| {
161 let state = access.get();
162
163 state
164 .table
165 .get(&self.resource)
166 .map(|x| x.connection.clone())
167 })?;
168
169 let mut guard = conn.write().await;
170 let mut tx = Some(guard.begin().await?);
171
172 while let Some(cmd) = self.receiver.recv().await {
173 match cmd {
174 TransactionCommand::FetchAll { query, cb } => {
175 let res = if let Some(ref mut tx) = tx {
176 execute_with!(tx, accessor, query, fetch_all)
177 } else {
178 Err(Error::TransactionClosed)
179 };
180
181 let _ = cb.send(res);
182 }
183 TransactionCommand::Execute { query, cb } => {
184 let res = if let Some(ref mut tx) = tx {
185 execute_with!(tx, accessor, query, execute)
186 } else {
187 Err(Error::TransactionClosed)
188 };
189
190 let _ = cb.send(res);
191 }
192
193 TransactionCommand::Commit { cb } => {
194 let res = if let Some(tx) = tx.take() {
195 tx.commit().await.map_err(|e| e.into())
196 } else {
197 Err(Error::TransactionClosed)
198 };
199
200 let _ = cb.send(res);
201 }
202 TransactionCommand::Rollback { cb } => {
203 let res = if let Some(tx) = tx.take() {
204 tx.rollback().await.map_err(|e| e.into())
205 } else {
206 Err(Error::TransactionClosed)
207 };
208
209 let _ = cb.send(res);
210 }
211 }
212 }
213
214 Ok(())
215 }
216}
217
218impl crate::core::bindings::generated::wasm_sql::core::transaction::HostTransaction
219 for SqlHostState
220{
221 async fn drop(
222 &mut self,
223 rep: wasmtime::component::Resource<Transaction>,
224 ) -> wasmtime::Result<()> {
225 self.table.delete(rep)?;
226
227 Ok(())
228 }
229}
230
231impl crate::core::bindings::generated::wasm_sql::core::transaction::Host for SqlHostState {}
232
233impl crate::core::bindings::generated::wasm_sql::core::transaction::HostTransactionWithStore
234 for SqlHostState
235{
236 async fn commit<T>(
237 accessor: &wasmtime::component::Accessor<T, Self>,
238 this: wasmtime::component::Resource<Transaction>,
239 ) -> Result<(), Error> {
240 let tx_impl = accessor.with(|mut access| access.get().table.delete(this))?;
241
242 tx_impl.commit().await
243 }
244
245 async fn rollback<T>(
246 accessor: &wasmtime::component::Accessor<T, Self>,
247 this: wasmtime::component::Resource<Transaction>,
248 ) -> Result<(), Error> {
249 let tx_impl = accessor.with(|mut access| access.get().table.delete(this))?;
250
251 tx_impl.rollback().await
252 }
253}