1use crate::{Task, TaskId, TaskMetadata, TaskQueueError, TaskWrapper};
2use deadpool_redis::{Config, Pool, Runtime};
3use redis::AsyncCommands;
4use serde::{Deserialize, Serialize};
5
6pub struct RedisBroker {
7 pub(crate) pool: Pool,
8}
9
10impl RedisBroker {
11 pub async fn new(redis_url: &str) -> Result<Self, TaskQueueError> {
12 Self::new_with_config(redis_url, None).await
13 }
14
15 pub async fn new_with_config(
16 redis_url: &str,
17 pool_size: Option<usize>,
18 ) -> Result<Self, TaskQueueError> {
19 let mut config = Config::from_url(redis_url);
20 if let Some(size) = pool_size {
21 config.pool = Some(deadpool_redis::PoolConfig::new(size));
22 }
23
24 let pool = config.create_pool(Some(Runtime::Tokio1)).map_err(|e| {
25 TaskQueueError::Connection(format!("Failed to create Redis pool: {}", e))
26 })?;
27
28 let mut conn = pool.get().await.map_err(|e| {
30 TaskQueueError::Connection(format!("Failed to connect to Redis: {}", e))
31 })?;
32
33 redis::cmd("PING")
35 .query_async::<_, String>(&mut conn)
36 .await
37 .map_err(|e| {
38 TaskQueueError::Connection(format!("Redis connection test failed: {}", e))
39 })?;
40
41 Ok(Self { pool })
42 }
43
44 async fn get_conn(&self) -> Result<deadpool_redis::Connection, TaskQueueError> {
45 self.pool
46 .get()
47 .await
48 .map_err(|e| TaskQueueError::Connection(e.to_string()))
49 }
50
51 pub async fn enqueue_task<T: Task>(
52 &self,
53 task: T,
54 queue: &str,
55 ) -> Result<TaskId, TaskQueueError> {
56 let task_name = task.name();
57 let priority = task.priority();
58 let max_retries = task.max_retries();
59 let timeout_seconds = task.timeout_seconds();
60 let enqueue_start = std::time::Instant::now();
61
62 #[cfg(feature = "tracing")]
63 tracing::info!(
64 task_name = task_name,
65 queue = queue,
66 priority = ?priority,
67 max_retries = max_retries,
68 timeout_seconds = timeout_seconds,
69 "Enqueuing task"
70 );
71
72 let task_id = TaskId::new_v4();
73
74 let metadata = TaskMetadata {
76 id: task_id,
77 name: task.name().to_string(),
78 created_at: chrono::Utc::now(),
79 attempts: 0,
80 max_retries: task.max_retries(),
81 timeout_seconds: task.timeout_seconds(),
82 };
83
84 let payload = rmp_serde::to_vec(&task)?;
86 let payload_len = payload.len(); let task_wrapper = TaskWrapper {
89 metadata: metadata.clone(),
90 payload,
91 };
92
93 self.enqueue_task_wrapper(task_wrapper, queue).await?;
94
95 #[cfg(feature = "tracing")]
96 tracing::info!(
97 task_id = %task_id,
98 task_name = task_name,
99 queue = queue,
100 priority = ?priority,
101 duration_ms = enqueue_start.elapsed().as_millis(),
102 payload_size_bytes = payload_len,
103 "Task enqueued successfully"
104 );
105
106 Ok(task_id)
107 }
108
109 fn validate_queue_name(queue: &str) -> Result<(), TaskQueueError> {
111 if queue.is_empty() {
112 return Err(TaskQueueError::Queue(
113 "Queue name cannot be empty".to_string(),
114 ));
115 }
116
117 if queue.len() > 255 {
118 return Err(TaskQueueError::Queue(
119 "Queue name too long (max 255 characters)".to_string(),
120 ));
121 }
122
123 if !queue
125 .chars()
126 .all(|c| c.is_alphanumeric() || matches!(c, '-' | '_' | ':'))
127 {
128 return Err(TaskQueueError::Queue(
129 "Queue name contains invalid characters. Only alphanumeric, dash, underscore, and colon allowed".to_string()
130 ));
131 }
132
133 let lowercase = queue.to_lowercase();
135 let dangerous_patterns = [
136 "eval",
137 "script",
138 "flushall",
139 "flushdb",
140 "shutdown",
141 "debug",
142 "config",
143 "info",
144 "monitor",
145 "sync",
146 "psync",
147 "slaveof",
148 "replicaof",
149 ];
150
151 for pattern in dangerous_patterns {
152 if lowercase.contains(pattern) {
153 return Err(TaskQueueError::Queue(format!(
154 "Queue name contains potentially dangerous pattern: {}",
155 pattern
156 )));
157 }
158 }
159
160 Ok(())
161 }
162
163 fn validate_task_payload(payload: &[u8]) -> Result<(), TaskQueueError> {
165 const MAX_PAYLOAD_SIZE: usize = 16 * 1024 * 1024; if payload.len() > MAX_PAYLOAD_SIZE {
168 return Err(TaskQueueError::TaskExecution(format!(
169 "Task payload too large: {} bytes (max: {} bytes)",
170 payload.len(),
171 MAX_PAYLOAD_SIZE
172 )));
173 }
174
175 if payload.is_empty() {
177 return Err(TaskQueueError::TaskExecution(
178 "Task payload cannot be empty".to_string(),
179 ));
180 }
181
182 Ok(())
183 }
184
185 pub async fn enqueue_task_wrapper(
187 &self,
188 task_wrapper: TaskWrapper,
189 queue: &str,
190 ) -> Result<TaskId, TaskQueueError> {
191 let operation_start = std::time::Instant::now();
192 let task_id = task_wrapper.metadata.id;
193 let task_name = &task_wrapper.metadata.name;
194
195 #[cfg(feature = "tracing")]
196 tracing::debug!(
197 task_id = %task_id,
198 task_name = task_name,
199 queue = queue,
200 attempts = task_wrapper.metadata.attempts,
201 max_retries = task_wrapper.metadata.max_retries,
202 "Enqueuing task wrapper"
203 );
204
205 Self::validate_queue_name(queue)?;
207
208 let serialized = rmp_serde::to_vec(&task_wrapper)?;
210 Self::validate_task_payload(&serialized)?;
211
212 if task_wrapper.metadata.name.is_empty() {
214 #[cfg(feature = "tracing")]
215 tracing::error!(
216 task_id = %task_id,
217 "Task name validation failed: empty name"
218 );
219 return Err(TaskQueueError::TaskExecution(
220 "Task name cannot be empty".to_string(),
221 ));
222 }
223
224 if task_wrapper.metadata.name.len() > 255 {
225 #[cfg(feature = "tracing")]
226 tracing::error!(
227 task_id = %task_id,
228 task_name = task_name,
229 name_length = task_wrapper.metadata.name.len(),
230 "Task name validation failed: name too long"
231 );
232 return Err(TaskQueueError::TaskExecution(
233 "Task name too long (max 255 characters)".to_string(),
234 ));
235 }
236
237 let mut conn = self.get_conn().await?;
238
239 let _pipeline_result: Vec<()> = redis::pipe()
241 .atomic()
242 .lpush(queue, &serialized)
244 .set_ex(
246 format!("task:{}:metadata", task_wrapper.metadata.id),
247 rmp_serde::to_vec(&task_wrapper.metadata)?,
248 3600, )
250 .query_async(&mut *conn)
251 .await?;
252
253 #[cfg(feature = "tracing")]
254 tracing::info!(
255 task_id = %task_wrapper.metadata.id,
256 task_name = task_name,
257 queue = queue,
258 duration_ms = operation_start.elapsed().as_millis(),
259 payload_size_bytes = serialized.len(),
260 metadata_ttl_seconds = 3600,
261 "Task wrapper enqueued successfully using pipeline"
262 );
263
264 Ok(task_wrapper.metadata.id)
265 }
266
267 pub async fn enqueue_tasks_batch<T: Task>(
269 &self,
270 tasks: Vec<(T, &str)>,
271 ) -> Result<Vec<TaskId>, TaskQueueError> {
272 if tasks.is_empty() {
273 #[cfg(feature = "tracing")]
274 tracing::warn!("Batch enqueue called with empty task list");
275 return Ok(Vec::new());
276 }
277
278 let batch_start = std::time::Instant::now();
279 let batch_size = tasks.len();
280
281 #[cfg(feature = "tracing")]
282 tracing::info!(batch_size = batch_size, "Starting batch enqueue operation");
283
284 let mut conn = self.get_conn().await?;
285 let mut pipeline = redis::pipe();
286 pipeline.atomic();
287
288 let mut task_ids = Vec::with_capacity(tasks.len());
289 let mut queue_distribution = std::collections::HashMap::new();
290
291 for (task, queue) in tasks {
292 Self::validate_queue_name(queue)?;
294
295 let task_id = TaskId::new_v4();
296 task_ids.push(task_id);
297
298 *queue_distribution.entry(queue.to_string()).or_insert(0) += 1;
300
301 let metadata = TaskMetadata {
302 id: task_id,
303 name: task.name().to_string(),
304 created_at: chrono::Utc::now(),
305 attempts: 0,
306 max_retries: task.max_retries(),
307 timeout_seconds: task.timeout_seconds(),
308 };
309
310 let payload = rmp_serde::to_vec(&task)?;
311 let payload_len = payload.len();
312 let task_wrapper = TaskWrapper {
313 metadata: metadata.clone(),
314 payload,
315 };
316 let serialized = rmp_serde::to_vec(&task_wrapper)?;
317
318 Self::validate_task_payload(&serialized)?;
320
321 pipeline.lpush(queue, &serialized);
323 pipeline.set_ex(
324 format!("task:{}:metadata", task_id),
325 rmp_serde::to_vec(&metadata)?,
326 3600,
327 );
328
329 #[cfg(feature = "tracing")]
330 tracing::debug!(
331 task_id = %task_id,
332 task_name = task.name(),
333 queue = queue,
334 payload_size_bytes = payload_len,
335 "Task added to batch pipeline"
336 );
337 }
338
339 let _: Vec<()> = pipeline.query_async(&mut *conn).await?;
341
342 #[cfg(feature = "tracing")]
343 tracing::info!(
344 batch_size = task_ids.len(),
345 duration_ms = batch_start.elapsed().as_millis(),
346 queue_distribution = ?queue_distribution,
347 total_task_ids = task_ids.len(),
348 "Batch enqueue completed successfully using pipeline"
349 );
350
351 Ok(task_ids)
352 }
353
354 pub async fn dequeue_task(
355 &self,
356 queues: &[String],
357 ) -> Result<Option<TaskWrapper>, TaskQueueError> {
358 let dequeue_start = std::time::Instant::now();
359
360 #[cfg(feature = "tracing")]
361 tracing::debug!(
362 queues = ?queues,
363 queue_count = queues.len(),
364 "Starting dequeue operation"
365 );
366
367 let mut conn = self.get_conn().await?;
368
369 let result: Option<(String, Vec<u8>)> = conn.brpop(queues, 5f64).await?;
371
372 if let Some((queue, serialized)) = result {
373 let task_wrapper: TaskWrapper = rmp_serde::from_slice(&serialized)?;
374
375 #[cfg(feature = "tracing")]
376 tracing::info!(
377 task_id = %task_wrapper.metadata.id,
378 task_name = %task_wrapper.metadata.name,
379 queue = queue,
380 duration_ms = dequeue_start.elapsed().as_millis(),
381 payload_size_bytes = serialized.len(),
382 attempts = task_wrapper.metadata.attempts,
383 max_retries = task_wrapper.metadata.max_retries,
384 created_at = %task_wrapper.metadata.created_at,
385 "Task dequeued successfully"
386 );
387
388 Ok(Some(task_wrapper))
389 } else {
390 #[cfg(feature = "tracing")]
391 tracing::trace!(
392 duration_ms = dequeue_start.elapsed().as_millis(),
393 queues = ?queues,
394 "Dequeue operation timed out - no tasks available"
395 );
396 Ok(None)
397 }
398 }
399
400 pub async fn get_queue_size(&self, queue: &str) -> Result<i64, TaskQueueError> {
401 let mut conn = self.get_conn().await?;
402 let size: i64 = conn.llen(queue).await?;
403 Ok(size)
404 }
405
406 pub async fn get_queue_metrics(&self, queue: &str) -> Result<QueueMetrics, TaskQueueError> {
407 let operation_start = std::time::Instant::now();
408
409 #[cfg(feature = "tracing")]
410 tracing::debug!(queue = queue, "Retrieving queue metrics");
411
412 let mut conn = self.get_conn().await?;
413
414 let size: i64 = conn.llen(queue).await?;
415 let processed_key = format!("queue:{}:processed", queue);
416 let failed_key = format!("queue:{}:failed", queue);
417
418 let processed: i64 = conn.get(&processed_key).await.unwrap_or(0);
419 let failed: i64 = conn.get(&failed_key).await.unwrap_or(0);
420
421 let metrics = QueueMetrics {
422 queue_name: queue.to_string(),
423 pending_tasks: size,
424 processed_tasks: processed,
425 failed_tasks: failed,
426 };
427
428 #[cfg(feature = "tracing")]
429 tracing::debug!(
430 queue = queue,
431 pending_tasks = metrics.pending_tasks,
432 processed_tasks = metrics.processed_tasks,
433 failed_tasks = metrics.failed_tasks,
434 total_tasks = metrics.pending_tasks + metrics.processed_tasks + metrics.failed_tasks,
435 success_rate = if (metrics.processed_tasks + metrics.failed_tasks) > 0 {
436 metrics.processed_tasks as f64
437 / (metrics.processed_tasks + metrics.failed_tasks) as f64
438 } else {
439 0.0
440 },
441 duration_ms = operation_start.elapsed().as_millis(),
442 "Queue metrics retrieved"
443 );
444
445 Ok(metrics)
446 }
447
448 pub async fn mark_task_completed(
449 &self,
450 task_id: TaskId,
451 queue: &str,
452 ) -> Result<(), TaskQueueError> {
453 let operation_start = std::time::Instant::now();
454
455 #[cfg(feature = "tracing")]
456 tracing::debug!(
457 task_id = %task_id,
458 queue = queue,
459 "Marking task as completed"
460 );
461
462 let mut conn = self.get_conn().await?;
463 let processed_key = format!("queue:{}:processed", queue);
464 conn.incr::<_, _, ()>(&processed_key, 1).await?;
465
466 let metadata_key = format!("task:{}:metadata", task_id);
468 conn.del::<_, ()>(&metadata_key).await?;
469
470 #[cfg(feature = "tracing")]
471 tracing::info!(
472 task_id = %task_id,
473 queue = queue,
474 duration_ms = operation_start.elapsed().as_millis(),
475 processed_key = processed_key,
476 "Task marked as completed successfully"
477 );
478
479 Ok(())
480 }
481
482 pub async fn mark_task_failed_with_reason(
483 &self,
484 task_id: TaskId,
485 queue: &str,
486 reason: Option<String>,
487 ) -> Result<(), TaskQueueError> {
488 let mut conn = self.get_conn().await?;
489
490 let failed_key = format!("queue:{}:failed", queue);
492 conn.incr::<_, _, ()>(&failed_key, 1).await?;
493
494 let default_reason = reason.unwrap_or_else(|| "Unknown error".to_string());
495
496 let failure_key = format!("task:{}:failure", task_id);
498 let failure_info = TaskFailureInfo {
499 task_id,
500 queue: queue.to_string(),
501 failed_at: chrono::Utc::now().to_rfc3339(),
502 reason: default_reason.clone(),
503 status: "failed".to_string(),
504 };
505
506 conn.set::<_, _, ()>(&failure_key, rmp_serde::to_vec(&failure_info)?)
508 .await?;
509 conn.expire::<_, ()>(&failure_key, 86400).await?;
510
511 let queue_failed_set = format!("queue:{}:failed_tasks", queue);
513 conn.sadd::<_, _, ()>(&queue_failed_set, task_id.to_string())
514 .await?;
515 conn.expire::<_, ()>(&queue_failed_set, 86400).await?;
516
517 let metadata_key = format!("task:{}:metadata", task_id);
519 conn.del::<_, ()>(&metadata_key).await?;
520
521 #[cfg(feature = "tracing")]
522 tracing::warn!(
523 "Task {} marked as failed in queue {} - Reason: {}",
524 task_id,
525 queue,
526 default_reason
527 );
528
529 Ok(())
530 }
531
532 pub async fn mark_task_failed(
534 &self,
535 task_id: TaskId,
536 queue: &str,
537 ) -> Result<(), TaskQueueError> {
538 self.mark_task_failed_with_reason(task_id, queue, None)
539 .await
540 }
541
542 pub async fn get_active_workers(&self) -> Result<i64, TaskQueueError> {
543 let mut conn = self.get_conn().await?;
544 let count: i64 = conn.scard("active_workers").await?;
545 Ok(count)
546 }
547
548 pub async fn register_worker(&self, worker_id: &str) -> Result<(), TaskQueueError> {
549 let operation_start = std::time::Instant::now();
550
551 #[cfg(feature = "tracing")]
552 tracing::info!(worker_id = worker_id, "Registering worker");
553
554 let mut conn = self.get_conn().await?;
555 conn.sadd::<_, _, ()>("active_workers", worker_id).await?;
556
557 let heartbeat_key = format!("worker:{}:heartbeat", worker_id);
559 let heartbeat_timestamp = chrono::Utc::now().to_rfc3339();
560 conn.set::<_, _, ()>(&heartbeat_key, &heartbeat_timestamp)
561 .await?;
562 conn.expire::<_, ()>(&heartbeat_key, 60).await?;
563
564 let active_workers = self.get_active_workers().await.unwrap_or(0);
565
566 #[cfg(feature = "tracing")]
567 tracing::info!(
568 worker_id = worker_id,
569 duration_ms = operation_start.elapsed().as_millis(),
570 heartbeat_key = heartbeat_key,
571 heartbeat_timestamp = heartbeat_timestamp,
572 total_active_workers = active_workers,
573 "Worker registered successfully"
574 );
575
576 Ok(())
577 }
578
579 pub async fn unregister_worker(&self, worker_id: &str) -> Result<(), TaskQueueError> {
580 let operation_start = std::time::Instant::now();
581
582 #[cfg(feature = "tracing")]
583 tracing::info!(worker_id = worker_id, "Unregistering worker");
584
585 let mut conn = self.get_conn().await?;
586 conn.srem::<_, _, ()>("active_workers", worker_id).await?;
587
588 let heartbeat_key = format!("worker:{}:heartbeat", worker_id);
590 conn.del::<_, ()>(&heartbeat_key).await?;
591
592 let active_workers = self.get_active_workers().await.unwrap_or(0);
593
594 #[cfg(feature = "tracing")]
595 tracing::info!(
596 worker_id = worker_id,
597 duration_ms = operation_start.elapsed().as_millis(),
598 heartbeat_key = heartbeat_key,
599 total_active_workers = active_workers,
600 "Worker unregistered successfully"
601 );
602
603 Ok(())
604 }
605
606 pub async fn update_worker_heartbeat(&self, worker_id: &str) -> Result<(), TaskQueueError> {
607 let operation_start = std::time::Instant::now();
608
609 let mut conn = self.get_conn().await?;
610 let heartbeat_key = format!("worker:{}:heartbeat", worker_id);
611 let heartbeat_timestamp = chrono::Utc::now().to_rfc3339();
612 conn.set::<_, _, ()>(&heartbeat_key, &heartbeat_timestamp)
613 .await?;
614 conn.expire::<_, ()>(&heartbeat_key, 60).await?;
615
616 #[cfg(feature = "tracing")]
617 tracing::trace!(
618 worker_id = worker_id,
619 duration_ms = operation_start.elapsed().as_millis(),
620 heartbeat_timestamp = heartbeat_timestamp,
621 "Worker heartbeat updated"
622 );
623
624 Ok(())
625 }
626
627 pub async fn get_task_failure_info(
628 &self,
629 task_id: TaskId,
630 ) -> Result<Option<TaskFailureInfo>, TaskQueueError> {
631 let mut conn = self.get_conn().await?;
632 let failure_key = format!("task:{}:failure", task_id);
633
634 if let Ok(data) = conn.get::<_, Vec<u8>>(&failure_key).await {
635 match rmp_serde::from_slice::<TaskFailureInfo>(&data) {
636 Ok(info) => Ok(Some(info)),
637 Err(_) => Ok(None),
638 }
639 } else {
640 Ok(None)
641 }
642 }
643
644 pub async fn get_failed_tasks(&self, queue: &str) -> Result<Vec<String>, TaskQueueError> {
645 let mut conn = self.get_conn().await?;
646 let queue_failed_set = format!("queue:{}:failed_tasks", queue);
647 let failed_tasks: Vec<String> = conn.smembers(&queue_failed_set).await.unwrap_or_default();
648 Ok(failed_tasks)
649 }
650}
651
652#[derive(Debug, Clone, Serialize, Deserialize)]
653pub struct TaskFailureInfo {
654 pub task_id: TaskId,
655 pub queue: String,
656 pub failed_at: String,
657 pub reason: String,
658 pub status: String,
659}
660
661#[derive(Debug, Clone, Serialize, Deserialize)]
662pub struct QueueMetrics {
663 pub queue_name: String,
664 pub pending_tasks: i64,
665 pub processed_tasks: i64,
666 pub failed_tasks: i64,
667}
668
669#[cfg(test)]
670mod tests {
671 use super::*;
672 use serde::{Deserialize, Serialize};
673
674 #[derive(Debug, Serialize, Deserialize, Clone)]
675 struct TestTask {
676 data: String,
677 }
678
679 #[async_trait::async_trait]
680 impl Task for TestTask {
681 async fn execute(&self) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
682 Ok(self.data.as_bytes().to_vec())
683 }
684
685 fn name(&self) -> &str {
686 "test_task"
687 }
688 }
689
690 fn get_test_redis_url() -> String {
691 use std::collections::hash_map::DefaultHasher;
693 use std::hash::{Hash, Hasher};
694
695 let mut hasher = DefaultHasher::new();
696 std::thread::current().id().hash(&mut hasher);
697 std::time::SystemTime::now()
698 .duration_since(std::time::UNIX_EPOCH)
699 .unwrap_or_default()
700 .as_nanos()
701 .hash(&mut hasher);
702
703 let db_num = (hasher.finish() % 16) as u8; std::env::var("REDIS_TEST_URL")
705 .unwrap_or_else(|_| format!("redis://127.0.0.1:6379/{}", db_num))
706 }
707
708 async fn create_test_broker() -> RedisBroker {
709 let redis_url = get_test_redis_url();
710 RedisBroker::new(&redis_url)
711 .await
712 .expect("Failed to create test broker")
713 }
714
715 async fn cleanup_test_data(broker: &RedisBroker) {
716 if let Ok(mut conn) = broker.get_conn().await {
717 let _: Result<String, _> = redis::cmd("FLUSHDB").query_async(&mut conn).await;
719 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
720 }
721 }
722
723 #[tokio::test]
724 async fn test_broker_creation() {
725 let broker = create_test_broker().await;
726 cleanup_test_data(&broker).await; assert!(broker.get_conn().await.is_ok());
730
731 cleanup_test_data(&broker).await; }
733
734 #[tokio::test]
735 async fn test_broker_creation_with_config() {
736 let redis_url = get_test_redis_url();
737 let broker = RedisBroker::new_with_config(&redis_url, Some(5))
738 .await
739 .expect("Failed to create broker");
740 cleanup_test_data(&broker).await; assert!(broker.get_conn().await.is_ok());
743
744 cleanup_test_data(&broker).await; }
746
747 #[tokio::test]
748 async fn test_broker_invalid_url() {
749 let result = RedisBroker::new("redis://invalid-host:6379").await;
750 assert!(result.is_err());
751
752 if let Err(e) = result {
753 assert!(matches!(e, TaskQueueError::Connection(_)));
754 }
755 }
756
757 #[tokio::test]
758 async fn test_enqueue_task() {
759 let broker = create_test_broker().await;
760 cleanup_test_data(&broker).await; let task = TestTask {
763 data: "test data".to_string(),
764 };
765
766 let task_id = broker
767 .enqueue_task(task, "test_queue")
768 .await
769 .expect("Failed to enqueue task");
770
771 let queue_size = broker
773 .get_queue_size("test_queue")
774 .await
775 .expect("Failed to get queue size");
776 assert_eq!(queue_size, 1);
777
778 assert!(!task_id.to_string().is_empty());
780
781 cleanup_test_data(&broker).await; }
783
784 #[tokio::test]
785 async fn test_dequeue_task() {
786 let broker = create_test_broker().await;
787 cleanup_test_data(&broker).await; let task = TestTask {
790 data: "test data".to_string(),
791 };
792
793 let task_id = broker
794 .enqueue_task(task, "test_queue")
795 .await
796 .expect("Failed to enqueue task");
797
798 let queue_size = broker
800 .get_queue_size("test_queue")
801 .await
802 .expect("Failed to get queue size");
803 assert_eq!(
804 queue_size, 1,
805 "Queue should have exactly 1 task before dequeue"
806 );
807
808 let queues = vec!["test_queue".to_string()];
810 let dequeued = broker
811 .dequeue_task(&queues)
812 .await
813 .expect("Failed to dequeue task");
814
815 assert!(dequeued.is_some(), "Should have dequeued a task");
816 let task_wrapper = dequeued.unwrap();
817 assert_eq!(
818 task_wrapper.metadata.id, task_id,
819 "Task ID should match the enqueued task"
820 );
821 assert_eq!(
822 task_wrapper.metadata.name, "test_task",
823 "Task name should match"
824 );
825
826 let queue_size = broker
828 .get_queue_size("test_queue")
829 .await
830 .expect("Failed to get queue size");
831 assert_eq!(queue_size, 0, "Queue should be empty after dequeue");
832
833 cleanup_test_data(&broker).await; }
835
836 #[tokio::test]
837 async fn test_dequeue_from_empty_queue() {
838 let broker = create_test_broker().await;
839 cleanup_test_data(&broker).await;
840
841 let queues = vec!["empty_queue".to_string()];
842
843 let start = std::time::Instant::now();
845 let result = broker
846 .dequeue_task(&queues)
847 .await
848 .expect("Failed to dequeue from empty queue");
849 let elapsed = start.elapsed();
850
851 assert!(result.is_none());
852 assert!(elapsed.as_secs() >= 4 && elapsed.as_secs() <= 6);
854
855 cleanup_test_data(&broker).await;
856 }
857
858 #[tokio::test]
859 async fn test_queue_metrics() {
860 let broker = create_test_broker().await;
861 cleanup_test_data(&broker).await; let metrics = broker
865 .get_queue_metrics("test_queue")
866 .await
867 .expect("Failed to get metrics");
868 assert_eq!(metrics.pending_tasks, 0);
869 assert_eq!(metrics.processed_tasks, 0);
870 assert_eq!(metrics.failed_tasks, 0);
871
872 let task = TestTask {
874 data: "test".to_string(),
875 };
876 broker
877 .enqueue_task(task, "test_queue")
878 .await
879 .expect("Failed to enqueue task");
880
881 let metrics = broker
882 .get_queue_metrics("test_queue")
883 .await
884 .expect("Failed to get metrics");
885 assert_eq!(metrics.pending_tasks, 1);
886
887 cleanup_test_data(&broker).await; }
889
890 #[tokio::test]
891 async fn test_mark_task_completed() {
892 let broker = create_test_broker().await;
893 cleanup_test_data(&broker).await;
894
895 let task = TestTask {
896 data: "test".to_string(),
897 };
898 let task_id = broker
899 .enqueue_task(task, "test_queue")
900 .await
901 .expect("Failed to enqueue task");
902
903 broker
905 .mark_task_completed(task_id, "test_queue")
906 .await
907 .expect("Failed to mark completed");
908
909 let metrics = broker
910 .get_queue_metrics("test_queue")
911 .await
912 .expect("Failed to get metrics");
913 assert_eq!(metrics.processed_tasks, 1);
914
915 cleanup_test_data(&broker).await;
916 }
917
918 #[tokio::test]
919 async fn test_mark_task_failed() {
920 let broker = create_test_broker().await;
921 cleanup_test_data(&broker).await;
922
923 let task = TestTask {
924 data: "test".to_string(),
925 };
926 let task_id = broker
927 .enqueue_task(task, "test_queue")
928 .await
929 .expect("Failed to enqueue task");
930
931 broker
933 .mark_task_failed(task_id, "test_queue")
934 .await
935 .expect("Failed to mark failed");
936
937 let metrics = broker
938 .get_queue_metrics("test_queue")
939 .await
940 .expect("Failed to get metrics");
941 assert_eq!(metrics.failed_tasks, 1);
942
943 cleanup_test_data(&broker).await;
944 }
945
946 #[tokio::test]
947 async fn test_mark_task_failed_with_reason() {
948 let broker = create_test_broker().await;
949 cleanup_test_data(&broker).await;
950
951 let task = TestTask {
952 data: "test".to_string(),
953 };
954 let task_id = broker
955 .enqueue_task(task, "test_queue")
956 .await
957 .expect("Failed to enqueue task");
958
959 let reason = "Custom failure reason".to_string();
960 broker
961 .mark_task_failed_with_reason(task_id, "test_queue", Some(reason.clone()))
962 .await
963 .expect("Failed to mark failed with reason");
964
965 let failure_info = broker
967 .get_task_failure_info(task_id)
968 .await
969 .expect("Failed to get failure info");
970 assert!(failure_info.is_some());
971
972 let info = failure_info.unwrap();
973 assert_eq!(info.task_id, task_id);
974 assert_eq!(info.queue, "test_queue");
975 assert_eq!(info.reason, reason);
976 assert_eq!(info.status, "failed");
977
978 cleanup_test_data(&broker).await;
979 }
980
981 #[tokio::test]
982 async fn test_worker_registration() {
983 let broker = create_test_broker().await;
984 cleanup_test_data(&broker).await; let worker_id = "test_worker_001";
987
988 broker
990 .register_worker(worker_id)
991 .await
992 .expect("Failed to register worker");
993
994 let active_workers = broker
995 .get_active_workers()
996 .await
997 .expect("Failed to get active workers");
998 assert_eq!(active_workers, 1);
999
1000 broker
1002 .update_worker_heartbeat(worker_id)
1003 .await
1004 .expect("Failed to update heartbeat");
1005
1006 broker
1008 .unregister_worker(worker_id)
1009 .await
1010 .expect("Failed to unregister worker");
1011
1012 let active_workers = broker
1013 .get_active_workers()
1014 .await
1015 .expect("Failed to get active workers");
1016 assert_eq!(active_workers, 0);
1017
1018 cleanup_test_data(&broker).await; }
1020
1021 #[tokio::test]
1022 async fn test_multiple_workers() {
1023 let broker = create_test_broker().await;
1024 cleanup_test_data(&broker).await;
1025
1026 for i in 0..5 {
1028 let worker_id = format!("worker_{}", i);
1029 broker
1030 .register_worker(&worker_id)
1031 .await
1032 .expect("Failed to register worker");
1033 }
1034
1035 let active_workers = broker
1036 .get_active_workers()
1037 .await
1038 .expect("Failed to get active workers");
1039 assert_eq!(active_workers, 5);
1040
1041 cleanup_test_data(&broker).await;
1042 }
1043
1044 #[tokio::test]
1045 async fn test_failed_tasks_tracking() {
1046 let broker = create_test_broker().await;
1047 cleanup_test_data(&broker).await;
1048
1049 for i in 0..3 {
1051 let task = TestTask {
1052 data: format!("task_{}", i),
1053 };
1054 let task_id = broker
1055 .enqueue_task(task, "test_queue")
1056 .await
1057 .expect("Failed to enqueue task");
1058 broker
1059 .mark_task_failed(task_id, "test_queue")
1060 .await
1061 .expect("Failed to mark failed");
1062 }
1063
1064 let failed_tasks = broker
1065 .get_failed_tasks("test_queue")
1066 .await
1067 .expect("Failed to get failed tasks");
1068 assert_eq!(failed_tasks.len(), 3);
1069
1070 cleanup_test_data(&broker).await;
1071 }
1072
1073 #[tokio::test]
1074 async fn test_queue_metrics_comprehensive() {
1075 let broker = create_test_broker().await;
1076 cleanup_test_data(&broker).await;
1077
1078 for i in 0..3 {
1080 let task = TestTask {
1081 data: format!("pending_{}", i),
1082 };
1083 broker
1084 .enqueue_task(task, "test_queue")
1085 .await
1086 .expect("Failed to enqueue task");
1087 }
1088
1089 for i in 0..2 {
1091 let task = TestTask {
1092 data: format!("processed_{}", i),
1093 };
1094 let task_id = broker
1095 .enqueue_task(task, "temp_queue")
1096 .await
1097 .expect("Failed to enqueue task");
1098 broker
1099 .mark_task_completed(task_id, "test_queue")
1100 .await
1101 .expect("Failed to mark completed");
1102 }
1103
1104 for i in 0..1 {
1106 let task = TestTask {
1107 data: format!("failed_{}", i),
1108 };
1109 let task_id = broker
1110 .enqueue_task(task, "temp_queue")
1111 .await
1112 .expect("Failed to enqueue task");
1113 broker
1114 .mark_task_failed(task_id, "test_queue")
1115 .await
1116 .expect("Failed to mark failed");
1117 }
1118
1119 let metrics = broker
1120 .get_queue_metrics("test_queue")
1121 .await
1122 .expect("Failed to get metrics");
1123 assert_eq!(metrics.pending_tasks, 3);
1124 assert_eq!(metrics.processed_tasks, 2);
1125 assert_eq!(metrics.failed_tasks, 1);
1126 assert_eq!(metrics.queue_name, "test_queue");
1127
1128 cleanup_test_data(&broker).await;
1129 }
1130
1131 #[tokio::test]
1132 async fn test_task_failure_info_serialization() {
1133 let task_id = TaskId::new_v4();
1134 let failure_info = TaskFailureInfo {
1135 task_id,
1136 queue: "test_queue".to_string(),
1137 failed_at: chrono::Utc::now().to_rfc3339(),
1138 reason: "Test failure".to_string(),
1139 status: "failed".to_string(),
1140 };
1141
1142 let serialized = rmp_serde::to_vec(&failure_info).expect("Failed to serialize");
1144 let deserialized: TaskFailureInfo =
1145 rmp_serde::from_slice(&serialized).expect("Failed to deserialize");
1146
1147 assert_eq!(deserialized.task_id, failure_info.task_id);
1148 assert_eq!(deserialized.queue, failure_info.queue);
1149 assert_eq!(deserialized.reason, failure_info.reason);
1150 assert_eq!(deserialized.status, failure_info.status);
1151 }
1152
1153 #[tokio::test]
1154 async fn test_queue_metrics_serialization() {
1155 let metrics = QueueMetrics {
1156 queue_name: "test_queue".to_string(),
1157 pending_tasks: 10,
1158 processed_tasks: 100,
1159 failed_tasks: 5,
1160 };
1161
1162 let serialized = rmp_serde::to_vec(&metrics).expect("Failed to serialize");
1164 let deserialized: QueueMetrics =
1165 rmp_serde::from_slice(&serialized).expect("Failed to deserialize");
1166
1167 assert_eq!(deserialized.queue_name, metrics.queue_name);
1168 assert_eq!(deserialized.pending_tasks, metrics.pending_tasks);
1169 assert_eq!(deserialized.processed_tasks, metrics.processed_tasks);
1170 assert_eq!(deserialized.failed_tasks, metrics.failed_tasks);
1171 }
1172
1173 #[tokio::test]
1174 async fn test_enqueue_task_wrapper() {
1175 let broker = create_test_broker().await;
1176 cleanup_test_data(&broker).await; let task_id = TaskId::new_v4();
1179 let metadata = TaskMetadata {
1180 id: task_id,
1181 name: "custom_task".to_string(),
1182 created_at: chrono::Utc::now(),
1183 attempts: 0,
1184 max_retries: 5,
1185 timeout_seconds: 600,
1186 };
1187
1188 let task_wrapper = TaskWrapper {
1189 metadata,
1190 payload: b"custom payload".to_vec(),
1191 };
1192
1193 let returned_id = broker
1194 .enqueue_task_wrapper(task_wrapper, "test_queue")
1195 .await
1196 .expect("Failed to enqueue task wrapper");
1197
1198 assert_eq!(returned_id, task_id);
1199
1200 let queue_size = broker
1201 .get_queue_size("test_queue")
1202 .await
1203 .expect("Failed to get queue size");
1204 assert_eq!(queue_size, 1);
1205
1206 cleanup_test_data(&broker).await; }
1208}