1use crate::csr::CsrMatrix;
27use crate::error::{SparseError, SparseResult};
28use scirs2_core::numeric::{Float, NumAssign, SparseElement};
29use std::collections::BTreeSet;
30use std::fmt::Debug;
31use std::iter::Sum;
32
33pub trait SparseSolver<F: Float> {
43 fn factorize(&mut self, matrix: &CsrMatrix<F>) -> SparseResult<()>;
45
46 fn solve(&self, b: &[F]) -> SparseResult<Vec<F>>;
48
49 fn solve_multi(&self, b_columns: &[Vec<F>]) -> SparseResult<Vec<Vec<F>>> {
51 let mut results = Vec::with_capacity(b_columns.len());
52 for b in b_columns {
53 results.push(self.solve(b)?);
54 }
55 Ok(results)
56 }
57}
58
59#[derive(Debug, Clone)]
67pub struct SymbolicAnalysis {
68 pub perm: Vec<usize>,
70 pub perm_inv: Vec<usize>,
72 pub etree: Vec<usize>,
74 pub l_colptr: Vec<usize>,
76 pub l_rowind: Vec<usize>,
78 pub u_colptr: Vec<usize>,
80 pub u_rowind: Vec<usize>,
82 pub n: usize,
84}
85
86pub fn amd_ordering<F>(matrix: &CsrMatrix<F>) -> SparseResult<Vec<usize>>
100where
101 F: Float + NumAssign + Sum + SparseElement + Debug + 'static,
102{
103 let n = matrix.rows();
104 if n != matrix.cols() {
105 return Err(SparseError::ValueError(
106 "AMD ordering requires a square matrix".to_string(),
107 ));
108 }
109 if n == 0 {
110 return Ok(Vec::new());
111 }
112
113 let mut adj: Vec<BTreeSet<usize>> = vec![BTreeSet::new(); n];
115 for i in 0..n {
116 let range = i_row_range(matrix, i);
117 for idx in range {
118 let j = matrix.indices[idx];
119 if i != j {
120 adj[i].insert(j);
121 adj[j].insert(i);
122 }
123 }
124 }
125
126 let mut degree: Vec<usize> = (0..n).map(|i| adj[i].len()).collect();
128 let mut eliminated = vec![false; n];
129 let mut perm = Vec::with_capacity(n);
130
131 for _ in 0..n {
132 let mut min_deg = usize::MAX;
134 let mut pivot = 0;
135 for (node, °) in degree.iter().enumerate() {
136 if !eliminated[node] && deg < min_deg {
137 min_deg = deg;
138 pivot = node;
139 }
140 }
141
142 eliminated[pivot] = true;
143 perm.push(pivot);
144
145 let neighbours: Vec<usize> = adj[pivot]
147 .iter()
148 .copied()
149 .filter(|&nb| !eliminated[nb])
150 .collect();
151
152 for i in 0..neighbours.len() {
154 let u = neighbours[i];
155 adj[u].remove(&pivot);
156 for j in (i + 1)..neighbours.len() {
157 let v = neighbours[j];
158 adj[u].insert(v);
159 adj[v].insert(u);
160 }
161 degree[u] = adj[u].iter().filter(|&&nb| !eliminated[nb]).count();
162 }
163 }
164
165 Ok(perm)
166}
167
168pub fn inverse_perm(perm: &[usize]) -> Vec<usize> {
170 let n = perm.len();
171 let mut inv = vec![0usize; n];
172 for (i, &p) in perm.iter().enumerate() {
173 inv[p] = i;
174 }
175 inv
176}
177
178pub fn nested_dissection_ordering<F>(matrix: &CsrMatrix<F>) -> SparseResult<Vec<usize>>
188where
189 F: Float + NumAssign + Sum + SparseElement + Debug + 'static,
190{
191 let n = matrix.rows();
192 if n != matrix.cols() {
193 return Err(SparseError::ValueError(
194 "Nested dissection requires a square matrix".to_string(),
195 ));
196 }
197 if n == 0 {
198 return Ok(Vec::new());
199 }
200
201 let mut adj: Vec<Vec<usize>> = vec![Vec::new(); n];
203 for i in 0..n {
204 let range = i_row_range(matrix, i);
205 for idx in range {
206 let j = matrix.indices[idx];
207 if i != j {
208 if !adj[i].contains(&j) {
209 adj[i].push(j);
210 }
211 if !adj[j].contains(&i) {
212 adj[j].push(i);
213 }
214 }
215 }
216 }
217
218 let nodes: Vec<usize> = (0..n).collect();
219 let mut perm = Vec::with_capacity(n);
220 nd_recurse(&adj, &nodes, &mut perm);
221
222 if perm.len() != n {
223 let in_perm: BTreeSet<usize> = perm.iter().copied().collect();
224 for i in 0..n {
225 if !in_perm.contains(&i) {
226 perm.push(i);
227 }
228 }
229 }
230
231 Ok(perm)
232}
233
234fn nd_recurse(adj: &[Vec<usize>], nodes: &[usize], perm: &mut Vec<usize>) {
235 if nodes.len() <= 64 {
236 perm.extend_from_slice(nodes);
237 return;
238 }
239
240 let start = find_pseudo_peripheral(adj, nodes);
241 let (part_a, separator, part_b) = bfs_bisect(adj, nodes, start);
242
243 if !part_a.is_empty() {
244 nd_recurse(adj, &part_a, perm);
245 }
246 if !part_b.is_empty() {
247 nd_recurse(adj, &part_b, perm);
248 }
249 perm.extend_from_slice(&separator);
250}
251
252fn find_pseudo_peripheral(adj: &[Vec<usize>], nodes: &[usize]) -> usize {
253 if nodes.is_empty() {
254 return 0;
255 }
256 let node_set: BTreeSet<usize> = nodes.iter().copied().collect();
257 let mut current = nodes[0];
258 for _ in 0..2 {
259 let levels = bfs_levels(adj, current, &node_set);
260 if let Some(last_level) = levels.last() {
261 if !last_level.is_empty() {
262 current = last_level[0];
263 }
264 }
265 }
266 current
267}
268
269fn bfs_levels(adj: &[Vec<usize>], start: usize, allowed: &BTreeSet<usize>) -> Vec<Vec<usize>> {
270 let mut visited = BTreeSet::new();
271 let mut levels: Vec<Vec<usize>> = Vec::new();
272 visited.insert(start);
273 levels.push(vec![start]);
274
275 loop {
276 let prev = match levels.last() {
277 Some(p) => p.clone(),
278 None => break,
279 };
280 let mut next_level = Vec::new();
281 for &node in &prev {
282 for &nb in &adj[node] {
283 if allowed.contains(&nb) && !visited.contains(&nb) {
284 visited.insert(nb);
285 next_level.push(nb);
286 }
287 }
288 }
289 if next_level.is_empty() {
290 break;
291 }
292 levels.push(next_level);
293 }
294 levels
295}
296
297fn bfs_bisect(
298 adj: &[Vec<usize>],
299 nodes: &[usize],
300 start: usize,
301) -> (Vec<usize>, Vec<usize>, Vec<usize>) {
302 let node_set: BTreeSet<usize> = nodes.iter().copied().collect();
303 let levels = bfs_levels(adj, start, &node_set);
304
305 let total = nodes.len();
306 let half = total / 2;
307
308 let mut count = 0;
309 let mut cut_level = 0;
310 for (li, level) in levels.iter().enumerate() {
311 count += level.len();
312 if count >= half {
313 cut_level = li;
314 break;
315 }
316 }
317
318 let mut part_a = Vec::new();
319 let mut separator = Vec::new();
320 let mut part_b = Vec::new();
321
322 for (li, level) in levels.iter().enumerate() {
323 if li < cut_level {
324 part_a.extend_from_slice(level);
325 } else if li == cut_level {
326 separator.extend_from_slice(level);
327 } else {
328 part_b.extend_from_slice(level);
329 }
330 }
331
332 let reached: BTreeSet<usize> = part_a
333 .iter()
334 .chain(separator.iter())
335 .chain(part_b.iter())
336 .copied()
337 .collect();
338 for &node in nodes {
339 if !reached.contains(&node) {
340 part_b.push(node);
341 }
342 }
343
344 (part_a, separator, part_b)
345}
346
347pub fn elimination_tree<F>(matrix: &CsrMatrix<F>, perm: &[usize]) -> Vec<usize>
353where
354 F: Float + SparseElement + Debug + 'static,
355{
356 let n = matrix.rows();
357 let perm_inv = inverse_perm(perm);
358 let mut parent = vec![usize::MAX; n];
359 let mut ancestor = vec![0usize; n];
360
361 for k in 0..n {
362 ancestor[k] = k;
363 let orig_row = perm[k];
364 let range = i_row_range(matrix, orig_row);
365 for idx in range {
366 let orig_col = matrix.indices[idx];
367 let j = perm_inv[orig_col];
368 if j < k {
369 let mut node = j;
370 loop {
371 let next = ancestor[node];
372 if next == k {
373 break;
374 }
375 ancestor[node] = k;
376 if parent[node] == usize::MAX || parent[node] > k {
377 parent[node] = k;
378 }
379 if next == node {
380 break;
381 }
382 node = next;
383 }
384 }
385 }
386 }
387 parent
388}
389
390pub fn symbolic_cholesky<F>(matrix: &CsrMatrix<F>, perm: &[usize]) -> SparseResult<SymbolicAnalysis>
396where
397 F: Float + NumAssign + Sum + SparseElement + Debug + 'static,
398{
399 let n = matrix.rows();
400 if n != matrix.cols() {
401 return Err(SparseError::ValueError(
402 "Symbolic Cholesky requires a square matrix".to_string(),
403 ));
404 }
405 let perm_inv = inverse_perm(perm);
406 let etree = elimination_tree(matrix, perm);
407
408 let mut l_col_count = vec![1usize; n];
409 let mut visited = vec![usize::MAX; n];
410
411 for k in 0..n {
412 visited[k] = k;
413 let orig_row = perm[k];
414 let range = i_row_range(matrix, orig_row);
415 for idx in range {
416 let orig_col = matrix.indices[idx];
417 let j = perm_inv[orig_col];
418 if j < k {
419 let mut node = j;
420 while visited[node] != k {
421 visited[node] = k;
422 l_col_count[node] += 1;
423 if etree[node] == usize::MAX || etree[node] >= n {
424 break;
425 }
426 node = etree[node];
427 }
428 }
429 }
430 }
431
432 let mut l_colptr = vec![0usize; n + 1];
433 for j in 0..n {
434 l_colptr[j + 1] = l_colptr[j] + l_col_count[j];
435 }
436 let total_nnz = l_colptr[n];
437 let l_rowind = vec![0usize; total_nnz];
438
439 Ok(SymbolicAnalysis {
440 perm: perm.to_vec(),
441 perm_inv,
442 etree,
443 l_colptr,
444 l_rowind,
445 u_colptr: Vec::new(),
446 u_rowind: Vec::new(),
447 n,
448 })
449}
450
451#[derive(Debug, Clone)]
457pub struct SparseCholResult<F> {
458 pub l_dense: Vec<Vec<F>>,
460 pub perm: Vec<usize>,
462 pub perm_inv: Vec<usize>,
464 pub n: usize,
466}
467
468pub struct SparseCholeskySolver<F> {
470 result: Option<SparseCholResult<F>>,
471}
472
473impl<F: Float + NumAssign + Sum + SparseElement + Debug + 'static> SparseCholeskySolver<F> {
474 pub fn new() -> Self {
476 Self { result: None }
477 }
478
479 pub fn factorization(&self) -> Option<&SparseCholResult<F>> {
481 self.result.as_ref()
482 }
483}
484
485impl<F: Float + NumAssign + Sum + SparseElement + Debug + 'static> Default
486 for SparseCholeskySolver<F>
487{
488 fn default() -> Self {
489 Self::new()
490 }
491}
492
493impl<F: Float + NumAssign + Sum + SparseElement + Debug + 'static> SparseSolver<F>
494 for SparseCholeskySolver<F>
495{
496 fn factorize(&mut self, matrix: &CsrMatrix<F>) -> SparseResult<()> {
497 let n = matrix.rows();
498 if n != matrix.cols() {
499 return Err(SparseError::ValueError(
500 "Cholesky requires a square matrix".to_string(),
501 ));
502 }
503 if n == 0 {
504 self.result = Some(SparseCholResult {
505 l_dense: Vec::new(),
506 perm: Vec::new(),
507 perm_inv: Vec::new(),
508 n: 0,
509 });
510 return Ok(());
511 }
512
513 let perm = amd_ordering(matrix)?;
515 let perm_inv = inverse_perm(&perm);
516
517 let mut b_dense = vec![vec![F::sparse_zero(); n]; n];
519 for i in 0..n {
520 let orig_row = perm[i];
521 let range = i_row_range(matrix, orig_row);
522 for idx in range {
523 let orig_col = matrix.indices[idx];
524 let j = perm_inv[orig_col];
525 b_dense[i][j] += matrix.data[idx];
526 }
527 }
528
529 let mut l = vec![vec![F::sparse_zero(); n]; n];
531 for i in 0..n {
532 for j in 0..=i {
533 let mut sum = b_dense[i][j];
534 for k in 0..j {
535 sum -= l[i][k] * l[j][k];
536 }
537 if i == j {
538 if sum <= F::sparse_zero() {
539 return Err(SparseError::ValueError(format!(
540 "Matrix is not positive definite: non-positive diagonal at row {i}"
541 )));
542 }
543 l[i][j] = sum.sqrt();
544 } else {
545 let ljj = l[j][j];
546 if ljj.abs() < F::epsilon() {
547 return Err(SparseError::SingularMatrix(format!(
548 "Zero diagonal in L at row {j}"
549 )));
550 }
551 l[i][j] = sum / ljj;
552 }
553 }
554 }
555
556 self.result = Some(SparseCholResult {
557 l_dense: l,
558 perm,
559 perm_inv,
560 n,
561 });
562 Ok(())
563 }
564
565 fn solve(&self, b: &[F]) -> SparseResult<Vec<F>> {
566 let res = self.result.as_ref().ok_or_else(|| {
567 SparseError::ValueError("Cholesky factorization not computed".to_string())
568 })?;
569 let n = res.n;
570 if b.len() != n {
571 return Err(SparseError::DimensionMismatch {
572 expected: n,
573 found: b.len(),
574 });
575 }
576 if n == 0 {
577 return Ok(Vec::new());
578 }
579
580 let mut y = vec![F::sparse_zero(); n];
582 for i in 0..n {
583 y[i] = b[res.perm[i]];
584 }
585
586 for i in 0..n {
588 for j in 0..i {
589 y[i] = y[i] - res.l_dense[i][j] * y[j];
590 }
591 let d = res.l_dense[i][i];
592 if d.abs() < F::epsilon() {
593 return Err(SparseError::SingularMatrix(
594 "Zero diagonal in L during solve".to_string(),
595 ));
596 }
597 y[i] /= d;
598 }
599
600 for i in (0..n).rev() {
602 for j in (i + 1)..n {
603 y[i] = y[i] - res.l_dense[j][i] * y[j];
604 }
605 let d = res.l_dense[i][i];
606 if d.abs() < F::epsilon() {
607 return Err(SparseError::SingularMatrix(
608 "Zero diagonal in L^T during solve".to_string(),
609 ));
610 }
611 y[i] /= d;
612 }
613
614 let mut x = vec![F::sparse_zero(); n];
616 for i in 0..n {
617 x[res.perm[i]] = y[i];
618 }
619 Ok(x)
620 }
621}
622
623#[derive(Debug, Clone)]
629pub struct SparseLuResult<F> {
630 pub lu_dense: Vec<Vec<F>>,
632 pub row_perm: Vec<usize>,
634 pub col_perm: Vec<usize>,
636 pub n: usize,
638}
639
640pub struct SparseLuSolver<F> {
642 result: Option<SparseLuResult<F>>,
643}
644
645impl<F: Float + NumAssign + Sum + SparseElement + Debug + 'static> SparseLuSolver<F> {
646 pub fn new() -> Self {
648 Self { result: None }
649 }
650
651 pub fn factorization(&self) -> Option<&SparseLuResult<F>> {
653 self.result.as_ref()
654 }
655}
656
657impl<F: Float + NumAssign + Sum + SparseElement + Debug + 'static> Default for SparseLuSolver<F> {
658 fn default() -> Self {
659 Self::new()
660 }
661}
662
663impl<F: Float + NumAssign + Sum + SparseElement + Debug + 'static> SparseSolver<F>
664 for SparseLuSolver<F>
665{
666 fn factorize(&mut self, matrix: &CsrMatrix<F>) -> SparseResult<()> {
667 let n = matrix.rows();
668 if n != matrix.cols() {
669 return Err(SparseError::ValueError(
670 "LU requires a square matrix".to_string(),
671 ));
672 }
673 if n == 0 {
674 self.result = Some(SparseLuResult {
675 lu_dense: Vec::new(),
676 row_perm: Vec::new(),
677 col_perm: Vec::new(),
678 n: 0,
679 });
680 return Ok(());
681 }
682
683 let col_perm = amd_ordering(matrix)?;
685 let col_perm_inv = inverse_perm(&col_perm);
686
687 let mut a = vec![vec![F::sparse_zero(); n]; n];
689 for i in 0..n {
690 let range = i_row_range(matrix, i);
691 for idx in range {
692 let orig_col = matrix.indices[idx];
693 let j = col_perm_inv[orig_col];
694 a[i][j] += matrix.data[idx];
695 }
696 }
697
698 let mut row_perm: Vec<usize> = (0..n).collect();
700
701 for k in 0..n {
702 let mut max_abs = F::sparse_zero();
704 let mut pivot = k;
705 for i in k..n {
706 if a[i][k].abs() > max_abs {
707 max_abs = a[i][k].abs();
708 pivot = i;
709 }
710 }
711
712 if pivot != k {
713 a.swap(k, pivot);
714 row_perm.swap(k, pivot);
715 }
716
717 let akk = a[k][k];
718 if akk.abs() < F::epsilon() {
719 continue; }
721
722 for i in (k + 1)..n {
723 let lik = a[i][k] / akk;
724 a[i][k] = lik; for j in (k + 1)..n {
726 let ukj = a[k][j];
727 a[i][j] -= lik * ukj;
728 }
729 }
730 }
731
732 self.result = Some(SparseLuResult {
733 lu_dense: a,
734 row_perm,
735 col_perm,
736 n,
737 });
738 Ok(())
739 }
740
741 fn solve(&self, b: &[F]) -> SparseResult<Vec<F>> {
742 let res = self
743 .result
744 .as_ref()
745 .ok_or_else(|| SparseError::ValueError("LU factorization not computed".to_string()))?;
746 let n = res.n;
747 if b.len() != n {
748 return Err(SparseError::DimensionMismatch {
749 expected: n,
750 found: b.len(),
751 });
752 }
753 if n == 0 {
754 return Ok(Vec::new());
755 }
756
757 let mut x = vec![F::sparse_zero(); n];
759 for i in 0..n {
760 x[i] = b[res.row_perm[i]];
761 }
762
763 for i in 0..n {
765 for j in 0..i {
766 x[i] = x[i] - res.lu_dense[i][j] * x[j];
767 }
768 }
769
770 for i in (0..n).rev() {
772 for j in (i + 1)..n {
773 x[i] = x[i] - res.lu_dense[i][j] * x[j];
774 }
775 let d = res.lu_dense[i][i];
776 if d.abs() < F::epsilon() {
777 return Err(SparseError::SingularMatrix(format!(
778 "Zero diagonal in U at row {i}"
779 )));
780 }
781 x[i] /= d;
782 }
783
784 let mut result = vec![F::sparse_zero(); n];
786 for j in 0..n {
787 result[res.col_perm[j]] = x[j];
788 }
789 Ok(result)
790 }
791}
792
793pub fn sparse_lu_solve<F>(matrix: &CsrMatrix<F>, b: &[F]) -> SparseResult<Vec<F>>
799where
800 F: Float + NumAssign + Sum + SparseElement + Debug + 'static,
801{
802 let mut solver = SparseLuSolver::new();
803 solver.factorize(matrix)?;
804 solver.solve(b)
805}
806
807pub fn sparse_cholesky_solve<F>(matrix: &CsrMatrix<F>, b: &[F]) -> SparseResult<Vec<F>>
809where
810 F: Float + NumAssign + Sum + SparseElement + Debug + 'static,
811{
812 let mut solver = SparseCholeskySolver::new();
813 solver.factorize(matrix)?;
814 solver.solve(b)
815}
816
817fn i_row_range<F: SparseElement + Clone + Copy + scirs2_core::numeric::Zero + PartialEq>(
823 matrix: &CsrMatrix<F>,
824 row: usize,
825) -> std::ops::Range<usize> {
826 if row >= matrix.rows() {
827 return 0..0;
828 }
829 matrix.indptr[row]..matrix.indptr[row + 1]
830}
831
832#[cfg(test)]
837mod tests {
838 use super::*;
839
840 fn create_spd_3x3() -> CsrMatrix<f64> {
842 let rows = vec![0, 0, 1, 1, 1, 2, 2];
843 let cols = vec![0, 1, 0, 1, 2, 1, 2];
844 let data = vec![4.0, 1.0, 1.0, 5.0, 2.0, 2.0, 6.0];
845 CsrMatrix::new(data, rows, cols, (3, 3)).expect("Failed to create SPD matrix")
846 }
847
848 fn create_general_3x3() -> CsrMatrix<f64> {
849 let rows = vec![0, 0, 0, 1, 1, 1, 2, 2, 2];
850 let cols = vec![0, 1, 2, 0, 1, 2, 0, 1, 2];
851 let data = vec![3.0, 1.0, 2.0, 1.0, 4.0, 1.0, 0.0, 1.0, 5.0];
852 CsrMatrix::new(data, rows, cols, (3, 3)).expect("Failed to create matrix")
853 }
854
855 fn create_spd_4x4() -> CsrMatrix<f64> {
856 let rows = vec![0, 0, 1, 1, 1, 2, 2, 2, 3, 3];
857 let cols = vec![0, 1, 0, 1, 2, 1, 2, 3, 2, 3];
858 let data = vec![4.0, 1.0, 1.0, 4.0, 1.0, 1.0, 4.0, 1.0, 1.0, 4.0];
859 CsrMatrix::new(data, rows, cols, (4, 4)).expect("Failed to create SPD 4x4")
860 }
861
862 fn verify_solve(mat: &CsrMatrix<f64>, x: &[f64], b: &[f64], tol: f64) {
863 let dense = mat.to_dense();
864 let n = b.len();
865 for i in 0..n {
866 let mut row_sum = 0.0;
867 for j in 0..n {
868 row_sum += dense[i][j] * x[j];
869 }
870 assert!(
871 (row_sum - b[i]).abs() < tol,
872 "Row {i}: residual {}",
873 (row_sum - b[i]).abs()
874 );
875 }
876 }
877
878 #[test]
879 fn test_amd_ordering_basic() {
880 let mat = create_spd_3x3();
881 let perm = amd_ordering(&mat).expect("AMD failed");
882 assert_eq!(perm.len(), 3);
883 let mut sorted = perm.clone();
884 sorted.sort();
885 assert_eq!(sorted, vec![0, 1, 2]);
886 }
887
888 #[test]
889 fn test_amd_ordering_empty() {
890 let mat =
891 CsrMatrix::<f64>::new(vec![], vec![], vec![], (0, 0)).expect("Failed to create empty");
892 let perm = amd_ordering(&mat).expect("AMD failed on empty");
893 assert!(perm.is_empty());
894 }
895
896 #[test]
897 fn test_inverse_perm() {
898 let perm = vec![2, 0, 1];
899 let inv = inverse_perm(&perm);
900 assert_eq!(inv, vec![1, 2, 0]);
901 for i in 0..3 {
902 assert_eq!(perm[inv[i]], i);
903 }
904 }
905
906 #[test]
907 fn test_nested_dissection_basic() {
908 let mat = create_spd_4x4();
909 let perm = nested_dissection_ordering(&mat).expect("ND failed");
910 assert_eq!(perm.len(), 4);
911 let mut sorted = perm.clone();
912 sorted.sort();
913 assert_eq!(sorted, vec![0, 1, 2, 3]);
914 }
915
916 #[test]
917 fn test_elimination_tree() {
918 let mat = create_spd_3x3();
919 let perm: Vec<usize> = (0..3).collect();
920 let etree = elimination_tree(&mat, &perm);
921 assert_eq!(etree.len(), 3);
922 }
923
924 #[test]
925 fn test_cholesky_solve_3x3() {
926 let mat = create_spd_3x3();
927 let b = vec![5.0, 8.0, 8.0];
928 let x = sparse_cholesky_solve(&mat, &b).expect("Cholesky solve failed");
929 assert_eq!(x.len(), 3);
930 for (i, &xi) in x.iter().enumerate() {
931 assert!((xi - 1.0).abs() < 1e-10, "x[{i}] = {xi}, expected 1.0");
932 }
933 }
934
935 #[test]
936 fn test_cholesky_solve_4x4() {
937 let mat = create_spd_4x4();
938 let b = vec![5.0, 6.0, 6.0, 5.0];
939 let x = sparse_cholesky_solve(&mat, &b).expect("Cholesky solve 4x4 failed");
940 verify_solve(&mat, &x, &b, 1e-10);
941 }
942
943 #[test]
944 fn test_cholesky_non_spd() {
945 let rows = vec![0, 1, 2];
946 let cols = vec![0, 1, 2];
947 let data = vec![-1.0, 1.0, 1.0];
948 let mat = CsrMatrix::new(data, rows, cols, (3, 3)).expect("Failed to create matrix");
949 let result = sparse_cholesky_solve(&mat, &[1.0, 1.0, 1.0]);
950 assert!(result.is_err());
951 }
952
953 #[test]
954 fn test_lu_solve_3x3() {
955 let mat = create_general_3x3();
956 let b = vec![6.0, 6.0, 6.0];
957 let x = sparse_lu_solve(&mat, &b).expect("LU solve failed");
958 verify_solve(&mat, &x, &b, 1e-9);
959 }
960
961 #[test]
962 fn test_lu_solve_identity() {
963 let rows = vec![0, 1, 2, 3];
964 let cols = vec![0, 1, 2, 3];
965 let data = vec![1.0, 1.0, 1.0, 1.0];
966 let mat = CsrMatrix::new(data, rows, cols, (4, 4)).expect("Failed to create identity");
967 let b = vec![1.0, 2.0, 3.0, 4.0];
968 let x = sparse_lu_solve(&mat, &b).expect("LU solve on identity failed");
969 for i in 0..4 {
970 assert!(
971 (x[i] - b[i]).abs() < 1e-12,
972 "x[{i}] = {}, expected {}",
973 x[i],
974 b[i]
975 );
976 }
977 }
978
979 #[test]
980 fn test_lu_solve_multi() {
981 let mat = create_general_3x3();
982 let mut solver = SparseLuSolver::new();
983 solver.factorize(&mat).expect("LU factorize failed");
984
985 let b1 = vec![6.0, 6.0, 6.0];
986 let b2 = vec![3.0, 1.0, 2.0];
987 let results = solver
988 .solve_multi(&[b1.clone(), b2.clone()])
989 .expect("Solve multi failed");
990
991 verify_solve(&mat, &results[0], &b1, 1e-9);
992 verify_solve(&mat, &results[1], &b2, 1e-9);
993 }
994
995 #[test]
996 fn test_cholesky_solver_trait() {
997 let mat = create_spd_3x3();
998 let mut solver = SparseCholeskySolver::new();
999 solver.factorize(&mat).expect("Factorize failed");
1000 assert!(solver.factorization().is_some());
1001
1002 let b = vec![5.0, 8.0, 8.0];
1003 let x = solver.solve(&b).expect("Solve failed");
1004 for (i, xi) in x.iter().enumerate() {
1005 assert!((xi - 1.0).abs() < 1e-10, "x[{i}] = {xi}");
1006 }
1007 }
1008
1009 #[test]
1010 fn test_lu_empty_matrix() {
1011 let mat =
1012 CsrMatrix::<f64>::new(vec![], vec![], vec![], (0, 0)).expect("Failed to create empty");
1013 let mut solver = SparseLuSolver::new();
1014 solver
1015 .factorize(&mat)
1016 .expect("LU factorize on empty failed");
1017 let x = solver.solve(&[]).expect("LU solve on empty failed");
1018 assert!(x.is_empty());
1019 }
1020
1021 #[test]
1022 fn test_cholesky_dimension_mismatch() {
1023 let mat = create_spd_3x3();
1024 let mut solver = SparseCholeskySolver::new();
1025 solver.factorize(&mat).expect("Factorize failed");
1026 let result = solver.solve(&[1.0, 2.0]);
1027 assert!(result.is_err());
1028 }
1029
1030 #[test]
1031 fn test_lu_solve_5x5_diag_dominant() {
1032 let mut rows = Vec::new();
1033 let mut cols = Vec::new();
1034 let mut data = Vec::new();
1035 for i in 0..5 {
1036 for j in 0..5 {
1037 if i == j {
1038 rows.push(i);
1039 cols.push(j);
1040 data.push(10.0);
1041 } else if (i as isize - j as isize).unsigned_abs() <= 1 {
1042 rows.push(i);
1043 cols.push(j);
1044 data.push(1.0);
1045 }
1046 }
1047 }
1048 let mat = CsrMatrix::new(data, rows, cols, (5, 5)).expect("Failed to create 5x5");
1049 let b = vec![12.0, 12.0, 12.0, 12.0, 12.0];
1050 let x = sparse_lu_solve(&mat, &b).expect("LU 5x5 failed");
1051 verify_solve(&mat, &x, &b, 1e-8);
1052 }
1053
1054 #[test]
1055 fn test_symbolic_cholesky() {
1056 let mat = create_spd_3x3();
1057 let perm: Vec<usize> = (0..3).collect();
1058 let analysis = symbolic_cholesky(&mat, &perm).expect("Symbolic Cholesky failed");
1059 assert_eq!(analysis.n, 3);
1060 assert_eq!(analysis.l_colptr.len(), 4);
1061 assert!(analysis.l_colptr[3] >= 3);
1062 }
1063
1064 #[test]
1065 fn test_lu_non_square_error() {
1066 let rows = vec![0, 1];
1067 let cols = vec![0, 0];
1068 let data = vec![1.0, 2.0];
1069 let mat = CsrMatrix::new(data, rows, cols, (2, 3)).expect("Failed to create non-square");
1070 let result = sparse_lu_solve(&mat, &[1.0, 2.0]);
1071 assert!(result.is_err());
1072 }
1073
1074 #[test]
1075 fn test_cholesky_non_square_error() {
1076 let rows = vec![0, 1, 2];
1077 let cols = vec![0, 0, 0];
1078 let data = vec![1.0, 2.0, 3.0];
1079 let mat = CsrMatrix::new(data, rows, cols, (3, 4)).expect("Failed to create non-square");
1080 let result = sparse_cholesky_solve(&mat, &[1.0, 2.0, 3.0]);
1081 assert!(result.is_err());
1082 }
1083
1084 #[test]
1085 fn test_lu_solve_with_zeros() {
1086 let rows = vec![0, 0, 1, 2, 2];
1087 let cols = vec![0, 2, 1, 0, 2];
1088 let data = vec![2.0, 1.0, 3.0, 1.0, 4.0];
1089 let mat = CsrMatrix::new(data, rows, cols, (3, 3)).expect("Failed");
1090 let b = vec![3.0, 3.0, 5.0];
1091 let x = sparse_lu_solve(&mat, &b).expect("LU solve sparse matrix failed");
1092 verify_solve(&mat, &x, &b, 1e-9);
1093 }
1094
1095 #[test]
1096 fn test_amd_non_square_error() {
1097 let rows = vec![0];
1098 let cols = vec![0];
1099 let data = vec![1.0];
1100 let mat = CsrMatrix::new(data, rows, cols, (2, 3)).expect("Failed");
1101 let result = amd_ordering(&mat);
1102 assert!(result.is_err());
1103 }
1104}