sea_orm/database/
transaction.rs

1use crate::{
2    debug_print, error::*, AccessMode, ConnectionTrait, DbBackend, DbErr, ExecResult,
3    InnerConnection, IsolationLevel, QueryResult, Statement, StreamTrait, TransactionStream,
4    TransactionTrait,
5};
6#[cfg(feature = "sqlx-dep")]
7use crate::{sqlx_error_to_exec_err, sqlx_error_to_query_err};
8use futures_util::lock::Mutex;
9#[cfg(feature = "sqlx-dep")]
10use sqlx::TransactionManager;
11use std::{future::Future, pin::Pin, sync::Arc};
12use tracing::instrument;
13
14// a Transaction is just a sugar for a connection where START TRANSACTION has been executed
15/// Defines a database transaction, whether it is an open transaction and the type of
16/// backend to use
17pub struct DatabaseTransaction {
18    conn: Arc<Mutex<InnerConnection>>,
19    backend: DbBackend,
20    open: bool,
21    metric_callback: Option<crate::metric::Callback>,
22}
23
24impl std::fmt::Debug for DatabaseTransaction {
25    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26        write!(f, "DatabaseTransaction")
27    }
28}
29
30impl DatabaseTransaction {
31    #[instrument(level = "trace", skip(metric_callback))]
32    pub(crate) async fn begin(
33        conn: Arc<Mutex<InnerConnection>>,
34        backend: DbBackend,
35        metric_callback: Option<crate::metric::Callback>,
36        isolation_level: Option<IsolationLevel>,
37        access_mode: Option<AccessMode>,
38    ) -> Result<DatabaseTransaction, DbErr> {
39        let res = DatabaseTransaction {
40            conn,
41            backend,
42            open: true,
43            metric_callback,
44        };
45        match *res.conn.lock().await {
46            #[cfg(feature = "sqlx-mysql")]
47            InnerConnection::MySql(ref mut c) => {
48                // in MySQL SET TRANSACTION operations must be executed before transaction start
49                crate::driver::sqlx_mysql::set_transaction_config(c, isolation_level, access_mode)
50                    .await?;
51                <sqlx::MySql as sqlx::Database>::TransactionManager::begin(c)
52                    .await
53                    .map_err(sqlx_error_to_query_err)
54            }
55            #[cfg(feature = "sqlx-postgres")]
56            InnerConnection::Postgres(ref mut c) => {
57                <sqlx::Postgres as sqlx::Database>::TransactionManager::begin(c)
58                    .await
59                    .map_err(sqlx_error_to_query_err)?;
60                // in PostgreSQL SET TRANSACTION operations must be executed inside transaction
61                crate::driver::sqlx_postgres::set_transaction_config(
62                    c,
63                    isolation_level,
64                    access_mode,
65                )
66                .await
67            }
68            #[cfg(feature = "sqlx-sqlite")]
69            InnerConnection::Sqlite(ref mut c) => {
70                // in SQLite isolation level and access mode are global settings
71                crate::driver::sqlx_sqlite::set_transaction_config(c, isolation_level, access_mode)
72                    .await?;
73                <sqlx::Sqlite as sqlx::Database>::TransactionManager::begin(c)
74                    .await
75                    .map_err(sqlx_error_to_query_err)
76            }
77            #[cfg(feature = "mock")]
78            InnerConnection::Mock(ref mut c) => {
79                c.begin();
80                Ok(())
81            }
82            #[allow(unreachable_patterns)]
83            _ => Err(conn_err("Disconnected")),
84        }?;
85        Ok(res)
86    }
87
88    /// Runs a transaction to completion returning an rolling back the transaction on
89    /// encountering an error if it fails
90    #[instrument(level = "trace", skip(callback))]
91    pub(crate) async fn run<F, T, E>(self, callback: F) -> Result<T, TransactionError<E>>
92    where
93        F: for<'b> FnOnce(
94                &'b DatabaseTransaction,
95            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'b>>
96            + Send,
97        T: Send,
98        E: std::error::Error + Send,
99    {
100        let res = callback(&self).await.map_err(TransactionError::Transaction);
101        if res.is_ok() {
102            self.commit().await.map_err(TransactionError::Connection)?;
103        } else {
104            self.rollback()
105                .await
106                .map_err(TransactionError::Connection)?;
107        }
108        res
109    }
110
111    /// Commit a transaction atomically
112    #[instrument(level = "trace")]
113    #[allow(unreachable_code, unused_mut)]
114    pub async fn commit(mut self) -> Result<(), DbErr> {
115        match *self.conn.lock().await {
116            #[cfg(feature = "sqlx-mysql")]
117            InnerConnection::MySql(ref mut c) => {
118                <sqlx::MySql as sqlx::Database>::TransactionManager::commit(c)
119                    .await
120                    .map_err(sqlx_error_to_query_err)
121            }
122            #[cfg(feature = "sqlx-postgres")]
123            InnerConnection::Postgres(ref mut c) => {
124                <sqlx::Postgres as sqlx::Database>::TransactionManager::commit(c)
125                    .await
126                    .map_err(sqlx_error_to_query_err)
127            }
128            #[cfg(feature = "sqlx-sqlite")]
129            InnerConnection::Sqlite(ref mut c) => {
130                <sqlx::Sqlite as sqlx::Database>::TransactionManager::commit(c)
131                    .await
132                    .map_err(sqlx_error_to_query_err)
133            }
134            #[cfg(feature = "mock")]
135            InnerConnection::Mock(ref mut c) => {
136                c.commit();
137                Ok(())
138            }
139            #[allow(unreachable_patterns)]
140            _ => Err(conn_err("Disconnected")),
141        }?;
142        self.open = false;
143        Ok(())
144    }
145
146    /// rolls back a transaction in case error are encountered during the operation
147    #[instrument(level = "trace")]
148    #[allow(unreachable_code, unused_mut)]
149    pub async fn rollback(mut self) -> Result<(), DbErr> {
150        match *self.conn.lock().await {
151            #[cfg(feature = "sqlx-mysql")]
152            InnerConnection::MySql(ref mut c) => {
153                <sqlx::MySql as sqlx::Database>::TransactionManager::rollback(c)
154                    .await
155                    .map_err(sqlx_error_to_query_err)
156            }
157            #[cfg(feature = "sqlx-postgres")]
158            InnerConnection::Postgres(ref mut c) => {
159                <sqlx::Postgres as sqlx::Database>::TransactionManager::rollback(c)
160                    .await
161                    .map_err(sqlx_error_to_query_err)
162            }
163            #[cfg(feature = "sqlx-sqlite")]
164            InnerConnection::Sqlite(ref mut c) => {
165                <sqlx::Sqlite as sqlx::Database>::TransactionManager::rollback(c)
166                    .await
167                    .map_err(sqlx_error_to_query_err)
168            }
169            #[cfg(feature = "mock")]
170            InnerConnection::Mock(ref mut c) => {
171                c.rollback();
172                Ok(())
173            }
174            #[allow(unreachable_patterns)]
175            _ => Err(conn_err("Disconnected")),
176        }?;
177        self.open = false;
178        Ok(())
179    }
180
181    // the rollback is queued and will be performed on next async operation, like returning the connection to the pool
182    #[instrument(level = "trace")]
183    fn start_rollback(&mut self) -> Result<(), DbErr> {
184        if self.open {
185            if let Some(mut conn) = self.conn.try_lock() {
186                match &mut *conn {
187                    #[cfg(feature = "sqlx-mysql")]
188                    InnerConnection::MySql(c) => {
189                        <sqlx::MySql as sqlx::Database>::TransactionManager::start_rollback(c);
190                    }
191                    #[cfg(feature = "sqlx-postgres")]
192                    InnerConnection::Postgres(c) => {
193                        <sqlx::Postgres as sqlx::Database>::TransactionManager::start_rollback(c);
194                    }
195                    #[cfg(feature = "sqlx-sqlite")]
196                    InnerConnection::Sqlite(c) => {
197                        <sqlx::Sqlite as sqlx::Database>::TransactionManager::start_rollback(c);
198                    }
199                    #[cfg(feature = "mock")]
200                    InnerConnection::Mock(c) => {
201                        c.rollback();
202                    }
203                    #[allow(unreachable_patterns)]
204                    _ => return Err(conn_err("Disconnected")),
205                }
206            } else {
207                //this should never happen
208                return Err(conn_err("Dropping a locked Transaction"));
209            }
210        }
211        Ok(())
212    }
213}
214
215impl Drop for DatabaseTransaction {
216    fn drop(&mut self) {
217        self.start_rollback().expect("Fail to rollback transaction");
218    }
219}
220
221#[async_trait::async_trait]
222impl ConnectionTrait for DatabaseTransaction {
223    fn get_database_backend(&self) -> DbBackend {
224        // this way we don't need to lock
225        self.backend
226    }
227
228    #[instrument(level = "trace")]
229    #[allow(unused_variables)]
230    async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
231        debug_print!("{}", stmt);
232
233        match &mut *self.conn.lock().await {
234            #[cfg(feature = "sqlx-mysql")]
235            InnerConnection::MySql(conn) => {
236                let query = crate::driver::sqlx_mysql::sqlx_query(&stmt);
237                let conn: &mut sqlx::MySqlConnection = &mut *conn;
238                crate::metric::metric!(self.metric_callback, &stmt, {
239                    query.execute(conn).await.map(Into::into)
240                })
241                .map_err(sqlx_error_to_exec_err)
242            }
243            #[cfg(feature = "sqlx-postgres")]
244            InnerConnection::Postgres(conn) => {
245                let query = crate::driver::sqlx_postgres::sqlx_query(&stmt);
246                let conn: &mut sqlx::PgConnection = &mut *conn;
247                crate::metric::metric!(self.metric_callback, &stmt, {
248                    query.execute(conn).await.map(Into::into)
249                })
250                .map_err(sqlx_error_to_exec_err)
251            }
252            #[cfg(feature = "sqlx-sqlite")]
253            InnerConnection::Sqlite(conn) => {
254                let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt);
255                let conn: &mut sqlx::SqliteConnection = &mut *conn;
256                crate::metric::metric!(self.metric_callback, &stmt, {
257                    query.execute(conn).await.map(Into::into)
258                })
259                .map_err(sqlx_error_to_exec_err)
260            }
261            #[cfg(feature = "mock")]
262            InnerConnection::Mock(conn) => return conn.execute(stmt),
263            #[allow(unreachable_patterns)]
264            _ => Err(conn_err("Disconnected")),
265        }
266    }
267
268    #[instrument(level = "trace")]
269    #[allow(unused_variables)]
270    async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
271        debug_print!("{}", sql);
272
273        match &mut *self.conn.lock().await {
274            #[cfg(feature = "sqlx-mysql")]
275            InnerConnection::MySql(conn) => {
276                let conn: &mut sqlx::MySqlConnection = &mut *conn;
277                sqlx::Executor::execute(conn, sql)
278                    .await
279                    .map(Into::into)
280                    .map_err(sqlx_error_to_exec_err)
281            }
282            #[cfg(feature = "sqlx-postgres")]
283            InnerConnection::Postgres(conn) => {
284                let conn: &mut sqlx::PgConnection = &mut *conn;
285                sqlx::Executor::execute(conn, sql)
286                    .await
287                    .map(Into::into)
288                    .map_err(sqlx_error_to_exec_err)
289            }
290            #[cfg(feature = "sqlx-sqlite")]
291            InnerConnection::Sqlite(conn) => {
292                let conn: &mut sqlx::SqliteConnection = &mut *conn;
293                sqlx::Executor::execute(conn, sql)
294                    .await
295                    .map(Into::into)
296                    .map_err(sqlx_error_to_exec_err)
297            }
298            #[cfg(feature = "mock")]
299            InnerConnection::Mock(conn) => {
300                let db_backend = conn.get_database_backend();
301                let stmt = Statement::from_string(db_backend, sql);
302                conn.execute(stmt)
303            }
304            #[allow(unreachable_patterns)]
305            _ => Err(conn_err("Disconnected")),
306        }
307    }
308
309    #[instrument(level = "trace")]
310    #[allow(unused_variables)]
311    async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
312        debug_print!("{}", stmt);
313
314        match &mut *self.conn.lock().await {
315            #[cfg(feature = "sqlx-mysql")]
316            InnerConnection::MySql(conn) => {
317                let query = crate::driver::sqlx_mysql::sqlx_query(&stmt);
318                let conn: &mut sqlx::MySqlConnection = &mut *conn;
319                crate::metric::metric!(self.metric_callback, &stmt, {
320                    crate::sqlx_map_err_ignore_not_found(
321                        query.fetch_one(conn).await.map(|row| Some(row.into())),
322                    )
323                })
324            }
325            #[cfg(feature = "sqlx-postgres")]
326            InnerConnection::Postgres(conn) => {
327                let query = crate::driver::sqlx_postgres::sqlx_query(&stmt);
328                let conn: &mut sqlx::PgConnection = &mut *conn;
329                crate::metric::metric!(self.metric_callback, &stmt, {
330                    crate::sqlx_map_err_ignore_not_found(
331                        query.fetch_one(conn).await.map(|row| Some(row.into())),
332                    )
333                })
334            }
335            #[cfg(feature = "sqlx-sqlite")]
336            InnerConnection::Sqlite(conn) => {
337                let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt);
338                let conn: &mut sqlx::SqliteConnection = &mut *conn;
339                crate::metric::metric!(self.metric_callback, &stmt, {
340                    crate::sqlx_map_err_ignore_not_found(
341                        query.fetch_one(conn).await.map(|row| Some(row.into())),
342                    )
343                })
344            }
345            #[cfg(feature = "mock")]
346            InnerConnection::Mock(conn) => return conn.query_one(stmt),
347            #[allow(unreachable_patterns)]
348            _ => Err(conn_err("Disconnected")),
349        }
350    }
351
352    #[instrument(level = "trace")]
353    #[allow(unused_variables)]
354    async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
355        debug_print!("{}", stmt);
356
357        match &mut *self.conn.lock().await {
358            #[cfg(feature = "sqlx-mysql")]
359            InnerConnection::MySql(conn) => {
360                let query = crate::driver::sqlx_mysql::sqlx_query(&stmt);
361                let conn: &mut sqlx::MySqlConnection = &mut *conn;
362                crate::metric::metric!(self.metric_callback, &stmt, {
363                    query
364                        .fetch_all(conn)
365                        .await
366                        .map(|rows| rows.into_iter().map(|r| r.into()).collect())
367                        .map_err(sqlx_error_to_query_err)
368                })
369            }
370            #[cfg(feature = "sqlx-postgres")]
371            InnerConnection::Postgres(conn) => {
372                let query = crate::driver::sqlx_postgres::sqlx_query(&stmt);
373                let conn: &mut sqlx::PgConnection = &mut *conn;
374                crate::metric::metric!(self.metric_callback, &stmt, {
375                    query
376                        .fetch_all(conn)
377                        .await
378                        .map(|rows| rows.into_iter().map(|r| r.into()).collect())
379                        .map_err(sqlx_error_to_query_err)
380                })
381            }
382            #[cfg(feature = "sqlx-sqlite")]
383            InnerConnection::Sqlite(conn) => {
384                let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt);
385                let conn: &mut sqlx::SqliteConnection = &mut *conn;
386                crate::metric::metric!(self.metric_callback, &stmt, {
387                    query
388                        .fetch_all(conn)
389                        .await
390                        .map(|rows| rows.into_iter().map(|r| r.into()).collect())
391                        .map_err(sqlx_error_to_query_err)
392                })
393            }
394            #[cfg(feature = "mock")]
395            InnerConnection::Mock(conn) => return conn.query_all(stmt),
396            #[allow(unreachable_patterns)]
397            _ => Err(conn_err("Disconnected")),
398        }
399    }
400}
401
402impl StreamTrait for DatabaseTransaction {
403    type Stream<'a> = TransactionStream<'a>;
404
405    #[instrument(level = "trace")]
406    fn stream<'a>(
407        &'a self,
408        stmt: Statement,
409    ) -> Pin<Box<dyn Future<Output = Result<Self::Stream<'a>, DbErr>> + 'a + Send>> {
410        Box::pin(async move {
411            let conn = self.conn.lock().await;
412            Ok(crate::TransactionStream::build(
413                conn,
414                stmt,
415                self.metric_callback.clone(),
416            ))
417        })
418    }
419}
420
421#[async_trait::async_trait]
422impl TransactionTrait for DatabaseTransaction {
423    #[instrument(level = "trace")]
424    async fn begin(&self) -> Result<DatabaseTransaction, DbErr> {
425        DatabaseTransaction::begin(
426            Arc::clone(&self.conn),
427            self.backend,
428            self.metric_callback.clone(),
429            None,
430            None,
431        )
432        .await
433    }
434
435    #[instrument(level = "trace")]
436    async fn begin_with_config(
437        &self,
438        isolation_level: Option<IsolationLevel>,
439        access_mode: Option<AccessMode>,
440    ) -> Result<DatabaseTransaction, DbErr> {
441        DatabaseTransaction::begin(
442            Arc::clone(&self.conn),
443            self.backend,
444            self.metric_callback.clone(),
445            isolation_level,
446            access_mode,
447        )
448        .await
449    }
450
451    /// Execute the function inside a transaction.
452    /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed.
453    #[instrument(level = "trace", skip(_callback))]
454    async fn transaction<F, T, E>(&self, _callback: F) -> Result<T, TransactionError<E>>
455    where
456        F: for<'c> FnOnce(
457                &'c DatabaseTransaction,
458            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
459            + Send,
460        T: Send,
461        E: std::error::Error + Send,
462    {
463        let transaction = self.begin().await.map_err(TransactionError::Connection)?;
464        transaction.run(_callback).await
465    }
466
467    /// Execute the function inside a transaction with isolation level and/or access mode.
468    /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed.
469    #[instrument(level = "trace", skip(_callback))]
470    async fn transaction_with_config<F, T, E>(
471        &self,
472        _callback: F,
473        isolation_level: Option<IsolationLevel>,
474        access_mode: Option<AccessMode>,
475    ) -> Result<T, TransactionError<E>>
476    where
477        F: for<'c> FnOnce(
478                &'c DatabaseTransaction,
479            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
480            + Send,
481        T: Send,
482        E: std::error::Error + Send,
483    {
484        let transaction = self
485            .begin_with_config(isolation_level, access_mode)
486            .await
487            .map_err(TransactionError::Connection)?;
488        transaction.run(_callback).await
489    }
490}
491
492/// Defines errors for handling transaction failures
493#[derive(Debug)]
494pub enum TransactionError<E>
495where
496    E: std::error::Error,
497{
498    /// A Database connection error
499    Connection(DbErr),
500    /// An error occurring when doing database transactions
501    Transaction(E),
502}
503
504impl<E> std::fmt::Display for TransactionError<E>
505where
506    E: std::error::Error,
507{
508    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
509        match self {
510            TransactionError::Connection(e) => std::fmt::Display::fmt(e, f),
511            TransactionError::Transaction(e) => std::fmt::Display::fmt(e, f),
512        }
513    }
514}
515
516impl<E> std::error::Error for TransactionError<E> where E: std::error::Error {}
517
518impl<E> From<DbErr> for TransactionError<E>
519where
520    E: std::error::Error,
521{
522    fn from(e: DbErr) -> Self {
523        Self::Connection(e)
524    }
525}