scirs2_integrate/ode/methods/
simd_explicit.rs

1//! SIMD-accelerated explicit ODE solver methods
2//!
3//! This module provides SIMD-optimized versions of explicit ODE solvers,
4//! offering significant performance improvements for large systems of ODEs
5//! on modern processors with SIMD instruction sets.
6
7use crate::common::IntegrateFloat;
8use crate::error::IntegrateResult;
9use crate::ode::types::{ODEOptions, ODEResult};
10use crate::ode::utils::common::{estimate_initial_step, ODEState, StepResult};
11use scirs2_core::ndarray::{Array1, ArrayView1};
12
13#[cfg(feature = "simd")]
14use crate::ode::utils::simd_ops::SimdOdeOps;
15use scirs2_core::simd_ops::SimdUnifiedOps;
16
17/// SIMD-accelerated 4th-order Runge-Kutta method
18///
19/// This implementation uses SIMD instructions to accelerate vector operations
20/// in the RK4 integration steps, providing significant performance improvements
21/// for large systems of ODEs.
22///
23/// # Arguments
24///
25/// * `f` - ODE function dy/dt = f(t, y)
26/// * `t_span` - Time span [t_start, t_end]
27/// * `y0` - Initial condition
28/// * `opts` - Solver options
29///
30/// # Returns
31///
32/// The solution as an ODEResult or an error
33#[cfg(feature = "simd")]
34#[allow(dead_code)]
35pub fn simd_rk4_method<F, Func>(
36    f: Func,
37    t_span: [F; 2],
38    y0: Array1<F>,
39    opts: ODEOptions<F>,
40) -> IntegrateResult<ODEResult<F>>
41where
42    F: IntegrateFloat + SimdUnifiedOps,
43    Func: Fn(F, ArrayView1<F>) -> Array1<F>,
44{
45    let [t_start, t_end] = t_span;
46    let n_dim = y0.len();
47
48    // Determine step size
49    let h = opts.h0.unwrap_or_else(|| {
50        let dy0 = f(t_start, y0.view());
51        let tol = opts.atol + opts.rtol;
52        estimate_initial_step(&f, t_start, &y0, &dy0, tol, t_end)
53    });
54
55    // Storage for solution
56    let mut t_values = vec![t_start];
57    let mut y_values = vec![y0.clone()];
58
59    let mut t = t_start;
60    let mut y = y0;
61    let mut steps = 0;
62    let mut func_evals = 0;
63
64    while t < t_end {
65        // Adjust step size near the end
66        let h_current = if t + h > t_end { t_end - t } else { h };
67
68        // SIMD-accelerated RK4 step
69        let (y_new, n_evals) = simd_rk4_step(&f, t, &y.view(), h_current)?;
70        func_evals += n_evals;
71
72        // Update state
73        t += h_current;
74        y = y_new;
75        steps += 1;
76
77        // Store solution
78        t_values.push(t);
79        y_values.push(y.clone());
80
81        // Safety check
82        if steps > 1_000_000 {
83            return Err(crate::error::IntegrateError::ConvergenceError(
84                "Maximum number of steps exceeded in SIMD RK4 method".to_string(),
85            ));
86        }
87    }
88
89    Ok(ODEResult {
90        t: t_values,
91        y: y_values,
92        n_steps: steps,
93        n_eval: func_evals,
94        n_accepted: steps,
95        n_rejected: 0,
96        n_lu: 0,
97        n_jac: 0,
98        method: crate::ode::types::ODEMethod::RK4,
99        success: true,
100        message: Some("Integration completed successfully".to_string()),
101    })
102}
103
104/// SIMD-accelerated adaptive Runge-Kutta method (RK45)
105///
106/// This method uses embedded Runge-Kutta formulas with SIMD acceleration
107/// for both the integration steps and error estimation.
108///
109/// # Arguments
110///
111/// * `f` - ODE function dy/dt = f(t, y)
112/// * `t_span` - Time span [t_start, t_end]
113/// * `y0` - Initial condition
114/// * `opts` - Solver options including tolerances
115///
116/// # Returns
117///
118/// The solution as an ODEResult or an error
119#[cfg(feature = "simd")]
120#[allow(dead_code)]
121pub fn simd_rk45_method<F, Func>(
122    f: Func,
123    t_span: [F; 2],
124    y0: Array1<F>,
125    opts: ODEOptions<F>,
126) -> IntegrateResult<ODEResult<F>>
127where
128    F: IntegrateFloat + SimdUnifiedOps,
129    Func: Fn(F, ArrayView1<F>) -> Array1<F>,
130{
131    let [t_start, t_end] = t_span;
132
133    // Initial step size
134    let mut h = opts.h0.unwrap_or_else(|| {
135        let dy0 = f(t_start, y0.view());
136        let tol = opts.atol + opts.rtol;
137        estimate_initial_step(&f, t_start, &y0, &dy0, tol, t_end)
138    });
139
140    let min_step = opts.min_step.unwrap_or(F::from_f64(1e-12).unwrap());
141    let max_step = opts
142        .max_step
143        .unwrap_or((t_end - t_start) / F::from_f64(10.0).unwrap());
144    let abs_tol = opts.atol;
145    let rel_tol = opts.rtol;
146
147    // Storage for solution
148    let mut t_values = vec![t_start];
149    let mut y_values = vec![y0.clone()];
150
151    let mut t = t_start;
152    let mut y = y0;
153    let mut steps = 0;
154    let mut func_evals = 0;
155    let mut rejected_steps = 0;
156
157    while t < t_end {
158        // Adjust step size near the end
159        if t + h > t_end {
160            h = t_end - t;
161        }
162
163        // Limit step size to bounds
164        h = h.min(max_step).max(min_step);
165
166        // SIMD-accelerated RK45 step with error estimation
167        let (y_new, y_star, n_evals) = simd_rk45_step(&f, t, &y.view(), h)?;
168        func_evals += n_evals;
169
170        // Compute scaled error norm (matching non-SIMD version)
171        let mut err_norm = F::zero();
172        for i in 0..y_new.len() {
173            let sc = abs_tol + rel_tol * y_new[i].abs();
174            let err = (y_new[i] - y_star[i]).abs() / sc;
175            err_norm = err_norm.max(err);
176        }
177
178        // Step size control (matching non-SIMD version)
179        let order = F::from_f64(5.0).unwrap();
180        let exponent = F::one() / (order + F::one());
181        let safety = F::from_f64(0.9).unwrap();
182        let factor = safety * (F::one() / err_norm).powf(exponent);
183        let factor_min = F::from_f64(0.2).unwrap();
184        let factor_max = F::from_f64(5.0).unwrap();
185        let factor = factor.min(factor_max).max(factor_min);
186
187        if err_norm <= F::one() {
188            // Accept step
189            t += h;
190            y = y_new;
191            steps += 1;
192
193            // Store solution
194            t_values.push(t);
195            y_values.push(y.clone());
196
197            // Adjust step size for next step
198            if err_norm <= F::from_f64(0.1).unwrap() {
199                h *= factor.max(F::from_f64(2.0).unwrap());
200            } else {
201                h *= factor;
202            }
203        } else {
204            // Reject step
205            rejected_steps += 1;
206            h *= factor.min(F::one());
207
208            // Check minimum step size
209            if h < min_step {
210                return Err(crate::error::IntegrateError::StepSizeTooSmall(
211                    "Step size became too small in SIMD RK45 method".to_string(),
212                ));
213            }
214        }
215
216        // Safety check
217        if steps > 100_000 {
218            return Err(crate::error::IntegrateError::ConvergenceError(
219                "Maximum number of steps exceeded in SIMD RK45 method".to_string(),
220            ));
221        }
222    }
223
224    Ok(ODEResult {
225        t: t_values,
226        y: y_values,
227        n_steps: steps,
228        n_eval: func_evals,
229        n_accepted: steps - rejected_steps,
230        n_rejected: rejected_steps,
231        n_lu: 0,
232        n_jac: 0,
233        method: crate::ode::types::ODEMethod::RK45,
234        success: true,
235        message: Some("Integration completed successfully".to_string()),
236    })
237}
238
239/// Perform a single SIMD-accelerated RK4 step
240#[cfg(feature = "simd")]
241#[allow(dead_code)]
242fn simd_rk4_step<F, Func>(
243    f: &Func,
244    t: F,
245    y: &ArrayView1<F>,
246    h: F,
247) -> IntegrateResult<(Array1<F>, usize)>
248where
249    F: IntegrateFloat + SimdUnifiedOps,
250    Func: Fn(F, ArrayView1<F>) -> Array1<F>,
251{
252    let h_half = h * F::from_f64(0.5).unwrap();
253
254    // k1 = f(t, y)
255    let k1 = f(t, y.to_owned().view());
256
257    // k2 = f(t + h/2, y + h/2 * k1)
258    let y_temp1 = F::simd_add(y, &F::simd_scalar_mul(&k1.view(), h_half).view());
259    let k2 = f(t + h_half, y_temp1.view());
260
261    // k3 = f(t + h/2, y + h/2 * k2)
262    let y_temp2 = F::simd_add(y, &F::simd_scalar_mul(&k2.view(), h_half).view());
263    let k3 = f(t + h_half, y_temp2.view());
264
265    // k4 = f(t + h, y + h * k3)
266    let y_temp3 = F::simd_add(y, &F::simd_scalar_mul(&k3.view(), h).view());
267    let k4 = f(t + h, y_temp3.view());
268
269    // y_new = y + h/6 * (k1 + 2*k2 + 2*k3 + k4)
270    let c1 = F::one() / F::from_f64(6.0).unwrap();
271    let c2 = F::from_f64(2.0).unwrap() / F::from_f64(6.0).unwrap();
272
273    // Compute weighted sum: k1/6 + k2/3 + k3/3 + k4/6
274    let term1 = F::simd_scalar_mul(&k1.view(), c1 * h);
275    let term2 = F::simd_scalar_mul(&k2.view(), c2 * h);
276    let term3 = F::simd_scalar_mul(&k3.view(), c2 * h);
277    let term4 = F::simd_scalar_mul(&k4.view(), c1 * h);
278
279    let sum12 = F::simd_add(&term1.view(), &term2.view());
280    let sum34 = F::simd_add(&term3.view(), &term4.view());
281    let weighted_sum = F::simd_add(&sum12.view(), &sum34.view());
282
283    let y_new = F::simd_add(y, &weighted_sum.view());
284
285    Ok((y_new, 4)) // 4 function evaluations
286}
287
288/// Perform a single SIMD-accelerated RK45 step with error estimation
289#[cfg(feature = "simd")]
290#[allow(dead_code)]
291fn simd_rk45_step<F, Func>(
292    f: &Func,
293    t: F,
294    y: &ArrayView1<F>,
295    h: F,
296) -> IntegrateResult<(Array1<F>, Array1<F>, usize)>
297where
298    F: IntegrateFloat + SimdUnifiedOps,
299    Func: Fn(F, ArrayView1<F>) -> Array1<F>,
300{
301    // Dormand-Prince coefficients
302    let a21 = F::from_f64(1.0 / 5.0).unwrap();
303    let a31 = F::from_f64(3.0 / 40.0).unwrap();
304    let a32 = F::from_f64(9.0 / 40.0).unwrap();
305    let a41 = F::from_f64(44.0 / 45.0).unwrap();
306    let a42 = F::from_f64(-56.0 / 15.0).unwrap();
307    let a43 = F::from_f64(32.0 / 9.0).unwrap();
308    let a51 = F::from_f64(19372.0 / 6561.0).unwrap();
309    let a52 = F::from_f64(-25360.0 / 2187.0).unwrap();
310    let a53 = F::from_f64(64448.0 / 6561.0).unwrap();
311    let a54 = F::from_f64(-212.0 / 729.0).unwrap();
312    let a61 = F::from_f64(9017.0 / 3168.0).unwrap();
313    let a62 = F::from_f64(-355.0 / 33.0).unwrap();
314    let a63 = F::from_f64(46732.0 / 5247.0).unwrap();
315    let a64 = F::from_f64(49.0 / 176.0).unwrap();
316    let a65 = F::from_f64(-5103.0 / 18656.0).unwrap();
317
318    // k1 = f(t, y)
319    let k1 = f(t, y.to_owned().view());
320
321    // k2 = f(t + h/5, y + h/5 * k1)
322    let y2 = F::simd_add(y, &F::simd_scalar_mul(&k1.view(), h * a21).view());
323    let k2 = f(t + h * a21, y2.view());
324
325    // k3 = f(t + 3h/10, y + h * (3/40 * k1 + 9/40 * k2))
326    let term1 = F::simd_scalar_mul(&k1.view(), a31 * h);
327    let term2 = F::simd_scalar_mul(&k2.view(), a32 * h);
328    let y3 = F::simd_add(y, &F::simd_add(&term1.view(), &term2.view()).view());
329    let k3 = f(t + h * F::from_f64(3.0 / 10.0).unwrap(), y3.view());
330
331    // k4 = f(t + 4h/5, y + h * (44/45 * k1 - 56/15 * k2 + 32/9 * k3))
332    let t1 = F::simd_scalar_mul(&k1.view(), a41 * h);
333    let t2 = F::simd_scalar_mul(&k2.view(), a42 * h);
334    let t3 = F::simd_scalar_mul(&k3.view(), a43 * h);
335    let y4 = F::simd_add(
336        y,
337        &F::simd_add(&F::simd_add(&t1.view(), &t2.view()).view(), &t3.view()).view(),
338    );
339    let k4 = f(t + h * F::from_f64(4.0 / 5.0).unwrap(), y4.view());
340
341    // k5
342    let r1 = F::simd_scalar_mul(&k1.view(), a51 * h);
343    let r2 = F::simd_scalar_mul(&k2.view(), a52 * h);
344    let r3 = F::simd_scalar_mul(&k3.view(), a53 * h);
345    let r4 = F::simd_scalar_mul(&k4.view(), a54 * h);
346    let sum1 = F::simd_add(&r1.view(), &r2.view());
347    let sum2 = F::simd_add(&r3.view(), &r4.view());
348    let y5 = F::simd_add(y, &F::simd_add(&sum1.view(), &sum2.view()).view());
349    let k5 = f(t + h * F::from_f64(8.0 / 9.0).unwrap(), y5.view());
350
351    // k6
352    let s1 = F::simd_scalar_mul(&k1.view(), a61 * h);
353    let s2 = F::simd_scalar_mul(&k2.view(), a62 * h);
354    let s3 = F::simd_scalar_mul(&k3.view(), a63 * h);
355    let s4 = F::simd_scalar_mul(&k4.view(), a64 * h);
356    let s5 = F::simd_scalar_mul(&k5.view(), a65 * h);
357    let ssum1 = F::simd_add(&s1.view(), &s2.view());
358    let ssum2 = F::simd_add(&s3.view(), &s4.view());
359    let ssum3 = F::simd_add(&ssum1.view(), &ssum2.view());
360    let y6 = F::simd_add(y, &F::simd_add(&ssum3.view(), &s5.view()).view());
361    let k6 = f(t + h, y6.view());
362
363    // 5th order solution (y_stage is same as y_new for FSAL property)
364    let b1 = F::from_f64(35.0 / 384.0).unwrap();
365    let b3 = F::from_f64(500.0 / 1113.0).unwrap();
366    let b4 = F::from_f64(125.0 / 192.0).unwrap();
367    let b5 = F::from_f64(-2187.0 / 6784.0).unwrap();
368    let b6 = F::from_f64(11.0 / 84.0).unwrap();
369
370    let w1 = F::simd_scalar_mul(&k1.view(), b1 * h);
371    let w3 = F::simd_scalar_mul(&k3.view(), b3 * h);
372    let w4 = F::simd_scalar_mul(&k4.view(), b4 * h);
373    let w5 = F::simd_scalar_mul(&k5.view(), b5 * h);
374    let w6 = F::simd_scalar_mul(&k6.view(), b6 * h);
375    let wsum1 = F::simd_add(&w1.view(), &w3.view());
376    let wsum2 = F::simd_add(&w4.view(), &w5.view());
377    let wsum3 = F::simd_add(&wsum1.view(), &wsum2.view());
378    let y_new = F::simd_add(y, &F::simd_add(&wsum3.view(), &w6.view()).view());
379
380    // k7 = f(t + h, y_new) - needed for 4th order solution
381    let k7 = f(t + h, y_new.view());
382
383    // 4th order solution for error estimation (includes k7)
384    let b1_star = F::from_f64(5179.0 / 57600.0).unwrap();
385    let b3_star = F::from_f64(7571.0 / 16695.0).unwrap();
386    let b4_star = F::from_f64(393.0 / 640.0).unwrap();
387    let b5_star = F::from_f64(-92097.0 / 339200.0).unwrap();
388    let b6_star = F::from_f64(187.0 / 2100.0).unwrap();
389    let b7_star = F::from_f64(1.0 / 40.0).unwrap();
390
391    let v1 = F::simd_scalar_mul(&k1.view(), b1_star * h);
392    let v3 = F::simd_scalar_mul(&k3.view(), b3_star * h);
393    let v4 = F::simd_scalar_mul(&k4.view(), b4_star * h);
394    let v5 = F::simd_scalar_mul(&k5.view(), b5_star * h);
395    let v6 = F::simd_scalar_mul(&k6.view(), b6_star * h);
396    let v7 = F::simd_scalar_mul(&k7.view(), b7_star * h);
397    let vsum1 = F::simd_add(&v1.view(), &v3.view());
398    let vsum2 = F::simd_add(&v4.view(), &v5.view());
399    let vsum3 = F::simd_add(&v6.view(), &v7.view());
400    let vsum4 = F::simd_add(&vsum1.view(), &vsum2.view());
401    let y_star = F::simd_add(y, &F::simd_add(&vsum4.view(), &vsum3.view()).view());
402
403    // Return both 5th and 4th order solutions for error estimation
404    Ok((y_new, y_star, 7)) // 7 function evaluations
405}
406
407/// Fallback methods when SIMD is not available
408#[cfg(not(feature = "simd"))]
409#[allow(dead_code)]
410pub fn simd_rk4_method<F, Func>(
411    f: Func,
412    t_span: [F; 2],
413    y0: Array1<F>,
414    opts: ODEOptions<F>,
415) -> IntegrateResult<ODEResult<F>>
416where
417    F: IntegrateFloat + SimdUnifiedOps,
418    Func: Fn(F, ArrayView1<F>) -> Array1<F>,
419{
420    // Fallback to regular RK4 method
421    let h = opts.h0.unwrap_or_else(|| {
422        let dy0 = f(t_span[0], y0.view());
423        let tol = opts.atol + opts.rtol;
424        estimate_initial_step(&f, t_span[0], &y0, &dy0, tol, t_span[1])
425    });
426    crate::ode::methods::explicit::rk4_method(f, t_span, y0, h, opts)
427}
428
429#[cfg(not(feature = "simd"))]
430#[allow(dead_code)]
431pub fn simd_rk45_method<F, Func>(
432    f: Func,
433    t_span: [F; 2],
434    y0: Array1<F>,
435    opts: ODEOptions<F>,
436) -> IntegrateResult<ODEResult<F>>
437where
438    F: IntegrateFloat + SimdUnifiedOps,
439    Func: Fn(F, ArrayView1<F>) -> Array1<F>,
440{
441    // Fallback to regular RK45 method
442    crate::ode::methods::adaptive::rk45_method(f, t_span, y0, opts)
443}
444
445#[cfg(test)]
446mod tests {
447    use super::*;
448    use approx::assert_relative_eq;
449    use scirs2_core::ndarray::arr1;
450
451    #[test]
452    #[cfg(feature = "simd")]
453    fn test_simd_rk4_simple() {
454        // Test on simple exponential decay: dy/dt = -y, y(0) = 1
455        let f = |_t: f64, y: ArrayView1<f64>| -> Array1<f64> { -y.to_owned() };
456
457        let y0 = arr1(&[1.0]);
458        let t_span = [0.0, 1.0];
459        let opts = ODEOptions {
460            h0: Some(0.1),
461            ..Default::default()
462        };
463
464        let result = simd_rk4_method(f, t_span, y0, opts).unwrap();
465
466        // Exact solution at t=1 is exp(-1) ≈ 0.36788
467        let final_value = result.y.last().unwrap()[0];
468        let exact = (-1.0_f64).exp();
469
470        assert_relative_eq!(final_value, exact, epsilon = 1e-3);
471        assert!(result.success);
472        // Check that it's using SIMD RK4 method (would need method tracking)
473    }
474
475    #[test]
476    #[cfg(feature = "simd")]
477    fn test_simd_rk45_adaptive() {
478        // Test on harmonic oscillator: d²y/dt² + y = 0
479        // Convert to system: dy₁/dt = y₂, dy₂/dt = -y₁
480        let f = |_t: f64, y: ArrayView1<f64>| -> Array1<f64> { arr1(&[y[1], -y[0]]) };
481
482        let y0 = arr1(&[1.0, 0.0]); // y(0) = 1, dy/dt(0) = 0
483        let t_span = [0.0, std::f64::consts::PI]; // Half period
484        let opts = ODEOptions {
485            atol: 1e-8,
486            rtol: 1e-8,
487            h0: Some(0.1),
488            ..Default::default()
489        };
490
491        let result = simd_rk45_method(f, t_span, y0, opts).unwrap();
492
493        // At t = π, exact solution is y₁ = -1, y₂ = 0
494        let final_y = result.y.last().unwrap();
495        assert_relative_eq!(final_y[0], -1.0, epsilon = 1e-6);
496        assert_relative_eq!(final_y[1], 0.0, epsilon = 1e-6);
497        assert!(result.success);
498        // Check that it's using SIMD RK45 method (would need method tracking)
499    }
500}