scirs2_integrate/ode/utils/
common.rs

1//! Common utilities for ODE solvers
2//!
3//! This module provides common utilities used by multiple ODE solvers.
4
5use crate::common::IntegrateFloat;
6use crate::error::{IntegrateError, IntegrateResult};
7use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
8
9/// Result of a single integration step
10pub enum StepResult<F: IntegrateFloat> {
11    /// Step accepted with the given solution
12    Accepted(Array1<F>),
13    /// Step rejected
14    Rejected,
15    /// Should switch methods (for LSODA)
16    ShouldSwitch,
17}
18
19/// State information for ODE solvers
20pub struct ODEState<F: IntegrateFloat> {
21    /// Current time
22    pub t: F,
23    /// Current solution
24    pub y: Array1<F>,
25    /// Current derivative
26    pub dy: Array1<F>,
27    /// Current step size
28    pub h: F,
29    /// Function evaluations
30    pub func_evals: usize,
31    /// Steps taken
32    pub steps: usize,
33    /// Accepted steps
34    pub accepted_steps: usize,
35    /// Rejected steps
36    pub rejected_steps: usize,
37}
38
39/// Type of ODE problem
40pub enum ODEType {
41    /// Non-stiff problem
42    NonStiff,
43    /// Stiff problem
44    Stiff,
45    /// Problem that changes between stiff and non-stiff
46    Mixed,
47}
48
49/// Calculate a safe step size based on function derivatives
50#[allow(dead_code)]
51pub fn estimate_initial_step<F, Func>(
52    f: &Func,
53    t: F,
54    y: &Array1<F>,
55    dy: &Array1<F>,
56    tol: F,
57    tend: F,
58) -> F
59where
60    F: IntegrateFloat,
61    Func: Fn(F, ArrayView1<F>) -> Array1<F>,
62{
63    // Calculate a scaling factor based on the solution magnitude
64    let mut d0 = F::zero();
65    for i in 0..y.len() {
66        let sc = (y[i].abs() + dy[i].abs()).max(tol);
67        d0 = d0.max(dy[i].abs() / sc);
68    }
69
70    // Set a reasonable default
71    if d0 < tol {
72        d0 = F::one();
73    }
74
75    // Initial step size
76    let dt = (F::from_f64(0.01).unwrap() / d0).min(F::from_f64(0.1).unwrap() * (tend - t).abs());
77
78    // Evaluate f at t + small step to estimate second derivative
79    let t_new = t + dt * F::from_f64(0.001).unwrap();
80    let y_new = y + &(dy * (t_new - t));
81    let dy_new = f(t_new, y_new.view());
82
83    // Calculate an estimate of the second derivative
84    let mut d1 = F::zero();
85    for i in 0..y.len() {
86        let sc = (y[i].abs() + dy[i].abs()).max(tol);
87        d1 = d1.max((dy_new[i] - dy[i]).abs() / (sc * (t_new - t)));
88    }
89
90    // Prevent division by zero
91    if d1 < tol {
92        d1 = tol;
93    }
94
95    // Calculate step size based on error tolerance
96    let h1 = (F::from_f64(0.01).unwrap() / d1).sqrt();
97
98    // Choose the smaller of the two estimates
99    let mut h = h1.min(dt * F::from_f64(100.0).unwrap());
100
101    // Make sure step size is not too large
102    h = h.min((tend - t).abs() * F::from_f64(0.1).unwrap());
103
104    // Ensure the step is in the correct direction
105    if tend < t {
106        h = -h;
107    }
108
109    h
110}
111
112/// Calculate finite difference approximation of the jacobian matrix
113#[allow(dead_code)]
114pub fn finite_difference_jacobian<F, Func>(
115    f: &Func,
116    t: F,
117    y: &Array1<F>,
118    f_eval: &Array1<F>,
119    _perturbation_scale: F,
120) -> Array2<F>
121where
122    F: IntegrateFloat,
123    Func: Fn(F, ArrayView1<F>) -> Array1<F>,
124{
125    let n_dim = y.len();
126    let mut jacobian = Array2::<F>::zeros((n_dim, n_dim));
127
128    // Calculate appropriate perturbation size
129    let eps_base = F::from_f64(1e-8).unwrap();
130
131    for i in 0..n_dim {
132        // Scale perturbation by variable magnitude
133        let eps = eps_base * (F::one() + y[i].abs()).max(F::one());
134
135        // Perturb the i-th component
136        let mut y_perturbed = y.clone();
137        y_perturbed[i] += eps;
138
139        // Evaluate function at perturbed point
140        let f_perturbed = f(t, y_perturbed.view());
141
142        // Calculate the i-th column of the Jacobian using finite differences
143        for j in 0..n_dim {
144            jacobian[[j, i]] = (f_perturbed[j] - f_eval[j]) / eps;
145        }
146    }
147
148    jacobian
149}
150
151/// Apply a scaled norm to an array
152#[allow(dead_code)]
153pub fn scaled_norm<F: IntegrateFloat>(v: &Array1<F>, scale: &Array1<F>) -> F {
154    let mut max_err = F::zero();
155    for i in 0..v.len() {
156        let err = v[i].abs() / scale[i];
157        max_err = max_err.max(err);
158    }
159    max_err
160}
161
162/// Calculate scaling factors for error control
163#[allow(dead_code)]
164pub fn calculate_error_weights<F: IntegrateFloat>(y: &Array1<F>, atol: F, rtol: F) -> Array1<F> {
165    let mut weights = Array1::<F>::zeros(y.len());
166    for i in 0..y.len() {
167        weights[i] = atol + rtol * y[i].abs();
168    }
169    weights
170}
171
172/// Solve a linear system Ax = b using Gaussian elimination with partial pivoting
173#[allow(dead_code)]
174pub fn solve_linear_system<F: IntegrateFloat>(
175    a: &Array2<F>,
176    b: &Array1<F>,
177) -> IntegrateResult<Array1<F>> {
178    let n = a.shape()[0];
179    if n != a.shape()[1] || n != b.len() {
180        return Err(IntegrateError::DimensionMismatch(
181            "Matrix dimensions do not match for linear solve".to_string(),
182        ));
183    }
184
185    // Create augmented matrix [A|b]
186    let mut aug = Array2::<F>::zeros((n, n + 1));
187    for i in 0..n {
188        for j in 0..n {
189            aug[[i, j]] = a[[i, j]];
190        }
191        aug[[i, n]] = b[i];
192    }
193
194    // Gaussian elimination with partial pivoting
195    for i in 0..n {
196        // Find pivot
197        let mut max_idx = i;
198        let mut max_val = aug[[i, i]].abs();
199
200        for j in i + 1..n {
201            if aug[[j, i]].abs() > max_val {
202                max_idx = j;
203                max_val = aug[[j, i]].abs();
204            }
205        }
206
207        // Check if matrix is singular
208        if max_val < F::from_f64(1e-10).unwrap() {
209            return Err(IntegrateError::LinearSolveError(
210                "Matrix is singular".to_string(),
211            ));
212        }
213
214        // Swap rows if necessary
215        if max_idx != i {
216            for j in 0..n + 1 {
217                let temp = aug[[i, j]];
218                aug[[i, j]] = aug[[max_idx, j]];
219                aug[[max_idx, j]] = temp;
220            }
221        }
222
223        // Eliminate below
224        for j in i + 1..n {
225            let factor = aug[[j, i]] / aug[[i, i]];
226            aug[[j, i]] = F::zero();
227
228            for k in i + 1..n + 1 {
229                aug[[j, k]] = aug[[j, k]] - factor * aug[[i, k]];
230            }
231        }
232    }
233
234    // Back substitution
235    let mut x = Array1::<F>::zeros(n);
236
237    for i in (0..n).rev() {
238        let mut sum = aug[[i, n]];
239
240        for j in i + 1..n {
241            sum -= aug[[i, j]] * x[j];
242        }
243
244        x[i] = sum / aug[[i, i]];
245    }
246
247    Ok(x)
248}
249
250/// Extrapolate solution values for use as initial guess
251#[allow(dead_code)]
252pub fn extrapolate<F: IntegrateFloat>(
253    times: &[F],
254    values: &[Array1<F>],
255    t_target: F,
256) -> IntegrateResult<Array1<F>> {
257    let n = values.len();
258
259    if n == 0 {
260        return Err(IntegrateError::ValueError(
261            "Cannot extrapolate from empty values".to_string(),
262        ));
263    }
264
265    if n == 1 {
266        return Ok(values[0].clone());
267    }
268
269    // Linear extrapolation if we have 2 points
270    if n == 2 {
271        let dt = times[1] - times[0];
272        if dt.abs() < F::from_f64(1e-10).unwrap() {
273            return Ok(values[1].clone());
274        }
275
276        let t_ratio = (t_target - times[1]) / dt;
277        return Ok(&values[1] + &((&values[1] - &values[0]) * t_ratio));
278    }
279
280    // Quadratic extrapolation if we have 3 or more points
281    let t0 = times[n - 3];
282    let t1 = times[n - 2];
283    let t2 = times[n - 1];
284
285    let y0 = &values[n - 3];
286    let y1 = &values[n - 2];
287    let y2 = &values[n - 1];
288
289    // Compute quadratic Lagrange extrapolation
290    let dt0 = t_target - t0;
291    let dt1 = t_target - t1;
292    let dt2 = t_target - t2;
293
294    let dt01 = t0 - t1;
295    let dt02 = t0 - t2;
296    let dt12 = t1 - t2;
297
298    let c0 = dt1 * dt2 / (dt01 * dt02);
299    let c1 = dt0 * dt2 / (-dt01 * dt12);
300    let c2 = dt0 * dt1 / (dt02 * dt12);
301
302    Ok(y0 * c0 + y1 * c1 + y2 * c2)
303}