Skip to main content

sea_orm/database/
db_connection.rs

1use super::transaction::run_async_transaction_callback;
2use crate::{
3    AccessMode, ConnectionTrait, DatabaseTransaction, ExecResult, IsolationLevel, QueryResult,
4    Schema, SchemaBuilder, Statement, StatementBuilder, StreamTrait, TransactionError,
5    TransactionOptions, TransactionTrait, error::*,
6};
7use std::fmt::Debug;
8use tracing::instrument;
9use url::Url;
10
11#[cfg(feature = "sqlx-dep")]
12use sqlx::pool::PoolConnection;
13
14#[cfg(feature = "rusqlite")]
15use crate::driver::rusqlite::{RusqliteInnerConnection, RusqliteSharedConnection};
16
17#[cfg(any(feature = "mock", feature = "proxy"))]
18use std::sync::Arc;
19
20/// Handle a database connection depending on the backend enabled by the feature
21/// flags. This creates a connection pool internally (for SQLx connections),
22/// and so is cheap to clone.
23#[derive(Debug, Clone)]
24#[non_exhaustive]
25pub struct DatabaseConnection {
26    /// `DatabaseConnection` used to be a enum. Now it's moved into inner,
27    /// because we have to attach other contexts.
28    pub inner: DatabaseConnectionType,
29    #[cfg(feature = "rbac")]
30    pub(crate) rbac: crate::RbacEngineMount,
31}
32
33/// The underlying database connection type.
34#[derive(Clone)]
35pub enum DatabaseConnectionType {
36    /// MySql database connection pool
37    #[cfg(feature = "sqlx-mysql")]
38    SqlxMySqlPoolConnection(crate::SqlxMySqlPoolConnection),
39
40    /// PostgreSQL database connection pool
41    #[cfg(feature = "sqlx-postgres")]
42    SqlxPostgresPoolConnection(crate::SqlxPostgresPoolConnection),
43
44    /// SQLite database connection pool
45    #[cfg(feature = "sqlx-sqlite")]
46    SqlxSqlitePoolConnection(crate::SqlxSqlitePoolConnection),
47
48    /// SQLite database connection sharable across threads
49    #[cfg(feature = "rusqlite")]
50    RusqliteSharedConnection(RusqliteSharedConnection),
51
52    /// Mock database connection useful for testing
53    #[cfg(feature = "mock")]
54    MockDatabaseConnection(Arc<crate::MockDatabaseConnection>),
55
56    /// Proxy database connection
57    #[cfg(feature = "proxy")]
58    ProxyDatabaseConnection(Arc<crate::ProxyDatabaseConnection>),
59
60    /// The connection has never been established
61    Disconnected,
62}
63
64/// The same as a [DatabaseConnection]
65pub type DbConn = DatabaseConnection;
66
67impl Default for DatabaseConnection {
68    fn default() -> Self {
69        DatabaseConnectionType::Disconnected.into()
70    }
71}
72
73impl From<DatabaseConnectionType> for DatabaseConnection {
74    fn from(inner: DatabaseConnectionType) -> Self {
75        Self {
76            inner,
77            #[cfg(feature = "rbac")]
78            rbac: Default::default(),
79        }
80    }
81}
82
83/// The type of database backend for real world databases.
84/// This is enabled by feature flags as specified in the crate documentation
85#[derive(Debug, Copy, Clone, PartialEq, Eq)]
86#[non_exhaustive]
87pub enum DatabaseBackend {
88    /// A MySQL backend
89    MySql,
90    /// A PostgreSQL backend
91    Postgres,
92    /// A SQLite backend
93    Sqlite,
94}
95
96/// A shorthand for [DatabaseBackend].
97pub type DbBackend = DatabaseBackend;
98
99#[derive(Debug)]
100pub(crate) enum InnerConnection {
101    #[cfg(feature = "sqlx-mysql")]
102    MySql(PoolConnection<sqlx::MySql>),
103    #[cfg(feature = "sqlx-postgres")]
104    Postgres(PoolConnection<sqlx::Postgres>),
105    #[cfg(feature = "sqlx-sqlite")]
106    Sqlite(PoolConnection<sqlx::Sqlite>),
107    #[cfg(feature = "rusqlite")]
108    Rusqlite(RusqliteInnerConnection),
109    #[cfg(feature = "mock")]
110    Mock(Arc<crate::MockDatabaseConnection>),
111    #[cfg(feature = "proxy")]
112    Proxy(Arc<crate::ProxyDatabaseConnection>),
113}
114
115impl Debug for DatabaseConnectionType {
116    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
117        write!(
118            f,
119            "{}",
120            match self {
121                #[cfg(feature = "sqlx-mysql")]
122                Self::SqlxMySqlPoolConnection(_) => "SqlxMySqlPoolConnection",
123                #[cfg(feature = "sqlx-postgres")]
124                Self::SqlxPostgresPoolConnection(_) => "SqlxPostgresPoolConnection",
125                #[cfg(feature = "sqlx-sqlite")]
126                Self::SqlxSqlitePoolConnection(_) => "SqlxSqlitePoolConnection",
127                #[cfg(feature = "rusqlite")]
128                Self::RusqliteSharedConnection(_) => "RusqliteSharedConnection",
129                #[cfg(feature = "mock")]
130                Self::MockDatabaseConnection(_) => "MockDatabaseConnection",
131                #[cfg(feature = "proxy")]
132                Self::ProxyDatabaseConnection(_) => "ProxyDatabaseConnection",
133                Self::Disconnected => "Disconnected",
134            }
135        )
136    }
137}
138
139impl ConnectionTrait for DatabaseConnection {
140    fn get_database_backend(&self) -> DbBackend {
141        self.get_database_backend()
142    }
143
144    #[instrument(level = "trace")]
145    #[allow(unused_variables)]
146    fn execute_raw(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
147        super::tracing_spans::with_db_span!(
148            "sea_orm.execute",
149            self.get_database_backend(),
150            stmt.sql.as_str(),
151            record_stmt = self.get_record_stmt_in_spans(),
152            {
153                match &self.inner {
154                    #[cfg(feature = "sqlx-mysql")]
155                    DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => conn.execute(stmt),
156                    #[cfg(feature = "sqlx-postgres")]
157                    DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => conn.execute(stmt),
158                    #[cfg(feature = "sqlx-sqlite")]
159                    DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => conn.execute(stmt),
160                    #[cfg(feature = "rusqlite")]
161                    DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.execute(stmt),
162                    #[cfg(feature = "mock")]
163                    DatabaseConnectionType::MockDatabaseConnection(conn) => conn.execute(stmt),
164                    #[cfg(feature = "proxy")]
165                    DatabaseConnectionType::ProxyDatabaseConnection(conn) => conn.execute(stmt),
166                    DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
167                }
168            }
169        )
170    }
171
172    #[instrument(level = "trace")]
173    #[allow(unused_variables)]
174    fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
175        super::tracing_spans::with_db_span!(
176            "sea_orm.execute_unprepared",
177            self.get_database_backend(),
178            sql,
179            record_stmt = false,
180            {
181                match &self.inner {
182                    #[cfg(feature = "sqlx-mysql")]
183                    DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => {
184                        conn.execute_unprepared(sql)
185                    }
186                    #[cfg(feature = "sqlx-postgres")]
187                    DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
188                        conn.execute_unprepared(sql)
189                    }
190                    #[cfg(feature = "sqlx-sqlite")]
191                    DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => {
192                        conn.execute_unprepared(sql)
193                    }
194                    #[cfg(feature = "rusqlite")]
195                    DatabaseConnectionType::RusqliteSharedConnection(conn) => {
196                        conn.execute_unprepared(sql)
197                    }
198                    #[cfg(feature = "mock")]
199                    DatabaseConnectionType::MockDatabaseConnection(conn) => {
200                        let db_backend = conn.get_database_backend();
201                        let stmt = Statement::from_string(db_backend, sql);
202                        conn.execute(stmt)
203                    }
204                    #[cfg(feature = "proxy")]
205                    DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
206                        let db_backend = conn.get_database_backend();
207                        let stmt = Statement::from_string(db_backend, sql);
208                        conn.execute(stmt)
209                    }
210                    DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
211                }
212            }
213        )
214    }
215
216    #[instrument(level = "trace")]
217    #[allow(unused_variables)]
218    fn query_one_raw(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
219        super::tracing_spans::with_db_span!(
220            "sea_orm.query_one",
221            self.get_database_backend(),
222            stmt.sql.as_str(),
223            record_stmt = self.get_record_stmt_in_spans(),
224            {
225                match &self.inner {
226                    #[cfg(feature = "sqlx-mysql")]
227                    DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => conn.query_one(stmt),
228                    #[cfg(feature = "sqlx-postgres")]
229                    DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
230                        conn.query_one(stmt)
231                    }
232                    #[cfg(feature = "sqlx-sqlite")]
233                    DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => conn.query_one(stmt),
234                    #[cfg(feature = "rusqlite")]
235                    DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.query_one(stmt),
236                    #[cfg(feature = "mock")]
237                    DatabaseConnectionType::MockDatabaseConnection(conn) => conn.query_one(stmt),
238                    #[cfg(feature = "proxy")]
239                    DatabaseConnectionType::ProxyDatabaseConnection(conn) => conn.query_one(stmt),
240                    DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
241                }
242            }
243        )
244    }
245
246    #[instrument(level = "trace")]
247    #[allow(unused_variables)]
248    fn query_all_raw(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
249        super::tracing_spans::with_db_span!(
250            "sea_orm.query_all",
251            self.get_database_backend(),
252            stmt.sql.as_str(),
253            record_stmt = self.get_record_stmt_in_spans(),
254            {
255                match &self.inner {
256                    #[cfg(feature = "sqlx-mysql")]
257                    DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => conn.query_all(stmt),
258                    #[cfg(feature = "sqlx-postgres")]
259                    DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
260                        conn.query_all(stmt)
261                    }
262                    #[cfg(feature = "sqlx-sqlite")]
263                    DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => conn.query_all(stmt),
264                    #[cfg(feature = "rusqlite")]
265                    DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.query_all(stmt),
266                    #[cfg(feature = "mock")]
267                    DatabaseConnectionType::MockDatabaseConnection(conn) => conn.query_all(stmt),
268                    #[cfg(feature = "proxy")]
269                    DatabaseConnectionType::ProxyDatabaseConnection(conn) => conn.query_all(stmt),
270                    DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
271                }
272            }
273        )
274    }
275
276    #[cfg(feature = "mock")]
277    fn is_mock_connection(&self) -> bool {
278        matches!(
279            self,
280            DatabaseConnection {
281                inner: DatabaseConnectionType::MockDatabaseConnection(_),
282                ..
283            }
284        )
285    }
286}
287
288impl StreamTrait for DatabaseConnection {
289    type Stream<'a> = crate::QueryStream;
290
291    fn get_database_backend(&self) -> DbBackend {
292        self.get_database_backend()
293    }
294
295    #[instrument(level = "trace")]
296    #[allow(unused_variables)]
297    fn stream_raw<'a>(&'a self, stmt: Statement) -> Result<Self::Stream<'a>, DbErr> {
298        ({
299            match &self.inner {
300                #[cfg(feature = "sqlx-mysql")]
301                DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => conn.stream(stmt),
302                #[cfg(feature = "sqlx-postgres")]
303                DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => conn.stream(stmt),
304                #[cfg(feature = "sqlx-sqlite")]
305                DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => conn.stream(stmt),
306                #[cfg(feature = "rusqlite")]
307                DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.stream(stmt),
308                #[cfg(feature = "mock")]
309                DatabaseConnectionType::MockDatabaseConnection(conn) => {
310                    Ok(crate::QueryStream::from((Arc::clone(conn), stmt, None)))
311                }
312                #[cfg(feature = "proxy")]
313                DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
314                    Ok(crate::QueryStream::from((Arc::clone(conn), stmt, None)))
315                }
316                DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
317            }
318        })
319    }
320}
321
322impl TransactionTrait for DatabaseConnection {
323    type Transaction = DatabaseTransaction;
324
325    #[instrument(level = "trace")]
326    fn begin(&self) -> Result<DatabaseTransaction, DbErr> {
327        match &self.inner {
328            #[cfg(feature = "sqlx-mysql")]
329            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => conn.begin(None, None),
330            #[cfg(feature = "sqlx-postgres")]
331            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => conn.begin(None, None),
332            #[cfg(feature = "sqlx-sqlite")]
333            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => conn.begin(None, None, None),
334            #[cfg(feature = "rusqlite")]
335            DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.begin(None, None, None),
336            #[cfg(feature = "mock")]
337            DatabaseConnectionType::MockDatabaseConnection(conn) => {
338                DatabaseTransaction::new_mock(Arc::clone(conn), None)
339            }
340            #[cfg(feature = "proxy")]
341            DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
342                DatabaseTransaction::new_proxy(conn.clone(), None)
343            }
344            DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
345        }
346    }
347
348    #[instrument(level = "trace")]
349    fn begin_with_config(
350        &self,
351        _isolation_level: Option<IsolationLevel>,
352        _access_mode: Option<AccessMode>,
353    ) -> Result<DatabaseTransaction, DbErr> {
354        match &self.inner {
355            #[cfg(feature = "sqlx-mysql")]
356            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => {
357                conn.begin(_isolation_level, _access_mode)
358            }
359            #[cfg(feature = "sqlx-postgres")]
360            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
361                conn.begin(_isolation_level, _access_mode)
362            }
363            #[cfg(feature = "sqlx-sqlite")]
364            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => {
365                conn.begin(_isolation_level, _access_mode, None)
366            }
367            #[cfg(feature = "rusqlite")]
368            DatabaseConnectionType::RusqliteSharedConnection(conn) => {
369                conn.begin(_isolation_level, _access_mode, None)
370            }
371            #[cfg(feature = "mock")]
372            DatabaseConnectionType::MockDatabaseConnection(conn) => {
373                DatabaseTransaction::new_mock(Arc::clone(conn), None)
374            }
375            #[cfg(feature = "proxy")]
376            DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
377                DatabaseTransaction::new_proxy(conn.clone(), None)
378            }
379            DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
380        }
381    }
382
383    #[instrument(level = "trace")]
384    fn begin_with_options(
385        &self,
386        TransactionOptions {
387            isolation_level: _isolation_level,
388            access_mode: _access_mode,
389            sqlite_transaction_mode: _sqlite_transaction_mode,
390        }: TransactionOptions,
391    ) -> Result<DatabaseTransaction, DbErr> {
392        match &self.inner {
393            #[cfg(feature = "sqlx-mysql")]
394            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => {
395                conn.begin(_isolation_level, _access_mode)
396            }
397            #[cfg(feature = "sqlx-postgres")]
398            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
399                conn.begin(_isolation_level, _access_mode)
400            }
401            #[cfg(feature = "sqlx-sqlite")]
402            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => {
403                conn.begin(_isolation_level, _access_mode, _sqlite_transaction_mode)
404            }
405            #[cfg(feature = "rusqlite")]
406            DatabaseConnectionType::RusqliteSharedConnection(conn) => {
407                conn.begin(_isolation_level, _access_mode, _sqlite_transaction_mode)
408            }
409            #[cfg(feature = "mock")]
410            DatabaseConnectionType::MockDatabaseConnection(conn) => {
411                DatabaseTransaction::new_mock(Arc::clone(conn), None)
412            }
413            #[cfg(feature = "proxy")]
414            DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
415                DatabaseTransaction::new_proxy(conn.clone(), None)
416            }
417            DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
418        }
419    }
420
421    /// Execute the function inside a transaction.
422    /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed.
423    #[instrument(level = "trace", skip(_callback))]
424    fn transaction<F, T, E>(&self, _callback: F) -> Result<T, TransactionError<E>>
425    where
426        F: for<'c> FnOnce(&'c DatabaseTransaction) -> Result<T, E>,
427        E: std::fmt::Display + std::fmt::Debug,
428    {
429        match &self.inner {
430            #[cfg(feature = "sqlx-mysql")]
431            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => {
432                conn.transaction(_callback, None, None)
433            }
434            #[cfg(feature = "sqlx-postgres")]
435            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
436                conn.transaction(_callback, None, None)
437            }
438            #[cfg(feature = "sqlx-sqlite")]
439            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => {
440                conn.transaction(_callback, None, None)
441            }
442            #[cfg(feature = "rusqlite")]
443            DatabaseConnectionType::RusqliteSharedConnection(conn) => {
444                conn.transaction(_callback, None, None)
445            }
446            #[cfg(feature = "mock")]
447            DatabaseConnectionType::MockDatabaseConnection(conn) => {
448                let transaction = DatabaseTransaction::new_mock(Arc::clone(conn), None)
449                    .map_err(TransactionError::Connection)?;
450                transaction.run(_callback)
451            }
452            #[cfg(feature = "proxy")]
453            DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
454                let transaction = DatabaseTransaction::new_proxy(conn.clone(), None)
455                    .map_err(TransactionError::Connection)?;
456                transaction.run(_callback)
457            }
458            DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected").into()),
459        }
460    }
461
462    /// Execute the function inside a transaction.
463    /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed.
464    #[instrument(level = "trace", skip(_callback))]
465    fn transaction_with_config<F, T, E>(
466        &self,
467        _callback: F,
468        _isolation_level: Option<IsolationLevel>,
469        _access_mode: Option<AccessMode>,
470    ) -> Result<T, TransactionError<E>>
471    where
472        F: for<'c> FnOnce(&'c DatabaseTransaction) -> Result<T, E>,
473        E: std::fmt::Display + std::fmt::Debug,
474    {
475        match &self.inner {
476            #[cfg(feature = "sqlx-mysql")]
477            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => {
478                conn.transaction(_callback, _isolation_level, _access_mode)
479            }
480            #[cfg(feature = "sqlx-postgres")]
481            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
482                conn.transaction(_callback, _isolation_level, _access_mode)
483            }
484            #[cfg(feature = "sqlx-sqlite")]
485            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => {
486                conn.transaction(_callback, _isolation_level, _access_mode)
487            }
488            #[cfg(feature = "rusqlite")]
489            DatabaseConnectionType::RusqliteSharedConnection(conn) => {
490                conn.transaction(_callback, _isolation_level, _access_mode)
491            }
492            #[cfg(feature = "mock")]
493            DatabaseConnectionType::MockDatabaseConnection(conn) => {
494                let transaction = DatabaseTransaction::new_mock(Arc::clone(conn), None)
495                    .map_err(TransactionError::Connection)?;
496                transaction.run(_callback)
497            }
498            #[cfg(feature = "proxy")]
499            DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
500                let transaction = DatabaseTransaction::new_proxy(conn.clone(), None)
501                    .map_err(TransactionError::Connection)?;
502                transaction.run(_callback)
503            }
504            DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected").into()),
505        }
506    }
507}
508
509#[cfg(feature = "mock")]
510impl DatabaseConnection {
511    /// Generate a database connection for testing the Mock database
512    ///
513    /// # Panics
514    ///
515    /// Panics if [DbConn] is not a mock connection.
516    pub fn as_mock_connection(&self) -> &crate::MockDatabaseConnection {
517        match &self.inner {
518            DatabaseConnectionType::MockDatabaseConnection(mock_conn) => mock_conn,
519            _ => panic!("Not mock connection"),
520        }
521    }
522
523    /// Get the transaction log as a collection Vec<[crate::Transaction]>
524    ///
525    /// # Panics
526    ///
527    /// Panics if the mocker mutex is being held by another thread.
528    pub fn into_transaction_log(self) -> Vec<crate::Transaction> {
529        let mut mocker = self
530            .as_mock_connection()
531            .get_mocker_mutex()
532            .lock()
533            .expect("Fail to acquire mocker");
534        mocker.drain_transaction_log()
535    }
536}
537
538#[cfg(feature = "proxy")]
539impl DatabaseConnection {
540    /// Generate a database connection for testing the Proxy database
541    ///
542    /// # Panics
543    ///
544    /// Panics if [DbConn] is not a proxy connection.
545    pub fn as_proxy_connection(&self) -> &crate::ProxyDatabaseConnection {
546        match &self.inner {
547            DatabaseConnectionType::ProxyDatabaseConnection(proxy_conn) => proxy_conn,
548            _ => panic!("Not proxy connection"),
549        }
550    }
551}
552
553#[cfg(feature = "rbac")]
554impl DatabaseConnection {
555    /// Load RBAC data from the same database as this connection and setup RBAC engine.
556    /// If the RBAC engine already exists, it will be replaced.
557    pub fn load_rbac(&self) -> Result<(), DbErr> {
558        self.load_rbac_from(self)
559    }
560
561    /// Load RBAC data from the given database connection and setup RBAC engine.
562    /// This could be from another database.
563    pub fn load_rbac_from(&self, db: &DbConn) -> Result<(), DbErr> {
564        let engine = crate::rbac::RbacEngine::load_from(db)?;
565        self.rbac.replace(engine);
566        Ok(())
567    }
568
569    /// Replace the internal RBAC engine.
570    pub fn replace_rbac(&self, engine: crate::rbac::RbacEngine) {
571        self.rbac.replace(engine);
572    }
573
574    /// Create a restricted connection with access control specific for the user.
575    pub fn restricted_for(
576        &self,
577        user_id: crate::rbac::RbacUserId,
578    ) -> Result<crate::RestrictedConnection, DbErr> {
579        if self.rbac.is_some() {
580            Ok(crate::RestrictedConnection {
581                user_id,
582                conn: self.clone(),
583            })
584        } else {
585            Err(DbErr::RbacError("engine not set up".into()))
586        }
587    }
588}
589
590impl DatabaseConnection {
591    /// Execute the function inside a transaction.
592    /// If the function returns an error, the transaction will be rolled back.
593    /// Otherwise, the transaction will be committed.
594    #[instrument(level = "trace", skip(callback))]
595    pub fn transaction<F, T, E>(&self, callback: F) -> Result<T, TransactionError<E>>
596    where
597        F: for<'c> FnOnce(&'c DatabaseTransaction) -> Result<T, E>,
598        E: std::fmt::Display + std::fmt::Debug,
599    {
600        let transaction = self.begin().map_err(TransactionError::Connection)?;
601        run_async_transaction_callback(transaction, callback)
602    }
603
604    /// Execute the function inside a transaction with isolation level and/or access mode.
605    /// If the function returns an error, the transaction will be rolled back.
606    /// Otherwise, the transaction will be committed.
607    #[instrument(level = "trace", skip(callback))]
608    pub fn transaction_with_config<F, T, E>(
609        &self,
610        callback: F,
611        isolation_level: Option<IsolationLevel>,
612        access_mode: Option<AccessMode>,
613    ) -> Result<T, TransactionError<E>>
614    where
615        F: for<'c> FnOnce(&'c DatabaseTransaction) -> Result<T, E>,
616        E: std::fmt::Display + std::fmt::Debug,
617    {
618        let transaction = self
619            .begin_with_config(isolation_level, access_mode)
620            .map_err(TransactionError::Connection)?;
621        run_async_transaction_callback(transaction, callback)
622    }
623
624    #[allow(unused)]
625    pub(crate) fn get_record_stmt_in_spans(&self) -> bool {
626        match &self.inner {
627            #[cfg(feature = "sqlx-mysql")]
628            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => conn.record_stmt_in_spans,
629            #[cfg(feature = "sqlx-postgres")]
630            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => conn.record_stmt_in_spans,
631            #[cfg(feature = "sqlx-sqlite")]
632            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => conn.record_stmt_in_spans,
633            #[cfg(feature = "rusqlite")]
634            DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.record_stmt_in_spans,
635            DatabaseConnectionType::Disconnected => true,
636            #[cfg(feature = "mock")]
637            DatabaseConnectionType::MockDatabaseConnection(_) => true,
638            #[cfg(feature = "proxy")]
639            DatabaseConnectionType::ProxyDatabaseConnection(_) => true,
640        }
641    }
642
643    /// Get the database backend for this connection
644    ///
645    /// # Panics
646    ///
647    /// Panics if [DatabaseConnection] is `Disconnected`.
648    pub fn get_database_backend(&self) -> DbBackend {
649        match &self.inner {
650            #[cfg(feature = "sqlx-mysql")]
651            DatabaseConnectionType::SqlxMySqlPoolConnection(_) => DbBackend::MySql,
652            #[cfg(feature = "sqlx-postgres")]
653            DatabaseConnectionType::SqlxPostgresPoolConnection(_) => DbBackend::Postgres,
654            #[cfg(feature = "sqlx-sqlite")]
655            DatabaseConnectionType::SqlxSqlitePoolConnection(_) => DbBackend::Sqlite,
656            #[cfg(feature = "rusqlite")]
657            DatabaseConnectionType::RusqliteSharedConnection(_) => DbBackend::Sqlite,
658            #[cfg(feature = "mock")]
659            DatabaseConnectionType::MockDatabaseConnection(conn) => conn.get_database_backend(),
660            #[cfg(feature = "proxy")]
661            DatabaseConnectionType::ProxyDatabaseConnection(conn) => conn.get_database_backend(),
662            DatabaseConnectionType::Disconnected => panic!("Disconnected"),
663        }
664    }
665
666    /// Creates a [`SchemaBuilder`] for this backend
667    pub fn get_schema_builder(&self) -> SchemaBuilder {
668        Schema::new(self.get_database_backend()).builder()
669    }
670
671    #[cfg(feature = "entity-registry")]
672    #[cfg_attr(docsrs, doc(cfg(feature = "entity-registry")))]
673    /// Builds a schema for all the entites in the given module
674    pub fn get_schema_registry(&self, prefix: &str) -> SchemaBuilder {
675        let schema = Schema::new(self.get_database_backend());
676        crate::EntityRegistry::build_schema(schema, prefix)
677    }
678
679    /// Sets a callback to metric this connection
680    pub fn set_metric_callback<F>(&mut self, _callback: F)
681    where
682        F: Fn(&crate::metric::Info<'_>) + 'static,
683    {
684        match &mut self.inner {
685            #[cfg(feature = "sqlx-mysql")]
686            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => {
687                conn.set_metric_callback(_callback)
688            }
689            #[cfg(feature = "sqlx-postgres")]
690            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
691                conn.set_metric_callback(_callback)
692            }
693            #[cfg(feature = "sqlx-sqlite")]
694            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => {
695                conn.set_metric_callback(_callback)
696            }
697            #[cfg(feature = "rusqlite")]
698            DatabaseConnectionType::RusqliteSharedConnection(conn) => {
699                conn.set_metric_callback(_callback)
700            }
701            _ => {}
702        }
703    }
704
705    /// Checks if a connection to the database is still valid.
706    pub fn ping(&self) -> Result<(), DbErr> {
707        match &self.inner {
708            #[cfg(feature = "sqlx-mysql")]
709            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => conn.ping(),
710            #[cfg(feature = "sqlx-postgres")]
711            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => conn.ping(),
712            #[cfg(feature = "sqlx-sqlite")]
713            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => conn.ping(),
714            #[cfg(feature = "rusqlite")]
715            DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.ping(),
716            #[cfg(feature = "mock")]
717            DatabaseConnectionType::MockDatabaseConnection(conn) => conn.ping(),
718            #[cfg(feature = "proxy")]
719            DatabaseConnectionType::ProxyDatabaseConnection(conn) => conn.ping(),
720            DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
721        }
722    }
723
724    /// Explicitly close the database connection.
725    /// See [`Self::close_by_ref`] for usage with references.
726    pub fn close(self) -> Result<(), DbErr> {
727        self.close_by_ref()
728    }
729
730    /// Explicitly close the database connection
731    pub fn close_by_ref(&self) -> Result<(), DbErr> {
732        match &self.inner {
733            #[cfg(feature = "sqlx-mysql")]
734            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => conn.close_by_ref(),
735            #[cfg(feature = "sqlx-postgres")]
736            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => conn.close_by_ref(),
737            #[cfg(feature = "sqlx-sqlite")]
738            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => conn.close_by_ref(),
739            #[cfg(feature = "rusqlite")]
740            DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.close_by_ref(),
741            #[cfg(feature = "mock")]
742            DatabaseConnectionType::MockDatabaseConnection(_) => {
743                // Nothing to cleanup, we just consume the `DatabaseConnection`
744                Ok(())
745            }
746            #[cfg(feature = "proxy")]
747            DatabaseConnectionType::ProxyDatabaseConnection(_) => {
748                // Nothing to cleanup, we just consume the `DatabaseConnection`
749                Ok(())
750            }
751            DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
752        }
753    }
754}
755
756impl DatabaseConnection {
757    /// Get [sqlx::MySqlPool]
758    ///
759    /// # Panics
760    ///
761    /// Panics if [DbConn] is not a MySQL connection.
762    #[cfg(feature = "sqlx-mysql")]
763    pub fn get_mysql_connection_pool(&self) -> &sqlx::MySqlPool {
764        match &self.inner {
765            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => &conn.pool,
766            _ => panic!("Not MySQL Connection"),
767        }
768    }
769
770    /// Get [sqlx::PgPool]
771    ///
772    /// # Panics
773    ///
774    /// Panics if [DbConn] is not a Postgres connection.
775    #[cfg(feature = "sqlx-postgres")]
776    pub fn get_postgres_connection_pool(&self) -> &sqlx::PgPool {
777        match &self.inner {
778            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => &conn.pool,
779            _ => panic!("Not Postgres Connection"),
780        }
781    }
782
783    /// Get [sqlx::SqlitePool]
784    ///
785    /// # Panics
786    ///
787    /// Panics if [DbConn] is not a SQLite connection.
788    #[cfg(feature = "sqlx-sqlite")]
789    pub fn get_sqlite_connection_pool(&self) -> &sqlx::SqlitePool {
790        match &self.inner {
791            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => &conn.pool,
792            _ => panic!("Not SQLite Connection"),
793        }
794    }
795}
796
797impl DbBackend {
798    /// Check if the URI is the same as the specified database backend.
799    /// Returns true if they match.
800    ///
801    /// # Panics
802    ///
803    /// Panics if `base_url` cannot be parsed as `Url`.
804    pub fn is_prefix_of(self, base_url: &str) -> bool {
805        let base_url_parsed = Url::parse(base_url).expect("Fail to parse database URL");
806        match self {
807            Self::Postgres => {
808                base_url_parsed.scheme() == "postgres" || base_url_parsed.scheme() == "postgresql"
809            }
810            Self::MySql => base_url_parsed.scheme() == "mysql",
811            Self::Sqlite => base_url_parsed.scheme() == "sqlite",
812        }
813    }
814
815    /// Build an SQL [Statement]
816    pub fn build<S>(&self, statement: &S) -> Statement
817    where
818        S: StatementBuilder,
819    {
820        statement.build(self)
821    }
822
823    /// Check if the database supports `RETURNING` syntax on insert and update
824    pub fn support_returning(&self) -> bool {
825        match self {
826            Self::Postgres => true,
827            Self::Sqlite if cfg!(feature = "sqlite-use-returning-for-3_35") => true,
828            Self::MySql if cfg!(feature = "mariadb-use-returning") => true,
829            _ => false,
830        }
831    }
832
833    /// A getter for database dependent boolean value
834    pub fn boolean_value(&self, boolean: bool) -> sea_query::Value {
835        match self {
836            Self::MySql | Self::Postgres | Self::Sqlite => boolean.into(),
837        }
838    }
839
840    /// Get the display string for this enum
841    pub fn as_str(&self) -> &'static str {
842        match self {
843            DatabaseBackend::MySql => "MySql",
844            DatabaseBackend::Postgres => "Postgres",
845            DatabaseBackend::Sqlite => "Sqlite",
846        }
847    }
848}
849
850#[cfg(test)]
851mod tests {
852    use crate::DatabaseConnection;
853
854    #[cfg(not(feature = "sync"))]
855    #[test]
856    fn assert_database_connection_traits() {
857        fn assert_send_sync<T: Send>() {}
858
859        assert_send_sync::<DatabaseConnection>();
860    }
861}