scirs2_integrate/ode/utils/linear_solvers/
mod.rs1use crate::error::{IntegrateError, IntegrateResult};
7use scirs2_core::ndarray::{Array1, ArrayView1, ArrayView2};
8use scirs2_core::numeric::{Float, FromPrimitive};
9use std::fmt::Debug;
10
11#[derive(Debug, Clone, Copy, PartialEq)]
13pub enum LinearSolverType {
14 Direct,
16 Iterative,
18 Auto,
20}
21
22#[allow(dead_code)]
31pub fn solve_linear_system<F>(a: &ArrayView2<F>, b: &ArrayView1<F>) -> IntegrateResult<Array1<F>>
32where
33 F: Float
34 + FromPrimitive
35 + Debug
36 + std::ops::AddAssign
37 + std::ops::SubAssign
38 + std::ops::MulAssign,
39{
40 let n = a.shape()[0];
42
43 if a.shape()[0] != a.shape()[1] {
45 return Err(IntegrateError::ValueError(format!(
46 "Matrix must be square to solve linear system, got shape {:?}",
47 a.shape()
48 )));
49 }
50
51 if b.len() != n {
53 return Err(IntegrateError::ValueError(
54 format!("Right-hand side vector dimensions incompatible with matrix: matrix has {} rows but vector has {} elements",
55 n, b.len())
56 ));
57 }
58
59 let mut a_copy = a.to_owned();
61 let mut b_copy = b.to_owned();
62
63 for k in 0..n {
65 let mut pivot_idx = k;
67 let mut max_val = a_copy[[k, k]].abs();
68
69 for i in (k + 1)..n {
70 let val = a_copy[[i, k]].abs();
71 if val > max_val {
72 max_val = val;
73 pivot_idx = i;
74 }
75 }
76
77 if max_val < F::from_f64(1e-14).unwrap() {
79 return Err(IntegrateError::ValueError(
80 "Matrix is singular or nearly singular".to_string(),
81 ));
82 }
83
84 if pivot_idx != k {
86 for j in k..n {
88 let temp = a_copy[[k, j]];
89 a_copy[[k, j]] = a_copy[[pivot_idx, j]];
90 a_copy[[pivot_idx, j]] = temp;
91 }
92
93 let temp = b_copy[k];
95 b_copy[k] = b_copy[pivot_idx];
96 b_copy[pivot_idx] = temp;
97 }
98
99 for i in (k + 1)..n {
101 let factor = a_copy[[i, k]] / a_copy[[k, k]];
102
103 b_copy[i] = b_copy[i] - factor * b_copy[k];
105
106 a_copy[[i, k]] = F::zero(); for j in (k + 1)..n {
110 a_copy[[i, j]] = a_copy[[i, j]] - factor * a_copy[[k, j]];
111 }
112 }
113 }
114
115 let mut x = Array1::<F>::zeros(n);
117
118 for i in (0..n).rev() {
119 let mut sum = b_copy[i];
120
121 for j in (i + 1)..n {
122 sum -= a_copy[[i, j]] * x[j];
123 }
124
125 x[i] = sum / a_copy[[i, i]];
126 }
127
128 Ok(x)
129}
130
131#[allow(dead_code)]
139pub fn vector_norm<F>(v: &ArrayView1<F>) -> F
140where
141 F: Float,
142{
143 let mut sum = F::zero();
144 for &val in v.iter() {
145 sum = sum + val * val;
146 }
147 sum.sqrt()
148}
149
150#[allow(dead_code)]
158pub fn matrix_norm<F>(m: &ArrayView2<F>) -> F
159where
160 F: Float,
161{
162 let mut sum = F::zero();
163 for val in m.iter() {
164 sum = sum + (*val) * (*val);
165 }
166 sum.sqrt()
167}
168
169#[allow(dead_code)]
171pub fn auto_solve_linear_system<F>(
172 a: &ArrayView2<F>,
173 b: &ArrayView1<F>,
174 solver_type: LinearSolverType,
175) -> IntegrateResult<Array1<F>>
176where
177 F: Float
178 + FromPrimitive
179 + Debug
180 + std::ops::AddAssign
181 + std::ops::SubAssign
182 + std::ops::MulAssign
183 + std::default::Default
184 + std::iter::Sum
185 + scirs2_core::ndarray::ScalarOperand
186 + std::ops::DivAssign,
187{
188 match solver_type {
189 LinearSolverType::Direct => solve_linear_system(a, b),
190 LinearSolverType::Iterative => {
191 solve_gmres(a, b, None, None, None)
193 }
194 LinearSolverType::Auto => {
195 let n = a.shape()[0];
197 if n < 100 {
198 solve_linear_system(a, b)
199 } else {
200 solve_gmres(a, b, None, None, None)
202 }
203 }
204 }
205}
206
207#[allow(dead_code)]
209pub fn solve_lu<F>(a: &ArrayView2<F>, b: &ArrayView1<F>) -> IntegrateResult<Array1<F>>
210where
211 F: Float
212 + FromPrimitive
213 + Debug
214 + std::ops::AddAssign
215 + std::ops::SubAssign
216 + std::ops::MulAssign,
217{
218 solve_linear_system(a, b)
219}
220
221#[allow(dead_code)]
235pub fn solve_gmres<F>(
236 a: &ArrayView2<F>,
237 b: &ArrayView1<F>,
238 max_iter: Option<usize>,
239 tol: Option<F>,
240 restart: Option<usize>,
241) -> IntegrateResult<Array1<F>>
242where
243 F: Float
244 + FromPrimitive
245 + Debug
246 + std::ops::AddAssign
247 + std::ops::SubAssign
248 + std::ops::MulAssign
249 + Default
250 + std::iter::Sum
251 + scirs2_core::ndarray::ScalarOperand
252 + std::ops::DivAssign,
253{
254 let n = a.nrows();
255 if n != a.ncols() {
256 return Err(IntegrateError::ValueError(
257 "Matrix must be square".to_string(),
258 ));
259 }
260 if n != b.len() {
261 return Err(IntegrateError::ValueError(
262 "Matrix and vector dimensions must match".to_string(),
263 ));
264 }
265
266 let max_iter = max_iter.unwrap_or(std::cmp::min(n, 50));
267 let tol = tol.unwrap_or_else(|| F::from_f64(1e-10).unwrap());
268 let restart = restart.unwrap_or(std::cmp::min(n, 20));
269
270 let mut x = Array1::<F>::zeros(n);
272
273 let mut r = b.to_owned();
275 for i in 0..n {
276 let mut ax_i = F::zero();
277 for j in 0..n {
278 ax_i += a[[i, j]] * x[j];
279 }
280 r[i] -= ax_i;
281 }
282
283 let initial_norm = (r.iter().map(|&x| x * x).sum::<F>()).sqrt();
284 if initial_norm < tol {
285 return Ok(x); }
287
288 let mut outer_iter = 0;
289 while outer_iter < max_iter {
290 let m = std::cmp::min(restart, max_iter - outer_iter);
292
293 let beta = (r.iter().map(|&x| x * x).sum::<F>()).sqrt();
295 if beta < tol {
296 break; }
298
299 let mut v = vec![Array1::<F>::zeros(n); m + 1];
300 v[0] = &r / beta;
301
302 let mut h = vec![vec![F::zero(); m]; m + 1];
303 let mut g = vec![F::zero(); m + 1];
304 g[0] = beta;
305
306 let mut j = 0;
307 while j < m {
308 let mut w = Array1::<F>::zeros(n);
310 for i in 0..n {
311 for k in 0..n {
312 w[i] += a[[i, k]] * v[j][k];
313 }
314 }
315
316 for i in 0..=j {
318 h[i][j] = v[i].dot(&w);
319 for k in 0..n {
320 w[k] -= h[i][j] * v[i][k];
321 }
322 }
323
324 h[j + 1][j] = (w.iter().map(|&x| x * x).sum::<F>()).sqrt();
325
326 if h[j + 1][j] < F::from_f64(1e-14).unwrap() {
327 break;
329 }
330
331 v[j + 1] = &w / h[j + 1][j];
332
333 for i in 0..j {
335 let c = if i < g.len() - 1 {
336 h[i][j] / (h[i][j] * h[i][j] + h[i + 1][j] * h[i + 1][j]).sqrt()
337 } else {
338 F::one()
339 };
340 let s = if i < g.len() - 1 {
341 h[i + 1][j] / (h[i][j] * h[i][j] + h[i + 1][j] * h[i + 1][j]).sqrt()
342 } else {
343 F::zero()
344 };
345
346 let temp = c * h[i][j] + s * h[i + 1][j];
347 h[i + 1][j] = -s * h[i][j] + c * h[i + 1][j];
348 h[i][j] = temp;
349 }
350
351 let c = h[j][j] / (h[j][j] * h[j][j] + h[j + 1][j] * h[j + 1][j]).sqrt();
353 let s = h[j + 1][j] / (h[j][j] * h[j][j] + h[j + 1][j] * h[j + 1][j]).sqrt();
354
355 h[j][j] = c * h[j][j] + s * h[j + 1][j];
357 h[j + 1][j] = F::zero();
358
359 let temp = c * g[j];
360 g[j + 1] = -s * g[j];
361 g[j] = temp;
362
363 if g[j + 1].abs() < tol * initial_norm {
365 j += 1;
366 break;
367 }
368
369 j += 1;
370 }
371
372 let mut y = vec![F::zero(); j];
374 for i in (0..j).rev() {
375 let mut sum = g[i];
376 for k in (i + 1)..j {
377 sum -= h[i][k] * y[k];
378 }
379 y[i] = sum / h[i][i];
380 }
381
382 for i in 0..n {
384 for k in 0..j {
385 x[i] += y[k] * v[k][i];
386 }
387 }
388
389 r = b.to_owned();
391 for i in 0..n {
392 let mut ax_i = F::zero();
393 for k in 0..n {
394 ax_i += a[[i, k]] * x[k];
395 }
396 r[i] -= ax_i;
397 }
398
399 let residual_norm = (r.iter().map(|&x| x * x).sum::<F>()).sqrt();
400 if residual_norm < tol * initial_norm {
401 break; }
403
404 outer_iter += m;
405 }
406
407 Ok(x)
408}