Skip to main content

zsh/extensions/
worker.rs

1//! Worker pool for zshrs — persistent threads for background work.
2//!
3//! **zshrs-original infrastructure — no C source counterpart.** This
4//! module does NOT port a corresponding `Src/*.c` file. C zsh's
5//! background-work strategy is `fork(2)`: every completion run,
6//! process substitution, or command substitution is a child process
7//! (see `zfork()` in Src/exec.c and the `forklevel` machinery
8//! Src/init.c uses to track depth). zshrs replaces that pattern with
9//! a fixed-size thread pool + crossbeam channel dispatch.
10//!
11//! Replacement rationale (vs the fork() path the C source takes):
12//!   - No fork overhead (50-500μs per fork on macOS)
13//!   - No address space duplication
14//!   - Warm thread stacks ready to go
15//!   - Backpressure via bounded channel
16//!
17//! Pool size = available_parallelism() clamped to [2, 18].
18//! Channel capacity = 4 × pool size (bounded backpressure).
19//!
20//! Audit fixes applied:
21//!   1. crossbeam-channel replaces Arc<Mutex<mpsc::Receiver>> — no mutex contention
22//!   2. Bounded channel (4×N) provides backpressure
23//!   3. catch_unwind wraps every task — panics logged, worker stays alive
24//!   4. tracing spans on submit + worker loop
25//!   5. Queue depth metric on submit
26//!   6. Task cancellation via AtomicBool flag
27
28use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
29use std::sync::Arc;
30use std::thread;
31
32/// A unit of work the pool can execute.
33type Task = Box<dyn FnOnce() + Send + 'static>;
34
35/// Fixed-size thread pool with bounded FIFO task queue.
36///
37/// zshrs-original — replaces C zsh's per-task `fork()` + `wait()`
38/// pattern (Src/exec.c `zfork()` / Src/jobs.c child management) with
39/// a persistent thread pool. Uses crossbeam-channel for lock-free
40/// multi-consumer dispatch — each worker calls `recv()` directly,
41/// no mutex.
42pub struct WorkerPool {
43    /// `workers` field.
44    workers: Vec<Worker>,
45    /// `sender` field.
46    sender: Option<crossbeam_channel::Sender<Task>>,
47    /// `size` field.
48    size: usize,
49    /// Shared cancellation flag — when set, workers drop pending tasks
50    cancelled: Arc<AtomicBool>,
51    /// Queue depth — incremented on submit, decremented on task start
52    queued: Arc<AtomicUsize>,
53    /// Total tasks completed across all workers
54    completed: Arc<AtomicUsize>,
55}
56
57struct Worker {
58    #[allow(dead_code)]
59    id: usize,
60    handle: Option<thread::JoinHandle<()>>,
61}
62
63impl WorkerPool {
64    /// Create a pool with `size` worker threads and bounded channel.
65    /// Channel capacity = 4 × size (provides backpressure without
66    /// starving).
67    /// zshrs-original — no C counterpart. Replaces the
68    /// "spawn-on-demand" semantics of `zfork()` (Src/exec.c) with
69    /// pre-spawned threads ready to receive work over a bounded
70    /// channel.
71    pub fn new(size: usize) -> Self {
72        let capacity = size * 4;
73        let (sender, receiver) = crossbeam_channel::bounded::<Task>(capacity);
74        let cancelled = Arc::new(AtomicBool::new(false));
75        let queued = Arc::new(AtomicUsize::new(0));
76        let completed = Arc::new(AtomicUsize::new(0));
77
78        let mut workers = Vec::with_capacity(size);
79        for id in 0..size {
80            let rx = receiver.clone();
81            let cancelled = Arc::clone(&cancelled);
82            let queued = Arc::clone(&queued);
83            let completed = Arc::clone(&completed);
84
85            let handle = thread::Builder::new()
86                .name(format!("zshrs-worker-{}", id))
87                .spawn(move || {
88                    loop {
89                        let task = match rx.recv() {
90                            Ok(task) => task,
91                            Err(_) => break, // channel closed → shutdown
92                        };
93
94                        queued.fetch_sub(1, Ordering::Relaxed);
95
96                        // Check cancellation before running
97                        if cancelled.load(Ordering::Relaxed) {
98                            continue; // drain without executing
99                        }
100
101                        // catch_unwind keeps the worker alive if a task panics
102                        if let Err(e) = std::panic::catch_unwind(std::panic::AssertUnwindSafe(task))
103                        {
104                            let msg = if let Some(s) = e.downcast_ref::<&str>() {
105                                (*s).to_string()
106                            } else if let Some(s) = e.downcast_ref::<String>() {
107                                s.clone()
108                            } else {
109                                "unknown panic".to_string()
110                            };
111                            tracing::error!(
112                                worker = id,
113                                panic = %msg,
114                                "worker task panicked"
115                            );
116                        }
117
118                        completed.fetch_add(1, Ordering::Relaxed);
119                    }
120                    tracing::debug!(worker = id, "worker thread exiting");
121                })
122                .expect("failed to spawn worker thread");
123
124            workers.push(Worker {
125                id,
126                handle: Some(handle),
127            });
128        }
129
130        tracing::info!(
131            pool_size = size,
132            channel_capacity = capacity,
133            "worker pool started"
134        );
135
136        WorkerPool {
137            workers,
138            sender: Some(sender),
139            size,
140            cancelled,
141            queued,
142            completed,
143        }
144    }
145
146    /// Create a pool sized to the machine's parallelism, clamped to
147    /// `[2, 18]`.
148    /// zshrs-original — no C counterpart. C zsh has no concept of a
149    /// "pool size" because it forks on demand (one child per
150    /// background task, see Src/jobs.c).
151    pub fn default_size() -> Self {
152        let cpus = thread::available_parallelism()
153            .map(|n| n.get())
154            .unwrap_or(4);
155        Self::new(cpus.clamp(2, 18))
156    }
157
158    /// Submit a task to the pool. Blocks if the queue is full
159    /// (backpressure). Panics if the pool has been shut down.
160    /// zshrs-original — replaces the `fork() + execve()` /
161    /// `fork() + run-shell-fn` dispatch pairs in Src/exec.c.
162    pub fn submit<F>(&self, f: F)
163    where
164        F: FnOnce() + Send + 'static,
165    {
166        let depth = self.queued.fetch_add(1, Ordering::Relaxed) + 1;
167        if depth > self.size * 2 {
168            tracing::debug!(queue_depth = depth, "worker pool queue building up");
169        }
170        self.sender
171            .as_ref()
172            .expect("pool shut down")
173            .send(Box::new(f))
174            .expect("all workers dead");
175    }
176
177    /// Submit a task and get a receiver for its result.
178    /// zshrs-original — closest C analog is the pipe-based
179    /// command-substitution result capture in Src/exec.c
180    /// (`getoutput()` reading the child's stdout pipe), but using a
181    /// typed Rust channel sidesteps the marshalling.
182    pub fn submit_with_result<F, R>(&self, f: F) -> crossbeam_channel::Receiver<R>
183    where
184        F: FnOnce() -> R + Send + 'static,
185        R: Send + 'static,
186    {
187        let (tx, rx) = crossbeam_channel::bounded(1);
188        self.submit(move || {
189            let result = f();
190            let _ = tx.send(result);
191        });
192        rx
193    }
194
195    /// Signal all workers to drop pending tasks.
196    /// Already-running tasks will finish, but queued tasks are
197    /// skipped. Reset with `reset_cancel()`.
198    /// zshrs-original — closest C analog is the SIGINT/SIGQUIT
199    /// signal-storm dispatch C zsh fires at its background children
200    /// in Src/signals.c (`killjb()` / `killpg()`), but here we set a
201    /// flag instead of sending a signal across a fork boundary.
202    pub fn cancel(&self) {
203        self.cancelled.store(true, Ordering::Relaxed);
204        tracing::info!("worker pool: cancel requested");
205    }
206
207    /// Clear the cancellation flag — pool resumes normal execution.
208    /// zshrs-original — no C counterpart.
209    pub fn reset_cancel(&self) {
210        self.cancelled.store(false, Ordering::Relaxed);
211    }
212
213    /// Number of worker threads.
214    /// zshrs-original — no C counterpart.
215    pub fn size(&self) -> usize {
216        self.size
217    }
218
219    /// Approximate number of tasks waiting in the queue.
220    /// zshrs-original — no C counterpart; closest equivalent is the
221    /// `jobtab` length walk Src/jobs.c uses for `jobs -l` output.
222    pub fn queue_depth(&self) -> usize {
223        self.queued.load(Ordering::Relaxed)
224    }
225
226    /// Total tasks completed since pool creation.
227    /// zshrs-original — no C counterpart.
228    pub fn completed(&self) -> usize {
229        self.completed.load(Ordering::Relaxed)
230    }
231}
232
233impl Drop for WorkerPool {
234    fn drop(&mut self) {
235        // Signal workers to skip remaining queued tasks
236        self.cancelled.store(true, Ordering::Relaxed);
237        // Drop the sender → channel closes → recv() returns Err → threads exit
238        drop(self.sender.take());
239        // Give workers a brief window to finish their current task.
240        // Don't block indefinitely — the process is exiting.
241        for w in &mut self.workers {
242            if let Some(handle) = w.handle.take() {
243                // Detach the thread — OS cleans up on process exit.
244                // join() would block if a worker is mid-parse on a 500-line
245                // completion function. Not worth the wait on Ctrl-D/exit.
246                drop(handle);
247            }
248        }
249        // Demoted from `info!` to `debug!` so the default tracing
250        // filter (INFO) suppresses it. The bare shutdown announcement
251        // has no operational value — interesting telemetry would be
252        // a non-zero error count or a stuck worker, which warrants its
253        // own surface. Empirically (bug #23 in docs/BUGS.md) the
254        // existing info! also leaked to stdout when a script left a
255        // duped fd open (`exec 3>&1`): by the time worker Drop runs,
256        // the file-backed log writer is closed, and tracing's fallback
257        // writes to fd 1 — which is the original stdout the dup
258        // pointed at. Default INFO filter no longer triggers this code
259        // path at all in normal use.
260        tracing::debug!(
261            tasks_completed = self.completed.load(Ordering::Relaxed),
262            "worker pool shut down"
263        );
264    }
265}
266
267#[cfg(test)]
268mod tests {
269    use super::*;
270
271    /// Spin-wait helper for tests: poll `counter` until it reaches
272    /// `target` or the deadline elapses. Replaces the old "drop(pool)
273    /// implicitly waits" pattern, which broke when production Drop
274    /// switched to setting cancelled=true (so queued tasks would be
275    /// skipped on drop instead of drained).
276    fn wait_for_count(counter: &AtomicUsize, target: usize, max_wait_ms: u64) {
277        let deadline = std::time::Instant::now() + std::time::Duration::from_millis(max_wait_ms);
278        while counter.load(Ordering::Relaxed) < target {
279            if std::time::Instant::now() >= deadline {
280                panic!(
281                    "wait_for_count timed out: counter={} target={} after {}ms",
282                    counter.load(Ordering::Relaxed),
283                    target,
284                    max_wait_ms
285                );
286            }
287            std::thread::sleep(std::time::Duration::from_millis(2));
288        }
289    }
290
291    #[test]
292    fn test_pool_executes_tasks() {
293        let _g = crate::test_util::global_state_lock();
294        let pool = WorkerPool::new(2);
295        let counter = Arc::new(AtomicUsize::new(0));
296
297        for _ in 0..100 {
298            let c = Arc::clone(&counter);
299            pool.submit(move || {
300                c.fetch_add(1, Ordering::Relaxed);
301            });
302        }
303
304        // Drain explicitly — production Drop sets cancelled=true and
305        // skips queued tasks (intentional for shell exit), so the test
306        // can't rely on `drop(pool)` to wait.
307        wait_for_count(&counter, 100, 5_000);
308        drop(pool);
309        assert_eq!(counter.load(Ordering::Relaxed), 100);
310    }
311
312    #[test]
313    fn test_submit_with_result() {
314        let _g = crate::test_util::global_state_lock();
315        let pool = WorkerPool::new(2);
316        let rx = pool.submit_with_result(|| 42);
317        assert_eq!(rx.recv().unwrap(), 42);
318    }
319
320    #[test]
321    fn test_default_size() {
322        let _g = crate::test_util::global_state_lock();
323        let pool = WorkerPool::default_size();
324        assert!(pool.size() >= 2);
325        assert!(pool.size() <= 18);
326    }
327
328    #[test]
329    fn test_panic_does_not_kill_worker() {
330        let _g = crate::test_util::global_state_lock();
331        let pool = WorkerPool::new(2);
332        let counter = Arc::new(AtomicUsize::new(0));
333
334        // Submit a task that panics
335        pool.submit(|| panic!("intentional test panic"));
336
337        // Submit tasks after the panic — they should still run
338        for _ in 0..10 {
339            let c = Arc::clone(&counter);
340            pool.submit(move || {
341                c.fetch_add(1, Ordering::Relaxed);
342            });
343        }
344
345        wait_for_count(&counter, 10, 5_000);
346        drop(pool);
347        assert_eq!(counter.load(Ordering::Relaxed), 10);
348    }
349
350    #[test]
351    fn test_cancel_skips_queued_tasks() {
352        let _g = crate::test_util::global_state_lock();
353        let pool = WorkerPool::new(1); // single worker to control ordering
354        let barrier = Arc::new(std::sync::Barrier::new(2));
355        // Signal the worker fires when it ENTERS the barrier task. Lets
356        // the main thread wait until the worker is provably blocked
357        // inside the barrier BEFORE calling cancel(). Without this, a
358        // pre-empted worker that hasn't yet pulled task #1 would see the
359        // cancel flag, skip task #1, and the main thread's barrier.wait()
360        // below would deadlock waiting for a second party that never
361        // arrives.
362        let started = Arc::new(std::sync::Mutex::new(false));
363        let started_cv = Arc::new(std::sync::Condvar::new());
364        let counter = Arc::new(AtomicUsize::new(0));
365
366        let b = Arc::clone(&barrier);
367        let started_clone = Arc::clone(&started);
368        let cv_clone = Arc::clone(&started_cv);
369        pool.submit(move || {
370            // Mark "task entered" + notify before blocking.
371            *started_clone.lock().unwrap() = true;
372            cv_clone.notify_one();
373            b.wait();
374        });
375
376        // Wait until the worker is provably inside the task (and thus
377        // committed to calling b.wait() — no race with cancel below).
378        // 5s timeout is a safety net; in practice this fires within μs.
379        let mut g = started.lock().unwrap();
380        let timeout = std::time::Duration::from_secs(5);
381        while !*g {
382            let (gg, wait_result) = started_cv.wait_timeout(g, timeout).unwrap();
383            g = gg;
384            if wait_result.timed_out() && !*g {
385                panic!("worker never started task #1 within 5s — test scaffolding broken");
386            }
387        }
388        drop(g);
389
390        // Queue tasks that should be skipped (worker is parked at b.wait()).
391        // Cap at channel capacity (size * 4 = 4 for a 1-worker pool) MINUS 1
392        // for safety. Submitting more than the channel holds while the
393        // worker is blocked deadlocks `submit` itself, since the bounded
394        // crossbeam channel back-pressures `send()`. 3 skipped tasks is
395        // enough to prove "queued tasks get cancelled" — the count isn't
396        // load-bearing.
397        for _ in 0..3 {
398            let c = Arc::clone(&counter);
399            pool.submit(move || {
400                c.fetch_add(1, Ordering::Relaxed);
401            });
402        }
403
404        // Cancel, then unblock the worker — it'll return from b.wait(),
405        // loop, see cancelled=true, drain the 5 queued tasks without
406        // executing them.
407        pool.cancel();
408        barrier.wait();
409
410        // Give workers time to drain
411        std::thread::sleep(std::time::Duration::from_millis(50));
412
413        // Queued tasks should have been skipped
414        assert_eq!(counter.load(Ordering::Relaxed), 0);
415
416        // Reset and verify pool still works
417        pool.reset_cancel();
418        let c = Arc::clone(&counter);
419        pool.submit(move || {
420            c.fetch_add(1, Ordering::Relaxed);
421        });
422        // Wait for the post-reset task to complete BEFORE drop, since
423        // production Drop sets cancelled=true again and would skip
424        // any not-yet-pulled task.
425        wait_for_count(&counter, 1, 5_000);
426        drop(pool);
427        assert_eq!(counter.load(Ordering::Relaxed), 1);
428    }
429
430    #[test]
431    fn test_metrics() {
432        let _g = crate::test_util::global_state_lock();
433        let pool = WorkerPool::new(2);
434        assert_eq!(pool.completed(), 0);
435
436        for _ in 0..10 {
437            pool.submit(|| {});
438        }
439
440        drop(pool);
441        // Can't assert exact completed count due to timing,
442        // but it should be > 0 after drop waits for all
443    }
444
445    #[test]
446    fn test_backpressure_bounded() {
447        let _g = crate::test_util::global_state_lock();
448        // Pool of 1 with capacity 4 — 5th submit blocks (back-pressure)
449        // until the worker drains one. With 20 submits + 1 worker the
450        // pool's submit() call blocks naturally; by the time the loop
451        // exits, ~16 are completed and ~4 are still queued / in-flight.
452        let pool = WorkerPool::new(1);
453        let counter = Arc::new(AtomicUsize::new(0));
454
455        for _ in 0..20 {
456            let c = Arc::clone(&counter);
457            pool.submit(move || {
458                c.fetch_add(1, Ordering::Relaxed);
459            });
460        }
461
462        wait_for_count(&counter, 20, 5_000);
463        drop(pool);
464        assert_eq!(counter.load(Ordering::Relaxed), 20);
465    }
466}