1use crate::error::SvdLibError;
2use ndarray::{Array, Array1, Array2};
3use num_traits::real::Real;
4use num_traits::{Float, FromPrimitive, One, Zero};
5use rand::rngs::StdRng;
6use rand::{thread_rng, Rng, SeedableRng};
7use std::fmt::Debug;
8use std::iter::Sum;
9use std::mem;
10use std::ops::{AddAssign, MulAssign, Neg, SubAssign};
11
12pub trait SMat<T: Float> {
13 fn nrows(&self) -> usize;
14 fn ncols(&self) -> usize;
15 fn nnz(&self) -> usize;
16 fn svd_opa(&self, x: &[T], y: &mut [T], transposed: bool); }
18
19#[derive(Debug, Clone, PartialEq)]
28pub struct SvdRec<T: Float> {
29 pub d: usize,
30 pub ut: Array2<T>,
31 pub s: Array1<T>,
32 pub vt: Array2<T>,
33 pub diagnostics: Diagnostics<T>,
34}
35
36#[derive(Debug, Clone, PartialEq)]
51pub struct Diagnostics<T: Float> {
52 pub non_zero: usize,
53 pub dimensions: usize,
54 pub iterations: usize,
55 pub transposed: bool,
56 pub lanczos_steps: usize,
57 pub ritz_values_stabilized: usize,
58 pub significant_values: usize,
59 pub singular_values: usize,
60 pub end_interval: [T; 2],
61 pub kappa: T,
62 pub random_seed: u32,
63}
64
65pub trait SvdFloat:
67 Float
68 + FromPrimitive
69 + Debug
70 + Send
71 + Sync
72 + Zero
73 + One
74 + AddAssign
75 + SubAssign
76 + MulAssign
77 + Neg<Output = Self>
78 + Sum
79{
80 fn eps() -> Self;
81 fn eps34() -> Self;
82 fn compare(a: Self, b: Self) -> bool;
83}
84
85impl SvdFloat for f32 {
86 fn eps() -> Self {
87 f32::EPSILON
88 }
89
90 fn eps34() -> Self {
91 f32::EPSILON.powf(0.75)
92 }
93
94 fn compare(a: Self, b: Self) -> bool {
95 (b - a).abs() < f32::EPSILON
96 }
97}
98
99impl SvdFloat for f64 {
100 fn eps() -> Self {
101 f64::EPSILON
102 }
103
104 fn eps34() -> Self {
105 f64::EPSILON.powf(0.75)
106 }
107
108 fn compare(a: Self, b: Self) -> bool {
109 (b - a).abs() < f64::EPSILON
110 }
111}
112
113pub fn svd<T, M>(a: &M) -> Result<SvdRec<T>, SvdLibError>
120where
121 T: SvdFloat,
122 M: SMat<T>,
123{
124 let eps_small = T::from_f64(-1.0e-30).unwrap();
125 let eps_large = T::from_f64(1.0e-30).unwrap();
126 let kappa = T::from_f64(1.0e-6).unwrap();
127 svd_las2(a, 0, 0, &[eps_small, eps_large], kappa, 0)
128}
129
130pub fn svd_dim<T, M>(a: &M, dimensions: usize) -> Result<SvdRec<T>, SvdLibError>
138where
139 T: SvdFloat,
140 M: SMat<T>,
141{
142 let eps_small = T::from_f64(-1.0e-30).unwrap();
143 let eps_large = T::from_f64(1.0e-30).unwrap();
144 let kappa = T::from_f64(1.0e-6).unwrap();
145
146 svd_las2(a, dimensions, 0, &[eps_small, eps_large], kappa, 0)
147}
148
149pub fn svd_dim_seed<T, M>(
158 a: &M,
159 dimensions: usize,
160 random_seed: u32,
161) -> Result<SvdRec<T>, SvdLibError>
162where
163 T: SvdFloat,
164 M: SMat<T>,
165{
166 let eps_small = T::from_f64(-1.0e-30).unwrap();
167 let eps_large = T::from_f64(1.0e-30).unwrap();
168 let kappa = T::from_f64(1.0e-6).unwrap();
169
170 svd_las2(
171 a,
172 dimensions,
173 0,
174 &[eps_small, eps_large],
175 kappa,
176 random_seed,
177 )
178}
179
180pub fn svd_las2<T, M>(
197 a: &M,
198 dimensions: usize,
199 iterations: usize,
200 end_interval: &[T; 2],
201 kappa: T,
202 random_seed: u32,
203) -> Result<SvdRec<T>, SvdLibError>
204where
205 T: SvdFloat,
206 M: SMat<T>,
207{
208 let random_seed = match random_seed > 0 {
209 true => random_seed,
210 false => thread_rng().gen::<_>(),
211 };
212
213 let min_nrows_ncols = a.nrows().min(a.ncols());
214
215 let dimensions = match dimensions {
216 n if n == 0 || n > min_nrows_ncols => min_nrows_ncols,
217 _ => dimensions,
218 };
219
220 let iterations = match iterations {
221 n if n == 0 || n > min_nrows_ncols => min_nrows_ncols,
222 n if n < dimensions => dimensions,
223 _ => iterations,
224 };
225
226 if dimensions < 2 {
227 return Err(SvdLibError::Las2Error(format!(
228 "svd_las2: insufficient dimensions: {dimensions}"
229 )));
230 }
231
232 assert!(dimensions > 1 && dimensions <= min_nrows_ncols);
233 assert!(iterations >= dimensions && iterations <= min_nrows_ncols);
234
235 let transposed = (a.ncols() as f64) >= ((a.nrows() as f64) * 1.2);
236 let nrows = if transposed { a.ncols() } else { a.nrows() };
237 let ncols = if transposed { a.nrows() } else { a.ncols() };
238
239 let mut wrk = WorkSpace::new(nrows, ncols, transposed, iterations)?;
240 let mut store = Store::new(ncols)?;
241
242 let mut neig = 0;
243 let steps = lanso(
244 a,
245 dimensions,
246 iterations,
247 end_interval,
248 &mut wrk,
249 &mut neig,
250 &mut store,
251 random_seed,
252 )?;
253
254 let kappa = kappa.abs().max(T::eps34());
255 let mut r = ritvec(a, dimensions, kappa, &mut wrk, steps, neig, &mut store)?;
256
257 if transposed {
258 mem::swap(&mut r.Ut, &mut r.Vt);
259 }
260
261 Ok(SvdRec {
262 d: r.d,
264 ut: Array2::from_shape_vec((r.d, r.Ut.cols), r.Ut.value)?,
265 s: Array::from_shape_vec(r.d, r.S)?,
266 vt: Array2::from_shape_vec((r.d, r.Vt.cols), r.Vt.value)?,
267 diagnostics: Diagnostics {
268 non_zero: a.nnz(),
269 dimensions: dimensions,
270 iterations: iterations,
271 transposed: transposed,
272 lanczos_steps: steps + 1,
273 ritz_values_stabilized: neig,
274 significant_values: r.d,
275 singular_values: r.nsig,
276 end_interval: *end_interval,
277 kappa: kappa,
278 random_seed: random_seed,
279 },
280 })
281}
282
283const MAXLL: usize = 2;
284
285#[derive(Debug, Clone, PartialEq)]
286struct Store<T: Float> {
287 n: usize,
288 vecs: Vec<Vec<T>>,
289}
290
291impl<T: Float + Zero + Clone> Store<T> {
292 fn new(n: usize) -> Result<Self, SvdLibError> {
293 Ok(Self { n, vecs: vec![] })
294 }
295
296 fn storq(&mut self, idx: usize, v: &[T]) {
297 while idx + MAXLL >= self.vecs.len() {
298 self.vecs.push(vec![T::zero(); self.n]);
299 }
300 self.vecs[idx + MAXLL].copy_from_slice(v);
301 }
302
303 fn storp(&mut self, idx: usize, v: &[T]) {
304 while idx >= self.vecs.len() {
305 self.vecs.push(vec![T::zero(); self.n]);
306 }
307 self.vecs[idx].copy_from_slice(v);
308 }
309
310 fn retrq(&mut self, idx: usize) -> &[T] {
311 &self.vecs[idx + MAXLL]
312 }
313
314 fn retrp(&mut self, idx: usize) -> &[T] {
315 &self.vecs[idx]
316 }
317}
318
319#[derive(Debug, Clone, PartialEq)]
320struct WorkSpace<T: Float> {
321 nrows: usize,
322 ncols: usize,
323 transposed: bool,
324 w0: Vec<T>, w1: Vec<T>, w2: Vec<T>, w3: Vec<T>, w4: Vec<T>, w5: Vec<T>, alf: Vec<T>, eta: Vec<T>, oldeta: Vec<T>, bet: Vec<T>, bnd: Vec<T>, ritz: Vec<T>, temp: Vec<T>, }
338
339impl<T: Float + Zero + FromPrimitive> WorkSpace<T> {
340 fn new(
341 nrows: usize,
342 ncols: usize,
343 transposed: bool,
344 iterations: usize,
345 ) -> Result<Self, SvdLibError> {
346 Ok(Self {
347 nrows,
348 ncols,
349 transposed,
350 w0: vec![T::zero(); ncols],
351 w1: vec![T::zero(); ncols],
352 w2: vec![T::zero(); ncols],
353 w3: vec![T::zero(); ncols],
354 w4: vec![T::zero(); ncols],
355 w5: vec![T::zero(); ncols],
356 alf: vec![T::zero(); iterations],
357 eta: vec![T::zero(); iterations],
358 oldeta: vec![T::zero(); iterations],
359 bet: vec![T::zero(); 1 + iterations],
360 ritz: vec![T::zero(); 1 + iterations],
361 bnd: vec![T::from_f64(f64::MAX).unwrap(); 1 + iterations],
362 temp: vec![T::zero(); nrows],
363 })
364 }
365}
366
367#[derive(Debug, Clone, PartialEq)]
369struct DMat<T: Float> {
370 cols: usize,
371 value: Vec<T>,
372}
373
374#[allow(non_snake_case)]
375#[derive(Debug, Clone, PartialEq)]
376struct SVDRawRec<T: Float> {
377 d: usize,
378 nsig: usize,
379 Ut: DMat<T>,
380 S: Vec<T>,
381 Vt: DMat<T>,
382}
383
384fn compare<T: SvdFloat>(computed: T, expected: T) -> bool {
385 T::compare(computed, expected)
386}
387
388fn insert_sort<T: PartialOrd>(n: usize, array1: &mut [T], array2: &mut [T]) {
390 for i in 1..n {
391 for j in (1..i + 1).rev() {
392 if array1[j - 1] <= array1[j] {
393 break;
394 }
395 array1.swap(j - 1, j);
396 array2.swap(j - 1, j);
397 }
398 }
399}
400
401#[allow(non_snake_case)]
402#[rustfmt::skip]
403fn svd_opb<T: Float>(A: &dyn SMat<T>, x: &[T], y: &mut [T], temp: &mut [T], transposed: bool) {
404 let nrows = if transposed { A.ncols() } else { A.nrows() };
405 let ncols = if transposed { A.nrows() } else { A.ncols() };
406 assert_eq!(x.len(), ncols, "svd_opb: x must be A.ncols() in length, x = {}, A.ncols = {}", x.len(), ncols);
407 assert_eq!(y.len(), ncols, "svd_opb: y must be A.ncols() in length, y = {}, A.ncols = {}", y.len(), ncols);
408 assert_eq!(temp.len(), nrows, "svd_opa: temp must be A.nrows() in length, temp = {}, A.nrows = {}", temp.len(), nrows);
409 A.svd_opa(x, temp, transposed); A.svd_opa(temp, y, !transposed); }
412
413fn svd_daxpy<T: Float + AddAssign>(da: T, x: &[T], y: &mut [T]) {
415 for (xval, yval) in x.iter().zip(y.iter_mut()) {
416 *yval += da * *xval
417 }
418}
419
420fn svd_idamax<T: Float>(n: usize, x: &[T]) -> usize {
422 assert!(n > 0, "svd_idamax: unexpected inputs!");
423
424 match n {
425 1 => 0,
426 _ => {
427 let mut imax = 0;
428 for (i, xval) in x.iter().enumerate().take(n).skip(1) {
429 if xval.abs() > x[imax].abs() {
430 imax = i;
431 }
432 }
433 imax
434 }
435 }
436}
437
438fn svd_fsign<T: Float>(a: T, b: T) -> T {
440 match (a >= T::zero() && b >= T::zero()) || (a < T::zero() && b < T::zero()) {
441 true => a,
442 false => -a,
443 }
444}
445
446fn svd_pythag<T: SvdFloat + FromPrimitive>(a: T, b: T) -> T {
448 match a.abs().max(b.abs()) {
449 n if n > T::zero() => {
450 let mut p = n;
451 let mut r = (a.abs().min(b.abs()) / p).powi(2);
452 let four = T::from_f64(4.0).unwrap();
453 let two = T::from_f64(2.0).unwrap();
454 let mut t = four + r;
455 while !compare(t, four) {
456 let s = r / t;
457 let u = T::one() + two * s;
458 p = p * u;
459 r = (s / u).powi(2);
460 t = four + r;
461 }
462 p
463 }
464 _ => T::zero(),
465 }
466}
467
468fn svd_ddot<T: Float + Sum<T>>(x: &[T], y: &[T]) -> T {
470 x.iter().zip(y).map(|(a, b)| *a * *b).sum()
471}
472
473fn svd_norm<T: Float + Sum<T>>(x: &[T]) -> T {
475 svd_ddot(x, x).sqrt()
476}
477
478fn svd_datx<T: Float + Sum<T>>(d: T, x: &[T], y: &mut [T]) {
480 for (i, xval) in x.iter().enumerate() {
481 y[i] = d * *xval;
482 }
483}
484
485fn svd_dscal<T: Float + MulAssign>(d: T, x: &mut [T]) {
487 for elem in x.iter_mut() {
488 *elem *= d;
489 }
490}
491
492fn svd_dcopy<T: Float + Copy>(n: usize, offset: usize, x: &[T], y: &mut [T]) {
494 if n > 0 {
495 let start = n - 1;
496 for i in 0..n {
497 y[offset + start - i] = x[offset + i];
498 }
499 }
500}
501
502const MAX_IMTQLB_ITERATIONS: usize = 100;
503
504fn imtqlb<T: SvdFloat>(
505 n: usize,
506 d: &mut [T],
507 e: &mut [T],
508 bnd: &mut [T],
509 max_imtqlb: Option<usize>,
510) -> Result<(), SvdLibError> {
511 let max_imtqlb = max_imtqlb.unwrap_or(MAX_IMTQLB_ITERATIONS);
512 if n == 1 {
513 return Ok(());
514 }
515
516 let matrix_size_factor = T::from_f64((n as f64).sqrt()).unwrap();
517
518 bnd[0] = T::one();
519 let last = n - 1;
520 for i in 1..=last {
521 bnd[i] = T::zero();
522 e[i - 1] = e[i];
523 }
524 e[last] = T::zero();
525
526 let mut i = 0;
527
528 let mut had_convergence_issues = false;
529
530 for l in 0..=last {
531 let mut iteration = 0;
532 let mut p = d[l];
533 let mut f = bnd[l];
534
535 while iteration <= max_imtqlb {
536 let mut m = l;
537 while m < n {
538 if m == last {
539 break;
540 }
541
542 let test = d[m].abs() + d[m + 1].abs();
544 let tol = T::epsilon()
546 * T::from_f64(100.0).unwrap()
547 * test.max(T::one())
548 * matrix_size_factor;
549
550 if e[m].abs() <= tol {
551 break; }
553 m += 1;
554 }
555
556 if m == l {
557 let mut exchange = true;
559 if l > 0 {
560 i = l;
561 while i >= 1 && exchange {
562 if p < d[i - 1] {
563 d[i] = d[i - 1];
564 bnd[i] = bnd[i - 1];
565 i -= 1;
566 } else {
567 exchange = false;
568 }
569 }
570 }
571 if exchange {
572 i = 0;
573 }
574 d[i] = p;
575 bnd[i] = f;
576 iteration = max_imtqlb + 1; } else {
578 if iteration == max_imtqlb {
580 had_convergence_issues = true;
582
583 for idx in l..=m {
585 bnd[idx] = bnd[idx].max(T::from_f64(0.1).unwrap());
586 }
587
588 e[l] = T::zero();
590
591 break;
593 }
594
595 iteration += 1;
596 let two = T::from_f64(2.0).unwrap();
598 let mut g = (d[l + 1] - p) / (two * e[l]);
599 let mut r = svd_pythag(g, T::one());
600 g = d[m] - p + e[l] / (g + svd_fsign(r, g));
601 let mut s = T::one();
602 let mut c = T::one();
603 p = T::zero();
604
605 assert!(m > 0, "imtqlb: expected 'm' to be non-zero");
606 i = m - 1;
607 let mut underflow = false;
608 while !underflow && i >= l {
609 f = s * e[i];
610 let b = c * e[i];
611 r = svd_pythag(f, g);
612 e[i + 1] = r;
613
614 if r < T::epsilon() * T::from_f64(1000.0).unwrap() * (f.abs() + g.abs()) {
616 underflow = true;
617 break;
618 }
619
620 if r.abs() < T::epsilon() * T::from_f64(100.0).unwrap() {
622 r = T::epsilon() * T::from_f64(100.0).unwrap() * svd_fsign(T::one(), r);
623 }
624
625 s = f / r;
626 c = g / r;
627 g = d[i + 1] - p;
628 r = (d[i] - g) * s + T::from_f64(2.0).unwrap() * c * b;
629 p = s * r;
630 d[i + 1] = g + p;
631 g = c * r - b;
632 f = bnd[i + 1];
633 bnd[i + 1] = s * bnd[i] + c * f;
634 bnd[i] = c * bnd[i] - s * f;
635 if i == 0 {
636 break;
637 }
638 i -= 1;
639 }
640 if underflow {
642 d[i + 1] -= p;
643 } else {
644 d[l] -= p;
645 e[l] = g;
646 }
647 e[m] = T::zero();
648 }
649 }
650 }
651 if had_convergence_issues {
652 eprintln!("Warning: imtqlb had some convergence issues but continued with best estimates. Results may have reduced accuracy.");
653 }
654 Ok(())
655}
656
657#[allow(non_snake_case)]
658fn startv<T: SvdFloat>(
659 A: &dyn SMat<T>,
660 wrk: &mut WorkSpace<T>,
661 step: usize,
662 store: &mut Store<T>,
663 random_seed: u32,
664) -> Result<T, SvdLibError> {
665 let mut rnm2 = svd_ddot(&wrk.w0, &wrk.w0);
667 for id in 0..3 {
668 if id > 0 || step > 0 || compare(rnm2, T::zero()) {
669 let mut bytes = [0; 32];
670 for (i, b) in random_seed.to_le_bytes().iter().enumerate() {
671 bytes[i] = *b;
672 }
673 let mut seeded_rng = StdRng::from_seed(bytes);
674 for val in wrk.w0.iter_mut() {
675 *val = T::from_f64(seeded_rng.gen_range(-1.0..1.0)).unwrap();
676 }
677 }
678 wrk.w3.copy_from_slice(&wrk.w0);
679
680 svd_opb(A, &wrk.w3, &mut wrk.w0, &mut wrk.temp, wrk.transposed);
682 wrk.w3.copy_from_slice(&wrk.w0);
683 rnm2 = svd_ddot(&wrk.w3, &wrk.w3);
684 if rnm2 > T::zero() {
685 break;
686 }
687 }
688
689 if rnm2 <= T::zero() {
690 return Err(SvdLibError::StartvError(format!(
691 "rnm2 <= 0.0, rnm2 = {rnm2:?}"
692 )));
693 }
694
695 if step > 0 {
696 for i in 0..step {
697 let v = store.retrq(i);
698 svd_daxpy(-svd_ddot(&wrk.w3, v), v, &mut wrk.w0);
699 }
700
701 svd_daxpy(-svd_ddot(&wrk.w4, &wrk.w0), &wrk.w2, &mut wrk.w0);
703 wrk.w3.copy_from_slice(&wrk.w0);
704
705 rnm2 = match svd_ddot(&wrk.w3, &wrk.w3) {
706 dot if dot <= T::eps() * rnm2 => T::zero(),
707 dot => dot,
708 }
709 }
710 Ok(rnm2.sqrt())
711}
712
713#[allow(non_snake_case)]
714fn stpone<T: SvdFloat>(
715 A: &dyn SMat<T>,
716 wrk: &mut WorkSpace<T>,
717 store: &mut Store<T>,
718 random_seed: u32,
719) -> Result<(T, T), SvdLibError> {
720 let mut rnm = startv(A, wrk, 0, store, random_seed)?;
722 if compare(rnm, T::zero()) {
723 return Err(SvdLibError::StponeError("rnm == 0.0".to_string()));
724 }
725
726 svd_datx(rnm.recip(), &wrk.w0, &mut wrk.w1);
728 svd_dscal(rnm.recip(), &mut wrk.w3);
729
730 svd_opb(A, &wrk.w3, &mut wrk.w0, &mut wrk.temp, wrk.transposed);
732 wrk.alf[0] = svd_ddot(&wrk.w0, &wrk.w3);
733 svd_daxpy(-wrk.alf[0], &wrk.w1, &mut wrk.w0);
734 let t = svd_ddot(&wrk.w0, &wrk.w3);
735 wrk.alf[0] += t;
736 svd_daxpy(-t, &wrk.w1, &mut wrk.w0);
737 wrk.w4.copy_from_slice(&wrk.w0);
738 rnm = svd_norm(&wrk.w4);
739 let anorm = rnm + wrk.alf[0].abs();
740 Ok((rnm, T::eps().sqrt() * anorm))
741}
742
743#[allow(non_snake_case)]
744#[allow(clippy::too_many_arguments)]
745fn lanczos_step<T: SvdFloat>(
746 A: &dyn SMat<T>,
747 wrk: &mut WorkSpace<T>,
748 first: usize,
749 last: usize,
750 ll: &mut usize,
751 enough: &mut bool,
752 rnm: &mut T,
753 tol: &mut T,
754 store: &mut Store<T>,
755) -> Result<usize, SvdLibError> {
756 let eps1 = T::eps() * T::from_f64(wrk.ncols as f64).unwrap().sqrt();
757 let mut j = first;
758 let four = T::from_f64(4.0).unwrap();
759
760 while j < last {
761 mem::swap(&mut wrk.w1, &mut wrk.w2);
762 mem::swap(&mut wrk.w3, &mut wrk.w4);
763
764 store.storq(j - 1, &wrk.w2);
765 if j - 1 < MAXLL {
766 store.storp(j - 1, &wrk.w4);
767 }
768 wrk.bet[j] = *rnm;
769
770 if compare(*rnm, T::zero()) {
772 *rnm = startv(A, wrk, j, store, 0)?;
773 if compare(*rnm, T::zero()) {
774 *enough = true;
775 }
776 }
777
778 if *enough {
779 mem::swap(&mut wrk.w1, &mut wrk.w2);
780 break;
781 }
782
783 svd_datx(rnm.recip(), &wrk.w0, &mut wrk.w1);
785 svd_dscal(rnm.recip(), &mut wrk.w3);
786 svd_opb(A, &wrk.w3, &mut wrk.w0, &mut wrk.temp, wrk.transposed);
787 svd_daxpy(-*rnm, &wrk.w2, &mut wrk.w0);
788 wrk.alf[j] = svd_ddot(&wrk.w0, &wrk.w3);
789 svd_daxpy(-wrk.alf[j], &wrk.w1, &mut wrk.w0);
790
791 if j <= MAXLL && wrk.alf[j - 1].abs() > four * wrk.alf[j].abs() {
793 *ll = j;
794 }
795 for i in 0..(j - 1).min(*ll) {
796 let v1 = store.retrp(i);
797 let t = svd_ddot(v1, &wrk.w0);
798 let v2 = store.retrq(i);
799 svd_daxpy(-t, v2, &mut wrk.w0);
800 wrk.eta[i] = eps1;
801 wrk.oldeta[i] = eps1;
802 }
803
804 let t = svd_ddot(&wrk.w0, &wrk.w4);
806 svd_daxpy(-t, &wrk.w2, &mut wrk.w0);
807 if wrk.bet[j] > T::zero() {
808 wrk.bet[j] += t;
809 }
810 let t = svd_ddot(&wrk.w0, &wrk.w3);
811 svd_daxpy(-t, &wrk.w1, &mut wrk.w0);
812 wrk.alf[j] += t;
813 wrk.w4.copy_from_slice(&wrk.w0);
814 *rnm = svd_norm(&wrk.w4);
815 let anorm = wrk.bet[j] + wrk.alf[j].abs() + *rnm;
816 *tol = T::eps().sqrt() * anorm;
817
818 ortbnd(wrk, j, *rnm, eps1);
820
821 purge(wrk.ncols, *ll, wrk, j, rnm, *tol, store);
823 if *rnm <= *tol {
824 *rnm = T::zero();
825 }
826 j += 1;
827 }
828 Ok(j)
829}
830
831fn purge<T: SvdFloat>(
832 n: usize,
833 ll: usize,
834 wrk: &mut WorkSpace<T>,
835 step: usize,
836 rnm: &mut T,
837 tol: T,
838 store: &mut Store<T>,
839) {
840 if step < ll + 2 {
841 return;
842 }
843
844 let reps = T::eps().sqrt();
845 let eps1 = T::eps() * T::from_f64(n as f64).unwrap().sqrt();
846 let two = T::from_f64(2.0).unwrap();
847
848 let k = svd_idamax(step - (ll + 1), &wrk.eta) + ll;
849 if wrk.eta[k].abs() > reps {
850 let reps1 = eps1 / reps;
851 let mut iteration = 0;
852 let mut flag = true;
853 while iteration < 2 && flag {
854 if *rnm > tol {
855 let mut tq = T::zero();
857 let mut tr = T::zero();
858 for i in ll..step {
859 let v = store.retrq(i);
860 let t = svd_ddot(v, &wrk.w3);
861 tq += t.abs();
862 svd_daxpy(-t, v, &mut wrk.w1);
863 let t = svd_ddot(v, &wrk.w4);
864 tr += t.abs();
865 svd_daxpy(-t, v, &mut wrk.w0);
866 }
867 wrk.w3.copy_from_slice(&wrk.w1);
868 let t = svd_ddot(&wrk.w0, &wrk.w3);
869 tr += t.abs();
870 svd_daxpy(-t, &wrk.w1, &mut wrk.w0);
871 wrk.w4.copy_from_slice(&wrk.w0);
872 *rnm = svd_norm(&wrk.w4);
873 if tq <= reps1 && tr <= *rnm * reps1 {
874 flag = false;
875 }
876 }
877 iteration += 1;
878 }
879 for i in ll..=step {
880 wrk.eta[i] = eps1;
881 wrk.oldeta[i] = eps1;
882 }
883 }
884}
885
886fn ortbnd<T: SvdFloat>(wrk: &mut WorkSpace<T>, step: usize, rnm: T, eps1: T) {
887 if step < 1 {
888 return;
889 }
890 if !compare(rnm, T::zero()) && step > 1 {
891 wrk.oldeta[0] = (wrk.bet[1] * wrk.eta[1] + (wrk.alf[0] - wrk.alf[step]) * wrk.eta[0]
892 - wrk.bet[step] * wrk.oldeta[0])
893 / rnm
894 + eps1;
895 if step > 2 {
896 for i in 1..=step - 2 {
897 wrk.oldeta[i] = (wrk.bet[i + 1] * wrk.eta[i + 1]
898 + (wrk.alf[i] - wrk.alf[step]) * wrk.eta[i]
899 + wrk.bet[i] * wrk.eta[i - 1]
900 - wrk.bet[step] * wrk.oldeta[i])
901 / rnm
902 + eps1;
903 }
904 }
905 }
906 wrk.oldeta[step - 1] = eps1;
907 mem::swap(&mut wrk.oldeta, &mut wrk.eta);
908 wrk.eta[step] = eps1;
909}
910
911fn error_bound<T: SvdFloat>(
912 enough: &mut bool,
913 endl: T,
914 endr: T,
915 ritz: &mut [T],
916 bnd: &mut [T],
917 step: usize,
918 tol: T,
919) -> usize {
920 assert!(step > 0, "error_bound: expected 'step' to be non-zero");
921
922 let mid = svd_idamax(step + 1, bnd);
924 let sixteen = T::from_f64(16.0).unwrap();
925
926 let mut i = ((step + 1) + (step - 1)) / 2;
927 while i > mid + 1 {
928 if (ritz[i - 1] - ritz[i]).abs() < T::eps34() * ritz[i].abs()
929 && bnd[i] > tol
930 && bnd[i - 1] > tol
931 {
932 bnd[i - 1] = (bnd[i].powi(2) + bnd[i - 1].powi(2)).sqrt();
933 bnd[i] = T::zero();
934 }
935 i -= 1;
936 }
937
938 let mut i = ((step + 1) - (step - 1)) / 2;
939 while i + 1 < mid {
940 if (ritz[i + 1] - ritz[i]).abs() < T::eps34() * ritz[i].abs()
941 && bnd[i] > tol
942 && bnd[i + 1] > tol
943 {
944 bnd[i + 1] = (bnd[i].powi(2) + bnd[i + 1].powi(2)).sqrt();
945 bnd[i] = T::zero();
946 }
947 i += 1;
948 }
949
950 let mut neig = 0;
952 let mut gapl = ritz[step] - ritz[0];
953 for i in 0..=step {
954 let mut gap = gapl;
955 if i < step {
956 gapl = ritz[i + 1] - ritz[i];
957 }
958 gap = gap.min(gapl);
959 if gap > bnd[i] {
960 bnd[i] *= bnd[i] / gap;
961 }
962 if bnd[i] <= sixteen * T::eps() * ritz[i].abs() {
963 neig += 1;
964 if !*enough {
965 *enough = endl < ritz[i] && ritz[i] < endr;
966 }
967 }
968 }
969 neig
970}
971
972fn imtql2<T: SvdFloat>(
973 nm: usize,
974 n: usize,
975 d: &mut [T],
976 e: &mut [T],
977 z: &mut [T],
978 max_imtqlb: Option<usize>,
979) -> Result<(), SvdLibError> {
980 let max_imtqlb = max_imtqlb.unwrap_or(MAX_IMTQLB_ITERATIONS);
981 if n == 1 {
982 return Ok(());
983 }
984 assert!(n > 1, "imtql2: expected 'n' to be > 1");
985 let two = T::from_f64(2.0).unwrap();
986
987 let last = n - 1;
988
989 for i in 1..n {
990 e[i - 1] = e[i];
991 }
992 e[last] = T::zero();
993
994 let nnm = n * nm;
995 for l in 0..n {
996 let mut iteration = 0;
997
998 while iteration <= max_imtqlb {
1000 let mut m = l;
1001 while m < n {
1002 if m == last {
1003 break;
1004 }
1005 let test = d[m].abs() + d[m + 1].abs();
1006 if compare(test, test + e[m].abs()) {
1007 break; }
1009 m += 1;
1010 }
1011 if m == l {
1012 break;
1013 }
1014
1015 if iteration == max_imtqlb {
1017 return Err(SvdLibError::Imtql2Error(format!(
1018 "imtql2 no convergence to an eigenvalue after {} iterations",
1019 max_imtqlb
1020 )));
1021 }
1022 iteration += 1;
1023
1024 let mut g = (d[l + 1] - d[l]) / (two * e[l]);
1026 let mut r = svd_pythag(g, T::one());
1027 g = d[m] - d[l] + e[l] / (g + svd_fsign(r, g));
1028
1029 let mut s = T::one();
1030 let mut c = T::one();
1031 let mut p = T::zero();
1032
1033 assert!(m > 0, "imtql2: expected 'm' to be non-zero");
1034 let mut i = m - 1;
1035 let mut underflow = false;
1036 while !underflow && i >= l {
1037 let mut f = s * e[i];
1038 let b = c * e[i];
1039 r = svd_pythag(f, g);
1040 e[i + 1] = r;
1041 if compare(r, T::zero()) {
1042 underflow = true;
1043 } else {
1044 s = f / r;
1045 c = g / r;
1046 g = d[i + 1] - p;
1047 r = (d[i] - g) * s + two * c * b;
1048 p = s * r;
1049 d[i + 1] = g + p;
1050 g = c * r - b;
1051
1052 for k in (0..nnm).step_by(n) {
1054 let index = k + i;
1055 f = z[index + 1];
1056 z[index + 1] = s * z[index] + c * f;
1057 z[index] = c * z[index] - s * f;
1058 }
1059 if i == 0 {
1060 break;
1061 }
1062 i -= 1;
1063 }
1064 } if underflow {
1067 d[i + 1] -= p;
1068 } else {
1069 d[l] -= p;
1070 e[l] = g;
1071 }
1072 e[m] = T::zero();
1073 }
1074 }
1075
1076 for l in 1..n {
1078 let i = l - 1;
1079 let mut k = i;
1080 let mut p = d[i];
1081 for (j, item) in d.iter().enumerate().take(n).skip(l) {
1082 if *item < p {
1083 k = j;
1084 p = *item;
1085 }
1086 }
1087
1088 if k != i {
1090 d[k] = d[i];
1091 d[i] = p;
1092 for j in (0..nnm).step_by(n) {
1093 z.swap(j + i, j + k);
1094 }
1095 }
1096 }
1097
1098 Ok(())
1099}
1100
1101fn rotate_array<T: Float + Copy>(a: &mut [T], x: usize) {
1102 let n = a.len();
1103 let mut j = 0;
1104 let mut start = 0;
1105 let mut t1 = a[0];
1106
1107 for _ in 0..n {
1108 j = match j >= x {
1109 true => j - x,
1110 false => j + n - x,
1111 };
1112
1113 let t2 = a[j];
1114 a[j] = t1;
1115
1116 if j == start {
1117 j += 1;
1118 start = j;
1119 t1 = a[j];
1120 } else {
1121 t1 = t2;
1122 }
1123 }
1124}
1125
1126#[allow(non_snake_case)]
1127fn ritvec<T: SvdFloat>(
1128 A: &dyn SMat<T>,
1129 dimensions: usize,
1130 kappa: T,
1131 wrk: &mut WorkSpace<T>,
1132 steps: usize,
1133 neig: usize,
1134 store: &mut Store<T>,
1135) -> Result<SVDRawRec<T>, SvdLibError> {
1136 let js = steps + 1;
1137 let jsq = js * js;
1138
1139 let sparsity = T::one()
1140 - (T::from_usize(A.nnz()).unwrap()
1141 / (T::from_usize(A.nrows()).unwrap() * T::from_usize(A.ncols()).unwrap()));
1142
1143 let epsilon = T::epsilon();
1144 let adaptive_eps = if sparsity > T::from_f64(0.99).unwrap() {
1145 epsilon * T::from_f64(100.0).unwrap()
1147 } else if sparsity > T::from_f64(0.9).unwrap() {
1148 epsilon * T::from_f64(10.0).unwrap()
1150 } else {
1151 epsilon
1153 };
1154
1155 let max_iterations_imtql2 = if sparsity > T::from_f64(0.999).unwrap() {
1156 Some(500)
1158 } else if sparsity > T::from_f64(0.99).unwrap() {
1159 Some(200)
1162 } else if sparsity > T::from_f64(0.9).unwrap() {
1163 Some(100)
1165 } else {
1166 Some(50)
1168 };
1169
1170 let mut s = vec![T::zero(); jsq];
1171 for i in (0..jsq).step_by(js + 1) {
1173 s[i] = T::one();
1174 }
1175
1176 let mut Vt = DMat {
1177 cols: wrk.ncols,
1178 value: vec![T::zero(); wrk.ncols * dimensions],
1179 };
1180
1181 svd_dcopy(js, 0, &wrk.alf, &mut Vt.value);
1182 svd_dcopy(steps, 1, &wrk.bet, &mut wrk.w5);
1183
1184 imtql2(
1187 js,
1188 js,
1189 &mut Vt.value,
1190 &mut wrk.w5,
1191 &mut s,
1192 max_iterations_imtql2,
1193 )?;
1194
1195 let max_eigenvalue = Vt
1196 .value
1197 .iter()
1198 .fold(T::zero(), |max, &val| max.max(val.abs()));
1199
1200 let adaptive_kappa = if sparsity > T::from_f64(0.99).unwrap() {
1201 kappa * T::from_f64(10.0).unwrap()
1203 } else {
1204 kappa
1205 };
1206
1207 let mut nsig = 0;
1208 let mut x = 0;
1209 let mut id2 = jsq - js;
1210
1211 let mut significant_count = 0;
1212 for k in 0..js {
1213 let relative_bound = adaptive_kappa * wrk.ritz[k].abs().max(max_eigenvalue * adaptive_eps);
1215 if wrk.bnd[k] <= relative_bound && k + 1 > js - neig {
1216 significant_count += 1;
1217 }
1218 }
1219
1220 id2 = jsq - js;
1221 for k in 0..js {
1222 let relative_bound = adaptive_kappa * wrk.ritz[k].abs().max(max_eigenvalue * adaptive_eps);
1224 if wrk.bnd[k] <= relative_bound && k + 1 > js - neig {
1225 x = match x {
1226 0 => dimensions - 1,
1227 _ => x - 1,
1228 };
1229
1230 let offset = x * Vt.cols;
1231 Vt.value[offset..offset + Vt.cols].fill(T::zero());
1232 let mut idx = id2 + js;
1233
1234 for i in 0..js {
1235 idx -= js;
1236 if s[idx].abs() > adaptive_eps {
1238 for (j, item) in store.retrq(i).iter().enumerate().take(Vt.cols) {
1239 Vt.value[j + offset] += s[idx] * *item;
1240 }
1241 }
1242 }
1243 nsig += 1;
1244 }
1245 id2 += 1;
1246 }
1247
1248 if x > 0 {
1251 rotate_array(&mut Vt.value, x * Vt.cols);
1252 }
1253
1254 let d = dimensions.min(nsig);
1256 let mut S = vec![T::zero(); d];
1257 let mut Ut = DMat {
1258 cols: wrk.nrows,
1259 value: vec![T::zero(); wrk.nrows * d],
1260 };
1261 Vt.value.resize(Vt.cols * d, T::zero());
1262
1263 let mut tmp_vec = vec![T::zero(); Vt.cols];
1264 for (i, sval) in S.iter_mut().enumerate() {
1265 let vt_offset = i * Vt.cols;
1266 let ut_offset = i * Ut.cols;
1267
1268 let vt_vec = &Vt.value[vt_offset..vt_offset + Vt.cols];
1269 let ut_vec = &mut Ut.value[ut_offset..ut_offset + Ut.cols];
1270
1271 svd_opb(A, vt_vec, &mut tmp_vec, &mut wrk.temp, wrk.transposed);
1273 let t = svd_ddot(vt_vec, &tmp_vec);
1274
1275 *sval = t.max(T::zero()).sqrt();
1278
1279 if t > adaptive_eps {
1281 svd_daxpy(-t, vt_vec, &mut tmp_vec);
1282 if *sval > adaptive_eps {
1284 wrk.bnd[js] = svd_norm(&tmp_vec) / *sval;
1285 } else {
1286 wrk.bnd[js] = T::from_f64(f64::MAX).unwrap() * T::from_f64(0.1).unwrap();
1287 }
1288
1289 A.svd_opa(vt_vec, ut_vec, wrk.transposed);
1291
1292 if *sval > adaptive_eps {
1294 svd_dscal(T::one() / *sval, ut_vec);
1295 } else {
1296 let dls = sval.max(adaptive_eps);
1298 let safe_scale = T::one() / dls;
1299 svd_dscal(safe_scale, ut_vec);
1300 }
1301 } else {
1302 A.svd_opa(vt_vec, ut_vec, wrk.transposed);
1305 let norm = svd_norm(ut_vec);
1306 if norm > adaptive_eps {
1307 svd_dscal(T::one() / norm, ut_vec);
1308 }
1309 wrk.bnd[js] = T::from_f64(f64::MAX).unwrap() * T::from_f64(0.01).unwrap();
1310 }
1311 }
1312
1313 Ok(SVDRawRec {
1314 d,
1316
1317 nsig,
1319
1320 Ut,
1323
1324 S,
1326
1327 Vt,
1330 })
1331}
1332
1333#[allow(non_snake_case)]
1334#[allow(clippy::too_many_arguments)]
1335fn lanso<T: SvdFloat>(
1336 A: &dyn SMat<T>,
1337 dim: usize,
1338 iterations: usize,
1339 end_interval: &[T; 2],
1340 wrk: &mut WorkSpace<T>,
1341 neig: &mut usize,
1342 store: &mut Store<T>,
1343 random_seed: u32,
1344) -> Result<usize, SvdLibError> {
1345 let sparsity = T::one()
1346 - (T::from_usize(A.nnz()).unwrap()
1347 / (T::from_usize(A.nrows()).unwrap() * T::from_usize(A.ncols()).unwrap()));
1348 let max_iterations_imtqlb = if sparsity > T::from_f64(0.999).unwrap() {
1349 Some(500)
1351 } else if sparsity > T::from_f64(0.99).unwrap() {
1352 Some(200)
1354 } else if sparsity > T::from_f64(0.9).unwrap() {
1355 Some(100)
1357 } else {
1358 Some(50)
1360 };
1361
1362 let epsilon = T::epsilon();
1363 let adaptive_eps = if sparsity > T::from_f64(0.99).unwrap() {
1364 epsilon * T::from_f64(100.0).unwrap()
1366 } else if sparsity > T::from_f64(0.9).unwrap() {
1367 epsilon * T::from_f64(10.0).unwrap()
1369 } else {
1370 epsilon
1372 };
1373
1374 let (endl, endr) = (end_interval[0], end_interval[1]);
1375
1376 let rnm_tol = stpone(A, wrk, store, random_seed)?;
1378 let mut rnm = rnm_tol.0;
1379 let mut tol = rnm_tol.1;
1380
1381 let eps1 = adaptive_eps * T::from_f64(wrk.ncols as f64).unwrap().sqrt();
1382 wrk.eta[0] = eps1;
1383 wrk.oldeta[0] = eps1;
1384 let mut ll = 0;
1385 let mut first = 1;
1386 let mut last = iterations.min(dim.max(8) + dim);
1387 let mut enough = false;
1388 let mut j = 0;
1389 let mut intro = 0;
1390
1391 while !enough {
1392 if rnm <= tol {
1393 rnm = T::zero();
1394 }
1395
1396 let steps = lanczos_step(
1398 A,
1399 wrk,
1400 first,
1401 last,
1402 &mut ll,
1403 &mut enough,
1404 &mut rnm,
1405 &mut tol,
1406 store,
1407 )?;
1408 j = match enough {
1409 true => steps - 1,
1410 false => last - 1,
1411 };
1412
1413 first = j + 1;
1414 wrk.bet[first] = rnm;
1415
1416 let mut l = 0;
1418 for _ in 0..j {
1419 if l > j {
1420 break;
1421 }
1422
1423 let mut i = l;
1424 while i <= j {
1425 if wrk.bet[i + 1].abs() <= adaptive_eps {
1426 break;
1427 }
1428 i += 1;
1429 }
1430 i = i.min(j);
1431
1432 let sz = i - l;
1434 svd_dcopy(sz + 1, l, &wrk.alf, &mut wrk.ritz);
1435 svd_dcopy(sz, l + 1, &wrk.bet, &mut wrk.w5);
1436
1437 imtqlb(
1438 sz + 1,
1439 &mut wrk.ritz[l..],
1440 &mut wrk.w5[l..],
1441 &mut wrk.bnd[l..],
1442 max_iterations_imtqlb,
1443 )?;
1444
1445 for m in l..=i {
1446 wrk.bnd[m] = rnm * wrk.bnd[m].abs();
1447 }
1448 l = i + 1;
1449 }
1450
1451 insert_sort(j + 1, &mut wrk.ritz, &mut wrk.bnd);
1453
1454 *neig = error_bound(&mut enough, endl, endr, &mut wrk.ritz, &mut wrk.bnd, j, tol);
1455
1456 if *neig < dim {
1458 if *neig == 0 {
1459 last = first + 9;
1460 intro = first;
1461 } else {
1462 let extra_steps = if sparsity > T::from_f64(0.99).unwrap() {
1463 5 } else {
1465 0
1466 };
1467
1468 last = first + 3.max(1 + ((j - intro) * (dim - *neig)) / *neig) + extra_steps;
1469 }
1470 last = last.min(iterations);
1471 } else {
1472 enough = true
1473 }
1474 enough = enough || first >= iterations;
1475 }
1476 store.storq(j, &wrk.w1);
1477 Ok(j)
1478}
1479
1480impl<T: SvdFloat + 'static> SvdRec<T> {
1481 pub fn recompose(&self) -> Array2<T> {
1482 let sdiag = Array2::from_diag(&self.s);
1483 self.ut.t().dot(&sdiag).dot(&self.vt)
1484 }
1485}
1486
1487#[rustfmt::skip]
1488impl<T: Float + Zero + AddAssign + Clone> SMat<T> for nalgebra_sparse::csc::CscMatrix<T> {
1489 fn nrows(&self) -> usize { self.nrows() }
1490 fn ncols(&self) -> usize { self.ncols() }
1491 fn nnz(&self) -> usize { self.nnz() }
1492
1493 fn svd_opa(&self, x: &[T], y: &mut [T], transposed: bool) {
1495 let nrows = if transposed { self.ncols() } else { self.nrows() };
1496 let ncols = if transposed { self.nrows() } else { self.ncols() };
1497 assert_eq!(x.len(), ncols, "svd_opa: x must be A.ncols() in length, x = {}, A.ncols = {}", x.len(), ncols);
1498 assert_eq!(y.len(), nrows, "svd_opa: y must be A.nrows() in length, y = {}, A.nrows = {}", y.len(), nrows);
1499
1500 let (major_offsets, minor_indices, values) = self.csc_data();
1501
1502 for y_val in y.iter_mut() {
1503 *y_val = T::zero();
1504 }
1505
1506 if transposed {
1507 for (i, yval) in y.iter_mut().enumerate() {
1508 for j in major_offsets[i]..major_offsets[i + 1] {
1509 *yval += values[j] * x[minor_indices[j]];
1510 }
1511 }
1512 } else {
1513 for (i, xval) in x.iter().enumerate() {
1514 for j in major_offsets[i]..major_offsets[i + 1] {
1515 y[minor_indices[j]] += values[j] * *xval;
1516 }
1517 }
1518 }
1519 }
1520}
1521
1522#[rustfmt::skip]
1523impl<T: Float + Zero + AddAssign + Clone> SMat<T> for nalgebra_sparse::csr::CsrMatrix<T> {
1524 fn nrows(&self) -> usize { self.nrows() }
1525 fn ncols(&self) -> usize { self.ncols() }
1526 fn nnz(&self) -> usize { self.nnz() }
1527
1528 fn svd_opa(&self, x: &[T], y: &mut [T], transposed: bool) {
1530 let nrows = if transposed { self.ncols() } else { self.nrows() };
1531 let ncols = if transposed { self.nrows() } else { self.ncols() };
1532 assert_eq!(x.len(), ncols, "svd_opa: x must be A.ncols() in length, x = {}, A.ncols = {}", x.len(), ncols);
1533 assert_eq!(y.len(), nrows, "svd_opa: y must be A.nrows() in length, y = {}, A.nrows = {}", y.len(), nrows);
1534
1535 let (major_offsets, minor_indices, values) = self.csr_data();
1536
1537 for y_val in y.iter_mut() {
1538 *y_val = T::zero();
1539 }
1540
1541 if !transposed {
1542 for (i, yval) in y.iter_mut().enumerate() {
1543 for j in major_offsets[i]..major_offsets[i + 1] {
1544 *yval += values[j] * x[minor_indices[j]];
1545 }
1546 }
1547 } else {
1548 for (i, xval) in x.iter().enumerate() {
1549 for j in major_offsets[i]..major_offsets[i + 1] {
1550 y[minor_indices[j]] += values[j] * *xval;
1551 }
1552 }
1553 }
1554 }
1555}
1556
1557#[rustfmt::skip]
1558impl<T: Float + Zero + AddAssign + Clone> SMat<T> for nalgebra_sparse::coo::CooMatrix<T> {
1559 fn nrows(&self) -> usize { self.nrows() }
1560 fn ncols(&self) -> usize { self.ncols() }
1561 fn nnz(&self) -> usize { self.nnz() }
1562
1563 fn svd_opa(&self, x: &[T], y: &mut [T], transposed: bool) {
1565 let nrows = if transposed { self.ncols() } else { self.nrows() };
1566 let ncols = if transposed { self.nrows() } else { self.ncols() };
1567 assert_eq!(x.len(), ncols, "svd_opa: x must be A.ncols() in length, x = {}, A.ncols = {}", x.len(), ncols);
1568 assert_eq!(y.len(), nrows, "svd_opa: y must be A.nrows() in length, y = {}, A.nrows = {}", y.len(), nrows);
1569
1570 for y_val in y.iter_mut() {
1571 *y_val = T::zero();
1572 }
1573
1574 if transposed {
1575 for (i, j, v) in self.triplet_iter() {
1576 y[j] += *v * x[i];
1577 }
1578 } else {
1579 for (i, j, v) in self.triplet_iter() {
1580 y[i] += *v * x[j];
1581 }
1582 }
1583 }
1584}