single_svdlib/lanczos/
mod.rs

1pub mod masked;
2
3use crate::error::SvdLibError;
4use crate::{Diagnostics, SMat, SvdFloat, SvdRec};
5use ndarray::{Array, Array2};
6use num_traits::real::Real;
7use num_traits::{Float, FromPrimitive, One, Zero};
8use rand::rngs::StdRng;
9use rand::{rng, Rng, RngCore, SeedableRng};
10use rayon::iter::IndexedParallelIterator;
11use rayon::iter::ParallelIterator;
12use rayon::prelude::{IntoParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator};
13use std::fmt::Debug;
14use std::iter::Sum;
15use std::mem;
16use std::ops::{AddAssign, MulAssign, Neg, SubAssign};
17
18/// Trait for floating point types that can be used with the SVD algorithm
19
20/// SVD at full dimensionality, calls `svdLAS2` with the highlighted defaults
21///
22/// svdLAS2(A, `0`, `0`, `&[-1.0e-30, 1.0e-30]`, `1.0e-6`, `0`)
23///
24/// # Parameters
25/// - A: Sparse matrix
26pub fn svd<T, M>(a: &M) -> Result<SvdRec<T>, SvdLibError>
27where
28    T: SvdFloat,
29    M: SMat<T>,
30{
31    let eps_small = T::from_f64(-1.0e-30).unwrap();
32    let eps_large = T::from_f64(1.0e-30).unwrap();
33    let kappa = T::from_f64(1.0e-6).unwrap();
34    svd_las2(a, 0, 0, &[eps_small, eps_large], kappa, 0)
35}
36
37/// SVD at desired dimensionality, calls `svdLAS2` with the highlighted defaults
38///
39/// svdLAS2(A, dimensions, `0`, `&[-1.0e-30, 1.0e-30]`, `1.0e-6`, `0`)
40///
41/// # Parameters
42/// - A: Sparse matrix
43/// - dimensions: Upper limit of desired number of dimensions, bounded by the matrix shape
44pub fn svd_dim<T, M>(a: &M, dimensions: usize) -> Result<SvdRec<T>, SvdLibError>
45where
46    T: SvdFloat,
47    M: SMat<T>,
48{
49    let eps_small = T::from_f64(-1.0e-30).unwrap();
50    let eps_large = T::from_f64(1.0e-30).unwrap();
51    let kappa = T::from_f64(1.0e-6).unwrap();
52
53    svd_las2(a, dimensions, 0, &[eps_small, eps_large], kappa, 0)
54}
55
56/// SVD at desired dimensionality with supplied seed, calls `svdLAS2` with the highlighted defaults
57///
58/// svdLAS2(A, dimensions, `0`, `&[-1.0e-30, 1.0e-30]`, `1.0e-6`, random_seed)
59///
60/// # Parameters
61/// - A: Sparse matrix
62/// - dimensions: Upper limit of desired number of dimensions, bounded by the matrix shape
63/// - random_seed: A supplied seed `if > 0`, otherwise an internal seed will be generated
64pub fn svd_dim_seed<T, M>(
65    a: &M,
66    dimensions: usize,
67    random_seed: u32,
68) -> Result<SvdRec<T>, SvdLibError>
69where
70    T: SvdFloat,
71    M: SMat<T>,
72{
73    let eps_small = T::from_f64(-1.0e-30).unwrap();
74    let eps_large = T::from_f64(1.0e-30).unwrap();
75    let kappa = T::from_f64(1.0e-6).unwrap();
76
77    svd_las2(
78        a,
79        dimensions,
80        0,
81        &[eps_small, eps_large],
82        kappa,
83        random_seed,
84    )
85}
86
87/// Compute a singular value decomposition
88///
89/// # Parameters
90///
91/// - A: Sparse matrix
92/// - dimensions: Upper limit of desired number of dimensions (0 = max),
93///       where "max" is a value bounded by the matrix shape, the smaller of
94///       the matrix rows or columns. e.g. `A.nrows().min(A.ncols())`
95/// - iterations: Upper limit of desired number of lanczos steps (0 = max),
96///       where "max" is a value bounded by the matrix shape, the smaller of
97///       the matrix rows or columns. e.g. `A.nrows().min(A.ncols())`
98///       iterations must also be in range [`dimensions`, `A.nrows().min(A.ncols())`]
99/// - end_interval: Left, right end of interval containing unwanted eigenvalues,
100///       typically small values centered around zero, e.g. `[-1.0e-30, 1.0e-30]`
101/// - kappa: Relative accuracy of ritz values acceptable as eigenvalues, e.g. `1.0e-6`
102/// - random_seed: A supplied seed `if > 0`, otherwise an internal seed will be generated
103pub fn svd_las2<T, M>(
104    a: &M,
105    dimensions: usize,
106    iterations: usize,
107    end_interval: &[T; 2],
108    kappa: T,
109    random_seed: u32,
110) -> Result<SvdRec<T>, SvdLibError>
111where
112    T: SvdFloat,
113    M: SMat<T>,
114{
115    let random_seed = match random_seed > 0 {
116        true => random_seed,
117        false => rng().next_u32(),
118    };
119
120    let min_nrows_ncols = a.nrows().min(a.ncols());
121
122    let dimensions = match dimensions {
123        n if n == 0 || n > min_nrows_ncols => min_nrows_ncols,
124        _ => dimensions,
125    };
126
127    let iterations = match iterations {
128        n if n == 0 || n > min_nrows_ncols => min_nrows_ncols,
129        n if n < dimensions => dimensions,
130        _ => iterations,
131    };
132
133    if dimensions < 2 {
134        return Err(SvdLibError::Las2Error(format!(
135            "svd_las2: insufficient dimensions: {dimensions}"
136        )));
137    }
138
139    assert!(dimensions > 1 && dimensions <= min_nrows_ncols);
140    assert!(iterations >= dimensions && iterations <= min_nrows_ncols);
141
142    let transposed = (a.ncols() as f64) >= ((a.nrows() as f64) * 1.2);
143    let nrows = if transposed { a.ncols() } else { a.nrows() };
144    let ncols = if transposed { a.nrows() } else { a.ncols() };
145
146    let mut wrk = WorkSpace::new(nrows, ncols, transposed, iterations)?;
147    let mut store = Store::new(ncols)?;
148
149    let mut neig = 0;
150    let steps = lanso(
151        a,
152        dimensions,
153        iterations,
154        end_interval,
155        &mut wrk,
156        &mut neig,
157        &mut store,
158        random_seed,
159    )?;
160
161    let kappa = Float::max(Float::abs(kappa), T::eps34());
162    let mut r = ritvec(a, dimensions, kappa, &mut wrk, steps, neig, &mut store)?;
163
164    if transposed {
165        mem::swap(&mut r.Ut, &mut r.Vt);
166    }
167
168    Ok(SvdRec {
169        // Dimensionality (number of Ut,Vt rows & length of S)
170        d: r.d,
171        u: Array2::from_shape_vec((r.d, r.Ut.cols), r.Ut.value)?,
172        s: Array::from_shape_vec(r.d, r.S)?,
173        vt: Array2::from_shape_vec((r.d, r.Vt.cols), r.Vt.value)?,
174        diagnostics: Diagnostics {
175            non_zero: a.nnz(),
176            dimensions: dimensions,
177            iterations: iterations,
178            transposed: transposed,
179            lanczos_steps: steps + 1,
180            ritz_values_stabilized: neig,
181            significant_values: r.d,
182            singular_values: r.nsig,
183            end_interval: *end_interval,
184            kappa: kappa,
185            random_seed: random_seed,
186        },
187    })
188}
189
190const MAXLL: usize = 2;
191
192#[derive(Debug, Clone, PartialEq)]
193struct Store<T: Float> {
194    n: usize,
195    vecs: Vec<Vec<T>>,
196}
197
198impl<T: Float + Zero + Clone> Store<T> {
199    fn new(n: usize) -> Result<Self, SvdLibError> {
200        Ok(Self { n, vecs: vec![] })
201    }
202
203    fn storq(&mut self, idx: usize, v: &[T]) {
204        while idx + MAXLL >= self.vecs.len() {
205            self.vecs.push(vec![T::zero(); self.n]);
206        }
207        self.vecs[idx + MAXLL].copy_from_slice(v);
208    }
209
210    fn storp(&mut self, idx: usize, v: &[T]) {
211        while idx >= self.vecs.len() {
212            self.vecs.push(vec![T::zero(); self.n]);
213        }
214        self.vecs[idx].copy_from_slice(v);
215    }
216
217    fn retrq(&mut self, idx: usize) -> &[T] {
218        &self.vecs[idx + MAXLL]
219    }
220
221    fn retrp(&mut self, idx: usize) -> &[T] {
222        &self.vecs[idx]
223    }
224}
225
226#[derive(Debug, Clone, PartialEq)]
227struct WorkSpace<T: Float> {
228    nrows: usize,
229    ncols: usize,
230    transposed: bool,
231    w0: Vec<T>,     // workspace 0
232    w1: Vec<T>,     // workspace 1
233    w2: Vec<T>,     // workspace 2
234    w3: Vec<T>,     // workspace 3
235    w4: Vec<T>,     // workspace 4
236    w5: Vec<T>,     // workspace 5
237    alf: Vec<T>,    // array to hold diagonal of the tridiagonal matrix T
238    eta: Vec<T>,    // orthogonality estimate of Lanczos vectors at step j
239    oldeta: Vec<T>, // orthogonality estimate of Lanczos vectors at step j-1
240    bet: Vec<T>,    // array to hold off-diagonal of T
241    bnd: Vec<T>,    // array to hold the error bounds
242    ritz: Vec<T>,   // array to hold the ritz values
243    temp: Vec<T>,   // array to hold the temp values
244}
245
246impl<T: Float + Zero + FromPrimitive> WorkSpace<T> {
247    fn new(
248        nrows: usize,
249        ncols: usize,
250        transposed: bool,
251        iterations: usize,
252    ) -> Result<Self, SvdLibError> {
253        Ok(Self {
254            nrows,
255            ncols,
256            transposed,
257            w0: vec![T::zero(); ncols],
258            w1: vec![T::zero(); ncols],
259            w2: vec![T::zero(); ncols],
260            w3: vec![T::zero(); ncols],
261            w4: vec![T::zero(); ncols],
262            w5: vec![T::zero(); ncols],
263            alf: vec![T::zero(); iterations],
264            eta: vec![T::zero(); iterations],
265            oldeta: vec![T::zero(); iterations],
266            bet: vec![T::zero(); 1 + iterations],
267            ritz: vec![T::zero(); 1 + iterations],
268            bnd: vec![T::from_f64(f64::MAX).unwrap(); 1 + iterations],
269            temp: vec![T::zero(); nrows],
270        })
271    }
272}
273
274/* Row-major dense matrix.  Rows are consecutive vectors. */
275#[derive(Debug, Clone, PartialEq)]
276struct DMat<T: Float> {
277    cols: usize,
278    value: Vec<T>,
279}
280
281#[allow(non_snake_case)]
282#[derive(Debug, Clone, PartialEq)]
283struct SVDRawRec<T: Float> {
284    d: usize,
285    nsig: usize,
286    Ut: DMat<T>,
287    S: Vec<T>,
288    Vt: DMat<T>,
289}
290
291fn compare<T: SvdFloat>(computed: T, expected: T) -> bool {
292    T::compare(computed, expected)
293}
294
295/* Function sorts array1 and array2 into increasing order for array1 */
296fn insert_sort<T: PartialOrd>(n: usize, array1: &mut [T], array2: &mut [T]) {
297    for i in 1..n {
298        for j in (1..i + 1).rev() {
299            if array1[j - 1] <= array1[j] {
300                break;
301            }
302            array1.swap(j - 1, j);
303            array2.swap(j - 1, j);
304        }
305    }
306}
307
308#[allow(non_snake_case)]
309#[rustfmt::skip]
310fn svd_opb<T: Float>(A: &dyn SMat<T>, x: &[T], y: &mut [T], temp: &mut [T], transposed: bool) {
311    let nrows = if transposed { A.ncols() } else { A.nrows() };
312    let ncols = if transposed { A.nrows() } else { A.ncols() };
313    assert_eq!(x.len(), ncols, "svd_opb: x must be A.ncols() in length, x = {}, A.ncols = {}", x.len(), ncols);
314    assert_eq!(y.len(), ncols, "svd_opb: y must be A.ncols() in length, y = {}, A.ncols = {}", y.len(), ncols);
315    assert_eq!(temp.len(), nrows, "svd_opa: temp must be A.nrows() in length, temp = {}, A.nrows = {}", temp.len(), nrows);
316    A.svd_opa(x, temp, transposed); // temp = (A * x)
317    A.svd_opa(temp, y, !transposed); // y = A' * (A * x) = A' * temp
318}
319
320// constant times a vector plus a vector
321fn svd_daxpy<T: Float + AddAssign + Send + Sync>(da: T, x: &[T], y: &mut [T]) {
322    if x.len() < 1000 {
323        for (xval, yval) in x.iter().zip(y.iter_mut()) {
324            *yval += da * *xval
325        }
326    } else {
327        y.par_iter_mut()
328            .zip(x.par_iter())
329            .for_each(|(yval, xval)| *yval += da * *xval);
330    }
331}
332
333// finds the index of element having max absolute value
334fn svd_idamax<T: Float>(n: usize, x: &[T]) -> usize {
335    assert!(n > 0, "svd_idamax: unexpected inputs!");
336
337    match n {
338        1 => 0,
339        _ => {
340            let mut imax = 0;
341            for (i, xval) in x.iter().enumerate().take(n).skip(1) {
342                if xval.abs() > x[imax].abs() {
343                    imax = i;
344                }
345            }
346            imax
347        }
348    }
349}
350
351// returns |a| if b is positive; else fsign returns -|a|
352fn svd_fsign<T: Float>(a: T, b: T) -> T {
353    match (a >= T::zero() && b >= T::zero()) || (a < T::zero() && b < T::zero()) {
354        true => a,
355        false => -a,
356    }
357}
358
359// finds sqrt(a^2 + b^2) without overflow or destructive underflow
360fn svd_pythag<T: SvdFloat + FromPrimitive>(a: T, b: T) -> T {
361    match Float::max(Float::abs(a), Float::abs(b)) {
362        n if n > T::zero() => {
363            let mut p = n;
364            let mut r = Float::powi(Float::min(Float::abs(a), Float::abs(b)) / p, 2);
365            let four = T::from_f64(4.0).unwrap();
366            let two = T::from_f64(2.0).unwrap();
367            let mut t = four + r;
368            while !compare(t, four) {
369                let s = r / t;
370                let u = T::one() + two * s;
371                p = p * u;
372                r = Float::powi((s / u), 2);
373                t = four + r;
374            }
375            p
376        }
377        _ => T::zero(),
378    }
379}
380
381// dot product of two vectors
382fn svd_ddot<T: Float + Sum<T> + Send + Sync>(x: &[T], y: &[T]) -> T {
383    if x.len() < 1000 {
384        x.iter().zip(y).map(|(a, b)| *a * *b).sum()
385    } else {
386        x.par_iter().zip(y.par_iter()).map(|(a, b)| *a * *b).sum()
387    }
388}
389
390// norm (length) of a vector
391fn svd_norm<T: Float + Sum<T> + Send + Sync>(x: &[T]) -> T {
392    svd_ddot(x, x).sqrt()
393}
394
395// scales an input vector 'x', by a constant, storing in 'y'
396fn svd_datx<T: Float + Sum<T>>(d: T, x: &[T], y: &mut [T]) {
397    for (i, xval) in x.iter().enumerate() {
398        y[i] = d * *xval;
399    }
400}
401
402// scales an input vector 'x' by a constant, modifying 'x'
403fn svd_dscal<T: Float + MulAssign + Send + Sync>(d: T, x: &mut [T]) {
404    if x.len() < 1000 {
405        for elem in x.iter_mut() {
406            *elem *= d;
407        }
408    } else {
409        x.par_iter_mut().for_each(|elem| {
410            *elem *= d;
411        });
412    }
413}
414
415// copies a vector x to a vector y (reversed direction)
416fn svd_dcopy<T: Float + Copy>(n: usize, offset: usize, x: &[T], y: &mut [T]) {
417    if n > 0 {
418        let start = n - 1;
419        for i in 0..n {
420            y[offset + start - i] = x[offset + i];
421        }
422    }
423}
424
425const MAX_IMTQLB_ITERATIONS: usize = 100;
426
427fn imtqlb<T: SvdFloat>(
428    n: usize,
429    d: &mut [T],
430    e: &mut [T],
431    bnd: &mut [T],
432    max_imtqlb: Option<usize>,
433) -> Result<(), SvdLibError> {
434    let max_imtqlb = max_imtqlb.unwrap_or(MAX_IMTQLB_ITERATIONS);
435    if n == 1 {
436        return Ok(());
437    }
438
439    let matrix_size_factor = T::from_f64((n as f64).sqrt()).unwrap();
440
441    bnd[0] = T::one();
442    let last = n - 1;
443    for i in 1..=last {
444        bnd[i] = T::zero();
445        e[i - 1] = e[i];
446    }
447    e[last] = T::zero();
448
449    let mut i = 0;
450
451    let mut had_convergence_issues = false;
452
453    for l in 0..=last {
454        let mut iteration = 0;
455        let mut p = d[l];
456        let mut f = bnd[l];
457
458        while iteration <= max_imtqlb {
459            let mut m = l;
460            while m < n {
461                if m == last {
462                    break;
463                }
464
465                // More forgiving convergence test for large/sparse matrices
466                let test = Float::abs(d[m]) + Float::abs(d[m + 1]);
467                // Scale tolerance with matrix size and magnitude
468                let tol = <T as Float>::epsilon()
469                    * T::from_f64(100.0).unwrap()
470                    * Float::max(test, T::one())
471                    * matrix_size_factor;
472
473                if Float::abs(e[m]) <= tol {
474                    break; // Convergence achieved for this element
475                }
476                m += 1;
477            }
478
479            if m == l {
480                // Order the eigenvalues
481                let mut exchange = true;
482                if l > 0 {
483                    i = l;
484                    while i >= 1 && exchange {
485                        if p < d[i - 1] {
486                            d[i] = d[i - 1];
487                            bnd[i] = bnd[i - 1];
488                            i -= 1;
489                        } else {
490                            exchange = false;
491                        }
492                    }
493                }
494                if exchange {
495                    i = 0;
496                }
497                d[i] = p;
498                bnd[i] = f;
499                iteration = max_imtqlb + 1; // Exit the loop
500            } else {
501                // Check if we've reached max iterations without convergence
502                if iteration == max_imtqlb {
503                    // CRITICAL CHANGE: Don't fail, just note the issue and continue
504                    had_convergence_issues = true;
505
506                    // Set conservative error bounds for non-converged values
507                    for idx in l..=m {
508                        bnd[idx] = Float::max(bnd[idx], T::from_f64(0.1).unwrap());
509                    }
510
511                    // Force "convergence" by zeroing the problematic subdiagonal element
512                    e[l] = T::zero();
513
514                    // Break out of the iteration loop and move to next eigenvalue
515                    break;
516                }
517
518                iteration += 1;
519                // ........ form shift ........
520                let two = T::from_f64(2.0).unwrap();
521                let mut g = (d[l + 1] - p) / (two * e[l]);
522                let mut r = svd_pythag(g, T::one());
523                g = d[m] - p + e[l] / (g + svd_fsign(r, g));
524                let mut s = T::one();
525                let mut c = T::one();
526                p = T::zero();
527
528                assert!(m > 0, "imtqlb: expected 'm' to be non-zero");
529                i = m - 1;
530                let mut underflow = false;
531                while !underflow && i >= l {
532                    f = s * e[i];
533                    let b = c * e[i];
534                    r = svd_pythag(f, g);
535                    e[i + 1] = r;
536
537                    // More forgiving underflow detection for sparse matrices
538                    if r < <T as Float>::epsilon()
539                        * T::from_f64(1000.0).unwrap()
540                        * (Float::abs(f) + Float::abs(g))
541                    {
542                        underflow = true;
543                        break;
544                    }
545
546                    // Safety check for division by very small numbers
547                    if Float::abs(r) < <T as Float>::epsilon() * T::from_f64(100.0).unwrap() {
548                        r = <T as Float>::epsilon()
549                            * T::from_f64(100.0).unwrap()
550                            * svd_fsign(T::one(), r);
551                    }
552
553                    s = f / r;
554                    c = g / r;
555                    g = d[i + 1] - p;
556                    r = (d[i] - g) * s + T::from_f64(2.0).unwrap() * c * b;
557                    p = s * r;
558                    d[i + 1] = g + p;
559                    g = c * r - b;
560                    f = bnd[i + 1];
561                    bnd[i + 1] = s * bnd[i] + c * f;
562                    bnd[i] = c * bnd[i] - s * f;
563                    if i == 0 {
564                        break;
565                    }
566                    i -= 1;
567                }
568                // ........ recover from underflow .........
569                if underflow {
570                    d[i + 1] -= p;
571                } else {
572                    d[l] -= p;
573                    e[l] = g;
574                }
575                e[m] = T::zero();
576            }
577        }
578    }
579    if had_convergence_issues {
580        eprintln!("Warning: imtqlb had some convergence issues but continued with best estimates. Results may have reduced accuracy.");
581    }
582    Ok(())
583}
584
585#[allow(non_snake_case)]
586fn startv<T: SvdFloat>(
587    A: &dyn SMat<T>,
588    wrk: &mut WorkSpace<T>,
589    step: usize,
590    store: &mut Store<T>,
591    random_seed: u32,
592) -> Result<T, SvdLibError> {
593    // get initial vector; default is random
594    let mut rnm2 = svd_ddot(&wrk.w0, &wrk.w0);
595    for id in 0..3 {
596        if id > 0 || step > 0 || compare(rnm2, T::zero()) {
597            let mut bytes = [0; 32];
598            for (i, b) in random_seed.to_le_bytes().iter().enumerate() {
599                bytes[i] = *b;
600            }
601            let mut seeded_rng = StdRng::from_seed(bytes);
602            for val in wrk.w0.iter_mut() {
603                *val = T::from_f64(seeded_rng.random_range(-1.0..1.0)).unwrap();
604            }
605        }
606        wrk.w3.copy_from_slice(&wrk.w0);
607
608        // apply operator to put r in range (essential if m singular)
609        svd_opb(A, &wrk.w3, &mut wrk.w0, &mut wrk.temp, wrk.transposed);
610        wrk.w3.copy_from_slice(&wrk.w0);
611        rnm2 = svd_ddot(&wrk.w3, &wrk.w3);
612        if rnm2 > T::zero() {
613            break;
614        }
615    }
616
617    if rnm2 <= T::zero() {
618        return Err(SvdLibError::StartvError(format!(
619            "rnm2 <= 0.0, rnm2 = {rnm2:?}"
620        )));
621    }
622
623    if step > 0 {
624        for i in 0..step {
625            let v = store.retrq(i);
626            svd_daxpy(-svd_ddot(&wrk.w3, v), v, &mut wrk.w0);
627        }
628
629        // make sure q[step] is orthogonal to q[step-1]
630        svd_daxpy(-svd_ddot(&wrk.w4, &wrk.w0), &wrk.w2, &mut wrk.w0);
631        wrk.w3.copy_from_slice(&wrk.w0);
632
633        rnm2 = match svd_ddot(&wrk.w3, &wrk.w3) {
634            dot if dot <= T::eps() * rnm2 => T::zero(),
635            dot => dot,
636        }
637    }
638    Ok(rnm2.sqrt())
639}
640
641#[allow(non_snake_case)]
642fn stpone<T: SvdFloat>(
643    A: &dyn SMat<T>,
644    wrk: &mut WorkSpace<T>,
645    store: &mut Store<T>,
646    random_seed: u32,
647) -> Result<(T, T), SvdLibError> {
648    // get initial vector; default is random
649    let mut rnm = startv(A, wrk, 0, store, random_seed)?;
650    if compare(rnm, T::zero()) {
651        return Err(SvdLibError::StponeError("rnm == 0.0".to_string()));
652    }
653
654    // normalize starting vector
655    svd_datx(Float::recip(rnm), &wrk.w0, &mut wrk.w1);
656    svd_dscal(Float::recip(rnm), &mut wrk.w3);
657
658    // take the first step
659    svd_opb(A, &wrk.w3, &mut wrk.w0, &mut wrk.temp, wrk.transposed);
660    wrk.alf[0] = svd_ddot(&wrk.w0, &wrk.w3);
661    svd_daxpy(-wrk.alf[0], &wrk.w1, &mut wrk.w0);
662    let t = svd_ddot(&wrk.w0, &wrk.w3);
663    wrk.alf[0] += t;
664    svd_daxpy(-t, &wrk.w1, &mut wrk.w0);
665    wrk.w4.copy_from_slice(&wrk.w0);
666    rnm = svd_norm(&wrk.w4);
667    let anorm = rnm + Float::abs(wrk.alf[0]);
668    Ok((rnm, T::eps().sqrt() * anorm))
669}
670
671#[allow(non_snake_case)]
672#[allow(clippy::too_many_arguments)]
673fn lanczos_step<T: SvdFloat>(
674    A: &dyn SMat<T>,
675    wrk: &mut WorkSpace<T>,
676    first: usize,
677    last: usize,
678    ll: &mut usize,
679    enough: &mut bool,
680    rnm: &mut T,
681    tol: &mut T,
682    store: &mut Store<T>,
683) -> Result<usize, SvdLibError> {
684    let eps1 = T::eps() * T::from_f64(wrk.ncols as f64).unwrap().sqrt();
685    let mut j = first;
686    let four = T::from_f64(4.0).unwrap();
687
688    while j < last {
689        mem::swap(&mut wrk.w1, &mut wrk.w2);
690        mem::swap(&mut wrk.w3, &mut wrk.w4);
691
692        store.storq(j - 1, &wrk.w2);
693        if j - 1 < MAXLL {
694            store.storp(j - 1, &wrk.w4);
695        }
696        wrk.bet[j] = *rnm;
697
698        // restart if invariant subspace is found
699        if compare(*rnm, T::zero()) {
700            *rnm = startv(A, wrk, j, store, 0)?;
701            if compare(*rnm, T::zero()) {
702                *enough = true;
703            }
704        }
705
706        if *enough {
707            mem::swap(&mut wrk.w1, &mut wrk.w2);
708            break;
709        }
710
711        // take a lanczos step
712        svd_datx(Float::recip(*rnm), &wrk.w0, &mut wrk.w1);
713        svd_dscal(Float::recip(*rnm), &mut wrk.w3);
714        svd_opb(A, &wrk.w3, &mut wrk.w0, &mut wrk.temp, wrk.transposed);
715        svd_daxpy(-*rnm, &wrk.w2, &mut wrk.w0);
716        wrk.alf[j] = svd_ddot(&wrk.w0, &wrk.w3);
717        svd_daxpy(-wrk.alf[j], &wrk.w1, &mut wrk.w0);
718
719        // orthogonalize against initial lanczos vectors
720        if j <= MAXLL && Float::abs(wrk.alf[j - 1]) > four * Float::abs(wrk.alf[j]) {
721            *ll = j;
722        }
723        for i in 0..(j - 1).min(*ll) {
724            let v1 = store.retrp(i);
725            let t = svd_ddot(v1, &wrk.w0);
726            let v2 = store.retrq(i);
727            svd_daxpy(-t, v2, &mut wrk.w0);
728            wrk.eta[i] = eps1;
729            wrk.oldeta[i] = eps1;
730        }
731
732        // extended local reorthogonalization
733        let t = svd_ddot(&wrk.w0, &wrk.w4);
734        svd_daxpy(-t, &wrk.w2, &mut wrk.w0);
735        if wrk.bet[j] > T::zero() {
736            wrk.bet[j] += t;
737        }
738        let t = svd_ddot(&wrk.w0, &wrk.w3);
739        svd_daxpy(-t, &wrk.w1, &mut wrk.w0);
740        wrk.alf[j] += t;
741        wrk.w4.copy_from_slice(&wrk.w0);
742        *rnm = svd_norm(&wrk.w4);
743        let anorm = wrk.bet[j] + Float::abs(wrk.alf[j]) + *rnm;
744        *tol = T::eps().sqrt() * anorm;
745
746        // update the orthogonality bounds
747        ortbnd(wrk, j, *rnm, eps1);
748
749        // restore the orthogonality state when needed
750        purge(wrk.ncols, *ll, wrk, j, rnm, *tol, store);
751        if *rnm <= *tol {
752            *rnm = T::zero();
753        }
754        j += 1;
755    }
756    Ok(j)
757}
758
759fn purge<T: SvdFloat>(
760    n: usize,
761    ll: usize,
762    wrk: &mut WorkSpace<T>,
763    step: usize,
764    rnm: &mut T,
765    tol: T,
766    store: &mut Store<T>,
767) {
768    if step < ll + 2 {
769        return;
770    }
771
772    let reps = T::eps().sqrt();
773    let eps1 = T::eps() * T::from_f64(n as f64).unwrap().sqrt();
774    let two = T::from_f64(2.0).unwrap();
775
776    let k = svd_idamax(step - (ll + 1), &wrk.eta) + ll;
777    if Float::abs(wrk.eta[k]) > reps {
778        let reps1 = eps1 / reps;
779        let mut iteration = 0;
780        let mut flag = true;
781        while iteration < 2 && flag {
782            if *rnm > tol {
783                // bring in a lanczos vector t and orthogonalize both r and q against it
784                let mut tq = T::zero();
785                let mut tr = T::zero();
786                for i in ll..step {
787                    let v = store.retrq(i);
788                    let t = svd_ddot(v, &wrk.w3);
789                    tq += Float::abs(t);
790                    svd_daxpy(-t, v, &mut wrk.w1);
791                    let t = svd_ddot(v, &wrk.w4);
792                    tr += Float::abs(t);
793                    svd_daxpy(-t, v, &mut wrk.w0);
794                }
795                wrk.w3.copy_from_slice(&wrk.w1);
796                let t = svd_ddot(&wrk.w0, &wrk.w3);
797                tr += Float::abs(t);
798                svd_daxpy(-t, &wrk.w1, &mut wrk.w0);
799                wrk.w4.copy_from_slice(&wrk.w0);
800                *rnm = svd_norm(&wrk.w4);
801                if tq <= reps1 && tr <= *rnm * reps1 {
802                    flag = false;
803                }
804            }
805            iteration += 1;
806        }
807        for i in ll..=step {
808            wrk.eta[i] = eps1;
809            wrk.oldeta[i] = eps1;
810        }
811    }
812}
813
814fn ortbnd<T: SvdFloat>(wrk: &mut WorkSpace<T>, step: usize, rnm: T, eps1: T) {
815    if step < 1 {
816        return;
817    }
818    if !compare(rnm, T::zero()) && step > 1 {
819        wrk.oldeta[0] = (wrk.bet[1] * wrk.eta[1] + (wrk.alf[0] - wrk.alf[step]) * wrk.eta[0]
820            - wrk.bet[step] * wrk.oldeta[0])
821            / rnm
822            + eps1;
823        if step > 2 {
824            for i in 1..=step - 2 {
825                wrk.oldeta[i] = (wrk.bet[i + 1] * wrk.eta[i + 1]
826                    + (wrk.alf[i] - wrk.alf[step]) * wrk.eta[i]
827                    + wrk.bet[i] * wrk.eta[i - 1]
828                    - wrk.bet[step] * wrk.oldeta[i])
829                    / rnm
830                    + eps1;
831            }
832        }
833    }
834    wrk.oldeta[step - 1] = eps1;
835    mem::swap(&mut wrk.oldeta, &mut wrk.eta);
836    wrk.eta[step] = eps1;
837}
838
839fn error_bound<T: SvdFloat>(
840    enough: &mut bool,
841    endl: T,
842    endr: T,
843    ritz: &mut [T],
844    bnd: &mut [T],
845    step: usize,
846    tol: T,
847) -> usize {
848    assert!(step > 0, "error_bound: expected 'step' to be non-zero");
849
850    // massage error bounds for very close ritz values
851    let mid = svd_idamax(step + 1, bnd);
852    let sixteen = T::from_f64(16.0).unwrap();
853
854    let mut i = ((step + 1) + (step - 1)) / 2;
855    while i > mid + 1 {
856        if Float::abs(ritz[i - 1] - ritz[i]) < T::eps34() * Float::abs(ritz[i])
857            && bnd[i] > tol
858            && bnd[i - 1] > tol
859        {
860            bnd[i - 1] = (Float::powi(bnd[i], 2) + Float::powi(bnd[i - 1], 2)).sqrt();
861            bnd[i] = T::zero();
862        }
863        i -= 1;
864    }
865
866    let mut i = ((step + 1) - (step - 1)) / 2;
867    while i + 1 < mid {
868        if Float::abs(ritz[i + 1] - ritz[i]) < T::eps34() * Float::abs(ritz[i])
869            && bnd[i] > tol
870            && bnd[i + 1] > tol
871        {
872            bnd[i + 1] = (Float::powi(bnd[i], 2) + Float::powi(bnd[i + 1], 2)).sqrt();
873            bnd[i] = T::zero();
874        }
875        i += 1;
876    }
877
878    // refine the error bounds
879    let mut neig = 0;
880    let mut gapl = ritz[step] - ritz[0];
881    for i in 0..=step {
882        let mut gap = gapl;
883        if i < step {
884            gapl = ritz[i + 1] - ritz[i];
885        }
886        gap = Float::min(gap, gapl);
887        if gap > bnd[i] {
888            bnd[i] *= bnd[i] / gap;
889        }
890        if bnd[i] <= sixteen * T::eps() * Float::abs(ritz[i]) {
891            neig += 1;
892            if !*enough {
893                *enough = endl < ritz[i] && ritz[i] < endr;
894            }
895        }
896    }
897    neig
898}
899
900fn imtql2<T: SvdFloat>(
901    nm: usize,
902    n: usize,
903    d: &mut [T],
904    e: &mut [T],
905    z: &mut [T],
906    max_imtqlb: Option<usize>,
907) -> Result<(), SvdLibError> {
908    let max_imtqlb = max_imtqlb.unwrap_or(MAX_IMTQLB_ITERATIONS);
909    if n == 1 {
910        return Ok(());
911    }
912    assert!(n > 1, "imtql2: expected 'n' to be > 1");
913    let two = T::from_f64(2.0).unwrap();
914
915    let last = n - 1;
916
917    for i in 1..n {
918        e[i - 1] = e[i];
919    }
920    e[last] = T::zero();
921
922    let nnm = n * nm;
923    for l in 0..n {
924        let mut iteration = 0;
925
926        // look for small sub-diagonal element
927        while iteration <= max_imtqlb {
928            let mut m = l;
929            while m < n {
930                if m == last {
931                    break;
932                }
933                let test = Float::abs(d[m]) + Float::abs(d[m + 1]);
934                if compare(test, test + Float::abs(e[m])) {
935                    break; // convergence = true;
936                }
937                m += 1;
938            }
939            if m == l {
940                break;
941            }
942
943            // error -- no convergence to an eigenvalue after 30 iterations.
944            if iteration == max_imtqlb {
945                return Err(SvdLibError::Imtql2Error(format!(
946                    "imtql2 no convergence to an eigenvalue after {} iterations",
947                    max_imtqlb
948                )));
949            }
950            iteration += 1;
951
952            // form shift
953            let mut g = (d[l + 1] - d[l]) / (two * e[l]);
954            let mut r = svd_pythag(g, T::one());
955            g = d[m] - d[l] + e[l] / (g + svd_fsign(r, g));
956
957            let mut s = T::one();
958            let mut c = T::one();
959            let mut p = T::zero();
960
961            assert!(m > 0, "imtql2: expected 'm' to be non-zero");
962            let mut i = m - 1;
963            let mut underflow = false;
964            while !underflow && i >= l {
965                let mut f = s * e[i];
966                let b = c * e[i];
967                r = svd_pythag(f, g);
968                e[i + 1] = r;
969                if compare(r, T::zero()) {
970                    underflow = true;
971                } else {
972                    s = f / r;
973                    c = g / r;
974                    g = d[i + 1] - p;
975                    r = (d[i] - g) * s + two * c * b;
976                    p = s * r;
977                    d[i + 1] = g + p;
978                    g = c * r - b;
979
980                    // form vector
981                    for k in (0..nnm).step_by(n) {
982                        let index = k + i;
983                        f = z[index + 1];
984                        z[index + 1] = s * z[index] + c * f;
985                        z[index] = c * z[index] - s * f;
986                    }
987                    if i == 0 {
988                        break;
989                    }
990                    i -= 1;
991                }
992            } /* end while (underflow != FALSE && i >= l) */
993            /*........ recover from underflow .........*/
994            if underflow {
995                d[i + 1] -= p;
996            } else {
997                d[l] -= p;
998                e[l] = g;
999            }
1000            e[m] = T::zero();
1001        }
1002    }
1003
1004    // order the eigenvalues
1005    for l in 1..n {
1006        let i = l - 1;
1007        let mut k = i;
1008        let mut p = d[i];
1009        for (j, item) in d.iter().enumerate().take(n).skip(l) {
1010            if *item < p {
1011                k = j;
1012                p = *item;
1013            }
1014        }
1015
1016        // ...and corresponding eigenvectors
1017        if k != i {
1018            d[k] = d[i];
1019            d[i] = p;
1020            for j in (0..nnm).step_by(n) {
1021                z.swap(j + i, j + k);
1022            }
1023        }
1024    }
1025
1026    Ok(())
1027}
1028
1029fn rotate_array<T: Float + Copy>(a: &mut [T], x: usize) {
1030    let n = a.len();
1031    let mut j = 0;
1032    let mut start = 0;
1033    let mut t1 = a[0];
1034
1035    for _ in 0..n {
1036        j = match j >= x {
1037            true => j - x,
1038            false => j + n - x,
1039        };
1040
1041        let t2 = a[j];
1042        a[j] = t1;
1043
1044        if j == start {
1045            j += 1;
1046            start = j;
1047            t1 = a[j];
1048        } else {
1049            t1 = t2;
1050        }
1051    }
1052}
1053
1054#[allow(non_snake_case)]
1055fn ritvec<T: SvdFloat>(
1056    A: &dyn SMat<T>,
1057    dimensions: usize,
1058    kappa: T,
1059    wrk: &mut WorkSpace<T>,
1060    steps: usize,
1061    neig: usize,
1062    store: &mut Store<T>,
1063) -> Result<SVDRawRec<T>, SvdLibError> {
1064    let js = steps + 1;
1065    let jsq = js * js;
1066
1067    let sparsity = T::one()
1068        - (T::from_usize(A.nnz()).unwrap()
1069            / (T::from_usize(A.nrows()).unwrap() * T::from_usize(A.ncols()).unwrap()));
1070
1071    let epsilon = <T as Float>::epsilon();
1072    let adaptive_eps = if sparsity > T::from_f64(0.99).unwrap() {
1073        // For very sparse matrices (>99%), use a more relaxed tolerance
1074        epsilon * T::from_f64(100.0).unwrap()
1075    } else if sparsity > T::from_f64(0.9).unwrap() {
1076        // For moderately sparse matrices (>90%), use a somewhat relaxed tolerance
1077        epsilon * T::from_f64(10.0).unwrap()
1078    } else {
1079        // For less sparse matrices, use standard epsilon
1080        epsilon
1081    };
1082
1083    let max_iterations_imtql2 = if sparsity > T::from_f64(0.999).unwrap() {
1084        // Ultra sparse (>99.9%) - needs many more iterations
1085        Some(500)
1086    } else if sparsity > T::from_f64(0.99).unwrap() {
1087        // Very sparse (>99%) - needs more iterations
1088        Some(300)
1089    } else if sparsity > T::from_f64(0.9).unwrap() {
1090        // Moderately sparse (>90%) - needs somewhat more iterations
1091        Some(200)
1092    } else {
1093        // Default iterations for less sparse matrices
1094        Some(50)
1095    };
1096
1097    let mut s = vec![T::zero(); jsq];
1098    // initialize s to an identity matrix
1099    for i in (0..jsq).step_by(js + 1) {
1100        s[i] = T::one();
1101    }
1102
1103    let mut Vt = DMat {
1104        cols: wrk.ncols,
1105        value: vec![T::zero(); wrk.ncols * dimensions],
1106    };
1107
1108    svd_dcopy(js, 0, &wrk.alf, &mut Vt.value);
1109    svd_dcopy(steps, 1, &wrk.bet, &mut wrk.w5);
1110
1111    // on return from imtql2(), `R.Vt.value` contains eigenvalues in
1112    // ascending order and `s` contains the corresponding eigenvectors
1113    imtql2(
1114        js,
1115        js,
1116        &mut Vt.value,
1117        &mut wrk.w5,
1118        &mut s,
1119        max_iterations_imtql2,
1120    )?;
1121
1122    let max_eigenvalue = Vt
1123        .value
1124        .iter()
1125        .fold(T::zero(), |max, &val| Float::max(max, Float::abs(val)));
1126
1127    let adaptive_kappa = if sparsity > T::from_f64(0.99).unwrap() {
1128        // More relaxed kappa for very sparse matrices
1129        kappa * T::from_f64(10.0).unwrap()
1130    } else {
1131        kappa
1132    };
1133
1134    let mut x = dimensions - 1;
1135
1136    let store_vectors: Vec<Vec<T>> = (0..js).map(|i| store.retrq(i).to_vec()).collect();
1137
1138    let significant_indices: Vec<usize> = (0..js)
1139        .into_par_iter()
1140        .filter(|&k| {
1141            let relative_bound =
1142                adaptive_kappa * Float::max(Float::abs(wrk.ritz[k]), max_eigenvalue * adaptive_eps);
1143            wrk.bnd[k] <= relative_bound && k + 1 > js - neig
1144        })
1145        .collect();
1146
1147    let nsig = significant_indices.len();
1148
1149    let mut vt_vectors: Vec<(usize, Vec<T>)> = significant_indices
1150        .into_par_iter()
1151        .map(|k| {
1152            let mut vec = vec![T::zero(); wrk.ncols];
1153
1154            for i in 0..js {
1155                let idx = k * js + i;
1156
1157                if Float::abs(s[idx]) > adaptive_eps {
1158                    for (j, item) in store_vectors[i].iter().enumerate().take(wrk.ncols) {
1159                        vec[j] += s[idx] * *item;
1160                    }
1161                }
1162            }
1163
1164            (k, vec)
1165        })
1166        .collect();
1167
1168    // Sort by k value to maintain original order
1169    vt_vectors.sort_by_key(|(k, _)| *k);
1170
1171    // final dimension size
1172    let d = dimensions.min(nsig);
1173    let mut S = vec![T::zero(); d];
1174    let mut Ut = DMat {
1175        cols: wrk.nrows,
1176        value: vec![T::zero(); wrk.nrows * d],
1177    };
1178
1179    // Create new Vt with the correct size
1180    let mut Vt = DMat {
1181        cols: wrk.ncols,
1182        value: vec![T::zero(); wrk.ncols * d],
1183    };
1184
1185    // Fill Vt with the vectors we computed
1186    for (i, (_, vec)) in vt_vectors.into_iter().take(d).enumerate() {
1187        let vt_offset = i * Vt.cols;
1188        Vt.value[vt_offset..vt_offset + Vt.cols].copy_from_slice(&vec);
1189    }
1190
1191    // Prepare for parallel computation of S and Ut
1192    let mut ab_products = Vec::with_capacity(d);
1193    let mut a_products = Vec::with_capacity(d);
1194
1195    // First compute all matrix-vector products sequentially
1196    for i in 0..d {
1197        let vt_offset = i * Vt.cols;
1198        let vt_vec = &Vt.value[vt_offset..vt_offset + Vt.cols];
1199
1200        let mut tmp_vec = vec![T::zero(); Vt.cols];
1201        let mut ut_vec = vec![T::zero(); wrk.nrows];
1202
1203        // Matrix-vector products with A and A'A
1204        svd_opb(A, vt_vec, &mut tmp_vec, &mut wrk.temp, wrk.transposed);
1205        A.svd_opa(vt_vec, &mut ut_vec, wrk.transposed);
1206
1207        ab_products.push(tmp_vec);
1208        a_products.push(ut_vec);
1209    }
1210
1211    let results: Vec<(usize, T)> = (0..d)
1212        .into_par_iter()
1213        .map(|i| {
1214            let vt_offset = i * Vt.cols;
1215            let vt_vec = &Vt.value[vt_offset..vt_offset + Vt.cols];
1216            let tmp_vec = &ab_products[i];
1217
1218            // Compute singular value
1219            let t = svd_ddot(vt_vec, tmp_vec);
1220            let sval = Float::max(t, T::zero()).sqrt();
1221
1222            (i, sval)
1223        })
1224        .collect();
1225
1226    // Process results and scale the vectors
1227    for (i, sval) in results {
1228        S[i] = sval;
1229        let ut_offset = i * Ut.cols;
1230        let mut ut_vec = a_products[i].clone();
1231
1232        if sval > adaptive_eps {
1233            svd_dscal(T::one() / sval, &mut ut_vec);
1234        } else {
1235            let dls = Float::max(sval, adaptive_eps);
1236            let safe_scale = T::one() / dls;
1237            svd_dscal(safe_scale, &mut ut_vec);
1238        }
1239
1240        // Copy to output
1241        Ut.value[ut_offset..ut_offset + Ut.cols].copy_from_slice(&ut_vec);
1242    }
1243
1244    Ok(SVDRawRec {
1245        // Dimensionality (rank)
1246        d,
1247        // Significant values
1248        nsig,
1249        // DMat Ut  Transpose of left singular vectors. (d by m)
1250        //          The vectors are the rows of Ut.
1251        Ut,
1252        // Array of singular values. (length d)
1253        S,
1254        // DMat Vt  Transpose of right singular vectors. (d by n)
1255        //          The vectors are the rows of Vt.
1256        Vt,
1257    })
1258}
1259
1260#[allow(non_snake_case)]
1261#[allow(clippy::too_many_arguments)]
1262fn lanso<T: SvdFloat>(
1263    A: &dyn SMat<T>,
1264    dim: usize,
1265    iterations: usize,
1266    end_interval: &[T; 2],
1267    wrk: &mut WorkSpace<T>,
1268    neig: &mut usize,
1269    store: &mut Store<T>,
1270    random_seed: u32,
1271) -> Result<usize, SvdLibError> {
1272    let sparsity = T::one()
1273        - (T::from_usize(A.nnz()).unwrap()
1274            / (T::from_usize(A.nrows()).unwrap() * T::from_usize(A.ncols()).unwrap()));
1275    let max_iterations_imtqlb = if sparsity > T::from_f64(0.999).unwrap() {
1276        // Ultra sparse (>99.9%) - needs many more iterations
1277        Some(500)
1278    } else if sparsity > T::from_f64(0.99).unwrap() {
1279        // Very sparse (>99%) - needs more iterations
1280        Some(300)
1281    } else if sparsity > T::from_f64(0.9).unwrap() {
1282        // Moderately sparse (>90%) - needs somewhat more iterations
1283        Some(100)
1284    } else {
1285        // Default iterations for less sparse matrices
1286        Some(50)
1287    };
1288
1289    let epsilon = <T as Float>::epsilon();
1290    let adaptive_eps = if sparsity > T::from_f64(0.99).unwrap() {
1291        // For very sparse matrices (>99%), use a more relaxed tolerance
1292        epsilon * T::from_f64(100.0).unwrap()
1293    } else if sparsity > T::from_f64(0.9).unwrap() {
1294        // For moderately sparse matrices (>90%), use a somewhat relaxed tolerance
1295        epsilon * T::from_f64(10.0).unwrap()
1296    } else {
1297        // For less sparse matrices, use standard epsilon
1298        epsilon
1299    };
1300
1301    let (endl, endr) = (end_interval[0], end_interval[1]);
1302
1303    /* take the first step */
1304    let rnm_tol = stpone(A, wrk, store, random_seed)?;
1305    let mut rnm = rnm_tol.0;
1306    let mut tol = rnm_tol.1;
1307
1308    let eps1 = adaptive_eps * T::from_f64(wrk.ncols as f64).unwrap().sqrt();
1309    wrk.eta[0] = eps1;
1310    wrk.oldeta[0] = eps1;
1311    let mut ll = 0;
1312    let mut first = 1;
1313    let mut last = iterations.min(dim.max(8) + dim);
1314    let mut enough = false;
1315    let mut j = 0;
1316    let mut intro = 0;
1317
1318    while !enough {
1319        if rnm <= tol {
1320            rnm = T::zero();
1321        }
1322
1323        // the actual lanczos loop
1324        let steps = lanczos_step(
1325            A,
1326            wrk,
1327            first,
1328            last,
1329            &mut ll,
1330            &mut enough,
1331            &mut rnm,
1332            &mut tol,
1333            store,
1334        )?;
1335        j = match enough {
1336            true => steps - 1,
1337            false => last - 1,
1338        };
1339
1340        first = j + 1;
1341        wrk.bet[first] = rnm;
1342
1343        // analyze T
1344        let mut l = 0;
1345        for _ in 0..j {
1346            if l > j {
1347                break;
1348            }
1349
1350            let mut i = l;
1351            while i <= j {
1352                if Float::abs(wrk.bet[i + 1]) <= adaptive_eps {
1353                    break;
1354                }
1355                i += 1;
1356            }
1357            i = i.min(j);
1358
1359            // now i is at the end of an unreduced submatrix
1360            let sz = i - l;
1361            svd_dcopy(sz + 1, l, &wrk.alf, &mut wrk.ritz);
1362            svd_dcopy(sz, l + 1, &wrk.bet, &mut wrk.w5);
1363
1364            imtqlb(
1365                sz + 1,
1366                &mut wrk.ritz[l..],
1367                &mut wrk.w5[l..],
1368                &mut wrk.bnd[l..],
1369                max_iterations_imtqlb,
1370            )?;
1371
1372            for m in l..=i {
1373                wrk.bnd[m] = rnm * Float::abs(wrk.bnd[m]);
1374            }
1375            l = i + 1;
1376        }
1377
1378        // sort eigenvalues into increasing order
1379        insert_sort(j + 1, &mut wrk.ritz, &mut wrk.bnd);
1380
1381        *neig = error_bound(&mut enough, endl, endr, &mut wrk.ritz, &mut wrk.bnd, j, tol);
1382
1383        // should we stop?
1384        if *neig < dim {
1385            if *neig == 0 {
1386                last = first + 9;
1387                intro = first;
1388            } else {
1389                let extra_steps = if sparsity > T::from_f64(0.99).unwrap() {
1390                    5 // For very sparse matrices, add extra steps
1391                } else {
1392                    0
1393                };
1394
1395                last = first + 3.max(1 + ((j - intro) * (dim - *neig)) / *neig) + extra_steps;
1396            }
1397            last = last.min(iterations);
1398        } else {
1399            enough = true
1400        }
1401        enough = enough || first >= iterations;
1402    }
1403    store.storq(j, &wrk.w1);
1404    Ok(j)
1405}
1406
1407impl<T: SvdFloat + 'static> SvdRec<T> {
1408    pub fn recompose(&self) -> Array2<T> {
1409        let sdiag = Array2::from_diag(&self.s);
1410        self.u.dot(&sdiag).dot(&self.vt)
1411    }
1412}
1413
1414impl<T: Float + Zero + AddAssign + Clone + Sync> SMat<T> for nalgebra_sparse::csc::CscMatrix<T> {
1415    fn nrows(&self) -> usize {
1416        self.nrows()
1417    }
1418    fn ncols(&self) -> usize {
1419        self.ncols()
1420    }
1421    fn nnz(&self) -> usize {
1422        self.nnz()
1423    }
1424
1425    /// takes an n-vector x and returns A*x in y
1426    fn svd_opa(&self, x: &[T], y: &mut [T], transposed: bool) {
1427        let nrows = if transposed {
1428            self.ncols()
1429        } else {
1430            self.nrows()
1431        };
1432        let ncols = if transposed {
1433            self.nrows()
1434        } else {
1435            self.ncols()
1436        };
1437        assert_eq!(
1438            x.len(),
1439            ncols,
1440            "svd_opa: x must be A.ncols() in length, x = {}, A.ncols = {}",
1441            x.len(),
1442            ncols
1443        );
1444        assert_eq!(
1445            y.len(),
1446            nrows,
1447            "svd_opa: y must be A.nrows() in length, y = {}, A.nrows = {}",
1448            y.len(),
1449            nrows
1450        );
1451
1452        let (major_offsets, minor_indices, values) = self.csc_data();
1453
1454        for y_val in y.iter_mut() {
1455            *y_val = T::zero();
1456        }
1457
1458        if transposed {
1459            for (i, yval) in y.iter_mut().enumerate() {
1460                for j in major_offsets[i]..major_offsets[i + 1] {
1461                    *yval += values[j] * x[minor_indices[j]];
1462                }
1463            }
1464        } else {
1465            for (i, xval) in x.iter().enumerate() {
1466                for j in major_offsets[i]..major_offsets[i + 1] {
1467                    y[minor_indices[j]] += values[j] * *xval;
1468                }
1469            }
1470        }
1471    }
1472
1473    fn compute_column_means(&self) -> Vec<T> {
1474        todo!()
1475    }
1476}
1477
1478impl<T: Float + Zero + AddAssign + Clone + Sync + Send + std::ops::MulAssign > SMat<T>
1479    for nalgebra_sparse::csr::CsrMatrix<T>
1480{
1481    fn nrows(&self) -> usize {
1482        self.nrows()
1483    }
1484    fn ncols(&self) -> usize {
1485        self.ncols()
1486    }
1487    fn nnz(&self) -> usize {
1488        self.nnz()
1489    }
1490
1491    /// takes an n-vector x and returns A*x in y
1492    fn svd_opa(&self, x: &[T], y: &mut [T], transposed: bool) {
1493        //TODO parallelize me please
1494        let nrows = if transposed {
1495            self.ncols()
1496        } else {
1497            self.nrows()
1498        };
1499        let ncols = if transposed {
1500            self.nrows()
1501        } else {
1502            self.ncols()
1503        };
1504        assert_eq!(
1505            x.len(),
1506            ncols,
1507            "svd_opa: x must be A.ncols() in length, x = {}, A.ncols = {}",
1508            x.len(),
1509            ncols
1510        );
1511        assert_eq!(
1512            y.len(),
1513            nrows,
1514            "svd_opa: y must be A.nrows() in length, y = {}, A.nrows = {}",
1515            y.len(),
1516            nrows
1517        );
1518
1519        let (major_offsets, minor_indices, values) = self.csr_data();
1520
1521        y.fill(T::zero());
1522
1523        if !transposed {
1524            let nrows = self.nrows();
1525            let chunk_size = crate::utils::determine_chunk_size(nrows);
1526
1527            // Create thread-local vectors with results
1528            let results: Vec<(usize, T)> = (0..nrows)
1529                .into_par_iter()
1530                .map(|i| {
1531                    let mut sum = T::zero();
1532                    for j in major_offsets[i]..major_offsets[i + 1] {
1533                        sum += values[j] * x[minor_indices[j]];
1534                    }
1535                    (i, sum)
1536                })
1537                .collect();
1538
1539            // Apply the results to y
1540            for (i, val) in results {
1541                y[i] = val;
1542            }
1543        } else {
1544            let nrows = self.nrows();
1545            let chunk_size = crate::utils::determine_chunk_size(nrows);
1546
1547            // Process input in chunks and create partial results
1548            let results: Vec<Vec<T>> = (0..((nrows + chunk_size - 1) / chunk_size))
1549                .into_par_iter()
1550                .map(|chunk_idx| {
1551                    let start = chunk_idx * chunk_size;
1552                    let end = (start + chunk_size).min(nrows);
1553
1554                    let mut local_y = vec![T::zero(); y.len()];
1555                    for i in start..end {
1556                        let row_val = x[i];
1557                        for j in major_offsets[i]..major_offsets[i + 1] {
1558                            let col = minor_indices[j];
1559                            local_y[col] += values[j] * row_val;
1560                        }
1561                    }
1562                    local_y
1563                })
1564                .collect();
1565
1566            // Combine partial results
1567            for local_y in results {
1568                for (idx, val) in local_y.iter().enumerate() {
1569                    if !val.is_zero() {
1570                        y[idx] += *val;
1571                    }
1572                }
1573            }
1574        }
1575    }
1576
1577    fn compute_column_means(&self) -> Vec<T> {
1578        let rows = self.nrows();
1579        let cols = self.ncols();
1580        let row_count_recip = T::one() / T::from(rows).unwrap();
1581
1582        let mut col_sums = vec![T::zero(); cols];
1583        let (row_offsets, col_indices, values) = self.csr_data();
1584
1585        // Directly accumulate column sums from sparse representation
1586        for i in 0..rows {
1587            for j in row_offsets[i]..row_offsets[i + 1] {
1588                let col = col_indices[j];
1589                col_sums[col] += values[j];
1590            }
1591        }
1592
1593        // Convert to means
1594        for j in 0..cols {
1595            col_sums[j] *= row_count_recip;
1596        }
1597
1598        col_sums
1599    }
1600}
1601
1602impl<T: Float + Zero + AddAssign + Clone + Sync> SMat<T> for nalgebra_sparse::coo::CooMatrix<T> {
1603    fn nrows(&self) -> usize { self.nrows() }
1604    fn ncols(&self) -> usize { self.ncols() }
1605    fn nnz(&self) -> usize { self.nnz() }
1606
1607    /// takes an n-vector x and returns A*x in y
1608    fn svd_opa(&self, x: &[T], y: &mut [T], transposed: bool) {
1609        let nrows = if transposed { self.ncols() } else { self.nrows() };
1610        let ncols = if transposed { self.nrows() } else { self.ncols() };
1611        assert_eq!(x.len(), ncols, "svd_opa: x must be A.ncols() in length, x = {}, A.ncols = {}", x.len(), ncols);
1612        assert_eq!(y.len(), nrows, "svd_opa: y must be A.nrows() in length, y = {}, A.nrows = {}", y.len(), nrows);
1613
1614        for y_val in y.iter_mut() {
1615            *y_val = T::zero();
1616        }
1617
1618        if transposed {
1619            for (i, j, v) in self.triplet_iter() {
1620                y[j] += *v * x[i];
1621            }
1622        } else {
1623            for (i, j, v) in self.triplet_iter() {
1624                y[i] += *v * x[j];
1625            }
1626        }
1627    }
1628
1629    fn compute_column_means(&self) -> Vec<T> {
1630        todo!()
1631    }
1632}