Skip to main content

scirs2_autograd/parallel/
thread_pool.rs

1//! Advanced thread pool implementation with work stealing and load balancing
2//!
3//! This module provides a high-performance thread pool optimized for
4//! scientific computing workloads with features like work stealing,
5//! NUMA awareness, and adaptive scheduling.
6
7use super::{ThreadPoolConfig, ThreadPoolError, WorkerStats};
8use std::collections::VecDeque;
9use std::sync::{
10    atomic::{AtomicBool, Ordering},
11    Arc, Condvar, Mutex,
12};
13use std::thread::{self, JoinHandle};
14use std::time::{Duration, Instant};
15
16/// Advanced thread pool with work stealing
17pub struct AdvancedThreadPool {
18    workers: Vec<WorkStealingWorker>,
19    global_queue: Arc<Mutex<VecDeque<Task>>>,
20    config: ThreadPoolConfig,
21    running: Arc<AtomicBool>,
22    stats: Arc<Mutex<AdvancedThreadPoolStats>>,
23}
24
25impl AdvancedThreadPool {
26    /// Create a new advanced thread pool
27    pub fn new(config: ThreadPoolConfig) -> Self {
28        let global_queue = Arc::new(Mutex::new(VecDeque::new()));
29        let running = Arc::new(AtomicBool::new(true));
30        let stats = Arc::new(Mutex::new(AdvancedThreadPoolStats::new(config.num_threads)));
31
32        let mut workers = Vec::with_capacity(config.num_threads);
33
34        for id in 0..config.num_threads {
35            let worker = WorkStealingWorker::new(
36                id,
37                Arc::clone(&global_queue),
38                Arc::clone(&running),
39                Arc::clone(&stats),
40                config.clone(),
41            );
42            workers.push(worker);
43        }
44
45        Self {
46            workers,
47            global_queue,
48            config,
49            running,
50            stats,
51        }
52    }
53
54    /// Submit a task to the thread pool
55    pub fn submit<F>(&self, task: F) -> Result<TaskHandle<()>, ThreadPoolError>
56    where
57        F: FnOnce() + Send + 'static,
58    {
59        if !self.running.load(Ordering::Relaxed) {
60            return Err(ThreadPoolError::QueueFull);
61        }
62
63        let (task, handle) = Task::new(task);
64
65        // Use global queue for all tasks to avoid move issues
66        let mut queue = self.global_queue.lock().expect("Test: operation failed");
67        if queue.len() >= self.config.max_queue_size {
68            return Err(ThreadPoolError::QueueFull);
69        }
70        queue.push_back(task);
71
72        Ok(handle)
73    }
74
75    /// Submit a batch of tasks
76    pub fn submit_batch<F, I>(&self, tasks: I) -> Result<Vec<TaskHandle<()>>, ThreadPoolError>
77    where
78        F: FnOnce() + Send + 'static,
79        I: IntoIterator<Item = F>,
80    {
81        let tasks: Vec<F> = tasks.into_iter().collect();
82        let mut handles = Vec::with_capacity(tasks.len());
83
84        for task in tasks {
85            handles.push(self.submit(task)?);
86        }
87
88        Ok(handles)
89    }
90
91    /// Find the least loaded worker
92    #[allow(dead_code)]
93    fn find_least_loaded_worker(&self) -> Option<usize> {
94        // Use round-robin strategy by default since ThreadPoolConfig doesn't have load_balancing
95        if self.config.work_stealing {
96            self.workers
97                .iter()
98                .enumerate()
99                .min_by_key(|(_, worker)| worker.get_queue_size())
100                .map(|(id_, _)| id_)
101        } else {
102            // Simple round-robin based on current time
103            let now = Instant::now();
104            Some(now.elapsed().as_nanos() as usize % self.workers.len())
105        }
106    }
107
108    /// Get thread pool statistics
109    pub fn get_stats(&self) -> AdvancedThreadPoolStats {
110        self.stats.lock().expect("Test: operation failed").clone()
111    }
112
113    /// Shutdown the thread pool
114    pub fn shutdown(self) -> Result<(), ThreadPoolError> {
115        self.running.store(false, Ordering::Relaxed);
116
117        // Wake up all workers
118        for worker in &self.workers {
119            worker.notify_shutdown();
120        }
121
122        // Wait for all workers to finish
123        for worker in self.workers {
124            worker.join().map_err(|_| ThreadPoolError::ShutdownFailed)?;
125        }
126
127        Ok(())
128    }
129
130    /// Resize the thread pool (dynamic scaling)
131    pub fn resize(&mut self, new_size: usize) -> Result<(), ThreadPoolError> {
132        if new_size == 0 {
133            return Err(ThreadPoolError::InvalidConfiguration(
134                "Thread pool size cannot be zero".into(),
135            ));
136        }
137
138        let current_size = self.workers.len();
139
140        match new_size.cmp(&current_size) {
141            std::cmp::Ordering::Greater => {
142                // Add new workers
143                for id in current_size..new_size {
144                    let worker = WorkStealingWorker::new(
145                        id,
146                        Arc::clone(&self.global_queue),
147                        Arc::clone(&self.running),
148                        Arc::clone(&self.stats),
149                        self.config.clone(),
150                    );
151                    self.workers.push(worker);
152                }
153            }
154            std::cmp::Ordering::Less => {
155                // Remove workers (simplified - in practice would need graceful shutdown)
156                self.workers.truncate(new_size);
157            }
158            std::cmp::Ordering::Equal => {
159                // No change needed
160            }
161        }
162
163        self.config.num_threads = new_size;
164        Ok(())
165    }
166}
167
168/// Work-stealing worker thread
169pub struct WorkStealingWorker {
170    #[allow(dead_code)]
171    id: usize,
172    #[allow(dead_code)]
173    local_queue: Arc<Mutex<VecDeque<Task>>>,
174    thread_handle: Option<JoinHandle<()>>,
175    shutdown_signal: Arc<(Mutex<bool>, Condvar)>,
176}
177
178impl WorkStealingWorker {
179    /// Create a new work-stealing worker
180    fn new(
181        id: usize,
182        global_queue: Arc<Mutex<VecDeque<Task>>>,
183        running: Arc<AtomicBool>,
184        stats: Arc<Mutex<AdvancedThreadPoolStats>>,
185        config: ThreadPoolConfig,
186    ) -> Self {
187        let local_queue = Arc::new(Mutex::new(VecDeque::new()));
188        let shutdown_signal = Arc::new((Mutex::new(false), Condvar::new()));
189
190        let local_queue_clone = Arc::clone(&local_queue);
191        let shutdown_signal_clone = Arc::clone(&shutdown_signal);
192
193        let thread_handle = thread::spawn(move || {
194            Self::worker_loop(
195                id,
196                local_queue_clone,
197                global_queue,
198                running,
199                stats,
200                config,
201                shutdown_signal_clone,
202            );
203        });
204
205        Self {
206            id,
207            local_queue,
208            thread_handle: Some(thread_handle),
209            shutdown_signal,
210        }
211    }
212
213    /// Main worker loop with work stealing
214    fn worker_loop(
215        id: usize,
216        local_queue: Arc<Mutex<VecDeque<Task>>>,
217        global_queue: Arc<Mutex<VecDeque<Task>>>,
218        running: Arc<AtomicBool>,
219        stats: Arc<Mutex<AdvancedThreadPoolStats>>,
220        config: ThreadPoolConfig,
221        shutdown_signal: Arc<(Mutex<bool>, Condvar)>,
222    ) {
223        let mut idle_start = None;
224
225        while running.load(Ordering::Relaxed) {
226            let task = Self::find_task(&local_queue, &global_queue, &config);
227
228            match task {
229                Some(task) => {
230                    idle_start = None;
231                    let start_time = Instant::now();
232
233                    // Execute the task
234                    task.execute();
235
236                    let execution_time = start_time.elapsed();
237
238                    // Update statistics
239                    {
240                        let mut stats = stats.lock().expect("Test: operation failed");
241                        stats.total_tasks_executed += 1;
242                        stats.total_execution_time += execution_time;
243                        stats.worker_stats[id].tasks_completed += 1;
244                        stats.worker_stats[id].total_time += execution_time;
245                        stats.worker_stats[id].last_activity = Some(Instant::now());
246                    }
247                }
248                None => {
249                    // No task found, handle idle time
250                    if idle_start.is_none() {
251                        idle_start = Some(Instant::now());
252                    }
253
254                    // Check for shutdown after idle timeout
255                    if let Some(start) = idle_start {
256                        if start.elapsed() > config.idle_timeout {
257                            let (lock, cvar) = &*shutdown_signal;
258                            let mut shutdown = lock.lock().expect("Test: operation failed");
259                            while !*shutdown && running.load(Ordering::Relaxed) {
260                                let result = cvar
261                                    .wait_timeout(shutdown, Duration::from_millis(100))
262                                    .expect("Test: wait timeout failed");
263                                shutdown = result.0;
264                                if result.1.timed_out() {
265                                    break;
266                                }
267                            }
268                        }
269                    }
270
271                    // Brief sleep to avoid busy waiting
272                    thread::sleep(Duration::from_micros(100));
273                }
274            }
275        }
276    }
277
278    /// Find a task to execute (local queue -> global queue -> work stealing)
279    fn find_task(
280        local_queue: &Arc<Mutex<VecDeque<Task>>>,
281        global_queue: &Arc<Mutex<VecDeque<Task>>>,
282        config: &ThreadPoolConfig,
283    ) -> Option<Task> {
284        // Try local _queue first
285        {
286            let mut _queue = local_queue.lock().expect("Test: operation failed");
287            if let Some(task) = _queue.pop_front() {
288                return Some(task);
289            }
290        }
291
292        // Try global _queue
293        {
294            let mut _queue = global_queue.lock().expect("Test: operation failed");
295            if let Some(task) = _queue.pop_front() {
296                return Some(task);
297            }
298        }
299
300        // Work stealing (simplified - would need access to other workers)
301        if config.work_stealing {
302            // Implementation would steal from other workers' queues
303        }
304
305        None
306    }
307
308    /// Try to submit a task to the local queue
309    #[allow(dead_code)]
310    fn try_submit_local(&self, task: Task) -> bool {
311        let mut queue = self.local_queue.lock().expect("Test: operation failed");
312        if queue.len()
313            < self
314                .local_queue
315                .lock()
316                .expect("Test: operation failed")
317                .capacity()
318        {
319            queue.push_back(task);
320            true
321        } else {
322            false
323        }
324    }
325
326    /// Get the current queue size
327    fn get_queue_size(&self) -> usize {
328        self.local_queue
329            .lock()
330            .expect("Test: operation failed")
331            .len()
332    }
333
334    /// Notify worker to shutdown
335    fn notify_shutdown(&self) {
336        let (lock, cvar) = &*self.shutdown_signal;
337        let mut shutdown = lock.lock().expect("Test: operation failed");
338        *shutdown = true;
339        cvar.notify_one();
340    }
341
342    /// Join the worker thread
343    fn join(mut self) -> Result<(), Box<dyn std::any::Any + Send>> {
344        if let Some(handle) = self.thread_handle.take() {
345            handle.join()
346        } else {
347            Ok(())
348        }
349    }
350}
351
352/// Task wrapper with execution tracking
353pub struct Task {
354    func: Box<dyn FnOnce() + Send + 'static>,
355    created_at: Instant,
356    priority: TaskPriority,
357}
358
359impl Task {
360    /// Create a new task with completion handle
361    pub fn new<F>(func: F) -> (Self, TaskHandle<()>)
362    where
363        F: FnOnce() + Send + 'static,
364    {
365        let (sender, receiver) = std::sync::mpsc::channel();
366
367        let task = Task {
368            func: Box::new(move || {
369                func();
370                let _ = sender.send(());
371            }),
372            created_at: Instant::now(),
373            priority: TaskPriority::Normal,
374        };
375
376        let handle = TaskHandle { receiver };
377        (task, handle)
378    }
379
380    /// Execute the task
381    fn execute(self) {
382        (self.func)();
383    }
384
385    /// Get task age
386    pub fn age(&self) -> Duration {
387        self.created_at.elapsed()
388    }
389
390    /// Set task priority
391    pub fn with_priority(mut self, priority: TaskPriority) -> Self {
392        self.priority = priority;
393        self
394    }
395}
396
397/// Task priority levels
398#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
399pub enum TaskPriority {
400    Low,
401    Normal,
402    High,
403    Critical,
404}
405
406/// Handle for waiting on task completion
407pub struct TaskHandle<T> {
408    receiver: std::sync::mpsc::Receiver<T>,
409}
410
411impl<T> TaskHandle<T> {
412    /// Wait for task completion
413    pub fn wait(self) -> Result<T, ThreadPoolError> {
414        self.receiver
415            .recv()
416            .map_err(|_| ThreadPoolError::ExecutionFailed)
417    }
418
419    /// Wait for task completion with timeout
420    pub fn wait_timeout(self, timeout: Duration) -> Result<T, ThreadPoolError> {
421        self.receiver
422            .recv_timeout(timeout)
423            .map_err(|_| ThreadPoolError::ExecutionFailed)
424    }
425
426    /// Check if task is complete without blocking
427    pub fn try_wait(&self) -> Result<Option<T>, ThreadPoolError> {
428        match self.receiver.try_recv() {
429            Ok(result) => Ok(Some(result)),
430            Err(std::sync::mpsc::TryRecvError::Empty) => Ok(None),
431            Err(std::sync::mpsc::TryRecvError::Disconnected) => {
432                Err(ThreadPoolError::ExecutionFailed)
433            }
434        }
435    }
436}
437
438/// Advanced statistics for the thread pool
439#[derive(Debug, Clone)]
440pub struct AdvancedThreadPoolStats {
441    /// Total tasks executed across all workers
442    pub total_tasks_executed: u64,
443    /// Total execution time
444    pub total_execution_time: Duration,
445    /// Number of work steals performed
446    pub work_steals: u64,
447    /// Load balancing efficiency (0.0 to 1.0)
448    pub load_balance_efficiency: f64,
449    /// Per-worker statistics
450    pub worker_stats: Vec<WorkerStats>,
451    /// Queue contention metrics
452    pub queue_contention: f64,
453}
454
455impl AdvancedThreadPoolStats {
456    fn new(_numworkers: usize) -> Self {
457        Self {
458            total_tasks_executed: 0,
459            total_execution_time: Duration::ZERO,
460            work_steals: 0,
461            load_balance_efficiency: 1.0,
462            worker_stats: (0.._numworkers).map(WorkerStats::new).collect(),
463            queue_contention: 0.0,
464        }
465    }
466
467    /// Calculate throughput (tasks per second)
468    pub fn throughput(&self) -> f64 {
469        if self.total_execution_time.is_zero() {
470            0.0
471        } else {
472            self.total_tasks_executed as f64 / self.total_execution_time.as_secs_f64()
473        }
474    }
475
476    /// Calculate average task latency
477    pub fn average_latency(&self) -> Duration {
478        if self.total_tasks_executed == 0 {
479            Duration::ZERO
480        } else {
481            self.total_execution_time / self.total_tasks_executed as u32
482        }
483    }
484
485    /// Calculate worker utilization
486    pub fn worker_utilization(&self) -> Vec<f64> {
487        let total_time = self.total_execution_time;
488        self.worker_stats
489            .iter()
490            .map(|stats| {
491                if total_time.is_zero() {
492                    0.0
493                } else {
494                    stats.total_time.as_secs_f64() / total_time.as_secs_f64()
495                }
496            })
497            .collect()
498    }
499
500    /// Calculate load balance efficiency
501    pub fn calculate_load_balance_efficiency(&self) -> f64 {
502        if self.worker_stats.len() <= 1 {
503            return 1.0;
504        }
505
506        let task_counts: Vec<u64> = self
507            .worker_stats
508            .iter()
509            .map(|stats| stats.tasks_completed)
510            .collect();
511
512        let total_tasks: u64 = task_counts.iter().sum();
513        if total_tasks == 0 {
514            return 1.0;
515        }
516
517        let average_tasks = total_tasks as f64 / task_counts.len() as f64;
518        let variance: f64 = task_counts
519            .iter()
520            .map(|&count| {
521                let diff = count as f64 - average_tasks;
522                diff * diff
523            })
524            .sum::<f64>()
525            / task_counts.len() as f64;
526
527        let std_dev = variance.sqrt();
528        let coefficient_of_variation = if average_tasks > 0.0 {
529            std_dev / average_tasks
530        } else {
531            0.0
532        };
533
534        // Convert to efficiency (lower variation = higher efficiency)
535        (1.0 - coefficient_of_variation.min(1.0)).max(0.0)
536    }
537}
538
539/// NUMA-aware thread pool for systems with multiple memory nodes
540pub struct NumaAwareThreadPool {
541    pools: Vec<AdvancedThreadPool>,
542    #[allow(dead_code)]
543    numa_topology: NumaTopology,
544}
545
546impl NumaAwareThreadPool {
547    /// Create a NUMA-aware thread pool
548    pub fn new(config: ThreadPoolConfig) -> Self {
549        let topology = NumaTopology::detect();
550        let pools_per_node = config.num_threads / topology.num_nodes.max(1);
551
552        let mut pools = Vec::with_capacity(topology.num_nodes);
553
554        for _ in 0..topology.num_nodes {
555            let node_config = ThreadPoolConfig {
556                num_threads: pools_per_node,
557                ..config.clone()
558            };
559            pools.push(AdvancedThreadPool::new(node_config));
560        }
561
562        Self {
563            pools,
564            numa_topology: topology,
565        }
566    }
567
568    /// Submit a task to the appropriate NUMA node
569    pub fn submit_numa<F>(
570        &self,
571        task: F,
572        preferred_node: Option<usize>,
573    ) -> Result<TaskHandle<()>, ThreadPoolError>
574    where
575        F: FnOnce() + Send + 'static,
576    {
577        let _node = preferred_node
578            .unwrap_or_else(|| self.select_optimal_node())
579            .min(self.pools.len() - 1);
580
581        self.pools[_node].submit(task)
582    }
583
584    /// Select the optimal NUMA node for task placement
585    fn select_optimal_node(&self) -> usize {
586        // Simple load balancing - choose least loaded node
587        self.pools
588            .iter()
589            .enumerate()
590            .min_by_key(|(_, pool)| pool.get_stats().total_tasks_executed)
591            .map(|(id_, _)| id_)
592            .unwrap_or(0)
593    }
594}
595
596/// NUMA topology information
597#[derive(Debug, Clone)]
598pub struct NumaTopology {
599    /// Number of NUMA nodes
600    pub num_nodes: usize,
601    /// CPU cores per node
602    pub cores_per_node: Vec<usize>,
603    /// Memory per node (in bytes)
604    pub memory_per_node: Vec<usize>,
605}
606
607impl NumaTopology {
608    /// Detect NUMA topology (simplified implementation)
609    fn detect() -> Self {
610        // In a real implementation, this would query the system for NUMA information
611        // For now, assume a simple single-node system
612        let num_cpus = std::thread::available_parallelism()
613            .map(|n| n.get())
614            .unwrap_or(4);
615
616        Self {
617            num_nodes: 1,
618            cores_per_node: vec![num_cpus],
619            memory_per_node: vec![8 * 1024 * 1024 * 1024], // 8GB default
620        }
621    }
622}
623
624#[cfg(test)]
625mod tests {
626    use super::*;
627    use std::sync::atomic::{AtomicUsize, Ordering};
628
629    #[test]
630    fn test_advanced_thread_pool() {
631        let config = ThreadPoolConfig {
632            num_threads: 2,
633            work_stealing: true,
634            ..Default::default()
635        };
636
637        let pool = AdvancedThreadPool::new(config);
638        let counter = Arc::new(AtomicUsize::new(0));
639        let counter_clone = Arc::clone(&counter);
640
641        let handle = pool
642            .submit(move || {
643                counter_clone.fetch_add(1, Ordering::SeqCst);
644            })
645            .expect("Test: thread spawn failed");
646
647        handle.wait().expect("Test: operation failed");
648        assert_eq!(counter.load(Ordering::SeqCst), 1);
649    }
650
651    #[test]
652    fn test_task_handle_timeout() {
653        let config = ThreadPoolConfig {
654            num_threads: 1,
655            ..Default::default()
656        };
657
658        let pool = AdvancedThreadPool::new(config);
659
660        let handle = pool
661            .submit(|| {
662                std::thread::sleep(Duration::from_millis(200));
663            })
664            .expect("Test: thread spawn failed");
665
666        // Should timeout
667        let result = handle.wait_timeout(Duration::from_millis(50));
668        assert!(result.is_err());
669    }
670
671    #[test]
672    fn test_batch_submission() {
673        let config = ThreadPoolConfig {
674            num_threads: 2,
675            ..Default::default()
676        };
677
678        let pool = AdvancedThreadPool::new(config);
679        let counter = Arc::new(AtomicUsize::new(0));
680
681        let tasks: Vec<_> = (0..5)
682            .map(|_| {
683                let counter_clone = Arc::clone(&counter);
684                move || {
685                    counter_clone.fetch_add(1, Ordering::SeqCst);
686                }
687            })
688            .collect();
689
690        let handles = pool.submit_batch(tasks).expect("Test: operation failed");
691
692        for handle in handles {
693            handle.wait().expect("Test: operation failed");
694        }
695
696        assert_eq!(counter.load(Ordering::SeqCst), 5);
697    }
698
699    #[test]
700    fn test_thread_pool_stats() {
701        let config = ThreadPoolConfig {
702            num_threads: 2,
703            ..Default::default()
704        };
705
706        let pool = AdvancedThreadPool::new(config);
707        let stats = pool.get_stats();
708
709        assert_eq!(stats.total_tasks_executed, 0);
710        assert_eq!(stats.worker_stats.len(), 2);
711    }
712
713    #[test]
714    fn test_numa_aware_thread_pool() {
715        let config = ThreadPoolConfig {
716            num_threads: 4,
717            ..Default::default()
718        };
719
720        let numa_pool = NumaAwareThreadPool::new(config);
721        let counter = Arc::new(AtomicUsize::new(0));
722        let counter_clone = Arc::clone(&counter);
723
724        let handle = numa_pool
725            .submit_numa(
726                move || {
727                    counter_clone.fetch_add(1, Ordering::SeqCst);
728                },
729                Some(0),
730            )
731            .expect("Test: array creation failed");
732
733        handle.wait().expect("Test: operation failed");
734        assert_eq!(counter.load(Ordering::SeqCst), 1);
735    }
736
737    #[test]
738    fn test_task_priority() {
739        let task = Task::new(|| {}).0.with_priority(TaskPriority::High);
740        assert_eq!(task.priority, TaskPriority::High);
741    }
742
743    #[test]
744    fn test_numa_topology() {
745        let topology = NumaTopology::detect();
746        assert!(topology.num_nodes > 0);
747        assert!(!topology.cores_per_node.is_empty());
748        assert!(!topology.memory_per_node.is_empty());
749    }
750}