Skip to main content

sciforge_lib/maths/ode/
solvers.rs

1use crate::constants::{
2    ADAMS_BASHFORTH_4, RK45_A2, RK45_A3, RK45_A4, RK45_A5, RK45_A6, RK45_B21, RK45_B31, RK45_B32,
3    RK45_B41, RK45_B42, RK45_B43, RK45_B51, RK45_B52, RK45_B53, RK45_B54, RK45_B61, RK45_B62,
4    RK45_B63, RK45_B64, RK45_B65, RK45_C1, RK45_C3, RK45_C4, RK45_C5, RK45_C6, RK45_D1, RK45_D3,
5    RK45_D4, RK45_D5,
6};
7
8pub struct OdeResult {
9    pub t: Vec<f64>,
10    pub y: Vec<Vec<f64>>,
11}
12
13pub fn euler(
14    f: impl Fn(f64, &[f64]) -> Vec<f64>,
15    t_span: (f64, f64),
16    y0: &[f64],
17    dt: f64,
18) -> OdeResult {
19    let mut t = t_span.0;
20    let mut y = y0.to_vec();
21    let mut ts = vec![t];
22    let mut ys = vec![y.clone()];
23
24    while t < t_span.1 - 1e-12 {
25        let h = dt.min(t_span.1 - t);
26        let dy = f(t, &y);
27        for i in 0..y.len() {
28            y[i] += h * dy[i];
29        }
30        t += h;
31        ts.push(t);
32        ys.push(y.clone());
33    }
34    OdeResult { t: ts, y: ys }
35}
36
37pub fn rk4(
38    f: impl Fn(f64, &[f64]) -> Vec<f64>,
39    t_span: (f64, f64),
40    y0: &[f64],
41    dt: f64,
42) -> OdeResult {
43    let n = y0.len();
44    let mut t = t_span.0;
45    let mut y = y0.to_vec();
46    let mut ts = vec![t];
47    let mut ys = vec![y.clone()];
48
49    while t < t_span.1 - 1e-12 {
50        let h = dt.min(t_span.1 - t);
51        let k1 = f(t, &y);
52        let y2: Vec<f64> = (0..n).map(|i| y[i] + 0.5 * h * k1[i]).collect();
53        let k2 = f(t + 0.5 * h, &y2);
54        let y3: Vec<f64> = (0..n).map(|i| y[i] + 0.5 * h * k2[i]).collect();
55        let k3 = f(t + 0.5 * h, &y3);
56        let y4: Vec<f64> = (0..n).map(|i| y[i] + h * k3[i]).collect();
57        let k4 = f(t + h, &y4);
58
59        for i in 0..n {
60            y[i] += h / 6.0 * (k1[i] + 2.0 * k2[i] + 2.0 * k3[i] + k4[i]);
61        }
62        t += h;
63        ts.push(t);
64        ys.push(y.clone());
65    }
66    OdeResult { t: ts, y: ys }
67}
68
69pub fn rk45_adaptive(
70    f: impl Fn(f64, &[f64]) -> Vec<f64>,
71    t_span: (f64, f64),
72    y0: &[f64],
73    tol: f64,
74    dt_init: f64,
75) -> OdeResult {
76    let n = y0.len();
77    let mut t = t_span.0;
78    let mut y = y0.to_vec();
79    let mut h = dt_init;
80    let mut ts = vec![t];
81    let mut ys = vec![y.clone()];
82
83    while t < t_span.1 - 1e-12 {
84        h = h.min(t_span.1 - t);
85        let k1 = f(t, &y);
86        let y2: Vec<f64> = (0..n).map(|i| y[i] + h * RK45_B21 * k1[i]).collect();
87        let k2 = f(t + RK45_A2 * h, &y2);
88        let y3: Vec<f64> = (0..n)
89            .map(|i| y[i] + h * (RK45_B31 * k1[i] + RK45_B32 * k2[i]))
90            .collect();
91        let k3 = f(t + RK45_A3 * h, &y3);
92        let y4: Vec<f64> = (0..n)
93            .map(|i| y[i] + h * (RK45_B41 * k1[i] + RK45_B42 * k2[i] + RK45_B43 * k3[i]))
94            .collect();
95        let k4 = f(t + RK45_A4 * h, &y4);
96        let y5: Vec<f64> = (0..n)
97            .map(|i| {
98                y[i] + h
99                    * (RK45_B51 * k1[i] + RK45_B52 * k2[i] + RK45_B53 * k3[i] + RK45_B54 * k4[i])
100            })
101            .collect();
102        let k5 = f(t + RK45_A5 * h, &y5);
103        let y6: Vec<f64> = (0..n)
104            .map(|i| {
105                y[i] + h
106                    * (RK45_B61 * k1[i]
107                        + RK45_B62 * k2[i]
108                        + RK45_B63 * k3[i]
109                        + RK45_B64 * k4[i]
110                        + RK45_B65 * k5[i])
111            })
112            .collect();
113        let k6 = f(t + RK45_A6 * h, &y6);
114
115        let y5th: Vec<f64> = (0..n)
116            .map(|i| {
117                y[i] + h
118                    * (RK45_C1 * k1[i]
119                        + RK45_C3 * k3[i]
120                        + RK45_C4 * k4[i]
121                        + RK45_C5 * k5[i]
122                        + RK45_C6 * k6[i])
123            })
124            .collect();
125        let y4th: Vec<f64> = (0..n)
126            .map(|i| {
127                y[i] + h * (RK45_D1 * k1[i] + RK45_D3 * k3[i] + RK45_D4 * k4[i] + RK45_D5 * k5[i])
128            })
129            .collect();
130
131        let err: f64 = (0..n)
132            .map(|i| (y5th[i] - y4th[i]).powi(2))
133            .sum::<f64>()
134            .sqrt();
135
136        if err < tol || h < 1e-15 {
137            y = y5th;
138            t += h;
139            ts.push(t);
140            ys.push(y.clone());
141            if err > 1e-30 {
142                h *= 0.9 * (tol / err).powf(0.2);
143            } else {
144                h *= 2.0;
145            }
146        } else {
147            h *= 0.9 * (tol / err).powf(0.25);
148        }
149        h = h.max(1e-15);
150    }
151    OdeResult { t: ts, y: ys }
152}
153
154pub fn implicit_euler(
155    f: impl Fn(f64, &[f64]) -> Vec<f64>,
156    t_span: (f64, f64),
157    y0: &[f64],
158    dt: f64,
159    newton_iters: usize,
160) -> OdeResult {
161    let n = y0.len();
162    let mut t = t_span.0;
163    let mut y = y0.to_vec();
164    let mut ts = vec![t];
165    let mut ys = vec![y.clone()];
166
167    while t < t_span.1 - 1e-12 {
168        let h = dt.min(t_span.1 - t);
169        let mut y_new = y.clone();
170        let dy = f(t, &y);
171        for i in 0..n {
172            y_new[i] = y[i] + h * dy[i];
173        }
174
175        for _ in 0..newton_iters {
176            let f_new = f(t + h, &y_new);
177            for i in 0..n {
178                y_new[i] = y[i] + h * f_new[i];
179            }
180        }
181
182        y = y_new;
183        t += h;
184        ts.push(t);
185        ys.push(y.clone());
186    }
187    OdeResult { t: ts, y: ys }
188}
189
190pub fn velocity_verlet(
191    accel: impl Fn(f64, &[f64]) -> Vec<f64>,
192    t_span: (f64, f64),
193    x0: &[f64],
194    v0: &[f64],
195    dt: f64,
196) -> (Vec<f64>, Vec<Vec<f64>>, Vec<Vec<f64>>) {
197    let n = x0.len();
198    let mut t = t_span.0;
199    let mut x = x0.to_vec();
200    let mut v = v0.to_vec();
201    let mut ts = vec![t];
202    let mut xs = vec![x.clone()];
203    let mut vs = vec![v.clone()];
204    let mut a = accel(t, &x);
205
206    while t < t_span.1 - 1e-12 {
207        let h = dt.min(t_span.1 - t);
208        for i in 0..n {
209            x[i] += v[i] * h + 0.5 * a[i] * h * h;
210        }
211        t += h;
212        let a_new = accel(t, &x);
213        for i in 0..n {
214            v[i] += 0.5 * (a[i] + a_new[i]) * h;
215        }
216        a = a_new;
217        ts.push(t);
218        xs.push(x.clone());
219        vs.push(v.clone());
220    }
221    (ts, xs, vs)
222}
223
224pub fn midpoint_method(
225    f: impl Fn(f64, &[f64]) -> Vec<f64>,
226    t_span: (f64, f64),
227    y0: &[f64],
228    dt: f64,
229) -> OdeResult {
230    let n = y0.len();
231    let mut t = t_span.0;
232    let mut y = y0.to_vec();
233    let mut ts = vec![t];
234    let mut ys = vec![y.clone()];
235
236    while t < t_span.1 - 1e-12 {
237        let h = dt.min(t_span.1 - t);
238        let k1 = f(t, &y);
239        let ymid: Vec<f64> = (0..n).map(|i| y[i] + 0.5 * h * k1[i]).collect();
240        let k2 = f(t + 0.5 * h, &ymid);
241        for i in 0..n {
242            y[i] += h * k2[i];
243        }
244        t += h;
245        ts.push(t);
246        ys.push(y.clone());
247    }
248    OdeResult { t: ts, y: ys }
249}
250
251pub fn heun(
252    f: impl Fn(f64, &[f64]) -> Vec<f64>,
253    t_span: (f64, f64),
254    y0: &[f64],
255    dt: f64,
256) -> OdeResult {
257    let n = y0.len();
258    let mut t = t_span.0;
259    let mut y = y0.to_vec();
260    let mut ts = vec![t];
261    let mut ys = vec![y.clone()];
262
263    while t < t_span.1 - 1e-12 {
264        let h = dt.min(t_span.1 - t);
265        let k1 = f(t, &y);
266        let y_pred: Vec<f64> = (0..n).map(|i| y[i] + h * k1[i]).collect();
267        let k2 = f(t + h, &y_pred);
268        for i in 0..n {
269            y[i] += 0.5 * h * (k1[i] + k2[i]);
270        }
271        t += h;
272        ts.push(t);
273        ys.push(y.clone());
274    }
275    OdeResult { t: ts, y: ys }
276}
277
278pub fn rk38(
279    f: impl Fn(f64, &[f64]) -> Vec<f64>,
280    t_span: (f64, f64),
281    y0: &[f64],
282    dt: f64,
283) -> OdeResult {
284    let n = y0.len();
285    let mut t = t_span.0;
286    let mut y = y0.to_vec();
287    let mut ts = vec![t];
288    let mut ys = vec![y.clone()];
289
290    while t < t_span.1 - 1e-12 {
291        let h = dt.min(t_span.1 - t);
292        let k1 = f(t, &y);
293        let y2: Vec<f64> = (0..n).map(|i| y[i] + h / 3.0 * k1[i]).collect();
294        let k2 = f(t + h / 3.0, &y2);
295        let y3: Vec<f64> = (0..n).map(|i| y[i] + h * (-k1[i] / 3.0 + k2[i])).collect();
296        let k3 = f(t + 2.0 * h / 3.0, &y3);
297        let y4: Vec<f64> = (0..n).map(|i| y[i] + h * (k1[i] - k2[i] + k3[i])).collect();
298        let k4 = f(t + h, &y4);
299        for i in 0..n {
300            y[i] += h / 8.0 * (k1[i] + 3.0 * k2[i] + 3.0 * k3[i] + k4[i]);
301        }
302        t += h;
303        ts.push(t);
304        ys.push(y.clone());
305    }
306    OdeResult { t: ts, y: ys }
307}
308
309pub fn adams_bashforth_4(
310    f: impl Fn(f64, &[f64]) -> Vec<f64>,
311    t_span: (f64, f64),
312    y0: &[f64],
313    dt: f64,
314) -> OdeResult {
315    let n = y0.len();
316    let bootstrap = rk4(&f, (t_span.0, t_span.0 + 3.0 * dt), y0, dt);
317    let mut ts = bootstrap.t.clone();
318    let mut ys = bootstrap.y.clone();
319    let mut fs: Vec<Vec<f64>> = ts.iter().zip(&ys).map(|(&ti, yi)| f(ti, yi)).collect();
320    let mut t = *ts.last().unwrap();
321    let mut y = ys.last().unwrap().clone();
322
323    while t < t_span.1 - 1e-12 {
324        let h = dt.min(t_span.1 - t);
325        let m = fs.len();
326        let mut y_new = vec![0.0; n];
327        for i in 0..n {
328            y_new[i] = y[i]
329                + h * (ADAMS_BASHFORTH_4[0] * fs[m - 1][i]
330                    + ADAMS_BASHFORTH_4[1] * fs[m - 2][i]
331                    + ADAMS_BASHFORTH_4[2] * fs[m - 3][i]
332                    + ADAMS_BASHFORTH_4[3] * fs[m - 4][i]);
333        }
334        t += h;
335        y = y_new;
336        ts.push(t);
337        ys.push(y.clone());
338        fs.push(f(t, &y));
339    }
340    OdeResult { t: ts, y: ys }
341}
342
343pub fn symplectic_euler(
344    dqdt: impl Fn(&[f64], &[f64]) -> Vec<f64>,
345    dpdt: impl Fn(&[f64], &[f64]) -> Vec<f64>,
346    t_span: (f64, f64),
347    q0: &[f64],
348    p0: &[f64],
349    dt: f64,
350) -> (Vec<f64>, Vec<Vec<f64>>, Vec<Vec<f64>>) {
351    let n = q0.len();
352    let mut t = t_span.0;
353    let mut q = q0.to_vec();
354    let mut p = p0.to_vec();
355    let mut ts = vec![t];
356    let mut qs = vec![q.clone()];
357    let mut ps = vec![p.clone()];
358
359    while t < t_span.1 - 1e-12 {
360        let h = dt.min(t_span.1 - t);
361        let dp = dpdt(&q, &p);
362        for i in 0..n {
363            p[i] += h * dp[i];
364        }
365        let dq = dqdt(&q, &p);
366        for i in 0..n {
367            q[i] += h * dq[i];
368        }
369        t += h;
370        ts.push(t);
371        qs.push(q.clone());
372        ps.push(p.clone());
373    }
374    (ts, qs, ps)
375}
376
377pub fn stiff_bdf2(
378    f: impl Fn(f64, &[f64]) -> Vec<f64>,
379    t_span: (f64, f64),
380    y0: &[f64],
381    dt: f64,
382    newton_iters: usize,
383) -> OdeResult {
384    let n = y0.len();
385    let first = implicit_euler(&f, (t_span.0, t_span.0 + dt), y0, dt, newton_iters);
386    let mut ts = first.t.clone();
387    let mut ys = first.y.clone();
388    let mut t = *ts.last().unwrap();
389
390    while t < t_span.1 - 1e-12 {
391        let h = dt.min(t_span.1 - t);
392        let m = ys.len();
393        let y_n = &ys[m - 1];
394        let y_nm1 = &ys[m - 2];
395        let mut y_new: Vec<f64> = (0..n)
396            .map(|i| 4.0 / 3.0 * y_n[i] - 1.0 / 3.0 * y_nm1[i])
397            .collect();
398        for _ in 0..newton_iters {
399            let fv = f(t + h, &y_new);
400            for i in 0..n {
401                y_new[i] = 4.0 / 3.0 * y_n[i] - 1.0 / 3.0 * y_nm1[i] + 2.0 / 3.0 * h * fv[i];
402            }
403        }
404        t += h;
405        ts.push(t);
406        ys.push(y_new);
407    }
408    OdeResult { t: ts, y: ys }
409}