Skip to main content

ruvector_solver/
cg.rs

1//! Conjugate Gradient solver for symmetric positive-definite systems.
2//!
3//! Solves `Ax = b` where `A` is a symmetric positive-definite (SPD) sparse
4//! matrix in CSR format. The algorithm converges in at most `n` iterations
5//! for an `n x n` system, but in practice converges in
6//! `O(sqrt(kappa) * log(1/eps))` iterations where `kappa = cond(A)`.
7//!
8//! # Algorithm
9//!
10//! Implements the Hestenes-Stiefel variant of Conjugate Gradient:
11//!
12//! ```text
13//! r = b - A*x
14//! z = M^{-1} * r        (preconditioner; z = r when disabled)
15//! p = z
16//! rz = r . z
17//!
18//! for k in 0..max_iterations:
19//!     Ap = A * p
20//!     alpha = rz / (p . Ap)
21//!     x  = x + alpha * p
22//!     r  = r - alpha * Ap
23//!     if ||r||_2 < tolerance * ||b||_2:
24//!         converged; break
25//!     z  = M^{-1} * r
26//!     rz_new = r . z
27//!     beta = rz_new / rz
28//!     p  = z + beta * p
29//!     rz = rz_new
30//! ```
31//!
32//! # Preconditioning
33//!
34//! When `use_preconditioner` is `true`, a diagonal (Jacobi) preconditioner is
35//! applied: `M = diag(A)`, so that `z_i = r_i / A_{ii}`. This reduces the
36//! effective condition number for diagonally-dominant systems without adding
37//! significant per-iteration cost.
38//!
39//! # Numerical precision
40//!
41//! All dot products and norm computations use `f64` accumulation even though
42//! the matrix may store `f32` values, preventing catastrophic cancellation in
43//! the inner products that drive the CG recurrence.
44//!
45//! # Convergence
46//!
47//! Theoretical bound:
48//! `||x_k - x*||_A <= 2 * ((sqrt(kappa) - 1)/(sqrt(kappa) + 1))^k * ||x_0 - x*||_A`
49//! where `kappa` is the 2-condition number of `A`.
50
51use std::time::Instant;
52
53use tracing::{debug, trace, warn};
54
55use crate::error::{SolverError, ValidationError};
56use crate::traits::SolverEngine;
57use crate::types::{
58    Algorithm, ComplexityClass, ComplexityEstimate, ComputeBudget, ConvergenceInfo, CsrMatrix,
59    SolverResult, SparsityProfile,
60};
61
62// ═══════════════════════════════════════════════════════════════════════════
63// Helper functions -- f64-accumulated linear algebra primitives
64// ═══════════════════════════════════════════════════════════════════════════
65
66/// Compute the dot product of two `f32` slices using `f64` accumulation.
67///
68/// Uses a 4-wide accumulator to exploit instruction-level parallelism and
69/// reduce the dependency chain length, preventing precision loss in the
70/// inner products that drive the CG recurrence.
71///
72/// # Panics
73///
74/// Debug-asserts that `a.len() == b.len()`.
75#[inline]
76pub fn dot_product_f64(a: &[f32], b: &[f32]) -> f64 {
77    assert_eq!(a.len(), b.len(), "dot_product_f64: length mismatch");
78
79    let n = a.len();
80    let chunks = n / 4;
81    let remainder = n % 4;
82
83    let mut acc0: f64 = 0.0;
84    let mut acc1: f64 = 0.0;
85    let mut acc2: f64 = 0.0;
86    let mut acc3: f64 = 0.0;
87
88    for i in 0..chunks {
89        let j = i * 4;
90        acc0 += a[j] as f64 * b[j] as f64;
91        acc1 += a[j + 1] as f64 * b[j + 1] as f64;
92        acc2 += a[j + 2] as f64 * b[j + 2] as f64;
93        acc3 += a[j + 3] as f64 * b[j + 3] as f64;
94    }
95
96    let base = chunks * 4;
97    for i in 0..remainder {
98        acc0 += a[base + i] as f64 * b[base + i] as f64;
99    }
100
101    (acc0 + acc1) + (acc2 + acc3)
102}
103
104/// Compute the dot product of two `f64` slices with 4-wide accumulation.
105///
106/// Used internally when the working vectors are already `f64`.
107#[inline]
108fn dot_f64(a: &[f64], b: &[f64]) -> f64 {
109    assert_eq!(a.len(), b.len(), "dot_f64: length mismatch");
110
111    let n = a.len();
112    let chunks = n / 4;
113    let remainder = n % 4;
114
115    let mut acc0: f64 = 0.0;
116    let mut acc1: f64 = 0.0;
117    let mut acc2: f64 = 0.0;
118    let mut acc3: f64 = 0.0;
119
120    for i in 0..chunks {
121        let j = i * 4;
122        acc0 += a[j] * b[j];
123        acc1 += a[j + 1] * b[j + 1];
124        acc2 += a[j + 2] * b[j + 2];
125        acc3 += a[j + 3] * b[j + 3];
126    }
127
128    let base = chunks * 4;
129    for i in 0..remainder {
130        acc0 += a[base + i] * b[base + i];
131    }
132
133    (acc0 + acc1) + (acc2 + acc3)
134}
135
136/// Compute `y[i] += alpha * x[i]` for all `i` (AXPY operation, `f32`).
137///
138/// # Panics
139///
140/// Debug-asserts that `x.len() == y.len()`.
141#[inline]
142pub fn axpy(alpha: f32, x: &[f32], y: &mut [f32]) {
143    assert_eq!(x.len(), y.len(), "axpy: length mismatch");
144
145    let n = x.len();
146    let chunks = n / 4;
147    let base = chunks * 4;
148
149    for i in 0..chunks {
150        let j = i * 4;
151        y[j] += alpha * x[j];
152        y[j + 1] += alpha * x[j + 1];
153        y[j + 2] += alpha * x[j + 2];
154        y[j + 3] += alpha * x[j + 3];
155    }
156    for i in base..n {
157        y[i] += alpha * x[i];
158    }
159}
160
161/// Compute `y[i] += alpha * x[i]` for all `i` (AXPY operation, `f64`).
162#[inline]
163fn axpy_f64(alpha: f64, x: &[f64], y: &mut [f64]) {
164    assert_eq!(x.len(), y.len(), "axpy_f64: length mismatch");
165
166    let n = x.len();
167    let chunks = n / 4;
168    let base = chunks * 4;
169
170    for i in 0..chunks {
171        let j = i * 4;
172        y[j] += alpha * x[j];
173        y[j + 1] += alpha * x[j + 1];
174        y[j + 2] += alpha * x[j + 2];
175        y[j + 3] += alpha * x[j + 3];
176    }
177    for i in base..n {
178        y[i] += alpha * x[i];
179    }
180}
181
182/// Compute the L2 norm of an `f32` slice using `f64` accumulation.
183///
184/// Returns `sqrt(sum(x_i^2))` computed entirely in `f64`.
185#[inline]
186pub fn norm2(x: &[f32]) -> f64 {
187    dot_product_f64(x, x).sqrt()
188}
189
190/// Compute the L2 norm of an `f64` slice.
191#[inline]
192fn norm2_f64(x: &[f64]) -> f64 {
193    dot_f64(x, x).sqrt()
194}
195
196// ═══════════════════════════════════════════════════════════════════════════
197// ConjugateGradientSolver
198// ═══════════════════════════════════════════════════════════════════════════
199
200/// Conjugate Gradient solver for symmetric positive-definite sparse systems.
201///
202/// Stores the solver configuration (tolerance, iteration cap, preconditioning).
203/// The solve itself is stateless and may be invoked concurrently on different
204/// inputs from multiple threads.
205#[derive(Debug, Clone)]
206pub struct ConjugateGradientSolver {
207    /// Relative residual convergence tolerance.
208    ///
209    /// The solver declares convergence when `||r||_2 < tolerance * ||b||_2`.
210    tolerance: f64,
211
212    /// Maximum number of CG iterations before declaring non-convergence.
213    max_iterations: usize,
214
215    /// Whether to apply diagonal (Jacobi) preconditioning.
216    ///
217    /// When `true`, the preconditioner `M = diag(A)` is used, computing
218    /// `z_i = r_i / A_{ii}` each iteration. Beneficial for diagonally-dominant
219    /// systems at the cost of O(n) extra work per iteration.
220    use_preconditioner: bool,
221}
222
223impl ConjugateGradientSolver {
224    /// Create a new CG solver.
225    ///
226    /// # Arguments
227    ///
228    /// * `tolerance` -- Relative residual threshold for convergence. Must be
229    ///   positive and finite.
230    /// * `max_iterations` -- Upper bound on CG iterations. Must be >= 1.
231    /// * `use_preconditioner` -- Enable diagonal (Jacobi) preconditioning.
232    pub fn new(tolerance: f64, max_iterations: usize, use_preconditioner: bool) -> Self {
233        Self {
234            tolerance,
235            max_iterations,
236            use_preconditioner,
237        }
238    }
239
240    /// Return the configured tolerance.
241    #[inline]
242    pub fn tolerance(&self) -> f64 {
243        self.tolerance
244    }
245
246    /// Return the configured maximum iterations.
247    #[inline]
248    pub fn max_iterations(&self) -> usize {
249        self.max_iterations
250    }
251
252    /// Return whether preconditioning is enabled.
253    #[inline]
254    pub fn use_preconditioner(&self) -> bool {
255        self.use_preconditioner
256    }
257
258    // -------------------------------------------------------------------
259    // Input validation
260    // -------------------------------------------------------------------
261
262    /// Validate inputs before entering the CG loop.
263    fn validate(&self, matrix: &CsrMatrix<f64>, rhs: &[f64]) -> Result<(), SolverError> {
264        if matrix.rows != matrix.cols {
265            return Err(SolverError::InvalidInput(
266                ValidationError::DimensionMismatch(format!(
267                    "CG requires a square matrix but got {}x{}",
268                    matrix.rows, matrix.cols,
269                )),
270            ));
271        }
272
273        if rhs.len() != matrix.rows {
274            return Err(SolverError::InvalidInput(
275                ValidationError::DimensionMismatch(format!(
276                    "rhs length {} does not match matrix rows {}",
277                    rhs.len(),
278                    matrix.rows,
279                )),
280            ));
281        }
282
283        if matrix.row_ptr.len() != matrix.rows + 1 {
284            return Err(SolverError::InvalidInput(
285                ValidationError::DimensionMismatch(format!(
286                    "row_ptr length {} does not equal rows + 1 = {}",
287                    matrix.row_ptr.len(),
288                    matrix.rows + 1,
289                )),
290            ));
291        }
292
293        if !self.tolerance.is_finite() || self.tolerance <= 0.0 {
294            return Err(SolverError::InvalidInput(
295                ValidationError::ParameterOutOfRange {
296                    name: "tolerance".into(),
297                    value: self.tolerance.to_string(),
298                    expected: "positive finite value".into(),
299                },
300            ));
301        }
302
303        if self.max_iterations == 0 {
304            return Err(SolverError::InvalidInput(
305                ValidationError::ParameterOutOfRange {
306                    name: "max_iterations".into(),
307                    value: "0".into(),
308                    expected: ">= 1".into(),
309                },
310            ));
311        }
312
313        Ok(())
314    }
315
316    // -------------------------------------------------------------------
317    // Jacobi preconditioner
318    // -------------------------------------------------------------------
319
320    /// Build the Jacobi (diagonal) preconditioner from `A`.
321    ///
322    /// Returns `inv_diag[i] = 1.0 / A_{ii}`. Zero or near-zero diagonal
323    /// entries are replaced with `1.0` to prevent division by zero.
324    fn build_jacobi_preconditioner(matrix: &CsrMatrix<f64>) -> Vec<f64> {
325        let n = matrix.rows;
326        let mut inv_diag = vec![1.0f64; n];
327
328        for row in 0..n {
329            let start = matrix.row_ptr[row];
330            let end = matrix.row_ptr[row + 1];
331            for idx in start..end {
332                if matrix.col_indices[idx] == row {
333                    let diag_val = matrix.values[idx];
334                    if diag_val.abs() > f64::EPSILON {
335                        inv_diag[row] = 1.0 / diag_val;
336                    }
337                    break;
338                }
339            }
340        }
341
342        inv_diag
343    }
344
345    /// Apply the diagonal preconditioner: `z[i] = inv_diag[i] * r[i]`.
346    #[inline]
347    fn apply_preconditioner(inv_diag: &[f64], r: &[f64], z: &mut [f64]) {
348        assert_eq!(inv_diag.len(), r.len());
349        assert_eq!(r.len(), z.len());
350
351        let n = r.len();
352        let chunks = n / 4;
353        let base = chunks * 4;
354
355        for i in 0..chunks {
356            let j = i * 4;
357            z[j] = inv_diag[j] * r[j];
358            z[j + 1] = inv_diag[j + 1] * r[j + 1];
359            z[j + 2] = inv_diag[j + 2] * r[j + 2];
360            z[j + 3] = inv_diag[j + 3] * r[j + 3];
361        }
362        for i in base..n {
363            z[i] = inv_diag[i] * r[i];
364        }
365    }
366
367    // -------------------------------------------------------------------
368    // Core CG algorithm
369    // -------------------------------------------------------------------
370
371    /// Core CG algorithm implementation.
372    ///
373    /// Works entirely in `f64` precision internally. The final solution is
374    /// down-cast to `f32` in the returned [`SolverResult`] to match the
375    /// crate's output contract.
376    fn solve_inner(
377        &self,
378        matrix: &CsrMatrix<f64>,
379        rhs: &[f64],
380        budget: &ComputeBudget,
381    ) -> Result<SolverResult, SolverError> {
382        let start_time = Instant::now();
383        let n = matrix.rows;
384
385        // Effective limits: take the tighter of our config and the budget.
386        let effective_max_iter = self.max_iterations.min(budget.max_iterations);
387        let effective_tol = self.tolerance.min(budget.tolerance);
388
389        // --- Trivial case: zero-dimensional system ---
390        if n == 0 {
391            return Ok(SolverResult {
392                solution: vec![],
393                iterations: 0,
394                residual_norm: 0.0,
395                wall_time: start_time.elapsed(),
396                convergence_history: vec![],
397                algorithm: Algorithm::CG,
398            });
399        }
400
401        // --- Allocate working vectors (all f64) ---
402        let mut x = vec![0.0f64; n]; // solution (initial guess = zero)
403        let mut r = vec![0.0f64; n]; // residual
404        let mut z = vec![0.0f64; n]; // preconditioned residual
405        let mut p = vec![0.0f64; n]; // search direction
406        let mut ap = vec![0.0f64; n]; // A * p scratch buffer
407
408        // --- Build preconditioner (if enabled) ---
409        let inv_diag = if self.use_preconditioner {
410            Some(Self::build_jacobi_preconditioner(matrix))
411        } else {
412            None
413        };
414
415        // --- r = b - A*x. Since x = 0, r = b ---
416        r.copy_from_slice(rhs);
417
418        // --- Convergence threshold: ||r||_2 < tol * ||b||_2 (relative) ---
419        let b_norm = norm2_f64(rhs);
420        let abs_tolerance = effective_tol * b_norm;
421
422        // Handle zero RHS: the solution is the zero vector.
423        if b_norm < f64::EPSILON {
424            debug!("CG: zero RHS detected, returning zero solution");
425            return Ok(SolverResult {
426                solution: vec![0.0f32; n],
427                iterations: 0,
428                residual_norm: 0.0,
429                wall_time: start_time.elapsed(),
430                convergence_history: vec![],
431                algorithm: Algorithm::CG,
432            });
433        }
434
435        let initial_residual_norm = norm2_f64(&r);
436
437        // --- z = M^{-1} * r ---
438        match &inv_diag {
439            Some(diag) => Self::apply_preconditioner(diag, &r, &mut z),
440            None => z.copy_from_slice(&r),
441        }
442
443        // --- p = z ---
444        p.copy_from_slice(&z);
445
446        // --- rz = r . z ---
447        let mut rz = dot_f64(&r, &z);
448
449        let mut convergence_history = Vec::with_capacity(effective_max_iter.min(256));
450        let mut converged = false;
451
452        debug!(
453            "CG: n={}, nnz={}, tol={:.2e}, max_iter={}, precond={}",
454            n,
455            matrix.nnz(),
456            effective_tol,
457            effective_max_iter,
458            self.use_preconditioner,
459        );
460
461        // ===============================================================
462        // Main CG loop (Hestenes-Stiefel)
463        // ===============================================================
464        for k in 0..effective_max_iter {
465            // --- Budget: wall-time check ---
466            if start_time.elapsed() > budget.max_time {
467                warn!("CG: wall-time budget exhausted at iteration {k}");
468                return Err(SolverError::BudgetExhausted {
469                    reason: format!(
470                        "wall-time limit {:?} exceeded at iteration {k}",
471                        budget.max_time,
472                    ),
473                    elapsed: start_time.elapsed(),
474                });
475            }
476
477            // --- Ap = A * p  (sparse matrix-vector product) ---
478            matrix.spmv(&p, &mut ap);
479
480            // --- alpha = rz / (p . Ap) ---
481            let p_dot_ap = dot_f64(&p, &ap);
482
483            // Guard: if p.Ap <= 0 the matrix is not SPD or we hit numerical
484            // breakdown.
485            if p_dot_ap <= 0.0 {
486                warn!("CG: non-positive p.Ap = {p_dot_ap:.4e} at iteration {k}");
487                return Err(SolverError::NumericalInstability {
488                    iteration: k,
489                    detail: format!("p.Ap = {p_dot_ap:.6e} <= 0; matrix may not be SPD",),
490                });
491            }
492
493            let alpha = rz / p_dot_ap;
494
495            // --- x = x + alpha * p ---
496            axpy_f64(alpha, &p, &mut x);
497
498            // --- r = r - alpha * Ap ---
499            axpy_f64(-alpha, &ap, &mut r);
500
501            // --- Convergence check: ||r||_2 < tol * ||b||_2 ---
502            let r_norm = norm2_f64(&r);
503
504            convergence_history.push(ConvergenceInfo {
505                iteration: k,
506                residual_norm: r_norm,
507            });
508
509            trace!(
510                "CG iter {k}: ||r|| = {r_norm:.6e}, rel = {:.6e}",
511                r_norm / b_norm,
512            );
513
514            if r_norm < abs_tolerance {
515                converged = true;
516                debug!(
517                    "CG converged at iteration {k}: ||r|| = {r_norm:.6e}, \
518                     rel = {:.6e}",
519                    r_norm / b_norm,
520                );
521                break;
522            }
523
524            // --- Divergence detection ---
525            // If ||r|| has grown by 10x from the initial residual, the
526            // system is likely indefinite or ill-conditioned beyond rescue.
527            if r_norm > 10.0 * initial_residual_norm {
528                warn!(
529                    "CG: divergence at iteration {k}: ||r|| = {r_norm:.6e} \
530                     > 10 * ||r_0|| = {:.6e}",
531                    10.0 * initial_residual_norm,
532                );
533                return Err(SolverError::NumericalInstability {
534                    iteration: k,
535                    detail: format!(
536                        "residual diverged: ||r|| = {r_norm:.6e} exceeds \
537                         10x initial residual {initial_residual_norm:.6e}",
538                    ),
539                });
540            }
541
542            // --- z = M^{-1} * r ---
543            match &inv_diag {
544                Some(diag) => Self::apply_preconditioner(diag, &r, &mut z),
545                None => z.copy_from_slice(&r),
546            }
547
548            // --- rz_new = r . z ---
549            let rz_new = dot_f64(&r, &z);
550
551            // Guard: stagnation when rz is near-zero.
552            if rz.abs() < f64::EPSILON * f64::EPSILON {
553                warn!("CG: rz near zero at iteration {k}, stagnation");
554                return Err(SolverError::NumericalInstability {
555                    iteration: k,
556                    detail: format!("rz = {rz:.6e} is near zero; solver stagnated",),
557                });
558            }
559
560            // --- beta = rz_new / rz ---
561            let beta = rz_new / rz;
562
563            // --- p = z + beta * p ---
564            for i in 0..n {
565                p[i] = z[i] + beta * p[i];
566            }
567
568            // --- rz = rz_new ---
569            rz = rz_new;
570        }
571
572        let wall_time = start_time.elapsed();
573        let final_residual = norm2_f64(&r);
574
575        if !converged {
576            debug!(
577                "CG: non-convergence after {} iterations, ||r|| = {final_residual:.6e}",
578                effective_max_iter,
579            );
580            return Err(SolverError::NonConvergence {
581                iterations: effective_max_iter,
582                residual: final_residual,
583                tolerance: abs_tolerance,
584            });
585        }
586
587        // Down-cast the f64 solution to f32 for SolverResult.
588        let solution_f32: Vec<f32> = x.iter().map(|&v| v as f32).collect();
589
590        Ok(SolverResult {
591            solution: solution_f32,
592            iterations: convergence_history.len(),
593            residual_norm: final_residual,
594            wall_time,
595            convergence_history,
596            algorithm: Algorithm::CG,
597        })
598    }
599}
600
601// ═══════════════════════════════════════════════════════════════════════════
602// SolverEngine trait implementation
603// ═══════════════════════════════════════════════════════════════════════════
604
605impl SolverEngine for ConjugateGradientSolver {
606    /// Solve `Ax = b` using the Conjugate Gradient method.
607    ///
608    /// # Errors
609    ///
610    /// * [`SolverError::InvalidInput`] -- dimension mismatch or invalid params.
611    /// * [`SolverError::NumericalInstability`] -- divergence or non-SPD matrix.
612    /// * [`SolverError::NonConvergence`] -- iteration limit exceeded.
613    /// * [`SolverError::BudgetExhausted`] -- wall-time limit exceeded.
614    fn solve(
615        &self,
616        matrix: &CsrMatrix<f64>,
617        rhs: &[f64],
618        budget: &ComputeBudget,
619    ) -> Result<SolverResult, SolverError> {
620        self.validate(matrix, rhs)?;
621        self.solve_inner(matrix, rhs, budget)
622    }
623
624    /// Estimate CG complexity from the sparsity profile.
625    ///
626    /// CG converges in `O(sqrt(kappa))` iterations, each costing `O(nnz)` for
627    /// the SpMV plus `O(n)` for the vector updates.
628    fn estimate_complexity(&self, profile: &SparsityProfile, n: usize) -> ComplexityEstimate {
629        // Estimated iterations from condition number, clamped to max_iterations.
630        let est_iters = (profile.estimated_condition.sqrt() as usize)
631            .max(1)
632            .min(self.max_iterations);
633
634        // FLOPs per iteration: 2*nnz (SpMV) + 6*n (dot products, axpy ops).
635        let flops_per_iter = 2 * profile.nnz as u64 + 6 * n as u64;
636        let estimated_flops = est_iters as u64 * flops_per_iter;
637
638        // Memory: 5 vectors of length n (x, r, z, p, Ap) plus preconditioner.
639        let vec_bytes = n * std::mem::size_of::<f64>();
640        let precond_bytes = if self.use_preconditioner {
641            vec_bytes
642        } else {
643            0
644        };
645        let estimated_memory_bytes = 5 * vec_bytes + precond_bytes;
646
647        ComplexityEstimate {
648            algorithm: Algorithm::CG,
649            estimated_flops,
650            estimated_iterations: est_iters,
651            estimated_memory_bytes,
652            complexity_class: ComplexityClass::SqrtCondition,
653        }
654    }
655
656    /// Return the algorithm identifier.
657    fn algorithm(&self) -> Algorithm {
658        Algorithm::CG
659    }
660}
661
662// ═══════════════════════════════════════════════════════════════════════════
663// Tests
664// ═══════════════════════════════════════════════════════════════════════════
665
666#[cfg(test)]
667mod tests {
668    use super::*;
669    use std::time::Duration;
670
671    /// Build a symmetric tridiagonal SPD matrix (f64):
672    ///   diag = 4.0, off-diag = -1.0
673    fn tridiagonal_spd(n: usize) -> CsrMatrix<f64> {
674        let mut entries = Vec::with_capacity(3 * n);
675        for i in 0..n {
676            if i > 0 {
677                entries.push((i, i - 1, -1.0f64));
678            }
679            entries.push((i, i, 4.0f64));
680            if i + 1 < n {
681                entries.push((i, i + 1, -1.0f64));
682            }
683        }
684        CsrMatrix::<f64>::from_coo(n, n, entries)
685    }
686
687    /// Build a diagonal matrix from the given values.
688    fn diagonal_matrix(diag: &[f64]) -> CsrMatrix<f64> {
689        let n = diag.len();
690        let entries: Vec<_> = diag.iter().enumerate().map(|(i, &v)| (i, i, v)).collect();
691        CsrMatrix::<f64>::from_coo(n, n, entries)
692    }
693
694    /// Build an identity matrix of dimension n.
695    fn identity(n: usize) -> CsrMatrix<f64> {
696        CsrMatrix::<f64>::identity(n)
697    }
698
699    fn default_budget() -> ComputeBudget {
700        ComputeBudget {
701            max_time: Duration::from_secs(30),
702            max_iterations: 10_000,
703            tolerance: 1e-10,
704        }
705    }
706
707    // -----------------------------------------------------------------
708    // dot_product_f64
709    // -----------------------------------------------------------------
710
711    #[test]
712    fn dot_product_f64_basic() {
713        let a = vec![1.0f32, 2.0, 3.0];
714        let b = vec![4.0f32, 5.0, 6.0];
715        let result = dot_product_f64(&a, &b);
716        assert!((result - 32.0).abs() < 1e-10);
717    }
718
719    #[test]
720    fn dot_product_f64_empty() {
721        assert!((dot_product_f64(&[], &[]) - 0.0).abs() < 1e-10);
722    }
723
724    #[test]
725    fn dot_product_f64_precision() {
726        let n = 10_000;
727        let a = vec![1.0f32; n];
728        let b = vec![1.0f32; n];
729        assert!((dot_product_f64(&a, &b) - n as f64).abs() < 1e-10);
730    }
731
732    #[test]
733    fn dot_product_f64_odd_length() {
734        let a = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
735        let b = vec![5.0f32, 4.0, 3.0, 2.0, 1.0];
736        // 5 + 8 + 9 + 8 + 5 = 35
737        assert!((dot_product_f64(&a, &b) - 35.0).abs() < 1e-10);
738    }
739
740    // -----------------------------------------------------------------
741    // axpy
742    // -----------------------------------------------------------------
743
744    #[test]
745    fn axpy_basic() {
746        let x = vec![1.0f32, 2.0, 3.0];
747        let mut y = vec![10.0f32, 20.0, 30.0];
748        axpy(2.0, &x, &mut y);
749        assert_eq!(y, vec![12.0, 24.0, 36.0]);
750    }
751
752    #[test]
753    fn axpy_negative_alpha() {
754        let x = vec![1.0f32, 1.0, 1.0];
755        let mut y = vec![5.0f32, 5.0, 5.0];
756        axpy(-3.0, &x, &mut y);
757        assert_eq!(y, vec![2.0, 2.0, 2.0]);
758    }
759
760    // -----------------------------------------------------------------
761    // norm2
762    // -----------------------------------------------------------------
763
764    #[test]
765    fn norm2_basic() {
766        let x = vec![3.0f32, 4.0];
767        assert!((norm2(&x) - 5.0).abs() < 1e-10);
768    }
769
770    #[test]
771    fn norm2_zero() {
772        assert!((norm2(&vec![0.0f32; 5]) - 0.0).abs() < 1e-10);
773    }
774
775    // -----------------------------------------------------------------
776    // CG solver: convergence on well-conditioned systems
777    // -----------------------------------------------------------------
778
779    #[test]
780    fn cg_identity_matrix() {
781        let n = 5;
782        let matrix = identity(n);
783        let rhs = vec![1.0, 2.0, 3.0, 4.0, 5.0];
784        let budget = default_budget();
785
786        let solver = ConjugateGradientSolver::new(1e-10, 100, false);
787        let result = solver.solve(&matrix, &rhs, &budget).unwrap();
788
789        for i in 0..n {
790            assert!(
791                (result.solution[i] as f64 - rhs[i]).abs() < 1e-5,
792                "x[{i}] = {} != {}",
793                result.solution[i],
794                rhs[i],
795            );
796        }
797        // Identity should converge in at most 1 iteration.
798        assert!(result.iterations <= 1);
799    }
800
801    #[test]
802    fn cg_diagonal_matrix() {
803        let diag = vec![2.0, 3.0, 5.0, 7.0];
804        let matrix = diagonal_matrix(&diag);
805        let rhs = vec![4.0, 9.0, 25.0, 49.0];
806        let budget = default_budget();
807
808        let solver = ConjugateGradientSolver::new(1e-10, 100, false);
809        let result = solver.solve(&matrix, &rhs, &budget).unwrap();
810
811        let expected = [2.0, 3.0, 5.0, 7.0];
812        for i in 0..4 {
813            assert!(
814                (result.solution[i] as f64 - expected[i]).abs() < 1e-4,
815                "x[{i}] = {} != {}",
816                result.solution[i],
817                expected[i],
818            );
819        }
820    }
821
822    #[test]
823    fn cg_tridiagonal_small() {
824        let n = 10;
825        let matrix = tridiagonal_spd(n);
826        let rhs = vec![1.0f64; n];
827        let budget = default_budget();
828
829        let solver = ConjugateGradientSolver::new(1e-8, 200, false);
830        let result = solver.solve(&matrix, &rhs, &budget).unwrap();
831
832        assert!(
833            result.residual_norm < 1e-6,
834            "residual = {}",
835            result.residual_norm,
836        );
837        assert!(
838            result.iterations <= n,
839            "took {} iterations for n={}",
840            result.iterations,
841            n,
842        );
843    }
844
845    #[test]
846    fn cg_tridiagonal_large() {
847        let n = 500;
848        let matrix = tridiagonal_spd(n);
849        let rhs: Vec<f64> = (0..n).map(|i| (i as f64 + 1.0) / n as f64).collect();
850        let budget = default_budget();
851
852        let solver = ConjugateGradientSolver::new(1e-8, 2000, false);
853        let result = solver.solve(&matrix, &rhs, &budget).unwrap();
854
855        assert!(
856            result.residual_norm < 1e-5,
857            "residual = {}",
858            result.residual_norm,
859        );
860    }
861
862    // -----------------------------------------------------------------
863    // Preconditioning
864    // -----------------------------------------------------------------
865
866    #[test]
867    fn cg_preconditioned_converges_faster() {
868        let n = 100;
869        let matrix = tridiagonal_spd(n);
870        let rhs = vec![1.0f64; n];
871        let budget = default_budget();
872
873        let no_precond = ConjugateGradientSolver::new(1e-8, 500, false);
874        let with_precond = ConjugateGradientSolver::new(1e-8, 500, true);
875
876        let result_no = no_precond.solve(&matrix, &rhs, &budget).unwrap();
877        let result_yes = with_precond.solve(&matrix, &rhs, &budget).unwrap();
878
879        assert!(result_no.residual_norm < 1e-6);
880        assert!(result_yes.residual_norm < 1e-6);
881
882        assert!(
883            result_yes.iterations <= result_no.iterations,
884            "preconditioned ({}) should use <= iterations than \
885             unpreconditioned ({})",
886            result_yes.iterations,
887            result_no.iterations,
888        );
889    }
890
891    // -----------------------------------------------------------------
892    // Zero RHS / empty system
893    // -----------------------------------------------------------------
894
895    #[test]
896    fn cg_zero_rhs() {
897        let matrix = tridiagonal_spd(5);
898        let rhs = vec![0.0f64; 5];
899        let budget = default_budget();
900
901        let solver = ConjugateGradientSolver::new(1e-8, 100, false);
902        let result = solver.solve(&matrix, &rhs, &budget).unwrap();
903
904        assert_eq!(result.iterations, 0);
905        for &v in &result.solution {
906            assert!((v as f64).abs() < 1e-10);
907        }
908    }
909
910    #[test]
911    fn cg_empty_system() {
912        let matrix = CsrMatrix {
913            row_ptr: vec![0],
914            col_indices: vec![],
915            values: Vec::<f64>::new(),
916            rows: 0,
917            cols: 0,
918        };
919        let rhs: Vec<f64> = vec![];
920        let budget = default_budget();
921
922        let solver = ConjugateGradientSolver::new(1e-8, 100, false);
923        let result = solver.solve(&matrix, &rhs, &budget).unwrap();
924
925        assert_eq!(result.iterations, 0);
926        assert!(result.solution.is_empty());
927    }
928
929    // -----------------------------------------------------------------
930    // Error cases
931    // -----------------------------------------------------------------
932
933    #[test]
934    fn cg_dimension_mismatch() {
935        let matrix = tridiagonal_spd(3);
936        let rhs = vec![1.0f64; 5];
937        let budget = default_budget();
938
939        let solver = ConjugateGradientSolver::new(1e-8, 100, false);
940        let err = solver.solve(&matrix, &rhs, &budget).unwrap_err();
941        assert!(matches!(err, SolverError::InvalidInput(_)));
942    }
943
944    #[test]
945    fn cg_non_square_matrix() {
946        let matrix = CsrMatrix {
947            row_ptr: vec![0, 1, 2],
948            col_indices: vec![0, 1],
949            values: vec![1.0f64, 1.0],
950            rows: 2,
951            cols: 3,
952        };
953        let rhs = vec![1.0f64; 2];
954        let budget = default_budget();
955
956        let solver = ConjugateGradientSolver::new(1e-8, 100, false);
957        let err = solver.solve(&matrix, &rhs, &budget).unwrap_err();
958        assert!(matches!(err, SolverError::InvalidInput(_)));
959    }
960
961    #[test]
962    fn cg_non_convergence() {
963        let n = 50;
964        let matrix = tridiagonal_spd(n);
965        let rhs = vec![1.0f64; n];
966        let budget = ComputeBudget {
967            max_time: Duration::from_secs(30),
968            max_iterations: 1,
969            tolerance: 1e-15,
970        };
971
972        let solver = ConjugateGradientSolver::new(1e-15, 1, false);
973        let err = solver.solve(&matrix, &rhs, &budget).unwrap_err();
974        assert!(matches!(err, SolverError::NonConvergence { .. }));
975    }
976
977    #[test]
978    fn cg_budget_iteration_limit() {
979        let n = 50;
980        let matrix = tridiagonal_spd(n);
981        let rhs = vec![1.0f64; n];
982
983        // Solver allows 1000, but budget allows only 2.
984        let solver = ConjugateGradientSolver::new(1e-15, 1000, false);
985        let budget = ComputeBudget {
986            max_time: Duration::from_secs(60),
987            max_iterations: 2,
988            tolerance: 1e-15,
989        };
990
991        let err = solver.solve(&matrix, &rhs, &budget).unwrap_err();
992        assert!(matches!(err, SolverError::NonConvergence { .. }));
993    }
994
995    // -----------------------------------------------------------------
996    // Convergence history
997    // -----------------------------------------------------------------
998
999    #[test]
1000    fn cg_convergence_history_populated() {
1001        let n = 20;
1002        let matrix = tridiagonal_spd(n);
1003        let rhs = vec![1.0f64; n];
1004        let budget = default_budget();
1005
1006        let solver = ConjugateGradientSolver::new(1e-10, 200, false);
1007        let result = solver.solve(&matrix, &rhs, &budget).unwrap();
1008
1009        assert!(!result.convergence_history.is_empty());
1010
1011        // Final history entry should match the reported residual.
1012        let last = result.convergence_history.last().unwrap();
1013        assert!((last.residual_norm - result.residual_norm).abs() < 1e-12);
1014    }
1015
1016    #[test]
1017    fn cg_algorithm_field() {
1018        let matrix = identity(3);
1019        let rhs = vec![1.0f64; 3];
1020        let budget = default_budget();
1021
1022        let solver = ConjugateGradientSolver::new(1e-8, 100, false);
1023        let result = solver.solve(&matrix, &rhs, &budget).unwrap();
1024        assert_eq!(result.algorithm, Algorithm::CG);
1025    }
1026
1027    // -----------------------------------------------------------------
1028    // Verify Ax = b for computed solution
1029    // -----------------------------------------------------------------
1030
1031    #[test]
1032    fn cg_solution_satisfies_system() {
1033        let n = 20;
1034        let matrix = tridiagonal_spd(n);
1035        let rhs = vec![1.0f64; n];
1036        let budget = default_budget();
1037
1038        let solver = ConjugateGradientSolver::new(1e-10, 200, true);
1039        let result = solver.solve(&matrix, &rhs, &budget).unwrap();
1040
1041        // Up-cast solution back to f64 and compute Ax.
1042        let x_f64: Vec<f64> = result.solution.iter().map(|&v| v as f64).collect();
1043        let mut ax = vec![0.0f64; n];
1044        matrix.spmv(&x_f64, &mut ax);
1045
1046        let mut max_err: f64 = 0.0;
1047        for i in 0..n {
1048            let err = (ax[i] - rhs[i]).abs();
1049            if err > max_err {
1050                max_err = err;
1051            }
1052        }
1053
1054        assert!(
1055            max_err < 1e-4,
1056            "max |Ax - b| = {max_err:.6e}, expected < 1e-4",
1057        );
1058    }
1059
1060    // -----------------------------------------------------------------
1061    // Estimate complexity
1062    // -----------------------------------------------------------------
1063
1064    #[test]
1065    fn estimate_complexity_returns_cg() {
1066        let solver = ConjugateGradientSolver::new(1e-8, 500, true);
1067        let profile = SparsityProfile {
1068            rows: 100,
1069            cols: 100,
1070            nnz: 298,
1071            density: 0.0298,
1072            is_diag_dominant: true,
1073            estimated_spectral_radius: 0.5,
1074            estimated_condition: 100.0,
1075            is_symmetric_structure: true,
1076            avg_nnz_per_row: 2.98,
1077            max_nnz_per_row: 3,
1078        };
1079
1080        let est = solver.estimate_complexity(&profile, 100);
1081        assert_eq!(est.algorithm, Algorithm::CG);
1082        assert_eq!(est.complexity_class, ComplexityClass::SqrtCondition);
1083        assert!(est.estimated_iterations > 0);
1084        assert!(est.estimated_flops > 0);
1085        assert!(est.estimated_memory_bytes > 0);
1086    }
1087
1088    // -----------------------------------------------------------------
1089    // Accessors
1090    // -----------------------------------------------------------------
1091
1092    #[test]
1093    fn accessors() {
1094        let solver = ConjugateGradientSolver::new(1e-6, 500, true);
1095        assert!((solver.tolerance() - 1e-6).abs() < 1e-15);
1096        assert_eq!(solver.max_iterations(), 500);
1097        assert!(solver.use_preconditioner());
1098    }
1099}