stv_rs/parallelism/
mod.rs1mod range;
18mod thread_pool;
19
20pub use thread_pool::{RangeStrategy, ThreadAccumulator, ThreadPool};
21
22#[cfg(test)]
23mod test {
24 use super::*;
25 use std::num::NonZeroUsize;
26
27 struct SumAccumulator;
29
30 impl ThreadAccumulator<u64, u64> for SumAccumulator {
31 type Accumulator<'a> = u64;
32
33 fn init(&self) -> u64 {
34 0
35 }
36
37 fn process_item(&self, accumulator: &mut u64, _index: usize, x: &u64) {
38 *accumulator += *x;
39 }
40
41 fn finalize(&self, accumulator: u64) -> u64 {
42 accumulator
43 }
44 }
45
46 macro_rules! parallelism_tests {
47 ( $mod:ident, $range_strategy:expr, $($case:ident,)+ ) => {
48 mod $mod {
49 use super::*;
50
51 $(
52 #[test]
53 fn $case() {
54 $crate::parallelism::test::$case($range_strategy);
55 }
56 )+
57 }
58 };
59 }
60
61 macro_rules! all_parallelism_tests {
62 ( $mod:ident, $range_strategy:expr ) => {
63 parallelism_tests!($mod, $range_strategy, test_sum_integers, test_sum_twice,);
64 };
65 }
66
67 all_parallelism_tests!(fixed, RangeStrategy::Fixed);
68 all_parallelism_tests!(work_stealing, RangeStrategy::WorkStealing);
69
70 fn test_sum_integers(range_strategy: RangeStrategy) {
71 let input = (0..=10_000).collect::<Vec<u64>>();
72 let num_threads = NonZeroUsize::try_from(4).unwrap();
73 let sum = std::thread::scope(|scope| {
74 let thread_pool = ThreadPool::new(scope, num_threads, range_strategy, &input, || {
75 SumAccumulator
76 });
77 thread_pool.process_inputs().reduce(|a, b| a + b).unwrap()
78 });
79 assert_eq!(sum, 5_000 * 10_001);
80 }
81
82 fn test_sum_twice(range_strategy: RangeStrategy) {
83 let input = (0..=10_000).collect::<Vec<u64>>();
84 let num_threads = NonZeroUsize::try_from(4).unwrap();
85 let (sum1, sum2) = std::thread::scope(|scope| {
86 let thread_pool = ThreadPool::new(scope, num_threads, range_strategy, &input, || {
87 SumAccumulator
88 });
89 let sum1 = thread_pool.process_inputs().reduce(|a, b| a + b).unwrap();
91 let sum2 = thread_pool.process_inputs().reduce(|a, b| a + b).unwrap();
92 (sum1, sum2)
93 });
94 assert_eq!(sum1, 5_000 * 10_001);
95 assert_eq!(sum2, 5_000 * 10_001);
96 }
97}