scirs2_integrate/ode/utils/step_control.rs
1//! Step size control algorithms for adaptive methods.
2//!
3//! This module provides utility functions for controlling step sizes
4//! in adaptive ODE solvers.
5
6use crate::IntegrateFloat;
7use scirs2_core::ndarray::{Array1, ArrayView1};
8
9/// Calculate the error norm based on relative and absolute tolerances
10///
11/// # Arguments
12///
13/// * `error` - Error estimate from the ODE step
14/// * `y` - Solution at the current step
15/// * `rtol` - Relative tolerance
16/// * `atol` - Absolute tolerance
17///
18/// # Returns
19///
20/// The normalized error
21#[allow(dead_code)]
22pub fn error_norm<F: IntegrateFloat>(error: &Array1<F>, y: &Array1<F>, rtol: F, atol: F) -> F {
23 // Calculate the denominator for normalization
24 let scale = y
25 .iter()
26 .zip(error.iter())
27 .map(|(y_i_, _)| rtol * y_i_.abs() + atol)
28 .collect::<Array1<F>>();
29
30 // Calculate RMS of scaled error
31 let mut sum_sq = F::zero();
32 for (e, s) in error.iter().zip(scale.iter()) {
33 sum_sq += (*e / *s).powi(2);
34 }
35
36 let n = F::from_usize(error.len()).unwrap();
37 (sum_sq / n).sqrt()
38}
39
40/// Calculate a new step size based on error estimate
41///
42/// # Arguments
43///
44/// * `h_current` - Current step size
45/// * `error_norm` - Current error norm
46/// * `error_order` - Order of the error estimator
47/// * `safety` - Safety factor (typically 0.8-0.9)
48///
49/// # Returns
50///
51/// The suggested new step size
52#[allow(dead_code)]
53pub fn calculate_new_step_size<F: IntegrateFloat>(
54 h_current: F,
55 error_norm: F,
56 error_order: usize,
57 safety: F,
58) -> F {
59 // If error is zero, increase step size significantly but safely
60 if error_norm == F::zero() {
61 return h_current * F::from_f64(10.0).unwrap();
62 }
63
64 // Standard step size calculation based on error estimate
65 let _order = F::from_usize(error_order).unwrap();
66 let error_ratio = F::one() / error_norm;
67
68 // Calculate factor using the formula: safety * error_ratio^(1/_order)
69 let factor = safety * error_ratio.powf(F::one() / _order);
70
71 // Limit factor to reasonable bounds to prevent too large or small step sizes
72 let factor_max = F::from_f64(10.0).unwrap();
73 let factor_min = F::from_f64(0.1).unwrap();
74
75 let factor = if factor > factor_max {
76 factor_max
77 } else if factor < factor_min {
78 factor_min
79 } else {
80 factor
81 };
82
83 // Apply factor to _current step size
84 h_current * factor
85}
86
87/// Select an initial step size for ODE solving
88///
89/// # Arguments
90///
91/// * `f` - ODE function
92/// * `t` - Initial time
93/// * `y` - Initial state
94/// * `direction` - Direction of integration (1.0 for forward, -1.0 for backward)
95/// * `rtol` - Relative tolerance
96/// * `atol` - Absolute tolerance
97///
98/// # Returns
99///
100/// Suggested initial step size
101#[allow(dead_code)]
102pub fn select_initial_step<F, Func>(
103 f: &Func,
104 t: F,
105 y: &Array1<F>,
106 direction: F,
107 rtol: F,
108 atol: F,
109) -> F
110where
111 F: IntegrateFloat,
112 Func: Fn(F, ArrayView1<F>) -> Array1<F>,
113{
114 // Calculate scale based on tolerances
115 let scale = y
116 .iter()
117 .map(|y_i| rtol * y_i.abs() + atol)
118 .collect::<Array1<F>>();
119
120 // Initial derivatives
121 let f0 = f(t, y.view());
122
123 // Estimate using the derivatives
124 let d0 = f0
125 .iter()
126 .zip(scale.iter())
127 .map(|(f, s)| *f / *s)
128 .fold(F::zero(), |acc, x| acc + x * x);
129
130 let d0 = d0.sqrt() / F::from_f64(y.len() as f64).unwrap().sqrt();
131
132 let step_size = if d0 < F::from_f64(1.0e-5).unwrap() {
133 // If derivatives are very small, use a default small step
134 F::from_f64(1.0e-6).unwrap()
135 } else {
136 // Otherwise, use a step size based on the derivatives
137 F::from_f64(0.01).unwrap() / d0
138 };
139
140 // Return step size with the correct sign
141 step_size * direction.signum()
142}