stv_rs/parallelism/
thread_pool.rs

1// Copyright 2023 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! A hand-rolled thread pool, customized for the vote counting problem.
16
17use super::range::{
18    FixedRangeFactory, Range, RangeFactory, RangeOrchestrator, WorkStealingRangeFactory,
19};
20use log::{debug, error, warn};
21// Platforms that support `libc::sched_setaffinity()`.
22#[cfg(any(
23    target_os = "android",
24    target_os = "dragonfly",
25    target_os = "freebsd",
26    target_os = "linux"
27))]
28use nix::{
29    sched::{sched_setaffinity, CpuSet},
30    unistd::Pid,
31};
32use std::cell::Cell;
33use std::num::NonZeroUsize;
34use std::sync::atomic::{AtomicUsize, Ordering};
35use std::sync::{Arc, Condvar, Mutex, MutexGuard, PoisonError};
36use std::thread::{Scope, ScopedJoinHandle};
37
38/// Status of the main thread.
39#[derive(Clone, Copy, PartialEq, Eq)]
40enum MainStatus {
41    /// The main thread is waiting for the worker threads to finish a round.
42    Waiting,
43    /// The main thread is ready to prepare the next round.
44    Ready,
45    /// One of the worker threads panicked.
46    WorkerPanic,
47}
48
49/// Status sent to the worker threads.
50#[derive(Clone, Copy, PartialEq, Eq)]
51enum WorkerStatus {
52    /// The threads need to compute a vote counting round of the given color.
53    Round(RoundColor),
54    /// There is nothing more to do and the threads must exit.
55    Finished,
56}
57
58/// An 2-element enumeration to distinguish successive rounds. The "colors" are
59/// only illustrative.
60#[derive(Debug, Clone, Copy, PartialEq, Eq)]
61enum RoundColor {
62    Blue,
63    Red,
64}
65
66impl RoundColor {
67    /// Flips to the other color.
68    fn toggle(&mut self) {
69        *self = match self {
70            RoundColor::Blue => RoundColor::Red,
71            RoundColor::Red => RoundColor::Blue,
72        }
73    }
74}
75
76/// An ergonomic wrapper around a [`Mutex`]-[`Condvar`] pair.
77struct Status<T> {
78    mutex: Mutex<T>,
79    condvar: Condvar,
80}
81
82impl<T> Status<T> {
83    /// Creates a new status initialized with the given value.
84    fn new(t: T) -> Self {
85        Self {
86            mutex: Mutex::new(t),
87            condvar: Condvar::new(),
88        }
89    }
90
91    /// Attempts to set the status to the given value and notifies one waiting
92    /// thread.
93    ///
94    /// Fails if the [`Mutex`] is poisoned.
95    fn try_notify_one(&self, t: T) -> Result<(), PoisonError<MutexGuard<'_, T>>> {
96        *self.mutex.lock()? = t;
97        self.condvar.notify_one();
98        Ok(())
99    }
100
101    /// If the predicate is true on this status, sets the status to the given
102    /// value and notifies one waiting thread.
103    fn notify_one_if(&self, predicate: impl Fn(&T) -> bool, t: T) {
104        let mut locked = self.mutex.lock().unwrap();
105        if predicate(&*locked) {
106            *locked = t;
107            self.condvar.notify_one();
108        }
109    }
110
111    /// Sets the status to the given value and notifies all waiting threads.
112    fn notify_all(&self, t: T) {
113        *self.mutex.lock().unwrap() = t;
114        self.condvar.notify_all();
115    }
116
117    /// Waits until the predicate is true on this status.
118    ///
119    /// This returns a [`MutexGuard`], allowing to further inspect or modify the
120    /// status.
121    fn wait_while(&self, predicate: impl FnMut(&mut T) -> bool) -> MutexGuard<T> {
122        self.condvar
123            .wait_while(self.mutex.lock().unwrap(), predicate)
124            .unwrap()
125    }
126}
127
128/// A thread pool tied to a scope, that can process inputs into the given output
129/// type.
130pub struct ThreadPool<'scope, Output> {
131    /// Handles to all the worker threads in the pool.
132    threads: Vec<WorkerThreadHandle<'scope, Output>>,
133    /// Number of worker threads active in the current round.
134    num_active_threads: Arc<AtomicUsize>,
135    /// Color of the current round.
136    round: Cell<RoundColor>,
137    /// Status of the worker threads.
138    worker_status: Arc<Status<WorkerStatus>>,
139    /// Status of the main thread.
140    main_status: Arc<Status<MainStatus>>,
141    /// Orchestrator for the work ranges distributed to the threads. This is a
142    /// dynamic object to avoid making the range type a parameter of
143    /// everything.
144    range_orchestrator: Box<dyn RangeOrchestrator>,
145}
146
147/// Handle to a worker thread in the pool.
148struct WorkerThreadHandle<'scope, Output> {
149    /// Thread handle object.
150    handle: ScopedJoinHandle<'scope, ()>,
151    /// Storage for this thread's computation output.
152    output: Arc<Mutex<Option<Output>>>,
153}
154
155/// Strategy to distribute ranges of work items among threads.
156pub enum RangeStrategy {
157    /// Each thread processes a fixed range of items.
158    Fixed,
159    /// Threads can steal work from each other.
160    WorkStealing,
161}
162
163impl<'scope, Output: Send + 'scope> ThreadPool<'scope, Output> {
164    /// Creates a new pool tied to the given scope, spawning the given number of
165    /// threads and using the given input slice.
166    pub fn new<'env, Input: Sync, Accum: ThreadAccumulator<Input, Output> + Send + 'scope>(
167        thread_scope: &'scope Scope<'scope, 'env>,
168        num_threads: NonZeroUsize,
169        range_strategy: RangeStrategy,
170        input: &'env [Input],
171        new_accumulator: impl Fn() -> Accum,
172    ) -> Self {
173        let num_threads: usize = num_threads.into();
174        let input_len = input.len();
175        match range_strategy {
176            RangeStrategy::Fixed => Self::new_with_factory(
177                thread_scope,
178                num_threads,
179                FixedRangeFactory::new(input_len, num_threads),
180                input,
181                new_accumulator,
182            ),
183            RangeStrategy::WorkStealing => Self::new_with_factory(
184                thread_scope,
185                num_threads,
186                WorkStealingRangeFactory::new(input_len, num_threads),
187                input,
188                new_accumulator,
189            ),
190        }
191    }
192
193    fn new_with_factory<
194        'env,
195        RnFactory: RangeFactory,
196        Input: Sync,
197        Accum: ThreadAccumulator<Input, Output> + Send + 'scope,
198    >(
199        thread_scope: &'scope Scope<'scope, 'env>,
200        num_threads: usize,
201        range_factory: RnFactory,
202        input: &'env [Input],
203        new_accumulator: impl Fn() -> Accum,
204    ) -> Self
205    where
206        RnFactory::Rn: 'scope + Send,
207        RnFactory::Orchestrator: 'static,
208    {
209        let color = RoundColor::Blue;
210        let num_active_threads = Arc::new(AtomicUsize::new(0));
211        let worker_status = Arc::new(Status::new(WorkerStatus::Round(color)));
212        let main_status = Arc::new(Status::new(MainStatus::Waiting));
213
214        #[cfg(not(any(
215            target_os = "android",
216            target_os = "dragonfly",
217            target_os = "freebsd",
218            target_os = "linux"
219        )))]
220        warn!("Pinning threads to CPUs is not implemented on this platform.");
221        let threads = (0..num_threads)
222            .map(|id| {
223                let output = Arc::new(Mutex::new(None));
224                let context = ThreadContext {
225                    id,
226                    num_active_threads: num_active_threads.clone(),
227                    worker_status: worker_status.clone(),
228                    main_status: main_status.clone(),
229                    range: range_factory.range(id),
230                    input,
231                    output: output.clone(),
232                    accumulator: new_accumulator(),
233                };
234                WorkerThreadHandle {
235                    handle: thread_scope.spawn(move || {
236                        #[cfg(any(
237                            target_os = "android",
238                            target_os = "dragonfly",
239                            target_os = "freebsd",
240                            target_os = "linux"
241                        ))]
242                        {
243                            let mut cpu_set = CpuSet::new();
244                            if let Err(e) = cpu_set.set(id) {
245                                warn!("Failed to set CPU affinity for thread #{id}: {e}");
246                            } else if let Err(e) = sched_setaffinity(Pid::from_raw(0), &cpu_set) {
247                                warn!("Failed to set CPU affinity for thread #{id}: {e}");
248                            } else {
249                                debug!("Pinned thread #{id} to CPU #{id}");
250                            }
251                        }
252                        context.run()
253                    }),
254                    output,
255                }
256            })
257            .collect();
258        debug!("[main thread] Spawned threads");
259
260        Self {
261            threads,
262            num_active_threads,
263            round: Cell::new(color),
264            worker_status,
265            main_status,
266            range_orchestrator: Box::new(range_factory.orchestrator()),
267        }
268    }
269
270    /// Performs a computation round, processing the input slice in parallel and
271    /// returning an iterator over the threads' outputs.
272    pub fn process_inputs(&self) -> impl Iterator<Item = Output> + '_ {
273        self.range_orchestrator.reset_ranges();
274
275        let num_threads = self.threads.len();
276        self.num_active_threads.store(num_threads, Ordering::SeqCst);
277
278        let mut round = self.round.get();
279        round.toggle();
280        self.round.set(round);
281
282        debug!("[main thread, round {round:?}] Ready to accumulate votes.");
283
284        self.worker_status.notify_all(WorkerStatus::Round(round));
285
286        debug!(
287            "[main thread, round {round:?}] Waiting for all threads to finish accumulating votes."
288        );
289
290        let mut guard = self
291            .main_status
292            .wait_while(|status| *status == MainStatus::Waiting);
293        if *guard == MainStatus::WorkerPanic {
294            error!("[main thread] A worker thread panicked!");
295            panic!("A worker thread panicked!");
296        }
297        *guard = MainStatus::Waiting;
298        drop(guard);
299
300        debug!("[main thread, round {round:?}] All threads have now finished accumulating votes.");
301
302        self.threads
303            .iter()
304            .map(move |t| t.output.lock().unwrap().take().unwrap())
305    }
306}
307
308impl<Output> Drop for ThreadPool<'_, Output> {
309    /// Joins all the threads in the pool.
310    fn drop(&mut self) {
311        debug!("[main thread] Notifying threads to finish...");
312        self.worker_status.notify_all(WorkerStatus::Finished);
313
314        debug!("[main thread] Joining threads in the pool...");
315        for (i, t) in self.threads.drain(..).enumerate() {
316            let result = t.handle.join();
317            match result {
318                Ok(_) => debug!("[main thread] Thread {i} joined with result: {result:?}"),
319                Err(_) => error!("[main thread] Thread {i} joined with result: {result:?}"),
320            }
321        }
322        debug!("[main thread] Joined threads.");
323
324        #[cfg(feature = "log_parallelism")]
325        self.range_orchestrator.print_statistics();
326    }
327}
328
329/// Trait representing a function to map and reduce inputs into an output.
330pub trait ThreadAccumulator<Input, Output> {
331    /// Type to accumulate inputs into.
332    type Accumulator<'a>
333    where
334        Self: 'a;
335
336    /// Creates a new accumulator to process inputs.
337    fn init(&self) -> Self::Accumulator<'_>;
338
339    /// Accumulates the given input item.
340    fn process_item<'a>(
341        &'a self,
342        accumulator: &mut Self::Accumulator<'a>,
343        index: usize,
344        item: &Input,
345    );
346
347    /// Converts the given accumulator into an output.
348    fn finalize<'a>(&'a self, accumulator: Self::Accumulator<'a>) -> Output;
349}
350
351/// Context object owned by a worker thread.
352struct ThreadContext<'env, Rn: Range, Input, Output, Accum: ThreadAccumulator<Input, Output>> {
353    /// Thread index.
354    id: usize,
355    /// Number of worker threads active in the current round.
356    num_active_threads: Arc<AtomicUsize>,
357    /// Status of the worker threads.
358    worker_status: Arc<Status<WorkerStatus>>,
359    /// Status of the main thread.
360    main_status: Arc<Status<MainStatus>>,
361    /// Range of items that this worker thread needs to process.
362    range: Rn,
363    /// Reference to the inputs to process.
364    input: &'env [Input],
365    /// Output that this thread writes to.
366    output: Arc<Mutex<Option<Output>>>,
367    /// Function to map and reduce inputs into the output.
368    accumulator: Accum,
369}
370
371impl<Rn: Range, Input, Output, Accum: ThreadAccumulator<Input, Output>>
372    ThreadContext<'_, Rn, Input, Output, Accum>
373{
374    /// Main function run by this thread.
375    fn run(&self) {
376        let mut round = RoundColor::Blue;
377        loop {
378            round.toggle();
379            debug!(
380                "[thread {}, round {round:?}] Waiting for start signal",
381                self.id
382            );
383
384            let worker_status: WorkerStatus =
385                *self.worker_status.wait_while(|status| match status {
386                    WorkerStatus::Finished => false,
387                    WorkerStatus::Round(r) => *r != round,
388                });
389            match worker_status {
390                WorkerStatus::Finished => {
391                    debug!(
392                        "[thread {}, round {round:?}] Received finish signal",
393                        self.id
394                    );
395                    break;
396                }
397                WorkerStatus::Round(r) => {
398                    assert_eq!(round, r);
399                    debug!(
400                        "[thread {}, round {round:?}] Received start signal. Processing...",
401                        self.id
402                    );
403
404                    // Counting votes may panic, and we want to notify the main thread in that case
405                    // to avoid a deadlock.
406                    let panic_notifier = PanicNotifier {
407                        id: self.id,
408                        main_status: &self.main_status,
409                    };
410                    {
411                        let mut accumulator = self.accumulator.init();
412                        for i in self.range.iter() {
413                            self.accumulator
414                                .process_item(&mut accumulator, i, &self.input[i]);
415                        }
416                        *self.output.lock().unwrap() = Some(self.accumulator.finalize(accumulator));
417                    }
418                    std::mem::forget(panic_notifier);
419
420                    let thread_count = self.num_active_threads.fetch_sub(1, Ordering::SeqCst);
421                    assert!(thread_count > 0);
422                    debug!(
423                        "[thread {}, round {round:?}] Decremented the counter: {}.",
424                        self.id,
425                        thread_count - 1
426                    );
427                    if thread_count == 1 {
428                        // We're the last thread.
429                        debug!(
430                            "[thread {}, round {round:?}] We're the last thread. Notifying the main thread.",
431                            self.id
432                        );
433
434                        self.main_status.notify_one_if(
435                            |&status| status == MainStatus::Waiting,
436                            MainStatus::Ready,
437                        );
438
439                        debug!(
440                            "[thread {}, round {round:?}] Notified the main thread.",
441                            self.id
442                        );
443                    } else {
444                        debug!(
445                            "[thread {}, round {round:?}] Waiting for other threads to finish.",
446                            self.id
447                        );
448                    }
449                }
450            }
451        }
452    }
453}
454
455/// Object whose destructor notifies the main thread that a panic happened.
456///
457/// The way to use this is to create an instance before a section that may
458/// panic, and to [`std::mem::forget()`] it at the end of the section. That way:
459/// - If a panic happens, the [`std::mem::forget()`] call will be skipped but
460///   the destructor will run due to RAII.
461/// - If no panic happens, the destructor won't run because this object will be
462///   forgotten.
463struct PanicNotifier<'a> {
464    /// Thread index.
465    id: usize,
466    /// Status of the main thread.
467    main_status: &'a Status<MainStatus>,
468}
469
470impl Drop for PanicNotifier<'_> {
471    fn drop(&mut self) {
472        error!(
473            "[thread {}] Detected panic in this thread, notifying the main thread",
474            self.id
475        );
476        if let Err(e) = self.main_status.try_notify_one(MainStatus::WorkerPanic) {
477            error!(
478                "[thread {}] Failed to notify the main thread, the mutex was poisoned: {e:?}",
479                self.id
480            );
481        }
482    }
483}