Skip to main content

squawk_thread/
pool.rs

1//! [`Pool`] implements a basic custom thread pool
2//! inspired by the [`threadpool` crate](http://docs.rs/threadpool).
3//! When you spawn a task you specify a thread intent
4//! so the pool can schedule it to run on a thread with that intent.
5//! rust-analyzer uses this to prioritize work based on latency requirements.
6//!
7//! The thread pool is implemented entirely using
8//! the threading utilities in [`crate`].
9
10use std::{
11    marker::PhantomData,
12    num::NonZeroUsize,
13    panic::{AssertUnwindSafe, UnwindSafe},
14    sync::{
15        Arc,
16        atomic::{AtomicUsize, Ordering},
17    },
18};
19
20use crossbeam_channel::{Receiver, Sender};
21use crossbeam_utils::sync::WaitGroup;
22
23use crate::{Builder, JoinHandle, ThreadIntent};
24
25pub struct Pool {
26    // `_handles` is never read: the field is present
27    // only for its `Drop` impl.
28
29    // The worker threads exit once the channel closes;
30    // make sure to keep `job_sender` above `handles`
31    // so that the channel is actually closed
32    // before we join the worker threads!
33    job_sender: Sender<Job>,
34    _handles: Box<[JoinHandle]>,
35    extant_tasks: Arc<AtomicUsize>,
36}
37
38struct Job {
39    requested_intent: ThreadIntent,
40    f: Box<dyn FnOnce() + Send /* + UnwindSafe */ + 'static>,
41}
42
43impl Pool {
44    /// # Panics
45    ///
46    /// Panics if job panics
47    #[must_use]
48    pub fn new(threads: NonZeroUsize) -> Self {
49        const STACK_SIZE: usize = 8 * 1024 * 1024;
50        const INITIAL_INTENT: ThreadIntent = ThreadIntent::Worker;
51
52        let (job_sender, job_receiver) = crossbeam_channel::unbounded();
53        let extant_tasks = Arc::new(AtomicUsize::new(0));
54
55        let mut handles = Vec::with_capacity(threads.into());
56        for idx in 0..threads.into() {
57            let handle = Builder::new(INITIAL_INTENT, format!("squawk:worker:{idx}",))
58                .stack_size(STACK_SIZE)
59                .allow_leak(true)
60                .spawn({
61                    let extant_tasks = Arc::clone(&extant_tasks);
62                    let job_receiver: Receiver<Job> = job_receiver.clone();
63                    move || {
64                        let mut current_intent = INITIAL_INTENT;
65                        for job in job_receiver {
66                            if job.requested_intent != current_intent {
67                                job.requested_intent.apply_to_current_thread();
68                                current_intent = job.requested_intent;
69                            }
70                            extant_tasks.fetch_add(1, Ordering::SeqCst);
71
72                            // SAFETY: it's safe to assume that `job.f` is unwind safe because we always
73                            // abort the process if it panics.
74                            // Panicking here ensures that we don't swallow errors and is the same as
75                            // what rayon does.
76                            // Any recovery should be implemented outside the thread pool (e.g. when
77                            // dispatching requests/notifications etc).
78                            if let Err(error) = std::panic::catch_unwind(AssertUnwindSafe(job.f)) {
79                                if let Some(msg) = error.downcast_ref::<String>() {
80                                    tracing::error!("Worker thread panicked with: {msg}; aborting");
81                                } else if let Some(msg) = error.downcast_ref::<&str>() {
82                                    tracing::error!("Worker thread panicked with: {msg}; aborting");
83                                } else if let Some(cancelled) =
84                                    error.downcast_ref::<salsa::Cancelled>()
85                                {
86                                    tracing::error!(
87                                        "Worker thread got cancelled: {cancelled}; aborting"
88                                    );
89                                } else {
90                                    tracing::error!(
91                                        "Worker thread panicked with: {error:?}; aborting"
92                                    );
93                                }
94
95                                std::process::abort();
96                            }
97
98                            extant_tasks.fetch_sub(1, Ordering::SeqCst);
99                        }
100                    }
101                })
102                .expect("failed to spawn thread");
103
104            handles.push(handle);
105        }
106
107        Self {
108            _handles: handles.into_boxed_slice(),
109            extant_tasks,
110            job_sender,
111        }
112    }
113
114    pub fn spawn<F>(&self, intent: ThreadIntent, f: F)
115    where
116        F: FnOnce() + Send + /* UnwindSafe + */ 'static,
117    {
118        let f = Box::new(move || {
119            if cfg!(debug_assertions) {
120                intent.assert_is_used_on_current_thread();
121            }
122            f();
123        });
124
125        let job = Job {
126            requested_intent: intent,
127            f,
128        };
129        self.job_sender.send(job).unwrap();
130    }
131
132    pub fn scoped<'pool, 'scope, F, R>(&'pool self, f: F) -> R
133    where
134        F: FnOnce(&Scope<'pool, 'scope>) -> R,
135    {
136        let wg = WaitGroup::new();
137        let scope = Scope {
138            pool: self,
139            wg,
140            _marker: PhantomData,
141        };
142        let r = f(&scope);
143        scope.wg.wait();
144        r
145    }
146
147    #[must_use]
148    pub fn len(&self) -> usize {
149        self.extant_tasks.load(Ordering::SeqCst)
150    }
151
152    #[must_use]
153    pub fn is_empty(&self) -> bool {
154        self.len() == 0
155    }
156}
157
158pub struct Scope<'pool, 'scope> {
159    pool: &'pool Pool,
160    wg: WaitGroup,
161    _marker: PhantomData<fn(&'scope ()) -> &'scope ()>,
162}
163
164impl<'scope> Scope<'_, 'scope> {
165    pub fn spawn<F>(&self, intent: ThreadIntent, f: F)
166    where
167        F: 'scope + FnOnce() + Send + UnwindSafe,
168    {
169        let wg = self.wg.clone();
170        let f = Box::new(move || {
171            if cfg!(debug_assertions) {
172                intent.assert_is_used_on_current_thread();
173            }
174            f();
175            drop(wg);
176        });
177
178        let job = Job {
179            requested_intent: intent,
180            f: unsafe {
181                std::mem::transmute::<
182                    Box<dyn 'scope + FnOnce() + Send + UnwindSafe>,
183                    Box<dyn 'static + FnOnce() + Send + UnwindSafe>,
184                >(f)
185            },
186        };
187        self.pool.job_sender.send(job).unwrap();
188    }
189}