1use super::range::{
18 FixedRangeFactory, Range, RangeFactory, RangeOrchestrator, WorkStealingRangeFactory,
19};
20use log::{debug, error, warn};
21#[cfg(any(
23 target_os = "android",
24 target_os = "dragonfly",
25 target_os = "freebsd",
26 target_os = "linux"
27))]
28use nix::{
29 sched::{sched_setaffinity, CpuSet},
30 unistd::Pid,
31};
32use std::cell::Cell;
33use std::num::NonZeroUsize;
34use std::sync::atomic::{AtomicUsize, Ordering};
35use std::sync::{Arc, Condvar, Mutex, MutexGuard, PoisonError};
36use std::thread::{Scope, ScopedJoinHandle};
37
38#[derive(Clone, Copy, PartialEq, Eq)]
40enum MainStatus {
41 Waiting,
43 Ready,
45 WorkerPanic,
47}
48
49#[derive(Clone, Copy, PartialEq, Eq)]
51enum WorkerStatus {
52 Round(RoundColor),
54 Finished,
56}
57
58#[derive(Debug, Clone, Copy, PartialEq, Eq)]
61enum RoundColor {
62 Blue,
63 Red,
64}
65
66impl RoundColor {
67 fn toggle(&mut self) {
69 *self = match self {
70 RoundColor::Blue => RoundColor::Red,
71 RoundColor::Red => RoundColor::Blue,
72 }
73 }
74}
75
76struct Status<T> {
78 mutex: Mutex<T>,
79 condvar: Condvar,
80}
81
82impl<T> Status<T> {
83 fn new(t: T) -> Self {
85 Self {
86 mutex: Mutex::new(t),
87 condvar: Condvar::new(),
88 }
89 }
90
91 fn try_notify_one(&self, t: T) -> Result<(), PoisonError<MutexGuard<'_, T>>> {
96 *self.mutex.lock()? = t;
97 self.condvar.notify_one();
98 Ok(())
99 }
100
101 fn notify_one_if(&self, predicate: impl Fn(&T) -> bool, t: T) {
104 let mut locked = self.mutex.lock().unwrap();
105 if predicate(&*locked) {
106 *locked = t;
107 self.condvar.notify_one();
108 }
109 }
110
111 fn notify_all(&self, t: T) {
113 *self.mutex.lock().unwrap() = t;
114 self.condvar.notify_all();
115 }
116
117 fn wait_while(&self, predicate: impl FnMut(&mut T) -> bool) -> MutexGuard<T> {
122 self.condvar
123 .wait_while(self.mutex.lock().unwrap(), predicate)
124 .unwrap()
125 }
126}
127
128pub struct ThreadPool<'scope, Output> {
131 threads: Vec<WorkerThreadHandle<'scope, Output>>,
133 num_active_threads: Arc<AtomicUsize>,
135 round: Cell<RoundColor>,
137 worker_status: Arc<Status<WorkerStatus>>,
139 main_status: Arc<Status<MainStatus>>,
141 range_orchestrator: Box<dyn RangeOrchestrator>,
145}
146
147struct WorkerThreadHandle<'scope, Output> {
149 handle: ScopedJoinHandle<'scope, ()>,
151 output: Arc<Mutex<Option<Output>>>,
153}
154
155pub enum RangeStrategy {
157 Fixed,
159 WorkStealing,
161}
162
163impl<'scope, Output: Send + 'scope> ThreadPool<'scope, Output> {
164 pub fn new<'env, Input: Sync, Accum: ThreadAccumulator<Input, Output> + Send + 'scope>(
167 thread_scope: &'scope Scope<'scope, 'env>,
168 num_threads: NonZeroUsize,
169 range_strategy: RangeStrategy,
170 input: &'env [Input],
171 new_accumulator: impl Fn() -> Accum,
172 ) -> Self {
173 let num_threads: usize = num_threads.into();
174 let input_len = input.len();
175 match range_strategy {
176 RangeStrategy::Fixed => Self::new_with_factory(
177 thread_scope,
178 num_threads,
179 FixedRangeFactory::new(input_len, num_threads),
180 input,
181 new_accumulator,
182 ),
183 RangeStrategy::WorkStealing => Self::new_with_factory(
184 thread_scope,
185 num_threads,
186 WorkStealingRangeFactory::new(input_len, num_threads),
187 input,
188 new_accumulator,
189 ),
190 }
191 }
192
193 fn new_with_factory<
194 'env,
195 RnFactory: RangeFactory,
196 Input: Sync,
197 Accum: ThreadAccumulator<Input, Output> + Send + 'scope,
198 >(
199 thread_scope: &'scope Scope<'scope, 'env>,
200 num_threads: usize,
201 range_factory: RnFactory,
202 input: &'env [Input],
203 new_accumulator: impl Fn() -> Accum,
204 ) -> Self
205 where
206 RnFactory::Rn: 'scope + Send,
207 RnFactory::Orchestrator: 'static,
208 {
209 let color = RoundColor::Blue;
210 let num_active_threads = Arc::new(AtomicUsize::new(0));
211 let worker_status = Arc::new(Status::new(WorkerStatus::Round(color)));
212 let main_status = Arc::new(Status::new(MainStatus::Waiting));
213
214 #[cfg(not(any(
215 target_os = "android",
216 target_os = "dragonfly",
217 target_os = "freebsd",
218 target_os = "linux"
219 )))]
220 warn!("Pinning threads to CPUs is not implemented on this platform.");
221 let threads = (0..num_threads)
222 .map(|id| {
223 let output = Arc::new(Mutex::new(None));
224 let context = ThreadContext {
225 id,
226 num_active_threads: num_active_threads.clone(),
227 worker_status: worker_status.clone(),
228 main_status: main_status.clone(),
229 range: range_factory.range(id),
230 input,
231 output: output.clone(),
232 accumulator: new_accumulator(),
233 };
234 WorkerThreadHandle {
235 handle: thread_scope.spawn(move || {
236 #[cfg(any(
237 target_os = "android",
238 target_os = "dragonfly",
239 target_os = "freebsd",
240 target_os = "linux"
241 ))]
242 {
243 let mut cpu_set = CpuSet::new();
244 if let Err(e) = cpu_set.set(id) {
245 warn!("Failed to set CPU affinity for thread #{id}: {e}");
246 } else if let Err(e) = sched_setaffinity(Pid::from_raw(0), &cpu_set) {
247 warn!("Failed to set CPU affinity for thread #{id}: {e}");
248 } else {
249 debug!("Pinned thread #{id} to CPU #{id}");
250 }
251 }
252 context.run()
253 }),
254 output,
255 }
256 })
257 .collect();
258 debug!("[main thread] Spawned threads");
259
260 Self {
261 threads,
262 num_active_threads,
263 round: Cell::new(color),
264 worker_status,
265 main_status,
266 range_orchestrator: Box::new(range_factory.orchestrator()),
267 }
268 }
269
270 pub fn process_inputs(&self) -> impl Iterator<Item = Output> + '_ {
273 self.range_orchestrator.reset_ranges();
274
275 let num_threads = self.threads.len();
276 self.num_active_threads.store(num_threads, Ordering::SeqCst);
277
278 let mut round = self.round.get();
279 round.toggle();
280 self.round.set(round);
281
282 debug!("[main thread, round {round:?}] Ready to accumulate votes.");
283
284 self.worker_status.notify_all(WorkerStatus::Round(round));
285
286 debug!(
287 "[main thread, round {round:?}] Waiting for all threads to finish accumulating votes."
288 );
289
290 let mut guard = self
291 .main_status
292 .wait_while(|status| *status == MainStatus::Waiting);
293 if *guard == MainStatus::WorkerPanic {
294 error!("[main thread] A worker thread panicked!");
295 panic!("A worker thread panicked!");
296 }
297 *guard = MainStatus::Waiting;
298 drop(guard);
299
300 debug!("[main thread, round {round:?}] All threads have now finished accumulating votes.");
301
302 self.threads
303 .iter()
304 .map(move |t| t.output.lock().unwrap().take().unwrap())
305 }
306}
307
308impl<Output> Drop for ThreadPool<'_, Output> {
309 fn drop(&mut self) {
311 debug!("[main thread] Notifying threads to finish...");
312 self.worker_status.notify_all(WorkerStatus::Finished);
313
314 debug!("[main thread] Joining threads in the pool...");
315 for (i, t) in self.threads.drain(..).enumerate() {
316 let result = t.handle.join();
317 match result {
318 Ok(_) => debug!("[main thread] Thread {i} joined with result: {result:?}"),
319 Err(_) => error!("[main thread] Thread {i} joined with result: {result:?}"),
320 }
321 }
322 debug!("[main thread] Joined threads.");
323
324 #[cfg(feature = "log_parallelism")]
325 self.range_orchestrator.print_statistics();
326 }
327}
328
329pub trait ThreadAccumulator<Input, Output> {
331 type Accumulator<'a>
333 where
334 Self: 'a;
335
336 fn init(&self) -> Self::Accumulator<'_>;
338
339 fn process_item<'a>(
341 &'a self,
342 accumulator: &mut Self::Accumulator<'a>,
343 index: usize,
344 item: &Input,
345 );
346
347 fn finalize<'a>(&'a self, accumulator: Self::Accumulator<'a>) -> Output;
349}
350
351struct ThreadContext<'env, Rn: Range, Input, Output, Accum: ThreadAccumulator<Input, Output>> {
353 id: usize,
355 num_active_threads: Arc<AtomicUsize>,
357 worker_status: Arc<Status<WorkerStatus>>,
359 main_status: Arc<Status<MainStatus>>,
361 range: Rn,
363 input: &'env [Input],
365 output: Arc<Mutex<Option<Output>>>,
367 accumulator: Accum,
369}
370
371impl<Rn: Range, Input, Output, Accum: ThreadAccumulator<Input, Output>>
372 ThreadContext<'_, Rn, Input, Output, Accum>
373{
374 fn run(&self) {
376 let mut round = RoundColor::Blue;
377 loop {
378 round.toggle();
379 debug!(
380 "[thread {}, round {round:?}] Waiting for start signal",
381 self.id
382 );
383
384 let worker_status: WorkerStatus =
385 *self.worker_status.wait_while(|status| match status {
386 WorkerStatus::Finished => false,
387 WorkerStatus::Round(r) => *r != round,
388 });
389 match worker_status {
390 WorkerStatus::Finished => {
391 debug!(
392 "[thread {}, round {round:?}] Received finish signal",
393 self.id
394 );
395 break;
396 }
397 WorkerStatus::Round(r) => {
398 assert_eq!(round, r);
399 debug!(
400 "[thread {}, round {round:?}] Received start signal. Processing...",
401 self.id
402 );
403
404 let panic_notifier = PanicNotifier {
407 id: self.id,
408 main_status: &self.main_status,
409 };
410 {
411 let mut accumulator = self.accumulator.init();
412 for i in self.range.iter() {
413 self.accumulator
414 .process_item(&mut accumulator, i, &self.input[i]);
415 }
416 *self.output.lock().unwrap() = Some(self.accumulator.finalize(accumulator));
417 }
418 std::mem::forget(panic_notifier);
419
420 let thread_count = self.num_active_threads.fetch_sub(1, Ordering::SeqCst);
421 assert!(thread_count > 0);
422 debug!(
423 "[thread {}, round {round:?}] Decremented the counter: {}.",
424 self.id,
425 thread_count - 1
426 );
427 if thread_count == 1 {
428 debug!(
430 "[thread {}, round {round:?}] We're the last thread. Notifying the main thread.",
431 self.id
432 );
433
434 self.main_status.notify_one_if(
435 |&status| status == MainStatus::Waiting,
436 MainStatus::Ready,
437 );
438
439 debug!(
440 "[thread {}, round {round:?}] Notified the main thread.",
441 self.id
442 );
443 } else {
444 debug!(
445 "[thread {}, round {round:?}] Waiting for other threads to finish.",
446 self.id
447 );
448 }
449 }
450 }
451 }
452 }
453}
454
455struct PanicNotifier<'a> {
464 id: usize,
466 main_status: &'a Status<MainStatus>,
468}
469
470impl Drop for PanicNotifier<'_> {
471 fn drop(&mut self) {
472 error!(
473 "[thread {}] Detected panic in this thread, notifying the main thread",
474 self.id
475 );
476 if let Err(e) = self.main_status.try_notify_one(MainStatus::WorkerPanic) {
477 error!(
478 "[thread {}] Failed to notify the main thread, the mutex was poisoned: {e:?}",
479 self.id
480 );
481 }
482 }
483}