1use std::{sync::Arc, time::Duration};
2
3#[cfg(feature = "sqlx-mysql")]
4use sqlx::mysql::MySqlConnectOptions;
5#[cfg(feature = "sqlx-postgres")]
6use sqlx::postgres::PgConnectOptions;
7#[cfg(feature = "sqlx-sqlite")]
8use sqlx::sqlite::SqliteConnectOptions;
9
10mod connection;
11mod db_connection;
12mod executor;
13#[cfg(feature = "mock")]
14#[cfg_attr(docsrs, doc(cfg(feature = "mock")))]
15mod mock;
16#[cfg(feature = "proxy")]
17#[cfg_attr(docsrs, doc(cfg(feature = "proxy")))]
18mod proxy;
19#[cfg(feature = "rbac")]
20mod restricted_connection;
21#[cfg(all(feature = "schema-sync", feature = "rusqlite"))]
22mod sea_schema_rusqlite;
23#[cfg(all(feature = "schema-sync", feature = "sqlx-dep"))]
24mod sea_schema_shim;
25mod statement;
26mod stream;
27mod tracing_spans;
28mod transaction;
29
30pub use connection::*;
31pub use db_connection::*;
32pub use executor::*;
33#[cfg(feature = "mock")]
34#[cfg_attr(docsrs, doc(cfg(feature = "mock")))]
35pub use mock::*;
36#[cfg(feature = "proxy")]
37#[cfg_attr(docsrs, doc(cfg(feature = "proxy")))]
38pub use proxy::*;
39#[cfg(feature = "rbac")]
40pub use restricted_connection::*;
41pub use statement::*;
42use std::borrow::Cow;
43pub use stream::*;
44use tracing::instrument;
45pub use transaction::*;
46
47use crate::error::*;
48
49#[derive(Debug, Default)]
51pub struct Database;
52
53#[cfg(feature = "sync")]
54type BoxFuture<'a, T> = T;
55
56type AfterConnectCallback =
57 Option<Arc<dyn Fn(DatabaseConnection) -> BoxFuture<'static, Result<(), DbErr>> + 'static>>;
58
59#[derive(derive_more::Debug, Clone)]
61pub struct ConnectOptions {
62 pub(crate) url: String,
64 pub(crate) max_connections: Option<u32>,
66 pub(crate) min_connections: Option<u32>,
68 pub(crate) connect_timeout: Option<Duration>,
70 pub(crate) idle_timeout: Option<Option<Duration>>,
73 pub(crate) acquire_timeout: Option<Duration>,
75 pub(crate) max_lifetime: Option<Option<Duration>>,
77 pub(crate) sqlx_logging: bool,
79 pub(crate) sqlx_logging_level: log::LevelFilter,
81 pub(crate) sqlx_slow_statements_logging_level: log::LevelFilter,
83 pub(crate) sqlx_slow_statements_logging_threshold: Duration,
85 pub(crate) sqlcipher_key: Option<Cow<'static, str>>,
87 pub(crate) schema_search_path: Option<String>,
89 pub(crate) application_name: Option<String>,
91 pub(crate) test_before_acquire: bool,
92 pub(crate) connect_lazy: bool,
96
97 #[debug(skip)]
98 pub(crate) after_connect: AfterConnectCallback,
99
100 #[cfg(feature = "sqlx-mysql")]
101 #[debug(skip)]
102 pub(crate) mysql_opts_fn: Option<Arc<dyn Fn(MySqlConnectOptions) -> MySqlConnectOptions>>,
103 #[cfg(feature = "sqlx-postgres")]
104 #[debug(skip)]
105 pub(crate) pg_opts_fn: Option<Arc<dyn Fn(PgConnectOptions) -> PgConnectOptions>>,
106 #[cfg(feature = "sqlx-sqlite")]
107 #[debug(skip)]
108 pub(crate) sqlite_opts_fn: Option<Arc<dyn Fn(SqliteConnectOptions) -> SqliteConnectOptions>>,
109}
110
111impl Database {
112 #[instrument(level = "trace", skip(opt))]
115 pub fn connect<C>(opt: C) -> Result<DatabaseConnection, DbErr>
116 where
117 C: Into<ConnectOptions>,
118 {
119 let opt: ConnectOptions = opt.into();
120
121 if url::Url::parse(&opt.url).is_err() {
122 return Err(conn_err(format!(
123 "The connection string '{}' cannot be parsed.",
124 opt.url
125 )));
126 }
127
128 #[cfg(feature = "sqlx-mysql")]
129 if DbBackend::MySql.is_prefix_of(&opt.url) {
130 return crate::SqlxMySqlConnector::connect(opt);
131 }
132 #[cfg(feature = "sqlx-postgres")]
133 if DbBackend::Postgres.is_prefix_of(&opt.url) {
134 return crate::SqlxPostgresConnector::connect(opt);
135 }
136 #[cfg(feature = "sqlx-sqlite")]
137 if DbBackend::Sqlite.is_prefix_of(&opt.url) {
138 return crate::SqlxSqliteConnector::connect(opt);
139 }
140 #[cfg(feature = "rusqlite")]
141 if DbBackend::Sqlite.is_prefix_of(&opt.url) {
142 return crate::driver::rusqlite::RusqliteConnector::connect(opt);
143 }
144 #[cfg(feature = "mock")]
145 if crate::MockDatabaseConnector::accepts(&opt.url) {
146 return crate::MockDatabaseConnector::connect(&opt.url);
147 }
148
149 Err(conn_err(format!(
150 "The connection string '{}' has no supporting driver.",
151 opt.url
152 )))
153 }
154
155 #[cfg(feature = "proxy")]
157 #[instrument(level = "trace", skip(proxy_func_arc))]
158 pub fn connect_proxy(
159 db_type: DbBackend,
160 proxy_func_arc: std::sync::Arc<Box<dyn ProxyDatabaseTrait>>,
161 ) -> Result<DatabaseConnection, DbErr> {
162 match db_type {
163 DbBackend::MySql => {
164 return crate::ProxyDatabaseConnector::connect(
165 DbBackend::MySql,
166 proxy_func_arc.to_owned(),
167 );
168 }
169 DbBackend::Postgres => {
170 return crate::ProxyDatabaseConnector::connect(
171 DbBackend::Postgres,
172 proxy_func_arc.to_owned(),
173 );
174 }
175 DbBackend::Sqlite => {
176 return crate::ProxyDatabaseConnector::connect(
177 DbBackend::Sqlite,
178 proxy_func_arc.to_owned(),
179 );
180 }
181 }
182 }
183}
184
185impl<T> From<T> for ConnectOptions
186where
187 T: Into<String>,
188{
189 fn from(s: T) -> ConnectOptions {
190 ConnectOptions::new(s.into())
191 }
192}
193
194impl ConnectOptions {
195 pub fn new<T>(url: T) -> Self
197 where
198 T: Into<String>,
199 {
200 Self {
201 url: url.into(),
202 max_connections: None,
203 min_connections: None,
204 connect_timeout: None,
205 idle_timeout: None,
206 acquire_timeout: None,
207 max_lifetime: None,
208 sqlx_logging: true,
209 sqlx_logging_level: log::LevelFilter::Info,
210 sqlx_slow_statements_logging_level: log::LevelFilter::Off,
211 sqlx_slow_statements_logging_threshold: Duration::from_secs(1),
212 sqlcipher_key: None,
213 schema_search_path: None,
214 application_name: None,
215 test_before_acquire: true,
216 connect_lazy: false,
217 after_connect: None,
218 #[cfg(feature = "sqlx-mysql")]
219 mysql_opts_fn: None,
220 #[cfg(feature = "sqlx-postgres")]
221 pg_opts_fn: None,
222 #[cfg(feature = "sqlx-sqlite")]
223 sqlite_opts_fn: None,
224 }
225 }
226
227 pub fn get_url(&self) -> &str {
229 &self.url
230 }
231
232 pub fn max_connections(&mut self, value: u32) -> &mut Self {
234 self.max_connections = Some(value);
235 self
236 }
237
238 pub fn get_max_connections(&self) -> Option<u32> {
240 self.max_connections
241 }
242
243 pub fn min_connections(&mut self, value: u32) -> &mut Self {
245 self.min_connections = Some(value);
246 self
247 }
248
249 pub fn get_min_connections(&self) -> Option<u32> {
251 self.min_connections
252 }
253
254 pub fn connect_timeout(&mut self, value: Duration) -> &mut Self {
256 self.connect_timeout = Some(value);
257 self
258 }
259
260 pub fn get_connect_timeout(&self) -> Option<Duration> {
262 self.connect_timeout
263 }
264
265 pub fn idle_timeout<T>(&mut self, value: T) -> &mut Self
267 where
268 T: Into<Option<Duration>>,
269 {
270 self.idle_timeout = Some(value.into());
271 self
272 }
273
274 pub fn get_idle_timeout(&self) -> Option<Option<Duration>> {
276 self.idle_timeout
277 }
278
279 pub fn acquire_timeout(&mut self, value: Duration) -> &mut Self {
281 self.acquire_timeout = Some(value);
282 self
283 }
284
285 pub fn get_acquire_timeout(&self) -> Option<Duration> {
287 self.acquire_timeout
288 }
289
290 pub fn max_lifetime<T>(&mut self, lifetime: T) -> &mut Self
292 where
293 T: Into<Option<Duration>>,
294 {
295 self.max_lifetime = Some(lifetime.into());
296 self
297 }
298
299 pub fn get_max_lifetime(&self) -> Option<Option<Duration>> {
301 self.max_lifetime
302 }
303
304 pub fn sqlx_logging(&mut self, value: bool) -> &mut Self {
306 self.sqlx_logging = value;
307 self
308 }
309
310 pub fn get_sqlx_logging(&self) -> bool {
312 self.sqlx_logging
313 }
314
315 pub fn sqlx_logging_level(&mut self, level: log::LevelFilter) -> &mut Self {
318 self.sqlx_logging_level = level;
319 self
320 }
321
322 pub fn sqlx_slow_statements_logging_settings(
325 &mut self,
326 level: log::LevelFilter,
327 duration: Duration,
328 ) -> &mut Self {
329 self.sqlx_slow_statements_logging_level = level;
330 self.sqlx_slow_statements_logging_threshold = duration;
331 self
332 }
333
334 pub fn get_sqlx_logging_level(&self) -> log::LevelFilter {
336 self.sqlx_logging_level
337 }
338
339 pub fn get_sqlx_slow_statements_logging_settings(&self) -> (log::LevelFilter, Duration) {
341 (
342 self.sqlx_slow_statements_logging_level,
343 self.sqlx_slow_statements_logging_threshold,
344 )
345 }
346
347 pub fn sqlcipher_key<T>(&mut self, value: T) -> &mut Self
349 where
350 T: Into<Cow<'static, str>>,
351 {
352 self.sqlcipher_key = Some(value.into());
353 self
354 }
355
356 pub fn set_schema_search_path<T>(&mut self, schema_search_path: T) -> &mut Self
358 where
359 T: Into<String>,
360 {
361 self.schema_search_path = Some(schema_search_path.into());
362 self
363 }
364
365 pub fn set_application_name<T>(&mut self, application_name: T) -> &mut Self
367 where
368 T: Into<String>,
369 {
370 self.application_name = Some(application_name.into());
371 self
372 }
373
374 pub fn test_before_acquire(&mut self, value: bool) -> &mut Self {
376 self.test_before_acquire = value;
377 self
378 }
379
380 pub fn connect_lazy(&mut self, value: bool) -> &mut Self {
383 self.connect_lazy = value;
384 self
385 }
386
387 pub fn get_connect_lazy(&self) -> bool {
389 self.connect_lazy
390 }
391
392 pub fn after_connect<F>(&mut self, f: F) -> &mut Self
394 where
395 F: Fn(DatabaseConnection) -> BoxFuture<'static, Result<(), DbErr>> + 'static,
396 {
397 self.after_connect = Some(Arc::new(f));
398
399 self
400 }
401
402 #[cfg(feature = "sqlx-mysql")]
403 #[cfg_attr(docsrs, doc(cfg(feature = "sqlx-mysql")))]
404 pub fn map_sqlx_mysql_opts<F>(&mut self, f: F) -> &mut Self
407 where
408 F: Fn(MySqlConnectOptions) -> MySqlConnectOptions + 'static,
409 {
410 self.mysql_opts_fn = Some(Arc::new(f));
411 self
412 }
413
414 #[cfg(feature = "sqlx-postgres")]
415 #[cfg_attr(docsrs, doc(cfg(feature = "sqlx-postgres")))]
416 pub fn map_sqlx_postgres_opts<F>(&mut self, f: F) -> &mut Self
419 where
420 F: Fn(PgConnectOptions) -> PgConnectOptions + 'static,
421 {
422 self.pg_opts_fn = Some(Arc::new(f));
423 self
424 }
425
426 #[cfg(feature = "sqlx-sqlite")]
427 #[cfg_attr(docsrs, doc(cfg(feature = "sqlx-sqlite")))]
428 pub fn map_sqlx_sqlite_opts<F>(&mut self, f: F) -> &mut Self
431 where
432 F: Fn(SqliteConnectOptions) -> SqliteConnectOptions + 'static,
433 {
434 self.sqlite_opts_fn = Some(Arc::new(f));
435 self
436 }
437}