rustalign_concurrency/
thread_pool.rs1use std::sync::{Arc, Mutex, mpsc};
7use std::thread::{self, JoinHandle, sleep};
8use std::time::Duration;
9
10use super::WorkQueue;
11
12pub struct ThreadPool {
19 done: Arc<Mutex<bool>>,
20 nthreads: usize,
21 queue: WorkQueue<Box<dyn FnOnce() + Send + 'static>>,
22 threads: Vec<JoinHandle<()>>,
23}
24
25impl ThreadPool {
26 pub fn new(nthreads: usize) -> Self {
28 assert!(nthreads > 0, "Thread pool must have at least one thread");
29
30 let done = Arc::new(Mutex::new(false));
31 let queue: WorkQueue<Box<dyn FnOnce() + Send + 'static>> = WorkQueue::new();
32 let mut threads = Vec::with_capacity(nthreads);
33
34 for _ in 0..nthreads {
35 let done = Arc::clone(&done);
36 let queue = queue.clone();
37
38 let handle = thread::spawn(move || {
39 while !*done.lock().unwrap() {
40 if let Some(task) = queue.try_pop() {
41 task();
42 } else {
43 sleep(Duration::from_millis(1));
45 }
46 }
47 });
48
49 threads.push(handle);
50 }
51
52 Self {
53 done,
54 nthreads,
55 queue,
56 threads,
57 }
58 }
59
60 pub fn submit<F, R, Args>(&self, f: F, args: Args) -> mpsc::Receiver<R>
64 where
65 F: FnOnce(Args) -> R + Send + 'static,
66 R: Send + 'static,
67 Args: Send + 'static,
68 {
69 let (tx, rx) = mpsc::channel();
70
71 let task = Box::new(move || {
72 let result = f(args);
73 let _ = tx.send(result);
74 });
75
76 self.queue.push(task);
77
78 rx
79 }
80
81 pub fn parallel_for<F>(&self, start: usize, end: usize, stride: usize, f: F)
85 where
86 F: Fn(usize, usize, usize) + Send + Sync + Clone + 'static,
87 {
88 let range = end - start;
89 let block_size = range / self.nthreads;
90
91 if block_size == 0 {
92 f(start, end, stride);
94 return;
95 }
96
97 let mut block_start = start;
98 let mut block_end = block_start + block_size;
99
100 while block_start < end {
101 if block_end >= end {
102 block_end = end;
103 }
104
105 f(block_start, block_end, stride);
107
108 block_start = block_end;
109 block_end += block_size;
110 }
111 }
112
113 pub fn size(&self) -> usize {
115 self.nthreads
116 }
117
118 pub fn is_done(&self) -> bool {
120 *self.done.lock().unwrap()
121 }
122}
123
124impl Drop for ThreadPool {
125 fn drop(&mut self) {
126 *self.done.lock().unwrap() = true;
128
129 for thread in self.threads.drain(..) {
131 let _ = thread.join();
132 }
133 }
134}
135
136#[cfg(test)]
137mod tests {
138 use super::*;
139 use std::sync::Arc;
140
141 #[test]
142 fn test_thread_pool_basic() {
143 let pool = ThreadPool::new(2);
144 assert_eq!(pool.size(), 2);
145 assert!(!pool.is_done());
146 }
147
148 #[test]
149 fn test_thread_pool_submit() {
150 let pool = ThreadPool::new(2);
151 let rx = pool.submit(|x: usize| x * 2, 21);
152
153 assert_eq!(rx.recv().unwrap(), 42);
154 }
155
156 #[test]
157 fn test_thread_pool_parallel_for() {
158 let pool = ThreadPool::new(4);
159 let counter = Arc::new(Mutex::new(0usize));
160
161 let counter_clone = Arc::clone(&counter);
162 pool.parallel_for(0, 100, 1, move |start, end, _stride| {
163 let mut c = counter_clone.lock().unwrap();
164 for _ in start..end {
165 *c += 1;
166 }
167 });
168
169 assert_eq!(*counter.lock().unwrap(), 100);
171 }
172}