Skip to main content

waypoint_core/
db.rs

1//! Database connection, TLS support, advisory locking, and transaction execution.
2
3use fastrand;
4use tokio_postgres::Client;
5
6use crate::config::SslMode;
7use crate::error::{Result, WaypointError};
8
9/// Quote a SQL identifier to prevent SQL injection.
10///
11/// Doubles any embedded double-quotes and wraps in double-quotes.
12pub fn quote_ident(name: &str) -> String {
13    format!("\"{}\"", name.replace('"', "\"\""))
14}
15
16/// Validate that a SQL identifier contains only safe characters.
17///
18/// Returns an error for names with characters outside `[a-zA-Z0-9_]`.
19/// Even with quoting (defense in depth), we reject suspicious identifiers early.
20pub fn validate_identifier(name: &str) -> Result<()> {
21    if name.is_empty() {
22        return Err(WaypointError::ConfigError(
23            "Identifier cannot be empty".to_string(),
24        ));
25    }
26    if !name.chars().all(|c| c.is_ascii_alphanumeric() || c == '_') {
27        return Err(WaypointError::ConfigError(format!(
28            "Identifier '{}' contains invalid characters. Only [a-zA-Z0-9_] are allowed.",
29            name
30        )));
31    }
32    Ok(())
33}
34
35/// Build a rustls ClientConfig using the Mozilla CA bundle and ring crypto provider.
36fn make_rustls_config() -> rustls::ClientConfig {
37    let root_store =
38        rustls::RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
39    rustls::ClientConfig::builder_with_provider(std::sync::Arc::new(
40        rustls::crypto::ring::default_provider(),
41    ))
42    .with_safe_default_protocol_versions()
43    .unwrap()
44    .with_root_certificates(root_store)
45    .with_no_client_auth()
46}
47
48/// Check if a postgres error is a permanent authentication failure that should not be retried.
49fn is_permanent_error(e: &tokio_postgres::Error) -> bool {
50    if let Some(db_err) = e.as_db_error() {
51        let code = db_err.code().code();
52        // 28P01 = invalid_password, 28000 = invalid_authorization_specification
53        return code == "28P01" || code == "28000";
54    }
55    false
56}
57
58/// Inject TCP keepalive parameters into a connection string if not already present.
59///
60/// For URL-style strings (`postgres://...`), appends `?keepalives=1&keepalives_idle=N`
61/// (or `&` if `?` already exists). For key=value style, appends ` keepalives=1 keepalives_idle=N`.
62/// Returns the string unchanged if `keepalive_secs == 0` or keepalive params already exist.
63pub fn inject_keepalive(conn_string: &str, keepalive_secs: u32) -> String {
64    if keepalive_secs == 0 {
65        return conn_string.to_string();
66    }
67    let lower = conn_string.to_lowercase();
68    if lower.contains("keepalives") {
69        return conn_string.to_string();
70    }
71    let params = format!("keepalives=1&keepalives_idle={}", keepalive_secs);
72    if conn_string.starts_with("postgres://") || conn_string.starts_with("postgresql://") {
73        if conn_string.contains('?') {
74            format!("{}&{}", conn_string, params)
75        } else {
76            format!("{}?{}", conn_string, params)
77        }
78    } else {
79        // Key=value style
80        format!(
81            "{} keepalives=1 keepalives_idle={}",
82            conn_string, keepalive_secs
83        )
84    }
85}
86
87/// Spawn the background connection driver task.
88///
89/// Both TLS and non-TLS connections produce a future that resolves when the
90/// connection terminates.  This helper accepts any such future and runs it
91/// on the tokio runtime, logging errors.
92fn spawn_connection_task<F>(connection: F)
93where
94    F: std::future::Future<Output = std::result::Result<(), tokio_postgres::Error>>
95        + Send
96        + 'static,
97{
98    tokio::spawn(async move {
99        if let Err(e) = connection.await {
100            log::error!("Database connection error: {}", e);
101        }
102    });
103}
104
105/// Connect to the database using the provided connection string with TLS support.
106///
107/// Spawns the connection task on the tokio runtime.
108async fn connect_once(
109    conn_string: &str,
110    ssl_mode: &SslMode,
111    connect_timeout_secs: u32,
112) -> std::result::Result<Client, tokio_postgres::Error> {
113    let connect_fut = async {
114        match ssl_mode {
115            SslMode::Disable => {
116                let (client, connection) =
117                    tokio_postgres::connect(conn_string, tokio_postgres::NoTls).await?;
118                spawn_connection_task(connection);
119                Ok(client)
120            }
121            SslMode::Require => {
122                let tls_config = make_rustls_config();
123                let tls = tokio_postgres_rustls::MakeRustlsConnect::new(tls_config);
124                let (client, connection) = tokio_postgres::connect(conn_string, tls).await?;
125                spawn_connection_task(connection);
126                Ok(client)
127            }
128            SslMode::Prefer => {
129                // Try TLS first, fall back to plaintext
130                let tls_config = make_rustls_config();
131                let tls = tokio_postgres_rustls::MakeRustlsConnect::new(tls_config);
132                match tokio_postgres::connect(conn_string, tls).await {
133                    Ok((client, connection)) => {
134                        spawn_connection_task(connection);
135                        Ok(client)
136                    }
137                    Err(_) => {
138                        log::debug!("TLS connection failed, falling back to plaintext");
139                        let (client, connection) =
140                            tokio_postgres::connect(conn_string, tokio_postgres::NoTls).await?;
141                        spawn_connection_task(connection);
142                        Ok(client)
143                    }
144                }
145            }
146        }
147    };
148
149    if connect_timeout_secs > 0 {
150        match tokio::time::timeout(
151            std::time::Duration::from_secs(connect_timeout_secs as u64),
152            connect_fut,
153        )
154        .await
155        {
156            Ok(result) => result,
157            Err(_) => Err(tokio_postgres::Error::__private_api_timeout()),
158        }
159    } else {
160        connect_fut.await
161    }
162}
163
164/// Connect to the database using the provided connection string.
165///
166/// Spawns the connection task on the tokio runtime.
167pub async fn connect(conn_string: &str) -> Result<Client> {
168    connect_with_config(conn_string, &SslMode::Prefer, 0, 30, 0).await
169}
170
171/// Connect to the database, retrying up to `retries` times with exponential backoff + jitter.
172///
173/// Each retry waits `min(2^attempt, 30) + rand(0..1000ms)` before the next attempt.
174/// Permanent errors (authentication failures) are not retried.
175pub async fn connect_with_config(
176    conn_string: &str,
177    ssl_mode: &SslMode,
178    retries: u32,
179    connect_timeout_secs: u32,
180    statement_timeout_secs: u32,
181) -> Result<Client> {
182    connect_with_full_config(
183        conn_string,
184        ssl_mode,
185        retries,
186        connect_timeout_secs,
187        statement_timeout_secs,
188        120,
189    )
190    .await
191}
192
193/// Connect to the database with all configuration options including TCP keepalive.
194pub async fn connect_with_full_config(
195    conn_string: &str,
196    ssl_mode: &SslMode,
197    retries: u32,
198    connect_timeout_secs: u32,
199    statement_timeout_secs: u32,
200    keepalive_secs: u32,
201) -> Result<Client> {
202    let conn_string = inject_keepalive(conn_string, keepalive_secs);
203    let mut last_err = None;
204
205    for attempt in 0..=retries {
206        if attempt > 0 {
207            let base_delay = std::cmp::min(1u64 << attempt, 30);
208            let jitter_ms = fastrand::u64(0..1000);
209            let delay = std::time::Duration::from_secs(base_delay)
210                + std::time::Duration::from_millis(jitter_ms);
211            log::info!(
212                "Connection attempt failed, retrying; attempt={}, max_attempts={}, delay_ms={}",
213                attempt + 1,
214                retries + 1,
215                delay.as_millis() as u64
216            );
217            tokio::time::sleep(delay).await;
218        }
219
220        match connect_once(&conn_string, ssl_mode, connect_timeout_secs).await {
221            Ok(client) => {
222                if attempt > 0 {
223                    log::info!(
224                        "Connected successfully after retry; attempt={}, max_attempts={}",
225                        attempt + 1,
226                        retries + 1
227                    );
228                }
229
230                // Set statement timeout if configured
231                if statement_timeout_secs > 0 {
232                    let timeout_sql =
233                        format!("SET statement_timeout = '{}s'", statement_timeout_secs);
234                    client.batch_execute(&timeout_sql).await?;
235                }
236
237                return Ok(client);
238            }
239            Err(e) => {
240                // Don't retry permanent errors (e.g. bad credentials)
241                if is_permanent_error(&e) {
242                    log::error!("Permanent connection error, not retrying: {}", e);
243                    return Err(WaypointError::DatabaseError(e));
244                }
245                last_err = Some(e);
246            }
247        }
248    }
249
250    Err(WaypointError::DatabaseError(last_err.unwrap()))
251}
252
253/// Acquire a PostgreSQL advisory lock based on the history table name.
254///
255/// This prevents concurrent migration runs from interfering with each other.
256pub async fn acquire_advisory_lock(client: &Client, table_name: &str) -> Result<()> {
257    let lock_id = advisory_lock_id(table_name);
258    log::info!(
259        "Acquiring advisory lock; lock_id={}, table={}",
260        lock_id,
261        table_name
262    );
263
264    client
265        .execute("SELECT pg_advisory_lock($1)", &[&lock_id])
266        .await
267        .map_err(|e| WaypointError::LockError(format!("Failed to acquire advisory lock: {}", e)))?;
268
269    Ok(())
270}
271
272/// Try to acquire a PostgreSQL advisory lock with a timeout.
273///
274/// Uses `pg_try_advisory_lock()` in a polling loop with configurable timeout.
275/// Returns Ok(()) if lock acquired, or a LockError if the timeout expires.
276pub async fn acquire_advisory_lock_with_timeout(
277    client: &Client,
278    table_name: &str,
279    timeout_secs: u32,
280) -> Result<()> {
281    let lock_id = advisory_lock_id(table_name);
282    log::info!(
283        "Trying to acquire advisory lock with timeout; lock_id={}, table={}, timeout_secs={}",
284        lock_id,
285        table_name,
286        timeout_secs
287    );
288
289    let deadline = std::time::Instant::now() + std::time::Duration::from_secs(timeout_secs as u64);
290
291    loop {
292        let row = client
293            .query_one("SELECT pg_try_advisory_lock($1)", &[&lock_id])
294            .await
295            .map_err(|e| WaypointError::LockError(format!("Failed to try advisory lock: {}", e)))?;
296
297        let acquired: bool = row.get(0);
298        if acquired {
299            return Ok(());
300        }
301
302        if std::time::Instant::now() >= deadline {
303            return Err(WaypointError::LockError(format!(
304                "Timed out waiting for advisory lock after {}s (table: {}). Another migration may be running.",
305                timeout_secs, table_name
306            )));
307        }
308
309        // Wait 500ms before retrying
310        tokio::time::sleep(std::time::Duration::from_millis(500)).await;
311    }
312}
313
314/// Release the PostgreSQL advisory lock.
315pub async fn release_advisory_lock(client: &Client, table_name: &str) -> Result<()> {
316    let lock_id = advisory_lock_id(table_name);
317    log::info!(
318        "Releasing advisory lock; lock_id={}, table={}",
319        lock_id,
320        table_name
321    );
322
323    client
324        .execute("SELECT pg_advisory_unlock($1)", &[&lock_id])
325        .await
326        .map_err(|e| WaypointError::LockError(format!("Failed to release advisory lock: {}", e)))?;
327
328    Ok(())
329}
330
331/// Compute a stable i64 lock ID from the table name using CRC32.
332///
333/// Uses CRC32 instead of DefaultHasher for cross-version stability —
334/// DefaultHasher is not guaranteed to produce the same output across
335/// Rust compiler versions.
336fn advisory_lock_id(table_name: &str) -> i64 {
337    crc32fast::hash(table_name.as_bytes()) as i64
338}
339
340/// Get the current database user.
341pub async fn get_current_user(client: &Client) -> Result<String> {
342    let row = client.query_one("SELECT current_user", &[]).await?;
343    Ok(row.get::<_, String>(0))
344}
345
346/// Get the current database name.
347pub async fn get_current_database(client: &Client) -> Result<String> {
348    let row = client.query_one("SELECT current_database()", &[]).await?;
349    Ok(row.get::<_, String>(0))
350}
351
352/// Execute a SQL string within a transaction using SQL-level BEGIN/COMMIT.
353/// Returns the execution time in milliseconds.
354pub async fn execute_in_transaction(client: &Client, sql: &str) -> Result<i32> {
355    let start = std::time::Instant::now();
356
357    client.batch_execute("BEGIN").await?;
358
359    match client.batch_execute(sql).await {
360        Ok(()) => {
361            client.batch_execute("COMMIT").await?;
362        }
363        Err(e) => {
364            if let Err(rollback_err) = client.batch_execute("ROLLBACK").await {
365                log::warn!("Failed to rollback transaction: {}", rollback_err);
366            }
367            return Err(WaypointError::DatabaseError(e));
368        }
369    }
370
371    let elapsed = start.elapsed().as_millis() as i32;
372    Ok(elapsed)
373}
374
375/// Execute SQL without a transaction wrapper (for statements that can't run in a transaction).
376pub async fn execute_raw(client: &Client, sql: &str) -> Result<i32> {
377    let start = std::time::Instant::now();
378    client.batch_execute(sql).await?;
379    let elapsed = start.elapsed().as_millis() as i32;
380    Ok(elapsed)
381}
382
383/// Check if an error is a transient connection error that may be retried.
384///
385/// Detects PostgreSQL server shutdown codes, connection exception codes,
386/// closed connections, and common network error message patterns.
387pub fn is_transient_error(e: &WaypointError) -> bool {
388    match e {
389        WaypointError::DatabaseError(pg_err) => {
390            // Check if the connection is closed
391            if pg_err.is_closed() {
392                return true;
393            }
394            // Check PostgreSQL error codes
395            if let Some(db_err) = pg_err.as_db_error() {
396                let code = db_err.code().code();
397                // 57P01 = admin_shutdown, 57P02 = crash_shutdown, 57P03 = cannot_connect_now
398                // 08000 = connection_exception, 08003 = connection_does_not_exist,
399                // 08006 = connection_failure
400                return matches!(
401                    code,
402                    "57P01" | "57P02" | "57P03" | "08000" | "08003" | "08006"
403                );
404            }
405            // Check error message patterns for connection-related issues
406            let msg = pg_err.to_string().to_lowercase();
407            msg.contains("connection reset")
408                || msg.contains("broken pipe")
409                || msg.contains("connection closed")
410                || msg.contains("unexpected eof")
411        }
412        WaypointError::ConnectionLost { .. } => true,
413        _ => false,
414    }
415}
416
417/// Verify the database connection is still alive with a minimal round-trip.
418pub async fn check_connection(client: &Client) -> Result<()> {
419    client
420        .simple_query("")
421        .await
422        .map_err(|e| WaypointError::ConnectionLost {
423            operation: "health check".to_string(),
424            detail: e.to_string(),
425        })?;
426    Ok(())
427}
428
429#[cfg(test)]
430mod tests {
431    use super::*;
432
433    // ── inject_keepalive tests ──
434
435    #[test]
436    fn test_inject_keepalive_url_style() {
437        let result = inject_keepalive("postgres://user:pass@localhost/db", 120);
438        assert_eq!(
439            result,
440            "postgres://user:pass@localhost/db?keepalives=1&keepalives_idle=120"
441        );
442    }
443
444    #[test]
445    fn test_inject_keepalive_url_with_existing_params() {
446        let result = inject_keepalive("postgres://user:pass@localhost/db?sslmode=require", 60);
447        assert_eq!(
448            result,
449            "postgres://user:pass@localhost/db?sslmode=require&keepalives=1&keepalives_idle=60"
450        );
451    }
452
453    #[test]
454    fn test_inject_keepalive_kv_style() {
455        let result = inject_keepalive("host=localhost port=5432 user=admin dbname=mydb", 90);
456        assert_eq!(
457            result,
458            "host=localhost port=5432 user=admin dbname=mydb keepalives=1 keepalives_idle=90"
459        );
460    }
461
462    #[test]
463    fn test_inject_keepalive_zero_disables() {
464        let result = inject_keepalive("postgres://user:pass@localhost/db", 0);
465        assert_eq!(result, "postgres://user:pass@localhost/db");
466    }
467
468    #[test]
469    fn test_inject_keepalive_already_present() {
470        let result = inject_keepalive("postgres://user:pass@localhost/db?keepalives=1", 120);
471        assert_eq!(result, "postgres://user:pass@localhost/db?keepalives=1");
472    }
473
474    // ── is_transient_error tests ──
475
476    #[test]
477    fn test_transient_error_connection_lost() {
478        let err = WaypointError::ConnectionLost {
479            operation: "test".to_string(),
480            detail: "gone".to_string(),
481        };
482        assert!(is_transient_error(&err));
483    }
484
485    #[test]
486    fn test_transient_error_config_is_not_transient() {
487        let err = WaypointError::ConfigError("bad config".to_string());
488        assert!(!is_transient_error(&err));
489    }
490
491    #[test]
492    fn test_transient_error_migration_failed_is_not_transient() {
493        let err = WaypointError::MigrationFailed {
494            script: "V1__test.sql".to_string(),
495            reason: "syntax error".to_string(),
496        };
497        assert!(!is_transient_error(&err));
498    }
499
500    #[test]
501    fn test_advisory_lock_id_stability() {
502        // Ensure the same table name always produces the same lock ID
503        let id1 = advisory_lock_id("waypoint_schema_history");
504        let id2 = advisory_lock_id("waypoint_schema_history");
505        assert_eq!(id1, id2);
506        // Different table names should produce different lock IDs
507        let id3 = advisory_lock_id("other_table");
508        assert_ne!(id1, id3);
509    }
510
511    #[test]
512    fn test_transient_error_lock_error_is_not_transient() {
513        let err = WaypointError::LockError("lock failed".to_string());
514        assert!(!is_transient_error(&err));
515    }
516
517    #[test]
518    fn test_transient_error_io_error_is_not_transient() {
519        let err = WaypointError::IoError(std::io::Error::new(
520            std::io::ErrorKind::NotFound,
521            "file not found",
522        ));
523        assert!(!is_transient_error(&err));
524    }
525
526    #[test]
527    fn test_validate_identifier_valid() {
528        assert!(validate_identifier("users").is_ok());
529        assert!(validate_identifier("my_table").is_ok());
530        assert!(validate_identifier("Table123").is_ok());
531        assert!(validate_identifier("a").is_ok());
532    }
533
534    #[test]
535    fn test_validate_identifier_invalid() {
536        assert!(validate_identifier("").is_err());
537        assert!(validate_identifier("my-table").is_err());
538        assert!(validate_identifier("my table").is_err());
539        assert!(validate_identifier("table.name").is_err());
540        assert!(validate_identifier("table;drop").is_err());
541    }
542
543    #[test]
544    fn test_quote_ident_simple() {
545        assert_eq!(quote_ident("users"), "\"users\"");
546    }
547
548    #[test]
549    fn test_quote_ident_embedded_quotes() {
550        assert_eq!(quote_ident("my\"table"), "\"my\"\"table\"");
551    }
552
553    #[test]
554    fn test_quote_ident_empty() {
555        assert_eq!(quote_ident(""), "\"\"");
556    }
557
558    #[test]
559    fn test_inject_keepalive_postgresql_prefix() {
560        let result = inject_keepalive("postgresql://user:pass@localhost/db", 120);
561        assert_eq!(
562            result,
563            "postgresql://user:pass@localhost/db?keepalives=1&keepalives_idle=120"
564        );
565    }
566}