sidereon_core/astro/integrators/
rk4.rs1use crate::astro::error::PropagationError;
2use crate::astro::integrators::{DynamicsModel, Integrator};
3use crate::astro::propagator::api::{
4 validate_integrator_epoch, validate_integrator_options, IntegratorOptions, PropagationContext,
5};
6use crate::astro::propagator::result::{
7 validate_propagation_result, PropagationPoint, PropagationResult, PropagationStats,
8};
9use crate::astro::state::{CartesianState, StateDerivative};
10
11pub struct RK4;
12
13impl Integrator for RK4 {
14 fn propagate(
15 &self,
16 initial: CartesianState,
17 t_end_seconds: f64,
18 rhs: &dyn DynamicsModel,
19 ctx: &PropagationContext,
20 opts: &IntegratorOptions,
21 ) -> Result<PropagationResult, PropagationError> {
22 validate_integrator_options(opts)?;
23 validate_integrator_epoch(initial.epoch_tdb_seconds, "initial.epoch_tdb_seconds")?;
24 validate_integrator_epoch(t_end_seconds, "t_end_seconds")?;
25
26 let mut state = initial;
27 let mut t = initial.epoch_tdb_seconds;
28 let dt_target = t_end_seconds - t;
29 let sign = dt_target.signum();
30 let target_abs = dt_target.abs();
31
32 let h_initial = opts.initial_step.min(target_abs) * sign;
33 let mut h = h_initial;
34 let mut steps = 0;
35 let mut points = Vec::new();
36
37 points.push(PropagationPoint {
38 epoch_tdb_seconds: t,
39 position_km: state.position_array(),
40 velocity_km_s: state.velocity_array(),
41 });
42
43 while (t - initial.epoch_tdb_seconds).abs() < target_abs {
44 if steps >= opts.max_steps {
45 return Err(PropagationError::MaxStepsExceeded);
46 }
47
48 if (t + h - initial.epoch_tdb_seconds).abs() > target_abs {
49 h = t_end_seconds - t;
50 }
51
52 let next_state = self.step(state, h, rhs, ctx)?;
53 state = next_state;
54 t += h;
55 steps += 1;
56
57 if opts.dense_output {
58 points.push(PropagationPoint {
59 epoch_tdb_seconds: t,
60 position_km: state.position_array(),
61 velocity_km_s: state.velocity_array(),
62 });
63 }
64 }
65
66 if !opts.dense_output {
67 points.push(PropagationPoint {
68 epoch_tdb_seconds: t,
69 position_km: state.position_array(),
70 velocity_km_s: state.velocity_array(),
71 });
72 }
73
74 validate_propagation_result(PropagationResult {
75 final_state: state,
76 points,
77 events: Vec::new(),
78 stats: PropagationStats {
79 accepted_steps: steps,
80 rejected_steps: 0,
81 evaluations: steps * 4,
82 },
83 dense: None,
84 })
85 }
86}
87
88impl RK4 {
89 fn step(
90 &self,
91 state: CartesianState,
92 h: f64,
93 rhs: &dyn DynamicsModel,
94 ctx: &PropagationContext,
95 ) -> Result<CartesianState, PropagationError> {
96 let k1 = rhs.derivative(&state, ctx)?;
97
98 let s2 = self.advance(&state, &k1, h / 2.0);
99 let k2 = rhs.derivative(&s2, ctx)?;
100
101 let s3 = self.advance(&state, &k2, h / 2.0);
102 let k3 = rhs.derivative(&s3, ctx)?;
103
104 let s4 = self.advance(&state, &k3, h);
105 let k4 = rhs.derivative(&s4, ctx)?;
106
107 let dpos =
108 (k1.dpos_km_s + k2.dpos_km_s * 2.0 + k3.dpos_km_s * 2.0 + k4.dpos_km_s) * (h / 6.0);
109 let dvel =
110 (k1.dvel_km_s2 + k2.dvel_km_s2 * 2.0 + k3.dvel_km_s2 * 2.0 + k4.dvel_km_s2) * (h / 6.0);
111
112 Ok(CartesianState {
113 epoch_tdb_seconds: state.epoch_tdb_seconds + h,
114 position_km: state.position_km + dpos,
115 velocity_km_s: state.velocity_km_s + dvel,
116 })
117 }
118
119 fn advance(&self, state: &CartesianState, deriv: &StateDerivative, h: f64) -> CartesianState {
120 CartesianState {
121 epoch_tdb_seconds: state.epoch_tdb_seconds + h,
122 position_km: state.position_km + deriv.dpos_km_s * h,
123 velocity_km_s: state.velocity_km_s + deriv.dvel_km_s2 * h,
124 }
125 }
126}
127
128#[cfg(test)]
129mod tests {
130 use super::*;
131 use nalgebra::Vector3;
132 use std::sync::atomic::{AtomicUsize, Ordering};
133
134 struct CountingDynamics<'a> {
135 calls: &'a AtomicUsize,
136 }
137
138 impl DynamicsModel for CountingDynamics<'_> {
139 fn derivative(
140 &self,
141 state: &CartesianState,
142 _ctx: &PropagationContext,
143 ) -> Result<StateDerivative, PropagationError> {
144 self.calls.fetch_add(1, Ordering::SeqCst);
145 Ok(StateDerivative {
146 dpos_km_s: state.velocity_km_s,
147 dvel_km_s2: Vector3::zeros(),
148 })
149 }
150 }
151
152 struct InfiniteAcceleration;
153
154 impl DynamicsModel for InfiniteAcceleration {
155 fn derivative(
156 &self,
157 state: &CartesianState,
158 _ctx: &PropagationContext,
159 ) -> Result<StateDerivative, PropagationError> {
160 Ok(StateDerivative {
161 dpos_km_s: state.velocity_km_s,
162 dvel_km_s2: Vector3::new(f64::INFINITY, 0.0, 0.0),
163 })
164 }
165 }
166
167 fn initial_state() -> CartesianState {
168 CartesianState {
169 epoch_tdb_seconds: 0.0,
170 position_km: Vector3::new(7000.0, 0.0, 0.0),
171 velocity_km_s: Vector3::new(0.0, 7.5, 0.0),
172 }
173 }
174
175 #[test]
176 fn rejects_non_finite_epochs_before_derivative_evaluation() {
177 let base = initial_state();
178 let mut nan_initial = base;
179 nan_initial.epoch_tdb_seconds = f64::NAN;
180 let mut infinite_initial = base;
181 infinite_initial.epoch_tdb_seconds = f64::INFINITY;
182 let cases = [
183 (nan_initial, 60.0, "initial.epoch_tdb_seconds"),
184 (infinite_initial, 60.0, "initial.epoch_tdb_seconds"),
185 (base, f64::NAN, "t_end_seconds"),
186 (base, f64::INFINITY, "t_end_seconds"),
187 ];
188
189 for (initial, t_end_seconds, field) in cases {
190 let calls = AtomicUsize::new(0);
191 let dynamics = CountingDynamics { calls: &calls };
192 let ctx = PropagationContext::default();
193 let opts = IntegratorOptions::default();
194
195 let error = RK4
196 .propagate(initial, t_end_seconds, &dynamics, &ctx, &opts)
197 .expect_err("non-finite RK4 epoch must fail validation");
198
199 assert_invalid_input(error, field, "not finite");
200 assert_eq!(
201 calls.load(Ordering::SeqCst),
202 0,
203 "non-finite {field} must be rejected before integration starts"
204 );
205 }
206 }
207
208 #[test]
209 fn finite_epochs_integrate_as_before() {
210 let calls = AtomicUsize::new(0);
211 let dynamics = CountingDynamics { calls: &calls };
212 let ctx = PropagationContext::default();
213 let opts = IntegratorOptions {
214 initial_step: 10.0,
215 ..IntegratorOptions::default()
216 };
217
218 let result = RK4
219 .propagate(initial_state(), 60.0, &dynamics, &ctx, &opts)
220 .expect("finite RK4 epochs must remain valid");
221
222 assert_eq!(result.final_state.epoch_tdb_seconds, 60.0);
223 assert_eq!(
224 result.final_state.position_km.x.to_bits(),
225 7000.0f64.to_bits()
226 );
227 assert_eq!(
228 result.final_state.position_km.y.to_bits(),
229 450.0f64.to_bits()
230 );
231 assert_eq!(
232 result.final_state.velocity_km_s.y.to_bits(),
233 7.5f64.to_bits()
234 );
235 assert_eq!(result.stats.accepted_steps, 6);
236 assert_eq!(calls.load(Ordering::SeqCst), 24);
237 }
238
239 #[test]
240 fn rejects_non_finite_outputs() {
241 let ctx = PropagationContext::default();
242 let opts = IntegratorOptions {
243 initial_step: 1.0,
244 ..IntegratorOptions::default()
245 };
246
247 let error = RK4
248 .propagate(initial_state(), 1.0, &InfiniteAcceleration, &ctx, &opts)
249 .expect_err("non-finite RK4 result must be rejected");
250
251 assert_numerical_failure(error, "final_state.position_km", "not finite");
252 }
253
254 fn assert_invalid_input(error: PropagationError, field: &str, reason: &str) {
255 match error {
256 PropagationError::InvalidInput(message) => {
257 assert!(message.contains(field), "{message}");
258 assert!(message.contains(reason), "{message}");
259 }
260 other => panic!("expected invalid propagation input for {field}, got {other:?}"),
261 }
262 }
263
264 fn assert_numerical_failure(error: PropagationError, field: &str, reason: &str) {
265 match error {
266 PropagationError::NumericalFailure(message) => {
267 assert!(message.contains(field), "{message}");
268 assert!(message.contains(reason), "{message}");
269 }
270 other => panic!("expected numerical failure for {field}, got {other:?}"),
271 }
272 }
273}