1use std::sync::Arc;
2
3use rayon::{ThreadPool, ThreadPoolBuildError, ThreadPoolBuilder, prelude::*};
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq)]
6pub struct ExecutionOptions {
7 pub parallel: bool,
8 pub jobs: Option<usize>,
9}
10
11impl Default for ExecutionOptions {
12 fn default() -> Self {
13 Self {
14 parallel: true,
15 jobs: None,
16 }
17 }
18}
19
20#[derive(Clone)]
21pub struct ExecutionContext {
22 options: ExecutionOptions,
23 pool: Option<Arc<ThreadPool>>,
24}
25
26impl ExecutionContext {
27 pub fn new(options: ExecutionOptions) -> Result<Self, ThreadPoolBuildError> {
28 let pool = match (options.parallel, options.jobs) {
29 (true, Some(jobs)) => Some(Arc::new(
30 ThreadPoolBuilder::new().num_threads(jobs).build()?,
31 )),
32 _ => None,
33 };
34
35 Ok(Self { options, pool })
36 }
37
38 pub fn serial() -> Self {
39 Self {
40 options: ExecutionOptions {
41 parallel: false,
42 jobs: Some(1),
43 },
44 pool: None,
45 }
46 }
47
48 pub fn options(&self) -> ExecutionOptions {
49 self.options
50 }
51
52 pub fn map<T, R, E, F>(&self, items: Vec<T>, map_item: F) -> Result<Vec<R>, E>
53 where
54 T: Send,
55 R: Send,
56 E: Send,
57 F: Fn(T) -> Result<R, E> + Send + Sync,
58 {
59 if !self.options.parallel || items.len() <= 1 {
60 return items.into_iter().map(map_item).collect();
61 }
62
63 match &self.pool {
64 Some(pool) => pool.install(|| items.into_par_iter().map(map_item).collect()),
65 None => items.into_par_iter().map(map_item).collect(),
66 }
67 }
68}
69
70impl Default for ExecutionContext {
71 fn default() -> Self {
72 Self::new(ExecutionOptions::default())
73 .expect("default execution context must use a valid thread pool configuration")
74 }
75}
76
77#[cfg(test)]
78mod tests {
79 use super::*;
80
81 #[test]
82 fn serial_context_preserves_order() {
83 let context = ExecutionContext::serial();
84
85 let values = context
86 .map(vec![1, 2, 3], |value| Ok::<_, ()>(value * 2))
87 .unwrap();
88
89 assert_eq!(values, vec![2, 4, 6]);
90 }
91
92 #[test]
93 fn parallel_context_preserves_order() {
94 let context = ExecutionContext::new(ExecutionOptions {
95 parallel: true,
96 jobs: Some(2),
97 })
98 .unwrap();
99
100 let values = context
101 .map(vec![1, 2, 3], |value| Ok::<_, ()>(value * 2))
102 .unwrap();
103
104 assert_eq!(values, vec![2, 4, 6]);
105 }
106}