1use crate::{
2 connection::ConnectionManager,
3 error::{RabbitError, Result},
4 metrics::RustRabbitMetrics,
5 publisher::CustomQueueDeclareOptions,
6 retry::RetryPolicy,
7};
8use async_trait::async_trait;
9use futures::StreamExt;
10use lapin::{
11 message::Delivery,
12 options::{
13 BasicAckOptions, BasicConsumeOptions, BasicNackOptions,
14 QueueDeclareOptions as LapinQueueDeclareOptions,
15 },
16 types::FieldTable,
17 Channel,
18};
19use serde::de::DeserializeOwned;
20use std::sync::Arc;
21use tokio::sync::Semaphore;
22use tracing::{debug, error, info, warn};
23
24#[async_trait]
26pub trait MessageHandler<T>: Send + Sync + 'static
27where
28 T: DeserializeOwned + Send + Sync,
29{
30 async fn handle(&self, message: T, context: MessageContext) -> MessageResult;
32}
33
34#[derive(Debug, Clone)]
36pub struct MessageContext {
37 pub message_id: Option<String>,
38 pub correlation_id: Option<String>,
39 pub reply_to: Option<String>,
40 pub delivery_tag: u64,
41 pub redelivered: bool,
42 pub exchange: String,
43 pub routing_key: String,
44 pub headers: FieldTable,
45 pub timestamp: Option<u64>,
46 pub retry_count: u32,
47}
48
49#[derive(Debug)]
51pub enum MessageResult {
52 Ack,
54 Retry,
56 Reject,
58 Requeue,
60}
61
62#[derive(Debug, Clone)]
64pub struct ConsumerOptions {
65 pub queue_name: String,
67
68 pub consumer_tag: Option<String>,
70
71 pub concurrency: usize,
73
74 pub prefetch_count: Option<u16>,
76
77 pub auto_declare_queue: bool,
79
80 pub queue_options: CustomQueueDeclareOptions,
82
83 pub retry_policy: Option<RetryPolicy>,
85
86 pub dead_letter_exchange: Option<String>,
88
89 pub auto_ack: bool,
91
92 pub exclusive: bool,
94
95 pub arguments: FieldTable,
97}
98
99impl ConsumerOptions {
100 pub fn builder<S: Into<String>>(queue_name: S) -> ConsumerOptionsBuilder {
102 ConsumerOptionsBuilder::new(queue_name.into())
103 }
104}
105
106#[derive(Debug, Clone)]
108pub struct ConsumerOptionsBuilder {
109 queue_name: String,
110 consumer_tag: Option<String>,
111 concurrency: usize,
112 prefetch_count: Option<u16>,
113 auto_declare_queue: bool,
114 queue_options: CustomQueueDeclareOptions,
115 retry_policy: Option<RetryPolicy>,
116 dead_letter_exchange: Option<String>,
117 auto_ack: bool,
118 exclusive: bool,
119 arguments: FieldTable,
120}
121
122impl ConsumerOptionsBuilder {
123 pub fn new(queue_name: String) -> Self {
125 Self {
126 queue_name,
127 consumer_tag: None,
128 concurrency: 1,
129 prefetch_count: Some(10),
130 auto_declare_queue: false,
131 queue_options: CustomQueueDeclareOptions::default(),
132 retry_policy: None,
133 dead_letter_exchange: None,
134 auto_ack: false,
135 exclusive: false,
136 arguments: FieldTable::default(),
137 }
138 }
139
140 pub fn consumer_tag<S: Into<String>>(mut self, tag: S) -> Self {
142 self.consumer_tag = Some(tag.into());
143 self
144 }
145
146 pub fn concurrency(mut self, concurrency: usize) -> Self {
148 self.concurrency = concurrency;
149 self
150 }
151
152 pub fn prefetch_count(mut self, count: u16) -> Self {
154 self.prefetch_count = Some(count);
155 self
156 }
157
158 pub fn no_prefetch_limit(mut self) -> Self {
160 self.prefetch_count = None;
161 self
162 }
163
164 pub fn auto_declare_queue(mut self) -> Self {
166 self.auto_declare_queue = true;
167 self
168 }
169
170 pub fn queue_options(mut self, options: CustomQueueDeclareOptions) -> Self {
172 self.queue_options = options;
173 self
174 }
175
176 pub fn retry_policy(mut self, policy: RetryPolicy) -> Self {
178 self.retry_policy = Some(policy);
179 self
180 }
181
182 pub fn dead_letter_exchange<S: Into<String>>(mut self, exchange: S) -> Self {
184 self.dead_letter_exchange = Some(exchange.into());
185 self
186 }
187
188 pub fn auto_ack(mut self) -> Self {
190 self.auto_ack = true;
191 self
192 }
193
194 pub fn manual_ack(mut self) -> Self {
196 self.auto_ack = false;
197 self
198 }
199
200 pub fn exclusive(mut self) -> Self {
202 self.exclusive = true;
203 self
204 }
205
206 pub fn high_throughput(mut self) -> Self {
208 self.concurrency = 20;
209 self.prefetch_count = Some(50);
210 self.auto_ack = false;
211 self
212 }
213
214 pub fn reliable(mut self) -> Self {
216 self.concurrency = 1;
217 self.prefetch_count = Some(1);
218 self.auto_ack = false;
219 self
220 }
221
222 pub fn development(mut self) -> Self {
224 self.concurrency = 1;
225 self.prefetch_count = Some(1);
226 self.auto_ack = true;
227 self.auto_declare_queue = true;
228 self
229 }
230
231 pub fn build(self) -> ConsumerOptions {
233 ConsumerOptions {
234 queue_name: self.queue_name,
235 consumer_tag: self.consumer_tag,
236 concurrency: self.concurrency,
237 prefetch_count: self.prefetch_count,
238 auto_declare_queue: self.auto_declare_queue,
239 queue_options: self.queue_options,
240 retry_policy: self.retry_policy,
241 dead_letter_exchange: self.dead_letter_exchange,
242 auto_ack: self.auto_ack,
243 exclusive: self.exclusive,
244 arguments: self.arguments,
245 }
246 }
247}
248
249impl Default for ConsumerOptions {
250 fn default() -> Self {
251 Self {
252 queue_name: String::new(),
253 consumer_tag: None,
254 concurrency: 1,
255 prefetch_count: Some(10),
256 auto_declare_queue: false,
257 queue_options: CustomQueueDeclareOptions::default(),
258 retry_policy: None,
259 dead_letter_exchange: None,
260 auto_ack: false,
261 exclusive: false,
262 arguments: FieldTable::default(),
263 }
264 }
265}
266
267pub struct Consumer {
269 #[allow(dead_code)] connection_manager: ConnectionManager,
271 options: ConsumerOptions,
272 channel: Channel,
273 semaphore: Arc<Semaphore>,
274 metrics: Option<RustRabbitMetrics>,
275}
276
277impl Consumer {
278 pub async fn new(
280 connection_manager: ConnectionManager,
281 options: ConsumerOptions,
282 ) -> Result<Self> {
283 let connection = connection_manager.get_connection().await?;
284 let channel = connection.create_channel().await?;
285
286 if let Some(prefetch_count) = options.prefetch_count {
288 channel
289 .basic_qos(prefetch_count, lapin::options::BasicQosOptions::default())
290 .await?;
291 }
292
293 if options.auto_declare_queue {
295 Self::declare_queue(&channel, &options).await?;
296 }
297
298 let semaphore = Arc::new(Semaphore::new(options.concurrency));
299
300 Ok(Self {
301 connection_manager,
302 options,
303 channel,
304 semaphore,
305 metrics: None,
306 })
307 }
308
309 pub fn set_metrics(&mut self, metrics: RustRabbitMetrics) {
311 self.metrics = Some(metrics);
312 }
313
314 pub async fn consume<T, H>(&self, handler: Arc<H>) -> Result<()>
316 where
317 T: DeserializeOwned + Send + Sync + 'static,
318 H: MessageHandler<T>,
319 {
320 let consumer_tag = self
321 .options
322 .consumer_tag
323 .clone()
324 .unwrap_or_else(|| format!("rust-rabbit-{}", uuid::Uuid::new_v4()));
325
326 let consume_options = BasicConsumeOptions {
327 no_local: false,
328 no_ack: self.options.auto_ack,
329 exclusive: self.options.exclusive,
330 nowait: false,
331 };
332
333 let mut consumer = self
334 .channel
335 .basic_consume(
336 &self.options.queue_name,
337 &consumer_tag,
338 consume_options,
339 self.options.arguments.clone(),
340 )
341 .await?;
342
343 info!(
344 "Started consuming from queue: {} with tag: {}",
345 self.options.queue_name, consumer_tag
346 );
347
348 while let Some(delivery) = consumer.next().await {
349 let delivery = delivery?;
350 let permit = self
351 .semaphore
352 .clone()
353 .acquire_owned()
354 .await
355 .map_err(|e| RabbitError::Generic(e.into()))?;
356
357 let handler = handler.clone();
358 let retry_policy = self.options.retry_policy.clone();
359 let dead_letter_exchange = self.options.dead_letter_exchange.clone();
360 let channel = self.channel.clone();
361
362 tokio::spawn(async move {
364 let _permit = permit; if let Err(e) = Self::process_message::<T, H>(
367 delivery,
368 handler,
369 retry_policy,
370 dead_letter_exchange,
371 channel,
372 )
373 .await
374 {
375 error!("Error processing message: {}", e);
376 }
377 });
378 }
379
380 warn!(
381 "Consumer stream ended for queue: {}",
382 self.options.queue_name
383 );
384 Ok(())
385 }
386
387 async fn process_message<T, H>(
389 delivery: Delivery,
390 handler: Arc<H>,
391 retry_policy: Option<RetryPolicy>,
392 dead_letter_exchange: Option<String>,
393 channel: Channel,
394 ) -> Result<()>
395 where
396 T: DeserializeOwned + Send + Sync,
397 H: MessageHandler<T>,
398 {
399 let context = Self::build_message_context(&delivery);
400
401 let message: T = match serde_json::from_slice(&delivery.data) {
403 Ok(msg) => msg,
404 Err(e) => {
405 error!("Failed to deserialize message: {}", e);
406 Self::reject_message(&delivery, &channel, false).await?;
407 return Ok(());
408 }
409 };
410
411 let result = handler.handle(message, context.clone()).await;
413
414 match result {
415 MessageResult::Ack => {
416 Self::ack_message(&delivery, &channel).await?;
417 debug!("Message acknowledged: {}", delivery.delivery_tag);
418 }
419 MessageResult::Retry => {
420 if let Some(ref policy) = retry_policy {
421 Self::handle_retry(&delivery, &channel, &context, policy).await?;
422 } else {
423 Self::reject_message(&delivery, &channel, true).await?;
424 }
425 }
426 MessageResult::Reject => {
427 if let Some(ref dle) = dead_letter_exchange {
428 Self::send_to_dead_letter(&delivery, &channel, dle, &context).await?;
429 } else {
430 Self::reject_message(&delivery, &channel, false).await?;
431 }
432 }
433 MessageResult::Requeue => {
434 Self::reject_message(&delivery, &channel, true).await?;
435 }
436 }
437
438 Ok(())
439 }
440
441 fn build_message_context(delivery: &Delivery) -> MessageContext {
443 let properties = &delivery.properties;
444
445 MessageContext {
446 message_id: properties.message_id().as_ref().map(|s| s.to_string()),
447 correlation_id: properties.correlation_id().as_ref().map(|s| s.to_string()),
448 reply_to: properties.reply_to().as_ref().map(|s| s.to_string()),
449 delivery_tag: delivery.delivery_tag,
450 redelivered: delivery.redelivered,
451 exchange: delivery.exchange.to_string(),
452 routing_key: delivery.routing_key.to_string(),
453 headers: properties.headers().clone().unwrap_or_default(),
454 timestamp: *properties.timestamp(),
455 retry_count: Self::get_retry_count_from_headers(
456 properties
457 .headers()
458 .as_ref()
459 .unwrap_or(&FieldTable::default()),
460 ),
461 }
462 }
463
464 fn get_retry_count_from_headers(headers: &FieldTable) -> u32 {
466 headers
467 .inner()
468 .get("x-retry-count")
469 .and_then(|v| match v {
470 lapin::types::AMQPValue::LongInt(count) => Some(*count as u32),
471 lapin::types::AMQPValue::LongLongInt(count) => Some(*count as u32),
472 _ => None,
473 })
474 .unwrap_or(0)
475 }
476
477 async fn ack_message(delivery: &Delivery, channel: &Channel) -> Result<()> {
479 channel
480 .basic_ack(delivery.delivery_tag, BasicAckOptions::default())
481 .await?;
482 Ok(())
483 }
484
485 async fn reject_message(delivery: &Delivery, channel: &Channel, requeue: bool) -> Result<()> {
487 channel
488 .basic_nack(
489 delivery.delivery_tag,
490 BasicNackOptions {
491 multiple: false,
492 requeue,
493 },
494 )
495 .await?;
496 Ok(())
497 }
498
499 async fn handle_retry(
501 delivery: &Delivery,
502 channel: &Channel,
503 context: &MessageContext,
504 retry_policy: &RetryPolicy,
505 ) -> Result<()> {
506 if context.retry_count >= retry_policy.max_retries {
507 warn!(
508 "Max retries exceeded for message: {}",
509 delivery.delivery_tag
510 );
511 Self::reject_message(delivery, channel, false).await?;
512 return Ok(());
513 }
514
515 let delay = retry_policy.calculate_delay(context.retry_count);
517
518 info!(
522 "Retrying message after {:?} (attempt {})",
523 delay,
524 context.retry_count + 1
525 );
526 Self::reject_message(delivery, channel, true).await?;
527
528 Ok(())
529 }
530
531 async fn send_to_dead_letter(
533 delivery: &Delivery,
534 channel: &Channel,
535 dead_letter_exchange: &str,
536 _context: &MessageContext,
537 ) -> Result<()> {
538 warn!(
541 "Sending message to dead letter exchange: {}",
542 dead_letter_exchange
543 );
544 Self::reject_message(delivery, channel, false).await?;
545 Ok(())
546 }
547
548 async fn declare_queue(channel: &Channel, options: &ConsumerOptions) -> Result<()> {
550 let queue_options = LapinQueueDeclareOptions {
551 passive: options.queue_options.passive,
552 durable: options.queue_options.durable,
553 exclusive: options.queue_options.exclusive,
554 auto_delete: options.queue_options.auto_delete,
555 nowait: false,
556 };
557
558 channel
559 .queue_declare(
560 &options.queue_name,
561 queue_options,
562 options.queue_options.arguments.clone(),
563 )
564 .await?;
565
566 debug!("Declared queue: {}", options.queue_name);
567 Ok(())
568 }
569
570 pub async fn stop(&self) -> Result<()> {
572 info!("Stopping consumer for queue: {}", self.options.queue_name);
575 Ok(())
576 }
577}
578
579pub struct SimpleMessageHandler<F, T>
581where
582 F: Fn(T, MessageContext) -> MessageResult + Send + Sync,
583 T: DeserializeOwned + Send + Sync,
584{
585 handler_fn: F,
586 _phantom: std::marker::PhantomData<T>,
587}
588
589impl<F, T> SimpleMessageHandler<F, T>
590where
591 F: Fn(T, MessageContext) -> MessageResult + Send + Sync + 'static,
592 T: DeserializeOwned + Send + Sync + 'static,
593{
594 pub fn new(handler_fn: F) -> Self {
595 Self {
596 handler_fn,
597 _phantom: std::marker::PhantomData,
598 }
599 }
600}
601
602#[async_trait]
603impl<F, T> MessageHandler<T> for SimpleMessageHandler<F, T>
604where
605 F: Fn(T, MessageContext) -> MessageResult + Send + Sync + 'static,
606 T: DeserializeOwned + Send + Sync + 'static,
607{
608 async fn handle(&self, message: T, context: MessageContext) -> MessageResult {
609 (self.handler_fn)(message, context)
610 }
611}