scirs2_integrate/dae/utils/
linear_solvers.rs

1//! Linear solvers for DAE systems
2//!
3//! This module provides linear system solvers for use within DAE solvers.
4//! These replace the need for external linear algebra libraries like ndarray-linalg.
5
6use crate::error::{IntegrateError, IntegrateResult};
7use crate::IntegrateFloat;
8use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
9// use scirs2_core::numeric::{Float, FromPrimitive};
10
11/// Solve a linear system Ax = b using Gaussian elimination with partial pivoting
12///
13/// # Arguments
14/// * `a` - The coefficient matrix A
15/// * `b` - The right-hand side vector b
16///
17/// # Returns
18/// * `Result<Array1<F>, IntegrateError>` - The solution vector x
19#[allow(dead_code)]
20pub fn solve_linear_system<F>(a: &ArrayView2<F>, b: &ArrayView1<F>) -> IntegrateResult<Array1<F>>
21where
22    F: IntegrateFloat,
23{
24    // Get dimensions
25    let n = a.shape()[0];
26
27    // Check that A is square
28    if a.shape()[0] != a.shape()[1] {
29        return Err(IntegrateError::ValueError(format!(
30            "Matrix must be square to solve linear system, got shape {:?}",
31            a.shape()
32        )));
33    }
34
35    // Check that b has compatible dimensions
36    if b.len() != n {
37        return Err(IntegrateError::ValueError(
38            format!("Right-hand side vector dimensions incompatible with matrix: matrix has {} rows but vector has {} elements", 
39                n, b.len())
40        ));
41    }
42
43    // Create copies of A and b that we can modify
44    let mut a_copy = a.to_owned();
45    let mut b_copy = b.to_owned();
46
47    // Gaussian elimination with partial pivoting
48    for k in 0..n {
49        // Find pivot
50        let mut pivot_idx = k;
51        let mut max_val = a_copy[[k, k]].abs();
52
53        for i in (k + 1)..n {
54            let val = a_copy[[i, k]].abs();
55            if val > max_val {
56                max_val = val;
57                pivot_idx = i;
58            }
59        }
60
61        // Check for singularity
62        if max_val < F::from_f64(1e-14).unwrap() {
63            return Err(IntegrateError::ValueError(
64                "Matrix is singular or nearly singular".to_string(),
65            ));
66        }
67
68        // Swap rows if necessary
69        if pivot_idx != k {
70            // Swap rows in A
71            for j in k..n {
72                let temp = a_copy[[k, j]];
73                a_copy[[k, j]] = a_copy[[pivot_idx, j]];
74                a_copy[[pivot_idx, j]] = temp;
75            }
76
77            // Swap elements in b
78            let temp = b_copy[k];
79            b_copy[k] = b_copy[pivot_idx];
80            b_copy[pivot_idx] = temp;
81        }
82
83        // Eliminate below the pivot
84        for i in (k + 1)..n {
85            let factor = a_copy[[i, k]] / a_copy[[k, k]];
86
87            // Update the right-hand side
88            b_copy[i] = b_copy[i] - factor * b_copy[k];
89
90            // Update the matrix
91            a_copy[[i, k]] = F::zero(); // Explicitly set to zero to avoid numerical issues
92
93            for j in (k + 1)..n {
94                a_copy[[i, j]] = a_copy[[i, j]] - factor * a_copy[[k, j]];
95            }
96        }
97    }
98
99    // Back-substitution
100    let mut x = Array1::<F>::zeros(n);
101
102    for i in (0..n).rev() {
103        let mut sum = b_copy[i];
104
105        for j in (i + 1)..n {
106            sum -= a_copy[[i, j]] * x[j];
107        }
108
109        x[i] = sum / a_copy[[i, i]];
110    }
111
112    Ok(x)
113}
114
115/// Solve a linear system Ax = b using LU decomposition
116///
117/// # Arguments
118/// * `a` - The coefficient matrix A
119/// * `b` - The right-hand side vector b
120///
121/// # Returns
122/// * `Result<Array1<F>, IntegrateError>` - The solution vector x
123#[allow(dead_code)]
124pub fn solve_lu<F>(a: &ArrayView2<F>, b: &ArrayView1<F>) -> IntegrateResult<Array1<F>>
125where
126    F: IntegrateFloat,
127{
128    // For small systems, just use Gaussian elimination
129    if a.shape()[0] <= 10 {
130        return solve_linear_system(a, b);
131    }
132
133    // Get dimensions
134    let n = a.shape()[0];
135
136    // Create copies of A that we can modify
137    let mut a_copy = a.to_owned();
138
139    // Arrays to store the LU decomposition
140    let mut l = Array2::<F>::eye(n);
141    let mut u = Array2::<F>::zeros((n, n));
142
143    // Array to store permutation
144    let mut p = vec![0; n];
145    for (i, p_elem) in p.iter_mut().enumerate().take(n) {
146        *p_elem = i;
147    }
148
149    // Perform LU decomposition with partial pivoting
150    for k in 0..n {
151        // Find pivot
152        let mut pivot_idx = k;
153        let mut max_val = a_copy[[k, k]].abs();
154
155        for i in (k + 1)..n {
156            let val = a_copy[[i, k]].abs();
157            if val > max_val {
158                max_val = val;
159                pivot_idx = i;
160            }
161        }
162
163        // Check for singularity
164        if max_val < F::from_f64(1e-14).unwrap() {
165            return Err(IntegrateError::ValueError(
166                "Matrix is singular or nearly singular".to_string(),
167            ));
168        }
169
170        // Swap rows if necessary
171        if pivot_idx != k {
172            // Swap rows in A
173            for j in 0..n {
174                let temp = a_copy[[k, j]];
175                a_copy[[k, j]] = a_copy[[pivot_idx, j]];
176                a_copy[[pivot_idx, j]] = temp;
177            }
178
179            // Update permutation
180            p.swap(k, pivot_idx);
181
182            // If k > 0, swap rows in L for columns 0 to k-1
183            if k > 0 {
184                for j in 0..k {
185                    let temp = l[[k, j]];
186                    l[[k, j]] = l[[pivot_idx, j]];
187                    l[[pivot_idx, j]] = temp;
188                }
189            }
190        }
191
192        // Compute elements of U
193        for j in k..n {
194            u[[k, j]] = a_copy[[k, j]];
195            for p in 0..k {
196                u[[k, j]] = u[[k, j]] - l[[k, p]] * u[[p, j]];
197            }
198        }
199
200        // Compute elements of L
201        for i in (k + 1)..n {
202            if u[[k, k]].abs() < F::from_f64(1e-14).unwrap() {
203                return Err(IntegrateError::ValueError(
204                    "LU decomposition failed: division by zero".to_string(),
205                ));
206            }
207
208            l[[i, k]] = a_copy[[i, k]];
209            for p in 0..k {
210                l[[i, k]] = l[[i, k]] - l[[i, p]] * u[[p, k]];
211            }
212            l[[i, k]] /= u[[k, k]];
213        }
214    }
215
216    // Solve Ly = Pb
217    let mut y = Array1::<F>::zeros(n);
218    let mut pb = Array1::<F>::zeros(n);
219
220    // Permute b
221    for i in 0..n {
222        pb[i] = b[p[i]];
223    }
224
225    // Forward substitution
226    for i in 0..n {
227        y[i] = pb[i];
228        for j in 0..i {
229            y[i] = y[i] - l[[i, j]] * y[j];
230        }
231    }
232
233    // Solve Ux = y
234    let mut x = Array1::<F>::zeros(n);
235
236    // Back substitution
237    for i in (0..n).rev() {
238        if u[[i, i]].abs() < F::from_f64(1e-14).unwrap() {
239            return Err(IntegrateError::ValueError(
240                "LU decomposition: singular matrix detected during back substitution".to_string(),
241            ));
242        }
243
244        x[i] = y[i];
245        for j in (i + 1)..n {
246            x[i] = x[i] - u[[i, j]] * x[j];
247        }
248        x[i] /= u[[i, i]];
249    }
250
251    Ok(x)
252}
253
254/// Compute the norm of a vector
255///
256/// # Arguments
257/// * `v` - The vector
258///
259/// # Returns
260/// * The L2 norm of the vector
261#[allow(dead_code)]
262pub fn vector_norm<F>(v: &ArrayView1<F>) -> F
263where
264    F: IntegrateFloat,
265{
266    let mut sum = F::zero();
267    for &val in v.iter() {
268        sum += val * val;
269    }
270    sum.sqrt()
271}
272
273/// Compute the Frobenius norm of a matrix
274///
275/// # Arguments
276/// * `m` - The matrix
277///
278/// # Returns
279/// * The Frobenius norm of the matrix
280#[allow(dead_code)]
281pub fn matrix_norm<F>(m: &ArrayView2<F>) -> F
282where
283    F: IntegrateFloat,
284{
285    let mut sum = F::zero();
286    for val in m.iter() {
287        sum += (*val) * (*val);
288    }
289    sum.sqrt()
290}