Skip to main content

ruvector_solver/
cg.rs

1//! Conjugate Gradient solver for symmetric positive-definite systems.
2//!
3//! Solves `Ax = b` where `A` is a symmetric positive-definite (SPD) sparse
4//! matrix in CSR format. The algorithm converges in at most `n` iterations
5//! for an `n x n` system, but in practice converges in
6//! `O(sqrt(kappa) * log(1/eps))` iterations where `kappa = cond(A)`.
7//!
8//! # Algorithm
9//!
10//! Implements the Hestenes-Stiefel variant of Conjugate Gradient:
11//!
12//! ```text
13//! r = b - A*x
14//! z = M^{-1} * r        (preconditioner; z = r when disabled)
15//! p = z
16//! rz = r . z
17//!
18//! for k in 0..max_iterations:
19//!     Ap = A * p
20//!     alpha = rz / (p . Ap)
21//!     x  = x + alpha * p
22//!     r  = r - alpha * Ap
23//!     if ||r||_2 < tolerance * ||b||_2:
24//!         converged; break
25//!     z  = M^{-1} * r
26//!     rz_new = r . z
27//!     beta = rz_new / rz
28//!     p  = z + beta * p
29//!     rz = rz_new
30//! ```
31//!
32//! # Preconditioning
33//!
34//! When `use_preconditioner` is `true`, a diagonal (Jacobi) preconditioner is
35//! applied: `M = diag(A)`, so that `z_i = r_i / A_{ii}`. This reduces the
36//! effective condition number for diagonally-dominant systems without adding
37//! significant per-iteration cost.
38//!
39//! # Numerical precision
40//!
41//! All dot products and norm computations use `f64` accumulation even though
42//! the matrix may store `f32` values, preventing catastrophic cancellation in
43//! the inner products that drive the CG recurrence.
44//!
45//! # Convergence
46//!
47//! Theoretical bound:
48//! `||x_k - x*||_A <= 2 * ((sqrt(kappa) - 1)/(sqrt(kappa) + 1))^k * ||x_0 - x*||_A`
49//! where `kappa` is the 2-condition number of `A`.
50
51use std::time::Instant;
52
53use tracing::{debug, trace, warn};
54
55use crate::error::{SolverError, ValidationError};
56use crate::traits::SolverEngine;
57use crate::types::{
58    Algorithm, ComplexityClass, ComplexityEstimate, ComputeBudget, ConvergenceInfo, CsrMatrix,
59    SolverResult, SparsityProfile,
60};
61
62// ═══════════════════════════════════════════════════════════════════════════
63// Helper functions -- f64-accumulated linear algebra primitives
64// ═══════════════════════════════════════════════════════════════════════════
65
66/// Compute the dot product of two `f32` slices using `f64` accumulation.
67///
68/// Uses a 4-wide accumulator to exploit instruction-level parallelism and
69/// reduce the dependency chain length, preventing precision loss in the
70/// inner products that drive the CG recurrence.
71///
72/// # Panics
73///
74/// Debug-asserts that `a.len() == b.len()`.
75#[inline]
76pub fn dot_product_f64(a: &[f32], b: &[f32]) -> f64 {
77    assert_eq!(a.len(), b.len(), "dot_product_f64: length mismatch");
78
79    let n = a.len();
80    let chunks = n / 4;
81    let remainder = n % 4;
82
83    let mut acc0: f64 = 0.0;
84    let mut acc1: f64 = 0.0;
85    let mut acc2: f64 = 0.0;
86    let mut acc3: f64 = 0.0;
87
88    for i in 0..chunks {
89        let j = i * 4;
90        acc0 += a[j] as f64 * b[j] as f64;
91        acc1 += a[j + 1] as f64 * b[j + 1] as f64;
92        acc2 += a[j + 2] as f64 * b[j + 2] as f64;
93        acc3 += a[j + 3] as f64 * b[j + 3] as f64;
94    }
95
96    let base = chunks * 4;
97    for i in 0..remainder {
98        acc0 += a[base + i] as f64 * b[base + i] as f64;
99    }
100
101    (acc0 + acc1) + (acc2 + acc3)
102}
103
104/// Compute the dot product of two `f64` slices with 4-wide accumulation.
105///
106/// Used internally when the working vectors are already `f64`.
107#[inline]
108fn dot_f64(a: &[f64], b: &[f64]) -> f64 {
109    assert_eq!(a.len(), b.len(), "dot_f64: length mismatch");
110
111    let n = a.len();
112    let chunks = n / 4;
113    let remainder = n % 4;
114
115    let mut acc0: f64 = 0.0;
116    let mut acc1: f64 = 0.0;
117    let mut acc2: f64 = 0.0;
118    let mut acc3: f64 = 0.0;
119
120    for i in 0..chunks {
121        let j = i * 4;
122        acc0 += a[j] * b[j];
123        acc1 += a[j + 1] * b[j + 1];
124        acc2 += a[j + 2] * b[j + 2];
125        acc3 += a[j + 3] * b[j + 3];
126    }
127
128    let base = chunks * 4;
129    for i in 0..remainder {
130        acc0 += a[base + i] * b[base + i];
131    }
132
133    (acc0 + acc1) + (acc2 + acc3)
134}
135
136/// Compute `y[i] += alpha * x[i]` for all `i` (AXPY operation, `f32`).
137///
138/// # Panics
139///
140/// Debug-asserts that `x.len() == y.len()`.
141#[inline]
142pub fn axpy(alpha: f32, x: &[f32], y: &mut [f32]) {
143    assert_eq!(x.len(), y.len(), "axpy: length mismatch");
144
145    let n = x.len();
146    let chunks = n / 4;
147    let base = chunks * 4;
148
149    for i in 0..chunks {
150        let j = i * 4;
151        y[j] += alpha * x[j];
152        y[j + 1] += alpha * x[j + 1];
153        y[j + 2] += alpha * x[j + 2];
154        y[j + 3] += alpha * x[j + 3];
155    }
156    for i in base..n {
157        y[i] += alpha * x[i];
158    }
159}
160
161/// Compute `y[i] += alpha * x[i]` for all `i` (AXPY operation, `f64`).
162#[inline]
163fn axpy_f64(alpha: f64, x: &[f64], y: &mut [f64]) {
164    assert_eq!(x.len(), y.len(), "axpy_f64: length mismatch");
165
166    let n = x.len();
167    let chunks = n / 4;
168    let base = chunks * 4;
169
170    for i in 0..chunks {
171        let j = i * 4;
172        y[j] += alpha * x[j];
173        y[j + 1] += alpha * x[j + 1];
174        y[j + 2] += alpha * x[j + 2];
175        y[j + 3] += alpha * x[j + 3];
176    }
177    for i in base..n {
178        y[i] += alpha * x[i];
179    }
180}
181
182/// Compute the L2 norm of an `f32` slice using `f64` accumulation.
183///
184/// Returns `sqrt(sum(x_i^2))` computed entirely in `f64`.
185#[inline]
186pub fn norm2(x: &[f32]) -> f64 {
187    dot_product_f64(x, x).sqrt()
188}
189
190/// Compute the L2 norm of an `f64` slice.
191#[inline]
192fn norm2_f64(x: &[f64]) -> f64 {
193    dot_f64(x, x).sqrt()
194}
195
196// ═══════════════════════════════════════════════════════════════════════════
197// ConjugateGradientSolver
198// ═══════════════════════════════════════════════════════════════════════════
199
200/// Conjugate Gradient solver for symmetric positive-definite sparse systems.
201///
202/// Stores the solver configuration (tolerance, iteration cap, preconditioning).
203/// The solve itself is stateless and may be invoked concurrently on different
204/// inputs from multiple threads.
205#[derive(Debug, Clone)]
206pub struct ConjugateGradientSolver {
207    /// Relative residual convergence tolerance.
208    ///
209    /// The solver declares convergence when `||r||_2 < tolerance * ||b||_2`.
210    tolerance: f64,
211
212    /// Maximum number of CG iterations before declaring non-convergence.
213    max_iterations: usize,
214
215    /// Whether to apply diagonal (Jacobi) preconditioning.
216    ///
217    /// When `true`, the preconditioner `M = diag(A)` is used, computing
218    /// `z_i = r_i / A_{ii}` each iteration. Beneficial for diagonally-dominant
219    /// systems at the cost of O(n) extra work per iteration.
220    use_preconditioner: bool,
221}
222
223impl ConjugateGradientSolver {
224    /// Create a new CG solver.
225    ///
226    /// # Arguments
227    ///
228    /// * `tolerance` -- Relative residual threshold for convergence. Must be
229    ///   positive and finite.
230    /// * `max_iterations` -- Upper bound on CG iterations. Must be >= 1.
231    /// * `use_preconditioner` -- Enable diagonal (Jacobi) preconditioning.
232    pub fn new(tolerance: f64, max_iterations: usize, use_preconditioner: bool) -> Self {
233        Self {
234            tolerance,
235            max_iterations,
236            use_preconditioner,
237        }
238    }
239
240    /// Return the configured tolerance.
241    #[inline]
242    pub fn tolerance(&self) -> f64 {
243        self.tolerance
244    }
245
246    /// Return the configured maximum iterations.
247    #[inline]
248    pub fn max_iterations(&self) -> usize {
249        self.max_iterations
250    }
251
252    /// Return whether preconditioning is enabled.
253    #[inline]
254    pub fn use_preconditioner(&self) -> bool {
255        self.use_preconditioner
256    }
257
258    // -------------------------------------------------------------------
259    // Input validation
260    // -------------------------------------------------------------------
261
262    /// Validate inputs before entering the CG loop.
263    fn validate(
264        &self,
265        matrix: &CsrMatrix<f64>,
266        rhs: &[f64],
267    ) -> Result<(), SolverError> {
268        if matrix.rows != matrix.cols {
269            return Err(SolverError::InvalidInput(
270                ValidationError::DimensionMismatch(format!(
271                    "CG requires a square matrix but got {}x{}",
272                    matrix.rows, matrix.cols,
273                )),
274            ));
275        }
276
277        if rhs.len() != matrix.rows {
278            return Err(SolverError::InvalidInput(
279                ValidationError::DimensionMismatch(format!(
280                    "rhs length {} does not match matrix rows {}",
281                    rhs.len(),
282                    matrix.rows,
283                )),
284            ));
285        }
286
287        if matrix.row_ptr.len() != matrix.rows + 1 {
288            return Err(SolverError::InvalidInput(
289                ValidationError::DimensionMismatch(format!(
290                    "row_ptr length {} does not equal rows + 1 = {}",
291                    matrix.row_ptr.len(),
292                    matrix.rows + 1,
293                )),
294            ));
295        }
296
297        if !self.tolerance.is_finite() || self.tolerance <= 0.0 {
298            return Err(SolverError::InvalidInput(
299                ValidationError::ParameterOutOfRange {
300                    name: "tolerance".into(),
301                    value: self.tolerance.to_string(),
302                    expected: "positive finite value".into(),
303                },
304            ));
305        }
306
307        if self.max_iterations == 0 {
308            return Err(SolverError::InvalidInput(
309                ValidationError::ParameterOutOfRange {
310                    name: "max_iterations".into(),
311                    value: "0".into(),
312                    expected: ">= 1".into(),
313                },
314            ));
315        }
316
317        Ok(())
318    }
319
320    // -------------------------------------------------------------------
321    // Jacobi preconditioner
322    // -------------------------------------------------------------------
323
324    /// Build the Jacobi (diagonal) preconditioner from `A`.
325    ///
326    /// Returns `inv_diag[i] = 1.0 / A_{ii}`. Zero or near-zero diagonal
327    /// entries are replaced with `1.0` to prevent division by zero.
328    fn build_jacobi_preconditioner(matrix: &CsrMatrix<f64>) -> Vec<f64> {
329        let n = matrix.rows;
330        let mut inv_diag = vec![1.0f64; n];
331
332        for row in 0..n {
333            let start = matrix.row_ptr[row];
334            let end = matrix.row_ptr[row + 1];
335            for idx in start..end {
336                if matrix.col_indices[idx] == row {
337                    let diag_val = matrix.values[idx];
338                    if diag_val.abs() > f64::EPSILON {
339                        inv_diag[row] = 1.0 / diag_val;
340                    }
341                    break;
342                }
343            }
344        }
345
346        inv_diag
347    }
348
349    /// Apply the diagonal preconditioner: `z[i] = inv_diag[i] * r[i]`.
350    #[inline]
351    fn apply_preconditioner(inv_diag: &[f64], r: &[f64], z: &mut [f64]) {
352        assert_eq!(inv_diag.len(), r.len());
353        assert_eq!(r.len(), z.len());
354
355        let n = r.len();
356        let chunks = n / 4;
357        let base = chunks * 4;
358
359        for i in 0..chunks {
360            let j = i * 4;
361            z[j] = inv_diag[j] * r[j];
362            z[j + 1] = inv_diag[j + 1] * r[j + 1];
363            z[j + 2] = inv_diag[j + 2] * r[j + 2];
364            z[j + 3] = inv_diag[j + 3] * r[j + 3];
365        }
366        for i in base..n {
367            z[i] = inv_diag[i] * r[i];
368        }
369    }
370
371    // -------------------------------------------------------------------
372    // Core CG algorithm
373    // -------------------------------------------------------------------
374
375    /// Core CG algorithm implementation.
376    ///
377    /// Works entirely in `f64` precision internally. The final solution is
378    /// down-cast to `f32` in the returned [`SolverResult`] to match the
379    /// crate's output contract.
380    fn solve_inner(
381        &self,
382        matrix: &CsrMatrix<f64>,
383        rhs: &[f64],
384        budget: &ComputeBudget,
385    ) -> Result<SolverResult, SolverError> {
386        let start_time = Instant::now();
387        let n = matrix.rows;
388
389        // Effective limits: take the tighter of our config and the budget.
390        let effective_max_iter = self.max_iterations.min(budget.max_iterations);
391        let effective_tol = self.tolerance.min(budget.tolerance);
392
393        // --- Trivial case: zero-dimensional system ---
394        if n == 0 {
395            return Ok(SolverResult {
396                solution: vec![],
397                iterations: 0,
398                residual_norm: 0.0,
399                wall_time: start_time.elapsed(),
400                convergence_history: vec![],
401                algorithm: Algorithm::CG,
402            });
403        }
404
405        // --- Allocate working vectors (all f64) ---
406        let mut x = vec![0.0f64; n]; // solution (initial guess = zero)
407        let mut r = vec![0.0f64; n]; // residual
408        let mut z = vec![0.0f64; n]; // preconditioned residual
409        let mut p = vec![0.0f64; n]; // search direction
410        let mut ap = vec![0.0f64; n]; // A * p scratch buffer
411
412        // --- Build preconditioner (if enabled) ---
413        let inv_diag = if self.use_preconditioner {
414            Some(Self::build_jacobi_preconditioner(matrix))
415        } else {
416            None
417        };
418
419        // --- r = b - A*x. Since x = 0, r = b ---
420        r.copy_from_slice(rhs);
421
422        // --- Convergence threshold: ||r||_2 < tol * ||b||_2 (relative) ---
423        let b_norm = norm2_f64(rhs);
424        let abs_tolerance = effective_tol * b_norm;
425
426        // Handle zero RHS: the solution is the zero vector.
427        if b_norm < f64::EPSILON {
428            debug!("CG: zero RHS detected, returning zero solution");
429            return Ok(SolverResult {
430                solution: vec![0.0f32; n],
431                iterations: 0,
432                residual_norm: 0.0,
433                wall_time: start_time.elapsed(),
434                convergence_history: vec![],
435                algorithm: Algorithm::CG,
436            });
437        }
438
439        let initial_residual_norm = norm2_f64(&r);
440
441        // --- z = M^{-1} * r ---
442        match &inv_diag {
443            Some(diag) => Self::apply_preconditioner(diag, &r, &mut z),
444            None => z.copy_from_slice(&r),
445        }
446
447        // --- p = z ---
448        p.copy_from_slice(&z);
449
450        // --- rz = r . z ---
451        let mut rz = dot_f64(&r, &z);
452
453        let mut convergence_history =
454            Vec::with_capacity(effective_max_iter.min(256));
455        let mut converged = false;
456
457        debug!(
458            "CG: n={}, nnz={}, tol={:.2e}, max_iter={}, precond={}",
459            n,
460            matrix.nnz(),
461            effective_tol,
462            effective_max_iter,
463            self.use_preconditioner,
464        );
465
466        // ===============================================================
467        // Main CG loop (Hestenes-Stiefel)
468        // ===============================================================
469        for k in 0..effective_max_iter {
470            // --- Budget: wall-time check ---
471            if start_time.elapsed() > budget.max_time {
472                warn!("CG: wall-time budget exhausted at iteration {k}");
473                return Err(SolverError::BudgetExhausted {
474                    reason: format!(
475                        "wall-time limit {:?} exceeded at iteration {k}",
476                        budget.max_time,
477                    ),
478                    elapsed: start_time.elapsed(),
479                });
480            }
481
482            // --- Ap = A * p  (sparse matrix-vector product) ---
483            matrix.spmv(&p, &mut ap);
484
485            // --- alpha = rz / (p . Ap) ---
486            let p_dot_ap = dot_f64(&p, &ap);
487
488            // Guard: if p.Ap <= 0 the matrix is not SPD or we hit numerical
489            // breakdown.
490            if p_dot_ap <= 0.0 {
491                warn!("CG: non-positive p.Ap = {p_dot_ap:.4e} at iteration {k}");
492                return Err(SolverError::NumericalInstability {
493                    iteration: k,
494                    detail: format!(
495                        "p.Ap = {p_dot_ap:.6e} <= 0; matrix may not be SPD",
496                    ),
497                });
498            }
499
500            let alpha = rz / p_dot_ap;
501
502            // --- x = x + alpha * p ---
503            axpy_f64(alpha, &p, &mut x);
504
505            // --- r = r - alpha * Ap ---
506            axpy_f64(-alpha, &ap, &mut r);
507
508            // --- Convergence check: ||r||_2 < tol * ||b||_2 ---
509            let r_norm = norm2_f64(&r);
510
511            convergence_history.push(ConvergenceInfo {
512                iteration: k,
513                residual_norm: r_norm,
514            });
515
516            trace!(
517                "CG iter {k}: ||r|| = {r_norm:.6e}, rel = {:.6e}",
518                r_norm / b_norm,
519            );
520
521            if r_norm < abs_tolerance {
522                converged = true;
523                debug!(
524                    "CG converged at iteration {k}: ||r|| = {r_norm:.6e}, \
525                     rel = {:.6e}",
526                    r_norm / b_norm,
527                );
528                break;
529            }
530
531            // --- Divergence detection ---
532            // If ||r|| has grown by 10x from the initial residual, the
533            // system is likely indefinite or ill-conditioned beyond rescue.
534            if r_norm > 10.0 * initial_residual_norm {
535                warn!(
536                    "CG: divergence at iteration {k}: ||r|| = {r_norm:.6e} \
537                     > 10 * ||r_0|| = {:.6e}",
538                    10.0 * initial_residual_norm,
539                );
540                return Err(SolverError::NumericalInstability {
541                    iteration: k,
542                    detail: format!(
543                        "residual diverged: ||r|| = {r_norm:.6e} exceeds \
544                         10x initial residual {initial_residual_norm:.6e}",
545                    ),
546                });
547            }
548
549            // --- z = M^{-1} * r ---
550            match &inv_diag {
551                Some(diag) => Self::apply_preconditioner(diag, &r, &mut z),
552                None => z.copy_from_slice(&r),
553            }
554
555            // --- rz_new = r . z ---
556            let rz_new = dot_f64(&r, &z);
557
558            // Guard: stagnation when rz is near-zero.
559            if rz.abs() < f64::EPSILON * f64::EPSILON {
560                warn!("CG: rz near zero at iteration {k}, stagnation");
561                return Err(SolverError::NumericalInstability {
562                    iteration: k,
563                    detail: format!(
564                        "rz = {rz:.6e} is near zero; solver stagnated",
565                    ),
566                });
567            }
568
569            // --- beta = rz_new / rz ---
570            let beta = rz_new / rz;
571
572            // --- p = z + beta * p ---
573            for i in 0..n {
574                p[i] = z[i] + beta * p[i];
575            }
576
577            // --- rz = rz_new ---
578            rz = rz_new;
579        }
580
581        let wall_time = start_time.elapsed();
582        let final_residual = norm2_f64(&r);
583
584        if !converged {
585            debug!(
586                "CG: non-convergence after {} iterations, ||r|| = {final_residual:.6e}",
587                effective_max_iter,
588            );
589            return Err(SolverError::NonConvergence {
590                iterations: effective_max_iter,
591                residual: final_residual,
592                tolerance: abs_tolerance,
593            });
594        }
595
596        // Down-cast the f64 solution to f32 for SolverResult.
597        let solution_f32: Vec<f32> = x.iter().map(|&v| v as f32).collect();
598
599        Ok(SolverResult {
600            solution: solution_f32,
601            iterations: convergence_history.len(),
602            residual_norm: final_residual,
603            wall_time,
604            convergence_history,
605            algorithm: Algorithm::CG,
606        })
607    }
608}
609
610// ═══════════════════════════════════════════════════════════════════════════
611// SolverEngine trait implementation
612// ═══════════════════════════════════════════════════════════════════════════
613
614impl SolverEngine for ConjugateGradientSolver {
615    /// Solve `Ax = b` using the Conjugate Gradient method.
616    ///
617    /// # Errors
618    ///
619    /// * [`SolverError::InvalidInput`] -- dimension mismatch or invalid params.
620    /// * [`SolverError::NumericalInstability`] -- divergence or non-SPD matrix.
621    /// * [`SolverError::NonConvergence`] -- iteration limit exceeded.
622    /// * [`SolverError::BudgetExhausted`] -- wall-time limit exceeded.
623    fn solve(
624        &self,
625        matrix: &CsrMatrix<f64>,
626        rhs: &[f64],
627        budget: &ComputeBudget,
628    ) -> Result<SolverResult, SolverError> {
629        self.validate(matrix, rhs)?;
630        self.solve_inner(matrix, rhs, budget)
631    }
632
633    /// Estimate CG complexity from the sparsity profile.
634    ///
635    /// CG converges in `O(sqrt(kappa))` iterations, each costing `O(nnz)` for
636    /// the SpMV plus `O(n)` for the vector updates.
637    fn estimate_complexity(
638        &self,
639        profile: &SparsityProfile,
640        n: usize,
641    ) -> ComplexityEstimate {
642        // Estimated iterations from condition number, clamped to max_iterations.
643        let est_iters = (profile.estimated_condition.sqrt() as usize)
644            .max(1)
645            .min(self.max_iterations);
646
647        // FLOPs per iteration: 2*nnz (SpMV) + 6*n (dot products, axpy ops).
648        let flops_per_iter = 2 * profile.nnz as u64 + 6 * n as u64;
649        let estimated_flops = est_iters as u64 * flops_per_iter;
650
651        // Memory: 5 vectors of length n (x, r, z, p, Ap) plus preconditioner.
652        let vec_bytes = n * std::mem::size_of::<f64>();
653        let precond_bytes = if self.use_preconditioner { vec_bytes } else { 0 };
654        let estimated_memory_bytes = 5 * vec_bytes + precond_bytes;
655
656        ComplexityEstimate {
657            algorithm: Algorithm::CG,
658            estimated_flops,
659            estimated_iterations: est_iters,
660            estimated_memory_bytes,
661            complexity_class: ComplexityClass::SqrtCondition,
662        }
663    }
664
665    /// Return the algorithm identifier.
666    fn algorithm(&self) -> Algorithm {
667        Algorithm::CG
668    }
669}
670
671// ═══════════════════════════════════════════════════════════════════════════
672// Tests
673// ═══════════════════════════════════════════════════════════════════════════
674
675#[cfg(test)]
676mod tests {
677    use super::*;
678    use std::time::Duration;
679
680    /// Build a symmetric tridiagonal SPD matrix (f64):
681    ///   diag = 4.0, off-diag = -1.0
682    fn tridiagonal_spd(n: usize) -> CsrMatrix<f64> {
683        let mut entries = Vec::with_capacity(3 * n);
684        for i in 0..n {
685            if i > 0 {
686                entries.push((i, i - 1, -1.0f64));
687            }
688            entries.push((i, i, 4.0f64));
689            if i + 1 < n {
690                entries.push((i, i + 1, -1.0f64));
691            }
692        }
693        CsrMatrix::<f64>::from_coo(n, n, entries)
694    }
695
696    /// Build a diagonal matrix from the given values.
697    fn diagonal_matrix(diag: &[f64]) -> CsrMatrix<f64> {
698        let n = diag.len();
699        let entries: Vec<_> = diag
700            .iter()
701            .enumerate()
702            .map(|(i, &v)| (i, i, v))
703            .collect();
704        CsrMatrix::<f64>::from_coo(n, n, entries)
705    }
706
707    /// Build an identity matrix of dimension n.
708    fn identity(n: usize) -> CsrMatrix<f64> {
709        CsrMatrix::<f64>::identity(n)
710    }
711
712    fn default_budget() -> ComputeBudget {
713        ComputeBudget {
714            max_time: Duration::from_secs(30),
715            max_iterations: 10_000,
716            tolerance: 1e-10,
717        }
718    }
719
720    // -----------------------------------------------------------------
721    // dot_product_f64
722    // -----------------------------------------------------------------
723
724    #[test]
725    fn dot_product_f64_basic() {
726        let a = vec![1.0f32, 2.0, 3.0];
727        let b = vec![4.0f32, 5.0, 6.0];
728        let result = dot_product_f64(&a, &b);
729        assert!((result - 32.0).abs() < 1e-10);
730    }
731
732    #[test]
733    fn dot_product_f64_empty() {
734        assert!((dot_product_f64(&[], &[]) - 0.0).abs() < 1e-10);
735    }
736
737    #[test]
738    fn dot_product_f64_precision() {
739        let n = 10_000;
740        let a = vec![1.0f32; n];
741        let b = vec![1.0f32; n];
742        assert!((dot_product_f64(&a, &b) - n as f64).abs() < 1e-10);
743    }
744
745    #[test]
746    fn dot_product_f64_odd_length() {
747        let a = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
748        let b = vec![5.0f32, 4.0, 3.0, 2.0, 1.0];
749        // 5 + 8 + 9 + 8 + 5 = 35
750        assert!((dot_product_f64(&a, &b) - 35.0).abs() < 1e-10);
751    }
752
753    // -----------------------------------------------------------------
754    // axpy
755    // -----------------------------------------------------------------
756
757    #[test]
758    fn axpy_basic() {
759        let x = vec![1.0f32, 2.0, 3.0];
760        let mut y = vec![10.0f32, 20.0, 30.0];
761        axpy(2.0, &x, &mut y);
762        assert_eq!(y, vec![12.0, 24.0, 36.0]);
763    }
764
765    #[test]
766    fn axpy_negative_alpha() {
767        let x = vec![1.0f32, 1.0, 1.0];
768        let mut y = vec![5.0f32, 5.0, 5.0];
769        axpy(-3.0, &x, &mut y);
770        assert_eq!(y, vec![2.0, 2.0, 2.0]);
771    }
772
773    // -----------------------------------------------------------------
774    // norm2
775    // -----------------------------------------------------------------
776
777    #[test]
778    fn norm2_basic() {
779        let x = vec![3.0f32, 4.0];
780        assert!((norm2(&x) - 5.0).abs() < 1e-10);
781    }
782
783    #[test]
784    fn norm2_zero() {
785        assert!((norm2(&vec![0.0f32; 5]) - 0.0).abs() < 1e-10);
786    }
787
788    // -----------------------------------------------------------------
789    // CG solver: convergence on well-conditioned systems
790    // -----------------------------------------------------------------
791
792    #[test]
793    fn cg_identity_matrix() {
794        let n = 5;
795        let matrix = identity(n);
796        let rhs = vec![1.0, 2.0, 3.0, 4.0, 5.0];
797        let budget = default_budget();
798
799        let solver = ConjugateGradientSolver::new(1e-10, 100, false);
800        let result = solver.solve(&matrix, &rhs, &budget).unwrap();
801
802        for i in 0..n {
803            assert!(
804                (result.solution[i] as f64 - rhs[i]).abs() < 1e-5,
805                "x[{i}] = {} != {}",
806                result.solution[i],
807                rhs[i],
808            );
809        }
810        // Identity should converge in at most 1 iteration.
811        assert!(result.iterations <= 1);
812    }
813
814    #[test]
815    fn cg_diagonal_matrix() {
816        let diag = vec![2.0, 3.0, 5.0, 7.0];
817        let matrix = diagonal_matrix(&diag);
818        let rhs = vec![4.0, 9.0, 25.0, 49.0];
819        let budget = default_budget();
820
821        let solver = ConjugateGradientSolver::new(1e-10, 100, false);
822        let result = solver.solve(&matrix, &rhs, &budget).unwrap();
823
824        let expected = [2.0, 3.0, 5.0, 7.0];
825        for i in 0..4 {
826            assert!(
827                (result.solution[i] as f64 - expected[i]).abs() < 1e-4,
828                "x[{i}] = {} != {}",
829                result.solution[i],
830                expected[i],
831            );
832        }
833    }
834
835    #[test]
836    fn cg_tridiagonal_small() {
837        let n = 10;
838        let matrix = tridiagonal_spd(n);
839        let rhs = vec![1.0f64; n];
840        let budget = default_budget();
841
842        let solver = ConjugateGradientSolver::new(1e-8, 200, false);
843        let result = solver.solve(&matrix, &rhs, &budget).unwrap();
844
845        assert!(
846            result.residual_norm < 1e-6,
847            "residual = {}",
848            result.residual_norm,
849        );
850        assert!(
851            result.iterations <= n,
852            "took {} iterations for n={}",
853            result.iterations,
854            n,
855        );
856    }
857
858    #[test]
859    fn cg_tridiagonal_large() {
860        let n = 500;
861        let matrix = tridiagonal_spd(n);
862        let rhs: Vec<f64> = (0..n).map(|i| (i as f64 + 1.0) / n as f64).collect();
863        let budget = default_budget();
864
865        let solver = ConjugateGradientSolver::new(1e-8, 2000, false);
866        let result = solver.solve(&matrix, &rhs, &budget).unwrap();
867
868        assert!(
869            result.residual_norm < 1e-5,
870            "residual = {}",
871            result.residual_norm,
872        );
873    }
874
875    // -----------------------------------------------------------------
876    // Preconditioning
877    // -----------------------------------------------------------------
878
879    #[test]
880    fn cg_preconditioned_converges_faster() {
881        let n = 100;
882        let matrix = tridiagonal_spd(n);
883        let rhs = vec![1.0f64; n];
884        let budget = default_budget();
885
886        let no_precond = ConjugateGradientSolver::new(1e-8, 500, false);
887        let with_precond = ConjugateGradientSolver::new(1e-8, 500, true);
888
889        let result_no = no_precond.solve(&matrix, &rhs, &budget).unwrap();
890        let result_yes = with_precond.solve(&matrix, &rhs, &budget).unwrap();
891
892        assert!(result_no.residual_norm < 1e-6);
893        assert!(result_yes.residual_norm < 1e-6);
894
895        assert!(
896            result_yes.iterations <= result_no.iterations,
897            "preconditioned ({}) should use <= iterations than \
898             unpreconditioned ({})",
899            result_yes.iterations,
900            result_no.iterations,
901        );
902    }
903
904    // -----------------------------------------------------------------
905    // Zero RHS / empty system
906    // -----------------------------------------------------------------
907
908    #[test]
909    fn cg_zero_rhs() {
910        let matrix = tridiagonal_spd(5);
911        let rhs = vec![0.0f64; 5];
912        let budget = default_budget();
913
914        let solver = ConjugateGradientSolver::new(1e-8, 100, false);
915        let result = solver.solve(&matrix, &rhs, &budget).unwrap();
916
917        assert_eq!(result.iterations, 0);
918        for &v in &result.solution {
919            assert!((v as f64).abs() < 1e-10);
920        }
921    }
922
923    #[test]
924    fn cg_empty_system() {
925        let matrix = CsrMatrix {
926            row_ptr: vec![0],
927            col_indices: vec![],
928            values: Vec::<f64>::new(),
929            rows: 0,
930            cols: 0,
931        };
932        let rhs: Vec<f64> = vec![];
933        let budget = default_budget();
934
935        let solver = ConjugateGradientSolver::new(1e-8, 100, false);
936        let result = solver.solve(&matrix, &rhs, &budget).unwrap();
937
938        assert_eq!(result.iterations, 0);
939        assert!(result.solution.is_empty());
940    }
941
942    // -----------------------------------------------------------------
943    // Error cases
944    // -----------------------------------------------------------------
945
946    #[test]
947    fn cg_dimension_mismatch() {
948        let matrix = tridiagonal_spd(3);
949        let rhs = vec![1.0f64; 5];
950        let budget = default_budget();
951
952        let solver = ConjugateGradientSolver::new(1e-8, 100, false);
953        let err = solver.solve(&matrix, &rhs, &budget).unwrap_err();
954        assert!(matches!(err, SolverError::InvalidInput(_)));
955    }
956
957    #[test]
958    fn cg_non_square_matrix() {
959        let matrix = CsrMatrix {
960            row_ptr: vec![0, 1, 2],
961            col_indices: vec![0, 1],
962            values: vec![1.0f64, 1.0],
963            rows: 2,
964            cols: 3,
965        };
966        let rhs = vec![1.0f64; 2];
967        let budget = default_budget();
968
969        let solver = ConjugateGradientSolver::new(1e-8, 100, false);
970        let err = solver.solve(&matrix, &rhs, &budget).unwrap_err();
971        assert!(matches!(err, SolverError::InvalidInput(_)));
972    }
973
974    #[test]
975    fn cg_non_convergence() {
976        let n = 50;
977        let matrix = tridiagonal_spd(n);
978        let rhs = vec![1.0f64; n];
979        let budget = ComputeBudget {
980            max_time: Duration::from_secs(30),
981            max_iterations: 1,
982            tolerance: 1e-15,
983        };
984
985        let solver = ConjugateGradientSolver::new(1e-15, 1, false);
986        let err = solver.solve(&matrix, &rhs, &budget).unwrap_err();
987        assert!(matches!(err, SolverError::NonConvergence { .. }));
988    }
989
990    #[test]
991    fn cg_budget_iteration_limit() {
992        let n = 50;
993        let matrix = tridiagonal_spd(n);
994        let rhs = vec![1.0f64; n];
995
996        // Solver allows 1000, but budget allows only 2.
997        let solver = ConjugateGradientSolver::new(1e-15, 1000, false);
998        let budget = ComputeBudget {
999            max_time: Duration::from_secs(60),
1000            max_iterations: 2,
1001            tolerance: 1e-15,
1002        };
1003
1004        let err = solver.solve(&matrix, &rhs, &budget).unwrap_err();
1005        assert!(matches!(err, SolverError::NonConvergence { .. }));
1006    }
1007
1008    // -----------------------------------------------------------------
1009    // Convergence history
1010    // -----------------------------------------------------------------
1011
1012    #[test]
1013    fn cg_convergence_history_populated() {
1014        let n = 20;
1015        let matrix = tridiagonal_spd(n);
1016        let rhs = vec![1.0f64; n];
1017        let budget = default_budget();
1018
1019        let solver = ConjugateGradientSolver::new(1e-10, 200, false);
1020        let result = solver.solve(&matrix, &rhs, &budget).unwrap();
1021
1022        assert!(!result.convergence_history.is_empty());
1023
1024        // Final history entry should match the reported residual.
1025        let last = result.convergence_history.last().unwrap();
1026        assert!((last.residual_norm - result.residual_norm).abs() < 1e-12);
1027    }
1028
1029    #[test]
1030    fn cg_algorithm_field() {
1031        let matrix = identity(3);
1032        let rhs = vec![1.0f64; 3];
1033        let budget = default_budget();
1034
1035        let solver = ConjugateGradientSolver::new(1e-8, 100, false);
1036        let result = solver.solve(&matrix, &rhs, &budget).unwrap();
1037        assert_eq!(result.algorithm, Algorithm::CG);
1038    }
1039
1040    // -----------------------------------------------------------------
1041    // Verify Ax = b for computed solution
1042    // -----------------------------------------------------------------
1043
1044    #[test]
1045    fn cg_solution_satisfies_system() {
1046        let n = 20;
1047        let matrix = tridiagonal_spd(n);
1048        let rhs = vec![1.0f64; n];
1049        let budget = default_budget();
1050
1051        let solver = ConjugateGradientSolver::new(1e-10, 200, true);
1052        let result = solver.solve(&matrix, &rhs, &budget).unwrap();
1053
1054        // Up-cast solution back to f64 and compute Ax.
1055        let x_f64: Vec<f64> = result.solution.iter().map(|&v| v as f64).collect();
1056        let mut ax = vec![0.0f64; n];
1057        matrix.spmv(&x_f64, &mut ax);
1058
1059        let mut max_err: f64 = 0.0;
1060        for i in 0..n {
1061            let err = (ax[i] - rhs[i]).abs();
1062            if err > max_err {
1063                max_err = err;
1064            }
1065        }
1066
1067        assert!(
1068            max_err < 1e-4,
1069            "max |Ax - b| = {max_err:.6e}, expected < 1e-4",
1070        );
1071    }
1072
1073    // -----------------------------------------------------------------
1074    // Estimate complexity
1075    // -----------------------------------------------------------------
1076
1077    #[test]
1078    fn estimate_complexity_returns_cg() {
1079        let solver = ConjugateGradientSolver::new(1e-8, 500, true);
1080        let profile = SparsityProfile {
1081            rows: 100,
1082            cols: 100,
1083            nnz: 298,
1084            density: 0.0298,
1085            is_diag_dominant: true,
1086            estimated_spectral_radius: 0.5,
1087            estimated_condition: 100.0,
1088            is_symmetric_structure: true,
1089            avg_nnz_per_row: 2.98,
1090            max_nnz_per_row: 3,
1091        };
1092
1093        let est = solver.estimate_complexity(&profile, 100);
1094        assert_eq!(est.algorithm, Algorithm::CG);
1095        assert_eq!(est.complexity_class, ComplexityClass::SqrtCondition);
1096        assert!(est.estimated_iterations > 0);
1097        assert!(est.estimated_flops > 0);
1098        assert!(est.estimated_memory_bytes > 0);
1099    }
1100
1101    // -----------------------------------------------------------------
1102    // Accessors
1103    // -----------------------------------------------------------------
1104
1105    #[test]
1106    fn accessors() {
1107        let solver = ConjugateGradientSolver::new(1e-6, 500, true);
1108        assert!((solver.tolerance() - 1e-6).abs() < 1e-15);
1109        assert_eq!(solver.max_iterations(), 500);
1110        assert!(solver.use_preconditioner());
1111    }
1112}