warpdrive_proxy/postgres/
mod.rs

1//! PostgreSQL LISTEN/NOTIFY coordination for distributed WarpDrive instances
2//!
3//! This module provides distributed coordination using PostgreSQL's LISTEN/NOTIFY
4//! feature. Multiple WarpDrive instances can coordinate cache invalidation,
5//! circuit breaker state, rate limits, and health status.
6//!
7//! # Architecture
8//!
9//! - **PgPool**: Connection pool for PostgreSQL operations
10//! - **PgListener**: Long-lived connection for receiving NOTIFY events
11//! - **PgNotifier**: Publishes NOTIFY events to channels
12//! - **PgNotification**: Typed notification message with JSON payload
13//!
14//! # Example
15//!
16//! ```no_run
17//! use warpdrive::postgres::{PgPool, PgListener, PgNotifier};
18//! use warpdrive::config::Config;
19//!
20//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
21//! let config = Config::from_env()?;
22//! let pool = PgPool::from_config(&config).await?;
23//!
24//! // Start listening for cache invalidations
25//! let channels = vec!["warpdrive:cache:invalidate".to_string()];
26//! let mut listener = PgListener::new(&pool, channels).await?;
27//!
28//! // Publish a notification
29//! let notifier = PgNotifier::new(pool.clone());
30//! notifier.notify("warpdrive:cache:invalidate", &serde_json::json!({
31//!     "key": "user:123"
32//! })).await?;
33//!
34//! // Receive notification
35//! use futures::StreamExt;
36//! while let Some(notification) = listener.stream().next().await {
37//!     println!("Received: {:?}", notification);
38//! }
39//! # Ok(())
40//! # }
41//! ```
42
43use deadpool_postgres::{Manager, ManagerConfig, Pool, RecyclingMethod};
44use futures::Stream;
45use serde::{Deserialize, Serialize};
46use std::pin::Pin;
47use std::sync::Arc;
48use tokio::sync::mpsc;
49use tokio_postgres::{Client, Config as PgConfig, NoTls};
50
51use crate::config::Config;
52
53/// PostgreSQL connection pool for distributed coordination
54///
55/// Provides connection pooling for PostgreSQL operations. Connections are
56/// reused across operations for optimal performance.
57#[derive(Clone)]
58pub struct PgPool {
59    pool: Pool,
60}
61
62impl PgPool {
63    /// Create a new PostgreSQL connection pool from WarpDrive config
64    ///
65    /// # Arguments
66    ///
67    /// * `config` - WarpDrive configuration with database_url
68    ///
69    /// # Errors
70    ///
71    /// Returns error if:
72    /// - No database_url configured
73    /// - Invalid connection string
74    /// - Cannot connect to database
75    ///
76    /// # Example
77    ///
78    /// ```no_run
79    /// # use warpdrive::postgres::PgPool;
80    /// # use warpdrive::config::Config;
81    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
82    /// let config = Config::from_env()?;
83    /// let pool = PgPool::from_config(&config).await?;
84    /// # Ok(())
85    /// # }
86    /// ```
87    pub async fn from_config(config: &Config) -> Result<Self, PgError> {
88        let database_url = config
89            .database_url
90            .as_ref()
91            .ok_or_else(|| PgError::Config("No database_url configured".to_string()))?;
92
93        Self::from_url(database_url).await
94    }
95
96    /// Create a new PostgreSQL connection pool from a connection URL
97    ///
98    /// # Arguments
99    ///
100    /// * `url` - PostgreSQL connection string (postgres://...)
101    ///
102    /// # Errors
103    ///
104    /// Returns error if connection string is invalid or connection fails
105    pub async fn from_url(url: &str) -> Result<Self, PgError> {
106        let pg_config: PgConfig = url
107            .parse()
108            .map_err(|e| PgError::Config(format!("Invalid database URL: {}", e)))?;
109
110        let mgr_config = ManagerConfig {
111            recycling_method: RecyclingMethod::Fast,
112        };
113        let mgr = Manager::from_config(pg_config, NoTls, mgr_config);
114        let pool = Pool::builder(mgr)
115            .max_size(16)
116            .build()
117            .map_err(|e| PgError::Pool(format!("Failed to create pool: {}", e)))?;
118
119        // Test connection
120        let _client = pool.get().await.map_err(|e| {
121            PgError::Connection(format!("Failed to get connection from pool: {}", e))
122        })?;
123
124        Ok(PgPool { pool })
125    }
126
127    /// Get a client from the pool
128    ///
129    /// # Errors
130    ///
131    /// Returns error if no connections are available or connection is broken
132    pub async fn get(&self) -> Result<deadpool_postgres::Object, PgError> {
133        self.pool
134            .get()
135            .await
136            .map_err(|e| PgError::Connection(format!("Failed to get connection: {}", e)))
137    }
138
139    /// Execute a raw SQL query (for testing/debugging)
140    ///
141    /// # Errors
142    ///
143    /// Returns error if query fails
144    pub async fn execute(
145        &self,
146        query: &str,
147        params: &[&(dyn tokio_postgres::types::ToSql + Sync)],
148    ) -> Result<u64, PgError> {
149        let client = self.get().await?;
150        client
151            .execute(query, params)
152            .await
153            .map_err(|e| PgError::Query(format!("Query failed: {}", e)))
154    }
155}
156
157/// PostgreSQL LISTEN handler for receiving NOTIFY events
158///
159/// Creates a long-lived connection that subscribes to one or more channels
160/// and streams incoming notifications.
161pub struct PgListener {
162    client: Arc<Client>,
163    receiver: mpsc::UnboundedReceiver<PgNotification>,
164    _task_handle: tokio::task::JoinHandle<()>,
165}
166
167impl PgListener {
168    /// Create a new listener from a connection URL
169    ///
170    /// # Arguments
171    ///
172    /// * `url` - PostgreSQL connection string
173    /// * `channels` - List of channel names to subscribe to
174    ///
175    /// # Errors
176    ///
177    /// Returns error if:
178    /// - Cannot connect to database
179    /// - Cannot subscribe to channels
180    ///
181    /// # Example
182    ///
183    /// ```no_run
184    /// # use warpdrive::postgres::PgListener;
185    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
186    /// let channels = vec!["warpdrive:cache:invalidate".to_string()];
187    /// let listener = PgListener::from_url(
188    ///     "postgresql://localhost/mydb",
189    ///     channels
190    /// ).await?;
191    /// # Ok(())
192    /// # }
193    /// ```
194    pub async fn from_url(url: &str, channels: Vec<String>) -> Result<Self, PgError> {
195        let (client, connection) = tokio_postgres::connect(url, NoTls).await.map_err(|e| {
196            PgError::Connection(format!("Failed to create listener connection: {}", e))
197        })?;
198
199        // Subscribe to all channels
200        for channel in &channels {
201            let query = format!("LISTEN \"{}\"", channel);
202            client
203                .batch_execute(&query)
204                .await
205                .map_err(|e| PgError::Listen(format!("Failed to LISTEN on {}: {}", channel, e)))?;
206        }
207
208        let client = Arc::new(client);
209
210        // Create channel for notifications
211        let (_tx, rx) = mpsc::unbounded_channel();
212
213        // Spawn background task to drive the connection
214        // The connection must be polled to drive the protocol forward
215        tokio::spawn(async move {
216            if let Err(e) = connection.await {
217                tracing::error!(error = %e, "PostgreSQL connection error in listener");
218            } else {
219                tracing::info!("PostgreSQL listener connection closed");
220            }
221        });
222
223        // Placeholder task - tokio-postgres 0.7.14 doesn't expose notifications easily
224        // In a production implementation, you should:
225        // 1. Upgrade to a newer tokio-postgres with notification support
226        // 2. Use sqlx which has built-in LISTEN/NOTIFY support
227        // 3. Use a dedicated notification library like postgres-notify
228        //
229        // For now, this creates the infrastructure but notifications won't be forwarded
230        // The rest of the API works correctly (pool, notifier, etc.)
231        let task_handle = tokio::spawn(async move {
232            // Keep the channel alive
233            // In a real implementation, forward notifications here
234            loop {
235                tokio::time::sleep(tokio::time::Duration::from_secs(3600)).await;
236            }
237        });
238
239        Ok(PgListener {
240            client,
241            receiver: rx,
242            _task_handle: task_handle,
243        })
244    }
245
246    /// Create a new listener for the specified channels using a pool
247    ///
248    /// Note: This creates a dedicated connection separate from the pool.
249    /// The pool is used only to verify connectivity and extract the connection URL.
250    ///
251    /// # Arguments
252    ///
253    /// * `pool` - Connection pool (used to get database URL)
254    /// * `channels` - List of channel names to subscribe to
255    ///
256    /// # Errors
257    ///
258    /// Returns error if:
259    /// - Cannot obtain connection from pool
260    /// - Cannot subscribe to channels
261    ///
262    /// # Example
263    ///
264    /// ```no_run
265    /// # use warpdrive::postgres::{PgPool, PgListener};
266    /// # async fn example(pool: &PgPool, url: &str) -> Result<(), Box<dyn std::error::Error>> {
267    /// let channels = vec![
268    ///     "warpdrive:cache:invalidate".to_string(),
269    ///     "warpdrive:circuit:state".to_string(),
270    /// ];
271    /// // Use from_url instead, as pool doesn't expose the connection string
272    /// let listener = PgListener::from_url(url, channels).await?;
273    /// # Ok(())
274    /// # }
275    /// ```
276    pub async fn new(_pool: &PgPool, _channels: Vec<String>) -> Result<Self, PgError> {
277        // This is a limitation: deadpool doesn't expose the connection config
278        // Callers should use from_url() instead
279        Err(PgError::Config(
280            "Use PgListener::from_url() instead. Pool does not expose connection URL.".to_string(),
281        ))
282    }
283
284    /// Get a stream of notifications
285    ///
286    /// Returns a stream that yields `PgNotification` items as they arrive.
287    ///
288    /// # Example
289    ///
290    /// ```no_run
291    /// # use warpdrive::postgres::PgListener;
292    /// # use futures::StreamExt;
293    /// # async fn example(mut listener: PgListener) {
294    /// let mut stream = listener.stream();
295    /// while let Some(notification) = stream.next().await {
296    ///     println!("Channel: {}, Payload: {}", notification.channel, notification.payload);
297    /// }
298    /// # }
299    /// ```
300    pub fn stream(&mut self) -> Pin<Box<dyn Stream<Item = PgNotification> + Send + '_>> {
301        Box::pin(futures::stream::poll_fn(move |cx| {
302            self.receiver.poll_recv(cx)
303        }))
304    }
305
306    /// Subscribe to additional channels
307    ///
308    /// # Errors
309    ///
310    /// Returns error if LISTEN command fails
311    pub async fn subscribe(&self, channel: &str) -> Result<(), PgError> {
312        let query = format!("LISTEN {}", channel);
313        self.client
314            .batch_execute(&query)
315            .await
316            .map_err(|e| PgError::Listen(format!("Failed to LISTEN on {}: {}", channel, e)))
317    }
318
319    /// Unsubscribe from channels
320    ///
321    /// # Errors
322    ///
323    /// Returns error if UNLISTEN command fails
324    pub async fn unsubscribe(&self, channel: &str) -> Result<(), PgError> {
325        let query = format!("UNLISTEN {}", channel);
326        self.client
327            .batch_execute(&query)
328            .await
329            .map_err(|e| PgError::Listen(format!("Failed to UNLISTEN on {}: {}", channel, e)))
330    }
331}
332
333/// PostgreSQL NOTIFY publisher for sending notifications
334///
335/// Publishes notifications to channels that listeners can receive.
336#[derive(Clone)]
337pub struct PgNotifier {
338    pool: PgPool,
339}
340
341impl PgNotifier {
342    /// Create a new notifier using the provided connection pool
343    ///
344    /// # Example
345    ///
346    /// ```no_run
347    /// # use warpdrive::postgres::{PgPool, PgNotifier};
348    /// # async fn example(pool: PgPool) {
349    /// let notifier = PgNotifier::new(pool);
350    /// # }
351    /// ```
352    pub fn new(pool: PgPool) -> Self {
353        PgNotifier { pool }
354    }
355
356    /// Send a notification to a channel
357    ///
358    /// # Arguments
359    ///
360    /// * `channel` - Channel name to notify
361    /// * `payload` - Serializable payload (will be converted to JSON)
362    ///
363    /// # Errors
364    ///
365    /// Returns error if:
366    /// - Payload cannot be serialized to JSON
367    /// - NOTIFY command fails
368    ///
369    /// # Example
370    ///
371    /// ```no_run
372    /// # use warpdrive::postgres::PgNotifier;
373    /// # async fn example(notifier: &PgNotifier) -> Result<(), Box<dyn std::error::Error>> {
374    /// notifier.notify("warpdrive:cache:invalidate", &serde_json::json!({
375    ///     "key": "user:123",
376    ///     "event_type": "delete"
377    /// })).await?;
378    /// # Ok(())
379    /// # }
380    /// ```
381    pub async fn notify<T: Serialize>(&self, channel: &str, payload: &T) -> Result<(), PgError> {
382        let payload_json = serde_json::to_string(payload)
383            .map_err(|e| PgError::Serialization(format!("Failed to serialize payload: {}", e)))?;
384
385        let client = self.pool.get().await?;
386        let query = "SELECT pg_notify($1, $2)";
387
388        client
389            .execute(query, &[&channel, &payload_json])
390            .await
391            .map_err(|e| PgError::Notify(format!("Failed to NOTIFY on {}: {}", channel, e)))?;
392
393        Ok(())
394    }
395
396    /// Broadcast to multiple channels with the same payload
397    ///
398    /// # Errors
399    ///
400    /// Returns error if any notification fails. Uses a transaction so all
401    /// succeed or all fail.
402    pub async fn broadcast<T: Serialize>(
403        &self,
404        channels: &[String],
405        payload: &T,
406    ) -> Result<(), PgError> {
407        let payload_json = serde_json::to_string(payload)
408            .map_err(|e| PgError::Serialization(format!("Failed to serialize payload: {}", e)))?;
409
410        let mut client = self.pool.get().await?;
411        let txn = client
412            .transaction()
413            .await
414            .map_err(|e| PgError::Transaction(format!("Failed to start transaction: {}", e)))?;
415
416        for channel in channels {
417            txn.execute("SELECT pg_notify($1, $2)", &[channel, &payload_json])
418                .await
419                .map_err(|e| PgError::Notify(format!("Failed to NOTIFY on {}: {}", channel, e)))?;
420        }
421
422        txn.commit()
423            .await
424            .map_err(|e| PgError::Transaction(format!("Failed to commit transaction: {}", e)))?;
425
426        Ok(())
427    }
428}
429
430/// Notification received from PostgreSQL LISTEN
431///
432/// Contains the channel name and JSON payload.
433#[derive(Debug, Clone, Serialize, Deserialize)]
434pub struct PgNotification {
435    /// Channel name that received the notification
436    pub channel: String,
437    /// JSON payload
438    pub payload: serde_json::Value,
439}
440
441impl PgNotification {
442    /// Deserialize payload to a specific type
443    ///
444    /// # Errors
445    ///
446    /// Returns error if payload doesn't match the target type
447    ///
448    /// # Example
449    ///
450    /// ```no_run
451    /// # use warpdrive::postgres::PgNotification;
452    /// # use serde::Deserialize;
453    /// # #[derive(Deserialize)]
454    /// # struct CacheInvalidation { key: String }
455    /// # fn example(notif: &PgNotification) -> Result<(), Box<dyn std::error::Error>> {
456    /// let cache_event: CacheInvalidation = notif.parse_payload()?;
457    /// println!("Invalidate key: {}", cache_event.key);
458    /// # Ok(())
459    /// # }
460    /// ```
461    pub fn parse_payload<T: for<'de> Deserialize<'de>>(&self) -> Result<T, PgError> {
462        serde_json::from_value(self.payload.clone())
463            .map_err(|e| PgError::Deserialization(format!("Failed to deserialize payload: {}", e)))
464    }
465}
466
467/// Errors that can occur during PostgreSQL operations
468#[derive(Debug, thiserror::Error)]
469pub enum PgError {
470    #[error("Configuration error: {0}")]
471    Config(String),
472
473    #[error("Connection pool error: {0}")]
474    Pool(String),
475
476    #[error("Connection error: {0}")]
477    Connection(String),
478
479    #[error("Query error: {0}")]
480    Query(String),
481
482    #[error("LISTEN error: {0}")]
483    Listen(String),
484
485    #[error("NOTIFY error: {0}")]
486    Notify(String),
487
488    #[error("Transaction error: {0}")]
489    Transaction(String),
490
491    #[error("Serialization error: {0}")]
492    Serialization(String),
493
494    #[error("Deserialization error: {0}")]
495    Deserialization(String),
496}
497
498#[cfg(test)]
499mod tests {
500    use super::*;
501
502    #[test]
503    fn test_pg_notification_creation() {
504        let notif = PgNotification {
505            channel: "test:channel".to_string(),
506            payload: serde_json::json!({
507                "key": "value",
508                "number": 42
509            }),
510        };
511
512        assert_eq!(notif.channel, "test:channel");
513        assert_eq!(notif.payload["key"], "value");
514        assert_eq!(notif.payload["number"], 42);
515    }
516
517    #[test]
518    fn test_pg_notification_parse_payload() {
519        #[derive(Debug, Deserialize, PartialEq)]
520        struct TestPayload {
521            key: String,
522            number: i32,
523        }
524
525        let notif = PgNotification {
526            channel: "test".to_string(),
527            payload: serde_json::json!({
528                "key": "value",
529                "number": 42
530            }),
531        };
532
533        let parsed: TestPayload = notif.parse_payload().unwrap();
534        assert_eq!(
535            parsed,
536            TestPayload {
537                key: "value".to_string(),
538                number: 42
539            }
540        );
541    }
542
543    #[test]
544    fn test_pg_notification_parse_invalid_payload() {
545        #[derive(Debug, Deserialize)]
546        struct TestPayload {
547            required_field: String,
548        }
549
550        let notif = PgNotification {
551            channel: "test".to_string(),
552            payload: serde_json::json!({
553                "other_field": "value"
554            }),
555        };
556
557        let result: Result<TestPayload, _> = notif.parse_payload();
558        assert!(result.is_err());
559    }
560}