1#![allow(unused_assignments)]
2use crate::{
3 AccessMode, ConnectionTrait, DbBackend, DbErr, ExecResult, InnerConnection, IsolationLevel,
4 QueryResult, Statement, StreamTrait, TransactionSession, TransactionStream, TransactionTrait,
5 debug_print, error::*,
6};
7#[cfg(feature = "sqlx-dep")]
8use crate::{sqlx_error_to_exec_err, sqlx_error_to_query_err};
9use futures_util::lock::Mutex;
10#[cfg(feature = "sqlx-dep")]
11use sqlx::TransactionManager;
12use std::{future::Future, pin::Pin, sync::Arc};
13use tracing::instrument;
14
15pub struct DatabaseTransaction {
20 conn: Arc<Mutex<InnerConnection>>,
21 backend: DbBackend,
22 open: bool,
23 metric_callback: Option<crate::metric::Callback>,
24}
25
26impl std::fmt::Debug for DatabaseTransaction {
27 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28 write!(f, "DatabaseTransaction")
29 }
30}
31
32impl DatabaseTransaction {
33 #[instrument(level = "trace", skip(metric_callback))]
34 pub(crate) async fn begin(
35 conn: Arc<Mutex<InnerConnection>>,
36 backend: DbBackend,
37 metric_callback: Option<crate::metric::Callback>,
38 isolation_level: Option<IsolationLevel>,
39 access_mode: Option<AccessMode>,
40 ) -> Result<DatabaseTransaction, DbErr> {
41 let res = DatabaseTransaction {
42 conn,
43 backend,
44 open: true,
45 metric_callback,
46 };
47 {
48 #[cfg(not(feature = "sync"))]
49 let conn = &mut *res.conn.lock().await;
50 #[cfg(feature = "sync")]
51 let conn = &mut *res.conn.lock().map_err(|_| DbErr::MutexPoisonError)?;
52
53 match conn {
54 #[cfg(feature = "sqlx-mysql")]
55 InnerConnection::MySql(c) => {
56 crate::driver::sqlx_mysql::set_transaction_config(
58 c,
59 isolation_level,
60 access_mode,
61 )
62 .await?;
63 <sqlx::MySql as sqlx::Database>::TransactionManager::begin(c, None)
64 .await
65 .map_err(sqlx_error_to_query_err)
66 }
67 #[cfg(feature = "sqlx-postgres")]
68 InnerConnection::Postgres(c) => {
69 <sqlx::Postgres as sqlx::Database>::TransactionManager::begin(c, None)
70 .await
71 .map_err(sqlx_error_to_query_err)?;
72 crate::driver::sqlx_postgres::set_transaction_config(
74 c,
75 isolation_level,
76 access_mode,
77 )
78 .await
79 }
80 #[cfg(feature = "sqlx-sqlite")]
81 InnerConnection::Sqlite(c) => {
82 crate::driver::sqlx_sqlite::set_transaction_config(
84 c,
85 isolation_level,
86 access_mode,
87 )
88 .await?;
89 <sqlx::Sqlite as sqlx::Database>::TransactionManager::begin(c, None)
90 .await
91 .map_err(sqlx_error_to_query_err)
92 }
93 #[cfg(feature = "rusqlite")]
94 InnerConnection::Rusqlite(c) => c.begin(),
95 #[cfg(feature = "mock")]
96 InnerConnection::Mock(c) => {
97 c.begin();
98 Ok(())
99 }
100 #[cfg(feature = "proxy")]
101 InnerConnection::Proxy(c) => {
102 c.begin().await;
103 Ok(())
104 }
105 #[allow(unreachable_patterns)]
106 _ => Err(conn_err("Disconnected")),
107 }?
108 };
109
110 Ok(res)
111 }
112
113 #[instrument(level = "trace", skip(callback))]
116 pub(crate) async fn run<F, T, E>(self, callback: F) -> Result<T, TransactionError<E>>
117 where
118 F: for<'b> FnOnce(
119 &'b DatabaseTransaction,
120 ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'b>>
121 + Send,
122 T: Send,
123 E: std::fmt::Display + std::fmt::Debug + Send,
124 {
125 let res = callback(&self).await.map_err(TransactionError::Transaction);
126 if res.is_ok() {
127 self.commit().await.map_err(TransactionError::Connection)?;
128 } else {
129 self.rollback()
130 .await
131 .map_err(TransactionError::Connection)?;
132 }
133 res
134 }
135
136 #[instrument(level = "trace")]
138 #[allow(unreachable_code, unused_mut)]
139 pub async fn commit(mut self) -> Result<(), DbErr> {
140 #[cfg(not(feature = "sync"))]
141 let conn = &mut *self.conn.lock().await;
142 #[cfg(feature = "sync")]
143 let conn = &mut *self.conn.lock().map_err(|_| DbErr::MutexPoisonError)?;
144
145 match conn {
146 #[cfg(feature = "sqlx-mysql")]
147 InnerConnection::MySql(c) => {
148 <sqlx::MySql as sqlx::Database>::TransactionManager::commit(c)
149 .await
150 .map_err(sqlx_error_to_query_err)
151 }
152 #[cfg(feature = "sqlx-postgres")]
153 InnerConnection::Postgres(c) => {
154 <sqlx::Postgres as sqlx::Database>::TransactionManager::commit(c)
155 .await
156 .map_err(sqlx_error_to_query_err)
157 }
158 #[cfg(feature = "sqlx-sqlite")]
159 InnerConnection::Sqlite(c) => {
160 <sqlx::Sqlite as sqlx::Database>::TransactionManager::commit(c)
161 .await
162 .map_err(sqlx_error_to_query_err)
163 }
164 #[cfg(feature = "rusqlite")]
165 InnerConnection::Rusqlite(c) => c.commit(),
166 #[cfg(feature = "mock")]
167 InnerConnection::Mock(c) => {
168 c.commit();
169 Ok(())
170 }
171 #[cfg(feature = "proxy")]
172 InnerConnection::Proxy(c) => {
173 c.commit().await;
174 Ok(())
175 }
176 #[allow(unreachable_patterns)]
177 _ => Err(conn_err("Disconnected")),
178 }?;
179 self.open = false; Ok(())
181 }
182
183 #[instrument(level = "trace")]
185 #[allow(unreachable_code, unused_mut)]
186 pub async fn rollback(mut self) -> Result<(), DbErr> {
187 #[cfg(not(feature = "sync"))]
188 let conn = &mut *self.conn.lock().await;
189 #[cfg(feature = "sync")]
190 let conn = &mut *self.conn.lock().map_err(|_| DbErr::MutexPoisonError)?;
191
192 match conn {
193 #[cfg(feature = "sqlx-mysql")]
194 InnerConnection::MySql(c) => {
195 <sqlx::MySql as sqlx::Database>::TransactionManager::rollback(c)
196 .await
197 .map_err(sqlx_error_to_query_err)
198 }
199 #[cfg(feature = "sqlx-postgres")]
200 InnerConnection::Postgres(c) => {
201 <sqlx::Postgres as sqlx::Database>::TransactionManager::rollback(c)
202 .await
203 .map_err(sqlx_error_to_query_err)
204 }
205 #[cfg(feature = "sqlx-sqlite")]
206 InnerConnection::Sqlite(c) => {
207 <sqlx::Sqlite as sqlx::Database>::TransactionManager::rollback(c)
208 .await
209 .map_err(sqlx_error_to_query_err)
210 }
211 #[cfg(feature = "rusqlite")]
212 InnerConnection::Rusqlite(c) => c.rollback(),
213 #[cfg(feature = "mock")]
214 InnerConnection::Mock(c) => {
215 c.rollback();
216 Ok(())
217 }
218 #[cfg(feature = "proxy")]
219 InnerConnection::Proxy(c) => {
220 c.rollback().await;
221 Ok(())
222 }
223 #[allow(unreachable_patterns)]
224 _ => Err(conn_err("Disconnected")),
225 }?;
226 self.open = false; Ok(())
228 }
229
230 #[instrument(level = "trace")]
232 fn start_rollback(&mut self) -> Result<(), DbErr> {
233 if self.open {
234 if let Some(mut conn) = self.conn.try_lock() {
235 match &mut *conn {
236 #[cfg(feature = "sqlx-mysql")]
237 InnerConnection::MySql(c) => {
238 <sqlx::MySql as sqlx::Database>::TransactionManager::start_rollback(c);
239 }
240 #[cfg(feature = "sqlx-postgres")]
241 InnerConnection::Postgres(c) => {
242 <sqlx::Postgres as sqlx::Database>::TransactionManager::start_rollback(c);
243 }
244 #[cfg(feature = "sqlx-sqlite")]
245 InnerConnection::Sqlite(c) => {
246 <sqlx::Sqlite as sqlx::Database>::TransactionManager::start_rollback(c);
247 }
248 #[cfg(feature = "rusqlite")]
249 InnerConnection::Rusqlite(c) => {
250 c.start_rollback()?;
251 }
252 #[cfg(feature = "mock")]
253 InnerConnection::Mock(c) => {
254 c.rollback();
255 }
256 #[cfg(feature = "proxy")]
257 InnerConnection::Proxy(c) => {
258 c.start_rollback();
259 }
260 #[allow(unreachable_patterns)]
261 _ => return Err(conn_err("Disconnected")),
262 }
263 } else {
264 return Err(conn_err("Dropping a locked Transaction"));
266 }
267 }
268 Ok(())
269 }
270}
271
272#[async_trait::async_trait]
273impl TransactionSession for DatabaseTransaction {
274 async fn commit(self) -> Result<(), DbErr> {
275 self.commit().await
276 }
277
278 async fn rollback(self) -> Result<(), DbErr> {
279 self.rollback().await
280 }
281}
282
283impl Drop for DatabaseTransaction {
284 fn drop(&mut self) {
285 self.start_rollback().expect("Fail to rollback transaction");
286 }
287}
288
289#[async_trait::async_trait]
290impl ConnectionTrait for DatabaseTransaction {
291 fn get_database_backend(&self) -> DbBackend {
292 self.backend
294 }
295
296 #[instrument(level = "trace")]
297 #[allow(unused_variables)]
298 async fn execute_raw(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
299 debug_print!("{}", stmt);
300
301 #[cfg(not(feature = "sync"))]
302 let conn = &mut *self.conn.lock().await;
303 #[cfg(feature = "sync")]
304 let conn = &mut *self.conn.lock().map_err(|_| DbErr::MutexPoisonError)?;
305
306 match conn {
307 #[cfg(feature = "sqlx-mysql")]
308 InnerConnection::MySql(conn) => {
309 let query = crate::driver::sqlx_mysql::sqlx_query(&stmt);
310 let conn: &mut sqlx::MySqlConnection = &mut *conn;
311 crate::metric::metric!(self.metric_callback, &stmt, {
312 query.execute(conn).await.map(Into::into)
313 })
314 .map_err(sqlx_error_to_exec_err)
315 }
316 #[cfg(feature = "sqlx-postgres")]
317 InnerConnection::Postgres(conn) => {
318 let query = crate::driver::sqlx_postgres::sqlx_query(&stmt);
319 let conn: &mut sqlx::PgConnection = &mut *conn;
320 crate::metric::metric!(self.metric_callback, &stmt, {
321 query.execute(conn).await.map(Into::into)
322 })
323 .map_err(sqlx_error_to_exec_err)
324 }
325 #[cfg(feature = "sqlx-sqlite")]
326 InnerConnection::Sqlite(conn) => {
327 let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt);
328 let conn: &mut sqlx::SqliteConnection = &mut *conn;
329 crate::metric::metric!(self.metric_callback, &stmt, {
330 query.execute(conn).await.map(Into::into)
331 })
332 .map_err(sqlx_error_to_exec_err)
333 }
334 #[cfg(feature = "rusqlite")]
335 InnerConnection::Rusqlite(conn) => conn.execute(stmt, &self.metric_callback),
336 #[cfg(feature = "mock")]
337 InnerConnection::Mock(conn) => return conn.execute(stmt),
338 #[cfg(feature = "proxy")]
339 InnerConnection::Proxy(conn) => return conn.execute(stmt).await,
340 #[allow(unreachable_patterns)]
341 _ => Err(conn_err("Disconnected")),
342 }
343 }
344
345 #[instrument(level = "trace")]
346 #[allow(unused_variables)]
347 async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
348 debug_print!("{}", sql);
349
350 #[cfg(not(feature = "sync"))]
351 let conn = &mut *self.conn.lock().await;
352 #[cfg(feature = "sync")]
353 let conn = &mut *self.conn.lock().map_err(|_| DbErr::MutexPoisonError)?;
354
355 match conn {
356 #[cfg(feature = "sqlx-mysql")]
357 InnerConnection::MySql(conn) => {
358 let conn: &mut sqlx::MySqlConnection = &mut *conn;
359 sqlx::Executor::execute(conn, sql)
360 .await
361 .map(Into::into)
362 .map_err(sqlx_error_to_exec_err)
363 }
364 #[cfg(feature = "sqlx-postgres")]
365 InnerConnection::Postgres(conn) => {
366 let conn: &mut sqlx::PgConnection = &mut *conn;
367 sqlx::Executor::execute(conn, sql)
368 .await
369 .map(Into::into)
370 .map_err(sqlx_error_to_exec_err)
371 }
372 #[cfg(feature = "sqlx-sqlite")]
373 InnerConnection::Sqlite(conn) => {
374 let conn: &mut sqlx::SqliteConnection = &mut *conn;
375 sqlx::Executor::execute(conn, sql)
376 .await
377 .map(Into::into)
378 .map_err(sqlx_error_to_exec_err)
379 }
380 #[cfg(feature = "rusqlite")]
381 InnerConnection::Rusqlite(conn) => conn.execute_unprepared(sql),
382 #[cfg(feature = "mock")]
383 InnerConnection::Mock(conn) => {
384 let db_backend = conn.get_database_backend();
385 let stmt = Statement::from_string(db_backend, sql);
386 conn.execute(stmt)
387 }
388 #[cfg(feature = "proxy")]
389 InnerConnection::Proxy(conn) => {
390 let db_backend = conn.get_database_backend();
391 let stmt = Statement::from_string(db_backend, sql);
392 conn.execute(stmt).await
393 }
394 #[allow(unreachable_patterns)]
395 _ => Err(conn_err("Disconnected")),
396 }
397 }
398
399 #[instrument(level = "trace")]
400 #[allow(unused_variables)]
401 async fn query_one_raw(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
402 debug_print!("{}", stmt);
403
404 #[cfg(not(feature = "sync"))]
405 let conn = &mut *self.conn.lock().await;
406 #[cfg(feature = "sync")]
407 let conn = &mut *self.conn.lock().map_err(|_| DbErr::MutexPoisonError)?;
408
409 match conn {
410 #[cfg(feature = "sqlx-mysql")]
411 InnerConnection::MySql(conn) => {
412 let query = crate::driver::sqlx_mysql::sqlx_query(&stmt);
413 let conn: &mut sqlx::MySqlConnection = &mut *conn;
414 crate::metric::metric!(self.metric_callback, &stmt, {
415 crate::sqlx_map_err_ignore_not_found(
416 query.fetch_one(conn).await.map(|row| Some(row.into())),
417 )
418 })
419 }
420 #[cfg(feature = "sqlx-postgres")]
421 InnerConnection::Postgres(conn) => {
422 let query = crate::driver::sqlx_postgres::sqlx_query(&stmt);
423 let conn: &mut sqlx::PgConnection = &mut *conn;
424 crate::metric::metric!(self.metric_callback, &stmt, {
425 crate::sqlx_map_err_ignore_not_found(
426 query.fetch_one(conn).await.map(|row| Some(row.into())),
427 )
428 })
429 }
430 #[cfg(feature = "sqlx-sqlite")]
431 InnerConnection::Sqlite(conn) => {
432 let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt);
433 let conn: &mut sqlx::SqliteConnection = &mut *conn;
434 crate::metric::metric!(self.metric_callback, &stmt, {
435 crate::sqlx_map_err_ignore_not_found(
436 query.fetch_one(conn).await.map(|row| Some(row.into())),
437 )
438 })
439 }
440 #[cfg(feature = "rusqlite")]
441 InnerConnection::Rusqlite(conn) => conn.query_one(stmt, &self.metric_callback),
442 #[cfg(feature = "mock")]
443 InnerConnection::Mock(conn) => return conn.query_one(stmt),
444 #[cfg(feature = "proxy")]
445 InnerConnection::Proxy(conn) => return conn.query_one(stmt).await,
446 #[allow(unreachable_patterns)]
447 _ => Err(conn_err("Disconnected")),
448 }
449 }
450
451 #[instrument(level = "trace")]
452 #[allow(unused_variables)]
453 async fn query_all_raw(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
454 debug_print!("{}", stmt);
455
456 #[cfg(not(feature = "sync"))]
457 let conn = &mut *self.conn.lock().await;
458 #[cfg(feature = "sync")]
459 let conn = &mut *self.conn.lock().map_err(|_| DbErr::MutexPoisonError)?;
460
461 match conn {
462 #[cfg(feature = "sqlx-mysql")]
463 InnerConnection::MySql(conn) => {
464 let query = crate::driver::sqlx_mysql::sqlx_query(&stmt);
465 let conn: &mut sqlx::MySqlConnection = &mut *conn;
466 crate::metric::metric!(self.metric_callback, &stmt, {
467 query
468 .fetch_all(conn)
469 .await
470 .map(|rows| rows.into_iter().map(|r| r.into()).collect())
471 .map_err(sqlx_error_to_query_err)
472 })
473 }
474 #[cfg(feature = "sqlx-postgres")]
475 InnerConnection::Postgres(conn) => {
476 let query = crate::driver::sqlx_postgres::sqlx_query(&stmt);
477 let conn: &mut sqlx::PgConnection = &mut *conn;
478 crate::metric::metric!(self.metric_callback, &stmt, {
479 query
480 .fetch_all(conn)
481 .await
482 .map(|rows| rows.into_iter().map(|r| r.into()).collect())
483 .map_err(sqlx_error_to_query_err)
484 })
485 }
486 #[cfg(feature = "sqlx-sqlite")]
487 InnerConnection::Sqlite(conn) => {
488 let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt);
489 let conn: &mut sqlx::SqliteConnection = &mut *conn;
490 crate::metric::metric!(self.metric_callback, &stmt, {
491 query
492 .fetch_all(conn)
493 .await
494 .map(|rows| rows.into_iter().map(|r| r.into()).collect())
495 .map_err(sqlx_error_to_query_err)
496 })
497 }
498 #[cfg(feature = "rusqlite")]
499 InnerConnection::Rusqlite(conn) => conn.query_all(stmt, &self.metric_callback),
500 #[cfg(feature = "mock")]
501 InnerConnection::Mock(conn) => return conn.query_all(stmt),
502 #[cfg(feature = "proxy")]
503 InnerConnection::Proxy(conn) => return conn.query_all(stmt).await,
504 #[allow(unreachable_patterns)]
505 _ => Err(conn_err("Disconnected")),
506 }
507 }
508}
509
510impl StreamTrait for DatabaseTransaction {
511 type Stream<'a> = TransactionStream<'a>;
512
513 fn get_database_backend(&self) -> DbBackend {
514 self.backend
515 }
516
517 #[instrument(level = "trace")]
518 fn stream_raw<'a>(
519 &'a self,
520 stmt: Statement,
521 ) -> Pin<Box<dyn Future<Output = Result<Self::Stream<'a>, DbErr>> + 'a + Send>> {
522 Box::pin(async move {
523 #[cfg(not(feature = "sync"))]
524 let conn = self.conn.lock().await;
525 #[cfg(feature = "sync")]
526 let conn = self.conn.lock().map_err(|_| DbErr::MutexPoisonError)?;
527 Ok(crate::TransactionStream::build(
528 conn,
529 stmt,
530 self.metric_callback.clone(),
531 ))
532 })
533 }
534}
535
536#[async_trait::async_trait]
537impl TransactionTrait for DatabaseTransaction {
538 type Transaction = DatabaseTransaction;
539
540 #[instrument(level = "trace")]
541 async fn begin(&self) -> Result<DatabaseTransaction, DbErr> {
542 DatabaseTransaction::begin(
543 Arc::clone(&self.conn),
544 self.backend,
545 self.metric_callback.clone(),
546 None,
547 None,
548 )
549 .await
550 }
551
552 #[instrument(level = "trace")]
553 async fn begin_with_config(
554 &self,
555 isolation_level: Option<IsolationLevel>,
556 access_mode: Option<AccessMode>,
557 ) -> Result<DatabaseTransaction, DbErr> {
558 DatabaseTransaction::begin(
559 Arc::clone(&self.conn),
560 self.backend,
561 self.metric_callback.clone(),
562 isolation_level,
563 access_mode,
564 )
565 .await
566 }
567
568 #[instrument(level = "trace", skip(_callback))]
572 async fn transaction<F, T, E>(&self, _callback: F) -> Result<T, TransactionError<E>>
573 where
574 F: for<'c> FnOnce(
575 &'c DatabaseTransaction,
576 ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
577 + Send,
578 T: Send,
579 E: std::fmt::Display + std::fmt::Debug + Send,
580 {
581 let transaction = self.begin().await.map_err(TransactionError::Connection)?;
582 transaction.run(_callback).await
583 }
584
585 #[instrument(level = "trace", skip(_callback))]
589 async fn transaction_with_config<F, T, E>(
590 &self,
591 _callback: F,
592 isolation_level: Option<IsolationLevel>,
593 access_mode: Option<AccessMode>,
594 ) -> Result<T, TransactionError<E>>
595 where
596 F: for<'c> FnOnce(
597 &'c DatabaseTransaction,
598 ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
599 + Send,
600 T: Send,
601 E: std::fmt::Display + std::fmt::Debug + Send,
602 {
603 let transaction = self
604 .begin_with_config(isolation_level, access_mode)
605 .await
606 .map_err(TransactionError::Connection)?;
607 transaction.run(_callback).await
608 }
609}
610
611#[derive(Debug)]
613pub enum TransactionError<E> {
614 Connection(DbErr),
616 Transaction(E),
618}
619
620impl<E> std::fmt::Display for TransactionError<E>
621where
622 E: std::fmt::Display + std::fmt::Debug,
623{
624 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
625 match self {
626 TransactionError::Connection(e) => std::fmt::Display::fmt(e, f),
627 TransactionError::Transaction(e) => std::fmt::Display::fmt(e, f),
628 }
629 }
630}
631
632impl<E> std::error::Error for TransactionError<E> where E: std::fmt::Display + std::fmt::Debug {}
633
634impl<E> From<DbErr> for TransactionError<E>
635where
636 E: std::fmt::Display + std::fmt::Debug,
637{
638 fn from(e: DbErr) -> Self {
639 Self::Connection(e)
640 }
641}