shape_runtime/simulation/
parallel.rs1use super::dense_kernel::{DenseKernel, DenseKernelConfig, DenseKernelResult};
7use shape_ast::error::Result;
8use shape_value::DataTable;
9use std::sync::Arc;
10
11#[derive(Debug)]
13pub struct ParallelSweepResult<S, P> {
14 pub results: Vec<(P, DenseKernelResult<S>)>,
16 pub simulations_run: usize,
18 pub total_ticks: usize,
20}
21
22pub fn par_run<P, S, F>(
27 data: Arc<DataTable>,
28 param_sets: Vec<P>,
29 strategy_factory: F,
30) -> Result<ParallelSweepResult<S, P>>
31where
32 P: Send + Sync + Clone,
33 S: Send + Default,
34 F: Fn(&P) -> Box<dyn FnMut(usize, &[*const f64], &mut S) -> i32 + Send> + Send + Sync,
35{
36 let config = DenseKernelConfig::full(data.row_count());
37 par_run_with_config(data, param_sets, config, strategy_factory)
38}
39
40pub fn par_run_with_config<P, S, F>(
42 data: Arc<DataTable>,
43 param_sets: Vec<P>,
44 config: DenseKernelConfig,
45 strategy_factory: F,
46) -> Result<ParallelSweepResult<S, P>>
47where
48 P: Send + Sync + Clone,
49 S: Send + Default,
50 F: Fn(&P) -> Box<dyn FnMut(usize, &[*const f64], &mut S) -> i32 + Send> + Send + Sync,
51{
52 use rayon::prelude::*;
53
54 let results: Vec<(P, DenseKernelResult<S>)> = param_sets
55 .par_iter()
56 .map(|params| {
57 let kernel = DenseKernel::new(config.clone());
58 let mut strategy = strategy_factory(params);
59 let state = S::default();
60 let result = kernel.run(&data, state, |idx, ptrs, s| strategy(idx, ptrs, s));
61 let result = result.unwrap_or(DenseKernelResult {
63 final_state: S::default(),
64 ticks_processed: 0,
65 completed: false,
66 });
67 (params.clone(), result)
68 })
69 .collect();
70
71 let simulations_run = results.len();
72 let total_ticks = results.iter().map(|(_, r)| r.ticks_processed).sum();
73
74 Ok(ParallelSweepResult {
75 results,
76 simulations_run,
77 total_ticks,
78 })
79}
80
81pub fn param_grid<A, B>(a_values: Vec<A>, b_values: Vec<B>) -> Vec<(A, B)>
83where
84 A: Clone,
85 B: Clone,
86{
87 let mut grid = Vec::with_capacity(a_values.len() * b_values.len());
88 for a in &a_values {
89 for b in &b_values {
90 grid.push((a.clone(), b.clone()));
91 }
92 }
93 grid
94}
95
96pub fn param_grid3<A, B, C>(a_values: Vec<A>, b_values: Vec<B>, c_values: Vec<C>) -> Vec<(A, B, C)>
98where
99 A: Clone,
100 B: Clone,
101 C: Clone,
102{
103 let mut grid = Vec::with_capacity(a_values.len() * b_values.len() * c_values.len());
104 for a in &a_values {
105 for b in &b_values {
106 for c in &c_values {
107 grid.push((a.clone(), b.clone(), c.clone()));
108 }
109 }
110 }
111 grid
112}
113
114#[cfg(test)]
115mod tests {
116 use super::*;
117
118 #[test]
119 fn test_param_grid() {
120 let grid = param_grid(vec![1, 2], vec![10, 20, 30]);
121 assert_eq!(grid.len(), 6);
122 assert_eq!(grid[0], (1, 10));
123 assert_eq!(grid[5], (2, 30));
124 }
125
126 #[test]
127 fn test_param_grid3() {
128 let grid = param_grid3(vec![1, 2], vec![10, 20], vec![100, 200]);
129 assert_eq!(grid.len(), 8); assert_eq!(grid[0], (1, 10, 100));
131 assert_eq!(grid[7], (2, 20, 200));
132 }
133}