scirs2_integrate/ode/utils/
step_control.rs

1//! Step size control algorithms for adaptive methods.
2//!
3//! This module provides utility functions for controlling step sizes
4//! in adaptive ODE solvers.
5
6use crate::IntegrateFloat;
7use scirs2_core::ndarray::{Array1, ArrayView1};
8
9/// Calculate the error norm based on relative and absolute tolerances
10///
11/// # Arguments
12///
13/// * `error` - Error estimate from the ODE step
14/// * `y` - Solution at the current step
15/// * `rtol` - Relative tolerance
16/// * `atol` - Absolute tolerance
17///
18/// # Returns
19///
20/// The normalized error
21#[allow(dead_code)]
22pub fn error_norm<F: IntegrateFloat>(error: &Array1<F>, y: &Array1<F>, rtol: F, atol: F) -> F {
23    // Calculate the denominator for normalization
24    let scale = y
25        .iter()
26        .zip(error.iter())
27        .map(|(y_i_, _)| rtol * y_i_.abs() + atol)
28        .collect::<Array1<F>>();
29
30    // Calculate RMS of scaled error
31    let mut sum_sq = F::zero();
32    for (e, s) in error.iter().zip(scale.iter()) {
33        sum_sq += (*e / *s).powi(2);
34    }
35
36    let n = F::from_usize(error.len()).unwrap();
37    (sum_sq / n).sqrt()
38}
39
40/// Calculate a new step size based on error estimate
41///
42/// # Arguments
43///
44/// * `h_current` - Current step size
45/// * `error_norm` - Current error norm
46/// * `error_order` - Order of the error estimator
47/// * `safety` - Safety factor (typically 0.8-0.9)
48///
49/// # Returns
50///
51/// The suggested new step size
52#[allow(dead_code)]
53pub fn calculate_new_step_size<F: IntegrateFloat>(
54    h_current: F,
55    error_norm: F,
56    error_order: usize,
57    safety: F,
58) -> F {
59    // If error is zero, increase step size significantly but safely
60    if error_norm == F::zero() {
61        return h_current * F::from_f64(10.0).unwrap();
62    }
63
64    // Standard step size calculation based on error estimate
65    let _order = F::from_usize(error_order).unwrap();
66    let error_ratio = F::one() / error_norm;
67
68    // Calculate factor using the formula: safety * error_ratio^(1/_order)
69    let factor = safety * error_ratio.powf(F::one() / _order);
70
71    // Limit factor to reasonable bounds to prevent too large or small step sizes
72    let factor_max = F::from_f64(10.0).unwrap();
73    let factor_min = F::from_f64(0.1).unwrap();
74
75    let factor = if factor > factor_max {
76        factor_max
77    } else if factor < factor_min {
78        factor_min
79    } else {
80        factor
81    };
82
83    // Apply factor to _current step size
84    h_current * factor
85}
86
87/// Select an initial step size for ODE solving
88///
89/// # Arguments
90///
91/// * `f` - ODE function
92/// * `t` - Initial time
93/// * `y` - Initial state
94/// * `direction` - Direction of integration (1.0 for forward, -1.0 for backward)
95/// * `rtol` - Relative tolerance
96/// * `atol` - Absolute tolerance
97///
98/// # Returns
99///
100/// Suggested initial step size
101#[allow(dead_code)]
102pub fn select_initial_step<F, Func>(
103    f: &Func,
104    t: F,
105    y: &Array1<F>,
106    direction: F,
107    rtol: F,
108    atol: F,
109) -> F
110where
111    F: IntegrateFloat,
112    Func: Fn(F, ArrayView1<F>) -> Array1<F>,
113{
114    // Calculate scale based on tolerances
115    let scale = y
116        .iter()
117        .map(|y_i| rtol * y_i.abs() + atol)
118        .collect::<Array1<F>>();
119
120    // Initial derivatives
121    let f0 = f(t, y.view());
122
123    // Estimate using the derivatives
124    let d0 = f0
125        .iter()
126        .zip(scale.iter())
127        .map(|(f, s)| *f / *s)
128        .fold(F::zero(), |acc, x| acc + x * x);
129
130    let d0 = d0.sqrt() / F::from_f64(y.len() as f64).unwrap().sqrt();
131
132    let step_size = if d0 < F::from_f64(1.0e-5).unwrap() {
133        // If derivatives are very small, use a default small step
134        F::from_f64(1.0e-6).unwrap()
135    } else {
136        // Otherwise, use a step size based on the derivatives
137        F::from_f64(0.01).unwrap() / d0
138    };
139
140    // Return step size with the correct sign
141    step_size * direction.signum()
142}