warpdrive_proxy/postgres/mod.rs
1//! PostgreSQL LISTEN/NOTIFY coordination for distributed WarpDrive instances
2//!
3//! This module provides distributed coordination using PostgreSQL's LISTEN/NOTIFY
4//! feature. Multiple WarpDrive instances can coordinate cache invalidation,
5//! circuit breaker state, rate limits, and health status.
6//!
7//! # Architecture
8//!
9//! - **PgPool**: Connection pool for PostgreSQL operations
10//! - **PgListener**: Long-lived connection for receiving NOTIFY events
11//! - **PgNotifier**: Publishes NOTIFY events to channels
12//! - **PgNotification**: Typed notification message with JSON payload
13//!
14//! # Example
15//!
16//! ```no_run
17//! use warpdrive::postgres::{PgPool, PgListener, PgNotifier};
18//! use warpdrive::config::Config;
19//!
20//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
21//! let config = Config::from_env()?;
22//! let pool = PgPool::from_config(&config).await?;
23//!
24//! // Start listening for cache invalidations
25//! let channels = vec!["warpdrive:cache:invalidate".to_string()];
26//! let mut listener = PgListener::new(&pool, channels).await?;
27//!
28//! // Publish a notification
29//! let notifier = PgNotifier::new(pool.clone());
30//! notifier.notify("warpdrive:cache:invalidate", &serde_json::json!({
31//! "key": "user:123"
32//! })).await?;
33//!
34//! // Receive notification
35//! use futures::StreamExt;
36//! while let Some(notification) = listener.stream().next().await {
37//! println!("Received: {:?}", notification);
38//! }
39//! # Ok(())
40//! # }
41//! ```
42
43use deadpool_postgres::{Manager, ManagerConfig, Pool, RecyclingMethod};
44use futures::Stream;
45use serde::{Deserialize, Serialize};
46use std::pin::Pin;
47use std::sync::Arc;
48use tokio::sync::mpsc;
49use tokio_postgres::{Client, Config as PgConfig, NoTls};
50
51use crate::config::Config;
52
53/// PostgreSQL connection pool for distributed coordination
54///
55/// Provides connection pooling for PostgreSQL operations. Connections are
56/// reused across operations for optimal performance.
57#[derive(Clone)]
58pub struct PgPool {
59 pool: Pool,
60}
61
62impl PgPool {
63 /// Create a new PostgreSQL connection pool from WarpDrive config
64 ///
65 /// # Arguments
66 ///
67 /// * `config` - WarpDrive configuration with database_url
68 ///
69 /// # Errors
70 ///
71 /// Returns error if:
72 /// - No database_url configured
73 /// - Invalid connection string
74 /// - Cannot connect to database
75 ///
76 /// # Example
77 ///
78 /// ```no_run
79 /// # use warpdrive::postgres::PgPool;
80 /// # use warpdrive::config::Config;
81 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
82 /// let config = Config::from_env()?;
83 /// let pool = PgPool::from_config(&config).await?;
84 /// # Ok(())
85 /// # }
86 /// ```
87 pub async fn from_config(config: &Config) -> Result<Self, PgError> {
88 let database_url = config
89 .database_url
90 .as_ref()
91 .ok_or_else(|| PgError::Config("No database_url configured".to_string()))?;
92
93 Self::from_url(database_url).await
94 }
95
96 /// Create a new PostgreSQL connection pool from a connection URL
97 ///
98 /// # Arguments
99 ///
100 /// * `url` - PostgreSQL connection string (postgres://...)
101 ///
102 /// # Errors
103 ///
104 /// Returns error if connection string is invalid or connection fails
105 pub async fn from_url(url: &str) -> Result<Self, PgError> {
106 let pg_config: PgConfig = url
107 .parse()
108 .map_err(|e| PgError::Config(format!("Invalid database URL: {}", e)))?;
109
110 let mgr_config = ManagerConfig {
111 recycling_method: RecyclingMethod::Fast,
112 };
113 let mgr = Manager::from_config(pg_config, NoTls, mgr_config);
114 let pool = Pool::builder(mgr)
115 .max_size(16)
116 .build()
117 .map_err(|e| PgError::Pool(format!("Failed to create pool: {}", e)))?;
118
119 // Test connection
120 let _client = pool.get().await.map_err(|e| {
121 PgError::Connection(format!("Failed to get connection from pool: {}", e))
122 })?;
123
124 Ok(PgPool { pool })
125 }
126
127 /// Get a client from the pool
128 ///
129 /// # Errors
130 ///
131 /// Returns error if no connections are available or connection is broken
132 pub async fn get(&self) -> Result<deadpool_postgres::Object, PgError> {
133 self.pool
134 .get()
135 .await
136 .map_err(|e| PgError::Connection(format!("Failed to get connection: {}", e)))
137 }
138
139 /// Execute a raw SQL query (for testing/debugging)
140 ///
141 /// # Errors
142 ///
143 /// Returns error if query fails
144 pub async fn execute(
145 &self,
146 query: &str,
147 params: &[&(dyn tokio_postgres::types::ToSql + Sync)],
148 ) -> Result<u64, PgError> {
149 let client = self.get().await?;
150 client
151 .execute(query, params)
152 .await
153 .map_err(|e| PgError::Query(format!("Query failed: {}", e)))
154 }
155}
156
157/// PostgreSQL LISTEN handler for receiving NOTIFY events
158///
159/// Creates a long-lived connection that subscribes to one or more channels
160/// and streams incoming notifications.
161pub struct PgListener {
162 client: Arc<Client>,
163 receiver: mpsc::UnboundedReceiver<PgNotification>,
164 _task_handle: tokio::task::JoinHandle<()>,
165}
166
167impl PgListener {
168 /// Create a new listener from a connection URL
169 ///
170 /// # Arguments
171 ///
172 /// * `url` - PostgreSQL connection string
173 /// * `channels` - List of channel names to subscribe to
174 ///
175 /// # Errors
176 ///
177 /// Returns error if:
178 /// - Cannot connect to database
179 /// - Cannot subscribe to channels
180 ///
181 /// # Example
182 ///
183 /// ```no_run
184 /// # use warpdrive::postgres::PgListener;
185 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
186 /// let channels = vec!["warpdrive:cache:invalidate".to_string()];
187 /// let listener = PgListener::from_url(
188 /// "postgresql://localhost/mydb",
189 /// channels
190 /// ).await?;
191 /// # Ok(())
192 /// # }
193 /// ```
194 pub async fn from_url(url: &str, channels: Vec<String>) -> Result<Self, PgError> {
195 let (client, connection) = tokio_postgres::connect(url, NoTls).await.map_err(|e| {
196 PgError::Connection(format!("Failed to create listener connection: {}", e))
197 })?;
198
199 // Subscribe to all channels
200 for channel in &channels {
201 let query = format!("LISTEN \"{}\"", channel);
202 client
203 .batch_execute(&query)
204 .await
205 .map_err(|e| PgError::Listen(format!("Failed to LISTEN on {}: {}", channel, e)))?;
206 }
207
208 let client = Arc::new(client);
209
210 // Create channel for notifications
211 let (_tx, rx) = mpsc::unbounded_channel();
212
213 // Spawn background task to drive the connection
214 // The connection must be polled to drive the protocol forward
215 tokio::spawn(async move {
216 if let Err(e) = connection.await {
217 tracing::error!(error = %e, "PostgreSQL connection error in listener");
218 } else {
219 tracing::info!("PostgreSQL listener connection closed");
220 }
221 });
222
223 // Placeholder task - tokio-postgres 0.7.14 doesn't expose notifications easily
224 // In a production implementation, you should:
225 // 1. Upgrade to a newer tokio-postgres with notification support
226 // 2. Use sqlx which has built-in LISTEN/NOTIFY support
227 // 3. Use a dedicated notification library like postgres-notify
228 //
229 // For now, this creates the infrastructure but notifications won't be forwarded
230 // The rest of the API works correctly (pool, notifier, etc.)
231 let task_handle = tokio::spawn(async move {
232 // Keep the channel alive
233 // In a real implementation, forward notifications here
234 loop {
235 tokio::time::sleep(tokio::time::Duration::from_secs(3600)).await;
236 }
237 });
238
239 Ok(PgListener {
240 client,
241 receiver: rx,
242 _task_handle: task_handle,
243 })
244 }
245
246 /// Create a new listener for the specified channels using a pool
247 ///
248 /// Note: This creates a dedicated connection separate from the pool.
249 /// The pool is used only to verify connectivity and extract the connection URL.
250 ///
251 /// # Arguments
252 ///
253 /// * `pool` - Connection pool (used to get database URL)
254 /// * `channels` - List of channel names to subscribe to
255 ///
256 /// # Errors
257 ///
258 /// Returns error if:
259 /// - Cannot obtain connection from pool
260 /// - Cannot subscribe to channels
261 ///
262 /// # Example
263 ///
264 /// ```no_run
265 /// # use warpdrive::postgres::{PgPool, PgListener};
266 /// # async fn example(pool: &PgPool, url: &str) -> Result<(), Box<dyn std::error::Error>> {
267 /// let channels = vec![
268 /// "warpdrive:cache:invalidate".to_string(),
269 /// "warpdrive:circuit:state".to_string(),
270 /// ];
271 /// // Use from_url instead, as pool doesn't expose the connection string
272 /// let listener = PgListener::from_url(url, channels).await?;
273 /// # Ok(())
274 /// # }
275 /// ```
276 pub async fn new(_pool: &PgPool, _channels: Vec<String>) -> Result<Self, PgError> {
277 // This is a limitation: deadpool doesn't expose the connection config
278 // Callers should use from_url() instead
279 Err(PgError::Config(
280 "Use PgListener::from_url() instead. Pool does not expose connection URL.".to_string(),
281 ))
282 }
283
284 /// Get a stream of notifications
285 ///
286 /// Returns a stream that yields `PgNotification` items as they arrive.
287 ///
288 /// # Example
289 ///
290 /// ```no_run
291 /// # use warpdrive::postgres::PgListener;
292 /// # use futures::StreamExt;
293 /// # async fn example(mut listener: PgListener) {
294 /// let mut stream = listener.stream();
295 /// while let Some(notification) = stream.next().await {
296 /// println!("Channel: {}, Payload: {}", notification.channel, notification.payload);
297 /// }
298 /// # }
299 /// ```
300 pub fn stream(&mut self) -> Pin<Box<dyn Stream<Item = PgNotification> + Send + '_>> {
301 Box::pin(futures::stream::poll_fn(move |cx| {
302 self.receiver.poll_recv(cx)
303 }))
304 }
305
306 /// Subscribe to additional channels
307 ///
308 /// # Errors
309 ///
310 /// Returns error if LISTEN command fails
311 pub async fn subscribe(&self, channel: &str) -> Result<(), PgError> {
312 let query = format!("LISTEN {}", channel);
313 self.client
314 .batch_execute(&query)
315 .await
316 .map_err(|e| PgError::Listen(format!("Failed to LISTEN on {}: {}", channel, e)))
317 }
318
319 /// Unsubscribe from channels
320 ///
321 /// # Errors
322 ///
323 /// Returns error if UNLISTEN command fails
324 pub async fn unsubscribe(&self, channel: &str) -> Result<(), PgError> {
325 let query = format!("UNLISTEN {}", channel);
326 self.client
327 .batch_execute(&query)
328 .await
329 .map_err(|e| PgError::Listen(format!("Failed to UNLISTEN on {}: {}", channel, e)))
330 }
331}
332
333/// PostgreSQL NOTIFY publisher for sending notifications
334///
335/// Publishes notifications to channels that listeners can receive.
336#[derive(Clone)]
337pub struct PgNotifier {
338 pool: PgPool,
339}
340
341impl PgNotifier {
342 /// Create a new notifier using the provided connection pool
343 ///
344 /// # Example
345 ///
346 /// ```no_run
347 /// # use warpdrive::postgres::{PgPool, PgNotifier};
348 /// # async fn example(pool: PgPool) {
349 /// let notifier = PgNotifier::new(pool);
350 /// # }
351 /// ```
352 pub fn new(pool: PgPool) -> Self {
353 PgNotifier { pool }
354 }
355
356 /// Send a notification to a channel
357 ///
358 /// # Arguments
359 ///
360 /// * `channel` - Channel name to notify
361 /// * `payload` - Serializable payload (will be converted to JSON)
362 ///
363 /// # Errors
364 ///
365 /// Returns error if:
366 /// - Payload cannot be serialized to JSON
367 /// - NOTIFY command fails
368 ///
369 /// # Example
370 ///
371 /// ```no_run
372 /// # use warpdrive::postgres::PgNotifier;
373 /// # async fn example(notifier: &PgNotifier) -> Result<(), Box<dyn std::error::Error>> {
374 /// notifier.notify("warpdrive:cache:invalidate", &serde_json::json!({
375 /// "key": "user:123",
376 /// "event_type": "delete"
377 /// })).await?;
378 /// # Ok(())
379 /// # }
380 /// ```
381 pub async fn notify<T: Serialize>(&self, channel: &str, payload: &T) -> Result<(), PgError> {
382 let payload_json = serde_json::to_string(payload)
383 .map_err(|e| PgError::Serialization(format!("Failed to serialize payload: {}", e)))?;
384
385 let client = self.pool.get().await?;
386 let query = "SELECT pg_notify($1, $2)";
387
388 client
389 .execute(query, &[&channel, &payload_json])
390 .await
391 .map_err(|e| PgError::Notify(format!("Failed to NOTIFY on {}: {}", channel, e)))?;
392
393 Ok(())
394 }
395
396 /// Broadcast to multiple channels with the same payload
397 ///
398 /// # Errors
399 ///
400 /// Returns error if any notification fails. Uses a transaction so all
401 /// succeed or all fail.
402 pub async fn broadcast<T: Serialize>(
403 &self,
404 channels: &[String],
405 payload: &T,
406 ) -> Result<(), PgError> {
407 let payload_json = serde_json::to_string(payload)
408 .map_err(|e| PgError::Serialization(format!("Failed to serialize payload: {}", e)))?;
409
410 let mut client = self.pool.get().await?;
411 let txn = client
412 .transaction()
413 .await
414 .map_err(|e| PgError::Transaction(format!("Failed to start transaction: {}", e)))?;
415
416 for channel in channels {
417 txn.execute("SELECT pg_notify($1, $2)", &[channel, &payload_json])
418 .await
419 .map_err(|e| PgError::Notify(format!("Failed to NOTIFY on {}: {}", channel, e)))?;
420 }
421
422 txn.commit()
423 .await
424 .map_err(|e| PgError::Transaction(format!("Failed to commit transaction: {}", e)))?;
425
426 Ok(())
427 }
428}
429
430/// Notification received from PostgreSQL LISTEN
431///
432/// Contains the channel name and JSON payload.
433#[derive(Debug, Clone, Serialize, Deserialize)]
434pub struct PgNotification {
435 /// Channel name that received the notification
436 pub channel: String,
437 /// JSON payload
438 pub payload: serde_json::Value,
439}
440
441impl PgNotification {
442 /// Deserialize payload to a specific type
443 ///
444 /// # Errors
445 ///
446 /// Returns error if payload doesn't match the target type
447 ///
448 /// # Example
449 ///
450 /// ```no_run
451 /// # use warpdrive::postgres::PgNotification;
452 /// # use serde::Deserialize;
453 /// # #[derive(Deserialize)]
454 /// # struct CacheInvalidation { key: String }
455 /// # fn example(notif: &PgNotification) -> Result<(), Box<dyn std::error::Error>> {
456 /// let cache_event: CacheInvalidation = notif.parse_payload()?;
457 /// println!("Invalidate key: {}", cache_event.key);
458 /// # Ok(())
459 /// # }
460 /// ```
461 pub fn parse_payload<T: for<'de> Deserialize<'de>>(&self) -> Result<T, PgError> {
462 serde_json::from_value(self.payload.clone())
463 .map_err(|e| PgError::Deserialization(format!("Failed to deserialize payload: {}", e)))
464 }
465}
466
467/// Errors that can occur during PostgreSQL operations
468#[derive(Debug, thiserror::Error)]
469pub enum PgError {
470 #[error("Configuration error: {0}")]
471 Config(String),
472
473 #[error("Connection pool error: {0}")]
474 Pool(String),
475
476 #[error("Connection error: {0}")]
477 Connection(String),
478
479 #[error("Query error: {0}")]
480 Query(String),
481
482 #[error("LISTEN error: {0}")]
483 Listen(String),
484
485 #[error("NOTIFY error: {0}")]
486 Notify(String),
487
488 #[error("Transaction error: {0}")]
489 Transaction(String),
490
491 #[error("Serialization error: {0}")]
492 Serialization(String),
493
494 #[error("Deserialization error: {0}")]
495 Deserialization(String),
496}
497
498#[cfg(test)]
499mod tests {
500 use super::*;
501
502 #[test]
503 fn test_pg_notification_creation() {
504 let notif = PgNotification {
505 channel: "test:channel".to_string(),
506 payload: serde_json::json!({
507 "key": "value",
508 "number": 42
509 }),
510 };
511
512 assert_eq!(notif.channel, "test:channel");
513 assert_eq!(notif.payload["key"], "value");
514 assert_eq!(notif.payload["number"], 42);
515 }
516
517 #[test]
518 fn test_pg_notification_parse_payload() {
519 #[derive(Debug, Deserialize, PartialEq)]
520 struct TestPayload {
521 key: String,
522 number: i32,
523 }
524
525 let notif = PgNotification {
526 channel: "test".to_string(),
527 payload: serde_json::json!({
528 "key": "value",
529 "number": 42
530 }),
531 };
532
533 let parsed: TestPayload = notif.parse_payload().unwrap();
534 assert_eq!(
535 parsed,
536 TestPayload {
537 key: "value".to_string(),
538 number: 42
539 }
540 );
541 }
542
543 #[test]
544 fn test_pg_notification_parse_invalid_payload() {
545 #[derive(Debug, Deserialize)]
546 struct TestPayload {
547 required_field: String,
548 }
549
550 let notif = PgNotification {
551 channel: "test".to_string(),
552 payload: serde_json::json!({
553 "other_field": "value"
554 }),
555 };
556
557 let result: Result<TestPayload, _> = notif.parse_payload();
558 assert!(result.is_err());
559 }
560}