1use crate::DbPool;
5use crate::error::DbError;
6
7pub struct DbConfig {
9 pub url: String,
11 pub max_connections: u32,
13 pub pool_size: u32,
18}
19
20impl Default for DbConfig {
21 fn default() -> Self {
22 Self {
23 url: String::new(),
24 max_connections: 5,
25 pool_size: 5,
26 }
27 }
28}
29
30impl DbConfig {
31 pub async fn connect(&self) -> Result<DbPool, DbError> {
37 #[cfg(all(feature = "sqlite", not(feature = "postgres")))]
38 {
39 Self::connect_sqlite(&self.url, self.max_connections, self.pool_size).await
40 }
41 #[cfg(feature = "postgres")]
42 {
43 Self::connect_postgres(&self.url, self.pool_size).await
44 }
45 }
46
47 #[cfg(all(feature = "sqlite", not(feature = "postgres")))]
48 async fn connect_sqlite(
49 path: &str,
50 max_connections: u32,
51 pool_size: u32,
52 ) -> Result<DbPool, DbError> {
53 use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions};
54 use std::str::FromStr;
55
56 let url = if path == ":memory:" {
57 "sqlite::memory:".to_string()
58 } else {
59 if let Some(parent) = std::path::Path::new(path).parent()
60 && !parent.as_os_str().is_empty()
61 {
62 std::fs::create_dir_all(parent)?;
63 }
64 format!("sqlite:{path}?mode=rwc")
65 };
66
67 let opts = SqliteConnectOptions::from_str(&url)
68 .map_err(DbError::Sqlx)?
69 .create_if_missing(true)
70 .foreign_keys(true)
71 .busy_timeout(std::time::Duration::from_secs(5))
72 .journal_mode(sqlx::sqlite::SqliteJournalMode::Wal)
73 .synchronous(sqlx::sqlite::SqliteSynchronous::Normal);
74
75 let effective_max = if path == ":memory:" {
80 1
81 } else {
82 max_connections.max(pool_size)
83 };
84 let pool = SqlitePoolOptions::new()
85 .max_connections(effective_max)
86 .min_connections(1)
87 .acquire_timeout(std::time::Duration::from_secs(30))
88 .connect_with(opts)
89 .await
90 .map_err(DbError::Sqlx)?;
91
92 crate::migrate::run_migrations(&pool).await?;
93
94 #[cfg(unix)]
96 if path != ":memory:" {
97 use std::os::unix::fs::PermissionsExt;
98 if let Ok(metadata) = std::fs::metadata(path) {
99 let mut perms = metadata.permissions();
100 perms.set_mode(0o600);
101 let _ = std::fs::set_permissions(path, perms);
102 }
103 }
104
105 if path != ":memory:" {
108 sqlx::query("PRAGMA wal_checkpoint(PASSIVE)")
109 .execute(&pool)
110 .await
111 .map_err(DbError::Sqlx)?;
112 }
113
114 Ok(pool)
115 }
116
117 #[cfg(feature = "postgres")]
118 async fn connect_postgres(url: &str, pool_size: u32) -> Result<DbPool, DbError> {
119 use sqlx::postgres::PgPoolOptions;
120
121 if !url.contains("sslmode=") {
122 tracing::warn!(
123 "postgres connection string has no sslmode; plaintext connections are allowed"
124 );
125 }
126
127 let pool = PgPoolOptions::new()
128 .max_connections(pool_size)
129 .acquire_timeout(std::time::Duration::from_secs(30))
130 .connect(url)
131 .await
132 .map_err(|e| DbError::Connection {
133 url: redact_url(url).unwrap_or_else(|| "[redacted]".into()),
134 source: e,
135 })?;
136
137 crate::migrate::run_migrations(&pool).await?;
138
139 Ok(pool)
140 }
141}
142
143#[must_use]
150pub fn redact_url(url: &str) -> Option<String> {
151 use std::sync::LazyLock;
152 static RE: LazyLock<regex::Regex> =
153 LazyLock::new(|| regex::Regex::new(r"://[^:]+:[^@]+@").expect("static regex"));
154 if RE.is_match(url) {
155 Some(RE.replace(url, "://[redacted]@").into_owned())
156 } else {
157 None
158 }
159}
160
161#[cfg(test)]
162mod tests {
163 use super::*;
164
165 #[test]
166 fn redact_url_replaces_credentials() {
167 let url = "postgres://user:secret@localhost:5432/zeph";
168 let redacted = redact_url(url).unwrap();
169 assert_eq!(redacted, "postgres://[redacted]@localhost:5432/zeph");
170 assert!(!redacted.contains("secret"));
171 }
172
173 #[test]
174 fn redact_url_returns_none_for_no_credentials() {
175 let url = "postgres://localhost:5432/zeph";
177 assert!(redact_url(url).is_none());
178 }
179
180 #[test]
181 fn redact_url_handles_sqlite_path() {
182 let url = "sqlite:/path/to/db";
183 assert!(redact_url(url).is_none());
184 }
185}