sea_orm/database/
transaction.rs

1use crate::{
2    AccessMode, ConnectionTrait, DbBackend, DbErr, ExecResult, InnerConnection, IsolationLevel,
3    QueryResult, Statement, StreamTrait, TransactionStream, TransactionTrait, debug_print,
4    error::*,
5};
6#[cfg(feature = "sqlx-dep")]
7use crate::{sqlx_error_to_exec_err, sqlx_error_to_query_err};
8use futures_util::lock::Mutex;
9#[cfg(feature = "sqlx-dep")]
10use sqlx::TransactionManager;
11use std::{future::Future, pin::Pin, sync::Arc};
12use tracing::instrument;
13
14/// Defines a database transaction, whether it is an open transaction and the type of
15/// backend to use.
16/// Under the hood, a Transaction is just a wrapper for a connection where
17/// START TRANSACTION has been executed.
18pub struct DatabaseTransaction {
19    conn: Arc<Mutex<InnerConnection>>,
20    backend: DbBackend,
21    open: bool,
22    metric_callback: Option<crate::metric::Callback>,
23}
24
25impl std::fmt::Debug for DatabaseTransaction {
26    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27        write!(f, "DatabaseTransaction")
28    }
29}
30
31impl DatabaseTransaction {
32    #[instrument(level = "trace", skip(metric_callback))]
33    pub(crate) async fn begin(
34        conn: Arc<Mutex<InnerConnection>>,
35        backend: DbBackend,
36        metric_callback: Option<crate::metric::Callback>,
37        isolation_level: Option<IsolationLevel>,
38        access_mode: Option<AccessMode>,
39    ) -> Result<DatabaseTransaction, DbErr> {
40        let res = DatabaseTransaction {
41            conn,
42            backend,
43            open: true,
44            metric_callback,
45        };
46        match *res.conn.lock().await {
47            #[cfg(feature = "sqlx-mysql")]
48            InnerConnection::MySql(ref mut c) => {
49                // in MySQL SET TRANSACTION operations must be executed before transaction start
50                crate::driver::sqlx_mysql::set_transaction_config(c, isolation_level, access_mode)
51                    .await?;
52                <sqlx::MySql as sqlx::Database>::TransactionManager::begin(c, None)
53                    .await
54                    .map_err(sqlx_error_to_query_err)
55            }
56            #[cfg(feature = "sqlx-postgres")]
57            InnerConnection::Postgres(ref mut c) => {
58                <sqlx::Postgres as sqlx::Database>::TransactionManager::begin(c, None)
59                    .await
60                    .map_err(sqlx_error_to_query_err)?;
61                // in PostgreSQL SET TRANSACTION operations must be executed inside transaction
62                crate::driver::sqlx_postgres::set_transaction_config(
63                    c,
64                    isolation_level,
65                    access_mode,
66                )
67                .await
68            }
69            #[cfg(feature = "sqlx-sqlite")]
70            InnerConnection::Sqlite(ref mut c) => {
71                // in SQLite isolation level and access mode are global settings
72                crate::driver::sqlx_sqlite::set_transaction_config(c, isolation_level, access_mode)
73                    .await?;
74                <sqlx::Sqlite as sqlx::Database>::TransactionManager::begin(c, None)
75                    .await
76                    .map_err(sqlx_error_to_query_err)
77            }
78            #[cfg(feature = "mock")]
79            InnerConnection::Mock(ref mut c) => {
80                c.begin();
81                Ok(())
82            }
83            #[cfg(feature = "proxy")]
84            InnerConnection::Proxy(ref mut c) => {
85                c.begin().await;
86                Ok(())
87            }
88            #[allow(unreachable_patterns)]
89            _ => Err(conn_err("Disconnected")),
90        }?;
91        Ok(res)
92    }
93
94    /// Runs a transaction to completion passing through the result.
95    /// Rolling back the transaction on encountering an error.
96    #[instrument(level = "trace", skip(callback))]
97    pub(crate) async fn run<F, T, E>(self, callback: F) -> Result<T, TransactionError<E>>
98    where
99        F: for<'b> FnOnce(
100                &'b DatabaseTransaction,
101            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'b>>
102            + Send,
103        T: Send,
104        E: std::fmt::Display + std::fmt::Debug + Send,
105    {
106        let res = callback(&self).await.map_err(TransactionError::Transaction);
107        if res.is_ok() {
108            self.commit().await.map_err(TransactionError::Connection)?;
109        } else {
110            self.rollback()
111                .await
112                .map_err(TransactionError::Connection)?;
113        }
114        res
115    }
116
117    /// Commit a transaction
118    #[instrument(level = "trace")]
119    #[allow(unreachable_code, unused_mut)]
120    pub async fn commit(mut self) -> Result<(), DbErr> {
121        match *self.conn.lock().await {
122            #[cfg(feature = "sqlx-mysql")]
123            InnerConnection::MySql(ref mut c) => {
124                <sqlx::MySql as sqlx::Database>::TransactionManager::commit(c)
125                    .await
126                    .map_err(sqlx_error_to_query_err)
127            }
128            #[cfg(feature = "sqlx-postgres")]
129            InnerConnection::Postgres(ref mut c) => {
130                <sqlx::Postgres as sqlx::Database>::TransactionManager::commit(c)
131                    .await
132                    .map_err(sqlx_error_to_query_err)
133            }
134            #[cfg(feature = "sqlx-sqlite")]
135            InnerConnection::Sqlite(ref mut c) => {
136                <sqlx::Sqlite as sqlx::Database>::TransactionManager::commit(c)
137                    .await
138                    .map_err(sqlx_error_to_query_err)
139            }
140            #[cfg(feature = "mock")]
141            InnerConnection::Mock(ref mut c) => {
142                c.commit();
143                Ok(())
144            }
145            #[cfg(feature = "proxy")]
146            InnerConnection::Proxy(ref mut c) => {
147                c.commit().await;
148                Ok(())
149            }
150            #[allow(unreachable_patterns)]
151            _ => Err(conn_err("Disconnected")),
152        }?;
153        self.open = false;
154        Ok(())
155    }
156
157    /// Rolls back a transaction explicitly
158    #[instrument(level = "trace")]
159    #[allow(unreachable_code, unused_mut)]
160    pub async fn rollback(mut self) -> Result<(), DbErr> {
161        match *self.conn.lock().await {
162            #[cfg(feature = "sqlx-mysql")]
163            InnerConnection::MySql(ref mut c) => {
164                <sqlx::MySql as sqlx::Database>::TransactionManager::rollback(c)
165                    .await
166                    .map_err(sqlx_error_to_query_err)
167            }
168            #[cfg(feature = "sqlx-postgres")]
169            InnerConnection::Postgres(ref mut c) => {
170                <sqlx::Postgres as sqlx::Database>::TransactionManager::rollback(c)
171                    .await
172                    .map_err(sqlx_error_to_query_err)
173            }
174            #[cfg(feature = "sqlx-sqlite")]
175            InnerConnection::Sqlite(ref mut c) => {
176                <sqlx::Sqlite as sqlx::Database>::TransactionManager::rollback(c)
177                    .await
178                    .map_err(sqlx_error_to_query_err)
179            }
180            #[cfg(feature = "mock")]
181            InnerConnection::Mock(ref mut c) => {
182                c.rollback();
183                Ok(())
184            }
185            #[cfg(feature = "proxy")]
186            InnerConnection::Proxy(ref mut c) => {
187                c.rollback().await;
188                Ok(())
189            }
190            #[allow(unreachable_patterns)]
191            _ => Err(conn_err("Disconnected")),
192        }?;
193        self.open = false;
194        Ok(())
195    }
196
197    // the rollback is queued and will be performed on next async operation, like returning the connection to the pool
198    #[instrument(level = "trace")]
199    fn start_rollback(&mut self) -> Result<(), DbErr> {
200        if self.open {
201            if let Some(mut conn) = self.conn.try_lock() {
202                match &mut *conn {
203                    #[cfg(feature = "sqlx-mysql")]
204                    InnerConnection::MySql(c) => {
205                        <sqlx::MySql as sqlx::Database>::TransactionManager::start_rollback(c);
206                    }
207                    #[cfg(feature = "sqlx-postgres")]
208                    InnerConnection::Postgres(c) => {
209                        <sqlx::Postgres as sqlx::Database>::TransactionManager::start_rollback(c);
210                    }
211                    #[cfg(feature = "sqlx-sqlite")]
212                    InnerConnection::Sqlite(c) => {
213                        <sqlx::Sqlite as sqlx::Database>::TransactionManager::start_rollback(c);
214                    }
215                    #[cfg(feature = "mock")]
216                    InnerConnection::Mock(c) => {
217                        c.rollback();
218                    }
219                    #[cfg(feature = "proxy")]
220                    InnerConnection::Proxy(c) => {
221                        c.start_rollback();
222                    }
223                    #[allow(unreachable_patterns)]
224                    _ => return Err(conn_err("Disconnected")),
225                }
226            } else {
227                //this should never happen
228                return Err(conn_err("Dropping a locked Transaction"));
229            }
230        }
231        Ok(())
232    }
233}
234
235impl Drop for DatabaseTransaction {
236    fn drop(&mut self) {
237        self.start_rollback().expect("Fail to rollback transaction");
238    }
239}
240
241#[async_trait::async_trait]
242impl ConnectionTrait for DatabaseTransaction {
243    fn get_database_backend(&self) -> DbBackend {
244        // this way we don't need to lock just to know the backend
245        self.backend
246    }
247
248    #[instrument(level = "trace")]
249    #[allow(unused_variables)]
250    async fn execute_raw(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
251        debug_print!("{}", stmt);
252
253        match &mut *self.conn.lock().await {
254            #[cfg(feature = "sqlx-mysql")]
255            InnerConnection::MySql(conn) => {
256                let query = crate::driver::sqlx_mysql::sqlx_query(&stmt);
257                let conn: &mut sqlx::MySqlConnection = &mut *conn;
258                crate::metric::metric!(self.metric_callback, &stmt, {
259                    query.execute(conn).await.map(Into::into)
260                })
261                .map_err(sqlx_error_to_exec_err)
262            }
263            #[cfg(feature = "sqlx-postgres")]
264            InnerConnection::Postgres(conn) => {
265                let query = crate::driver::sqlx_postgres::sqlx_query(&stmt);
266                let conn: &mut sqlx::PgConnection = &mut *conn;
267                crate::metric::metric!(self.metric_callback, &stmt, {
268                    query.execute(conn).await.map(Into::into)
269                })
270                .map_err(sqlx_error_to_exec_err)
271            }
272            #[cfg(feature = "sqlx-sqlite")]
273            InnerConnection::Sqlite(conn) => {
274                let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt);
275                let conn: &mut sqlx::SqliteConnection = &mut *conn;
276                crate::metric::metric!(self.metric_callback, &stmt, {
277                    query.execute(conn).await.map(Into::into)
278                })
279                .map_err(sqlx_error_to_exec_err)
280            }
281            #[cfg(feature = "mock")]
282            InnerConnection::Mock(conn) => return conn.execute(stmt),
283            #[cfg(feature = "proxy")]
284            InnerConnection::Proxy(conn) => return conn.execute(stmt).await,
285            #[allow(unreachable_patterns)]
286            _ => Err(conn_err("Disconnected")),
287        }
288    }
289
290    #[instrument(level = "trace")]
291    #[allow(unused_variables)]
292    async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
293        debug_print!("{}", sql);
294
295        match &mut *self.conn.lock().await {
296            #[cfg(feature = "sqlx-mysql")]
297            InnerConnection::MySql(conn) => {
298                let conn: &mut sqlx::MySqlConnection = &mut *conn;
299                sqlx::Executor::execute(conn, sql)
300                    .await
301                    .map(Into::into)
302                    .map_err(sqlx_error_to_exec_err)
303            }
304            #[cfg(feature = "sqlx-postgres")]
305            InnerConnection::Postgres(conn) => {
306                let conn: &mut sqlx::PgConnection = &mut *conn;
307                sqlx::Executor::execute(conn, sql)
308                    .await
309                    .map(Into::into)
310                    .map_err(sqlx_error_to_exec_err)
311            }
312            #[cfg(feature = "sqlx-sqlite")]
313            InnerConnection::Sqlite(conn) => {
314                let conn: &mut sqlx::SqliteConnection = &mut *conn;
315                sqlx::Executor::execute(conn, sql)
316                    .await
317                    .map(Into::into)
318                    .map_err(sqlx_error_to_exec_err)
319            }
320            #[cfg(feature = "mock")]
321            InnerConnection::Mock(conn) => {
322                let db_backend = conn.get_database_backend();
323                let stmt = Statement::from_string(db_backend, sql);
324                conn.execute(stmt)
325            }
326            #[cfg(feature = "proxy")]
327            InnerConnection::Proxy(conn) => {
328                let db_backend = conn.get_database_backend();
329                let stmt = Statement::from_string(db_backend, sql);
330                conn.execute(stmt).await
331            }
332            #[allow(unreachable_patterns)]
333            _ => Err(conn_err("Disconnected")),
334        }
335    }
336
337    #[instrument(level = "trace")]
338    #[allow(unused_variables)]
339    async fn query_one_raw(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
340        debug_print!("{}", stmt);
341
342        match &mut *self.conn.lock().await {
343            #[cfg(feature = "sqlx-mysql")]
344            InnerConnection::MySql(conn) => {
345                let query = crate::driver::sqlx_mysql::sqlx_query(&stmt);
346                let conn: &mut sqlx::MySqlConnection = &mut *conn;
347                crate::metric::metric!(self.metric_callback, &stmt, {
348                    crate::sqlx_map_err_ignore_not_found(
349                        query.fetch_one(conn).await.map(|row| Some(row.into())),
350                    )
351                })
352            }
353            #[cfg(feature = "sqlx-postgres")]
354            InnerConnection::Postgres(conn) => {
355                let query = crate::driver::sqlx_postgres::sqlx_query(&stmt);
356                let conn: &mut sqlx::PgConnection = &mut *conn;
357                crate::metric::metric!(self.metric_callback, &stmt, {
358                    crate::sqlx_map_err_ignore_not_found(
359                        query.fetch_one(conn).await.map(|row| Some(row.into())),
360                    )
361                })
362            }
363            #[cfg(feature = "sqlx-sqlite")]
364            InnerConnection::Sqlite(conn) => {
365                let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt);
366                let conn: &mut sqlx::SqliteConnection = &mut *conn;
367                crate::metric::metric!(self.metric_callback, &stmt, {
368                    crate::sqlx_map_err_ignore_not_found(
369                        query.fetch_one(conn).await.map(|row| Some(row.into())),
370                    )
371                })
372            }
373            #[cfg(feature = "mock")]
374            InnerConnection::Mock(conn) => return conn.query_one(stmt),
375            #[cfg(feature = "proxy")]
376            InnerConnection::Proxy(conn) => return conn.query_one(stmt).await,
377            #[allow(unreachable_patterns)]
378            _ => Err(conn_err("Disconnected")),
379        }
380    }
381
382    #[instrument(level = "trace")]
383    #[allow(unused_variables)]
384    async fn query_all_raw(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
385        debug_print!("{}", stmt);
386
387        match &mut *self.conn.lock().await {
388            #[cfg(feature = "sqlx-mysql")]
389            InnerConnection::MySql(conn) => {
390                let query = crate::driver::sqlx_mysql::sqlx_query(&stmt);
391                let conn: &mut sqlx::MySqlConnection = &mut *conn;
392                crate::metric::metric!(self.metric_callback, &stmt, {
393                    query
394                        .fetch_all(conn)
395                        .await
396                        .map(|rows| rows.into_iter().map(|r| r.into()).collect())
397                        .map_err(sqlx_error_to_query_err)
398                })
399            }
400            #[cfg(feature = "sqlx-postgres")]
401            InnerConnection::Postgres(conn) => {
402                let query = crate::driver::sqlx_postgres::sqlx_query(&stmt);
403                let conn: &mut sqlx::PgConnection = &mut *conn;
404                crate::metric::metric!(self.metric_callback, &stmt, {
405                    query
406                        .fetch_all(conn)
407                        .await
408                        .map(|rows| rows.into_iter().map(|r| r.into()).collect())
409                        .map_err(sqlx_error_to_query_err)
410                })
411            }
412            #[cfg(feature = "sqlx-sqlite")]
413            InnerConnection::Sqlite(conn) => {
414                let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt);
415                let conn: &mut sqlx::SqliteConnection = &mut *conn;
416                crate::metric::metric!(self.metric_callback, &stmt, {
417                    query
418                        .fetch_all(conn)
419                        .await
420                        .map(|rows| rows.into_iter().map(|r| r.into()).collect())
421                        .map_err(sqlx_error_to_query_err)
422                })
423            }
424            #[cfg(feature = "mock")]
425            InnerConnection::Mock(conn) => return conn.query_all(stmt),
426            #[cfg(feature = "proxy")]
427            InnerConnection::Proxy(conn) => return conn.query_all(stmt).await,
428            #[allow(unreachable_patterns)]
429            _ => Err(conn_err("Disconnected")),
430        }
431    }
432}
433
434impl StreamTrait for DatabaseTransaction {
435    type Stream<'a> = TransactionStream<'a>;
436
437    fn get_database_backend(&self) -> DbBackend {
438        self.backend
439    }
440
441    #[instrument(level = "trace")]
442    fn stream_raw<'a>(
443        &'a self,
444        stmt: Statement,
445    ) -> Pin<Box<dyn Future<Output = Result<Self::Stream<'a>, DbErr>> + 'a + Send>> {
446        Box::pin(async move {
447            let conn = self.conn.lock().await;
448            Ok(crate::TransactionStream::build(
449                conn,
450                stmt,
451                self.metric_callback.clone(),
452            ))
453        })
454    }
455}
456
457#[async_trait::async_trait]
458impl TransactionTrait for DatabaseTransaction {
459    type Transaction = DatabaseTransaction;
460
461    #[instrument(level = "trace")]
462    async fn begin(&self) -> Result<DatabaseTransaction, DbErr> {
463        DatabaseTransaction::begin(
464            Arc::clone(&self.conn),
465            self.backend,
466            self.metric_callback.clone(),
467            None,
468            None,
469        )
470        .await
471    }
472
473    #[instrument(level = "trace")]
474    async fn begin_with_config(
475        &self,
476        isolation_level: Option<IsolationLevel>,
477        access_mode: Option<AccessMode>,
478    ) -> Result<DatabaseTransaction, DbErr> {
479        DatabaseTransaction::begin(
480            Arc::clone(&self.conn),
481            self.backend,
482            self.metric_callback.clone(),
483            isolation_level,
484            access_mode,
485        )
486        .await
487    }
488
489    /// Execute the async function inside a transaction.
490    /// If the function returns an error, the transaction will be rolled back.
491    /// Otherwise, the transaction will be committed.
492    #[instrument(level = "trace", skip(_callback))]
493    async fn transaction<F, T, E>(&self, _callback: F) -> Result<T, TransactionError<E>>
494    where
495        F: for<'c> FnOnce(
496                &'c DatabaseTransaction,
497            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
498            + Send,
499        T: Send,
500        E: std::fmt::Display + std::fmt::Debug + Send,
501    {
502        let transaction = self.begin().await.map_err(TransactionError::Connection)?;
503        transaction.run(_callback).await
504    }
505
506    /// Execute the async function inside a transaction.
507    /// If the function returns an error, the transaction will be rolled back.
508    /// Otherwise, the transaction will be committed.
509    #[instrument(level = "trace", skip(_callback))]
510    async fn transaction_with_config<F, T, E>(
511        &self,
512        _callback: F,
513        isolation_level: Option<IsolationLevel>,
514        access_mode: Option<AccessMode>,
515    ) -> Result<T, TransactionError<E>>
516    where
517        F: for<'c> FnOnce(
518                &'c DatabaseTransaction,
519            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
520            + Send,
521        T: Send,
522        E: std::fmt::Display + std::fmt::Debug + Send,
523    {
524        let transaction = self
525            .begin_with_config(isolation_level, access_mode)
526            .await
527            .map_err(TransactionError::Connection)?;
528        transaction.run(_callback).await
529    }
530}
531
532/// Defines errors for handling transaction failures
533#[derive(Debug)]
534pub enum TransactionError<E> {
535    /// A Database connection error
536    Connection(DbErr),
537    /// An error occurring when doing database transactions
538    Transaction(E),
539}
540
541impl<E> std::fmt::Display for TransactionError<E>
542where
543    E: std::fmt::Display + std::fmt::Debug,
544{
545    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
546        match self {
547            TransactionError::Connection(e) => std::fmt::Display::fmt(e, f),
548            TransactionError::Transaction(e) => std::fmt::Display::fmt(e, f),
549        }
550    }
551}
552
553impl<E> std::error::Error for TransactionError<E> where E: std::fmt::Display + std::fmt::Debug {}
554
555impl<E> From<DbErr> for TransactionError<E>
556where
557    E: std::fmt::Display + std::fmt::Debug,
558{
559    fn from(e: DbErr) -> Self {
560        Self::Connection(e)
561    }
562}