Skip to main content

sidereon_core/astro/integrators/
rk4.rs

1use crate::astro::error::PropagationError;
2use crate::astro::integrators::{DynamicsModel, Integrator};
3use crate::astro::propagator::api::{
4    validate_integrator_epoch, validate_integrator_options, IntegratorOptions, PropagationContext,
5};
6use crate::astro::propagator::result::{
7    validate_propagation_result, PropagationPoint, PropagationResult, PropagationStats,
8};
9use crate::astro::state::{CartesianState, StateDerivative};
10
11pub struct RK4;
12
13impl Integrator for RK4 {
14    fn propagate(
15        &self,
16        initial: CartesianState,
17        t_end_seconds: f64,
18        rhs: &dyn DynamicsModel,
19        ctx: &PropagationContext,
20        opts: &IntegratorOptions,
21    ) -> Result<PropagationResult, PropagationError> {
22        validate_integrator_options(opts)?;
23        validate_integrator_epoch(initial.epoch_tdb_seconds, "initial.epoch_tdb_seconds")?;
24        validate_integrator_epoch(t_end_seconds, "t_end_seconds")?;
25
26        let mut state = initial;
27        let mut t = initial.epoch_tdb_seconds;
28        let dt_target = t_end_seconds - t;
29        let sign = dt_target.signum();
30        let target_abs = dt_target.abs();
31
32        let h_initial = opts.initial_step.min(target_abs) * sign;
33        let mut h = h_initial;
34        let mut steps = 0;
35        let mut points = Vec::new();
36
37        points.push(PropagationPoint {
38            epoch_tdb_seconds: t,
39            position_km: state.position_array(),
40            velocity_km_s: state.velocity_array(),
41        });
42
43        while (t - initial.epoch_tdb_seconds).abs() < target_abs {
44            if steps >= opts.max_steps {
45                return Err(PropagationError::MaxStepsExceeded);
46            }
47
48            if (t + h - initial.epoch_tdb_seconds).abs() > target_abs {
49                h = t_end_seconds - t;
50            }
51
52            let next_state = self.step(state, h, rhs, ctx)?;
53            state = next_state;
54            t += h;
55            steps += 1;
56
57            if opts.dense_output {
58                points.push(PropagationPoint {
59                    epoch_tdb_seconds: t,
60                    position_km: state.position_array(),
61                    velocity_km_s: state.velocity_array(),
62                });
63            }
64        }
65
66        if !opts.dense_output {
67            points.push(PropagationPoint {
68                epoch_tdb_seconds: t,
69                position_km: state.position_array(),
70                velocity_km_s: state.velocity_array(),
71            });
72        }
73
74        validate_propagation_result(PropagationResult {
75            final_state: state,
76            points,
77            events: Vec::new(),
78            stats: PropagationStats {
79                accepted_steps: steps,
80                rejected_steps: 0,
81                evaluations: steps * 4,
82            },
83            dense: None,
84        })
85    }
86}
87
88impl RK4 {
89    fn step(
90        &self,
91        state: CartesianState,
92        h: f64,
93        rhs: &dyn DynamicsModel,
94        ctx: &PropagationContext,
95    ) -> Result<CartesianState, PropagationError> {
96        let k1 = rhs.derivative(&state, ctx)?;
97
98        let s2 = self.advance(&state, &k1, h / 2.0);
99        let k2 = rhs.derivative(&s2, ctx)?;
100
101        let s3 = self.advance(&state, &k2, h / 2.0);
102        let k3 = rhs.derivative(&s3, ctx)?;
103
104        let s4 = self.advance(&state, &k3, h);
105        let k4 = rhs.derivative(&s4, ctx)?;
106
107        let dpos =
108            (k1.dpos_km_s + k2.dpos_km_s * 2.0 + k3.dpos_km_s * 2.0 + k4.dpos_km_s) * (h / 6.0);
109        let dvel =
110            (k1.dvel_km_s2 + k2.dvel_km_s2 * 2.0 + k3.dvel_km_s2 * 2.0 + k4.dvel_km_s2) * (h / 6.0);
111
112        Ok(CartesianState {
113            epoch_tdb_seconds: state.epoch_tdb_seconds + h,
114            position_km: state.position_km + dpos,
115            velocity_km_s: state.velocity_km_s + dvel,
116        })
117    }
118
119    fn advance(&self, state: &CartesianState, deriv: &StateDerivative, h: f64) -> CartesianState {
120        CartesianState {
121            epoch_tdb_seconds: state.epoch_tdb_seconds + h,
122            position_km: state.position_km + deriv.dpos_km_s * h,
123            velocity_km_s: state.velocity_km_s + deriv.dvel_km_s2 * h,
124        }
125    }
126}
127
128#[cfg(test)]
129mod tests {
130    use super::*;
131    use nalgebra::Vector3;
132    use std::sync::atomic::{AtomicUsize, Ordering};
133
134    struct CountingDynamics<'a> {
135        calls: &'a AtomicUsize,
136    }
137
138    impl DynamicsModel for CountingDynamics<'_> {
139        fn derivative(
140            &self,
141            state: &CartesianState,
142            _ctx: &PropagationContext,
143        ) -> Result<StateDerivative, PropagationError> {
144            self.calls.fetch_add(1, Ordering::SeqCst);
145            Ok(StateDerivative {
146                dpos_km_s: state.velocity_km_s,
147                dvel_km_s2: Vector3::zeros(),
148            })
149        }
150    }
151
152    struct InfiniteAcceleration;
153
154    impl DynamicsModel for InfiniteAcceleration {
155        fn derivative(
156            &self,
157            state: &CartesianState,
158            _ctx: &PropagationContext,
159        ) -> Result<StateDerivative, PropagationError> {
160            Ok(StateDerivative {
161                dpos_km_s: state.velocity_km_s,
162                dvel_km_s2: Vector3::new(f64::INFINITY, 0.0, 0.0),
163            })
164        }
165    }
166
167    fn initial_state() -> CartesianState {
168        CartesianState {
169            epoch_tdb_seconds: 0.0,
170            position_km: Vector3::new(7000.0, 0.0, 0.0),
171            velocity_km_s: Vector3::new(0.0, 7.5, 0.0),
172        }
173    }
174
175    #[test]
176    fn rejects_non_finite_epochs_before_derivative_evaluation() {
177        let base = initial_state();
178        let mut nan_initial = base;
179        nan_initial.epoch_tdb_seconds = f64::NAN;
180        let mut infinite_initial = base;
181        infinite_initial.epoch_tdb_seconds = f64::INFINITY;
182        let cases = [
183            (nan_initial, 60.0, "initial.epoch_tdb_seconds"),
184            (infinite_initial, 60.0, "initial.epoch_tdb_seconds"),
185            (base, f64::NAN, "t_end_seconds"),
186            (base, f64::INFINITY, "t_end_seconds"),
187        ];
188
189        for (initial, t_end_seconds, field) in cases {
190            let calls = AtomicUsize::new(0);
191            let dynamics = CountingDynamics { calls: &calls };
192            let ctx = PropagationContext::default();
193            let opts = IntegratorOptions::default();
194
195            let error = RK4
196                .propagate(initial, t_end_seconds, &dynamics, &ctx, &opts)
197                .expect_err("non-finite RK4 epoch must fail validation");
198
199            assert_invalid_input(error, field, "not finite");
200            assert_eq!(
201                calls.load(Ordering::SeqCst),
202                0,
203                "non-finite {field} must be rejected before integration starts"
204            );
205        }
206    }
207
208    #[test]
209    fn finite_epochs_integrate_as_before() {
210        let calls = AtomicUsize::new(0);
211        let dynamics = CountingDynamics { calls: &calls };
212        let ctx = PropagationContext::default();
213        let opts = IntegratorOptions {
214            initial_step: 10.0,
215            ..IntegratorOptions::default()
216        };
217
218        let result = RK4
219            .propagate(initial_state(), 60.0, &dynamics, &ctx, &opts)
220            .expect("finite RK4 epochs must remain valid");
221
222        assert_eq!(result.final_state.epoch_tdb_seconds, 60.0);
223        assert_eq!(
224            result.final_state.position_km.x.to_bits(),
225            7000.0f64.to_bits()
226        );
227        assert_eq!(
228            result.final_state.position_km.y.to_bits(),
229            450.0f64.to_bits()
230        );
231        assert_eq!(
232            result.final_state.velocity_km_s.y.to_bits(),
233            7.5f64.to_bits()
234        );
235        assert_eq!(result.stats.accepted_steps, 6);
236        assert_eq!(calls.load(Ordering::SeqCst), 24);
237    }
238
239    #[test]
240    fn rejects_non_finite_outputs() {
241        let ctx = PropagationContext::default();
242        let opts = IntegratorOptions {
243            initial_step: 1.0,
244            ..IntegratorOptions::default()
245        };
246
247        let error = RK4
248            .propagate(initial_state(), 1.0, &InfiniteAcceleration, &ctx, &opts)
249            .expect_err("non-finite RK4 result must be rejected");
250
251        assert_numerical_failure(error, "final_state.position_km", "not finite");
252    }
253
254    fn assert_invalid_input(error: PropagationError, field: &str, reason: &str) {
255        match error {
256            PropagationError::InvalidInput(message) => {
257                assert!(message.contains(field), "{message}");
258                assert!(message.contains(reason), "{message}");
259            }
260            other => panic!("expected invalid propagation input for {field}, got {other:?}"),
261        }
262    }
263
264    fn assert_numerical_failure(error: PropagationError, field: &str, reason: &str) {
265        match error {
266            PropagationError::NumericalFailure(message) => {
267                assert!(message.contains(field), "{message}");
268                assert!(message.contains(reason), "{message}");
269            }
270            other => panic!("expected numerical failure for {field}, got {other:?}"),
271        }
272    }
273}