1use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
24use std::sync::Arc;
25use std::thread;
26
27type Task = Box<dyn FnOnce() + Send + 'static>;
29
30pub struct WorkerPool {
35 workers: Vec<Worker>,
36 sender: Option<crossbeam_channel::Sender<Task>>,
37 size: usize,
38 cancelled: Arc<AtomicBool>,
40 queued: Arc<AtomicUsize>,
42 completed: Arc<AtomicUsize>,
44}
45
46struct Worker {
47 #[allow(dead_code)]
48 id: usize,
49 handle: Option<thread::JoinHandle<()>>,
50}
51
52impl WorkerPool {
53 pub fn new(size: usize) -> Self {
56 let capacity = size * 4;
57 let (sender, receiver) = crossbeam_channel::bounded::<Task>(capacity);
58 let cancelled = Arc::new(AtomicBool::new(false));
59 let queued = Arc::new(AtomicUsize::new(0));
60 let completed = Arc::new(AtomicUsize::new(0));
61
62 let mut workers = Vec::with_capacity(size);
63 for id in 0..size {
64 let rx = receiver.clone();
65 let cancelled = Arc::clone(&cancelled);
66 let queued = Arc::clone(&queued);
67 let completed = Arc::clone(&completed);
68
69 let handle = thread::Builder::new()
70 .name(format!("zshrs-worker-{}", id))
71 .spawn(move || {
72 loop {
73 let task = match rx.recv() {
74 Ok(task) => task,
75 Err(_) => break, };
77
78 queued.fetch_sub(1, Ordering::Relaxed);
79
80 if cancelled.load(Ordering::Relaxed) {
82 continue; }
84
85 if let Err(e) = std::panic::catch_unwind(std::panic::AssertUnwindSafe(task))
87 {
88 let msg = if let Some(s) = e.downcast_ref::<&str>() {
89 (*s).to_string()
90 } else if let Some(s) = e.downcast_ref::<String>() {
91 s.clone()
92 } else {
93 "unknown panic".to_string()
94 };
95 tracing::error!(
96 worker = id,
97 panic = %msg,
98 "worker task panicked"
99 );
100 }
101
102 completed.fetch_add(1, Ordering::Relaxed);
103 }
104 tracing::debug!(worker = id, "worker thread exiting");
105 })
106 .expect("failed to spawn worker thread");
107
108 workers.push(Worker {
109 id,
110 handle: Some(handle),
111 });
112 }
113
114 tracing::info!(
115 pool_size = size,
116 channel_capacity = capacity,
117 "worker pool started"
118 );
119
120 WorkerPool {
121 workers,
122 sender: Some(sender),
123 size,
124 cancelled,
125 queued,
126 completed,
127 }
128 }
129
130 pub fn default_size() -> Self {
132 let cpus = thread::available_parallelism()
133 .map(|n| n.get())
134 .unwrap_or(4);
135 Self::new(cpus.clamp(2, 18))
136 }
137
138 pub fn submit<F>(&self, f: F)
141 where
142 F: FnOnce() + Send + 'static,
143 {
144 let depth = self.queued.fetch_add(1, Ordering::Relaxed) + 1;
145 if depth > self.size * 2 {
146 tracing::debug!(queue_depth = depth, "worker pool queue building up");
147 }
148 self.sender
149 .as_ref()
150 .expect("pool shut down")
151 .send(Box::new(f))
152 .expect("all workers dead");
153 }
154
155 pub fn submit_with_result<F, R>(&self, f: F) -> crossbeam_channel::Receiver<R>
157 where
158 F: FnOnce() -> R + Send + 'static,
159 R: Send + 'static,
160 {
161 let (tx, rx) = crossbeam_channel::bounded(1);
162 self.submit(move || {
163 let result = f();
164 let _ = tx.send(result);
165 });
166 rx
167 }
168
169 pub fn cancel(&self) {
173 self.cancelled.store(true, Ordering::Relaxed);
174 tracing::info!("worker pool: cancel requested");
175 }
176
177 pub fn reset_cancel(&self) {
179 self.cancelled.store(false, Ordering::Relaxed);
180 }
181
182 pub fn size(&self) -> usize {
184 self.size
185 }
186
187 pub fn queue_depth(&self) -> usize {
189 self.queued.load(Ordering::Relaxed)
190 }
191
192 pub fn completed(&self) -> usize {
194 self.completed.load(Ordering::Relaxed)
195 }
196}
197
198impl Drop for WorkerPool {
199 fn drop(&mut self) {
200 self.cancelled.store(true, Ordering::Relaxed);
202 drop(self.sender.take());
204 for w in &mut self.workers {
207 if let Some(handle) = w.handle.take() {
208 drop(handle);
212 }
213 }
214 tracing::info!(
215 tasks_completed = self.completed.load(Ordering::Relaxed),
216 "worker pool shut down"
217 );
218 }
219}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224
225 #[test]
226 fn test_pool_executes_tasks() {
227 let pool = WorkerPool::new(2);
228 let counter = Arc::new(AtomicUsize::new(0));
229
230 for _ in 0..100 {
231 let c = Arc::clone(&counter);
232 pool.submit(move || {
233 c.fetch_add(1, Ordering::Relaxed);
234 });
235 }
236
237 drop(pool); assert_eq!(counter.load(Ordering::Relaxed), 100);
239 }
240
241 #[test]
242 fn test_submit_with_result() {
243 let pool = WorkerPool::new(2);
244 let rx = pool.submit_with_result(|| 42);
245 assert_eq!(rx.recv().unwrap(), 42);
246 }
247
248 #[test]
249 fn test_default_size() {
250 let pool = WorkerPool::default_size();
251 assert!(pool.size() >= 2);
252 assert!(pool.size() <= 18);
253 }
254
255 #[test]
256 fn test_panic_does_not_kill_worker() {
257 let pool = WorkerPool::new(2);
258 let counter = Arc::new(AtomicUsize::new(0));
259
260 pool.submit(|| panic!("intentional test panic"));
262
263 for _ in 0..10 {
265 let c = Arc::clone(&counter);
266 pool.submit(move || {
267 c.fetch_add(1, Ordering::Relaxed);
268 });
269 }
270
271 drop(pool);
272 assert_eq!(counter.load(Ordering::Relaxed), 10);
273 }
274
275 #[test]
276 fn test_cancel_skips_queued_tasks() {
277 let pool = WorkerPool::new(1); let barrier = Arc::new(std::sync::Barrier::new(2));
279 let counter = Arc::new(AtomicUsize::new(0));
280
281 let b = Arc::clone(&barrier);
283 pool.submit(move || {
284 b.wait();
285 });
286
287 for _ in 0..5 {
289 let c = Arc::clone(&counter);
290 pool.submit(move || {
291 c.fetch_add(1, Ordering::Relaxed);
292 });
293 }
294
295 pool.cancel();
297 barrier.wait();
298
299 std::thread::sleep(std::time::Duration::from_millis(50));
301
302 assert_eq!(counter.load(Ordering::Relaxed), 0);
304
305 pool.reset_cancel();
307 let c = Arc::clone(&counter);
308 pool.submit(move || {
309 c.fetch_add(1, Ordering::Relaxed);
310 });
311 drop(pool);
312 assert_eq!(counter.load(Ordering::Relaxed), 1);
313 }
314
315 #[test]
316 fn test_metrics() {
317 let pool = WorkerPool::new(2);
318 assert_eq!(pool.completed(), 0);
319
320 for _ in 0..10 {
321 pool.submit(|| {});
322 }
323
324 drop(pool);
325 }
328
329 #[test]
330 fn test_backpressure_bounded() {
331 let pool = WorkerPool::new(1);
333 let counter = Arc::new(AtomicUsize::new(0));
334
335 for _ in 0..20 {
336 let c = Arc::clone(&counter);
337 pool.submit(move || {
338 c.fetch_add(1, Ordering::Relaxed);
339 });
340 }
341
342 drop(pool);
343 assert_eq!(counter.load(Ordering::Relaxed), 20);
344 }
345}