1use crate::DbPool;
5use crate::error::DbError;
6
7pub struct DbConfig {
9 pub url: String,
11 pub max_connections: u32,
13 pub pool_size: u32,
18}
19
20impl Default for DbConfig {
21 fn default() -> Self {
22 Self {
23 url: String::new(),
24 max_connections: 5,
25 pool_size: 5,
26 }
27 }
28}
29
30impl DbConfig {
31 pub async fn connect(&self) -> Result<DbPool, DbError> {
37 #[cfg(all(feature = "sqlite", not(feature = "postgres")))]
38 {
39 Self::connect_sqlite(&self.url, self.max_connections, self.pool_size).await
40 }
41 #[cfg(feature = "postgres")]
42 {
43 Self::connect_postgres(&self.url, self.pool_size).await
44 }
45 }
46
47 #[cfg(all(feature = "sqlite", not(feature = "postgres")))]
48 async fn connect_sqlite(
49 path: &str,
50 max_connections: u32,
51 pool_size: u32,
52 ) -> Result<DbPool, DbError> {
53 use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions};
54 use std::str::FromStr;
55
56 let url = if path == ":memory:" {
57 "sqlite::memory:".to_string()
58 } else {
59 let db_path = std::path::Path::new(path);
60 if let Some(parent) = db_path.parent()
61 && !parent.as_os_str().is_empty()
62 {
63 std::fs::create_dir_all(parent)?;
64 }
65 if !db_path.exists() {
70 drop(zeph_common::fs_secure::open_private_truncate(db_path)?);
71 }
72 format!("sqlite:{path}?mode=rwc")
73 };
74
75 let opts = SqliteConnectOptions::from_str(&url)
76 .map_err(DbError::Sqlx)?
77 .create_if_missing(true)
78 .foreign_keys(true)
79 .busy_timeout(std::time::Duration::from_secs(5))
80 .journal_mode(sqlx::sqlite::SqliteJournalMode::Wal)
81 .synchronous(sqlx::sqlite::SqliteSynchronous::Normal);
82
83 let effective_max = if path == ":memory:" {
88 1
89 } else {
90 max_connections.max(pool_size)
91 };
92 let pool = SqlitePoolOptions::new()
93 .max_connections(effective_max)
94 .min_connections(1)
95 .acquire_timeout(std::time::Duration::from_secs(30))
96 .connect_with(opts)
97 .await
98 .map_err(DbError::Sqlx)?;
99
100 crate::migrate::run_migrations(&pool).await?;
101
102 #[cfg(unix)]
108 if path != ":memory:" {
109 use std::os::unix::fs::PermissionsExt as _;
110 for suffix in &["", "-wal", "-shm", "-journal"] {
111 let p = format!("{path}{suffix}");
112 if let Ok(metadata) = std::fs::metadata(&p) {
113 let mut perms = metadata.permissions();
114 perms.set_mode(0o600);
115 let _ = std::fs::set_permissions(&p, perms);
116 }
117 }
118 }
119
120 if path != ":memory:" {
123 sqlx::query("PRAGMA wal_checkpoint(PASSIVE)")
124 .execute(&pool)
125 .await
126 .map_err(DbError::Sqlx)?;
127 }
128
129 Ok(pool)
130 }
131
132 #[cfg(feature = "postgres")]
133 async fn connect_postgres(url: &str, pool_size: u32) -> Result<DbPool, DbError> {
134 use sqlx::postgres::PgPoolOptions;
135
136 if !url.contains("sslmode=") {
137 tracing::warn!(
138 "postgres connection string has no sslmode; plaintext connections are allowed"
139 );
140 }
141
142 let pool = PgPoolOptions::new()
143 .max_connections(pool_size)
144 .acquire_timeout(std::time::Duration::from_secs(30))
145 .connect(url)
146 .await
147 .map_err(|e| DbError::Connection {
148 url: redact_url(url).unwrap_or_else(|| "[redacted]".into()),
149 source: e,
150 })?;
151
152 crate::migrate::run_migrations(&pool).await?;
153
154 Ok(pool)
155 }
156}
157
158#[must_use]
165pub fn redact_url(url: &str) -> Option<String> {
166 use std::sync::LazyLock;
167 static RE: LazyLock<regex::Regex> =
168 LazyLock::new(|| regex::Regex::new(r"://[^:]+:[^@]+@").expect("static regex"));
169 if RE.is_match(url) {
170 Some(RE.replace(url, "://[redacted]@").into_owned())
171 } else {
172 None
173 }
174}
175
176#[cfg(test)]
177mod tests {
178 use super::*;
179
180 #[test]
181 fn redact_url_replaces_credentials() {
182 let url = "postgres://user:secret@localhost:5432/zeph";
183 let redacted = redact_url(url).unwrap();
184 assert_eq!(redacted, "postgres://[redacted]@localhost:5432/zeph");
185 assert!(!redacted.contains("secret"));
186 }
187
188 #[test]
189 fn redact_url_returns_none_for_no_credentials() {
190 let url = "postgres://localhost:5432/zeph";
192 assert!(redact_url(url).is_none());
193 }
194
195 #[test]
196 fn redact_url_handles_sqlite_path() {
197 let url = "sqlite:/path/to/db";
198 assert!(redact_url(url).is_none());
199 }
200
201 #[cfg(all(unix, feature = "sqlite", not(feature = "postgres")))]
202 #[tokio::test]
203 async fn sqlite_precreated_with_0600() {
204 use std::os::unix::fs::PermissionsExt as _;
205 let dir = tempfile::tempdir().unwrap();
206 let db_path = dir.path().join("test.db");
207 let cfg = DbConfig {
208 url: db_path.to_str().unwrap().to_owned(),
209 max_connections: 1,
210 pool_size: 1,
211 };
212 cfg.connect().await.unwrap();
213 let mode = std::fs::metadata(&db_path).unwrap().permissions().mode() & 0o777;
214 assert_eq!(
215 mode, 0o600,
216 "SQLite DB file must be created with mode 0o600"
217 );
218 }
219}