1use crate::{
2 AccessMode, ConnectionTrait, DbBackend, DbErr, ExecResult, InnerConnection, IsolationLevel,
3 QueryResult, Statement, StreamTrait, TransactionSession, TransactionStream, TransactionTrait,
4 debug_print, error::*,
5};
6#[cfg(feature = "sqlx-dep")]
7use crate::{sqlx_error_to_exec_err, sqlx_error_to_query_err};
8use futures_util::lock::Mutex;
9#[cfg(feature = "sqlx-dep")]
10use sqlx::TransactionManager;
11use std::{future::Future, pin::Pin, sync::Arc};
12use tracing::instrument;
13
14pub struct DatabaseTransaction {
19 conn: Arc<Mutex<InnerConnection>>,
20 backend: DbBackend,
21 open: bool,
22 metric_callback: Option<crate::metric::Callback>,
23}
24
25impl std::fmt::Debug for DatabaseTransaction {
26 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27 write!(f, "DatabaseTransaction")
28 }
29}
30
31impl DatabaseTransaction {
32 #[instrument(level = "trace", skip(metric_callback))]
33 pub(crate) async fn begin(
34 conn: Arc<Mutex<InnerConnection>>,
35 backend: DbBackend,
36 metric_callback: Option<crate::metric::Callback>,
37 isolation_level: Option<IsolationLevel>,
38 access_mode: Option<AccessMode>,
39 ) -> Result<DatabaseTransaction, DbErr> {
40 let res = DatabaseTransaction {
41 conn,
42 backend,
43 open: true,
44 metric_callback,
45 };
46 match *res.conn.lock().await {
47 #[cfg(feature = "sqlx-mysql")]
48 InnerConnection::MySql(ref mut c) => {
49 crate::driver::sqlx_mysql::set_transaction_config(c, isolation_level, access_mode)
51 .await?;
52 <sqlx::MySql as sqlx::Database>::TransactionManager::begin(c, None)
53 .await
54 .map_err(sqlx_error_to_query_err)
55 }
56 #[cfg(feature = "sqlx-postgres")]
57 InnerConnection::Postgres(ref mut c) => {
58 <sqlx::Postgres as sqlx::Database>::TransactionManager::begin(c, None)
59 .await
60 .map_err(sqlx_error_to_query_err)?;
61 crate::driver::sqlx_postgres::set_transaction_config(
63 c,
64 isolation_level,
65 access_mode,
66 )
67 .await
68 }
69 #[cfg(feature = "sqlx-sqlite")]
70 InnerConnection::Sqlite(ref mut c) => {
71 crate::driver::sqlx_sqlite::set_transaction_config(c, isolation_level, access_mode)
73 .await?;
74 <sqlx::Sqlite as sqlx::Database>::TransactionManager::begin(c, None)
75 .await
76 .map_err(sqlx_error_to_query_err)
77 }
78 #[cfg(feature = "mock")]
79 InnerConnection::Mock(ref mut c) => {
80 c.begin();
81 Ok(())
82 }
83 #[cfg(feature = "proxy")]
84 InnerConnection::Proxy(ref mut c) => {
85 c.begin().await;
86 Ok(())
87 }
88 #[allow(unreachable_patterns)]
89 _ => Err(conn_err("Disconnected")),
90 }?;
91 Ok(res)
92 }
93
94 #[instrument(level = "trace", skip(callback))]
97 pub(crate) async fn run<F, T, E>(self, callback: F) -> Result<T, TransactionError<E>>
98 where
99 F: for<'b> FnOnce(
100 &'b DatabaseTransaction,
101 ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'b>>
102 + Send,
103 T: Send,
104 E: std::fmt::Display + std::fmt::Debug + Send,
105 {
106 let res = callback(&self).await.map_err(TransactionError::Transaction);
107 if res.is_ok() {
108 self.commit().await.map_err(TransactionError::Connection)?;
109 } else {
110 self.rollback()
111 .await
112 .map_err(TransactionError::Connection)?;
113 }
114 res
115 }
116
117 #[instrument(level = "trace")]
119 #[allow(unreachable_code, unused_mut)]
120 pub async fn commit(mut self) -> Result<(), DbErr> {
121 match *self.conn.lock().await {
122 #[cfg(feature = "sqlx-mysql")]
123 InnerConnection::MySql(ref mut c) => {
124 <sqlx::MySql as sqlx::Database>::TransactionManager::commit(c)
125 .await
126 .map_err(sqlx_error_to_query_err)
127 }
128 #[cfg(feature = "sqlx-postgres")]
129 InnerConnection::Postgres(ref mut c) => {
130 <sqlx::Postgres as sqlx::Database>::TransactionManager::commit(c)
131 .await
132 .map_err(sqlx_error_to_query_err)
133 }
134 #[cfg(feature = "sqlx-sqlite")]
135 InnerConnection::Sqlite(ref mut c) => {
136 <sqlx::Sqlite as sqlx::Database>::TransactionManager::commit(c)
137 .await
138 .map_err(sqlx_error_to_query_err)
139 }
140 #[cfg(feature = "mock")]
141 InnerConnection::Mock(ref mut c) => {
142 c.commit();
143 Ok(())
144 }
145 #[cfg(feature = "proxy")]
146 InnerConnection::Proxy(ref mut c) => {
147 c.commit().await;
148 Ok(())
149 }
150 #[allow(unreachable_patterns)]
151 _ => Err(conn_err("Disconnected")),
152 }?;
153 self.open = false;
154 Ok(())
155 }
156
157 #[instrument(level = "trace")]
159 #[allow(unreachable_code, unused_mut)]
160 pub async fn rollback(mut self) -> Result<(), DbErr> {
161 match *self.conn.lock().await {
162 #[cfg(feature = "sqlx-mysql")]
163 InnerConnection::MySql(ref mut c) => {
164 <sqlx::MySql as sqlx::Database>::TransactionManager::rollback(c)
165 .await
166 .map_err(sqlx_error_to_query_err)
167 }
168 #[cfg(feature = "sqlx-postgres")]
169 InnerConnection::Postgres(ref mut c) => {
170 <sqlx::Postgres as sqlx::Database>::TransactionManager::rollback(c)
171 .await
172 .map_err(sqlx_error_to_query_err)
173 }
174 #[cfg(feature = "sqlx-sqlite")]
175 InnerConnection::Sqlite(ref mut c) => {
176 <sqlx::Sqlite as sqlx::Database>::TransactionManager::rollback(c)
177 .await
178 .map_err(sqlx_error_to_query_err)
179 }
180 #[cfg(feature = "mock")]
181 InnerConnection::Mock(ref mut c) => {
182 c.rollback();
183 Ok(())
184 }
185 #[cfg(feature = "proxy")]
186 InnerConnection::Proxy(ref mut c) => {
187 c.rollback().await;
188 Ok(())
189 }
190 #[allow(unreachable_patterns)]
191 _ => Err(conn_err("Disconnected")),
192 }?;
193 self.open = false;
194 Ok(())
195 }
196
197 #[instrument(level = "trace")]
199 fn start_rollback(&mut self) -> Result<(), DbErr> {
200 if self.open {
201 if let Some(mut conn) = self.conn.try_lock() {
202 match &mut *conn {
203 #[cfg(feature = "sqlx-mysql")]
204 InnerConnection::MySql(c) => {
205 <sqlx::MySql as sqlx::Database>::TransactionManager::start_rollback(c);
206 }
207 #[cfg(feature = "sqlx-postgres")]
208 InnerConnection::Postgres(c) => {
209 <sqlx::Postgres as sqlx::Database>::TransactionManager::start_rollback(c);
210 }
211 #[cfg(feature = "sqlx-sqlite")]
212 InnerConnection::Sqlite(c) => {
213 <sqlx::Sqlite as sqlx::Database>::TransactionManager::start_rollback(c);
214 }
215 #[cfg(feature = "mock")]
216 InnerConnection::Mock(c) => {
217 c.rollback();
218 }
219 #[cfg(feature = "proxy")]
220 InnerConnection::Proxy(c) => {
221 c.start_rollback();
222 }
223 #[allow(unreachable_patterns)]
224 _ => return Err(conn_err("Disconnected")),
225 }
226 } else {
227 return Err(conn_err("Dropping a locked Transaction"));
229 }
230 }
231 Ok(())
232 }
233}
234
235#[async_trait::async_trait]
236impl TransactionSession for DatabaseTransaction {
237 async fn commit(self) -> Result<(), DbErr> {
238 self.commit().await
239 }
240
241 async fn rollback(self) -> Result<(), DbErr> {
242 self.rollback().await
243 }
244}
245
246impl Drop for DatabaseTransaction {
247 fn drop(&mut self) {
248 self.start_rollback().expect("Fail to rollback transaction");
249 }
250}
251
252#[async_trait::async_trait]
253impl ConnectionTrait for DatabaseTransaction {
254 fn get_database_backend(&self) -> DbBackend {
255 self.backend
257 }
258
259 #[instrument(level = "trace")]
260 #[allow(unused_variables)]
261 async fn execute_raw(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
262 debug_print!("{}", stmt);
263
264 match &mut *self.conn.lock().await {
265 #[cfg(feature = "sqlx-mysql")]
266 InnerConnection::MySql(conn) => {
267 let query = crate::driver::sqlx_mysql::sqlx_query(&stmt);
268 let conn: &mut sqlx::MySqlConnection = &mut *conn;
269 crate::metric::metric!(self.metric_callback, &stmt, {
270 query.execute(conn).await.map(Into::into)
271 })
272 .map_err(sqlx_error_to_exec_err)
273 }
274 #[cfg(feature = "sqlx-postgres")]
275 InnerConnection::Postgres(conn) => {
276 let query = crate::driver::sqlx_postgres::sqlx_query(&stmt);
277 let conn: &mut sqlx::PgConnection = &mut *conn;
278 crate::metric::metric!(self.metric_callback, &stmt, {
279 query.execute(conn).await.map(Into::into)
280 })
281 .map_err(sqlx_error_to_exec_err)
282 }
283 #[cfg(feature = "sqlx-sqlite")]
284 InnerConnection::Sqlite(conn) => {
285 let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt);
286 let conn: &mut sqlx::SqliteConnection = &mut *conn;
287 crate::metric::metric!(self.metric_callback, &stmt, {
288 query.execute(conn).await.map(Into::into)
289 })
290 .map_err(sqlx_error_to_exec_err)
291 }
292 #[cfg(feature = "mock")]
293 InnerConnection::Mock(conn) => return conn.execute(stmt),
294 #[cfg(feature = "proxy")]
295 InnerConnection::Proxy(conn) => return conn.execute(stmt).await,
296 #[allow(unreachable_patterns)]
297 _ => Err(conn_err("Disconnected")),
298 }
299 }
300
301 #[instrument(level = "trace")]
302 #[allow(unused_variables)]
303 async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
304 debug_print!("{}", sql);
305
306 match &mut *self.conn.lock().await {
307 #[cfg(feature = "sqlx-mysql")]
308 InnerConnection::MySql(conn) => {
309 let conn: &mut sqlx::MySqlConnection = &mut *conn;
310 sqlx::Executor::execute(conn, sql)
311 .await
312 .map(Into::into)
313 .map_err(sqlx_error_to_exec_err)
314 }
315 #[cfg(feature = "sqlx-postgres")]
316 InnerConnection::Postgres(conn) => {
317 let conn: &mut sqlx::PgConnection = &mut *conn;
318 sqlx::Executor::execute(conn, sql)
319 .await
320 .map(Into::into)
321 .map_err(sqlx_error_to_exec_err)
322 }
323 #[cfg(feature = "sqlx-sqlite")]
324 InnerConnection::Sqlite(conn) => {
325 let conn: &mut sqlx::SqliteConnection = &mut *conn;
326 sqlx::Executor::execute(conn, sql)
327 .await
328 .map(Into::into)
329 .map_err(sqlx_error_to_exec_err)
330 }
331 #[cfg(feature = "mock")]
332 InnerConnection::Mock(conn) => {
333 let db_backend = conn.get_database_backend();
334 let stmt = Statement::from_string(db_backend, sql);
335 conn.execute(stmt)
336 }
337 #[cfg(feature = "proxy")]
338 InnerConnection::Proxy(conn) => {
339 let db_backend = conn.get_database_backend();
340 let stmt = Statement::from_string(db_backend, sql);
341 conn.execute(stmt).await
342 }
343 #[allow(unreachable_patterns)]
344 _ => Err(conn_err("Disconnected")),
345 }
346 }
347
348 #[instrument(level = "trace")]
349 #[allow(unused_variables)]
350 async fn query_one_raw(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
351 debug_print!("{}", stmt);
352
353 match &mut *self.conn.lock().await {
354 #[cfg(feature = "sqlx-mysql")]
355 InnerConnection::MySql(conn) => {
356 let query = crate::driver::sqlx_mysql::sqlx_query(&stmt);
357 let conn: &mut sqlx::MySqlConnection = &mut *conn;
358 crate::metric::metric!(self.metric_callback, &stmt, {
359 crate::sqlx_map_err_ignore_not_found(
360 query.fetch_one(conn).await.map(|row| Some(row.into())),
361 )
362 })
363 }
364 #[cfg(feature = "sqlx-postgres")]
365 InnerConnection::Postgres(conn) => {
366 let query = crate::driver::sqlx_postgres::sqlx_query(&stmt);
367 let conn: &mut sqlx::PgConnection = &mut *conn;
368 crate::metric::metric!(self.metric_callback, &stmt, {
369 crate::sqlx_map_err_ignore_not_found(
370 query.fetch_one(conn).await.map(|row| Some(row.into())),
371 )
372 })
373 }
374 #[cfg(feature = "sqlx-sqlite")]
375 InnerConnection::Sqlite(conn) => {
376 let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt);
377 let conn: &mut sqlx::SqliteConnection = &mut *conn;
378 crate::metric::metric!(self.metric_callback, &stmt, {
379 crate::sqlx_map_err_ignore_not_found(
380 query.fetch_one(conn).await.map(|row| Some(row.into())),
381 )
382 })
383 }
384 #[cfg(feature = "mock")]
385 InnerConnection::Mock(conn) => return conn.query_one(stmt),
386 #[cfg(feature = "proxy")]
387 InnerConnection::Proxy(conn) => return conn.query_one(stmt).await,
388 #[allow(unreachable_patterns)]
389 _ => Err(conn_err("Disconnected")),
390 }
391 }
392
393 #[instrument(level = "trace")]
394 #[allow(unused_variables)]
395 async fn query_all_raw(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
396 debug_print!("{}", stmt);
397
398 match &mut *self.conn.lock().await {
399 #[cfg(feature = "sqlx-mysql")]
400 InnerConnection::MySql(conn) => {
401 let query = crate::driver::sqlx_mysql::sqlx_query(&stmt);
402 let conn: &mut sqlx::MySqlConnection = &mut *conn;
403 crate::metric::metric!(self.metric_callback, &stmt, {
404 query
405 .fetch_all(conn)
406 .await
407 .map(|rows| rows.into_iter().map(|r| r.into()).collect())
408 .map_err(sqlx_error_to_query_err)
409 })
410 }
411 #[cfg(feature = "sqlx-postgres")]
412 InnerConnection::Postgres(conn) => {
413 let query = crate::driver::sqlx_postgres::sqlx_query(&stmt);
414 let conn: &mut sqlx::PgConnection = &mut *conn;
415 crate::metric::metric!(self.metric_callback, &stmt, {
416 query
417 .fetch_all(conn)
418 .await
419 .map(|rows| rows.into_iter().map(|r| r.into()).collect())
420 .map_err(sqlx_error_to_query_err)
421 })
422 }
423 #[cfg(feature = "sqlx-sqlite")]
424 InnerConnection::Sqlite(conn) => {
425 let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt);
426 let conn: &mut sqlx::SqliteConnection = &mut *conn;
427 crate::metric::metric!(self.metric_callback, &stmt, {
428 query
429 .fetch_all(conn)
430 .await
431 .map(|rows| rows.into_iter().map(|r| r.into()).collect())
432 .map_err(sqlx_error_to_query_err)
433 })
434 }
435 #[cfg(feature = "mock")]
436 InnerConnection::Mock(conn) => return conn.query_all(stmt),
437 #[cfg(feature = "proxy")]
438 InnerConnection::Proxy(conn) => return conn.query_all(stmt).await,
439 #[allow(unreachable_patterns)]
440 _ => Err(conn_err("Disconnected")),
441 }
442 }
443}
444
445impl StreamTrait for DatabaseTransaction {
446 type Stream<'a> = TransactionStream<'a>;
447
448 fn get_database_backend(&self) -> DbBackend {
449 self.backend
450 }
451
452 #[instrument(level = "trace")]
453 fn stream_raw<'a>(
454 &'a self,
455 stmt: Statement,
456 ) -> Pin<Box<dyn Future<Output = Result<Self::Stream<'a>, DbErr>> + 'a + Send>> {
457 Box::pin(async move {
458 let conn = self.conn.lock().await;
459 Ok(crate::TransactionStream::build(
460 conn,
461 stmt,
462 self.metric_callback.clone(),
463 ))
464 })
465 }
466}
467
468#[async_trait::async_trait]
469impl TransactionTrait for DatabaseTransaction {
470 type Transaction = DatabaseTransaction;
471
472 #[instrument(level = "trace")]
473 async fn begin(&self) -> Result<DatabaseTransaction, DbErr> {
474 DatabaseTransaction::begin(
475 Arc::clone(&self.conn),
476 self.backend,
477 self.metric_callback.clone(),
478 None,
479 None,
480 )
481 .await
482 }
483
484 #[instrument(level = "trace")]
485 async fn begin_with_config(
486 &self,
487 isolation_level: Option<IsolationLevel>,
488 access_mode: Option<AccessMode>,
489 ) -> Result<DatabaseTransaction, DbErr> {
490 DatabaseTransaction::begin(
491 Arc::clone(&self.conn),
492 self.backend,
493 self.metric_callback.clone(),
494 isolation_level,
495 access_mode,
496 )
497 .await
498 }
499
500 #[instrument(level = "trace", skip(_callback))]
504 async fn transaction<F, T, E>(&self, _callback: F) -> Result<T, TransactionError<E>>
505 where
506 F: for<'c> FnOnce(
507 &'c DatabaseTransaction,
508 ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
509 + Send,
510 T: Send,
511 E: std::fmt::Display + std::fmt::Debug + Send,
512 {
513 let transaction = self.begin().await.map_err(TransactionError::Connection)?;
514 transaction.run(_callback).await
515 }
516
517 #[instrument(level = "trace", skip(_callback))]
521 async fn transaction_with_config<F, T, E>(
522 &self,
523 _callback: F,
524 isolation_level: Option<IsolationLevel>,
525 access_mode: Option<AccessMode>,
526 ) -> Result<T, TransactionError<E>>
527 where
528 F: for<'c> FnOnce(
529 &'c DatabaseTransaction,
530 ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
531 + Send,
532 T: Send,
533 E: std::fmt::Display + std::fmt::Debug + Send,
534 {
535 let transaction = self
536 .begin_with_config(isolation_level, access_mode)
537 .await
538 .map_err(TransactionError::Connection)?;
539 transaction.run(_callback).await
540 }
541}
542
543#[derive(Debug)]
545pub enum TransactionError<E> {
546 Connection(DbErr),
548 Transaction(E),
550}
551
552impl<E> std::fmt::Display for TransactionError<E>
553where
554 E: std::fmt::Display + std::fmt::Debug,
555{
556 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
557 match self {
558 TransactionError::Connection(e) => std::fmt::Display::fmt(e, f),
559 TransactionError::Transaction(e) => std::fmt::Display::fmt(e, f),
560 }
561 }
562}
563
564impl<E> std::error::Error for TransactionError<E> where E: std::fmt::Display + std::fmt::Debug {}
565
566impl<E> From<DbErr> for TransactionError<E>
567where
568 E: std::fmt::Display + std::fmt::Debug,
569{
570 fn from(e: DbErr) -> Self {
571 Self::Connection(e)
572 }
573}