Skip to main content

scry_learn/preprocess/
pca.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2//! Principal Component Analysis (PCA) — pure-Rust eigendecomposition.
3//!
4//! Reduces dimensionality by projecting data onto the directions of
5//! maximum variance.  Uses Jacobi rotation for eigendecomposition of
6//! the covariance matrix — no BLAS / LAPACK required.
7//!
8//! # Example
9//!
10//! ```ignore
11//! use scry_learn::prelude::*;
12//!
13//! let mut pca = Pca::with_n_components(2).whiten(true);
14//! pca.fit_transform(&mut dataset)?;
15//!
16//! // Inspect variance explained
17//! println!("{:?}", pca.explained_variance_ratio());
18//! ```
19
20use crate::dataset::Dataset;
21use crate::error::{Result, ScryLearnError};
22use crate::preprocess::Transformer;
23
24// ── Public types ──────────────────────────────────────────────────
25
26/// Principal Component Analysis.
27///
28/// Projects data onto the top-k eigenvectors of the covariance matrix.
29/// Optionally whitens the output so each component has unit variance.
30#[derive(Clone, Debug)]
31#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
32#[non_exhaustive]
33pub struct Pca {
34    n_components: Option<usize>,
35    do_whiten: bool,
36    // — fitted state —
37    mean: Vec<f64>,
38    /// Rows = components (top-k eigenvectors), each of length n_features.
39    components: Vec<Vec<f64>>,
40    explained_variance: Vec<f64>,
41    explained_variance_ratio: Vec<f64>,
42    total_variance: f64,
43    fitted: bool,
44    #[cfg_attr(feature = "serde", serde(default))]
45    _schema_version: u32,
46}
47
48// ── Builder ───────────────────────────────────────────────────────
49
50impl Pca {
51    /// Create a PCA that retains **all** components.
52    pub fn new() -> Self {
53        Self {
54            n_components: None,
55            do_whiten: false,
56            mean: Vec::new(),
57            components: Vec::new(),
58            explained_variance: Vec::new(),
59            explained_variance_ratio: Vec::new(),
60            total_variance: 0.0,
61            fitted: false,
62            _schema_version: crate::version::SCHEMA_VERSION,
63        }
64    }
65
66    /// Create a PCA that retains the top `k` components.
67    pub fn with_n_components(k: usize) -> Self {
68        Self {
69            n_components: Some(k),
70            ..Self::new()
71        }
72    }
73
74    /// Enable whitening (scale components to unit variance).
75    pub fn whiten(mut self, yes: bool) -> Self {
76        self.do_whiten = yes;
77        self
78    }
79
80    // ── Accessors ─────────────────────────────────────────────────
81
82    /// Fraction of total variance explained by each retained component.
83    pub fn explained_variance_ratio(&self) -> &[f64] {
84        &self.explained_variance_ratio
85    }
86
87    /// Absolute variance (eigenvalue) of each retained component.
88    pub fn explained_variance(&self) -> &[f64] {
89        &self.explained_variance
90    }
91
92    /// Principal axes in feature space — `[n_components][n_features]`.
93    pub fn components(&self) -> &[Vec<f64>] {
94        &self.components
95    }
96
97    /// Number of components actually retained after fitting.
98    pub fn n_components_fitted(&self) -> usize {
99        self.components.len()
100    }
101}
102
103impl Default for Pca {
104    fn default() -> Self {
105        Self::new()
106    }
107}
108
109// ── Transformer impl ─────────────────────────────────────────────
110
111impl Transformer for Pca {
112    fn fit(&mut self, data: &Dataset) -> Result<()> {
113        data.validate_finite()?;
114        let n = data.n_samples();
115        let m = data.n_features();
116        if n == 0 {
117            return Err(ScryLearnError::EmptyDataset);
118        }
119        if m == 0 {
120            return Err(ScryLearnError::InvalidParameter(
121                "dataset has no features".into(),
122            ));
123        }
124
125        let k = self.n_components.unwrap_or(m).min(m);
126
127        // 1. Column means — contiguous column slices from DenseMatrix.
128        let mat = data.matrix();
129        let mut mean = vec![0.0; m];
130        for j in 0..m {
131            let col = mat.col(j);
132            let s: f64 = col.iter().sum();
133            mean[j] = s / n as f64;
134        }
135
136        // 2. Covariance matrix (m × m), stored flat row-major.
137        //    cov[i*m+j] = (1/(n-1)) * Σ (x_ij - μ_j)(x_ik - μ_k)
138        let denom = if n > 1 { (n - 1) as f64 } else { 1.0 };
139        let mut cov = vec![0.0; m * m];
140        for i in 0..m {
141            let col_i = mat.col(i);
142            for j in i..m {
143                let col_j = mat.col(j);
144                let mut s = 0.0;
145                for s_idx in 0..n {
146                    s += (col_i[s_idx] - mean[i]) * (col_j[s_idx] - mean[j]);
147                }
148                let v = s / denom;
149                cov[i * m + j] = v;
150                cov[j * m + i] = v;
151            }
152        }
153
154        // 3. Jacobi eigendecomposition → eigenvalues + eigenvectors.
155        let (eigenvalues, eigenvectors) = jacobi_eigen(m, &mut cov);
156
157        // 4. Sort by descending eigenvalue.
158        let mut order: Vec<usize> = (0..m).collect();
159        order.sort_by(|&a, &b| eigenvalues[b].total_cmp(&eigenvalues[a]));
160
161        let total: f64 = eigenvalues.iter().filter(|&&v| v > 0.0).sum();
162
163        self.mean = mean;
164        self.total_variance = total;
165        self.explained_variance = order[..k]
166            .iter()
167            .map(|&i| eigenvalues[i].max(0.0))
168            .collect();
169        self.explained_variance_ratio = if total > crate::constants::NEAR_ZERO {
170            self.explained_variance.iter().map(|v| v / total).collect()
171        } else {
172            vec![0.0; k]
173        };
174        self.components = order[..k]
175            .iter()
176            .map(|&i| {
177                // eigenvector column i → row in components
178                (0..m).map(|r| eigenvectors[r * m + i]).collect()
179            })
180            .collect();
181        self.fitted = true;
182        Ok(())
183    }
184
185    fn transform(&self, data: &mut Dataset) -> Result<()> {
186        const BLOCK: usize = 32;
187        crate::version::check_schema_version(self._schema_version)?;
188        if !self.fitted {
189            return Err(ScryLearnError::NotFitted);
190        }
191        let n = data.n_samples();
192        let m = self.mean.len();
193        let k = self.components.len();
194
195        // ── Cache-friendly blocked matrix multiply ──
196        //
197        // Pre-center into contiguous row-major buffer, transpose components
198        // for column-major access, then blocked matmul with slice-based
199        // bound elision for SIMD (iter_mut().zip() proves equal lengths →
200        // LLVM emits vfmadd231pd).
201        let mat = data.matrix();
202        let mut centered = vec![0.0; n * m];
203        for j in 0..m {
204            let col = mat.col(j);
205            let mean_j = self.mean[j];
206            for i in 0..n {
207                centered[i * m + j] = col[i] - mean_j;
208            }
209        }
210
211        // 2. Transpose components to column-major layout: comp_t[j*k + c]
212        //    so that the inner dot-product loop reads contiguously.
213        let mut comp_t = vec![0.0; m * k];
214        for (c, comp) in self.components.iter().enumerate() {
215            for (j, &v) in comp.iter().enumerate() {
216                comp_t[j * k + c] = v;
217            }
218        }
219
220        // 3. Blocked matmul: centered(n×m) × comp_t(m×k) → result(n×k)
221        //    Slice extraction elides bounds checks → LLVM vectorises to FMA.
222        let mut result = vec![0.0; n * k];
223        for ib in (0..n).step_by(BLOCK) {
224            let i_end = (ib + BLOCK).min(n);
225            for jb in (0..m).step_by(BLOCK) {
226                let j_end = (jb + BLOCK).min(m);
227                for i in ib..i_end {
228                    let r_row = i * k;
229                    for j in jb..j_end {
230                        let c_val = centered[i * m + j];
231
232                        let r_slice = &mut result[r_row..r_row + k];
233                        let c_slice = &comp_t[j * k..j * k + k];
234
235                        // iter_mut().zip() proves equal lengths → no
236                        // bounds checks → LLVM vectorises to vfmadd231pd.
237                        for (r, &w) in r_slice.iter_mut().zip(c_slice) {
238                            *r += c_val * w;
239                        }
240                    }
241                }
242            }
243        }
244
245        // 4. Apply whitening if enabled, then split into column-major features.
246        let mut new_features: Vec<Vec<f64>> = Vec::with_capacity(k);
247        for c in 0..k {
248            let scale = if self.do_whiten {
249                let ev = self.explained_variance[c];
250                if ev > crate::constants::NEAR_ZERO {
251                    1.0 / ev.sqrt()
252                } else {
253                    1.0
254                }
255            } else {
256                1.0
257            };
258            let mut col = Vec::with_capacity(n);
259            for i in 0..n {
260                col.push(result[i * k + c] * scale);
261            }
262            new_features.push(col);
263        }
264
265        // Replace dataset features.
266        data.features = new_features;
267        data.feature_names = (0..k).map(|i| format!("PC{}", i + 1)).collect();
268        data.invalidate_matrix();
269        Ok(())
270    }
271
272    fn inverse_transform(&self, data: &mut Dataset) -> Result<()> {
273        const BLOCK: usize = 64;
274        if !self.fitted {
275            return Err(ScryLearnError::NotFitted);
276        }
277        let n = data.n_samples();
278        let k = self.components.len();
279        let m = self.mean.len();
280
281        // ── Cache-friendly blocked inverse transform ──
282        //
283        // Reconstructs X_recon ≈ Y · W + μ using contiguous flat buffers
284        // and blocked matmul with slice-based bound elision for SIMD.
285
286        // Step 1: Flatten scores from column-major Vec<Vec<f64>> into
287        //         row-major n×k buffer, un-whitening inline.
288        let mut scores_flat = vec![0.0; n * k];
289        for c in 0..k {
290            let col = &data.features[c];
291            let scale = if self.do_whiten {
292                self.explained_variance[c].sqrt()
293            } else {
294                1.0
295            };
296            for i in 0..n {
297                scores_flat[i * k + c] = col[i] * scale;
298            }
299        }
300
301        // Step 2: Flatten components (k × m) — natively row-major,
302        //         so copy_from_slice lowers to a fast memcpy.
303        let mut comp_flat = vec![0.0; k * m];
304        for (c, comp) in self.components.iter().enumerate() {
305            comp_flat[c * m..c * m + m].copy_from_slice(comp);
306        }
307
308        // Step 3: Pre-fill row-major n×m recon buffer with means.
309        //         Avoids a separate addition pass in the hot loop.
310        let mut recon_flat = vec![0.0; n * m];
311        for i in 0..n {
312            recon_flat[i * m..i * m + m].copy_from_slice(&self.mean);
313        }
314
315        // Step 4: Blocked DAXPY matmul — scores_flat(n×k) · comp_flat(k×m)
316        //         Loop order ib→jb→cb keeps a BLOCK×BLOCK recon tile in L1.
317        //         Slice extraction elides bounds checks → LLVM emits SIMD FMA.
318        for ib in (0..n).step_by(BLOCK) {
319            let i_end = (ib + BLOCK).min(n);
320            for jb in (0..m).step_by(BLOCK) {
321                let j_end = (jb + BLOCK).min(m);
322                for cb in (0..k).step_by(BLOCK) {
323                    let c_end = (cb + BLOCK).min(k);
324
325                    for i in ib..i_end {
326                        for c in cb..c_end {
327                            let y_val = scores_flat[i * k + c];
328
329                            let recon_slice = &mut recon_flat[i * m + jb..i * m + j_end];
330                            let comp_slice = &comp_flat[c * m + jb..c * m + j_end];
331
332                            // iter_mut().zip() proves equal lengths → no
333                            // bounds checks → LLVM vectorises to vfmadd231pd.
334                            for (r, &w) in recon_slice.iter_mut().zip(comp_slice) {
335                                *r += y_val * w;
336                            }
337                        }
338                    }
339                }
340            }
341        }
342
343        // Step 5: Blocked scatter from row-major flat → column-major Vec<Vec<f64>>.
344        let mut reconstructed: Vec<Vec<f64>> = vec![vec![0.0; n]; m];
345        for ib in (0..n).step_by(BLOCK) {
346            let i_end = (ib + BLOCK).min(n);
347            for jb in (0..m).step_by(BLOCK) {
348                let j_end = (jb + BLOCK).min(m);
349                for j in jb..j_end {
350                    let col = &mut reconstructed[j];
351                    for i in ib..i_end {
352                        col[i] = recon_flat[i * m + j];
353                    }
354                }
355            }
356        }
357
358        data.features = reconstructed;
359        data.feature_names = (0..m).map(|i| format!("x{i}")).collect();
360        data.invalidate_matrix();
361        Ok(())
362    }
363}
364
365// ── Jacobi eigendecomposition ─────────────────────────────────────
366//
367// For a real symmetric n×n matrix, iterates 2×2 rotations to
368// diagonalise it.  Returns (eigenvalues, eigenvectors_flat).
369// eigenvectors_flat is row-major n×n where column j is eigenvector j.
370
371fn jacobi_eigen(n: usize, a: &mut [f64]) -> (Vec<f64>, Vec<f64>) {
372    // Identity matrix for eigenvectors (row-major).
373    let mut v = vec![0.0; n * n];
374    for i in 0..n {
375        v[i * n + i] = 1.0;
376    }
377
378    let max_sweeps = crate::constants::JACOBI_MAX_SWEEPS;
379    let tol = crate::constants::JACOBI_TOL;
380
381    for _sweep in 0..max_sweeps {
382        // Off-diagonal Frobenius norm.
383        let mut off = 0.0;
384        for i in 0..n {
385            for j in (i + 1)..n {
386                off += a[i * n + j] * a[i * n + j];
387            }
388        }
389        if off < tol {
390            break;
391        }
392
393        for p in 0..n {
394            for q in (p + 1)..n {
395                let apq = a[p * n + q];
396                if apq.abs() < crate::constants::NEAR_ZERO {
397                    continue;
398                }
399
400                let diff = a[q * n + q] - a[p * n + p];
401                let t = if diff.abs() < crate::constants::NEAR_ZERO {
402                    // θ = π/4 → t = 1
403                    1.0
404                } else {
405                    let tau = diff / (2.0 * apq);
406                    // Pick the smaller root for stability.
407                    let sign = if tau >= 0.0 { 1.0 } else { -1.0 };
408                    sign / (tau.abs() + (1.0 + tau * tau).sqrt())
409                };
410
411                let c = 1.0 / (1.0 + t * t).sqrt();
412                let s = t * c;
413
414                // Update matrix A.
415                let tau_val = s / (1.0 + c);
416
417                a[p * n + p] -= t * apq;
418                a[q * n + q] += t * apq;
419                a[p * n + q] = 0.0;
420                a[q * n + p] = 0.0;
421
422                // Rotate rows/columns (only upper triangle elements are needed
423                // but we keep it symmetric for simplicity).
424                for r in 0..n {
425                    if r == p || r == q {
426                        continue;
427                    }
428                    let arp = a[r * n + p];
429                    let arq = a[r * n + q];
430                    a[r * n + p] = arp - s * (arq + tau_val * arp);
431                    a[p * n + r] = a[r * n + p];
432                    a[r * n + q] = arq + s * (arp - tau_val * arq);
433                    a[q * n + r] = a[r * n + q];
434                }
435
436                // Rotate eigenvector columns.
437                for r in 0..n {
438                    let vp = v[r * n + p];
439                    let vq = v[r * n + q];
440                    v[r * n + p] = vp - s * (vq + tau_val * vp);
441                    v[r * n + q] = vq + s * (vp - tau_val * vq);
442                }
443            }
444        }
445    }
446
447    let eigenvalues: Vec<f64> = (0..n).map(|i| a[i * n + i]).collect();
448    (eigenvalues, v)
449}
450
451// ── Tests ─────────────────────────────────────────────────────────
452
453#[cfg(test)]
454mod tests {
455    use super::*;
456
457    fn iris_4d_subset() -> Dataset {
458        // 12 samples from Iris (3 classes × 4 samples), 4 features.
459        Dataset::new(
460            vec![
461                vec![5.1, 4.9, 4.7, 4.6, 7.0, 6.4, 6.9, 5.5, 6.3, 5.8, 7.1, 6.3],
462                vec![3.5, 3.0, 3.2, 3.1, 3.2, 3.2, 3.1, 2.3, 3.3, 2.7, 3.0, 2.9],
463                vec![1.4, 1.4, 1.3, 1.5, 4.7, 4.5, 4.9, 4.0, 6.0, 5.1, 5.9, 5.6],
464                vec![0.2, 0.2, 0.2, 0.2, 1.4, 1.5, 1.5, 1.3, 2.5, 1.9, 2.1, 1.8],
465            ],
466            vec![0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0],
467            vec![
468                "sepal_length".into(),
469                "sepal_width".into(),
470                "petal_length".into(),
471                "petal_width".into(),
472            ],
473            "species",
474        )
475    }
476
477    #[test]
478    fn pca_identity_no_reduction() {
479        let ds = iris_4d_subset();
480        let mut pca = Pca::new();
481        pca.fit(&ds).unwrap();
482        assert_eq!(pca.n_components_fitted(), 4);
483    }
484
485    #[test]
486    fn pca_variance_explained_sums_to_one() {
487        let ds = iris_4d_subset();
488        let mut pca = Pca::new();
489        pca.fit(&ds).unwrap();
490        let sum: f64 = pca.explained_variance_ratio().iter().sum();
491        assert!(
492            (sum - 1.0).abs() < 1e-6,
493            "variance ratios should sum to 1.0, got {sum}"
494        );
495    }
496
497    #[test]
498    fn pca_reduces_dimensions() {
499        let mut ds = iris_4d_subset();
500        let mut pca = Pca::with_n_components(2);
501        pca.fit_transform(&mut ds).unwrap();
502        assert_eq!(ds.n_features(), 2);
503        assert_eq!(ds.feature_names[0], "PC1");
504        assert_eq!(ds.feature_names[1], "PC2");
505    }
506
507    #[test]
508    fn pca_roundtrip_inverse() {
509        let original = iris_4d_subset();
510        let mut ds = original.clone();
511        let mut pca = Pca::new(); // keep all → perfect roundtrip.
512        pca.fit_transform(&mut ds).unwrap();
513        pca.inverse_transform(&mut ds).unwrap();
514
515        for j in 0..original.n_features() {
516            for i in 0..original.n_samples() {
517                assert!(
518                    (ds.features[j][i] - original.features[j][i]).abs() < 1e-6,
519                    "roundtrip mismatch at feature {j}, sample {i}: {} vs {}",
520                    ds.features[j][i],
521                    original.features[j][i],
522                );
523            }
524        }
525    }
526
527    #[test]
528    fn pca_whiten_unit_variance() {
529        let mut ds = iris_4d_subset();
530        let mut pca = Pca::with_n_components(2).whiten(true);
531        pca.fit_transform(&mut ds).unwrap();
532
533        // Each component should have variance ≈ 1.
534        for j in 0..2 {
535            let col = &ds.features[j];
536            let mean = col.iter().sum::<f64>() / col.len() as f64;
537            let var = col.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / (col.len() - 1) as f64;
538            assert!(
539                (var - 1.0).abs() < 0.15,
540                "whitened PC{} variance should be ~1.0, got {var}",
541                j + 1,
542            );
543        }
544    }
545
546    #[test]
547    fn pca_not_fitted_error() {
548        let pca = Pca::new();
549        let mut ds = iris_4d_subset();
550        assert!(pca.transform(&mut ds).is_err());
551    }
552
553    #[test]
554    fn pca_empty_dataset_error() {
555        let ds = Dataset::new(vec![vec![]], vec![], vec!["x".into()], "y");
556        let mut pca = Pca::new();
557        assert!(pca.fit(&ds).is_err());
558    }
559
560    #[test]
561    fn pca_components_orthogonal() {
562        let ds = iris_4d_subset();
563        let mut pca = Pca::new();
564        pca.fit(&ds).unwrap();
565
566        let comps = pca.components();
567        let k = comps.len();
568        for i in 0..k {
569            for j in (i + 1)..k {
570                let dot: f64 = comps[i]
571                    .iter()
572                    .zip(comps[j].iter())
573                    .map(|(a, b)| a * b)
574                    .sum();
575                assert!(
576                    dot.abs() < 1e-6,
577                    "components {i} and {j} should be orthogonal, dot = {dot}"
578                );
579            }
580            // Unit norm.
581            let norm: f64 = comps[i].iter().map(|x| x * x).sum::<f64>().sqrt();
582            assert!(
583                (norm - 1.0).abs() < 1e-6,
584                "component {i} should have unit norm, got {norm}"
585            );
586        }
587    }
588}