scirs2_integrate/geometric/
structure_preserving.rs

1//! Structure-preserving integrators
2//!
3//! This module provides integrators that preserve various geometric structures
4//! such as energy, momentum, symplectic structure, and other invariants.
5
6use crate::error::{IntegrateResult, IntegrateResult as Result};
7use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
8
9// Type aliases for complex function types
10type HamiltonianFn = Box<dyn Fn(&ArrayView1<f64>, &ArrayView1<f64>) -> f64>;
11type ConstraintFn = Box<dyn Fn(&ArrayView1<f64>) -> Array1<f64>>;
12type ConstraintJacobianFn = Box<dyn Fn(&ArrayView1<f64>) -> Array2<f64>>;
13type GradientFn = dyn Fn(&ArrayView1<f64>) -> Array1<f64>;
14type ForceFn = Box<dyn Fn(&ArrayView1<f64>) -> Array1<f64>>;
15type KineticFn = Box<dyn Fn(&ArrayView1<f64>) -> f64>;
16type PotentialFn = Box<dyn Fn(&ArrayView1<f64>) -> f64>;
17type StiffnessFn = Box<dyn Fn(&ArrayView1<f64>) -> Array1<f64>>;
18
19/// Trait for geometric invariants
20pub trait GeometricInvariant {
21    /// Evaluate the invariant quantity
22    fn evaluate(&self, x: &ArrayView1<f64>, v: &ArrayView1<f64>, t: f64) -> f64;
23
24    /// Name of the invariant (for debugging)
25    fn name(&self) -> &'static str;
26}
27
28/// General structure-preserving integrator
29pub struct StructurePreservingIntegrator {
30    /// Time step
31    pub dt: f64,
32    /// Integration method
33    pub method: StructurePreservingMethod,
34    /// Invariants to preserve
35    pub invariants: Vec<Box<dyn GeometricInvariant>>,
36    /// Tolerance for invariant preservation
37    pub tol: f64,
38}
39
40/// Available structure-preserving methods
41#[derive(Debug, Clone, Copy)]
42pub enum StructurePreservingMethod {
43    /// Discrete gradient method
44    DiscreteGradient,
45    /// Average vector field method
46    AverageVectorField,
47    /// Energy-momentum method
48    EnergyMomentum,
49    /// Variational integrator
50    Variational,
51}
52
53impl StructurePreservingIntegrator {
54    /// Create a new structure-preserving integrator
55    pub fn new(dt: f64, method: StructurePreservingMethod) -> Self {
56        Self {
57            dt,
58            method,
59            invariants: Vec::new(),
60            tol: 1e-10,
61        }
62    }
63
64    /// Add an invariant to preserve
65    pub fn add_invariant(&mut self, invariant: Box<dyn GeometricInvariant>) -> &mut Self {
66        self.invariants.push(invariant);
67        self
68    }
69
70    /// Check invariant preservation
71    pub fn check_invariants(
72        &self,
73        x0: &ArrayView1<f64>,
74        v0: &ArrayView1<f64>,
75        x1: &ArrayView1<f64>,
76        v1: &ArrayView1<f64>,
77        t: f64,
78    ) -> Vec<(String, f64)> {
79        let mut errors = Vec::new();
80
81        for invariant in &self.invariants {
82            let i0 = invariant.evaluate(x0, v0, t);
83            let i1 = invariant.evaluate(x1, v1, t + self.dt);
84            let error = (i1 - i0).abs() / (1.0 + i0.abs());
85            errors.push((invariant.name().to_string(), error));
86        }
87
88        errors
89    }
90}
91
92/// Energy-preserving integrator for Hamiltonian systems
93pub struct EnergyPreservingMethod {
94    /// Hamiltonian function
95    hamiltonian: HamiltonianFn,
96    /// Dimension
97    dim: usize,
98}
99
100impl EnergyPreservingMethod {
101    /// Create a new energy-preserving integrator
102    pub fn new(hamiltonian: HamiltonianFn, dim: usize) -> Self {
103        Self { hamiltonian, dim }
104    }
105
106    /// Discrete gradient method
107    pub fn discrete_gradient_step(
108        &self,
109        q: &ArrayView1<f64>,
110        p: &ArrayView1<f64>,
111        dt: f64,
112    ) -> IntegrateResult<(Array1<f64>, Array1<f64>)> {
113        let h = 1e-8;
114
115        // Compute gradients using finite differences
116        let mut grad_q = Array1::zeros(self.dim);
117        let mut grad_p = Array1::zeros(self.dim);
118
119        for i in 0..self.dim {
120            let mut q_plus = q.to_owned();
121            let mut q_minus = q.to_owned();
122            q_plus[i] += h;
123            q_minus[i] -= h;
124            grad_q[i] = ((self.hamiltonian)(&q_plus.view(), p)
125                - (self.hamiltonian)(&q_minus.view(), p))
126                / (2.0 * h);
127
128            let mut p_plus = p.to_owned();
129            let mut p_minus = p.to_owned();
130            p_plus[i] += h;
131            p_minus[i] -= h;
132            grad_p[i] = ((self.hamiltonian)(q, &p_plus.view())
133                - (self.hamiltonian)(q, &p_minus.view()))
134                / (2.0 * h);
135        }
136
137        // Implicit midpoint with discrete gradient
138        let q_mid = q + &grad_p * (dt / 2.0);
139        let p_mid = p - &grad_q * (dt / 2.0);
140
141        // Compute discrete gradient at midpoint
142        let mut grad_q_mid = Array1::zeros(self.dim);
143        let mut grad_p_mid = Array1::zeros(self.dim);
144
145        for i in 0..self.dim {
146            let mut q_plus = q_mid.clone();
147            let mut q_minus = q_mid.clone();
148            q_plus[i] += h;
149            q_minus[i] -= h;
150            grad_q_mid[i] = ((self.hamiltonian)(&q_plus.view(), &p_mid.view())
151                - (self.hamiltonian)(&q_minus.view(), &p_mid.view()))
152                / (2.0 * h);
153
154            let mut p_plus = p_mid.clone();
155            let mut p_minus = p_mid.clone();
156            p_plus[i] += h;
157            p_minus[i] -= h;
158            grad_p_mid[i] = ((self.hamiltonian)(&q_mid.view(), &p_plus.view())
159                - (self.hamiltonian)(&q_mid.view(), &p_minus.view()))
160                / (2.0 * h);
161        }
162
163        let q_new = q + &grad_p_mid * dt;
164        let p_new = p - &grad_q_mid * dt;
165
166        Ok((q_new, p_new))
167    }
168
169    /// Average vector field method
170    pub fn average_vector_field_step(
171        &self,
172        q: &ArrayView1<f64>,
173        p: &ArrayView1<f64>,
174        dt: f64,
175    ) -> IntegrateResult<(Array1<f64>, Array1<f64>)> {
176        // Simplified AVF - uses quadrature to average the vector field
177        let _n_quad = 3; // Number of quadrature points
178        let weights = [1.0 / 6.0, 4.0 / 6.0, 1.0 / 6.0]; // Simpson's rule weights
179        let nodes = [0.0, 0.5, 1.0]; // Quadrature nodes
180
181        let mut q_avg = Array1::zeros(self.dim);
182        let mut p_avg = Array1::zeros(self.dim);
183
184        for (&w, &s) in weights.iter().zip(nodes.iter()) {
185            // Linear interpolation
186            let q_s = q * (1.0 - s) + &(q + &(p * dt)) * s;
187            let p_s = p * (1.0 - s) + &(p - &(q * dt)) * s; // Simplified
188
189            // Compute gradients at interpolated point
190            let h = 1e-8;
191            for j in 0..self.dim {
192                let mut q_plus = q_s.clone();
193                let mut q_minus = q_s.clone();
194                q_plus[j] += h;
195                q_minus[j] -= h;
196
197                p_avg[j] += w * (self.hamiltonian)(&q_plus.view(), &p_s.view())
198                    - (self.hamiltonian)(&q_minus.view(), &p_s.view()) / (2.0 * h);
199
200                let mut p_plus = p_s.clone();
201                let mut p_minus = p_s.clone();
202                p_plus[j] += h;
203                p_minus[j] -= h;
204
205                q_avg[j] += w * (self.hamiltonian)(&q_s.view(), &p_plus.view())
206                    - (self.hamiltonian)(&q_s.view(), &p_minus.view()) / (2.0 * h);
207            }
208        }
209
210        let q_new = q + &q_avg * dt;
211        let p_new = p - &p_avg * dt;
212
213        Ok((q_new, p_new))
214    }
215}
216
217/// Momentum-preserving integrator
218pub struct MomentumPreservingMethod {
219    /// System dimension
220    #[allow(dead_code)]
221    dim: usize,
222    /// Force function
223    force: ForceFn,
224    /// Mass matrix (diagonal)
225    mass: Array1<f64>,
226}
227
228impl MomentumPreservingMethod {
229    /// Create a new momentum-preserving integrator
230    pub fn new(dim: usize, force: ForceFn, mass: Array1<f64>) -> Self {
231        Self { dim, force, mass }
232    }
233
234    /// Integrate one step preserving total momentum
235    pub fn step(
236        &self,
237        x: &ArrayView1<f64>,
238        v: &ArrayView1<f64>,
239        dt: f64,
240    ) -> IntegrateResult<(Array1<f64>, Array1<f64>)> {
241        // Compute forces
242        let f = (self.force)(x);
243
244        // Check momentum conservation (for internal forces, total force should be zero)
245        let total_force: f64 = f.sum();
246        if total_force.abs() > 1e-10 {
247            // Apply momentum correction
248            let f_corrected = &f - total_force / self.dim as f64;
249
250            // Velocity Verlet with corrected forces
251            let a = &f_corrected / &self.mass;
252            let x_new = x + v * dt + &a * (dt * dt / 2.0);
253
254            let f_new = (self.force)(&x_new.view());
255            let f_new_corrected = &f_new - f_new.sum() / self.dim as f64;
256            let a_new = &f_new_corrected / &self.mass;
257
258            let v_new = v + (&a + &a_new) * (dt / 2.0);
259
260            Ok((x_new, v_new))
261        } else {
262            // Standard velocity Verlet
263            let a = &f / &self.mass;
264            let x_new = x + v * dt + &a * (dt * dt / 2.0);
265
266            let f_new = (self.force)(&x_new.view());
267            let a_new = &f_new / &self.mass;
268
269            let v_new = v + (&a + &a_new) * (dt / 2.0);
270
271            Ok((x_new, v_new))
272        }
273    }
274}
275
276/// Conservation checker for verifying invariant preservation
277pub struct ConservationChecker;
278
279impl ConservationChecker {
280    /// Check energy conservation
281    pub fn check_energy<H>(trajectory: &[(Array1<f64>, Array1<f64>)], hamiltonian: H) -> Vec<f64>
282    where
283        H: Fn(&ArrayView1<f64>, &ArrayView1<f64>) -> f64,
284    {
285        trajectory
286            .iter()
287            .map(|(q, p)| hamiltonian(&q.view(), &p.view()))
288            .collect()
289    }
290
291    /// Check momentum conservation
292    pub fn check_momentum(
293        trajectory: &[(Array1<f64>, Array1<f64>)],
294        masses: &ArrayView1<f64>,
295    ) -> Vec<Array1<f64>> {
296        trajectory.iter().map(|(_, v)| v * masses).collect()
297    }
298
299    /// Check angular momentum conservation
300    pub fn check_angular_momentum(
301        trajectory: &[(Array1<f64>, Array1<f64>)],
302        masses: &ArrayView1<f64>,
303    ) -> Vec<Array1<f64>> {
304        trajectory
305            .iter()
306            .map(|(x, v)| {
307                // For 3D systems, compute L = r × p
308                if x.len() == 3 && v.len() == 3 {
309                    let px = v[0] * masses[0];
310                    let py = v[1] * masses[1];
311                    let pz = v[2] * masses[2];
312
313                    Array1::from_vec(vec![
314                        x[1] * pz - x[2] * py,
315                        x[2] * px - x[0] * pz,
316                        x[0] * py - x[1] * px,
317                    ])
318                } else {
319                    Array1::zeros(3)
320                }
321            })
322            .collect()
323    }
324
325    /// Compute relative error in conservation
326    pub fn relative_error(values: &[f64]) -> f64 {
327        if values.is_empty() {
328            return 0.0;
329        }
330
331        let initial = values[0];
332        let max_deviation = values
333            .iter()
334            .map(|&v| (v - initial).abs())
335            .fold(0.0, f64::max);
336
337        max_deviation / (1.0 + initial.abs())
338    }
339}
340
341/// Splitting method for separable Hamiltonians
342pub struct SplittingIntegrator {
343    /// Kinetic energy part T(p)
344    kinetic: KineticFn,
345    /// Potential energy part V(q)
346    potential: PotentialFn,
347    /// System dimension
348    #[allow(dead_code)]
349    dim: usize,
350    /// Splitting coefficients
351    coefficients: Vec<(f64, f64)>,
352}
353
354impl SplittingIntegrator {
355    /// Create a new splitting integrator with Strang splitting
356    pub fn strang(kinetic: KineticFn, potential: PotentialFn, dim: usize) -> Self {
357        let coefficients = vec![(0.5, 1.0), (0.5, 0.0)];
358        Self {
359            kinetic,
360            potential,
361            dim,
362            coefficients,
363        }
364    }
365
366    /// Create with Yoshida 4th order coefficients
367    pub fn yoshida4(kinetic: KineticFn, potential: PotentialFn, dim: usize) -> Self {
368        let x1 = 1.0 / (2.0 - 2.0_f64.powf(1.0 / 3.0));
369        let x0 = -2.0_f64.powf(1.0 / 3.0) * x1;
370
371        let coefficients = vec![
372            (x1 / 2.0, x1),
373            ((x0 + x1) / 2.0, x0),
374            ((x0 + x1) / 2.0, x1),
375            (x1 / 2.0, 0.0),
376        ];
377
378        Self {
379            kinetic,
380            potential,
381            dim,
382            coefficients,
383        }
384    }
385
386    /// Perform one splitting step
387    pub fn step(
388        &self,
389        q: &ArrayView1<f64>,
390        p: &ArrayView1<f64>,
391        dt: f64,
392    ) -> IntegrateResult<(Array1<f64>, Array1<f64>)> {
393        let mut q_current = q.to_owned();
394        let mut p_current = p.to_owned();
395
396        for &(a, b) in &self.coefficients {
397            // Kick: p' = p - a*dt*∇V(q)
398            if a != 0.0 {
399                let grad_v = self.gradient_potential(&q_current.view());
400                p_current = p_current - grad_v * (a * dt);
401            }
402
403            // Drift: q' = q + b*dt*∇T(p')
404            if b != 0.0 {
405                let grad_t = self.gradient_kinetic(&p_current.view());
406                q_current = q_current + grad_t * (b * dt);
407            }
408        }
409
410        Ok((q_current, p_current))
411    }
412
413    /// Compute gradient of kinetic energy
414    fn gradient_kinetic(&self, p: &ArrayView1<f64>) -> Array1<f64> {
415        let h = 1e-8;
416        let mut grad = Array1::zeros(self.dim);
417
418        for i in 0..self.dim {
419            let mut p_plus = p.to_owned();
420            let mut p_minus = p.to_owned();
421            p_plus[i] += h;
422            p_minus[i] -= h;
423
424            grad[i] =
425                ((self.kinetic)(&p_plus.view()) - (self.kinetic)(&p_minus.view())) / (2.0 * h);
426        }
427
428        grad
429    }
430
431    /// Compute gradient of potential energy
432    fn gradient_potential(&self, q: &ArrayView1<f64>) -> Array1<f64> {
433        let h = 1e-8;
434        let mut grad = Array1::zeros(self.dim);
435
436        for i in 0..self.dim {
437            let mut q_plus = q.to_owned();
438            let mut q_minus = q.to_owned();
439            q_plus[i] += h;
440            q_minus[i] -= h;
441
442            grad[i] =
443                ((self.potential)(&q_plus.view()) - (self.potential)(&q_minus.view())) / (2.0 * h);
444        }
445
446        grad
447    }
448}
449
450/// Energy-momentum conserving integrator for nonlinear elastodynamics
451pub struct EnergyMomentumIntegrator {
452    /// Mass matrix
453    mass: Array1<f64>,
454    /// Stiffness function
455    stiffness: StiffnessFn,
456    /// System dimension
457    #[allow(dead_code)]
458    dim: usize,
459}
460
461impl EnergyMomentumIntegrator {
462    /// Create a new energy-momentum integrator
463    pub fn new(mass: Array1<f64>, stiffness: StiffnessFn) -> Self {
464        let dim = mass.len();
465        Self {
466            mass,
467            stiffness,
468            dim,
469        }
470    }
471
472    /// Integrate one step with energy-momentum conservation
473    pub fn step(
474        &self,
475        u: &ArrayView1<f64>,
476        v: &ArrayView1<f64>,
477        dt: f64,
478    ) -> IntegrateResult<(Array1<f64>, Array1<f64>)> {
479        // Predict displacement
480        let u_pred = u + v * dt;
481
482        // Compute average internal force
483        let f0 = (self.stiffness)(u);
484        let f1 = (self.stiffness)(&u_pred.view());
485        let f_avg = (&f0 + &f1) / 2.0;
486
487        // Algorithmic acceleration
488        let a_alg = &f_avg / &self.mass;
489
490        // Update
491        let u_new = u + v * dt + &a_alg * (dt * dt / 2.0);
492        let v_new = v + &a_alg * dt;
493
494        // Energy-momentum correction
495        let momentum_error: f64 = (&v_new * &self.mass).sum() - (v.to_owned() * &self.mass).sum();
496        if momentum_error.abs() > 1e-12 {
497            let v_corrected = &v_new - momentum_error / self.mass.sum();
498            Ok((u_new, v_corrected))
499        } else {
500            Ok((u_new, v_new))
501        }
502    }
503}
504
505/// Störmer-Verlet method for constrained systems
506pub struct ConstrainedIntegrator {
507    /// Constraint function g(q) = 0
508    constraints: ConstraintFn,
509    /// Constraint Jacobian
510    constraint_jacobian: ConstraintJacobianFn,
511    /// System dimension
512    #[allow(dead_code)]
513    dim: usize,
514    /// Number of constraints
515    #[allow(dead_code)]
516    n_constraints: usize,
517    /// Tolerance for constraint satisfaction
518    tol: f64,
519}
520
521impl ConstrainedIntegrator {
522    /// Create a new constrained integrator
523    pub fn new(
524        constraints: ConstraintFn,
525        constraint_jacobian: ConstraintJacobianFn,
526        dim: usize,
527        n_constraints: usize,
528    ) -> Self {
529        Self {
530            constraints,
531            constraint_jacobian,
532            dim,
533            n_constraints,
534            tol: 1e-10,
535        }
536    }
537
538    /// SHAKE algorithm for position constraints
539    pub fn shake_step(
540        &self,
541        q: &ArrayView1<f64>,
542        p: &ArrayView1<f64>,
543        dt: f64,
544        force: &Array1<f64>,
545    ) -> IntegrateResult<(Array1<f64>, Array1<f64>)> {
546        // Unconstrained step
547        let q_tilde = q + p * dt;
548        let p_tilde = p + force * dt;
549
550        // SHAKE iteration for position constraints
551        let mut q_new = q_tilde.clone();
552        let mut lambda;
553
554        for _ in 0..100 {
555            let g = (self.constraints)(&q_new.view());
556            if g.mapv(f64::abs).sum() < self.tol {
557                break;
558            }
559
560            let g_matrix = (self.constraint_jacobian)(&q_new.view());
561
562            // Solve for Lagrange multipliers
563            // G * G^T * λ = -g
564            let ggt = g_matrix.dot(&g_matrix.t());
565            lambda = self.solve_linear_system(&ggt, &(-&g))?;
566
567            // Update position
568            let correction = g_matrix.t().dot(&lambda);
569            q_new = &q_new + &correction * dt * dt;
570        }
571
572        // RATTLE for velocity constraints
573        let g_new = (self.constraint_jacobian)(&q_new.view());
574        let gv = g_new.dot(&p_tilde);
575
576        // Solve G * G^T * μ = -G * v
577        let ggt = g_new.dot(&g_new.t());
578        let mu = self.solve_linear_system(&ggt, &(-&gv))?;
579
580        let p_correction = g_new.t().dot(&mu);
581        let p_new = &p_tilde + &p_correction;
582
583        Ok((q_new, p_new))
584    }
585
586    /// Simple linear system solver (for small systems)
587    fn solve_linear_system(
588        &self,
589        a: &scirs2_core::ndarray::Array2<f64>,
590        b: &Array1<f64>,
591    ) -> IntegrateResult<Array1<f64>> {
592        // LU decomposition would be more robust
593        let n = b.len();
594        let mut x = Array1::zeros(n);
595
596        // Simplified Gaussian elimination
597        let mut a_copy = a.clone();
598        let mut b_copy = b.clone();
599
600        for i in 0..n {
601            // Pivot
602            let pivot = a_copy[[i, i]];
603            if pivot.abs() < 1e-14 {
604                return Err(crate::error::IntegrateError::ComputationError(
605                    "Singular constraint matrix".to_string(),
606                ));
607            }
608
609            // Eliminate
610            for j in (i + 1)..n {
611                let factor = a_copy[[j, i]] / pivot;
612                for k in i..n {
613                    a_copy[[j, k]] -= factor * a_copy[[i, k]];
614                }
615                b_copy[j] -= factor * b_copy[i];
616            }
617        }
618
619        // Back substitution
620        for i in (0..n).rev() {
621            let mut sum = b_copy[i];
622            for j in (i + 1)..n {
623                sum -= a_copy[[i, j]] * x[j];
624            }
625            x[i] = sum / a_copy[[i, i]];
626        }
627
628        Ok(x)
629    }
630}
631
632/// Multi-symplectic integrator for PDEs
633pub struct MultiSymplecticIntegrator {
634    /// Spatial dimension
635    #[allow(dead_code)]
636    spatial_dim: usize,
637    /// Number of fields
638    #[allow(dead_code)]
639    n_fields: usize,
640    /// Symplectic structure matrices
641    k: scirs2_core::ndarray::Array2<f64>,
642    l: scirs2_core::ndarray::Array2<f64>,
643}
644
645impl MultiSymplecticIntegrator {
646    /// Create a new multi-symplectic integrator
647    pub fn new(
648        spatial_dim: usize,
649        n_fields: usize,
650        k: scirs2_core::ndarray::Array2<f64>,
651        l: scirs2_core::ndarray::Array2<f64>,
652    ) -> Self {
653        Self {
654            spatial_dim,
655            n_fields,
656            k,
657            l,
658        }
659    }
660
661    /// Preissman box scheme
662    pub fn preissman_step(
663        &self,
664        z: &scirs2_core::ndarray::Array2<f64>,
665        s: &GradientFn,
666        dt: f64,
667        dx: f64,
668    ) -> IntegrateResult<scirs2_core::ndarray::Array2<f64>> {
669        let (nx_, _) = z.dim();
670        let mut z_new = z.clone();
671
672        // Iterate through spatial grid
673        for i in 1..nx_ {
674            // Box average
675            let z_avg = (&z.row(i - 1) + &z.row(i) + z_new.row(i - 1) + z_new.row(i)) / 4.0;
676
677            // Compute gradient of S
678            let grad_s = s(&z_avg.view());
679
680            // Multi-symplectic conservation law
681            // K(z_t) + L(z_x) = ∇S(z)
682            let z_t = (&z_new.row(i) - &z.row(i) + z_new.row(i - 1) - z.row(i - 1)) / (2.0 * dt);
683            let z_x = (&z_new.row(i) + &z.row(i) - z_new.row(i - 1) - z.row(i - 1)) / (2.0 * dx);
684
685            let residual = self.k.dot(&z_t) + self.l.dot(&z_x) - grad_s;
686
687            // Newton iteration (simplified)
688            let current_row = z_new.row(i).to_owned();
689            let update = &current_row - &residual * 0.5;
690            z_new.row_mut(i).assign(&update);
691        }
692
693        Ok(z_new)
694    }
695}
696
697/// Example invariants
698pub mod invariants {
699    use super::*;
700
701    /// Energy invariant for Hamiltonian systems
702    pub struct EnergyInvariant {
703        hamiltonian: HamiltonianFn,
704    }
705
706    impl EnergyInvariant {
707        pub fn new(hamiltonian: HamiltonianFn) -> Self {
708            Self { hamiltonian }
709        }
710    }
711
712    impl GeometricInvariant for EnergyInvariant {
713        fn evaluate(&self, x: &ArrayView1<f64>, v: &ArrayView1<f64>, t: f64) -> f64 {
714            (self.hamiltonian)(x, v)
715        }
716
717        fn name(&self) -> &'static str {
718            "Energy"
719        }
720    }
721
722    /// Linear momentum invariant
723    pub struct LinearMomentumInvariant {
724        masses: Array1<f64>,
725        component: usize,
726    }
727
728    impl LinearMomentumInvariant {
729        pub fn new(masses: Array1<f64>, component: usize) -> Self {
730            Self { masses, component }
731        }
732    }
733
734    impl GeometricInvariant for LinearMomentumInvariant {
735        fn evaluate(&self, x: &ArrayView1<f64>, v: &ArrayView1<f64>, t: f64) -> f64 {
736            v[self.component] * self.masses[self.component]
737        }
738
739        fn name(&self) -> &'static str {
740            "Linear Momentum"
741        }
742    }
743
744    /// Angular momentum invariant (for 2D systems)
745    pub struct AngularMomentumInvariant2D {
746        masses: Array1<f64>,
747    }
748
749    impl AngularMomentumInvariant2D {
750        pub fn new(masses: Array1<f64>) -> Self {
751            Self { masses }
752        }
753    }
754
755    impl GeometricInvariant for AngularMomentumInvariant2D {
756        fn evaluate(&self, x: &ArrayView1<f64>, v: &ArrayView1<f64>, t: f64) -> f64 {
757            // L = m(xv_y - yv_x) for 2D
758            let n = x.len() / 2;
759            let mut l = 0.0;
760
761            for i in 0..n {
762                let xi = x[2 * i];
763                let yi = x[2 * i + 1];
764                let vxi = v[2 * i];
765                let vyi = v[2 * i + 1];
766                l += self.masses[i] * (xi * vyi - yi * vxi);
767            }
768
769            l
770        }
771
772        fn name(&self) -> &'static str {
773            "Angular Momentum"
774        }
775    }
776}
777
778#[cfg(test)]
779mod tests {
780    use super::*;
781    use crate::{EnergyPreservingMethod, MomentumPreservingMethod};
782    use approx::assert_relative_eq;
783    use scirs2_core::ndarray::{Array1, ArrayView1};
784
785    #[test]
786    fn test_energy_preservation() {
787        // Simple harmonic oscillator: H = p²/2m + kx²/2
788        let m = 1.0;
789        let k = 1.0;
790        let hamiltonian = Box::new(move |q: &ArrayView1<f64>, p: &ArrayView1<f64>| {
791            p[0] * p[0] / (2.0 * m) + k * q[0] * q[0] / 2.0
792        });
793
794        let integrator = EnergyPreservingMethod::new(hamiltonian.clone(), 1);
795        let q0 = Array1::from_vec(vec![1.0]);
796        let p0 = Array1::from_vec(vec![0.0]);
797
798        let initial_energy = hamiltonian(&q0.view(), &p0.view());
799
800        // Integrate for one period
801        let dt = 0.1;
802        let n_steps = 63; // approximately 2π
803
804        let mut q = q0.clone();
805        let mut p = p0.clone();
806
807        for _ in 0..n_steps {
808            let (q_new, p_new) = integrator
809                .discrete_gradient_step(&q.view(), &p.view(), dt)
810                .unwrap();
811            q = q_new;
812            p = p_new;
813        }
814
815        let final_energy = hamiltonian(&q.view(), &p.view());
816        assert_relative_eq!(initial_energy, final_energy, epsilon = 1e-3);
817    }
818
819    #[test]
820    fn test_momentum_preservation() {
821        // Two-particle system with internal forces
822        let dim = 4; // 2 particles in 2D
823        let force = Box::new(|x: &ArrayView1<f64>| {
824            // Spring force between particles
825            let dx = x[2] - x[0];
826            let dy = x[3] - x[1];
827            let r = (dx * dx + dy * dy).sqrt();
828            let f = if r > 0.0 { 1.0 / r } else { 0.0 };
829
830            Array1::from_vec(vec![f * dx, f * dy, -f * dx, -f * dy])
831        });
832
833        let mass = Array1::from_vec(vec![1.0, 1.0, 2.0, 2.0]);
834        let integrator = MomentumPreservingMethod::new(dim, force, mass.clone());
835
836        let x0 = Array1::from_vec(vec![0.0, 0.0, 1.0, 0.0]);
837        let v0 = Array1::from_vec(vec![0.0, 0.1, 0.0, -0.05]);
838
839        let initial_momentum: f64 = (&v0 * &mass).sum();
840
841        let (_x1, v1) = integrator.step(&x0.view(), &v0.view(), 0.01).unwrap();
842        let final_momentum: f64 = (&v1 * &mass).sum();
843
844        assert_relative_eq!(initial_momentum, final_momentum, epsilon = 1e-12);
845    }
846}