1use std::{sync::Arc, time::Duration};
2
3use futures_util::future::BoxFuture;
4#[cfg(feature = "sqlx-mysql")]
5use sqlx::mysql::MySqlConnectOptions;
6#[cfg(feature = "sqlx-postgres")]
7use sqlx::postgres::PgConnectOptions;
8#[cfg(feature = "sqlx-sqlite")]
9use sqlx::sqlite::SqliteConnectOptions;
10
11mod connection;
12mod db_connection;
13mod executor;
14#[cfg(feature = "mock")]
15#[cfg_attr(docsrs, doc(cfg(feature = "mock")))]
16mod mock;
17#[cfg(feature = "proxy")]
18#[cfg_attr(docsrs, doc(cfg(feature = "proxy")))]
19mod proxy;
20#[cfg(feature = "rbac")]
21mod restricted_connection;
22#[cfg(all(feature = "schema-sync", feature = "rusqlite"))]
23mod sea_schema_rusqlite;
24#[cfg(all(feature = "schema-sync", feature = "sqlx-dep"))]
25mod sea_schema_shim;
26mod statement;
27mod stream;
28mod tracing_spans;
29mod transaction;
30
31pub use connection::*;
32pub use db_connection::*;
33pub use executor::*;
34#[cfg(feature = "mock")]
35#[cfg_attr(docsrs, doc(cfg(feature = "mock")))]
36pub use mock::*;
37#[cfg(feature = "proxy")]
38#[cfg_attr(docsrs, doc(cfg(feature = "proxy")))]
39pub use proxy::*;
40#[cfg(feature = "rbac")]
41pub use restricted_connection::*;
42pub use statement::*;
43use std::borrow::Cow;
44pub use stream::*;
45use tracing::instrument;
46pub use transaction::*;
47
48use crate::error::*;
49
50#[derive(Debug, Default)]
52pub struct Database;
53
54#[cfg(feature = "sync")]
55type BoxFuture<'a, T> = T;
56
57type AfterConnectCallback = Option<
58 Arc<
59 dyn Fn(DatabaseConnection) -> BoxFuture<'static, Result<(), DbErr>> + Send + Sync + 'static,
60 >,
61>;
62
63#[derive(derive_more::Debug, Clone)]
65pub struct ConnectOptions {
66 pub(crate) url: String,
68 pub(crate) max_connections: Option<u32>,
70 pub(crate) min_connections: Option<u32>,
72 pub(crate) connect_timeout: Option<Duration>,
74 pub(crate) idle_timeout: Option<Option<Duration>>,
77 pub(crate) acquire_timeout: Option<Duration>,
79 pub(crate) max_lifetime: Option<Option<Duration>>,
81 pub(crate) sqlx_logging: bool,
83 pub(crate) sqlx_logging_level: log::LevelFilter,
85 pub(crate) sqlx_slow_statements_logging_level: log::LevelFilter,
87 pub(crate) sqlx_slow_statements_logging_threshold: Duration,
89 pub(crate) sqlcipher_key: Option<Cow<'static, str>>,
91 pub(crate) schema_search_path: Option<String>,
93 pub(crate) application_name: Option<String>,
95 pub(crate) test_before_acquire: bool,
96 pub(crate) connect_lazy: bool,
100
101 #[debug(skip)]
102 pub(crate) after_connect: AfterConnectCallback,
103
104 #[cfg(feature = "sqlx-mysql")]
105 #[debug(skip)]
106 pub(crate) mysql_opts_fn:
107 Option<Arc<dyn Fn(MySqlConnectOptions) -> MySqlConnectOptions + Send + Sync>>,
108 #[cfg(feature = "sqlx-postgres")]
109 #[debug(skip)]
110 pub(crate) pg_opts_fn: Option<Arc<dyn Fn(PgConnectOptions) -> PgConnectOptions + Send + Sync>>,
111 #[cfg(feature = "sqlx-sqlite")]
112 #[debug(skip)]
113 pub(crate) sqlite_opts_fn:
114 Option<Arc<dyn Fn(SqliteConnectOptions) -> SqliteConnectOptions + Send + Sync>>,
115}
116
117impl Database {
118 #[instrument(level = "trace", skip(opt))]
121 pub async fn connect<C>(opt: C) -> Result<DatabaseConnection, DbErr>
122 where
123 C: Into<ConnectOptions>,
124 {
125 let opt: ConnectOptions = opt.into();
126
127 if url::Url::parse(&opt.url).is_err() {
128 return Err(conn_err(format!(
129 "The connection string '{}' cannot be parsed.",
130 opt.url
131 )));
132 }
133
134 #[cfg(feature = "sqlx-mysql")]
135 if DbBackend::MySql.is_prefix_of(&opt.url) {
136 return crate::SqlxMySqlConnector::connect(opt).await;
137 }
138 #[cfg(feature = "sqlx-postgres")]
139 if DbBackend::Postgres.is_prefix_of(&opt.url) {
140 return crate::SqlxPostgresConnector::connect(opt).await;
141 }
142 #[cfg(feature = "sqlx-sqlite")]
143 if DbBackend::Sqlite.is_prefix_of(&opt.url) {
144 return crate::SqlxSqliteConnector::connect(opt).await;
145 }
146 #[cfg(feature = "rusqlite")]
147 if DbBackend::Sqlite.is_prefix_of(&opt.url) {
148 return crate::driver::rusqlite::RusqliteConnector::connect(opt);
149 }
150 #[cfg(feature = "mock")]
151 if crate::MockDatabaseConnector::accepts(&opt.url) {
152 return crate::MockDatabaseConnector::connect(&opt.url).await;
153 }
154
155 Err(conn_err(format!(
156 "The connection string '{}' has no supporting driver.",
157 opt.url
158 )))
159 }
160
161 #[cfg(feature = "proxy")]
163 #[instrument(level = "trace", skip(proxy_func_arc))]
164 pub async fn connect_proxy(
165 db_type: DbBackend,
166 proxy_func_arc: std::sync::Arc<Box<dyn ProxyDatabaseTrait>>,
167 ) -> Result<DatabaseConnection, DbErr> {
168 match db_type {
169 DbBackend::MySql => {
170 return crate::ProxyDatabaseConnector::connect(
171 DbBackend::MySql,
172 proxy_func_arc.to_owned(),
173 );
174 }
175 DbBackend::Postgres => {
176 return crate::ProxyDatabaseConnector::connect(
177 DbBackend::Postgres,
178 proxy_func_arc.to_owned(),
179 );
180 }
181 DbBackend::Sqlite => {
182 return crate::ProxyDatabaseConnector::connect(
183 DbBackend::Sqlite,
184 proxy_func_arc.to_owned(),
185 );
186 }
187 }
188 }
189}
190
191impl<T> From<T> for ConnectOptions
192where
193 T: Into<String>,
194{
195 fn from(s: T) -> ConnectOptions {
196 ConnectOptions::new(s.into())
197 }
198}
199
200impl ConnectOptions {
201 pub fn new<T>(url: T) -> Self
203 where
204 T: Into<String>,
205 {
206 Self {
207 url: url.into(),
208 max_connections: None,
209 min_connections: None,
210 connect_timeout: None,
211 idle_timeout: None,
212 acquire_timeout: None,
213 max_lifetime: None,
214 sqlx_logging: true,
215 sqlx_logging_level: log::LevelFilter::Info,
216 sqlx_slow_statements_logging_level: log::LevelFilter::Off,
217 sqlx_slow_statements_logging_threshold: Duration::from_secs(1),
218 sqlcipher_key: None,
219 schema_search_path: None,
220 application_name: None,
221 test_before_acquire: true,
222 connect_lazy: false,
223 after_connect: None,
224 #[cfg(feature = "sqlx-mysql")]
225 mysql_opts_fn: None,
226 #[cfg(feature = "sqlx-postgres")]
227 pg_opts_fn: None,
228 #[cfg(feature = "sqlx-sqlite")]
229 sqlite_opts_fn: None,
230 }
231 }
232
233 pub fn get_url(&self) -> &str {
235 &self.url
236 }
237
238 pub fn max_connections(&mut self, value: u32) -> &mut Self {
240 self.max_connections = Some(value);
241 self
242 }
243
244 pub fn get_max_connections(&self) -> Option<u32> {
246 self.max_connections
247 }
248
249 pub fn min_connections(&mut self, value: u32) -> &mut Self {
251 self.min_connections = Some(value);
252 self
253 }
254
255 pub fn get_min_connections(&self) -> Option<u32> {
257 self.min_connections
258 }
259
260 pub fn connect_timeout(&mut self, value: Duration) -> &mut Self {
262 self.connect_timeout = Some(value);
263 self
264 }
265
266 pub fn get_connect_timeout(&self) -> Option<Duration> {
268 self.connect_timeout
269 }
270
271 pub fn idle_timeout<T>(&mut self, value: T) -> &mut Self
273 where
274 T: Into<Option<Duration>>,
275 {
276 self.idle_timeout = Some(value.into());
277 self
278 }
279
280 pub fn get_idle_timeout(&self) -> Option<Option<Duration>> {
282 self.idle_timeout
283 }
284
285 pub fn acquire_timeout(&mut self, value: Duration) -> &mut Self {
287 self.acquire_timeout = Some(value);
288 self
289 }
290
291 pub fn get_acquire_timeout(&self) -> Option<Duration> {
293 self.acquire_timeout
294 }
295
296 pub fn max_lifetime<T>(&mut self, lifetime: T) -> &mut Self
298 where
299 T: Into<Option<Duration>>,
300 {
301 self.max_lifetime = Some(lifetime.into());
302 self
303 }
304
305 pub fn get_max_lifetime(&self) -> Option<Option<Duration>> {
307 self.max_lifetime
308 }
309
310 pub fn sqlx_logging(&mut self, value: bool) -> &mut Self {
312 self.sqlx_logging = value;
313 self
314 }
315
316 pub fn get_sqlx_logging(&self) -> bool {
318 self.sqlx_logging
319 }
320
321 pub fn sqlx_logging_level(&mut self, level: log::LevelFilter) -> &mut Self {
324 self.sqlx_logging_level = level;
325 self
326 }
327
328 pub fn sqlx_slow_statements_logging_settings(
331 &mut self,
332 level: log::LevelFilter,
333 duration: Duration,
334 ) -> &mut Self {
335 self.sqlx_slow_statements_logging_level = level;
336 self.sqlx_slow_statements_logging_threshold = duration;
337 self
338 }
339
340 pub fn get_sqlx_logging_level(&self) -> log::LevelFilter {
342 self.sqlx_logging_level
343 }
344
345 pub fn get_sqlx_slow_statements_logging_settings(&self) -> (log::LevelFilter, Duration) {
347 (
348 self.sqlx_slow_statements_logging_level,
349 self.sqlx_slow_statements_logging_threshold,
350 )
351 }
352
353 pub fn sqlcipher_key<T>(&mut self, value: T) -> &mut Self
355 where
356 T: Into<Cow<'static, str>>,
357 {
358 self.sqlcipher_key = Some(value.into());
359 self
360 }
361
362 pub fn set_schema_search_path<T>(&mut self, schema_search_path: T) -> &mut Self
364 where
365 T: Into<String>,
366 {
367 self.schema_search_path = Some(schema_search_path.into());
368 self
369 }
370
371 pub fn set_application_name<T>(&mut self, application_name: T) -> &mut Self
373 where
374 T: Into<String>,
375 {
376 self.application_name = Some(application_name.into());
377 self
378 }
379
380 pub fn test_before_acquire(&mut self, value: bool) -> &mut Self {
382 self.test_before_acquire = value;
383 self
384 }
385
386 pub fn connect_lazy(&mut self, value: bool) -> &mut Self {
389 self.connect_lazy = value;
390 self
391 }
392
393 pub fn get_connect_lazy(&self) -> bool {
395 self.connect_lazy
396 }
397
398 pub fn after_connect<F>(&mut self, f: F) -> &mut Self
400 where
401 F: Fn(DatabaseConnection) -> BoxFuture<'static, Result<(), DbErr>> + Send + Sync + 'static,
402 {
403 self.after_connect = Some(Arc::new(f));
404
405 self
406 }
407
408 #[cfg(feature = "sqlx-mysql")]
409 #[cfg_attr(docsrs, doc(cfg(feature = "sqlx-mysql")))]
410 pub fn map_sqlx_mysql_opts<F>(&mut self, f: F) -> &mut Self
413 where
414 F: Fn(MySqlConnectOptions) -> MySqlConnectOptions + Send + Sync + 'static,
415 {
416 self.mysql_opts_fn = Some(Arc::new(f));
417 self
418 }
419
420 #[cfg(feature = "sqlx-postgres")]
421 #[cfg_attr(docsrs, doc(cfg(feature = "sqlx-postgres")))]
422 pub fn map_sqlx_postgres_opts<F>(&mut self, f: F) -> &mut Self
425 where
426 F: Fn(PgConnectOptions) -> PgConnectOptions + Send + Sync + 'static,
427 {
428 self.pg_opts_fn = Some(Arc::new(f));
429 self
430 }
431
432 #[cfg(feature = "sqlx-sqlite")]
433 #[cfg_attr(docsrs, doc(cfg(feature = "sqlx-sqlite")))]
434 pub fn map_sqlx_sqlite_opts<F>(&mut self, f: F) -> &mut Self
437 where
438 F: Fn(SqliteConnectOptions) -> SqliteConnectOptions + Send + Sync + 'static,
439 {
440 self.sqlite_opts_fn = Some(Arc::new(f));
441 self
442 }
443}