1use std::fmt;
41
42pub trait OdeFunc: Send + Sync {
50 fn call(&self, t: f64, y: &[f64], params: &[f64]) -> Vec<f64>;
52
53 fn vjp(
60 &self,
61 t: f64,
62 y: &[f64],
63 params: &[f64],
64 grad_output: &[f64],
65 ) -> (Vec<f64>, f64, Vec<f64>) {
66 let eps = 1e-6_f64;
67 let n = y.len();
68 let p = params.len();
69
70 let mut grad_y = vec![0.0_f64; n];
72 for i in 0..n {
73 let mut y_plus = y.to_vec();
74 let mut y_minus = y.to_vec();
75 y_plus[i] += eps;
76 y_minus[i] -= eps;
77 let f_plus = self.call(t, &y_plus, params);
78 let f_minus = self.call(t, &y_minus, params);
79 for (k, go) in grad_output.iter().enumerate() {
80 grad_y[i] += go * (f_plus[k] - f_minus[k]) / (2.0 * eps);
81 }
82 }
83
84 let f_tplus = self.call(t + eps, y, params);
86 let f_tminus = self.call(t - eps, y, params);
87 let grad_t: f64 = grad_output
88 .iter()
89 .enumerate()
90 .map(|(k, go)| go * (f_tplus[k] - f_tminus[k]) / (2.0 * eps))
91 .sum();
92
93 let mut grad_params = vec![0.0_f64; p];
95 for j in 0..p {
96 let mut p_plus = params.to_vec();
97 let mut p_minus = params.to_vec();
98 p_plus[j] += eps;
99 p_minus[j] -= eps;
100 let f_plus = self.call(t, y, &p_plus);
101 let f_minus = self.call(t, y, &p_minus);
102 for (k, go) in grad_output.iter().enumerate() {
103 grad_params[j] += go * (f_plus[k] - f_minus[k]) / (2.0 * eps);
104 }
105 }
106
107 (grad_y, grad_t, grad_params)
108 }
109}
110
111#[derive(Debug, Clone)]
117pub struct OdeSolution {
118 pub times: Vec<f64>,
120 pub states: Vec<Vec<f64>>,
122 pub nfev: usize,
124}
125
126#[derive(Debug, Clone)]
128pub struct AdaptiveSolution {
129 pub solution: OdeSolution,
131 pub rejected_steps: usize,
133 pub final_step_size: f64,
135}
136
137#[derive(Debug, Clone)]
139pub struct AdjointResult {
140 pub final_state: Vec<f64>,
142 pub grad_y0: Vec<f64>,
144 pub grad_params: Vec<f64>,
146 pub total_nfev: usize,
148}
149
150#[derive(Debug, Clone)]
156pub struct OdeSolverConfig {
157 pub rtol: f64,
159 pub atol: f64,
161 pub max_steps: usize,
163 pub min_step: f64,
165 pub max_step: f64,
167 pub dense_output: bool,
170}
171
172impl Default for OdeSolverConfig {
173 fn default() -> Self {
174 Self {
175 rtol: 1e-4,
176 atol: 1e-6,
177 max_steps: 1000,
178 min_step: 1e-12,
179 max_step: f64::INFINITY,
180 dense_output: true,
181 }
182 }
183}
184
185impl OdeSolverConfig {
186 pub fn new() -> Self {
188 Self::default()
189 }
190
191 pub fn rtol(mut self, v: f64) -> Self {
193 self.rtol = v;
194 self
195 }
196
197 pub fn atol(mut self, v: f64) -> Self {
199 self.atol = v;
200 self
201 }
202
203 pub fn max_steps(mut self, n: usize) -> Self {
205 self.max_steps = n;
206 self
207 }
208
209 pub fn no_dense_output(mut self) -> Self {
211 self.dense_output = false;
212 self
213 }
214}
215
216#[derive(Debug)]
222pub enum OdeError {
223 MaxStepsExceeded,
225 StepTooSmall,
227 DivergentSolution,
229 InvalidInput(String),
231}
232
233impl fmt::Display for OdeError {
234 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
235 match self {
236 OdeError::MaxStepsExceeded => write!(
237 f,
238 "ODE solver exceeded the maximum number of steps; \
239 consider relaxing tolerances or increasing max_steps"
240 ),
241 OdeError::StepTooSmall => write!(
242 f,
243 "ODE solver step size fell below the minimum threshold; \
244 the problem may be too stiff for this explicit solver"
245 ),
246 OdeError::DivergentSolution => write!(
247 f,
248 "ODE solution diverged (NaN or Inf encountered in state vector)"
249 ),
250 OdeError::InvalidInput(msg) => {
251 write!(f, "ODE solver received invalid input: {msg}")
252 }
253 }
254 }
255}
256
257impl std::error::Error for OdeError {}
258
259#[inline]
264#[allow(dead_code)]
265fn vec_add(a: &[f64], b: &[f64]) -> Vec<f64> {
266 a.iter().zip(b.iter()).map(|(x, y)| x + y).collect()
267}
268
269#[inline]
270#[allow(dead_code)]
271fn vec_scale(v: &[f64], s: f64) -> Vec<f64> {
272 v.iter().map(|x| x * s).collect()
273}
274
275#[inline]
276fn vec_axpy(y: &[f64], alpha: f64, x: &[f64]) -> Vec<f64> {
277 y.iter()
278 .zip(x.iter())
279 .map(|(yi, xi)| yi + alpha * xi)
280 .collect()
281}
282
283fn error_norm(err: &[f64], y: &[f64], y_new: &[f64], rtol: f64, atol: f64) -> f64 {
288 let n = err.len();
289 if n == 0 {
290 return 0.0;
291 }
292 let sum: f64 = err
293 .iter()
294 .zip(y.iter())
295 .zip(y_new.iter())
296 .map(|((e, yi), yn)| {
297 let sc = atol + rtol * yi.abs().max(yn.abs());
298 (e / sc).powi(2)
299 })
300 .sum();
301 (sum / n as f64).sqrt()
302}
303
304fn has_diverged(v: &[f64]) -> bool {
306 v.iter().any(|x| x.is_nan() || x.is_infinite())
307}
308
309pub fn rk4_solve(
326 func: &dyn OdeFunc,
327 t0: f64,
328 t1: f64,
329 y0: &[f64],
330 params: &[f64],
331 num_steps: usize,
332) -> OdeSolution {
333 let steps = num_steps.max(1);
334 let h = (t1 - t0) / steps as f64;
335
336 let mut times = Vec::with_capacity(steps + 1);
337 let mut states = Vec::with_capacity(steps + 1);
338 let mut nfev = 0usize;
339
340 times.push(t0);
341 states.push(y0.to_vec());
342
343 let mut t = t0;
344 let mut y = y0.to_vec();
345
346 for _ in 0..steps {
347 let k1 = func.call(t, &y, params);
349 nfev += 1;
350 let y2 = vec_axpy(&y, h * 0.5, &k1);
352 let k2 = func.call(t + h * 0.5, &y2, params);
353 nfev += 1;
354 let y3 = vec_axpy(&y, h * 0.5, &k2);
356 let k3 = func.call(t + h * 0.5, &y3, params);
357 nfev += 1;
358 let y4 = vec_axpy(&y, h, &k3);
360 let k4 = func.call(t + h, &y4, params);
361 nfev += 1;
362
363 y = y
365 .iter()
366 .zip(k1.iter())
367 .zip(k2.iter())
368 .zip(k3.iter())
369 .zip(k4.iter())
370 .map(|((((yi, k1i), k2i), k3i), k4i)| {
371 yi + h / 6.0 * (k1i + 2.0 * k2i + 2.0 * k3i + k4i)
372 })
373 .collect();
374 t += h;
375
376 times.push(t);
377 states.push(y.clone());
378 }
379
380 OdeSolution {
381 times,
382 states,
383 nfev,
384 }
385}
386
387const DOPRI5_A21: f64 = 1.0 / 5.0;
420const DOPRI5_A31: f64 = 3.0 / 40.0;
421const DOPRI5_A32: f64 = 9.0 / 40.0;
422const DOPRI5_A41: f64 = 44.0 / 45.0;
423const DOPRI5_A42: f64 = -56.0 / 15.0;
424const DOPRI5_A43: f64 = 32.0 / 9.0;
425const DOPRI5_A51: f64 = 19372.0 / 6561.0;
426const DOPRI5_A52: f64 = -25360.0 / 2187.0;
427const DOPRI5_A53: f64 = 64448.0 / 6561.0;
428const DOPRI5_A54: f64 = -212.0 / 729.0;
429const DOPRI5_A61: f64 = 9017.0 / 3168.0;
430const DOPRI5_A62: f64 = -355.0 / 33.0;
431const DOPRI5_A63: f64 = 46732.0 / 5247.0;
432const DOPRI5_A64: f64 = 49.0 / 176.0;
433const DOPRI5_A65: f64 = -5103.0 / 18656.0;
434const DOPRI5_A71: f64 = 35.0 / 384.0;
435const DOPRI5_A73: f64 = 500.0 / 1113.0;
436const DOPRI5_A74: f64 = 125.0 / 192.0;
437const DOPRI5_A75: f64 = -2187.0 / 6784.0;
438const DOPRI5_A76: f64 = 11.0 / 84.0;
439
440const DOPRI5_E1: f64 = 71.0 / 57600.0;
442const DOPRI5_E3: f64 = -71.0 / 16695.0;
443const DOPRI5_E4: f64 = 71.0 / 1920.0;
444const DOPRI5_E5: f64 = -17253.0 / 339200.0;
445const DOPRI5_E6: f64 = 22.0 / 525.0;
446const DOPRI5_E7: f64 = -1.0 / 40.0;
447
448const DOPRI5_SAFETY: f64 = 0.9;
449const DOPRI5_MIN_FACTOR: f64 = 0.2;
450const DOPRI5_MAX_FACTOR: f64 = 10.0;
451const DOPRI5_ORDER: f64 = 5.0;
452
453pub fn dopri5_solve(
465 func: &dyn OdeFunc,
466 t0: f64,
467 t1: f64,
468 y0: &[f64],
469 params: &[f64],
470 config: &OdeSolverConfig,
471) -> Result<AdaptiveSolution, OdeError> {
472 if t0 == t1 {
473 return Ok(AdaptiveSolution {
474 solution: OdeSolution {
475 times: vec![t0],
476 states: vec![y0.to_vec()],
477 nfev: 0,
478 },
479 rejected_steps: 0,
480 final_step_size: 0.0,
481 });
482 }
483
484 if y0.is_empty() {
485 return Err(OdeError::InvalidInput("state vector is empty".into()));
486 }
487
488 let forward = t1 > t0;
489 let sign = if forward { 1.0_f64 } else { -1.0_f64 };
490 let span = (t1 - t0).abs();
491
492 let f0 = func.call(t0, y0, params);
494 let d0 = (y0.iter().map(|x| x * x).sum::<f64>() / y0.len() as f64).sqrt();
495 let d1 = (f0.iter().map(|x| x * x).sum::<f64>() / f0.len() as f64).sqrt();
496 let h0 = if d0 < 1e-5 || d1 < 1e-5 {
497 1e-6
498 } else {
499 0.01 * d0 / d1
500 };
501 let mut h = sign * h0.min(span).min(config.max_step);
502
503 let mut t = t0;
504 let mut y = y0.to_vec();
505 let mut k1 = f0;
506 let mut nfev = 1usize; let mut times = vec![t0];
509 let mut states = vec![y0.to_vec()];
510
511 let mut rejected_steps = 0usize;
512 let mut steps = 0usize;
513
514 while (sign * (t1 - t)).abs() > f64::EPSILON * span.max(1.0) {
515 if steps >= config.max_steps {
516 return Err(OdeError::MaxStepsExceeded);
517 }
518
519 if (t + h - t1) * sign > 0.0 {
521 h = t1 - t;
522 }
523
524 let h_abs = h.abs();
525 if h_abs < config.min_step {
526 return Err(OdeError::StepTooSmall);
527 }
528
529 let y2 = vec_axpy(&y, DOPRI5_A21 * h, &k1);
531 let k2 = func.call(t + h / 5.0, &y2, params);
532 nfev += 1;
533
534 let y3: Vec<f64> = y
536 .iter()
537 .zip(k1.iter())
538 .zip(k2.iter())
539 .map(|((yi, k1i), k2i)| yi + h * (DOPRI5_A31 * k1i + DOPRI5_A32 * k2i))
540 .collect();
541 let k3 = func.call(t + h * 3.0 / 10.0, &y3, params);
542 nfev += 1;
543
544 let y4: Vec<f64> = y
546 .iter()
547 .zip(k1.iter())
548 .zip(k2.iter())
549 .zip(k3.iter())
550 .map(|(((yi, k1i), k2i), k3i)| {
551 yi + h * (DOPRI5_A41 * k1i + DOPRI5_A42 * k2i + DOPRI5_A43 * k3i)
552 })
553 .collect();
554 let k4 = func.call(t + h * 4.0 / 5.0, &y4, params);
555 nfev += 1;
556
557 let y5: Vec<f64> = y
559 .iter()
560 .zip(k1.iter())
561 .zip(k2.iter())
562 .zip(k3.iter())
563 .zip(k4.iter())
564 .map(|((((yi, k1i), k2i), k3i), k4i)| {
565 yi + h * (DOPRI5_A51 * k1i + DOPRI5_A52 * k2i + DOPRI5_A53 * k3i + DOPRI5_A54 * k4i)
566 })
567 .collect();
568 let k5 = func.call(t + h * 8.0 / 9.0, &y5, params);
569 nfev += 1;
570
571 let y6: Vec<f64> = y
573 .iter()
574 .zip(k1.iter())
575 .zip(k2.iter())
576 .zip(k3.iter())
577 .zip(k4.iter())
578 .zip(k5.iter())
579 .map(|(((((yi, k1i), k2i), k3i), k4i), k5i)| {
580 yi + h
581 * (DOPRI5_A61 * k1i
582 + DOPRI5_A62 * k2i
583 + DOPRI5_A63 * k3i
584 + DOPRI5_A64 * k4i
585 + DOPRI5_A65 * k5i)
586 })
587 .collect();
588 let k6 = func.call(t + h, &y6, params);
589 nfev += 1;
590
591 let y_new: Vec<f64> = y
593 .iter()
594 .zip(k1.iter())
595 .zip(k3.iter())
596 .zip(k4.iter())
597 .zip(k5.iter())
598 .zip(k6.iter())
599 .map(|(((((yi, k1i), k3i), k4i), k5i), k6i)| {
600 yi + h
601 * (DOPRI5_A71 * k1i
602 + DOPRI5_A73 * k3i
603 + DOPRI5_A74 * k4i
604 + DOPRI5_A75 * k5i
605 + DOPRI5_A76 * k6i)
606 })
607 .collect();
608
609 if has_diverged(&y_new) {
610 return Err(OdeError::DivergentSolution);
611 }
612
613 let k7 = func.call(t + h, &y_new, params);
615 nfev += 1;
616
617 let err: Vec<f64> = k1
619 .iter()
620 .zip(k3.iter())
621 .zip(k4.iter())
622 .zip(k5.iter())
623 .zip(k6.iter())
624 .zip(k7.iter())
625 .map(|(((((e1, e3), e4), e5), e6), e7)| {
626 h * (DOPRI5_E1 * e1
627 + DOPRI5_E3 * e3
628 + DOPRI5_E4 * e4
629 + DOPRI5_E5 * e5
630 + DOPRI5_E6 * e6
631 + DOPRI5_E7 * e7)
632 })
633 .collect();
634
635 let error_norm_val = error_norm(&err, &y, &y_new, config.rtol, config.atol);
636
637 if error_norm_val <= 1.0 {
638 t += h;
640 y = y_new;
641 k1 = k7; if config.dense_output {
644 times.push(t);
645 states.push(y.clone());
646 }
647 steps += 1;
648
649 let factor = if error_norm_val == 0.0 {
651 DOPRI5_MAX_FACTOR
652 } else {
653 (DOPRI5_SAFETY * error_norm_val.powf(-1.0 / DOPRI5_ORDER))
654 .clamp(DOPRI5_MIN_FACTOR, DOPRI5_MAX_FACTOR)
655 };
656 h *= factor;
657 h = h.abs().min(config.max_step) * sign;
658 } else {
659 rejected_steps += 1;
661 let factor = (DOPRI5_SAFETY * error_norm_val.powf(-1.0 / DOPRI5_ORDER))
662 .clamp(DOPRI5_MIN_FACTOR, 1.0);
663 h *= factor;
664 }
665 }
666
667 if !config.dense_output || times.last().map(|&last| last != t).unwrap_or(true) {
670 times.push(t);
671 states.push(y.clone());
672 }
673
674 Ok(AdaptiveSolution {
675 solution: OdeSolution {
676 times,
677 states,
678 nfev,
679 },
680 rejected_steps,
681 final_step_size: h.abs(),
682 })
683}
684
685pub struct NeuralOde<F: OdeFunc> {
695 func: F,
696 t0: f64,
697 t1: f64,
698 config: OdeSolverConfig,
699}
700
701impl<F: OdeFunc> NeuralOde<F> {
702 pub fn new(func: F, t0: f64, t1: f64) -> Self {
704 Self {
705 func,
706 t0,
707 t1,
708 config: OdeSolverConfig::default(),
709 }
710 }
711
712 pub fn with_config(func: F, t0: f64, t1: f64, config: OdeSolverConfig) -> Self {
714 Self {
715 func,
716 t0,
717 t1,
718 config,
719 }
720 }
721
722 pub fn forward(&self, y0: &[f64], params: &[f64]) -> Result<OdeSolution, OdeError> {
726 if y0.is_empty() {
727 return Err(OdeError::InvalidInput("initial state is empty".into()));
728 }
729 let adaptive = dopri5_solve(&self.func, self.t0, self.t1, y0, params, &self.config)?;
730 Ok(adaptive.solution)
731 }
732
733 pub fn adjoint(
743 &self,
744 y0: &[f64],
745 params: &[f64],
746 grad_output: &[f64],
747 ) -> Result<AdjointResult, OdeError> {
748 if y0.len() != grad_output.len() {
749 return Err(OdeError::InvalidInput(format!(
750 "grad_output length {} does not match state dimension {}",
751 grad_output.len(),
752 y0.len()
753 )));
754 }
755
756 let fwd_config = OdeSolverConfig {
758 dense_output: true,
759 ..self.config.clone()
760 };
761 let adaptive = dopri5_solve(&self.func, self.t0, self.t1, y0, params, &fwd_config)?;
762 let fwd_nfev = adaptive.solution.nfev;
763
764 let adj_result = adjoint_backward(
765 &self.func,
766 &adaptive.solution,
767 params,
768 grad_output,
769 &self.config,
770 );
771
772 Ok(AdjointResult {
773 total_nfev: fwd_nfev + adj_result.total_nfev,
774 ..adj_result
775 })
776 }
777}
778
779fn adjoint_backward(
799 func: &dyn OdeFunc,
800 solution: &OdeSolution,
801 params: &[f64],
802 grad_output: &[f64],
803 _config: &OdeSolverConfig,
804) -> AdjointResult {
805 let n_state = grad_output.len();
806 let n_params = params.len();
807
808 let final_state = solution
809 .states
810 .last()
811 .cloned()
812 .unwrap_or_else(|| grad_output.to_vec());
813
814 let mut a = grad_output.to_vec();
816 let mut grad_params = vec![0.0_f64; n_params];
817 let mut total_nfev = 0usize;
818
819 let adj_steps_per_interval = 4usize;
821
822 let n_intervals = solution.times.len().saturating_sub(1);
824 for interval_idx in (0..n_intervals).rev() {
825 let t_start = solution.times[interval_idx + 1];
826 let t_end = solution.times[interval_idx];
827 let y_start = &solution.states[interval_idx + 1];
828 let y_end = &solution.states[interval_idx];
829
830 let h = (t_end - t_start) / adj_steps_per_interval as f64;
832
833 let mut t_cur = t_start;
834
835 for step_idx in 0..adj_steps_per_interval {
836 let alpha = step_idx as f64 / adj_steps_per_interval as f64;
839 let y_interp: Vec<f64> = y_start
840 .iter()
841 .zip(y_end.iter())
842 .map(|(ys, ye)| ys + alpha * (ye - ys))
843 .collect();
844
845 let aug_rhs =
847 |t_local: f64, a_local: &[f64], y_local: &[f64]| -> (Vec<f64>, Vec<f64>) {
848 let (da_dy, _da_dt, da_dp) = func.vjp(t_local, y_local, params, a_local);
849 let a_dot: Vec<f64> = da_dy.iter().map(|x| -x).collect();
851 let gp_dot: Vec<f64> = da_dp.iter().map(|x| -x).collect();
853 (a_dot, gp_dot)
854 };
855
856 let (k1_a, k1_gp) = aug_rhs(t_cur, &a, &y_interp);
858 total_nfev += 1;
859
860 let a2 = vec_axpy(&a, h * 0.5, &k1_a);
861 let alpha2 = (step_idx as f64 + 0.5) / adj_steps_per_interval as f64;
862 let y2: Vec<f64> = y_start
863 .iter()
864 .zip(y_end.iter())
865 .map(|(ys, ye)| ys + alpha2 * (ye - ys))
866 .collect();
867 let (k2_a, k2_gp) = aug_rhs(t_cur + h * 0.5, &a2, &y2);
868 total_nfev += 1;
869
870 let a3 = vec_axpy(&a, h * 0.5, &k2_a);
871 let (k3_a, k3_gp) = aug_rhs(t_cur + h * 0.5, &a3, &y2);
872 total_nfev += 1;
873
874 let a4 = vec_axpy(&a, h, &k3_a);
875 let alpha_end = (step_idx + 1) as f64 / adj_steps_per_interval as f64;
876 let y4: Vec<f64> = y_start
877 .iter()
878 .zip(y_end.iter())
879 .map(|(ys, ye)| ys + alpha_end * (ye - ys))
880 .collect();
881 let (k4_a, k4_gp) = aug_rhs(t_cur + h, &a4, &y4);
882 total_nfev += 1;
883
884 a = a
886 .iter()
887 .zip(k1_a.iter())
888 .zip(k2_a.iter())
889 .zip(k3_a.iter())
890 .zip(k4_a.iter())
891 .map(|((((ai, k1i), k2i), k3i), k4i)| {
892 ai + h / 6.0 * (k1i + 2.0 * k2i + 2.0 * k3i + k4i)
893 })
894 .collect();
895
896 grad_params = grad_params
897 .iter()
898 .zip(k1_gp.iter())
899 .zip(k2_gp.iter())
900 .zip(k3_gp.iter())
901 .zip(k4_gp.iter())
902 .map(|((((gp, k1i), k2i), k3i), k4i)| {
903 gp + h / 6.0 * (k1i + 2.0 * k2i + 2.0 * k3i + k4i)
904 })
905 .collect();
906
907 t_cur += h;
908 }
909
910 let _ = n_state; }
912
913 AdjointResult {
914 final_state,
915 grad_y0: a,
916 grad_params,
917 total_nfev,
918 }
919}
920
921#[cfg(test)]
926mod tests {
927 use super::*;
928
929 struct ConstantOde;
933 impl OdeFunc for ConstantOde {
934 fn call(&self, _t: f64, _y: &[f64], _params: &[f64]) -> Vec<f64> {
935 vec![0.0]
936 }
937 fn vjp(
938 &self,
939 _t: f64,
940 _y: &[f64],
941 _params: &[f64],
942 _grad: &[f64],
943 ) -> (Vec<f64>, f64, Vec<f64>) {
944 (vec![0.0], 0.0, vec![])
945 }
946 }
947
948 struct ExponentialGrowthOde;
950 impl OdeFunc for ExponentialGrowthOde {
951 fn call(&self, _t: f64, y: &[f64], _params: &[f64]) -> Vec<f64> {
952 vec![y[0]]
953 }
954 fn vjp(
955 &self,
956 _t: f64,
957 _y: &[f64],
958 _params: &[f64],
959 grad: &[f64],
960 ) -> (Vec<f64>, f64, Vec<f64>) {
961 (grad.to_vec(), 0.0, vec![])
963 }
964 }
965
966 struct ExponentialDecayOde;
968 impl OdeFunc for ExponentialDecayOde {
969 fn call(&self, _t: f64, y: &[f64], _params: &[f64]) -> Vec<f64> {
970 vec![-y[0]]
971 }
972 fn vjp(
973 &self,
974 _t: f64,
975 _y: &[f64],
976 _params: &[f64],
977 grad: &[f64],
978 ) -> (Vec<f64>, f64, Vec<f64>) {
979 (grad.iter().map(|g| -g).collect(), 0.0, vec![])
980 }
981 }
982
983 struct OscillatorOde;
985 impl OdeFunc for OscillatorOde {
986 fn call(&self, _t: f64, y: &[f64], _params: &[f64]) -> Vec<f64> {
987 vec![y[1], -y[0]]
988 }
989 fn vjp(
990 &self,
991 _t: f64,
992 _y: &[f64],
993 _params: &[f64],
994 grad: &[f64],
995 ) -> (Vec<f64>, f64, Vec<f64>) {
996 let ga = grad[1]; let gb = grad[0]; (vec![-ga, gb], 0.0, vec![])
1000 }
1001 }
1002
1003 struct LinearParamOde;
1005 impl OdeFunc for LinearParamOde {
1006 fn call(&self, _t: f64, y: &[f64], params: &[f64]) -> Vec<f64> {
1007 vec![params[0] * y[0]]
1008 }
1009 fn vjp(
1010 &self,
1011 _t: f64,
1012 y: &[f64],
1013 params: &[f64],
1014 grad: &[f64],
1015 ) -> (Vec<f64>, f64, Vec<f64>) {
1016 let grad_y = vec![grad[0] * params[0]];
1017 let grad_p = vec![grad[0] * y[0]];
1018 (grad_y, 0.0, grad_p)
1019 }
1020 }
1021
1022 struct StiffOde;
1024 impl OdeFunc for StiffOde {
1025 fn call(&self, _t: f64, y: &[f64], _params: &[f64]) -> Vec<f64> {
1026 vec![-1000.0 * y[0]]
1027 }
1028 fn vjp(
1029 &self,
1030 _t: f64,
1031 _y: &[f64],
1032 _params: &[f64],
1033 grad: &[f64],
1034 ) -> (Vec<f64>, f64, Vec<f64>) {
1035 (grad.iter().map(|g| -1000.0 * g).collect(), 0.0, vec![])
1036 }
1037 }
1038
1039 #[test]
1043 fn test_rk4_constant_ode() {
1044 let init_val = 42.0_f64;
1045 let sol = rk4_solve(&ConstantOde, 0.0, 1.0, &[init_val], &[], 100);
1046 let final_y = sol.states.last().unwrap()[0];
1047 assert!(
1048 (final_y - init_val).abs() < 1e-12,
1049 "constant ODE should stay at {init_val}, got {final_y}"
1050 );
1051 }
1052
1053 #[test]
1057 fn test_rk4_exponential_growth() {
1058 let sol = rk4_solve(&ExponentialGrowthOde, 0.0, 1.0, &[1.0], &[], 10_000);
1059 let final_y = sol.states.last().unwrap()[0];
1060 let exact = std::f64::consts::E;
1061 assert!(
1062 (final_y - exact).abs() < 1e-6,
1063 "RK4 exponential growth: got {final_y}, expected {exact}"
1064 );
1065 }
1066
1067 #[test]
1071 fn test_rk4_exponential_decay() {
1072 let sol = rk4_solve(&ExponentialDecayOde, 0.0, 1.0, &[1.0], &[], 10_000);
1073 let final_y = sol.states.last().unwrap()[0];
1074 let exact = (-1.0_f64).exp();
1075 assert!(
1076 (final_y - exact).abs() < 1e-6,
1077 "RK4 exponential decay: got {final_y}, expected {exact}"
1078 );
1079 }
1080
1081 #[test]
1085 fn test_rk4_oscillator_2d() {
1086 use std::f64::consts::PI;
1088 let sol = rk4_solve(&OscillatorOde, 0.0, 2.0 * PI, &[1.0, 0.0], &[], 100_000);
1089 let last = sol.states.last().unwrap();
1090 assert!(
1092 (last[0] - 1.0).abs() < 1e-4,
1093 "oscillator x: got {}",
1094 last[0]
1095 );
1096 assert!(last[1].abs() < 1e-4, "oscillator y: got {}", last[1]);
1097 }
1098
1099 #[test]
1103 fn test_dopri5_more_accurate_than_rk4() {
1104 let exact = std::f64::consts::E;
1105
1106 let rk4_sol = rk4_solve(&ExponentialGrowthOde, 0.0, 1.0, &[1.0], &[], 10);
1108 let rk4_err = (rk4_sol.states.last().unwrap()[0] - exact).abs();
1109
1110 let config = OdeSolverConfig::new().rtol(1e-8).atol(1e-10);
1112 let dp5 = dopri5_solve(&ExponentialGrowthOde, 0.0, 1.0, &[1.0], &[], &config).unwrap();
1113 let dp5_err = (dp5.solution.states.last().unwrap()[0] - exact).abs();
1114
1115 assert!(
1116 dp5_err < rk4_err,
1117 "DOPRI5 (tight tol) error {dp5_err} should be less than coarse RK4 error {rk4_err}"
1118 );
1119 assert!(
1121 dp5_err < 1e-6,
1122 "DOPRI5 with rtol=1e-8/atol=1e-10 should achieve < 1e-6 error, got {dp5_err}"
1123 );
1124 }
1125
1126 #[test]
1130 fn test_dopri5_step_rejection_on_stiff() {
1131 let config = OdeSolverConfig::new().rtol(1e-6).atol(1e-8).max_steps(5000);
1132 let result = dopri5_solve(&StiffOde, 0.0, 0.01, &[1.0], &[], &config);
1134 match result {
1136 Ok(adaptive) => {
1137 let _ = adaptive.rejected_steps;
1139 }
1140 Err(OdeError::StepTooSmall) | Err(OdeError::MaxStepsExceeded) => {
1141 }
1143 Err(e) => panic!("unexpected error: {e}"),
1144 }
1145 }
1146
1147 #[test]
1151 fn test_solver_config_builder() {
1152 let cfg = OdeSolverConfig::new().rtol(1e-8).atol(1e-10).max_steps(500);
1153 assert!((cfg.rtol - 1e-8).abs() < 1e-15);
1154 assert!((cfg.atol - 1e-10).abs() < 1e-18);
1155 assert_eq!(cfg.max_steps, 500);
1156 }
1157
1158 #[test]
1162 fn test_neural_ode_forward_correct_endpoint() {
1163 let ode = NeuralOde::new(ExponentialGrowthOde, 0.0, 1.0);
1164 let sol = ode.forward(&[1.0], &[]).unwrap();
1165 let final_y = sol.states.last().unwrap()[0];
1166 let exact = std::f64::consts::E;
1167 assert!(
1168 (final_y - exact).abs() < 1e-4,
1169 "NeuralOde forward: got {final_y}, expected ~{exact}"
1170 );
1171 }
1172
1173 #[test]
1177 fn test_neural_ode_forward_t0_equals_t1() {
1178 let init_val = 7.5_f64; let ode = NeuralOde::new(ExponentialGrowthOde, 1.5, 1.5);
1180 let sol = ode.forward(&[init_val], &[]).unwrap();
1181 assert!((sol.states[0][0] - init_val).abs() < 1e-12);
1183 }
1184
1185 #[test]
1189 fn test_max_steps_exceeded_on_stiff() {
1190 let config = OdeSolverConfig::new().rtol(1e-12).atol(1e-14).max_steps(5); let result = dopri5_solve(&StiffOde, 0.0, 1.0, &[1.0], &[], &config);
1193 assert!(
1194 matches!(
1195 result,
1196 Err(OdeError::MaxStepsExceeded) | Err(OdeError::StepTooSmall)
1197 ),
1198 "expected MaxStepsExceeded or StepTooSmall"
1199 );
1200 }
1201
1202 #[test]
1206 fn test_nfev_is_positive() {
1207 let sol = rk4_solve(&ExponentialGrowthOde, 0.0, 1.0, &[1.0], &[], 10);
1208 assert!(sol.nfev > 0, "nfev should be > 0, got {}", sol.nfev);
1209 assert_eq!(sol.nfev, 40, "RK4 should use 4 * num_steps evaluations");
1211 }
1212
1213 #[test]
1217 fn test_rejected_steps_field_exists() {
1218 let config = OdeSolverConfig::new();
1219 let adaptive = dopri5_solve(&ExponentialGrowthOde, 0.0, 1.0, &[1.0], &[], &config).unwrap();
1220 let _ = adaptive.rejected_steps; assert!(adaptive.solution.nfev > 0);
1223 }
1224
1225 #[test]
1229 fn test_dense_output_stores_intermediate_steps() {
1230 let config = OdeSolverConfig::new().rtol(1e-6).atol(1e-8);
1231 let adaptive = dopri5_solve(&ExponentialGrowthOde, 0.0, 1.0, &[1.0], &[], &config).unwrap();
1232 assert!(
1234 adaptive.solution.times.len() > 2,
1235 "dense output should contain more than 2 time points, got {}",
1236 adaptive.solution.times.len()
1237 );
1238 assert_eq!(
1239 adaptive.solution.times.len(),
1240 adaptive.solution.states.len(),
1241 "times and states must have the same length"
1242 );
1243 }
1244
1245 #[test]
1249 fn test_adjoint_grad_y0_dimension() {
1250 let ode = NeuralOde::new(LinearParamOde, 0.0, 0.5);
1251 let y0 = vec![1.0_f64];
1252 let params = vec![-1.0_f64];
1253 let grad_out = vec![1.0_f64];
1254 let adj = ode.adjoint(&y0, ¶ms, &grad_out).unwrap();
1255 assert_eq!(
1256 adj.grad_y0.len(),
1257 y0.len(),
1258 "grad_y0 must have same dim as y0"
1259 );
1260 }
1261
1262 #[test]
1266 fn test_adjoint_grad_params_dimension() {
1267 let ode = NeuralOde::new(LinearParamOde, 0.0, 0.5);
1268 let y0 = vec![1.0_f64];
1269 let params = vec![-1.0_f64];
1270 let grad_out = vec![1.0_f64];
1271 let adj = ode.adjoint(&y0, ¶ms, &grad_out).unwrap();
1272 assert_eq!(
1273 adj.grad_params.len(),
1274 params.len(),
1275 "grad_params must have same dim as params"
1276 );
1277 }
1278
1279 #[test]
1283 fn test_ode_error_display() {
1284 let msgs = [
1285 (OdeError::MaxStepsExceeded, "max"),
1286 (OdeError::StepTooSmall, "step"),
1287 (OdeError::DivergentSolution, "diverged"),
1288 (OdeError::InvalidInput("bad".into()), "bad"),
1289 ];
1290 for (err, keyword) in msgs {
1291 let msg = format!("{err}");
1292 assert!(
1293 msg.to_lowercase().contains(keyword),
1294 "Display for {err:?} should contain '{keyword}', got: '{msg}'"
1295 );
1296 }
1297 }
1298
1299 #[test]
1303 fn test_forward_is_deterministic() {
1304 let ode = NeuralOde::new(ExponentialGrowthOde, 0.0, 1.0);
1305 let sol1 = ode.forward(&[1.0], &[]).unwrap();
1306 let sol2 = ode.forward(&[1.0], &[]).unwrap();
1307 let y1 = sol1.states.last().unwrap()[0];
1308 let y2 = sol2.states.last().unwrap()[0];
1309 assert_eq!(y1, y2, "repeated forward passes must be deterministic");
1310 }
1311
1312 #[test]
1316 fn test_rk4_convergence_with_steps() {
1317 let exact = std::f64::consts::E;
1318 let steps_list = [10usize, 100, 1000, 10_000];
1319 let mut prev_err = f64::INFINITY;
1320 for &n in &steps_list {
1321 let sol = rk4_solve(&ExponentialGrowthOde, 0.0, 1.0, &[1.0], &[], n);
1322 let err = (sol.states.last().unwrap()[0] - exact).abs();
1323 assert!(
1324 err < prev_err,
1325 "error {err} at n={n} is not less than prev {prev_err}"
1326 );
1327 prev_err = err;
1328 }
1329 assert!(
1332 prev_err < 1e-13,
1333 "RK4 with 10_000 steps: error {prev_err} > 1e-13"
1334 );
1335 }
1336
1337 #[test]
1341 fn test_dopri5_tolerance_affects_accuracy() {
1342 let exact = std::f64::consts::E;
1343
1344 let coarse = OdeSolverConfig::new().rtol(1e-3).atol(1e-5);
1345 let fine = OdeSolverConfig::new().rtol(1e-9).atol(1e-11);
1346
1347 let sol_coarse =
1348 dopri5_solve(&ExponentialGrowthOde, 0.0, 1.0, &[1.0], &[], &coarse).unwrap();
1349 let sol_fine = dopri5_solve(&ExponentialGrowthOde, 0.0, 1.0, &[1.0], &[], &fine).unwrap();
1350
1351 let err_coarse = (sol_coarse.solution.states.last().unwrap()[0] - exact).abs();
1352 let err_fine = (sol_fine.solution.states.last().unwrap()[0] - exact).abs();
1353
1354 assert!(
1355 err_fine < err_coarse,
1356 "fine tol error {err_fine} should be less than coarse tol error {err_coarse}"
1357 );
1358 }
1359
1360 #[test]
1364 fn test_neural_ode_params_affect_trajectory() {
1365 let ode = NeuralOde::new(LinearParamOde, 0.0, 1.0);
1367 let sol_pos = ode.forward(&[1.0], &[1.0]).unwrap(); let sol_neg = ode.forward(&[1.0], &[-1.0]).unwrap(); let y_pos = sol_pos.states.last().unwrap()[0];
1371 let y_neg = sol_neg.states.last().unwrap()[0];
1372
1373 assert!(
1374 y_pos > y_neg,
1375 "positive param should give larger y: y_pos={y_pos}, y_neg={y_neg}"
1376 );
1377 assert!(
1378 (y_pos - std::f64::consts::E).abs() < 1e-3,
1379 "y_pos ~ e, got {y_pos}"
1380 );
1381 assert!(
1382 (y_neg - (-1.0_f64).exp()).abs() < 1e-3,
1383 "y_neg ~ e^-1, got {y_neg}"
1384 );
1385 }
1386
1387 #[test]
1391 fn test_adjoint_result_fields() {
1392 let ode = NeuralOde::new(LinearParamOde, 0.0, 1.0);
1393 let adj = ode.adjoint(&[1.0], &[-1.0], &[1.0]).unwrap();
1394 assert!(adj.total_nfev > 0, "total_nfev should be > 0");
1395 assert!(!adj.final_state.is_empty(), "final_state must not be empty");
1396 assert!(!adj.grad_y0.is_empty(), "grad_y0 must not be empty");
1397 assert_eq!(adj.grad_params.len(), 1);
1399 }
1400
1401 #[test]
1405 fn test_solution_first_state_is_y0() {
1406 let y0 = vec![42.0_f64, -7.5];
1407 let sol = rk4_solve(&OscillatorOde, 0.0, 1.0, &y0, &[], 100);
1408 assert_eq!(&sol.states[0], &y0, "first stored state must equal y0");
1409 assert_eq!(sol.times[0], 0.0);
1410 }
1411}