1use crate::{
8 config::PgmqNotifyConfig,
9 error::{PgmqNotifyError, Result},
10 listener::PgmqNotifyListener,
11 types::{ClientStatus, QueueMetrics},
12};
13use pgmq::{types::Message, PGMQueueExt};
14use regex::Regex;
15use sqlx::{PgPool, Row};
16use std::collections::HashMap;
17use tracing::{debug, error, info, instrument, warn};
18
19#[derive(Debug, Clone)]
21pub struct PgmqClient {
22 pgmq: PGMQueueExt,
24 pool: sqlx::PgPool,
26 config: PgmqNotifyConfig,
28}
29
30impl PgmqClient {
31 pub async fn new(database_url: &str) -> Result<Self> {
33 Self::new_with_config(database_url, PgmqNotifyConfig::default()).await
34 }
35
36 pub async fn new_with_config(database_url: &str, config: PgmqNotifyConfig) -> Result<Self> {
38 info!("Connecting to pgmq using unified client");
39
40 let max_connections = 20;
41 let pgmq = PGMQueueExt::new(database_url.to_string(), max_connections).await?;
42 let pool = pgmq.connection.clone();
43
44 info!("Connected to pgmq using unified client");
45 Ok(Self { pgmq, pool, config })
46 }
47
48 pub async fn new_with_pool(pool: sqlx::PgPool) -> Self {
50 Self::new_with_pool_and_config(pool, PgmqNotifyConfig::default()).await
51 }
52
53 pub async fn new_with_pool_and_config(pool: sqlx::PgPool, config: PgmqNotifyConfig) -> Self {
55 info!("Creating unified pgmq client with shared connection pool");
56
57 let pgmq = PGMQueueExt::new_with_pool(pool.clone()).await;
58
59 info!("Unified pgmq client created with shared pool");
60 Self { pgmq, pool, config }
61 }
62
63 #[instrument(skip(self), fields(queue = %queue_name))]
65 pub async fn create_queue(&self, queue_name: &str) -> Result<()> {
66 debug!("๐ Creating queue: {}", queue_name);
67
68 let created = self.pgmq.create(queue_name).await?;
69
70 if created {
71 info!("Queue created: {}", queue_name);
72 } else {
73 debug!("Queue already exists: {}", queue_name);
74 }
75 Ok(())
76 }
77
78 #[instrument(skip(self, message), fields(queue = %queue_name))]
80 pub async fn send_json_message<T>(&self, queue_name: &str, message: &T) -> Result<i64>
81 where
82 T: serde::Serialize,
83 {
84 debug!("๐ค Sending JSON message to queue: {}", queue_name);
85
86 let serialized = serde_json::to_value(message)?;
87
88 let message_id = sqlx::query_scalar!(
90 "SELECT pgmq_send_with_notify($1, $2, $3)",
91 queue_name,
92 &serialized,
93 0i32
94 )
95 .fetch_one(&self.pool)
96 .await?;
97
98 let message_id = message_id.ok_or_else(|| {
99 PgmqNotifyError::Generic(anyhow::anyhow!("Wrapper function returned NULL message ID"))
100 })?;
101
102 info!(
103 "JSON message sent to queue: {} with ID: {} (with notification)",
104 queue_name, message_id
105 );
106 Ok(message_id)
107 }
108
109 #[instrument(skip(self, message), fields(queue = %queue_name, delay_seconds = %delay_seconds))]
111 pub async fn send_message_with_delay<T>(
112 &self,
113 queue_name: &str,
114 message: &T,
115 delay_seconds: u64,
116 ) -> Result<i64>
117 where
118 T: serde::Serialize,
119 {
120 debug!(
121 "๐ค Sending delayed message to queue: {} with delay: {}s",
122 queue_name, delay_seconds
123 );
124
125 let serialized = serde_json::to_value(message)?;
126
127 let message_id = sqlx::query_scalar!(
129 "SELECT pgmq_send_with_notify($1, $2, $3)",
130 queue_name,
131 &serialized,
132 delay_seconds as i32
133 )
134 .fetch_one(&self.pool)
135 .await?;
136
137 let message_id = message_id.ok_or_else(|| {
138 PgmqNotifyError::Generic(anyhow::anyhow!("Wrapper function returned NULL message ID"))
139 })?;
140
141 info!(
142 "Delayed message sent to queue: {} with ID: {} (with notification)",
143 queue_name, message_id
144 );
145 Ok(message_id)
146 }
147
148 #[instrument(skip(self), fields(queue = %queue_name, limit = ?limit))]
154 pub async fn read_messages(
155 &self,
156 queue_name: &str,
157 visibility_timeout: Option<i32>,
158 limit: Option<i32>,
159 ) -> Result<Vec<Message<serde_json::Value>>> {
160 debug!(
161 "๐ฅ Reading messages from queue: {} (limit: {:?})",
162 queue_name, limit
163 );
164
165 let vt = visibility_timeout.unwrap_or(30);
166 let messages = match limit {
167 Some(l) => self
168 .pgmq
169 .read_batch_with_poll::<serde_json::Value>(queue_name, vt, l, None, None)
170 .await?
171 .unwrap_or_default(),
172 None => match self.pgmq.read::<serde_json::Value>(queue_name, vt).await? {
173 Some(msg) => vec![msg],
174 None => vec![],
175 },
176 };
177
178 debug!(
179 "๐จ Read {} messages from queue: {}",
180 messages.len(),
181 queue_name
182 );
183 Ok(messages)
184 }
185
186 #[instrument(skip(self), fields(queue = %queue_name))]
188 pub async fn pop_message(
189 &self,
190 queue_name: &str,
191 ) -> Result<Option<Message<serde_json::Value>>> {
192 debug!("๐ฅ Popping message from queue: {}", queue_name);
193
194 let message = self.pgmq.pop::<serde_json::Value>(queue_name).await?;
195
196 if message.is_some() {
197 debug!("๐จ Popped message from queue: {}", queue_name);
198 } else {
199 debug!("๐ญ No messages available in queue: {}", queue_name);
200 }
201 Ok(message)
202 }
203
204 #[instrument(skip(self), fields(queue = %queue_name, message_id = %message_id))]
206 pub async fn read_specific_message<T>(
207 &self,
208 queue_name: &str,
209 message_id: i64,
210 visibility_timeout: i32,
211 ) -> Result<Option<Message<T>>>
212 where
213 T: serde::de::DeserializeOwned,
214 {
215 debug!(
216 "๐ฅ Reading specific message {} from queue: {}",
217 message_id, queue_name
218 );
219
220 let query = "SELECT msg_id, read_ct, enqueued_at, vt, message FROM pgmq_read_specific_message($1, $2, $3)";
222
223 let row = sqlx::query(query)
224 .bind(queue_name)
225 .bind(message_id)
226 .bind(visibility_timeout)
227 .fetch_optional(&self.pool)
228 .await?;
229
230 if let Some(row) = row {
231 let msg_id: i64 = row.get("msg_id");
232 let read_ct: i32 = row.get("read_ct");
233 let enqueued_at: chrono::DateTime<chrono::Utc> = row.get("enqueued_at");
234 let vt: chrono::DateTime<chrono::Utc> = row.get("vt");
235 let message_json: serde_json::Value = row.get("message");
236
237 match serde_json::from_value::<T>(message_json) {
239 Ok(deserialized) => {
240 let typed_message = Message {
241 msg_id,
242 read_ct,
243 enqueued_at,
244 vt,
245 message: deserialized,
246 };
247 debug!("Found and deserialized specific message {}", message_id);
248 Ok(Some(typed_message))
249 }
250 Err(e) => {
251 warn!("Failed to deserialize message {}: {}", message_id, e);
252 Err(PgmqNotifyError::Serialization(e))
253 }
254 }
255 } else {
256 debug!(
257 "๐ญ Specific message {} not found in queue: {}",
258 message_id, queue_name
259 );
260 Ok(None)
261 }
262 }
263
264 #[instrument(skip(self), fields(queue = %queue_name, message_id = %message_id))]
266 pub async fn delete_message(&self, queue_name: &str, message_id: i64) -> Result<()> {
267 debug!(
268 "๐Deleting message {} from queue: {}",
269 message_id, queue_name
270 );
271
272 let _ = self.pgmq.delete(queue_name, message_id).await?;
273
274 debug!("Message deleted: {}", message_id);
275 Ok(())
276 }
277
278 #[instrument(skip(self), fields(queue = %queue_name, message_id = %message_id))]
280 pub async fn archive_message(&self, queue_name: &str, message_id: i64) -> Result<()> {
281 debug!(
282 "Archiving message {} from queue: {}",
283 message_id, queue_name
284 );
285
286 let _ = self.pgmq.archive(queue_name, message_id).await?;
287
288 debug!("Message archived: {}", message_id);
289 Ok(())
290 }
291
292 #[instrument(skip(self), fields(queue = %queue_name, message_id = %message_id, vt_seconds = %vt_seconds))]
308 pub async fn set_visibility_timeout(
309 &self,
310 queue_name: &str,
311 message_id: i64,
312 vt_seconds: i32,
313 ) -> Result<()> {
314 debug!(
315 "Setting visibility timeout for message {} in queue {} to {} seconds",
316 message_id, queue_name, vt_seconds
317 );
318
319 sqlx::query_scalar!(
322 "SELECT msg_id FROM pgmq.set_vt($1::text, $2::bigint, $3::integer)",
323 queue_name,
324 message_id,
325 vt_seconds
326 )
327 .fetch_optional(&self.pool)
328 .await?;
329
330 debug!(
331 "Visibility timeout set for message {} to {} seconds",
332 message_id, vt_seconds
333 );
334 Ok(())
335 }
336
337 #[instrument(skip(self), fields(queue = %queue_name))]
339 pub async fn purge_queue(&self, queue_name: &str) -> Result<u64> {
340 warn!("๐งน Purging queue: {}", queue_name);
341
342 let purged_count = self.pgmq.purge_queue(queue_name).await?;
343
344 warn!(
345 "๐Purged {} messages from queue: {}",
346 purged_count, queue_name
347 );
348 Ok(purged_count as u64)
349 }
350
351 #[instrument(skip(self), fields(queue = %queue_name))]
353 pub async fn drop_queue(&self, queue_name: &str) -> Result<()> {
354 warn!("๐ฅ Dropping queue: {}", queue_name);
355
356 self.pgmq.drop_queue(queue_name).await?;
357
358 warn!("๐Queue dropped: {}", queue_name);
359 Ok(())
360 }
361
362 #[instrument(skip(self), fields(queue = %queue_name))]
364 pub async fn queue_metrics(&self, queue_name: &str) -> Result<QueueMetrics> {
365 debug!("Getting metrics for queue: {}", queue_name);
366
367 let row = sqlx::query!(
369 "SELECT queue_length, oldest_msg_age_sec FROM pgmq.metrics($1)",
370 queue_name
371 )
372 .fetch_optional(&self.pool)
373 .await?;
374
375 if let Some(row) = row {
376 Ok(QueueMetrics {
377 queue_name: queue_name.to_string(),
378 message_count: row.queue_length.unwrap_or(0),
379 consumer_count: None,
380 oldest_message_age_seconds: row.oldest_msg_age_sec.map(i64::from),
381 })
382 } else {
383 Ok(QueueMetrics {
385 queue_name: queue_name.to_string(),
386 message_count: 0,
387 consumer_count: None,
388 oldest_message_age_seconds: None,
389 })
390 }
391 }
392
393 #[must_use]
395 pub fn pool(&self) -> &sqlx::PgPool {
396 &self.pool
397 }
398
399 #[must_use]
401 pub fn pgmq(&self) -> &PGMQueueExt {
402 &self.pgmq
403 }
404
405 #[must_use]
407 pub fn config(&self) -> &PgmqNotifyConfig {
408 &self.config
409 }
410
411 #[must_use]
413 pub fn has_notify_capabilities(&self) -> bool {
414 self.config.enable_triggers
416 }
417
418 #[instrument(skip(self))]
420 pub async fn health_check(&self) -> Result<bool> {
421 match sqlx::query!("SELECT 1 as health_check")
422 .fetch_one(&self.pool)
423 .await
424 {
425 Ok(_) => {
426 debug!("Health check passed");
427 Ok(true)
428 }
429 Err(e) => {
430 error!("Health check failed: {}", e);
431 Ok(false)
432 }
433 }
434 }
435
436 #[instrument(skip(self))]
438 pub async fn get_client_status(&self) -> Result<ClientStatus> {
439 let healthy = self.health_check().await.unwrap_or(false);
440
441 Ok(ClientStatus {
442 client_type: "pgmq-unified".to_string(),
443 connected: healthy,
444 connection_info: HashMap::from([
445 (
446 "backend".to_string(),
447 serde_json::Value::String("postgresql".to_string()),
448 ),
449 (
450 "queue_type".to_string(),
451 serde_json::Value::String("pgmq".to_string()),
452 ),
453 (
454 "has_notifications".to_string(),
455 serde_json::Value::Bool(true),
456 ),
457 (
458 "pool_size".to_string(),
459 serde_json::Value::Number(self.pool.size().into()),
460 ),
461 ]),
462 last_activity: Some(chrono::Utc::now()),
463 })
464 }
465
466 #[must_use]
468 pub fn extract_namespace(&self, queue_name: &str) -> Option<String> {
469 let pattern = &self.config.queue_naming_pattern;
470 if let Ok(regex) = Regex::new(pattern) {
471 if let Some(captures) = regex.captures(queue_name) {
472 if let Some(namespace_match) = captures.name("namespace") {
473 return Some(namespace_match.as_str().to_string());
474 }
475 }
476 }
477
478 queue_name
480 .ends_with("_queue")
481 .then(|| queue_name.trim_end_matches("_queue").to_string())
482 }
483
484 pub async fn create_listener(&self, buffer_size: usize) -> Result<PgmqNotifyListener> {
495 PgmqNotifyListener::new(self.pool.clone(), self.config.clone(), buffer_size).await
496 }
497}
498
499impl PgmqClient {
501 #[instrument(skip(self), fields(namespace = %namespace, batch_size = %batch_size))]
503 pub async fn process_namespace_queue(
504 &self,
505 namespace: &str,
506 visibility_timeout: Option<i32>,
507 batch_size: i32,
508 ) -> Result<Vec<Message<serde_json::Value>>> {
509 let queue_name = format!("worker_{namespace}_queue");
510 self.read_messages(&queue_name, visibility_timeout, Some(batch_size))
511 .await
512 }
513
514 #[instrument(skip(self), fields(namespace = %namespace, message_id = %message_id))]
516 pub async fn complete_message(&self, namespace: &str, message_id: i64) -> Result<()> {
517 let queue_name = format!("worker_{namespace}_queue");
518 self.delete_message(&queue_name, message_id).await
519 }
520
521 #[instrument(skip(self, namespaces))]
523 pub async fn initialize_namespace_queues(&self, namespaces: &[&str]) -> Result<()> {
524 info!("Initializing {} namespace queues", namespaces.len());
525
526 for namespace in namespaces {
527 let queue_name = format!("worker_{namespace}_queue");
528 self.create_queue(&queue_name).await?;
529 }
530
531 info!("Initialized all namespace queues");
532 Ok(())
533 }
534
535 #[instrument(skip(self, message, tx), fields(queue = %queue_name))]
537 pub async fn send_with_transaction<T>(
538 &self,
539 queue_name: &str,
540 message: &T,
541 tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
542 ) -> Result<i64>
543 where
544 T: serde::Serialize,
545 {
546 debug!(
547 "๐ค Sending message within transaction to queue: {}",
548 queue_name
549 );
550
551 let serialized = serde_json::to_value(message)?;
552
553 let message_id = sqlx::query_scalar!(
555 "SELECT pgmq_send_with_notify($1, $2, $3)",
556 queue_name,
557 &serialized,
558 0i32
559 )
560 .fetch_one(&mut **tx)
561 .await?;
562
563 let message_id = message_id.ok_or_else(|| {
564 PgmqNotifyError::Generic(anyhow::anyhow!("Wrapper function returned NULL message ID"))
565 })?;
566
567 debug!(
568 "Message sent in transaction with id: {} (with notification)",
569 message_id
570 );
571 Ok(message_id)
572 }
573}
574
575#[derive(Debug)]
577pub struct PgmqClientFactory;
578
579impl PgmqClientFactory {
580 pub async fn create(database_url: &str) -> Result<PgmqClient> {
582 PgmqClient::new(database_url).await
583 }
584
585 pub async fn create_with_config(
587 database_url: &str,
588 config: PgmqNotifyConfig,
589 ) -> Result<PgmqClient> {
590 PgmqClient::new_with_config(database_url, config).await
591 }
592
593 pub async fn create_with_pool(pool: PgPool) -> PgmqClient {
595 PgmqClient::new_with_pool(pool).await
596 }
597
598 pub async fn create_with_pool_and_config(pool: PgPool, config: PgmqNotifyConfig) -> PgmqClient {
600 PgmqClient::new_with_pool_and_config(pool, config).await
601 }
602}
603
604pub use PgmqClient as PgmqNotifyClient;
606pub use PgmqClientFactory as PgmqNotifyClientFactory;
607
608#[cfg(test)]
609mod tests {
610 use super::*;
611 use dotenvy::dotenv;
612
613 #[tokio::test]
614 async fn test_pgmq_client_creation() {
615 dotenv().ok();
616 if std::env::var("DATABASE_URL").is_err() {
619 println!("Skipping pgmq test - no DATABASE_URL provided");
620 return;
621 }
622
623 let database_url = std::env::var("DATABASE_URL").unwrap();
624 match PgmqClient::new(&database_url).await {
625 Ok(_) => {
626 println!("PgmqClient created successfully");
628 }
629 Err(e) => {
630 println!(" Skipping test due to client creation error: {e:?}");
633 return;
634 }
635 }
636 }
637
638 #[test]
639 fn test_namespace_extraction() {
640 dotenv().ok();
641 let config = PgmqNotifyConfig::new().with_queue_naming_pattern(r"(?P<namespace>\w+)_queue");
642
643 let pattern = &config.queue_naming_pattern;
645 let regex = regex::Regex::new(pattern).unwrap();
646
647 let captures = regex.captures("orders_queue").unwrap();
649 let namespace = captures.name("namespace").unwrap().as_str();
650 assert_eq!(namespace, "orders");
651
652 let captures = regex.captures("inventory_queue").unwrap();
653 let namespace = captures.name("namespace").unwrap().as_str();
654 assert_eq!(namespace, "inventory");
655
656 assert!(regex.captures("invalid_name").is_none());
658 }
659
660 #[tokio::test]
661 async fn test_shared_pool_pattern() {
662 dotenv().ok();
663 if std::env::var("DATABASE_URL").is_err() {
665 println!("Skipping shared pool test - no DATABASE_URL provided");
666 return;
667 }
668
669 let database_url = std::env::var("DATABASE_URL").unwrap();
670
671 let pool = sqlx::postgres::PgPoolOptions::new()
673 .max_connections(5)
674 .connect(&database_url)
675 .await
676 .expect("Failed to create connection pool");
677
678 let client = PgmqClient::new_with_pool(pool.clone()).await;
680
681 assert_eq!(client.pool().size(), pool.size());
683
684 println!("Shared pool pattern working correctly");
685 }
686}