1use crate::csr_array::CsrArray;
7use crate::error::{SparseError, SparseResult};
8use crate::sparray::SparseArray;
9use scirs2_core::ndarray::{Array1, Array2};
10use scirs2_core::numeric::Float;
11use scirs2_core::SparseElement;
12use std::collections::HashMap;
13use std::fmt::Debug;
14use std::ops::{Add, Div, Mul, Sub};
15
16#[derive(Debug, Clone)]
18pub struct LUResult<T>
19where
20 T: Float + SparseElement + Debug + Copy + 'static,
21{
22 pub l: CsrArray<T>,
24 pub u: CsrArray<T>,
26 pub p: Array1<usize>,
28 pub success: bool,
30}
31
32#[derive(Debug, Clone)]
34pub struct QRResult<T>
35where
36 T: Float + SparseElement + Debug + Copy + 'static,
37{
38 pub q: CsrArray<T>,
40 pub r: CsrArray<T>,
42 pub success: bool,
44}
45
46#[derive(Debug, Clone)]
48pub struct CholeskyResult<T>
49where
50 T: Float + SparseElement + Debug + Copy + 'static,
51{
52 pub l: CsrArray<T>,
54 pub success: bool,
56}
57
58#[derive(Debug, Clone)]
60pub struct PivotedCholeskyResult<T>
61where
62 T: Float + SparseElement + Debug + Copy + 'static,
63{
64 pub l: CsrArray<T>,
66 pub p: Array1<usize>,
68 pub rank: usize,
70 pub success: bool,
72}
73
74#[derive(Debug, Clone, Default)]
76pub enum PivotingStrategy {
77 None,
79 #[default]
81 Partial,
82 Threshold(f64),
84 ScaledPartial,
86 Complete,
88 Rook,
90}
91
92#[derive(Debug, Clone)]
94pub struct LUOptions {
95 pub pivoting: PivotingStrategy,
97 pub zero_threshold: f64,
99 pub check_singular: bool,
101}
102
103impl Default for LUOptions {
104 fn default() -> Self {
105 Self {
106 pivoting: PivotingStrategy::default(),
107 zero_threshold: 1e-14,
108 check_singular: true,
109 }
110 }
111}
112
113#[derive(Debug, Clone)]
115pub struct ILUOptions {
116 pub drop_tol: f64,
118 pub fill_factor: f64,
120 pub max_fill_per_row: usize,
122 pub pivoting: PivotingStrategy,
124}
125
126impl Default for ILUOptions {
127 fn default() -> Self {
128 Self {
129 drop_tol: 1e-4,
130 fill_factor: 2.0,
131 max_fill_per_row: 20,
132 pivoting: PivotingStrategy::default(),
133 }
134 }
135}
136
137#[derive(Debug, Clone)]
139pub struct ICOptions {
140 pub drop_tol: f64,
142 pub fill_factor: f64,
144 pub max_fill_per_row: usize,
146}
147
148impl Default for ICOptions {
149 fn default() -> Self {
150 Self {
151 drop_tol: 1e-4,
152 fill_factor: 2.0,
153 max_fill_per_row: 20,
154 }
155 }
156}
157
158#[allow(dead_code)]
187pub fn lu_decomposition<T, S>(_matrix: &S, pivotthreshold: f64) -> SparseResult<LUResult<T>>
188where
189 T: Float
190 + SparseElement
191 + Debug
192 + Copy
193 + Add<Output = T>
194 + Sub<Output = T>
195 + Mul<Output = T>
196 + Div<Output = T>,
197 S: SparseArray<T>,
198{
199 let options = LUOptions {
201 pivoting: PivotingStrategy::Threshold(pivotthreshold),
202 zero_threshold: 1e-14,
203 check_singular: true,
204 };
205
206 lu_decomposition_with_options(_matrix, Some(options))
207}
208
209#[allow(dead_code)]
244pub fn lu_decomposition_with_options<T, S>(
245 matrix: &S,
246 options: Option<LUOptions>,
247) -> SparseResult<LUResult<T>>
248where
249 T: Float
250 + SparseElement
251 + Debug
252 + Copy
253 + Add<Output = T>
254 + Sub<Output = T>
255 + Mul<Output = T>
256 + Div<Output = T>,
257 S: SparseArray<T>,
258{
259 let opts = options.unwrap_or_default();
260 let (n, m) = matrix.shape();
261 if n != m {
262 return Err(SparseError::ValueError(
263 "Matrix must be square for LU decomposition".to_string(),
264 ));
265 }
266
267 let (row_indices, col_indices, values) = matrix.find();
269 let mut working_matrix = SparseWorkingMatrix::from_triplets(
270 row_indices.as_slice().expect("Operation failed"),
271 col_indices.as_slice().expect("Operation failed"),
272 values.as_slice().expect("Operation failed"),
273 n,
274 );
275
276 let mut row_perm: Vec<usize> = (0..n).collect();
278 let mut col_perm: Vec<usize> = (0..n).collect();
279
280 let mut row_scales = vec![T::sparse_one(); n];
282 if matches!(opts.pivoting, PivotingStrategy::ScaledPartial) {
283 for (i, scale) in row_scales.iter_mut().enumerate().take(n) {
284 let row_data = working_matrix.get_row(i);
285 let max_val = row_data
286 .values()
287 .map(|&v| v.abs())
288 .fold(T::sparse_zero(), |a, b| if a > b { a } else { b });
289 if max_val > T::sparse_zero() {
290 *scale = max_val;
291 }
292 }
293 }
294
295 for k in 0..n - 1 {
297 let (pivot_row, pivot_col) =
299 find_enhanced_pivot(&working_matrix, k, &row_perm, &col_perm, &row_scales, &opts)?;
300
301 if pivot_row != k {
303 row_perm.swap(k, pivot_row);
304 }
305 if pivot_col != k
306 && matches!(
307 opts.pivoting,
308 PivotingStrategy::Complete | PivotingStrategy::Rook
309 )
310 {
311 col_perm.swap(k, pivot_col);
312 for &row_idx in row_perm.iter().take(n) {
314 let temp = working_matrix.get(row_idx, k);
315 working_matrix.set(row_idx, k, working_matrix.get(row_idx, pivot_col));
316 working_matrix.set(row_idx, pivot_col, temp);
317 }
318 }
319
320 let actual_pivot_row = row_perm[k];
321 let actual_pivot_col = col_perm[k];
322 let pivot_value = working_matrix.get(actual_pivot_row, actual_pivot_col);
323
324 if opts.check_singular
326 && pivot_value.abs() < T::from(opts.zero_threshold).expect("Operation failed")
327 {
328 return Ok(LUResult {
329 l: CsrArray::from_triplets(&[], &[], &[], (n, n), false)?,
330 u: CsrArray::from_triplets(&[], &[], &[], (n, n), false)?,
331 p: Array1::from_vec(row_perm),
332 success: false,
333 });
334 }
335
336 for &actual_row_i in row_perm.iter().take(n).skip(k + 1) {
338 let factor = working_matrix.get(actual_row_i, actual_pivot_col) / pivot_value;
339
340 if !SparseElement::is_zero(&factor) {
341 working_matrix.set(actual_row_i, actual_pivot_col, factor);
343
344 let pivot_row_data = working_matrix.get_row(actual_pivot_row);
346 for (col, &value) in &pivot_row_data {
347 if *col > k {
348 let old_val = working_matrix.get(actual_row_i, *col);
349 working_matrix.set(actual_row_i, *col, old_val - factor * value);
350 }
351 }
352 }
353 }
354 }
355
356 let (l_rows, l_cols, l_vals, u_rows, u_cols, u_vals) =
358 extract_lu_factors(&working_matrix, &row_perm, n);
359
360 let l = CsrArray::from_triplets(&l_rows, &l_cols, &l_vals, (n, n), false)?;
361 let u = CsrArray::from_triplets(&u_rows, &u_cols, &u_vals, (n, n), false)?;
362
363 Ok(LUResult {
364 l,
365 u,
366 p: Array1::from_vec(row_perm),
367 success: true,
368 })
369}
370
371#[allow(dead_code)]
398pub fn qr_decomposition<T, S>(matrix: &S) -> SparseResult<QRResult<T>>
399where
400 T: Float
401 + SparseElement
402 + Debug
403 + Copy
404 + Add<Output = T>
405 + Sub<Output = T>
406 + Mul<Output = T>
407 + Div<Output = T>,
408 S: SparseArray<T>,
409{
410 let (m, n) = matrix.shape();
411
412 let dense_matrix = matrix.to_array();
414
415 let mut q = Array2::zeros((m, n));
417 let mut r = Array2::zeros((n, n));
418
419 for j in 0..n {
420 for i in 0..m {
422 q[[i, j]] = dense_matrix[[i, j]];
423 }
424
425 for k in 0..j {
427 let mut dot = T::sparse_zero();
428 for i in 0..m {
429 dot = dot + q[[i, k]] * dense_matrix[[i, j]];
430 }
431 r[[k, j]] = dot;
432
433 for i in 0..m {
434 q[[i, j]] = q[[i, j]] - dot * q[[i, k]];
435 }
436 }
437
438 let mut norm = T::sparse_zero();
440 for i in 0..m {
441 norm = norm + q[[i, j]] * q[[i, j]];
442 }
443 norm = norm.sqrt();
444 r[[j, j]] = norm;
445
446 if !SparseElement::is_zero(&norm) {
447 for i in 0..m {
448 q[[i, j]] = q[[i, j]] / norm;
449 }
450 }
451 }
452
453 let q_sparse = dense_to_sparse(&q)?;
455 let r_sparse = dense_to_sparse(&r)?;
456
457 Ok(QRResult {
458 q: q_sparse,
459 r: r_sparse,
460 success: true,
461 })
462}
463
464#[allow(dead_code)]
492pub fn cholesky_decomposition<T, S>(matrix: &S) -> SparseResult<CholeskyResult<T>>
493where
494 T: Float
495 + SparseElement
496 + Debug
497 + Copy
498 + Add<Output = T>
499 + Sub<Output = T>
500 + Mul<Output = T>
501 + Div<Output = T>,
502 S: SparseArray<T>,
503{
504 let (n, m) = matrix.shape();
505 if n != m {
506 return Err(SparseError::ValueError(
507 "Matrix must be square for Cholesky decomposition".to_string(),
508 ));
509 }
510
511 let (row_indices, col_indices, values) = matrix.find();
513 let mut working_matrix = SparseWorkingMatrix::from_triplets(
514 row_indices.as_slice().expect("Operation failed"),
515 col_indices.as_slice().expect("Operation failed"),
516 values.as_slice().expect("Operation failed"),
517 n,
518 );
519
520 for k in 0..n {
522 let mut sum = T::sparse_zero();
524 for j in 0..k {
525 let l_kj = working_matrix.get(k, j);
526 sum = sum + l_kj * l_kj;
527 }
528
529 let a_kk = working_matrix.get(k, k);
530 let diag_val = a_kk - sum;
531
532 if diag_val <= T::sparse_zero() {
533 return Ok(CholeskyResult {
534 l: CsrArray::from_triplets(&[], &[], &[], (n, n), false)?,
535 success: false,
536 });
537 }
538
539 let l_kk = diag_val.sqrt();
540 working_matrix.set(k, k, l_kk);
541
542 for i in (k + 1)..n {
544 let mut sum = T::sparse_zero();
545 for j in 0..k {
546 sum = sum + working_matrix.get(i, j) * working_matrix.get(k, j);
547 }
548
549 let a_ik = working_matrix.get(i, k);
550 let l_ik = (a_ik - sum) / l_kk;
551 working_matrix.set(i, k, l_ik);
552 }
553 }
554
555 let (l_rows, l_cols, l_vals) = extract_lower_triangular(&working_matrix, n);
557 let l = CsrArray::from_triplets(&l_rows, &l_cols, &l_vals, (n, n), false)?;
558
559 Ok(CholeskyResult { l, success: true })
560}
561
562#[allow(dead_code)]
592pub fn pivoted_cholesky_decomposition<T, S>(
593 matrix: &S,
594 threshold: Option<T>,
595) -> SparseResult<PivotedCholeskyResult<T>>
596where
597 T: Float
598 + SparseElement
599 + Debug
600 + Copy
601 + Add<Output = T>
602 + Sub<Output = T>
603 + Mul<Output = T>
604 + Div<Output = T>,
605 S: SparseArray<T>,
606{
607 let (n, m) = matrix.shape();
608 if n != m {
609 return Err(SparseError::ValueError(
610 "Matrix must be square for Cholesky decomposition".to_string(),
611 ));
612 }
613
614 let threshold = threshold.unwrap_or_else(|| T::from(1e-12).expect("Operation failed"));
615
616 let (row_indices, col_indices, values) = matrix.find();
618 let mut working_matrix = SparseWorkingMatrix::from_triplets(
619 row_indices.as_slice().expect("Operation failed"),
620 col_indices.as_slice().expect("Operation failed"),
621 values.as_slice().expect("Operation failed"),
622 n,
623 );
624
625 let mut perm: Vec<usize> = (0..n).collect();
627 let mut rank = 0;
628
629 for k in 0..n {
631 let mut max_diag = T::sparse_zero();
633 let mut pivot_idx = k;
634
635 for i in k..n {
636 let mut diag_val = working_matrix.get(perm[i], perm[i]);
637 for j in 0..k {
638 let l_ij = working_matrix.get(perm[i], perm[j]);
639 diag_val = diag_val - l_ij * l_ij;
640 }
641 if diag_val > max_diag {
642 max_diag = diag_val;
643 pivot_idx = i;
644 }
645 }
646
647 if max_diag <= threshold {
649 break;
650 }
651
652 if pivot_idx != k {
654 perm.swap(k, pivot_idx);
655 }
656
657 let l_kk = max_diag.sqrt();
659 working_matrix.set(perm[k], perm[k], l_kk);
660 rank += 1;
661
662 for i in (k + 1)..n {
664 let mut sum = T::sparse_zero();
665 for j in 0..k {
666 sum = sum
667 + working_matrix.get(perm[i], perm[j]) * working_matrix.get(perm[k], perm[j]);
668 }
669
670 let a_ik = working_matrix.get(perm[i], perm[k]);
671 let l_ik = (a_ik - sum) / l_kk;
672 working_matrix.set(perm[i], perm[k], l_ik);
673 }
674 }
675
676 let mut l_rows = Vec::new();
678 let mut l_cols = Vec::new();
679 let mut l_vals = Vec::new();
680
681 for i in 0..rank {
682 for j in 0..=i {
683 let val = working_matrix.get(perm[i], perm[j]);
684 if val != T::sparse_zero() {
685 l_rows.push(i);
686 l_cols.push(j);
687 l_vals.push(val);
688 }
689 }
690 }
691
692 let l = CsrArray::from_triplets(&l_rows, &l_cols, &l_vals, (n, rank), false)?;
693 let p = Array1::from_vec(perm);
694
695 Ok(PivotedCholeskyResult {
696 l,
697 p,
698 rank,
699 success: true,
700 })
701}
702
703#[derive(Debug, Clone)]
705pub struct LDLTResult<T>
706where
707 T: Float + SparseElement + Debug + Copy + 'static,
708{
709 pub l: CsrArray<T>,
711 pub d: Array1<T>,
713 pub p: Array1<usize>,
715 pub success: bool,
717}
718
719#[allow(dead_code)]
750pub fn ldlt_decomposition<T, S>(
751 matrix: &S,
752 pivoting: Option<bool>,
753 threshold: Option<T>,
754) -> SparseResult<LDLTResult<T>>
755where
756 T: Float
757 + SparseElement
758 + Debug
759 + Copy
760 + Add<Output = T>
761 + Sub<Output = T>
762 + Mul<Output = T>
763 + Div<Output = T>,
764 S: SparseArray<T>,
765{
766 let (n, m) = matrix.shape();
767 if n != m {
768 return Err(SparseError::ValueError(
769 "Matrix must be square for LDLT decomposition".to_string(),
770 ));
771 }
772
773 let use_pivoting = pivoting.unwrap_or(true);
774 let threshold = threshold.unwrap_or_else(|| T::from(1e-12).expect("Operation failed"));
775
776 let (row_indices, col_indices, values) = matrix.find();
778 let mut working_matrix = SparseWorkingMatrix::from_triplets(
779 row_indices.as_slice().expect("Operation failed"),
780 col_indices.as_slice().expect("Operation failed"),
781 values.as_slice().expect("Operation failed"),
782 n,
783 );
784
785 let mut perm: Vec<usize> = (0..n).collect();
787 let mut d_values = vec![T::sparse_zero(); n];
788
789 for k in 0..n {
791 if use_pivoting {
793 let pivot_idx = find_ldlt_pivot(&working_matrix, k, &perm, threshold);
794 if pivot_idx != k {
795 perm.swap(k, pivot_idx);
796 }
797 }
798
799 let actual_k = perm[k];
800
801 let mut diag_val = working_matrix.get(actual_k, actual_k);
803 for j in 0..k {
804 let l_kj = working_matrix.get(actual_k, perm[j]);
805 diag_val = diag_val - l_kj * l_kj * d_values[j];
806 }
807
808 d_values[k] = diag_val;
809
810 if diag_val.abs() < threshold {
812 return Ok(LDLTResult {
813 l: CsrArray::from_triplets(&[], &[], &[], (n, n), false)?,
814 d: Array1::from_vec(d_values),
815 p: Array1::from_vec(perm),
816 success: false,
817 });
818 }
819
820 for i in (k + 1)..n {
822 let actual_i = perm[i];
823 let mut l_ik = working_matrix.get(actual_i, actual_k);
824
825 for j in 0..k {
826 l_ik = l_ik
827 - working_matrix.get(actual_i, perm[j])
828 * working_matrix.get(actual_k, perm[j])
829 * d_values[j];
830 }
831
832 l_ik = l_ik / diag_val;
833 working_matrix.set(actual_i, actual_k, l_ik);
834 }
835
836 working_matrix.set(actual_k, actual_k, T::sparse_one());
838 }
839
840 let (l_rows, l_cols, l_vals) = extract_unit_lower_triangular(&working_matrix, &perm, n);
842 let l = CsrArray::from_triplets(&l_rows, &l_cols, &l_vals, (n, n), false)?;
843
844 Ok(LDLTResult {
845 l,
846 d: Array1::from_vec(d_values),
847 p: Array1::from_vec(perm),
848 success: true,
849 })
850}
851
852#[allow(dead_code)]
854fn find_ldlt_pivot<T>(
855 matrix: &SparseWorkingMatrix<T>,
856 k: usize,
857 perm: &[usize],
858 threshold: T,
859) -> usize
860where
861 T: Float + SparseElement + Debug + Copy,
862{
863 let n = matrix.n;
864 let mut max_val = T::sparse_zero();
865 let mut pivot_idx = k;
866
867 for (i, &actual_i) in perm.iter().enumerate().take(n).skip(k) {
869 let diag_val = matrix.get(actual_i, actual_i).abs();
870
871 if diag_val > max_val {
872 max_val = diag_val;
873 pivot_idx = i;
874 }
875 }
876
877 if max_val >= threshold {
879 pivot_idx
880 } else {
881 k }
883}
884
885#[allow(dead_code)]
887fn extract_unit_lower_triangular<T>(
888 matrix: &SparseWorkingMatrix<T>,
889 perm: &[usize],
890 n: usize,
891) -> (Vec<usize>, Vec<usize>, Vec<T>)
892where
893 T: Float + SparseElement + Debug + Copy,
894{
895 let mut rows = Vec::new();
896 let mut cols = Vec::new();
897 let mut vals = Vec::new();
898
899 for i in 0..n {
900 let actual_i = perm[i];
901
902 rows.push(i);
904 cols.push(i);
905 vals.push(T::sparse_one());
906
907 for (j, &perm_j) in perm.iter().enumerate().take(i) {
909 let val = matrix.get(actual_i, perm_j);
910 if val != T::sparse_zero() {
911 rows.push(i);
912 cols.push(j);
913 vals.push(val);
914 }
915 }
916 }
917
918 (rows, cols, vals)
919}
920
921#[allow(dead_code)]
935pub fn incomplete_lu<T, S>(matrix: &S, options: Option<ILUOptions>) -> SparseResult<LUResult<T>>
936where
937 T: Float
938 + SparseElement
939 + Debug
940 + Copy
941 + Add<Output = T>
942 + Sub<Output = T>
943 + Mul<Output = T>
944 + Div<Output = T>,
945 S: SparseArray<T>,
946{
947 let opts = options.unwrap_or_default();
948 let (n, m) = matrix.shape();
949
950 if n != m {
951 return Err(SparseError::ValueError(
952 "Matrix must be square for ILU decomposition".to_string(),
953 ));
954 }
955
956 let (row_indices, col_indices, values) = matrix.find();
958 let mut working_matrix = SparseWorkingMatrix::from_triplets(
959 row_indices.as_slice().expect("Operation failed"),
960 col_indices.as_slice().expect("Operation failed"),
961 values.as_slice().expect("Operation failed"),
962 n,
963 );
964
965 for k in 0..n - 1 {
967 let pivot_val = working_matrix.get(k, k);
968
969 if pivot_val.abs() < T::from(1e-14).expect("Operation failed") {
970 continue; }
972
973 let col_k_entries = working_matrix.get_column_below_diagonal(k);
975
976 for &row_i in &col_k_entries {
977 let factor = working_matrix.get(row_i, k) / pivot_val;
978
979 if factor.abs() < T::from(opts.drop_tol).expect("Operation failed") {
981 working_matrix.set(row_i, k, T::sparse_zero());
982 continue;
983 }
984
985 working_matrix.set(row_i, k, factor);
986
987 let row_k_entries = working_matrix.get_row_after_column(k, k);
989 for (col_j, &val_kj) in &row_k_entries {
990 if working_matrix.has_entry(row_i, *col_j) {
991 let old_val = working_matrix.get(row_i, *col_j);
992 let new_val = old_val - factor * val_kj;
993
994 if new_val.abs() < T::from(opts.drop_tol).expect("Operation failed") {
996 working_matrix.set(row_i, *col_j, T::sparse_zero());
997 } else {
998 working_matrix.set(row_i, *col_j, new_val);
999 }
1000 }
1001 }
1002 }
1003 }
1004
1005 let identity_p: Vec<usize> = (0..n).collect();
1007 let (l_rows, l_cols, l_vals, u_rows, u_cols, u_vals) =
1008 extract_lu_factors(&working_matrix, &identity_p, n);
1009
1010 let l = CsrArray::from_triplets(&l_rows, &l_cols, &l_vals, (n, n), false)?;
1011 let u = CsrArray::from_triplets(&u_rows, &u_cols, &u_vals, (n, n), false)?;
1012
1013 Ok(LUResult {
1014 l,
1015 u,
1016 p: Array1::from_vec(identity_p),
1017 success: true,
1018 })
1019}
1020
1021#[allow(dead_code)]
1035pub fn incomplete_cholesky<T, S>(
1036 matrix: &S,
1037 options: Option<ICOptions>,
1038) -> SparseResult<CholeskyResult<T>>
1039where
1040 T: Float
1041 + SparseElement
1042 + Debug
1043 + Copy
1044 + Add<Output = T>
1045 + Sub<Output = T>
1046 + Mul<Output = T>
1047 + Div<Output = T>,
1048 S: SparseArray<T>,
1049{
1050 let opts = options.unwrap_or_default();
1051 let (n, m) = matrix.shape();
1052
1053 if n != m {
1054 return Err(SparseError::ValueError(
1055 "Matrix must be square for IC decomposition".to_string(),
1056 ));
1057 }
1058
1059 let (row_indices, col_indices, values) = matrix.find();
1061 let mut working_matrix = SparseWorkingMatrix::from_triplets(
1062 row_indices.as_slice().expect("Operation failed"),
1063 col_indices.as_slice().expect("Operation failed"),
1064 values.as_slice().expect("Operation failed"),
1065 n,
1066 );
1067
1068 for k in 0..n {
1070 let mut sum = T::sparse_zero();
1072 let row_k_before_k = working_matrix.get_row_before_column(k, k);
1073 for &val_kj in row_k_before_k.values() {
1074 sum = sum + val_kj * val_kj;
1075 }
1076
1077 let a_kk = working_matrix.get(k, k);
1078 let diag_val = a_kk - sum;
1079
1080 if diag_val <= T::sparse_zero() {
1081 return Ok(CholeskyResult {
1082 l: CsrArray::from_triplets(&[], &[], &[], (n, n), false)?,
1083 success: false,
1084 });
1085 }
1086
1087 let l_kk = diag_val.sqrt();
1088 working_matrix.set(k, k, l_kk);
1089
1090 let col_k_below = working_matrix.get_column_below_diagonal(k);
1092 for &row_i in &col_k_below {
1093 let mut sum = T::sparse_zero();
1094 let row_i_before_k = working_matrix.get_row_before_column(row_i, k);
1095 let row_k_before_k = working_matrix.get_row_before_column(k, k);
1096
1097 for (col_j, &val_ij) in &row_i_before_k {
1099 if let Some(&val_kj) = row_k_before_k.get(col_j) {
1100 sum = sum + val_ij * val_kj;
1101 }
1102 }
1103
1104 let a_ik = working_matrix.get(row_i, k);
1105 let l_ik = (a_ik - sum) / l_kk;
1106
1107 if l_ik.abs() < T::from(opts.drop_tol).expect("Operation failed") {
1109 working_matrix.set(row_i, k, T::sparse_zero());
1110 } else {
1111 working_matrix.set(row_i, k, l_ik);
1112 }
1113 }
1114 }
1115
1116 let (l_rows, l_cols, l_vals) = extract_lower_triangular(&working_matrix, n);
1118 let l = CsrArray::from_triplets(&l_rows, &l_cols, &l_vals, (n, n), false)?;
1119
1120 Ok(CholeskyResult { l, success: true })
1121}
1122
1123struct SparseWorkingMatrix<T>
1125where
1126 T: Float + SparseElement + Debug + Copy,
1127{
1128 data: HashMap<(usize, usize), T>,
1129 n: usize,
1130}
1131
1132impl<T> SparseWorkingMatrix<T>
1133where
1134 T: Float
1135 + SparseElement
1136 + Debug
1137 + Copy
1138 + Add<Output = T>
1139 + Sub<Output = T>
1140 + Mul<Output = T>
1141 + Div<Output = T>,
1142{
1143 fn from_triplets(rows: &[usize], cols: &[usize], values: &[T], n: usize) -> Self {
1144 let mut data = HashMap::new();
1145
1146 for (i, (&row, &col)) in rows.iter().zip(cols.iter()).enumerate() {
1147 data.insert((row, col), values[i]);
1148 }
1149
1150 Self { data, n }
1151 }
1152
1153 fn get(&self, row: usize, col: usize) -> T {
1154 self.data
1155 .get(&(row, col))
1156 .copied()
1157 .unwrap_or(T::sparse_zero())
1158 }
1159
1160 fn set(&mut self, row: usize, col: usize, value: T) {
1161 if SparseElement::is_zero(&value) {
1162 self.data.remove(&(row, col));
1163 } else {
1164 self.data.insert((row, col), value);
1165 }
1166 }
1167
1168 fn has_entry(&self, row: usize, col: usize) -> bool {
1169 self.data.contains_key(&(row, col))
1170 }
1171
1172 fn get_row(&self, row: usize) -> HashMap<usize, T> {
1173 let mut result = HashMap::new();
1174 for (&(r, c), &value) in &self.data {
1175 if r == row {
1176 result.insert(c, value);
1177 }
1178 }
1179 result
1180 }
1181
1182 fn get_row_after_column(&self, row: usize, col: usize) -> HashMap<usize, T> {
1183 let mut result = HashMap::new();
1184 for (&(r, c), &value) in &self.data {
1185 if r == row && c > col {
1186 result.insert(c, value);
1187 }
1188 }
1189 result
1190 }
1191
1192 fn get_row_before_column(&self, row: usize, col: usize) -> HashMap<usize, T> {
1193 let mut result = HashMap::new();
1194 for (&(r, c), &value) in &self.data {
1195 if r == row && c < col {
1196 result.insert(c, value);
1197 }
1198 }
1199 result
1200 }
1201
1202 fn get_column_below_diagonal(&self, col: usize) -> Vec<usize> {
1203 let mut result = Vec::new();
1204 for &(r, c) in self.data.keys() {
1205 if c == col && r > col {
1206 result.push(r);
1207 }
1208 }
1209 result.sort();
1210 result
1211 }
1212}
1213
1214#[allow(dead_code)]
1216fn find_pivot<T>(
1217 matrix: &SparseWorkingMatrix<T>,
1218 k: usize,
1219 p: &[usize],
1220 threshold: f64,
1221) -> SparseResult<usize>
1222where
1223 T: Float + SparseElement + Debug + Copy,
1224{
1225 let opts = LUOptions {
1227 pivoting: PivotingStrategy::Threshold(threshold),
1228 zero_threshold: 1e-14,
1229 check_singular: true,
1230 };
1231
1232 let row_scales = vec![T::sparse_one(); matrix.n];
1233 let col_perm: Vec<usize> = (0..matrix.n).collect();
1234
1235 let (pivot_row, pivot_col) = find_enhanced_pivot(matrix, k, p, &col_perm, &row_scales, &opts)?;
1236 Ok(pivot_row)
1237}
1238
1239#[allow(dead_code)]
1241fn find_enhanced_pivot<T>(
1242 matrix: &SparseWorkingMatrix<T>,
1243 k: usize,
1244 row_perm: &[usize],
1245 col_perm: &[usize],
1246 row_scales: &[T],
1247 opts: &LUOptions,
1248) -> SparseResult<(usize, usize)>
1249where
1250 T: Float + SparseElement + Debug + Copy,
1251{
1252 let n = matrix.n;
1253
1254 match &opts.pivoting {
1255 PivotingStrategy::None => {
1256 Ok((k, k))
1258 }
1259
1260 PivotingStrategy::Partial => {
1261 let mut max_val = T::sparse_zero();
1263 let mut pivot_row = k;
1264
1265 for (idx, &actual_row) in row_perm.iter().enumerate().skip(k).take(n - k) {
1266 let i = k + idx;
1267 let val = matrix.get(actual_row, col_perm[k]).abs();
1268 if val > max_val {
1269 max_val = val;
1270 pivot_row = i;
1271 }
1272 }
1273
1274 Ok((pivot_row, k))
1275 }
1276
1277 PivotingStrategy::Threshold(threshold) => {
1278 let threshold_val = T::from(*threshold).expect("Operation failed");
1280 let mut max_val = T::sparse_zero();
1281 let mut pivot_row = k;
1282
1283 for (idx, &actual_row) in row_perm.iter().enumerate().skip(k).take(n - k) {
1284 let i = k + idx;
1285 let val = matrix.get(actual_row, col_perm[k]).abs();
1286 if val > max_val {
1287 max_val = val;
1288 pivot_row = i;
1289 }
1290 if val >= threshold_val {
1292 pivot_row = i;
1293 break;
1294 }
1295 }
1296
1297 Ok((pivot_row, k))
1298 }
1299
1300 PivotingStrategy::ScaledPartial => {
1301 let mut max_ratio = T::sparse_zero();
1303 let mut pivot_row = k;
1304
1305 for (idx, &actual_row) in row_perm.iter().enumerate().skip(k).take(n - k) {
1306 let i = k + idx;
1307 let val = matrix.get(actual_row, col_perm[k]).abs();
1308 let scale = row_scales[actual_row];
1309
1310 let ratio = if scale > T::sparse_zero() {
1311 val / scale
1312 } else {
1313 val
1314 };
1315
1316 if ratio > max_ratio {
1317 max_ratio = ratio;
1318 pivot_row = i;
1319 }
1320 }
1321
1322 Ok((pivot_row, k))
1323 }
1324
1325 PivotingStrategy::Complete => {
1326 let mut max_val = T::sparse_zero();
1328 let mut pivot_row = k;
1329 let mut pivot_col = k;
1330
1331 for (i_idx, &actual_row) in row_perm.iter().enumerate().skip(k).take(n - k) {
1332 let i = k + i_idx;
1333 for (j_idx, &actual_col) in col_perm.iter().enumerate().skip(k).take(n - k) {
1334 let j = k + j_idx;
1335 let val = matrix.get(actual_row, actual_col).abs();
1336 if val > max_val {
1337 max_val = val;
1338 pivot_row = i;
1339 pivot_col = j;
1340 }
1341 }
1342 }
1343
1344 Ok((pivot_row, pivot_col))
1345 }
1346
1347 PivotingStrategy::Rook => {
1348 let mut best_row = k;
1350 let mut best_col = k;
1351 let mut max_val = T::sparse_zero();
1352
1353 for (idx, &actual_row) in row_perm.iter().enumerate().skip(k).take(n - k) {
1355 let i = k + idx;
1356 let val = matrix.get(actual_row, col_perm[k]).abs();
1357 if val > max_val {
1358 max_val = val;
1359 best_row = i;
1360 }
1361 }
1362
1363 if max_val > T::from(opts.zero_threshold).expect("Operation failed") {
1365 let actual_best_row = row_perm[best_row];
1366 let mut col_max = T::sparse_zero();
1367
1368 for (idx, &actual_col) in col_perm.iter().enumerate().skip(k).take(n - k) {
1369 let j = k + idx;
1370 let val = matrix.get(actual_best_row, actual_col).abs();
1371 if val > col_max {
1372 col_max = val;
1373 best_col = j;
1374 }
1375 }
1376
1377 let improvement_threshold = T::from(1.5).expect("Operation failed");
1379 if col_max > max_val * improvement_threshold {
1380 max_val = T::sparse_zero();
1382 for (idx, &actual_row) in row_perm.iter().enumerate().skip(k).take(n - k) {
1383 let i = k + idx;
1384 let val = matrix.get(actual_row, col_perm[best_col]).abs();
1385 if val > max_val {
1386 max_val = val;
1387 best_row = i;
1388 }
1389 }
1390 }
1391 }
1392
1393 Ok((best_row, best_col))
1394 }
1395 }
1396}
1397
1398type LuFactors<T> = (
1400 Vec<usize>, Vec<usize>, Vec<T>, Vec<usize>, Vec<usize>, Vec<T>, );
1407
1408#[allow(dead_code)]
1409fn extract_lu_factors<T>(matrix: &SparseWorkingMatrix<T>, p: &[usize], n: usize) -> LuFactors<T>
1410where
1411 T: Float + SparseElement + Debug + Copy,
1412{
1413 let mut l_rows = Vec::new();
1414 let mut l_cols = Vec::new();
1415 let mut l_vals = Vec::new();
1416 let mut u_rows = Vec::new();
1417 let mut u_cols = Vec::new();
1418 let mut u_vals = Vec::new();
1419
1420 #[allow(clippy::needless_range_loop)]
1421 for i in 0..n {
1422 let actual_row = p[i];
1423
1424 l_rows.push(i);
1426 l_cols.push(i);
1427 l_vals.push(T::sparse_one());
1428
1429 for j in 0..n {
1430 let val = matrix.get(actual_row, j);
1431 if !SparseElement::is_zero(&val) {
1432 if j < i {
1433 l_rows.push(i);
1435 l_cols.push(j);
1436 l_vals.push(val);
1437 } else {
1438 u_rows.push(i);
1440 u_cols.push(j);
1441 u_vals.push(val);
1442 }
1443 }
1444 }
1445 }
1446
1447 (l_rows, l_cols, l_vals, u_rows, u_cols, u_vals)
1448}
1449
1450#[allow(dead_code)]
1452fn extract_lower_triangular<T>(
1453 matrix: &SparseWorkingMatrix<T>,
1454 n: usize,
1455) -> (Vec<usize>, Vec<usize>, Vec<T>)
1456where
1457 T: Float + SparseElement + Debug + Copy,
1458{
1459 let mut rows = Vec::new();
1460 let mut cols = Vec::new();
1461 let mut vals = Vec::new();
1462
1463 for i in 0..n {
1464 for j in 0..=i {
1465 let val = matrix.get(i, j);
1466 if !SparseElement::is_zero(&val) {
1467 rows.push(i);
1468 cols.push(j);
1469 vals.push(val);
1470 }
1471 }
1472 }
1473
1474 (rows, cols, vals)
1475}
1476
1477#[allow(dead_code)]
1479fn dense_to_sparse<T>(matrix: &Array2<T>) -> SparseResult<CsrArray<T>>
1480where
1481 T: Float + SparseElement + Debug + Copy,
1482{
1483 let (m, n) = matrix.dim();
1484 let mut rows = Vec::new();
1485 let mut cols = Vec::new();
1486 let mut vals = Vec::new();
1487
1488 for i in 0..m {
1489 for j in 0..n {
1490 let val = matrix[[i, j]];
1491 if !SparseElement::is_zero(&val) {
1492 rows.push(i);
1493 cols.push(j);
1494 vals.push(val);
1495 }
1496 }
1497 }
1498
1499 CsrArray::from_triplets(&rows, &cols, &vals, (m, n), false)
1500}
1501
1502#[cfg(test)]
1503mod tests {
1504 use super::*;
1505 use crate::csr_array::CsrArray;
1506
1507 fn create_test_matrix() -> CsrArray<f64> {
1508 let rows = vec![0, 0, 1, 1, 2, 2];
1510 let cols = vec![0, 1, 0, 1, 1, 2];
1511 let data = vec![2.0, 1.0, 1.0, 3.0, 2.0, 4.0];
1512
1513 CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).expect("Operation failed")
1514 }
1515
1516 fn create_spd_matrix() -> CsrArray<f64> {
1517 let rows = vec![0, 1, 1, 2, 2, 2];
1519 let cols = vec![0, 0, 1, 0, 1, 2];
1520 let data = vec![4.0, 2.0, 5.0, 1.0, 3.0, 6.0];
1521
1522 CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).expect("Operation failed")
1523 }
1524
1525 #[test]
1526 fn test_lu_decomposition() {
1527 let matrix = create_test_matrix();
1528 let lu_result = lu_decomposition(&matrix, 0.1).expect("Operation failed");
1529
1530 assert!(lu_result.success);
1531 assert_eq!(lu_result.l.shape(), (3, 3));
1532 assert_eq!(lu_result.u.shape(), (3, 3));
1533 assert_eq!(lu_result.p.len(), 3);
1534 }
1535
1536 #[test]
1537 fn test_qr_decomposition() {
1538 let rows = vec![0, 1, 2];
1539 let cols = vec![0, 0, 1];
1540 let data = vec![1.0, 2.0, 3.0];
1541 let matrix =
1542 CsrArray::from_triplets(&rows, &cols, &data, (3, 2), false).expect("Operation failed");
1543
1544 let qr_result = qr_decomposition(&matrix).expect("Operation failed");
1545
1546 assert!(qr_result.success);
1547 assert_eq!(qr_result.q.shape(), (3, 2));
1548 assert_eq!(qr_result.r.shape(), (2, 2));
1549 }
1550
1551 #[test]
1552 fn test_cholesky_decomposition() {
1553 let matrix = create_spd_matrix();
1554 let chol_result = cholesky_decomposition(&matrix).expect("Operation failed");
1555
1556 assert!(chol_result.success);
1557 assert_eq!(chol_result.l.shape(), (3, 3));
1558 }
1559
1560 #[test]
1561 fn test_incomplete_lu() {
1562 let matrix = create_test_matrix();
1563 let options = ILUOptions {
1564 drop_tol: 1e-6,
1565 ..Default::default()
1566 };
1567
1568 let ilu_result = incomplete_lu(&matrix, Some(options)).expect("Operation failed");
1569
1570 assert!(ilu_result.success);
1571 assert_eq!(ilu_result.l.shape(), (3, 3));
1572 assert_eq!(ilu_result.u.shape(), (3, 3));
1573 }
1574
1575 #[test]
1576 fn test_incomplete_cholesky() {
1577 let matrix = create_spd_matrix();
1578 let options = ICOptions {
1579 drop_tol: 1e-6,
1580 ..Default::default()
1581 };
1582
1583 let ic_result = incomplete_cholesky(&matrix, Some(options)).expect("Operation failed");
1584
1585 assert!(ic_result.success);
1586 assert_eq!(ic_result.l.shape(), (3, 3));
1587 }
1588
1589 #[test]
1590 fn test_sparse_working_matrix() {
1591 let rows = vec![0, 1, 2];
1592 let cols = vec![0, 1, 2];
1593 let vals = vec![1.0, 2.0, 3.0];
1594
1595 let mut matrix = SparseWorkingMatrix::from_triplets(&rows, &cols, &vals, 3);
1596
1597 assert_eq!(matrix.get(0, 0), 1.0);
1598 assert_eq!(matrix.get(1, 1), 2.0);
1599 assert_eq!(matrix.get(2, 2), 3.0);
1600 assert_eq!(matrix.get(0, 1), 0.0);
1601
1602 matrix.set(0, 1, 5.0);
1603 assert_eq!(matrix.get(0, 1), 5.0);
1604
1605 matrix.set(0, 1, 0.0);
1606 assert_eq!(matrix.get(0, 1), 0.0);
1607 assert!(!matrix.has_entry(0, 1));
1608 }
1609
1610 #[test]
1611 fn test_dense_to_sparse_conversion() {
1612 let dense =
1613 Array2::from_shape_vec((2, 2), vec![1.0, 0.0, 2.0, 3.0]).expect("Operation failed");
1614 let sparse = dense_to_sparse(&dense).expect("Operation failed");
1615
1616 assert_eq!(sparse.nnz(), 3);
1617 assert_eq!(sparse.get(0, 0), 1.0);
1618 assert_eq!(sparse.get(0, 1), 0.0);
1619 assert_eq!(sparse.get(1, 0), 2.0);
1620 assert_eq!(sparse.get(1, 1), 3.0);
1621 }
1622}