Skip to main content

scirs2_integrate/
gpu_ode_ensemble.rs

1//! Batched ODE ensemble integration across parameter sets.
2//!
3//! This module provides a mechanism to solve many ODE initial-value problems
4//! simultaneously, each with its own set of parameters and initial conditions
5//! but sharing the same right-hand-side structure.  This pattern arises in
6//! parameter sweeps, uncertainty quantification, and neural-ODE training.
7//!
8//! ## Algorithm
9//!
10//! Each ensemble member is integrated independently using the Dormand-Prince
11//! adaptive RK45 scheme (the same pair used by `scipy.integrate.solve_ivp`
12//! with `method='RK45'`).  Step-size control follows the standard PI-controller
13//! formula:
14//!
15//! ```text
16//! h_new = h * min(facmax, max(facmin, fac * (1/err)^(1/5)))
17//! ```
18//!
19//! with `fac = 0.9`, `facmax = 10.0`, `facmin = 0.2`.
20//!
21//! ## Dispatch
22//!
23//! [`EnsembleDispatch::Sequential`] processes members one at a time on the CPU.
24//! [`EnsembleDispatch::Simulated`] represents a conceptual GPU batched dispatch
25//! (same numerics, different conceptual path) and is provided for API
26//! compatibility with future hardware acceleration.
27//!
28//! ## Example
29//!
30//! ```rust
31//! use scirs2_integrate::gpu_ode_ensemble::{
32//!     OdeEnsemble, OdeEnsembleConfig, EnsembleMember, EnsembleDispatch,
33//! };
34//!
35//! // Solve y' = -k * y for several values of k
36//! let config = OdeEnsembleConfig {
37//!     t_span: [0.0, 1.0],
38//!     rtol: 1e-6,
39//!     atol: 1e-9,
40//!     max_steps: 10_000,
41//!     dispatch: EnsembleDispatch::Sequential,
42//! };
43//! let members: Vec<EnsembleMember> = (1..=5)
44//!     .map(|k| EnsembleMember {
45//!         params: vec![k as f64],
46//!         y0: vec![1.0],
47//!     })
48//!     .collect();
49//!
50//! let ensemble = OdeEnsemble::new(config);
51//! let result = ensemble.integrate(&members, &|t, y, p| vec![-p[0] * y[0]]);
52//! assert!(result.success.iter().all(|&s| s));
53//! ```
54
55/// Convenience type alias for the right-hand-side function signature.
56///
57/// The arguments are `(t, y, params) -> dydt`.
58pub type OdeRhsFn = Box<dyn Fn(f64, &[f64], &[f64]) -> Vec<f64> + Send + Sync>;
59
60/// Execution strategy for the ensemble.
61#[derive(Debug, Clone, Copy, PartialEq, Eq)]
62pub enum EnsembleDispatch {
63    /// Integrate members one at a time on the CPU.
64    Sequential,
65    /// Simulated GPU-batched execution (same numerics as `Sequential`).
66    Simulated,
67}
68
69/// Configuration for [`OdeEnsemble`].
70#[derive(Debug, Clone)]
71pub struct OdeEnsembleConfig {
72    /// Integration interval `[t_start, t_end]`.
73    pub t_span: [f64; 2],
74    /// Relative tolerance for the adaptive stepper.
75    pub rtol: f64,
76    /// Absolute tolerance for the adaptive stepper.
77    pub atol: f64,
78    /// Maximum number of steps per member before declaring failure.
79    pub max_steps: usize,
80    /// Execution dispatch strategy.
81    pub dispatch: EnsembleDispatch,
82}
83
84/// One member of the ensemble: its parameters and initial conditions.
85#[derive(Debug, Clone)]
86pub struct EnsembleMember {
87    /// Parameters passed to the RHS as the third argument.
88    pub params: Vec<f64>,
89    /// Initial condition `y(t_start)`.
90    pub y0: Vec<f64>,
91}
92
93/// Result of integrating a full ensemble.
94#[derive(Debug, Clone)]
95pub struct EnsembleResult {
96    /// Final state `y(t_end)` for each member.
97    pub solutions: Vec<Vec<f64>>,
98    /// Number of steps taken per member.
99    pub n_steps: Vec<usize>,
100    /// Whether each member converged within `max_steps`.
101    pub success: Vec<bool>,
102    /// Final time reached by each member.
103    pub t_final: Vec<f64>,
104}
105
106/// Ensemble ODE integrator.
107pub struct OdeEnsemble {
108    config: OdeEnsembleConfig,
109}
110
111// ─────────────────────────────────────────────────────────────────────────────
112// Dormand-Prince RK45 Butcher tableau
113// ─────────────────────────────────────────────────────────────────────────────
114//
115//   0    |
116//  1/5   | 1/5
117//  3/10  | 3/40        9/40
118//  4/5   | 44/45      -56/15      32/9
119//  8/9   | 19372/6561 -25360/2187  64448/6561  -212/729
120//  1     | 9017/3168  -355/33      46732/5247   49/176   -5103/18656
121//  1     | 35/384      0           500/1113     125/192  -2187/6784   11/84
122//
123//  Order-4 error estimate coefficients (difference: 5th − 4th order)
124//  e = y5 − y4
125//  e1 = 71/57600,  e3 = -71/16695, e4 = 71/1920, e5 = -17253/339200, e6 = 22/525, e7 = -1/40
126
127const A21: f64 = 1.0 / 5.0;
128const A31: f64 = 3.0 / 40.0;
129const A32: f64 = 9.0 / 40.0;
130const A41: f64 = 44.0 / 45.0;
131const A42: f64 = -56.0 / 15.0;
132const A43: f64 = 32.0 / 9.0;
133const A51: f64 = 19372.0 / 6561.0;
134const A52: f64 = -25360.0 / 2187.0;
135const A53: f64 = 64448.0 / 6561.0;
136const A54: f64 = -212.0 / 729.0;
137const A61: f64 = 9017.0 / 3168.0;
138const A62: f64 = -355.0 / 33.0;
139const A63: f64 = 46732.0 / 5247.0;
140const A64: f64 = 49.0 / 176.0;
141const A65: f64 = -5103.0 / 18656.0;
142
143// 5th-order solution weights
144const B1: f64 = 35.0 / 384.0;
145const B3: f64 = 500.0 / 1113.0;
146const B4: f64 = 125.0 / 192.0;
147const B5: f64 = -2187.0 / 6784.0;
148const B6: f64 = 11.0 / 84.0;
149
150// Error coefficients (5th − 4th order)
151const E1: f64 = 71.0 / 57600.0;
152const E3: f64 = -71.0 / 16695.0;
153const E4: f64 = 71.0 / 1920.0;
154const E5: f64 = -17253.0 / 339200.0;
155const E6: f64 = 22.0 / 525.0;
156const E7: f64 = -1.0 / 40.0;
157
158// Node positions (c values)
159const C2: f64 = 1.0 / 5.0;
160const C3: f64 = 3.0 / 10.0;
161const C4: f64 = 4.0 / 5.0;
162const C5: f64 = 8.0 / 9.0;
163
164// ─────────────────────────────────────────────────────────────────────────────
165// Core RK45 step
166// ─────────────────────────────────────────────────────────────────────────────
167
168/// Dormand-Prince RK45 adaptive step.
169///
170/// Advances the state from `(t, y)` by step `h` using the Dormand-Prince
171/// pair.  Returns `(y_order5, y_order4_error_estimate, error_norm)`.
172///
173/// The error norm is the RMS of the componentwise scaled errors:
174/// `err_i / (atol + rtol * max(|y_i|, |y5_i|))`.
175///
176/// # Arguments
177///
178/// * `t`      — current time.
179/// * `y`      — current state (length `n`).
180/// * `params` — parameters forwarded verbatim to `rhs`.
181/// * `h`      — step size (may be positive or negative).
182/// * `rhs`    — right-hand side `f(t, y, params) -> dydt`.
183/// * `rtol`   — relative tolerance (for error scaling).
184/// * `atol`   — absolute tolerance (for error scaling).
185///
186/// Returns `(y5, err_norm)` where `y5` is the 5th-order solution.
187pub fn rk45_step(
188    t: f64,
189    y: &[f64],
190    params: &[f64],
191    h: f64,
192    rhs: &dyn Fn(f64, &[f64], &[f64]) -> Vec<f64>,
193    rtol: f64,
194    atol: f64,
195) -> (Vec<f64>, Vec<f64>, f64) {
196    let n = y.len();
197
198    // Stage 1
199    let k1 = rhs(t, y, params);
200
201    // Stage 2
202    let y2: Vec<f64> = (0..n).map(|i| y[i] + h * A21 * k1[i]).collect();
203    let k2 = rhs(t + C2 * h, &y2, params);
204
205    // Stage 3
206    let y3: Vec<f64> = (0..n)
207        .map(|i| y[i] + h * (A31 * k1[i] + A32 * k2[i]))
208        .collect();
209    let k3 = rhs(t + C3 * h, &y3, params);
210
211    // Stage 4
212    let y4: Vec<f64> = (0..n)
213        .map(|i| y[i] + h * (A41 * k1[i] + A42 * k2[i] + A43 * k3[i]))
214        .collect();
215    let k4 = rhs(t + C4 * h, &y4, params);
216
217    // Stage 5
218    let y5_tmp: Vec<f64> = (0..n)
219        .map(|i| y[i] + h * (A51 * k1[i] + A52 * k2[i] + A53 * k3[i] + A54 * k4[i]))
220        .collect();
221    let k5 = rhs(t + C5 * h, &y5_tmp, params);
222
223    // Stage 6
224    let y6_tmp: Vec<f64> = (0..n)
225        .map(|i| y[i] + h * (A61 * k1[i] + A62 * k2[i] + A63 * k3[i] + A64 * k4[i] + A65 * k5[i]))
226        .collect();
227    let k6 = rhs(t + h, &y6_tmp, params);
228
229    // 5th-order solution
230    let y_new: Vec<f64> = (0..n)
231        .map(|i| y[i] + h * (B1 * k1[i] + B3 * k3[i] + B4 * k4[i] + B5 * k5[i] + B6 * k6[i]))
232        .collect();
233
234    // Stage 7 (FSAL: first same as last)
235    let k7 = rhs(t + h, &y_new, params);
236
237    // Error estimate: e = y5 - y4  (using the E coefficients)
238    let err_vec: Vec<f64> = (0..n)
239        .map(|i| h * (E1 * k1[i] + E3 * k3[i] + E4 * k4[i] + E5 * k5[i] + E6 * k6[i] + E7 * k7[i]))
240        .collect();
241
242    // RMS error norm (scaled)
243    let err_norm = {
244        let sum_sq: f64 = (0..n)
245            .map(|i| {
246                let sc = atol + rtol * y[i].abs().max(y_new[i].abs());
247                let e = err_vec[i] / sc;
248                e * e
249            })
250            .sum::<f64>();
251        (sum_sq / n as f64).sqrt()
252    };
253
254    (y_new, err_vec, err_norm)
255}
256
257// ─────────────────────────────────────────────────────────────────────────────
258// OdeEnsemble implementation
259// ─────────────────────────────────────────────────────────────────────────────
260
261impl OdeEnsemble {
262    /// Create a new ensemble integrator with the given configuration.
263    pub fn new(config: OdeEnsembleConfig) -> Self {
264        Self { config }
265    }
266
267    /// Integrate all members from `t_span[0]` to `t_span[1]`.
268    ///
269    /// # Arguments
270    ///
271    /// * `members` — slice of ensemble members (parameters + initial conditions).
272    /// * `rhs`     — right-hand side `f(t, y, params) -> dydt`.
273    ///
274    /// # Returns
275    ///
276    /// An [`EnsembleResult`] containing the final state for each member.
277    pub fn integrate(
278        &self,
279        members: &[EnsembleMember],
280        rhs: &dyn Fn(f64, &[f64], &[f64]) -> Vec<f64>,
281    ) -> EnsembleResult {
282        let n = members.len();
283        let mut solutions = Vec::with_capacity(n);
284        let mut n_steps_vec = Vec::with_capacity(n);
285        let mut success_vec = Vec::with_capacity(n);
286        let mut t_final_vec = Vec::with_capacity(n);
287
288        for member in members {
289            let (y_final, n_steps, ok) = self.integrate_single(member, rhs);
290            let t_reached = if ok {
291                self.config.t_span[1]
292            } else {
293                // Report partial progress: we don't track intermediate times in
294                // the current implementation, so report t_start on failure.
295                self.config.t_span[0]
296            };
297            solutions.push(y_final);
298            n_steps_vec.push(n_steps);
299            success_vec.push(ok);
300            t_final_vec.push(t_reached);
301        }
302
303        EnsembleResult {
304            solutions,
305            n_steps: n_steps_vec,
306            success: success_vec,
307            t_final: t_final_vec,
308        }
309    }
310
311    /// Integrate a single ensemble member using adaptive RK45.
312    ///
313    /// Returns `(final_y, n_steps, converged)`.
314    fn integrate_single(
315        &self,
316        member: &EnsembleMember,
317        rhs: &dyn Fn(f64, &[f64], &[f64]) -> Vec<f64>,
318    ) -> (Vec<f64>, usize, bool) {
319        let t_start = self.config.t_span[0];
320        let t_end = self.config.t_span[1];
321        let rtol = self.config.rtol;
322        let atol = self.config.atol;
323        let max_steps = self.config.max_steps;
324
325        let mut t = t_start;
326        let mut y = member.y0.clone();
327        let n = y.len();
328
329        if n == 0 {
330            return (y, 0, true);
331        }
332
333        // Initial step size heuristic
334        let span = (t_end - t_start).abs();
335        let mut h = span * 1e-3;
336        // Clamp to avoid overshooting on the first step
337        h = h.min(span);
338
339        let direction = if t_end >= t_start { 1.0_f64 } else { -1.0 };
340        h *= direction;
341
342        let fac = 0.9_f64;
343        let fac_max = 10.0_f64;
344        let fac_min = 0.2_f64;
345
346        let mut steps = 0_usize;
347        let mut converged = false;
348
349        while (direction * (t_end - t)).abs() > 1e-12 * span.max(f64::EPSILON) {
350            if steps >= max_steps {
351                break;
352            }
353
354            // Don't overshoot the endpoint
355            if direction * (t + h - t_end) > 0.0 {
356                h = t_end - t;
357            }
358            if h.abs() < f64::EPSILON * span {
359                // Step size collapsed — declare failure
360                break;
361            }
362
363            let (y_new, _err_vec, err_norm) = rk45_step(t, &y, &member.params, h, rhs, rtol, atol);
364
365            // Accept or reject step
366            if err_norm <= 1.0 || err_norm.is_nan() {
367                // Accept
368                t += h;
369                y = y_new;
370                steps += 1;
371
372                if (direction * (t_end - t)).abs() < 1e-12 * span.max(f64::EPSILON) {
373                    converged = true;
374                    break;
375                }
376            }
377
378            // Adjust step size
379            let err_safe = err_norm.max(f64::EPSILON);
380            let factor = fac * err_safe.powf(-0.2);
381            let factor = factor.clamp(fac_min, fac_max);
382            h *= factor;
383
384            // Safety: if we accepted, count this step towards the limit already
385            // (done above via `steps += 1`).
386        }
387
388        // If we've reached t_end within tolerance, mark as converged
389        if (t - t_end).abs() < 1e-8 * span.max(f64::EPSILON) {
390            converged = true;
391        }
392
393        (y, steps, converged)
394    }
395}
396
397// ─────────────────────────────────────────────────────────────────────────────
398// Tests
399// ─────────────────────────────────────────────────────────────────────────────
400
401#[cfg(test)]
402mod tests {
403    use super::*;
404
405    fn default_config() -> OdeEnsembleConfig {
406        OdeEnsembleConfig {
407            t_span: [0.0, 1.0],
408            rtol: 1e-7,
409            atol: 1e-9,
410            max_steps: 100_000,
411            dispatch: EnsembleDispatch::Sequential,
412        }
413    }
414
415    /// Five members with identical parameters and initial conditions must
416    /// produce identical solutions.
417    #[test]
418    fn test_identical_params_same_solution() {
419        let config = default_config();
420        let ensemble = OdeEnsemble::new(config);
421        let members: Vec<EnsembleMember> = (0..5)
422            .map(|_| EnsembleMember {
423                params: vec![2.0],
424                y0: vec![1.0],
425            })
426            .collect();
427        let result = ensemble.integrate(&members, &|_t, y, p| vec![-p[0] * y[0]]);
428        let y0 = &result.solutions[0];
429        for (i, sol) in result.solutions.iter().enumerate().skip(1) {
430            assert!(
431                (sol[0] - y0[0]).abs() < 1e-14,
432                "member {i} diverges from member 0: {:.6e} vs {:.6e}",
433                sol[0],
434                y0[0]
435            );
436        }
437    }
438
439    /// Members with different decay rates must give different final values.
440    #[test]
441    fn test_different_params_different_solutions() {
442        let config = default_config();
443        let ensemble = OdeEnsemble::new(config);
444        let ks: Vec<f64> = vec![0.5, 1.0, 2.0, 4.0, 8.0];
445        let members: Vec<EnsembleMember> = ks
446            .iter()
447            .map(|&k| EnsembleMember {
448                params: vec![k],
449                y0: vec![1.0],
450            })
451            .collect();
452        let result = ensemble.integrate(&members, &|_t, y, p| vec![-p[0] * y[0]]);
453        // Higher k → smaller y(1)
454        for i in 1..ks.len() {
455            let y_prev = result.solutions[i - 1][0];
456            let y_curr = result.solutions[i][0];
457            assert!(
458                y_curr < y_prev,
459                "k={} solution ({:.6e}) should be < k={} solution ({:.6e})",
460                ks[i],
461                y_curr,
462                ks[i - 1],
463                y_prev
464            );
465        }
466    }
467
468    /// Exponential decay: y' = -k*y, y(0) = y0.
469    /// Analytical solution: y(t) = y0 * exp(-k*t).
470    #[test]
471    fn test_exponential_decay_analytical() {
472        let config = OdeEnsembleConfig {
473            t_span: [0.0, 2.0],
474            rtol: 1e-8,
475            atol: 1e-10,
476            max_steps: 100_000,
477            dispatch: EnsembleDispatch::Sequential,
478        };
479        let ensemble = OdeEnsemble::new(config);
480        let k = 3.0_f64;
481        let y0 = 2.5_f64;
482        let members = vec![EnsembleMember {
483            params: vec![k],
484            y0: vec![y0],
485        }];
486        let result = ensemble.integrate(&members, &|_t, y, p| vec![-p[0] * y[0]]);
487        let y_numerical = result.solutions[0][0];
488        let y_analytical = y0 * (-k * 2.0_f64).exp();
489        assert!(
490            (y_numerical - y_analytical).abs() < 1e-6,
491            "y_numerical = {y_numerical:.8e}, y_analytical = {y_analytical:.8e}"
492        );
493    }
494
495    /// All members of a well-behaved system must converge.
496    #[test]
497    fn test_all_converge() {
498        let config = default_config();
499        let ensemble = OdeEnsemble::new(config);
500        let members: Vec<EnsembleMember> = (1..=5)
501            .map(|k| EnsembleMember {
502                params: vec![k as f64],
503                y0: vec![1.0],
504            })
505            .collect();
506        let result = ensemble.integrate(&members, &|_t, y, p| vec![-p[0] * y[0]]);
507        for (i, &ok) in result.success.iter().enumerate() {
508            assert!(ok, "member {i} did not converge");
509        }
510    }
511
512    /// Number of steps must be positive for all members.
513    #[test]
514    fn test_n_steps_positive() {
515        let config = default_config();
516        let ensemble = OdeEnsemble::new(config);
517        let members: Vec<EnsembleMember> = (1..=5)
518            .map(|k| EnsembleMember {
519                params: vec![k as f64],
520                y0: vec![1.0],
521            })
522            .collect();
523        let result = ensemble.integrate(&members, &|_t, y, p| vec![-p[0] * y[0]]);
524        for (i, &ns) in result.n_steps.iter().enumerate() {
525            assert!(ns > 0, "member {i} took 0 steps");
526        }
527    }
528
529    /// 2-D system: van-der-Pol oscillator at low μ must be stable.
530    #[test]
531    fn test_2d_system_vanderpol() {
532        let config = OdeEnsembleConfig {
533            t_span: [0.0, 5.0],
534            rtol: 1e-6,
535            atol: 1e-8,
536            max_steps: 500_000,
537            dispatch: EnsembleDispatch::Sequential,
538        };
539        let ensemble = OdeEnsemble::new(config);
540        // mu = 0.1  (weak non-linearity)
541        let member = EnsembleMember {
542            params: vec![0.1],
543            y0: vec![2.0, 0.0],
544        };
545        let result = ensemble.integrate(&[member], &|_t, y, p| {
546            let mu = p[0];
547            vec![y[1], mu * (1.0 - y[0] * y[0]) * y[1] - y[0]]
548        });
549        assert!(result.success[0], "van-der-Pol did not converge");
550        // Final state should be finite
551        for &v in &result.solutions[0] {
552            assert!(v.is_finite(), "van-der-Pol solution is non-finite");
553        }
554    }
555
556    /// Simulated dispatch produces the same solutions as sequential.
557    #[test]
558    fn test_simulated_dispatch_matches_sequential() {
559        let config_seq = OdeEnsembleConfig {
560            t_span: [0.0, 1.0],
561            rtol: 1e-7,
562            atol: 1e-9,
563            max_steps: 50_000,
564            dispatch: EnsembleDispatch::Sequential,
565        };
566        let config_sim = OdeEnsembleConfig {
567            dispatch: EnsembleDispatch::Simulated,
568            ..config_seq.clone()
569        };
570        let members: Vec<EnsembleMember> = vec![
571            EnsembleMember {
572                params: vec![1.0],
573                y0: vec![1.0],
574            },
575            EnsembleMember {
576                params: vec![2.0],
577                y0: vec![3.0],
578            },
579        ];
580        let ens_seq = OdeEnsemble::new(config_seq);
581        let ens_sim = OdeEnsemble::new(config_sim);
582        let rhs = &|_t: f64, y: &[f64], p: &[f64]| vec![-p[0] * y[0]];
583        let res_seq = ens_seq.integrate(&members, rhs);
584        let res_sim = ens_sim.integrate(&members, rhs);
585        for i in 0..members.len() {
586            assert!(
587                (res_seq.solutions[i][0] - res_sim.solutions[i][0]).abs() < 1e-14,
588                "member {i}: sequential={:.6e}, simulated={:.6e}",
589                res_seq.solutions[i][0],
590                res_sim.solutions[i][0]
591            );
592        }
593    }
594}