scirs2_integrate/ode/utils/
mass_matrix.rs1use crate::common::IntegrateFloat;
7use crate::dae::utils::linear_solvers::solve_linear_system;
8use crate::error::{IntegrateError, IntegrateResult};
9use crate::ode::types::{MassMatrix, MassMatrixType};
10use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
11
12#[allow(dead_code)]
28pub fn solve_mass_system<F>(
29 mass: &MassMatrix<F>,
30 t: F,
31 y: ArrayView1<F>,
32 b: ArrayView1<F>,
33) -> IntegrateResult<Array1<F>>
34where
35 F: IntegrateFloat,
36{
37 match mass.matrix_type {
38 MassMatrixType::Identity => {
39 Ok(b.to_owned())
41 }
42 _ => {
43 let matrix = mass.evaluate(t, y).ok_or_else(|| {
45 IntegrateError::ComputationError("Failed to evaluate mass matrix".to_string())
46 })?;
47
48 solve_matrix_system(matrix.view(), b)
50 }
51 }
52}
53
54#[allow(dead_code)]
58fn solve_matrix_system<F>(matrix: ArrayView2<F>, b: ArrayView1<F>) -> IntegrateResult<Array1<F>>
59where
60 F: IntegrateFloat,
61{
62 solve_linear_system(&matrix, &b).map_err(|err| {
64 IntegrateError::ComputationError(format!("Failed to solve mass _matrix system: {err}"))
65 })
66}
67
68#[allow(dead_code)]
83pub fn apply_mass<F>(
84 mass: &MassMatrix<F>,
85 t: F,
86 y: ArrayView1<F>,
87 v: ArrayView1<F>,
88) -> IntegrateResult<Array1<F>>
89where
90 F: IntegrateFloat,
91{
92 match mass.matrix_type {
93 MassMatrixType::Identity => {
94 Ok(v.to_owned())
96 }
97 _ => {
98 let matrix = mass.evaluate(t, y).ok_or_else(|| {
100 IntegrateError::ComputationError("Failed to evaluate mass matrix".to_string())
101 })?;
102
103 let result = matrix.dot(&v);
105 Ok(result)
106 }
107 }
108}
109
110#[allow(dead_code)]
115struct LUDecomposition<F: IntegrateFloat> {
116 lu: Array2<F>,
118 pivots: Vec<usize>,
120}
121
122#[allow(dead_code)]
123impl<F: IntegrateFloat> LUDecomposition<F> {
124 fn new(matrix: ArrayView2<F>) -> IntegrateResult<Self> {
126 let (n, m) = matrix.dim();
127 if n != m {
128 return Err(IntegrateError::ValueError(
129 "Matrix must be square for LU decomposition".to_string(),
130 ));
131 }
132
133 let mut lu = matrix.to_owned();
134 let mut pivots = (0..n).collect::<Vec<_>>();
135
136 for k in 0..n {
138 let mut max_row = k;
140 let mut max_val = lu[[k, k]].abs();
141
142 for i in (k + 1)..n {
143 let val = lu[[i, k]].abs();
144 if val > max_val {
145 max_val = val;
146 max_row = i;
147 }
148 }
149
150 if max_val < F::from_f64(1e-14).unwrap() {
152 return Err(IntegrateError::ComputationError(
153 "Matrix is singular or nearly singular".to_string(),
154 ));
155 }
156
157 if max_row != k {
159 pivots.swap(k, max_row);
160 for j in 0..n {
161 let temp = lu[[k, j]];
162 lu[[k, j]] = lu[[max_row, j]];
163 lu[[max_row, j]] = temp;
164 }
165 }
166
167 for i in (k + 1)..n {
169 let factor = lu[[i, k]] / lu[[k, k]];
170 lu[[i, k]] = factor; for j in (k + 1)..n {
173 let temp = lu[[k, j]];
174 lu[[i, j]] -= factor * temp;
175 }
176 }
177 }
178
179 Ok(LUDecomposition { lu, pivots })
180 }
181
182 fn solve(&self, b: ArrayView1<F>) -> IntegrateResult<Array1<F>> {
184 solve_linear_system(&self.lu.view(), &b).map_err(|err| {
188 IntegrateError::ComputationError(format!("Failed to solve with matrix: {err}"))
189 })
190 }
191}
192
193#[allow(dead_code)]
197pub fn check_mass_compatibility<F>(
198 mass: &MassMatrix<F>,
199 t: F,
200 y: ArrayView1<F>,
201) -> IntegrateResult<()>
202where
203 F: IntegrateFloat,
204{
205 let n = y.len();
206
207 match mass.matrix_type {
208 MassMatrixType::Identity => {
209 Ok(())
211 }
212 _ => {
213 let matrix = mass.evaluate(t, y).ok_or_else(|| {
215 IntegrateError::ComputationError("Failed to evaluate mass matrix".to_string())
216 })?;
217
218 let (rows, cols) = matrix.dim();
219
220 if rows != n || cols != n {
221 return Err(IntegrateError::ValueError(format!(
222 "Mass matrix dimensions ({rows},{cols}) do not match state vector length ({n})"
223 )));
224 }
225
226 Ok(())
227 }
228 }
229}
230
231#[allow(dead_code)]
248pub fn transform_to_standard_form<F, Func>(
249 f: Func,
250 mass: &MassMatrix<F>,
251) -> impl Fn(F, ArrayView1<F>) -> IntegrateResult<Array1<F>> + Clone
252where
253 F: IntegrateFloat,
254 Func: Fn(F, ArrayView1<F>) -> Array1<F> + Clone,
255{
256 let mass_cloned = mass.clone();
257
258 move |t: F, y: ArrayView1<F>| {
259 let rhs = f(t, y);
261
262 solve_mass_system(&mass_cloned, t, y, rhs.view())
264 }
265}
266
267#[allow(dead_code)]
272pub fn is_singular<F>(matrix: ArrayView2<F>, threshold: Option<F>) -> bool
273where
274 F: IntegrateFloat,
275{
276 let thresh = threshold.unwrap_or_else(|| F::from_f64(1e14).unwrap());
278
279 let (n, m) = matrix.dim();
280 if n != m {
281 return true; }
283
284 if n <= 3 {
289 let det = compute_determinant(&matrix);
291 return det.abs() < F::from_f64(1e-14).unwrap();
292 }
293
294 let cond_number = estimate_condition_number(&matrix);
296
297 cond_number > thresh
298}
299
300#[allow(dead_code)]
302fn compute_determinant<F: IntegrateFloat>(matrix: &ArrayView2<F>) -> F {
303 let (n, _) = matrix.dim();
304
305 match n {
306 1 => matrix[[0, 0]],
307 2 => matrix[[0, 0]] * matrix[[1, 1]] - matrix[[0, 1]] * matrix[[1, 0]],
308 3 => {
309 matrix[[0, 0]] * (matrix[[1, 1]] * matrix[[2, 2]] - matrix[[1, 2]] * matrix[[2, 1]])
310 - matrix[[0, 1]]
311 * (matrix[[1, 0]] * matrix[[2, 2]] - matrix[[1, 2]] * matrix[[2, 0]])
312 + matrix[[0, 2]]
313 * (matrix[[1, 0]] * matrix[[2, 1]] - matrix[[1, 1]] * matrix[[2, 0]])
314 }
315 _ => F::zero(), }
317}
318
319#[allow(dead_code)]
321fn estimate_condition_number<F: IntegrateFloat>(matrix: &ArrayView2<F>) -> F {
322 let _n = matrix.nrows();
323
324 let max_singular_val_sq = estimate_largest_eigenvalue_ata(matrix);
326 let max_singular_val = max_singular_val_sq.sqrt();
327
328 let min_singular_val_sq = estimate_smallest_eigenvalue_ata(matrix);
330 let min_singular_val = min_singular_val_sq.sqrt();
331
332 if min_singular_val < F::from_f64(1e-14).unwrap() {
333 F::from_f64(1e16).unwrap() } else {
335 max_singular_val / min_singular_val
336 }
337}
338
339#[allow(dead_code)]
341fn estimate_largest_eigenvalue_ata<F: IntegrateFloat>(matrix: &ArrayView2<F>) -> F {
342 let n = matrix.nrows();
343 let max_iterations = 10;
344
345 let mut v = Array1::<F>::from_elem(n, F::one());
347
348 let mut norm = (v.dot(&v)).sqrt();
350 if norm > F::from_f64(1e-14).unwrap() {
351 v = &v / norm;
352 }
353
354 let mut eigenvalue = F::zero();
355
356 for _ in 0..max_iterations {
357 let mut av = Array1::<F>::zeros(n);
359 for i in 0..n {
360 for j in 0..n {
361 av[i] += matrix[[i, j]] * v[j];
362 }
363 }
364
365 let mut atav = Array1::<F>::zeros(n);
367 for i in 0..n {
368 for j in 0..n {
369 atav[i] += matrix[[j, i]] * av[j];
370 }
371 }
372
373 let new_eigenvalue = v.dot(&atav);
375
376 norm = (atav.dot(&atav)).sqrt();
378 if norm > F::from_f64(1e-14).unwrap() {
379 v = &atav / norm;
380 }
381
382 eigenvalue = new_eigenvalue;
383 }
384
385 eigenvalue.abs()
386}
387
388#[allow(dead_code)]
390fn estimate_smallest_eigenvalue_ata<F: IntegrateFloat>(matrix: &ArrayView2<F>) -> F {
391 let n = matrix.nrows();
392
393 let mut min_diag = F::from_f64(f64::INFINITY).unwrap();
396
397 for i in 0..n {
398 let mut diag_elem = F::zero();
399 for k in 0..n {
400 diag_elem += matrix[[k, i]] * matrix[[k, i]];
401 }
402 if diag_elem < min_diag {
403 min_diag = diag_elem;
404 }
405 }
406
407 min_diag.max(F::from_f64(1e-16).unwrap())
408}