Skip to main content

scivex_optim/ode/
bdf.rs

1//! BDF (Backward Differentiation Formula) method for stiff ODE systems.
2//!
3//! Implements BDF-2: `y_{n+1} = (4/3)*y_n - (1/3)*y_{n-1} + (2/3)*h*f(t_{n+1}, y_{n+1})`.
4//! Implicit method suitable for stiff problems. Uses fixed-point iteration
5//! to solve the implicit equation at each step.
6
7use scivex_core::Float;
8
9use super::{OdeOptions, OdeResult};
10use crate::error::{OptimError, Result};
11
12const FIXED_POINT_MAX_ITER: usize = 50;
13const FIXED_POINT_TOL: f64 = 1e-10;
14
15/// Solve a stiff ODE system using the BDF-2 method.
16///
17/// `f(t, y) -> dy/dt` defines the system. `y0` is the initial state vector.
18/// Uses BDF-1 (backward Euler) for the first step, then BDF-2 for subsequent steps.
19///
20/// # Examples
21///
22/// ```
23/// # use scivex_optim::ode::{bdf2, OdeOptions};
24/// // dy/dt = -50*y (stiff equation), y(0) = 1
25/// let result = bdf2(
26///     |_t: f64, y: &[f64]| vec![-50.0 * y[0]],
27///     [0.0, 1.0],
28///     &[1.0],
29///     &OdeOptions::default(),
30/// ).unwrap();
31/// assert!(result.success);
32/// ```
33#[allow(clippy::too_many_lines)]
34pub fn bdf2<T, F>(f: F, t_span: [T; 2], y0: &[T], options: &OdeOptions<T>) -> Result<OdeResult<T>>
35where
36    T: Float,
37    F: Fn(T, &[T]) -> Vec<T>,
38{
39    let t0 = t_span[0];
40    let tf = t_span[1];
41    let n = y0.len();
42    let h = options
43        .first_step
44        .unwrap_or_else(|| (tf - t0) / T::from_f64(200.0));
45    let max_steps = options.max_steps;
46
47    let mut t = t0;
48    let mut t_values = vec![t];
49    let mut y_values = vec![y0.to_vec()];
50    let mut n_evals: usize = 0;
51    let mut n_steps: usize = 0;
52
53    // BDF-1 (backward Euler) for first step: y1 = y0 + h*f(t1, y1)
54    // Solve by fixed-point iteration: y1^{k+1} = y0 + h*f(t1, y1^k)
55    let t1 = t + h.min(tf - t);
56    let h_actual = t1 - t;
57    let mut y1 = y0.to_vec();
58
59    // Initial guess using forward Euler
60    let dy0 = f(t, &y1);
61    n_evals += 1;
62    for i in 0..n {
63        y1[i] += h_actual * dy0[i];
64    }
65
66    // Fixed-point iteration for backward Euler
67    for _ in 0..FIXED_POINT_MAX_ITER {
68        let dy = f(t1, &y1);
69        n_evals += 1;
70        let mut max_diff = T::zero();
71        for i in 0..n {
72            let y_new = y0[i] + h_actual * dy[i];
73            let diff = (y_new - y1[i]).abs();
74            if diff > max_diff {
75                max_diff = diff;
76            }
77            y1[i] = y_new;
78        }
79        if max_diff < T::from_f64(FIXED_POINT_TOL) {
80            break;
81        }
82    }
83
84    t = t1;
85    t_values.push(t);
86    y_values.push(y1.clone());
87    n_steps += 1;
88
89    if t >= tf {
90        return Ok(OdeResult {
91            t: t_values,
92            y: y_values,
93            n_evals,
94            n_steps,
95            success: true,
96        });
97    }
98
99    // BDF-2 for subsequent steps
100    // y_{n+1} = (4/3)*y_n - (1/3)*y_{n-1} + (2/3)*h*f(t_{n+1}, y_{n+1})
101    let four_thirds = T::from_f64(4.0 / 3.0);
102    let one_third = T::from_f64(1.0 / 3.0);
103    let two_thirds = T::from_f64(2.0 / 3.0);
104
105    while t < tf {
106        if n_steps >= max_steps {
107            return Err(OptimError::ConvergenceFailure {
108                iterations: n_steps,
109            });
110        }
111
112        let step = h.min(tf - t);
113        let t_next = t + step;
114
115        let y_nm1 = &y_values[y_values.len() - 2];
116        let y_n = &y_values[y_values.len() - 1];
117
118        // Predictor: extrapolate from previous two points
119        let mut y_next = vec![T::zero(); n];
120        for i in 0..n {
121            y_next[i] = four_thirds * y_n[i] - one_third * y_nm1[i];
122        }
123
124        // Fixed-point iteration for BDF-2
125        let mut converged = false;
126        for _ in 0..FIXED_POINT_MAX_ITER {
127            let dy = f(t_next, &y_next);
128            n_evals += 1;
129            let mut max_diff = T::zero();
130            for i in 0..n {
131                let y_new = four_thirds * y_n[i] - one_third * y_nm1[i] + two_thirds * step * dy[i];
132                let diff = (y_new - y_next[i]).abs();
133                if diff > max_diff {
134                    max_diff = diff;
135                }
136                y_next[i] = y_new;
137            }
138            if max_diff < T::from_f64(FIXED_POINT_TOL) {
139                converged = true;
140                break;
141            }
142        }
143
144        if !converged {
145            return Err(OptimError::ConvergenceFailure {
146                iterations: n_steps,
147            });
148        }
149
150        t = t_next;
151        t_values.push(t);
152        y_values.push(y_next);
153        n_steps += 1;
154
155        // Event detection
156        if let Some(ref event_fn) = options.event_fn {
157            let y_cur = &y_values[y_values.len() - 1];
158            let val = event_fn(t, y_cur);
159            if val.abs() < T::from_f64(1e-12)
160                || (t_values.len() > 1 && {
161                    let prev_y = &y_values[y_values.len() - 2];
162                    let prev_t = t_values[t_values.len() - 2];
163                    let prev_val = event_fn(prev_t, prev_y);
164                    (prev_val > T::zero()) != (val > T::zero())
165                })
166            {
167                break;
168            }
169        }
170    }
171
172    Ok(OdeResult {
173        t: t_values,
174        y: y_values,
175        n_evals,
176        n_steps,
177        success: true,
178    })
179}
180
181#[cfg(test)]
182mod tests {
183    use super::*;
184
185    #[test]
186    fn test_bdf2_exponential_decay() {
187        // dy/dt = -y, y(0) = 1 => y(t) = e^(-t)
188        let result = bdf2(
189            |_t: f64, y: &[f64]| vec![-y[0]],
190            [0.0, 1.0],
191            &[1.0],
192            &OdeOptions::default(),
193        )
194        .unwrap();
195
196        let y_final = result.y.last().unwrap()[0];
197        let expected = (-1.0_f64).exp();
198        assert!(
199            (y_final - expected).abs() < 1e-4,
200            "y_final={y_final}, expected={expected}"
201        );
202    }
203
204    #[test]
205    fn test_bdf2_stiff_system() {
206        // dy/dt = -50*y, y(0) = 1 => y(t) = e^(-50t)
207        // This is stiff — BDF should handle it well
208        let result = bdf2(
209            |_t: f64, y: &[f64]| vec![-50.0 * y[0]],
210            [0.0, 0.5],
211            &[1.0],
212            &OdeOptions {
213                first_step: Some(0.002),
214                max_steps: 5000,
215                ..OdeOptions::default()
216            },
217        )
218        .unwrap();
219
220        let y_final = result.y.last().unwrap()[0];
221        let expected = (-25.0_f64).exp();
222        // BDF-2 on stiff problems: allow moderate tolerance
223        assert!(
224            (y_final - expected).abs() < 1e-3,
225            "y_final={y_final}, expected={expected}, err={}",
226            (y_final - expected).abs()
227        );
228        assert!(result.success);
229    }
230
231    #[test]
232    fn test_bdf2_linear() {
233        // dy/dt = 1, y(0) = 0 => y(t) = t
234        let result = bdf2(
235            |_t: f64, _y: &[f64]| vec![1.0],
236            [0.0, 2.0],
237            &[0.0],
238            &OdeOptions::default(),
239        )
240        .unwrap();
241
242        let y_final = result.y.last().unwrap()[0];
243        assert!(
244            (y_final - 2.0).abs() < 1e-6,
245            "y_final={y_final}, expected=2.0"
246        );
247    }
248
249    #[test]
250    fn test_bdf2_system() {
251        // Coupled stiff system: dy0/dt = -20*y0 + y1, dy1/dt = y0 - 20*y1
252        // Both decay rapidly
253        let result = bdf2(
254            |_t: f64, y: &[f64]| vec![-20.0 * y[0] + y[1], y[0] - 20.0 * y[1]],
255            [0.0, 1.0],
256            &[1.0, 0.0],
257            &OdeOptions {
258                first_step: Some(0.002),
259                max_steps: 5000,
260                ..OdeOptions::default()
261            },
262        )
263        .unwrap();
264
265        let y_final = &result.y.last().unwrap();
266        // Both components should be very small after t=1
267        assert!(
268            y_final[0].abs() < 1e-3,
269            "y0={} should be near zero",
270            y_final[0]
271        );
272        assert!(
273            y_final[1].abs() < 1e-3,
274            "y1={} should be near zero",
275            y_final[1]
276        );
277    }
278}