Skip to main content

tensorlogic_train/
neural_ode.rs

1//! Neural ODE (Neural Ordinary Differential Equations) implementation.
2//!
3//! Provides the Dormand-Prince RK45 adaptive solver and adjoint sensitivity
4//! method for memory-efficient gradient computation through continuous dynamics.
5//!
6//! # Overview
7//!
8//! Neural ODEs replace discrete layer stacks with a continuous ODE:
9//! ```text
10//!   dy/dt = f(t, y, θ),   y(t0) = y0
11//! ```
12//! The output is `y(t1)` obtained by numerical integration. Gradients are
13//! computed via the adjoint sensitivity method, which avoids storing all
14//! intermediate states during the forward pass.
15//!
16//! # Example
17//! ```rust
18//! use tensorlogic_train::neural_ode::{NeuralOde, OdeFunc, OdeSolverConfig};
19//!
20//! struct LinearOde;
21//! impl OdeFunc for LinearOde {
22//!     fn call(&self, _t: f64, y: &[f64], params: &[f64]) -> Vec<f64> {
23//!         y.iter().zip(params.iter()).map(|(yi, pi)| yi * pi).collect()
24//!     }
25//!     fn vjp(&self, _t: f64, y: &[f64], params: &[f64], grad: &[f64])
26//!         -> (Vec<f64>, f64, Vec<f64>)
27//!     {
28//!         let dy = grad.iter().zip(params.iter()).map(|(g, p)| g * p).collect();
29//!         let dt = 0.0_f64;
30//!         let dp = grad.iter().zip(y.iter()).map(|(g, yi)| g * yi).collect();
31//!         (dy, dt, dp)
32//!     }
33//! }
34//!
35//! let ode = NeuralOde::new(LinearOde, 0.0, 1.0);
36//! let sol = ode.forward(&[1.0], &[-1.0]).unwrap();
37//! assert!((sol.states.last().unwrap()[0] - (-1.0_f64).exp()).abs() < 1e-3);
38//! ```
39
40use std::fmt;
41
42// ---------------------------------------------------------------------------
43// Public traits
44// ---------------------------------------------------------------------------
45
46/// ODE right-hand side: `dy/dt = f(t, y, params)`.
47///
48/// Implement this trait to define the dynamics of a Neural ODE layer.
49pub trait OdeFunc: Send + Sync {
50    /// Evaluate the ODE RHS at time `t`, state `y`, and parameters `params`.
51    fn call(&self, t: f64, y: &[f64], params: &[f64]) -> Vec<f64>;
52
53    /// Vector-Jacobian product (VJP) for the adjoint method.
54    ///
55    /// Returns `(dL/dy, dL/dt, dL/dparams)` given `grad_output = dL/df`.
56    ///
57    /// The default implementation uses finite differences (expensive but
58    /// correct). Override for analytic efficiency.
59    fn vjp(
60        &self,
61        t: f64,
62        y: &[f64],
63        params: &[f64],
64        grad_output: &[f64],
65    ) -> (Vec<f64>, f64, Vec<f64>) {
66        let eps = 1e-6_f64;
67        let n = y.len();
68        let p = params.len();
69
70        // dL/dy via finite differences
71        let mut grad_y = vec![0.0_f64; n];
72        for i in 0..n {
73            let mut y_plus = y.to_vec();
74            let mut y_minus = y.to_vec();
75            y_plus[i] += eps;
76            y_minus[i] -= eps;
77            let f_plus = self.call(t, &y_plus, params);
78            let f_minus = self.call(t, &y_minus, params);
79            for (k, go) in grad_output.iter().enumerate() {
80                grad_y[i] += go * (f_plus[k] - f_minus[k]) / (2.0 * eps);
81            }
82        }
83
84        // dL/dt via finite differences
85        let f_tplus = self.call(t + eps, y, params);
86        let f_tminus = self.call(t - eps, y, params);
87        let grad_t: f64 = grad_output
88            .iter()
89            .enumerate()
90            .map(|(k, go)| go * (f_tplus[k] - f_tminus[k]) / (2.0 * eps))
91            .sum();
92
93        // dL/dparams via finite differences
94        let mut grad_params = vec![0.0_f64; p];
95        for j in 0..p {
96            let mut p_plus = params.to_vec();
97            let mut p_minus = params.to_vec();
98            p_plus[j] += eps;
99            p_minus[j] -= eps;
100            let f_plus = self.call(t, y, &p_plus);
101            let f_minus = self.call(t, y, &p_minus);
102            for (k, go) in grad_output.iter().enumerate() {
103                grad_params[j] += go * (f_plus[k] - f_minus[k]) / (2.0 * eps);
104            }
105        }
106
107        (grad_y, grad_t, grad_params)
108    }
109}
110
111// ---------------------------------------------------------------------------
112// Result types
113// ---------------------------------------------------------------------------
114
115/// Result of a fixed-step RK4 integration.
116#[derive(Debug, Clone)]
117pub struct OdeSolution {
118    /// Time points at which the state was recorded.
119    pub times: Vec<f64>,
120    /// States corresponding to each time point. `states[i]` is `y(times[i])`.
121    pub states: Vec<Vec<f64>>,
122    /// Total number of ODE function evaluations.
123    pub nfev: usize,
124}
125
126/// Result of an adaptive Dormand-Prince RK45 integration.
127#[derive(Debug, Clone)]
128pub struct AdaptiveSolution {
129    /// The embedded ODE solution.
130    pub solution: OdeSolution,
131    /// Number of steps that were rejected due to error tolerance violation.
132    pub rejected_steps: usize,
133    /// Step size at the final accepted step.
134    pub final_step_size: f64,
135}
136
137/// Gradient information produced by the adjoint sensitivity method.
138#[derive(Debug, Clone)]
139pub struct AdjointResult {
140    /// Final state `y(t1)` from the forward pass.
141    pub final_state: Vec<f64>,
142    /// Gradient with respect to the initial state: `dL/dy0`.
143    pub grad_y0: Vec<f64>,
144    /// Gradient with respect to parameters: `dL/dθ`.
145    pub grad_params: Vec<f64>,
146    /// Total ODE function evaluations (forward + backward).
147    pub total_nfev: usize,
148}
149
150// ---------------------------------------------------------------------------
151// Solver configuration
152// ---------------------------------------------------------------------------
153
154/// Configuration for the adaptive ODE solver.
155#[derive(Debug, Clone)]
156pub struct OdeSolverConfig {
157    /// Relative tolerance (default `1e-4`).
158    pub rtol: f64,
159    /// Absolute tolerance (default `1e-6`).
160    pub atol: f64,
161    /// Maximum number of integration steps (default `1000`).
162    pub max_steps: usize,
163    /// Minimum allowed step size (default `1e-12`).
164    pub min_step: f64,
165    /// Maximum allowed step size (default `f64::INFINITY`).
166    pub max_step: f64,
167    /// Whether to store every accepted step (`true`) or only the endpoint
168    /// (`false`).
169    pub dense_output: bool,
170}
171
172impl Default for OdeSolverConfig {
173    fn default() -> Self {
174        Self {
175            rtol: 1e-4,
176            atol: 1e-6,
177            max_steps: 1000,
178            min_step: 1e-12,
179            max_step: f64::INFINITY,
180            dense_output: true,
181        }
182    }
183}
184
185impl OdeSolverConfig {
186    /// Create a new configuration with default values.
187    pub fn new() -> Self {
188        Self::default()
189    }
190
191    /// Set the relative tolerance (builder pattern).
192    pub fn rtol(mut self, v: f64) -> Self {
193        self.rtol = v;
194        self
195    }
196
197    /// Set the absolute tolerance (builder pattern).
198    pub fn atol(mut self, v: f64) -> Self {
199        self.atol = v;
200        self
201    }
202
203    /// Set the maximum number of integration steps (builder pattern).
204    pub fn max_steps(mut self, n: usize) -> Self {
205        self.max_steps = n;
206        self
207    }
208
209    /// Disable intermediate state storage (builder pattern).
210    pub fn no_dense_output(mut self) -> Self {
211        self.dense_output = false;
212        self
213    }
214}
215
216// ---------------------------------------------------------------------------
217// Error type
218// ---------------------------------------------------------------------------
219
220/// Errors that can occur during ODE integration.
221#[derive(Debug)]
222pub enum OdeError {
223    /// The solver exceeded the maximum number of allowed steps.
224    MaxStepsExceeded,
225    /// The adaptive solver required a step smaller than `min_step`.
226    StepTooSmall,
227    /// The solution grew without bound (detected by NaN/Inf in state).
228    DivergentSolution,
229    /// Invalid input parameters were supplied.
230    InvalidInput(String),
231}
232
233impl fmt::Display for OdeError {
234    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
235        match self {
236            OdeError::MaxStepsExceeded => write!(
237                f,
238                "ODE solver exceeded the maximum number of steps; \
239                 consider relaxing tolerances or increasing max_steps"
240            ),
241            OdeError::StepTooSmall => write!(
242                f,
243                "ODE solver step size fell below the minimum threshold; \
244                 the problem may be too stiff for this explicit solver"
245            ),
246            OdeError::DivergentSolution => write!(
247                f,
248                "ODE solution diverged (NaN or Inf encountered in state vector)"
249            ),
250            OdeError::InvalidInput(msg) => {
251                write!(f, "ODE solver received invalid input: {msg}")
252            }
253        }
254    }
255}
256
257impl std::error::Error for OdeError {}
258
259// ---------------------------------------------------------------------------
260// Helper arithmetic on Vec<f64>
261// ---------------------------------------------------------------------------
262
263#[inline]
264#[allow(dead_code)]
265fn vec_add(a: &[f64], b: &[f64]) -> Vec<f64> {
266    a.iter().zip(b.iter()).map(|(x, y)| x + y).collect()
267}
268
269#[inline]
270#[allow(dead_code)]
271fn vec_scale(v: &[f64], s: f64) -> Vec<f64> {
272    v.iter().map(|x| x * s).collect()
273}
274
275#[inline]
276fn vec_axpy(y: &[f64], alpha: f64, x: &[f64]) -> Vec<f64> {
277    y.iter()
278        .zip(x.iter())
279        .map(|(yi, xi)| yi + alpha * xi)
280        .collect()
281}
282
283/// Compute the mixed-tolerance norm used for step-size control.
284///
285/// Uses the standard formula `||e||_rms = sqrt( mean( (e_i / sc_i)^2 ) )`
286/// where `sc_i = atol + rtol * max(|y_i|, |y_new_i|)`.
287fn error_norm(err: &[f64], y: &[f64], y_new: &[f64], rtol: f64, atol: f64) -> f64 {
288    let n = err.len();
289    if n == 0 {
290        return 0.0;
291    }
292    let sum: f64 = err
293        .iter()
294        .zip(y.iter())
295        .zip(y_new.iter())
296        .map(|((e, yi), yn)| {
297            let sc = atol + rtol * yi.abs().max(yn.abs());
298            (e / sc).powi(2)
299        })
300        .sum();
301    (sum / n as f64).sqrt()
302}
303
304/// Check whether any element of the state is NaN or infinite.
305fn has_diverged(v: &[f64]) -> bool {
306    v.iter().any(|x| x.is_nan() || x.is_infinite())
307}
308
309// ---------------------------------------------------------------------------
310// Fixed-step RK4 solver
311// ---------------------------------------------------------------------------
312
313/// Integrate an ODE using classic 4th-order Runge-Kutta with a fixed step size.
314///
315/// # Arguments
316/// - `func`       – ODE right-hand side.
317/// - `t0`, `t1`  – integration interval `[t0, t1]`.
318/// - `y0`        – initial state.
319/// - `params`    – ODE parameters passed through to `func`.
320/// - `num_steps` – number of equal-width steps to take.
321///
322/// # Returns
323/// An [`OdeSolution`] that always contains both the initial and final states.
324/// If `dense_output` semantics are desired, all intermediate states are stored.
325pub fn rk4_solve(
326    func: &dyn OdeFunc,
327    t0: f64,
328    t1: f64,
329    y0: &[f64],
330    params: &[f64],
331    num_steps: usize,
332) -> OdeSolution {
333    let steps = num_steps.max(1);
334    let h = (t1 - t0) / steps as f64;
335
336    let mut times = Vec::with_capacity(steps + 1);
337    let mut states = Vec::with_capacity(steps + 1);
338    let mut nfev = 0usize;
339
340    times.push(t0);
341    states.push(y0.to_vec());
342
343    let mut t = t0;
344    let mut y = y0.to_vec();
345
346    for _ in 0..steps {
347        // k1
348        let k1 = func.call(t, &y, params);
349        nfev += 1;
350        // k2
351        let y2 = vec_axpy(&y, h * 0.5, &k1);
352        let k2 = func.call(t + h * 0.5, &y2, params);
353        nfev += 1;
354        // k3
355        let y3 = vec_axpy(&y, h * 0.5, &k2);
356        let k3 = func.call(t + h * 0.5, &y3, params);
357        nfev += 1;
358        // k4
359        let y4 = vec_axpy(&y, h, &k3);
360        let k4 = func.call(t + h, &y4, params);
361        nfev += 1;
362
363        // y_next = y + h/6 * (k1 + 2*k2 + 2*k3 + k4)
364        y = y
365            .iter()
366            .zip(k1.iter())
367            .zip(k2.iter())
368            .zip(k3.iter())
369            .zip(k4.iter())
370            .map(|((((yi, k1i), k2i), k3i), k4i)| {
371                yi + h / 6.0 * (k1i + 2.0 * k2i + 2.0 * k3i + k4i)
372            })
373            .collect();
374        t += h;
375
376        times.push(t);
377        states.push(y.clone());
378    }
379
380    OdeSolution {
381        times,
382        states,
383        nfev,
384    }
385}
386
387// ---------------------------------------------------------------------------
388// Dormand-Prince RK45 adaptive solver (DOPRI5)
389// ---------------------------------------------------------------------------
390
391/// Dormand-Prince Butcher tableau coefficients (DOPRI5 / RK45).
392///
393/// c-coefficients (nodes):
394/// ```text
395///   c = [0, 1/5, 3/10, 4/5, 8/9, 1, 1]
396/// ```
397///
398/// a-coefficients (Runge-Kutta matrix, lower-triangular):
399/// ```text
400///   a21 = 1/5
401///   a31 = 3/40,      a32 = 9/40
402///   a41 = 44/45,     a42 = -56/15,    a43 = 32/9
403///   a51 = 19372/6561, a52=-25360/2187, a53=64448/6561, a54=-212/729
404///   a61 = 9017/3168,  a62=-355/33,     a63=46732/5247, a64=49/176,  a65=-5103/18656
405///   a71 = 35/384,     a72 = 0,         a73=500/1113,   a74=125/192, a75=-2187/6784, a76=11/84
406/// ```
407///
408/// 5th-order weights `b5` = a7x (FSAL property):
409/// ```text
410///   b5 = [35/384, 0, 500/1113, 125/192, -2187/6784, 11/84, 0]
411/// ```
412///
413/// 4th-order weights `b4` (embedded):
414/// ```text
415///   b4 = [5179/57600, 0, 7571/16695, 393/640, -92097/339200, 187/2100, 1/40]
416/// ```
417///
418/// Error coefficients `e = b5 - b4`:
419const DOPRI5_A21: f64 = 1.0 / 5.0;
420const DOPRI5_A31: f64 = 3.0 / 40.0;
421const DOPRI5_A32: f64 = 9.0 / 40.0;
422const DOPRI5_A41: f64 = 44.0 / 45.0;
423const DOPRI5_A42: f64 = -56.0 / 15.0;
424const DOPRI5_A43: f64 = 32.0 / 9.0;
425const DOPRI5_A51: f64 = 19372.0 / 6561.0;
426const DOPRI5_A52: f64 = -25360.0 / 2187.0;
427const DOPRI5_A53: f64 = 64448.0 / 6561.0;
428const DOPRI5_A54: f64 = -212.0 / 729.0;
429const DOPRI5_A61: f64 = 9017.0 / 3168.0;
430const DOPRI5_A62: f64 = -355.0 / 33.0;
431const DOPRI5_A63: f64 = 46732.0 / 5247.0;
432const DOPRI5_A64: f64 = 49.0 / 176.0;
433const DOPRI5_A65: f64 = -5103.0 / 18656.0;
434const DOPRI5_A71: f64 = 35.0 / 384.0;
435const DOPRI5_A73: f64 = 500.0 / 1113.0;
436const DOPRI5_A74: f64 = 125.0 / 192.0;
437const DOPRI5_A75: f64 = -2187.0 / 6784.0;
438const DOPRI5_A76: f64 = 11.0 / 84.0;
439
440// Error coefficients e_i = b5_i - b4_i
441const DOPRI5_E1: f64 = 71.0 / 57600.0;
442const DOPRI5_E3: f64 = -71.0 / 16695.0;
443const DOPRI5_E4: f64 = 71.0 / 1920.0;
444const DOPRI5_E5: f64 = -17253.0 / 339200.0;
445const DOPRI5_E6: f64 = 22.0 / 525.0;
446const DOPRI5_E7: f64 = -1.0 / 40.0;
447
448const DOPRI5_SAFETY: f64 = 0.9;
449const DOPRI5_MIN_FACTOR: f64 = 0.2;
450const DOPRI5_MAX_FACTOR: f64 = 10.0;
451const DOPRI5_ORDER: f64 = 5.0;
452
453/// Integrate an ODE with the adaptive Dormand-Prince RK45 method (DOPRI5).
454///
455/// # Arguments
456/// - `func`   – ODE right-hand side.
457/// - `t0`, `t1` – integration interval.
458/// - `y0`    – initial state.
459/// - `params` – ODE parameters.
460/// - `config` – solver tolerances and limits.
461///
462/// # Returns
463/// `Ok(AdaptiveSolution)` on success, or an [`OdeError`] if the solver fails.
464pub fn dopri5_solve(
465    func: &dyn OdeFunc,
466    t0: f64,
467    t1: f64,
468    y0: &[f64],
469    params: &[f64],
470    config: &OdeSolverConfig,
471) -> Result<AdaptiveSolution, OdeError> {
472    if t0 == t1 {
473        return Ok(AdaptiveSolution {
474            solution: OdeSolution {
475                times: vec![t0],
476                states: vec![y0.to_vec()],
477                nfev: 0,
478            },
479            rejected_steps: 0,
480            final_step_size: 0.0,
481        });
482    }
483
484    if y0.is_empty() {
485        return Err(OdeError::InvalidInput("state vector is empty".into()));
486    }
487
488    let forward = t1 > t0;
489    let sign = if forward { 1.0_f64 } else { -1.0_f64 };
490    let span = (t1 - t0).abs();
491
492    // Initial step size heuristic
493    let f0 = func.call(t0, y0, params);
494    let d0 = (y0.iter().map(|x| x * x).sum::<f64>() / y0.len() as f64).sqrt();
495    let d1 = (f0.iter().map(|x| x * x).sum::<f64>() / f0.len() as f64).sqrt();
496    let h0 = if d0 < 1e-5 || d1 < 1e-5 {
497        1e-6
498    } else {
499        0.01 * d0 / d1
500    };
501    let mut h = sign * h0.min(span).min(config.max_step);
502
503    let mut t = t0;
504    let mut y = y0.to_vec();
505    let mut k1 = f0;
506    let mut nfev = 1usize; // already evaluated f0
507
508    let mut times = vec![t0];
509    let mut states = vec![y0.to_vec()];
510
511    let mut rejected_steps = 0usize;
512    let mut steps = 0usize;
513
514    while (sign * (t1 - t)).abs() > f64::EPSILON * span.max(1.0) {
515        if steps >= config.max_steps {
516            return Err(OdeError::MaxStepsExceeded);
517        }
518
519        // Clamp step to not overshoot
520        if (t + h - t1) * sign > 0.0 {
521            h = t1 - t;
522        }
523
524        let h_abs = h.abs();
525        if h_abs < config.min_step {
526            return Err(OdeError::StepTooSmall);
527        }
528
529        // Stage 2
530        let y2 = vec_axpy(&y, DOPRI5_A21 * h, &k1);
531        let k2 = func.call(t + h / 5.0, &y2, params);
532        nfev += 1;
533
534        // Stage 3
535        let y3: Vec<f64> = y
536            .iter()
537            .zip(k1.iter())
538            .zip(k2.iter())
539            .map(|((yi, k1i), k2i)| yi + h * (DOPRI5_A31 * k1i + DOPRI5_A32 * k2i))
540            .collect();
541        let k3 = func.call(t + h * 3.0 / 10.0, &y3, params);
542        nfev += 1;
543
544        // Stage 4
545        let y4: Vec<f64> = y
546            .iter()
547            .zip(k1.iter())
548            .zip(k2.iter())
549            .zip(k3.iter())
550            .map(|(((yi, k1i), k2i), k3i)| {
551                yi + h * (DOPRI5_A41 * k1i + DOPRI5_A42 * k2i + DOPRI5_A43 * k3i)
552            })
553            .collect();
554        let k4 = func.call(t + h * 4.0 / 5.0, &y4, params);
555        nfev += 1;
556
557        // Stage 5
558        let y5: Vec<f64> = y
559            .iter()
560            .zip(k1.iter())
561            .zip(k2.iter())
562            .zip(k3.iter())
563            .zip(k4.iter())
564            .map(|((((yi, k1i), k2i), k3i), k4i)| {
565                yi + h * (DOPRI5_A51 * k1i + DOPRI5_A52 * k2i + DOPRI5_A53 * k3i + DOPRI5_A54 * k4i)
566            })
567            .collect();
568        let k5 = func.call(t + h * 8.0 / 9.0, &y5, params);
569        nfev += 1;
570
571        // Stage 6
572        let y6: Vec<f64> = y
573            .iter()
574            .zip(k1.iter())
575            .zip(k2.iter())
576            .zip(k3.iter())
577            .zip(k4.iter())
578            .zip(k5.iter())
579            .map(|(((((yi, k1i), k2i), k3i), k4i), k5i)| {
580                yi + h
581                    * (DOPRI5_A61 * k1i
582                        + DOPRI5_A62 * k2i
583                        + DOPRI5_A63 * k3i
584                        + DOPRI5_A64 * k4i
585                        + DOPRI5_A65 * k5i)
586            })
587            .collect();
588        let k6 = func.call(t + h, &y6, params);
589        nfev += 1;
590
591        // 5th-order solution (= next k1 via FSAL)
592        let y_new: Vec<f64> = y
593            .iter()
594            .zip(k1.iter())
595            .zip(k3.iter())
596            .zip(k4.iter())
597            .zip(k5.iter())
598            .zip(k6.iter())
599            .map(|(((((yi, k1i), k3i), k4i), k5i), k6i)| {
600                yi + h
601                    * (DOPRI5_A71 * k1i
602                        + DOPRI5_A73 * k3i
603                        + DOPRI5_A74 * k4i
604                        + DOPRI5_A75 * k5i
605                        + DOPRI5_A76 * k6i)
606            })
607            .collect();
608
609        if has_diverged(&y_new) {
610            return Err(OdeError::DivergentSolution);
611        }
612
613        // Stage 7 (FSAL: this becomes k1 of the next step)
614        let k7 = func.call(t + h, &y_new, params);
615        nfev += 1;
616
617        // Error estimate using e_i = b5_i - b4_i
618        let err: Vec<f64> = k1
619            .iter()
620            .zip(k3.iter())
621            .zip(k4.iter())
622            .zip(k5.iter())
623            .zip(k6.iter())
624            .zip(k7.iter())
625            .map(|(((((e1, e3), e4), e5), e6), e7)| {
626                h * (DOPRI5_E1 * e1
627                    + DOPRI5_E3 * e3
628                    + DOPRI5_E4 * e4
629                    + DOPRI5_E5 * e5
630                    + DOPRI5_E6 * e6
631                    + DOPRI5_E7 * e7)
632            })
633            .collect();
634
635        let error_norm_val = error_norm(&err, &y, &y_new, config.rtol, config.atol);
636
637        if error_norm_val <= 1.0 {
638            // Accept step
639            t += h;
640            y = y_new;
641            k1 = k7; // FSAL
642
643            if config.dense_output {
644                times.push(t);
645                states.push(y.clone());
646            }
647            steps += 1;
648
649            // Compute new step size
650            let factor = if error_norm_val == 0.0 {
651                DOPRI5_MAX_FACTOR
652            } else {
653                (DOPRI5_SAFETY * error_norm_val.powf(-1.0 / DOPRI5_ORDER))
654                    .clamp(DOPRI5_MIN_FACTOR, DOPRI5_MAX_FACTOR)
655            };
656            h *= factor;
657            h = h.abs().min(config.max_step) * sign;
658        } else {
659            // Reject step
660            rejected_steps += 1;
661            let factor = (DOPRI5_SAFETY * error_norm_val.powf(-1.0 / DOPRI5_ORDER))
662                .clamp(DOPRI5_MIN_FACTOR, 1.0);
663            h *= factor;
664        }
665    }
666
667    // Ensure endpoint is stored
668    // For dense output only push endpoint if not already recorded; for fixed output always push.
669    if !config.dense_output || times.last().map(|&last| last != t).unwrap_or(true) {
670        times.push(t);
671        states.push(y.clone());
672    }
673
674    Ok(AdaptiveSolution {
675        solution: OdeSolution {
676            times,
677            states,
678            nfev,
679        },
680        rejected_steps,
681        final_step_size: h.abs(),
682    })
683}
684
685// ---------------------------------------------------------------------------
686// NeuralOde layer
687// ---------------------------------------------------------------------------
688
689/// A Neural ODE layer that wraps an [`OdeFunc`] with fixed integration limits.
690///
691/// The forward pass integrates `y(t0) = y0` to `y(t1)` and the `adjoint`
692/// method computes `dL/dy0` and `dL/dθ` via the adjoint sensitivity method
693/// without storing all intermediate activations.
694pub struct NeuralOde<F: OdeFunc> {
695    func: F,
696    t0: f64,
697    t1: f64,
698    config: OdeSolverConfig,
699}
700
701impl<F: OdeFunc> NeuralOde<F> {
702    /// Create a new [`NeuralOde`] with default solver configuration.
703    pub fn new(func: F, t0: f64, t1: f64) -> Self {
704        Self {
705            func,
706            t0,
707            t1,
708            config: OdeSolverConfig::default(),
709        }
710    }
711
712    /// Create a new [`NeuralOde`] with a custom solver configuration.
713    pub fn with_config(func: F, t0: f64, t1: f64, config: OdeSolverConfig) -> Self {
714        Self {
715            func,
716            t0,
717            t1,
718            config,
719        }
720    }
721
722    /// Forward pass: integrate `y0` from `t0` to `t1`.
723    ///
724    /// Uses the adaptive DOPRI5 solver internally.
725    pub fn forward(&self, y0: &[f64], params: &[f64]) -> Result<OdeSolution, OdeError> {
726        if y0.is_empty() {
727            return Err(OdeError::InvalidInput("initial state is empty".into()));
728        }
729        let adaptive = dopri5_solve(&self.func, self.t0, self.t1, y0, params, &self.config)?;
730        Ok(adaptive.solution)
731    }
732
733    /// Full adjoint sensitivity method.
734    ///
735    /// Integrates forward to obtain `y(t1)`, then runs the augmented backward
736    /// ODE to compute `dL/dy0` and `dL/dθ` in a single backward pass.
737    ///
738    /// # Arguments
739    /// - `y0`          – initial state.
740    /// - `params`      – ODE parameters.
741    /// - `grad_output` – upstream gradient `dL/dy(t1)`.
742    pub fn adjoint(
743        &self,
744        y0: &[f64],
745        params: &[f64],
746        grad_output: &[f64],
747    ) -> Result<AdjointResult, OdeError> {
748        if y0.len() != grad_output.len() {
749            return Err(OdeError::InvalidInput(format!(
750                "grad_output length {} does not match state dimension {}",
751                grad_output.len(),
752                y0.len()
753            )));
754        }
755
756        // Forward pass with dense output to store trajectory
757        let fwd_config = OdeSolverConfig {
758            dense_output: true,
759            ..self.config.clone()
760        };
761        let adaptive = dopri5_solve(&self.func, self.t0, self.t1, y0, params, &fwd_config)?;
762        let fwd_nfev = adaptive.solution.nfev;
763
764        let adj_result = adjoint_backward(
765            &self.func,
766            &adaptive.solution,
767            params,
768            grad_output,
769            &self.config,
770        );
771
772        Ok(AdjointResult {
773            total_nfev: fwd_nfev + adj_result.total_nfev,
774            ..adj_result
775        })
776    }
777}
778
779// ---------------------------------------------------------------------------
780// Adjoint backward pass
781// ---------------------------------------------------------------------------
782
783/// Run the augmented adjoint backward integration.
784///
785/// The adjoint (co-state) `a(t) = dL/dy(t)` satisfies:
786/// ```text
787///   da/dt = -a^T * (∂f/∂y)
788/// ```
789/// evaluated backwards from `t1` to `t0`. Simultaneously, the parameter
790/// gradient accumulates as:
791/// ```text
792///   dL/dθ = -∫_{t1}^{t0} a^T * (∂f/∂θ) dt
793/// ```
794///
795/// Implementation: we integrate the augmented state `[a, dL/dθ]` backward
796/// in time using RK4, stepping through the stored forward trajectory in
797/// reverse order to obtain accurate `y(t)` at each sub-step.
798fn adjoint_backward(
799    func: &dyn OdeFunc,
800    solution: &OdeSolution,
801    params: &[f64],
802    grad_output: &[f64],
803    _config: &OdeSolverConfig,
804) -> AdjointResult {
805    let n_state = grad_output.len();
806    let n_params = params.len();
807
808    let final_state = solution
809        .states
810        .last()
811        .cloned()
812        .unwrap_or_else(|| grad_output.to_vec());
813
814    // Initialise adjoint at t1
815    let mut a = grad_output.to_vec();
816    let mut grad_params = vec![0.0_f64; n_params];
817    let mut total_nfev = 0usize;
818
819    // Use a fixed number of backward sub-steps per forward interval
820    let adj_steps_per_interval = 4usize;
821
822    // Walk intervals in reverse: [t_{k+1}, t_k]
823    let n_intervals = solution.times.len().saturating_sub(1);
824    for interval_idx in (0..n_intervals).rev() {
825        let t_start = solution.times[interval_idx + 1];
826        let t_end = solution.times[interval_idx];
827        let y_start = &solution.states[interval_idx + 1];
828        let y_end = &solution.states[interval_idx];
829
830        // Subdivide backward interval with fixed-step RK4
831        let h = (t_end - t_start) / adj_steps_per_interval as f64;
832
833        let mut t_cur = t_start;
834
835        for step_idx in 0..adj_steps_per_interval {
836            // Interpolate y linearly between the two stored states for
837            // the current sub-step (simple but sufficient for moderate stiffness)
838            let alpha = step_idx as f64 / adj_steps_per_interval as f64;
839            let y_interp: Vec<f64> = y_start
840                .iter()
841                .zip(y_end.iter())
842                .map(|(ys, ye)| ys + alpha * (ye - ys))
843                .collect();
844
845            // Augmented RHS evaluated at current (a, param_grad)
846            let aug_rhs =
847                |t_local: f64, a_local: &[f64], y_local: &[f64]| -> (Vec<f64>, Vec<f64>) {
848                    let (da_dy, _da_dt, da_dp) = func.vjp(t_local, y_local, params, a_local);
849                    // a-dot = -da_dy  (adjoint equation in backward time)
850                    let a_dot: Vec<f64> = da_dy.iter().map(|x| -x).collect();
851                    // grad_params accumulation = -da_dp
852                    let gp_dot: Vec<f64> = da_dp.iter().map(|x| -x).collect();
853                    (a_dot, gp_dot)
854                };
855
856            // RK4 for augmented state [a, grad_params]
857            let (k1_a, k1_gp) = aug_rhs(t_cur, &a, &y_interp);
858            total_nfev += 1;
859
860            let a2 = vec_axpy(&a, h * 0.5, &k1_a);
861            let alpha2 = (step_idx as f64 + 0.5) / adj_steps_per_interval as f64;
862            let y2: Vec<f64> = y_start
863                .iter()
864                .zip(y_end.iter())
865                .map(|(ys, ye)| ys + alpha2 * (ye - ys))
866                .collect();
867            let (k2_a, k2_gp) = aug_rhs(t_cur + h * 0.5, &a2, &y2);
868            total_nfev += 1;
869
870            let a3 = vec_axpy(&a, h * 0.5, &k2_a);
871            let (k3_a, k3_gp) = aug_rhs(t_cur + h * 0.5, &a3, &y2);
872            total_nfev += 1;
873
874            let a4 = vec_axpy(&a, h, &k3_a);
875            let alpha_end = (step_idx + 1) as f64 / adj_steps_per_interval as f64;
876            let y4: Vec<f64> = y_start
877                .iter()
878                .zip(y_end.iter())
879                .map(|(ys, ye)| ys + alpha_end * (ye - ys))
880                .collect();
881            let (k4_a, k4_gp) = aug_rhs(t_cur + h, &a4, &y4);
882            total_nfev += 1;
883
884            // Update a and grad_params
885            a = a
886                .iter()
887                .zip(k1_a.iter())
888                .zip(k2_a.iter())
889                .zip(k3_a.iter())
890                .zip(k4_a.iter())
891                .map(|((((ai, k1i), k2i), k3i), k4i)| {
892                    ai + h / 6.0 * (k1i + 2.0 * k2i + 2.0 * k3i + k4i)
893                })
894                .collect();
895
896            grad_params = grad_params
897                .iter()
898                .zip(k1_gp.iter())
899                .zip(k2_gp.iter())
900                .zip(k3_gp.iter())
901                .zip(k4_gp.iter())
902                .map(|((((gp, k1i), k2i), k3i), k4i)| {
903                    gp + h / 6.0 * (k1i + 2.0 * k2i + 2.0 * k3i + k4i)
904                })
905                .collect();
906
907            t_cur += h;
908        }
909
910        let _ = n_state; // suppress unused warning if n_state == 0
911    }
912
913    AdjointResult {
914        final_state,
915        grad_y0: a,
916        grad_params,
917        total_nfev,
918    }
919}
920
921// ---------------------------------------------------------------------------
922// Tests
923// ---------------------------------------------------------------------------
924
925#[cfg(test)]
926mod tests {
927    use super::*;
928
929    // ---- Helper ODE functions -----------------------------------------------
930
931    /// dy/dt = 0  (constant function)
932    struct ConstantOde;
933    impl OdeFunc for ConstantOde {
934        fn call(&self, _t: f64, _y: &[f64], _params: &[f64]) -> Vec<f64> {
935            vec![0.0]
936        }
937        fn vjp(
938            &self,
939            _t: f64,
940            _y: &[f64],
941            _params: &[f64],
942            _grad: &[f64],
943        ) -> (Vec<f64>, f64, Vec<f64>) {
944            (vec![0.0], 0.0, vec![])
945        }
946    }
947
948    /// dy/dt = y  (exponential growth)
949    struct ExponentialGrowthOde;
950    impl OdeFunc for ExponentialGrowthOde {
951        fn call(&self, _t: f64, y: &[f64], _params: &[f64]) -> Vec<f64> {
952            vec![y[0]]
953        }
954        fn vjp(
955            &self,
956            _t: f64,
957            _y: &[f64],
958            _params: &[f64],
959            grad: &[f64],
960        ) -> (Vec<f64>, f64, Vec<f64>) {
961            // df/dy = I, so VJP = grad
962            (grad.to_vec(), 0.0, vec![])
963        }
964    }
965
966    /// dy/dt = -y  (exponential decay)
967    struct ExponentialDecayOde;
968    impl OdeFunc for ExponentialDecayOde {
969        fn call(&self, _t: f64, y: &[f64], _params: &[f64]) -> Vec<f64> {
970            vec![-y[0]]
971        }
972        fn vjp(
973            &self,
974            _t: f64,
975            _y: &[f64],
976            _params: &[f64],
977            grad: &[f64],
978        ) -> (Vec<f64>, f64, Vec<f64>) {
979            (grad.iter().map(|g| -g).collect(), 0.0, vec![])
980        }
981    }
982
983    /// Harmonic oscillator: dx/dt = y, dy/dt = -x  (unit circle)
984    struct OscillatorOde;
985    impl OdeFunc for OscillatorOde {
986        fn call(&self, _t: f64, y: &[f64], _params: &[f64]) -> Vec<f64> {
987            vec![y[1], -y[0]]
988        }
989        fn vjp(
990            &self,
991            _t: f64,
992            _y: &[f64],
993            _params: &[f64],
994            grad: &[f64],
995        ) -> (Vec<f64>, f64, Vec<f64>) {
996            // Jacobian: [[0, 1], [-1, 0]], so VJP = grad^T * J
997            let ga = grad[1]; // d/dy0 = -grad[1]
998            let gb = grad[0]; // d/dy1 =  grad[0]
999            (vec![-ga, gb], 0.0, vec![])
1000        }
1001    }
1002
1003    /// dy/dt = param * y  (linear with parameter)
1004    struct LinearParamOde;
1005    impl OdeFunc for LinearParamOde {
1006        fn call(&self, _t: f64, y: &[f64], params: &[f64]) -> Vec<f64> {
1007            vec![params[0] * y[0]]
1008        }
1009        fn vjp(
1010            &self,
1011            _t: f64,
1012            y: &[f64],
1013            params: &[f64],
1014            grad: &[f64],
1015        ) -> (Vec<f64>, f64, Vec<f64>) {
1016            let grad_y = vec![grad[0] * params[0]];
1017            let grad_p = vec![grad[0] * y[0]];
1018            (grad_y, 0.0, grad_p)
1019        }
1020    }
1021
1022    /// Stiff ODE: dy/dt = -1000 * y  (stiffness ratio 1000)
1023    struct StiffOde;
1024    impl OdeFunc for StiffOde {
1025        fn call(&self, _t: f64, y: &[f64], _params: &[f64]) -> Vec<f64> {
1026            vec![-1000.0 * y[0]]
1027        }
1028        fn vjp(
1029            &self,
1030            _t: f64,
1031            _y: &[f64],
1032            _params: &[f64],
1033            grad: &[f64],
1034        ) -> (Vec<f64>, f64, Vec<f64>) {
1035            (grad.iter().map(|g| -1000.0 * g).collect(), 0.0, vec![])
1036        }
1037    }
1038
1039    // =========================================================================
1040    // Test 1: RK4 solves dy/dt = 0 (constant)
1041    // =========================================================================
1042    #[test]
1043    fn test_rk4_constant_ode() {
1044        let init_val = 42.0_f64;
1045        let sol = rk4_solve(&ConstantOde, 0.0, 1.0, &[init_val], &[], 100);
1046        let final_y = sol.states.last().unwrap()[0];
1047        assert!(
1048            (final_y - init_val).abs() < 1e-12,
1049            "constant ODE should stay at {init_val}, got {final_y}"
1050        );
1051    }
1052
1053    // =========================================================================
1054    // Test 2: RK4 solves dy/dt = y (exponential growth)
1055    // =========================================================================
1056    #[test]
1057    fn test_rk4_exponential_growth() {
1058        let sol = rk4_solve(&ExponentialGrowthOde, 0.0, 1.0, &[1.0], &[], 10_000);
1059        let final_y = sol.states.last().unwrap()[0];
1060        let exact = std::f64::consts::E;
1061        assert!(
1062            (final_y - exact).abs() < 1e-6,
1063            "RK4 exponential growth: got {final_y}, expected {exact}"
1064        );
1065    }
1066
1067    // =========================================================================
1068    // Test 3: RK4 solves dy/dt = -y (exponential decay)
1069    // =========================================================================
1070    #[test]
1071    fn test_rk4_exponential_decay() {
1072        let sol = rk4_solve(&ExponentialDecayOde, 0.0, 1.0, &[1.0], &[], 10_000);
1073        let final_y = sol.states.last().unwrap()[0];
1074        let exact = (-1.0_f64).exp();
1075        assert!(
1076            (final_y - exact).abs() < 1e-6,
1077            "RK4 exponential decay: got {final_y}, expected {exact}"
1078        );
1079    }
1080
1081    // =========================================================================
1082    // Test 4: RK4 with 2D oscillator (unit circle)
1083    // =========================================================================
1084    #[test]
1085    fn test_rk4_oscillator_2d() {
1086        // Integrate one full period: [0, 2π]
1087        use std::f64::consts::PI;
1088        let sol = rk4_solve(&OscillatorOde, 0.0, 2.0 * PI, &[1.0, 0.0], &[], 100_000);
1089        let last = sol.states.last().unwrap();
1090        // Should return close to [1, 0]
1091        assert!(
1092            (last[0] - 1.0).abs() < 1e-4,
1093            "oscillator x: got {}",
1094            last[0]
1095        );
1096        assert!(last[1].abs() < 1e-4, "oscillator y: got {}", last[1]);
1097    }
1098
1099    // =========================================================================
1100    // Test 5: DOPRI5 achieves high accuracy when tolerances are tight
1101    // =========================================================================
1102    #[test]
1103    fn test_dopri5_more_accurate_than_rk4() {
1104        let exact = std::f64::consts::E;
1105
1106        // RK4 with only 10 steps (coarse fixed-step baseline)
1107        let rk4_sol = rk4_solve(&ExponentialGrowthOde, 0.0, 1.0, &[1.0], &[], 10);
1108        let rk4_err = (rk4_sol.states.last().unwrap()[0] - exact).abs();
1109
1110        // DOPRI5 with tight tolerances — adaptive solver should beat coarse RK4
1111        let config = OdeSolverConfig::new().rtol(1e-8).atol(1e-10);
1112        let dp5 = dopri5_solve(&ExponentialGrowthOde, 0.0, 1.0, &[1.0], &[], &config).unwrap();
1113        let dp5_err = (dp5.solution.states.last().unwrap()[0] - exact).abs();
1114
1115        assert!(
1116            dp5_err < rk4_err,
1117            "DOPRI5 (tight tol) error {dp5_err} should be less than coarse RK4 error {rk4_err}"
1118        );
1119        // Verify DOPRI5 achieves the requested tolerance
1120        assert!(
1121            dp5_err < 1e-6,
1122            "DOPRI5 with rtol=1e-8/atol=1e-10 should achieve < 1e-6 error, got {dp5_err}"
1123        );
1124    }
1125
1126    // =========================================================================
1127    // Test 6: DOPRI5 rejects steps on a stiff function
1128    // =========================================================================
1129    #[test]
1130    fn test_dopri5_step_rejection_on_stiff() {
1131        let config = OdeSolverConfig::new().rtol(1e-6).atol(1e-8).max_steps(5000);
1132        // Stiff ODE: many step rejections expected
1133        let result = dopri5_solve(&StiffOde, 0.0, 0.01, &[1.0], &[], &config);
1134        // Either succeeds with rejections, or fails with StepTooSmall
1135        match result {
1136            Ok(adaptive) => {
1137                // The solver may reject some steps on a very stiff problem; just verify the field is accessible.
1138                let _ = adaptive.rejected_steps;
1139            }
1140            Err(OdeError::StepTooSmall) | Err(OdeError::MaxStepsExceeded) => {
1141                // Acceptable — this problem is genuinely stiff
1142            }
1143            Err(e) => panic!("unexpected error: {e}"),
1144        }
1145    }
1146
1147    // =========================================================================
1148    // Test 7: OdeSolverConfig builder pattern
1149    // =========================================================================
1150    #[test]
1151    fn test_solver_config_builder() {
1152        let cfg = OdeSolverConfig::new().rtol(1e-8).atol(1e-10).max_steps(500);
1153        assert!((cfg.rtol - 1e-8).abs() < 1e-15);
1154        assert!((cfg.atol - 1e-10).abs() < 1e-18);
1155        assert_eq!(cfg.max_steps, 500);
1156    }
1157
1158    // =========================================================================
1159    // Test 8: NeuralOde::forward returns correct endpoint
1160    // =========================================================================
1161    #[test]
1162    fn test_neural_ode_forward_correct_endpoint() {
1163        let ode = NeuralOde::new(ExponentialGrowthOde, 0.0, 1.0);
1164        let sol = ode.forward(&[1.0], &[]).unwrap();
1165        let final_y = sol.states.last().unwrap()[0];
1166        let exact = std::f64::consts::E;
1167        assert!(
1168            (final_y - exact).abs() < 1e-4,
1169            "NeuralOde forward: got {final_y}, expected ~{exact}"
1170        );
1171    }
1172
1173    // =========================================================================
1174    // Test 9: NeuralOde::forward with t0 = t1 returns y0 unchanged
1175    // =========================================================================
1176    #[test]
1177    fn test_neural_ode_forward_t0_equals_t1() {
1178        let init_val = 7.5_f64; // arbitrary non-special constant
1179        let ode = NeuralOde::new(ExponentialGrowthOde, 1.5, 1.5);
1180        let sol = ode.forward(&[init_val], &[]).unwrap();
1181        // Should contain exactly the initial state
1182        assert!((sol.states[0][0] - init_val).abs() < 1e-12);
1183    }
1184
1185    // =========================================================================
1186    // Test 10: MaxStepsExceeded error on very stiff problem with tight limits
1187    // =========================================================================
1188    #[test]
1189    fn test_max_steps_exceeded_on_stiff() {
1190        // Extremely stiff ODE with very few allowed steps and tight tolerances
1191        let config = OdeSolverConfig::new().rtol(1e-12).atol(1e-14).max_steps(5); // intentionally tiny
1192        let result = dopri5_solve(&StiffOde, 0.0, 1.0, &[1.0], &[], &config);
1193        assert!(
1194            matches!(
1195                result,
1196                Err(OdeError::MaxStepsExceeded) | Err(OdeError::StepTooSmall)
1197            ),
1198            "expected MaxStepsExceeded or StepTooSmall"
1199        );
1200    }
1201
1202    // =========================================================================
1203    // Test 11: OdeSolution nfev count is reasonable (> 0)
1204    // =========================================================================
1205    #[test]
1206    fn test_nfev_is_positive() {
1207        let sol = rk4_solve(&ExponentialGrowthOde, 0.0, 1.0, &[1.0], &[], 10);
1208        assert!(sol.nfev > 0, "nfev should be > 0, got {}", sol.nfev);
1209        // RK4: 4 evaluations per step
1210        assert_eq!(sol.nfev, 40, "RK4 should use 4 * num_steps evaluations");
1211    }
1212
1213    // =========================================================================
1214    // Test 12: AdaptiveSolution.rejected_steps >= 0
1215    // =========================================================================
1216    #[test]
1217    fn test_rejected_steps_field_exists() {
1218        let config = OdeSolverConfig::new();
1219        let adaptive = dopri5_solve(&ExponentialGrowthOde, 0.0, 1.0, &[1.0], &[], &config).unwrap();
1220        // This field must exist and be a valid integer >= 0
1221        let _ = adaptive.rejected_steps; // type check — usize is always >= 0
1222        assert!(adaptive.solution.nfev > 0);
1223    }
1224
1225    // =========================================================================
1226    // Test 13: Dense output stores intermediate steps
1227    // =========================================================================
1228    #[test]
1229    fn test_dense_output_stores_intermediate_steps() {
1230        let config = OdeSolverConfig::new().rtol(1e-6).atol(1e-8);
1231        let adaptive = dopri5_solve(&ExponentialGrowthOde, 0.0, 1.0, &[1.0], &[], &config).unwrap();
1232        // Dense output: should have more than just start and end
1233        assert!(
1234            adaptive.solution.times.len() > 2,
1235            "dense output should contain more than 2 time points, got {}",
1236            adaptive.solution.times.len()
1237        );
1238        assert_eq!(
1239            adaptive.solution.times.len(),
1240            adaptive.solution.states.len(),
1241            "times and states must have the same length"
1242        );
1243    }
1244
1245    // =========================================================================
1246    // Test 14: Adjoint grad_y0 has same dimension as y0
1247    // =========================================================================
1248    #[test]
1249    fn test_adjoint_grad_y0_dimension() {
1250        let ode = NeuralOde::new(LinearParamOde, 0.0, 0.5);
1251        let y0 = vec![1.0_f64];
1252        let params = vec![-1.0_f64];
1253        let grad_out = vec![1.0_f64];
1254        let adj = ode.adjoint(&y0, &params, &grad_out).unwrap();
1255        assert_eq!(
1256            adj.grad_y0.len(),
1257            y0.len(),
1258            "grad_y0 must have same dim as y0"
1259        );
1260    }
1261
1262    // =========================================================================
1263    // Test 15: Adjoint grad_params has same dimension as params
1264    // =========================================================================
1265    #[test]
1266    fn test_adjoint_grad_params_dimension() {
1267        let ode = NeuralOde::new(LinearParamOde, 0.0, 0.5);
1268        let y0 = vec![1.0_f64];
1269        let params = vec![-1.0_f64];
1270        let grad_out = vec![1.0_f64];
1271        let adj = ode.adjoint(&y0, &params, &grad_out).unwrap();
1272        assert_eq!(
1273            adj.grad_params.len(),
1274            params.len(),
1275            "grad_params must have same dim as params"
1276        );
1277    }
1278
1279    // =========================================================================
1280    // Test 16: OdeError Display shows meaningful messages
1281    // =========================================================================
1282    #[test]
1283    fn test_ode_error_display() {
1284        let msgs = [
1285            (OdeError::MaxStepsExceeded, "max"),
1286            (OdeError::StepTooSmall, "step"),
1287            (OdeError::DivergentSolution, "diverged"),
1288            (OdeError::InvalidInput("bad".into()), "bad"),
1289        ];
1290        for (err, keyword) in msgs {
1291            let msg = format!("{err}");
1292            assert!(
1293                msg.to_lowercase().contains(keyword),
1294                "Display for {err:?} should contain '{keyword}', got: '{msg}'"
1295            );
1296        }
1297    }
1298
1299    // =========================================================================
1300    // Test 17: Multiple forward passes produce same result (deterministic)
1301    // =========================================================================
1302    #[test]
1303    fn test_forward_is_deterministic() {
1304        let ode = NeuralOde::new(ExponentialGrowthOde, 0.0, 1.0);
1305        let sol1 = ode.forward(&[1.0], &[]).unwrap();
1306        let sol2 = ode.forward(&[1.0], &[]).unwrap();
1307        let y1 = sol1.states.last().unwrap()[0];
1308        let y2 = sol2.states.last().unwrap()[0];
1309        assert_eq!(y1, y2, "repeated forward passes must be deterministic");
1310    }
1311
1312    // =========================================================================
1313    // Test 18: RK4 converges to exact solution as num_steps increases
1314    // =========================================================================
1315    #[test]
1316    fn test_rk4_convergence_with_steps() {
1317        let exact = std::f64::consts::E;
1318        let steps_list = [10usize, 100, 1000, 10_000];
1319        let mut prev_err = f64::INFINITY;
1320        for &n in &steps_list {
1321            let sol = rk4_solve(&ExponentialGrowthOde, 0.0, 1.0, &[1.0], &[], n);
1322            let err = (sol.states.last().unwrap()[0] - exact).abs();
1323            assert!(
1324                err < prev_err,
1325                "error {err} at n={n} is not less than prev {prev_err}"
1326            );
1327            prev_err = err;
1328        }
1329        // Final error at 10_000 steps must be very small (RK4 is 4th order;
1330        // we bound at 1e-13 to allow for floating-point rounding accumulation)
1331        assert!(
1332            prev_err < 1e-13,
1333            "RK4 with 10_000 steps: error {prev_err} > 1e-13"
1334        );
1335    }
1336
1337    // =========================================================================
1338    // Test 19: DOPRI5 rtol/atol affect solution accuracy
1339    // =========================================================================
1340    #[test]
1341    fn test_dopri5_tolerance_affects_accuracy() {
1342        let exact = std::f64::consts::E;
1343
1344        let coarse = OdeSolverConfig::new().rtol(1e-3).atol(1e-5);
1345        let fine = OdeSolverConfig::new().rtol(1e-9).atol(1e-11);
1346
1347        let sol_coarse =
1348            dopri5_solve(&ExponentialGrowthOde, 0.0, 1.0, &[1.0], &[], &coarse).unwrap();
1349        let sol_fine = dopri5_solve(&ExponentialGrowthOde, 0.0, 1.0, &[1.0], &[], &fine).unwrap();
1350
1351        let err_coarse = (sol_coarse.solution.states.last().unwrap()[0] - exact).abs();
1352        let err_fine = (sol_fine.solution.states.last().unwrap()[0] - exact).abs();
1353
1354        assert!(
1355            err_fine < err_coarse,
1356            "fine tol error {err_fine} should be less than coarse tol error {err_coarse}"
1357        );
1358    }
1359
1360    // =========================================================================
1361    // Test 20: NeuralOde with params affects trajectory
1362    // =========================================================================
1363    #[test]
1364    fn test_neural_ode_params_affect_trajectory() {
1365        // LinearParamOde: dy/dt = p * y  =>  y(1) = y0 * exp(p)
1366        let ode = NeuralOde::new(LinearParamOde, 0.0, 1.0);
1367        let sol_pos = ode.forward(&[1.0], &[1.0]).unwrap(); // y(1) ~ e
1368        let sol_neg = ode.forward(&[1.0], &[-1.0]).unwrap(); // y(1) ~ e^-1
1369
1370        let y_pos = sol_pos.states.last().unwrap()[0];
1371        let y_neg = sol_neg.states.last().unwrap()[0];
1372
1373        assert!(
1374            y_pos > y_neg,
1375            "positive param should give larger y: y_pos={y_pos}, y_neg={y_neg}"
1376        );
1377        assert!(
1378            (y_pos - std::f64::consts::E).abs() < 1e-3,
1379            "y_pos ~ e, got {y_pos}"
1380        );
1381        assert!(
1382            (y_neg - (-1.0_f64).exp()).abs() < 1e-3,
1383            "y_neg ~ e^-1, got {y_neg}"
1384        );
1385    }
1386
1387    // =========================================================================
1388    // Extra: verify AdjointResult fields exist and total_nfev is positive
1389    // =========================================================================
1390    #[test]
1391    fn test_adjoint_result_fields() {
1392        let ode = NeuralOde::new(LinearParamOde, 0.0, 1.0);
1393        let adj = ode.adjoint(&[1.0], &[-1.0], &[1.0]).unwrap();
1394        assert!(adj.total_nfev > 0, "total_nfev should be > 0");
1395        assert!(!adj.final_state.is_empty(), "final_state must not be empty");
1396        assert!(!adj.grad_y0.is_empty(), "grad_y0 must not be empty");
1397        // grad_params matches params dimension (1 here)
1398        assert_eq!(adj.grad_params.len(), 1);
1399    }
1400
1401    // =========================================================================
1402    // Extra: OdeSolution start state equals y0
1403    // =========================================================================
1404    #[test]
1405    fn test_solution_first_state_is_y0() {
1406        let y0 = vec![42.0_f64, -7.5];
1407        let sol = rk4_solve(&OscillatorOde, 0.0, 1.0, &y0, &[], 100);
1408        assert_eq!(&sol.states[0], &y0, "first stored state must equal y0");
1409        assert_eq!(sol.times[0], 0.0);
1410    }
1411}