1use super::task::{Task, TaskResult};
3use std::collections::HashMap;
4use std::sync::{Arc, Mutex};
5use tokio::task::JoinHandle;
6use uuid::Uuid;
7
8pub struct WorkerPool {
10 workers: Arc<Mutex<HashMap<Uuid, WorkerInfo>>>,
11 max_workers: usize,
12}
13
14#[derive(Debug)]
15struct WorkerInfo {
16 handle: JoinHandle<TaskResult>,
17 task_id: String,
18 start_time: std::time::Instant,
19 worker_type: WorkerType,
20}
21
22#[derive(Debug, Clone, PartialEq, Eq)]
27pub enum WorkerType {
28 CpuIntensive,
30 IoIntensive,
32 Mixed,
34}
35
36impl WorkerPool {
37 pub fn new(max_workers: usize) -> Self {
43 Self {
44 workers: Arc::new(Mutex::new(HashMap::new())),
45 max_workers,
46 }
47 }
48
49 pub async fn execute(&self, task: Box<dyn Task + Send + Sync>) -> Result<TaskResult, String> {
51 let worker_id = Uuid::new_v4();
52 let task_id = task.task_id();
53 let worker_type = self.determine_worker_type(task.task_type());
54
55 {
56 let workers = self.workers.lock().unwrap();
57 if workers.len() >= self.max_workers {
58 return Err("Worker pool is full".to_string());
59 }
60 }
61
62 let handle = tokio::spawn(async move { task.execute().await });
63
64 {
65 let mut workers = self.workers.lock().unwrap();
66 workers.insert(
67 worker_id,
68 WorkerInfo {
69 handle,
70 task_id: task_id.clone(),
71 start_time: std::time::Instant::now(),
72 worker_type,
73 },
74 );
75 }
76
77 Ok(TaskResult::Success("Task submitted".to_string()))
79 }
80
81 fn determine_worker_type(&self, task_type: &str) -> WorkerType {
82 match task_type {
83 "convert" => WorkerType::CpuIntensive,
84 "sync" => WorkerType::Mixed,
85 "match" => WorkerType::IoIntensive,
86 "validate" => WorkerType::IoIntensive,
87 _ => WorkerType::Mixed,
88 }
89 }
90
91 pub fn get_active_count(&self) -> usize {
93 self.workers.lock().unwrap().len()
94 }
95
96 pub fn get_capacity(&self) -> usize {
98 self.max_workers
99 }
100
101 pub fn get_worker_stats(&self) -> WorkerStats {
103 let workers = self.workers.lock().unwrap();
104 let mut cpu = 0;
105 let mut io = 0;
106 let mut mixed = 0;
107 for w in workers.values() {
108 match w.worker_type {
109 WorkerType::CpuIntensive => cpu += 1,
110 WorkerType::IoIntensive => io += 1,
111 WorkerType::Mixed => mixed += 1,
112 }
113 }
114 WorkerStats {
115 total_active: workers.len(),
116 cpu_intensive_count: cpu,
117 io_intensive_count: io,
118 mixed_count: mixed,
119 max_capacity: self.max_workers,
120 }
121 }
122
123 pub async fn shutdown(&self) {
125 let workers = { std::mem::take(&mut *self.workers.lock().unwrap()) };
126 for (id, info) in workers {
127 println!(
128 "Waiting for worker {} to complete task {}",
129 id, info.task_id
130 );
131 let _ = info.handle.await;
132 }
133 }
134
135 pub fn list_active_workers(&self) -> Vec<ActiveWorkerInfo> {
137 let workers = self.workers.lock().unwrap();
138 workers
139 .iter()
140 .map(|(id, info)| ActiveWorkerInfo {
141 worker_id: *id,
142 task_id: info.task_id.clone(),
143 worker_type: info.worker_type.clone(),
144 runtime: info.start_time.elapsed(),
145 })
146 .collect()
147 }
148}
149
150impl Clone for WorkerPool {
151 fn clone(&self) -> Self {
152 Self {
153 workers: Arc::clone(&self.workers),
154 max_workers: self.max_workers,
155 }
156 }
157}
158
159#[derive(Debug, Clone)]
164pub struct WorkerStats {
165 pub total_active: usize,
167 pub cpu_intensive_count: usize,
169 pub io_intensive_count: usize,
171 pub mixed_count: usize,
173 pub max_capacity: usize,
175}
176
177#[derive(Debug, Clone)]
181pub struct ActiveWorkerInfo {
182 pub worker_id: Uuid,
184 pub task_id: String,
186 pub worker_type: WorkerType,
188 pub runtime: std::time::Duration,
190}
191
192pub struct Worker {
194 id: Uuid,
195 status: WorkerStatus,
196}
197
198#[derive(Debug, Clone)]
203pub enum WorkerStatus {
204 Idle,
206 Busy(String),
208 Stopped,
210 Error(String),
212}
213
214impl Worker {
215 pub fn new() -> Self {
217 Self {
218 id: Uuid::new_v4(),
219 status: WorkerStatus::Idle,
220 }
221 }
222
223 pub fn id(&self) -> Uuid {
225 self.id
226 }
227
228 pub fn status(&self) -> &WorkerStatus {
230 &self.status
231 }
232
233 pub fn set_status(&mut self, status: WorkerStatus) {
239 self.status = status;
240 }
241}
242
243#[cfg(test)]
244mod tests {
245 use super::*;
246
247 #[tokio::test]
248 async fn test_worker_pool_capacity() {
249 let pool = WorkerPool::new(2);
250 assert_eq!(pool.get_capacity(), 2);
251 assert_eq!(pool.get_active_count(), 0);
252 let stats = pool.get_worker_stats();
253 assert_eq!(stats.max_capacity, 2);
254 assert_eq!(stats.total_active, 0);
255 }
256
257 #[tokio::test]
258 async fn test_execute_and_active_count() {
259 use crate::core::parallel::task::{Task, TaskResult};
260 use async_trait::async_trait;
261
262 #[derive(Clone)]
263 struct DummyTask {
264 id: String,
265 tp: &'static str,
266 }
267
268 #[async_trait]
269 impl Task for DummyTask {
270 async fn execute(&self) -> TaskResult {
271 TaskResult::Success(self.id.clone())
272 }
273 fn task_type(&self) -> &'static str {
274 self.tp
275 }
276 fn task_id(&self) -> String {
277 self.id.clone()
278 }
279 }
280
281 let pool = WorkerPool::new(1);
282 let task = DummyTask {
283 id: "t1".into(),
284 tp: "convert",
285 };
286 let res = pool.execute(Box::new(task.clone())).await;
287 assert!(matches!(res, Ok(TaskResult::Success(_))));
288 assert_eq!(pool.get_active_count(), 1);
289 }
290
291 #[tokio::test]
292 async fn test_reject_when_full() {
293 use crate::core::parallel::task::{Task, TaskResult};
294 use async_trait::async_trait;
295
296 #[derive(Clone)]
297 struct DummyTask;
298
299 #[async_trait]
300 impl Task for DummyTask {
301 async fn execute(&self) -> TaskResult {
302 TaskResult::Success("".into())
303 }
304 fn task_type(&self) -> &'static str {
305 "match"
306 }
307 fn task_id(&self) -> String {
308 "".into()
309 }
310 }
311
312 let pool = WorkerPool::new(1);
313 let _ = pool.execute(Box::new(DummyTask)).await;
314 let err = pool.execute(Box::new(DummyTask)).await;
315 assert!(err.is_err());
316 }
317
318 #[tokio::test]
319 async fn test_list_active_workers_and_stats() {
320 use super::WorkerType;
321 use crate::core::parallel::task::{Task, TaskResult};
322 use async_trait::async_trait;
323
324 #[derive(Clone)]
325 struct DummyTask2;
326
327 #[async_trait]
328 impl Task for DummyTask2 {
329 async fn execute(&self) -> TaskResult {
330 TaskResult::Success("".into())
331 }
332 fn task_type(&self) -> &'static str {
333 "sync"
334 }
335 fn task_id(&self) -> String {
336 "tok2".into()
337 }
338 }
339
340 let pool = WorkerPool::new(2);
341 let _ = pool.execute(Box::new(DummyTask2)).await;
342 let workers = pool.list_active_workers();
343 assert_eq!(workers.len(), 1);
344 let info = &workers[0];
345 assert_eq!(info.task_id, "tok2");
346 assert_eq!(info.worker_type, WorkerType::Mixed);
347 let stats = pool.get_worker_stats();
348 assert_eq!(stats.total_active, 1);
349 }
350
351 #[tokio::test]
353 async fn test_worker_job_distribution() {
354 use crate::core::parallel::task::{Task, TaskResult};
355 use async_trait::async_trait;
356 use std::sync::Arc;
357 use std::sync::atomic::{AtomicUsize, Ordering};
358
359 #[derive(Clone)]
360 struct CountingTask {
361 id: String,
362 counter: Arc<AtomicUsize>,
363 }
364
365 #[async_trait]
366 impl Task for CountingTask {
367 async fn execute(&self) -> TaskResult {
368 self.counter.fetch_add(1, Ordering::SeqCst);
369 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
370 TaskResult::Success(format!("task-{}", self.id))
371 }
372 fn task_type(&self) -> &'static str {
373 "convert"
374 }
375 fn task_id(&self) -> String {
376 self.id.clone()
377 }
378 }
379
380 let pool = WorkerPool::new(4);
381 let counter = Arc::new(AtomicUsize::new(0));
382 let mut handles = Vec::new();
383
384 for i in 0..4 {
386 let task = CountingTask {
388 id: format!("task-{}", i),
389 counter: Arc::clone(&counter),
390 };
391
392 let pool_clone = pool.clone();
394 let handle = tokio::spawn(async move { pool_clone.execute(Box::new(task)).await });
395 handles.push(handle);
396 }
397
398 for handle in handles {
400 let result = handle.await.unwrap();
401 assert!(result.is_ok(), "Task submission should succeed");
402 }
403
404 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
406
407 let final_count = counter.load(Ordering::SeqCst);
409 assert_eq!(final_count, 4, "All 4 tasks should have been executed");
410 }
411
412 #[tokio::test]
414 async fn test_worker_error_recovery() {
415 use crate::core::parallel::task::{Task, TaskResult};
416 use async_trait::async_trait;
417
418 #[derive(Clone)]
419 struct FailingTask {
420 id: String,
421 should_fail: bool,
422 }
423
424 #[async_trait]
425 impl Task for FailingTask {
426 async fn execute(&self) -> TaskResult {
427 if self.should_fail {
428 TaskResult::Failed("Intentional failure".to_string())
429 } else {
430 TaskResult::Success(format!("success-{}", self.id))
431 }
432 }
433 fn task_type(&self) -> &'static str {
434 "sync"
435 }
436 fn task_id(&self) -> String {
437 self.id.clone()
438 }
439 }
440
441 let pool = WorkerPool::new(2);
442
443 let success_task = FailingTask {
445 id: "success".to_string(),
446 should_fail: false,
447 };
448 let result = pool.execute(Box::new(success_task)).await;
449 assert!(result.is_ok(), "Successful task should be submitted");
450
451 let fail_task = FailingTask {
453 id: "fail".to_string(),
454 should_fail: true,
455 };
456 let result = pool.execute(Box::new(fail_task)).await;
457 assert!(
458 result.is_ok(),
459 "Failing task should still be submitted successfully"
460 );
461
462 assert!(
464 pool.get_active_count() <= 2,
465 "Active count should be within limits"
466 );
467 }
468
469 #[tokio::test]
471 async fn test_parallel_processing_performance() {
472 use crate::core::parallel::task::{Task, TaskResult};
473 use async_trait::async_trait;
474 use std::time::Instant;
475
476 #[derive(Clone)]
477 struct CpuIntensiveTask {
478 id: String,
479 duration_ms: u64,
480 }
481
482 #[async_trait]
483 impl Task for CpuIntensiveTask {
484 async fn execute(&self) -> TaskResult {
485 tokio::time::sleep(tokio::time::Duration::from_millis(self.duration_ms)).await;
487 TaskResult::Success(format!("completed-{}", self.id))
488 }
489 fn task_type(&self) -> &'static str {
490 "convert"
491 }
492 fn task_id(&self) -> String {
493 self.id.clone()
494 }
495 }
496
497 let sequential_pool = WorkerPool::new(1);
499 let start = Instant::now();
500
501 for i in 0..2 {
502 let task = CpuIntensiveTask {
504 id: format!("seq-{}", i),
505 duration_ms: 10, };
507 if let Err(e) = sequential_pool.execute(Box::new(task)).await {
508 println!("Sequential task {} failed: {}", i, e);
509 }
511 }
512 let sequential_time = start.elapsed();
513
514 let parallel_pool = WorkerPool::new(2); let start = Instant::now();
517
518 let task = CpuIntensiveTask {
520 id: "par-0".to_string(),
521 duration_ms: 10,
522 };
523 if let Err(e) = parallel_pool.execute(Box::new(task)).await {
524 println!("Parallel task failed: {}", e);
525 }
526 let parallel_time = start.elapsed();
527
528 println!("Sequential submission time: {:?}", sequential_time);
531 println!("Parallel submission time: {:?}", parallel_time);
532
533 assert!(
535 parallel_time <= sequential_time * 2,
536 "Parallel submission should not be significantly slower"
537 );
538 }
539
540 #[tokio::test]
542 async fn test_resource_management() {
543 let pool = WorkerPool::new(3);
544
545 assert_eq!(
547 pool.determine_worker_type("convert"),
548 WorkerType::CpuIntensive
549 );
550 assert_eq!(pool.determine_worker_type("sync"), WorkerType::Mixed);
551 assert_eq!(pool.determine_worker_type("match"), WorkerType::IoIntensive);
552 assert_eq!(
553 pool.determine_worker_type("validate"),
554 WorkerType::IoIntensive
555 );
556 assert_eq!(pool.determine_worker_type("unknown"), WorkerType::Mixed);
557
558 let stats = pool.get_worker_stats();
560 assert_eq!(stats.total_active, 0);
561 assert_eq!(stats.max_capacity, 3);
562 assert_eq!(stats.cpu_intensive_count, 0);
563 assert_eq!(stats.io_intensive_count, 0);
564 assert_eq!(stats.mixed_count, 0);
565
566 assert_eq!(pool.get_capacity(), 3);
568 assert_eq!(pool.get_active_count(), 0);
569 }
570
571 #[tokio::test]
573 async fn test_worker_pool_shutdown() {
574 use crate::core::parallel::task::{Task, TaskResult};
575 use async_trait::async_trait;
576
577 #[derive(Clone)]
578 struct SlowTask {
579 id: String,
580 }
581
582 #[async_trait]
583 impl Task for SlowTask {
584 async fn execute(&self) -> TaskResult {
585 tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
586 TaskResult::Success(format!("slow-{}", self.id))
587 }
588 fn task_type(&self) -> &'static str {
589 "mixed"
590 }
591 fn task_id(&self) -> String {
592 self.id.clone()
593 }
594 }
595
596 let pool = WorkerPool::new(2);
597
598 for i in 0..2 {
600 let task = SlowTask {
601 id: format!("slow-{}", i),
602 };
603 pool.execute(Box::new(task)).await.unwrap();
604 }
605
606 assert!(pool.get_active_count() <= 2);
608
609 let start = std::time::Instant::now();
611 pool.shutdown().await;
612 let shutdown_time = start.elapsed();
613
614 assert!(shutdown_time >= std::time::Duration::from_millis(30));
616
617 assert_eq!(pool.get_active_count(), 0);
619 }
620
621 #[tokio::test]
623 async fn test_active_worker_tracking() {
624 use crate::core::parallel::task::{Task, TaskResult};
625 use async_trait::async_trait;
626
627 #[derive(Clone)]
628 struct TrackableTask {
629 id: String,
630 task_type: &'static str,
631 }
632
633 #[async_trait]
634 impl Task for TrackableTask {
635 async fn execute(&self) -> TaskResult {
636 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
637 TaskResult::Success(format!("tracked-{}", self.id))
638 }
639 fn task_type(&self) -> &'static str {
640 self.task_type
641 }
642 fn task_id(&self) -> String {
643 self.id.clone()
644 }
645 }
646
647 let pool = WorkerPool::new(3);
648
649 let tasks = vec![
651 ("cpu-task", "convert"),
652 ("io-task", "match"),
653 ("mixed-task", "sync"),
654 ];
655
656 for (id, task_type) in tasks {
657 let task = TrackableTask {
658 id: id.to_string(),
659 task_type,
660 };
661 pool.execute(Box::new(task)).await.unwrap();
662 }
663
664 let active_workers = pool.list_active_workers();
666 assert!(active_workers.len() <= 3, "Should not exceed pool capacity");
667
668 for worker in &active_workers {
670 assert!(!worker.task_id.is_empty(), "Task ID should be set");
671 assert!(matches!(
672 worker.worker_type,
673 WorkerType::CpuIntensive | WorkerType::IoIntensive | WorkerType::Mixed
674 ));
675 assert!(
676 worker.runtime.as_millis() < u128::MAX,
677 "Runtime should be valid"
678 );
679 }
680
681 let stats = pool.get_worker_stats();
683 assert!(stats.total_active <= 3);
684 assert_eq!(stats.max_capacity, 3);
685
686 tokio::time::sleep(tokio::time::Duration::from_millis(150)).await;
688 }
689}