Skip to main content

sora_execution/
lib.rs

1use std::sync::Arc;
2
3use rayon::{ThreadPool, ThreadPoolBuildError, ThreadPoolBuilder, prelude::*};
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq)]
6pub struct ExecutionOptions {
7    pub parallel: bool,
8    pub jobs: Option<usize>,
9}
10
11impl Default for ExecutionOptions {
12    fn default() -> Self {
13        Self {
14            parallel: true,
15            jobs: None,
16        }
17    }
18}
19
20#[derive(Clone)]
21pub struct ExecutionContext {
22    options: ExecutionOptions,
23    pool: Option<Arc<ThreadPool>>,
24}
25
26impl ExecutionContext {
27    pub fn new(options: ExecutionOptions) -> Result<Self, ThreadPoolBuildError> {
28        let pool = match (options.parallel, options.jobs) {
29            (true, Some(jobs)) => Some(Arc::new(
30                ThreadPoolBuilder::new().num_threads(jobs).build()?,
31            )),
32            _ => None,
33        };
34
35        Ok(Self { options, pool })
36    }
37
38    pub fn serial() -> Self {
39        Self {
40            options: ExecutionOptions {
41                parallel: false,
42                jobs: Some(1),
43            },
44            pool: None,
45        }
46    }
47
48    pub fn options(&self) -> ExecutionOptions {
49        self.options
50    }
51
52    pub fn map<T, R, E, F>(&self, items: Vec<T>, map_item: F) -> Result<Vec<R>, E>
53    where
54        T: Send,
55        R: Send,
56        E: Send,
57        F: Fn(T) -> Result<R, E> + Send + Sync,
58    {
59        if !self.options.parallel || items.len() <= 1 {
60            return items.into_iter().map(map_item).collect();
61        }
62
63        match &self.pool {
64            Some(pool) => pool.install(|| items.into_par_iter().map(map_item).collect()),
65            None => items.into_par_iter().map(map_item).collect(),
66        }
67    }
68}
69
70impl Default for ExecutionContext {
71    fn default() -> Self {
72        Self::new(ExecutionOptions::default())
73            .expect("default execution context must use a valid thread pool configuration")
74    }
75}
76
77#[cfg(test)]
78mod tests {
79    use super::*;
80
81    #[test]
82    fn serial_context_preserves_order() {
83        let context = ExecutionContext::serial();
84
85        let values = context
86            .map(vec![1, 2, 3], |value| Ok::<_, ()>(value * 2))
87            .unwrap();
88
89        assert_eq!(values, vec![2, 4, 6]);
90    }
91
92    #[test]
93    fn parallel_context_preserves_order() {
94        let context = ExecutionContext::new(ExecutionOptions {
95            parallel: true,
96            jobs: Some(2),
97        })
98        .unwrap();
99
100        let values = context
101            .map(vec![1, 2, 3], |value| Ok::<_, ()>(value * 2))
102            .unwrap();
103
104        assert_eq!(values, vec![2, 4, 6]);
105    }
106}