1use crate::astro::error::PropagationError;
2use crate::astro::integrators::tableau::DP54Tableau;
3use crate::astro::integrators::{DynamicsModel, Integrator};
4use crate::astro::propagator::api::{
5 validate_adaptive_integrator_options, validate_integrator_epoch, IntegratorOptions,
6 PropagationContext,
7};
8use crate::astro::propagator::controller::PIController;
9use crate::astro::propagator::dense_output::{DenseOutput, DenseSegment};
10use crate::astro::propagator::result::{
11 validate_propagation_result, PropagationPoint, PropagationResult, PropagationStats,
12};
13use crate::astro::state::{CartesianState, StateDerivative};
14use nalgebra::Vector3;
15
16pub struct DP54;
17
18impl Integrator for DP54 {
19 fn propagate(
20 &self,
21 initial: CartesianState,
22 t_end_seconds: f64,
23 rhs: &dyn DynamicsModel,
24 ctx: &PropagationContext,
25 opts: &IntegratorOptions,
26 ) -> Result<PropagationResult, PropagationError> {
27 validate_adaptive_integrator_options(opts)?;
28 validate_integrator_epoch(initial.epoch_tdb_seconds, "initial.epoch_tdb_seconds")?;
29 validate_integrator_epoch(t_end_seconds, "t_end_seconds")?;
30
31 let dt_target = t_end_seconds - initial.epoch_tdb_seconds;
32 let target_abs = dt_target.abs();
33 if target_abs == 0.0 {
34 let point = PropagationPoint {
35 epoch_tdb_seconds: initial.epoch_tdb_seconds,
36 position_km: initial.position_array(),
37 velocity_km_s: initial.velocity_array(),
38 };
39 let mut points = vec![point.clone()];
40 if !opts.dense_output {
41 points.push(point);
42 }
43 let dense = if opts.dense_output {
44 Some(DenseOutput {
45 segments: Vec::new(),
46 })
47 } else {
48 None
49 };
50
51 return validate_propagation_result(PropagationResult {
52 final_state: initial,
53 points,
54 events: Vec::new(),
55 stats: PropagationStats {
56 accepted_steps: 0,
57 rejected_steps: 0,
58 evaluations: 0,
59 },
60 dense,
61 });
62 }
63
64 let tableau = DP54Tableau::default();
65 let controller = PIController {
66 order: 5.0,
67 ..PIController::default()
68 };
69
70 let mut state = initial;
71 let mut t = initial.epoch_tdb_seconds;
72 let sign = dt_target.signum();
73
74 let mut h = crate::validate::clamp_magnitude(
75 opts.initial_step.min(target_abs) * sign,
76 opts.max_step,
77 );
78 let mut steps_accepted = 0;
79 let mut steps_rejected = 0;
80 let mut evals = 0;
81 let mut points = Vec::new();
82 let mut dense_segments = Vec::new();
83
84 points.push(PropagationPoint {
85 epoch_tdb_seconds: t,
86 position_km: state.position_array(),
87 velocity_km_s: state.velocity_array(),
88 });
89
90 let mut k1 = rhs.derivative(&state, ctx)?;
92 evals += 1;
93
94 while (t - initial.epoch_tdb_seconds).abs() < target_abs {
95 if steps_accepted + steps_rejected >= opts.max_steps {
96 return Err(PropagationError::MaxStepsExceeded);
97 }
98
99 let mut h_step = h;
100 if (t + h_step - initial.epoch_tdb_seconds).abs() > target_abs {
101 h_step = t_end_seconds - t;
102 }
103
104 let step_ctx = DP54StepContext {
106 rhs,
107 ctx,
108 tableau: &tableau,
109 capture_stages: opts.dense_output,
110 };
111 let step_res = self.step(state, h_step, k1, &step_ctx)?;
112
113 let r_scale = opts.abs_tol
115 + state
116 .position_km
117 .norm()
118 .max(step_res.next_state.position_km.norm())
119 * opts.rel_tol;
120 let v_scale = opts.abs_tol
121 + state
122 .velocity_km_s
123 .norm()
124 .max(step_res.next_state.velocity_km_s.norm())
125 * opts.rel_tol;
126
127 let err_r = step_res.r_err.norm() / r_scale;
128 let err_v = step_res.v_err.norm() / v_scale;
129 let err = err_r.max(err_v);
130
131 if err <= 1.0 {
132 if opts.dense_output {
134 if let Some(stages) = step_res.stages {
135 let ks_array: [StateDerivative; 7] = stages.try_into().map_err(|_| {
136 PropagationError::NumericalFailure(
137 "Failed to capture RK stages".to_string(),
138 )
139 })?;
140 dense_segments.push(DenseSegment::from_dp54_stages(
141 t,
142 h_step,
143 state,
144 step_res.next_state,
145 &ks_array,
146 ));
147 }
148 }
149
150 state = step_res.next_state;
151 t += h_step;
152 k1 = step_res.k_fsal; steps_accepted += 1;
154 evals += step_res.evals;
155
156 if opts.dense_output {
157 points.push(PropagationPoint {
158 epoch_tdb_seconds: t,
159 position_km: state.position_array(),
160 velocity_km_s: state.velocity_array(),
161 });
162 }
163
164 h = crate::validate::clamp_magnitude(
165 controller.next_step(h_step, err),
166 opts.max_step,
167 );
168 } else {
169 steps_rejected += 1;
170 evals += step_res.evals;
171 h = crate::validate::clamp_magnitude(
172 controller.next_step(h_step, err),
173 opts.max_step,
174 );
175
176 if h.abs() < opts.min_step {
177 return Err(PropagationError::NumericalFailure(
178 "Step size too small".to_string(),
179 ));
180 }
181 }
182 }
183
184 if !opts.dense_output {
185 points.push(PropagationPoint {
186 epoch_tdb_seconds: t,
187 position_km: state.position_array(),
188 velocity_km_s: state.velocity_array(),
189 });
190 }
191
192 let dense = if opts.dense_output {
193 Some(DenseOutput {
194 segments: dense_segments,
195 })
196 } else {
197 None
198 };
199
200 validate_propagation_result(PropagationResult {
201 final_state: state,
202 points,
203 events: Vec::new(),
204 stats: PropagationStats {
205 accepted_steps: steps_accepted,
206 rejected_steps: steps_rejected,
207 evaluations: evals,
208 },
209 dense,
210 })
211 }
212}
213
214#[cfg(test)]
215mod tests {
216 use super::*;
217 use std::sync::atomic::{AtomicUsize, Ordering};
218
219 struct CountingDynamics<'a> {
220 calls: &'a AtomicUsize,
221 }
222
223 impl DynamicsModel for CountingDynamics<'_> {
224 fn derivative(
225 &self,
226 state: &CartesianState,
227 _ctx: &PropagationContext,
228 ) -> Result<StateDerivative, PropagationError> {
229 self.calls.fetch_add(1, Ordering::SeqCst);
230 Ok(StateDerivative {
231 dpos_km_s: state.velocity_km_s,
232 dvel_km_s2: Vector3::zeros(),
233 })
234 }
235 }
236
237 struct CountingOscillator<'a> {
238 calls: &'a AtomicUsize,
239 }
240
241 impl DynamicsModel for CountingOscillator<'_> {
242 fn derivative(
243 &self,
244 state: &CartesianState,
245 _ctx: &PropagationContext,
246 ) -> Result<StateDerivative, PropagationError> {
247 self.calls.fetch_add(1, Ordering::SeqCst);
248 Ok(StateDerivative {
249 dpos_km_s: state.velocity_km_s,
250 dvel_km_s2: -state.position_km,
251 })
252 }
253 }
254
255 fn initial_state() -> CartesianState {
256 CartesianState {
257 epoch_tdb_seconds: 0.0,
258 position_km: Vector3::new(7000.0, 0.0, 0.0),
259 velocity_km_s: Vector3::new(0.0, 7.5, 0.0),
260 }
261 }
262
263 #[test]
264 fn rejects_invalid_tolerances_before_derivative_evaluation() {
265 let cases = [
266 ("abs_tol", "not positive", -1.0, 1.0e-12),
267 ("abs_tol", "not positive", 0.0, 1.0e-12),
268 ("abs_tol", "not finite", f64::NAN, 1.0e-12),
269 ("rel_tol", "not positive", 1.0e-9, -1.0),
270 ("rel_tol", "not positive", 1.0e-9, 0.0),
271 ("rel_tol", "not finite", 1.0e-9, f64::NAN),
272 ];
273
274 for (field, reason, abs_tol, rel_tol) in cases {
275 let calls = AtomicUsize::new(0);
276 let dynamics = CountingDynamics { calls: &calls };
277 let ctx = PropagationContext::default();
278 let opts = IntegratorOptions {
279 abs_tol,
280 rel_tol,
281 ..IntegratorOptions::default()
282 };
283
284 let error = DP54
285 .propagate(initial_state(), 60.0, &dynamics, &ctx, &opts)
286 .expect_err("invalid DP54 tolerance must fail validation");
287
288 assert_invalid_input(error, field, reason);
289 assert_eq!(
290 calls.load(Ordering::SeqCst),
291 0,
292 "invalid {field} must be rejected before integration starts"
293 );
294 }
295 }
296
297 #[test]
298 fn rejects_non_finite_epochs_before_derivative_evaluation() {
299 let base = initial_state();
300 let mut nan_initial = base;
301 nan_initial.epoch_tdb_seconds = f64::NAN;
302 let mut infinite_initial = base;
303 infinite_initial.epoch_tdb_seconds = f64::INFINITY;
304 let cases = [
305 (nan_initial, 60.0, "initial.epoch_tdb_seconds"),
306 (infinite_initial, 60.0, "initial.epoch_tdb_seconds"),
307 (base, f64::NAN, "t_end_seconds"),
308 (base, f64::INFINITY, "t_end_seconds"),
309 ];
310
311 for (initial, t_end_seconds, field) in cases {
312 let calls = AtomicUsize::new(0);
313 let dynamics = CountingDynamics { calls: &calls };
314 let ctx = PropagationContext::default();
315 let opts = IntegratorOptions::default();
316
317 let error = DP54
318 .propagate(initial, t_end_seconds, &dynamics, &ctx, &opts)
319 .expect_err("non-finite DP54 epoch must fail validation");
320
321 assert_invalid_input(error, field, "not finite");
322 assert_eq!(
323 calls.load(Ordering::SeqCst),
324 0,
325 "non-finite {field} must be rejected before integration starts"
326 );
327 }
328 }
329
330 #[test]
331 fn accepts_positive_tolerances() {
332 let calls = AtomicUsize::new(0);
333 let dynamics = CountingDynamics { calls: &calls };
334 let ctx = PropagationContext::default();
335 let opts = IntegratorOptions {
336 abs_tol: 1.0e-9,
337 rel_tol: 1.0e-12,
338 initial_step: 10.0,
339 ..IntegratorOptions::default()
340 };
341
342 let result = DP54
343 .propagate(initial_state(), 60.0, &dynamics, &ctx, &opts)
344 .expect("positive DP54 tolerances must remain valid");
345
346 assert_eq!(result.final_state.epoch_tdb_seconds, 60.0);
347 assert!(calls.load(Ordering::SeqCst) > 0);
348 }
349
350 #[test]
351 fn zero_duration_returns_initial_state_without_derivative_evaluation() {
352 let calls = AtomicUsize::new(0);
353 let dynamics = CountingDynamics { calls: &calls };
354 let ctx = PropagationContext::default();
355 let opts = IntegratorOptions::default();
356 let initial = initial_state();
357
358 let result = DP54
359 .propagate(initial, initial.epoch_tdb_seconds, &dynamics, &ctx, &opts)
360 .expect("zero-duration propagation should return the initial state");
361
362 assert_eq!(result.final_state, initial);
363 assert_eq!(result.stats.accepted_steps, 0);
364 assert_eq!(result.stats.rejected_steps, 0);
365 assert_eq!(result.stats.evaluations, 0);
366 assert_eq!(calls.load(Ordering::SeqCst), 0);
367 }
368
369 #[test]
370 fn zero_duration_rejects_non_finite_initial_state_output() {
371 let calls = AtomicUsize::new(0);
372 let dynamics = CountingDynamics { calls: &calls };
373 let ctx = PropagationContext::default();
374 let opts = IntegratorOptions::default();
375 let mut initial = initial_state();
376 initial.position_km.x = f64::INFINITY;
377
378 let error = DP54
379 .propagate(initial, initial.epoch_tdb_seconds, &dynamics, &ctx, &opts)
380 .expect_err("zero-duration non-finite output must be rejected");
381
382 assert_numerical_failure(error, "final_state.position_km", "not finite");
383 assert_eq!(calls.load(Ordering::SeqCst), 0);
384 }
385
386 #[test]
387 fn rejected_steps_count_every_derivative_evaluation() {
388 let calls = AtomicUsize::new(0);
389 let dynamics = CountingOscillator { calls: &calls };
390 let ctx = PropagationContext::default();
391 let opts = IntegratorOptions {
392 abs_tol: 1.0e-12,
393 rel_tol: 1.0e-12,
394 initial_step: 1.0,
395 max_step: 1.0,
396 min_step: 1.0e-15,
397 ..IntegratorOptions::default()
398 };
399 let initial = CartesianState {
400 epoch_tdb_seconds: 0.0,
401 position_km: Vector3::new(1.0, 0.0, 0.0),
402 velocity_km_s: Vector3::new(0.0, 1.0, 0.0),
403 };
404
405 let result = DP54
406 .propagate(initial, 1.0, &dynamics, &ctx, &opts)
407 .expect("tight oscillator propagation should recover after rejected steps");
408
409 assert!(
410 result.stats.rejected_steps > 0,
411 "test setup must force at least one rejected step"
412 );
413 assert_eq!(
414 result.stats.evaluations,
415 calls.load(Ordering::SeqCst) as u32
416 );
417 assert_eq!(
418 result.stats.evaluations,
419 1 + 6 * (result.stats.accepted_steps + result.stats.rejected_steps)
420 );
421 }
422
423 fn assert_invalid_input(error: PropagationError, field: &str, reason: &str) {
424 match error {
425 PropagationError::InvalidInput(message) => {
426 assert!(message.contains(field), "{message}");
427 assert!(message.contains(reason), "{message}");
428 }
429 other => panic!("expected invalid propagation input for {field}, got {other:?}"),
430 }
431 }
432
433 fn assert_numerical_failure(error: PropagationError, field: &str, reason: &str) {
434 match error {
435 PropagationError::NumericalFailure(message) => {
436 assert!(message.contains(field), "{message}");
437 assert!(message.contains(reason), "{message}");
438 }
439 other => panic!("expected numerical failure for {field}, got {other:?}"),
440 }
441 }
442}
443
444struct DP54Step {
445 next_state: CartesianState,
446 k_fsal: StateDerivative,
447 r_err: Vector3<f64>,
448 v_err: Vector3<f64>,
449 evals: u32,
450 stages: Option<Vec<StateDerivative>>,
451}
452
453#[derive(Clone, Copy)]
457struct DP54StepContext<'a> {
458 rhs: &'a dyn DynamicsModel,
459 ctx: &'a PropagationContext,
460 tableau: &'a DP54Tableau,
461 capture_stages: bool,
462}
463
464impl DP54 {
465 fn step(
466 &self,
467 state: CartesianState,
468 h: f64,
469 k1: StateDerivative,
470 step_ctx: &DP54StepContext,
471 ) -> Result<DP54Step, PropagationError> {
472 let DP54StepContext {
473 rhs,
474 ctx,
475 tableau,
476 capture_stages,
477 } = *step_ctx;
478 let mut ks = Vec::with_capacity(7);
479 ks.push(k1);
480
481 for i in 1..6 {
482 let mut dpos = Vector3::zeros();
483 let mut dvel = Vector3::zeros();
484 for (j, k) in ks.iter().enumerate().take(i) {
485 dpos += k.dpos_km_s * tableau.a[i][j];
486 dvel += k.dvel_km_s2 * tableau.a[i][j];
487 }
488
489 let stage_state = CartesianState {
490 epoch_tdb_seconds: state.epoch_tdb_seconds + h * tableau.c[i],
491 position_km: state.position_km + dpos * h,
492 velocity_km_s: state.velocity_km_s + dvel * h,
493 };
494 ks.push(rhs.derivative(&stage_state, ctx)?);
495 }
496
497 let mut dpos5 = Vector3::zeros();
499 let mut dvel5 = Vector3::zeros();
500 for (i, k) in ks.iter().enumerate().take(6) {
501 dpos5 += k.dpos_km_s * tableau.b5[i];
502 dvel5 += k.dvel_km_s2 * tableau.b5[i];
503 }
504
505 let next_state = CartesianState {
506 epoch_tdb_seconds: state.epoch_tdb_seconds + h,
507 position_km: state.position_km + dpos5 * h,
508 velocity_km_s: state.velocity_km_s + dvel5 * h,
509 };
510
511 let k_fsal = rhs.derivative(&next_state, ctx)?;
513 ks.push(k_fsal);
514
515 let mut dpos4 = Vector3::zeros();
517 let mut dvel4 = Vector3::zeros();
518 for (i, k) in ks.iter().enumerate().take(7) {
519 dpos4 += k.dpos_km_s * tableau.b4[i];
520 dvel4 += k.dvel_km_s2 * tableau.b4[i];
521 }
522
523 let r_err = (dpos5 - dpos4) * h;
524 let v_err = (dvel5 - dvel4) * h;
525
526 let stages = if capture_stages { Some(ks) } else { None };
527
528 Ok(DP54Step {
529 next_state,
530 k_fsal,
531 r_err,
532 v_err,
533 evals: 6,
534 stages,
535 })
536 }
537}