Skip to main content

sea_orm/database/
transaction.rs

1#![allow(unused_assignments)]
2use crate::{
3    AccessMode, ConnectionTrait, DbBackend, DbErr, ExecResult, InnerConnection, IsolationLevel,
4    QueryResult, Statement, StreamTrait, TransactionSession, TransactionStream, TransactionTrait,
5    debug_print, error::*,
6};
7#[cfg(feature = "sqlx-dep")]
8use crate::{sqlx_error_to_exec_err, sqlx_error_to_query_err};
9use futures_util::lock::Mutex;
10#[cfg(feature = "sqlx-dep")]
11use sqlx::TransactionManager;
12use std::{future::Future, pin::Pin, sync::Arc};
13use tracing::instrument;
14
15/// Defines a database transaction, whether it is an open transaction and the type of
16/// backend to use.
17/// Under the hood, a Transaction is just a wrapper for a connection where
18/// START TRANSACTION has been executed.
19pub struct DatabaseTransaction {
20    conn: Arc<Mutex<InnerConnection>>,
21    backend: DbBackend,
22    open: bool,
23    metric_callback: Option<crate::metric::Callback>,
24}
25
26impl std::fmt::Debug for DatabaseTransaction {
27    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28        write!(f, "DatabaseTransaction")
29    }
30}
31
32impl DatabaseTransaction {
33    #[instrument(level = "trace", skip(metric_callback))]
34    pub(crate) async fn begin(
35        conn: Arc<Mutex<InnerConnection>>,
36        backend: DbBackend,
37        metric_callback: Option<crate::metric::Callback>,
38        isolation_level: Option<IsolationLevel>,
39        access_mode: Option<AccessMode>,
40    ) -> Result<DatabaseTransaction, DbErr> {
41        let res = DatabaseTransaction {
42            conn,
43            backend,
44            open: true,
45            metric_callback,
46        };
47
48        let begin_result: Result<(), DbErr> = super::tracing_spans::with_db_span!(
49            "sea_orm.begin",
50            backend,
51            "BEGIN",
52            record_stmt = false,
53            async {
54                #[cfg(not(feature = "sync"))]
55                let conn = &mut *res.conn.lock().await;
56                #[cfg(feature = "sync")]
57                let conn = &mut *res.conn.lock().map_err(|_| DbErr::MutexPoisonError)?;
58
59                match conn {
60                    #[cfg(feature = "sqlx-mysql")]
61                    InnerConnection::MySql(c) => {
62                        // in MySQL SET TRANSACTION operations must be executed before transaction start
63                        crate::driver::sqlx_mysql::set_transaction_config(
64                            c,
65                            isolation_level,
66                            access_mode,
67                        )
68                        .await?;
69                        <sqlx::MySql as sqlx::Database>::TransactionManager::begin(c, None)
70                            .await
71                            .map_err(sqlx_error_to_query_err)
72                    }
73                    #[cfg(feature = "sqlx-postgres")]
74                    InnerConnection::Postgres(c) => {
75                        <sqlx::Postgres as sqlx::Database>::TransactionManager::begin(c, None)
76                            .await
77                            .map_err(sqlx_error_to_query_err)?;
78                        // in PostgreSQL SET TRANSACTION operations must be executed inside transaction
79                        crate::driver::sqlx_postgres::set_transaction_config(
80                            c,
81                            isolation_level,
82                            access_mode,
83                        )
84                        .await
85                    }
86                    #[cfg(feature = "sqlx-sqlite")]
87                    InnerConnection::Sqlite(c) => {
88                        // in SQLite isolation level and access mode are global settings
89                        crate::driver::sqlx_sqlite::set_transaction_config(
90                            c,
91                            isolation_level,
92                            access_mode,
93                        )
94                        .await?;
95                        <sqlx::Sqlite as sqlx::Database>::TransactionManager::begin(c, None)
96                            .await
97                            .map_err(sqlx_error_to_query_err)
98                    }
99                    #[cfg(feature = "rusqlite")]
100                    InnerConnection::Rusqlite(c) => c.begin(),
101                    #[cfg(feature = "mock")]
102                    InnerConnection::Mock(c) => {
103                        c.begin();
104                        Ok(())
105                    }
106                    #[cfg(feature = "proxy")]
107                    InnerConnection::Proxy(c) => {
108                        c.begin().await;
109                        Ok(())
110                    }
111                    #[allow(unreachable_patterns)]
112                    _ => Err(conn_err("Disconnected")),
113                }
114            }
115        );
116
117        begin_result?;
118        Ok(res)
119    }
120
121    /// Runs a transaction to completion passing through the result.
122    /// Rolling back the transaction on encountering an error.
123    #[instrument(level = "trace", skip(callback))]
124    pub(crate) async fn run<F, T, E>(self, callback: F) -> Result<T, TransactionError<E>>
125    where
126        F: for<'b> FnOnce(
127                &'b DatabaseTransaction,
128            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'b>>
129            + Send,
130        T: Send,
131        E: std::fmt::Display + std::fmt::Debug + Send,
132    {
133        let res = callback(&self).await.map_err(TransactionError::Transaction);
134        if res.is_ok() {
135            self.commit().await.map_err(TransactionError::Connection)?;
136        } else {
137            self.rollback()
138                .await
139                .map_err(TransactionError::Connection)?;
140        }
141        res
142    }
143
144    /// Commit a transaction
145    #[instrument(level = "trace")]
146    #[allow(unreachable_code, unused_mut)]
147    pub async fn commit(mut self) -> Result<(), DbErr> {
148        let result: Result<(), DbErr> = super::tracing_spans::with_db_span!(
149            "sea_orm.commit",
150            self.backend,
151            "COMMIT",
152            record_stmt = false,
153            async {
154                #[cfg(not(feature = "sync"))]
155                let conn = &mut *self.conn.lock().await;
156                #[cfg(feature = "sync")]
157                let conn = &mut *self.conn.lock().map_err(|_| DbErr::MutexPoisonError)?;
158
159                match conn {
160                    #[cfg(feature = "sqlx-mysql")]
161                    InnerConnection::MySql(c) => {
162                        <sqlx::MySql as sqlx::Database>::TransactionManager::commit(c)
163                            .await
164                            .map_err(sqlx_error_to_query_err)
165                    }
166                    #[cfg(feature = "sqlx-postgres")]
167                    InnerConnection::Postgres(c) => {
168                        <sqlx::Postgres as sqlx::Database>::TransactionManager::commit(c)
169                            .await
170                            .map_err(sqlx_error_to_query_err)
171                    }
172                    #[cfg(feature = "sqlx-sqlite")]
173                    InnerConnection::Sqlite(c) => {
174                        <sqlx::Sqlite as sqlx::Database>::TransactionManager::commit(c)
175                            .await
176                            .map_err(sqlx_error_to_query_err)
177                    }
178                    #[cfg(feature = "rusqlite")]
179                    InnerConnection::Rusqlite(c) => c.commit(),
180                    #[cfg(feature = "mock")]
181                    InnerConnection::Mock(c) => {
182                        c.commit();
183                        Ok(())
184                    }
185                    #[cfg(feature = "proxy")]
186                    InnerConnection::Proxy(c) => {
187                        c.commit().await;
188                        Ok(())
189                    }
190                    #[allow(unreachable_patterns)]
191                    _ => Err(conn_err("Disconnected")),
192                }
193            }
194        );
195
196        result?;
197        self.open = false; // read by start_rollback
198        Ok(())
199    }
200
201    /// Rolls back a transaction explicitly
202    #[instrument(level = "trace")]
203    #[allow(unreachable_code, unused_mut)]
204    pub async fn rollback(mut self) -> Result<(), DbErr> {
205        let result: Result<(), DbErr> = super::tracing_spans::with_db_span!(
206            "sea_orm.rollback",
207            self.backend,
208            "ROLLBACK",
209            record_stmt = false,
210            async {
211                #[cfg(not(feature = "sync"))]
212                let conn = &mut *self.conn.lock().await;
213                #[cfg(feature = "sync")]
214                let conn = &mut *self.conn.lock().map_err(|_| DbErr::MutexPoisonError)?;
215
216                match conn {
217                    #[cfg(feature = "sqlx-mysql")]
218                    InnerConnection::MySql(c) => {
219                        <sqlx::MySql as sqlx::Database>::TransactionManager::rollback(c)
220                            .await
221                            .map_err(sqlx_error_to_query_err)
222                    }
223                    #[cfg(feature = "sqlx-postgres")]
224                    InnerConnection::Postgres(c) => {
225                        <sqlx::Postgres as sqlx::Database>::TransactionManager::rollback(c)
226                            .await
227                            .map_err(sqlx_error_to_query_err)
228                    }
229                    #[cfg(feature = "sqlx-sqlite")]
230                    InnerConnection::Sqlite(c) => {
231                        <sqlx::Sqlite as sqlx::Database>::TransactionManager::rollback(c)
232                            .await
233                            .map_err(sqlx_error_to_query_err)
234                    }
235                    #[cfg(feature = "rusqlite")]
236                    InnerConnection::Rusqlite(c) => c.rollback(),
237                    #[cfg(feature = "mock")]
238                    InnerConnection::Mock(c) => {
239                        c.rollback();
240                        Ok(())
241                    }
242                    #[cfg(feature = "proxy")]
243                    InnerConnection::Proxy(c) => {
244                        c.rollback().await;
245                        Ok(())
246                    }
247                    #[allow(unreachable_patterns)]
248                    _ => Err(conn_err("Disconnected")),
249                }
250            }
251        );
252
253        result?;
254        self.open = false; // read by start_rollback
255        Ok(())
256    }
257
258    // the rollback is queued and will be performed on next async operation, like returning the connection to the pool
259    #[instrument(level = "trace")]
260    fn start_rollback(&mut self) -> Result<(), DbErr> {
261        if self.open {
262            if let Some(mut conn) = self.conn.try_lock() {
263                match &mut *conn {
264                    #[cfg(feature = "sqlx-mysql")]
265                    InnerConnection::MySql(c) => {
266                        <sqlx::MySql as sqlx::Database>::TransactionManager::start_rollback(c);
267                    }
268                    #[cfg(feature = "sqlx-postgres")]
269                    InnerConnection::Postgres(c) => {
270                        <sqlx::Postgres as sqlx::Database>::TransactionManager::start_rollback(c);
271                    }
272                    #[cfg(feature = "sqlx-sqlite")]
273                    InnerConnection::Sqlite(c) => {
274                        <sqlx::Sqlite as sqlx::Database>::TransactionManager::start_rollback(c);
275                    }
276                    #[cfg(feature = "rusqlite")]
277                    InnerConnection::Rusqlite(c) => {
278                        c.start_rollback()?;
279                    }
280                    #[cfg(feature = "mock")]
281                    InnerConnection::Mock(c) => {
282                        c.rollback();
283                    }
284                    #[cfg(feature = "proxy")]
285                    InnerConnection::Proxy(c) => {
286                        c.start_rollback();
287                    }
288                    #[allow(unreachable_patterns)]
289                    _ => return Err(conn_err("Disconnected")),
290                }
291            } else {
292                //this should never happen
293                return Err(conn_err("Dropping a locked Transaction"));
294            }
295        }
296        Ok(())
297    }
298}
299
300#[async_trait::async_trait]
301impl TransactionSession for DatabaseTransaction {
302    async fn commit(self) -> Result<(), DbErr> {
303        self.commit().await
304    }
305
306    async fn rollback(self) -> Result<(), DbErr> {
307        self.rollback().await
308    }
309}
310
311impl Drop for DatabaseTransaction {
312    fn drop(&mut self) {
313        self.start_rollback().expect("Fail to rollback transaction");
314    }
315}
316
317#[async_trait::async_trait]
318impl ConnectionTrait for DatabaseTransaction {
319    fn get_database_backend(&self) -> DbBackend {
320        // this way we don't need to lock just to know the backend
321        self.backend
322    }
323
324    #[instrument(level = "trace")]
325    #[allow(unused_variables)]
326    async fn execute_raw(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
327        debug_print!("{}", stmt);
328
329        super::tracing_spans::with_db_span!(
330            "sea_orm.execute",
331            self.backend,
332            stmt.sql.as_str(),
333            record_stmt = true,
334            async {
335                #[cfg(not(feature = "sync"))]
336                let conn = &mut *self.conn.lock().await;
337                #[cfg(feature = "sync")]
338                let conn = &mut *self.conn.lock().map_err(|_| DbErr::MutexPoisonError)?;
339
340                match conn {
341                    #[cfg(feature = "sqlx-mysql")]
342                    InnerConnection::MySql(conn) => {
343                        let query = crate::driver::sqlx_mysql::sqlx_query(&stmt);
344                        let conn: &mut sqlx::MySqlConnection = &mut *conn;
345                        crate::metric::metric!(self.metric_callback, &stmt, {
346                            query.execute(conn).await.map(Into::into)
347                        })
348                        .map_err(sqlx_error_to_exec_err)
349                    }
350                    #[cfg(feature = "sqlx-postgres")]
351                    InnerConnection::Postgres(conn) => {
352                        let query = crate::driver::sqlx_postgres::sqlx_query(&stmt);
353                        let conn: &mut sqlx::PgConnection = &mut *conn;
354                        crate::metric::metric!(self.metric_callback, &stmt, {
355                            query.execute(conn).await.map(Into::into)
356                        })
357                        .map_err(sqlx_error_to_exec_err)
358                    }
359                    #[cfg(feature = "sqlx-sqlite")]
360                    InnerConnection::Sqlite(conn) => {
361                        let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt);
362                        let conn: &mut sqlx::SqliteConnection = &mut *conn;
363                        crate::metric::metric!(self.metric_callback, &stmt, {
364                            query.execute(conn).await.map(Into::into)
365                        })
366                        .map_err(sqlx_error_to_exec_err)
367                    }
368                    #[cfg(feature = "rusqlite")]
369                    InnerConnection::Rusqlite(conn) => conn.execute(stmt, &self.metric_callback),
370                    #[cfg(feature = "mock")]
371                    InnerConnection::Mock(conn) => conn.execute(stmt),
372                    #[cfg(feature = "proxy")]
373                    InnerConnection::Proxy(conn) => conn.execute(stmt).await,
374                    #[allow(unreachable_patterns)]
375                    _ => Err(conn_err("Disconnected")),
376                }
377            }
378        )
379    }
380
381    #[instrument(level = "trace")]
382    #[allow(unused_variables)]
383    async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
384        debug_print!("{}", sql);
385
386        super::tracing_spans::with_db_span!(
387            "sea_orm.execute_unprepared",
388            self.backend,
389            sql,
390            record_stmt = false,
391            async {
392                #[cfg(not(feature = "sync"))]
393                let conn = &mut *self.conn.lock().await;
394                #[cfg(feature = "sync")]
395                let conn = &mut *self.conn.lock().map_err(|_| DbErr::MutexPoisonError)?;
396
397                match conn {
398                    #[cfg(feature = "sqlx-mysql")]
399                    InnerConnection::MySql(conn) => {
400                        let conn: &mut sqlx::MySqlConnection = &mut *conn;
401                        sqlx::Executor::execute(conn, sql)
402                            .await
403                            .map(Into::into)
404                            .map_err(sqlx_error_to_exec_err)
405                    }
406                    #[cfg(feature = "sqlx-postgres")]
407                    InnerConnection::Postgres(conn) => {
408                        let conn: &mut sqlx::PgConnection = &mut *conn;
409                        sqlx::Executor::execute(conn, sql)
410                            .await
411                            .map(Into::into)
412                            .map_err(sqlx_error_to_exec_err)
413                    }
414                    #[cfg(feature = "sqlx-sqlite")]
415                    InnerConnection::Sqlite(conn) => {
416                        let conn: &mut sqlx::SqliteConnection = &mut *conn;
417                        sqlx::Executor::execute(conn, sql)
418                            .await
419                            .map(Into::into)
420                            .map_err(sqlx_error_to_exec_err)
421                    }
422                    #[cfg(feature = "rusqlite")]
423                    InnerConnection::Rusqlite(conn) => conn.execute_unprepared(sql),
424                    #[cfg(feature = "mock")]
425                    InnerConnection::Mock(conn) => {
426                        let db_backend = conn.get_database_backend();
427                        let stmt = Statement::from_string(db_backend, sql);
428                        conn.execute(stmt)
429                    }
430                    #[cfg(feature = "proxy")]
431                    InnerConnection::Proxy(conn) => {
432                        let db_backend = conn.get_database_backend();
433                        let stmt = Statement::from_string(db_backend, sql);
434                        conn.execute(stmt).await
435                    }
436                    #[allow(unreachable_patterns)]
437                    _ => Err(conn_err("Disconnected")),
438                }
439            }
440        )
441    }
442
443    #[instrument(level = "trace")]
444    #[allow(unused_variables)]
445    async fn query_one_raw(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
446        debug_print!("{}", stmt);
447
448        super::tracing_spans::with_db_span!(
449            "sea_orm.query_one",
450            self.backend,
451            stmt.sql.as_str(),
452            record_stmt = true,
453            async {
454                #[cfg(not(feature = "sync"))]
455                let conn = &mut *self.conn.lock().await;
456                #[cfg(feature = "sync")]
457                let conn = &mut *self.conn.lock().map_err(|_| DbErr::MutexPoisonError)?;
458
459                match conn {
460                    #[cfg(feature = "sqlx-mysql")]
461                    InnerConnection::MySql(conn) => {
462                        let query = crate::driver::sqlx_mysql::sqlx_query(&stmt);
463                        let conn: &mut sqlx::MySqlConnection = &mut *conn;
464                        crate::metric::metric!(self.metric_callback, &stmt, {
465                            crate::sqlx_map_err_ignore_not_found(
466                                query.fetch_one(conn).await.map(|row| Some(row.into())),
467                            )
468                        })
469                    }
470                    #[cfg(feature = "sqlx-postgres")]
471                    InnerConnection::Postgres(conn) => {
472                        let query = crate::driver::sqlx_postgres::sqlx_query(&stmt);
473                        let conn: &mut sqlx::PgConnection = &mut *conn;
474                        crate::metric::metric!(self.metric_callback, &stmt, {
475                            crate::sqlx_map_err_ignore_not_found(
476                                query.fetch_one(conn).await.map(|row| Some(row.into())),
477                            )
478                        })
479                    }
480                    #[cfg(feature = "sqlx-sqlite")]
481                    InnerConnection::Sqlite(conn) => {
482                        let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt);
483                        let conn: &mut sqlx::SqliteConnection = &mut *conn;
484                        crate::metric::metric!(self.metric_callback, &stmt, {
485                            crate::sqlx_map_err_ignore_not_found(
486                                query.fetch_one(conn).await.map(|row| Some(row.into())),
487                            )
488                        })
489                    }
490                    #[cfg(feature = "rusqlite")]
491                    InnerConnection::Rusqlite(conn) => conn.query_one(stmt, &self.metric_callback),
492                    #[cfg(feature = "mock")]
493                    InnerConnection::Mock(conn) => conn.query_one(stmt),
494                    #[cfg(feature = "proxy")]
495                    InnerConnection::Proxy(conn) => conn.query_one(stmt).await,
496                    #[allow(unreachable_patterns)]
497                    _ => Err(conn_err("Disconnected")),
498                }
499            }
500        )
501    }
502
503    #[instrument(level = "trace")]
504    #[allow(unused_variables)]
505    async fn query_all_raw(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
506        debug_print!("{}", stmt);
507
508        super::tracing_spans::with_db_span!(
509            "sea_orm.query_all",
510            self.backend,
511            stmt.sql.as_str(),
512            record_stmt = true,
513            async {
514                #[cfg(not(feature = "sync"))]
515                let conn = &mut *self.conn.lock().await;
516                #[cfg(feature = "sync")]
517                let conn = &mut *self.conn.lock().map_err(|_| DbErr::MutexPoisonError)?;
518
519                match conn {
520                    #[cfg(feature = "sqlx-mysql")]
521                    InnerConnection::MySql(conn) => {
522                        let query = crate::driver::sqlx_mysql::sqlx_query(&stmt);
523                        let conn: &mut sqlx::MySqlConnection = &mut *conn;
524                        crate::metric::metric!(self.metric_callback, &stmt, {
525                            query
526                                .fetch_all(conn)
527                                .await
528                                .map(|rows| rows.into_iter().map(|r| r.into()).collect())
529                                .map_err(sqlx_error_to_query_err)
530                        })
531                    }
532                    #[cfg(feature = "sqlx-postgres")]
533                    InnerConnection::Postgres(conn) => {
534                        let query = crate::driver::sqlx_postgres::sqlx_query(&stmt);
535                        let conn: &mut sqlx::PgConnection = &mut *conn;
536                        crate::metric::metric!(self.metric_callback, &stmt, {
537                            query
538                                .fetch_all(conn)
539                                .await
540                                .map(|rows| rows.into_iter().map(|r| r.into()).collect())
541                                .map_err(sqlx_error_to_query_err)
542                        })
543                    }
544                    #[cfg(feature = "sqlx-sqlite")]
545                    InnerConnection::Sqlite(conn) => {
546                        let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt);
547                        let conn: &mut sqlx::SqliteConnection = &mut *conn;
548                        crate::metric::metric!(self.metric_callback, &stmt, {
549                            query
550                                .fetch_all(conn)
551                                .await
552                                .map(|rows| rows.into_iter().map(|r| r.into()).collect())
553                                .map_err(sqlx_error_to_query_err)
554                        })
555                    }
556                    #[cfg(feature = "rusqlite")]
557                    InnerConnection::Rusqlite(conn) => conn.query_all(stmt, &self.metric_callback),
558                    #[cfg(feature = "mock")]
559                    InnerConnection::Mock(conn) => conn.query_all(stmt),
560                    #[cfg(feature = "proxy")]
561                    InnerConnection::Proxy(conn) => conn.query_all(stmt).await,
562                    #[allow(unreachable_patterns)]
563                    _ => Err(conn_err("Disconnected")),
564                }
565            }
566        )
567    }
568}
569
570impl StreamTrait for DatabaseTransaction {
571    type Stream<'a> = TransactionStream<'a>;
572
573    fn get_database_backend(&self) -> DbBackend {
574        self.backend
575    }
576
577    #[instrument(level = "trace")]
578    fn stream_raw<'a>(
579        &'a self,
580        stmt: Statement,
581    ) -> Pin<Box<dyn Future<Output = Result<Self::Stream<'a>, DbErr>> + 'a + Send>> {
582        Box::pin(async move {
583            #[cfg(not(feature = "sync"))]
584            let conn = self.conn.lock().await;
585            #[cfg(feature = "sync")]
586            let conn = self.conn.lock().map_err(|_| DbErr::MutexPoisonError)?;
587            Ok(crate::TransactionStream::build(
588                conn,
589                stmt,
590                self.metric_callback.clone(),
591            ))
592        })
593    }
594}
595
596#[async_trait::async_trait]
597impl TransactionTrait for DatabaseTransaction {
598    type Transaction = DatabaseTransaction;
599
600    #[instrument(level = "trace")]
601    async fn begin(&self) -> Result<DatabaseTransaction, DbErr> {
602        DatabaseTransaction::begin(
603            Arc::clone(&self.conn),
604            self.backend,
605            self.metric_callback.clone(),
606            None,
607            None,
608        )
609        .await
610    }
611
612    #[instrument(level = "trace")]
613    async fn begin_with_config(
614        &self,
615        isolation_level: Option<IsolationLevel>,
616        access_mode: Option<AccessMode>,
617    ) -> Result<DatabaseTransaction, DbErr> {
618        DatabaseTransaction::begin(
619            Arc::clone(&self.conn),
620            self.backend,
621            self.metric_callback.clone(),
622            isolation_level,
623            access_mode,
624        )
625        .await
626    }
627
628    /// Execute the async function inside a transaction.
629    /// If the function returns an error, the transaction will be rolled back.
630    /// Otherwise, the transaction will be committed.
631    #[instrument(level = "trace", skip(_callback))]
632    async fn transaction<F, T, E>(&self, _callback: F) -> Result<T, TransactionError<E>>
633    where
634        F: for<'c> FnOnce(
635                &'c DatabaseTransaction,
636            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
637            + Send,
638        T: Send,
639        E: std::fmt::Display + std::fmt::Debug + Send,
640    {
641        let transaction = self.begin().await.map_err(TransactionError::Connection)?;
642        transaction.run(_callback).await
643    }
644
645    /// Execute the async function inside a transaction.
646    /// If the function returns an error, the transaction will be rolled back.
647    /// Otherwise, the transaction will be committed.
648    #[instrument(level = "trace", skip(_callback))]
649    async fn transaction_with_config<F, T, E>(
650        &self,
651        _callback: F,
652        isolation_level: Option<IsolationLevel>,
653        access_mode: Option<AccessMode>,
654    ) -> Result<T, TransactionError<E>>
655    where
656        F: for<'c> FnOnce(
657                &'c DatabaseTransaction,
658            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
659            + Send,
660        T: Send,
661        E: std::fmt::Display + std::fmt::Debug + Send,
662    {
663        let transaction = self
664            .begin_with_config(isolation_level, access_mode)
665            .await
666            .map_err(TransactionError::Connection)?;
667        transaction.run(_callback).await
668    }
669}
670
671/// Defines errors for handling transaction failures
672#[derive(Debug)]
673pub enum TransactionError<E> {
674    /// A Database connection error
675    Connection(DbErr),
676    /// An error occurring when doing database transactions
677    Transaction(E),
678}
679
680impl<E> std::fmt::Display for TransactionError<E>
681where
682    E: std::fmt::Display + std::fmt::Debug,
683{
684    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
685        match self {
686            TransactionError::Connection(e) => std::fmt::Display::fmt(e, f),
687            TransactionError::Transaction(e) => std::fmt::Display::fmt(e, f),
688        }
689    }
690}
691
692impl<E> std::error::Error for TransactionError<E> where E: std::fmt::Display + std::fmt::Debug {}
693
694impl<E> From<DbErr> for TransactionError<E>
695where
696    E: std::fmt::Display + std::fmt::Debug,
697{
698    fn from(e: DbErr) -> Self {
699        Self::Connection(e)
700    }
701}