Skip to main content

sea_orm/database/
transaction.rs

1#![allow(unused_assignments)]
2use std::{future::Future, pin::Pin, sync::Arc};
3
4use futures_util::lock::Mutex;
5#[cfg(feature = "sqlx-sqlite")]
6use sqlx_core::sql_str::SqlSafeStr;
7#[cfg(feature = "sqlx-dep")]
8use sqlx_core::transaction::TransactionManager;
9use tracing::instrument;
10
11use crate::{
12    AccessMode, ConnectionTrait, DbBackend, DbErr, ExecResult, InnerConnection, IsolationLevel,
13    QueryResult, SqliteTransactionMode, Statement, TransactionOptions, TransactionSession,
14    TransactionTrait, debug_print, error::*,
15};
16#[cfg(feature = "sqlx-dep")]
17use crate::{sqlx_error_to_exec_err, sqlx_error_to_query_err};
18
19#[cfg(feature = "stream")]
20use crate::{StreamTrait, TransactionStream};
21
22/// A live database transaction.
23///
24/// Obtain one with
25/// [`TransactionTrait::begin`](crate::TransactionTrait::begin) (manual
26/// commit/rollback) or
27/// [`TransactionTrait::transaction`](crate::TransactionTrait::transaction)
28/// (auto commit on `Ok`, auto rollback on `Err`). Like
29/// [`DatabaseConnection`](crate::DatabaseConnection) it implements
30/// [`ConnectionTrait`](crate::ConnectionTrait), so SeaORM's query and
31/// mutation methods work against it transparently. Calling `begin` on a
32/// transaction starts a nested transaction via `SAVEPOINT`.
33pub struct DatabaseTransaction {
34    conn: Arc<Mutex<InnerConnection>>,
35    backend: DbBackend,
36    open: bool,
37    metric_callback: Option<crate::metric::Callback>,
38    record_stmt_in_spans: bool,
39}
40
41#[instrument(level = "trace", skip(transaction, callback))]
42pub(crate) async fn run_async_transaction_callback<Txn, F, T, E>(
43    transaction: Txn,
44    callback: F,
45) -> Result<T, TransactionError<E>>
46where
47    Txn: TransactionSession + Send + Sync,
48    F: for<'b> AsyncFnOnce(&'b Txn) -> Result<T, E> + Send,
49    T: Send,
50    E: std::fmt::Display + std::fmt::Debug + Send,
51{
52    let res = callback(&transaction)
53        .await
54        .map_err(TransactionError::Transaction);
55    if res.is_ok() {
56        transaction
57            .commit()
58            .await
59            .map_err(TransactionError::Connection)?;
60    } else {
61        transaction
62            .rollback()
63            .await
64            .map_err(TransactionError::Connection)?;
65    }
66    res
67}
68
69impl std::fmt::Debug for DatabaseTransaction {
70    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71        write!(f, "DatabaseTransaction")
72    }
73}
74
75impl DatabaseTransaction {
76    #[instrument(level = "trace", skip(metric_callback))]
77    pub(crate) async fn begin(
78        conn: Arc<Mutex<InnerConnection>>,
79        backend: DbBackend,
80        metric_callback: Option<crate::metric::Callback>,
81        record_stmt_in_spans: bool,
82        isolation_level: Option<IsolationLevel>,
83        access_mode: Option<AccessMode>,
84        sqlite_transaction_mode: Option<SqliteTransactionMode>,
85    ) -> Result<DatabaseTransaction, DbErr> {
86        let res = DatabaseTransaction {
87            conn,
88            backend,
89            open: true,
90            metric_callback,
91            record_stmt_in_spans,
92        };
93
94        let begin_result: Result<(), DbErr> = super::tracing_spans::with_db_span!(
95            "sea_orm.begin",
96            backend,
97            "BEGIN",
98            record_stmt = false,
99            async {
100                #[cfg(not(feature = "sync"))]
101                let conn = &mut *res.conn.lock().await;
102                #[cfg(feature = "sync")]
103                let conn = &mut *res.conn.lock().map_err(|_| DbErr::MutexPoisonError)?;
104
105                match conn {
106                    #[cfg(feature = "sqlx-mysql")]
107                    InnerConnection::MySql(c) => {
108                        // in MySQL SET TRANSACTION operations must be executed before transaction start
109                        crate::driver::sqlx_mysql::set_transaction_config(
110                            c,
111                            isolation_level,
112                            access_mode,
113                        )
114                        .await?;
115                        <sqlx::MySql as sqlx::Database>::TransactionManager::begin(c, None)
116                            .await
117                            .map_err(sqlx_error_to_query_err)
118                    }
119                    #[cfg(feature = "sqlx-postgres")]
120                    InnerConnection::Postgres(c) => {
121                        <sqlx::Postgres as sqlx::Database>::TransactionManager::begin(c, None)
122                            .await
123                            .map_err(sqlx_error_to_query_err)?;
124                        // in PostgreSQL SET TRANSACTION operations must be executed inside transaction
125                        crate::driver::sqlx_postgres::set_transaction_config(
126                            c,
127                            isolation_level,
128                            access_mode,
129                        )
130                        .await
131                    }
132                    #[cfg(feature = "sqlx-sqlite")]
133                    InnerConnection::Sqlite(c) => {
134                        crate::driver::sqlx_sqlite::set_transaction_config(
135                            c,
136                            isolation_level,
137                            access_mode,
138                        )
139                        .await?;
140                        let depth = <sqlx::Sqlite as sqlx::Database>::TransactionManager::get_transaction_depth(c);
141                        let statement = if depth == 0 {
142                            sqlite_transaction_mode.map(|mode| {
143                                sqlx::AssertSqlSafe(format!("BEGIN {}", mode.sqlite_keyword()))
144                                    .into_sql_str()
145                            })
146                        } else {
147                            // Nested transaction uses SAVEPOINT; the mode only applies to the top-level BEGIN
148                            None
149                        };
150                        <sqlx::Sqlite as sqlx::Database>::TransactionManager::begin(c, statement)
151                            .await
152                            .map_err(sqlx_error_to_query_err)
153                    }
154                    #[cfg(feature = "rusqlite")]
155                    InnerConnection::Rusqlite(c) => c.begin(sqlite_transaction_mode),
156                    #[cfg(feature = "mock")]
157                    InnerConnection::Mock(c) => {
158                        c.begin();
159                        Ok(())
160                    }
161                    #[cfg(feature = "proxy")]
162                    InnerConnection::Proxy(c) => {
163                        c.begin().await;
164                        Ok(())
165                    }
166                    #[allow(unreachable_patterns)]
167                    _ => Err(conn_err("Disconnected")),
168                }
169            }
170        );
171
172        begin_result?;
173        Ok(res)
174    }
175
176    /// Runs a transaction to completion passing through the result.
177    /// Rolling back the transaction on encountering an error.
178    #[instrument(level = "trace", skip(callback))]
179    pub(crate) async fn run<F, T, E>(self, callback: F) -> Result<T, TransactionError<E>>
180    where
181        F: for<'b> FnOnce(
182                &'b DatabaseTransaction,
183            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'b>>
184            + Send,
185        T: Send,
186        E: std::fmt::Display + std::fmt::Debug + Send,
187    {
188        let res = callback(&self).await.map_err(TransactionError::Transaction);
189        if res.is_ok() {
190            self.commit().await.map_err(TransactionError::Connection)?;
191        } else {
192            self.rollback()
193                .await
194                .map_err(TransactionError::Connection)?;
195        }
196        res
197    }
198
199    /// Execute the function inside a transaction.
200    /// If the function returns an error, the transaction will be rolled back.
201    /// Otherwise, the transaction will be committed.
202    #[instrument(level = "trace", skip(callback))]
203    pub async fn transaction_async<F, T, E>(&self, callback: F) -> Result<T, TransactionError<E>>
204    where
205        F: for<'c> AsyncFnOnce(&'c DatabaseTransaction) -> Result<T, E> + Send,
206        T: Send,
207        E: std::fmt::Display + std::fmt::Debug + Send,
208    {
209        let transaction = self.begin().await.map_err(TransactionError::Connection)?;
210        run_async_transaction_callback(transaction, callback).await
211    }
212
213    /// Execute the function inside a transaction with isolation level and/or access mode.
214    /// If the function returns an error, the transaction will be rolled back.
215    /// Otherwise, the transaction will be committed.
216    #[instrument(level = "trace", skip(callback))]
217    pub async fn transaction_with_config_async<F, T, E>(
218        &self,
219        callback: F,
220        isolation_level: Option<IsolationLevel>,
221        access_mode: Option<AccessMode>,
222    ) -> Result<T, TransactionError<E>>
223    where
224        F: for<'c> AsyncFnOnce(&'c DatabaseTransaction) -> Result<T, E> + Send,
225        T: Send,
226        E: std::fmt::Display + std::fmt::Debug + Send,
227    {
228        let transaction = self
229            .begin_with_config(isolation_level, access_mode)
230            .await
231            .map_err(TransactionError::Connection)?;
232        run_async_transaction_callback(transaction, callback).await
233    }
234
235    /// Commit a transaction
236    #[instrument(level = "trace")]
237    #[allow(unreachable_code, unused_mut)]
238    pub async fn commit(mut self) -> Result<(), DbErr> {
239        let result: Result<(), DbErr> = super::tracing_spans::with_db_span!(
240            "sea_orm.commit",
241            self.backend,
242            "COMMIT",
243            record_stmt = false,
244            async {
245                #[cfg(not(feature = "sync"))]
246                let conn = &mut *self.conn.lock().await;
247                #[cfg(feature = "sync")]
248                let conn = &mut *self.conn.lock().map_err(|_| DbErr::MutexPoisonError)?;
249
250                match conn {
251                    #[cfg(feature = "sqlx-mysql")]
252                    InnerConnection::MySql(c) => {
253                        <sqlx::MySql as sqlx::Database>::TransactionManager::commit(c)
254                            .await
255                            .map_err(sqlx_error_to_query_err)
256                    }
257                    #[cfg(feature = "sqlx-postgres")]
258                    InnerConnection::Postgres(c) => {
259                        <sqlx::Postgres as sqlx::Database>::TransactionManager::commit(c)
260                            .await
261                            .map_err(sqlx_error_to_query_err)
262                    }
263                    #[cfg(feature = "sqlx-sqlite")]
264                    InnerConnection::Sqlite(c) => {
265                        <sqlx::Sqlite as sqlx::Database>::TransactionManager::commit(c)
266                            .await
267                            .map_err(sqlx_error_to_query_err)
268                    }
269                    #[cfg(feature = "rusqlite")]
270                    InnerConnection::Rusqlite(c) => c.commit(),
271                    #[cfg(feature = "mock")]
272                    InnerConnection::Mock(c) => {
273                        c.commit();
274                        Ok(())
275                    }
276                    #[cfg(feature = "proxy")]
277                    InnerConnection::Proxy(c) => {
278                        c.commit().await;
279                        Ok(())
280                    }
281                    #[allow(unreachable_patterns)]
282                    _ => Err(conn_err("Disconnected")),
283                }
284            }
285        );
286
287        result?;
288        self.open = false; // read by start_rollback
289        Ok(())
290    }
291
292    /// Rolls back a transaction explicitly
293    #[instrument(level = "trace")]
294    #[allow(unreachable_code, unused_mut)]
295    pub async fn rollback(mut self) -> Result<(), DbErr> {
296        let result: Result<(), DbErr> = super::tracing_spans::with_db_span!(
297            "sea_orm.rollback",
298            self.backend,
299            "ROLLBACK",
300            record_stmt = false,
301            async {
302                #[cfg(not(feature = "sync"))]
303                let conn = &mut *self.conn.lock().await;
304                #[cfg(feature = "sync")]
305                let conn = &mut *self.conn.lock().map_err(|_| DbErr::MutexPoisonError)?;
306
307                match conn {
308                    #[cfg(feature = "sqlx-mysql")]
309                    InnerConnection::MySql(c) => {
310                        <sqlx::MySql as sqlx::Database>::TransactionManager::rollback(c)
311                            .await
312                            .map_err(sqlx_error_to_query_err)
313                    }
314                    #[cfg(feature = "sqlx-postgres")]
315                    InnerConnection::Postgres(c) => {
316                        <sqlx::Postgres as sqlx::Database>::TransactionManager::rollback(c)
317                            .await
318                            .map_err(sqlx_error_to_query_err)
319                    }
320                    #[cfg(feature = "sqlx-sqlite")]
321                    InnerConnection::Sqlite(c) => {
322                        <sqlx::Sqlite as sqlx::Database>::TransactionManager::rollback(c)
323                            .await
324                            .map_err(sqlx_error_to_query_err)
325                    }
326                    #[cfg(feature = "rusqlite")]
327                    InnerConnection::Rusqlite(c) => c.rollback(),
328                    #[cfg(feature = "mock")]
329                    InnerConnection::Mock(c) => {
330                        c.rollback();
331                        Ok(())
332                    }
333                    #[cfg(feature = "proxy")]
334                    InnerConnection::Proxy(c) => {
335                        c.rollback().await;
336                        Ok(())
337                    }
338                    #[allow(unreachable_patterns)]
339                    _ => Err(conn_err("Disconnected")),
340                }
341            }
342        );
343
344        result?;
345        self.open = false; // read by start_rollback
346        Ok(())
347    }
348
349    // the rollback is queued and will be performed on next async operation, like returning the connection to the pool
350    #[instrument(level = "trace")]
351    fn start_rollback(&mut self) -> Result<(), DbErr> {
352        if self.open {
353            if let Some(mut conn) = self.conn.try_lock() {
354                match &mut *conn {
355                    #[cfg(feature = "sqlx-mysql")]
356                    InnerConnection::MySql(c) => {
357                        <sqlx::MySql as sqlx::Database>::TransactionManager::start_rollback(c);
358                    }
359                    #[cfg(feature = "sqlx-postgres")]
360                    InnerConnection::Postgres(c) => {
361                        <sqlx::Postgres as sqlx::Database>::TransactionManager::start_rollback(c);
362                    }
363                    #[cfg(feature = "sqlx-sqlite")]
364                    InnerConnection::Sqlite(c) => {
365                        <sqlx::Sqlite as sqlx::Database>::TransactionManager::start_rollback(c);
366                    }
367                    #[cfg(feature = "rusqlite")]
368                    InnerConnection::Rusqlite(c) => {
369                        c.start_rollback()?;
370                    }
371                    #[cfg(feature = "mock")]
372                    InnerConnection::Mock(c) => {
373                        c.rollback();
374                    }
375                    #[cfg(feature = "proxy")]
376                    InnerConnection::Proxy(c) => {
377                        c.start_rollback();
378                    }
379                    #[allow(unreachable_patterns)]
380                    _ => return Err(conn_err("Disconnected")),
381                }
382            } else {
383                //this should never happen
384                return Err(conn_err("Dropping a locked Transaction"));
385            }
386        }
387        Ok(())
388    }
389}
390
391#[async_trait::async_trait]
392impl TransactionSession for DatabaseTransaction {
393    async fn commit(self) -> Result<(), DbErr> {
394        self.commit().await
395    }
396
397    async fn rollback(self) -> Result<(), DbErr> {
398        self.rollback().await
399    }
400}
401
402impl Drop for DatabaseTransaction {
403    fn drop(&mut self) {
404        self.start_rollback().expect("Fail to rollback transaction");
405    }
406}
407
408#[async_trait::async_trait]
409impl ConnectionTrait for DatabaseTransaction {
410    fn get_database_backend(&self) -> DbBackend {
411        // this way we don't need to lock just to know the backend
412        self.backend
413    }
414
415    #[instrument(level = "trace", skip(stmt))]
416    #[allow(unused_variables)]
417    async fn execute_raw(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
418        debug_print!("{}", stmt);
419
420        super::tracing_spans::with_db_span!(
421            "sea_orm.execute",
422            self.backend,
423            stmt.sql.as_str(),
424            record_stmt = self.record_stmt_in_spans,
425            async {
426                #[cfg(not(feature = "sync"))]
427                let conn = &mut *self.conn.lock().await;
428                #[cfg(feature = "sync")]
429                let conn = &mut *self.conn.lock().map_err(|_| DbErr::MutexPoisonError)?;
430
431                match conn {
432                    #[cfg(feature = "sqlx-mysql")]
433                    InnerConnection::MySql(conn) => {
434                        let query = crate::driver::sqlx_mysql::sqlx_query(&stmt);
435                        let conn: &mut sqlx::MySqlConnection = &mut *conn;
436                        crate::metric::metric!(self.metric_callback, &stmt, {
437                            query.execute(conn).await.map(Into::into)
438                        })
439                        .map_err(sqlx_error_to_exec_err)
440                    }
441                    #[cfg(feature = "sqlx-postgres")]
442                    InnerConnection::Postgres(conn) => {
443                        let query = crate::driver::sqlx_postgres::sqlx_query(&stmt);
444                        let conn: &mut sqlx::PgConnection = &mut *conn;
445                        crate::metric::metric!(self.metric_callback, &stmt, {
446                            query.execute(conn).await.map(Into::into)
447                        })
448                        .map_err(sqlx_error_to_exec_err)
449                    }
450                    #[cfg(feature = "sqlx-sqlite")]
451                    InnerConnection::Sqlite(conn) => {
452                        let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt);
453                        let conn: &mut sqlx::SqliteConnection = &mut *conn;
454                        crate::metric::metric!(self.metric_callback, &stmt, {
455                            query.execute(conn).await.map(Into::into)
456                        })
457                        .map_err(sqlx_error_to_exec_err)
458                    }
459                    #[cfg(feature = "rusqlite")]
460                    InnerConnection::Rusqlite(conn) => conn.execute(stmt, &self.metric_callback),
461                    #[cfg(feature = "mock")]
462                    InnerConnection::Mock(conn) => conn.execute(stmt),
463                    #[cfg(feature = "proxy")]
464                    InnerConnection::Proxy(conn) => conn.execute(stmt).await,
465                    #[allow(unreachable_patterns)]
466                    _ => Err(conn_err("Disconnected")),
467                }
468            }
469        )
470    }
471
472    #[instrument(level = "trace", skip(sql))]
473    #[allow(unused_variables)]
474    async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
475        debug_print!("{}", sql);
476
477        super::tracing_spans::with_db_span!(
478            "sea_orm.execute_unprepared",
479            self.backend,
480            sql,
481            record_stmt = false,
482            async {
483                #[cfg(not(feature = "sync"))]
484                let conn = &mut *self.conn.lock().await;
485                #[cfg(feature = "sync")]
486                let conn = &mut *self.conn.lock().map_err(|_| DbErr::MutexPoisonError)?;
487
488                match conn {
489                    #[cfg(feature = "sqlx-mysql")]
490                    InnerConnection::MySql(conn) => {
491                        let conn: &mut sqlx::MySqlConnection = &mut *conn;
492                        sqlx::Executor::execute(conn, sqlx::AssertSqlSafe(sql.to_owned()))
493                            .await
494                            .map(Into::into)
495                            .map_err(sqlx_error_to_exec_err)
496                    }
497                    #[cfg(feature = "sqlx-postgres")]
498                    InnerConnection::Postgres(conn) => {
499                        let conn: &mut sqlx::PgConnection = &mut *conn;
500                        sqlx::Executor::execute(conn, sqlx::AssertSqlSafe(sql.to_owned()))
501                            .await
502                            .map(Into::into)
503                            .map_err(sqlx_error_to_exec_err)
504                    }
505                    #[cfg(feature = "sqlx-sqlite")]
506                    InnerConnection::Sqlite(conn) => {
507                        let conn: &mut sqlx::SqliteConnection = &mut *conn;
508                        sqlx::Executor::execute(conn, sqlx::AssertSqlSafe(sql.to_owned()))
509                            .await
510                            .map(Into::into)
511                            .map_err(sqlx_error_to_exec_err)
512                    }
513                    #[cfg(feature = "rusqlite")]
514                    InnerConnection::Rusqlite(conn) => conn.execute_unprepared(sql),
515                    #[cfg(feature = "mock")]
516                    InnerConnection::Mock(conn) => {
517                        let db_backend = conn.get_database_backend();
518                        let stmt = Statement::from_string(db_backend, sql);
519                        conn.execute(stmt)
520                    }
521                    #[cfg(feature = "proxy")]
522                    InnerConnection::Proxy(conn) => {
523                        let db_backend = conn.get_database_backend();
524                        let stmt = Statement::from_string(db_backend, sql);
525                        conn.execute(stmt).await
526                    }
527                    #[allow(unreachable_patterns)]
528                    _ => Err(conn_err("Disconnected")),
529                }
530            }
531        )
532    }
533
534    #[instrument(level = "trace", skip(stmt))]
535    #[allow(unused_variables)]
536    async fn query_one_raw(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
537        debug_print!("{}", stmt);
538
539        super::tracing_spans::with_db_span!(
540            "sea_orm.query_one",
541            self.backend,
542            stmt.sql.as_str(),
543            record_stmt = self.record_stmt_in_spans,
544            async {
545                #[cfg(not(feature = "sync"))]
546                let conn = &mut *self.conn.lock().await;
547                #[cfg(feature = "sync")]
548                let conn = &mut *self.conn.lock().map_err(|_| DbErr::MutexPoisonError)?;
549
550                match conn {
551                    #[cfg(feature = "sqlx-mysql")]
552                    InnerConnection::MySql(conn) => {
553                        let query = crate::driver::sqlx_mysql::sqlx_query(&stmt);
554                        let conn: &mut sqlx::MySqlConnection = &mut *conn;
555                        crate::metric::metric!(self.metric_callback, &stmt, {
556                            crate::sqlx_map_err_ignore_not_found(
557                                query.fetch_one(conn).await.map(|row| Some(row.into())),
558                            )
559                        })
560                    }
561                    #[cfg(feature = "sqlx-postgres")]
562                    InnerConnection::Postgres(conn) => {
563                        let query = crate::driver::sqlx_postgres::sqlx_query(&stmt);
564                        let conn: &mut sqlx::PgConnection = &mut *conn;
565                        crate::metric::metric!(self.metric_callback, &stmt, {
566                            crate::sqlx_map_err_ignore_not_found(
567                                query.fetch_one(conn).await.map(|row| Some(row.into())),
568                            )
569                        })
570                    }
571                    #[cfg(feature = "sqlx-sqlite")]
572                    InnerConnection::Sqlite(conn) => {
573                        let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt);
574                        let conn: &mut sqlx::SqliteConnection = &mut *conn;
575                        crate::metric::metric!(self.metric_callback, &stmt, {
576                            crate::sqlx_map_err_ignore_not_found(
577                                query.fetch_one(conn).await.map(|row| Some(row.into())),
578                            )
579                        })
580                    }
581                    #[cfg(feature = "rusqlite")]
582                    InnerConnection::Rusqlite(conn) => conn.query_one(stmt, &self.metric_callback),
583                    #[cfg(feature = "mock")]
584                    InnerConnection::Mock(conn) => conn.query_one(stmt),
585                    #[cfg(feature = "proxy")]
586                    InnerConnection::Proxy(conn) => conn.query_one(stmt).await,
587                    #[allow(unreachable_patterns)]
588                    _ => Err(conn_err("Disconnected")),
589                }
590            }
591        )
592    }
593
594    #[instrument(level = "trace", skip(stmt))]
595    #[allow(unused_variables)]
596    async fn query_all_raw(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
597        debug_print!("{}", stmt);
598
599        super::tracing_spans::with_db_span!(
600            "sea_orm.query_all",
601            self.backend,
602            stmt.sql.as_str(),
603            record_stmt = self.record_stmt_in_spans,
604            async {
605                #[cfg(not(feature = "sync"))]
606                let conn = &mut *self.conn.lock().await;
607                #[cfg(feature = "sync")]
608                let conn = &mut *self.conn.lock().map_err(|_| DbErr::MutexPoisonError)?;
609
610                match conn {
611                    #[cfg(feature = "sqlx-mysql")]
612                    InnerConnection::MySql(conn) => {
613                        let query = crate::driver::sqlx_mysql::sqlx_query(&stmt);
614                        let conn: &mut sqlx::MySqlConnection = &mut *conn;
615                        crate::metric::metric!(self.metric_callback, &stmt, {
616                            query
617                                .fetch_all(conn)
618                                .await
619                                .map(|rows| rows.into_iter().map(|r| r.into()).collect())
620                                .map_err(sqlx_error_to_query_err)
621                        })
622                    }
623                    #[cfg(feature = "sqlx-postgres")]
624                    InnerConnection::Postgres(conn) => {
625                        let query = crate::driver::sqlx_postgres::sqlx_query(&stmt);
626                        let conn: &mut sqlx::PgConnection = &mut *conn;
627                        crate::metric::metric!(self.metric_callback, &stmt, {
628                            query
629                                .fetch_all(conn)
630                                .await
631                                .map(|rows| rows.into_iter().map(|r| r.into()).collect())
632                                .map_err(sqlx_error_to_query_err)
633                        })
634                    }
635                    #[cfg(feature = "sqlx-sqlite")]
636                    InnerConnection::Sqlite(conn) => {
637                        let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt);
638                        let conn: &mut sqlx::SqliteConnection = &mut *conn;
639                        crate::metric::metric!(self.metric_callback, &stmt, {
640                            query
641                                .fetch_all(conn)
642                                .await
643                                .map(|rows| rows.into_iter().map(|r| r.into()).collect())
644                                .map_err(sqlx_error_to_query_err)
645                        })
646                    }
647                    #[cfg(feature = "rusqlite")]
648                    InnerConnection::Rusqlite(conn) => conn.query_all(stmt, &self.metric_callback),
649                    #[cfg(feature = "mock")]
650                    InnerConnection::Mock(conn) => conn.query_all(stmt),
651                    #[cfg(feature = "proxy")]
652                    InnerConnection::Proxy(conn) => conn.query_all(stmt).await,
653                    #[allow(unreachable_patterns)]
654                    _ => Err(conn_err("Disconnected")),
655                }
656            }
657        )
658    }
659}
660
661#[cfg(feature = "stream")]
662impl StreamTrait for DatabaseTransaction {
663    type Stream<'a> = TransactionStream<'a>;
664
665    fn get_database_backend(&self) -> DbBackend {
666        self.backend
667    }
668
669    #[instrument(level = "trace", skip(stmt))]
670    fn stream_raw<'a>(
671        &'a self,
672        stmt: Statement,
673    ) -> Pin<Box<dyn Future<Output = Result<Self::Stream<'a>, DbErr>> + 'a + Send>> {
674        Box::pin(async move {
675            #[cfg(not(feature = "sync"))]
676            let conn = self.conn.lock().await;
677            #[cfg(feature = "sync")]
678            let conn = self.conn.lock().map_err(|_| DbErr::MutexPoisonError)?;
679            Ok(crate::TransactionStream::build(
680                conn,
681                stmt,
682                self.metric_callback.clone(),
683            ))
684        })
685    }
686}
687
688#[async_trait::async_trait]
689impl TransactionTrait for DatabaseTransaction {
690    type Transaction = DatabaseTransaction;
691
692    #[instrument(level = "trace")]
693    async fn begin(&self) -> Result<DatabaseTransaction, DbErr> {
694        DatabaseTransaction::begin(
695            Arc::clone(&self.conn),
696            self.backend,
697            self.metric_callback.clone(),
698            self.record_stmt_in_spans,
699            None,
700            None,
701            None,
702        )
703        .await
704    }
705
706    #[instrument(level = "trace")]
707    async fn begin_with_config(
708        &self,
709        isolation_level: Option<IsolationLevel>,
710        access_mode: Option<AccessMode>,
711    ) -> Result<DatabaseTransaction, DbErr> {
712        DatabaseTransaction::begin(
713            Arc::clone(&self.conn),
714            self.backend,
715            self.metric_callback.clone(),
716            self.record_stmt_in_spans,
717            isolation_level,
718            access_mode,
719            None,
720        )
721        .await
722    }
723
724    #[instrument(level = "trace")]
725    async fn begin_with_options(
726        &self,
727        options: TransactionOptions,
728    ) -> Result<DatabaseTransaction, DbErr> {
729        DatabaseTransaction::begin(
730            Arc::clone(&self.conn),
731            self.backend,
732            self.metric_callback.clone(),
733            self.record_stmt_in_spans,
734            options.isolation_level,
735            options.access_mode,
736            options.sqlite_transaction_mode,
737        )
738        .await
739    }
740
741    /// Execute the async function inside a transaction.
742    /// If the function returns an error, the transaction will be rolled back.
743    /// Otherwise, the transaction will be committed.
744    #[instrument(level = "trace", skip(_callback))]
745    async fn transaction<F, T, E>(&self, _callback: F) -> Result<T, TransactionError<E>>
746    where
747        F: for<'c> FnOnce(
748                &'c DatabaseTransaction,
749            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
750            + Send,
751        T: Send,
752        E: std::fmt::Display + std::fmt::Debug + Send,
753    {
754        let transaction = self.begin().await.map_err(TransactionError::Connection)?;
755        transaction.run(_callback).await
756    }
757
758    /// Execute the async function inside a transaction.
759    /// If the function returns an error, the transaction will be rolled back.
760    /// Otherwise, the transaction will be committed.
761    #[instrument(level = "trace", skip(_callback))]
762    async fn transaction_with_config<F, T, E>(
763        &self,
764        _callback: F,
765        isolation_level: Option<IsolationLevel>,
766        access_mode: Option<AccessMode>,
767    ) -> Result<T, TransactionError<E>>
768    where
769        F: for<'c> FnOnce(
770                &'c DatabaseTransaction,
771            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
772            + Send,
773        T: Send,
774        E: std::fmt::Display + std::fmt::Debug + Send,
775    {
776        let transaction = self
777            .begin_with_config(isolation_level, access_mode)
778            .await
779            .map_err(TransactionError::Connection)?;
780        transaction.run(_callback).await
781    }
782}
783
784/// Error returned by [`TransactionTrait::transaction`](crate::TransactionTrait::transaction):
785/// either the database itself failed, or the user closure returned an `Err`
786/// (causing a rollback).
787#[derive(Debug)]
788pub enum TransactionError<E> {
789    /// Failure issuing `BEGIN`, `COMMIT`, or `ROLLBACK`.
790    Connection(DbErr),
791    /// The closure returned an `Err` (the transaction has been rolled back).
792    Transaction(E),
793}
794
795impl<E> std::fmt::Display for TransactionError<E>
796where
797    E: std::fmt::Display + std::fmt::Debug,
798{
799    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
800        match self {
801            TransactionError::Connection(e) => std::fmt::Display::fmt(e, f),
802            TransactionError::Transaction(e) => std::fmt::Display::fmt(e, f),
803        }
804    }
805}
806
807impl<E> std::error::Error for TransactionError<E> where E: std::fmt::Display + std::fmt::Debug {}
808
809impl<E> From<DbErr> for TransactionError<E>
810where
811    E: std::fmt::Display + std::fmt::Debug,
812{
813    fn from(e: DbErr) -> Self {
814        Self::Connection(e)
815    }
816}