rust_task_queue/
worker.rs

1use std::sync::{
2    atomic::{AtomicUsize, Ordering},
3    Arc,
4};
5
6use crate::{error::TaskQueueError, queue::queue_names, RedisBroker, TaskScheduler, TaskWrapper};
7use serde::Serialize;
8use tokio::sync::{mpsc, Semaphore};
9use tokio::task::JoinHandle;
10
11#[cfg(feature = "tracing")]
12use tracing::Instrument;
13
14pub struct Worker {
15    id: String,
16    broker: Arc<RedisBroker>,
17    #[allow(dead_code)]
18    scheduler: Arc<TaskScheduler>,
19    task_registry: Arc<crate::TaskRegistry>,
20    shutdown_tx: Option<mpsc::Sender<()>>,
21    handle: Option<JoinHandle<()>>,
22    // Added for backpressure control
23    max_concurrent_tasks: usize,
24    task_semaphore: Option<Arc<Semaphore>>,
25    // Task tracking for better observability
26    active_tasks: Arc<AtomicUsize>,
27}
28
29/// Configuration for worker backpressure and performance tuning
30#[derive(Debug, Clone)]
31pub struct WorkerBackpressureConfig {
32    pub max_concurrent_tasks: usize,
33    pub queue_size_threshold: usize,
34    pub backpressure_delay_ms: u64,
35}
36
37/// Context for task execution spawning
38#[derive(Clone)]
39struct TaskExecutionContext {
40    broker: Arc<RedisBroker>,
41    task_registry: Arc<crate::TaskRegistry>,
42    worker_id: String,
43    semaphore: Option<Arc<Semaphore>>,
44    active_tasks: Arc<AtomicUsize>,
45}
46
47/// Result of task spawning attempt
48#[derive(Debug)]
49enum SpawnResult {
50    Spawned,
51    Rejected(TaskWrapper),
52    Failed(TaskQueueError),
53}
54
55impl Worker {
56    pub fn new(id: String, broker: Arc<RedisBroker>, scheduler: Arc<TaskScheduler>) -> Self {
57        let max_concurrent_tasks = 10;
58        Self {
59            id,
60            broker,
61            scheduler,
62            task_registry: Arc::new(crate::TaskRegistry::new()),
63            shutdown_tx: None,
64            handle: None,
65            max_concurrent_tasks,
66            task_semaphore: Some(Arc::new(Semaphore::new(max_concurrent_tasks))),
67            active_tasks: Arc::new(AtomicUsize::new(0)),
68        }
69    }
70
71    pub fn with_task_registry(mut self, registry: Arc<crate::TaskRegistry>) -> Self {
72        self.task_registry = registry;
73        self
74    }
75
76    pub fn with_backpressure_config(mut self, config: WorkerBackpressureConfig) -> Self {
77        self.max_concurrent_tasks = config.max_concurrent_tasks;
78        self.task_semaphore = Some(Arc::new(Semaphore::new(config.max_concurrent_tasks)));
79        self
80    }
81
82    pub fn with_max_concurrent_tasks(mut self, max_tasks: usize) -> Self {
83        self.max_concurrent_tasks = max_tasks;
84        self.task_semaphore = Some(Arc::new(Semaphore::new(max_tasks)));
85        self
86    }
87
88    /// Get the current number of active tasks
89    pub fn active_task_count(&self) -> usize {
90        self.active_tasks.load(Ordering::Relaxed)
91    }
92
93    pub async fn start(mut self) -> Result<Self, TaskQueueError> {
94        let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1);
95
96        let execution_context = TaskExecutionContext {
97            broker: self.broker.clone(),
98            task_registry: self.task_registry.clone(),
99            worker_id: self.id.clone(),
100            semaphore: self.task_semaphore.clone(),
101            active_tasks: self.active_tasks.clone(),
102        };
103
104        // Register worker
105        execution_context
106            .broker
107            .register_worker(&execution_context.worker_id)
108            .await?;
109
110        let handle = tokio::spawn(async move {
111            let queues = vec![
112                queue_names::DEFAULT.to_string(),
113                queue_names::HIGH_PRIORITY.to_string(),
114                queue_names::LOW_PRIORITY.to_string(),
115            ];
116
117            #[cfg(feature = "tracing")]
118            tracing::info!(
119                worker_id = %execution_context.worker_id,
120                queues = ?queues,
121                semaphore_permits = execution_context.semaphore.as_ref().map(|s| s.available_permits()),
122                "Worker main loop started"
123            );
124
125            loop {
126                tokio::select! {
127                    // Check for shutdown signal
128                    _ = shutdown_rx.recv() => {
129                        #[cfg(feature = "tracing")]
130                        tracing::info!(
131                            worker_id = %execution_context.worker_id,
132                            active_tasks = execution_context.active_tasks.load(Ordering::Relaxed),
133                            "Worker received shutdown signal"
134                        );
135
136                        // Wait for active tasks to complete before shutting down
137                        Self::graceful_shutdown(&execution_context).await;
138
139                        if let Err(_e) = execution_context.broker.unregister_worker(&execution_context.worker_id).await {
140                            #[cfg(feature = "tracing")]
141                            tracing::error!(
142                                worker_id = %execution_context.worker_id,
143                                error = %_e,
144                                "Failed to unregister worker during shutdown"
145                            );
146                        } else {
147                            #[cfg(feature = "tracing")]
148                            tracing::info!(
149                                worker_id = %execution_context.worker_id,
150                                "Worker unregistered successfully during shutdown"
151                            );
152                        }
153                        break;
154                    }
155
156                    // Process tasks with improved spawning logic
157                    task_result = execution_context.broker.dequeue_task(&queues) => {
158                        let current_active_tasks = execution_context.active_tasks.load(Ordering::Relaxed);
159
160                        match task_result {
161                            Ok(Some(task_wrapper)) => {
162                                #[cfg(feature = "tracing")]
163                                tracing::debug!(
164                                    worker_id = %execution_context.worker_id,
165                                    task_id = %task_wrapper.metadata.id,
166                                    task_name = %task_wrapper.metadata.name,
167                                    queue_source = "unknown", // Could be enhanced to track which queue
168                                    active_tasks_before = current_active_tasks,
169                                    "Task received for processing"
170                                );
171
172                                match Self::handle_task_execution(execution_context.clone(), task_wrapper).await {
173                                    SpawnResult::Spawned => {
174                                        #[cfg(feature = "tracing")]
175                                        tracing::debug!(
176                                            worker_id = %execution_context.worker_id,
177                                            active_tasks = execution_context.active_tasks.load(Ordering::Relaxed),
178                                            semaphore_permits = execution_context.semaphore.as_ref().map(|s| s.available_permits()),
179                                            "Task spawned successfully"
180                                        );
181                                    }
182                                    SpawnResult::Rejected(rejected_task) => {
183                                        #[cfg(feature = "tracing")]
184                                        tracing::warn!(
185                                            worker_id = %execution_context.worker_id,
186                                            task_id = %rejected_task.metadata.id,
187                                            task_name = %rejected_task.metadata.name,
188                                            active_tasks = current_active_tasks,
189                                            semaphore_permits = execution_context.semaphore.as_ref().map(|s| s.available_permits()),
190                                            "Task rejected due to backpressure"
191                                        );
192                                    }
193                                    SpawnResult::Failed(_e) => {
194                                        #[cfg(feature = "tracing")]
195                                        tracing::error!(
196                                            worker_id = %execution_context.worker_id,
197                                            error = %_e,
198                                            active_tasks = current_active_tasks,
199                                            "Failed to handle task execution"
200                                        );
201                                    }
202                                }
203                            }
204                            Ok(None) => {
205                                // No tasks available, update heartbeat
206                                #[cfg(feature = "tracing")]
207                                tracing::trace!(
208                                    worker_id = %execution_context.worker_id,
209                                    active_tasks = current_active_tasks,
210                                    "No tasks available, updating heartbeat"
211                                );
212
213                                if let Err(_e) = execution_context.broker.update_worker_heartbeat(&execution_context.worker_id).await {
214                                    #[cfg(feature = "tracing")]
215                                    tracing::error!(
216                                        worker_id = %execution_context.worker_id,
217                                        error = %_e,
218                                        "Failed to update worker heartbeat"
219                                    );
220                                }
221                            }
222                            Err(_e) => {
223                                #[cfg(feature = "tracing")]
224                                tracing::error!(
225                                    worker_id = %execution_context.worker_id,
226                                    error = %_e,
227                                    active_tasks = current_active_tasks,
228                                    "Error dequeuing task, backing off"
229                                );
230
231                                tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
232                            }
233                        }
234                    }
235                }
236            }
237
238            #[cfg(feature = "tracing")]
239            tracing::info!(
240                worker_id = %execution_context.worker_id,
241                "Worker main loop ended gracefully"
242            );
243        });
244
245        self.shutdown_tx = Some(shutdown_tx);
246        self.handle = Some(handle);
247
248        #[cfg(feature = "tracing")]
249        tracing::info!("Started worker {}", self.id);
250
251        Ok(self)
252    }
253
254    /// Improved async task spawning with proper resource management
255    async fn handle_task_execution(
256        context: TaskExecutionContext,
257        task_wrapper: TaskWrapper,
258    ) -> SpawnResult {
259        // Extract semaphore to avoid borrowing issues
260        let semaphore_opt = context.semaphore.clone();
261
262        match semaphore_opt {
263            Some(semaphore) => {
264                // Clone semaphore before use to avoid borrow checker issues
265                let semaphore_clone = semaphore.clone();
266
267                // Try to acquire permit without blocking the main worker loop
268                match semaphore.try_acquire() {
269                    Ok(_permit) => {
270                        // Drop permit and rely on async acquisition in spawned task
271                        // This maintains backpressure while avoiding lifetime issues
272                        drop(_permit);
273                        Self::spawn_task_with_semaphore(context, task_wrapper, semaphore_clone)
274                            .await;
275                        SpawnResult::Spawned
276                    }
277                    Err(_) => {
278                        // At capacity, handle backpressure
279                        Self::handle_backpressure(context, task_wrapper).await
280                    }
281                }
282            }
283            None => {
284                // No backpressure control, execute directly
285                Self::execute_task_directly(context, task_wrapper).await;
286                SpawnResult::Spawned
287            }
288        }
289    }
290
291    /// Wait for active tasks to complete during shutdown
292    async fn graceful_shutdown(context: &TaskExecutionContext) {
293        let shutdown_timeout = tokio::time::Duration::from_secs(30);
294        let check_interval = tokio::time::Duration::from_millis(100);
295
296        let start_time = tokio::time::Instant::now();
297
298        while context.active_tasks.load(Ordering::Relaxed) > 0
299            && start_time.elapsed() < shutdown_timeout
300        {
301            #[cfg(feature = "tracing")]
302            tracing::debug!(
303                "Worker {} waiting for {} active tasks to complete",
304                context.worker_id,
305                context.active_tasks.load(Ordering::Relaxed)
306            );
307
308            tokio::time::sleep(check_interval).await;
309        }
310
311        if context.active_tasks.load(Ordering::Relaxed) > 0 {
312            #[cfg(feature = "tracing")]
313            tracing::warn!(
314                "Worker {} shutdown timeout reached with {} active tasks",
315                context.worker_id,
316                context.active_tasks.load(Ordering::Relaxed)
317            );
318        }
319    }
320
321    /// Spawn task execution with semaphore backpressure
322    async fn spawn_task_with_semaphore(
323        context: TaskExecutionContext,
324        task_wrapper: TaskWrapper,
325        semaphore: Arc<Semaphore>,
326    ) {
327        // Increment active task counter
328        context.active_tasks.fetch_add(1, Ordering::Relaxed);
329
330        tokio::spawn(async move {
331            // Acquire permit for the full duration of execution - this is the correct approach
332            let _permit = semaphore
333                .acquire()
334                .await
335                .expect("Semaphore should not be closed");
336
337            if let Err(_e) = Self::process_task(
338                &context.broker,
339                &context.task_registry,
340                &context.worker_id,
341                task_wrapper,
342            )
343            .await
344            {
345                #[cfg(feature = "tracing")]
346                tracing::error!("Task processing failed: {}", _e);
347            }
348
349            // Decrement active task counter
350            context.active_tasks.fetch_sub(1, Ordering::Relaxed);
351            // Permit is automatically released when dropped here
352        });
353    }
354
355    /// Handle backpressure by re-queuing task
356    async fn handle_backpressure(
357        context: TaskExecutionContext,
358        task_wrapper: TaskWrapper,
359    ) -> SpawnResult {
360        // Attempt to re-queue the task
361        match Self::requeue_task(&context.broker, task_wrapper.clone()).await {
362            Ok(_) => {
363                #[cfg(feature = "tracing")]
364                tracing::debug!(
365                    "Task {} re-queued due to backpressure",
366                    task_wrapper.metadata.id
367                );
368                SpawnResult::Rejected(task_wrapper)
369            }
370            Err(e) => SpawnResult::Failed(e),
371        }
372    }
373
374    /// Execute task directly without semaphore control
375    async fn execute_task_directly(context: TaskExecutionContext, task_wrapper: TaskWrapper) {
376        // Increment active task counter
377        context.active_tasks.fetch_add(1, Ordering::Relaxed);
378
379        tokio::spawn(async move {
380            if let Err(_e) = Self::process_task(
381                &context.broker,
382                &context.task_registry,
383                &context.worker_id,
384                task_wrapper,
385            )
386            .await
387            {
388                #[cfg(feature = "tracing")]
389                tracing::error!("Task processing failed: {}", _e);
390            }
391
392            // Decrement active task counter
393            context.active_tasks.fetch_sub(1, Ordering::Relaxed);
394        });
395    }
396
397    /// Re-queue a task for later processing
398    async fn requeue_task(
399        broker: &RedisBroker,
400        task_wrapper: TaskWrapper,
401    ) -> Result<(), TaskQueueError> {
402        broker
403            .enqueue_task_wrapper(task_wrapper, queue_names::DEFAULT)
404            .await?;
405        Ok(())
406    }
407
408    pub async fn stop(self) {
409        if let Some(tx) = self.shutdown_tx {
410            if let Err(_e) = tx.send(()).await {
411                #[cfg(feature = "tracing")]
412                tracing::error!("Failed to send shutdown signal");
413            }
414        }
415
416        if let Some(handle) = self.handle {
417            let _ = handle.await;
418        }
419    }
420
421    async fn process_task(
422        broker: &RedisBroker,
423        task_registry: &crate::TaskRegistry,
424        worker_id: &str,
425        mut task_wrapper: TaskWrapper,
426    ) -> Result<(), TaskQueueError> {
427        let task_id = task_wrapper.metadata.id;
428        let task_name = &task_wrapper.metadata.name;
429
430        // Create a span for the entire task lifecycle
431        let span = tracing::info_span!(
432            "process_task",
433            task_id = %task_id,
434            task_name = task_name,
435            worker_id = worker_id,
436            attempt = task_wrapper.metadata.attempts + 1,
437            max_retries = task_wrapper.metadata.max_retries
438        );
439
440        async move {
441            #[cfg(feature = "tracing")]
442            tracing::info!(
443                created_at = %task_wrapper.metadata.created_at,
444                timeout_seconds = task_wrapper.metadata.timeout_seconds,
445                payload_size_bytes = task_wrapper.payload.len(),
446                "Starting task processing"
447            );
448
449            let execution_start = std::time::Instant::now();
450            task_wrapper.metadata.attempts += 1;
451
452            // Record attempt start
453            #[cfg(feature = "tracing")]
454            tracing::debug!(
455                attempt = task_wrapper.metadata.attempts,
456                created_at = %task_wrapper.metadata.created_at,
457                "Task execution attempt started"
458            );
459
460            // Execute the task
461            match Self::execute_task_with_registry(task_registry, &task_wrapper).await {
462                Ok(result) => {
463                    let execution_duration = execution_start.elapsed();
464
465                    #[cfg(feature = "tracing")]
466                    tracing::info!(
467                        duration_ms = execution_duration.as_millis(),
468                        result_size_bytes = result.len(),
469                        success = true,
470                        "Task completed successfully"
471                    );
472
473                    broker
474                        .mark_task_completed(task_id, queue_names::DEFAULT)
475                        .await?;
476                }
477                Err(error) => {
478                    let execution_duration = execution_start.elapsed();
479                    let error_msg = error.to_string();
480
481                    #[cfg(feature = "tracing")]
482                    tracing::error!(
483                        duration_ms = execution_duration.as_millis(),
484                        error = %error,
485                        error_source = error.source().map(|e| e.to_string()).as_deref(),
486                        success = false,
487                        "Task execution failed"
488                    );
489
490                    // Check if we should retry
491                    if task_wrapper.metadata.attempts < task_wrapper.metadata.max_retries {
492                        let remaining_retries =
493                            task_wrapper.metadata.max_retries - task_wrapper.metadata.attempts;
494
495                        #[cfg(feature = "tracing")]
496                        tracing::warn!(
497                            remaining_retries = remaining_retries,
498                            retry_delay_ms = 1000 * task_wrapper.metadata.attempts as u64, // Simple exponential backoff
499                            "Re-queuing task for retry"
500                        );
501
502                        // Re-enqueue for retry
503                        broker
504                            .enqueue_task_wrapper(task_wrapper, queue_names::DEFAULT)
505                            .await?;
506                    } else {
507                        #[cfg(feature = "tracing")]
508                        tracing::error!(
509                            final_error = %error_msg,
510                            total_attempts = task_wrapper.metadata.attempts,
511                            "Task failed permanently - maximum retries exceeded"
512                        );
513
514                        broker
515                            .mark_task_failed_with_reason(
516                                task_id,
517                                queue_names::DEFAULT,
518                                Some(error_msg),
519                            )
520                            .await?;
521                    }
522                }
523            }
524
525            Ok(())
526        }
527        .instrument(span)
528        .await
529    }
530
531    async fn execute_task_with_registry(
532        task_registry: &crate::TaskRegistry,
533        task_wrapper: &TaskWrapper,
534    ) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
535        let task_name = &task_wrapper.metadata.name;
536        let task_id = task_wrapper.metadata.id;
537        let execution_start = std::time::Instant::now();
538
539        #[cfg(feature = "tracing")]
540        tracing::debug!(
541            task_id = %task_id,
542            task_name = task_name,
543            payload_size_bytes = task_wrapper.payload.len(),
544            "Attempting task execution with registry"
545        );
546
547        // Try to execute using the task registry
548        match task_registry
549            .execute(task_name, task_wrapper.payload.clone())
550            .await
551        {
552            Ok(result) => {
553                #[cfg(feature = "tracing")]
554                tracing::info!(
555                    task_id = %task_id,
556                    task_name = task_name,
557                    duration_ms = execution_start.elapsed().as_millis(),
558                    result_size_bytes = result.len(),
559                    execution_method = "registry",
560                    "Task executed successfully with registered executor"
561                );
562
563                Ok(result)
564            }
565            Err(error) => {
566                let execution_duration = execution_start.elapsed();
567
568                // Check if this is a "task not found" error vs a task execution failure
569                let error_msg = error.to_string();
570                if error_msg.contains("Unknown task type") {
571                    #[cfg(feature = "tracing")]
572                    tracing::warn!(
573                        task_id = %task_id,
574                        task_name = task_name,
575                        duration_ms = execution_duration.as_millis(),
576                        error = %error,
577                        "No executor found for task type, using fallback"
578                    );
579
580                    // Fallback: serialize task metadata as response
581                    #[derive(Serialize)]
582                    struct FallbackResponse {
583                        status: String,
584                        message: String,
585                        timestamp: String,
586                        task_id: String,
587                        task_name: String,
588                    }
589
590                    let response = FallbackResponse {
591                        status: "completed".to_string(),
592                        message: format!("Task {} completed with fallback executor", task_name),
593                        timestamp: chrono::Utc::now().to_rfc3339(),
594                        task_id: task_id.to_string(),
595                        task_name: task_name.to_string(),
596                    };
597
598                    let serialized = serde_json::to_vec(&response)
599                        .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
600
601                    #[cfg(feature = "tracing")]
602                    tracing::info!(
603                        task_id = %task_id,
604                        task_name = task_name,
605                        duration_ms = execution_duration.as_millis(),
606                        result_size_bytes = serialized.len(),
607                        execution_method = "fallback",
608                        "Task completed with fallback executor"
609                    );
610
611                    Ok(serialized)
612                } else {
613                    // This is an actual task execution failure - propagate it
614                    #[cfg(feature = "tracing")]
615                    tracing::error!(
616                        task_id = %task_id,
617                        task_name = task_name,
618                        duration_ms = execution_duration.as_millis(),
619                        error = %error,
620                        error_source = error.source().map(|e| e.to_string()).as_deref(),
621                        execution_method = "registry",
622                        "Task execution failed in registered executor"
623                    );
624
625                    Err(error)
626                }
627            }
628        }
629    }
630}
631
632#[cfg(test)]
633mod tests {
634    use super::*;
635    use crate::{TaskId, TaskMetadata};
636    use std::time::Duration;
637    use tokio::time::timeout;
638
639    fn create_test_broker() -> Arc<RedisBroker> {
640        // Create a mock broker for testing - this is a simplified version
641        let redis_url = std::env::var("REDIS_TEST_URL")
642            .unwrap_or_else(|_| "redis://127.0.0.1:6379/15".to_string());
643
644        // For unit tests, we'll create a minimal broker
645        let config = deadpool_redis::Config::from_url(&redis_url);
646        let pool = config
647            .create_pool(Some(deadpool_redis::Runtime::Tokio1))
648            .expect("Failed to create test pool");
649
650        Arc::new(RedisBroker { pool })
651    }
652
653    fn create_test_scheduler() -> Arc<TaskScheduler> {
654        let broker = create_test_broker();
655        Arc::new(TaskScheduler::new(broker))
656    }
657
658    fn create_test_task_wrapper() -> TaskWrapper {
659        TaskWrapper {
660            metadata: TaskMetadata {
661                id: TaskId::new_v4(),
662                name: "test_task".to_string(),
663                created_at: chrono::Utc::now(),
664                attempts: 0,
665                max_retries: 3,
666                timeout_seconds: 60,
667            },
668            payload: b"test payload".to_vec(),
669        }
670    }
671
672    // Helper function to get a connection for tests since get_conn is private
673    async fn get_test_connection(
674        broker: &RedisBroker,
675    ) -> Result<deadpool_redis::Connection, deadpool_redis::PoolError> {
676        broker.pool.get().await
677    }
678
679    #[test]
680    fn test_worker_creation() {
681        let broker = create_test_broker();
682        let scheduler = create_test_scheduler();
683        let worker_id = "test_worker_001".to_string();
684
685        let worker = Worker::new(worker_id.clone(), broker, scheduler);
686
687        assert_eq!(worker.id, worker_id);
688        assert_eq!(worker.max_concurrent_tasks, 10);
689        assert_eq!(worker.active_task_count(), 0);
690        assert!(worker.task_semaphore.is_some());
691        assert!(worker.shutdown_tx.is_none());
692        assert!(worker.handle.is_none());
693    }
694
695    #[test]
696    fn test_worker_with_task_registry() {
697        let broker = create_test_broker();
698        let scheduler = create_test_scheduler();
699        let worker_id = "test_worker_002".to_string();
700        let registry = Arc::new(crate::TaskRegistry::new());
701
702        let _worker =
703            Worker::new(worker_id, broker, scheduler).with_task_registry(registry.clone());
704
705        // The registry should be set
706        assert_eq!(Arc::strong_count(&registry), 2); // One in worker, one here
707    }
708
709    #[test]
710    fn test_worker_with_backpressure_config() {
711        let broker = create_test_broker();
712        let scheduler = create_test_scheduler();
713        let worker_id = "test_worker_003".to_string();
714
715        let config = WorkerBackpressureConfig {
716            max_concurrent_tasks: 20,
717            queue_size_threshold: 100,
718            backpressure_delay_ms: 500,
719        };
720
721        let worker =
722            Worker::new(worker_id, broker, scheduler).with_backpressure_config(config.clone());
723
724        assert_eq!(worker.max_concurrent_tasks, config.max_concurrent_tasks);
725        assert!(worker.task_semaphore.is_some());
726
727        if let Some(semaphore) = &worker.task_semaphore {
728            assert_eq!(semaphore.available_permits(), config.max_concurrent_tasks);
729        }
730    }
731
732    #[test]
733    fn test_worker_with_max_concurrent_tasks() {
734        let broker = create_test_broker();
735        let scheduler = create_test_scheduler();
736        let worker_id = "test_worker_004".to_string();
737        let max_tasks = 15;
738
739        let worker = Worker::new(worker_id, broker, scheduler).with_max_concurrent_tasks(max_tasks);
740
741        assert_eq!(worker.max_concurrent_tasks, max_tasks);
742        assert!(worker.task_semaphore.is_some());
743
744        if let Some(semaphore) = &worker.task_semaphore {
745            assert_eq!(semaphore.available_permits(), max_tasks);
746        }
747    }
748
749    #[test]
750    fn test_worker_backpressure_config_clone() {
751        let original = WorkerBackpressureConfig {
752            max_concurrent_tasks: 25,
753            queue_size_threshold: 200,
754            backpressure_delay_ms: 1000,
755        };
756
757        let cloned = original.clone();
758
759        assert_eq!(original.max_concurrent_tasks, cloned.max_concurrent_tasks);
760        assert_eq!(original.queue_size_threshold, cloned.queue_size_threshold);
761        assert_eq!(original.backpressure_delay_ms, cloned.backpressure_delay_ms);
762    }
763
764    #[test]
765    fn test_worker_backpressure_config_debug() {
766        let config = WorkerBackpressureConfig {
767            max_concurrent_tasks: 8,
768            queue_size_threshold: 50,
769            backpressure_delay_ms: 250,
770        };
771
772        let debug_str = format!("{:?}", config);
773
774        assert!(debug_str.contains("WorkerBackpressureConfig"));
775        assert!(debug_str.contains("max_concurrent_tasks: 8"));
776        assert!(debug_str.contains("queue_size_threshold: 50"));
777        assert!(debug_str.contains("backpressure_delay_ms: 250"));
778    }
779
780    #[test]
781    fn test_spawn_result_debug() {
782        let spawned = SpawnResult::Spawned;
783        let rejected = SpawnResult::Rejected(create_test_task_wrapper());
784        let failed = SpawnResult::Failed(TaskQueueError::Serialization(
785            rmp_serde::encode::Error::Syntax("test error".to_string()),
786        ));
787
788        let spawned_debug = format!("{:?}", spawned);
789        let rejected_debug = format!("{:?}", rejected);
790        let failed_debug = format!("{:?}", failed);
791
792        assert!(spawned_debug.contains("Spawned"));
793        assert!(rejected_debug.contains("Rejected"));
794        assert!(failed_debug.contains("Failed"));
795    }
796
797    #[test]
798    fn test_task_execution_context_clone() {
799        let broker = create_test_broker();
800        let task_registry = Arc::new(crate::TaskRegistry::new());
801        let worker_id = "test_worker_005".to_string();
802        let semaphore = Some(Arc::new(Semaphore::new(10)));
803        let active_tasks = Arc::new(AtomicUsize::new(0));
804
805        let context = TaskExecutionContext {
806            broker: broker.clone(),
807            task_registry: task_registry.clone(),
808            worker_id: worker_id.clone(),
809            semaphore: semaphore.clone(),
810            active_tasks: active_tasks.clone(),
811        };
812
813        let cloned = context.clone();
814
815        assert_eq!(cloned.worker_id, worker_id);
816        assert_eq!(cloned.active_tasks.load(Ordering::Relaxed), 0);
817        assert!(cloned.semaphore.is_some());
818    }
819
820    #[tokio::test]
821    async fn test_active_task_count_tracking() {
822        let broker = create_test_broker();
823        let scheduler = create_test_scheduler();
824        let worker_id = "test_worker_006".to_string();
825
826        let worker = Worker::new(worker_id, broker, scheduler);
827
828        assert_eq!(worker.active_task_count(), 0);
829
830        // Simulate task processing
831        worker.active_tasks.fetch_add(1, Ordering::Relaxed);
832        assert_eq!(worker.active_task_count(), 1);
833
834        worker.active_tasks.fetch_add(2, Ordering::Relaxed);
835        assert_eq!(worker.active_task_count(), 3);
836
837        worker.active_tasks.fetch_sub(1, Ordering::Relaxed);
838        assert_eq!(worker.active_task_count(), 2);
839    }
840
841    #[tokio::test]
842    async fn test_requeue_task() {
843        let broker = create_test_broker();
844        let task_wrapper = create_test_task_wrapper();
845
846        // Clean up any existing data
847        if let Ok(mut conn) = get_test_connection(&broker).await {
848            let _: Result<String, _> = redis::cmd("FLUSHDB").query_async(&mut conn).await;
849        }
850
851        let result = Worker::requeue_task(&broker, task_wrapper).await;
852        assert!(result.is_ok());
853
854        // Verify task was requeued
855        let queue_size = broker
856            .get_queue_size(queue_names::DEFAULT)
857            .await
858            .expect("Failed to get queue size");
859        assert!(queue_size > 0);
860    }
861
862    #[tokio::test]
863    async fn test_execute_task_with_registry_fallback() {
864        let task_registry = crate::TaskRegistry::new();
865        let task_wrapper = create_test_task_wrapper();
866
867        let result = Worker::execute_task_with_registry(&task_registry, &task_wrapper).await;
868        assert!(result.is_ok());
869
870        let output = result.unwrap();
871        assert!(!output.is_empty());
872
873        // Verify it's valid JSON (fallback response)
874        let parsed: serde_json::Value =
875            serde_json::from_slice(&output).expect("Should be valid JSON");
876
877        assert_eq!(parsed["status"], "completed");
878        assert!(parsed["message"].as_str().unwrap().contains("test_task"));
879        assert!(parsed["timestamp"].is_string());
880    }
881
882    #[tokio::test]
883    async fn test_process_task_success() {
884        let broker = create_test_broker();
885        let task_registry = crate::TaskRegistry::new();
886        let worker_id = "test_worker_007";
887        let task_wrapper = create_test_task_wrapper();
888
889        // Clean up any existing data
890        if let Ok(mut conn) = get_test_connection(&broker).await {
891            let _: Result<String, _> = redis::cmd("FLUSHDB").query_async(&mut conn).await;
892        }
893
894        let result = Worker::process_task(&broker, &task_registry, worker_id, task_wrapper).await;
895        assert!(result.is_ok());
896
897        // Verify metrics were updated
898        let metrics = broker
899            .get_queue_metrics(queue_names::DEFAULT)
900            .await
901            .expect("Failed to get metrics");
902        assert_eq!(metrics.processed_tasks, 1);
903    }
904
905    #[tokio::test]
906    async fn test_execute_task_directly() {
907        let broker = create_test_broker();
908        let task_registry = Arc::new(crate::TaskRegistry::new());
909        let worker_id = "test_worker_012".to_string();
910        let active_tasks = Arc::new(AtomicUsize::new(0));
911        let task_wrapper = create_test_task_wrapper();
912
913        let context = TaskExecutionContext {
914            broker,
915            task_registry,
916            worker_id,
917            semaphore: None,
918            active_tasks: active_tasks.clone(),
919        };
920
921        assert_eq!(active_tasks.load(Ordering::Relaxed), 0);
922
923        Worker::execute_task_directly(context, task_wrapper).await;
924
925        // Wait longer and poll for the task to start (more robust timing)
926        let mut attempts = 0;
927        while active_tasks.load(Ordering::Relaxed) == 0 && attempts < 50 {
928            tokio::time::sleep(Duration::from_millis(10)).await;
929            attempts += 1;
930        }
931
932        // The task should have incremented the counter
933        assert!(
934            active_tasks.load(Ordering::Relaxed) >= 1,
935            "Task should have started and incremented active count"
936        );
937
938        // Wait for task to complete with longer timeout
939        let mut attempts = 0;
940        while active_tasks.load(Ordering::Relaxed) > 0 && attempts < 100 {
941            tokio::time::sleep(Duration::from_millis(10)).await;
942            attempts += 1;
943        }
944
945        // The task should have decremented the counter when it completed
946        assert_eq!(active_tasks.load(Ordering::Relaxed), 0);
947    }
948
949    #[tokio::test]
950    async fn test_graceful_shutdown() {
951        let broker = create_test_broker();
952        let task_registry = Arc::new(crate::TaskRegistry::new());
953        let worker_id = "test_worker_008".to_string();
954        let active_tasks = Arc::new(AtomicUsize::new(2));
955
956        let context = TaskExecutionContext {
957            broker,
958            task_registry,
959            worker_id,
960            semaphore: None,
961            active_tasks: active_tasks.clone(),
962        };
963
964        // Simulate tasks completing during shutdown
965        let active_tasks_clone = active_tasks.clone();
966        tokio::spawn(async move {
967            tokio::time::sleep(Duration::from_millis(50)).await;
968            active_tasks_clone.fetch_sub(1, Ordering::Relaxed);
969            tokio::time::sleep(Duration::from_millis(50)).await;
970            active_tasks_clone.fetch_sub(1, Ordering::Relaxed);
971        });
972
973        // Test graceful shutdown with more generous timeout for system variations
974        let start = std::time::Instant::now();
975        Worker::graceful_shutdown(&context).await;
976        let elapsed = start.elapsed();
977
978        assert_eq!(context.active_tasks.load(Ordering::Relaxed), 0);
979        assert!(
980            elapsed < Duration::from_millis(500),
981            "Shutdown should complete in reasonable time"
982        ); // More generous timeout
983    }
984
985    #[test]
986    fn test_worker_default_configuration() {
987        let broker = create_test_broker();
988        let scheduler = create_test_scheduler();
989        let worker_id = "test_worker_011".to_string();
990
991        let worker = Worker::new(worker_id.clone(), broker.clone(), scheduler.clone());
992
993        // Test default values
994        assert_eq!(worker.id, worker_id);
995        assert_eq!(worker.max_concurrent_tasks, 10);
996        assert_eq!(worker.active_task_count(), 0);
997        assert!(worker.task_semaphore.is_some());
998        assert!(worker.shutdown_tx.is_none());
999        assert!(worker.handle.is_none());
1000
1001        // Test that broker and scheduler are properly stored
1002        assert_eq!(Arc::strong_count(&broker), 2); // One in worker, one here
1003        assert_eq!(Arc::strong_count(&scheduler), 2); // One in worker, one here
1004    }
1005
1006    #[test]
1007    fn test_worker_backpressure_config_defaults() {
1008        let config = WorkerBackpressureConfig {
1009            max_concurrent_tasks: 50,
1010            queue_size_threshold: 1000,
1011            backpressure_delay_ms: 100,
1012        };
1013
1014        assert_eq!(config.max_concurrent_tasks, 50);
1015        assert_eq!(config.queue_size_threshold, 1000);
1016        assert_eq!(config.backpressure_delay_ms, 100);
1017    }
1018
1019    #[tokio::test]
1020    async fn test_worker_method_chaining() {
1021        let broker = create_test_broker();
1022        let scheduler = create_test_scheduler();
1023        let worker_id = "test_worker_013".to_string();
1024        let registry = Arc::new(crate::TaskRegistry::new());
1025
1026        let config = WorkerBackpressureConfig {
1027            max_concurrent_tasks: 25,
1028            queue_size_threshold: 500,
1029            backpressure_delay_ms: 200,
1030        };
1031
1032        let worker = Worker::new(worker_id.clone(), broker, scheduler)
1033            .with_task_registry(registry.clone())
1034            .with_backpressure_config(config.clone())
1035            .with_max_concurrent_tasks(30); // This should override the config value
1036
1037        assert_eq!(worker.id, worker_id);
1038        assert_eq!(worker.max_concurrent_tasks, 30); // Should be overridden
1039        assert_eq!(Arc::strong_count(&registry), 2); // One in worker, one here
1040
1041        if let Some(semaphore) = &worker.task_semaphore {
1042            assert_eq!(semaphore.available_permits(), 30);
1043        }
1044    }
1045
1046    #[tokio::test]
1047    async fn test_graceful_shutdown_timeout() {
1048        let broker = create_test_broker();
1049        let task_registry = Arc::new(crate::TaskRegistry::new());
1050        let worker_id = "test_worker_009".to_string();
1051        let active_tasks = Arc::new(AtomicUsize::new(1)); // Task that never completes
1052
1053        let context = TaskExecutionContext {
1054            broker,
1055            task_registry,
1056            worker_id,
1057            semaphore: None,
1058            active_tasks,
1059        };
1060
1061        // Test shutdown timeout (this would normally wait 30s, but we'll test with timeout)
1062        let result = timeout(
1063            Duration::from_millis(200),
1064            Worker::graceful_shutdown(&context),
1065        )
1066        .await;
1067        assert!(result.is_err()); // Should timeout
1068        assert_eq!(context.active_tasks.load(Ordering::Relaxed), 1); // Task still active
1069    }
1070
1071    #[tokio::test]
1072    async fn test_handle_backpressure() {
1073        let broker = create_test_broker();
1074        let task_registry = Arc::new(crate::TaskRegistry::new());
1075        let worker_id = "test_worker_010".to_string();
1076        let task_wrapper = create_test_task_wrapper();
1077
1078        // Clean up any existing data
1079        if let Ok(mut conn) = get_test_connection(&broker).await {
1080            let _: Result<String, _> = redis::cmd("FLUSHDB").query_async(&mut conn).await;
1081        }
1082
1083        let context = TaskExecutionContext {
1084            broker: broker.clone(),
1085            task_registry,
1086            worker_id,
1087            semaphore: None,
1088            active_tasks: Arc::new(AtomicUsize::new(0)),
1089        };
1090
1091        let result = Worker::handle_backpressure(context, task_wrapper.clone()).await;
1092
1093        match result {
1094            SpawnResult::Rejected(rejected_wrapper) => {
1095                assert_eq!(rejected_wrapper.metadata.id, task_wrapper.metadata.id);
1096            }
1097            _ => panic!("Expected rejected result"),
1098        }
1099
1100        // Verify task was requeued
1101        let queue_size = broker
1102            .get_queue_size(queue_names::DEFAULT)
1103            .await
1104            .expect("Failed to get queue size");
1105        assert!(queue_size > 0);
1106    }
1107}