1#![allow(unused_variables)]
7#![allow(unused_assignments)]
8#![allow(unused_mut)]
9
10use crate::error::{SparseError, SparseResult};
11use crate::sparray::SparseArray;
12use scirs2_core::ndarray::{Array1, Array2};
13use scirs2_core::numeric::{Float, SparseElement};
14use std::fmt::Debug;
15use std::ops::{Add, Div, Mul, Sub};
16
17type BidiagonalSvdResult<T> = (Vec<T>, Vec<Vec<f64>>, Vec<Vec<f64>>);
19
20#[derive(Debug, Clone, Copy, PartialEq)]
22pub enum SVDMethod {
23 Lanczos,
25 Randomized,
27 Power,
29 CrossApproximation,
31}
32
33impl SVDMethod {
34 pub fn from_str(s: &str) -> SparseResult<Self> {
35 match s.to_lowercase().as_str() {
36 "lanczos" => Ok(Self::Lanczos),
37 "randomized" | "random" => Ok(Self::Randomized),
38 "power" => Ok(Self::Power),
39 "cross" | "cross_approximation" => Ok(Self::CrossApproximation),
40 _ => Err(SparseError::ValueError(format!("Unknown SVD method: {s}"))),
41 }
42 }
43}
44
45#[derive(Debug, Clone)]
47pub struct SVDOptions {
48 pub k: usize,
50 pub maxiter: usize,
52 pub tol: f64,
54 pub n_oversamples: usize,
56 pub n_iter: usize,
58 pub method: SVDMethod,
60 pub random_seed: Option<u64>,
62 pub compute_u: bool,
64 pub compute_vt: bool,
66}
67
68impl Default for SVDOptions {
69 fn default() -> Self {
70 Self {
71 k: 6,
72 maxiter: 1000,
73 tol: 1e-10,
74 n_oversamples: 10,
75 n_iter: 2,
76 method: SVDMethod::Lanczos,
77 random_seed: None,
78 compute_u: true,
79 compute_vt: true,
80 }
81 }
82}
83
84#[derive(Debug, Clone)]
86pub struct SVDResult<T>
87where
88 T: Float + SparseElement + Debug + Copy,
89{
90 pub u: Option<Array2<T>>,
92 pub s: Array1<T>,
94 pub vt: Option<Array2<T>>,
96 pub iterations: usize,
98 pub converged: bool,
100}
101
102#[allow(dead_code)]
133pub fn svds<T, S>(
134 matrix: &S,
135 k: Option<usize>,
136 options: Option<SVDOptions>,
137) -> SparseResult<SVDResult<T>>
138where
139 T: Float
140 + SparseElement
141 + Debug
142 + Copy
143 + Add<Output = T>
144 + Sub<Output = T>
145 + Mul<Output = T>
146 + Div<Output = T>
147 + 'static
148 + std::iter::Sum,
149 S: SparseArray<T>,
150{
151 let opts = options.unwrap_or_default();
152 let k = k.unwrap_or(opts.k);
153
154 let (m, n) = matrix.shape();
155 if k >= m.min(n) {
156 return Err(SparseError::ValueError(
157 "Number of singular values k must be less than min(m, n)".to_string(),
158 ));
159 }
160
161 match opts.method {
162 SVDMethod::Lanczos => lanczos_bidiag_svd(matrix, k, &opts),
163 SVDMethod::Randomized => randomized_svd(matrix, k, &opts),
164 SVDMethod::Power => power_method_svd(matrix, k, &opts),
165 SVDMethod::CrossApproximation => cross_approximation_svd(matrix, k, &opts),
166 }
167}
168
169#[allow(dead_code)]
171pub fn svd_truncated<T, S>(
172 matrix: &S,
173 k: usize,
174 method: &str,
175 tol: Option<f64>,
176 maxiter: Option<usize>,
177) -> SparseResult<SVDResult<T>>
178where
179 T: Float
180 + SparseElement
181 + Debug
182 + Copy
183 + Add<Output = T>
184 + Sub<Output = T>
185 + Mul<Output = T>
186 + Div<Output = T>
187 + 'static
188 + std::iter::Sum,
189 S: SparseArray<T>,
190{
191 let method_enum = SVDMethod::from_str(method)?;
192
193 let options = SVDOptions {
194 k,
195 method: method_enum,
196 tol: tol.unwrap_or(1e-10),
197 maxiter: maxiter.unwrap_or(1000),
198 ..Default::default()
199 };
200
201 svds(matrix, Some(k), Some(options))
202}
203
204#[allow(dead_code)]
206fn lanczos_bidiag_svd<T, S>(
207 matrix: &S,
208 k: usize,
209 options: &SVDOptions,
210) -> SparseResult<SVDResult<T>>
211where
212 T: Float
213 + SparseElement
214 + Debug
215 + Copy
216 + Add<Output = T>
217 + Sub<Output = T>
218 + Mul<Output = T>
219 + Div<Output = T>
220 + 'static
221 + std::iter::Sum,
222 S: SparseArray<T>,
223{
224 let (m, n) = matrix.shape();
225 let max_lanczos_size = (2 * k + 10).min(m.min(n));
226
227 let mut u = Array1::zeros(m);
229 u[0] = T::sparse_one();
230
231 let norm = (u.iter().map(|&v| v * v).sum::<T>()).sqrt();
233 if !SparseElement::is_zero(&norm) {
234 for i in 0..m {
235 u[i] = u[i] / norm;
236 }
237 }
238
239 let mut alpha = Vec::<T>::new();
240 let mut beta = Vec::<T>::new();
241 let mut u_vectors = Vec::<Array1<T>>::with_capacity(max_lanczos_size);
242 let mut v_vectors = Vec::<Array1<T>>::with_capacity(max_lanczos_size);
243
244 u_vectors.push(u.clone());
245
246 let mut converged = false;
247 let mut iter = 0;
248
249 while iter < options.maxiter && alpha.len() < max_lanczos_size {
251 let av = matrix_transpose_vector_product(matrix, &u_vectors[iter])?;
253 let mut v = av;
254
255 if iter > 0 && !beta.is_empty() {
256 let prev_beta = beta[iter - 1];
257 for i in 0..n {
258 v[i] = v[i] - prev_beta * v_vectors[iter - 1][i];
259 }
260 }
261
262 let alpha_j = (v.iter().map(|&val| val * val).sum::<T>()).sqrt();
264 alpha.push(alpha_j);
265
266 if SparseElement::is_zero(&alpha_j) {
267 break;
268 }
269
270 for i in 0..n {
272 v[i] = v[i] / alpha_j;
273 }
274 v_vectors.push(v.clone());
275
276 let avu = matrix_vector_product(matrix, &v)?;
278 let mut u_next = avu;
279
280 for i in 0..m {
281 u_next[i] = u_next[i] - alpha_j * u_vectors[iter][i];
282 }
283
284 let beta_j = (u_next.iter().map(|&val| val * val).sum::<T>()).sqrt();
286 beta.push(beta_j);
287
288 if beta_j < T::from(options.tol).expect("Operation failed") {
289 converged = true;
290 break;
291 }
292
293 for i in 0..m {
295 u_next[i] = u_next[i] / beta_j;
296 }
297
298 u_vectors.push(u_next);
299 iter += 1;
300 }
301
302 let (singular_values, u_bidiag, vt_bidiag) = solve_bidiagonal_svd(&alpha, &beta, k)?;
304
305 let final_u = if options.compute_u {
307 let mut u_final = Array2::zeros((m, k.min(singular_values.len())));
308 for j in 0..k.min(singular_values.len()) {
309 for i in 0..m {
310 let mut sum = T::sparse_zero();
311 for l in 0..u_vectors.len().min(u_bidiag.len()) {
312 if j < u_bidiag[l].len() {
313 sum = sum
314 + T::from(u_bidiag[l][j]).expect("Operation failed") * u_vectors[l][i];
315 }
316 }
317 u_final[[i, j]] = sum;
318 }
319 }
320 Some(u_final)
321 } else {
322 None
323 };
324
325 let final_vt = if options.compute_vt {
326 let mut vt_final = Array2::zeros((k.min(singular_values.len()), n));
327 for j in 0..k.min(singular_values.len()) {
328 for i in 0..n {
329 let mut sum = T::sparse_zero();
330 for l in 0..v_vectors.len().min(vt_bidiag.len()) {
331 if j < vt_bidiag[l].len() {
332 sum = sum
333 + T::from(vt_bidiag[l][j]).expect("Operation failed") * v_vectors[l][i];
334 }
335 }
336 vt_final[[j, i]] = sum;
337 }
338 }
339 Some(vt_final)
340 } else {
341 None
342 };
343
344 Ok(SVDResult {
345 u: final_u,
346 s: Array1::from_vec(singular_values[..k.min(singular_values.len())].to_vec()),
347 vt: final_vt,
348 iterations: iter,
349 converged,
350 })
351}
352
353#[allow(dead_code)]
355fn randomized_svd<T, S>(matrix: &S, k: usize, options: &SVDOptions) -> SparseResult<SVDResult<T>>
356where
357 T: Float
358 + SparseElement
359 + Debug
360 + Copy
361 + Add<Output = T>
362 + Sub<Output = T>
363 + Mul<Output = T>
364 + Div<Output = T>
365 + 'static
366 + std::iter::Sum,
367 S: SparseArray<T>,
368{
369 let (m, n) = matrix.shape();
370 let l = (k + options.n_oversamples).min(m).min(n);
372
373 let mut omega = Array2::zeros((n, l));
375 for i in 0..n {
376 for j in 0..l {
377 let val = ((i * 17 + j * 13) % 1000) as f64 / 1000.0 - 0.5;
379 omega[[i, j]] = T::from(val).expect("Operation failed");
380 }
381 }
382
383 let mut y = Array2::zeros((m, l));
385 for j in 0..l {
386 let omega_col = omega.column(j).to_owned();
387 let y_col = matrix_vector_product(matrix, &omega_col)?;
388 for i in 0..m {
389 y[[i, j]] = y_col[i];
390 }
391 }
392
393 for _ in 0..options.n_iter {
395 let mut y_new = Array2::zeros((m, l));
397 for j in 0..l {
398 let y_col = y.column(j).to_owned();
399 let at_y_col = matrix_transpose_vector_product(matrix, &y_col)?;
400 let a_at_y_col = matrix_vector_product(matrix, &at_y_col)?;
401 for i in 0..m {
402 y_new[[i, j]] = a_at_y_col[i];
403 }
404 }
405 y = y_new;
406 }
407
408 let q = qr_decomposition_orthogonal(&y)?;
410
411 let mut b = Array2::zeros((l, n));
416 for i in 0..l {
417 let q_col = q.column(i).to_owned();
418
419 let b_row = matrix_transpose_vector_product(matrix, &q_col)?;
421 for j in 0..n {
422 b[[i, j]] = b_row[j];
423 }
424 }
425
426 let b_svd = dense_svd(&b, k)?;
428
429 let final_u = if options.compute_u {
433 if let Some(ref u_b) = b_svd.u {
434 let mut u_result = Array2::zeros((m, k));
435 for i in 0..m {
436 for j in 0..k {
437 let mut sum = T::sparse_zero();
438 for l_idx in 0..l {
439 sum = sum + q[[i, l_idx]] * u_b[[l_idx, j]];
440 }
441 u_result[[i, j]] = sum;
442 }
443 }
444 Some(u_result)
445 } else {
446 None
447 }
448 } else {
449 None
450 };
451
452 Ok(SVDResult {
453 u: final_u,
454 s: b_svd.s,
455 vt: b_svd.vt,
456 iterations: options.n_iter + 1,
457 converged: true,
458 })
459}
460
461#[allow(dead_code)]
463fn power_method_svd<T, S>(matrix: &S, k: usize, options: &SVDOptions) -> SparseResult<SVDResult<T>>
464where
465 T: Float
466 + SparseElement
467 + Debug
468 + Copy
469 + Add<Output = T>
470 + Sub<Output = T>
471 + Mul<Output = T>
472 + Div<Output = T>
473 + 'static
474 + std::iter::Sum,
475 S: SparseArray<T>,
476{
477 lanczos_bidiag_svd(matrix, k, options)
480}
481
482#[allow(dead_code)]
484fn cross_approximation_svd<T, S>(
485 matrix: &S,
486 k: usize,
487 options: &SVDOptions,
488) -> SparseResult<SVDResult<T>>
489where
490 T: Float
491 + SparseElement
492 + Debug
493 + Copy
494 + Add<Output = T>
495 + Sub<Output = T>
496 + Mul<Output = T>
497 + Div<Output = T>
498 + 'static
499 + std::iter::Sum,
500 S: SparseArray<T>,
501{
502 lanczos_bidiag_svd(matrix, k, options)
505}
506
507#[allow(dead_code)]
509fn matrix_vector_product<T, S>(matrix: &S, vector: &Array1<T>) -> SparseResult<Array1<T>>
510where
511 T: Float
512 + SparseElement
513 + Debug
514 + Copy
515 + Add<Output = T>
516 + Sub<Output = T>
517 + Mul<Output = T>
518 + Div<Output = T>
519 + 'static
520 + std::iter::Sum,
521 S: SparseArray<T>,
522{
523 let (m, n) = matrix.shape();
524 if vector.len() != n {
525 return Err(SparseError::DimensionMismatch {
526 expected: n,
527 found: vector.len(),
528 });
529 }
530
531 let mut result = Array1::zeros(m);
532 let (row_indices, col_indices, values) = matrix.find();
533
534 for (k, (&i, &j)) in row_indices.iter().zip(col_indices.iter()).enumerate() {
535 result[i] = result[i] + values[k] * vector[j];
536 }
537
538 Ok(result)
539}
540
541#[allow(dead_code)]
543fn matrix_transpose_vector_product<T, S>(matrix: &S, vector: &Array1<T>) -> SparseResult<Array1<T>>
544where
545 T: Float
546 + SparseElement
547 + Debug
548 + Copy
549 + Add<Output = T>
550 + Sub<Output = T>
551 + Mul<Output = T>
552 + Div<Output = T>
553 + 'static
554 + std::iter::Sum,
555 S: SparseArray<T>,
556{
557 let (m, n) = matrix.shape();
558 if vector.len() != m {
559 return Err(SparseError::DimensionMismatch {
560 expected: m,
561 found: vector.len(),
562 });
563 }
564
565 let mut result = Array1::zeros(n);
566 let (row_indices, col_indices, values) = matrix.find();
567
568 for (k, (&i, &j)) in row_indices.iter().zip(col_indices.iter()).enumerate() {
569 result[j] = result[j] + values[k] * vector[i];
570 }
571
572 Ok(result)
573}
574
575#[allow(dead_code)]
577fn solve_bidiagonal_svd<T>(
578 alpha: &[T],
579 beta: &[T],
580 k: usize,
581) -> SparseResult<BidiagonalSvdResult<T>>
582where
583 T: Float
584 + SparseElement
585 + Debug
586 + Copy
587 + Add<Output = T>
588 + Sub<Output = T>
589 + Mul<Output = T>
590 + Div<Output = T>
591 + 'static
592 + std::iter::Sum,
593{
594 let n = alpha.len();
595
596 let mut singular_values = Vec::with_capacity(k);
600 let mut u_vectors = Vec::with_capacity(k);
601 let mut vt_vectors = Vec::with_capacity(k);
602
603 if n > 0 {
605 let largest_sv = alpha
606 .iter()
607 .map(|&x| x.abs())
608 .fold(T::sparse_zero(), |a, b| if a > b { a } else { b });
609 singular_values.push(largest_sv);
610
611 let mut u_vec = vec![0.0; n];
613 let mut vt_vec = vec![0.0; n];
614
615 if n > 0 {
616 u_vec[0] = 1.0;
617 vt_vec[0] = 1.0;
618 }
619
620 u_vectors.push(u_vec);
621 vt_vectors.push(vt_vec);
622 }
623
624 while singular_values.len() < k && singular_values.len() < n {
626 singular_values.push(T::sparse_zero());
627 u_vectors.push(vec![0.0_f64; n]);
628 vt_vectors.push(vec![0.0_f64; n]);
629 }
630
631 Ok((singular_values, u_vectors, vt_vectors))
632}
633
634#[allow(dead_code)]
636fn qr_decomposition_orthogonal<T>(matrix: &Array2<T>) -> SparseResult<Array2<T>>
637where
638 T: Float
639 + SparseElement
640 + Debug
641 + Copy
642 + Add<Output = T>
643 + Sub<Output = T>
644 + Mul<Output = T>
645 + Div<Output = T>
646 + 'static
647 + std::iter::Sum,
648{
649 let (m, n) = matrix.dim();
650 let mut q = matrix.clone();
651
652 for j in 0..n {
654 let mut norm = T::sparse_zero();
656 for i in 0..m {
657 norm = norm + q[[i, j]] * q[[i, j]];
658 }
659 norm = norm.sqrt();
660
661 if !SparseElement::is_zero(&norm) {
662 for i in 0..m {
663 q[[i, j]] = q[[i, j]] / norm;
664 }
665 }
666
667 for k in (j + 1)..n {
669 let mut dot = T::sparse_zero();
670 for i in 0..m {
671 dot = dot + q[[i, j]] * q[[i, k]];
672 }
673
674 for i in 0..m {
675 q[[i, k]] = q[[i, k]] - dot * q[[i, j]];
676 }
677 }
678 }
679
680 Ok(q)
681}
682
683#[allow(dead_code)]
691fn dense_svd<T>(matrix: &Array2<T>, k: usize) -> SparseResult<SVDResult<T>>
692where
693 T: Float
694 + SparseElement
695 + Debug
696 + Copy
697 + Add<Output = T>
698 + Sub<Output = T>
699 + Mul<Output = T>
700 + Div<Output = T>
701 + 'static
702 + std::iter::Sum,
703{
704 let (m, n) = matrix.dim();
705 let rank = k.min(m).min(n);
706
707 let mut g = Array2::zeros((n, n));
709 for i in 0..n {
710 for j in i..n {
711 let mut s = T::sparse_zero();
712 for r in 0..m {
713 s = s + matrix[[r, i]] * matrix[[r, j]];
714 }
715 g[[i, j]] = s;
716 g[[j, i]] = s;
717 }
718 }
719
720 let mut v_mat = Array2::<T>::eye(n);
722 let max_sweeps = 100usize;
723 let tol = T::from(1e-14).unwrap_or_else(|| T::epsilon());
724
725 for _sweep in 0..max_sweeps {
726 let mut off_norm = T::sparse_zero();
727 for i in 0..n {
728 for j in (i + 1)..n {
729 off_norm = off_norm + g[[i, j]] * g[[i, j]];
730 }
731 }
732 if off_norm < tol * tol {
733 break;
734 }
735
736 for i in 0..n {
737 for j in (i + 1)..n {
738 let gij = g[[i, j]];
739 if gij.abs() < tol {
740 continue;
741 }
742 let diff = g[[j, j]] - g[[i, i]];
743 let tau = if diff.abs() < tol {
744 T::sparse_one()
745 } else {
746 let ratio = T::from(2.0).expect("conv") * gij / diff;
747 let sign_r = if ratio >= T::sparse_zero() {
748 T::sparse_one()
749 } else {
750 -T::sparse_one()
751 };
752 sign_r / (ratio.abs() + (ratio * ratio + T::sparse_one()).sqrt())
753 };
754 let cos_t = T::sparse_one() / (tau * tau + T::sparse_one()).sqrt();
755 let sin_t = tau * cos_t;
756
757 for r in 0..n {
758 if r == i || r == j {
759 continue;
760 }
761 let gri = g[[r, i]];
762 let grj = g[[r, j]];
763 g[[r, i]] = cos_t * gri - sin_t * grj;
764 g[[i, r]] = g[[r, i]];
765 g[[r, j]] = sin_t * gri + cos_t * grj;
766 g[[j, r]] = g[[r, j]];
767 }
768 let gii = g[[i, i]];
769 let gjj = g[[j, j]];
770 let two = T::from(2.0).expect("conv");
771 g[[i, i]] = cos_t * cos_t * gii - two * sin_t * cos_t * gij + sin_t * sin_t * gjj;
772 g[[j, j]] = sin_t * sin_t * gii + two * sin_t * cos_t * gij + cos_t * cos_t * gjj;
773 g[[i, j]] = T::sparse_zero();
774 g[[j, i]] = T::sparse_zero();
775
776 for r in 0..n {
777 let vri = v_mat[[r, i]];
778 let vrj = v_mat[[r, j]];
779 v_mat[[r, i]] = cos_t * vri - sin_t * vrj;
780 v_mat[[r, j]] = sin_t * vri + cos_t * vrj;
781 }
782 }
783 }
784 }
785
786 let mut sigma_sq: Vec<(T, usize)> = (0..n).map(|i| (g[[i, i]], i)).collect();
788 sigma_sq.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
789
790 let take = rank.min(sigma_sq.len());
791 let mut singular_values = Vec::with_capacity(take);
792 let mut vt_final = Array2::zeros((take, n));
793
794 for j in 0..take {
795 let (lam, idx) = sigma_sq[j];
796 let sv = if lam > T::sparse_zero() {
797 lam.sqrt()
798 } else {
799 T::sparse_zero()
800 };
801 singular_values.push(sv);
802 for col in 0..n {
803 vt_final[[j, col]] = v_mat[[col, idx]];
804 }
805 }
806
807 let mut u_final = Array2::zeros((m, take));
809 for j in 0..take {
810 let sv = singular_values[j];
811 if sv > tol {
812 for i in 0..m {
813 let mut dot = T::sparse_zero();
814 for l in 0..n {
815 dot = dot + matrix[[i, l]] * vt_final[[j, l]];
816 }
817 u_final[[i, j]] = dot / sv;
818 }
819 }
820 }
821
822 Ok(SVDResult {
823 u: Some(u_final),
824 s: Array1::from_vec(singular_values),
825 vt: Some(vt_final),
826 iterations: max_sweeps,
827 converged: true,
828 })
829}
830
831#[cfg(test)]
832mod tests {
833 use super::*;
834 use crate::csr_array::CsrArray;
835 use approx::assert_relative_eq;
836
837 fn create_test_matrix() -> CsrArray<f64> {
838 let rows = vec![0, 0, 1, 2, 2];
840 let cols = vec![0, 2, 1, 0, 2];
841 let data = vec![3.0, 2.0, 1.0, 4.0, 5.0];
842
843 CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).expect("Operation failed")
844 }
845
846 #[test]
847 fn test_svds_basic() {
848 let matrix = create_test_matrix();
849 let result = svds(&matrix, Some(2), None).expect("Operation failed");
850
851 assert_eq!(result.s.len(), 2);
853
854 if let Some(ref u) = result.u {
855 assert_eq!(u.shape(), [3, 2]);
856 }
857
858 if let Some(ref vt) = result.vt {
859 assert_eq!(vt.shape(), [2, 3]);
860 }
861
862 assert!(result.s[0] >= 0.0);
864 if result.s.len() > 1 {
865 assert!(result.s[0] >= result.s[1]);
866 }
867 }
868
869 #[test]
870 fn test_matrix_vector_product() {
871 let matrix = create_test_matrix();
872 let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
873
874 let y = matrix_vector_product(&matrix, &x).expect("Operation failed");
875
876 assert_eq!(y.len(), 3);
878
879 assert_relative_eq!(y[0], 3.0 * 1.0 + 2.0 * 3.0, epsilon = 1e-10); assert_relative_eq!(y[1], 1.0 * 2.0, epsilon = 1e-10); assert_relative_eq!(y[2], 4.0 * 1.0 + 5.0 * 3.0, epsilon = 1e-10); }
884
885 #[test]
886 fn test_matrix_transpose_vector_product() {
887 let matrix = create_test_matrix();
888 let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
889
890 let y = matrix_transpose_vector_product(&matrix, &x).expect("Operation failed");
891
892 assert_eq!(y.len(), 3);
894
895 assert_relative_eq!(y[0], 3.0 * 1.0 + 4.0 * 3.0, epsilon = 1e-10); assert_relative_eq!(y[1], 1.0 * 2.0, epsilon = 1e-10); assert_relative_eq!(y[2], 2.0 * 1.0 + 5.0 * 3.0, epsilon = 1e-10); }
900
901 #[test]
902 fn test_svd_options() {
903 let matrix = create_test_matrix();
904
905 let options = SVDOptions {
906 k: 1,
907 method: SVDMethod::Lanczos,
908 compute_u: false,
909 compute_vt: true,
910 ..Default::default()
911 };
912
913 let result = svds(&matrix, Some(1), Some(options)).expect("Operation failed");
914
915 assert_eq!(result.s.len(), 1);
916 assert!(result.u.is_none());
917 assert!(result.vt.is_some());
918 }
919
920 #[test]
921 fn test_svd_truncated_api() {
922 let matrix = create_test_matrix();
923
924 let result =
925 svd_truncated(&matrix, 2, "lanczos", Some(1e-8), Some(100)).expect("Operation failed");
926
927 assert_eq!(result.s.len(), 2);
928 assert!(result.u.is_some());
929 assert!(result.vt.is_some());
930 }
931
932 #[test]
933 fn test_randomized_svd() {
934 let matrix = create_test_matrix();
935
936 let options = SVDOptions {
937 k: 2,
938 method: SVDMethod::Randomized,
939 n_oversamples: 5,
940 n_iter: 1,
941 ..Default::default()
942 };
943
944 let result = svds(&matrix, Some(2), Some(options)).expect("Operation failed");
945
946 assert_eq!(result.s.len(), 2);
947 assert!(result.converged);
948 }
949
950 #[test]
951 fn test_qr_decomposition() {
952 let matrix = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
953 .expect("Operation failed");
954
955 let q = qr_decomposition_orthogonal(&matrix).expect("Operation failed");
956
957 assert_eq!(q.shape(), [3, 2]);
959
960 for j in 0..2 {
962 let mut norm = 0.0;
963 for i in 0..3 {
964 norm += q[[i, j]] * q[[i, j]];
965 }
966 assert_relative_eq!(norm, 1.0, epsilon = 1e-10);
967 }
968 }
969
970 #[test]
971 fn test_svd_method_parsing() {
972 assert_eq!(
973 SVDMethod::from_str("lanczos").expect("Operation failed"),
974 SVDMethod::Lanczos
975 );
976 assert_eq!(
977 SVDMethod::from_str("randomized").expect("Operation failed"),
978 SVDMethod::Randomized
979 );
980 assert_eq!(
981 SVDMethod::from_str("power").expect("Operation failed"),
982 SVDMethod::Power
983 );
984 assert!(SVDMethod::from_str("invalid").is_err());
985 }
986
987 #[test]
988 fn test_invalid_k() {
989 let matrix = create_test_matrix();
990
991 let result = svds(&matrix, Some(10), None);
993 assert!(result.is_err());
994 }
995}