Skip to main content

scirs2_integrate/ode/ensemble/
solver.rs

1//! Dormand-Prince RK45 ensemble ODE solver with FSAL optimisation.
2//!
3//! Each ensemble member is solved independently with adaptive step control.
4//! Members are distributed across threads using `std::thread::scope`.
5
6use super::types::{EnsembleConfig, EnsembleResult};
7use crate::error::{IntegrateError, IntegrateResult};
8
9// ── Dormand-Prince RK45 Butcher tableau ──────────────────────────────────────
10//
11// Dormand, J.R.; Prince, P.J. (1980).  "A family of embedded Runge-Kutta
12// formulae". J. Comput. Appl. Math. 6(1): 19-26.
13//
14// c2=1/5, c3=3/10, c4=4/5, c5=8/9, c6=1, c7=1
15//
16// The error estimate is  y5 − y4  using the 5th- and 4th-order solutions.
17
18/// RK45 Dormand-Prince: a coefficients (6×5 lower-triangular)
19const DP_A: [[f64; 6]; 6] = [
20    [1.0 / 5.0, 0.0, 0.0, 0.0, 0.0, 0.0],
21    [3.0 / 40.0, 9.0 / 40.0, 0.0, 0.0, 0.0, 0.0],
22    [44.0 / 45.0, -56.0 / 15.0, 32.0 / 9.0, 0.0, 0.0, 0.0],
23    [
24        19_372.0 / 6_561.0,
25        -25_360.0 / 2_187.0,
26        64_448.0 / 6_561.0,
27        -212.0 / 729.0,
28        0.0,
29        0.0,
30    ],
31    [
32        9_017.0 / 3_168.0,
33        -355.0 / 33.0,
34        46_732.0 / 5_247.0,
35        49.0 / 176.0,
36        -5_103.0 / 18_656.0,
37        0.0,
38    ],
39    [
40        35.0 / 384.0,
41        0.0,
42        500.0 / 1_113.0,
43        125.0 / 192.0,
44        -2_187.0 / 6_784.0,
45        11.0 / 84.0,
46    ],
47];
48
49/// RK45 Dormand-Prince: 5th-order weights (same as last row of A — FSAL)
50const DP_B5: [f64; 7] = [
51    35.0 / 384.0,
52    0.0,
53    500.0 / 1_113.0,
54    125.0 / 192.0,
55    -2_187.0 / 6_784.0,
56    11.0 / 84.0,
57    0.0,
58];
59
60/// RK45 Dormand-Prince: 4th-order embedded weights (for error estimate)
61const DP_B4: [f64; 7] = [
62    5_179.0 / 57_600.0,
63    0.0,
64    7_571.0 / 16_695.0,
65    393.0 / 640.0,
66    -92_097.0 / 339_200.0,
67    187.0 / 2_100.0,
68    1.0 / 40.0,
69];
70
71/// RK45 Dormand-Prince: node values c
72const DP_C: [f64; 7] = [0.0, 1.0 / 5.0, 3.0 / 10.0, 4.0 / 5.0, 8.0 / 9.0, 1.0, 1.0];
73
74// ── Per-member integrator ─────────────────────────────────────────────────────
75
76/// Internal state for integrating one ensemble member with RK45-FSAL.
77struct Rk45State {
78    t: f64,
79    y: Vec<f64>,
80    /// k1 (= f evaluated at current t, y) — reused from previous accepted step.
81    k1: Vec<f64>,
82    h: f64,
83}
84
85impl Rk45State {
86    fn new(t0: f64, y0: Vec<f64>, k1: Vec<f64>, h_init: f64) -> Self {
87        Self {
88            t: t0,
89            y: y0,
90            k1,
91            h: h_init,
92        }
93    }
94}
95
96/// Compute the RHS for a single stage.
97fn stage<F, P>(f: &F, t: f64, y: &[f64], param: &P) -> Vec<f64>
98where
99    F: Fn(f64, &[f64], &P) -> Vec<f64>,
100{
101    f(t, y, param)
102}
103
104/// Add scaled vectors: `result[i] = a[i] + scale * b[i]`.
105fn axpy(a: &[f64], scale: f64, b: &[f64]) -> Vec<f64> {
106    a.iter()
107        .zip(b.iter())
108        .map(|(&ai, &bi)| ai + scale * bi)
109        .collect()
110}
111
112/// Weighted sum: `result[i] = base[i] + h * Σ_j w[j] * ks[j][i]`.
113fn rk_sum(base: &[f64], h: f64, weights: &[f64], ks: &[Vec<f64>]) -> Vec<f64> {
114    let n = base.len();
115    let mut result = base.to_vec();
116    for (w, k) in weights.iter().zip(ks.iter()) {
117        if w.abs() < f64::EPSILON {
118            continue;
119        }
120        for i in 0..n {
121            result[i] += h * w * k[i];
122        }
123    }
124    result
125}
126
127/// Compute step-size error norm using mixed absolute/relative tolerance.
128fn error_norm(y: &[f64], y_new: &[f64], e: &[f64], rtol: f64, atol: f64) -> f64 {
129    let n = y.len();
130    if n == 0 {
131        return 0.0;
132    }
133    let mut sum = 0.0_f64;
134    for i in 0..n {
135        let scale = atol + rtol * y[i].abs().max(y_new[i].abs());
136        let ei = e[i] / scale;
137        sum += ei * ei;
138    }
139    (sum / n as f64).sqrt()
140}
141
142/// Take one adaptive RK45-FSAL step.
143///
144/// Returns `(y_new, k1_new, h_next, accepted)`.
145fn rk45_step<F, P>(
146    f: &F,
147    state: &Rk45State,
148    t_end: f64,
149    rtol: f64,
150    atol: f64,
151    param: &P,
152) -> (Vec<f64>, Vec<f64>, f64, bool)
153where
154    F: Fn(f64, &[f64], &P) -> Vec<f64>,
155{
156    let t = state.t;
157    let y = &state.y;
158    let h = state.h.min(t_end - t);
159
160    // Stages k1..k7
161    let k1 = state.k1.clone();
162    let y2 = axpy(y, h * DP_A[0][0], &k1);
163    let k2 = stage(f, t + DP_C[1] * h, &y2, param);
164
165    let y3 = {
166        let mut v = y.to_vec();
167        for i in 0..v.len() {
168            v[i] += h * (DP_A[1][0] * k1[i] + DP_A[1][1] * k2[i]);
169        }
170        v
171    };
172    let k3 = stage(f, t + DP_C[2] * h, &y3, param);
173
174    let y4 = {
175        let mut v = y.to_vec();
176        for i in 0..v.len() {
177            v[i] += h * (DP_A[2][0] * k1[i] + DP_A[2][1] * k2[i] + DP_A[2][2] * k3[i]);
178        }
179        v
180    };
181    let k4 = stage(f, t + DP_C[3] * h, &y4, param);
182
183    let y5 = {
184        let mut v = y.to_vec();
185        for i in 0..v.len() {
186            v[i] += h
187                * (DP_A[3][0] * k1[i]
188                    + DP_A[3][1] * k2[i]
189                    + DP_A[3][2] * k3[i]
190                    + DP_A[3][3] * k4[i]);
191        }
192        v
193    };
194    let k5 = stage(f, t + DP_C[4] * h, &y5, param);
195
196    let y6 = {
197        let mut v = y.to_vec();
198        for i in 0..v.len() {
199            v[i] += h
200                * (DP_A[4][0] * k1[i]
201                    + DP_A[4][1] * k2[i]
202                    + DP_A[4][2] * k3[i]
203                    + DP_A[4][3] * k4[i]
204                    + DP_A[4][4] * k5[i]);
205        }
206        v
207    };
208    let k6 = stage(f, t + DP_C[5] * h, &y6, param);
209
210    // 5th-order solution (FSAL: k7 = f(t+h, y6) = next k1)
211    let y_new = rk_sum(
212        y,
213        h,
214        &DP_B5[..6],
215        &[
216            k1.clone(),
217            k2.clone(),
218            k3.clone(),
219            k4.clone(),
220            k5.clone(),
221            k6.clone(),
222        ],
223    );
224    let k7 = stage(f, t + h, &y_new, param);
225
226    // 4th-order solution for error estimate
227    let y4_ord = rk_sum(y, h, &DP_B4, &[k1, k2, k3, k4, k5, k6, k7.clone()]);
228
229    // Error = 5th - 4th
230    let e: Vec<f64> = y_new
231        .iter()
232        .zip(y4_ord.iter())
233        .map(|(&a, &b)| a - b)
234        .collect();
235    let err = error_norm(y, &y_new, &e, rtol, atol);
236
237    // Step-size control (PI controller, safety factor 0.9)
238    let factor = if err == 0.0 {
239        5.0
240    } else {
241        0.9 * err.powf(-0.2)
242    };
243    let factor = factor.clamp(0.2, 5.0);
244    let h_next = h * factor;
245
246    if err <= 1.0 {
247        // Accepted
248        (y_new, k7, h_next, true)
249    } else {
250        // Rejected — return unchanged y, propose smaller h
251        (y.clone(), k7, h_next, false)
252    }
253}
254
255/// Integrate a single ODE member from `t0` to `t_end`.
256fn integrate_member<F, P>(
257    f: &F,
258    t0: f64,
259    t_end: f64,
260    y0: Vec<f64>,
261    param: &P,
262    rtol: f64,
263    atol: f64,
264    h_init: f64,
265    max_steps: usize,
266) -> (Vec<Vec<f64>>, Vec<f64>, bool, usize)
267where
268    F: Fn(f64, &[f64], &P) -> Vec<f64>,
269{
270    let n_state = y0.len();
271
272    // Choose initial step size
273    let h0 = if h_init > 0.0 {
274        h_init
275    } else {
276        // Estimate: h ~ 0.01 * (t_end - t0) but bounded
277        ((t_end - t0) * 0.01).max(1e-8).min((t_end - t0) / 10.0)
278    };
279
280    let k1_0 = f(t0, &y0, param);
281    let mut state = Rk45State::new(t0, y0.clone(), k1_0, h0);
282
283    let mut traj = vec![y0];
284    let mut times = vec![t0];
285    let mut n_steps = 0_usize;
286
287    while state.t < t_end - 1e-14 * (t_end - t0) && n_steps < max_steps {
288        let (y_new, k_new, h_next, accepted) = rk45_step(f, &state, t_end, rtol, atol, param);
289
290        if accepted {
291            state.t = (state.t + state.h).min(t_end);
292            state.y = y_new.clone();
293            state.k1 = k_new;
294            state.h = h_next.max(1e-14);
295            n_steps += 1;
296            traj.push(y_new);
297            times.push(state.t);
298        } else {
299            // Step rejected; update step size only
300            state.h = h_next.max(1e-14);
301        }
302
303        // Avoid step-size going below machine epsilon
304        if state.h < 1e-14 * state.t.abs().max(1.0) {
305            break;
306        }
307    }
308
309    let converged = if (state.t - t_end).abs() < 1e-12 * (t_end - t0 + 1.0) {
310        true
311    } else if n_steps == max_steps {
312        // Reached max; not fully converged
313        false
314    } else {
315        state.t >= t_end - 1e-10 * (t_end - t0)
316    };
317
318    // Ensure state vector isn't empty
319    if traj.is_empty() {
320        traj.push(vec![0.0; n_state]);
321        times.push(t0);
322    }
323
324    (traj, times, converged, n_steps)
325}
326
327// ── Public solver ─────────────────────────────────────────────────────────────
328
329/// Batched parallel ODE ensemble integrator.
330///
331/// Solves `n_ensemble` ODE IVPs in parallel.  Each member may have different
332/// initial conditions and/or parameters.
333pub struct OdeEnsembleSolver {
334    /// Configuration for the ensemble.
335    pub config: EnsembleConfig,
336}
337
338impl OdeEnsembleSolver {
339    /// Create a new solver with the given configuration.
340    pub fn new(config: EnsembleConfig) -> Self {
341        Self { config }
342    }
343
344    /// Integrate the ensemble.
345    ///
346    /// # Type parameters
347    ///
348    /// * `F` — RHS function `f(t, y, &param) -> Vec<f64>`.  Must be `Fn + Sync`.
349    /// * `P` — Parameter type.  Must be `Sync`.
350    ///
351    /// # Arguments
352    ///
353    /// * `f`       - ODE right-hand side.
354    /// * `params`  - Slice of parameters, one per member.
355    /// * `y0s`     - Slice of initial conditions, one per member.
356    /// * `config`  - Ensemble configuration (can differ from `self.config`).
357    ///
358    /// # Errors
359    ///
360    /// Returns `IntegrateError::InvalidInput` if `params.len() != y0s.len()`
361    /// or if `t_span` is invalid.
362    pub fn solve<F, P>(
363        &self,
364        f: F,
365        params: &[P],
366        y0s: &[Vec<f64>],
367        config: &EnsembleConfig,
368    ) -> IntegrateResult<EnsembleResult>
369    where
370        F: Fn(f64, &[f64], &P) -> Vec<f64> + Sync,
371        P: Sync,
372    {
373        if params.len() != y0s.len() {
374            return Err(IntegrateError::InvalidInput(format!(
375                "params.len() ({}) != y0s.len() ({})",
376                params.len(),
377                y0s.len()
378            )));
379        }
380
381        let (t0, t_end) = config.t_span;
382        if t0 >= t_end {
383            return Err(IntegrateError::InvalidInput(
384                "t_span must satisfy t0 < t_end".to_string(),
385            ));
386        }
387
388        let n = params.len();
389        if n == 0 {
390            return Ok(EnsembleResult {
391                trajectories: vec![],
392                times: vec![],
393                converged: vec![],
394                n_steps: vec![],
395            });
396        }
397
398        let rtol = config.rtol;
399        let atol = config.atol;
400        let h_init = config.h_init;
401        let max_steps = config.max_steps;
402        let n_threads = config.n_threads.max(1).min(n);
403
404        // Pre-allocate result storage
405        let mut trajectories: Vec<Vec<Vec<f64>>> = vec![Vec::new(); n];
406        let mut times_out: Vec<Vec<f64>> = vec![Vec::new(); n];
407        let mut converged: Vec<bool> = vec![false; n];
408        let mut n_steps_out: Vec<usize> = vec![0; n];
409
410        // We use indices to distribute work across threads.
411        // Build index chunks.
412        let chunk_size = n.div_ceil(n_threads);
413
414        // Shared result slots — declared OUTSIDE scope so they outlive the threads.
415        let results: Vec<std::sync::Mutex<Option<(Vec<Vec<f64>>, Vec<f64>, bool, usize)>>> =
416            (0..n).map(|_| std::sync::Mutex::new(None)).collect();
417
418        // Use thread::scope so we can borrow f, params, y0s safely.
419        std::thread::scope(|scope| {
420            let results_ref = &results;
421            let f_ref = &f;
422
423            // Spawn one thread per chunk.
424            for tid in 0..n_threads {
425                let start = tid * chunk_size;
426                if start >= n {
427                    break;
428                }
429                let end = (start + chunk_size).min(n);
430                let params_slice = &params[start..end];
431                let y0s_slice = &y0s[start..end];
432
433                scope.spawn(move || {
434                    for (local_idx, (param, y0)) in
435                        params_slice.iter().zip(y0s_slice.iter()).enumerate()
436                    {
437                        let global_idx = start + local_idx;
438                        let (traj, ts, conv, ns) = integrate_member(
439                            f_ref,
440                            t0,
441                            t_end,
442                            y0.clone(),
443                            param,
444                            rtol,
445                            atol,
446                            h_init,
447                            max_steps,
448                        );
449                        // Write result
450                        if let Ok(mut slot) = results_ref[global_idx].lock() {
451                            *slot = Some((traj, ts, conv, ns));
452                        }
453                    }
454                });
455            }
456            // scope drops here, joining all spawned threads.
457        });
458
459        // Collect results
460        for (i, slot) in results.into_iter().enumerate() {
461            if let Ok(Some((traj, ts, conv, ns))) = slot.into_inner() {
462                trajectories[i] = traj;
463                times_out[i] = ts;
464                converged[i] = conv;
465                n_steps_out[i] = ns;
466            }
467        }
468
469        Ok(EnsembleResult {
470            trajectories,
471            times: times_out,
472            converged,
473            n_steps: n_steps_out,
474        })
475    }
476}
477
478// ── Tests ─────────────────────────────────────────────────────────────────────
479
480#[cfg(test)]
481mod tests {
482    use super::*;
483
484    fn approx_eq(a: f64, b: f64, tol: f64) -> bool {
485        (a - b).abs() <= tol
486    }
487
488    /// dx/dt = -x, x(0)=1 → x(t) = e^{-t}.  Check 10 identical members.
489    #[test]
490    fn test_ensemble_exponential_decay() {
491        let cfg = EnsembleConfig {
492            n_ensemble: 10,
493            n_threads: 2,
494            rtol: 1e-8,
495            atol: 1e-10,
496            t_span: (0.0, 1.0),
497            max_steps: 10_000,
498            h_init: 0.0,
499        };
500
501        let solver = OdeEnsembleSolver::new(cfg.clone());
502        let params: Vec<f64> = vec![1.0; 10];
503        let y0s: Vec<Vec<f64>> = vec![vec![1.0]; 10];
504
505        let result = solver
506            .solve(|_t, y, &p| vec![-p * y[0]], &params, &y0s, &cfg)
507            .expect("solve failed");
508
509        assert_eq!(result.trajectories.len(), 10);
510        for (i, (traj, ts)) in result
511            .trajectories
512            .iter()
513            .zip(result.times.iter())
514            .enumerate()
515        {
516            let t_final = *ts.last().expect("no times");
517            let y_final = traj.last().expect("no trajectory")[0];
518            let expected = (-t_final).exp();
519            assert!(
520                approx_eq(y_final, expected, 1e-5),
521                "member {i}: y(t={t_final:.4}) = {y_final:.8}, expected {expected:.8}"
522            );
523        }
524    }
525
526    /// All 10 members should converge for the simple decay ODE.
527    #[test]
528    fn test_ensemble_all_converged() {
529        let cfg = EnsembleConfig {
530            n_ensemble: 10,
531            n_threads: 4,
532            rtol: 1e-8,
533            atol: 1e-10,
534            t_span: (0.0, 2.0),
535            max_steps: 50_000,
536            h_init: 0.0,
537        };
538        let solver = OdeEnsembleSolver::new(cfg.clone());
539        let params: Vec<f64> = vec![1.0; 10];
540        let y0s: Vec<Vec<f64>> = vec![vec![1.0]; 10];
541
542        let result = solver
543            .solve(|_t, y, &p| vec![-p * y[0]], &params, &y0s, &cfg)
544            .expect("solve failed");
545
546        for (i, &conv) in result.converged.iter().enumerate() {
547            assert!(conv, "member {i} did not converge");
548        }
549    }
550
551    /// Different initial conditions lead to different final values.
552    #[test]
553    fn test_ensemble_different_ics() {
554        let cfg = EnsembleConfig {
555            n_ensemble: 5,
556            n_threads: 2,
557            rtol: 1e-8,
558            atol: 1e-10,
559            t_span: (0.0, 1.0),
560            max_steps: 10_000,
561            h_init: 0.0,
562        };
563        let solver = OdeEnsembleSolver::new(cfg.clone());
564        let params: Vec<f64> = vec![1.0; 5];
565        // y0 = 1.0, 2.0, 3.0, 4.0, 5.0
566        let y0s: Vec<Vec<f64>> = (1..=5).map(|i| vec![i as f64]).collect();
567
568        let result = solver
569            .solve(|_t, y, &p| vec![-p * y[0]], &params, &y0s, &cfg)
570            .expect("solve failed");
571
572        // Final values should differ
573        let finals: Vec<f64> = result
574            .trajectories
575            .iter()
576            .map(|traj| traj.last().expect("no traj")[0])
577            .collect();
578
579        for i in 1..finals.len() {
580            assert!(
581                (finals[i] - finals[0]).abs() > 0.1,
582                "members 0 and {i} should differ: {} vs {}",
583                finals[0],
584                finals[i]
585            );
586        }
587    }
588
589    /// EnsembleConfig::default() is well-formed.
590    #[test]
591    fn test_ensemble_config_default() {
592        let cfg = EnsembleConfig::default();
593        assert!(cfg.n_ensemble > 0);
594        assert!(cfg.n_threads > 0);
595        assert!(cfg.rtol > 0.0);
596        assert!(cfg.atol > 0.0);
597        let (t0, t1) = cfg.t_span;
598        assert!(t0 < t1);
599    }
600
601    /// n_threads=1 vs n_threads=2 should give identical results.
602    #[test]
603    fn test_ensemble_parallel_same_as_serial() {
604        let mk_cfg = |n_threads: usize| EnsembleConfig {
605            n_ensemble: 4,
606            n_threads,
607            rtol: 1e-8,
608            atol: 1e-10,
609            t_span: (0.0, 1.0),
610            max_steps: 10_000,
611            h_init: 0.0,
612        };
613
614        let params: Vec<f64> = vec![0.5, 1.0, 1.5, 2.0];
615        let y0s: Vec<Vec<f64>> = vec![vec![1.0]; 4];
616
617        let f = |_t: f64, y: &[f64], &p: &f64| vec![-p * y[0]];
618
619        let cfg1 = mk_cfg(1);
620        let solver1 = OdeEnsembleSolver::new(cfg1.clone());
621        let res1 = solver1
622            .solve(f, &params, &y0s, &cfg1)
623            .expect("solve 1 failed");
624
625        let cfg2 = mk_cfg(2);
626        let solver2 = OdeEnsembleSolver::new(cfg2.clone());
627        let res2 = solver2
628            .solve(f, &params, &y0s, &cfg2)
629            .expect("solve 2 failed");
630
631        for i in 0..4 {
632            let y1 = res1.trajectories[i].last().expect("no traj1")[0];
633            let y2 = res2.trajectories[i].last().expect("no traj2")[0];
634            assert!(
635                approx_eq(y1, y2, 1e-10),
636                "member {i}: thread-1={y1}, thread-2={y2}"
637            );
638        }
639    }
640
641    /// Mean of identical ODEs should equal the single ODE solution.
642    #[test]
643    fn test_ensemble_mean_trajectory() {
644        let cfg = EnsembleConfig {
645            n_ensemble: 5,
646            n_threads: 2,
647            rtol: 1e-8,
648            atol: 1e-10,
649            t_span: (0.0, 1.0),
650            max_steps: 10_000,
651            h_init: 1e-3,
652        };
653        let solver = OdeEnsembleSolver::new(cfg.clone());
654        let params: Vec<f64> = vec![1.0; 5];
655        let y0s: Vec<Vec<f64>> = vec![vec![1.0]; 5];
656
657        let result = solver
658            .solve(|_t, y, &p| vec![-p * y[0]], &params, &y0s, &cfg)
659            .expect("solve failed");
660
661        let mean = result.mean_trajectory().expect("mean failed");
662        // Mean should equal single trajectory (all identical)
663        let single = &result.trajectories[0];
664        let min_len = mean.len().min(single.len());
665        for k in 0..min_len {
666            assert!(
667                approx_eq(mean[k][0], single[k][0], 1e-10),
668                "mean[{k}]={}, single[{k}]={}",
669                mean[k][0],
670                single[k][0]
671            );
672        }
673    }
674}