1use crate::csr_array::CsrArray;
7use crate::error::{SparseError, SparseResult};
8use crate::sparray::SparseArray;
9use scirs2_core::ndarray::{Array1, Array2};
10use scirs2_core::numeric::Float;
11use scirs2_core::SparseElement;
12use std::collections::HashMap;
13use std::fmt::Debug;
14use std::ops::{Add, Div, Mul, Sub};
15
16#[derive(Debug, Clone)]
18pub struct LUResult<T>
19where
20 T: Float + SparseElement + Debug + Copy + 'static,
21{
22 pub l: CsrArray<T>,
24 pub u: CsrArray<T>,
26 pub p: Array1<usize>,
28 pub success: bool,
30}
31
32#[derive(Debug, Clone)]
34pub struct QRResult<T>
35where
36 T: Float + SparseElement + Debug + Copy + 'static,
37{
38 pub q: CsrArray<T>,
40 pub r: CsrArray<T>,
42 pub success: bool,
44}
45
46#[derive(Debug, Clone)]
48pub struct CholeskyResult<T>
49where
50 T: Float + SparseElement + Debug + Copy + 'static,
51{
52 pub l: CsrArray<T>,
54 pub success: bool,
56}
57
58#[derive(Debug, Clone)]
60pub struct PivotedCholeskyResult<T>
61where
62 T: Float + SparseElement + Debug + Copy + 'static,
63{
64 pub l: CsrArray<T>,
66 pub p: Array1<usize>,
68 pub rank: usize,
70 pub success: bool,
72}
73
74#[derive(Debug, Clone, Default)]
76pub enum PivotingStrategy {
77 None,
79 #[default]
81 Partial,
82 Threshold(f64),
84 ScaledPartial,
86 Complete,
88 Rook,
90}
91
92#[derive(Debug, Clone)]
94pub struct LUOptions {
95 pub pivoting: PivotingStrategy,
97 pub zero_threshold: f64,
99 pub check_singular: bool,
101}
102
103impl Default for LUOptions {
104 fn default() -> Self {
105 Self {
106 pivoting: PivotingStrategy::default(),
107 zero_threshold: 1e-14,
108 check_singular: true,
109 }
110 }
111}
112
113#[derive(Debug, Clone)]
115pub struct ILUOptions {
116 pub drop_tol: f64,
118 pub fill_factor: f64,
120 pub max_fill_per_row: usize,
122 pub pivoting: PivotingStrategy,
124}
125
126impl Default for ILUOptions {
127 fn default() -> Self {
128 Self {
129 drop_tol: 1e-4,
130 fill_factor: 2.0,
131 max_fill_per_row: 20,
132 pivoting: PivotingStrategy::default(),
133 }
134 }
135}
136
137#[derive(Debug, Clone)]
139pub struct ICOptions {
140 pub drop_tol: f64,
142 pub fill_factor: f64,
144 pub max_fill_per_row: usize,
146}
147
148impl Default for ICOptions {
149 fn default() -> Self {
150 Self {
151 drop_tol: 1e-4,
152 fill_factor: 2.0,
153 max_fill_per_row: 20,
154 }
155 }
156}
157
158#[allow(dead_code)]
187pub fn lu_decomposition<T, S>(_matrix: &S, pivotthreshold: f64) -> SparseResult<LUResult<T>>
188where
189 T: Float
190 + SparseElement
191 + Debug
192 + Copy
193 + Add<Output = T>
194 + Sub<Output = T>
195 + Mul<Output = T>
196 + Div<Output = T>,
197 S: SparseArray<T>,
198{
199 let options = LUOptions {
201 pivoting: PivotingStrategy::Threshold(pivotthreshold),
202 zero_threshold: 1e-14,
203 check_singular: true,
204 };
205
206 lu_decomposition_with_options(_matrix, Some(options))
207}
208
209#[allow(dead_code)]
244pub fn lu_decomposition_with_options<T, S>(
245 matrix: &S,
246 options: Option<LUOptions>,
247) -> SparseResult<LUResult<T>>
248where
249 T: Float
250 + SparseElement
251 + Debug
252 + Copy
253 + Add<Output = T>
254 + Sub<Output = T>
255 + Mul<Output = T>
256 + Div<Output = T>,
257 S: SparseArray<T>,
258{
259 let opts = options.unwrap_or_default();
260 let (n, m) = matrix.shape();
261 if n != m {
262 return Err(SparseError::ValueError(
263 "Matrix must be square for LU decomposition".to_string(),
264 ));
265 }
266
267 let (row_indices, col_indices, values) = matrix.find();
269 let mut working_matrix = SparseWorkingMatrix::from_triplets(
270 row_indices.as_slice().unwrap(),
271 col_indices.as_slice().unwrap(),
272 values.as_slice().unwrap(),
273 n,
274 );
275
276 let mut row_perm: Vec<usize> = (0..n).collect();
278 let mut col_perm: Vec<usize> = (0..n).collect();
279
280 let mut row_scales = vec![T::sparse_one(); n];
282 if matches!(opts.pivoting, PivotingStrategy::ScaledPartial) {
283 for (i, scale) in row_scales.iter_mut().enumerate().take(n) {
284 let row_data = working_matrix.get_row(i);
285 let max_val = row_data
286 .values()
287 .map(|&v| v.abs())
288 .fold(T::sparse_zero(), |a, b| if a > b { a } else { b });
289 if max_val > T::sparse_zero() {
290 *scale = max_val;
291 }
292 }
293 }
294
295 for k in 0..n - 1 {
297 let (pivot_row, pivot_col) =
299 find_enhanced_pivot(&working_matrix, k, &row_perm, &col_perm, &row_scales, &opts)?;
300
301 if pivot_row != k {
303 row_perm.swap(k, pivot_row);
304 }
305 if pivot_col != k
306 && matches!(
307 opts.pivoting,
308 PivotingStrategy::Complete | PivotingStrategy::Rook
309 )
310 {
311 col_perm.swap(k, pivot_col);
312 for &row_idx in row_perm.iter().take(n) {
314 let temp = working_matrix.get(row_idx, k);
315 working_matrix.set(row_idx, k, working_matrix.get(row_idx, pivot_col));
316 working_matrix.set(row_idx, pivot_col, temp);
317 }
318 }
319
320 let actual_pivot_row = row_perm[k];
321 let actual_pivot_col = col_perm[k];
322 let pivot_value = working_matrix.get(actual_pivot_row, actual_pivot_col);
323
324 if opts.check_singular && pivot_value.abs() < T::from(opts.zero_threshold).unwrap() {
326 return Ok(LUResult {
327 l: CsrArray::from_triplets(&[], &[], &[], (n, n), false)?,
328 u: CsrArray::from_triplets(&[], &[], &[], (n, n), false)?,
329 p: Array1::from_vec(row_perm),
330 success: false,
331 });
332 }
333
334 for &actual_row_i in row_perm.iter().take(n).skip(k + 1) {
336 let factor = working_matrix.get(actual_row_i, actual_pivot_col) / pivot_value;
337
338 if !SparseElement::is_zero(&factor) {
339 working_matrix.set(actual_row_i, actual_pivot_col, factor);
341
342 let pivot_row_data = working_matrix.get_row(actual_pivot_row);
344 for (col, &value) in &pivot_row_data {
345 if *col > k {
346 let old_val = working_matrix.get(actual_row_i, *col);
347 working_matrix.set(actual_row_i, *col, old_val - factor * value);
348 }
349 }
350 }
351 }
352 }
353
354 let (l_rows, l_cols, l_vals, u_rows, u_cols, u_vals) =
356 extract_lu_factors(&working_matrix, &row_perm, n);
357
358 let l = CsrArray::from_triplets(&l_rows, &l_cols, &l_vals, (n, n), false)?;
359 let u = CsrArray::from_triplets(&u_rows, &u_cols, &u_vals, (n, n), false)?;
360
361 Ok(LUResult {
362 l,
363 u,
364 p: Array1::from_vec(row_perm),
365 success: true,
366 })
367}
368
369#[allow(dead_code)]
396pub fn qr_decomposition<T, S>(matrix: &S) -> SparseResult<QRResult<T>>
397where
398 T: Float
399 + SparseElement
400 + Debug
401 + Copy
402 + Add<Output = T>
403 + Sub<Output = T>
404 + Mul<Output = T>
405 + Div<Output = T>,
406 S: SparseArray<T>,
407{
408 let (m, n) = matrix.shape();
409
410 let dense_matrix = matrix.to_array();
412
413 let mut q = Array2::zeros((m, n));
415 let mut r = Array2::zeros((n, n));
416
417 for j in 0..n {
418 for i in 0..m {
420 q[[i, j]] = dense_matrix[[i, j]];
421 }
422
423 for k in 0..j {
425 let mut dot = T::sparse_zero();
426 for i in 0..m {
427 dot = dot + q[[i, k]] * dense_matrix[[i, j]];
428 }
429 r[[k, j]] = dot;
430
431 for i in 0..m {
432 q[[i, j]] = q[[i, j]] - dot * q[[i, k]];
433 }
434 }
435
436 let mut norm = T::sparse_zero();
438 for i in 0..m {
439 norm = norm + q[[i, j]] * q[[i, j]];
440 }
441 norm = norm.sqrt();
442 r[[j, j]] = norm;
443
444 if !SparseElement::is_zero(&norm) {
445 for i in 0..m {
446 q[[i, j]] = q[[i, j]] / norm;
447 }
448 }
449 }
450
451 let q_sparse = dense_to_sparse(&q)?;
453 let r_sparse = dense_to_sparse(&r)?;
454
455 Ok(QRResult {
456 q: q_sparse,
457 r: r_sparse,
458 success: true,
459 })
460}
461
462#[allow(dead_code)]
490pub fn cholesky_decomposition<T, S>(matrix: &S) -> SparseResult<CholeskyResult<T>>
491where
492 T: Float
493 + SparseElement
494 + Debug
495 + Copy
496 + Add<Output = T>
497 + Sub<Output = T>
498 + Mul<Output = T>
499 + Div<Output = T>,
500 S: SparseArray<T>,
501{
502 let (n, m) = matrix.shape();
503 if n != m {
504 return Err(SparseError::ValueError(
505 "Matrix must be square for Cholesky decomposition".to_string(),
506 ));
507 }
508
509 let (row_indices, col_indices, values) = matrix.find();
511 let mut working_matrix = SparseWorkingMatrix::from_triplets(
512 row_indices.as_slice().unwrap(),
513 col_indices.as_slice().unwrap(),
514 values.as_slice().unwrap(),
515 n,
516 );
517
518 for k in 0..n {
520 let mut sum = T::sparse_zero();
522 for j in 0..k {
523 let l_kj = working_matrix.get(k, j);
524 sum = sum + l_kj * l_kj;
525 }
526
527 let a_kk = working_matrix.get(k, k);
528 let diag_val = a_kk - sum;
529
530 if diag_val <= T::sparse_zero() {
531 return Ok(CholeskyResult {
532 l: CsrArray::from_triplets(&[], &[], &[], (n, n), false)?,
533 success: false,
534 });
535 }
536
537 let l_kk = diag_val.sqrt();
538 working_matrix.set(k, k, l_kk);
539
540 for i in (k + 1)..n {
542 let mut sum = T::sparse_zero();
543 for j in 0..k {
544 sum = sum + working_matrix.get(i, j) * working_matrix.get(k, j);
545 }
546
547 let a_ik = working_matrix.get(i, k);
548 let l_ik = (a_ik - sum) / l_kk;
549 working_matrix.set(i, k, l_ik);
550 }
551 }
552
553 let (l_rows, l_cols, l_vals) = extract_lower_triangular(&working_matrix, n);
555 let l = CsrArray::from_triplets(&l_rows, &l_cols, &l_vals, (n, n), false)?;
556
557 Ok(CholeskyResult { l, success: true })
558}
559
560#[allow(dead_code)]
590pub fn pivoted_cholesky_decomposition<T, S>(
591 matrix: &S,
592 threshold: Option<T>,
593) -> SparseResult<PivotedCholeskyResult<T>>
594where
595 T: Float
596 + SparseElement
597 + Debug
598 + Copy
599 + Add<Output = T>
600 + Sub<Output = T>
601 + Mul<Output = T>
602 + Div<Output = T>,
603 S: SparseArray<T>,
604{
605 let (n, m) = matrix.shape();
606 if n != m {
607 return Err(SparseError::ValueError(
608 "Matrix must be square for Cholesky decomposition".to_string(),
609 ));
610 }
611
612 let threshold = threshold.unwrap_or_else(|| T::from(1e-12).unwrap());
613
614 let (row_indices, col_indices, values) = matrix.find();
616 let mut working_matrix = SparseWorkingMatrix::from_triplets(
617 row_indices.as_slice().unwrap(),
618 col_indices.as_slice().unwrap(),
619 values.as_slice().unwrap(),
620 n,
621 );
622
623 let mut perm: Vec<usize> = (0..n).collect();
625 let mut rank = 0;
626
627 for k in 0..n {
629 let mut max_diag = T::sparse_zero();
631 let mut pivot_idx = k;
632
633 for i in k..n {
634 let mut diag_val = working_matrix.get(perm[i], perm[i]);
635 for j in 0..k {
636 let l_ij = working_matrix.get(perm[i], perm[j]);
637 diag_val = diag_val - l_ij * l_ij;
638 }
639 if diag_val > max_diag {
640 max_diag = diag_val;
641 pivot_idx = i;
642 }
643 }
644
645 if max_diag <= threshold {
647 break;
648 }
649
650 if pivot_idx != k {
652 perm.swap(k, pivot_idx);
653 }
654
655 let l_kk = max_diag.sqrt();
657 working_matrix.set(perm[k], perm[k], l_kk);
658 rank += 1;
659
660 for i in (k + 1)..n {
662 let mut sum = T::sparse_zero();
663 for j in 0..k {
664 sum = sum
665 + working_matrix.get(perm[i], perm[j]) * working_matrix.get(perm[k], perm[j]);
666 }
667
668 let a_ik = working_matrix.get(perm[i], perm[k]);
669 let l_ik = (a_ik - sum) / l_kk;
670 working_matrix.set(perm[i], perm[k], l_ik);
671 }
672 }
673
674 let mut l_rows = Vec::new();
676 let mut l_cols = Vec::new();
677 let mut l_vals = Vec::new();
678
679 for i in 0..rank {
680 for j in 0..=i {
681 let val = working_matrix.get(perm[i], perm[j]);
682 if val != T::sparse_zero() {
683 l_rows.push(i);
684 l_cols.push(j);
685 l_vals.push(val);
686 }
687 }
688 }
689
690 let l = CsrArray::from_triplets(&l_rows, &l_cols, &l_vals, (n, rank), false)?;
691 let p = Array1::from_vec(perm);
692
693 Ok(PivotedCholeskyResult {
694 l,
695 p,
696 rank,
697 success: true,
698 })
699}
700
701#[derive(Debug, Clone)]
703pub struct LDLTResult<T>
704where
705 T: Float + SparseElement + Debug + Copy + 'static,
706{
707 pub l: CsrArray<T>,
709 pub d: Array1<T>,
711 pub p: Array1<usize>,
713 pub success: bool,
715}
716
717#[allow(dead_code)]
748pub fn ldlt_decomposition<T, S>(
749 matrix: &S,
750 pivoting: Option<bool>,
751 threshold: Option<T>,
752) -> SparseResult<LDLTResult<T>>
753where
754 T: Float
755 + SparseElement
756 + Debug
757 + Copy
758 + Add<Output = T>
759 + Sub<Output = T>
760 + Mul<Output = T>
761 + Div<Output = T>,
762 S: SparseArray<T>,
763{
764 let (n, m) = matrix.shape();
765 if n != m {
766 return Err(SparseError::ValueError(
767 "Matrix must be square for LDLT decomposition".to_string(),
768 ));
769 }
770
771 let use_pivoting = pivoting.unwrap_or(true);
772 let threshold = threshold.unwrap_or_else(|| T::from(1e-12).unwrap());
773
774 let (row_indices, col_indices, values) = matrix.find();
776 let mut working_matrix = SparseWorkingMatrix::from_triplets(
777 row_indices.as_slice().unwrap(),
778 col_indices.as_slice().unwrap(),
779 values.as_slice().unwrap(),
780 n,
781 );
782
783 let mut perm: Vec<usize> = (0..n).collect();
785 let mut d_values = vec![T::sparse_zero(); n];
786
787 for k in 0..n {
789 if use_pivoting {
791 let pivot_idx = find_ldlt_pivot(&working_matrix, k, &perm, threshold);
792 if pivot_idx != k {
793 perm.swap(k, pivot_idx);
794 }
795 }
796
797 let actual_k = perm[k];
798
799 let mut diag_val = working_matrix.get(actual_k, actual_k);
801 for j in 0..k {
802 let l_kj = working_matrix.get(actual_k, perm[j]);
803 diag_val = diag_val - l_kj * l_kj * d_values[j];
804 }
805
806 d_values[k] = diag_val;
807
808 if diag_val.abs() < threshold {
810 return Ok(LDLTResult {
811 l: CsrArray::from_triplets(&[], &[], &[], (n, n), false)?,
812 d: Array1::from_vec(d_values),
813 p: Array1::from_vec(perm),
814 success: false,
815 });
816 }
817
818 for i in (k + 1)..n {
820 let actual_i = perm[i];
821 let mut l_ik = working_matrix.get(actual_i, actual_k);
822
823 for j in 0..k {
824 l_ik = l_ik
825 - working_matrix.get(actual_i, perm[j])
826 * working_matrix.get(actual_k, perm[j])
827 * d_values[j];
828 }
829
830 l_ik = l_ik / diag_val;
831 working_matrix.set(actual_i, actual_k, l_ik);
832 }
833
834 working_matrix.set(actual_k, actual_k, T::sparse_one());
836 }
837
838 let (l_rows, l_cols, l_vals) = extract_unit_lower_triangular(&working_matrix, &perm, n);
840 let l = CsrArray::from_triplets(&l_rows, &l_cols, &l_vals, (n, n), false)?;
841
842 Ok(LDLTResult {
843 l,
844 d: Array1::from_vec(d_values),
845 p: Array1::from_vec(perm),
846 success: true,
847 })
848}
849
850#[allow(dead_code)]
852fn find_ldlt_pivot<T>(
853 matrix: &SparseWorkingMatrix<T>,
854 k: usize,
855 perm: &[usize],
856 threshold: T,
857) -> usize
858where
859 T: Float + SparseElement + Debug + Copy,
860{
861 let n = matrix.n;
862 let mut max_val = T::sparse_zero();
863 let mut pivot_idx = k;
864
865 for (i, &actual_i) in perm.iter().enumerate().take(n).skip(k) {
867 let diag_val = matrix.get(actual_i, actual_i).abs();
868
869 if diag_val > max_val {
870 max_val = diag_val;
871 pivot_idx = i;
872 }
873 }
874
875 if max_val >= threshold {
877 pivot_idx
878 } else {
879 k }
881}
882
883#[allow(dead_code)]
885fn extract_unit_lower_triangular<T>(
886 matrix: &SparseWorkingMatrix<T>,
887 perm: &[usize],
888 n: usize,
889) -> (Vec<usize>, Vec<usize>, Vec<T>)
890where
891 T: Float + SparseElement + Debug + Copy,
892{
893 let mut rows = Vec::new();
894 let mut cols = Vec::new();
895 let mut vals = Vec::new();
896
897 for i in 0..n {
898 let actual_i = perm[i];
899
900 rows.push(i);
902 cols.push(i);
903 vals.push(T::sparse_one());
904
905 for (j, &perm_j) in perm.iter().enumerate().take(i) {
907 let val = matrix.get(actual_i, perm_j);
908 if val != T::sparse_zero() {
909 rows.push(i);
910 cols.push(j);
911 vals.push(val);
912 }
913 }
914 }
915
916 (rows, cols, vals)
917}
918
919#[allow(dead_code)]
933pub fn incomplete_lu<T, S>(matrix: &S, options: Option<ILUOptions>) -> SparseResult<LUResult<T>>
934where
935 T: Float
936 + SparseElement
937 + Debug
938 + Copy
939 + Add<Output = T>
940 + Sub<Output = T>
941 + Mul<Output = T>
942 + Div<Output = T>,
943 S: SparseArray<T>,
944{
945 let opts = options.unwrap_or_default();
946 let (n, m) = matrix.shape();
947
948 if n != m {
949 return Err(SparseError::ValueError(
950 "Matrix must be square for ILU decomposition".to_string(),
951 ));
952 }
953
954 let (row_indices, col_indices, values) = matrix.find();
956 let mut working_matrix = SparseWorkingMatrix::from_triplets(
957 row_indices.as_slice().unwrap(),
958 col_indices.as_slice().unwrap(),
959 values.as_slice().unwrap(),
960 n,
961 );
962
963 for k in 0..n - 1 {
965 let pivot_val = working_matrix.get(k, k);
966
967 if pivot_val.abs() < T::from(1e-14).unwrap() {
968 continue; }
970
971 let col_k_entries = working_matrix.get_column_below_diagonal(k);
973
974 for &row_i in &col_k_entries {
975 let factor = working_matrix.get(row_i, k) / pivot_val;
976
977 if factor.abs() < T::from(opts.drop_tol).unwrap() {
979 working_matrix.set(row_i, k, T::sparse_zero());
980 continue;
981 }
982
983 working_matrix.set(row_i, k, factor);
984
985 let row_k_entries = working_matrix.get_row_after_column(k, k);
987 for (col_j, &val_kj) in &row_k_entries {
988 if working_matrix.has_entry(row_i, *col_j) {
989 let old_val = working_matrix.get(row_i, *col_j);
990 let new_val = old_val - factor * val_kj;
991
992 if new_val.abs() < T::from(opts.drop_tol).unwrap() {
994 working_matrix.set(row_i, *col_j, T::sparse_zero());
995 } else {
996 working_matrix.set(row_i, *col_j, new_val);
997 }
998 }
999 }
1000 }
1001 }
1002
1003 let identity_p: Vec<usize> = (0..n).collect();
1005 let (l_rows, l_cols, l_vals, u_rows, u_cols, u_vals) =
1006 extract_lu_factors(&working_matrix, &identity_p, n);
1007
1008 let l = CsrArray::from_triplets(&l_rows, &l_cols, &l_vals, (n, n), false)?;
1009 let u = CsrArray::from_triplets(&u_rows, &u_cols, &u_vals, (n, n), false)?;
1010
1011 Ok(LUResult {
1012 l,
1013 u,
1014 p: Array1::from_vec(identity_p),
1015 success: true,
1016 })
1017}
1018
1019#[allow(dead_code)]
1033pub fn incomplete_cholesky<T, S>(
1034 matrix: &S,
1035 options: Option<ICOptions>,
1036) -> SparseResult<CholeskyResult<T>>
1037where
1038 T: Float
1039 + SparseElement
1040 + Debug
1041 + Copy
1042 + Add<Output = T>
1043 + Sub<Output = T>
1044 + Mul<Output = T>
1045 + Div<Output = T>,
1046 S: SparseArray<T>,
1047{
1048 let opts = options.unwrap_or_default();
1049 let (n, m) = matrix.shape();
1050
1051 if n != m {
1052 return Err(SparseError::ValueError(
1053 "Matrix must be square for IC decomposition".to_string(),
1054 ));
1055 }
1056
1057 let (row_indices, col_indices, values) = matrix.find();
1059 let mut working_matrix = SparseWorkingMatrix::from_triplets(
1060 row_indices.as_slice().unwrap(),
1061 col_indices.as_slice().unwrap(),
1062 values.as_slice().unwrap(),
1063 n,
1064 );
1065
1066 for k in 0..n {
1068 let mut sum = T::sparse_zero();
1070 let row_k_before_k = working_matrix.get_row_before_column(k, k);
1071 for &val_kj in row_k_before_k.values() {
1072 sum = sum + val_kj * val_kj;
1073 }
1074
1075 let a_kk = working_matrix.get(k, k);
1076 let diag_val = a_kk - sum;
1077
1078 if diag_val <= T::sparse_zero() {
1079 return Ok(CholeskyResult {
1080 l: CsrArray::from_triplets(&[], &[], &[], (n, n), false)?,
1081 success: false,
1082 });
1083 }
1084
1085 let l_kk = diag_val.sqrt();
1086 working_matrix.set(k, k, l_kk);
1087
1088 let col_k_below = working_matrix.get_column_below_diagonal(k);
1090 for &row_i in &col_k_below {
1091 let mut sum = T::sparse_zero();
1092 let row_i_before_k = working_matrix.get_row_before_column(row_i, k);
1093 let row_k_before_k = working_matrix.get_row_before_column(k, k);
1094
1095 for (col_j, &val_ij) in &row_i_before_k {
1097 if let Some(&val_kj) = row_k_before_k.get(col_j) {
1098 sum = sum + val_ij * val_kj;
1099 }
1100 }
1101
1102 let a_ik = working_matrix.get(row_i, k);
1103 let l_ik = (a_ik - sum) / l_kk;
1104
1105 if l_ik.abs() < T::from(opts.drop_tol).unwrap() {
1107 working_matrix.set(row_i, k, T::sparse_zero());
1108 } else {
1109 working_matrix.set(row_i, k, l_ik);
1110 }
1111 }
1112 }
1113
1114 let (l_rows, l_cols, l_vals) = extract_lower_triangular(&working_matrix, n);
1116 let l = CsrArray::from_triplets(&l_rows, &l_cols, &l_vals, (n, n), false)?;
1117
1118 Ok(CholeskyResult { l, success: true })
1119}
1120
1121struct SparseWorkingMatrix<T>
1123where
1124 T: Float + SparseElement + Debug + Copy,
1125{
1126 data: HashMap<(usize, usize), T>,
1127 n: usize,
1128}
1129
1130impl<T> SparseWorkingMatrix<T>
1131where
1132 T: Float
1133 + SparseElement
1134 + Debug
1135 + Copy
1136 + Add<Output = T>
1137 + Sub<Output = T>
1138 + Mul<Output = T>
1139 + Div<Output = T>,
1140{
1141 fn from_triplets(rows: &[usize], cols: &[usize], values: &[T], n: usize) -> Self {
1142 let mut data = HashMap::new();
1143
1144 for (i, (&row, &col)) in rows.iter().zip(cols.iter()).enumerate() {
1145 data.insert((row, col), values[i]);
1146 }
1147
1148 Self { data, n }
1149 }
1150
1151 fn get(&self, row: usize, col: usize) -> T {
1152 self.data
1153 .get(&(row, col))
1154 .copied()
1155 .unwrap_or(T::sparse_zero())
1156 }
1157
1158 fn set(&mut self, row: usize, col: usize, value: T) {
1159 if SparseElement::is_zero(&value) {
1160 self.data.remove(&(row, col));
1161 } else {
1162 self.data.insert((row, col), value);
1163 }
1164 }
1165
1166 fn has_entry(&self, row: usize, col: usize) -> bool {
1167 self.data.contains_key(&(row, col))
1168 }
1169
1170 fn get_row(&self, row: usize) -> HashMap<usize, T> {
1171 let mut result = HashMap::new();
1172 for (&(r, c), &value) in &self.data {
1173 if r == row {
1174 result.insert(c, value);
1175 }
1176 }
1177 result
1178 }
1179
1180 fn get_row_after_column(&self, row: usize, col: usize) -> HashMap<usize, T> {
1181 let mut result = HashMap::new();
1182 for (&(r, c), &value) in &self.data {
1183 if r == row && c > col {
1184 result.insert(c, value);
1185 }
1186 }
1187 result
1188 }
1189
1190 fn get_row_before_column(&self, row: usize, col: usize) -> HashMap<usize, T> {
1191 let mut result = HashMap::new();
1192 for (&(r, c), &value) in &self.data {
1193 if r == row && c < col {
1194 result.insert(c, value);
1195 }
1196 }
1197 result
1198 }
1199
1200 fn get_column_below_diagonal(&self, col: usize) -> Vec<usize> {
1201 let mut result = Vec::new();
1202 for &(r, c) in self.data.keys() {
1203 if c == col && r > col {
1204 result.push(r);
1205 }
1206 }
1207 result.sort();
1208 result
1209 }
1210}
1211
1212#[allow(dead_code)]
1214fn find_pivot<T>(
1215 matrix: &SparseWorkingMatrix<T>,
1216 k: usize,
1217 p: &[usize],
1218 threshold: f64,
1219) -> SparseResult<usize>
1220where
1221 T: Float + SparseElement + Debug + Copy,
1222{
1223 let opts = LUOptions {
1225 pivoting: PivotingStrategy::Threshold(threshold),
1226 zero_threshold: 1e-14,
1227 check_singular: true,
1228 };
1229
1230 let row_scales = vec![T::sparse_one(); matrix.n];
1231 let col_perm: Vec<usize> = (0..matrix.n).collect();
1232
1233 let (pivot_row, pivot_col) = find_enhanced_pivot(matrix, k, p, &col_perm, &row_scales, &opts)?;
1234 Ok(pivot_row)
1235}
1236
1237#[allow(dead_code)]
1239fn find_enhanced_pivot<T>(
1240 matrix: &SparseWorkingMatrix<T>,
1241 k: usize,
1242 row_perm: &[usize],
1243 col_perm: &[usize],
1244 row_scales: &[T],
1245 opts: &LUOptions,
1246) -> SparseResult<(usize, usize)>
1247where
1248 T: Float + SparseElement + Debug + Copy,
1249{
1250 let n = matrix.n;
1251
1252 match &opts.pivoting {
1253 PivotingStrategy::None => {
1254 Ok((k, k))
1256 }
1257
1258 PivotingStrategy::Partial => {
1259 let mut max_val = T::sparse_zero();
1261 let mut pivot_row = k;
1262
1263 for (idx, &actual_row) in row_perm.iter().enumerate().skip(k).take(n - k) {
1264 let i = k + idx;
1265 let val = matrix.get(actual_row, col_perm[k]).abs();
1266 if val > max_val {
1267 max_val = val;
1268 pivot_row = i;
1269 }
1270 }
1271
1272 Ok((pivot_row, k))
1273 }
1274
1275 PivotingStrategy::Threshold(threshold) => {
1276 let threshold_val = T::from(*threshold).unwrap();
1278 let mut max_val = T::sparse_zero();
1279 let mut pivot_row = k;
1280
1281 for (idx, &actual_row) in row_perm.iter().enumerate().skip(k).take(n - k) {
1282 let i = k + idx;
1283 let val = matrix.get(actual_row, col_perm[k]).abs();
1284 if val > max_val {
1285 max_val = val;
1286 pivot_row = i;
1287 }
1288 if val >= threshold_val {
1290 pivot_row = i;
1291 break;
1292 }
1293 }
1294
1295 Ok((pivot_row, k))
1296 }
1297
1298 PivotingStrategy::ScaledPartial => {
1299 let mut max_ratio = T::sparse_zero();
1301 let mut pivot_row = k;
1302
1303 for (idx, &actual_row) in row_perm.iter().enumerate().skip(k).take(n - k) {
1304 let i = k + idx;
1305 let val = matrix.get(actual_row, col_perm[k]).abs();
1306 let scale = row_scales[actual_row];
1307
1308 let ratio = if scale > T::sparse_zero() {
1309 val / scale
1310 } else {
1311 val
1312 };
1313
1314 if ratio > max_ratio {
1315 max_ratio = ratio;
1316 pivot_row = i;
1317 }
1318 }
1319
1320 Ok((pivot_row, k))
1321 }
1322
1323 PivotingStrategy::Complete => {
1324 let mut max_val = T::sparse_zero();
1326 let mut pivot_row = k;
1327 let mut pivot_col = k;
1328
1329 for (i_idx, &actual_row) in row_perm.iter().enumerate().skip(k).take(n - k) {
1330 let i = k + i_idx;
1331 for (j_idx, &actual_col) in col_perm.iter().enumerate().skip(k).take(n - k) {
1332 let j = k + j_idx;
1333 let val = matrix.get(actual_row, actual_col).abs();
1334 if val > max_val {
1335 max_val = val;
1336 pivot_row = i;
1337 pivot_col = j;
1338 }
1339 }
1340 }
1341
1342 Ok((pivot_row, pivot_col))
1343 }
1344
1345 PivotingStrategy::Rook => {
1346 let mut best_row = k;
1348 let mut best_col = k;
1349 let mut max_val = T::sparse_zero();
1350
1351 for (idx, &actual_row) in row_perm.iter().enumerate().skip(k).take(n - k) {
1353 let i = k + idx;
1354 let val = matrix.get(actual_row, col_perm[k]).abs();
1355 if val > max_val {
1356 max_val = val;
1357 best_row = i;
1358 }
1359 }
1360
1361 if max_val > T::from(opts.zero_threshold).unwrap() {
1363 let actual_best_row = row_perm[best_row];
1364 let mut col_max = T::sparse_zero();
1365
1366 for (idx, &actual_col) in col_perm.iter().enumerate().skip(k).take(n - k) {
1367 let j = k + idx;
1368 let val = matrix.get(actual_best_row, actual_col).abs();
1369 if val > col_max {
1370 col_max = val;
1371 best_col = j;
1372 }
1373 }
1374
1375 let improvement_threshold = T::from(1.5).unwrap();
1377 if col_max > max_val * improvement_threshold {
1378 max_val = T::sparse_zero();
1380 for (idx, &actual_row) in row_perm.iter().enumerate().skip(k).take(n - k) {
1381 let i = k + idx;
1382 let val = matrix.get(actual_row, col_perm[best_col]).abs();
1383 if val > max_val {
1384 max_val = val;
1385 best_row = i;
1386 }
1387 }
1388 }
1389 }
1390
1391 Ok((best_row, best_col))
1392 }
1393 }
1394}
1395
1396type LuFactors<T> = (
1398 Vec<usize>, Vec<usize>, Vec<T>, Vec<usize>, Vec<usize>, Vec<T>, );
1405
1406#[allow(dead_code)]
1407fn extract_lu_factors<T>(matrix: &SparseWorkingMatrix<T>, p: &[usize], n: usize) -> LuFactors<T>
1408where
1409 T: Float + SparseElement + Debug + Copy,
1410{
1411 let mut l_rows = Vec::new();
1412 let mut l_cols = Vec::new();
1413 let mut l_vals = Vec::new();
1414 let mut u_rows = Vec::new();
1415 let mut u_cols = Vec::new();
1416 let mut u_vals = Vec::new();
1417
1418 #[allow(clippy::needless_range_loop)]
1419 for i in 0..n {
1420 let actual_row = p[i];
1421
1422 l_rows.push(i);
1424 l_cols.push(i);
1425 l_vals.push(T::sparse_one());
1426
1427 for j in 0..n {
1428 let val = matrix.get(actual_row, j);
1429 if !SparseElement::is_zero(&val) {
1430 if j < i {
1431 l_rows.push(i);
1433 l_cols.push(j);
1434 l_vals.push(val);
1435 } else {
1436 u_rows.push(i);
1438 u_cols.push(j);
1439 u_vals.push(val);
1440 }
1441 }
1442 }
1443 }
1444
1445 (l_rows, l_cols, l_vals, u_rows, u_cols, u_vals)
1446}
1447
1448#[allow(dead_code)]
1450fn extract_lower_triangular<T>(
1451 matrix: &SparseWorkingMatrix<T>,
1452 n: usize,
1453) -> (Vec<usize>, Vec<usize>, Vec<T>)
1454where
1455 T: Float + SparseElement + Debug + Copy,
1456{
1457 let mut rows = Vec::new();
1458 let mut cols = Vec::new();
1459 let mut vals = Vec::new();
1460
1461 for i in 0..n {
1462 for j in 0..=i {
1463 let val = matrix.get(i, j);
1464 if !SparseElement::is_zero(&val) {
1465 rows.push(i);
1466 cols.push(j);
1467 vals.push(val);
1468 }
1469 }
1470 }
1471
1472 (rows, cols, vals)
1473}
1474
1475#[allow(dead_code)]
1477fn dense_to_sparse<T>(matrix: &Array2<T>) -> SparseResult<CsrArray<T>>
1478where
1479 T: Float + SparseElement + Debug + Copy,
1480{
1481 let (m, n) = matrix.dim();
1482 let mut rows = Vec::new();
1483 let mut cols = Vec::new();
1484 let mut vals = Vec::new();
1485
1486 for i in 0..m {
1487 for j in 0..n {
1488 let val = matrix[[i, j]];
1489 if !SparseElement::is_zero(&val) {
1490 rows.push(i);
1491 cols.push(j);
1492 vals.push(val);
1493 }
1494 }
1495 }
1496
1497 CsrArray::from_triplets(&rows, &cols, &vals, (m, n), false)
1498}
1499
1500#[cfg(test)]
1501mod tests {
1502 use super::*;
1503 use crate::csr_array::CsrArray;
1504
1505 fn create_test_matrix() -> CsrArray<f64> {
1506 let rows = vec![0, 0, 1, 1, 2, 2];
1508 let cols = vec![0, 1, 0, 1, 1, 2];
1509 let data = vec![2.0, 1.0, 1.0, 3.0, 2.0, 4.0];
1510
1511 CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap()
1512 }
1513
1514 fn create_spd_matrix() -> CsrArray<f64> {
1515 let rows = vec![0, 1, 1, 2, 2, 2];
1517 let cols = vec![0, 0, 1, 0, 1, 2];
1518 let data = vec![4.0, 2.0, 5.0, 1.0, 3.0, 6.0];
1519
1520 CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap()
1521 }
1522
1523 #[test]
1524 fn test_lu_decomposition() {
1525 let matrix = create_test_matrix();
1526 let lu_result = lu_decomposition(&matrix, 0.1).unwrap();
1527
1528 assert!(lu_result.success);
1529 assert_eq!(lu_result.l.shape(), (3, 3));
1530 assert_eq!(lu_result.u.shape(), (3, 3));
1531 assert_eq!(lu_result.p.len(), 3);
1532 }
1533
1534 #[test]
1535 fn test_qr_decomposition() {
1536 let rows = vec![0, 1, 2];
1537 let cols = vec![0, 0, 1];
1538 let data = vec![1.0, 2.0, 3.0];
1539 let matrix = CsrArray::from_triplets(&rows, &cols, &data, (3, 2), false).unwrap();
1540
1541 let qr_result = qr_decomposition(&matrix).unwrap();
1542
1543 assert!(qr_result.success);
1544 assert_eq!(qr_result.q.shape(), (3, 2));
1545 assert_eq!(qr_result.r.shape(), (2, 2));
1546 }
1547
1548 #[test]
1549 fn test_cholesky_decomposition() {
1550 let matrix = create_spd_matrix();
1551 let chol_result = cholesky_decomposition(&matrix).unwrap();
1552
1553 assert!(chol_result.success);
1554 assert_eq!(chol_result.l.shape(), (3, 3));
1555 }
1556
1557 #[test]
1558 fn test_incomplete_lu() {
1559 let matrix = create_test_matrix();
1560 let options = ILUOptions {
1561 drop_tol: 1e-6,
1562 ..Default::default()
1563 };
1564
1565 let ilu_result = incomplete_lu(&matrix, Some(options)).unwrap();
1566
1567 assert!(ilu_result.success);
1568 assert_eq!(ilu_result.l.shape(), (3, 3));
1569 assert_eq!(ilu_result.u.shape(), (3, 3));
1570 }
1571
1572 #[test]
1573 fn test_incomplete_cholesky() {
1574 let matrix = create_spd_matrix();
1575 let options = ICOptions {
1576 drop_tol: 1e-6,
1577 ..Default::default()
1578 };
1579
1580 let ic_result = incomplete_cholesky(&matrix, Some(options)).unwrap();
1581
1582 assert!(ic_result.success);
1583 assert_eq!(ic_result.l.shape(), (3, 3));
1584 }
1585
1586 #[test]
1587 fn test_sparse_working_matrix() {
1588 let rows = vec![0, 1, 2];
1589 let cols = vec![0, 1, 2];
1590 let vals = vec![1.0, 2.0, 3.0];
1591
1592 let mut matrix = SparseWorkingMatrix::from_triplets(&rows, &cols, &vals, 3);
1593
1594 assert_eq!(matrix.get(0, 0), 1.0);
1595 assert_eq!(matrix.get(1, 1), 2.0);
1596 assert_eq!(matrix.get(2, 2), 3.0);
1597 assert_eq!(matrix.get(0, 1), 0.0);
1598
1599 matrix.set(0, 1, 5.0);
1600 assert_eq!(matrix.get(0, 1), 5.0);
1601
1602 matrix.set(0, 1, 0.0);
1603 assert_eq!(matrix.get(0, 1), 0.0);
1604 assert!(!matrix.has_entry(0, 1));
1605 }
1606
1607 #[test]
1608 fn test_dense_to_sparse_conversion() {
1609 let dense = Array2::from_shape_vec((2, 2), vec![1.0, 0.0, 2.0, 3.0]).unwrap();
1610 let sparse = dense_to_sparse(&dense).unwrap();
1611
1612 assert_eq!(sparse.nnz(), 3);
1613 assert_eq!(sparse.get(0, 0), 1.0);
1614 assert_eq!(sparse.get(0, 1), 0.0);
1615 assert_eq!(sparse.get(1, 0), 2.0);
1616 assert_eq!(sparse.get(1, 1), 3.0);
1617 }
1618}