rust_task_queue/
broker.rs

1use crate::{Task, TaskId, TaskMetadata, TaskQueueError, TaskWrapper};
2use deadpool_redis::{Config, Pool, Runtime};
3use redis::AsyncCommands;
4use serde::{Deserialize, Serialize};
5
6pub struct RedisBroker {
7    pub(crate) pool: Pool,
8}
9
10impl RedisBroker {
11    pub async fn new(redis_url: &str) -> Result<Self, TaskQueueError> {
12        Self::new_with_config(redis_url, None).await
13    }
14
15    pub async fn new_with_config(
16        redis_url: &str,
17        pool_size: Option<usize>,
18    ) -> Result<Self, TaskQueueError> {
19        let mut config = Config::from_url(redis_url);
20        if let Some(size) = pool_size {
21            config.pool = Some(deadpool_redis::PoolConfig::new(size));
22        }
23
24        let pool = config.create_pool(Some(Runtime::Tokio1)).map_err(|e| {
25            TaskQueueError::Connection(format!("Failed to create Redis pool: {}", e))
26        })?;
27
28        // Test connection
29        let mut conn = pool.get().await.map_err(|e| {
30            TaskQueueError::Connection(format!("Failed to connect to Redis: {}", e))
31        })?;
32
33        // Verify Redis connection with a simple ping
34        redis::cmd("PING")
35            .query_async::<_, String>(&mut conn)
36            .await
37            .map_err(|e| {
38                TaskQueueError::Connection(format!("Redis connection test failed: {}", e))
39            })?;
40
41        Ok(Self { pool })
42    }
43
44    async fn get_conn(&self) -> Result<deadpool_redis::Connection, TaskQueueError> {
45        self.pool
46            .get()
47            .await
48            .map_err(|e| TaskQueueError::Connection(e.to_string()))
49    }
50
51    pub async fn enqueue_task<T: Task>(
52        &self,
53        task: T,
54        queue: &str,
55    ) -> Result<TaskId, TaskQueueError> {
56        let task_name = task.name();
57        let priority = task.priority();
58        let max_retries = task.max_retries();
59        let timeout_seconds = task.timeout_seconds();
60        let enqueue_start = std::time::Instant::now();
61
62        #[cfg(feature = "tracing")]
63        tracing::info!(
64            task_name = task_name,
65            queue = queue,
66            priority = ?priority,
67            max_retries = max_retries,
68            timeout_seconds = timeout_seconds,
69            "Enqueuing task"
70        );
71
72        let task_id = TaskId::new_v4();
73
74        // Create task metadata
75        let metadata = TaskMetadata {
76            id: task_id,
77            name: task.name().to_string(),
78            created_at: chrono::Utc::now(),
79            attempts: 0,
80            max_retries: task.max_retries(),
81            timeout_seconds: task.timeout_seconds(),
82        };
83
84        // Serialize the task
85        let payload = rmp_serde::to_vec(&task)?;
86        let payload_len = payload.len(); // Capture length before move
87
88        let task_wrapper = TaskWrapper {
89            metadata: metadata.clone(),
90            payload,
91        };
92
93        self.enqueue_task_wrapper(task_wrapper, queue).await?;
94
95        #[cfg(feature = "tracing")]
96        tracing::info!(
97            task_id = %task_id,
98            task_name = task_name,
99            queue = queue,
100            priority = ?priority,
101            duration_ms = enqueue_start.elapsed().as_millis(),
102            payload_size_bytes = payload_len,
103            "Task enqueued successfully"
104        );
105
106        Ok(task_id)
107    }
108
109    /// Validate and sanitize queue name to prevent Redis injection
110    fn validate_queue_name(queue: &str) -> Result<(), TaskQueueError> {
111        if queue.is_empty() {
112            return Err(TaskQueueError::Queue(
113                "Queue name cannot be empty".to_string(),
114            ));
115        }
116
117        if queue.len() > 255 {
118            return Err(TaskQueueError::Queue(
119                "Queue name too long (max 255 characters)".to_string(),
120            ));
121        }
122
123        // Only allow alphanumeric, dash, underscore, and colon
124        if !queue
125            .chars()
126            .all(|c| c.is_alphanumeric() || matches!(c, '-' | '_' | ':'))
127        {
128            return Err(TaskQueueError::Queue(
129                "Queue name contains invalid characters. Only alphanumeric, dash, underscore, and colon allowed".to_string()
130            ));
131        }
132
133        // Prevent Redis command injection patterns
134        let lowercase = queue.to_lowercase();
135        let dangerous_patterns = [
136            "eval",
137            "script",
138            "flushall",
139            "flushdb",
140            "shutdown",
141            "debug",
142            "config",
143            "info",
144            "monitor",
145            "sync",
146            "psync",
147            "slaveof",
148            "replicaof",
149        ];
150
151        for pattern in dangerous_patterns {
152            if lowercase.contains(pattern) {
153                return Err(TaskQueueError::Queue(format!(
154                    "Queue name contains potentially dangerous pattern: {}",
155                    pattern
156                )));
157            }
158        }
159
160        Ok(())
161    }
162
163    /// Validate task payload size to prevent DoS
164    fn validate_task_payload(payload: &[u8]) -> Result<(), TaskQueueError> {
165        const MAX_PAYLOAD_SIZE: usize = 16 * 1024 * 1024; // 16MB limit
166
167        if payload.len() > MAX_PAYLOAD_SIZE {
168            return Err(TaskQueueError::TaskExecution(format!(
169                "Task payload too large: {} bytes (max: {} bytes)",
170                payload.len(),
171                MAX_PAYLOAD_SIZE
172            )));
173        }
174
175        // Check for malformed MessagePack data
176        if payload.is_empty() {
177            return Err(TaskQueueError::TaskExecution(
178                "Task payload cannot be empty".to_string(),
179            ));
180        }
181
182        Ok(())
183    }
184
185    /// Enhanced enqueue with security validation
186    pub async fn enqueue_task_wrapper(
187        &self,
188        task_wrapper: TaskWrapper,
189        queue: &str,
190    ) -> Result<TaskId, TaskQueueError> {
191        let operation_start = std::time::Instant::now();
192        let task_id = task_wrapper.metadata.id;
193        let task_name = &task_wrapper.metadata.name;
194
195        #[cfg(feature = "tracing")]
196        tracing::debug!(
197            task_id = %task_id,
198            task_name = task_name,
199            queue = queue,
200            attempts = task_wrapper.metadata.attempts,
201            max_retries = task_wrapper.metadata.max_retries,
202            "Enqueuing task wrapper"
203        );
204
205        // SECURITY: Validate inputs
206        Self::validate_queue_name(queue)?;
207
208        // Serialize and validate task wrapper
209        let serialized = rmp_serde::to_vec(&task_wrapper)?;
210        Self::validate_task_payload(&serialized)?;
211
212        // Validate metadata
213        if task_wrapper.metadata.name.is_empty() {
214            #[cfg(feature = "tracing")]
215            tracing::error!(
216                task_id = %task_id,
217                "Task name validation failed: empty name"
218            );
219            return Err(TaskQueueError::TaskExecution(
220                "Task name cannot be empty".to_string(),
221            ));
222        }
223
224        if task_wrapper.metadata.name.len() > 255 {
225            #[cfg(feature = "tracing")]
226            tracing::error!(
227                task_id = %task_id,
228                task_name = task_name,
229                name_length = task_wrapper.metadata.name.len(),
230                "Task name validation failed: name too long"
231            );
232            return Err(TaskQueueError::TaskExecution(
233                "Task name too long (max 255 characters)".to_string(),
234            ));
235        }
236
237        let mut conn = self.get_conn().await?;
238
239        // FIXED: Use Redis pipeline without manual queue size tracking to avoid inconsistencies
240        let _pipeline_result: Vec<()> = redis::pipe()
241            .atomic()
242            // Push to the queue (left push for FIFO with right pop)
243            .lpush(queue, &serialized)
244            // Store task metadata for tracking
245            .set_ex(
246                format!("task:{}:metadata", task_wrapper.metadata.id),
247                rmp_serde::to_vec(&task_wrapper.metadata)?,
248                3600, // 1 hour TTL
249            )
250            .query_async(&mut *conn)
251            .await?;
252
253        #[cfg(feature = "tracing")]
254        tracing::info!(
255            task_id = %task_wrapper.metadata.id,
256            task_name = task_name,
257            queue = queue,
258            duration_ms = operation_start.elapsed().as_millis(),
259            payload_size_bytes = serialized.len(),
260            metadata_ttl_seconds = 3600,
261            "Task wrapper enqueued successfully using pipeline"
262        );
263
264        Ok(task_wrapper.metadata.id)
265    }
266
267    /// Batch enqueue multiple tasks for better throughput
268    pub async fn enqueue_tasks_batch<T: Task>(
269        &self,
270        tasks: Vec<(T, &str)>,
271    ) -> Result<Vec<TaskId>, TaskQueueError> {
272        if tasks.is_empty() {
273            #[cfg(feature = "tracing")]
274            tracing::warn!("Batch enqueue called with empty task list");
275            return Ok(Vec::new());
276        }
277
278        let batch_start = std::time::Instant::now();
279        let batch_size = tasks.len();
280
281        #[cfg(feature = "tracing")]
282        tracing::info!(batch_size = batch_size, "Starting batch enqueue operation");
283
284        let mut conn = self.get_conn().await?;
285        let mut pipeline = redis::pipe();
286        pipeline.atomic();
287
288        let mut task_ids = Vec::with_capacity(tasks.len());
289        let mut queue_distribution = std::collections::HashMap::new();
290
291        for (task, queue) in tasks {
292            // Validate queue name for each task
293            Self::validate_queue_name(queue)?;
294
295            let task_id = TaskId::new_v4();
296            task_ids.push(task_id);
297
298            // Track queue distribution for logging
299            *queue_distribution.entry(queue.to_string()).or_insert(0) += 1;
300
301            let metadata = TaskMetadata {
302                id: task_id,
303                name: task.name().to_string(),
304                created_at: chrono::Utc::now(),
305                attempts: 0,
306                max_retries: task.max_retries(),
307                timeout_seconds: task.timeout_seconds(),
308            };
309
310            let payload = rmp_serde::to_vec(&task)?;
311            let payload_len = payload.len();
312            let task_wrapper = TaskWrapper {
313                metadata: metadata.clone(),
314                payload,
315            };
316            let serialized = rmp_serde::to_vec(&task_wrapper)?;
317
318            // Validate task payload
319            Self::validate_task_payload(&serialized)?;
320
321            // Add to pipeline - only LPUSH and metadata storage
322            pipeline.lpush(queue, &serialized);
323            pipeline.set_ex(
324                format!("task:{}:metadata", task_id),
325                rmp_serde::to_vec(&metadata)?,
326                3600,
327            );
328
329            #[cfg(feature = "tracing")]
330            tracing::debug!(
331                task_id = %task_id,
332                task_name = task.name(),
333                queue = queue,
334                payload_size_bytes = payload_len,
335                "Task added to batch pipeline"
336            );
337        }
338
339        // Execute all operations atomically
340        let _: Vec<()> = pipeline.query_async(&mut *conn).await?;
341
342        #[cfg(feature = "tracing")]
343        tracing::info!(
344            batch_size = task_ids.len(),
345            duration_ms = batch_start.elapsed().as_millis(),
346            queue_distribution = ?queue_distribution,
347            total_task_ids = task_ids.len(),
348            "Batch enqueue completed successfully using pipeline"
349        );
350
351        Ok(task_ids)
352    }
353
354    pub async fn dequeue_task(
355        &self,
356        queues: &[String],
357    ) -> Result<Option<TaskWrapper>, TaskQueueError> {
358        let dequeue_start = std::time::Instant::now();
359
360        #[cfg(feature = "tracing")]
361        tracing::debug!(
362            queues = ?queues,
363            queue_count = queues.len(),
364            "Starting dequeue operation"
365        );
366
367        let mut conn = self.get_conn().await?;
368
369        // Use BRPOP for blocking right pop (FIFO with LPUSH)
370        let result: Option<(String, Vec<u8>)> = conn.brpop(queues, 5f64).await?;
371
372        if let Some((queue, serialized)) = result {
373            let task_wrapper: TaskWrapper = rmp_serde::from_slice(&serialized)?;
374
375            #[cfg(feature = "tracing")]
376            tracing::info!(
377                task_id = %task_wrapper.metadata.id,
378                task_name = %task_wrapper.metadata.name,
379                queue = queue,
380                duration_ms = dequeue_start.elapsed().as_millis(),
381                payload_size_bytes = serialized.len(),
382                attempts = task_wrapper.metadata.attempts,
383                max_retries = task_wrapper.metadata.max_retries,
384                created_at = %task_wrapper.metadata.created_at,
385                "Task dequeued successfully"
386            );
387
388            Ok(Some(task_wrapper))
389        } else {
390            #[cfg(feature = "tracing")]
391            tracing::trace!(
392                duration_ms = dequeue_start.elapsed().as_millis(),
393                queues = ?queues,
394                "Dequeue operation timed out - no tasks available"
395            );
396            Ok(None)
397        }
398    }
399
400    pub async fn get_queue_size(&self, queue: &str) -> Result<i64, TaskQueueError> {
401        let mut conn = self.get_conn().await?;
402        let size: i64 = conn.llen(queue).await?;
403        Ok(size)
404    }
405
406    pub async fn get_queue_metrics(&self, queue: &str) -> Result<QueueMetrics, TaskQueueError> {
407        let operation_start = std::time::Instant::now();
408
409        #[cfg(feature = "tracing")]
410        tracing::debug!(queue = queue, "Retrieving queue metrics");
411
412        let mut conn = self.get_conn().await?;
413
414        let size: i64 = conn.llen(queue).await?;
415        let processed_key = format!("queue:{}:processed", queue);
416        let failed_key = format!("queue:{}:failed", queue);
417
418        let processed: i64 = conn.get(&processed_key).await.unwrap_or(0);
419        let failed: i64 = conn.get(&failed_key).await.unwrap_or(0);
420
421        let metrics = QueueMetrics {
422            queue_name: queue.to_string(),
423            pending_tasks: size,
424            processed_tasks: processed,
425            failed_tasks: failed,
426        };
427
428        #[cfg(feature = "tracing")]
429        tracing::debug!(
430            queue = queue,
431            pending_tasks = metrics.pending_tasks,
432            processed_tasks = metrics.processed_tasks,
433            failed_tasks = metrics.failed_tasks,
434            total_tasks = metrics.pending_tasks + metrics.processed_tasks + metrics.failed_tasks,
435            success_rate = if (metrics.processed_tasks + metrics.failed_tasks) > 0 {
436                metrics.processed_tasks as f64
437                    / (metrics.processed_tasks + metrics.failed_tasks) as f64
438            } else {
439                0.0
440            },
441            duration_ms = operation_start.elapsed().as_millis(),
442            "Queue metrics retrieved"
443        );
444
445        Ok(metrics)
446    }
447
448    pub async fn mark_task_completed(
449        &self,
450        task_id: TaskId,
451        queue: &str,
452    ) -> Result<(), TaskQueueError> {
453        let operation_start = std::time::Instant::now();
454
455        #[cfg(feature = "tracing")]
456        tracing::debug!(
457            task_id = %task_id,
458            queue = queue,
459            "Marking task as completed"
460        );
461
462        let mut conn = self.get_conn().await?;
463        let processed_key = format!("queue:{}:processed", queue);
464        conn.incr::<_, _, ()>(&processed_key, 1).await?;
465
466        // Remove task metadata
467        let metadata_key = format!("task:{}:metadata", task_id);
468        conn.del::<_, ()>(&metadata_key).await?;
469
470        #[cfg(feature = "tracing")]
471        tracing::info!(
472            task_id = %task_id,
473            queue = queue,
474            duration_ms = operation_start.elapsed().as_millis(),
475            processed_key = processed_key,
476            "Task marked as completed successfully"
477        );
478
479        Ok(())
480    }
481
482    pub async fn mark_task_failed_with_reason(
483        &self,
484        task_id: TaskId,
485        queue: &str,
486        reason: Option<String>,
487    ) -> Result<(), TaskQueueError> {
488        let mut conn = self.get_conn().await?;
489
490        // Increment the failed counter for queue metrics
491        let failed_key = format!("queue:{}:failed", queue);
492        conn.incr::<_, _, ()>(&failed_key, 1).await?;
493
494        let default_reason = reason.unwrap_or_else(|| "Unknown error".to_string());
495
496        // Store detailed failure information
497        let failure_key = format!("task:{}:failure", task_id);
498        let failure_info = TaskFailureInfo {
499            task_id,
500            queue: queue.to_string(),
501            failed_at: chrono::Utc::now().to_rfc3339(),
502            reason: default_reason.clone(),
503            status: "failed".to_string(),
504        };
505
506        // Store failure info with expiration
507        conn.set::<_, _, ()>(&failure_key, rmp_serde::to_vec(&failure_info)?)
508            .await?;
509        conn.expire::<_, ()>(&failure_key, 86400).await?;
510
511        // Add to failed tasks set for monitoring
512        let queue_failed_set = format!("queue:{}:failed_tasks", queue);
513        conn.sadd::<_, _, ()>(&queue_failed_set, task_id.to_string())
514            .await?;
515        conn.expire::<_, ()>(&queue_failed_set, 86400).await?;
516
517        // Clean up task metadata
518        let metadata_key = format!("task:{}:metadata", task_id);
519        conn.del::<_, ()>(&metadata_key).await?;
520
521        #[cfg(feature = "tracing")]
522        tracing::warn!(
523            "Task {} marked as failed in queue {} - Reason: {}",
524            task_id,
525            queue,
526            default_reason
527        );
528
529        Ok(())
530    }
531
532    // Keep the original method for backwards compatibility
533    pub async fn mark_task_failed(
534        &self,
535        task_id: TaskId,
536        queue: &str,
537    ) -> Result<(), TaskQueueError> {
538        self.mark_task_failed_with_reason(task_id, queue, None)
539            .await
540    }
541
542    pub async fn get_active_workers(&self) -> Result<i64, TaskQueueError> {
543        let mut conn = self.get_conn().await?;
544        let count: i64 = conn.scard("active_workers").await?;
545        Ok(count)
546    }
547
548    pub async fn register_worker(&self, worker_id: &str) -> Result<(), TaskQueueError> {
549        let operation_start = std::time::Instant::now();
550
551        #[cfg(feature = "tracing")]
552        tracing::info!(worker_id = worker_id, "Registering worker");
553
554        let mut conn = self.get_conn().await?;
555        conn.sadd::<_, _, ()>("active_workers", worker_id).await?;
556
557        // Set heartbeat
558        let heartbeat_key = format!("worker:{}:heartbeat", worker_id);
559        let heartbeat_timestamp = chrono::Utc::now().to_rfc3339();
560        conn.set::<_, _, ()>(&heartbeat_key, &heartbeat_timestamp)
561            .await?;
562        conn.expire::<_, ()>(&heartbeat_key, 60).await?;
563
564        let active_workers = self.get_active_workers().await.unwrap_or(0);
565
566        #[cfg(feature = "tracing")]
567        tracing::info!(
568            worker_id = worker_id,
569            duration_ms = operation_start.elapsed().as_millis(),
570            heartbeat_key = heartbeat_key,
571            heartbeat_timestamp = heartbeat_timestamp,
572            total_active_workers = active_workers,
573            "Worker registered successfully"
574        );
575
576        Ok(())
577    }
578
579    pub async fn unregister_worker(&self, worker_id: &str) -> Result<(), TaskQueueError> {
580        let operation_start = std::time::Instant::now();
581
582        #[cfg(feature = "tracing")]
583        tracing::info!(worker_id = worker_id, "Unregistering worker");
584
585        let mut conn = self.get_conn().await?;
586        conn.srem::<_, _, ()>("active_workers", worker_id).await?;
587
588        // Clean up heartbeat
589        let heartbeat_key = format!("worker:{}:heartbeat", worker_id);
590        conn.del::<_, ()>(&heartbeat_key).await?;
591
592        let active_workers = self.get_active_workers().await.unwrap_or(0);
593
594        #[cfg(feature = "tracing")]
595        tracing::info!(
596            worker_id = worker_id,
597            duration_ms = operation_start.elapsed().as_millis(),
598            heartbeat_key = heartbeat_key,
599            total_active_workers = active_workers,
600            "Worker unregistered successfully"
601        );
602
603        Ok(())
604    }
605
606    pub async fn update_worker_heartbeat(&self, worker_id: &str) -> Result<(), TaskQueueError> {
607        let operation_start = std::time::Instant::now();
608
609        let mut conn = self.get_conn().await?;
610        let heartbeat_key = format!("worker:{}:heartbeat", worker_id);
611        let heartbeat_timestamp = chrono::Utc::now().to_rfc3339();
612        conn.set::<_, _, ()>(&heartbeat_key, &heartbeat_timestamp)
613            .await?;
614        conn.expire::<_, ()>(&heartbeat_key, 60).await?;
615
616        #[cfg(feature = "tracing")]
617        tracing::trace!(
618            worker_id = worker_id,
619            duration_ms = operation_start.elapsed().as_millis(),
620            heartbeat_timestamp = heartbeat_timestamp,
621            "Worker heartbeat updated"
622        );
623
624        Ok(())
625    }
626
627    pub async fn get_task_failure_info(
628        &self,
629        task_id: TaskId,
630    ) -> Result<Option<TaskFailureInfo>, TaskQueueError> {
631        let mut conn = self.get_conn().await?;
632        let failure_key = format!("task:{}:failure", task_id);
633
634        if let Ok(data) = conn.get::<_, Vec<u8>>(&failure_key).await {
635            match rmp_serde::from_slice::<TaskFailureInfo>(&data) {
636                Ok(info) => Ok(Some(info)),
637                Err(_) => Ok(None),
638            }
639        } else {
640            Ok(None)
641        }
642    }
643
644    pub async fn get_failed_tasks(&self, queue: &str) -> Result<Vec<String>, TaskQueueError> {
645        let mut conn = self.get_conn().await?;
646        let queue_failed_set = format!("queue:{}:failed_tasks", queue);
647        let failed_tasks: Vec<String> = conn.smembers(&queue_failed_set).await.unwrap_or_default();
648        Ok(failed_tasks)
649    }
650}
651
652#[derive(Debug, Clone, Serialize, Deserialize)]
653pub struct TaskFailureInfo {
654    pub task_id: TaskId,
655    pub queue: String,
656    pub failed_at: String,
657    pub reason: String,
658    pub status: String,
659}
660
661#[derive(Debug, Clone, Serialize, Deserialize)]
662pub struct QueueMetrics {
663    pub queue_name: String,
664    pub pending_tasks: i64,
665    pub processed_tasks: i64,
666    pub failed_tasks: i64,
667}
668
669#[cfg(test)]
670mod tests {
671    use super::*;
672    use serde::{Deserialize, Serialize};
673
674    #[derive(Debug, Serialize, Deserialize, Clone)]
675    struct TestTask {
676        data: String,
677    }
678
679    #[async_trait::async_trait]
680    impl Task for TestTask {
681        async fn execute(&self) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
682            Ok(self.data.as_bytes().to_vec())
683        }
684
685        fn name(&self) -> &str {
686            "test_task"
687        }
688    }
689
690    fn get_test_redis_url() -> String {
691        // Use a combination of thread ID and timestamp for better uniqueness
692        use std::collections::hash_map::DefaultHasher;
693        use std::hash::{Hash, Hasher};
694
695        let mut hasher = DefaultHasher::new();
696        std::thread::current().id().hash(&mut hasher);
697        std::time::SystemTime::now()
698            .duration_since(std::time::UNIX_EPOCH)
699            .unwrap_or_default()
700            .as_nanos()
701            .hash(&mut hasher);
702
703        let db_num = (hasher.finish() % 16) as u8; // Redis has 16 DBs by default (0-15)
704        std::env::var("REDIS_TEST_URL")
705            .unwrap_or_else(|_| format!("redis://127.0.0.1:6379/{}", db_num))
706    }
707
708    async fn create_test_broker() -> RedisBroker {
709        let redis_url = get_test_redis_url();
710        RedisBroker::new(&redis_url)
711            .await
712            .expect("Failed to create test broker")
713    }
714
715    async fn cleanup_test_data(broker: &RedisBroker) {
716        if let Ok(mut conn) = broker.get_conn().await {
717            // Use FLUSHDB to clear only this database, then wait a bit for consistency
718            let _: Result<String, _> = redis::cmd("FLUSHDB").query_async(&mut conn).await;
719            tokio::time::sleep(std::time::Duration::from_millis(10)).await;
720        }
721    }
722
723    #[tokio::test]
724    async fn test_broker_creation() {
725        let broker = create_test_broker().await;
726        cleanup_test_data(&broker).await; // Clean before test
727
728        // Broker should be created successfully and connection should work
729        assert!(broker.get_conn().await.is_ok());
730
731        cleanup_test_data(&broker).await; // Clean after test
732    }
733
734    #[tokio::test]
735    async fn test_broker_creation_with_config() {
736        let redis_url = get_test_redis_url();
737        let broker = RedisBroker::new_with_config(&redis_url, Some(5))
738            .await
739            .expect("Failed to create broker");
740        cleanup_test_data(&broker).await; // Clean before test
741
742        assert!(broker.get_conn().await.is_ok());
743
744        cleanup_test_data(&broker).await; // Clean after test
745    }
746
747    #[tokio::test]
748    async fn test_broker_invalid_url() {
749        let result = RedisBroker::new("redis://invalid-host:6379").await;
750        assert!(result.is_err());
751
752        if let Err(e) = result {
753            assert!(matches!(e, TaskQueueError::Connection(_)));
754        }
755    }
756
757    #[tokio::test]
758    async fn test_enqueue_task() {
759        let broker = create_test_broker().await;
760        cleanup_test_data(&broker).await; // Clean before test
761
762        let task = TestTask {
763            data: "test data".to_string(),
764        };
765
766        let task_id = broker
767            .enqueue_task(task, "test_queue")
768            .await
769            .expect("Failed to enqueue task");
770
771        // Verify task was enqueued
772        let queue_size = broker
773            .get_queue_size("test_queue")
774            .await
775            .expect("Failed to get queue size");
776        assert_eq!(queue_size, 1);
777
778        // Verify task ID was generated
779        assert!(!task_id.to_string().is_empty());
780
781        cleanup_test_data(&broker).await; // Clean after test
782    }
783
784    #[tokio::test]
785    async fn test_dequeue_task() {
786        let broker = create_test_broker().await;
787        cleanup_test_data(&broker).await; // Clean before test
788
789        let task = TestTask {
790            data: "test data".to_string(),
791        };
792
793        let task_id = broker
794            .enqueue_task(task, "test_queue")
795            .await
796            .expect("Failed to enqueue task");
797
798        // Verify exactly one task is in the queue
799        let queue_size = broker
800            .get_queue_size("test_queue")
801            .await
802            .expect("Failed to get queue size");
803        assert_eq!(
804            queue_size, 1,
805            "Queue should have exactly 1 task before dequeue"
806        );
807
808        // Dequeue the task
809        let queues = vec!["test_queue".to_string()];
810        let dequeued = broker
811            .dequeue_task(&queues)
812            .await
813            .expect("Failed to dequeue task");
814
815        assert!(dequeued.is_some(), "Should have dequeued a task");
816        let task_wrapper = dequeued.unwrap();
817        assert_eq!(
818            task_wrapper.metadata.id, task_id,
819            "Task ID should match the enqueued task"
820        );
821        assert_eq!(
822            task_wrapper.metadata.name, "test_task",
823            "Task name should match"
824        );
825
826        // Queue should be empty now
827        let queue_size = broker
828            .get_queue_size("test_queue")
829            .await
830            .expect("Failed to get queue size");
831        assert_eq!(queue_size, 0, "Queue should be empty after dequeue");
832
833        cleanup_test_data(&broker).await; // Clean after test
834    }
835
836    #[tokio::test]
837    async fn test_dequeue_from_empty_queue() {
838        let broker = create_test_broker().await;
839        cleanup_test_data(&broker).await;
840
841        let queues = vec!["empty_queue".to_string()];
842
843        // Should timeout and return None
844        let start = std::time::Instant::now();
845        let result = broker
846            .dequeue_task(&queues)
847            .await
848            .expect("Failed to dequeue from empty queue");
849        let elapsed = start.elapsed();
850
851        assert!(result.is_none());
852        // Should have waited approximately 5 seconds (the timeout)
853        assert!(elapsed.as_secs() >= 4 && elapsed.as_secs() <= 6);
854
855        cleanup_test_data(&broker).await;
856    }
857
858    #[tokio::test]
859    async fn test_queue_metrics() {
860        let broker = create_test_broker().await;
861        cleanup_test_data(&broker).await; // Clean before test
862
863        // Initial metrics should be zero
864        let metrics = broker
865            .get_queue_metrics("test_queue")
866            .await
867            .expect("Failed to get metrics");
868        assert_eq!(metrics.pending_tasks, 0);
869        assert_eq!(metrics.processed_tasks, 0);
870        assert_eq!(metrics.failed_tasks, 0);
871
872        // Add a task
873        let task = TestTask {
874            data: "test".to_string(),
875        };
876        broker
877            .enqueue_task(task, "test_queue")
878            .await
879            .expect("Failed to enqueue task");
880
881        let metrics = broker
882            .get_queue_metrics("test_queue")
883            .await
884            .expect("Failed to get metrics");
885        assert_eq!(metrics.pending_tasks, 1);
886
887        cleanup_test_data(&broker).await; // Clean after test
888    }
889
890    #[tokio::test]
891    async fn test_mark_task_completed() {
892        let broker = create_test_broker().await;
893        cleanup_test_data(&broker).await;
894
895        let task = TestTask {
896            data: "test".to_string(),
897        };
898        let task_id = broker
899            .enqueue_task(task, "test_queue")
900            .await
901            .expect("Failed to enqueue task");
902
903        // Mark as completed
904        broker
905            .mark_task_completed(task_id, "test_queue")
906            .await
907            .expect("Failed to mark completed");
908
909        let metrics = broker
910            .get_queue_metrics("test_queue")
911            .await
912            .expect("Failed to get metrics");
913        assert_eq!(metrics.processed_tasks, 1);
914
915        cleanup_test_data(&broker).await;
916    }
917
918    #[tokio::test]
919    async fn test_mark_task_failed() {
920        let broker = create_test_broker().await;
921        cleanup_test_data(&broker).await;
922
923        let task = TestTask {
924            data: "test".to_string(),
925        };
926        let task_id = broker
927            .enqueue_task(task, "test_queue")
928            .await
929            .expect("Failed to enqueue task");
930
931        // Mark as failed
932        broker
933            .mark_task_failed(task_id, "test_queue")
934            .await
935            .expect("Failed to mark failed");
936
937        let metrics = broker
938            .get_queue_metrics("test_queue")
939            .await
940            .expect("Failed to get metrics");
941        assert_eq!(metrics.failed_tasks, 1);
942
943        cleanup_test_data(&broker).await;
944    }
945
946    #[tokio::test]
947    async fn test_mark_task_failed_with_reason() {
948        let broker = create_test_broker().await;
949        cleanup_test_data(&broker).await;
950
951        let task = TestTask {
952            data: "test".to_string(),
953        };
954        let task_id = broker
955            .enqueue_task(task, "test_queue")
956            .await
957            .expect("Failed to enqueue task");
958
959        let reason = "Custom failure reason".to_string();
960        broker
961            .mark_task_failed_with_reason(task_id, "test_queue", Some(reason.clone()))
962            .await
963            .expect("Failed to mark failed with reason");
964
965        // Verify failure info was stored
966        let failure_info = broker
967            .get_task_failure_info(task_id)
968            .await
969            .expect("Failed to get failure info");
970        assert!(failure_info.is_some());
971
972        let info = failure_info.unwrap();
973        assert_eq!(info.task_id, task_id);
974        assert_eq!(info.queue, "test_queue");
975        assert_eq!(info.reason, reason);
976        assert_eq!(info.status, "failed");
977
978        cleanup_test_data(&broker).await;
979    }
980
981    #[tokio::test]
982    async fn test_worker_registration() {
983        let broker = create_test_broker().await;
984        cleanup_test_data(&broker).await; // Clean before test
985
986        let worker_id = "test_worker_001";
987
988        // Register worker
989        broker
990            .register_worker(worker_id)
991            .await
992            .expect("Failed to register worker");
993
994        let active_workers = broker
995            .get_active_workers()
996            .await
997            .expect("Failed to get active workers");
998        assert_eq!(active_workers, 1);
999
1000        // Update heartbeat
1001        broker
1002            .update_worker_heartbeat(worker_id)
1003            .await
1004            .expect("Failed to update heartbeat");
1005
1006        // Unregister worker
1007        broker
1008            .unregister_worker(worker_id)
1009            .await
1010            .expect("Failed to unregister worker");
1011
1012        let active_workers = broker
1013            .get_active_workers()
1014            .await
1015            .expect("Failed to get active workers");
1016        assert_eq!(active_workers, 0);
1017
1018        cleanup_test_data(&broker).await; // Clean after test
1019    }
1020
1021    #[tokio::test]
1022    async fn test_multiple_workers() {
1023        let broker = create_test_broker().await;
1024        cleanup_test_data(&broker).await;
1025
1026        // Register multiple workers
1027        for i in 0..5 {
1028            let worker_id = format!("worker_{}", i);
1029            broker
1030                .register_worker(&worker_id)
1031                .await
1032                .expect("Failed to register worker");
1033        }
1034
1035        let active_workers = broker
1036            .get_active_workers()
1037            .await
1038            .expect("Failed to get active workers");
1039        assert_eq!(active_workers, 5);
1040
1041        cleanup_test_data(&broker).await;
1042    }
1043
1044    #[tokio::test]
1045    async fn test_failed_tasks_tracking() {
1046        let broker = create_test_broker().await;
1047        cleanup_test_data(&broker).await;
1048
1049        // Enqueue and fail multiple tasks
1050        for i in 0..3 {
1051            let task = TestTask {
1052                data: format!("task_{}", i),
1053            };
1054            let task_id = broker
1055                .enqueue_task(task, "test_queue")
1056                .await
1057                .expect("Failed to enqueue task");
1058            broker
1059                .mark_task_failed(task_id, "test_queue")
1060                .await
1061                .expect("Failed to mark failed");
1062        }
1063
1064        let failed_tasks = broker
1065            .get_failed_tasks("test_queue")
1066            .await
1067            .expect("Failed to get failed tasks");
1068        assert_eq!(failed_tasks.len(), 3);
1069
1070        cleanup_test_data(&broker).await;
1071    }
1072
1073    #[tokio::test]
1074    async fn test_queue_metrics_comprehensive() {
1075        let broker = create_test_broker().await;
1076        cleanup_test_data(&broker).await;
1077
1078        // Add pending tasks
1079        for i in 0..3 {
1080            let task = TestTask {
1081                data: format!("pending_{}", i),
1082            };
1083            broker
1084                .enqueue_task(task, "test_queue")
1085                .await
1086                .expect("Failed to enqueue task");
1087        }
1088
1089        // Add processed tasks
1090        for i in 0..2 {
1091            let task = TestTask {
1092                data: format!("processed_{}", i),
1093            };
1094            let task_id = broker
1095                .enqueue_task(task, "temp_queue")
1096                .await
1097                .expect("Failed to enqueue task");
1098            broker
1099                .mark_task_completed(task_id, "test_queue")
1100                .await
1101                .expect("Failed to mark completed");
1102        }
1103
1104        // Add failed tasks
1105        for i in 0..1 {
1106            let task = TestTask {
1107                data: format!("failed_{}", i),
1108            };
1109            let task_id = broker
1110                .enqueue_task(task, "temp_queue")
1111                .await
1112                .expect("Failed to enqueue task");
1113            broker
1114                .mark_task_failed(task_id, "test_queue")
1115                .await
1116                .expect("Failed to mark failed");
1117        }
1118
1119        let metrics = broker
1120            .get_queue_metrics("test_queue")
1121            .await
1122            .expect("Failed to get metrics");
1123        assert_eq!(metrics.pending_tasks, 3);
1124        assert_eq!(metrics.processed_tasks, 2);
1125        assert_eq!(metrics.failed_tasks, 1);
1126        assert_eq!(metrics.queue_name, "test_queue");
1127
1128        cleanup_test_data(&broker).await;
1129    }
1130
1131    #[tokio::test]
1132    async fn test_task_failure_info_serialization() {
1133        let task_id = TaskId::new_v4();
1134        let failure_info = TaskFailureInfo {
1135            task_id,
1136            queue: "test_queue".to_string(),
1137            failed_at: chrono::Utc::now().to_rfc3339(),
1138            reason: "Test failure".to_string(),
1139            status: "failed".to_string(),
1140        };
1141
1142        // Test serialization
1143        let serialized = rmp_serde::to_vec(&failure_info).expect("Failed to serialize");
1144        let deserialized: TaskFailureInfo =
1145            rmp_serde::from_slice(&serialized).expect("Failed to deserialize");
1146
1147        assert_eq!(deserialized.task_id, failure_info.task_id);
1148        assert_eq!(deserialized.queue, failure_info.queue);
1149        assert_eq!(deserialized.reason, failure_info.reason);
1150        assert_eq!(deserialized.status, failure_info.status);
1151    }
1152
1153    #[tokio::test]
1154    async fn test_queue_metrics_serialization() {
1155        let metrics = QueueMetrics {
1156            queue_name: "test_queue".to_string(),
1157            pending_tasks: 10,
1158            processed_tasks: 100,
1159            failed_tasks: 5,
1160        };
1161
1162        // Test serialization
1163        let serialized = rmp_serde::to_vec(&metrics).expect("Failed to serialize");
1164        let deserialized: QueueMetrics =
1165            rmp_serde::from_slice(&serialized).expect("Failed to deserialize");
1166
1167        assert_eq!(deserialized.queue_name, metrics.queue_name);
1168        assert_eq!(deserialized.pending_tasks, metrics.pending_tasks);
1169        assert_eq!(deserialized.processed_tasks, metrics.processed_tasks);
1170        assert_eq!(deserialized.failed_tasks, metrics.failed_tasks);
1171    }
1172
1173    #[tokio::test]
1174    async fn test_enqueue_task_wrapper() {
1175        let broker = create_test_broker().await;
1176        cleanup_test_data(&broker).await; // Clean before test
1177
1178        let task_id = TaskId::new_v4();
1179        let metadata = TaskMetadata {
1180            id: task_id,
1181            name: "custom_task".to_string(),
1182            created_at: chrono::Utc::now(),
1183            attempts: 0,
1184            max_retries: 5,
1185            timeout_seconds: 600,
1186        };
1187
1188        let task_wrapper = TaskWrapper {
1189            metadata,
1190            payload: b"custom payload".to_vec(),
1191        };
1192
1193        let returned_id = broker
1194            .enqueue_task_wrapper(task_wrapper, "test_queue")
1195            .await
1196            .expect("Failed to enqueue task wrapper");
1197
1198        assert_eq!(returned_id, task_id);
1199
1200        let queue_size = broker
1201            .get_queue_size("test_queue")
1202            .await
1203            .expect("Failed to get queue size");
1204        assert_eq!(queue_size, 1);
1205
1206        cleanup_test_data(&broker).await; // Clean after test
1207    }
1208}