sea_orm/database/
transaction.rs

1#![allow(unused_assignments)]
2use crate::{
3    AccessMode, ConnectionTrait, DbBackend, DbErr, ExecResult, InnerConnection, IsolationLevel,
4    QueryResult, Statement, StreamTrait, TransactionSession, TransactionStream, TransactionTrait,
5    debug_print, error::*,
6};
7#[cfg(feature = "sqlx-dep")]
8use crate::{sqlx_error_to_exec_err, sqlx_error_to_query_err};
9#[cfg(feature = "sqlx-dep")]
10use sqlx::TransactionManager;
11use std::sync::Arc;
12use std::sync::Mutex;
13use tracing::instrument;
14
15/// Defines a database transaction, whether it is an open transaction and the type of
16/// backend to use.
17/// Under the hood, a Transaction is just a wrapper for a connection where
18/// START TRANSACTION has been executed.
19pub struct DatabaseTransaction {
20    conn: Arc<Mutex<InnerConnection>>,
21    backend: DbBackend,
22    open: bool,
23    metric_callback: Option<crate::metric::Callback>,
24}
25
26impl std::fmt::Debug for DatabaseTransaction {
27    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28        write!(f, "DatabaseTransaction")
29    }
30}
31
32impl DatabaseTransaction {
33    #[instrument(level = "trace", skip(metric_callback))]
34    pub(crate) fn begin(
35        conn: Arc<Mutex<InnerConnection>>,
36        backend: DbBackend,
37        metric_callback: Option<crate::metric::Callback>,
38        isolation_level: Option<IsolationLevel>,
39        access_mode: Option<AccessMode>,
40    ) -> Result<DatabaseTransaction, DbErr> {
41        let res = DatabaseTransaction {
42            conn,
43            backend,
44            open: true,
45            metric_callback,
46        };
47        {
48            #[cfg(not(feature = "sync"))]
49            let conn = &mut *res.conn.lock();
50            #[cfg(feature = "sync")]
51            let conn = &mut *res.conn.lock().map_err(|_| DbErr::MutexPoisonError)?;
52
53            match conn {
54                #[cfg(feature = "sqlx-mysql")]
55                InnerConnection::MySql(c) => {
56                    // in MySQL SET TRANSACTION operations must be executed before transaction start
57                    crate::driver::sqlx_mysql::set_transaction_config(
58                        c,
59                        isolation_level,
60                        access_mode,
61                    )?;
62                    <sqlx::MySql as sqlx::Database>::TransactionManager::begin(c, None)
63                        .map_err(sqlx_error_to_query_err)
64                }
65                #[cfg(feature = "sqlx-postgres")]
66                InnerConnection::Postgres(c) => {
67                    <sqlx::Postgres as sqlx::Database>::TransactionManager::begin(c, None)
68                        .map_err(sqlx_error_to_query_err)?;
69                    // in PostgreSQL SET TRANSACTION operations must be executed inside transaction
70                    crate::driver::sqlx_postgres::set_transaction_config(
71                        c,
72                        isolation_level,
73                        access_mode,
74                    )
75                }
76                #[cfg(feature = "sqlx-sqlite")]
77                InnerConnection::Sqlite(c) => {
78                    // in SQLite isolation level and access mode are global settings
79                    crate::driver::sqlx_sqlite::set_transaction_config(
80                        c,
81                        isolation_level,
82                        access_mode,
83                    )?;
84                    <sqlx::Sqlite as sqlx::Database>::TransactionManager::begin(c, None)
85                        .map_err(sqlx_error_to_query_err)
86                }
87                #[cfg(feature = "rusqlite")]
88                InnerConnection::Rusqlite(c) => c.begin(),
89                #[cfg(feature = "mock")]
90                InnerConnection::Mock(c) => {
91                    c.begin();
92                    Ok(())
93                }
94                #[cfg(feature = "proxy")]
95                InnerConnection::Proxy(c) => {
96                    c.begin();
97                    Ok(())
98                }
99                #[allow(unreachable_patterns)]
100                _ => Err(conn_err("Disconnected")),
101            }?
102        };
103
104        Ok(res)
105    }
106
107    /// Runs a transaction to completion passing through the result.
108    /// Rolling back the transaction on encountering an error.
109    #[instrument(level = "trace", skip(callback))]
110    pub(crate) fn run<F, T, E>(self, callback: F) -> Result<T, TransactionError<E>>
111    where
112        F: for<'b> FnOnce(&'b DatabaseTransaction) -> Result<T, E>,
113        E: std::fmt::Display + std::fmt::Debug,
114    {
115        let res = callback(&self).map_err(TransactionError::Transaction);
116        if res.is_ok() {
117            self.commit().map_err(TransactionError::Connection)?;
118        } else {
119            self.rollback().map_err(TransactionError::Connection)?;
120        }
121        res
122    }
123
124    /// Commit a transaction
125    #[instrument(level = "trace")]
126    #[allow(unreachable_code, unused_mut)]
127    pub fn commit(mut self) -> Result<(), DbErr> {
128        #[cfg(not(feature = "sync"))]
129        let conn = &mut *self.conn.lock();
130        #[cfg(feature = "sync")]
131        let conn = &mut *self.conn.lock().map_err(|_| DbErr::MutexPoisonError)?;
132
133        match conn {
134            #[cfg(feature = "sqlx-mysql")]
135            InnerConnection::MySql(c) => {
136                <sqlx::MySql as sqlx::Database>::TransactionManager::commit(c)
137                    .map_err(sqlx_error_to_query_err)
138            }
139            #[cfg(feature = "sqlx-postgres")]
140            InnerConnection::Postgres(c) => {
141                <sqlx::Postgres as sqlx::Database>::TransactionManager::commit(c)
142                    .map_err(sqlx_error_to_query_err)
143            }
144            #[cfg(feature = "sqlx-sqlite")]
145            InnerConnection::Sqlite(c) => {
146                <sqlx::Sqlite as sqlx::Database>::TransactionManager::commit(c)
147                    .map_err(sqlx_error_to_query_err)
148            }
149            #[cfg(feature = "rusqlite")]
150            InnerConnection::Rusqlite(c) => c.commit(),
151            #[cfg(feature = "mock")]
152            InnerConnection::Mock(c) => {
153                c.commit();
154                Ok(())
155            }
156            #[cfg(feature = "proxy")]
157            InnerConnection::Proxy(c) => {
158                c.commit();
159                Ok(())
160            }
161            #[allow(unreachable_patterns)]
162            _ => Err(conn_err("Disconnected")),
163        }?;
164        self.open = false; // read by start_rollback
165        Ok(())
166    }
167
168    /// Rolls back a transaction explicitly
169    #[instrument(level = "trace")]
170    #[allow(unreachable_code, unused_mut)]
171    pub fn rollback(mut self) -> Result<(), DbErr> {
172        #[cfg(not(feature = "sync"))]
173        let conn = &mut *self.conn.lock();
174        #[cfg(feature = "sync")]
175        let conn = &mut *self.conn.lock().map_err(|_| DbErr::MutexPoisonError)?;
176
177        match conn {
178            #[cfg(feature = "sqlx-mysql")]
179            InnerConnection::MySql(c) => {
180                <sqlx::MySql as sqlx::Database>::TransactionManager::rollback(c)
181                    .map_err(sqlx_error_to_query_err)
182            }
183            #[cfg(feature = "sqlx-postgres")]
184            InnerConnection::Postgres(c) => {
185                <sqlx::Postgres as sqlx::Database>::TransactionManager::rollback(c)
186                    .map_err(sqlx_error_to_query_err)
187            }
188            #[cfg(feature = "sqlx-sqlite")]
189            InnerConnection::Sqlite(c) => {
190                <sqlx::Sqlite as sqlx::Database>::TransactionManager::rollback(c)
191                    .map_err(sqlx_error_to_query_err)
192            }
193            #[cfg(feature = "rusqlite")]
194            InnerConnection::Rusqlite(c) => c.rollback(),
195            #[cfg(feature = "mock")]
196            InnerConnection::Mock(c) => {
197                c.rollback();
198                Ok(())
199            }
200            #[cfg(feature = "proxy")]
201            InnerConnection::Proxy(c) => {
202                c.rollback();
203                Ok(())
204            }
205            #[allow(unreachable_patterns)]
206            _ => Err(conn_err("Disconnected")),
207        }?;
208        self.open = false; // read by start_rollback
209        Ok(())
210    }
211
212    // the rollback is queued and will be performed on next operation, like returning the connection to the pool
213    #[instrument(level = "trace")]
214    fn start_rollback(&mut self) -> Result<(), DbErr> {
215        if self.open {
216            if let Some(mut conn) = self.conn.try_lock().ok() {
217                match &mut *conn {
218                    #[cfg(feature = "sqlx-mysql")]
219                    InnerConnection::MySql(c) => {
220                        <sqlx::MySql as sqlx::Database>::TransactionManager::start_rollback(c);
221                    }
222                    #[cfg(feature = "sqlx-postgres")]
223                    InnerConnection::Postgres(c) => {
224                        <sqlx::Postgres as sqlx::Database>::TransactionManager::start_rollback(c);
225                    }
226                    #[cfg(feature = "sqlx-sqlite")]
227                    InnerConnection::Sqlite(c) => {
228                        <sqlx::Sqlite as sqlx::Database>::TransactionManager::start_rollback(c);
229                    }
230                    #[cfg(feature = "rusqlite")]
231                    InnerConnection::Rusqlite(c) => {
232                        c.start_rollback()?;
233                    }
234                    #[cfg(feature = "mock")]
235                    InnerConnection::Mock(c) => {
236                        c.rollback();
237                    }
238                    #[cfg(feature = "proxy")]
239                    InnerConnection::Proxy(c) => {
240                        c.start_rollback();
241                    }
242                    #[allow(unreachable_patterns)]
243                    _ => return Err(conn_err("Disconnected")),
244                }
245            } else {
246                //this should never happen
247                return Err(conn_err("Dropping a locked Transaction"));
248            }
249        }
250        Ok(())
251    }
252}
253
254impl TransactionSession for DatabaseTransaction {
255    fn commit(self) -> Result<(), DbErr> {
256        self.commit()
257    }
258
259    fn rollback(self) -> Result<(), DbErr> {
260        self.rollback()
261    }
262}
263
264impl Drop for DatabaseTransaction {
265    fn drop(&mut self) {
266        self.start_rollback().expect("Fail to rollback transaction");
267    }
268}
269
270impl ConnectionTrait for DatabaseTransaction {
271    fn get_database_backend(&self) -> DbBackend {
272        // this way we don't need to lock just to know the backend
273        self.backend
274    }
275
276    #[instrument(level = "trace")]
277    #[allow(unused_variables)]
278    fn execute_raw(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
279        debug_print!("{}", stmt);
280
281        #[cfg(not(feature = "sync"))]
282        let conn = &mut *self.conn.lock();
283        #[cfg(feature = "sync")]
284        let conn = &mut *self.conn.lock().map_err(|_| DbErr::MutexPoisonError)?;
285
286        match conn {
287            #[cfg(feature = "sqlx-mysql")]
288            InnerConnection::MySql(conn) => {
289                let query = crate::driver::sqlx_mysql::sqlx_query(&stmt);
290                let conn: &mut sqlx::MySqlConnection = &mut *conn;
291                crate::metric::metric!(self.metric_callback, &stmt, {
292                    query.execute(conn).map(Into::into)
293                })
294                .map_err(sqlx_error_to_exec_err)
295            }
296            #[cfg(feature = "sqlx-postgres")]
297            InnerConnection::Postgres(conn) => {
298                let query = crate::driver::sqlx_postgres::sqlx_query(&stmt);
299                let conn: &mut sqlx::PgConnection = &mut *conn;
300                crate::metric::metric!(self.metric_callback, &stmt, {
301                    query.execute(conn).map(Into::into)
302                })
303                .map_err(sqlx_error_to_exec_err)
304            }
305            #[cfg(feature = "sqlx-sqlite")]
306            InnerConnection::Sqlite(conn) => {
307                let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt);
308                let conn: &mut sqlx::SqliteConnection = &mut *conn;
309                crate::metric::metric!(self.metric_callback, &stmt, {
310                    query.execute(conn).map(Into::into)
311                })
312                .map_err(sqlx_error_to_exec_err)
313            }
314            #[cfg(feature = "rusqlite")]
315            InnerConnection::Rusqlite(conn) => conn.execute(stmt, &self.metric_callback),
316            #[cfg(feature = "mock")]
317            InnerConnection::Mock(conn) => return conn.execute(stmt),
318            #[cfg(feature = "proxy")]
319            InnerConnection::Proxy(conn) => return conn.execute(stmt),
320            #[allow(unreachable_patterns)]
321            _ => Err(conn_err("Disconnected")),
322        }
323    }
324
325    #[instrument(level = "trace")]
326    #[allow(unused_variables)]
327    fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
328        debug_print!("{}", sql);
329
330        #[cfg(not(feature = "sync"))]
331        let conn = &mut *self.conn.lock();
332        #[cfg(feature = "sync")]
333        let conn = &mut *self.conn.lock().map_err(|_| DbErr::MutexPoisonError)?;
334
335        match conn {
336            #[cfg(feature = "sqlx-mysql")]
337            InnerConnection::MySql(conn) => {
338                let conn: &mut sqlx::MySqlConnection = &mut *conn;
339                sqlx::Executor::execute(conn, sql)
340                    .map(Into::into)
341                    .map_err(sqlx_error_to_exec_err)
342            }
343            #[cfg(feature = "sqlx-postgres")]
344            InnerConnection::Postgres(conn) => {
345                let conn: &mut sqlx::PgConnection = &mut *conn;
346                sqlx::Executor::execute(conn, sql)
347                    .map(Into::into)
348                    .map_err(sqlx_error_to_exec_err)
349            }
350            #[cfg(feature = "sqlx-sqlite")]
351            InnerConnection::Sqlite(conn) => {
352                let conn: &mut sqlx::SqliteConnection = &mut *conn;
353                sqlx::Executor::execute(conn, sql)
354                    .map(Into::into)
355                    .map_err(sqlx_error_to_exec_err)
356            }
357            #[cfg(feature = "rusqlite")]
358            InnerConnection::Rusqlite(conn) => conn.execute_unprepared(sql),
359            #[cfg(feature = "mock")]
360            InnerConnection::Mock(conn) => {
361                let db_backend = conn.get_database_backend();
362                let stmt = Statement::from_string(db_backend, sql);
363                conn.execute(stmt)
364            }
365            #[cfg(feature = "proxy")]
366            InnerConnection::Proxy(conn) => {
367                let db_backend = conn.get_database_backend();
368                let stmt = Statement::from_string(db_backend, sql);
369                conn.execute(stmt)
370            }
371            #[allow(unreachable_patterns)]
372            _ => Err(conn_err("Disconnected")),
373        }
374    }
375
376    #[instrument(level = "trace")]
377    #[allow(unused_variables)]
378    fn query_one_raw(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
379        debug_print!("{}", stmt);
380
381        #[cfg(not(feature = "sync"))]
382        let conn = &mut *self.conn.lock();
383        #[cfg(feature = "sync")]
384        let conn = &mut *self.conn.lock().map_err(|_| DbErr::MutexPoisonError)?;
385
386        match conn {
387            #[cfg(feature = "sqlx-mysql")]
388            InnerConnection::MySql(conn) => {
389                let query = crate::driver::sqlx_mysql::sqlx_query(&stmt);
390                let conn: &mut sqlx::MySqlConnection = &mut *conn;
391                crate::metric::metric!(self.metric_callback, &stmt, {
392                    crate::sqlx_map_err_ignore_not_found(
393                        query.fetch_one(conn).map(|row| Some(row.into())),
394                    )
395                })
396            }
397            #[cfg(feature = "sqlx-postgres")]
398            InnerConnection::Postgres(conn) => {
399                let query = crate::driver::sqlx_postgres::sqlx_query(&stmt);
400                let conn: &mut sqlx::PgConnection = &mut *conn;
401                crate::metric::metric!(self.metric_callback, &stmt, {
402                    crate::sqlx_map_err_ignore_not_found(
403                        query.fetch_one(conn).map(|row| Some(row.into())),
404                    )
405                })
406            }
407            #[cfg(feature = "sqlx-sqlite")]
408            InnerConnection::Sqlite(conn) => {
409                let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt);
410                let conn: &mut sqlx::SqliteConnection = &mut *conn;
411                crate::metric::metric!(self.metric_callback, &stmt, {
412                    crate::sqlx_map_err_ignore_not_found(
413                        query.fetch_one(conn).map(|row| Some(row.into())),
414                    )
415                })
416            }
417            #[cfg(feature = "rusqlite")]
418            InnerConnection::Rusqlite(conn) => conn.query_one(stmt, &self.metric_callback),
419            #[cfg(feature = "mock")]
420            InnerConnection::Mock(conn) => return conn.query_one(stmt),
421            #[cfg(feature = "proxy")]
422            InnerConnection::Proxy(conn) => return conn.query_one(stmt),
423            #[allow(unreachable_patterns)]
424            _ => Err(conn_err("Disconnected")),
425        }
426    }
427
428    #[instrument(level = "trace")]
429    #[allow(unused_variables)]
430    fn query_all_raw(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
431        debug_print!("{}", stmt);
432
433        #[cfg(not(feature = "sync"))]
434        let conn = &mut *self.conn.lock();
435        #[cfg(feature = "sync")]
436        let conn = &mut *self.conn.lock().map_err(|_| DbErr::MutexPoisonError)?;
437
438        match conn {
439            #[cfg(feature = "sqlx-mysql")]
440            InnerConnection::MySql(conn) => {
441                let query = crate::driver::sqlx_mysql::sqlx_query(&stmt);
442                let conn: &mut sqlx::MySqlConnection = &mut *conn;
443                crate::metric::metric!(self.metric_callback, &stmt, {
444                    query
445                        .fetch_all(conn)
446                        .map(|rows| rows.into_iter().map(|r| r.into()).collect())
447                        .map_err(sqlx_error_to_query_err)
448                })
449            }
450            #[cfg(feature = "sqlx-postgres")]
451            InnerConnection::Postgres(conn) => {
452                let query = crate::driver::sqlx_postgres::sqlx_query(&stmt);
453                let conn: &mut sqlx::PgConnection = &mut *conn;
454                crate::metric::metric!(self.metric_callback, &stmt, {
455                    query
456                        .fetch_all(conn)
457                        .map(|rows| rows.into_iter().map(|r| r.into()).collect())
458                        .map_err(sqlx_error_to_query_err)
459                })
460            }
461            #[cfg(feature = "sqlx-sqlite")]
462            InnerConnection::Sqlite(conn) => {
463                let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt);
464                let conn: &mut sqlx::SqliteConnection = &mut *conn;
465                crate::metric::metric!(self.metric_callback, &stmt, {
466                    query
467                        .fetch_all(conn)
468                        .map(|rows| rows.into_iter().map(|r| r.into()).collect())
469                        .map_err(sqlx_error_to_query_err)
470                })
471            }
472            #[cfg(feature = "rusqlite")]
473            InnerConnection::Rusqlite(conn) => conn.query_all(stmt, &self.metric_callback),
474            #[cfg(feature = "mock")]
475            InnerConnection::Mock(conn) => return conn.query_all(stmt),
476            #[cfg(feature = "proxy")]
477            InnerConnection::Proxy(conn) => return conn.query_all(stmt),
478            #[allow(unreachable_patterns)]
479            _ => Err(conn_err("Disconnected")),
480        }
481    }
482}
483
484impl StreamTrait for DatabaseTransaction {
485    type Stream<'a> = TransactionStream<'a>;
486
487    fn get_database_backend(&self) -> DbBackend {
488        self.backend
489    }
490
491    #[instrument(level = "trace")]
492    fn stream_raw<'a>(&'a self, stmt: Statement) -> Result<Self::Stream<'a>, DbErr> {
493        ({
494            #[cfg(not(feature = "sync"))]
495            let conn = self.conn.lock();
496            #[cfg(feature = "sync")]
497            let conn = self.conn.lock().map_err(|_| DbErr::MutexPoisonError)?;
498            Ok(crate::TransactionStream::build(
499                conn,
500                stmt,
501                self.metric_callback.clone(),
502            ))
503        })
504    }
505}
506
507impl TransactionTrait for DatabaseTransaction {
508    type Transaction = DatabaseTransaction;
509
510    #[instrument(level = "trace")]
511    fn begin(&self) -> Result<DatabaseTransaction, DbErr> {
512        DatabaseTransaction::begin(
513            Arc::clone(&self.conn),
514            self.backend,
515            self.metric_callback.clone(),
516            None,
517            None,
518        )
519    }
520
521    #[instrument(level = "trace")]
522    fn begin_with_config(
523        &self,
524        isolation_level: Option<IsolationLevel>,
525        access_mode: Option<AccessMode>,
526    ) -> Result<DatabaseTransaction, DbErr> {
527        DatabaseTransaction::begin(
528            Arc::clone(&self.conn),
529            self.backend,
530            self.metric_callback.clone(),
531            isolation_level,
532            access_mode,
533        )
534    }
535
536    /// Execute the function inside a transaction.
537    /// If the function returns an error, the transaction will be rolled back.
538    /// Otherwise, the transaction will be committed.
539    #[instrument(level = "trace", skip(_callback))]
540    fn transaction<F, T, E>(&self, _callback: F) -> Result<T, TransactionError<E>>
541    where
542        F: for<'c> FnOnce(&'c DatabaseTransaction) -> Result<T, E>,
543        E: std::fmt::Display + std::fmt::Debug,
544    {
545        let transaction = self.begin().map_err(TransactionError::Connection)?;
546        transaction.run(_callback)
547    }
548
549    /// Execute the function inside a transaction.
550    /// If the function returns an error, the transaction will be rolled back.
551    /// Otherwise, the transaction will be committed.
552    #[instrument(level = "trace", skip(_callback))]
553    fn transaction_with_config<F, T, E>(
554        &self,
555        _callback: F,
556        isolation_level: Option<IsolationLevel>,
557        access_mode: Option<AccessMode>,
558    ) -> Result<T, TransactionError<E>>
559    where
560        F: for<'c> FnOnce(&'c DatabaseTransaction) -> Result<T, E>,
561        E: std::fmt::Display + std::fmt::Debug,
562    {
563        let transaction = self
564            .begin_with_config(isolation_level, access_mode)
565            .map_err(TransactionError::Connection)?;
566        transaction.run(_callback)
567    }
568}
569
570/// Defines errors for handling transaction failures
571#[derive(Debug)]
572pub enum TransactionError<E> {
573    /// A Database connection error
574    Connection(DbErr),
575    /// An error occurring when doing database transactions
576    Transaction(E),
577}
578
579impl<E> std::fmt::Display for TransactionError<E>
580where
581    E: std::fmt::Display + std::fmt::Debug,
582{
583    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
584        match self {
585            TransactionError::Connection(e) => std::fmt::Display::fmt(e, f),
586            TransactionError::Transaction(e) => std::fmt::Display::fmt(e, f),
587        }
588    }
589}
590
591impl<E> std::error::Error for TransactionError<E> where E: std::fmt::Display + std::fmt::Debug {}
592
593impl<E> From<DbErr> for TransactionError<E>
594where
595    E: std::fmt::Display + std::fmt::Debug,
596{
597    fn from(e: DbErr) -> Self {
598        Self::Connection(e)
599    }
600}