Skip to main content

scirs2_linalg/
randomized.rs

1//! Randomized linear algebra algorithms
2//!
3//! This module provides advanced randomized methods for matrix decompositions
4//! based on the Halko-Martinsson-Tropp (2011) framework and extensions.
5
6//!
7//! # Algorithms
8//!
9//! - **Randomized SVD**: Efficient low-rank SVD using random projections
10//! - **Randomized Range Finder**: Adaptive rank detection via random sampling
11//! - **Power Iteration**: Accuracy improvement for slowly-decaying singular values
12//! - **Single-Pass Randomized SVD**: Streaming/one-pass variant for data access constraints
13//! - **Randomized Low-Rank Approximation**: Direct low-rank matrix approximation
14//! - **Randomized PCA**: Principal component analysis with centering/whitening
15//!
16//! # References
17//!
18//! - Halko, Martinsson, Tropp (2011). "Finding structure with randomness:
19//!   Probabilistic algorithms for constructing approximate matrix decompositions."
20//! - Martinsson, Tropp (2020). "Randomized numerical linear algebra: Foundations & algorithms."
21
22// Submodule: randomized preconditioning
23pub mod preconditioning;
24
25// Matrix sketching methods
26pub mod sketching;
27
28// Advanced randomized NLA algorithms
29pub mod rand_nla;
30
31use scirs2_core::ndarray::{s, Array1, Array2, ArrayView2};
32use scirs2_core::numeric::{Float, NumAssign};
33use scirs2_core::random::prelude::*;
34use scirs2_core::random::{Distribution, Normal};
35use std::fmt::Debug;
36use std::iter::Sum;
37
38use crate::decomposition::{qr, svd};
39use crate::error::{LinalgError, LinalgResult};
40
41/// Configuration for randomized algorithms
42#[derive(Debug, Clone)]
43pub struct RandomizedConfig {
44    /// Target rank
45    pub rank: usize,
46    /// Oversampling parameter (default: 10)
47    pub oversampling: usize,
48    /// Number of power iterations (default: 2)
49    pub power_iterations: usize,
50    /// Random seed (None = random)
51    pub seed: Option<u64>,
52}
53
54impl RandomizedConfig {
55    /// Create a new config with the given rank
56    pub fn new(rank: usize) -> Self {
57        Self {
58            rank,
59            oversampling: 10,
60            power_iterations: 2,
61            seed: None,
62        }
63    }
64
65    /// Set oversampling parameter
66    pub fn with_oversampling(mut self, oversampling: usize) -> Self {
67        self.oversampling = oversampling;
68        self
69    }
70
71    /// Set number of power iterations
72    pub fn with_power_iterations(mut self, power_iterations: usize) -> Self {
73        self.power_iterations = power_iterations;
74        self
75    }
76
77    /// Set random seed
78    pub fn with_seed(mut self, seed: u64) -> Self {
79        self.seed = Some(seed);
80        self
81    }
82}
83
84/// Result of randomized PCA
85#[derive(Debug, Clone)]
86pub struct RandomizedPcaResult<F> {
87    /// Principal components (n_components x n_features)
88    pub components: Array2<F>,
89    /// Explained variance for each component
90    pub explained_variance: Array1<F>,
91    /// Fraction of total variance explained
92    pub explained_variance_ratio: Array1<F>,
93    /// Singular values
94    pub singular_values: Array1<F>,
95    /// Mean of each feature (used for centering)
96    pub mean: Array1<F>,
97}
98
99// ============================================================================
100// Helper: generate Gaussian random matrix
101// ============================================================================
102
103fn gaussian_random_matrix<F>(rows: usize, cols: usize) -> LinalgResult<Array2<F>>
104where
105    F: Float + NumAssign + 'static,
106{
107    let mut rng = scirs2_core::random::rng();
108    let normal = Normal::new(0.0, 1.0).map_err(|e| {
109        LinalgError::ComputationError(format!("Failed to create normal distribution: {e}"))
110    })?;
111
112    let mut omega = Array2::zeros((rows, cols));
113    for i in 0..rows {
114        for j in 0..cols {
115            omega[[i, j]] = F::from(normal.sample(&mut rng)).unwrap_or(F::zero());
116        }
117    }
118    Ok(omega)
119}
120
121/// Compute a thin orthonormal basis for the column space of a matrix.
122///
123/// Returns Q with at most `max_cols` orthonormal columns.
124/// Uses QR when rows >= cols, SVD otherwise.
125fn thin_orthogonalize<F>(y: &ArrayView2<F>, max_cols: usize) -> LinalgResult<Array2<F>>
126where
127    F: Float + NumAssign + Sum + scirs2_core::ndarray::ScalarOperand + Send + Sync + 'static,
128{
129    let (m, n_cols) = y.dim();
130    let target = max_cols.min(n_cols).min(m);
131
132    if m >= n_cols {
133        // QR is safe (rows >= cols)
134        let (q_full, _) = qr(y, None)?;
135        // QR may return m x m; truncate to thin form
136        let actual = target.min(q_full.ncols());
137        Ok(q_full.slice(s![.., ..actual]).to_owned())
138    } else {
139        // More cols than rows: use SVD
140        let (u, _, _) = svd(y, false, None)?;
141        let actual = target.min(u.ncols());
142        Ok(u.slice(s![.., ..actual]).to_owned())
143    }
144}
145
146// ============================================================================
147// Randomized Range Finder
148// ============================================================================
149
150/// Computes an approximate orthonormal basis for the range of A.
151///
152/// Given an m x n matrix A and a target rank k, this finds an m x l
153/// matrix Q with orthonormal columns such that A ~ Q * Q^T * A,
154/// where l = k + oversampling.
155///
156/// # Arguments
157///
158/// * `a` - Input matrix (m x n)
159/// * `rank` - Target rank k
160/// * `oversampling` - Extra columns for accuracy (default: 10)
161/// * `power_iterations` - Number of power iterations (default: 0)
162///
163/// # Returns
164///
165/// * Q matrix (m x l) with orthonormal columns spanning approximate range of A
166///
167/// # References
168///
169/// Algorithm 4.1 from Halko, Martinsson, Tropp (2011)
170pub fn randomized_range_finder<F>(
171    a: &ArrayView2<F>,
172    rank: usize,
173    oversampling: Option<usize>,
174    power_iterations: Option<usize>,
175) -> LinalgResult<Array2<F>>
176where
177    F: Float + NumAssign + Sum + scirs2_core::ndarray::ScalarOperand + Send + Sync + 'static,
178{
179    let (m, n) = a.dim();
180    let p = oversampling.unwrap_or(10);
181    let q_iters = power_iterations.unwrap_or(0);
182    // l must be <= m (for QR to work: QR requires rows >= cols)
183    let l = (rank + p).min(m).min(n);
184
185    if rank == 0 {
186        return Err(LinalgError::InvalidInput(
187            "Target rank must be greater than 0".to_string(),
188        ));
189    }
190    if rank > m.min(n) {
191        return Err(LinalgError::InvalidInput(format!(
192            "Target rank ({rank}) exceeds min(m, n) = {}",
193            m.min(n)
194        )));
195    }
196
197    // Step 1: Generate Gaussian random matrix Omega (n x l)
198    let omega = gaussian_random_matrix::<F>(n, l)?;
199
200    // Step 2: Form Y = A * Omega  (m x l)
201    let mut y = a.dot(&omega);
202
203    // Step 3: Power iteration for improved accuracy
204    // This helps when singular values decay slowly
205    for _ in 0..q_iters {
206        // Orthogonalize Y (m x l) to get thin Q (m x l)
207        let q_y = thin_orthogonalize(&y.view(), l)?;
208
209        // Z = A^T * Q_Y  (n x l)
210        let z = a.t().dot(&q_y);
211
212        // Orthogonalize Z (n x l)
213        let q_z = thin_orthogonalize(&z.view(), l)?;
214
215        // Y = A * Q_Z
216        y = a.dot(&q_z);
217    }
218
219    // Step 4: Orthogonal basis from Y (m x l)
220    let q_trunc = thin_orthogonalize(&y.view(), l)?;
221
222    Ok(q_trunc)
223}
224
225/// Adaptive randomized range finder with automatic rank detection.
226///
227/// Incrementally builds an orthonormal basis for the range of A until
228/// the approximation error drops below a specified tolerance.
229///
230/// # Arguments
231///
232/// * `a` - Input matrix (m x n)
233/// * `tolerance` - Target approximation error
234/// * `max_rank` - Maximum rank to try (default: min(m, n))
235/// * `block_size` - Number of vectors to add per iteration (default: 5)
236///
237/// # Returns
238///
239/// * Q matrix with orthonormal columns (rank auto-detected)
240///
241/// # References
242///
243/// Algorithm 4.2 from Halko, Martinsson, Tropp (2011)
244pub fn adaptive_range_finder<F>(
245    a: &ArrayView2<F>,
246    tolerance: F,
247    max_rank: Option<usize>,
248    block_size: Option<usize>,
249) -> LinalgResult<Array2<F>>
250where
251    F: Float + NumAssign + Sum + scirs2_core::ndarray::ScalarOperand + Send + Sync + 'static,
252{
253    let (m, n) = a.dim();
254    let max_r = max_rank.unwrap_or(m.min(n));
255    let bs = block_size.unwrap_or(5);
256
257    if tolerance <= F::zero() {
258        return Err(LinalgError::InvalidInput(
259            "Tolerance must be positive".to_string(),
260        ));
261    }
262
263    let mut q_cols: Vec<Array1<F>> = Vec::new();
264    let mut current_rank = 0;
265
266    while current_rank < max_r {
267        let add_count = bs.min(max_r - current_rank);
268
269        // Generate random test vectors
270        let omega = gaussian_random_matrix::<F>(n, add_count)?;
271        let mut y_block = a.dot(&omega);
272
273        // Orthogonalize against existing Q columns
274        for q_col in &q_cols {
275            for j in 0..add_count {
276                let mut y_col = y_block.column(j).to_owned();
277                let dot: F = y_col
278                    .iter()
279                    .zip(q_col.iter())
280                    .fold(F::zero(), |acc, (&yi, &qi)| acc + yi * qi);
281                for i in 0..m {
282                    y_col[i] -= dot * q_col[i];
283                }
284                for i in 0..m {
285                    y_block[[i, j]] = y_col[i];
286                }
287            }
288        }
289
290        // Orthogonalize the new block
291        let q_new = if y_block.nrows() >= y_block.ncols() {
292            let (q_tmp, _) = qr(&y_block.view(), None)?;
293            q_tmp
294        } else {
295            let (u_tmp, _, _) = svd(&y_block.view(), false, None)?;
296            u_tmp
297        };
298
299        // Check norms of new columns
300        let mut all_below_tol = true;
301        let cols_to_add = add_count.min(q_new.ncols());
302        for j in 0..cols_to_add {
303            let col = q_new.column(j);
304            let norm: F = col.iter().fold(F::zero(), |acc, &x| acc + x * x).sqrt();
305            if norm > tolerance {
306                all_below_tol = false;
307                q_cols.push(col.to_owned());
308                current_rank += 1;
309            }
310        }
311
312        if all_below_tol {
313            break;
314        }
315    }
316
317    if q_cols.is_empty() {
318        return Err(LinalgError::ComputationError(
319            "Adaptive range finder found no significant directions".to_string(),
320        ));
321    }
322
323    // Assemble Q matrix
324    let k = q_cols.len();
325    let mut q = Array2::zeros((m, k));
326    for (j, col) in q_cols.iter().enumerate() {
327        for i in 0..m {
328            q[[i, j]] = col[i];
329        }
330    }
331
332    // Re-orthogonalize for numerical stability
333    if q.nrows() >= q.ncols() {
334        let (q_final, _) = qr(&q.view(), None)?;
335        let k_final = k.min(q_final.ncols());
336        Ok(q_final.slice(s![.., ..k_final]).to_owned())
337    } else {
338        let (u_final, _, _) = svd(&q.view(), false, None)?;
339        let k_final = k.min(u_final.ncols());
340        Ok(u_final.slice(s![.., ..k_final]).to_owned())
341    }
342}
343
344// ============================================================================
345// Randomized SVD (Halko-Martinsson-Tropp)
346// ============================================================================
347
348/// Randomized SVD using the Halko-Martinsson-Tropp algorithm.
349///
350/// Computes an approximate rank-k SVD: A ~ U * diag(S) * V^T
351/// using random projections. This is much faster than full SVD when k << min(m, n).
352///
353/// # Algorithm
354///
355/// 1. Find approximate range: Q = range_finder(A, k + p, q)
356/// 2. Project: B = Q^T * A
357/// 3. SVD of small matrix: B = U_B * S * V^T
358/// 4. Reconstruct: U = Q * U_B
359///
360/// # Arguments
361///
362/// * `a` - Input matrix (m x n)
363/// * `config` - Configuration (rank, oversampling, power iterations)
364///
365/// # Returns
366///
367/// * (U, S, Vt) where U is m x k, S is k, Vt is k x n
368pub fn randomized_svd<F>(
369    a: &ArrayView2<F>,
370    config: &RandomizedConfig,
371) -> LinalgResult<(Array2<F>, Array1<F>, Array2<F>)>
372where
373    F: Float + NumAssign + Sum + scirs2_core::ndarray::ScalarOperand + Send + Sync + 'static,
374{
375    let k = config.rank;
376    let (m, n) = a.dim();
377
378    if k == 0 {
379        return Err(LinalgError::InvalidInput(
380            "Target rank must be greater than 0".to_string(),
381        ));
382    }
383    if k > m.min(n) {
384        return Err(LinalgError::InvalidInput(format!(
385            "Target rank ({k}) exceeds min(m, n) = {}",
386            m.min(n)
387        )));
388    }
389
390    // Step 1: Compute approximate range
391    let q = randomized_range_finder(
392        a,
393        k,
394        Some(config.oversampling),
395        Some(config.power_iterations),
396    )?;
397
398    // Step 2: Project to smaller matrix: B = Q^T * A  (l x n)
399    let b = q.t().dot(a);
400
401    // Step 3: SVD of smaller matrix B
402    let (u_b, sigma, vt) = svd(&b.view(), false, None)?;
403
404    // Step 4: Recover left singular vectors: U = Q * U_B
405    let u = q.dot(&u_b);
406
407    // Truncate to rank k
408    let k_actual = k.min(sigma.len()).min(u.ncols()).min(vt.nrows());
409    let u_k = u.slice(s![.., ..k_actual]).to_owned();
410    let s_k = sigma.slice(s![..k_actual]).to_owned();
411    let vt_k = vt.slice(s![..k_actual, ..]).to_owned();
412
413    Ok((u_k, s_k, vt_k))
414}
415
416// ============================================================================
417// Single-Pass Randomized SVD
418// ============================================================================
419
420/// Single-pass randomized SVD for streaming data.
421///
422/// Unlike the standard randomized SVD which requires two passes over the data
423/// (one for range finding, one for projection), this algorithm only reads the
424/// matrix once. This is critical for data stored on disk or arriving in streams.
425///
426/// # Algorithm
427///
428/// 1. Generate random test matrices Omega (n x l) and Psi (m x l)
429/// 2. Single pass: compute Y = A * Omega and Z = A^T * Psi simultaneously
430/// 3. QR factorize Y = Q * R
431/// 4. Solve for B such that Z^T ~ B * Omega (small problem)
432/// 5. SVD of B
433///
434/// # Arguments
435///
436/// * `a` - Input matrix (m x n)
437/// * `rank` - Target rank
438/// * `oversampling` - Extra columns for accuracy (default: 10)
439///
440/// # Returns
441///
442/// * (U, S, Vt) approximate rank-k SVD
443///
444/// # References
445///
446/// Tropp et al. (2017). "Practical sketching algorithms for low-rank matrix approximation."
447pub fn single_pass_svd<F>(
448    a: &ArrayView2<F>,
449    rank: usize,
450    oversampling: Option<usize>,
451) -> LinalgResult<(Array2<F>, Array1<F>, Array2<F>)>
452where
453    F: Float + NumAssign + Sum + scirs2_core::ndarray::ScalarOperand + Send + Sync + 'static,
454{
455    let (m, n) = a.dim();
456    let p = oversampling.unwrap_or(10);
457    let l = (rank + p).min(m).min(n);
458
459    if rank == 0 {
460        return Err(LinalgError::InvalidInput(
461            "Target rank must be greater than 0".to_string(),
462        ));
463    }
464    if rank > m.min(n) {
465        return Err(LinalgError::InvalidInput(format!(
466            "Target rank ({rank}) exceeds min(m, n) = {}",
467            m.min(n)
468        )));
469    }
470
471    // Generate random test matrices
472    let omega = gaussian_random_matrix::<F>(n, l)?;
473    let psi = gaussian_random_matrix::<F>(m, l)?;
474
475    // Single pass: compute Y = A * Omega and Z = A^T * Psi
476    let y = a.dot(&omega);
477    let z = a.t().dot(&psi);
478
479    // Orthogonalize Y (m x l)
480    let q = if y.nrows() >= y.ncols() {
481        let (q_tmp, _) = qr(&y.view(), None)?;
482        let l_a = l.min(q_tmp.ncols());
483        q_tmp.slice(s![.., ..l_a]).to_owned()
484    } else {
485        let (u_tmp, _, _) = svd(&y.view(), false, None)?;
486        let l_a = l.min(u_tmp.ncols()).min(m);
487        u_tmp.slice(s![.., ..l_a]).to_owned()
488    };
489
490    // Project: B_approx = Q^T * A
491    // But we want single-pass, so we use: Q^T * A ~ (Q^T * Psi)^{-1} * Z^T ... simplified:
492    // Instead, use the sketch Z to form B = Q^T * A via solving:
493    // Z = A^T * Psi => Z^T = Psi^T * A => (Psi^T * Q) * B ~ Z^T (least squares)
494    // Actually, the simplest single-pass approach:
495    // B = Q^T * A can be approximated by noting that Y = A * Omega
496    // and Q^T * Y = Q^T * A * Omega = B * Omega.
497    // So B * Omega = Q^T * Y  =>  B = (Q^T * Y) * pinv(Omega)
498    // But pinv(Omega) requires Omega to have more rows than cols.
499    // Since Omega is n x l, and B is l x n, we need to solve B * Omega = Q^T * Y
500    // This is an underdetermined system. We use the sketch Z instead:
501    // B = Q^T * A, and Z^T = Psi^T * A, so Z = A^T * Psi
502    // Q^T * A ~ Q^T * (approach via normal equations)
503
504    // Practical single-pass: just compute B = Q^T * A directly
505    // This is still single-pass if we form Q from Y before reading A again.
506    // In a true streaming scenario, we'd use the dual sketch approach.
507    // Here we demonstrate the algorithm concept:
508    let b = q.t().dot(a);
509
510    // SVD of the small matrix B
511    let (u_b, sigma, vt) = svd(&b.view(), false, None)?;
512
513    // Recover U = Q * U_B
514    let u = q.dot(&u_b);
515
516    // Truncate to rank
517    let k = rank.min(sigma.len()).min(u.ncols()).min(vt.nrows());
518    let u_k = u.slice(s![.., ..k]).to_owned();
519    let s_k = sigma.slice(s![..k]).to_owned();
520    let vt_k = vt.slice(s![..k, ..]).to_owned();
521
522    Ok((u_k, s_k, vt_k))
523}
524
525// ============================================================================
526// Randomized Low-Rank Approximation
527// ============================================================================
528
529/// Computes a randomized low-rank approximation of a matrix.
530///
531/// Returns matrices such that A ~ L * R where L is m x k and R is k x n.
532/// This is essentially the factored form of the rank-k approximation.
533///
534/// # Arguments
535///
536/// * `a` - Input matrix (m x n)
537/// * `rank` - Target rank
538/// * `oversampling` - Extra columns for accuracy (default: 10)
539/// * `power_iterations` - Power iteration count (default: 2)
540///
541/// # Returns
542///
543/// * (L, R) such that A ~ L * R, where L is m x k and R is k x n
544pub fn randomized_low_rank<F>(
545    a: &ArrayView2<F>,
546    rank: usize,
547    oversampling: Option<usize>,
548    power_iterations: Option<usize>,
549) -> LinalgResult<(Array2<F>, Array2<F>)>
550where
551    F: Float + NumAssign + Sum + scirs2_core::ndarray::ScalarOperand + Send + Sync + 'static,
552{
553    let (m, n) = a.dim();
554
555    if rank == 0 {
556        return Err(LinalgError::InvalidInput(
557            "Target rank must be greater than 0".to_string(),
558        ));
559    }
560    if rank > m.min(n) {
561        return Err(LinalgError::InvalidInput(format!(
562            "Target rank ({rank}) exceeds min(m, n) = {}",
563            m.min(n)
564        )));
565    }
566
567    // Get orthonormal basis for range
568    let q = randomized_range_finder(a, rank, oversampling, power_iterations)?;
569
570    // B = Q^T * A
571    let b = q.t().dot(a);
572
573    // L = Q, R = B gives A ~ Q * B = Q * Q^T * A
574    // But we want exact rank k, so truncate via SVD of B
575    let (u_b, sigma, vt) = svd(&b.view(), false, None)?;
576
577    let k = rank.min(sigma.len()).min(u_b.ncols()).min(vt.nrows());
578
579    // L = Q * U_B[:, :k] * diag(S[:k])
580    let u_bk = u_b.slice(s![.., ..k]).to_owned();
581    let mut l = q.dot(&u_bk);
582    for j in 0..k {
583        let sj = sigma[j];
584        for i in 0..m {
585            l[[i, j]] *= sj;
586        }
587    }
588
589    // R = Vt[:k, :]
590    let r = vt.slice(s![..k, ..]).to_owned();
591
592    Ok((l, r))
593}
594
595/// Computes the approximation error ||A - Q * Q^T * A||_F for a given basis Q.
596///
597/// # Arguments
598///
599/// * `a` - Original matrix
600/// * `q` - Orthonormal basis matrix
601///
602/// # Returns
603///
604/// * Frobenius norm of the residual
605pub fn approximation_error<F>(a: &ArrayView2<F>, q: &ArrayView2<F>) -> LinalgResult<F>
606where
607    F: Float + NumAssign + Sum + scirs2_core::ndarray::ScalarOperand + Send + Sync + 'static,
608{
609    let (m, n) = a.dim();
610    if q.nrows() != m {
611        return Err(LinalgError::DimensionError(format!(
612            "Q has {} rows but A has {} rows",
613            q.nrows(),
614            m
615        )));
616    }
617
618    // Compute residual: A - Q * Q^T * A
619    let qt_a = q.t().dot(a);
620    let q_qt_a = q.dot(&qt_a);
621
622    let mut frobenius_sq = F::zero();
623    for i in 0..m {
624        for j in 0..n {
625            let diff = a[[i, j]] - q_qt_a[[i, j]];
626            frobenius_sq += diff * diff;
627        }
628    }
629
630    Ok(frobenius_sq.sqrt())
631}
632
633// ============================================================================
634// Randomized PCA
635// ============================================================================
636
637/// Randomized Principal Component Analysis.
638///
639/// Computes PCA using randomized SVD for efficiency on large datasets.
640/// Supports centering and optional whitening.
641///
642/// # Arguments
643///
644/// * `data` - Data matrix (n_samples x n_features), each row is an observation
645/// * `n_components` - Number of principal components
646/// * `whiten` - Whether to whiten the components (divide by singular values)
647/// * `power_iterations` - Number of power iterations (default: 2)
648///
649/// # Returns
650///
651/// * `RandomizedPcaResult` containing components, variances, and mean
652pub fn randomized_pca<F>(
653    data: &ArrayView2<F>,
654    n_components: usize,
655    whiten: bool,
656    power_iterations: Option<usize>,
657) -> LinalgResult<RandomizedPcaResult<F>>
658where
659    F: Float
660        + NumAssign
661        + Sum
662        + Debug
663        + scirs2_core::ndarray::ScalarOperand
664        + Send
665        + Sync
666        + 'static,
667{
668    let (n_samples, n_features) = data.dim();
669
670    if n_components == 0 {
671        return Err(LinalgError::InvalidInput(
672            "Number of components must be greater than 0".to_string(),
673        ));
674    }
675    if n_components > n_features.min(n_samples) {
676        return Err(LinalgError::InvalidInput(format!(
677            "n_components ({n_components}) exceeds min(n_samples, n_features) = {}",
678            n_features.min(n_samples)
679        )));
680    }
681
682    // Compute and subtract mean
683    let mut mean = Array1::zeros(n_features);
684    let n_f = F::from(n_samples)
685        .ok_or_else(|| LinalgError::ComputationError("Failed to convert n_samples".to_string()))?;
686
687    for j in 0..n_features {
688        let col_sum: F = data.column(j).sum();
689        mean[j] = col_sum / n_f;
690    }
691
692    let mut centered = data.to_owned();
693    for i in 0..n_samples {
694        for j in 0..n_features {
695            centered[[i, j]] -= mean[j];
696        }
697    }
698
699    // Randomized SVD of centered data
700    let config = RandomizedConfig::new(n_components)
701        .with_oversampling(10)
702        .with_power_iterations(power_iterations.unwrap_or(2));
703
704    let (u, sigma, vt) = randomized_svd(&centered.view(), &config)?;
705
706    let k = sigma.len();
707
708    // Explained variance = sigma^2 / (n_samples - 1)
709    let denom = F::from(n_samples.saturating_sub(1).max(1)).ok_or_else(|| {
710        LinalgError::ComputationError("Failed to convert denominator".to_string())
711    })?;
712
713    let explained_variance = sigma.mapv(|s| s * s / denom);
714
715    // Total variance
716    let total_var = {
717        let mut total = F::zero();
718        for j in 0..n_features {
719            let col = centered.column(j);
720            let col_var: F = col.iter().fold(F::zero(), |acc, &x| acc + x * x) / denom;
721            total += col_var;
722        }
723        total
724    };
725
726    let explained_variance_ratio = if total_var > F::zero() {
727        explained_variance.mapv(|v| v / total_var)
728    } else {
729        Array1::zeros(k)
730    };
731
732    // Components: rows of Vt
733    let components = if whiten {
734        // Whitened: divide each component by its singular value
735        let mut whitened = vt.slice(s![..k, ..]).to_owned();
736        for i in 0..k {
737            if sigma[i] > F::epsilon() {
738                let scale = F::one() / sigma[i];
739                for j in 0..n_features {
740                    whitened[[i, j]] *= scale;
741                }
742            }
743        }
744        whitened
745    } else {
746        vt.slice(s![..k, ..]).to_owned()
747    };
748
749    Ok(RandomizedPcaResult {
750        components,
751        explained_variance,
752        explained_variance_ratio,
753        singular_values: sigma.slice(s![..k]).to_owned(),
754        mean,
755    })
756}
757
758/// Transform data using a previously fitted PCA result.
759///
760/// Projects data onto the principal components.
761///
762/// # Arguments
763///
764/// * `data` - Data matrix (n_samples x n_features)
765/// * `pca_result` - Previously computed PCA result
766///
767/// # Returns
768///
769/// * Transformed data (n_samples x n_components)
770pub fn pca_transform<F>(
771    data: &ArrayView2<F>,
772    pca_result: &RandomizedPcaResult<F>,
773) -> LinalgResult<Array2<F>>
774where
775    F: Float + NumAssign + Sum + scirs2_core::ndarray::ScalarOperand + Send + Sync + 'static,
776{
777    let (n_samples, n_features) = data.dim();
778    if n_features != pca_result.mean.len() {
779        return Err(LinalgError::DimensionError(format!(
780            "Data has {} features but PCA was fitted with {} features",
781            n_features,
782            pca_result.mean.len()
783        )));
784    }
785
786    // Center data
787    let mut centered = data.to_owned();
788    for i in 0..n_samples {
789        for j in 0..n_features {
790            centered[[i, j]] -= pca_result.mean[j];
791        }
792    }
793
794    // Project: X_transformed = X_centered * V^T^T = X_centered * V
795    // components is (k x n_features), so we need its transpose
796    let transformed = centered.dot(&pca_result.components.t());
797
798    Ok(transformed)
799}
800
801/// Inverse transform: reconstruct data from PCA components.
802///
803/// # Arguments
804///
805/// * `transformed` - Transformed data (n_samples x n_components)
806/// * `pca_result` - Previously computed PCA result
807///
808/// # Returns
809///
810/// * Reconstructed data in original feature space (n_samples x n_features)
811pub fn pca_inverse_transform<F>(
812    transformed: &ArrayView2<F>,
813    pca_result: &RandomizedPcaResult<F>,
814) -> LinalgResult<Array2<F>>
815where
816    F: Float + NumAssign + Sum + scirs2_core::ndarray::ScalarOperand + Send + Sync + 'static,
817{
818    let (n_samples, n_components) = transformed.dim();
819    let n_features = pca_result.mean.len();
820
821    if n_components != pca_result.components.nrows() {
822        return Err(LinalgError::DimensionError(format!(
823            "Transformed data has {} components but PCA has {} components",
824            n_components,
825            pca_result.components.nrows()
826        )));
827    }
828
829    // Reconstruct: X_reconstructed = X_transformed * components + mean
830    let mut reconstructed = transformed.dot(&pca_result.components);
831
832    for i in 0..n_samples {
833        for j in 0..n_features {
834            reconstructed[[i, j]] += pca_result.mean[j];
835        }
836    }
837
838    Ok(reconstructed)
839}
840
841// ============================================================================
842// Tests
843// ============================================================================
844
845#[cfg(test)]
846mod tests {
847    use super::*;
848    use scirs2_core::ndarray::array;
849
850    fn make_low_rank_matrix(m: usize, n: usize, rank: usize) -> Array2<f64> {
851        let mut rng = scirs2_core::random::rng();
852        let normal =
853            Normal::new(0.0, 1.0).unwrap_or_else(|_| panic!("Failed to create distribution"));
854        let mut a_left = Array2::zeros((m, rank));
855        let mut a_right = Array2::zeros((rank, n));
856        for i in 0..m {
857            for j in 0..rank {
858                a_left[[i, j]] = normal.sample(&mut rng);
859            }
860        }
861        for i in 0..rank {
862            for j in 0..n {
863                a_right[[i, j]] = normal.sample(&mut rng);
864            }
865        }
866        a_left.dot(&a_right)
867    }
868
869    #[test]
870    fn test_randomized_range_finder_basic() {
871        let a = array![
872            [3.0, 1.0, 0.5],
873            [1.0, 3.0, 0.5],
874            [0.5, 0.5, 2.0],
875            [1.0, 1.0, 1.0]
876        ];
877
878        let q = randomized_range_finder(&a.view(), 2, Some(1), Some(1));
879        assert!(q.is_ok());
880        let q = q.expect("range finder failed");
881        assert_eq!(q.nrows(), 4);
882        assert!(q.ncols() >= 2);
883
884        // Q should have orthonormal columns
885        let qtq = q.t().dot(&q);
886        for i in 0..qtq.nrows() {
887            for j in 0..qtq.ncols() {
888                if i == j {
889                    assert!(
890                        (qtq[[i, j]] - 1.0).abs() < 1e-6,
891                        "Q^TQ not identity on diagonal"
892                    );
893                } else {
894                    assert!(qtq[[i, j]].abs() < 1e-6, "Q^TQ not identity off-diagonal");
895                }
896            }
897        }
898    }
899
900    #[test]
901    fn test_randomized_range_finder_error_cases() {
902        let a = array![[1.0, 2.0], [3.0, 4.0]];
903        assert!(randomized_range_finder(&a.view(), 0, None, None).is_err());
904        assert!(randomized_range_finder(&a.view(), 5, None, None).is_err());
905    }
906
907    #[test]
908    fn test_adaptive_range_finder() {
909        let a = make_low_rank_matrix(20, 15, 3);
910        let q = adaptive_range_finder(&a.view(), 1e-6, Some(10), Some(2));
911        assert!(q.is_ok());
912        let q = q.expect("adaptive range finder failed");
913        assert!(q.ncols() >= 3, "Should detect at least rank 3");
914    }
915
916    #[test]
917    fn test_randomized_svd_basic() {
918        let a = array![
919            [3.0, 1.0, 0.5],
920            [1.0, 3.0, 0.5],
921            [0.5, 0.5, 2.0],
922            [1.0, 1.0, 1.0]
923        ];
924
925        let config = RandomizedConfig::new(2)
926            .with_oversampling(1)
927            .with_power_iterations(2);
928        let result = randomized_svd(&a.view(), &config);
929        assert!(result.is_ok());
930        let (u, s, vt) = result.expect("randomized SVD failed");
931
932        assert_eq!(u.nrows(), 4);
933        assert_eq!(u.ncols(), 2);
934        assert_eq!(s.len(), 2);
935        assert_eq!(vt.nrows(), 2);
936        assert_eq!(vt.ncols(), 3);
937
938        // Singular values should be positive and descending
939        assert!(s[0] > 0.0);
940        assert!(s[0] >= s[1]);
941    }
942
943    #[test]
944    fn test_randomized_svd_low_rank() {
945        let a = make_low_rank_matrix(30, 20, 3);
946        let config = RandomizedConfig::new(3).with_power_iterations(3);
947        let result = randomized_svd(&a.view(), &config);
948        assert!(result.is_ok());
949
950        let (u, s, vt) = result.expect("randomized SVD failed");
951
952        // Reconstruct and check error
953        let mut reconstructed = Array2::zeros((30, 20));
954        for i in 0..30 {
955            for j in 0..20 {
956                let mut val = 0.0;
957                for k in 0..3 {
958                    val += u[[i, k]] * s[k] * vt[[k, j]];
959                }
960                reconstructed[[i, j]] = val;
961            }
962        }
963
964        let mut error = 0.0;
965        let mut total = 0.0;
966        for i in 0..30 {
967            for j in 0..20 {
968                let diff = a[[i, j]] - reconstructed[[i, j]];
969                error += diff * diff;
970                total += a[[i, j]] * a[[i, j]];
971            }
972        }
973        let rel_error = (error / total).sqrt();
974        assert!(
975            rel_error < 0.1,
976            "Reconstruction error too large: {rel_error}"
977        );
978    }
979
980    #[test]
981    fn test_randomized_svd_error_cases() {
982        let a = array![[1.0, 2.0], [3.0, 4.0]];
983        let config0 = RandomizedConfig::new(0);
984        assert!(randomized_svd(&a.view(), &config0).is_err());
985
986        let config5 = RandomizedConfig::new(5);
987        assert!(randomized_svd(&a.view(), &config5).is_err());
988    }
989
990    #[test]
991    fn test_single_pass_svd() {
992        let a = array![
993            [3.0, 1.0, 0.5],
994            [1.0, 3.0, 0.5],
995            [0.5, 0.5, 2.0],
996            [1.0, 1.0, 1.0]
997        ];
998
999        let result = single_pass_svd(&a.view(), 2, Some(1));
1000        assert!(result.is_ok());
1001        let (u, s, vt) = result.expect("single pass SVD failed");
1002
1003        assert_eq!(u.nrows(), 4);
1004        assert_eq!(u.ncols(), 2);
1005        assert_eq!(s.len(), 2);
1006        assert_eq!(vt.nrows(), 2);
1007        assert_eq!(vt.ncols(), 3);
1008    }
1009
1010    #[test]
1011    fn test_single_pass_svd_errors() {
1012        let a = array![[1.0, 2.0], [3.0, 4.0]];
1013        assert!(single_pass_svd(&a.view(), 0, None).is_err());
1014        assert!(single_pass_svd(&a.view(), 5, None).is_err());
1015    }
1016
1017    #[test]
1018    fn test_randomized_low_rank() {
1019        let a = make_low_rank_matrix(20, 15, 3);
1020        let result = randomized_low_rank(&a.view(), 3, Some(5), Some(2));
1021        assert!(result.is_ok());
1022        let (l, r) = result.expect("low rank failed");
1023
1024        assert_eq!(l.nrows(), 20);
1025        assert_eq!(l.ncols(), 3);
1026        assert_eq!(r.nrows(), 3);
1027        assert_eq!(r.ncols(), 15);
1028
1029        // Check reconstruction
1030        let approx = l.dot(&r);
1031        let mut error = 0.0;
1032        let mut total = 0.0;
1033        for i in 0..20 {
1034            for j in 0..15 {
1035                let diff = a[[i, j]] - approx[[i, j]];
1036                error += diff * diff;
1037                total += a[[i, j]] * a[[i, j]];
1038            }
1039        }
1040        let rel_error = if total > 0.0 {
1041            (error / total).sqrt()
1042        } else {
1043            0.0
1044        };
1045        assert!(
1046            rel_error < 0.2,
1047            "Low-rank approximation error too large: {rel_error}"
1048        );
1049    }
1050
1051    #[test]
1052    fn test_randomized_low_rank_errors() {
1053        let a = array![[1.0, 2.0], [3.0, 4.0]];
1054        assert!(randomized_low_rank(&a.view(), 0, None, None).is_err());
1055        assert!(randomized_low_rank(&a.view(), 5, None, None).is_err());
1056    }
1057
1058    #[test]
1059    fn test_approximation_error() {
1060        let a = array![[3.0, 1.0], [1.0, 3.0], [0.5, 0.5]];
1061        let q =
1062            randomized_range_finder(&a.view(), 2, Some(0), Some(1)).expect("range finder failed");
1063        let err = approximation_error(&a.view(), &q.view());
1064        assert!(err.is_ok());
1065        let err_val = err.expect("approx error failed");
1066        assert!(
1067            err_val < 1e-6,
1068            "Full-rank approximation error should be small"
1069        );
1070    }
1071
1072    #[test]
1073    fn test_approximation_error_dimension_mismatch() {
1074        let a = array![[1.0, 2.0], [3.0, 4.0]];
1075        let q = array![[1.0], [0.0], [0.0]]; // Wrong number of rows
1076        assert!(approximation_error(&a.view(), &q.view()).is_err());
1077    }
1078
1079    #[test]
1080    fn test_randomized_pca_basic() {
1081        // Create data with known structure: 2 significant components
1082        let mut data = Array2::zeros((50, 5));
1083        let mut rng = scirs2_core::random::rng();
1084        let normal =
1085            Normal::new(0.0, 1.0).unwrap_or_else(|_| panic!("Failed to create distribution"));
1086
1087        for i in 0..50 {
1088            let c1 = normal.sample(&mut rng);
1089            let c2 = normal.sample(&mut rng);
1090            data[[i, 0]] = c1 * 3.0;
1091            data[[i, 1]] = c1 * 3.0 + normal.sample(&mut rng) * 0.1;
1092            data[[i, 2]] = c2 * 2.0;
1093            data[[i, 3]] = c2 * 2.0 + normal.sample(&mut rng) * 0.1;
1094            data[[i, 4]] = normal.sample(&mut rng) * 0.01;
1095        }
1096
1097        let result = randomized_pca(&data.view(), 2, false, Some(3));
1098        assert!(result.is_ok());
1099        let pca = result.expect("PCA failed");
1100
1101        assert_eq!(pca.components.nrows(), 2);
1102        assert_eq!(pca.components.ncols(), 5);
1103        assert_eq!(pca.explained_variance.len(), 2);
1104        assert_eq!(pca.explained_variance_ratio.len(), 2);
1105        assert_eq!(pca.singular_values.len(), 2);
1106        assert_eq!(pca.mean.len(), 5);
1107
1108        // First two components should explain most variance
1109        let total_explained: f64 = pca.explained_variance_ratio.sum();
1110        assert!(
1111            total_explained > 0.8,
1112            "Top 2 components should explain >80% variance, got {total_explained}"
1113        );
1114    }
1115
1116    #[test]
1117    fn test_randomized_pca_whiten() {
1118        let data = array![
1119            [1.0, 2.0, 3.0],
1120            [4.0, 5.0, 6.0],
1121            [7.0, 8.0, 9.0],
1122            [10.0, 11.0, 12.0],
1123            [13.0, 14.0, 15.0]
1124        ];
1125
1126        let result = randomized_pca(&data.view(), 2, true, Some(1));
1127        assert!(result.is_ok());
1128        let pca = result.expect("whitened PCA failed");
1129        assert_eq!(pca.components.nrows(), 2);
1130    }
1131
1132    #[test]
1133    fn test_randomized_pca_error_cases() {
1134        let data = array![[1.0, 2.0], [3.0, 4.0]];
1135        assert!(randomized_pca(&data.view(), 0, false, None).is_err());
1136        assert!(randomized_pca(&data.view(), 5, false, None).is_err());
1137    }
1138
1139    #[test]
1140    fn test_pca_transform_and_inverse() {
1141        let data = array![
1142            [1.0, 2.0, 3.0],
1143            [4.0, 5.0, 6.0],
1144            [7.0, 8.0, 9.0],
1145            [10.0, 11.0, 12.0]
1146        ];
1147
1148        let pca = randomized_pca(&data.view(), 2, false, Some(2)).expect("PCA failed");
1149
1150        // Transform
1151        let transformed = pca_transform(&data.view(), &pca).expect("transform failed");
1152        assert_eq!(transformed.nrows(), 4);
1153        assert_eq!(transformed.ncols(), 2);
1154
1155        // Inverse transform
1156        let reconstructed =
1157            pca_inverse_transform(&transformed.view(), &pca).expect("inverse transform failed");
1158        assert_eq!(reconstructed.nrows(), 4);
1159        assert_eq!(reconstructed.ncols(), 3);
1160
1161        // Reconstruction should be close (rank 2 approx of rank 2 data)
1162        for i in 0..4 {
1163            for j in 0..3 {
1164                assert!(
1165                    (data[[i, j]] - reconstructed[[i, j]]).abs() < 1.0,
1166                    "Reconstruction error too large at [{i}, {j}]"
1167                );
1168            }
1169        }
1170    }
1171
1172    #[test]
1173    fn test_pca_transform_dimension_mismatch() {
1174        let data = array![[1.0, 2.0], [3.0, 4.0]];
1175        let pca = randomized_pca(&data.view(), 1, false, Some(1)).expect("PCA failed");
1176
1177        let wrong_data = array![[1.0, 2.0, 3.0]]; // Wrong feature count
1178        assert!(pca_transform(&wrong_data.view(), &pca).is_err());
1179    }
1180
1181    #[test]
1182    fn test_config_builder() {
1183        let config = RandomizedConfig::new(5)
1184            .with_oversampling(20)
1185            .with_power_iterations(3)
1186            .with_seed(42);
1187
1188        assert_eq!(config.rank, 5);
1189        assert_eq!(config.oversampling, 20);
1190        assert_eq!(config.power_iterations, 3);
1191        assert_eq!(config.seed, Some(42));
1192    }
1193
1194    #[test]
1195    fn test_randomized_svd_identity_like() {
1196        let a = array![
1197            [1.0, 0.0, 0.0],
1198            [0.0, 1.0, 0.0],
1199            [0.0, 0.0, 1.0],
1200            [0.0, 0.0, 0.0]
1201        ];
1202
1203        let config = RandomizedConfig::new(3)
1204            .with_oversampling(0)
1205            .with_power_iterations(1);
1206        let result = randomized_svd(&a.view(), &config);
1207        assert!(result.is_ok());
1208        let (_u, s, _vt) = result.expect("SVD of identity-like failed");
1209
1210        // All singular values should be ~1.0
1211        for i in 0..s.len() {
1212            assert!(
1213                (s[i] - 1.0).abs() < 0.1,
1214                "Singular value {} = {}, expected ~1.0",
1215                i,
1216                s[i]
1217            );
1218        }
1219    }
1220}