1use rand::Rng;
2use tokio_postgres::Client;
3
4use crate::config::SslMode;
5use crate::error::{Result, WaypointError};
6
7pub fn quote_ident(name: &str) -> String {
11 format!("\"{}\"", name.replace('"', "\"\""))
12}
13
14pub fn validate_identifier(name: &str) -> Result<()> {
19 if name.is_empty() {
20 return Err(WaypointError::ConfigError(
21 "Identifier cannot be empty".to_string(),
22 ));
23 }
24 if !name.chars().all(|c| c.is_ascii_alphanumeric() || c == '_') {
25 return Err(WaypointError::ConfigError(format!(
26 "Identifier '{}' contains invalid characters. Only [a-zA-Z0-9_] are allowed.",
27 name
28 )));
29 }
30 Ok(())
31}
32
33fn make_rustls_config() -> rustls::ClientConfig {
35 let root_store =
36 rustls::RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
37 rustls::ClientConfig::builder()
38 .with_root_certificates(root_store)
39 .with_no_client_auth()
40}
41
42fn is_permanent_error(e: &tokio_postgres::Error) -> bool {
44 if let Some(db_err) = e.as_db_error() {
45 let code = db_err.code().code();
46 return code == "28P01" || code == "28000";
48 }
49 false
50}
51
52async fn connect_once(
56 conn_string: &str,
57 ssl_mode: &SslMode,
58 connect_timeout_secs: u32,
59) -> std::result::Result<Client, tokio_postgres::Error> {
60 let connect_fut = async {
61 match ssl_mode {
62 SslMode::Disable => {
63 let (client, connection) =
64 tokio_postgres::connect(conn_string, tokio_postgres::NoTls).await?;
65 tokio::spawn(async move {
66 if let Err(e) = connection.await {
67 tracing::error!(error = %e, "Database connection error");
68 }
69 });
70 Ok(client)
71 }
72 SslMode::Require => {
73 let tls_config = make_rustls_config();
74 let tls = tokio_postgres_rustls::MakeRustlsConnect::new(tls_config);
75 let (client, connection) = tokio_postgres::connect(conn_string, tls).await?;
76 tokio::spawn(async move {
77 if let Err(e) = connection.await {
78 tracing::error!(error = %e, "Database connection error");
79 }
80 });
81 Ok(client)
82 }
83 SslMode::Prefer => {
84 let tls_config = make_rustls_config();
86 let tls = tokio_postgres_rustls::MakeRustlsConnect::new(tls_config);
87 match tokio_postgres::connect(conn_string, tls).await {
88 Ok((client, connection)) => {
89 tokio::spawn(async move {
90 if let Err(e) = connection.await {
91 tracing::error!(error = %e, "Database connection error");
92 }
93 });
94 Ok(client)
95 }
96 Err(_) => {
97 tracing::debug!("TLS connection failed, falling back to plaintext");
98 let (client, connection) =
99 tokio_postgres::connect(conn_string, tokio_postgres::NoTls).await?;
100 tokio::spawn(async move {
101 if let Err(e) = connection.await {
102 tracing::error!(error = %e, "Database connection error");
103 }
104 });
105 Ok(client)
106 }
107 }
108 }
109 }
110 };
111
112 if connect_timeout_secs > 0 {
113 match tokio::time::timeout(
114 std::time::Duration::from_secs(connect_timeout_secs as u64),
115 connect_fut,
116 )
117 .await
118 {
119 Ok(result) => result,
120 Err(_) => Err(tokio_postgres::Error::__private_api_timeout()),
121 }
122 } else {
123 connect_fut.await
124 }
125}
126
127pub async fn connect(conn_string: &str) -> Result<Client> {
131 connect_with_config(conn_string, &SslMode::Prefer, 0, 30, 0).await
132}
133
134pub async fn connect_with_config(
139 conn_string: &str,
140 ssl_mode: &SslMode,
141 retries: u32,
142 connect_timeout_secs: u32,
143 statement_timeout_secs: u32,
144) -> Result<Client> {
145 let mut last_err = None;
146
147 for attempt in 0..=retries {
148 if attempt > 0 {
149 let base_delay = std::cmp::min(1u64 << attempt, 30);
150 let jitter_ms = rand::thread_rng().gen_range(0..1000);
151 let delay = std::time::Duration::from_secs(base_delay)
152 + std::time::Duration::from_millis(jitter_ms);
153 tracing::info!(
154 attempt = attempt + 1,
155 max_attempts = retries + 1,
156 delay_ms = delay.as_millis() as u64,
157 "Connection attempt failed, retrying"
158 );
159 tokio::time::sleep(delay).await;
160 }
161
162 match connect_once(conn_string, ssl_mode, connect_timeout_secs).await {
163 Ok(client) => {
164 if attempt > 0 {
165 tracing::info!(
166 attempt = attempt + 1,
167 max_attempts = retries + 1,
168 "Connected successfully after retry"
169 );
170 }
171
172 if statement_timeout_secs > 0 {
174 let timeout_sql =
175 format!("SET statement_timeout = '{}s'", statement_timeout_secs);
176 client.batch_execute(&timeout_sql).await?;
177 }
178
179 return Ok(client);
180 }
181 Err(e) => {
182 if is_permanent_error(&e) {
184 tracing::error!(error = %e, "Permanent connection error, not retrying");
185 return Err(WaypointError::DatabaseError(e));
186 }
187 last_err = Some(e);
188 }
189 }
190 }
191
192 Err(WaypointError::DatabaseError(last_err.unwrap()))
193}
194
195pub async fn acquire_advisory_lock(client: &Client, table_name: &str) -> Result<()> {
199 let lock_id = advisory_lock_id(table_name);
200 tracing::info!(lock_id = lock_id, table = %table_name, "Acquiring advisory lock");
201
202 client
203 .execute(&format!("SELECT pg_advisory_lock({})", lock_id), &[])
204 .await
205 .map_err(|e| WaypointError::LockError(format!("Failed to acquire advisory lock: {}", e)))?;
206
207 Ok(())
208}
209
210pub async fn release_advisory_lock(client: &Client, table_name: &str) -> Result<()> {
212 let lock_id = advisory_lock_id(table_name);
213 tracing::info!(lock_id = lock_id, table = %table_name, "Releasing advisory lock");
214
215 client
216 .execute(&format!("SELECT pg_advisory_unlock({})", lock_id), &[])
217 .await
218 .map_err(|e| WaypointError::LockError(format!("Failed to release advisory lock: {}", e)))?;
219
220 Ok(())
221}
222
223fn advisory_lock_id(table_name: &str) -> i64 {
229 crc32fast::hash(table_name.as_bytes()) as i64
230}
231
232pub async fn get_current_user(client: &Client) -> Result<String> {
234 let row = client.query_one("SELECT current_user", &[]).await?;
235 Ok(row.get::<_, String>(0))
236}
237
238pub async fn get_current_database(client: &Client) -> Result<String> {
240 let row = client.query_one("SELECT current_database()", &[]).await?;
241 Ok(row.get::<_, String>(0))
242}
243
244pub async fn execute_in_transaction(client: &Client, sql: &str) -> Result<i32> {
247 let start = std::time::Instant::now();
248
249 client.batch_execute("BEGIN").await?;
250
251 match client.batch_execute(sql).await {
252 Ok(()) => {
253 client.batch_execute("COMMIT").await?;
254 }
255 Err(e) => {
256 if let Err(rollback_err) = client.batch_execute("ROLLBACK").await {
257 tracing::warn!(error = %rollback_err, "Failed to rollback transaction");
258 }
259 return Err(WaypointError::DatabaseError(e));
260 }
261 }
262
263 let elapsed = start.elapsed().as_millis() as i32;
264 Ok(elapsed)
265}
266
267pub async fn execute_raw(client: &Client, sql: &str) -> Result<i32> {
269 let start = std::time::Instant::now();
270 client.batch_execute(sql).await?;
271 let elapsed = start.elapsed().as_millis() as i32;
272 Ok(elapsed)
273}