scirs2_autograd/parallel/
thread_pool.rs1use super::{ThreadPoolConfig, ThreadPoolError, WorkerStats};
8use std::collections::VecDeque;
9use std::sync::{
10 atomic::{AtomicBool, Ordering},
11 Arc, Condvar, Mutex,
12};
13use std::thread::{self, JoinHandle};
14use std::time::{Duration, Instant};
15
16pub struct AdvancedThreadPool {
18 workers: Vec<WorkStealingWorker>,
19 global_queue: Arc<Mutex<VecDeque<Task>>>,
20 config: ThreadPoolConfig,
21 running: Arc<AtomicBool>,
22 stats: Arc<Mutex<AdvancedThreadPoolStats>>,
23}
24
25impl AdvancedThreadPool {
26 pub fn new(config: ThreadPoolConfig) -> Self {
28 let global_queue = Arc::new(Mutex::new(VecDeque::new()));
29 let running = Arc::new(AtomicBool::new(true));
30 let stats = Arc::new(Mutex::new(AdvancedThreadPoolStats::new(config.num_threads)));
31
32 let mut workers = Vec::with_capacity(config.num_threads);
33
34 for id in 0..config.num_threads {
35 let worker = WorkStealingWorker::new(
36 id,
37 Arc::clone(&global_queue),
38 Arc::clone(&running),
39 Arc::clone(&stats),
40 config.clone(),
41 );
42 workers.push(worker);
43 }
44
45 Self {
46 workers,
47 global_queue,
48 config,
49 running,
50 stats,
51 }
52 }
53
54 pub fn submit<F>(&self, task: F) -> Result<TaskHandle<()>, ThreadPoolError>
56 where
57 F: FnOnce() + Send + 'static,
58 {
59 if !self.running.load(Ordering::Relaxed) {
60 return Err(ThreadPoolError::QueueFull);
61 }
62
63 let (task, handle) = Task::new(task);
64
65 let mut queue = self.global_queue.lock().expect("Test: operation failed");
67 if queue.len() >= self.config.max_queue_size {
68 return Err(ThreadPoolError::QueueFull);
69 }
70 queue.push_back(task);
71
72 Ok(handle)
73 }
74
75 pub fn submit_batch<F, I>(&self, tasks: I) -> Result<Vec<TaskHandle<()>>, ThreadPoolError>
77 where
78 F: FnOnce() + Send + 'static,
79 I: IntoIterator<Item = F>,
80 {
81 let tasks: Vec<F> = tasks.into_iter().collect();
82 let mut handles = Vec::with_capacity(tasks.len());
83
84 for task in tasks {
85 handles.push(self.submit(task)?);
86 }
87
88 Ok(handles)
89 }
90
91 #[allow(dead_code)]
93 fn find_least_loaded_worker(&self) -> Option<usize> {
94 if self.config.work_stealing {
96 self.workers
97 .iter()
98 .enumerate()
99 .min_by_key(|(_, worker)| worker.get_queue_size())
100 .map(|(id_, _)| id_)
101 } else {
102 let now = Instant::now();
104 Some(now.elapsed().as_nanos() as usize % self.workers.len())
105 }
106 }
107
108 pub fn get_stats(&self) -> AdvancedThreadPoolStats {
110 self.stats.lock().expect("Test: operation failed").clone()
111 }
112
113 pub fn shutdown(self) -> Result<(), ThreadPoolError> {
115 self.running.store(false, Ordering::Relaxed);
116
117 for worker in &self.workers {
119 worker.notify_shutdown();
120 }
121
122 for worker in self.workers {
124 worker.join().map_err(|_| ThreadPoolError::ShutdownFailed)?;
125 }
126
127 Ok(())
128 }
129
130 pub fn resize(&mut self, new_size: usize) -> Result<(), ThreadPoolError> {
132 if new_size == 0 {
133 return Err(ThreadPoolError::InvalidConfiguration(
134 "Thread pool size cannot be zero".into(),
135 ));
136 }
137
138 let current_size = self.workers.len();
139
140 match new_size.cmp(¤t_size) {
141 std::cmp::Ordering::Greater => {
142 for id in current_size..new_size {
144 let worker = WorkStealingWorker::new(
145 id,
146 Arc::clone(&self.global_queue),
147 Arc::clone(&self.running),
148 Arc::clone(&self.stats),
149 self.config.clone(),
150 );
151 self.workers.push(worker);
152 }
153 }
154 std::cmp::Ordering::Less => {
155 self.workers.truncate(new_size);
157 }
158 std::cmp::Ordering::Equal => {
159 }
161 }
162
163 self.config.num_threads = new_size;
164 Ok(())
165 }
166}
167
168pub struct WorkStealingWorker {
170 #[allow(dead_code)]
171 id: usize,
172 #[allow(dead_code)]
173 local_queue: Arc<Mutex<VecDeque<Task>>>,
174 thread_handle: Option<JoinHandle<()>>,
175 shutdown_signal: Arc<(Mutex<bool>, Condvar)>,
176}
177
178impl WorkStealingWorker {
179 fn new(
181 id: usize,
182 global_queue: Arc<Mutex<VecDeque<Task>>>,
183 running: Arc<AtomicBool>,
184 stats: Arc<Mutex<AdvancedThreadPoolStats>>,
185 config: ThreadPoolConfig,
186 ) -> Self {
187 let local_queue = Arc::new(Mutex::new(VecDeque::new()));
188 let shutdown_signal = Arc::new((Mutex::new(false), Condvar::new()));
189
190 let local_queue_clone = Arc::clone(&local_queue);
191 let shutdown_signal_clone = Arc::clone(&shutdown_signal);
192
193 let thread_handle = thread::spawn(move || {
194 Self::worker_loop(
195 id,
196 local_queue_clone,
197 global_queue,
198 running,
199 stats,
200 config,
201 shutdown_signal_clone,
202 );
203 });
204
205 Self {
206 id,
207 local_queue,
208 thread_handle: Some(thread_handle),
209 shutdown_signal,
210 }
211 }
212
213 fn worker_loop(
215 id: usize,
216 local_queue: Arc<Mutex<VecDeque<Task>>>,
217 global_queue: Arc<Mutex<VecDeque<Task>>>,
218 running: Arc<AtomicBool>,
219 stats: Arc<Mutex<AdvancedThreadPoolStats>>,
220 config: ThreadPoolConfig,
221 shutdown_signal: Arc<(Mutex<bool>, Condvar)>,
222 ) {
223 let mut idle_start = None;
224
225 while running.load(Ordering::Relaxed) {
226 let task = Self::find_task(&local_queue, &global_queue, &config);
227
228 match task {
229 Some(task) => {
230 idle_start = None;
231 let start_time = Instant::now();
232
233 task.execute();
235
236 let execution_time = start_time.elapsed();
237
238 {
240 let mut stats = stats.lock().expect("Test: operation failed");
241 stats.total_tasks_executed += 1;
242 stats.total_execution_time += execution_time;
243 stats.worker_stats[id].tasks_completed += 1;
244 stats.worker_stats[id].total_time += execution_time;
245 stats.worker_stats[id].last_activity = Some(Instant::now());
246 }
247 }
248 None => {
249 if idle_start.is_none() {
251 idle_start = Some(Instant::now());
252 }
253
254 if let Some(start) = idle_start {
256 if start.elapsed() > config.idle_timeout {
257 let (lock, cvar) = &*shutdown_signal;
258 let mut shutdown = lock.lock().expect("Test: operation failed");
259 while !*shutdown && running.load(Ordering::Relaxed) {
260 let result = cvar
261 .wait_timeout(shutdown, Duration::from_millis(100))
262 .expect("Test: wait timeout failed");
263 shutdown = result.0;
264 if result.1.timed_out() {
265 break;
266 }
267 }
268 }
269 }
270
271 thread::sleep(Duration::from_micros(100));
273 }
274 }
275 }
276 }
277
278 fn find_task(
280 local_queue: &Arc<Mutex<VecDeque<Task>>>,
281 global_queue: &Arc<Mutex<VecDeque<Task>>>,
282 config: &ThreadPoolConfig,
283 ) -> Option<Task> {
284 {
286 let mut _queue = local_queue.lock().expect("Test: operation failed");
287 if let Some(task) = _queue.pop_front() {
288 return Some(task);
289 }
290 }
291
292 {
294 let mut _queue = global_queue.lock().expect("Test: operation failed");
295 if let Some(task) = _queue.pop_front() {
296 return Some(task);
297 }
298 }
299
300 if config.work_stealing {
302 }
304
305 None
306 }
307
308 #[allow(dead_code)]
310 fn try_submit_local(&self, task: Task) -> bool {
311 let mut queue = self.local_queue.lock().expect("Test: operation failed");
312 if queue.len()
313 < self
314 .local_queue
315 .lock()
316 .expect("Test: operation failed")
317 .capacity()
318 {
319 queue.push_back(task);
320 true
321 } else {
322 false
323 }
324 }
325
326 fn get_queue_size(&self) -> usize {
328 self.local_queue
329 .lock()
330 .expect("Test: operation failed")
331 .len()
332 }
333
334 fn notify_shutdown(&self) {
336 let (lock, cvar) = &*self.shutdown_signal;
337 let mut shutdown = lock.lock().expect("Test: operation failed");
338 *shutdown = true;
339 cvar.notify_one();
340 }
341
342 fn join(mut self) -> Result<(), Box<dyn std::any::Any + Send>> {
344 if let Some(handle) = self.thread_handle.take() {
345 handle.join()
346 } else {
347 Ok(())
348 }
349 }
350}
351
352pub struct Task {
354 func: Box<dyn FnOnce() + Send + 'static>,
355 created_at: Instant,
356 priority: TaskPriority,
357}
358
359impl Task {
360 pub fn new<F>(func: F) -> (Self, TaskHandle<()>)
362 where
363 F: FnOnce() + Send + 'static,
364 {
365 let (sender, receiver) = std::sync::mpsc::channel();
366
367 let task = Task {
368 func: Box::new(move || {
369 func();
370 let _ = sender.send(());
371 }),
372 created_at: Instant::now(),
373 priority: TaskPriority::Normal,
374 };
375
376 let handle = TaskHandle { receiver };
377 (task, handle)
378 }
379
380 fn execute(self) {
382 (self.func)();
383 }
384
385 pub fn age(&self) -> Duration {
387 self.created_at.elapsed()
388 }
389
390 pub fn with_priority(mut self, priority: TaskPriority) -> Self {
392 self.priority = priority;
393 self
394 }
395}
396
397#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
399pub enum TaskPriority {
400 Low,
401 Normal,
402 High,
403 Critical,
404}
405
406pub struct TaskHandle<T> {
408 receiver: std::sync::mpsc::Receiver<T>,
409}
410
411impl<T> TaskHandle<T> {
412 pub fn wait(self) -> Result<T, ThreadPoolError> {
414 self.receiver
415 .recv()
416 .map_err(|_| ThreadPoolError::ExecutionFailed)
417 }
418
419 pub fn wait_timeout(self, timeout: Duration) -> Result<T, ThreadPoolError> {
421 self.receiver
422 .recv_timeout(timeout)
423 .map_err(|_| ThreadPoolError::ExecutionFailed)
424 }
425
426 pub fn try_wait(&self) -> Result<Option<T>, ThreadPoolError> {
428 match self.receiver.try_recv() {
429 Ok(result) => Ok(Some(result)),
430 Err(std::sync::mpsc::TryRecvError::Empty) => Ok(None),
431 Err(std::sync::mpsc::TryRecvError::Disconnected) => {
432 Err(ThreadPoolError::ExecutionFailed)
433 }
434 }
435 }
436}
437
438#[derive(Debug, Clone)]
440pub struct AdvancedThreadPoolStats {
441 pub total_tasks_executed: u64,
443 pub total_execution_time: Duration,
445 pub work_steals: u64,
447 pub load_balance_efficiency: f64,
449 pub worker_stats: Vec<WorkerStats>,
451 pub queue_contention: f64,
453}
454
455impl AdvancedThreadPoolStats {
456 fn new(_numworkers: usize) -> Self {
457 Self {
458 total_tasks_executed: 0,
459 total_execution_time: Duration::ZERO,
460 work_steals: 0,
461 load_balance_efficiency: 1.0,
462 worker_stats: (0.._numworkers).map(WorkerStats::new).collect(),
463 queue_contention: 0.0,
464 }
465 }
466
467 pub fn throughput(&self) -> f64 {
469 if self.total_execution_time.is_zero() {
470 0.0
471 } else {
472 self.total_tasks_executed as f64 / self.total_execution_time.as_secs_f64()
473 }
474 }
475
476 pub fn average_latency(&self) -> Duration {
478 if self.total_tasks_executed == 0 {
479 Duration::ZERO
480 } else {
481 self.total_execution_time / self.total_tasks_executed as u32
482 }
483 }
484
485 pub fn worker_utilization(&self) -> Vec<f64> {
487 let total_time = self.total_execution_time;
488 self.worker_stats
489 .iter()
490 .map(|stats| {
491 if total_time.is_zero() {
492 0.0
493 } else {
494 stats.total_time.as_secs_f64() / total_time.as_secs_f64()
495 }
496 })
497 .collect()
498 }
499
500 pub fn calculate_load_balance_efficiency(&self) -> f64 {
502 if self.worker_stats.len() <= 1 {
503 return 1.0;
504 }
505
506 let task_counts: Vec<u64> = self
507 .worker_stats
508 .iter()
509 .map(|stats| stats.tasks_completed)
510 .collect();
511
512 let total_tasks: u64 = task_counts.iter().sum();
513 if total_tasks == 0 {
514 return 1.0;
515 }
516
517 let average_tasks = total_tasks as f64 / task_counts.len() as f64;
518 let variance: f64 = task_counts
519 .iter()
520 .map(|&count| {
521 let diff = count as f64 - average_tasks;
522 diff * diff
523 })
524 .sum::<f64>()
525 / task_counts.len() as f64;
526
527 let std_dev = variance.sqrt();
528 let coefficient_of_variation = if average_tasks > 0.0 {
529 std_dev / average_tasks
530 } else {
531 0.0
532 };
533
534 (1.0 - coefficient_of_variation.min(1.0)).max(0.0)
536 }
537}
538
539pub struct NumaAwareThreadPool {
541 pools: Vec<AdvancedThreadPool>,
542 #[allow(dead_code)]
543 numa_topology: NumaTopology,
544}
545
546impl NumaAwareThreadPool {
547 pub fn new(config: ThreadPoolConfig) -> Self {
549 let topology = NumaTopology::detect();
550 let pools_per_node = config.num_threads / topology.num_nodes.max(1);
551
552 let mut pools = Vec::with_capacity(topology.num_nodes);
553
554 for _ in 0..topology.num_nodes {
555 let node_config = ThreadPoolConfig {
556 num_threads: pools_per_node,
557 ..config.clone()
558 };
559 pools.push(AdvancedThreadPool::new(node_config));
560 }
561
562 Self {
563 pools,
564 numa_topology: topology,
565 }
566 }
567
568 pub fn submit_numa<F>(
570 &self,
571 task: F,
572 preferred_node: Option<usize>,
573 ) -> Result<TaskHandle<()>, ThreadPoolError>
574 where
575 F: FnOnce() + Send + 'static,
576 {
577 let _node = preferred_node
578 .unwrap_or_else(|| self.select_optimal_node())
579 .min(self.pools.len() - 1);
580
581 self.pools[_node].submit(task)
582 }
583
584 fn select_optimal_node(&self) -> usize {
586 self.pools
588 .iter()
589 .enumerate()
590 .min_by_key(|(_, pool)| pool.get_stats().total_tasks_executed)
591 .map(|(id_, _)| id_)
592 .unwrap_or(0)
593 }
594}
595
596#[derive(Debug, Clone)]
598pub struct NumaTopology {
599 pub num_nodes: usize,
601 pub cores_per_node: Vec<usize>,
603 pub memory_per_node: Vec<usize>,
605}
606
607impl NumaTopology {
608 fn detect() -> Self {
610 let num_cpus = std::thread::available_parallelism()
613 .map(|n| n.get())
614 .unwrap_or(4);
615
616 Self {
617 num_nodes: 1,
618 cores_per_node: vec![num_cpus],
619 memory_per_node: vec![8 * 1024 * 1024 * 1024], }
621 }
622}
623
624#[cfg(test)]
625mod tests {
626 use super::*;
627 use std::sync::atomic::{AtomicUsize, Ordering};
628
629 #[test]
630 fn test_advanced_thread_pool() {
631 let config = ThreadPoolConfig {
632 num_threads: 2,
633 work_stealing: true,
634 ..Default::default()
635 };
636
637 let pool = AdvancedThreadPool::new(config);
638 let counter = Arc::new(AtomicUsize::new(0));
639 let counter_clone = Arc::clone(&counter);
640
641 let handle = pool
642 .submit(move || {
643 counter_clone.fetch_add(1, Ordering::SeqCst);
644 })
645 .expect("Test: thread spawn failed");
646
647 handle.wait().expect("Test: operation failed");
648 assert_eq!(counter.load(Ordering::SeqCst), 1);
649 }
650
651 #[test]
652 fn test_task_handle_timeout() {
653 let config = ThreadPoolConfig {
654 num_threads: 1,
655 ..Default::default()
656 };
657
658 let pool = AdvancedThreadPool::new(config);
659
660 let handle = pool
661 .submit(|| {
662 std::thread::sleep(Duration::from_millis(200));
663 })
664 .expect("Test: thread spawn failed");
665
666 let result = handle.wait_timeout(Duration::from_millis(50));
668 assert!(result.is_err());
669 }
670
671 #[test]
672 fn test_batch_submission() {
673 let config = ThreadPoolConfig {
674 num_threads: 2,
675 ..Default::default()
676 };
677
678 let pool = AdvancedThreadPool::new(config);
679 let counter = Arc::new(AtomicUsize::new(0));
680
681 let tasks: Vec<_> = (0..5)
682 .map(|_| {
683 let counter_clone = Arc::clone(&counter);
684 move || {
685 counter_clone.fetch_add(1, Ordering::SeqCst);
686 }
687 })
688 .collect();
689
690 let handles = pool.submit_batch(tasks).expect("Test: operation failed");
691
692 for handle in handles {
693 handle.wait().expect("Test: operation failed");
694 }
695
696 assert_eq!(counter.load(Ordering::SeqCst), 5);
697 }
698
699 #[test]
700 fn test_thread_pool_stats() {
701 let config = ThreadPoolConfig {
702 num_threads: 2,
703 ..Default::default()
704 };
705
706 let pool = AdvancedThreadPool::new(config);
707 let stats = pool.get_stats();
708
709 assert_eq!(stats.total_tasks_executed, 0);
710 assert_eq!(stats.worker_stats.len(), 2);
711 }
712
713 #[test]
714 fn test_numa_aware_thread_pool() {
715 let config = ThreadPoolConfig {
716 num_threads: 4,
717 ..Default::default()
718 };
719
720 let numa_pool = NumaAwareThreadPool::new(config);
721 let counter = Arc::new(AtomicUsize::new(0));
722 let counter_clone = Arc::clone(&counter);
723
724 let handle = numa_pool
725 .submit_numa(
726 move || {
727 counter_clone.fetch_add(1, Ordering::SeqCst);
728 },
729 Some(0),
730 )
731 .expect("Test: array creation failed");
732
733 handle.wait().expect("Test: operation failed");
734 assert_eq!(counter.load(Ordering::SeqCst), 1);
735 }
736
737 #[test]
738 fn test_task_priority() {
739 let task = Task::new(|| {}).0.with_priority(TaskPriority::High);
740 assert_eq!(task.priority, TaskPriority::High);
741 }
742
743 #[test]
744 fn test_numa_topology() {
745 let topology = NumaTopology::detect();
746 assert!(topology.num_nodes > 0);
747 assert!(!topology.cores_per_node.is_empty());
748 assert!(!topology.memory_per_node.is_empty());
749 }
750}