Skip to main content

waypoint_core/
db.rs

1use rand::Rng;
2use tokio_postgres::Client;
3
4use crate::config::SslMode;
5use crate::error::{Result, WaypointError};
6
7/// Quote a SQL identifier to prevent SQL injection.
8///
9/// Doubles any embedded double-quotes and wraps in double-quotes.
10pub fn quote_ident(name: &str) -> String {
11    format!("\"{}\"", name.replace('"', "\"\""))
12}
13
14/// Validate that a SQL identifier contains only safe characters.
15///
16/// Returns an error for names with characters outside `[a-zA-Z0-9_]`.
17/// Even with quoting (defense in depth), we reject suspicious identifiers early.
18pub fn validate_identifier(name: &str) -> Result<()> {
19    if name.is_empty() {
20        return Err(WaypointError::ConfigError(
21            "Identifier cannot be empty".to_string(),
22        ));
23    }
24    if !name.chars().all(|c| c.is_ascii_alphanumeric() || c == '_') {
25        return Err(WaypointError::ConfigError(format!(
26            "Identifier '{}' contains invalid characters. Only [a-zA-Z0-9_] are allowed.",
27            name
28        )));
29    }
30    Ok(())
31}
32
33/// Build a rustls ClientConfig using the Mozilla CA bundle.
34fn make_rustls_config() -> rustls::ClientConfig {
35    let root_store =
36        rustls::RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
37    rustls::ClientConfig::builder()
38        .with_root_certificates(root_store)
39        .with_no_client_auth()
40}
41
42/// Check if a postgres error is a permanent authentication failure that should not be retried.
43fn is_permanent_error(e: &tokio_postgres::Error) -> bool {
44    if let Some(db_err) = e.as_db_error() {
45        let code = db_err.code().code();
46        // 28P01 = invalid_password, 28000 = invalid_authorization_specification
47        return code == "28P01" || code == "28000";
48    }
49    false
50}
51
52/// Connect to the database using the provided connection string with TLS support.
53///
54/// Spawns the connection task on the tokio runtime.
55async fn connect_once(
56    conn_string: &str,
57    ssl_mode: &SslMode,
58    connect_timeout_secs: u32,
59) -> std::result::Result<Client, tokio_postgres::Error> {
60    let connect_fut = async {
61        match ssl_mode {
62            SslMode::Disable => {
63                let (client, connection) =
64                    tokio_postgres::connect(conn_string, tokio_postgres::NoTls).await?;
65                tokio::spawn(async move {
66                    if let Err(e) = connection.await {
67                        tracing::error!(error = %e, "Database connection error");
68                    }
69                });
70                Ok(client)
71            }
72            SslMode::Require => {
73                let tls_config = make_rustls_config();
74                let tls = tokio_postgres_rustls::MakeRustlsConnect::new(tls_config);
75                let (client, connection) = tokio_postgres::connect(conn_string, tls).await?;
76                tokio::spawn(async move {
77                    if let Err(e) = connection.await {
78                        tracing::error!(error = %e, "Database connection error");
79                    }
80                });
81                Ok(client)
82            }
83            SslMode::Prefer => {
84                // Try TLS first, fall back to plaintext
85                let tls_config = make_rustls_config();
86                let tls = tokio_postgres_rustls::MakeRustlsConnect::new(tls_config);
87                match tokio_postgres::connect(conn_string, tls).await {
88                    Ok((client, connection)) => {
89                        tokio::spawn(async move {
90                            if let Err(e) = connection.await {
91                                tracing::error!(error = %e, "Database connection error");
92                            }
93                        });
94                        Ok(client)
95                    }
96                    Err(_) => {
97                        tracing::debug!("TLS connection failed, falling back to plaintext");
98                        let (client, connection) =
99                            tokio_postgres::connect(conn_string, tokio_postgres::NoTls).await?;
100                        tokio::spawn(async move {
101                            if let Err(e) = connection.await {
102                                tracing::error!(error = %e, "Database connection error");
103                            }
104                        });
105                        Ok(client)
106                    }
107                }
108            }
109        }
110    };
111
112    if connect_timeout_secs > 0 {
113        match tokio::time::timeout(
114            std::time::Duration::from_secs(connect_timeout_secs as u64),
115            connect_fut,
116        )
117        .await
118        {
119            Ok(result) => result,
120            Err(_) => Err(tokio_postgres::Error::__private_api_timeout()),
121        }
122    } else {
123        connect_fut.await
124    }
125}
126
127/// Connect to the database using the provided connection string.
128///
129/// Spawns the connection task on the tokio runtime.
130pub async fn connect(conn_string: &str) -> Result<Client> {
131    connect_with_config(conn_string, &SslMode::Prefer, 0, 30, 0).await
132}
133
134/// Connect to the database, retrying up to `retries` times with exponential backoff + jitter.
135///
136/// Each retry waits `min(2^attempt, 30) + rand(0..1000ms)` before the next attempt.
137/// Permanent errors (authentication failures) are not retried.
138pub async fn connect_with_config(
139    conn_string: &str,
140    ssl_mode: &SslMode,
141    retries: u32,
142    connect_timeout_secs: u32,
143    statement_timeout_secs: u32,
144) -> Result<Client> {
145    let mut last_err = None;
146
147    for attempt in 0..=retries {
148        if attempt > 0 {
149            let base_delay = std::cmp::min(1u64 << attempt, 30);
150            let jitter_ms = rand::thread_rng().gen_range(0..1000);
151            let delay = std::time::Duration::from_secs(base_delay)
152                + std::time::Duration::from_millis(jitter_ms);
153            tracing::info!(
154                attempt = attempt + 1,
155                max_attempts = retries + 1,
156                delay_ms = delay.as_millis() as u64,
157                "Connection attempt failed, retrying"
158            );
159            tokio::time::sleep(delay).await;
160        }
161
162        match connect_once(conn_string, ssl_mode, connect_timeout_secs).await {
163            Ok(client) => {
164                if attempt > 0 {
165                    tracing::info!(
166                        attempt = attempt + 1,
167                        max_attempts = retries + 1,
168                        "Connected successfully after retry"
169                    );
170                }
171
172                // Set statement timeout if configured
173                if statement_timeout_secs > 0 {
174                    let timeout_sql =
175                        format!("SET statement_timeout = '{}s'", statement_timeout_secs);
176                    client.batch_execute(&timeout_sql).await?;
177                }
178
179                return Ok(client);
180            }
181            Err(e) => {
182                // Don't retry permanent errors (e.g. bad credentials)
183                if is_permanent_error(&e) {
184                    tracing::error!(error = %e, "Permanent connection error, not retrying");
185                    return Err(WaypointError::DatabaseError(e));
186                }
187                last_err = Some(e);
188            }
189        }
190    }
191
192    Err(WaypointError::DatabaseError(last_err.unwrap()))
193}
194
195/// Acquire a PostgreSQL advisory lock based on the history table name.
196///
197/// This prevents concurrent migration runs from interfering with each other.
198pub async fn acquire_advisory_lock(client: &Client, table_name: &str) -> Result<()> {
199    let lock_id = advisory_lock_id(table_name);
200    tracing::info!(lock_id = lock_id, table = %table_name, "Acquiring advisory lock");
201
202    client
203        .execute(&format!("SELECT pg_advisory_lock({})", lock_id), &[])
204        .await
205        .map_err(|e| WaypointError::LockError(format!("Failed to acquire advisory lock: {}", e)))?;
206
207    Ok(())
208}
209
210/// Release the PostgreSQL advisory lock.
211pub async fn release_advisory_lock(client: &Client, table_name: &str) -> Result<()> {
212    let lock_id = advisory_lock_id(table_name);
213    tracing::info!(lock_id = lock_id, table = %table_name, "Releasing advisory lock");
214
215    client
216        .execute(&format!("SELECT pg_advisory_unlock({})", lock_id), &[])
217        .await
218        .map_err(|e| WaypointError::LockError(format!("Failed to release advisory lock: {}", e)))?;
219
220    Ok(())
221}
222
223/// Compute a stable i64 lock ID from the table name using CRC32.
224///
225/// Uses CRC32 instead of DefaultHasher for cross-version stability —
226/// DefaultHasher is not guaranteed to produce the same output across
227/// Rust compiler versions.
228fn advisory_lock_id(table_name: &str) -> i64 {
229    crc32fast::hash(table_name.as_bytes()) as i64
230}
231
232/// Get the current database user.
233pub async fn get_current_user(client: &Client) -> Result<String> {
234    let row = client.query_one("SELECT current_user", &[]).await?;
235    Ok(row.get::<_, String>(0))
236}
237
238/// Get the current database name.
239pub async fn get_current_database(client: &Client) -> Result<String> {
240    let row = client.query_one("SELECT current_database()", &[]).await?;
241    Ok(row.get::<_, String>(0))
242}
243
244/// Execute a SQL string within a transaction using SQL-level BEGIN/COMMIT.
245/// Returns the execution time in milliseconds.
246pub async fn execute_in_transaction(client: &Client, sql: &str) -> Result<i32> {
247    let start = std::time::Instant::now();
248
249    client.batch_execute("BEGIN").await?;
250
251    match client.batch_execute(sql).await {
252        Ok(()) => {
253            client.batch_execute("COMMIT").await?;
254        }
255        Err(e) => {
256            if let Err(rollback_err) = client.batch_execute("ROLLBACK").await {
257                tracing::warn!(error = %rollback_err, "Failed to rollback transaction");
258            }
259            return Err(WaypointError::DatabaseError(e));
260        }
261    }
262
263    let elapsed = start.elapsed().as_millis() as i32;
264    Ok(elapsed)
265}
266
267/// Execute SQL without a transaction wrapper (for statements that can't run in a transaction).
268pub async fn execute_raw(client: &Client, sql: &str) -> Result<i32> {
269    let start = std::time::Instant::now();
270    client.batch_execute(sql).await?;
271    let elapsed = start.elapsed().as_millis() as i32;
272    Ok(elapsed)
273}