Skip to main content

shape_runtime/simulation/
parallel.rs

1//! Parallel Parameter Sweeps
2//!
3//! This module provides parallel execution utilities for running multiple
4//! simulations with different parameter sets using rayon for parallelization.
5
6use super::dense_kernel::{DenseKernel, DenseKernelConfig, DenseKernelResult};
7use shape_ast::error::Result;
8use shape_value::DataTable;
9use std::sync::Arc;
10
11/// Result of a parallel parameter sweep.
12#[derive(Debug)]
13pub struct ParallelSweepResult<S, P> {
14    /// Results for each parameter set
15    pub results: Vec<(P, DenseKernelResult<S>)>,
16    /// Number of simulations run
17    pub simulations_run: usize,
18    /// Total ticks processed across all simulations
19    pub total_ticks: usize,
20}
21
22/// Run a parameter sweep in parallel using rayon.
23///
24/// This function runs multiple simulations in parallel, each with different
25/// parameter values. Data is shared (zero-copy) across all simulations using Arc.
26pub fn par_run<P, S, F>(
27    data: Arc<DataTable>,
28    param_sets: Vec<P>,
29    strategy_factory: F,
30) -> Result<ParallelSweepResult<S, P>>
31where
32    P: Send + Sync + Clone,
33    S: Send + Default,
34    F: Fn(&P) -> Box<dyn FnMut(usize, &[*const f64], &mut S) -> i32 + Send> + Send + Sync,
35{
36    let config = DenseKernelConfig::full(data.row_count());
37    par_run_with_config(data, param_sets, config, strategy_factory)
38}
39
40/// Run a parameter sweep with custom DenseKernelConfig.
41pub fn par_run_with_config<P, S, F>(
42    data: Arc<DataTable>,
43    param_sets: Vec<P>,
44    config: DenseKernelConfig,
45    strategy_factory: F,
46) -> Result<ParallelSweepResult<S, P>>
47where
48    P: Send + Sync + Clone,
49    S: Send + Default,
50    F: Fn(&P) -> Box<dyn FnMut(usize, &[*const f64], &mut S) -> i32 + Send> + Send + Sync,
51{
52    use rayon::prelude::*;
53
54    let results: Vec<(P, DenseKernelResult<S>)> = param_sets
55        .par_iter()
56        .map(|params| {
57            let kernel = DenseKernel::new(config.clone());
58            let mut strategy = strategy_factory(params);
59            let state = S::default();
60            let result = kernel.run(&data, state, |idx, ptrs, s| strategy(idx, ptrs, s));
61            // If simulation errored, create a default result
62            let result = result.unwrap_or(DenseKernelResult {
63                final_state: S::default(),
64                ticks_processed: 0,
65                completed: false,
66            });
67            (params.clone(), result)
68        })
69        .collect();
70
71    let simulations_run = results.len();
72    let total_ticks = results.iter().map(|(_, r)| r.ticks_processed).sum();
73
74    Ok(ParallelSweepResult {
75        results,
76        simulations_run,
77        total_ticks,
78    })
79}
80
81/// Build a 2D parameter grid.
82pub fn param_grid<A, B>(a_values: Vec<A>, b_values: Vec<B>) -> Vec<(A, B)>
83where
84    A: Clone,
85    B: Clone,
86{
87    let mut grid = Vec::with_capacity(a_values.len() * b_values.len());
88    for a in &a_values {
89        for b in &b_values {
90            grid.push((a.clone(), b.clone()));
91        }
92    }
93    grid
94}
95
96/// Build a 3D parameter grid.
97pub fn param_grid3<A, B, C>(a_values: Vec<A>, b_values: Vec<B>, c_values: Vec<C>) -> Vec<(A, B, C)>
98where
99    A: Clone,
100    B: Clone,
101    C: Clone,
102{
103    let mut grid = Vec::with_capacity(a_values.len() * b_values.len() * c_values.len());
104    for a in &a_values {
105        for b in &b_values {
106            for c in &c_values {
107                grid.push((a.clone(), b.clone(), c.clone()));
108            }
109        }
110    }
111    grid
112}
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117
118    #[test]
119    fn test_param_grid() {
120        let grid = param_grid(vec![1, 2], vec![10, 20, 30]);
121        assert_eq!(grid.len(), 6);
122        assert_eq!(grid[0], (1, 10));
123        assert_eq!(grid[5], (2, 30));
124    }
125
126    #[test]
127    fn test_param_grid3() {
128        let grid = param_grid3(vec![1, 2], vec![10, 20], vec![100, 200]);
129        assert_eq!(grid.len(), 8); // 2 × 2 × 2
130        assert_eq!(grid[0], (1, 10, 100));
131        assert_eq!(grid[7], (2, 20, 200));
132    }
133}