scirs2_integrate/ode/utils/linear_solvers/
mod.rs

1//! Linear solvers for ODE systems
2//!
3//! This module provides linear system solvers for use within ODE solvers.
4//! These replace the need for external linear algebra libraries like ndarray-linalg.
5
6use crate::error::{IntegrateError, IntegrateResult};
7use scirs2_core::ndarray::{Array1, ArrayView1, ArrayView2};
8use scirs2_core::numeric::{Float, FromPrimitive};
9use std::fmt::Debug;
10
11/// Enum for different types of linear solvers
12#[derive(Debug, Clone, Copy, PartialEq)]
13pub enum LinearSolverType {
14    /// Direct solver using LU decomposition
15    Direct,
16    /// Iterative solver (GMRES, etc.)
17    Iterative,
18    /// Automatic selection based on problem size
19    Auto,
20}
21
22/// Solve a linear system Ax = b using Gaussian elimination with partial pivoting
23///
24/// # Arguments
25/// * `a` - The coefficient matrix A
26/// * `b` - The right-hand side vector b
27///
28/// # Returns
29/// * `Result<Array1<F>, IntegrateError>` - The solution vector x
30#[allow(dead_code)]
31pub fn solve_linear_system<F>(a: &ArrayView2<F>, b: &ArrayView1<F>) -> IntegrateResult<Array1<F>>
32where
33    F: Float
34        + FromPrimitive
35        + Debug
36        + std::ops::AddAssign
37        + std::ops::SubAssign
38        + std::ops::MulAssign,
39{
40    // Get dimensions
41    let n = a.shape()[0];
42
43    // Check that A is square
44    if a.shape()[0] != a.shape()[1] {
45        return Err(IntegrateError::ValueError(format!(
46            "Matrix must be square to solve linear system, got shape {:?}",
47            a.shape()
48        )));
49    }
50
51    // Check that b has compatible dimensions
52    if b.len() != n {
53        return Err(IntegrateError::ValueError(
54            format!("Right-hand side vector dimensions incompatible with matrix: matrix has {} rows but vector has {} elements", 
55                n, b.len())
56        ));
57    }
58
59    // Create copies of A and b that we can modify
60    let mut a_copy = a.to_owned();
61    let mut b_copy = b.to_owned();
62
63    // Gaussian elimination with partial pivoting
64    for k in 0..n {
65        // Find pivot
66        let mut pivot_idx = k;
67        let mut max_val = a_copy[[k, k]].abs();
68
69        for i in (k + 1)..n {
70            let val = a_copy[[i, k]].abs();
71            if val > max_val {
72                max_val = val;
73                pivot_idx = i;
74            }
75        }
76
77        // Check for singularity
78        if max_val < F::from_f64(1e-14).unwrap() {
79            return Err(IntegrateError::ValueError(
80                "Matrix is singular or nearly singular".to_string(),
81            ));
82        }
83
84        // Swap rows if necessary
85        if pivot_idx != k {
86            // Swap rows in A
87            for j in k..n {
88                let temp = a_copy[[k, j]];
89                a_copy[[k, j]] = a_copy[[pivot_idx, j]];
90                a_copy[[pivot_idx, j]] = temp;
91            }
92
93            // Swap elements in b
94            let temp = b_copy[k];
95            b_copy[k] = b_copy[pivot_idx];
96            b_copy[pivot_idx] = temp;
97        }
98
99        // Eliminate below the pivot
100        for i in (k + 1)..n {
101            let factor = a_copy[[i, k]] / a_copy[[k, k]];
102
103            // Update the right-hand side
104            b_copy[i] = b_copy[i] - factor * b_copy[k];
105
106            // Update the matrix
107            a_copy[[i, k]] = F::zero(); // Explicitly set to zero to avoid numerical issues
108
109            for j in (k + 1)..n {
110                a_copy[[i, j]] = a_copy[[i, j]] - factor * a_copy[[k, j]];
111            }
112        }
113    }
114
115    // Back-substitution
116    let mut x = Array1::<F>::zeros(n);
117
118    for i in (0..n).rev() {
119        let mut sum = b_copy[i];
120
121        for j in (i + 1)..n {
122            sum -= a_copy[[i, j]] * x[j];
123        }
124
125        x[i] = sum / a_copy[[i, i]];
126    }
127
128    Ok(x)
129}
130
131/// Compute the norm of a vector
132///
133/// # Arguments
134/// * `v` - The vector
135///
136/// # Returns
137/// * The L2 norm of the vector
138#[allow(dead_code)]
139pub fn vector_norm<F>(v: &ArrayView1<F>) -> F
140where
141    F: Float,
142{
143    let mut sum = F::zero();
144    for &val in v.iter() {
145        sum = sum + val * val;
146    }
147    sum.sqrt()
148}
149
150/// Compute the Frobenius norm of a matrix
151///
152/// # Arguments
153/// * `m` - The matrix
154///
155/// # Returns
156/// * The Frobenius norm of the matrix
157#[allow(dead_code)]
158pub fn matrix_norm<F>(m: &ArrayView2<F>) -> F
159where
160    F: Float,
161{
162    let mut sum = F::zero();
163    for val in m.iter() {
164        sum = sum + (*val) * (*val);
165    }
166    sum.sqrt()
167}
168
169/// Solve a linear system using automatic method selection
170#[allow(dead_code)]
171pub fn auto_solve_linear_system<F>(
172    a: &ArrayView2<F>,
173    b: &ArrayView1<F>,
174    solver_type: LinearSolverType,
175) -> IntegrateResult<Array1<F>>
176where
177    F: Float
178        + FromPrimitive
179        + Debug
180        + std::ops::AddAssign
181        + std::ops::SubAssign
182        + std::ops::MulAssign
183        + std::default::Default
184        + std::iter::Sum
185        + scirs2_core::ndarray::ScalarOperand
186        + std::ops::DivAssign,
187{
188    match solver_type {
189        LinearSolverType::Direct => solve_linear_system(a, b),
190        LinearSolverType::Iterative => {
191            // Use GMRES iterative solver
192            solve_gmres(a, b, None, None, None)
193        }
194        LinearSolverType::Auto => {
195            // Use direct solver for small problems, iterative for large
196            let n = a.shape()[0];
197            if n < 100 {
198                solve_linear_system(a, b)
199            } else {
200                // Use GMRES for large systems
201                solve_gmres(a, b, None, None, None)
202            }
203        }
204    }
205}
206
207/// Solve a linear system using LU decomposition (alias for compatibility)
208#[allow(dead_code)]
209pub fn solve_lu<F>(a: &ArrayView2<F>, b: &ArrayView1<F>) -> IntegrateResult<Array1<F>>
210where
211    F: Float
212        + FromPrimitive
213        + Debug
214        + std::ops::AddAssign
215        + std::ops::SubAssign
216        + std::ops::MulAssign,
217{
218    solve_linear_system(a, b)
219}
220
221/// Solve a linear system using GMRES (Generalized Minimal Residual) method
222///
223/// GMRES is a robust iterative method for solving general linear systems.
224///
225/// # Arguments
226/// * `a` - The coefficient matrix A
227/// * `b` - The right-hand side vector b
228/// * `max_iter` - Maximum number of iterations (default: min(n, 50))
229/// * `tol` - Convergence tolerance (default: 1e-10)
230/// * `restart` - Restart parameter for GMRES(m) (default: min(n, 20))
231///
232/// # Returns
233/// * `Result<Array1<F>, IntegrateError>` - The solution vector x
234#[allow(dead_code)]
235pub fn solve_gmres<F>(
236    a: &ArrayView2<F>,
237    b: &ArrayView1<F>,
238    max_iter: Option<usize>,
239    tol: Option<F>,
240    restart: Option<usize>,
241) -> IntegrateResult<Array1<F>>
242where
243    F: Float
244        + FromPrimitive
245        + Debug
246        + std::ops::AddAssign
247        + std::ops::SubAssign
248        + std::ops::MulAssign
249        + Default
250        + std::iter::Sum
251        + scirs2_core::ndarray::ScalarOperand
252        + std::ops::DivAssign,
253{
254    let n = a.nrows();
255    if n != a.ncols() {
256        return Err(IntegrateError::ValueError(
257            "Matrix must be square".to_string(),
258        ));
259    }
260    if n != b.len() {
261        return Err(IntegrateError::ValueError(
262            "Matrix and vector dimensions must match".to_string(),
263        ));
264    }
265
266    let max_iter = max_iter.unwrap_or(std::cmp::min(n, 50));
267    let tol = tol.unwrap_or_else(|| F::from_f64(1e-10).unwrap());
268    let restart = restart.unwrap_or(std::cmp::min(n, 20));
269
270    // Initial guess: zero vector
271    let mut x = Array1::<F>::zeros(n);
272
273    // Compute initial residual: r0 = b - A*x0
274    let mut r = b.to_owned();
275    for i in 0..n {
276        let mut ax_i = F::zero();
277        for j in 0..n {
278            ax_i += a[[i, j]] * x[j];
279        }
280        r[i] -= ax_i;
281    }
282
283    let initial_norm = (r.iter().map(|&x| x * x).sum::<F>()).sqrt();
284    if initial_norm < tol {
285        return Ok(x); // Already converged
286    }
287
288    let mut outer_iter = 0;
289    while outer_iter < max_iter {
290        // GMRES restart cycle
291        let m = std::cmp::min(restart, max_iter - outer_iter);
292
293        // Normalize r to get v1
294        let beta = (r.iter().map(|&x| x * x).sum::<F>()).sqrt();
295        if beta < tol {
296            break; // Converged
297        }
298
299        let mut v = vec![Array1::<F>::zeros(n); m + 1];
300        v[0] = &r / beta;
301
302        let mut h = vec![vec![F::zero(); m]; m + 1];
303        let mut g = vec![F::zero(); m + 1];
304        g[0] = beta;
305
306        let mut j = 0;
307        while j < m {
308            // Compute w = A * v[j]
309            let mut w = Array1::<F>::zeros(n);
310            for i in 0..n {
311                for k in 0..n {
312                    w[i] += a[[i, k]] * v[j][k];
313                }
314            }
315
316            // Modified Gram-Schmidt orthogonalization
317            for i in 0..=j {
318                h[i][j] = v[i].dot(&w);
319                for k in 0..n {
320                    w[k] -= h[i][j] * v[i][k];
321                }
322            }
323
324            h[j + 1][j] = (w.iter().map(|&x| x * x).sum::<F>()).sqrt();
325
326            if h[j + 1][j] < F::from_f64(1e-14).unwrap() {
327                // Linear dependence, stop early
328                break;
329            }
330
331            v[j + 1] = &w / h[j + 1][j];
332
333            // Apply previous Givens rotations to new column of H
334            for i in 0..j {
335                let c = if i < g.len() - 1 {
336                    h[i][j] / (h[i][j] * h[i][j] + h[i + 1][j] * h[i + 1][j]).sqrt()
337                } else {
338                    F::one()
339                };
340                let s = if i < g.len() - 1 {
341                    h[i + 1][j] / (h[i][j] * h[i][j] + h[i + 1][j] * h[i + 1][j]).sqrt()
342                } else {
343                    F::zero()
344                };
345
346                let temp = c * h[i][j] + s * h[i + 1][j];
347                h[i + 1][j] = -s * h[i][j] + c * h[i + 1][j];
348                h[i][j] = temp;
349            }
350
351            // Compute new Givens rotation
352            let c = h[j][j] / (h[j][j] * h[j][j] + h[j + 1][j] * h[j + 1][j]).sqrt();
353            let s = h[j + 1][j] / (h[j][j] * h[j][j] + h[j + 1][j] * h[j + 1][j]).sqrt();
354
355            // Apply new Givens rotation
356            h[j][j] = c * h[j][j] + s * h[j + 1][j];
357            h[j + 1][j] = F::zero();
358
359            let temp = c * g[j];
360            g[j + 1] = -s * g[j];
361            g[j] = temp;
362
363            // Check convergence
364            if g[j + 1].abs() < tol * initial_norm {
365                j += 1;
366                break;
367            }
368
369            j += 1;
370        }
371
372        // Solve upper triangular system H*y = g
373        let mut y = vec![F::zero(); j];
374        for i in (0..j).rev() {
375            let mut sum = g[i];
376            for k in (i + 1)..j {
377                sum -= h[i][k] * y[k];
378            }
379            y[i] = sum / h[i][i];
380        }
381
382        // Update solution: x = x + V*y
383        for i in 0..n {
384            for k in 0..j {
385                x[i] += y[k] * v[k][i];
386            }
387        }
388
389        // Compute new residual
390        r = b.to_owned();
391        for i in 0..n {
392            let mut ax_i = F::zero();
393            for k in 0..n {
394                ax_i += a[[i, k]] * x[k];
395            }
396            r[i] -= ax_i;
397        }
398
399        let residual_norm = (r.iter().map(|&x| x * x).sum::<F>()).sqrt();
400        if residual_norm < tol * initial_norm {
401            break; // Converged
402        }
403
404        outer_iter += m;
405    }
406
407    Ok(x)
408}