Skip to main content

sea_orm/database/
db_connection.rs

1use super::transaction::run_async_transaction_callback;
2use crate::{
3    AccessMode, ConnectionTrait, DatabaseTransaction, ExecResult, IsolationLevel, QueryResult,
4    Schema, SchemaBuilder, Statement, StatementBuilder, StreamTrait, TransactionError,
5    TransactionOptions, TransactionTrait, error::*,
6};
7use std::{fmt::Debug, future::Future, pin::Pin};
8use tracing::instrument;
9use url::Url;
10
11#[cfg(feature = "sqlx-dep")]
12use sqlx::pool::PoolConnection;
13
14#[cfg(feature = "rusqlite")]
15use crate::driver::rusqlite::{RusqliteInnerConnection, RusqliteSharedConnection};
16
17#[cfg(any(feature = "mock", feature = "proxy"))]
18use std::sync::Arc;
19
20/// Handle a database connection depending on the backend enabled by the feature
21/// flags. This creates a connection pool internally (for SQLx connections),
22/// and so is cheap to clone.
23#[derive(Debug, Clone)]
24#[non_exhaustive]
25pub struct DatabaseConnection {
26    /// `DatabaseConnection` used to be a enum. Now it's moved into inner,
27    /// because we have to attach other contexts.
28    pub inner: DatabaseConnectionType,
29    #[cfg(feature = "rbac")]
30    pub(crate) rbac: crate::RbacEngineMount,
31}
32
33/// The underlying database connection type.
34#[derive(Clone)]
35pub enum DatabaseConnectionType {
36    /// MySql database connection pool
37    #[cfg(feature = "sqlx-mysql")]
38    SqlxMySqlPoolConnection(crate::SqlxMySqlPoolConnection),
39
40    /// PostgreSQL database connection pool
41    #[cfg(feature = "sqlx-postgres")]
42    SqlxPostgresPoolConnection(crate::SqlxPostgresPoolConnection),
43
44    /// SQLite database connection pool
45    #[cfg(feature = "sqlx-sqlite")]
46    SqlxSqlitePoolConnection(crate::SqlxSqlitePoolConnection),
47
48    /// SQLite database connection sharable across threads
49    #[cfg(feature = "rusqlite")]
50    RusqliteSharedConnection(RusqliteSharedConnection),
51
52    /// Mock database connection useful for testing
53    #[cfg(feature = "mock")]
54    MockDatabaseConnection(Arc<crate::MockDatabaseConnection>),
55
56    /// Proxy database connection
57    #[cfg(feature = "proxy")]
58    ProxyDatabaseConnection(Arc<crate::ProxyDatabaseConnection>),
59
60    /// The connection has never been established
61    Disconnected,
62}
63
64/// The same as a [DatabaseConnection]
65pub type DbConn = DatabaseConnection;
66
67impl Default for DatabaseConnection {
68    fn default() -> Self {
69        DatabaseConnectionType::Disconnected.into()
70    }
71}
72
73impl From<DatabaseConnectionType> for DatabaseConnection {
74    fn from(inner: DatabaseConnectionType) -> Self {
75        Self {
76            inner,
77            #[cfg(feature = "rbac")]
78            rbac: Default::default(),
79        }
80    }
81}
82
83/// The type of database backend for real world databases.
84/// This is enabled by feature flags as specified in the crate documentation
85#[derive(Debug, Copy, Clone, PartialEq, Eq)]
86#[non_exhaustive]
87pub enum DatabaseBackend {
88    /// A MySQL backend
89    MySql,
90    /// A PostgreSQL backend
91    Postgres,
92    /// A SQLite backend
93    Sqlite,
94}
95
96/// A shorthand for [DatabaseBackend].
97pub type DbBackend = DatabaseBackend;
98
99#[derive(Debug)]
100pub(crate) enum InnerConnection {
101    #[cfg(feature = "sqlx-mysql")]
102    MySql(PoolConnection<sqlx::MySql>),
103    #[cfg(feature = "sqlx-postgres")]
104    Postgres(PoolConnection<sqlx::Postgres>),
105    #[cfg(feature = "sqlx-sqlite")]
106    Sqlite(PoolConnection<sqlx::Sqlite>),
107    #[cfg(feature = "rusqlite")]
108    Rusqlite(RusqliteInnerConnection),
109    #[cfg(feature = "mock")]
110    Mock(Arc<crate::MockDatabaseConnection>),
111    #[cfg(feature = "proxy")]
112    Proxy(Arc<crate::ProxyDatabaseConnection>),
113}
114
115impl Debug for DatabaseConnectionType {
116    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
117        write!(
118            f,
119            "{}",
120            match self {
121                #[cfg(feature = "sqlx-mysql")]
122                Self::SqlxMySqlPoolConnection(_) => "SqlxMySqlPoolConnection",
123                #[cfg(feature = "sqlx-postgres")]
124                Self::SqlxPostgresPoolConnection(_) => "SqlxPostgresPoolConnection",
125                #[cfg(feature = "sqlx-sqlite")]
126                Self::SqlxSqlitePoolConnection(_) => "SqlxSqlitePoolConnection",
127                #[cfg(feature = "rusqlite")]
128                Self::RusqliteSharedConnection(_) => "RusqliteSharedConnection",
129                #[cfg(feature = "mock")]
130                Self::MockDatabaseConnection(_) => "MockDatabaseConnection",
131                #[cfg(feature = "proxy")]
132                Self::ProxyDatabaseConnection(_) => "ProxyDatabaseConnection",
133                Self::Disconnected => "Disconnected",
134            }
135        )
136    }
137}
138
139#[async_trait::async_trait]
140impl ConnectionTrait for DatabaseConnection {
141    fn get_database_backend(&self) -> DbBackend {
142        self.get_database_backend()
143    }
144
145    #[instrument(level = "trace")]
146    #[allow(unused_variables)]
147    async fn execute_raw(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
148        super::tracing_spans::with_db_span!(
149            "sea_orm.execute",
150            self.get_database_backend(),
151            stmt.sql.as_str(),
152            record_stmt = self.get_record_stmt_in_spans(),
153            async {
154                match &self.inner {
155                    #[cfg(feature = "sqlx-mysql")]
156                    DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => {
157                        conn.execute(stmt).await
158                    }
159                    #[cfg(feature = "sqlx-postgres")]
160                    DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
161                        conn.execute(stmt).await
162                    }
163                    #[cfg(feature = "sqlx-sqlite")]
164                    DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => {
165                        conn.execute(stmt).await
166                    }
167                    #[cfg(feature = "rusqlite")]
168                    DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.execute(stmt),
169                    #[cfg(feature = "mock")]
170                    DatabaseConnectionType::MockDatabaseConnection(conn) => conn.execute(stmt),
171                    #[cfg(feature = "proxy")]
172                    DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
173                        conn.execute(stmt).await
174                    }
175                    DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
176                }
177            }
178        )
179    }
180
181    #[instrument(level = "trace")]
182    #[allow(unused_variables)]
183    async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
184        super::tracing_spans::with_db_span!(
185            "sea_orm.execute_unprepared",
186            self.get_database_backend(),
187            sql,
188            record_stmt = false,
189            async {
190                match &self.inner {
191                    #[cfg(feature = "sqlx-mysql")]
192                    DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => {
193                        conn.execute_unprepared(sql).await
194                    }
195                    #[cfg(feature = "sqlx-postgres")]
196                    DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
197                        conn.execute_unprepared(sql).await
198                    }
199                    #[cfg(feature = "sqlx-sqlite")]
200                    DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => {
201                        conn.execute_unprepared(sql).await
202                    }
203                    #[cfg(feature = "rusqlite")]
204                    DatabaseConnectionType::RusqliteSharedConnection(conn) => {
205                        conn.execute_unprepared(sql)
206                    }
207                    #[cfg(feature = "mock")]
208                    DatabaseConnectionType::MockDatabaseConnection(conn) => {
209                        let db_backend = conn.get_database_backend();
210                        let stmt = Statement::from_string(db_backend, sql);
211                        conn.execute(stmt)
212                    }
213                    #[cfg(feature = "proxy")]
214                    DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
215                        let db_backend = conn.get_database_backend();
216                        let stmt = Statement::from_string(db_backend, sql);
217                        conn.execute(stmt).await
218                    }
219                    DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
220                }
221            }
222        )
223    }
224
225    #[instrument(level = "trace")]
226    #[allow(unused_variables)]
227    async fn query_one_raw(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
228        super::tracing_spans::with_db_span!(
229            "sea_orm.query_one",
230            self.get_database_backend(),
231            stmt.sql.as_str(),
232            record_stmt = self.get_record_stmt_in_spans(),
233            async {
234                match &self.inner {
235                    #[cfg(feature = "sqlx-mysql")]
236                    DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => {
237                        conn.query_one(stmt).await
238                    }
239                    #[cfg(feature = "sqlx-postgres")]
240                    DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
241                        conn.query_one(stmt).await
242                    }
243                    #[cfg(feature = "sqlx-sqlite")]
244                    DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => {
245                        conn.query_one(stmt).await
246                    }
247                    #[cfg(feature = "rusqlite")]
248                    DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.query_one(stmt),
249                    #[cfg(feature = "mock")]
250                    DatabaseConnectionType::MockDatabaseConnection(conn) => conn.query_one(stmt),
251                    #[cfg(feature = "proxy")]
252                    DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
253                        conn.query_one(stmt).await
254                    }
255                    DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
256                }
257            }
258        )
259    }
260
261    #[instrument(level = "trace")]
262    #[allow(unused_variables)]
263    async fn query_all_raw(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
264        super::tracing_spans::with_db_span!(
265            "sea_orm.query_all",
266            self.get_database_backend(),
267            stmt.sql.as_str(),
268            record_stmt = self.get_record_stmt_in_spans(),
269            async {
270                match &self.inner {
271                    #[cfg(feature = "sqlx-mysql")]
272                    DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => {
273                        conn.query_all(stmt).await
274                    }
275                    #[cfg(feature = "sqlx-postgres")]
276                    DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
277                        conn.query_all(stmt).await
278                    }
279                    #[cfg(feature = "sqlx-sqlite")]
280                    DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => {
281                        conn.query_all(stmt).await
282                    }
283                    #[cfg(feature = "rusqlite")]
284                    DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.query_all(stmt),
285                    #[cfg(feature = "mock")]
286                    DatabaseConnectionType::MockDatabaseConnection(conn) => conn.query_all(stmt),
287                    #[cfg(feature = "proxy")]
288                    DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
289                        conn.query_all(stmt).await
290                    }
291                    DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
292                }
293            }
294        )
295    }
296
297    #[cfg(feature = "mock")]
298    fn is_mock_connection(&self) -> bool {
299        matches!(
300            self,
301            DatabaseConnection {
302                inner: DatabaseConnectionType::MockDatabaseConnection(_),
303                ..
304            }
305        )
306    }
307}
308
309#[async_trait::async_trait]
310impl StreamTrait for DatabaseConnection {
311    type Stream<'a> = crate::QueryStream;
312
313    fn get_database_backend(&self) -> DbBackend {
314        self.get_database_backend()
315    }
316
317    #[instrument(level = "trace")]
318    #[allow(unused_variables)]
319    fn stream_raw<'a>(
320        &'a self,
321        stmt: Statement,
322    ) -> Pin<Box<dyn Future<Output = Result<Self::Stream<'a>, DbErr>> + 'a + Send>> {
323        Box::pin(async move {
324            match &self.inner {
325                #[cfg(feature = "sqlx-mysql")]
326                DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => conn.stream(stmt).await,
327                #[cfg(feature = "sqlx-postgres")]
328                DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => conn.stream(stmt).await,
329                #[cfg(feature = "sqlx-sqlite")]
330                DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => conn.stream(stmt).await,
331                #[cfg(feature = "rusqlite")]
332                DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.stream(stmt),
333                #[cfg(feature = "mock")]
334                DatabaseConnectionType::MockDatabaseConnection(conn) => {
335                    Ok(crate::QueryStream::from((Arc::clone(conn), stmt, None)))
336                }
337                #[cfg(feature = "proxy")]
338                DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
339                    Ok(crate::QueryStream::from((Arc::clone(conn), stmt, None)))
340                }
341                DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
342            }
343        })
344    }
345}
346
347#[async_trait::async_trait]
348impl TransactionTrait for DatabaseConnection {
349    type Transaction = DatabaseTransaction;
350
351    #[instrument(level = "trace")]
352    async fn begin(&self) -> Result<DatabaseTransaction, DbErr> {
353        match &self.inner {
354            #[cfg(feature = "sqlx-mysql")]
355            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => conn.begin(None, None).await,
356            #[cfg(feature = "sqlx-postgres")]
357            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
358                conn.begin(None, None).await
359            }
360            #[cfg(feature = "sqlx-sqlite")]
361            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => {
362                conn.begin(None, None, None).await
363            }
364            #[cfg(feature = "rusqlite")]
365            DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.begin(None, None, None),
366            #[cfg(feature = "mock")]
367            DatabaseConnectionType::MockDatabaseConnection(conn) => {
368                DatabaseTransaction::new_mock(Arc::clone(conn), None).await
369            }
370            #[cfg(feature = "proxy")]
371            DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
372                DatabaseTransaction::new_proxy(conn.clone(), None).await
373            }
374            DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
375        }
376    }
377
378    #[instrument(level = "trace")]
379    async fn begin_with_config(
380        &self,
381        _isolation_level: Option<IsolationLevel>,
382        _access_mode: Option<AccessMode>,
383    ) -> Result<DatabaseTransaction, DbErr> {
384        match &self.inner {
385            #[cfg(feature = "sqlx-mysql")]
386            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => {
387                conn.begin(_isolation_level, _access_mode).await
388            }
389            #[cfg(feature = "sqlx-postgres")]
390            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
391                conn.begin(_isolation_level, _access_mode).await
392            }
393            #[cfg(feature = "sqlx-sqlite")]
394            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => {
395                conn.begin(_isolation_level, _access_mode, None).await
396            }
397            #[cfg(feature = "rusqlite")]
398            DatabaseConnectionType::RusqliteSharedConnection(conn) => {
399                conn.begin(_isolation_level, _access_mode, None)
400            }
401            #[cfg(feature = "mock")]
402            DatabaseConnectionType::MockDatabaseConnection(conn) => {
403                DatabaseTransaction::new_mock(Arc::clone(conn), None).await
404            }
405            #[cfg(feature = "proxy")]
406            DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
407                DatabaseTransaction::new_proxy(conn.clone(), None).await
408            }
409            DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
410        }
411    }
412
413    #[instrument(level = "trace")]
414    async fn begin_with_options(
415        &self,
416        TransactionOptions {
417            isolation_level: _isolation_level,
418            access_mode: _access_mode,
419            sqlite_transaction_mode: _sqlite_transaction_mode,
420        }: TransactionOptions,
421    ) -> Result<DatabaseTransaction, DbErr> {
422        match &self.inner {
423            #[cfg(feature = "sqlx-mysql")]
424            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => {
425                conn.begin(_isolation_level, _access_mode).await
426            }
427            #[cfg(feature = "sqlx-postgres")]
428            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
429                conn.begin(_isolation_level, _access_mode).await
430            }
431            #[cfg(feature = "sqlx-sqlite")]
432            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => {
433                conn.begin(_isolation_level, _access_mode, _sqlite_transaction_mode)
434                    .await
435            }
436            #[cfg(feature = "rusqlite")]
437            DatabaseConnectionType::RusqliteSharedConnection(conn) => {
438                conn.begin(_isolation_level, _access_mode, _sqlite_transaction_mode)
439            }
440            #[cfg(feature = "mock")]
441            DatabaseConnectionType::MockDatabaseConnection(conn) => {
442                DatabaseTransaction::new_mock(Arc::clone(conn), None).await
443            }
444            #[cfg(feature = "proxy")]
445            DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
446                DatabaseTransaction::new_proxy(conn.clone(), None).await
447            }
448            DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
449        }
450    }
451
452    /// Execute the function inside a transaction.
453    /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed.
454    #[instrument(level = "trace", skip(_callback))]
455    async fn transaction<F, T, E>(&self, _callback: F) -> Result<T, TransactionError<E>>
456    where
457        F: for<'c> FnOnce(
458                &'c DatabaseTransaction,
459            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
460            + Send,
461        T: Send,
462        E: std::fmt::Display + std::fmt::Debug + Send,
463    {
464        match &self.inner {
465            #[cfg(feature = "sqlx-mysql")]
466            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => {
467                conn.transaction(_callback, None, None).await
468            }
469            #[cfg(feature = "sqlx-postgres")]
470            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
471                conn.transaction(_callback, None, None).await
472            }
473            #[cfg(feature = "sqlx-sqlite")]
474            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => {
475                conn.transaction(_callback, None, None).await
476            }
477            #[cfg(feature = "rusqlite")]
478            DatabaseConnectionType::RusqliteSharedConnection(conn) => {
479                conn.transaction(_callback, None, None)
480            }
481            #[cfg(feature = "mock")]
482            DatabaseConnectionType::MockDatabaseConnection(conn) => {
483                let transaction = DatabaseTransaction::new_mock(Arc::clone(conn), None)
484                    .await
485                    .map_err(TransactionError::Connection)?;
486                transaction.run(_callback).await
487            }
488            #[cfg(feature = "proxy")]
489            DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
490                let transaction = DatabaseTransaction::new_proxy(conn.clone(), None)
491                    .await
492                    .map_err(TransactionError::Connection)?;
493                transaction.run(_callback).await
494            }
495            DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected").into()),
496        }
497    }
498
499    /// Execute the function inside a transaction.
500    /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed.
501    #[instrument(level = "trace", skip(_callback))]
502    async fn transaction_with_config<F, T, E>(
503        &self,
504        _callback: F,
505        _isolation_level: Option<IsolationLevel>,
506        _access_mode: Option<AccessMode>,
507    ) -> Result<T, TransactionError<E>>
508    where
509        F: for<'c> FnOnce(
510                &'c DatabaseTransaction,
511            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
512            + Send,
513        T: Send,
514        E: std::fmt::Display + std::fmt::Debug + Send,
515    {
516        match &self.inner {
517            #[cfg(feature = "sqlx-mysql")]
518            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => {
519                conn.transaction(_callback, _isolation_level, _access_mode)
520                    .await
521            }
522            #[cfg(feature = "sqlx-postgres")]
523            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
524                conn.transaction(_callback, _isolation_level, _access_mode)
525                    .await
526            }
527            #[cfg(feature = "sqlx-sqlite")]
528            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => {
529                conn.transaction(_callback, _isolation_level, _access_mode)
530                    .await
531            }
532            #[cfg(feature = "rusqlite")]
533            DatabaseConnectionType::RusqliteSharedConnection(conn) => {
534                conn.transaction(_callback, _isolation_level, _access_mode)
535            }
536            #[cfg(feature = "mock")]
537            DatabaseConnectionType::MockDatabaseConnection(conn) => {
538                let transaction = DatabaseTransaction::new_mock(Arc::clone(conn), None)
539                    .await
540                    .map_err(TransactionError::Connection)?;
541                transaction.run(_callback).await
542            }
543            #[cfg(feature = "proxy")]
544            DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
545                let transaction = DatabaseTransaction::new_proxy(conn.clone(), None)
546                    .await
547                    .map_err(TransactionError::Connection)?;
548                transaction.run(_callback).await
549            }
550            DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected").into()),
551        }
552    }
553}
554
555#[cfg(feature = "mock")]
556impl DatabaseConnection {
557    /// Generate a database connection for testing the Mock database
558    ///
559    /// # Panics
560    ///
561    /// Panics if [DbConn] is not a mock connection.
562    pub fn as_mock_connection(&self) -> &crate::MockDatabaseConnection {
563        match &self.inner {
564            DatabaseConnectionType::MockDatabaseConnection(mock_conn) => mock_conn,
565            _ => panic!("Not mock connection"),
566        }
567    }
568
569    /// Get the transaction log as a collection Vec<[crate::Transaction]>
570    ///
571    /// # Panics
572    ///
573    /// Panics if the mocker mutex is being held by another thread.
574    pub fn into_transaction_log(self) -> Vec<crate::Transaction> {
575        let mut mocker = self
576            .as_mock_connection()
577            .get_mocker_mutex()
578            .lock()
579            .expect("Fail to acquire mocker");
580        mocker.drain_transaction_log()
581    }
582}
583
584#[cfg(feature = "proxy")]
585impl DatabaseConnection {
586    /// Generate a database connection for testing the Proxy database
587    ///
588    /// # Panics
589    ///
590    /// Panics if [DbConn] is not a proxy connection.
591    pub fn as_proxy_connection(&self) -> &crate::ProxyDatabaseConnection {
592        match &self.inner {
593            DatabaseConnectionType::ProxyDatabaseConnection(proxy_conn) => proxy_conn,
594            _ => panic!("Not proxy connection"),
595        }
596    }
597}
598
599#[cfg(feature = "rbac")]
600impl DatabaseConnection {
601    /// Load RBAC data from the same database as this connection and setup RBAC engine.
602    /// If the RBAC engine already exists, it will be replaced.
603    pub async fn load_rbac(&self) -> Result<(), DbErr> {
604        self.load_rbac_from(self).await
605    }
606
607    /// Load RBAC data from the given database connection and setup RBAC engine.
608    /// This could be from another database.
609    pub async fn load_rbac_from(&self, db: &DbConn) -> Result<(), DbErr> {
610        let engine = crate::rbac::RbacEngine::load_from(db).await?;
611        self.rbac.replace(engine);
612        Ok(())
613    }
614
615    /// Replace the internal RBAC engine.
616    pub fn replace_rbac(&self, engine: crate::rbac::RbacEngine) {
617        self.rbac.replace(engine);
618    }
619
620    /// Create a restricted connection with access control specific for the user.
621    pub fn restricted_for(
622        &self,
623        user_id: crate::rbac::RbacUserId,
624    ) -> Result<crate::RestrictedConnection, DbErr> {
625        if self.rbac.is_some() {
626            Ok(crate::RestrictedConnection {
627                user_id,
628                conn: self.clone(),
629            })
630        } else {
631            Err(DbErr::RbacError("engine not set up".into()))
632        }
633    }
634}
635
636impl DatabaseConnection {
637    /// Execute the function inside a transaction.
638    /// If the function returns an error, the transaction will be rolled back.
639    /// Otherwise, the transaction will be committed.
640    #[instrument(level = "trace", skip(callback))]
641    pub async fn transaction_async<F, T, E>(&self, callback: F) -> Result<T, TransactionError<E>>
642    where
643        F: for<'c> AsyncFnOnce(&'c DatabaseTransaction) -> Result<T, E> + Send,
644        T: Send,
645        E: std::fmt::Display + std::fmt::Debug + Send,
646    {
647        let transaction = self.begin().await.map_err(TransactionError::Connection)?;
648        run_async_transaction_callback(transaction, callback).await
649    }
650
651    /// Execute the function inside a transaction with isolation level and/or access mode.
652    /// If the function returns an error, the transaction will be rolled back.
653    /// Otherwise, the transaction will be committed.
654    #[instrument(level = "trace", skip(callback))]
655    pub async fn transaction_with_config_async<F, T, E>(
656        &self,
657        callback: F,
658        isolation_level: Option<IsolationLevel>,
659        access_mode: Option<AccessMode>,
660    ) -> Result<T, TransactionError<E>>
661    where
662        F: for<'c> AsyncFnOnce(&'c DatabaseTransaction) -> Result<T, E> + Send,
663        T: Send,
664        E: std::fmt::Display + std::fmt::Debug + Send,
665    {
666        let transaction = self
667            .begin_with_config(isolation_level, access_mode)
668            .await
669            .map_err(TransactionError::Connection)?;
670        run_async_transaction_callback(transaction, callback).await
671    }
672
673    #[allow(unused)]
674    pub(crate) fn get_record_stmt_in_spans(&self) -> bool {
675        match &self.inner {
676            #[cfg(feature = "sqlx-mysql")]
677            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => conn.record_stmt_in_spans,
678            #[cfg(feature = "sqlx-postgres")]
679            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => conn.record_stmt_in_spans,
680            #[cfg(feature = "sqlx-sqlite")]
681            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => conn.record_stmt_in_spans,
682            #[cfg(feature = "rusqlite")]
683            DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.record_stmt_in_spans,
684            DatabaseConnectionType::Disconnected => true,
685            #[cfg(feature = "mock")]
686            DatabaseConnectionType::MockDatabaseConnection(_) => true,
687            #[cfg(feature = "proxy")]
688            DatabaseConnectionType::ProxyDatabaseConnection(_) => true,
689        }
690    }
691
692    /// Get the database backend for this connection
693    ///
694    /// # Panics
695    ///
696    /// Panics if [DatabaseConnection] is `Disconnected`.
697    pub fn get_database_backend(&self) -> DbBackend {
698        match &self.inner {
699            #[cfg(feature = "sqlx-mysql")]
700            DatabaseConnectionType::SqlxMySqlPoolConnection(_) => DbBackend::MySql,
701            #[cfg(feature = "sqlx-postgres")]
702            DatabaseConnectionType::SqlxPostgresPoolConnection(_) => DbBackend::Postgres,
703            #[cfg(feature = "sqlx-sqlite")]
704            DatabaseConnectionType::SqlxSqlitePoolConnection(_) => DbBackend::Sqlite,
705            #[cfg(feature = "rusqlite")]
706            DatabaseConnectionType::RusqliteSharedConnection(_) => DbBackend::Sqlite,
707            #[cfg(feature = "mock")]
708            DatabaseConnectionType::MockDatabaseConnection(conn) => conn.get_database_backend(),
709            #[cfg(feature = "proxy")]
710            DatabaseConnectionType::ProxyDatabaseConnection(conn) => conn.get_database_backend(),
711            DatabaseConnectionType::Disconnected => panic!("Disconnected"),
712        }
713    }
714
715    /// Creates a [`SchemaBuilder`] for this backend
716    pub fn get_schema_builder(&self) -> SchemaBuilder {
717        Schema::new(self.get_database_backend()).builder()
718    }
719
720    #[cfg(feature = "entity-registry")]
721    #[cfg_attr(docsrs, doc(cfg(feature = "entity-registry")))]
722    /// Builds a schema for all the entites in the given module
723    pub fn get_schema_registry(&self, prefix: &str) -> SchemaBuilder {
724        let schema = Schema::new(self.get_database_backend());
725        crate::EntityRegistry::build_schema(schema, prefix)
726    }
727
728    /// Sets a callback to metric this connection
729    pub fn set_metric_callback<F>(&mut self, _callback: F)
730    where
731        F: Fn(&crate::metric::Info<'_>) + Send + Sync + 'static,
732    {
733        match &mut self.inner {
734            #[cfg(feature = "sqlx-mysql")]
735            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => {
736                conn.set_metric_callback(_callback)
737            }
738            #[cfg(feature = "sqlx-postgres")]
739            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
740                conn.set_metric_callback(_callback)
741            }
742            #[cfg(feature = "sqlx-sqlite")]
743            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => {
744                conn.set_metric_callback(_callback)
745            }
746            #[cfg(feature = "rusqlite")]
747            DatabaseConnectionType::RusqliteSharedConnection(conn) => {
748                conn.set_metric_callback(_callback)
749            }
750            _ => {}
751        }
752    }
753
754    /// Checks if a connection to the database is still valid.
755    pub async fn ping(&self) -> Result<(), DbErr> {
756        match &self.inner {
757            #[cfg(feature = "sqlx-mysql")]
758            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => conn.ping().await,
759            #[cfg(feature = "sqlx-postgres")]
760            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => conn.ping().await,
761            #[cfg(feature = "sqlx-sqlite")]
762            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => conn.ping().await,
763            #[cfg(feature = "rusqlite")]
764            DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.ping(),
765            #[cfg(feature = "mock")]
766            DatabaseConnectionType::MockDatabaseConnection(conn) => conn.ping(),
767            #[cfg(feature = "proxy")]
768            DatabaseConnectionType::ProxyDatabaseConnection(conn) => conn.ping().await,
769            DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
770        }
771    }
772
773    /// Explicitly close the database connection.
774    /// See [`Self::close_by_ref`] for usage with references.
775    pub async fn close(self) -> Result<(), DbErr> {
776        self.close_by_ref().await
777    }
778
779    /// Explicitly close the database connection
780    pub async fn close_by_ref(&self) -> Result<(), DbErr> {
781        match &self.inner {
782            #[cfg(feature = "sqlx-mysql")]
783            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => conn.close_by_ref().await,
784            #[cfg(feature = "sqlx-postgres")]
785            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => conn.close_by_ref().await,
786            #[cfg(feature = "sqlx-sqlite")]
787            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => conn.close_by_ref().await,
788            #[cfg(feature = "rusqlite")]
789            DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.close_by_ref(),
790            #[cfg(feature = "mock")]
791            DatabaseConnectionType::MockDatabaseConnection(_) => {
792                // Nothing to cleanup, we just consume the `DatabaseConnection`
793                Ok(())
794            }
795            #[cfg(feature = "proxy")]
796            DatabaseConnectionType::ProxyDatabaseConnection(_) => {
797                // Nothing to cleanup, we just consume the `DatabaseConnection`
798                Ok(())
799            }
800            DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
801        }
802    }
803}
804
805impl DatabaseConnection {
806    /// Get [sqlx::MySqlPool]
807    ///
808    /// # Panics
809    ///
810    /// Panics if [DbConn] is not a MySQL connection.
811    #[cfg(feature = "sqlx-mysql")]
812    pub fn get_mysql_connection_pool(&self) -> &sqlx::MySqlPool {
813        match &self.inner {
814            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => &conn.pool,
815            _ => panic!("Not MySQL Connection"),
816        }
817    }
818
819    /// Get [sqlx::PgPool]
820    ///
821    /// # Panics
822    ///
823    /// Panics if [DbConn] is not a Postgres connection.
824    #[cfg(feature = "sqlx-postgres")]
825    pub fn get_postgres_connection_pool(&self) -> &sqlx::PgPool {
826        match &self.inner {
827            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => &conn.pool,
828            _ => panic!("Not Postgres Connection"),
829        }
830    }
831
832    /// Get [sqlx::SqlitePool]
833    ///
834    /// # Panics
835    ///
836    /// Panics if [DbConn] is not a SQLite connection.
837    #[cfg(feature = "sqlx-sqlite")]
838    pub fn get_sqlite_connection_pool(&self) -> &sqlx::SqlitePool {
839        match &self.inner {
840            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => &conn.pool,
841            _ => panic!("Not SQLite Connection"),
842        }
843    }
844}
845
846impl DbBackend {
847    /// Check if the URI is the same as the specified database backend.
848    /// Returns true if they match.
849    ///
850    /// # Panics
851    ///
852    /// Panics if `base_url` cannot be parsed as `Url`.
853    pub fn is_prefix_of(self, base_url: &str) -> bool {
854        let base_url_parsed = Url::parse(base_url).expect("Fail to parse database URL");
855        match self {
856            Self::Postgres => {
857                base_url_parsed.scheme() == "postgres" || base_url_parsed.scheme() == "postgresql"
858            }
859            Self::MySql => base_url_parsed.scheme() == "mysql",
860            Self::Sqlite => base_url_parsed.scheme() == "sqlite",
861        }
862    }
863
864    /// Build an SQL [Statement]
865    pub fn build<S>(&self, statement: &S) -> Statement
866    where
867        S: StatementBuilder,
868    {
869        statement.build(self)
870    }
871
872    /// Check if the database supports `RETURNING` syntax on insert and update
873    pub fn support_returning(&self) -> bool {
874        match self {
875            Self::Postgres => true,
876            Self::Sqlite if cfg!(feature = "sqlite-use-returning-for-3_35") => true,
877            Self::MySql if cfg!(feature = "mariadb-use-returning") => true,
878            _ => false,
879        }
880    }
881
882    /// A getter for database dependent boolean value
883    pub fn boolean_value(&self, boolean: bool) -> sea_query::Value {
884        match self {
885            Self::MySql | Self::Postgres | Self::Sqlite => boolean.into(),
886        }
887    }
888
889    /// Get the display string for this enum
890    pub fn as_str(&self) -> &'static str {
891        match self {
892            DatabaseBackend::MySql => "MySql",
893            DatabaseBackend::Postgres => "Postgres",
894            DatabaseBackend::Sqlite => "Sqlite",
895        }
896    }
897}
898
899#[cfg(test)]
900mod tests {
901    use crate::DatabaseConnection;
902
903    #[cfg(not(feature = "sync"))]
904    #[test]
905    fn assert_database_connection_traits() {
906        fn assert_send_sync<T: Send + Sync>() {}
907
908        assert_send_sync::<DatabaseConnection>();
909    }
910}