Skip to main content

scirs2_integrate/async_ode/
mod.rs

1//! Async/cached ODE solver for neural ODE training patterns.
2//!
3//! # Motivation
4//!
5//! Neural ODE training requires integrating the same ODE structure many times
6//! per batch — once per forward pass and once per adjoint backward pass.
7//! The "CUDA graph capture" concept (repeated ODE solve patterns) maps in
8//! pure Rust to:
9//!
10//! 1. **Pre-allocated scratch buffers**: No per-solve heap allocation beyond
11//!    the initial setup (`CachedOdeProblem::new`).
12//! 2. **Rayon batch parallelism**: `integrate_batch` runs all initial
13//!    conditions concurrently using Rayon's work-stealing pool.
14//! 3. **Tokio async wrapper**: `integrate_batch_async` offloads the blocking
15//!    rayon call onto a `spawn_blocking` thread so the async runtime stays
16//!    responsive.
17//!
18//! # Solver
19//!
20//! A fixed-step RK4 integrator is used.  The step size `dt` and number of
21//! steps are determined at construction time from `t_span = (t0, t1)` and
22//! `dt`.  This predictable structure (unlike adaptive step-size methods) is
23//! what enables true "graph capture": the sequence of RHS evaluations is
24//! identical for every call, making it JIT-friendly.
25//!
26//! # Example
27//!
28//! ```rust
29//! use scirs2_integrate::async_ode::{CachedOdeProblem, integrate_batch_async};
30//! use std::sync::Arc;
31//!
32//! // dy/dt = -y  →  y(t) = y0 * exp(-t)
33//! let problem = Arc::new(
34//!     CachedOdeProblem::new(|_t, y, dydt| { dydt[0] = -y[0]; }, 0.0, 1.0, 0.01, 1)
35//! );
36//!
37//! let result = problem.integrate(&[1.0]).unwrap();
38//! let expected = 1.0_f64.exp().recip(); // e^{-1} ≈ 0.3679
39//! assert!((result[0] - expected).abs() < 1e-4);
40//! ```
41
42use crate::error::IntegrateError;
43use scirs2_core::parallel_ops::*;
44use std::sync::Arc;
45
46// ─────────────────────────────────────────────────────────────────────────────
47// CachedOdeProblem
48// ─────────────────────────────────────────────────────────────────────────────
49
50/// A pre-compiled ODE problem for repeated integration (neural ODE pattern).
51///
52/// Once constructed the RK4 graph is fixed: `n_steps` evaluations of `rhs`
53/// are performed per `integrate` call, with no re-allocation of step vectors.
54pub struct CachedOdeProblem<F>
55where
56    F: Fn(f64, &[f64], &mut [f64]) + Send + Sync,
57{
58    rhs: Arc<F>,
59    t0: f64,
60    dt: f64,
61    n_steps: usize,
62    state_dim: usize,
63}
64
65impl<F> CachedOdeProblem<F>
66where
67    F: Fn(f64, &[f64], &mut [f64]) + Send + Sync + 'static,
68{
69    /// Create a new cached ODE problem.
70    ///
71    /// # Parameters
72    ///
73    /// - `rhs`: Right-hand side `f(t, y, dydt)` — writes into `dydt`.
74    /// - `t0`, `t1`: Integration interval `[t0, t1]`.
75    /// - `dt`: Fixed step size.  Actual `n_steps = ceil((t1 - t0) / dt)`.
76    /// - `state_dim`: Dimensionality of the state vector.
77    ///
78    /// # Errors
79    ///
80    /// Returns [`IntegrateError::ValueError`] if `dt ≤ 0` or `t1 ≤ t0`.
81    pub fn new(rhs: F, t0: f64, t1: f64, dt: f64, state_dim: usize) -> Self {
82        let span = t1 - t0;
83        let n_steps = ((span / dt).ceil() as usize).max(1);
84        CachedOdeProblem {
85            rhs: Arc::new(rhs),
86            t0,
87            dt,
88            n_steps,
89            state_dim,
90        }
91    }
92
93    /// Integrate from initial state `y0` and return the final state.
94    ///
95    /// Pre-allocated scratch buffers are stack-allocated for `state_dim ≤ 16`
96    /// and heap-allocated otherwise; either way no allocation occurs inside
97    /// the RK4 loop itself.
98    pub fn integrate(&self, y0: &[f64]) -> Result<Vec<f64>, IntegrateError> {
99        if y0.len() != self.state_dim {
100            return Err(IntegrateError::DimensionMismatch(format!(
101                "y0.len()={} != state_dim={}",
102                y0.len(),
103                self.state_dim
104            )));
105        }
106
107        let dim = self.state_dim;
108        let mut y = y0.to_vec();
109        // Pre-allocate scratch once — reused across all steps
110        let mut k1 = vec![0.0_f64; dim];
111        let mut k2 = vec![0.0_f64; dim];
112        let mut k3 = vec![0.0_f64; dim];
113        let mut k4 = vec![0.0_f64; dim];
114        let mut ytmp = vec![0.0_f64; dim];
115
116        let rhs = &*self.rhs;
117        let mut t = self.t0;
118        let h = self.dt;
119
120        for _ in 0..self.n_steps {
121            // k1 = f(t, y)
122            rhs(t, &y, &mut k1);
123
124            // k2 = f(t + h/2, y + h/2 * k1)
125            for i in 0..dim {
126                ytmp[i] = y[i] + 0.5 * h * k1[i];
127            }
128            rhs(t + 0.5 * h, &ytmp, &mut k2);
129
130            // k3 = f(t + h/2, y + h/2 * k2)
131            for i in 0..dim {
132                ytmp[i] = y[i] + 0.5 * h * k2[i];
133            }
134            rhs(t + 0.5 * h, &ytmp, &mut k3);
135
136            // k4 = f(t + h, y + h * k3)
137            for i in 0..dim {
138                ytmp[i] = y[i] + h * k3[i];
139            }
140            rhs(t + h, &ytmp, &mut k4);
141
142            // y ← y + h/6 * (k1 + 2k2 + 2k3 + k4)
143            for i in 0..dim {
144                y[i] += (h / 6.0) * (k1[i] + 2.0 * k2[i] + 2.0 * k3[i] + k4[i]);
145            }
146            t += h;
147        }
148
149        Ok(y)
150    }
151
152    /// Batch integration — runs all initial conditions in parallel using Rayon.
153    ///
154    /// Returns one output state vector per input in `batch_y0`, in the same
155    /// order as the inputs.
156    pub fn integrate_batch(&self, batch_y0: &[Vec<f64>]) -> Result<Vec<Vec<f64>>, IntegrateError> {
157        parallel_map_result(batch_y0, |y0| self.integrate(y0))
158    }
159
160    /// The step size used during integration.
161    pub fn dt(&self) -> f64 {
162        self.dt
163    }
164
165    /// Number of RK4 steps per `integrate` call.
166    pub fn n_steps(&self) -> usize {
167        self.n_steps
168    }
169
170    /// State dimension.
171    pub fn state_dim(&self) -> usize {
172        self.state_dim
173    }
174}
175
176// ─────────────────────────────────────────────────────────────────────────────
177// Async wrapper
178// ─────────────────────────────────────────────────────────────────────────────
179
180/// Async batch integration — offloads blocking Rayon work to a spawn_blocking thread.
181///
182/// This keeps the Tokio runtime responsive while allowing the full Rayon thread
183/// pool to be used for the actual computation.
184///
185/// Requires the `tokio` dependency to be available in the workspace.
186pub async fn integrate_batch_async<F>(
187    problem: Arc<CachedOdeProblem<F>>,
188    batch_y0: Vec<Vec<f64>>,
189) -> Result<Vec<Vec<f64>>, IntegrateError>
190where
191    F: Fn(f64, &[f64], &mut [f64]) + Send + Sync + 'static,
192{
193    tokio::task::spawn_blocking(move || problem.integrate_batch(&batch_y0))
194        .await
195        .map_err(|e| IntegrateError::ComputationError(format!("spawn_blocking panicked: {e}")))?
196}
197
198// ─────────────────────────────────────────────────────────────────────────────
199// Tests
200// ─────────────────────────────────────────────────────────────────────────────
201
202#[cfg(test)]
203mod tests {
204    use super::*;
205
206    /// dy/dt = -y, exact solution y(t) = e^{-t}
207    fn exponential_decay() -> impl Fn(f64, &[f64], &mut [f64]) + Send + Sync + 'static {
208        |_t, y, dydt| {
209            dydt[0] = -y[0];
210        }
211    }
212
213    #[test]
214    fn test_cached_ode_exponential_decay() {
215        let problem = CachedOdeProblem::new(exponential_decay(), 0.0, 1.0, 0.001, 1);
216        let result = problem.integrate(&[1.0]).expect("integration failed");
217        let expected = std::f64::consts::E.recip(); // e^{-1}
218        assert!(
219            (result[0] - expected).abs() < 1e-5,
220            "Expected ≈{expected:.6}, got {:.6}",
221            result[0]
222        );
223    }
224
225    #[test]
226    fn test_batch_integration_matches_serial() {
227        let problem = Arc::new(CachedOdeProblem::new(
228            exponential_decay(),
229            0.0,
230            0.5,
231            0.001,
232            1,
233        ));
234        let batch_y0 = vec![vec![1.0], vec![2.0], vec![0.5]];
235        let batch_result = problem.integrate_batch(&batch_y0).expect("batch failed");
236
237        for (y0, yr) in batch_y0.iter().zip(batch_result.iter()) {
238            let serial = problem.integrate(y0).expect("serial failed");
239            assert!(
240                (serial[0] - yr[0]).abs() < 1e-14,
241                "Batch/serial mismatch: serial={:.10} batch={:.10}",
242                serial[0],
243                yr[0]
244            );
245        }
246    }
247
248    #[test]
249    fn test_neural_ode_repeated_forward_same_result() {
250        // Repeated calls with the same y0 must return identical results (deterministic)
251        let problem = Arc::new(CachedOdeProblem::new(
252            |_t, y, dydt| {
253                dydt[0] = -y[0];
254                dydt[1] = -2.0 * y[1];
255            },
256            0.0,
257            1.0,
258            0.01,
259            2,
260        ));
261        let y0 = vec![1.0, 1.0];
262        let r1 = problem.integrate(&y0).expect("first forward failed");
263        let r2 = problem.integrate(&y0).expect("second forward failed");
264        let r3 = problem.integrate(&y0).expect("third forward failed");
265        assert_eq!(r1, r2, "Results differ between calls 1 and 2");
266        assert_eq!(r1, r3, "Results differ between calls 1 and 3");
267    }
268
269    #[test]
270    fn test_dimension_mismatch_returns_error() {
271        let problem = CachedOdeProblem::new(exponential_decay(), 0.0, 1.0, 0.01, 1);
272        // Pass 2-element state to 1-dim problem
273        assert!(problem.integrate(&[1.0, 2.0]).is_err());
274    }
275
276    #[tokio::test]
277    async fn test_async_batch_returns_correct_shape() {
278        let problem = Arc::new(CachedOdeProblem::new(
279            exponential_decay(),
280            0.0,
281            0.5,
282            0.01,
283            1,
284        ));
285        let batch_y0 = vec![vec![1.0], vec![2.0], vec![3.0], vec![4.0]];
286        let expected_len = batch_y0.len();
287        let results = integrate_batch_async(problem, batch_y0)
288            .await
289            .expect("async batch failed");
290        assert_eq!(results.len(), expected_len);
291        for r in &results {
292            assert_eq!(r.len(), 1, "Each result must have state_dim=1 entries");
293        }
294    }
295
296    #[tokio::test]
297    async fn test_async_matches_sync() {
298        let problem_async = Arc::new(CachedOdeProblem::new(
299            exponential_decay(),
300            0.0,
301            1.0,
302            0.001,
303            1,
304        ));
305        let problem_sync = Arc::new(CachedOdeProblem::new(
306            exponential_decay(),
307            0.0,
308            1.0,
309            0.001,
310            1,
311        ));
312        let batch_y0 = vec![vec![1.0], vec![0.5], vec![2.0]];
313        let async_results = integrate_batch_async(problem_async, batch_y0.clone())
314            .await
315            .expect("async failed");
316        let sync_results = problem_sync
317            .integrate_batch(&batch_y0)
318            .expect("sync failed");
319        for (a, s) in async_results.iter().zip(sync_results.iter()) {
320            assert!(
321                (a[0] - s[0]).abs() < 1e-14,
322                "Async/sync mismatch: {:.10} vs {:.10}",
323                a[0],
324                s[0]
325            );
326        }
327    }
328}