Skip to main content

sea_orm/database/
db_connection.rs

1use super::transaction::run_async_transaction_callback;
2use crate::{
3    AccessMode, ConnectionTrait, DatabaseTransaction, ExecResult, IsolationLevel, QueryResult,
4    Schema, SchemaBuilder, Statement, StatementBuilder, TransactionError, TransactionOptions,
5    TransactionTrait, error::*,
6};
7use std::{fmt::Debug, future::Future, pin::Pin};
8use tracing::instrument;
9use url::Url;
10
11#[cfg(feature = "sqlx-dep")]
12use sqlx::pool::PoolConnection;
13
14#[cfg(feature = "rusqlite")]
15use crate::driver::rusqlite::{RusqliteInnerConnection, RusqliteSharedConnection};
16
17#[cfg(feature = "stream")]
18use crate::StreamTrait;
19
20#[cfg(any(feature = "mock", feature = "proxy"))]
21use std::sync::Arc;
22
23/// A handle to a database — implements [`ConnectionTrait`](crate::ConnectionTrait)
24/// and [`TransactionTrait`](crate::TransactionTrait) so it works with every
25/// query and mutation method in SeaORM.
26///
27/// Behind the scenes this is a connection pool (for SQLx-backed drivers) or
28/// a shared connection (for `rusqlite` / mocks / proxies), so it is cheap
29/// to clone — pass `&DbConn` around or `db.clone()` into spawned tasks.
30/// Obtain one via [`Database::connect`](crate::Database::connect).
31#[derive(Debug, Clone)]
32#[non_exhaustive]
33pub struct DatabaseConnection {
34    /// Driver-specific connection or pool. Held in a field so we can attach
35    /// orthogonal state (e.g. RBAC) alongside.
36    pub inner: DatabaseConnectionType,
37    #[cfg(feature = "rbac")]
38    pub(crate) rbac: crate::RbacEngineMount,
39}
40
41/// The driver-specific connection or pool wrapped by [`DatabaseConnection`].
42///
43/// Which variants are available depends on enabled feature flags. End users
44/// rarely match on this directly; use [`DatabaseConnection`]'s methods
45/// instead.
46#[derive(Clone)]
47pub enum DatabaseConnectionType {
48    /// MySQL connection pool (`sqlx-mysql`).
49    #[cfg(feature = "sqlx-mysql")]
50    SqlxMySqlPoolConnection(crate::SqlxMySqlPoolConnection),
51
52    /// PostgreSQL connection pool (`sqlx-postgres`).
53    #[cfg(feature = "sqlx-postgres")]
54    SqlxPostgresPoolConnection(crate::SqlxPostgresPoolConnection),
55
56    /// SQLite connection pool (`sqlx-sqlite`).
57    #[cfg(feature = "sqlx-sqlite")]
58    SqlxSqlitePoolConnection(crate::SqlxSqlitePoolConnection),
59
60    /// SQLite connection shared across threads (`rusqlite`).
61    #[cfg(feature = "rusqlite")]
62    RusqliteSharedConnection(RusqliteSharedConnection),
63
64    /// In-memory mock connection used for testing (`mock`).
65    #[cfg(feature = "mock")]
66    MockDatabaseConnection(Arc<crate::MockDatabaseConnection>),
67
68    /// Proxy connection that forwards statements to a user callback (`proxy`).
69    #[cfg(feature = "proxy")]
70    ProxyDatabaseConnection(Arc<crate::ProxyDatabaseConnection>),
71
72    /// Sentinel for an unconnected [`DatabaseConnection`] (default value);
73    /// any query against it returns an error.
74    Disconnected,
75}
76
77/// Short alias for [`DatabaseConnection`].
78pub type DbConn = DatabaseConnection;
79
80impl Default for DatabaseConnection {
81    fn default() -> Self {
82        DatabaseConnectionType::Disconnected.into()
83    }
84}
85
86impl From<DatabaseConnectionType> for DatabaseConnection {
87    fn from(inner: DatabaseConnectionType) -> Self {
88        Self {
89            inner,
90            #[cfg(feature = "rbac")]
91            rbac: Default::default(),
92        }
93    }
94}
95
96/// Identifies which SQL dialect is in use. Passed around so that
97/// `sea_query`-built statements can be rendered with the right placeholders,
98/// quoting, and feature support. Available variants are gated by feature
99/// flags — see the [crate-level documentation](crate).
100#[derive(Debug, Copy, Clone, PartialEq, Eq)]
101#[non_exhaustive]
102pub enum DatabaseBackend {
103    /// MySQL / MariaDB.
104    MySql,
105    /// PostgreSQL.
106    Postgres,
107    /// SQLite.
108    Sqlite,
109}
110
111/// Short alias for [`DatabaseBackend`].
112pub type DbBackend = DatabaseBackend;
113
114#[derive(Debug)]
115pub(crate) enum InnerConnection {
116    #[cfg(feature = "sqlx-mysql")]
117    MySql(PoolConnection<sqlx::MySql>),
118    #[cfg(feature = "sqlx-postgres")]
119    Postgres(PoolConnection<sqlx::Postgres>),
120    #[cfg(feature = "sqlx-sqlite")]
121    Sqlite(PoolConnection<sqlx::Sqlite>),
122    #[cfg(feature = "rusqlite")]
123    Rusqlite(RusqliteInnerConnection),
124    #[cfg(feature = "mock")]
125    Mock(Arc<crate::MockDatabaseConnection>),
126    #[cfg(feature = "proxy")]
127    Proxy(Arc<crate::ProxyDatabaseConnection>),
128}
129
130impl Debug for DatabaseConnectionType {
131    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
132        write!(
133            f,
134            "{}",
135            match self {
136                #[cfg(feature = "sqlx-mysql")]
137                Self::SqlxMySqlPoolConnection(_) => "SqlxMySqlPoolConnection",
138                #[cfg(feature = "sqlx-postgres")]
139                Self::SqlxPostgresPoolConnection(_) => "SqlxPostgresPoolConnection",
140                #[cfg(feature = "sqlx-sqlite")]
141                Self::SqlxSqlitePoolConnection(_) => "SqlxSqlitePoolConnection",
142                #[cfg(feature = "rusqlite")]
143                Self::RusqliteSharedConnection(_) => "RusqliteSharedConnection",
144                #[cfg(feature = "mock")]
145                Self::MockDatabaseConnection(_) => "MockDatabaseConnection",
146                #[cfg(feature = "proxy")]
147                Self::ProxyDatabaseConnection(_) => "ProxyDatabaseConnection",
148                Self::Disconnected => "Disconnected",
149            }
150        )
151    }
152}
153
154#[async_trait::async_trait]
155impl ConnectionTrait for DatabaseConnection {
156    fn get_database_backend(&self) -> DbBackend {
157        self.get_database_backend()
158    }
159
160    #[instrument(level = "trace", skip(stmt))]
161    #[allow(unused_variables)]
162    async fn execute_raw(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
163        super::tracing_spans::with_db_span!(
164            "sea_orm.execute",
165            self.get_database_backend(),
166            stmt.sql.as_str(),
167            record_stmt = self.get_record_stmt_in_spans(),
168            async {
169                match &self.inner {
170                    #[cfg(feature = "sqlx-mysql")]
171                    DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => {
172                        conn.execute(stmt).await
173                    }
174                    #[cfg(feature = "sqlx-postgres")]
175                    DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
176                        conn.execute(stmt).await
177                    }
178                    #[cfg(feature = "sqlx-sqlite")]
179                    DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => {
180                        conn.execute(stmt).await
181                    }
182                    #[cfg(feature = "rusqlite")]
183                    DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.execute(stmt),
184                    #[cfg(feature = "mock")]
185                    DatabaseConnectionType::MockDatabaseConnection(conn) => conn.execute(stmt),
186                    #[cfg(feature = "proxy")]
187                    DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
188                        conn.execute(stmt).await
189                    }
190                    DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
191                }
192            }
193        )
194    }
195
196    #[instrument(level = "trace", skip(sql))]
197    #[allow(unused_variables)]
198    async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
199        super::tracing_spans::with_db_span!(
200            "sea_orm.execute_unprepared",
201            self.get_database_backend(),
202            sql,
203            record_stmt = false,
204            async {
205                match &self.inner {
206                    #[cfg(feature = "sqlx-mysql")]
207                    DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => {
208                        conn.execute_unprepared(sql).await
209                    }
210                    #[cfg(feature = "sqlx-postgres")]
211                    DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
212                        conn.execute_unprepared(sql).await
213                    }
214                    #[cfg(feature = "sqlx-sqlite")]
215                    DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => {
216                        conn.execute_unprepared(sql).await
217                    }
218                    #[cfg(feature = "rusqlite")]
219                    DatabaseConnectionType::RusqliteSharedConnection(conn) => {
220                        conn.execute_unprepared(sql)
221                    }
222                    #[cfg(feature = "mock")]
223                    DatabaseConnectionType::MockDatabaseConnection(conn) => {
224                        let db_backend = conn.get_database_backend();
225                        let stmt = Statement::from_string(db_backend, sql);
226                        conn.execute(stmt)
227                    }
228                    #[cfg(feature = "proxy")]
229                    DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
230                        let db_backend = conn.get_database_backend();
231                        let stmt = Statement::from_string(db_backend, sql);
232                        conn.execute(stmt).await
233                    }
234                    DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
235                }
236            }
237        )
238    }
239
240    #[instrument(level = "trace", skip(stmt))]
241    #[allow(unused_variables)]
242    async fn query_one_raw(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
243        super::tracing_spans::with_db_span!(
244            "sea_orm.query_one",
245            self.get_database_backend(),
246            stmt.sql.as_str(),
247            record_stmt = self.get_record_stmt_in_spans(),
248            async {
249                match &self.inner {
250                    #[cfg(feature = "sqlx-mysql")]
251                    DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => {
252                        conn.query_one(stmt).await
253                    }
254                    #[cfg(feature = "sqlx-postgres")]
255                    DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
256                        conn.query_one(stmt).await
257                    }
258                    #[cfg(feature = "sqlx-sqlite")]
259                    DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => {
260                        conn.query_one(stmt).await
261                    }
262                    #[cfg(feature = "rusqlite")]
263                    DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.query_one(stmt),
264                    #[cfg(feature = "mock")]
265                    DatabaseConnectionType::MockDatabaseConnection(conn) => conn.query_one(stmt),
266                    #[cfg(feature = "proxy")]
267                    DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
268                        conn.query_one(stmt).await
269                    }
270                    DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
271                }
272            }
273        )
274    }
275
276    #[instrument(level = "trace", skip(stmt))]
277    #[allow(unused_variables)]
278    async fn query_all_raw(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
279        super::tracing_spans::with_db_span!(
280            "sea_orm.query_all",
281            self.get_database_backend(),
282            stmt.sql.as_str(),
283            record_stmt = self.get_record_stmt_in_spans(),
284            async {
285                match &self.inner {
286                    #[cfg(feature = "sqlx-mysql")]
287                    DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => {
288                        conn.query_all(stmt).await
289                    }
290                    #[cfg(feature = "sqlx-postgres")]
291                    DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
292                        conn.query_all(stmt).await
293                    }
294                    #[cfg(feature = "sqlx-sqlite")]
295                    DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => {
296                        conn.query_all(stmt).await
297                    }
298                    #[cfg(feature = "rusqlite")]
299                    DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.query_all(stmt),
300                    #[cfg(feature = "mock")]
301                    DatabaseConnectionType::MockDatabaseConnection(conn) => conn.query_all(stmt),
302                    #[cfg(feature = "proxy")]
303                    DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
304                        conn.query_all(stmt).await
305                    }
306                    DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
307                }
308            }
309        )
310    }
311
312    #[cfg(feature = "mock")]
313    fn is_mock_connection(&self) -> bool {
314        matches!(
315            self,
316            DatabaseConnection {
317                inner: DatabaseConnectionType::MockDatabaseConnection(_),
318                ..
319            }
320        )
321    }
322}
323
324#[async_trait::async_trait]
325#[cfg(feature = "stream")]
326impl StreamTrait for DatabaseConnection {
327    type Stream<'a> = crate::QueryStream;
328
329    fn get_database_backend(&self) -> DbBackend {
330        self.get_database_backend()
331    }
332
333    #[instrument(level = "trace", skip(stmt))]
334    #[allow(unused_variables)]
335    fn stream_raw<'a>(
336        &'a self,
337        stmt: Statement,
338    ) -> Pin<Box<dyn Future<Output = Result<Self::Stream<'a>, DbErr>> + 'a + Send>> {
339        Box::pin(async move {
340            match &self.inner {
341                #[cfg(feature = "sqlx-mysql")]
342                DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => conn.stream(stmt).await,
343                #[cfg(feature = "sqlx-postgres")]
344                DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => conn.stream(stmt).await,
345                #[cfg(feature = "sqlx-sqlite")]
346                DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => conn.stream(stmt).await,
347                #[cfg(feature = "rusqlite")]
348                DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.stream(stmt),
349                #[cfg(feature = "mock")]
350                DatabaseConnectionType::MockDatabaseConnection(conn) => {
351                    Ok(crate::QueryStream::from((Arc::clone(conn), stmt, None)))
352                }
353                #[cfg(feature = "proxy")]
354                DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
355                    Ok(crate::QueryStream::from((Arc::clone(conn), stmt, None)))
356                }
357                DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
358            }
359        })
360    }
361}
362
363#[async_trait::async_trait]
364impl TransactionTrait for DatabaseConnection {
365    type Transaction = DatabaseTransaction;
366
367    #[instrument(level = "trace")]
368    async fn begin(&self) -> Result<DatabaseTransaction, DbErr> {
369        match &self.inner {
370            #[cfg(feature = "sqlx-mysql")]
371            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => conn.begin(None, None).await,
372            #[cfg(feature = "sqlx-postgres")]
373            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
374                conn.begin(None, None).await
375            }
376            #[cfg(feature = "sqlx-sqlite")]
377            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => {
378                conn.begin(None, None, None).await
379            }
380            #[cfg(feature = "rusqlite")]
381            DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.begin(None, None, None),
382            #[cfg(feature = "mock")]
383            DatabaseConnectionType::MockDatabaseConnection(conn) => {
384                DatabaseTransaction::new_mock(Arc::clone(conn), None).await
385            }
386            #[cfg(feature = "proxy")]
387            DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
388                DatabaseTransaction::new_proxy(conn.clone(), None).await
389            }
390            DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
391        }
392    }
393
394    #[instrument(level = "trace")]
395    async fn begin_with_config(
396        &self,
397        _isolation_level: Option<IsolationLevel>,
398        _access_mode: Option<AccessMode>,
399    ) -> Result<DatabaseTransaction, DbErr> {
400        match &self.inner {
401            #[cfg(feature = "sqlx-mysql")]
402            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => {
403                conn.begin(_isolation_level, _access_mode).await
404            }
405            #[cfg(feature = "sqlx-postgres")]
406            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
407                conn.begin(_isolation_level, _access_mode).await
408            }
409            #[cfg(feature = "sqlx-sqlite")]
410            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => {
411                conn.begin(_isolation_level, _access_mode, None).await
412            }
413            #[cfg(feature = "rusqlite")]
414            DatabaseConnectionType::RusqliteSharedConnection(conn) => {
415                conn.begin(_isolation_level, _access_mode, None)
416            }
417            #[cfg(feature = "mock")]
418            DatabaseConnectionType::MockDatabaseConnection(conn) => {
419                DatabaseTransaction::new_mock(Arc::clone(conn), None).await
420            }
421            #[cfg(feature = "proxy")]
422            DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
423                DatabaseTransaction::new_proxy(conn.clone(), None).await
424            }
425            DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
426        }
427    }
428
429    #[instrument(level = "trace")]
430    async fn begin_with_options(
431        &self,
432        TransactionOptions {
433            isolation_level: _isolation_level,
434            access_mode: _access_mode,
435            sqlite_transaction_mode: _sqlite_transaction_mode,
436        }: TransactionOptions,
437    ) -> Result<DatabaseTransaction, DbErr> {
438        match &self.inner {
439            #[cfg(feature = "sqlx-mysql")]
440            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => {
441                conn.begin(_isolation_level, _access_mode).await
442            }
443            #[cfg(feature = "sqlx-postgres")]
444            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
445                conn.begin(_isolation_level, _access_mode).await
446            }
447            #[cfg(feature = "sqlx-sqlite")]
448            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => {
449                conn.begin(_isolation_level, _access_mode, _sqlite_transaction_mode)
450                    .await
451            }
452            #[cfg(feature = "rusqlite")]
453            DatabaseConnectionType::RusqliteSharedConnection(conn) => {
454                conn.begin(_isolation_level, _access_mode, _sqlite_transaction_mode)
455            }
456            #[cfg(feature = "mock")]
457            DatabaseConnectionType::MockDatabaseConnection(conn) => {
458                DatabaseTransaction::new_mock(Arc::clone(conn), None).await
459            }
460            #[cfg(feature = "proxy")]
461            DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
462                DatabaseTransaction::new_proxy(conn.clone(), None).await
463            }
464            DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
465        }
466    }
467
468    /// Execute the function inside a transaction.
469    /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed.
470    #[instrument(level = "trace", skip(_callback))]
471    async fn transaction<F, T, E>(&self, _callback: F) -> Result<T, TransactionError<E>>
472    where
473        F: for<'c> FnOnce(
474                &'c DatabaseTransaction,
475            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
476            + Send,
477        T: Send,
478        E: std::fmt::Display + std::fmt::Debug + Send,
479    {
480        match &self.inner {
481            #[cfg(feature = "sqlx-mysql")]
482            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => {
483                conn.transaction(_callback, None, None).await
484            }
485            #[cfg(feature = "sqlx-postgres")]
486            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
487                conn.transaction(_callback, None, None).await
488            }
489            #[cfg(feature = "sqlx-sqlite")]
490            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => {
491                conn.transaction(_callback, None, None).await
492            }
493            #[cfg(feature = "rusqlite")]
494            DatabaseConnectionType::RusqliteSharedConnection(conn) => {
495                conn.transaction(_callback, None, None)
496            }
497            #[cfg(feature = "mock")]
498            DatabaseConnectionType::MockDatabaseConnection(conn) => {
499                let transaction = DatabaseTransaction::new_mock(Arc::clone(conn), None)
500                    .await
501                    .map_err(TransactionError::Connection)?;
502                transaction.run(_callback).await
503            }
504            #[cfg(feature = "proxy")]
505            DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
506                let transaction = DatabaseTransaction::new_proxy(conn.clone(), None)
507                    .await
508                    .map_err(TransactionError::Connection)?;
509                transaction.run(_callback).await
510            }
511            DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected").into()),
512        }
513    }
514
515    /// Execute the function inside a transaction.
516    /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed.
517    #[instrument(level = "trace", skip(_callback))]
518    async fn transaction_with_config<F, T, E>(
519        &self,
520        _callback: F,
521        _isolation_level: Option<IsolationLevel>,
522        _access_mode: Option<AccessMode>,
523    ) -> Result<T, TransactionError<E>>
524    where
525        F: for<'c> FnOnce(
526                &'c DatabaseTransaction,
527            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
528            + Send,
529        T: Send,
530        E: std::fmt::Display + std::fmt::Debug + Send,
531    {
532        match &self.inner {
533            #[cfg(feature = "sqlx-mysql")]
534            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => {
535                conn.transaction(_callback, _isolation_level, _access_mode)
536                    .await
537            }
538            #[cfg(feature = "sqlx-postgres")]
539            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
540                conn.transaction(_callback, _isolation_level, _access_mode)
541                    .await
542            }
543            #[cfg(feature = "sqlx-sqlite")]
544            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => {
545                conn.transaction(_callback, _isolation_level, _access_mode)
546                    .await
547            }
548            #[cfg(feature = "rusqlite")]
549            DatabaseConnectionType::RusqliteSharedConnection(conn) => {
550                conn.transaction(_callback, _isolation_level, _access_mode)
551            }
552            #[cfg(feature = "mock")]
553            DatabaseConnectionType::MockDatabaseConnection(conn) => {
554                let transaction = DatabaseTransaction::new_mock(Arc::clone(conn), None)
555                    .await
556                    .map_err(TransactionError::Connection)?;
557                transaction.run(_callback).await
558            }
559            #[cfg(feature = "proxy")]
560            DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
561                let transaction = DatabaseTransaction::new_proxy(conn.clone(), None)
562                    .await
563                    .map_err(TransactionError::Connection)?;
564                transaction.run(_callback).await
565            }
566            DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected").into()),
567        }
568    }
569}
570
571#[cfg(feature = "mock")]
572impl DatabaseConnection {
573    /// Generate a database connection for testing the Mock database
574    ///
575    /// # Panics
576    ///
577    /// Panics if [DbConn] is not a mock connection.
578    pub fn as_mock_connection(&self) -> &crate::MockDatabaseConnection {
579        match &self.inner {
580            DatabaseConnectionType::MockDatabaseConnection(mock_conn) => mock_conn,
581            _ => panic!("Not mock connection"),
582        }
583    }
584
585    /// Get the transaction log as a collection Vec<[crate::Transaction]>
586    ///
587    /// # Panics
588    ///
589    /// Panics if the mocker mutex is being held by another thread.
590    pub fn into_transaction_log(self) -> Vec<crate::Transaction> {
591        let mut mocker = self
592            .as_mock_connection()
593            .get_mocker_mutex()
594            .lock()
595            .expect("Fail to acquire mocker");
596        mocker.drain_transaction_log()
597    }
598}
599
600#[cfg(feature = "proxy")]
601impl DatabaseConnection {
602    /// Generate a database connection for testing the Proxy database
603    ///
604    /// # Panics
605    ///
606    /// Panics if [DbConn] is not a proxy connection.
607    pub fn as_proxy_connection(&self) -> &crate::ProxyDatabaseConnection {
608        match &self.inner {
609            DatabaseConnectionType::ProxyDatabaseConnection(proxy_conn) => proxy_conn,
610            _ => panic!("Not proxy connection"),
611        }
612    }
613}
614
615#[cfg(feature = "rbac")]
616impl DatabaseConnection {
617    /// Load RBAC data from the same database as this connection and setup RBAC engine.
618    /// If the RBAC engine already exists, it will be replaced.
619    pub async fn load_rbac(&self) -> Result<(), DbErr> {
620        self.load_rbac_from(self).await
621    }
622
623    /// Load RBAC data from the given database connection and setup RBAC engine.
624    /// This could be from another database.
625    pub async fn load_rbac_from(&self, db: &DbConn) -> Result<(), DbErr> {
626        let engine = crate::rbac::RbacEngine::load_from(db).await?;
627        self.rbac.replace(engine);
628        Ok(())
629    }
630
631    /// Replace the internal RBAC engine.
632    pub fn replace_rbac(&self, engine: crate::rbac::RbacEngine) {
633        self.rbac.replace(engine);
634    }
635
636    /// Create a restricted connection with access control specific for the user.
637    pub fn restricted_for(
638        &self,
639        user_id: crate::rbac::RbacUserId,
640    ) -> Result<crate::RestrictedConnection, DbErr> {
641        if self.rbac.is_some() {
642            Ok(crate::RestrictedConnection {
643                user_id,
644                conn: self.clone(),
645            })
646        } else {
647            Err(DbErr::RbacError("engine not set up".into()))
648        }
649    }
650}
651
652impl DatabaseConnection {
653    /// Execute the function inside a transaction.
654    /// If the function returns an error, the transaction will be rolled back.
655    /// Otherwise, the transaction will be committed.
656    #[instrument(level = "trace", skip(callback))]
657    pub async fn transaction_async<F, T, E>(&self, callback: F) -> Result<T, TransactionError<E>>
658    where
659        F: for<'c> AsyncFnOnce(&'c DatabaseTransaction) -> Result<T, E> + Send,
660        T: Send,
661        E: std::fmt::Display + std::fmt::Debug + Send,
662    {
663        let transaction = self.begin().await.map_err(TransactionError::Connection)?;
664        run_async_transaction_callback(transaction, callback).await
665    }
666
667    /// Execute the function inside a transaction with isolation level and/or access mode.
668    /// If the function returns an error, the transaction will be rolled back.
669    /// Otherwise, the transaction will be committed.
670    #[instrument(level = "trace", skip(callback))]
671    pub async fn transaction_with_config_async<F, T, E>(
672        &self,
673        callback: F,
674        isolation_level: Option<IsolationLevel>,
675        access_mode: Option<AccessMode>,
676    ) -> Result<T, TransactionError<E>>
677    where
678        F: for<'c> AsyncFnOnce(&'c DatabaseTransaction) -> Result<T, E> + Send,
679        T: Send,
680        E: std::fmt::Display + std::fmt::Debug + Send,
681    {
682        let transaction = self
683            .begin_with_config(isolation_level, access_mode)
684            .await
685            .map_err(TransactionError::Connection)?;
686        run_async_transaction_callback(transaction, callback).await
687    }
688
689    #[allow(unused)]
690    pub(crate) fn get_record_stmt_in_spans(&self) -> bool {
691        match &self.inner {
692            #[cfg(feature = "sqlx-mysql")]
693            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => conn.record_stmt_in_spans,
694            #[cfg(feature = "sqlx-postgres")]
695            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => conn.record_stmt_in_spans,
696            #[cfg(feature = "sqlx-sqlite")]
697            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => conn.record_stmt_in_spans,
698            #[cfg(feature = "rusqlite")]
699            DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.record_stmt_in_spans,
700            DatabaseConnectionType::Disconnected => true,
701            #[cfg(feature = "mock")]
702            DatabaseConnectionType::MockDatabaseConnection(_) => true,
703            #[cfg(feature = "proxy")]
704            DatabaseConnectionType::ProxyDatabaseConnection(_) => true,
705        }
706    }
707
708    /// Get the database backend for this connection
709    ///
710    /// # Panics
711    ///
712    /// Panics if [DatabaseConnection] is `Disconnected`.
713    pub fn get_database_backend(&self) -> DbBackend {
714        match &self.inner {
715            #[cfg(feature = "sqlx-mysql")]
716            DatabaseConnectionType::SqlxMySqlPoolConnection(_) => DbBackend::MySql,
717            #[cfg(feature = "sqlx-postgres")]
718            DatabaseConnectionType::SqlxPostgresPoolConnection(_) => DbBackend::Postgres,
719            #[cfg(feature = "sqlx-sqlite")]
720            DatabaseConnectionType::SqlxSqlitePoolConnection(_) => DbBackend::Sqlite,
721            #[cfg(feature = "rusqlite")]
722            DatabaseConnectionType::RusqliteSharedConnection(_) => DbBackend::Sqlite,
723            #[cfg(feature = "mock")]
724            DatabaseConnectionType::MockDatabaseConnection(conn) => conn.get_database_backend(),
725            #[cfg(feature = "proxy")]
726            DatabaseConnectionType::ProxyDatabaseConnection(conn) => conn.get_database_backend(),
727            DatabaseConnectionType::Disconnected => panic!("Disconnected"),
728        }
729    }
730
731    /// Creates a [`SchemaBuilder`] for this backend
732    pub fn get_schema_builder(&self) -> SchemaBuilder {
733        Schema::new(self.get_database_backend()).builder()
734    }
735
736    #[cfg(feature = "entity-registry")]
737    #[cfg_attr(docsrs, doc(cfg(feature = "entity-registry")))]
738    /// Builds a schema for all the entites in the given module
739    pub fn get_schema_registry(&self, prefix: &str) -> SchemaBuilder {
740        let schema = Schema::new(self.get_database_backend());
741        crate::EntityRegistry::build_schema(schema, prefix)
742    }
743
744    /// Sets a callback to metric this connection
745    pub fn set_metric_callback<F>(&mut self, _callback: F)
746    where
747        F: Fn(&crate::metric::Info<'_>) + Send + Sync + 'static,
748    {
749        match &mut self.inner {
750            #[cfg(feature = "sqlx-mysql")]
751            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => {
752                conn.set_metric_callback(_callback)
753            }
754            #[cfg(feature = "sqlx-postgres")]
755            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
756                conn.set_metric_callback(_callback)
757            }
758            #[cfg(feature = "sqlx-sqlite")]
759            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => {
760                conn.set_metric_callback(_callback)
761            }
762            #[cfg(feature = "rusqlite")]
763            DatabaseConnectionType::RusqliteSharedConnection(conn) => {
764                conn.set_metric_callback(_callback)
765            }
766            _ => {}
767        }
768    }
769
770    /// Checks if a connection to the database is still valid.
771    pub async fn ping(&self) -> Result<(), DbErr> {
772        match &self.inner {
773            #[cfg(feature = "sqlx-mysql")]
774            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => conn.ping().await,
775            #[cfg(feature = "sqlx-postgres")]
776            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => conn.ping().await,
777            #[cfg(feature = "sqlx-sqlite")]
778            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => conn.ping().await,
779            #[cfg(feature = "rusqlite")]
780            DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.ping(),
781            #[cfg(feature = "mock")]
782            DatabaseConnectionType::MockDatabaseConnection(conn) => conn.ping(),
783            #[cfg(feature = "proxy")]
784            DatabaseConnectionType::ProxyDatabaseConnection(conn) => conn.ping().await,
785            DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
786        }
787    }
788
789    /// Explicitly close the database connection.
790    /// See [`Self::close_by_ref`] for usage with references.
791    pub async fn close(self) -> Result<(), DbErr> {
792        self.close_by_ref().await
793    }
794
795    /// Explicitly close the database connection
796    pub async fn close_by_ref(&self) -> Result<(), DbErr> {
797        match &self.inner {
798            #[cfg(feature = "sqlx-mysql")]
799            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => conn.close_by_ref().await,
800            #[cfg(feature = "sqlx-postgres")]
801            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => conn.close_by_ref().await,
802            #[cfg(feature = "sqlx-sqlite")]
803            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => conn.close_by_ref().await,
804            #[cfg(feature = "rusqlite")]
805            DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.close_by_ref(),
806            #[cfg(feature = "mock")]
807            DatabaseConnectionType::MockDatabaseConnection(_) => {
808                // Nothing to cleanup, we just consume the `DatabaseConnection`
809                Ok(())
810            }
811            #[cfg(feature = "proxy")]
812            DatabaseConnectionType::ProxyDatabaseConnection(_) => {
813                // Nothing to cleanup, we just consume the `DatabaseConnection`
814                Ok(())
815            }
816            DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
817        }
818    }
819}
820
821impl DatabaseConnection {
822    /// Get [sqlx::MySqlPool]
823    ///
824    /// # Panics
825    ///
826    /// Panics if [DbConn] is not a MySQL connection.
827    #[cfg(feature = "sqlx-mysql")]
828    pub fn get_mysql_connection_pool(&self) -> &sqlx::MySqlPool {
829        match &self.inner {
830            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => &conn.pool,
831            _ => panic!("Not MySQL Connection"),
832        }
833    }
834
835    /// Get [sqlx::PgPool]
836    ///
837    /// # Panics
838    ///
839    /// Panics if [DbConn] is not a Postgres connection.
840    #[cfg(feature = "sqlx-postgres")]
841    pub fn get_postgres_connection_pool(&self) -> &sqlx::PgPool {
842        match &self.inner {
843            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => &conn.pool,
844            _ => panic!("Not Postgres Connection"),
845        }
846    }
847
848    /// Get [sqlx::SqlitePool]
849    ///
850    /// # Panics
851    ///
852    /// Panics if [DbConn] is not a SQLite connection.
853    #[cfg(feature = "sqlx-sqlite")]
854    pub fn get_sqlite_connection_pool(&self) -> &sqlx::SqlitePool {
855        match &self.inner {
856            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => &conn.pool,
857            _ => panic!("Not SQLite Connection"),
858        }
859    }
860}
861
862impl DbBackend {
863    /// Check if the URI is the same as the specified database backend.
864    /// Returns true if they match.
865    ///
866    /// # Panics
867    ///
868    /// Panics if `base_url` cannot be parsed as `Url`.
869    pub fn is_prefix_of(self, base_url: &str) -> bool {
870        let base_url_parsed = Url::parse(base_url).expect("Fail to parse database URL");
871        match self {
872            Self::Postgres => {
873                base_url_parsed.scheme() == "postgres" || base_url_parsed.scheme() == "postgresql"
874            }
875            Self::MySql => base_url_parsed.scheme() == "mysql",
876            Self::Sqlite => base_url_parsed.scheme() == "sqlite",
877        }
878    }
879
880    /// Build an SQL [Statement]
881    pub fn build<S>(&self, statement: &S) -> Statement
882    where
883        S: StatementBuilder,
884    {
885        statement.build(self)
886    }
887
888    /// Check if the database supports `RETURNING` syntax on insert and update
889    pub fn support_returning(&self) -> bool {
890        match self {
891            Self::Postgres => true,
892            Self::Sqlite if cfg!(feature = "sqlite-use-returning-for-3_35") => true,
893            Self::MySql if cfg!(feature = "mariadb-use-returning") => true,
894            _ => false,
895        }
896    }
897
898    /// A getter for database dependent boolean value
899    pub fn boolean_value(&self, boolean: bool) -> sea_query::Value {
900        match self {
901            Self::MySql | Self::Postgres | Self::Sqlite => boolean.into(),
902        }
903    }
904
905    /// Get the display string for this enum
906    pub fn as_str(&self) -> &'static str {
907        match self {
908            DatabaseBackend::MySql => "MySql",
909            DatabaseBackend::Postgres => "Postgres",
910            DatabaseBackend::Sqlite => "Sqlite",
911        }
912    }
913}
914
915#[cfg(test)]
916mod tests {
917    use crate::DatabaseConnection;
918
919    #[cfg(not(feature = "sync"))]
920    #[test]
921    fn assert_database_connection_traits() {
922        fn assert_send_sync<T: Send + Sync>() {}
923
924        assert_send_sync::<DatabaseConnection>();
925    }
926}