1use super::task::{Task, TaskResult};
3use std::collections::HashMap;
4use std::sync::{Arc, Mutex};
5use tokio::task::JoinHandle;
6use uuid::Uuid;
7
8pub struct WorkerPool {
10 workers: Arc<Mutex<HashMap<Uuid, WorkerInfo>>>,
11 max_workers: usize,
12}
13
14#[derive(Debug)]
15struct WorkerInfo {
16 handle: JoinHandle<TaskResult>,
17 task_id: String,
18 start_time: std::time::Instant,
19 worker_type: WorkerType,
20}
21
22#[derive(Debug, Clone, PartialEq, Eq)]
27pub enum WorkerType {
28 CpuIntensive,
30 IoIntensive,
32 Mixed,
34}
35
36impl WorkerPool {
37 pub fn new(max_workers: usize) -> Self {
43 Self {
44 workers: Arc::new(Mutex::new(HashMap::new())),
45 max_workers,
46 }
47 }
48
49 pub async fn execute(&self, task: Box<dyn Task + Send + Sync>) -> Result<TaskResult, String> {
51 let worker_id = Uuid::now_v7();
52 let task_id = task.task_id();
53 let worker_type = self.determine_worker_type(task.task_type());
54
55 {
56 let workers = self.workers.lock().unwrap();
57 if workers.len() >= self.max_workers {
58 return Err("Worker pool is full".to_string());
59 }
60 }
61
62 let handle = tokio::spawn(async move { task.execute().await });
63
64 {
65 let mut workers = self.workers.lock().unwrap();
66 workers.insert(
67 worker_id,
68 WorkerInfo {
69 handle,
70 task_id: task_id.clone(),
71 start_time: std::time::Instant::now(),
72 worker_type,
73 },
74 );
75 }
76
77 Ok(TaskResult::Success("Task submitted".to_string()))
79 }
80
81 fn determine_worker_type(&self, task_type: &str) -> WorkerType {
82 match task_type {
83 "convert" => WorkerType::CpuIntensive,
84 "sync" => WorkerType::Mixed,
85 "match" => WorkerType::IoIntensive,
86 "validate" => WorkerType::IoIntensive,
87 _ => WorkerType::Mixed,
88 }
89 }
90
91 pub fn get_active_count(&self) -> usize {
93 self.workers.lock().unwrap().len()
94 }
95
96 pub fn get_capacity(&self) -> usize {
98 self.max_workers
99 }
100
101 pub fn get_worker_stats(&self) -> WorkerStats {
103 let workers = self.workers.lock().unwrap();
104 let mut cpu = 0;
105 let mut io = 0;
106 let mut mixed = 0;
107 for w in workers.values() {
108 match w.worker_type {
109 WorkerType::CpuIntensive => cpu += 1,
110 WorkerType::IoIntensive => io += 1,
111 WorkerType::Mixed => mixed += 1,
112 }
113 }
114 WorkerStats {
115 total_active: workers.len(),
116 cpu_intensive_count: cpu,
117 io_intensive_count: io,
118 mixed_count: mixed,
119 max_capacity: self.max_workers,
120 }
121 }
122
123 pub async fn shutdown(&self) {
125 let workers = { std::mem::take(&mut *self.workers.lock().unwrap()) };
126 for (id, info) in workers {
127 if !crate::cli::output::is_quiet() && !crate::cli::output::active_mode().is_json() {
132 eprintln!(
133 "Waiting for worker {} to complete task {}",
134 id, info.task_id
135 );
136 }
137 let _ = info.handle.await;
138 }
139 }
140
141 pub fn list_active_workers(&self) -> Vec<ActiveWorkerInfo> {
143 let workers = self.workers.lock().unwrap();
144 workers
145 .iter()
146 .map(|(id, info)| ActiveWorkerInfo {
147 worker_id: *id,
148 task_id: info.task_id.clone(),
149 worker_type: info.worker_type.clone(),
150 runtime: info.start_time.elapsed(),
151 })
152 .collect()
153 }
154}
155
156impl Clone for WorkerPool {
157 fn clone(&self) -> Self {
158 Self {
159 workers: Arc::clone(&self.workers),
160 max_workers: self.max_workers,
161 }
162 }
163}
164
165#[derive(Debug, Clone)]
170pub struct WorkerStats {
171 pub total_active: usize,
173 pub cpu_intensive_count: usize,
175 pub io_intensive_count: usize,
177 pub mixed_count: usize,
179 pub max_capacity: usize,
181}
182
183#[derive(Debug, Clone)]
187pub struct ActiveWorkerInfo {
188 pub worker_id: Uuid,
190 pub task_id: String,
192 pub worker_type: WorkerType,
194 pub runtime: std::time::Duration,
196}
197
198pub struct Worker {
200 id: Uuid,
201 status: WorkerStatus,
202}
203
204#[derive(Debug, Clone)]
209pub enum WorkerStatus {
210 Idle,
212 Busy(String),
214 Stopped,
216 Error(String),
218}
219
220impl Worker {
221 pub fn new() -> Self {
223 Self {
224 id: Uuid::now_v7(),
225 status: WorkerStatus::Idle,
226 }
227 }
228
229 pub fn id(&self) -> Uuid {
231 self.id
232 }
233
234 pub fn status(&self) -> &WorkerStatus {
236 &self.status
237 }
238
239 pub fn set_status(&mut self, status: WorkerStatus) {
245 self.status = status;
246 }
247}
248
249#[cfg(test)]
250mod tests {
251 use super::*;
252
253 #[tokio::test]
254 async fn test_worker_pool_capacity() {
255 let pool = WorkerPool::new(2);
256 assert_eq!(pool.get_capacity(), 2);
257 assert_eq!(pool.get_active_count(), 0);
258 let stats = pool.get_worker_stats();
259 assert_eq!(stats.max_capacity, 2);
260 assert_eq!(stats.total_active, 0);
261 }
262
263 #[tokio::test]
264 async fn test_execute_and_active_count() {
265 use crate::core::parallel::task::{Task, TaskResult};
266 use async_trait::async_trait;
267
268 #[derive(Clone)]
269 struct DummyTask {
270 id: String,
271 tp: &'static str,
272 }
273
274 #[async_trait]
275 impl Task for DummyTask {
276 async fn execute(&self) -> TaskResult {
277 TaskResult::Success(self.id.clone())
278 }
279 fn task_type(&self) -> &'static str {
280 self.tp
281 }
282 fn task_id(&self) -> String {
283 self.id.clone()
284 }
285 }
286
287 let pool = WorkerPool::new(1);
288 let task = DummyTask {
289 id: "t1".into(),
290 tp: "convert",
291 };
292 let res = pool.execute(Box::new(task.clone())).await;
293 assert!(matches!(res, Ok(TaskResult::Success(_))));
294 assert_eq!(pool.get_active_count(), 1);
295 }
296
297 #[tokio::test]
298 async fn test_reject_when_full() {
299 use crate::core::parallel::task::{Task, TaskResult};
300 use async_trait::async_trait;
301
302 #[derive(Clone)]
303 struct DummyTask;
304
305 #[async_trait]
306 impl Task for DummyTask {
307 async fn execute(&self) -> TaskResult {
308 TaskResult::Success("".into())
309 }
310 fn task_type(&self) -> &'static str {
311 "match"
312 }
313 fn task_id(&self) -> String {
314 "".into()
315 }
316 }
317
318 let pool = WorkerPool::new(1);
319 let _ = pool.execute(Box::new(DummyTask)).await;
320 let err = pool.execute(Box::new(DummyTask)).await;
321 assert!(err.is_err());
322 }
323
324 #[tokio::test]
325 async fn test_list_active_workers_and_stats() {
326 use super::WorkerType;
327 use crate::core::parallel::task::{Task, TaskResult};
328 use async_trait::async_trait;
329
330 #[derive(Clone)]
331 struct DummyTask2;
332
333 #[async_trait]
334 impl Task for DummyTask2 {
335 async fn execute(&self) -> TaskResult {
336 TaskResult::Success("".into())
337 }
338 fn task_type(&self) -> &'static str {
339 "sync"
340 }
341 fn task_id(&self) -> String {
342 "tok2".into()
343 }
344 }
345
346 let pool = WorkerPool::new(2);
347 let _ = pool.execute(Box::new(DummyTask2)).await;
348 let workers = pool.list_active_workers();
349 assert_eq!(workers.len(), 1);
350 let info = &workers[0];
351 assert_eq!(info.task_id, "tok2");
352 assert_eq!(info.worker_type, WorkerType::Mixed);
353 let stats = pool.get_worker_stats();
354 assert_eq!(stats.total_active, 1);
355 }
356
357 #[tokio::test]
359 async fn test_worker_job_distribution() {
360 use crate::core::parallel::task::{Task, TaskResult};
361 use async_trait::async_trait;
362 use std::sync::Arc;
363 use std::sync::atomic::{AtomicUsize, Ordering};
364
365 #[derive(Clone)]
366 struct CountingTask {
367 id: String,
368 counter: Arc<AtomicUsize>,
369 }
370
371 #[async_trait]
372 impl Task for CountingTask {
373 async fn execute(&self) -> TaskResult {
374 self.counter.fetch_add(1, Ordering::SeqCst);
375 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
376 TaskResult::Success(format!("task-{}", self.id))
377 }
378 fn task_type(&self) -> &'static str {
379 "convert"
380 }
381 fn task_id(&self) -> String {
382 self.id.clone()
383 }
384 }
385
386 let pool = WorkerPool::new(4);
387 let counter = Arc::new(AtomicUsize::new(0));
388 let mut handles = Vec::new();
389
390 for i in 0..4 {
392 let task = CountingTask {
394 id: format!("task-{}", i),
395 counter: Arc::clone(&counter),
396 };
397
398 let pool_clone = pool.clone();
400 let handle = tokio::spawn(async move { pool_clone.execute(Box::new(task)).await });
401 handles.push(handle);
402 }
403
404 for handle in handles {
406 let result = handle.await.unwrap();
407 assert!(result.is_ok(), "Task submission should succeed");
408 }
409
410 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
412
413 let final_count = counter.load(Ordering::SeqCst);
415 assert_eq!(final_count, 4, "All 4 tasks should have been executed");
416 }
417
418 #[tokio::test]
420 async fn test_worker_error_recovery() {
421 use crate::core::parallel::task::{Task, TaskResult};
422 use async_trait::async_trait;
423
424 #[derive(Clone)]
425 struct FailingTask {
426 id: String,
427 should_fail: bool,
428 }
429
430 #[async_trait]
431 impl Task for FailingTask {
432 async fn execute(&self) -> TaskResult {
433 if self.should_fail {
434 TaskResult::Failed("Intentional failure".to_string())
435 } else {
436 TaskResult::Success(format!("success-{}", self.id))
437 }
438 }
439 fn task_type(&self) -> &'static str {
440 "sync"
441 }
442 fn task_id(&self) -> String {
443 self.id.clone()
444 }
445 }
446
447 let pool = WorkerPool::new(2);
448
449 let success_task = FailingTask {
451 id: "success".to_string(),
452 should_fail: false,
453 };
454 let result = pool.execute(Box::new(success_task)).await;
455 assert!(result.is_ok(), "Successful task should be submitted");
456
457 let fail_task = FailingTask {
459 id: "fail".to_string(),
460 should_fail: true,
461 };
462 let result = pool.execute(Box::new(fail_task)).await;
463 assert!(
464 result.is_ok(),
465 "Failing task should still be submitted successfully"
466 );
467
468 assert!(
470 pool.get_active_count() <= 2,
471 "Active count should be within limits"
472 );
473 }
474
475 #[tokio::test]
477 async fn test_parallel_processing_performance() {
478 use crate::core::parallel::task::{Task, TaskResult};
479 use async_trait::async_trait;
480 use std::time::Instant;
481
482 #[derive(Clone)]
483 struct CpuIntensiveTask {
484 id: String,
485 duration_ms: u64,
486 }
487
488 #[async_trait]
489 impl Task for CpuIntensiveTask {
490 async fn execute(&self) -> TaskResult {
491 tokio::time::sleep(tokio::time::Duration::from_millis(self.duration_ms)).await;
493 TaskResult::Success(format!("completed-{}", self.id))
494 }
495 fn task_type(&self) -> &'static str {
496 "convert"
497 }
498 fn task_id(&self) -> String {
499 self.id.clone()
500 }
501 }
502
503 let sequential_pool = WorkerPool::new(1);
505 let start = Instant::now();
506
507 for i in 0..2 {
508 let task = CpuIntensiveTask {
510 id: format!("seq-{}", i),
511 duration_ms: 10, };
513 if let Err(e) = sequential_pool.execute(Box::new(task)).await {
514 println!("Sequential task {} failed: {}", i, e);
515 }
517 }
518 let sequential_time = start.elapsed();
519
520 let parallel_pool = WorkerPool::new(2); let start = Instant::now();
523
524 let task = CpuIntensiveTask {
526 id: "par-0".to_string(),
527 duration_ms: 10,
528 };
529 if let Err(e) = parallel_pool.execute(Box::new(task)).await {
530 println!("Parallel task failed: {}", e);
531 }
532 let parallel_time = start.elapsed();
533
534 println!("Sequential submission time: {:?}", sequential_time);
537 println!("Parallel submission time: {:?}", parallel_time);
538
539 assert!(
541 parallel_time <= sequential_time * 2,
542 "Parallel submission should not be significantly slower"
543 );
544 }
545
546 #[tokio::test]
548 async fn test_resource_management() {
549 let pool = WorkerPool::new(3);
550
551 assert_eq!(
553 pool.determine_worker_type("convert"),
554 WorkerType::CpuIntensive
555 );
556 assert_eq!(pool.determine_worker_type("sync"), WorkerType::Mixed);
557 assert_eq!(pool.determine_worker_type("match"), WorkerType::IoIntensive);
558 assert_eq!(
559 pool.determine_worker_type("validate"),
560 WorkerType::IoIntensive
561 );
562 assert_eq!(pool.determine_worker_type("unknown"), WorkerType::Mixed);
563
564 let stats = pool.get_worker_stats();
566 assert_eq!(stats.total_active, 0);
567 assert_eq!(stats.max_capacity, 3);
568 assert_eq!(stats.cpu_intensive_count, 0);
569 assert_eq!(stats.io_intensive_count, 0);
570 assert_eq!(stats.mixed_count, 0);
571
572 assert_eq!(pool.get_capacity(), 3);
574 assert_eq!(pool.get_active_count(), 0);
575 }
576
577 #[tokio::test]
579 async fn test_worker_pool_shutdown() {
580 use crate::core::parallel::task::{Task, TaskResult};
581 use async_trait::async_trait;
582
583 #[derive(Clone)]
584 struct SlowTask {
585 id: String,
586 }
587
588 #[async_trait]
589 impl Task for SlowTask {
590 async fn execute(&self) -> TaskResult {
591 tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
592 TaskResult::Success(format!("slow-{}", self.id))
593 }
594 fn task_type(&self) -> &'static str {
595 "mixed"
596 }
597 fn task_id(&self) -> String {
598 self.id.clone()
599 }
600 }
601
602 let pool = WorkerPool::new(2);
603
604 for i in 0..2 {
606 let task = SlowTask {
607 id: format!("slow-{}", i),
608 };
609 pool.execute(Box::new(task)).await.unwrap();
610 }
611
612 assert!(pool.get_active_count() <= 2);
614
615 let start = std::time::Instant::now();
617 pool.shutdown().await;
618 let shutdown_time = start.elapsed();
619
620 assert!(shutdown_time >= std::time::Duration::from_millis(30));
622
623 assert_eq!(pool.get_active_count(), 0);
625 }
626
627 #[tokio::test]
629 async fn test_active_worker_tracking() {
630 use crate::core::parallel::task::{Task, TaskResult};
631 use async_trait::async_trait;
632
633 #[derive(Clone)]
634 struct TrackableTask {
635 id: String,
636 task_type: &'static str,
637 }
638
639 #[async_trait]
640 impl Task for TrackableTask {
641 async fn execute(&self) -> TaskResult {
642 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
643 TaskResult::Success(format!("tracked-{}", self.id))
644 }
645 fn task_type(&self) -> &'static str {
646 self.task_type
647 }
648 fn task_id(&self) -> String {
649 self.id.clone()
650 }
651 }
652
653 let pool = WorkerPool::new(3);
654
655 let tasks = vec![
657 ("cpu-task", "convert"),
658 ("io-task", "match"),
659 ("mixed-task", "sync"),
660 ];
661
662 for (id, task_type) in tasks {
663 let task = TrackableTask {
664 id: id.to_string(),
665 task_type,
666 };
667 pool.execute(Box::new(task)).await.unwrap();
668 }
669
670 let active_workers = pool.list_active_workers();
672 assert!(active_workers.len() <= 3, "Should not exceed pool capacity");
673
674 for worker in &active_workers {
676 assert!(!worker.task_id.is_empty(), "Task ID should be set");
677 assert!(matches!(
678 worker.worker_type,
679 WorkerType::CpuIntensive | WorkerType::IoIntensive | WorkerType::Mixed
680 ));
681 assert!(
682 worker.runtime.as_millis() < u128::MAX,
683 "Runtime should be valid"
684 );
685 }
686
687 let stats = pool.get_worker_stats();
689 assert!(stats.total_active <= 3);
690 assert_eq!(stats.max_capacity, 3);
691
692 tokio::time::sleep(tokio::time::Duration::from_millis(150)).await;
694 }
695
696 #[test]
697 fn worker_id_is_uuidv7() {
698 let w = Worker::new();
699 assert_eq!(w.id().get_version_num(), 7);
700 }
701
702 #[test]
703 fn consecutive_workers_have_distinct_ids() {
704 let a = Worker::new();
705 let b = Worker::new();
706 assert_ne!(a.id(), b.id());
707 }
708
709 #[tokio::test]
710 async fn worker_pool_execute_dispatches_uuidv7_worker_id() {
711 use crate::core::parallel::task::{Task, TaskResult};
712 use async_trait::async_trait;
713
714 struct DummyTask;
715
716 #[async_trait]
717 impl Task for DummyTask {
718 async fn execute(&self) -> TaskResult {
719 TaskResult::Success("done".into())
720 }
721 fn task_type(&self) -> &'static str {
722 "match"
723 }
724 fn task_id(&self) -> String {
725 "dummy".into()
726 }
727 }
728
729 let pool = WorkerPool::new(1);
730 let res = pool.execute(Box::new(DummyTask)).await;
731 assert!(matches!(res, Ok(TaskResult::Success(_))));
732
733 let workers = pool.list_active_workers();
734 assert_eq!(workers.len(), 1);
735 assert_eq!(workers[0].worker_id.get_version_num(), 7);
736 }
737}