1use super::types::{EnsembleConfig, EnsembleResult};
7use crate::error::{IntegrateError, IntegrateResult};
8
9const DP_A: [[f64; 6]; 6] = [
20 [1.0 / 5.0, 0.0, 0.0, 0.0, 0.0, 0.0],
21 [3.0 / 40.0, 9.0 / 40.0, 0.0, 0.0, 0.0, 0.0],
22 [44.0 / 45.0, -56.0 / 15.0, 32.0 / 9.0, 0.0, 0.0, 0.0],
23 [
24 19_372.0 / 6_561.0,
25 -25_360.0 / 2_187.0,
26 64_448.0 / 6_561.0,
27 -212.0 / 729.0,
28 0.0,
29 0.0,
30 ],
31 [
32 9_017.0 / 3_168.0,
33 -355.0 / 33.0,
34 46_732.0 / 5_247.0,
35 49.0 / 176.0,
36 -5_103.0 / 18_656.0,
37 0.0,
38 ],
39 [
40 35.0 / 384.0,
41 0.0,
42 500.0 / 1_113.0,
43 125.0 / 192.0,
44 -2_187.0 / 6_784.0,
45 11.0 / 84.0,
46 ],
47];
48
49const DP_B5: [f64; 7] = [
51 35.0 / 384.0,
52 0.0,
53 500.0 / 1_113.0,
54 125.0 / 192.0,
55 -2_187.0 / 6_784.0,
56 11.0 / 84.0,
57 0.0,
58];
59
60const DP_B4: [f64; 7] = [
62 5_179.0 / 57_600.0,
63 0.0,
64 7_571.0 / 16_695.0,
65 393.0 / 640.0,
66 -92_097.0 / 339_200.0,
67 187.0 / 2_100.0,
68 1.0 / 40.0,
69];
70
71const DP_C: [f64; 7] = [0.0, 1.0 / 5.0, 3.0 / 10.0, 4.0 / 5.0, 8.0 / 9.0, 1.0, 1.0];
73
74struct Rk45State {
78 t: f64,
79 y: Vec<f64>,
80 k1: Vec<f64>,
82 h: f64,
83}
84
85impl Rk45State {
86 fn new(t0: f64, y0: Vec<f64>, k1: Vec<f64>, h_init: f64) -> Self {
87 Self {
88 t: t0,
89 y: y0,
90 k1,
91 h: h_init,
92 }
93 }
94}
95
96fn stage<F, P>(f: &F, t: f64, y: &[f64], param: &P) -> Vec<f64>
98where
99 F: Fn(f64, &[f64], &P) -> Vec<f64>,
100{
101 f(t, y, param)
102}
103
104fn axpy(a: &[f64], scale: f64, b: &[f64]) -> Vec<f64> {
106 a.iter()
107 .zip(b.iter())
108 .map(|(&ai, &bi)| ai + scale * bi)
109 .collect()
110}
111
112fn rk_sum(base: &[f64], h: f64, weights: &[f64], ks: &[Vec<f64>]) -> Vec<f64> {
114 let n = base.len();
115 let mut result = base.to_vec();
116 for (w, k) in weights.iter().zip(ks.iter()) {
117 if w.abs() < f64::EPSILON {
118 continue;
119 }
120 for i in 0..n {
121 result[i] += h * w * k[i];
122 }
123 }
124 result
125}
126
127fn error_norm(y: &[f64], y_new: &[f64], e: &[f64], rtol: f64, atol: f64) -> f64 {
129 let n = y.len();
130 if n == 0 {
131 return 0.0;
132 }
133 let mut sum = 0.0_f64;
134 for i in 0..n {
135 let scale = atol + rtol * y[i].abs().max(y_new[i].abs());
136 let ei = e[i] / scale;
137 sum += ei * ei;
138 }
139 (sum / n as f64).sqrt()
140}
141
142fn rk45_step<F, P>(
146 f: &F,
147 state: &Rk45State,
148 t_end: f64,
149 rtol: f64,
150 atol: f64,
151 param: &P,
152) -> (Vec<f64>, Vec<f64>, f64, bool)
153where
154 F: Fn(f64, &[f64], &P) -> Vec<f64>,
155{
156 let t = state.t;
157 let y = &state.y;
158 let h = state.h.min(t_end - t);
159
160 let k1 = state.k1.clone();
162 let y2 = axpy(y, h * DP_A[0][0], &k1);
163 let k2 = stage(f, t + DP_C[1] * h, &y2, param);
164
165 let y3 = {
166 let mut v = y.to_vec();
167 for i in 0..v.len() {
168 v[i] += h * (DP_A[1][0] * k1[i] + DP_A[1][1] * k2[i]);
169 }
170 v
171 };
172 let k3 = stage(f, t + DP_C[2] * h, &y3, param);
173
174 let y4 = {
175 let mut v = y.to_vec();
176 for i in 0..v.len() {
177 v[i] += h * (DP_A[2][0] * k1[i] + DP_A[2][1] * k2[i] + DP_A[2][2] * k3[i]);
178 }
179 v
180 };
181 let k4 = stage(f, t + DP_C[3] * h, &y4, param);
182
183 let y5 = {
184 let mut v = y.to_vec();
185 for i in 0..v.len() {
186 v[i] += h
187 * (DP_A[3][0] * k1[i]
188 + DP_A[3][1] * k2[i]
189 + DP_A[3][2] * k3[i]
190 + DP_A[3][3] * k4[i]);
191 }
192 v
193 };
194 let k5 = stage(f, t + DP_C[4] * h, &y5, param);
195
196 let y6 = {
197 let mut v = y.to_vec();
198 for i in 0..v.len() {
199 v[i] += h
200 * (DP_A[4][0] * k1[i]
201 + DP_A[4][1] * k2[i]
202 + DP_A[4][2] * k3[i]
203 + DP_A[4][3] * k4[i]
204 + DP_A[4][4] * k5[i]);
205 }
206 v
207 };
208 let k6 = stage(f, t + DP_C[5] * h, &y6, param);
209
210 let y_new = rk_sum(
212 y,
213 h,
214 &DP_B5[..6],
215 &[
216 k1.clone(),
217 k2.clone(),
218 k3.clone(),
219 k4.clone(),
220 k5.clone(),
221 k6.clone(),
222 ],
223 );
224 let k7 = stage(f, t + h, &y_new, param);
225
226 let y4_ord = rk_sum(y, h, &DP_B4, &[k1, k2, k3, k4, k5, k6, k7.clone()]);
228
229 let e: Vec<f64> = y_new
231 .iter()
232 .zip(y4_ord.iter())
233 .map(|(&a, &b)| a - b)
234 .collect();
235 let err = error_norm(y, &y_new, &e, rtol, atol);
236
237 let factor = if err == 0.0 {
239 5.0
240 } else {
241 0.9 * err.powf(-0.2)
242 };
243 let factor = factor.clamp(0.2, 5.0);
244 let h_next = h * factor;
245
246 if err <= 1.0 {
247 (y_new, k7, h_next, true)
249 } else {
250 (y.clone(), k7, h_next, false)
252 }
253}
254
255fn integrate_member<F, P>(
257 f: &F,
258 t0: f64,
259 t_end: f64,
260 y0: Vec<f64>,
261 param: &P,
262 rtol: f64,
263 atol: f64,
264 h_init: f64,
265 max_steps: usize,
266) -> (Vec<Vec<f64>>, Vec<f64>, bool, usize)
267where
268 F: Fn(f64, &[f64], &P) -> Vec<f64>,
269{
270 let n_state = y0.len();
271
272 let h0 = if h_init > 0.0 {
274 h_init
275 } else {
276 ((t_end - t0) * 0.01).max(1e-8).min((t_end - t0) / 10.0)
278 };
279
280 let k1_0 = f(t0, &y0, param);
281 let mut state = Rk45State::new(t0, y0.clone(), k1_0, h0);
282
283 let mut traj = vec![y0];
284 let mut times = vec![t0];
285 let mut n_steps = 0_usize;
286
287 while state.t < t_end - 1e-14 * (t_end - t0) && n_steps < max_steps {
288 let (y_new, k_new, h_next, accepted) = rk45_step(f, &state, t_end, rtol, atol, param);
289
290 if accepted {
291 state.t = (state.t + state.h).min(t_end);
292 state.y = y_new.clone();
293 state.k1 = k_new;
294 state.h = h_next.max(1e-14);
295 n_steps += 1;
296 traj.push(y_new);
297 times.push(state.t);
298 } else {
299 state.h = h_next.max(1e-14);
301 }
302
303 if state.h < 1e-14 * state.t.abs().max(1.0) {
305 break;
306 }
307 }
308
309 let converged = if (state.t - t_end).abs() < 1e-12 * (t_end - t0 + 1.0) {
310 true
311 } else if n_steps == max_steps {
312 false
314 } else {
315 state.t >= t_end - 1e-10 * (t_end - t0)
316 };
317
318 if traj.is_empty() {
320 traj.push(vec![0.0; n_state]);
321 times.push(t0);
322 }
323
324 (traj, times, converged, n_steps)
325}
326
327pub struct OdeEnsembleSolver {
334 pub config: EnsembleConfig,
336}
337
338impl OdeEnsembleSolver {
339 pub fn new(config: EnsembleConfig) -> Self {
341 Self { config }
342 }
343
344 pub fn solve<F, P>(
363 &self,
364 f: F,
365 params: &[P],
366 y0s: &[Vec<f64>],
367 config: &EnsembleConfig,
368 ) -> IntegrateResult<EnsembleResult>
369 where
370 F: Fn(f64, &[f64], &P) -> Vec<f64> + Sync,
371 P: Sync,
372 {
373 if params.len() != y0s.len() {
374 return Err(IntegrateError::InvalidInput(format!(
375 "params.len() ({}) != y0s.len() ({})",
376 params.len(),
377 y0s.len()
378 )));
379 }
380
381 let (t0, t_end) = config.t_span;
382 if t0 >= t_end {
383 return Err(IntegrateError::InvalidInput(
384 "t_span must satisfy t0 < t_end".to_string(),
385 ));
386 }
387
388 let n = params.len();
389 if n == 0 {
390 return Ok(EnsembleResult {
391 trajectories: vec![],
392 times: vec![],
393 converged: vec![],
394 n_steps: vec![],
395 });
396 }
397
398 let rtol = config.rtol;
399 let atol = config.atol;
400 let h_init = config.h_init;
401 let max_steps = config.max_steps;
402 let n_threads = config.n_threads.max(1).min(n);
403
404 let mut trajectories: Vec<Vec<Vec<f64>>> = vec![Vec::new(); n];
406 let mut times_out: Vec<Vec<f64>> = vec![Vec::new(); n];
407 let mut converged: Vec<bool> = vec![false; n];
408 let mut n_steps_out: Vec<usize> = vec![0; n];
409
410 let chunk_size = n.div_ceil(n_threads);
413
414 let results: Vec<std::sync::Mutex<Option<(Vec<Vec<f64>>, Vec<f64>, bool, usize)>>> =
416 (0..n).map(|_| std::sync::Mutex::new(None)).collect();
417
418 std::thread::scope(|scope| {
420 let results_ref = &results;
421 let f_ref = &f;
422
423 for tid in 0..n_threads {
425 let start = tid * chunk_size;
426 if start >= n {
427 break;
428 }
429 let end = (start + chunk_size).min(n);
430 let params_slice = ¶ms[start..end];
431 let y0s_slice = &y0s[start..end];
432
433 scope.spawn(move || {
434 for (local_idx, (param, y0)) in
435 params_slice.iter().zip(y0s_slice.iter()).enumerate()
436 {
437 let global_idx = start + local_idx;
438 let (traj, ts, conv, ns) = integrate_member(
439 f_ref,
440 t0,
441 t_end,
442 y0.clone(),
443 param,
444 rtol,
445 atol,
446 h_init,
447 max_steps,
448 );
449 if let Ok(mut slot) = results_ref[global_idx].lock() {
451 *slot = Some((traj, ts, conv, ns));
452 }
453 }
454 });
455 }
456 });
458
459 for (i, slot) in results.into_iter().enumerate() {
461 if let Ok(Some((traj, ts, conv, ns))) = slot.into_inner() {
462 trajectories[i] = traj;
463 times_out[i] = ts;
464 converged[i] = conv;
465 n_steps_out[i] = ns;
466 }
467 }
468
469 Ok(EnsembleResult {
470 trajectories,
471 times: times_out,
472 converged,
473 n_steps: n_steps_out,
474 })
475 }
476}
477
478#[cfg(test)]
481mod tests {
482 use super::*;
483
484 fn approx_eq(a: f64, b: f64, tol: f64) -> bool {
485 (a - b).abs() <= tol
486 }
487
488 #[test]
490 fn test_ensemble_exponential_decay() {
491 let cfg = EnsembleConfig {
492 n_ensemble: 10,
493 n_threads: 2,
494 rtol: 1e-8,
495 atol: 1e-10,
496 t_span: (0.0, 1.0),
497 max_steps: 10_000,
498 h_init: 0.0,
499 };
500
501 let solver = OdeEnsembleSolver::new(cfg.clone());
502 let params: Vec<f64> = vec![1.0; 10];
503 let y0s: Vec<Vec<f64>> = vec![vec![1.0]; 10];
504
505 let result = solver
506 .solve(|_t, y, &p| vec![-p * y[0]], ¶ms, &y0s, &cfg)
507 .expect("solve failed");
508
509 assert_eq!(result.trajectories.len(), 10);
510 for (i, (traj, ts)) in result
511 .trajectories
512 .iter()
513 .zip(result.times.iter())
514 .enumerate()
515 {
516 let t_final = *ts.last().expect("no times");
517 let y_final = traj.last().expect("no trajectory")[0];
518 let expected = (-t_final).exp();
519 assert!(
520 approx_eq(y_final, expected, 1e-5),
521 "member {i}: y(t={t_final:.4}) = {y_final:.8}, expected {expected:.8}"
522 );
523 }
524 }
525
526 #[test]
528 fn test_ensemble_all_converged() {
529 let cfg = EnsembleConfig {
530 n_ensemble: 10,
531 n_threads: 4,
532 rtol: 1e-8,
533 atol: 1e-10,
534 t_span: (0.0, 2.0),
535 max_steps: 50_000,
536 h_init: 0.0,
537 };
538 let solver = OdeEnsembleSolver::new(cfg.clone());
539 let params: Vec<f64> = vec![1.0; 10];
540 let y0s: Vec<Vec<f64>> = vec![vec![1.0]; 10];
541
542 let result = solver
543 .solve(|_t, y, &p| vec![-p * y[0]], ¶ms, &y0s, &cfg)
544 .expect("solve failed");
545
546 for (i, &conv) in result.converged.iter().enumerate() {
547 assert!(conv, "member {i} did not converge");
548 }
549 }
550
551 #[test]
553 fn test_ensemble_different_ics() {
554 let cfg = EnsembleConfig {
555 n_ensemble: 5,
556 n_threads: 2,
557 rtol: 1e-8,
558 atol: 1e-10,
559 t_span: (0.0, 1.0),
560 max_steps: 10_000,
561 h_init: 0.0,
562 };
563 let solver = OdeEnsembleSolver::new(cfg.clone());
564 let params: Vec<f64> = vec![1.0; 5];
565 let y0s: Vec<Vec<f64>> = (1..=5).map(|i| vec![i as f64]).collect();
567
568 let result = solver
569 .solve(|_t, y, &p| vec![-p * y[0]], ¶ms, &y0s, &cfg)
570 .expect("solve failed");
571
572 let finals: Vec<f64> = result
574 .trajectories
575 .iter()
576 .map(|traj| traj.last().expect("no traj")[0])
577 .collect();
578
579 for i in 1..finals.len() {
580 assert!(
581 (finals[i] - finals[0]).abs() > 0.1,
582 "members 0 and {i} should differ: {} vs {}",
583 finals[0],
584 finals[i]
585 );
586 }
587 }
588
589 #[test]
591 fn test_ensemble_config_default() {
592 let cfg = EnsembleConfig::default();
593 assert!(cfg.n_ensemble > 0);
594 assert!(cfg.n_threads > 0);
595 assert!(cfg.rtol > 0.0);
596 assert!(cfg.atol > 0.0);
597 let (t0, t1) = cfg.t_span;
598 assert!(t0 < t1);
599 }
600
601 #[test]
603 fn test_ensemble_parallel_same_as_serial() {
604 let mk_cfg = |n_threads: usize| EnsembleConfig {
605 n_ensemble: 4,
606 n_threads,
607 rtol: 1e-8,
608 atol: 1e-10,
609 t_span: (0.0, 1.0),
610 max_steps: 10_000,
611 h_init: 0.0,
612 };
613
614 let params: Vec<f64> = vec![0.5, 1.0, 1.5, 2.0];
615 let y0s: Vec<Vec<f64>> = vec![vec![1.0]; 4];
616
617 let f = |_t: f64, y: &[f64], &p: &f64| vec![-p * y[0]];
618
619 let cfg1 = mk_cfg(1);
620 let solver1 = OdeEnsembleSolver::new(cfg1.clone());
621 let res1 = solver1
622 .solve(f, ¶ms, &y0s, &cfg1)
623 .expect("solve 1 failed");
624
625 let cfg2 = mk_cfg(2);
626 let solver2 = OdeEnsembleSolver::new(cfg2.clone());
627 let res2 = solver2
628 .solve(f, ¶ms, &y0s, &cfg2)
629 .expect("solve 2 failed");
630
631 for i in 0..4 {
632 let y1 = res1.trajectories[i].last().expect("no traj1")[0];
633 let y2 = res2.trajectories[i].last().expect("no traj2")[0];
634 assert!(
635 approx_eq(y1, y2, 1e-10),
636 "member {i}: thread-1={y1}, thread-2={y2}"
637 );
638 }
639 }
640
641 #[test]
643 fn test_ensemble_mean_trajectory() {
644 let cfg = EnsembleConfig {
645 n_ensemble: 5,
646 n_threads: 2,
647 rtol: 1e-8,
648 atol: 1e-10,
649 t_span: (0.0, 1.0),
650 max_steps: 10_000,
651 h_init: 1e-3,
652 };
653 let solver = OdeEnsembleSolver::new(cfg.clone());
654 let params: Vec<f64> = vec![1.0; 5];
655 let y0s: Vec<Vec<f64>> = vec![vec![1.0]; 5];
656
657 let result = solver
658 .solve(|_t, y, &p| vec![-p * y[0]], ¶ms, &y0s, &cfg)
659 .expect("solve failed");
660
661 let mean = result.mean_trajectory().expect("mean failed");
662 let single = &result.trajectories[0];
664 let min_len = mean.len().min(single.len());
665 for k in 0..min_len {
666 assert!(
667 approx_eq(mean[k][0], single[k][0], 1e-10),
668 "mean[{k}]={}, single[{k}]={}",
669 mean[k][0],
670 single[k][0]
671 );
672 }
673 }
674}