1use crate::DbPool;
5use crate::error::DbError;
6
7pub struct DbConfig {
9 pub url: String,
11 pub max_connections: u32,
13 pub pool_size: u32,
18}
19
20impl Default for DbConfig {
21 fn default() -> Self {
22 Self {
23 url: String::new(),
24 max_connections: 5,
25 pool_size: 5,
26 }
27 }
28}
29
30impl DbConfig {
31 #[tracing::instrument(name = "db.pool.connect", skip_all, err)]
37 pub async fn connect(&self) -> Result<DbPool, DbError> {
38 #[cfg(all(feature = "sqlite", not(feature = "postgres")))]
39 {
40 Self::connect_sqlite(&self.url, self.max_connections, self.pool_size).await
41 }
42 #[cfg(feature = "postgres")]
43 {
44 Self::connect_postgres(&self.url, self.pool_size).await
45 }
46 }
47
48 #[cfg(all(feature = "sqlite", not(feature = "postgres")))]
49 async fn connect_sqlite(
50 path: &str,
51 max_connections: u32,
52 pool_size: u32,
53 ) -> Result<DbPool, DbError> {
54 use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions};
55 use std::str::FromStr;
56
57 let url = if path == ":memory:" {
58 "sqlite::memory:".to_string()
59 } else {
60 let db_path = std::path::Path::new(path);
61 if let Some(parent) = db_path.parent()
62 && !parent.as_os_str().is_empty()
63 {
64 std::fs::create_dir_all(parent)?;
65 }
66 if !db_path.exists() {
71 drop(zeph_common::fs_secure::open_private_truncate(db_path)?);
72 }
73 format!("sqlite:{path}?mode=rwc")
74 };
75
76 let opts = SqliteConnectOptions::from_str(&url)
77 .map_err(DbError::Sqlx)?
78 .create_if_missing(true)
79 .foreign_keys(true)
80 .busy_timeout(std::time::Duration::from_secs(5))
81 .journal_mode(sqlx::sqlite::SqliteJournalMode::Wal)
82 .synchronous(sqlx::sqlite::SqliteSynchronous::Normal);
83
84 let effective_max = if path == ":memory:" {
89 1
90 } else {
91 max_connections.max(pool_size)
92 };
93 let pool = SqlitePoolOptions::new()
94 .max_connections(effective_max)
95 .min_connections(1)
96 .acquire_timeout(std::time::Duration::from_secs(30))
97 .connect_with(opts)
98 .await
99 .map_err(DbError::Sqlx)?;
100
101 crate::migrate::run_migrations(&pool).await?;
102
103 #[cfg(unix)]
109 if path != ":memory:" {
110 use std::os::unix::fs::PermissionsExt as _;
111 for suffix in &["", "-wal", "-shm", "-journal"] {
112 let p = format!("{path}{suffix}");
113 if let Ok(metadata) = std::fs::metadata(&p) {
114 let mut perms = metadata.permissions();
115 perms.set_mode(0o600);
116 let _ = std::fs::set_permissions(&p, perms);
117 }
118 }
119 }
120
121 if path != ":memory:" {
124 sqlx::query("PRAGMA wal_checkpoint(PASSIVE)")
125 .execute(&pool)
126 .await
127 .map_err(DbError::Sqlx)?;
128 }
129
130 Ok(pool)
131 }
132
133 #[cfg(feature = "postgres")]
134 async fn connect_postgres(url: &str, pool_size: u32) -> Result<DbPool, DbError> {
135 use sqlx::postgres::PgPoolOptions;
136
137 if !url.contains("sslmode=") {
138 tracing::warn!(
139 "postgres connection string has no sslmode; plaintext connections are allowed"
140 );
141 }
142
143 let pool = PgPoolOptions::new()
144 .max_connections(pool_size)
145 .acquire_timeout(std::time::Duration::from_secs(30))
146 .connect(url)
147 .await
148 .map_err(|e| DbError::Connection {
149 url: redact_url(url).unwrap_or_else(|| "[redacted]".into()),
150 source: e,
151 })?;
152
153 crate::migrate::run_migrations(&pool).await?;
154
155 Ok(pool)
156 }
157}
158
159#[must_use]
166pub fn redact_url(url: &str) -> Option<String> {
167 use std::sync::LazyLock;
168 static RE: LazyLock<regex::Regex> =
169 LazyLock::new(|| regex::Regex::new(r"://[^:]+:[^@]+@").expect("static regex"));
170 if RE.is_match(url) {
171 Some(RE.replace(url, "://[redacted]@").into_owned())
172 } else {
173 None
174 }
175}
176
177#[cfg(test)]
178mod tests {
179 use super::*;
180
181 #[test]
182 fn redact_url_replaces_credentials() {
183 let url = "postgres://user:secret@localhost:5432/zeph";
184 let redacted = redact_url(url).unwrap();
185 assert_eq!(redacted, "postgres://[redacted]@localhost:5432/zeph");
186 assert!(!redacted.contains("secret"));
187 }
188
189 #[test]
190 fn redact_url_returns_none_for_no_credentials() {
191 let url = "postgres://localhost:5432/zeph";
193 assert!(redact_url(url).is_none());
194 }
195
196 #[test]
197 fn redact_url_handles_sqlite_path() {
198 let url = "sqlite:/path/to/db";
199 assert!(redact_url(url).is_none());
200 }
201
202 #[cfg(all(unix, feature = "sqlite", not(feature = "postgres")))]
203 #[tokio::test]
204 async fn sqlite_precreated_with_0600() {
205 use std::os::unix::fs::PermissionsExt as _;
206 let dir = tempfile::tempdir().unwrap();
207 let db_path = dir.path().join("test.db");
208 let cfg = DbConfig {
209 url: db_path.to_str().unwrap().to_owned(),
210 max_connections: 1,
211 pool_size: 1,
212 };
213 cfg.connect().await.unwrap();
214 let mode = std::fs::metadata(&db_path).unwrap().permissions().mode() & 0o777;
215 assert_eq!(
216 mode, 0o600,
217 "SQLite DB file must be created with mode 0o600"
218 );
219 }
220}