single_svdlib/randomized/
mod.rs

1use crate::error::SvdLibError;
2use crate::utils::determine_chunk_size;
3use crate::{Diagnostics, SMat, SvdFloat, SvdRec};
4use nalgebra_sparse::na::{ComplexField, DMatrix, DVector, RealField};
5use ndarray::Array1;
6use nshare::IntoNdarray2;
7use rand::prelude::{Distribution, StdRng};
8use rand::SeedableRng;
9use rand_distr::Normal;
10use rayon::iter::ParallelIterator;
11use rayon::prelude::{IndexedParallelIterator, IntoParallelIterator};
12use std::ops::Mul;
13
14#[derive(Debug, Clone, Copy, PartialEq)]
15pub enum PowerIterationNormalizer {
16    QR,
17    LU,
18    None,
19}
20
21const PARALLEL_THRESHOLD_ROWS: usize = 5000;
22const PARALLEL_THRESHOLD_COLS: usize = 1000;
23const PARALLEL_THRESHOLD_ELEMENTS: usize = 100_000;
24
25pub fn randomized_svd<T, M>(
26    m: &M,
27    target_rank: usize,
28    n_oversamples: usize,
29    n_power_iters: usize,
30    power_iteration_normalizer: PowerIterationNormalizer,
31    mean_center: bool,
32    seed: Option<u64>,
33) -> anyhow::Result<SvdRec<T>>
34where
35    T: SvdFloat + RealField,
36    M: SMat<T>,
37    T: ComplexField,
38{
39    let m_rows = m.nrows();
40    let m_cols = m.ncols();
41
42    let rank = target_rank.min(m_rows.min(m_cols));
43    let l = rank + n_oversamples;
44
45    let column_means = if mean_center {
46        compute_column_means(m)
47    } else {
48        None
49    };
50
51    let mut omega = generate_random_matrix(m_cols, l, seed);
52
53    let mut y = DMatrix::<T>::zeros(m_rows, l);
54    multiply_matrix_centered(m, &omega, &mut y, false, &column_means);
55
56    if n_power_iters > 0 {
57        let mut z = DMatrix::<T>::zeros(m_cols, l);
58
59        for _ in 0..n_power_iters {
60            multiply_matrix_centered(m, &y, &mut z, true, &column_means);
61            match power_iteration_normalizer {
62                PowerIterationNormalizer::QR => {
63                    let qr = z.qr();
64                    z = qr.q();
65                }
66                PowerIterationNormalizer::LU => {
67                    normalize_columns(&mut z);
68                }
69                PowerIterationNormalizer::None => {}
70            }
71
72            multiply_matrix_centered(m, &z, &mut y, false, &column_means);
73            match power_iteration_normalizer {
74                PowerIterationNormalizer::QR => {
75                    let qr = y.qr();
76                    y = qr.q();
77                }
78                PowerIterationNormalizer::LU => normalize_columns(&mut y),
79                PowerIterationNormalizer::None => {}
80            }
81        }
82    }
83
84    let qr = y.qr();
85    let y = qr.q();
86
87    let mut b = DMatrix::<T>::zeros(y.ncols(), m_cols);
88    multiply_transposed_by_matrix_centered(&y, m, &mut b, &column_means);
89    let svd = b.svd(true, true);
90    let u_b = svd
91        .u
92        .ok_or_else(|| SvdLibError::Las2Error("SVD U computation failed".to_string()))?;
93    let singular_values = svd.singular_values;
94    let vt = svd
95        .v_t
96        .ok_or_else(|| SvdLibError::Las2Error("SVD V_t computation failed".to_string()))?;
97
98    let u = y.mul(&u_b);
99    let actual_rank = target_rank.min(singular_values.len());
100
101    let u_subset = u.columns(0, actual_rank);
102    let s = convert_singular_values(
103        <DVector<T>>::from(singular_values.rows(0, actual_rank)),
104        actual_rank,
105    );
106    let vt_subset = vt.rows(0, actual_rank).into_owned();
107    let u = u_subset.into_owned().into_ndarray2();
108    let vt = vt_subset.into_ndarray2();
109    Ok(SvdRec {
110        d: actual_rank,
111        u,
112        s,
113        vt,
114        diagnostics: create_diagnostics(
115            m,
116            actual_rank,
117            target_rank,
118            n_power_iters,
119            seed.unwrap_or(0) as u32,
120        ),
121    })
122}
123
124fn convert_singular_values<T: SvdFloat + ComplexField>(
125    values: DVector<T::RealField>,
126    size: usize,
127) -> Array1<T> {
128    let mut array = Array1::zeros(size);
129
130    for i in 0..size {
131        array[i] = T::from_real(values[i].clone());
132    }
133
134    array
135}
136
137fn compute_column_means<T, M>(m: &M) -> Option<DVector<T>>
138where
139    T: SvdFloat + RealField,
140    M: SMat<T>,
141{
142    let m_rows = m.nrows();
143    let m_cols = m.ncols();
144
145    let mut means = DVector::zeros(m_cols);
146
147    for j in 0..m_cols {
148        let mut col_vec = vec![T::zero(); m_cols];
149        let mut result_vec = vec![T::zero(); m_rows];
150
151        col_vec[j] = T::one();
152
153        m.svd_opa(&col_vec, &mut result_vec, false);
154
155        let mut sum = T::zero();
156        for &val in &result_vec {
157            sum += val;
158        }
159
160        means[j] = sum / T::from_f64(m_rows as f64).unwrap();
161    }
162
163    Some(means)
164}
165
166fn create_diagnostics<T, M: SMat<T>>(
167    a: &M,
168    d: usize,
169    target_rank: usize,
170    power_iterations: usize,
171    seed: u32,
172) -> Diagnostics<T>
173where
174    T: SvdFloat,
175{
176    Diagnostics {
177        non_zero: a.nnz(),
178        dimensions: target_rank,
179        iterations: power_iterations,
180        transposed: false,
181        lanczos_steps: 0, // we dont do that
182        ritz_values_stabilized: d,
183        significant_values: d,
184        singular_values: d,
185        end_interval: [T::from(-1e-30).unwrap(), T::from(1e-30).unwrap()],
186        kappa: T::from(1e-6).unwrap(),
187        random_seed: seed,
188    }
189}
190
191fn normalize_columns<T: SvdFloat + RealField + Send + Sync>(matrix: &mut DMatrix<T>) {
192    let rows = matrix.nrows();
193    let cols = matrix.ncols();
194
195    if rows < PARALLEL_THRESHOLD_ROWS && cols < PARALLEL_THRESHOLD_COLS {
196        for j in 0..cols {
197            let mut norm = T::zero();
198
199            // Calculate column norm
200            for i in 0..rows {
201                norm += ComplexField::powi(matrix[(i, j)], 2);
202            }
203            norm = ComplexField::sqrt(norm);
204
205            if norm > T::from_f64(1e-10).unwrap() {
206                let scale = T::one() / norm;
207                for i in 0..rows {
208                    matrix[(i, j)] *= scale;
209                }
210            }
211        }
212        return;
213    }
214
215    let norms: Vec<T> = (0..cols)
216        .into_par_iter()
217        .map(|j| {
218            let mut norm = T::zero();
219            for i in 0..rows {
220                let val = unsafe { *matrix.get_unchecked((i, j)) };
221                norm += ComplexField::powi(val, 2);
222            }
223            ComplexField::sqrt(norm)
224        })
225        .collect();
226
227    let scales: Vec<(usize, T)> = norms
228        .into_iter()
229        .enumerate()
230        .filter_map(|(j, norm)| {
231            if norm > T::from_f64(1e-10).unwrap() {
232                Some((j, T::one() / norm))
233            } else {
234                None // Skip columns with too small norms
235            }
236        })
237        .collect();
238
239    scales.iter().for_each(|(j, scale)| {
240        for i in 0..rows {
241            let value = matrix.get_mut((i, *j)).unwrap();
242            *value = value.clone() * scale.clone();
243        }
244    });
245}
246
247// ----------------------------------------
248// Utils Functions
249// ----------------------------------------
250
251fn generate_random_matrix<T: SvdFloat + RealField>(
252    rows: usize,
253    cols: usize,
254    seed: Option<u64>,
255) -> DMatrix<T> {
256    let mut rng = match seed {
257        Some(s) => StdRng::seed_from_u64(s),
258        None => StdRng::seed_from_u64(0),
259    };
260
261    let normal = Normal::new(0.0, 1.0).unwrap();
262    DMatrix::from_fn(rows, cols, |_, _| {
263        T::from_f64(normal.sample(&mut rng)).unwrap()
264    })
265}
266
267fn multiply_matrix<T: SvdFloat, M: SMat<T>>(
268    sparse: &M,
269    dense: &DMatrix<T>,
270    result: &mut DMatrix<T>,
271    transpose_sparse: bool,
272) {
273    let cols = dense.ncols();
274
275    let results: Vec<(usize, Vec<T>)> = (0..cols)
276        .into_par_iter()
277        .map(|j| {
278            let mut col_vec = vec![T::zero(); dense.nrows()];
279            let mut result_vec = vec![T::zero(); result.nrows()];
280
281            for i in 0..dense.nrows() {
282                col_vec[i] = dense[(i, j)];
283            }
284
285            sparse.svd_opa(&col_vec, &mut result_vec, transpose_sparse);
286
287            (j, result_vec)
288        })
289        .collect();
290
291    for (j, col_result) in results {
292        for i in 0..result.nrows() {
293            result[(i, j)] = col_result[i];
294        }
295    }
296}
297
298fn multiply_transposed_by_matrix<T: SvdFloat, M: SMat<T>>(
299    q: &DMatrix<T>,
300    sparse: &M,
301    result: &mut DMatrix<T>,
302) {
303    let q_rows = q.nrows();
304    let q_cols = q.ncols();
305    let sparse_rows = sparse.nrows();
306    let sparse_cols = sparse.ncols();
307
308    eprintln!("Q dimensions: {} x {}", q_rows, q_cols);
309    eprintln!("Sparse dimensions: {} x {}", sparse_rows, sparse_cols);
310    eprintln!("Result dimensions: {} x {}", result.nrows(), result.ncols());
311
312    assert_eq!(
313        q_rows, sparse_rows,
314        "Dimension mismatch: Q has {} rows but sparse has {} rows",
315        q_rows, sparse_rows
316    );
317
318    assert_eq!(
319        result.nrows(),
320        q_cols,
321        "Result matrix has incorrect row count: expected {}, got {}",
322        q_cols,
323        result.nrows()
324    );
325    assert_eq!(
326        result.ncols(),
327        sparse_cols,
328        "Result matrix has incorrect column count: expected {}, got {}",
329        sparse_cols,
330        result.ncols()
331    );
332
333    let chunk_size = determine_chunk_size(q_cols);
334
335    let chunk_results: Vec<Vec<(usize, Vec<T>)>> = (0..q_cols)
336        .into_par_iter()
337        .chunks(chunk_size)
338        .map(|chunk| {
339            let mut chunk_results = Vec::with_capacity(chunk.len());
340
341            for &col_idx in &chunk {
342                let mut q_col = vec![T::zero(); q_rows];
343                for i in 0..q_rows {
344                    q_col[i] = q[(i, col_idx)];
345                }
346
347                let mut result_row = vec![T::zero(); sparse_cols];
348
349                sparse.svd_opa(&q_col, &mut result_row, true);
350
351                chunk_results.push((col_idx, result_row));
352            }
353            chunk_results
354        })
355        .collect();
356
357    for chunk_result in chunk_results {
358        for (row_idx, row_values) in chunk_result {
359            for j in 0..sparse_cols {
360                result[(row_idx, j)] = row_values[j];
361            }
362        }
363    }
364}
365
366pub fn svd_flip<T: SvdFloat + 'static>(
367    u: Option<&mut DMatrix<T>>,
368    v: Option<&mut DMatrix<T>>,
369    u_based_decision: bool,
370) -> Result<(), SvdLibError> {
371    if u.is_none() && v.is_none() {
372        return Err(SvdLibError::Las2Error(
373            "Both u and v cannot be None".to_string(),
374        ));
375    }
376
377    if u_based_decision {
378        if u.is_none() {
379            return Err(SvdLibError::Las2Error(
380                "u cannot be None when u_based_decision is true".to_string(),
381            ));
382        }
383
384        let u = u.unwrap();
385        let ncols = u.ncols();
386        let nrows = u.nrows();
387
388        let mut signs = DVector::from_element(ncols, T::one());
389
390        for j in 0..ncols {
391            let mut max_abs = T::zero();
392            let mut max_idx = 0;
393
394            for i in 0..nrows {
395                let abs_val = u[(i, j)].abs();
396                if abs_val > max_abs {
397                    max_abs = abs_val;
398                    max_idx = i;
399                }
400            }
401
402            if u[(max_idx, j)] < T::zero() {
403                signs[j] = -T::one();
404            }
405        }
406
407        for j in 0..ncols {
408            for i in 0..nrows {
409                u[(i, j)] *= signs[j];
410            }
411        }
412
413        if let Some(v) = v {
414            let v_nrows = v.nrows();
415            let v_ncols = v.ncols();
416
417            for i in 0..v_nrows.min(signs.len()) {
418                for j in 0..v_ncols {
419                    v[(i, j)] *= signs[i];
420                }
421            }
422        }
423    } else {
424        if v.is_none() {
425            return Err(SvdLibError::Las2Error(
426                "v cannot be None when u_based_decision is false".to_string(),
427            ));
428        }
429
430        let v = v.unwrap();
431        let nrows = v.nrows();
432        let ncols = v.ncols();
433
434        let mut signs = DVector::from_element(nrows, T::one());
435
436        for i in 0..nrows {
437            let mut max_abs = T::zero();
438            let mut max_idx = 0;
439
440            for j in 0..ncols {
441                let abs_val = v[(i, j)].abs();
442                if abs_val > max_abs {
443                    max_abs = abs_val;
444                    max_idx = j;
445                }
446            }
447
448            if v[(i, max_idx)] < T::zero() {
449                signs[i] = -T::one();
450            }
451        }
452
453        for i in 0..nrows {
454            for j in 0..ncols {
455                v[(i, j)] *= signs[i];
456            }
457        }
458
459        if let Some(u) = u {
460            let u_nrows = u.nrows();
461            let u_ncols = u.ncols();
462
463            for j in 0..u_ncols.min(signs.len()) {
464                for i in 0..u_nrows {
465                    u[(i, j)] *= signs[j];
466                }
467            }
468        }
469    }
470
471    Ok(())
472}
473
474fn multiply_matrix_centered<T: SvdFloat, M: SMat<T>>(
475    sparse: &M,
476    dense: &DMatrix<T>,
477    result: &mut DMatrix<T>,
478    transpose_sparse: bool,
479    column_means: &Option<DVector<T>>,
480) {
481    if column_means.is_none() {
482        multiply_matrix(sparse, dense, result, transpose_sparse);
483        return;
484    }
485
486    let means = column_means.as_ref().unwrap();
487    let cols = dense.ncols();
488
489    let results: Vec<(usize, Vec<T>)> = (0..cols)
490        .into_par_iter()
491        .map(|j| {
492            let mut col_vec = vec![T::zero(); dense.nrows()];
493            let mut result_vec = vec![T::zero(); result.nrows()];
494
495            for i in 0..dense.nrows() {
496                col_vec[i] = dense[(i, j)];
497            }
498
499            sparse.svd_opa(&col_vec, &mut result_vec, transpose_sparse);
500
501            if !transpose_sparse {
502                let mut dot_product = T::zero();
503                for &val in &col_vec {
504                    dot_product += val;
505                }
506
507                for i in 0..result_vec.len() {
508                    for (j, &mean) in means.iter().enumerate() {
509                        if !transpose_sparse {
510                            result_vec[i] -= mean * dot_product;
511                        }
512                    }
513                }
514            } else {
515                let mut sum_x = T::zero();
516                for &val in &col_vec {
517                    sum_x += val;
518                }
519
520                for (i, mean) in means.iter().enumerate() {
521                    result_vec[i] -= *mean * sum_x;
522                }
523            }
524
525            (j, result_vec)
526        })
527        .collect();
528
529    for (j, col_result) in results {
530        for i in 0..result.nrows() {
531            result[(i, j)] = col_result[i];
532        }
533    }
534}
535
536fn multiply_transposed_by_matrix_centered<T: SvdFloat, M: SMat<T>>(
537    q: &DMatrix<T>,
538    sparse: &M,
539    result: &mut DMatrix<T>,
540    column_means: &Option<DVector<T>>,
541) {
542    if column_means.is_none() {
543        multiply_transposed_by_matrix(q, sparse, result);
544        return;
545    }
546
547    let means = column_means.as_ref().unwrap();
548    let q_rows = q.nrows();
549    let q_cols = q.ncols();
550    let sparse_rows = sparse.nrows();
551    let sparse_cols = sparse.ncols();
552
553    assert_eq!(
554        q_rows, sparse_rows,
555        "Dimension mismatch: Q has {} rows but sparse has {} rows",
556        q_rows, sparse_rows
557    );
558
559    assert_eq!(
560        result.nrows(),
561        q_cols,
562        "Result matrix has incorrect row count: expected {}, got {}",
563        q_cols,
564        result.nrows()
565    );
566    assert_eq!(
567        result.ncols(),
568        sparse_cols,
569        "Result matrix has incorrect column count: expected {}, got {}",
570        sparse_cols,
571        result.ncols()
572    );
573
574    let chunk_size = determine_chunk_size(q_cols);
575
576    let chunk_results: Vec<Vec<(usize, Vec<T>)>> = (0..q_cols)
577        .into_par_iter()
578        .chunks(chunk_size)
579        .map(|chunk| {
580            let mut chunk_results = Vec::with_capacity(chunk.len());
581
582            for &col_idx in &chunk {
583                let mut q_col = vec![T::zero(); q_rows];
584                for i in 0..q_rows {
585                    q_col[i] = q[(i, col_idx)];
586                }
587
588                let mut result_row = vec![T::zero(); sparse_cols];
589
590                sparse.svd_opa(&q_col, &mut result_row, true);
591
592                let mut q_sum = T::zero();
593                for &val in &q_col {
594                    q_sum += val;
595                }
596
597                for j in 0..sparse_cols {
598                    result_row[j] -= means[j] * q_sum;
599                }
600
601                chunk_results.push((col_idx, result_row));
602            }
603            chunk_results
604        })
605        .collect();
606
607    for chunk_result in chunk_results {
608        for (row_idx, row_values) in chunk_result {
609            for j in 0..sparse_cols {
610                result[(row_idx, j)] = row_values[j];
611            }
612        }
613    }
614}
615
616#[cfg(test)]
617mod randomized_svd_tests {
618    use super::*;
619    use crate::randomized::{randomized_svd, PowerIterationNormalizer};
620    use nalgebra_sparse::coo::CooMatrix;
621    use nalgebra_sparse::CsrMatrix;
622    use ndarray::Array2;
623    use rand::rngs::StdRng;
624    use rand::{Rng, SeedableRng};
625    use rayon::ThreadPoolBuilder;
626    use std::sync::Once;
627
628    static INIT: Once = Once::new();
629
630    fn setup_thread_pool() {
631        INIT.call_once(|| {
632            ThreadPoolBuilder::new()
633                .num_threads(16)
634                .build_global()
635                .expect("Failed to build global thread pool");
636
637            println!("Initialized thread pool with {} threads", 16);
638        });
639    }
640
641    fn create_sparse_matrix(
642        rows: usize,
643        cols: usize,
644        density: f64,
645    ) -> nalgebra_sparse::coo::CooMatrix<f64> {
646        use std::collections::HashSet;
647
648        let mut coo = nalgebra_sparse::coo::CooMatrix::new(rows, cols);
649
650        let mut rng = StdRng::seed_from_u64(42);
651
652        let nnz = (rows as f64 * cols as f64 * density).round() as usize;
653
654        let nnz = nnz.max(1);
655
656        let mut positions = HashSet::new();
657
658        while positions.len() < nnz {
659            let i = rng.gen_range(0..rows);
660            let j = rng.gen_range(0..cols);
661
662            if positions.insert((i, j)) {
663                let val = loop {
664                    let v: f64 = rng.gen_range(-10.0..10.0);
665                    if v.abs() > 1e-10 {
666                        break v;
667                    }
668                };
669
670                coo.push(i, j, val);
671            }
672        }
673
674        let actual_density = coo.nnz() as f64 / (rows as f64 * cols as f64);
675        println!("Created sparse matrix: {} x {}", rows, cols);
676        println!("  - Requested density: {:.6}", density);
677        println!("  - Actual density: {:.6}", actual_density);
678        println!("  - Sparsity: {:.4}%", (1.0 - actual_density) * 100.0);
679        println!("  - Non-zeros: {}", coo.nnz());
680
681        coo
682    }
683
684    #[test]
685    fn test_randomized_svd_accuracy() {
686        setup_thread_pool();
687
688        let mut coo = CooMatrix::<f64>::new(20, 15);
689
690        for i in 0..20 {
691            for j in 0..5 {
692                let val = (i as f64) * 0.5 + (j as f64) * 2.0;
693                coo.push(i, j, val);
694            }
695        }
696
697        let csr = CsrMatrix::from(&coo);
698
699        let mut std_svd = crate::lanczos::svd_dim(&csr, 10).unwrap();
700
701        let rand_svd = randomized_svd(
702            &csr,
703            10,
704            5,
705            2,
706            PowerIterationNormalizer::QR,
707            false,
708            Some(42),
709        )
710        .unwrap();
711
712        assert_eq!(rand_svd.d, 10, "Expected rank of 10");
713
714        let rel_tol = 0.3;
715        let compare_count = std::cmp::min(2, std::cmp::min(std_svd.d, rand_svd.d));
716        println!("Standard SVD has {} dimensions", std_svd.d);
717        println!("Randomized SVD has {} dimensions", rand_svd.d);
718
719        for i in 0..compare_count {
720            let rel_diff = (std_svd.s[i] - rand_svd.s[i]).abs() / std_svd.s[i];
721            println!(
722                "Singular value {}: standard={}, randomized={}, rel_diff={}",
723                i, std_svd.s[i], rand_svd.s[i], rel_diff
724            );
725            assert!(
726                rel_diff < rel_tol,
727                "Dominant singular value {} differs too much: rel diff = {}, standard = {}, randomized = {}",
728                i, rel_diff, std_svd.s[i], rand_svd.s[i]
729            );
730        }
731
732        std_svd.u = std_svd.u.t().into_owned();
733        let std_recon = std_svd.recompose();
734        let rand_recon = rand_svd.recompose();
735
736        let mut diff_norm = 0.0;
737        let mut orig_norm = 0.0;
738
739        for i in 0..20 {
740            for j in 0..15 {
741                diff_norm += (std_recon[[i, j]] - rand_recon[[i, j]]).powi(2);
742                orig_norm += std_recon[[i, j]].powi(2);
743            }
744        }
745
746        diff_norm = diff_norm.sqrt();
747        orig_norm = orig_norm.sqrt();
748
749        let rel_error = diff_norm / orig_norm;
750        assert!(
751            rel_error < 0.2,
752            "Reconstruction difference too large: {}",
753            rel_error
754        );
755    }
756
757    // Test with mean centering
758    #[test]
759    fn test_randomized_svd_with_mean_centering() {
760        setup_thread_pool();
761
762        let mut coo = CooMatrix::<f64>::new(30, 10);
763        let mut rng = StdRng::seed_from_u64(123);
764
765        let column_means: Vec<f64> = (0..10).map(|i| i as f64 * 2.0).collect();
766
767        let mut u = vec![vec![0.0; 3]; 30]; // 3 factors
768        let mut v = vec![vec![0.0; 3]; 10];
769
770        for i in 0..30 {
771            for j in 0..3 {
772                u[i][j] = rng.gen_range(-1.0..1.0);
773            }
774        }
775
776        for i in 0..10 {
777            for j in 0..3 {
778                v[i][j] = rng.gen_range(-1.0..1.0);
779            }
780        }
781
782        for i in 0..30 {
783            for j in 0..10 {
784                let mut val = 0.0;
785                for k in 0..3 {
786                    val += u[i][k] * v[j][k];
787                }
788                val = val + column_means[j] + rng.gen_range(-0.1..0.1);
789                coo.push(i, j, val);
790            }
791        }
792
793        let csr = CsrMatrix::from(&coo);
794
795        let svd_no_center =
796            randomized_svd(&csr, 3, 3, 2, PowerIterationNormalizer::QR, false, Some(42)).unwrap();
797
798        let svd_with_center =
799            randomized_svd(&csr, 3, 3, 2, PowerIterationNormalizer::QR, true, Some(42)).unwrap();
800
801        println!("Singular values without centering: {:?}", svd_no_center.s);
802        println!("Singular values with centering: {:?}", svd_with_center.s);
803    }
804
805    #[test]
806    fn test_randomized_svd_large_sparse() {
807        setup_thread_pool();
808
809        let test_matrix = create_sparse_matrix(5000, 1000, 0.01);
810
811        let csr = CsrMatrix::from(&test_matrix);
812
813        let result = randomized_svd(
814            &csr,
815            20,
816            10,
817            2,
818            PowerIterationNormalizer::QR,
819            false,
820            Some(42),
821        );
822
823        assert!(
824            result.is_ok(),
825            "Randomized SVD failed on large sparse matrix: {:?}",
826            result.err().unwrap()
827        );
828
829        let svd = result.unwrap();
830        assert_eq!(svd.d, 20, "Expected rank of 20");
831        assert_eq!(svd.u.ncols(), 20, "Expected 20 left singular vectors");
832        assert_eq!(svd.u.nrows(), 5000, "Expected 5000 columns in U transpose");
833        assert_eq!(svd.vt.nrows(), 20, "Expected 20 right singular vectors");
834        assert_eq!(svd.vt.ncols(), 1000, "Expected 1000 columns in V transpose");
835
836        for i in 1..svd.s.len() {
837            assert!(svd.s[i] > 0.0, "Singular values should be positive");
838            assert!(
839                svd.s[i - 1] >= svd.s[i],
840                "Singular values should be in descending order"
841            );
842        }
843    }
844
845    // Test with different power iteration settings
846    #[test]
847    fn test_power_iteration_impact() {
848        setup_thread_pool();
849
850        let mut coo = CooMatrix::<f64>::new(100, 50);
851        let mut rng = StdRng::seed_from_u64(987);
852
853        let mut u = vec![vec![0.0; 10]; 100];
854        let mut v = vec![vec![0.0; 10]; 50];
855
856        for i in 0..100 {
857            for j in 0..10 {
858                u[i][j] = rng.gen_range(-1.0..1.0);
859            }
860        }
861
862        for i in 0..50 {
863            for j in 0..10 {
864                v[i][j] = rng.gen_range(-1.0..1.0);
865            }
866        }
867
868        for i in 0..100 {
869            for j in 0..50 {
870                let mut val = 0.0;
871                for k in 0..10 {
872                    val += u[i][k] * v[j][k];
873                }
874                val += rng.gen_range(-0.01..0.01);
875                coo.push(i, j, val);
876            }
877        }
878
879        let csr = CsrMatrix::from(&coo);
880
881        let powers = [0, 1, 3, 5];
882        let mut errors = Vec::new();
883
884        let mut dense_mat = Array2::<f64>::zeros((100, 50));
885        for (i, j, val) in csr.triplet_iter() {
886            dense_mat[[i, j]] = *val;
887        }
888        let matrix_norm = dense_mat.iter().map(|x| x.powi(2)).sum::<f64>().sqrt();
889
890        for &power in &powers {
891            let svd = randomized_svd(
892                &csr,
893                10,
894                5,
895                power,
896                PowerIterationNormalizer::QR,
897                false,
898                Some(42),
899            )
900            .unwrap();
901
902            let recon = svd.recompose();
903            let mut error = 0.0;
904
905            for i in 0..100 {
906                for j in 0..50 {
907                    error += (dense_mat[[i, j]] - recon[[i, j]]).powi(2);
908                }
909            }
910
911            error = error.sqrt() / matrix_norm;
912            errors.push(error);
913
914            println!("Power iterations: {}, Relative error: {}", power, error);
915        }
916
917        let mut improved = false;
918        for i in 1..errors.len() {
919            if errors[i] < errors[0] * 0.9 {
920                improved = true;
921                break;
922            }
923        }
924
925        assert!(
926            improved,
927            "Power iterations did not improve accuracy as expected"
928        );
929    }
930}