Skip to main content

sea_orm/database/
transaction.rs

1#![allow(unused_assignments)]
2use std::{future::Future, pin::Pin, sync::Arc};
3
4use futures_util::lock::Mutex;
5#[cfg(feature = "sqlx-sqlite")]
6use sqlx_core::sql_str::SqlSafeStr;
7#[cfg(feature = "sqlx-dep")]
8use sqlx_core::transaction::TransactionManager;
9use tracing::instrument;
10
11use crate::{
12    AccessMode, ConnectionTrait, DbBackend, DbErr, ExecResult, InnerConnection, IsolationLevel,
13    QueryResult, SqliteTransactionMode, Statement, StreamTrait, TransactionOptions,
14    TransactionSession, TransactionStream, TransactionTrait, debug_print, error::*,
15};
16#[cfg(feature = "sqlx-dep")]
17use crate::{sqlx_error_to_exec_err, sqlx_error_to_query_err};
18
19/// Defines a database transaction, whether it is an open transaction and the type of
20/// backend to use.
21/// Under the hood, a Transaction is just a wrapper for a connection where
22/// START TRANSACTION has been executed.
23pub struct DatabaseTransaction {
24    conn: Arc<Mutex<InnerConnection>>,
25    backend: DbBackend,
26    open: bool,
27    metric_callback: Option<crate::metric::Callback>,
28    record_stmt_in_spans: bool,
29}
30
31#[instrument(level = "trace", skip(transaction, callback))]
32pub(crate) async fn run_async_transaction_callback<Txn, F, T, E>(
33    transaction: Txn,
34    callback: F,
35) -> Result<T, TransactionError<E>>
36where
37    Txn: TransactionSession + Send + Sync,
38    F: for<'b> AsyncFnOnce(&'b Txn) -> Result<T, E> + Send,
39    T: Send,
40    E: std::fmt::Display + std::fmt::Debug + Send,
41{
42    let res = callback(&transaction)
43        .await
44        .map_err(TransactionError::Transaction);
45    if res.is_ok() {
46        transaction
47            .commit()
48            .await
49            .map_err(TransactionError::Connection)?;
50    } else {
51        transaction
52            .rollback()
53            .await
54            .map_err(TransactionError::Connection)?;
55    }
56    res
57}
58
59impl std::fmt::Debug for DatabaseTransaction {
60    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61        write!(f, "DatabaseTransaction")
62    }
63}
64
65impl DatabaseTransaction {
66    #[instrument(level = "trace", skip(metric_callback))]
67    pub(crate) async fn begin(
68        conn: Arc<Mutex<InnerConnection>>,
69        backend: DbBackend,
70        metric_callback: Option<crate::metric::Callback>,
71        record_stmt_in_spans: bool,
72        isolation_level: Option<IsolationLevel>,
73        access_mode: Option<AccessMode>,
74        sqlite_transaction_mode: Option<SqliteTransactionMode>,
75    ) -> Result<DatabaseTransaction, DbErr> {
76        let res = DatabaseTransaction {
77            conn,
78            backend,
79            open: true,
80            metric_callback,
81            record_stmt_in_spans,
82        };
83
84        let begin_result: Result<(), DbErr> = super::tracing_spans::with_db_span!(
85            "sea_orm.begin",
86            backend,
87            "BEGIN",
88            record_stmt = false,
89            async {
90                #[cfg(not(feature = "sync"))]
91                let conn = &mut *res.conn.lock().await;
92                #[cfg(feature = "sync")]
93                let conn = &mut *res.conn.lock().map_err(|_| DbErr::MutexPoisonError)?;
94
95                match conn {
96                    #[cfg(feature = "sqlx-mysql")]
97                    InnerConnection::MySql(c) => {
98                        // in MySQL SET TRANSACTION operations must be executed before transaction start
99                        crate::driver::sqlx_mysql::set_transaction_config(
100                            c,
101                            isolation_level,
102                            access_mode,
103                        )
104                        .await?;
105                        <sqlx::MySql as sqlx::Database>::TransactionManager::begin(c, None)
106                            .await
107                            .map_err(sqlx_error_to_query_err)
108                    }
109                    #[cfg(feature = "sqlx-postgres")]
110                    InnerConnection::Postgres(c) => {
111                        <sqlx::Postgres as sqlx::Database>::TransactionManager::begin(c, None)
112                            .await
113                            .map_err(sqlx_error_to_query_err)?;
114                        // in PostgreSQL SET TRANSACTION operations must be executed inside transaction
115                        crate::driver::sqlx_postgres::set_transaction_config(
116                            c,
117                            isolation_level,
118                            access_mode,
119                        )
120                        .await
121                    }
122                    #[cfg(feature = "sqlx-sqlite")]
123                    InnerConnection::Sqlite(c) => {
124                        crate::driver::sqlx_sqlite::set_transaction_config(
125                            c,
126                            isolation_level,
127                            access_mode,
128                        )
129                        .await?;
130                        let depth = <sqlx::Sqlite as sqlx::Database>::TransactionManager::get_transaction_depth(c);
131                        let statement = if depth == 0 {
132                            sqlite_transaction_mode.map(|mode| {
133                                sqlx::AssertSqlSafe(format!("BEGIN {}", mode.sqlite_keyword()))
134                                    .into_sql_str()
135                            })
136                        } else {
137                            // Nested transaction uses SAVEPOINT; the mode only applies to the top-level BEGIN
138                            None
139                        };
140                        <sqlx::Sqlite as sqlx::Database>::TransactionManager::begin(c, statement)
141                            .await
142                            .map_err(sqlx_error_to_query_err)
143                    }
144                    #[cfg(feature = "rusqlite")]
145                    InnerConnection::Rusqlite(c) => c.begin(sqlite_transaction_mode),
146                    #[cfg(feature = "mock")]
147                    InnerConnection::Mock(c) => {
148                        c.begin();
149                        Ok(())
150                    }
151                    #[cfg(feature = "proxy")]
152                    InnerConnection::Proxy(c) => {
153                        c.begin().await;
154                        Ok(())
155                    }
156                    #[allow(unreachable_patterns)]
157                    _ => Err(conn_err("Disconnected")),
158                }
159            }
160        );
161
162        begin_result?;
163        Ok(res)
164    }
165
166    /// Runs a transaction to completion passing through the result.
167    /// Rolling back the transaction on encountering an error.
168    #[instrument(level = "trace", skip(callback))]
169    pub(crate) async fn run<F, T, E>(self, callback: F) -> Result<T, TransactionError<E>>
170    where
171        F: for<'b> FnOnce(
172                &'b DatabaseTransaction,
173            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'b>>
174            + Send,
175        T: Send,
176        E: std::fmt::Display + std::fmt::Debug + Send,
177    {
178        let res = callback(&self).await.map_err(TransactionError::Transaction);
179        if res.is_ok() {
180            self.commit().await.map_err(TransactionError::Connection)?;
181        } else {
182            self.rollback()
183                .await
184                .map_err(TransactionError::Connection)?;
185        }
186        res
187    }
188
189    /// Execute the function inside a transaction.
190    /// If the function returns an error, the transaction will be rolled back.
191    /// Otherwise, the transaction will be committed.
192    #[instrument(level = "trace", skip(callback))]
193    pub async fn transaction_async<F, T, E>(&self, callback: F) -> Result<T, TransactionError<E>>
194    where
195        F: for<'c> AsyncFnOnce(&'c DatabaseTransaction) -> Result<T, E> + Send,
196        T: Send,
197        E: std::fmt::Display + std::fmt::Debug + Send,
198    {
199        let transaction = self.begin().await.map_err(TransactionError::Connection)?;
200        run_async_transaction_callback(transaction, callback).await
201    }
202
203    /// Execute the function inside a transaction with isolation level and/or access mode.
204    /// If the function returns an error, the transaction will be rolled back.
205    /// Otherwise, the transaction will be committed.
206    #[instrument(level = "trace", skip(callback))]
207    pub async fn transaction_with_config_async<F, T, E>(
208        &self,
209        callback: F,
210        isolation_level: Option<IsolationLevel>,
211        access_mode: Option<AccessMode>,
212    ) -> Result<T, TransactionError<E>>
213    where
214        F: for<'c> AsyncFnOnce(&'c DatabaseTransaction) -> Result<T, E> + Send,
215        T: Send,
216        E: std::fmt::Display + std::fmt::Debug + Send,
217    {
218        let transaction = self
219            .begin_with_config(isolation_level, access_mode)
220            .await
221            .map_err(TransactionError::Connection)?;
222        run_async_transaction_callback(transaction, callback).await
223    }
224
225    /// Commit a transaction
226    #[instrument(level = "trace")]
227    #[allow(unreachable_code, unused_mut)]
228    pub async fn commit(mut self) -> Result<(), DbErr> {
229        let result: Result<(), DbErr> = super::tracing_spans::with_db_span!(
230            "sea_orm.commit",
231            self.backend,
232            "COMMIT",
233            record_stmt = false,
234            async {
235                #[cfg(not(feature = "sync"))]
236                let conn = &mut *self.conn.lock().await;
237                #[cfg(feature = "sync")]
238                let conn = &mut *self.conn.lock().map_err(|_| DbErr::MutexPoisonError)?;
239
240                match conn {
241                    #[cfg(feature = "sqlx-mysql")]
242                    InnerConnection::MySql(c) => {
243                        <sqlx::MySql as sqlx::Database>::TransactionManager::commit(c)
244                            .await
245                            .map_err(sqlx_error_to_query_err)
246                    }
247                    #[cfg(feature = "sqlx-postgres")]
248                    InnerConnection::Postgres(c) => {
249                        <sqlx::Postgres as sqlx::Database>::TransactionManager::commit(c)
250                            .await
251                            .map_err(sqlx_error_to_query_err)
252                    }
253                    #[cfg(feature = "sqlx-sqlite")]
254                    InnerConnection::Sqlite(c) => {
255                        <sqlx::Sqlite as sqlx::Database>::TransactionManager::commit(c)
256                            .await
257                            .map_err(sqlx_error_to_query_err)
258                    }
259                    #[cfg(feature = "rusqlite")]
260                    InnerConnection::Rusqlite(c) => c.commit(),
261                    #[cfg(feature = "mock")]
262                    InnerConnection::Mock(c) => {
263                        c.commit();
264                        Ok(())
265                    }
266                    #[cfg(feature = "proxy")]
267                    InnerConnection::Proxy(c) => {
268                        c.commit().await;
269                        Ok(())
270                    }
271                    #[allow(unreachable_patterns)]
272                    _ => Err(conn_err("Disconnected")),
273                }
274            }
275        );
276
277        result?;
278        self.open = false; // read by start_rollback
279        Ok(())
280    }
281
282    /// Rolls back a transaction explicitly
283    #[instrument(level = "trace")]
284    #[allow(unreachable_code, unused_mut)]
285    pub async fn rollback(mut self) -> Result<(), DbErr> {
286        let result: Result<(), DbErr> = super::tracing_spans::with_db_span!(
287            "sea_orm.rollback",
288            self.backend,
289            "ROLLBACK",
290            record_stmt = false,
291            async {
292                #[cfg(not(feature = "sync"))]
293                let conn = &mut *self.conn.lock().await;
294                #[cfg(feature = "sync")]
295                let conn = &mut *self.conn.lock().map_err(|_| DbErr::MutexPoisonError)?;
296
297                match conn {
298                    #[cfg(feature = "sqlx-mysql")]
299                    InnerConnection::MySql(c) => {
300                        <sqlx::MySql as sqlx::Database>::TransactionManager::rollback(c)
301                            .await
302                            .map_err(sqlx_error_to_query_err)
303                    }
304                    #[cfg(feature = "sqlx-postgres")]
305                    InnerConnection::Postgres(c) => {
306                        <sqlx::Postgres as sqlx::Database>::TransactionManager::rollback(c)
307                            .await
308                            .map_err(sqlx_error_to_query_err)
309                    }
310                    #[cfg(feature = "sqlx-sqlite")]
311                    InnerConnection::Sqlite(c) => {
312                        <sqlx::Sqlite as sqlx::Database>::TransactionManager::rollback(c)
313                            .await
314                            .map_err(sqlx_error_to_query_err)
315                    }
316                    #[cfg(feature = "rusqlite")]
317                    InnerConnection::Rusqlite(c) => c.rollback(),
318                    #[cfg(feature = "mock")]
319                    InnerConnection::Mock(c) => {
320                        c.rollback();
321                        Ok(())
322                    }
323                    #[cfg(feature = "proxy")]
324                    InnerConnection::Proxy(c) => {
325                        c.rollback().await;
326                        Ok(())
327                    }
328                    #[allow(unreachable_patterns)]
329                    _ => Err(conn_err("Disconnected")),
330                }
331            }
332        );
333
334        result?;
335        self.open = false; // read by start_rollback
336        Ok(())
337    }
338
339    // the rollback is queued and will be performed on next async operation, like returning the connection to the pool
340    #[instrument(level = "trace")]
341    fn start_rollback(&mut self) -> Result<(), DbErr> {
342        if self.open {
343            if let Some(mut conn) = self.conn.try_lock() {
344                match &mut *conn {
345                    #[cfg(feature = "sqlx-mysql")]
346                    InnerConnection::MySql(c) => {
347                        <sqlx::MySql as sqlx::Database>::TransactionManager::start_rollback(c);
348                    }
349                    #[cfg(feature = "sqlx-postgres")]
350                    InnerConnection::Postgres(c) => {
351                        <sqlx::Postgres as sqlx::Database>::TransactionManager::start_rollback(c);
352                    }
353                    #[cfg(feature = "sqlx-sqlite")]
354                    InnerConnection::Sqlite(c) => {
355                        <sqlx::Sqlite as sqlx::Database>::TransactionManager::start_rollback(c);
356                    }
357                    #[cfg(feature = "rusqlite")]
358                    InnerConnection::Rusqlite(c) => {
359                        c.start_rollback()?;
360                    }
361                    #[cfg(feature = "mock")]
362                    InnerConnection::Mock(c) => {
363                        c.rollback();
364                    }
365                    #[cfg(feature = "proxy")]
366                    InnerConnection::Proxy(c) => {
367                        c.start_rollback();
368                    }
369                    #[allow(unreachable_patterns)]
370                    _ => return Err(conn_err("Disconnected")),
371                }
372            } else {
373                //this should never happen
374                return Err(conn_err("Dropping a locked Transaction"));
375            }
376        }
377        Ok(())
378    }
379}
380
381#[async_trait::async_trait]
382impl TransactionSession for DatabaseTransaction {
383    async fn commit(self) -> Result<(), DbErr> {
384        self.commit().await
385    }
386
387    async fn rollback(self) -> Result<(), DbErr> {
388        self.rollback().await
389    }
390}
391
392impl Drop for DatabaseTransaction {
393    fn drop(&mut self) {
394        self.start_rollback().expect("Fail to rollback transaction");
395    }
396}
397
398#[async_trait::async_trait]
399impl ConnectionTrait for DatabaseTransaction {
400    fn get_database_backend(&self) -> DbBackend {
401        // this way we don't need to lock just to know the backend
402        self.backend
403    }
404
405    #[instrument(level = "trace")]
406    #[allow(unused_variables)]
407    async fn execute_raw(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
408        debug_print!("{}", stmt);
409
410        super::tracing_spans::with_db_span!(
411            "sea_orm.execute",
412            self.backend,
413            stmt.sql.as_str(),
414            record_stmt = self.record_stmt_in_spans,
415            async {
416                #[cfg(not(feature = "sync"))]
417                let conn = &mut *self.conn.lock().await;
418                #[cfg(feature = "sync")]
419                let conn = &mut *self.conn.lock().map_err(|_| DbErr::MutexPoisonError)?;
420
421                match conn {
422                    #[cfg(feature = "sqlx-mysql")]
423                    InnerConnection::MySql(conn) => {
424                        let query = crate::driver::sqlx_mysql::sqlx_query(&stmt);
425                        let conn: &mut sqlx::MySqlConnection = &mut *conn;
426                        crate::metric::metric!(self.metric_callback, &stmt, {
427                            query.execute(conn).await.map(Into::into)
428                        })
429                        .map_err(sqlx_error_to_exec_err)
430                    }
431                    #[cfg(feature = "sqlx-postgres")]
432                    InnerConnection::Postgres(conn) => {
433                        let query = crate::driver::sqlx_postgres::sqlx_query(&stmt);
434                        let conn: &mut sqlx::PgConnection = &mut *conn;
435                        crate::metric::metric!(self.metric_callback, &stmt, {
436                            query.execute(conn).await.map(Into::into)
437                        })
438                        .map_err(sqlx_error_to_exec_err)
439                    }
440                    #[cfg(feature = "sqlx-sqlite")]
441                    InnerConnection::Sqlite(conn) => {
442                        let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt);
443                        let conn: &mut sqlx::SqliteConnection = &mut *conn;
444                        crate::metric::metric!(self.metric_callback, &stmt, {
445                            query.execute(conn).await.map(Into::into)
446                        })
447                        .map_err(sqlx_error_to_exec_err)
448                    }
449                    #[cfg(feature = "rusqlite")]
450                    InnerConnection::Rusqlite(conn) => conn.execute(stmt, &self.metric_callback),
451                    #[cfg(feature = "mock")]
452                    InnerConnection::Mock(conn) => conn.execute(stmt),
453                    #[cfg(feature = "proxy")]
454                    InnerConnection::Proxy(conn) => conn.execute(stmt).await,
455                    #[allow(unreachable_patterns)]
456                    _ => Err(conn_err("Disconnected")),
457                }
458            }
459        )
460    }
461
462    #[instrument(level = "trace")]
463    #[allow(unused_variables)]
464    async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
465        debug_print!("{}", sql);
466
467        super::tracing_spans::with_db_span!(
468            "sea_orm.execute_unprepared",
469            self.backend,
470            sql,
471            record_stmt = false,
472            async {
473                #[cfg(not(feature = "sync"))]
474                let conn = &mut *self.conn.lock().await;
475                #[cfg(feature = "sync")]
476                let conn = &mut *self.conn.lock().map_err(|_| DbErr::MutexPoisonError)?;
477
478                match conn {
479                    #[cfg(feature = "sqlx-mysql")]
480                    InnerConnection::MySql(conn) => {
481                        let conn: &mut sqlx::MySqlConnection = &mut *conn;
482                        sqlx::Executor::execute(conn, sqlx::AssertSqlSafe(sql.to_owned()))
483                            .await
484                            .map(Into::into)
485                            .map_err(sqlx_error_to_exec_err)
486                    }
487                    #[cfg(feature = "sqlx-postgres")]
488                    InnerConnection::Postgres(conn) => {
489                        let conn: &mut sqlx::PgConnection = &mut *conn;
490                        sqlx::Executor::execute(conn, sqlx::AssertSqlSafe(sql.to_owned()))
491                            .await
492                            .map(Into::into)
493                            .map_err(sqlx_error_to_exec_err)
494                    }
495                    #[cfg(feature = "sqlx-sqlite")]
496                    InnerConnection::Sqlite(conn) => {
497                        let conn: &mut sqlx::SqliteConnection = &mut *conn;
498                        sqlx::Executor::execute(conn, sqlx::AssertSqlSafe(sql.to_owned()))
499                            .await
500                            .map(Into::into)
501                            .map_err(sqlx_error_to_exec_err)
502                    }
503                    #[cfg(feature = "rusqlite")]
504                    InnerConnection::Rusqlite(conn) => conn.execute_unprepared(sql),
505                    #[cfg(feature = "mock")]
506                    InnerConnection::Mock(conn) => {
507                        let db_backend = conn.get_database_backend();
508                        let stmt = Statement::from_string(db_backend, sql);
509                        conn.execute(stmt)
510                    }
511                    #[cfg(feature = "proxy")]
512                    InnerConnection::Proxy(conn) => {
513                        let db_backend = conn.get_database_backend();
514                        let stmt = Statement::from_string(db_backend, sql);
515                        conn.execute(stmt).await
516                    }
517                    #[allow(unreachable_patterns)]
518                    _ => Err(conn_err("Disconnected")),
519                }
520            }
521        )
522    }
523
524    #[instrument(level = "trace")]
525    #[allow(unused_variables)]
526    async fn query_one_raw(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
527        debug_print!("{}", stmt);
528
529        super::tracing_spans::with_db_span!(
530            "sea_orm.query_one",
531            self.backend,
532            stmt.sql.as_str(),
533            record_stmt = self.record_stmt_in_spans,
534            async {
535                #[cfg(not(feature = "sync"))]
536                let conn = &mut *self.conn.lock().await;
537                #[cfg(feature = "sync")]
538                let conn = &mut *self.conn.lock().map_err(|_| DbErr::MutexPoisonError)?;
539
540                match conn {
541                    #[cfg(feature = "sqlx-mysql")]
542                    InnerConnection::MySql(conn) => {
543                        let query = crate::driver::sqlx_mysql::sqlx_query(&stmt);
544                        let conn: &mut sqlx::MySqlConnection = &mut *conn;
545                        crate::metric::metric!(self.metric_callback, &stmt, {
546                            crate::sqlx_map_err_ignore_not_found(
547                                query.fetch_one(conn).await.map(|row| Some(row.into())),
548                            )
549                        })
550                    }
551                    #[cfg(feature = "sqlx-postgres")]
552                    InnerConnection::Postgres(conn) => {
553                        let query = crate::driver::sqlx_postgres::sqlx_query(&stmt);
554                        let conn: &mut sqlx::PgConnection = &mut *conn;
555                        crate::metric::metric!(self.metric_callback, &stmt, {
556                            crate::sqlx_map_err_ignore_not_found(
557                                query.fetch_one(conn).await.map(|row| Some(row.into())),
558                            )
559                        })
560                    }
561                    #[cfg(feature = "sqlx-sqlite")]
562                    InnerConnection::Sqlite(conn) => {
563                        let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt);
564                        let conn: &mut sqlx::SqliteConnection = &mut *conn;
565                        crate::metric::metric!(self.metric_callback, &stmt, {
566                            crate::sqlx_map_err_ignore_not_found(
567                                query.fetch_one(conn).await.map(|row| Some(row.into())),
568                            )
569                        })
570                    }
571                    #[cfg(feature = "rusqlite")]
572                    InnerConnection::Rusqlite(conn) => conn.query_one(stmt, &self.metric_callback),
573                    #[cfg(feature = "mock")]
574                    InnerConnection::Mock(conn) => conn.query_one(stmt),
575                    #[cfg(feature = "proxy")]
576                    InnerConnection::Proxy(conn) => conn.query_one(stmt).await,
577                    #[allow(unreachable_patterns)]
578                    _ => Err(conn_err("Disconnected")),
579                }
580            }
581        )
582    }
583
584    #[instrument(level = "trace")]
585    #[allow(unused_variables)]
586    async fn query_all_raw(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
587        debug_print!("{}", stmt);
588
589        super::tracing_spans::with_db_span!(
590            "sea_orm.query_all",
591            self.backend,
592            stmt.sql.as_str(),
593            record_stmt = self.record_stmt_in_spans,
594            async {
595                #[cfg(not(feature = "sync"))]
596                let conn = &mut *self.conn.lock().await;
597                #[cfg(feature = "sync")]
598                let conn = &mut *self.conn.lock().map_err(|_| DbErr::MutexPoisonError)?;
599
600                match conn {
601                    #[cfg(feature = "sqlx-mysql")]
602                    InnerConnection::MySql(conn) => {
603                        let query = crate::driver::sqlx_mysql::sqlx_query(&stmt);
604                        let conn: &mut sqlx::MySqlConnection = &mut *conn;
605                        crate::metric::metric!(self.metric_callback, &stmt, {
606                            query
607                                .fetch_all(conn)
608                                .await
609                                .map(|rows| rows.into_iter().map(|r| r.into()).collect())
610                                .map_err(sqlx_error_to_query_err)
611                        })
612                    }
613                    #[cfg(feature = "sqlx-postgres")]
614                    InnerConnection::Postgres(conn) => {
615                        let query = crate::driver::sqlx_postgres::sqlx_query(&stmt);
616                        let conn: &mut sqlx::PgConnection = &mut *conn;
617                        crate::metric::metric!(self.metric_callback, &stmt, {
618                            query
619                                .fetch_all(conn)
620                                .await
621                                .map(|rows| rows.into_iter().map(|r| r.into()).collect())
622                                .map_err(sqlx_error_to_query_err)
623                        })
624                    }
625                    #[cfg(feature = "sqlx-sqlite")]
626                    InnerConnection::Sqlite(conn) => {
627                        let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt);
628                        let conn: &mut sqlx::SqliteConnection = &mut *conn;
629                        crate::metric::metric!(self.metric_callback, &stmt, {
630                            query
631                                .fetch_all(conn)
632                                .await
633                                .map(|rows| rows.into_iter().map(|r| r.into()).collect())
634                                .map_err(sqlx_error_to_query_err)
635                        })
636                    }
637                    #[cfg(feature = "rusqlite")]
638                    InnerConnection::Rusqlite(conn) => conn.query_all(stmt, &self.metric_callback),
639                    #[cfg(feature = "mock")]
640                    InnerConnection::Mock(conn) => conn.query_all(stmt),
641                    #[cfg(feature = "proxy")]
642                    InnerConnection::Proxy(conn) => conn.query_all(stmt).await,
643                    #[allow(unreachable_patterns)]
644                    _ => Err(conn_err("Disconnected")),
645                }
646            }
647        )
648    }
649}
650
651impl StreamTrait for DatabaseTransaction {
652    type Stream<'a> = TransactionStream<'a>;
653
654    fn get_database_backend(&self) -> DbBackend {
655        self.backend
656    }
657
658    #[instrument(level = "trace")]
659    fn stream_raw<'a>(
660        &'a self,
661        stmt: Statement,
662    ) -> Pin<Box<dyn Future<Output = Result<Self::Stream<'a>, DbErr>> + 'a + Send>> {
663        Box::pin(async move {
664            #[cfg(not(feature = "sync"))]
665            let conn = self.conn.lock().await;
666            #[cfg(feature = "sync")]
667            let conn = self.conn.lock().map_err(|_| DbErr::MutexPoisonError)?;
668            Ok(crate::TransactionStream::build(
669                conn,
670                stmt,
671                self.metric_callback.clone(),
672            ))
673        })
674    }
675}
676
677#[async_trait::async_trait]
678impl TransactionTrait for DatabaseTransaction {
679    type Transaction = DatabaseTransaction;
680
681    #[instrument(level = "trace")]
682    async fn begin(&self) -> Result<DatabaseTransaction, DbErr> {
683        DatabaseTransaction::begin(
684            Arc::clone(&self.conn),
685            self.backend,
686            self.metric_callback.clone(),
687            self.record_stmt_in_spans,
688            None,
689            None,
690            None,
691        )
692        .await
693    }
694
695    #[instrument(level = "trace")]
696    async fn begin_with_config(
697        &self,
698        isolation_level: Option<IsolationLevel>,
699        access_mode: Option<AccessMode>,
700    ) -> Result<DatabaseTransaction, DbErr> {
701        DatabaseTransaction::begin(
702            Arc::clone(&self.conn),
703            self.backend,
704            self.metric_callback.clone(),
705            self.record_stmt_in_spans,
706            isolation_level,
707            access_mode,
708            None,
709        )
710        .await
711    }
712
713    #[instrument(level = "trace")]
714    async fn begin_with_options(
715        &self,
716        options: TransactionOptions,
717    ) -> Result<DatabaseTransaction, DbErr> {
718        DatabaseTransaction::begin(
719            Arc::clone(&self.conn),
720            self.backend,
721            self.metric_callback.clone(),
722            self.record_stmt_in_spans,
723            options.isolation_level,
724            options.access_mode,
725            options.sqlite_transaction_mode,
726        )
727        .await
728    }
729
730    /// Execute the async function inside a transaction.
731    /// If the function returns an error, the transaction will be rolled back.
732    /// Otherwise, the transaction will be committed.
733    #[instrument(level = "trace", skip(_callback))]
734    async fn transaction<F, T, E>(&self, _callback: F) -> Result<T, TransactionError<E>>
735    where
736        F: for<'c> FnOnce(
737                &'c DatabaseTransaction,
738            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
739            + Send,
740        T: Send,
741        E: std::fmt::Display + std::fmt::Debug + Send,
742    {
743        let transaction = self.begin().await.map_err(TransactionError::Connection)?;
744        transaction.run(_callback).await
745    }
746
747    /// Execute the async function inside a transaction.
748    /// If the function returns an error, the transaction will be rolled back.
749    /// Otherwise, the transaction will be committed.
750    #[instrument(level = "trace", skip(_callback))]
751    async fn transaction_with_config<F, T, E>(
752        &self,
753        _callback: F,
754        isolation_level: Option<IsolationLevel>,
755        access_mode: Option<AccessMode>,
756    ) -> Result<T, TransactionError<E>>
757    where
758        F: for<'c> FnOnce(
759                &'c DatabaseTransaction,
760            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
761            + Send,
762        T: Send,
763        E: std::fmt::Display + std::fmt::Debug + Send,
764    {
765        let transaction = self
766            .begin_with_config(isolation_level, access_mode)
767            .await
768            .map_err(TransactionError::Connection)?;
769        transaction.run(_callback).await
770    }
771}
772
773/// Defines errors for handling transaction failures
774#[derive(Debug)]
775pub enum TransactionError<E> {
776    /// A Database connection error
777    Connection(DbErr),
778    /// An error occurring when doing database transactions
779    Transaction(E),
780}
781
782impl<E> std::fmt::Display for TransactionError<E>
783where
784    E: std::fmt::Display + std::fmt::Debug,
785{
786    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
787        match self {
788            TransactionError::Connection(e) => std::fmt::Display::fmt(e, f),
789            TransactionError::Transaction(e) => std::fmt::Display::fmt(e, f),
790        }
791    }
792}
793
794impl<E> std::error::Error for TransactionError<E> where E: std::fmt::Display + std::fmt::Debug {}
795
796impl<E> From<DbErr> for TransactionError<E>
797where
798    E: std::fmt::Display + std::fmt::Debug,
799{
800    fn from(e: DbErr) -> Self {
801        Self::Connection(e)
802    }
803}