1use serde::{Deserialize, Serialize};
37use std::collections::{HashMap, VecDeque};
38use std::sync::{Arc, Mutex};
39use thiserror::Error;
40
41#[derive(Error, Debug, Clone, PartialEq)]
43pub enum ParallelError {
44 #[error("Task queue is full")]
45 QueueFull,
46
47 #[error("Task dependency cycle detected")]
48 DependencyCycle,
49
50 #[error("Task {0} not found")]
51 TaskNotFound(String),
52
53 #[error("Invalid worker count: {0}")]
54 InvalidWorkerCount(usize),
55
56 #[error("Parallel execution failed: {0}")]
57 ExecutionFailed(String),
58
59 #[error("NUMA allocation failed: {0}")]
60 NumaAllocationFailed(String),
61}
62
63#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
65pub enum StealStrategy {
66 Random,
68 MaxLoad,
70 LRU,
72 RoundRobin,
74}
75
76#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
78pub struct NumaNode(pub usize);
79
80#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
82pub enum NumaStrategy {
83 None,
85 LocalPreferred,
87 LocalStrict,
89 Interleave,
91}
92
93#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
95pub struct ParallelConfig {
96 pub num_workers: usize,
98
99 pub steal_strategy: StealStrategy,
101
102 pub numa_strategy: NumaStrategy,
104
105 pub enable_priority: bool,
107
108 pub enable_stats: bool,
110
111 pub max_queue_size: usize,
113
114 pub cache_line_padding: bool,
116}
117
118impl Default for ParallelConfig {
119 fn default() -> Self {
120 Self {
121 num_workers: num_cpus::get(),
122 steal_strategy: StealStrategy::Random,
123 numa_strategy: NumaStrategy::None,
124 enable_priority: false,
125 enable_stats: true,
126 max_queue_size: 10000,
127 cache_line_padding: true,
128 }
129 }
130}
131
132impl ParallelConfig {
133 pub fn new(num_workers: usize) -> Result<Self, ParallelError> {
135 if num_workers == 0 {
136 return Err(ParallelError::InvalidWorkerCount(num_workers));
137 }
138
139 Ok(Self {
140 num_workers,
141 ..Default::default()
142 })
143 }
144
145 pub fn with_num_workers(mut self, num_workers: usize) -> Self {
147 self.num_workers = num_workers;
148 self
149 }
150
151 pub fn with_steal_strategy(mut self, strategy: StealStrategy) -> Self {
153 self.steal_strategy = strategy;
154 self
155 }
156
157 pub fn with_numa_strategy(mut self, strategy: NumaStrategy) -> Self {
159 self.numa_strategy = strategy;
160 self
161 }
162
163 pub fn with_priority(mut self, enabled: bool) -> Self {
165 self.enable_priority = enabled;
166 self
167 }
168
169 pub fn with_stats(mut self, enabled: bool) -> Self {
171 self.enable_stats = enabled;
172 self
173 }
174}
175
176pub type TaskId = String;
178
179#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
181pub enum TaskPriority {
182 Low = 0,
184 Normal = 1,
186 High = 2,
188 Critical = 3,
190}
191
192#[derive(Debug, Clone)]
194pub struct Task {
195 pub id: TaskId,
197
198 pub priority: TaskPriority,
200
201 pub dependencies: Vec<TaskId>,
203
204 pub numa_node: Option<NumaNode>,
206
207 pub estimated_time_us: Option<u64>,
209}
210
211impl Task {
212 pub fn new(id: TaskId) -> Self {
214 Self {
215 id,
216 priority: TaskPriority::Normal,
217 dependencies: Vec::new(),
218 numa_node: None,
219 estimated_time_us: None,
220 }
221 }
222
223 pub fn with_priority(mut self, priority: TaskPriority) -> Self {
225 self.priority = priority;
226 self
227 }
228
229 pub fn with_dependency(mut self, dep: TaskId) -> Self {
231 self.dependencies.push(dep);
232 self
233 }
234
235 pub fn with_numa_node(mut self, node: NumaNode) -> Self {
237 self.numa_node = Some(node);
238 self
239 }
240
241 pub fn with_estimated_time(mut self, time_us: u64) -> Self {
243 self.estimated_time_us = Some(time_us);
244 self
245 }
246}
247
248#[repr(align(64))] struct WorkerQueue {
251 queue: VecDeque<Task>,
252 steal_count: usize,
253 tasks_executed: usize,
254 total_execution_time_us: u64,
255}
256
257impl WorkerQueue {
258 fn new() -> Self {
259 Self {
260 queue: VecDeque::new(),
261 steal_count: 0,
262 tasks_executed: 0,
263 total_execution_time_us: 0,
264 }
265 }
266
267 fn push(&mut self, task: Task) {
268 self.queue.push_back(task);
269 }
270
271 fn pop(&mut self) -> Option<Task> {
272 self.queue.pop_front()
273 }
274
275 fn steal(&mut self) -> Option<Task> {
276 self.steal_count += 1;
277 self.queue.pop_back()
278 }
279
280 fn len(&self) -> usize {
281 self.queue.len()
282 }
283}
284
285pub struct WorkStealingScheduler {
287 config: ParallelConfig,
288 workers: Vec<Arc<Mutex<WorkerQueue>>>,
289 completed_tasks: Arc<Mutex<HashMap<TaskId, u64>>>, stats: Arc<Mutex<SchedulerStats>>,
291}
292
293impl WorkStealingScheduler {
294 pub fn new(config: ParallelConfig) -> Self {
296 let mut workers = Vec::with_capacity(config.num_workers);
297 for _ in 0..config.num_workers {
298 workers.push(Arc::new(Mutex::new(WorkerQueue::new())));
299 }
300
301 Self {
302 config,
303 workers,
304 completed_tasks: Arc::new(Mutex::new(HashMap::new())),
305 stats: Arc::new(Mutex::new(SchedulerStats::default())),
306 }
307 }
308
309 pub fn submit(&self, task: Task) -> Result<(), ParallelError> {
311 self.validate_dependencies(&task)?;
313
314 let worker_idx = self.select_worker(&task);
316
317 let mut worker = self.workers[worker_idx].lock().unwrap();
319 if worker.len() >= self.config.max_queue_size {
320 return Err(ParallelError::QueueFull);
321 }
322
323 worker.push(task);
324
325 Ok(())
326 }
327
328 pub fn submit_batch(&self, tasks: Vec<Task>) -> Result<(), ParallelError> {
330 for task in tasks {
331 self.submit(task)?;
332 }
333 Ok(())
334 }
335
336 pub fn execute_all(&self) -> Result<Vec<TaskId>, ParallelError> {
338 let mut completed = Vec::new();
339
340 for worker in &self.workers {
342 let mut worker = worker.lock().unwrap();
343 while let Some(task) = worker.pop() {
344 if self.dependencies_satisfied(&task)? {
346 let execution_time = task.estimated_time_us.unwrap_or(1000);
348 worker.tasks_executed += 1;
349 worker.total_execution_time_us += execution_time;
350
351 self.completed_tasks
353 .lock()
354 .unwrap()
355 .insert(task.id.clone(), execution_time);
356
357 completed.push(task.id);
358
359 if self.config.enable_stats {
361 let mut stats = self.stats.lock().unwrap();
362 stats.tasks_executed += 1;
363 stats.total_execution_time_us += execution_time;
364 }
365 } else {
366 worker.push(task);
368 }
369 }
370 }
371
372 Ok(completed)
373 }
374
375 pub fn stats(&self) -> SchedulerStats {
377 self.stats.lock().unwrap().clone()
378 }
379
380 pub fn reset(&self) {
382 for worker in &self.workers {
383 let mut worker = worker.lock().unwrap();
384 worker.queue.clear();
385 worker.steal_count = 0;
386 worker.tasks_executed = 0;
387 worker.total_execution_time_us = 0;
388 }
389
390 self.completed_tasks.lock().unwrap().clear();
391 *self.stats.lock().unwrap() = SchedulerStats::default();
392 }
393
394 fn validate_dependencies(&self, task: &Task) -> Result<(), ParallelError> {
397 let mut visited = std::collections::HashSet::new();
399 self.check_cycle(&task.id, &task.dependencies, &mut visited)
400 }
401
402 fn check_cycle(
403 &self,
404 current: &TaskId,
405 dependencies: &[TaskId],
406 visited: &mut std::collections::HashSet<TaskId>,
407 ) -> Result<(), ParallelError> {
408 if visited.contains(current) {
409 return Err(ParallelError::DependencyCycle);
410 }
411
412 visited.insert(current.clone());
413
414 for _dep in dependencies {
415 }
418
419 Ok(())
420 }
421
422 fn dependencies_satisfied(&self, task: &Task) -> Result<bool, ParallelError> {
423 let completed = self.completed_tasks.lock().unwrap();
424 Ok(task
425 .dependencies
426 .iter()
427 .all(|dep| completed.contains_key(dep)))
428 }
429
430 fn select_worker(&self, task: &Task) -> usize {
431 if let Some(numa_node) = task.numa_node {
433 return self.numa_node_to_worker(numa_node);
434 }
435
436 let mut min_load = usize::MAX;
438 let mut selected = 0;
439
440 for (idx, worker) in self.workers.iter().enumerate() {
441 let worker = worker.lock().unwrap();
442 let load = worker.len();
443 if load < min_load {
444 min_load = load;
445 selected = idx;
446 }
447 }
448
449 selected
450 }
451
452 fn numa_node_to_worker(&self, node: NumaNode) -> usize {
453 node.0 % self.config.num_workers
456 }
457
458 pub fn try_steal(&self, thief_idx: usize) -> Option<Task> {
460 let victim_idx = self.select_victim(thief_idx);
461 if victim_idx == thief_idx {
462 return None;
463 }
464
465 let mut victim = self.workers[victim_idx].lock().unwrap();
466 let stolen = victim.steal();
467
468 if stolen.is_some() && self.config.enable_stats {
469 let mut stats = self.stats.lock().unwrap();
470 stats.steal_count += 1;
471 }
472
473 stolen
474 }
475
476 fn select_victim(&self, thief_idx: usize) -> usize {
477 match self.config.steal_strategy {
478 StealStrategy::Random => {
479 (thief_idx + 1) % self.config.num_workers
481 }
482 StealStrategy::MaxLoad => {
483 let mut max_load = 0;
485 let mut victim = thief_idx;
486
487 for (idx, worker) in self.workers.iter().enumerate() {
488 if idx == thief_idx {
489 continue;
490 }
491 let worker = worker.lock().unwrap();
492 let load = worker.len();
493 if load > max_load {
494 max_load = load;
495 victim = idx;
496 }
497 }
498
499 victim
500 }
501 StealStrategy::LRU | StealStrategy::RoundRobin => {
502 (thief_idx + 1) % self.config.num_workers
504 }
505 }
506 }
507
508 pub fn load_balance_stats(&self) -> LoadBalanceStats {
510 let mut worker_loads = Vec::new();
511 let mut total_tasks = 0;
512
513 for worker in &self.workers {
514 let worker = worker.lock().unwrap();
515 let load = worker.tasks_executed;
516 worker_loads.push(load);
517 total_tasks += load;
518 }
519
520 let avg_load = total_tasks as f64 / self.config.num_workers as f64;
521 let variance = worker_loads
522 .iter()
523 .map(|&load| (load as f64 - avg_load).powi(2))
524 .sum::<f64>()
525 / self.config.num_workers as f64;
526
527 let std_dev = variance.sqrt();
528 let cv = if avg_load > 0.0 {
529 std_dev / avg_load
530 } else {
531 0.0
532 };
533 let max_load = *worker_loads.iter().max().unwrap_or(&0);
534
535 LoadBalanceStats {
536 worker_loads,
537 avg_load,
538 std_dev,
539 coefficient_of_variation: cv,
540 imbalance_ratio: if avg_load > 0.0 {
541 max_load as f64 / avg_load
542 } else {
543 1.0
544 },
545 }
546 }
547}
548
549#[derive(Debug, Clone, Default, PartialEq)]
551pub struct SchedulerStats {
552 pub tasks_executed: usize,
554
555 pub total_execution_time_us: u64,
557
558 pub steal_count: usize,
560
561 pub failed_steals: usize,
563}
564
565impl SchedulerStats {
566 pub fn avg_execution_time_us(&self) -> f64 {
568 if self.tasks_executed > 0 {
569 self.total_execution_time_us as f64 / self.tasks_executed as f64
570 } else {
571 0.0
572 }
573 }
574
575 pub fn steal_success_rate(&self) -> f64 {
577 let total_attempts = self.steal_count + self.failed_steals;
578 if total_attempts > 0 {
579 self.steal_count as f64 / total_attempts as f64
580 } else {
581 0.0
582 }
583 }
584}
585
586impl std::fmt::Display for SchedulerStats {
587 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
588 writeln!(f, "Scheduler Statistics")?;
589 writeln!(f, "====================")?;
590 writeln!(f, "Tasks executed: {}", self.tasks_executed)?;
591 writeln!(
592 f,
593 "Total time: {:.2} ms",
594 self.total_execution_time_us as f64 / 1000.0
595 )?;
596 writeln!(
597 f,
598 "Avg time/task: {:.2} µs",
599 self.avg_execution_time_us()
600 )?;
601 writeln!(f, "Steal count: {}", self.steal_count)?;
602 writeln!(f, "Failed steals: {}", self.failed_steals)?;
603 writeln!(
604 f,
605 "Steal success rate: {:.2}%",
606 self.steal_success_rate() * 100.0
607 )?;
608 Ok(())
609 }
610}
611
612#[derive(Debug, Clone, PartialEq)]
614pub struct LoadBalanceStats {
615 pub worker_loads: Vec<usize>,
617
618 pub avg_load: f64,
620
621 pub std_dev: f64,
623
624 pub coefficient_of_variation: f64,
626
627 pub imbalance_ratio: f64,
629}
630
631impl LoadBalanceStats {
632 pub fn is_well_balanced(&self) -> bool {
634 self.coefficient_of_variation < 0.2
635 }
636}
637
638impl std::fmt::Display for LoadBalanceStats {
639 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
640 writeln!(f, "Load Balance Statistics")?;
641 writeln!(f, "=======================")?;
642 writeln!(f, "Worker loads: {:?}", self.worker_loads)?;
643 writeln!(f, "Average load: {:.2}", self.avg_load)?;
644 writeln!(f, "Std deviation: {:.2}", self.std_dev)?;
645 writeln!(f, "CV: {:.4}", self.coefficient_of_variation)?;
646 writeln!(f, "Imbalance: {:.2}x", self.imbalance_ratio)?;
647 writeln!(
648 f,
649 "Well balanced: {}",
650 if self.is_well_balanced() { "Yes" } else { "No" }
651 )?;
652 Ok(())
653 }
654}
655
656#[cfg(test)]
657mod tests {
658 use super::*;
659
660 #[test]
661 fn test_parallel_config_default() {
662 let config = ParallelConfig::default();
663 assert!(config.num_workers > 0);
664 assert_eq!(config.steal_strategy, StealStrategy::Random);
665 assert!(config.enable_stats);
666 }
667
668 #[test]
669 fn test_parallel_config_builder() {
670 let config = ParallelConfig::new(4)
671 .unwrap()
672 .with_steal_strategy(StealStrategy::MaxLoad)
673 .with_numa_strategy(NumaStrategy::LocalPreferred)
674 .with_priority(true);
675
676 assert_eq!(config.num_workers, 4);
677 assert_eq!(config.steal_strategy, StealStrategy::MaxLoad);
678 assert_eq!(config.numa_strategy, NumaStrategy::LocalPreferred);
679 assert!(config.enable_priority);
680 }
681
682 #[test]
683 fn test_task_creation() {
684 let task = Task::new("task1".to_string())
685 .with_priority(TaskPriority::High)
686 .with_dependency("task0".to_string())
687 .with_estimated_time(1000);
688
689 assert_eq!(task.id, "task1");
690 assert_eq!(task.priority, TaskPriority::High);
691 assert_eq!(task.dependencies.len(), 1);
692 assert_eq!(task.estimated_time_us, Some(1000));
693 }
694
695 #[test]
696 fn test_scheduler_creation() {
697 let config = ParallelConfig::new(4).unwrap();
698 let scheduler = WorkStealingScheduler::new(config);
699
700 assert_eq!(scheduler.workers.len(), 4);
701 }
702
703 #[test]
704 fn test_scheduler_submit() {
705 let config = ParallelConfig::new(2).unwrap();
706 let scheduler = WorkStealingScheduler::new(config);
707
708 let task = Task::new("task1".to_string());
709 assert!(scheduler.submit(task).is_ok());
710 }
711
712 #[test]
713 fn test_scheduler_execute_simple() {
714 let config = ParallelConfig::new(2).unwrap();
715 let scheduler = WorkStealingScheduler::new(config);
716
717 let task1 = Task::new("task1".to_string()).with_estimated_time(100);
718 let task2 = Task::new("task2".to_string()).with_estimated_time(200);
719
720 scheduler.submit(task1).unwrap();
721 scheduler.submit(task2).unwrap();
722
723 let completed = scheduler.execute_all().unwrap();
724 assert_eq!(completed.len(), 2);
725 }
726
727 #[test]
728 fn test_scheduler_dependencies() {
729 let config = ParallelConfig::new(2).unwrap();
730 let scheduler = WorkStealingScheduler::new(config);
731
732 let task1 = Task::new("task1".to_string());
733 let task2 = Task::new("task2".to_string()).with_dependency("task1".to_string());
734
735 scheduler.submit(task1).unwrap();
736 scheduler.submit(task2).unwrap();
737
738 let completed = scheduler.execute_all().unwrap();
739 assert!(completed.contains(&"task1".to_string()));
740 }
741
742 #[test]
743 fn test_scheduler_stats() {
744 let config = ParallelConfig::new(2).unwrap();
745 let scheduler = WorkStealingScheduler::new(config);
746
747 let task1 = Task::new("task1".to_string()).with_estimated_time(1000);
748 let task2 = Task::new("task2".to_string()).with_estimated_time(2000);
749
750 scheduler.submit(task1).unwrap();
751 scheduler.submit(task2).unwrap();
752 scheduler.execute_all().unwrap();
753
754 let stats = scheduler.stats();
755 assert_eq!(stats.tasks_executed, 2);
756 assert_eq!(stats.total_execution_time_us, 3000);
757 }
758
759 #[test]
760 fn test_load_balance_stats() {
761 let config = ParallelConfig::new(4).unwrap();
762 let scheduler = WorkStealingScheduler::new(config);
763
764 for i in 0..8 {
766 let task = Task::new(format!("task{}", i)).with_estimated_time(100);
767 scheduler.submit(task).unwrap();
768 }
769
770 scheduler.execute_all().unwrap();
771
772 let stats = scheduler.load_balance_stats();
773 assert!((stats.avg_load - 2.0).abs() < 0.1); }
775
776 #[test]
777 fn test_scheduler_reset() {
778 let config = ParallelConfig::new(2).unwrap();
779 let scheduler = WorkStealingScheduler::new(config);
780
781 let task = Task::new("task1".to_string());
782 scheduler.submit(task).unwrap();
783 scheduler.execute_all().unwrap();
784
785 let stats_before = scheduler.stats();
786 assert_eq!(stats_before.tasks_executed, 1);
787
788 scheduler.reset();
789
790 let stats_after = scheduler.stats();
791 assert_eq!(stats_after.tasks_executed, 0);
792 }
793
794 #[test]
795 fn test_task_priority() {
796 assert!(TaskPriority::Critical > TaskPriority::High);
797 assert!(TaskPriority::High > TaskPriority::Normal);
798 assert!(TaskPriority::Normal > TaskPriority::Low);
799 }
800
801 #[test]
802 fn test_numa_node() {
803 let node = NumaNode(0);
804 assert_eq!(node.0, 0);
805 }
806
807 #[test]
808 fn test_steal_strategy() {
809 let s1 = StealStrategy::Random;
811 let s2 = s1;
812 let s3 = s1; assert_eq!(s2, s3);
814 }
815}