Skip to main content

scirs2_sparse/
direct_solver.rs

1//! Sparse direct solvers
2//!
3//! This module provides production-quality direct solvers for sparse linear systems:
4//!
5//! - **Sparse LU factorization**: Dense kernel with partial pivoting, AMD column ordering
6//! - **Sparse Cholesky**: For symmetric positive definite (SPD) matrices
7//! - **Fill-reducing orderings**: AMD (Approximate Minimum Degree), nested dissection
8//! - **Symbolic analysis**: Determines the fill-in pattern before numeric factorization
9//! - **Numeric factorization**: Computes the actual factor values
10//! - **Triangular solves**: Forward and backward substitution
11//!
12//! # Architecture
13//!
14//! Factorization is split into two phases:
15//! 1. **Ordering phase** — Computes a fill-reducing permutation (AMD or nested dissection)
16//! 2. **Numeric phase** — Applies the permutation, then factors the reordered matrix
17//!
18//! # References
19//!
20//! - Davis, T.A. (2006). "Direct Methods for Sparse Linear Systems". SIAM.
21//! - Amestoy, P.R., Davis, T.A., & Duff, I.S. (1996). "An approximate minimum
22//!   degree ordering algorithm". SIAM J. Matrix Anal. Appl. 17(4), 886-905.
23//! - George, A. & Liu, J.W. (1981). "Computer Solution of Large Sparse Positive
24//!   Definite Systems". Prentice-Hall.
25
26use crate::csr::CsrMatrix;
27use crate::error::{SparseError, SparseResult};
28use scirs2_core::numeric::{Float, NumAssign, SparseElement};
29use std::collections::BTreeSet;
30use std::fmt::Debug;
31use std::iter::Sum;
32
33// ---------------------------------------------------------------------------
34// SparseSolver trait
35// ---------------------------------------------------------------------------
36
37/// Trait for sparse direct solvers.
38///
39/// A solver follows the factorize-then-solve paradigm:
40/// 1. Call `factorize()` to compute the decomposition.
41/// 2. Call `solve()` (or `solve_multi()`) one or more times.
42pub trait SparseSolver<F: Float> {
43    /// Compute the factorization of the given matrix.
44    fn factorize(&mut self, matrix: &CsrMatrix<F>) -> SparseResult<()>;
45
46    /// Solve `A x = b` using the stored factorization.
47    fn solve(&self, b: &[F]) -> SparseResult<Vec<F>>;
48
49    /// Solve `A X = B` for multiple right-hand sides (columns of B).
50    fn solve_multi(&self, b_columns: &[Vec<F>]) -> SparseResult<Vec<Vec<F>>> {
51        let mut results = Vec::with_capacity(b_columns.len());
52        for b in b_columns {
53            results.push(self.solve(b)?);
54        }
55        Ok(results)
56    }
57}
58
59// ---------------------------------------------------------------------------
60// Symbolic analysis result
61// ---------------------------------------------------------------------------
62
63/// Result of the symbolic analysis phase.
64///
65/// Contains the non-zero pattern of the factors and the elimination tree.
66#[derive(Debug, Clone)]
67pub struct SymbolicAnalysis {
68    /// Fill-reducing permutation (row/column index mapping).
69    pub perm: Vec<usize>,
70    /// Inverse permutation.
71    pub perm_inv: Vec<usize>,
72    /// Elimination tree: `parent[i]` is the parent of node i, or `usize::MAX` for roots.
73    pub etree: Vec<usize>,
74    /// Column pointers for the L factor non-zero pattern.
75    pub l_colptr: Vec<usize>,
76    /// Row indices for the L factor non-zero pattern.
77    pub l_rowind: Vec<usize>,
78    /// Column pointers for the U factor non-zero pattern (LU only).
79    pub u_colptr: Vec<usize>,
80    /// Row indices for the U factor non-zero pattern (LU only).
81    pub u_rowind: Vec<usize>,
82    /// Matrix dimension.
83    pub n: usize,
84}
85
86// ---------------------------------------------------------------------------
87// AMD Ordering
88// ---------------------------------------------------------------------------
89
90/// Approximate Minimum Degree (AMD) ordering.
91///
92/// Computes a fill-reducing permutation for a symmetric matrix (or the
93/// sparsity pattern A + A^T for an unsymmetric matrix). The algorithm
94/// greedily selects the node with the smallest approximate external degree
95/// at each step.
96///
97/// Returns a permutation vector `perm` such that `A[perm, perm]` has
98/// reduced fill during factorization.
99pub fn amd_ordering<F>(matrix: &CsrMatrix<F>) -> SparseResult<Vec<usize>>
100where
101    F: Float + NumAssign + Sum + SparseElement + Debug + 'static,
102{
103    let n = matrix.rows();
104    if n != matrix.cols() {
105        return Err(SparseError::ValueError(
106            "AMD ordering requires a square matrix".to_string(),
107        ));
108    }
109    if n == 0 {
110        return Ok(Vec::new());
111    }
112
113    // Build adjacency list for A + A^T (symmetric structure)
114    let mut adj: Vec<BTreeSet<usize>> = vec![BTreeSet::new(); n];
115    for i in 0..n {
116        let range = i_row_range(matrix, i);
117        for idx in range {
118            let j = matrix.indices[idx];
119            if i != j {
120                adj[i].insert(j);
121                adj[j].insert(i);
122            }
123        }
124    }
125
126    // Degree of each node
127    let mut degree: Vec<usize> = (0..n).map(|i| adj[i].len()).collect();
128    let mut eliminated = vec![false; n];
129    let mut perm = Vec::with_capacity(n);
130
131    for _ in 0..n {
132        // Find node with minimum approximate degree among non-eliminated nodes
133        let mut min_deg = usize::MAX;
134        let mut pivot = 0;
135        for (node, &deg) in degree.iter().enumerate() {
136            if !eliminated[node] && deg < min_deg {
137                min_deg = deg;
138                pivot = node;
139            }
140        }
141
142        eliminated[pivot] = true;
143        perm.push(pivot);
144
145        // Collect neighbours of pivot that are not yet eliminated
146        let neighbours: Vec<usize> = adj[pivot]
147            .iter()
148            .copied()
149            .filter(|&nb| !eliminated[nb])
150            .collect();
151
152        // "Absorb" pivot: connect all its neighbours to each other (clique)
153        for i in 0..neighbours.len() {
154            let u = neighbours[i];
155            adj[u].remove(&pivot);
156            for j in (i + 1)..neighbours.len() {
157                let v = neighbours[j];
158                adj[u].insert(v);
159                adj[v].insert(u);
160            }
161            degree[u] = adj[u].iter().filter(|&&nb| !eliminated[nb]).count();
162        }
163    }
164
165    Ok(perm)
166}
167
168/// Compute the inverse permutation.
169pub fn inverse_perm(perm: &[usize]) -> Vec<usize> {
170    let n = perm.len();
171    let mut inv = vec![0usize; n];
172    for (i, &p) in perm.iter().enumerate() {
173        inv[p] = i;
174    }
175    inv
176}
177
178// ---------------------------------------------------------------------------
179// Nested Dissection Ordering
180// ---------------------------------------------------------------------------
181
182/// Nested dissection ordering.
183///
184/// Recursively bisects the graph of the matrix using a simple graph
185/// partitioning heuristic (BFS-based bisection), numbering separators
186/// last. This is effective for matrices arising from 2D/3D discretisations.
187pub fn nested_dissection_ordering<F>(matrix: &CsrMatrix<F>) -> SparseResult<Vec<usize>>
188where
189    F: Float + NumAssign + Sum + SparseElement + Debug + 'static,
190{
191    let n = matrix.rows();
192    if n != matrix.cols() {
193        return Err(SparseError::ValueError(
194            "Nested dissection requires a square matrix".to_string(),
195        ));
196    }
197    if n == 0 {
198        return Ok(Vec::new());
199    }
200
201    // Build adjacency list for A + A^T
202    let mut adj: Vec<Vec<usize>> = vec![Vec::new(); n];
203    for i in 0..n {
204        let range = i_row_range(matrix, i);
205        for idx in range {
206            let j = matrix.indices[idx];
207            if i != j {
208                if !adj[i].contains(&j) {
209                    adj[i].push(j);
210                }
211                if !adj[j].contains(&i) {
212                    adj[j].push(i);
213                }
214            }
215        }
216    }
217
218    let nodes: Vec<usize> = (0..n).collect();
219    let mut perm = Vec::with_capacity(n);
220    nd_recurse(&adj, &nodes, &mut perm);
221
222    if perm.len() != n {
223        let in_perm: BTreeSet<usize> = perm.iter().copied().collect();
224        for i in 0..n {
225            if !in_perm.contains(&i) {
226                perm.push(i);
227            }
228        }
229    }
230
231    Ok(perm)
232}
233
234fn nd_recurse(adj: &[Vec<usize>], nodes: &[usize], perm: &mut Vec<usize>) {
235    if nodes.len() <= 64 {
236        perm.extend_from_slice(nodes);
237        return;
238    }
239
240    let start = find_pseudo_peripheral(adj, nodes);
241    let (part_a, separator, part_b) = bfs_bisect(adj, nodes, start);
242
243    if !part_a.is_empty() {
244        nd_recurse(adj, &part_a, perm);
245    }
246    if !part_b.is_empty() {
247        nd_recurse(adj, &part_b, perm);
248    }
249    perm.extend_from_slice(&separator);
250}
251
252fn find_pseudo_peripheral(adj: &[Vec<usize>], nodes: &[usize]) -> usize {
253    if nodes.is_empty() {
254        return 0;
255    }
256    let node_set: BTreeSet<usize> = nodes.iter().copied().collect();
257    let mut current = nodes[0];
258    for _ in 0..2 {
259        let levels = bfs_levels(adj, current, &node_set);
260        if let Some(last_level) = levels.last() {
261            if !last_level.is_empty() {
262                current = last_level[0];
263            }
264        }
265    }
266    current
267}
268
269fn bfs_levels(adj: &[Vec<usize>], start: usize, allowed: &BTreeSet<usize>) -> Vec<Vec<usize>> {
270    let mut visited = BTreeSet::new();
271    let mut levels: Vec<Vec<usize>> = Vec::new();
272    visited.insert(start);
273    levels.push(vec![start]);
274
275    loop {
276        let prev = match levels.last() {
277            Some(p) => p.clone(),
278            None => break,
279        };
280        let mut next_level = Vec::new();
281        for &node in &prev {
282            for &nb in &adj[node] {
283                if allowed.contains(&nb) && !visited.contains(&nb) {
284                    visited.insert(nb);
285                    next_level.push(nb);
286                }
287            }
288        }
289        if next_level.is_empty() {
290            break;
291        }
292        levels.push(next_level);
293    }
294    levels
295}
296
297fn bfs_bisect(
298    adj: &[Vec<usize>],
299    nodes: &[usize],
300    start: usize,
301) -> (Vec<usize>, Vec<usize>, Vec<usize>) {
302    let node_set: BTreeSet<usize> = nodes.iter().copied().collect();
303    let levels = bfs_levels(adj, start, &node_set);
304
305    let total = nodes.len();
306    let half = total / 2;
307
308    let mut count = 0;
309    let mut cut_level = 0;
310    for (li, level) in levels.iter().enumerate() {
311        count += level.len();
312        if count >= half {
313            cut_level = li;
314            break;
315        }
316    }
317
318    let mut part_a = Vec::new();
319    let mut separator = Vec::new();
320    let mut part_b = Vec::new();
321
322    for (li, level) in levels.iter().enumerate() {
323        if li < cut_level {
324            part_a.extend_from_slice(level);
325        } else if li == cut_level {
326            separator.extend_from_slice(level);
327        } else {
328            part_b.extend_from_slice(level);
329        }
330    }
331
332    let reached: BTreeSet<usize> = part_a
333        .iter()
334        .chain(separator.iter())
335        .chain(part_b.iter())
336        .copied()
337        .collect();
338    for &node in nodes {
339        if !reached.contains(&node) {
340            part_b.push(node);
341        }
342    }
343
344    (part_a, separator, part_b)
345}
346
347// ---------------------------------------------------------------------------
348// Elimination tree
349// ---------------------------------------------------------------------------
350
351/// Compute the elimination tree of a symmetric matrix.
352pub fn elimination_tree<F>(matrix: &CsrMatrix<F>, perm: &[usize]) -> Vec<usize>
353where
354    F: Float + SparseElement + Debug + 'static,
355{
356    let n = matrix.rows();
357    let perm_inv = inverse_perm(perm);
358    let mut parent = vec![usize::MAX; n];
359    let mut ancestor = vec![0usize; n];
360
361    for k in 0..n {
362        ancestor[k] = k;
363        let orig_row = perm[k];
364        let range = i_row_range(matrix, orig_row);
365        for idx in range {
366            let orig_col = matrix.indices[idx];
367            let j = perm_inv[orig_col];
368            if j < k {
369                let mut node = j;
370                loop {
371                    let next = ancestor[node];
372                    if next == k {
373                        break;
374                    }
375                    ancestor[node] = k;
376                    if parent[node] == usize::MAX || parent[node] > k {
377                        parent[node] = k;
378                    }
379                    if next == node {
380                        break;
381                    }
382                    node = next;
383                }
384            }
385        }
386    }
387    parent
388}
389
390// ---------------------------------------------------------------------------
391// Symbolic Cholesky
392// ---------------------------------------------------------------------------
393
394/// Perform symbolic analysis for Cholesky factorization.
395pub fn symbolic_cholesky<F>(matrix: &CsrMatrix<F>, perm: &[usize]) -> SparseResult<SymbolicAnalysis>
396where
397    F: Float + NumAssign + Sum + SparseElement + Debug + 'static,
398{
399    let n = matrix.rows();
400    if n != matrix.cols() {
401        return Err(SparseError::ValueError(
402            "Symbolic Cholesky requires a square matrix".to_string(),
403        ));
404    }
405    let perm_inv = inverse_perm(perm);
406    let etree = elimination_tree(matrix, perm);
407
408    let mut l_col_count = vec![1usize; n];
409    let mut visited = vec![usize::MAX; n];
410
411    for k in 0..n {
412        visited[k] = k;
413        let orig_row = perm[k];
414        let range = i_row_range(matrix, orig_row);
415        for idx in range {
416            let orig_col = matrix.indices[idx];
417            let j = perm_inv[orig_col];
418            if j < k {
419                let mut node = j;
420                while visited[node] != k {
421                    visited[node] = k;
422                    l_col_count[node] += 1;
423                    if etree[node] == usize::MAX || etree[node] >= n {
424                        break;
425                    }
426                    node = etree[node];
427                }
428            }
429        }
430    }
431
432    let mut l_colptr = vec![0usize; n + 1];
433    for j in 0..n {
434        l_colptr[j + 1] = l_colptr[j] + l_col_count[j];
435    }
436    let total_nnz = l_colptr[n];
437    let l_rowind = vec![0usize; total_nnz];
438
439    Ok(SymbolicAnalysis {
440        perm: perm.to_vec(),
441        perm_inv,
442        etree,
443        l_colptr,
444        l_rowind,
445        u_colptr: Vec::new(),
446        u_rowind: Vec::new(),
447        n,
448    })
449}
450
451// ---------------------------------------------------------------------------
452// Sparse Cholesky factorization (dense kernel)
453// ---------------------------------------------------------------------------
454
455/// Result of sparse Cholesky factorization (L * L^T = P*A*P^T).
456#[derive(Debug, Clone)]
457pub struct SparseCholResult<F> {
458    /// Dense lower-triangular factor L (row-major, n x n).
459    pub l_dense: Vec<Vec<F>>,
460    /// Permutation vector.
461    pub perm: Vec<usize>,
462    /// Inverse permutation.
463    pub perm_inv: Vec<usize>,
464    /// Dimension.
465    pub n: usize,
466}
467
468/// Sparse Cholesky solver for symmetric positive definite matrices.
469pub struct SparseCholeskySolver<F> {
470    result: Option<SparseCholResult<F>>,
471}
472
473impl<F: Float + NumAssign + Sum + SparseElement + Debug + 'static> SparseCholeskySolver<F> {
474    /// Create a new Cholesky solver (unfactorized).
475    pub fn new() -> Self {
476        Self { result: None }
477    }
478
479    /// Access the factorization result (if available).
480    pub fn factorization(&self) -> Option<&SparseCholResult<F>> {
481        self.result.as_ref()
482    }
483}
484
485impl<F: Float + NumAssign + Sum + SparseElement + Debug + 'static> Default
486    for SparseCholeskySolver<F>
487{
488    fn default() -> Self {
489        Self::new()
490    }
491}
492
493impl<F: Float + NumAssign + Sum + SparseElement + Debug + 'static> SparseSolver<F>
494    for SparseCholeskySolver<F>
495{
496    fn factorize(&mut self, matrix: &CsrMatrix<F>) -> SparseResult<()> {
497        let n = matrix.rows();
498        if n != matrix.cols() {
499            return Err(SparseError::ValueError(
500                "Cholesky requires a square matrix".to_string(),
501            ));
502        }
503        if n == 0 {
504            self.result = Some(SparseCholResult {
505                l_dense: Vec::new(),
506                perm: Vec::new(),
507                perm_inv: Vec::new(),
508                n: 0,
509            });
510            return Ok(());
511        }
512
513        // AMD ordering
514        let perm = amd_ordering(matrix)?;
515        let perm_inv = inverse_perm(&perm);
516
517        // Build dense permuted matrix B = P * A * P^T
518        let mut b_dense = vec![vec![F::sparse_zero(); n]; n];
519        for i in 0..n {
520            let orig_row = perm[i];
521            let range = i_row_range(matrix, orig_row);
522            for idx in range {
523                let orig_col = matrix.indices[idx];
524                let j = perm_inv[orig_col];
525                b_dense[i][j] += matrix.data[idx];
526            }
527        }
528
529        // Dense Cholesky: L * L^T = B (row-by-row, lower triangular)
530        let mut l = vec![vec![F::sparse_zero(); n]; n];
531        for i in 0..n {
532            for j in 0..=i {
533                let mut sum = b_dense[i][j];
534                for k in 0..j {
535                    sum -= l[i][k] * l[j][k];
536                }
537                if i == j {
538                    if sum <= F::sparse_zero() {
539                        return Err(SparseError::ValueError(format!(
540                            "Matrix is not positive definite: non-positive diagonal at row {i}"
541                        )));
542                    }
543                    l[i][j] = sum.sqrt();
544                } else {
545                    let ljj = l[j][j];
546                    if ljj.abs() < F::epsilon() {
547                        return Err(SparseError::SingularMatrix(format!(
548                            "Zero diagonal in L at row {j}"
549                        )));
550                    }
551                    l[i][j] = sum / ljj;
552                }
553            }
554        }
555
556        self.result = Some(SparseCholResult {
557            l_dense: l,
558            perm,
559            perm_inv,
560            n,
561        });
562        Ok(())
563    }
564
565    fn solve(&self, b: &[F]) -> SparseResult<Vec<F>> {
566        let res = self.result.as_ref().ok_or_else(|| {
567            SparseError::ValueError("Cholesky factorization not computed".to_string())
568        })?;
569        let n = res.n;
570        if b.len() != n {
571            return Err(SparseError::DimensionMismatch {
572                expected: n,
573                found: b.len(),
574            });
575        }
576        if n == 0 {
577            return Ok(Vec::new());
578        }
579
580        // bp[i] = b[perm[i]]
581        let mut y = vec![F::sparse_zero(); n];
582        for i in 0..n {
583            y[i] = b[res.perm[i]];
584        }
585
586        // Forward solve: L y = bp
587        for i in 0..n {
588            for j in 0..i {
589                y[i] = y[i] - res.l_dense[i][j] * y[j];
590            }
591            let d = res.l_dense[i][i];
592            if d.abs() < F::epsilon() {
593                return Err(SparseError::SingularMatrix(
594                    "Zero diagonal in L during solve".to_string(),
595                ));
596            }
597            y[i] /= d;
598        }
599
600        // Backward solve: L^T xp = y
601        for i in (0..n).rev() {
602            for j in (i + 1)..n {
603                y[i] = y[i] - res.l_dense[j][i] * y[j];
604            }
605            let d = res.l_dense[i][i];
606            if d.abs() < F::epsilon() {
607                return Err(SparseError::SingularMatrix(
608                    "Zero diagonal in L^T during solve".to_string(),
609                ));
610            }
611            y[i] /= d;
612        }
613
614        // x[perm[i]] = y[i]
615        let mut x = vec![F::sparse_zero(); n];
616        for i in 0..n {
617            x[res.perm[i]] = y[i];
618        }
619        Ok(x)
620    }
621}
622
623// ---------------------------------------------------------------------------
624// Sparse LU factorization (dense kernel)
625// ---------------------------------------------------------------------------
626
627/// Result of sparse LU factorization (P*A*Q = L*U).
628#[derive(Debug, Clone)]
629pub struct SparseLuResult<F> {
630    /// Dense LU factors in-place (L below diagonal with unit diagonal, U on+above diagonal).
631    pub lu_dense: Vec<Vec<F>>,
632    /// Row permutation (pivoting).
633    pub row_perm: Vec<usize>,
634    /// Column permutation (fill-reducing ordering).
635    pub col_perm: Vec<usize>,
636    /// Dimension.
637    pub n: usize,
638}
639
640/// Sparse LU solver with partial pivoting.
641pub struct SparseLuSolver<F> {
642    result: Option<SparseLuResult<F>>,
643}
644
645impl<F: Float + NumAssign + Sum + SparseElement + Debug + 'static> SparseLuSolver<F> {
646    /// Create a new LU solver (unfactorized).
647    pub fn new() -> Self {
648        Self { result: None }
649    }
650
651    /// Access the factorization result (if available).
652    pub fn factorization(&self) -> Option<&SparseLuResult<F>> {
653        self.result.as_ref()
654    }
655}
656
657impl<F: Float + NumAssign + Sum + SparseElement + Debug + 'static> Default for SparseLuSolver<F> {
658    fn default() -> Self {
659        Self::new()
660    }
661}
662
663impl<F: Float + NumAssign + Sum + SparseElement + Debug + 'static> SparseSolver<F>
664    for SparseLuSolver<F>
665{
666    fn factorize(&mut self, matrix: &CsrMatrix<F>) -> SparseResult<()> {
667        let n = matrix.rows();
668        if n != matrix.cols() {
669            return Err(SparseError::ValueError(
670                "LU requires a square matrix".to_string(),
671            ));
672        }
673        if n == 0 {
674            self.result = Some(SparseLuResult {
675                lu_dense: Vec::new(),
676                row_perm: Vec::new(),
677                col_perm: Vec::new(),
678                n: 0,
679            });
680            return Ok(());
681        }
682
683        // Column ordering (AMD)
684        let col_perm = amd_ordering(matrix)?;
685        let col_perm_inv = inverse_perm(&col_perm);
686
687        // Build dense matrix: a[i][j] = A[i][col_perm[j]]
688        let mut a = vec![vec![F::sparse_zero(); n]; n];
689        for i in 0..n {
690            let range = i_row_range(matrix, i);
691            for idx in range {
692                let orig_col = matrix.indices[idx];
693                let j = col_perm_inv[orig_col];
694                a[i][j] += matrix.data[idx];
695            }
696        }
697
698        // Dense LU with partial pivoting (in-place)
699        let mut row_perm: Vec<usize> = (0..n).collect();
700
701        for k in 0..n {
702            // Find pivot
703            let mut max_abs = F::sparse_zero();
704            let mut pivot = k;
705            for i in k..n {
706                if a[i][k].abs() > max_abs {
707                    max_abs = a[i][k].abs();
708                    pivot = i;
709                }
710            }
711
712            if pivot != k {
713                a.swap(k, pivot);
714                row_perm.swap(k, pivot);
715            }
716
717            let akk = a[k][k];
718            if akk.abs() < F::epsilon() {
719                continue; // near-singular column
720            }
721
722            for i in (k + 1)..n {
723                let lik = a[i][k] / akk;
724                a[i][k] = lik; // L part
725                for j in (k + 1)..n {
726                    let ukj = a[k][j];
727                    a[i][j] -= lik * ukj;
728                }
729            }
730        }
731
732        self.result = Some(SparseLuResult {
733            lu_dense: a,
734            row_perm,
735            col_perm,
736            n,
737        });
738        Ok(())
739    }
740
741    fn solve(&self, b: &[F]) -> SparseResult<Vec<F>> {
742        let res = self
743            .result
744            .as_ref()
745            .ok_or_else(|| SparseError::ValueError("LU factorization not computed".to_string()))?;
746        let n = res.n;
747        if b.len() != n {
748            return Err(SparseError::DimensionMismatch {
749                expected: n,
750                found: b.len(),
751            });
752        }
753        if n == 0 {
754            return Ok(Vec::new());
755        }
756
757        // Apply row permutation
758        let mut x = vec![F::sparse_zero(); n];
759        for i in 0..n {
760            x[i] = b[res.row_perm[i]];
761        }
762
763        // Forward solve: L y = Pb (unit diagonal)
764        for i in 0..n {
765            for j in 0..i {
766                x[i] = x[i] - res.lu_dense[i][j] * x[j];
767            }
768        }
769
770        // Backward solve: U z = y
771        for i in (0..n).rev() {
772            for j in (i + 1)..n {
773                x[i] = x[i] - res.lu_dense[i][j] * x[j];
774            }
775            let d = res.lu_dense[i][i];
776            if d.abs() < F::epsilon() {
777                return Err(SparseError::SingularMatrix(format!(
778                    "Zero diagonal in U at row {i}"
779                )));
780            }
781            x[i] /= d;
782        }
783
784        // Apply inverse column permutation: result[col_perm[j]] = x[j]
785        let mut result = vec![F::sparse_zero(); n];
786        for j in 0..n {
787            result[res.col_perm[j]] = x[j];
788        }
789        Ok(result)
790    }
791}
792
793// ---------------------------------------------------------------------------
794// Convenience functions
795// ---------------------------------------------------------------------------
796
797/// Solve Ax = b using sparse LU factorization with AMD ordering.
798pub fn sparse_lu_solve<F>(matrix: &CsrMatrix<F>, b: &[F]) -> SparseResult<Vec<F>>
799where
800    F: Float + NumAssign + Sum + SparseElement + Debug + 'static,
801{
802    let mut solver = SparseLuSolver::new();
803    solver.factorize(matrix)?;
804    solver.solve(b)
805}
806
807/// Solve Ax = b using sparse Cholesky (matrix must be SPD).
808pub fn sparse_cholesky_solve<F>(matrix: &CsrMatrix<F>, b: &[F]) -> SparseResult<Vec<F>>
809where
810    F: Float + NumAssign + Sum + SparseElement + Debug + 'static,
811{
812    let mut solver = SparseCholeskySolver::new();
813    solver.factorize(matrix)?;
814    solver.solve(b)
815}
816
817// ---------------------------------------------------------------------------
818// Internal helpers
819// ---------------------------------------------------------------------------
820
821/// Safe row range extraction for CsrMatrix.
822fn i_row_range<F: SparseElement + Clone + Copy + scirs2_core::numeric::Zero + PartialEq>(
823    matrix: &CsrMatrix<F>,
824    row: usize,
825) -> std::ops::Range<usize> {
826    if row >= matrix.rows() {
827        return 0..0;
828    }
829    matrix.indptr[row]..matrix.indptr[row + 1]
830}
831
832// ---------------------------------------------------------------------------
833// Tests
834// ---------------------------------------------------------------------------
835
836#[cfg(test)]
837mod tests {
838    use super::*;
839
840    /// A = [[4, 1, 0], [1, 5, 2], [0, 2, 6]]
841    fn create_spd_3x3() -> CsrMatrix<f64> {
842        let rows = vec![0, 0, 1, 1, 1, 2, 2];
843        let cols = vec![0, 1, 0, 1, 2, 1, 2];
844        let data = vec![4.0, 1.0, 1.0, 5.0, 2.0, 2.0, 6.0];
845        CsrMatrix::new(data, rows, cols, (3, 3)).expect("Failed to create SPD matrix")
846    }
847
848    fn create_general_3x3() -> CsrMatrix<f64> {
849        let rows = vec![0, 0, 0, 1, 1, 1, 2, 2, 2];
850        let cols = vec![0, 1, 2, 0, 1, 2, 0, 1, 2];
851        let data = vec![3.0, 1.0, 2.0, 1.0, 4.0, 1.0, 0.0, 1.0, 5.0];
852        CsrMatrix::new(data, rows, cols, (3, 3)).expect("Failed to create matrix")
853    }
854
855    fn create_spd_4x4() -> CsrMatrix<f64> {
856        let rows = vec![0, 0, 1, 1, 1, 2, 2, 2, 3, 3];
857        let cols = vec![0, 1, 0, 1, 2, 1, 2, 3, 2, 3];
858        let data = vec![4.0, 1.0, 1.0, 4.0, 1.0, 1.0, 4.0, 1.0, 1.0, 4.0];
859        CsrMatrix::new(data, rows, cols, (4, 4)).expect("Failed to create SPD 4x4")
860    }
861
862    fn verify_solve(mat: &CsrMatrix<f64>, x: &[f64], b: &[f64], tol: f64) {
863        let dense = mat.to_dense();
864        let n = b.len();
865        for i in 0..n {
866            let mut row_sum = 0.0;
867            for j in 0..n {
868                row_sum += dense[i][j] * x[j];
869            }
870            assert!(
871                (row_sum - b[i]).abs() < tol,
872                "Row {i}: residual {}",
873                (row_sum - b[i]).abs()
874            );
875        }
876    }
877
878    #[test]
879    fn test_amd_ordering_basic() {
880        let mat = create_spd_3x3();
881        let perm = amd_ordering(&mat).expect("AMD failed");
882        assert_eq!(perm.len(), 3);
883        let mut sorted = perm.clone();
884        sorted.sort();
885        assert_eq!(sorted, vec![0, 1, 2]);
886    }
887
888    #[test]
889    fn test_amd_ordering_empty() {
890        let mat =
891            CsrMatrix::<f64>::new(vec![], vec![], vec![], (0, 0)).expect("Failed to create empty");
892        let perm = amd_ordering(&mat).expect("AMD failed on empty");
893        assert!(perm.is_empty());
894    }
895
896    #[test]
897    fn test_inverse_perm() {
898        let perm = vec![2, 0, 1];
899        let inv = inverse_perm(&perm);
900        assert_eq!(inv, vec![1, 2, 0]);
901        for i in 0..3 {
902            assert_eq!(perm[inv[i]], i);
903        }
904    }
905
906    #[test]
907    fn test_nested_dissection_basic() {
908        let mat = create_spd_4x4();
909        let perm = nested_dissection_ordering(&mat).expect("ND failed");
910        assert_eq!(perm.len(), 4);
911        let mut sorted = perm.clone();
912        sorted.sort();
913        assert_eq!(sorted, vec![0, 1, 2, 3]);
914    }
915
916    #[test]
917    fn test_elimination_tree() {
918        let mat = create_spd_3x3();
919        let perm: Vec<usize> = (0..3).collect();
920        let etree = elimination_tree(&mat, &perm);
921        assert_eq!(etree.len(), 3);
922    }
923
924    #[test]
925    fn test_cholesky_solve_3x3() {
926        let mat = create_spd_3x3();
927        let b = vec![5.0, 8.0, 8.0];
928        let x = sparse_cholesky_solve(&mat, &b).expect("Cholesky solve failed");
929        assert_eq!(x.len(), 3);
930        for (i, &xi) in x.iter().enumerate() {
931            assert!((xi - 1.0).abs() < 1e-10, "x[{i}] = {xi}, expected 1.0");
932        }
933    }
934
935    #[test]
936    fn test_cholesky_solve_4x4() {
937        let mat = create_spd_4x4();
938        let b = vec![5.0, 6.0, 6.0, 5.0];
939        let x = sparse_cholesky_solve(&mat, &b).expect("Cholesky solve 4x4 failed");
940        verify_solve(&mat, &x, &b, 1e-10);
941    }
942
943    #[test]
944    fn test_cholesky_non_spd() {
945        let rows = vec![0, 1, 2];
946        let cols = vec![0, 1, 2];
947        let data = vec![-1.0, 1.0, 1.0];
948        let mat = CsrMatrix::new(data, rows, cols, (3, 3)).expect("Failed to create matrix");
949        let result = sparse_cholesky_solve(&mat, &[1.0, 1.0, 1.0]);
950        assert!(result.is_err());
951    }
952
953    #[test]
954    fn test_lu_solve_3x3() {
955        let mat = create_general_3x3();
956        let b = vec![6.0, 6.0, 6.0];
957        let x = sparse_lu_solve(&mat, &b).expect("LU solve failed");
958        verify_solve(&mat, &x, &b, 1e-9);
959    }
960
961    #[test]
962    fn test_lu_solve_identity() {
963        let rows = vec![0, 1, 2, 3];
964        let cols = vec![0, 1, 2, 3];
965        let data = vec![1.0, 1.0, 1.0, 1.0];
966        let mat = CsrMatrix::new(data, rows, cols, (4, 4)).expect("Failed to create identity");
967        let b = vec![1.0, 2.0, 3.0, 4.0];
968        let x = sparse_lu_solve(&mat, &b).expect("LU solve on identity failed");
969        for i in 0..4 {
970            assert!(
971                (x[i] - b[i]).abs() < 1e-12,
972                "x[{i}] = {}, expected {}",
973                x[i],
974                b[i]
975            );
976        }
977    }
978
979    #[test]
980    fn test_lu_solve_multi() {
981        let mat = create_general_3x3();
982        let mut solver = SparseLuSolver::new();
983        solver.factorize(&mat).expect("LU factorize failed");
984
985        let b1 = vec![6.0, 6.0, 6.0];
986        let b2 = vec![3.0, 1.0, 2.0];
987        let results = solver
988            .solve_multi(&[b1.clone(), b2.clone()])
989            .expect("Solve multi failed");
990
991        verify_solve(&mat, &results[0], &b1, 1e-9);
992        verify_solve(&mat, &results[1], &b2, 1e-9);
993    }
994
995    #[test]
996    fn test_cholesky_solver_trait() {
997        let mat = create_spd_3x3();
998        let mut solver = SparseCholeskySolver::new();
999        solver.factorize(&mat).expect("Factorize failed");
1000        assert!(solver.factorization().is_some());
1001
1002        let b = vec![5.0, 8.0, 8.0];
1003        let x = solver.solve(&b).expect("Solve failed");
1004        for (i, xi) in x.iter().enumerate() {
1005            assert!((xi - 1.0).abs() < 1e-10, "x[{i}] = {xi}");
1006        }
1007    }
1008
1009    #[test]
1010    fn test_lu_empty_matrix() {
1011        let mat =
1012            CsrMatrix::<f64>::new(vec![], vec![], vec![], (0, 0)).expect("Failed to create empty");
1013        let mut solver = SparseLuSolver::new();
1014        solver
1015            .factorize(&mat)
1016            .expect("LU factorize on empty failed");
1017        let x = solver.solve(&[]).expect("LU solve on empty failed");
1018        assert!(x.is_empty());
1019    }
1020
1021    #[test]
1022    fn test_cholesky_dimension_mismatch() {
1023        let mat = create_spd_3x3();
1024        let mut solver = SparseCholeskySolver::new();
1025        solver.factorize(&mat).expect("Factorize failed");
1026        let result = solver.solve(&[1.0, 2.0]);
1027        assert!(result.is_err());
1028    }
1029
1030    #[test]
1031    fn test_lu_solve_5x5_diag_dominant() {
1032        let mut rows = Vec::new();
1033        let mut cols = Vec::new();
1034        let mut data = Vec::new();
1035        for i in 0..5 {
1036            for j in 0..5 {
1037                if i == j {
1038                    rows.push(i);
1039                    cols.push(j);
1040                    data.push(10.0);
1041                } else if (i as isize - j as isize).unsigned_abs() <= 1 {
1042                    rows.push(i);
1043                    cols.push(j);
1044                    data.push(1.0);
1045                }
1046            }
1047        }
1048        let mat = CsrMatrix::new(data, rows, cols, (5, 5)).expect("Failed to create 5x5");
1049        let b = vec![12.0, 12.0, 12.0, 12.0, 12.0];
1050        let x = sparse_lu_solve(&mat, &b).expect("LU 5x5 failed");
1051        verify_solve(&mat, &x, &b, 1e-8);
1052    }
1053
1054    #[test]
1055    fn test_symbolic_cholesky() {
1056        let mat = create_spd_3x3();
1057        let perm: Vec<usize> = (0..3).collect();
1058        let analysis = symbolic_cholesky(&mat, &perm).expect("Symbolic Cholesky failed");
1059        assert_eq!(analysis.n, 3);
1060        assert_eq!(analysis.l_colptr.len(), 4);
1061        assert!(analysis.l_colptr[3] >= 3);
1062    }
1063
1064    #[test]
1065    fn test_lu_non_square_error() {
1066        let rows = vec![0, 1];
1067        let cols = vec![0, 0];
1068        let data = vec![1.0, 2.0];
1069        let mat = CsrMatrix::new(data, rows, cols, (2, 3)).expect("Failed to create non-square");
1070        let result = sparse_lu_solve(&mat, &[1.0, 2.0]);
1071        assert!(result.is_err());
1072    }
1073
1074    #[test]
1075    fn test_cholesky_non_square_error() {
1076        let rows = vec![0, 1, 2];
1077        let cols = vec![0, 0, 0];
1078        let data = vec![1.0, 2.0, 3.0];
1079        let mat = CsrMatrix::new(data, rows, cols, (3, 4)).expect("Failed to create non-square");
1080        let result = sparse_cholesky_solve(&mat, &[1.0, 2.0, 3.0]);
1081        assert!(result.is_err());
1082    }
1083
1084    #[test]
1085    fn test_lu_solve_with_zeros() {
1086        let rows = vec![0, 0, 1, 2, 2];
1087        let cols = vec![0, 2, 1, 0, 2];
1088        let data = vec![2.0, 1.0, 3.0, 1.0, 4.0];
1089        let mat = CsrMatrix::new(data, rows, cols, (3, 3)).expect("Failed");
1090        let b = vec![3.0, 3.0, 5.0];
1091        let x = sparse_lu_solve(&mat, &b).expect("LU solve sparse matrix failed");
1092        verify_solve(&mat, &x, &b, 1e-9);
1093    }
1094
1095    #[test]
1096    fn test_amd_non_square_error() {
1097        let rows = vec![0];
1098        let cols = vec![0];
1099        let data = vec![1.0];
1100        let mat = CsrMatrix::new(data, rows, cols, (2, 3)).expect("Failed");
1101        let result = amd_ordering(&mat);
1102        assert!(result.is_err());
1103    }
1104}