scirs2_linalg/
solve.rs

1//! Linear equation solvers
2
3use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, ScalarOperand};
4use scirs2_core::numeric::{Float, NumAssign, One};
5use std::iter::Sum;
6
7use crate::basic::inv;
8use crate::decomposition::{lu, qr, svd};
9use crate::error::{LinalgError, LinalgResult};
10use crate::validation::{
11    validate_finite_vector, validate_finitematrix, validate_least_squares, validate_linear_system,
12    validate_multiple_linear_systems, validate_not_empty_vector, validate_not_emptymatrix,
13    validate_squarematrix, validatematrix_vector_dimensions,
14};
15
16/// Solution to a least-squares problem
17pub struct LstsqResult<F: Float> {
18    /// Least-squares solution
19    pub x: Array1<F>,
20    /// Sum of squared residuals
21    pub residuals: F,
22    /// Rank of coefficient matrix
23    pub rank: usize,
24    /// Singular values
25    pub s: Array1<F>,
26}
27
28/// Solve a linear system of equations.
29///
30/// Solves the equation a x = b for x, assuming a is a square matrix.
31///
32/// # Arguments
33///
34/// * `a` - Coefficient matrix
35/// * `b` - Ordinate or "dependent variable" values
36/// * `workers` - Number of worker threads (None = use default)
37///
38/// # Returns
39///
40/// * Solution vector x
41///
42/// # Examples
43///
44/// ```
45/// use scirs2_core::ndarray::{array, ScalarOperand};
46/// use scirs2_linalg::solve;
47///
48/// let a = array![[1.0_f64, 0.0], [0.0, 1.0]];
49/// let b = array![2.0_f64, 3.0];
50/// let x = solve(&a.view(), &b.view(), None).unwrap();
51/// assert!((x[0] - 2.0).abs() < 1e-10);
52/// assert!((x[1] - 3.0).abs() < 1e-10);
53/// ```
54#[allow(dead_code)]
55pub fn solve<F>(
56    a: &ArrayView2<F>,
57    b: &ArrayView1<F>,
58    workers: Option<usize>,
59) -> LinalgResult<Array1<F>>
60where
61    F: Float + NumAssign + One + Sum + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
62{
63    // Parameter validation using helper function
64    validate_linear_system(a, b, "Linear system solve")?;
65
66    // For small matrices, we can solve directly using the inverse
67    if a.nrows() <= 4 {
68        let a_inv = inv(a, None)?;
69        // Compute x = a_inv * b
70        let mut x = Array1::zeros(a.nrows());
71        for i in 0..a.nrows() {
72            for j in 0..a.nrows() {
73                x[i] += a_inv[[i, j]] * b[j];
74            }
75        }
76        return Ok(x);
77    }
78
79    // Configure OpenMP thread count if workers specified
80    if let Some(num_workers) = workers {
81        std::env::set_var("OMP_NUM_THREADS", num_workers.to_string());
82    }
83
84    // For larger systems, use LU decomposition
85    let (p, l, u) = match lu(a, workers) {
86        Err(LinalgError::SingularMatrixError(_)) => {
87            return Err(LinalgError::singularmatrix_with_suggestions(
88                "linear system solve",
89                a.dim(),
90                None,
91            ))
92        }
93        Err(e) => return Err(e),
94        Ok(result) => result,
95    };
96
97    // Compute P*b
98    let mut pb = Array1::zeros(b.len());
99    for i in 0..p.nrows() {
100        for j in 0..p.ncols() {
101            pb[i] += p[[i, j]] * b[j];
102        }
103    }
104
105    // Solve L*y = P*b by forward substitution
106    let y = solve_triangular(&l.view(), &pb.view(), true, true)?;
107
108    // Solve U*x = y by back substitution
109    let x = solve_triangular(&u.view(), &y.view(), false, false)?;
110
111    Ok(x)
112}
113
114/// Solve a linear system with a lower or upper triangular coefficient matrix.
115///
116/// # Arguments
117///
118/// * `a` - Triangular coefficient matrix
119/// * `b` - Ordinate or "dependent variable" values
120/// * `lower` - If true, the matrix is lower triangular, if false, upper triangular
121/// * `unit_diagonal` - If true, the diagonal elements of a are assumed to be 1
122///
123/// # Returns
124///
125/// * Solution vector x
126///
127/// # Examples
128///
129/// ```
130/// use scirs2_core::ndarray::{array, ScalarOperand};
131/// use scirs2_linalg::solve_triangular;
132///
133/// // Lower triangular system
134/// let a = array![[1.0_f64, 0.0], [2.0, 3.0]];
135/// let b = array![2.0_f64, 8.0];
136/// let x = solve_triangular(&a.view(), &b.view(), true, false).unwrap();
137/// assert!((x[0] - 2.0).abs() < 1e-10);
138/// assert!((x[1] - 4.0/3.0).abs() < 1e-10);
139/// ```
140#[allow(dead_code)]
141pub fn solve_triangular<F>(
142    a: &ArrayView2<F>,
143    b: &ArrayView1<F>,
144    lower: bool,
145    unit_diagonal: bool,
146) -> LinalgResult<Array1<F>>
147where
148    F: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
149{
150    // Parameter validation using helper functions
151    validate_not_emptymatrix(a, "Triangular system solve")?;
152    validate_not_empty_vector(b, "Triangular system solve")?;
153    validate_squarematrix(a, "Triangular system solve")?;
154    validatematrix_vector_dimensions(a, b, "Triangular system solve")?;
155    validate_finitematrix(a, "Triangular system solve")?;
156    validate_finite_vector(b, "Triangular system solve")?;
157
158    let n = a.nrows();
159    let mut x = Array1::zeros(n);
160
161    if lower {
162        // Forward substitution for lower triangular matrix
163        for i in 0..n {
164            let mut sum = b[i];
165            for j in 0..i {
166                sum -= a[[i, j]] * x[j];
167            }
168            if unit_diagonal {
169                x[i] = sum;
170            } else {
171                if a[[i, i]].abs() < F::epsilon() {
172                    return Err(LinalgError::singularmatrix_with_suggestions(
173                        "triangular system solve (forward substitution)",
174                        a.dim(),
175                        Some(1e16), // Very high condition number due to zero _diagonal
176                    ));
177                }
178                x[i] = sum / a[[i, i]];
179            }
180        }
181    } else {
182        // Back substitution for upper triangular matrix
183        for i in (0..n).rev() {
184            let mut sum = b[i];
185            for j in (i + 1)..n {
186                sum -= a[[i, j]] * x[j];
187            }
188            if unit_diagonal {
189                x[i] = sum;
190            } else {
191                if a[[i, i]].abs() < F::epsilon() {
192                    return Err(LinalgError::singularmatrix_with_suggestions(
193                        "triangular system solve (back substitution)",
194                        a.dim(),
195                        Some(1e16), // Very high condition number due to zero _diagonal
196                    ));
197                }
198                x[i] = sum / a[[i, i]];
199            }
200        }
201    }
202
203    Ok(x)
204}
205
206/// Compute least-squares solution to a linear matrix equation.
207///
208/// Computes the vector x that solves the least squares equation
209/// a x = b by computing the full least squares solution.
210///
211/// # Arguments
212///
213/// * `a` - Coefficient matrix
214/// * `b` - Ordinate or "dependent variable" values
215/// * `workers` - Number of worker threads (None = use default)
216///
217/// # Returns
218///
219/// * A LstsqResult struct containing:
220///   * x: Least-squares solution
221///   * residuals: Sum of squared residuals
222///   * rank: Rank of matrix a
223///   * s: Singular values of a
224///
225/// # Examples
226///
227/// ```
228/// use scirs2_core::ndarray::{array, ScalarOperand};
229/// use scirs2_linalg::lstsq;
230///
231/// let a = array![[1.0_f64, 1.0], [1.0, 2.0], [1.0, 3.0]];
232/// let b = array![6.0_f64, 9.0, 12.0];
233/// let result = lstsq(&a.view(), &b.view(), None).unwrap();
234/// // result.x should be approximately [3.0, 3.0]
235/// ```
236#[allow(dead_code)]
237pub fn lstsq<F>(
238    a: &ArrayView2<F>,
239    b: &ArrayView1<F>,
240    workers: Option<usize>,
241) -> LinalgResult<LstsqResult<F>>
242where
243    F: Float + NumAssign + Sum + One + scirs2_core::ndarray::ScalarOperand + Send + Sync + 'static,
244{
245    // Parameter validation using helper function
246    validate_least_squares(a, b, "Least squares solve")?;
247
248    // Configure OpenMP thread count if workers specified
249    if let Some(num_workers) = workers {
250        std::env::set_var("OMP_NUM_THREADS", num_workers.to_string());
251    }
252
253    // For underdetermined systems with full rank, use the normal equation approach
254    if a.nrows() >= a.ncols() {
255        // QR decomposition approach
256        let (q, r) = qr(a, workers)?;
257
258        // Compute Q^T * b
259        let qt = q.t().to_owned();
260        let mut qt_b = Array1::zeros(qt.nrows());
261        for i in 0..qt.nrows() {
262            for j in 0..qt.ncols() {
263                qt_b[i] += qt[[i, j]] * b[j];
264            }
265        }
266
267        // Get the effective rank
268        let rank = a.ncols(); // Assume full rank for now
269
270        // Extract the first part of Q^T * b corresponding to the rank
271        let qt_b_truncated = qt_b.slice(scirs2_core::ndarray::s![0..rank]).to_owned();
272
273        // Solve R * x = Q^T * b using back substitution
274        let r_truncated = r
275            .slice(scirs2_core::ndarray::s![0..rank, 0..a.ncols()])
276            .to_owned();
277        let x = solve_triangular(&r_truncated.view(), &qt_b_truncated.view(), false, false)?;
278
279        // Compute residuals: ||Ax - b||²
280        let mut residuals = F::zero();
281        for i in 0..a.nrows() {
282            let mut a_x_i = F::zero();
283            for j in 0..a.ncols() {
284                a_x_i += a[[i, j]] * x[j];
285            }
286            let diff = b[i] - a_x_i;
287            residuals += diff * diff;
288        }
289
290        // Create singular values (empty for QR approach)
291        let s = Array1::zeros(0);
292
293        Ok(LstsqResult {
294            x,
295            residuals,
296            rank,
297            s,
298        })
299    } else {
300        // Underdetermined system, use SVD
301        let (u, s, vt) = svd(a, false, workers)?;
302
303        // Determine effective rank by thresholding singular values
304        let max_dim = a.nrows().max(a.ncols());
305        let max_dim_f = F::from(max_dim).ok_or_else(|| {
306            LinalgError::NumericalError(format!(
307                "Failed to convert matrix dimension {max_dim} to numeric type"
308            ))
309        })?;
310        let threshold = s[0] * max_dim_f * F::epsilon();
311        let rank = s.iter().filter(|&&val| val > threshold).count();
312
313        // Compute U^T * b
314        let ut = u.t().to_owned();
315        let mut ut_b = Array1::zeros(ut.nrows());
316        for i in 0..ut.nrows() {
317            for j in 0..ut.ncols() {
318                ut_b[i] += ut[[i, j]] * b[j];
319            }
320        }
321
322        // Initialize solution vector
323        let mut x = Array1::zeros(a.ncols());
324
325        // Solve using SVD components
326        for i in 0..rank {
327            let s_inv = F::one() / s[i];
328            for j in 0..a.ncols() {
329                x[j] += vt[[i, j]] * ut_b[i] * s_inv;
330            }
331        }
332
333        // Compute residuals: ||Ax - b||²
334        let mut residuals = F::zero();
335        for i in 0..a.nrows() {
336            let mut a_x_i = F::zero();
337            for j in 0..a.ncols() {
338                a_x_i += a[[i, j]] * x[j];
339            }
340            let diff = b[i] - a_x_i;
341            residuals += diff * diff;
342        }
343
344        Ok(LstsqResult {
345            x,
346            residuals,
347            rank,
348            s,
349        })
350    }
351}
352
353/// Solve the linear system Ax = B for x with multiple right-hand sides.
354///
355/// # Arguments
356///
357/// * `a` - Coefficient matrix
358/// * `b` - Matrix of right-hand sides where each column is a different right-hand side
359/// * `workers` - Number of worker threads (None = use default)
360///
361/// # Returns
362///
363/// * Solution matrix x where each column is a solution vector
364///
365/// # Examples
366///
367/// ```
368/// use scirs2_core::ndarray::{array, ScalarOperand};
369/// use scirs2_linalg::solve_multiple;
370///
371/// let a = array![[1.0_f64, 0.0], [0.0, 1.0]];
372/// let b = array![[2.0_f64, 4.0], [3.0, 5.0]];
373/// let x = solve_multiple(&a.view(), &b.view(), None).unwrap();
374/// // First column of x should be [2.0, 3.0]
375/// // Second column of x should be [4.0, 5.0]
376/// ```
377#[allow(dead_code)]
378pub fn solve_multiple<F>(
379    a: &ArrayView2<F>,
380    b: &ArrayView2<F>,
381    workers: Option<usize>,
382) -> LinalgResult<Array2<F>>
383where
384    F: Float + NumAssign + One + Sum + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
385{
386    // Parameter validation using helper function
387    validate_multiple_linear_systems(a, b, "Multiple linear systems solve")?;
388
389    // Configure OpenMP thread count if workers specified
390    if let Some(num_workers) = workers {
391        std::env::set_var("OMP_NUM_THREADS", num_workers.to_string());
392    }
393
394    // For efficiency, perform LU decomposition once
395    let (p, l, u) = match lu(a, workers) {
396        Err(LinalgError::SingularMatrixError(_)) => {
397            return Err(LinalgError::singularmatrix_with_suggestions(
398                "multiple linear systems solve",
399                a.dim(),
400                None,
401            ))
402        }
403        Err(e) => return Err(e),
404        Ok(result) => result,
405    };
406
407    // Initialize solution matrix
408    let mut x = Array2::zeros((a.ncols(), b.ncols()));
409
410    // Solve for each right-hand side
411    for j in 0..b.ncols() {
412        // Extract j-th right-hand side
413        let b_j = b.column(j).to_owned();
414
415        // Compute P*b
416        let mut pb = Array1::zeros(b_j.len());
417        for i in 0..p.nrows() {
418            for k in 0..p.ncols() {
419                pb[i] += p[[i, k]] * b_j[k];
420            }
421        }
422
423        // Solve L*y = P*b by forward substitution
424        let y = solve_triangular(&l.view(), &pb.view(), true, true)?;
425
426        // Solve U*x = y by back substitution
427        let x_j = solve_triangular(&u.view(), &y.view(), false, false)?;
428
429        // Store solution in the j-th column of x
430        for i in 0..x_j.len() {
431            x[[i, j]] = x_j[i];
432        }
433    }
434
435    Ok(x)
436}
437
438// Convenience wrapper functions for backward compatibility
439
440/// Solve linear system using default thread count
441#[allow(dead_code)]
442pub fn solve_default<F>(a: &ArrayView2<F>, b: &ArrayView1<F>) -> LinalgResult<Array1<F>>
443where
444    F: Float + NumAssign + One + Sum + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
445{
446    solve(a, b, None)
447}
448
449/// Compute least-squares solution using default thread count
450#[allow(dead_code)]
451pub fn lstsq_default<F>(a: &ArrayView2<F>, b: &ArrayView1<F>) -> LinalgResult<LstsqResult<F>>
452where
453    F: Float + NumAssign + Sum + One + scirs2_core::ndarray::ScalarOperand + Send + Sync + 'static,
454{
455    lstsq(a, b, None)
456}
457
458/// Solve multiple linear systems using default thread count
459#[allow(dead_code)]
460pub fn solve_multiple_default<F>(a: &ArrayView2<F>, b: &ArrayView2<F>) -> LinalgResult<Array2<F>>
461where
462    F: Float + NumAssign + One + Sum + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
463{
464    solve_multiple(a, b, None)
465}
466
467#[cfg(test)]
468mod tests {
469    use super::*;
470    use approx::assert_relative_eq;
471    use scirs2_core::ndarray::array;
472
473    #[test]
474    fn test_solve() {
475        // Identity matrix
476        let a = array![[1.0, 0.0], [0.0, 1.0]];
477        let b = array![2.0, 3.0];
478        let x =
479            solve(&a.view(), &b.view(), None).expect("Solve should succeed for identity matrix");
480        assert_relative_eq!(x[0], 2.0);
481        assert_relative_eq!(x[1], 3.0);
482
483        // General 2x2 matrix
484        let a = array![[1.0, 2.0], [3.0, 4.0]];
485        let b = array![5.0, 11.0];
486        let x =
487            solve(&a.view(), &b.view(), None).expect("Solve should succeed for this test system");
488        assert_relative_eq!(x[0], 1.0);
489        assert_relative_eq!(x[1], 2.0);
490    }
491
492    #[test]
493    fn test_solve_triangular_lower() {
494        // Lower triangular system
495        let a = array![[1.0, 0.0], [2.0, 3.0]];
496        let b = array![2.0, 8.0];
497        let x = solve_triangular(&a.view(), &b.view(), true, false)
498            .expect("Lower triangular solve should succeed");
499        assert_relative_eq!(x[0], 2.0);
500        assert_relative_eq!(x[1], 4.0 / 3.0);
501
502        // With unit diagonal
503        let a = array![[1.0, 0.0], [2.0, 1.0]];
504        let b = array![2.0, 6.0];
505        let x = solve_triangular(&a.view(), &b.view(), true, true)
506            .expect("Upper triangular solve should succeed");
507        assert_relative_eq!(x[0], 2.0);
508        assert_relative_eq!(x[1], 2.0);
509    }
510
511    #[test]
512    fn test_solve_triangular_upper() {
513        // Upper triangular system
514        let a = array![[3.0, 2.0], [0.0, 1.0]];
515        let b = array![8.0, 2.0];
516        let x = solve_triangular(&a.view(), &b.view(), false, false)
517            .expect("Lower triangular unit diagonal solve should succeed");
518        assert_relative_eq!(x[0], 4.0 / 3.0);
519        assert_relative_eq!(x[1], 2.0);
520
521        // With unit diagonal
522        let a = array![[1.0, 2.0], [0.0, 1.0]];
523        let b = array![6.0, 2.0];
524        let x = solve_triangular(&a.view(), &b.view(), false, true)
525            .expect("Upper triangular unit diagonal solve should succeed");
526        assert_relative_eq!(x[0], 2.0);
527        assert_relative_eq!(x[1], 2.0);
528    }
529}