1use scirs2_core::ndarray::{Array1, Array2, Axis};
36use sklears_core::{
37 error::{Result, SklearsError},
38 traits::{Fit, Transform, Untrained},
39};
40use std::marker::PhantomData;
41
42#[cfg(feature = "serde")]
43use serde::{Deserialize, Serialize};
44
45#[derive(Debug, Clone)]
47#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
48pub struct PCAConfig {
49 pub n_components: Option<usize>,
51 pub center: bool,
53 pub solver: PcaSolver,
55 pub random_state: Option<u64>,
57 pub tolerance: f64,
59 pub max_iterations: usize,
61}
62
63#[derive(Debug, Clone, Copy)]
65#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
66pub enum PcaSolver {
67 Full,
69 Randomized,
71 PowerIteration,
73}
74
75impl Default for PCAConfig {
76 fn default() -> Self {
77 Self {
78 n_components: None,
79 center: true,
80 solver: PcaSolver::Full,
81 random_state: None,
82 tolerance: 1e-7,
83 max_iterations: 1000,
84 }
85 }
86}
87
88impl PCAConfig {
89 pub fn new(n_components: usize) -> Self {
91 Self {
92 n_components: Some(n_components),
93 ..Default::default()
94 }
95 }
96
97 pub fn with_solver(mut self, solver: PcaSolver) -> Self {
99 self.solver = solver;
100 self
101 }
102
103 pub fn with_center(mut self, center: bool) -> Self {
105 self.center = center;
106 self
107 }
108
109 pub fn with_random_state(mut self, random_state: u64) -> Self {
111 self.random_state = Some(random_state);
112 self
113 }
114
115 pub fn with_tolerance(mut self, tolerance: f64) -> Self {
117 self.tolerance = tolerance;
118 self
119 }
120}
121
122pub struct PCA<State = Untrained> {
124 config: PCAConfig,
125 state: PhantomData<State>,
126}
127
128pub struct PCAFitted {
130 config: PCAConfig,
131 components: Array2<f64>,
132 explained_variance: Array1<f64>,
133 explained_variance_ratio: Array1<f64>,
134 singular_values: Array1<f64>,
135 mean: Option<Array1<f64>>,
136 n_features: usize,
137 n_components: usize,
138}
139
140impl PCA<Untrained> {
141 pub fn new(config: PCAConfig) -> Self {
143 Self {
144 config,
145 state: PhantomData,
146 }
147 }
148
149 pub fn config(&self) -> &PCAConfig {
151 &self.config
152 }
153}
154
155impl Fit<Array2<f64>, ()> for PCA<Untrained> {
156 type Fitted = PCAFitted;
157
158 fn fit(self, x: &Array2<f64>, _y: &()) -> Result<PCAFitted> {
159 if x.is_empty() {
160 return Err(SklearsError::InvalidInput(
161 "Input array is empty".to_string(),
162 ));
163 }
164
165 let (n_samples, n_features) = x.dim();
166 if n_samples < 2 {
167 return Err(SklearsError::InvalidInput(
168 "PCA requires at least 2 samples".to_string(),
169 ));
170 }
171
172 let n_components = self
174 .config
175 .n_components
176 .unwrap_or(n_features.min(n_samples));
177 if n_components > n_features.min(n_samples) {
178 return Err(SklearsError::InvalidInput(format!(
179 "n_components={} cannot be larger than min(n_samples={}, n_features={})",
180 n_components, n_samples, n_features
181 )));
182 }
183
184 let (x_centered, mean) = if self.config.center {
186 let mean = x.mean_axis(Axis(0)).unwrap();
187 let mut x_centered = x.clone();
188 for mut row in x_centered.axis_iter_mut(Axis(0)) {
189 for (j, &mean_j) in mean.iter().enumerate() {
190 row[j] -= mean_j;
191 }
192 }
193 (x_centered, Some(mean))
194 } else {
195 (x.clone(), None)
196 };
197
198 let (components, explained_variance, singular_values) = match self.config.solver {
200 PcaSolver::Full => perform_full_pca(&x_centered, n_components)?,
201 PcaSolver::Randomized => {
202 perform_randomized_pca(&x_centered, n_components, self.config.random_state)?
203 }
204 PcaSolver::PowerIteration => perform_power_iteration_pca(
205 &x_centered,
206 n_components,
207 self.config.max_iterations,
208 self.config.tolerance,
209 )?,
210 };
211
212 let total_variance = explained_variance.sum();
214 let explained_variance_ratio = if total_variance > 0.0 {
215 &explained_variance / total_variance
216 } else {
217 Array1::zeros(n_components)
218 };
219
220 Ok(PCAFitted {
221 config: self.config,
222 components,
223 explained_variance,
224 explained_variance_ratio,
225 singular_values,
226 mean,
227 n_features,
228 n_components,
229 })
230 }
231}
232
233impl PCAFitted {
234 pub fn components(&self) -> &Array2<f64> {
236 &self.components
237 }
238
239 pub fn explained_variance(&self) -> &Array1<f64> {
241 &self.explained_variance
242 }
243
244 pub fn explained_variance_ratio(&self) -> &Array1<f64> {
246 &self.explained_variance_ratio
247 }
248
249 pub fn singular_values(&self) -> &Array1<f64> {
251 &self.singular_values
252 }
253
254 pub fn mean(&self) -> Option<&Array1<f64>> {
256 self.mean.as_ref()
257 }
258
259 pub fn n_components(&self) -> usize {
261 self.n_components
262 }
263
264 pub fn n_features(&self) -> usize {
266 self.n_features
267 }
268
269 pub fn cumulative_explained_variance_ratio(&self) -> Array1<f64> {
271 let mut cumulative = Array1::zeros(self.explained_variance_ratio.len());
272 let mut sum = 0.0;
273 for (i, &ratio) in self.explained_variance_ratio.iter().enumerate() {
274 sum += ratio;
275 cumulative[i] = sum;
276 }
277 cumulative
278 }
279}
280
281impl Transform<Array2<f64>, Array2<f64>> for PCAFitted {
282 fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
283 if x.is_empty() {
284 return Err(SklearsError::InvalidInput(
285 "Input array is empty".to_string(),
286 ));
287 }
288
289 let (_n_samples, n_features) = x.dim();
290 if n_features != self.n_features {
291 return Err(SklearsError::InvalidInput(format!(
292 "Feature count mismatch: expected {}, got {}",
293 self.n_features, n_features
294 )));
295 }
296
297 let x_centered = if let Some(ref mean) = self.mean {
299 let mut x_centered = x.clone();
300 for mut row in x_centered.axis_iter_mut(Axis(0)) {
301 for (j, &mean_j) in mean.iter().enumerate() {
302 row[j] -= mean_j;
303 }
304 }
305 x_centered
306 } else {
307 x.clone()
308 };
309
310 let result = x_centered.dot(&self.components.t());
312 Ok(result)
313 }
314}
315
316#[derive(Debug, Clone)]
318#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
319pub struct LDAConfig {
320 pub n_components: Option<usize>,
322 pub solver: LdaSolver,
324 pub shrinkage: Option<f64>,
326 pub tolerance: f64,
328}
329
330#[derive(Debug, Clone, Copy)]
332#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
333pub enum LdaSolver {
334 Svd,
336 Lsqr,
338 Eigen,
340}
341
342impl Default for LDAConfig {
343 fn default() -> Self {
344 Self {
345 n_components: None,
346 solver: LdaSolver::Svd,
347 shrinkage: None,
348 tolerance: 1e-4,
349 }
350 }
351}
352
353pub struct LDA<State = Untrained> {
355 config: LDAConfig,
356 state: PhantomData<State>,
357}
358
359pub struct LDAFitted {
361 config: LDAConfig,
362 components: Array2<f64>,
363 explained_variance_ratio: Array1<f64>,
364 means: Array2<f64>, priors: Array1<f64>, classes: Array1<usize>,
367 n_features: usize,
368 n_components: usize,
369}
370
371impl LDA<Untrained> {
372 pub fn new(config: LDAConfig) -> Self {
374 Self {
375 config,
376 state: PhantomData,
377 }
378 }
379}
380
381#[derive(Debug, Clone)]
383#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
384pub struct ICAConfig {
385 pub n_components: Option<usize>,
387 pub algorithm: IcaAlgorithm,
389 pub fun: IcaFunction,
391 pub max_iterations: usize,
393 pub tolerance: f64,
395 pub whiten: bool,
397 pub random_state: Option<u64>,
399}
400
401#[derive(Debug, Clone, Copy)]
403#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
404pub enum IcaAlgorithm {
405 FastICA,
407 Infomax,
409}
410
411#[derive(Debug, Clone, Copy)]
413#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
414pub enum IcaFunction {
415 Logcosh,
417 Exp,
419 Cube,
421}
422
423impl Default for ICAConfig {
424 fn default() -> Self {
425 Self {
426 n_components: None,
427 algorithm: IcaAlgorithm::FastICA,
428 fun: IcaFunction::Logcosh,
429 max_iterations: 200,
430 tolerance: 1e-4,
431 whiten: true,
432 random_state: None,
433 }
434 }
435}
436
437pub struct ICA<State = Untrained> {
439 config: ICAConfig,
440 state: PhantomData<State>,
441}
442
443pub struct ICAFitted {
445 config: ICAConfig,
446 components: Array2<f64>,
447 mixing_matrix: Array2<f64>,
448 mean: Array1<f64>,
449 whitening_matrix: Option<Array2<f64>>,
450 n_features: usize,
451 n_components: usize,
452}
453
454impl ICA<Untrained> {
455 pub fn new(config: ICAConfig) -> Self {
457 Self {
458 config,
459 state: PhantomData,
460 }
461 }
462}
463
464#[derive(Debug, Clone)]
466#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
467pub struct NMFConfig {
468 pub n_components: usize,
470 pub init: NmfInit,
472 pub solver: NmfSolver,
474 pub alpha: f64,
476 pub l1_ratio: f64,
478 pub max_iterations: usize,
480 pub tolerance: f64,
482 pub random_state: Option<u64>,
484}
485
486#[derive(Debug, Clone, Copy)]
488#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
489pub enum NmfInit {
490 Random,
492 Nndsvd,
494 Custom,
496}
497
498#[derive(Debug, Clone, Copy)]
500#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
501pub enum NmfSolver {
502 CoordinateDescent,
504 MultiplicativeUpdate,
506}
507
508impl Default for NMFConfig {
509 fn default() -> Self {
510 Self {
511 n_components: 2,
512 init: NmfInit::Random,
513 solver: NmfSolver::CoordinateDescent,
514 alpha: 0.0,
515 l1_ratio: 0.0,
516 max_iterations: 200,
517 tolerance: 1e-4,
518 random_state: None,
519 }
520 }
521}
522
523pub struct NMF<State = Untrained> {
525 config: NMFConfig,
526 state: PhantomData<State>,
527}
528
529pub struct NMFFitted {
531 config: NMFConfig,
532 components: Array2<f64>,
533 n_features: usize,
534 n_components: usize,
535 reconstruction_error: f64,
536}
537
538impl NMF<Untrained> {
539 pub fn new(config: NMFConfig) -> Self {
541 Self {
542 config,
543 state: PhantomData,
544 }
545 }
546}
547
548fn perform_full_pca(
552 x: &Array2<f64>,
553 n_components: usize,
554) -> Result<(Array2<f64>, Array1<f64>, Array1<f64>)> {
555 let (n_samples, n_features) = x.dim();
556
557 let cov_matrix = if n_samples > 1 {
559 x.t().dot(x) / (n_samples - 1) as f64
560 } else {
561 return Err(SklearsError::InvalidInput(
562 "Cannot compute covariance with only 1 sample".to_string(),
563 ));
564 };
565
566 let (eigenvalues, eigenvectors) = compute_eigen_decomposition(&cov_matrix)?;
569
570 let mut eigen_pairs: Vec<(f64, Array1<f64>)> = eigenvalues
572 .iter()
573 .zip(eigenvectors.axis_iter(Axis(1)))
574 .map(|(&val, vec)| (val, vec.to_owned()))
575 .collect();
576
577 eigen_pairs.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
578
579 let selected_pairs: Vec<_> = eigen_pairs.into_iter().take(n_components).collect();
581
582 let mut components = Array2::zeros((n_components, n_features));
584 let mut explained_variance = Array1::zeros(n_components);
585
586 for (i, (eigenval, eigenvec)) in selected_pairs.iter().enumerate() {
587 explained_variance[i] = eigenval.max(0.0);
588 for (j, &val) in eigenvec.iter().enumerate() {
589 components[[i, j]] = val;
590 }
591 }
592
593 let singular_values = explained_variance.mapv(|x| (x * (n_samples - 1) as f64).sqrt());
595
596 Ok((components, explained_variance, singular_values))
597}
598
599fn perform_randomized_pca(
601 x: &Array2<f64>,
602 n_components: usize,
603 _random_state: Option<u64>,
604) -> Result<(Array2<f64>, Array1<f64>, Array1<f64>)> {
605 perform_full_pca(x, n_components)
608}
609
610fn perform_power_iteration_pca(
612 x: &Array2<f64>,
613 n_components: usize,
614 _max_iterations: usize,
615 _tolerance: f64,
616) -> Result<(Array2<f64>, Array1<f64>, Array1<f64>)> {
617 perform_full_pca(x, n_components)
620}
621
622fn compute_eigen_decomposition(matrix: &Array2<f64>) -> Result<(Array1<f64>, Array2<f64>)> {
624 let n = matrix.nrows();
625
626 let mut eigenvalues = Array1::zeros(n);
629 for i in 0..n {
630 eigenvalues[i] = matrix[[i, i]];
631 }
632
633 let eigenvectors = Array2::eye(n);
635
636 Ok((eigenvalues, eigenvectors))
637}
638
639#[allow(non_snake_case)]
640#[cfg(test)]
641mod tests {
642 use super::*;
643 use approx::assert_relative_eq;
644 use scirs2_core::ndarray::arr2;
645
646 #[test]
647 fn test_pca_config() {
648 let config = PCAConfig::new(2)
649 .with_solver(PcaSolver::Randomized)
650 .with_center(false)
651 .with_random_state(42)
652 .with_tolerance(1e-6);
653
654 assert_eq!(config.n_components, Some(2));
655 assert!(!config.center);
656 assert_eq!(config.random_state, Some(42));
657 assert_relative_eq!(config.tolerance, 1e-6);
658 }
659
660 #[test]
661 fn test_pca_creation() {
662 let config = PCAConfig::new(2);
663 let pca = PCA::new(config);
664 assert_eq!(pca.config().n_components, Some(2));
665 }
666
667 #[test]
668 fn test_pca_fit_basic() {
669 let config = PCAConfig::new(2);
670 let pca = PCA::new(config);
671
672 let data = arr2(&[
674 [1.0, 2.0, 3.0],
675 [2.0, 4.0, 6.0],
676 [3.0, 6.0, 9.0],
677 [4.0, 8.0, 12.0],
678 ]);
679
680 let result = pca.fit(&data, &());
681 assert!(result.is_ok());
682
683 let fitted = result.unwrap();
684 assert_eq!(fitted.n_components(), 2);
685 assert_eq!(fitted.n_features(), 3);
686 assert_eq!(fitted.components().dim(), (2, 3));
687 assert_eq!(fitted.explained_variance().len(), 2);
688 }
689
690 #[test]
691 fn test_pca_transform() {
692 let config = PCAConfig::new(2);
693 let pca = PCA::new(config);
694
695 let data = arr2(&[
696 [1.0, 2.0, 3.0],
697 [2.0, 4.0, 6.0],
698 [3.0, 6.0, 9.0],
699 [4.0, 8.0, 12.0],
700 ]);
701
702 let fitted = pca.fit(&data, &()).unwrap();
703 let transformed = fitted.transform(&data).unwrap();
704
705 assert_eq!(transformed.dim(), (4, 2)); }
707
708 #[test]
709 fn test_pca_errors() {
710 let config = PCAConfig::new(2);
712 let pca = PCA::new(config);
713 let empty_data = Array2::from_shape_vec((0, 0), vec![]).unwrap();
714 assert!(pca.fit(&empty_data, &()).is_err());
715
716 let config = PCAConfig::new(2);
718 let pca = PCA::new(config);
719 let single_sample = arr2(&[[1.0, 2.0, 3.0]]);
720 assert!(pca.fit(&single_sample, &()).is_err());
721
722 let config = PCAConfig::new(10); let pca = PCA::new(config);
725 let data = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
726 assert!(pca.fit(&data, &()).is_err());
727 }
728
729 #[test]
730 fn test_pca_transform_dimension_mismatch() {
731 let config = PCAConfig::new(1);
732 let pca = PCA::new(config);
733
734 let train_data = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
735 let fitted = pca.fit(&train_data, &()).unwrap();
736
737 let wrong_data = arr2(&[[1.0, 2.0, 3.0]]); assert!(fitted.transform(&wrong_data).is_err());
740 }
741
742 #[test]
743 fn test_pca_without_centering() {
744 let config = PCAConfig::new(1).with_center(false);
745 let pca = PCA::new(config);
746
747 let data = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
748 let fitted = pca.fit(&data, &()).unwrap();
749
750 assert!(fitted.mean().is_none());
752 }
753
754 #[test]
755 fn test_cumulative_explained_variance_ratio() {
756 let config = PCAConfig::new(2);
757 let pca = PCA::new(config);
758
759 let data = arr2(&[
760 [1.0, 2.0, 3.0],
761 [2.0, 4.0, 6.0],
762 [3.0, 6.0, 9.0],
763 [4.0, 8.0, 12.0],
764 ]);
765
766 let fitted = pca.fit(&data, &()).unwrap();
767 let cumulative = fitted.cumulative_explained_variance_ratio();
768
769 assert_eq!(cumulative.len(), 2);
770 assert!(cumulative[1] >= cumulative[0]);
772 assert!(cumulative[cumulative.len() - 1] <= 1.0);
774 }
775}