scirs2_integrate/symplectic/
euler.rs1use crate::common::IntegrateFloat;
14use crate::error::IntegrateResult;
15use crate::symplectic::{HamiltonianFn, SymplecticIntegrator};
16use scirs2_core::ndarray::Array1;
17use std::marker::PhantomData;
18
19#[derive(Debug, Clone)]
25pub struct SymplecticEulerA<F: IntegrateFloat> {
26 _marker: PhantomData<F>,
27}
28
29impl<F: IntegrateFloat> SymplecticEulerA<F> {
30 pub fn new() -> Self {
32 SymplecticEulerA {
33 _marker: PhantomData,
34 }
35 }
36}
37
38impl<F: IntegrateFloat> Default for SymplecticEulerA<F> {
39 fn default() -> Self {
40 Self::new()
41 }
42}
43
44impl<F: IntegrateFloat> SymplecticIntegrator<F> for SymplecticEulerA<F> {
45 fn step(
46 &self,
47 system: &dyn HamiltonianFn<F>,
48 t: F,
49 q: &Array1<F>,
50 p: &Array1<F>,
51 dt: F,
52 ) -> IntegrateResult<(Array1<F>, Array1<F>)> {
53 let dq = system.dq_dt(t, q, p)?;
55 let q_new = q + &(&dq * dt);
56
57 let dp = system.dp_dt(t, &q_new, p)?;
59 let p_new = p + &(&dp * dt);
60
61 Ok((q_new, p_new))
62 }
63}
64
65#[derive(Debug, Clone)]
71pub struct SymplecticEulerB<F: IntegrateFloat> {
72 _marker: PhantomData<F>,
73}
74
75impl<F: IntegrateFloat> SymplecticEulerB<F> {
76 pub fn new() -> Self {
78 SymplecticEulerB {
79 _marker: PhantomData,
80 }
81 }
82}
83
84impl<F: IntegrateFloat> Default for SymplecticEulerB<F> {
85 fn default() -> Self {
86 Self::new()
87 }
88}
89
90impl<F: IntegrateFloat> SymplecticIntegrator<F> for SymplecticEulerB<F> {
91 fn step(
92 &self,
93 system: &dyn HamiltonianFn<F>,
94 t: F,
95 q: &Array1<F>,
96 p: &Array1<F>,
97 dt: F,
98 ) -> IntegrateResult<(Array1<F>, Array1<F>)> {
99 let dp = system.dp_dt(t, q, p)?;
101 let p_new = p + &(&dp * dt);
102
103 let dq = system.dq_dt(t, q, &p_new)?;
105 let q_new = q + &(&dq * dt);
106
107 Ok((q_new, p_new))
108 }
109}
110
111#[allow(dead_code)]
116pub fn symplectic_euler<F: IntegrateFloat>(
117 system: &dyn HamiltonianFn<F>,
118 t: F,
119 q: &Array1<F>,
120 p: &Array1<F>,
121 dt: F,
122) -> IntegrateResult<(Array1<F>, Array1<F>)> {
123 SymplecticEulerA::new().step(system, t, q, p, dt)
124}
125
126#[allow(dead_code)]
128pub fn symplectic_euler_a<F: IntegrateFloat>(
129 system: &dyn HamiltonianFn<F>,
130 t: F,
131 q: &Array1<F>,
132 p: &Array1<F>,
133 dt: F,
134) -> IntegrateResult<(Array1<F>, Array1<F>)> {
135 SymplecticEulerA::new().step(system, t, q, p, dt)
136}
137
138#[allow(dead_code)]
140pub fn symplectic_euler_b<F: IntegrateFloat>(
141 system: &dyn HamiltonianFn<F>,
142 t: F,
143 q: &Array1<F>,
144 p: &Array1<F>,
145 dt: F,
146) -> IntegrateResult<(Array1<F>, Array1<F>)> {
147 SymplecticEulerB::new().step(system, t, q, p, dt)
148}
149
150#[cfg(test)]
151mod tests {
152 use super::*;
153 use crate::symplectic::potential::SeparableHamiltonian;
154 use scirs2_core::ndarray::array;
155
156 #[test]
158 fn test_symplectic_euler() {
159 let system = SeparableHamiltonian::new(
161 |_t, p| -> f64 { 0.5 * p.dot(p) },
163 |_t, q| -> f64 { 0.5 * q.dot(q) },
165 );
166
167 let q0 = array![1.0];
169 let p0 = array![0.0];
170 let t0 = 0.0;
171 let dt = 0.1;
172
173 let (q1_a, p1_a) = symplectic_euler_a(&system, t0, &q0, &p0, dt).unwrap();
175
176 let (q1_b, p1_b) = symplectic_euler_b(&system, t0, &q0, &p0, dt).unwrap();
178
179 assert!((q1_a[0] - 1.0).abs() < 1e-12);
184 assert!((p1_a[0] + 0.1).abs() < 1e-12);
185
186 assert!((q1_b[0] - 0.99).abs() < 1e-12);
190 assert!((p1_b[0] + 0.1).abs() < 1e-12);
191 }
192
193 #[test]
194 fn test_energy_conservation() {
195 let system = SeparableHamiltonian::new(
197 |_t, p| -> f64 { 0.5 * p.dot(p) },
198 |_t, q| -> f64 { 0.5 * q.dot(q) },
199 );
200
201 let q0 = array![1.0];
203 let p0 = array![0.0];
204 let t0 = 0.0;
205 let tf = 10.0;
206 let dt = 0.1;
207
208 let integrator = SymplecticEulerA::new();
210 let result = integrator.integrate(&system, t0, tf, dt, q0, p0).unwrap();
211
212 if let Some(error) = result.energy_relative_error {
215 assert!(error < 0.1, "Energy error too large: {error}");
216 }
217
218 for i in 0..result.t.len() {
222 let q = &result.q[i];
223 let p = &result.p[i];
224 let radius_squared = q[0] * q[0] + p[0] * p[0];
225 assert!(
226 (radius_squared - 1.0).abs() < 0.1,
227 "Point ({}, {}) is too far from unit circle, r² = {}",
228 q[0],
229 p[0],
230 radius_squared
231 );
232 }
233 }
234}