1use serde::{Deserialize, Serialize};
37use std::collections::{HashMap, VecDeque};
38use std::sync::{Arc, Mutex};
39use thiserror::Error;
40
41#[derive(Error, Debug, Clone, PartialEq)]
43pub enum ParallelError {
44 #[error("Task queue is full")]
45 QueueFull,
46
47 #[error("Task dependency cycle detected")]
48 DependencyCycle,
49
50 #[error("Task {0} not found")]
51 TaskNotFound(String),
52
53 #[error("Invalid worker count: {0}")]
54 InvalidWorkerCount(usize),
55
56 #[error("Parallel execution failed: {0}")]
57 ExecutionFailed(String),
58
59 #[error("NUMA allocation failed: {0}")]
60 NumaAllocationFailed(String),
61}
62
63#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
65pub enum StealStrategy {
66 Random,
68 MaxLoad,
70 LRU,
72 RoundRobin,
74}
75
76#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
78pub struct NumaNode(pub usize);
79
80#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
82pub enum NumaStrategy {
83 None,
85 LocalPreferred,
87 LocalStrict,
89 Interleave,
91}
92
93#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
95pub struct ParallelConfig {
96 pub num_workers: usize,
98
99 pub steal_strategy: StealStrategy,
101
102 pub numa_strategy: NumaStrategy,
104
105 pub enable_priority: bool,
107
108 pub enable_stats: bool,
110
111 pub max_queue_size: usize,
113
114 pub cache_line_padding: bool,
116}
117
118impl Default for ParallelConfig {
119 fn default() -> Self {
120 Self {
121 num_workers: num_cpus::get(),
122 steal_strategy: StealStrategy::Random,
123 numa_strategy: NumaStrategy::None,
124 enable_priority: false,
125 enable_stats: true,
126 max_queue_size: 10000,
127 cache_line_padding: true,
128 }
129 }
130}
131
132impl ParallelConfig {
133 pub fn new(num_workers: usize) -> Result<Self, ParallelError> {
135 if num_workers == 0 {
136 return Err(ParallelError::InvalidWorkerCount(num_workers));
137 }
138
139 Ok(Self {
140 num_workers,
141 ..Default::default()
142 })
143 }
144
145 pub fn with_num_workers(mut self, num_workers: usize) -> Self {
147 self.num_workers = num_workers;
148 self
149 }
150
151 pub fn with_steal_strategy(mut self, strategy: StealStrategy) -> Self {
153 self.steal_strategy = strategy;
154 self
155 }
156
157 pub fn with_numa_strategy(mut self, strategy: NumaStrategy) -> Self {
159 self.numa_strategy = strategy;
160 self
161 }
162
163 pub fn with_priority(mut self, enabled: bool) -> Self {
165 self.enable_priority = enabled;
166 self
167 }
168
169 pub fn with_stats(mut self, enabled: bool) -> Self {
171 self.enable_stats = enabled;
172 self
173 }
174}
175
176pub type TaskId = String;
178
179#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
181pub enum TaskPriority {
182 Low = 0,
184 Normal = 1,
186 High = 2,
188 Critical = 3,
190}
191
192#[derive(Debug, Clone)]
194pub struct Task {
195 pub id: TaskId,
197
198 pub priority: TaskPriority,
200
201 pub dependencies: Vec<TaskId>,
203
204 pub numa_node: Option<NumaNode>,
206
207 pub estimated_time_us: Option<u64>,
209}
210
211impl Task {
212 pub fn new(id: TaskId) -> Self {
214 Self {
215 id,
216 priority: TaskPriority::Normal,
217 dependencies: Vec::new(),
218 numa_node: None,
219 estimated_time_us: None,
220 }
221 }
222
223 pub fn with_priority(mut self, priority: TaskPriority) -> Self {
225 self.priority = priority;
226 self
227 }
228
229 pub fn with_dependency(mut self, dep: TaskId) -> Self {
231 self.dependencies.push(dep);
232 self
233 }
234
235 pub fn with_numa_node(mut self, node: NumaNode) -> Self {
237 self.numa_node = Some(node);
238 self
239 }
240
241 pub fn with_estimated_time(mut self, time_us: u64) -> Self {
243 self.estimated_time_us = Some(time_us);
244 self
245 }
246}
247
248#[repr(align(64))] struct WorkerQueue {
251 queue: VecDeque<Task>,
252 steal_count: usize,
253 tasks_executed: usize,
254 total_execution_time_us: u64,
255}
256
257impl WorkerQueue {
258 fn new() -> Self {
259 Self {
260 queue: VecDeque::new(),
261 steal_count: 0,
262 tasks_executed: 0,
263 total_execution_time_us: 0,
264 }
265 }
266
267 fn push(&mut self, task: Task) {
268 self.queue.push_back(task);
269 }
270
271 fn pop(&mut self) -> Option<Task> {
272 self.queue.pop_front()
273 }
274
275 fn steal(&mut self) -> Option<Task> {
276 self.steal_count += 1;
277 self.queue.pop_back()
278 }
279
280 fn len(&self) -> usize {
281 self.queue.len()
282 }
283}
284
285pub struct WorkStealingScheduler {
287 config: ParallelConfig,
288 workers: Vec<Arc<Mutex<WorkerQueue>>>,
289 completed_tasks: Arc<Mutex<HashMap<TaskId, u64>>>, stats: Arc<Mutex<SchedulerStats>>,
291}
292
293impl WorkStealingScheduler {
294 pub fn new(config: ParallelConfig) -> Self {
296 let mut workers = Vec::with_capacity(config.num_workers);
297 for _ in 0..config.num_workers {
298 workers.push(Arc::new(Mutex::new(WorkerQueue::new())));
299 }
300
301 Self {
302 config,
303 workers,
304 completed_tasks: Arc::new(Mutex::new(HashMap::new())),
305 stats: Arc::new(Mutex::new(SchedulerStats::default())),
306 }
307 }
308
309 pub fn submit(&self, task: Task) -> Result<(), ParallelError> {
311 self.validate_dependencies(&task)?;
313
314 let worker_idx = self.select_worker(&task);
316
317 let mut worker = self.workers[worker_idx]
319 .lock()
320 .expect("lock should not be poisoned");
321 if worker.len() >= self.config.max_queue_size {
322 return Err(ParallelError::QueueFull);
323 }
324
325 worker.push(task);
326
327 Ok(())
328 }
329
330 pub fn submit_batch(&self, tasks: Vec<Task>) -> Result<(), ParallelError> {
332 for task in tasks {
333 self.submit(task)?;
334 }
335 Ok(())
336 }
337
338 pub fn execute_all(&self) -> Result<Vec<TaskId>, ParallelError> {
340 let mut completed = Vec::new();
341
342 for worker in &self.workers {
344 let mut worker = worker.lock().expect("lock should not be poisoned");
345 while let Some(task) = worker.pop() {
346 if self.dependencies_satisfied(&task)? {
348 let execution_time = task.estimated_time_us.unwrap_or(1000);
350 worker.tasks_executed += 1;
351 worker.total_execution_time_us += execution_time;
352
353 self.completed_tasks
355 .lock()
356 .expect("lock should not be poisoned")
357 .insert(task.id.clone(), execution_time);
358
359 completed.push(task.id);
360
361 if self.config.enable_stats {
363 let mut stats = self.stats.lock().expect("lock should not be poisoned");
364 stats.tasks_executed += 1;
365 stats.total_execution_time_us += execution_time;
366 }
367 } else {
368 worker.push(task);
370 }
371 }
372 }
373
374 Ok(completed)
375 }
376
377 pub fn stats(&self) -> SchedulerStats {
379 self.stats
380 .lock()
381 .expect("lock should not be poisoned")
382 .clone()
383 }
384
385 pub fn reset(&self) {
387 for worker in &self.workers {
388 let mut worker = worker.lock().expect("lock should not be poisoned");
389 worker.queue.clear();
390 worker.steal_count = 0;
391 worker.tasks_executed = 0;
392 worker.total_execution_time_us = 0;
393 }
394
395 self.completed_tasks
396 .lock()
397 .expect("lock should not be poisoned")
398 .clear();
399 *self.stats.lock().expect("lock should not be poisoned") = SchedulerStats::default();
400 }
401
402 fn validate_dependencies(&self, task: &Task) -> Result<(), ParallelError> {
405 let mut visited = std::collections::HashSet::new();
407 self.check_cycle(&task.id, &task.dependencies, &mut visited)
408 }
409
410 fn check_cycle(
411 &self,
412 current: &TaskId,
413 dependencies: &[TaskId],
414 visited: &mut std::collections::HashSet<TaskId>,
415 ) -> Result<(), ParallelError> {
416 if visited.contains(current) {
417 return Err(ParallelError::DependencyCycle);
418 }
419
420 visited.insert(current.clone());
421
422 for _dep in dependencies {
423 }
426
427 Ok(())
428 }
429
430 fn dependencies_satisfied(&self, task: &Task) -> Result<bool, ParallelError> {
431 let completed = self
432 .completed_tasks
433 .lock()
434 .expect("lock should not be poisoned");
435 Ok(task
436 .dependencies
437 .iter()
438 .all(|dep| completed.contains_key(dep)))
439 }
440
441 fn select_worker(&self, task: &Task) -> usize {
442 if let Some(numa_node) = task.numa_node {
444 return self.numa_node_to_worker(numa_node);
445 }
446
447 let mut min_load = usize::MAX;
449 let mut selected = 0;
450
451 for (idx, worker) in self.workers.iter().enumerate() {
452 let worker = worker.lock().expect("lock should not be poisoned");
453 let load = worker.len();
454 if load < min_load {
455 min_load = load;
456 selected = idx;
457 }
458 }
459
460 selected
461 }
462
463 fn numa_node_to_worker(&self, node: NumaNode) -> usize {
464 node.0 % self.config.num_workers
467 }
468
469 pub fn try_steal(&self, thief_idx: usize) -> Option<Task> {
471 let victim_idx = self.select_victim(thief_idx);
472 if victim_idx == thief_idx {
473 return None;
474 }
475
476 let mut victim = self.workers[victim_idx]
477 .lock()
478 .expect("lock should not be poisoned");
479 let stolen = victim.steal();
480
481 if stolen.is_some() && self.config.enable_stats {
482 let mut stats = self.stats.lock().expect("lock should not be poisoned");
483 stats.steal_count += 1;
484 }
485
486 stolen
487 }
488
489 fn select_victim(&self, thief_idx: usize) -> usize {
490 match self.config.steal_strategy {
491 StealStrategy::Random => {
492 (thief_idx + 1) % self.config.num_workers
494 }
495 StealStrategy::MaxLoad => {
496 let mut max_load = 0;
498 let mut victim = thief_idx;
499
500 for (idx, worker) in self.workers.iter().enumerate() {
501 if idx == thief_idx {
502 continue;
503 }
504 let worker = worker.lock().expect("lock should not be poisoned");
505 let load = worker.len();
506 if load > max_load {
507 max_load = load;
508 victim = idx;
509 }
510 }
511
512 victim
513 }
514 StealStrategy::LRU | StealStrategy::RoundRobin => {
515 (thief_idx + 1) % self.config.num_workers
517 }
518 }
519 }
520
521 pub fn load_balance_stats(&self) -> LoadBalanceStats {
523 let mut worker_loads = Vec::new();
524 let mut total_tasks = 0;
525
526 for worker in &self.workers {
527 let worker = worker.lock().expect("lock should not be poisoned");
528 let load = worker.tasks_executed;
529 worker_loads.push(load);
530 total_tasks += load;
531 }
532
533 let avg_load = total_tasks as f64 / self.config.num_workers as f64;
534 let variance = worker_loads
535 .iter()
536 .map(|&load| (load as f64 - avg_load).powi(2))
537 .sum::<f64>()
538 / self.config.num_workers as f64;
539
540 let std_dev = variance.sqrt();
541 let cv = if avg_load > 0.0 {
542 std_dev / avg_load
543 } else {
544 0.0
545 };
546 let max_load = *worker_loads.iter().max().unwrap_or(&0);
547
548 LoadBalanceStats {
549 worker_loads,
550 avg_load,
551 std_dev,
552 coefficient_of_variation: cv,
553 imbalance_ratio: if avg_load > 0.0 {
554 max_load as f64 / avg_load
555 } else {
556 1.0
557 },
558 }
559 }
560}
561
562#[derive(Debug, Clone, Default, PartialEq)]
564pub struct SchedulerStats {
565 pub tasks_executed: usize,
567
568 pub total_execution_time_us: u64,
570
571 pub steal_count: usize,
573
574 pub failed_steals: usize,
576}
577
578impl SchedulerStats {
579 pub fn avg_execution_time_us(&self) -> f64 {
581 if self.tasks_executed > 0 {
582 self.total_execution_time_us as f64 / self.tasks_executed as f64
583 } else {
584 0.0
585 }
586 }
587
588 pub fn steal_success_rate(&self) -> f64 {
590 let total_attempts = self.steal_count + self.failed_steals;
591 if total_attempts > 0 {
592 self.steal_count as f64 / total_attempts as f64
593 } else {
594 0.0
595 }
596 }
597}
598
599impl std::fmt::Display for SchedulerStats {
600 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
601 writeln!(f, "Scheduler Statistics")?;
602 writeln!(f, "====================")?;
603 writeln!(f, "Tasks executed: {}", self.tasks_executed)?;
604 writeln!(
605 f,
606 "Total time: {:.2} ms",
607 self.total_execution_time_us as f64 / 1000.0
608 )?;
609 writeln!(
610 f,
611 "Avg time/task: {:.2} µs",
612 self.avg_execution_time_us()
613 )?;
614 writeln!(f, "Steal count: {}", self.steal_count)?;
615 writeln!(f, "Failed steals: {}", self.failed_steals)?;
616 writeln!(
617 f,
618 "Steal success rate: {:.2}%",
619 self.steal_success_rate() * 100.0
620 )?;
621 Ok(())
622 }
623}
624
625#[derive(Debug, Clone, PartialEq)]
627pub struct LoadBalanceStats {
628 pub worker_loads: Vec<usize>,
630
631 pub avg_load: f64,
633
634 pub std_dev: f64,
636
637 pub coefficient_of_variation: f64,
639
640 pub imbalance_ratio: f64,
642}
643
644impl LoadBalanceStats {
645 pub fn is_well_balanced(&self) -> bool {
647 self.coefficient_of_variation < 0.2
648 }
649}
650
651impl std::fmt::Display for LoadBalanceStats {
652 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
653 writeln!(f, "Load Balance Statistics")?;
654 writeln!(f, "=======================")?;
655 writeln!(f, "Worker loads: {:?}", self.worker_loads)?;
656 writeln!(f, "Average load: {:.2}", self.avg_load)?;
657 writeln!(f, "Std deviation: {:.2}", self.std_dev)?;
658 writeln!(f, "CV: {:.4}", self.coefficient_of_variation)?;
659 writeln!(f, "Imbalance: {:.2}x", self.imbalance_ratio)?;
660 writeln!(
661 f,
662 "Well balanced: {}",
663 if self.is_well_balanced() { "Yes" } else { "No" }
664 )?;
665 Ok(())
666 }
667}
668
669#[cfg(test)]
670mod tests {
671 use super::*;
672
673 #[test]
674 fn test_parallel_config_default() {
675 let config = ParallelConfig::default();
676 assert!(config.num_workers > 0);
677 assert_eq!(config.steal_strategy, StealStrategy::Random);
678 assert!(config.enable_stats);
679 }
680
681 #[test]
682 fn test_parallel_config_builder() {
683 let config = ParallelConfig::new(4)
684 .expect("unwrap")
685 .with_steal_strategy(StealStrategy::MaxLoad)
686 .with_numa_strategy(NumaStrategy::LocalPreferred)
687 .with_priority(true);
688
689 assert_eq!(config.num_workers, 4);
690 assert_eq!(config.steal_strategy, StealStrategy::MaxLoad);
691 assert_eq!(config.numa_strategy, NumaStrategy::LocalPreferred);
692 assert!(config.enable_priority);
693 }
694
695 #[test]
696 fn test_task_creation() {
697 let task = Task::new("task1".to_string())
698 .with_priority(TaskPriority::High)
699 .with_dependency("task0".to_string())
700 .with_estimated_time(1000);
701
702 assert_eq!(task.id, "task1");
703 assert_eq!(task.priority, TaskPriority::High);
704 assert_eq!(task.dependencies.len(), 1);
705 assert_eq!(task.estimated_time_us, Some(1000));
706 }
707
708 #[test]
709 fn test_scheduler_creation() {
710 let config = ParallelConfig::new(4).expect("unwrap");
711 let scheduler = WorkStealingScheduler::new(config);
712
713 assert_eq!(scheduler.workers.len(), 4);
714 }
715
716 #[test]
717 fn test_scheduler_submit() {
718 let config = ParallelConfig::new(2).expect("unwrap");
719 let scheduler = WorkStealingScheduler::new(config);
720
721 let task = Task::new("task1".to_string());
722 assert!(scheduler.submit(task).is_ok());
723 }
724
725 #[test]
726 fn test_scheduler_execute_simple() {
727 let config = ParallelConfig::new(2).expect("unwrap");
728 let scheduler = WorkStealingScheduler::new(config);
729
730 let task1 = Task::new("task1".to_string()).with_estimated_time(100);
731 let task2 = Task::new("task2".to_string()).with_estimated_time(200);
732
733 scheduler.submit(task1).expect("unwrap");
734 scheduler.submit(task2).expect("unwrap");
735
736 let completed = scheduler.execute_all().expect("unwrap");
737 assert_eq!(completed.len(), 2);
738 }
739
740 #[test]
741 fn test_scheduler_dependencies() {
742 let config = ParallelConfig::new(2).expect("unwrap");
743 let scheduler = WorkStealingScheduler::new(config);
744
745 let task1 = Task::new("task1".to_string());
746 let task2 = Task::new("task2".to_string()).with_dependency("task1".to_string());
747
748 scheduler.submit(task1).expect("unwrap");
749 scheduler.submit(task2).expect("unwrap");
750
751 let completed = scheduler.execute_all().expect("unwrap");
752 assert!(completed.contains(&"task1".to_string()));
753 }
754
755 #[test]
756 fn test_scheduler_stats() {
757 let config = ParallelConfig::new(2).expect("unwrap");
758 let scheduler = WorkStealingScheduler::new(config);
759
760 let task1 = Task::new("task1".to_string()).with_estimated_time(1000);
761 let task2 = Task::new("task2".to_string()).with_estimated_time(2000);
762
763 scheduler.submit(task1).expect("unwrap");
764 scheduler.submit(task2).expect("unwrap");
765 scheduler.execute_all().expect("unwrap");
766
767 let stats = scheduler.stats();
768 assert_eq!(stats.tasks_executed, 2);
769 assert_eq!(stats.total_execution_time_us, 3000);
770 }
771
772 #[test]
773 fn test_load_balance_stats() {
774 let config = ParallelConfig::new(4).expect("unwrap");
775 let scheduler = WorkStealingScheduler::new(config);
776
777 for i in 0..8 {
779 let task = Task::new(format!("task{}", i)).with_estimated_time(100);
780 scheduler.submit(task).expect("unwrap");
781 }
782
783 scheduler.execute_all().expect("unwrap");
784
785 let stats = scheduler.load_balance_stats();
786 assert!((stats.avg_load - 2.0).abs() < 0.1); }
788
789 #[test]
790 fn test_scheduler_reset() {
791 let config = ParallelConfig::new(2).expect("unwrap");
792 let scheduler = WorkStealingScheduler::new(config);
793
794 let task = Task::new("task1".to_string());
795 scheduler.submit(task).expect("unwrap");
796 scheduler.execute_all().expect("unwrap");
797
798 let stats_before = scheduler.stats();
799 assert_eq!(stats_before.tasks_executed, 1);
800
801 scheduler.reset();
802
803 let stats_after = scheduler.stats();
804 assert_eq!(stats_after.tasks_executed, 0);
805 }
806
807 #[test]
808 fn test_task_priority() {
809 assert!(TaskPriority::Critical > TaskPriority::High);
810 assert!(TaskPriority::High > TaskPriority::Normal);
811 assert!(TaskPriority::Normal > TaskPriority::Low);
812 }
813
814 #[test]
815 fn test_numa_node() {
816 let node = NumaNode(0);
817 assert_eq!(node.0, 0);
818 }
819
820 #[test]
821 fn test_steal_strategy() {
822 let s1 = StealStrategy::Random;
824 let s2 = s1;
825 let s3 = s1; assert_eq!(s2, s3);
827 }
828}