1use crate::csr::CsrMatrix;
29use crate::error::{SparseError, SparseResult};
30use scirs2_core::numeric::{Float, NumAssign, SparseElement};
31use std::fmt::Debug;
32use std::iter::Sum;
33
34#[derive(Debug, Clone)]
40pub struct SparseQrConfig {
41 pub pivoting: bool,
43 pub rank_tol: f64,
45 pub economy: bool,
47}
48
49impl Default for SparseQrConfig {
50 fn default() -> Self {
51 Self {
52 pivoting: true,
53 rank_tol: 1e-12,
54 economy: true,
55 }
56 }
57}
58
59#[derive(Debug, Clone)]
70pub struct SparseQrResult<F> {
71 pub householder_v: Vec<Vec<F>>,
74 pub tau: Vec<F>,
76 pub r_data: Vec<Vec<F>>,
78 pub col_perm: Vec<usize>,
80 pub m: usize,
82 pub n: usize,
84 pub rank: usize,
86}
87
88#[derive(Debug, Clone)]
90pub struct SparseLeastSquaresResult<F> {
91 pub solution: Vec<F>,
93 pub residual_norm: F,
95 pub rank: usize,
97}
98
99pub fn sparse_qr<F>(
108 matrix: &CsrMatrix<F>,
109 config: &SparseQrConfig,
110) -> SparseResult<SparseQrResult<F>>
111where
112 F: Float + NumAssign + Sum + SparseElement + Debug + 'static,
113{
114 let m = matrix.rows();
115 let n = matrix.cols();
116
117 if m == 0 || n == 0 {
118 return Ok(SparseQrResult {
119 householder_v: Vec::new(),
120 tau: Vec::new(),
121 r_data: Vec::new(),
122 col_perm: (0..n).collect(),
123 m,
124 n,
125 rank: 0,
126 });
127 }
128
129 let mut cols_data = extract_columns(matrix, m, n);
133
134 let mut col_perm: Vec<usize> = (0..n).collect();
136
137 let mut col_norms: Vec<F> = (0..n).map(|j| column_norm(&cols_data[j])).collect();
139
140 let k_max = m.min(n);
141 let mut householder_v: Vec<Vec<F>> = Vec::with_capacity(k_max);
142 let mut tau_vec: Vec<F> = Vec::with_capacity(k_max);
143 let mut rank = k_max;
144
145 let rank_tol = F::from(config.rank_tol).unwrap_or_else(|| F::epsilon());
146
147 let max_norm = col_norms
149 .iter()
150 .copied()
151 .fold(F::sparse_zero(), |a, b| if b > a { b } else { a });
152
153 for k in 0..k_max {
154 if config.pivoting {
156 let mut best_j = k;
157 let mut best_norm = col_norms[k];
158 for j in (k + 1)..n {
159 if col_norms[j] > best_norm {
160 best_norm = col_norms[j];
161 best_j = j;
162 }
163 }
164 if best_j != k {
165 cols_data.swap(k, best_j);
166 col_perm.swap(k, best_j);
167 col_norms.swap(k, best_j);
168 }
169
170 if best_norm < rank_tol * max_norm {
172 rank = k;
173 for _i in k..k_max {
175 let v = vec![F::sparse_zero(); m];
176 householder_v.push(v);
177 tau_vec.push(F::sparse_zero());
178 }
179 break;
180 }
181 }
182
183 let (v, tau) = householder_vector(&cols_data[k], k, m);
185
186 for j in k..n {
190 let dot = dot_from(m, &v, &cols_data[j], k);
191 let scale = tau * dot;
192 for i in k..m {
193 cols_data[j][i] -= scale * v[i];
194 }
195 }
196
197 for j in (k + 1)..n {
199 let r_kj = cols_data[j][k];
200 let old_norm_sq = col_norms[j] * col_norms[j];
201 let new_norm_sq = old_norm_sq - r_kj * r_kj;
202 col_norms[j] = if new_norm_sq > F::sparse_zero() {
203 new_norm_sq.sqrt()
204 } else {
205 column_norm_from(&cols_data[j], k + 1, m)
207 };
208 }
209
210 householder_v.push(v);
211 tau_vec.push(tau);
212 }
213
214 let r_rows = if config.economy { rank } else { m };
216 let mut r_data = vec![vec![F::sparse_zero(); n]; r_rows];
217 for j in 0..n {
218 for i in 0..r_rows.min(j + 1) {
219 r_data[i][j] = cols_data[j][i];
220 }
221 }
222
223 Ok(SparseQrResult {
224 householder_v,
225 tau: tau_vec,
226 r_data,
227 col_perm,
228 m,
229 n,
230 rank,
231 })
232}
233
234pub fn sparse_least_squares<F>(
243 matrix: &CsrMatrix<F>,
244 b: &[F],
245 config: Option<&SparseQrConfig>,
246) -> SparseResult<SparseLeastSquaresResult<F>>
247where
248 F: Float + NumAssign + Sum + SparseElement + Debug + 'static,
249{
250 let m = matrix.rows();
251 let n = matrix.cols();
252
253 if b.len() != m {
254 return Err(SparseError::DimensionMismatch {
255 expected: m,
256 found: b.len(),
257 });
258 }
259
260 let default_config = SparseQrConfig::default();
261 let cfg = config.unwrap_or(&default_config);
262
263 let qr = sparse_qr(matrix, cfg)?;
265
266 let qt_b = apply_qt(b, &qr.householder_v, &qr.tau, m)?;
268
269 let rank = qr.rank;
270 if rank == 0 {
271 return Ok(SparseLeastSquaresResult {
272 solution: vec![F::sparse_zero(); n],
273 residual_norm: vector_norm(b),
274 rank: 0,
275 });
276 }
277
278 let mut y = vec![F::sparse_zero(); n];
280 for i in (0..rank).rev() {
281 let mut sum = qt_b[i];
282 for j in (i + 1)..rank {
283 sum -= qr.r_data[i][j] * y[j];
284 }
285 let diag = qr.r_data[i][i];
286 if diag.abs() < F::epsilon() {
287 return Err(SparseError::SingularMatrix(format!(
288 "Zero diagonal in R at position {i}"
289 )));
290 }
291 y[i] = sum / diag;
292 }
293
294 let mut x = vec![F::sparse_zero(); n];
296 for j in 0..n {
297 x[qr.col_perm[j]] = y[j];
298 }
299
300 let residual_norm = if rank < m {
302 let mut sum_sq = F::sparse_zero();
303 for i in rank..m {
304 sum_sq += qt_b[i] * qt_b[i];
305 }
306 sum_sq.sqrt()
307 } else {
308 F::sparse_zero()
309 };
310
311 Ok(SparseLeastSquaresResult {
312 solution: x,
313 residual_norm,
314 rank,
315 })
316}
317
318pub fn apply_qt<F>(b: &[F], householder_v: &[Vec<F>], tau: &[F], m: usize) -> SparseResult<Vec<F>>
331where
332 F: Float + NumAssign + Sum + SparseElement + Debug + 'static,
333{
334 if b.len() != m {
335 return Err(SparseError::DimensionMismatch {
336 expected: m,
337 found: b.len(),
338 });
339 }
340
341 let mut result = b.to_vec();
342 let k = householder_v.len();
343
344 for i in 0..k {
345 if tau[i] == F::sparse_zero() {
346 continue;
347 }
348 let v = &householder_v[i];
349 let dot: F = (0..m).map(|row| v[row] * result[row]).sum();
350 let scale = tau[i] * dot;
351 for row in 0..m {
352 result[row] -= scale * v[row];
353 }
354 }
355
356 Ok(result)
357}
358
359pub fn apply_q<F>(b: &[F], householder_v: &[Vec<F>], tau: &[F], m: usize) -> SparseResult<Vec<F>>
363where
364 F: Float + NumAssign + Sum + SparseElement + Debug + 'static,
365{
366 if b.len() != m {
367 return Err(SparseError::DimensionMismatch {
368 expected: m,
369 found: b.len(),
370 });
371 }
372
373 let mut result = b.to_vec();
374 let k = householder_v.len();
375
376 for i in (0..k).rev() {
377 if tau[i] == F::sparse_zero() {
378 continue;
379 }
380 let v = &householder_v[i];
381 let dot: F = (0..m).map(|row| v[row] * result[row]).sum();
382 let scale = tau[i] * dot;
383 for row in 0..m {
384 result[row] -= scale * v[row];
385 }
386 }
387
388 Ok(result)
389}
390
391pub fn extract_q_dense<F>(
395 householder_v: &[Vec<F>],
396 tau: &[F],
397 m: usize,
398 rank: usize,
399) -> SparseResult<Vec<Vec<F>>>
400where
401 F: Float + NumAssign + Sum + SparseElement + Debug + 'static,
402{
403 let mut q = vec![vec![F::sparse_zero(); rank]; m];
404
405 for j in 0..rank {
407 let mut ej = vec![F::sparse_zero(); m];
408 if j < m {
409 ej[j] = F::sparse_one();
410 }
411 let col = apply_q(&ej, householder_v, tau, m)?;
412 for i in 0..m {
413 q[i][j] = col[i];
414 }
415 }
416
417 Ok(q)
418}
419
420pub fn numerical_rank<F>(matrix: &CsrMatrix<F>, tol: Option<f64>) -> SparseResult<usize>
428where
429 F: Float + NumAssign + Sum + SparseElement + Debug + 'static,
430{
431 let config = SparseQrConfig {
432 pivoting: true,
433 rank_tol: tol.unwrap_or(1e-12),
434 economy: true,
435 };
436 let qr = sparse_qr(matrix, &config)?;
437 Ok(qr.rank)
438}
439
440fn extract_columns<F>(matrix: &CsrMatrix<F>, m: usize, n: usize) -> Vec<Vec<F>>
446where
447 F: Float + SparseElement + Debug + 'static,
448{
449 let mut cols = vec![vec![F::sparse_zero(); m]; n];
450 for i in 0..m {
451 let start = matrix.indptr[i];
452 let end = matrix.indptr[i + 1];
453 for idx in start..end {
454 let j = matrix.indices[idx];
455 cols[j][i] = matrix.data[idx];
456 }
457 }
458 cols
459}
460
461fn householder_vector<F>(col: &[F], k: usize, m: usize) -> (Vec<F>, F)
468where
469 F: Float + NumAssign + SparseElement + Debug + 'static,
470{
471 let mut v = vec![F::sparse_zero(); m];
472
473 let mut sigma_sq = F::sparse_zero();
475 for i in k..m {
476 sigma_sq += col[i] * col[i];
477 }
478 let sigma = sigma_sq.sqrt();
479
480 if sigma < F::epsilon() {
481 return (v, F::sparse_zero());
482 }
483
484 let alpha = if col[k] >= F::sparse_zero() {
486 -sigma
487 } else {
488 sigma
489 };
490
491 v[k] = col[k] - alpha;
493 v[(k + 1)..m].copy_from_slice(&col[(k + 1)..m]);
494
495 let mut v_norm_sq = F::sparse_zero();
497 for i in k..m {
498 v_norm_sq += v[i] * v[i];
499 }
500
501 if v_norm_sq < F::epsilon() {
502 return (v, F::sparse_zero());
503 }
504
505 let tau = F::from(2.0).unwrap_or_else(|| F::sparse_one() + F::sparse_one()) / v_norm_sq;
506
507 (v, tau)
508}
509
510fn dot_from<F: Float + SparseElement>(m: usize, a: &[F], b: &[F], start: usize) -> F {
512 let mut sum = F::sparse_zero();
513 for i in start..m {
514 sum = sum + a[i] * b[i];
515 }
516 sum
517}
518
519fn column_norm<F: Float + SparseElement>(col: &[F]) -> F {
521 let mut sum_sq = F::sparse_zero();
522 for &v in col {
523 sum_sq = sum_sq + v * v;
524 }
525 sum_sq.sqrt()
526}
527
528fn column_norm_from<F: Float + SparseElement>(col: &[F], start: usize, end: usize) -> F {
530 let mut sum_sq = F::sparse_zero();
531 for i in start..end {
532 sum_sq = sum_sq + col[i] * col[i];
533 }
534 sum_sq.sqrt()
535}
536
537fn vector_norm<F: Float + SparseElement>(v: &[F]) -> F {
539 let mut sum_sq = F::sparse_zero();
540 for &x in v {
541 sum_sq = sum_sq + x * x;
542 }
543 sum_sq.sqrt()
544}
545
546#[cfg(test)]
551mod tests {
552 use super::*;
553
554 fn create_3x3() -> CsrMatrix<f64> {
556 let rows = vec![0, 0, 0, 1, 1, 1, 2, 2, 2];
557 let cols = vec![0, 1, 2, 0, 1, 2, 0, 1, 2];
558 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 10.0];
559 CsrMatrix::new(data, rows, cols, (3, 3)).expect("Failed")
560 }
561
562 fn create_overdetermined() -> CsrMatrix<f64> {
564 let rows = vec![0, 0, 1, 1, 2, 2, 3, 3];
565 let cols = vec![0, 1, 0, 1, 0, 1, 0, 1];
566 let data = vec![1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0];
567 CsrMatrix::new(data, rows, cols, (4, 2)).expect("Failed")
568 }
569
570 fn create_rank_deficient() -> CsrMatrix<f64> {
572 let rows = vec![0, 0, 0, 1, 1, 1, 2, 2, 2];
574 let cols = vec![0, 1, 2, 0, 1, 2, 0, 1, 2];
575 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 5.0, 7.0, 9.0];
576 CsrMatrix::new(data, rows, cols, (3, 3)).expect("Failed")
577 }
578
579 #[test]
580 fn test_sparse_qr_basic() {
581 let mat = create_3x3();
582 let config = SparseQrConfig {
583 pivoting: false,
584 economy: true,
585 ..Default::default()
586 };
587 let qr = sparse_qr(&mat, &config).expect("QR failed");
588 assert_eq!(qr.m, 3);
589 assert_eq!(qr.n, 3);
590 assert_eq!(qr.rank, 3);
591 assert_eq!(qr.householder_v.len(), 3);
592 assert_eq!(qr.r_data.len(), 3);
593 }
594
595 #[test]
596 fn test_sparse_qr_empty() {
597 let mat = CsrMatrix::<f64>::new(vec![], vec![], vec![], (0, 0)).expect("Failed");
598 let config = SparseQrConfig::default();
599 let qr = sparse_qr(&mat, &config).expect("QR on empty failed");
600 assert_eq!(qr.rank, 0);
601 }
602
603 #[test]
604 fn test_qr_orthogonality() {
605 let mat = create_3x3();
606 let config = SparseQrConfig {
607 pivoting: false,
608 economy: true,
609 ..Default::default()
610 };
611 let qr = sparse_qr(&mat, &config).expect("QR failed");
612 let q = extract_q_dense(&qr.householder_v, &qr.tau, 3, 3).expect("Q extraction failed");
613
614 for i in 0..3 {
616 for j in 0..3 {
617 let mut dot = 0.0;
618 for k in 0..3 {
619 dot += q[k][i] * q[k][j];
620 }
621 let expected = if i == j { 1.0 } else { 0.0 };
622 assert!(
623 (dot - expected).abs() < 1e-10,
624 "Q^T Q[{i},{j}] = {dot}, expected {expected}"
625 );
626 }
627 }
628 }
629
630 #[test]
631 fn test_qr_factorization_accuracy() {
632 let mat = create_3x3();
633 let config = SparseQrConfig {
634 pivoting: false,
635 economy: true,
636 ..Default::default()
637 };
638 let qr = sparse_qr(&mat, &config).expect("QR failed");
639 let q = extract_q_dense(&qr.householder_v, &qr.tau, 3, 3).expect("Q extraction failed");
640
641 let dense = mat.to_dense();
643 for i in 0..3 {
644 for j in 0..3 {
645 let mut qr_val = 0.0;
646 for k in 0..3 {
647 qr_val += q[i][k] * qr.r_data[k][j];
648 }
649 let orig_col = qr.col_perm[j];
650 let a_val = dense[i][orig_col];
651 assert!(
652 (qr_val - a_val).abs() < 1e-10,
653 "QR[{i},{j}] = {qr_val}, A*P[{i},{j}] = {a_val}"
654 );
655 }
656 }
657 }
658
659 #[test]
660 fn test_least_squares_square() {
661 let mat = create_3x3();
662 let b = vec![6.0, 15.0, 25.0];
665 let result = sparse_least_squares(&mat, &b, None).expect("LS failed");
666 assert_eq!(result.solution.len(), 3);
667 assert_eq!(result.rank, 3);
668
669 let dense = mat.to_dense();
671 for i in 0..3 {
672 let mut sum = 0.0;
673 for j in 0..3 {
674 sum += dense[i][j] * result.solution[j];
675 }
676 assert!(
677 (sum - b[i]).abs() < 1e-8,
678 "Row {i}: residual {}",
679 (sum - b[i]).abs()
680 );
681 }
682 }
683
684 #[test]
685 fn test_least_squares_overdetermined() {
686 let mat = create_overdetermined();
687 let b = vec![1.0, 2.0, 1.0, 2.0];
690 let result = sparse_least_squares(&mat, &b, None).expect("LS overdetermined failed");
691 assert_eq!(result.solution.len(), 2);
692 assert_eq!(result.rank, 2);
693
694 let dense = mat.to_dense();
697 let mut ata = vec![vec![0.0; 2]; 2];
698 let mut atb = [0.0; 2];
699 for i in 0..4 {
700 for j in 0..2 {
701 atb[j] += dense[i][j] * b[i];
702 for k in 0..2 {
703 ata[j][k] += dense[i][j] * dense[i][k];
704 }
705 }
706 }
707 for j in 0..2 {
708 let mut sum = 0.0;
709 for k in 0..2 {
710 sum += ata[j][k] * result.solution[k];
711 }
712 assert!(
713 (sum - atb[j]).abs() < 1e-8,
714 "Normal eq {j}: {sum} vs {}",
715 atb[j]
716 );
717 }
718 }
719
720 #[test]
721 fn test_least_squares_dimension_mismatch() {
722 let mat = create_3x3();
723 let result = sparse_least_squares(&mat, &[1.0, 2.0], None);
724 assert!(result.is_err());
725 }
726
727 #[test]
728 fn test_pivoted_qr_rank_deficient() {
729 let mat = create_rank_deficient();
730 let config = SparseQrConfig {
731 pivoting: true,
732 rank_tol: 1e-10,
733 economy: true,
734 };
735 let qr = sparse_qr(&mat, &config).expect("QR rank deficient failed");
736 assert!(qr.rank <= 2, "Expected rank <= 2, got {}", qr.rank);
737 }
738
739 #[test]
740 fn test_numerical_rank() {
741 let mat = create_3x3();
742 let rank = numerical_rank(&mat, None).expect("Rank computation failed");
743 assert_eq!(rank, 3);
744
745 let mat2 = create_rank_deficient();
746 let rank2 = numerical_rank(&mat2, Some(1e-10)).expect("Rank computation failed");
747 assert!(rank2 <= 2, "Expected rank <= 2, got {rank2}");
748 }
749
750 #[test]
751 fn test_apply_q_qt_inverse() {
752 let mat = create_3x3();
753 let config = SparseQrConfig {
754 pivoting: false,
755 economy: true,
756 ..Default::default()
757 };
758 let qr = sparse_qr(&mat, &config).expect("QR failed");
759
760 let b = vec![1.0, 2.0, 3.0];
761 let qt_b = apply_qt(&b, &qr.householder_v, &qr.tau, 3).expect("Q^T failed");
762 let q_qt_b = apply_q(&qt_b, &qr.householder_v, &qr.tau, 3).expect("Q failed");
763
764 for i in 0..3 {
766 assert!(
767 (q_qt_b[i] - b[i]).abs() < 1e-10,
768 "Q*Q^T*b[{i}] = {}, expected {}",
769 q_qt_b[i],
770 b[i]
771 );
772 }
773 }
774
775 #[test]
776 fn test_qr_r_upper_triangular() {
777 let mat = create_3x3();
778 let config = SparseQrConfig {
779 pivoting: false,
780 economy: true,
781 ..Default::default()
782 };
783 let qr = sparse_qr(&mat, &config).expect("QR failed");
784
785 for i in 0..qr.r_data.len() {
787 for j in 0..i {
788 assert!(
789 qr.r_data[i][j].abs() < 1e-10,
790 "R[{i},{j}] = {} should be zero",
791 qr.r_data[i][j]
792 );
793 }
794 }
795 }
796
797 #[test]
798 fn test_sparse_qr_single_element() {
799 let mat = CsrMatrix::new(vec![5.0], vec![0], vec![0], (1, 1)).expect("Failed");
800 let config = SparseQrConfig::default();
801 let qr = sparse_qr(&mat, &config).expect("QR single failed");
802 assert_eq!(qr.rank, 1);
803 assert!((qr.r_data[0][0].abs() - 5.0).abs() < 1e-10);
804 }
805
806 #[test]
807 fn test_least_squares_identity() {
808 let rows = vec![0, 1, 2];
809 let cols = vec![0, 1, 2];
810 let data = vec![1.0, 1.0, 1.0];
811 let mat = CsrMatrix::new(data, rows, cols, (3, 3)).expect("Failed");
812 let b = vec![1.0, 2.0, 3.0];
813 let result = sparse_least_squares(&mat, &b, None).expect("LS identity failed");
814 for i in 0..3 {
815 assert!(
816 (result.solution[i] - b[i]).abs() < 1e-10,
817 "x[{i}] = {}, expected {}",
818 result.solution[i],
819 b[i]
820 );
821 }
822 assert!(result.residual_norm < 1e-10);
823 }
824
825 #[test]
826 fn test_least_squares_tall_skinny() {
827 let rows = vec![0, 1, 1, 2, 2, 3, 3, 4, 4];
829 let cols = vec![0, 0, 1, 0, 1, 0, 1, 0, 1];
830 let data = vec![1.0, 1.0, 1.0, 1.0, 2.0, 1.0, 3.0, 1.0, 4.0];
831 let mat = CsrMatrix::new(data, rows, cols, (5, 2)).expect("Failed");
832 let b = vec![2.0, 3.0, 4.0, 5.0, 6.0];
833 let result = sparse_least_squares(&mat, &b, None).expect("LS tall failed");
834 assert_eq!(result.solution.len(), 2);
835 assert_eq!(result.rank, 2);
836 }
837
838 #[test]
839 fn test_householder_vector_zero() {
840 let col = vec![0.0, 0.0, 0.0];
841 let (v, tau) = householder_vector(&col, 0, 3);
842 assert!((tau).abs() < 1e-15);
843 for &vi in &v {
844 assert!((vi).abs() < 1e-15);
845 }
846 }
847
848 #[test]
849 fn test_col_perm_valid() {
850 let mat = create_3x3();
851 let config = SparseQrConfig {
852 pivoting: true,
853 ..Default::default()
854 };
855 let qr = sparse_qr(&mat, &config).expect("QR failed");
856 let mut sorted = qr.col_perm.clone();
857 sorted.sort();
858 assert_eq!(sorted, vec![0, 1, 2]);
859 }
860
861 #[test]
862 fn test_extract_q_dense_columns_orthonormal() {
863 let mat = create_3x3();
864 let config = SparseQrConfig {
865 pivoting: false,
866 economy: true,
867 ..Default::default()
868 };
869 let qr = sparse_qr(&mat, &config).expect("QR failed");
870 let q = extract_q_dense(&qr.householder_v, &qr.tau, 3, 3).expect("Q failed");
871
872 for j in 0..3 {
874 let mut norm_sq = 0.0;
875 for i in 0..3 {
876 norm_sq += q[i][j] * q[i][j];
877 }
878 assert!(
879 (norm_sq - 1.0).abs() < 1e-10,
880 "Column {j} norm = {}",
881 norm_sq.sqrt()
882 );
883 }
884 }
885}