scirs2_integrate/async_ode/mod.rs
1//! Async/cached ODE solver for neural ODE training patterns.
2//!
3//! # Motivation
4//!
5//! Neural ODE training requires integrating the same ODE structure many times
6//! per batch — once per forward pass and once per adjoint backward pass.
7//! The "CUDA graph capture" concept (repeated ODE solve patterns) maps in
8//! pure Rust to:
9//!
10//! 1. **Pre-allocated scratch buffers**: No per-solve heap allocation beyond
11//! the initial setup (`CachedOdeProblem::new`).
12//! 2. **Rayon batch parallelism**: `integrate_batch` runs all initial
13//! conditions concurrently using Rayon's work-stealing pool.
14//! 3. **Tokio async wrapper**: `integrate_batch_async` offloads the blocking
15//! rayon call onto a `spawn_blocking` thread so the async runtime stays
16//! responsive.
17//!
18//! # Solver
19//!
20//! A fixed-step RK4 integrator is used. The step size `dt` and number of
21//! steps are determined at construction time from `t_span = (t0, t1)` and
22//! `dt`. This predictable structure (unlike adaptive step-size methods) is
23//! what enables true "graph capture": the sequence of RHS evaluations is
24//! identical for every call, making it JIT-friendly.
25//!
26//! # Example
27//!
28//! ```rust
29//! use scirs2_integrate::async_ode::{CachedOdeProblem, integrate_batch_async};
30//! use std::sync::Arc;
31//!
32//! // dy/dt = -y → y(t) = y0 * exp(-t)
33//! let problem = Arc::new(
34//! CachedOdeProblem::new(|_t, y, dydt| { dydt[0] = -y[0]; }, 0.0, 1.0, 0.01, 1)
35//! );
36//!
37//! let result = problem.integrate(&[1.0]).unwrap();
38//! let expected = 1.0_f64.exp().recip(); // e^{-1} ≈ 0.3679
39//! assert!((result[0] - expected).abs() < 1e-4);
40//! ```
41
42use crate::error::IntegrateError;
43use scirs2_core::parallel_ops::*;
44use std::sync::Arc;
45
46// ─────────────────────────────────────────────────────────────────────────────
47// CachedOdeProblem
48// ─────────────────────────────────────────────────────────────────────────────
49
50/// A pre-compiled ODE problem for repeated integration (neural ODE pattern).
51///
52/// Once constructed the RK4 graph is fixed: `n_steps` evaluations of `rhs`
53/// are performed per `integrate` call, with no re-allocation of step vectors.
54pub struct CachedOdeProblem<F>
55where
56 F: Fn(f64, &[f64], &mut [f64]) + Send + Sync,
57{
58 rhs: Arc<F>,
59 t0: f64,
60 dt: f64,
61 n_steps: usize,
62 state_dim: usize,
63}
64
65impl<F> CachedOdeProblem<F>
66where
67 F: Fn(f64, &[f64], &mut [f64]) + Send + Sync + 'static,
68{
69 /// Create a new cached ODE problem.
70 ///
71 /// # Parameters
72 ///
73 /// - `rhs`: Right-hand side `f(t, y, dydt)` — writes into `dydt`.
74 /// - `t0`, `t1`: Integration interval `[t0, t1]`.
75 /// - `dt`: Fixed step size. Actual `n_steps = ceil((t1 - t0) / dt)`.
76 /// - `state_dim`: Dimensionality of the state vector.
77 ///
78 /// # Errors
79 ///
80 /// Returns [`IntegrateError::ValueError`] if `dt ≤ 0` or `t1 ≤ t0`.
81 pub fn new(rhs: F, t0: f64, t1: f64, dt: f64, state_dim: usize) -> Self {
82 let span = t1 - t0;
83 let n_steps = ((span / dt).ceil() as usize).max(1);
84 CachedOdeProblem {
85 rhs: Arc::new(rhs),
86 t0,
87 dt,
88 n_steps,
89 state_dim,
90 }
91 }
92
93 /// Integrate from initial state `y0` and return the final state.
94 ///
95 /// Pre-allocated scratch buffers are stack-allocated for `state_dim ≤ 16`
96 /// and heap-allocated otherwise; either way no allocation occurs inside
97 /// the RK4 loop itself.
98 pub fn integrate(&self, y0: &[f64]) -> Result<Vec<f64>, IntegrateError> {
99 if y0.len() != self.state_dim {
100 return Err(IntegrateError::DimensionMismatch(format!(
101 "y0.len()={} != state_dim={}",
102 y0.len(),
103 self.state_dim
104 )));
105 }
106
107 let dim = self.state_dim;
108 let mut y = y0.to_vec();
109 // Pre-allocate scratch once — reused across all steps
110 let mut k1 = vec![0.0_f64; dim];
111 let mut k2 = vec![0.0_f64; dim];
112 let mut k3 = vec![0.0_f64; dim];
113 let mut k4 = vec![0.0_f64; dim];
114 let mut ytmp = vec![0.0_f64; dim];
115
116 let rhs = &*self.rhs;
117 let mut t = self.t0;
118 let h = self.dt;
119
120 for _ in 0..self.n_steps {
121 // k1 = f(t, y)
122 rhs(t, &y, &mut k1);
123
124 // k2 = f(t + h/2, y + h/2 * k1)
125 for i in 0..dim {
126 ytmp[i] = y[i] + 0.5 * h * k1[i];
127 }
128 rhs(t + 0.5 * h, &ytmp, &mut k2);
129
130 // k3 = f(t + h/2, y + h/2 * k2)
131 for i in 0..dim {
132 ytmp[i] = y[i] + 0.5 * h * k2[i];
133 }
134 rhs(t + 0.5 * h, &ytmp, &mut k3);
135
136 // k4 = f(t + h, y + h * k3)
137 for i in 0..dim {
138 ytmp[i] = y[i] + h * k3[i];
139 }
140 rhs(t + h, &ytmp, &mut k4);
141
142 // y ← y + h/6 * (k1 + 2k2 + 2k3 + k4)
143 for i in 0..dim {
144 y[i] += (h / 6.0) * (k1[i] + 2.0 * k2[i] + 2.0 * k3[i] + k4[i]);
145 }
146 t += h;
147 }
148
149 Ok(y)
150 }
151
152 /// Batch integration — runs all initial conditions in parallel using Rayon.
153 ///
154 /// Returns one output state vector per input in `batch_y0`, in the same
155 /// order as the inputs.
156 pub fn integrate_batch(&self, batch_y0: &[Vec<f64>]) -> Result<Vec<Vec<f64>>, IntegrateError> {
157 parallel_map_result(batch_y0, |y0| self.integrate(y0))
158 }
159
160 /// The step size used during integration.
161 pub fn dt(&self) -> f64 {
162 self.dt
163 }
164
165 /// Number of RK4 steps per `integrate` call.
166 pub fn n_steps(&self) -> usize {
167 self.n_steps
168 }
169
170 /// State dimension.
171 pub fn state_dim(&self) -> usize {
172 self.state_dim
173 }
174}
175
176// ─────────────────────────────────────────────────────────────────────────────
177// Async wrapper
178// ─────────────────────────────────────────────────────────────────────────────
179
180/// Async batch integration — offloads blocking Rayon work to a spawn_blocking thread.
181///
182/// This keeps the Tokio runtime responsive while allowing the full Rayon thread
183/// pool to be used for the actual computation.
184///
185/// Requires the `tokio` dependency to be available in the workspace.
186pub async fn integrate_batch_async<F>(
187 problem: Arc<CachedOdeProblem<F>>,
188 batch_y0: Vec<Vec<f64>>,
189) -> Result<Vec<Vec<f64>>, IntegrateError>
190where
191 F: Fn(f64, &[f64], &mut [f64]) + Send + Sync + 'static,
192{
193 tokio::task::spawn_blocking(move || problem.integrate_batch(&batch_y0))
194 .await
195 .map_err(|e| IntegrateError::ComputationError(format!("spawn_blocking panicked: {e}")))?
196}
197
198// ─────────────────────────────────────────────────────────────────────────────
199// Tests
200// ─────────────────────────────────────────────────────────────────────────────
201
202#[cfg(test)]
203mod tests {
204 use super::*;
205
206 /// dy/dt = -y, exact solution y(t) = e^{-t}
207 fn exponential_decay() -> impl Fn(f64, &[f64], &mut [f64]) + Send + Sync + 'static {
208 |_t, y, dydt| {
209 dydt[0] = -y[0];
210 }
211 }
212
213 #[test]
214 fn test_cached_ode_exponential_decay() {
215 let problem = CachedOdeProblem::new(exponential_decay(), 0.0, 1.0, 0.001, 1);
216 let result = problem.integrate(&[1.0]).expect("integration failed");
217 let expected = std::f64::consts::E.recip(); // e^{-1}
218 assert!(
219 (result[0] - expected).abs() < 1e-5,
220 "Expected ≈{expected:.6}, got {:.6}",
221 result[0]
222 );
223 }
224
225 #[test]
226 fn test_batch_integration_matches_serial() {
227 let problem = Arc::new(CachedOdeProblem::new(
228 exponential_decay(),
229 0.0,
230 0.5,
231 0.001,
232 1,
233 ));
234 let batch_y0 = vec![vec![1.0], vec![2.0], vec![0.5]];
235 let batch_result = problem.integrate_batch(&batch_y0).expect("batch failed");
236
237 for (y0, yr) in batch_y0.iter().zip(batch_result.iter()) {
238 let serial = problem.integrate(y0).expect("serial failed");
239 assert!(
240 (serial[0] - yr[0]).abs() < 1e-14,
241 "Batch/serial mismatch: serial={:.10} batch={:.10}",
242 serial[0],
243 yr[0]
244 );
245 }
246 }
247
248 #[test]
249 fn test_neural_ode_repeated_forward_same_result() {
250 // Repeated calls with the same y0 must return identical results (deterministic)
251 let problem = Arc::new(CachedOdeProblem::new(
252 |_t, y, dydt| {
253 dydt[0] = -y[0];
254 dydt[1] = -2.0 * y[1];
255 },
256 0.0,
257 1.0,
258 0.01,
259 2,
260 ));
261 let y0 = vec![1.0, 1.0];
262 let r1 = problem.integrate(&y0).expect("first forward failed");
263 let r2 = problem.integrate(&y0).expect("second forward failed");
264 let r3 = problem.integrate(&y0).expect("third forward failed");
265 assert_eq!(r1, r2, "Results differ between calls 1 and 2");
266 assert_eq!(r1, r3, "Results differ between calls 1 and 3");
267 }
268
269 #[test]
270 fn test_dimension_mismatch_returns_error() {
271 let problem = CachedOdeProblem::new(exponential_decay(), 0.0, 1.0, 0.01, 1);
272 // Pass 2-element state to 1-dim problem
273 assert!(problem.integrate(&[1.0, 2.0]).is_err());
274 }
275
276 #[tokio::test]
277 async fn test_async_batch_returns_correct_shape() {
278 let problem = Arc::new(CachedOdeProblem::new(
279 exponential_decay(),
280 0.0,
281 0.5,
282 0.01,
283 1,
284 ));
285 let batch_y0 = vec![vec![1.0], vec![2.0], vec![3.0], vec![4.0]];
286 let expected_len = batch_y0.len();
287 let results = integrate_batch_async(problem, batch_y0)
288 .await
289 .expect("async batch failed");
290 assert_eq!(results.len(), expected_len);
291 for r in &results {
292 assert_eq!(r.len(), 1, "Each result must have state_dim=1 entries");
293 }
294 }
295
296 #[tokio::test]
297 async fn test_async_matches_sync() {
298 let problem_async = Arc::new(CachedOdeProblem::new(
299 exponential_decay(),
300 0.0,
301 1.0,
302 0.001,
303 1,
304 ));
305 let problem_sync = Arc::new(CachedOdeProblem::new(
306 exponential_decay(),
307 0.0,
308 1.0,
309 0.001,
310 1,
311 ));
312 let batch_y0 = vec![vec![1.0], vec![0.5], vec![2.0]];
313 let async_results = integrate_batch_async(problem_async, batch_y0.clone())
314 .await
315 .expect("async failed");
316 let sync_results = problem_sync
317 .integrate_batch(&batch_y0)
318 .expect("sync failed");
319 for (a, s) in async_results.iter().zip(sync_results.iter()) {
320 assert!(
321 (a[0] - s[0]).abs() < 1e-14,
322 "Async/sync mismatch: {:.10} vs {:.10}",
323 a[0],
324 s[0]
325 );
326 }
327 }
328}