Skip to main content

sea_orm/database/
db_connection.rs

1use crate::{
2    AccessMode, ConnectionTrait, DatabaseTransaction, ExecResult, IsolationLevel, QueryResult,
3    Schema, SchemaBuilder, Statement, StatementBuilder, StreamTrait, TransactionError,
4    TransactionTrait, error::*,
5};
6use std::{fmt::Debug, future::Future, pin::Pin};
7use tracing::instrument;
8use url::Url;
9
10#[cfg(feature = "sqlx-dep")]
11use sqlx::pool::PoolConnection;
12
13#[cfg(feature = "rusqlite")]
14use crate::driver::rusqlite::{RusqliteInnerConnection, RusqliteSharedConnection};
15
16#[cfg(any(feature = "mock", feature = "proxy"))]
17use std::sync::Arc;
18
19/// Handle a database connection depending on the backend enabled by the feature
20/// flags. This creates a connection pool internally (for SQLx connections),
21/// and so is cheap to clone.
22#[derive(Debug, Clone)]
23#[non_exhaustive]
24pub struct DatabaseConnection {
25    /// `DatabaseConnection` used to be a enum. Now it's moved into inner,
26    /// because we have to attach other contexts.
27    pub inner: DatabaseConnectionType,
28    #[cfg(feature = "rbac")]
29    pub(crate) rbac: crate::RbacEngineMount,
30}
31
32/// The underlying database connection type.
33#[derive(Clone)]
34pub enum DatabaseConnectionType {
35    /// MySql database connection pool
36    #[cfg(feature = "sqlx-mysql")]
37    SqlxMySqlPoolConnection(crate::SqlxMySqlPoolConnection),
38
39    /// PostgreSQL database connection pool
40    #[cfg(feature = "sqlx-postgres")]
41    SqlxPostgresPoolConnection(crate::SqlxPostgresPoolConnection),
42
43    /// SQLite database connection pool
44    #[cfg(feature = "sqlx-sqlite")]
45    SqlxSqlitePoolConnection(crate::SqlxSqlitePoolConnection),
46
47    /// SQLite database connection sharable across threads
48    #[cfg(feature = "rusqlite")]
49    RusqliteSharedConnection(RusqliteSharedConnection),
50
51    /// Mock database connection useful for testing
52    #[cfg(feature = "mock")]
53    MockDatabaseConnection(Arc<crate::MockDatabaseConnection>),
54
55    /// Proxy database connection
56    #[cfg(feature = "proxy")]
57    ProxyDatabaseConnection(Arc<crate::ProxyDatabaseConnection>),
58
59    /// The connection has never been established
60    Disconnected,
61}
62
63/// The same as a [DatabaseConnection]
64pub type DbConn = DatabaseConnection;
65
66impl Default for DatabaseConnection {
67    fn default() -> Self {
68        DatabaseConnectionType::Disconnected.into()
69    }
70}
71
72impl From<DatabaseConnectionType> for DatabaseConnection {
73    fn from(inner: DatabaseConnectionType) -> Self {
74        Self {
75            inner,
76            #[cfg(feature = "rbac")]
77            rbac: Default::default(),
78        }
79    }
80}
81
82/// The type of database backend for real world databases.
83/// This is enabled by feature flags as specified in the crate documentation
84#[derive(Debug, Copy, Clone, PartialEq, Eq)]
85#[non_exhaustive]
86pub enum DatabaseBackend {
87    /// A MySQL backend
88    MySql,
89    /// A PostgreSQL backend
90    Postgres,
91    /// A SQLite backend
92    Sqlite,
93}
94
95/// A shorthand for [DatabaseBackend].
96pub type DbBackend = DatabaseBackend;
97
98#[derive(Debug)]
99pub(crate) enum InnerConnection {
100    #[cfg(feature = "sqlx-mysql")]
101    MySql(PoolConnection<sqlx::MySql>),
102    #[cfg(feature = "sqlx-postgres")]
103    Postgres(PoolConnection<sqlx::Postgres>),
104    #[cfg(feature = "sqlx-sqlite")]
105    Sqlite(PoolConnection<sqlx::Sqlite>),
106    #[cfg(feature = "rusqlite")]
107    Rusqlite(RusqliteInnerConnection),
108    #[cfg(feature = "mock")]
109    Mock(Arc<crate::MockDatabaseConnection>),
110    #[cfg(feature = "proxy")]
111    Proxy(Arc<crate::ProxyDatabaseConnection>),
112}
113
114impl Debug for DatabaseConnectionType {
115    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
116        write!(
117            f,
118            "{}",
119            match self {
120                #[cfg(feature = "sqlx-mysql")]
121                Self::SqlxMySqlPoolConnection(_) => "SqlxMySqlPoolConnection",
122                #[cfg(feature = "sqlx-postgres")]
123                Self::SqlxPostgresPoolConnection(_) => "SqlxPostgresPoolConnection",
124                #[cfg(feature = "sqlx-sqlite")]
125                Self::SqlxSqlitePoolConnection(_) => "SqlxSqlitePoolConnection",
126                #[cfg(feature = "rusqlite")]
127                Self::RusqliteSharedConnection(_) => "RusqliteSharedConnection",
128                #[cfg(feature = "mock")]
129                Self::MockDatabaseConnection(_) => "MockDatabaseConnection",
130                #[cfg(feature = "proxy")]
131                Self::ProxyDatabaseConnection(_) => "ProxyDatabaseConnection",
132                Self::Disconnected => "Disconnected",
133            }
134        )
135    }
136}
137
138#[async_trait::async_trait]
139impl ConnectionTrait for DatabaseConnection {
140    fn get_database_backend(&self) -> DbBackend {
141        self.get_database_backend()
142    }
143
144    #[instrument(level = "trace")]
145    #[allow(unused_variables)]
146    async fn execute_raw(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
147        super::tracing_spans::with_db_span!(
148            "sea_orm.execute",
149            self.get_database_backend(),
150            stmt.sql.as_str(),
151            record_stmt = true,
152            async {
153                match &self.inner {
154                    #[cfg(feature = "sqlx-mysql")]
155                    DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => {
156                        conn.execute(stmt).await
157                    }
158                    #[cfg(feature = "sqlx-postgres")]
159                    DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
160                        conn.execute(stmt).await
161                    }
162                    #[cfg(feature = "sqlx-sqlite")]
163                    DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => {
164                        conn.execute(stmt).await
165                    }
166                    #[cfg(feature = "rusqlite")]
167                    DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.execute(stmt),
168                    #[cfg(feature = "mock")]
169                    DatabaseConnectionType::MockDatabaseConnection(conn) => conn.execute(stmt),
170                    #[cfg(feature = "proxy")]
171                    DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
172                        conn.execute(stmt).await
173                    }
174                    DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
175                }
176            }
177        )
178    }
179
180    #[instrument(level = "trace")]
181    #[allow(unused_variables)]
182    async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
183        super::tracing_spans::with_db_span!(
184            "sea_orm.execute_unprepared",
185            self.get_database_backend(),
186            sql,
187            record_stmt = false,
188            async {
189                match &self.inner {
190                    #[cfg(feature = "sqlx-mysql")]
191                    DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => {
192                        conn.execute_unprepared(sql).await
193                    }
194                    #[cfg(feature = "sqlx-postgres")]
195                    DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
196                        conn.execute_unprepared(sql).await
197                    }
198                    #[cfg(feature = "sqlx-sqlite")]
199                    DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => {
200                        conn.execute_unprepared(sql).await
201                    }
202                    #[cfg(feature = "rusqlite")]
203                    DatabaseConnectionType::RusqliteSharedConnection(conn) => {
204                        conn.execute_unprepared(sql)
205                    }
206                    #[cfg(feature = "mock")]
207                    DatabaseConnectionType::MockDatabaseConnection(conn) => {
208                        let db_backend = conn.get_database_backend();
209                        let stmt = Statement::from_string(db_backend, sql);
210                        conn.execute(stmt)
211                    }
212                    #[cfg(feature = "proxy")]
213                    DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
214                        let db_backend = conn.get_database_backend();
215                        let stmt = Statement::from_string(db_backend, sql);
216                        conn.execute(stmt).await
217                    }
218                    DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
219                }
220            }
221        )
222    }
223
224    #[instrument(level = "trace")]
225    #[allow(unused_variables)]
226    async fn query_one_raw(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
227        super::tracing_spans::with_db_span!(
228            "sea_orm.query_one",
229            self.get_database_backend(),
230            stmt.sql.as_str(),
231            record_stmt = true,
232            async {
233                match &self.inner {
234                    #[cfg(feature = "sqlx-mysql")]
235                    DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => {
236                        conn.query_one(stmt).await
237                    }
238                    #[cfg(feature = "sqlx-postgres")]
239                    DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
240                        conn.query_one(stmt).await
241                    }
242                    #[cfg(feature = "sqlx-sqlite")]
243                    DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => {
244                        conn.query_one(stmt).await
245                    }
246                    #[cfg(feature = "rusqlite")]
247                    DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.query_one(stmt),
248                    #[cfg(feature = "mock")]
249                    DatabaseConnectionType::MockDatabaseConnection(conn) => conn.query_one(stmt),
250                    #[cfg(feature = "proxy")]
251                    DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
252                        conn.query_one(stmt).await
253                    }
254                    DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
255                }
256            }
257        )
258    }
259
260    #[instrument(level = "trace")]
261    #[allow(unused_variables)]
262    async fn query_all_raw(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
263        super::tracing_spans::with_db_span!(
264            "sea_orm.query_all",
265            self.get_database_backend(),
266            stmt.sql.as_str(),
267            record_stmt = true,
268            async {
269                match &self.inner {
270                    #[cfg(feature = "sqlx-mysql")]
271                    DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => {
272                        conn.query_all(stmt).await
273                    }
274                    #[cfg(feature = "sqlx-postgres")]
275                    DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
276                        conn.query_all(stmt).await
277                    }
278                    #[cfg(feature = "sqlx-sqlite")]
279                    DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => {
280                        conn.query_all(stmt).await
281                    }
282                    #[cfg(feature = "rusqlite")]
283                    DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.query_all(stmt),
284                    #[cfg(feature = "mock")]
285                    DatabaseConnectionType::MockDatabaseConnection(conn) => conn.query_all(stmt),
286                    #[cfg(feature = "proxy")]
287                    DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
288                        conn.query_all(stmt).await
289                    }
290                    DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
291                }
292            }
293        )
294    }
295
296    #[cfg(feature = "mock")]
297    fn is_mock_connection(&self) -> bool {
298        matches!(
299            self,
300            DatabaseConnection {
301                inner: DatabaseConnectionType::MockDatabaseConnection(_),
302                ..
303            }
304        )
305    }
306}
307
308#[async_trait::async_trait]
309impl StreamTrait for DatabaseConnection {
310    type Stream<'a> = crate::QueryStream;
311
312    fn get_database_backend(&self) -> DbBackend {
313        self.get_database_backend()
314    }
315
316    #[instrument(level = "trace")]
317    #[allow(unused_variables)]
318    fn stream_raw<'a>(
319        &'a self,
320        stmt: Statement,
321    ) -> Pin<Box<dyn Future<Output = Result<Self::Stream<'a>, DbErr>> + 'a + Send>> {
322        Box::pin(async move {
323            match &self.inner {
324                #[cfg(feature = "sqlx-mysql")]
325                DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => conn.stream(stmt).await,
326                #[cfg(feature = "sqlx-postgres")]
327                DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => conn.stream(stmt).await,
328                #[cfg(feature = "sqlx-sqlite")]
329                DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => conn.stream(stmt).await,
330                #[cfg(feature = "rusqlite")]
331                DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.stream(stmt),
332                #[cfg(feature = "mock")]
333                DatabaseConnectionType::MockDatabaseConnection(conn) => {
334                    Ok(crate::QueryStream::from((Arc::clone(conn), stmt, None)))
335                }
336                #[cfg(feature = "proxy")]
337                DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
338                    Ok(crate::QueryStream::from((Arc::clone(conn), stmt, None)))
339                }
340                DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
341            }
342        })
343    }
344}
345
346#[async_trait::async_trait]
347impl TransactionTrait for DatabaseConnection {
348    type Transaction = DatabaseTransaction;
349
350    #[instrument(level = "trace")]
351    async fn begin(&self) -> Result<DatabaseTransaction, DbErr> {
352        match &self.inner {
353            #[cfg(feature = "sqlx-mysql")]
354            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => conn.begin(None, None).await,
355            #[cfg(feature = "sqlx-postgres")]
356            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
357                conn.begin(None, None).await
358            }
359            #[cfg(feature = "sqlx-sqlite")]
360            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => conn.begin(None, None).await,
361            #[cfg(feature = "rusqlite")]
362            DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.begin(None, None),
363            #[cfg(feature = "mock")]
364            DatabaseConnectionType::MockDatabaseConnection(conn) => {
365                DatabaseTransaction::new_mock(Arc::clone(conn), None).await
366            }
367            #[cfg(feature = "proxy")]
368            DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
369                DatabaseTransaction::new_proxy(conn.clone(), None).await
370            }
371            DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
372        }
373    }
374
375    #[instrument(level = "trace")]
376    async fn begin_with_config(
377        &self,
378        _isolation_level: Option<IsolationLevel>,
379        _access_mode: Option<AccessMode>,
380    ) -> Result<DatabaseTransaction, DbErr> {
381        match &self.inner {
382            #[cfg(feature = "sqlx-mysql")]
383            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => {
384                conn.begin(_isolation_level, _access_mode).await
385            }
386            #[cfg(feature = "sqlx-postgres")]
387            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
388                conn.begin(_isolation_level, _access_mode).await
389            }
390            #[cfg(feature = "sqlx-sqlite")]
391            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => {
392                conn.begin(_isolation_level, _access_mode).await
393            }
394            #[cfg(feature = "rusqlite")]
395            DatabaseConnectionType::RusqliteSharedConnection(conn) => {
396                conn.begin(_isolation_level, _access_mode)
397            }
398            #[cfg(feature = "mock")]
399            DatabaseConnectionType::MockDatabaseConnection(conn) => {
400                DatabaseTransaction::new_mock(Arc::clone(conn), None).await
401            }
402            #[cfg(feature = "proxy")]
403            DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
404                DatabaseTransaction::new_proxy(conn.clone(), None).await
405            }
406            DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
407        }
408    }
409
410    /// Execute the function inside a transaction.
411    /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed.
412    #[instrument(level = "trace", skip(_callback))]
413    async fn transaction<F, T, E>(&self, _callback: F) -> Result<T, TransactionError<E>>
414    where
415        F: for<'c> FnOnce(
416                &'c DatabaseTransaction,
417            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
418            + Send,
419        T: Send,
420        E: std::fmt::Display + std::fmt::Debug + Send,
421    {
422        match &self.inner {
423            #[cfg(feature = "sqlx-mysql")]
424            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => {
425                conn.transaction(_callback, None, None).await
426            }
427            #[cfg(feature = "sqlx-postgres")]
428            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
429                conn.transaction(_callback, None, None).await
430            }
431            #[cfg(feature = "sqlx-sqlite")]
432            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => {
433                conn.transaction(_callback, None, None).await
434            }
435            #[cfg(feature = "rusqlite")]
436            DatabaseConnectionType::RusqliteSharedConnection(conn) => {
437                conn.transaction(_callback, None, None)
438            }
439            #[cfg(feature = "mock")]
440            DatabaseConnectionType::MockDatabaseConnection(conn) => {
441                let transaction = DatabaseTransaction::new_mock(Arc::clone(conn), None)
442                    .await
443                    .map_err(TransactionError::Connection)?;
444                transaction.run(_callback).await
445            }
446            #[cfg(feature = "proxy")]
447            DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
448                let transaction = DatabaseTransaction::new_proxy(conn.clone(), None)
449                    .await
450                    .map_err(TransactionError::Connection)?;
451                transaction.run(_callback).await
452            }
453            DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected").into()),
454        }
455    }
456
457    /// Execute the function inside a transaction.
458    /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed.
459    #[instrument(level = "trace", skip(_callback))]
460    async fn transaction_with_config<F, T, E>(
461        &self,
462        _callback: F,
463        _isolation_level: Option<IsolationLevel>,
464        _access_mode: Option<AccessMode>,
465    ) -> Result<T, TransactionError<E>>
466    where
467        F: for<'c> FnOnce(
468                &'c DatabaseTransaction,
469            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
470            + Send,
471        T: Send,
472        E: std::fmt::Display + std::fmt::Debug + Send,
473    {
474        match &self.inner {
475            #[cfg(feature = "sqlx-mysql")]
476            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => {
477                conn.transaction(_callback, _isolation_level, _access_mode)
478                    .await
479            }
480            #[cfg(feature = "sqlx-postgres")]
481            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
482                conn.transaction(_callback, _isolation_level, _access_mode)
483                    .await
484            }
485            #[cfg(feature = "sqlx-sqlite")]
486            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => {
487                conn.transaction(_callback, _isolation_level, _access_mode)
488                    .await
489            }
490            #[cfg(feature = "rusqlite")]
491            DatabaseConnectionType::RusqliteSharedConnection(conn) => {
492                conn.transaction(_callback, _isolation_level, _access_mode)
493            }
494            #[cfg(feature = "mock")]
495            DatabaseConnectionType::MockDatabaseConnection(conn) => {
496                let transaction = DatabaseTransaction::new_mock(Arc::clone(conn), None)
497                    .await
498                    .map_err(TransactionError::Connection)?;
499                transaction.run(_callback).await
500            }
501            #[cfg(feature = "proxy")]
502            DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
503                let transaction = DatabaseTransaction::new_proxy(conn.clone(), None)
504                    .await
505                    .map_err(TransactionError::Connection)?;
506                transaction.run(_callback).await
507            }
508            DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected").into()),
509        }
510    }
511}
512
513#[cfg(feature = "mock")]
514impl DatabaseConnection {
515    /// Generate a database connection for testing the Mock database
516    ///
517    /// # Panics
518    ///
519    /// Panics if [DbConn] is not a mock connection.
520    pub fn as_mock_connection(&self) -> &crate::MockDatabaseConnection {
521        match &self.inner {
522            DatabaseConnectionType::MockDatabaseConnection(mock_conn) => mock_conn,
523            _ => panic!("Not mock connection"),
524        }
525    }
526
527    /// Get the transaction log as a collection Vec<[crate::Transaction]>
528    ///
529    /// # Panics
530    ///
531    /// Panics if the mocker mutex is being held by another thread.
532    pub fn into_transaction_log(self) -> Vec<crate::Transaction> {
533        let mut mocker = self
534            .as_mock_connection()
535            .get_mocker_mutex()
536            .lock()
537            .expect("Fail to acquire mocker");
538        mocker.drain_transaction_log()
539    }
540}
541
542#[cfg(feature = "proxy")]
543impl DatabaseConnection {
544    /// Generate a database connection for testing the Proxy database
545    ///
546    /// # Panics
547    ///
548    /// Panics if [DbConn] is not a proxy connection.
549    pub fn as_proxy_connection(&self) -> &crate::ProxyDatabaseConnection {
550        match &self.inner {
551            DatabaseConnectionType::ProxyDatabaseConnection(proxy_conn) => proxy_conn,
552            _ => panic!("Not proxy connection"),
553        }
554    }
555}
556
557#[cfg(feature = "rbac")]
558impl DatabaseConnection {
559    /// Load RBAC data from the same database as this connection and setup RBAC engine.
560    /// If the RBAC engine already exists, it will be replaced.
561    pub async fn load_rbac(&self) -> Result<(), DbErr> {
562        self.load_rbac_from(self).await
563    }
564
565    /// Load RBAC data from the given database connection and setup RBAC engine.
566    /// This could be from another database.
567    pub async fn load_rbac_from(&self, db: &DbConn) -> Result<(), DbErr> {
568        let engine = crate::rbac::RbacEngine::load_from(db).await?;
569        self.rbac.replace(engine);
570        Ok(())
571    }
572
573    /// Replace the internal RBAC engine.
574    pub fn replace_rbac(&self, engine: crate::rbac::RbacEngine) {
575        self.rbac.replace(engine);
576    }
577
578    /// Create a restricted connection with access control specific for the user.
579    pub fn restricted_for(
580        &self,
581        user_id: crate::rbac::RbacUserId,
582    ) -> Result<crate::RestrictedConnection, DbErr> {
583        if self.rbac.is_some() {
584            Ok(crate::RestrictedConnection {
585                user_id,
586                conn: self.clone(),
587            })
588        } else {
589            Err(DbErr::RbacError("engine not set up".into()))
590        }
591    }
592}
593
594impl DatabaseConnection {
595    /// Get the database backend for this connection
596    ///
597    /// # Panics
598    ///
599    /// Panics if [DatabaseConnection] is `Disconnected`.
600    pub fn get_database_backend(&self) -> DbBackend {
601        match &self.inner {
602            #[cfg(feature = "sqlx-mysql")]
603            DatabaseConnectionType::SqlxMySqlPoolConnection(_) => DbBackend::MySql,
604            #[cfg(feature = "sqlx-postgres")]
605            DatabaseConnectionType::SqlxPostgresPoolConnection(_) => DbBackend::Postgres,
606            #[cfg(feature = "sqlx-sqlite")]
607            DatabaseConnectionType::SqlxSqlitePoolConnection(_) => DbBackend::Sqlite,
608            #[cfg(feature = "rusqlite")]
609            DatabaseConnectionType::RusqliteSharedConnection(_) => DbBackend::Sqlite,
610            #[cfg(feature = "mock")]
611            DatabaseConnectionType::MockDatabaseConnection(conn) => conn.get_database_backend(),
612            #[cfg(feature = "proxy")]
613            DatabaseConnectionType::ProxyDatabaseConnection(conn) => conn.get_database_backend(),
614            DatabaseConnectionType::Disconnected => panic!("Disconnected"),
615        }
616    }
617
618    /// Creates a [`SchemaBuilder`] for this backend
619    pub fn get_schema_builder(&self) -> SchemaBuilder {
620        Schema::new(self.get_database_backend()).builder()
621    }
622
623    #[cfg(feature = "entity-registry")]
624    #[cfg_attr(docsrs, doc(cfg(feature = "entity-registry")))]
625    /// Builds a schema for all the entites in the given module
626    pub fn get_schema_registry(&self, prefix: &str) -> SchemaBuilder {
627        let schema = Schema::new(self.get_database_backend());
628        crate::EntityRegistry::build_schema(schema, prefix)
629    }
630
631    /// Sets a callback to metric this connection
632    pub fn set_metric_callback<F>(&mut self, _callback: F)
633    where
634        F: Fn(&crate::metric::Info<'_>) + Send + Sync + 'static,
635    {
636        match &mut self.inner {
637            #[cfg(feature = "sqlx-mysql")]
638            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => {
639                conn.set_metric_callback(_callback)
640            }
641            #[cfg(feature = "sqlx-postgres")]
642            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
643                conn.set_metric_callback(_callback)
644            }
645            #[cfg(feature = "sqlx-sqlite")]
646            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => {
647                conn.set_metric_callback(_callback)
648            }
649            #[cfg(feature = "rusqlite")]
650            DatabaseConnectionType::RusqliteSharedConnection(conn) => {
651                conn.set_metric_callback(_callback)
652            }
653            _ => {}
654        }
655    }
656
657    /// Checks if a connection to the database is still valid.
658    pub async fn ping(&self) -> Result<(), DbErr> {
659        match &self.inner {
660            #[cfg(feature = "sqlx-mysql")]
661            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => conn.ping().await,
662            #[cfg(feature = "sqlx-postgres")]
663            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => conn.ping().await,
664            #[cfg(feature = "sqlx-sqlite")]
665            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => conn.ping().await,
666            #[cfg(feature = "rusqlite")]
667            DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.ping(),
668            #[cfg(feature = "mock")]
669            DatabaseConnectionType::MockDatabaseConnection(conn) => conn.ping(),
670            #[cfg(feature = "proxy")]
671            DatabaseConnectionType::ProxyDatabaseConnection(conn) => conn.ping().await,
672            DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
673        }
674    }
675
676    /// Explicitly close the database connection.
677    /// See [`Self::close_by_ref`] for usage with references.
678    pub async fn close(self) -> Result<(), DbErr> {
679        self.close_by_ref().await
680    }
681
682    /// Explicitly close the database connection
683    pub async fn close_by_ref(&self) -> Result<(), DbErr> {
684        match &self.inner {
685            #[cfg(feature = "sqlx-mysql")]
686            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => conn.close_by_ref().await,
687            #[cfg(feature = "sqlx-postgres")]
688            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => conn.close_by_ref().await,
689            #[cfg(feature = "sqlx-sqlite")]
690            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => conn.close_by_ref().await,
691            #[cfg(feature = "rusqlite")]
692            DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.close_by_ref(),
693            #[cfg(feature = "mock")]
694            DatabaseConnectionType::MockDatabaseConnection(_) => {
695                // Nothing to cleanup, we just consume the `DatabaseConnection`
696                Ok(())
697            }
698            #[cfg(feature = "proxy")]
699            DatabaseConnectionType::ProxyDatabaseConnection(_) => {
700                // Nothing to cleanup, we just consume the `DatabaseConnection`
701                Ok(())
702            }
703            DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
704        }
705    }
706}
707
708impl DatabaseConnection {
709    /// Get [sqlx::MySqlPool]
710    ///
711    /// # Panics
712    ///
713    /// Panics if [DbConn] is not a MySQL connection.
714    #[cfg(feature = "sqlx-mysql")]
715    pub fn get_mysql_connection_pool(&self) -> &sqlx::MySqlPool {
716        match &self.inner {
717            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => &conn.pool,
718            _ => panic!("Not MySQL Connection"),
719        }
720    }
721
722    /// Get [sqlx::PgPool]
723    ///
724    /// # Panics
725    ///
726    /// Panics if [DbConn] is not a Postgres connection.
727    #[cfg(feature = "sqlx-postgres")]
728    pub fn get_postgres_connection_pool(&self) -> &sqlx::PgPool {
729        match &self.inner {
730            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => &conn.pool,
731            _ => panic!("Not Postgres Connection"),
732        }
733    }
734
735    /// Get [sqlx::SqlitePool]
736    ///
737    /// # Panics
738    ///
739    /// Panics if [DbConn] is not a SQLite connection.
740    #[cfg(feature = "sqlx-sqlite")]
741    pub fn get_sqlite_connection_pool(&self) -> &sqlx::SqlitePool {
742        match &self.inner {
743            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => &conn.pool,
744            _ => panic!("Not SQLite Connection"),
745        }
746    }
747}
748
749impl DbBackend {
750    /// Check if the URI is the same as the specified database backend.
751    /// Returns true if they match.
752    ///
753    /// # Panics
754    ///
755    /// Panics if `base_url` cannot be parsed as `Url`.
756    pub fn is_prefix_of(self, base_url: &str) -> bool {
757        let base_url_parsed = Url::parse(base_url).expect("Fail to parse database URL");
758        match self {
759            Self::Postgres => {
760                base_url_parsed.scheme() == "postgres" || base_url_parsed.scheme() == "postgresql"
761            }
762            Self::MySql => base_url_parsed.scheme() == "mysql",
763            Self::Sqlite => base_url_parsed.scheme() == "sqlite",
764        }
765    }
766
767    /// Build an SQL [Statement]
768    pub fn build<S>(&self, statement: &S) -> Statement
769    where
770        S: StatementBuilder,
771    {
772        statement.build(self)
773    }
774
775    /// Check if the database supports `RETURNING` syntax on insert and update
776    pub fn support_returning(&self) -> bool {
777        match self {
778            Self::Postgres => true,
779            Self::Sqlite if cfg!(feature = "sqlite-use-returning-for-3_35") => true,
780            Self::MySql if cfg!(feature = "mariadb-use-returning") => true,
781            _ => false,
782        }
783    }
784
785    /// A getter for database dependent boolean value
786    pub fn boolean_value(&self, boolean: bool) -> sea_query::Value {
787        match self {
788            Self::MySql | Self::Postgres | Self::Sqlite => boolean.into(),
789        }
790    }
791
792    /// Get the display string for this enum
793    pub fn as_str(&self) -> &'static str {
794        match self {
795            DatabaseBackend::MySql => "MySql",
796            DatabaseBackend::Postgres => "Postgres",
797            DatabaseBackend::Sqlite => "Sqlite",
798        }
799    }
800}
801
802#[cfg(test)]
803mod tests {
804    use crate::DatabaseConnection;
805
806    #[cfg(not(feature = "sync"))]
807    #[test]
808    fn assert_database_connection_traits() {
809        fn assert_send_sync<T: Send + Sync>() {}
810
811        assert_send_sync::<DatabaseConnection>();
812    }
813}