Skip to main content

scirs2_integrate/ode/
types.rs

1//! Types for ODE solver module
2//!
3//! This module defines the core types used by ODE solvers,
4//! including method enums, options, and results.
5
6use crate::common::IntegrateFloat;
7use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
8use std::fmt::Debug;
9use std::sync::Arc;
10
11/// Type alias for time-dependent matrix function
12pub type TimeFunction<F> = Arc<dyn Fn(F) -> Array2<F> + Send + Sync>;
13
14/// Type alias for state-dependent matrix function  
15pub type StateFunction<F> = Arc<dyn Fn(F, ArrayView1<F>) -> Array2<F> + Send + Sync>;
16
17/// ODE solver method
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
19pub enum ODEMethod {
20    /// Euler method (first-order)
21    Euler,
22    /// Fourth-order Runge-Kutta method (fixed step size)
23    RK4,
24    /// Dormand-Prince method (variable step size)
25    /// 5th order method with 4th order error estimate
26    #[default]
27    RK45,
28    /// Bogacki-Shampine method (variable step size)
29    /// 3rd order method with 2nd order error estimate
30    RK23,
31    /// Backward Differentiation Formula (BDF) method
32    /// Implicit method for stiff equations
33    /// Default is BDF order 2
34    Bdf,
35    /// Dormand-Prince method of order 8(5,3)
36    /// 8th order method with 5th order error estimate
37    /// High-accuracy explicit Runge-Kutta method
38    DOP853,
39    /// Implicit Runge-Kutta method of Radau IIA family
40    /// 5th order method with 3rd order error estimate
41    /// L-stable implicit method for stiff problems
42    Radau,
43    /// Livermore Solver for Ordinary Differential Equations with Automatic method switching
44    /// Automatically switches between Adams methods (non-stiff) and BDF (stiff)
45    /// Efficiently handles problems that change character during integration
46    LSODA,
47    /// Enhanced LSODA method with improved stiffness detection and method switching
48    /// Features better Jacobian handling, adaptive order selection, and robust error control
49    /// Provides detailed diagnostics about method switching decisions
50    EnhancedLSODA,
51    /// Enhanced BDF method with improved Jacobian handling and error estimation
52    /// Features intelligent Jacobian strategy selection based on problem size
53    /// Supports multiple Newton solver variants and provides better convergence
54    /// Includes specialized handling for banded matrices and adaptive order selection
55    EnhancedBDF,
56}
57
58/// Type of mass matrix for ODE system
59#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
60pub enum MassMatrixType {
61    /// Identity mass matrix (standard ODE)
62    #[default]
63    Identity,
64    /// Constant mass matrix
65    Constant,
66    /// Time-dependent mass matrix M(t)
67    TimeDependent,
68    /// State-dependent mass matrix M(t,y)
69    StateDependent,
70}
71
72/// Mass matrix for ODE system of the form M(t,y)·y' = f(t,y)
73pub struct MassMatrix<F: IntegrateFloat> {
74    /// Type of the mass matrix
75    pub matrix_type: MassMatrixType,
76    /// Constant mass matrix (if applicable)
77    pub constant_matrix: Option<scirs2_core::ndarray::Array2<F>>,
78    /// Function for time-dependent mass matrix
79    pub time_function: Option<TimeFunction<F>>,
80    /// Function for state-dependent mass matrix
81    pub state_function: Option<StateFunction<F>>,
82    /// Whether the mass matrix is sparse/banded
83    pub is_banded: bool,
84    /// Lower bandwidth for banded matrices
85    pub lower_bandwidth: Option<usize>,
86    /// Upper bandwidth for banded matrices
87    pub upper_bandwidth: Option<usize>,
88}
89
90impl<F: IntegrateFloat> Debug for MassMatrix<F> {
91    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
92        f.debug_struct("MassMatrix")
93            .field("matrix_type", &self.matrix_type)
94            .field("constant_matrix", &self.constant_matrix)
95            .field("time_function", &self.time_function.is_some())
96            .field("state_function", &self.state_function.is_some())
97            .field("is_banded", &self.is_banded)
98            .field("lower_bandwidth", &self.lower_bandwidth)
99            .field("upper_bandwidth", &self.upper_bandwidth)
100            .finish()
101    }
102}
103
104impl<F: IntegrateFloat> Clone for MassMatrix<F> {
105    fn clone(&self) -> Self {
106        MassMatrix {
107            matrix_type: self.matrix_type,
108            constant_matrix: self.constant_matrix.clone(),
109            time_function: self.time_function.clone(),
110            state_function: self.state_function.clone(),
111            is_banded: self.is_banded,
112            lower_bandwidth: self.lower_bandwidth,
113            upper_bandwidth: self.upper_bandwidth,
114        }
115    }
116}
117
118impl<F: IntegrateFloat> MassMatrix<F> {
119    /// Create a new identity mass matrix (standard ODE)
120    pub fn identity() -> Self {
121        MassMatrix {
122            matrix_type: MassMatrixType::Identity,
123            constant_matrix: None,
124            time_function: None,
125            state_function: None,
126            is_banded: false,
127            lower_bandwidth: None,
128            upper_bandwidth: None,
129        }
130    }
131
132    /// Create a new constant mass matrix
133    pub fn constant(matrix: scirs2_core::ndarray::Array2<F>) -> Self {
134        MassMatrix {
135            matrix_type: MassMatrixType::Constant,
136            constant_matrix: Some(matrix),
137            time_function: None,
138            state_function: None,
139            is_banded: false,
140            lower_bandwidth: None,
141            upper_bandwidth: None,
142        }
143    }
144
145    /// Create a new time-dependent mass matrix M(t)
146    pub fn time_dependent<Func>(func: Func) -> Self
147    where
148        Func: Fn(F) -> scirs2_core::ndarray::Array2<F> + Send + Sync + 'static,
149    {
150        MassMatrix {
151            matrix_type: MassMatrixType::TimeDependent,
152            constant_matrix: None,
153            time_function: Some(Arc::new(func)),
154            state_function: None,
155            is_banded: false,
156            lower_bandwidth: None,
157            upper_bandwidth: None,
158        }
159    }
160
161    /// Create a new state-dependent mass matrix M(t,y)
162    pub fn state_dependent<Func>(func: Func) -> Self
163    where
164        Func: Fn(F, scirs2_core::ndarray::ArrayView1<F>) -> scirs2_core::ndarray::Array2<F>
165            + Send
166            + Sync
167            + 'static,
168    {
169        MassMatrix {
170            matrix_type: MassMatrixType::StateDependent,
171            constant_matrix: None,
172            time_function: None,
173            state_function: Some(Arc::new(func)),
174            is_banded: false,
175            lower_bandwidth: None,
176            upper_bandwidth: None,
177        }
178    }
179
180    /// Set the matrix as banded with specified bandwidths
181    pub fn with_bandwidth(&mut self, lower: usize, upper: usize) -> &mut Self {
182        self.is_banded = true;
183        self.lower_bandwidth = Some(lower);
184        self.upper_bandwidth = Some(upper);
185        self
186    }
187
188    /// Get the mass matrix at a given time and state
189    pub fn evaluate(
190        &self,
191        t: F,
192        y: scirs2_core::ndarray::ArrayView1<F>,
193    ) -> Option<scirs2_core::ndarray::Array2<F>> {
194        match self.matrix_type {
195            MassMatrixType::Identity => None, // Identity is handled specially
196            MassMatrixType::Constant => self.constant_matrix.clone(),
197            MassMatrixType::TimeDependent => self.time_function.as_ref().map(|f| f(t)),
198            MassMatrixType::StateDependent => self.state_function.as_ref().map(|f| f(t, y)),
199        }
200    }
201}
202
203/// Options for controlling the behavior of ODE solvers
204#[derive(Debug, Clone)]
205pub struct ODEOptions<F: IntegrateFloat> {
206    /// The ODE solver method to use
207    pub method: ODEMethod,
208    /// Relative tolerance for error control
209    pub rtol: F,
210    /// Absolute tolerance for error control
211    pub atol: F,
212    /// Initial step size (optional, if not provided, it will be estimated)
213    pub h0: Option<F>,
214    /// Maximum number of steps to take
215    pub max_steps: usize,
216    /// Maximum step size (optional)
217    pub max_step: Option<F>,
218    /// Minimum step size (optional)
219    pub min_step: Option<F>,
220    /// Dense output flag - whether to enable dense output
221    pub dense_output: bool,
222    /// Maximum order for BDF method (1-5)
223    pub max_order: Option<usize>,
224    /// Jacobian matrix (optional, for implicit methods)
225    pub jac: Option<Array1<F>>,
226    /// Whether to use a banded Jacobian matrix
227    pub use_banded_jacobian: bool,
228    /// Number of lower diagonals for banded Jacobian
229    pub ml: Option<usize>,
230    /// Number of upper diagonals for banded Jacobian
231    pub mu: Option<usize>,
232    /// Mass matrix for M(t,y)·y' = f(t,y) form (optional)
233    pub mass_matrix: Option<MassMatrix<F>>,
234    /// Strategy for Jacobian approximation/computation
235    pub jacobian_strategy: Option<crate::ode::utils::jacobian::JacobianStrategy>,
236}
237
238impl<F: IntegrateFloat> Default for ODEOptions<F> {
239    fn default() -> Self {
240        ODEOptions {
241            method: ODEMethod::default(),
242            rtol: F::from_f64(1e-3).expect("Operation failed"),
243            atol: F::from_f64(1e-6).expect("Operation failed"),
244            h0: None,
245            max_steps: 500,
246            max_step: None,
247            min_step: None,
248            dense_output: false,
249            max_order: None,
250            jac: None,
251            use_banded_jacobian: false,
252            ml: None,
253            mu: None,
254            mass_matrix: None,
255            jacobian_strategy: None, // Defaults to Adaptive in JacobianManager
256        }
257    }
258}
259
260/// Result of ODE integration
261#[derive(Debug, Clone)]
262pub struct ODEResult<F: IntegrateFloat> {
263    /// Time points
264    pub t: Vec<F>,
265    /// Solution values at time points
266    pub y: Vec<Array1<F>>,
267    /// Whether the integration was successful
268    pub success: bool,
269    /// Status message
270    pub message: Option<String>,
271    /// Number of function evaluations
272    pub n_eval: usize,
273    /// Number of steps taken
274    pub n_steps: usize,
275    /// Number of accepted steps
276    pub n_accepted: usize,
277    /// Number of rejected steps
278    pub n_rejected: usize,
279    /// Number of LU decompositions
280    pub n_lu: usize,
281    /// Number of Jacobian evaluations
282    pub n_jac: usize,
283    /// The solver method used
284    pub method: ODEMethod,
285}