1use async_trait::async_trait;
4use futures::StreamExt;
5use sqlx::postgres::PgListener;
6use sqlx::PgPool;
7use std::collections::HashSet;
8use std::sync::{Arc, RwLock};
9use std::time::SystemTime;
10use tokio::sync::mpsc;
11use tracing::{debug, error, info, instrument, warn};
12
13use crate::channel_metrics::ChannelMonitor;
14use crate::config::PgmqNotifyConfig;
15use crate::error::{PgmqNotifyError, Result};
16use crate::events::PgmqNotifyEvent;
17
18#[derive(Debug, Clone, Default)]
20pub struct ListenerStats {
21 pub connected: bool,
22 pub channels_listening: usize,
23 pub events_received: u64,
24 pub parse_errors: u64,
25 pub connection_errors: u64,
26 pub last_event_at: Option<SystemTime>,
27 pub last_error_at: Option<SystemTime>,
28}
29
30#[async_trait]
32pub trait PgmqEventHandler: Send + Sync {
33 async fn handle_event(&self, event: PgmqNotifyEvent) -> Result<()>;
35
36 async fn handle_parse_error(&self, channel: &str, payload: &str, error: PgmqNotifyError) {
38 warn!(
39 "Failed to parse notification from channel {}: {} - payload: {}",
40 channel, error, payload
41 );
42 }
43
44 async fn handle_connection_error(&self, error: PgmqNotifyError) {
46 error!("Connection error in PGMQ listener: {}", error);
47 }
48}
49
50pub struct PgmqNotifyListener {
52 pool: PgPool,
53 config: PgmqNotifyConfig,
54 listener: Option<PgListener>,
55 listening_channels: Arc<RwLock<HashSet<String>>>,
56 stats: Arc<RwLock<ListenerStats>>,
57 event_sender: Option<mpsc::Sender<PgmqNotifyEvent>>,
58 event_receiver: Option<mpsc::Receiver<PgmqNotifyEvent>>,
59 channel_monitor: ChannelMonitor,
60}
61
62impl std::fmt::Debug for PgmqNotifyListener {
63 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64 f.debug_struct("PgmqNotifyListener")
65 .field("config", &self.config)
66 .field("listener_connected", &self.listener.is_some())
67 .field("listening_channels", &self.listening_channels)
68 .field("stats", &self.stats)
69 .field("has_event_sender", &self.event_sender.is_some())
70 .field("has_event_receiver", &self.event_receiver.is_some())
71 .field("channel_monitor", &self.channel_monitor)
72 .finish()
73 }
74}
75
76impl PgmqNotifyListener {
77 pub async fn new(pool: PgPool, config: PgmqNotifyConfig, buffer_size: usize) -> Result<Self> {
93 config.validate()?;
94
95 let (event_sender, event_receiver) = mpsc::channel(buffer_size);
96
97 let channel_monitor = ChannelMonitor::new("pgmq_notify_listener", buffer_size);
99
100 Ok(Self {
101 pool,
102 config,
103 listener: None,
104 listening_channels: Arc::new(RwLock::new(HashSet::new())),
105 stats: Arc::new(RwLock::new(ListenerStats::default())),
106 event_sender: Some(event_sender),
107 event_receiver: Some(event_receiver),
108 channel_monitor,
109 })
110 }
111
112 #[must_use]
114 pub fn config(&self) -> &PgmqNotifyConfig {
115 &self.config
116 }
117
118 #[must_use]
120 pub fn stats(&self) -> ListenerStats {
121 self.stats.read().unwrap_or_else(|p| p.into_inner()).clone()
122 }
123
124 #[instrument(skip(self))]
126 pub async fn connect(&mut self) -> Result<()> {
127 if self.listener.is_some() {
128 debug!("Already connected to database");
129 return Ok(());
130 }
131
132 info!("Connecting PGMQ notification listener to database");
133
134 let listener = PgListener::connect_with(&self.pool).await?;
135 self.listener = Some(listener);
136
137 {
139 let mut stats = self.stats.write().unwrap_or_else(|p| p.into_inner());
140 stats.connected = true;
141 };
142
143 info!("Successfully connected PGMQ notification listener");
144 Ok(())
145 }
146
147 #[instrument(skip(self))]
149 pub async fn disconnect(&mut self) -> Result<()> {
150 if let Some(listener) = self.listener.take() {
151 info!("Disconnecting PGMQ notification listener");
152 drop(listener);
154 }
155
156 {
158 let mut channels = self
159 .listening_channels
160 .write()
161 .unwrap_or_else(|p| p.into_inner());
162 channels.clear();
163 };
164
165 {
167 let mut stats = self.stats.write().unwrap_or_else(|p| p.into_inner());
168 stats.connected = false;
169 stats.channels_listening = 0;
170 };
171
172 info!("Disconnected PGMQ notification listener");
173 Ok(())
174 }
175
176 #[instrument(skip(self), fields(channel = %channel))]
178 pub async fn listen_channel(&mut self, channel: &str) -> Result<()> {
179 if self.listener.is_none() {
180 return Err(PgmqNotifyError::NotConnected);
181 }
182
183 {
185 let channels = self
186 .listening_channels
187 .read()
188 .unwrap_or_else(|p| p.into_inner());
189 if channels.contains(channel) {
190 warn!("Already listening on channel {channel}");
191 return Ok(());
192 }
193 }
194
195 debug!("Starting to listen to channel: {}", channel);
196
197 if let Some(ref mut listener) = self.listener {
198 listener.listen(channel).await?;
199 }
200
201 {
203 let mut channels = self
204 .listening_channels
205 .write()
206 .unwrap_or_else(|p| p.into_inner());
207 channels.insert(channel.to_string())
208 };
209
210 {
212 let mut stats = self.stats.write().unwrap_or_else(|p| p.into_inner());
213 stats.channels_listening = self
214 .listening_channels
215 .read()
216 .unwrap_or_else(|p| p.into_inner())
217 .len();
218 };
219
220 info!("Now listening to channel: {}", channel);
221 Ok(())
222 }
223
224 #[instrument(skip(self), fields(channel = %channel))]
226 pub async fn unlisten_channel(&mut self, channel: &str) -> Result<()> {
227 if self.listener.is_none() {
228 return Err(PgmqNotifyError::NotConnected);
229 }
230
231 debug!("Stopping listening to channel: {}", channel);
232
233 if let Some(ref mut listener) = self.listener {
234 listener.unlisten(channel).await?;
235 }
236
237 {
239 let mut channels = self
240 .listening_channels
241 .write()
242 .unwrap_or_else(|p| p.into_inner());
243 channels.remove(channel)
244 };
245
246 {
248 let mut stats = self.stats.write().unwrap_or_else(|p| p.into_inner());
249 stats.channels_listening = self
250 .listening_channels
251 .read()
252 .unwrap_or_else(|p| p.into_inner())
253 .len();
254 };
255
256 info!("Stopped listening to channel: {}", channel);
257 Ok(())
258 }
259
260 pub async fn listen_queue_created(&mut self) -> Result<()> {
262 let channel = self.config.queue_created_channel()?;
263 self.listen_channel(&channel).await
264 }
265
266 pub async fn listen_message_ready_for_namespace(&mut self, namespace: &str) -> Result<()> {
268 let channel = self.config.message_ready_channel(namespace)?;
269 self.listen_channel(&channel).await
270 }
271
272 pub async fn listen_message_ready_global(&mut self) -> Result<()> {
274 let channel = self.config.global_message_ready_channel()?;
275 self.listen_channel(&channel).await
276 }
277
278 pub async fn listen_default_namespaces(&mut self) -> Result<()> {
280 let namespaces: Vec<String> = self.config.default_namespaces.iter().cloned().collect();
281
282 for namespace in namespaces {
283 self.listen_message_ready_for_namespace(&namespace).await?;
284 }
285
286 Ok(())
287 }
288
289 pub async fn next_event(&mut self) -> Result<Option<PgmqNotifyEvent>> {
291 if let Some(ref mut receiver) = self.event_receiver {
292 Ok(receiver.recv().await)
293 } else {
294 Err(PgmqNotifyError::NotConnected)
295 }
296 }
297
298 #[instrument(skip(self, handler))]
304 pub async fn listen_with_handler<H>(&mut self, handler: H) -> Result<()>
305 where
306 H: PgmqEventHandler + 'static,
307 {
308 if self.listener.is_none() {
309 return Err(PgmqNotifyError::NotConnected);
310 }
311
312 let handler = Arc::new(handler);
313
314 info!("Starting PGMQ notification listener loop");
315
316 if let Some(listener) = self.listener.take() {
317 let stats = Arc::clone(&self.stats);
318 let _listening_channels = Arc::clone(&self.listening_channels);
319
320 let mut stream = listener.into_stream();
323
324 while let Some(notification) = stream.next().await {
325 match notification {
326 Ok(notification) => {
327 debug!(
328 "Received notification from channel: {} with payload: {}",
329 notification.channel(),
330 notification.payload()
331 );
332
333 {
335 let mut stats = stats.write().unwrap_or_else(|p| p.into_inner());
336 stats.events_received += 1;
337 stats.last_event_at = Some(SystemTime::now());
338 };
339
340 match serde_json::from_str::<PgmqNotifyEvent>(notification.payload()) {
342 Ok(event) => {
343 if let Err(e) = handler.handle_event(event).await {
344 error!("Event handler failed: {}", e);
345 }
346 }
347 Err(e) => {
348 let parse_error = PgmqNotifyError::Serialization(e);
349
350 {
352 let mut stats =
353 stats.write().unwrap_or_else(|p| p.into_inner());
354 stats.parse_errors += 1;
355 stats.last_error_at = Some(SystemTime::now());
356 };
357
358 handler
359 .handle_parse_error(
360 notification.channel(),
361 notification.payload(),
362 parse_error,
363 )
364 .await;
365 }
366 }
367 }
368 Err(e) => {
369 let conn_error = PgmqNotifyError::Database(e);
370
371 {
373 let mut stats = stats.write().unwrap_or_else(|p| p.into_inner());
374 stats.connection_errors += 1;
375 stats.last_error_at = Some(SystemTime::now());
376 stats.connected = false;
377 };
378
379 handler.handle_connection_error(conn_error).await;
380
381 break;
383 }
384 }
385 }
386
387 info!("PGMQ notification listener loop ended");
388 }
389
390 Ok(())
391 }
392
393 #[instrument(skip(self, handler))]
398 pub async fn start_listening_with_handler<H>(
399 &mut self,
400 handler: H,
401 ) -> Result<tokio::task::JoinHandle<Result<()>>>
402 where
403 H: PgmqEventHandler + 'static,
404 {
405 if self.listener.is_none() {
406 return Err(PgmqNotifyError::NotConnected);
407 }
408
409 let handler = Arc::new(handler);
410
411 info!("Starting PGMQ notification listener in background task");
412
413 if let Some(listener) = self.listener.take() {
414 let stats = Arc::clone(&self.stats);
415 let _listening_channels = Arc::clone(&self.listening_channels);
416
417 let handle = tokio::spawn(async move {
419 let mut stream = listener.into_stream();
420 info!("Started listening for notifications");
421
422 while let Some(notification) = stream.next().await {
423 match notification {
424 Ok(notification) => {
425 debug!(
426 "Received notification from channel: {} with payload: {}",
427 notification.channel(),
428 notification.payload()
429 );
430
431 {
433 let mut stats = stats.write().unwrap_or_else(|p| p.into_inner());
434 stats.events_received += 1;
435 stats.last_event_at = Some(SystemTime::now());
436 };
437
438 match serde_json::from_str::<PgmqNotifyEvent>(notification.payload()) {
440 Ok(event) => {
441 if let Err(e) = handler.handle_event(event).await {
442 error!("Event handler failed: {}", e);
443 }
444 }
445 Err(e) => {
446 let parse_error = PgmqNotifyError::Serialization(e);
447
448 {
450 let mut stats =
451 stats.write().unwrap_or_else(|p| p.into_inner());
452 stats.parse_errors += 1;
453 stats.last_error_at = Some(SystemTime::now());
454 };
455
456 handler
457 .handle_parse_error(
458 notification.channel(),
459 notification.payload(),
460 parse_error,
461 )
462 .await;
463 }
464 }
465 }
466 Err(e) => {
467 let conn_error = PgmqNotifyError::Database(e);
468
469 {
471 let mut stats = stats.write().unwrap_or_else(|p| p.into_inner());
472 stats.connection_errors += 1;
473 stats.last_error_at = Some(SystemTime::now());
474 stats.connected = false;
475 };
476
477 handler.handle_connection_error(conn_error).await;
478
479 break;
481 }
482 }
483 }
484
485 info!("PGMQ notification listener loop ended");
486 Ok(())
487 });
488
489 return Ok(handle);
490 }
491
492 Err(PgmqNotifyError::NotConnected)
493 }
494
495 pub async fn start_listening(&mut self) -> Result<()> {
497 if self.listener.is_none() {
498 return Err(PgmqNotifyError::NotConnected);
499 }
500
501 let event_sender = self.event_sender.take();
502 if let (Some(listener), Some(sender)) = (self.listener.take(), event_sender) {
503 let stats = Arc::clone(&self.stats);
504 let monitor = self.channel_monitor.clone();
506
507 info!(
508 channel_monitor = %monitor.channel_name(),
509 buffer_size = monitor.buffer_size(),
510 "Starting PGMQ notification listener with event queue and channel monitoring"
511 );
512
513 tokio::spawn(async move {
514 let mut stream = listener.into_stream();
515
516 while let Some(notification) = stream.next().await {
517 match notification {
518 Ok(notification) => {
519 debug!(
520 "Received notification from channel: {} with payload: {}",
521 notification.channel(),
522 notification.payload()
523 );
524
525 {
527 let mut stats = stats.write().unwrap_or_else(|p| p.into_inner());
528 stats.events_received += 1;
529 stats.last_event_at = Some(SystemTime::now());
530 };
531
532 match serde_json::from_str::<PgmqNotifyEvent>(notification.payload()) {
534 Ok(event) => {
535 if let Ok(()) = sender.send(event).await {
537 if monitor.record_send_success() {
539 monitor.check_and_warn_saturation(sender.capacity());
540 }
541 } else {
542 warn!("Event receiver dropped, stopping listener");
543 break;
544 }
545 }
546 Err(e) => {
547 {
549 let mut stats =
550 stats.write().unwrap_or_else(|p| p.into_inner());
551 stats.parse_errors += 1;
552 stats.last_error_at = Some(SystemTime::now());
553 };
554
555 warn!(
556 "Failed to parse notification from channel {}: {} - payload: {}",
557 notification.channel(),
558 e,
559 notification.payload()
560 );
561 }
562 }
563 }
564 Err(e) => {
565 {
567 let mut stats = stats.write().unwrap_or_else(|p| p.into_inner());
568 stats.connection_errors += 1;
569 stats.last_error_at = Some(SystemTime::now());
570 stats.connected = false;
571 };
572
573 error!("Connection error in listener: {}", e);
574 break;
575 }
576 }
577 }
578
579 info!("PGMQ notification listener stopped");
580 });
581 }
582
583 Ok(())
584 }
585
586 pub async fn is_healthy(&self) -> bool {
588 let stats = self.stats.read().unwrap_or_else(|p| p.into_inner());
589 stats.connected
590 }
591
592 #[must_use]
594 pub fn listening_channels(&self) -> Vec<String> {
595 self.listening_channels
596 .read()
597 .unwrap_or_else(|p| p.into_inner())
598 .iter()
599 .cloned()
600 .collect()
601 }
602}
603
604#[cfg(test)]
605mod tests {
606 use super::*;
607 use crate::events::QueueCreatedEvent;
608
609 struct MockEventHandler {
611 events_received: Arc<RwLock<Vec<PgmqNotifyEvent>>>,
612 }
613
614 impl MockEventHandler {
615 fn new() -> Self {
616 Self {
617 events_received: Arc::new(RwLock::new(Vec::new())),
618 }
619 }
620
621 fn received_events(&self) -> Vec<PgmqNotifyEvent> {
622 self.events_received.read().unwrap().clone()
623 }
624 }
625
626 #[async_trait]
627 impl PgmqEventHandler for MockEventHandler {
628 async fn handle_event(&self, event: PgmqNotifyEvent) -> Result<()> {
629 self.events_received.write().unwrap().push(event);
630 Ok(())
631 }
632 }
633
634 #[test]
635 fn test_listener_stats() {
636 let stats = ListenerStats::default();
637 assert!(!stats.connected);
638 assert_eq!(stats.channels_listening, 0);
639 assert_eq!(stats.events_received, 0);
640 }
641
642 #[test]
643 fn test_channel_management() {
644 let config = PgmqNotifyConfig::default();
645
646 assert_eq!(
647 config.queue_created_channel().unwrap(),
648 "pgmq_queue_created"
649 );
650 assert_eq!(
651 config.message_ready_channel("orders").unwrap(),
652 "pgmq_message_ready.orders"
653 );
654 assert_eq!(
655 config.global_message_ready_channel().unwrap(),
656 "pgmq_message_ready"
657 );
658 }
659
660 #[test]
661 fn test_mock_event_handler() {
662 let handler = MockEventHandler::new();
663 let _event = PgmqNotifyEvent::QueueCreated(QueueCreatedEvent::new("test_queue", "test"));
664
665 assert_eq!(handler.received_events().len(), 0);
667 }
668
669 }