Skip to main content

sea_orm/database/
db_connection.rs

1use crate::{
2    AccessMode, ConnectionTrait, DatabaseTransaction, ExecResult, IsolationLevel, QueryResult,
3    Schema, SchemaBuilder, Statement, StatementBuilder, StreamTrait, TransactionError,
4    TransactionTrait, error::*,
5};
6use std::fmt::Debug;
7use tracing::instrument;
8use url::Url;
9
10#[cfg(feature = "sqlx-dep")]
11use sqlx::pool::PoolConnection;
12
13#[cfg(feature = "rusqlite")]
14use crate::driver::rusqlite::{RusqliteInnerConnection, RusqliteSharedConnection};
15
16#[cfg(any(feature = "mock", feature = "proxy"))]
17use std::sync::Arc;
18
19/// Handle a database connection depending on the backend enabled by the feature
20/// flags. This creates a connection pool internally (for SQLx connections),
21/// and so is cheap to clone.
22#[derive(Debug, Clone)]
23#[non_exhaustive]
24pub struct DatabaseConnection {
25    /// `DatabaseConnection` used to be a enum. Now it's moved into inner,
26    /// because we have to attach other contexts.
27    pub inner: DatabaseConnectionType,
28    #[cfg(feature = "rbac")]
29    pub(crate) rbac: crate::RbacEngineMount,
30}
31
32/// The underlying database connection type.
33#[derive(Clone)]
34pub enum DatabaseConnectionType {
35    /// MySql database connection pool
36    #[cfg(feature = "sqlx-mysql")]
37    SqlxMySqlPoolConnection(crate::SqlxMySqlPoolConnection),
38
39    /// PostgreSQL database connection pool
40    #[cfg(feature = "sqlx-postgres")]
41    SqlxPostgresPoolConnection(crate::SqlxPostgresPoolConnection),
42
43    /// SQLite database connection pool
44    #[cfg(feature = "sqlx-sqlite")]
45    SqlxSqlitePoolConnection(crate::SqlxSqlitePoolConnection),
46
47    /// SQLite database connection sharable across threads
48    #[cfg(feature = "rusqlite")]
49    RusqliteSharedConnection(RusqliteSharedConnection),
50
51    /// Mock database connection useful for testing
52    #[cfg(feature = "mock")]
53    MockDatabaseConnection(Arc<crate::MockDatabaseConnection>),
54
55    /// Proxy database connection
56    #[cfg(feature = "proxy")]
57    ProxyDatabaseConnection(Arc<crate::ProxyDatabaseConnection>),
58
59    /// The connection has never been established
60    Disconnected,
61}
62
63/// The same as a [DatabaseConnection]
64pub type DbConn = DatabaseConnection;
65
66impl Default for DatabaseConnection {
67    fn default() -> Self {
68        DatabaseConnectionType::Disconnected.into()
69    }
70}
71
72impl From<DatabaseConnectionType> for DatabaseConnection {
73    fn from(inner: DatabaseConnectionType) -> Self {
74        Self {
75            inner,
76            #[cfg(feature = "rbac")]
77            rbac: Default::default(),
78        }
79    }
80}
81
82/// The type of database backend for real world databases.
83/// This is enabled by feature flags as specified in the crate documentation
84#[derive(Debug, Copy, Clone, PartialEq, Eq)]
85#[non_exhaustive]
86pub enum DatabaseBackend {
87    /// A MySQL backend
88    MySql,
89    /// A PostgreSQL backend
90    Postgres,
91    /// A SQLite backend
92    Sqlite,
93}
94
95/// A shorthand for [DatabaseBackend].
96pub type DbBackend = DatabaseBackend;
97
98#[derive(Debug)]
99pub(crate) enum InnerConnection {
100    #[cfg(feature = "sqlx-mysql")]
101    MySql(PoolConnection<sqlx::MySql>),
102    #[cfg(feature = "sqlx-postgres")]
103    Postgres(PoolConnection<sqlx::Postgres>),
104    #[cfg(feature = "sqlx-sqlite")]
105    Sqlite(PoolConnection<sqlx::Sqlite>),
106    #[cfg(feature = "rusqlite")]
107    Rusqlite(RusqliteInnerConnection),
108    #[cfg(feature = "mock")]
109    Mock(Arc<crate::MockDatabaseConnection>),
110    #[cfg(feature = "proxy")]
111    Proxy(Arc<crate::ProxyDatabaseConnection>),
112}
113
114impl Debug for DatabaseConnectionType {
115    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
116        write!(
117            f,
118            "{}",
119            match self {
120                #[cfg(feature = "sqlx-mysql")]
121                Self::SqlxMySqlPoolConnection(_) => "SqlxMySqlPoolConnection",
122                #[cfg(feature = "sqlx-postgres")]
123                Self::SqlxPostgresPoolConnection(_) => "SqlxPostgresPoolConnection",
124                #[cfg(feature = "sqlx-sqlite")]
125                Self::SqlxSqlitePoolConnection(_) => "SqlxSqlitePoolConnection",
126                #[cfg(feature = "rusqlite")]
127                Self::RusqliteSharedConnection(_) => "RusqliteSharedConnection",
128                #[cfg(feature = "mock")]
129                Self::MockDatabaseConnection(_) => "MockDatabaseConnection",
130                #[cfg(feature = "proxy")]
131                Self::ProxyDatabaseConnection(_) => "ProxyDatabaseConnection",
132                Self::Disconnected => "Disconnected",
133            }
134        )
135    }
136}
137
138impl ConnectionTrait for DatabaseConnection {
139    fn get_database_backend(&self) -> DbBackend {
140        self.get_database_backend()
141    }
142
143    #[instrument(level = "trace")]
144    #[allow(unused_variables)]
145    fn execute_raw(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
146        super::tracing_spans::with_db_span!(
147            "sea_orm.execute",
148            self.get_database_backend(),
149            stmt.sql.as_str(),
150            record_stmt = true,
151            {
152                match &self.inner {
153                    #[cfg(feature = "sqlx-mysql")]
154                    DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => conn.execute(stmt),
155                    #[cfg(feature = "sqlx-postgres")]
156                    DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => conn.execute(stmt),
157                    #[cfg(feature = "sqlx-sqlite")]
158                    DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => conn.execute(stmt),
159                    #[cfg(feature = "rusqlite")]
160                    DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.execute(stmt),
161                    #[cfg(feature = "mock")]
162                    DatabaseConnectionType::MockDatabaseConnection(conn) => conn.execute(stmt),
163                    #[cfg(feature = "proxy")]
164                    DatabaseConnectionType::ProxyDatabaseConnection(conn) => conn.execute(stmt),
165                    DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
166                }
167            }
168        )
169    }
170
171    #[instrument(level = "trace")]
172    #[allow(unused_variables)]
173    fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
174        super::tracing_spans::with_db_span!(
175            "sea_orm.execute_unprepared",
176            self.get_database_backend(),
177            sql,
178            record_stmt = false,
179            {
180                match &self.inner {
181                    #[cfg(feature = "sqlx-mysql")]
182                    DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => {
183                        conn.execute_unprepared(sql)
184                    }
185                    #[cfg(feature = "sqlx-postgres")]
186                    DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
187                        conn.execute_unprepared(sql)
188                    }
189                    #[cfg(feature = "sqlx-sqlite")]
190                    DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => {
191                        conn.execute_unprepared(sql)
192                    }
193                    #[cfg(feature = "rusqlite")]
194                    DatabaseConnectionType::RusqliteSharedConnection(conn) => {
195                        conn.execute_unprepared(sql)
196                    }
197                    #[cfg(feature = "mock")]
198                    DatabaseConnectionType::MockDatabaseConnection(conn) => {
199                        let db_backend = conn.get_database_backend();
200                        let stmt = Statement::from_string(db_backend, sql);
201                        conn.execute(stmt)
202                    }
203                    #[cfg(feature = "proxy")]
204                    DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
205                        let db_backend = conn.get_database_backend();
206                        let stmt = Statement::from_string(db_backend, sql);
207                        conn.execute(stmt)
208                    }
209                    DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
210                }
211            }
212        )
213    }
214
215    #[instrument(level = "trace")]
216    #[allow(unused_variables)]
217    fn query_one_raw(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
218        super::tracing_spans::with_db_span!(
219            "sea_orm.query_one",
220            self.get_database_backend(),
221            stmt.sql.as_str(),
222            record_stmt = true,
223            {
224                match &self.inner {
225                    #[cfg(feature = "sqlx-mysql")]
226                    DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => conn.query_one(stmt),
227                    #[cfg(feature = "sqlx-postgres")]
228                    DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
229                        conn.query_one(stmt)
230                    }
231                    #[cfg(feature = "sqlx-sqlite")]
232                    DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => conn.query_one(stmt),
233                    #[cfg(feature = "rusqlite")]
234                    DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.query_one(stmt),
235                    #[cfg(feature = "mock")]
236                    DatabaseConnectionType::MockDatabaseConnection(conn) => conn.query_one(stmt),
237                    #[cfg(feature = "proxy")]
238                    DatabaseConnectionType::ProxyDatabaseConnection(conn) => conn.query_one(stmt),
239                    DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
240                }
241            }
242        )
243    }
244
245    #[instrument(level = "trace")]
246    #[allow(unused_variables)]
247    fn query_all_raw(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
248        super::tracing_spans::with_db_span!(
249            "sea_orm.query_all",
250            self.get_database_backend(),
251            stmt.sql.as_str(),
252            record_stmt = true,
253            {
254                match &self.inner {
255                    #[cfg(feature = "sqlx-mysql")]
256                    DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => conn.query_all(stmt),
257                    #[cfg(feature = "sqlx-postgres")]
258                    DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
259                        conn.query_all(stmt)
260                    }
261                    #[cfg(feature = "sqlx-sqlite")]
262                    DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => conn.query_all(stmt),
263                    #[cfg(feature = "rusqlite")]
264                    DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.query_all(stmt),
265                    #[cfg(feature = "mock")]
266                    DatabaseConnectionType::MockDatabaseConnection(conn) => conn.query_all(stmt),
267                    #[cfg(feature = "proxy")]
268                    DatabaseConnectionType::ProxyDatabaseConnection(conn) => conn.query_all(stmt),
269                    DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
270                }
271            }
272        )
273    }
274
275    #[cfg(feature = "mock")]
276    fn is_mock_connection(&self) -> bool {
277        matches!(
278            self,
279            DatabaseConnection {
280                inner: DatabaseConnectionType::MockDatabaseConnection(_),
281                ..
282            }
283        )
284    }
285}
286
287impl StreamTrait for DatabaseConnection {
288    type Stream<'a> = crate::QueryStream;
289
290    fn get_database_backend(&self) -> DbBackend {
291        self.get_database_backend()
292    }
293
294    #[instrument(level = "trace")]
295    #[allow(unused_variables)]
296    fn stream_raw<'a>(&'a self, stmt: Statement) -> Result<Self::Stream<'a>, DbErr> {
297        ({
298            match &self.inner {
299                #[cfg(feature = "sqlx-mysql")]
300                DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => conn.stream(stmt),
301                #[cfg(feature = "sqlx-postgres")]
302                DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => conn.stream(stmt),
303                #[cfg(feature = "sqlx-sqlite")]
304                DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => conn.stream(stmt),
305                #[cfg(feature = "rusqlite")]
306                DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.stream(stmt),
307                #[cfg(feature = "mock")]
308                DatabaseConnectionType::MockDatabaseConnection(conn) => {
309                    Ok(crate::QueryStream::from((Arc::clone(conn), stmt, None)))
310                }
311                #[cfg(feature = "proxy")]
312                DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
313                    Ok(crate::QueryStream::from((Arc::clone(conn), stmt, None)))
314                }
315                DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
316            }
317        })
318    }
319}
320
321impl TransactionTrait for DatabaseConnection {
322    type Transaction = DatabaseTransaction;
323
324    #[instrument(level = "trace")]
325    fn begin(&self) -> Result<DatabaseTransaction, DbErr> {
326        match &self.inner {
327            #[cfg(feature = "sqlx-mysql")]
328            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => conn.begin(None, None),
329            #[cfg(feature = "sqlx-postgres")]
330            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => conn.begin(None, None),
331            #[cfg(feature = "sqlx-sqlite")]
332            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => conn.begin(None, None),
333            #[cfg(feature = "rusqlite")]
334            DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.begin(None, None),
335            #[cfg(feature = "mock")]
336            DatabaseConnectionType::MockDatabaseConnection(conn) => {
337                DatabaseTransaction::new_mock(Arc::clone(conn), None)
338            }
339            #[cfg(feature = "proxy")]
340            DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
341                DatabaseTransaction::new_proxy(conn.clone(), None)
342            }
343            DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
344        }
345    }
346
347    #[instrument(level = "trace")]
348    fn begin_with_config(
349        &self,
350        _isolation_level: Option<IsolationLevel>,
351        _access_mode: Option<AccessMode>,
352    ) -> Result<DatabaseTransaction, DbErr> {
353        match &self.inner {
354            #[cfg(feature = "sqlx-mysql")]
355            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => {
356                conn.begin(_isolation_level, _access_mode)
357            }
358            #[cfg(feature = "sqlx-postgres")]
359            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
360                conn.begin(_isolation_level, _access_mode)
361            }
362            #[cfg(feature = "sqlx-sqlite")]
363            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => {
364                conn.begin(_isolation_level, _access_mode)
365            }
366            #[cfg(feature = "rusqlite")]
367            DatabaseConnectionType::RusqliteSharedConnection(conn) => {
368                conn.begin(_isolation_level, _access_mode)
369            }
370            #[cfg(feature = "mock")]
371            DatabaseConnectionType::MockDatabaseConnection(conn) => {
372                DatabaseTransaction::new_mock(Arc::clone(conn), None)
373            }
374            #[cfg(feature = "proxy")]
375            DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
376                DatabaseTransaction::new_proxy(conn.clone(), None)
377            }
378            DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
379        }
380    }
381
382    /// Execute the function inside a transaction.
383    /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed.
384    #[instrument(level = "trace", skip(_callback))]
385    fn transaction<F, T, E>(&self, _callback: F) -> Result<T, TransactionError<E>>
386    where
387        F: for<'c> FnOnce(&'c DatabaseTransaction) -> Result<T, E>,
388        E: std::fmt::Display + std::fmt::Debug,
389    {
390        match &self.inner {
391            #[cfg(feature = "sqlx-mysql")]
392            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => {
393                conn.transaction(_callback, None, None)
394            }
395            #[cfg(feature = "sqlx-postgres")]
396            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
397                conn.transaction(_callback, None, None)
398            }
399            #[cfg(feature = "sqlx-sqlite")]
400            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => {
401                conn.transaction(_callback, None, None)
402            }
403            #[cfg(feature = "rusqlite")]
404            DatabaseConnectionType::RusqliteSharedConnection(conn) => {
405                conn.transaction(_callback, None, None)
406            }
407            #[cfg(feature = "mock")]
408            DatabaseConnectionType::MockDatabaseConnection(conn) => {
409                let transaction = DatabaseTransaction::new_mock(Arc::clone(conn), None)
410                    .map_err(TransactionError::Connection)?;
411                transaction.run(_callback)
412            }
413            #[cfg(feature = "proxy")]
414            DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
415                let transaction = DatabaseTransaction::new_proxy(conn.clone(), None)
416                    .map_err(TransactionError::Connection)?;
417                transaction.run(_callback)
418            }
419            DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected").into()),
420        }
421    }
422
423    /// Execute the function inside a transaction.
424    /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed.
425    #[instrument(level = "trace", skip(_callback))]
426    fn transaction_with_config<F, T, E>(
427        &self,
428        _callback: F,
429        _isolation_level: Option<IsolationLevel>,
430        _access_mode: Option<AccessMode>,
431    ) -> Result<T, TransactionError<E>>
432    where
433        F: for<'c> FnOnce(&'c DatabaseTransaction) -> Result<T, E>,
434        E: std::fmt::Display + std::fmt::Debug,
435    {
436        match &self.inner {
437            #[cfg(feature = "sqlx-mysql")]
438            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => {
439                conn.transaction(_callback, _isolation_level, _access_mode)
440            }
441            #[cfg(feature = "sqlx-postgres")]
442            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
443                conn.transaction(_callback, _isolation_level, _access_mode)
444            }
445            #[cfg(feature = "sqlx-sqlite")]
446            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => {
447                conn.transaction(_callback, _isolation_level, _access_mode)
448            }
449            #[cfg(feature = "rusqlite")]
450            DatabaseConnectionType::RusqliteSharedConnection(conn) => {
451                conn.transaction(_callback, _isolation_level, _access_mode)
452            }
453            #[cfg(feature = "mock")]
454            DatabaseConnectionType::MockDatabaseConnection(conn) => {
455                let transaction = DatabaseTransaction::new_mock(Arc::clone(conn), None)
456                    .map_err(TransactionError::Connection)?;
457                transaction.run(_callback)
458            }
459            #[cfg(feature = "proxy")]
460            DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
461                let transaction = DatabaseTransaction::new_proxy(conn.clone(), None)
462                    .map_err(TransactionError::Connection)?;
463                transaction.run(_callback)
464            }
465            DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected").into()),
466        }
467    }
468}
469
470#[cfg(feature = "mock")]
471impl DatabaseConnection {
472    /// Generate a database connection for testing the Mock database
473    ///
474    /// # Panics
475    ///
476    /// Panics if [DbConn] is not a mock connection.
477    pub fn as_mock_connection(&self) -> &crate::MockDatabaseConnection {
478        match &self.inner {
479            DatabaseConnectionType::MockDatabaseConnection(mock_conn) => mock_conn,
480            _ => panic!("Not mock connection"),
481        }
482    }
483
484    /// Get the transaction log as a collection Vec<[crate::Transaction]>
485    ///
486    /// # Panics
487    ///
488    /// Panics if the mocker mutex is being held by another thread.
489    pub fn into_transaction_log(self) -> Vec<crate::Transaction> {
490        let mut mocker = self
491            .as_mock_connection()
492            .get_mocker_mutex()
493            .lock()
494            .expect("Fail to acquire mocker");
495        mocker.drain_transaction_log()
496    }
497}
498
499#[cfg(feature = "proxy")]
500impl DatabaseConnection {
501    /// Generate a database connection for testing the Proxy database
502    ///
503    /// # Panics
504    ///
505    /// Panics if [DbConn] is not a proxy connection.
506    pub fn as_proxy_connection(&self) -> &crate::ProxyDatabaseConnection {
507        match &self.inner {
508            DatabaseConnectionType::ProxyDatabaseConnection(proxy_conn) => proxy_conn,
509            _ => panic!("Not proxy connection"),
510        }
511    }
512}
513
514#[cfg(feature = "rbac")]
515impl DatabaseConnection {
516    /// Load RBAC data from the same database as this connection and setup RBAC engine.
517    /// If the RBAC engine already exists, it will be replaced.
518    pub fn load_rbac(&self) -> Result<(), DbErr> {
519        self.load_rbac_from(self)
520    }
521
522    /// Load RBAC data from the given database connection and setup RBAC engine.
523    /// This could be from another database.
524    pub fn load_rbac_from(&self, db: &DbConn) -> Result<(), DbErr> {
525        let engine = crate::rbac::RbacEngine::load_from(db)?;
526        self.rbac.replace(engine);
527        Ok(())
528    }
529
530    /// Replace the internal RBAC engine.
531    pub fn replace_rbac(&self, engine: crate::rbac::RbacEngine) {
532        self.rbac.replace(engine);
533    }
534
535    /// Create a restricted connection with access control specific for the user.
536    pub fn restricted_for(
537        &self,
538        user_id: crate::rbac::RbacUserId,
539    ) -> Result<crate::RestrictedConnection, DbErr> {
540        if self.rbac.is_some() {
541            Ok(crate::RestrictedConnection {
542                user_id,
543                conn: self.clone(),
544            })
545        } else {
546            Err(DbErr::RbacError("engine not set up".into()))
547        }
548    }
549}
550
551impl DatabaseConnection {
552    /// Get the database backend for this connection
553    ///
554    /// # Panics
555    ///
556    /// Panics if [DatabaseConnection] is `Disconnected`.
557    pub fn get_database_backend(&self) -> DbBackend {
558        match &self.inner {
559            #[cfg(feature = "sqlx-mysql")]
560            DatabaseConnectionType::SqlxMySqlPoolConnection(_) => DbBackend::MySql,
561            #[cfg(feature = "sqlx-postgres")]
562            DatabaseConnectionType::SqlxPostgresPoolConnection(_) => DbBackend::Postgres,
563            #[cfg(feature = "sqlx-sqlite")]
564            DatabaseConnectionType::SqlxSqlitePoolConnection(_) => DbBackend::Sqlite,
565            #[cfg(feature = "rusqlite")]
566            DatabaseConnectionType::RusqliteSharedConnection(_) => DbBackend::Sqlite,
567            #[cfg(feature = "mock")]
568            DatabaseConnectionType::MockDatabaseConnection(conn) => conn.get_database_backend(),
569            #[cfg(feature = "proxy")]
570            DatabaseConnectionType::ProxyDatabaseConnection(conn) => conn.get_database_backend(),
571            DatabaseConnectionType::Disconnected => panic!("Disconnected"),
572        }
573    }
574
575    /// Creates a [`SchemaBuilder`] for this backend
576    pub fn get_schema_builder(&self) -> SchemaBuilder {
577        Schema::new(self.get_database_backend()).builder()
578    }
579
580    #[cfg(feature = "entity-registry")]
581    #[cfg_attr(docsrs, doc(cfg(feature = "entity-registry")))]
582    /// Builds a schema for all the entites in the given module
583    pub fn get_schema_registry(&self, prefix: &str) -> SchemaBuilder {
584        let schema = Schema::new(self.get_database_backend());
585        crate::EntityRegistry::build_schema(schema, prefix)
586    }
587
588    /// Sets a callback to metric this connection
589    pub fn set_metric_callback<F>(&mut self, _callback: F)
590    where
591        F: Fn(&crate::metric::Info<'_>) + 'static,
592    {
593        match &mut self.inner {
594            #[cfg(feature = "sqlx-mysql")]
595            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => {
596                conn.set_metric_callback(_callback)
597            }
598            #[cfg(feature = "sqlx-postgres")]
599            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
600                conn.set_metric_callback(_callback)
601            }
602            #[cfg(feature = "sqlx-sqlite")]
603            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => {
604                conn.set_metric_callback(_callback)
605            }
606            #[cfg(feature = "rusqlite")]
607            DatabaseConnectionType::RusqliteSharedConnection(conn) => {
608                conn.set_metric_callback(_callback)
609            }
610            _ => {}
611        }
612    }
613
614    /// Checks if a connection to the database is still valid.
615    pub fn ping(&self) -> Result<(), DbErr> {
616        match &self.inner {
617            #[cfg(feature = "sqlx-mysql")]
618            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => conn.ping(),
619            #[cfg(feature = "sqlx-postgres")]
620            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => conn.ping(),
621            #[cfg(feature = "sqlx-sqlite")]
622            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => conn.ping(),
623            #[cfg(feature = "rusqlite")]
624            DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.ping(),
625            #[cfg(feature = "mock")]
626            DatabaseConnectionType::MockDatabaseConnection(conn) => conn.ping(),
627            #[cfg(feature = "proxy")]
628            DatabaseConnectionType::ProxyDatabaseConnection(conn) => conn.ping(),
629            DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
630        }
631    }
632
633    /// Explicitly close the database connection.
634    /// See [`Self::close_by_ref`] for usage with references.
635    pub fn close(self) -> Result<(), DbErr> {
636        self.close_by_ref()
637    }
638
639    /// Explicitly close the database connection
640    pub fn close_by_ref(&self) -> Result<(), DbErr> {
641        match &self.inner {
642            #[cfg(feature = "sqlx-mysql")]
643            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => conn.close_by_ref(),
644            #[cfg(feature = "sqlx-postgres")]
645            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => conn.close_by_ref(),
646            #[cfg(feature = "sqlx-sqlite")]
647            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => conn.close_by_ref(),
648            #[cfg(feature = "rusqlite")]
649            DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.close_by_ref(),
650            #[cfg(feature = "mock")]
651            DatabaseConnectionType::MockDatabaseConnection(_) => {
652                // Nothing to cleanup, we just consume the `DatabaseConnection`
653                Ok(())
654            }
655            #[cfg(feature = "proxy")]
656            DatabaseConnectionType::ProxyDatabaseConnection(_) => {
657                // Nothing to cleanup, we just consume the `DatabaseConnection`
658                Ok(())
659            }
660            DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
661        }
662    }
663}
664
665impl DatabaseConnection {
666    /// Get [sqlx::MySqlPool]
667    ///
668    /// # Panics
669    ///
670    /// Panics if [DbConn] is not a MySQL connection.
671    #[cfg(feature = "sqlx-mysql")]
672    pub fn get_mysql_connection_pool(&self) -> &sqlx::MySqlPool {
673        match &self.inner {
674            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => &conn.pool,
675            _ => panic!("Not MySQL Connection"),
676        }
677    }
678
679    /// Get [sqlx::PgPool]
680    ///
681    /// # Panics
682    ///
683    /// Panics if [DbConn] is not a Postgres connection.
684    #[cfg(feature = "sqlx-postgres")]
685    pub fn get_postgres_connection_pool(&self) -> &sqlx::PgPool {
686        match &self.inner {
687            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => &conn.pool,
688            _ => panic!("Not Postgres Connection"),
689        }
690    }
691
692    /// Get [sqlx::SqlitePool]
693    ///
694    /// # Panics
695    ///
696    /// Panics if [DbConn] is not a SQLite connection.
697    #[cfg(feature = "sqlx-sqlite")]
698    pub fn get_sqlite_connection_pool(&self) -> &sqlx::SqlitePool {
699        match &self.inner {
700            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => &conn.pool,
701            _ => panic!("Not SQLite Connection"),
702        }
703    }
704}
705
706impl DbBackend {
707    /// Check if the URI is the same as the specified database backend.
708    /// Returns true if they match.
709    ///
710    /// # Panics
711    ///
712    /// Panics if `base_url` cannot be parsed as `Url`.
713    pub fn is_prefix_of(self, base_url: &str) -> bool {
714        let base_url_parsed = Url::parse(base_url).expect("Fail to parse database URL");
715        match self {
716            Self::Postgres => {
717                base_url_parsed.scheme() == "postgres" || base_url_parsed.scheme() == "postgresql"
718            }
719            Self::MySql => base_url_parsed.scheme() == "mysql",
720            Self::Sqlite => base_url_parsed.scheme() == "sqlite",
721        }
722    }
723
724    /// Build an SQL [Statement]
725    pub fn build<S>(&self, statement: &S) -> Statement
726    where
727        S: StatementBuilder,
728    {
729        statement.build(self)
730    }
731
732    /// Check if the database supports `RETURNING` syntax on insert and update
733    pub fn support_returning(&self) -> bool {
734        match self {
735            Self::Postgres => true,
736            Self::Sqlite if cfg!(feature = "sqlite-use-returning-for-3_35") => true,
737            Self::MySql if cfg!(feature = "mariadb-use-returning") => true,
738            _ => false,
739        }
740    }
741
742    /// A getter for database dependent boolean value
743    pub fn boolean_value(&self, boolean: bool) -> sea_query::Value {
744        match self {
745            Self::MySql | Self::Postgres | Self::Sqlite => boolean.into(),
746        }
747    }
748
749    /// Get the display string for this enum
750    pub fn as_str(&self) -> &'static str {
751        match self {
752            DatabaseBackend::MySql => "MySql",
753            DatabaseBackend::Postgres => "Postgres",
754            DatabaseBackend::Sqlite => "Sqlite",
755        }
756    }
757}
758
759#[cfg(test)]
760mod tests {
761    use crate::DatabaseConnection;
762
763    #[cfg(not(feature = "sync"))]
764    #[test]
765    fn assert_database_connection_traits() {
766        fn assert_send_sync<T: Send>() {}
767
768        assert_send_sync::<DatabaseConnection>();
769    }
770}