1use crate::csr_array::CsrArray;
7use crate::error::{SparseError, SparseResult};
8use crate::sparray::SparseArray;
9use scirs2_core::ndarray::{Array1, Array2};
10use scirs2_core::numeric::Float;
11use std::collections::HashMap;
12use std::fmt::Debug;
13use std::ops::{Add, Div, Mul, Sub};
14
15#[derive(Debug, Clone)]
17pub struct LUResult<T>
18where
19 T: Float + Debug + Copy + 'static,
20{
21 pub l: CsrArray<T>,
23 pub u: CsrArray<T>,
25 pub p: Array1<usize>,
27 pub success: bool,
29}
30
31#[derive(Debug, Clone)]
33pub struct QRResult<T>
34where
35 T: Float + Debug + Copy + 'static,
36{
37 pub q: CsrArray<T>,
39 pub r: CsrArray<T>,
41 pub success: bool,
43}
44
45#[derive(Debug, Clone)]
47pub struct CholeskyResult<T>
48where
49 T: Float + Debug + Copy + 'static,
50{
51 pub l: CsrArray<T>,
53 pub success: bool,
55}
56
57#[derive(Debug, Clone)]
59pub struct PivotedCholeskyResult<T>
60where
61 T: Float + Debug + Copy + 'static,
62{
63 pub l: CsrArray<T>,
65 pub p: Array1<usize>,
67 pub rank: usize,
69 pub success: bool,
71}
72
73#[derive(Debug, Clone, Default)]
75pub enum PivotingStrategy {
76 None,
78 #[default]
80 Partial,
81 Threshold(f64),
83 ScaledPartial,
85 Complete,
87 Rook,
89}
90
91#[derive(Debug, Clone)]
93pub struct LUOptions {
94 pub pivoting: PivotingStrategy,
96 pub zero_threshold: f64,
98 pub check_singular: bool,
100}
101
102impl Default for LUOptions {
103 fn default() -> Self {
104 Self {
105 pivoting: PivotingStrategy::default(),
106 zero_threshold: 1e-14,
107 check_singular: true,
108 }
109 }
110}
111
112#[derive(Debug, Clone)]
114pub struct ILUOptions {
115 pub drop_tol: f64,
117 pub fill_factor: f64,
119 pub max_fill_per_row: usize,
121 pub pivoting: PivotingStrategy,
123}
124
125impl Default for ILUOptions {
126 fn default() -> Self {
127 Self {
128 drop_tol: 1e-4,
129 fill_factor: 2.0,
130 max_fill_per_row: 20,
131 pivoting: PivotingStrategy::default(),
132 }
133 }
134}
135
136#[derive(Debug, Clone)]
138pub struct ICOptions {
139 pub drop_tol: f64,
141 pub fill_factor: f64,
143 pub max_fill_per_row: usize,
145}
146
147impl Default for ICOptions {
148 fn default() -> Self {
149 Self {
150 drop_tol: 1e-4,
151 fill_factor: 2.0,
152 max_fill_per_row: 20,
153 }
154 }
155}
156
157#[allow(dead_code)]
186pub fn lu_decomposition<T, S>(_matrix: &S, pivotthreshold: f64) -> SparseResult<LUResult<T>>
187where
188 T: Float + Debug + Copy + Add<Output = T> + Sub<Output = T> + Mul<Output = T> + Div<Output = T>,
189 S: SparseArray<T>,
190{
191 let options = LUOptions {
193 pivoting: PivotingStrategy::Threshold(pivotthreshold),
194 zero_threshold: 1e-14,
195 check_singular: true,
196 };
197
198 lu_decomposition_with_options(_matrix, Some(options))
199}
200
201#[allow(dead_code)]
236pub fn lu_decomposition_with_options<T, S>(
237 matrix: &S,
238 options: Option<LUOptions>,
239) -> SparseResult<LUResult<T>>
240where
241 T: Float + Debug + Copy + Add<Output = T> + Sub<Output = T> + Mul<Output = T> + Div<Output = T>,
242 S: SparseArray<T>,
243{
244 let opts = options.unwrap_or_default();
245 let (n, m) = matrix.shape();
246 if n != m {
247 return Err(SparseError::ValueError(
248 "Matrix must be square for LU decomposition".to_string(),
249 ));
250 }
251
252 let (row_indices, col_indices, values) = matrix.find();
254 let mut working_matrix = SparseWorkingMatrix::from_triplets(
255 row_indices.as_slice().unwrap(),
256 col_indices.as_slice().unwrap(),
257 values.as_slice().unwrap(),
258 n,
259 );
260
261 let mut row_perm: Vec<usize> = (0..n).collect();
263 let mut col_perm: Vec<usize> = (0..n).collect();
264
265 let mut row_scales = vec![T::one(); n];
267 if matches!(opts.pivoting, PivotingStrategy::ScaledPartial) {
268 for (i, scale) in row_scales.iter_mut().enumerate().take(n) {
269 let row_data = working_matrix.get_row(i);
270 let max_val =
271 row_data
272 .values()
273 .map(|&v| v.abs())
274 .fold(T::zero(), |a, b| if a > b { a } else { b });
275 if max_val > T::zero() {
276 *scale = max_val;
277 }
278 }
279 }
280
281 for k in 0..n - 1 {
283 let (pivot_row, pivot_col) =
285 find_enhanced_pivot(&working_matrix, k, &row_perm, &col_perm, &row_scales, &opts)?;
286
287 if pivot_row != k {
289 row_perm.swap(k, pivot_row);
290 }
291 if pivot_col != k
292 && matches!(
293 opts.pivoting,
294 PivotingStrategy::Complete | PivotingStrategy::Rook
295 )
296 {
297 col_perm.swap(k, pivot_col);
298 for &row_idx in row_perm.iter().take(n) {
300 let temp = working_matrix.get(row_idx, k);
301 working_matrix.set(row_idx, k, working_matrix.get(row_idx, pivot_col));
302 working_matrix.set(row_idx, pivot_col, temp);
303 }
304 }
305
306 let actual_pivot_row = row_perm[k];
307 let actual_pivot_col = col_perm[k];
308 let pivot_value = working_matrix.get(actual_pivot_row, actual_pivot_col);
309
310 if opts.check_singular && pivot_value.abs() < T::from(opts.zero_threshold).unwrap() {
312 return Ok(LUResult {
313 l: CsrArray::from_triplets(&[], &[], &[], (n, n), false)?,
314 u: CsrArray::from_triplets(&[], &[], &[], (n, n), false)?,
315 p: Array1::from_vec(row_perm),
316 success: false,
317 });
318 }
319
320 for &actual_row_i in row_perm.iter().take(n).skip(k + 1) {
322 let factor = working_matrix.get(actual_row_i, actual_pivot_col) / pivot_value;
323
324 if !factor.is_zero() {
325 working_matrix.set(actual_row_i, actual_pivot_col, factor);
327
328 let pivot_row_data = working_matrix.get_row(actual_pivot_row);
330 for (col, &value) in &pivot_row_data {
331 if *col > k {
332 let old_val = working_matrix.get(actual_row_i, *col);
333 working_matrix.set(actual_row_i, *col, old_val - factor * value);
334 }
335 }
336 }
337 }
338 }
339
340 let (l_rows, l_cols, l_vals, u_rows, u_cols, u_vals) =
342 extract_lu_factors(&working_matrix, &row_perm, n);
343
344 let l = CsrArray::from_triplets(&l_rows, &l_cols, &l_vals, (n, n), false)?;
345 let u = CsrArray::from_triplets(&u_rows, &u_cols, &u_vals, (n, n), false)?;
346
347 Ok(LUResult {
348 l,
349 u,
350 p: Array1::from_vec(row_perm),
351 success: true,
352 })
353}
354
355#[allow(dead_code)]
382pub fn qr_decomposition<T, S>(matrix: &S) -> SparseResult<QRResult<T>>
383where
384 T: Float + Debug + Copy + Add<Output = T> + Sub<Output = T> + Mul<Output = T> + Div<Output = T>,
385 S: SparseArray<T>,
386{
387 let (m, n) = matrix.shape();
388
389 let dense_matrix = matrix.to_array();
391
392 let mut q = Array2::zeros((m, n));
394 let mut r = Array2::zeros((n, n));
395
396 for j in 0..n {
397 for i in 0..m {
399 q[[i, j]] = dense_matrix[[i, j]];
400 }
401
402 for k in 0..j {
404 let mut dot = T::zero();
405 for i in 0..m {
406 dot = dot + q[[i, k]] * dense_matrix[[i, j]];
407 }
408 r[[k, j]] = dot;
409
410 for i in 0..m {
411 q[[i, j]] = q[[i, j]] - dot * q[[i, k]];
412 }
413 }
414
415 let mut norm = T::zero();
417 for i in 0..m {
418 norm = norm + q[[i, j]] * q[[i, j]];
419 }
420 norm = norm.sqrt();
421 r[[j, j]] = norm;
422
423 if !norm.is_zero() {
424 for i in 0..m {
425 q[[i, j]] = q[[i, j]] / norm;
426 }
427 }
428 }
429
430 let q_sparse = dense_to_sparse(&q)?;
432 let r_sparse = dense_to_sparse(&r)?;
433
434 Ok(QRResult {
435 q: q_sparse,
436 r: r_sparse,
437 success: true,
438 })
439}
440
441#[allow(dead_code)]
469pub fn cholesky_decomposition<T, S>(matrix: &S) -> SparseResult<CholeskyResult<T>>
470where
471 T: Float + Debug + Copy + Add<Output = T> + Sub<Output = T> + Mul<Output = T> + Div<Output = T>,
472 S: SparseArray<T>,
473{
474 let (n, m) = matrix.shape();
475 if n != m {
476 return Err(SparseError::ValueError(
477 "Matrix must be square for Cholesky decomposition".to_string(),
478 ));
479 }
480
481 let (row_indices, col_indices, values) = matrix.find();
483 let mut working_matrix = SparseWorkingMatrix::from_triplets(
484 row_indices.as_slice().unwrap(),
485 col_indices.as_slice().unwrap(),
486 values.as_slice().unwrap(),
487 n,
488 );
489
490 for k in 0..n {
492 let mut sum = T::zero();
494 for j in 0..k {
495 let l_kj = working_matrix.get(k, j);
496 sum = sum + l_kj * l_kj;
497 }
498
499 let a_kk = working_matrix.get(k, k);
500 let diag_val = a_kk - sum;
501
502 if diag_val <= T::zero() {
503 return Ok(CholeskyResult {
504 l: CsrArray::from_triplets(&[], &[], &[], (n, n), false)?,
505 success: false,
506 });
507 }
508
509 let l_kk = diag_val.sqrt();
510 working_matrix.set(k, k, l_kk);
511
512 for i in (k + 1)..n {
514 let mut sum = T::zero();
515 for j in 0..k {
516 sum = sum + working_matrix.get(i, j) * working_matrix.get(k, j);
517 }
518
519 let a_ik = working_matrix.get(i, k);
520 let l_ik = (a_ik - sum) / l_kk;
521 working_matrix.set(i, k, l_ik);
522 }
523 }
524
525 let (l_rows, l_cols, l_vals) = extract_lower_triangular(&working_matrix, n);
527 let l = CsrArray::from_triplets(&l_rows, &l_cols, &l_vals, (n, n), false)?;
528
529 Ok(CholeskyResult { l, success: true })
530}
531
532#[allow(dead_code)]
562pub fn pivoted_cholesky_decomposition<T, S>(
563 matrix: &S,
564 threshold: Option<T>,
565) -> SparseResult<PivotedCholeskyResult<T>>
566where
567 T: Float + Debug + Copy + Add<Output = T> + Sub<Output = T> + Mul<Output = T> + Div<Output = T>,
568 S: SparseArray<T>,
569{
570 let (n, m) = matrix.shape();
571 if n != m {
572 return Err(SparseError::ValueError(
573 "Matrix must be square for Cholesky decomposition".to_string(),
574 ));
575 }
576
577 let threshold = threshold.unwrap_or_else(|| T::from(1e-12).unwrap());
578
579 let (row_indices, col_indices, values) = matrix.find();
581 let mut working_matrix = SparseWorkingMatrix::from_triplets(
582 row_indices.as_slice().unwrap(),
583 col_indices.as_slice().unwrap(),
584 values.as_slice().unwrap(),
585 n,
586 );
587
588 let mut perm: Vec<usize> = (0..n).collect();
590 let mut rank = 0;
591
592 for k in 0..n {
594 let mut max_diag = T::zero();
596 let mut pivot_idx = k;
597
598 for i in k..n {
599 let mut diag_val = working_matrix.get(perm[i], perm[i]);
600 for j in 0..k {
601 let l_ij = working_matrix.get(perm[i], perm[j]);
602 diag_val = diag_val - l_ij * l_ij;
603 }
604 if diag_val > max_diag {
605 max_diag = diag_val;
606 pivot_idx = i;
607 }
608 }
609
610 if max_diag <= threshold {
612 break;
613 }
614
615 if pivot_idx != k {
617 perm.swap(k, pivot_idx);
618 }
619
620 let l_kk = max_diag.sqrt();
622 working_matrix.set(perm[k], perm[k], l_kk);
623 rank += 1;
624
625 for i in (k + 1)..n {
627 let mut sum = T::zero();
628 for j in 0..k {
629 sum = sum
630 + working_matrix.get(perm[i], perm[j]) * working_matrix.get(perm[k], perm[j]);
631 }
632
633 let a_ik = working_matrix.get(perm[i], perm[k]);
634 let l_ik = (a_ik - sum) / l_kk;
635 working_matrix.set(perm[i], perm[k], l_ik);
636 }
637 }
638
639 let mut l_rows = Vec::new();
641 let mut l_cols = Vec::new();
642 let mut l_vals = Vec::new();
643
644 for i in 0..rank {
645 for j in 0..=i {
646 let val = working_matrix.get(perm[i], perm[j]);
647 if val != T::zero() {
648 l_rows.push(i);
649 l_cols.push(j);
650 l_vals.push(val);
651 }
652 }
653 }
654
655 let l = CsrArray::from_triplets(&l_rows, &l_cols, &l_vals, (n, rank), false)?;
656 let p = Array1::from_vec(perm);
657
658 Ok(PivotedCholeskyResult {
659 l,
660 p,
661 rank,
662 success: true,
663 })
664}
665
666#[derive(Debug, Clone)]
668pub struct LDLTResult<T>
669where
670 T: Float + Debug + Copy + 'static,
671{
672 pub l: CsrArray<T>,
674 pub d: Array1<T>,
676 pub p: Array1<usize>,
678 pub success: bool,
680}
681
682#[allow(dead_code)]
713pub fn ldlt_decomposition<T, S>(
714 matrix: &S,
715 pivoting: Option<bool>,
716 threshold: Option<T>,
717) -> SparseResult<LDLTResult<T>>
718where
719 T: Float + Debug + Copy + Add<Output = T> + Sub<Output = T> + Mul<Output = T> + Div<Output = T>,
720 S: SparseArray<T>,
721{
722 let (n, m) = matrix.shape();
723 if n != m {
724 return Err(SparseError::ValueError(
725 "Matrix must be square for LDLT decomposition".to_string(),
726 ));
727 }
728
729 let use_pivoting = pivoting.unwrap_or(true);
730 let threshold = threshold.unwrap_or_else(|| T::from(1e-12).unwrap());
731
732 let (row_indices, col_indices, values) = matrix.find();
734 let mut working_matrix = SparseWorkingMatrix::from_triplets(
735 row_indices.as_slice().unwrap(),
736 col_indices.as_slice().unwrap(),
737 values.as_slice().unwrap(),
738 n,
739 );
740
741 let mut perm: Vec<usize> = (0..n).collect();
743 let mut d_values = vec![T::zero(); n];
744
745 for k in 0..n {
747 if use_pivoting {
749 let pivot_idx = find_ldlt_pivot(&working_matrix, k, &perm, threshold);
750 if pivot_idx != k {
751 perm.swap(k, pivot_idx);
752 }
753 }
754
755 let actual_k = perm[k];
756
757 let mut diag_val = working_matrix.get(actual_k, actual_k);
759 for j in 0..k {
760 let l_kj = working_matrix.get(actual_k, perm[j]);
761 diag_val = diag_val - l_kj * l_kj * d_values[j];
762 }
763
764 d_values[k] = diag_val;
765
766 if diag_val.abs() < threshold {
768 return Ok(LDLTResult {
769 l: CsrArray::from_triplets(&[], &[], &[], (n, n), false)?,
770 d: Array1::from_vec(d_values),
771 p: Array1::from_vec(perm),
772 success: false,
773 });
774 }
775
776 for i in (k + 1)..n {
778 let actual_i = perm[i];
779 let mut l_ik = working_matrix.get(actual_i, actual_k);
780
781 for j in 0..k {
782 l_ik = l_ik
783 - working_matrix.get(actual_i, perm[j])
784 * working_matrix.get(actual_k, perm[j])
785 * d_values[j];
786 }
787
788 l_ik = l_ik / diag_val;
789 working_matrix.set(actual_i, actual_k, l_ik);
790 }
791
792 working_matrix.set(actual_k, actual_k, T::one());
794 }
795
796 let (l_rows, l_cols, l_vals) = extract_unit_lower_triangular(&working_matrix, &perm, n);
798 let l = CsrArray::from_triplets(&l_rows, &l_cols, &l_vals, (n, n), false)?;
799
800 Ok(LDLTResult {
801 l,
802 d: Array1::from_vec(d_values),
803 p: Array1::from_vec(perm),
804 success: true,
805 })
806}
807
808#[allow(dead_code)]
810fn find_ldlt_pivot<T>(
811 matrix: &SparseWorkingMatrix<T>,
812 k: usize,
813 perm: &[usize],
814 threshold: T,
815) -> usize
816where
817 T: Float + Debug + Copy,
818{
819 let n = matrix.n;
820 let mut max_val = T::zero();
821 let mut pivot_idx = k;
822
823 for (i, &actual_i) in perm.iter().enumerate().take(n).skip(k) {
825 let diag_val = matrix.get(actual_i, actual_i).abs();
826
827 if diag_val > max_val {
828 max_val = diag_val;
829 pivot_idx = i;
830 }
831 }
832
833 if max_val >= threshold {
835 pivot_idx
836 } else {
837 k }
839}
840
841#[allow(dead_code)]
843fn extract_unit_lower_triangular<T>(
844 matrix: &SparseWorkingMatrix<T>,
845 perm: &[usize],
846 n: usize,
847) -> (Vec<usize>, Vec<usize>, Vec<T>)
848where
849 T: Float + Debug + Copy,
850{
851 let mut rows = Vec::new();
852 let mut cols = Vec::new();
853 let mut vals = Vec::new();
854
855 for i in 0..n {
856 let actual_i = perm[i];
857
858 rows.push(i);
860 cols.push(i);
861 vals.push(T::one());
862
863 for (j, &perm_j) in perm.iter().enumerate().take(i) {
865 let val = matrix.get(actual_i, perm_j);
866 if val != T::zero() {
867 rows.push(i);
868 cols.push(j);
869 vals.push(val);
870 }
871 }
872 }
873
874 (rows, cols, vals)
875}
876
877#[allow(dead_code)]
891pub fn incomplete_lu<T, S>(matrix: &S, options: Option<ILUOptions>) -> SparseResult<LUResult<T>>
892where
893 T: Float + Debug + Copy + Add<Output = T> + Sub<Output = T> + Mul<Output = T> + Div<Output = T>,
894 S: SparseArray<T>,
895{
896 let opts = options.unwrap_or_default();
897 let (n, m) = matrix.shape();
898
899 if n != m {
900 return Err(SparseError::ValueError(
901 "Matrix must be square for ILU decomposition".to_string(),
902 ));
903 }
904
905 let (row_indices, col_indices, values) = matrix.find();
907 let mut working_matrix = SparseWorkingMatrix::from_triplets(
908 row_indices.as_slice().unwrap(),
909 col_indices.as_slice().unwrap(),
910 values.as_slice().unwrap(),
911 n,
912 );
913
914 for k in 0..n - 1 {
916 let pivot_val = working_matrix.get(k, k);
917
918 if pivot_val.abs() < T::from(1e-14).unwrap() {
919 continue; }
921
922 let col_k_entries = working_matrix.get_column_below_diagonal(k);
924
925 for &row_i in &col_k_entries {
926 let factor = working_matrix.get(row_i, k) / pivot_val;
927
928 if factor.abs() < T::from(opts.drop_tol).unwrap() {
930 working_matrix.set(row_i, k, T::zero());
931 continue;
932 }
933
934 working_matrix.set(row_i, k, factor);
935
936 let row_k_entries = working_matrix.get_row_after_column(k, k);
938 for (col_j, &val_kj) in &row_k_entries {
939 if working_matrix.has_entry(row_i, *col_j) {
940 let old_val = working_matrix.get(row_i, *col_j);
941 let new_val = old_val - factor * val_kj;
942
943 if new_val.abs() < T::from(opts.drop_tol).unwrap() {
945 working_matrix.set(row_i, *col_j, T::zero());
946 } else {
947 working_matrix.set(row_i, *col_j, new_val);
948 }
949 }
950 }
951 }
952 }
953
954 let identity_p: Vec<usize> = (0..n).collect();
956 let (l_rows, l_cols, l_vals, u_rows, u_cols, u_vals) =
957 extract_lu_factors(&working_matrix, &identity_p, n);
958
959 let l = CsrArray::from_triplets(&l_rows, &l_cols, &l_vals, (n, n), false)?;
960 let u = CsrArray::from_triplets(&u_rows, &u_cols, &u_vals, (n, n), false)?;
961
962 Ok(LUResult {
963 l,
964 u,
965 p: Array1::from_vec(identity_p),
966 success: true,
967 })
968}
969
970#[allow(dead_code)]
984pub fn incomplete_cholesky<T, S>(
985 matrix: &S,
986 options: Option<ICOptions>,
987) -> SparseResult<CholeskyResult<T>>
988where
989 T: Float + Debug + Copy + Add<Output = T> + Sub<Output = T> + Mul<Output = T> + Div<Output = T>,
990 S: SparseArray<T>,
991{
992 let opts = options.unwrap_or_default();
993 let (n, m) = matrix.shape();
994
995 if n != m {
996 return Err(SparseError::ValueError(
997 "Matrix must be square for IC decomposition".to_string(),
998 ));
999 }
1000
1001 let (row_indices, col_indices, values) = matrix.find();
1003 let mut working_matrix = SparseWorkingMatrix::from_triplets(
1004 row_indices.as_slice().unwrap(),
1005 col_indices.as_slice().unwrap(),
1006 values.as_slice().unwrap(),
1007 n,
1008 );
1009
1010 for k in 0..n {
1012 let mut sum = T::zero();
1014 let row_k_before_k = working_matrix.get_row_before_column(k, k);
1015 for &val_kj in row_k_before_k.values() {
1016 sum = sum + val_kj * val_kj;
1017 }
1018
1019 let a_kk = working_matrix.get(k, k);
1020 let diag_val = a_kk - sum;
1021
1022 if diag_val <= T::zero() {
1023 return Ok(CholeskyResult {
1024 l: CsrArray::from_triplets(&[], &[], &[], (n, n), false)?,
1025 success: false,
1026 });
1027 }
1028
1029 let l_kk = diag_val.sqrt();
1030 working_matrix.set(k, k, l_kk);
1031
1032 let col_k_below = working_matrix.get_column_below_diagonal(k);
1034 for &row_i in &col_k_below {
1035 let mut sum = T::zero();
1036 let row_i_before_k = working_matrix.get_row_before_column(row_i, k);
1037 let row_k_before_k = working_matrix.get_row_before_column(k, k);
1038
1039 for (col_j, &val_ij) in &row_i_before_k {
1041 if let Some(&val_kj) = row_k_before_k.get(col_j) {
1042 sum = sum + val_ij * val_kj;
1043 }
1044 }
1045
1046 let a_ik = working_matrix.get(row_i, k);
1047 let l_ik = (a_ik - sum) / l_kk;
1048
1049 if l_ik.abs() < T::from(opts.drop_tol).unwrap() {
1051 working_matrix.set(row_i, k, T::zero());
1052 } else {
1053 working_matrix.set(row_i, k, l_ik);
1054 }
1055 }
1056 }
1057
1058 let (l_rows, l_cols, l_vals) = extract_lower_triangular(&working_matrix, n);
1060 let l = CsrArray::from_triplets(&l_rows, &l_cols, &l_vals, (n, n), false)?;
1061
1062 Ok(CholeskyResult { l, success: true })
1063}
1064
1065struct SparseWorkingMatrix<T>
1067where
1068 T: Float + Debug + Copy,
1069{
1070 data: HashMap<(usize, usize), T>,
1071 n: usize,
1072}
1073
1074impl<T> SparseWorkingMatrix<T>
1075where
1076 T: Float + Debug + Copy + Add<Output = T> + Sub<Output = T> + Mul<Output = T> + Div<Output = T>,
1077{
1078 fn from_triplets(rows: &[usize], cols: &[usize], values: &[T], n: usize) -> Self {
1079 let mut data = HashMap::new();
1080
1081 for (i, (&row, &col)) in rows.iter().zip(cols.iter()).enumerate() {
1082 data.insert((row, col), values[i]);
1083 }
1084
1085 Self { data, n }
1086 }
1087
1088 fn get(&self, row: usize, col: usize) -> T {
1089 self.data.get(&(row, col)).copied().unwrap_or(T::zero())
1090 }
1091
1092 fn set(&mut self, row: usize, col: usize, value: T) {
1093 if value.is_zero() {
1094 self.data.remove(&(row, col));
1095 } else {
1096 self.data.insert((row, col), value);
1097 }
1098 }
1099
1100 fn has_entry(&self, row: usize, col: usize) -> bool {
1101 self.data.contains_key(&(row, col))
1102 }
1103
1104 fn get_row(&self, row: usize) -> HashMap<usize, T> {
1105 let mut result = HashMap::new();
1106 for (&(r, c), &value) in &self.data {
1107 if r == row {
1108 result.insert(c, value);
1109 }
1110 }
1111 result
1112 }
1113
1114 fn get_row_after_column(&self, row: usize, col: usize) -> HashMap<usize, T> {
1115 let mut result = HashMap::new();
1116 for (&(r, c), &value) in &self.data {
1117 if r == row && c > col {
1118 result.insert(c, value);
1119 }
1120 }
1121 result
1122 }
1123
1124 fn get_row_before_column(&self, row: usize, col: usize) -> HashMap<usize, T> {
1125 let mut result = HashMap::new();
1126 for (&(r, c), &value) in &self.data {
1127 if r == row && c < col {
1128 result.insert(c, value);
1129 }
1130 }
1131 result
1132 }
1133
1134 fn get_column_below_diagonal(&self, col: usize) -> Vec<usize> {
1135 let mut result = Vec::new();
1136 for &(r, c) in self.data.keys() {
1137 if c == col && r > col {
1138 result.push(r);
1139 }
1140 }
1141 result.sort();
1142 result
1143 }
1144}
1145
1146#[allow(dead_code)]
1148fn find_pivot<T>(
1149 matrix: &SparseWorkingMatrix<T>,
1150 k: usize,
1151 p: &[usize],
1152 threshold: f64,
1153) -> SparseResult<usize>
1154where
1155 T: Float + Debug + Copy,
1156{
1157 let opts = LUOptions {
1159 pivoting: PivotingStrategy::Threshold(threshold),
1160 zero_threshold: 1e-14,
1161 check_singular: true,
1162 };
1163
1164 let row_scales = vec![T::one(); matrix.n];
1165 let col_perm: Vec<usize> = (0..matrix.n).collect();
1166
1167 let (pivot_row, pivot_col) = find_enhanced_pivot(matrix, k, p, &col_perm, &row_scales, &opts)?;
1168 Ok(pivot_row)
1169}
1170
1171#[allow(dead_code)]
1173fn find_enhanced_pivot<T>(
1174 matrix: &SparseWorkingMatrix<T>,
1175 k: usize,
1176 row_perm: &[usize],
1177 col_perm: &[usize],
1178 row_scales: &[T],
1179 opts: &LUOptions,
1180) -> SparseResult<(usize, usize)>
1181where
1182 T: Float + Debug + Copy,
1183{
1184 let n = matrix.n;
1185
1186 match &opts.pivoting {
1187 PivotingStrategy::None => {
1188 Ok((k, k))
1190 }
1191
1192 PivotingStrategy::Partial => {
1193 let mut max_val = T::zero();
1195 let mut pivot_row = k;
1196
1197 for (idx, &actual_row) in row_perm.iter().enumerate().skip(k).take(n - k) {
1198 let i = k + idx;
1199 let val = matrix.get(actual_row, col_perm[k]).abs();
1200 if val > max_val {
1201 max_val = val;
1202 pivot_row = i;
1203 }
1204 }
1205
1206 Ok((pivot_row, k))
1207 }
1208
1209 PivotingStrategy::Threshold(threshold) => {
1210 let threshold_val = T::from(*threshold).unwrap();
1212 let mut max_val = T::zero();
1213 let mut pivot_row = k;
1214
1215 for (idx, &actual_row) in row_perm.iter().enumerate().skip(k).take(n - k) {
1216 let i = k + idx;
1217 let val = matrix.get(actual_row, col_perm[k]).abs();
1218 if val > max_val {
1219 max_val = val;
1220 pivot_row = i;
1221 }
1222 if val >= threshold_val {
1224 pivot_row = i;
1225 break;
1226 }
1227 }
1228
1229 Ok((pivot_row, k))
1230 }
1231
1232 PivotingStrategy::ScaledPartial => {
1233 let mut max_ratio = T::zero();
1235 let mut pivot_row = k;
1236
1237 for (idx, &actual_row) in row_perm.iter().enumerate().skip(k).take(n - k) {
1238 let i = k + idx;
1239 let val = matrix.get(actual_row, col_perm[k]).abs();
1240 let scale = row_scales[actual_row];
1241
1242 let ratio = if scale > T::zero() { val / scale } else { val };
1243
1244 if ratio > max_ratio {
1245 max_ratio = ratio;
1246 pivot_row = i;
1247 }
1248 }
1249
1250 Ok((pivot_row, k))
1251 }
1252
1253 PivotingStrategy::Complete => {
1254 let mut max_val = T::zero();
1256 let mut pivot_row = k;
1257 let mut pivot_col = k;
1258
1259 for (i_idx, &actual_row) in row_perm.iter().enumerate().skip(k).take(n - k) {
1260 let i = k + i_idx;
1261 for (j_idx, &actual_col) in col_perm.iter().enumerate().skip(k).take(n - k) {
1262 let j = k + j_idx;
1263 let val = matrix.get(actual_row, actual_col).abs();
1264 if val > max_val {
1265 max_val = val;
1266 pivot_row = i;
1267 pivot_col = j;
1268 }
1269 }
1270 }
1271
1272 Ok((pivot_row, pivot_col))
1273 }
1274
1275 PivotingStrategy::Rook => {
1276 let mut best_row = k;
1278 let mut best_col = k;
1279 let mut max_val = T::zero();
1280
1281 for (idx, &actual_row) in row_perm.iter().enumerate().skip(k).take(n - k) {
1283 let i = k + idx;
1284 let val = matrix.get(actual_row, col_perm[k]).abs();
1285 if val > max_val {
1286 max_val = val;
1287 best_row = i;
1288 }
1289 }
1290
1291 if max_val > T::from(opts.zero_threshold).unwrap() {
1293 let actual_best_row = row_perm[best_row];
1294 let mut col_max = T::zero();
1295
1296 for (idx, &actual_col) in col_perm.iter().enumerate().skip(k).take(n - k) {
1297 let j = k + idx;
1298 let val = matrix.get(actual_best_row, actual_col).abs();
1299 if val > col_max {
1300 col_max = val;
1301 best_col = j;
1302 }
1303 }
1304
1305 let improvement_threshold = T::from(1.5).unwrap();
1307 if col_max > max_val * improvement_threshold {
1308 max_val = T::zero();
1310 for (idx, &actual_row) in row_perm.iter().enumerate().skip(k).take(n - k) {
1311 let i = k + idx;
1312 let val = matrix.get(actual_row, col_perm[best_col]).abs();
1313 if val > max_val {
1314 max_val = val;
1315 best_row = i;
1316 }
1317 }
1318 }
1319 }
1320
1321 Ok((best_row, best_col))
1322 }
1323 }
1324}
1325
1326type LuFactors<T> = (
1328 Vec<usize>, Vec<usize>, Vec<T>, Vec<usize>, Vec<usize>, Vec<T>, );
1335
1336#[allow(dead_code)]
1337fn extract_lu_factors<T>(matrix: &SparseWorkingMatrix<T>, p: &[usize], n: usize) -> LuFactors<T>
1338where
1339 T: Float + Debug + Copy,
1340{
1341 let mut l_rows = Vec::new();
1342 let mut l_cols = Vec::new();
1343 let mut l_vals = Vec::new();
1344 let mut u_rows = Vec::new();
1345 let mut u_cols = Vec::new();
1346 let mut u_vals = Vec::new();
1347
1348 #[allow(clippy::needless_range_loop)]
1349 for i in 0..n {
1350 let actual_row = p[i];
1351
1352 l_rows.push(i);
1354 l_cols.push(i);
1355 l_vals.push(T::one());
1356
1357 for j in 0..n {
1358 let val = matrix.get(actual_row, j);
1359 if !val.is_zero() {
1360 if j < i {
1361 l_rows.push(i);
1363 l_cols.push(j);
1364 l_vals.push(val);
1365 } else {
1366 u_rows.push(i);
1368 u_cols.push(j);
1369 u_vals.push(val);
1370 }
1371 }
1372 }
1373 }
1374
1375 (l_rows, l_cols, l_vals, u_rows, u_cols, u_vals)
1376}
1377
1378#[allow(dead_code)]
1380fn extract_lower_triangular<T>(
1381 matrix: &SparseWorkingMatrix<T>,
1382 n: usize,
1383) -> (Vec<usize>, Vec<usize>, Vec<T>)
1384where
1385 T: Float + Debug + Copy,
1386{
1387 let mut rows = Vec::new();
1388 let mut cols = Vec::new();
1389 let mut vals = Vec::new();
1390
1391 for i in 0..n {
1392 for j in 0..=i {
1393 let val = matrix.get(i, j);
1394 if !val.is_zero() {
1395 rows.push(i);
1396 cols.push(j);
1397 vals.push(val);
1398 }
1399 }
1400 }
1401
1402 (rows, cols, vals)
1403}
1404
1405#[allow(dead_code)]
1407fn dense_to_sparse<T>(matrix: &Array2<T>) -> SparseResult<CsrArray<T>>
1408where
1409 T: Float + Debug + Copy,
1410{
1411 let (m, n) = matrix.dim();
1412 let mut rows = Vec::new();
1413 let mut cols = Vec::new();
1414 let mut vals = Vec::new();
1415
1416 for i in 0..m {
1417 for j in 0..n {
1418 let val = matrix[[i, j]];
1419 if !val.is_zero() {
1420 rows.push(i);
1421 cols.push(j);
1422 vals.push(val);
1423 }
1424 }
1425 }
1426
1427 CsrArray::from_triplets(&rows, &cols, &vals, (m, n), false)
1428}
1429
1430#[cfg(test)]
1431mod tests {
1432 use super::*;
1433 use crate::csr_array::CsrArray;
1434
1435 fn create_test_matrix() -> CsrArray<f64> {
1436 let rows = vec![0, 0, 1, 1, 2, 2];
1438 let cols = vec![0, 1, 0, 1, 1, 2];
1439 let data = vec![2.0, 1.0, 1.0, 3.0, 2.0, 4.0];
1440
1441 CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap()
1442 }
1443
1444 fn create_spd_matrix() -> CsrArray<f64> {
1445 let rows = vec![0, 1, 1, 2, 2, 2];
1447 let cols = vec![0, 0, 1, 0, 1, 2];
1448 let data = vec![4.0, 2.0, 5.0, 1.0, 3.0, 6.0];
1449
1450 CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap()
1451 }
1452
1453 #[test]
1454 fn test_lu_decomposition() {
1455 let matrix = create_test_matrix();
1456 let lu_result = lu_decomposition(&matrix, 0.1).unwrap();
1457
1458 assert!(lu_result.success);
1459 assert_eq!(lu_result.l.shape(), (3, 3));
1460 assert_eq!(lu_result.u.shape(), (3, 3));
1461 assert_eq!(lu_result.p.len(), 3);
1462 }
1463
1464 #[test]
1465 fn test_qr_decomposition() {
1466 let rows = vec![0, 1, 2];
1467 let cols = vec![0, 0, 1];
1468 let data = vec![1.0, 2.0, 3.0];
1469 let matrix = CsrArray::from_triplets(&rows, &cols, &data, (3, 2), false).unwrap();
1470
1471 let qr_result = qr_decomposition(&matrix).unwrap();
1472
1473 assert!(qr_result.success);
1474 assert_eq!(qr_result.q.shape(), (3, 2));
1475 assert_eq!(qr_result.r.shape(), (2, 2));
1476 }
1477
1478 #[test]
1479 fn test_cholesky_decomposition() {
1480 let matrix = create_spd_matrix();
1481 let chol_result = cholesky_decomposition(&matrix).unwrap();
1482
1483 assert!(chol_result.success);
1484 assert_eq!(chol_result.l.shape(), (3, 3));
1485 }
1486
1487 #[test]
1488 fn test_incomplete_lu() {
1489 let matrix = create_test_matrix();
1490 let options = ILUOptions {
1491 drop_tol: 1e-6,
1492 ..Default::default()
1493 };
1494
1495 let ilu_result = incomplete_lu(&matrix, Some(options)).unwrap();
1496
1497 assert!(ilu_result.success);
1498 assert_eq!(ilu_result.l.shape(), (3, 3));
1499 assert_eq!(ilu_result.u.shape(), (3, 3));
1500 }
1501
1502 #[test]
1503 fn test_incomplete_cholesky() {
1504 let matrix = create_spd_matrix();
1505 let options = ICOptions {
1506 drop_tol: 1e-6,
1507 ..Default::default()
1508 };
1509
1510 let ic_result = incomplete_cholesky(&matrix, Some(options)).unwrap();
1511
1512 assert!(ic_result.success);
1513 assert_eq!(ic_result.l.shape(), (3, 3));
1514 }
1515
1516 #[test]
1517 fn test_sparse_working_matrix() {
1518 let rows = vec![0, 1, 2];
1519 let cols = vec![0, 1, 2];
1520 let vals = vec![1.0, 2.0, 3.0];
1521
1522 let mut matrix = SparseWorkingMatrix::from_triplets(&rows, &cols, &vals, 3);
1523
1524 assert_eq!(matrix.get(0, 0), 1.0);
1525 assert_eq!(matrix.get(1, 1), 2.0);
1526 assert_eq!(matrix.get(2, 2), 3.0);
1527 assert_eq!(matrix.get(0, 1), 0.0);
1528
1529 matrix.set(0, 1, 5.0);
1530 assert_eq!(matrix.get(0, 1), 5.0);
1531
1532 matrix.set(0, 1, 0.0);
1533 assert_eq!(matrix.get(0, 1), 0.0);
1534 assert!(!matrix.has_entry(0, 1));
1535 }
1536
1537 #[test]
1538 fn test_dense_to_sparse_conversion() {
1539 let dense = Array2::from_shape_vec((2, 2), vec![1.0, 0.0, 2.0, 3.0]).unwrap();
1540 let sparse = dense_to_sparse(&dense).unwrap();
1541
1542 assert_eq!(sparse.nnz(), 3);
1543 assert_eq!(sparse.get(0, 0), 1.0);
1544 assert_eq!(sparse.get(0, 1), 0.0);
1545 assert_eq!(sparse.get(1, 0), 2.0);
1546 assert_eq!(sparse.get(1, 1), 3.0);
1547 }
1548}