1use std::time::Instant;
52
53use tracing::{debug, trace, warn};
54
55use crate::error::{SolverError, ValidationError};
56use crate::traits::SolverEngine;
57use crate::types::{
58 Algorithm, ComplexityClass, ComplexityEstimate, ComputeBudget, ConvergenceInfo, CsrMatrix,
59 SolverResult, SparsityProfile,
60};
61
62#[inline]
76pub fn dot_product_f64(a: &[f32], b: &[f32]) -> f64 {
77 assert_eq!(a.len(), b.len(), "dot_product_f64: length mismatch");
78
79 let n = a.len();
80 let chunks = n / 4;
81 let remainder = n % 4;
82
83 let mut acc0: f64 = 0.0;
84 let mut acc1: f64 = 0.0;
85 let mut acc2: f64 = 0.0;
86 let mut acc3: f64 = 0.0;
87
88 for i in 0..chunks {
89 let j = i * 4;
90 acc0 += a[j] as f64 * b[j] as f64;
91 acc1 += a[j + 1] as f64 * b[j + 1] as f64;
92 acc2 += a[j + 2] as f64 * b[j + 2] as f64;
93 acc3 += a[j + 3] as f64 * b[j + 3] as f64;
94 }
95
96 let base = chunks * 4;
97 for i in 0..remainder {
98 acc0 += a[base + i] as f64 * b[base + i] as f64;
99 }
100
101 (acc0 + acc1) + (acc2 + acc3)
102}
103
104#[inline]
108fn dot_f64(a: &[f64], b: &[f64]) -> f64 {
109 assert_eq!(a.len(), b.len(), "dot_f64: length mismatch");
110
111 let n = a.len();
112 let chunks = n / 4;
113 let remainder = n % 4;
114
115 let mut acc0: f64 = 0.0;
116 let mut acc1: f64 = 0.0;
117 let mut acc2: f64 = 0.0;
118 let mut acc3: f64 = 0.0;
119
120 for i in 0..chunks {
121 let j = i * 4;
122 acc0 += a[j] * b[j];
123 acc1 += a[j + 1] * b[j + 1];
124 acc2 += a[j + 2] * b[j + 2];
125 acc3 += a[j + 3] * b[j + 3];
126 }
127
128 let base = chunks * 4;
129 for i in 0..remainder {
130 acc0 += a[base + i] * b[base + i];
131 }
132
133 (acc0 + acc1) + (acc2 + acc3)
134}
135
136#[inline]
142pub fn axpy(alpha: f32, x: &[f32], y: &mut [f32]) {
143 assert_eq!(x.len(), y.len(), "axpy: length mismatch");
144
145 let n = x.len();
146 let chunks = n / 4;
147 let base = chunks * 4;
148
149 for i in 0..chunks {
150 let j = i * 4;
151 y[j] += alpha * x[j];
152 y[j + 1] += alpha * x[j + 1];
153 y[j + 2] += alpha * x[j + 2];
154 y[j + 3] += alpha * x[j + 3];
155 }
156 for i in base..n {
157 y[i] += alpha * x[i];
158 }
159}
160
161#[inline]
163fn axpy_f64(alpha: f64, x: &[f64], y: &mut [f64]) {
164 assert_eq!(x.len(), y.len(), "axpy_f64: length mismatch");
165
166 let n = x.len();
167 let chunks = n / 4;
168 let base = chunks * 4;
169
170 for i in 0..chunks {
171 let j = i * 4;
172 y[j] += alpha * x[j];
173 y[j + 1] += alpha * x[j + 1];
174 y[j + 2] += alpha * x[j + 2];
175 y[j + 3] += alpha * x[j + 3];
176 }
177 for i in base..n {
178 y[i] += alpha * x[i];
179 }
180}
181
182#[inline]
186pub fn norm2(x: &[f32]) -> f64 {
187 dot_product_f64(x, x).sqrt()
188}
189
190#[inline]
192fn norm2_f64(x: &[f64]) -> f64 {
193 dot_f64(x, x).sqrt()
194}
195
196#[derive(Debug, Clone)]
206pub struct ConjugateGradientSolver {
207 tolerance: f64,
211
212 max_iterations: usize,
214
215 use_preconditioner: bool,
221}
222
223impl ConjugateGradientSolver {
224 pub fn new(tolerance: f64, max_iterations: usize, use_preconditioner: bool) -> Self {
233 Self {
234 tolerance,
235 max_iterations,
236 use_preconditioner,
237 }
238 }
239
240 #[inline]
242 pub fn tolerance(&self) -> f64 {
243 self.tolerance
244 }
245
246 #[inline]
248 pub fn max_iterations(&self) -> usize {
249 self.max_iterations
250 }
251
252 #[inline]
254 pub fn use_preconditioner(&self) -> bool {
255 self.use_preconditioner
256 }
257
258 fn validate(&self, matrix: &CsrMatrix<f64>, rhs: &[f64]) -> Result<(), SolverError> {
264 if matrix.rows != matrix.cols {
265 return Err(SolverError::InvalidInput(
266 ValidationError::DimensionMismatch(format!(
267 "CG requires a square matrix but got {}x{}",
268 matrix.rows, matrix.cols,
269 )),
270 ));
271 }
272
273 if rhs.len() != matrix.rows {
274 return Err(SolverError::InvalidInput(
275 ValidationError::DimensionMismatch(format!(
276 "rhs length {} does not match matrix rows {}",
277 rhs.len(),
278 matrix.rows,
279 )),
280 ));
281 }
282
283 if matrix.row_ptr.len() != matrix.rows + 1 {
284 return Err(SolverError::InvalidInput(
285 ValidationError::DimensionMismatch(format!(
286 "row_ptr length {} does not equal rows + 1 = {}",
287 matrix.row_ptr.len(),
288 matrix.rows + 1,
289 )),
290 ));
291 }
292
293 if !self.tolerance.is_finite() || self.tolerance <= 0.0 {
294 return Err(SolverError::InvalidInput(
295 ValidationError::ParameterOutOfRange {
296 name: "tolerance".into(),
297 value: self.tolerance.to_string(),
298 expected: "positive finite value".into(),
299 },
300 ));
301 }
302
303 if self.max_iterations == 0 {
304 return Err(SolverError::InvalidInput(
305 ValidationError::ParameterOutOfRange {
306 name: "max_iterations".into(),
307 value: "0".into(),
308 expected: ">= 1".into(),
309 },
310 ));
311 }
312
313 Ok(())
314 }
315
316 fn build_jacobi_preconditioner(matrix: &CsrMatrix<f64>) -> Vec<f64> {
325 let n = matrix.rows;
326 let mut inv_diag = vec![1.0f64; n];
327
328 for row in 0..n {
329 let start = matrix.row_ptr[row];
330 let end = matrix.row_ptr[row + 1];
331 for idx in start..end {
332 if matrix.col_indices[idx] == row {
333 let diag_val = matrix.values[idx];
334 if diag_val.abs() > f64::EPSILON {
335 inv_diag[row] = 1.0 / diag_val;
336 }
337 break;
338 }
339 }
340 }
341
342 inv_diag
343 }
344
345 #[inline]
347 fn apply_preconditioner(inv_diag: &[f64], r: &[f64], z: &mut [f64]) {
348 assert_eq!(inv_diag.len(), r.len());
349 assert_eq!(r.len(), z.len());
350
351 let n = r.len();
352 let chunks = n / 4;
353 let base = chunks * 4;
354
355 for i in 0..chunks {
356 let j = i * 4;
357 z[j] = inv_diag[j] * r[j];
358 z[j + 1] = inv_diag[j + 1] * r[j + 1];
359 z[j + 2] = inv_diag[j + 2] * r[j + 2];
360 z[j + 3] = inv_diag[j + 3] * r[j + 3];
361 }
362 for i in base..n {
363 z[i] = inv_diag[i] * r[i];
364 }
365 }
366
367 fn solve_inner(
377 &self,
378 matrix: &CsrMatrix<f64>,
379 rhs: &[f64],
380 budget: &ComputeBudget,
381 ) -> Result<SolverResult, SolverError> {
382 let start_time = Instant::now();
383 let n = matrix.rows;
384
385 let effective_max_iter = self.max_iterations.min(budget.max_iterations);
387 let effective_tol = self.tolerance.min(budget.tolerance);
388
389 if n == 0 {
391 return Ok(SolverResult {
392 solution: vec![],
393 iterations: 0,
394 residual_norm: 0.0,
395 wall_time: start_time.elapsed(),
396 convergence_history: vec![],
397 algorithm: Algorithm::CG,
398 });
399 }
400
401 let mut x = vec![0.0f64; n]; let mut r = vec![0.0f64; n]; let mut z = vec![0.0f64; n]; let mut p = vec![0.0f64; n]; let mut ap = vec![0.0f64; n]; let inv_diag = if self.use_preconditioner {
410 Some(Self::build_jacobi_preconditioner(matrix))
411 } else {
412 None
413 };
414
415 r.copy_from_slice(rhs);
417
418 let b_norm = norm2_f64(rhs);
420 let abs_tolerance = effective_tol * b_norm;
421
422 if b_norm < f64::EPSILON {
424 debug!("CG: zero RHS detected, returning zero solution");
425 return Ok(SolverResult {
426 solution: vec![0.0f32; n],
427 iterations: 0,
428 residual_norm: 0.0,
429 wall_time: start_time.elapsed(),
430 convergence_history: vec![],
431 algorithm: Algorithm::CG,
432 });
433 }
434
435 let initial_residual_norm = norm2_f64(&r);
436
437 match &inv_diag {
439 Some(diag) => Self::apply_preconditioner(diag, &r, &mut z),
440 None => z.copy_from_slice(&r),
441 }
442
443 p.copy_from_slice(&z);
445
446 let mut rz = dot_f64(&r, &z);
448
449 let mut convergence_history = Vec::with_capacity(effective_max_iter.min(256));
450 let mut converged = false;
451
452 debug!(
453 "CG: n={}, nnz={}, tol={:.2e}, max_iter={}, precond={}",
454 n,
455 matrix.nnz(),
456 effective_tol,
457 effective_max_iter,
458 self.use_preconditioner,
459 );
460
461 for k in 0..effective_max_iter {
465 if start_time.elapsed() > budget.max_time {
467 warn!("CG: wall-time budget exhausted at iteration {k}");
468 return Err(SolverError::BudgetExhausted {
469 reason: format!(
470 "wall-time limit {:?} exceeded at iteration {k}",
471 budget.max_time,
472 ),
473 elapsed: start_time.elapsed(),
474 });
475 }
476
477 matrix.spmv(&p, &mut ap);
479
480 let p_dot_ap = dot_f64(&p, &ap);
482
483 if p_dot_ap <= 0.0 {
486 warn!("CG: non-positive p.Ap = {p_dot_ap:.4e} at iteration {k}");
487 return Err(SolverError::NumericalInstability {
488 iteration: k,
489 detail: format!("p.Ap = {p_dot_ap:.6e} <= 0; matrix may not be SPD",),
490 });
491 }
492
493 let alpha = rz / p_dot_ap;
494
495 axpy_f64(alpha, &p, &mut x);
497
498 axpy_f64(-alpha, &ap, &mut r);
500
501 let r_norm = norm2_f64(&r);
503
504 convergence_history.push(ConvergenceInfo {
505 iteration: k,
506 residual_norm: r_norm,
507 });
508
509 trace!(
510 "CG iter {k}: ||r|| = {r_norm:.6e}, rel = {:.6e}",
511 r_norm / b_norm,
512 );
513
514 if r_norm < abs_tolerance {
515 converged = true;
516 debug!(
517 "CG converged at iteration {k}: ||r|| = {r_norm:.6e}, \
518 rel = {:.6e}",
519 r_norm / b_norm,
520 );
521 break;
522 }
523
524 if r_norm > 10.0 * initial_residual_norm {
528 warn!(
529 "CG: divergence at iteration {k}: ||r|| = {r_norm:.6e} \
530 > 10 * ||r_0|| = {:.6e}",
531 10.0 * initial_residual_norm,
532 );
533 return Err(SolverError::NumericalInstability {
534 iteration: k,
535 detail: format!(
536 "residual diverged: ||r|| = {r_norm:.6e} exceeds \
537 10x initial residual {initial_residual_norm:.6e}",
538 ),
539 });
540 }
541
542 match &inv_diag {
544 Some(diag) => Self::apply_preconditioner(diag, &r, &mut z),
545 None => z.copy_from_slice(&r),
546 }
547
548 let rz_new = dot_f64(&r, &z);
550
551 if rz.abs() < f64::EPSILON * f64::EPSILON {
553 warn!("CG: rz near zero at iteration {k}, stagnation");
554 return Err(SolverError::NumericalInstability {
555 iteration: k,
556 detail: format!("rz = {rz:.6e} is near zero; solver stagnated",),
557 });
558 }
559
560 let beta = rz_new / rz;
562
563 for i in 0..n {
565 p[i] = z[i] + beta * p[i];
566 }
567
568 rz = rz_new;
570 }
571
572 let wall_time = start_time.elapsed();
573 let final_residual = norm2_f64(&r);
574
575 if !converged {
576 debug!(
577 "CG: non-convergence after {} iterations, ||r|| = {final_residual:.6e}",
578 effective_max_iter,
579 );
580 return Err(SolverError::NonConvergence {
581 iterations: effective_max_iter,
582 residual: final_residual,
583 tolerance: abs_tolerance,
584 });
585 }
586
587 let solution_f32: Vec<f32> = x.iter().map(|&v| v as f32).collect();
589
590 Ok(SolverResult {
591 solution: solution_f32,
592 iterations: convergence_history.len(),
593 residual_norm: final_residual,
594 wall_time,
595 convergence_history,
596 algorithm: Algorithm::CG,
597 })
598 }
599}
600
601impl SolverEngine for ConjugateGradientSolver {
606 fn solve(
615 &self,
616 matrix: &CsrMatrix<f64>,
617 rhs: &[f64],
618 budget: &ComputeBudget,
619 ) -> Result<SolverResult, SolverError> {
620 self.validate(matrix, rhs)?;
621 self.solve_inner(matrix, rhs, budget)
622 }
623
624 fn estimate_complexity(&self, profile: &SparsityProfile, n: usize) -> ComplexityEstimate {
629 let est_iters = (profile.estimated_condition.sqrt() as usize)
631 .max(1)
632 .min(self.max_iterations);
633
634 let flops_per_iter = 2 * profile.nnz as u64 + 6 * n as u64;
636 let estimated_flops = est_iters as u64 * flops_per_iter;
637
638 let vec_bytes = n * std::mem::size_of::<f64>();
640 let precond_bytes = if self.use_preconditioner {
641 vec_bytes
642 } else {
643 0
644 };
645 let estimated_memory_bytes = 5 * vec_bytes + precond_bytes;
646
647 ComplexityEstimate {
648 algorithm: Algorithm::CG,
649 estimated_flops,
650 estimated_iterations: est_iters,
651 estimated_memory_bytes,
652 complexity_class: ComplexityClass::SqrtCondition,
653 }
654 }
655
656 fn algorithm(&self) -> Algorithm {
658 Algorithm::CG
659 }
660}
661
662#[cfg(test)]
667mod tests {
668 use super::*;
669 use std::time::Duration;
670
671 fn tridiagonal_spd(n: usize) -> CsrMatrix<f64> {
674 let mut entries = Vec::with_capacity(3 * n);
675 for i in 0..n {
676 if i > 0 {
677 entries.push((i, i - 1, -1.0f64));
678 }
679 entries.push((i, i, 4.0f64));
680 if i + 1 < n {
681 entries.push((i, i + 1, -1.0f64));
682 }
683 }
684 CsrMatrix::<f64>::from_coo(n, n, entries)
685 }
686
687 fn diagonal_matrix(diag: &[f64]) -> CsrMatrix<f64> {
689 let n = diag.len();
690 let entries: Vec<_> = diag.iter().enumerate().map(|(i, &v)| (i, i, v)).collect();
691 CsrMatrix::<f64>::from_coo(n, n, entries)
692 }
693
694 fn identity(n: usize) -> CsrMatrix<f64> {
696 CsrMatrix::<f64>::identity(n)
697 }
698
699 fn default_budget() -> ComputeBudget {
700 ComputeBudget {
701 max_time: Duration::from_secs(30),
702 max_iterations: 10_000,
703 tolerance: 1e-10,
704 }
705 }
706
707 #[test]
712 fn dot_product_f64_basic() {
713 let a = vec![1.0f32, 2.0, 3.0];
714 let b = vec![4.0f32, 5.0, 6.0];
715 let result = dot_product_f64(&a, &b);
716 assert!((result - 32.0).abs() < 1e-10);
717 }
718
719 #[test]
720 fn dot_product_f64_empty() {
721 assert!((dot_product_f64(&[], &[]) - 0.0).abs() < 1e-10);
722 }
723
724 #[test]
725 fn dot_product_f64_precision() {
726 let n = 10_000;
727 let a = vec![1.0f32; n];
728 let b = vec![1.0f32; n];
729 assert!((dot_product_f64(&a, &b) - n as f64).abs() < 1e-10);
730 }
731
732 #[test]
733 fn dot_product_f64_odd_length() {
734 let a = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
735 let b = vec![5.0f32, 4.0, 3.0, 2.0, 1.0];
736 assert!((dot_product_f64(&a, &b) - 35.0).abs() < 1e-10);
738 }
739
740 #[test]
745 fn axpy_basic() {
746 let x = vec![1.0f32, 2.0, 3.0];
747 let mut y = vec![10.0f32, 20.0, 30.0];
748 axpy(2.0, &x, &mut y);
749 assert_eq!(y, vec![12.0, 24.0, 36.0]);
750 }
751
752 #[test]
753 fn axpy_negative_alpha() {
754 let x = vec![1.0f32, 1.0, 1.0];
755 let mut y = vec![5.0f32, 5.0, 5.0];
756 axpy(-3.0, &x, &mut y);
757 assert_eq!(y, vec![2.0, 2.0, 2.0]);
758 }
759
760 #[test]
765 fn norm2_basic() {
766 let x = vec![3.0f32, 4.0];
767 assert!((norm2(&x) - 5.0).abs() < 1e-10);
768 }
769
770 #[test]
771 fn norm2_zero() {
772 assert!((norm2(&vec![0.0f32; 5]) - 0.0).abs() < 1e-10);
773 }
774
775 #[test]
780 fn cg_identity_matrix() {
781 let n = 5;
782 let matrix = identity(n);
783 let rhs = vec![1.0, 2.0, 3.0, 4.0, 5.0];
784 let budget = default_budget();
785
786 let solver = ConjugateGradientSolver::new(1e-10, 100, false);
787 let result = solver.solve(&matrix, &rhs, &budget).unwrap();
788
789 for i in 0..n {
790 assert!(
791 (result.solution[i] as f64 - rhs[i]).abs() < 1e-5,
792 "x[{i}] = {} != {}",
793 result.solution[i],
794 rhs[i],
795 );
796 }
797 assert!(result.iterations <= 1);
799 }
800
801 #[test]
802 fn cg_diagonal_matrix() {
803 let diag = vec![2.0, 3.0, 5.0, 7.0];
804 let matrix = diagonal_matrix(&diag);
805 let rhs = vec![4.0, 9.0, 25.0, 49.0];
806 let budget = default_budget();
807
808 let solver = ConjugateGradientSolver::new(1e-10, 100, false);
809 let result = solver.solve(&matrix, &rhs, &budget).unwrap();
810
811 let expected = [2.0, 3.0, 5.0, 7.0];
812 for i in 0..4 {
813 assert!(
814 (result.solution[i] as f64 - expected[i]).abs() < 1e-4,
815 "x[{i}] = {} != {}",
816 result.solution[i],
817 expected[i],
818 );
819 }
820 }
821
822 #[test]
823 fn cg_tridiagonal_small() {
824 let n = 10;
825 let matrix = tridiagonal_spd(n);
826 let rhs = vec![1.0f64; n];
827 let budget = default_budget();
828
829 let solver = ConjugateGradientSolver::new(1e-8, 200, false);
830 let result = solver.solve(&matrix, &rhs, &budget).unwrap();
831
832 assert!(
833 result.residual_norm < 1e-6,
834 "residual = {}",
835 result.residual_norm,
836 );
837 assert!(
838 result.iterations <= n,
839 "took {} iterations for n={}",
840 result.iterations,
841 n,
842 );
843 }
844
845 #[test]
846 fn cg_tridiagonal_large() {
847 let n = 500;
848 let matrix = tridiagonal_spd(n);
849 let rhs: Vec<f64> = (0..n).map(|i| (i as f64 + 1.0) / n as f64).collect();
850 let budget = default_budget();
851
852 let solver = ConjugateGradientSolver::new(1e-8, 2000, false);
853 let result = solver.solve(&matrix, &rhs, &budget).unwrap();
854
855 assert!(
856 result.residual_norm < 1e-5,
857 "residual = {}",
858 result.residual_norm,
859 );
860 }
861
862 #[test]
867 fn cg_preconditioned_converges_faster() {
868 let n = 100;
869 let matrix = tridiagonal_spd(n);
870 let rhs = vec![1.0f64; n];
871 let budget = default_budget();
872
873 let no_precond = ConjugateGradientSolver::new(1e-8, 500, false);
874 let with_precond = ConjugateGradientSolver::new(1e-8, 500, true);
875
876 let result_no = no_precond.solve(&matrix, &rhs, &budget).unwrap();
877 let result_yes = with_precond.solve(&matrix, &rhs, &budget).unwrap();
878
879 assert!(result_no.residual_norm < 1e-6);
880 assert!(result_yes.residual_norm < 1e-6);
881
882 assert!(
883 result_yes.iterations <= result_no.iterations,
884 "preconditioned ({}) should use <= iterations than \
885 unpreconditioned ({})",
886 result_yes.iterations,
887 result_no.iterations,
888 );
889 }
890
891 #[test]
896 fn cg_zero_rhs() {
897 let matrix = tridiagonal_spd(5);
898 let rhs = vec![0.0f64; 5];
899 let budget = default_budget();
900
901 let solver = ConjugateGradientSolver::new(1e-8, 100, false);
902 let result = solver.solve(&matrix, &rhs, &budget).unwrap();
903
904 assert_eq!(result.iterations, 0);
905 for &v in &result.solution {
906 assert!((v as f64).abs() < 1e-10);
907 }
908 }
909
910 #[test]
911 fn cg_empty_system() {
912 let matrix = CsrMatrix {
913 row_ptr: vec![0],
914 col_indices: vec![],
915 values: Vec::<f64>::new(),
916 rows: 0,
917 cols: 0,
918 };
919 let rhs: Vec<f64> = vec![];
920 let budget = default_budget();
921
922 let solver = ConjugateGradientSolver::new(1e-8, 100, false);
923 let result = solver.solve(&matrix, &rhs, &budget).unwrap();
924
925 assert_eq!(result.iterations, 0);
926 assert!(result.solution.is_empty());
927 }
928
929 #[test]
934 fn cg_dimension_mismatch() {
935 let matrix = tridiagonal_spd(3);
936 let rhs = vec![1.0f64; 5];
937 let budget = default_budget();
938
939 let solver = ConjugateGradientSolver::new(1e-8, 100, false);
940 let err = solver.solve(&matrix, &rhs, &budget).unwrap_err();
941 assert!(matches!(err, SolverError::InvalidInput(_)));
942 }
943
944 #[test]
945 fn cg_non_square_matrix() {
946 let matrix = CsrMatrix {
947 row_ptr: vec![0, 1, 2],
948 col_indices: vec![0, 1],
949 values: vec![1.0f64, 1.0],
950 rows: 2,
951 cols: 3,
952 };
953 let rhs = vec![1.0f64; 2];
954 let budget = default_budget();
955
956 let solver = ConjugateGradientSolver::new(1e-8, 100, false);
957 let err = solver.solve(&matrix, &rhs, &budget).unwrap_err();
958 assert!(matches!(err, SolverError::InvalidInput(_)));
959 }
960
961 #[test]
962 fn cg_non_convergence() {
963 let n = 50;
964 let matrix = tridiagonal_spd(n);
965 let rhs = vec![1.0f64; n];
966 let budget = ComputeBudget {
967 max_time: Duration::from_secs(30),
968 max_iterations: 1,
969 tolerance: 1e-15,
970 };
971
972 let solver = ConjugateGradientSolver::new(1e-15, 1, false);
973 let err = solver.solve(&matrix, &rhs, &budget).unwrap_err();
974 assert!(matches!(err, SolverError::NonConvergence { .. }));
975 }
976
977 #[test]
978 fn cg_budget_iteration_limit() {
979 let n = 50;
980 let matrix = tridiagonal_spd(n);
981 let rhs = vec![1.0f64; n];
982
983 let solver = ConjugateGradientSolver::new(1e-15, 1000, false);
985 let budget = ComputeBudget {
986 max_time: Duration::from_secs(60),
987 max_iterations: 2,
988 tolerance: 1e-15,
989 };
990
991 let err = solver.solve(&matrix, &rhs, &budget).unwrap_err();
992 assert!(matches!(err, SolverError::NonConvergence { .. }));
993 }
994
995 #[test]
1000 fn cg_convergence_history_populated() {
1001 let n = 20;
1002 let matrix = tridiagonal_spd(n);
1003 let rhs = vec![1.0f64; n];
1004 let budget = default_budget();
1005
1006 let solver = ConjugateGradientSolver::new(1e-10, 200, false);
1007 let result = solver.solve(&matrix, &rhs, &budget).unwrap();
1008
1009 assert!(!result.convergence_history.is_empty());
1010
1011 let last = result.convergence_history.last().unwrap();
1013 assert!((last.residual_norm - result.residual_norm).abs() < 1e-12);
1014 }
1015
1016 #[test]
1017 fn cg_algorithm_field() {
1018 let matrix = identity(3);
1019 let rhs = vec![1.0f64; 3];
1020 let budget = default_budget();
1021
1022 let solver = ConjugateGradientSolver::new(1e-8, 100, false);
1023 let result = solver.solve(&matrix, &rhs, &budget).unwrap();
1024 assert_eq!(result.algorithm, Algorithm::CG);
1025 }
1026
1027 #[test]
1032 fn cg_solution_satisfies_system() {
1033 let n = 20;
1034 let matrix = tridiagonal_spd(n);
1035 let rhs = vec![1.0f64; n];
1036 let budget = default_budget();
1037
1038 let solver = ConjugateGradientSolver::new(1e-10, 200, true);
1039 let result = solver.solve(&matrix, &rhs, &budget).unwrap();
1040
1041 let x_f64: Vec<f64> = result.solution.iter().map(|&v| v as f64).collect();
1043 let mut ax = vec![0.0f64; n];
1044 matrix.spmv(&x_f64, &mut ax);
1045
1046 let mut max_err: f64 = 0.0;
1047 for i in 0..n {
1048 let err = (ax[i] - rhs[i]).abs();
1049 if err > max_err {
1050 max_err = err;
1051 }
1052 }
1053
1054 assert!(
1055 max_err < 1e-4,
1056 "max |Ax - b| = {max_err:.6e}, expected < 1e-4",
1057 );
1058 }
1059
1060 #[test]
1065 fn estimate_complexity_returns_cg() {
1066 let solver = ConjugateGradientSolver::new(1e-8, 500, true);
1067 let profile = SparsityProfile {
1068 rows: 100,
1069 cols: 100,
1070 nnz: 298,
1071 density: 0.0298,
1072 is_diag_dominant: true,
1073 estimated_spectral_radius: 0.5,
1074 estimated_condition: 100.0,
1075 is_symmetric_structure: true,
1076 avg_nnz_per_row: 2.98,
1077 max_nnz_per_row: 3,
1078 };
1079
1080 let est = solver.estimate_complexity(&profile, 100);
1081 assert_eq!(est.algorithm, Algorithm::CG);
1082 assert_eq!(est.complexity_class, ComplexityClass::SqrtCondition);
1083 assert!(est.estimated_iterations > 0);
1084 assert!(est.estimated_flops > 0);
1085 assert!(est.estimated_memory_bytes > 0);
1086 }
1087
1088 #[test]
1093 fn accessors() {
1094 let solver = ConjugateGradientSolver::new(1e-6, 500, true);
1095 assert!((solver.tolerance() - 1e-6).abs() < 1e-15);
1096 assert_eq!(solver.max_iterations(), 500);
1097 assert!(solver.use_preconditioner());
1098 }
1099}