1use std::{sync::Arc, time::Duration};
2
3use futures_util::future::BoxFuture;
4#[cfg(feature = "sqlx-mysql")]
5use sqlx::mysql::MySqlConnectOptions;
6#[cfg(feature = "sqlx-postgres")]
7use sqlx::postgres::PgConnectOptions;
8#[cfg(feature = "sqlx-sqlite")]
9use sqlx::sqlite::SqliteConnectOptions;
10
11mod connection;
12mod db_connection;
13mod executor;
14#[cfg(feature = "mock")]
15#[cfg_attr(docsrs, doc(cfg(feature = "mock")))]
16mod mock;
17#[cfg(feature = "proxy")]
18#[cfg_attr(docsrs, doc(cfg(feature = "proxy")))]
19mod proxy;
20#[cfg(feature = "rbac")]
21mod restricted_connection;
22#[cfg(all(feature = "schema-sync", feature = "rusqlite"))]
23mod sea_schema_rusqlite;
24#[cfg(all(feature = "schema-sync", feature = "sqlx-dep"))]
25mod sea_schema_shim;
26mod statement;
27mod stream;
28mod transaction;
29
30pub use connection::*;
31pub use db_connection::*;
32pub use executor::*;
33#[cfg(feature = "mock")]
34#[cfg_attr(docsrs, doc(cfg(feature = "mock")))]
35pub use mock::*;
36#[cfg(feature = "proxy")]
37#[cfg_attr(docsrs, doc(cfg(feature = "proxy")))]
38pub use proxy::*;
39#[cfg(feature = "rbac")]
40pub use restricted_connection::*;
41pub use statement::*;
42use std::borrow::Cow;
43pub use stream::*;
44use tracing::instrument;
45pub use transaction::*;
46
47use crate::error::*;
48
49#[derive(Debug, Default)]
51pub struct Database;
52
53#[cfg(feature = "sync")]
54type BoxFuture<'a, T> = T;
55
56type AfterConnectCallback = Option<
57 Arc<
58 dyn Fn(DatabaseConnection) -> BoxFuture<'static, Result<(), DbErr>> + Send + Sync + 'static,
59 >,
60>;
61
62#[derive(derive_more::Debug, Clone)]
64pub struct ConnectOptions {
65 pub(crate) url: String,
67 pub(crate) max_connections: Option<u32>,
69 pub(crate) min_connections: Option<u32>,
71 pub(crate) connect_timeout: Option<Duration>,
73 pub(crate) idle_timeout: Option<Option<Duration>>,
76 pub(crate) acquire_timeout: Option<Duration>,
78 pub(crate) max_lifetime: Option<Option<Duration>>,
80 pub(crate) sqlx_logging: bool,
82 pub(crate) sqlx_logging_level: log::LevelFilter,
84 pub(crate) sqlx_slow_statements_logging_level: log::LevelFilter,
86 pub(crate) sqlx_slow_statements_logging_threshold: Duration,
88 pub(crate) sqlcipher_key: Option<Cow<'static, str>>,
90 pub(crate) schema_search_path: Option<String>,
92 pub(crate) test_before_acquire: bool,
93 pub(crate) connect_lazy: bool,
97
98 #[debug(skip)]
99 pub(crate) after_connect: AfterConnectCallback,
100
101 #[cfg(feature = "sqlx-mysql")]
102 #[debug(skip)]
103 pub(crate) mysql_opts_fn:
104 Option<Arc<dyn Fn(MySqlConnectOptions) -> MySqlConnectOptions + Send + Sync>>,
105 #[cfg(feature = "sqlx-postgres")]
106 #[debug(skip)]
107 pub(crate) pg_opts_fn: Option<Arc<dyn Fn(PgConnectOptions) -> PgConnectOptions + Send + Sync>>,
108 #[cfg(feature = "sqlx-sqlite")]
109 #[debug(skip)]
110 pub(crate) sqlite_opts_fn:
111 Option<Arc<dyn Fn(SqliteConnectOptions) -> SqliteConnectOptions + Send + Sync>>,
112}
113
114impl Database {
115 #[instrument(level = "trace", skip(opt))]
118 pub async fn connect<C>(opt: C) -> Result<DatabaseConnection, DbErr>
119 where
120 C: Into<ConnectOptions>,
121 {
122 let opt: ConnectOptions = opt.into();
123
124 if url::Url::parse(&opt.url).is_err() {
125 return Err(conn_err(format!(
126 "The connection string '{}' cannot be parsed.",
127 opt.url
128 )));
129 }
130
131 #[cfg(feature = "sqlx-mysql")]
132 if DbBackend::MySql.is_prefix_of(&opt.url) {
133 return crate::SqlxMySqlConnector::connect(opt).await;
134 }
135 #[cfg(feature = "sqlx-postgres")]
136 if DbBackend::Postgres.is_prefix_of(&opt.url) {
137 return crate::SqlxPostgresConnector::connect(opt).await;
138 }
139 #[cfg(feature = "sqlx-sqlite")]
140 if DbBackend::Sqlite.is_prefix_of(&opt.url) {
141 return crate::SqlxSqliteConnector::connect(opt).await;
142 }
143 #[cfg(feature = "rusqlite")]
144 if DbBackend::Sqlite.is_prefix_of(&opt.url) {
145 return crate::driver::rusqlite::RusqliteConnector::connect(opt);
146 }
147 #[cfg(feature = "mock")]
148 if crate::MockDatabaseConnector::accepts(&opt.url) {
149 return crate::MockDatabaseConnector::connect(&opt.url).await;
150 }
151
152 Err(conn_err(format!(
153 "The connection string '{}' has no supporting driver.",
154 opt.url
155 )))
156 }
157
158 #[cfg(feature = "proxy")]
160 #[instrument(level = "trace", skip(proxy_func_arc))]
161 pub async fn connect_proxy(
162 db_type: DbBackend,
163 proxy_func_arc: std::sync::Arc<Box<dyn ProxyDatabaseTrait>>,
164 ) -> Result<DatabaseConnection, DbErr> {
165 match db_type {
166 DbBackend::MySql => {
167 return crate::ProxyDatabaseConnector::connect(
168 DbBackend::MySql,
169 proxy_func_arc.to_owned(),
170 );
171 }
172 DbBackend::Postgres => {
173 return crate::ProxyDatabaseConnector::connect(
174 DbBackend::Postgres,
175 proxy_func_arc.to_owned(),
176 );
177 }
178 DbBackend::Sqlite => {
179 return crate::ProxyDatabaseConnector::connect(
180 DbBackend::Sqlite,
181 proxy_func_arc.to_owned(),
182 );
183 }
184 }
185 }
186}
187
188impl<T> From<T> for ConnectOptions
189where
190 T: Into<String>,
191{
192 fn from(s: T) -> ConnectOptions {
193 ConnectOptions::new(s.into())
194 }
195}
196
197impl ConnectOptions {
198 pub fn new<T>(url: T) -> Self
200 where
201 T: Into<String>,
202 {
203 Self {
204 url: url.into(),
205 max_connections: None,
206 min_connections: None,
207 connect_timeout: None,
208 idle_timeout: None,
209 acquire_timeout: None,
210 max_lifetime: None,
211 sqlx_logging: true,
212 sqlx_logging_level: log::LevelFilter::Info,
213 sqlx_slow_statements_logging_level: log::LevelFilter::Off,
214 sqlx_slow_statements_logging_threshold: Duration::from_secs(1),
215 sqlcipher_key: None,
216 schema_search_path: None,
217 test_before_acquire: true,
218 connect_lazy: false,
219 after_connect: None,
220 #[cfg(feature = "sqlx-mysql")]
221 mysql_opts_fn: None,
222 #[cfg(feature = "sqlx-postgres")]
223 pg_opts_fn: None,
224 #[cfg(feature = "sqlx-sqlite")]
225 sqlite_opts_fn: None,
226 }
227 }
228
229 pub fn get_url(&self) -> &str {
231 &self.url
232 }
233
234 pub fn max_connections(&mut self, value: u32) -> &mut Self {
236 self.max_connections = Some(value);
237 self
238 }
239
240 pub fn get_max_connections(&self) -> Option<u32> {
242 self.max_connections
243 }
244
245 pub fn min_connections(&mut self, value: u32) -> &mut Self {
247 self.min_connections = Some(value);
248 self
249 }
250
251 pub fn get_min_connections(&self) -> Option<u32> {
253 self.min_connections
254 }
255
256 pub fn connect_timeout(&mut self, value: Duration) -> &mut Self {
258 self.connect_timeout = Some(value);
259 self
260 }
261
262 pub fn get_connect_timeout(&self) -> Option<Duration> {
264 self.connect_timeout
265 }
266
267 pub fn idle_timeout<T>(&mut self, value: T) -> &mut Self
269 where
270 T: Into<Option<Duration>>,
271 {
272 self.idle_timeout = Some(value.into());
273 self
274 }
275
276 pub fn get_idle_timeout(&self) -> Option<Option<Duration>> {
278 self.idle_timeout
279 }
280
281 pub fn acquire_timeout(&mut self, value: Duration) -> &mut Self {
283 self.acquire_timeout = Some(value);
284 self
285 }
286
287 pub fn get_acquire_timeout(&self) -> Option<Duration> {
289 self.acquire_timeout
290 }
291
292 pub fn max_lifetime<T>(&mut self, lifetime: T) -> &mut Self
294 where
295 T: Into<Option<Duration>>,
296 {
297 self.max_lifetime = Some(lifetime.into());
298 self
299 }
300
301 pub fn get_max_lifetime(&self) -> Option<Option<Duration>> {
303 self.max_lifetime
304 }
305
306 pub fn sqlx_logging(&mut self, value: bool) -> &mut Self {
308 self.sqlx_logging = value;
309 self
310 }
311
312 pub fn get_sqlx_logging(&self) -> bool {
314 self.sqlx_logging
315 }
316
317 pub fn sqlx_logging_level(&mut self, level: log::LevelFilter) -> &mut Self {
320 self.sqlx_logging_level = level;
321 self
322 }
323
324 pub fn sqlx_slow_statements_logging_settings(
327 &mut self,
328 level: log::LevelFilter,
329 duration: Duration,
330 ) -> &mut Self {
331 self.sqlx_slow_statements_logging_level = level;
332 self.sqlx_slow_statements_logging_threshold = duration;
333 self
334 }
335
336 pub fn get_sqlx_logging_level(&self) -> log::LevelFilter {
338 self.sqlx_logging_level
339 }
340
341 pub fn get_sqlx_slow_statements_logging_settings(&self) -> (log::LevelFilter, Duration) {
343 (
344 self.sqlx_slow_statements_logging_level,
345 self.sqlx_slow_statements_logging_threshold,
346 )
347 }
348
349 pub fn sqlcipher_key<T>(&mut self, value: T) -> &mut Self
351 where
352 T: Into<Cow<'static, str>>,
353 {
354 self.sqlcipher_key = Some(value.into());
355 self
356 }
357
358 pub fn set_schema_search_path<T>(&mut self, schema_search_path: T) -> &mut Self
360 where
361 T: Into<String>,
362 {
363 self.schema_search_path = Some(schema_search_path.into());
364 self
365 }
366
367 pub fn test_before_acquire(&mut self, value: bool) -> &mut Self {
369 self.test_before_acquire = value;
370 self
371 }
372
373 pub fn connect_lazy(&mut self, value: bool) -> &mut Self {
376 self.connect_lazy = value;
377 self
378 }
379
380 pub fn get_connect_lazy(&self) -> bool {
382 self.connect_lazy
383 }
384
385 pub fn after_connect<F>(&mut self, f: F) -> &mut Self
387 where
388 F: Fn(DatabaseConnection) -> BoxFuture<'static, Result<(), DbErr>> + Send + Sync + 'static,
389 {
390 self.after_connect = Some(Arc::new(f));
391
392 self
393 }
394
395 #[cfg(feature = "sqlx-mysql")]
396 #[cfg_attr(docsrs, doc(cfg(feature = "sqlx-mysql")))]
397 pub fn map_sqlx_mysql_opts<F>(&mut self, f: F) -> &mut Self
400 where
401 F: Fn(MySqlConnectOptions) -> MySqlConnectOptions + Send + Sync + 'static,
402 {
403 self.mysql_opts_fn = Some(Arc::new(f));
404 self
405 }
406
407 #[cfg(feature = "sqlx-postgres")]
408 #[cfg_attr(docsrs, doc(cfg(feature = "sqlx-postgres")))]
409 pub fn map_sqlx_postgres_opts<F>(&mut self, f: F) -> &mut Self
412 where
413 F: Fn(PgConnectOptions) -> PgConnectOptions + Send + Sync + 'static,
414 {
415 self.pg_opts_fn = Some(Arc::new(f));
416 self
417 }
418
419 #[cfg(feature = "sqlx-sqlite")]
420 #[cfg_attr(docsrs, doc(cfg(feature = "sqlx-sqlite")))]
421 pub fn map_sqlx_sqlite_opts<F>(&mut self, f: F) -> &mut Self
424 where
425 F: Fn(SqliteConnectOptions) -> SqliteConnectOptions + Send + Sync + 'static,
426 {
427 self.sqlite_opts_fn = Some(Arc::new(f));
428 self
429 }
430}