sea_orm/database/
db_connection.rs

1use crate::{
2    AccessMode, ConnectionTrait, DatabaseTransaction, ExecResult, IsolationLevel, QueryResult,
3    Schema, SchemaBuilder, Statement, StatementBuilder, StreamTrait, TransactionError,
4    TransactionTrait, error::*,
5};
6use std::{fmt::Debug, future::Future, pin::Pin};
7use tracing::instrument;
8use url::Url;
9
10#[cfg(feature = "sqlx-dep")]
11use sqlx::pool::PoolConnection;
12
13#[cfg(feature = "rusqlite")]
14use crate::driver::rusqlite::{RusqliteInnerConnection, RusqliteSharedConnection};
15
16#[cfg(any(feature = "mock", feature = "proxy"))]
17use std::sync::Arc;
18
19/// Handle a database connection depending on the backend enabled by the feature
20/// flags. This creates a connection pool internally (for SQLx connections),
21/// and so is cheap to clone.
22#[derive(Debug, Clone)]
23#[non_exhaustive]
24pub struct DatabaseConnection {
25    /// `DatabaseConnection` used to be a enum. Now it's moved into inner,
26    /// because we have to attach other contexts.
27    pub inner: DatabaseConnectionType,
28    #[cfg(feature = "rbac")]
29    pub(crate) rbac: crate::RbacEngineMount,
30}
31
32/// The underlying database connection type.
33#[derive(Clone)]
34pub enum DatabaseConnectionType {
35    /// MySql database connection pool
36    #[cfg(feature = "sqlx-mysql")]
37    SqlxMySqlPoolConnection(crate::SqlxMySqlPoolConnection),
38
39    /// PostgreSQL database connection pool
40    #[cfg(feature = "sqlx-postgres")]
41    SqlxPostgresPoolConnection(crate::SqlxPostgresPoolConnection),
42
43    /// SQLite database connection pool
44    #[cfg(feature = "sqlx-sqlite")]
45    SqlxSqlitePoolConnection(crate::SqlxSqlitePoolConnection),
46
47    /// SQLite database connection sharable across threads
48    #[cfg(feature = "rusqlite")]
49    RusqliteSharedConnection(RusqliteSharedConnection),
50
51    /// Mock database connection useful for testing
52    #[cfg(feature = "mock")]
53    MockDatabaseConnection(Arc<crate::MockDatabaseConnection>),
54
55    /// Proxy database connection
56    #[cfg(feature = "proxy")]
57    ProxyDatabaseConnection(Arc<crate::ProxyDatabaseConnection>),
58
59    /// The connection has never been established
60    Disconnected,
61}
62
63/// The same as a [DatabaseConnection]
64pub type DbConn = DatabaseConnection;
65
66impl Default for DatabaseConnection {
67    fn default() -> Self {
68        DatabaseConnectionType::Disconnected.into()
69    }
70}
71
72impl From<DatabaseConnectionType> for DatabaseConnection {
73    fn from(inner: DatabaseConnectionType) -> Self {
74        Self {
75            inner,
76            #[cfg(feature = "rbac")]
77            rbac: Default::default(),
78        }
79    }
80}
81
82/// The type of database backend for real world databases.
83/// This is enabled by feature flags as specified in the crate documentation
84#[derive(Debug, Copy, Clone, PartialEq, Eq)]
85#[non_exhaustive]
86pub enum DatabaseBackend {
87    /// A MySQL backend
88    MySql,
89    /// A PostgreSQL backend
90    Postgres,
91    /// A SQLite backend
92    Sqlite,
93}
94
95/// A shorthand for [DatabaseBackend].
96pub type DbBackend = DatabaseBackend;
97
98#[derive(Debug)]
99pub(crate) enum InnerConnection {
100    #[cfg(feature = "sqlx-mysql")]
101    MySql(PoolConnection<sqlx::MySql>),
102    #[cfg(feature = "sqlx-postgres")]
103    Postgres(PoolConnection<sqlx::Postgres>),
104    #[cfg(feature = "sqlx-sqlite")]
105    Sqlite(PoolConnection<sqlx::Sqlite>),
106    #[cfg(feature = "rusqlite")]
107    Rusqlite(RusqliteInnerConnection),
108    #[cfg(feature = "mock")]
109    Mock(Arc<crate::MockDatabaseConnection>),
110    #[cfg(feature = "proxy")]
111    Proxy(Arc<crate::ProxyDatabaseConnection>),
112}
113
114impl Debug for DatabaseConnectionType {
115    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
116        write!(
117            f,
118            "{}",
119            match self {
120                #[cfg(feature = "sqlx-mysql")]
121                Self::SqlxMySqlPoolConnection(_) => "SqlxMySqlPoolConnection",
122                #[cfg(feature = "sqlx-postgres")]
123                Self::SqlxPostgresPoolConnection(_) => "SqlxPostgresPoolConnection",
124                #[cfg(feature = "sqlx-sqlite")]
125                Self::SqlxSqlitePoolConnection(_) => "SqlxSqlitePoolConnection",
126                #[cfg(feature = "rusqlite")]
127                Self::RusqliteSharedConnection(_) => "RusqliteSharedConnection",
128                #[cfg(feature = "mock")]
129                Self::MockDatabaseConnection(_) => "MockDatabaseConnection",
130                #[cfg(feature = "proxy")]
131                Self::ProxyDatabaseConnection(_) => "ProxyDatabaseConnection",
132                Self::Disconnected => "Disconnected",
133            }
134        )
135    }
136}
137
138#[async_trait::async_trait]
139impl ConnectionTrait for DatabaseConnection {
140    fn get_database_backend(&self) -> DbBackend {
141        self.get_database_backend()
142    }
143
144    #[instrument(level = "trace")]
145    #[allow(unused_variables)]
146    async fn execute_raw(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
147        match &self.inner {
148            #[cfg(feature = "sqlx-mysql")]
149            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => conn.execute(stmt).await,
150            #[cfg(feature = "sqlx-postgres")]
151            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => conn.execute(stmt).await,
152            #[cfg(feature = "sqlx-sqlite")]
153            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => conn.execute(stmt).await,
154            #[cfg(feature = "rusqlite")]
155            DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.execute(stmt),
156            #[cfg(feature = "mock")]
157            DatabaseConnectionType::MockDatabaseConnection(conn) => conn.execute(stmt),
158            #[cfg(feature = "proxy")]
159            DatabaseConnectionType::ProxyDatabaseConnection(conn) => conn.execute(stmt).await,
160            DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
161        }
162    }
163
164    #[instrument(level = "trace")]
165    #[allow(unused_variables)]
166    async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
167        match &self.inner {
168            #[cfg(feature = "sqlx-mysql")]
169            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => {
170                conn.execute_unprepared(sql).await
171            }
172            #[cfg(feature = "sqlx-postgres")]
173            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
174                conn.execute_unprepared(sql).await
175            }
176            #[cfg(feature = "sqlx-sqlite")]
177            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => {
178                conn.execute_unprepared(sql).await
179            }
180            #[cfg(feature = "rusqlite")]
181            DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.execute_unprepared(sql),
182            #[cfg(feature = "mock")]
183            DatabaseConnectionType::MockDatabaseConnection(conn) => {
184                let db_backend = conn.get_database_backend();
185                let stmt = Statement::from_string(db_backend, sql);
186                conn.execute(stmt)
187            }
188            #[cfg(feature = "proxy")]
189            DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
190                let db_backend = conn.get_database_backend();
191                let stmt = Statement::from_string(db_backend, sql);
192                conn.execute(stmt).await
193            }
194            DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
195        }
196    }
197
198    #[instrument(level = "trace")]
199    #[allow(unused_variables)]
200    async fn query_one_raw(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
201        match &self.inner {
202            #[cfg(feature = "sqlx-mysql")]
203            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => conn.query_one(stmt).await,
204            #[cfg(feature = "sqlx-postgres")]
205            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => conn.query_one(stmt).await,
206            #[cfg(feature = "sqlx-sqlite")]
207            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => conn.query_one(stmt).await,
208            #[cfg(feature = "rusqlite")]
209            DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.query_one(stmt),
210            #[cfg(feature = "mock")]
211            DatabaseConnectionType::MockDatabaseConnection(conn) => conn.query_one(stmt),
212            #[cfg(feature = "proxy")]
213            DatabaseConnectionType::ProxyDatabaseConnection(conn) => conn.query_one(stmt).await,
214            DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
215        }
216    }
217
218    #[instrument(level = "trace")]
219    #[allow(unused_variables)]
220    async fn query_all_raw(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
221        match &self.inner {
222            #[cfg(feature = "sqlx-mysql")]
223            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => conn.query_all(stmt).await,
224            #[cfg(feature = "sqlx-postgres")]
225            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => conn.query_all(stmt).await,
226            #[cfg(feature = "sqlx-sqlite")]
227            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => conn.query_all(stmt).await,
228            #[cfg(feature = "rusqlite")]
229            DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.query_all(stmt),
230            #[cfg(feature = "mock")]
231            DatabaseConnectionType::MockDatabaseConnection(conn) => conn.query_all(stmt),
232            #[cfg(feature = "proxy")]
233            DatabaseConnectionType::ProxyDatabaseConnection(conn) => conn.query_all(stmt).await,
234            DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
235        }
236    }
237
238    #[cfg(feature = "mock")]
239    fn is_mock_connection(&self) -> bool {
240        matches!(
241            self,
242            DatabaseConnection {
243                inner: DatabaseConnectionType::MockDatabaseConnection(_),
244                ..
245            }
246        )
247    }
248}
249
250#[async_trait::async_trait]
251impl StreamTrait for DatabaseConnection {
252    type Stream<'a> = crate::QueryStream;
253
254    fn get_database_backend(&self) -> DbBackend {
255        self.get_database_backend()
256    }
257
258    #[instrument(level = "trace")]
259    #[allow(unused_variables)]
260    fn stream_raw<'a>(
261        &'a self,
262        stmt: Statement,
263    ) -> Pin<Box<dyn Future<Output = Result<Self::Stream<'a>, DbErr>> + 'a + Send>> {
264        Box::pin(async move {
265            match &self.inner {
266                #[cfg(feature = "sqlx-mysql")]
267                DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => conn.stream(stmt).await,
268                #[cfg(feature = "sqlx-postgres")]
269                DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => conn.stream(stmt).await,
270                #[cfg(feature = "sqlx-sqlite")]
271                DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => conn.stream(stmt).await,
272                #[cfg(feature = "rusqlite")]
273                DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.stream(stmt),
274                #[cfg(feature = "mock")]
275                DatabaseConnectionType::MockDatabaseConnection(conn) => {
276                    Ok(crate::QueryStream::from((Arc::clone(conn), stmt, None)))
277                }
278                #[cfg(feature = "proxy")]
279                DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
280                    Ok(crate::QueryStream::from((Arc::clone(conn), stmt, None)))
281                }
282                DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
283            }
284        })
285    }
286}
287
288#[async_trait::async_trait]
289impl TransactionTrait for DatabaseConnection {
290    type Transaction = DatabaseTransaction;
291
292    #[instrument(level = "trace")]
293    async fn begin(&self) -> Result<DatabaseTransaction, DbErr> {
294        match &self.inner {
295            #[cfg(feature = "sqlx-mysql")]
296            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => conn.begin(None, None).await,
297            #[cfg(feature = "sqlx-postgres")]
298            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
299                conn.begin(None, None).await
300            }
301            #[cfg(feature = "sqlx-sqlite")]
302            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => conn.begin(None, None).await,
303            #[cfg(feature = "rusqlite")]
304            DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.begin(None, None),
305            #[cfg(feature = "mock")]
306            DatabaseConnectionType::MockDatabaseConnection(conn) => {
307                DatabaseTransaction::new_mock(Arc::clone(conn), None).await
308            }
309            #[cfg(feature = "proxy")]
310            DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
311                DatabaseTransaction::new_proxy(conn.clone(), None).await
312            }
313            DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
314        }
315    }
316
317    #[instrument(level = "trace")]
318    async fn begin_with_config(
319        &self,
320        _isolation_level: Option<IsolationLevel>,
321        _access_mode: Option<AccessMode>,
322    ) -> Result<DatabaseTransaction, DbErr> {
323        match &self.inner {
324            #[cfg(feature = "sqlx-mysql")]
325            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => {
326                conn.begin(_isolation_level, _access_mode).await
327            }
328            #[cfg(feature = "sqlx-postgres")]
329            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
330                conn.begin(_isolation_level, _access_mode).await
331            }
332            #[cfg(feature = "sqlx-sqlite")]
333            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => {
334                conn.begin(_isolation_level, _access_mode).await
335            }
336            #[cfg(feature = "rusqlite")]
337            DatabaseConnectionType::RusqliteSharedConnection(conn) => {
338                conn.begin(_isolation_level, _access_mode)
339            }
340            #[cfg(feature = "mock")]
341            DatabaseConnectionType::MockDatabaseConnection(conn) => {
342                DatabaseTransaction::new_mock(Arc::clone(conn), None).await
343            }
344            #[cfg(feature = "proxy")]
345            DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
346                DatabaseTransaction::new_proxy(conn.clone(), None).await
347            }
348            DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
349        }
350    }
351
352    /// Execute the function inside a transaction.
353    /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed.
354    #[instrument(level = "trace", skip(_callback))]
355    async fn transaction<F, T, E>(&self, _callback: F) -> Result<T, TransactionError<E>>
356    where
357        F: for<'c> FnOnce(
358                &'c DatabaseTransaction,
359            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
360            + Send,
361        T: Send,
362        E: std::fmt::Display + std::fmt::Debug + Send,
363    {
364        match &self.inner {
365            #[cfg(feature = "sqlx-mysql")]
366            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => {
367                conn.transaction(_callback, None, None).await
368            }
369            #[cfg(feature = "sqlx-postgres")]
370            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
371                conn.transaction(_callback, None, None).await
372            }
373            #[cfg(feature = "sqlx-sqlite")]
374            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => {
375                conn.transaction(_callback, None, None).await
376            }
377            #[cfg(feature = "rusqlite")]
378            DatabaseConnectionType::RusqliteSharedConnection(conn) => {
379                conn.transaction(_callback, None, None)
380            }
381            #[cfg(feature = "mock")]
382            DatabaseConnectionType::MockDatabaseConnection(conn) => {
383                let transaction = DatabaseTransaction::new_mock(Arc::clone(conn), None)
384                    .await
385                    .map_err(TransactionError::Connection)?;
386                transaction.run(_callback).await
387            }
388            #[cfg(feature = "proxy")]
389            DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
390                let transaction = DatabaseTransaction::new_proxy(conn.clone(), None)
391                    .await
392                    .map_err(TransactionError::Connection)?;
393                transaction.run(_callback).await
394            }
395            DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected").into()),
396        }
397    }
398
399    /// Execute the function inside a transaction.
400    /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed.
401    #[instrument(level = "trace", skip(_callback))]
402    async fn transaction_with_config<F, T, E>(
403        &self,
404        _callback: F,
405        _isolation_level: Option<IsolationLevel>,
406        _access_mode: Option<AccessMode>,
407    ) -> Result<T, TransactionError<E>>
408    where
409        F: for<'c> FnOnce(
410                &'c DatabaseTransaction,
411            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
412            + Send,
413        T: Send,
414        E: std::fmt::Display + std::fmt::Debug + Send,
415    {
416        match &self.inner {
417            #[cfg(feature = "sqlx-mysql")]
418            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => {
419                conn.transaction(_callback, _isolation_level, _access_mode)
420                    .await
421            }
422            #[cfg(feature = "sqlx-postgres")]
423            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
424                conn.transaction(_callback, _isolation_level, _access_mode)
425                    .await
426            }
427            #[cfg(feature = "sqlx-sqlite")]
428            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => {
429                conn.transaction(_callback, _isolation_level, _access_mode)
430                    .await
431            }
432            #[cfg(feature = "rusqlite")]
433            DatabaseConnectionType::RusqliteSharedConnection(conn) => {
434                conn.transaction(_callback, _isolation_level, _access_mode)
435            }
436            #[cfg(feature = "mock")]
437            DatabaseConnectionType::MockDatabaseConnection(conn) => {
438                let transaction = DatabaseTransaction::new_mock(Arc::clone(conn), None)
439                    .await
440                    .map_err(TransactionError::Connection)?;
441                transaction.run(_callback).await
442            }
443            #[cfg(feature = "proxy")]
444            DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
445                let transaction = DatabaseTransaction::new_proxy(conn.clone(), None)
446                    .await
447                    .map_err(TransactionError::Connection)?;
448                transaction.run(_callback).await
449            }
450            DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected").into()),
451        }
452    }
453}
454
455#[cfg(feature = "mock")]
456impl DatabaseConnection {
457    /// Generate a database connection for testing the Mock database
458    ///
459    /// # Panics
460    ///
461    /// Panics if [DbConn] is not a mock connection.
462    pub fn as_mock_connection(&self) -> &crate::MockDatabaseConnection {
463        match &self.inner {
464            DatabaseConnectionType::MockDatabaseConnection(mock_conn) => mock_conn,
465            _ => panic!("Not mock connection"),
466        }
467    }
468
469    /// Get the transaction log as a collection Vec<[crate::Transaction]>
470    ///
471    /// # Panics
472    ///
473    /// Panics if the mocker mutex is being held by another thread.
474    pub fn into_transaction_log(self) -> Vec<crate::Transaction> {
475        let mut mocker = self
476            .as_mock_connection()
477            .get_mocker_mutex()
478            .lock()
479            .expect("Fail to acquire mocker");
480        mocker.drain_transaction_log()
481    }
482}
483
484#[cfg(feature = "proxy")]
485impl DatabaseConnection {
486    /// Generate a database connection for testing the Proxy database
487    ///
488    /// # Panics
489    ///
490    /// Panics if [DbConn] is not a proxy connection.
491    pub fn as_proxy_connection(&self) -> &crate::ProxyDatabaseConnection {
492        match &self.inner {
493            DatabaseConnectionType::ProxyDatabaseConnection(proxy_conn) => proxy_conn,
494            _ => panic!("Not proxy connection"),
495        }
496    }
497}
498
499#[cfg(feature = "rbac")]
500impl DatabaseConnection {
501    /// Load RBAC data from the same database as this connection and setup RBAC engine.
502    /// If the RBAC engine already exists, it will be replaced.
503    pub async fn load_rbac(&self) -> Result<(), DbErr> {
504        self.load_rbac_from(self).await
505    }
506
507    /// Load RBAC data from the given database connection and setup RBAC engine.
508    /// This could be from another database.
509    pub async fn load_rbac_from(&self, db: &DbConn) -> Result<(), DbErr> {
510        let engine = crate::rbac::RbacEngine::load_from(db).await?;
511        self.rbac.replace(engine);
512        Ok(())
513    }
514
515    /// Replace the internal RBAC engine.
516    pub fn replace_rbac(&self, engine: crate::rbac::RbacEngine) {
517        self.rbac.replace(engine);
518    }
519
520    /// Create a restricted connection with access control specific for the user.
521    pub fn restricted_for(
522        &self,
523        user_id: crate::rbac::RbacUserId,
524    ) -> Result<crate::RestrictedConnection, DbErr> {
525        if self.rbac.is_some() {
526            Ok(crate::RestrictedConnection {
527                user_id,
528                conn: self.clone(),
529            })
530        } else {
531            Err(DbErr::RbacError("engine not set up".into()))
532        }
533    }
534}
535
536impl DatabaseConnection {
537    /// Get the database backend for this connection
538    ///
539    /// # Panics
540    ///
541    /// Panics if [DatabaseConnection] is `Disconnected`.
542    pub fn get_database_backend(&self) -> DbBackend {
543        match &self.inner {
544            #[cfg(feature = "sqlx-mysql")]
545            DatabaseConnectionType::SqlxMySqlPoolConnection(_) => DbBackend::MySql,
546            #[cfg(feature = "sqlx-postgres")]
547            DatabaseConnectionType::SqlxPostgresPoolConnection(_) => DbBackend::Postgres,
548            #[cfg(feature = "sqlx-sqlite")]
549            DatabaseConnectionType::SqlxSqlitePoolConnection(_) => DbBackend::Sqlite,
550            #[cfg(feature = "rusqlite")]
551            DatabaseConnectionType::RusqliteSharedConnection(_) => DbBackend::Sqlite,
552            #[cfg(feature = "mock")]
553            DatabaseConnectionType::MockDatabaseConnection(conn) => conn.get_database_backend(),
554            #[cfg(feature = "proxy")]
555            DatabaseConnectionType::ProxyDatabaseConnection(conn) => conn.get_database_backend(),
556            DatabaseConnectionType::Disconnected => panic!("Disconnected"),
557        }
558    }
559
560    /// Creates a [`SchemaBuilder`] for this backend
561    pub fn get_schema_builder(&self) -> SchemaBuilder {
562        Schema::new(self.get_database_backend()).builder()
563    }
564
565    #[cfg(feature = "entity-registry")]
566    #[cfg_attr(docsrs, doc(cfg(feature = "entity-registry")))]
567    /// Builds a schema for all the entites in the given module
568    pub fn get_schema_registry(&self, prefix: &str) -> SchemaBuilder {
569        let schema = Schema::new(self.get_database_backend());
570        crate::EntityRegistry::build_schema(schema, prefix)
571    }
572
573    /// Sets a callback to metric this connection
574    pub fn set_metric_callback<F>(&mut self, _callback: F)
575    where
576        F: Fn(&crate::metric::Info<'_>) + Send + Sync + 'static,
577    {
578        match &mut self.inner {
579            #[cfg(feature = "sqlx-mysql")]
580            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => {
581                conn.set_metric_callback(_callback)
582            }
583            #[cfg(feature = "sqlx-postgres")]
584            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
585                conn.set_metric_callback(_callback)
586            }
587            #[cfg(feature = "sqlx-sqlite")]
588            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => {
589                conn.set_metric_callback(_callback)
590            }
591            #[cfg(feature = "rusqlite")]
592            DatabaseConnectionType::RusqliteSharedConnection(conn) => {
593                conn.set_metric_callback(_callback)
594            }
595            _ => {}
596        }
597    }
598
599    /// Checks if a connection to the database is still valid.
600    pub async fn ping(&self) -> Result<(), DbErr> {
601        match &self.inner {
602            #[cfg(feature = "sqlx-mysql")]
603            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => conn.ping().await,
604            #[cfg(feature = "sqlx-postgres")]
605            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => conn.ping().await,
606            #[cfg(feature = "sqlx-sqlite")]
607            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => conn.ping().await,
608            #[cfg(feature = "rusqlite")]
609            DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.ping(),
610            #[cfg(feature = "mock")]
611            DatabaseConnectionType::MockDatabaseConnection(conn) => conn.ping(),
612            #[cfg(feature = "proxy")]
613            DatabaseConnectionType::ProxyDatabaseConnection(conn) => conn.ping().await,
614            DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
615        }
616    }
617
618    /// Explicitly close the database connection.
619    /// See [`Self::close_by_ref`] for usage with references.
620    pub async fn close(self) -> Result<(), DbErr> {
621        self.close_by_ref().await
622    }
623
624    /// Explicitly close the database connection
625    pub async fn close_by_ref(&self) -> Result<(), DbErr> {
626        match &self.inner {
627            #[cfg(feature = "sqlx-mysql")]
628            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => conn.close_by_ref().await,
629            #[cfg(feature = "sqlx-postgres")]
630            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => conn.close_by_ref().await,
631            #[cfg(feature = "sqlx-sqlite")]
632            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => conn.close_by_ref().await,
633            #[cfg(feature = "rusqlite")]
634            DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.close_by_ref(),
635            #[cfg(feature = "mock")]
636            DatabaseConnectionType::MockDatabaseConnection(_) => {
637                // Nothing to cleanup, we just consume the `DatabaseConnection`
638                Ok(())
639            }
640            #[cfg(feature = "proxy")]
641            DatabaseConnectionType::ProxyDatabaseConnection(_) => {
642                // Nothing to cleanup, we just consume the `DatabaseConnection`
643                Ok(())
644            }
645            DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
646        }
647    }
648}
649
650impl DatabaseConnection {
651    /// Get [sqlx::MySqlPool]
652    ///
653    /// # Panics
654    ///
655    /// Panics if [DbConn] is not a MySQL connection.
656    #[cfg(feature = "sqlx-mysql")]
657    pub fn get_mysql_connection_pool(&self) -> &sqlx::MySqlPool {
658        match &self.inner {
659            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => &conn.pool,
660            _ => panic!("Not MySQL Connection"),
661        }
662    }
663
664    /// Get [sqlx::PgPool]
665    ///
666    /// # Panics
667    ///
668    /// Panics if [DbConn] is not a Postgres connection.
669    #[cfg(feature = "sqlx-postgres")]
670    pub fn get_postgres_connection_pool(&self) -> &sqlx::PgPool {
671        match &self.inner {
672            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => &conn.pool,
673            _ => panic!("Not Postgres Connection"),
674        }
675    }
676
677    /// Get [sqlx::SqlitePool]
678    ///
679    /// # Panics
680    ///
681    /// Panics if [DbConn] is not a SQLite connection.
682    #[cfg(feature = "sqlx-sqlite")]
683    pub fn get_sqlite_connection_pool(&self) -> &sqlx::SqlitePool {
684        match &self.inner {
685            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => &conn.pool,
686            _ => panic!("Not SQLite Connection"),
687        }
688    }
689}
690
691impl DbBackend {
692    /// Check if the URI is the same as the specified database backend.
693    /// Returns true if they match.
694    ///
695    /// # Panics
696    ///
697    /// Panics if `base_url` cannot be parsed as `Url`.
698    pub fn is_prefix_of(self, base_url: &str) -> bool {
699        let base_url_parsed = Url::parse(base_url).expect("Fail to parse database URL");
700        match self {
701            Self::Postgres => {
702                base_url_parsed.scheme() == "postgres" || base_url_parsed.scheme() == "postgresql"
703            }
704            Self::MySql => base_url_parsed.scheme() == "mysql",
705            Self::Sqlite => base_url_parsed.scheme() == "sqlite",
706        }
707    }
708
709    /// Build an SQL [Statement]
710    pub fn build<S>(&self, statement: &S) -> Statement
711    where
712        S: StatementBuilder,
713    {
714        statement.build(self)
715    }
716
717    /// Check if the database supports `RETURNING` syntax on insert and update
718    pub fn support_returning(&self) -> bool {
719        match self {
720            Self::Postgres => true,
721            Self::Sqlite if cfg!(feature = "sqlite-use-returning-for-3_35") => true,
722            Self::MySql if cfg!(feature = "mariadb-use-returning") => true,
723            _ => false,
724        }
725    }
726
727    /// A getter for database dependent boolean value
728    pub fn boolean_value(&self, boolean: bool) -> sea_query::Value {
729        match self {
730            Self::MySql | Self::Postgres | Self::Sqlite => boolean.into(),
731        }
732    }
733
734    /// Get the display string for this enum
735    pub fn as_str(&self) -> &'static str {
736        match self {
737            DatabaseBackend::MySql => "MySql",
738            DatabaseBackend::Postgres => "Postgres",
739            DatabaseBackend::Sqlite => "Sqlite",
740        }
741    }
742}
743
744#[cfg(test)]
745mod tests {
746    use crate::DatabaseConnection;
747
748    #[cfg(not(feature = "sync"))]
749    #[test]
750    fn assert_database_connection_traits() {
751        fn assert_send_sync<T: Send + Sync>() {}
752
753        assert_send_sync::<DatabaseConnection>();
754    }
755}