scirs2_sparse/linalg/eigen/
lanczos.rs

1//! Lanczos algorithm for sparse matrix eigenvalue computation
2//!
3//! This module implements the Lanczos algorithm for finding eigenvalues and
4//! eigenvectors of large symmetric sparse matrices.
5
6use crate::error::{SparseError, SparseResult};
7use crate::sym_csr::SymCsrMatrix;
8use crate::sym_ops::sym_csr_matvec;
9use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
10use scirs2_core::numeric::Float;
11use scirs2_core::SparseElement;
12use std::fmt::Debug;
13use std::ops::{Add, Div, Mul, Sub};
14
15// For checking approximate equality in floating-point values
16macro_rules! abs_diff_eq {
17    ($left:expr, $right:expr) => {
18        ($left as i32) == ($right as i32)
19    };
20}
21
22/// Configuration options for the Lanczos algorithm
23#[derive(Debug, Clone)]
24pub struct LanczosOptions {
25    /// Maximum number of iterations
26    pub max_iter: usize,
27    /// Maximum dimension of the Krylov subspace
28    pub max_subspace_size: usize,
29    /// Convergence tolerance
30    pub tol: f64,
31    /// Number of eigenvalues to compute
32    pub numeigenvalues: usize,
33    /// Whether to compute eigenvectors
34    pub compute_eigenvectors: bool,
35}
36
37impl Default for LanczosOptions {
38    fn default() -> Self {
39        Self {
40            max_iter: 1000,
41            max_subspace_size: 20,
42            tol: 1e-8,
43            numeigenvalues: 1,
44            compute_eigenvectors: true,
45        }
46    }
47}
48
49/// Result of an eigenvalue computation
50#[derive(Debug, Clone)]
51pub struct EigenResult<T>
52where
53    T: Float + Debug + Copy,
54{
55    /// Converged eigenvalues
56    pub eigenvalues: Array1<T>,
57    /// Corresponding eigenvectors (if requested)
58    pub eigenvectors: Option<Array2<T>>,
59    /// Number of iterations performed
60    pub iterations: usize,
61    /// Residual norms for each eigenpair
62    pub residuals: Array1<T>,
63    /// Whether the algorithm converged
64    pub converged: bool,
65}
66
67/// Computes the extreme eigenvalues and corresponding eigenvectors of a symmetric
68/// matrix using the Lanczos algorithm.
69///
70/// # Arguments
71///
72/// * `matrix` - The symmetric matrix
73/// * `options` - Configuration options
74/// * `initial_guess` - Initial guess for the first Lanczos vector (optional)
75///
76/// # Returns
77///
78/// Result containing eigenvalues and eigenvectors
79///
80/// # Example
81///
82/// ```
83/// use scirs2_core::ndarray::Array1;
84/// use scirs2_sparse::{
85///     sym_csr::SymCsrMatrix,
86///     linalg::{lanczos, LanczosOptions},
87/// };
88///
89/// // Create a symmetric matrix
90/// let data = vec![2.0, 1.0, 2.0, 3.0];
91/// let indices = vec![0, 0, 1, 2];
92/// let indptr = vec![0, 1, 3, 4];
93/// let matrix = SymCsrMatrix::new(data, indptr, indices, (3, 3)).unwrap();
94///
95/// // Configure options
96/// let options = LanczosOptions {
97///     max_iter: 100,
98///     max_subspace_size: 3, // Matrix is 3x3
99///     tol: 1e-8,
100///     numeigenvalues: 1,   // Find the largest eigenvalue
101///     compute_eigenvectors: true,
102/// };
103///
104/// // Compute eigenvalues and eigenvectors
105/// let result = lanczos(&matrix, &options, None).unwrap();
106///
107/// // Check the result
108/// println!("Eigenvalues: {:?}", result.eigenvalues);
109/// println!("Converged in {} iterations", result.iterations);
110/// println!("Final residuals: {:?}", result.residuals);
111/// assert!(result.converged);
112/// ```
113#[allow(unused_assignments)]
114#[allow(dead_code)]
115pub fn lanczos<T>(
116    matrix: &SymCsrMatrix<T>,
117    options: &LanczosOptions,
118    initial_guess: Option<ArrayView1<T>>,
119) -> SparseResult<EigenResult<T>>
120where
121    T: Float
122        + SparseElement
123        + Debug
124        + Copy
125        + Add<Output = T>
126        + Sub<Output = T>
127        + Mul<Output = T>
128        + Div<Output = T>
129        + std::iter::Sum
130        + scirs2_core::simd_ops::SimdUnifiedOps
131        + Send
132        + Sync
133        + 'static,
134{
135    let (n, _) = matrix.shape();
136
137    // Ensure the subspace size is valid
138    let subspace_size = options.max_subspace_size.min(n);
139
140    // Ensure the number of eigenvalues requested is valid
141    let numeigenvalues = options.numeigenvalues.min(subspace_size);
142
143    // Initialize the first Lanczos vector
144    let mut v = match initial_guess {
145        Some(v) => {
146            if v.len() != n {
147                return Err(SparseError::DimensionMismatch {
148                    expected: n,
149                    found: v.len(),
150                });
151            }
152            // Create a copy of the initial guess
153            let mut v_arr = Array1::zeros(n);
154            for i in 0..n {
155                v_arr[i] = v[i];
156            }
157            v_arr
158        }
159        None => {
160            // Random initialization
161            let mut v_arr = Array1::zeros(n);
162            v_arr[0] = T::sparse_one(); // Simple initialization with [1, 0, 0, ...]
163            v_arr
164        }
165    };
166
167    // Normalize the initial vector
168    let norm = (v.iter().map(|&val| val * val).sum::<T>()).sqrt();
169    if !SparseElement::is_zero(&norm) {
170        for i in 0..n {
171            v[i] = v[i] / norm;
172        }
173    }
174
175    // Allocate space for Lanczos vectors
176    let mut v_vectors = Vec::with_capacity(subspace_size);
177    v_vectors.push(v.clone());
178
179    // Allocate space for tridiagonal matrix elements
180    let mut alpha = Vec::<T>::with_capacity(subspace_size); // Diagonal elements
181    let mut beta = Vec::<T>::with_capacity(subspace_size - 1); // Off-diagonal elements
182
183    // First iteration step
184    let mut w = sym_csr_matvec(matrix, &v.view())?;
185    let alpha_j = v.iter().zip(w.iter()).map(|(&vi, &wi)| vi * wi).sum::<T>();
186    alpha.push(alpha_j);
187
188    // Orthogonalize against previous vectors
189    for i in 0..n {
190        w[i] = w[i] - alpha_j * v[i];
191    }
192
193    // Compute beta (norm of w)
194    let beta_j = (w.iter().map(|&val| val * val).sum::<T>()).sqrt();
195
196    let mut iter = 1;
197    let mut converged = false;
198
199    while iter < options.max_iter && alpha.len() < subspace_size {
200        if SparseElement::is_zero(&beta_j) {
201            // Lucky breakdown - exact invariant subspace found
202            break;
203        }
204
205        beta.push(beta_j);
206
207        // Next Lanczos vector
208        let mut v_next = Array1::zeros(n);
209        for i in 0..n {
210            v_next[i] = w[i] / beta_j;
211        }
212
213        // Store the vector
214        v_vectors.push(v_next.clone());
215
216        // Next iteration step
217        w = sym_csr_matvec(matrix, &v_next.view())?;
218
219        // Full reorthogonalization (for numerical stability)
220        for v_j in v_vectors.iter() {
221            let proj = v_j
222                .iter()
223                .zip(w.iter())
224                .map(|(&vj, &wi)| vj * wi)
225                .sum::<T>();
226            for i in 0..n {
227                w[i] = w[i] - proj * v_j[i];
228            }
229        }
230
231        // Compute alpha
232        let alpha_j = v_next
233            .iter()
234            .zip(w.iter())
235            .map(|(&vi, &wi)| vi * wi)
236            .sum::<T>();
237        alpha.push(alpha_j);
238
239        // Update v for next iteration
240        for i in 0..n {
241            w[i] = w[i] - alpha_j * v_next[i];
242        }
243
244        // Compute beta for next iteration
245        let beta_j_next = (w.iter().map(|&val| val * val).sum::<T>()).sqrt();
246
247        // Check for convergence using the largest eigenvalue approx
248        if alpha.len() >= numeigenvalues {
249            // Build and solve the tridiagonal system
250            let (eigvals, _) = solve_tridiagonal_eigenproblem(&alpha, &beta, numeigenvalues)?;
251
252            // Check if the largest eigvals have converged (using beta as an error estimate)
253            if beta_j_next < T::from(options.tol).unwrap() * eigvals[0].abs() {
254                converged = true;
255                break;
256            }
257        }
258
259        v = v_next;
260        iter += 1;
261
262        // Update beta for next iteration
263        if iter < options.max_iter && alpha.len() < subspace_size {
264            let _beta_j = beta_j_next;
265        }
266    }
267
268    // Solve the final tridiagonal eigenproblem
269    let (eigvals, eigvecs) = solve_tridiagonal_eigenproblem(&alpha, &beta, numeigenvalues)?;
270
271    // Compute the Ritz vectors (eigenvectors in the original space) if requested
272    let eigenvectors = if options.compute_eigenvectors {
273        let mut ritz_vectors = Array2::zeros((n, numeigenvalues));
274
275        for k in 0..numeigenvalues {
276            for i in 0..n {
277                let mut sum = T::sparse_zero();
278                for j in 0..v_vectors.len() {
279                    if j < eigvecs.len() && k < eigvecs[j].len() {
280                        sum = sum + eigvecs[j][k] * v_vectors[j][i];
281                    }
282                }
283                ritz_vectors[[i, k]] = sum;
284            }
285        }
286
287        Some(ritz_vectors)
288    } else {
289        None
290    };
291
292    // Compute residuals
293    let actualeigenvalues = eigvals.len();
294    let mut residuals = Array1::zeros(actualeigenvalues);
295    if let Some(ref evecs) = eigenvectors {
296        for k in 0..actualeigenvalues {
297            let mut evec = Array1::zeros(n);
298            for i in 0..n {
299                evec[i] = evecs[[i, k]];
300            }
301
302            let ax = sym_csr_matvec(matrix, &evec.view())?;
303
304            let mut res = Array1::zeros(n);
305            for i in 0..n {
306                res[i] = ax[i] - eigvals[k] * evec[i];
307            }
308
309            residuals[k] = (res.iter().map(|&v| v * v).sum::<T>()).sqrt();
310        }
311    } else {
312        // If no eigenvectors were computed, use the Kaniel-Paige error bound
313        // (beta_j * last component of eigenvector in the Krylov basis)
314        for k in 0..numeigenvalues {
315            if k < eigvecs.len() && !beta.is_empty() {
316                residuals[k] = beta[beta.len() - 1] * eigvecs[eigvecs.len() - 1][k].abs();
317            }
318        }
319    }
320
321    // Create the result
322    let result = EigenResult {
323        eigenvalues: Array1::from_vec(eigvals),
324        eigenvectors,
325        iterations: iter,
326        residuals,
327        converged,
328    };
329
330    Ok(result)
331}
332
333/// Solves a symmetric tridiagonal eigenvalue problem.
334///
335/// This function computes the eigenvalues and eigenvectors of a symmetric
336/// tridiagonal matrix defined by its diagonal elements `alpha` and
337/// off-diagonal elements `beta`.
338///
339/// # Arguments
340///
341/// * `alpha` - Diagonal elements
342/// * `beta` - Off-diagonal elements
343/// * `numeigenvalues` - Number of eigenvalues to compute
344///
345/// # Returns
346///
347/// A tuple containing:
348/// - The eigenvalues in descending order
349/// - The corresponding eigenvectors
350#[allow(dead_code)]
351fn solve_tridiagonal_eigenproblem<T>(
352    alpha: &[T],
353    beta: &[T],
354    numeigenvalues: usize,
355) -> SparseResult<(Vec<T>, Vec<Vec<T>>)>
356where
357    T: Float
358        + SparseElement
359        + Debug
360        + Copy
361        + Add<Output = T>
362        + Sub<Output = T>
363        + Mul<Output = T>
364        + Div<Output = T>,
365{
366    let n = alpha.len();
367    if n == 0 {
368        return Err(SparseError::ValueError(
369            "Empty tridiagonal matrix".to_string(),
370        ));
371    }
372
373    if beta.len() != n - 1 {
374        return Err(SparseError::DimensionMismatch {
375            expected: n - 1,
376            found: beta.len(),
377        });
378    }
379
380    // For small matrices, use a simple algorithm for all eigenvalues
381    if n <= 3 {
382        return solve_small_tridiagonal(alpha, beta, numeigenvalues);
383    }
384
385    // For larger matrices, use the QL algorithm with implicit shifts
386    // This is a simplified implementation and could be optimized further
387
388    // Clone the diagonal and off-diagonal elements
389    let mut d = alpha.to_vec();
390    let mut e = beta.to_vec();
391    e.push(T::sparse_zero()); // Add a zero at the end
392
393    // Allocate space for eigenvectors
394    let mut z = vec![vec![T::sparse_zero(); n]; n];
395    #[allow(clippy::needless_range_loop)]
396    for i in 0..n {
397        z[i][i] = T::sparse_one(); // Initialize with identity matrix
398    }
399
400    // Run the QL algorithm with implicit shifts
401    for l in 0..n {
402        let mut iter = 0;
403        let max_iter = 30; // Typical value
404
405        loop {
406            // Look for a small off-diagonal element
407            let mut m = l;
408            while m < n - 1 {
409                if e[m].abs() <= T::from(1e-12).unwrap() * (d[m].abs() + d[m + 1].abs()) {
410                    break;
411                }
412                m += 1;
413            }
414
415            if m == l {
416                // No more work for this eigenvalue
417                break;
418            }
419
420            if iter >= max_iter {
421                // Too many iterations, return error
422                return Err(SparseError::IterativeSolverFailure(
423                    "QL algorithm did not converge".to_string(),
424                ));
425            }
426
427            let g = (d[l + 1] - d[l]) * T::from(0.5).unwrap() / e[l];
428            let r = (g * g + T::sparse_one()).sqrt();
429            let mut g = d[m] - d[l] + e[l] / (g + if g >= T::sparse_zero() { r } else { -r });
430
431            let mut s = T::sparse_one();
432            let mut c = T::sparse_one();
433            let mut p = T::sparse_zero();
434
435            let mut i = m - 1;
436            while i >= l && i < n {
437                // Handle unsigned underflow
438                let f = s * e[i];
439                let b = c * e[i];
440
441                // Compute the Givens rotation
442                let r = (f * f + g * g).sqrt();
443                e[i + 1] = r;
444
445                if SparseElement::is_zero(&r) {
446                    // Avoid division by zero
447                    d[i + 1] = d[i + 1] - p;
448                    e[m] = T::sparse_zero();
449                    break;
450                }
451
452                s = f / r;
453                c = g / r;
454
455                let _h = g * p;
456                p = s * (d[i] - d[i + 1]) + c * b;
457                d[i + 1] = d[i + 1] + p;
458                g = c * s - b;
459
460                // Update eigenvectors
461                #[allow(clippy::needless_range_loop)]
462                for k in 0..n {
463                    let t = z[k][i + 1];
464                    z[k][i + 1] = s * z[k][i] + c * t;
465                    z[k][i] = c * z[k][i] - s * t;
466                }
467
468                if i == 0 {
469                    break;
470                }
471                i -= 1;
472            }
473
474            if (i as i32) < (l as i32) || i >= n {
475                // Handle the case of i becoming invalid after decrement
476                break;
477            }
478
479            if SparseElement::is_zero(&r) {
480                if abs_diff_eq!(m, l + 1) {
481                    // Special case for m == l + 1
482                    break;
483                }
484                d[l] = d[l] - p;
485                e[l] = g;
486                e[m - 1] = T::sparse_zero();
487            }
488
489            iter += 1;
490        }
491    }
492
493    // Sort eigenvalues and eigenvectors in descending order
494    let mut indices: Vec<usize> = (0..n).collect();
495    indices.sort_by(|&i, &j| d[j].partial_cmp(&d[i]).unwrap_or(std::cmp::Ordering::Equal));
496
497    let mut sortedeigenvalues = Vec::with_capacity(numeigenvalues);
498    let mut sorted_eigenvectors = Vec::with_capacity(numeigenvalues);
499
500    #[allow(clippy::needless_range_loop)]
501    for k in 0..numeigenvalues.min(n) {
502        let idx = indices[k];
503        sortedeigenvalues.push(d[idx]);
504
505        let mut eigenvector = Vec::with_capacity(n);
506        #[allow(clippy::needless_range_loop)]
507        for i in 0..n {
508            eigenvector.push(z[i][idx]);
509        }
510        sorted_eigenvectors.push(eigenvector);
511    }
512
513    Ok((sortedeigenvalues, sorted_eigenvectors))
514}
515
516/// Solves a small (n ≤ 3) symmetric tridiagonal eigenvalue problem.
517#[allow(unused_assignments)]
518#[allow(dead_code)]
519fn solve_small_tridiagonal<T>(
520    alpha: &[T],
521    beta: &[T],
522    numeigenvalues: usize,
523) -> SparseResult<(Vec<T>, Vec<Vec<T>>)>
524where
525    T: Float
526        + SparseElement
527        + Debug
528        + Copy
529        + Add<Output = T>
530        + Sub<Output = T>
531        + Mul<Output = T>
532        + Div<Output = T>,
533{
534    let n = alpha.len();
535
536    if n == 1 {
537        // 1x1 case - just return the single value
538        return Ok((vec![alpha[0]], vec![vec![T::sparse_one()]]));
539    }
540
541    if n == 2 {
542        // 2x2 case - direct formula
543        let a = alpha[0];
544        let b = alpha[1];
545        let c = beta[0];
546
547        let trace = a + b;
548        let det = a * b - c * c;
549
550        // Calculate eigenvalues
551        let discriminant = (trace * trace - T::from(4.0).unwrap() * det).sqrt();
552        let lambda1 = (trace + discriminant) * T::from(0.5).unwrap();
553        let lambda2 = (trace - discriminant) * T::from(0.5).unwrap();
554
555        // Sort in descending order
556        let (lambda1, lambda2) = if lambda1 >= lambda2 {
557            (lambda1, lambda2)
558        } else {
559            (lambda2, lambda1)
560        };
561
562        // Calculate eigenvectors
563        let mut v1 = vec![T::sparse_zero(); 2];
564        let mut v2 = vec![T::sparse_zero(); 2];
565
566        if !SparseElement::is_zero(&c) {
567            v1[0] = c;
568            v1[1] = lambda1 - a;
569
570            v2[0] = c;
571            v2[1] = lambda2 - a;
572
573            // Normalize
574            let norm1 = (v1[0] * v1[0] + v1[1] * v1[1]).sqrt();
575            let norm2 = (v2[0] * v2[0] + v2[1] * v2[1]).sqrt();
576
577            if !SparseElement::is_zero(&norm1) {
578                v1[0] = v1[0] / norm1;
579                v1[1] = v1[1] / norm1;
580            }
581
582            if !SparseElement::is_zero(&norm2) {
583                v2[0] = v2[0] / norm2;
584                v2[1] = v2[1] / norm2;
585            }
586        } else {
587            // c is zero - diagonal matrix case
588            if a >= b {
589                v1[0] = T::sparse_one();
590                v1[1] = T::sparse_zero();
591
592                v2[0] = T::sparse_zero();
593                v2[1] = T::sparse_one();
594            } else {
595                v1[0] = T::sparse_zero();
596                v1[1] = T::sparse_one();
597
598                v2[0] = T::sparse_one();
599                v2[1] = T::sparse_zero();
600            }
601        }
602
603        let mut eigenvalues = vec![lambda1, lambda2];
604        let mut eigenvectors = vec![v1, v2];
605
606        // Return only the requested number of eigenvalues
607        eigenvalues.truncate(numeigenvalues);
608        eigenvectors.truncate(numeigenvalues);
609
610        return Ok((eigenvalues, eigenvectors));
611    }
612
613    if n == 3 {
614        // 3x3 case - use characteristic polynomial
615        let a = alpha[0];
616        let b = alpha[1];
617        let c = alpha[2];
618        let d = beta[0];
619        let e = beta[1];
620
621        // Characteristic polynomial coefficients
622        let p = -(a + b + c);
623        let q = a * b + a * c + b * c - d * d - e * e;
624        let r = -(a * b * c - a * e * e - c * d * d);
625
626        // Solve the cubic equation using the Vieta formulas
627        let eigenvalues = solve_cubic(p, q, r)?;
628
629        // Sort eigenvalues in descending order
630        let mut sortedeigenvalues = eigenvalues.clone();
631        sortedeigenvalues.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
632
633        // Compute eigenvectors
634        let mut eigenvectors = Vec::with_capacity(sortedeigenvalues.len());
635
636        for &lambda in &sortedeigenvalues[0..numeigenvalues.min(3)] {
637            // For each eigenvalue, construct the resulting linear system
638            // (A - lambda*I)v = 0, and solve for v
639
640            // Build the matrix (A - lambda*I)
641            let mut m00 = a - lambda;
642            let mut m01 = d;
643            let m02 = T::sparse_zero();
644
645            let mut m10 = d;
646            let mut m11 = b - lambda;
647            let mut m12 = e;
648
649            let m20 = T::sparse_zero();
650            let mut m21 = e;
651            let mut m22 = c - lambda;
652
653            // Find the largest absolute row to use as pivot
654            let r0_norm = (m00 * m00 + m01 * m01 + m02 * m02).sqrt();
655            let r1_norm = (m10 * m10 + m11 * m11 + m12 * m12).sqrt();
656            let r2_norm = (m20 * m20 + m21 * m21 + m22 * m22).sqrt();
657
658            let mut v = vec![T::sparse_zero(); 3];
659
660            if r0_norm >= r1_norm && r0_norm >= r2_norm && !SparseElement::is_zero(&r0_norm) {
661                // Use first row as pivot
662                let scale = T::sparse_one() / r0_norm;
663                m00 = m00 * scale;
664                m01 = m01 * scale;
665
666                // Eliminate first variable from second row
667                let factor = m10 / m00;
668                m11 = m11 - factor * m01;
669                m12 = m12 - factor * m02;
670
671                // Eliminate first variable from third row
672                let factor = m20 / m00;
673                m21 = m21 - factor * m01;
674                m22 = m22 - factor * m02;
675
676                // Back-substitute
677                v[2] = T::sparse_one(); // Set last component to 1
678                v[1] = -m12 * v[2] / m11;
679                v[0] = -(m01 * v[1] + m02 * v[2]) / m00;
680            } else if r1_norm >= r0_norm && r1_norm >= r2_norm && !SparseElement::is_zero(&r1_norm)
681            {
682                // Use second row as pivot
683                let scale = T::sparse_one() / r1_norm;
684                m10 = m10 * scale;
685                m11 = m11 * scale;
686                m12 = m12 * scale;
687
688                // Back-substitute
689                v[2] = T::sparse_one(); // Set last component to 1
690                v[0] = -m02 * v[2] / m00;
691                v[1] = -(m10 * v[0] + m12 * v[2]) / m11;
692            } else if !SparseElement::is_zero(&r2_norm) {
693                // Use third row as pivot
694                v[0] = T::sparse_one(); // Set first component to 1
695                v[1] = T::sparse_zero();
696                v[2] = T::sparse_zero();
697            } else {
698                // Degenerate case - just use unit vector
699                v[0] = T::sparse_one();
700                v[1] = T::sparse_zero();
701                v[2] = T::sparse_zero();
702            }
703
704            // Normalize the eigenvector
705            let norm = (v[0] * v[0] + v[1] * v[1] + v[2] * v[2]).sqrt();
706            if !SparseElement::is_zero(&norm) {
707                v[0] = v[0] / norm;
708                v[1] = v[1] / norm;
709                v[2] = v[2] / norm;
710            }
711
712            eigenvectors.push(v);
713        }
714
715        // Return only the requested number of eigenvalues
716        sortedeigenvalues.truncate(numeigenvalues);
717        eigenvectors.truncate(numeigenvalues);
718
719        return Ok((sortedeigenvalues, eigenvectors));
720    }
721
722    // For n > 3, fallback to general algorithm (not implemented here)
723    Err(SparseError::ValueError(
724        "Tridiagonal eigenvalue problem for n > 3 not implemented".to_string(),
725    ))
726}
727
728/// Solves a cubic equation ax³ + bx² + cx + d = 0
729/// using Cardano's formula.
730fn solve_cubic<T>(p: T, q: T, r: T) -> SparseResult<Vec<T>>
731where
732    T: Float
733        + SparseElement
734        + Debug
735        + Copy
736        + Add<Output = T>
737        + Sub<Output = T>
738        + Mul<Output = T>
739        + Div<Output = T>,
740{
741    // The equation is x³ + px² + qx + r = 0
742
743    // Substitute x = y - p/3 to eliminate the quadratic term
744    let p_over_3 = p / T::from(3.0).unwrap();
745    let q_new = q - p * p / T::from(3.0).unwrap();
746    let r_new = r - p * q / T::from(3.0).unwrap()
747        + T::from(2.0).unwrap() * p * p * p / T::from(27.0).unwrap();
748
749    // Now solve y³ + q_new * y + r_new = 0
750    let discriminant =
751        -(T::from(4.0).unwrap() * q_new * q_new * q_new + T::from(27.0).unwrap() * r_new * r_new);
752
753    if discriminant > T::sparse_zero() {
754        // Three real roots
755        let theta = ((T::from(3.0).unwrap() * r_new) / (T::from(2.0).unwrap() * q_new)
756            * (-T::from(3.0).unwrap() / q_new).sqrt())
757        .acos();
758        let sqrt_term = T::from(2.0).unwrap() * (-q_new / T::from(3.0).unwrap()).sqrt();
759
760        let y1 = sqrt_term * (theta / T::from(3.0).unwrap()).cos();
761        let y2 = sqrt_term
762            * ((theta + T::from(2.0).unwrap() * T::from(std::f64::consts::PI).unwrap())
763                / T::from(3.0).unwrap())
764            .cos();
765        let y3 = sqrt_term
766            * ((theta + T::from(4.0).unwrap() * T::from(std::f64::consts::PI).unwrap())
767                / T::from(3.0).unwrap())
768            .cos();
769
770        let x1 = y1 - p_over_3;
771        let x2 = y2 - p_over_3;
772        let x3 = y3 - p_over_3;
773
774        Ok(vec![x1, x2, x3])
775    } else {
776        // One real root
777        let u = (-r_new / T::from(2.0).unwrap() + (discriminant / T::from(-108.0).unwrap()).sqrt())
778            .cbrt();
779        let v = if SparseElement::is_zero(&u) {
780            T::sparse_zero()
781        } else {
782            -q_new / (T::from(3.0).unwrap() * u)
783        };
784
785        let y = u + v;
786        let x = y - p_over_3;
787
788        Ok(vec![x])
789    }
790}
791
792#[cfg(test)]
793mod tests {
794    use super::*;
795    use crate::sym_csr::SymCsrMatrix;
796
797    #[test]
798    fn test_lanczos_simple() {
799        // Create a simple 2x2 symmetric matrix [[2, 1], [1, 2]]
800        // Only store lower triangular part for symmetric CSR
801        let data = vec![2.0, 1.0, 2.0]; // values: diag[0], [1,0], diag[1]
802        let indptr = vec![0, 1, 3]; // row 0 has 1 element, row 1 has 2 elements
803        let indices = vec![0, 0, 1]; // column indices
804        let matrix = SymCsrMatrix::new(data, indptr, indices, (2, 2)).unwrap();
805
806        let options = LanczosOptions {
807            max_iter: 100,
808            max_subspace_size: 2,
809            tol: 1e-8,
810            numeigenvalues: 1,
811            compute_eigenvectors: true,
812        };
813        let result = lanczos(&matrix, &options, None).unwrap();
814
815        assert!(result.converged);
816        assert_eq!(result.eigenvalues.len(), 1);
817        // Test that we get a finite eigenvalue (algorithm converges)
818        assert!(result.eigenvalues[0].is_finite());
819    }
820
821    #[test]
822    fn test_tridiagonal_solver_2x2() {
823        let alpha = vec![2.0, 3.0];
824        let beta = vec![1.0];
825        let (eigenvalues, _eigenvectors) =
826            solve_tridiagonal_eigenproblem(&alpha, &beta, 2).unwrap();
827
828        assert_eq!(eigenvalues.len(), 2);
829        // Eigenvalues should be sorted in descending order
830        assert!(eigenvalues[0] >= eigenvalues[1]);
831    }
832
833    #[test]
834    fn test_solve_cubic() {
835        // Test x³ - 6x² + 11x - 6 = 0, which has roots 1, 2, 3
836        let roots = solve_cubic(-6.0, 11.0, -6.0).unwrap();
837        assert_eq!(roots.len(), 3);
838
839        // Sort roots for comparison
840        let mut sorted_roots = roots;
841        sorted_roots.sort_by(|a, b| a.partial_cmp(b).unwrap());
842
843        assert!((sorted_roots[0] - 1.0).abs() < 1e-10);
844        assert!((sorted_roots[1] - 2.0).abs() < 1e-10);
845        assert!((sorted_roots[2] - 3.0).abs() < 1e-10);
846    }
847}