Skip to main content

scirs2_integrate/ode/methods/
symplectic.rs

1//! Symplectic integrators for Hamiltonian systems within the ODE framework
2//!
3//! This module provides advanced symplectic integration methods that preserve
4//! the geometric structure of Hamiltonian systems. These are particularly useful
5//! for long-time integration of conservative systems where energy conservation
6//! is critical.
7//!
8//! # Methods provided
9//!
10//! - **Stormer-Verlet (Leapfrog)**: 2nd order, most widely used
11//! - **Velocity Verlet**: Variant optimized for molecular dynamics
12//! - **Yoshida 4th order**: Higher accuracy via triple-jump composition
13//! - **Yoshida 6th order**: Even higher accuracy, 7-stage composition
14//! - **Yoshida 8th order**: Very high accuracy, 15-stage composition
15//!
16//! # Energy monitoring
17//!
18//! All integrators track energy drift via the `EnergyMonitor` which records
19//! per-step and cumulative energy errors, enabling early detection of numerical
20//! instabilities.
21
22use crate::common::IntegrateFloat;
23use crate::error::{IntegrateError, IntegrateResult};
24use scirs2_core::ndarray::Array1;
25use std::marker::PhantomData;
26
27// ---------------------------------------------------------------------------
28// Hamiltonian system trait (ODE-methods flavour, independent of symplectic/)
29// ---------------------------------------------------------------------------
30
31/// Trait representing a Hamiltonian system for symplectic integration.
32///
33/// The system is described by Hamilton's equations:
34///   dq/dt =  dH/dp
35///   dp/dt = -dH/dq
36pub trait HamiltonianSystem<F: IntegrateFloat> {
37    /// Number of degrees of freedom (dimension of q or p).
38    fn ndof(&self) -> usize;
39
40    /// Compute dq/dt = dH/dp evaluated at (t, q, p).
41    fn dq_dt(&self, t: F, q: &Array1<F>, p: &Array1<F>) -> IntegrateResult<Array1<F>>;
42
43    /// Compute dp/dt = -dH/dq evaluated at (t, q, p).
44    fn dp_dt(&self, t: F, q: &Array1<F>, p: &Array1<F>) -> IntegrateResult<Array1<F>>;
45
46    /// Optionally compute the Hamiltonian H(t, q, p) for energy monitoring.
47    /// Returns `None` if no closed-form Hamiltonian is available.
48    fn hamiltonian(&self, _t: F, _q: &Array1<F>, _p: &Array1<F>) -> Option<F> {
49        None
50    }
51}
52
53/// A separable Hamiltonian H(q, p) = T(p) + V(q).
54///
55/// For separable systems the equations of motion are:
56///   dq/dt = dT/dp
57///   dp/dt = -dV/dq
58pub struct SeparableSystem<F: IntegrateFloat> {
59    ndof: usize,
60    /// dT/dp (kinetic gradient)
61    kinetic_grad: Box<dyn Fn(F, &Array1<F>) -> Array1<F> + Send + Sync>,
62    /// dV/dq (potential gradient)
63    potential_grad: Box<dyn Fn(F, &Array1<F>) -> Array1<F> + Send + Sync>,
64    /// Optional: T(p) for energy monitoring
65    kinetic_energy: Option<Box<dyn Fn(F, &Array1<F>) -> F + Send + Sync>>,
66    /// Optional: V(q) for energy monitoring
67    potential_energy: Option<Box<dyn Fn(F, &Array1<F>) -> F + Send + Sync>>,
68}
69
70impl<F: IntegrateFloat> SeparableSystem<F> {
71    /// Create a separable system from gradient functions.
72    ///
73    /// # Arguments
74    /// * `ndof` - number of degrees of freedom
75    /// * `kinetic_grad` - computes dT/dp
76    /// * `potential_grad` - computes dV/dq
77    pub fn new<KG, VG>(ndof: usize, kinetic_grad: KG, potential_grad: VG) -> Self
78    where
79        KG: Fn(F, &Array1<F>) -> Array1<F> + Send + Sync + 'static,
80        VG: Fn(F, &Array1<F>) -> Array1<F> + Send + Sync + 'static,
81    {
82        SeparableSystem {
83            ndof,
84            kinetic_grad: Box::new(kinetic_grad),
85            potential_grad: Box::new(potential_grad),
86            kinetic_energy: None,
87            potential_energy: None,
88        }
89    }
90
91    /// Attach energy functions for monitoring.
92    pub fn with_energy<KE, VE>(mut self, kinetic_energy: KE, potential_energy: VE) -> Self
93    where
94        KE: Fn(F, &Array1<F>) -> F + Send + Sync + 'static,
95        VE: Fn(F, &Array1<F>) -> F + Send + Sync + 'static,
96    {
97        self.kinetic_energy = Some(Box::new(kinetic_energy));
98        self.potential_energy = Some(Box::new(potential_energy));
99        self
100    }
101}
102
103impl<F: IntegrateFloat> HamiltonianSystem<F> for SeparableSystem<F> {
104    fn ndof(&self) -> usize {
105        self.ndof
106    }
107
108    fn dq_dt(&self, t: F, _q: &Array1<F>, p: &Array1<F>) -> IntegrateResult<Array1<F>> {
109        Ok((self.kinetic_grad)(t, p))
110    }
111
112    fn dp_dt(&self, t: F, q: &Array1<F>, _p: &Array1<F>) -> IntegrateResult<Array1<F>> {
113        // dp/dt = -dV/dq
114        let grad_v = (self.potential_grad)(t, q);
115        Ok(grad_v.mapv(|x| -x))
116    }
117
118    fn hamiltonian(&self, t: F, q: &Array1<F>, p: &Array1<F>) -> Option<F> {
119        match (&self.kinetic_energy, &self.potential_energy) {
120            (Some(ke), Some(ve)) => Some(ke(t, p) + ve(t, q)),
121            _ => None,
122        }
123    }
124}
125
126// ---------------------------------------------------------------------------
127// Energy monitor
128// ---------------------------------------------------------------------------
129
130/// Tracks energy drift during symplectic integration.
131#[derive(Debug, Clone)]
132pub struct EnergyMonitor<F: IntegrateFloat> {
133    /// Initial energy H_0
134    pub initial_energy: Option<F>,
135    /// Energy at each recorded step
136    pub energy_history: Vec<F>,
137    /// Absolute energy error |H(t) - H_0| at each recorded step
138    pub abs_errors: Vec<F>,
139    /// Maximum absolute energy error observed so far
140    pub max_abs_error: F,
141    /// Mean absolute energy error
142    pub mean_abs_error: F,
143    /// Relative energy error |H(t) - H_0| / |H_0| (only if H_0 != 0)
144    pub max_rel_error: F,
145    /// Number of samples recorded
146    sample_count: usize,
147    /// Running sum of absolute errors
148    error_sum: F,
149}
150
151impl<F: IntegrateFloat> EnergyMonitor<F> {
152    /// Create a new energy monitor.
153    pub fn new() -> Self {
154        EnergyMonitor {
155            initial_energy: None,
156            energy_history: Vec::new(),
157            abs_errors: Vec::new(),
158            max_abs_error: F::zero(),
159            mean_abs_error: F::zero(),
160            max_rel_error: F::zero(),
161            sample_count: 0,
162            error_sum: F::zero(),
163        }
164    }
165
166    /// Record an energy sample.
167    pub fn record(&mut self, energy: F) {
168        let h0 = match self.initial_energy {
169            Some(h) => h,
170            None => {
171                self.initial_energy = Some(energy);
172                self.energy_history.push(energy);
173                self.abs_errors.push(F::zero());
174                self.sample_count = 1;
175                return;
176            }
177        };
178
179        let abs_err = (energy - h0).abs();
180        self.energy_history.push(energy);
181        self.abs_errors.push(abs_err);
182        self.sample_count += 1;
183        self.error_sum += abs_err;
184
185        if abs_err > self.max_abs_error {
186            self.max_abs_error = abs_err;
187        }
188
189        let eps = F::from_f64(1e-300).unwrap_or_else(|| F::epsilon());
190        if h0.abs() > eps {
191            let rel_err = abs_err / h0.abs();
192            if rel_err > self.max_rel_error {
193                self.max_rel_error = rel_err;
194            }
195        }
196
197        if self.sample_count > 0 {
198            self.mean_abs_error =
199                self.error_sum / F::from_usize(self.sample_count).unwrap_or_else(|| F::one());
200        }
201    }
202}
203
204impl<F: IntegrateFloat> Default for EnergyMonitor<F> {
205    fn default() -> Self {
206        Self::new()
207    }
208}
209
210// ---------------------------------------------------------------------------
211// Result type
212// ---------------------------------------------------------------------------
213
214/// Result of symplectic ODE integration.
215#[derive(Debug, Clone)]
216pub struct SymplecticODEResult<F: IntegrateFloat> {
217    /// Time points
218    pub t: Vec<F>,
219    /// Position coordinates at each time
220    pub q: Vec<Array1<F>>,
221    /// Momentum coordinates at each time
222    pub p: Vec<Array1<F>>,
223    /// Number of steps taken
224    pub n_steps: usize,
225    /// Number of function evaluations
226    pub n_eval: usize,
227    /// Energy monitoring data (present only if Hamiltonian was available)
228    pub energy_monitor: Option<EnergyMonitor<F>>,
229}
230
231// ---------------------------------------------------------------------------
232// Trait for symplectic steppers
233// ---------------------------------------------------------------------------
234
235/// Trait for symplectic one-step methods.
236pub trait SymplecticStepper<F: IntegrateFloat> {
237    /// Order of the method.
238    fn order(&self) -> usize;
239
240    /// Name for diagnostic output.
241    fn name(&self) -> &str;
242
243    /// Perform a single symplectic step.
244    fn step(
245        &self,
246        sys: &dyn HamiltonianSystem<F>,
247        t: F,
248        q: &Array1<F>,
249        p: &Array1<F>,
250        dt: F,
251    ) -> IntegrateResult<(Array1<F>, Array1<F>)>;
252
253    /// Integrate from t0 to tf with fixed step size dt.
254    fn integrate(
255        &self,
256        sys: &dyn HamiltonianSystem<F>,
257        t0: F,
258        tf: F,
259        dt: F,
260        q0: Array1<F>,
261        p0: Array1<F>,
262    ) -> IntegrateResult<SymplecticODEResult<F>> {
263        if dt <= F::zero() {
264            return Err(IntegrateError::ValueError("dt must be positive".into()));
265        }
266        if q0.len() != p0.len() {
267            return Err(IntegrateError::DimensionMismatch(
268                "q and p must have the same length".into(),
269            ));
270        }
271
272        let span = tf - t0;
273        let n_steps_f = (span / dt).ceil();
274        let n_steps = n_steps_f
275            .to_f64()
276            .ok_or_else(|| IntegrateError::ValueError("Cannot convert n_steps to f64".into()))?
277            as usize;
278        let actual_dt = span
279            / F::from_usize(n_steps)
280                .ok_or_else(|| IntegrateError::ValueError("Cannot convert n_steps".into()))?;
281
282        let mut ts = Vec::with_capacity(n_steps + 1);
283        let mut qs = Vec::with_capacity(n_steps + 1);
284        let mut ps = Vec::with_capacity(n_steps + 1);
285
286        ts.push(t0);
287        qs.push(q0.clone());
288        ps.push(p0.clone());
289
290        let mut monitor = EnergyMonitor::new();
291        let has_hamiltonian = sys.hamiltonian(t0, &q0, &p0).is_some();
292        if let Some(h0) = sys.hamiltonian(t0, &q0, &p0) {
293            monitor.record(h0);
294        }
295
296        let mut cur_t = t0;
297        let mut cur_q = q0;
298        let mut cur_p = p0;
299        let mut n_eval: usize = 0;
300
301        for _ in 0..n_steps {
302            let (next_q, next_p) = self.step(sys, cur_t, &cur_q, &cur_p, actual_dt)?;
303            // Approximate evals per step (depends on method, conservative estimate)
304            n_eval += 2 * self.order();
305
306            cur_t += actual_dt;
307            if let Some(h) = sys.hamiltonian(cur_t, &next_q, &next_p) {
308                monitor.record(h);
309            }
310
311            ts.push(cur_t);
312            qs.push(next_q.clone());
313            ps.push(next_p.clone());
314
315            cur_q = next_q;
316            cur_p = next_p;
317        }
318
319        Ok(SymplecticODEResult {
320            t: ts,
321            q: qs,
322            p: ps,
323            n_steps,
324            n_eval,
325            energy_monitor: if has_hamiltonian { Some(monitor) } else { None },
326        })
327    }
328}
329
330// ---------------------------------------------------------------------------
331// Stormer-Verlet (Leapfrog) -- 2nd order
332// ---------------------------------------------------------------------------
333
334/// Stormer-Verlet (leapfrog) symplectic integrator, 2nd order.
335///
336/// Algorithm:
337/// 1. p_{1/2} = p_n + (dt/2) dp/dt(t_n, q_n, p_n)
338/// 2. q_{n+1} = q_n + dt  dq/dt(t_{n+1/2}, q_n, p_{1/2})
339/// 3. p_{n+1} = p_{1/2} + (dt/2) dp/dt(t_{n+1}, q_{n+1}, p_{1/2})
340#[derive(Debug, Clone)]
341pub struct StormerVerletODE<F: IntegrateFloat> {
342    _marker: PhantomData<F>,
343}
344
345impl<F: IntegrateFloat> StormerVerletODE<F> {
346    pub fn new() -> Self {
347        StormerVerletODE {
348            _marker: PhantomData,
349        }
350    }
351}
352
353impl<F: IntegrateFloat> Default for StormerVerletODE<F> {
354    fn default() -> Self {
355        Self::new()
356    }
357}
358
359impl<F: IntegrateFloat> SymplecticStepper<F> for StormerVerletODE<F> {
360    fn order(&self) -> usize {
361        2
362    }
363    fn name(&self) -> &str {
364        "Stormer-Verlet"
365    }
366
367    fn step(
368        &self,
369        sys: &dyn HamiltonianSystem<F>,
370        t: F,
371        q: &Array1<F>,
372        p: &Array1<F>,
373        dt: F,
374    ) -> IntegrateResult<(Array1<F>, Array1<F>)> {
375        let two = F::one() + F::one();
376        let half_dt = dt / two;
377
378        // Half-step momentum
379        let dp1 = sys.dp_dt(t, q, p)?;
380        let p_half = p + &(&dp1 * half_dt);
381
382        // Full-step position
383        let t_half = t + half_dt;
384        let dq = sys.dq_dt(t_half, q, &p_half)?;
385        let q_new = q + &(&dq * dt);
386
387        // Half-step momentum
388        let t_new = t + dt;
389        let dp2 = sys.dp_dt(t_new, &q_new, &p_half)?;
390        let p_new = &p_half + &(&dp2 * half_dt);
391
392        Ok((q_new, p_new))
393    }
394}
395
396// ---------------------------------------------------------------------------
397// Velocity-Verlet -- 2nd order variant
398// ---------------------------------------------------------------------------
399
400/// Velocity Verlet symplectic integrator, 2nd order.
401///
402/// Equivalent to Stormer-Verlet for separable Hamiltonians but
403/// formulated differently: updates position first using both
404/// velocity and acceleration, then updates momentum.
405#[derive(Debug, Clone)]
406pub struct VelocityVerletODE<F: IntegrateFloat> {
407    _marker: PhantomData<F>,
408}
409
410impl<F: IntegrateFloat> VelocityVerletODE<F> {
411    pub fn new() -> Self {
412        VelocityVerletODE {
413            _marker: PhantomData,
414        }
415    }
416}
417
418impl<F: IntegrateFloat> Default for VelocityVerletODE<F> {
419    fn default() -> Self {
420        Self::new()
421    }
422}
423
424impl<F: IntegrateFloat> SymplecticStepper<F> for VelocityVerletODE<F> {
425    fn order(&self) -> usize {
426        2
427    }
428    fn name(&self) -> &str {
429        "Velocity-Verlet"
430    }
431
432    fn step(
433        &self,
434        sys: &dyn HamiltonianSystem<F>,
435        t: F,
436        q: &Array1<F>,
437        p: &Array1<F>,
438        dt: F,
439    ) -> IntegrateResult<(Array1<F>, Array1<F>)> {
440        let two = F::one() + F::one();
441        let half_dt = dt / two;
442
443        // Compute acceleration (force) at current state
444        let dp_old = sys.dp_dt(t, q, p)?;
445
446        // Half-step momentum
447        let p_half = p + &(&dp_old * half_dt);
448
449        // Full-step position using half-step momentum
450        let dq = sys.dq_dt(t + half_dt, q, &p_half)?;
451        let q_new = q + &(&dq * dt);
452
453        // Compute acceleration at new position
454        let t_new = t + dt;
455        let dp_new = sys.dp_dt(t_new, &q_new, &p_half)?;
456
457        // Complete momentum step
458        let p_new = &p_half + &(&dp_new * half_dt);
459
460        Ok((q_new, p_new))
461    }
462}
463
464// ---------------------------------------------------------------------------
465// Yoshida composition methods
466// ---------------------------------------------------------------------------
467
468/// Yoshida 4th order symplectic integrator.
469///
470/// Constructed by triple-jump composition of a 2nd-order base method
471/// (Stormer-Verlet). Coefficients from Yoshida (1990):
472///   w_1 = w_3 = 1/(2 - 2^{1/3})
473///   w_0 = -2^{1/3}/(2 - 2^{1/3})
474#[derive(Debug, Clone)]
475pub struct Yoshida4<F: IntegrateFloat> {
476    base: StormerVerletODE<F>,
477    coefficients: [F; 3],
478}
479
480impl<F: IntegrateFloat> Yoshida4<F> {
481    pub fn new() -> Self {
482        let two = F::one() + F::one();
483        let cbrt2 = two.powf(
484            F::from_f64(1.0 / 3.0).unwrap_or_else(|| F::one() / (F::one() + F::one() + F::one())),
485        );
486        let w1 = F::one() / (two - cbrt2);
487        let w0 = -cbrt2 / (two - cbrt2);
488
489        Yoshida4 {
490            base: StormerVerletODE::new(),
491            coefficients: [w1, w0, w1],
492        }
493    }
494}
495
496impl<F: IntegrateFloat> Default for Yoshida4<F> {
497    fn default() -> Self {
498        Self::new()
499    }
500}
501
502impl<F: IntegrateFloat> SymplecticStepper<F> for Yoshida4<F> {
503    fn order(&self) -> usize {
504        4
505    }
506    fn name(&self) -> &str {
507        "Yoshida-4"
508    }
509
510    fn step(
511        &self,
512        sys: &dyn HamiltonianSystem<F>,
513        t: F,
514        q: &Array1<F>,
515        p: &Array1<F>,
516        dt: F,
517    ) -> IntegrateResult<(Array1<F>, Array1<F>)> {
518        let mut cur_t = t;
519        let mut cur_q = q.clone();
520        let mut cur_p = p.clone();
521
522        for &c in &self.coefficients {
523            let sub_dt = dt * c;
524            let (nq, np) = self.base.step(sys, cur_t, &cur_q, &cur_p, sub_dt)?;
525            cur_t += sub_dt;
526            cur_q = nq;
527            cur_p = np;
528        }
529
530        Ok((cur_q, cur_p))
531    }
532}
533
534/// Yoshida 6th order symplectic integrator.
535///
536/// 7-stage composition with coefficients from Yoshida (1990).
537#[derive(Debug, Clone)]
538pub struct Yoshida6<F: IntegrateFloat> {
539    base: StormerVerletODE<F>,
540    coefficients: [F; 7],
541}
542
543impl<F: IntegrateFloat> Yoshida6<F> {
544    pub fn new() -> Self {
545        // Coefficients from Yoshida (1990) for 6th order
546        let w1 = F::from_f64(0.784513610477560).unwrap_or_else(|| F::one());
547        let w2 = F::from_f64(0.235573213359357).unwrap_or_else(|| F::one());
548        let w3 = F::from_f64(-1.17767998417887).unwrap_or_else(|| -F::one());
549        let w4 = F::from_f64(1.31518632068391).unwrap_or_else(|| F::one());
550
551        Yoshida6 {
552            base: StormerVerletODE::new(),
553            coefficients: [w1, w2, w3, w4, w3, w2, w1],
554        }
555    }
556}
557
558impl<F: IntegrateFloat> Default for Yoshida6<F> {
559    fn default() -> Self {
560        Self::new()
561    }
562}
563
564impl<F: IntegrateFloat> SymplecticStepper<F> for Yoshida6<F> {
565    fn order(&self) -> usize {
566        6
567    }
568    fn name(&self) -> &str {
569        "Yoshida-6"
570    }
571
572    fn step(
573        &self,
574        sys: &dyn HamiltonianSystem<F>,
575        t: F,
576        q: &Array1<F>,
577        p: &Array1<F>,
578        dt: F,
579    ) -> IntegrateResult<(Array1<F>, Array1<F>)> {
580        let mut cur_t = t;
581        let mut cur_q = q.clone();
582        let mut cur_p = p.clone();
583
584        for &c in &self.coefficients {
585            let sub_dt = dt * c;
586            let (nq, np) = self.base.step(sys, cur_t, &cur_q, &cur_p, sub_dt)?;
587            cur_t += sub_dt;
588            cur_q = nq;
589            cur_p = np;
590        }
591
592        Ok((cur_q, cur_p))
593    }
594}
595
596/// Yoshida 8th order symplectic integrator.
597///
598/// 15-stage composition with coefficients from Yoshida (1990).
599#[derive(Debug, Clone)]
600pub struct Yoshida8<F: IntegrateFloat> {
601    base: StormerVerletODE<F>,
602    coefficients: Vec<F>,
603}
604
605impl<F: IntegrateFloat> Yoshida8<F> {
606    pub fn new() -> Self {
607        // Coefficients for 8th-order Yoshida composition (Kahan & Li, 1997)
608        let w = [
609            F::from_f64(0.74167036435061).unwrap_or_else(|| F::one()),
610            F::from_f64(-0.40910082580003).unwrap_or_else(|| -F::one()),
611            F::from_f64(0.19075471029623).unwrap_or_else(|| F::one()),
612            F::from_f64(-0.57386247111608).unwrap_or_else(|| -F::one()),
613            F::from_f64(0.29906418130365).unwrap_or_else(|| F::one()),
614            F::from_f64(0.33462491824529).unwrap_or_else(|| F::one()),
615            F::from_f64(0.31529309239676).unwrap_or_else(|| F::one()),
616            F::from_f64(-0.79688793935291).unwrap_or_else(|| -F::one()),
617        ];
618
619        // Symmetric composition: [w0..w7, w7..w0]
620        let mut coefficients = Vec::with_capacity(15);
621        for &c in &w {
622            coefficients.push(c);
623        }
624        // Mirror from w[6] down to w[0]
625        for i in (0..7).rev() {
626            coefficients.push(w[i]);
627        }
628
629        Yoshida8 {
630            base: StormerVerletODE::new(),
631            coefficients,
632        }
633    }
634}
635
636impl<F: IntegrateFloat> Default for Yoshida8<F> {
637    fn default() -> Self {
638        Self::new()
639    }
640}
641
642impl<F: IntegrateFloat> SymplecticStepper<F> for Yoshida8<F> {
643    fn order(&self) -> usize {
644        8
645    }
646    fn name(&self) -> &str {
647        "Yoshida-8"
648    }
649
650    fn step(
651        &self,
652        sys: &dyn HamiltonianSystem<F>,
653        t: F,
654        q: &Array1<F>,
655        p: &Array1<F>,
656        dt: F,
657    ) -> IntegrateResult<(Array1<F>, Array1<F>)> {
658        let mut cur_t = t;
659        let mut cur_q = q.clone();
660        let mut cur_p = p.clone();
661
662        for &c in &self.coefficients {
663            let sub_dt = dt * c;
664            let (nq, np) = self.base.step(sys, cur_t, &cur_q, &cur_p, sub_dt)?;
665            cur_t += sub_dt;
666            cur_q = nq;
667            cur_p = np;
668        }
669
670        Ok((cur_q, cur_p))
671    }
672}
673
674// ---------------------------------------------------------------------------
675// Convenience: solve_hamiltonian
676// ---------------------------------------------------------------------------
677
678/// Solve a Hamiltonian system using a specified symplectic method.
679///
680/// # Arguments
681/// * `sys` - the Hamiltonian system
682/// * `method` - the symplectic stepper
683/// * `t0` - initial time
684/// * `tf` - final time
685/// * `dt` - step size
686/// * `q0` - initial positions
687/// * `p0` - initial momenta
688pub fn solve_hamiltonian<F: IntegrateFloat>(
689    sys: &dyn HamiltonianSystem<F>,
690    method: &dyn SymplecticStepper<F>,
691    t0: F,
692    tf: F,
693    dt: F,
694    q0: Array1<F>,
695    p0: Array1<F>,
696) -> IntegrateResult<SymplecticODEResult<F>> {
697    method.integrate(sys, t0, tf, dt, q0, p0)
698}
699
700/// Enumeration of available symplectic methods for convenience.
701#[derive(Debug, Clone, Copy, PartialEq, Eq)]
702pub enum SymplecticMethod {
703    /// Stormer-Verlet (leapfrog), order 2
704    StormerVerlet,
705    /// Velocity Verlet, order 2
706    VelocityVerlet,
707    /// Yoshida 4th order
708    Yoshida4,
709    /// Yoshida 6th order
710    Yoshida6,
711    /// Yoshida 8th order
712    Yoshida8,
713}
714
715/// Create a boxed stepper from the method enum.
716pub fn create_stepper<F: IntegrateFloat>(
717    method: SymplecticMethod,
718) -> Box<dyn SymplecticStepper<F>> {
719    match method {
720        SymplecticMethod::StormerVerlet => Box::new(StormerVerletODE::<F>::new()),
721        SymplecticMethod::VelocityVerlet => Box::new(VelocityVerletODE::<F>::new()),
722        SymplecticMethod::Yoshida4 => Box::new(Yoshida4::<F>::new()),
723        SymplecticMethod::Yoshida6 => Box::new(Yoshida6::<F>::new()),
724        SymplecticMethod::Yoshida8 => Box::new(Yoshida8::<F>::new()),
725    }
726}
727
728// ---------------------------------------------------------------------------
729// Tests
730// ---------------------------------------------------------------------------
731
732#[cfg(test)]
733mod tests {
734    use super::*;
735    use scirs2_core::ndarray::array;
736
737    /// Build a simple harmonic oscillator: H = p^2/2 + q^2/2
738    fn harmonic_oscillator() -> SeparableSystem<f64> {
739        SeparableSystem::new(
740            1,
741            |_t, p: &Array1<f64>| p.clone(), // dT/dp = p
742            |_t, q: &Array1<f64>| q.clone(), // dV/dq = q
743        )
744        .with_energy(
745            |_t, p: &Array1<f64>| 0.5 * p.dot(p), // T(p) = p^2/2
746            |_t, q: &Array1<f64>| 0.5 * q.dot(q), // V(q) = q^2/2
747        )
748    }
749
750    /// Build a 2D Kepler problem: H = |p|^2/2 - 1/|q|
751    fn kepler_2d() -> SeparableSystem<f64> {
752        SeparableSystem::new(
753            2,
754            |_t, p: &Array1<f64>| p.clone(),
755            |_t, q: &Array1<f64>| {
756                let r2 = q[0] * q[0] + q[1] * q[1];
757                let r = r2.sqrt();
758                if r < 1e-12 {
759                    Array1::zeros(2)
760                } else {
761                    // dV/dq = -d/dq(-1/r) = q/r^3
762                    let r3 = r * r2;
763                    array![q[0] / r3, q[1] / r3]
764                }
765            },
766        )
767        .with_energy(
768            |_t, p: &Array1<f64>| 0.5 * p.dot(p),
769            |_t, q: &Array1<f64>| {
770                let r = (q[0] * q[0] + q[1] * q[1]).sqrt();
771                if r < 1e-12 {
772                    0.0
773                } else {
774                    -1.0 / r
775                }
776            },
777        )
778    }
779
780    #[test]
781    fn test_stormer_verlet_harmonic() {
782        let sys = harmonic_oscillator();
783        let sv = StormerVerletODE::new();
784        let q0 = array![1.0_f64];
785        let p0 = array![0.0_f64];
786
787        let result = sv
788            .integrate(
789                &sys,
790                0.0,
791                2.0 * std::f64::consts::PI,
792                0.01,
793                q0.clone(),
794                p0.clone(),
795            )
796            .expect("integration should succeed");
797
798        // After one full period, should return close to initial state
799        let q_final = result.q.last().expect("should have final q");
800        let p_final = result.p.last().expect("should have final p");
801        assert!(
802            (q_final[0] - 1.0).abs() < 0.01,
803            "q should return near 1.0, got {}",
804            q_final[0]
805        );
806        assert!(
807            p_final[0].abs() < 0.01,
808            "p should return near 0.0, got {}",
809            p_final[0]
810        );
811
812        // Energy conservation
813        let mon = result.energy_monitor.as_ref().expect("should have monitor");
814        // Stormer-Verlet is 2nd order: energy error bounded by O(dt^2) ~ 1e-4 for dt=0.01
815        assert!(
816            mon.max_rel_error < 1e-3,
817            "energy drift too large: {}",
818            mon.max_rel_error
819        );
820    }
821
822    #[test]
823    fn test_velocity_verlet_harmonic() {
824        let sys = harmonic_oscillator();
825        let vv = VelocityVerletODE::new();
826        let q0 = array![1.0_f64];
827        let p0 = array![0.0_f64];
828
829        let result = vv
830            .integrate(&sys, 0.0, 2.0 * std::f64::consts::PI, 0.01, q0, p0)
831            .expect("integration should succeed");
832
833        let mon = result.energy_monitor.as_ref().expect("should have monitor");
834        // Velocity-Verlet is 2nd order: energy error bounded by O(dt^2) ~ 1e-4 for dt=0.01
835        assert!(
836            mon.max_rel_error < 1e-3,
837            "energy drift too large: {}",
838            mon.max_rel_error
839        );
840    }
841
842    #[test]
843    fn test_yoshida4_convergence() {
844        let sys = harmonic_oscillator();
845        let y4 = Yoshida4::new();
846        let sv = StormerVerletODE::new();
847
848        let q0 = array![1.0_f64];
849        let p0 = array![0.0_f64];
850        let tf = 1.0;
851
852        // Exact solution at t=1: q = cos(1), p = -sin(1)
853        let q_exact = 1.0_f64.cos();
854        let p_exact = -1.0_f64.sin();
855
856        // Compare errors at two step sizes to verify order
857        let dts = [0.1, 0.05];
858        let mut sv_errors = Vec::new();
859        let mut y4_errors = Vec::new();
860
861        for &dt in &dts {
862            let sv_res = sv
863                .integrate(&sys, 0.0, tf, dt, q0.clone(), p0.clone())
864                .expect("sv integration failed");
865            let y4_res = y4
866                .integrate(&sys, 0.0, tf, dt, q0.clone(), p0.clone())
867                .expect("y4 integration failed");
868
869            let sv_err = ((sv_res.q.last().expect("no q")[0] - q_exact).powi(2)
870                + (sv_res.p.last().expect("no p")[0] - p_exact).powi(2))
871            .sqrt();
872            let y4_err = ((y4_res.q.last().expect("no q")[0] - q_exact).powi(2)
873                + (y4_res.p.last().expect("no p")[0] - p_exact).powi(2))
874            .sqrt();
875
876            sv_errors.push(sv_err);
877            y4_errors.push(y4_err);
878        }
879
880        // When dt halved, 2nd-order error should decrease by ~4, 4th-order by ~16
881        let sv_ratio = sv_errors[0] / sv_errors[1];
882        let y4_ratio = y4_errors[0] / y4_errors[1];
883
884        assert!(
885            sv_ratio > 3.0 && sv_ratio < 5.0,
886            "SV convergence ratio {sv_ratio} not ~4"
887        );
888        assert!(
889            y4_ratio > 12.0 && y4_ratio < 20.0,
890            "Y4 convergence ratio {y4_ratio} not ~16"
891        );
892
893        // Y4 should be more accurate than SV at same step size
894        assert!(y4_errors[0] < sv_errors[0], "Y4 should beat SV accuracy");
895    }
896
897    #[test]
898    fn test_yoshida6_better_than_yoshida4() {
899        let sys = harmonic_oscillator();
900        let y4 = Yoshida4::new();
901        let y6 = Yoshida6::new();
902
903        let q0 = array![1.0_f64];
904        let p0 = array![0.0_f64];
905        let dt = 0.1;
906        let tf = 1.0;
907
908        let q_exact = 1.0_f64.cos();
909        let p_exact = -1.0_f64.sin();
910
911        let r4 = y4
912            .integrate(&sys, 0.0, tf, dt, q0.clone(), p0.clone())
913            .expect("y4 failed");
914        let r6 = y6
915            .integrate(&sys, 0.0, tf, dt, q0.clone(), p0.clone())
916            .expect("y6 failed");
917
918        let e4 = ((r4.q.last().expect("no q")[0] - q_exact).powi(2)
919            + (r4.p.last().expect("no p")[0] - p_exact).powi(2))
920        .sqrt();
921        let e6 = ((r6.q.last().expect("no q")[0] - q_exact).powi(2)
922            + (r6.p.last().expect("no p")[0] - p_exact).powi(2))
923        .sqrt();
924
925        assert!(
926            e6 < e4,
927            "Y6 error ({e6}) should be less than Y4 error ({e4})"
928        );
929    }
930
931    #[test]
932    fn test_yoshida8_high_accuracy() {
933        let sys = harmonic_oscillator();
934        let y8 = Yoshida8::new();
935
936        let q0 = array![1.0_f64];
937        let p0 = array![0.0_f64];
938        let dt = 0.1;
939        let tf = 1.0;
940
941        let q_exact = 1.0_f64.cos();
942        let p_exact = -1.0_f64.sin();
943
944        let r8 = y8.integrate(&sys, 0.0, tf, dt, q0, p0).expect("y8 failed");
945
946        let e8 = ((r8.q.last().expect("no q")[0] - q_exact).powi(2)
947            + (r8.p.last().expect("no p")[0] - p_exact).powi(2))
948        .sqrt();
949
950        // 8th order with dt=0.1 should be very accurate
951        assert!(e8 < 1e-8, "Y8 error {e8} too large");
952    }
953
954    #[test]
955    fn test_kepler_energy_conservation() {
956        let sys = kepler_2d();
957        let y4 = Yoshida4::new();
958
959        // Circular orbit
960        let q0 = array![1.0, 0.0];
961        let p0 = array![0.0, 1.0];
962
963        let result = y4
964            .integrate(&sys, 0.0, 20.0, 0.01, q0, p0)
965            .expect("kepler integration failed");
966
967        let mon = result.energy_monitor.as_ref().expect("should have monitor");
968        assert!(
969            mon.max_rel_error < 1e-6,
970            "Kepler energy drift too large: {}",
971            mon.max_rel_error
972        );
973    }
974
975    #[test]
976    fn test_energy_monitor_recording() {
977        let sys = harmonic_oscillator();
978        let sv = StormerVerletODE::new();
979        let q0 = array![1.0_f64];
980        let p0 = array![0.0_f64];
981
982        let result = sv
983            .integrate(&sys, 0.0, 1.0, 0.1, q0, p0)
984            .expect("integration failed");
985
986        let mon = result.energy_monitor.as_ref().expect("should have monitor");
987        assert!(mon.initial_energy.is_some());
988        assert!(!mon.energy_history.is_empty());
989        assert_eq!(mon.energy_history.len(), mon.abs_errors.len());
990
991        // Initial energy should be 0.5 for q=1,p=0
992        let h0 = mon.initial_energy.expect("should have initial energy");
993        assert!((h0 - 0.5).abs() < 1e-10);
994    }
995
996    #[test]
997    fn test_create_stepper() {
998        let stepper = create_stepper::<f64>(SymplecticMethod::Yoshida4);
999        assert_eq!(stepper.order(), 4);
1000        assert_eq!(stepper.name(), "Yoshida-4");
1001
1002        let stepper = create_stepper::<f64>(SymplecticMethod::Yoshida8);
1003        assert_eq!(stepper.order(), 8);
1004    }
1005
1006    #[test]
1007    fn test_solve_hamiltonian_convenience() {
1008        let sys = harmonic_oscillator();
1009        let stepper = Yoshida4::<f64>::new();
1010
1011        let result = solve_hamiltonian(
1012            &sys,
1013            &stepper,
1014            0.0,
1015            std::f64::consts::PI,
1016            0.01,
1017            array![1.0],
1018            array![0.0],
1019        )
1020        .expect("solve_hamiltonian failed");
1021
1022        // At t=pi, q should be cos(pi) = -1, p should be -sin(pi) ~ 0
1023        let q_f = result.q.last().expect("no q");
1024        let p_f = result.p.last().expect("no p");
1025        assert!((q_f[0] + 1.0).abs() < 0.01, "q should be near -1");
1026        assert!(p_f[0].abs() < 0.01, "p should be near 0");
1027    }
1028
1029    #[test]
1030    fn test_invalid_inputs() {
1031        let sys = harmonic_oscillator();
1032        let sv = StormerVerletODE::new();
1033
1034        // Negative dt
1035        let res = sv.integrate(&sys, 0.0, 1.0, -0.1, array![1.0], array![0.0]);
1036        assert!(res.is_err());
1037
1038        // Mismatched dimensions
1039        let res = sv.integrate(&sys, 0.0, 1.0, 0.1, array![1.0, 2.0], array![0.0]);
1040        assert!(res.is_err());
1041    }
1042
1043    #[test]
1044    fn test_long_time_energy_bounded() {
1045        // Symplectic integrators should have bounded energy error, not growing
1046        let sys = harmonic_oscillator();
1047        let y4 = Yoshida4::new();
1048
1049        let result = y4
1050            .integrate(&sys, 0.0, 100.0, 0.05, array![1.0], array![0.0])
1051            .expect("long integration failed");
1052
1053        let mon = result.energy_monitor.as_ref().expect("monitor");
1054        // Energy error should stay bounded even after long integration
1055        assert!(
1056            mon.max_abs_error < 1e-6,
1057            "Energy error grew too large over long time: {}",
1058            mon.max_abs_error
1059        );
1060    }
1061}