1use std::fmt::Debug;
2use std::iter::Sum;
3use std::mem;
4use std::ops::{AddAssign, MulAssign, Neg, SubAssign};
5use ndarray::{Array, Array1, Array2};
6use num_traits::{Float, FromPrimitive, One, Zero};
7use num_traits::real::Real;
8use rand::{thread_rng, Rng, SeedableRng};
9use rand::rngs::StdRng;
10use crate::error::SvdLibError;
11
12pub trait SMat<T: Float> {
13 fn nrows(&self) -> usize;
14 fn ncols(&self) -> usize;
15 fn nnz(&self) -> usize;
16 fn svd_opa(&self, x: &[T], y: &mut [T], transposed: bool); }
18
19#[derive(Debug, Clone, PartialEq)]
28pub struct SvdRec<T: Float> {
29 pub d: usize,
30 pub ut: Array2<T>,
31 pub s: Array1<T>,
32 pub vt: Array2<T>,
33 pub diagnostics: Diagnostics<T>
34}
35
36#[derive(Debug, Clone, PartialEq)]
51pub struct Diagnostics<T: Float> {
52 pub non_zero: usize,
53 pub dimensions: usize,
54 pub iterations: usize,
55 pub transposed: bool,
56 pub lanczos_steps: usize,
57 pub ritz_values_stabilized: usize,
58 pub significant_values: usize,
59 pub singular_values: usize,
60 pub end_interval: [T; 2],
61 pub kappa: T,
62 pub random_seed: u32
63}
64
65pub trait SvdFloat: Float + FromPrimitive + Debug + Send + Sync + Zero + One + AddAssign + SubAssign + MulAssign + Neg<Output = Self> + Sum {
67 fn eps() -> Self;
68 fn eps34() -> Self;
69 fn compare(a: Self, b: Self) -> bool;
70}
71
72impl SvdFloat for f32 {
73 fn eps() -> Self {
74 f32::EPSILON
75 }
76
77 fn eps34() -> Self {
78 f32::EPSILON.powf(0.75)
79 }
80
81 fn compare(a: Self, b: Self) -> bool {
82 (b - a).abs() < f32::EPSILON
83 }
84}
85
86impl SvdFloat for f64 {
87 fn eps() -> Self {
88 f64::EPSILON
89 }
90
91 fn eps34() -> Self {
92 f64::EPSILON.powf(0.75)
93 }
94
95 fn compare(a: Self, b: Self) -> bool {
96 (b - a).abs() < f64::EPSILON
97 }
98}
99
100pub fn svd<T, M>(a: &M) -> Result<SvdRec<T>, SvdLibError>
107where
108 T: SvdFloat,
109 M: SMat<T>
110{
111 let eps_small = T::from_f64(-1.0e-30).unwrap();
112 let eps_large = T::from_f64(1.0e-30).unwrap();
113 let kappa = T::from_f64(1.0e-6).unwrap();
114 svd_las2(a, 0, 0, &[eps_small, eps_large], kappa, 0)
115}
116
117pub fn svd_dim<T, M>(a: &M, dimensions: usize) -> Result<SvdRec<T>, SvdLibError>
125where
126 T: SvdFloat,
127 M: SMat<T> {
128 let eps_small = T::from_f64(-1.0e-30).unwrap();
129 let eps_large = T::from_f64(1.0e-30).unwrap();
130 let kappa = T::from_f64(1.0e-6).unwrap();
131
132 svd_las2(a, dimensions, 0, &[eps_small, eps_large], kappa, 0)
133}
134
135pub fn svd_dim_seed<T, M>(a: &M, dimensions: usize, random_seed: u32) -> Result<SvdRec<T>, SvdLibError>
144where
145 T: SvdFloat,
146 M: SMat<T> {
147 let eps_small = T::from_f64(-1.0e-30).unwrap();
148 let eps_large = T::from_f64(1.0e-30).unwrap();
149 let kappa = T::from_f64(1.0e-6).unwrap();
150
151 svd_las2(a, dimensions, 0, &[eps_small, eps_large], kappa, random_seed)
152}
153
154pub fn svd_las2<T, M>(
171 a: &M,
172 dimensions: usize,
173 iterations: usize,
174 end_interval: &[T; 2],
175 kappa: T,
176 random_seed: u32
177) -> Result<SvdRec<T>, SvdLibError>
178where
179 T: SvdFloat,
180 M: SMat<T> {
181 let random_seed = match random_seed > 0 {
182 true => random_seed,
183 false => thread_rng().gen::<_>()
184 };
185
186 let min_nrows_ncols = a.nrows().min(a.ncols());
187
188 let dimensions = match dimensions {
189 n if n == 0 || n > min_nrows_ncols => min_nrows_ncols,
190 _ => dimensions
191 };
192
193 let iterations = match iterations {
194 n if n == 0 || n > min_nrows_ncols => min_nrows_ncols,
195 n if n < dimensions => dimensions,
196 _ => iterations
197 };
198
199 if dimensions < 2 {
200 return Err(SvdLibError::Las2Error(format!(
201 "svd_las2: insufficient dimensions: {dimensions}"
202 )))
203 }
204
205 assert!(dimensions > 1 && dimensions <= min_nrows_ncols);
206 assert!(iterations >= dimensions && iterations <= min_nrows_ncols);
207
208 let transposed = (a.ncols() as f64) >= ((a.nrows() as f64) * 1.2);
209 let nrows = if transposed { a.ncols() } else { a.nrows() };
210 let ncols = if transposed { a.nrows() } else { a.ncols() };
211
212 let mut wrk = WorkSpace::new(nrows, ncols, transposed, iterations)?;
213 let mut store = Store::new(ncols)?;
214
215 let mut neig = 0;
216 let steps = lanso(
217 a,
218 dimensions,
219 iterations,
220 end_interval,
221 &mut wrk,
222 &mut neig,
223 &mut store,
224 random_seed
225 )?;
226
227 let kappa = kappa.abs().max(T::eps34());
228 let mut r = ritvec(a, dimensions, kappa, &mut wrk, steps, neig, &mut store)?;
229
230 if transposed {
231 mem::swap(&mut r.Ut, &mut r.Vt);
232 }
233
234 Ok(SvdRec {
235 d: r.d,
237 ut: Array2::from_shape_vec((r.d, r.Ut.cols), r.Ut.value)?,
238 s: Array::from_shape_vec(r.d, r.S)?,
239 vt: Array2::from_shape_vec((r.d, r.Vt.cols), r.Vt.value)?,
240 diagnostics: Diagnostics {
241 non_zero: a.nnz(),
242 dimensions: dimensions,
243 iterations: iterations,
244 transposed: transposed,
245 lanczos_steps: steps + 1,
246 ritz_values_stabilized: neig,
247 significant_values: r.d,
248 singular_values: r.nsig,
249 end_interval: *end_interval,
250 kappa: kappa,
251 random_seed: random_seed,
252 },
253 })
254}
255
256const MAXLL: usize = 2;
257
258#[derive(Debug, Clone, PartialEq)]
259struct Store<T: Float> {
260 n: usize,
261 vecs: Vec<Vec<T>>
262}
263
264impl<T: Float + Zero + Clone> Store<T> {
265 fn new(n: usize) -> Result<Self, SvdLibError> {
266 Ok(Self { n, vecs: vec![] })
267 }
268
269 fn storq(&mut self, idx: usize, v: &[T]) {
270 while idx + MAXLL >= self.vecs.len() {
271 self.vecs.push(vec![T::zero(); self.n]);
272 }
273 self.vecs[idx + MAXLL].copy_from_slice(v);
274 }
275
276 fn storp(&mut self, idx: usize, v: &[T]) {
277 while idx >= self.vecs.len() {
278 self.vecs.push(vec![T::zero(); self.n]);
279 }
280 self.vecs[idx].copy_from_slice(v);
281 }
282
283 fn retrq(&mut self, idx: usize) -> &[T] {
284 &self.vecs[idx + MAXLL]
285 }
286
287 fn retrp(&mut self, idx: usize) -> &[T] {
288 &self.vecs[idx]
289 }
290}
291
292#[derive(Debug, Clone, PartialEq)]
293struct WorkSpace<T: Float> {
294 nrows: usize,
295 ncols: usize,
296 transposed: bool,
297 w0: Vec<T>, w1: Vec<T>, w2: Vec<T>, w3: Vec<T>, w4: Vec<T>, w5: Vec<T>, alf: Vec<T>, eta: Vec<T>, oldeta: Vec<T>, bet: Vec<T>, bnd: Vec<T>, ritz: Vec<T>, temp: Vec<T>, }
311
312impl<T: Float + Zero + FromPrimitive> WorkSpace<T> {
313 fn new(nrows: usize, ncols: usize, transposed: bool, iterations: usize) -> Result<Self, SvdLibError> {
314 Ok(Self {
315 nrows,
316 ncols,
317 transposed,
318 w0: vec![T::zero(); ncols],
319 w1: vec![T::zero(); ncols],
320 w2: vec![T::zero(); ncols],
321 w3: vec![T::zero(); ncols],
322 w4: vec![T::zero(); ncols],
323 w5: vec![T::zero(); ncols],
324 alf: vec![T::zero(); iterations],
325 eta: vec![T::zero(); iterations],
326 oldeta: vec![T::zero(); iterations],
327 bet: vec![T::zero(); 1 + iterations],
328 ritz: vec![T::zero(); 1 + iterations],
329 bnd: vec![T::from_f64(f64::MAX).unwrap(); 1 + iterations],
330 temp: vec![T::zero(); nrows],
331 })
332 }
333}
334
335#[derive(Debug, Clone, PartialEq)]
337struct DMat<T: Float> {
338 cols: usize,
339 value: Vec<T>,
340}
341
342#[allow(non_snake_case)]
343#[derive(Debug, Clone, PartialEq)]
344struct SVDRawRec<T: Float> {
345 d: usize,
346 nsig: usize,
347 Ut: DMat<T>,
348 S: Vec<T>,
349 Vt: DMat<T>,
350}
351
352fn compare<T: SvdFloat>(computed: T, expected: T) -> bool {
353 T::compare(computed, expected)
354}
355
356fn insert_sort<T: PartialOrd>(n: usize, array1: &mut [T], array2: &mut [T]) {
358 for i in 1..n {
359 for j in (1..i + 1).rev() {
360 if array1[j - 1] <= array1[j] {
361 break;
362 }
363 array1.swap(j - 1, j);
364 array2.swap(j - 1, j);
365 }
366 }
367}
368
369#[allow(non_snake_case)]
370#[rustfmt::skip]
371fn svd_opb<T: Float>(A: &dyn SMat<T>, x: &[T], y: &mut [T], temp: &mut [T], transposed: bool) {
372 let nrows = if transposed { A.ncols() } else { A.nrows() };
373 let ncols = if transposed { A.nrows() } else { A.ncols() };
374 assert_eq!(x.len(), ncols, "svd_opb: x must be A.ncols() in length, x = {}, A.ncols = {}", x.len(), ncols);
375 assert_eq!(y.len(), ncols, "svd_opb: y must be A.ncols() in length, y = {}, A.ncols = {}", y.len(), ncols);
376 assert_eq!(temp.len(), nrows, "svd_opa: temp must be A.nrows() in length, temp = {}, A.nrows = {}", temp.len(), nrows);
377 A.svd_opa(x, temp, transposed); A.svd_opa(temp, y, !transposed); }
380
381fn svd_daxpy<T: Float + AddAssign>(da: T, x: &[T], y: &mut [T]) {
383 for (xval, yval) in x.iter().zip(y.iter_mut()) {
384 *yval += da * *xval
385 }
386}
387
388fn svd_idamax<T: Float>(n: usize, x: &[T]) -> usize {
390 assert!(n > 0, "svd_idamax: unexpected inputs!");
391
392 match n {
393 1 => 0,
394 _ => {
395 let mut imax = 0;
396 for (i, xval) in x.iter().enumerate().take(n).skip(1) {
397 if xval.abs() > x[imax].abs() {
398 imax = i;
399 }
400 }
401 imax
402 }
403 }
404}
405
406fn svd_fsign<T: Float>(a: T, b: T) -> T {
408 match (a >= T::zero() && b >= T::zero()) || (a < T::zero() && b < T::zero()) {
409 true => a,
410 false => -a,
411 }
412}
413
414fn svd_pythag<T: SvdFloat + FromPrimitive>(a: T, b: T) -> T {
416 match a.abs().max(b.abs()) {
417 n if n > T::zero() => {
418 let mut p = n;
419 let mut r = (a.abs().min(b.abs()) / p).powi(2);
420 let four = T::from_f64(4.0).unwrap();
421 let two = T::from_f64(2.0).unwrap();
422 let mut t = four + r;
423 while !compare(t, four) {
424 let s = r / t;
425 let u = T::one() + two * s;
426 p = p * u;
427 r = (s / u).powi(2);
428 t = four + r;
429 }
430 p
431 }
432 _ => T::zero(),
433 }
434}
435
436fn svd_ddot<T: Float + Sum<T>>(x: &[T], y: &[T]) -> T {
438 x.iter().zip(y).map(|(a, b)| *a * *b).sum()
439}
440
441fn svd_norm<T: Float + Sum<T>>(x: &[T]) -> T {
443 svd_ddot(x, x).sqrt()
444}
445
446fn svd_datx<T: Float + Sum<T>>(d: T, x: &[T], y: &mut [T]) {
448 for (i, xval) in x.iter().enumerate() {
449 y[i] = d * *xval;
450 }
451}
452
453fn svd_dscal<T: Float + MulAssign>(d: T, x: &mut [T]) {
455 for elem in x.iter_mut() {
456 *elem *= d;
457 }
458}
459
460fn svd_dcopy<T: Float + Copy>(n: usize, offset: usize, x: &[T], y: &mut [T]) {
462 if n > 0 {
463 let start = n - 1;
464 for i in 0..n {
465 y[offset + start - i] = x[offset + i];
466 }
467 }
468}
469
470fn imtqlb<T: SvdFloat>(n: usize, d: &mut [T], e: &mut [T], bnd: &mut [T]) -> Result<(), SvdLibError> {
471 if n == 1 {
472 return Ok(());
473 }
474
475 bnd[0] = T::one();
476 let last = n - 1;
477 for i in 1..=last {
478 bnd[i] = T::zero();
479 e[i - 1] = e[i];
480 }
481 e[last] = T::zero();
482
483 let mut i = 0;
484
485 for l in 0..=last {
486 let mut iteration = 0;
487 while iteration <= 30 {
488 let mut m = l;
489 while m < n {
490 if m == last {
491 break;
492 }
493 let test = d[m].abs() + d[m + 1].abs();
494 if compare(test, test + e[m].abs()) {
495 break; }
497 m += 1;
498 }
499 let mut p = d[l];
500 let mut f = bnd[l];
501 if m == l {
502 let mut exchange = true;
504 if l > 0 {
505 i = l;
506 while i >= 1 && exchange {
507 if p < d[i - 1] {
508 d[i] = d[i - 1];
509 bnd[i] = bnd[i - 1];
510 i -= 1;
511 } else {
512 exchange = false;
513 }
514 }
515 }
516 if exchange {
517 i = 0;
518 }
519 d[i] = p;
520 bnd[i] = f;
521 iteration = 31;
522 } else {
523 if iteration == 30 {
524 return Err(SvdLibError::ImtqlbError(
525 "imtqlb no convergence to an eigenvalue after 30 iterations".to_string(),
526 ));
527 }
528 iteration += 1;
529 let two = T::from_f64(2.0).unwrap();
531 let mut g = (d[l + 1] - p) / (two * e[l]);
532 let mut r = svd_pythag(g, T::one());
533 g = d[m] - p + e[l] / (g + svd_fsign(r, g));
534 let mut s = T::one();
535 let mut c = T::one();
536 p = T::zero();
537
538 assert!(m > 0, "imtqlb: expected 'm' to be non-zero");
539 i = m - 1;
540 let mut underflow = false;
541 while !underflow && i >= l {
542 f = s * e[i];
543 let b = c * e[i];
544 r = svd_pythag(f, g);
545 e[i + 1] = r;
546 if compare(r, T::zero()) {
547 underflow = true;
548 break;
549 }
550 s = f / r;
551 c = g / r;
552 g = d[i + 1] - p;
553 r = (d[i] - g) * s + two * c * b;
554 p = s * r;
555 d[i + 1] = g + p;
556 g = c * r - b;
557 f = bnd[i + 1];
558 bnd[i + 1] = s * bnd[i] + c * f;
559 bnd[i] = c * bnd[i] - s * f;
560 if i == 0 {
561 break;
562 }
563 i -= 1;
564 }
565 if underflow {
567 d[i + 1] -= p;
568 } else {
569 d[l] -= p;
570 e[l] = g;
571 }
572 e[m] = T::zero();
573 }
574 }
575 }
576 Ok(())
577}
578
579#[allow(non_snake_case)]
580fn startv<T: SvdFloat>(
581 A: &dyn SMat<T>,
582 wrk: &mut WorkSpace<T>,
583 step: usize,
584 store: &mut Store<T>,
585 random_seed: u32,
586) -> Result<T, SvdLibError> {
587 let mut rnm2 = svd_ddot(&wrk.w0, &wrk.w0);
589 for id in 0..3 {
590 if id > 0 || step > 0 || compare(rnm2, T::zero()) {
591 let mut bytes = [0; 32];
592 for (i, b) in random_seed.to_le_bytes().iter().enumerate() {
593 bytes[i] = *b;
594 }
595 let mut seeded_rng = StdRng::from_seed(bytes);
596 for val in wrk.w0.iter_mut() {
597 *val = T::from_f64(seeded_rng.gen_range(-1.0..1.0)).unwrap();
598 }
599 }
600 wrk.w3.copy_from_slice(&wrk.w0);
601
602 svd_opb(A, &wrk.w3, &mut wrk.w0, &mut wrk.temp, wrk.transposed);
604 wrk.w3.copy_from_slice(&wrk.w0);
605 rnm2 = svd_ddot(&wrk.w3, &wrk.w3);
606 if rnm2 > T::zero() {
607 break;
608 }
609 }
610
611 if rnm2 <= T::zero() {
612 return Err(SvdLibError::StartvError(format!("rnm2 <= 0.0, rnm2 = {rnm2:?}")));
613 }
614
615 if step > 0 {
616 for i in 0..step {
617 let v = store.retrq(i);
618 svd_daxpy(-svd_ddot(&wrk.w3, v), v, &mut wrk.w0);
619 }
620
621 svd_daxpy(-svd_ddot(&wrk.w4, &wrk.w0), &wrk.w2, &mut wrk.w0);
623 wrk.w3.copy_from_slice(&wrk.w0);
624
625 rnm2 = match svd_ddot(&wrk.w3, &wrk.w3) {
626 dot if dot <= T::eps() * rnm2 => T::zero(),
627 dot => dot,
628 }
629 }
630 Ok(rnm2.sqrt())
631}
632
633#[allow(non_snake_case)]
634fn stpone<T: SvdFloat>(
635 A: &dyn SMat<T>,
636 wrk: &mut WorkSpace<T>,
637 store: &mut Store<T>,
638 random_seed: u32,
639) -> Result<(T, T), SvdLibError> {
640 let mut rnm = startv(A, wrk, 0, store, random_seed)?;
642 if compare(rnm, T::zero()) {
643 return Err(SvdLibError::StponeError("rnm == 0.0".to_string()));
644 }
645
646 svd_datx(rnm.recip(), &wrk.w0, &mut wrk.w1);
648 svd_dscal(rnm.recip(), &mut wrk.w3);
649
650 svd_opb(A, &wrk.w3, &mut wrk.w0, &mut wrk.temp, wrk.transposed);
652 wrk.alf[0] = svd_ddot(&wrk.w0, &wrk.w3);
653 svd_daxpy(-wrk.alf[0], &wrk.w1, &mut wrk.w0);
654 let t = svd_ddot(&wrk.w0, &wrk.w3);
655 wrk.alf[0] += t;
656 svd_daxpy(-t, &wrk.w1, &mut wrk.w0);
657 wrk.w4.copy_from_slice(&wrk.w0);
658 rnm = svd_norm(&wrk.w4);
659 let anorm = rnm + wrk.alf[0].abs();
660 Ok((rnm, T::eps().sqrt() * anorm))
661}
662
663#[allow(non_snake_case)]
664#[allow(clippy::too_many_arguments)]
665fn lanczos_step<T: SvdFloat>(
666 A: &dyn SMat<T>,
667 wrk: &mut WorkSpace<T>,
668 first: usize,
669 last: usize,
670 ll: &mut usize,
671 enough: &mut bool,
672 rnm: &mut T,
673 tol: &mut T,
674 store: &mut Store<T>,
675) -> Result<usize, SvdLibError> {
676 let eps1 = T::eps() * T::from_f64(wrk.ncols as f64).unwrap().sqrt();
677 let mut j = first;
678 let four = T::from_f64(4.0).unwrap();
679
680 while j < last {
681 mem::swap(&mut wrk.w1, &mut wrk.w2);
682 mem::swap(&mut wrk.w3, &mut wrk.w4);
683
684 store.storq(j - 1, &wrk.w2);
685 if j - 1 < MAXLL {
686 store.storp(j - 1, &wrk.w4);
687 }
688 wrk.bet[j] = *rnm;
689
690 if compare(*rnm, T::zero()) {
692 *rnm = startv(A, wrk, j, store, 0)?;
693 if compare(*rnm, T::zero()) {
694 *enough = true;
695 }
696 }
697
698 if *enough {
699 mem::swap(&mut wrk.w1, &mut wrk.w2);
700 break;
701 }
702
703 svd_datx(rnm.recip(), &wrk.w0, &mut wrk.w1);
705 svd_dscal(rnm.recip(), &mut wrk.w3);
706 svd_opb(A, &wrk.w3, &mut wrk.w0, &mut wrk.temp, wrk.transposed);
707 svd_daxpy(-*rnm, &wrk.w2, &mut wrk.w0);
708 wrk.alf[j] = svd_ddot(&wrk.w0, &wrk.w3);
709 svd_daxpy(-wrk.alf[j], &wrk.w1, &mut wrk.w0);
710
711 if j <= MAXLL && wrk.alf[j - 1].abs() > four * wrk.alf[j].abs() {
713 *ll = j;
714 }
715 for i in 0..(j - 1).min(*ll) {
716 let v1 = store.retrp(i);
717 let t = svd_ddot(v1, &wrk.w0);
718 let v2 = store.retrq(i);
719 svd_daxpy(-t, v2, &mut wrk.w0);
720 wrk.eta[i] = eps1;
721 wrk.oldeta[i] = eps1;
722 }
723
724 let t = svd_ddot(&wrk.w0, &wrk.w4);
726 svd_daxpy(-t, &wrk.w2, &mut wrk.w0);
727 if wrk.bet[j] > T::zero() {
728 wrk.bet[j] += t;
729 }
730 let t = svd_ddot(&wrk.w0, &wrk.w3);
731 svd_daxpy(-t, &wrk.w1, &mut wrk.w0);
732 wrk.alf[j] += t;
733 wrk.w4.copy_from_slice(&wrk.w0);
734 *rnm = svd_norm(&wrk.w4);
735 let anorm = wrk.bet[j] + wrk.alf[j].abs() + *rnm;
736 *tol = T::eps().sqrt() * anorm;
737
738 ortbnd(wrk, j, *rnm, eps1);
740
741 purge(wrk.ncols, *ll, wrk, j, rnm, *tol, store);
743 if *rnm <= *tol {
744 *rnm = T::zero();
745 }
746 j += 1;
747 }
748 Ok(j)
749}
750
751fn purge<T: SvdFloat>(
752 n: usize,
753 ll: usize,
754 wrk: &mut WorkSpace<T>,
755 step: usize,
756 rnm: &mut T,
757 tol: T,
758 store: &mut Store<T>,
759) {
760 if step < ll + 2 {
761 return;
762 }
763
764 let reps = T::eps().sqrt();
765 let eps1 = T::eps() * T::from_f64(n as f64).unwrap().sqrt();
766 let two = T::from_f64(2.0).unwrap();
767
768 let k = svd_idamax(step - (ll + 1), &wrk.eta) + ll;
769 if wrk.eta[k].abs() > reps {
770 let reps1 = eps1 / reps;
771 let mut iteration = 0;
772 let mut flag = true;
773 while iteration < 2 && flag {
774 if *rnm > tol {
775 let mut tq = T::zero();
777 let mut tr = T::zero();
778 for i in ll..step {
779 let v = store.retrq(i);
780 let t = svd_ddot(v, &wrk.w3);
781 tq += t.abs();
782 svd_daxpy(-t, v, &mut wrk.w1);
783 let t = svd_ddot(v, &wrk.w4);
784 tr += t.abs();
785 svd_daxpy(-t, v, &mut wrk.w0);
786 }
787 wrk.w3.copy_from_slice(&wrk.w1);
788 let t = svd_ddot(&wrk.w0, &wrk.w3);
789 tr += t.abs();
790 svd_daxpy(-t, &wrk.w1, &mut wrk.w0);
791 wrk.w4.copy_from_slice(&wrk.w0);
792 *rnm = svd_norm(&wrk.w4);
793 if tq <= reps1 && tr <= *rnm * reps1 {
794 flag = false;
795 }
796 }
797 iteration += 1;
798 }
799 for i in ll..=step {
800 wrk.eta[i] = eps1;
801 wrk.oldeta[i] = eps1;
802 }
803 }
804}
805
806fn ortbnd<T: SvdFloat>(wrk: &mut WorkSpace<T>, step: usize, rnm: T, eps1: T) {
807 if step < 1 {
808 return;
809 }
810 if !compare(rnm, T::zero()) && step > 1 {
811 wrk.oldeta[0] =
812 (wrk.bet[1] * wrk.eta[1] + (wrk.alf[0] - wrk.alf[step]) * wrk.eta[0] - wrk.bet[step] * wrk.oldeta[0]) / rnm
813 + eps1;
814 if step > 2 {
815 for i in 1..=step - 2 {
816 wrk.oldeta[i] = (wrk.bet[i + 1] * wrk.eta[i + 1]
817 + (wrk.alf[i] - wrk.alf[step]) * wrk.eta[i]
818 + wrk.bet[i] * wrk.eta[i - 1]
819 - wrk.bet[step] * wrk.oldeta[i])
820 / rnm
821 + eps1;
822 }
823 }
824 }
825 wrk.oldeta[step - 1] = eps1;
826 mem::swap(&mut wrk.oldeta, &mut wrk.eta);
827 wrk.eta[step] = eps1;
828}
829
830fn error_bound<T: SvdFloat>(
831 enough: &mut bool,
832 endl: T,
833 endr: T,
834 ritz: &mut [T],
835 bnd: &mut [T],
836 step: usize,
837 tol: T,
838) -> usize {
839 assert!(step > 0, "error_bound: expected 'step' to be non-zero");
840
841 let mid = svd_idamax(step + 1, bnd);
843 let sixteen = T::from_f64(16.0).unwrap();
844
845 let mut i = ((step + 1) + (step - 1)) / 2;
846 while i > mid + 1 {
847 if (ritz[i - 1] - ritz[i]).abs() < T::eps34() * ritz[i].abs() && bnd[i] > tol && bnd[i - 1] > tol {
848 bnd[i - 1] = (bnd[i].powi(2) + bnd[i - 1].powi(2)).sqrt();
849 bnd[i] = T::zero();
850 }
851 i -= 1;
852 }
853
854 let mut i = ((step + 1) - (step - 1)) / 2;
855 while i + 1 < mid {
856 if (ritz[i + 1] - ritz[i]).abs() < T::eps34() * ritz[i].abs() && bnd[i] > tol && bnd[i + 1] > tol {
857 bnd[i + 1] = (bnd[i].powi(2) + bnd[i + 1].powi(2)).sqrt();
858 bnd[i] = T::zero();
859 }
860 i += 1;
861 }
862
863 let mut neig = 0;
865 let mut gapl = ritz[step] - ritz[0];
866 for i in 0..=step {
867 let mut gap = gapl;
868 if i < step {
869 gapl = ritz[i + 1] - ritz[i];
870 }
871 gap = gap.min(gapl);
872 if gap > bnd[i] {
873 bnd[i] *= bnd[i] / gap;
874 }
875 if bnd[i] <= sixteen * T::eps() * ritz[i].abs() {
876 neig += 1;
877 if !*enough {
878 *enough = endl < ritz[i] && ritz[i] < endr;
879 }
880 }
881 }
882 neig
883}
884
885fn imtql2<T: SvdFloat>(nm: usize, n: usize, d: &mut [T], e: &mut [T], z: &mut [T]) -> Result<(), SvdLibError> {
886 if n == 1 {
887 return Ok(());
888 }
889 assert!(n > 1, "imtql2: expected 'n' to be > 1");
890 let two = T::from_f64(2.0).unwrap();
891
892 let last = n - 1;
893
894 for i in 1..n {
895 e[i - 1] = e[i];
896 }
897 e[last] = T::zero();
898
899 let nnm = n * nm;
900 for l in 0..n {
901 let mut iteration = 0;
902
903 while iteration <= 30 {
905 let mut m = l;
906 while m < n {
907 if m == last {
908 break;
909 }
910 let test = d[m].abs() + d[m + 1].abs();
911 if compare(test, test + e[m].abs()) {
912 break; }
914 m += 1;
915 }
916 if m == l {
917 break;
918 }
919
920 if iteration == 30 {
922 return Err(SvdLibError::Imtql2Error(
923 "imtql2 no convergence to an eigenvalue after 30 iterations".to_string(),
924 ));
925 }
926 iteration += 1;
927
928 let mut g = (d[l + 1] - d[l]) / (two * e[l]);
930 let mut r = svd_pythag(g, T::one());
931 g = d[m] - d[l] + e[l] / (g + svd_fsign(r, g));
932
933 let mut s = T::one();
934 let mut c = T::one();
935 let mut p = T::zero();
936
937 assert!(m > 0, "imtql2: expected 'm' to be non-zero");
938 let mut i = m - 1;
939 let mut underflow = false;
940 while !underflow && i >= l {
941 let mut f = s * e[i];
942 let b = c * e[i];
943 r = svd_pythag(f, g);
944 e[i + 1] = r;
945 if compare(r, T::zero()) {
946 underflow = true;
947 } else {
948 s = f / r;
949 c = g / r;
950 g = d[i + 1] - p;
951 r = (d[i] - g) * s + two * c * b;
952 p = s * r;
953 d[i + 1] = g + p;
954 g = c * r - b;
955
956 for k in (0..nnm).step_by(n) {
958 let index = k + i;
959 f = z[index + 1];
960 z[index + 1] = s * z[index] + c * f;
961 z[index] = c * z[index] - s * f;
962 }
963 if i == 0 {
964 break;
965 }
966 i -= 1;
967 }
968 } if underflow {
971 d[i + 1] -= p;
972 } else {
973 d[l] -= p;
974 e[l] = g;
975 }
976 e[m] = T::zero();
977 }
978 }
979
980 for l in 1..n {
982 let i = l - 1;
983 let mut k = i;
984 let mut p = d[i];
985 for (j, item) in d.iter().enumerate().take(n).skip(l) {
986 if *item < p {
987 k = j;
988 p = *item;
989 }
990 }
991
992 if k != i {
994 d[k] = d[i];
995 d[i] = p;
996 for j in (0..nnm).step_by(n) {
997 z.swap(j + i, j + k);
998 }
999 }
1000 }
1001
1002 Ok(())
1003}
1004
1005fn rotate_array<T: Float + Copy>(a: &mut [T], x: usize) {
1006 let n = a.len();
1007 let mut j = 0;
1008 let mut start = 0;
1009 let mut t1 = a[0];
1010
1011 for _ in 0..n {
1012 j = match j >= x {
1013 true => j - x,
1014 false => j + n - x,
1015 };
1016
1017 let t2 = a[j];
1018 a[j] = t1;
1019
1020 if j == start {
1021 j += 1;
1022 start = j;
1023 t1 = a[j];
1024 } else {
1025 t1 = t2;
1026 }
1027 }
1028}
1029
1030#[allow(non_snake_case)]
1031fn ritvec<T: SvdFloat>(
1032 A: &dyn SMat<T>,
1033 dimensions: usize,
1034 kappa: T,
1035 wrk: &mut WorkSpace<T>,
1036 steps: usize,
1037 neig: usize,
1038 store: &mut Store<T>,
1039) -> Result<SVDRawRec<T>, SvdLibError> {
1040 let js = steps + 1;
1041 let jsq = js * js;
1042 let mut s = vec![T::zero(); jsq];
1043
1044 for i in (0..jsq).step_by(js + 1) {
1046 s[i] = T::one();
1047 }
1048
1049 let mut Vt = DMat {
1050 cols: wrk.ncols,
1051 value: vec![T::zero(); wrk.ncols * dimensions],
1052 };
1053
1054 svd_dcopy(js, 0, &wrk.alf, &mut Vt.value);
1055 svd_dcopy(steps, 1, &wrk.bet, &mut wrk.w5);
1056
1057 imtql2(js, js, &mut Vt.value, &mut wrk.w5, &mut s)?;
1060
1061 let mut nsig = 0;
1062 let mut x = 0;
1063 let mut id2 = jsq - js;
1064 for k in 0..js {
1065 if wrk.bnd[k] <= kappa * wrk.ritz[k].abs() && k + 1 > js - neig {
1066 x = match x {
1067 0 => dimensions - 1,
1068 _ => x - 1,
1069 };
1070
1071 let offset = x * Vt.cols;
1072 Vt.value[offset..offset + Vt.cols].fill(T::zero());
1073 let mut idx = id2 + js;
1074 for i in 0..js {
1075 idx -= js;
1076 if s[idx] != T::zero() {
1077 for (j, item) in store.retrq(i).iter().enumerate().take(Vt.cols) {
1078 Vt.value[j + offset] += s[idx] * *item;
1079 }
1080 }
1081 }
1082 nsig += 1;
1083 }
1084 id2 += 1;
1085 }
1086
1087 if x > 0 {
1090 rotate_array(&mut Vt.value, x * Vt.cols);
1091 }
1092
1093 let d = dimensions.min(nsig);
1095 let mut S = vec![T::zero(); d];
1096 let mut Ut = DMat {
1097 cols: wrk.nrows,
1098 value: vec![T::zero(); wrk.nrows * d],
1099 };
1100 Vt.value.resize(Vt.cols * d, T::zero());
1101
1102 let mut tmp_vec = vec![T::zero(); Vt.cols];
1103 for (i, sval) in S.iter_mut().enumerate() {
1104 let vt_offset = i * Vt.cols;
1105 let ut_offset = i * Ut.cols;
1106
1107 let vt_vec = &Vt.value[vt_offset..vt_offset + Vt.cols];
1108 let ut_vec = &mut Ut.value[ut_offset..ut_offset + Ut.cols];
1109
1110 svd_opb(A, vt_vec, &mut tmp_vec, &mut wrk.temp, wrk.transposed);
1112 let t = svd_ddot(vt_vec, &tmp_vec);
1113
1114 *sval = t.sqrt();
1116
1117 svd_daxpy(-t, vt_vec, &mut tmp_vec);
1118 wrk.bnd[js] = svd_norm(&tmp_vec) * sval.recip();
1119
1120 A.svd_opa(vt_vec, ut_vec, wrk.transposed);
1122 svd_dscal(sval.recip(), ut_vec);
1123 }
1124
1125 Ok(SVDRawRec {
1126 d,
1128
1129 nsig,
1131
1132 Ut,
1135
1136 S,
1138
1139 Vt,
1142 })
1143}
1144
1145#[allow(non_snake_case)]
1146#[allow(clippy::too_many_arguments)]
1147fn lanso<T: SvdFloat>(
1148 A: &dyn SMat<T>,
1149 dim: usize,
1150 iterations: usize,
1151 end_interval: &[T; 2],
1152 wrk: &mut WorkSpace<T>,
1153 neig: &mut usize,
1154 store: &mut Store<T>,
1155 random_seed: u32,
1156) -> Result<usize, SvdLibError> {
1157 let (endl, endr) = (end_interval[0], end_interval[1]);
1158
1159 let rnm_tol = stpone(A, wrk, store, random_seed)?;
1161 let mut rnm = rnm_tol.0;
1162 let mut tol = rnm_tol.1;
1163
1164 let eps1 = T::eps() * T::from_f64(wrk.ncols as f64).unwrap().sqrt();
1165 wrk.eta[0] = eps1;
1166 wrk.oldeta[0] = eps1;
1167 let mut ll = 0;
1168 let mut first = 1;
1169 let mut last = iterations.min(dim.max(8) + dim);
1170 let mut enough = false;
1171 let mut j = 0;
1172 let mut intro = 0;
1173
1174 while !enough {
1175 if rnm <= tol {
1176 rnm = T::zero();
1177 }
1178
1179 let steps = lanczos_step(A, wrk, first, last, &mut ll, &mut enough, &mut rnm, &mut tol, store)?;
1181 j = match enough {
1182 true => steps - 1,
1183 false => last - 1,
1184 };
1185
1186 first = j + 1;
1187 wrk.bet[first] = rnm;
1188
1189 let mut l = 0;
1191 for _ in 0..j {
1192 if l > j {
1193 break;
1194 }
1195
1196 let mut i = l;
1197 while i <= j {
1198 if compare(wrk.bet[i + 1], T::zero()) {
1199 break;
1200 }
1201 i += 1;
1202 }
1203 i = i.min(j);
1204
1205 let sz = i - l;
1207 svd_dcopy(sz + 1, l, &wrk.alf, &mut wrk.ritz);
1208 svd_dcopy(sz, l + 1, &wrk.bet, &mut wrk.w5);
1209
1210 imtqlb(sz + 1, &mut wrk.ritz[l..], &mut wrk.w5[l..], &mut wrk.bnd[l..])?;
1211
1212 for m in l..=i {
1213 wrk.bnd[m] = rnm * wrk.bnd[m].abs();
1214 }
1215 l = i + 1;
1216 }
1217
1218 insert_sort(j + 1, &mut wrk.ritz, &mut wrk.bnd);
1220
1221 *neig = error_bound(&mut enough, endl, endr, &mut wrk.ritz, &mut wrk.bnd, j, tol);
1222
1223 if *neig < dim {
1225 if *neig == 0 {
1226 last = first + 9;
1227 intro = first;
1228 } else {
1229 last = first + 3.max(1 + ((j - intro) * (dim - *neig)) / *neig);
1230 }
1231 last = last.min(iterations);
1232 } else {
1233 enough = true
1234 }
1235 enough = enough || first >= iterations;
1236 }
1237 store.storq(j, &wrk.w1);
1238 Ok(j)
1239}
1240
1241impl<T: SvdFloat + 'static> SvdRec<T> {
1242 pub fn recompose(&self) -> Array2<T> {
1243 let sdiag = Array2::from_diag(&self.s);
1244 self.ut.t().dot(&sdiag).dot(&self.vt)
1245 }
1246}
1247
1248#[rustfmt::skip]
1249impl<T: Float + Zero + AddAssign + Clone> SMat<T> for nalgebra_sparse::csc::CscMatrix<T> {
1250 fn nrows(&self) -> usize { self.nrows() }
1251 fn ncols(&self) -> usize { self.ncols() }
1252 fn nnz(&self) -> usize { self.nnz() }
1253
1254 fn svd_opa(&self, x: &[T], y: &mut [T], transposed: bool) {
1256 let nrows = if transposed { self.ncols() } else { self.nrows() };
1257 let ncols = if transposed { self.nrows() } else { self.ncols() };
1258 assert_eq!(x.len(), ncols, "svd_opa: x must be A.ncols() in length, x = {}, A.ncols = {}", x.len(), ncols);
1259 assert_eq!(y.len(), nrows, "svd_opa: y must be A.nrows() in length, y = {}, A.nrows = {}", y.len(), nrows);
1260
1261 let (major_offsets, minor_indices, values) = self.csc_data();
1262
1263 for y_val in y.iter_mut() {
1264 *y_val = T::zero();
1265 }
1266
1267 if transposed {
1268 for (i, yval) in y.iter_mut().enumerate() {
1269 for j in major_offsets[i]..major_offsets[i + 1] {
1270 *yval += values[j] * x[minor_indices[j]];
1271 }
1272 }
1273 } else {
1274 for (i, xval) in x.iter().enumerate() {
1275 for j in major_offsets[i]..major_offsets[i + 1] {
1276 y[minor_indices[j]] += values[j] * *xval;
1277 }
1278 }
1279 }
1280 }
1281}
1282
1283#[rustfmt::skip]
1284impl<T: Float + Zero + AddAssign + Clone> SMat<T> for nalgebra_sparse::csr::CsrMatrix<T> {
1285 fn nrows(&self) -> usize { self.nrows() }
1286 fn ncols(&self) -> usize { self.ncols() }
1287 fn nnz(&self) -> usize { self.nnz() }
1288
1289 fn svd_opa(&self, x: &[T], y: &mut [T], transposed: bool) {
1291 let nrows = if transposed { self.ncols() } else { self.nrows() };
1292 let ncols = if transposed { self.nrows() } else { self.ncols() };
1293 assert_eq!(x.len(), ncols, "svd_opa: x must be A.ncols() in length, x = {}, A.ncols = {}", x.len(), ncols);
1294 assert_eq!(y.len(), nrows, "svd_opa: y must be A.nrows() in length, y = {}, A.nrows = {}", y.len(), nrows);
1295
1296 let (major_offsets, minor_indices, values) = self.csr_data();
1297
1298 for y_val in y.iter_mut() {
1299 *y_val = T::zero();
1300 }
1301
1302 if !transposed {
1303 for (i, yval) in y.iter_mut().enumerate() {
1304 for j in major_offsets[i]..major_offsets[i + 1] {
1305 *yval += values[j] * x[minor_indices[j]];
1306 }
1307 }
1308 } else {
1309 for (i, xval) in x.iter().enumerate() {
1310 for j in major_offsets[i]..major_offsets[i + 1] {
1311 y[minor_indices[j]] += values[j] * *xval;
1312 }
1313 }
1314 }
1315 }
1316}
1317
1318#[rustfmt::skip]
1319impl<T: Float + Zero + AddAssign + Clone> SMat<T> for nalgebra_sparse::coo::CooMatrix<T> {
1320 fn nrows(&self) -> usize { self.nrows() }
1321 fn ncols(&self) -> usize { self.ncols() }
1322 fn nnz(&self) -> usize { self.nnz() }
1323
1324 fn svd_opa(&self, x: &[T], y: &mut [T], transposed: bool) {
1326 let nrows = if transposed { self.ncols() } else { self.nrows() };
1327 let ncols = if transposed { self.nrows() } else { self.ncols() };
1328 assert_eq!(x.len(), ncols, "svd_opa: x must be A.ncols() in length, x = {}, A.ncols = {}", x.len(), ncols);
1329 assert_eq!(y.len(), nrows, "svd_opa: y must be A.nrows() in length, y = {}, A.nrows = {}", y.len(), nrows);
1330
1331 for y_val in y.iter_mut() {
1332 *y_val = T::zero();
1333 }
1334
1335 if transposed {
1336 for (i, j, v) in self.triplet_iter() {
1337 y[j] += *v * x[i];
1338 }
1339 } else {
1340 for (i, j, v) in self.triplet_iter() {
1341 y[i] += *v * x[j];
1342 }
1343 }
1344 }
1345}
1346