Skip to main content

sea_orm/database/
transaction.rs

1#![allow(unused_assignments)]
2use std::{future::Future, pin::Pin, sync::Arc};
3
4use futures_util::lock::Mutex;
5#[cfg(feature = "sqlx-dep")]
6use sqlx::TransactionManager;
7use tracing::instrument;
8
9use crate::{
10    AccessMode, ConnectionTrait, DbBackend, DbErr, ExecResult, InnerConnection, IsolationLevel,
11    QueryResult, SqliteTransactionMode, Statement, StreamTrait, TransactionOptions,
12    TransactionSession, TransactionStream, TransactionTrait, debug_print, error::*,
13};
14#[cfg(feature = "sqlx-dep")]
15use crate::{sqlx_error_to_exec_err, sqlx_error_to_query_err};
16
17/// Defines a database transaction, whether it is an open transaction and the type of
18/// backend to use.
19/// Under the hood, a Transaction is just a wrapper for a connection where
20/// START TRANSACTION has been executed.
21pub struct DatabaseTransaction {
22    conn: Arc<Mutex<InnerConnection>>,
23    backend: DbBackend,
24    open: bool,
25    metric_callback: Option<crate::metric::Callback>,
26}
27
28impl std::fmt::Debug for DatabaseTransaction {
29    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30        write!(f, "DatabaseTransaction")
31    }
32}
33
34impl DatabaseTransaction {
35    #[instrument(level = "trace", skip(metric_callback))]
36    pub(crate) async fn begin(
37        conn: Arc<Mutex<InnerConnection>>,
38        backend: DbBackend,
39        metric_callback: Option<crate::metric::Callback>,
40        isolation_level: Option<IsolationLevel>,
41        access_mode: Option<AccessMode>,
42        sqlite_transaction_mode: Option<SqliteTransactionMode>,
43    ) -> Result<DatabaseTransaction, DbErr> {
44        let res = DatabaseTransaction {
45            conn,
46            backend,
47            open: true,
48            metric_callback,
49        };
50
51        let begin_result: Result<(), DbErr> = super::tracing_spans::with_db_span!(
52            "sea_orm.begin",
53            backend,
54            "BEGIN",
55            record_stmt = false,
56            async {
57                #[cfg(not(feature = "sync"))]
58                let conn = &mut *res.conn.lock().await;
59                #[cfg(feature = "sync")]
60                let conn = &mut *res.conn.lock().map_err(|_| DbErr::MutexPoisonError)?;
61
62                match conn {
63                    #[cfg(feature = "sqlx-mysql")]
64                    InnerConnection::MySql(c) => {
65                        // in MySQL SET TRANSACTION operations must be executed before transaction start
66                        crate::driver::sqlx_mysql::set_transaction_config(
67                            c,
68                            isolation_level,
69                            access_mode,
70                        )
71                        .await?;
72                        <sqlx::MySql as sqlx::Database>::TransactionManager::begin(c, None)
73                            .await
74                            .map_err(sqlx_error_to_query_err)
75                    }
76                    #[cfg(feature = "sqlx-postgres")]
77                    InnerConnection::Postgres(c) => {
78                        <sqlx::Postgres as sqlx::Database>::TransactionManager::begin(c, None)
79                            .await
80                            .map_err(sqlx_error_to_query_err)?;
81                        // in PostgreSQL SET TRANSACTION operations must be executed inside transaction
82                        crate::driver::sqlx_postgres::set_transaction_config(
83                            c,
84                            isolation_level,
85                            access_mode,
86                        )
87                        .await
88                    }
89                    #[cfg(feature = "sqlx-sqlite")]
90                    InnerConnection::Sqlite(c) => {
91                        crate::driver::sqlx_sqlite::set_transaction_config(
92                            c,
93                            isolation_level,
94                            access_mode,
95                        )
96                        .await?;
97                        let depth = <sqlx::Sqlite as sqlx::Database>::TransactionManager::get_transaction_depth(c);
98                        let statement = if depth == 0 {
99                            sqlite_transaction_mode.map(|mode| {
100                                std::borrow::Cow::from(format!("BEGIN {}", mode.sqlite_keyword()))
101                            })
102                        } else {
103                            // Nested transaction uses SAVEPOINT; the mode only applies to the top-level BEGIN
104                            None
105                        };
106                        <sqlx::Sqlite as sqlx::Database>::TransactionManager::begin(c, statement)
107                            .await
108                            .map_err(sqlx_error_to_query_err)
109                    }
110                    #[cfg(feature = "rusqlite")]
111                    InnerConnection::Rusqlite(c) => c.begin(sqlite_transaction_mode),
112                    #[cfg(feature = "mock")]
113                    InnerConnection::Mock(c) => {
114                        c.begin();
115                        Ok(())
116                    }
117                    #[cfg(feature = "proxy")]
118                    InnerConnection::Proxy(c) => {
119                        c.begin().await;
120                        Ok(())
121                    }
122                    #[allow(unreachable_patterns)]
123                    _ => Err(conn_err("Disconnected")),
124                }
125            }
126        );
127
128        begin_result?;
129        Ok(res)
130    }
131
132    /// Runs a transaction to completion passing through the result.
133    /// Rolling back the transaction on encountering an error.
134    #[instrument(level = "trace", skip(callback))]
135    pub(crate) async fn run<F, T, E>(self, callback: F) -> Result<T, TransactionError<E>>
136    where
137        F: for<'b> FnOnce(
138                &'b DatabaseTransaction,
139            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'b>>
140            + Send,
141        T: Send,
142        E: std::fmt::Display + std::fmt::Debug + Send,
143    {
144        let res = callback(&self).await.map_err(TransactionError::Transaction);
145        if res.is_ok() {
146            self.commit().await.map_err(TransactionError::Connection)?;
147        } else {
148            self.rollback()
149                .await
150                .map_err(TransactionError::Connection)?;
151        }
152        res
153    }
154
155    /// Commit a transaction
156    #[instrument(level = "trace")]
157    #[allow(unreachable_code, unused_mut)]
158    pub async fn commit(mut self) -> Result<(), DbErr> {
159        let result: Result<(), DbErr> = super::tracing_spans::with_db_span!(
160            "sea_orm.commit",
161            self.backend,
162            "COMMIT",
163            record_stmt = false,
164            async {
165                #[cfg(not(feature = "sync"))]
166                let conn = &mut *self.conn.lock().await;
167                #[cfg(feature = "sync")]
168                let conn = &mut *self.conn.lock().map_err(|_| DbErr::MutexPoisonError)?;
169
170                match conn {
171                    #[cfg(feature = "sqlx-mysql")]
172                    InnerConnection::MySql(c) => {
173                        <sqlx::MySql as sqlx::Database>::TransactionManager::commit(c)
174                            .await
175                            .map_err(sqlx_error_to_query_err)
176                    }
177                    #[cfg(feature = "sqlx-postgres")]
178                    InnerConnection::Postgres(c) => {
179                        <sqlx::Postgres as sqlx::Database>::TransactionManager::commit(c)
180                            .await
181                            .map_err(sqlx_error_to_query_err)
182                    }
183                    #[cfg(feature = "sqlx-sqlite")]
184                    InnerConnection::Sqlite(c) => {
185                        <sqlx::Sqlite as sqlx::Database>::TransactionManager::commit(c)
186                            .await
187                            .map_err(sqlx_error_to_query_err)
188                    }
189                    #[cfg(feature = "rusqlite")]
190                    InnerConnection::Rusqlite(c) => c.commit(),
191                    #[cfg(feature = "mock")]
192                    InnerConnection::Mock(c) => {
193                        c.commit();
194                        Ok(())
195                    }
196                    #[cfg(feature = "proxy")]
197                    InnerConnection::Proxy(c) => {
198                        c.commit().await;
199                        Ok(())
200                    }
201                    #[allow(unreachable_patterns)]
202                    _ => Err(conn_err("Disconnected")),
203                }
204            }
205        );
206
207        result?;
208        self.open = false; // read by start_rollback
209        Ok(())
210    }
211
212    /// Rolls back a transaction explicitly
213    #[instrument(level = "trace")]
214    #[allow(unreachable_code, unused_mut)]
215    pub async fn rollback(mut self) -> Result<(), DbErr> {
216        let result: Result<(), DbErr> = super::tracing_spans::with_db_span!(
217            "sea_orm.rollback",
218            self.backend,
219            "ROLLBACK",
220            record_stmt = false,
221            async {
222                #[cfg(not(feature = "sync"))]
223                let conn = &mut *self.conn.lock().await;
224                #[cfg(feature = "sync")]
225                let conn = &mut *self.conn.lock().map_err(|_| DbErr::MutexPoisonError)?;
226
227                match conn {
228                    #[cfg(feature = "sqlx-mysql")]
229                    InnerConnection::MySql(c) => {
230                        <sqlx::MySql as sqlx::Database>::TransactionManager::rollback(c)
231                            .await
232                            .map_err(sqlx_error_to_query_err)
233                    }
234                    #[cfg(feature = "sqlx-postgres")]
235                    InnerConnection::Postgres(c) => {
236                        <sqlx::Postgres as sqlx::Database>::TransactionManager::rollback(c)
237                            .await
238                            .map_err(sqlx_error_to_query_err)
239                    }
240                    #[cfg(feature = "sqlx-sqlite")]
241                    InnerConnection::Sqlite(c) => {
242                        <sqlx::Sqlite as sqlx::Database>::TransactionManager::rollback(c)
243                            .await
244                            .map_err(sqlx_error_to_query_err)
245                    }
246                    #[cfg(feature = "rusqlite")]
247                    InnerConnection::Rusqlite(c) => c.rollback(),
248                    #[cfg(feature = "mock")]
249                    InnerConnection::Mock(c) => {
250                        c.rollback();
251                        Ok(())
252                    }
253                    #[cfg(feature = "proxy")]
254                    InnerConnection::Proxy(c) => {
255                        c.rollback().await;
256                        Ok(())
257                    }
258                    #[allow(unreachable_patterns)]
259                    _ => Err(conn_err("Disconnected")),
260                }
261            }
262        );
263
264        result?;
265        self.open = false; // read by start_rollback
266        Ok(())
267    }
268
269    // the rollback is queued and will be performed on next async operation, like returning the connection to the pool
270    #[instrument(level = "trace")]
271    fn start_rollback(&mut self) -> Result<(), DbErr> {
272        if self.open {
273            if let Some(mut conn) = self.conn.try_lock() {
274                match &mut *conn {
275                    #[cfg(feature = "sqlx-mysql")]
276                    InnerConnection::MySql(c) => {
277                        <sqlx::MySql as sqlx::Database>::TransactionManager::start_rollback(c);
278                    }
279                    #[cfg(feature = "sqlx-postgres")]
280                    InnerConnection::Postgres(c) => {
281                        <sqlx::Postgres as sqlx::Database>::TransactionManager::start_rollback(c);
282                    }
283                    #[cfg(feature = "sqlx-sqlite")]
284                    InnerConnection::Sqlite(c) => {
285                        <sqlx::Sqlite as sqlx::Database>::TransactionManager::start_rollback(c);
286                    }
287                    #[cfg(feature = "rusqlite")]
288                    InnerConnection::Rusqlite(c) => {
289                        c.start_rollback()?;
290                    }
291                    #[cfg(feature = "mock")]
292                    InnerConnection::Mock(c) => {
293                        c.rollback();
294                    }
295                    #[cfg(feature = "proxy")]
296                    InnerConnection::Proxy(c) => {
297                        c.start_rollback();
298                    }
299                    #[allow(unreachable_patterns)]
300                    _ => return Err(conn_err("Disconnected")),
301                }
302            } else {
303                //this should never happen
304                return Err(conn_err("Dropping a locked Transaction"));
305            }
306        }
307        Ok(())
308    }
309}
310
311#[async_trait::async_trait]
312impl TransactionSession for DatabaseTransaction {
313    async fn commit(self) -> Result<(), DbErr> {
314        self.commit().await
315    }
316
317    async fn rollback(self) -> Result<(), DbErr> {
318        self.rollback().await
319    }
320}
321
322impl Drop for DatabaseTransaction {
323    fn drop(&mut self) {
324        self.start_rollback().expect("Fail to rollback transaction");
325    }
326}
327
328#[async_trait::async_trait]
329impl ConnectionTrait for DatabaseTransaction {
330    fn get_database_backend(&self) -> DbBackend {
331        // this way we don't need to lock just to know the backend
332        self.backend
333    }
334
335    #[instrument(level = "trace")]
336    #[allow(unused_variables)]
337    async fn execute_raw(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
338        debug_print!("{}", stmt);
339
340        super::tracing_spans::with_db_span!(
341            "sea_orm.execute",
342            self.backend,
343            stmt.sql.as_str(),
344            record_stmt = true,
345            async {
346                #[cfg(not(feature = "sync"))]
347                let conn = &mut *self.conn.lock().await;
348                #[cfg(feature = "sync")]
349                let conn = &mut *self.conn.lock().map_err(|_| DbErr::MutexPoisonError)?;
350
351                match conn {
352                    #[cfg(feature = "sqlx-mysql")]
353                    InnerConnection::MySql(conn) => {
354                        let query = crate::driver::sqlx_mysql::sqlx_query(&stmt);
355                        let conn: &mut sqlx::MySqlConnection = &mut *conn;
356                        crate::metric::metric!(self.metric_callback, &stmt, {
357                            query.execute(conn).await.map(Into::into)
358                        })
359                        .map_err(sqlx_error_to_exec_err)
360                    }
361                    #[cfg(feature = "sqlx-postgres")]
362                    InnerConnection::Postgres(conn) => {
363                        let query = crate::driver::sqlx_postgres::sqlx_query(&stmt);
364                        let conn: &mut sqlx::PgConnection = &mut *conn;
365                        crate::metric::metric!(self.metric_callback, &stmt, {
366                            query.execute(conn).await.map(Into::into)
367                        })
368                        .map_err(sqlx_error_to_exec_err)
369                    }
370                    #[cfg(feature = "sqlx-sqlite")]
371                    InnerConnection::Sqlite(conn) => {
372                        let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt);
373                        let conn: &mut sqlx::SqliteConnection = &mut *conn;
374                        crate::metric::metric!(self.metric_callback, &stmt, {
375                            query.execute(conn).await.map(Into::into)
376                        })
377                        .map_err(sqlx_error_to_exec_err)
378                    }
379                    #[cfg(feature = "rusqlite")]
380                    InnerConnection::Rusqlite(conn) => conn.execute(stmt, &self.metric_callback),
381                    #[cfg(feature = "mock")]
382                    InnerConnection::Mock(conn) => conn.execute(stmt),
383                    #[cfg(feature = "proxy")]
384                    InnerConnection::Proxy(conn) => conn.execute(stmt).await,
385                    #[allow(unreachable_patterns)]
386                    _ => Err(conn_err("Disconnected")),
387                }
388            }
389        )
390    }
391
392    #[instrument(level = "trace")]
393    #[allow(unused_variables)]
394    async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
395        debug_print!("{}", sql);
396
397        super::tracing_spans::with_db_span!(
398            "sea_orm.execute_unprepared",
399            self.backend,
400            sql,
401            record_stmt = false,
402            async {
403                #[cfg(not(feature = "sync"))]
404                let conn = &mut *self.conn.lock().await;
405                #[cfg(feature = "sync")]
406                let conn = &mut *self.conn.lock().map_err(|_| DbErr::MutexPoisonError)?;
407
408                match conn {
409                    #[cfg(feature = "sqlx-mysql")]
410                    InnerConnection::MySql(conn) => {
411                        let conn: &mut sqlx::MySqlConnection = &mut *conn;
412                        sqlx::Executor::execute(conn, sql)
413                            .await
414                            .map(Into::into)
415                            .map_err(sqlx_error_to_exec_err)
416                    }
417                    #[cfg(feature = "sqlx-postgres")]
418                    InnerConnection::Postgres(conn) => {
419                        let conn: &mut sqlx::PgConnection = &mut *conn;
420                        sqlx::Executor::execute(conn, sql)
421                            .await
422                            .map(Into::into)
423                            .map_err(sqlx_error_to_exec_err)
424                    }
425                    #[cfg(feature = "sqlx-sqlite")]
426                    InnerConnection::Sqlite(conn) => {
427                        let conn: &mut sqlx::SqliteConnection = &mut *conn;
428                        sqlx::Executor::execute(conn, sql)
429                            .await
430                            .map(Into::into)
431                            .map_err(sqlx_error_to_exec_err)
432                    }
433                    #[cfg(feature = "rusqlite")]
434                    InnerConnection::Rusqlite(conn) => conn.execute_unprepared(sql),
435                    #[cfg(feature = "mock")]
436                    InnerConnection::Mock(conn) => {
437                        let db_backend = conn.get_database_backend();
438                        let stmt = Statement::from_string(db_backend, sql);
439                        conn.execute(stmt)
440                    }
441                    #[cfg(feature = "proxy")]
442                    InnerConnection::Proxy(conn) => {
443                        let db_backend = conn.get_database_backend();
444                        let stmt = Statement::from_string(db_backend, sql);
445                        conn.execute(stmt).await
446                    }
447                    #[allow(unreachable_patterns)]
448                    _ => Err(conn_err("Disconnected")),
449                }
450            }
451        )
452    }
453
454    #[instrument(level = "trace")]
455    #[allow(unused_variables)]
456    async fn query_one_raw(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
457        debug_print!("{}", stmt);
458
459        super::tracing_spans::with_db_span!(
460            "sea_orm.query_one",
461            self.backend,
462            stmt.sql.as_str(),
463            record_stmt = true,
464            async {
465                #[cfg(not(feature = "sync"))]
466                let conn = &mut *self.conn.lock().await;
467                #[cfg(feature = "sync")]
468                let conn = &mut *self.conn.lock().map_err(|_| DbErr::MutexPoisonError)?;
469
470                match conn {
471                    #[cfg(feature = "sqlx-mysql")]
472                    InnerConnection::MySql(conn) => {
473                        let query = crate::driver::sqlx_mysql::sqlx_query(&stmt);
474                        let conn: &mut sqlx::MySqlConnection = &mut *conn;
475                        crate::metric::metric!(self.metric_callback, &stmt, {
476                            crate::sqlx_map_err_ignore_not_found(
477                                query.fetch_one(conn).await.map(|row| Some(row.into())),
478                            )
479                        })
480                    }
481                    #[cfg(feature = "sqlx-postgres")]
482                    InnerConnection::Postgres(conn) => {
483                        let query = crate::driver::sqlx_postgres::sqlx_query(&stmt);
484                        let conn: &mut sqlx::PgConnection = &mut *conn;
485                        crate::metric::metric!(self.metric_callback, &stmt, {
486                            crate::sqlx_map_err_ignore_not_found(
487                                query.fetch_one(conn).await.map(|row| Some(row.into())),
488                            )
489                        })
490                    }
491                    #[cfg(feature = "sqlx-sqlite")]
492                    InnerConnection::Sqlite(conn) => {
493                        let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt);
494                        let conn: &mut sqlx::SqliteConnection = &mut *conn;
495                        crate::metric::metric!(self.metric_callback, &stmt, {
496                            crate::sqlx_map_err_ignore_not_found(
497                                query.fetch_one(conn).await.map(|row| Some(row.into())),
498                            )
499                        })
500                    }
501                    #[cfg(feature = "rusqlite")]
502                    InnerConnection::Rusqlite(conn) => conn.query_one(stmt, &self.metric_callback),
503                    #[cfg(feature = "mock")]
504                    InnerConnection::Mock(conn) => conn.query_one(stmt),
505                    #[cfg(feature = "proxy")]
506                    InnerConnection::Proxy(conn) => conn.query_one(stmt).await,
507                    #[allow(unreachable_patterns)]
508                    _ => Err(conn_err("Disconnected")),
509                }
510            }
511        )
512    }
513
514    #[instrument(level = "trace")]
515    #[allow(unused_variables)]
516    async fn query_all_raw(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
517        debug_print!("{}", stmt);
518
519        super::tracing_spans::with_db_span!(
520            "sea_orm.query_all",
521            self.backend,
522            stmt.sql.as_str(),
523            record_stmt = true,
524            async {
525                #[cfg(not(feature = "sync"))]
526                let conn = &mut *self.conn.lock().await;
527                #[cfg(feature = "sync")]
528                let conn = &mut *self.conn.lock().map_err(|_| DbErr::MutexPoisonError)?;
529
530                match conn {
531                    #[cfg(feature = "sqlx-mysql")]
532                    InnerConnection::MySql(conn) => {
533                        let query = crate::driver::sqlx_mysql::sqlx_query(&stmt);
534                        let conn: &mut sqlx::MySqlConnection = &mut *conn;
535                        crate::metric::metric!(self.metric_callback, &stmt, {
536                            query
537                                .fetch_all(conn)
538                                .await
539                                .map(|rows| rows.into_iter().map(|r| r.into()).collect())
540                                .map_err(sqlx_error_to_query_err)
541                        })
542                    }
543                    #[cfg(feature = "sqlx-postgres")]
544                    InnerConnection::Postgres(conn) => {
545                        let query = crate::driver::sqlx_postgres::sqlx_query(&stmt);
546                        let conn: &mut sqlx::PgConnection = &mut *conn;
547                        crate::metric::metric!(self.metric_callback, &stmt, {
548                            query
549                                .fetch_all(conn)
550                                .await
551                                .map(|rows| rows.into_iter().map(|r| r.into()).collect())
552                                .map_err(sqlx_error_to_query_err)
553                        })
554                    }
555                    #[cfg(feature = "sqlx-sqlite")]
556                    InnerConnection::Sqlite(conn) => {
557                        let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt);
558                        let conn: &mut sqlx::SqliteConnection = &mut *conn;
559                        crate::metric::metric!(self.metric_callback, &stmt, {
560                            query
561                                .fetch_all(conn)
562                                .await
563                                .map(|rows| rows.into_iter().map(|r| r.into()).collect())
564                                .map_err(sqlx_error_to_query_err)
565                        })
566                    }
567                    #[cfg(feature = "rusqlite")]
568                    InnerConnection::Rusqlite(conn) => conn.query_all(stmt, &self.metric_callback),
569                    #[cfg(feature = "mock")]
570                    InnerConnection::Mock(conn) => conn.query_all(stmt),
571                    #[cfg(feature = "proxy")]
572                    InnerConnection::Proxy(conn) => conn.query_all(stmt).await,
573                    #[allow(unreachable_patterns)]
574                    _ => Err(conn_err("Disconnected")),
575                }
576            }
577        )
578    }
579}
580
581impl StreamTrait for DatabaseTransaction {
582    type Stream<'a> = TransactionStream<'a>;
583
584    fn get_database_backend(&self) -> DbBackend {
585        self.backend
586    }
587
588    #[instrument(level = "trace")]
589    fn stream_raw<'a>(
590        &'a self,
591        stmt: Statement,
592    ) -> Pin<Box<dyn Future<Output = Result<Self::Stream<'a>, DbErr>> + 'a + Send>> {
593        Box::pin(async move {
594            #[cfg(not(feature = "sync"))]
595            let conn = self.conn.lock().await;
596            #[cfg(feature = "sync")]
597            let conn = self.conn.lock().map_err(|_| DbErr::MutexPoisonError)?;
598            Ok(crate::TransactionStream::build(
599                conn,
600                stmt,
601                self.metric_callback.clone(),
602            ))
603        })
604    }
605}
606
607#[async_trait::async_trait]
608impl TransactionTrait for DatabaseTransaction {
609    type Transaction = DatabaseTransaction;
610
611    #[instrument(level = "trace")]
612    async fn begin(&self) -> Result<DatabaseTransaction, DbErr> {
613        DatabaseTransaction::begin(
614            Arc::clone(&self.conn),
615            self.backend,
616            self.metric_callback.clone(),
617            None,
618            None,
619            None,
620        )
621        .await
622    }
623
624    #[instrument(level = "trace")]
625    async fn begin_with_config(
626        &self,
627        isolation_level: Option<IsolationLevel>,
628        access_mode: Option<AccessMode>,
629    ) -> Result<DatabaseTransaction, DbErr> {
630        DatabaseTransaction::begin(
631            Arc::clone(&self.conn),
632            self.backend,
633            self.metric_callback.clone(),
634            isolation_level,
635            access_mode,
636            None,
637        )
638        .await
639    }
640
641    #[instrument(level = "trace")]
642    async fn begin_with_options(
643        &self,
644        options: TransactionOptions,
645    ) -> Result<DatabaseTransaction, DbErr> {
646        DatabaseTransaction::begin(
647            Arc::clone(&self.conn),
648            self.backend,
649            self.metric_callback.clone(),
650            options.isolation_level,
651            options.access_mode,
652            options.sqlite_transaction_mode,
653        )
654        .await
655    }
656
657    /// Execute the async function inside a transaction.
658    /// If the function returns an error, the transaction will be rolled back.
659    /// Otherwise, the transaction will be committed.
660    #[instrument(level = "trace", skip(_callback))]
661    async fn transaction<F, T, E>(&self, _callback: F) -> Result<T, TransactionError<E>>
662    where
663        F: for<'c> FnOnce(
664                &'c DatabaseTransaction,
665            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
666            + Send,
667        T: Send,
668        E: std::fmt::Display + std::fmt::Debug + Send,
669    {
670        let transaction = self.begin().await.map_err(TransactionError::Connection)?;
671        transaction.run(_callback).await
672    }
673
674    /// Execute the async function inside a transaction.
675    /// If the function returns an error, the transaction will be rolled back.
676    /// Otherwise, the transaction will be committed.
677    #[instrument(level = "trace", skip(_callback))]
678    async fn transaction_with_config<F, T, E>(
679        &self,
680        _callback: F,
681        isolation_level: Option<IsolationLevel>,
682        access_mode: Option<AccessMode>,
683    ) -> Result<T, TransactionError<E>>
684    where
685        F: for<'c> FnOnce(
686                &'c DatabaseTransaction,
687            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
688            + Send,
689        T: Send,
690        E: std::fmt::Display + std::fmt::Debug + Send,
691    {
692        let transaction = self
693            .begin_with_config(isolation_level, access_mode)
694            .await
695            .map_err(TransactionError::Connection)?;
696        transaction.run(_callback).await
697    }
698}
699
700/// Defines errors for handling transaction failures
701#[derive(Debug)]
702pub enum TransactionError<E> {
703    /// A Database connection error
704    Connection(DbErr),
705    /// An error occurring when doing database transactions
706    Transaction(E),
707}
708
709impl<E> std::fmt::Display for TransactionError<E>
710where
711    E: std::fmt::Display + std::fmt::Debug,
712{
713    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
714        match self {
715            TransactionError::Connection(e) => std::fmt::Display::fmt(e, f),
716            TransactionError::Transaction(e) => std::fmt::Display::fmt(e, f),
717        }
718    }
719}
720
721impl<E> std::error::Error for TransactionError<E> where E: std::fmt::Display + std::fmt::Debug {}
722
723impl<E> From<DbErr> for TransactionError<E>
724where
725    E: std::fmt::Display + std::fmt::Debug,
726{
727    fn from(e: DbErr) -> Self {
728        Self::Connection(e)
729    }
730}