warpdrive_proxy/cache/
invalidation.rs

1//! Cache invalidation via PostgreSQL LISTEN/NOTIFY
2//!
3//! This module provides distributed cache invalidation using PostgreSQL's
4//! LISTEN/NOTIFY feature. When one WarpDrive instance invalidates a cache entry,
5//! all other instances are notified via PostgreSQL and can clear their local caches.
6//!
7//! # Architecture
8//!
9//! ```text
10//! Instance A                    Instance B
11//! ┌────────┐                    ┌────────┐
12//! │ DELETE │                    │        │
13//! │  key   │                    │        │
14//! └───┬────┘                    └───▲────┘
15//!     │                             │
16//!     │    ┌──────────────┐         │
17//!     └───►│  PostgreSQL  │─────────┘
18//!          │    NOTIFY    │
19//!          │   "key=foo"  │
20//!          └──────────────┘
21//! ```
22//!
23//! # Example
24//!
25//! ```no_run
26//! use warpdrive::cache::coordinator::CacheCoordinator;
27//! use warpdrive::cache::invalidation::InvalidationListener;
28//! use warpdrive::config::Config;
29//!
30//! # async fn example() -> anyhow::Result<()> {
31//! let config = Config::from_env()?;
32//! let cache = CacheCoordinator::from_config(&config).await?;
33//!
34//! // Start listener in background (if PostgreSQL configured)
35//! if let Some(db_url) = &config.database_url {
36//!     InvalidationListener::spawn(
37//!         db_url.clone(),
38//!         config.pg_channel_cache_invalidation.clone(),
39//!         cache.clone(),
40//!     );
41//! }
42//! # Ok(())
43//! # }
44//! ```
45
46use futures::StreamExt;
47use serde::{Deserialize, Serialize};
48use std::sync::Arc;
49use tokio::task::JoinHandle;
50use tracing::{error, info, warn};
51
52use crate::cache::Cache;
53use crate::cache::coordinator::CacheCoordinator;
54use crate::postgres::PgListener;
55
56/// Cache invalidation message format
57///
58/// Sent via PostgreSQL NOTIFY when a cache entry is invalidated.
59#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct InvalidationMessage {
61    /// Cache key to invalidate
62    pub key: String,
63
64    /// Source instance ID (for debugging/metrics)
65    #[serde(skip_serializing_if = "Option::is_none")]
66    pub source: Option<String>,
67
68    /// Reason for invalidation (for debugging)
69    #[serde(skip_serializing_if = "Option::is_none")]
70    pub reason: Option<String>,
71}
72
73impl InvalidationMessage {
74    /// Create a new invalidation message
75    pub fn new(key: String) -> Self {
76        InvalidationMessage {
77            key,
78            source: None,
79            reason: None,
80        }
81    }
82
83    /// Create invalidation message with metadata
84    pub fn with_metadata(key: String, source: String, reason: String) -> Self {
85        InvalidationMessage {
86            key,
87            source: Some(source),
88            reason: Some(reason),
89        }
90    }
91}
92
93/// Cache invalidation listener
94///
95/// Subscribes to PostgreSQL NOTIFY events and invalidates local cache entries.
96pub struct InvalidationListener {
97    _handle: JoinHandle<()>,
98}
99
100impl InvalidationListener {
101    /// Spawn a cache invalidation listener in the background
102    ///
103    /// Creates a PostgreSQL listener and processes invalidation messages.
104    /// The listener runs in a background task and will continue until the
105    /// application exits or an unrecoverable error occurs.
106    ///
107    /// # Arguments
108    ///
109    /// * `database_url` - PostgreSQL connection string
110    /// * `channel` - Channel name to listen on (e.g., "warpdrive:cache:invalidate")
111    /// * `cache` - Cache coordinator to invalidate entries from
112    ///
113    /// # Panics
114    ///
115    /// Does not panic. Errors are logged but the listener will attempt to reconnect.
116    ///
117    /// # Example
118    ///
119    /// ```no_run
120    /// # use warpdrive::cache::coordinator::CacheCoordinator;
121    /// # use warpdrive::cache::invalidation::InvalidationListener;
122    /// # async fn example(cache: CacheCoordinator) {
123    /// let listener = InvalidationListener::spawn(
124    ///     "postgresql://localhost/warpdrive".to_string(),
125    ///     "warpdrive:cache:invalidate".to_string(),
126    ///     cache,
127    /// );
128    /// # }
129    /// ```
130    pub fn spawn(
131        database_url: String,
132        channel: String,
133        cache: CacheCoordinator,
134    ) -> InvalidationListener {
135        let handle =
136            tokio::spawn(
137                async move { Self::listen_loop(database_url, channel, Arc::new(cache)).await },
138            );
139
140        InvalidationListener { _handle: handle }
141    }
142
143    /// Main listen loop
144    ///
145    /// Connects to PostgreSQL, subscribes to the channel, and processes messages.
146    /// Automatically reconnects on connection loss.
147    async fn listen_loop(database_url: String, channel: String, cache: Arc<CacheCoordinator>) {
148        info!(
149            database_url = %database_url,
150            channel = %channel,
151            "Starting cache invalidation listener"
152        );
153
154        loop {
155            match Self::try_listen(&database_url, &channel, Arc::clone(&cache)).await {
156                Ok(()) => {
157                    warn!("Cache invalidation listener exited cleanly, restarting...");
158                }
159                Err(e) => {
160                    error!(
161                        error = %e,
162                        "Cache invalidation listener error, retrying in 5s..."
163                    );
164                    tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
165                }
166            }
167        }
168    }
169
170    /// Attempt to listen for invalidation messages
171    ///
172    /// Returns when connection is lost or an unrecoverable error occurs.
173    async fn try_listen(
174        database_url: &str,
175        channel: &str,
176        cache: Arc<CacheCoordinator>,
177    ) -> anyhow::Result<()> {
178        // Create listener
179        let mut listener = PgListener::from_url(database_url, vec![channel.to_string()]).await?;
180
181        info!(channel = %channel, "Cache invalidation listener connected");
182
183        // Process messages
184        let mut stream = listener.stream();
185        while let Some(notification) = stream.next().await {
186            if let Err(e) = Self::handle_notification(notification, Arc::clone(&cache)).await {
187                error!(
188                    error = %e,
189                    "Failed to handle cache invalidation notification"
190                );
191                // Continue processing other notifications
192            }
193        }
194
195        Ok(())
196    }
197
198    /// Handle a single invalidation notification
199    async fn handle_notification(
200        notification: crate::postgres::PgNotification,
201        cache: Arc<CacheCoordinator>,
202    ) -> anyhow::Result<()> {
203        // Parse message
204        let message: InvalidationMessage = notification.parse_payload()?;
205
206        info!(
207            key = %message.key,
208            source = ?message.source,
209            reason = ?message.reason,
210            "Received cache invalidation notification"
211        );
212
213        // Invalidate cache entry
214        cache.delete(&message.key).await?;
215
216        // Update metrics
217        crate::metrics::CACHE_INVALIDATIONS
218            .with_label_values(&["pg_notify"])
219            .inc();
220
221        Ok(())
222    }
223}
224
225#[cfg(test)]
226mod tests {
227    use super::*;
228
229    #[test]
230    fn test_invalidation_message_new() {
231        let msg = InvalidationMessage::new("test_key".to_string());
232        assert_eq!(msg.key, "test_key");
233        assert!(msg.source.is_none());
234        assert!(msg.reason.is_none());
235    }
236
237    #[test]
238    fn test_invalidation_message_with_metadata() {
239        let msg = InvalidationMessage::with_metadata(
240            "test_key".to_string(),
241            "instance-1".to_string(),
242            "expired".to_string(),
243        );
244        assert_eq!(msg.key, "test_key");
245        assert_eq!(msg.source, Some("instance-1".to_string()));
246        assert_eq!(msg.reason, Some("expired".to_string()));
247    }
248
249    #[test]
250    fn test_invalidation_message_serialization() {
251        let msg = InvalidationMessage::new("test_key".to_string());
252        let json = serde_json::to_string(&msg).unwrap();
253        let parsed: InvalidationMessage = serde_json::from_str(&json).unwrap();
254        assert_eq!(parsed.key, "test_key");
255    }
256
257    #[test]
258    fn test_invalidation_message_with_metadata_serialization() {
259        let msg = InvalidationMessage::with_metadata(
260            "test_key".to_string(),
261            "instance-1".to_string(),
262            "expired".to_string(),
263        );
264        let json = serde_json::to_string(&msg).unwrap();
265        let parsed: InvalidationMessage = serde_json::from_str(&json).unwrap();
266        assert_eq!(parsed.key, "test_key");
267        assert_eq!(parsed.source, Some("instance-1".to_string()));
268        assert_eq!(parsed.reason, Some("expired".to_string()));
269    }
270}