1use std::time::Instant;
52
53use tracing::{debug, trace, warn};
54
55use crate::error::{SolverError, ValidationError};
56use crate::traits::SolverEngine;
57use crate::types::{
58 Algorithm, ComplexityClass, ComplexityEstimate, ComputeBudget, ConvergenceInfo, CsrMatrix,
59 SolverResult, SparsityProfile,
60};
61
62#[inline]
76pub fn dot_product_f64(a: &[f32], b: &[f32]) -> f64 {
77 assert_eq!(a.len(), b.len(), "dot_product_f64: length mismatch");
78
79 let n = a.len();
80 let chunks = n / 4;
81 let remainder = n % 4;
82
83 let mut acc0: f64 = 0.0;
84 let mut acc1: f64 = 0.0;
85 let mut acc2: f64 = 0.0;
86 let mut acc3: f64 = 0.0;
87
88 for i in 0..chunks {
89 let j = i * 4;
90 acc0 += a[j] as f64 * b[j] as f64;
91 acc1 += a[j + 1] as f64 * b[j + 1] as f64;
92 acc2 += a[j + 2] as f64 * b[j + 2] as f64;
93 acc3 += a[j + 3] as f64 * b[j + 3] as f64;
94 }
95
96 let base = chunks * 4;
97 for i in 0..remainder {
98 acc0 += a[base + i] as f64 * b[base + i] as f64;
99 }
100
101 (acc0 + acc1) + (acc2 + acc3)
102}
103
104#[inline]
108fn dot_f64(a: &[f64], b: &[f64]) -> f64 {
109 assert_eq!(a.len(), b.len(), "dot_f64: length mismatch");
110
111 let n = a.len();
112 let chunks = n / 4;
113 let remainder = n % 4;
114
115 let mut acc0: f64 = 0.0;
116 let mut acc1: f64 = 0.0;
117 let mut acc2: f64 = 0.0;
118 let mut acc3: f64 = 0.0;
119
120 for i in 0..chunks {
121 let j = i * 4;
122 acc0 += a[j] * b[j];
123 acc1 += a[j + 1] * b[j + 1];
124 acc2 += a[j + 2] * b[j + 2];
125 acc3 += a[j + 3] * b[j + 3];
126 }
127
128 let base = chunks * 4;
129 for i in 0..remainder {
130 acc0 += a[base + i] * b[base + i];
131 }
132
133 (acc0 + acc1) + (acc2 + acc3)
134}
135
136#[inline]
142pub fn axpy(alpha: f32, x: &[f32], y: &mut [f32]) {
143 assert_eq!(x.len(), y.len(), "axpy: length mismatch");
144
145 let n = x.len();
146 let chunks = n / 4;
147 let base = chunks * 4;
148
149 for i in 0..chunks {
150 let j = i * 4;
151 y[j] += alpha * x[j];
152 y[j + 1] += alpha * x[j + 1];
153 y[j + 2] += alpha * x[j + 2];
154 y[j + 3] += alpha * x[j + 3];
155 }
156 for i in base..n {
157 y[i] += alpha * x[i];
158 }
159}
160
161#[inline]
163fn axpy_f64(alpha: f64, x: &[f64], y: &mut [f64]) {
164 assert_eq!(x.len(), y.len(), "axpy_f64: length mismatch");
165
166 let n = x.len();
167 let chunks = n / 4;
168 let base = chunks * 4;
169
170 for i in 0..chunks {
171 let j = i * 4;
172 y[j] += alpha * x[j];
173 y[j + 1] += alpha * x[j + 1];
174 y[j + 2] += alpha * x[j + 2];
175 y[j + 3] += alpha * x[j + 3];
176 }
177 for i in base..n {
178 y[i] += alpha * x[i];
179 }
180}
181
182#[inline]
186pub fn norm2(x: &[f32]) -> f64 {
187 dot_product_f64(x, x).sqrt()
188}
189
190#[inline]
192fn norm2_f64(x: &[f64]) -> f64 {
193 dot_f64(x, x).sqrt()
194}
195
196#[derive(Debug, Clone)]
206pub struct ConjugateGradientSolver {
207 tolerance: f64,
211
212 max_iterations: usize,
214
215 use_preconditioner: bool,
221}
222
223impl ConjugateGradientSolver {
224 pub fn new(tolerance: f64, max_iterations: usize, use_preconditioner: bool) -> Self {
233 Self {
234 tolerance,
235 max_iterations,
236 use_preconditioner,
237 }
238 }
239
240 #[inline]
242 pub fn tolerance(&self) -> f64 {
243 self.tolerance
244 }
245
246 #[inline]
248 pub fn max_iterations(&self) -> usize {
249 self.max_iterations
250 }
251
252 #[inline]
254 pub fn use_preconditioner(&self) -> bool {
255 self.use_preconditioner
256 }
257
258 fn validate(
264 &self,
265 matrix: &CsrMatrix<f64>,
266 rhs: &[f64],
267 ) -> Result<(), SolverError> {
268 if matrix.rows != matrix.cols {
269 return Err(SolverError::InvalidInput(
270 ValidationError::DimensionMismatch(format!(
271 "CG requires a square matrix but got {}x{}",
272 matrix.rows, matrix.cols,
273 )),
274 ));
275 }
276
277 if rhs.len() != matrix.rows {
278 return Err(SolverError::InvalidInput(
279 ValidationError::DimensionMismatch(format!(
280 "rhs length {} does not match matrix rows {}",
281 rhs.len(),
282 matrix.rows,
283 )),
284 ));
285 }
286
287 if matrix.row_ptr.len() != matrix.rows + 1 {
288 return Err(SolverError::InvalidInput(
289 ValidationError::DimensionMismatch(format!(
290 "row_ptr length {} does not equal rows + 1 = {}",
291 matrix.row_ptr.len(),
292 matrix.rows + 1,
293 )),
294 ));
295 }
296
297 if !self.tolerance.is_finite() || self.tolerance <= 0.0 {
298 return Err(SolverError::InvalidInput(
299 ValidationError::ParameterOutOfRange {
300 name: "tolerance".into(),
301 value: self.tolerance.to_string(),
302 expected: "positive finite value".into(),
303 },
304 ));
305 }
306
307 if self.max_iterations == 0 {
308 return Err(SolverError::InvalidInput(
309 ValidationError::ParameterOutOfRange {
310 name: "max_iterations".into(),
311 value: "0".into(),
312 expected: ">= 1".into(),
313 },
314 ));
315 }
316
317 Ok(())
318 }
319
320 fn build_jacobi_preconditioner(matrix: &CsrMatrix<f64>) -> Vec<f64> {
329 let n = matrix.rows;
330 let mut inv_diag = vec![1.0f64; n];
331
332 for row in 0..n {
333 let start = matrix.row_ptr[row];
334 let end = matrix.row_ptr[row + 1];
335 for idx in start..end {
336 if matrix.col_indices[idx] == row {
337 let diag_val = matrix.values[idx];
338 if diag_val.abs() > f64::EPSILON {
339 inv_diag[row] = 1.0 / diag_val;
340 }
341 break;
342 }
343 }
344 }
345
346 inv_diag
347 }
348
349 #[inline]
351 fn apply_preconditioner(inv_diag: &[f64], r: &[f64], z: &mut [f64]) {
352 assert_eq!(inv_diag.len(), r.len());
353 assert_eq!(r.len(), z.len());
354
355 let n = r.len();
356 let chunks = n / 4;
357 let base = chunks * 4;
358
359 for i in 0..chunks {
360 let j = i * 4;
361 z[j] = inv_diag[j] * r[j];
362 z[j + 1] = inv_diag[j + 1] * r[j + 1];
363 z[j + 2] = inv_diag[j + 2] * r[j + 2];
364 z[j + 3] = inv_diag[j + 3] * r[j + 3];
365 }
366 for i in base..n {
367 z[i] = inv_diag[i] * r[i];
368 }
369 }
370
371 fn solve_inner(
381 &self,
382 matrix: &CsrMatrix<f64>,
383 rhs: &[f64],
384 budget: &ComputeBudget,
385 ) -> Result<SolverResult, SolverError> {
386 let start_time = Instant::now();
387 let n = matrix.rows;
388
389 let effective_max_iter = self.max_iterations.min(budget.max_iterations);
391 let effective_tol = self.tolerance.min(budget.tolerance);
392
393 if n == 0 {
395 return Ok(SolverResult {
396 solution: vec![],
397 iterations: 0,
398 residual_norm: 0.0,
399 wall_time: start_time.elapsed(),
400 convergence_history: vec![],
401 algorithm: Algorithm::CG,
402 });
403 }
404
405 let mut x = vec![0.0f64; n]; let mut r = vec![0.0f64; n]; let mut z = vec![0.0f64; n]; let mut p = vec![0.0f64; n]; let mut ap = vec![0.0f64; n]; let inv_diag = if self.use_preconditioner {
414 Some(Self::build_jacobi_preconditioner(matrix))
415 } else {
416 None
417 };
418
419 r.copy_from_slice(rhs);
421
422 let b_norm = norm2_f64(rhs);
424 let abs_tolerance = effective_tol * b_norm;
425
426 if b_norm < f64::EPSILON {
428 debug!("CG: zero RHS detected, returning zero solution");
429 return Ok(SolverResult {
430 solution: vec![0.0f32; n],
431 iterations: 0,
432 residual_norm: 0.0,
433 wall_time: start_time.elapsed(),
434 convergence_history: vec![],
435 algorithm: Algorithm::CG,
436 });
437 }
438
439 let initial_residual_norm = norm2_f64(&r);
440
441 match &inv_diag {
443 Some(diag) => Self::apply_preconditioner(diag, &r, &mut z),
444 None => z.copy_from_slice(&r),
445 }
446
447 p.copy_from_slice(&z);
449
450 let mut rz = dot_f64(&r, &z);
452
453 let mut convergence_history =
454 Vec::with_capacity(effective_max_iter.min(256));
455 let mut converged = false;
456
457 debug!(
458 "CG: n={}, nnz={}, tol={:.2e}, max_iter={}, precond={}",
459 n,
460 matrix.nnz(),
461 effective_tol,
462 effective_max_iter,
463 self.use_preconditioner,
464 );
465
466 for k in 0..effective_max_iter {
470 if start_time.elapsed() > budget.max_time {
472 warn!("CG: wall-time budget exhausted at iteration {k}");
473 return Err(SolverError::BudgetExhausted {
474 reason: format!(
475 "wall-time limit {:?} exceeded at iteration {k}",
476 budget.max_time,
477 ),
478 elapsed: start_time.elapsed(),
479 });
480 }
481
482 matrix.spmv(&p, &mut ap);
484
485 let p_dot_ap = dot_f64(&p, &ap);
487
488 if p_dot_ap <= 0.0 {
491 warn!("CG: non-positive p.Ap = {p_dot_ap:.4e} at iteration {k}");
492 return Err(SolverError::NumericalInstability {
493 iteration: k,
494 detail: format!(
495 "p.Ap = {p_dot_ap:.6e} <= 0; matrix may not be SPD",
496 ),
497 });
498 }
499
500 let alpha = rz / p_dot_ap;
501
502 axpy_f64(alpha, &p, &mut x);
504
505 axpy_f64(-alpha, &ap, &mut r);
507
508 let r_norm = norm2_f64(&r);
510
511 convergence_history.push(ConvergenceInfo {
512 iteration: k,
513 residual_norm: r_norm,
514 });
515
516 trace!(
517 "CG iter {k}: ||r|| = {r_norm:.6e}, rel = {:.6e}",
518 r_norm / b_norm,
519 );
520
521 if r_norm < abs_tolerance {
522 converged = true;
523 debug!(
524 "CG converged at iteration {k}: ||r|| = {r_norm:.6e}, \
525 rel = {:.6e}",
526 r_norm / b_norm,
527 );
528 break;
529 }
530
531 if r_norm > 10.0 * initial_residual_norm {
535 warn!(
536 "CG: divergence at iteration {k}: ||r|| = {r_norm:.6e} \
537 > 10 * ||r_0|| = {:.6e}",
538 10.0 * initial_residual_norm,
539 );
540 return Err(SolverError::NumericalInstability {
541 iteration: k,
542 detail: format!(
543 "residual diverged: ||r|| = {r_norm:.6e} exceeds \
544 10x initial residual {initial_residual_norm:.6e}",
545 ),
546 });
547 }
548
549 match &inv_diag {
551 Some(diag) => Self::apply_preconditioner(diag, &r, &mut z),
552 None => z.copy_from_slice(&r),
553 }
554
555 let rz_new = dot_f64(&r, &z);
557
558 if rz.abs() < f64::EPSILON * f64::EPSILON {
560 warn!("CG: rz near zero at iteration {k}, stagnation");
561 return Err(SolverError::NumericalInstability {
562 iteration: k,
563 detail: format!(
564 "rz = {rz:.6e} is near zero; solver stagnated",
565 ),
566 });
567 }
568
569 let beta = rz_new / rz;
571
572 for i in 0..n {
574 p[i] = z[i] + beta * p[i];
575 }
576
577 rz = rz_new;
579 }
580
581 let wall_time = start_time.elapsed();
582 let final_residual = norm2_f64(&r);
583
584 if !converged {
585 debug!(
586 "CG: non-convergence after {} iterations, ||r|| = {final_residual:.6e}",
587 effective_max_iter,
588 );
589 return Err(SolverError::NonConvergence {
590 iterations: effective_max_iter,
591 residual: final_residual,
592 tolerance: abs_tolerance,
593 });
594 }
595
596 let solution_f32: Vec<f32> = x.iter().map(|&v| v as f32).collect();
598
599 Ok(SolverResult {
600 solution: solution_f32,
601 iterations: convergence_history.len(),
602 residual_norm: final_residual,
603 wall_time,
604 convergence_history,
605 algorithm: Algorithm::CG,
606 })
607 }
608}
609
610impl SolverEngine for ConjugateGradientSolver {
615 fn solve(
624 &self,
625 matrix: &CsrMatrix<f64>,
626 rhs: &[f64],
627 budget: &ComputeBudget,
628 ) -> Result<SolverResult, SolverError> {
629 self.validate(matrix, rhs)?;
630 self.solve_inner(matrix, rhs, budget)
631 }
632
633 fn estimate_complexity(
638 &self,
639 profile: &SparsityProfile,
640 n: usize,
641 ) -> ComplexityEstimate {
642 let est_iters = (profile.estimated_condition.sqrt() as usize)
644 .max(1)
645 .min(self.max_iterations);
646
647 let flops_per_iter = 2 * profile.nnz as u64 + 6 * n as u64;
649 let estimated_flops = est_iters as u64 * flops_per_iter;
650
651 let vec_bytes = n * std::mem::size_of::<f64>();
653 let precond_bytes = if self.use_preconditioner { vec_bytes } else { 0 };
654 let estimated_memory_bytes = 5 * vec_bytes + precond_bytes;
655
656 ComplexityEstimate {
657 algorithm: Algorithm::CG,
658 estimated_flops,
659 estimated_iterations: est_iters,
660 estimated_memory_bytes,
661 complexity_class: ComplexityClass::SqrtCondition,
662 }
663 }
664
665 fn algorithm(&self) -> Algorithm {
667 Algorithm::CG
668 }
669}
670
671#[cfg(test)]
676mod tests {
677 use super::*;
678 use std::time::Duration;
679
680 fn tridiagonal_spd(n: usize) -> CsrMatrix<f64> {
683 let mut entries = Vec::with_capacity(3 * n);
684 for i in 0..n {
685 if i > 0 {
686 entries.push((i, i - 1, -1.0f64));
687 }
688 entries.push((i, i, 4.0f64));
689 if i + 1 < n {
690 entries.push((i, i + 1, -1.0f64));
691 }
692 }
693 CsrMatrix::<f64>::from_coo(n, n, entries)
694 }
695
696 fn diagonal_matrix(diag: &[f64]) -> CsrMatrix<f64> {
698 let n = diag.len();
699 let entries: Vec<_> = diag
700 .iter()
701 .enumerate()
702 .map(|(i, &v)| (i, i, v))
703 .collect();
704 CsrMatrix::<f64>::from_coo(n, n, entries)
705 }
706
707 fn identity(n: usize) -> CsrMatrix<f64> {
709 CsrMatrix::<f64>::identity(n)
710 }
711
712 fn default_budget() -> ComputeBudget {
713 ComputeBudget {
714 max_time: Duration::from_secs(30),
715 max_iterations: 10_000,
716 tolerance: 1e-10,
717 }
718 }
719
720 #[test]
725 fn dot_product_f64_basic() {
726 let a = vec![1.0f32, 2.0, 3.0];
727 let b = vec![4.0f32, 5.0, 6.0];
728 let result = dot_product_f64(&a, &b);
729 assert!((result - 32.0).abs() < 1e-10);
730 }
731
732 #[test]
733 fn dot_product_f64_empty() {
734 assert!((dot_product_f64(&[], &[]) - 0.0).abs() < 1e-10);
735 }
736
737 #[test]
738 fn dot_product_f64_precision() {
739 let n = 10_000;
740 let a = vec![1.0f32; n];
741 let b = vec![1.0f32; n];
742 assert!((dot_product_f64(&a, &b) - n as f64).abs() < 1e-10);
743 }
744
745 #[test]
746 fn dot_product_f64_odd_length() {
747 let a = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
748 let b = vec![5.0f32, 4.0, 3.0, 2.0, 1.0];
749 assert!((dot_product_f64(&a, &b) - 35.0).abs() < 1e-10);
751 }
752
753 #[test]
758 fn axpy_basic() {
759 let x = vec![1.0f32, 2.0, 3.0];
760 let mut y = vec![10.0f32, 20.0, 30.0];
761 axpy(2.0, &x, &mut y);
762 assert_eq!(y, vec![12.0, 24.0, 36.0]);
763 }
764
765 #[test]
766 fn axpy_negative_alpha() {
767 let x = vec![1.0f32, 1.0, 1.0];
768 let mut y = vec![5.0f32, 5.0, 5.0];
769 axpy(-3.0, &x, &mut y);
770 assert_eq!(y, vec![2.0, 2.0, 2.0]);
771 }
772
773 #[test]
778 fn norm2_basic() {
779 let x = vec![3.0f32, 4.0];
780 assert!((norm2(&x) - 5.0).abs() < 1e-10);
781 }
782
783 #[test]
784 fn norm2_zero() {
785 assert!((norm2(&vec![0.0f32; 5]) - 0.0).abs() < 1e-10);
786 }
787
788 #[test]
793 fn cg_identity_matrix() {
794 let n = 5;
795 let matrix = identity(n);
796 let rhs = vec![1.0, 2.0, 3.0, 4.0, 5.0];
797 let budget = default_budget();
798
799 let solver = ConjugateGradientSolver::new(1e-10, 100, false);
800 let result = solver.solve(&matrix, &rhs, &budget).unwrap();
801
802 for i in 0..n {
803 assert!(
804 (result.solution[i] as f64 - rhs[i]).abs() < 1e-5,
805 "x[{i}] = {} != {}",
806 result.solution[i],
807 rhs[i],
808 );
809 }
810 assert!(result.iterations <= 1);
812 }
813
814 #[test]
815 fn cg_diagonal_matrix() {
816 let diag = vec![2.0, 3.0, 5.0, 7.0];
817 let matrix = diagonal_matrix(&diag);
818 let rhs = vec![4.0, 9.0, 25.0, 49.0];
819 let budget = default_budget();
820
821 let solver = ConjugateGradientSolver::new(1e-10, 100, false);
822 let result = solver.solve(&matrix, &rhs, &budget).unwrap();
823
824 let expected = [2.0, 3.0, 5.0, 7.0];
825 for i in 0..4 {
826 assert!(
827 (result.solution[i] as f64 - expected[i]).abs() < 1e-4,
828 "x[{i}] = {} != {}",
829 result.solution[i],
830 expected[i],
831 );
832 }
833 }
834
835 #[test]
836 fn cg_tridiagonal_small() {
837 let n = 10;
838 let matrix = tridiagonal_spd(n);
839 let rhs = vec![1.0f64; n];
840 let budget = default_budget();
841
842 let solver = ConjugateGradientSolver::new(1e-8, 200, false);
843 let result = solver.solve(&matrix, &rhs, &budget).unwrap();
844
845 assert!(
846 result.residual_norm < 1e-6,
847 "residual = {}",
848 result.residual_norm,
849 );
850 assert!(
851 result.iterations <= n,
852 "took {} iterations for n={}",
853 result.iterations,
854 n,
855 );
856 }
857
858 #[test]
859 fn cg_tridiagonal_large() {
860 let n = 500;
861 let matrix = tridiagonal_spd(n);
862 let rhs: Vec<f64> = (0..n).map(|i| (i as f64 + 1.0) / n as f64).collect();
863 let budget = default_budget();
864
865 let solver = ConjugateGradientSolver::new(1e-8, 2000, false);
866 let result = solver.solve(&matrix, &rhs, &budget).unwrap();
867
868 assert!(
869 result.residual_norm < 1e-5,
870 "residual = {}",
871 result.residual_norm,
872 );
873 }
874
875 #[test]
880 fn cg_preconditioned_converges_faster() {
881 let n = 100;
882 let matrix = tridiagonal_spd(n);
883 let rhs = vec![1.0f64; n];
884 let budget = default_budget();
885
886 let no_precond = ConjugateGradientSolver::new(1e-8, 500, false);
887 let with_precond = ConjugateGradientSolver::new(1e-8, 500, true);
888
889 let result_no = no_precond.solve(&matrix, &rhs, &budget).unwrap();
890 let result_yes = with_precond.solve(&matrix, &rhs, &budget).unwrap();
891
892 assert!(result_no.residual_norm < 1e-6);
893 assert!(result_yes.residual_norm < 1e-6);
894
895 assert!(
896 result_yes.iterations <= result_no.iterations,
897 "preconditioned ({}) should use <= iterations than \
898 unpreconditioned ({})",
899 result_yes.iterations,
900 result_no.iterations,
901 );
902 }
903
904 #[test]
909 fn cg_zero_rhs() {
910 let matrix = tridiagonal_spd(5);
911 let rhs = vec![0.0f64; 5];
912 let budget = default_budget();
913
914 let solver = ConjugateGradientSolver::new(1e-8, 100, false);
915 let result = solver.solve(&matrix, &rhs, &budget).unwrap();
916
917 assert_eq!(result.iterations, 0);
918 for &v in &result.solution {
919 assert!((v as f64).abs() < 1e-10);
920 }
921 }
922
923 #[test]
924 fn cg_empty_system() {
925 let matrix = CsrMatrix {
926 row_ptr: vec![0],
927 col_indices: vec![],
928 values: Vec::<f64>::new(),
929 rows: 0,
930 cols: 0,
931 };
932 let rhs: Vec<f64> = vec![];
933 let budget = default_budget();
934
935 let solver = ConjugateGradientSolver::new(1e-8, 100, false);
936 let result = solver.solve(&matrix, &rhs, &budget).unwrap();
937
938 assert_eq!(result.iterations, 0);
939 assert!(result.solution.is_empty());
940 }
941
942 #[test]
947 fn cg_dimension_mismatch() {
948 let matrix = tridiagonal_spd(3);
949 let rhs = vec![1.0f64; 5];
950 let budget = default_budget();
951
952 let solver = ConjugateGradientSolver::new(1e-8, 100, false);
953 let err = solver.solve(&matrix, &rhs, &budget).unwrap_err();
954 assert!(matches!(err, SolverError::InvalidInput(_)));
955 }
956
957 #[test]
958 fn cg_non_square_matrix() {
959 let matrix = CsrMatrix {
960 row_ptr: vec![0, 1, 2],
961 col_indices: vec![0, 1],
962 values: vec![1.0f64, 1.0],
963 rows: 2,
964 cols: 3,
965 };
966 let rhs = vec![1.0f64; 2];
967 let budget = default_budget();
968
969 let solver = ConjugateGradientSolver::new(1e-8, 100, false);
970 let err = solver.solve(&matrix, &rhs, &budget).unwrap_err();
971 assert!(matches!(err, SolverError::InvalidInput(_)));
972 }
973
974 #[test]
975 fn cg_non_convergence() {
976 let n = 50;
977 let matrix = tridiagonal_spd(n);
978 let rhs = vec![1.0f64; n];
979 let budget = ComputeBudget {
980 max_time: Duration::from_secs(30),
981 max_iterations: 1,
982 tolerance: 1e-15,
983 };
984
985 let solver = ConjugateGradientSolver::new(1e-15, 1, false);
986 let err = solver.solve(&matrix, &rhs, &budget).unwrap_err();
987 assert!(matches!(err, SolverError::NonConvergence { .. }));
988 }
989
990 #[test]
991 fn cg_budget_iteration_limit() {
992 let n = 50;
993 let matrix = tridiagonal_spd(n);
994 let rhs = vec![1.0f64; n];
995
996 let solver = ConjugateGradientSolver::new(1e-15, 1000, false);
998 let budget = ComputeBudget {
999 max_time: Duration::from_secs(60),
1000 max_iterations: 2,
1001 tolerance: 1e-15,
1002 };
1003
1004 let err = solver.solve(&matrix, &rhs, &budget).unwrap_err();
1005 assert!(matches!(err, SolverError::NonConvergence { .. }));
1006 }
1007
1008 #[test]
1013 fn cg_convergence_history_populated() {
1014 let n = 20;
1015 let matrix = tridiagonal_spd(n);
1016 let rhs = vec![1.0f64; n];
1017 let budget = default_budget();
1018
1019 let solver = ConjugateGradientSolver::new(1e-10, 200, false);
1020 let result = solver.solve(&matrix, &rhs, &budget).unwrap();
1021
1022 assert!(!result.convergence_history.is_empty());
1023
1024 let last = result.convergence_history.last().unwrap();
1026 assert!((last.residual_norm - result.residual_norm).abs() < 1e-12);
1027 }
1028
1029 #[test]
1030 fn cg_algorithm_field() {
1031 let matrix = identity(3);
1032 let rhs = vec![1.0f64; 3];
1033 let budget = default_budget();
1034
1035 let solver = ConjugateGradientSolver::new(1e-8, 100, false);
1036 let result = solver.solve(&matrix, &rhs, &budget).unwrap();
1037 assert_eq!(result.algorithm, Algorithm::CG);
1038 }
1039
1040 #[test]
1045 fn cg_solution_satisfies_system() {
1046 let n = 20;
1047 let matrix = tridiagonal_spd(n);
1048 let rhs = vec![1.0f64; n];
1049 let budget = default_budget();
1050
1051 let solver = ConjugateGradientSolver::new(1e-10, 200, true);
1052 let result = solver.solve(&matrix, &rhs, &budget).unwrap();
1053
1054 let x_f64: Vec<f64> = result.solution.iter().map(|&v| v as f64).collect();
1056 let mut ax = vec![0.0f64; n];
1057 matrix.spmv(&x_f64, &mut ax);
1058
1059 let mut max_err: f64 = 0.0;
1060 for i in 0..n {
1061 let err = (ax[i] - rhs[i]).abs();
1062 if err > max_err {
1063 max_err = err;
1064 }
1065 }
1066
1067 assert!(
1068 max_err < 1e-4,
1069 "max |Ax - b| = {max_err:.6e}, expected < 1e-4",
1070 );
1071 }
1072
1073 #[test]
1078 fn estimate_complexity_returns_cg() {
1079 let solver = ConjugateGradientSolver::new(1e-8, 500, true);
1080 let profile = SparsityProfile {
1081 rows: 100,
1082 cols: 100,
1083 nnz: 298,
1084 density: 0.0298,
1085 is_diag_dominant: true,
1086 estimated_spectral_radius: 0.5,
1087 estimated_condition: 100.0,
1088 is_symmetric_structure: true,
1089 avg_nnz_per_row: 2.98,
1090 max_nnz_per_row: 3,
1091 };
1092
1093 let est = solver.estimate_complexity(&profile, 100);
1094 assert_eq!(est.algorithm, Algorithm::CG);
1095 assert_eq!(est.complexity_class, ComplexityClass::SqrtCondition);
1096 assert!(est.estimated_iterations > 0);
1097 assert!(est.estimated_flops > 0);
1098 assert!(est.estimated_memory_bytes > 0);
1099 }
1100
1101 #[test]
1106 fn accessors() {
1107 let solver = ConjugateGradientSolver::new(1e-6, 500, true);
1108 assert!((solver.tolerance() - 1e-6).abs() < 1e-15);
1109 assert_eq!(solver.max_iterations(), 500);
1110 assert!(solver.use_preconditioner());
1111 }
1112}