Skip to main content

sciforge_hub/engine/simulation/
integrator.rs

1use super::model::DynamicalSystem;
2use crate::domain::common::errors::{HubError, HubResult};
3
4/// Available numerical integration methods.
5#[derive(Debug, Clone, Copy, PartialEq, Eq)]
6pub enum IntegrationMethod {
7    Euler,
8    Heun,
9    RungeKutta4,
10    Midpoint,
11}
12
13/// Configuration for the ODE integrator.
14#[derive(Debug, Clone)]
15pub struct IntegratorConfig {
16    /// Integration scheme.
17    pub method: IntegrationMethod,
18    /// Time step.
19    pub dt: f64,
20    /// Maximum number of steps.
21    pub max_steps: usize,
22    /// Error tolerance.
23    pub tolerance: f64,
24}
25
26impl IntegratorConfig {
27    /// Creates a config with default max_steps and tolerance.
28    pub fn new(method: IntegrationMethod, dt: f64) -> Self {
29        Self {
30            method,
31            dt,
32            max_steps: 1_000_000,
33            tolerance: 1e-8,
34        }
35    }
36}
37
38/// Integrates the system from `t0` to `tf` starting at `y0`.
39pub fn integrate(
40    config: &IntegratorConfig,
41    system: &dyn DynamicalSystem,
42    y0: &[f64],
43    t0: f64,
44    tf: f64,
45) -> HubResult<(Vec<f64>, Vec<Vec<f64>>)> {
46    if config.dt <= 0.0 {
47        return Err(HubError::InvalidInput("dt must be positive".into()));
48    }
49    let steps = ((tf - t0) / config.dt).ceil() as usize;
50    if steps > config.max_steps {
51        return Err(HubError::InvalidInput("too many steps required".into()));
52    }
53
54    let dim = y0.len();
55    let mut times = Vec::with_capacity(steps + 1);
56    let mut states = Vec::with_capacity(steps + 1);
57    let mut y = y0.to_vec();
58    let mut t = t0;
59    let mut dy = vec![0.0; dim];
60
61    times.push(t);
62    states.push(y.clone());
63
64    for _ in 0..steps {
65        match config.method {
66            IntegrationMethod::Euler => {
67                system.derivatives(t, &y, &mut dy);
68                for (yi, &dyi) in y.iter_mut().zip(dy.iter()) {
69                    *yi += config.dt * dyi;
70                }
71            }
72            IntegrationMethod::Heun => {
73                system.derivatives(t, &y, &mut dy);
74                let k1 = dy.clone();
75                let y_pred: Vec<f64> = y
76                    .iter()
77                    .zip(k1.iter())
78                    .map(|(&yi, &ki)| yi + config.dt * ki)
79                    .collect();
80                system.derivatives(t + config.dt, &y_pred, &mut dy);
81                for (i, yi) in y.iter_mut().enumerate() {
82                    *yi += 0.5 * config.dt * (k1[i] + dy[i]);
83                }
84            }
85            IntegrationMethod::Midpoint => {
86                system.derivatives(t, &y, &mut dy);
87                let y_mid: Vec<f64> = y
88                    .iter()
89                    .zip(dy.iter())
90                    .map(|(&yi, &ki)| yi + 0.5 * config.dt * ki)
91                    .collect();
92                system.derivatives(t + 0.5 * config.dt, &y_mid, &mut dy);
93                for (yi, &dyi) in y.iter_mut().zip(dy.iter()) {
94                    *yi += config.dt * dyi;
95                }
96            }
97            IntegrationMethod::RungeKutta4 => {
98                system.derivatives(t, &y, &mut dy);
99                let k1 = dy.clone();
100                let y2: Vec<f64> = y
101                    .iter()
102                    .zip(k1.iter())
103                    .map(|(&yi, &ki)| yi + 0.5 * config.dt * ki)
104                    .collect();
105                system.derivatives(t + 0.5 * config.dt, &y2, &mut dy);
106                let k2 = dy.clone();
107                let y3: Vec<f64> = y
108                    .iter()
109                    .zip(k2.iter())
110                    .map(|(&yi, &ki)| yi + 0.5 * config.dt * ki)
111                    .collect();
112                system.derivatives(t + 0.5 * config.dt, &y3, &mut dy);
113                let k3 = dy.clone();
114                let y4: Vec<f64> = y
115                    .iter()
116                    .zip(k3.iter())
117                    .map(|(&yi, &ki)| yi + config.dt * ki)
118                    .collect();
119                system.derivatives(t + config.dt, &y4, &mut dy);
120                for i in 0..dim {
121                    y[i] += config.dt / 6.0 * (k1[i] + 2.0 * k2[i] + 2.0 * k3[i] + dy[i]);
122                }
123            }
124        }
125        t += config.dt;
126        times.push(t);
127        states.push(y.clone());
128    }
129
130    Ok((times, states))
131}