single_svdlib/randomized/
mod.rs

1use crate::error::SvdLibError;
2use crate::utils::determine_chunk_size;
3use crate::{Diagnostics, SMat, SvdFloat, SvdRec};
4use nalgebra_sparse::na::{ComplexField, DMatrix, DVector, RealField};
5use ndarray::Array1;
6use nshare::IntoNdarray2;
7use rand::prelude::{Distribution, StdRng};
8use rand::SeedableRng;
9use rand_distr::Normal;
10use rayon::iter::ParallelIterator;
11use rayon::prelude::{IndexedParallelIterator, IntoParallelIterator};
12use std::ops::Mul;
13use std::time::Instant;
14
15#[derive(Debug, Clone, Copy, PartialEq)]
16pub enum PowerIterationNormalizer {
17    QR,
18    LU,
19    None,
20}
21
22const PARALLEL_THRESHOLD_ROWS: usize = 5000;
23const PARALLEL_THRESHOLD_COLS: usize = 1000;
24const PARALLEL_THRESHOLD_ELEMENTS: usize = 100_000;
25
26pub fn randomized_svd<T, M>(
27    m: &M,
28    target_rank: usize,
29    n_oversamples: usize,
30    n_power_iters: usize,
31    power_iteration_normalizer: PowerIterationNormalizer,
32    mean_center: bool,
33    seed: Option<u64>,
34    verbose: bool,
35) -> anyhow::Result<SvdRec<T>>
36where
37    T: SvdFloat + RealField,
38    M: SMat<T>,
39    T: ComplexField,
40{
41    let start = Instant::now();
42    let m_rows = m.nrows();
43    let m_cols = m.ncols();
44
45    let rank = target_rank.min(m_rows.min(m_cols));
46    let l = rank + n_oversamples;
47
48    let column_means: Option<DVector<T>> = if mean_center {
49        if verbose {
50            println!("SVD | Randomized | Computing column means....");
51        }
52        Some(DVector::from(m.compute_column_means()))
53    } else {
54        None
55    };
56    if verbose && mean_center {
57        println!(
58            "SVD | Randomized | Computed column means, took: {:?} of total running time",
59            start.elapsed()
60        );
61    }
62
63    let omega = generate_random_matrix(m_cols, l, seed);
64
65    let mut y = DMatrix::<T>::zeros(m_rows, l);
66    if verbose {
67        println!("SVD | Randomized | Multiplying m with omega matrix....");
68    }
69    multiply_matrix_centered(m, &omega, &mut y, false, &column_means);
70    if verbose {
71        println!(
72            "SVD | Randomized | Multiplication done, took: {:?} of total running time",
73            start.elapsed()
74        );
75    }
76    if verbose {
77        println!("SVD | Randomized | Starting power iterations....");
78    }
79    if n_power_iters > 0 {
80        let mut z = DMatrix::<T>::zeros(m_cols, l);
81
82        for i in 0..n_power_iters {
83            if verbose {
84                println!(
85                    "SVD | Randomized | Running power-iteration: {:?}, current time: {:?}",
86                    i,
87                    start.elapsed()
88                );
89            }
90            multiply_matrix_centered(m, &y, &mut z, true, &column_means);
91            if verbose {
92                println!(
93                    "SVD | Randomized | Forward Multiplication {:?}",
94                    start.elapsed()
95                );
96            }
97            match power_iteration_normalizer {
98                PowerIterationNormalizer::QR => {
99                    let qr = z.qr();
100                    z = qr.q();
101                }
102                PowerIterationNormalizer::LU => {
103                    normalize_columns(&mut z);
104                }
105                PowerIterationNormalizer::None => {}
106            }
107            if verbose {
108                println!(
109                    "SVD | Randomized | Power Iteration Normalization Forward-Step {:?}",
110                    start.elapsed()
111                );
112            }
113
114            multiply_matrix_centered(m, &z, &mut y, false, &column_means);
115            if verbose {
116                println!(
117                    "SVD | Randomized | Backward Multiplication {:?}",
118                    start.elapsed()
119                );
120            }
121            match power_iteration_normalizer {
122                PowerIterationNormalizer::QR => {
123                    let qr = y.qr();
124                    y = qr.q();
125                }
126                PowerIterationNormalizer::LU => normalize_columns(&mut y),
127                PowerIterationNormalizer::None => {}
128            }
129            if verbose {
130                println!(
131                    "SVD | Randomized | Power Iteration Normalization Backward-Step {:?}",
132                    start.elapsed()
133                );
134            }
135        }
136    }
137    if verbose {
138        println!(
139            "SVD | Randomized | Running QR-Normalization after Power-Iterations {:?}",
140            start.elapsed()
141        );
142    }
143    let qr = y.qr();
144    let y = qr.q();
145    if verbose {
146        println!(
147            "SVD | Randomized | Finished QR-Normalization after Power-Iterations {:?}",
148            start.elapsed()
149        );
150    }
151
152    let mut b = DMatrix::<T>::zeros(y.ncols(), m_cols);
153    multiply_transposed_by_matrix_centered(&y, m, &mut b, &column_means);
154    if verbose {
155        println!(
156            "SVD | Randomized | Transposed Matrix Multiplication {:?}",
157            start.elapsed()
158        );
159    }
160    let svd = b.svd(true, true);
161    if verbose {
162        println!(
163            "SVD | Randomized | Running Singular Value Decomposition, took {:?}",
164            start.elapsed()
165        );
166    }
167    let u_b = svd
168        .u
169        .ok_or_else(|| SvdLibError::Las2Error("SVD U computation failed".to_string()))?;
170    let singular_values = svd.singular_values;
171    let vt = svd
172        .v_t
173        .ok_or_else(|| SvdLibError::Las2Error("SVD V_t computation failed".to_string()))?;
174
175    let u = y.mul(&u_b);
176    let actual_rank = target_rank.min(singular_values.len());
177
178    let u_subset = u.columns(0, actual_rank);
179    let s = convert_singular_values(
180        <DVector<T>>::from(singular_values.rows(0, actual_rank)),
181        actual_rank,
182    );
183    let vt_subset = vt.rows(0, actual_rank).into_owned();
184    let u = u_subset.into_owned().into_ndarray2();
185    let vt = vt_subset.into_ndarray2();
186    Ok(SvdRec {
187        d: actual_rank,
188        u,
189        s,
190        vt,
191        diagnostics: create_diagnostics(
192            m,
193            actual_rank,
194            target_rank,
195            n_power_iters,
196            seed.unwrap_or(0) as u32,
197        ),
198    })
199}
200
201fn convert_singular_values<T: SvdFloat + ComplexField>(
202    values: DVector<T::RealField>,
203    size: usize,
204) -> Array1<T> {
205    let mut array = Array1::zeros(size);
206
207    for i in 0..size {
208        array[i] = T::from_real(values[i].clone());
209    }
210
211    array
212}
213
214fn compute_column_means<T, M>(m: &M) -> Option<DVector<T>>
215where
216    T: SvdFloat + RealField + Send + Sync,
217    M: SMat<T> + Sync,
218{
219    let m_rows = m.nrows();
220    let m_cols = m.ncols();
221
222    let means: Vec<T> = (0..m_cols)
223        .into_par_iter()
224        .map(|j| {
225            let mut col_vec = vec![T::zero(); m_cols];
226            let mut result_vec = vec![T::zero(); m_rows];
227
228            col_vec[j] = T::one();
229
230            m.svd_opa(&col_vec, &mut result_vec, false);
231
232            let sum: T = result_vec.iter().copied().sum();
233            sum / T::from_f64(m_rows as f64).unwrap()
234        })
235        .collect();
236
237    Some(DVector::from_vec(means))
238}
239
240fn create_diagnostics<T, M: SMat<T>>(
241    a: &M,
242    d: usize,
243    target_rank: usize,
244    power_iterations: usize,
245    seed: u32,
246) -> Diagnostics<T>
247where
248    T: SvdFloat,
249{
250    Diagnostics {
251        non_zero: a.nnz(),
252        dimensions: target_rank,
253        iterations: power_iterations,
254        transposed: false,
255        lanczos_steps: 0, // we dont do that
256        ritz_values_stabilized: d,
257        significant_values: d,
258        singular_values: d,
259        end_interval: [T::from(-1e-30).unwrap(), T::from(1e-30).unwrap()],
260        kappa: T::from(1e-6).unwrap(),
261        random_seed: seed,
262    }
263}
264
265fn normalize_columns<T: SvdFloat + RealField + Send + Sync>(matrix: &mut DMatrix<T>) {
266    let rows = matrix.nrows();
267    let cols = matrix.ncols();
268
269    if rows < PARALLEL_THRESHOLD_ROWS && cols < PARALLEL_THRESHOLD_COLS {
270        for j in 0..cols {
271            let mut norm = T::zero();
272
273            // Calculate column norm
274            for i in 0..rows {
275                norm += ComplexField::powi(matrix[(i, j)], 2);
276            }
277            norm = ComplexField::sqrt(norm);
278
279            if norm > T::from_f64(1e-10).unwrap() {
280                let scale = T::one() / norm;
281                for i in 0..rows {
282                    matrix[(i, j)] *= scale;
283                }
284            }
285        }
286        return;
287    }
288
289    let norms: Vec<T> = (0..cols)
290        .into_par_iter()
291        .map(|j| {
292            let mut norm = T::zero();
293            for i in 0..rows {
294                let val = unsafe { *matrix.get_unchecked((i, j)) };
295                norm += ComplexField::powi(val, 2);
296            }
297            ComplexField::sqrt(norm)
298        })
299        .collect();
300
301    let scales: Vec<(usize, T)> = norms
302        .into_iter()
303        .enumerate()
304        .filter_map(|(j, norm)| {
305            if norm > T::from_f64(1e-10).unwrap() {
306                Some((j, T::one() / norm))
307            } else {
308                None // Skip columns with too small norms
309            }
310        })
311        .collect();
312
313    scales.iter().for_each(|(j, scale)| {
314        for i in 0..rows {
315            let value = matrix.get_mut((i, *j)).unwrap();
316            *value = value.clone() * scale.clone();
317        }
318    });
319}
320
321// ----------------------------------------
322// Utils Functions
323// ----------------------------------------
324
325fn generate_random_matrix<T: SvdFloat + RealField>(
326    rows: usize,
327    cols: usize,
328    seed: Option<u64>,
329) -> DMatrix<T> {
330    let mut rng = match seed {
331        Some(s) => StdRng::seed_from_u64(s),
332        None => StdRng::seed_from_u64(0),
333    };
334
335    let normal = Normal::new(0.0, 1.0).unwrap();
336    DMatrix::from_fn(rows, cols, |_, _| {
337        T::from_f64(normal.sample(&mut rng)).unwrap()
338    })
339}
340
341fn multiply_matrix<T: SvdFloat, M: SMat<T>>(
342    sparse: &M,
343    dense: &DMatrix<T>,
344    result: &mut DMatrix<T>,
345    transpose_sparse: bool,
346) {
347    let cols = dense.ncols();
348
349    let results: Vec<(usize, Vec<T>)> = (0..cols)
350        .into_par_iter()
351        .map(|j| {
352            let mut col_vec = vec![T::zero(); dense.nrows()];
353            let mut result_vec = vec![T::zero(); result.nrows()];
354
355            for i in 0..dense.nrows() {
356                col_vec[i] = dense[(i, j)];
357            }
358
359            sparse.svd_opa(&col_vec, &mut result_vec, transpose_sparse);
360
361            (j, result_vec)
362        })
363        .collect();
364
365    for (j, col_result) in results {
366        for i in 0..result.nrows() {
367            result[(i, j)] = col_result[i];
368        }
369    }
370}
371
372fn multiply_transposed_by_matrix<T: SvdFloat, M: SMat<T>>(
373    q: &DMatrix<T>,
374    sparse: &M,
375    result: &mut DMatrix<T>,
376) {
377    let q_rows = q.nrows();
378    let q_cols = q.ncols();
379    let sparse_rows = sparse.nrows();
380    let sparse_cols = sparse.ncols();
381
382    eprintln!("Q dimensions: {} x {}", q_rows, q_cols);
383    eprintln!("Sparse dimensions: {} x {}", sparse_rows, sparse_cols);
384    eprintln!("Result dimensions: {} x {}", result.nrows(), result.ncols());
385
386    assert_eq!(
387        q_rows, sparse_rows,
388        "Dimension mismatch: Q has {} rows but sparse has {} rows",
389        q_rows, sparse_rows
390    );
391
392    assert_eq!(
393        result.nrows(),
394        q_cols,
395        "Result matrix has incorrect row count: expected {}, got {}",
396        q_cols,
397        result.nrows()
398    );
399    assert_eq!(
400        result.ncols(),
401        sparse_cols,
402        "Result matrix has incorrect column count: expected {}, got {}",
403        sparse_cols,
404        result.ncols()
405    );
406
407    let chunk_size = determine_chunk_size(q_cols);
408
409    let chunk_results: Vec<Vec<(usize, Vec<T>)>> = (0..q_cols)
410        .into_par_iter()
411        .chunks(chunk_size)
412        .map(|chunk| {
413            let mut chunk_results = Vec::with_capacity(chunk.len());
414
415            for &col_idx in &chunk {
416                let mut q_col = vec![T::zero(); q_rows];
417                for i in 0..q_rows {
418                    q_col[i] = q[(i, col_idx)];
419                }
420
421                let mut result_row = vec![T::zero(); sparse_cols];
422
423                sparse.svd_opa(&q_col, &mut result_row, true);
424
425                chunk_results.push((col_idx, result_row));
426            }
427            chunk_results
428        })
429        .collect();
430
431    for chunk_result in chunk_results {
432        for (row_idx, row_values) in chunk_result {
433            for j in 0..sparse_cols {
434                result[(row_idx, j)] = row_values[j];
435            }
436        }
437    }
438}
439
440pub fn svd_flip<T: SvdFloat + 'static>(
441    u: Option<&mut DMatrix<T>>,
442    v: Option<&mut DMatrix<T>>,
443    u_based_decision: bool,
444) -> Result<(), SvdLibError> {
445    if u.is_none() && v.is_none() {
446        return Err(SvdLibError::Las2Error(
447            "Both u and v cannot be None".to_string(),
448        ));
449    }
450
451    if u_based_decision {
452        if u.is_none() {
453            return Err(SvdLibError::Las2Error(
454                "u cannot be None when u_based_decision is true".to_string(),
455            ));
456        }
457
458        let u = u.unwrap();
459        let ncols = u.ncols();
460        let nrows = u.nrows();
461
462        let mut signs = DVector::from_element(ncols, T::one());
463
464        for j in 0..ncols {
465            let mut max_abs = T::zero();
466            let mut max_idx = 0;
467
468            for i in 0..nrows {
469                let abs_val = num_traits::Float::abs(u[(i, j)]);
470                if abs_val > max_abs {
471                    max_abs = abs_val;
472                    max_idx = i;
473                }
474            }
475
476            if u[(max_idx, j)] < T::zero() {
477                signs[j] = -T::one();
478            }
479        }
480
481        for j in 0..ncols {
482            for i in 0..nrows {
483                u[(i, j)] *= signs[j];
484            }
485        }
486
487        if let Some(v) = v {
488            let v_nrows = v.nrows();
489            let v_ncols = v.ncols();
490
491            for i in 0..v_nrows.min(signs.len()) {
492                for j in 0..v_ncols {
493                    v[(i, j)] *= signs[i];
494                }
495            }
496        }
497    } else {
498        if v.is_none() {
499            return Err(SvdLibError::Las2Error(
500                "v cannot be None when u_based_decision is false".to_string(),
501            ));
502        }
503
504        let v = v.unwrap();
505        let nrows = v.nrows();
506        let ncols = v.ncols();
507
508        let mut signs = DVector::from_element(nrows, T::one());
509
510        for i in 0..nrows {
511            let mut max_abs = T::zero();
512            let mut max_idx = 0;
513
514            for j in 0..ncols {
515                let abs_val = num_traits::Float::abs(v[(i, j)]);
516                if abs_val > max_abs {
517                    max_abs = abs_val;
518                    max_idx = j;
519                }
520            }
521
522            if v[(i, max_idx)] < T::zero() {
523                signs[i] = -T::one();
524            }
525        }
526
527        for i in 0..nrows {
528            for j in 0..ncols {
529                v[(i, j)] *= signs[i];
530            }
531        }
532
533        if let Some(u) = u {
534            let u_nrows = u.nrows();
535            let u_ncols = u.ncols();
536
537            for j in 0..u_ncols.min(signs.len()) {
538                for i in 0..u_nrows {
539                    u[(i, j)] *= signs[j];
540                }
541            }
542        }
543    }
544
545    Ok(())
546}
547
548fn multiply_matrix_centered<T: SvdFloat, M: SMat<T>>(
549    sparse: &M,
550    dense: &DMatrix<T>,
551    result: &mut DMatrix<T>,
552    transpose_sparse: bool,
553    column_means: &Option<DVector<T>>,
554) {
555    if column_means.is_none() {
556        multiply_matrix(sparse, dense, result, transpose_sparse);
557        return;
558    }
559
560    let means = column_means.as_ref().unwrap();
561    let cols = dense.ncols();
562    let dense_rows = dense.nrows();
563    let result_rows = result.nrows();
564
565    let results: Vec<(usize, Vec<T>)> = (0..cols)
566        .into_par_iter()
567        .map(|j| {
568            let mut col_vec = vec![T::zero(); dense_rows];
569            let mut result_vec = vec![T::zero(); result_rows];
570
571            for i in 0..dense_rows {
572                col_vec[i] = dense[(i, j)];
573            }
574
575            let col_sum = col_vec.iter().fold(T::zero(), |acc, &val| acc + val);
576
577            sparse.svd_opa(&col_vec, &mut result_vec, transpose_sparse);
578
579            if !transpose_sparse {
580                for i in 0..result_rows {
581                    let mean_adjustment = means.iter().map(|&mean| mean * col_sum).sum();
582                    result_vec[i] -= mean_adjustment;
583                }
584            } else {
585                for (i, &mean) in means.iter().enumerate() {
586                    result_vec[i] -= mean * col_sum;
587                }
588            }
589
590            (j, result_vec)
591        })
592        .collect();
593
594    for (j, col_result) in results {
595        for i in 0..result_rows {
596            result[(i, j)] = col_result[i];
597        }
598    }
599}
600
601fn multiply_transposed_by_matrix_centered<T: SvdFloat, M: SMat<T>>(
602    q: &DMatrix<T>,
603    sparse: &M,
604    result: &mut DMatrix<T>,
605    column_means: &Option<DVector<T>>,
606) {
607    if column_means.is_none() {
608        multiply_transposed_by_matrix(q, sparse, result);
609        return;
610    }
611
612    let means = column_means.as_ref().unwrap();
613    let q_rows = q.nrows();
614    let q_cols = q.ncols();
615    let sparse_rows = sparse.nrows();
616    let sparse_cols = sparse.ncols();
617
618    assert_eq!(
619        q_rows, sparse_rows,
620        "Dimension mismatch: Q has {} rows but sparse has {} rows",
621        q_rows, sparse_rows
622    );
623
624    assert_eq!(
625        result.nrows(),
626        q_cols,
627        "Result matrix has incorrect row count: expected {}, got {}",
628        q_cols,
629        result.nrows()
630    );
631    assert_eq!(
632        result.ncols(),
633        sparse_cols,
634        "Result matrix has incorrect column count: expected {}, got {}",
635        sparse_cols,
636        result.ncols()
637    );
638
639    let chunk_size = determine_chunk_size(q_cols);
640
641    let chunk_results: Vec<Vec<(usize, Vec<T>)>> = (0..q_cols)
642        .into_par_iter()
643        .chunks(chunk_size)
644        .map(|chunk| {
645            let mut chunk_results = Vec::with_capacity(chunk.len());
646
647            for &col_idx in &chunk {
648                let mut q_col = vec![T::zero(); q_rows];
649                for i in 0..q_rows {
650                    q_col[i] = q[(i, col_idx)];
651                }
652
653                let mut result_row = vec![T::zero(); sparse_cols];
654
655                sparse.svd_opa(&q_col, &mut result_row, true);
656
657                let mut q_sum = T::zero();
658                for &val in &q_col {
659                    q_sum += val;
660                }
661
662                for j in 0..sparse_cols {
663                    result_row[j] -= means[j] * q_sum;
664                }
665
666                chunk_results.push((col_idx, result_row));
667            }
668            chunk_results
669        })
670        .collect();
671
672    for chunk_result in chunk_results {
673        for (row_idx, row_values) in chunk_result {
674            for j in 0..sparse_cols {
675                result[(row_idx, j)] = row_values[j];
676            }
677        }
678    }
679}
680
681#[cfg(test)]
682mod randomized_svd_tests {
683    use super::*;
684    use crate::randomized::{randomized_svd, PowerIterationNormalizer};
685    use nalgebra_sparse::coo::CooMatrix;
686    use nalgebra_sparse::CsrMatrix;
687    use ndarray::Array2;
688    use rand::rngs::StdRng;
689    use rand::{Rng, SeedableRng};
690    use rayon::ThreadPoolBuilder;
691    use std::sync::Once;
692
693    static INIT: Once = Once::new();
694
695    fn setup_thread_pool() {
696        INIT.call_once(|| {
697            ThreadPoolBuilder::new()
698                .num_threads(16)
699                .build_global()
700                .expect("Failed to build global thread pool");
701
702            println!("Initialized thread pool with {} threads", 16);
703        });
704    }
705
706    fn create_sparse_matrix(
707        rows: usize,
708        cols: usize,
709        density: f64,
710    ) -> nalgebra_sparse::coo::CooMatrix<f64> {
711        use std::collections::HashSet;
712
713        let mut coo = nalgebra_sparse::coo::CooMatrix::new(rows, cols);
714
715        let mut rng = StdRng::seed_from_u64(42);
716
717        let nnz = (rows as f64 * cols as f64 * density).round() as usize;
718
719        let nnz = nnz.max(1);
720
721        let mut positions = HashSet::new();
722
723        while positions.len() < nnz {
724            let i = rng.gen_range(0..rows);
725            let j = rng.gen_range(0..cols);
726
727            if positions.insert((i, j)) {
728                let val = loop {
729                    let v: f64 = rng.gen_range(-10.0..10.0);
730                    if v.abs() > 1e-10 {
731                        break v;
732                    }
733                };
734
735                coo.push(i, j, val);
736            }
737        }
738
739        let actual_density = coo.nnz() as f64 / (rows as f64 * cols as f64);
740        println!("Created sparse matrix: {} x {}", rows, cols);
741        println!("  - Requested density: {:.6}", density);
742        println!("  - Actual density: {:.6}", actual_density);
743        println!("  - Sparsity: {:.4}%", (1.0 - actual_density) * 100.0);
744        println!("  - Non-zeros: {}", coo.nnz());
745
746        coo
747    }
748
749    #[test]
750    fn test_randomized_svd_accuracy() {
751        setup_thread_pool();
752
753        let mut coo = CooMatrix::<f64>::new(20, 15);
754
755        for i in 0..20 {
756            for j in 0..5 {
757                let val = (i as f64) * 0.5 + (j as f64) * 2.0;
758                coo.push(i, j, val);
759            }
760        }
761
762        let csr = CsrMatrix::from(&coo);
763
764        let mut std_svd = crate::lanczos::svd_dim(&csr, 10).unwrap();
765
766        let rand_svd = randomized_svd(
767            &csr,
768            10,
769            5,
770            2,
771            PowerIterationNormalizer::QR,
772            false,
773            Some(42),
774            true,
775        )
776        .unwrap();
777
778        assert_eq!(rand_svd.d, 10, "Expected rank of 10");
779
780        let rel_tol = 0.3;
781        let compare_count = std::cmp::min(2, std::cmp::min(std_svd.d, rand_svd.d));
782        println!("Standard SVD has {} dimensions", std_svd.d);
783        println!("Randomized SVD has {} dimensions", rand_svd.d);
784
785        for i in 0..compare_count {
786            let rel_diff = (std_svd.s[i] - rand_svd.s[i]).abs() / std_svd.s[i];
787            println!(
788                "Singular value {}: standard={}, randomized={}, rel_diff={}",
789                i, std_svd.s[i], rand_svd.s[i], rel_diff
790            );
791            assert!(
792                rel_diff < rel_tol,
793                "Dominant singular value {} differs too much: rel diff = {}, standard = {}, randomized = {}",
794                i, rel_diff, std_svd.s[i], rand_svd.s[i]
795            );
796        }
797
798        std_svd.u = std_svd.u.t().into_owned();
799        let std_recon = std_svd.recompose();
800        let rand_recon = rand_svd.recompose();
801
802        let mut diff_norm = 0.0;
803        let mut orig_norm = 0.0;
804
805        for i in 0..20 {
806            for j in 0..15 {
807                diff_norm += (std_recon[[i, j]] - rand_recon[[i, j]]).powi(2);
808                orig_norm += std_recon[[i, j]].powi(2);
809            }
810        }
811
812        diff_norm = diff_norm.sqrt();
813        orig_norm = orig_norm.sqrt();
814
815        let rel_error = diff_norm / orig_norm;
816        assert!(
817            rel_error < 0.2,
818            "Reconstruction difference too large: {}",
819            rel_error
820        );
821    }
822
823    // Test with mean centering
824    #[test]
825    fn test_randomized_svd_with_mean_centering() {
826        setup_thread_pool();
827
828        let mut coo = CooMatrix::<f64>::new(30, 10);
829        let mut rng = StdRng::seed_from_u64(123);
830
831        let column_means: Vec<f64> = (0..10).map(|i| i as f64 * 2.0).collect();
832
833        let mut u = vec![vec![0.0; 3]; 30]; // 3 factors
834        let mut v = vec![vec![0.0; 3]; 10];
835
836        for i in 0..30 {
837            for j in 0..3 {
838                u[i][j] = rng.gen_range(-1.0..1.0);
839            }
840        }
841
842        for i in 0..10 {
843            for j in 0..3 {
844                v[i][j] = rng.gen_range(-1.0..1.0);
845            }
846        }
847
848        for i in 0..30 {
849            for j in 0..10 {
850                let mut val = 0.0;
851                for k in 0..3 {
852                    val += u[i][k] * v[j][k];
853                }
854                val = val + column_means[j] + rng.gen_range(-0.1..0.1);
855                coo.push(i, j, val);
856            }
857        }
858
859        let csr = CsrMatrix::from(&coo);
860
861        let svd_no_center = randomized_svd(
862            &csr,
863            3,
864            3,
865            2,
866            PowerIterationNormalizer::QR,
867            false,
868            Some(42),
869            false,
870        )
871        .unwrap();
872
873        let svd_with_center = randomized_svd(
874            &csr,
875            3,
876            3,
877            2,
878            PowerIterationNormalizer::QR,
879            true,
880            Some(42),
881            false,
882        )
883        .unwrap();
884
885        println!("Singular values without centering: {:?}", svd_no_center.s);
886        println!("Singular values with centering: {:?}", svd_with_center.s);
887    }
888
889    #[test]
890    fn test_randomized_svd_large_sparse() {
891        setup_thread_pool();
892
893        let test_matrix = create_sparse_matrix(5000, 1000, 0.01);
894
895        let csr = CsrMatrix::from(&test_matrix);
896
897        let result = randomized_svd(
898            &csr,
899            20,
900            10,
901            2,
902            PowerIterationNormalizer::QR,
903            false,
904            Some(42),
905            false,
906        );
907
908        assert!(
909            result.is_ok(),
910            "Randomized SVD failed on large sparse matrix: {:?}",
911            result.err().unwrap()
912        );
913
914        let svd = result.unwrap();
915        assert_eq!(svd.d, 20, "Expected rank of 20");
916        assert_eq!(svd.u.ncols(), 20, "Expected 20 left singular vectors");
917        assert_eq!(svd.u.nrows(), 5000, "Expected 5000 columns in U transpose");
918        assert_eq!(svd.vt.nrows(), 20, "Expected 20 right singular vectors");
919        assert_eq!(svd.vt.ncols(), 1000, "Expected 1000 columns in V transpose");
920
921        for i in 1..svd.s.len() {
922            assert!(svd.s[i] > 0.0, "Singular values should be positive");
923            assert!(
924                svd.s[i - 1] >= svd.s[i],
925                "Singular values should be in descending order"
926            );
927        }
928    }
929
930    // Test with different power iteration settings
931    #[test]
932    fn test_power_iteration_impact() {
933        setup_thread_pool();
934
935        let mut coo = CooMatrix::<f64>::new(100, 50);
936        let mut rng = StdRng::seed_from_u64(987);
937
938        let mut u = vec![vec![0.0; 10]; 100];
939        let mut v = vec![vec![0.0; 10]; 50];
940
941        for i in 0..100 {
942            for j in 0..10 {
943                u[i][j] = rng.gen_range(-1.0..1.0);
944            }
945        }
946
947        for i in 0..50 {
948            for j in 0..10 {
949                v[i][j] = rng.gen_range(-1.0..1.0);
950            }
951        }
952
953        for i in 0..100 {
954            for j in 0..50 {
955                let mut val = 0.0;
956                for k in 0..10 {
957                    val += u[i][k] * v[j][k];
958                }
959                val += rng.gen_range(-0.01..0.01);
960                coo.push(i, j, val);
961            }
962        }
963
964        let csr = CsrMatrix::from(&coo);
965
966        let powers = [0, 1, 3, 5];
967        let mut errors = Vec::new();
968
969        let mut dense_mat = Array2::<f64>::zeros((100, 50));
970        for (i, j, val) in csr.triplet_iter() {
971            dense_mat[[i, j]] = *val;
972        }
973        let matrix_norm = dense_mat.iter().map(|x| x.powi(2)).sum::<f64>().sqrt();
974
975        for &power in &powers {
976            let svd = randomized_svd(
977                &csr,
978                10,
979                5,
980                power,
981                PowerIterationNormalizer::QR,
982                false,
983                Some(42),
984                false,
985            )
986            .unwrap();
987
988            let recon = svd.recompose();
989            let mut error = 0.0;
990
991            for i in 0..100 {
992                for j in 0..50 {
993                    error += (dense_mat[[i, j]] - recon[[i, j]]).powi(2);
994                }
995            }
996
997            error = error.sqrt() / matrix_norm;
998            errors.push(error);
999
1000            println!("Power iterations: {}, Relative error: {}", power, error);
1001        }
1002
1003        let mut improved = false;
1004        for i in 1..errors.len() {
1005            if errors[i] < errors[0] * 0.9 {
1006                improved = true;
1007                break;
1008            }
1009        }
1010
1011        assert!(
1012            improved,
1013            "Power iterations did not improve accuracy as expected"
1014        );
1015    }
1016}