Skip to main content

zsh/
worker.rs

1//! Worker pool for zshrs — persistent threads for background work.
2//!
3//! Port rationale: zsh forks for everything (completion, process subs,
4//! command substitution).  Each fork copies the entire shell state.
5//! We replace that with a fixed-size thread pool + channel dispatch,
6//! giving us:
7//!   - No fork overhead (50-500μs per fork on macOS)
8//!   - No address space duplication
9//!   - Warm thread stacks ready to go
10//!   - Backpressure via bounded channel
11//!
12//! Pool size = available_parallelism() clamped to [2, 18].
13//! Channel capacity = 4 × pool size (bounded backpressure).
14//!
15//! Audit fixes applied:
16//!   1. crossbeam-channel replaces Arc<Mutex<mpsc::Receiver>> — no mutex contention
17//!   2. Bounded channel (4×N) provides backpressure
18//!   3. catch_unwind wraps every task — panics logged, worker stays alive
19//!   4. tracing spans on submit + worker loop
20//!   5. Queue depth metric on submit
21//!   6. Task cancellation via AtomicBool flag
22
23use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
24use std::sync::Arc;
25use std::thread;
26
27/// A unit of work the pool can execute.
28type Task = Box<dyn FnOnce() + Send + 'static>;
29
30/// Fixed-size thread pool with bounded FIFO task queue.
31///
32/// Uses crossbeam-channel for lock-free multi-consumer dispatch —
33/// each worker calls `recv()` directly, no mutex.
34pub struct WorkerPool {
35    workers: Vec<Worker>,
36    sender: Option<crossbeam_channel::Sender<Task>>,
37    size: usize,
38    /// Shared cancellation flag — when set, workers drop pending tasks
39    cancelled: Arc<AtomicBool>,
40    /// Queue depth — incremented on submit, decremented on task start
41    queued: Arc<AtomicUsize>,
42    /// Total tasks completed across all workers
43    completed: Arc<AtomicUsize>,
44}
45
46struct Worker {
47    #[allow(dead_code)]
48    id: usize,
49    handle: Option<thread::JoinHandle<()>>,
50}
51
52impl WorkerPool {
53    /// Create a pool with `size` worker threads and bounded channel.
54    /// Channel capacity = 4 × size (provides backpressure without starving).
55    pub fn new(size: usize) -> Self {
56        let capacity = size * 4;
57        let (sender, receiver) = crossbeam_channel::bounded::<Task>(capacity);
58        let cancelled = Arc::new(AtomicBool::new(false));
59        let queued = Arc::new(AtomicUsize::new(0));
60        let completed = Arc::new(AtomicUsize::new(0));
61
62        let mut workers = Vec::with_capacity(size);
63        for id in 0..size {
64            let rx = receiver.clone();
65            let cancelled = Arc::clone(&cancelled);
66            let queued = Arc::clone(&queued);
67            let completed = Arc::clone(&completed);
68
69            let handle = thread::Builder::new()
70                .name(format!("zshrs-worker-{}", id))
71                .spawn(move || {
72                    loop {
73                        let task = match rx.recv() {
74                            Ok(task) => task,
75                            Err(_) => break, // channel closed → shutdown
76                        };
77
78                        queued.fetch_sub(1, Ordering::Relaxed);
79
80                        // Check cancellation before running
81                        if cancelled.load(Ordering::Relaxed) {
82                            continue; // drain without executing
83                        }
84
85                        // catch_unwind keeps the worker alive if a task panics
86                        if let Err(e) = std::panic::catch_unwind(std::panic::AssertUnwindSafe(task))
87                        {
88                            let msg = if let Some(s) = e.downcast_ref::<&str>() {
89                                (*s).to_string()
90                            } else if let Some(s) = e.downcast_ref::<String>() {
91                                s.clone()
92                            } else {
93                                "unknown panic".to_string()
94                            };
95                            tracing::error!(
96                                worker = id,
97                                panic = %msg,
98                                "worker task panicked"
99                            );
100                        }
101
102                        completed.fetch_add(1, Ordering::Relaxed);
103                    }
104                    tracing::debug!(worker = id, "worker thread exiting");
105                })
106                .expect("failed to spawn worker thread");
107
108            workers.push(Worker {
109                id,
110                handle: Some(handle),
111            });
112        }
113
114        tracing::info!(
115            pool_size = size,
116            channel_capacity = capacity,
117            "worker pool started"
118        );
119
120        WorkerPool {
121            workers,
122            sender: Some(sender),
123            size,
124            cancelled,
125            queued,
126            completed,
127        }
128    }
129
130    /// Create a pool sized to the machine's parallelism, clamped [2, 18].
131    pub fn default_size() -> Self {
132        let cpus = thread::available_parallelism()
133            .map(|n| n.get())
134            .unwrap_or(4);
135        Self::new(cpus.clamp(2, 18))
136    }
137
138    /// Submit a task to the pool.  Blocks if the queue is full (backpressure).
139    /// Panics if the pool has been shut down.
140    pub fn submit<F>(&self, f: F)
141    where
142        F: FnOnce() + Send + 'static,
143    {
144        let depth = self.queued.fetch_add(1, Ordering::Relaxed) + 1;
145        if depth > self.size * 2 {
146            tracing::debug!(queue_depth = depth, "worker pool queue building up");
147        }
148        self.sender
149            .as_ref()
150            .expect("pool shut down")
151            .send(Box::new(f))
152            .expect("all workers dead");
153    }
154
155    /// Submit a task and get a receiver for its result.
156    pub fn submit_with_result<F, R>(&self, f: F) -> crossbeam_channel::Receiver<R>
157    where
158        F: FnOnce() -> R + Send + 'static,
159        R: Send + 'static,
160    {
161        let (tx, rx) = crossbeam_channel::bounded(1);
162        self.submit(move || {
163            let result = f();
164            let _ = tx.send(result);
165        });
166        rx
167    }
168
169    /// Signal all workers to drop pending tasks.
170    /// Already-running tasks will finish, but queued tasks are skipped.
171    /// Reset with `reset_cancel()`.
172    pub fn cancel(&self) {
173        self.cancelled.store(true, Ordering::Relaxed);
174        tracing::info!("worker pool: cancel requested");
175    }
176
177    /// Clear the cancellation flag — pool resumes normal execution.
178    pub fn reset_cancel(&self) {
179        self.cancelled.store(false, Ordering::Relaxed);
180    }
181
182    /// Number of worker threads.
183    pub fn size(&self) -> usize {
184        self.size
185    }
186
187    /// Approximate number of tasks waiting in the queue.
188    pub fn queue_depth(&self) -> usize {
189        self.queued.load(Ordering::Relaxed)
190    }
191
192    /// Total tasks completed since pool creation.
193    pub fn completed(&self) -> usize {
194        self.completed.load(Ordering::Relaxed)
195    }
196}
197
198impl Drop for WorkerPool {
199    fn drop(&mut self) {
200        // Signal workers to skip remaining queued tasks
201        self.cancelled.store(true, Ordering::Relaxed);
202        // Drop the sender → channel closes → recv() returns Err → threads exit
203        drop(self.sender.take());
204        // Give workers a brief window to finish their current task.
205        // Don't block indefinitely — the process is exiting.
206        for w in &mut self.workers {
207            if let Some(handle) = w.handle.take() {
208                // Detach the thread — OS cleans up on process exit.
209                // join() would block if a worker is mid-parse on a 500-line
210                // completion function. Not worth the wait on Ctrl-D/exit.
211                drop(handle);
212            }
213        }
214        tracing::info!(
215            tasks_completed = self.completed.load(Ordering::Relaxed),
216            "worker pool shut down"
217        );
218    }
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224
225    #[test]
226    fn test_pool_executes_tasks() {
227        let pool = WorkerPool::new(2);
228        let counter = Arc::new(AtomicUsize::new(0));
229
230        for _ in 0..100 {
231            let c = Arc::clone(&counter);
232            pool.submit(move || {
233                c.fetch_add(1, Ordering::Relaxed);
234            });
235        }
236
237        drop(pool); // waits for all tasks to finish
238        assert_eq!(counter.load(Ordering::Relaxed), 100);
239    }
240
241    #[test]
242    fn test_submit_with_result() {
243        let pool = WorkerPool::new(2);
244        let rx = pool.submit_with_result(|| 42);
245        assert_eq!(rx.recv().unwrap(), 42);
246    }
247
248    #[test]
249    fn test_default_size() {
250        let pool = WorkerPool::default_size();
251        assert!(pool.size() >= 2);
252        assert!(pool.size() <= 18);
253    }
254
255    #[test]
256    fn test_panic_does_not_kill_worker() {
257        let pool = WorkerPool::new(2);
258        let counter = Arc::new(AtomicUsize::new(0));
259
260        // Submit a task that panics
261        pool.submit(|| panic!("intentional test panic"));
262
263        // Submit tasks after the panic — they should still run
264        for _ in 0..10 {
265            let c = Arc::clone(&counter);
266            pool.submit(move || {
267                c.fetch_add(1, Ordering::Relaxed);
268            });
269        }
270
271        drop(pool);
272        assert_eq!(counter.load(Ordering::Relaxed), 10);
273    }
274
275    #[test]
276    fn test_cancel_skips_queued_tasks() {
277        let pool = WorkerPool::new(1); // single worker to control ordering
278        let barrier = Arc::new(std::sync::Barrier::new(2));
279        let counter = Arc::new(AtomicUsize::new(0));
280
281        // Block the worker on a barrier so tasks queue up
282        let b = Arc::clone(&barrier);
283        pool.submit(move || {
284            b.wait();
285        });
286
287        // Queue tasks that should be skipped
288        for _ in 0..5 {
289            let c = Arc::clone(&counter);
290            pool.submit(move || {
291                c.fetch_add(1, Ordering::Relaxed);
292            });
293        }
294
295        // Cancel, then unblock the worker
296        pool.cancel();
297        barrier.wait();
298
299        // Give workers time to drain
300        std::thread::sleep(std::time::Duration::from_millis(50));
301
302        // Queued tasks should have been skipped
303        assert_eq!(counter.load(Ordering::Relaxed), 0);
304
305        // Reset and verify pool still works
306        pool.reset_cancel();
307        let c = Arc::clone(&counter);
308        pool.submit(move || {
309            c.fetch_add(1, Ordering::Relaxed);
310        });
311        drop(pool);
312        assert_eq!(counter.load(Ordering::Relaxed), 1);
313    }
314
315    #[test]
316    fn test_metrics() {
317        let pool = WorkerPool::new(2);
318        assert_eq!(pool.completed(), 0);
319
320        for _ in 0..10 {
321            pool.submit(|| {});
322        }
323
324        drop(pool);
325        // Can't assert exact completed count due to timing,
326        // but it should be > 0 after drop waits for all
327    }
328
329    #[test]
330    fn test_backpressure_bounded() {
331        // Pool of 1 with capacity 4 — 5th submit should block until one completes
332        let pool = WorkerPool::new(1);
333        let counter = Arc::new(AtomicUsize::new(0));
334
335        for _ in 0..20 {
336            let c = Arc::clone(&counter);
337            pool.submit(move || {
338                c.fetch_add(1, Ordering::Relaxed);
339            });
340        }
341
342        drop(pool);
343        assert_eq!(counter.load(Ordering::Relaxed), 20);
344    }
345}