1use crate::{
2 error::*, AccessMode, ConnectionTrait, DatabaseTransaction, ExecResult, IsolationLevel,
3 QueryResult, Statement, StatementBuilder, StreamTrait, TransactionError, TransactionTrait,
4};
5use sea_query::{MysqlQueryBuilder, PostgresQueryBuilder, QueryBuilder, SqliteQueryBuilder};
6use std::{future::Future, pin::Pin};
7use tracing::instrument;
8use url::Url;
9
10#[cfg(feature = "sqlx-dep")]
11use sqlx::pool::PoolConnection;
12
13#[cfg(any(feature = "mock", feature = "proxy"))]
14use std::sync::Arc;
15
16#[cfg_attr(not(feature = "mock"), derive(Clone))]
20pub enum DatabaseConnection {
21 #[cfg(feature = "sqlx-mysql")]
23 SqlxMySqlPoolConnection(crate::SqlxMySqlPoolConnection),
24
25 #[cfg(feature = "sqlx-postgres")]
27 SqlxPostgresPoolConnection(crate::SqlxPostgresPoolConnection),
28
29 #[cfg(feature = "sqlx-sqlite")]
31 SqlxSqlitePoolConnection(crate::SqlxSqlitePoolConnection),
32
33 #[cfg(feature = "mock")]
35 MockDatabaseConnection(Arc<crate::MockDatabaseConnection>),
36
37 #[cfg(feature = "proxy")]
39 ProxyDatabaseConnection(Arc<crate::ProxyDatabaseConnection>),
40
41 Disconnected,
43}
44
45pub type DbConn = DatabaseConnection;
47
48impl Default for DatabaseConnection {
49 fn default() -> Self {
50 Self::Disconnected
51 }
52}
53
54#[derive(Debug, Copy, Clone, PartialEq, Eq)]
57pub enum DatabaseBackend {
58 MySql,
60 Postgres,
62 Sqlite,
64}
65
66pub type DbBackend = DatabaseBackend;
68
69#[derive(Debug)]
70pub(crate) enum InnerConnection {
71 #[cfg(feature = "sqlx-mysql")]
72 MySql(PoolConnection<sqlx::MySql>),
73 #[cfg(feature = "sqlx-postgres")]
74 Postgres(PoolConnection<sqlx::Postgres>),
75 #[cfg(feature = "sqlx-sqlite")]
76 Sqlite(PoolConnection<sqlx::Sqlite>),
77 #[cfg(feature = "mock")]
78 Mock(Arc<crate::MockDatabaseConnection>),
79 #[cfg(feature = "proxy")]
80 Proxy(Arc<crate::ProxyDatabaseConnection>),
81}
82
83impl std::fmt::Debug for DatabaseConnection {
84 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
85 write!(
86 f,
87 "{}",
88 match self {
89 #[cfg(feature = "sqlx-mysql")]
90 Self::SqlxMySqlPoolConnection(_) => "SqlxMySqlPoolConnection",
91 #[cfg(feature = "sqlx-postgres")]
92 Self::SqlxPostgresPoolConnection(_) => "SqlxPostgresPoolConnection",
93 #[cfg(feature = "sqlx-sqlite")]
94 Self::SqlxSqlitePoolConnection(_) => "SqlxSqlitePoolConnection",
95 #[cfg(feature = "mock")]
96 Self::MockDatabaseConnection(_) => "MockDatabaseConnection",
97 #[cfg(feature = "proxy")]
98 Self::ProxyDatabaseConnection(_) => "ProxyDatabaseConnection",
99 Self::Disconnected => "Disconnected",
100 }
101 )
102 }
103}
104
105#[async_trait::async_trait]
106impl ConnectionTrait for DatabaseConnection {
107 fn get_database_backend(&self) -> DbBackend {
108 match self {
109 #[cfg(feature = "sqlx-mysql")]
110 DatabaseConnection::SqlxMySqlPoolConnection(_) => DbBackend::MySql,
111 #[cfg(feature = "sqlx-postgres")]
112 DatabaseConnection::SqlxPostgresPoolConnection(_) => DbBackend::Postgres,
113 #[cfg(feature = "sqlx-sqlite")]
114 DatabaseConnection::SqlxSqlitePoolConnection(_) => DbBackend::Sqlite,
115 #[cfg(feature = "mock")]
116 DatabaseConnection::MockDatabaseConnection(conn) => conn.get_database_backend(),
117 #[cfg(feature = "proxy")]
118 DatabaseConnection::ProxyDatabaseConnection(conn) => conn.get_database_backend(),
119 DatabaseConnection::Disconnected => panic!("Disconnected"),
120 }
121 }
122
123 #[instrument(level = "trace")]
124 #[allow(unused_variables)]
125 async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
126 match self {
127 #[cfg(feature = "sqlx-mysql")]
128 DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.execute(stmt).await,
129 #[cfg(feature = "sqlx-postgres")]
130 DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.execute(stmt).await,
131 #[cfg(feature = "sqlx-sqlite")]
132 DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.execute(stmt).await,
133 #[cfg(feature = "mock")]
134 DatabaseConnection::MockDatabaseConnection(conn) => conn.execute(stmt),
135 #[cfg(feature = "proxy")]
136 DatabaseConnection::ProxyDatabaseConnection(conn) => conn.execute(stmt).await,
137 DatabaseConnection::Disconnected => Err(conn_err("Disconnected")),
138 }
139 }
140
141 #[instrument(level = "trace")]
142 #[allow(unused_variables)]
143 async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
144 match self {
145 #[cfg(feature = "sqlx-mysql")]
146 DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.execute_unprepared(sql).await,
147 #[cfg(feature = "sqlx-postgres")]
148 DatabaseConnection::SqlxPostgresPoolConnection(conn) => {
149 conn.execute_unprepared(sql).await
150 }
151 #[cfg(feature = "sqlx-sqlite")]
152 DatabaseConnection::SqlxSqlitePoolConnection(conn) => {
153 conn.execute_unprepared(sql).await
154 }
155 #[cfg(feature = "mock")]
156 DatabaseConnection::MockDatabaseConnection(conn) => {
157 let db_backend = conn.get_database_backend();
158 let stmt = Statement::from_string(db_backend, sql);
159 conn.execute(stmt)
160 }
161 #[cfg(feature = "proxy")]
162 DatabaseConnection::ProxyDatabaseConnection(conn) => {
163 let db_backend = conn.get_database_backend();
164 let stmt = Statement::from_string(db_backend, sql);
165 conn.execute(stmt).await
166 }
167 DatabaseConnection::Disconnected => Err(conn_err("Disconnected")),
168 }
169 }
170
171 #[instrument(level = "trace")]
172 #[allow(unused_variables)]
173 async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
174 match self {
175 #[cfg(feature = "sqlx-mysql")]
176 DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.query_one(stmt).await,
177 #[cfg(feature = "sqlx-postgres")]
178 DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.query_one(stmt).await,
179 #[cfg(feature = "sqlx-sqlite")]
180 DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.query_one(stmt).await,
181 #[cfg(feature = "mock")]
182 DatabaseConnection::MockDatabaseConnection(conn) => conn.query_one(stmt),
183 #[cfg(feature = "proxy")]
184 DatabaseConnection::ProxyDatabaseConnection(conn) => conn.query_one(stmt).await,
185 DatabaseConnection::Disconnected => Err(conn_err("Disconnected")),
186 }
187 }
188
189 #[instrument(level = "trace")]
190 #[allow(unused_variables)]
191 async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
192 match self {
193 #[cfg(feature = "sqlx-mysql")]
194 DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.query_all(stmt).await,
195 #[cfg(feature = "sqlx-postgres")]
196 DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.query_all(stmt).await,
197 #[cfg(feature = "sqlx-sqlite")]
198 DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.query_all(stmt).await,
199 #[cfg(feature = "mock")]
200 DatabaseConnection::MockDatabaseConnection(conn) => conn.query_all(stmt),
201 #[cfg(feature = "proxy")]
202 DatabaseConnection::ProxyDatabaseConnection(conn) => conn.query_all(stmt).await,
203 DatabaseConnection::Disconnected => Err(conn_err("Disconnected")),
204 }
205 }
206
207 #[cfg(feature = "mock")]
208 fn is_mock_connection(&self) -> bool {
209 matches!(self, DatabaseConnection::MockDatabaseConnection(_))
210 }
211}
212
213#[async_trait::async_trait]
214impl StreamTrait for DatabaseConnection {
215 type Stream<'a> = crate::QueryStream;
216
217 #[instrument(level = "trace")]
218 #[allow(unused_variables)]
219 fn stream<'a>(
220 &'a self,
221 stmt: Statement,
222 ) -> Pin<Box<dyn Future<Output = Result<Self::Stream<'a>, DbErr>> + 'a + Send>> {
223 Box::pin(async move {
224 match self {
225 #[cfg(feature = "sqlx-mysql")]
226 DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.stream(stmt).await,
227 #[cfg(feature = "sqlx-postgres")]
228 DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.stream(stmt).await,
229 #[cfg(feature = "sqlx-sqlite")]
230 DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.stream(stmt).await,
231 #[cfg(feature = "mock")]
232 DatabaseConnection::MockDatabaseConnection(conn) => {
233 Ok(crate::QueryStream::from((Arc::clone(conn), stmt, None)))
234 }
235 #[cfg(feature = "proxy")]
236 DatabaseConnection::ProxyDatabaseConnection(conn) => {
237 Ok(crate::QueryStream::from((Arc::clone(conn), stmt, None)))
238 }
239 DatabaseConnection::Disconnected => Err(conn_err("Disconnected")),
240 }
241 })
242 }
243}
244
245#[async_trait::async_trait]
246impl TransactionTrait for DatabaseConnection {
247 #[instrument(level = "trace")]
248 async fn begin(&self) -> Result<DatabaseTransaction, DbErr> {
249 match self {
250 #[cfg(feature = "sqlx-mysql")]
251 DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.begin(None, None).await,
252 #[cfg(feature = "sqlx-postgres")]
253 DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.begin(None, None).await,
254 #[cfg(feature = "sqlx-sqlite")]
255 DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.begin(None, None).await,
256 #[cfg(feature = "mock")]
257 DatabaseConnection::MockDatabaseConnection(conn) => {
258 DatabaseTransaction::new_mock(Arc::clone(conn), None).await
259 }
260 #[cfg(feature = "proxy")]
261 DatabaseConnection::ProxyDatabaseConnection(conn) => {
262 DatabaseTransaction::new_proxy(conn.clone(), None).await
263 }
264 DatabaseConnection::Disconnected => Err(conn_err("Disconnected")),
265 }
266 }
267
268 #[instrument(level = "trace")]
269 async fn begin_with_config(
270 &self,
271 _isolation_level: Option<IsolationLevel>,
272 _access_mode: Option<AccessMode>,
273 ) -> Result<DatabaseTransaction, DbErr> {
274 match self {
275 #[cfg(feature = "sqlx-mysql")]
276 DatabaseConnection::SqlxMySqlPoolConnection(conn) => {
277 conn.begin(_isolation_level, _access_mode).await
278 }
279 #[cfg(feature = "sqlx-postgres")]
280 DatabaseConnection::SqlxPostgresPoolConnection(conn) => {
281 conn.begin(_isolation_level, _access_mode).await
282 }
283 #[cfg(feature = "sqlx-sqlite")]
284 DatabaseConnection::SqlxSqlitePoolConnection(conn) => {
285 conn.begin(_isolation_level, _access_mode).await
286 }
287 #[cfg(feature = "mock")]
288 DatabaseConnection::MockDatabaseConnection(conn) => {
289 DatabaseTransaction::new_mock(Arc::clone(conn), None).await
290 }
291 #[cfg(feature = "proxy")]
292 DatabaseConnection::ProxyDatabaseConnection(conn) => {
293 DatabaseTransaction::new_proxy(conn.clone(), None).await
294 }
295 DatabaseConnection::Disconnected => Err(conn_err("Disconnected")),
296 }
297 }
298
299 #[instrument(level = "trace", skip(_callback))]
302 async fn transaction<F, T, E>(&self, _callback: F) -> Result<T, TransactionError<E>>
303 where
304 F: for<'c> FnOnce(
305 &'c DatabaseTransaction,
306 ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
307 + Send,
308 T: Send,
309 E: std::fmt::Display + std::fmt::Debug + Send,
310 {
311 match self {
312 #[cfg(feature = "sqlx-mysql")]
313 DatabaseConnection::SqlxMySqlPoolConnection(conn) => {
314 conn.transaction(_callback, None, None).await
315 }
316 #[cfg(feature = "sqlx-postgres")]
317 DatabaseConnection::SqlxPostgresPoolConnection(conn) => {
318 conn.transaction(_callback, None, None).await
319 }
320 #[cfg(feature = "sqlx-sqlite")]
321 DatabaseConnection::SqlxSqlitePoolConnection(conn) => {
322 conn.transaction(_callback, None, None).await
323 }
324 #[cfg(feature = "mock")]
325 DatabaseConnection::MockDatabaseConnection(conn) => {
326 let transaction = DatabaseTransaction::new_mock(Arc::clone(conn), None)
327 .await
328 .map_err(TransactionError::Connection)?;
329 transaction.run(_callback).await
330 }
331 #[cfg(feature = "proxy")]
332 DatabaseConnection::ProxyDatabaseConnection(conn) => {
333 let transaction = DatabaseTransaction::new_proxy(conn.clone(), None)
334 .await
335 .map_err(TransactionError::Connection)?;
336 transaction.run(_callback).await
337 }
338 DatabaseConnection::Disconnected => Err(conn_err("Disconnected").into()),
339 }
340 }
341
342 #[instrument(level = "trace", skip(_callback))]
345 async fn transaction_with_config<F, T, E>(
346 &self,
347 _callback: F,
348 _isolation_level: Option<IsolationLevel>,
349 _access_mode: Option<AccessMode>,
350 ) -> Result<T, TransactionError<E>>
351 where
352 F: for<'c> FnOnce(
353 &'c DatabaseTransaction,
354 ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
355 + Send,
356 T: Send,
357 E: std::fmt::Display + std::fmt::Debug + Send,
358 {
359 match self {
360 #[cfg(feature = "sqlx-mysql")]
361 DatabaseConnection::SqlxMySqlPoolConnection(conn) => {
362 conn.transaction(_callback, _isolation_level, _access_mode)
363 .await
364 }
365 #[cfg(feature = "sqlx-postgres")]
366 DatabaseConnection::SqlxPostgresPoolConnection(conn) => {
367 conn.transaction(_callback, _isolation_level, _access_mode)
368 .await
369 }
370 #[cfg(feature = "sqlx-sqlite")]
371 DatabaseConnection::SqlxSqlitePoolConnection(conn) => {
372 conn.transaction(_callback, _isolation_level, _access_mode)
373 .await
374 }
375 #[cfg(feature = "mock")]
376 DatabaseConnection::MockDatabaseConnection(conn) => {
377 let transaction = DatabaseTransaction::new_mock(Arc::clone(conn), None)
378 .await
379 .map_err(TransactionError::Connection)?;
380 transaction.run(_callback).await
381 }
382 #[cfg(feature = "proxy")]
383 DatabaseConnection::ProxyDatabaseConnection(conn) => {
384 let transaction = DatabaseTransaction::new_proxy(conn.clone(), None)
385 .await
386 .map_err(TransactionError::Connection)?;
387 transaction.run(_callback).await
388 }
389 DatabaseConnection::Disconnected => Err(conn_err("Disconnected").into()),
390 }
391 }
392}
393
394#[cfg(feature = "mock")]
395impl DatabaseConnection {
396 pub fn as_mock_connection(&self) -> &crate::MockDatabaseConnection {
402 match self {
403 DatabaseConnection::MockDatabaseConnection(mock_conn) => mock_conn,
404 _ => panic!("Not mock connection"),
405 }
406 }
407
408 pub fn into_transaction_log(self) -> Vec<crate::Transaction> {
414 let mut mocker = self
415 .as_mock_connection()
416 .get_mocker_mutex()
417 .lock()
418 .expect("Fail to acquire mocker");
419 mocker.drain_transaction_log()
420 }
421}
422
423#[cfg(feature = "proxy")]
424impl DatabaseConnection {
425 pub fn as_proxy_connection(&self) -> &crate::ProxyDatabaseConnection {
431 match self {
432 DatabaseConnection::ProxyDatabaseConnection(proxy_conn) => proxy_conn,
433 _ => panic!("Not proxy connection"),
434 }
435 }
436}
437
438impl DatabaseConnection {
439 pub fn set_metric_callback<F>(&mut self, _callback: F)
441 where
442 F: Fn(&crate::metric::Info<'_>) + Send + Sync + 'static,
443 {
444 match self {
445 #[cfg(feature = "sqlx-mysql")]
446 DatabaseConnection::SqlxMySqlPoolConnection(conn) => {
447 conn.set_metric_callback(_callback)
448 }
449 #[cfg(feature = "sqlx-postgres")]
450 DatabaseConnection::SqlxPostgresPoolConnection(conn) => {
451 conn.set_metric_callback(_callback)
452 }
453 #[cfg(feature = "sqlx-sqlite")]
454 DatabaseConnection::SqlxSqlitePoolConnection(conn) => {
455 conn.set_metric_callback(_callback)
456 }
457 _ => {}
458 }
459 }
460
461 pub async fn ping(&self) -> Result<(), DbErr> {
463 match self {
464 #[cfg(feature = "sqlx-mysql")]
465 DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.ping().await,
466 #[cfg(feature = "sqlx-postgres")]
467 DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.ping().await,
468 #[cfg(feature = "sqlx-sqlite")]
469 DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.ping().await,
470 #[cfg(feature = "mock")]
471 DatabaseConnection::MockDatabaseConnection(conn) => conn.ping(),
472 #[cfg(feature = "proxy")]
473 DatabaseConnection::ProxyDatabaseConnection(conn) => conn.ping().await,
474 DatabaseConnection::Disconnected => Err(conn_err("Disconnected")),
475 }
476 }
477
478 pub async fn close(self) -> Result<(), DbErr> {
481 self.close_by_ref().await
482 }
483
484 pub async fn close_by_ref(&self) -> Result<(), DbErr> {
486 match self {
487 #[cfg(feature = "sqlx-mysql")]
488 DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.close_by_ref().await,
489 #[cfg(feature = "sqlx-postgres")]
490 DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.close_by_ref().await,
491 #[cfg(feature = "sqlx-sqlite")]
492 DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.close_by_ref().await,
493 #[cfg(feature = "mock")]
494 DatabaseConnection::MockDatabaseConnection(_) => {
495 Ok(())
497 }
498 #[cfg(feature = "proxy")]
499 DatabaseConnection::ProxyDatabaseConnection(_) => {
500 Ok(())
502 }
503 DatabaseConnection::Disconnected => Err(conn_err("Disconnected")),
504 }
505 }
506}
507
508impl DatabaseConnection {
509 #[cfg(feature = "sqlx-mysql")]
515 pub fn get_mysql_connection_pool(&self) -> &sqlx::MySqlPool {
516 match self {
517 DatabaseConnection::SqlxMySqlPoolConnection(conn) => &conn.pool,
518 _ => panic!("Not MySQL Connection"),
519 }
520 }
521
522 #[cfg(feature = "sqlx-postgres")]
528 pub fn get_postgres_connection_pool(&self) -> &sqlx::PgPool {
529 match self {
530 DatabaseConnection::SqlxPostgresPoolConnection(conn) => &conn.pool,
531 _ => panic!("Not Postgres Connection"),
532 }
533 }
534
535 #[cfg(feature = "sqlx-sqlite")]
541 pub fn get_sqlite_connection_pool(&self) -> &sqlx::SqlitePool {
542 match self {
543 DatabaseConnection::SqlxSqlitePoolConnection(conn) => &conn.pool,
544 _ => panic!("Not SQLite Connection"),
545 }
546 }
547}
548
549impl DbBackend {
550 pub fn is_prefix_of(self, base_url: &str) -> bool {
557 let base_url_parsed = Url::parse(base_url).expect("Fail to parse database URL");
558 match self {
559 Self::Postgres => {
560 base_url_parsed.scheme() == "postgres" || base_url_parsed.scheme() == "postgresql"
561 }
562 Self::MySql => base_url_parsed.scheme() == "mysql",
563 Self::Sqlite => base_url_parsed.scheme() == "sqlite",
564 }
565 }
566
567 pub fn build<S>(&self, statement: &S) -> Statement
569 where
570 S: StatementBuilder,
571 {
572 statement.build(self)
573 }
574
575 pub fn get_query_builder(&self) -> Box<dyn QueryBuilder> {
577 match self {
578 Self::MySql => Box::new(MysqlQueryBuilder),
579 Self::Postgres => Box::new(PostgresQueryBuilder),
580 Self::Sqlite => Box::new(SqliteQueryBuilder),
581 }
582 }
583
584 pub fn support_returning(&self) -> bool {
586 match self {
587 Self::Postgres => true,
588 Self::Sqlite if cfg!(feature = "sqlite-use-returning-for-3_35") => true,
589 _ => false,
590 }
591 }
592
593 pub fn boolean_value(&self, boolean: bool) -> sea_query::Value {
595 match self {
596 Self::MySql | Self::Postgres | Self::Sqlite => boolean.into(),
597 }
598 }
599}
600
601#[cfg(test)]
602mod tests {
603 use crate::DatabaseConnection;
604
605 #[test]
606 fn assert_database_connection_traits() {
607 fn assert_send_sync<T: Send + Sync>() {}
608
609 assert_send_sync::<DatabaseConnection>();
610 }
611}