1use std::{
11 marker::PhantomData,
12 num::NonZeroUsize,
13 panic::{AssertUnwindSafe, UnwindSafe},
14 sync::{
15 Arc,
16 atomic::{AtomicUsize, Ordering},
17 },
18};
19
20use crossbeam_channel::{Receiver, Sender};
21use crossbeam_utils::sync::WaitGroup;
22
23use crate::{Builder, JoinHandle, ThreadIntent};
24
25pub struct Pool {
26 job_sender: Sender<Job>,
34 _handles: Box<[JoinHandle]>,
35 extant_tasks: Arc<AtomicUsize>,
36}
37
38struct Job {
39 requested_intent: ThreadIntent,
40 f: Box<dyn FnOnce() + Send + 'static>,
41}
42
43impl Pool {
44 #[must_use]
48 pub fn new(threads: NonZeroUsize) -> Self {
49 const STACK_SIZE: usize = 8 * 1024 * 1024;
50 const INITIAL_INTENT: ThreadIntent = ThreadIntent::Worker;
51
52 let (job_sender, job_receiver) = crossbeam_channel::unbounded();
53 let extant_tasks = Arc::new(AtomicUsize::new(0));
54
55 let mut handles = Vec::with_capacity(threads.into());
56 for idx in 0..threads.into() {
57 let handle = Builder::new(INITIAL_INTENT, format!("squawk:worker:{idx}",))
58 .stack_size(STACK_SIZE)
59 .allow_leak(true)
60 .spawn({
61 let extant_tasks = Arc::clone(&extant_tasks);
62 let job_receiver: Receiver<Job> = job_receiver.clone();
63 move || {
64 let mut current_intent = INITIAL_INTENT;
65 for job in job_receiver {
66 if job.requested_intent != current_intent {
67 job.requested_intent.apply_to_current_thread();
68 current_intent = job.requested_intent;
69 }
70 extant_tasks.fetch_add(1, Ordering::SeqCst);
71
72 if let Err(error) = std::panic::catch_unwind(AssertUnwindSafe(job.f)) {
79 if let Some(msg) = error.downcast_ref::<String>() {
80 tracing::error!("Worker thread panicked with: {msg}; aborting");
81 } else if let Some(msg) = error.downcast_ref::<&str>() {
82 tracing::error!("Worker thread panicked with: {msg}; aborting");
83 } else if let Some(cancelled) =
84 error.downcast_ref::<salsa::Cancelled>()
85 {
86 tracing::error!(
87 "Worker thread got cancelled: {cancelled}; aborting"
88 );
89 } else {
90 tracing::error!(
91 "Worker thread panicked with: {error:?}; aborting"
92 );
93 }
94
95 std::process::abort();
96 }
97
98 extant_tasks.fetch_sub(1, Ordering::SeqCst);
99 }
100 }
101 })
102 .expect("failed to spawn thread");
103
104 handles.push(handle);
105 }
106
107 Self {
108 _handles: handles.into_boxed_slice(),
109 extant_tasks,
110 job_sender,
111 }
112 }
113
114 pub fn spawn<F>(&self, intent: ThreadIntent, f: F)
115 where
116 F: FnOnce() + Send + 'static,
117 {
118 let f = Box::new(move || {
119 if cfg!(debug_assertions) {
120 intent.assert_is_used_on_current_thread();
121 }
122 f();
123 });
124
125 let job = Job {
126 requested_intent: intent,
127 f,
128 };
129 self.job_sender.send(job).unwrap();
130 }
131
132 pub fn scoped<'pool, 'scope, F, R>(&'pool self, f: F) -> R
133 where
134 F: FnOnce(&Scope<'pool, 'scope>) -> R,
135 {
136 let wg = WaitGroup::new();
137 let scope = Scope {
138 pool: self,
139 wg,
140 _marker: PhantomData,
141 };
142 let r = f(&scope);
143 scope.wg.wait();
144 r
145 }
146
147 #[must_use]
148 pub fn len(&self) -> usize {
149 self.extant_tasks.load(Ordering::SeqCst)
150 }
151
152 #[must_use]
153 pub fn is_empty(&self) -> bool {
154 self.len() == 0
155 }
156}
157
158pub struct Scope<'pool, 'scope> {
159 pool: &'pool Pool,
160 wg: WaitGroup,
161 _marker: PhantomData<fn(&'scope ()) -> &'scope ()>,
162}
163
164impl<'scope> Scope<'_, 'scope> {
165 pub fn spawn<F>(&self, intent: ThreadIntent, f: F)
166 where
167 F: 'scope + FnOnce() + Send + UnwindSafe,
168 {
169 let wg = self.wg.clone();
170 let f = Box::new(move || {
171 if cfg!(debug_assertions) {
172 intent.assert_is_used_on_current_thread();
173 }
174 f();
175 drop(wg);
176 });
177
178 let job = Job {
179 requested_intent: intent,
180 f: unsafe {
181 std::mem::transmute::<
182 Box<dyn 'scope + FnOnce() + Send + UnwindSafe>,
183 Box<dyn 'static + FnOnce() + Send + UnwindSafe>,
184 >(f)
185 },
186 };
187 self.pool.job_sender.send(job).unwrap();
188 }
189}