scirs2_stats/multivariate/
pca.rs

1//! Principal Component Analysis (PCA)
2//!
3//! PCA is a dimensionality reduction technique that finds the directions of maximum variance
4//! in high-dimensional data and projects the data onto a lower-dimensional subspace.
5
6use crate::error::{StatsError, StatsResult as Result};
7use scirs2_core::ndarray::{Array1, Array2, ArrayView2, Axis};
8use scirs2_core::validation::*;
9
10/// Principal Component Analysis
11#[derive(Debug, Clone)]
12pub struct PCA {
13    /// Number of components to keep
14    pub n_components: Option<usize>,
15    /// Whether to use SVD instead of eigendecomposition  
16    pub svd_solver: SvdSolver,
17    /// Whether to center the data
18    pub center: bool,
19    /// Whether to scale the data to unit variance
20    pub scale: bool,
21    /// Random state for randomized solver
22    pub random_state: Option<u64>,
23}
24
25/// SVD solver type
26#[derive(Debug, Clone, Copy, PartialEq)]
27pub enum SvdSolver {
28    /// Full SVD
29    Full,
30    /// Randomized SVD (for large datasets)
31    Randomized,
32    /// Automatically choose based on data size
33    Auto,
34}
35
36/// Result of PCA fit
37#[derive(Debug, Clone)]
38pub struct PCAResult {
39    /// Principal components (eigenvectors)
40    pub components: Array2<f64>,
41    /// Explained variance for each component
42    pub explained_variance: Array1<f64>,
43    /// Explained variance ratio for each component
44    pub explained_variance_ratio: Array1<f64>,
45    /// Singular values corresponding to each component
46    pub singular_values: Array1<f64>,
47    /// Mean of the training data
48    pub mean: Array1<f64>,
49    /// Standard deviation of the training data (if scaling was used)
50    pub scale: Option<Array1<f64>>,
51    /// Number of samples used for fitting
52    pub n_samples_: usize,
53    /// Number of features
54    pub n_features: usize,
55}
56
57impl Default for PCA {
58    fn default() -> Self {
59        Self {
60            n_components: None,
61            svd_solver: SvdSolver::Auto,
62            center: true,
63            scale: false,
64            random_state: None,
65        }
66    }
67}
68
69impl PCA {
70    /// Create a new PCA instance
71    pub fn new() -> Self {
72        Self::default()
73    }
74
75    /// Set the number of components to keep
76    pub fn with_n_components(mut self, n_components: usize) -> Self {
77        self.n_components = Some(n_components);
78        self
79    }
80
81    /// Set the SVD solver
82    pub fn with_svd_solver(mut self, solver: SvdSolver) -> Self {
83        self.svd_solver = solver;
84        self
85    }
86
87    /// Enable or disable centering
88    pub fn with_center(mut self, center: bool) -> Self {
89        self.center = center;
90        self
91    }
92
93    /// Enable or disable scaling
94    pub fn with_scale(mut self, scale: bool) -> Self {
95        self.scale = scale;
96        self
97    }
98
99    /// Set random state for reproducibility
100    pub fn with_random_state(mut self, seed: u64) -> Self {
101        self.random_state = Some(seed);
102        self
103    }
104
105    /// Fit the PCA model to the data
106    pub fn fit(&self, data: ArrayView2<f64>) -> Result<PCAResult> {
107        checkarray_finite(&data, "data")?;
108        let (n_samples, n_features) = data.dim();
109        if n_samples < 2 {
110            return Err(StatsError::InvalidArgument(
111                "n_samples must be at least 2".to_string(),
112            ));
113        }
114        if n_features < 1 {
115            return Err(StatsError::InvalidArgument(
116                "n_features must be at least 1".to_string(),
117            ));
118        }
119
120        // Determine number of components
121        let max_components = n_samples.min(n_features);
122        let n_components = match self.n_components {
123            Some(k) => {
124                check_positive(k, "n_components")?;
125                if k > max_components {
126                    return Err(StatsError::InvalidArgument(format!(
127                        "n_components ({}) cannot be larger than min(n_samples, n_features) = {}",
128                        k, max_components
129                    )));
130                }
131                k
132            }
133            None => max_components,
134        };
135
136        // Center the data
137        let mean = if self.center {
138            data.mean_axis(Axis(0)).unwrap()
139        } else {
140            Array1::zeros(n_features)
141        };
142
143        let mut centereddata = data.to_owned();
144        if self.center {
145            for mut row in centereddata.rows_mut() {
146                row -= &mean;
147            }
148        }
149
150        // Scale the data
151        let scale = if self.scale {
152            let std = centereddata.std_axis(Axis(0), 1.0);
153            // Avoid division by zero
154            let std = std.mapv(|s| if s > 1e-10 { s } else { 1.0 });
155
156            for (mut col, &s) in centereddata.columns_mut().into_iter().zip(std.iter()) {
157                col /= s;
158            }
159            Some(std)
160        } else {
161            None
162        };
163
164        // Choose solver
165        let solver = match self.svd_solver {
166            SvdSolver::Auto => {
167                if n_samples >= 500 && n_features >= 500 && n_components < max_components / 2 {
168                    SvdSolver::Randomized
169                } else {
170                    SvdSolver::Full
171                }
172            }
173            solver => solver,
174        };
175
176        // Perform PCA
177        let result = match solver {
178            SvdSolver::Full => self.pca_svd(&centereddata, n_components, n_samples)?,
179            SvdSolver::Randomized => self.pca_randomized(&centereddata, n_components, n_samples)?,
180            _ => unreachable!(),
181        };
182
183        Ok(PCAResult {
184            components: result.0,
185            explained_variance: result.1,
186            explained_variance_ratio: result.2,
187            singular_values: result.3,
188            mean,
189            scale,
190            n_samples_: n_samples,
191            n_features,
192        })
193    }
194
195    /// Perform PCA using SVD
196    fn pca_svd(
197        &self,
198        data: &Array2<f64>,
199        n_components: usize,
200        n_samples: usize,
201    ) -> Result<(Array2<f64>, Array1<f64>, Array1<f64>, Array1<f64>)> {
202        use scirs2_core::ndarray::ndarray_linalg::SVD;
203
204        // Perform SVD: X = U * S * V^T
205        let (_u, s, vt) = data
206            .svd(true, true)
207            .map_err(|e| StatsError::ComputationError(format!("SVD failed: {}", e)))?;
208        let v = vt.unwrap().t().to_owned();
209
210        // Extract _components
211        let components = v
212            .slice(scirs2_core::ndarray::s![.., ..n_components])
213            .to_owned();
214
215        // Compute explained variance
216        let singular_values = s.slice(scirs2_core::ndarray::s![..n_components]).to_owned();
217        let explained_variance = &singular_values * &singular_values / (n_samples - 1) as f64;
218
219        // Compute explained variance ratio
220        let total_variance = explained_variance.sum();
221        let explained_variance_ratio = &explained_variance / total_variance;
222
223        Ok((
224            components.t().to_owned(),
225            explained_variance,
226            explained_variance_ratio,
227            singular_values,
228        ))
229    }
230
231    /// Perform PCA using randomized SVD
232    fn pca_randomized(
233        &self,
234        data: &Array2<f64>,
235        n_components: usize,
236        n_samples: usize,
237    ) -> Result<(Array2<f64>, Array1<f64>, Array1<f64>, Array1<f64>)> {
238        use scirs2_core::ndarray::ndarray_linalg::{QR, SVD};
239        use scirs2_core::random::{rngs::StdRng, SeedableRng};
240        use scirs2_core::random::{Distribution, Normal};
241
242        let n_features = data.ncols();
243        let n_oversamples = 10.min((n_features - n_components) / 2);
244        let n_random = n_components + n_oversamples;
245
246        // Initialize RNG
247        let mut rng = match self.random_state {
248            Some(seed) => StdRng::seed_from_u64(seed),
249            None => {
250                // Use a simple fallback seed based on current time or a fixed seed
251                use std::time::{SystemTime, UNIX_EPOCH};
252                let seed = SystemTime::now()
253                    .duration_since(UNIX_EPOCH)
254                    .unwrap_or_default()
255                    .as_secs();
256                StdRng::seed_from_u64(seed)
257            }
258        };
259
260        // Generate random matrix
261        let normal = Normal::new(0.0, 1.0).map_err(|e| {
262            StatsError::ComputationError(format!("Failed to create normal distribution: {}", e))
263        })?;
264        let omega = Array2::from_shape_fn((n_features, n_random), |_| normal.sample(&mut rng));
265
266        // Power iterations for better approximation
267        let n_iter = 4;
268        let mut q = data.dot(&omega);
269
270        for _ in 0..n_iter {
271            // QR decomposition
272            let (q_mat, r) = q.qr().map_err(|e| {
273                StatsError::ComputationError(format!("QR decomposition failed: {}", e))
274            })?;
275            q = q_mat;
276
277            // Project back
278            let z = data.t().dot(&q);
279            let (q_mat, r) = z.qr().map_err(|e| {
280                StatsError::ComputationError(format!("QR decomposition failed: {}", e))
281            })?;
282            q = data.dot(&q_mat);
283        }
284
285        // Final QR decomposition
286        let (q_final, r) = q.qr().map_err(|e| {
287            StatsError::ComputationError(format!("Final QR decomposition failed: {}", e))
288        })?;
289
290        // Project data onto subspace
291        let b = q_final.t().dot(data);
292
293        // SVD of small matrix B
294        let (_u_small, s, vt) = b.svd(true, true).map_err(|e| {
295            StatsError::ComputationError(format!("SVD of projected matrix failed: {}", e))
296        })?;
297
298        let v = vt.unwrap().t().to_owned();
299
300        // Extract _components
301        let components = v
302            .slice(scirs2_core::ndarray::s![.., ..n_components])
303            .to_owned();
304
305        // Compute explained variance
306        let singular_values = s.slice(scirs2_core::ndarray::s![..n_components]).to_owned();
307        let explained_variance = &singular_values * &singular_values / (n_samples - 1) as f64;
308
309        // Compute explained variance ratio
310        let total_variance = explained_variance.sum();
311        let explained_variance_ratio = &explained_variance / total_variance;
312
313        Ok((
314            components.t().to_owned(),
315            explained_variance,
316            explained_variance_ratio,
317            singular_values,
318        ))
319    }
320
321    /// Transform data using the fitted PCA model
322    pub fn transform(&self, data: ArrayView2<f64>, result: &PCAResult) -> Result<Array2<f64>> {
323        checkarray_finite(&data, "data")?;
324        if data.ncols() != result.n_features {
325            return Err(StatsError::DimensionMismatch(format!(
326                "data has {} features, expected {}",
327                data.ncols(),
328                result.n_features
329            )));
330        }
331
332        let mut transformed = data.to_owned();
333
334        // Center
335        if self.center {
336            for mut row in transformed.rows_mut() {
337                row -= &result.mean;
338            }
339        }
340
341        // Scale
342        if let Some(ref scale) = result.scale {
343            for (mut col, &s) in transformed.columns_mut().into_iter().zip(scale.iter()) {
344                col /= s;
345            }
346        }
347
348        // Project onto components
349        Ok(transformed.dot(&result.components.t()))
350    }
351
352    /// Inverse transform from component space back to original space
353    pub fn inverse_transform(
354        &self,
355        data: ArrayView2<f64>,
356        result: &PCAResult,
357    ) -> Result<Array2<f64>> {
358        checkarray_finite(&data, "data")?;
359        let n_components = result.components.nrows();
360        if data.ncols() != n_components {
361            return Err(StatsError::DimensionMismatch(format!(
362                "data has {} components, expected {}",
363                data.ncols(),
364                n_components
365            )));
366        }
367
368        // Project back to original space
369        let mut reconstructed = data.dot(&result.components);
370
371        // Inverse scale
372        if let Some(ref scale) = result.scale {
373            for (mut col, &s) in reconstructed.columns_mut().into_iter().zip(scale.iter()) {
374                col *= s;
375            }
376        }
377
378        // Add mean back
379        if self.center {
380            for mut row in reconstructed.rows_mut() {
381                row += &result.mean;
382            }
383        }
384
385        Ok(reconstructed)
386    }
387
388    /// Fit and transform in one step
389    pub fn fit_transform(&self, data: ArrayView2<f64>) -> Result<(Array2<f64>, PCAResult)> {
390        let result = self.fit(data)?;
391        let transformed = self.transform(data, &result)?;
392        Ok((transformed, result))
393    }
394}
395
396/// Compute the optimal number of components using Minka's MLE
397#[allow(dead_code)]
398pub fn mle_components(data: ArrayView2<f64>, maxcomponents: Option<usize>) -> Result<usize> {
399    checkarray_finite(&data, "data")?;
400    let (n_samples, n_features) = data.dim();
401
402    let pca = PCA::new().with_n_components(maxcomponents.unwrap_or(n_features.min(n_samples)));
403    let result = pca.fit(data)?;
404
405    let eigenvalues = &result.explained_variance;
406    let n = n_samples as f64;
407    let p = n_features as f64;
408
409    // Minka's MLE for PCA
410    let mut best_k = 0;
411    let mut best_ll = f64::NEG_INFINITY;
412
413    for k in 0..eigenvalues.len() {
414        let k_f64 = k as f64;
415
416        // Average of remaining eigenvalues
417        let sigma2 = if k < eigenvalues.len() - 1 {
418            eigenvalues.slice(scirs2_core::ndarray::s![k + 1..]).sum() / (p - k_f64 - 1.0)
419        } else {
420            1e-10
421        };
422
423        // Log-likelihood
424        let ll = -n / 2.0
425            * (eigenvalues
426                .slice(scirs2_core::ndarray::s![..=k])
427                .mapv(f64::ln)
428                .sum()
429                + (p - k_f64 - 1.0) * sigma2.ln()
430                + p * (2.0 * std::f64::consts::PI).ln());
431
432        // AIC penalty
433        let aic_penalty = k_f64 * (2.0 * p - k_f64 - 1.0);
434        let aic = ll - aic_penalty;
435
436        if aic > best_ll {
437            best_ll = aic;
438            best_k = k + 1;
439        }
440    }
441
442    Ok(best_k)
443}
444
445/// Incremental PCA for large datasets that don't fit in memory
446#[derive(Debug, Clone)]
447pub struct IncrementalPCA {
448    /// Base PCA configuration
449    pub pca: PCA,
450    /// Batch size for incremental updates
451    pub batchsize: usize,
452    /// Running mean
453    mean: Option<Array1<f64>>,
454    /// Running components
455    components: Option<Array2<f64>>,
456    /// Singular values
457    singular_values: Option<Array1<f64>>,
458    /// Number of samples seen
459    n_samples_seen: usize,
460    /// Incremental SVD state
461    svd_u: Option<Array2<f64>>,
462    svd_s: Option<Array1<f64>>,
463    svd_v: Option<Array2<f64>>,
464}
465
466impl IncrementalPCA {
467    /// Create a new incremental PCA instance
468    pub fn new(n_components: usize, batchsize: usize) -> Result<Self> {
469        check_positive(n_components, "n_components")?;
470        check_positive(batchsize, "batchsize")?;
471
472        Ok(Self {
473            pca: PCA::new().with_n_components(n_components),
474            batchsize,
475            mean: None,
476            components: None,
477            singular_values: None,
478            n_samples_seen: 0,
479            svd_u: None,
480            svd_s: None,
481            svd_v: None,
482        })
483    }
484
485    /// Partial fit on a batch of data
486    pub fn partial_fit(&mut self, batch: ArrayView2<f64>) -> Result<()> {
487        checkarray_finite(&batch, "batch")?;
488        let (batchsize, n_features) = batch.dim();
489
490        // Update mean incrementally
491        let batch_mean = batch.mean_axis(Axis(0)).unwrap();
492        let old_n = self.n_samples_seen;
493        self.n_samples_seen += batchsize;
494
495        self.mean = match &self.mean {
496            None => Some(batch_mean.clone()),
497            Some(mean) => {
498                let updated = (mean * old_n as f64 + &batch_mean * batchsize as f64)
499                    / self.n_samples_seen as f64;
500                Some(updated)
501            }
502        };
503
504        // Center the batch
505        let mut centered_batch = batch.to_owned();
506        for mut row in centered_batch.rows_mut() {
507            row -= &batch_mean;
508        }
509
510        // Incremental SVD update using Brand's algorithm
511        let n_components = self
512            .pca
513            .n_components
514            .unwrap_or(n_features.min(self.n_samples_seen));
515
516        if self.svd_u.is_none() {
517            // First batch - initialize with standard SVD
518            use scirs2_core::ndarray::ndarray_linalg::SVD;
519            let (u, s, vt) = centered_batch
520                .svd(true, true)
521                .map_err(|e| StatsError::ComputationError(format!("Initial SVD failed: {}", e)))?;
522
523            let u = u.unwrap();
524            let vt = vt.unwrap();
525
526            // Keep only n_components
527            self.svd_u = Some(
528                u.slice(scirs2_core::ndarray::s![.., ..n_components])
529                    .to_owned(),
530            );
531            self.svd_s = Some(s.slice(scirs2_core::ndarray::s![..n_components]).to_owned());
532            self.svd_v = Some(
533                vt.slice(scirs2_core::ndarray::s![..n_components, ..])
534                    .t()
535                    .to_owned(),
536            );
537
538            self.components = Some(self.svd_v.as_ref().unwrap().t().to_owned());
539            self.singular_values = Some(self.svd_s.as_ref().unwrap().clone());
540        } else {
541            // Incremental update
542            let u_old = self.svd_u.as_ref().unwrap();
543            let s_old = self.svd_s.as_ref().unwrap();
544            let v_old = self.svd_v.as_ref().unwrap();
545
546            // Project new data onto existing components
547            let projection = centered_batch.dot(v_old);
548            let residual = &centered_batch - &projection.dot(&v_old.t());
549
550            // QR decomposition of residual
551            use scirs2_core::ndarray::ndarray_linalg::QR;
552            let (q_res, r_res) = residual.qr().map_err(|e| {
553                StatsError::ComputationError(format!("QR decomposition failed: {}", e))
554            })?;
555
556            // Build augmented matrix
557            let k = s_old.len();
558            let p = r_res.ncols();
559
560            // Create block matrix [diag(s_old), projection^T; 0, r_res]
561            let mut augmented = Array2::zeros((k + p, k + p));
562            for i in 0..k {
563                augmented[[i, i]] = s_old[i];
564            }
565            for i in 0..projection.nrows() {
566                for j in 0..k {
567                    augmented[[j, k + i]] = projection[[i, j]];
568                }
569            }
570            for i in 0..p {
571                for j in 0..p {
572                    augmented[[k + i, k + j]] = r_res[[i, j]];
573                }
574            }
575
576            // SVD of augmented matrix
577            use scirs2_core::ndarray::ndarray_linalg::SVD;
578            let (u_aug, s_aug, vt_aug) = augmented.svd(true, true).map_err(|e| {
579                StatsError::ComputationError(format!("Augmented SVD failed: {}", e))
580            })?;
581
582            let u_aug = u_aug.unwrap();
583            let vt_aug = vt_aug.unwrap();
584
585            // Update U
586            let mut u_new = Array2::zeros((old_n + batchsize, n_components));
587            let u_aug_slice = u_aug.slice(scirs2_core::ndarray::s![..n_components, ..n_components]);
588
589            // Update old samples part
590            let u_old_part = u_old.dot(&u_aug_slice.t());
591            u_new
592                .slice_mut(scirs2_core::ndarray::s![..old_n, ..])
593                .assign(&u_old_part);
594
595            // Update new samples part
596            let u_batch_part =
597                projection.dot(&u_aug_slice.slice(scirs2_core::ndarray::s![.., ..k]).t());
598            let u_res_part = q_res.dot(&u_aug_slice.slice(scirs2_core::ndarray::s![.., k..]).t());
599            u_new
600                .slice_mut(scirs2_core::ndarray::s![old_n.., ..])
601                .assign(&(&u_batch_part + &u_res_part));
602
603            // Update singular values
604            self.svd_s = Some(
605                s_aug
606                    .slice(scirs2_core::ndarray::s![..n_components])
607                    .to_owned(),
608            );
609
610            // Update V
611            let v_aug_slice =
612                vt_aug.slice(scirs2_core::ndarray::s![..n_components, ..n_components]);
613            let mut v_new = Array2::zeros((n_features, n_components));
614
615            let v_old_part = v_old.dot(&v_aug_slice.slice(scirs2_core::ndarray::s![.., ..k]).t());
616            let v_res_part = q_res
617                .t()
618                .dot(&centered_batch)
619                .t()
620                .dot(&v_aug_slice.slice(scirs2_core::ndarray::s![.., k..]).t());
621            v_new.assign(&(&v_old_part + &v_res_part));
622
623            self.svd_u = Some(u_new);
624            self.svd_v = Some(v_new.clone());
625            self.components = Some(v_new.t().to_owned());
626            self.singular_values = Some(self.svd_s.as_ref().unwrap().clone());
627        }
628
629        Ok(())
630    }
631
632    /// Transform new data
633    pub fn transform(&self, data: ArrayView2<f64>) -> Result<Array2<f64>> {
634        if self.components.is_none() || self.mean.is_none() {
635            return Err(StatsError::ComputationError(
636                "IncrementalPCA must be fitted before transform".to_string(),
637            ));
638        }
639
640        let mut centered = data.to_owned();
641        for mut row in centered.rows_mut() {
642            row -= self.mean.as_ref().unwrap();
643        }
644
645        Ok(centered.dot(&self.components.as_ref().unwrap().t()))
646    }
647}