1use std::{sync::Arc, time::Duration};
2
3use futures_util::future::BoxFuture;
4#[cfg(feature = "sqlx-mysql")]
5use sqlx::mysql::MySqlConnectOptions;
6#[cfg(feature = "sqlx-postgres")]
7use sqlx::postgres::PgConnectOptions;
8#[cfg(feature = "sqlx-sqlite")]
9use sqlx::sqlite::SqliteConnectOptions;
10
11mod connection;
12mod db_connection;
13mod executor;
14#[cfg(feature = "mock")]
15#[cfg_attr(docsrs, doc(cfg(feature = "mock")))]
16mod mock;
17#[cfg(feature = "proxy")]
18#[cfg_attr(docsrs, doc(cfg(feature = "proxy")))]
19mod proxy;
20#[cfg(feature = "rbac")]
21mod restricted_connection;
22#[cfg(all(feature = "schema-sync", feature = "rusqlite"))]
23mod sea_schema_rusqlite;
24#[cfg(all(feature = "schema-sync", feature = "sqlx-dep"))]
25mod sea_schema_shim;
26mod statement;
27mod stream;
28mod tracing_spans;
29mod transaction;
30
31pub use connection::*;
32pub use db_connection::*;
33pub use executor::*;
34#[cfg(feature = "mock")]
35#[cfg_attr(docsrs, doc(cfg(feature = "mock")))]
36pub use mock::*;
37#[cfg(feature = "proxy")]
38#[cfg_attr(docsrs, doc(cfg(feature = "proxy")))]
39pub use proxy::*;
40#[cfg(feature = "rbac")]
41pub use restricted_connection::*;
42pub use statement::*;
43use std::borrow::Cow;
44pub use stream::*;
45use tracing::instrument;
46pub use transaction::*;
47
48use crate::error::*;
49
50#[derive(Debug, Default)]
52pub struct Database;
53
54#[cfg(feature = "sync")]
55type BoxFuture<'a, T> = T;
56
57type AfterConnectCallback = Option<
58 Arc<
59 dyn Fn(DatabaseConnection) -> BoxFuture<'static, Result<(), DbErr>> + Send + Sync + 'static,
60 >,
61>;
62
63#[derive(derive_more::Debug, Clone)]
65pub struct ConnectOptions {
66 pub(crate) url: String,
68 pub(crate) max_connections: Option<u32>,
70 pub(crate) min_connections: Option<u32>,
72 pub(crate) connect_timeout: Option<Duration>,
74 pub(crate) idle_timeout: Option<Option<Duration>>,
77 pub(crate) acquire_timeout: Option<Duration>,
79 pub(crate) max_lifetime: Option<Option<Duration>>,
81 pub(crate) sqlx_logging: bool,
83 pub(crate) sqlx_logging_level: log::LevelFilter,
85 pub(crate) sqlx_slow_statements_logging_level: log::LevelFilter,
87 pub(crate) sqlx_slow_statements_logging_threshold: Duration,
89 pub(crate) sqlcipher_key: Option<Cow<'static, str>>,
91 pub(crate) schema_search_path: Option<String>,
93 pub(crate) test_before_acquire: bool,
94 pub(crate) connect_lazy: bool,
98
99 #[debug(skip)]
100 pub(crate) after_connect: AfterConnectCallback,
101
102 #[cfg(feature = "sqlx-mysql")]
103 #[debug(skip)]
104 pub(crate) mysql_opts_fn:
105 Option<Arc<dyn Fn(MySqlConnectOptions) -> MySqlConnectOptions + Send + Sync>>,
106 #[cfg(feature = "sqlx-postgres")]
107 #[debug(skip)]
108 pub(crate) pg_opts_fn: Option<Arc<dyn Fn(PgConnectOptions) -> PgConnectOptions + Send + Sync>>,
109 #[cfg(feature = "sqlx-sqlite")]
110 #[debug(skip)]
111 pub(crate) sqlite_opts_fn:
112 Option<Arc<dyn Fn(SqliteConnectOptions) -> SqliteConnectOptions + Send + Sync>>,
113}
114
115impl Database {
116 #[instrument(level = "trace", skip(opt))]
119 pub async fn connect<C>(opt: C) -> Result<DatabaseConnection, DbErr>
120 where
121 C: Into<ConnectOptions>,
122 {
123 let opt: ConnectOptions = opt.into();
124
125 if url::Url::parse(&opt.url).is_err() {
126 return Err(conn_err(format!(
127 "The connection string '{}' cannot be parsed.",
128 opt.url
129 )));
130 }
131
132 #[cfg(feature = "sqlx-mysql")]
133 if DbBackend::MySql.is_prefix_of(&opt.url) {
134 return crate::SqlxMySqlConnector::connect(opt).await;
135 }
136 #[cfg(feature = "sqlx-postgres")]
137 if DbBackend::Postgres.is_prefix_of(&opt.url) {
138 return crate::SqlxPostgresConnector::connect(opt).await;
139 }
140 #[cfg(feature = "sqlx-sqlite")]
141 if DbBackend::Sqlite.is_prefix_of(&opt.url) {
142 return crate::SqlxSqliteConnector::connect(opt).await;
143 }
144 #[cfg(feature = "rusqlite")]
145 if DbBackend::Sqlite.is_prefix_of(&opt.url) {
146 return crate::driver::rusqlite::RusqliteConnector::connect(opt);
147 }
148 #[cfg(feature = "mock")]
149 if crate::MockDatabaseConnector::accepts(&opt.url) {
150 return crate::MockDatabaseConnector::connect(&opt.url).await;
151 }
152
153 Err(conn_err(format!(
154 "The connection string '{}' has no supporting driver.",
155 opt.url
156 )))
157 }
158
159 #[cfg(feature = "proxy")]
161 #[instrument(level = "trace", skip(proxy_func_arc))]
162 pub async fn connect_proxy(
163 db_type: DbBackend,
164 proxy_func_arc: std::sync::Arc<Box<dyn ProxyDatabaseTrait>>,
165 ) -> Result<DatabaseConnection, DbErr> {
166 match db_type {
167 DbBackend::MySql => {
168 return crate::ProxyDatabaseConnector::connect(
169 DbBackend::MySql,
170 proxy_func_arc.to_owned(),
171 );
172 }
173 DbBackend::Postgres => {
174 return crate::ProxyDatabaseConnector::connect(
175 DbBackend::Postgres,
176 proxy_func_arc.to_owned(),
177 );
178 }
179 DbBackend::Sqlite => {
180 return crate::ProxyDatabaseConnector::connect(
181 DbBackend::Sqlite,
182 proxy_func_arc.to_owned(),
183 );
184 }
185 }
186 }
187}
188
189impl<T> From<T> for ConnectOptions
190where
191 T: Into<String>,
192{
193 fn from(s: T) -> ConnectOptions {
194 ConnectOptions::new(s.into())
195 }
196}
197
198impl ConnectOptions {
199 pub fn new<T>(url: T) -> Self
201 where
202 T: Into<String>,
203 {
204 Self {
205 url: url.into(),
206 max_connections: None,
207 min_connections: None,
208 connect_timeout: None,
209 idle_timeout: None,
210 acquire_timeout: None,
211 max_lifetime: None,
212 sqlx_logging: true,
213 sqlx_logging_level: log::LevelFilter::Info,
214 sqlx_slow_statements_logging_level: log::LevelFilter::Off,
215 sqlx_slow_statements_logging_threshold: Duration::from_secs(1),
216 sqlcipher_key: None,
217 schema_search_path: None,
218 test_before_acquire: true,
219 connect_lazy: false,
220 after_connect: None,
221 #[cfg(feature = "sqlx-mysql")]
222 mysql_opts_fn: None,
223 #[cfg(feature = "sqlx-postgres")]
224 pg_opts_fn: None,
225 #[cfg(feature = "sqlx-sqlite")]
226 sqlite_opts_fn: None,
227 }
228 }
229
230 pub fn get_url(&self) -> &str {
232 &self.url
233 }
234
235 pub fn max_connections(&mut self, value: u32) -> &mut Self {
237 self.max_connections = Some(value);
238 self
239 }
240
241 pub fn get_max_connections(&self) -> Option<u32> {
243 self.max_connections
244 }
245
246 pub fn min_connections(&mut self, value: u32) -> &mut Self {
248 self.min_connections = Some(value);
249 self
250 }
251
252 pub fn get_min_connections(&self) -> Option<u32> {
254 self.min_connections
255 }
256
257 pub fn connect_timeout(&mut self, value: Duration) -> &mut Self {
259 self.connect_timeout = Some(value);
260 self
261 }
262
263 pub fn get_connect_timeout(&self) -> Option<Duration> {
265 self.connect_timeout
266 }
267
268 pub fn idle_timeout<T>(&mut self, value: T) -> &mut Self
270 where
271 T: Into<Option<Duration>>,
272 {
273 self.idle_timeout = Some(value.into());
274 self
275 }
276
277 pub fn get_idle_timeout(&self) -> Option<Option<Duration>> {
279 self.idle_timeout
280 }
281
282 pub fn acquire_timeout(&mut self, value: Duration) -> &mut Self {
284 self.acquire_timeout = Some(value);
285 self
286 }
287
288 pub fn get_acquire_timeout(&self) -> Option<Duration> {
290 self.acquire_timeout
291 }
292
293 pub fn max_lifetime<T>(&mut self, lifetime: T) -> &mut Self
295 where
296 T: Into<Option<Duration>>,
297 {
298 self.max_lifetime = Some(lifetime.into());
299 self
300 }
301
302 pub fn get_max_lifetime(&self) -> Option<Option<Duration>> {
304 self.max_lifetime
305 }
306
307 pub fn sqlx_logging(&mut self, value: bool) -> &mut Self {
309 self.sqlx_logging = value;
310 self
311 }
312
313 pub fn get_sqlx_logging(&self) -> bool {
315 self.sqlx_logging
316 }
317
318 pub fn sqlx_logging_level(&mut self, level: log::LevelFilter) -> &mut Self {
321 self.sqlx_logging_level = level;
322 self
323 }
324
325 pub fn sqlx_slow_statements_logging_settings(
328 &mut self,
329 level: log::LevelFilter,
330 duration: Duration,
331 ) -> &mut Self {
332 self.sqlx_slow_statements_logging_level = level;
333 self.sqlx_slow_statements_logging_threshold = duration;
334 self
335 }
336
337 pub fn get_sqlx_logging_level(&self) -> log::LevelFilter {
339 self.sqlx_logging_level
340 }
341
342 pub fn get_sqlx_slow_statements_logging_settings(&self) -> (log::LevelFilter, Duration) {
344 (
345 self.sqlx_slow_statements_logging_level,
346 self.sqlx_slow_statements_logging_threshold,
347 )
348 }
349
350 pub fn sqlcipher_key<T>(&mut self, value: T) -> &mut Self
352 where
353 T: Into<Cow<'static, str>>,
354 {
355 self.sqlcipher_key = Some(value.into());
356 self
357 }
358
359 pub fn set_schema_search_path<T>(&mut self, schema_search_path: T) -> &mut Self
361 where
362 T: Into<String>,
363 {
364 self.schema_search_path = Some(schema_search_path.into());
365 self
366 }
367
368 pub fn test_before_acquire(&mut self, value: bool) -> &mut Self {
370 self.test_before_acquire = value;
371 self
372 }
373
374 pub fn connect_lazy(&mut self, value: bool) -> &mut Self {
377 self.connect_lazy = value;
378 self
379 }
380
381 pub fn get_connect_lazy(&self) -> bool {
383 self.connect_lazy
384 }
385
386 pub fn after_connect<F>(&mut self, f: F) -> &mut Self
388 where
389 F: Fn(DatabaseConnection) -> BoxFuture<'static, Result<(), DbErr>> + Send + Sync + 'static,
390 {
391 self.after_connect = Some(Arc::new(f));
392
393 self
394 }
395
396 #[cfg(feature = "sqlx-mysql")]
397 #[cfg_attr(docsrs, doc(cfg(feature = "sqlx-mysql")))]
398 pub fn map_sqlx_mysql_opts<F>(&mut self, f: F) -> &mut Self
401 where
402 F: Fn(MySqlConnectOptions) -> MySqlConnectOptions + Send + Sync + 'static,
403 {
404 self.mysql_opts_fn = Some(Arc::new(f));
405 self
406 }
407
408 #[cfg(feature = "sqlx-postgres")]
409 #[cfg_attr(docsrs, doc(cfg(feature = "sqlx-postgres")))]
410 pub fn map_sqlx_postgres_opts<F>(&mut self, f: F) -> &mut Self
413 where
414 F: Fn(PgConnectOptions) -> PgConnectOptions + Send + Sync + 'static,
415 {
416 self.pg_opts_fn = Some(Arc::new(f));
417 self
418 }
419
420 #[cfg(feature = "sqlx-sqlite")]
421 #[cfg_attr(docsrs, doc(cfg(feature = "sqlx-sqlite")))]
422 pub fn map_sqlx_sqlite_opts<F>(&mut self, f: F) -> &mut Self
425 where
426 F: Fn(SqliteConnectOptions) -> SqliteConnectOptions + Send + Sync + 'static,
427 {
428 self.sqlite_opts_fn = Some(Arc::new(f));
429 self
430 }
431}