1use crate::error::{StatsError, StatsResult as Result};
7use scirs2_core::ndarray::{Array1, Array2, ArrayView2, Axis};
8use scirs2_core::validation::*;
9
10#[derive(Debug, Clone)]
12pub struct PCA {
13 pub n_components: Option<usize>,
15 pub svd_solver: SvdSolver,
17 pub center: bool,
19 pub scale: bool,
21 pub random_state: Option<u64>,
23}
24
25#[derive(Debug, Clone, Copy, PartialEq)]
27pub enum SvdSolver {
28 Full,
30 Randomized,
32 Auto,
34}
35
36#[derive(Debug, Clone)]
38pub struct PCAResult {
39 pub components: Array2<f64>,
41 pub explained_variance: Array1<f64>,
43 pub explained_variance_ratio: Array1<f64>,
45 pub singular_values: Array1<f64>,
47 pub mean: Array1<f64>,
49 pub scale: Option<Array1<f64>>,
51 pub n_samples_: usize,
53 pub n_features: usize,
55}
56
57impl Default for PCA {
58 fn default() -> Self {
59 Self {
60 n_components: None,
61 svd_solver: SvdSolver::Auto,
62 center: true,
63 scale: false,
64 random_state: None,
65 }
66 }
67}
68
69impl PCA {
70 pub fn new() -> Self {
72 Self::default()
73 }
74
75 pub fn with_n_components(mut self, n_components: usize) -> Self {
77 self.n_components = Some(n_components);
78 self
79 }
80
81 pub fn with_svd_solver(mut self, solver: SvdSolver) -> Self {
83 self.svd_solver = solver;
84 self
85 }
86
87 pub fn with_center(mut self, center: bool) -> Self {
89 self.center = center;
90 self
91 }
92
93 pub fn with_scale(mut self, scale: bool) -> Self {
95 self.scale = scale;
96 self
97 }
98
99 pub fn with_random_state(mut self, seed: u64) -> Self {
101 self.random_state = Some(seed);
102 self
103 }
104
105 pub fn fit(&self, data: ArrayView2<f64>) -> Result<PCAResult> {
107 checkarray_finite(&data, "data")?;
108 let (n_samples, n_features) = data.dim();
109 if n_samples < 2 {
110 return Err(StatsError::InvalidArgument(
111 "n_samples must be at least 2".to_string(),
112 ));
113 }
114 if n_features < 1 {
115 return Err(StatsError::InvalidArgument(
116 "n_features must be at least 1".to_string(),
117 ));
118 }
119
120 let max_components = n_samples.min(n_features);
122 let n_components = match self.n_components {
123 Some(k) => {
124 check_positive(k, "n_components")?;
125 if k > max_components {
126 return Err(StatsError::InvalidArgument(format!(
127 "n_components ({}) cannot be larger than min(n_samples, n_features) = {}",
128 k, max_components
129 )));
130 }
131 k
132 }
133 None => max_components,
134 };
135
136 let mean = if self.center {
138 data.mean_axis(Axis(0)).expect("Operation failed")
139 } else {
140 Array1::zeros(n_features)
141 };
142
143 let mut centereddata = data.to_owned();
144 if self.center {
145 for mut row in centereddata.rows_mut() {
146 row -= &mean;
147 }
148 }
149
150 let scale = if self.scale {
152 let std = centereddata.std_axis(Axis(0), 1.0);
153 let std = std.mapv(|s| if s > 1e-10 { s } else { 1.0 });
155
156 for (mut col, &s) in centereddata.columns_mut().into_iter().zip(std.iter()) {
157 col /= s;
158 }
159 Some(std)
160 } else {
161 None
162 };
163
164 let solver = match self.svd_solver {
166 SvdSolver::Auto => {
167 if n_samples >= 500 && n_features >= 500 && n_components < max_components / 2 {
168 SvdSolver::Randomized
169 } else {
170 SvdSolver::Full
171 }
172 }
173 solver => solver,
174 };
175
176 let result = match solver {
178 SvdSolver::Full => self.pca_svd(¢ereddata, n_components, n_samples)?,
179 SvdSolver::Randomized => self.pca_randomized(¢ereddata, n_components, n_samples)?,
180 _ => unreachable!(),
181 };
182
183 Ok(PCAResult {
184 components: result.0,
185 explained_variance: result.1,
186 explained_variance_ratio: result.2,
187 singular_values: result.3,
188 mean,
189 scale,
190 n_samples_: n_samples,
191 n_features,
192 })
193 }
194
195 fn pca_svd(
197 &self,
198 data: &Array2<f64>,
199 n_components: usize,
200 n_samples: usize,
201 ) -> Result<(Array2<f64>, Array1<f64>, Array1<f64>, Array1<f64>)> {
202 let (_u, s, vt) = scirs2_linalg::svd(&data.view(), true, None)
204 .map_err(|e| StatsError::ComputationError(format!("SVD failed: {}", e)))?;
205 let v = vt.t().to_owned();
206
207 let components = v
209 .slice(scirs2_core::ndarray::s![.., ..n_components])
210 .to_owned();
211
212 let singular_values = s.slice(scirs2_core::ndarray::s![..n_components]).to_owned();
214 let explained_variance = &singular_values * &singular_values / (n_samples - 1) as f64;
215
216 let total_variance = explained_variance.sum();
218 let explained_variance_ratio = &explained_variance / total_variance;
219
220 Ok((
221 components.t().to_owned(),
222 explained_variance,
223 explained_variance_ratio,
224 singular_values,
225 ))
226 }
227
228 fn pca_randomized(
230 &self,
231 data: &Array2<f64>,
232 n_components: usize,
233 n_samples: usize,
234 ) -> Result<(Array2<f64>, Array1<f64>, Array1<f64>, Array1<f64>)> {
235 use scirs2_core::random::{rngs::StdRng, SeedableRng};
236 use scirs2_core::random::{Distribution, Normal};
237
238 let n_features = data.ncols();
239 let n_oversamples = 10.min((n_features - n_components) / 2);
240 let n_random = n_components + n_oversamples;
241
242 let mut rng = match self.random_state {
244 Some(seed) => StdRng::seed_from_u64(seed),
245 None => {
246 use std::time::{SystemTime, UNIX_EPOCH};
248 let seed = SystemTime::now()
249 .duration_since(UNIX_EPOCH)
250 .unwrap_or_default()
251 .as_secs();
252 StdRng::seed_from_u64(seed)
253 }
254 };
255
256 let normal = Normal::new(0.0, 1.0).map_err(|e| {
258 StatsError::ComputationError(format!("Failed to create normal distribution: {}", e))
259 })?;
260 let omega = Array2::from_shape_fn((n_features, n_random), |_| normal.sample(&mut rng));
261
262 let n_iter = 4;
264 let mut q = data.dot(&omega);
265
266 for _ in 0..n_iter {
267 let (q_mat, _r) = scirs2_linalg::qr(&q.view(), None).map_err(|e| {
269 StatsError::ComputationError(format!("QR decomposition failed: {}", e))
270 })?;
271 q = q_mat;
272
273 let z = data.t().dot(&q);
275 let (q_mat, _r) = scirs2_linalg::qr(&z.view(), None).map_err(|e| {
276 StatsError::ComputationError(format!("QR decomposition failed: {}", e))
277 })?;
278 q = data.dot(&q_mat);
279 }
280
281 let (q_final, _r) = scirs2_linalg::qr(&q.view(), None).map_err(|e| {
283 StatsError::ComputationError(format!("Final QR decomposition failed: {}", e))
284 })?;
285
286 let b = q_final.t().dot(data);
288
289 let (_u_small, s, vt) = scirs2_linalg::svd(&b.view(), true, None).map_err(|e| {
291 StatsError::ComputationError(format!("SVD of projected matrix failed: {}", e))
292 })?;
293
294 let v = vt.t().to_owned();
295
296 let components = v
298 .slice(scirs2_core::ndarray::s![.., ..n_components])
299 .to_owned();
300
301 let singular_values = s.slice(scirs2_core::ndarray::s![..n_components]).to_owned();
303 let explained_variance = &singular_values * &singular_values / (n_samples - 1) as f64;
304
305 let total_variance = explained_variance.sum();
307 let explained_variance_ratio = &explained_variance / total_variance;
308
309 Ok((
310 components.t().to_owned(),
311 explained_variance,
312 explained_variance_ratio,
313 singular_values,
314 ))
315 }
316
317 pub fn transform(&self, data: ArrayView2<f64>, result: &PCAResult) -> Result<Array2<f64>> {
319 checkarray_finite(&data, "data")?;
320 if data.ncols() != result.n_features {
321 return Err(StatsError::DimensionMismatch(format!(
322 "data has {} features, expected {}",
323 data.ncols(),
324 result.n_features
325 )));
326 }
327
328 let mut transformed = data.to_owned();
329
330 if self.center {
332 for mut row in transformed.rows_mut() {
333 row -= &result.mean;
334 }
335 }
336
337 if let Some(ref scale) = result.scale {
339 for (mut col, &s) in transformed.columns_mut().into_iter().zip(scale.iter()) {
340 col /= s;
341 }
342 }
343
344 Ok(transformed.dot(&result.components.t()))
346 }
347
348 pub fn inverse_transform(
350 &self,
351 data: ArrayView2<f64>,
352 result: &PCAResult,
353 ) -> Result<Array2<f64>> {
354 checkarray_finite(&data, "data")?;
355 let n_components = result.components.nrows();
356 if data.ncols() != n_components {
357 return Err(StatsError::DimensionMismatch(format!(
358 "data has {} components, expected {}",
359 data.ncols(),
360 n_components
361 )));
362 }
363
364 let mut reconstructed = data.dot(&result.components);
366
367 if let Some(ref scale) = result.scale {
369 for (mut col, &s) in reconstructed.columns_mut().into_iter().zip(scale.iter()) {
370 col *= s;
371 }
372 }
373
374 if self.center {
376 for mut row in reconstructed.rows_mut() {
377 row += &result.mean;
378 }
379 }
380
381 Ok(reconstructed)
382 }
383
384 pub fn fit_transform(&self, data: ArrayView2<f64>) -> Result<(Array2<f64>, PCAResult)> {
386 let result = self.fit(data)?;
387 let transformed = self.transform(data, &result)?;
388 Ok((transformed, result))
389 }
390}
391
392#[allow(dead_code)]
394pub fn mle_components(data: ArrayView2<f64>, maxcomponents: Option<usize>) -> Result<usize> {
395 checkarray_finite(&data, "data")?;
396 let (n_samples, n_features) = data.dim();
397
398 let pca = PCA::new().with_n_components(maxcomponents.unwrap_or(n_features.min(n_samples)));
399 let result = pca.fit(data)?;
400
401 let eigenvalues = &result.explained_variance;
402 let n = n_samples as f64;
403 let p = n_features as f64;
404
405 let mut best_k = 0;
407 let mut best_ll = f64::NEG_INFINITY;
408
409 for k in 0..eigenvalues.len() {
410 let k_f64 = k as f64;
411
412 let sigma2 = if k < eigenvalues.len() - 1 {
414 eigenvalues.slice(scirs2_core::ndarray::s![k + 1..]).sum() / (p - k_f64 - 1.0)
415 } else {
416 1e-10
417 };
418
419 let ll = -n / 2.0
421 * (eigenvalues
422 .slice(scirs2_core::ndarray::s![..=k])
423 .mapv(f64::ln)
424 .sum()
425 + (p - k_f64 - 1.0) * sigma2.ln()
426 + p * (2.0 * std::f64::consts::PI).ln());
427
428 let aic_penalty = k_f64 * (2.0 * p - k_f64 - 1.0);
430 let aic = ll - aic_penalty;
431
432 if aic > best_ll {
433 best_ll = aic;
434 best_k = k + 1;
435 }
436 }
437
438 Ok(best_k)
439}
440
441#[derive(Debug, Clone)]
443pub struct IncrementalPCA {
444 pub pca: PCA,
446 pub batchsize: usize,
448 mean: Option<Array1<f64>>,
450 components: Option<Array2<f64>>,
452 singular_values: Option<Array1<f64>>,
454 n_samples_seen: usize,
456 svd_u: Option<Array2<f64>>,
458 svd_s: Option<Array1<f64>>,
459 svd_v: Option<Array2<f64>>,
460}
461
462impl IncrementalPCA {
463 pub fn new(n_components: usize, batchsize: usize) -> Result<Self> {
465 check_positive(n_components, "n_components")?;
466 check_positive(batchsize, "batchsize")?;
467
468 Ok(Self {
469 pca: PCA::new().with_n_components(n_components),
470 batchsize,
471 mean: None,
472 components: None,
473 singular_values: None,
474 n_samples_seen: 0,
475 svd_u: None,
476 svd_s: None,
477 svd_v: None,
478 })
479 }
480
481 pub fn partial_fit(&mut self, batch: ArrayView2<f64>) -> Result<()> {
483 checkarray_finite(&batch, "batch")?;
484 let (batchsize, n_features) = batch.dim();
485
486 let batch_mean = batch.mean_axis(Axis(0)).expect("Operation failed");
488 let old_n = self.n_samples_seen;
489 self.n_samples_seen += batchsize;
490
491 self.mean = match &self.mean {
492 None => Some(batch_mean.clone()),
493 Some(mean) => {
494 let updated = (mean * old_n as f64 + &batch_mean * batchsize as f64)
495 / self.n_samples_seen as f64;
496 Some(updated)
497 }
498 };
499
500 let mut centered_batch = batch.to_owned();
502 for mut row in centered_batch.rows_mut() {
503 row -= &batch_mean;
504 }
505
506 let n_components = self
508 .pca
509 .n_components
510 .unwrap_or(n_features.min(self.n_samples_seen));
511
512 if self.svd_u.is_none() {
513 let (u, s, vt) = scirs2_linalg::svd(¢ered_batch.view(), true, None)
515 .map_err(|e| StatsError::ComputationError(format!("Initial SVD failed: {}", e)))?;
516
517 self.svd_u = Some(
519 u.slice(scirs2_core::ndarray::s![.., ..n_components])
520 .to_owned(),
521 );
522 self.svd_s = Some(s.slice(scirs2_core::ndarray::s![..n_components]).to_owned());
523 self.svd_v = Some(
524 vt.slice(scirs2_core::ndarray::s![..n_components, ..])
525 .t()
526 .to_owned(),
527 );
528
529 self.components = Some(
530 self.svd_v
531 .as_ref()
532 .expect("Operation failed")
533 .t()
534 .to_owned(),
535 );
536 self.singular_values = Some(self.svd_s.as_ref().expect("Operation failed").clone());
537 } else {
538 let u_old = self.svd_u.as_ref().expect("Operation failed");
540 let s_old = self.svd_s.as_ref().expect("Operation failed");
541 let v_old = self.svd_v.as_ref().expect("Operation failed");
542
543 let projection = centered_batch.dot(v_old);
545 let residual = ¢ered_batch - &projection.dot(&v_old.t());
546
547 let (q_res, r_res) = scirs2_linalg::qr(&residual.view(), None).map_err(|e| {
549 StatsError::ComputationError(format!("QR decomposition failed: {}", e))
550 })?;
551
552 let k = s_old.len();
554 let p = r_res.ncols();
555
556 let mut augmented = Array2::zeros((k + p, k + p));
558 for i in 0..k {
559 augmented[[i, i]] = s_old[i];
560 }
561 for i in 0..projection.nrows() {
562 for j in 0..k {
563 augmented[[j, k + i]] = projection[[i, j]];
564 }
565 }
566 for i in 0..p {
567 for j in 0..p {
568 augmented[[k + i, k + j]] = r_res[[i, j]];
569 }
570 }
571
572 let (u_aug, s_aug, vt_aug) = scirs2_linalg::svd(&augmented.view(), true, None)
574 .map_err(|e| {
575 StatsError::ComputationError(format!("Augmented SVD failed: {}", e))
576 })?;
577
578 let mut u_new = Array2::zeros((old_n + batchsize, n_components));
580 let u_aug_slice = u_aug.slice(scirs2_core::ndarray::s![..n_components, ..n_components]);
581
582 let u_old_part = u_old.dot(&u_aug_slice.t());
584 u_new
585 .slice_mut(scirs2_core::ndarray::s![..old_n, ..])
586 .assign(&u_old_part);
587
588 let u_batch_part =
590 projection.dot(&u_aug_slice.slice(scirs2_core::ndarray::s![.., ..k]).t());
591 let u_res_part = q_res.dot(&u_aug_slice.slice(scirs2_core::ndarray::s![.., k..]).t());
592 u_new
593 .slice_mut(scirs2_core::ndarray::s![old_n.., ..])
594 .assign(&(&u_batch_part + &u_res_part));
595
596 self.svd_s = Some(
598 s_aug
599 .slice(scirs2_core::ndarray::s![..n_components])
600 .to_owned(),
601 );
602
603 let v_aug_slice =
605 vt_aug.slice(scirs2_core::ndarray::s![..n_components, ..n_components]);
606 let mut v_new = Array2::zeros((n_features, n_components));
607
608 let v_old_part = v_old.dot(&v_aug_slice.slice(scirs2_core::ndarray::s![.., ..k]).t());
609 let v_res_part = q_res
610 .t()
611 .dot(¢ered_batch)
612 .t()
613 .dot(&v_aug_slice.slice(scirs2_core::ndarray::s![.., k..]).t());
614 v_new.assign(&(&v_old_part + &v_res_part));
615
616 self.svd_u = Some(u_new);
617 self.svd_v = Some(v_new.clone());
618 self.components = Some(v_new.t().to_owned());
619 self.singular_values = Some(self.svd_s.as_ref().expect("Operation failed").clone());
620 }
621
622 Ok(())
623 }
624
625 pub fn transform(&self, data: ArrayView2<f64>) -> Result<Array2<f64>> {
627 if self.components.is_none() || self.mean.is_none() {
628 return Err(StatsError::ComputationError(
629 "IncrementalPCA must be fitted before transform".to_string(),
630 ));
631 }
632
633 let mut centered = data.to_owned();
634 for mut row in centered.rows_mut() {
635 row -= self.mean.as_ref().expect("Operation failed");
636 }
637
638 Ok(centered.dot(&self.components.as_ref().expect("Operation failed").t()))
639 }
640}