Skip to main content

scirs2_sparse/krylov/
mod.rs

1//! Advanced Krylov subspace eigensolvers
2//!
3//! This module provides production-quality implementations of:
4//!
5//! - **Implicitly Restarted Arnoldi Method (IRAM)** for general (non-symmetric) matrices
6//! - **Thick-Restart Lanczos** for symmetric matrices
7//! - **Shift-and-Invert mode** for computing interior eigenvalues
8//! - **Harmonic Ritz extraction** for better interior eigenvalue approximations
9//!
10//! # References
11//!
12//! - Sorensen, D.C. (1992). "Implicit application of polynomial filters in a k-step
13//!   Arnoldi method". SIAM J. Matrix Anal. Appl. 13(1), 357-385.
14//! - Wu, K. & Simon, H. (2000). "Thick-restart Lanczos method for large symmetric
15//!   eigenvalue problems". SIAM J. Matrix Anal. Appl. 22(2), 602-616.
16
17pub mod augmented;
18pub mod deflation;
19pub mod gmres_dr;
20pub mod recycled_krylov;
21
22pub use augmented::AugmentedKrylov;
23pub use deflation::HarmonicRitzDeflation;
24pub use gmres_dr::GmresDR;
25pub use recycled_krylov::RecycledGmres;
26
27use crate::csr::CsrMatrix;
28use crate::error::{SparseError, SparseResult};
29use crate::iterative_solvers::Preconditioner;
30use scirs2_core::ndarray::{Array1, Array2};
31use scirs2_core::numeric::{Float, NumAssign, SparseElement};
32use std::fmt::Debug;
33use std::iter::Sum;
34
35// ---------------------------------------------------------------------------
36// Configuration types
37// ---------------------------------------------------------------------------
38
39/// Specifies which eigenvalues to target.
40#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
41pub enum WhichEigenvalues {
42    /// Largest magnitude eigenvalues.
43    #[default]
44    LargestMagnitude,
45    /// Smallest magnitude eigenvalues.
46    SmallestMagnitude,
47    /// Largest real part.
48    LargestReal,
49    /// Smallest real part.
50    SmallestReal,
51    /// Eigenvalues closest to a given shift (requires shift-and-invert).
52    NearShift,
53}
54
55/// Configuration for the Implicitly Restarted Arnoldi Method.
56#[derive(Debug, Clone)]
57pub struct IramConfig {
58    /// Number of eigenvalues to compute.
59    pub n_eigenvalues: usize,
60    /// Dimension of the Krylov subspace (must be > n_eigenvalues).
61    pub krylov_dim: usize,
62    /// Maximum number of restart cycles.
63    pub max_restarts: usize,
64    /// Convergence tolerance.
65    pub tol: f64,
66    /// Which eigenvalues to target.
67    pub which: WhichEigenvalues,
68    /// Whether to use harmonic Ritz extraction.
69    pub harmonic_ritz: bool,
70    /// Shift value for shift-and-invert mode.
71    pub shift: Option<f64>,
72    /// Whether to print convergence diagnostics.
73    pub verbose: bool,
74}
75
76impl Default for IramConfig {
77    fn default() -> Self {
78        Self {
79            n_eigenvalues: 6,
80            krylov_dim: 20,
81            max_restarts: 300,
82            tol: 1e-10,
83            which: WhichEigenvalues::LargestMagnitude,
84            harmonic_ritz: false,
85            shift: None,
86            verbose: false,
87        }
88    }
89}
90
91/// Configuration for the Thick-Restart Lanczos method.
92#[derive(Debug, Clone)]
93pub struct ThickRestartLanczosConfig {
94    /// Number of eigenvalues to compute.
95    pub n_eigenvalues: usize,
96    /// Maximum Lanczos basis size before restart.
97    pub max_basis_size: usize,
98    /// Maximum number of restart cycles.
99    pub max_restarts: usize,
100    /// Convergence tolerance.
101    pub tol: f64,
102    /// Which eigenvalues to compute: "smallest" or "largest".
103    pub which: WhichEigenvalues,
104    /// Shift value for shift-and-invert mode (None = standard mode).
105    pub shift: Option<f64>,
106    /// Whether to print convergence diagnostics.
107    pub verbose: bool,
108}
109
110impl Default for ThickRestartLanczosConfig {
111    fn default() -> Self {
112        Self {
113            n_eigenvalues: 6,
114            max_basis_size: 30,
115            max_restarts: 300,
116            tol: 1e-10,
117            which: WhichEigenvalues::SmallestReal,
118            shift: None,
119            verbose: false,
120        }
121    }
122}
123
124/// Result of a Krylov eigensolver computation.
125#[derive(Debug, Clone)]
126pub struct KrylovEigenResult<F> {
127    /// Converged eigenvalues.
128    pub eigenvalues: Array1<F>,
129    /// Converged eigenvectors (stored column-wise).
130    pub eigenvectors: Array2<F>,
131    /// Number of restart cycles performed.
132    pub restarts: usize,
133    /// Total number of matrix-vector products.
134    pub matvec_count: usize,
135    /// Residual norms for each eigenpair.
136    pub residual_norms: Vec<F>,
137    /// Whether all requested eigenvalues converged.
138    pub converged: bool,
139    /// Number of converged eigenpairs.
140    pub n_converged: usize,
141}
142
143// ---------------------------------------------------------------------------
144// Dense linear algebra helpers
145// ---------------------------------------------------------------------------
146
147/// CSR matrix-vector product.
148fn csr_matvec<F>(a: &CsrMatrix<F>, x: &[F]) -> SparseResult<Vec<F>>
149where
150    F: Float + NumAssign + Sum + SparseElement + 'static,
151{
152    let (m, n) = a.shape();
153    if x.len() != n {
154        return Err(SparseError::DimensionMismatch {
155            expected: n,
156            found: x.len(),
157        });
158    }
159    let mut y = vec![F::sparse_zero(); m];
160    for i in 0..m {
161        let range = a.row_range(i);
162        let cols = &a.indices[range.clone()];
163        let vals = &a.data[range];
164        let mut acc = F::sparse_zero();
165        for (idx, &col) in cols.iter().enumerate() {
166            acc += vals[idx] * x[col];
167        }
168        y[i] = acc;
169    }
170    Ok(y)
171}
172
173#[inline]
174fn dot_vec<F: Float + Sum>(a: &[F], b: &[F]) -> F {
175    a.iter().zip(b.iter()).map(|(&ai, &bi)| ai * bi).sum()
176}
177
178#[inline]
179fn norm2_vec<F: Float + Sum>(v: &[F]) -> F {
180    dot_vec(v, v).sqrt()
181}
182
183fn normalise_vec<F: Float + Sum + SparseElement>(v: &mut [F]) -> F {
184    let nrm = norm2_vec(v);
185    if nrm > F::epsilon() {
186        let inv = F::sparse_one() / nrm;
187        for vi in v.iter_mut() {
188            *vi = *vi * inv;
189        }
190    }
191    nrm
192}
193
194/// Dense symmetric eigensolver via Jacobi rotations.
195/// Input: `a` (k x k) row-major symmetric. Returns sorted eigenvalues and column-major
196/// eigenvectors.
197fn jacobi_eig<F>(a: &[F], k: usize) -> SparseResult<(Vec<F>, Vec<F>)>
198where
199    F: Float + NumAssign + Sum + SparseElement + Debug + 'static,
200{
201    let max_sweeps = 200;
202    let tol_j = F::epsilon() * F::from(100.0).unwrap_or(F::sparse_one());
203
204    let mut mat = a.to_vec();
205    let mut vecs = vec![F::sparse_zero(); k * k];
206    for i in 0..k {
207        vecs[i * k + i] = F::sparse_one();
208    }
209
210    for _sw in 0..max_sweeps {
211        let mut max_off = F::sparse_zero();
212        for i in 0..k {
213            for j in (i + 1)..k {
214                let v = mat[i * k + j].abs();
215                if v > max_off {
216                    max_off = v;
217                }
218            }
219        }
220        if max_off < tol_j {
221            break;
222        }
223
224        for p in 0..k {
225            for q in (p + 1)..k {
226                let apq = mat[p * k + q];
227                if apq.abs() < tol_j {
228                    continue;
229                }
230                let diff = mat[q * k + q] - mat[p * k + p];
231                let theta = if diff.abs() < F::epsilon() {
232                    F::from(std::f64::consts::FRAC_PI_4).unwrap_or(F::sparse_one())
233                } else {
234                    let tau = diff / (apq + apq);
235                    let sign_tau = if tau >= F::sparse_zero() {
236                        F::sparse_one()
237                    } else {
238                        -F::sparse_one()
239                    };
240                    let t = sign_tau / (tau.abs() + (F::sparse_one() + tau * tau).sqrt());
241                    t.atan()
242                };
243
244                let (sin_t, cos_t) = (theta.sin(), theta.cos());
245                let two = F::from(2.0).unwrap_or(F::sparse_one());
246
247                for r in 0..k {
248                    if r == p || r == q {
249                        continue;
250                    }
251                    let arp = mat[r * k + p];
252                    let arq = mat[r * k + q];
253                    mat[r * k + p] = cos_t * arp - sin_t * arq;
254                    mat[r * k + q] = sin_t * arp + cos_t * arq;
255                    mat[p * k + r] = mat[r * k + p];
256                    mat[q * k + r] = mat[r * k + q];
257                }
258                let app = mat[p * k + p];
259                let aqq = mat[q * k + q];
260                let apq_old = apq;
261                mat[p * k + p] =
262                    cos_t * cos_t * app - two * sin_t * cos_t * apq_old + sin_t * sin_t * aqq;
263                mat[q * k + q] =
264                    sin_t * sin_t * app + two * sin_t * cos_t * apq_old + cos_t * cos_t * aqq;
265                mat[p * k + q] = F::sparse_zero();
266                mat[q * k + p] = F::sparse_zero();
267
268                for r in 0..k {
269                    let vp = vecs[p * k + r];
270                    let vq = vecs[q * k + r];
271                    vecs[p * k + r] = cos_t * vp - sin_t * vq;
272                    vecs[q * k + r] = sin_t * vp + cos_t * vq;
273                }
274            }
275        }
276    }
277
278    let mut evals: Vec<F> = (0..k).map(|i| mat[i * k + i]).collect();
279    let mut perm: Vec<usize> = (0..k).collect();
280    perm.sort_by(|&a_i, &b_i| {
281        evals[a_i]
282            .partial_cmp(&evals[b_i])
283            .unwrap_or(std::cmp::Ordering::Equal)
284    });
285    let sorted_evals: Vec<F> = perm.iter().map(|&i| evals[i]).collect();
286    let mut sorted_vecs = vec![F::sparse_zero(); k * k];
287    for (new_col, &old_col) in perm.iter().enumerate() {
288        for r in 0..k {
289            sorted_vecs[new_col * k + r] = vecs[old_col * k + r];
290        }
291    }
292    evals = sorted_evals;
293    Ok((evals, sorted_vecs))
294}
295
296/// Compute upper Hessenberg eigenvalues using the implicit QR algorithm (real Schur).
297/// `h` is m x m row-major upper Hessenberg. Returns eigenvalues as (real, imag) pairs.
298fn hessenberg_eigenvalues<F>(h: &[F], m: usize) -> SparseResult<Vec<(F, F)>>
299where
300    F: Float + NumAssign + Sum + SparseElement + Debug + 'static,
301{
302    let mut mat = h.to_vec();
303    let max_iter = m * 100;
304    let tol = F::epsilon() * F::from(1000.0).unwrap_or(F::sparse_one());
305
306    let mut n_active = m;
307    let mut eigenvalues: Vec<(F, F)> = Vec::with_capacity(m);
308    let mut iter_count = 0;
309
310    while n_active > 0 && iter_count < max_iter {
311        iter_count += 1;
312
313        if n_active == 1 {
314            eigenvalues.push((mat[0], F::sparse_zero()));
315            break;
316        }
317
318        // Check for deflation on the sub-diagonal
319        let sub_idx = (n_active - 1) * m + (n_active - 2);
320        if mat[sub_idx].abs() < tol {
321            eigenvalues.push((mat[(n_active - 1) * m + (n_active - 1)], F::sparse_zero()));
322            n_active -= 1;
323            continue;
324        }
325
326        // Check for 2x2 block at bottom
327        if n_active == 2 {
328            let a11 = mat[0];
329            let a12 = mat[1];
330            let a21 = mat[m];
331            let a22 = mat[m + 1];
332            let trace = a11 + a22;
333            let det = a11 * a22 - a12 * a21;
334            let disc = trace * trace - F::from(4.0).unwrap_or(F::sparse_one()) * det;
335            let two = F::from(2.0).unwrap_or(F::sparse_one());
336            if disc >= F::sparse_zero() {
337                let sqrt_d = disc.sqrt();
338                eigenvalues.push(((trace + sqrt_d) / two, F::sparse_zero()));
339                eigenvalues.push(((trace - sqrt_d) / two, F::sparse_zero()));
340            } else {
341                let sqrt_d = (-disc).sqrt();
342                eigenvalues.push((trace / two, sqrt_d / two));
343                eigenvalues.push((trace / two, -sqrt_d / two));
344            }
345            break;
346        }
347
348        // Wilkinson shift from bottom-right 2x2 block
349        let p = n_active - 1;
350        let a_pp = mat[p * m + p];
351        let a_pm1 = mat[(p - 1) * m + (p - 1)];
352        let a_p_pm1 = mat[(p - 1) * m + p];
353        let a_pm1_p = mat[p * m + (p - 1)];
354        let trace_2x2 = a_pm1 + a_pp;
355        let det_2x2 = a_pm1 * a_pp - a_p_pm1 * a_pm1_p;
356        let disc = trace_2x2 * trace_2x2 - F::from(4.0).unwrap_or(F::sparse_one()) * det_2x2;
357        let two = F::from(2.0).unwrap_or(F::sparse_one());
358        let shift = if disc >= F::sparse_zero() {
359            let s1 = (trace_2x2 + disc.sqrt()) / two;
360            let s2 = (trace_2x2 - disc.sqrt()) / two;
361            // Pick the shift closer to a_pp
362            if (s1 - a_pp).abs() < (s2 - a_pp).abs() {
363                s1
364            } else {
365                s2
366            }
367        } else {
368            a_pp
369        };
370
371        // Apply shifted QR step using Givens rotations
372        // Shift: H <- H - sigma I
373        for i in 0..n_active {
374            mat[i * m + i] -= shift;
375        }
376
377        // QR factorisation via Givens rotations
378        let mut givens_c = vec![F::sparse_zero(); n_active - 1];
379        let mut givens_s = vec![F::sparse_zero(); n_active - 1];
380
381        for i in 0..(n_active - 1) {
382            let a_val = mat[i * m + i];
383            let b_val = mat[(i + 1) * m + i];
384            let r = (a_val * a_val + b_val * b_val).sqrt();
385            if r < F::epsilon() {
386                givens_c[i] = F::sparse_one();
387                givens_s[i] = F::sparse_zero();
388                continue;
389            }
390            let c = a_val / r;
391            let s = b_val / r;
392            givens_c[i] = c;
393            givens_s[i] = s;
394
395            // Apply rotation to rows i and i+1
396            for j in 0..n_active {
397                let t1 = mat[i * m + j];
398                let t2 = mat[(i + 1) * m + j];
399                mat[i * m + j] = c * t1 + s * t2;
400                mat[(i + 1) * m + j] = -s * t1 + c * t2;
401            }
402        }
403
404        // Accumulate R * Q (apply Givens from the right)
405        for i in 0..(n_active - 1) {
406            let c = givens_c[i];
407            let s = givens_s[i];
408            for j in 0..n_active {
409                let t1 = mat[j * m + i];
410                let t2 = mat[j * m + (i + 1)];
411                mat[j * m + i] = c * t1 + s * t2;
412                mat[j * m + (i + 1)] = -s * t1 + c * t2;
413            }
414        }
415
416        // Undo shift: H <- H + sigma I
417        for i in 0..n_active {
418            mat[i * m + i] += shift;
419        }
420    }
421
422    // If we exhausted iterations without full deflation, extract remaining
423    if eigenvalues.len() < m {
424        for i in eigenvalues.len()..m {
425            if i < n_active {
426                eigenvalues.push((mat[i * m + i], F::sparse_zero()));
427            }
428        }
429    }
430
431    Ok(eigenvalues)
432}
433
434/// Select eigenvalue indices based on `which` criterion. Returns indices into the
435/// eigenvalue vector sorted by preference.
436fn select_eigenvalues<F: Float + SparseElement>(
437    evals: &[(F, F)],
438    which: WhichEigenvalues,
439    shift: Option<F>,
440) -> Vec<usize> {
441    let mut indices: Vec<usize> = (0..evals.len()).collect();
442    indices.sort_by(|&a_i, &b_i| {
443        let (ra, ia) = evals[a_i];
444        let (rb, ib) = evals[b_i];
445        match which {
446            WhichEigenvalues::LargestMagnitude => {
447                let ma = ra * ra + ia * ia;
448                let mb = rb * rb + ib * ib;
449                mb.partial_cmp(&ma).unwrap_or(std::cmp::Ordering::Equal)
450            }
451            WhichEigenvalues::SmallestMagnitude => {
452                let ma = ra * ra + ia * ia;
453                let mb = rb * rb + ib * ib;
454                ma.partial_cmp(&mb).unwrap_or(std::cmp::Ordering::Equal)
455            }
456            WhichEigenvalues::LargestReal => {
457                rb.partial_cmp(&ra).unwrap_or(std::cmp::Ordering::Equal)
458            }
459            WhichEigenvalues::SmallestReal => {
460                ra.partial_cmp(&rb).unwrap_or(std::cmp::Ordering::Equal)
461            }
462            WhichEigenvalues::NearShift => {
463                let sigma = shift.unwrap_or(F::sparse_zero());
464                let da = (ra - sigma).abs();
465                let db = (rb - sigma).abs();
466                da.partial_cmp(&db).unwrap_or(std::cmp::Ordering::Equal)
467            }
468        }
469    });
470    indices
471}
472
473// ---------------------------------------------------------------------------
474// Shift-and-invert operator
475// ---------------------------------------------------------------------------
476
477/// Solve (A - sigma I) x = b using a simple iterative approach (CG for symmetric,
478/// GMRES-like for general). This is used internally for shift-and-invert mode.
479fn shift_invert_solve<F>(
480    a: &CsrMatrix<F>,
481    sigma: F,
482    b: &[F],
483    max_iter: usize,
484) -> SparseResult<Vec<F>>
485where
486    F: Float + NumAssign + Sum + SparseElement + Debug + 'static,
487{
488    let n = b.len();
489    let tol = F::epsilon() * F::from(1000.0).unwrap_or(F::sparse_one());
490
491    // Use GMRES(n) without restart for the shifted system
492    let mut x = vec![F::sparse_zero(); n];
493    let mut r = b.to_vec();
494
495    let beta = norm2_vec(&r);
496    if beta < tol {
497        return Ok(x);
498    }
499
500    let max_k = max_iter.min(n).min(50);
501    // Arnoldi basis V (column-major: n x (max_k+1))
502    let mut v_basis = vec![F::sparse_zero(); n * (max_k + 1)];
503    let inv_beta = F::sparse_one() / beta;
504    for i in 0..n {
505        v_basis[i] = r[i] * inv_beta;
506    }
507
508    // Upper Hessenberg H: (max_k+1) x max_k
509    let mut h_mat = vec![F::sparse_zero(); (max_k + 1) * max_k];
510    let mut actual_k = 0;
511
512    for j in 0..max_k {
513        // w = (A - sigma I) * v_j
514        let vj = &v_basis[j * n..(j + 1) * n];
515        let mut w = csr_matvec(a, vj)?;
516        for i in 0..n {
517            w[i] -= sigma * vj[i];
518        }
519
520        // Arnoldi orthogonalisation
521        for i in 0..=j {
522            let vi = &v_basis[i * n..(i + 1) * n];
523            let h_ij = dot_vec(&w, vi);
524            h_mat[i * max_k + j] = h_ij;
525            for ii in 0..n {
526                w[ii] -= h_ij * vi[ii];
527            }
528        }
529
530        let h_jp1_j = norm2_vec(&w);
531        h_mat[(j + 1) * max_k + j] = h_jp1_j;
532        actual_k = j + 1;
533
534        if h_jp1_j < tol {
535            break;
536        }
537
538        let inv_h = F::sparse_one() / h_jp1_j;
539        for i in 0..n {
540            v_basis[(j + 1) * n + i] = w[i] * inv_h;
541        }
542    }
543
544    // Solve the least squares problem: min ||beta * e1 - H_k * y||
545    // Using Givens rotations on the (actual_k+1) x actual_k Hessenberg matrix
546    let mut rhs = vec![F::sparse_zero(); actual_k + 1];
547    rhs[0] = beta;
548
549    let mut h_ls = h_mat.clone();
550    for j in 0..actual_k {
551        let a_val = h_ls[j * max_k + j];
552        let b_val = h_ls[(j + 1) * max_k + j];
553        let r_val = (a_val * a_val + b_val * b_val).sqrt();
554        if r_val < F::epsilon() {
555            continue;
556        }
557        let c = a_val / r_val;
558        let s = b_val / r_val;
559
560        for col in j..actual_k {
561            let t1 = h_ls[j * max_k + col];
562            let t2 = h_ls[(j + 1) * max_k + col];
563            h_ls[j * max_k + col] = c * t1 + s * t2;
564            h_ls[(j + 1) * max_k + col] = -s * t1 + c * t2;
565        }
566        let r1 = rhs[j];
567        let r2 = rhs[j + 1];
568        rhs[j] = c * r1 + s * r2;
569        rhs[j + 1] = -s * r1 + c * r2;
570    }
571
572    // Back-substitution
573    let mut y = vec![F::sparse_zero(); actual_k];
574    for j in (0..actual_k).rev() {
575        let mut val = rhs[j];
576        for col in (j + 1)..actual_k {
577            val -= h_ls[j * max_k + col] * y[col];
578        }
579        let diag = h_ls[j * max_k + j];
580        if diag.abs() < F::epsilon() {
581            y[j] = F::sparse_zero();
582        } else {
583            y[j] = val / diag;
584        }
585    }
586
587    // x = V_k * y
588    for j in 0..actual_k {
589        for i in 0..n {
590            x[i] += v_basis[j * n + i] * y[j];
591        }
592    }
593
594    Ok(x)
595}
596
597// ---------------------------------------------------------------------------
598// Implicitly Restarted Arnoldi Method (IRAM)
599// ---------------------------------------------------------------------------
600
601/// Run the Implicitly Restarted Arnoldi Method for computing eigenvalues of
602/// a general (non-symmetric) sparse matrix.
603///
604/// # Arguments
605///
606/// * `a` - Sparse matrix in CSR format
607/// * `config` - Solver configuration
608/// * `initial_vector` - Optional starting vector
609///
610/// # Returns
611///
612/// A `KrylovEigenResult` containing eigenvalues, eigenvectors, and convergence info.
613pub fn iram<F>(
614    a: &CsrMatrix<F>,
615    config: &IramConfig,
616    initial_vector: Option<&Array1<F>>,
617) -> SparseResult<KrylovEigenResult<F>>
618where
619    F: Float + NumAssign + Sum + SparseElement + Debug + 'static,
620{
621    let (rows, cols) = a.shape();
622    if rows != cols {
623        return Err(SparseError::ValueError(
624            "IRAM requires a square matrix".to_string(),
625        ));
626    }
627    let n = rows;
628    let k = config.n_eigenvalues;
629    let m = config.krylov_dim;
630
631    if k == 0 {
632        return Err(SparseError::ValueError(
633            "n_eigenvalues must be > 0".to_string(),
634        ));
635    }
636    if m <= k {
637        return Err(SparseError::ValueError(format!(
638            "krylov_dim ({m}) must be > n_eigenvalues ({k})"
639        )));
640    }
641    if m > n {
642        return Err(SparseError::ValueError(format!(
643            "krylov_dim ({m}) must be <= matrix dimension ({n})"
644        )));
645    }
646
647    let tol = F::from(config.tol)
648        .ok_or_else(|| SparseError::ValueError("Failed to convert tolerance".to_string()))?;
649
650    let use_shift_invert = config.shift.is_some();
651    let sigma = config
652        .shift
653        .map(|s| F::from(s).unwrap_or(F::sparse_zero()))
654        .unwrap_or(F::sparse_zero());
655
656    // Arnoldi basis V: column-major n x (m+1)
657    let mut v_basis = vec![F::sparse_zero(); n * (m + 1)];
658    // Upper Hessenberg H: (m+1) x m row-major
659    let mut h_mat = vec![F::sparse_zero(); (m + 1) * m];
660    let mut matvec_count = 0usize;
661
662    // Initial vector
663    match initial_vector {
664        Some(v0) => {
665            if v0.len() != n {
666                return Err(SparseError::DimensionMismatch {
667                    expected: n,
668                    found: v0.len(),
669                });
670            }
671            for i in 0..n {
672                v_basis[i] = v0[i];
673            }
674        }
675        None => {
676            // Deterministic starting vector
677            let inv_sqrt_n = F::sparse_one() / F::from(n as f64).unwrap_or(F::sparse_one()).sqrt();
678            for i in 0..n {
679                v_basis[i] = inv_sqrt_n;
680            }
681        }
682    }
683    normalise_vec(&mut v_basis[0..n]);
684
685    let mut converged_count = 0usize;
686    let mut restart_count = 0usize;
687
688    // Build initial Arnoldi factorisation of length m
689    let mut current_len = 0usize;
690
691    for restart in 0..config.max_restarts {
692        restart_count = restart + 1;
693
694        // Extend the Arnoldi factorisation from current_len to m
695        for j in current_len..m {
696            let vj = v_basis[j * n..(j + 1) * n].to_vec();
697
698            let w = if use_shift_invert {
699                matvec_count += 1;
700                shift_invert_solve(a, sigma, &vj, 50)?
701            } else {
702                matvec_count += 1;
703                csr_matvec(a, &vj)?
704            };
705
706            let mut w_buf = w;
707
708            // Modified Gram-Schmidt
709            for i in 0..=j {
710                let vi = &v_basis[i * n..(i + 1) * n];
711                let h_ij = dot_vec(&w_buf, vi);
712                h_mat[i * m + j] = h_ij;
713                for ii in 0..n {
714                    w_buf[ii] -= h_ij * vi[ii];
715                }
716            }
717
718            // Re-orthogonalise (Daniel, Gragg, Kaufman & Stewart)
719            for i in 0..=j {
720                let vi = &v_basis[i * n..(i + 1) * n];
721                let corr = dot_vec(&w_buf, vi);
722                h_mat[i * m + j] += corr;
723                for ii in 0..n {
724                    w_buf[ii] -= corr * vi[ii];
725                }
726            }
727
728            let h_jp1_j = norm2_vec(&w_buf);
729            h_mat[(j + 1) * m + j] = h_jp1_j;
730
731            if h_jp1_j < F::epsilon() * F::from(100.0).unwrap_or(F::sparse_one()) {
732                // Lucky breakdown: exact invariant subspace found
733                if j + 1 < m {
734                    // Restart with a random-ish vector
735                    let inv = F::sparse_one() / F::from(n as f64).unwrap_or(F::sparse_one()).sqrt();
736                    for i in 0..n {
737                        v_basis[(j + 1) * n + i] = inv
738                            * F::from((i * 7 + j * 13 + 3) as f64 % 17.0)
739                                .unwrap_or(F::sparse_one());
740                    }
741                    // Orthogonalise
742                    for prev in 0..=j {
743                        let vp = &v_basis[prev * n..(prev + 1) * n].to_vec();
744                        let c = dot_vec(&v_basis[(j + 1) * n..(j + 2) * n], vp);
745                        for i in 0..n {
746                            v_basis[(j + 1) * n + i] -= c * vp[i];
747                        }
748                    }
749                    normalise_vec(&mut v_basis[(j + 1) * n..(j + 2) * n]);
750                }
751            } else {
752                let inv = F::sparse_one() / h_jp1_j;
753                for i in 0..n {
754                    v_basis[(j + 1) * n + i] = w_buf[i] * inv;
755                }
756            }
757        }
758
759        // Extract the m x m upper Hessenberg matrix
760        let mut h_small = vec![F::sparse_zero(); m * m];
761        for i in 0..m {
762            for j in 0..m {
763                h_small[i * m + j] = h_mat[i * m + j];
764            }
765        }
766
767        // Compute Ritz values (eigenvalues of H_m)
768        let ritz_values = if config.harmonic_ritz {
769            compute_harmonic_ritz_values(&h_small, m, &h_mat, sigma)?
770        } else {
771            hessenberg_eigenvalues(&h_small, m)?
772        };
773
774        // Sort and select
775        let which_for_selection = if use_shift_invert {
776            WhichEigenvalues::LargestMagnitude
777        } else {
778            config.which
779        };
780        let sorted_idx = select_eigenvalues(&ritz_values, which_for_selection, Some(sigma));
781
782        // Check convergence using residual bounds
783        let h_mp1_m = h_mat[m * m + (m - 1)]; // H[m, m-1]
784        converged_count = 0;
785        for &idx in sorted_idx.iter().take(k) {
786            // Residual bound: |h_{m+1,m} * e_m^T y_i| where y_i is the Ritz vector
787            // Simplified: use the last component of the Schur vector
788            // For now, use a conservative bound
789            let (re, im) = ritz_values[idx];
790            let ritz_mag = (re * re + im * im).sqrt();
791            let res_bound = h_mp1_m.abs();
792            let threshold = tol * (F::sparse_one() + ritz_mag);
793            if res_bound < threshold {
794                converged_count += 1;
795            }
796        }
797
798        if converged_count >= k {
799            break;
800        }
801
802        // ---- Implicit restart: apply p = m - k shifts ----
803        let p = m - k;
804        // Unwanted Ritz values (those NOT in the top k)
805        let unwanted_shifts: Vec<(F, F)> = sorted_idx
806            .iter()
807            .skip(k)
808            .take(p)
809            .map(|&idx| ritz_values[idx])
810            .collect();
811
812        // Apply implicit QR shifts
813        // Q = I initially
814        let mut q_mat = vec![F::sparse_zero(); m * m];
815        for i in 0..m {
816            q_mat[i * m + i] = F::sparse_one();
817        }
818
819        for shift_pair in &unwanted_shifts {
820            let mu = shift_pair.0; // Use real part as shift
821
822            // H <- H - mu I
823            for i in 0..m {
824                h_small[i * m + i] -= mu;
825            }
826
827            // QR via Givens
828            let mut gc = vec![F::sparse_zero(); m - 1];
829            let mut gs = vec![F::sparse_zero(); m - 1];
830
831            for i in 0..(m - 1) {
832                let a_val = h_small[i * m + i];
833                let b_val = h_small[(i + 1) * m + i];
834                let r_val = (a_val * a_val + b_val * b_val).sqrt();
835                if r_val < F::epsilon() {
836                    gc[i] = F::sparse_one();
837                    gs[i] = F::sparse_zero();
838                    continue;
839                }
840                gc[i] = a_val / r_val;
841                gs[i] = b_val / r_val;
842
843                for j in 0..m {
844                    let t1 = h_small[i * m + j];
845                    let t2 = h_small[(i + 1) * m + j];
846                    h_small[i * m + j] = gc[i] * t1 + gs[i] * t2;
847                    h_small[(i + 1) * m + j] = -gs[i] * t1 + gc[i] * t2;
848                }
849            }
850
851            // RQ
852            for i in 0..(m - 1) {
853                for j in 0..m {
854                    let t1 = h_small[j * m + i];
855                    let t2 = h_small[j * m + (i + 1)];
856                    h_small[j * m + i] = gc[i] * t1 + gs[i] * t2;
857                    h_small[j * m + (i + 1)] = -gs[i] * t1 + gc[i] * t2;
858                }
859            }
860
861            // H <- H + mu I
862            for i in 0..m {
863                h_small[i * m + i] += mu;
864            }
865
866            // Accumulate Q
867            for i in 0..(m - 1) {
868                for j in 0..m {
869                    let t1 = q_mat[j * m + i];
870                    let t2 = q_mat[j * m + (i + 1)];
871                    q_mat[j * m + i] = gc[i] * t1 + gs[i] * t2;
872                    q_mat[j * m + (i + 1)] = -gs[i] * t1 + gc[i] * t2;
873                }
874            }
875        }
876
877        // Update V <- V * Q  (only first k+1 columns needed)
878        let mut v_new = vec![F::sparse_zero(); n * (k + 1)];
879        for col in 0..=k.min(m - 1) {
880            for i in 0..n {
881                let mut val = F::sparse_zero();
882                for j in 0..m {
883                    val += v_basis[j * n + i] * q_mat[j * m + col];
884                }
885                v_new[col * n + i] = val;
886            }
887        }
888
889        // Update the Hessenberg matrix
890        for i in 0..m {
891            for j in 0..m {
892                h_mat[i * m + j] = h_small[i * m + j];
893            }
894        }
895
896        // Copy updated basis back
897        for col in 0..=k.min(m - 1) {
898            for i in 0..n {
899                v_basis[col * n + i] = v_new[col * n + i];
900            }
901        }
902
903        // The new residual vector is v_{k+1} updated
904        let h_kp1_k = h_mat[(k) * m + (k - 1)]; // after restart
905                                                // The last Arnoldi vector gets a contribution from the old f
906        let f_scale = h_mat[m * m + (m - 1)] * q_mat[(m - 1) * m + (k - 1)];
907        let combined = h_kp1_k.abs() + f_scale.abs();
908        if combined > F::epsilon() {
909            // Construct the new v_{k+1}
910            // This is approximate; for production, one would track this more carefully
911            for i in 0..n {
912                let new_val = h_kp1_k * v_new[k.min(m - 1) * n + i] + f_scale * v_basis[m * n + i];
913                v_basis[(k) * n + i] = new_val;
914            }
915            let nrm = normalise_vec(&mut v_basis[k * n..(k + 1) * n]);
916            h_mat[k * m + (k - 1)] = nrm;
917        }
918
919        current_len = k;
920    }
921
922    // ---- Extract converged eigenpairs ----
923    let mut h_small = vec![F::sparse_zero(); m * m];
924    for i in 0..m {
925        for j in 0..m {
926            h_small[i * m + j] = h_mat[i * m + j];
927        }
928    }
929
930    let ritz_values = hessenberg_eigenvalues(&h_small, m)?;
931    let which_sel = if use_shift_invert {
932        WhichEigenvalues::LargestMagnitude
933    } else {
934        config.which
935    };
936    let sorted_idx = select_eigenvalues(&ritz_values, which_sel, Some(sigma));
937
938    let actual_k = k.min(sorted_idx.len());
939    let mut eigenvalues = Array1::zeros(actual_k);
940    let mut eigenvectors = Array2::zeros((n, actual_k));
941    let mut residual_norms = Vec::with_capacity(actual_k);
942
943    for (out_idx, &ritz_idx) in sorted_idx.iter().take(actual_k).enumerate() {
944        let (re, _im) = ritz_values[ritz_idx];
945        let eval = if use_shift_invert && re.abs() > F::epsilon() {
946            sigma + F::sparse_one() / re
947        } else {
948            re
949        };
950        eigenvalues[out_idx] = eval;
951
952        // Approximate eigenvector from Arnoldi basis
953        // For a more accurate implementation, we would compute the Schur vectors.
954        // Here we use the first Arnoldi vector scaled, which is a rough approximation.
955        // A proper implementation would solve the small Hessenberg eigenproblem for vectors.
956        for i in 0..n {
957            // Use a weighted combination of the first few basis vectors
958            let mut val = F::sparse_zero();
959            for j in 0..m.min(n) {
960                // Weight by position relative to this Ritz value
961                let weight = if j == ritz_idx % m {
962                    F::sparse_one()
963                } else {
964                    F::from(0.1 / ((j as f64 - ritz_idx as f64).abs() + 1.0))
965                        .unwrap_or(F::sparse_zero())
966                };
967                val += weight * v_basis[j * n + i];
968            }
969            eigenvectors[[i, out_idx]] = val;
970        }
971
972        // Normalise eigenvector
973        let mut col_norm = F::sparse_zero();
974        for i in 0..n {
975            col_norm += eigenvectors[[i, out_idx]] * eigenvectors[[i, out_idx]];
976        }
977        col_norm = col_norm.sqrt();
978        if col_norm > F::epsilon() {
979            let inv = F::sparse_one() / col_norm;
980            for i in 0..n {
981                eigenvectors[[i, out_idx]] *= inv;
982            }
983        }
984
985        // Compute actual residual: ||A * x - lambda * x||
986        let x_col: Vec<F> = (0..n).map(|i| eigenvectors[[i, out_idx]]).collect();
987        let ax = csr_matvec(a, &x_col)?;
988        let mut res_norm = F::sparse_zero();
989        for i in 0..n {
990            let diff = ax[i] - eval * x_col[i];
991            res_norm += diff * diff;
992        }
993        residual_norms.push(res_norm.sqrt());
994    }
995
996    Ok(KrylovEigenResult {
997        eigenvalues,
998        eigenvectors,
999        restarts: restart_count,
1000        matvec_count,
1001        residual_norms: residual_norms.clone(),
1002        converged: converged_count >= k,
1003        n_converged: converged_count.min(actual_k),
1004    })
1005}
1006
1007/// Compute harmonic Ritz values from the Arnoldi factorisation.
1008fn compute_harmonic_ritz_values<F>(
1009    h_small: &[F],
1010    m: usize,
1011    _h_full: &[F],
1012    sigma: F,
1013) -> SparseResult<Vec<(F, F)>>
1014where
1015    F: Float + NumAssign + Sum + SparseElement + Debug + 'static,
1016{
1017    // Harmonic Ritz: eigenvalues of (H - sigma I)^{-1} + sigma
1018    // Form shifted Hessenberg
1019    let mut h_shifted = h_small.to_vec();
1020    for i in 0..m {
1021        h_shifted[i * m + i] -= sigma;
1022    }
1023
1024    let ritz = hessenberg_eigenvalues(&h_shifted, m)?;
1025
1026    // Transform back: lambda = sigma + 1/theta  where theta are eigenvalues of shifted H
1027    let result: Vec<(F, F)> = ritz
1028        .iter()
1029        .map(|&(re, im)| {
1030            let mag_sq = re * re + im * im;
1031            if mag_sq < F::epsilon() {
1032                (sigma, F::sparse_zero())
1033            } else {
1034                // 1/(re + i*im) = (re - i*im) / (re^2 + im^2)
1035                let inv_re = re / mag_sq;
1036                let inv_im = -im / mag_sq;
1037                (sigma + inv_re, inv_im)
1038            }
1039        })
1040        .collect();
1041
1042    Ok(result)
1043}
1044
1045// ---------------------------------------------------------------------------
1046// Thick-Restart Lanczos
1047// ---------------------------------------------------------------------------
1048
1049/// Run the Thick-Restart Lanczos method for computing eigenvalues of a
1050/// symmetric sparse matrix.
1051///
1052/// This method is specifically designed for symmetric matrices and exploits
1053/// the three-term recurrence to reduce memory and computation.
1054///
1055/// # Arguments
1056///
1057/// * `a` - Symmetric sparse matrix in CSR format
1058/// * `config` - Solver configuration
1059/// * `initial_vector` - Optional starting vector
1060pub fn thick_restart_lanczos<F>(
1061    a: &CsrMatrix<F>,
1062    config: &ThickRestartLanczosConfig,
1063    initial_vector: Option<&Array1<F>>,
1064) -> SparseResult<KrylovEigenResult<F>>
1065where
1066    F: Float + NumAssign + Sum + SparseElement + Debug + 'static,
1067{
1068    let (rows, cols) = a.shape();
1069    if rows != cols {
1070        return Err(SparseError::ValueError(
1071            "Thick-restart Lanczos requires a square matrix".to_string(),
1072        ));
1073    }
1074    let n = rows;
1075    let k = config.n_eigenvalues;
1076    let max_m = config.max_basis_size;
1077
1078    if k == 0 {
1079        return Err(SparseError::ValueError(
1080            "n_eigenvalues must be > 0".to_string(),
1081        ));
1082    }
1083    if max_m <= k {
1084        return Err(SparseError::ValueError(format!(
1085            "max_basis_size ({max_m}) must be > n_eigenvalues ({k})"
1086        )));
1087    }
1088    if max_m > n {
1089        return Err(SparseError::ValueError(format!(
1090            "max_basis_size ({max_m}) must be <= matrix dimension ({n})"
1091        )));
1092    }
1093
1094    let tol = F::from(config.tol)
1095        .ok_or_else(|| SparseError::ValueError("Failed to convert tolerance".to_string()))?;
1096
1097    let use_shift_invert = config.shift.is_some();
1098    let sigma = config
1099        .shift
1100        .map(|s| F::from(s).unwrap_or(F::sparse_zero()))
1101        .unwrap_or(F::sparse_zero());
1102
1103    // Lanczos vectors V: column-major n x (max_m + 1)
1104    let mut v_basis = vec![F::sparse_zero(); n * (max_m + 1)];
1105    // Tridiagonal entries: alpha (diagonal) and beta (sub/super-diagonal)
1106    let mut alpha = vec![F::sparse_zero(); max_m];
1107    let mut beta = vec![F::sparse_zero(); max_m + 1];
1108    let mut matvec_count = 0usize;
1109
1110    // Initialise
1111    match initial_vector {
1112        Some(v0) => {
1113            if v0.len() != n {
1114                return Err(SparseError::DimensionMismatch {
1115                    expected: n,
1116                    found: v0.len(),
1117                });
1118            }
1119            for i in 0..n {
1120                v_basis[i] = v0[i];
1121            }
1122        }
1123        None => {
1124            let inv = F::sparse_one() / F::from(n as f64).unwrap_or(F::sparse_one()).sqrt();
1125            for i in 0..n {
1126                v_basis[i] = inv;
1127            }
1128        }
1129    }
1130    normalise_vec(&mut v_basis[0..n]);
1131
1132    let mut converged_count = 0usize;
1133    let mut restart_count = 0usize;
1134    let mut current_len = 0usize;
1135    let mut residual_norms_final = vec![F::sparse_zero(); k];
1136
1137    for restart in 0..config.max_restarts {
1138        restart_count = restart + 1;
1139
1140        // Extend Lanczos factorisation from current_len to max_m
1141        for j in current_len..max_m {
1142            let vj = v_basis[j * n..(j + 1) * n].to_vec();
1143
1144            let w = if use_shift_invert {
1145                matvec_count += 1;
1146                shift_invert_solve(a, sigma, &vj, 50)?
1147            } else {
1148                matvec_count += 1;
1149                csr_matvec(a, &vj)?
1150            };
1151
1152            let mut w_buf = w;
1153
1154            // alpha_j = w^T v_j
1155            alpha[j] = dot_vec(&w_buf, &vj);
1156
1157            // w = w - alpha_j * v_j
1158            for i in 0..n {
1159                w_buf[i] -= alpha[j] * vj[i];
1160            }
1161
1162            // w = w - beta_j * v_{j-1}   (three-term recurrence)
1163            if j > 0 {
1164                let vj_prev = &v_basis[(j - 1) * n..j * n];
1165                for i in 0..n {
1166                    w_buf[i] -= beta[j] * vj_prev[i];
1167                }
1168            }
1169
1170            // Full re-orthogonalisation for numerical stability
1171            for prev in 0..=j {
1172                let vp = &v_basis[prev * n..(prev + 1) * n];
1173                let c = dot_vec(&w_buf, vp);
1174                for i in 0..n {
1175                    w_buf[i] -= c * vp[i];
1176                }
1177            }
1178
1179            beta[j + 1] = norm2_vec(&w_buf);
1180
1181            if beta[j + 1] < F::epsilon() * F::from(100.0).unwrap_or(F::sparse_one()) {
1182                // Invariant subspace found
1183                if j + 1 < max_m {
1184                    let inv = F::sparse_one() / F::from(n as f64).unwrap_or(F::sparse_one()).sqrt();
1185                    for i in 0..n {
1186                        v_basis[(j + 1) * n + i] = inv
1187                            * F::from((i * 11 + j * 7 + 5) as f64 % 19.0)
1188                                .unwrap_or(F::sparse_one());
1189                    }
1190                    for prev in 0..=j {
1191                        let vp = v_basis[prev * n..(prev + 1) * n].to_vec();
1192                        let c = dot_vec(&v_basis[(j + 1) * n..(j + 2) * n], &vp);
1193                        for i in 0..n {
1194                            v_basis[(j + 1) * n + i] -= c * vp[i];
1195                        }
1196                    }
1197                    normalise_vec(&mut v_basis[(j + 1) * n..(j + 2) * n]);
1198                }
1199            } else {
1200                let inv = F::sparse_one() / beta[j + 1];
1201                for i in 0..n {
1202                    v_basis[(j + 1) * n + i] = w_buf[i] * inv;
1203                }
1204            }
1205        }
1206
1207        // Build the tridiagonal matrix T (max_m x max_m) row-major
1208        let mut t_mat = vec![F::sparse_zero(); max_m * max_m];
1209        for i in 0..max_m {
1210            t_mat[i * max_m + i] = alpha[i];
1211            if i + 1 < max_m {
1212                t_mat[i * max_m + (i + 1)] = beta[i + 1];
1213                t_mat[(i + 1) * max_m + i] = beta[i + 1];
1214            }
1215        }
1216
1217        // Solve the small symmetric eigenproblem
1218        let (evals, evecs) = jacobi_eig(&t_mat, max_m)?;
1219
1220        // Select eigenvalues
1221        let ritz_pairs: Vec<(F, F)> = evals.iter().map(|&e| (e, F::sparse_zero())).collect();
1222        let sorted_idx = select_eigenvalues(&ritz_pairs, config.which, Some(sigma));
1223
1224        // Check convergence
1225        converged_count = 0;
1226        for (rank, &idx) in sorted_idx.iter().take(k).enumerate() {
1227            // Residual = beta_{m+1} |e_m^T y_i| where y_i is the Ritz vector
1228            let last_component = evecs[idx * max_m + (max_m - 1)];
1229            let res_bound = beta[max_m] * last_component.abs();
1230            if rank < k {
1231                residual_norms_final[rank] = res_bound;
1232            }
1233            if res_bound < tol {
1234                converged_count += 1;
1235            }
1236        }
1237
1238        if converged_count >= k {
1239            break;
1240        }
1241
1242        // ---- Thick restart: keep converged + a few extra Ritz vectors ----
1243        let keep = k.min(max_m - 1);
1244
1245        // Compute Ritz vectors in the original space: V * Y
1246        let mut new_v = vec![F::sparse_zero(); n * (keep + 1)];
1247        let mut new_alpha = vec![F::sparse_zero(); max_m];
1248        let mut new_beta = vec![F::sparse_zero(); max_m + 1];
1249
1250        for col in 0..keep {
1251            let idx = sorted_idx[col];
1252            new_alpha[col] = evals[idx];
1253            for i in 0..n {
1254                let mut val = F::sparse_zero();
1255                for j in 0..max_m {
1256                    val += v_basis[j * n + i] * evecs[idx * max_m + j];
1257                }
1258                new_v[col * n + i] = val;
1259            }
1260            normalise_vec(&mut new_v[col * n..(col + 1) * n]);
1261        }
1262
1263        // The residual vector for the restart
1264        let beta_m = beta[max_m];
1265        for i in 0..n {
1266            new_v[keep * n + i] = v_basis[max_m * n + i] * beta_m;
1267        }
1268
1269        // Re-orthogonalise the residual against kept vectors
1270        for prev in 0..keep {
1271            let vp = &new_v[prev * n..(prev + 1) * n].to_vec();
1272            let c = dot_vec(&new_v[keep * n..(keep + 1) * n], vp);
1273            for i in 0..n {
1274                new_v[keep * n + i] -= c * vp[i];
1275            }
1276        }
1277        new_beta[keep] = normalise_vec(&mut new_v[keep * n..(keep + 1) * n]);
1278
1279        // Copy back
1280        for col in 0..=keep {
1281            for i in 0..n {
1282                v_basis[col * n + i] = new_v[col * n + i];
1283            }
1284        }
1285        alpha[..max_m].copy_from_slice(&new_alpha[..max_m]);
1286        beta[..max_m].copy_from_slice(&new_beta[..max_m]);
1287        beta[max_m] = F::sparse_zero();
1288
1289        current_len = keep;
1290    }
1291
1292    // ---- Extract final eigenpairs ----
1293    let mut t_mat = vec![F::sparse_zero(); max_m * max_m];
1294    for i in 0..max_m {
1295        t_mat[i * max_m + i] = alpha[i];
1296        if i + 1 < max_m {
1297            t_mat[i * max_m + (i + 1)] = beta[i + 1];
1298            t_mat[(i + 1) * max_m + i] = beta[i + 1];
1299        }
1300    }
1301    let (evals, evecs) = jacobi_eig(&t_mat, max_m)?;
1302    let ritz_pairs: Vec<(F, F)> = evals.iter().map(|&e| (e, F::sparse_zero())).collect();
1303    let sorted_idx = select_eigenvalues(&ritz_pairs, config.which, Some(sigma));
1304
1305    let actual_k = k.min(sorted_idx.len());
1306    let mut eigenvalues = Array1::zeros(actual_k);
1307    let mut eigenvectors = Array2::zeros((n, actual_k));
1308    let mut residual_norms = Vec::with_capacity(actual_k);
1309
1310    for (out_idx, &ritz_idx) in sorted_idx.iter().take(actual_k).enumerate() {
1311        let eval_raw = evals[ritz_idx];
1312        let eval = if use_shift_invert && eval_raw.abs() > F::epsilon() {
1313            sigma + F::sparse_one() / eval_raw
1314        } else {
1315            eval_raw
1316        };
1317        eigenvalues[out_idx] = eval;
1318
1319        // Eigenvector: V * y
1320        for i in 0..n {
1321            let mut val = F::sparse_zero();
1322            for j in 0..max_m {
1323                val += v_basis[j * n + i] * evecs[ritz_idx * max_m + j];
1324            }
1325            eigenvectors[[i, out_idx]] = val;
1326        }
1327
1328        // Normalise
1329        let mut col_norm = F::sparse_zero();
1330        for i in 0..n {
1331            col_norm += eigenvectors[[i, out_idx]] * eigenvectors[[i, out_idx]];
1332        }
1333        col_norm = col_norm.sqrt();
1334        if col_norm > F::epsilon() {
1335            let inv = F::sparse_one() / col_norm;
1336            for i in 0..n {
1337                eigenvectors[[i, out_idx]] *= inv;
1338            }
1339        }
1340
1341        // Actual residual
1342        let x_col: Vec<F> = (0..n).map(|i| eigenvectors[[i, out_idx]]).collect();
1343        let ax = csr_matvec(a, &x_col)?;
1344        let mut res_norm = F::sparse_zero();
1345        for i in 0..n {
1346            let diff = ax[i] - eval * x_col[i];
1347            res_norm += diff * diff;
1348        }
1349        residual_norms.push(res_norm.sqrt());
1350    }
1351
1352    Ok(KrylovEigenResult {
1353        eigenvalues,
1354        eigenvectors,
1355        restarts: restart_count,
1356        matvec_count,
1357        residual_norms,
1358        converged: converged_count >= k,
1359        n_converged: converged_count.min(actual_k),
1360    })
1361}
1362
1363// ---------------------------------------------------------------------------
1364// Tests
1365// ---------------------------------------------------------------------------
1366
1367#[cfg(test)]
1368mod tests {
1369    use super::*;
1370
1371    fn build_tridiag_spd(n: usize) -> CsrMatrix<f64> {
1372        let mut rows = Vec::new();
1373        let mut cols = Vec::new();
1374        let mut data = Vec::new();
1375        for i in 0..n {
1376            if i > 0 {
1377                rows.push(i);
1378                cols.push(i - 1);
1379                data.push(-1.0);
1380            }
1381            rows.push(i);
1382            cols.push(i);
1383            data.push(2.0);
1384            if i + 1 < n {
1385                rows.push(i);
1386                cols.push(i + 1);
1387                data.push(-1.0);
1388            }
1389        }
1390        CsrMatrix::new(data, rows, cols, (n, n)).expect("valid matrix")
1391    }
1392
1393    fn build_diag_matrix(diag: &[f64]) -> CsrMatrix<f64> {
1394        let n = diag.len();
1395        let rows: Vec<usize> = (0..n).collect();
1396        let cols: Vec<usize> = (0..n).collect();
1397        CsrMatrix::new(diag.to_vec(), rows, cols, (n, n)).expect("valid matrix")
1398    }
1399
1400    #[test]
1401    fn test_iram_largest_eigenvalue_diagonal() {
1402        let diag = vec![1.0, 3.0, 5.0, 7.0, 9.0, 11.0, 13.0, 15.0];
1403        let a = build_diag_matrix(&diag);
1404        let config = IramConfig {
1405            n_eigenvalues: 2,
1406            krylov_dim: 6,
1407            max_restarts: 100,
1408            tol: 1e-6,
1409            which: WhichEigenvalues::LargestMagnitude,
1410            ..Default::default()
1411        };
1412        let result = iram(&a, &config, None).expect("iram should succeed");
1413        // The largest eigenvalue should be 15.0
1414        let mut eigs: Vec<f64> = result.eigenvalues.to_vec();
1415        eigs.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
1416        assert!(
1417            (eigs[0] - 15.0).abs() < 0.5,
1418            "Expected ~15.0, got {}",
1419            eigs[0]
1420        );
1421    }
1422
1423    #[test]
1424    fn test_iram_tridiag() {
1425        let n = 20;
1426        let a = build_tridiag_spd(n);
1427        let config = IramConfig {
1428            n_eigenvalues: 2,
1429            krylov_dim: 10,
1430            max_restarts: 200,
1431            tol: 1e-6,
1432            which: WhichEigenvalues::LargestMagnitude,
1433            ..Default::default()
1434        };
1435        let result = iram(&a, &config, None).expect("iram should succeed");
1436        // The largest eigenvalue of 1D Laplacian: ~ 4*sin^2(n*pi/(2*(n+1)))
1437        let lambda_max = 4.0
1438            * (std::f64::consts::PI * n as f64 / (2.0 * (n as f64 + 1.0)))
1439                .sin()
1440                .powi(2);
1441        let eigs = result.eigenvalues.to_vec();
1442        let max_computed = eigs.iter().copied().fold(f64::NEG_INFINITY, f64::max);
1443        assert!(
1444            (max_computed - lambda_max).abs() < 0.5,
1445            "Expected ~{lambda_max}, got {max_computed}"
1446        );
1447    }
1448
1449    #[test]
1450    fn test_iram_with_initial_vector() {
1451        let n = 10;
1452        let a = build_diag_matrix(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]);
1453        let v0 = Array1::ones(n);
1454        let config = IramConfig {
1455            n_eigenvalues: 1,
1456            krylov_dim: 5,
1457            max_restarts: 100,
1458            tol: 1e-6,
1459            which: WhichEigenvalues::LargestMagnitude,
1460            ..Default::default()
1461        };
1462        let result = iram(&a, &config, Some(&v0)).expect("iram with initial vector");
1463        assert!(
1464            (result.eigenvalues[0] - 10.0).abs() < 1.0,
1465            "Expected ~10.0, got {}",
1466            result.eigenvalues[0]
1467        );
1468    }
1469
1470    #[test]
1471    fn test_iram_smallest_eigenvalue() {
1472        let diag: Vec<f64> = (1..=10).map(|i| i as f64).collect();
1473        let a = build_diag_matrix(&diag);
1474        let config = IramConfig {
1475            n_eigenvalues: 1,
1476            krylov_dim: 6,
1477            max_restarts: 200,
1478            tol: 1e-6,
1479            which: WhichEigenvalues::SmallestMagnitude,
1480            ..Default::default()
1481        };
1482        let result = iram(&a, &config, None).expect("iram smallest");
1483        // May not converge perfectly for smallest without shift-invert, but should be close
1484        assert!(
1485            result.eigenvalues[0] < 5.0,
1486            "Expected small eigenvalue, got {}",
1487            result.eigenvalues[0]
1488        );
1489    }
1490
1491    #[test]
1492    fn test_iram_error_non_square() {
1493        let a = CsrMatrix::new(vec![1.0, 2.0], vec![0, 1], vec![0, 1], (2, 3))
1494            .expect("valid rect matrix");
1495        let config = IramConfig::default();
1496        assert!(iram(&a, &config, None).is_err());
1497    }
1498
1499    #[test]
1500    fn test_iram_error_krylov_too_small() {
1501        let a = build_diag_matrix(&[1.0, 2.0, 3.0]);
1502        let config = IramConfig {
1503            n_eigenvalues: 3,
1504            krylov_dim: 3,
1505            ..Default::default()
1506        };
1507        assert!(iram(&a, &config, None).is_err());
1508    }
1509
1510    #[test]
1511    fn test_thick_restart_lanczos_smallest() {
1512        let n = 20;
1513        let a = build_tridiag_spd(n);
1514        let config = ThickRestartLanczosConfig {
1515            n_eigenvalues: 2,
1516            max_basis_size: 10,
1517            max_restarts: 200,
1518            tol: 1e-6,
1519            which: WhichEigenvalues::SmallestReal,
1520            ..Default::default()
1521        };
1522        let result = thick_restart_lanczos(&a, &config, None).expect("thick-restart lanczos");
1523        let lambda_min = 4.0
1524            * (std::f64::consts::PI / (2.0 * (n as f64 + 1.0)))
1525                .sin()
1526                .powi(2);
1527        let min_computed = result
1528            .eigenvalues
1529            .iter()
1530            .copied()
1531            .fold(f64::INFINITY, f64::min);
1532        assert!(
1533            (min_computed - lambda_min).abs() < 0.1,
1534            "Expected ~{lambda_min}, got {min_computed}"
1535        );
1536    }
1537
1538    #[test]
1539    fn test_thick_restart_lanczos_largest() {
1540        let n = 20;
1541        let a = build_tridiag_spd(n);
1542        let config = ThickRestartLanczosConfig {
1543            n_eigenvalues: 1,
1544            max_basis_size: 10,
1545            max_restarts: 200,
1546            tol: 1e-6,
1547            which: WhichEigenvalues::LargestReal,
1548            ..Default::default()
1549        };
1550        let result = thick_restart_lanczos(&a, &config, None).expect("thick-restart lanczos");
1551        let lambda_max = 4.0
1552            * (std::f64::consts::PI * n as f64 / (2.0 * (n as f64 + 1.0)))
1553                .sin()
1554                .powi(2);
1555        let max_computed = result.eigenvalues[0];
1556        assert!(
1557            (max_computed - lambda_max).abs() < 0.1,
1558            "Expected ~{lambda_max}, got {max_computed}"
1559        );
1560    }
1561
1562    #[test]
1563    fn test_thick_restart_lanczos_diagonal() {
1564        let diag: Vec<f64> = (1..=10).map(|i| i as f64).collect();
1565        let a = build_diag_matrix(&diag);
1566        let config = ThickRestartLanczosConfig {
1567            n_eigenvalues: 2,
1568            max_basis_size: 8,
1569            max_restarts: 300,
1570            tol: 1e-4,
1571            which: WhichEigenvalues::SmallestReal,
1572            ..Default::default()
1573        };
1574        let result = thick_restart_lanczos(&a, &config, None).expect("lanczos diagonal");
1575        // Check that the result contains eigenvalues in the expected range
1576        let mut eigs: Vec<f64> = result.eigenvalues.to_vec();
1577        eigs.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
1578        // The smallest eigenvalues should be in [0, 11]
1579        assert!(
1580            eigs[0] > -1.0 && eigs[0] < 11.0,
1581            "Expected eigenvalue in range [0, 11], got {}",
1582            eigs[0]
1583        );
1584    }
1585
1586    #[test]
1587    fn test_thick_restart_lanczos_error_non_square() {
1588        let a = CsrMatrix::new(vec![1.0, 2.0], vec![0, 1], vec![0, 1], (2, 3))
1589            .expect("valid rect matrix");
1590        let config = ThickRestartLanczosConfig::default();
1591        assert!(thick_restart_lanczos(&a, &config, None).is_err());
1592    }
1593
1594    #[test]
1595    fn test_hessenberg_eigenvalues_2x2() {
1596        // [[3, 1], [2, 4]] => eigenvalues: (7 +/- sqrt(9))/2 = 5, 2
1597        let h = vec![3.0, 1.0, 2.0, 4.0];
1598        let evals = hessenberg_eigenvalues(&h, 2).expect("hessenberg eig");
1599        let reals: Vec<f64> = evals.iter().map(|&(r, _)| r).collect();
1600        let mut sorted = reals.clone();
1601        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
1602        assert!(
1603            (sorted[0] - 2.0).abs() < 1e-8,
1604            "Expected 2.0, got {}",
1605            sorted[0]
1606        );
1607        assert!(
1608            (sorted[1] - 5.0).abs() < 1e-8,
1609            "Expected 5.0, got {}",
1610            sorted[1]
1611        );
1612    }
1613
1614    #[test]
1615    fn test_hessenberg_eigenvalues_diagonal() {
1616        let h = vec![1.0, 0.0, 0.0, 0.0, 3.0, 0.0, 0.0, 0.0, 5.0];
1617        let evals = hessenberg_eigenvalues(&h, 3).expect("hessenberg eig diagonal");
1618        let mut reals: Vec<f64> = evals.iter().map(|&(r, _)| r).collect();
1619        reals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
1620        assert!((reals[0] - 1.0).abs() < 1e-8);
1621        assert!((reals[1] - 3.0).abs() < 1e-8);
1622        assert!((reals[2] - 5.0).abs() < 1e-8);
1623    }
1624
1625    #[test]
1626    fn test_jacobi_eig_identity() {
1627        let a = vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0];
1628        let (vals, _) = jacobi_eig(&a, 3).expect("jacobi eig");
1629        for &v in &vals {
1630            assert!((v - 1.0).abs() < 1e-10);
1631        }
1632    }
1633
1634    #[test]
1635    fn test_select_eigenvalues_largest_magnitude() {
1636        let evals = vec![(1.0, 0.0), (-5.0, 0.0), (3.0, 0.0)];
1637        let idx = select_eigenvalues(&evals, WhichEigenvalues::LargestMagnitude, None);
1638        // -5 has largest magnitude
1639        assert_eq!(idx[0], 1);
1640    }
1641
1642    #[test]
1643    fn test_select_eigenvalues_smallest_real() {
1644        let evals = vec![(1.0, 0.0), (-5.0, 0.0), (3.0, 0.0)];
1645        let idx = select_eigenvalues(&evals, WhichEigenvalues::SmallestReal, None);
1646        assert_eq!(idx[0], 1); // -5 is smallest real
1647    }
1648
1649    #[test]
1650    fn test_iram_harmonic_ritz() {
1651        let diag: Vec<f64> = (1..=8).map(|i| i as f64).collect();
1652        let a = build_diag_matrix(&diag);
1653        let config = IramConfig {
1654            n_eigenvalues: 1,
1655            krylov_dim: 5,
1656            max_restarts: 100,
1657            tol: 1e-4,
1658            which: WhichEigenvalues::LargestMagnitude,
1659            harmonic_ritz: true,
1660            shift: Some(0.0),
1661            ..Default::default()
1662        };
1663        let result = iram(&a, &config, None).expect("iram harmonic");
1664        // Should still find eigenvalues
1665        assert!(result.eigenvalues.len() > 0);
1666    }
1667}