1use crate::{
2 debug_print, error::*, AccessMode, ConnectionTrait, DbBackend, DbErr, ExecResult,
3 InnerConnection, IsolationLevel, QueryResult, Statement, StreamTrait, TransactionStream,
4 TransactionTrait,
5};
6#[cfg(feature = "sqlx-dep")]
7use crate::{sqlx_error_to_exec_err, sqlx_error_to_query_err};
8use futures_util::lock::Mutex;
9#[cfg(feature = "sqlx-dep")]
10use sqlx::TransactionManager;
11use std::{future::Future, pin::Pin, sync::Arc};
12use tracing::instrument;
13
14pub struct DatabaseTransaction {
18 conn: Arc<Mutex<InnerConnection>>,
19 backend: DbBackend,
20 open: bool,
21 metric_callback: Option<crate::metric::Callback>,
22}
23
24impl std::fmt::Debug for DatabaseTransaction {
25 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26 write!(f, "DatabaseTransaction")
27 }
28}
29
30impl DatabaseTransaction {
31 #[instrument(level = "trace", skip(metric_callback))]
32 pub(crate) async fn begin(
33 conn: Arc<Mutex<InnerConnection>>,
34 backend: DbBackend,
35 metric_callback: Option<crate::metric::Callback>,
36 isolation_level: Option<IsolationLevel>,
37 access_mode: Option<AccessMode>,
38 ) -> Result<DatabaseTransaction, DbErr> {
39 let res = DatabaseTransaction {
40 conn,
41 backend,
42 open: true,
43 metric_callback,
44 };
45 match *res.conn.lock().await {
46 #[cfg(feature = "sqlx-mysql")]
47 InnerConnection::MySql(ref mut c) => {
48 crate::driver::sqlx_mysql::set_transaction_config(c, isolation_level, access_mode)
50 .await?;
51 <sqlx::MySql as sqlx::Database>::TransactionManager::begin(c, None)
52 .await
53 .map_err(sqlx_error_to_query_err)
54 }
55 #[cfg(feature = "sqlx-postgres")]
56 InnerConnection::Postgres(ref mut c) => {
57 <sqlx::Postgres as sqlx::Database>::TransactionManager::begin(c, None)
58 .await
59 .map_err(sqlx_error_to_query_err)?;
60 crate::driver::sqlx_postgres::set_transaction_config(
62 c,
63 isolation_level,
64 access_mode,
65 )
66 .await
67 }
68 #[cfg(feature = "sqlx-sqlite")]
69 InnerConnection::Sqlite(ref mut c) => {
70 crate::driver::sqlx_sqlite::set_transaction_config(c, isolation_level, access_mode)
72 .await?;
73 <sqlx::Sqlite as sqlx::Database>::TransactionManager::begin(c, None)
74 .await
75 .map_err(sqlx_error_to_query_err)
76 }
77 #[cfg(feature = "mock")]
78 InnerConnection::Mock(ref mut c) => {
79 c.begin();
80 Ok(())
81 }
82 #[cfg(feature = "proxy")]
83 InnerConnection::Proxy(ref mut c) => {
84 c.begin().await;
85 Ok(())
86 }
87 #[allow(unreachable_patterns)]
88 _ => Err(conn_err("Disconnected")),
89 }?;
90 Ok(res)
91 }
92
93 #[instrument(level = "trace", skip(callback))]
96 pub(crate) async fn run<F, T, E>(self, callback: F) -> Result<T, TransactionError<E>>
97 where
98 F: for<'b> FnOnce(
99 &'b DatabaseTransaction,
100 ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'b>>
101 + Send,
102 T: Send,
103 E: std::fmt::Display + std::fmt::Debug + Send,
104 {
105 let res = callback(&self).await.map_err(TransactionError::Transaction);
106 if res.is_ok() {
107 self.commit().await.map_err(TransactionError::Connection)?;
108 } else {
109 self.rollback()
110 .await
111 .map_err(TransactionError::Connection)?;
112 }
113 res
114 }
115
116 #[instrument(level = "trace")]
118 #[allow(unreachable_code, unused_mut)]
119 pub async fn commit(mut self) -> Result<(), DbErr> {
120 match *self.conn.lock().await {
121 #[cfg(feature = "sqlx-mysql")]
122 InnerConnection::MySql(ref mut c) => {
123 <sqlx::MySql as sqlx::Database>::TransactionManager::commit(c)
124 .await
125 .map_err(sqlx_error_to_query_err)
126 }
127 #[cfg(feature = "sqlx-postgres")]
128 InnerConnection::Postgres(ref mut c) => {
129 <sqlx::Postgres as sqlx::Database>::TransactionManager::commit(c)
130 .await
131 .map_err(sqlx_error_to_query_err)
132 }
133 #[cfg(feature = "sqlx-sqlite")]
134 InnerConnection::Sqlite(ref mut c) => {
135 <sqlx::Sqlite as sqlx::Database>::TransactionManager::commit(c)
136 .await
137 .map_err(sqlx_error_to_query_err)
138 }
139 #[cfg(feature = "mock")]
140 InnerConnection::Mock(ref mut c) => {
141 c.commit();
142 Ok(())
143 }
144 #[cfg(feature = "proxy")]
145 InnerConnection::Proxy(ref mut c) => {
146 c.commit().await;
147 Ok(())
148 }
149 #[allow(unreachable_patterns)]
150 _ => Err(conn_err("Disconnected")),
151 }?;
152 self.open = false;
153 Ok(())
154 }
155
156 #[instrument(level = "trace")]
158 #[allow(unreachable_code, unused_mut)]
159 pub async fn rollback(mut self) -> Result<(), DbErr> {
160 match *self.conn.lock().await {
161 #[cfg(feature = "sqlx-mysql")]
162 InnerConnection::MySql(ref mut c) => {
163 <sqlx::MySql as sqlx::Database>::TransactionManager::rollback(c)
164 .await
165 .map_err(sqlx_error_to_query_err)
166 }
167 #[cfg(feature = "sqlx-postgres")]
168 InnerConnection::Postgres(ref mut c) => {
169 <sqlx::Postgres as sqlx::Database>::TransactionManager::rollback(c)
170 .await
171 .map_err(sqlx_error_to_query_err)
172 }
173 #[cfg(feature = "sqlx-sqlite")]
174 InnerConnection::Sqlite(ref mut c) => {
175 <sqlx::Sqlite as sqlx::Database>::TransactionManager::rollback(c)
176 .await
177 .map_err(sqlx_error_to_query_err)
178 }
179 #[cfg(feature = "mock")]
180 InnerConnection::Mock(ref mut c) => {
181 c.rollback();
182 Ok(())
183 }
184 #[cfg(feature = "proxy")]
185 InnerConnection::Proxy(ref mut c) => {
186 c.rollback().await;
187 Ok(())
188 }
189 #[allow(unreachable_patterns)]
190 _ => Err(conn_err("Disconnected")),
191 }?;
192 self.open = false;
193 Ok(())
194 }
195
196 #[instrument(level = "trace")]
198 fn start_rollback(&mut self) -> Result<(), DbErr> {
199 if self.open {
200 if let Some(mut conn) = self.conn.try_lock() {
201 match &mut *conn {
202 #[cfg(feature = "sqlx-mysql")]
203 InnerConnection::MySql(c) => {
204 <sqlx::MySql as sqlx::Database>::TransactionManager::start_rollback(c);
205 }
206 #[cfg(feature = "sqlx-postgres")]
207 InnerConnection::Postgres(c) => {
208 <sqlx::Postgres as sqlx::Database>::TransactionManager::start_rollback(c);
209 }
210 #[cfg(feature = "sqlx-sqlite")]
211 InnerConnection::Sqlite(c) => {
212 <sqlx::Sqlite as sqlx::Database>::TransactionManager::start_rollback(c);
213 }
214 #[cfg(feature = "mock")]
215 InnerConnection::Mock(c) => {
216 c.rollback();
217 }
218 #[cfg(feature = "proxy")]
219 InnerConnection::Proxy(c) => {
220 c.start_rollback();
221 }
222 #[allow(unreachable_patterns)]
223 _ => return Err(conn_err("Disconnected")),
224 }
225 } else {
226 return Err(conn_err("Dropping a locked Transaction"));
228 }
229 }
230 Ok(())
231 }
232}
233
234impl Drop for DatabaseTransaction {
235 fn drop(&mut self) {
236 self.start_rollback().expect("Fail to rollback transaction");
237 }
238}
239
240#[async_trait::async_trait]
241impl ConnectionTrait for DatabaseTransaction {
242 fn get_database_backend(&self) -> DbBackend {
243 self.backend
245 }
246
247 #[instrument(level = "trace")]
248 #[allow(unused_variables)]
249 async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
250 debug_print!("{}", stmt);
251
252 match &mut *self.conn.lock().await {
253 #[cfg(feature = "sqlx-mysql")]
254 InnerConnection::MySql(conn) => {
255 let query = crate::driver::sqlx_mysql::sqlx_query(&stmt);
256 let conn: &mut sqlx::MySqlConnection = &mut *conn;
257 crate::metric::metric!(self.metric_callback, &stmt, {
258 query.execute(conn).await.map(Into::into)
259 })
260 .map_err(sqlx_error_to_exec_err)
261 }
262 #[cfg(feature = "sqlx-postgres")]
263 InnerConnection::Postgres(conn) => {
264 let query = crate::driver::sqlx_postgres::sqlx_query(&stmt);
265 let conn: &mut sqlx::PgConnection = &mut *conn;
266 crate::metric::metric!(self.metric_callback, &stmt, {
267 query.execute(conn).await.map(Into::into)
268 })
269 .map_err(sqlx_error_to_exec_err)
270 }
271 #[cfg(feature = "sqlx-sqlite")]
272 InnerConnection::Sqlite(conn) => {
273 let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt);
274 let conn: &mut sqlx::SqliteConnection = &mut *conn;
275 crate::metric::metric!(self.metric_callback, &stmt, {
276 query.execute(conn).await.map(Into::into)
277 })
278 .map_err(sqlx_error_to_exec_err)
279 }
280 #[cfg(feature = "mock")]
281 InnerConnection::Mock(conn) => return conn.execute(stmt),
282 #[cfg(feature = "proxy")]
283 InnerConnection::Proxy(conn) => return conn.execute(stmt).await,
284 #[allow(unreachable_patterns)]
285 _ => Err(conn_err("Disconnected")),
286 }
287 }
288
289 #[instrument(level = "trace")]
290 #[allow(unused_variables)]
291 async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
292 debug_print!("{}", sql);
293
294 match &mut *self.conn.lock().await {
295 #[cfg(feature = "sqlx-mysql")]
296 InnerConnection::MySql(conn) => {
297 let conn: &mut sqlx::MySqlConnection = &mut *conn;
298 sqlx::Executor::execute(conn, sql)
299 .await
300 .map(Into::into)
301 .map_err(sqlx_error_to_exec_err)
302 }
303 #[cfg(feature = "sqlx-postgres")]
304 InnerConnection::Postgres(conn) => {
305 let conn: &mut sqlx::PgConnection = &mut *conn;
306 sqlx::Executor::execute(conn, sql)
307 .await
308 .map(Into::into)
309 .map_err(sqlx_error_to_exec_err)
310 }
311 #[cfg(feature = "sqlx-sqlite")]
312 InnerConnection::Sqlite(conn) => {
313 let conn: &mut sqlx::SqliteConnection = &mut *conn;
314 sqlx::Executor::execute(conn, sql)
315 .await
316 .map(Into::into)
317 .map_err(sqlx_error_to_exec_err)
318 }
319 #[cfg(feature = "mock")]
320 InnerConnection::Mock(conn) => {
321 let db_backend = conn.get_database_backend();
322 let stmt = Statement::from_string(db_backend, sql);
323 conn.execute(stmt)
324 }
325 #[cfg(feature = "proxy")]
326 InnerConnection::Proxy(conn) => {
327 let db_backend = conn.get_database_backend();
328 let stmt = Statement::from_string(db_backend, sql);
329 conn.execute(stmt).await
330 }
331 #[allow(unreachable_patterns)]
332 _ => Err(conn_err("Disconnected")),
333 }
334 }
335
336 #[instrument(level = "trace")]
337 #[allow(unused_variables)]
338 async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
339 debug_print!("{}", stmt);
340
341 match &mut *self.conn.lock().await {
342 #[cfg(feature = "sqlx-mysql")]
343 InnerConnection::MySql(conn) => {
344 let query = crate::driver::sqlx_mysql::sqlx_query(&stmt);
345 let conn: &mut sqlx::MySqlConnection = &mut *conn;
346 crate::metric::metric!(self.metric_callback, &stmt, {
347 crate::sqlx_map_err_ignore_not_found(
348 query.fetch_one(conn).await.map(|row| Some(row.into())),
349 )
350 })
351 }
352 #[cfg(feature = "sqlx-postgres")]
353 InnerConnection::Postgres(conn) => {
354 let query = crate::driver::sqlx_postgres::sqlx_query(&stmt);
355 let conn: &mut sqlx::PgConnection = &mut *conn;
356 crate::metric::metric!(self.metric_callback, &stmt, {
357 crate::sqlx_map_err_ignore_not_found(
358 query.fetch_one(conn).await.map(|row| Some(row.into())),
359 )
360 })
361 }
362 #[cfg(feature = "sqlx-sqlite")]
363 InnerConnection::Sqlite(conn) => {
364 let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt);
365 let conn: &mut sqlx::SqliteConnection = &mut *conn;
366 crate::metric::metric!(self.metric_callback, &stmt, {
367 crate::sqlx_map_err_ignore_not_found(
368 query.fetch_one(conn).await.map(|row| Some(row.into())),
369 )
370 })
371 }
372 #[cfg(feature = "mock")]
373 InnerConnection::Mock(conn) => return conn.query_one(stmt),
374 #[cfg(feature = "proxy")]
375 InnerConnection::Proxy(conn) => return conn.query_one(stmt).await,
376 #[allow(unreachable_patterns)]
377 _ => Err(conn_err("Disconnected")),
378 }
379 }
380
381 #[instrument(level = "trace")]
382 #[allow(unused_variables)]
383 async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
384 debug_print!("{}", stmt);
385
386 match &mut *self.conn.lock().await {
387 #[cfg(feature = "sqlx-mysql")]
388 InnerConnection::MySql(conn) => {
389 let query = crate::driver::sqlx_mysql::sqlx_query(&stmt);
390 let conn: &mut sqlx::MySqlConnection = &mut *conn;
391 crate::metric::metric!(self.metric_callback, &stmt, {
392 query
393 .fetch_all(conn)
394 .await
395 .map(|rows| rows.into_iter().map(|r| r.into()).collect())
396 .map_err(sqlx_error_to_query_err)
397 })
398 }
399 #[cfg(feature = "sqlx-postgres")]
400 InnerConnection::Postgres(conn) => {
401 let query = crate::driver::sqlx_postgres::sqlx_query(&stmt);
402 let conn: &mut sqlx::PgConnection = &mut *conn;
403 crate::metric::metric!(self.metric_callback, &stmt, {
404 query
405 .fetch_all(conn)
406 .await
407 .map(|rows| rows.into_iter().map(|r| r.into()).collect())
408 .map_err(sqlx_error_to_query_err)
409 })
410 }
411 #[cfg(feature = "sqlx-sqlite")]
412 InnerConnection::Sqlite(conn) => {
413 let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt);
414 let conn: &mut sqlx::SqliteConnection = &mut *conn;
415 crate::metric::metric!(self.metric_callback, &stmt, {
416 query
417 .fetch_all(conn)
418 .await
419 .map(|rows| rows.into_iter().map(|r| r.into()).collect())
420 .map_err(sqlx_error_to_query_err)
421 })
422 }
423 #[cfg(feature = "mock")]
424 InnerConnection::Mock(conn) => return conn.query_all(stmt),
425 #[cfg(feature = "proxy")]
426 InnerConnection::Proxy(conn) => return conn.query_all(stmt).await,
427 #[allow(unreachable_patterns)]
428 _ => Err(conn_err("Disconnected")),
429 }
430 }
431}
432
433impl StreamTrait for DatabaseTransaction {
434 type Stream<'a> = TransactionStream<'a>;
435
436 #[instrument(level = "trace")]
437 fn stream<'a>(
438 &'a self,
439 stmt: Statement,
440 ) -> Pin<Box<dyn Future<Output = Result<Self::Stream<'a>, DbErr>> + 'a + Send>> {
441 Box::pin(async move {
442 let conn = self.conn.lock().await;
443 Ok(crate::TransactionStream::build(
444 conn,
445 stmt,
446 self.metric_callback.clone(),
447 ))
448 })
449 }
450}
451
452#[async_trait::async_trait]
453impl TransactionTrait for DatabaseTransaction {
454 #[instrument(level = "trace")]
455 async fn begin(&self) -> Result<DatabaseTransaction, DbErr> {
456 DatabaseTransaction::begin(
457 Arc::clone(&self.conn),
458 self.backend,
459 self.metric_callback.clone(),
460 None,
461 None,
462 )
463 .await
464 }
465
466 #[instrument(level = "trace")]
467 async fn begin_with_config(
468 &self,
469 isolation_level: Option<IsolationLevel>,
470 access_mode: Option<AccessMode>,
471 ) -> Result<DatabaseTransaction, DbErr> {
472 DatabaseTransaction::begin(
473 Arc::clone(&self.conn),
474 self.backend,
475 self.metric_callback.clone(),
476 isolation_level,
477 access_mode,
478 )
479 .await
480 }
481
482 #[instrument(level = "trace", skip(_callback))]
485 async fn transaction<F, T, E>(&self, _callback: F) -> Result<T, TransactionError<E>>
486 where
487 F: for<'c> FnOnce(
488 &'c DatabaseTransaction,
489 ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
490 + Send,
491 T: Send,
492 E: std::fmt::Display + std::fmt::Debug + Send,
493 {
494 let transaction = self.begin().await.map_err(TransactionError::Connection)?;
495 transaction.run(_callback).await
496 }
497
498 #[instrument(level = "trace", skip(_callback))]
501 async fn transaction_with_config<F, T, E>(
502 &self,
503 _callback: F,
504 isolation_level: Option<IsolationLevel>,
505 access_mode: Option<AccessMode>,
506 ) -> Result<T, TransactionError<E>>
507 where
508 F: for<'c> FnOnce(
509 &'c DatabaseTransaction,
510 ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
511 + Send,
512 T: Send,
513 E: std::fmt::Display + std::fmt::Debug + Send,
514 {
515 let transaction = self
516 .begin_with_config(isolation_level, access_mode)
517 .await
518 .map_err(TransactionError::Connection)?;
519 transaction.run(_callback).await
520 }
521}
522
523#[derive(Debug)]
525pub enum TransactionError<E> {
526 Connection(DbErr),
528 Transaction(E),
530}
531
532impl<E> std::fmt::Display for TransactionError<E>
533where
534 E: std::fmt::Display + std::fmt::Debug,
535{
536 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
537 match self {
538 TransactionError::Connection(e) => std::fmt::Display::fmt(e, f),
539 TransactionError::Transaction(e) => std::fmt::Display::fmt(e, f),
540 }
541 }
542}
543
544impl<E> std::error::Error for TransactionError<E> where E: std::fmt::Display + std::fmt::Debug {}
545
546impl<E> From<DbErr> for TransactionError<E>
547where
548 E: std::fmt::Display + std::fmt::Debug,
549{
550 fn from(e: DbErr) -> Self {
551 Self::Connection(e)
552 }
553}