1use anyhow::Result;
2use chrono::{DateTime, Utc};
3use serde::{Deserialize, Serialize};
4use std::cmp::Ordering;
5use std::collections::{BinaryHeap, HashMap, VecDeque};
6use std::sync::{Arc, Mutex};
7use std::time::{Duration, Instant};
8use tokio::sync::Notify;
9use tracing::{debug, error, info, warn};
10use uuid::Uuid;
11
12use crate::error::RustRabbitError;
13
14#[derive(
16 Default, Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize,
17)]
18pub enum Priority {
19 Low = 1,
21 #[default]
23 Normal = 5,
24 High = 8,
26 Critical = 10,
28}
29
30impl Priority {
31 pub fn value(&self) -> u8 {
32 *self as u8
33 }
34
35 pub fn from_value(value: u8) -> Self {
36 match value {
37 0 => Priority::Low,
38 1..=2 => Priority::Low,
39 3..=6 => Priority::Normal,
40 7..=9 => Priority::High,
41 10.. => Priority::Critical,
42 }
43 }
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct PriorityMessage {
49 pub message_id: String,
50 pub priority: Priority,
51 pub payload: Vec<u8>,
52 pub headers: HashMap<String, String>,
53 pub timestamp: DateTime<Utc>,
54 pub expiry: Option<DateTime<Utc>>,
55 pub retry_count: u32,
56 pub max_retries: u32,
57}
58
59impl PriorityMessage {
60 pub fn new(payload: Vec<u8>, priority: Priority) -> Self {
61 Self {
62 message_id: Uuid::new_v4().to_string(),
63 priority,
64 payload,
65 headers: HashMap::new(),
66 timestamp: Utc::now(),
67 expiry: None,
68 retry_count: 0,
69 max_retries: 3,
70 }
71 }
72
73 pub fn with_expiry(mut self, expiry: DateTime<Utc>) -> Self {
74 self.expiry = Some(expiry);
75 self
76 }
77
78 pub fn with_header(mut self, key: String, value: String) -> Self {
79 self.headers.insert(key, value);
80 self
81 }
82
83 pub fn with_max_retries(mut self, max_retries: u32) -> Self {
84 self.max_retries = max_retries;
85 self
86 }
87
88 pub fn is_expired(&self) -> bool {
89 if let Some(expiry) = self.expiry {
90 Utc::now() > expiry
91 } else {
92 false
93 }
94 }
95
96 pub fn can_retry(&self) -> bool {
97 self.retry_count < self.max_retries
98 }
99
100 pub fn increment_retry(&mut self) {
101 self.retry_count += 1;
102 }
103}
104
105#[derive(Debug, Clone)]
107struct PriorityMessageWrapper {
108 message: PriorityMessage,
109 enqueue_time: Instant,
110}
111
112impl PartialEq for PriorityMessageWrapper {
113 fn eq(&self, other: &Self) -> bool {
114 self.message.priority == other.message.priority && self.enqueue_time == other.enqueue_time
115 }
116}
117
118impl Eq for PriorityMessageWrapper {}
119
120impl PartialOrd for PriorityMessageWrapper {
121 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
122 Some(self.cmp(other))
123 }
124}
125
126impl Ord for PriorityMessageWrapper {
127 fn cmp(&self, other: &Self) -> Ordering {
128 match self.message.priority.cmp(&other.message.priority) {
130 Ordering::Equal => other.enqueue_time.cmp(&self.enqueue_time),
131 other => other,
132 }
133 }
134}
135
136#[derive(Debug, Clone)]
138pub struct PriorityQueueConfig {
139 pub max_queue_size: usize,
140 pub dead_letter_enabled: bool,
141 pub dead_letter_threshold: u32,
142 pub cleanup_interval: Duration,
143 pub metrics_enabled: bool,
144}
145
146impl Default for PriorityQueueConfig {
147 fn default() -> Self {
148 Self {
149 max_queue_size: 10_000,
150 dead_letter_enabled: true,
151 dead_letter_threshold: 3,
152 cleanup_interval: Duration::from_secs(60),
153 metrics_enabled: true,
154 }
155 }
156}
157
158#[derive(Debug, Clone)]
160pub struct PriorityQueueStats {
161 pub total_messages: usize,
162 pub messages_by_priority: HashMap<Priority, usize>,
163 pub dead_letter_count: usize,
164 pub expired_count: usize,
165 pub average_wait_time: Duration,
166 pub throughput_per_second: f64,
167}
168
169#[derive(Debug)]
171pub struct PriorityQueue {
172 config: PriorityQueueConfig,
173 queue: Arc<Mutex<BinaryHeap<PriorityMessageWrapper>>>,
174 dead_letter_queue: Arc<Mutex<VecDeque<PriorityMessage>>>,
175 stats: Arc<Mutex<PriorityQueueStats>>,
176 notify: Arc<Notify>,
177}
178
179impl PriorityQueue {
180 pub fn new(config: PriorityQueueConfig) -> Self {
181 let queue = Self {
182 config: config.clone(),
183 queue: Arc::new(Mutex::new(BinaryHeap::new())),
184 dead_letter_queue: Arc::new(Mutex::new(VecDeque::new())),
185 stats: Arc::new(Mutex::new(PriorityQueueStats {
186 total_messages: 0,
187 messages_by_priority: HashMap::new(),
188 dead_letter_count: 0,
189 expired_count: 0,
190 average_wait_time: Duration::ZERO,
191 throughput_per_second: 0.0,
192 })),
193 notify: Arc::new(Notify::new()),
194 };
195
196 if config.cleanup_interval > Duration::ZERO {
198 queue.start_cleanup_task();
199 }
200
201 queue
202 }
203
204 pub fn enqueue(&self, message: PriorityMessage) -> Result<()> {
206 let priority = message.priority;
207
208 debug!(
209 message_id = %message.message_id,
210 priority = ?priority,
211 "Enqueuing priority message"
212 );
213
214 {
215 let mut queue = self.queue.lock().unwrap();
216
217 if queue.len() >= self.config.max_queue_size {
219 warn!(
220 queue_size = queue.len(),
221 max_size = self.config.max_queue_size,
222 "Priority queue is full"
223 );
224 return Err(RustRabbitError::QueueFull.into());
225 }
226
227 let wrapper = PriorityMessageWrapper {
228 message,
229 enqueue_time: Instant::now(),
230 };
231
232 queue.push(wrapper);
233 }
234
235 {
237 let mut stats = self.stats.lock().unwrap();
238 stats.total_messages += 1;
239 *stats.messages_by_priority.entry(priority).or_insert(0) += 1;
240 }
241
242 self.notify.notify_one();
244
245 Ok(())
246 }
247
248 pub fn dequeue(&self) -> Option<PriorityMessage> {
250 let mut queue = self.queue.lock().unwrap();
251
252 while let Some(wrapper) = queue.pop() {
253 let message = wrapper.message;
254
255 if message.is_expired() {
257 warn!(
258 message_id = %message.message_id,
259 "Message expired, moving to dead letter queue"
260 );
261
262 self.move_to_dead_letter(message);
263 continue;
264 }
265
266 debug!(
267 message_id = %message.message_id,
268 priority = ?message.priority,
269 wait_time_ms = wrapper.enqueue_time.elapsed().as_millis(),
270 "Dequeued priority message"
271 );
272
273 {
275 let mut stats = self.stats.lock().unwrap();
276 stats.total_messages = stats.total_messages.saturating_sub(1);
277 if let Some(count) = stats.messages_by_priority.get_mut(&message.priority) {
278 *count = count.saturating_sub(1);
279 }
280 }
281
282 return Some(message);
283 }
284
285 None
286 }
287
288 pub async fn dequeue_async(&self) -> Option<PriorityMessage> {
290 loop {
291 if let Some(message) = self.dequeue() {
292 return Some(message);
293 }
294
295 self.notify.notified().await;
297 }
298 }
299
300 pub async fn dequeue_timeout(&self, timeout: Duration) -> Option<PriorityMessage> {
302 tokio::select! {
303 message = self.dequeue_async() => message,
304 _ = tokio::time::sleep(timeout) => None,
305 }
306 }
307
308 pub fn peek(&self) -> Option<PriorityMessage> {
310 let queue = self.queue.lock().unwrap();
311 queue.peek().map(|wrapper| wrapper.message.clone())
312 }
313
314 pub fn size(&self) -> usize {
316 self.queue.lock().unwrap().len()
317 }
318
319 pub fn is_empty(&self) -> bool {
321 self.queue.lock().unwrap().is_empty()
322 }
323
324 pub fn stats(&self) -> PriorityQueueStats {
326 self.stats.lock().unwrap().clone()
327 }
328
329 pub fn requeue(&self, mut message: PriorityMessage) -> Result<()> {
331 if message.can_retry() {
332 message.increment_retry();
333
334 info!(
335 message_id = %message.message_id,
336 retry_count = message.retry_count,
337 max_retries = message.max_retries,
338 "Requeuing message for retry"
339 );
340
341 self.enqueue(message)
342 } else {
343 warn!(
344 message_id = %message.message_id,
345 retry_count = message.retry_count,
346 "Message exceeded max retries, moving to dead letter queue"
347 );
348
349 self.move_to_dead_letter(message);
350 Ok(())
351 }
352 }
353
354 fn move_to_dead_letter(&self, message: PriorityMessage) {
356 if self.config.dead_letter_enabled {
357 let mut dead_letter = self.dead_letter_queue.lock().unwrap();
358 dead_letter.push_back(message);
359
360 let mut stats = self.stats.lock().unwrap();
362 stats.dead_letter_count += 1;
363 }
364 }
365
366 pub fn dead_letter_messages(&self) -> Vec<PriorityMessage> {
368 self.dead_letter_queue
369 .lock()
370 .unwrap()
371 .iter()
372 .cloned()
373 .collect()
374 }
375
376 pub fn clear_dead_letter(&self) -> usize {
378 let mut dead_letter = self.dead_letter_queue.lock().unwrap();
379 let count = dead_letter.len();
380 dead_letter.clear();
381
382 {
384 let mut stats = self.stats.lock().unwrap();
385 stats.dead_letter_count = 0;
386 }
387
388 count
389 }
390
391 fn start_cleanup_task(&self) {
393 let queue = self.queue.clone();
394 let dead_letter = self.dead_letter_queue.clone();
395 let stats = self.stats.clone();
396 let cleanup_interval = self.config.cleanup_interval;
397 let dead_letter_enabled = self.config.dead_letter_enabled;
398
399 tokio::spawn(async move {
400 let mut interval = tokio::time::interval(cleanup_interval);
401
402 loop {
403 interval.tick().await;
404
405 let mut expired_count = 0;
406
407 {
409 let mut queue_guard = queue.lock().unwrap();
410 let mut temp_queue = BinaryHeap::new();
411
412 while let Some(wrapper) = queue_guard.pop() {
413 if wrapper.message.is_expired() {
414 expired_count += 1;
415
416 if dead_letter_enabled {
417 let mut dead_letter_guard = dead_letter.lock().unwrap();
418 dead_letter_guard.push_back(wrapper.message);
419 }
420 } else {
421 temp_queue.push(wrapper);
422 }
423 }
424
425 *queue_guard = temp_queue;
426 }
427
428 if expired_count > 0 {
430 let mut stats_guard = stats.lock().unwrap();
431 stats_guard.expired_count += expired_count;
432 stats_guard.total_messages =
433 stats_guard.total_messages.saturating_sub(expired_count);
434
435 debug!(
436 expired_count = expired_count,
437 "Cleanup task removed expired messages"
438 );
439 }
440 }
441 });
442 }
443}
444
445#[derive(Debug)]
447pub struct PriorityRouter {
448 queues: HashMap<String, Arc<PriorityQueue>>,
449 default_queue: String,
450}
451
452impl PriorityRouter {
453 pub fn new(default_queue: String) -> Self {
454 Self {
455 queues: HashMap::new(),
456 default_queue,
457 }
458 }
459
460 pub fn add_queue(&mut self, name: String, queue: Arc<PriorityQueue>) {
462 self.queues.insert(name, queue);
463 }
464
465 pub fn route_message(
467 &self,
468 queue_name: Option<String>,
469 message: PriorityMessage,
470 ) -> Result<()> {
471 let queue_name = queue_name.unwrap_or_else(|| self.default_queue.clone());
472
473 if let Some(queue) = self.queues.get(&queue_name) {
474 queue.enqueue(message)
475 } else {
476 error!(queue_name = %queue_name, "Priority queue not found");
477 Err(RustRabbitError::QueueNotFound(queue_name).into())
478 }
479 }
480
481 pub async fn dequeue_any(&self) -> Option<(String, PriorityMessage)> {
483 for (queue_name, queue) in &self.queues {
486 if let Some(message) = queue.dequeue() {
487 return Some((queue_name.clone(), message));
488 }
489 }
490 None
491 }
492
493 pub fn get_queue(&self, name: &str) -> Option<Arc<PriorityQueue>> {
495 self.queues.get(name).cloned()
496 }
497
498 pub fn queue_names(&self) -> Vec<String> {
500 self.queues.keys().cloned().collect()
501 }
502
503 pub fn aggregate_stats(&self) -> HashMap<String, PriorityQueueStats> {
505 self.queues
506 .iter()
507 .map(|(name, queue)| (name.clone(), queue.stats()))
508 .collect()
509 }
510}
511
512#[derive(Debug)]
514pub struct PriorityConsumer {
515 queue: Arc<PriorityQueue>,
516 batch_size: usize,
517 processing_timeout: Duration,
518}
519
520impl PriorityConsumer {
521 pub fn new(queue: Arc<PriorityQueue>) -> Self {
522 Self {
523 queue,
524 batch_size: 1,
525 processing_timeout: Duration::from_secs(30),
526 }
527 }
528
529 pub fn with_batch_size(mut self, batch_size: usize) -> Self {
530 self.batch_size = batch_size;
531 self
532 }
533
534 pub fn with_timeout(mut self, timeout: Duration) -> Self {
535 self.processing_timeout = timeout;
536 self
537 }
538
539 pub async fn consume_batch(&self) -> Vec<PriorityMessage> {
541 let mut batch = Vec::new();
542
543 for _ in 0..self.batch_size {
544 if let Some(message) = self.queue.dequeue_timeout(Duration::from_millis(100)).await {
545 batch.push(message);
546 } else {
547 break; }
549 }
550
551 debug!(batch_size = batch.len(), "Consumed priority message batch");
552 batch
553 }
554
555 pub async fn consume_one(&self) -> Option<PriorityMessage> {
557 self.queue.dequeue_timeout(self.processing_timeout).await
558 }
559
560 pub async fn start_consuming<F, Fut>(&self, mut handler: F) -> Result<()>
562 where
563 F: FnMut(PriorityMessage) -> Fut + Send,
564 Fut: std::future::Future<Output = Result<()>> + Send,
565 {
566 info!("Starting priority consumer");
567
568 loop {
569 if let Some(message) = self.queue.dequeue_async().await {
570 let message_id = message.message_id.clone();
571
572 debug!(
573 message_id = %message_id,
574 priority = ?message.priority,
575 "Processing priority message"
576 );
577
578 match handler(message.clone()).await {
579 Ok(()) => {
580 debug!(message_id = %message_id, "Message processed successfully");
581 }
582 Err(err) => {
583 error!(
584 message_id = %message_id,
585 error = %err,
586 "Message processing failed"
587 );
588
589 if let Err(requeue_err) = self.queue.requeue(message) {
591 error!(
592 message_id = %message_id,
593 error = %requeue_err,
594 "Failed to requeue message"
595 );
596 }
597 }
598 }
599 }
600 }
601 }
602}
603
604#[cfg(test)]
605mod tests {
606 use super::*;
607 use tokio::time::sleep;
608
609 #[test]
610 fn test_priority_ordering() {
611 assert!(Priority::Critical > Priority::High);
612 assert!(Priority::High > Priority::Normal);
613 assert!(Priority::Normal > Priority::Low);
614 }
615
616 #[test]
617 fn test_priority_from_value() {
618 assert_eq!(Priority::from_value(1), Priority::Low);
619 assert_eq!(Priority::from_value(5), Priority::Normal);
620 assert_eq!(Priority::from_value(8), Priority::High);
621 assert_eq!(Priority::from_value(10), Priority::Critical);
622 }
623
624 #[tokio::test]
625 async fn test_priority_queue_ordering() {
626 let config = PriorityQueueConfig::default();
627 let queue = PriorityQueue::new(config);
628
629 queue
631 .enqueue(PriorityMessage::new(b"low".to_vec(), Priority::Low))
632 .unwrap();
633 queue
634 .enqueue(PriorityMessage::new(
635 b"critical".to_vec(),
636 Priority::Critical,
637 ))
638 .unwrap();
639 queue
640 .enqueue(PriorityMessage::new(b"normal".to_vec(), Priority::Normal))
641 .unwrap();
642 queue
643 .enqueue(PriorityMessage::new(b"high".to_vec(), Priority::High))
644 .unwrap();
645
646 let msg1 = queue.dequeue().unwrap();
648 assert_eq!(msg1.priority, Priority::Critical);
649
650 let msg2 = queue.dequeue().unwrap();
651 assert_eq!(msg2.priority, Priority::High);
652
653 let msg3 = queue.dequeue().unwrap();
654 assert_eq!(msg3.priority, Priority::Normal);
655
656 let msg4 = queue.dequeue().unwrap();
657 assert_eq!(msg4.priority, Priority::Low);
658 }
659
660 #[tokio::test]
661 async fn test_fifo_within_same_priority() {
662 let config = PriorityQueueConfig::default();
663 let queue = PriorityQueue::new(config);
664
665 queue
667 .enqueue(PriorityMessage::new(b"first".to_vec(), Priority::Normal))
668 .unwrap();
669 sleep(Duration::from_millis(1)).await; queue
671 .enqueue(PriorityMessage::new(b"second".to_vec(), Priority::Normal))
672 .unwrap();
673 sleep(Duration::from_millis(1)).await;
674 queue
675 .enqueue(PriorityMessage::new(b"third".to_vec(), Priority::Normal))
676 .unwrap();
677
678 let msg1 = queue.dequeue().unwrap();
680 assert_eq!(msg1.payload, b"first");
681
682 let msg2 = queue.dequeue().unwrap();
683 assert_eq!(msg2.payload, b"second");
684
685 let msg3 = queue.dequeue().unwrap();
686 assert_eq!(msg3.payload, b"third");
687 }
688
689 #[tokio::test]
690 async fn test_message_expiry() {
691 let config = PriorityQueueConfig::default();
692 let queue = PriorityQueue::new(config);
693
694 let expired_message = PriorityMessage::new(b"expired".to_vec(), Priority::Normal)
695 .with_expiry(Utc::now() - chrono::Duration::seconds(1));
696
697 queue.enqueue(expired_message).unwrap();
698
699 let result = queue.dequeue();
701 assert!(result.is_none());
702
703 let dead_letters = queue.dead_letter_messages();
705 assert_eq!(dead_letters.len(), 1);
706 }
707
708 #[tokio::test]
709 async fn test_retry_logic() {
710 let config = PriorityQueueConfig::default();
711 let queue = PriorityQueue::new(config);
712
713 let message = PriorityMessage::new(b"retry".to_vec(), Priority::Normal).with_max_retries(2);
714
715 queue.requeue(message.clone()).unwrap();
717 assert_eq!(queue.size(), 1);
718
719 let mut requeued = queue.dequeue().unwrap();
720 assert_eq!(requeued.retry_count, 1);
721
722 queue.requeue(requeued.clone()).unwrap();
724 assert_eq!(queue.size(), 1);
725
726 requeued = queue.dequeue().unwrap();
727 assert_eq!(requeued.retry_count, 2);
728
729 queue.requeue(requeued).unwrap();
731 assert_eq!(queue.size(), 0);
732
733 let dead_letters = queue.dead_letter_messages();
734 assert_eq!(dead_letters.len(), 1);
735 }
736
737 #[tokio::test]
738 async fn test_priority_router() {
739 let mut router = PriorityRouter::new("default".to_string());
740
741 let config = PriorityQueueConfig::default();
742 let queue1 = Arc::new(PriorityQueue::new(config.clone()));
743 let queue2 = Arc::new(PriorityQueue::new(config));
744
745 router.add_queue("queue1".to_string(), queue1.clone());
746 router.add_queue("queue2".to_string(), queue2.clone());
747
748 let message1 = PriorityMessage::new(b"msg1".to_vec(), Priority::High);
749 let message2 = PriorityMessage::new(b"msg2".to_vec(), Priority::Normal);
750
751 router
753 .route_message(Some("queue1".to_string()), message1)
754 .unwrap();
755 router
756 .route_message(Some("queue2".to_string()), message2)
757 .unwrap();
758
759 assert_eq!(queue1.size(), 1);
761 assert_eq!(queue2.size(), 1);
762
763 let msg_from_q1 = queue1.dequeue().unwrap();
764 assert_eq!(msg_from_q1.payload, b"msg1");
765
766 let msg_from_q2 = queue2.dequeue().unwrap();
767 assert_eq!(msg_from_q2.payload, b"msg2");
768 }
769
770 #[tokio::test]
771 async fn test_priority_consumer() {
772 let config = PriorityQueueConfig::default();
773 let queue = Arc::new(PriorityQueue::new(config));
774 let consumer = PriorityConsumer::new(queue.clone()).with_batch_size(2);
775
776 queue
778 .enqueue(PriorityMessage::new(b"msg1".to_vec(), Priority::High))
779 .unwrap();
780 queue
781 .enqueue(PriorityMessage::new(b"msg2".to_vec(), Priority::Normal))
782 .unwrap();
783
784 let batch = consumer.consume_batch().await;
786 assert_eq!(batch.len(), 2);
787 assert_eq!(batch[0].priority, Priority::High); assert_eq!(batch[1].priority, Priority::Normal);
789 }
790
791 #[tokio::test]
792 async fn test_queue_full_behavior() {
793 let config = PriorityQueueConfig {
794 max_queue_size: 2,
795 ..Default::default()
796 };
797 let queue = PriorityQueue::new(config);
798
799 queue
801 .enqueue(PriorityMessage::new(b"msg1".to_vec(), Priority::Normal))
802 .unwrap();
803 queue
804 .enqueue(PriorityMessage::new(b"msg2".to_vec(), Priority::Normal))
805 .unwrap();
806
807 let result = queue.enqueue(PriorityMessage::new(b"msg3".to_vec(), Priority::Normal));
809 assert!(result.is_err());
810 }
811}