1use crate::common::IntegrateFloat;
7use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
8use std::fmt::Debug;
9use std::sync::Arc;
10
11pub type TimeFunction<F> = Arc<dyn Fn(F) -> Array2<F> + Send + Sync>;
13
14pub type StateFunction<F> = Arc<dyn Fn(F, ArrayView1<F>) -> Array2<F> + Send + Sync>;
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
19pub enum ODEMethod {
20 Euler,
22 RK4,
24 #[default]
27 RK45,
28 RK23,
31 Bdf,
35 DOP853,
39 Radau,
43 LSODA,
47 EnhancedLSODA,
51 EnhancedBDF,
56}
57
58#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
60pub enum MassMatrixType {
61 #[default]
63 Identity,
64 Constant,
66 TimeDependent,
68 StateDependent,
70}
71
72pub struct MassMatrix<F: IntegrateFloat> {
74 pub matrix_type: MassMatrixType,
76 pub constant_matrix: Option<scirs2_core::ndarray::Array2<F>>,
78 pub time_function: Option<TimeFunction<F>>,
80 pub state_function: Option<StateFunction<F>>,
82 pub is_banded: bool,
84 pub lower_bandwidth: Option<usize>,
86 pub upper_bandwidth: Option<usize>,
88}
89
90impl<F: IntegrateFloat> Debug for MassMatrix<F> {
91 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
92 f.debug_struct("MassMatrix")
93 .field("matrix_type", &self.matrix_type)
94 .field("constant_matrix", &self.constant_matrix)
95 .field("time_function", &self.time_function.is_some())
96 .field("state_function", &self.state_function.is_some())
97 .field("is_banded", &self.is_banded)
98 .field("lower_bandwidth", &self.lower_bandwidth)
99 .field("upper_bandwidth", &self.upper_bandwidth)
100 .finish()
101 }
102}
103
104impl<F: IntegrateFloat> Clone for MassMatrix<F> {
105 fn clone(&self) -> Self {
106 MassMatrix {
107 matrix_type: self.matrix_type,
108 constant_matrix: self.constant_matrix.clone(),
109 time_function: self.time_function.clone(),
110 state_function: self.state_function.clone(),
111 is_banded: self.is_banded,
112 lower_bandwidth: self.lower_bandwidth,
113 upper_bandwidth: self.upper_bandwidth,
114 }
115 }
116}
117
118impl<F: IntegrateFloat> MassMatrix<F> {
119 pub fn identity() -> Self {
121 MassMatrix {
122 matrix_type: MassMatrixType::Identity,
123 constant_matrix: None,
124 time_function: None,
125 state_function: None,
126 is_banded: false,
127 lower_bandwidth: None,
128 upper_bandwidth: None,
129 }
130 }
131
132 pub fn constant(matrix: scirs2_core::ndarray::Array2<F>) -> Self {
134 MassMatrix {
135 matrix_type: MassMatrixType::Constant,
136 constant_matrix: Some(matrix),
137 time_function: None,
138 state_function: None,
139 is_banded: false,
140 lower_bandwidth: None,
141 upper_bandwidth: None,
142 }
143 }
144
145 pub fn time_dependent<Func>(func: Func) -> Self
147 where
148 Func: Fn(F) -> scirs2_core::ndarray::Array2<F> + Send + Sync + 'static,
149 {
150 MassMatrix {
151 matrix_type: MassMatrixType::TimeDependent,
152 constant_matrix: None,
153 time_function: Some(Arc::new(func)),
154 state_function: None,
155 is_banded: false,
156 lower_bandwidth: None,
157 upper_bandwidth: None,
158 }
159 }
160
161 pub fn state_dependent<Func>(func: Func) -> Self
163 where
164 Func: Fn(F, scirs2_core::ndarray::ArrayView1<F>) -> scirs2_core::ndarray::Array2<F>
165 + Send
166 + Sync
167 + 'static,
168 {
169 MassMatrix {
170 matrix_type: MassMatrixType::StateDependent,
171 constant_matrix: None,
172 time_function: None,
173 state_function: Some(Arc::new(func)),
174 is_banded: false,
175 lower_bandwidth: None,
176 upper_bandwidth: None,
177 }
178 }
179
180 pub fn with_bandwidth(&mut self, lower: usize, upper: usize) -> &mut Self {
182 self.is_banded = true;
183 self.lower_bandwidth = Some(lower);
184 self.upper_bandwidth = Some(upper);
185 self
186 }
187
188 pub fn evaluate(
190 &self,
191 t: F,
192 y: scirs2_core::ndarray::ArrayView1<F>,
193 ) -> Option<scirs2_core::ndarray::Array2<F>> {
194 match self.matrix_type {
195 MassMatrixType::Identity => None, MassMatrixType::Constant => self.constant_matrix.clone(),
197 MassMatrixType::TimeDependent => self.time_function.as_ref().map(|f| f(t)),
198 MassMatrixType::StateDependent => self.state_function.as_ref().map(|f| f(t, y)),
199 }
200 }
201}
202
203#[derive(Debug, Clone)]
205pub struct ODEOptions<F: IntegrateFloat> {
206 pub method: ODEMethod,
208 pub rtol: F,
210 pub atol: F,
212 pub h0: Option<F>,
214 pub max_steps: usize,
216 pub max_step: Option<F>,
218 pub min_step: Option<F>,
220 pub dense_output: bool,
222 pub max_order: Option<usize>,
224 pub jac: Option<Array1<F>>,
226 pub use_banded_jacobian: bool,
228 pub ml: Option<usize>,
230 pub mu: Option<usize>,
232 pub mass_matrix: Option<MassMatrix<F>>,
234 pub jacobian_strategy: Option<crate::ode::utils::jacobian::JacobianStrategy>,
236}
237
238impl<F: IntegrateFloat> Default for ODEOptions<F> {
239 fn default() -> Self {
240 ODEOptions {
241 method: ODEMethod::default(),
242 rtol: F::from_f64(1e-3).expect("Operation failed"),
243 atol: F::from_f64(1e-6).expect("Operation failed"),
244 h0: None,
245 max_steps: 500,
246 max_step: None,
247 min_step: None,
248 dense_output: false,
249 max_order: None,
250 jac: None,
251 use_banded_jacobian: false,
252 ml: None,
253 mu: None,
254 mass_matrix: None,
255 jacobian_strategy: None, }
257 }
258}
259
260#[derive(Debug, Clone)]
262pub struct ODEResult<F: IntegrateFloat> {
263 pub t: Vec<F>,
265 pub y: Vec<Array1<F>>,
267 pub success: bool,
269 pub message: Option<String>,
271 pub n_eval: usize,
273 pub n_steps: usize,
275 pub n_accepted: usize,
277 pub n_rejected: usize,
279 pub n_lu: usize,
281 pub n_jac: usize,
283 pub method: ODEMethod,
285}