Skip to main content

sidereon_core/astro/integrators/
dp54.rs

1use crate::astro::error::PropagationError;
2use crate::astro::integrators::tableau::DP54Tableau;
3use crate::astro::integrators::{DynamicsModel, Integrator};
4use crate::astro::propagator::api::{
5    validate_adaptive_integrator_options, validate_integrator_epoch, IntegratorOptions,
6    PropagationContext,
7};
8use crate::astro::propagator::controller::PIController;
9use crate::astro::propagator::dense_output::{DenseOutput, DenseSegment};
10use crate::astro::propagator::result::{
11    validate_propagation_result, PropagationPoint, PropagationResult, PropagationStats,
12};
13use crate::astro::state::{CartesianState, StateDerivative};
14use nalgebra::Vector3;
15
16pub struct DP54;
17
18impl Integrator for DP54 {
19    fn propagate(
20        &self,
21        initial: CartesianState,
22        t_end_seconds: f64,
23        rhs: &dyn DynamicsModel,
24        ctx: &PropagationContext,
25        opts: &IntegratorOptions,
26    ) -> Result<PropagationResult, PropagationError> {
27        validate_adaptive_integrator_options(opts)?;
28        validate_integrator_epoch(initial.epoch_tdb_seconds, "initial.epoch_tdb_seconds")?;
29        validate_integrator_epoch(t_end_seconds, "t_end_seconds")?;
30
31        let dt_target = t_end_seconds - initial.epoch_tdb_seconds;
32        let target_abs = dt_target.abs();
33        if target_abs == 0.0 {
34            let point = PropagationPoint {
35                epoch_tdb_seconds: initial.epoch_tdb_seconds,
36                position_km: initial.position_array(),
37                velocity_km_s: initial.velocity_array(),
38            };
39            let mut points = vec![point.clone()];
40            if !opts.dense_output {
41                points.push(point);
42            }
43            let dense = if opts.dense_output {
44                Some(DenseOutput {
45                    segments: Vec::new(),
46                })
47            } else {
48                None
49            };
50
51            return validate_propagation_result(PropagationResult {
52                final_state: initial,
53                points,
54                events: Vec::new(),
55                stats: PropagationStats {
56                    accepted_steps: 0,
57                    rejected_steps: 0,
58                    evaluations: 0,
59                },
60                dense,
61            });
62        }
63
64        let tableau = DP54Tableau::default();
65        let controller = PIController {
66            order: 5.0,
67            ..PIController::default()
68        };
69
70        let mut state = initial;
71        let mut t = initial.epoch_tdb_seconds;
72        let sign = dt_target.signum();
73
74        let mut h = crate::validate::clamp_magnitude(
75            opts.initial_step.min(target_abs) * sign,
76            opts.max_step,
77        );
78        let mut steps_accepted = 0;
79        let mut steps_rejected = 0;
80        let mut evals = 0;
81        let mut points = Vec::new();
82        let mut dense_segments = Vec::new();
83
84        points.push(PropagationPoint {
85            epoch_tdb_seconds: t,
86            position_km: state.position_array(),
87            velocity_km_s: state.velocity_array(),
88        });
89
90        // FSAL: k1
91        let mut k1 = rhs.derivative(&state, ctx)?;
92        evals += 1;
93
94        while (t - initial.epoch_tdb_seconds).abs() < target_abs {
95            if steps_accepted + steps_rejected >= opts.max_steps {
96                return Err(PropagationError::MaxStepsExceeded);
97            }
98
99            let mut h_step = h;
100            if (t + h_step - initial.epoch_tdb_seconds).abs() > target_abs {
101                h_step = t_end_seconds - t;
102            }
103
104            // Step using DP54
105            let step_ctx = DP54StepContext {
106                rhs,
107                ctx,
108                tableau: &tableau,
109                capture_stages: opts.dense_output,
110            };
111            let step_res = self.step(state, h_step, k1, &step_ctx)?;
112
113            // Error estimation
114            let r_scale = opts.abs_tol
115                + state
116                    .position_km
117                    .norm()
118                    .max(step_res.next_state.position_km.norm())
119                    * opts.rel_tol;
120            let v_scale = opts.abs_tol
121                + state
122                    .velocity_km_s
123                    .norm()
124                    .max(step_res.next_state.velocity_km_s.norm())
125                    * opts.rel_tol;
126
127            let err_r = step_res.r_err.norm() / r_scale;
128            let err_v = step_res.v_err.norm() / v_scale;
129            let err = err_r.max(err_v);
130
131            if err <= 1.0 {
132                // Accepted
133                if opts.dense_output {
134                    if let Some(stages) = step_res.stages {
135                        let ks_array: [StateDerivative; 7] = stages.try_into().map_err(|_| {
136                            PropagationError::NumericalFailure(
137                                "Failed to capture RK stages".to_string(),
138                            )
139                        })?;
140                        dense_segments.push(DenseSegment::from_dp54_stages(
141                            t,
142                            h_step,
143                            state,
144                            step_res.next_state,
145                            &ks_array,
146                        ));
147                    }
148                }
149
150                state = step_res.next_state;
151                t += h_step;
152                k1 = step_res.k_fsal; // FSAL
153                steps_accepted += 1;
154                evals += step_res.evals;
155
156                if opts.dense_output {
157                    points.push(PropagationPoint {
158                        epoch_tdb_seconds: t,
159                        position_km: state.position_array(),
160                        velocity_km_s: state.velocity_array(),
161                    });
162                }
163
164                h = crate::validate::clamp_magnitude(
165                    controller.next_step(h_step, err),
166                    opts.max_step,
167                );
168            } else {
169                steps_rejected += 1;
170                evals += step_res.evals;
171                h = crate::validate::clamp_magnitude(
172                    controller.next_step(h_step, err),
173                    opts.max_step,
174                );
175
176                if h.abs() < opts.min_step {
177                    return Err(PropagationError::NumericalFailure(
178                        "Step size too small".to_string(),
179                    ));
180                }
181            }
182        }
183
184        if !opts.dense_output {
185            points.push(PropagationPoint {
186                epoch_tdb_seconds: t,
187                position_km: state.position_array(),
188                velocity_km_s: state.velocity_array(),
189            });
190        }
191
192        let dense = if opts.dense_output {
193            Some(DenseOutput {
194                segments: dense_segments,
195            })
196        } else {
197            None
198        };
199
200        validate_propagation_result(PropagationResult {
201            final_state: state,
202            points,
203            events: Vec::new(),
204            stats: PropagationStats {
205                accepted_steps: steps_accepted,
206                rejected_steps: steps_rejected,
207                evaluations: evals,
208            },
209            dense,
210        })
211    }
212}
213
214#[cfg(test)]
215mod tests {
216    use super::*;
217    use std::sync::atomic::{AtomicUsize, Ordering};
218
219    struct CountingDynamics<'a> {
220        calls: &'a AtomicUsize,
221    }
222
223    impl DynamicsModel for CountingDynamics<'_> {
224        fn derivative(
225            &self,
226            state: &CartesianState,
227            _ctx: &PropagationContext,
228        ) -> Result<StateDerivative, PropagationError> {
229            self.calls.fetch_add(1, Ordering::SeqCst);
230            Ok(StateDerivative {
231                dpos_km_s: state.velocity_km_s,
232                dvel_km_s2: Vector3::zeros(),
233            })
234        }
235    }
236
237    struct CountingOscillator<'a> {
238        calls: &'a AtomicUsize,
239    }
240
241    impl DynamicsModel for CountingOscillator<'_> {
242        fn derivative(
243            &self,
244            state: &CartesianState,
245            _ctx: &PropagationContext,
246        ) -> Result<StateDerivative, PropagationError> {
247            self.calls.fetch_add(1, Ordering::SeqCst);
248            Ok(StateDerivative {
249                dpos_km_s: state.velocity_km_s,
250                dvel_km_s2: -state.position_km,
251            })
252        }
253    }
254
255    fn initial_state() -> CartesianState {
256        CartesianState {
257            epoch_tdb_seconds: 0.0,
258            position_km: Vector3::new(7000.0, 0.0, 0.0),
259            velocity_km_s: Vector3::new(0.0, 7.5, 0.0),
260        }
261    }
262
263    #[test]
264    fn rejects_invalid_tolerances_before_derivative_evaluation() {
265        let cases = [
266            ("abs_tol", "not positive", -1.0, 1.0e-12),
267            ("abs_tol", "not positive", 0.0, 1.0e-12),
268            ("abs_tol", "not finite", f64::NAN, 1.0e-12),
269            ("rel_tol", "not positive", 1.0e-9, -1.0),
270            ("rel_tol", "not positive", 1.0e-9, 0.0),
271            ("rel_tol", "not finite", 1.0e-9, f64::NAN),
272        ];
273
274        for (field, reason, abs_tol, rel_tol) in cases {
275            let calls = AtomicUsize::new(0);
276            let dynamics = CountingDynamics { calls: &calls };
277            let ctx = PropagationContext::default();
278            let opts = IntegratorOptions {
279                abs_tol,
280                rel_tol,
281                ..IntegratorOptions::default()
282            };
283
284            let error = DP54
285                .propagate(initial_state(), 60.0, &dynamics, &ctx, &opts)
286                .expect_err("invalid DP54 tolerance must fail validation");
287
288            assert_invalid_input(error, field, reason);
289            assert_eq!(
290                calls.load(Ordering::SeqCst),
291                0,
292                "invalid {field} must be rejected before integration starts"
293            );
294        }
295    }
296
297    #[test]
298    fn rejects_non_finite_epochs_before_derivative_evaluation() {
299        let base = initial_state();
300        let mut nan_initial = base;
301        nan_initial.epoch_tdb_seconds = f64::NAN;
302        let mut infinite_initial = base;
303        infinite_initial.epoch_tdb_seconds = f64::INFINITY;
304        let cases = [
305            (nan_initial, 60.0, "initial.epoch_tdb_seconds"),
306            (infinite_initial, 60.0, "initial.epoch_tdb_seconds"),
307            (base, f64::NAN, "t_end_seconds"),
308            (base, f64::INFINITY, "t_end_seconds"),
309        ];
310
311        for (initial, t_end_seconds, field) in cases {
312            let calls = AtomicUsize::new(0);
313            let dynamics = CountingDynamics { calls: &calls };
314            let ctx = PropagationContext::default();
315            let opts = IntegratorOptions::default();
316
317            let error = DP54
318                .propagate(initial, t_end_seconds, &dynamics, &ctx, &opts)
319                .expect_err("non-finite DP54 epoch must fail validation");
320
321            assert_invalid_input(error, field, "not finite");
322            assert_eq!(
323                calls.load(Ordering::SeqCst),
324                0,
325                "non-finite {field} must be rejected before integration starts"
326            );
327        }
328    }
329
330    #[test]
331    fn accepts_positive_tolerances() {
332        let calls = AtomicUsize::new(0);
333        let dynamics = CountingDynamics { calls: &calls };
334        let ctx = PropagationContext::default();
335        let opts = IntegratorOptions {
336            abs_tol: 1.0e-9,
337            rel_tol: 1.0e-12,
338            initial_step: 10.0,
339            ..IntegratorOptions::default()
340        };
341
342        let result = DP54
343            .propagate(initial_state(), 60.0, &dynamics, &ctx, &opts)
344            .expect("positive DP54 tolerances must remain valid");
345
346        assert_eq!(result.final_state.epoch_tdb_seconds, 60.0);
347        assert!(calls.load(Ordering::SeqCst) > 0);
348    }
349
350    #[test]
351    fn zero_duration_returns_initial_state_without_derivative_evaluation() {
352        let calls = AtomicUsize::new(0);
353        let dynamics = CountingDynamics { calls: &calls };
354        let ctx = PropagationContext::default();
355        let opts = IntegratorOptions::default();
356        let initial = initial_state();
357
358        let result = DP54
359            .propagate(initial, initial.epoch_tdb_seconds, &dynamics, &ctx, &opts)
360            .expect("zero-duration propagation should return the initial state");
361
362        assert_eq!(result.final_state, initial);
363        assert_eq!(result.stats.accepted_steps, 0);
364        assert_eq!(result.stats.rejected_steps, 0);
365        assert_eq!(result.stats.evaluations, 0);
366        assert_eq!(calls.load(Ordering::SeqCst), 0);
367    }
368
369    #[test]
370    fn zero_duration_rejects_non_finite_initial_state_output() {
371        let calls = AtomicUsize::new(0);
372        let dynamics = CountingDynamics { calls: &calls };
373        let ctx = PropagationContext::default();
374        let opts = IntegratorOptions::default();
375        let mut initial = initial_state();
376        initial.position_km.x = f64::INFINITY;
377
378        let error = DP54
379            .propagate(initial, initial.epoch_tdb_seconds, &dynamics, &ctx, &opts)
380            .expect_err("zero-duration non-finite output must be rejected");
381
382        assert_numerical_failure(error, "final_state.position_km", "not finite");
383        assert_eq!(calls.load(Ordering::SeqCst), 0);
384    }
385
386    #[test]
387    fn rejected_steps_count_every_derivative_evaluation() {
388        let calls = AtomicUsize::new(0);
389        let dynamics = CountingOscillator { calls: &calls };
390        let ctx = PropagationContext::default();
391        let opts = IntegratorOptions {
392            abs_tol: 1.0e-12,
393            rel_tol: 1.0e-12,
394            initial_step: 1.0,
395            max_step: 1.0,
396            min_step: 1.0e-15,
397            ..IntegratorOptions::default()
398        };
399        let initial = CartesianState {
400            epoch_tdb_seconds: 0.0,
401            position_km: Vector3::new(1.0, 0.0, 0.0),
402            velocity_km_s: Vector3::new(0.0, 1.0, 0.0),
403        };
404
405        let result = DP54
406            .propagate(initial, 1.0, &dynamics, &ctx, &opts)
407            .expect("tight oscillator propagation should recover after rejected steps");
408
409        assert!(
410            result.stats.rejected_steps > 0,
411            "test setup must force at least one rejected step"
412        );
413        assert_eq!(
414            result.stats.evaluations,
415            calls.load(Ordering::SeqCst) as u32
416        );
417        assert_eq!(
418            result.stats.evaluations,
419            1 + 6 * (result.stats.accepted_steps + result.stats.rejected_steps)
420        );
421    }
422
423    fn assert_invalid_input(error: PropagationError, field: &str, reason: &str) {
424        match error {
425            PropagationError::InvalidInput(message) => {
426                assert!(message.contains(field), "{message}");
427                assert!(message.contains(reason), "{message}");
428            }
429            other => panic!("expected invalid propagation input for {field}, got {other:?}"),
430        }
431    }
432
433    fn assert_numerical_failure(error: PropagationError, field: &str, reason: &str) {
434        match error {
435            PropagationError::NumericalFailure(message) => {
436                assert!(message.contains(field), "{message}");
437                assert!(message.contains(reason), "{message}");
438            }
439            other => panic!("expected numerical failure for {field}, got {other:?}"),
440        }
441    }
442}
443
444struct DP54Step {
445    next_state: CartesianState,
446    k_fsal: StateDerivative,
447    r_err: Vector3<f64>,
448    v_err: Vector3<f64>,
449    evals: u32,
450    stages: Option<Vec<StateDerivative>>,
451}
452
453/// Per-step invariants shared across every Dormand-Prince stage evaluation:
454/// the dynamics model, propagation context, Butcher tableau, and whether to
455/// retain the intermediate stages for dense output.
456#[derive(Clone, Copy)]
457struct DP54StepContext<'a> {
458    rhs: &'a dyn DynamicsModel,
459    ctx: &'a PropagationContext,
460    tableau: &'a DP54Tableau,
461    capture_stages: bool,
462}
463
464impl DP54 {
465    fn step(
466        &self,
467        state: CartesianState,
468        h: f64,
469        k1: StateDerivative,
470        step_ctx: &DP54StepContext,
471    ) -> Result<DP54Step, PropagationError> {
472        let DP54StepContext {
473            rhs,
474            ctx,
475            tableau,
476            capture_stages,
477        } = *step_ctx;
478        let mut ks = Vec::with_capacity(7);
479        ks.push(k1);
480
481        for i in 1..6 {
482            let mut dpos = Vector3::zeros();
483            let mut dvel = Vector3::zeros();
484            for (j, k) in ks.iter().enumerate().take(i) {
485                dpos += k.dpos_km_s * tableau.a[i][j];
486                dvel += k.dvel_km_s2 * tableau.a[i][j];
487            }
488
489            let stage_state = CartesianState {
490                epoch_tdb_seconds: state.epoch_tdb_seconds + h * tableau.c[i],
491                position_km: state.position_km + dpos * h,
492                velocity_km_s: state.velocity_km_s + dvel * h,
493            };
494            ks.push(rhs.derivative(&stage_state, ctx)?);
495        }
496
497        // 5th order solution
498        let mut dpos5 = Vector3::zeros();
499        let mut dvel5 = Vector3::zeros();
500        for (i, k) in ks.iter().enumerate().take(6) {
501            dpos5 += k.dpos_km_s * tableau.b5[i];
502            dvel5 += k.dvel_km_s2 * tableau.b5[i];
503        }
504
505        let next_state = CartesianState {
506            epoch_tdb_seconds: state.epoch_tdb_seconds + h,
507            position_km: state.position_km + dpos5 * h,
508            velocity_km_s: state.velocity_km_s + dvel5 * h,
509        };
510
511        // FSAL
512        let k_fsal = rhs.derivative(&next_state, ctx)?;
513        ks.push(k_fsal);
514
515        // 4th order for error estimate
516        let mut dpos4 = Vector3::zeros();
517        let mut dvel4 = Vector3::zeros();
518        for (i, k) in ks.iter().enumerate().take(7) {
519            dpos4 += k.dpos_km_s * tableau.b4[i];
520            dvel4 += k.dvel_km_s2 * tableau.b4[i];
521        }
522
523        let r_err = (dpos5 - dpos4) * h;
524        let v_err = (dvel5 - dvel4) * h;
525
526        let stages = if capture_stages { Some(ks) } else { None };
527
528        Ok(DP54Step {
529            next_state,
530            k_fsal,
531            r_err,
532            v_err,
533            evals: 6,
534            stages,
535        })
536    }
537}