1pub mod masked;
2
3use crate::error::SvdLibError;
4use crate::{Diagnostics, SMat, SvdFloat, SvdRec};
5use nalgebra_sparse::na::{DMatrix, DVector};
6use ndarray::{Array, Array2};
7use num_traits::real::Real;
8use num_traits::{Float, FromPrimitive, One, Zero};
9use rand::rngs::StdRng;
10use rand::{rng, Rng, RngCore, SeedableRng};
11use rayon::iter::IndexedParallelIterator;
12use rayon::iter::ParallelIterator;
13use rayon::prelude::{IntoParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator};
14use std::fmt::Debug;
15use std::iter::Sum;
16use std::mem;
17use std::ops::{AddAssign, MulAssign, Neg, SubAssign};
18
19pub fn svd<T, M>(a: &M) -> Result<SvdRec<T>, SvdLibError>
28where
29 T: SvdFloat,
30 M: SMat<T>,
31{
32 let eps_small = T::from_f64(-1.0e-30).unwrap();
33 let eps_large = T::from_f64(1.0e-30).unwrap();
34 let kappa = T::from_f64(1.0e-6).unwrap();
35 svd_las2(a, 0, 0, &[eps_small, eps_large], kappa, 0)
36}
37
38pub fn svd_dim<T, M>(a: &M, dimensions: usize) -> Result<SvdRec<T>, SvdLibError>
46where
47 T: SvdFloat,
48 M: SMat<T>,
49{
50 let eps_small = T::from_f64(-1.0e-30).unwrap();
51 let eps_large = T::from_f64(1.0e-30).unwrap();
52 let kappa = T::from_f64(1.0e-6).unwrap();
53
54 svd_las2(a, dimensions, 0, &[eps_small, eps_large], kappa, 0)
55}
56
57pub fn svd_dim_seed<T, M>(
66 a: &M,
67 dimensions: usize,
68 random_seed: u32,
69) -> Result<SvdRec<T>, SvdLibError>
70where
71 T: SvdFloat,
72 M: SMat<T>,
73{
74 let eps_small = T::from_f64(-1.0e-30).unwrap();
75 let eps_large = T::from_f64(1.0e-30).unwrap();
76 let kappa = T::from_f64(1.0e-6).unwrap();
77
78 svd_las2(
79 a,
80 dimensions,
81 0,
82 &[eps_small, eps_large],
83 kappa,
84 random_seed,
85 )
86}
87
88pub fn svd_las2<T, M>(
105 a: &M,
106 dimensions: usize,
107 iterations: usize,
108 end_interval: &[T; 2],
109 kappa: T,
110 random_seed: u32,
111) -> Result<SvdRec<T>, SvdLibError>
112where
113 T: SvdFloat,
114 M: SMat<T>,
115{
116 let random_seed = match random_seed > 0 {
117 true => random_seed,
118 false => rng().next_u32(),
119 };
120
121 let min_nrows_ncols = a.nrows().min(a.ncols());
122
123 let dimensions = match dimensions {
124 n if n == 0 || n > min_nrows_ncols => min_nrows_ncols,
125 _ => dimensions,
126 };
127
128 let iterations = match iterations {
129 n if n == 0 || n > min_nrows_ncols => min_nrows_ncols,
130 n if n < dimensions => dimensions,
131 _ => iterations,
132 };
133
134 if dimensions < 2 {
135 return Err(SvdLibError::Las2Error(format!(
136 "svd_las2: insufficient dimensions: {dimensions}"
137 )));
138 }
139
140 assert!(dimensions > 1 && dimensions <= min_nrows_ncols);
141 assert!(iterations >= dimensions && iterations <= min_nrows_ncols);
142
143 let transposed = (a.ncols() as f64) >= ((a.nrows() as f64) * 1.2);
144 let nrows = if transposed { a.ncols() } else { a.nrows() };
145 let ncols = if transposed { a.nrows() } else { a.ncols() };
146
147 let mut wrk = WorkSpace::new(nrows, ncols, transposed, iterations)?;
148 let mut store = Store::new(ncols)?;
149
150 let mut neig = 0;
151 let steps = lanso(
152 a,
153 dimensions,
154 iterations,
155 end_interval,
156 &mut wrk,
157 &mut neig,
158 &mut store,
159 random_seed,
160 )?;
161
162 let kappa = Float::max(Float::abs(kappa), T::eps34());
163 let mut r = ritvec(a, dimensions, kappa, &mut wrk, steps, neig, &mut store)?;
164
165 if transposed {
166 mem::swap(&mut r.Ut, &mut r.Vt);
167 }
168
169 Ok(SvdRec {
170 d: r.d,
172 u: Array2::from_shape_vec((r.d, r.Ut.cols), r.Ut.value)?,
173 s: Array::from_shape_vec(r.d, r.S)?,
174 vt: Array2::from_shape_vec((r.d, r.Vt.cols), r.Vt.value)?,
175 diagnostics: Diagnostics {
176 non_zero: a.nnz(),
177 dimensions: dimensions,
178 iterations: iterations,
179 transposed: transposed,
180 lanczos_steps: steps + 1,
181 ritz_values_stabilized: neig,
182 significant_values: r.d,
183 singular_values: r.nsig,
184 end_interval: *end_interval,
185 kappa: kappa,
186 random_seed: random_seed,
187 },
188 })
189}
190
191const MAXLL: usize = 2;
192
193#[derive(Debug, Clone, PartialEq)]
194struct Store<T: Float> {
195 n: usize,
196 vecs: Vec<Vec<T>>,
197}
198
199impl<T: Float + Zero + Clone> Store<T> {
200 fn new(n: usize) -> Result<Self, SvdLibError> {
201 Ok(Self { n, vecs: vec![] })
202 }
203
204 fn storq(&mut self, idx: usize, v: &[T]) {
205 while idx + MAXLL >= self.vecs.len() {
206 self.vecs.push(vec![T::zero(); self.n]);
207 }
208 self.vecs[idx + MAXLL].copy_from_slice(v);
209 }
210
211 fn storp(&mut self, idx: usize, v: &[T]) {
212 while idx >= self.vecs.len() {
213 self.vecs.push(vec![T::zero(); self.n]);
214 }
215 self.vecs[idx].copy_from_slice(v);
216 }
217
218 fn retrq(&mut self, idx: usize) -> &[T] {
219 &self.vecs[idx + MAXLL]
220 }
221
222 fn retrp(&mut self, idx: usize) -> &[T] {
223 &self.vecs[idx]
224 }
225}
226
227#[derive(Debug, Clone, PartialEq)]
228struct WorkSpace<T: Float> {
229 nrows: usize,
230 ncols: usize,
231 transposed: bool,
232 w0: Vec<T>, w1: Vec<T>, w2: Vec<T>, w3: Vec<T>, w4: Vec<T>, w5: Vec<T>, alf: Vec<T>, eta: Vec<T>, oldeta: Vec<T>, bet: Vec<T>, bnd: Vec<T>, ritz: Vec<T>, temp: Vec<T>, }
246
247impl<T: Float + Zero + FromPrimitive> WorkSpace<T> {
248 fn new(
249 nrows: usize,
250 ncols: usize,
251 transposed: bool,
252 iterations: usize,
253 ) -> Result<Self, SvdLibError> {
254 Ok(Self {
255 nrows,
256 ncols,
257 transposed,
258 w0: vec![T::zero(); ncols],
259 w1: vec![T::zero(); ncols],
260 w2: vec![T::zero(); ncols],
261 w3: vec![T::zero(); ncols],
262 w4: vec![T::zero(); ncols],
263 w5: vec![T::zero(); ncols],
264 alf: vec![T::zero(); iterations],
265 eta: vec![T::zero(); iterations],
266 oldeta: vec![T::zero(); iterations],
267 bet: vec![T::zero(); 1 + iterations],
268 ritz: vec![T::zero(); 1 + iterations],
269 bnd: vec![T::from_f64(f64::MAX).unwrap(); 1 + iterations],
270 temp: vec![T::zero(); nrows],
271 })
272 }
273}
274
275#[derive(Debug, Clone, PartialEq)]
277struct DMat<T: Float> {
278 cols: usize,
279 value: Vec<T>,
280}
281
282#[allow(non_snake_case)]
283#[derive(Debug, Clone, PartialEq)]
284struct SVDRawRec<T: Float> {
285 d: usize,
286 nsig: usize,
287 Ut: DMat<T>,
288 S: Vec<T>,
289 Vt: DMat<T>,
290}
291
292fn compare<T: SvdFloat>(computed: T, expected: T) -> bool {
293 T::compare(computed, expected)
294}
295
296fn insert_sort<T: PartialOrd>(n: usize, array1: &mut [T], array2: &mut [T]) {
298 for i in 1..n {
299 for j in (1..i + 1).rev() {
300 if array1[j - 1] <= array1[j] {
301 break;
302 }
303 array1.swap(j - 1, j);
304 array2.swap(j - 1, j);
305 }
306 }
307}
308
309#[allow(non_snake_case)]
310#[rustfmt::skip]
311fn svd_opb<T: Float>(A: &dyn SMat<T>, x: &[T], y: &mut [T], temp: &mut [T], transposed: bool) {
312 let nrows = if transposed { A.ncols() } else { A.nrows() };
313 let ncols = if transposed { A.nrows() } else { A.ncols() };
314 assert_eq!(x.len(), ncols, "svd_opb: x must be A.ncols() in length, x = {}, A.ncols = {}", x.len(), ncols);
315 assert_eq!(y.len(), ncols, "svd_opb: y must be A.ncols() in length, y = {}, A.ncols = {}", y.len(), ncols);
316 assert_eq!(temp.len(), nrows, "svd_opa: temp must be A.nrows() in length, temp = {}, A.nrows = {}", temp.len(), nrows);
317 A.svd_opa(x, temp, transposed); A.svd_opa(temp, y, !transposed); }
320
321fn svd_daxpy<T: Float + AddAssign + Send + Sync>(da: T, x: &[T], y: &mut [T]) {
323 if x.len() < 1000 {
324 for (xval, yval) in x.iter().zip(y.iter_mut()) {
325 *yval += da * *xval
326 }
327 } else {
328 y.par_iter_mut()
329 .zip(x.par_iter())
330 .for_each(|(yval, xval)| *yval += da * *xval);
331 }
332}
333
334fn svd_idamax<T: Float>(n: usize, x: &[T]) -> usize {
336 assert!(n > 0, "svd_idamax: unexpected inputs!");
337
338 match n {
339 1 => 0,
340 _ => {
341 let mut imax = 0;
342 for (i, xval) in x.iter().enumerate().take(n).skip(1) {
343 if xval.abs() > x[imax].abs() {
344 imax = i;
345 }
346 }
347 imax
348 }
349 }
350}
351
352fn svd_fsign<T: Float>(a: T, b: T) -> T {
354 match (a >= T::zero() && b >= T::zero()) || (a < T::zero() && b < T::zero()) {
355 true => a,
356 false => -a,
357 }
358}
359
360fn svd_pythag<T: SvdFloat + FromPrimitive>(a: T, b: T) -> T {
362 match Float::max(Float::abs(a), Float::abs(b)) {
363 n if n > T::zero() => {
364 let mut p = n;
365 let mut r = Float::powi(Float::min(Float::abs(a), Float::abs(b)) / p, 2);
366 let four = T::from_f64(4.0).unwrap();
367 let two = T::from_f64(2.0).unwrap();
368 let mut t = four + r;
369 while !compare(t, four) {
370 let s = r / t;
371 let u = T::one() + two * s;
372 p = p * u;
373 r = Float::powi((s / u), 2);
374 t = four + r;
375 }
376 p
377 }
378 _ => T::zero(),
379 }
380}
381
382fn svd_ddot<T: Float + Sum<T> + Send + Sync>(x: &[T], y: &[T]) -> T {
384 if x.len() < 1000 {
385 x.iter().zip(y).map(|(a, b)| *a * *b).sum()
386 } else {
387 x.par_iter().zip(y.par_iter()).map(|(a, b)| *a * *b).sum()
388 }
389}
390
391fn svd_norm<T: Float + Sum<T> + Send + Sync>(x: &[T]) -> T {
393 svd_ddot(x, x).sqrt()
394}
395
396fn svd_datx<T: Float + Sum<T>>(d: T, x: &[T], y: &mut [T]) {
398 for (i, xval) in x.iter().enumerate() {
399 y[i] = d * *xval;
400 }
401}
402
403fn svd_dscal<T: Float + MulAssign + Send + Sync>(d: T, x: &mut [T]) {
405 if x.len() < 1000 {
406 for elem in x.iter_mut() {
407 *elem *= d;
408 }
409 } else {
410 x.par_iter_mut().for_each(|elem| {
411 *elem *= d;
412 });
413 }
414}
415
416fn svd_dcopy<T: Float + Copy>(n: usize, offset: usize, x: &[T], y: &mut [T]) {
418 if n > 0 {
419 let start = n - 1;
420 for i in 0..n {
421 y[offset + start - i] = x[offset + i];
422 }
423 }
424}
425
426const MAX_IMTQLB_ITERATIONS: usize = 100;
427
428fn imtqlb<T: SvdFloat>(
429 n: usize,
430 d: &mut [T],
431 e: &mut [T],
432 bnd: &mut [T],
433 max_imtqlb: Option<usize>,
434) -> Result<(), SvdLibError> {
435 let max_imtqlb = max_imtqlb.unwrap_or(MAX_IMTQLB_ITERATIONS);
436 if n == 1 {
437 return Ok(());
438 }
439
440 let matrix_size_factor = T::from_f64((n as f64).sqrt()).unwrap();
441
442 bnd[0] = T::one();
443 let last = n - 1;
444 for i in 1..=last {
445 bnd[i] = T::zero();
446 e[i - 1] = e[i];
447 }
448 e[last] = T::zero();
449
450 let mut i = 0;
451
452 let mut had_convergence_issues = false;
453
454 for l in 0..=last {
455 let mut iteration = 0;
456 let mut p = d[l];
457 let mut f = bnd[l];
458
459 while iteration <= max_imtqlb {
460 let mut m = l;
461 while m < n {
462 if m == last {
463 break;
464 }
465
466 let test = Float::abs(d[m]) + Float::abs(d[m + 1]);
468 let tol = <T as Float>::epsilon()
470 * T::from_f64(100.0).unwrap()
471 * Float::max(test, T::one())
472 * matrix_size_factor;
473
474 if Float::abs(e[m]) <= tol {
475 break; }
477 m += 1;
478 }
479
480 if m == l {
481 let mut exchange = true;
483 if l > 0 {
484 i = l;
485 while i >= 1 && exchange {
486 if p < d[i - 1] {
487 d[i] = d[i - 1];
488 bnd[i] = bnd[i - 1];
489 i -= 1;
490 } else {
491 exchange = false;
492 }
493 }
494 }
495 if exchange {
496 i = 0;
497 }
498 d[i] = p;
499 bnd[i] = f;
500 iteration = max_imtqlb + 1; } else {
502 if iteration == max_imtqlb {
504 had_convergence_issues = true;
506
507 for idx in l..=m {
509 bnd[idx] = Float::max(bnd[idx], T::from_f64(0.1).unwrap());
510 }
511
512 e[l] = T::zero();
514
515 break;
517 }
518
519 iteration += 1;
520 let two = T::from_f64(2.0).unwrap();
522 let mut g = (d[l + 1] - p) / (two * e[l]);
523 let mut r = svd_pythag(g, T::one());
524 g = d[m] - p + e[l] / (g + svd_fsign(r, g));
525 let mut s = T::one();
526 let mut c = T::one();
527 p = T::zero();
528
529 assert!(m > 0, "imtqlb: expected 'm' to be non-zero");
530 i = m - 1;
531 let mut underflow = false;
532 while !underflow && i >= l {
533 f = s * e[i];
534 let b = c * e[i];
535 r = svd_pythag(f, g);
536 e[i + 1] = r;
537
538 if r < <T as Float>::epsilon()
540 * T::from_f64(1000.0).unwrap()
541 * (Float::abs(f) + Float::abs(g))
542 {
543 underflow = true;
544 break;
545 }
546
547 if Float::abs(r) < <T as Float>::epsilon() * T::from_f64(100.0).unwrap() {
549 r = <T as Float>::epsilon()
550 * T::from_f64(100.0).unwrap()
551 * svd_fsign(T::one(), r);
552 }
553
554 s = f / r;
555 c = g / r;
556 g = d[i + 1] - p;
557 r = (d[i] - g) * s + T::from_f64(2.0).unwrap() * c * b;
558 p = s * r;
559 d[i + 1] = g + p;
560 g = c * r - b;
561 f = bnd[i + 1];
562 bnd[i + 1] = s * bnd[i] + c * f;
563 bnd[i] = c * bnd[i] - s * f;
564 if i == 0 {
565 break;
566 }
567 i -= 1;
568 }
569 if underflow {
571 d[i + 1] -= p;
572 } else {
573 d[l] -= p;
574 e[l] = g;
575 }
576 e[m] = T::zero();
577 }
578 }
579 }
580 if had_convergence_issues {
581 eprintln!("Warning: imtqlb had some convergence issues but continued with best estimates. Results may have reduced accuracy.");
582 }
583 Ok(())
584}
585
586#[allow(non_snake_case)]
587fn startv<T: SvdFloat>(
588 A: &dyn SMat<T>,
589 wrk: &mut WorkSpace<T>,
590 step: usize,
591 store: &mut Store<T>,
592 random_seed: u32,
593) -> Result<T, SvdLibError> {
594 let mut rnm2 = svd_ddot(&wrk.w0, &wrk.w0);
596 for id in 0..3 {
597 if id > 0 || step > 0 || compare(rnm2, T::zero()) {
598 let mut bytes = [0; 32];
599 for (i, b) in random_seed.to_le_bytes().iter().enumerate() {
600 bytes[i] = *b;
601 }
602 let mut seeded_rng = StdRng::from_seed(bytes);
603 for val in wrk.w0.iter_mut() {
604 *val = T::from_f64(seeded_rng.random_range(-1.0..1.0)).unwrap();
605 }
606 }
607 wrk.w3.copy_from_slice(&wrk.w0);
608
609 svd_opb(A, &wrk.w3, &mut wrk.w0, &mut wrk.temp, wrk.transposed);
611 wrk.w3.copy_from_slice(&wrk.w0);
612 rnm2 = svd_ddot(&wrk.w3, &wrk.w3);
613 if rnm2 > T::zero() {
614 break;
615 }
616 }
617
618 if rnm2 <= T::zero() {
619 return Err(SvdLibError::StartvError(format!(
620 "rnm2 <= 0.0, rnm2 = {rnm2:?}"
621 )));
622 }
623
624 if step > 0 {
625 for i in 0..step {
626 let v = store.retrq(i);
627 svd_daxpy(-svd_ddot(&wrk.w3, v), v, &mut wrk.w0);
628 }
629
630 svd_daxpy(-svd_ddot(&wrk.w4, &wrk.w0), &wrk.w2, &mut wrk.w0);
632 wrk.w3.copy_from_slice(&wrk.w0);
633
634 rnm2 = match svd_ddot(&wrk.w3, &wrk.w3) {
635 dot if dot <= T::eps() * rnm2 => T::zero(),
636 dot => dot,
637 }
638 }
639 Ok(rnm2.sqrt())
640}
641
642#[allow(non_snake_case)]
643fn stpone<T: SvdFloat>(
644 A: &dyn SMat<T>,
645 wrk: &mut WorkSpace<T>,
646 store: &mut Store<T>,
647 random_seed: u32,
648) -> Result<(T, T), SvdLibError> {
649 let mut rnm = startv(A, wrk, 0, store, random_seed)?;
651 if compare(rnm, T::zero()) {
652 return Err(SvdLibError::StponeError("rnm == 0.0".to_string()));
653 }
654
655 svd_datx(Float::recip(rnm), &wrk.w0, &mut wrk.w1);
657 svd_dscal(Float::recip(rnm), &mut wrk.w3);
658
659 svd_opb(A, &wrk.w3, &mut wrk.w0, &mut wrk.temp, wrk.transposed);
661 wrk.alf[0] = svd_ddot(&wrk.w0, &wrk.w3);
662 svd_daxpy(-wrk.alf[0], &wrk.w1, &mut wrk.w0);
663 let t = svd_ddot(&wrk.w0, &wrk.w3);
664 wrk.alf[0] += t;
665 svd_daxpy(-t, &wrk.w1, &mut wrk.w0);
666 wrk.w4.copy_from_slice(&wrk.w0);
667 rnm = svd_norm(&wrk.w4);
668 let anorm = rnm + Float::abs(wrk.alf[0]);
669 Ok((rnm, T::eps().sqrt() * anorm))
670}
671
672#[allow(non_snake_case)]
673#[allow(clippy::too_many_arguments)]
674fn lanczos_step<T: SvdFloat>(
675 A: &dyn SMat<T>,
676 wrk: &mut WorkSpace<T>,
677 first: usize,
678 last: usize,
679 ll: &mut usize,
680 enough: &mut bool,
681 rnm: &mut T,
682 tol: &mut T,
683 store: &mut Store<T>,
684) -> Result<usize, SvdLibError> {
685 let eps1 = T::eps() * T::from_f64(wrk.ncols as f64).unwrap().sqrt();
686 let mut j = first;
687 let four = T::from_f64(4.0).unwrap();
688
689 while j < last {
690 mem::swap(&mut wrk.w1, &mut wrk.w2);
691 mem::swap(&mut wrk.w3, &mut wrk.w4);
692
693 store.storq(j - 1, &wrk.w2);
694 if j - 1 < MAXLL {
695 store.storp(j - 1, &wrk.w4);
696 }
697 wrk.bet[j] = *rnm;
698
699 if compare(*rnm, T::zero()) {
701 *rnm = startv(A, wrk, j, store, 0)?;
702 if compare(*rnm, T::zero()) {
703 *enough = true;
704 }
705 }
706
707 if *enough {
708 mem::swap(&mut wrk.w1, &mut wrk.w2);
709 break;
710 }
711
712 svd_datx(Float::recip(*rnm), &wrk.w0, &mut wrk.w1);
714 svd_dscal(Float::recip(*rnm), &mut wrk.w3);
715 svd_opb(A, &wrk.w3, &mut wrk.w0, &mut wrk.temp, wrk.transposed);
716 svd_daxpy(-*rnm, &wrk.w2, &mut wrk.w0);
717 wrk.alf[j] = svd_ddot(&wrk.w0, &wrk.w3);
718 svd_daxpy(-wrk.alf[j], &wrk.w1, &mut wrk.w0);
719
720 if j <= MAXLL && Float::abs(wrk.alf[j - 1]) > four * Float::abs(wrk.alf[j]) {
722 *ll = j;
723 }
724 for i in 0..(j - 1).min(*ll) {
725 let v1 = store.retrp(i);
726 let t = svd_ddot(v1, &wrk.w0);
727 let v2 = store.retrq(i);
728 svd_daxpy(-t, v2, &mut wrk.w0);
729 wrk.eta[i] = eps1;
730 wrk.oldeta[i] = eps1;
731 }
732
733 let t = svd_ddot(&wrk.w0, &wrk.w4);
735 svd_daxpy(-t, &wrk.w2, &mut wrk.w0);
736 if wrk.bet[j] > T::zero() {
737 wrk.bet[j] += t;
738 }
739 let t = svd_ddot(&wrk.w0, &wrk.w3);
740 svd_daxpy(-t, &wrk.w1, &mut wrk.w0);
741 wrk.alf[j] += t;
742 wrk.w4.copy_from_slice(&wrk.w0);
743 *rnm = svd_norm(&wrk.w4);
744 let anorm = wrk.bet[j] + Float::abs(wrk.alf[j]) + *rnm;
745 *tol = T::eps().sqrt() * anorm;
746
747 ortbnd(wrk, j, *rnm, eps1);
749
750 purge(wrk.ncols, *ll, wrk, j, rnm, *tol, store);
752 if *rnm <= *tol {
753 *rnm = T::zero();
754 }
755 j += 1;
756 }
757 Ok(j)
758}
759
760fn purge<T: SvdFloat>(
761 n: usize,
762 ll: usize,
763 wrk: &mut WorkSpace<T>,
764 step: usize,
765 rnm: &mut T,
766 tol: T,
767 store: &mut Store<T>,
768) {
769 if step < ll + 2 {
770 return;
771 }
772
773 let reps = T::eps().sqrt();
774 let eps1 = T::eps() * T::from_f64(n as f64).unwrap().sqrt();
775 let two = T::from_f64(2.0).unwrap();
776
777 let k = svd_idamax(step - (ll + 1), &wrk.eta) + ll;
778 if Float::abs(wrk.eta[k]) > reps {
779 let reps1 = eps1 / reps;
780 let mut iteration = 0;
781 let mut flag = true;
782 while iteration < 2 && flag {
783 if *rnm > tol {
784 let mut tq = T::zero();
786 let mut tr = T::zero();
787 for i in ll..step {
788 let v = store.retrq(i);
789 let t = svd_ddot(v, &wrk.w3);
790 tq += Float::abs(t);
791 svd_daxpy(-t, v, &mut wrk.w1);
792 let t = svd_ddot(v, &wrk.w4);
793 tr += Float::abs(t);
794 svd_daxpy(-t, v, &mut wrk.w0);
795 }
796 wrk.w3.copy_from_slice(&wrk.w1);
797 let t = svd_ddot(&wrk.w0, &wrk.w3);
798 tr += Float::abs(t);
799 svd_daxpy(-t, &wrk.w1, &mut wrk.w0);
800 wrk.w4.copy_from_slice(&wrk.w0);
801 *rnm = svd_norm(&wrk.w4);
802 if tq <= reps1 && tr <= *rnm * reps1 {
803 flag = false;
804 }
805 }
806 iteration += 1;
807 }
808 for i in ll..=step {
809 wrk.eta[i] = eps1;
810 wrk.oldeta[i] = eps1;
811 }
812 }
813}
814
815fn ortbnd<T: SvdFloat>(wrk: &mut WorkSpace<T>, step: usize, rnm: T, eps1: T) {
816 if step < 1 {
817 return;
818 }
819 if !compare(rnm, T::zero()) && step > 1 {
820 wrk.oldeta[0] = (wrk.bet[1] * wrk.eta[1] + (wrk.alf[0] - wrk.alf[step]) * wrk.eta[0]
821 - wrk.bet[step] * wrk.oldeta[0])
822 / rnm
823 + eps1;
824 if step > 2 {
825 for i in 1..=step - 2 {
826 wrk.oldeta[i] = (wrk.bet[i + 1] * wrk.eta[i + 1]
827 + (wrk.alf[i] - wrk.alf[step]) * wrk.eta[i]
828 + wrk.bet[i] * wrk.eta[i - 1]
829 - wrk.bet[step] * wrk.oldeta[i])
830 / rnm
831 + eps1;
832 }
833 }
834 }
835 wrk.oldeta[step - 1] = eps1;
836 mem::swap(&mut wrk.oldeta, &mut wrk.eta);
837 wrk.eta[step] = eps1;
838}
839
840fn error_bound<T: SvdFloat>(
841 enough: &mut bool,
842 endl: T,
843 endr: T,
844 ritz: &mut [T],
845 bnd: &mut [T],
846 step: usize,
847 tol: T,
848) -> usize {
849 assert!(step > 0, "error_bound: expected 'step' to be non-zero");
850
851 let mid = svd_idamax(step + 1, bnd);
853 let sixteen = T::from_f64(16.0).unwrap();
854
855 let mut i = ((step + 1) + (step - 1)) / 2;
856 while i > mid + 1 {
857 if Float::abs(ritz[i - 1] - ritz[i]) < T::eps34() * Float::abs(ritz[i])
858 && bnd[i] > tol
859 && bnd[i - 1] > tol
860 {
861 bnd[i - 1] = (Float::powi(bnd[i], 2) + Float::powi(bnd[i - 1], 2)).sqrt();
862 bnd[i] = T::zero();
863 }
864 i -= 1;
865 }
866
867 let mut i = ((step + 1) - (step - 1)) / 2;
868 while i + 1 < mid {
869 if Float::abs(ritz[i + 1] - ritz[i]) < T::eps34() * Float::abs(ritz[i])
870 && bnd[i] > tol
871 && bnd[i + 1] > tol
872 {
873 bnd[i + 1] = (Float::powi(bnd[i], 2) + Float::powi(bnd[i + 1], 2)).sqrt();
874 bnd[i] = T::zero();
875 }
876 i += 1;
877 }
878
879 let mut neig = 0;
881 let mut gapl = ritz[step] - ritz[0];
882 for i in 0..=step {
883 let mut gap = gapl;
884 if i < step {
885 gapl = ritz[i + 1] - ritz[i];
886 }
887 gap = Float::min(gap, gapl);
888 if gap > bnd[i] {
889 bnd[i] *= bnd[i] / gap;
890 }
891 if bnd[i] <= sixteen * T::eps() * Float::abs(ritz[i]) {
892 neig += 1;
893 if !*enough {
894 *enough = endl < ritz[i] && ritz[i] < endr;
895 }
896 }
897 }
898 neig
899}
900
901fn imtql2<T: SvdFloat>(
902 nm: usize,
903 n: usize,
904 d: &mut [T],
905 e: &mut [T],
906 z: &mut [T],
907 max_imtqlb: Option<usize>,
908) -> Result<(), SvdLibError> {
909 let max_imtqlb = max_imtqlb.unwrap_or(MAX_IMTQLB_ITERATIONS);
910 if n == 1 {
911 return Ok(());
912 }
913 assert!(n > 1, "imtql2: expected 'n' to be > 1");
914 let two = T::from_f64(2.0).unwrap();
915
916 let last = n - 1;
917
918 for i in 1..n {
919 e[i - 1] = e[i];
920 }
921 e[last] = T::zero();
922
923 let nnm = n * nm;
924 for l in 0..n {
925 let mut iteration = 0;
926
927 while iteration <= max_imtqlb {
929 let mut m = l;
930 while m < n {
931 if m == last {
932 break;
933 }
934 let test = Float::abs(d[m]) + Float::abs(d[m + 1]);
935 if compare(test, test + Float::abs(e[m])) {
936 break; }
938 m += 1;
939 }
940 if m == l {
941 break;
942 }
943
944 if iteration == max_imtqlb {
946 return Err(SvdLibError::Imtql2Error(format!(
947 "imtql2 no convergence to an eigenvalue after {} iterations",
948 max_imtqlb
949 )));
950 }
951 iteration += 1;
952
953 let mut g = (d[l + 1] - d[l]) / (two * e[l]);
955 let mut r = svd_pythag(g, T::one());
956 g = d[m] - d[l] + e[l] / (g + svd_fsign(r, g));
957
958 let mut s = T::one();
959 let mut c = T::one();
960 let mut p = T::zero();
961
962 assert!(m > 0, "imtql2: expected 'm' to be non-zero");
963 let mut i = m - 1;
964 let mut underflow = false;
965 while !underflow && i >= l {
966 let mut f = s * e[i];
967 let b = c * e[i];
968 r = svd_pythag(f, g);
969 e[i + 1] = r;
970 if compare(r, T::zero()) {
971 underflow = true;
972 } else {
973 s = f / r;
974 c = g / r;
975 g = d[i + 1] - p;
976 r = (d[i] - g) * s + two * c * b;
977 p = s * r;
978 d[i + 1] = g + p;
979 g = c * r - b;
980
981 for k in (0..nnm).step_by(n) {
983 let index = k + i;
984 f = z[index + 1];
985 z[index + 1] = s * z[index] + c * f;
986 z[index] = c * z[index] - s * f;
987 }
988 if i == 0 {
989 break;
990 }
991 i -= 1;
992 }
993 } if underflow {
996 d[i + 1] -= p;
997 } else {
998 d[l] -= p;
999 e[l] = g;
1000 }
1001 e[m] = T::zero();
1002 }
1003 }
1004
1005 for l in 1..n {
1007 let i = l - 1;
1008 let mut k = i;
1009 let mut p = d[i];
1010 for (j, item) in d.iter().enumerate().take(n).skip(l) {
1011 if *item < p {
1012 k = j;
1013 p = *item;
1014 }
1015 }
1016
1017 if k != i {
1019 d[k] = d[i];
1020 d[i] = p;
1021 for j in (0..nnm).step_by(n) {
1022 z.swap(j + i, j + k);
1023 }
1024 }
1025 }
1026
1027 Ok(())
1028}
1029
1030fn rotate_array<T: Float + Copy>(a: &mut [T], x: usize) {
1031 let n = a.len();
1032 let mut j = 0;
1033 let mut start = 0;
1034 let mut t1 = a[0];
1035
1036 for _ in 0..n {
1037 j = match j >= x {
1038 true => j - x,
1039 false => j + n - x,
1040 };
1041
1042 let t2 = a[j];
1043 a[j] = t1;
1044
1045 if j == start {
1046 j += 1;
1047 start = j;
1048 t1 = a[j];
1049 } else {
1050 t1 = t2;
1051 }
1052 }
1053}
1054
1055#[allow(non_snake_case)]
1056fn ritvec<T: SvdFloat>(
1057 A: &dyn SMat<T>,
1058 dimensions: usize,
1059 kappa: T,
1060 wrk: &mut WorkSpace<T>,
1061 steps: usize,
1062 neig: usize,
1063 store: &mut Store<T>,
1064) -> Result<SVDRawRec<T>, SvdLibError> {
1065 let js = steps + 1;
1066 let jsq = js * js;
1067
1068 let sparsity = T::one()
1069 - (T::from_usize(A.nnz()).unwrap()
1070 / (T::from_usize(A.nrows()).unwrap() * T::from_usize(A.ncols()).unwrap()));
1071
1072 let epsilon = <T as Float>::epsilon();
1073 let adaptive_eps = if sparsity > T::from_f64(0.99).unwrap() {
1074 epsilon * T::from_f64(100.0).unwrap()
1076 } else if sparsity > T::from_f64(0.9).unwrap() {
1077 epsilon * T::from_f64(10.0).unwrap()
1079 } else {
1080 epsilon
1082 };
1083
1084 let max_iterations_imtql2 = if sparsity > T::from_f64(0.999).unwrap() {
1085 Some(500)
1087 } else if sparsity > T::from_f64(0.99).unwrap() {
1088 Some(300)
1090 } else if sparsity > T::from_f64(0.9).unwrap() {
1091 Some(200)
1093 } else {
1094 Some(50)
1096 };
1097
1098 let mut s = vec![T::zero(); jsq];
1099 for i in (0..jsq).step_by(js + 1) {
1101 s[i] = T::one();
1102 }
1103
1104 let mut Vt = DMat {
1105 cols: wrk.ncols,
1106 value: vec![T::zero(); wrk.ncols * dimensions],
1107 };
1108
1109 svd_dcopy(js, 0, &wrk.alf, &mut Vt.value);
1110 svd_dcopy(steps, 1, &wrk.bet, &mut wrk.w5);
1111
1112 imtql2(
1115 js,
1116 js,
1117 &mut Vt.value,
1118 &mut wrk.w5,
1119 &mut s,
1120 max_iterations_imtql2,
1121 )?;
1122
1123 let max_eigenvalue = Vt
1124 .value
1125 .iter()
1126 .fold(T::zero(), |max, &val| Float::max(max, Float::abs(val)));
1127
1128 let adaptive_kappa = if sparsity > T::from_f64(0.99).unwrap() {
1129 kappa * T::from_f64(10.0).unwrap()
1131 } else {
1132 kappa
1133 };
1134
1135 let mut x = dimensions - 1;
1136
1137 let store_vectors: Vec<Vec<T>> = (0..js).map(|i| store.retrq(i).to_vec()).collect();
1138
1139 let significant_indices: Vec<usize> = (0..js)
1140 .into_par_iter()
1141 .filter(|&k| {
1142 let relative_bound =
1143 adaptive_kappa * Float::max(Float::abs(wrk.ritz[k]), max_eigenvalue * adaptive_eps);
1144 wrk.bnd[k] <= relative_bound && k + 1 > js - neig
1145 })
1146 .collect();
1147
1148 let nsig = significant_indices.len();
1149
1150 let mut vt_vectors: Vec<(usize, Vec<T>)> = significant_indices
1151 .into_par_iter()
1152 .map(|k| {
1153 let mut vec = vec![T::zero(); wrk.ncols];
1154
1155 for i in 0..js {
1156 let idx = k * js + i;
1157
1158 if Float::abs(s[idx]) > adaptive_eps {
1159 for (j, item) in store_vectors[i].iter().enumerate().take(wrk.ncols) {
1160 vec[j] += s[idx] * *item;
1161 }
1162 }
1163 }
1164
1165 (k, vec)
1166 })
1167 .collect();
1168
1169 vt_vectors.sort_by_key(|(k, _)| *k);
1171
1172 let d = dimensions.min(nsig);
1174 let mut S = vec![T::zero(); d];
1175 let mut Ut = DMat {
1176 cols: wrk.nrows,
1177 value: vec![T::zero(); wrk.nrows * d],
1178 };
1179
1180 let mut Vt = DMat {
1182 cols: wrk.ncols,
1183 value: vec![T::zero(); wrk.ncols * d],
1184 };
1185
1186 for (i, (_, vec)) in vt_vectors.into_iter().take(d).enumerate() {
1188 let vt_offset = i * Vt.cols;
1189 Vt.value[vt_offset..vt_offset + Vt.cols].copy_from_slice(&vec);
1190 }
1191
1192 let mut ab_products = Vec::with_capacity(d);
1194 let mut a_products = Vec::with_capacity(d);
1195
1196 for i in 0..d {
1198 let vt_offset = i * Vt.cols;
1199 let vt_vec = &Vt.value[vt_offset..vt_offset + Vt.cols];
1200
1201 let mut tmp_vec = vec![T::zero(); Vt.cols];
1202 let mut ut_vec = vec![T::zero(); wrk.nrows];
1203
1204 svd_opb(A, vt_vec, &mut tmp_vec, &mut wrk.temp, wrk.transposed);
1206 A.svd_opa(vt_vec, &mut ut_vec, wrk.transposed);
1207
1208 ab_products.push(tmp_vec);
1209 a_products.push(ut_vec);
1210 }
1211
1212 let results: Vec<(usize, T)> = (0..d)
1213 .into_par_iter()
1214 .map(|i| {
1215 let vt_offset = i * Vt.cols;
1216 let vt_vec = &Vt.value[vt_offset..vt_offset + Vt.cols];
1217 let tmp_vec = &ab_products[i];
1218
1219 let t = svd_ddot(vt_vec, tmp_vec);
1221 let sval = Float::max(t, T::zero()).sqrt();
1222
1223 (i, sval)
1224 })
1225 .collect();
1226
1227 for (i, sval) in results {
1229 S[i] = sval;
1230 let ut_offset = i * Ut.cols;
1231 let mut ut_vec = a_products[i].clone();
1232
1233 if sval > adaptive_eps {
1234 svd_dscal(T::one() / sval, &mut ut_vec);
1235 } else {
1236 let dls = Float::max(sval, adaptive_eps);
1237 let safe_scale = T::one() / dls;
1238 svd_dscal(safe_scale, &mut ut_vec);
1239 }
1240
1241 Ut.value[ut_offset..ut_offset + Ut.cols].copy_from_slice(&ut_vec);
1243 }
1244
1245 Ok(SVDRawRec {
1246 d,
1248 nsig,
1250 Ut,
1253 S,
1255 Vt,
1258 })
1259}
1260
1261#[allow(non_snake_case)]
1262#[allow(clippy::too_many_arguments)]
1263fn lanso<T: SvdFloat>(
1264 A: &dyn SMat<T>,
1265 dim: usize,
1266 iterations: usize,
1267 end_interval: &[T; 2],
1268 wrk: &mut WorkSpace<T>,
1269 neig: &mut usize,
1270 store: &mut Store<T>,
1271 random_seed: u32,
1272) -> Result<usize, SvdLibError> {
1273 let sparsity = T::one()
1274 - (T::from_usize(A.nnz()).unwrap()
1275 / (T::from_usize(A.nrows()).unwrap() * T::from_usize(A.ncols()).unwrap()));
1276 let max_iterations_imtqlb = if sparsity > T::from_f64(0.999).unwrap() {
1277 Some(500)
1279 } else if sparsity > T::from_f64(0.99).unwrap() {
1280 Some(300)
1282 } else if sparsity > T::from_f64(0.9).unwrap() {
1283 Some(100)
1285 } else {
1286 Some(50)
1288 };
1289
1290 let epsilon = <T as Float>::epsilon();
1291 let adaptive_eps = if sparsity > T::from_f64(0.99).unwrap() {
1292 epsilon * T::from_f64(100.0).unwrap()
1294 } else if sparsity > T::from_f64(0.9).unwrap() {
1295 epsilon * T::from_f64(10.0).unwrap()
1297 } else {
1298 epsilon
1300 };
1301
1302 let (endl, endr) = (end_interval[0], end_interval[1]);
1303
1304 let rnm_tol = stpone(A, wrk, store, random_seed)?;
1306 let mut rnm = rnm_tol.0;
1307 let mut tol = rnm_tol.1;
1308
1309 let eps1 = adaptive_eps * T::from_f64(wrk.ncols as f64).unwrap().sqrt();
1310 wrk.eta[0] = eps1;
1311 wrk.oldeta[0] = eps1;
1312 let mut ll = 0;
1313 let mut first = 1;
1314 let mut last = iterations.min(dim.max(8) + dim);
1315 let mut enough = false;
1316 let mut j = 0;
1317 let mut intro = 0;
1318
1319 while !enough {
1320 if rnm <= tol {
1321 rnm = T::zero();
1322 }
1323
1324 let steps = lanczos_step(
1326 A,
1327 wrk,
1328 first,
1329 last,
1330 &mut ll,
1331 &mut enough,
1332 &mut rnm,
1333 &mut tol,
1334 store,
1335 )?;
1336 j = match enough {
1337 true => steps - 1,
1338 false => last - 1,
1339 };
1340
1341 first = j + 1;
1342 wrk.bet[first] = rnm;
1343
1344 let mut l = 0;
1346 for _ in 0..j {
1347 if l > j {
1348 break;
1349 }
1350
1351 let mut i = l;
1352 while i <= j {
1353 if Float::abs(wrk.bet[i + 1]) <= adaptive_eps {
1354 break;
1355 }
1356 i += 1;
1357 }
1358 i = i.min(j);
1359
1360 let sz = i - l;
1362 svd_dcopy(sz + 1, l, &wrk.alf, &mut wrk.ritz);
1363 svd_dcopy(sz, l + 1, &wrk.bet, &mut wrk.w5);
1364
1365 imtqlb(
1366 sz + 1,
1367 &mut wrk.ritz[l..],
1368 &mut wrk.w5[l..],
1369 &mut wrk.bnd[l..],
1370 max_iterations_imtqlb,
1371 )?;
1372
1373 for m in l..=i {
1374 wrk.bnd[m] = rnm * Float::abs(wrk.bnd[m]);
1375 }
1376 l = i + 1;
1377 }
1378
1379 insert_sort(j + 1, &mut wrk.ritz, &mut wrk.bnd);
1381
1382 *neig = error_bound(&mut enough, endl, endr, &mut wrk.ritz, &mut wrk.bnd, j, tol);
1383
1384 if *neig < dim {
1386 if *neig == 0 {
1387 last = first + 9;
1388 intro = first;
1389 } else {
1390 let extra_steps = if sparsity > T::from_f64(0.99).unwrap() {
1391 5 } else {
1393 0
1394 };
1395
1396 last = first + 3.max(1 + ((j - intro) * (dim - *neig)) / *neig) + extra_steps;
1397 }
1398 last = last.min(iterations);
1399 } else {
1400 enough = true
1401 }
1402 enough = enough || first >= iterations;
1403 }
1404 store.storq(j, &wrk.w1);
1405 Ok(j)
1406}
1407
1408impl<T: SvdFloat + 'static> SvdRec<T> {
1409 pub fn recompose(&self) -> Array2<T> {
1410 let sdiag = Array2::from_diag(&self.s);
1411 self.u.dot(&sdiag).dot(&self.vt)
1412 }
1413}
1414
1415impl<T: Float + Zero + AddAssign + Clone + Sync> SMat<T> for nalgebra_sparse::csc::CscMatrix<T> {
1416 fn nrows(&self) -> usize {
1417 self.nrows()
1418 }
1419 fn ncols(&self) -> usize {
1420 self.ncols()
1421 }
1422 fn nnz(&self) -> usize {
1423 self.nnz()
1424 }
1425
1426 fn svd_opa(&self, x: &[T], y: &mut [T], transposed: bool) {
1428 let nrows = if transposed {
1429 self.ncols()
1430 } else {
1431 self.nrows()
1432 };
1433 let ncols = if transposed {
1434 self.nrows()
1435 } else {
1436 self.ncols()
1437 };
1438 assert_eq!(
1439 x.len(),
1440 ncols,
1441 "svd_opa: x must be A.ncols() in length, x = {}, A.ncols = {}",
1442 x.len(),
1443 ncols
1444 );
1445 assert_eq!(
1446 y.len(),
1447 nrows,
1448 "svd_opa: y must be A.nrows() in length, y = {}, A.nrows = {}",
1449 y.len(),
1450 nrows
1451 );
1452
1453 let (major_offsets, minor_indices, values) = self.csc_data();
1454
1455 for y_val in y.iter_mut() {
1456 *y_val = T::zero();
1457 }
1458
1459 if transposed {
1460 for (i, yval) in y.iter_mut().enumerate() {
1461 for j in major_offsets[i]..major_offsets[i + 1] {
1462 *yval += values[j] * x[minor_indices[j]];
1463 }
1464 }
1465 } else {
1466 for (i, xval) in x.iter().enumerate() {
1467 for j in major_offsets[i]..major_offsets[i + 1] {
1468 y[minor_indices[j]] += values[j] * *xval;
1469 }
1470 }
1471 }
1472 }
1473
1474 fn compute_column_means(&self) -> Vec<T> {
1475 todo!()
1476 }
1477
1478 fn multiply_with_dense(
1479 &self,
1480 dense: &DMatrix<T>,
1481 result: &mut DMatrix<T>,
1482 transpose_self: bool,
1483 ) {
1484 todo!()
1485 }
1486
1487 fn multiply_with_dense_centered(
1488 &self,
1489 dense: &DMatrix<T>,
1490 result: &mut DMatrix<T>,
1491 transpose_self: bool,
1492 means: &DVector<T>,
1493 ) {
1494 todo!()
1495 }
1496
1497 fn multiply_transposed_by_dense(&self, q: &DMatrix<T>, result: &mut DMatrix<T>) {
1498 todo!()
1499 }
1500
1501 fn multiply_transposed_by_dense_centered(&self, q: &DMatrix<T>, result: &mut DMatrix<T>, means: &DVector<T>) {
1502 todo!()
1503 }
1504}
1505
1506impl<T: Float + Zero + AddAssign + Clone + Sync + Send + std::ops::MulAssign> SMat<T>
1507 for nalgebra_sparse::csr::CsrMatrix<T>
1508{
1509 fn nrows(&self) -> usize {
1510 self.nrows()
1511 }
1512 fn ncols(&self) -> usize {
1513 self.ncols()
1514 }
1515 fn nnz(&self) -> usize {
1516 self.nnz()
1517 }
1518
1519 fn svd_opa(&self, x: &[T], y: &mut [T], transposed: bool) {
1521 let nrows = if transposed {
1523 self.ncols()
1524 } else {
1525 self.nrows()
1526 };
1527 let ncols = if transposed {
1528 self.nrows()
1529 } else {
1530 self.ncols()
1531 };
1532 assert_eq!(
1533 x.len(),
1534 ncols,
1535 "svd_opa: x must be A.ncols() in length, x = {}, A.ncols = {}",
1536 x.len(),
1537 ncols
1538 );
1539 assert_eq!(
1540 y.len(),
1541 nrows,
1542 "svd_opa: y must be A.nrows() in length, y = {}, A.nrows = {}",
1543 y.len(),
1544 nrows
1545 );
1546
1547 let (major_offsets, minor_indices, values) = self.csr_data();
1548
1549 y.fill(T::zero());
1550
1551 if !transposed {
1552 let nrows = self.nrows();
1553 let chunk_size = crate::utils::determine_chunk_size(nrows);
1554
1555 let results: Vec<(usize, T)> = (0..nrows)
1557 .into_par_iter()
1558 .map(|i| {
1559 let mut sum = T::zero();
1560 for j in major_offsets[i]..major_offsets[i + 1] {
1561 sum += values[j] * x[minor_indices[j]];
1562 }
1563 (i, sum)
1564 })
1565 .collect();
1566
1567 for (i, val) in results {
1569 y[i] = val;
1570 }
1571 } else {
1572 let nrows = self.nrows();
1573 let chunk_size = crate::utils::determine_chunk_size(nrows);
1574
1575 let results: Vec<Vec<T>> = (0..((nrows + chunk_size - 1) / chunk_size))
1577 .into_par_iter()
1578 .map(|chunk_idx| {
1579 let start = chunk_idx * chunk_size;
1580 let end = (start + chunk_size).min(nrows);
1581
1582 let mut local_y = vec![T::zero(); y.len()];
1583 for i in start..end {
1584 let row_val = x[i];
1585 for j in major_offsets[i]..major_offsets[i + 1] {
1586 let col = minor_indices[j];
1587 local_y[col] += values[j] * row_val;
1588 }
1589 }
1590 local_y
1591 })
1592 .collect();
1593
1594 for local_y in results {
1596 for (idx, val) in local_y.iter().enumerate() {
1597 if !val.is_zero() {
1598 y[idx] += *val;
1599 }
1600 }
1601 }
1602 }
1603 }
1604
1605 fn compute_column_means(&self) -> Vec<T> {
1606 let rows = self.nrows();
1607 let cols = self.ncols();
1608 let row_count_recip = T::one() / T::from(rows).unwrap();
1609
1610 let mut col_sums = vec![T::zero(); cols];
1611 let (row_offsets, col_indices, values) = self.csr_data();
1612
1613 for i in 0..rows {
1615 for j in row_offsets[i]..row_offsets[i + 1] {
1616 let col = col_indices[j];
1617 col_sums[col] += values[j];
1618 }
1619 }
1620
1621 for j in 0..cols {
1623 col_sums[j] *= row_count_recip;
1624 }
1625
1626 col_sums
1627 }
1628
1629 fn multiply_with_dense(
1630 &self,
1631 dense: &DMatrix<T>,
1632 result: &mut DMatrix<T>,
1633 transpose_self: bool,
1634 ) {
1635 todo!()
1636 }
1637
1638 fn multiply_with_dense_centered(
1639 &self,
1640 dense: &DMatrix<T>,
1641 result: &mut DMatrix<T>,
1642 transpose_self: bool,
1643 means: &DVector<T>,
1644 ) {
1645 todo!()
1646 }
1647
1648 fn multiply_transposed_by_dense(&self, q: &DMatrix<T>, result: &mut DMatrix<T>) {
1649 todo!()
1650 }
1651
1652 fn multiply_transposed_by_dense_centered(&self, q: &DMatrix<T>, result: &mut DMatrix<T>, means: &DVector<T>) {
1653 todo!()
1654 }
1655}
1656
1657impl<T: Float + Zero + AddAssign + Clone + Sync> SMat<T> for nalgebra_sparse::coo::CooMatrix<T> {
1658 fn nrows(&self) -> usize {
1659 self.nrows()
1660 }
1661 fn ncols(&self) -> usize {
1662 self.ncols()
1663 }
1664 fn nnz(&self) -> usize {
1665 self.nnz()
1666 }
1667
1668 fn svd_opa(&self, x: &[T], y: &mut [T], transposed: bool) {
1670 let nrows = if transposed {
1671 self.ncols()
1672 } else {
1673 self.nrows()
1674 };
1675 let ncols = if transposed {
1676 self.nrows()
1677 } else {
1678 self.ncols()
1679 };
1680 assert_eq!(
1681 x.len(),
1682 ncols,
1683 "svd_opa: x must be A.ncols() in length, x = {}, A.ncols = {}",
1684 x.len(),
1685 ncols
1686 );
1687 assert_eq!(
1688 y.len(),
1689 nrows,
1690 "svd_opa: y must be A.nrows() in length, y = {}, A.nrows = {}",
1691 y.len(),
1692 nrows
1693 );
1694
1695 for y_val in y.iter_mut() {
1696 *y_val = T::zero();
1697 }
1698
1699 if transposed {
1700 for (i, j, v) in self.triplet_iter() {
1701 y[j] += *v * x[i];
1702 }
1703 } else {
1704 for (i, j, v) in self.triplet_iter() {
1705 y[i] += *v * x[j];
1706 }
1707 }
1708 }
1709
1710 fn compute_column_means(&self) -> Vec<T> {
1711 todo!()
1712 }
1713
1714 fn multiply_with_dense(
1715 &self,
1716 dense: &DMatrix<T>,
1717 result: &mut DMatrix<T>,
1718 transpose_self: bool,
1719 ) {
1720 todo!()
1721 }
1722
1723 fn multiply_with_dense_centered(
1724 &self,
1725 dense: &DMatrix<T>,
1726 result: &mut DMatrix<T>,
1727 transpose_self: bool,
1728 means: &DVector<T>,
1729 ) {
1730 todo!()
1731 }
1732
1733 fn multiply_transposed_by_dense(&self, q: &DMatrix<T>, result: &mut DMatrix<T>) {
1734 todo!()
1735 }
1736
1737 fn multiply_transposed_by_dense_centered(&self, q: &DMatrix<T>, result: &mut DMatrix<T>, means: &DVector<T>) {
1738 todo!()
1739 }
1740}