1use std::{sync::Arc, time::Duration};
2
3#[cfg(feature = "sqlx-mysql")]
4use sqlx::mysql::MySqlConnectOptions;
5#[cfg(feature = "sqlx-postgres")]
6use sqlx::postgres::PgConnectOptions;
7#[cfg(feature = "sqlx-sqlite")]
8use sqlx::sqlite::SqliteConnectOptions;
9
10mod connection;
11mod db_connection;
12mod executor;
13#[cfg(feature = "mock")]
14#[cfg_attr(docsrs, doc(cfg(feature = "mock")))]
15mod mock;
16#[cfg(feature = "proxy")]
17#[cfg_attr(docsrs, doc(cfg(feature = "proxy")))]
18mod proxy;
19#[cfg(feature = "rbac")]
20mod restricted_connection;
21#[cfg(all(feature = "schema-sync", feature = "rusqlite"))]
22mod sea_schema_rusqlite;
23#[cfg(all(feature = "schema-sync", feature = "sqlx-dep"))]
24mod sea_schema_shim;
25mod statement;
26mod stream;
27mod tracing_spans;
28mod transaction;
29
30pub use connection::*;
31pub use db_connection::*;
32pub use executor::*;
33#[cfg(feature = "mock")]
34#[cfg_attr(docsrs, doc(cfg(feature = "mock")))]
35pub use mock::*;
36#[cfg(feature = "proxy")]
37#[cfg_attr(docsrs, doc(cfg(feature = "proxy")))]
38pub use proxy::*;
39#[cfg(feature = "rbac")]
40pub use restricted_connection::*;
41pub use statement::*;
42use std::borrow::Cow;
43pub use stream::*;
44use tracing::instrument;
45pub use transaction::*;
46
47use crate::error::*;
48
49#[derive(Debug, Default)]
51pub struct Database;
52
53#[cfg(feature = "sync")]
54type BoxFuture<'a, T> = T;
55
56type AfterConnectCallback =
57 Option<Arc<dyn Fn(DatabaseConnection) -> BoxFuture<'static, Result<(), DbErr>> + 'static>>;
58
59#[derive(derive_more::Debug, Clone)]
61pub struct ConnectOptions {
62 pub(crate) url: String,
64 pub(crate) max_connections: Option<u32>,
66 pub(crate) min_connections: Option<u32>,
68 pub(crate) connect_timeout: Option<Duration>,
70 pub(crate) idle_timeout: Option<Option<Duration>>,
73 pub(crate) acquire_timeout: Option<Duration>,
75 pub(crate) max_lifetime: Option<Option<Duration>>,
77 pub(crate) sqlx_logging: bool,
79 pub(crate) sqlx_logging_level: log::LevelFilter,
81 pub(crate) sqlx_slow_statements_logging_level: log::LevelFilter,
83 pub(crate) sqlx_slow_statements_logging_threshold: Duration,
85 pub(crate) sqlcipher_key: Option<Cow<'static, str>>,
87 pub(crate) schema_search_path: Option<String>,
89 pub(crate) application_name: Option<String>,
91 pub(crate) statement_timeout: Option<Duration>,
93 pub(crate) test_before_acquire: bool,
94 pub(crate) connect_lazy: bool,
98
99 #[debug(skip)]
100 pub(crate) after_connect: AfterConnectCallback,
101
102 #[cfg(feature = "sqlx-mysql")]
103 #[debug(skip)]
104 pub(crate) mysql_opts_fn: Option<Arc<dyn Fn(MySqlConnectOptions) -> MySqlConnectOptions>>,
105 #[cfg(feature = "sqlx-postgres")]
106 #[debug(skip)]
107 pub(crate) pg_opts_fn: Option<Arc<dyn Fn(PgConnectOptions) -> PgConnectOptions>>,
108 #[cfg(feature = "sqlx-sqlite")]
109 #[debug(skip)]
110 pub(crate) sqlite_opts_fn: Option<Arc<dyn Fn(SqliteConnectOptions) -> SqliteConnectOptions>>,
111}
112
113impl Database {
114 #[instrument(level = "trace", skip(opt))]
117 pub fn connect<C>(opt: C) -> Result<DatabaseConnection, DbErr>
118 where
119 C: Into<ConnectOptions>,
120 {
121 let opt: ConnectOptions = opt.into();
122
123 if url::Url::parse(&opt.url).is_err() {
124 return Err(conn_err(format!(
125 "The connection string '{}' cannot be parsed.",
126 opt.url
127 )));
128 }
129
130 #[cfg(feature = "sqlx-mysql")]
131 if DbBackend::MySql.is_prefix_of(&opt.url) {
132 return crate::SqlxMySqlConnector::connect(opt);
133 }
134 #[cfg(feature = "sqlx-postgres")]
135 if DbBackend::Postgres.is_prefix_of(&opt.url) {
136 return crate::SqlxPostgresConnector::connect(opt);
137 }
138 #[cfg(feature = "sqlx-sqlite")]
139 if DbBackend::Sqlite.is_prefix_of(&opt.url) {
140 return crate::SqlxSqliteConnector::connect(opt);
141 }
142 #[cfg(feature = "rusqlite")]
143 if DbBackend::Sqlite.is_prefix_of(&opt.url) {
144 return crate::driver::rusqlite::RusqliteConnector::connect(opt);
145 }
146 #[cfg(feature = "mock")]
147 if crate::MockDatabaseConnector::accepts(&opt.url) {
148 return crate::MockDatabaseConnector::connect(&opt.url);
149 }
150
151 Err(conn_err(format!(
152 "The connection string '{}' has no supporting driver.",
153 opt.url
154 )))
155 }
156
157 #[cfg(feature = "proxy")]
159 #[instrument(level = "trace", skip(proxy_func_arc))]
160 pub fn connect_proxy(
161 db_type: DbBackend,
162 proxy_func_arc: std::sync::Arc<Box<dyn ProxyDatabaseTrait>>,
163 ) -> Result<DatabaseConnection, DbErr> {
164 match db_type {
165 DbBackend::MySql => {
166 return crate::ProxyDatabaseConnector::connect(
167 DbBackend::MySql,
168 proxy_func_arc.to_owned(),
169 );
170 }
171 DbBackend::Postgres => {
172 return crate::ProxyDatabaseConnector::connect(
173 DbBackend::Postgres,
174 proxy_func_arc.to_owned(),
175 );
176 }
177 DbBackend::Sqlite => {
178 return crate::ProxyDatabaseConnector::connect(
179 DbBackend::Sqlite,
180 proxy_func_arc.to_owned(),
181 );
182 }
183 }
184 }
185}
186
187impl<T> From<T> for ConnectOptions
188where
189 T: Into<String>,
190{
191 fn from(s: T) -> ConnectOptions {
192 ConnectOptions::new(s.into())
193 }
194}
195
196impl ConnectOptions {
197 pub fn new<T>(url: T) -> Self
199 where
200 T: Into<String>,
201 {
202 Self {
203 url: url.into(),
204 max_connections: None,
205 min_connections: None,
206 connect_timeout: None,
207 idle_timeout: None,
208 acquire_timeout: None,
209 max_lifetime: None,
210 sqlx_logging: true,
211 sqlx_logging_level: log::LevelFilter::Info,
212 sqlx_slow_statements_logging_level: log::LevelFilter::Off,
213 sqlx_slow_statements_logging_threshold: Duration::from_secs(1),
214 sqlcipher_key: None,
215 schema_search_path: None,
216 application_name: None,
217 statement_timeout: None,
218 test_before_acquire: true,
219 connect_lazy: false,
220 after_connect: None,
221 #[cfg(feature = "sqlx-mysql")]
222 mysql_opts_fn: None,
223 #[cfg(feature = "sqlx-postgres")]
224 pg_opts_fn: None,
225 #[cfg(feature = "sqlx-sqlite")]
226 sqlite_opts_fn: None,
227 }
228 }
229
230 pub fn get_url(&self) -> &str {
232 &self.url
233 }
234
235 pub fn max_connections(&mut self, value: u32) -> &mut Self {
237 self.max_connections = Some(value);
238 self
239 }
240
241 pub fn get_max_connections(&self) -> Option<u32> {
243 self.max_connections
244 }
245
246 pub fn min_connections(&mut self, value: u32) -> &mut Self {
248 self.min_connections = Some(value);
249 self
250 }
251
252 pub fn get_min_connections(&self) -> Option<u32> {
254 self.min_connections
255 }
256
257 pub fn connect_timeout(&mut self, value: Duration) -> &mut Self {
259 self.connect_timeout = Some(value);
260 self
261 }
262
263 pub fn get_connect_timeout(&self) -> Option<Duration> {
265 self.connect_timeout
266 }
267
268 pub fn idle_timeout<T>(&mut self, value: T) -> &mut Self
270 where
271 T: Into<Option<Duration>>,
272 {
273 self.idle_timeout = Some(value.into());
274 self
275 }
276
277 pub fn get_idle_timeout(&self) -> Option<Option<Duration>> {
279 self.idle_timeout
280 }
281
282 pub fn acquire_timeout(&mut self, value: Duration) -> &mut Self {
284 self.acquire_timeout = Some(value);
285 self
286 }
287
288 pub fn get_acquire_timeout(&self) -> Option<Duration> {
290 self.acquire_timeout
291 }
292
293 pub fn max_lifetime<T>(&mut self, lifetime: T) -> &mut Self
295 where
296 T: Into<Option<Duration>>,
297 {
298 self.max_lifetime = Some(lifetime.into());
299 self
300 }
301
302 pub fn get_max_lifetime(&self) -> Option<Option<Duration>> {
304 self.max_lifetime
305 }
306
307 pub fn sqlx_logging(&mut self, value: bool) -> &mut Self {
309 self.sqlx_logging = value;
310 self
311 }
312
313 pub fn get_sqlx_logging(&self) -> bool {
315 self.sqlx_logging
316 }
317
318 pub fn sqlx_logging_level(&mut self, level: log::LevelFilter) -> &mut Self {
321 self.sqlx_logging_level = level;
322 self
323 }
324
325 pub fn sqlx_slow_statements_logging_settings(
328 &mut self,
329 level: log::LevelFilter,
330 duration: Duration,
331 ) -> &mut Self {
332 self.sqlx_slow_statements_logging_level = level;
333 self.sqlx_slow_statements_logging_threshold = duration;
334 self
335 }
336
337 pub fn get_sqlx_logging_level(&self) -> log::LevelFilter {
339 self.sqlx_logging_level
340 }
341
342 pub fn get_sqlx_slow_statements_logging_settings(&self) -> (log::LevelFilter, Duration) {
344 (
345 self.sqlx_slow_statements_logging_level,
346 self.sqlx_slow_statements_logging_threshold,
347 )
348 }
349
350 pub fn sqlcipher_key<T>(&mut self, value: T) -> &mut Self
352 where
353 T: Into<Cow<'static, str>>,
354 {
355 self.sqlcipher_key = Some(value.into());
356 self
357 }
358
359 pub fn set_schema_search_path<T>(&mut self, schema_search_path: T) -> &mut Self
361 where
362 T: Into<String>,
363 {
364 self.schema_search_path = Some(schema_search_path.into());
365 self
366 }
367
368 pub fn set_application_name<T>(&mut self, application_name: T) -> &mut Self
370 where
371 T: Into<String>,
372 {
373 self.application_name = Some(application_name.into());
374 self
375 }
376
377 pub fn statement_timeout(&mut self, value: Duration) -> &mut Self {
385 self.statement_timeout = Some(value);
386 self
387 }
388
389 pub fn get_statement_timeout(&self) -> Option<Duration> {
391 self.statement_timeout
392 }
393
394 pub fn test_before_acquire(&mut self, value: bool) -> &mut Self {
396 self.test_before_acquire = value;
397 self
398 }
399
400 pub fn connect_lazy(&mut self, value: bool) -> &mut Self {
403 self.connect_lazy = value;
404 self
405 }
406
407 pub fn get_connect_lazy(&self) -> bool {
409 self.connect_lazy
410 }
411
412 pub fn after_connect<F>(&mut self, f: F) -> &mut Self
414 where
415 F: Fn(DatabaseConnection) -> BoxFuture<'static, Result<(), DbErr>> + 'static,
416 {
417 self.after_connect = Some(Arc::new(f));
418
419 self
420 }
421
422 #[cfg(feature = "sqlx-mysql")]
423 #[cfg_attr(docsrs, doc(cfg(feature = "sqlx-mysql")))]
424 pub fn map_sqlx_mysql_opts<F>(&mut self, f: F) -> &mut Self
427 where
428 F: Fn(MySqlConnectOptions) -> MySqlConnectOptions + 'static,
429 {
430 self.mysql_opts_fn = Some(Arc::new(f));
431 self
432 }
433
434 #[cfg(feature = "sqlx-postgres")]
435 #[cfg_attr(docsrs, doc(cfg(feature = "sqlx-postgres")))]
436 pub fn map_sqlx_postgres_opts<F>(&mut self, f: F) -> &mut Self
439 where
440 F: Fn(PgConnectOptions) -> PgConnectOptions + 'static,
441 {
442 self.pg_opts_fn = Some(Arc::new(f));
443 self
444 }
445
446 #[cfg(feature = "sqlx-sqlite")]
447 #[cfg_attr(docsrs, doc(cfg(feature = "sqlx-sqlite")))]
448 pub fn map_sqlx_sqlite_opts<F>(&mut self, f: F) -> &mut Self
451 where
452 F: Fn(SqliteConnectOptions) -> SqliteConnectOptions + 'static,
453 {
454 self.sqlite_opts_fn = Some(Arc::new(f));
455 self
456 }
457}