1use crate::{
8 config::PgmqNotifyConfig,
9 error::{PgmqNotifyError, Result},
10 listener::PgmqNotifyListener,
11 types::{ClientStatus, QueueMetrics},
12};
13use pgmq::{types::Message, PGMQueue};
14use regex::Regex;
15use sqlx::{PgPool, Row};
16use std::collections::HashMap;
17use tracing::{debug, error, info, instrument, warn};
18
19#[derive(Debug, Clone)]
21pub struct PgmqClient {
22 pgmq: PGMQueue,
24 pool: sqlx::PgPool,
26 config: PgmqNotifyConfig,
28}
29
30impl PgmqClient {
31 pub async fn new(database_url: &str) -> Result<Self> {
33 Self::new_with_config(database_url, PgmqNotifyConfig::default()).await
34 }
35
36 pub async fn new_with_config(database_url: &str, config: PgmqNotifyConfig) -> Result<Self> {
38 info!("Connecting to pgmq using unified client");
39
40 let pgmq = PGMQueue::new(database_url.to_string()).await?;
41 let pool = sqlx::postgres::PgPoolOptions::new()
42 .max_connections(20)
43 .connect(database_url)
44 .await?;
45
46 info!("Connected to pgmq using unified client");
47 Ok(Self { pgmq, pool, config })
48 }
49
50 pub async fn new_with_pool(pool: sqlx::PgPool) -> Self {
52 Self::new_with_pool_and_config(pool, PgmqNotifyConfig::default()).await
53 }
54
55 pub async fn new_with_pool_and_config(pool: sqlx::PgPool, config: PgmqNotifyConfig) -> Self {
57 info!("Creating unified pgmq client with shared connection pool");
58
59 let pgmq = PGMQueue::new_with_pool(pool.clone()).await;
60
61 info!("Unified pgmq client created with shared pool");
62 Self { pgmq, pool, config }
63 }
64
65 #[instrument(skip(self), fields(queue = %queue_name))]
67 pub async fn create_queue(&self, queue_name: &str) -> Result<()> {
68 debug!("๐ Creating queue: {}", queue_name);
69
70 self.pgmq.create(queue_name).await?;
71
72 info!("Queue created: {}", queue_name);
73 Ok(())
74 }
75
76 #[instrument(skip(self, message), fields(queue = %queue_name))]
78 pub async fn send_json_message<T>(&self, queue_name: &str, message: &T) -> Result<i64>
79 where
80 T: serde::Serialize,
81 {
82 debug!("๐ค Sending JSON message to queue: {}", queue_name);
83
84 let serialized = serde_json::to_value(message)?;
85
86 let message_id = sqlx::query_scalar!(
88 "SELECT pgmq_send_with_notify($1, $2, $3)",
89 queue_name,
90 &serialized,
91 0i32
92 )
93 .fetch_one(&self.pool)
94 .await?;
95
96 let message_id = message_id.ok_or_else(|| {
97 PgmqNotifyError::Generic(anyhow::anyhow!("Wrapper function returned NULL message ID"))
98 })?;
99
100 info!(
101 "JSON message sent to queue: {} with ID: {} (with notification)",
102 queue_name, message_id
103 );
104 Ok(message_id)
105 }
106
107 #[instrument(skip(self, message), fields(queue = %queue_name, delay_seconds = %delay_seconds))]
109 pub async fn send_message_with_delay<T>(
110 &self,
111 queue_name: &str,
112 message: &T,
113 delay_seconds: u64,
114 ) -> Result<i64>
115 where
116 T: serde::Serialize,
117 {
118 debug!(
119 "๐ค Sending delayed message to queue: {} with delay: {}s",
120 queue_name, delay_seconds
121 );
122
123 let serialized = serde_json::to_value(message)?;
124
125 let message_id = sqlx::query_scalar!(
127 "SELECT pgmq_send_with_notify($1, $2, $3)",
128 queue_name,
129 &serialized,
130 delay_seconds as i32
131 )
132 .fetch_one(&self.pool)
133 .await?;
134
135 let message_id = message_id.ok_or_else(|| {
136 PgmqNotifyError::Generic(anyhow::anyhow!("Wrapper function returned NULL message ID"))
137 })?;
138
139 info!(
140 "Delayed message sent to queue: {} with ID: {} (with notification)",
141 queue_name, message_id
142 );
143 Ok(message_id)
144 }
145
146 #[instrument(skip(self), fields(queue = %queue_name, limit = ?limit))]
148 pub async fn read_messages(
149 &self,
150 queue_name: &str,
151 visibility_timeout: Option<i32>,
152 limit: Option<i32>,
153 ) -> Result<Vec<Message<serde_json::Value>>> {
154 debug!(
155 "๐ฅ Reading messages from queue: {} (limit: {:?})",
156 queue_name, limit
157 );
158
159 let messages = match limit {
160 Some(l) => self
161 .pgmq
162 .read_batch(queue_name, visibility_timeout, l)
163 .await?
164 .unwrap_or_default(),
165 None => match self.pgmq.read(queue_name, visibility_timeout).await? {
166 Some(msg) => vec![msg],
167 None => vec![],
168 },
169 };
170
171 debug!(
172 "๐จ Read {} messages from queue: {}",
173 messages.len(),
174 queue_name
175 );
176 Ok(messages)
177 }
178
179 #[instrument(skip(self), fields(queue = %queue_name))]
181 pub async fn pop_message(
182 &self,
183 queue_name: &str,
184 ) -> Result<Option<Message<serde_json::Value>>> {
185 debug!("๐ฅ Popping message from queue: {}", queue_name);
186
187 let message = self.pgmq.pop(queue_name).await?;
188
189 if message.is_some() {
190 debug!("๐จ Popped message from queue: {}", queue_name);
191 } else {
192 debug!("๐ญ No messages available in queue: {}", queue_name);
193 }
194 Ok(message)
195 }
196
197 #[instrument(skip(self), fields(queue = %queue_name, message_id = %message_id))]
199 pub async fn read_specific_message<T>(
200 &self,
201 queue_name: &str,
202 message_id: i64,
203 visibility_timeout: i32,
204 ) -> Result<Option<Message<T>>>
205 where
206 T: serde::de::DeserializeOwned,
207 {
208 debug!(
209 "๐ฅ Reading specific message {} from queue: {}",
210 message_id, queue_name
211 );
212
213 let query = "SELECT msg_id, read_ct, enqueued_at, vt, message FROM pgmq_read_specific_message($1, $2, $3)";
215
216 let row = sqlx::query(query)
217 .bind(queue_name)
218 .bind(message_id)
219 .bind(visibility_timeout)
220 .fetch_optional(&self.pool)
221 .await?;
222
223 if let Some(row) = row {
224 let msg_id: i64 = row.get("msg_id");
225 let read_ct: i32 = row.get("read_ct");
226 let enqueued_at: chrono::DateTime<chrono::Utc> = row.get("enqueued_at");
227 let vt: chrono::DateTime<chrono::Utc> = row.get("vt");
228 let message_json: serde_json::Value = row.get("message");
229
230 match serde_json::from_value::<T>(message_json) {
232 Ok(deserialized) => {
233 let typed_message = Message {
234 msg_id,
235 read_ct,
236 enqueued_at,
237 vt,
238 message: deserialized,
239 };
240 debug!("Found and deserialized specific message {}", message_id);
241 Ok(Some(typed_message))
242 }
243 Err(e) => {
244 warn!("Failed to deserialize message {}: {}", message_id, e);
245 Err(PgmqNotifyError::Serialization(e))
246 }
247 }
248 } else {
249 debug!(
250 "๐ญ Specific message {} not found in queue: {}",
251 message_id, queue_name
252 );
253 Ok(None)
254 }
255 }
256
257 #[instrument(skip(self), fields(queue = %queue_name, message_id = %message_id))]
259 pub async fn delete_message(&self, queue_name: &str, message_id: i64) -> Result<()> {
260 debug!(
261 "๐Deleting message {} from queue: {}",
262 message_id, queue_name
263 );
264
265 self.pgmq.delete(queue_name, message_id).await?;
266
267 debug!("Message deleted: {}", message_id);
268 Ok(())
269 }
270
271 #[instrument(skip(self), fields(queue = %queue_name, message_id = %message_id))]
273 pub async fn archive_message(&self, queue_name: &str, message_id: i64) -> Result<()> {
274 debug!(
275 "Archiving message {} from queue: {}",
276 message_id, queue_name
277 );
278
279 self.pgmq.archive(queue_name, message_id).await?;
280
281 debug!("Message archived: {}", message_id);
282 Ok(())
283 }
284
285 #[instrument(skip(self), fields(queue = %queue_name, message_id = %message_id, vt_seconds = %vt_seconds))]
301 pub async fn set_visibility_timeout(
302 &self,
303 queue_name: &str,
304 message_id: i64,
305 vt_seconds: i32,
306 ) -> Result<()> {
307 debug!(
308 "Setting visibility timeout for message {} in queue {} to {} seconds",
309 message_id, queue_name, vt_seconds
310 );
311
312 sqlx::query_scalar!(
315 "SELECT msg_id FROM pgmq.set_vt($1::text, $2::bigint, $3::integer)",
316 queue_name,
317 message_id,
318 vt_seconds
319 )
320 .fetch_optional(&self.pool)
321 .await?;
322
323 debug!(
324 "Visibility timeout set for message {} to {} seconds",
325 message_id, vt_seconds
326 );
327 Ok(())
328 }
329
330 #[instrument(skip(self), fields(queue = %queue_name))]
332 pub async fn purge_queue(&self, queue_name: &str) -> Result<u64> {
333 warn!("๐งน Purging queue: {}", queue_name);
334
335 let purged_count = self.pgmq.purge(queue_name).await?;
336
337 warn!(
338 "๐Purged {} messages from queue: {}",
339 purged_count, queue_name
340 );
341 Ok(purged_count)
342 }
343
344 #[instrument(skip(self), fields(queue = %queue_name))]
346 pub async fn drop_queue(&self, queue_name: &str) -> Result<()> {
347 warn!("๐ฅ Dropping queue: {}", queue_name);
348
349 self.pgmq.destroy(queue_name).await?;
350
351 warn!("๐Queue dropped: {}", queue_name);
352 Ok(())
353 }
354
355 #[instrument(skip(self), fields(queue = %queue_name))]
357 pub async fn queue_metrics(&self, queue_name: &str) -> Result<QueueMetrics> {
358 debug!("Getting metrics for queue: {}", queue_name);
359
360 let row = sqlx::query!(
362 "SELECT queue_length, oldest_msg_age_sec FROM pgmq.metrics($1)",
363 queue_name
364 )
365 .fetch_optional(&self.pool)
366 .await?;
367
368 if let Some(row) = row {
369 Ok(QueueMetrics {
370 queue_name: queue_name.to_string(),
371 message_count: row.queue_length.unwrap_or(0),
372 consumer_count: None,
373 oldest_message_age_seconds: row.oldest_msg_age_sec.map(i64::from),
374 })
375 } else {
376 Ok(QueueMetrics {
378 queue_name: queue_name.to_string(),
379 message_count: 0,
380 consumer_count: None,
381 oldest_message_age_seconds: None,
382 })
383 }
384 }
385
386 #[must_use]
388 pub fn pool(&self) -> &sqlx::PgPool {
389 &self.pool
390 }
391
392 #[must_use]
394 pub fn pgmq(&self) -> &PGMQueue {
395 &self.pgmq
396 }
397
398 #[must_use]
400 pub fn config(&self) -> &PgmqNotifyConfig {
401 &self.config
402 }
403
404 #[must_use]
406 pub fn has_notify_capabilities(&self) -> bool {
407 self.config.enable_triggers
409 }
410
411 #[instrument(skip(self))]
413 pub async fn health_check(&self) -> Result<bool> {
414 match sqlx::query!("SELECT 1 as health_check")
415 .fetch_one(&self.pool)
416 .await
417 {
418 Ok(_) => {
419 debug!("Health check passed");
420 Ok(true)
421 }
422 Err(e) => {
423 error!("Health check failed: {}", e);
424 Ok(false)
425 }
426 }
427 }
428
429 #[instrument(skip(self))]
431 pub async fn get_client_status(&self) -> Result<ClientStatus> {
432 let healthy = self.health_check().await.unwrap_or(false);
433
434 Ok(ClientStatus {
435 client_type: "pgmq-unified".to_string(),
436 connected: healthy,
437 connection_info: HashMap::from([
438 (
439 "backend".to_string(),
440 serde_json::Value::String("postgresql".to_string()),
441 ),
442 (
443 "queue_type".to_string(),
444 serde_json::Value::String("pgmq".to_string()),
445 ),
446 (
447 "has_notifications".to_string(),
448 serde_json::Value::Bool(true),
449 ),
450 (
451 "pool_size".to_string(),
452 serde_json::Value::Number(self.pool.size().into()),
453 ),
454 ]),
455 last_activity: Some(chrono::Utc::now()),
456 })
457 }
458
459 #[must_use]
461 pub fn extract_namespace(&self, queue_name: &str) -> Option<String> {
462 let pattern = &self.config.queue_naming_pattern;
463 if let Ok(regex) = Regex::new(pattern) {
464 if let Some(captures) = regex.captures(queue_name) {
465 if let Some(namespace_match) = captures.name("namespace") {
466 return Some(namespace_match.as_str().to_string());
467 }
468 }
469 }
470
471 queue_name
473 .ends_with("_queue")
474 .then(|| queue_name.trim_end_matches("_queue").to_string())
475 }
476
477 pub async fn create_listener(&self, buffer_size: usize) -> Result<PgmqNotifyListener> {
488 PgmqNotifyListener::new(self.pool.clone(), self.config.clone(), buffer_size).await
489 }
490}
491
492impl PgmqClient {
494 #[instrument(skip(self), fields(namespace = %namespace, batch_size = %batch_size))]
496 pub async fn process_namespace_queue(
497 &self,
498 namespace: &str,
499 visibility_timeout: Option<i32>,
500 batch_size: i32,
501 ) -> Result<Vec<Message<serde_json::Value>>> {
502 let queue_name = format!("worker_{namespace}_queue");
503 self.read_messages(&queue_name, visibility_timeout, Some(batch_size))
504 .await
505 }
506
507 #[instrument(skip(self), fields(namespace = %namespace, message_id = %message_id))]
509 pub async fn complete_message(&self, namespace: &str, message_id: i64) -> Result<()> {
510 let queue_name = format!("worker_{namespace}_queue");
511 self.delete_message(&queue_name, message_id).await
512 }
513
514 #[instrument(skip(self, namespaces))]
516 pub async fn initialize_namespace_queues(&self, namespaces: &[&str]) -> Result<()> {
517 info!("Initializing {} namespace queues", namespaces.len());
518
519 for namespace in namespaces {
520 let queue_name = format!("worker_{namespace}_queue");
521 self.create_queue(&queue_name).await?;
522 }
523
524 info!("Initialized all namespace queues");
525 Ok(())
526 }
527
528 #[instrument(skip(self, message, tx), fields(queue = %queue_name))]
530 pub async fn send_with_transaction<T>(
531 &self,
532 queue_name: &str,
533 message: &T,
534 tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
535 ) -> Result<i64>
536 where
537 T: serde::Serialize,
538 {
539 debug!(
540 "๐ค Sending message within transaction to queue: {}",
541 queue_name
542 );
543
544 let serialized = serde_json::to_value(message)?;
545
546 let message_id = sqlx::query_scalar!(
548 "SELECT pgmq_send_with_notify($1, $2, $3)",
549 queue_name,
550 &serialized,
551 0i32
552 )
553 .fetch_one(&mut **tx)
554 .await?;
555
556 let message_id = message_id.ok_or_else(|| {
557 PgmqNotifyError::Generic(anyhow::anyhow!("Wrapper function returned NULL message ID"))
558 })?;
559
560 debug!(
561 "Message sent in transaction with id: {} (with notification)",
562 message_id
563 );
564 Ok(message_id)
565 }
566}
567
568#[derive(Debug)]
570pub struct PgmqClientFactory;
571
572impl PgmqClientFactory {
573 pub async fn create(database_url: &str) -> Result<PgmqClient> {
575 PgmqClient::new(database_url).await
576 }
577
578 pub async fn create_with_config(
580 database_url: &str,
581 config: PgmqNotifyConfig,
582 ) -> Result<PgmqClient> {
583 PgmqClient::new_with_config(database_url, config).await
584 }
585
586 pub async fn create_with_pool(pool: PgPool) -> PgmqClient {
588 PgmqClient::new_with_pool(pool).await
589 }
590
591 pub async fn create_with_pool_and_config(pool: PgPool, config: PgmqNotifyConfig) -> PgmqClient {
593 PgmqClient::new_with_pool_and_config(pool, config).await
594 }
595}
596
597pub use PgmqClient as PgmqNotifyClient;
599pub use PgmqClientFactory as PgmqNotifyClientFactory;
600
601#[cfg(test)]
602mod tests {
603 use super::*;
604 use dotenvy::dotenv;
605
606 #[tokio::test]
607 async fn test_pgmq_client_creation() {
608 dotenv().ok();
609 if std::env::var("DATABASE_URL").is_err() {
612 println!("Skipping pgmq test - no DATABASE_URL provided");
613 return;
614 }
615
616 let database_url = std::env::var("DATABASE_URL").unwrap();
617 match PgmqClient::new(&database_url).await {
618 Ok(_) => {
619 println!("PgmqClient created successfully");
621 }
622 Err(e) => {
623 println!(" Skipping test due to client creation error: {e:?}");
626 return;
627 }
628 }
629 }
630
631 #[test]
632 fn test_namespace_extraction() {
633 dotenv().ok();
634 let config = PgmqNotifyConfig::new().with_queue_naming_pattern(r"(?P<namespace>\w+)_queue");
635
636 let pattern = &config.queue_naming_pattern;
638 let regex = regex::Regex::new(pattern).unwrap();
639
640 let captures = regex.captures("orders_queue").unwrap();
642 let namespace = captures.name("namespace").unwrap().as_str();
643 assert_eq!(namespace, "orders");
644
645 let captures = regex.captures("inventory_queue").unwrap();
646 let namespace = captures.name("namespace").unwrap().as_str();
647 assert_eq!(namespace, "inventory");
648
649 assert!(regex.captures("invalid_name").is_none());
651 }
652
653 #[tokio::test]
654 async fn test_shared_pool_pattern() {
655 dotenv().ok();
656 if std::env::var("DATABASE_URL").is_err() {
658 println!("Skipping shared pool test - no DATABASE_URL provided");
659 return;
660 }
661
662 let database_url = std::env::var("DATABASE_URL").unwrap();
663
664 let pool = sqlx::postgres::PgPoolOptions::new()
666 .max_connections(5)
667 .connect(&database_url)
668 .await
669 .expect("Failed to create connection pool");
670
671 let client = PgmqClient::new_with_pool(pool.clone()).await;
673
674 assert_eq!(client.pool().size(), pool.size());
676
677 println!("Shared pool pattern working correctly");
678 }
679}