sea_orm/database/
transaction.rs

1use crate::{
2    AccessMode, ConnectionTrait, DbBackend, DbErr, ExecResult, InnerConnection, IsolationLevel,
3    QueryResult, Statement, StreamTrait, TransactionSession, TransactionStream, TransactionTrait,
4    debug_print, error::*,
5};
6#[cfg(feature = "sqlx-dep")]
7use crate::{sqlx_error_to_exec_err, sqlx_error_to_query_err};
8use futures_util::lock::Mutex;
9#[cfg(feature = "sqlx-dep")]
10use sqlx::TransactionManager;
11use std::{future::Future, pin::Pin, sync::Arc};
12use tracing::instrument;
13
14/// Defines a database transaction, whether it is an open transaction and the type of
15/// backend to use.
16/// Under the hood, a Transaction is just a wrapper for a connection where
17/// START TRANSACTION has been executed.
18pub struct DatabaseTransaction {
19    conn: Arc<Mutex<InnerConnection>>,
20    backend: DbBackend,
21    open: bool,
22    metric_callback: Option<crate::metric::Callback>,
23}
24
25impl std::fmt::Debug for DatabaseTransaction {
26    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27        write!(f, "DatabaseTransaction")
28    }
29}
30
31impl DatabaseTransaction {
32    #[instrument(level = "trace", skip(metric_callback))]
33    pub(crate) async fn begin(
34        conn: Arc<Mutex<InnerConnection>>,
35        backend: DbBackend,
36        metric_callback: Option<crate::metric::Callback>,
37        isolation_level: Option<IsolationLevel>,
38        access_mode: Option<AccessMode>,
39    ) -> Result<DatabaseTransaction, DbErr> {
40        let res = DatabaseTransaction {
41            conn,
42            backend,
43            open: true,
44            metric_callback,
45        };
46        match *res.conn.lock().await {
47            #[cfg(feature = "sqlx-mysql")]
48            InnerConnection::MySql(ref mut c) => {
49                // in MySQL SET TRANSACTION operations must be executed before transaction start
50                crate::driver::sqlx_mysql::set_transaction_config(c, isolation_level, access_mode)
51                    .await?;
52                <sqlx::MySql as sqlx::Database>::TransactionManager::begin(c, None)
53                    .await
54                    .map_err(sqlx_error_to_query_err)
55            }
56            #[cfg(feature = "sqlx-postgres")]
57            InnerConnection::Postgres(ref mut c) => {
58                <sqlx::Postgres as sqlx::Database>::TransactionManager::begin(c, None)
59                    .await
60                    .map_err(sqlx_error_to_query_err)?;
61                // in PostgreSQL SET TRANSACTION operations must be executed inside transaction
62                crate::driver::sqlx_postgres::set_transaction_config(
63                    c,
64                    isolation_level,
65                    access_mode,
66                )
67                .await
68            }
69            #[cfg(feature = "sqlx-sqlite")]
70            InnerConnection::Sqlite(ref mut c) => {
71                // in SQLite isolation level and access mode are global settings
72                crate::driver::sqlx_sqlite::set_transaction_config(c, isolation_level, access_mode)
73                    .await?;
74                <sqlx::Sqlite as sqlx::Database>::TransactionManager::begin(c, None)
75                    .await
76                    .map_err(sqlx_error_to_query_err)
77            }
78            #[cfg(feature = "mock")]
79            InnerConnection::Mock(ref mut c) => {
80                c.begin();
81                Ok(())
82            }
83            #[cfg(feature = "proxy")]
84            InnerConnection::Proxy(ref mut c) => {
85                c.begin().await;
86                Ok(())
87            }
88            #[allow(unreachable_patterns)]
89            _ => Err(conn_err("Disconnected")),
90        }?;
91        Ok(res)
92    }
93
94    /// Runs a transaction to completion passing through the result.
95    /// Rolling back the transaction on encountering an error.
96    #[instrument(level = "trace", skip(callback))]
97    pub(crate) async fn run<F, T, E>(self, callback: F) -> Result<T, TransactionError<E>>
98    where
99        F: for<'b> FnOnce(
100                &'b DatabaseTransaction,
101            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'b>>
102            + Send,
103        T: Send,
104        E: std::fmt::Display + std::fmt::Debug + Send,
105    {
106        let res = callback(&self).await.map_err(TransactionError::Transaction);
107        if res.is_ok() {
108            self.commit().await.map_err(TransactionError::Connection)?;
109        } else {
110            self.rollback()
111                .await
112                .map_err(TransactionError::Connection)?;
113        }
114        res
115    }
116
117    /// Commit a transaction
118    #[instrument(level = "trace")]
119    #[allow(unreachable_code, unused_mut)]
120    pub async fn commit(mut self) -> Result<(), DbErr> {
121        match *self.conn.lock().await {
122            #[cfg(feature = "sqlx-mysql")]
123            InnerConnection::MySql(ref mut c) => {
124                <sqlx::MySql as sqlx::Database>::TransactionManager::commit(c)
125                    .await
126                    .map_err(sqlx_error_to_query_err)
127            }
128            #[cfg(feature = "sqlx-postgres")]
129            InnerConnection::Postgres(ref mut c) => {
130                <sqlx::Postgres as sqlx::Database>::TransactionManager::commit(c)
131                    .await
132                    .map_err(sqlx_error_to_query_err)
133            }
134            #[cfg(feature = "sqlx-sqlite")]
135            InnerConnection::Sqlite(ref mut c) => {
136                <sqlx::Sqlite as sqlx::Database>::TransactionManager::commit(c)
137                    .await
138                    .map_err(sqlx_error_to_query_err)
139            }
140            #[cfg(feature = "mock")]
141            InnerConnection::Mock(ref mut c) => {
142                c.commit();
143                Ok(())
144            }
145            #[cfg(feature = "proxy")]
146            InnerConnection::Proxy(ref mut c) => {
147                c.commit().await;
148                Ok(())
149            }
150            #[allow(unreachable_patterns)]
151            _ => Err(conn_err("Disconnected")),
152        }?;
153        self.open = false;
154        Ok(())
155    }
156
157    /// Rolls back a transaction explicitly
158    #[instrument(level = "trace")]
159    #[allow(unreachable_code, unused_mut)]
160    pub async fn rollback(mut self) -> Result<(), DbErr> {
161        match *self.conn.lock().await {
162            #[cfg(feature = "sqlx-mysql")]
163            InnerConnection::MySql(ref mut c) => {
164                <sqlx::MySql as sqlx::Database>::TransactionManager::rollback(c)
165                    .await
166                    .map_err(sqlx_error_to_query_err)
167            }
168            #[cfg(feature = "sqlx-postgres")]
169            InnerConnection::Postgres(ref mut c) => {
170                <sqlx::Postgres as sqlx::Database>::TransactionManager::rollback(c)
171                    .await
172                    .map_err(sqlx_error_to_query_err)
173            }
174            #[cfg(feature = "sqlx-sqlite")]
175            InnerConnection::Sqlite(ref mut c) => {
176                <sqlx::Sqlite as sqlx::Database>::TransactionManager::rollback(c)
177                    .await
178                    .map_err(sqlx_error_to_query_err)
179            }
180            #[cfg(feature = "mock")]
181            InnerConnection::Mock(ref mut c) => {
182                c.rollback();
183                Ok(())
184            }
185            #[cfg(feature = "proxy")]
186            InnerConnection::Proxy(ref mut c) => {
187                c.rollback().await;
188                Ok(())
189            }
190            #[allow(unreachable_patterns)]
191            _ => Err(conn_err("Disconnected")),
192        }?;
193        self.open = false;
194        Ok(())
195    }
196
197    // the rollback is queued and will be performed on next async operation, like returning the connection to the pool
198    #[instrument(level = "trace")]
199    fn start_rollback(&mut self) -> Result<(), DbErr> {
200        if self.open {
201            if let Some(mut conn) = self.conn.try_lock() {
202                match &mut *conn {
203                    #[cfg(feature = "sqlx-mysql")]
204                    InnerConnection::MySql(c) => {
205                        <sqlx::MySql as sqlx::Database>::TransactionManager::start_rollback(c);
206                    }
207                    #[cfg(feature = "sqlx-postgres")]
208                    InnerConnection::Postgres(c) => {
209                        <sqlx::Postgres as sqlx::Database>::TransactionManager::start_rollback(c);
210                    }
211                    #[cfg(feature = "sqlx-sqlite")]
212                    InnerConnection::Sqlite(c) => {
213                        <sqlx::Sqlite as sqlx::Database>::TransactionManager::start_rollback(c);
214                    }
215                    #[cfg(feature = "mock")]
216                    InnerConnection::Mock(c) => {
217                        c.rollback();
218                    }
219                    #[cfg(feature = "proxy")]
220                    InnerConnection::Proxy(c) => {
221                        c.start_rollback();
222                    }
223                    #[allow(unreachable_patterns)]
224                    _ => return Err(conn_err("Disconnected")),
225                }
226            } else {
227                //this should never happen
228                return Err(conn_err("Dropping a locked Transaction"));
229            }
230        }
231        Ok(())
232    }
233}
234
235#[async_trait::async_trait]
236impl TransactionSession for DatabaseTransaction {
237    async fn commit(self) -> Result<(), DbErr> {
238        self.commit().await
239    }
240
241    async fn rollback(self) -> Result<(), DbErr> {
242        self.rollback().await
243    }
244}
245
246impl Drop for DatabaseTransaction {
247    fn drop(&mut self) {
248        self.start_rollback().expect("Fail to rollback transaction");
249    }
250}
251
252#[async_trait::async_trait]
253impl ConnectionTrait for DatabaseTransaction {
254    fn get_database_backend(&self) -> DbBackend {
255        // this way we don't need to lock just to know the backend
256        self.backend
257    }
258
259    #[instrument(level = "trace")]
260    #[allow(unused_variables)]
261    async fn execute_raw(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
262        debug_print!("{}", stmt);
263
264        match &mut *self.conn.lock().await {
265            #[cfg(feature = "sqlx-mysql")]
266            InnerConnection::MySql(conn) => {
267                let query = crate::driver::sqlx_mysql::sqlx_query(&stmt);
268                let conn: &mut sqlx::MySqlConnection = &mut *conn;
269                crate::metric::metric!(self.metric_callback, &stmt, {
270                    query.execute(conn).await.map(Into::into)
271                })
272                .map_err(sqlx_error_to_exec_err)
273            }
274            #[cfg(feature = "sqlx-postgres")]
275            InnerConnection::Postgres(conn) => {
276                let query = crate::driver::sqlx_postgres::sqlx_query(&stmt);
277                let conn: &mut sqlx::PgConnection = &mut *conn;
278                crate::metric::metric!(self.metric_callback, &stmt, {
279                    query.execute(conn).await.map(Into::into)
280                })
281                .map_err(sqlx_error_to_exec_err)
282            }
283            #[cfg(feature = "sqlx-sqlite")]
284            InnerConnection::Sqlite(conn) => {
285                let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt);
286                let conn: &mut sqlx::SqliteConnection = &mut *conn;
287                crate::metric::metric!(self.metric_callback, &stmt, {
288                    query.execute(conn).await.map(Into::into)
289                })
290                .map_err(sqlx_error_to_exec_err)
291            }
292            #[cfg(feature = "mock")]
293            InnerConnection::Mock(conn) => return conn.execute(stmt),
294            #[cfg(feature = "proxy")]
295            InnerConnection::Proxy(conn) => return conn.execute(stmt).await,
296            #[allow(unreachable_patterns)]
297            _ => Err(conn_err("Disconnected")),
298        }
299    }
300
301    #[instrument(level = "trace")]
302    #[allow(unused_variables)]
303    async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
304        debug_print!("{}", sql);
305
306        match &mut *self.conn.lock().await {
307            #[cfg(feature = "sqlx-mysql")]
308            InnerConnection::MySql(conn) => {
309                let conn: &mut sqlx::MySqlConnection = &mut *conn;
310                sqlx::Executor::execute(conn, sql)
311                    .await
312                    .map(Into::into)
313                    .map_err(sqlx_error_to_exec_err)
314            }
315            #[cfg(feature = "sqlx-postgres")]
316            InnerConnection::Postgres(conn) => {
317                let conn: &mut sqlx::PgConnection = &mut *conn;
318                sqlx::Executor::execute(conn, sql)
319                    .await
320                    .map(Into::into)
321                    .map_err(sqlx_error_to_exec_err)
322            }
323            #[cfg(feature = "sqlx-sqlite")]
324            InnerConnection::Sqlite(conn) => {
325                let conn: &mut sqlx::SqliteConnection = &mut *conn;
326                sqlx::Executor::execute(conn, sql)
327                    .await
328                    .map(Into::into)
329                    .map_err(sqlx_error_to_exec_err)
330            }
331            #[cfg(feature = "mock")]
332            InnerConnection::Mock(conn) => {
333                let db_backend = conn.get_database_backend();
334                let stmt = Statement::from_string(db_backend, sql);
335                conn.execute(stmt)
336            }
337            #[cfg(feature = "proxy")]
338            InnerConnection::Proxy(conn) => {
339                let db_backend = conn.get_database_backend();
340                let stmt = Statement::from_string(db_backend, sql);
341                conn.execute(stmt).await
342            }
343            #[allow(unreachable_patterns)]
344            _ => Err(conn_err("Disconnected")),
345        }
346    }
347
348    #[instrument(level = "trace")]
349    #[allow(unused_variables)]
350    async fn query_one_raw(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
351        debug_print!("{}", stmt);
352
353        match &mut *self.conn.lock().await {
354            #[cfg(feature = "sqlx-mysql")]
355            InnerConnection::MySql(conn) => {
356                let query = crate::driver::sqlx_mysql::sqlx_query(&stmt);
357                let conn: &mut sqlx::MySqlConnection = &mut *conn;
358                crate::metric::metric!(self.metric_callback, &stmt, {
359                    crate::sqlx_map_err_ignore_not_found(
360                        query.fetch_one(conn).await.map(|row| Some(row.into())),
361                    )
362                })
363            }
364            #[cfg(feature = "sqlx-postgres")]
365            InnerConnection::Postgres(conn) => {
366                let query = crate::driver::sqlx_postgres::sqlx_query(&stmt);
367                let conn: &mut sqlx::PgConnection = &mut *conn;
368                crate::metric::metric!(self.metric_callback, &stmt, {
369                    crate::sqlx_map_err_ignore_not_found(
370                        query.fetch_one(conn).await.map(|row| Some(row.into())),
371                    )
372                })
373            }
374            #[cfg(feature = "sqlx-sqlite")]
375            InnerConnection::Sqlite(conn) => {
376                let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt);
377                let conn: &mut sqlx::SqliteConnection = &mut *conn;
378                crate::metric::metric!(self.metric_callback, &stmt, {
379                    crate::sqlx_map_err_ignore_not_found(
380                        query.fetch_one(conn).await.map(|row| Some(row.into())),
381                    )
382                })
383            }
384            #[cfg(feature = "mock")]
385            InnerConnection::Mock(conn) => return conn.query_one(stmt),
386            #[cfg(feature = "proxy")]
387            InnerConnection::Proxy(conn) => return conn.query_one(stmt).await,
388            #[allow(unreachable_patterns)]
389            _ => Err(conn_err("Disconnected")),
390        }
391    }
392
393    #[instrument(level = "trace")]
394    #[allow(unused_variables)]
395    async fn query_all_raw(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
396        debug_print!("{}", stmt);
397
398        match &mut *self.conn.lock().await {
399            #[cfg(feature = "sqlx-mysql")]
400            InnerConnection::MySql(conn) => {
401                let query = crate::driver::sqlx_mysql::sqlx_query(&stmt);
402                let conn: &mut sqlx::MySqlConnection = &mut *conn;
403                crate::metric::metric!(self.metric_callback, &stmt, {
404                    query
405                        .fetch_all(conn)
406                        .await
407                        .map(|rows| rows.into_iter().map(|r| r.into()).collect())
408                        .map_err(sqlx_error_to_query_err)
409                })
410            }
411            #[cfg(feature = "sqlx-postgres")]
412            InnerConnection::Postgres(conn) => {
413                let query = crate::driver::sqlx_postgres::sqlx_query(&stmt);
414                let conn: &mut sqlx::PgConnection = &mut *conn;
415                crate::metric::metric!(self.metric_callback, &stmt, {
416                    query
417                        .fetch_all(conn)
418                        .await
419                        .map(|rows| rows.into_iter().map(|r| r.into()).collect())
420                        .map_err(sqlx_error_to_query_err)
421                })
422            }
423            #[cfg(feature = "sqlx-sqlite")]
424            InnerConnection::Sqlite(conn) => {
425                let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt);
426                let conn: &mut sqlx::SqliteConnection = &mut *conn;
427                crate::metric::metric!(self.metric_callback, &stmt, {
428                    query
429                        .fetch_all(conn)
430                        .await
431                        .map(|rows| rows.into_iter().map(|r| r.into()).collect())
432                        .map_err(sqlx_error_to_query_err)
433                })
434            }
435            #[cfg(feature = "mock")]
436            InnerConnection::Mock(conn) => return conn.query_all(stmt),
437            #[cfg(feature = "proxy")]
438            InnerConnection::Proxy(conn) => return conn.query_all(stmt).await,
439            #[allow(unreachable_patterns)]
440            _ => Err(conn_err("Disconnected")),
441        }
442    }
443}
444
445impl StreamTrait for DatabaseTransaction {
446    type Stream<'a> = TransactionStream<'a>;
447
448    fn get_database_backend(&self) -> DbBackend {
449        self.backend
450    }
451
452    #[instrument(level = "trace")]
453    fn stream_raw<'a>(
454        &'a self,
455        stmt: Statement,
456    ) -> Pin<Box<dyn Future<Output = Result<Self::Stream<'a>, DbErr>> + 'a + Send>> {
457        Box::pin(async move {
458            let conn = self.conn.lock().await;
459            Ok(crate::TransactionStream::build(
460                conn,
461                stmt,
462                self.metric_callback.clone(),
463            ))
464        })
465    }
466}
467
468#[async_trait::async_trait]
469impl TransactionTrait for DatabaseTransaction {
470    type Transaction = DatabaseTransaction;
471
472    #[instrument(level = "trace")]
473    async fn begin(&self) -> Result<DatabaseTransaction, DbErr> {
474        DatabaseTransaction::begin(
475            Arc::clone(&self.conn),
476            self.backend,
477            self.metric_callback.clone(),
478            None,
479            None,
480        )
481        .await
482    }
483
484    #[instrument(level = "trace")]
485    async fn begin_with_config(
486        &self,
487        isolation_level: Option<IsolationLevel>,
488        access_mode: Option<AccessMode>,
489    ) -> Result<DatabaseTransaction, DbErr> {
490        DatabaseTransaction::begin(
491            Arc::clone(&self.conn),
492            self.backend,
493            self.metric_callback.clone(),
494            isolation_level,
495            access_mode,
496        )
497        .await
498    }
499
500    /// Execute the async function inside a transaction.
501    /// If the function returns an error, the transaction will be rolled back.
502    /// Otherwise, the transaction will be committed.
503    #[instrument(level = "trace", skip(_callback))]
504    async fn transaction<F, T, E>(&self, _callback: F) -> Result<T, TransactionError<E>>
505    where
506        F: for<'c> FnOnce(
507                &'c DatabaseTransaction,
508            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
509            + Send,
510        T: Send,
511        E: std::fmt::Display + std::fmt::Debug + Send,
512    {
513        let transaction = self.begin().await.map_err(TransactionError::Connection)?;
514        transaction.run(_callback).await
515    }
516
517    /// Execute the async function inside a transaction.
518    /// If the function returns an error, the transaction will be rolled back.
519    /// Otherwise, the transaction will be committed.
520    #[instrument(level = "trace", skip(_callback))]
521    async fn transaction_with_config<F, T, E>(
522        &self,
523        _callback: F,
524        isolation_level: Option<IsolationLevel>,
525        access_mode: Option<AccessMode>,
526    ) -> Result<T, TransactionError<E>>
527    where
528        F: for<'c> FnOnce(
529                &'c DatabaseTransaction,
530            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
531            + Send,
532        T: Send,
533        E: std::fmt::Display + std::fmt::Debug + Send,
534    {
535        let transaction = self
536            .begin_with_config(isolation_level, access_mode)
537            .await
538            .map_err(TransactionError::Connection)?;
539        transaction.run(_callback).await
540    }
541}
542
543/// Defines errors for handling transaction failures
544#[derive(Debug)]
545pub enum TransactionError<E> {
546    /// A Database connection error
547    Connection(DbErr),
548    /// An error occurring when doing database transactions
549    Transaction(E),
550}
551
552impl<E> std::fmt::Display for TransactionError<E>
553where
554    E: std::fmt::Display + std::fmt::Debug,
555{
556    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
557        match self {
558            TransactionError::Connection(e) => std::fmt::Display::fmt(e, f),
559            TransactionError::Transaction(e) => std::fmt::Display::fmt(e, f),
560        }
561    }
562}
563
564impl<E> std::error::Error for TransactionError<E> where E: std::fmt::Display + std::fmt::Debug {}
565
566impl<E> From<DbErr> for TransactionError<E>
567where
568    E: std::fmt::Display + std::fmt::Debug,
569{
570    fn from(e: DbErr) -> Self {
571        Self::Connection(e)
572    }
573}