rust_rabbit/
consumer.rs

1use crate::{
2    connection::ConnectionManager,
3    error::{RabbitError, Result},
4    metrics::RustRabbitMetrics,
5    publisher::CustomQueueDeclareOptions,
6    retry::RetryPolicy,
7};
8use async_trait::async_trait;
9use futures::StreamExt;
10use lapin::{
11    message::Delivery,
12    options::{
13        BasicAckOptions, BasicConsumeOptions, BasicNackOptions,
14        QueueDeclareOptions as LapinQueueDeclareOptions,
15    },
16    types::FieldTable,
17    Channel,
18};
19use serde::de::DeserializeOwned;
20use std::sync::Arc;
21use tokio::sync::Semaphore;
22use tracing::{debug, error, info, warn};
23
24/// Message handler trait for processing consumed messages
25#[async_trait]
26pub trait MessageHandler<T>: Send + Sync + 'static
27where
28    T: DeserializeOwned + Send + Sync,
29{
30    /// Handle a received message
31    async fn handle(&self, message: T, context: MessageContext) -> MessageResult;
32}
33
34/// Context information for a received message
35#[derive(Debug, Clone)]
36pub struct MessageContext {
37    pub message_id: Option<String>,
38    pub correlation_id: Option<String>,
39    pub reply_to: Option<String>,
40    pub delivery_tag: u64,
41    pub redelivered: bool,
42    pub exchange: String,
43    pub routing_key: String,
44    pub headers: FieldTable,
45    pub timestamp: Option<u64>,
46    pub retry_count: u32,
47}
48
49/// Result of message processing
50#[derive(Debug)]
51pub enum MessageResult {
52    /// Message processed successfully
53    Ack,
54    /// Message processing failed, should be retried
55    Retry,
56    /// Message processing failed permanently, should be rejected
57    Reject,
58    /// Message processing failed, should be requeued
59    Requeue,
60}
61
62/// Consumer options
63#[derive(Debug, Clone)]
64pub struct ConsumerOptions {
65    /// Queue name to consume from
66    pub queue_name: String,
67
68    /// Consumer tag (optional)
69    pub consumer_tag: Option<String>,
70
71    /// Number of concurrent message processors
72    pub concurrency: usize,
73
74    /// Prefetch count (QoS)
75    pub prefetch_count: Option<u16>,
76
77    /// Auto-declare queue before consuming
78    pub auto_declare_queue: bool,
79
80    /// Queue declaration options
81    pub queue_options: CustomQueueDeclareOptions,
82
83    /// Retry policy for failed messages
84    pub retry_policy: Option<RetryPolicy>,
85
86    /// Dead letter exchange for failed messages
87    pub dead_letter_exchange: Option<String>,
88
89    /// Auto-ack messages (not recommended for production)
90    pub auto_ack: bool,
91
92    /// Consumer exclusive mode
93    pub exclusive: bool,
94
95    /// Consumer arguments
96    pub arguments: FieldTable,
97}
98
99impl ConsumerOptions {
100    /// Create a new consumer options builder
101    pub fn builder<S: Into<String>>(queue_name: S) -> ConsumerOptionsBuilder {
102        ConsumerOptionsBuilder::new(queue_name.into())
103    }
104}
105
106/// Builder for ConsumerOptions
107#[derive(Debug, Clone)]
108pub struct ConsumerOptionsBuilder {
109    queue_name: String,
110    consumer_tag: Option<String>,
111    concurrency: usize,
112    prefetch_count: Option<u16>,
113    auto_declare_queue: bool,
114    queue_options: CustomQueueDeclareOptions,
115    retry_policy: Option<RetryPolicy>,
116    dead_letter_exchange: Option<String>,
117    auto_ack: bool,
118    exclusive: bool,
119    arguments: FieldTable,
120}
121
122impl ConsumerOptionsBuilder {
123    /// Create a new builder with default values
124    pub fn new(queue_name: String) -> Self {
125        Self {
126            queue_name,
127            consumer_tag: None,
128            concurrency: 1,
129            prefetch_count: Some(10),
130            auto_declare_queue: false,
131            queue_options: CustomQueueDeclareOptions::default(),
132            retry_policy: None,
133            dead_letter_exchange: None,
134            auto_ack: false,
135            exclusive: false,
136            arguments: FieldTable::default(),
137        }
138    }
139
140    /// Set consumer tag
141    pub fn consumer_tag<S: Into<String>>(mut self, tag: S) -> Self {
142        self.consumer_tag = Some(tag.into());
143        self
144    }
145
146    /// Set concurrency level
147    pub fn concurrency(mut self, concurrency: usize) -> Self {
148        self.concurrency = concurrency;
149        self
150    }
151
152    /// Set prefetch count
153    pub fn prefetch_count(mut self, count: u16) -> Self {
154        self.prefetch_count = Some(count);
155        self
156    }
157
158    /// Disable prefetch limit
159    pub fn no_prefetch_limit(mut self) -> Self {
160        self.prefetch_count = None;
161        self
162    }
163
164    /// Enable auto-declare queue
165    pub fn auto_declare_queue(mut self) -> Self {
166        self.auto_declare_queue = true;
167        self
168    }
169
170    /// Set queue options
171    pub fn queue_options(mut self, options: CustomQueueDeclareOptions) -> Self {
172        self.queue_options = options;
173        self
174    }
175
176    /// Set retry policy
177    pub fn retry_policy(mut self, policy: RetryPolicy) -> Self {
178        self.retry_policy = Some(policy);
179        self
180    }
181
182    /// Set dead letter exchange
183    pub fn dead_letter_exchange<S: Into<String>>(mut self, exchange: S) -> Self {
184        self.dead_letter_exchange = Some(exchange.into());
185        self
186    }
187
188    /// Enable auto-ack (not recommended for production)
189    pub fn auto_ack(mut self) -> Self {
190        self.auto_ack = true;
191        self
192    }
193
194    /// Enable manual ack (recommended for production)
195    pub fn manual_ack(mut self) -> Self {
196        self.auto_ack = false;
197        self
198    }
199
200    /// Enable exclusive mode
201    pub fn exclusive(mut self) -> Self {
202        self.exclusive = true;
203        self
204    }
205
206    /// Configure for high throughput
207    pub fn high_throughput(mut self) -> Self {
208        self.concurrency = 20;
209        self.prefetch_count = Some(50);
210        self.auto_ack = false;
211        self
212    }
213
214    /// Configure for reliability (lower throughput but safer)
215    pub fn reliable(mut self) -> Self {
216        self.concurrency = 1;
217        self.prefetch_count = Some(1);
218        self.auto_ack = false;
219        self
220    }
221
222    /// Configure for development (simpler settings)
223    pub fn development(mut self) -> Self {
224        self.concurrency = 1;
225        self.prefetch_count = Some(1);
226        self.auto_ack = true;
227        self.auto_declare_queue = true;
228        self
229    }
230
231    /// Build the final configuration
232    pub fn build(self) -> ConsumerOptions {
233        ConsumerOptions {
234            queue_name: self.queue_name,
235            consumer_tag: self.consumer_tag,
236            concurrency: self.concurrency,
237            prefetch_count: self.prefetch_count,
238            auto_declare_queue: self.auto_declare_queue,
239            queue_options: self.queue_options,
240            retry_policy: self.retry_policy,
241            dead_letter_exchange: self.dead_letter_exchange,
242            auto_ack: self.auto_ack,
243            exclusive: self.exclusive,
244            arguments: self.arguments,
245        }
246    }
247}
248
249impl Default for ConsumerOptions {
250    fn default() -> Self {
251        Self {
252            queue_name: String::new(),
253            consumer_tag: None,
254            concurrency: 1,
255            prefetch_count: Some(10),
256            auto_declare_queue: false,
257            queue_options: CustomQueueDeclareOptions::default(),
258            retry_policy: None,
259            dead_letter_exchange: None,
260            auto_ack: false,
261            exclusive: false,
262            arguments: FieldTable::default(),
263        }
264    }
265}
266
267/// Consumer for receiving messages from RabbitMQ
268pub struct Consumer {
269    #[allow(dead_code)] // Will be used for connection health monitoring
270    connection_manager: ConnectionManager,
271    options: ConsumerOptions,
272    channel: Channel,
273    semaphore: Arc<Semaphore>,
274    metrics: Option<RustRabbitMetrics>,
275}
276
277impl Consumer {
278    /// Create a new consumer
279    pub async fn new(
280        connection_manager: ConnectionManager,
281        options: ConsumerOptions,
282    ) -> Result<Self> {
283        let connection = connection_manager.get_connection().await?;
284        let channel = connection.create_channel().await?;
285
286        // Set QoS if prefetch_count is specified
287        if let Some(prefetch_count) = options.prefetch_count {
288            channel
289                .basic_qos(prefetch_count, lapin::options::BasicQosOptions::default())
290                .await?;
291        }
292
293        // Declare queue if auto_declare is enabled
294        if options.auto_declare_queue {
295            Self::declare_queue(&channel, &options).await?;
296        }
297
298        let semaphore = Arc::new(Semaphore::new(options.concurrency));
299
300        Ok(Self {
301            connection_manager,
302            options,
303            channel,
304            semaphore,
305            metrics: None,
306        })
307    }
308
309    /// Set metrics for this consumer
310    pub fn set_metrics(&mut self, metrics: RustRabbitMetrics) {
311        self.metrics = Some(metrics);
312    }
313
314    /// Start consuming messages with the given handler
315    pub async fn consume<T, H>(&self, handler: Arc<H>) -> Result<()>
316    where
317        T: DeserializeOwned + Send + Sync + 'static,
318        H: MessageHandler<T>,
319    {
320        let consumer_tag = self
321            .options
322            .consumer_tag
323            .clone()
324            .unwrap_or_else(|| format!("rust-rabbit-{}", uuid::Uuid::new_v4()));
325
326        let consume_options = BasicConsumeOptions {
327            no_local: false,
328            no_ack: self.options.auto_ack,
329            exclusive: self.options.exclusive,
330            nowait: false,
331        };
332
333        let mut consumer = self
334            .channel
335            .basic_consume(
336                &self.options.queue_name,
337                &consumer_tag,
338                consume_options,
339                self.options.arguments.clone(),
340            )
341            .await?;
342
343        info!(
344            "Started consuming from queue: {} with tag: {}",
345            self.options.queue_name, consumer_tag
346        );
347
348        while let Some(delivery) = consumer.next().await {
349            let delivery = delivery?;
350            let permit = self
351                .semaphore
352                .clone()
353                .acquire_owned()
354                .await
355                .map_err(|e| RabbitError::Generic(e.into()))?;
356
357            let handler = handler.clone();
358            let retry_policy = self.options.retry_policy.clone();
359            let dead_letter_exchange = self.options.dead_letter_exchange.clone();
360            let channel = self.channel.clone();
361
362            // Process message in a separate task
363            tokio::spawn(async move {
364                let _permit = permit; // Hold the permit for the duration of processing
365
366                if let Err(e) = Self::process_message::<T, H>(
367                    delivery,
368                    handler,
369                    retry_policy,
370                    dead_letter_exchange,
371                    channel,
372                )
373                .await
374                {
375                    error!("Error processing message: {}", e);
376                }
377            });
378        }
379
380        warn!(
381            "Consumer stream ended for queue: {}",
382            self.options.queue_name
383        );
384        Ok(())
385    }
386
387    /// Process a single message
388    async fn process_message<T, H>(
389        delivery: Delivery,
390        handler: Arc<H>,
391        retry_policy: Option<RetryPolicy>,
392        dead_letter_exchange: Option<String>,
393        channel: Channel,
394    ) -> Result<()>
395    where
396        T: DeserializeOwned + Send + Sync,
397        H: MessageHandler<T>,
398    {
399        let context = Self::build_message_context(&delivery);
400
401        // Deserialize message
402        let message: T = match serde_json::from_slice(&delivery.data) {
403            Ok(msg) => msg,
404            Err(e) => {
405                error!("Failed to deserialize message: {}", e);
406                Self::reject_message(&delivery, &channel, false).await?;
407                return Ok(());
408            }
409        };
410
411        // Handle message
412        let result = handler.handle(message, context.clone()).await;
413
414        match result {
415            MessageResult::Ack => {
416                Self::ack_message(&delivery, &channel).await?;
417                debug!("Message acknowledged: {}", delivery.delivery_tag);
418            }
419            MessageResult::Retry => {
420                if let Some(ref policy) = retry_policy {
421                    Self::handle_retry(&delivery, &channel, &context, policy).await?;
422                } else {
423                    Self::reject_message(&delivery, &channel, true).await?;
424                }
425            }
426            MessageResult::Reject => {
427                if let Some(ref dle) = dead_letter_exchange {
428                    Self::send_to_dead_letter(&delivery, &channel, dle, &context).await?;
429                } else {
430                    Self::reject_message(&delivery, &channel, false).await?;
431                }
432            }
433            MessageResult::Requeue => {
434                Self::reject_message(&delivery, &channel, true).await?;
435            }
436        }
437
438        Ok(())
439    }
440
441    /// Build message context from delivery
442    fn build_message_context(delivery: &Delivery) -> MessageContext {
443        let properties = &delivery.properties;
444
445        MessageContext {
446            message_id: properties.message_id().as_ref().map(|s| s.to_string()),
447            correlation_id: properties.correlation_id().as_ref().map(|s| s.to_string()),
448            reply_to: properties.reply_to().as_ref().map(|s| s.to_string()),
449            delivery_tag: delivery.delivery_tag,
450            redelivered: delivery.redelivered,
451            exchange: delivery.exchange.to_string(),
452            routing_key: delivery.routing_key.to_string(),
453            headers: properties.headers().clone().unwrap_or_default(),
454            timestamp: *properties.timestamp(),
455            retry_count: Self::get_retry_count_from_headers(
456                properties
457                    .headers()
458                    .as_ref()
459                    .unwrap_or(&FieldTable::default()),
460            ),
461        }
462    }
463
464    /// Get retry count from message headers
465    fn get_retry_count_from_headers(headers: &FieldTable) -> u32 {
466        headers
467            .inner()
468            .get("x-retry-count")
469            .and_then(|v| match v {
470                lapin::types::AMQPValue::LongInt(count) => Some(*count as u32),
471                lapin::types::AMQPValue::LongLongInt(count) => Some(*count as u32),
472                _ => None,
473            })
474            .unwrap_or(0)
475    }
476
477    /// Acknowledge a message
478    async fn ack_message(delivery: &Delivery, channel: &Channel) -> Result<()> {
479        channel
480            .basic_ack(delivery.delivery_tag, BasicAckOptions::default())
481            .await?;
482        Ok(())
483    }
484
485    /// Reject a message
486    async fn reject_message(delivery: &Delivery, channel: &Channel, requeue: bool) -> Result<()> {
487        channel
488            .basic_nack(
489                delivery.delivery_tag,
490                BasicNackOptions {
491                    multiple: false,
492                    requeue,
493                },
494            )
495            .await?;
496        Ok(())
497    }
498
499    /// Handle retry logic
500    async fn handle_retry(
501        delivery: &Delivery,
502        channel: &Channel,
503        context: &MessageContext,
504        retry_policy: &RetryPolicy,
505    ) -> Result<()> {
506        if context.retry_count >= retry_policy.max_retries {
507            warn!(
508                "Max retries exceeded for message: {}",
509                delivery.delivery_tag
510            );
511            Self::reject_message(delivery, channel, false).await?;
512            return Ok(());
513        }
514
515        // Calculate delay for next retry
516        let delay = retry_policy.calculate_delay(context.retry_count);
517
518        // For now, just requeue the message
519        // In a production implementation, you would use the delayed message exchange
520        // or implement a retry queue pattern
521        info!(
522            "Retrying message after {:?} (attempt {})",
523            delay,
524            context.retry_count + 1
525        );
526        Self::reject_message(delivery, channel, true).await?;
527
528        Ok(())
529    }
530
531    /// Send message to dead letter exchange
532    async fn send_to_dead_letter(
533        delivery: &Delivery,
534        channel: &Channel,
535        dead_letter_exchange: &str,
536        _context: &MessageContext,
537    ) -> Result<()> {
538        // In a real implementation, you would republish the message to the DLE
539        // For now, just reject without requeue
540        warn!(
541            "Sending message to dead letter exchange: {}",
542            dead_letter_exchange
543        );
544        Self::reject_message(delivery, channel, false).await?;
545        Ok(())
546    }
547
548    /// Declare queue with options
549    async fn declare_queue(channel: &Channel, options: &ConsumerOptions) -> Result<()> {
550        let queue_options = LapinQueueDeclareOptions {
551            passive: options.queue_options.passive,
552            durable: options.queue_options.durable,
553            exclusive: options.queue_options.exclusive,
554            auto_delete: options.queue_options.auto_delete,
555            nowait: false,
556        };
557
558        channel
559            .queue_declare(
560                &options.queue_name,
561                queue_options,
562                options.queue_options.arguments.clone(),
563            )
564            .await?;
565
566        debug!("Declared queue: {}", options.queue_name);
567        Ok(())
568    }
569
570    /// Stop consuming (close the consumer)
571    pub async fn stop(&self) -> Result<()> {
572        // The consumer will stop when the channel is closed
573        // or when the stream ends
574        info!("Stopping consumer for queue: {}", self.options.queue_name);
575        Ok(())
576    }
577}
578
579// Example message handler implementation
580pub struct SimpleMessageHandler<F, T>
581where
582    F: Fn(T, MessageContext) -> MessageResult + Send + Sync,
583    T: DeserializeOwned + Send + Sync,
584{
585    handler_fn: F,
586    _phantom: std::marker::PhantomData<T>,
587}
588
589impl<F, T> SimpleMessageHandler<F, T>
590where
591    F: Fn(T, MessageContext) -> MessageResult + Send + Sync + 'static,
592    T: DeserializeOwned + Send + Sync + 'static,
593{
594    pub fn new(handler_fn: F) -> Self {
595        Self {
596            handler_fn,
597            _phantom: std::marker::PhantomData,
598        }
599    }
600}
601
602#[async_trait]
603impl<F, T> MessageHandler<T> for SimpleMessageHandler<F, T>
604where
605    F: Fn(T, MessageContext) -> MessageResult + Send + Sync + 'static,
606    T: DeserializeOwned + Send + Sync + 'static,
607{
608    async fn handle(&self, message: T, context: MessageContext) -> MessageResult {
609        (self.handler_fn)(message, context)
610    }
611}