Skip to main content

scirs2_autograd/parallel/
mod.rs

1//! Parallel processing and thread pool optimizations
2//!
3//! This module provides thread pool management and parallel execution
4//! optimizations for tensor operations, particularly targeting CPU performance
5//! improvements for large-scale computations.
6
7use std::sync::mpsc::{channel, Receiver, Sender};
8use std::sync::{Arc, Mutex};
9use std::thread::{self, JoinHandle};
10use std::time::{Duration, Instant};
11
12pub mod parallel_ops;
13pub mod thread_pool;
14pub mod work_stealing;
15
16/// Global thread pool manager
17static GLOBAL_THREAD_POOL: std::sync::LazyLock<Arc<Mutex<Option<ThreadPool>>>> =
18    std::sync::LazyLock::new(|| Arc::new(Mutex::new(None)));
19
20/// Configuration for thread pool optimization
21#[derive(Debug, Clone)]
22pub struct ThreadPoolConfig {
23    /// Number of worker threads
24    pub num_threads: usize,
25    /// Maximum queue size per thread
26    pub max_queue_size: usize,
27    /// Enable work stealing between threads
28    pub work_stealing: bool,
29    /// Thread priority level
30    pub priority: ThreadPriority,
31    /// CPU affinity settings
32    pub cpu_affinity: CpuAffinity,
33    /// Idle timeout for threads
34    pub idle_timeout: Duration,
35    /// Enable adaptive scheduling
36    pub adaptive_scheduling: bool,
37}
38
39impl Default for ThreadPoolConfig {
40    fn default() -> Self {
41        Self {
42            num_threads: std::thread::available_parallelism()
43                .map(|n| n.get())
44                .unwrap_or(4),
45            max_queue_size: 1000,
46            work_stealing: true,
47            priority: ThreadPriority::Normal,
48            cpu_affinity: CpuAffinity::Auto,
49            idle_timeout: Duration::from_secs(60),
50            adaptive_scheduling: true,
51        }
52    }
53}
54
55/// Thread priority levels
56#[derive(Debug, Clone, Copy, PartialEq)]
57pub enum ThreadPriority {
58    Low,
59    Normal,
60    High,
61    Critical,
62}
63
64/// CPU affinity configuration
65#[derive(Debug, Clone)]
66pub enum CpuAffinity {
67    /// Automatic assignment
68    Auto,
69    /// Specific CPU cores
70    Cores(Vec<usize>),
71    /// NUMA-aware assignment
72    Numa,
73}
74
75/// Thread pool for parallel execution
76pub struct ThreadPool {
77    workers: Vec<Worker>,
78    sender: Sender<Job>,
79    config: ThreadPoolConfig,
80    stats: Arc<Mutex<ThreadPoolStats>>,
81}
82
83impl ThreadPool {
84    /// Create a new thread pool with default configuration
85    pub fn new() -> Self {
86        Self::with_config(ThreadPoolConfig::default())
87    }
88
89    /// Create a new thread pool with custom configuration
90    pub fn with_config(config: ThreadPoolConfig) -> Self {
91        let (sender, receiver) = channel();
92        let receiver = Arc::new(Mutex::new(receiver));
93
94        // Initialize stats with proper worker_stats vector
95        let mut stats_data = ThreadPoolStats::new();
96        stats_data.worker_stats = (0..config.num_threads).map(WorkerStats::new).collect();
97        let stats = Arc::new(Mutex::new(stats_data));
98
99        let mut workers = Vec::with_capacity(config.num_threads);
100
101        for id in 0..config.num_threads {
102            workers.push(Worker::new(
103                id,
104                Arc::clone(&receiver),
105                Arc::clone(&stats),
106                config.clone(),
107            ));
108        }
109
110        ThreadPool {
111            workers,
112            sender,
113            config,
114            stats,
115        }
116    }
117
118    /// Execute a closure on the thread pool
119    pub fn execute<F>(&self, f: F) -> Result<(), ThreadPoolError>
120    where
121        F: FnOnce() + Send + 'static,
122    {
123        let job = Box::new(f);
124        self.sender
125            .send(job)
126            .map_err(|_| ThreadPoolError::QueueFull)
127    }
128
129    /// Execute a closure and wait for completion
130    pub fn execute_and_wait<F, R>(&self, f: F) -> Result<R, ThreadPoolError>
131    where
132        F: FnOnce() -> R + Send + 'static,
133        R: Send + 'static,
134    {
135        let (tx, rx) = std::sync::mpsc::channel();
136
137        self.execute(move || {
138            let result = f();
139            let _ = tx.send(result);
140        })?;
141
142        rx.recv().map_err(|_| ThreadPoolError::ExecutionFailed)
143    }
144
145    /// Execute multiple tasks in parallel
146    pub fn execute_parallel<F, I>(&self, tasks: I) -> Result<Vec<()>, ThreadPoolError>
147    where
148        F: FnOnce() + Send + 'static,
149        I: IntoIterator<Item = F>,
150    {
151        let tasks: Vec<F> = tasks.into_iter().collect();
152        let mut handles = Vec::with_capacity(tasks.len());
153
154        for task in tasks {
155            let (tx, rx) = std::sync::mpsc::channel();
156
157            self.execute(move || {
158                task();
159                let _ = tx.send(());
160            })?;
161
162            handles.push(rx);
163        }
164
165        // Wait for all tasks to complete
166        let num_handles = handles.len();
167        for rx in handles {
168            rx.recv().map_err(|_| ThreadPoolError::ExecutionFailed)?;
169        }
170
171        Ok(vec![(); num_handles])
172    }
173
174    /// Get thread pool statistics
175    pub fn get_stats(&self) -> ThreadPoolStats {
176        self.stats
177            .lock()
178            .unwrap_or_else(|poisoned| poisoned.into_inner())
179            .clone()
180    }
181
182    /// Get current configuration
183    pub fn get_config(&self) -> &ThreadPoolConfig {
184        &self.config
185    }
186
187    /// Resize the thread pool
188    pub fn resize(&mut self, newsize: usize) -> Result<(), ThreadPoolError> {
189        if newsize == 0 {
190            return Err(ThreadPoolError::InvalidConfiguration(
191                "Thread pool size cannot be zero".into(),
192            ));
193        }
194
195        // Implementation would recreate the thread pool with new size
196        // For now, just update the config
197        self.config.num_threads = newsize;
198        Ok(())
199    }
200
201    /// Shutdown the thread pool gracefully
202    pub fn shutdown(self) -> Result<(), ThreadPoolError> {
203        // Drop the sender to signal shutdown
204        drop(self.sender);
205
206        // Wait for all workers to finish
207        for worker in self.workers {
208            if let Some(thread) = worker.thread {
209                thread.join().map_err(|_| ThreadPoolError::ShutdownFailed)?;
210            }
211        }
212
213        Ok(())
214    }
215}
216
217impl Default for ThreadPool {
218    fn default() -> Self {
219        Self::new()
220    }
221}
222
223/// Worker thread in the thread pool
224struct Worker {
225    #[allow(dead_code)]
226    id: usize,
227    thread: Option<JoinHandle<()>>,
228}
229
230impl Worker {
231    fn new(
232        id: usize,
233        receiver: Arc<Mutex<Receiver<Job>>>,
234        stats: Arc<Mutex<ThreadPoolStats>>,
235        config: ThreadPoolConfig,
236    ) -> Worker {
237        let thread = thread::spawn(move || {
238            // Set thread priority if supported
239            Self::set_thread_priority(config.priority);
240
241            // Set CPU affinity if specified
242            Self::set_cpu_affinity(id, &config.cpu_affinity);
243
244            loop {
245                let job = {
246                    let receiver = receiver.lock().expect("Test: operation failed");
247                    receiver.recv()
248                };
249
250                match job {
251                    Ok(job) => {
252                        let start = Instant::now();
253                        job();
254                        let duration = start.elapsed();
255
256                        // Update statistics
257                        {
258                            if let Ok(mut stats) = stats.lock() {
259                                stats.tasks_completed += 1;
260                                stats.total_execution_time += duration;
261                                if id < stats.worker_stats.len() {
262                                    stats.worker_stats[id].tasks_completed += 1;
263                                    stats.worker_stats[id].total_time += duration;
264                                }
265                            }
266                        }
267                    }
268                    Err(_) => {
269                        // Channel closed, shutdown
270                        break;
271                    }
272                }
273            }
274        });
275
276        Worker {
277            id,
278            thread: Some(thread),
279        }
280    }
281
282    fn set_thread_priority(priority: ThreadPriority) {
283        // Platform-specific thread _priority setting would go here
284        // For now, this is a no-op
285    }
286
287    fn set_cpu_affinity(_worker_id: usize, affinity: &CpuAffinity) {
288        // Platform-specific CPU _affinity setting would go here
289        // For now, this is a no-op
290    }
291}
292
293/// Job type for the thread pool
294type Job = Box<dyn FnOnce() + Send + 'static>;
295
296/// Statistics for thread pool performance monitoring
297#[derive(Debug, Clone)]
298pub struct ThreadPoolStats {
299    /// Total tasks completed
300    pub tasks_completed: u64,
301    /// Total execution time across all threads
302    pub total_execution_time: Duration,
303    /// Current queue size
304    pub current_queue_size: usize,
305    /// Maximum queue size reached
306    pub max_queue_size: usize,
307    /// Number of active threads
308    pub active_threads: usize,
309    /// Per-worker statistics
310    pub worker_stats: Vec<WorkerStats>,
311    /// Load balancing efficiency
312    pub load_balance_ratio: f32,
313}
314
315impl ThreadPoolStats {
316    fn new() -> Self {
317        Self {
318            tasks_completed: 0,
319            total_execution_time: Duration::ZERO,
320            current_queue_size: 0,
321            max_queue_size: 0,
322            active_threads: 0,
323            worker_stats: Vec::new(),
324            load_balance_ratio: 1.0,
325        }
326    }
327
328    /// Calculate average task execution time
329    pub fn average_execution_time(&self) -> Duration {
330        if self.tasks_completed == 0 {
331            Duration::ZERO
332        } else {
333            self.total_execution_time / self.tasks_completed as u32
334        }
335    }
336
337    /// Calculate tasks per second
338    pub fn tasks_per_second(&self) -> f64 {
339        if self.total_execution_time.is_zero() {
340            0.0
341        } else {
342            self.tasks_completed as f64 / self.total_execution_time.as_secs_f64()
343        }
344    }
345
346    /// Calculate thread utilization
347    pub fn thread_utilization(&self) -> f64 {
348        if self.worker_stats.is_empty() {
349            return 0.0;
350        }
351
352        let total_time: Duration = self.worker_stats.iter().map(|stats| stats.total_time).sum();
353
354        let max_time = self
355            .worker_stats
356            .iter()
357            .map(|stats| stats.total_time)
358            .max()
359            .unwrap_or(Duration::ZERO);
360
361        if max_time.is_zero() {
362            0.0
363        } else {
364            total_time.as_secs_f64() / (max_time.as_secs_f64() * self.worker_stats.len() as f64)
365        }
366    }
367}
368
369/// Statistics for individual worker threads
370#[derive(Debug, Clone)]
371pub struct WorkerStats {
372    /// Worker ID
373    pub worker_id: usize,
374    /// Tasks completed by this worker
375    pub tasks_completed: u64,
376    /// Total execution time for this worker
377    pub total_time: Duration,
378    /// Current queue size for this worker
379    pub queue_size: usize,
380    /// Last activity timestamp
381    pub last_activity: Option<Instant>,
382}
383
384impl WorkerStats {
385    fn new(_workerid: usize) -> Self {
386        Self {
387            worker_id: _workerid,
388            tasks_completed: 0,
389            total_time: Duration::ZERO,
390            queue_size: 0,
391            last_activity: None,
392        }
393    }
394}
395
396/// Parallel execution scheduler for tensor operations
397pub struct ParallelScheduler {
398    thread_pool: Arc<ThreadPool>,
399    config: SchedulerConfig,
400}
401
402impl ParallelScheduler {
403    /// Create a new parallel scheduler
404    pub fn new() -> Self {
405        let thread_pool = Arc::new(ThreadPool::new());
406        Self {
407            thread_pool,
408            config: SchedulerConfig::default(),
409        }
410    }
411
412    /// Create a scheduler with custom thread pool
413    pub fn with_thread_pool(_threadpool: Arc<ThreadPool>) -> Self {
414        Self {
415            thread_pool: _threadpool,
416            config: SchedulerConfig::default(),
417        }
418    }
419
420    /// Schedule a parallel tensor operation
421    pub fn schedule_operation<F, R>(&self, operation: F) -> Result<R, ThreadPoolError>
422    where
423        F: FnOnce() -> R + Send + 'static,
424        R: Send + 'static,
425    {
426        if ParallelScheduler::should_parallelize(&operation) {
427            self.thread_pool.execute_and_wait(operation)
428        } else {
429            // Execute on current thread for small operations
430            Ok(operation())
431        }
432    }
433
434    /// Schedule multiple parallel operations
435    pub fn schedule_batch<F>(&self, operations: Vec<F>) -> Result<Vec<()>, ThreadPoolError>
436    where
437        F: FnOnce() + Send + 'static,
438    {
439        if operations.len() <= 1 || !self.config.enable_batching {
440            // Execute sequentially for small batches
441            for op in operations {
442                op();
443            }
444            Ok(vec![])
445        } else {
446            self.thread_pool.execute_parallel(operations)
447        }
448    }
449
450    /// Check if an operation should be parallelized
451    fn should_parallelize<F, R>(operation: &F) -> bool
452    where
453        F: FnOnce() -> R + Send + 'static,
454        R: Send + 'static,
455    {
456        // Heuristics for deciding whether to parallelize:
457        // - Operation complexity
458        // - Data size
459        // - Current thread pool load
460        // - Overhead considerations
461
462        true // Simplified decision
463    }
464
465    /// Get scheduler statistics
466    pub fn get_stats(&self) -> ThreadPoolStats {
467        self.thread_pool.get_stats()
468    }
469}
470
471impl Default for ParallelScheduler {
472    fn default() -> Self {
473        Self::new()
474    }
475}
476
477/// Configuration for the parallel scheduler
478#[derive(Debug, Clone)]
479pub struct SchedulerConfig {
480    /// Enable automatic batching of operations
481    pub enable_batching: bool,
482    /// Minimum operation size for parallelization
483    pub min_parallel_size: usize,
484    /// Maximum batch size
485    pub max_batch_size: usize,
486    /// Load balancing strategy
487    pub load_balancing: LoadBalancingStrategy,
488}
489
490impl Default for SchedulerConfig {
491    fn default() -> Self {
492        Self {
493            enable_batching: true,
494            min_parallel_size: 1000,
495            max_batch_size: 100,
496            load_balancing: LoadBalancingStrategy::RoundRobin,
497        }
498    }
499}
500
501/// Load balancing strategies
502#[derive(Debug, Clone, Copy)]
503pub enum LoadBalancingStrategy {
504    /// Simple round-robin assignment
505    RoundRobin,
506    /// Assign to least loaded worker
507    LeastLoaded,
508    /// Work stealing between workers
509    WorkStealing,
510    /// Adaptive based on operation characteristics
511    Adaptive,
512}
513
514/// Errors that can occur in thread pool operations
515#[derive(Debug, thiserror::Error)]
516pub enum ThreadPoolError {
517    #[error("Thread pool queue is full")]
518    QueueFull,
519    #[error("Task execution failed")]
520    ExecutionFailed,
521    #[error("Thread pool shutdown failed")]
522    ShutdownFailed,
523    #[error("Invalid configuration: {0}")]
524    InvalidConfiguration(String),
525    #[error("Worker thread panicked")]
526    WorkerPanic,
527}
528
529/// Public API functions for thread pool management
530/// Initialize the global thread pool with default configuration
531#[allow(dead_code)]
532pub fn init_thread_pool() -> Result<(), ThreadPoolError> {
533    let mut pool = GLOBAL_THREAD_POOL.lock().expect("Test: operation failed");
534    if pool.is_none() {
535        *pool = Some(ThreadPool::new());
536    }
537    Ok(())
538}
539
540/// Initialize the global thread pool with custom configuration
541#[allow(dead_code)]
542pub fn init_thread_pool_with_config(config: ThreadPoolConfig) -> Result<(), ThreadPoolError> {
543    let mut pool = GLOBAL_THREAD_POOL.lock().expect("Test: operation failed");
544    *pool = Some(ThreadPool::with_config(config));
545    Ok(())
546}
547
548/// Execute a task on the global thread pool
549#[allow(dead_code)]
550pub fn execute_global<F>(f: F) -> Result<(), ThreadPoolError>
551where
552    F: FnOnce() + Send + 'static,
553{
554    let pool = GLOBAL_THREAD_POOL.lock().expect("Test: operation failed");
555    if let Some(ref pool) = *pool {
556        pool.execute(f)
557    } else {
558        Err(ThreadPoolError::InvalidConfiguration(
559            "Thread pool not initialized".into(),
560        ))
561    }
562}
563
564/// Execute a task and wait for completion on the global thread pool
565#[allow(dead_code)]
566pub fn execute_and_wait_global<F, R>(f: F) -> Result<R, ThreadPoolError>
567where
568    F: FnOnce() -> R + Send + 'static,
569    R: Send + 'static,
570{
571    let pool = GLOBAL_THREAD_POOL.lock().expect("Test: operation failed");
572    if let Some(ref pool) = *pool {
573        pool.execute_and_wait(f)
574    } else {
575        Err(ThreadPoolError::InvalidConfiguration(
576            "Thread pool not initialized".into(),
577        ))
578    }
579}
580
581/// Get global thread pool statistics
582#[allow(dead_code)]
583pub fn get_global_thread_pool_stats() -> Option<ThreadPoolStats> {
584    let pool = GLOBAL_THREAD_POOL.lock().expect("Test: operation failed");
585    pool.as_ref().map(|p| p.get_stats())
586}
587
588/// Shutdown the global thread pool
589#[allow(dead_code)]
590pub fn shutdown_global_thread_pool() -> Result<(), ThreadPoolError> {
591    let mut pool = GLOBAL_THREAD_POOL.lock().expect("Test: operation failed");
592    if let Some(pool) = pool.take() {
593        pool.shutdown()
594    } else {
595        Ok(())
596    }
597}
598
599/// Set the number of threads for the global thread pool
600#[allow(dead_code)]
601pub fn set_global_thread_count(count: usize) -> Result<(), ThreadPoolError> {
602    let config = ThreadPoolConfig {
603        num_threads: count,
604        ..Default::default()
605    };
606    init_thread_pool_with_config(config)
607}
608
609/// Get the current number of threads in the global thread pool
610#[allow(dead_code)]
611pub fn get_global_thread_count() -> usize {
612    let pool = GLOBAL_THREAD_POOL.lock().expect("Test: operation failed");
613    pool.as_ref()
614        .map(|p| p.get_config().num_threads)
615        .unwrap_or(0)
616}
617
618/// Check if the global thread pool is initialized
619#[allow(dead_code)]
620pub fn is_thread_pool_initialized() -> bool {
621    let pool = GLOBAL_THREAD_POOL.lock().expect("Test: operation failed");
622    pool.is_some()
623}
624
625#[cfg(test)]
626mod tests {
627    use super::*;
628    use std::sync::atomic::{AtomicUsize, Ordering};
629
630    #[test]
631    fn test_thread_pool_creation() {
632        let pool = ThreadPool::new();
633        assert!(pool.get_config().num_threads > 0);
634
635        let config = ThreadPoolConfig {
636            num_threads: 2,
637            work_stealing: false,
638            ..Default::default()
639        };
640        let custom_pool = ThreadPool::with_config(config);
641        assert_eq!(custom_pool.get_config().num_threads, 2);
642        assert!(!custom_pool.get_config().work_stealing);
643    }
644
645    #[test]
646    fn test_thread_pool_execution() {
647        let pool = ThreadPool::new();
648        let counter = Arc::new(AtomicUsize::new(0));
649        let counter_clone = Arc::clone(&counter);
650
651        pool.execute(move || {
652            counter_clone.fetch_add(1, Ordering::SeqCst);
653        })
654        .expect("Test: thread spawn failed");
655
656        // Give the task time to execute
657        std::thread::sleep(Duration::from_millis(100));
658        assert_eq!(counter.load(Ordering::SeqCst), 1);
659    }
660
661    #[test]
662    fn test_thread_pool_execute_and_wait() {
663        let pool = ThreadPool::new();
664
665        let result = pool
666            .execute_and_wait(|| 42)
667            .expect("Test: operation failed");
668
669        assert_eq!(result, 42);
670    }
671
672    #[test]
673    fn test_thread_pool_parallel_execution() {
674        let pool = ThreadPool::new();
675        let counter = Arc::new(AtomicUsize::new(0));
676
677        let tasks: Vec<_> = (0..5)
678            .map(|_| {
679                let counter_clone = Arc::clone(&counter);
680                move || {
681                    counter_clone.fetch_add(1, Ordering::SeqCst);
682                }
683            })
684            .collect();
685
686        pool.execute_parallel(tasks)
687            .expect("Test: operation failed");
688        assert_eq!(counter.load(Ordering::SeqCst), 5);
689    }
690
691    #[test]
692    fn test_thread_pool_stats() {
693        let pool = ThreadPool::new();
694        let stats = pool.get_stats();
695
696        // Initially no tasks completed
697        assert_eq!(stats.tasks_completed, 0);
698        assert_eq!(stats.average_execution_time(), Duration::ZERO);
699    }
700
701    #[test]
702    fn test_parallel_scheduler() {
703        let scheduler = ParallelScheduler::new();
704
705        let result = scheduler
706            .schedule_operation(|| 100)
707            .expect("Test: operation failed");
708
709        assert_eq!(result, 100);
710    }
711
712    #[test]
713    fn test_global_thread_pool() {
714        // Clean shutdown first in case of previous test failures
715        let _ = shutdown_global_thread_pool();
716
717        // Initialize fresh thread pool
718        init_thread_pool().expect("Test: operation failed");
719        assert!(is_thread_pool_initialized());
720
721        // Test execute and wait (more reliable than async execute)
722        let result = execute_and_wait_global(|| 42);
723        assert!(result.is_ok());
724        assert_eq!(result.expect("Test: result failed"), 42);
725
726        // Get stats (handle potential poisoned mutex gracefully)
727        let stats = get_global_thread_pool_stats();
728        assert!(stats.is_some());
729
730        // Get thread count
731        let thread_count = get_global_thread_count();
732        assert!(thread_count > 0);
733
734        // Shutdown
735        shutdown_global_thread_pool().expect("Test: operation failed");
736        assert!(!is_thread_pool_initialized());
737    }
738
739    #[test]
740    fn test_thread_pool_config() {
741        let config = ThreadPoolConfig::default();
742        assert!(config.num_threads > 0);
743        assert!(config.work_stealing);
744        assert_eq!(config.priority, ThreadPriority::Normal);
745
746        let custom_config = ThreadPoolConfig {
747            num_threads: 8,
748            max_queue_size: 500,
749            work_stealing: false,
750            priority: ThreadPriority::High,
751            cpu_affinity: CpuAffinity::Cores(vec![0, 1, 2, 3]),
752            idle_timeout: Duration::from_secs(30),
753            adaptive_scheduling: false,
754        };
755
756        assert_eq!(custom_config.num_threads, 8);
757        assert_eq!(custom_config.max_queue_size, 500);
758        assert!(!custom_config.work_stealing);
759        assert_eq!(custom_config.priority, ThreadPriority::High);
760    }
761
762    #[test]
763    fn test_scheduler_config() {
764        let config = SchedulerConfig::default();
765        assert!(config.enable_batching);
766        assert_eq!(config.min_parallel_size, 1000);
767        assert_eq!(config.max_batch_size, 100);
768        assert!(matches!(
769            config.load_balancing,
770            LoadBalancingStrategy::RoundRobin
771        ));
772    }
773
774    #[test]
775    fn test_worker_stats() {
776        let stats = WorkerStats::new(0);
777        assert_eq!(stats.worker_id, 0);
778        assert_eq!(stats.tasks_completed, 0);
779        assert_eq!(stats.total_time, Duration::ZERO);
780        assert_eq!(stats.queue_size, 0);
781        assert!(stats.last_activity.is_none());
782    }
783
784    #[test]
785    fn test_thread_pool_stats_calculations() {
786        let mut stats = ThreadPoolStats::new();
787        stats.tasks_completed = 10;
788        stats.total_execution_time = Duration::from_secs(5);
789
790        assert_eq!(stats.average_execution_time(), Duration::from_millis(500));
791        assert_eq!(stats.tasks_per_second(), 2.0);
792    }
793}