single_svdlib/lanczos/
mod.rs

1pub mod masked;
2
3use crate::error::SvdLibError;
4use crate::{Diagnostics, SMat, SvdFloat, SvdRec};
5use ndarray::{Array, Array2};
6use num_traits::real::Real;
7use num_traits::{Float, FromPrimitive, One, Zero};
8use rand::rngs::StdRng;
9use rand::{thread_rng, Rng, SeedableRng};
10use rayon::iter::IndexedParallelIterator;
11use rayon::iter::ParallelIterator;
12use rayon::prelude::{IntoParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator};
13use std::fmt::Debug;
14use std::iter::Sum;
15use std::mem;
16use std::ops::{AddAssign, MulAssign, Neg, SubAssign};
17
18/// Trait for floating point types that can be used with the SVD algorithm
19
20/// SVD at full dimensionality, calls `svdLAS2` with the highlighted defaults
21///
22/// svdLAS2(A, `0`, `0`, `&[-1.0e-30, 1.0e-30]`, `1.0e-6`, `0`)
23///
24/// # Parameters
25/// - A: Sparse matrix
26pub fn svd<T, M>(a: &M) -> Result<SvdRec<T>, SvdLibError>
27where
28    T: SvdFloat,
29    M: SMat<T>,
30{
31    let eps_small = T::from_f64(-1.0e-30).unwrap();
32    let eps_large = T::from_f64(1.0e-30).unwrap();
33    let kappa = T::from_f64(1.0e-6).unwrap();
34    svd_las2(a, 0, 0, &[eps_small, eps_large], kappa, 0)
35}
36
37/// SVD at desired dimensionality, calls `svdLAS2` with the highlighted defaults
38///
39/// svdLAS2(A, dimensions, `0`, `&[-1.0e-30, 1.0e-30]`, `1.0e-6`, `0`)
40///
41/// # Parameters
42/// - A: Sparse matrix
43/// - dimensions: Upper limit of desired number of dimensions, bounded by the matrix shape
44pub fn svd_dim<T, M>(a: &M, dimensions: usize) -> Result<SvdRec<T>, SvdLibError>
45where
46    T: SvdFloat,
47    M: SMat<T>,
48{
49    let eps_small = T::from_f64(-1.0e-30).unwrap();
50    let eps_large = T::from_f64(1.0e-30).unwrap();
51    let kappa = T::from_f64(1.0e-6).unwrap();
52
53    svd_las2(a, dimensions, 0, &[eps_small, eps_large], kappa, 0)
54}
55
56/// SVD at desired dimensionality with supplied seed, calls `svdLAS2` with the highlighted defaults
57///
58/// svdLAS2(A, dimensions, `0`, `&[-1.0e-30, 1.0e-30]`, `1.0e-6`, random_seed)
59///
60/// # Parameters
61/// - A: Sparse matrix
62/// - dimensions: Upper limit of desired number of dimensions, bounded by the matrix shape
63/// - random_seed: A supplied seed `if > 0`, otherwise an internal seed will be generated
64pub fn svd_dim_seed<T, M>(
65    a: &M,
66    dimensions: usize,
67    random_seed: u32,
68) -> Result<SvdRec<T>, SvdLibError>
69where
70    T: SvdFloat,
71    M: SMat<T>,
72{
73    let eps_small = T::from_f64(-1.0e-30).unwrap();
74    let eps_large = T::from_f64(1.0e-30).unwrap();
75    let kappa = T::from_f64(1.0e-6).unwrap();
76
77    svd_las2(
78        a,
79        dimensions,
80        0,
81        &[eps_small, eps_large],
82        kappa,
83        random_seed,
84    )
85}
86
87/// Compute a singular value decomposition
88///
89/// # Parameters
90///
91/// - A: Sparse matrix
92/// - dimensions: Upper limit of desired number of dimensions (0 = max),
93///       where "max" is a value bounded by the matrix shape, the smaller of
94///       the matrix rows or columns. e.g. `A.nrows().min(A.ncols())`
95/// - iterations: Upper limit of desired number of lanczos steps (0 = max),
96///       where "max" is a value bounded by the matrix shape, the smaller of
97///       the matrix rows or columns. e.g. `A.nrows().min(A.ncols())`
98///       iterations must also be in range [`dimensions`, `A.nrows().min(A.ncols())`]
99/// - end_interval: Left, right end of interval containing unwanted eigenvalues,
100///       typically small values centered around zero, e.g. `[-1.0e-30, 1.0e-30]`
101/// - kappa: Relative accuracy of ritz values acceptable as eigenvalues, e.g. `1.0e-6`
102/// - random_seed: A supplied seed `if > 0`, otherwise an internal seed will be generated
103pub fn svd_las2<T, M>(
104    a: &M,
105    dimensions: usize,
106    iterations: usize,
107    end_interval: &[T; 2],
108    kappa: T,
109    random_seed: u32,
110) -> Result<SvdRec<T>, SvdLibError>
111where
112    T: SvdFloat,
113    M: SMat<T>,
114{
115    let random_seed = match random_seed > 0 {
116        true => random_seed,
117        false => thread_rng().gen::<_>(),
118    };
119
120    let min_nrows_ncols = a.nrows().min(a.ncols());
121
122    let dimensions = match dimensions {
123        n if n == 0 || n > min_nrows_ncols => min_nrows_ncols,
124        _ => dimensions,
125    };
126
127    let iterations = match iterations {
128        n if n == 0 || n > min_nrows_ncols => min_nrows_ncols,
129        n if n < dimensions => dimensions,
130        _ => iterations,
131    };
132
133    if dimensions < 2 {
134        return Err(SvdLibError::Las2Error(format!(
135            "svd_las2: insufficient dimensions: {dimensions}"
136        )));
137    }
138
139    assert!(dimensions > 1 && dimensions <= min_nrows_ncols);
140    assert!(iterations >= dimensions && iterations <= min_nrows_ncols);
141
142    let transposed = (a.ncols() as f64) >= ((a.nrows() as f64) * 1.2);
143    let nrows = if transposed { a.ncols() } else { a.nrows() };
144    let ncols = if transposed { a.nrows() } else { a.ncols() };
145
146    let mut wrk = WorkSpace::new(nrows, ncols, transposed, iterations)?;
147    let mut store = Store::new(ncols)?;
148
149    let mut neig = 0;
150    let steps = lanso(
151        a,
152        dimensions,
153        iterations,
154        end_interval,
155        &mut wrk,
156        &mut neig,
157        &mut store,
158        random_seed,
159    )?;
160
161    let kappa = kappa.abs().max(T::eps34());
162    let mut r = ritvec(a, dimensions, kappa, &mut wrk, steps, neig, &mut store)?;
163
164    if transposed {
165        mem::swap(&mut r.Ut, &mut r.Vt);
166    }
167
168    Ok(SvdRec {
169        // Dimensionality (number of Ut,Vt rows & length of S)
170        d: r.d,
171        u: Array2::from_shape_vec((r.d, r.Ut.cols), r.Ut.value)?,
172        s: Array::from_shape_vec(r.d, r.S)?,
173        vt: Array2::from_shape_vec((r.d, r.Vt.cols), r.Vt.value)?,
174        diagnostics: Diagnostics {
175            non_zero: a.nnz(),
176            dimensions: dimensions,
177            iterations: iterations,
178            transposed: transposed,
179            lanczos_steps: steps + 1,
180            ritz_values_stabilized: neig,
181            significant_values: r.d,
182            singular_values: r.nsig,
183            end_interval: *end_interval,
184            kappa: kappa,
185            random_seed: random_seed,
186        },
187    })
188}
189
190const MAXLL: usize = 2;
191
192#[derive(Debug, Clone, PartialEq)]
193struct Store<T: Float> {
194    n: usize,
195    vecs: Vec<Vec<T>>,
196}
197
198impl<T: Float + Zero + Clone> Store<T> {
199    fn new(n: usize) -> Result<Self, SvdLibError> {
200        Ok(Self { n, vecs: vec![] })
201    }
202
203    fn storq(&mut self, idx: usize, v: &[T]) {
204        while idx + MAXLL >= self.vecs.len() {
205            self.vecs.push(vec![T::zero(); self.n]);
206        }
207        self.vecs[idx + MAXLL].copy_from_slice(v);
208    }
209
210    fn storp(&mut self, idx: usize, v: &[T]) {
211        while idx >= self.vecs.len() {
212            self.vecs.push(vec![T::zero(); self.n]);
213        }
214        self.vecs[idx].copy_from_slice(v);
215    }
216
217    fn retrq(&mut self, idx: usize) -> &[T] {
218        &self.vecs[idx + MAXLL]
219    }
220
221    fn retrp(&mut self, idx: usize) -> &[T] {
222        &self.vecs[idx]
223    }
224}
225
226#[derive(Debug, Clone, PartialEq)]
227struct WorkSpace<T: Float> {
228    nrows: usize,
229    ncols: usize,
230    transposed: bool,
231    w0: Vec<T>,     // workspace 0
232    w1: Vec<T>,     // workspace 1
233    w2: Vec<T>,     // workspace 2
234    w3: Vec<T>,     // workspace 3
235    w4: Vec<T>,     // workspace 4
236    w5: Vec<T>,     // workspace 5
237    alf: Vec<T>,    // array to hold diagonal of the tridiagonal matrix T
238    eta: Vec<T>,    // orthogonality estimate of Lanczos vectors at step j
239    oldeta: Vec<T>, // orthogonality estimate of Lanczos vectors at step j-1
240    bet: Vec<T>,    // array to hold off-diagonal of T
241    bnd: Vec<T>,    // array to hold the error bounds
242    ritz: Vec<T>,   // array to hold the ritz values
243    temp: Vec<T>,   // array to hold the temp values
244}
245
246impl<T: Float + Zero + FromPrimitive> WorkSpace<T> {
247    fn new(
248        nrows: usize,
249        ncols: usize,
250        transposed: bool,
251        iterations: usize,
252    ) -> Result<Self, SvdLibError> {
253        Ok(Self {
254            nrows,
255            ncols,
256            transposed,
257            w0: vec![T::zero(); ncols],
258            w1: vec![T::zero(); ncols],
259            w2: vec![T::zero(); ncols],
260            w3: vec![T::zero(); ncols],
261            w4: vec![T::zero(); ncols],
262            w5: vec![T::zero(); ncols],
263            alf: vec![T::zero(); iterations],
264            eta: vec![T::zero(); iterations],
265            oldeta: vec![T::zero(); iterations],
266            bet: vec![T::zero(); 1 + iterations],
267            ritz: vec![T::zero(); 1 + iterations],
268            bnd: vec![T::from_f64(f64::MAX).unwrap(); 1 + iterations],
269            temp: vec![T::zero(); nrows],
270        })
271    }
272}
273
274/* Row-major dense matrix.  Rows are consecutive vectors. */
275#[derive(Debug, Clone, PartialEq)]
276struct DMat<T: Float> {
277    cols: usize,
278    value: Vec<T>,
279}
280
281#[allow(non_snake_case)]
282#[derive(Debug, Clone, PartialEq)]
283struct SVDRawRec<T: Float> {
284    d: usize,
285    nsig: usize,
286    Ut: DMat<T>,
287    S: Vec<T>,
288    Vt: DMat<T>,
289}
290
291fn compare<T: SvdFloat>(computed: T, expected: T) -> bool {
292    T::compare(computed, expected)
293}
294
295/* Function sorts array1 and array2 into increasing order for array1 */
296fn insert_sort<T: PartialOrd>(n: usize, array1: &mut [T], array2: &mut [T]) {
297    for i in 1..n {
298        for j in (1..i + 1).rev() {
299            if array1[j - 1] <= array1[j] {
300                break;
301            }
302            array1.swap(j - 1, j);
303            array2.swap(j - 1, j);
304        }
305    }
306}
307
308#[allow(non_snake_case)]
309#[rustfmt::skip]
310fn svd_opb<T: Float>(A: &dyn SMat<T>, x: &[T], y: &mut [T], temp: &mut [T], transposed: bool) {
311    let nrows = if transposed { A.ncols() } else { A.nrows() };
312    let ncols = if transposed { A.nrows() } else { A.ncols() };
313    assert_eq!(x.len(), ncols, "svd_opb: x must be A.ncols() in length, x = {}, A.ncols = {}", x.len(), ncols);
314    assert_eq!(y.len(), ncols, "svd_opb: y must be A.ncols() in length, y = {}, A.ncols = {}", y.len(), ncols);
315    assert_eq!(temp.len(), nrows, "svd_opa: temp must be A.nrows() in length, temp = {}, A.nrows = {}", temp.len(), nrows);
316    A.svd_opa(x, temp, transposed); // temp = (A * x)
317    A.svd_opa(temp, y, !transposed); // y = A' * (A * x) = A' * temp
318}
319
320// constant times a vector plus a vector
321fn svd_daxpy<T: Float + AddAssign + Send + Sync>(da: T, x: &[T], y: &mut [T]) {
322    if x.len() < 1000 {
323        for (xval, yval) in x.iter().zip(y.iter_mut()) {
324            *yval += da * *xval
325        }
326    } else {
327        y.par_iter_mut()
328            .zip(x.par_iter())
329            .for_each(|(yval, xval)| *yval += da * *xval);
330    }
331}
332
333// finds the index of element having max absolute value
334fn svd_idamax<T: Float>(n: usize, x: &[T]) -> usize {
335    assert!(n > 0, "svd_idamax: unexpected inputs!");
336
337    match n {
338        1 => 0,
339        _ => {
340            let mut imax = 0;
341            for (i, xval) in x.iter().enumerate().take(n).skip(1) {
342                if xval.abs() > x[imax].abs() {
343                    imax = i;
344                }
345            }
346            imax
347        }
348    }
349}
350
351// returns |a| if b is positive; else fsign returns -|a|
352fn svd_fsign<T: Float>(a: T, b: T) -> T {
353    match (a >= T::zero() && b >= T::zero()) || (a < T::zero() && b < T::zero()) {
354        true => a,
355        false => -a,
356    }
357}
358
359// finds sqrt(a^2 + b^2) without overflow or destructive underflow
360fn svd_pythag<T: SvdFloat + FromPrimitive>(a: T, b: T) -> T {
361    match a.abs().max(b.abs()) {
362        n if n > T::zero() => {
363            let mut p = n;
364            let mut r = (a.abs().min(b.abs()) / p).powi(2);
365            let four = T::from_f64(4.0).unwrap();
366            let two = T::from_f64(2.0).unwrap();
367            let mut t = four + r;
368            while !compare(t, four) {
369                let s = r / t;
370                let u = T::one() + two * s;
371                p = p * u;
372                r = (s / u).powi(2);
373                t = four + r;
374            }
375            p
376        }
377        _ => T::zero(),
378    }
379}
380
381// dot product of two vectors
382fn svd_ddot<T: Float + Sum<T> + Send + Sync>(x: &[T], y: &[T]) -> T {
383    if x.len() < 1000 {
384        x.iter().zip(y).map(|(a, b)| *a * *b).sum()
385    } else {
386        x.par_iter().zip(y.par_iter()).map(|(a, b)| *a * *b).sum()
387    }
388}
389
390// norm (length) of a vector
391fn svd_norm<T: Float + Sum<T> + Send + Sync>(x: &[T]) -> T {
392    svd_ddot(x, x).sqrt()
393}
394
395// scales an input vector 'x', by a constant, storing in 'y'
396fn svd_datx<T: Float + Sum<T>>(d: T, x: &[T], y: &mut [T]) {
397    for (i, xval) in x.iter().enumerate() {
398        y[i] = d * *xval;
399    }
400}
401
402// scales an input vector 'x' by a constant, modifying 'x'
403fn svd_dscal<T: Float + MulAssign + Send + Sync>(d: T, x: &mut [T]) {
404    if x.len() < 1000 {
405        for elem in x.iter_mut() {
406            *elem *= d;
407        }
408    } else {
409        x.par_iter_mut().for_each(|elem| {
410            *elem *= d;
411        });
412    }
413}
414
415// copies a vector x to a vector y (reversed direction)
416fn svd_dcopy<T: Float + Copy>(n: usize, offset: usize, x: &[T], y: &mut [T]) {
417    if n > 0 {
418        let start = n - 1;
419        for i in 0..n {
420            y[offset + start - i] = x[offset + i];
421        }
422    }
423}
424
425const MAX_IMTQLB_ITERATIONS: usize = 100;
426
427fn imtqlb<T: SvdFloat>(
428    n: usize,
429    d: &mut [T],
430    e: &mut [T],
431    bnd: &mut [T],
432    max_imtqlb: Option<usize>,
433) -> Result<(), SvdLibError> {
434    let max_imtqlb = max_imtqlb.unwrap_or(MAX_IMTQLB_ITERATIONS);
435    if n == 1 {
436        return Ok(());
437    }
438
439    let matrix_size_factor = T::from_f64((n as f64).sqrt()).unwrap();
440
441    bnd[0] = T::one();
442    let last = n - 1;
443    for i in 1..=last {
444        bnd[i] = T::zero();
445        e[i - 1] = e[i];
446    }
447    e[last] = T::zero();
448
449    let mut i = 0;
450
451    let mut had_convergence_issues = false;
452
453    for l in 0..=last {
454        let mut iteration = 0;
455        let mut p = d[l];
456        let mut f = bnd[l];
457
458        while iteration <= max_imtqlb {
459            let mut m = l;
460            while m < n {
461                if m == last {
462                    break;
463                }
464
465                // More forgiving convergence test for large/sparse matrices
466                let test = d[m].abs() + d[m + 1].abs();
467                // Scale tolerance with matrix size and magnitude
468                let tol = T::epsilon()
469                    * T::from_f64(100.0).unwrap()
470                    * test.max(T::one())
471                    * matrix_size_factor;
472
473                if e[m].abs() <= tol {
474                    break; // Convergence achieved for this element
475                }
476                m += 1;
477            }
478
479            if m == l {
480                // Order the eigenvalues
481                let mut exchange = true;
482                if l > 0 {
483                    i = l;
484                    while i >= 1 && exchange {
485                        if p < d[i - 1] {
486                            d[i] = d[i - 1];
487                            bnd[i] = bnd[i - 1];
488                            i -= 1;
489                        } else {
490                            exchange = false;
491                        }
492                    }
493                }
494                if exchange {
495                    i = 0;
496                }
497                d[i] = p;
498                bnd[i] = f;
499                iteration = max_imtqlb + 1; // Exit the loop
500            } else {
501                // Check if we've reached max iterations without convergence
502                if iteration == max_imtqlb {
503                    // CRITICAL CHANGE: Don't fail, just note the issue and continue
504                    had_convergence_issues = true;
505
506                    // Set conservative error bounds for non-converged values
507                    for idx in l..=m {
508                        bnd[idx] = bnd[idx].max(T::from_f64(0.1).unwrap());
509                    }
510
511                    // Force "convergence" by zeroing the problematic subdiagonal element
512                    e[l] = T::zero();
513
514                    // Break out of the iteration loop and move to next eigenvalue
515                    break;
516                }
517
518                iteration += 1;
519                // ........ form shift ........
520                let two = T::from_f64(2.0).unwrap();
521                let mut g = (d[l + 1] - p) / (two * e[l]);
522                let mut r = svd_pythag(g, T::one());
523                g = d[m] - p + e[l] / (g + svd_fsign(r, g));
524                let mut s = T::one();
525                let mut c = T::one();
526                p = T::zero();
527
528                assert!(m > 0, "imtqlb: expected 'm' to be non-zero");
529                i = m - 1;
530                let mut underflow = false;
531                while !underflow && i >= l {
532                    f = s * e[i];
533                    let b = c * e[i];
534                    r = svd_pythag(f, g);
535                    e[i + 1] = r;
536
537                    // More forgiving underflow detection for sparse matrices
538                    if r < T::epsilon() * T::from_f64(1000.0).unwrap() * (f.abs() + g.abs()) {
539                        underflow = true;
540                        break;
541                    }
542
543                    // Safety check for division by very small numbers
544                    if r.abs() < T::epsilon() * T::from_f64(100.0).unwrap() {
545                        r = T::epsilon() * T::from_f64(100.0).unwrap() * svd_fsign(T::one(), r);
546                    }
547
548                    s = f / r;
549                    c = g / r;
550                    g = d[i + 1] - p;
551                    r = (d[i] - g) * s + T::from_f64(2.0).unwrap() * c * b;
552                    p = s * r;
553                    d[i + 1] = g + p;
554                    g = c * r - b;
555                    f = bnd[i + 1];
556                    bnd[i + 1] = s * bnd[i] + c * f;
557                    bnd[i] = c * bnd[i] - s * f;
558                    if i == 0 {
559                        break;
560                    }
561                    i -= 1;
562                }
563                // ........ recover from underflow .........
564                if underflow {
565                    d[i + 1] -= p;
566                } else {
567                    d[l] -= p;
568                    e[l] = g;
569                }
570                e[m] = T::zero();
571            }
572        }
573    }
574    if had_convergence_issues {
575        eprintln!("Warning: imtqlb had some convergence issues but continued with best estimates. Results may have reduced accuracy.");
576    }
577    Ok(())
578}
579
580#[allow(non_snake_case)]
581fn startv<T: SvdFloat>(
582    A: &dyn SMat<T>,
583    wrk: &mut WorkSpace<T>,
584    step: usize,
585    store: &mut Store<T>,
586    random_seed: u32,
587) -> Result<T, SvdLibError> {
588    // get initial vector; default is random
589    let mut rnm2 = svd_ddot(&wrk.w0, &wrk.w0);
590    for id in 0..3 {
591        if id > 0 || step > 0 || compare(rnm2, T::zero()) {
592            let mut bytes = [0; 32];
593            for (i, b) in random_seed.to_le_bytes().iter().enumerate() {
594                bytes[i] = *b;
595            }
596            let mut seeded_rng = StdRng::from_seed(bytes);
597            for val in wrk.w0.iter_mut() {
598                *val = T::from_f64(seeded_rng.gen_range(-1.0..1.0)).unwrap();
599            }
600        }
601        wrk.w3.copy_from_slice(&wrk.w0);
602
603        // apply operator to put r in range (essential if m singular)
604        svd_opb(A, &wrk.w3, &mut wrk.w0, &mut wrk.temp, wrk.transposed);
605        wrk.w3.copy_from_slice(&wrk.w0);
606        rnm2 = svd_ddot(&wrk.w3, &wrk.w3);
607        if rnm2 > T::zero() {
608            break;
609        }
610    }
611
612    if rnm2 <= T::zero() {
613        return Err(SvdLibError::StartvError(format!(
614            "rnm2 <= 0.0, rnm2 = {rnm2:?}"
615        )));
616    }
617
618    if step > 0 {
619        for i in 0..step {
620            let v = store.retrq(i);
621            svd_daxpy(-svd_ddot(&wrk.w3, v), v, &mut wrk.w0);
622        }
623
624        // make sure q[step] is orthogonal to q[step-1]
625        svd_daxpy(-svd_ddot(&wrk.w4, &wrk.w0), &wrk.w2, &mut wrk.w0);
626        wrk.w3.copy_from_slice(&wrk.w0);
627
628        rnm2 = match svd_ddot(&wrk.w3, &wrk.w3) {
629            dot if dot <= T::eps() * rnm2 => T::zero(),
630            dot => dot,
631        }
632    }
633    Ok(rnm2.sqrt())
634}
635
636#[allow(non_snake_case)]
637fn stpone<T: SvdFloat>(
638    A: &dyn SMat<T>,
639    wrk: &mut WorkSpace<T>,
640    store: &mut Store<T>,
641    random_seed: u32,
642) -> Result<(T, T), SvdLibError> {
643    // get initial vector; default is random
644    let mut rnm = startv(A, wrk, 0, store, random_seed)?;
645    if compare(rnm, T::zero()) {
646        return Err(SvdLibError::StponeError("rnm == 0.0".to_string()));
647    }
648
649    // normalize starting vector
650    svd_datx(rnm.recip(), &wrk.w0, &mut wrk.w1);
651    svd_dscal(rnm.recip(), &mut wrk.w3);
652
653    // take the first step
654    svd_opb(A, &wrk.w3, &mut wrk.w0, &mut wrk.temp, wrk.transposed);
655    wrk.alf[0] = svd_ddot(&wrk.w0, &wrk.w3);
656    svd_daxpy(-wrk.alf[0], &wrk.w1, &mut wrk.w0);
657    let t = svd_ddot(&wrk.w0, &wrk.w3);
658    wrk.alf[0] += t;
659    svd_daxpy(-t, &wrk.w1, &mut wrk.w0);
660    wrk.w4.copy_from_slice(&wrk.w0);
661    rnm = svd_norm(&wrk.w4);
662    let anorm = rnm + wrk.alf[0].abs();
663    Ok((rnm, T::eps().sqrt() * anorm))
664}
665
666#[allow(non_snake_case)]
667#[allow(clippy::too_many_arguments)]
668fn lanczos_step<T: SvdFloat>(
669    A: &dyn SMat<T>,
670    wrk: &mut WorkSpace<T>,
671    first: usize,
672    last: usize,
673    ll: &mut usize,
674    enough: &mut bool,
675    rnm: &mut T,
676    tol: &mut T,
677    store: &mut Store<T>,
678) -> Result<usize, SvdLibError> {
679    let eps1 = T::eps() * T::from_f64(wrk.ncols as f64).unwrap().sqrt();
680    let mut j = first;
681    let four = T::from_f64(4.0).unwrap();
682
683    while j < last {
684        mem::swap(&mut wrk.w1, &mut wrk.w2);
685        mem::swap(&mut wrk.w3, &mut wrk.w4);
686
687        store.storq(j - 1, &wrk.w2);
688        if j - 1 < MAXLL {
689            store.storp(j - 1, &wrk.w4);
690        }
691        wrk.bet[j] = *rnm;
692
693        // restart if invariant subspace is found
694        if compare(*rnm, T::zero()) {
695            *rnm = startv(A, wrk, j, store, 0)?;
696            if compare(*rnm, T::zero()) {
697                *enough = true;
698            }
699        }
700
701        if *enough {
702            mem::swap(&mut wrk.w1, &mut wrk.w2);
703            break;
704        }
705
706        // take a lanczos step
707        svd_datx(rnm.recip(), &wrk.w0, &mut wrk.w1);
708        svd_dscal(rnm.recip(), &mut wrk.w3);
709        svd_opb(A, &wrk.w3, &mut wrk.w0, &mut wrk.temp, wrk.transposed);
710        svd_daxpy(-*rnm, &wrk.w2, &mut wrk.w0);
711        wrk.alf[j] = svd_ddot(&wrk.w0, &wrk.w3);
712        svd_daxpy(-wrk.alf[j], &wrk.w1, &mut wrk.w0);
713
714        // orthogonalize against initial lanczos vectors
715        if j <= MAXLL && wrk.alf[j - 1].abs() > four * wrk.alf[j].abs() {
716            *ll = j;
717        }
718        for i in 0..(j - 1).min(*ll) {
719            let v1 = store.retrp(i);
720            let t = svd_ddot(v1, &wrk.w0);
721            let v2 = store.retrq(i);
722            svd_daxpy(-t, v2, &mut wrk.w0);
723            wrk.eta[i] = eps1;
724            wrk.oldeta[i] = eps1;
725        }
726
727        // extended local reorthogonalization
728        let t = svd_ddot(&wrk.w0, &wrk.w4);
729        svd_daxpy(-t, &wrk.w2, &mut wrk.w0);
730        if wrk.bet[j] > T::zero() {
731            wrk.bet[j] += t;
732        }
733        let t = svd_ddot(&wrk.w0, &wrk.w3);
734        svd_daxpy(-t, &wrk.w1, &mut wrk.w0);
735        wrk.alf[j] += t;
736        wrk.w4.copy_from_slice(&wrk.w0);
737        *rnm = svd_norm(&wrk.w4);
738        let anorm = wrk.bet[j] + wrk.alf[j].abs() + *rnm;
739        *tol = T::eps().sqrt() * anorm;
740
741        // update the orthogonality bounds
742        ortbnd(wrk, j, *rnm, eps1);
743
744        // restore the orthogonality state when needed
745        purge(wrk.ncols, *ll, wrk, j, rnm, *tol, store);
746        if *rnm <= *tol {
747            *rnm = T::zero();
748        }
749        j += 1;
750    }
751    Ok(j)
752}
753
754fn purge<T: SvdFloat>(
755    n: usize,
756    ll: usize,
757    wrk: &mut WorkSpace<T>,
758    step: usize,
759    rnm: &mut T,
760    tol: T,
761    store: &mut Store<T>,
762) {
763    if step < ll + 2 {
764        return;
765    }
766
767    let reps = T::eps().sqrt();
768    let eps1 = T::eps() * T::from_f64(n as f64).unwrap().sqrt();
769    let two = T::from_f64(2.0).unwrap();
770
771    let k = svd_idamax(step - (ll + 1), &wrk.eta) + ll;
772    if wrk.eta[k].abs() > reps {
773        let reps1 = eps1 / reps;
774        let mut iteration = 0;
775        let mut flag = true;
776        while iteration < 2 && flag {
777            if *rnm > tol {
778                // bring in a lanczos vector t and orthogonalize both r and q against it
779                let mut tq = T::zero();
780                let mut tr = T::zero();
781                for i in ll..step {
782                    let v = store.retrq(i);
783                    let t = svd_ddot(v, &wrk.w3);
784                    tq += t.abs();
785                    svd_daxpy(-t, v, &mut wrk.w1);
786                    let t = svd_ddot(v, &wrk.w4);
787                    tr += t.abs();
788                    svd_daxpy(-t, v, &mut wrk.w0);
789                }
790                wrk.w3.copy_from_slice(&wrk.w1);
791                let t = svd_ddot(&wrk.w0, &wrk.w3);
792                tr += t.abs();
793                svd_daxpy(-t, &wrk.w1, &mut wrk.w0);
794                wrk.w4.copy_from_slice(&wrk.w0);
795                *rnm = svd_norm(&wrk.w4);
796                if tq <= reps1 && tr <= *rnm * reps1 {
797                    flag = false;
798                }
799            }
800            iteration += 1;
801        }
802        for i in ll..=step {
803            wrk.eta[i] = eps1;
804            wrk.oldeta[i] = eps1;
805        }
806    }
807}
808
809fn ortbnd<T: SvdFloat>(wrk: &mut WorkSpace<T>, step: usize, rnm: T, eps1: T) {
810    if step < 1 {
811        return;
812    }
813    if !compare(rnm, T::zero()) && step > 1 {
814        wrk.oldeta[0] = (wrk.bet[1] * wrk.eta[1] + (wrk.alf[0] - wrk.alf[step]) * wrk.eta[0]
815            - wrk.bet[step] * wrk.oldeta[0])
816            / rnm
817            + eps1;
818        if step > 2 {
819            for i in 1..=step - 2 {
820                wrk.oldeta[i] = (wrk.bet[i + 1] * wrk.eta[i + 1]
821                    + (wrk.alf[i] - wrk.alf[step]) * wrk.eta[i]
822                    + wrk.bet[i] * wrk.eta[i - 1]
823                    - wrk.bet[step] * wrk.oldeta[i])
824                    / rnm
825                    + eps1;
826            }
827        }
828    }
829    wrk.oldeta[step - 1] = eps1;
830    mem::swap(&mut wrk.oldeta, &mut wrk.eta);
831    wrk.eta[step] = eps1;
832}
833
834fn error_bound<T: SvdFloat>(
835    enough: &mut bool,
836    endl: T,
837    endr: T,
838    ritz: &mut [T],
839    bnd: &mut [T],
840    step: usize,
841    tol: T,
842) -> usize {
843    assert!(step > 0, "error_bound: expected 'step' to be non-zero");
844
845    // massage error bounds for very close ritz values
846    let mid = svd_idamax(step + 1, bnd);
847    let sixteen = T::from_f64(16.0).unwrap();
848
849    let mut i = ((step + 1) + (step - 1)) / 2;
850    while i > mid + 1 {
851        if (ritz[i - 1] - ritz[i]).abs() < T::eps34() * ritz[i].abs()
852            && bnd[i] > tol
853            && bnd[i - 1] > tol
854        {
855            bnd[i - 1] = (bnd[i].powi(2) + bnd[i - 1].powi(2)).sqrt();
856            bnd[i] = T::zero();
857        }
858        i -= 1;
859    }
860
861    let mut i = ((step + 1) - (step - 1)) / 2;
862    while i + 1 < mid {
863        if (ritz[i + 1] - ritz[i]).abs() < T::eps34() * ritz[i].abs()
864            && bnd[i] > tol
865            && bnd[i + 1] > tol
866        {
867            bnd[i + 1] = (bnd[i].powi(2) + bnd[i + 1].powi(2)).sqrt();
868            bnd[i] = T::zero();
869        }
870        i += 1;
871    }
872
873    // refine the error bounds
874    let mut neig = 0;
875    let mut gapl = ritz[step] - ritz[0];
876    for i in 0..=step {
877        let mut gap = gapl;
878        if i < step {
879            gapl = ritz[i + 1] - ritz[i];
880        }
881        gap = gap.min(gapl);
882        if gap > bnd[i] {
883            bnd[i] *= bnd[i] / gap;
884        }
885        if bnd[i] <= sixteen * T::eps() * ritz[i].abs() {
886            neig += 1;
887            if !*enough {
888                *enough = endl < ritz[i] && ritz[i] < endr;
889            }
890        }
891    }
892    neig
893}
894
895fn imtql2<T: SvdFloat>(
896    nm: usize,
897    n: usize,
898    d: &mut [T],
899    e: &mut [T],
900    z: &mut [T],
901    max_imtqlb: Option<usize>,
902) -> Result<(), SvdLibError> {
903    let max_imtqlb = max_imtqlb.unwrap_or(MAX_IMTQLB_ITERATIONS);
904    if n == 1 {
905        return Ok(());
906    }
907    assert!(n > 1, "imtql2: expected 'n' to be > 1");
908    let two = T::from_f64(2.0).unwrap();
909
910    let last = n - 1;
911
912    for i in 1..n {
913        e[i - 1] = e[i];
914    }
915    e[last] = T::zero();
916
917    let nnm = n * nm;
918    for l in 0..n {
919        let mut iteration = 0;
920
921        // look for small sub-diagonal element
922        while iteration <= max_imtqlb {
923            let mut m = l;
924            while m < n {
925                if m == last {
926                    break;
927                }
928                let test = d[m].abs() + d[m + 1].abs();
929                if compare(test, test + e[m].abs()) {
930                    break; // convergence = true;
931                }
932                m += 1;
933            }
934            if m == l {
935                break;
936            }
937
938            // error -- no convergence to an eigenvalue after 30 iterations.
939            if iteration == max_imtqlb {
940                return Err(SvdLibError::Imtql2Error(format!(
941                    "imtql2 no convergence to an eigenvalue after {} iterations",
942                    max_imtqlb
943                )));
944            }
945            iteration += 1;
946
947            // form shift
948            let mut g = (d[l + 1] - d[l]) / (two * e[l]);
949            let mut r = svd_pythag(g, T::one());
950            g = d[m] - d[l] + e[l] / (g + svd_fsign(r, g));
951
952            let mut s = T::one();
953            let mut c = T::one();
954            let mut p = T::zero();
955
956            assert!(m > 0, "imtql2: expected 'm' to be non-zero");
957            let mut i = m - 1;
958            let mut underflow = false;
959            while !underflow && i >= l {
960                let mut f = s * e[i];
961                let b = c * e[i];
962                r = svd_pythag(f, g);
963                e[i + 1] = r;
964                if compare(r, T::zero()) {
965                    underflow = true;
966                } else {
967                    s = f / r;
968                    c = g / r;
969                    g = d[i + 1] - p;
970                    r = (d[i] - g) * s + two * c * b;
971                    p = s * r;
972                    d[i + 1] = g + p;
973                    g = c * r - b;
974
975                    // form vector
976                    for k in (0..nnm).step_by(n) {
977                        let index = k + i;
978                        f = z[index + 1];
979                        z[index + 1] = s * z[index] + c * f;
980                        z[index] = c * z[index] - s * f;
981                    }
982                    if i == 0 {
983                        break;
984                    }
985                    i -= 1;
986                }
987            } /* end while (underflow != FALSE && i >= l) */
988            /*........ recover from underflow .........*/
989            if underflow {
990                d[i + 1] -= p;
991            } else {
992                d[l] -= p;
993                e[l] = g;
994            }
995            e[m] = T::zero();
996        }
997    }
998
999    // order the eigenvalues
1000    for l in 1..n {
1001        let i = l - 1;
1002        let mut k = i;
1003        let mut p = d[i];
1004        for (j, item) in d.iter().enumerate().take(n).skip(l) {
1005            if *item < p {
1006                k = j;
1007                p = *item;
1008            }
1009        }
1010
1011        // ...and corresponding eigenvectors
1012        if k != i {
1013            d[k] = d[i];
1014            d[i] = p;
1015            for j in (0..nnm).step_by(n) {
1016                z.swap(j + i, j + k);
1017            }
1018        }
1019    }
1020
1021    Ok(())
1022}
1023
1024fn rotate_array<T: Float + Copy>(a: &mut [T], x: usize) {
1025    let n = a.len();
1026    let mut j = 0;
1027    let mut start = 0;
1028    let mut t1 = a[0];
1029
1030    for _ in 0..n {
1031        j = match j >= x {
1032            true => j - x,
1033            false => j + n - x,
1034        };
1035
1036        let t2 = a[j];
1037        a[j] = t1;
1038
1039        if j == start {
1040            j += 1;
1041            start = j;
1042            t1 = a[j];
1043        } else {
1044            t1 = t2;
1045        }
1046    }
1047}
1048
1049#[allow(non_snake_case)]
1050fn ritvec<T: SvdFloat>(
1051    A: &dyn SMat<T>,
1052    dimensions: usize,
1053    kappa: T,
1054    wrk: &mut WorkSpace<T>,
1055    steps: usize,
1056    neig: usize,
1057    store: &mut Store<T>,
1058) -> Result<SVDRawRec<T>, SvdLibError> {
1059    let js = steps + 1;
1060    let jsq = js * js;
1061
1062    let sparsity = T::one()
1063        - (T::from_usize(A.nnz()).unwrap()
1064            / (T::from_usize(A.nrows()).unwrap() * T::from_usize(A.ncols()).unwrap()));
1065
1066    let epsilon = T::epsilon();
1067    let adaptive_eps = if sparsity > T::from_f64(0.99).unwrap() {
1068        // For very sparse matrices (>99%), use a more relaxed tolerance
1069        epsilon * T::from_f64(100.0).unwrap()
1070    } else if sparsity > T::from_f64(0.9).unwrap() {
1071        // For moderately sparse matrices (>90%), use a somewhat relaxed tolerance
1072        epsilon * T::from_f64(10.0).unwrap()
1073    } else {
1074        // For less sparse matrices, use standard epsilon
1075        epsilon
1076    };
1077
1078    let max_iterations_imtql2 = if sparsity > T::from_f64(0.999).unwrap() {
1079        // Ultra sparse (>99.9%) - needs many more iterations
1080        Some(500)
1081    } else if sparsity > T::from_f64(0.99).unwrap() {
1082        // Very sparse (>99%) - needs more iterations
1083        Some(300)
1084    } else if sparsity > T::from_f64(0.9).unwrap() {
1085        // Moderately sparse (>90%) - needs somewhat more iterations
1086        Some(200)
1087    } else {
1088        // Default iterations for less sparse matrices
1089        Some(50)
1090    };
1091
1092    let mut s = vec![T::zero(); jsq];
1093    // initialize s to an identity matrix
1094    for i in (0..jsq).step_by(js + 1) {
1095        s[i] = T::one();
1096    }
1097
1098    let mut Vt = DMat {
1099        cols: wrk.ncols,
1100        value: vec![T::zero(); wrk.ncols * dimensions],
1101    };
1102
1103    svd_dcopy(js, 0, &wrk.alf, &mut Vt.value);
1104    svd_dcopy(steps, 1, &wrk.bet, &mut wrk.w5);
1105
1106    // on return from imtql2(), `R.Vt.value` contains eigenvalues in
1107    // ascending order and `s` contains the corresponding eigenvectors
1108    imtql2(
1109        js,
1110        js,
1111        &mut Vt.value,
1112        &mut wrk.w5,
1113        &mut s,
1114        max_iterations_imtql2,
1115    )?;
1116
1117    let max_eigenvalue = Vt
1118        .value
1119        .iter()
1120        .fold(T::zero(), |max, &val| max.max(val.abs()));
1121
1122    let adaptive_kappa = if sparsity > T::from_f64(0.99).unwrap() {
1123        // More relaxed kappa for very sparse matrices
1124        kappa * T::from_f64(10.0).unwrap()
1125    } else {
1126        kappa
1127    };
1128
1129    let mut x = dimensions - 1;
1130
1131    let store_vectors: Vec<Vec<T>> = (0..js).map(|i| store.retrq(i).to_vec()).collect();
1132
1133    let significant_indices: Vec<usize> = (0..js)
1134        .into_par_iter()
1135        .filter(|&k| {
1136            let relative_bound =
1137                adaptive_kappa * wrk.ritz[k].abs().max(max_eigenvalue * adaptive_eps);
1138            wrk.bnd[k] <= relative_bound && k + 1 > js - neig
1139        })
1140        .collect();
1141
1142    let nsig = significant_indices.len();
1143
1144    let mut vt_vectors: Vec<(usize, Vec<T>)> = significant_indices
1145        .into_par_iter()
1146        .map(|k| {
1147            let mut vec = vec![T::zero(); wrk.ncols];
1148
1149            for i in 0..js {
1150                let idx = k * js + i;
1151
1152                if s[idx].abs() > adaptive_eps {
1153                    for (j, item) in store_vectors[i].iter().enumerate().take(wrk.ncols) {
1154                        vec[j] += s[idx] * *item;
1155                    }
1156                }
1157            }
1158
1159            (k, vec)
1160        })
1161        .collect();
1162
1163    // Sort by k value to maintain original order
1164    vt_vectors.sort_by_key(|(k, _)| *k);
1165
1166    // final dimension size
1167    let d = dimensions.min(nsig);
1168    let mut S = vec![T::zero(); d];
1169    let mut Ut = DMat {
1170        cols: wrk.nrows,
1171        value: vec![T::zero(); wrk.nrows * d],
1172    };
1173
1174    // Create new Vt with the correct size
1175    let mut Vt = DMat {
1176        cols: wrk.ncols,
1177        value: vec![T::zero(); wrk.ncols * d],
1178    };
1179
1180    // Fill Vt with the vectors we computed
1181    for (i, (_, vec)) in vt_vectors.into_iter().take(d).enumerate() {
1182        let vt_offset = i * Vt.cols;
1183        Vt.value[vt_offset..vt_offset + Vt.cols].copy_from_slice(&vec);
1184    }
1185
1186    // Prepare for parallel computation of S and Ut
1187    let mut ab_products = Vec::with_capacity(d);
1188    let mut a_products = Vec::with_capacity(d);
1189
1190    // First compute all matrix-vector products sequentially
1191    for i in 0..d {
1192        let vt_offset = i * Vt.cols;
1193        let vt_vec = &Vt.value[vt_offset..vt_offset + Vt.cols];
1194
1195        let mut tmp_vec = vec![T::zero(); Vt.cols];
1196        let mut ut_vec = vec![T::zero(); wrk.nrows];
1197
1198        // Matrix-vector products with A and A'A
1199        svd_opb(A, vt_vec, &mut tmp_vec, &mut wrk.temp, wrk.transposed);
1200        A.svd_opa(vt_vec, &mut ut_vec, wrk.transposed);
1201
1202        ab_products.push(tmp_vec);
1203        a_products.push(ut_vec);
1204    }
1205
1206    let results: Vec<(usize, T)> = (0..d)
1207        .into_par_iter()
1208        .map(|i| {
1209            let vt_offset = i * Vt.cols;
1210            let vt_vec = &Vt.value[vt_offset..vt_offset + Vt.cols];
1211            let tmp_vec = &ab_products[i];
1212
1213            // Compute singular value
1214            let t = svd_ddot(vt_vec, tmp_vec);
1215            let sval = t.max(T::zero()).sqrt();
1216
1217            (i, sval)
1218        })
1219        .collect();
1220
1221    // Process results and scale the vectors
1222    for (i, sval) in results {
1223        S[i] = sval;
1224        let ut_offset = i * Ut.cols;
1225        let mut ut_vec = a_products[i].clone();
1226
1227        if sval > adaptive_eps {
1228            svd_dscal(T::one() / sval, &mut ut_vec);
1229        } else {
1230            let dls = sval.max(adaptive_eps);
1231            let safe_scale = T::one() / dls;
1232            svd_dscal(safe_scale, &mut ut_vec);
1233        }
1234
1235        // Copy to output
1236        Ut.value[ut_offset..ut_offset + Ut.cols].copy_from_slice(&ut_vec);
1237    }
1238
1239    Ok(SVDRawRec {
1240        // Dimensionality (rank)
1241        d,
1242        // Significant values
1243        nsig,
1244        // DMat Ut  Transpose of left singular vectors. (d by m)
1245        //          The vectors are the rows of Ut.
1246        Ut,
1247        // Array of singular values. (length d)
1248        S,
1249        // DMat Vt  Transpose of right singular vectors. (d by n)
1250        //          The vectors are the rows of Vt.
1251        Vt,
1252    })
1253}
1254
1255#[allow(non_snake_case)]
1256#[allow(clippy::too_many_arguments)]
1257fn lanso<T: SvdFloat>(
1258    A: &dyn SMat<T>,
1259    dim: usize,
1260    iterations: usize,
1261    end_interval: &[T; 2],
1262    wrk: &mut WorkSpace<T>,
1263    neig: &mut usize,
1264    store: &mut Store<T>,
1265    random_seed: u32,
1266) -> Result<usize, SvdLibError> {
1267    let sparsity = T::one()
1268        - (T::from_usize(A.nnz()).unwrap()
1269            / (T::from_usize(A.nrows()).unwrap() * T::from_usize(A.ncols()).unwrap()));
1270    let max_iterations_imtqlb = if sparsity > T::from_f64(0.999).unwrap() {
1271        // Ultra sparse (>99.9%) - needs many more iterations
1272        Some(500)
1273    } else if sparsity > T::from_f64(0.99).unwrap() {
1274        // Very sparse (>99%) - needs more iterations
1275        Some(300)
1276    } else if sparsity > T::from_f64(0.9).unwrap() {
1277        // Moderately sparse (>90%) - needs somewhat more iterations
1278        Some(100)
1279    } else {
1280        // Default iterations for less sparse matrices
1281        Some(50)
1282    };
1283
1284    let epsilon = T::epsilon();
1285    let adaptive_eps = if sparsity > T::from_f64(0.99).unwrap() {
1286        // For very sparse matrices (>99%), use a more relaxed tolerance
1287        epsilon * T::from_f64(100.0).unwrap()
1288    } else if sparsity > T::from_f64(0.9).unwrap() {
1289        // For moderately sparse matrices (>90%), use a somewhat relaxed tolerance
1290        epsilon * T::from_f64(10.0).unwrap()
1291    } else {
1292        // For less sparse matrices, use standard epsilon
1293        epsilon
1294    };
1295
1296    let (endl, endr) = (end_interval[0], end_interval[1]);
1297
1298    /* take the first step */
1299    let rnm_tol = stpone(A, wrk, store, random_seed)?;
1300    let mut rnm = rnm_tol.0;
1301    let mut tol = rnm_tol.1;
1302
1303    let eps1 = adaptive_eps * T::from_f64(wrk.ncols as f64).unwrap().sqrt();
1304    wrk.eta[0] = eps1;
1305    wrk.oldeta[0] = eps1;
1306    let mut ll = 0;
1307    let mut first = 1;
1308    let mut last = iterations.min(dim.max(8) + dim);
1309    let mut enough = false;
1310    let mut j = 0;
1311    let mut intro = 0;
1312
1313    while !enough {
1314        if rnm <= tol {
1315            rnm = T::zero();
1316        }
1317
1318        // the actual lanczos loop
1319        let steps = lanczos_step(
1320            A,
1321            wrk,
1322            first,
1323            last,
1324            &mut ll,
1325            &mut enough,
1326            &mut rnm,
1327            &mut tol,
1328            store,
1329        )?;
1330        j = match enough {
1331            true => steps - 1,
1332            false => last - 1,
1333        };
1334
1335        first = j + 1;
1336        wrk.bet[first] = rnm;
1337
1338        // analyze T
1339        let mut l = 0;
1340        for _ in 0..j {
1341            if l > j {
1342                break;
1343            }
1344
1345            let mut i = l;
1346            while i <= j {
1347                if wrk.bet[i + 1].abs() <= adaptive_eps {
1348                    break;
1349                }
1350                i += 1;
1351            }
1352            i = i.min(j);
1353
1354            // now i is at the end of an unreduced submatrix
1355            let sz = i - l;
1356            svd_dcopy(sz + 1, l, &wrk.alf, &mut wrk.ritz);
1357            svd_dcopy(sz, l + 1, &wrk.bet, &mut wrk.w5);
1358
1359            imtqlb(
1360                sz + 1,
1361                &mut wrk.ritz[l..],
1362                &mut wrk.w5[l..],
1363                &mut wrk.bnd[l..],
1364                max_iterations_imtqlb,
1365            )?;
1366
1367            for m in l..=i {
1368                wrk.bnd[m] = rnm * wrk.bnd[m].abs();
1369            }
1370            l = i + 1;
1371        }
1372
1373        // sort eigenvalues into increasing order
1374        insert_sort(j + 1, &mut wrk.ritz, &mut wrk.bnd);
1375
1376        *neig = error_bound(&mut enough, endl, endr, &mut wrk.ritz, &mut wrk.bnd, j, tol);
1377
1378        // should we stop?
1379        if *neig < dim {
1380            if *neig == 0 {
1381                last = first + 9;
1382                intro = first;
1383            } else {
1384                let extra_steps = if sparsity > T::from_f64(0.99).unwrap() {
1385                    5 // For very sparse matrices, add extra steps
1386                } else {
1387                    0
1388                };
1389
1390                last = first + 3.max(1 + ((j - intro) * (dim - *neig)) / *neig) + extra_steps;
1391            }
1392            last = last.min(iterations);
1393        } else {
1394            enough = true
1395        }
1396        enough = enough || first >= iterations;
1397    }
1398    store.storq(j, &wrk.w1);
1399    Ok(j)
1400}
1401
1402impl<T: SvdFloat + 'static> SvdRec<T> {
1403    pub fn recompose(&self) -> Array2<T> {
1404        let sdiag = Array2::from_diag(&self.s);
1405        self.u.dot(&sdiag).dot(&self.vt)
1406    }
1407}
1408
1409#[rustfmt::skip]
1410impl<T: Float + Zero + AddAssign + Clone + Sync> SMat<T> for nalgebra_sparse::csc::CscMatrix<T> {
1411    fn nrows(&self) -> usize { self.nrows() }
1412    fn ncols(&self) -> usize { self.ncols() }
1413    fn nnz(&self) -> usize { self.nnz() }
1414
1415    /// takes an n-vector x and returns A*x in y
1416    fn svd_opa(&self, x: &[T], y: &mut [T], transposed: bool) {
1417        let nrows = if transposed { self.ncols() } else { self.nrows() };
1418        let ncols = if transposed { self.nrows() } else { self.ncols() };
1419        assert_eq!(x.len(), ncols, "svd_opa: x must be A.ncols() in length, x = {}, A.ncols = {}", x.len(), ncols);
1420        assert_eq!(y.len(), nrows, "svd_opa: y must be A.nrows() in length, y = {}, A.nrows = {}", y.len(), nrows);
1421
1422        let (major_offsets, minor_indices, values) = self.csc_data();
1423
1424        for y_val in y.iter_mut() {
1425            *y_val = T::zero();
1426        }
1427
1428        if transposed {
1429            for (i, yval) in y.iter_mut().enumerate() {
1430                for j in major_offsets[i]..major_offsets[i + 1] {
1431                    *yval += values[j] * x[minor_indices[j]];
1432                }
1433            }
1434        } else {
1435            for (i, xval) in x.iter().enumerate() {
1436                for j in major_offsets[i]..major_offsets[i + 1] {
1437                    y[minor_indices[j]] += values[j] * *xval;
1438                }
1439            }
1440        }
1441    }
1442}
1443
1444#[rustfmt::skip]
1445impl<T: Float + Zero + AddAssign + Clone + Sync + Send> SMat<T> for nalgebra_sparse::csr::CsrMatrix<T> {
1446    fn nrows(&self) -> usize { self.nrows() }
1447    fn ncols(&self) -> usize { self.ncols() }
1448    fn nnz(&self) -> usize { self.nnz() }
1449
1450    /// takes an n-vector x and returns A*x in y
1451    fn svd_opa(&self, x: &[T], y: &mut [T], transposed: bool) {
1452        //TODO parallelize me please
1453        let nrows = if transposed { self.ncols() } else { self.nrows() };
1454        let ncols = if transposed { self.nrows() } else { self.ncols() };
1455        assert_eq!(x.len(), ncols, "svd_opa: x must be A.ncols() in length, x = {}, A.ncols = {}", x.len(), ncols);
1456        assert_eq!(y.len(), nrows, "svd_opa: y must be A.nrows() in length, y = {}, A.nrows = {}", y.len(), nrows);
1457
1458        let (major_offsets, minor_indices, values) = self.csr_data();
1459
1460        y.fill(T::zero());
1461
1462        if !transposed {
1463            let nrows = self.nrows();
1464            let chunk_size = crate::utils::determine_chunk_size(nrows);
1465
1466            // Create thread-local vectors with results
1467            let results: Vec<(usize, T)> = (0..nrows)
1468                .into_par_iter()
1469                .map(|i| {
1470                    let mut sum = T::zero();
1471                    for j in major_offsets[i]..major_offsets[i + 1] {
1472                        sum += values[j] * x[minor_indices[j]];
1473                    }
1474                (i, sum)
1475            })
1476            .collect();
1477
1478            // Apply the results to y
1479            for (i, val) in results {
1480                y[i] = val;
1481            }
1482        } else {
1483            let nrows = self.nrows();
1484        let chunk_size = crate::utils::determine_chunk_size(nrows);
1485
1486        // Process input in chunks and create partial results
1487        let results: Vec<Vec<T>> = (0..((nrows + chunk_size - 1) / chunk_size))
1488            .into_par_iter()
1489            .map(|chunk_idx| {
1490                let start = chunk_idx * chunk_size;
1491                let end = (start + chunk_size).min(nrows);
1492
1493                let mut local_y = vec![T::zero(); y.len()];
1494                for i in start..end {
1495                    let row_val = x[i];
1496                    for j in major_offsets[i]..major_offsets[i + 1] {
1497                        let col = minor_indices[j];
1498                        local_y[col] += values[j] * row_val;
1499                    }
1500                }
1501                local_y
1502            })
1503            .collect();
1504
1505        // Combine partial results
1506        for local_y in results {
1507            for (idx, val) in local_y.iter().enumerate() {
1508                if !val.is_zero() {
1509                    y[idx] += *val;
1510                }
1511            }
1512        }
1513        }
1514    }
1515}
1516
1517#[rustfmt::skip]
1518impl<T: Float + Zero + AddAssign + Clone + Sync> SMat<T> for nalgebra_sparse::coo::CooMatrix<T> {
1519    fn nrows(&self) -> usize { self.nrows() }
1520    fn ncols(&self) -> usize { self.ncols() }
1521    fn nnz(&self) -> usize { self.nnz() }
1522
1523    /// takes an n-vector x and returns A*x in y
1524    fn svd_opa(&self, x: &[T], y: &mut [T], transposed: bool) {
1525        let nrows = if transposed { self.ncols() } else { self.nrows() };
1526        let ncols = if transposed { self.nrows() } else { self.ncols() };
1527        assert_eq!(x.len(), ncols, "svd_opa: x must be A.ncols() in length, x = {}, A.ncols = {}", x.len(), ncols);
1528        assert_eq!(y.len(), nrows, "svd_opa: y must be A.nrows() in length, y = {}, A.nrows = {}", y.len(), nrows);
1529
1530        for y_val in y.iter_mut() {
1531            *y_val = T::zero();
1532        }
1533
1534        if transposed {
1535            for (i, j, v) in self.triplet_iter() {
1536                y[j] += *v * x[i];
1537            }
1538        } else {
1539            for (i, j, v) in self.triplet_iter() {
1540                y[i] += *v * x[j];
1541            }
1542        }
1543    }
1544}