Skip to main content

zeph_db/
pool.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use crate::DbPool;
5use crate::error::DbError;
6
7/// Configuration for database pool construction.
8pub struct DbConfig {
9    /// Database URL. `SQLite`: file path or `:memory:`. `PostgreSQL`: connection URL.
10    pub url: String,
11    /// Maximum number of connections in the pool.
12    pub max_connections: u32,
13    /// `SQLite` only: connection pool size. Default 5.
14    ///
15    /// `BEGIN IMMEDIATE` serializes concurrent writers at the `SQLite` level;
16    /// the pool size controls read concurrency only.
17    pub pool_size: u32,
18}
19
20impl Default for DbConfig {
21    fn default() -> Self {
22        Self {
23            url: String::new(),
24            max_connections: 5,
25            pool_size: 5,
26        }
27    }
28}
29
30impl DbConfig {
31    /// Connect to the database and run migrations.
32    ///
33    /// # Errors
34    ///
35    /// Returns [`DbError`] if connection or migration fails.
36    pub async fn connect(&self) -> Result<DbPool, DbError> {
37        #[cfg(all(feature = "sqlite", not(feature = "postgres")))]
38        {
39            Self::connect_sqlite(&self.url, self.max_connections, self.pool_size).await
40        }
41        #[cfg(feature = "postgres")]
42        {
43            Self::connect_postgres(&self.url, self.pool_size).await
44        }
45    }
46
47    #[cfg(all(feature = "sqlite", not(feature = "postgres")))]
48    async fn connect_sqlite(
49        path: &str,
50        max_connections: u32,
51        pool_size: u32,
52    ) -> Result<DbPool, DbError> {
53        use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions};
54        use std::str::FromStr;
55
56        let url = if path == ":memory:" {
57            "sqlite::memory:".to_string()
58        } else {
59            let db_path = std::path::Path::new(path);
60            if let Some(parent) = db_path.parent()
61                && !parent.as_os_str().is_empty()
62            {
63                std::fs::create_dir_all(parent)?;
64            }
65            // Pre-create with 0o600 so sqlx inherits the mode rather than using the
66            // process umask. sqlx reopens the existing file via SQLITE_OPEN_CREATE.
67            // WAL/SHM sidecars are created by sqlx after the pool opens and will still
68            // inherit the process umask (sqlx limitation — best-effort chmod below).
69            if !db_path.exists() {
70                drop(zeph_common::fs_secure::open_private_truncate(db_path)?);
71            }
72            format!("sqlite:{path}?mode=rwc")
73        };
74
75        let opts = SqliteConnectOptions::from_str(&url)
76            .map_err(DbError::Sqlx)?
77            .create_if_missing(true)
78            .foreign_keys(true)
79            .busy_timeout(std::time::Duration::from_secs(5))
80            .journal_mode(sqlx::sqlite::SqliteJournalMode::Wal)
81            .synchronous(sqlx::sqlite::SqliteSynchronous::Normal);
82
83        // BEGIN IMMEDIATE serializes concurrent writers at the SQLite level.
84        // pool_size controls the connection count; max_connections is the upper bound.
85        // In-memory databases are connection-scoped: each new connection is a separate
86        // empty DB. Force a single connection so all queries share the migrated schema.
87        let effective_max = if path == ":memory:" {
88            1
89        } else {
90            max_connections.max(pool_size)
91        };
92        let pool = SqlitePoolOptions::new()
93            .max_connections(effective_max)
94            .min_connections(1)
95            .acquire_timeout(std::time::Duration::from_secs(30))
96            .connect_with(opts)
97            .await
98            .map_err(DbError::Sqlx)?;
99
100        crate::migrate::run_migrations(&pool).await?;
101
102        // Best-effort chmod for .db, .db-wal, and .db-shm. The .db itself was
103        // pre-created with 0o600 above; the WAL/SHM sidecars are created by sqlx
104        // after the pool opens and inherit the process umask, so we fix them here.
105        // There is a small race window between sidecar creation and this chmod;
106        // there is no way to close it without upstream sqlx support.
107        #[cfg(unix)]
108        if path != ":memory:" {
109            use std::os::unix::fs::PermissionsExt as _;
110            for suffix in &["", "-wal", "-shm", "-journal"] {
111                let p = format!("{path}{suffix}");
112                if let Ok(metadata) = std::fs::metadata(&p) {
113                    let mut perms = metadata.permissions();
114                    perms.set_mode(0o600);
115                    let _ = std::fs::set_permissions(&p, perms);
116                }
117            }
118        }
119
120        // Run a passive WAL checkpoint after migrations to avoid unbounded WAL growth.
121        // Skipped for in-memory databases (no WAL file).
122        if path != ":memory:" {
123            sqlx::query("PRAGMA wal_checkpoint(PASSIVE)")
124                .execute(&pool)
125                .await
126                .map_err(DbError::Sqlx)?;
127        }
128
129        Ok(pool)
130    }
131
132    #[cfg(feature = "postgres")]
133    async fn connect_postgres(url: &str, pool_size: u32) -> Result<DbPool, DbError> {
134        use sqlx::postgres::PgPoolOptions;
135
136        if !url.contains("sslmode=") {
137            tracing::warn!(
138                "postgres connection string has no sslmode; plaintext connections are allowed"
139            );
140        }
141
142        let pool = PgPoolOptions::new()
143            .max_connections(pool_size)
144            .acquire_timeout(std::time::Duration::from_secs(30))
145            .connect(url)
146            .await
147            .map_err(|e| DbError::Connection {
148                url: redact_url(url).unwrap_or_else(|| "[redacted]".into()),
149                source: e,
150            })?;
151
152        crate::migrate::run_migrations(&pool).await?;
153
154        Ok(pool)
155    }
156}
157
158/// Strip password from a database URL for safe logging.
159///
160/// Replaces `://user:password@` with `://[redacted]@`.
161///
162/// Returns `None` if the URL contains no embedded credentials (already safe).
163/// Returns `Some(redacted)` if credentials were found and replaced.
164#[must_use]
165pub fn redact_url(url: &str) -> Option<String> {
166    use std::sync::LazyLock;
167    static RE: LazyLock<regex::Regex> =
168        LazyLock::new(|| regex::Regex::new(r"://[^:]+:[^@]+@").expect("static regex"));
169    if RE.is_match(url) {
170        Some(RE.replace(url, "://[redacted]@").into_owned())
171    } else {
172        None
173    }
174}
175
176#[cfg(test)]
177mod tests {
178    use super::*;
179
180    #[test]
181    fn redact_url_replaces_credentials() {
182        let url = "postgres://user:secret@localhost:5432/zeph";
183        let redacted = redact_url(url).unwrap();
184        assert_eq!(redacted, "postgres://[redacted]@localhost:5432/zeph");
185        assert!(!redacted.contains("secret"));
186    }
187
188    #[test]
189    fn redact_url_returns_none_for_no_credentials() {
190        // URL without credentials — no match, returns None
191        let url = "postgres://localhost:5432/zeph";
192        assert!(redact_url(url).is_none());
193    }
194
195    #[test]
196    fn redact_url_handles_sqlite_path() {
197        let url = "sqlite:/path/to/db";
198        assert!(redact_url(url).is_none());
199    }
200
201    #[cfg(all(unix, feature = "sqlite", not(feature = "postgres")))]
202    #[tokio::test]
203    async fn sqlite_precreated_with_0600() {
204        use std::os::unix::fs::PermissionsExt as _;
205        let dir = tempfile::tempdir().unwrap();
206        let db_path = dir.path().join("test.db");
207        let cfg = DbConfig {
208            url: db_path.to_str().unwrap().to_owned(),
209            max_connections: 1,
210            pool_size: 1,
211        };
212        cfg.connect().await.unwrap();
213        let mode = std::fs::metadata(&db_path).unwrap().permissions().mode() & 0o777;
214        assert_eq!(
215            mode, 0o600,
216            "SQLite DB file must be created with mode 0o600"
217        );
218    }
219}