Skip to main content

sea_orm/database/
db_connection.rs

1use super::transaction::run_async_transaction_callback;
2use crate::{
3    AccessMode, ConnectionTrait, DatabaseTransaction, ExecResult, IsolationLevel, QueryResult,
4    Schema, SchemaBuilder, Statement, StatementBuilder, TransactionError, TransactionOptions,
5    TransactionTrait, error::*,
6};
7use std::fmt::Debug;
8use tracing::instrument;
9use url::Url;
10
11#[cfg(feature = "sqlx-dep")]
12use sqlx::pool::PoolConnection;
13
14#[cfg(feature = "rusqlite")]
15use crate::driver::rusqlite::{RusqliteInnerConnection, RusqliteSharedConnection};
16
17#[cfg(feature = "stream")]
18use crate::StreamTrait;
19
20#[cfg(any(feature = "mock", feature = "proxy"))]
21use std::sync::Arc;
22
23/// A handle to a database — implements [`ConnectionTrait`](crate::ConnectionTrait)
24/// and [`TransactionTrait`](crate::TransactionTrait) so it works with every
25/// query and mutation method in SeaORM.
26///
27/// Behind the scenes this is a connection pool (for SQLx-backed drivers) or
28/// a shared connection (for `rusqlite` / mocks / proxies), so it is cheap
29/// to clone — pass `&DbConn` around or `db.clone()` into spawned tasks.
30/// Obtain one via [`Database::connect`](crate::Database::connect).
31#[derive(Debug, Clone)]
32#[non_exhaustive]
33pub struct DatabaseConnection {
34    /// Driver-specific connection or pool. Held in a field so we can attach
35    /// orthogonal state (e.g. RBAC) alongside.
36    pub inner: DatabaseConnectionType,
37    #[cfg(feature = "rbac")]
38    pub(crate) rbac: crate::RbacEngineMount,
39}
40
41/// The driver-specific connection or pool wrapped by [`DatabaseConnection`].
42///
43/// Which variants are available depends on enabled feature flags. End users
44/// rarely match on this directly; use [`DatabaseConnection`]'s methods
45/// instead.
46#[derive(Clone)]
47pub enum DatabaseConnectionType {
48    /// MySQL connection pool (`sqlx-mysql`).
49    #[cfg(feature = "sqlx-mysql")]
50    SqlxMySqlPoolConnection(crate::SqlxMySqlPoolConnection),
51
52    /// PostgreSQL connection pool (`sqlx-postgres`).
53    #[cfg(feature = "sqlx-postgres")]
54    SqlxPostgresPoolConnection(crate::SqlxPostgresPoolConnection),
55
56    /// SQLite connection pool (`sqlx-sqlite`).
57    #[cfg(feature = "sqlx-sqlite")]
58    SqlxSqlitePoolConnection(crate::SqlxSqlitePoolConnection),
59
60    /// SQLite connection shared across threads (`rusqlite`).
61    #[cfg(feature = "rusqlite")]
62    RusqliteSharedConnection(RusqliteSharedConnection),
63
64    /// In-memory mock connection used for testing (`mock`).
65    #[cfg(feature = "mock")]
66    MockDatabaseConnection(Arc<crate::MockDatabaseConnection>),
67
68    /// Proxy connection that forwards statements to a user callback (`proxy`).
69    #[cfg(feature = "proxy")]
70    ProxyDatabaseConnection(Arc<crate::ProxyDatabaseConnection>),
71
72    /// Sentinel for an unconnected [`DatabaseConnection`] (default value);
73    /// any query against it returns an error.
74    Disconnected,
75}
76
77/// Short alias for [`DatabaseConnection`].
78pub type DbConn = DatabaseConnection;
79
80impl Default for DatabaseConnection {
81    fn default() -> Self {
82        DatabaseConnectionType::Disconnected.into()
83    }
84}
85
86impl From<DatabaseConnectionType> for DatabaseConnection {
87    fn from(inner: DatabaseConnectionType) -> Self {
88        Self {
89            inner,
90            #[cfg(feature = "rbac")]
91            rbac: Default::default(),
92        }
93    }
94}
95
96/// Identifies which SQL dialect is in use. Passed around so that
97/// `sea_query`-built statements can be rendered with the right placeholders,
98/// quoting, and feature support. Available variants are gated by feature
99/// flags — see the [crate-level documentation](crate).
100#[derive(Debug, Copy, Clone, PartialEq, Eq)]
101#[non_exhaustive]
102pub enum DatabaseBackend {
103    /// MySQL / MariaDB.
104    MySql,
105    /// PostgreSQL.
106    Postgres,
107    /// SQLite.
108    Sqlite,
109}
110
111/// Short alias for [`DatabaseBackend`].
112pub type DbBackend = DatabaseBackend;
113
114#[derive(Debug)]
115pub(crate) enum InnerConnection {
116    #[cfg(feature = "sqlx-mysql")]
117    MySql(PoolConnection<sqlx::MySql>),
118    #[cfg(feature = "sqlx-postgres")]
119    Postgres(PoolConnection<sqlx::Postgres>),
120    #[cfg(feature = "sqlx-sqlite")]
121    Sqlite(PoolConnection<sqlx::Sqlite>),
122    #[cfg(feature = "rusqlite")]
123    Rusqlite(RusqliteInnerConnection),
124    #[cfg(feature = "mock")]
125    Mock(Arc<crate::MockDatabaseConnection>),
126    #[cfg(feature = "proxy")]
127    Proxy(Arc<crate::ProxyDatabaseConnection>),
128}
129
130impl Debug for DatabaseConnectionType {
131    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
132        write!(
133            f,
134            "{}",
135            match self {
136                #[cfg(feature = "sqlx-mysql")]
137                Self::SqlxMySqlPoolConnection(_) => "SqlxMySqlPoolConnection",
138                #[cfg(feature = "sqlx-postgres")]
139                Self::SqlxPostgresPoolConnection(_) => "SqlxPostgresPoolConnection",
140                #[cfg(feature = "sqlx-sqlite")]
141                Self::SqlxSqlitePoolConnection(_) => "SqlxSqlitePoolConnection",
142                #[cfg(feature = "rusqlite")]
143                Self::RusqliteSharedConnection(_) => "RusqliteSharedConnection",
144                #[cfg(feature = "mock")]
145                Self::MockDatabaseConnection(_) => "MockDatabaseConnection",
146                #[cfg(feature = "proxy")]
147                Self::ProxyDatabaseConnection(_) => "ProxyDatabaseConnection",
148                Self::Disconnected => "Disconnected",
149            }
150        )
151    }
152}
153
154impl ConnectionTrait for DatabaseConnection {
155    fn get_database_backend(&self) -> DbBackend {
156        self.get_database_backend()
157    }
158
159    #[instrument(level = "trace", skip(stmt))]
160    #[allow(unused_variables)]
161    fn execute_raw(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
162        super::tracing_spans::with_db_span!(
163            "sea_orm.execute",
164            self.get_database_backend(),
165            stmt.sql.as_str(),
166            record_stmt = self.get_record_stmt_in_spans(),
167            {
168                match &self.inner {
169                    #[cfg(feature = "sqlx-mysql")]
170                    DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => conn.execute(stmt),
171                    #[cfg(feature = "sqlx-postgres")]
172                    DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => conn.execute(stmt),
173                    #[cfg(feature = "sqlx-sqlite")]
174                    DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => conn.execute(stmt),
175                    #[cfg(feature = "rusqlite")]
176                    DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.execute(stmt),
177                    #[cfg(feature = "mock")]
178                    DatabaseConnectionType::MockDatabaseConnection(conn) => conn.execute(stmt),
179                    #[cfg(feature = "proxy")]
180                    DatabaseConnectionType::ProxyDatabaseConnection(conn) => conn.execute(stmt),
181                    DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
182                }
183            }
184        )
185    }
186
187    #[instrument(level = "trace", skip(sql))]
188    #[allow(unused_variables)]
189    fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
190        super::tracing_spans::with_db_span!(
191            "sea_orm.execute_unprepared",
192            self.get_database_backend(),
193            sql,
194            record_stmt = false,
195            {
196                match &self.inner {
197                    #[cfg(feature = "sqlx-mysql")]
198                    DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => {
199                        conn.execute_unprepared(sql)
200                    }
201                    #[cfg(feature = "sqlx-postgres")]
202                    DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
203                        conn.execute_unprepared(sql)
204                    }
205                    #[cfg(feature = "sqlx-sqlite")]
206                    DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => {
207                        conn.execute_unprepared(sql)
208                    }
209                    #[cfg(feature = "rusqlite")]
210                    DatabaseConnectionType::RusqliteSharedConnection(conn) => {
211                        conn.execute_unprepared(sql)
212                    }
213                    #[cfg(feature = "mock")]
214                    DatabaseConnectionType::MockDatabaseConnection(conn) => {
215                        let db_backend = conn.get_database_backend();
216                        let stmt = Statement::from_string(db_backend, sql);
217                        conn.execute(stmt)
218                    }
219                    #[cfg(feature = "proxy")]
220                    DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
221                        let db_backend = conn.get_database_backend();
222                        let stmt = Statement::from_string(db_backend, sql);
223                        conn.execute(stmt)
224                    }
225                    DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
226                }
227            }
228        )
229    }
230
231    #[instrument(level = "trace", skip(stmt))]
232    #[allow(unused_variables)]
233    fn query_one_raw(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
234        super::tracing_spans::with_db_span!(
235            "sea_orm.query_one",
236            self.get_database_backend(),
237            stmt.sql.as_str(),
238            record_stmt = self.get_record_stmt_in_spans(),
239            {
240                match &self.inner {
241                    #[cfg(feature = "sqlx-mysql")]
242                    DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => conn.query_one(stmt),
243                    #[cfg(feature = "sqlx-postgres")]
244                    DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
245                        conn.query_one(stmt)
246                    }
247                    #[cfg(feature = "sqlx-sqlite")]
248                    DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => conn.query_one(stmt),
249                    #[cfg(feature = "rusqlite")]
250                    DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.query_one(stmt),
251                    #[cfg(feature = "mock")]
252                    DatabaseConnectionType::MockDatabaseConnection(conn) => conn.query_one(stmt),
253                    #[cfg(feature = "proxy")]
254                    DatabaseConnectionType::ProxyDatabaseConnection(conn) => conn.query_one(stmt),
255                    DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
256                }
257            }
258        )
259    }
260
261    #[instrument(level = "trace", skip(stmt))]
262    #[allow(unused_variables)]
263    fn query_all_raw(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
264        super::tracing_spans::with_db_span!(
265            "sea_orm.query_all",
266            self.get_database_backend(),
267            stmt.sql.as_str(),
268            record_stmt = self.get_record_stmt_in_spans(),
269            {
270                match &self.inner {
271                    #[cfg(feature = "sqlx-mysql")]
272                    DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => conn.query_all(stmt),
273                    #[cfg(feature = "sqlx-postgres")]
274                    DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
275                        conn.query_all(stmt)
276                    }
277                    #[cfg(feature = "sqlx-sqlite")]
278                    DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => conn.query_all(stmt),
279                    #[cfg(feature = "rusqlite")]
280                    DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.query_all(stmt),
281                    #[cfg(feature = "mock")]
282                    DatabaseConnectionType::MockDatabaseConnection(conn) => conn.query_all(stmt),
283                    #[cfg(feature = "proxy")]
284                    DatabaseConnectionType::ProxyDatabaseConnection(conn) => conn.query_all(stmt),
285                    DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
286                }
287            }
288        )
289    }
290
291    #[cfg(feature = "mock")]
292    fn is_mock_connection(&self) -> bool {
293        matches!(
294            self,
295            DatabaseConnection {
296                inner: DatabaseConnectionType::MockDatabaseConnection(_),
297                ..
298            }
299        )
300    }
301}
302
303#[cfg(feature = "stream")]
304impl StreamTrait for DatabaseConnection {
305    type Stream<'a> = crate::QueryStream;
306
307    fn get_database_backend(&self) -> DbBackend {
308        self.get_database_backend()
309    }
310
311    #[instrument(level = "trace", skip(stmt))]
312    #[allow(unused_variables)]
313    fn stream_raw<'a>(&'a self, stmt: Statement) -> Result<Self::Stream<'a>, DbErr> {
314        ({
315            match &self.inner {
316                #[cfg(feature = "sqlx-mysql")]
317                DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => conn.stream(stmt),
318                #[cfg(feature = "sqlx-postgres")]
319                DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => conn.stream(stmt),
320                #[cfg(feature = "sqlx-sqlite")]
321                DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => conn.stream(stmt),
322                #[cfg(feature = "rusqlite")]
323                DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.stream(stmt),
324                #[cfg(feature = "mock")]
325                DatabaseConnectionType::MockDatabaseConnection(conn) => {
326                    Ok(crate::QueryStream::from((Arc::clone(conn), stmt, None)))
327                }
328                #[cfg(feature = "proxy")]
329                DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
330                    Ok(crate::QueryStream::from((Arc::clone(conn), stmt, None)))
331                }
332                DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
333            }
334        })
335    }
336}
337
338impl TransactionTrait for DatabaseConnection {
339    type Transaction = DatabaseTransaction;
340
341    #[instrument(level = "trace")]
342    fn begin(&self) -> Result<DatabaseTransaction, DbErr> {
343        match &self.inner {
344            #[cfg(feature = "sqlx-mysql")]
345            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => conn.begin(None, None),
346            #[cfg(feature = "sqlx-postgres")]
347            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => conn.begin(None, None),
348            #[cfg(feature = "sqlx-sqlite")]
349            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => conn.begin(None, None, None),
350            #[cfg(feature = "rusqlite")]
351            DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.begin(None, None, None),
352            #[cfg(feature = "mock")]
353            DatabaseConnectionType::MockDatabaseConnection(conn) => {
354                DatabaseTransaction::new_mock(Arc::clone(conn), None)
355            }
356            #[cfg(feature = "proxy")]
357            DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
358                DatabaseTransaction::new_proxy(conn.clone(), None)
359            }
360            DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
361        }
362    }
363
364    #[instrument(level = "trace")]
365    fn begin_with_config(
366        &self,
367        _isolation_level: Option<IsolationLevel>,
368        _access_mode: Option<AccessMode>,
369    ) -> Result<DatabaseTransaction, DbErr> {
370        match &self.inner {
371            #[cfg(feature = "sqlx-mysql")]
372            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => {
373                conn.begin(_isolation_level, _access_mode)
374            }
375            #[cfg(feature = "sqlx-postgres")]
376            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
377                conn.begin(_isolation_level, _access_mode)
378            }
379            #[cfg(feature = "sqlx-sqlite")]
380            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => {
381                conn.begin(_isolation_level, _access_mode, None)
382            }
383            #[cfg(feature = "rusqlite")]
384            DatabaseConnectionType::RusqliteSharedConnection(conn) => {
385                conn.begin(_isolation_level, _access_mode, None)
386            }
387            #[cfg(feature = "mock")]
388            DatabaseConnectionType::MockDatabaseConnection(conn) => {
389                DatabaseTransaction::new_mock(Arc::clone(conn), None)
390            }
391            #[cfg(feature = "proxy")]
392            DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
393                DatabaseTransaction::new_proxy(conn.clone(), None)
394            }
395            DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
396        }
397    }
398
399    #[instrument(level = "trace")]
400    fn begin_with_options(
401        &self,
402        TransactionOptions {
403            isolation_level: _isolation_level,
404            access_mode: _access_mode,
405            sqlite_transaction_mode: _sqlite_transaction_mode,
406        }: TransactionOptions,
407    ) -> Result<DatabaseTransaction, DbErr> {
408        match &self.inner {
409            #[cfg(feature = "sqlx-mysql")]
410            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => {
411                conn.begin(_isolation_level, _access_mode)
412            }
413            #[cfg(feature = "sqlx-postgres")]
414            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
415                conn.begin(_isolation_level, _access_mode)
416            }
417            #[cfg(feature = "sqlx-sqlite")]
418            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => {
419                conn.begin(_isolation_level, _access_mode, _sqlite_transaction_mode)
420            }
421            #[cfg(feature = "rusqlite")]
422            DatabaseConnectionType::RusqliteSharedConnection(conn) => {
423                conn.begin(_isolation_level, _access_mode, _sqlite_transaction_mode)
424            }
425            #[cfg(feature = "mock")]
426            DatabaseConnectionType::MockDatabaseConnection(conn) => {
427                DatabaseTransaction::new_mock(Arc::clone(conn), None)
428            }
429            #[cfg(feature = "proxy")]
430            DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
431                DatabaseTransaction::new_proxy(conn.clone(), None)
432            }
433            DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
434        }
435    }
436
437    /// Execute the function inside a transaction.
438    /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed.
439    #[instrument(level = "trace", skip(_callback))]
440    fn transaction<F, T, E>(&self, _callback: F) -> Result<T, TransactionError<E>>
441    where
442        F: for<'c> FnOnce(&'c DatabaseTransaction) -> Result<T, E>,
443        E: std::fmt::Display + std::fmt::Debug,
444    {
445        match &self.inner {
446            #[cfg(feature = "sqlx-mysql")]
447            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => {
448                conn.transaction(_callback, None, None)
449            }
450            #[cfg(feature = "sqlx-postgres")]
451            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
452                conn.transaction(_callback, None, None)
453            }
454            #[cfg(feature = "sqlx-sqlite")]
455            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => {
456                conn.transaction(_callback, None, None)
457            }
458            #[cfg(feature = "rusqlite")]
459            DatabaseConnectionType::RusqliteSharedConnection(conn) => {
460                conn.transaction(_callback, None, None)
461            }
462            #[cfg(feature = "mock")]
463            DatabaseConnectionType::MockDatabaseConnection(conn) => {
464                let transaction = DatabaseTransaction::new_mock(Arc::clone(conn), None)
465                    .map_err(TransactionError::Connection)?;
466                transaction.run(_callback)
467            }
468            #[cfg(feature = "proxy")]
469            DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
470                let transaction = DatabaseTransaction::new_proxy(conn.clone(), None)
471                    .map_err(TransactionError::Connection)?;
472                transaction.run(_callback)
473            }
474            DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected").into()),
475        }
476    }
477
478    /// Execute the function inside a transaction.
479    /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed.
480    #[instrument(level = "trace", skip(_callback))]
481    fn transaction_with_config<F, T, E>(
482        &self,
483        _callback: F,
484        _isolation_level: Option<IsolationLevel>,
485        _access_mode: Option<AccessMode>,
486    ) -> Result<T, TransactionError<E>>
487    where
488        F: for<'c> FnOnce(&'c DatabaseTransaction) -> Result<T, E>,
489        E: std::fmt::Display + std::fmt::Debug,
490    {
491        match &self.inner {
492            #[cfg(feature = "sqlx-mysql")]
493            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => {
494                conn.transaction(_callback, _isolation_level, _access_mode)
495            }
496            #[cfg(feature = "sqlx-postgres")]
497            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
498                conn.transaction(_callback, _isolation_level, _access_mode)
499            }
500            #[cfg(feature = "sqlx-sqlite")]
501            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => {
502                conn.transaction(_callback, _isolation_level, _access_mode)
503            }
504            #[cfg(feature = "rusqlite")]
505            DatabaseConnectionType::RusqliteSharedConnection(conn) => {
506                conn.transaction(_callback, _isolation_level, _access_mode)
507            }
508            #[cfg(feature = "mock")]
509            DatabaseConnectionType::MockDatabaseConnection(conn) => {
510                let transaction = DatabaseTransaction::new_mock(Arc::clone(conn), None)
511                    .map_err(TransactionError::Connection)?;
512                transaction.run(_callback)
513            }
514            #[cfg(feature = "proxy")]
515            DatabaseConnectionType::ProxyDatabaseConnection(conn) => {
516                let transaction = DatabaseTransaction::new_proxy(conn.clone(), None)
517                    .map_err(TransactionError::Connection)?;
518                transaction.run(_callback)
519            }
520            DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected").into()),
521        }
522    }
523}
524
525#[cfg(feature = "mock")]
526impl DatabaseConnection {
527    /// Generate a database connection for testing the Mock database
528    ///
529    /// # Panics
530    ///
531    /// Panics if [DbConn] is not a mock connection.
532    pub fn as_mock_connection(&self) -> &crate::MockDatabaseConnection {
533        match &self.inner {
534            DatabaseConnectionType::MockDatabaseConnection(mock_conn) => mock_conn,
535            _ => panic!("Not mock connection"),
536        }
537    }
538
539    /// Get the transaction log as a collection Vec<[crate::Transaction]>
540    ///
541    /// # Panics
542    ///
543    /// Panics if the mocker mutex is being held by another thread.
544    pub fn into_transaction_log(self) -> Vec<crate::Transaction> {
545        let mut mocker = self
546            .as_mock_connection()
547            .get_mocker_mutex()
548            .lock()
549            .expect("Fail to acquire mocker");
550        mocker.drain_transaction_log()
551    }
552}
553
554#[cfg(feature = "proxy")]
555impl DatabaseConnection {
556    /// Generate a database connection for testing the Proxy database
557    ///
558    /// # Panics
559    ///
560    /// Panics if [DbConn] is not a proxy connection.
561    pub fn as_proxy_connection(&self) -> &crate::ProxyDatabaseConnection {
562        match &self.inner {
563            DatabaseConnectionType::ProxyDatabaseConnection(proxy_conn) => proxy_conn,
564            _ => panic!("Not proxy connection"),
565        }
566    }
567}
568
569#[cfg(feature = "rbac")]
570impl DatabaseConnection {
571    /// Load RBAC data from the same database as this connection and setup RBAC engine.
572    /// If the RBAC engine already exists, it will be replaced.
573    pub fn load_rbac(&self) -> Result<(), DbErr> {
574        self.load_rbac_from(self)
575    }
576
577    /// Load RBAC data from the given database connection and setup RBAC engine.
578    /// This could be from another database.
579    pub fn load_rbac_from(&self, db: &DbConn) -> Result<(), DbErr> {
580        let engine = crate::rbac::RbacEngine::load_from(db)?;
581        self.rbac.replace(engine);
582        Ok(())
583    }
584
585    /// Replace the internal RBAC engine.
586    pub fn replace_rbac(&self, engine: crate::rbac::RbacEngine) {
587        self.rbac.replace(engine);
588    }
589
590    /// Create a restricted connection with access control specific for the user.
591    pub fn restricted_for(
592        &self,
593        user_id: crate::rbac::RbacUserId,
594    ) -> Result<crate::RestrictedConnection, DbErr> {
595        if self.rbac.is_some() {
596            Ok(crate::RestrictedConnection {
597                user_id,
598                conn: self.clone(),
599            })
600        } else {
601            Err(DbErr::RbacError("engine not set up".into()))
602        }
603    }
604}
605
606impl DatabaseConnection {
607    /// Execute the function inside a transaction.
608    /// If the function returns an error, the transaction will be rolled back.
609    /// Otherwise, the transaction will be committed.
610    #[instrument(level = "trace", skip(callback))]
611    pub fn transaction<F, T, E>(&self, callback: F) -> Result<T, TransactionError<E>>
612    where
613        F: for<'c> FnOnce(&'c DatabaseTransaction) -> Result<T, E>,
614        E: std::fmt::Display + std::fmt::Debug,
615    {
616        let transaction = self.begin().map_err(TransactionError::Connection)?;
617        run_async_transaction_callback(transaction, callback)
618    }
619
620    /// Execute the function inside a transaction with isolation level and/or access mode.
621    /// If the function returns an error, the transaction will be rolled back.
622    /// Otherwise, the transaction will be committed.
623    #[instrument(level = "trace", skip(callback))]
624    pub fn transaction_with_config<F, T, E>(
625        &self,
626        callback: F,
627        isolation_level: Option<IsolationLevel>,
628        access_mode: Option<AccessMode>,
629    ) -> Result<T, TransactionError<E>>
630    where
631        F: for<'c> FnOnce(&'c DatabaseTransaction) -> Result<T, E>,
632        E: std::fmt::Display + std::fmt::Debug,
633    {
634        let transaction = self
635            .begin_with_config(isolation_level, access_mode)
636            .map_err(TransactionError::Connection)?;
637        run_async_transaction_callback(transaction, callback)
638    }
639
640    #[allow(unused)]
641    pub(crate) fn get_record_stmt_in_spans(&self) -> bool {
642        match &self.inner {
643            #[cfg(feature = "sqlx-mysql")]
644            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => conn.record_stmt_in_spans,
645            #[cfg(feature = "sqlx-postgres")]
646            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => conn.record_stmt_in_spans,
647            #[cfg(feature = "sqlx-sqlite")]
648            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => conn.record_stmt_in_spans,
649            #[cfg(feature = "rusqlite")]
650            DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.record_stmt_in_spans,
651            DatabaseConnectionType::Disconnected => true,
652            #[cfg(feature = "mock")]
653            DatabaseConnectionType::MockDatabaseConnection(_) => true,
654            #[cfg(feature = "proxy")]
655            DatabaseConnectionType::ProxyDatabaseConnection(_) => true,
656        }
657    }
658
659    /// Get the database backend for this connection
660    ///
661    /// # Panics
662    ///
663    /// Panics if [DatabaseConnection] is `Disconnected`.
664    pub fn get_database_backend(&self) -> DbBackend {
665        match &self.inner {
666            #[cfg(feature = "sqlx-mysql")]
667            DatabaseConnectionType::SqlxMySqlPoolConnection(_) => DbBackend::MySql,
668            #[cfg(feature = "sqlx-postgres")]
669            DatabaseConnectionType::SqlxPostgresPoolConnection(_) => DbBackend::Postgres,
670            #[cfg(feature = "sqlx-sqlite")]
671            DatabaseConnectionType::SqlxSqlitePoolConnection(_) => DbBackend::Sqlite,
672            #[cfg(feature = "rusqlite")]
673            DatabaseConnectionType::RusqliteSharedConnection(_) => DbBackend::Sqlite,
674            #[cfg(feature = "mock")]
675            DatabaseConnectionType::MockDatabaseConnection(conn) => conn.get_database_backend(),
676            #[cfg(feature = "proxy")]
677            DatabaseConnectionType::ProxyDatabaseConnection(conn) => conn.get_database_backend(),
678            DatabaseConnectionType::Disconnected => panic!("Disconnected"),
679        }
680    }
681
682    /// Creates a [`SchemaBuilder`] for this backend
683    pub fn get_schema_builder(&self) -> SchemaBuilder {
684        Schema::new(self.get_database_backend()).builder()
685    }
686
687    #[cfg(feature = "entity-registry")]
688    #[cfg_attr(docsrs, doc(cfg(feature = "entity-registry")))]
689    /// Builds a schema for all the entites in the given module
690    pub fn get_schema_registry(&self, prefix: &str) -> SchemaBuilder {
691        let schema = Schema::new(self.get_database_backend());
692        crate::EntityRegistry::build_schema(schema, prefix)
693    }
694
695    /// Sets a callback to metric this connection
696    pub fn set_metric_callback<F>(&mut self, _callback: F)
697    where
698        F: Fn(&crate::metric::Info<'_>) + 'static,
699    {
700        match &mut self.inner {
701            #[cfg(feature = "sqlx-mysql")]
702            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => {
703                conn.set_metric_callback(_callback)
704            }
705            #[cfg(feature = "sqlx-postgres")]
706            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
707                conn.set_metric_callback(_callback)
708            }
709            #[cfg(feature = "sqlx-sqlite")]
710            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => {
711                conn.set_metric_callback(_callback)
712            }
713            #[cfg(feature = "rusqlite")]
714            DatabaseConnectionType::RusqliteSharedConnection(conn) => {
715                conn.set_metric_callback(_callback)
716            }
717            _ => {}
718        }
719    }
720
721    /// Checks if a connection to the database is still valid.
722    pub fn ping(&self) -> Result<(), DbErr> {
723        match &self.inner {
724            #[cfg(feature = "sqlx-mysql")]
725            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => conn.ping(),
726            #[cfg(feature = "sqlx-postgres")]
727            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => conn.ping(),
728            #[cfg(feature = "sqlx-sqlite")]
729            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => conn.ping(),
730            #[cfg(feature = "rusqlite")]
731            DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.ping(),
732            #[cfg(feature = "mock")]
733            DatabaseConnectionType::MockDatabaseConnection(conn) => conn.ping(),
734            #[cfg(feature = "proxy")]
735            DatabaseConnectionType::ProxyDatabaseConnection(conn) => conn.ping(),
736            DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
737        }
738    }
739
740    /// Explicitly close the database connection.
741    /// See [`Self::close_by_ref`] for usage with references.
742    pub fn close(self) -> Result<(), DbErr> {
743        self.close_by_ref()
744    }
745
746    /// Explicitly close the database connection
747    pub fn close_by_ref(&self) -> Result<(), DbErr> {
748        match &self.inner {
749            #[cfg(feature = "sqlx-mysql")]
750            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => conn.close_by_ref(),
751            #[cfg(feature = "sqlx-postgres")]
752            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => conn.close_by_ref(),
753            #[cfg(feature = "sqlx-sqlite")]
754            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => conn.close_by_ref(),
755            #[cfg(feature = "rusqlite")]
756            DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.close_by_ref(),
757            #[cfg(feature = "mock")]
758            DatabaseConnectionType::MockDatabaseConnection(_) => {
759                // Nothing to cleanup, we just consume the `DatabaseConnection`
760                Ok(())
761            }
762            #[cfg(feature = "proxy")]
763            DatabaseConnectionType::ProxyDatabaseConnection(_) => {
764                // Nothing to cleanup, we just consume the `DatabaseConnection`
765                Ok(())
766            }
767            DatabaseConnectionType::Disconnected => Err(conn_err("Disconnected")),
768        }
769    }
770}
771
772impl DatabaseConnection {
773    /// Get [sqlx::MySqlPool]
774    ///
775    /// # Panics
776    ///
777    /// Panics if [DbConn] is not a MySQL connection.
778    #[cfg(feature = "sqlx-mysql")]
779    pub fn get_mysql_connection_pool(&self) -> &sqlx::MySqlPool {
780        match &self.inner {
781            DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => &conn.pool,
782            _ => panic!("Not MySQL Connection"),
783        }
784    }
785
786    /// Get [sqlx::PgPool]
787    ///
788    /// # Panics
789    ///
790    /// Panics if [DbConn] is not a Postgres connection.
791    #[cfg(feature = "sqlx-postgres")]
792    pub fn get_postgres_connection_pool(&self) -> &sqlx::PgPool {
793        match &self.inner {
794            DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => &conn.pool,
795            _ => panic!("Not Postgres Connection"),
796        }
797    }
798
799    /// Get [sqlx::SqlitePool]
800    ///
801    /// # Panics
802    ///
803    /// Panics if [DbConn] is not a SQLite connection.
804    #[cfg(feature = "sqlx-sqlite")]
805    pub fn get_sqlite_connection_pool(&self) -> &sqlx::SqlitePool {
806        match &self.inner {
807            DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => &conn.pool,
808            _ => panic!("Not SQLite Connection"),
809        }
810    }
811}
812
813impl DbBackend {
814    /// Check if the URI is the same as the specified database backend.
815    /// Returns true if they match.
816    ///
817    /// # Panics
818    ///
819    /// Panics if `base_url` cannot be parsed as `Url`.
820    pub fn is_prefix_of(self, base_url: &str) -> bool {
821        let base_url_parsed = Url::parse(base_url).expect("Fail to parse database URL");
822        match self {
823            Self::Postgres => {
824                base_url_parsed.scheme() == "postgres" || base_url_parsed.scheme() == "postgresql"
825            }
826            Self::MySql => base_url_parsed.scheme() == "mysql",
827            Self::Sqlite => base_url_parsed.scheme() == "sqlite",
828        }
829    }
830
831    /// Build an SQL [Statement]
832    pub fn build<S>(&self, statement: &S) -> Statement
833    where
834        S: StatementBuilder,
835    {
836        statement.build(self)
837    }
838
839    /// Check if the database supports `RETURNING` syntax on insert and update
840    pub fn support_returning(&self) -> bool {
841        match self {
842            Self::Postgres => true,
843            Self::Sqlite if cfg!(feature = "sqlite-use-returning-for-3_35") => true,
844            Self::MySql if cfg!(feature = "mariadb-use-returning") => true,
845            _ => false,
846        }
847    }
848
849    /// A getter for database dependent boolean value
850    pub fn boolean_value(&self, boolean: bool) -> sea_query::Value {
851        match self {
852            Self::MySql | Self::Postgres | Self::Sqlite => boolean.into(),
853        }
854    }
855
856    /// Get the display string for this enum
857    pub fn as_str(&self) -> &'static str {
858        match self {
859            DatabaseBackend::MySql => "MySql",
860            DatabaseBackend::Postgres => "Postgres",
861            DatabaseBackend::Sqlite => "Sqlite",
862        }
863    }
864}
865
866#[cfg(test)]
867mod tests {
868    use crate::DatabaseConnection;
869
870    #[cfg(not(feature = "sync"))]
871    #[test]
872    fn assert_database_connection_traits() {
873        fn assert_send_sync<T: Send>() {}
874
875        assert_send_sync::<DatabaseConnection>();
876    }
877}