1use fastrand;
4use tokio_postgres::Client;
5
6use crate::config::SslMode;
7use crate::error::{Result, WaypointError};
8
9pub fn quote_ident(name: &str) -> String {
13 format!("\"{}\"", name.replace('"', "\"\""))
14}
15
16pub fn validate_identifier(name: &str) -> Result<()> {
21 if name.is_empty() {
22 return Err(WaypointError::ConfigError(
23 "Identifier cannot be empty".to_string(),
24 ));
25 }
26 if !name.chars().all(|c| c.is_ascii_alphanumeric() || c == '_') {
27 return Err(WaypointError::ConfigError(format!(
28 "Identifier '{}' contains invalid characters. Only [a-zA-Z0-9_] are allowed.",
29 name
30 )));
31 }
32 Ok(())
33}
34
35fn make_rustls_config() -> rustls::ClientConfig {
37 let root_store =
38 rustls::RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
39 rustls::ClientConfig::builder_with_provider(std::sync::Arc::new(
40 rustls::crypto::ring::default_provider(),
41 ))
42 .with_safe_default_protocol_versions()
43 .unwrap()
44 .with_root_certificates(root_store)
45 .with_no_client_auth()
46}
47
48fn is_permanent_error(e: &tokio_postgres::Error) -> bool {
50 if let Some(db_err) = e.as_db_error() {
51 let code = db_err.code().code();
52 return code == "28P01" || code == "28000";
54 }
55 false
56}
57
58pub fn inject_keepalive(conn_string: &str, keepalive_secs: u32) -> String {
64 if keepalive_secs == 0 {
65 return conn_string.to_string();
66 }
67 let lower = conn_string.to_lowercase();
68 if lower.contains("keepalives") {
69 return conn_string.to_string();
70 }
71 let params = format!("keepalives=1&keepalives_idle={}", keepalive_secs);
72 if conn_string.starts_with("postgres://") || conn_string.starts_with("postgresql://") {
73 if conn_string.contains('?') {
74 format!("{}&{}", conn_string, params)
75 } else {
76 format!("{}?{}", conn_string, params)
77 }
78 } else {
79 format!(
81 "{} keepalives=1 keepalives_idle={}",
82 conn_string, keepalive_secs
83 )
84 }
85}
86
87fn spawn_connection_task<F>(connection: F)
93where
94 F: std::future::Future<Output = std::result::Result<(), tokio_postgres::Error>>
95 + Send
96 + 'static,
97{
98 tokio::spawn(async move {
99 if let Err(e) = connection.await {
100 log::error!("Database connection error: {}", e);
101 }
102 });
103}
104
105async fn connect_once(
109 conn_string: &str,
110 ssl_mode: &SslMode,
111 connect_timeout_secs: u32,
112) -> std::result::Result<Client, tokio_postgres::Error> {
113 let connect_fut = async {
114 match ssl_mode {
115 SslMode::Disable => {
116 let (client, connection) =
117 tokio_postgres::connect(conn_string, tokio_postgres::NoTls).await?;
118 spawn_connection_task(connection);
119 Ok(client)
120 }
121 SslMode::Require => {
122 let tls_config = make_rustls_config();
123 let tls = tokio_postgres_rustls::MakeRustlsConnect::new(tls_config);
124 let (client, connection) = tokio_postgres::connect(conn_string, tls).await?;
125 spawn_connection_task(connection);
126 Ok(client)
127 }
128 SslMode::Prefer => {
129 let tls_config = make_rustls_config();
131 let tls = tokio_postgres_rustls::MakeRustlsConnect::new(tls_config);
132 match tokio_postgres::connect(conn_string, tls).await {
133 Ok((client, connection)) => {
134 spawn_connection_task(connection);
135 Ok(client)
136 }
137 Err(_) => {
138 log::debug!("TLS connection failed, falling back to plaintext");
139 let (client, connection) =
140 tokio_postgres::connect(conn_string, tokio_postgres::NoTls).await?;
141 spawn_connection_task(connection);
142 Ok(client)
143 }
144 }
145 }
146 }
147 };
148
149 if connect_timeout_secs > 0 {
150 match tokio::time::timeout(
151 std::time::Duration::from_secs(connect_timeout_secs as u64),
152 connect_fut,
153 )
154 .await
155 {
156 Ok(result) => result,
157 Err(_) => Err(tokio_postgres::Error::__private_api_timeout()),
158 }
159 } else {
160 connect_fut.await
161 }
162}
163
164pub async fn connect(conn_string: &str) -> Result<Client> {
168 connect_with_config(conn_string, &SslMode::Prefer, 0, 30, 0).await
169}
170
171pub async fn connect_with_config(
176 conn_string: &str,
177 ssl_mode: &SslMode,
178 retries: u32,
179 connect_timeout_secs: u32,
180 statement_timeout_secs: u32,
181) -> Result<Client> {
182 connect_with_full_config(
183 conn_string,
184 ssl_mode,
185 retries,
186 connect_timeout_secs,
187 statement_timeout_secs,
188 120,
189 )
190 .await
191}
192
193pub async fn connect_with_full_config(
195 conn_string: &str,
196 ssl_mode: &SslMode,
197 retries: u32,
198 connect_timeout_secs: u32,
199 statement_timeout_secs: u32,
200 keepalive_secs: u32,
201) -> Result<Client> {
202 let conn_string = inject_keepalive(conn_string, keepalive_secs);
203 let mut last_err = None;
204
205 for attempt in 0..=retries {
206 if attempt > 0 {
207 let base_delay = std::cmp::min(1u64 << attempt, 30);
208 let jitter_ms = fastrand::u64(0..1000);
209 let delay = std::time::Duration::from_secs(base_delay)
210 + std::time::Duration::from_millis(jitter_ms);
211 log::info!(
212 "Connection attempt failed, retrying; attempt={}, max_attempts={}, delay_ms={}",
213 attempt + 1,
214 retries + 1,
215 delay.as_millis() as u64
216 );
217 tokio::time::sleep(delay).await;
218 }
219
220 match connect_once(&conn_string, ssl_mode, connect_timeout_secs).await {
221 Ok(client) => {
222 if attempt > 0 {
223 log::info!(
224 "Connected successfully after retry; attempt={}, max_attempts={}",
225 attempt + 1,
226 retries + 1
227 );
228 }
229
230 if statement_timeout_secs > 0 {
232 let timeout_sql =
233 format!("SET statement_timeout = '{}s'", statement_timeout_secs);
234 client.batch_execute(&timeout_sql).await?;
235 }
236
237 return Ok(client);
238 }
239 Err(e) => {
240 if is_permanent_error(&e) {
242 log::error!("Permanent connection error, not retrying: {}", e);
243 return Err(WaypointError::DatabaseError(e));
244 }
245 last_err = Some(e);
246 }
247 }
248 }
249
250 Err(WaypointError::DatabaseError(last_err.unwrap()))
251}
252
253pub async fn acquire_advisory_lock(client: &Client, table_name: &str) -> Result<()> {
257 let lock_id = advisory_lock_id(table_name);
258 log::info!(
259 "Acquiring advisory lock; lock_id={}, table={}",
260 lock_id,
261 table_name
262 );
263
264 client
265 .execute("SELECT pg_advisory_lock($1)", &[&lock_id])
266 .await
267 .map_err(|e| WaypointError::LockError(format!("Failed to acquire advisory lock: {}", e)))?;
268
269 Ok(())
270}
271
272pub async fn acquire_advisory_lock_with_timeout(
277 client: &Client,
278 table_name: &str,
279 timeout_secs: u32,
280) -> Result<()> {
281 let lock_id = advisory_lock_id(table_name);
282 log::info!(
283 "Trying to acquire advisory lock with timeout; lock_id={}, table={}, timeout_secs={}",
284 lock_id,
285 table_name,
286 timeout_secs
287 );
288
289 let deadline = std::time::Instant::now() + std::time::Duration::from_secs(timeout_secs as u64);
290
291 loop {
292 let row = client
293 .query_one("SELECT pg_try_advisory_lock($1)", &[&lock_id])
294 .await
295 .map_err(|e| WaypointError::LockError(format!("Failed to try advisory lock: {}", e)))?;
296
297 let acquired: bool = row.get(0);
298 if acquired {
299 return Ok(());
300 }
301
302 if std::time::Instant::now() >= deadline {
303 return Err(WaypointError::LockError(format!(
304 "Timed out waiting for advisory lock after {}s (table: {}). Another migration may be running.",
305 timeout_secs, table_name
306 )));
307 }
308
309 tokio::time::sleep(std::time::Duration::from_millis(500)).await;
311 }
312}
313
314pub async fn release_advisory_lock(client: &Client, table_name: &str) -> Result<()> {
316 let lock_id = advisory_lock_id(table_name);
317 log::info!(
318 "Releasing advisory lock; lock_id={}, table={}",
319 lock_id,
320 table_name
321 );
322
323 client
324 .execute("SELECT pg_advisory_unlock($1)", &[&lock_id])
325 .await
326 .map_err(|e| WaypointError::LockError(format!("Failed to release advisory lock: {}", e)))?;
327
328 Ok(())
329}
330
331fn advisory_lock_id(table_name: &str) -> i64 {
337 crc32fast::hash(table_name.as_bytes()) as i64
338}
339
340pub async fn get_current_user(client: &Client) -> Result<String> {
342 let row = client.query_one("SELECT current_user", &[]).await?;
343 Ok(row.get::<_, String>(0))
344}
345
346pub async fn get_current_database(client: &Client) -> Result<String> {
348 let row = client.query_one("SELECT current_database()", &[]).await?;
349 Ok(row.get::<_, String>(0))
350}
351
352pub async fn execute_in_transaction(client: &Client, sql: &str) -> Result<i32> {
355 let start = std::time::Instant::now();
356
357 client.batch_execute("BEGIN").await?;
358
359 match client.batch_execute(sql).await {
360 Ok(()) => {
361 client.batch_execute("COMMIT").await?;
362 }
363 Err(e) => {
364 if let Err(rollback_err) = client.batch_execute("ROLLBACK").await {
365 log::warn!("Failed to rollback transaction: {}", rollback_err);
366 }
367 return Err(WaypointError::DatabaseError(e));
368 }
369 }
370
371 let elapsed = start.elapsed().as_millis() as i32;
372 Ok(elapsed)
373}
374
375pub async fn execute_raw(client: &Client, sql: &str) -> Result<i32> {
377 let start = std::time::Instant::now();
378 client.batch_execute(sql).await?;
379 let elapsed = start.elapsed().as_millis() as i32;
380 Ok(elapsed)
381}
382
383pub fn is_transient_error(e: &WaypointError) -> bool {
388 match e {
389 WaypointError::DatabaseError(pg_err) => {
390 if pg_err.is_closed() {
392 return true;
393 }
394 if let Some(db_err) = pg_err.as_db_error() {
396 let code = db_err.code().code();
397 return matches!(
401 code,
402 "57P01" | "57P02" | "57P03" | "08000" | "08003" | "08006"
403 );
404 }
405 let msg = pg_err.to_string().to_lowercase();
407 msg.contains("connection reset")
408 || msg.contains("broken pipe")
409 || msg.contains("connection closed")
410 || msg.contains("unexpected eof")
411 }
412 WaypointError::ConnectionLost { .. } => true,
413 _ => false,
414 }
415}
416
417pub async fn check_connection(client: &Client) -> Result<()> {
419 client
420 .simple_query("")
421 .await
422 .map_err(|e| WaypointError::ConnectionLost {
423 operation: "health check".to_string(),
424 detail: e.to_string(),
425 })?;
426 Ok(())
427}
428
429#[cfg(test)]
430mod tests {
431 use super::*;
432
433 #[test]
436 fn test_inject_keepalive_url_style() {
437 let result = inject_keepalive("postgres://user:pass@localhost/db", 120);
438 assert_eq!(
439 result,
440 "postgres://user:pass@localhost/db?keepalives=1&keepalives_idle=120"
441 );
442 }
443
444 #[test]
445 fn test_inject_keepalive_url_with_existing_params() {
446 let result = inject_keepalive("postgres://user:pass@localhost/db?sslmode=require", 60);
447 assert_eq!(
448 result,
449 "postgres://user:pass@localhost/db?sslmode=require&keepalives=1&keepalives_idle=60"
450 );
451 }
452
453 #[test]
454 fn test_inject_keepalive_kv_style() {
455 let result = inject_keepalive("host=localhost port=5432 user=admin dbname=mydb", 90);
456 assert_eq!(
457 result,
458 "host=localhost port=5432 user=admin dbname=mydb keepalives=1 keepalives_idle=90"
459 );
460 }
461
462 #[test]
463 fn test_inject_keepalive_zero_disables() {
464 let result = inject_keepalive("postgres://user:pass@localhost/db", 0);
465 assert_eq!(result, "postgres://user:pass@localhost/db");
466 }
467
468 #[test]
469 fn test_inject_keepalive_already_present() {
470 let result = inject_keepalive("postgres://user:pass@localhost/db?keepalives=1", 120);
471 assert_eq!(result, "postgres://user:pass@localhost/db?keepalives=1");
472 }
473
474 #[test]
477 fn test_transient_error_connection_lost() {
478 let err = WaypointError::ConnectionLost {
479 operation: "test".to_string(),
480 detail: "gone".to_string(),
481 };
482 assert!(is_transient_error(&err));
483 }
484
485 #[test]
486 fn test_transient_error_config_is_not_transient() {
487 let err = WaypointError::ConfigError("bad config".to_string());
488 assert!(!is_transient_error(&err));
489 }
490
491 #[test]
492 fn test_transient_error_migration_failed_is_not_transient() {
493 let err = WaypointError::MigrationFailed {
494 script: "V1__test.sql".to_string(),
495 reason: "syntax error".to_string(),
496 };
497 assert!(!is_transient_error(&err));
498 }
499
500 #[test]
501 fn test_advisory_lock_id_stability() {
502 let id1 = advisory_lock_id("waypoint_schema_history");
504 let id2 = advisory_lock_id("waypoint_schema_history");
505 assert_eq!(id1, id2);
506 let id3 = advisory_lock_id("other_table");
508 assert_ne!(id1, id3);
509 }
510
511 #[test]
512 fn test_transient_error_lock_error_is_not_transient() {
513 let err = WaypointError::LockError("lock failed".to_string());
514 assert!(!is_transient_error(&err));
515 }
516
517 #[test]
518 fn test_transient_error_io_error_is_not_transient() {
519 let err = WaypointError::IoError(std::io::Error::new(
520 std::io::ErrorKind::NotFound,
521 "file not found",
522 ));
523 assert!(!is_transient_error(&err));
524 }
525
526 #[test]
527 fn test_validate_identifier_valid() {
528 assert!(validate_identifier("users").is_ok());
529 assert!(validate_identifier("my_table").is_ok());
530 assert!(validate_identifier("Table123").is_ok());
531 assert!(validate_identifier("a").is_ok());
532 }
533
534 #[test]
535 fn test_validate_identifier_invalid() {
536 assert!(validate_identifier("").is_err());
537 assert!(validate_identifier("my-table").is_err());
538 assert!(validate_identifier("my table").is_err());
539 assert!(validate_identifier("table.name").is_err());
540 assert!(validate_identifier("table;drop").is_err());
541 }
542
543 #[test]
544 fn test_quote_ident_simple() {
545 assert_eq!(quote_ident("users"), "\"users\"");
546 }
547
548 #[test]
549 fn test_quote_ident_embedded_quotes() {
550 assert_eq!(quote_ident("my\"table"), "\"my\"\"table\"");
551 }
552
553 #[test]
554 fn test_quote_ident_empty() {
555 assert_eq!(quote_ident(""), "\"\"");
556 }
557
558 #[test]
559 fn test_inject_keepalive_postgresql_prefix() {
560 let result = inject_keepalive("postgresql://user:pass@localhost/db", 120);
561 assert_eq!(
562 result,
563 "postgresql://user:pass@localhost/db?keepalives=1&keepalives_idle=120"
564 );
565 }
566}