sea_orm/database/
transaction.rs

1use crate::{
2    debug_print, error::*, AccessMode, ConnectionTrait, DbBackend, DbErr, ExecResult,
3    InnerConnection, IsolationLevel, QueryResult, Statement, StreamTrait, TransactionStream,
4    TransactionTrait,
5};
6#[cfg(feature = "sqlx-dep")]
7use crate::{sqlx_error_to_exec_err, sqlx_error_to_query_err};
8use futures_util::lock::Mutex;
9#[cfg(feature = "sqlx-dep")]
10use sqlx::TransactionManager;
11use std::{future::Future, pin::Pin, sync::Arc};
12use tracing::instrument;
13
14// a Transaction is just a sugar for a connection where START TRANSACTION has been executed
15/// Defines a database transaction, whether it is an open transaction and the type of
16/// backend to use
17pub struct DatabaseTransaction {
18    conn: Arc<Mutex<InnerConnection>>,
19    backend: DbBackend,
20    open: bool,
21    metric_callback: Option<crate::metric::Callback>,
22}
23
24impl std::fmt::Debug for DatabaseTransaction {
25    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26        write!(f, "DatabaseTransaction")
27    }
28}
29
30impl DatabaseTransaction {
31    #[instrument(level = "trace", skip(metric_callback))]
32    pub(crate) async fn begin(
33        conn: Arc<Mutex<InnerConnection>>,
34        backend: DbBackend,
35        metric_callback: Option<crate::metric::Callback>,
36        isolation_level: Option<IsolationLevel>,
37        access_mode: Option<AccessMode>,
38    ) -> Result<DatabaseTransaction, DbErr> {
39        let res = DatabaseTransaction {
40            conn,
41            backend,
42            open: true,
43            metric_callback,
44        };
45        match *res.conn.lock().await {
46            #[cfg(feature = "sqlx-mysql")]
47            InnerConnection::MySql(ref mut c) => {
48                // in MySQL SET TRANSACTION operations must be executed before transaction start
49                crate::driver::sqlx_mysql::set_transaction_config(c, isolation_level, access_mode)
50                    .await?;
51                <sqlx::MySql as sqlx::Database>::TransactionManager::begin(c, None)
52                    .await
53                    .map_err(sqlx_error_to_query_err)
54            }
55            #[cfg(feature = "sqlx-postgres")]
56            InnerConnection::Postgres(ref mut c) => {
57                <sqlx::Postgres as sqlx::Database>::TransactionManager::begin(c, None)
58                    .await
59                    .map_err(sqlx_error_to_query_err)?;
60                // in PostgreSQL SET TRANSACTION operations must be executed inside transaction
61                crate::driver::sqlx_postgres::set_transaction_config(
62                    c,
63                    isolation_level,
64                    access_mode,
65                )
66                .await
67            }
68            #[cfg(feature = "sqlx-sqlite")]
69            InnerConnection::Sqlite(ref mut c) => {
70                // in SQLite isolation level and access mode are global settings
71                crate::driver::sqlx_sqlite::set_transaction_config(c, isolation_level, access_mode)
72                    .await?;
73                <sqlx::Sqlite as sqlx::Database>::TransactionManager::begin(c, None)
74                    .await
75                    .map_err(sqlx_error_to_query_err)
76            }
77            #[cfg(feature = "mock")]
78            InnerConnection::Mock(ref mut c) => {
79                c.begin();
80                Ok(())
81            }
82            #[cfg(feature = "proxy")]
83            InnerConnection::Proxy(ref mut c) => {
84                c.begin().await;
85                Ok(())
86            }
87            #[allow(unreachable_patterns)]
88            _ => Err(conn_err("Disconnected")),
89        }?;
90        Ok(res)
91    }
92
93    /// Runs a transaction to completion returning an rolling back the transaction on
94    /// encountering an error if it fails
95    #[instrument(level = "trace", skip(callback))]
96    pub(crate) async fn run<F, T, E>(self, callback: F) -> Result<T, TransactionError<E>>
97    where
98        F: for<'b> FnOnce(
99                &'b DatabaseTransaction,
100            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'b>>
101            + Send,
102        T: Send,
103        E: std::fmt::Display + std::fmt::Debug + Send,
104    {
105        let res = callback(&self).await.map_err(TransactionError::Transaction);
106        if res.is_ok() {
107            self.commit().await.map_err(TransactionError::Connection)?;
108        } else {
109            self.rollback()
110                .await
111                .map_err(TransactionError::Connection)?;
112        }
113        res
114    }
115
116    /// Commit a transaction atomically
117    #[instrument(level = "trace")]
118    #[allow(unreachable_code, unused_mut)]
119    pub async fn commit(mut self) -> Result<(), DbErr> {
120        match *self.conn.lock().await {
121            #[cfg(feature = "sqlx-mysql")]
122            InnerConnection::MySql(ref mut c) => {
123                <sqlx::MySql as sqlx::Database>::TransactionManager::commit(c)
124                    .await
125                    .map_err(sqlx_error_to_query_err)
126            }
127            #[cfg(feature = "sqlx-postgres")]
128            InnerConnection::Postgres(ref mut c) => {
129                <sqlx::Postgres as sqlx::Database>::TransactionManager::commit(c)
130                    .await
131                    .map_err(sqlx_error_to_query_err)
132            }
133            #[cfg(feature = "sqlx-sqlite")]
134            InnerConnection::Sqlite(ref mut c) => {
135                <sqlx::Sqlite as sqlx::Database>::TransactionManager::commit(c)
136                    .await
137                    .map_err(sqlx_error_to_query_err)
138            }
139            #[cfg(feature = "mock")]
140            InnerConnection::Mock(ref mut c) => {
141                c.commit();
142                Ok(())
143            }
144            #[cfg(feature = "proxy")]
145            InnerConnection::Proxy(ref mut c) => {
146                c.commit().await;
147                Ok(())
148            }
149            #[allow(unreachable_patterns)]
150            _ => Err(conn_err("Disconnected")),
151        }?;
152        self.open = false;
153        Ok(())
154    }
155
156    /// rolls back a transaction in case error are encountered during the operation
157    #[instrument(level = "trace")]
158    #[allow(unreachable_code, unused_mut)]
159    pub async fn rollback(mut self) -> Result<(), DbErr> {
160        match *self.conn.lock().await {
161            #[cfg(feature = "sqlx-mysql")]
162            InnerConnection::MySql(ref mut c) => {
163                <sqlx::MySql as sqlx::Database>::TransactionManager::rollback(c)
164                    .await
165                    .map_err(sqlx_error_to_query_err)
166            }
167            #[cfg(feature = "sqlx-postgres")]
168            InnerConnection::Postgres(ref mut c) => {
169                <sqlx::Postgres as sqlx::Database>::TransactionManager::rollback(c)
170                    .await
171                    .map_err(sqlx_error_to_query_err)
172            }
173            #[cfg(feature = "sqlx-sqlite")]
174            InnerConnection::Sqlite(ref mut c) => {
175                <sqlx::Sqlite as sqlx::Database>::TransactionManager::rollback(c)
176                    .await
177                    .map_err(sqlx_error_to_query_err)
178            }
179            #[cfg(feature = "mock")]
180            InnerConnection::Mock(ref mut c) => {
181                c.rollback();
182                Ok(())
183            }
184            #[cfg(feature = "proxy")]
185            InnerConnection::Proxy(ref mut c) => {
186                c.rollback().await;
187                Ok(())
188            }
189            #[allow(unreachable_patterns)]
190            _ => Err(conn_err("Disconnected")),
191        }?;
192        self.open = false;
193        Ok(())
194    }
195
196    // the rollback is queued and will be performed on next async operation, like returning the connection to the pool
197    #[instrument(level = "trace")]
198    fn start_rollback(&mut self) -> Result<(), DbErr> {
199        if self.open {
200            if let Some(mut conn) = self.conn.try_lock() {
201                match &mut *conn {
202                    #[cfg(feature = "sqlx-mysql")]
203                    InnerConnection::MySql(c) => {
204                        <sqlx::MySql as sqlx::Database>::TransactionManager::start_rollback(c);
205                    }
206                    #[cfg(feature = "sqlx-postgres")]
207                    InnerConnection::Postgres(c) => {
208                        <sqlx::Postgres as sqlx::Database>::TransactionManager::start_rollback(c);
209                    }
210                    #[cfg(feature = "sqlx-sqlite")]
211                    InnerConnection::Sqlite(c) => {
212                        <sqlx::Sqlite as sqlx::Database>::TransactionManager::start_rollback(c);
213                    }
214                    #[cfg(feature = "mock")]
215                    InnerConnection::Mock(c) => {
216                        c.rollback();
217                    }
218                    #[cfg(feature = "proxy")]
219                    InnerConnection::Proxy(c) => {
220                        c.start_rollback();
221                    }
222                    #[allow(unreachable_patterns)]
223                    _ => return Err(conn_err("Disconnected")),
224                }
225            } else {
226                //this should never happen
227                return Err(conn_err("Dropping a locked Transaction"));
228            }
229        }
230        Ok(())
231    }
232}
233
234impl Drop for DatabaseTransaction {
235    fn drop(&mut self) {
236        self.start_rollback().expect("Fail to rollback transaction");
237    }
238}
239
240#[async_trait::async_trait]
241impl ConnectionTrait for DatabaseTransaction {
242    fn get_database_backend(&self) -> DbBackend {
243        // this way we don't need to lock
244        self.backend
245    }
246
247    #[instrument(level = "trace")]
248    #[allow(unused_variables)]
249    async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
250        debug_print!("{}", stmt);
251
252        match &mut *self.conn.lock().await {
253            #[cfg(feature = "sqlx-mysql")]
254            InnerConnection::MySql(conn) => {
255                let query = crate::driver::sqlx_mysql::sqlx_query(&stmt);
256                let conn: &mut sqlx::MySqlConnection = &mut *conn;
257                crate::metric::metric!(self.metric_callback, &stmt, {
258                    query.execute(conn).await.map(Into::into)
259                })
260                .map_err(sqlx_error_to_exec_err)
261            }
262            #[cfg(feature = "sqlx-postgres")]
263            InnerConnection::Postgres(conn) => {
264                let query = crate::driver::sqlx_postgres::sqlx_query(&stmt);
265                let conn: &mut sqlx::PgConnection = &mut *conn;
266                crate::metric::metric!(self.metric_callback, &stmt, {
267                    query.execute(conn).await.map(Into::into)
268                })
269                .map_err(sqlx_error_to_exec_err)
270            }
271            #[cfg(feature = "sqlx-sqlite")]
272            InnerConnection::Sqlite(conn) => {
273                let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt);
274                let conn: &mut sqlx::SqliteConnection = &mut *conn;
275                crate::metric::metric!(self.metric_callback, &stmt, {
276                    query.execute(conn).await.map(Into::into)
277                })
278                .map_err(sqlx_error_to_exec_err)
279            }
280            #[cfg(feature = "mock")]
281            InnerConnection::Mock(conn) => return conn.execute(stmt),
282            #[cfg(feature = "proxy")]
283            InnerConnection::Proxy(conn) => return conn.execute(stmt).await,
284            #[allow(unreachable_patterns)]
285            _ => Err(conn_err("Disconnected")),
286        }
287    }
288
289    #[instrument(level = "trace")]
290    #[allow(unused_variables)]
291    async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
292        debug_print!("{}", sql);
293
294        match &mut *self.conn.lock().await {
295            #[cfg(feature = "sqlx-mysql")]
296            InnerConnection::MySql(conn) => {
297                let conn: &mut sqlx::MySqlConnection = &mut *conn;
298                sqlx::Executor::execute(conn, sql)
299                    .await
300                    .map(Into::into)
301                    .map_err(sqlx_error_to_exec_err)
302            }
303            #[cfg(feature = "sqlx-postgres")]
304            InnerConnection::Postgres(conn) => {
305                let conn: &mut sqlx::PgConnection = &mut *conn;
306                sqlx::Executor::execute(conn, sql)
307                    .await
308                    .map(Into::into)
309                    .map_err(sqlx_error_to_exec_err)
310            }
311            #[cfg(feature = "sqlx-sqlite")]
312            InnerConnection::Sqlite(conn) => {
313                let conn: &mut sqlx::SqliteConnection = &mut *conn;
314                sqlx::Executor::execute(conn, sql)
315                    .await
316                    .map(Into::into)
317                    .map_err(sqlx_error_to_exec_err)
318            }
319            #[cfg(feature = "mock")]
320            InnerConnection::Mock(conn) => {
321                let db_backend = conn.get_database_backend();
322                let stmt = Statement::from_string(db_backend, sql);
323                conn.execute(stmt)
324            }
325            #[cfg(feature = "proxy")]
326            InnerConnection::Proxy(conn) => {
327                let db_backend = conn.get_database_backend();
328                let stmt = Statement::from_string(db_backend, sql);
329                conn.execute(stmt).await
330            }
331            #[allow(unreachable_patterns)]
332            _ => Err(conn_err("Disconnected")),
333        }
334    }
335
336    #[instrument(level = "trace")]
337    #[allow(unused_variables)]
338    async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
339        debug_print!("{}", stmt);
340
341        match &mut *self.conn.lock().await {
342            #[cfg(feature = "sqlx-mysql")]
343            InnerConnection::MySql(conn) => {
344                let query = crate::driver::sqlx_mysql::sqlx_query(&stmt);
345                let conn: &mut sqlx::MySqlConnection = &mut *conn;
346                crate::metric::metric!(self.metric_callback, &stmt, {
347                    crate::sqlx_map_err_ignore_not_found(
348                        query.fetch_one(conn).await.map(|row| Some(row.into())),
349                    )
350                })
351            }
352            #[cfg(feature = "sqlx-postgres")]
353            InnerConnection::Postgres(conn) => {
354                let query = crate::driver::sqlx_postgres::sqlx_query(&stmt);
355                let conn: &mut sqlx::PgConnection = &mut *conn;
356                crate::metric::metric!(self.metric_callback, &stmt, {
357                    crate::sqlx_map_err_ignore_not_found(
358                        query.fetch_one(conn).await.map(|row| Some(row.into())),
359                    )
360                })
361            }
362            #[cfg(feature = "sqlx-sqlite")]
363            InnerConnection::Sqlite(conn) => {
364                let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt);
365                let conn: &mut sqlx::SqliteConnection = &mut *conn;
366                crate::metric::metric!(self.metric_callback, &stmt, {
367                    crate::sqlx_map_err_ignore_not_found(
368                        query.fetch_one(conn).await.map(|row| Some(row.into())),
369                    )
370                })
371            }
372            #[cfg(feature = "mock")]
373            InnerConnection::Mock(conn) => return conn.query_one(stmt),
374            #[cfg(feature = "proxy")]
375            InnerConnection::Proxy(conn) => return conn.query_one(stmt).await,
376            #[allow(unreachable_patterns)]
377            _ => Err(conn_err("Disconnected")),
378        }
379    }
380
381    #[instrument(level = "trace")]
382    #[allow(unused_variables)]
383    async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
384        debug_print!("{}", stmt);
385
386        match &mut *self.conn.lock().await {
387            #[cfg(feature = "sqlx-mysql")]
388            InnerConnection::MySql(conn) => {
389                let query = crate::driver::sqlx_mysql::sqlx_query(&stmt);
390                let conn: &mut sqlx::MySqlConnection = &mut *conn;
391                crate::metric::metric!(self.metric_callback, &stmt, {
392                    query
393                        .fetch_all(conn)
394                        .await
395                        .map(|rows| rows.into_iter().map(|r| r.into()).collect())
396                        .map_err(sqlx_error_to_query_err)
397                })
398            }
399            #[cfg(feature = "sqlx-postgres")]
400            InnerConnection::Postgres(conn) => {
401                let query = crate::driver::sqlx_postgres::sqlx_query(&stmt);
402                let conn: &mut sqlx::PgConnection = &mut *conn;
403                crate::metric::metric!(self.metric_callback, &stmt, {
404                    query
405                        .fetch_all(conn)
406                        .await
407                        .map(|rows| rows.into_iter().map(|r| r.into()).collect())
408                        .map_err(sqlx_error_to_query_err)
409                })
410            }
411            #[cfg(feature = "sqlx-sqlite")]
412            InnerConnection::Sqlite(conn) => {
413                let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt);
414                let conn: &mut sqlx::SqliteConnection = &mut *conn;
415                crate::metric::metric!(self.metric_callback, &stmt, {
416                    query
417                        .fetch_all(conn)
418                        .await
419                        .map(|rows| rows.into_iter().map(|r| r.into()).collect())
420                        .map_err(sqlx_error_to_query_err)
421                })
422            }
423            #[cfg(feature = "mock")]
424            InnerConnection::Mock(conn) => return conn.query_all(stmt),
425            #[cfg(feature = "proxy")]
426            InnerConnection::Proxy(conn) => return conn.query_all(stmt).await,
427            #[allow(unreachable_patterns)]
428            _ => Err(conn_err("Disconnected")),
429        }
430    }
431}
432
433impl StreamTrait for DatabaseTransaction {
434    type Stream<'a> = TransactionStream<'a>;
435
436    #[instrument(level = "trace")]
437    fn stream<'a>(
438        &'a self,
439        stmt: Statement,
440    ) -> Pin<Box<dyn Future<Output = Result<Self::Stream<'a>, DbErr>> + 'a + Send>> {
441        Box::pin(async move {
442            let conn = self.conn.lock().await;
443            Ok(crate::TransactionStream::build(
444                conn,
445                stmt,
446                self.metric_callback.clone(),
447            ))
448        })
449    }
450}
451
452#[async_trait::async_trait]
453impl TransactionTrait for DatabaseTransaction {
454    #[instrument(level = "trace")]
455    async fn begin(&self) -> Result<DatabaseTransaction, DbErr> {
456        DatabaseTransaction::begin(
457            Arc::clone(&self.conn),
458            self.backend,
459            self.metric_callback.clone(),
460            None,
461            None,
462        )
463        .await
464    }
465
466    #[instrument(level = "trace")]
467    async fn begin_with_config(
468        &self,
469        isolation_level: Option<IsolationLevel>,
470        access_mode: Option<AccessMode>,
471    ) -> Result<DatabaseTransaction, DbErr> {
472        DatabaseTransaction::begin(
473            Arc::clone(&self.conn),
474            self.backend,
475            self.metric_callback.clone(),
476            isolation_level,
477            access_mode,
478        )
479        .await
480    }
481
482    /// Execute the function inside a transaction.
483    /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed.
484    #[instrument(level = "trace", skip(_callback))]
485    async fn transaction<F, T, E>(&self, _callback: F) -> Result<T, TransactionError<E>>
486    where
487        F: for<'c> FnOnce(
488                &'c DatabaseTransaction,
489            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
490            + Send,
491        T: Send,
492        E: std::fmt::Display + std::fmt::Debug + Send,
493    {
494        let transaction = self.begin().await.map_err(TransactionError::Connection)?;
495        transaction.run(_callback).await
496    }
497
498    /// Execute the function inside a transaction with isolation level and/or access mode.
499    /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed.
500    #[instrument(level = "trace", skip(_callback))]
501    async fn transaction_with_config<F, T, E>(
502        &self,
503        _callback: F,
504        isolation_level: Option<IsolationLevel>,
505        access_mode: Option<AccessMode>,
506    ) -> Result<T, TransactionError<E>>
507    where
508        F: for<'c> FnOnce(
509                &'c DatabaseTransaction,
510            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
511            + Send,
512        T: Send,
513        E: std::fmt::Display + std::fmt::Debug + Send,
514    {
515        let transaction = self
516            .begin_with_config(isolation_level, access_mode)
517            .await
518            .map_err(TransactionError::Connection)?;
519        transaction.run(_callback).await
520    }
521}
522
523/// Defines errors for handling transaction failures
524#[derive(Debug)]
525pub enum TransactionError<E> {
526    /// A Database connection error
527    Connection(DbErr),
528    /// An error occurring when doing database transactions
529    Transaction(E),
530}
531
532impl<E> std::fmt::Display for TransactionError<E>
533where
534    E: std::fmt::Display + std::fmt::Debug,
535{
536    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
537        match self {
538            TransactionError::Connection(e) => std::fmt::Display::fmt(e, f),
539            TransactionError::Transaction(e) => std::fmt::Display::fmt(e, f),
540        }
541    }
542}
543
544impl<E> std::error::Error for TransactionError<E> where E: std::fmt::Display + std::fmt::Debug {}
545
546impl<E> From<DbErr> for TransactionError<E>
547where
548    E: std::fmt::Display + std::fmt::Debug,
549{
550    fn from(e: DbErr) -> Self {
551        Self::Connection(e)
552    }
553}