stv_rs/parallelism/
mod.rs

1// Copyright 2023 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Hand-rolled parallelism utilities for vote counting.
16
17mod range;
18mod thread_pool;
19
20pub use thread_pool::{RangeStrategy, ThreadAccumulator, ThreadPool};
21
22#[cfg(test)]
23mod test {
24    use super::*;
25    use std::num::NonZeroUsize;
26
27    /// Example of accumulator that computes a sum of integers.
28    struct SumAccumulator;
29
30    impl ThreadAccumulator<u64, u64> for SumAccumulator {
31        type Accumulator<'a> = u64;
32
33        fn init(&self) -> u64 {
34            0
35        }
36
37        fn process_item(&self, accumulator: &mut u64, _index: usize, x: &u64) {
38            *accumulator += *x;
39        }
40
41        fn finalize(&self, accumulator: u64) -> u64 {
42            accumulator
43        }
44    }
45
46    macro_rules! parallelism_tests {
47        ( $mod:ident, $range_strategy:expr, $($case:ident,)+ ) => {
48            mod $mod {
49                use super::*;
50
51                $(
52                #[test]
53                fn $case() {
54                    $crate::parallelism::test::$case($range_strategy);
55                }
56                )+
57            }
58        };
59    }
60
61    macro_rules! all_parallelism_tests {
62        ( $mod:ident, $range_strategy:expr ) => {
63            parallelism_tests!($mod, $range_strategy, test_sum_integers, test_sum_twice,);
64        };
65    }
66
67    all_parallelism_tests!(fixed, RangeStrategy::Fixed);
68    all_parallelism_tests!(work_stealing, RangeStrategy::WorkStealing);
69
70    fn test_sum_integers(range_strategy: RangeStrategy) {
71        let input = (0..=10_000).collect::<Vec<u64>>();
72        let num_threads = NonZeroUsize::try_from(4).unwrap();
73        let sum = std::thread::scope(|scope| {
74            let thread_pool = ThreadPool::new(scope, num_threads, range_strategy, &input, || {
75                SumAccumulator
76            });
77            thread_pool.process_inputs().reduce(|a, b| a + b).unwrap()
78        });
79        assert_eq!(sum, 5_000 * 10_001);
80    }
81
82    fn test_sum_twice(range_strategy: RangeStrategy) {
83        let input = (0..=10_000).collect::<Vec<u64>>();
84        let num_threads = NonZeroUsize::try_from(4).unwrap();
85        let (sum1, sum2) = std::thread::scope(|scope| {
86            let thread_pool = ThreadPool::new(scope, num_threads, range_strategy, &input, || {
87                SumAccumulator
88            });
89            // The same input can be processed multiple times on the thread pool.
90            let sum1 = thread_pool.process_inputs().reduce(|a, b| a + b).unwrap();
91            let sum2 = thread_pool.process_inputs().reduce(|a, b| a + b).unwrap();
92            (sum1, sum2)
93        });
94        assert_eq!(sum1, 5_000 * 10_001);
95        assert_eq!(sum2, 5_000 * 10_001);
96    }
97}