Skip to main content

scirs2_integrate/ode/ensemble/
types.rs

1//! Types for the ensemble ODE solver.
2
3/// Configuration for batched parallel ODE integration.
4#[non_exhaustive]
5#[derive(Debug, Clone)]
6pub struct EnsembleConfig {
7    /// Number of ensemble members (parameter sets / initial conditions).
8    /// Default: 100.
9    pub n_ensemble: usize,
10    /// Number of worker threads.  Default: number of logical CPUs.
11    pub n_threads: usize,
12    /// Relative error tolerance for step-size control.  Default: 1e-6.
13    pub rtol: f64,
14    /// Absolute error tolerance for step-size control.  Default: 1e-9.
15    pub atol: f64,
16    /// Integration interval `(t0, t_end)`.  Default: (0.0, 1.0).
17    pub t_span: (f64, f64),
18    /// Maximum number of accepted steps per member.  Default: 100_000.
19    pub max_steps: usize,
20    /// Initial step size (0 means auto).  Default: 0.0.
21    pub h_init: f64,
22}
23
24impl Default for EnsembleConfig {
25    fn default() -> Self {
26        let n_threads = num_cpus::get().max(1);
27        Self {
28            n_ensemble: 100,
29            n_threads,
30            rtol: 1e-6,
31            atol: 1e-9,
32            t_span: (0.0, 1.0),
33            max_steps: 100_000,
34            h_init: 0.0,
35        }
36    }
37}
38
39/// Result of an ensemble ODE integration.
40#[derive(Debug, Clone)]
41pub struct EnsembleResult {
42    /// Trajectories: `trajectories[i][k][j]` = state component `j` of member `i`
43    /// at time index `k`.
44    pub trajectories: Vec<Vec<Vec<f64>>>,
45    /// Adaptive time grids: `times[i][k]` = time of step `k` for member `i`.
46    pub times: Vec<Vec<f64>>,
47    /// Whether member `i` converged within `max_steps`.
48    pub converged: Vec<bool>,
49    /// Number of accepted steps taken by member `i`.
50    pub n_steps: Vec<usize>,
51}
52
53impl EnsembleResult {
54    /// Compute the element-wise mean trajectory over all ensemble members.
55    ///
56    /// Members are interpolated onto a common time grid formed by taking
57    /// the union of all output times of member 0 (as a reference).
58    /// For simplicity, returns the mean of final states only if trajectories
59    /// have different lengths; uses the shortest common prefix otherwise.
60    ///
61    /// Returns `None` if the ensemble is empty.
62    pub fn mean_trajectory(&self) -> Option<Vec<Vec<f64>>> {
63        if self.trajectories.is_empty() {
64            return None;
65        }
66
67        // Use the trajectory of member 0 as the reference length
68        let ref_len = self.trajectories[0].len();
69        if ref_len == 0 {
70            return None;
71        }
72
73        let n_state = self.trajectories[0][0].len();
74        let n_members = self.trajectories.len();
75
76        // Find the minimum trajectory length across all members
77        let min_len = self
78            .trajectories
79            .iter()
80            .map(|traj| traj.len())
81            .min()
82            .unwrap_or(0);
83
84        if min_len == 0 {
85            return None;
86        }
87
88        let mut mean = vec![vec![0.0_f64; n_state]; min_len];
89        for traj in &self.trajectories {
90            for (k, step) in traj.iter().take(min_len).enumerate() {
91                for (j, &val) in step.iter().enumerate() {
92                    mean[k][j] += val;
93                }
94            }
95        }
96        let n_f = n_members as f64;
97        for step in mean.iter_mut() {
98            for val in step.iter_mut() {
99                *val /= n_f;
100            }
101        }
102        Some(mean)
103    }
104
105    /// Compute the element-wise standard deviation trajectory.
106    ///
107    /// Returns `None` if fewer than 2 members or if the ensemble is empty.
108    pub fn std_trajectory(&self) -> Option<Vec<Vec<f64>>> {
109        if self.trajectories.len() < 2 {
110            return None;
111        }
112        let mean = self.mean_trajectory()?;
113        let min_len = mean.len();
114        let n_state = mean[0].len();
115        let n_members = self.trajectories.len();
116
117        let mut variance = vec![vec![0.0_f64; n_state]; min_len];
118        for traj in &self.trajectories {
119            for (k, step) in traj.iter().take(min_len).enumerate() {
120                for (j, &val) in step.iter().enumerate() {
121                    let diff = val - mean[k][j];
122                    variance[k][j] += diff * diff;
123                }
124            }
125        }
126        let n_f = (n_members - 1) as f64;
127        for step in variance.iter_mut() {
128            for val in step.iter_mut() {
129                *val = (*val / n_f).sqrt();
130            }
131        }
132        Some(variance)
133    }
134
135    /// Return quantile trajectories at quantile `q` (e.g. 0.5 = median).
136    ///
137    /// Computes the quantile across ensemble members at each common time step.
138    /// Returns `None` if the ensemble is empty.
139    pub fn quantile_trajectories(&self, q: f64) -> Option<Vec<Vec<f64>>> {
140        if self.trajectories.is_empty() {
141            return None;
142        }
143        let min_len = self.trajectories.iter().map(|t| t.len()).min().unwrap_or(0);
144        if min_len == 0 {
145            return None;
146        }
147        let n_state = self.trajectories[0][0].len();
148        let n_members = self.trajectories.len();
149
150        let mut result = vec![vec![0.0_f64; n_state]; min_len];
151        for k in 0..min_len {
152            for j in 0..n_state {
153                let mut vals: Vec<f64> = self
154                    .trajectories
155                    .iter()
156                    .filter(|traj| traj.len() > k)
157                    .map(|traj| traj[k][j])
158                    .collect();
159                vals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
160                let idx = ((q * (n_members - 1) as f64).round() as usize).min(n_members - 1);
161                result[k][j] = vals[idx];
162            }
163        }
164        Some(result)
165    }
166}