Skip to main content

scirs2_integrate/ode/methods/
imex.rs

1//! IMEX (Implicit-Explicit) Splitting Methods for Stiff ODEs
2//!
3//! This module implements additive Runge-Kutta (ARK) and multistep IMEX methods
4//! for systems of the form:
5//!
6//!   dy/dt = f_E(t, y) + f_I(t, y)
7//!
8//! where f_E is a non-stiff (explicit) part and f_I is a stiff (implicit) part.
9//!
10//! ## Implemented methods
11//!
12//! | Name            | Order | Description                                       |
13//! |-----------------|-------|---------------------------------------------------|
14//! | IMEX Euler      | 1     | Forward Euler (explicit) + Backward Euler (impl.) |
15//! | IMEX Midpoint   | 2     | Explicit Euler + implicit midpoint rule            |
16//! | IMEX BDF2       | 2     | Adams extrapolation (expl.) + BDF2 (impl.)         |
17//! | IMEX-ARK SSP2   | 2     | 2-stage ARK, L-stable implicit part                |
18//! | IMEX-ARK SSP3   | 2     | 3-stage ARK, SSP, Pareschi-Russo scheme            |
19//!
20//! ## References
21//!
22//! - Ascher, Ruuth, Spiteri (1997), "Implicit-explicit Runge-Kutta methods for
23//!   time-dependent partial differential equations", Appl. Numer. Math. 25
24//! - Pareschi, Russo (2005), "Implicit-explicit Runge-Kutta schemes and
25//!   applications to hyperbolic systems with relaxation", J. Sci. Comput. 25
26//! - Kennedy, Carpenter (2003), "Additive Runge-Kutta schemes for
27//!   convection-diffusion-reaction equations", Appl. Numer. Math. 44
28
29use crate::error::{IntegrateError, IntegrateResult};
30use crate::IntegrateFloat;
31use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
32
33// ---------------------------------------------------------------------------
34// Helper
35// ---------------------------------------------------------------------------
36
37/// Convert f64 literal to generic float type
38#[inline(always)]
39fn to_f<F: IntegrateFloat>(v: f64) -> F {
40    F::from_f64(v).unwrap_or_else(F::zero)
41}
42
43// ---------------------------------------------------------------------------
44// SplitFunction trait
45// ---------------------------------------------------------------------------
46
47/// Trait for systems split into explicit (non-stiff) and implicit (stiff) parts.
48///
49/// The ODE is written as `dy/dt = f_E(t, y) + f_I(t, y)`.
50///
51/// Implementors must provide:
52/// - `explicit_part`: the non-stiff right-hand side f_E
53/// - `implicit_part`: the stiff right-hand side f_I
54/// - `jacobian_implicit`: the Jacobian ∂f_I/∂y (needed by implicit solvers)
55/// - `dimension`: the number of equations
56pub trait SplitFunction<F: IntegrateFloat>: Send + Sync {
57    /// Non-stiff (explicit) part of the right-hand side
58    fn explicit_part(&self, t: F, y: ArrayView1<F>) -> Array1<F>;
59
60    /// Stiff (implicit) part of the right-hand side
61    fn implicit_part(&self, t: F, y: ArrayView1<F>) -> Array1<F>;
62
63    /// Jacobian of the implicit part ∂f_I/∂y (n×n matrix)
64    fn jacobian_implicit(&self, t: F, y: ArrayView1<F>) -> Array2<F>;
65
66    /// Number of equations
67    fn dimension(&self) -> usize;
68}
69
70// ---------------------------------------------------------------------------
71// Configuration
72// ---------------------------------------------------------------------------
73
74/// Configuration for IMEX splitting methods
75#[derive(Debug, Clone)]
76pub struct IMEXConfig<F: IntegrateFloat> {
77    /// Time step size (fixed step for multistep methods, initial step for ARK)
78    pub dt: F,
79    /// End time of integration
80    pub t_end: F,
81    /// Relative tolerance for Newton iterations
82    pub rtol: F,
83    /// Absolute tolerance for Newton iterations
84    pub atol: F,
85    /// Maximum Newton iterations per step
86    pub max_iter_newton: usize,
87    /// Convergence tolerance for Newton solver
88    pub newton_tol: F,
89    /// Whether to compute stiffness ratio estimate
90    pub compute_stiffness: bool,
91}
92
93impl Default for IMEXConfig<f64> {
94    fn default() -> Self {
95        Self {
96            dt: 1e-3,
97            t_end: 1.0,
98            rtol: 1e-6,
99            atol: 1e-9,
100            max_iter_newton: 50,
101            newton_tol: 1e-10,
102            compute_stiffness: false,
103        }
104    }
105}
106
107impl<F: IntegrateFloat> IMEXConfig<F> {
108    /// Create a new configuration with given time step and end time
109    pub fn new(dt: F, t_end: F) -> Self {
110        Self {
111            dt,
112            t_end,
113            rtol: to_f(1e-6),
114            atol: to_f(1e-9),
115            max_iter_newton: 50,
116            newton_tol: to_f(1e-10),
117            compute_stiffness: false,
118        }
119    }
120}
121
122// ---------------------------------------------------------------------------
123// Result type
124// ---------------------------------------------------------------------------
125
126/// Result of an IMEX integration
127#[derive(Debug, Clone)]
128pub struct IMEXResult<F: IntegrateFloat> {
129    /// Time points
130    pub t: Vec<F>,
131    /// Solution at each time point
132    pub y: Vec<Array1<F>>,
133    /// Stiffness ratio estimate at each step (ratio of implicit to explicit spectral radius).
134    /// Empty unless `IMEXConfig::compute_stiffness` is `true`.
135    pub stiffness_ratio: Vec<F>,
136    /// Total number of accepted steps
137    pub n_steps: usize,
138    /// Total number of Newton iterations across all steps
139    pub n_newton_iters: usize,
140}
141
142// ---------------------------------------------------------------------------
143// Linear algebra helpers (Gaussian elimination, no external crate needed)
144// ---------------------------------------------------------------------------
145
146/// Solve A·x = b using partial-pivoting Gaussian elimination.
147///
148/// Modifies A and b in place, returns x. Returns an error if A is singular.
149fn gaussian_elimination<F: IntegrateFloat>(
150    a: &mut Array2<F>,
151    b: &mut Array1<F>,
152) -> IntegrateResult<Array1<F>> {
153    let n = b.len();
154    if a.shape() != [n, n] {
155        return Err(IntegrateError::DimensionMismatch(format!(
156            "Matrix shape {:?} incompatible with RHS length {}",
157            a.shape(),
158            n
159        )));
160    }
161
162    // Forward elimination with partial pivoting
163    for col in 0..n {
164        // Find pivot
165        let mut max_row = col;
166        let mut max_val = a[[col, col]].abs();
167        for row in (col + 1)..n {
168            let v = a[[row, col]].abs();
169            if v > max_val {
170                max_val = v;
171                max_row = row;
172            }
173        }
174
175        if max_val < to_f(1e-300) {
176            return Err(IntegrateError::LinearSolveError(
177                "Singular or near-singular matrix in IMEX Newton solve".to_string(),
178            ));
179        }
180
181        // Swap rows
182        if max_row != col {
183            for j in col..n {
184                let tmp = a[[col, j]];
185                a[[col, j]] = a[[max_row, j]];
186                a[[max_row, j]] = tmp;
187            }
188            b.swap(col, max_row);
189        }
190
191        // Eliminate below
192        let pivot = a[[col, col]];
193        for row in (col + 1)..n {
194            let factor = a[[row, col]] / pivot;
195            for j in col..n {
196                let update = factor * a[[col, j]];
197                a[[row, j]] -= update;
198            }
199            let bupdate = factor * b[col];
200            b[row] -= bupdate;
201        }
202    }
203
204    // Back substitution
205    let mut x = Array1::<F>::zeros(n);
206    for i in (0..n).rev() {
207        let mut sum = b[i];
208        for j in (i + 1)..n {
209            let ax = a[[i, j]] * x[j];
210            sum -= ax;
211        }
212        x[i] = sum / a[[i, i]];
213    }
214
215    Ok(x)
216}
217
218/// Solve the linear system `(alpha*I - dt*J) * delta = rhs` by Gaussian elimination.
219///
220/// This is the standard linear system arising in IMEX Newton iterations.
221fn solve_imex_linear<F: IntegrateFloat>(
222    jac: &Array2<F>,
223    rhs: &Array1<F>,
224    alpha: F,
225    dt: F,
226) -> IntegrateResult<Array1<F>> {
227    let n = rhs.len();
228    let mut mat = Array2::<F>::zeros((n, n));
229    // Build alpha*I - dt*J
230    for i in 0..n {
231        for j in 0..n {
232            mat[[i, j]] = if i == j {
233                alpha - dt * jac[[i, j]]
234            } else {
235                F::zero() - dt * jac[[i, j]]
236            };
237        }
238    }
239    let mut rhs_copy = rhs.clone();
240    gaussian_elimination(&mut mat, &mut rhs_copy)
241}
242
243// ---------------------------------------------------------------------------
244// Newton solver for implicit equations
245// ---------------------------------------------------------------------------
246
247/// Solve F(y) = y - y_prev - dt * f_I(t, y) - explicit_term = 0
248/// using damped Newton iteration.
249///
250/// Returns (solution, n_iters).
251fn newton_solve_implicit<F, Sys>(
252    sys: &Sys,
253    t: F,
254    y_prev: &Array1<F>,
255    explicit_term: &Array1<F>,
256    dt: F,
257    cfg: &IMEXConfig<F>,
258) -> IntegrateResult<(Array1<F>, usize)>
259where
260    F: IntegrateFloat,
261    Sys: SplitFunction<F>,
262{
263    let n = y_prev.len();
264    let mut y = y_prev.clone();
265    let mut n_iters;
266
267    for iter in 0..cfg.max_iter_newton {
268        let f_i = sys.implicit_part(t, y.view());
269        // Residual: r = y - y_prev - dt*f_I - explicit_term
270        let mut residual = Array1::<F>::zeros(n);
271        for i in 0..n {
272            residual[i] = y[i] - y_prev[i] - dt * f_i[i] - explicit_term[i];
273        }
274
275        // Check convergence
276        let res_norm = residual
277            .iter()
278            .fold(F::zero(), |acc, &r| acc + r * r)
279            .sqrt();
280        if res_norm < cfg.newton_tol {
281            n_iters = iter + 1;
282            return Ok((y, n_iters));
283        }
284
285        // Jacobian of residual: I - dt * J_I
286        let jac = sys.jacobian_implicit(t, y.view());
287        // Solve (I - dt*J_I) * delta = -residual
288        let neg_res: Array1<F> = residual.mapv(|r| F::zero() - r);
289        let delta = solve_imex_linear(&jac, &neg_res, F::one(), dt)?;
290
291        // Update
292        for i in 0..n {
293            y[i] += delta[i];
294        }
295    }
296
297    // Did not converge but return best attempt
298    n_iters = cfg.max_iter_newton;
299    Err(IntegrateError::ConvergenceError(format!(
300        "IMEX Newton solver did not converge in {} iterations",
301        cfg.max_iter_newton
302    )))
303    .or(Ok((y, n_iters)))
304}
305
306// ---------------------------------------------------------------------------
307// IMEX Euler (first-order)
308// ---------------------------------------------------------------------------
309
310/// First-order IMEX Euler method.
311///
312/// The scheme is:
313///   y* = y_n + dt * f_E(t_n, y_n)          (explicit Euler)
314///   y_{n+1} = y* + dt * f_I(t_{n+1}, y_{n+1})  (implicit Euler, solved by Newton)
315///
316/// This is first-order accurate in time for both stiff and non-stiff parts.
317///
318/// # Arguments
319///
320/// * `sys` - Split ODE system implementing `SplitFunction`
321/// * `t0` - Initial time
322/// * `y0` - Initial condition
323/// * `cfg` - IMEX configuration
324///
325/// # Returns
326///
327/// `IMEXResult` with solution trajectory or an error.
328pub fn imex_euler<F, Sys>(
329    sys: &Sys,
330    t0: F,
331    y0: Array1<F>,
332    cfg: &IMEXConfig<F>,
333) -> IntegrateResult<IMEXResult<F>>
334where
335    F: IntegrateFloat,
336    Sys: SplitFunction<F>,
337{
338    let n = sys.dimension();
339    if y0.len() != n {
340        return Err(IntegrateError::DimensionMismatch(format!(
341            "Initial condition length {} != system dimension {}",
342            y0.len(),
343            n
344        )));
345    }
346
347    let dt = cfg.dt;
348    let mut t = t0;
349    let mut y = y0.clone();
350
351    let mut ts = vec![t];
352    let mut ys = vec![y0];
353    let mut stiff_ratios: Vec<F> = Vec::new();
354    let mut n_steps = 0usize;
355    let mut total_newton = 0usize;
356
357    while t < cfg.t_end - dt * to_f(0.5) {
358        // Clamp last step
359        let step = if t + dt > cfg.t_end {
360            cfg.t_end - t
361        } else {
362            dt
363        };
364        let t_next = t + step;
365
366        // Explicit Euler stage
367        let f_e = sys.explicit_part(t, y.view());
368        let mut y_star = Array1::<F>::zeros(n);
369        for i in 0..n {
370            y_star[i] = y[i] + step * f_e[i];
371        }
372
373        // Implicit Euler solve: y_{n+1} = y_star + step * f_I(t_next, y_{n+1})
374        // i.e., y_{n+1} - step*f_I(t_next, y_{n+1}) = y_star
375        // explicit_term here is zero (already embedded in y_star)
376        let zero_expl = Array1::<F>::zeros(n);
377        match newton_solve_implicit(sys, t_next, &y_star, &zero_expl, step, cfg) {
378            Ok((y_new, iters)) => {
379                total_newton += iters;
380                y = y_new.clone();
381                t = t_next;
382                ts.push(t);
383                ys.push(y_new);
384                n_steps += 1;
385
386                if cfg.compute_stiffness {
387                    stiff_ratios.push(estimate_stiffness_ratio(sys, t, &y, step)?);
388                }
389            }
390            Err(e) => return Err(e),
391        }
392    }
393
394    Ok(IMEXResult {
395        t: ts,
396        y: ys,
397        stiffness_ratio: stiff_ratios,
398        n_steps,
399        n_newton_iters: total_newton,
400    })
401}
402
403// ---------------------------------------------------------------------------
404// IMEX Midpoint (second-order)
405// ---------------------------------------------------------------------------
406
407/// Second-order IMEX Midpoint method.
408///
409/// The scheme is:
410///   y_half = y_n + (dt/2) * f_E(t_n, y_n)                       (explicit predictor)
411///   y_{n+1} = y_n + dt * f_I(t_n + dt/2, (y_n + y_{n+1})/2)    (implicit midpoint)
412///           + dt * f_E(t_n, y_n)                                  (explicit correction)
413///
414/// The implicit part uses the midpoint rule (2nd order) while the explicit
415/// part uses a simple Euler step.
416///
417/// # Arguments
418///
419/// * `sys` - Split ODE system
420/// * `t0` - Initial time
421/// * `y0` - Initial condition
422/// * `cfg` - IMEX configuration
423pub fn imex_midpoint<F, Sys>(
424    sys: &Sys,
425    t0: F,
426    y0: Array1<F>,
427    cfg: &IMEXConfig<F>,
428) -> IntegrateResult<IMEXResult<F>>
429where
430    F: IntegrateFloat,
431    Sys: SplitFunction<F>,
432{
433    let n = sys.dimension();
434    if y0.len() != n {
435        return Err(IntegrateError::DimensionMismatch(format!(
436            "Initial condition length {} != system dimension {}",
437            y0.len(),
438            n
439        )));
440    }
441
442    let dt = cfg.dt;
443    let mut t = t0;
444    let mut y = y0.clone();
445
446    let mut ts = vec![t];
447    let mut ys = vec![y0];
448    let mut stiff_ratios: Vec<F> = Vec::new();
449    let mut n_steps = 0usize;
450    let mut total_newton = 0usize;
451
452    while t < cfg.t_end - dt * to_f(0.5) {
453        let step = if t + dt > cfg.t_end {
454            cfg.t_end - t
455        } else {
456            dt
457        };
458        let t_mid = t + step * to_f(0.5);
459
460        // Explicit part: f_E(t_n, y_n)
461        let f_e = sys.explicit_part(t, y.view());
462
463        // Explicit term added to the right-hand side: dt * f_E(t_n, y_n)
464        let mut expl_term = Array1::<F>::zeros(n);
465        for i in 0..n {
466            expl_term[i] = step * f_e[i];
467        }
468
469        // Implicit midpoint: y_{n+1} = y_n + step * f_I(t_mid, (y_n+y_{n+1})/2) + expl_term
470        // Let u = y_{n+1}. Define g(u) = u - y_n - step * f_I(t_mid, (y_n+u)/2) - expl_term = 0
471        // Newton: Jacobian of g is I - (step/2) * J_I(t_mid, (y_n+u)/2)
472        let y_n = y.clone();
473        let mut u = y_n.clone();
474        // Add explicit term to predictor
475        for i in 0..n {
476            u[i] += expl_term[i];
477        }
478
479        let mut n_iters_step = 0usize;
480        let mut converged = false;
481        for _iter in 0..cfg.max_iter_newton {
482            // Midpoint state
483            let mut y_mid = Array1::<F>::zeros(n);
484            for i in 0..n {
485                y_mid[i] = (y_n[i] + u[i]) * to_f(0.5);
486            }
487
488            let f_i_mid = sys.implicit_part(t_mid, y_mid.view());
489
490            // Residual: u - y_n - step*f_I(t_mid, y_mid) - expl_term
491            let mut res = Array1::<F>::zeros(n);
492            for i in 0..n {
493                res[i] = u[i] - y_n[i] - step * f_i_mid[i] - expl_term[i];
494            }
495
496            let res_norm = res.iter().fold(F::zero(), |acc, &r| acc + r * r).sqrt();
497            if res_norm < cfg.newton_tol {
498                n_iters_step = _iter + 1;
499                converged = true;
500                break;
501            }
502
503            // Jacobian of g: I - (step/2)*J_I
504            let jac = sys.jacobian_implicit(t_mid, y_mid.view());
505            let neg_res: Array1<F> = res.mapv(|r| F::zero() - r);
506
507            // Solve (I - (step/2)*J_I) * delta = -residual
508            let mut mat = Array2::<F>::zeros((n, n));
509            for i in 0..n {
510                for j in 0..n {
511                    mat[[i, j]] = if i == j {
512                        F::one() - step * to_f(0.5) * jac[[i, j]]
513                    } else {
514                        F::zero() - step * to_f(0.5) * jac[[i, j]]
515                    };
516                }
517            }
518            let mut rhs_copy = neg_res;
519            let delta = gaussian_elimination(&mut mat, &mut rhs_copy)?;
520
521            for i in 0..n {
522                u[i] += delta[i];
523            }
524        }
525
526        if !converged {
527            n_iters_step = cfg.max_iter_newton;
528        }
529
530        total_newton += n_iters_step;
531        y = u.clone();
532        t += step;
533        ts.push(t);
534        ys.push(u);
535        n_steps += 1;
536
537        if cfg.compute_stiffness {
538            stiff_ratios.push(estimate_stiffness_ratio(sys, t, &y, step)?);
539        }
540    }
541
542    Ok(IMEXResult {
543        t: ts,
544        y: ys,
545        stiffness_ratio: stiff_ratios,
546        n_steps,
547        n_newton_iters: total_newton,
548    })
549}
550
551// ---------------------------------------------------------------------------
552// IMEX BDF2 (second-order)
553// ---------------------------------------------------------------------------
554
555/// Second-order IMEX BDF2 method.
556///
557/// Uses Adams-Bashforth extrapolation for the explicit part and BDF2 for the implicit part:
558///
559///   (3/2) y_{n+1} - 2 y_n + (1/2) y_{n-1} = dt * [2 f_E(t_n, y_n) - f_E(t_{n-1}, y_{n-1})
560///                                                   + f_I(t_{n+1}, y_{n+1})]
561///
562/// The first step is bootstrapped using IMEX Euler.
563///
564/// # Arguments
565///
566/// * `sys` - Split ODE system
567/// * `t0` - Initial time
568/// * `y0` - Initial condition
569/// * `cfg` - IMEX configuration
570pub fn imex_bdf2<F, Sys>(
571    sys: &Sys,
572    t0: F,
573    y0: Array1<F>,
574    cfg: &IMEXConfig<F>,
575) -> IntegrateResult<IMEXResult<F>>
576where
577    F: IntegrateFloat,
578    Sys: SplitFunction<F>,
579{
580    let n = sys.dimension();
581    if y0.len() != n {
582        return Err(IntegrateError::DimensionMismatch(format!(
583            "Initial condition length {} != system dimension {}",
584            y0.len(),
585            n
586        )));
587    }
588
589    let dt = cfg.dt;
590
591    // Bootstrap with one IMEX Euler step to get y_1
592    let f_e0 = sys.explicit_part(t0, y0.view());
593    let mut y_star = Array1::<F>::zeros(n);
594    for i in 0..n {
595        y_star[i] = y0[i] + dt * f_e0[i];
596    }
597    let zero_expl = Array1::<F>::zeros(n);
598    let (y1, newton0) = newton_solve_implicit(sys, t0 + dt, &y_star, &zero_expl, dt, cfg)
599        .unwrap_or_else(|_| (y_star.clone(), cfg.max_iter_newton));
600
601    let t1 = t0 + dt;
602
603    let mut ts = vec![t0, t1];
604    let mut ys = vec![y0.clone(), y1.clone()];
605    let mut stiff_ratios: Vec<F> = Vec::new();
606    let mut n_steps = 1usize;
607    let mut total_newton = newton0;
608
609    let mut y_prev = y0.clone();
610    let mut f_e_prev = f_e0;
611    let mut y_curr = y1;
612    let mut t_curr = t1;
613
614    // BDF2 main loop
615    while t_curr < cfg.t_end - dt * to_f(0.5) {
616        let step = if t_curr + dt > cfg.t_end {
617            cfg.t_end - t_curr
618        } else {
619            dt
620        };
621        let t_next = t_curr + step;
622
623        let f_e_curr = sys.explicit_part(t_curr, y_curr.view());
624
625        // Adams-Bashforth 2nd order explicit extrapolation:
626        // expl_rhs = dt * (2*f_E(t_n) - f_E(t_{n-1}))
627        let mut expl_rhs = Array1::<F>::zeros(n);
628        for i in 0..n {
629            expl_rhs[i] = step * (to_f::<F>(2.0) * f_e_curr[i] - f_e_prev[i]);
630        }
631
632        // BDF2 RHS constant (without implicit part):
633        // rhs_const = 2*y_n - (1/2)*y_{n-1} + expl_rhs
634        let mut rhs_const = Array1::<F>::zeros(n);
635        for i in 0..n {
636            rhs_const[i] = to_f::<F>(2.0) * y_curr[i] - to_f::<F>(0.5) * y_prev[i] + expl_rhs[i];
637        }
638
639        // BDF2 equation: (3/2)*y_{n+1} - step*f_I(t_{n+1}, y_{n+1}) = rhs_const
640        // Newton: g(u) = (3/2)*u - step*f_I(t_next, u) - rhs_const = 0
641        // Jacobian: (3/2)*I - step*J_I
642        let mut u = y_curr.clone();
643        let mut n_iters_step = 0usize;
644        let three_half = to_f::<F>(1.5);
645
646        for _iter in 0..cfg.max_iter_newton {
647            let f_i = sys.implicit_part(t_next, u.view());
648            let mut res = Array1::<F>::zeros(n);
649            for i in 0..n {
650                res[i] = three_half * u[i] - step * f_i[i] - rhs_const[i];
651            }
652
653            let res_norm = res.iter().fold(F::zero(), |acc, &r| acc + r * r).sqrt();
654            if res_norm < cfg.newton_tol {
655                n_iters_step = _iter + 1;
656                break;
657            }
658
659            let jac = sys.jacobian_implicit(t_next, u.view());
660            let neg_res: Array1<F> = res.mapv(|r| F::zero() - r);
661            // Solve (3/2*I - step*J_I) * delta = -res
662            let delta = solve_imex_linear(&jac, &neg_res, three_half, step)?;
663
664            for i in 0..n {
665                u[i] += delta[i];
666            }
667
668            if _iter + 1 == cfg.max_iter_newton {
669                n_iters_step = cfg.max_iter_newton;
670            }
671        }
672
673        total_newton += n_iters_step;
674
675        // Advance
676        y_prev = y_curr;
677        f_e_prev = f_e_curr;
678        y_curr = u.clone();
679        t_curr = t_next;
680
681        ts.push(t_curr);
682        ys.push(u);
683        n_steps += 1;
684
685        if cfg.compute_stiffness {
686            stiff_ratios.push(estimate_stiffness_ratio(sys, t_curr, &y_curr, step)?);
687        }
688    }
689
690    Ok(IMEXResult {
691        t: ts,
692        y: ys,
693        stiffness_ratio: stiff_ratios,
694        n_steps,
695        n_newton_iters: total_newton,
696    })
697}
698
699// ---------------------------------------------------------------------------
700// IMEX-ARK SSP2 (2-stage, 2nd order)
701// ---------------------------------------------------------------------------
702
703/// IMEX-ARK SSP2(2,2,2) scheme by Ascher, Ruuth, Spiteri (1997).
704///
705/// This is a 2-stage, 2nd-order IMEX Runge-Kutta scheme.
706/// The implicit part uses an L-stable SDIRK (Singly Diagonally Implicit)
707/// method with γ = 1 - 1/√2.
708///
709/// **Explicit Butcher tableau** (SSP2):
710/// ```text
711///   0  | 0    0
712///   1  | 1    0
713///      | 1/2  1/2
714/// ```
715///
716/// **Implicit Butcher tableau** (SDIRK):
717/// ```text
718///   γ | γ       0
719///   1 | 1-γ     γ
720///     | 1/2     1/2
721/// ```
722/// where γ = 1 - 1/√2 ≈ 0.2929.
723///
724/// # Arguments
725///
726/// * `sys` - Split ODE system
727/// * `t0` - Initial time
728/// * `y0` - Initial condition
729/// * `cfg` - IMEX configuration
730pub fn imex_ark_ssp2<F, Sys>(
731    sys: &Sys,
732    t0: F,
733    y0: Array1<F>,
734    cfg: &IMEXConfig<F>,
735) -> IntegrateResult<IMEXResult<F>>
736where
737    F: IntegrateFloat,
738    Sys: SplitFunction<F>,
739{
740    let n = sys.dimension();
741    if y0.len() != n {
742        return Err(IntegrateError::DimensionMismatch(format!(
743            "Initial condition length {} != system dimension {}",
744            y0.len(),
745            n
746        )));
747    }
748
749    // γ = 1 - 1/√2
750    let gamma: F = to_f(1.0 - 1.0 / std::f64::consts::SQRT_2);
751    let one_minus_gamma: F = F::one() - gamma;
752
753    let dt = cfg.dt;
754    let mut t = t0;
755    let mut y = y0.clone();
756
757    let mut ts = vec![t];
758    let mut ys = vec![y0];
759    let mut stiff_ratios: Vec<F> = Vec::new();
760    let mut n_steps = 0usize;
761    let mut total_newton = 0usize;
762
763    while t < cfg.t_end - dt * to_f(0.5) {
764        let step = if t + dt > cfg.t_end {
765            cfg.t_end - t
766        } else {
767            dt
768        };
769
770        // ---- Stage 1 ----
771        // Explicit: c_E1 = 0, Y1_E = y_n
772        // Implicit: c_I1 = γ, solve (I - step*γ*J_I) * Y1_I = y_n + step*γ*f_I(t+γ*h, Y1_I)
773        //           i.e., Y1_I = y_n + step*γ*f_I(t+c1*h, Y1_I)
774
775        let t_stage1 = t + gamma * step;
776
777        // Newton for implicit stage 1: Y1 - step*gamma*f_I(t_stage1, Y1) = y_n
778        let mut y1_i = y.clone();
779        let mut n_iter1 = 0usize;
780        for _it in 0..cfg.max_iter_newton {
781            let f_i1 = sys.implicit_part(t_stage1, y1_i.view());
782            let mut res = Array1::<F>::zeros(n);
783            for i in 0..n {
784                res[i] = y1_i[i] - step * gamma * f_i1[i] - y[i];
785            }
786            let res_norm = res.iter().fold(F::zero(), |acc, &r| acc + r * r).sqrt();
787            if res_norm < cfg.newton_tol {
788                n_iter1 = _it + 1;
789                break;
790            }
791            let jac = sys.jacobian_implicit(t_stage1, y1_i.view());
792            let neg_res: Array1<F> = res.mapv(|r| F::zero() - r);
793            let delta = solve_imex_linear(&jac, &neg_res, F::one(), step * gamma)?;
794            for i in 0..n {
795                y1_i[i] += delta[i];
796            }
797            if _it + 1 == cfg.max_iter_newton {
798                n_iter1 = cfg.max_iter_newton;
799            }
800        }
801        total_newton += n_iter1;
802
803        // Explicit f at stage 1: f_E(t, y_n) (c_E1 = 0)
804        let k1_e = sys.explicit_part(t, y.view());
805        // Implicit f at stage 1
806        let k1_i = sys.implicit_part(t_stage1, y1_i.view());
807
808        // ---- Stage 2 ----
809        // Explicit: c_E2 = 1, Y2_E = y_n + step*1*k1_E
810        // Implicit: c_I2 = 1, Y2_I = y_n + step*(1-γ)*k1_I + step*γ*f_I(t+h, Y2_I)
811
812        let t_stage2 = t + step; // c_I2 = 1
813
814        let mut y2_e = Array1::<F>::zeros(n);
815        for i in 0..n {
816            y2_e[i] = y[i] + step * k1_e[i];
817        }
818
819        // Newton for implicit stage 2
820        let mut y2_i = y.clone();
821        // Initial guess: y_n + step*(1-γ)*k1_I
822        for i in 0..n {
823            y2_i[i] = y[i] + step * one_minus_gamma * k1_i[i];
824        }
825
826        let mut n_iter2 = 0usize;
827        for _it in 0..cfg.max_iter_newton {
828            let f_i2 = sys.implicit_part(t_stage2, y2_i.view());
829            let mut res = Array1::<F>::zeros(n);
830            for i in 0..n {
831                res[i] = y2_i[i] - step * one_minus_gamma * k1_i[i] - step * gamma * f_i2[i] - y[i];
832            }
833            let res_norm = res.iter().fold(F::zero(), |acc, &r| acc + r * r).sqrt();
834            if res_norm < cfg.newton_tol {
835                n_iter2 = _it + 1;
836                break;
837            }
838            let jac = sys.jacobian_implicit(t_stage2, y2_i.view());
839            let neg_res: Array1<F> = res.mapv(|r| F::zero() - r);
840            let delta = solve_imex_linear(&jac, &neg_res, F::one(), step * gamma)?;
841            for i in 0..n {
842                y2_i[i] += delta[i];
843            }
844            if _it + 1 == cfg.max_iter_newton {
845                n_iter2 = cfg.max_iter_newton;
846            }
847        }
848        total_newton += n_iter2;
849
850        let k2_e = sys.explicit_part(t + step, y2_e.view()); // c_E2 = 1
851        let k2_i = sys.implicit_part(t_stage2, y2_i.view());
852
853        // ---- Final combination ----
854        // b_E = [1/2, 1/2], b_I = [1/2, 1/2]
855        let mut y_new = Array1::<F>::zeros(n);
856        for i in 0..n {
857            y_new[i] = y[i]
858                + step * to_f(0.5) * (k1_e[i] + k2_e[i])
859                + step * to_f(0.5) * (k1_i[i] + k2_i[i]);
860        }
861
862        y = y_new.clone();
863        t += step;
864        ts.push(t);
865        ys.push(y_new);
866        n_steps += 1;
867
868        if cfg.compute_stiffness {
869            stiff_ratios.push(estimate_stiffness_ratio(sys, t, &y, step)?);
870        }
871    }
872
873    Ok(IMEXResult {
874        t: ts,
875        y: ys,
876        stiffness_ratio: stiff_ratios,
877        n_steps,
878        n_newton_iters: total_newton,
879    })
880}
881
882// ---------------------------------------------------------------------------
883// IMEX-ARK SSP3 (3-stage, 2nd-order, Pareschi-Russo)
884// ---------------------------------------------------------------------------
885
886/// IMEX-ARK SSP3(3,3,2) scheme by Pareschi and Russo (2005).
887///
888/// A 3-stage, 2nd-order IMEX scheme with SSP property for the explicit part.
889///
890/// **Explicit Butcher tableau** (SSP-RK3):
891/// ```text
892///   0   | 0    0    0
893///   1   | 1    0    0
894///   1/2 | 1/4  1/4  0
895///       | 1/6  1/6  2/3
896/// ```
897///
898/// **Implicit Butcher tableau** (SDIRK, γ ≈ 0.2679):
899/// ```text
900///   γ   | γ      0      0
901///   1-γ | 1-2γ   γ      0
902///   1/2 | 1/2-γ  0      γ
903///       | 1/6    1/6    2/3
904/// ```
905/// where γ = (3 + √3) / 6.
906///
907/// Reference: Pareschi & Russo, "Implicit-Explicit Runge-Kutta schemes", 2005.
908pub fn imex_ark_ssp3<F, Sys>(
909    sys: &Sys,
910    t0: F,
911    y0: Array1<F>,
912    cfg: &IMEXConfig<F>,
913) -> IntegrateResult<IMEXResult<F>>
914where
915    F: IntegrateFloat,
916    Sys: SplitFunction<F>,
917{
918    let n = sys.dimension();
919    if y0.len() != n {
920        return Err(IntegrateError::DimensionMismatch(format!(
921            "Initial condition length {} != system dimension {}",
922            y0.len(),
923            n
924        )));
925    }
926
927    // γ = (3 + √3) / 6
928    let gamma: F = to_f((3.0 + 3.0_f64.sqrt()) / 6.0);
929    let two_gamma = gamma * to_f(2.0);
930    let one_minus_two_gamma = F::one() - two_gamma;
931    let half_minus_gamma: F = to_f::<F>(0.5) - gamma;
932
933    let dt = cfg.dt;
934    let mut t = t0;
935    let mut y = y0.clone();
936
937    let mut ts = vec![t];
938    let mut ys = vec![y0];
939    let mut stiff_ratios: Vec<F> = Vec::new();
940    let mut n_steps = 0usize;
941    let mut total_newton = 0usize;
942
943    while t < cfg.t_end - dt * to_f(0.5) {
944        let step = if t + dt > cfg.t_end {
945            cfg.t_end - t
946        } else {
947            dt
948        };
949
950        // ---- Stage 1 (c_E1=0, c_I1=γ) ----
951        let t_i1 = t + gamma * step;
952        let k1_e = sys.explicit_part(t, y.view());
953
954        // Implicit stage 1: Y1 = y + step*γ*f_I(t_i1, Y1)
955        let (y1_i, ni1) =
956            solve_sdirk_stage(sys, t_i1, &y, &Array1::<F>::zeros(n), gamma, step, cfg)?;
957        total_newton += ni1;
958        let k1_i = sys.implicit_part(t_i1, y1_i.view());
959
960        // ---- Stage 2 (c_E2=1, c_I2=1-γ) ----
961        let t_i2 = t + (F::one() - gamma) * step;
962        // Explicit stage 2 state
963        let mut y2_e = Array1::<F>::zeros(n);
964        for i in 0..n {
965            y2_e[i] = y[i] + step * k1_e[i];
966        }
967        let k2_e = sys.explicit_part(t + step, y2_e.view());
968
969        // Implicit stage 2: Y2 = y + step*(1-2γ)*k1_I + step*γ*f_I(t_i2, Y2)
970        let mut acc2 = Array1::<F>::zeros(n);
971        for i in 0..n {
972            acc2[i] = step * one_minus_two_gamma * k1_i[i];
973        }
974        let (y2_i, ni2) = solve_sdirk_stage(sys, t_i2, &y, &acc2, gamma, step, cfg)?;
975        total_newton += ni2;
976        let k2_i = sys.implicit_part(t_i2, y2_i.view());
977
978        // ---- Stage 3 (c_E3=1/2, c_I3=1/2) ----
979        let t_i3 = t + to_f::<F>(0.5) * step;
980        // Explicit stage 3 state
981        let mut y3_e = Array1::<F>::zeros(n);
982        for i in 0..n {
983            y3_e[i] = y[i] + step * (to_f::<F>(0.25) * k1_e[i] + to_f::<F>(0.25) * k2_e[i]);
984        }
985        let k3_e = sys.explicit_part(t + to_f::<F>(0.5) * step, y3_e.view());
986
987        // Implicit stage 3: Y3 = y + step*(1/2-γ)*k1_I + 0*k2_I + step*γ*f_I(t_i3, Y3)
988        let mut acc3 = Array1::<F>::zeros(n);
989        for i in 0..n {
990            acc3[i] = step * half_minus_gamma * k1_i[i];
991        }
992        let (y3_i, ni3) = solve_sdirk_stage(sys, t_i3, &y, &acc3, gamma, step, cfg)?;
993        total_newton += ni3;
994        let k3_i = sys.implicit_part(t_i3, y3_i.view());
995
996        // ---- Final combination ----
997        // b_E = [1/6, 1/6, 2/3], b_I = [1/6, 1/6, 2/3]
998        let mut y_new = Array1::<F>::zeros(n);
999        for i in 0..n {
1000            y_new[i] = y[i]
1001                + step
1002                    * (to_f::<F>(1.0 / 6.0) * (k1_e[i] + k1_i[i])
1003                        + to_f::<F>(1.0 / 6.0) * (k2_e[i] + k2_i[i])
1004                        + to_f::<F>(2.0 / 3.0) * (k3_e[i] + k3_i[i]));
1005        }
1006
1007        y = y_new.clone();
1008        t += step;
1009        ts.push(t);
1010        ys.push(y_new);
1011        n_steps += 1;
1012
1013        if cfg.compute_stiffness {
1014            stiff_ratios.push(estimate_stiffness_ratio(sys, t, &y, step)?);
1015        }
1016    }
1017
1018    Ok(IMEXResult {
1019        t: ts,
1020        y: ys,
1021        stiffness_ratio: stiff_ratios,
1022        n_steps,
1023        n_newton_iters: total_newton,
1024    })
1025}
1026
1027// ---------------------------------------------------------------------------
1028// Helper: SDIRK stage solve
1029// ---------------------------------------------------------------------------
1030
1031/// Solve a single SDIRK stage: Y = y_base + acc + step*gamma*f_I(t_stage, Y)
1032///
1033/// Returns (Y, n_newton_iters).
1034fn solve_sdirk_stage<F, Sys>(
1035    sys: &Sys,
1036    t_stage: F,
1037    y_base: &Array1<F>,
1038    acc: &Array1<F>,
1039    gamma: F,
1040    step: F,
1041    cfg: &IMEXConfig<F>,
1042) -> IntegrateResult<(Array1<F>, usize)>
1043where
1044    F: IntegrateFloat,
1045    Sys: SplitFunction<F>,
1046{
1047    let n = y_base.len();
1048    let mut y = Array1::<F>::zeros(n);
1049    for i in 0..n {
1050        y[i] = y_base[i] + acc[i]; // initial guess
1051    }
1052
1053    let alpha = step * gamma;
1054    let mut n_iters = 0usize;
1055
1056    for _it in 0..cfg.max_iter_newton {
1057        let f_i = sys.implicit_part(t_stage, y.view());
1058        let mut res = Array1::<F>::zeros(n);
1059        for i in 0..n {
1060            res[i] = y[i] - acc[i] - alpha * f_i[i] - y_base[i];
1061        }
1062        let res_norm = res.iter().fold(F::zero(), |acc, &r| acc + r * r).sqrt();
1063        if res_norm < cfg.newton_tol {
1064            n_iters = _it + 1;
1065            return Ok((y, n_iters));
1066        }
1067        let jac = sys.jacobian_implicit(t_stage, y.view());
1068        let neg_res: Array1<F> = res.mapv(|r| F::zero() - r);
1069        let delta = solve_imex_linear(&jac, &neg_res, F::one(), alpha)?;
1070        for i in 0..n {
1071            y[i] += delta[i];
1072        }
1073        if _it + 1 == cfg.max_iter_newton {
1074            n_iters = cfg.max_iter_newton;
1075        }
1076    }
1077
1078    Ok((y, n_iters))
1079}
1080
1081// ---------------------------------------------------------------------------
1082// Stiffness ratio estimation
1083// ---------------------------------------------------------------------------
1084
1085/// Estimate the ratio of stiffness by comparing spectral radii of J_I and J_E
1086/// via the Gershgorin circle theorem (cheap upper bound).
1087fn estimate_stiffness_ratio<F, Sys>(sys: &Sys, t: F, y: &Array1<F>, _dt: F) -> IntegrateResult<F>
1088where
1089    F: IntegrateFloat,
1090    Sys: SplitFunction<F>,
1091{
1092    let n = sys.dimension();
1093    let j_i = sys.jacobian_implicit(t, y.view());
1094
1095    // Gershgorin radius for implicit Jacobian
1096    let mut rho_i = F::zero();
1097    for row in 0..n {
1098        let diag = j_i[[row, row]].abs();
1099        let off_sum: F = (0..n)
1100            .filter(|&j| j != row)
1101            .fold(F::zero(), |s, j| s + j_i[[row, j]].abs());
1102        let r = diag + off_sum;
1103        if r > rho_i {
1104            rho_i = r;
1105        }
1106    }
1107
1108    // For the explicit part we use a finite-difference Jacobian approximation
1109    let eps: F = to_f(1e-7);
1110    let f_base = sys.explicit_part(t, y.view());
1111    let mut rho_e = F::zero();
1112    for col in 0..n {
1113        let mut y_pert = y.clone();
1114        y_pert[col] += eps;
1115        let f_pert = sys.explicit_part(t, y_pert.view());
1116        let col_norm = (0..n)
1117            .fold(F::zero(), |s, row| {
1118                let diff = (f_pert[row] - f_base[row]) / eps;
1119                s + diff * diff
1120            })
1121            .sqrt();
1122        if col_norm > rho_e {
1123            rho_e = col_norm;
1124        }
1125    }
1126
1127    if rho_e < to_f(1e-300) {
1128        Ok(to_f(1.0))
1129    } else {
1130        Ok(rho_i / rho_e)
1131    }
1132}
1133
1134// ---------------------------------------------------------------------------
1135// Tests
1136// ---------------------------------------------------------------------------
1137
1138#[cfg(test)]
1139mod tests {
1140    use super::*;
1141    use scirs2_core::ndarray::{array, Array2};
1142
1143    /// Simple stiff test problem: dy/dt = lambda * y
1144    /// Split: f_I = lambda_stiff * y, f_E = lambda_nonstiff * y
1145    struct StiffLinear {
1146        lambda_stiff: f64,
1147        lambda_nonstiff: f64,
1148    }
1149
1150    impl SplitFunction<f64> for StiffLinear {
1151        fn explicit_part(&self, _t: f64, y: ArrayView1<f64>) -> Array1<f64> {
1152            array![self.lambda_nonstiff * y[0]]
1153        }
1154
1155        fn implicit_part(&self, _t: f64, y: ArrayView1<f64>) -> Array1<f64> {
1156            array![self.lambda_stiff * y[0]]
1157        }
1158
1159        fn jacobian_implicit(&self, _t: f64, _y: ArrayView1<f64>) -> Array2<f64> {
1160            let mut j = Array2::<f64>::zeros((1, 1));
1161            j[[0, 0]] = self.lambda_stiff;
1162            j
1163        }
1164
1165        fn dimension(&self) -> usize {
1166            1
1167        }
1168    }
1169
1170    #[test]
1171    fn test_imex_euler_decay() {
1172        // dy/dt = -10*y (stiff) + 0*y (nonstiff), y(0) = 1
1173        // exact: y(t) = exp(-10*t)
1174        let sys = StiffLinear {
1175            lambda_stiff: -10.0,
1176            lambda_nonstiff: 0.0,
1177        };
1178        let cfg = IMEXConfig {
1179            dt: 0.01,
1180            t_end: 1.0,
1181            newton_tol: 1e-12,
1182            ..IMEXConfig::default()
1183        };
1184        let result = imex_euler(&sys, 0.0, array![1.0], &cfg).expect("IMEX Euler failed");
1185
1186        let t_final = *result.t.last().expect("no time points");
1187        let y_final = result.y.last().expect("no solution")[0];
1188        let exact = (-10.0_f64 * t_final).exp();
1189
1190        assert!(
1191            (y_final - exact).abs() < 0.05,
1192            "IMEX Euler: y={} exact={} err={}",
1193            y_final,
1194            exact,
1195            (y_final - exact).abs()
1196        );
1197    }
1198
1199    #[test]
1200    fn test_imex_bdf2_decay() {
1201        let sys = StiffLinear {
1202            lambda_stiff: -5.0,
1203            lambda_nonstiff: -1.0,
1204        };
1205        let cfg = IMEXConfig {
1206            dt: 0.01,
1207            t_end: 0.5,
1208            newton_tol: 1e-12,
1209            ..IMEXConfig::default()
1210        };
1211        let result = imex_bdf2(&sys, 0.0, array![1.0], &cfg).expect("IMEX BDF2 failed");
1212
1213        let t_final = *result.t.last().expect("no time points");
1214        let y_final = result.y.last().expect("no solution")[0];
1215        let exact = (-6.0_f64 * t_final).exp();
1216
1217        assert!(
1218            (y_final - exact).abs() < 0.02,
1219            "IMEX BDF2: y={} exact={} err={}",
1220            y_final,
1221            exact,
1222            (y_final - exact).abs()
1223        );
1224    }
1225
1226    #[test]
1227    fn test_imex_ark_ssp2_decay() {
1228        let sys = StiffLinear {
1229            lambda_stiff: -5.0,
1230            lambda_nonstiff: -1.0,
1231        };
1232        let cfg = IMEXConfig {
1233            dt: 0.01,
1234            t_end: 0.5,
1235            newton_tol: 1e-12,
1236            ..IMEXConfig::default()
1237        };
1238        let result = imex_ark_ssp2(&sys, 0.0, array![1.0], &cfg).expect("IMEX ARK SSP2 failed");
1239
1240        let t_final = *result.t.last().expect("no time points");
1241        let y_final = result.y.last().expect("no solution")[0];
1242        let exact = (-6.0_f64 * t_final).exp();
1243
1244        assert!(
1245            (y_final - exact).abs() < 0.01,
1246            "IMEX ARK SSP2: y={} exact={} err={}",
1247            y_final,
1248            exact,
1249            (y_final - exact).abs()
1250        );
1251    }
1252
1253    #[test]
1254    fn test_imex_ark_ssp3_decay() {
1255        let sys = StiffLinear {
1256            lambda_stiff: -5.0,
1257            lambda_nonstiff: -1.0,
1258        };
1259        let cfg = IMEXConfig {
1260            dt: 0.01,
1261            t_end: 0.5,
1262            newton_tol: 1e-12,
1263            ..IMEXConfig::default()
1264        };
1265        let result = imex_ark_ssp3(&sys, 0.0, array![1.0], &cfg).expect("IMEX ARK SSP3 failed");
1266
1267        let t_final = *result.t.last().expect("no time points");
1268        let y_final = result.y.last().expect("no solution")[0];
1269        let exact = (-6.0_f64 * t_final).exp();
1270
1271        assert!(
1272            (y_final - exact).abs() < 0.01,
1273            "IMEX ARK SSP3: y={} exact={} err={}",
1274            y_final,
1275            exact,
1276            (y_final - exact).abs()
1277        );
1278    }
1279
1280    #[test]
1281    fn test_imex_midpoint_decay() {
1282        let sys = StiffLinear {
1283            lambda_stiff: -5.0,
1284            lambda_nonstiff: -1.0,
1285        };
1286        let cfg = IMEXConfig {
1287            dt: 0.01,
1288            t_end: 0.5,
1289            newton_tol: 1e-12,
1290            ..IMEXConfig::default()
1291        };
1292        let result = imex_midpoint(&sys, 0.0, array![1.0], &cfg).expect("IMEX Midpoint failed");
1293
1294        let t_final = *result.t.last().expect("no time points");
1295        let y_final = result.y.last().expect("no solution")[0];
1296        let exact = (-6.0_f64 * t_final).exp();
1297
1298        assert!(
1299            (y_final - exact).abs() < 0.01,
1300            "IMEX Midpoint: y={} exact={} err={}",
1301            y_final,
1302            exact,
1303            (y_final - exact).abs()
1304        );
1305    }
1306}