Skip to main content

sea_orm/database/
db_connection.rs

1use crate::{
2    AccessMode, ConnectionTrait, DatabaseTransaction, ExecResult, IsolationLevel, QueryResult,
3    Schema, SchemaBuilder, Statement, StatementBuilder, StreamTrait, TransactionError,
4    TransactionOptions, TransactionTrait, error::*,
5};
6use std::{fmt::Debug, future::Future, pin::Pin};
7use tracing::instrument;
8use url::Url;
9
10#[cfg(feature = "sqlx-dep")]
11use sqlx::pool::PoolConnection;
12
13#[cfg(feature = "rusqlite")]
14use crate::driver::rusqlite::{RusqliteInnerConnection, RusqliteSharedConnection};
15
16#[cfg(any(feature = "mock", feature = "proxy"))]
17use std::sync::Arc;
18
19/// Handle a database connection depending on the backend enabled by the feature
20/// flags. This creates a connection pool internally (for SQLx connections),
21/// and so is cheap to clone.
22#[derive(Debug, Clone)]
23#[non_exhaustive]
24pub struct DatabaseConnection {
25    /// `DatabaseConnection` used to be a enum. Now it's moved into inner,
26    /// because we have to attach other contexts.
27    pub inner: DatabaseConnectionType,
28    #[cfg(feature = "rbac")]
29    pub(crate) rbac: crate::RbacEngineMount,
30}
31
32/// The underlying database connection type.
33#[derive(Clone)]
34pub enum DatabaseConnectionType {
35    /// MySql database connection pool
36    #[cfg(feature = "sqlx-mysql")]
37    SqlxMySqlPoolConnection(crate::SqlxMySqlPoolConnection),
38
39    /// PostgreSQL database connection pool
40    #[cfg(feature = "sqlx-postgres")]
41    SqlxPostgresPoolConnection(crate::SqlxPostgresPoolConnection),
42
43    /// SQLite database connection pool
44    #[cfg(feature = "sqlx-sqlite")]
45    SqlxSqlitePoolConnection(crate::SqlxSqlitePoolConnection),
46
47    /// SQLite database connection sharable across threads
48    #[cfg(feature = "rusqlite")]
49    RusqliteSharedConnection(RusqliteSharedConnection),
50
51    /// Mock database connection useful for testing
52    #[cfg(feature = "mock")]
53    MockDatabaseConnection(Arc<crate::MockDatabaseConnection>),
54
55    /// Proxy database connection
56    #[cfg(feature = "proxy")]
57    ProxyDatabaseConnection(Arc<crate::ProxyDatabaseConnection>),
58
59    /// The connection has never been established
60    Disconnected,
61}
62
63/// The same as a [DatabaseConnection]
64pub type DbConn = DatabaseConnection;
65
66impl Default for DatabaseConnection {
67    fn default() -> Self {
68        DatabaseConnectionType::Disconnected.into()
69    }
70}
71
72impl From<DatabaseConnectionType> for DatabaseConnection {
73    fn from(inner: DatabaseConnectionType) -> Self {
74        Self {
75            inner,
76            #[cfg(feature = "rbac")]
77            rbac: Default::default(),
78        }
79    }
80}
81
82/// The type of database backend for real world databases.
83/// This is enabled by feature flags as specified in the crate documentation
84#[derive(Debug, Copy, Clone, PartialEq, Eq)]
85#[non_exhaustive]
86pub enum DatabaseBackend {
87    /// A MySQL backend
88    MySql,
89    /// A PostgreSQL backend
90    Postgres,
91    /// A SQLite backend
92    Sqlite,
93}
94
95/// A shorthand for [DatabaseBackend].
96pub type DbBackend = DatabaseBackend;
97
98#[derive(Debug)]
99pub(crate) enum InnerConnection {
100    #[cfg(feature = "sqlx-mysql")]
101    MySql(PoolConnection<sqlx::MySql>),
102    #[cfg(feature = "sqlx-postgres")]
103    Postgres(PoolConnection<sqlx::Postgres>),
104    #[cfg(feature = "sqlx-sqlite")]
105    Sqlite(PoolConnection<sqlx::Sqlite>),
106    #[cfg(feature = "rusqlite")]
107    Rusqlite(RusqliteInnerConnection),
108    #[cfg(feature = "mock")]
109    Mock(Arc<crate::MockDatabaseConnection>),
110    #[cfg(feature = "proxy")]
111    Proxy(Arc<crate::ProxyDatabaseConnection>),
112}
113
114impl Debug for DatabaseConnectionType {
115    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
116        write!(
117            f,
118            "{}",
119            match self {
120                #[cfg(feature = "sqlx-mysql")]
121                Self::SqlxMySqlPoolConnection(_) => "SqlxMySqlPoolConnection",
122                #[cfg(feature = "sqlx-postgres")]
123                Self::SqlxPostgresPoolConnection(_) => "SqlxPostgresPoolConnection",
124                #[cfg(feature = "sqlx-sqlite")]
125                Self::SqlxSqlitePoolConnection(_) => "SqlxSqlitePoolConnection",
126                #[cfg(feature = "rusqlite")]
127                Self::RusqliteSharedConnection(_) => "RusqliteSharedConnection",
128                #[cfg(feature = "mock")]
129                Self::MockDatabaseConnection(_) => "MockDatabaseConnection",
130                #[cfg(feature = "proxy")]
131                Self::ProxyDatabaseConnection(_) => "ProxyDatabaseConnection",
132                Self::Disconnected => "Disconnected",
133            }
134        )
135    }
136}
137
138#[async_trait::async_trait]
139impl ConnectionTrait for DatabaseConnection {
140    fn get_database_backend(&self) -> DbBackend {
141        self.get_database_backend()
142    }
143
144    #[instrument(level = "trace")]
145    #[allow(unused_variables)]
146    async fn execute_raw(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
147        super::tracing_spans::with_db_span!(
148            "sea_orm.execute",
149            self.get_database_backend(),
150            stmt.sql.as_str(),
151            record_stmt = true,
152            async {
153                match &self.inner {
154                    #[cfg(feature = "sqlx-mysql")]
155                    DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => {
156                        conn.execute(stmt).await
157                    }
158                    #[cfg(feature = "sqlx-postgres")]
159                    DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
160                        conn.execute(stmt).await
161                    }
162                    #[cfg(feature = "sqlx-sqlite")]
163                    DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => {
164                        conn.execute(stmt).await
165                    }
166                    #[cfg(feature = "rusqlite")]
167                    DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.execute(stmt),
168                    #[cfg(feature = "mock")]
169                    DatabaseConnectionType::MockDatabaseConnection(conn) => conn.execute(stmt),
170                    #[cfg(feature = "proxy")]
171                    DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
172                        conn.execute(stmt).await
173                    }
174                    DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
175                }
176            }
177        )
178    }
179
180    #[instrument(level = "trace")]
181    #[allow(unused_variables)]
182    async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
183        super::tracing_spans::with_db_span!(
184            "sea_orm.execute_unprepared",
185            self.get_database_backend(),
186            sql,
187            record_stmt = false,
188            async {
189                match &self.inner {
190                    #[cfg(feature = "sqlx-mysql")]
191                    DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => {
192                        conn.execute_unprepared(sql).await
193                    }
194                    #[cfg(feature = "sqlx-postgres")]
195                    DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
196                        conn.execute_unprepared(sql).await
197                    }
198                    #[cfg(feature = "sqlx-sqlite")]
199                    DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => {
200                        conn.execute_unprepared(sql).await
201                    }
202                    #[cfg(feature = "rusqlite")]
203                    DatabaseConnectionType::RusqliteSharedConnection(conn) => {
204                        conn.execute_unprepared(sql)
205                    }
206                    #[cfg(feature = "mock")]
207                    DatabaseConnectionType::MockDatabaseConnection(conn) => {
208                        let db_backend = conn.get_database_backend();
209                        let stmt = Statement::from_string(db_backend, sql);
210                        conn.execute(stmt)
211                    }
212                    #[cfg(feature = "proxy")]
213                    DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
214                        let db_backend = conn.get_database_backend();
215                        let stmt = Statement::from_string(db_backend, sql);
216                        conn.execute(stmt).await
217                    }
218                    DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
219                }
220            }
221        )
222    }
223
224    #[instrument(level = "trace")]
225    #[allow(unused_variables)]
226    async fn query_one_raw(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
227        super::tracing_spans::with_db_span!(
228            "sea_orm.query_one",
229            self.get_database_backend(),
230            stmt.sql.as_str(),
231            record_stmt = true,
232            async {
233                match &self.inner {
234                    #[cfg(feature = "sqlx-mysql")]
235                    DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => {
236                        conn.query_one(stmt).await
237                    }
238                    #[cfg(feature = "sqlx-postgres")]
239                    DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
240                        conn.query_one(stmt).await
241                    }
242                    #[cfg(feature = "sqlx-sqlite")]
243                    DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => {
244                        conn.query_one(stmt).await
245                    }
246                    #[cfg(feature = "rusqlite")]
247                    DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.query_one(stmt),
248                    #[cfg(feature = "mock")]
249                    DatabaseConnectionType::MockDatabaseConnection(conn) => conn.query_one(stmt),
250                    #[cfg(feature = "proxy")]
251                    DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
252                        conn.query_one(stmt).await
253                    }
254                    DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
255                }
256            }
257        )
258    }
259
260    #[instrument(level = "trace")]
261    #[allow(unused_variables)]
262    async fn query_all_raw(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
263        super::tracing_spans::with_db_span!(
264            "sea_orm.query_all",
265            self.get_database_backend(),
266            stmt.sql.as_str(),
267            record_stmt = true,
268            async {
269                match &self.inner {
270                    #[cfg(feature = "sqlx-mysql")]
271                    DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => {
272                        conn.query_all(stmt).await
273                    }
274                    #[cfg(feature = "sqlx-postgres")]
275                    DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
276                        conn.query_all(stmt).await
277                    }
278                    #[cfg(feature = "sqlx-sqlite")]
279                    DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => {
280                        conn.query_all(stmt).await
281                    }
282                    #[cfg(feature = "rusqlite")]
283                    DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.query_all(stmt),
284                    #[cfg(feature = "mock")]
285                    DatabaseConnectionType::MockDatabaseConnection(conn) => conn.query_all(stmt),
286                    #[cfg(feature = "proxy")]
287                    DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
288                        conn.query_all(stmt).await
289                    }
290                    DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
291                }
292            }
293        )
294    }
295
296    #[cfg(feature = "mock")]
297    fn is_mock_connection(&self) -> bool {
298        matches!(
299            self,
300            DatabaseConnection {
301                inner: DatabaseConnectionType::MockDatabaseConnection(_),
302                ..
303            }
304        )
305    }
306}
307
308#[async_trait::async_trait]
309impl StreamTrait for DatabaseConnection {
310    type Stream<'a> = crate::QueryStream;
311
312    fn get_database_backend(&self) -> DbBackend {
313        self.get_database_backend()
314    }
315
316    #[instrument(level = "trace")]
317    #[allow(unused_variables)]
318    fn stream_raw<'a>(
319        &'a self,
320        stmt: Statement,
321    ) -> Pin<Box<dyn Future<Output = Result<Self::Stream<'a>, DbErr>> + 'a + Send>> {
322        Box::pin(async move {
323            match &self.inner {
324                #[cfg(feature = "sqlx-mysql")]
325                DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => conn.stream(stmt).await,
326                #[cfg(feature = "sqlx-postgres")]
327                DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => conn.stream(stmt).await,
328                #[cfg(feature = "sqlx-sqlite")]
329                DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => conn.stream(stmt).await,
330                #[cfg(feature = "rusqlite")]
331                DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.stream(stmt),
332                #[cfg(feature = "mock")]
333                DatabaseConnectionType::MockDatabaseConnection(conn) => {
334                    Ok(crate::QueryStream::from((Arc::clone(conn), stmt, None)))
335                }
336                #[cfg(feature = "proxy")]
337                DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
338                    Ok(crate::QueryStream::from((Arc::clone(conn), stmt, None)))
339                }
340                DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
341            }
342        })
343    }
344}
345
346#[async_trait::async_trait]
347impl TransactionTrait for DatabaseConnection {
348    type Transaction = DatabaseTransaction;
349
350    #[instrument(level = "trace")]
351    async fn begin(&self) -> Result<DatabaseTransaction, DbErr> {
352        match &self.inner {
353            #[cfg(feature = "sqlx-mysql")]
354            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => conn.begin(None, None).await,
355            #[cfg(feature = "sqlx-postgres")]
356            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
357                conn.begin(None, None).await
358            }
359            #[cfg(feature = "sqlx-sqlite")]
360            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => {
361                conn.begin(None, None, None).await
362            }
363            #[cfg(feature = "rusqlite")]
364            DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.begin(None, None, None),
365            #[cfg(feature = "mock")]
366            DatabaseConnectionType::MockDatabaseConnection(conn) => {
367                DatabaseTransaction::new_mock(Arc::clone(conn), None).await
368            }
369            #[cfg(feature = "proxy")]
370            DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
371                DatabaseTransaction::new_proxy(conn.clone(), None).await
372            }
373            DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
374        }
375    }
376
377    #[instrument(level = "trace")]
378    async fn begin_with_config(
379        &self,
380        _isolation_level: Option<IsolationLevel>,
381        _access_mode: Option<AccessMode>,
382    ) -> Result<DatabaseTransaction, DbErr> {
383        match &self.inner {
384            #[cfg(feature = "sqlx-mysql")]
385            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => {
386                conn.begin(_isolation_level, _access_mode).await
387            }
388            #[cfg(feature = "sqlx-postgres")]
389            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
390                conn.begin(_isolation_level, _access_mode).await
391            }
392            #[cfg(feature = "sqlx-sqlite")]
393            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => {
394                conn.begin(_isolation_level, _access_mode, None).await
395            }
396            #[cfg(feature = "rusqlite")]
397            DatabaseConnectionType::RusqliteSharedConnection(conn) => {
398                conn.begin(_isolation_level, _access_mode, None)
399            }
400            #[cfg(feature = "mock")]
401            DatabaseConnectionType::MockDatabaseConnection(conn) => {
402                DatabaseTransaction::new_mock(Arc::clone(conn), None).await
403            }
404            #[cfg(feature = "proxy")]
405            DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
406                DatabaseTransaction::new_proxy(conn.clone(), None).await
407            }
408            DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
409        }
410    }
411
412    #[instrument(level = "trace")]
413    async fn begin_with_options(
414        &self,
415        TransactionOptions {
416            isolation_level: _isolation_level,
417            access_mode: _access_mode,
418            sqlite_transaction_mode: _sqlite_transaction_mode,
419        }: TransactionOptions,
420    ) -> Result<DatabaseTransaction, DbErr> {
421        match &self.inner {
422            #[cfg(feature = "sqlx-mysql")]
423            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => {
424                conn.begin(_isolation_level, _access_mode).await
425            }
426            #[cfg(feature = "sqlx-postgres")]
427            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
428                conn.begin(_isolation_level, _access_mode).await
429            }
430            #[cfg(feature = "sqlx-sqlite")]
431            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => {
432                conn.begin(_isolation_level, _access_mode, _sqlite_transaction_mode)
433                    .await
434            }
435            #[cfg(feature = "rusqlite")]
436            DatabaseConnectionType::RusqliteSharedConnection(conn) => {
437                conn.begin(_isolation_level, _access_mode, _sqlite_transaction_mode)
438            }
439            #[cfg(feature = "mock")]
440            DatabaseConnectionType::MockDatabaseConnection(conn) => {
441                DatabaseTransaction::new_mock(Arc::clone(conn), None).await
442            }
443            #[cfg(feature = "proxy")]
444            DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
445                DatabaseTransaction::new_proxy(conn.clone(), None).await
446            }
447            DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
448        }
449    }
450
451    /// Execute the function inside a transaction.
452    /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed.
453    #[instrument(level = "trace", skip(_callback))]
454    async fn transaction<F, T, E>(&self, _callback: F) -> Result<T, TransactionError<E>>
455    where
456        F: for<'c> FnOnce(
457                &'c DatabaseTransaction,
458            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
459            + Send,
460        T: Send,
461        E: std::fmt::Display + std::fmt::Debug + Send,
462    {
463        match &self.inner {
464            #[cfg(feature = "sqlx-mysql")]
465            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => {
466                conn.transaction(_callback, None, None).await
467            }
468            #[cfg(feature = "sqlx-postgres")]
469            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
470                conn.transaction(_callback, None, None).await
471            }
472            #[cfg(feature = "sqlx-sqlite")]
473            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => {
474                conn.transaction(_callback, None, None).await
475            }
476            #[cfg(feature = "rusqlite")]
477            DatabaseConnectionType::RusqliteSharedConnection(conn) => {
478                conn.transaction(_callback, None, None)
479            }
480            #[cfg(feature = "mock")]
481            DatabaseConnectionType::MockDatabaseConnection(conn) => {
482                let transaction = DatabaseTransaction::new_mock(Arc::clone(conn), None)
483                    .await
484                    .map_err(TransactionError::Connection)?;
485                transaction.run(_callback).await
486            }
487            #[cfg(feature = "proxy")]
488            DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
489                let transaction = DatabaseTransaction::new_proxy(conn.clone(), None)
490                    .await
491                    .map_err(TransactionError::Connection)?;
492                transaction.run(_callback).await
493            }
494            DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected").into()),
495        }
496    }
497
498    /// Execute the function inside a transaction.
499    /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed.
500    #[instrument(level = "trace", skip(_callback))]
501    async fn transaction_with_config<F, T, E>(
502        &self,
503        _callback: F,
504        _isolation_level: Option<IsolationLevel>,
505        _access_mode: Option<AccessMode>,
506    ) -> Result<T, TransactionError<E>>
507    where
508        F: for<'c> FnOnce(
509                &'c DatabaseTransaction,
510            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
511            + Send,
512        T: Send,
513        E: std::fmt::Display + std::fmt::Debug + Send,
514    {
515        match &self.inner {
516            #[cfg(feature = "sqlx-mysql")]
517            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => {
518                conn.transaction(_callback, _isolation_level, _access_mode)
519                    .await
520            }
521            #[cfg(feature = "sqlx-postgres")]
522            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
523                conn.transaction(_callback, _isolation_level, _access_mode)
524                    .await
525            }
526            #[cfg(feature = "sqlx-sqlite")]
527            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => {
528                conn.transaction(_callback, _isolation_level, _access_mode)
529                    .await
530            }
531            #[cfg(feature = "rusqlite")]
532            DatabaseConnectionType::RusqliteSharedConnection(conn) => {
533                conn.transaction(_callback, _isolation_level, _access_mode)
534            }
535            #[cfg(feature = "mock")]
536            DatabaseConnectionType::MockDatabaseConnection(conn) => {
537                let transaction = DatabaseTransaction::new_mock(Arc::clone(conn), None)
538                    .await
539                    .map_err(TransactionError::Connection)?;
540                transaction.run(_callback).await
541            }
542            #[cfg(feature = "proxy")]
543            DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
544                let transaction = DatabaseTransaction::new_proxy(conn.clone(), None)
545                    .await
546                    .map_err(TransactionError::Connection)?;
547                transaction.run(_callback).await
548            }
549            DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected").into()),
550        }
551    }
552}
553
554#[cfg(feature = "mock")]
555impl DatabaseConnection {
556    /// Generate a database connection for testing the Mock database
557    ///
558    /// # Panics
559    ///
560    /// Panics if [DbConn] is not a mock connection.
561    pub fn as_mock_connection(&self) -> &crate::MockDatabaseConnection {
562        match &self.inner {
563            DatabaseConnectionType::MockDatabaseConnection(mock_conn) => mock_conn,
564            _ => panic!("Not mock connection"),
565        }
566    }
567
568    /// Get the transaction log as a collection Vec<[crate::Transaction]>
569    ///
570    /// # Panics
571    ///
572    /// Panics if the mocker mutex is being held by another thread.
573    pub fn into_transaction_log(self) -> Vec<crate::Transaction> {
574        let mut mocker = self
575            .as_mock_connection()
576            .get_mocker_mutex()
577            .lock()
578            .expect("Fail to acquire mocker");
579        mocker.drain_transaction_log()
580    }
581}
582
583#[cfg(feature = "proxy")]
584impl DatabaseConnection {
585    /// Generate a database connection for testing the Proxy database
586    ///
587    /// # Panics
588    ///
589    /// Panics if [DbConn] is not a proxy connection.
590    pub fn as_proxy_connection(&self) -> &crate::ProxyDatabaseConnection {
591        match &self.inner {
592            DatabaseConnectionType::ProxyDatabaseConnection(proxy_conn) => proxy_conn,
593            _ => panic!("Not proxy connection"),
594        }
595    }
596}
597
598#[cfg(feature = "rbac")]
599impl DatabaseConnection {
600    /// Load RBAC data from the same database as this connection and setup RBAC engine.
601    /// If the RBAC engine already exists, it will be replaced.
602    pub async fn load_rbac(&self) -> Result<(), DbErr> {
603        self.load_rbac_from(self).await
604    }
605
606    /// Load RBAC data from the given database connection and setup RBAC engine.
607    /// This could be from another database.
608    pub async fn load_rbac_from(&self, db: &DbConn) -> Result<(), DbErr> {
609        let engine = crate::rbac::RbacEngine::load_from(db).await?;
610        self.rbac.replace(engine);
611        Ok(())
612    }
613
614    /// Replace the internal RBAC engine.
615    pub fn replace_rbac(&self, engine: crate::rbac::RbacEngine) {
616        self.rbac.replace(engine);
617    }
618
619    /// Create a restricted connection with access control specific for the user.
620    pub fn restricted_for(
621        &self,
622        user_id: crate::rbac::RbacUserId,
623    ) -> Result<crate::RestrictedConnection, DbErr> {
624        if self.rbac.is_some() {
625            Ok(crate::RestrictedConnection {
626                user_id,
627                conn: self.clone(),
628            })
629        } else {
630            Err(DbErr::RbacError("engine not set up".into()))
631        }
632    }
633}
634
635impl DatabaseConnection {
636    /// Get the database backend for this connection
637    ///
638    /// # Panics
639    ///
640    /// Panics if [DatabaseConnection] is `Disconnected`.
641    pub fn get_database_backend(&self) -> DbBackend {
642        match &self.inner {
643            #[cfg(feature = "sqlx-mysql")]
644            DatabaseConnectionType::SqlxMySqlPoolConnection(_) => DbBackend::MySql,
645            #[cfg(feature = "sqlx-postgres")]
646            DatabaseConnectionType::SqlxPostgresPoolConnection(_) => DbBackend::Postgres,
647            #[cfg(feature = "sqlx-sqlite")]
648            DatabaseConnectionType::SqlxSqlitePoolConnection(_) => DbBackend::Sqlite,
649            #[cfg(feature = "rusqlite")]
650            DatabaseConnectionType::RusqliteSharedConnection(_) => DbBackend::Sqlite,
651            #[cfg(feature = "mock")]
652            DatabaseConnectionType::MockDatabaseConnection(conn) => conn.get_database_backend(),
653            #[cfg(feature = "proxy")]
654            DatabaseConnectionType::ProxyDatabaseConnection(conn) => conn.get_database_backend(),
655            DatabaseConnectionType::Disconnected => panic!("Disconnected"),
656        }
657    }
658
659    /// Creates a [`SchemaBuilder`] for this backend
660    pub fn get_schema_builder(&self) -> SchemaBuilder {
661        Schema::new(self.get_database_backend()).builder()
662    }
663
664    #[cfg(feature = "entity-registry")]
665    #[cfg_attr(docsrs, doc(cfg(feature = "entity-registry")))]
666    /// Builds a schema for all the entites in the given module
667    pub fn get_schema_registry(&self, prefix: &str) -> SchemaBuilder {
668        let schema = Schema::new(self.get_database_backend());
669        crate::EntityRegistry::build_schema(schema, prefix)
670    }
671
672    /// Sets a callback to metric this connection
673    pub fn set_metric_callback<F>(&mut self, _callback: F)
674    where
675        F: Fn(&crate::metric::Info<'_>) + Send + Sync + 'static,
676    {
677        match &mut self.inner {
678            #[cfg(feature = "sqlx-mysql")]
679            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => {
680                conn.set_metric_callback(_callback)
681            }
682            #[cfg(feature = "sqlx-postgres")]
683            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
684                conn.set_metric_callback(_callback)
685            }
686            #[cfg(feature = "sqlx-sqlite")]
687            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => {
688                conn.set_metric_callback(_callback)
689            }
690            #[cfg(feature = "rusqlite")]
691            DatabaseConnectionType::RusqliteSharedConnection(conn) => {
692                conn.set_metric_callback(_callback)
693            }
694            _ => {}
695        }
696    }
697
698    /// Checks if a connection to the database is still valid.
699    pub async fn ping(&self) -> Result<(), DbErr> {
700        match &self.inner {
701            #[cfg(feature = "sqlx-mysql")]
702            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => conn.ping().await,
703            #[cfg(feature = "sqlx-postgres")]
704            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => conn.ping().await,
705            #[cfg(feature = "sqlx-sqlite")]
706            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => conn.ping().await,
707            #[cfg(feature = "rusqlite")]
708            DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.ping(),
709            #[cfg(feature = "mock")]
710            DatabaseConnectionType::MockDatabaseConnection(conn) => conn.ping(),
711            #[cfg(feature = "proxy")]
712            DatabaseConnectionType::ProxyDatabaseConnection(conn) => conn.ping().await,
713            DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
714        }
715    }
716
717    /// Explicitly close the database connection.
718    /// See [`Self::close_by_ref`] for usage with references.
719    pub async fn close(self) -> Result<(), DbErr> {
720        self.close_by_ref().await
721    }
722
723    /// Explicitly close the database connection
724    pub async fn close_by_ref(&self) -> Result<(), DbErr> {
725        match &self.inner {
726            #[cfg(feature = "sqlx-mysql")]
727            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => conn.close_by_ref().await,
728            #[cfg(feature = "sqlx-postgres")]
729            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => conn.close_by_ref().await,
730            #[cfg(feature = "sqlx-sqlite")]
731            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => conn.close_by_ref().await,
732            #[cfg(feature = "rusqlite")]
733            DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.close_by_ref(),
734            #[cfg(feature = "mock")]
735            DatabaseConnectionType::MockDatabaseConnection(_) => {
736                // Nothing to cleanup, we just consume the `DatabaseConnection`
737                Ok(())
738            }
739            #[cfg(feature = "proxy")]
740            DatabaseConnectionType::ProxyDatabaseConnection(_) => {
741                // Nothing to cleanup, we just consume the `DatabaseConnection`
742                Ok(())
743            }
744            DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
745        }
746    }
747}
748
749impl DatabaseConnection {
750    /// Get [sqlx::MySqlPool]
751    ///
752    /// # Panics
753    ///
754    /// Panics if [DbConn] is not a MySQL connection.
755    #[cfg(feature = "sqlx-mysql")]
756    pub fn get_mysql_connection_pool(&self) -> &sqlx::MySqlPool {
757        match &self.inner {
758            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => &conn.pool,
759            _ => panic!("Not MySQL Connection"),
760        }
761    }
762
763    /// Get [sqlx::PgPool]
764    ///
765    /// # Panics
766    ///
767    /// Panics if [DbConn] is not a Postgres connection.
768    #[cfg(feature = "sqlx-postgres")]
769    pub fn get_postgres_connection_pool(&self) -> &sqlx::PgPool {
770        match &self.inner {
771            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => &conn.pool,
772            _ => panic!("Not Postgres Connection"),
773        }
774    }
775
776    /// Get [sqlx::SqlitePool]
777    ///
778    /// # Panics
779    ///
780    /// Panics if [DbConn] is not a SQLite connection.
781    #[cfg(feature = "sqlx-sqlite")]
782    pub fn get_sqlite_connection_pool(&self) -> &sqlx::SqlitePool {
783        match &self.inner {
784            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => &conn.pool,
785            _ => panic!("Not SQLite Connection"),
786        }
787    }
788}
789
790impl DbBackend {
791    /// Check if the URI is the same as the specified database backend.
792    /// Returns true if they match.
793    ///
794    /// # Panics
795    ///
796    /// Panics if `base_url` cannot be parsed as `Url`.
797    pub fn is_prefix_of(self, base_url: &str) -> bool {
798        let base_url_parsed = Url::parse(base_url).expect("Fail to parse database URL");
799        match self {
800            Self::Postgres => {
801                base_url_parsed.scheme() == "postgres" || base_url_parsed.scheme() == "postgresql"
802            }
803            Self::MySql => base_url_parsed.scheme() == "mysql",
804            Self::Sqlite => base_url_parsed.scheme() == "sqlite",
805        }
806    }
807
808    /// Build an SQL [Statement]
809    pub fn build<S>(&self, statement: &S) -> Statement
810    where
811        S: StatementBuilder,
812    {
813        statement.build(self)
814    }
815
816    /// Check if the database supports `RETURNING` syntax on insert and update
817    pub fn support_returning(&self) -> bool {
818        match self {
819            Self::Postgres => true,
820            Self::Sqlite if cfg!(feature = "sqlite-use-returning-for-3_35") => true,
821            Self::MySql if cfg!(feature = "mariadb-use-returning") => true,
822            _ => false,
823        }
824    }
825
826    /// A getter for database dependent boolean value
827    pub fn boolean_value(&self, boolean: bool) -> sea_query::Value {
828        match self {
829            Self::MySql | Self::Postgres | Self::Sqlite => boolean.into(),
830        }
831    }
832
833    /// Get the display string for this enum
834    pub fn as_str(&self) -> &'static str {
835        match self {
836            DatabaseBackend::MySql => "MySql",
837            DatabaseBackend::Postgres => "Postgres",
838            DatabaseBackend::Sqlite => "Sqlite",
839        }
840    }
841}
842
843#[cfg(test)]
844mod tests {
845    use crate::DatabaseConnection;
846
847    #[cfg(not(feature = "sync"))]
848    #[test]
849    fn assert_database_connection_traits() {
850        fn assert_send_sync<T: Send + Sync>() {}
851
852        assert_send_sync::<DatabaseConnection>();
853    }
854}