1pub type OdeRhsFn = Box<dyn Fn(f64, &[f64], &[f64]) -> Vec<f64> + Send + Sync>;
59
60#[derive(Debug, Clone, Copy, PartialEq, Eq)]
62pub enum EnsembleDispatch {
63 Sequential,
65 Simulated,
67}
68
69#[derive(Debug, Clone)]
71pub struct OdeEnsembleConfig {
72 pub t_span: [f64; 2],
74 pub rtol: f64,
76 pub atol: f64,
78 pub max_steps: usize,
80 pub dispatch: EnsembleDispatch,
82}
83
84#[derive(Debug, Clone)]
86pub struct EnsembleMember {
87 pub params: Vec<f64>,
89 pub y0: Vec<f64>,
91}
92
93#[derive(Debug, Clone)]
95pub struct EnsembleResult {
96 pub solutions: Vec<Vec<f64>>,
98 pub n_steps: Vec<usize>,
100 pub success: Vec<bool>,
102 pub t_final: Vec<f64>,
104}
105
106pub struct OdeEnsemble {
108 config: OdeEnsembleConfig,
109}
110
111const A21: f64 = 1.0 / 5.0;
128const A31: f64 = 3.0 / 40.0;
129const A32: f64 = 9.0 / 40.0;
130const A41: f64 = 44.0 / 45.0;
131const A42: f64 = -56.0 / 15.0;
132const A43: f64 = 32.0 / 9.0;
133const A51: f64 = 19372.0 / 6561.0;
134const A52: f64 = -25360.0 / 2187.0;
135const A53: f64 = 64448.0 / 6561.0;
136const A54: f64 = -212.0 / 729.0;
137const A61: f64 = 9017.0 / 3168.0;
138const A62: f64 = -355.0 / 33.0;
139const A63: f64 = 46732.0 / 5247.0;
140const A64: f64 = 49.0 / 176.0;
141const A65: f64 = -5103.0 / 18656.0;
142
143const B1: f64 = 35.0 / 384.0;
145const B3: f64 = 500.0 / 1113.0;
146const B4: f64 = 125.0 / 192.0;
147const B5: f64 = -2187.0 / 6784.0;
148const B6: f64 = 11.0 / 84.0;
149
150const E1: f64 = 71.0 / 57600.0;
152const E3: f64 = -71.0 / 16695.0;
153const E4: f64 = 71.0 / 1920.0;
154const E5: f64 = -17253.0 / 339200.0;
155const E6: f64 = 22.0 / 525.0;
156const E7: f64 = -1.0 / 40.0;
157
158const C2: f64 = 1.0 / 5.0;
160const C3: f64 = 3.0 / 10.0;
161const C4: f64 = 4.0 / 5.0;
162const C5: f64 = 8.0 / 9.0;
163
164pub fn rk45_step(
188 t: f64,
189 y: &[f64],
190 params: &[f64],
191 h: f64,
192 rhs: &dyn Fn(f64, &[f64], &[f64]) -> Vec<f64>,
193 rtol: f64,
194 atol: f64,
195) -> (Vec<f64>, Vec<f64>, f64) {
196 let n = y.len();
197
198 let k1 = rhs(t, y, params);
200
201 let y2: Vec<f64> = (0..n).map(|i| y[i] + h * A21 * k1[i]).collect();
203 let k2 = rhs(t + C2 * h, &y2, params);
204
205 let y3: Vec<f64> = (0..n)
207 .map(|i| y[i] + h * (A31 * k1[i] + A32 * k2[i]))
208 .collect();
209 let k3 = rhs(t + C3 * h, &y3, params);
210
211 let y4: Vec<f64> = (0..n)
213 .map(|i| y[i] + h * (A41 * k1[i] + A42 * k2[i] + A43 * k3[i]))
214 .collect();
215 let k4 = rhs(t + C4 * h, &y4, params);
216
217 let y5_tmp: Vec<f64> = (0..n)
219 .map(|i| y[i] + h * (A51 * k1[i] + A52 * k2[i] + A53 * k3[i] + A54 * k4[i]))
220 .collect();
221 let k5 = rhs(t + C5 * h, &y5_tmp, params);
222
223 let y6_tmp: Vec<f64> = (0..n)
225 .map(|i| y[i] + h * (A61 * k1[i] + A62 * k2[i] + A63 * k3[i] + A64 * k4[i] + A65 * k5[i]))
226 .collect();
227 let k6 = rhs(t + h, &y6_tmp, params);
228
229 let y_new: Vec<f64> = (0..n)
231 .map(|i| y[i] + h * (B1 * k1[i] + B3 * k3[i] + B4 * k4[i] + B5 * k5[i] + B6 * k6[i]))
232 .collect();
233
234 let k7 = rhs(t + h, &y_new, params);
236
237 let err_vec: Vec<f64> = (0..n)
239 .map(|i| h * (E1 * k1[i] + E3 * k3[i] + E4 * k4[i] + E5 * k5[i] + E6 * k6[i] + E7 * k7[i]))
240 .collect();
241
242 let err_norm = {
244 let sum_sq: f64 = (0..n)
245 .map(|i| {
246 let sc = atol + rtol * y[i].abs().max(y_new[i].abs());
247 let e = err_vec[i] / sc;
248 e * e
249 })
250 .sum::<f64>();
251 (sum_sq / n as f64).sqrt()
252 };
253
254 (y_new, err_vec, err_norm)
255}
256
257impl OdeEnsemble {
262 pub fn new(config: OdeEnsembleConfig) -> Self {
264 Self { config }
265 }
266
267 pub fn integrate(
278 &self,
279 members: &[EnsembleMember],
280 rhs: &dyn Fn(f64, &[f64], &[f64]) -> Vec<f64>,
281 ) -> EnsembleResult {
282 let n = members.len();
283 let mut solutions = Vec::with_capacity(n);
284 let mut n_steps_vec = Vec::with_capacity(n);
285 let mut success_vec = Vec::with_capacity(n);
286 let mut t_final_vec = Vec::with_capacity(n);
287
288 for member in members {
289 let (y_final, n_steps, ok) = self.integrate_single(member, rhs);
290 let t_reached = if ok {
291 self.config.t_span[1]
292 } else {
293 self.config.t_span[0]
296 };
297 solutions.push(y_final);
298 n_steps_vec.push(n_steps);
299 success_vec.push(ok);
300 t_final_vec.push(t_reached);
301 }
302
303 EnsembleResult {
304 solutions,
305 n_steps: n_steps_vec,
306 success: success_vec,
307 t_final: t_final_vec,
308 }
309 }
310
311 fn integrate_single(
315 &self,
316 member: &EnsembleMember,
317 rhs: &dyn Fn(f64, &[f64], &[f64]) -> Vec<f64>,
318 ) -> (Vec<f64>, usize, bool) {
319 let t_start = self.config.t_span[0];
320 let t_end = self.config.t_span[1];
321 let rtol = self.config.rtol;
322 let atol = self.config.atol;
323 let max_steps = self.config.max_steps;
324
325 let mut t = t_start;
326 let mut y = member.y0.clone();
327 let n = y.len();
328
329 if n == 0 {
330 return (y, 0, true);
331 }
332
333 let span = (t_end - t_start).abs();
335 let mut h = span * 1e-3;
336 h = h.min(span);
338
339 let direction = if t_end >= t_start { 1.0_f64 } else { -1.0 };
340 h *= direction;
341
342 let fac = 0.9_f64;
343 let fac_max = 10.0_f64;
344 let fac_min = 0.2_f64;
345
346 let mut steps = 0_usize;
347 let mut converged = false;
348
349 while (direction * (t_end - t)).abs() > 1e-12 * span.max(f64::EPSILON) {
350 if steps >= max_steps {
351 break;
352 }
353
354 if direction * (t + h - t_end) > 0.0 {
356 h = t_end - t;
357 }
358 if h.abs() < f64::EPSILON * span {
359 break;
361 }
362
363 let (y_new, _err_vec, err_norm) = rk45_step(t, &y, &member.params, h, rhs, rtol, atol);
364
365 if err_norm <= 1.0 || err_norm.is_nan() {
367 t += h;
369 y = y_new;
370 steps += 1;
371
372 if (direction * (t_end - t)).abs() < 1e-12 * span.max(f64::EPSILON) {
373 converged = true;
374 break;
375 }
376 }
377
378 let err_safe = err_norm.max(f64::EPSILON);
380 let factor = fac * err_safe.powf(-0.2);
381 let factor = factor.clamp(fac_min, fac_max);
382 h *= factor;
383
384 }
387
388 if (t - t_end).abs() < 1e-8 * span.max(f64::EPSILON) {
390 converged = true;
391 }
392
393 (y, steps, converged)
394 }
395}
396
397#[cfg(test)]
402mod tests {
403 use super::*;
404
405 fn default_config() -> OdeEnsembleConfig {
406 OdeEnsembleConfig {
407 t_span: [0.0, 1.0],
408 rtol: 1e-7,
409 atol: 1e-9,
410 max_steps: 100_000,
411 dispatch: EnsembleDispatch::Sequential,
412 }
413 }
414
415 #[test]
418 fn test_identical_params_same_solution() {
419 let config = default_config();
420 let ensemble = OdeEnsemble::new(config);
421 let members: Vec<EnsembleMember> = (0..5)
422 .map(|_| EnsembleMember {
423 params: vec![2.0],
424 y0: vec![1.0],
425 })
426 .collect();
427 let result = ensemble.integrate(&members, &|_t, y, p| vec![-p[0] * y[0]]);
428 let y0 = &result.solutions[0];
429 for (i, sol) in result.solutions.iter().enumerate().skip(1) {
430 assert!(
431 (sol[0] - y0[0]).abs() < 1e-14,
432 "member {i} diverges from member 0: {:.6e} vs {:.6e}",
433 sol[0],
434 y0[0]
435 );
436 }
437 }
438
439 #[test]
441 fn test_different_params_different_solutions() {
442 let config = default_config();
443 let ensemble = OdeEnsemble::new(config);
444 let ks: Vec<f64> = vec![0.5, 1.0, 2.0, 4.0, 8.0];
445 let members: Vec<EnsembleMember> = ks
446 .iter()
447 .map(|&k| EnsembleMember {
448 params: vec![k],
449 y0: vec![1.0],
450 })
451 .collect();
452 let result = ensemble.integrate(&members, &|_t, y, p| vec![-p[0] * y[0]]);
453 for i in 1..ks.len() {
455 let y_prev = result.solutions[i - 1][0];
456 let y_curr = result.solutions[i][0];
457 assert!(
458 y_curr < y_prev,
459 "k={} solution ({:.6e}) should be < k={} solution ({:.6e})",
460 ks[i],
461 y_curr,
462 ks[i - 1],
463 y_prev
464 );
465 }
466 }
467
468 #[test]
471 fn test_exponential_decay_analytical() {
472 let config = OdeEnsembleConfig {
473 t_span: [0.0, 2.0],
474 rtol: 1e-8,
475 atol: 1e-10,
476 max_steps: 100_000,
477 dispatch: EnsembleDispatch::Sequential,
478 };
479 let ensemble = OdeEnsemble::new(config);
480 let k = 3.0_f64;
481 let y0 = 2.5_f64;
482 let members = vec![EnsembleMember {
483 params: vec![k],
484 y0: vec![y0],
485 }];
486 let result = ensemble.integrate(&members, &|_t, y, p| vec![-p[0] * y[0]]);
487 let y_numerical = result.solutions[0][0];
488 let y_analytical = y0 * (-k * 2.0_f64).exp();
489 assert!(
490 (y_numerical - y_analytical).abs() < 1e-6,
491 "y_numerical = {y_numerical:.8e}, y_analytical = {y_analytical:.8e}"
492 );
493 }
494
495 #[test]
497 fn test_all_converge() {
498 let config = default_config();
499 let ensemble = OdeEnsemble::new(config);
500 let members: Vec<EnsembleMember> = (1..=5)
501 .map(|k| EnsembleMember {
502 params: vec![k as f64],
503 y0: vec![1.0],
504 })
505 .collect();
506 let result = ensemble.integrate(&members, &|_t, y, p| vec![-p[0] * y[0]]);
507 for (i, &ok) in result.success.iter().enumerate() {
508 assert!(ok, "member {i} did not converge");
509 }
510 }
511
512 #[test]
514 fn test_n_steps_positive() {
515 let config = default_config();
516 let ensemble = OdeEnsemble::new(config);
517 let members: Vec<EnsembleMember> = (1..=5)
518 .map(|k| EnsembleMember {
519 params: vec![k as f64],
520 y0: vec![1.0],
521 })
522 .collect();
523 let result = ensemble.integrate(&members, &|_t, y, p| vec![-p[0] * y[0]]);
524 for (i, &ns) in result.n_steps.iter().enumerate() {
525 assert!(ns > 0, "member {i} took 0 steps");
526 }
527 }
528
529 #[test]
531 fn test_2d_system_vanderpol() {
532 let config = OdeEnsembleConfig {
533 t_span: [0.0, 5.0],
534 rtol: 1e-6,
535 atol: 1e-8,
536 max_steps: 500_000,
537 dispatch: EnsembleDispatch::Sequential,
538 };
539 let ensemble = OdeEnsemble::new(config);
540 let member = EnsembleMember {
542 params: vec![0.1],
543 y0: vec![2.0, 0.0],
544 };
545 let result = ensemble.integrate(&[member], &|_t, y, p| {
546 let mu = p[0];
547 vec![y[1], mu * (1.0 - y[0] * y[0]) * y[1] - y[0]]
548 });
549 assert!(result.success[0], "van-der-Pol did not converge");
550 for &v in &result.solutions[0] {
552 assert!(v.is_finite(), "van-der-Pol solution is non-finite");
553 }
554 }
555
556 #[test]
558 fn test_simulated_dispatch_matches_sequential() {
559 let config_seq = OdeEnsembleConfig {
560 t_span: [0.0, 1.0],
561 rtol: 1e-7,
562 atol: 1e-9,
563 max_steps: 50_000,
564 dispatch: EnsembleDispatch::Sequential,
565 };
566 let config_sim = OdeEnsembleConfig {
567 dispatch: EnsembleDispatch::Simulated,
568 ..config_seq.clone()
569 };
570 let members: Vec<EnsembleMember> = vec![
571 EnsembleMember {
572 params: vec![1.0],
573 y0: vec![1.0],
574 },
575 EnsembleMember {
576 params: vec![2.0],
577 y0: vec![3.0],
578 },
579 ];
580 let ens_seq = OdeEnsemble::new(config_seq);
581 let ens_sim = OdeEnsemble::new(config_sim);
582 let rhs = &|_t: f64, y: &[f64], p: &[f64]| vec![-p[0] * y[0]];
583 let res_seq = ens_seq.integrate(&members, rhs);
584 let res_sim = ens_sim.integrate(&members, rhs);
585 for i in 0..members.len() {
586 assert!(
587 (res_seq.solutions[i][0] - res_sim.solutions[i][0]).abs() < 1e-14,
588 "member {i}: sequential={:.6e}, simulated={:.6e}",
589 res_seq.solutions[i][0],
590 res_sim.solutions[i][0]
591 );
592 }
593 }
594}