1use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, ScalarOperand};
4use scirs2_core::numeric::{Float, NumAssign, One};
5use std::iter::Sum;
6
7use crate::basic::inv;
8use crate::decomposition::{lu, qr, svd};
9use crate::error::{LinalgError, LinalgResult};
10use crate::validation::{
11 validate_finite_vector, validate_finitematrix, validate_least_squares, validate_linear_system,
12 validate_multiple_linear_systems, validate_not_empty_vector, validate_not_emptymatrix,
13 validate_squarematrix, validatematrix_vector_dimensions,
14};
15
16pub struct LstsqResult<F: Float> {
18 pub x: Array1<F>,
20 pub residuals: F,
22 pub rank: usize,
24 pub s: Array1<F>,
26}
27
28#[allow(dead_code)]
55pub fn solve<F>(
56 a: &ArrayView2<F>,
57 b: &ArrayView1<F>,
58 workers: Option<usize>,
59) -> LinalgResult<Array1<F>>
60where
61 F: Float + NumAssign + One + Sum + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
62{
63 validate_linear_system(a, b, "Linear system solve")?;
65
66 if a.nrows() <= 4 {
68 let a_inv = inv(a, None)?;
69 let mut x = Array1::zeros(a.nrows());
71 for i in 0..a.nrows() {
72 for j in 0..a.nrows() {
73 x[i] += a_inv[[i, j]] * b[j];
74 }
75 }
76 return Ok(x);
77 }
78
79 if let Some(num_workers) = workers {
81 std::env::set_var("OMP_NUM_THREADS", num_workers.to_string());
82 }
83
84 let (p, l, u) = match lu(a, workers) {
86 Err(LinalgError::SingularMatrixError(_)) => {
87 return Err(LinalgError::singularmatrix_with_suggestions(
88 "linear system solve",
89 a.dim(),
90 None,
91 ))
92 }
93 Err(e) => return Err(e),
94 Ok(result) => result,
95 };
96
97 let mut pb = Array1::zeros(b.len());
99 for i in 0..p.nrows() {
100 for j in 0..p.ncols() {
101 pb[i] += p[[i, j]] * b[j];
102 }
103 }
104
105 let y = solve_triangular(&l.view(), &pb.view(), true, true)?;
107
108 let x = solve_triangular(&u.view(), &y.view(), false, false)?;
110
111 Ok(x)
112}
113
114#[allow(dead_code)]
141pub fn solve_triangular<F>(
142 a: &ArrayView2<F>,
143 b: &ArrayView1<F>,
144 lower: bool,
145 unit_diagonal: bool,
146) -> LinalgResult<Array1<F>>
147where
148 F: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
149{
150 validate_not_emptymatrix(a, "Triangular system solve")?;
152 validate_not_empty_vector(b, "Triangular system solve")?;
153 validate_squarematrix(a, "Triangular system solve")?;
154 validatematrix_vector_dimensions(a, b, "Triangular system solve")?;
155 validate_finitematrix(a, "Triangular system solve")?;
156 validate_finite_vector(b, "Triangular system solve")?;
157
158 let n = a.nrows();
159 let mut x = Array1::zeros(n);
160
161 if lower {
162 for i in 0..n {
164 let mut sum = b[i];
165 for j in 0..i {
166 sum -= a[[i, j]] * x[j];
167 }
168 if unit_diagonal {
169 x[i] = sum;
170 } else {
171 if a[[i, i]].abs() < F::epsilon() {
172 return Err(LinalgError::singularmatrix_with_suggestions(
173 "triangular system solve (forward substitution)",
174 a.dim(),
175 Some(1e16), ));
177 }
178 x[i] = sum / a[[i, i]];
179 }
180 }
181 } else {
182 for i in (0..n).rev() {
184 let mut sum = b[i];
185 for j in (i + 1)..n {
186 sum -= a[[i, j]] * x[j];
187 }
188 if unit_diagonal {
189 x[i] = sum;
190 } else {
191 if a[[i, i]].abs() < F::epsilon() {
192 return Err(LinalgError::singularmatrix_with_suggestions(
193 "triangular system solve (back substitution)",
194 a.dim(),
195 Some(1e16), ));
197 }
198 x[i] = sum / a[[i, i]];
199 }
200 }
201 }
202
203 Ok(x)
204}
205
206#[allow(dead_code)]
237pub fn lstsq<F>(
238 a: &ArrayView2<F>,
239 b: &ArrayView1<F>,
240 workers: Option<usize>,
241) -> LinalgResult<LstsqResult<F>>
242where
243 F: Float + NumAssign + Sum + One + scirs2_core::ndarray::ScalarOperand + Send + Sync + 'static,
244{
245 validate_least_squares(a, b, "Least squares solve")?;
247
248 if let Some(num_workers) = workers {
250 std::env::set_var("OMP_NUM_THREADS", num_workers.to_string());
251 }
252
253 if a.nrows() >= a.ncols() {
255 let (q, r) = qr(a, workers)?;
257
258 let qt = q.t().to_owned();
260 let mut qt_b = Array1::zeros(qt.nrows());
261 for i in 0..qt.nrows() {
262 for j in 0..qt.ncols() {
263 qt_b[i] += qt[[i, j]] * b[j];
264 }
265 }
266
267 let rank = a.ncols(); let qt_b_truncated = qt_b.slice(scirs2_core::ndarray::s![0..rank]).to_owned();
272
273 let r_truncated = r
275 .slice(scirs2_core::ndarray::s![0..rank, 0..a.ncols()])
276 .to_owned();
277 let x = solve_triangular(&r_truncated.view(), &qt_b_truncated.view(), false, false)?;
278
279 let mut residuals = F::zero();
281 for i in 0..a.nrows() {
282 let mut a_x_i = F::zero();
283 for j in 0..a.ncols() {
284 a_x_i += a[[i, j]] * x[j];
285 }
286 let diff = b[i] - a_x_i;
287 residuals += diff * diff;
288 }
289
290 let s = Array1::zeros(0);
292
293 Ok(LstsqResult {
294 x,
295 residuals,
296 rank,
297 s,
298 })
299 } else {
300 let (u, s, vt) = svd(a, false, workers)?;
302
303 let max_dim = a.nrows().max(a.ncols());
305 let max_dim_f = F::from(max_dim).ok_or_else(|| {
306 LinalgError::NumericalError(format!(
307 "Failed to convert matrix dimension {max_dim} to numeric type"
308 ))
309 })?;
310 let threshold = s[0] * max_dim_f * F::epsilon();
311 let rank = s.iter().filter(|&&val| val > threshold).count();
312
313 let ut = u.t().to_owned();
315 let mut ut_b = Array1::zeros(ut.nrows());
316 for i in 0..ut.nrows() {
317 for j in 0..ut.ncols() {
318 ut_b[i] += ut[[i, j]] * b[j];
319 }
320 }
321
322 let mut x = Array1::zeros(a.ncols());
324
325 for i in 0..rank {
327 let s_inv = F::one() / s[i];
328 for j in 0..a.ncols() {
329 x[j] += vt[[i, j]] * ut_b[i] * s_inv;
330 }
331 }
332
333 let mut residuals = F::zero();
335 for i in 0..a.nrows() {
336 let mut a_x_i = F::zero();
337 for j in 0..a.ncols() {
338 a_x_i += a[[i, j]] * x[j];
339 }
340 let diff = b[i] - a_x_i;
341 residuals += diff * diff;
342 }
343
344 Ok(LstsqResult {
345 x,
346 residuals,
347 rank,
348 s,
349 })
350 }
351}
352
353#[allow(dead_code)]
378pub fn solve_multiple<F>(
379 a: &ArrayView2<F>,
380 b: &ArrayView2<F>,
381 workers: Option<usize>,
382) -> LinalgResult<Array2<F>>
383where
384 F: Float + NumAssign + One + Sum + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
385{
386 validate_multiple_linear_systems(a, b, "Multiple linear systems solve")?;
388
389 if let Some(num_workers) = workers {
391 std::env::set_var("OMP_NUM_THREADS", num_workers.to_string());
392 }
393
394 let (p, l, u) = match lu(a, workers) {
396 Err(LinalgError::SingularMatrixError(_)) => {
397 return Err(LinalgError::singularmatrix_with_suggestions(
398 "multiple linear systems solve",
399 a.dim(),
400 None,
401 ))
402 }
403 Err(e) => return Err(e),
404 Ok(result) => result,
405 };
406
407 let mut x = Array2::zeros((a.ncols(), b.ncols()));
409
410 for j in 0..b.ncols() {
412 let b_j = b.column(j).to_owned();
414
415 let mut pb = Array1::zeros(b_j.len());
417 for i in 0..p.nrows() {
418 for k in 0..p.ncols() {
419 pb[i] += p[[i, k]] * b_j[k];
420 }
421 }
422
423 let y = solve_triangular(&l.view(), &pb.view(), true, true)?;
425
426 let x_j = solve_triangular(&u.view(), &y.view(), false, false)?;
428
429 for i in 0..x_j.len() {
431 x[[i, j]] = x_j[i];
432 }
433 }
434
435 Ok(x)
436}
437
438#[allow(dead_code)]
442pub fn solve_default<F>(a: &ArrayView2<F>, b: &ArrayView1<F>) -> LinalgResult<Array1<F>>
443where
444 F: Float + NumAssign + One + Sum + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
445{
446 solve(a, b, None)
447}
448
449#[allow(dead_code)]
451pub fn lstsq_default<F>(a: &ArrayView2<F>, b: &ArrayView1<F>) -> LinalgResult<LstsqResult<F>>
452where
453 F: Float + NumAssign + Sum + One + scirs2_core::ndarray::ScalarOperand + Send + Sync + 'static,
454{
455 lstsq(a, b, None)
456}
457
458#[allow(dead_code)]
460pub fn solve_multiple_default<F>(a: &ArrayView2<F>, b: &ArrayView2<F>) -> LinalgResult<Array2<F>>
461where
462 F: Float + NumAssign + One + Sum + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
463{
464 solve_multiple(a, b, None)
465}
466
467#[cfg(test)]
468mod tests {
469 use super::*;
470 use approx::assert_relative_eq;
471 use scirs2_core::ndarray::array;
472
473 #[test]
474 fn test_solve() {
475 let a = array![[1.0, 0.0], [0.0, 1.0]];
477 let b = array![2.0, 3.0];
478 let x =
479 solve(&a.view(), &b.view(), None).expect("Solve should succeed for identity matrix");
480 assert_relative_eq!(x[0], 2.0);
481 assert_relative_eq!(x[1], 3.0);
482
483 let a = array![[1.0, 2.0], [3.0, 4.0]];
485 let b = array![5.0, 11.0];
486 let x =
487 solve(&a.view(), &b.view(), None).expect("Solve should succeed for this test system");
488 assert_relative_eq!(x[0], 1.0);
489 assert_relative_eq!(x[1], 2.0);
490 }
491
492 #[test]
493 fn test_solve_triangular_lower() {
494 let a = array![[1.0, 0.0], [2.0, 3.0]];
496 let b = array![2.0, 8.0];
497 let x = solve_triangular(&a.view(), &b.view(), true, false)
498 .expect("Lower triangular solve should succeed");
499 assert_relative_eq!(x[0], 2.0);
500 assert_relative_eq!(x[1], 4.0 / 3.0);
501
502 let a = array![[1.0, 0.0], [2.0, 1.0]];
504 let b = array![2.0, 6.0];
505 let x = solve_triangular(&a.view(), &b.view(), true, true)
506 .expect("Upper triangular solve should succeed");
507 assert_relative_eq!(x[0], 2.0);
508 assert_relative_eq!(x[1], 2.0);
509 }
510
511 #[test]
512 fn test_solve_triangular_upper() {
513 let a = array![[3.0, 2.0], [0.0, 1.0]];
515 let b = array![8.0, 2.0];
516 let x = solve_triangular(&a.view(), &b.view(), false, false)
517 .expect("Lower triangular unit diagonal solve should succeed");
518 assert_relative_eq!(x[0], 4.0 / 3.0);
519 assert_relative_eq!(x[1], 2.0);
520
521 let a = array![[1.0, 2.0], [0.0, 1.0]];
523 let b = array![6.0, 2.0];
524 let x = solve_triangular(&a.view(), &b.view(), false, true)
525 .expect("Upper triangular unit diagonal solve should succeed");
526 assert_relative_eq!(x[0], 2.0);
527 assert_relative_eq!(x[1], 2.0);
528 }
529}