Skip to main content

zeph_db/
pool.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use crate::DbPool;
5use crate::error::DbError;
6
7/// Configuration for database pool construction.
8pub struct DbConfig {
9    /// Database URL. `SQLite`: file path or `:memory:`. `PostgreSQL`: connection URL.
10    pub url: String,
11    /// Maximum number of connections in the pool.
12    pub max_connections: u32,
13    /// `SQLite` only: connection pool size. Default 5.
14    ///
15    /// `BEGIN IMMEDIATE` serializes concurrent writers at the `SQLite` level;
16    /// the pool size controls read concurrency only.
17    pub pool_size: u32,
18}
19
20impl Default for DbConfig {
21    fn default() -> Self {
22        Self {
23            url: String::new(),
24            max_connections: 5,
25            pool_size: 5,
26        }
27    }
28}
29
30impl DbConfig {
31    /// Connect to the database and run migrations.
32    ///
33    /// # Errors
34    ///
35    /// Returns [`DbError`] if connection or migration fails.
36    #[tracing::instrument(name = "db.pool.connect", skip_all, err)]
37    pub async fn connect(&self) -> Result<DbPool, DbError> {
38        #[cfg(all(feature = "sqlite", not(feature = "postgres")))]
39        {
40            Self::connect_sqlite(&self.url, self.max_connections, self.pool_size).await
41        }
42        #[cfg(feature = "postgres")]
43        {
44            Self::connect_postgres(&self.url, self.pool_size).await
45        }
46    }
47
48    #[cfg(all(feature = "sqlite", not(feature = "postgres")))]
49    async fn connect_sqlite(
50        path: &str,
51        max_connections: u32,
52        pool_size: u32,
53    ) -> Result<DbPool, DbError> {
54        use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions};
55        use std::str::FromStr;
56
57        let url = if path == ":memory:" {
58            "sqlite::memory:".to_string()
59        } else {
60            let db_path = std::path::Path::new(path);
61            if let Some(parent) = db_path.parent()
62                && !parent.as_os_str().is_empty()
63            {
64                std::fs::create_dir_all(parent)?;
65            }
66            // Pre-create with 0o600 so sqlx inherits the mode rather than using the
67            // process umask. sqlx reopens the existing file via SQLITE_OPEN_CREATE.
68            // WAL/SHM sidecars are created by sqlx after the pool opens and will still
69            // inherit the process umask (sqlx limitation — best-effort chmod below).
70            if !db_path.exists() {
71                drop(zeph_common::fs_secure::open_private_truncate(db_path)?);
72            }
73            format!("sqlite:{path}?mode=rwc")
74        };
75
76        let opts = SqliteConnectOptions::from_str(&url)
77            .map_err(DbError::Sqlx)?
78            .create_if_missing(true)
79            .foreign_keys(true)
80            .busy_timeout(std::time::Duration::from_secs(5))
81            .journal_mode(sqlx::sqlite::SqliteJournalMode::Wal)
82            .synchronous(sqlx::sqlite::SqliteSynchronous::Normal);
83
84        // BEGIN IMMEDIATE serializes concurrent writers at the SQLite level.
85        // pool_size controls the connection count; max_connections is the upper bound.
86        // In-memory databases are connection-scoped: each new connection is a separate
87        // empty DB. Force a single connection so all queries share the migrated schema.
88        let effective_max = if path == ":memory:" {
89            1
90        } else {
91            max_connections.max(pool_size)
92        };
93        let pool = SqlitePoolOptions::new()
94            .max_connections(effective_max)
95            .min_connections(1)
96            .acquire_timeout(std::time::Duration::from_secs(30))
97            .connect_with(opts)
98            .await
99            .map_err(DbError::Sqlx)?;
100
101        crate::migrate::run_migrations(&pool).await?;
102
103        // Best-effort chmod for .db, .db-wal, and .db-shm. The .db itself was
104        // pre-created with 0o600 above; the WAL/SHM sidecars are created by sqlx
105        // after the pool opens and inherit the process umask, so we fix them here.
106        // There is a small race window between sidecar creation and this chmod;
107        // there is no way to close it without upstream sqlx support.
108        #[cfg(unix)]
109        if path != ":memory:" {
110            use std::os::unix::fs::PermissionsExt as _;
111            for suffix in &["", "-wal", "-shm", "-journal"] {
112                let p = format!("{path}{suffix}");
113                if let Ok(metadata) = std::fs::metadata(&p) {
114                    let mut perms = metadata.permissions();
115                    perms.set_mode(0o600);
116                    let _ = std::fs::set_permissions(&p, perms);
117                }
118            }
119        }
120
121        // Run a passive WAL checkpoint after migrations to avoid unbounded WAL growth.
122        // Skipped for in-memory databases (no WAL file).
123        if path != ":memory:" {
124            sqlx::query("PRAGMA wal_checkpoint(PASSIVE)")
125                .execute(&pool)
126                .await
127                .map_err(DbError::Sqlx)?;
128        }
129
130        Ok(pool)
131    }
132
133    #[cfg(feature = "postgres")]
134    async fn connect_postgres(url: &str, pool_size: u32) -> Result<DbPool, DbError> {
135        use sqlx::postgres::PgPoolOptions;
136
137        if !url.contains("sslmode=") {
138            tracing::warn!(
139                "postgres connection string has no sslmode; plaintext connections are allowed"
140            );
141        }
142
143        let pool = PgPoolOptions::new()
144            .max_connections(pool_size)
145            .acquire_timeout(std::time::Duration::from_secs(30))
146            .connect(url)
147            .await
148            .map_err(|e| DbError::Connection {
149                url: redact_url(url).unwrap_or_else(|| "[redacted]".into()),
150                source: e,
151            })?;
152
153        crate::migrate::run_migrations(&pool).await?;
154
155        Ok(pool)
156    }
157}
158
159/// Strip password from a database URL for safe logging.
160///
161/// Replaces `://user:password@` with `://[redacted]@`.
162///
163/// Returns `None` if the URL contains no embedded credentials (already safe).
164/// Returns `Some(redacted)` if credentials were found and replaced.
165#[must_use]
166pub fn redact_url(url: &str) -> Option<String> {
167    use std::sync::LazyLock;
168    static RE: LazyLock<regex::Regex> =
169        LazyLock::new(|| regex::Regex::new(r"://[^:]+:[^@]+@").expect("static regex"));
170    if RE.is_match(url) {
171        Some(RE.replace(url, "://[redacted]@").into_owned())
172    } else {
173        None
174    }
175}
176
177#[cfg(test)]
178mod tests {
179    use super::*;
180
181    #[test]
182    fn redact_url_replaces_credentials() {
183        let url = "postgres://user:secret@localhost:5432/zeph";
184        let redacted = redact_url(url).unwrap();
185        assert_eq!(redacted, "postgres://[redacted]@localhost:5432/zeph");
186        assert!(!redacted.contains("secret"));
187    }
188
189    #[test]
190    fn redact_url_returns_none_for_no_credentials() {
191        // URL without credentials — no match, returns None
192        let url = "postgres://localhost:5432/zeph";
193        assert!(redact_url(url).is_none());
194    }
195
196    #[test]
197    fn redact_url_handles_sqlite_path() {
198        let url = "sqlite:/path/to/db";
199        assert!(redact_url(url).is_none());
200    }
201
202    #[cfg(all(unix, feature = "sqlite", not(feature = "postgres")))]
203    #[tokio::test]
204    async fn sqlite_precreated_with_0600() {
205        use std::os::unix::fs::PermissionsExt as _;
206        let dir = tempfile::tempdir().unwrap();
207        let db_path = dir.path().join("test.db");
208        let cfg = DbConfig {
209            url: db_path.to_str().unwrap().to_owned(),
210            max_connections: 1,
211            pool_size: 1,
212        };
213        cfg.connect().await.unwrap();
214        let mode = std::fs::metadata(&db_path).unwrap().permissions().mode() & 0o777;
215        assert_eq!(
216            mode, 0o600,
217            "SQLite DB file must be created with mode 0o600"
218        );
219    }
220}