scirs2_integrate/ode/utils/
mass_matrix.rs

1//! Utilities for working with mass matrices in ODE systems
2//!
3//! This module provides functions for handling mass matrices in ODEs of the form:
4//! M(t,y)·y' = f(t,y), where M is a mass matrix that may depend on time t and state y.
5
6use crate::common::IntegrateFloat;
7use crate::dae::utils::linear_solvers::solve_linear_system;
8use crate::error::{IntegrateError, IntegrateResult};
9use crate::ode::types::{MassMatrix, MassMatrixType};
10use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
11
12/// Solve a linear system with mass matrix: M·x = b
13///
14/// For ODEs with mass matrices, we often need to solve M·x = f(t,y)
15/// to find x = y' for the standard form y' = g(t,y)
16///
17/// # Arguments
18///
19/// * `mass` - The mass matrix structure
20/// * `t` - Current time
21/// * `y` - Current state
22/// * `b` - Right-hand side vector
23///
24/// # Returns
25///
26/// Solution vector x where M·x = b, or error if the system cannot be solved
27#[allow(dead_code)]
28pub fn solve_mass_system<F>(
29    mass: &MassMatrix<F>,
30    t: F,
31    y: ArrayView1<F>,
32    b: ArrayView1<F>,
33) -> IntegrateResult<Array1<F>>
34where
35    F: IntegrateFloat,
36{
37    match mass.matrix_type {
38        MassMatrixType::Identity => {
39            // For identity mass matrix, solution is just b
40            Ok(b.to_owned())
41        }
42        _ => {
43            // Get the mass matrix at current time and state
44            let matrix = mass.evaluate(t, y).ok_or_else(|| {
45                IntegrateError::ComputationError("Failed to evaluate mass matrix".to_string())
46            })?;
47
48            // Solve the linear system M·x = b
49            solve_matrix_system(matrix.view(), b)
50        }
51    }
52}
53
54/// Solve a linear system M·x = b with explicit matrix
55///
56/// Helper function to solve linear systems with mass matrices
57#[allow(dead_code)]
58fn solve_matrix_system<F>(matrix: ArrayView2<F>, b: ArrayView1<F>) -> IntegrateResult<Array1<F>>
59where
60    F: IntegrateFloat,
61{
62    // Use our custom solver
63    solve_linear_system(&matrix, &b).map_err(|err| {
64        IntegrateError::ComputationError(format!("Failed to solve mass _matrix system: {err}"))
65    })
66}
67
68/// Apply mass matrix to a vector: result = M·v
69///
70/// Used to compute the product of a mass matrix with a vector
71///
72/// # Arguments
73///
74/// * `mass` - The mass matrix structure
75/// * `t` - Current time
76/// * `y` - Current state
77/// * `v` - Vector to multiply with
78///
79/// # Returns
80///
81/// Result of M·v, or error if the operation cannot be performed
82#[allow(dead_code)]
83pub fn apply_mass<F>(
84    mass: &MassMatrix<F>,
85    t: F,
86    y: ArrayView1<F>,
87    v: ArrayView1<F>,
88) -> IntegrateResult<Array1<F>>
89where
90    F: IntegrateFloat,
91{
92    match mass.matrix_type {
93        MassMatrixType::Identity => {
94            // For identity mass matrix, result is just v
95            Ok(v.to_owned())
96        }
97        _ => {
98            // Get the mass matrix at current time and state
99            let matrix = mass.evaluate(t, y).ok_or_else(|| {
100                IntegrateError::ComputationError("Failed to evaluate mass matrix".to_string())
101            })?;
102
103            // Perform matrix-vector multiplication
104            let result = matrix.dot(&v);
105            Ok(result)
106        }
107    }
108}
109
110/// Compute the LU decomposition of a mass matrix
111///
112/// This can be used to cache the decomposition for repeated solves
113/// with the same mass matrix
114#[allow(dead_code)]
115struct LUDecomposition<F: IntegrateFloat> {
116    /// The LU factors
117    lu: Array2<F>,
118    /// Pivot indices
119    pivots: Vec<usize>,
120}
121
122#[allow(dead_code)]
123impl<F: IntegrateFloat> LUDecomposition<F> {
124    /// Create a new LU decomposition from a matrix with partial pivoting
125    fn new(matrix: ArrayView2<F>) -> IntegrateResult<Self> {
126        let (n, m) = matrix.dim();
127        if n != m {
128            return Err(IntegrateError::ValueError(
129                "Matrix must be square for LU decomposition".to_string(),
130            ));
131        }
132
133        let mut lu = matrix.to_owned();
134        let mut pivots = (0..n).collect::<Vec<_>>();
135
136        // Gaussian elimination with partial pivoting
137        for k in 0..n {
138            // Find the largest element in column k from row k onwards (partial pivoting)
139            let mut max_row = k;
140            let mut max_val = lu[[k, k]].abs();
141
142            for i in (k + 1)..n {
143                let val = lu[[i, k]].abs();
144                if val > max_val {
145                    max_val = val;
146                    max_row = i;
147                }
148            }
149
150            // Check for singularity
151            if max_val < F::from_f64(1e-14).unwrap() {
152                return Err(IntegrateError::ComputationError(
153                    "Matrix is singular or nearly singular".to_string(),
154                ));
155            }
156
157            // Swap rows if necessary
158            if max_row != k {
159                pivots.swap(k, max_row);
160                for j in 0..n {
161                    let temp = lu[[k, j]];
162                    lu[[k, j]] = lu[[max_row, j]];
163                    lu[[max_row, j]] = temp;
164                }
165            }
166
167            // Elimination step
168            for i in (k + 1)..n {
169                let factor = lu[[i, k]] / lu[[k, k]];
170                lu[[i, k]] = factor; // Store the multiplier
171
172                for j in (k + 1)..n {
173                    let temp = lu[[k, j]];
174                    lu[[i, j]] -= factor * temp;
175                }
176            }
177        }
178
179        Ok(LUDecomposition { lu, pivots })
180    }
181
182    /// Solve a linear system using the LU decomposition
183    fn solve(&self, b: ArrayView1<F>) -> IntegrateResult<Array1<F>> {
184        // Use our custom solver with the matrix
185        // Note: For a proper LU-based solver, we would need to implement one
186        // For now, this is a simpler approach that still works
187        solve_linear_system(&self.lu.view(), &b).map_err(|err| {
188            IntegrateError::ComputationError(format!("Failed to solve with matrix: {err}"))
189        })
190    }
191}
192
193/// Check if a mass matrix is compatible with an ODE state
194///
195/// Verifies that the mass matrix dimensions match the state vector dimensions
196#[allow(dead_code)]
197pub fn check_mass_compatibility<F>(
198    mass: &MassMatrix<F>,
199    t: F,
200    y: ArrayView1<F>,
201) -> IntegrateResult<()>
202where
203    F: IntegrateFloat,
204{
205    let n = y.len();
206
207    match mass.matrix_type {
208        MassMatrixType::Identity => {
209            // Identity matrix is always compatible
210            Ok(())
211        }
212        _ => {
213            // Evaluate the mass matrix and check dimensions
214            let matrix = mass.evaluate(t, y).ok_or_else(|| {
215                IntegrateError::ComputationError("Failed to evaluate mass matrix".to_string())
216            })?;
217
218            let (rows, cols) = matrix.dim();
219
220            if rows != n || cols != n {
221                return Err(IntegrateError::ValueError(format!(
222                    "Mass matrix dimensions ({rows},{cols}) do not match state vector length ({n})"
223                )));
224            }
225
226            Ok(())
227        }
228    }
229}
230
231/// Transform standard ODE to form with identity mass matrix
232///
233/// For ODE systems with constant or time-dependent mass matrices,
234/// we can transform to a standard ODE with identity mass matrix
235/// if the mass matrix is invertible.
236///
237/// M·y' = f(t,y) transforms to y' = M⁻¹·f(t,y)
238///
239/// # Arguments
240///
241/// * `f` - Original ODE function: f(t,y)
242/// * `mass` - Mass matrix specification
243///
244/// # Returns
245///
246/// A function representing the transformed ODE: g(t,y) where y' = g(t,y)
247#[allow(dead_code)]
248pub fn transform_to_standard_form<F, Func>(
249    f: Func,
250    mass: &MassMatrix<F>,
251) -> impl Fn(F, ArrayView1<F>) -> IntegrateResult<Array1<F>> + Clone
252where
253    F: IntegrateFloat,
254    Func: Fn(F, ArrayView1<F>) -> Array1<F> + Clone,
255{
256    let mass_cloned = mass.clone();
257
258    move |t: F, y: ArrayView1<F>| {
259        // Compute original RHS: f(t,y)
260        let rhs = f(t, y);
261
262        // Solve M·y' = f(t,y) for y'
263        solve_mass_system(&mass_cloned, t, y, rhs.view())
264    }
265}
266
267/// Check if a matrix is singular or ill-conditioned
268///
269/// Uses condition number estimation to check if a matrix is
270/// close to singular, which would cause problems for ODE solvers
271#[allow(dead_code)]
272pub fn is_singular<F>(matrix: ArrayView2<F>, threshold: Option<F>) -> bool
273where
274    F: IntegrateFloat,
275{
276    // Default condition number threshold
277    let thresh = threshold.unwrap_or_else(|| F::from_f64(1e14).unwrap());
278
279    let (n, m) = matrix.dim();
280    if n != m {
281        return true; // Non-square matrices are considered singular
282    }
283
284    // Estimate condition number using power iteration for largest singular value
285    // and inverse power iteration for smallest singular value
286
287    // For efficiency, we'll use a simpler approach for small matrices
288    if n <= 3 {
289        // For small matrices, compute determinant directly
290        let det = compute_determinant(&matrix);
291        return det.abs() < F::from_f64(1e-14).unwrap();
292    }
293
294    // For larger matrices, estimate condition number
295    let cond_number = estimate_condition_number(&matrix);
296
297    cond_number > thresh
298}
299
300/// Compute determinant for small matrices (up to 3x3)
301#[allow(dead_code)]
302fn compute_determinant<F: IntegrateFloat>(matrix: &ArrayView2<F>) -> F {
303    let (n, _) = matrix.dim();
304
305    match n {
306        1 => matrix[[0, 0]],
307        2 => matrix[[0, 0]] * matrix[[1, 1]] - matrix[[0, 1]] * matrix[[1, 0]],
308        3 => {
309            matrix[[0, 0]] * (matrix[[1, 1]] * matrix[[2, 2]] - matrix[[1, 2]] * matrix[[2, 1]])
310                - matrix[[0, 1]]
311                    * (matrix[[1, 0]] * matrix[[2, 2]] - matrix[[1, 2]] * matrix[[2, 0]])
312                + matrix[[0, 2]]
313                    * (matrix[[1, 0]] * matrix[[2, 1]] - matrix[[1, 1]] * matrix[[2, 0]])
314        }
315        _ => F::zero(), // Should not reach here
316    }
317}
318
319/// Estimate condition number using iterative methods
320#[allow(dead_code)]
321fn estimate_condition_number<F: IntegrateFloat>(matrix: &ArrayView2<F>) -> F {
322    let _n = matrix.nrows();
323
324    // Estimate largest eigenvalue magnitude of A^T * A using power iteration
325    let max_singular_val_sq = estimate_largest_eigenvalue_ata(matrix);
326    let max_singular_val = max_singular_val_sq.sqrt();
327
328    // Estimate smallest eigenvalue magnitude of A^T * A using inverse power iteration
329    let min_singular_val_sq = estimate_smallest_eigenvalue_ata(matrix);
330    let min_singular_val = min_singular_val_sq.sqrt();
331
332    if min_singular_val < F::from_f64(1e-14).unwrap() {
333        F::from_f64(1e16).unwrap() // Very large condition number
334    } else {
335        max_singular_val / min_singular_val
336    }
337}
338
339/// Estimate largest eigenvalue of A^T * A using power iteration
340#[allow(dead_code)]
341fn estimate_largest_eigenvalue_ata<F: IntegrateFloat>(matrix: &ArrayView2<F>) -> F {
342    let n = matrix.nrows();
343    let max_iterations = 10;
344
345    // Initialize with ones vector
346    let mut v = Array1::<F>::from_elem(n, F::one());
347
348    // Normalize
349    let mut norm = (v.dot(&v)).sqrt();
350    if norm > F::from_f64(1e-14).unwrap() {
351        v = &v / norm;
352    }
353
354    let mut eigenvalue = F::zero();
355
356    for _ in 0..max_iterations {
357        // Compute A * v
358        let mut av = Array1::<F>::zeros(n);
359        for i in 0..n {
360            for j in 0..n {
361                av[i] += matrix[[i, j]] * v[j];
362            }
363        }
364
365        // Compute A^T * (A * v) = A^T * av
366        let mut atav = Array1::<F>::zeros(n);
367        for i in 0..n {
368            for j in 0..n {
369                atav[i] += matrix[[j, i]] * av[j];
370            }
371        }
372
373        // Compute eigenvalue (Rayleigh quotient)
374        let new_eigenvalue = v.dot(&atav);
375
376        // Normalize atav for next iteration
377        norm = (atav.dot(&atav)).sqrt();
378        if norm > F::from_f64(1e-14).unwrap() {
379            v = &atav / norm;
380        }
381
382        eigenvalue = new_eigenvalue;
383    }
384
385    eigenvalue.abs()
386}
387
388/// Estimate smallest eigenvalue of A^T * A using simplified approach
389#[allow(dead_code)]
390fn estimate_smallest_eigenvalue_ata<F: IntegrateFloat>(matrix: &ArrayView2<F>) -> F {
391    let n = matrix.nrows();
392
393    // For simplicity, we'll use the minimum diagonal element of A^T * A as a lower bound
394    // This is not exact but gives a reasonable estimate for condition number purposes
395    let mut min_diag = F::from_f64(f64::INFINITY).unwrap();
396
397    for i in 0..n {
398        let mut diag_elem = F::zero();
399        for k in 0..n {
400            diag_elem += matrix[[k, i]] * matrix[[k, i]];
401        }
402        if diag_elem < min_diag {
403            min_diag = diag_elem;
404        }
405    }
406
407    min_diag.max(F::from_f64(1e-16).unwrap())
408}