1use std::sync::{
2 atomic::{AtomicUsize, Ordering},
3 Arc,
4};
5
6use crate::{error::TaskQueueError, queue::queue_names, RedisBroker, TaskScheduler, TaskWrapper};
7use serde::Serialize;
8use tokio::sync::{mpsc, Semaphore};
9use tokio::task::JoinHandle;
10
11#[cfg(feature = "tracing")]
12use tracing::Instrument;
13
14pub struct Worker {
15 id: String,
16 broker: Arc<RedisBroker>,
17 #[allow(dead_code)]
18 scheduler: Arc<TaskScheduler>,
19 task_registry: Arc<crate::TaskRegistry>,
20 shutdown_tx: Option<mpsc::Sender<()>>,
21 handle: Option<JoinHandle<()>>,
22 max_concurrent_tasks: usize,
24 task_semaphore: Option<Arc<Semaphore>>,
25 active_tasks: Arc<AtomicUsize>,
27}
28
29#[derive(Debug, Clone)]
31pub struct WorkerBackpressureConfig {
32 pub max_concurrent_tasks: usize,
33 pub queue_size_threshold: usize,
34 pub backpressure_delay_ms: u64,
35}
36
37#[derive(Clone)]
39struct TaskExecutionContext {
40 broker: Arc<RedisBroker>,
41 task_registry: Arc<crate::TaskRegistry>,
42 worker_id: String,
43 semaphore: Option<Arc<Semaphore>>,
44 active_tasks: Arc<AtomicUsize>,
45}
46
47#[derive(Debug)]
49enum SpawnResult {
50 Spawned,
51 Rejected(TaskWrapper),
52 Failed(TaskQueueError),
53}
54
55impl Worker {
56 pub fn new(id: String, broker: Arc<RedisBroker>, scheduler: Arc<TaskScheduler>) -> Self {
57 let max_concurrent_tasks = 10;
58 Self {
59 id,
60 broker,
61 scheduler,
62 task_registry: Arc::new(crate::TaskRegistry::new()),
63 shutdown_tx: None,
64 handle: None,
65 max_concurrent_tasks,
66 task_semaphore: Some(Arc::new(Semaphore::new(max_concurrent_tasks))),
67 active_tasks: Arc::new(AtomicUsize::new(0)),
68 }
69 }
70
71 pub fn with_task_registry(mut self, registry: Arc<crate::TaskRegistry>) -> Self {
72 self.task_registry = registry;
73 self
74 }
75
76 pub fn with_backpressure_config(mut self, config: WorkerBackpressureConfig) -> Self {
77 self.max_concurrent_tasks = config.max_concurrent_tasks;
78 self.task_semaphore = Some(Arc::new(Semaphore::new(config.max_concurrent_tasks)));
79 self
80 }
81
82 pub fn with_max_concurrent_tasks(mut self, max_tasks: usize) -> Self {
83 self.max_concurrent_tasks = max_tasks;
84 self.task_semaphore = Some(Arc::new(Semaphore::new(max_tasks)));
85 self
86 }
87
88 pub fn active_task_count(&self) -> usize {
90 self.active_tasks.load(Ordering::Relaxed)
91 }
92
93 pub async fn start(mut self) -> Result<Self, TaskQueueError> {
94 let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1);
95
96 let execution_context = TaskExecutionContext {
97 broker: self.broker.clone(),
98 task_registry: self.task_registry.clone(),
99 worker_id: self.id.clone(),
100 semaphore: self.task_semaphore.clone(),
101 active_tasks: self.active_tasks.clone(),
102 };
103
104 execution_context
106 .broker
107 .register_worker(&execution_context.worker_id)
108 .await?;
109
110 let handle = tokio::spawn(async move {
111 let queues = vec![
112 queue_names::DEFAULT.to_string(),
113 queue_names::HIGH_PRIORITY.to_string(),
114 queue_names::LOW_PRIORITY.to_string(),
115 ];
116
117 #[cfg(feature = "tracing")]
118 tracing::info!(
119 worker_id = %execution_context.worker_id,
120 queues = ?queues,
121 semaphore_permits = execution_context.semaphore.as_ref().map(|s| s.available_permits()),
122 "Worker main loop started"
123 );
124
125 loop {
126 tokio::select! {
127 _ = shutdown_rx.recv() => {
129 #[cfg(feature = "tracing")]
130 tracing::info!(
131 worker_id = %execution_context.worker_id,
132 active_tasks = execution_context.active_tasks.load(Ordering::Relaxed),
133 "Worker received shutdown signal"
134 );
135
136 Self::graceful_shutdown(&execution_context).await;
138
139 if let Err(_e) = execution_context.broker.unregister_worker(&execution_context.worker_id).await {
140 #[cfg(feature = "tracing")]
141 tracing::error!(
142 worker_id = %execution_context.worker_id,
143 error = %_e,
144 "Failed to unregister worker during shutdown"
145 );
146 } else {
147 #[cfg(feature = "tracing")]
148 tracing::info!(
149 worker_id = %execution_context.worker_id,
150 "Worker unregistered successfully during shutdown"
151 );
152 }
153 break;
154 }
155
156 task_result = execution_context.broker.dequeue_task(&queues) => {
158 let current_active_tasks = execution_context.active_tasks.load(Ordering::Relaxed);
159
160 match task_result {
161 Ok(Some(task_wrapper)) => {
162 #[cfg(feature = "tracing")]
163 tracing::debug!(
164 worker_id = %execution_context.worker_id,
165 task_id = %task_wrapper.metadata.id,
166 task_name = %task_wrapper.metadata.name,
167 queue_source = "unknown", active_tasks_before = current_active_tasks,
169 "Task received for processing"
170 );
171
172 match Self::handle_task_execution(execution_context.clone(), task_wrapper).await {
173 SpawnResult::Spawned => {
174 #[cfg(feature = "tracing")]
175 tracing::debug!(
176 worker_id = %execution_context.worker_id,
177 active_tasks = execution_context.active_tasks.load(Ordering::Relaxed),
178 semaphore_permits = execution_context.semaphore.as_ref().map(|s| s.available_permits()),
179 "Task spawned successfully"
180 );
181 }
182 SpawnResult::Rejected(rejected_task) => {
183 #[cfg(feature = "tracing")]
184 tracing::warn!(
185 worker_id = %execution_context.worker_id,
186 task_id = %rejected_task.metadata.id,
187 task_name = %rejected_task.metadata.name,
188 active_tasks = current_active_tasks,
189 semaphore_permits = execution_context.semaphore.as_ref().map(|s| s.available_permits()),
190 "Task rejected due to backpressure"
191 );
192 }
193 SpawnResult::Failed(_e) => {
194 #[cfg(feature = "tracing")]
195 tracing::error!(
196 worker_id = %execution_context.worker_id,
197 error = %_e,
198 active_tasks = current_active_tasks,
199 "Failed to handle task execution"
200 );
201 }
202 }
203 }
204 Ok(None) => {
205 #[cfg(feature = "tracing")]
207 tracing::trace!(
208 worker_id = %execution_context.worker_id,
209 active_tasks = current_active_tasks,
210 "No tasks available, updating heartbeat"
211 );
212
213 if let Err(_e) = execution_context.broker.update_worker_heartbeat(&execution_context.worker_id).await {
214 #[cfg(feature = "tracing")]
215 tracing::error!(
216 worker_id = %execution_context.worker_id,
217 error = %_e,
218 "Failed to update worker heartbeat"
219 );
220 }
221 }
222 Err(_e) => {
223 #[cfg(feature = "tracing")]
224 tracing::error!(
225 worker_id = %execution_context.worker_id,
226 error = %_e,
227 active_tasks = current_active_tasks,
228 "Error dequeuing task, backing off"
229 );
230
231 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
232 }
233 }
234 }
235 }
236 }
237
238 #[cfg(feature = "tracing")]
239 tracing::info!(
240 worker_id = %execution_context.worker_id,
241 "Worker main loop ended gracefully"
242 );
243 });
244
245 self.shutdown_tx = Some(shutdown_tx);
246 self.handle = Some(handle);
247
248 #[cfg(feature = "tracing")]
249 tracing::info!("Started worker {}", self.id);
250
251 Ok(self)
252 }
253
254 async fn handle_task_execution(
256 context: TaskExecutionContext,
257 task_wrapper: TaskWrapper,
258 ) -> SpawnResult {
259 let semaphore_opt = context.semaphore.clone();
261
262 match semaphore_opt {
263 Some(semaphore) => {
264 let semaphore_clone = semaphore.clone();
266
267 match semaphore.try_acquire() {
269 Ok(_permit) => {
270 drop(_permit);
273 Self::spawn_task_with_semaphore(context, task_wrapper, semaphore_clone)
274 .await;
275 SpawnResult::Spawned
276 }
277 Err(_) => {
278 Self::handle_backpressure(context, task_wrapper).await
280 }
281 }
282 }
283 None => {
284 Self::execute_task_directly(context, task_wrapper).await;
286 SpawnResult::Spawned
287 }
288 }
289 }
290
291 async fn graceful_shutdown(context: &TaskExecutionContext) {
293 let shutdown_timeout = tokio::time::Duration::from_secs(30);
294 let check_interval = tokio::time::Duration::from_millis(100);
295
296 let start_time = tokio::time::Instant::now();
297
298 while context.active_tasks.load(Ordering::Relaxed) > 0
299 && start_time.elapsed() < shutdown_timeout
300 {
301 #[cfg(feature = "tracing")]
302 tracing::debug!(
303 "Worker {} waiting for {} active tasks to complete",
304 context.worker_id,
305 context.active_tasks.load(Ordering::Relaxed)
306 );
307
308 tokio::time::sleep(check_interval).await;
309 }
310
311 if context.active_tasks.load(Ordering::Relaxed) > 0 {
312 #[cfg(feature = "tracing")]
313 tracing::warn!(
314 "Worker {} shutdown timeout reached with {} active tasks",
315 context.worker_id,
316 context.active_tasks.load(Ordering::Relaxed)
317 );
318 }
319 }
320
321 async fn spawn_task_with_semaphore(
323 context: TaskExecutionContext,
324 task_wrapper: TaskWrapper,
325 semaphore: Arc<Semaphore>,
326 ) {
327 context.active_tasks.fetch_add(1, Ordering::Relaxed);
329
330 tokio::spawn(async move {
331 let _permit = semaphore
333 .acquire()
334 .await
335 .expect("Semaphore should not be closed");
336
337 if let Err(_e) = Self::process_task(
338 &context.broker,
339 &context.task_registry,
340 &context.worker_id,
341 task_wrapper,
342 )
343 .await
344 {
345 #[cfg(feature = "tracing")]
346 tracing::error!("Task processing failed: {}", _e);
347 }
348
349 context.active_tasks.fetch_sub(1, Ordering::Relaxed);
351 });
353 }
354
355 async fn handle_backpressure(
357 context: TaskExecutionContext,
358 task_wrapper: TaskWrapper,
359 ) -> SpawnResult {
360 match Self::requeue_task(&context.broker, task_wrapper.clone()).await {
362 Ok(_) => {
363 #[cfg(feature = "tracing")]
364 tracing::debug!(
365 "Task {} re-queued due to backpressure",
366 task_wrapper.metadata.id
367 );
368 SpawnResult::Rejected(task_wrapper)
369 }
370 Err(e) => SpawnResult::Failed(e),
371 }
372 }
373
374 async fn execute_task_directly(context: TaskExecutionContext, task_wrapper: TaskWrapper) {
376 context.active_tasks.fetch_add(1, Ordering::Relaxed);
378
379 tokio::spawn(async move {
380 if let Err(_e) = Self::process_task(
381 &context.broker,
382 &context.task_registry,
383 &context.worker_id,
384 task_wrapper,
385 )
386 .await
387 {
388 #[cfg(feature = "tracing")]
389 tracing::error!("Task processing failed: {}", _e);
390 }
391
392 context.active_tasks.fetch_sub(1, Ordering::Relaxed);
394 });
395 }
396
397 async fn requeue_task(
399 broker: &RedisBroker,
400 task_wrapper: TaskWrapper,
401 ) -> Result<(), TaskQueueError> {
402 broker
403 .enqueue_task_wrapper(task_wrapper, queue_names::DEFAULT)
404 .await?;
405 Ok(())
406 }
407
408 pub async fn stop(self) {
409 if let Some(tx) = self.shutdown_tx {
410 if let Err(_e) = tx.send(()).await {
411 #[cfg(feature = "tracing")]
412 tracing::error!("Failed to send shutdown signal");
413 }
414 }
415
416 if let Some(handle) = self.handle {
417 let _ = handle.await;
418 }
419 }
420
421 async fn process_task(
422 broker: &RedisBroker,
423 task_registry: &crate::TaskRegistry,
424 worker_id: &str,
425 mut task_wrapper: TaskWrapper,
426 ) -> Result<(), TaskQueueError> {
427 let task_id = task_wrapper.metadata.id;
428 let task_name = &task_wrapper.metadata.name;
429
430 let span = tracing::info_span!(
432 "process_task",
433 task_id = %task_id,
434 task_name = task_name,
435 worker_id = worker_id,
436 attempt = task_wrapper.metadata.attempts + 1,
437 max_retries = task_wrapper.metadata.max_retries
438 );
439
440 async move {
441 #[cfg(feature = "tracing")]
442 tracing::info!(
443 created_at = %task_wrapper.metadata.created_at,
444 timeout_seconds = task_wrapper.metadata.timeout_seconds,
445 payload_size_bytes = task_wrapper.payload.len(),
446 "Starting task processing"
447 );
448
449 let execution_start = std::time::Instant::now();
450 task_wrapper.metadata.attempts += 1;
451
452 #[cfg(feature = "tracing")]
454 tracing::debug!(
455 attempt = task_wrapper.metadata.attempts,
456 created_at = %task_wrapper.metadata.created_at,
457 "Task execution attempt started"
458 );
459
460 match Self::execute_task_with_registry(task_registry, &task_wrapper).await {
462 Ok(result) => {
463 let execution_duration = execution_start.elapsed();
464
465 #[cfg(feature = "tracing")]
466 tracing::info!(
467 duration_ms = execution_duration.as_millis(),
468 result_size_bytes = result.len(),
469 success = true,
470 "Task completed successfully"
471 );
472
473 broker
474 .mark_task_completed(task_id, queue_names::DEFAULT)
475 .await?;
476 }
477 Err(error) => {
478 let execution_duration = execution_start.elapsed();
479 let error_msg = error.to_string();
480
481 #[cfg(feature = "tracing")]
482 tracing::error!(
483 duration_ms = execution_duration.as_millis(),
484 error = %error,
485 error_source = error.source().map(|e| e.to_string()).as_deref(),
486 success = false,
487 "Task execution failed"
488 );
489
490 if task_wrapper.metadata.attempts < task_wrapper.metadata.max_retries {
492 let remaining_retries =
493 task_wrapper.metadata.max_retries - task_wrapper.metadata.attempts;
494
495 #[cfg(feature = "tracing")]
496 tracing::warn!(
497 remaining_retries = remaining_retries,
498 retry_delay_ms = 1000 * task_wrapper.metadata.attempts as u64, "Re-queuing task for retry"
500 );
501
502 broker
504 .enqueue_task_wrapper(task_wrapper, queue_names::DEFAULT)
505 .await?;
506 } else {
507 #[cfg(feature = "tracing")]
508 tracing::error!(
509 final_error = %error_msg,
510 total_attempts = task_wrapper.metadata.attempts,
511 "Task failed permanently - maximum retries exceeded"
512 );
513
514 broker
515 .mark_task_failed_with_reason(
516 task_id,
517 queue_names::DEFAULT,
518 Some(error_msg),
519 )
520 .await?;
521 }
522 }
523 }
524
525 Ok(())
526 }
527 .instrument(span)
528 .await
529 }
530
531 async fn execute_task_with_registry(
532 task_registry: &crate::TaskRegistry,
533 task_wrapper: &TaskWrapper,
534 ) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
535 let task_name = &task_wrapper.metadata.name;
536 let task_id = task_wrapper.metadata.id;
537 let execution_start = std::time::Instant::now();
538
539 #[cfg(feature = "tracing")]
540 tracing::debug!(
541 task_id = %task_id,
542 task_name = task_name,
543 payload_size_bytes = task_wrapper.payload.len(),
544 "Attempting task execution with registry"
545 );
546
547 match task_registry
549 .execute(task_name, task_wrapper.payload.clone())
550 .await
551 {
552 Ok(result) => {
553 #[cfg(feature = "tracing")]
554 tracing::info!(
555 task_id = %task_id,
556 task_name = task_name,
557 duration_ms = execution_start.elapsed().as_millis(),
558 result_size_bytes = result.len(),
559 execution_method = "registry",
560 "Task executed successfully with registered executor"
561 );
562
563 Ok(result)
564 }
565 Err(error) => {
566 let execution_duration = execution_start.elapsed();
567
568 let error_msg = error.to_string();
570 if error_msg.contains("Unknown task type") {
571 #[cfg(feature = "tracing")]
572 tracing::warn!(
573 task_id = %task_id,
574 task_name = task_name,
575 duration_ms = execution_duration.as_millis(),
576 error = %error,
577 "No executor found for task type, using fallback"
578 );
579
580 #[derive(Serialize)]
582 struct FallbackResponse {
583 status: String,
584 message: String,
585 timestamp: String,
586 task_id: String,
587 task_name: String,
588 }
589
590 let response = FallbackResponse {
591 status: "completed".to_string(),
592 message: format!("Task {} completed with fallback executor", task_name),
593 timestamp: chrono::Utc::now().to_rfc3339(),
594 task_id: task_id.to_string(),
595 task_name: task_name.to_string(),
596 };
597
598 let serialized = serde_json::to_vec(&response)
599 .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
600
601 #[cfg(feature = "tracing")]
602 tracing::info!(
603 task_id = %task_id,
604 task_name = task_name,
605 duration_ms = execution_duration.as_millis(),
606 result_size_bytes = serialized.len(),
607 execution_method = "fallback",
608 "Task completed with fallback executor"
609 );
610
611 Ok(serialized)
612 } else {
613 #[cfg(feature = "tracing")]
615 tracing::error!(
616 task_id = %task_id,
617 task_name = task_name,
618 duration_ms = execution_duration.as_millis(),
619 error = %error,
620 error_source = error.source().map(|e| e.to_string()).as_deref(),
621 execution_method = "registry",
622 "Task execution failed in registered executor"
623 );
624
625 Err(error)
626 }
627 }
628 }
629 }
630}
631
632#[cfg(test)]
633mod tests {
634 use super::*;
635 use crate::{TaskId, TaskMetadata};
636 use std::time::Duration;
637 use tokio::time::timeout;
638
639 fn create_test_broker() -> Arc<RedisBroker> {
640 let redis_url = std::env::var("REDIS_TEST_URL")
642 .unwrap_or_else(|_| "redis://127.0.0.1:6379/15".to_string());
643
644 let config = deadpool_redis::Config::from_url(&redis_url);
646 let pool = config
647 .create_pool(Some(deadpool_redis::Runtime::Tokio1))
648 .expect("Failed to create test pool");
649
650 Arc::new(RedisBroker { pool })
651 }
652
653 fn create_test_scheduler() -> Arc<TaskScheduler> {
654 let broker = create_test_broker();
655 Arc::new(TaskScheduler::new(broker))
656 }
657
658 fn create_test_task_wrapper() -> TaskWrapper {
659 TaskWrapper {
660 metadata: TaskMetadata {
661 id: TaskId::new_v4(),
662 name: "test_task".to_string(),
663 created_at: chrono::Utc::now(),
664 attempts: 0,
665 max_retries: 3,
666 timeout_seconds: 60,
667 },
668 payload: b"test payload".to_vec(),
669 }
670 }
671
672 async fn get_test_connection(
674 broker: &RedisBroker,
675 ) -> Result<deadpool_redis::Connection, deadpool_redis::PoolError> {
676 broker.pool.get().await
677 }
678
679 #[test]
680 fn test_worker_creation() {
681 let broker = create_test_broker();
682 let scheduler = create_test_scheduler();
683 let worker_id = "test_worker_001".to_string();
684
685 let worker = Worker::new(worker_id.clone(), broker, scheduler);
686
687 assert_eq!(worker.id, worker_id);
688 assert_eq!(worker.max_concurrent_tasks, 10);
689 assert_eq!(worker.active_task_count(), 0);
690 assert!(worker.task_semaphore.is_some());
691 assert!(worker.shutdown_tx.is_none());
692 assert!(worker.handle.is_none());
693 }
694
695 #[test]
696 fn test_worker_with_task_registry() {
697 let broker = create_test_broker();
698 let scheduler = create_test_scheduler();
699 let worker_id = "test_worker_002".to_string();
700 let registry = Arc::new(crate::TaskRegistry::new());
701
702 let _worker =
703 Worker::new(worker_id, broker, scheduler).with_task_registry(registry.clone());
704
705 assert_eq!(Arc::strong_count(®istry), 2); }
708
709 #[test]
710 fn test_worker_with_backpressure_config() {
711 let broker = create_test_broker();
712 let scheduler = create_test_scheduler();
713 let worker_id = "test_worker_003".to_string();
714
715 let config = WorkerBackpressureConfig {
716 max_concurrent_tasks: 20,
717 queue_size_threshold: 100,
718 backpressure_delay_ms: 500,
719 };
720
721 let worker =
722 Worker::new(worker_id, broker, scheduler).with_backpressure_config(config.clone());
723
724 assert_eq!(worker.max_concurrent_tasks, config.max_concurrent_tasks);
725 assert!(worker.task_semaphore.is_some());
726
727 if let Some(semaphore) = &worker.task_semaphore {
728 assert_eq!(semaphore.available_permits(), config.max_concurrent_tasks);
729 }
730 }
731
732 #[test]
733 fn test_worker_with_max_concurrent_tasks() {
734 let broker = create_test_broker();
735 let scheduler = create_test_scheduler();
736 let worker_id = "test_worker_004".to_string();
737 let max_tasks = 15;
738
739 let worker = Worker::new(worker_id, broker, scheduler).with_max_concurrent_tasks(max_tasks);
740
741 assert_eq!(worker.max_concurrent_tasks, max_tasks);
742 assert!(worker.task_semaphore.is_some());
743
744 if let Some(semaphore) = &worker.task_semaphore {
745 assert_eq!(semaphore.available_permits(), max_tasks);
746 }
747 }
748
749 #[test]
750 fn test_worker_backpressure_config_clone() {
751 let original = WorkerBackpressureConfig {
752 max_concurrent_tasks: 25,
753 queue_size_threshold: 200,
754 backpressure_delay_ms: 1000,
755 };
756
757 let cloned = original.clone();
758
759 assert_eq!(original.max_concurrent_tasks, cloned.max_concurrent_tasks);
760 assert_eq!(original.queue_size_threshold, cloned.queue_size_threshold);
761 assert_eq!(original.backpressure_delay_ms, cloned.backpressure_delay_ms);
762 }
763
764 #[test]
765 fn test_worker_backpressure_config_debug() {
766 let config = WorkerBackpressureConfig {
767 max_concurrent_tasks: 8,
768 queue_size_threshold: 50,
769 backpressure_delay_ms: 250,
770 };
771
772 let debug_str = format!("{:?}", config);
773
774 assert!(debug_str.contains("WorkerBackpressureConfig"));
775 assert!(debug_str.contains("max_concurrent_tasks: 8"));
776 assert!(debug_str.contains("queue_size_threshold: 50"));
777 assert!(debug_str.contains("backpressure_delay_ms: 250"));
778 }
779
780 #[test]
781 fn test_spawn_result_debug() {
782 let spawned = SpawnResult::Spawned;
783 let rejected = SpawnResult::Rejected(create_test_task_wrapper());
784 let failed = SpawnResult::Failed(TaskQueueError::Serialization(
785 rmp_serde::encode::Error::Syntax("test error".to_string()),
786 ));
787
788 let spawned_debug = format!("{:?}", spawned);
789 let rejected_debug = format!("{:?}", rejected);
790 let failed_debug = format!("{:?}", failed);
791
792 assert!(spawned_debug.contains("Spawned"));
793 assert!(rejected_debug.contains("Rejected"));
794 assert!(failed_debug.contains("Failed"));
795 }
796
797 #[test]
798 fn test_task_execution_context_clone() {
799 let broker = create_test_broker();
800 let task_registry = Arc::new(crate::TaskRegistry::new());
801 let worker_id = "test_worker_005".to_string();
802 let semaphore = Some(Arc::new(Semaphore::new(10)));
803 let active_tasks = Arc::new(AtomicUsize::new(0));
804
805 let context = TaskExecutionContext {
806 broker: broker.clone(),
807 task_registry: task_registry.clone(),
808 worker_id: worker_id.clone(),
809 semaphore: semaphore.clone(),
810 active_tasks: active_tasks.clone(),
811 };
812
813 let cloned = context.clone();
814
815 assert_eq!(cloned.worker_id, worker_id);
816 assert_eq!(cloned.active_tasks.load(Ordering::Relaxed), 0);
817 assert!(cloned.semaphore.is_some());
818 }
819
820 #[tokio::test]
821 async fn test_active_task_count_tracking() {
822 let broker = create_test_broker();
823 let scheduler = create_test_scheduler();
824 let worker_id = "test_worker_006".to_string();
825
826 let worker = Worker::new(worker_id, broker, scheduler);
827
828 assert_eq!(worker.active_task_count(), 0);
829
830 worker.active_tasks.fetch_add(1, Ordering::Relaxed);
832 assert_eq!(worker.active_task_count(), 1);
833
834 worker.active_tasks.fetch_add(2, Ordering::Relaxed);
835 assert_eq!(worker.active_task_count(), 3);
836
837 worker.active_tasks.fetch_sub(1, Ordering::Relaxed);
838 assert_eq!(worker.active_task_count(), 2);
839 }
840
841 #[tokio::test]
842 async fn test_requeue_task() {
843 let broker = create_test_broker();
844 let task_wrapper = create_test_task_wrapper();
845
846 if let Ok(mut conn) = get_test_connection(&broker).await {
848 let _: Result<String, _> = redis::cmd("FLUSHDB").query_async(&mut conn).await;
849 }
850
851 let result = Worker::requeue_task(&broker, task_wrapper).await;
852 assert!(result.is_ok());
853
854 let queue_size = broker
856 .get_queue_size(queue_names::DEFAULT)
857 .await
858 .expect("Failed to get queue size");
859 assert!(queue_size > 0);
860 }
861
862 #[tokio::test]
863 async fn test_execute_task_with_registry_fallback() {
864 let task_registry = crate::TaskRegistry::new();
865 let task_wrapper = create_test_task_wrapper();
866
867 let result = Worker::execute_task_with_registry(&task_registry, &task_wrapper).await;
868 assert!(result.is_ok());
869
870 let output = result.unwrap();
871 assert!(!output.is_empty());
872
873 let parsed: serde_json::Value =
875 serde_json::from_slice(&output).expect("Should be valid JSON");
876
877 assert_eq!(parsed["status"], "completed");
878 assert!(parsed["message"].as_str().unwrap().contains("test_task"));
879 assert!(parsed["timestamp"].is_string());
880 }
881
882 #[tokio::test]
883 async fn test_process_task_success() {
884 let broker = create_test_broker();
885 let task_registry = crate::TaskRegistry::new();
886 let worker_id = "test_worker_007";
887 let task_wrapper = create_test_task_wrapper();
888
889 if let Ok(mut conn) = get_test_connection(&broker).await {
891 let _: Result<String, _> = redis::cmd("FLUSHDB").query_async(&mut conn).await;
892 }
893
894 let result = Worker::process_task(&broker, &task_registry, worker_id, task_wrapper).await;
895 assert!(result.is_ok());
896
897 let metrics = broker
899 .get_queue_metrics(queue_names::DEFAULT)
900 .await
901 .expect("Failed to get metrics");
902 assert_eq!(metrics.processed_tasks, 1);
903 }
904
905 #[tokio::test]
906 async fn test_execute_task_directly() {
907 let broker = create_test_broker();
908 let task_registry = Arc::new(crate::TaskRegistry::new());
909 let worker_id = "test_worker_012".to_string();
910 let active_tasks = Arc::new(AtomicUsize::new(0));
911 let task_wrapper = create_test_task_wrapper();
912
913 let context = TaskExecutionContext {
914 broker,
915 task_registry,
916 worker_id,
917 semaphore: None,
918 active_tasks: active_tasks.clone(),
919 };
920
921 assert_eq!(active_tasks.load(Ordering::Relaxed), 0);
922
923 Worker::execute_task_directly(context, task_wrapper).await;
924
925 let mut attempts = 0;
927 while active_tasks.load(Ordering::Relaxed) == 0 && attempts < 50 {
928 tokio::time::sleep(Duration::from_millis(10)).await;
929 attempts += 1;
930 }
931
932 assert!(
934 active_tasks.load(Ordering::Relaxed) >= 1,
935 "Task should have started and incremented active count"
936 );
937
938 let mut attempts = 0;
940 while active_tasks.load(Ordering::Relaxed) > 0 && attempts < 100 {
941 tokio::time::sleep(Duration::from_millis(10)).await;
942 attempts += 1;
943 }
944
945 assert_eq!(active_tasks.load(Ordering::Relaxed), 0);
947 }
948
949 #[tokio::test]
950 async fn test_graceful_shutdown() {
951 let broker = create_test_broker();
952 let task_registry = Arc::new(crate::TaskRegistry::new());
953 let worker_id = "test_worker_008".to_string();
954 let active_tasks = Arc::new(AtomicUsize::new(2));
955
956 let context = TaskExecutionContext {
957 broker,
958 task_registry,
959 worker_id,
960 semaphore: None,
961 active_tasks: active_tasks.clone(),
962 };
963
964 let active_tasks_clone = active_tasks.clone();
966 tokio::spawn(async move {
967 tokio::time::sleep(Duration::from_millis(50)).await;
968 active_tasks_clone.fetch_sub(1, Ordering::Relaxed);
969 tokio::time::sleep(Duration::from_millis(50)).await;
970 active_tasks_clone.fetch_sub(1, Ordering::Relaxed);
971 });
972
973 let start = std::time::Instant::now();
975 Worker::graceful_shutdown(&context).await;
976 let elapsed = start.elapsed();
977
978 assert_eq!(context.active_tasks.load(Ordering::Relaxed), 0);
979 assert!(
980 elapsed < Duration::from_millis(500),
981 "Shutdown should complete in reasonable time"
982 ); }
984
985 #[test]
986 fn test_worker_default_configuration() {
987 let broker = create_test_broker();
988 let scheduler = create_test_scheduler();
989 let worker_id = "test_worker_011".to_string();
990
991 let worker = Worker::new(worker_id.clone(), broker.clone(), scheduler.clone());
992
993 assert_eq!(worker.id, worker_id);
995 assert_eq!(worker.max_concurrent_tasks, 10);
996 assert_eq!(worker.active_task_count(), 0);
997 assert!(worker.task_semaphore.is_some());
998 assert!(worker.shutdown_tx.is_none());
999 assert!(worker.handle.is_none());
1000
1001 assert_eq!(Arc::strong_count(&broker), 2); assert_eq!(Arc::strong_count(&scheduler), 2); }
1005
1006 #[test]
1007 fn test_worker_backpressure_config_defaults() {
1008 let config = WorkerBackpressureConfig {
1009 max_concurrent_tasks: 50,
1010 queue_size_threshold: 1000,
1011 backpressure_delay_ms: 100,
1012 };
1013
1014 assert_eq!(config.max_concurrent_tasks, 50);
1015 assert_eq!(config.queue_size_threshold, 1000);
1016 assert_eq!(config.backpressure_delay_ms, 100);
1017 }
1018
1019 #[tokio::test]
1020 async fn test_worker_method_chaining() {
1021 let broker = create_test_broker();
1022 let scheduler = create_test_scheduler();
1023 let worker_id = "test_worker_013".to_string();
1024 let registry = Arc::new(crate::TaskRegistry::new());
1025
1026 let config = WorkerBackpressureConfig {
1027 max_concurrent_tasks: 25,
1028 queue_size_threshold: 500,
1029 backpressure_delay_ms: 200,
1030 };
1031
1032 let worker = Worker::new(worker_id.clone(), broker, scheduler)
1033 .with_task_registry(registry.clone())
1034 .with_backpressure_config(config.clone())
1035 .with_max_concurrent_tasks(30); assert_eq!(worker.id, worker_id);
1038 assert_eq!(worker.max_concurrent_tasks, 30); assert_eq!(Arc::strong_count(®istry), 2); if let Some(semaphore) = &worker.task_semaphore {
1042 assert_eq!(semaphore.available_permits(), 30);
1043 }
1044 }
1045
1046 #[tokio::test]
1047 async fn test_graceful_shutdown_timeout() {
1048 let broker = create_test_broker();
1049 let task_registry = Arc::new(crate::TaskRegistry::new());
1050 let worker_id = "test_worker_009".to_string();
1051 let active_tasks = Arc::new(AtomicUsize::new(1)); let context = TaskExecutionContext {
1054 broker,
1055 task_registry,
1056 worker_id,
1057 semaphore: None,
1058 active_tasks,
1059 };
1060
1061 let result = timeout(
1063 Duration::from_millis(200),
1064 Worker::graceful_shutdown(&context),
1065 )
1066 .await;
1067 assert!(result.is_err()); assert_eq!(context.active_tasks.load(Ordering::Relaxed), 1); }
1070
1071 #[tokio::test]
1072 async fn test_handle_backpressure() {
1073 let broker = create_test_broker();
1074 let task_registry = Arc::new(crate::TaskRegistry::new());
1075 let worker_id = "test_worker_010".to_string();
1076 let task_wrapper = create_test_task_wrapper();
1077
1078 if let Ok(mut conn) = get_test_connection(&broker).await {
1080 let _: Result<String, _> = redis::cmd("FLUSHDB").query_async(&mut conn).await;
1081 }
1082
1083 let context = TaskExecutionContext {
1084 broker: broker.clone(),
1085 task_registry,
1086 worker_id,
1087 semaphore: None,
1088 active_tasks: Arc::new(AtomicUsize::new(0)),
1089 };
1090
1091 let result = Worker::handle_backpressure(context, task_wrapper.clone()).await;
1092
1093 match result {
1094 SpawnResult::Rejected(rejected_wrapper) => {
1095 assert_eq!(rejected_wrapper.metadata.id, task_wrapper.metadata.id);
1096 }
1097 _ => panic!("Expected rejected result"),
1098 }
1099
1100 let queue_size = broker
1102 .get_queue_size(queue_names::DEFAULT)
1103 .await
1104 .expect("Failed to get queue size");
1105 assert!(queue_size > 0);
1106 }
1107}