Skip to main content

rustalign_concurrency/
thread_pool.rs

1//! Thread pool for parallel processing
2//!
3//! This matches the C++ thread_pool class and provides
4//! a simple thread pool with work queue and future support.
5
6use std::sync::{Arc, Mutex, mpsc};
7use std::thread::{self, JoinHandle, sleep};
8use std::time::Duration;
9
10use super::WorkQueue;
11
12/// Thread pool for executing work in parallel
13///
14/// This matches the C++ thread_pool class with similar functionality:
15/// - Fixed number of worker threads
16/// - Work queue
17/// - Parallel for loop support
18pub struct ThreadPool {
19    done: Arc<Mutex<bool>>,
20    nthreads: usize,
21    queue: WorkQueue<Box<dyn FnOnce() + Send + 'static>>,
22    threads: Vec<JoinHandle<()>>,
23}
24
25impl ThreadPool {
26    /// Create a new thread pool with the specified number of threads
27    pub fn new(nthreads: usize) -> Self {
28        assert!(nthreads > 0, "Thread pool must have at least one thread");
29
30        let done = Arc::new(Mutex::new(false));
31        let queue: WorkQueue<Box<dyn FnOnce() + Send + 'static>> = WorkQueue::new();
32        let mut threads = Vec::with_capacity(nthreads);
33
34        for _ in 0..nthreads {
35            let done = Arc::clone(&done);
36            let queue = queue.clone();
37
38            let handle = thread::spawn(move || {
39                while !*done.lock().unwrap() {
40                    if let Some(task) = queue.try_pop() {
41                        task();
42                    } else {
43                        // Small sleep to avoid busy-waiting
44                        sleep(Duration::from_millis(1));
45                    }
46                }
47            });
48
49            threads.push(handle);
50        }
51
52        Self {
53            done,
54            nthreads,
55            queue,
56            threads,
57        }
58    }
59
60    /// Submit a function to be executed by the thread pool
61    ///
62    /// Returns a receiver that can be used to get the result
63    pub fn submit<F, R, Args>(&self, f: F, args: Args) -> mpsc::Receiver<R>
64    where
65        F: FnOnce(Args) -> R + Send + 'static,
66        R: Send + 'static,
67        Args: Send + 'static,
68    {
69        let (tx, rx) = mpsc::channel();
70
71        let task = Box::new(move || {
72            let result = f(args);
73            let _ = tx.send(result);
74        });
75
76        self.queue.push(task);
77
78        rx
79    }
80
81    /// Execute a parallel for loop over a range
82    ///
83    /// This matches the C++ thread_pool::parallel_for method
84    pub fn parallel_for<F>(&self, start: usize, end: usize, stride: usize, f: F)
85    where
86        F: Fn(usize, usize, usize) + Send + Sync + Clone + 'static,
87    {
88        let range = end - start;
89        let block_size = range / self.nthreads;
90
91        if block_size == 0 {
92            // Handle case where range is smaller than number of threads
93            f(start, end, stride);
94            return;
95        }
96
97        let mut block_start = start;
98        let mut block_end = block_start + block_size;
99
100        while block_start < end {
101            if block_end >= end {
102                block_end = end;
103            }
104
105            // Execute this block
106            f(block_start, block_end, stride);
107
108            block_start = block_end;
109            block_end += block_size;
110        }
111    }
112
113    /// Get the number of threads in the pool
114    pub fn size(&self) -> usize {
115        self.nthreads
116    }
117
118    /// Check if the pool has been shut down
119    pub fn is_done(&self) -> bool {
120        *self.done.lock().unwrap()
121    }
122}
123
124impl Drop for ThreadPool {
125    fn drop(&mut self) {
126        // Signal threads to stop
127        *self.done.lock().unwrap() = true;
128
129        // Join all threads
130        for thread in self.threads.drain(..) {
131            let _ = thread.join();
132        }
133    }
134}
135
136#[cfg(test)]
137mod tests {
138    use super::*;
139    use std::sync::Arc;
140
141    #[test]
142    fn test_thread_pool_basic() {
143        let pool = ThreadPool::new(2);
144        assert_eq!(pool.size(), 2);
145        assert!(!pool.is_done());
146    }
147
148    #[test]
149    fn test_thread_pool_submit() {
150        let pool = ThreadPool::new(2);
151        let rx = pool.submit(|x: usize| x * 2, 21);
152
153        assert_eq!(rx.recv().unwrap(), 42);
154    }
155
156    #[test]
157    fn test_thread_pool_parallel_for() {
158        let pool = ThreadPool::new(4);
159        let counter = Arc::new(Mutex::new(0usize));
160
161        let counter_clone = Arc::clone(&counter);
162        pool.parallel_for(0, 100, 1, move |start, end, _stride| {
163            let mut c = counter_clone.lock().unwrap();
164            for _ in start..end {
165                *c += 1;
166            }
167        });
168
169        // Check that work was done
170        assert_eq!(*counter.lock().unwrap(), 100);
171    }
172}