Skip to main content

scirs2_sparse/
iterative_solvers.rs

1//! Enhanced iterative solvers for sparse linear systems
2//!
3//! This module provides production-quality iterative solvers with ndarray-based
4//! interfaces, preconditioner abstractions, and sparse matrix utility functions.
5//!
6//! # Solvers
7//!
8//! - **CG**: Conjugate Gradient for symmetric positive definite (SPD) systems
9//! - **BiCGSTAB**: Biconjugate Gradient Stabilized for general square systems
10//! - **GMRES(m)**: Generalized Minimal Residual with restarts for general systems
11//! - **Chebyshev**: Chebyshev iteration for SPD systems with known eigenvalue bounds
12//!
13//! # Preconditioners
14//!
15//! - **Jacobi**: Diagonal (inverse) preconditioner
16//! - **ILU(0)**: Incomplete LU with zero fill-in
17//! - **SSOR**: Symmetric Successive Over-Relaxation
18//!
19//! # Utility Functions
20//!
21//! - `estimate_spectral_radius`: Power iteration based spectral radius estimation
22//! - `sparse_diagonal`: Extract diagonal of a sparse matrix
23//! - `sparse_trace`: Compute trace of a sparse matrix
24//! - `sparse_norm`: Compute Frobenius, infinity, or 1-norm of a sparse matrix
25
26use crate::csr::CsrMatrix;
27use crate::error::{SparseError, SparseResult};
28use scirs2_core::ndarray::{Array1, ScalarOperand};
29use scirs2_core::numeric::{Float, NumAssign, SparseElement};
30use std::fmt::Debug;
31use std::iter::Sum;
32use std::ops::{AddAssign, MulAssign};
33
34// ---------------------------------------------------------------------------
35// Configuration and result types
36// ---------------------------------------------------------------------------
37
38/// Configuration for iterative solvers.
39#[derive(Debug, Clone)]
40pub struct IterativeSolverConfig {
41    /// Maximum number of iterations.
42    pub max_iter: usize,
43    /// Relative convergence tolerance.
44    pub tol: f64,
45    /// Whether to print convergence information.
46    pub verbose: bool,
47}
48
49impl Default for IterativeSolverConfig {
50    fn default() -> Self {
51        Self {
52            max_iter: 1000,
53            tol: 1e-10,
54            verbose: false,
55        }
56    }
57}
58
59/// Result returned by an iterative solver.
60#[derive(Debug, Clone)]
61pub struct SolverResult<F> {
62    /// The computed solution vector.
63    pub solution: Array1<F>,
64    /// Number of iterations performed.
65    pub n_iter: usize,
66    /// Final residual norm ||b - Ax||.
67    pub residual_norm: F,
68    /// Whether the solver converged within the tolerance.
69    pub converged: bool,
70}
71
72// ---------------------------------------------------------------------------
73// Preconditioner trait and implementations
74// ---------------------------------------------------------------------------
75
76/// Trait for preconditioners used with iterative solvers.
77///
78/// A preconditioner approximates `M^{-1}` so that `M^{-1} A` has a more
79/// clustered spectrum, accelerating convergence.
80pub trait Preconditioner<F: Float> {
81    /// Apply the preconditioner to vector `r`, returning `M^{-1} r`.
82    fn apply(&self, r: &Array1<F>) -> SparseResult<Array1<F>>;
83}
84
85/// Jacobi (diagonal) preconditioner.
86///
87/// Uses `M = diag(A)`, so `M^{-1} r = r ./ diag(A)`.
88/// Effective when the matrix is diagonally dominant.
89pub struct JacobiPreconditioner<F> {
90    diagonal_inv: Array1<F>,
91}
92
93impl<F: Float + SparseElement + Debug> JacobiPreconditioner<F> {
94    /// Create a Jacobi preconditioner from a CSR matrix.
95    ///
96    /// Returns an error if any diagonal element is zero or near-zero.
97    pub fn new(matrix: &CsrMatrix<F>) -> SparseResult<Self> {
98        let n = matrix.rows();
99        if n != matrix.cols() {
100            return Err(SparseError::ValueError(
101                "Matrix must be square for Jacobi preconditioner".to_string(),
102            ));
103        }
104        let mut diag_inv = Array1::zeros(n);
105        for i in 0..n {
106            let d = matrix.get(i, i);
107            if d.abs() < F::epsilon() {
108                return Err(SparseError::ValueError(format!(
109                    "Zero diagonal element at row {i} prevents Jacobi preconditioner"
110                )));
111            }
112            diag_inv[i] = F::sparse_one() / d;
113        }
114        Ok(Self {
115            diagonal_inv: diag_inv,
116        })
117    }
118
119    /// Create a Jacobi preconditioner from an explicit diagonal vector.
120    pub fn from_diagonal(diagonal: Array1<F>) -> SparseResult<Self> {
121        let n = diagonal.len();
122        let mut diag_inv = Array1::zeros(n);
123        for i in 0..n {
124            if diagonal[i].abs() < F::epsilon() {
125                return Err(SparseError::ValueError(format!(
126                    "Zero diagonal element at position {i}"
127                )));
128            }
129            diag_inv[i] = F::sparse_one() / diagonal[i];
130        }
131        Ok(Self {
132            diagonal_inv: diag_inv,
133        })
134    }
135}
136
137impl<F: Float + SparseElement> Preconditioner<F> for JacobiPreconditioner<F> {
138    fn apply(&self, r: &Array1<F>) -> SparseResult<Array1<F>> {
139        if r.len() != self.diagonal_inv.len() {
140            return Err(SparseError::DimensionMismatch {
141                expected: self.diagonal_inv.len(),
142                found: r.len(),
143            });
144        }
145        Ok(r * &self.diagonal_inv)
146    }
147}
148
149/// ILU(0) preconditioner (Incomplete LU with zero fill-in).
150///
151/// Computes an approximate factorization `A ~ L U` where L and U
152/// retain only the sparsity pattern of A.
153pub struct ILU0Preconditioner<F> {
154    // Store the combined LU data in CSR-like arrays.
155    l_data: Vec<F>,
156    u_data: Vec<F>,
157    l_indices: Vec<usize>,
158    u_indices: Vec<usize>,
159    l_indptr: Vec<usize>,
160    u_indptr: Vec<usize>,
161    n: usize,
162}
163
164impl<F: Float + NumAssign + Sum + Debug + SparseElement + 'static> ILU0Preconditioner<F> {
165    /// Construct an ILU(0) preconditioner from a CSR matrix.
166    pub fn new(matrix: &CsrMatrix<F>) -> SparseResult<Self> {
167        let n = matrix.rows();
168        if n != matrix.cols() {
169            return Err(SparseError::ValueError(
170                "Matrix must be square for ILU(0) preconditioner".to_string(),
171            ));
172        }
173
174        // Copy matrix data for in-place modification
175        let mut data = matrix.data.clone();
176        let indices = matrix.indices.clone();
177        let indptr = matrix.indptr.clone();
178
179        // ILU(0) factorisation (Gaussian elimination with pattern restriction)
180        for k in 0..n {
181            let k_diag_idx = find_csr_diag_index(&indices, &indptr, k)?;
182            let k_diag = data[k_diag_idx];
183            if k_diag.abs() < F::epsilon() {
184                return Err(SparseError::ValueError(format!(
185                    "Zero pivot at row {k} in ILU(0) factorization"
186                )));
187            }
188
189            for i in (k + 1)..n {
190                let row_start = indptr[i];
191                let row_end = indptr[i + 1];
192
193                // Find column k in row i
194                let mut k_pos = None;
195                for pos in row_start..row_end {
196                    if indices[pos] == k {
197                        k_pos = Some(pos);
198                        break;
199                    }
200                    if indices[pos] > k {
201                        break;
202                    }
203                }
204
205                if let Some(ki_idx) = k_pos {
206                    let mult = data[ki_idx] / k_diag;
207                    data[ki_idx] = mult;
208
209                    // Update remaining entries in row i that also appear in row k
210                    let k_row_start = indptr[k];
211                    let k_row_end = indptr[k + 1];
212
213                    for kj_idx in k_row_start..k_row_end {
214                        let j = indices[kj_idx];
215                        if j <= k {
216                            continue;
217                        }
218                        // Find position of column j in row i
219                        for ij_idx in row_start..row_end {
220                            if indices[ij_idx] == j {
221                                let kj_val = data[kj_idx];
222                                data[ij_idx] -= mult * kj_val;
223                                break;
224                            }
225                        }
226                    }
227                }
228            }
229        }
230
231        // Split into L (unit lower) and U (upper including diagonal)
232        let mut l_data = Vec::new();
233        let mut u_data = Vec::new();
234        let mut l_indices = Vec::new();
235        let mut u_indices = Vec::new();
236        let mut l_indptr = vec![0usize];
237        let mut u_indptr = vec![0usize];
238
239        for i in 0..n {
240            let row_start = indptr[i];
241            let row_end = indptr[i + 1];
242            for pos in row_start..row_end {
243                let col = indices[pos];
244                let val = data[pos];
245                if col < i {
246                    l_indices.push(col);
247                    l_data.push(val);
248                } else {
249                    u_indices.push(col);
250                    u_data.push(val);
251                }
252            }
253            l_indptr.push(l_indices.len());
254            u_indptr.push(u_indices.len());
255        }
256
257        Ok(Self {
258            l_data,
259            u_data,
260            l_indices,
261            u_indices,
262            l_indptr,
263            u_indptr,
264            n,
265        })
266    }
267}
268
269impl<F: Float + NumAssign + Sum + Debug + SparseElement + 'static> Preconditioner<F>
270    for ILU0Preconditioner<F>
271{
272    fn apply(&self, r: &Array1<F>) -> SparseResult<Array1<F>> {
273        if r.len() != self.n {
274            return Err(SparseError::DimensionMismatch {
275                expected: self.n,
276                found: r.len(),
277            });
278        }
279
280        // Forward solve: L y = r  (L is unit-lower triangular)
281        let mut y = Array1::zeros(self.n);
282        for i in 0..self.n {
283            y[i] = r[i];
284            let start = self.l_indptr[i];
285            let end = self.l_indptr[i + 1];
286            for pos in start..end {
287                let col = self.l_indices[pos];
288                y[i] = y[i] - self.l_data[pos] * y[col];
289            }
290        }
291
292        // Backward solve: U z = y
293        let mut z = Array1::zeros(self.n);
294        for i in (0..self.n).rev() {
295            z[i] = y[i];
296            let start = self.u_indptr[i];
297            let end = self.u_indptr[i + 1];
298            let mut diag_val = F::sparse_one();
299            for pos in start..end {
300                let col = self.u_indices[pos];
301                if col == i {
302                    diag_val = self.u_data[pos];
303                } else if col > i {
304                    z[i] = z[i] - self.u_data[pos] * z[col];
305                }
306            }
307            z[i] /= diag_val;
308        }
309
310        Ok(z)
311    }
312}
313
314/// SSOR (Symmetric Successive Over-Relaxation) preconditioner.
315///
316/// Uses the splitting `M = (D + omega L) D^{-1} (D + omega U) / (2 - omega)`,
317/// with relaxation parameter `omega in (0, 2)`.
318pub struct SSORPreconditioner<F> {
319    omega: F,
320    matrix: CsrMatrix<F>,
321    diagonal: Vec<F>,
322}
323
324impl<F: Float + NumAssign + Sum + Debug + SparseElement + 'static> SSORPreconditioner<F> {
325    /// Create an SSOR preconditioner.
326    ///
327    /// `omega` must lie in the open interval (0, 2).
328    pub fn new(matrix: CsrMatrix<F>, omega: F) -> SparseResult<Self> {
329        let two = F::from(2.0).ok_or_else(|| {
330            SparseError::ValueError("Failed to convert 2.0 to float type".to_string())
331        })?;
332        if omega <= F::sparse_zero() || omega >= two {
333            return Err(SparseError::ValueError(
334                "SSOR omega must be in the open interval (0, 2)".to_string(),
335            ));
336        }
337        let n = matrix.rows();
338        if n != matrix.cols() {
339            return Err(SparseError::ValueError(
340                "Matrix must be square for SSOR preconditioner".to_string(),
341            ));
342        }
343        let mut diagonal = vec![F::sparse_zero(); n];
344        for i in 0..n {
345            let d = matrix.get(i, i);
346            if d.abs() < F::epsilon() {
347                return Err(SparseError::ValueError(format!(
348                    "Zero diagonal element at row {i} prevents SSOR preconditioner"
349                )));
350            }
351            diagonal[i] = d;
352        }
353        Ok(Self {
354            omega,
355            matrix,
356            diagonal,
357        })
358    }
359}
360
361impl<F: Float + NumAssign + Sum + Debug + SparseElement + 'static> Preconditioner<F>
362    for SSORPreconditioner<F>
363{
364    fn apply(&self, r: &Array1<F>) -> SparseResult<Array1<F>> {
365        let n = self.matrix.rows();
366        if r.len() != n {
367            return Err(SparseError::DimensionMismatch {
368                expected: n,
369                found: r.len(),
370            });
371        }
372
373        // Forward sweep: (D + omega L) y = omega (2 - omega) r
374        let two = F::from(2.0)
375            .ok_or_else(|| SparseError::ValueError("Failed to convert 2.0 to float".to_string()))?;
376        let scale = self.omega * (two - self.omega);
377        let mut y = Array1::zeros(n);
378        for i in 0..n {
379            let mut sum = r[i] * scale;
380            let range = self.matrix.row_range(i);
381            let row_indices = &self.matrix.indices[range.clone()];
382            let row_data = &self.matrix.data[range];
383            for (idx, &col) in row_indices.iter().enumerate() {
384                if col < i {
385                    sum -= self.omega * row_data[idx] * y[col];
386                }
387            }
388            y[i] = sum / self.diagonal[i];
389        }
390
391        // Diagonal scaling: z_i = D_i * y_i / D_i  (effectively z = y scaled by D * D^{-1} identity)
392        // The correct SSOR combines forward + backward with diagonal in between:
393        // z_i = D_i * y_i
394        let mut z = Array1::zeros(n);
395        for i in 0..n {
396            z[i] = y[i] * self.diagonal[i];
397        }
398
399        // Backward sweep: (D + omega U) w = z
400        let mut w = Array1::zeros(n);
401        for i in (0..n).rev() {
402            let mut sum = z[i];
403            let range = self.matrix.row_range(i);
404            let row_indices = &self.matrix.indices[range.clone()];
405            let row_data = &self.matrix.data[range];
406            for (idx, &col) in row_indices.iter().enumerate() {
407                if col > i {
408                    sum -= self.omega * row_data[idx] * w[col];
409                }
410            }
411            w[i] = sum / self.diagonal[i];
412        }
413
414        Ok(w)
415    }
416}
417
418// ---------------------------------------------------------------------------
419// Sparse matrix-vector multiplication helper
420// ---------------------------------------------------------------------------
421
422/// Compute y = A * x  for a CSR matrix `A` and dense vector `x`.
423fn spmv<F: Float + NumAssign + Sum + SparseElement + 'static>(
424    a: &CsrMatrix<F>,
425    x: &Array1<F>,
426) -> SparseResult<Array1<F>> {
427    let (m, n) = a.shape();
428    if x.len() != n {
429        return Err(SparseError::DimensionMismatch {
430            expected: n,
431            found: x.len(),
432        });
433    }
434    let mut y = Array1::zeros(m);
435    for i in 0..m {
436        let range = a.row_range(i);
437        let cols = &a.indices[range.clone()];
438        let vals = &a.data[range];
439        let mut acc = F::sparse_zero();
440        for (idx, &col) in cols.iter().enumerate() {
441            acc += vals[idx] * x[col];
442        }
443        y[i] = acc;
444    }
445    Ok(y)
446}
447
448/// Compute the dot product of two Array1 vectors.
449#[inline]
450fn dot_arr<F: Float + Sum>(a: &Array1<F>, b: &Array1<F>) -> F {
451    a.iter().zip(b.iter()).map(|(&ai, &bi)| ai * bi).sum()
452}
453
454/// Compute the 2-norm of an Array1 vector.
455#[inline]
456fn norm2_arr<F: Float + Sum>(v: &Array1<F>) -> F {
457    dot_arr(v, v).sqrt()
458}
459
460/// axpy: y = y + alpha * x
461#[inline]
462fn axpy<F: Float>(y: &mut Array1<F>, alpha: F, x: &Array1<F>) {
463    for (yi, &xi) in y.iter_mut().zip(x.iter()) {
464        *yi = *yi + alpha * xi;
465    }
466}
467
468// ---------------------------------------------------------------------------
469// Conjugate Gradient solver
470// ---------------------------------------------------------------------------
471
472/// Conjugate Gradient solver for symmetric positive definite systems.
473///
474/// Solves `A x = b` where `A` is SPD. Optionally accepts a preconditioner.
475///
476/// # Errors
477///
478/// Returns an error if dimensions are incompatible or the matrix is detected
479/// to be non-positive-definite (negative `p^T A p`).
480pub fn cg<F>(
481    a: &CsrMatrix<F>,
482    b: &Array1<F>,
483    config: &IterativeSolverConfig,
484    precond: Option<&dyn Preconditioner<F>>,
485) -> SparseResult<SolverResult<F>>
486where
487    F: Float + NumAssign + Sum + SparseElement + Debug + 'static,
488{
489    let (m, n) = a.shape();
490    if m != n {
491        return Err(SparseError::ValueError(
492            "CG requires a square matrix".to_string(),
493        ));
494    }
495    if b.len() != n {
496        return Err(SparseError::DimensionMismatch {
497            expected: n,
498            found: b.len(),
499        });
500    }
501
502    let tol = F::from(config.tol).ok_or_else(|| {
503        SparseError::ValueError("Failed to convert tolerance to float type".to_string())
504    })?;
505
506    let mut x = Array1::zeros(n);
507
508    // r = b - A x  (x = 0 initially, so r = b)
509    let mut r = b.clone();
510    let bnorm = norm2_arr(b);
511    if bnorm <= F::epsilon() {
512        return Ok(SolverResult {
513            solution: x,
514            n_iter: 0,
515            residual_norm: F::sparse_zero(),
516            converged: true,
517        });
518    }
519
520    let tolerance = tol * bnorm;
521
522    // z = M^{-1} r
523    let mut z = match precond {
524        Some(pc) => pc.apply(&r)?,
525        None => r.clone(),
526    };
527
528    let mut p = z.clone();
529    let mut rz = dot_arr(&r, &z);
530
531    for k in 0..config.max_iter {
532        let ap = spmv(a, &p)?;
533        let pap = dot_arr(&p, &ap);
534        if pap <= F::sparse_zero() {
535            return Ok(SolverResult {
536                solution: x,
537                n_iter: k,
538                residual_norm: norm2_arr(&r),
539                converged: false,
540            });
541        }
542
543        let alpha = rz / pap;
544        axpy(&mut x, alpha, &p);
545
546        // r = r - alpha * ap
547        axpy(&mut r, -alpha, &ap);
548
549        let rnorm = norm2_arr(&r);
550        if config.verbose {
551            // Intentionally not printing; the flag is available for future use
552        }
553        if rnorm <= tolerance {
554            return Ok(SolverResult {
555                solution: x,
556                n_iter: k + 1,
557                residual_norm: rnorm,
558                converged: true,
559            });
560        }
561
562        z = match precond {
563            Some(pc) => pc.apply(&r)?,
564            None => r.clone(),
565        };
566
567        let rz_new = dot_arr(&r, &z);
568        let beta = rz_new / rz;
569
570        // p = z + beta * p
571        for (pi, &zi) in p.iter_mut().zip(z.iter()) {
572            *pi = zi + beta * *pi;
573        }
574
575        rz = rz_new;
576    }
577
578    let rnorm = norm2_arr(&r);
579    Ok(SolverResult {
580        solution: x,
581        n_iter: config.max_iter,
582        residual_norm: rnorm,
583        converged: rnorm <= tolerance,
584    })
585}
586
587// ---------------------------------------------------------------------------
588// BiCGSTAB solver
589// ---------------------------------------------------------------------------
590
591/// Biconjugate Gradient Stabilized solver for general square systems.
592///
593/// Solves `A x = b` for non-symmetric `A`. This method is more stable
594/// than vanilla BiCG and avoids the irregular convergence of CGS.
595pub fn bicgstab<F>(
596    a: &CsrMatrix<F>,
597    b: &Array1<F>,
598    config: &IterativeSolverConfig,
599    precond: Option<&dyn Preconditioner<F>>,
600) -> SparseResult<SolverResult<F>>
601where
602    F: Float + NumAssign + Sum + SparseElement + Debug + 'static,
603{
604    let (m, n) = a.shape();
605    if m != n {
606        return Err(SparseError::ValueError(
607            "BiCGSTAB requires a square matrix".to_string(),
608        ));
609    }
610    if b.len() != n {
611        return Err(SparseError::DimensionMismatch {
612            expected: n,
613            found: b.len(),
614        });
615    }
616
617    let tol = F::from(config.tol).ok_or_else(|| {
618        SparseError::ValueError("Failed to convert tolerance to float type".to_string())
619    })?;
620
621    let mut x = Array1::zeros(n);
622    let mut r = b.clone();
623    let bnorm = norm2_arr(b);
624    if bnorm <= F::epsilon() {
625        return Ok(SolverResult {
626            solution: x,
627            n_iter: 0,
628            residual_norm: F::sparse_zero(),
629            converged: true,
630        });
631    }
632    let tolerance = tol * bnorm;
633
634    let r_hat = r.clone(); // shadow residual
635
636    let mut rho = F::sparse_one();
637    let mut alpha = F::sparse_one();
638    let mut omega = F::sparse_one();
639
640    let mut v = Array1::zeros(n);
641    let mut p = Array1::zeros(n);
642
643    let ten_eps = F::epsilon()
644        * F::from(10.0).ok_or_else(|| {
645            SparseError::ValueError("Failed to convert 10.0 to float".to_string())
646        })?;
647
648    for k in 0..config.max_iter {
649        let rho_new = dot_arr(&r_hat, &r);
650        if rho_new.abs() < ten_eps {
651            return Ok(SolverResult {
652                solution: x,
653                n_iter: k,
654                residual_norm: norm2_arr(&r),
655                converged: false,
656            });
657        }
658
659        let beta = (rho_new / rho) * (alpha / omega);
660
661        // p = r + beta * (p - omega * v)
662        for i in 0..n {
663            p[i] = r[i] + beta * (p[i] - omega * v[i]);
664        }
665
666        // Apply preconditioner
667        let p_hat = match precond {
668            Some(pc) => pc.apply(&p)?,
669            None => p.clone(),
670        };
671
672        v = spmv(a, &p_hat)?;
673
674        let den = dot_arr(&r_hat, &v);
675        if den.abs() < ten_eps {
676            return Ok(SolverResult {
677                solution: x,
678                n_iter: k,
679                residual_norm: norm2_arr(&r),
680                converged: false,
681            });
682        }
683        alpha = rho_new / den;
684
685        // s = r - alpha * v
686        let mut s = r.clone();
687        axpy(&mut s, -alpha, &v);
688
689        let snorm = norm2_arr(&s);
690        if snorm <= tolerance {
691            axpy(&mut x, alpha, &p_hat);
692            return Ok(SolverResult {
693                solution: x,
694                n_iter: k + 1,
695                residual_norm: snorm,
696                converged: true,
697            });
698        }
699
700        // Apply preconditioner to s
701        let s_hat = match precond {
702            Some(pc) => pc.apply(&s)?,
703            None => s.clone(),
704        };
705
706        let t = spmv(a, &s_hat)?;
707
708        let tt = dot_arr(&t, &t);
709        if tt < ten_eps {
710            axpy(&mut x, alpha, &p_hat);
711            return Ok(SolverResult {
712                solution: x,
713                n_iter: k + 1,
714                residual_norm: snorm,
715                converged: false,
716            });
717        }
718        omega = dot_arr(&t, &s) / tt;
719
720        // x = x + alpha * p_hat + omega * s_hat
721        axpy(&mut x, alpha, &p_hat);
722        axpy(&mut x, omega, &s_hat);
723
724        // r = s - omega * t
725        r = s;
726        axpy(&mut r, -omega, &t);
727
728        let rnorm = norm2_arr(&r);
729        if rnorm <= tolerance {
730            return Ok(SolverResult {
731                solution: x,
732                n_iter: k + 1,
733                residual_norm: rnorm,
734                converged: true,
735            });
736        }
737
738        if omega.abs() < ten_eps {
739            return Ok(SolverResult {
740                solution: x,
741                n_iter: k + 1,
742                residual_norm: rnorm,
743                converged: false,
744            });
745        }
746
747        rho = rho_new;
748    }
749
750    let rnorm = norm2_arr(&r);
751    Ok(SolverResult {
752        solution: x,
753        n_iter: config.max_iter,
754        residual_norm: rnorm,
755        converged: rnorm <= tolerance,
756    })
757}
758
759// ---------------------------------------------------------------------------
760// GMRES(m) solver
761// ---------------------------------------------------------------------------
762
763/// Restarted Generalized Minimal Residual solver for general square systems.
764///
765/// GMRES minimises the residual over a Krylov subspace using Arnoldi
766/// iteration with Givens rotations. The `restart` parameter controls
767/// the dimension of the Krylov subspace before a restart.
768pub fn gmres<F>(
769    a: &CsrMatrix<F>,
770    b: &Array1<F>,
771    config: &IterativeSolverConfig,
772    restart: usize,
773    precond: Option<&dyn Preconditioner<F>>,
774) -> SparseResult<SolverResult<F>>
775where
776    F: Float + NumAssign + Sum + SparseElement + ScalarOperand + Debug + 'static,
777{
778    let (m_rows, n) = a.shape();
779    if m_rows != n {
780        return Err(SparseError::ValueError(
781            "GMRES requires a square matrix".to_string(),
782        ));
783    }
784    if b.len() != n {
785        return Err(SparseError::DimensionMismatch {
786            expected: n,
787            found: b.len(),
788        });
789    }
790
791    let tol = F::from(config.tol).ok_or_else(|| {
792        SparseError::ValueError("Failed to convert tolerance to float type".to_string())
793    })?;
794
795    let restart_dim = restart.min(n).max(1);
796
797    let mut x = Array1::zeros(n);
798    let bnorm = norm2_arr(b);
799    if bnorm <= F::epsilon() {
800        return Ok(SolverResult {
801            solution: x,
802            n_iter: 0,
803            residual_norm: F::sparse_zero(),
804            converged: true,
805        });
806    }
807    let tolerance = tol * bnorm;
808    let ten_eps = F::epsilon()
809        * F::from(10.0).ok_or_else(|| {
810            SparseError::ValueError("Failed to convert 10.0 to float".to_string())
811        })?;
812
813    let mut total_iter = 0usize;
814
815    // Outer restart loop
816    while total_iter < config.max_iter {
817        // Compute residual
818        let ax = spmv(a, &x)?;
819        let mut r = b.clone();
820        axpy(&mut r, -F::sparse_one(), &ax);
821
822        // Apply preconditioner
823        r = match precond {
824            Some(pc) => pc.apply(&r)?,
825            None => r,
826        };
827
828        let mut rnorm = norm2_arr(&r);
829        if rnorm <= tolerance {
830            return Ok(SolverResult {
831                solution: x,
832                n_iter: total_iter,
833                residual_norm: rnorm,
834                converged: true,
835            });
836        }
837
838        // Arnoldi basis
839        let mut v_basis: Vec<Array1<F>> = Vec::with_capacity(restart_dim + 1);
840        v_basis.push(&r / rnorm);
841
842        // Hessenberg matrix (stored column-by-column)
843        let mut h = vec![vec![F::sparse_zero(); restart_dim]; restart_dim + 1];
844        // Givens rotation cosines and sines
845        let mut cs = vec![F::sparse_zero(); restart_dim];
846        let mut sn = vec![F::sparse_zero(); restart_dim];
847        // Right-hand side of the least-squares problem
848        let mut g = vec![F::sparse_zero(); restart_dim + 1];
849        g[0] = rnorm;
850
851        let mut inner_iter = 0usize;
852        while inner_iter < restart_dim && total_iter + inner_iter < config.max_iter {
853            let j = inner_iter;
854
855            // w = A * M^{-1} * v_j  (right preconditioned)
856            let mut w = spmv(a, &v_basis[j])?;
857            w = match precond {
858                Some(pc) => pc.apply(&w)?,
859                None => w,
860            };
861
862            // Modified Gram-Schmidt orthogonalisation
863            for i in 0..=j {
864                h[i][j] = dot_arr(&v_basis[i], &w);
865                axpy(&mut w, -h[i][j], &v_basis[i]);
866            }
867            h[j + 1][j] = norm2_arr(&w);
868
869            if h[j + 1][j] < ten_eps {
870                // Lucky breakdown: residual is in the Krylov subspace
871                inner_iter += 1;
872                break;
873            }
874
875            v_basis.push(&w / h[j + 1][j]);
876
877            // Apply previous Givens rotations to column j
878            for i in 0..j {
879                let temp = cs[i] * h[i][j] + sn[i] * h[i + 1][j];
880                h[i + 1][j] = -sn[i] * h[i][j] + cs[i] * h[i + 1][j];
881                h[i][j] = temp;
882            }
883
884            // Compute new Givens rotation for row j
885            let (c_val, s_val, r_val) = givens_rotation(h[j][j], h[j + 1][j]);
886            cs[j] = c_val;
887            sn[j] = s_val;
888            h[j][j] = r_val;
889            h[j + 1][j] = F::sparse_zero();
890
891            // Apply rotation to g
892            let temp = cs[j] * g[j] + sn[j] * g[j + 1];
893            g[j + 1] = -sn[j] * g[j] + cs[j] * g[j + 1];
894            g[j] = temp;
895
896            inner_iter += 1;
897
898            rnorm = g[j + 1].abs();
899            if rnorm <= tolerance {
900                break;
901            }
902        }
903
904        // Solve the upper triangular system H * y = g
905        let m_dim = inner_iter;
906        let mut y_vec = vec![F::sparse_zero(); m_dim];
907        for i in (0..m_dim).rev() {
908            y_vec[i] = g[i];
909            for jj in (i + 1)..m_dim {
910                y_vec[i] = y_vec[i] - h[i][jj] * y_vec[jj];
911            }
912            if h[i][i].abs() < ten_eps {
913                // Skip near-zero diagonal; keep current best
914                y_vec[i] = F::sparse_zero();
915            } else {
916                y_vec[i] /= h[i][i];
917            }
918        }
919
920        // Update solution: x = x + V * y
921        for (i, &yi) in y_vec.iter().enumerate() {
922            axpy(&mut x, yi, &v_basis[i]);
923        }
924
925        total_iter += inner_iter;
926
927        if rnorm <= tolerance {
928            return Ok(SolverResult {
929                solution: x,
930                n_iter: total_iter,
931                residual_norm: rnorm,
932                converged: true,
933            });
934        }
935    }
936
937    let ax = spmv(a, &x)?;
938    let mut r_final = b.clone();
939    axpy(&mut r_final, -F::sparse_one(), &ax);
940    let rnorm = norm2_arr(&r_final);
941
942    Ok(SolverResult {
943        solution: x,
944        n_iter: total_iter,
945        residual_norm: rnorm,
946        converged: rnorm <= tolerance,
947    })
948}
949
950/// Compute a Givens rotation (c, s, r) such that:
951///   [ c  s ] [ a ] = [ r ]
952///   [-s  c ] [ b ]   [ 0 ]
953fn givens_rotation<F: Float + SparseElement>(a: F, b: F) -> (F, F, F) {
954    let zero = F::sparse_zero();
955    let one = F::sparse_one();
956    if b == zero {
957        let c = if a >= zero { one } else { -one };
958        return (c, zero, a.abs());
959    }
960    if a == zero {
961        let s = if b >= zero { one } else { -one };
962        return (zero, s, b.abs());
963    }
964    if b.abs() > a.abs() {
965        let tau = a / b;
966        let s_sign = if b >= zero { one } else { -one };
967        let s = s_sign / (one + tau * tau).sqrt();
968        let c = s * tau;
969        let r = b / s;
970        (c, s, r)
971    } else {
972        let tau = b / a;
973        let c_sign = if a >= zero { one } else { -one };
974        let c = c_sign / (one + tau * tau).sqrt();
975        let s = c * tau;
976        let r = a / c;
977        (c, s, r)
978    }
979}
980
981// ---------------------------------------------------------------------------
982// Chebyshev iteration
983// ---------------------------------------------------------------------------
984
985/// Chebyshev iteration for SPD systems with known eigenvalue bounds.
986///
987/// Accelerates stationary iteration using Chebyshev polynomials. Requires
988/// estimates `lambda_min` and `lambda_max` of the smallest and largest
989/// eigenvalues of `A`. The convergence rate depends on the ratio
990/// `lambda_max / lambda_min`.
991///
992/// Unlike CG, Chebyshev iteration does not require inner products,
993/// making it attractive for massively parallel environments.
994pub fn chebyshev<F>(
995    a: &CsrMatrix<F>,
996    b: &Array1<F>,
997    config: &IterativeSolverConfig,
998    lambda_min: F,
999    lambda_max: F,
1000) -> SparseResult<SolverResult<F>>
1001where
1002    F: Float + NumAssign + Sum + SparseElement + Debug + 'static,
1003{
1004    let (m, n) = a.shape();
1005    if m != n {
1006        return Err(SparseError::ValueError(
1007            "Chebyshev iteration requires a square matrix".to_string(),
1008        ));
1009    }
1010    if b.len() != n {
1011        return Err(SparseError::DimensionMismatch {
1012            expected: n,
1013            found: b.len(),
1014        });
1015    }
1016    if lambda_min <= F::sparse_zero() || lambda_max <= F::sparse_zero() {
1017        return Err(SparseError::ValueError(
1018            "Eigenvalue bounds must be positive for Chebyshev iteration".to_string(),
1019        ));
1020    }
1021    if lambda_min >= lambda_max {
1022        return Err(SparseError::ValueError(
1023            "lambda_min must be strictly less than lambda_max".to_string(),
1024        ));
1025    }
1026
1027    let tol = F::from(config.tol).ok_or_else(|| {
1028        SparseError::ValueError("Failed to convert tolerance to float type".to_string())
1029    })?;
1030    let two = F::from(2.0)
1031        .ok_or_else(|| SparseError::ValueError("Failed to convert 2.0 to float".to_string()))?;
1032
1033    let bnorm = norm2_arr(b);
1034    if bnorm <= F::epsilon() {
1035        return Ok(SolverResult {
1036            solution: Array1::zeros(n),
1037            n_iter: 0,
1038            residual_norm: F::sparse_zero(),
1039            converged: true,
1040        });
1041    }
1042    let tolerance = tol * bnorm;
1043
1044    // Chebyshev parameters
1045    let d = (lambda_max + lambda_min) / two;
1046    let c = (lambda_max - lambda_min) / two;
1047
1048    let mut x = Array1::zeros(n);
1049    let mut r = b.clone();
1050    let mut rnorm = norm2_arr(&r);
1051
1052    if rnorm <= tolerance {
1053        return Ok(SolverResult {
1054            solution: x,
1055            n_iter: 0,
1056            residual_norm: rnorm,
1057            converged: true,
1058        });
1059    }
1060
1061    // First iteration: x_1 = x_0 + (1/d) * r_0
1062    let inv_d = F::sparse_one() / d;
1063    let mut p = Array1::zeros(n);
1064    for i in 0..n {
1065        p[i] = inv_d * r[i];
1066    }
1067    axpy(&mut x, F::sparse_one(), &p);
1068
1069    let ax = spmv(a, &x)?;
1070    for i in 0..n {
1071        r[i] = b[i] - ax[i];
1072    }
1073    rnorm = norm2_arr(&r);
1074
1075    if rnorm <= tolerance {
1076        return Ok(SolverResult {
1077            solution: x,
1078            n_iter: 1,
1079            residual_norm: rnorm,
1080            converged: true,
1081        });
1082    }
1083
1084    // Subsequent iterations
1085    let mut alpha;
1086    let mut beta;
1087    let half = F::sparse_one() / two;
1088
1089    // rho_0 = 1/d, rho_1 = d / (2c^2 - d)   -- recurrence parameter
1090    // Actually the standard Chebyshev iteration uses:
1091    //   theta = (lambda_max + lambda_min) / 2
1092    //   delta = (lambda_max - lambda_min) / 2
1093    //   sigma_1 = theta / delta
1094    //   rho_0 = 1 / sigma_1
1095    //   For k >= 1:  rho_k = 1 / (2 sigma_1 - rho_{k-1})
1096    //   alpha_k = 2 rho_k / theta  (but below we use standard formulation)
1097    //
1098    // We use the three-term recurrence formulation:
1099    //   x_{k+1} = x_k + alpha_k (b - A x_k) + beta_k (x_k - x_{k-1})
1100    //
1101    // With alpha_0 = 1/d, beta_0 = 0
1102    // For k >= 1:
1103    //   beta_k = (c * alpha_{k-1} / 2)^2
1104    //   alpha_k = 1 / (d - beta_k / alpha_{k-1})
1105
1106    let mut alpha_prev = inv_d;
1107    let c_half = c * half;
1108
1109    for k in 1..config.max_iter {
1110        beta = (c_half * alpha_prev) * (c_half * alpha_prev);
1111        let denom = d - beta / alpha_prev;
1112        if denom.abs() < F::epsilon() {
1113            break;
1114        }
1115        alpha = F::sparse_one() / denom;
1116
1117        // p = alpha * r + beta * p
1118        for i in 0..n {
1119            p[i] = alpha * r[i] + beta * p[i];
1120        }
1121        axpy(&mut x, F::sparse_one(), &p);
1122
1123        let ax_k = spmv(a, &x)?;
1124        for i in 0..n {
1125            r[i] = b[i] - ax_k[i];
1126        }
1127        rnorm = norm2_arr(&r);
1128
1129        if rnorm <= tolerance {
1130            return Ok(SolverResult {
1131                solution: x,
1132                n_iter: k + 1,
1133                residual_norm: rnorm,
1134                converged: true,
1135            });
1136        }
1137
1138        alpha_prev = alpha;
1139    }
1140
1141    Ok(SolverResult {
1142        solution: x,
1143        n_iter: config.max_iter,
1144        residual_norm: rnorm,
1145        converged: rnorm <= tolerance,
1146    })
1147}
1148
1149// ---------------------------------------------------------------------------
1150// Sparse utility functions
1151// ---------------------------------------------------------------------------
1152
1153/// Norm type for `sparse_norm`.
1154#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1155pub enum NormType {
1156    /// Frobenius norm: sqrt( sum |a_ij|^2 ).
1157    Frobenius,
1158    /// Infinity norm: max_i sum_j |a_ij|  (maximum absolute row sum).
1159    Inf,
1160    /// 1-norm: max_j sum_i |a_ij|  (maximum absolute column sum).
1161    One,
1162}
1163
1164/// Estimate the spectral radius of `A` via power iteration.
1165///
1166/// Performs `n_iter` steps of the power method on `A` and returns
1167/// the Rayleigh quotient as an estimate of the spectral radius.
1168pub fn estimate_spectral_radius<F>(a: &CsrMatrix<F>, n_iter: usize) -> SparseResult<F>
1169where
1170    F: Float + NumAssign + Sum + SparseElement + ScalarOperand + Debug + 'static,
1171{
1172    let (m, n) = a.shape();
1173    if m != n {
1174        return Err(SparseError::ValueError(
1175            "Matrix must be square to estimate spectral radius".to_string(),
1176        ));
1177    }
1178    if n == 0 {
1179        return Ok(F::sparse_zero());
1180    }
1181
1182    // Initial vector: all ones normalised
1183    let inv_sqrt_n = F::sparse_one()
1184        / F::from(n)
1185            .ok_or_else(|| {
1186                SparseError::ValueError("Failed to convert matrix size to float".to_string())
1187            })?
1188            .sqrt();
1189    let mut v = Array1::from_elem(n, inv_sqrt_n);
1190
1191    let mut lambda = F::sparse_zero();
1192    let iters = if n_iter == 0 { 50 } else { n_iter };
1193
1194    for _ in 0..iters {
1195        let w = spmv(a, &v)?;
1196        lambda = dot_arr(&v, &w);
1197        let wnorm = norm2_arr(&w);
1198        if wnorm < F::epsilon() {
1199            return Ok(F::sparse_zero());
1200        }
1201        v = &w / wnorm;
1202    }
1203
1204    Ok(lambda.abs())
1205}
1206
1207/// Extract the diagonal of a CSR matrix as an `Array1`.
1208pub fn sparse_diagonal<F>(a: &CsrMatrix<F>) -> Array1<F>
1209where
1210    F: Float + SparseElement,
1211{
1212    let (m, n) = a.shape();
1213    let dim = m.min(n);
1214    let mut diag = Array1::zeros(dim);
1215    for i in 0..dim {
1216        diag[i] = a.get(i, i);
1217    }
1218    diag
1219}
1220
1221/// Compute the trace of a CSR matrix (sum of diagonal elements).
1222pub fn sparse_trace<F>(a: &CsrMatrix<F>) -> F
1223where
1224    F: Float + SparseElement,
1225{
1226    let (m, n) = a.shape();
1227    let dim = m.min(n);
1228    let mut tr = F::sparse_zero();
1229    for i in 0..dim {
1230        tr = tr + a.get(i, i);
1231    }
1232    tr
1233}
1234
1235/// Compute a matrix norm of a CSR matrix.
1236///
1237/// Supports Frobenius, infinity (max row sum), and 1-norm (max column sum).
1238pub fn sparse_norm<F>(a: &CsrMatrix<F>, norm_type: NormType) -> F
1239where
1240    F: Float + NumAssign + SparseElement + AddAssign + MulAssign + 'static,
1241{
1242    match norm_type {
1243        NormType::Frobenius => {
1244            let mut sum_sq = F::sparse_zero();
1245            for &val in &a.data {
1246                sum_sq += val * val;
1247            }
1248            sum_sq.sqrt()
1249        }
1250        NormType::Inf => {
1251            let m = a.rows();
1252            let mut max_sum = F::sparse_zero();
1253            for i in 0..m {
1254                let range = a.row_range(i);
1255                let row_data = &a.data[range];
1256                let mut row_sum = F::sparse_zero();
1257                for &v in row_data {
1258                    let abs_v: F = v.abs();
1259                    row_sum += abs_v;
1260                }
1261                if row_sum > max_sum {
1262                    max_sum = row_sum;
1263                }
1264            }
1265            max_sum
1266        }
1267        NormType::One => {
1268            let n = a.cols();
1269            let mut col_sums = vec![F::sparse_zero(); n];
1270            for (&col, &val) in a.indices.iter().zip(a.data.iter()) {
1271                if col < n {
1272                    let abs_val: F = val.abs();
1273                    col_sums[col] += abs_val;
1274                }
1275            }
1276            let mut max_sum = F::sparse_zero();
1277            for &s in &col_sums {
1278                if s > max_sum {
1279                    max_sum = s;
1280                }
1281            }
1282            max_sum
1283        }
1284    }
1285}
1286
1287// ---------------------------------------------------------------------------
1288// Internal helper
1289// ---------------------------------------------------------------------------
1290
1291/// Find the index of the diagonal entry in a CSR row.
1292fn find_csr_diag_index(indices: &[usize], indptr: &[usize], row: usize) -> SparseResult<usize> {
1293    let start = indptr[row];
1294    let end = indptr[row + 1];
1295    for pos in start..end {
1296        if indices[pos] == row {
1297            return Ok(pos);
1298        }
1299    }
1300    Err(SparseError::ValueError(format!(
1301        "Missing diagonal element at row {row}"
1302    )))
1303}
1304
1305// ---------------------------------------------------------------------------
1306// Tests
1307// ---------------------------------------------------------------------------
1308
1309#[cfg(test)]
1310mod tests {
1311    use super::*;
1312
1313    /// Helper: build an SPD 3x3 matrix:
1314    ///   [4 -1 -1]
1315    ///   [-1  4 -1]
1316    ///   [-1 -1  4]
1317    fn spd_3x3() -> CsrMatrix<f64> {
1318        let rows = vec![0, 0, 0, 1, 1, 1, 2, 2, 2];
1319        let cols = vec![0, 1, 2, 0, 1, 2, 0, 1, 2];
1320        let data = vec![4.0, -1.0, -1.0, -1.0, 4.0, -1.0, -1.0, -1.0, 4.0];
1321        CsrMatrix::new(data, rows, cols, (3, 3)).expect("failed to build test matrix")
1322    }
1323
1324    /// Helper: build a non-symmetric 3x3 matrix:
1325    ///   [5 -1  0]
1326    ///   [-2  5 -1]
1327    ///   [0  -2  5]
1328    fn nonsym_3x3() -> CsrMatrix<f64> {
1329        let rows = vec![0, 0, 1, 1, 1, 2, 2];
1330        let cols = vec![0, 1, 0, 1, 2, 1, 2];
1331        let data = vec![5.0, -1.0, -2.0, 5.0, -1.0, -2.0, 5.0];
1332        CsrMatrix::new(data, rows, cols, (3, 3)).expect("failed to build test matrix")
1333    }
1334
1335    /// Helper: build a larger 5x5 SPD tridiagonal matrix
1336    fn spd_5x5() -> CsrMatrix<f64> {
1337        let mut rows = Vec::new();
1338        let mut cols = Vec::new();
1339        let mut data = Vec::new();
1340        for i in 0..5 {
1341            rows.push(i);
1342            cols.push(i);
1343            data.push(4.0);
1344            if i > 0 {
1345                rows.push(i);
1346                cols.push(i - 1);
1347                data.push(-1.0);
1348            }
1349            if i < 4 {
1350                rows.push(i);
1351                cols.push(i + 1);
1352                data.push(-1.0);
1353            }
1354        }
1355        CsrMatrix::new(data, rows, cols, (5, 5)).expect("failed to build test matrix")
1356    }
1357
1358    fn rhs_3() -> Array1<f64> {
1359        Array1::from_vec(vec![1.0, 2.0, 3.0])
1360    }
1361
1362    fn rhs_5() -> Array1<f64> {
1363        Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0])
1364    }
1365
1366    fn verify_solution(a: &CsrMatrix<f64>, x: &Array1<f64>, b: &Array1<f64>, tol: f64) {
1367        let ax = spmv(a, x).expect("spmv failed in verification");
1368        for (i, (&axi, &bi)) in ax.iter().zip(b.iter()).enumerate() {
1369            assert!(
1370                (axi - bi).abs() < tol,
1371                "Mismatch at index {i}: Ax[{i}]={axi}, b[{i}]={bi}"
1372            );
1373        }
1374    }
1375
1376    // --- CG tests ---
1377
1378    #[test]
1379    fn test_cg_spd_3x3() {
1380        let a = spd_3x3();
1381        let b = rhs_3();
1382        let cfg = IterativeSolverConfig::default();
1383        let res = cg(&a, &b, &cfg, None).expect("CG failed");
1384        assert!(res.converged, "CG did not converge");
1385        verify_solution(&a, &res.solution, &b, 1e-8);
1386    }
1387
1388    #[test]
1389    fn test_cg_spd_5x5() {
1390        let a = spd_5x5();
1391        let b = rhs_5();
1392        let cfg = IterativeSolverConfig::default();
1393        let res = cg(&a, &b, &cfg, None).expect("CG failed");
1394        assert!(res.converged);
1395        verify_solution(&a, &res.solution, &b, 1e-8);
1396    }
1397
1398    #[test]
1399    fn test_cg_with_jacobi_precond() {
1400        let a = spd_3x3();
1401        let b = rhs_3();
1402        let pc = JacobiPreconditioner::new(&a).expect("Jacobi failed");
1403        let cfg = IterativeSolverConfig::default();
1404        let res = cg(&a, &b, &cfg, Some(&pc)).expect("CG + Jacobi failed");
1405        assert!(res.converged);
1406        verify_solution(&a, &res.solution, &b, 1e-8);
1407    }
1408
1409    #[test]
1410    fn test_cg_zero_rhs() {
1411        let a = spd_3x3();
1412        let b = Array1::zeros(3);
1413        let cfg = IterativeSolverConfig::default();
1414        let res = cg(&a, &b, &cfg, None).expect("CG failed");
1415        assert!(res.converged);
1416        assert!(res.residual_norm <= 1e-14);
1417    }
1418
1419    #[test]
1420    fn test_cg_dimension_mismatch() {
1421        let a = spd_3x3();
1422        let b = Array1::from_vec(vec![1.0, 2.0]);
1423        let cfg = IterativeSolverConfig::default();
1424        assert!(cg(&a, &b, &cfg, None).is_err());
1425    }
1426
1427    // --- BiCGSTAB tests ---
1428
1429    #[test]
1430    fn test_bicgstab_nonsym_3x3() {
1431        let a = nonsym_3x3();
1432        let b = rhs_3();
1433        let cfg = IterativeSolverConfig::default();
1434        let res = bicgstab(&a, &b, &cfg, None).expect("BiCGSTAB failed");
1435        assert!(res.converged, "BiCGSTAB did not converge");
1436        verify_solution(&a, &res.solution, &b, 1e-8);
1437    }
1438
1439    #[test]
1440    fn test_bicgstab_spd_3x3() {
1441        let a = spd_3x3();
1442        let b = rhs_3();
1443        let cfg = IterativeSolverConfig::default();
1444        let res = bicgstab(&a, &b, &cfg, None).expect("BiCGSTAB failed");
1445        assert!(res.converged);
1446        verify_solution(&a, &res.solution, &b, 1e-8);
1447    }
1448
1449    #[test]
1450    fn test_bicgstab_with_jacobi() {
1451        let a = nonsym_3x3();
1452        let b = rhs_3();
1453        let pc = JacobiPreconditioner::new(&a).expect("Jacobi failed");
1454        let cfg = IterativeSolverConfig::default();
1455        let res = bicgstab(&a, &b, &cfg, Some(&pc)).expect("BiCGSTAB + Jacobi failed");
1456        assert!(res.converged);
1457        verify_solution(&a, &res.solution, &b, 1e-8);
1458    }
1459
1460    #[test]
1461    fn test_bicgstab_zero_rhs() {
1462        let a = nonsym_3x3();
1463        let b = Array1::zeros(3);
1464        let cfg = IterativeSolverConfig::default();
1465        let res = bicgstab(&a, &b, &cfg, None).expect("BiCGSTAB failed");
1466        assert!(res.converged);
1467    }
1468
1469    // --- GMRES tests ---
1470
1471    #[test]
1472    fn test_gmres_nonsym_3x3() {
1473        let a = nonsym_3x3();
1474        let b = rhs_3();
1475        let cfg = IterativeSolverConfig::default();
1476        let res = gmres(&a, &b, &cfg, 30, None).expect("GMRES failed");
1477        assert!(res.converged, "GMRES did not converge");
1478        verify_solution(&a, &res.solution, &b, 1e-8);
1479    }
1480
1481    #[test]
1482    fn test_gmres_spd_5x5() {
1483        let a = spd_5x5();
1484        let b = rhs_5();
1485        let cfg = IterativeSolverConfig::default();
1486        let res = gmres(&a, &b, &cfg, 10, None).expect("GMRES failed");
1487        assert!(res.converged);
1488        verify_solution(&a, &res.solution, &b, 1e-8);
1489    }
1490
1491    #[test]
1492    fn test_gmres_with_jacobi() {
1493        let a = nonsym_3x3();
1494        let b = rhs_3();
1495        let pc = JacobiPreconditioner::new(&a).expect("Jacobi failed");
1496        let cfg = IterativeSolverConfig::default();
1497        let res = gmres(&a, &b, &cfg, 30, Some(&pc)).expect("GMRES + Jacobi failed");
1498        assert!(res.converged);
1499        verify_solution(&a, &res.solution, &b, 1e-8);
1500    }
1501
1502    #[test]
1503    fn test_gmres_restart_small() {
1504        // Test with very small restart value (forces outer restarts)
1505        let a = spd_5x5();
1506        let b = rhs_5();
1507        let cfg = IterativeSolverConfig {
1508            max_iter: 200,
1509            tol: 1e-8,
1510            verbose: false,
1511        };
1512        let res = gmres(&a, &b, &cfg, 2, None).expect("GMRES failed");
1513        assert!(res.converged, "GMRES(2) did not converge");
1514        verify_solution(&a, &res.solution, &b, 1e-6);
1515    }
1516
1517    // --- Chebyshev tests ---
1518
1519    #[test]
1520    fn test_chebyshev_spd_3x3() {
1521        let a = spd_3x3();
1522        let b = rhs_3();
1523        // Eigenvalues: 2 (once), 5 (twice). Use bounds that bracket them.
1524        let cfg = IterativeSolverConfig {
1525            max_iter: 200,
1526            tol: 1e-8,
1527            verbose: false,
1528        };
1529        let res = chebyshev(&a, &b, &cfg, 1.5, 5.5).expect("Chebyshev failed");
1530        assert!(res.converged, "Chebyshev did not converge");
1531        verify_solution(&a, &res.solution, &b, 1e-6);
1532    }
1533
1534    #[test]
1535    fn test_chebyshev_spd_5x5() {
1536        let a = spd_5x5();
1537        let b = rhs_5();
1538        // Tridiagonal 4,-1,-1: eigenvalues in [4 - 2cos(pi/6), 4 + 2cos(pi/6)] ~ [2.27, 5.73]
1539        let cfg = IterativeSolverConfig {
1540            max_iter: 300,
1541            tol: 1e-8,
1542            verbose: false,
1543        };
1544        let res = chebyshev(&a, &b, &cfg, 2.0, 6.0).expect("Chebyshev failed");
1545        assert!(res.converged, "Chebyshev did not converge");
1546        verify_solution(&a, &res.solution, &b, 1e-6);
1547    }
1548
1549    #[test]
1550    fn test_chebyshev_invalid_bounds() {
1551        let a = spd_3x3();
1552        let b = rhs_3();
1553        let cfg = IterativeSolverConfig::default();
1554        // lambda_min >= lambda_max should fail
1555        assert!(chebyshev(&a, &b, &cfg, 5.0, 3.0).is_err());
1556        // Negative lambda_min should fail
1557        assert!(chebyshev(&a, &b, &cfg, -1.0, 5.0).is_err());
1558    }
1559
1560    // --- Preconditioner tests ---
1561
1562    #[test]
1563    fn test_jacobi_from_matrix() {
1564        let a = spd_3x3();
1565        let pc = JacobiPreconditioner::new(&a).expect("Jacobi creation failed");
1566        let r = Array1::from_vec(vec![4.0, 8.0, 12.0]);
1567        let z = pc.apply(&r).expect("Jacobi apply failed");
1568        // diagonal is 4.0, so z = r/4
1569        assert!((z[0] - 1.0).abs() < 1e-12);
1570        assert!((z[1] - 2.0).abs() < 1e-12);
1571        assert!((z[2] - 3.0).abs() < 1e-12);
1572    }
1573
1574    #[test]
1575    fn test_ilu0_preconditioner() {
1576        let a = spd_3x3();
1577        let pc = ILU0Preconditioner::new(&a).expect("ILU0 creation failed");
1578        let r = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1579        let z = pc.apply(&r).expect("ILU0 apply failed");
1580        // ILU(0) on a dense 3x3 is the exact LU, so M^{-1}r = A^{-1}r
1581        let a_inv_r = spmv(&a, &z).expect("spmv failed");
1582        for i in 0..3 {
1583            assert!(
1584                (a_inv_r[i] - r[i]).abs() < 1e-10,
1585                "ILU0 did not produce exact inverse on dense matrix at index {i}"
1586            );
1587        }
1588    }
1589
1590    #[test]
1591    fn test_ssor_preconditioner() {
1592        let a = spd_3x3();
1593        let pc = SSORPreconditioner::new(a.clone(), 1.0).expect("SSOR creation failed");
1594        let r = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1595        let z = pc.apply(&r).expect("SSOR apply failed");
1596        // Just check that the output is finite and has the right length
1597        assert_eq!(z.len(), 3);
1598        for &val in z.iter() {
1599            assert!(val.is_finite(), "SSOR produced non-finite value");
1600        }
1601    }
1602
1603    #[test]
1604    fn test_ssor_invalid_omega() {
1605        let a = spd_3x3();
1606        assert!(SSORPreconditioner::new(a.clone(), 0.0).is_err());
1607        assert!(SSORPreconditioner::new(a.clone(), 2.0).is_err());
1608        assert!(SSORPreconditioner::new(a.clone(), -0.5).is_err());
1609    }
1610
1611    #[test]
1612    fn test_cg_with_ilu0() {
1613        let a = spd_3x3();
1614        let b = rhs_3();
1615        let pc = ILU0Preconditioner::new(&a).expect("ILU0 creation failed");
1616        let cfg = IterativeSolverConfig::default();
1617        let res = cg(&a, &b, &cfg, Some(&pc)).expect("CG + ILU0 failed");
1618        assert!(res.converged, "CG + ILU0 did not converge");
1619        verify_solution(&a, &res.solution, &b, 1e-8);
1620    }
1621
1622    // --- Utility function tests ---
1623
1624    #[test]
1625    fn test_sparse_diagonal() {
1626        let a = spd_3x3();
1627        let d = sparse_diagonal(&a);
1628        assert_eq!(d.len(), 3);
1629        assert!((d[0] - 4.0).abs() < 1e-12);
1630        assert!((d[1] - 4.0).abs() < 1e-12);
1631        assert!((d[2] - 4.0).abs() < 1e-12);
1632    }
1633
1634    #[test]
1635    fn test_sparse_trace() {
1636        let a = spd_3x3();
1637        let tr = sparse_trace(&a);
1638        assert!((tr - 12.0).abs() < 1e-12);
1639    }
1640
1641    #[test]
1642    fn test_sparse_norm_frobenius() {
1643        let a = spd_3x3();
1644        // ||A||_F = sqrt(sum of squares of all elements)
1645        // 3*16 + 6*1 = 48+6 = 54,  sqrt(54) ~ 7.3484692...
1646        let nf = sparse_norm(&a, NormType::Frobenius);
1647        assert!((nf - 54.0_f64.sqrt()).abs() < 1e-10);
1648    }
1649
1650    #[test]
1651    fn test_sparse_norm_inf() {
1652        let a = spd_3x3();
1653        // Each row sums to |4| + |-1| + |-1| = 6
1654        let ni = sparse_norm(&a, NormType::Inf);
1655        assert!((ni - 6.0).abs() < 1e-12);
1656    }
1657
1658    #[test]
1659    fn test_sparse_norm_one() {
1660        let a = spd_3x3();
1661        // Each column sums to |4| + |-1| + |-1| = 6
1662        let n1 = sparse_norm(&a, NormType::One);
1663        assert!((n1 - 6.0).abs() < 1e-12);
1664    }
1665
1666    #[test]
1667    fn test_estimate_spectral_radius() {
1668        let a = spd_3x3();
1669        // Eigenvalues of [[4,-1,-1],[-1,4,-1],[-1,-1,4]]: 2 (once), 5 (twice)
1670        // Spectral radius = 5.0
1671        let rho = estimate_spectral_radius(&a, 100).expect("spectral radius estimation failed");
1672        assert!(
1673            (rho - 5.0).abs() < 0.5,
1674            "Expected spectral radius near 5.0, got {rho}"
1675        );
1676    }
1677
1678    #[test]
1679    fn test_sparse_diagonal_rectangular() {
1680        // Test diagonal extraction on non-square (4x3) matrix
1681        let rows = vec![0, 1, 2, 3];
1682        let cols = vec![0, 1, 2, 0];
1683        let data = vec![10.0, 20.0, 30.0, 99.0];
1684        let a = CsrMatrix::new(data, rows, cols, (4, 3)).expect("failed to build matrix");
1685        let d = sparse_diagonal(&a);
1686        assert_eq!(d.len(), 3);
1687        assert!((d[0] - 10.0).abs() < 1e-12);
1688        assert!((d[1] - 20.0).abs() < 1e-12);
1689        assert!((d[2] - 30.0).abs() < 1e-12);
1690    }
1691
1692    #[test]
1693    fn test_solver_config_default() {
1694        let cfg = IterativeSolverConfig::default();
1695        assert_eq!(cfg.max_iter, 1000);
1696        assert!((cfg.tol - 1e-10).abs() < 1e-15);
1697        assert!(!cfg.verbose);
1698    }
1699
1700    #[test]
1701    fn test_gmres_dimension_mismatch() {
1702        let a = spd_3x3();
1703        let b = Array1::from_vec(vec![1.0, 2.0]);
1704        let cfg = IterativeSolverConfig::default();
1705        assert!(gmres(&a, &b, &cfg, 10, None).is_err());
1706    }
1707
1708    #[test]
1709    fn test_bicgstab_5x5() {
1710        let a = spd_5x5();
1711        let b = rhs_5();
1712        let cfg = IterativeSolverConfig::default();
1713        let res = bicgstab(&a, &b, &cfg, None).expect("BiCGSTAB failed");
1714        assert!(res.converged);
1715        verify_solution(&a, &res.solution, &b, 1e-8);
1716    }
1717
1718    #[test]
1719    fn test_cg_with_ssor_precond() {
1720        let a = spd_5x5();
1721        let b = rhs_5();
1722        let pc = SSORPreconditioner::new(a.clone(), 1.2).expect("SSOR creation failed");
1723        let cfg = IterativeSolverConfig::default();
1724        let res = cg(&a, &b, &cfg, Some(&pc)).expect("CG + SSOR failed");
1725        assert!(res.converged);
1726        verify_solution(&a, &res.solution, &b, 1e-8);
1727    }
1728
1729    #[test]
1730    fn test_nonsquare_matrix_error() {
1731        let rows = vec![0, 1, 2];
1732        let cols = vec![0, 0, 1];
1733        let data = vec![1.0, 2.0, 3.0];
1734        let a = CsrMatrix::new(data, rows, cols, (3, 2)).expect("failed to build matrix");
1735        let b = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1736        let cfg = IterativeSolverConfig::default();
1737        assert!(cg(&a, &b, &cfg, None).is_err());
1738        assert!(bicgstab(&a, &b, &cfg, None).is_err());
1739        assert!(gmres(&a, &b, &cfg, 10, None).is_err());
1740    }
1741}