Skip to main content

tasker_pgmq/
client.rs

1//! # Unified PGMQ Client
2//!
3//! This module provides a unified `PostgreSQL` Message Queue (PGMQ) client that combines
4//! all functionality from both the original `PgmqNotifyClient` and the tasker-shared `PgmqClient`.
5//! It includes both notification capabilities and comprehensive PGMQ operations.
6
7use crate::{
8    config::PgmqNotifyConfig,
9    error::{PgmqNotifyError, Result},
10    listener::PgmqNotifyListener,
11    types::{ClientStatus, QueueMetrics},
12};
13use pgmq::{types::Message, PGMQueue};
14use regex::Regex;
15use sqlx::{PgPool, Row};
16use std::collections::HashMap;
17use tracing::{debug, error, info, instrument, warn};
18
19/// Unified PGMQ client with comprehensive functionality and notification capabilities
20#[derive(Debug, Clone)]
21pub struct PgmqClient {
22    /// Underlying PGMQ client
23    pgmq: PGMQueue,
24    /// Database connection pool for advanced operations and health checks
25    pool: sqlx::PgPool,
26    /// Configuration for notifications and queue naming
27    config: PgmqNotifyConfig,
28}
29
30impl PgmqClient {
31    /// Create new unified PGMQ client using connection string
32    pub async fn new(database_url: &str) -> Result<Self> {
33        Self::new_with_config(database_url, PgmqNotifyConfig::default()).await
34    }
35
36    /// Create new unified PGMQ client with custom configuration
37    pub async fn new_with_config(database_url: &str, config: PgmqNotifyConfig) -> Result<Self> {
38        info!("Connecting to pgmq using unified client");
39
40        let pgmq = PGMQueue::new(database_url.to_string()).await?;
41        let pool = sqlx::postgres::PgPoolOptions::new()
42            .max_connections(20)
43            .connect(database_url)
44            .await?;
45
46        info!("Connected to pgmq using unified client");
47        Ok(Self { pgmq, pool, config })
48    }
49
50    /// Create new unified PGMQ client using existing connection pool (BYOP - Bring Your Own Pool)
51    pub async fn new_with_pool(pool: sqlx::PgPool) -> Self {
52        Self::new_with_pool_and_config(pool, PgmqNotifyConfig::default()).await
53    }
54
55    /// Create new unified PGMQ client with existing pool and custom configuration
56    pub async fn new_with_pool_and_config(pool: sqlx::PgPool, config: PgmqNotifyConfig) -> Self {
57        info!("Creating unified pgmq client with shared connection pool");
58
59        let pgmq = PGMQueue::new_with_pool(pool.clone()).await;
60
61        info!("Unified pgmq client created with shared pool");
62        Self { pgmq, pool, config }
63    }
64
65    /// Create queue if it doesn't exist
66    #[instrument(skip(self), fields(queue = %queue_name))]
67    pub async fn create_queue(&self, queue_name: &str) -> Result<()> {
68        debug!("๐Ÿ“‹ Creating queue: {}", queue_name);
69
70        self.pgmq.create(queue_name).await?;
71
72        info!("Queue created: {}", queue_name);
73        Ok(())
74    }
75
76    /// Send generic JSON message to queue
77    #[instrument(skip(self, message), fields(queue = %queue_name))]
78    pub async fn send_json_message<T>(&self, queue_name: &str, message: &T) -> Result<i64>
79    where
80        T: serde::Serialize,
81    {
82        debug!("๐Ÿ“ค Sending JSON message to queue: {}", queue_name);
83
84        let serialized = serde_json::to_value(message)?;
85
86        // Use wrapper function for atomic message sending + notification
87        let message_id = sqlx::query_scalar!(
88            "SELECT pgmq_send_with_notify($1, $2, $3)",
89            queue_name,
90            &serialized,
91            0i32
92        )
93        .fetch_one(&self.pool)
94        .await?;
95
96        let message_id = message_id.ok_or_else(|| {
97            PgmqNotifyError::Generic(anyhow::anyhow!("Wrapper function returned NULL message ID"))
98        })?;
99
100        info!(
101            "JSON message sent to queue: {} with ID: {} (with notification)",
102            queue_name, message_id
103        );
104        Ok(message_id)
105    }
106
107    /// Send message with visibility timeout (delay)
108    #[instrument(skip(self, message), fields(queue = %queue_name, delay_seconds = %delay_seconds))]
109    pub async fn send_message_with_delay<T>(
110        &self,
111        queue_name: &str,
112        message: &T,
113        delay_seconds: u64,
114    ) -> Result<i64>
115    where
116        T: serde::Serialize,
117    {
118        debug!(
119            "๐Ÿ“ค Sending delayed message to queue: {} with delay: {}s",
120            queue_name, delay_seconds
121        );
122
123        let serialized = serde_json::to_value(message)?;
124
125        // Use wrapper function for atomic message sending + notification
126        let message_id = sqlx::query_scalar!(
127            "SELECT pgmq_send_with_notify($1, $2, $3)",
128            queue_name,
129            &serialized,
130            delay_seconds as i32
131        )
132        .fetch_one(&self.pool)
133        .await?;
134
135        let message_id = message_id.ok_or_else(|| {
136            PgmqNotifyError::Generic(anyhow::anyhow!("Wrapper function returned NULL message ID"))
137        })?;
138
139        info!(
140            "Delayed message sent to queue: {} with ID: {} (with notification)",
141            queue_name, message_id
142        );
143        Ok(message_id)
144    }
145
146    /// Read messages from queue
147    #[instrument(skip(self), fields(queue = %queue_name, limit = ?limit))]
148    pub async fn read_messages(
149        &self,
150        queue_name: &str,
151        visibility_timeout: Option<i32>,
152        limit: Option<i32>,
153    ) -> Result<Vec<Message<serde_json::Value>>> {
154        debug!(
155            "๐Ÿ“ฅ Reading messages from queue: {} (limit: {:?})",
156            queue_name, limit
157        );
158
159        let messages = match limit {
160            Some(l) => self
161                .pgmq
162                .read_batch(queue_name, visibility_timeout, l)
163                .await?
164                .unwrap_or_default(),
165            None => match self.pgmq.read(queue_name, visibility_timeout).await? {
166                Some(msg) => vec![msg],
167                None => vec![],
168            },
169        };
170
171        debug!(
172            "๐Ÿ“จ Read {} messages from queue: {}",
173            messages.len(),
174            queue_name
175        );
176        Ok(messages)
177    }
178
179    /// Read messages from queue with pop (single read and delete)
180    #[instrument(skip(self), fields(queue = %queue_name))]
181    pub async fn pop_message(
182        &self,
183        queue_name: &str,
184    ) -> Result<Option<Message<serde_json::Value>>> {
185        debug!("๐Ÿ“ฅ Popping message from queue: {}", queue_name);
186
187        let message = self.pgmq.pop(queue_name).await?;
188
189        if message.is_some() {
190            debug!("๐Ÿ“จ Popped message from queue: {}", queue_name);
191        } else {
192            debug!("๐Ÿ“ญ No messages available in queue: {}", queue_name);
193        }
194        Ok(message)
195    }
196
197    /// Read a specific message by ID using custom SQL function (for notification event handling)
198    #[instrument(skip(self), fields(queue = %queue_name, message_id = %message_id))]
199    pub async fn read_specific_message<T>(
200        &self,
201        queue_name: &str,
202        message_id: i64,
203        visibility_timeout: i32,
204    ) -> Result<Option<Message<T>>>
205    where
206        T: serde::de::DeserializeOwned,
207    {
208        debug!(
209            "๐Ÿ“ฅ Reading specific message {} from queue: {}",
210            message_id, queue_name
211        );
212
213        // Use the custom SQL function pgmq_read_specific_message for efficient specific message reading
214        let query = "SELECT msg_id, read_ct, enqueued_at, vt, message FROM pgmq_read_specific_message($1, $2, $3)";
215
216        let row = sqlx::query(query)
217            .bind(queue_name)
218            .bind(message_id)
219            .bind(visibility_timeout)
220            .fetch_optional(&self.pool)
221            .await?;
222
223        if let Some(row) = row {
224            let msg_id: i64 = row.get("msg_id");
225            let read_ct: i32 = row.get("read_ct");
226            let enqueued_at: chrono::DateTime<chrono::Utc> = row.get("enqueued_at");
227            let vt: chrono::DateTime<chrono::Utc> = row.get("vt");
228            let message_json: serde_json::Value = row.get("message");
229
230            // Try to deserialize the message to the expected type
231            match serde_json::from_value::<T>(message_json) {
232                Ok(deserialized) => {
233                    let typed_message = Message {
234                        msg_id,
235                        read_ct,
236                        enqueued_at,
237                        vt,
238                        message: deserialized,
239                    };
240                    debug!("Found and deserialized specific message {}", message_id);
241                    Ok(Some(typed_message))
242                }
243                Err(e) => {
244                    warn!("Failed to deserialize message {}: {}", message_id, e);
245                    Err(PgmqNotifyError::Serialization(e))
246                }
247            }
248        } else {
249            debug!(
250                "๐Ÿ“ญ Specific message {} not found in queue: {}",
251                message_id, queue_name
252            );
253            Ok(None)
254        }
255    }
256
257    /// Delete message from queue
258    #[instrument(skip(self), fields(queue = %queue_name, message_id = %message_id))]
259    pub async fn delete_message(&self, queue_name: &str, message_id: i64) -> Result<()> {
260        debug!(
261            "๐Ÿ—‘Deleting message {} from queue: {}",
262            message_id, queue_name
263        );
264
265        self.pgmq.delete(queue_name, message_id).await?;
266
267        debug!("Message deleted: {}", message_id);
268        Ok(())
269    }
270
271    /// Archive message (move to archive)
272    #[instrument(skip(self), fields(queue = %queue_name, message_id = %message_id))]
273    pub async fn archive_message(&self, queue_name: &str, message_id: i64) -> Result<()> {
274        debug!(
275            "Archiving message {} from queue: {}",
276            message_id, queue_name
277        );
278
279        self.pgmq.archive(queue_name, message_id).await?;
280
281        debug!("Message archived: {}", message_id);
282        Ok(())
283    }
284
285    /// Set visibility timeout for a message
286    ///
287    /// Extends or resets the visibility timeout for a message. This is useful for:
288    /// - Heartbeat during long-running step processing
289    /// - Returning a message to the queue immediately (vt_seconds = 0)
290    ///
291    /// # Arguments
292    ///
293    /// * `queue_name` - The queue containing the message
294    /// * `message_id` - The message ID to update
295    /// * `vt_seconds` - New visibility timeout in seconds from now
296    ///
297    /// # Returns
298    ///
299    /// Returns `Ok(())` on success, or an error if the message doesn't exist.
300    #[instrument(skip(self), fields(queue = %queue_name, message_id = %message_id, vt_seconds = %vt_seconds))]
301    pub async fn set_visibility_timeout(
302        &self,
303        queue_name: &str,
304        message_id: i64,
305        vt_seconds: i32,
306    ) -> Result<()> {
307        debug!(
308            "Setting visibility timeout for message {} in queue {} to {} seconds",
309            message_id, queue_name, vt_seconds
310        );
311
312        // Use pgmq.set_vt SQL function directly for precise control
313        // The function signature is: pgmq.set_vt(queue_name text, msg_id bigint, vt_offset integer)
314        sqlx::query_scalar!(
315            "SELECT msg_id FROM pgmq.set_vt($1::text, $2::bigint, $3::integer)",
316            queue_name,
317            message_id,
318            vt_seconds
319        )
320        .fetch_optional(&self.pool)
321        .await?;
322
323        debug!(
324            "Visibility timeout set for message {} to {} seconds",
325            message_id, vt_seconds
326        );
327        Ok(())
328    }
329
330    /// Purge queue (delete all messages)
331    #[instrument(skip(self), fields(queue = %queue_name))]
332    pub async fn purge_queue(&self, queue_name: &str) -> Result<u64> {
333        warn!("๐Ÿงน Purging queue: {}", queue_name);
334
335        let purged_count = self.pgmq.purge(queue_name).await?;
336
337        warn!(
338            "๐Ÿ—‘Purged {} messages from queue: {}",
339            purged_count, queue_name
340        );
341        Ok(purged_count)
342    }
343
344    /// Drop queue completely
345    #[instrument(skip(self), fields(queue = %queue_name))]
346    pub async fn drop_queue(&self, queue_name: &str) -> Result<()> {
347        warn!("๐Ÿ’ฅ Dropping queue: {}", queue_name);
348
349        self.pgmq.destroy(queue_name).await?;
350
351        warn!("๐Ÿ—‘Queue dropped: {}", queue_name);
352        Ok(())
353    }
354
355    /// Get queue metrics/statistics
356    #[instrument(skip(self), fields(queue = %queue_name))]
357    pub async fn queue_metrics(&self, queue_name: &str) -> Result<QueueMetrics> {
358        debug!("Getting metrics for queue: {}", queue_name);
359
360        // Query actual pgmq metrics from the database using pgmq.metrics() function
361        let row = sqlx::query!(
362            "SELECT queue_length, oldest_msg_age_sec FROM pgmq.metrics($1)",
363            queue_name
364        )
365        .fetch_optional(&self.pool)
366        .await?;
367
368        if let Some(row) = row {
369            Ok(QueueMetrics {
370                queue_name: queue_name.to_string(),
371                message_count: row.queue_length.unwrap_or(0),
372                consumer_count: None,
373                oldest_message_age_seconds: row.oldest_msg_age_sec.map(i64::from),
374            })
375        } else {
376            // Queue doesn't exist or has no metrics
377            Ok(QueueMetrics {
378                queue_name: queue_name.to_string(),
379                message_count: 0,
380                consumer_count: None,
381                oldest_message_age_seconds: None,
382            })
383        }
384    }
385
386    /// Get reference to underlying connection pool for advanced operations
387    #[must_use]
388    pub fn pool(&self) -> &sqlx::PgPool {
389        &self.pool
390    }
391
392    /// Get reference to pgmq client for direct access
393    #[must_use]
394    pub fn pgmq(&self) -> &PGMQueue {
395        &self.pgmq
396    }
397
398    /// Get the configuration
399    #[must_use]
400    pub fn config(&self) -> &PgmqNotifyConfig {
401        &self.config
402    }
403
404    /// Check if this client has notification capabilities enabled
405    #[must_use]
406    pub fn has_notify_capabilities(&self) -> bool {
407        // Check if the config has triggers enabled which enable notifications
408        self.config.enable_triggers
409    }
410
411    /// Health check - verify database connectivity
412    #[instrument(skip(self))]
413    pub async fn health_check(&self) -> Result<bool> {
414        match sqlx::query!("SELECT 1 as health_check")
415            .fetch_one(&self.pool)
416            .await
417        {
418            Ok(_) => {
419                debug!("Health check passed");
420                Ok(true)
421            }
422            Err(e) => {
423                error!("Health check failed: {}", e);
424                Ok(false)
425            }
426        }
427    }
428
429    /// Get client status information
430    #[instrument(skip(self))]
431    pub async fn get_client_status(&self) -> Result<ClientStatus> {
432        let healthy = self.health_check().await.unwrap_or(false);
433
434        Ok(ClientStatus {
435            client_type: "pgmq-unified".to_string(),
436            connected: healthy,
437            connection_info: HashMap::from([
438                (
439                    "backend".to_string(),
440                    serde_json::Value::String("postgresql".to_string()),
441                ),
442                (
443                    "queue_type".to_string(),
444                    serde_json::Value::String("pgmq".to_string()),
445                ),
446                (
447                    "has_notifications".to_string(),
448                    serde_json::Value::Bool(true),
449                ),
450                (
451                    "pool_size".to_string(),
452                    serde_json::Value::Number(self.pool.size().into()),
453                ),
454            ]),
455            last_activity: Some(chrono::Utc::now()),
456        })
457    }
458
459    /// Extract namespace from queue name using configured pattern
460    #[must_use]
461    pub fn extract_namespace(&self, queue_name: &str) -> Option<String> {
462        let pattern = &self.config.queue_naming_pattern;
463        if let Ok(regex) = Regex::new(pattern) {
464            if let Some(captures) = regex.captures(queue_name) {
465                if let Some(namespace_match) = captures.name("namespace") {
466                    return Some(namespace_match.as_str().to_string());
467                }
468            }
469        }
470
471        // Fallback: assume queue name is "{namespace}_queue"
472        queue_name
473            .ends_with("_queue")
474            .then(|| queue_name.trim_end_matches("_queue").to_string())
475    }
476
477    /// Create a notify listener for this client
478    ///
479    /// # Arguments
480    /// * `buffer_size` - MPSC channel buffer size (TAS-51: bounded channels)
481    ///
482    /// # Note
483    /// TAS-51: Migrated from unbounded to bounded channel to prevent OOM during notification bursts.
484    /// Buffer size should come from configuration based on context:
485    /// - Orchestration: `config.mpsc_channels.orchestration.event_listeners.pgmq_event_buffer_size`
486    /// - Worker: `config.mpsc_channels.worker.event_listeners.pgmq_event_buffer_size`
487    pub async fn create_listener(&self, buffer_size: usize) -> Result<PgmqNotifyListener> {
488        PgmqNotifyListener::new(self.pool.clone(), self.config.clone(), buffer_size).await
489    }
490}
491
492/// Helper methods for common queue operations
493impl PgmqClient {
494    /// Process messages from namespace queue
495    #[instrument(skip(self), fields(namespace = %namespace, batch_size = %batch_size))]
496    pub async fn process_namespace_queue(
497        &self,
498        namespace: &str,
499        visibility_timeout: Option<i32>,
500        batch_size: i32,
501    ) -> Result<Vec<Message<serde_json::Value>>> {
502        let queue_name = format!("worker_{namespace}_queue");
503        self.read_messages(&queue_name, visibility_timeout, Some(batch_size))
504            .await
505    }
506
507    /// Complete message processing (delete from queue)
508    #[instrument(skip(self), fields(namespace = %namespace, message_id = %message_id))]
509    pub async fn complete_message(&self, namespace: &str, message_id: i64) -> Result<()> {
510        let queue_name = format!("worker_{namespace}_queue");
511        self.delete_message(&queue_name, message_id).await
512    }
513
514    /// Initialize standard namespace queues
515    #[instrument(skip(self, namespaces))]
516    pub async fn initialize_namespace_queues(&self, namespaces: &[&str]) -> Result<()> {
517        info!("Initializing {} namespace queues", namespaces.len());
518
519        for namespace in namespaces {
520            let queue_name = format!("worker_{namespace}_queue");
521            self.create_queue(&queue_name).await?;
522        }
523
524        info!("Initialized all namespace queues");
525        Ok(())
526    }
527
528    /// Send message within a transaction (for atomic operations)
529    #[instrument(skip(self, message, tx), fields(queue = %queue_name))]
530    pub async fn send_with_transaction<T>(
531        &self,
532        queue_name: &str,
533        message: &T,
534        tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
535    ) -> Result<i64>
536    where
537        T: serde::Serialize,
538    {
539        debug!(
540            "๐Ÿ“ค Sending message within transaction to queue: {}",
541            queue_name
542        );
543
544        let serialized = serde_json::to_value(message)?;
545
546        // Use wrapper function within the transaction for atomic message sending + notification
547        let message_id = sqlx::query_scalar!(
548            "SELECT pgmq_send_with_notify($1, $2, $3)",
549            queue_name,
550            &serialized,
551            0i32
552        )
553        .fetch_one(&mut **tx)
554        .await?;
555
556        let message_id = message_id.ok_or_else(|| {
557            PgmqNotifyError::Generic(anyhow::anyhow!("Wrapper function returned NULL message ID"))
558        })?;
559
560        debug!(
561            "Message sent in transaction with id: {} (with notification)",
562            message_id
563        );
564        Ok(message_id)
565    }
566}
567
568/// Factory for creating `PgmqClient` instances
569#[derive(Debug)]
570pub struct PgmqClientFactory;
571
572impl PgmqClientFactory {
573    /// Create new client from database URL
574    pub async fn create(database_url: &str) -> Result<PgmqClient> {
575        PgmqClient::new(database_url).await
576    }
577
578    /// Create new client with configuration
579    pub async fn create_with_config(
580        database_url: &str,
581        config: PgmqNotifyConfig,
582    ) -> Result<PgmqClient> {
583        PgmqClient::new_with_config(database_url, config).await
584    }
585
586    /// Create new client with existing pool
587    pub async fn create_with_pool(pool: PgPool) -> PgmqClient {
588        PgmqClient::new_with_pool(pool).await
589    }
590
591    /// Create new client with existing pool and configuration
592    pub async fn create_with_pool_and_config(pool: PgPool, config: PgmqNotifyConfig) -> PgmqClient {
593        PgmqClient::new_with_pool_and_config(pool, config).await
594    }
595}
596
597// Re-export for backward compatibility
598pub use PgmqClient as PgmqNotifyClient;
599pub use PgmqClientFactory as PgmqNotifyClientFactory;
600
601#[cfg(test)]
602mod tests {
603    use super::*;
604    use dotenvy::dotenv;
605
606    #[tokio::test]
607    async fn test_pgmq_client_creation() {
608        dotenv().ok();
609        // This test requires a PostgreSQL database with pgmq extension
610        // Skip in CI or when database is not available
611        if std::env::var("DATABASE_URL").is_err() {
612            println!("Skipping pgmq test - no DATABASE_URL provided");
613            return;
614        }
615
616        let database_url = std::env::var("DATABASE_URL").unwrap();
617        match PgmqClient::new(&database_url).await {
618            Ok(_) => {
619                // Client creation succeeded
620                println!("PgmqClient created successfully");
621            }
622            Err(e) => {
623                // Skip test if it's a URL parsing or connection error
624                // This allows the test to pass in environments without proper database setup
625                println!(" Skipping test due to client creation error: {e:?}");
626                return;
627            }
628        }
629    }
630
631    #[test]
632    fn test_namespace_extraction() {
633        dotenv().ok();
634        let config = PgmqNotifyConfig::new().with_queue_naming_pattern(r"(?P<namespace>\w+)_queue");
635
636        // Test the pattern matching directly without needing a full client
637        let pattern = &config.queue_naming_pattern;
638        let regex = regex::Regex::new(pattern).unwrap();
639
640        // Test valid patterns
641        let captures = regex.captures("orders_queue").unwrap();
642        let namespace = captures.name("namespace").unwrap().as_str();
643        assert_eq!(namespace, "orders");
644
645        let captures = regex.captures("inventory_queue").unwrap();
646        let namespace = captures.name("namespace").unwrap().as_str();
647        assert_eq!(namespace, "inventory");
648
649        // Test invalid pattern
650        assert!(regex.captures("invalid_name").is_none());
651    }
652
653    #[tokio::test]
654    async fn test_shared_pool_pattern() {
655        dotenv().ok();
656        // Skip test if no database URL provided
657        if std::env::var("DATABASE_URL").is_err() {
658            println!("Skipping shared pool test - no DATABASE_URL provided");
659            return;
660        }
661
662        let database_url = std::env::var("DATABASE_URL").unwrap();
663
664        // Create a connection pool
665        let pool = sqlx::postgres::PgPoolOptions::new()
666            .max_connections(5)
667            .connect(&database_url)
668            .await
669            .expect("Failed to create connection pool");
670
671        // Create pgmq client with shared pool
672        let client = PgmqClient::new_with_pool(pool.clone()).await;
673
674        // Verify we can access the pool
675        assert_eq!(client.pool().size(), pool.size());
676
677        println!("Shared pool pattern working correctly");
678    }
679}