Skip to main content

sea_orm/database/
db_connection.rs

1use crate::{
2    AccessMode, ConnectionTrait, DatabaseTransaction, ExecResult, IsolationLevel, QueryResult,
3    Schema, SchemaBuilder, Statement, StatementBuilder, StreamTrait, TransactionError,
4    TransactionOptions, TransactionTrait, error::*,
5};
6use std::fmt::Debug;
7use tracing::instrument;
8use url::Url;
9
10#[cfg(feature = "sqlx-dep")]
11use sqlx::pool::PoolConnection;
12
13#[cfg(feature = "rusqlite")]
14use crate::driver::rusqlite::{RusqliteInnerConnection, RusqliteSharedConnection};
15
16#[cfg(any(feature = "mock", feature = "proxy"))]
17use std::sync::Arc;
18
19/// Handle a database connection depending on the backend enabled by the feature
20/// flags. This creates a connection pool internally (for SQLx connections),
21/// and so is cheap to clone.
22#[derive(Debug, Clone)]
23#[non_exhaustive]
24pub struct DatabaseConnection {
25    /// `DatabaseConnection` used to be a enum. Now it's moved into inner,
26    /// because we have to attach other contexts.
27    pub inner: DatabaseConnectionType,
28    #[cfg(feature = "rbac")]
29    pub(crate) rbac: crate::RbacEngineMount,
30}
31
32/// The underlying database connection type.
33#[derive(Clone)]
34pub enum DatabaseConnectionType {
35    /// MySql database connection pool
36    #[cfg(feature = "sqlx-mysql")]
37    SqlxMySqlPoolConnection(crate::SqlxMySqlPoolConnection),
38
39    /// PostgreSQL database connection pool
40    #[cfg(feature = "sqlx-postgres")]
41    SqlxPostgresPoolConnection(crate::SqlxPostgresPoolConnection),
42
43    /// SQLite database connection pool
44    #[cfg(feature = "sqlx-sqlite")]
45    SqlxSqlitePoolConnection(crate::SqlxSqlitePoolConnection),
46
47    /// SQLite database connection sharable across threads
48    #[cfg(feature = "rusqlite")]
49    RusqliteSharedConnection(RusqliteSharedConnection),
50
51    /// Mock database connection useful for testing
52    #[cfg(feature = "mock")]
53    MockDatabaseConnection(Arc<crate::MockDatabaseConnection>),
54
55    /// Proxy database connection
56    #[cfg(feature = "proxy")]
57    ProxyDatabaseConnection(Arc<crate::ProxyDatabaseConnection>),
58
59    /// The connection has never been established
60    Disconnected,
61}
62
63/// The same as a [DatabaseConnection]
64pub type DbConn = DatabaseConnection;
65
66impl Default for DatabaseConnection {
67    fn default() -> Self {
68        DatabaseConnectionType::Disconnected.into()
69    }
70}
71
72impl From<DatabaseConnectionType> for DatabaseConnection {
73    fn from(inner: DatabaseConnectionType) -> Self {
74        Self {
75            inner,
76            #[cfg(feature = "rbac")]
77            rbac: Default::default(),
78        }
79    }
80}
81
82/// The type of database backend for real world databases.
83/// This is enabled by feature flags as specified in the crate documentation
84#[derive(Debug, Copy, Clone, PartialEq, Eq)]
85#[non_exhaustive]
86pub enum DatabaseBackend {
87    /// A MySQL backend
88    MySql,
89    /// A PostgreSQL backend
90    Postgres,
91    /// A SQLite backend
92    Sqlite,
93}
94
95/// A shorthand for [DatabaseBackend].
96pub type DbBackend = DatabaseBackend;
97
98#[derive(Debug)]
99pub(crate) enum InnerConnection {
100    #[cfg(feature = "sqlx-mysql")]
101    MySql(PoolConnection<sqlx::MySql>),
102    #[cfg(feature = "sqlx-postgres")]
103    Postgres(PoolConnection<sqlx::Postgres>),
104    #[cfg(feature = "sqlx-sqlite")]
105    Sqlite(PoolConnection<sqlx::Sqlite>),
106    #[cfg(feature = "rusqlite")]
107    Rusqlite(RusqliteInnerConnection),
108    #[cfg(feature = "mock")]
109    Mock(Arc<crate::MockDatabaseConnection>),
110    #[cfg(feature = "proxy")]
111    Proxy(Arc<crate::ProxyDatabaseConnection>),
112}
113
114impl Debug for DatabaseConnectionType {
115    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
116        write!(
117            f,
118            "{}",
119            match self {
120                #[cfg(feature = "sqlx-mysql")]
121                Self::SqlxMySqlPoolConnection(_) => "SqlxMySqlPoolConnection",
122                #[cfg(feature = "sqlx-postgres")]
123                Self::SqlxPostgresPoolConnection(_) => "SqlxPostgresPoolConnection",
124                #[cfg(feature = "sqlx-sqlite")]
125                Self::SqlxSqlitePoolConnection(_) => "SqlxSqlitePoolConnection",
126                #[cfg(feature = "rusqlite")]
127                Self::RusqliteSharedConnection(_) => "RusqliteSharedConnection",
128                #[cfg(feature = "mock")]
129                Self::MockDatabaseConnection(_) => "MockDatabaseConnection",
130                #[cfg(feature = "proxy")]
131                Self::ProxyDatabaseConnection(_) => "ProxyDatabaseConnection",
132                Self::Disconnected => "Disconnected",
133            }
134        )
135    }
136}
137
138impl ConnectionTrait for DatabaseConnection {
139    fn get_database_backend(&self) -> DbBackend {
140        self.get_database_backend()
141    }
142
143    #[instrument(level = "trace")]
144    #[allow(unused_variables)]
145    fn execute_raw(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
146        super::tracing_spans::with_db_span!(
147            "sea_orm.execute",
148            self.get_database_backend(),
149            stmt.sql.as_str(),
150            record_stmt = true,
151            {
152                match &self.inner {
153                    #[cfg(feature = "sqlx-mysql")]
154                    DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => conn.execute(stmt),
155                    #[cfg(feature = "sqlx-postgres")]
156                    DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => conn.execute(stmt),
157                    #[cfg(feature = "sqlx-sqlite")]
158                    DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => conn.execute(stmt),
159                    #[cfg(feature = "rusqlite")]
160                    DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.execute(stmt),
161                    #[cfg(feature = "mock")]
162                    DatabaseConnectionType::MockDatabaseConnection(conn) => conn.execute(stmt),
163                    #[cfg(feature = "proxy")]
164                    DatabaseConnectionType::ProxyDatabaseConnection(conn) => conn.execute(stmt),
165                    DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
166                }
167            }
168        )
169    }
170
171    #[instrument(level = "trace")]
172    #[allow(unused_variables)]
173    fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
174        super::tracing_spans::with_db_span!(
175            "sea_orm.execute_unprepared",
176            self.get_database_backend(),
177            sql,
178            record_stmt = false,
179            {
180                match &self.inner {
181                    #[cfg(feature = "sqlx-mysql")]
182                    DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => {
183                        conn.execute_unprepared(sql)
184                    }
185                    #[cfg(feature = "sqlx-postgres")]
186                    DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
187                        conn.execute_unprepared(sql)
188                    }
189                    #[cfg(feature = "sqlx-sqlite")]
190                    DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => {
191                        conn.execute_unprepared(sql)
192                    }
193                    #[cfg(feature = "rusqlite")]
194                    DatabaseConnectionType::RusqliteSharedConnection(conn) => {
195                        conn.execute_unprepared(sql)
196                    }
197                    #[cfg(feature = "mock")]
198                    DatabaseConnectionType::MockDatabaseConnection(conn) => {
199                        let db_backend = conn.get_database_backend();
200                        let stmt = Statement::from_string(db_backend, sql);
201                        conn.execute(stmt)
202                    }
203                    #[cfg(feature = "proxy")]
204                    DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
205                        let db_backend = conn.get_database_backend();
206                        let stmt = Statement::from_string(db_backend, sql);
207                        conn.execute(stmt)
208                    }
209                    DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
210                }
211            }
212        )
213    }
214
215    #[instrument(level = "trace")]
216    #[allow(unused_variables)]
217    fn query_one_raw(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
218        super::tracing_spans::with_db_span!(
219            "sea_orm.query_one",
220            self.get_database_backend(),
221            stmt.sql.as_str(),
222            record_stmt = true,
223            {
224                match &self.inner {
225                    #[cfg(feature = "sqlx-mysql")]
226                    DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => conn.query_one(stmt),
227                    #[cfg(feature = "sqlx-postgres")]
228                    DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
229                        conn.query_one(stmt)
230                    }
231                    #[cfg(feature = "sqlx-sqlite")]
232                    DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => conn.query_one(stmt),
233                    #[cfg(feature = "rusqlite")]
234                    DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.query_one(stmt),
235                    #[cfg(feature = "mock")]
236                    DatabaseConnectionType::MockDatabaseConnection(conn) => conn.query_one(stmt),
237                    #[cfg(feature = "proxy")]
238                    DatabaseConnectionType::ProxyDatabaseConnection(conn) => conn.query_one(stmt),
239                    DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
240                }
241            }
242        )
243    }
244
245    #[instrument(level = "trace")]
246    #[allow(unused_variables)]
247    fn query_all_raw(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
248        super::tracing_spans::with_db_span!(
249            "sea_orm.query_all",
250            self.get_database_backend(),
251            stmt.sql.as_str(),
252            record_stmt = true,
253            {
254                match &self.inner {
255                    #[cfg(feature = "sqlx-mysql")]
256                    DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => conn.query_all(stmt),
257                    #[cfg(feature = "sqlx-postgres")]
258                    DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
259                        conn.query_all(stmt)
260                    }
261                    #[cfg(feature = "sqlx-sqlite")]
262                    DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => conn.query_all(stmt),
263                    #[cfg(feature = "rusqlite")]
264                    DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.query_all(stmt),
265                    #[cfg(feature = "mock")]
266                    DatabaseConnectionType::MockDatabaseConnection(conn) => conn.query_all(stmt),
267                    #[cfg(feature = "proxy")]
268                    DatabaseConnectionType::ProxyDatabaseConnection(conn) => conn.query_all(stmt),
269                    DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
270                }
271            }
272        )
273    }
274
275    #[cfg(feature = "mock")]
276    fn is_mock_connection(&self) -> bool {
277        matches!(
278            self,
279            DatabaseConnection {
280                inner: DatabaseConnectionType::MockDatabaseConnection(_),
281                ..
282            }
283        )
284    }
285}
286
287impl StreamTrait for DatabaseConnection {
288    type Stream<'a> = crate::QueryStream;
289
290    fn get_database_backend(&self) -> DbBackend {
291        self.get_database_backend()
292    }
293
294    #[instrument(level = "trace")]
295    #[allow(unused_variables)]
296    fn stream_raw<'a>(&'a self, stmt: Statement) -> Result<Self::Stream<'a>, DbErr> {
297        ({
298            match &self.inner {
299                #[cfg(feature = "sqlx-mysql")]
300                DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => conn.stream(stmt),
301                #[cfg(feature = "sqlx-postgres")]
302                DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => conn.stream(stmt),
303                #[cfg(feature = "sqlx-sqlite")]
304                DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => conn.stream(stmt),
305                #[cfg(feature = "rusqlite")]
306                DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.stream(stmt),
307                #[cfg(feature = "mock")]
308                DatabaseConnectionType::MockDatabaseConnection(conn) => {
309                    Ok(crate::QueryStream::from((Arc::clone(conn), stmt, None)))
310                }
311                #[cfg(feature = "proxy")]
312                DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
313                    Ok(crate::QueryStream::from((Arc::clone(conn), stmt, None)))
314                }
315                DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
316            }
317        })
318    }
319}
320
321impl TransactionTrait for DatabaseConnection {
322    type Transaction = DatabaseTransaction;
323
324    #[instrument(level = "trace")]
325    fn begin(&self) -> Result<DatabaseTransaction, DbErr> {
326        match &self.inner {
327            #[cfg(feature = "sqlx-mysql")]
328            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => conn.begin(None, None),
329            #[cfg(feature = "sqlx-postgres")]
330            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => conn.begin(None, None),
331            #[cfg(feature = "sqlx-sqlite")]
332            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => conn.begin(None, None, None),
333            #[cfg(feature = "rusqlite")]
334            DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.begin(None, None, None),
335            #[cfg(feature = "mock")]
336            DatabaseConnectionType::MockDatabaseConnection(conn) => {
337                DatabaseTransaction::new_mock(Arc::clone(conn), None)
338            }
339            #[cfg(feature = "proxy")]
340            DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
341                DatabaseTransaction::new_proxy(conn.clone(), None)
342            }
343            DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
344        }
345    }
346
347    #[instrument(level = "trace")]
348    fn begin_with_config(
349        &self,
350        _isolation_level: Option<IsolationLevel>,
351        _access_mode: Option<AccessMode>,
352    ) -> Result<DatabaseTransaction, DbErr> {
353        match &self.inner {
354            #[cfg(feature = "sqlx-mysql")]
355            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => {
356                conn.begin(_isolation_level, _access_mode)
357            }
358            #[cfg(feature = "sqlx-postgres")]
359            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
360                conn.begin(_isolation_level, _access_mode)
361            }
362            #[cfg(feature = "sqlx-sqlite")]
363            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => {
364                conn.begin(_isolation_level, _access_mode, None)
365            }
366            #[cfg(feature = "rusqlite")]
367            DatabaseConnectionType::RusqliteSharedConnection(conn) => {
368                conn.begin(_isolation_level, _access_mode, None)
369            }
370            #[cfg(feature = "mock")]
371            DatabaseConnectionType::MockDatabaseConnection(conn) => {
372                DatabaseTransaction::new_mock(Arc::clone(conn), None)
373            }
374            #[cfg(feature = "proxy")]
375            DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
376                DatabaseTransaction::new_proxy(conn.clone(), None)
377            }
378            DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
379        }
380    }
381
382    #[instrument(level = "trace")]
383    fn begin_with_options(
384        &self,
385        TransactionOptions {
386            isolation_level: _isolation_level,
387            access_mode: _access_mode,
388            sqlite_transaction_mode: _sqlite_transaction_mode,
389        }: TransactionOptions,
390    ) -> Result<DatabaseTransaction, DbErr> {
391        match &self.inner {
392            #[cfg(feature = "sqlx-mysql")]
393            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => {
394                conn.begin(_isolation_level, _access_mode)
395            }
396            #[cfg(feature = "sqlx-postgres")]
397            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
398                conn.begin(_isolation_level, _access_mode)
399            }
400            #[cfg(feature = "sqlx-sqlite")]
401            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => {
402                conn.begin(_isolation_level, _access_mode, _sqlite_transaction_mode)
403            }
404            #[cfg(feature = "rusqlite")]
405            DatabaseConnectionType::RusqliteSharedConnection(conn) => {
406                conn.begin(_isolation_level, _access_mode, _sqlite_transaction_mode)
407            }
408            #[cfg(feature = "mock")]
409            DatabaseConnectionType::MockDatabaseConnection(conn) => {
410                DatabaseTransaction::new_mock(Arc::clone(conn), None)
411            }
412            #[cfg(feature = "proxy")]
413            DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
414                DatabaseTransaction::new_proxy(conn.clone(), None)
415            }
416            DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
417        }
418    }
419
420    /// Execute the function inside a transaction.
421    /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed.
422    #[instrument(level = "trace", skip(_callback))]
423    fn transaction<F, T, E>(&self, _callback: F) -> Result<T, TransactionError<E>>
424    where
425        F: for<'c> FnOnce(&'c DatabaseTransaction) -> Result<T, E>,
426        E: std::fmt::Display + std::fmt::Debug,
427    {
428        match &self.inner {
429            #[cfg(feature = "sqlx-mysql")]
430            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => {
431                conn.transaction(_callback, None, None)
432            }
433            #[cfg(feature = "sqlx-postgres")]
434            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
435                conn.transaction(_callback, None, None)
436            }
437            #[cfg(feature = "sqlx-sqlite")]
438            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => {
439                conn.transaction(_callback, None, None)
440            }
441            #[cfg(feature = "rusqlite")]
442            DatabaseConnectionType::RusqliteSharedConnection(conn) => {
443                conn.transaction(_callback, None, None)
444            }
445            #[cfg(feature = "mock")]
446            DatabaseConnectionType::MockDatabaseConnection(conn) => {
447                let transaction = DatabaseTransaction::new_mock(Arc::clone(conn), None)
448                    .map_err(TransactionError::Connection)?;
449                transaction.run(_callback)
450            }
451            #[cfg(feature = "proxy")]
452            DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
453                let transaction = DatabaseTransaction::new_proxy(conn.clone(), None)
454                    .map_err(TransactionError::Connection)?;
455                transaction.run(_callback)
456            }
457            DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected").into()),
458        }
459    }
460
461    /// Execute the function inside a transaction.
462    /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed.
463    #[instrument(level = "trace", skip(_callback))]
464    fn transaction_with_config<F, T, E>(
465        &self,
466        _callback: F,
467        _isolation_level: Option<IsolationLevel>,
468        _access_mode: Option<AccessMode>,
469    ) -> Result<T, TransactionError<E>>
470    where
471        F: for<'c> FnOnce(&'c DatabaseTransaction) -> Result<T, E>,
472        E: std::fmt::Display + std::fmt::Debug,
473    {
474        match &self.inner {
475            #[cfg(feature = "sqlx-mysql")]
476            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => {
477                conn.transaction(_callback, _isolation_level, _access_mode)
478            }
479            #[cfg(feature = "sqlx-postgres")]
480            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
481                conn.transaction(_callback, _isolation_level, _access_mode)
482            }
483            #[cfg(feature = "sqlx-sqlite")]
484            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => {
485                conn.transaction(_callback, _isolation_level, _access_mode)
486            }
487            #[cfg(feature = "rusqlite")]
488            DatabaseConnectionType::RusqliteSharedConnection(conn) => {
489                conn.transaction(_callback, _isolation_level, _access_mode)
490            }
491            #[cfg(feature = "mock")]
492            DatabaseConnectionType::MockDatabaseConnection(conn) => {
493                let transaction = DatabaseTransaction::new_mock(Arc::clone(conn), None)
494                    .map_err(TransactionError::Connection)?;
495                transaction.run(_callback)
496            }
497            #[cfg(feature = "proxy")]
498            DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
499                let transaction = DatabaseTransaction::new_proxy(conn.clone(), None)
500                    .map_err(TransactionError::Connection)?;
501                transaction.run(_callback)
502            }
503            DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected").into()),
504        }
505    }
506}
507
508#[cfg(feature = "mock")]
509impl DatabaseConnection {
510    /// Generate a database connection for testing the Mock database
511    ///
512    /// # Panics
513    ///
514    /// Panics if [DbConn] is not a mock connection.
515    pub fn as_mock_connection(&self) -> &crate::MockDatabaseConnection {
516        match &self.inner {
517            DatabaseConnectionType::MockDatabaseConnection(mock_conn) => mock_conn,
518            _ => panic!("Not mock connection"),
519        }
520    }
521
522    /// Get the transaction log as a collection Vec<[crate::Transaction]>
523    ///
524    /// # Panics
525    ///
526    /// Panics if the mocker mutex is being held by another thread.
527    pub fn into_transaction_log(self) -> Vec<crate::Transaction> {
528        let mut mocker = self
529            .as_mock_connection()
530            .get_mocker_mutex()
531            .lock()
532            .expect("Fail to acquire mocker");
533        mocker.drain_transaction_log()
534    }
535}
536
537#[cfg(feature = "proxy")]
538impl DatabaseConnection {
539    /// Generate a database connection for testing the Proxy database
540    ///
541    /// # Panics
542    ///
543    /// Panics if [DbConn] is not a proxy connection.
544    pub fn as_proxy_connection(&self) -> &crate::ProxyDatabaseConnection {
545        match &self.inner {
546            DatabaseConnectionType::ProxyDatabaseConnection(proxy_conn) => proxy_conn,
547            _ => panic!("Not proxy connection"),
548        }
549    }
550}
551
552#[cfg(feature = "rbac")]
553impl DatabaseConnection {
554    /// Load RBAC data from the same database as this connection and setup RBAC engine.
555    /// If the RBAC engine already exists, it will be replaced.
556    pub fn load_rbac(&self) -> Result<(), DbErr> {
557        self.load_rbac_from(self)
558    }
559
560    /// Load RBAC data from the given database connection and setup RBAC engine.
561    /// This could be from another database.
562    pub fn load_rbac_from(&self, db: &DbConn) -> Result<(), DbErr> {
563        let engine = crate::rbac::RbacEngine::load_from(db)?;
564        self.rbac.replace(engine);
565        Ok(())
566    }
567
568    /// Replace the internal RBAC engine.
569    pub fn replace_rbac(&self, engine: crate::rbac::RbacEngine) {
570        self.rbac.replace(engine);
571    }
572
573    /// Create a restricted connection with access control specific for the user.
574    pub fn restricted_for(
575        &self,
576        user_id: crate::rbac::RbacUserId,
577    ) -> Result<crate::RestrictedConnection, DbErr> {
578        if self.rbac.is_some() {
579            Ok(crate::RestrictedConnection {
580                user_id,
581                conn: self.clone(),
582            })
583        } else {
584            Err(DbErr::RbacError("engine not set up".into()))
585        }
586    }
587}
588
589impl DatabaseConnection {
590    /// Get the database backend for this connection
591    ///
592    /// # Panics
593    ///
594    /// Panics if [DatabaseConnection] is `Disconnected`.
595    pub fn get_database_backend(&self) -> DbBackend {
596        match &self.inner {
597            #[cfg(feature = "sqlx-mysql")]
598            DatabaseConnectionType::SqlxMySqlPoolConnection(_) => DbBackend::MySql,
599            #[cfg(feature = "sqlx-postgres")]
600            DatabaseConnectionType::SqlxPostgresPoolConnection(_) => DbBackend::Postgres,
601            #[cfg(feature = "sqlx-sqlite")]
602            DatabaseConnectionType::SqlxSqlitePoolConnection(_) => DbBackend::Sqlite,
603            #[cfg(feature = "rusqlite")]
604            DatabaseConnectionType::RusqliteSharedConnection(_) => DbBackend::Sqlite,
605            #[cfg(feature = "mock")]
606            DatabaseConnectionType::MockDatabaseConnection(conn) => conn.get_database_backend(),
607            #[cfg(feature = "proxy")]
608            DatabaseConnectionType::ProxyDatabaseConnection(conn) => conn.get_database_backend(),
609            DatabaseConnectionType::Disconnected => panic!("Disconnected"),
610        }
611    }
612
613    /// Creates a [`SchemaBuilder`] for this backend
614    pub fn get_schema_builder(&self) -> SchemaBuilder {
615        Schema::new(self.get_database_backend()).builder()
616    }
617
618    #[cfg(feature = "entity-registry")]
619    #[cfg_attr(docsrs, doc(cfg(feature = "entity-registry")))]
620    /// Builds a schema for all the entites in the given module
621    pub fn get_schema_registry(&self, prefix: &str) -> SchemaBuilder {
622        let schema = Schema::new(self.get_database_backend());
623        crate::EntityRegistry::build_schema(schema, prefix)
624    }
625
626    /// Sets a callback to metric this connection
627    pub fn set_metric_callback<F>(&mut self, _callback: F)
628    where
629        F: Fn(&crate::metric::Info<'_>) + 'static,
630    {
631        match &mut self.inner {
632            #[cfg(feature = "sqlx-mysql")]
633            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => {
634                conn.set_metric_callback(_callback)
635            }
636            #[cfg(feature = "sqlx-postgres")]
637            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
638                conn.set_metric_callback(_callback)
639            }
640            #[cfg(feature = "sqlx-sqlite")]
641            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => {
642                conn.set_metric_callback(_callback)
643            }
644            #[cfg(feature = "rusqlite")]
645            DatabaseConnectionType::RusqliteSharedConnection(conn) => {
646                conn.set_metric_callback(_callback)
647            }
648            _ => {}
649        }
650    }
651
652    /// Checks if a connection to the database is still valid.
653    pub fn ping(&self) -> Result<(), DbErr> {
654        match &self.inner {
655            #[cfg(feature = "sqlx-mysql")]
656            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => conn.ping(),
657            #[cfg(feature = "sqlx-postgres")]
658            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => conn.ping(),
659            #[cfg(feature = "sqlx-sqlite")]
660            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => conn.ping(),
661            #[cfg(feature = "rusqlite")]
662            DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.ping(),
663            #[cfg(feature = "mock")]
664            DatabaseConnectionType::MockDatabaseConnection(conn) => conn.ping(),
665            #[cfg(feature = "proxy")]
666            DatabaseConnectionType::ProxyDatabaseConnection(conn) => conn.ping(),
667            DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
668        }
669    }
670
671    /// Explicitly close the database connection.
672    /// See [`Self::close_by_ref`] for usage with references.
673    pub fn close(self) -> Result<(), DbErr> {
674        self.close_by_ref()
675    }
676
677    /// Explicitly close the database connection
678    pub fn close_by_ref(&self) -> Result<(), DbErr> {
679        match &self.inner {
680            #[cfg(feature = "sqlx-mysql")]
681            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => conn.close_by_ref(),
682            #[cfg(feature = "sqlx-postgres")]
683            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => conn.close_by_ref(),
684            #[cfg(feature = "sqlx-sqlite")]
685            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => conn.close_by_ref(),
686            #[cfg(feature = "rusqlite")]
687            DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.close_by_ref(),
688            #[cfg(feature = "mock")]
689            DatabaseConnectionType::MockDatabaseConnection(_) => {
690                // Nothing to cleanup, we just consume the `DatabaseConnection`
691                Ok(())
692            }
693            #[cfg(feature = "proxy")]
694            DatabaseConnectionType::ProxyDatabaseConnection(_) => {
695                // Nothing to cleanup, we just consume the `DatabaseConnection`
696                Ok(())
697            }
698            DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
699        }
700    }
701}
702
703impl DatabaseConnection {
704    /// Get [sqlx::MySqlPool]
705    ///
706    /// # Panics
707    ///
708    /// Panics if [DbConn] is not a MySQL connection.
709    #[cfg(feature = "sqlx-mysql")]
710    pub fn get_mysql_connection_pool(&self) -> &sqlx::MySqlPool {
711        match &self.inner {
712            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => &conn.pool,
713            _ => panic!("Not MySQL Connection"),
714        }
715    }
716
717    /// Get [sqlx::PgPool]
718    ///
719    /// # Panics
720    ///
721    /// Panics if [DbConn] is not a Postgres connection.
722    #[cfg(feature = "sqlx-postgres")]
723    pub fn get_postgres_connection_pool(&self) -> &sqlx::PgPool {
724        match &self.inner {
725            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => &conn.pool,
726            _ => panic!("Not Postgres Connection"),
727        }
728    }
729
730    /// Get [sqlx::SqlitePool]
731    ///
732    /// # Panics
733    ///
734    /// Panics if [DbConn] is not a SQLite connection.
735    #[cfg(feature = "sqlx-sqlite")]
736    pub fn get_sqlite_connection_pool(&self) -> &sqlx::SqlitePool {
737        match &self.inner {
738            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => &conn.pool,
739            _ => panic!("Not SQLite Connection"),
740        }
741    }
742}
743
744impl DbBackend {
745    /// Check if the URI is the same as the specified database backend.
746    /// Returns true if they match.
747    ///
748    /// # Panics
749    ///
750    /// Panics if `base_url` cannot be parsed as `Url`.
751    pub fn is_prefix_of(self, base_url: &str) -> bool {
752        let base_url_parsed = Url::parse(base_url).expect("Fail to parse database URL");
753        match self {
754            Self::Postgres => {
755                base_url_parsed.scheme() == "postgres" || base_url_parsed.scheme() == "postgresql"
756            }
757            Self::MySql => base_url_parsed.scheme() == "mysql",
758            Self::Sqlite => base_url_parsed.scheme() == "sqlite",
759        }
760    }
761
762    /// Build an SQL [Statement]
763    pub fn build<S>(&self, statement: &S) -> Statement
764    where
765        S: StatementBuilder,
766    {
767        statement.build(self)
768    }
769
770    /// Check if the database supports `RETURNING` syntax on insert and update
771    pub fn support_returning(&self) -> bool {
772        match self {
773            Self::Postgres => true,
774            Self::Sqlite if cfg!(feature = "sqlite-use-returning-for-3_35") => true,
775            Self::MySql if cfg!(feature = "mariadb-use-returning") => true,
776            _ => false,
777        }
778    }
779
780    /// A getter for database dependent boolean value
781    pub fn boolean_value(&self, boolean: bool) -> sea_query::Value {
782        match self {
783            Self::MySql | Self::Postgres | Self::Sqlite => boolean.into(),
784        }
785    }
786
787    /// Get the display string for this enum
788    pub fn as_str(&self) -> &'static str {
789        match self {
790            DatabaseBackend::MySql => "MySql",
791            DatabaseBackend::Postgres => "Postgres",
792            DatabaseBackend::Sqlite => "Sqlite",
793        }
794    }
795}
796
797#[cfg(test)]
798mod tests {
799    use crate::DatabaseConnection;
800
801    #[cfg(not(feature = "sync"))]
802    #[test]
803    fn assert_database_connection_traits() {
804        fn assert_send_sync<T: Send>() {}
805
806        assert_send_sync::<DatabaseConnection>();
807    }
808}