1mod isomap;
8mod lle;
9mod spectral_embedding;
10mod tsne;
11mod umap;
12
13pub use crate::reduction::isomap::Isomap;
14pub use crate::reduction::lle::LLE;
15pub use crate::reduction::spectral_embedding::{AffinityMethod, SpectralEmbedding};
16pub use crate::reduction::tsne::{trustworthiness, TSNE};
17pub use crate::reduction::umap::UMAP;
18
19use scirs2_core::ndarray::{Array1, Array2, ArrayBase, Axis, Data, Ix1, Ix2};
20use scirs2_core::numeric::{Float, NumCast};
21use scirs2_linalg::svd;
22
23use crate::error::{Result, TransformError};
24
25const EPSILON: f64 = 1e-10;
27
28#[derive(Debug, Clone)]
33pub struct PCA {
34 n_components: usize,
36 center: bool,
38 scale: bool,
40 components: Option<Array2<f64>>,
42 mean: Option<Array1<f64>>,
44 std: Option<Array1<f64>>,
46 singular_values: Option<Array1<f64>>,
48 explained_variance_ratio: Option<Array1<f64>>,
50}
51
52impl PCA {
53 pub fn new(ncomponents: usize, center: bool, scale: bool) -> Self {
63 PCA {
64 n_components: ncomponents,
65 center,
66 scale,
67 components: None,
68 mean: None,
69 std: None,
70 singular_values: None,
71 explained_variance_ratio: None,
72 }
73 }
74
75 pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<()>
83 where
84 S: Data,
85 S::Elem: Float + NumCast,
86 {
87 let x_f64 = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
88
89 let n_samples = x_f64.shape()[0];
90 let n_features = x_f64.shape()[1];
91
92 if n_samples == 0 || n_features == 0 {
93 return Err(TransformError::InvalidInput("Empty input data".to_string()));
94 }
95
96 if self.n_components > n_features {
97 return Err(TransformError::InvalidInput(format!(
98 "n_components={} must be <= n_features={}",
99 self.n_components, n_features
100 )));
101 }
102
103 let mut x_processed = Array2::zeros((n_samples, n_features));
105 let mut mean = Array1::zeros(n_features);
106 let mut std = Array1::ones(n_features);
107
108 if self.center {
109 for j in 0..n_features {
110 let col_mean = x_f64.column(j).sum() / n_samples as f64;
111 mean[j] = col_mean;
112
113 for i in 0..n_samples {
114 x_processed[[i, j]] = x_f64[[i, j]] - col_mean;
115 }
116 }
117 } else {
118 x_processed.assign(&x_f64);
119 }
120
121 if self.scale {
122 for j in 0..n_features {
123 let col_std =
124 (x_processed.column(j).mapv(|x| x * x).sum() / n_samples as f64).sqrt();
125 if col_std > f64::EPSILON {
126 std[j] = col_std;
127
128 for i in 0..n_samples {
129 x_processed[[i, j]] /= col_std;
130 }
131 }
132 }
133 }
134
135 let (_u, s, vt) = match svd::<f64>(&x_processed.view(), true, None) {
137 Ok(result) => result,
138 Err(e) => return Err(TransformError::LinalgError(e)),
139 };
140
141 let mut components = Array2::zeros((self.n_components, n_features));
143 let mut singular_values = Array1::zeros(self.n_components);
144
145 for i in 0..self.n_components {
146 singular_values[i] = s[i];
147 for j in 0..n_features {
148 components[[i, j]] = vt[[i, j]];
149 }
150 }
151
152 let total_variance = s.mapv(|s| s * s).sum();
154 let explained_variance_ratio = singular_values.mapv(|s| s * s / total_variance);
155
156 self.components = Some(components);
157 self.mean = Some(mean);
158 self.std = Some(std);
159 self.singular_values = Some(singular_values);
160 self.explained_variance_ratio = Some(explained_variance_ratio);
161
162 Ok(())
163 }
164
165 pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
173 where
174 S: Data,
175 S::Elem: Float + NumCast,
176 {
177 let x_f64 = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
178
179 let n_samples = x_f64.shape()[0];
180 let n_features = x_f64.shape()[1];
181
182 if self.components.is_none() {
183 return Err(TransformError::TransformationError(
184 "PCA model has not been fitted".to_string(),
185 ));
186 }
187
188 let components = self.components.as_ref().unwrap();
189 let mean = self.mean.as_ref().unwrap();
190 let std = self.std.as_ref().unwrap();
191
192 if n_features != components.shape()[1] {
193 return Err(TransformError::InvalidInput(format!(
194 "x has {} features, but PCA was fitted with {} features",
195 n_features,
196 components.shape()[1]
197 )));
198 }
199
200 let mut x_processed = Array2::zeros((n_samples, n_features));
202
203 for i in 0..n_samples {
204 for j in 0..n_features {
205 let mut value = x_f64[[i, j]];
206
207 if self.center {
208 value -= mean[j];
209 }
210
211 if self.scale {
212 value /= std[j];
213 }
214
215 x_processed[[i, j]] = value;
216 }
217 }
218
219 let mut transformed = Array2::zeros((n_samples, self.n_components));
221
222 for i in 0..n_samples {
223 for j in 0..self.n_components {
224 let mut dot_product = 0.0;
225 for k in 0..n_features {
226 dot_product += x_processed[[i, k]] * components[[j, k]];
227 }
228 transformed[[i, j]] = dot_product;
229 }
230 }
231
232 Ok(transformed)
233 }
234
235 pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
243 where
244 S: Data,
245 S::Elem: Float + NumCast,
246 {
247 self.fit(x)?;
248 self.transform(x)
249 }
250
251 pub fn components(&self) -> Option<&Array2<f64>> {
256 self.components.as_ref()
257 }
258
259 pub fn explained_variance_ratio(&self) -> Option<&Array1<f64>> {
264 self.explained_variance_ratio.as_ref()
265 }
266}
267
268#[derive(Debug, Clone)]
274pub struct TruncatedSVD {
275 n_components: usize,
277 singular_values: Option<Array1<f64>>,
279 components: Option<Array2<f64>>,
281 explained_variance_ratio: Option<Array1<f64>>,
283}
284
285impl TruncatedSVD {
286 pub fn new(ncomponents: usize) -> Self {
294 TruncatedSVD {
295 n_components: ncomponents,
296 singular_values: None,
297 components: None,
298 explained_variance_ratio: None,
299 }
300 }
301
302 pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<()>
310 where
311 S: Data,
312 S::Elem: Float + NumCast,
313 {
314 let x_f64 = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
315
316 let n_samples = x_f64.shape()[0];
317 let n_features = x_f64.shape()[1];
318
319 if n_samples == 0 || n_features == 0 {
320 return Err(TransformError::InvalidInput("Empty input data".to_string()));
321 }
322
323 if self.n_components > n_features {
324 return Err(TransformError::InvalidInput(format!(
325 "n_components={} must be <= n_features={}",
326 self.n_components, n_features
327 )));
328 }
329
330 let (_u, s, vt) = match svd::<f64>(&x_f64.view(), true, None) {
332 Ok(result) => result,
333 Err(e) => return Err(TransformError::LinalgError(e)),
334 };
335
336 let mut components = Array2::zeros((self.n_components, n_features));
338 let mut singular_values = Array1::zeros(self.n_components);
339
340 for i in 0..self.n_components {
341 singular_values[i] = s[i];
342 for j in 0..n_features {
343 components[[i, j]] = vt[[i, j]];
344 }
345 }
346
347 let total_variance =
349 (x_f64.map_axis(Axis(1), |row| row.dot(&row)).sum()) / n_samples as f64;
350 let explained_variance = singular_values.mapv(|s| s * s / n_samples as f64);
351 let explained_variance_ratio = explained_variance.mapv(|v| v / total_variance);
352
353 self.singular_values = Some(singular_values);
354 self.components = Some(components);
355 self.explained_variance_ratio = Some(explained_variance_ratio);
356
357 Ok(())
358 }
359
360 pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
368 where
369 S: Data,
370 S::Elem: Float + NumCast,
371 {
372 let x_f64 = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
373
374 let n_samples = x_f64.shape()[0];
375 let n_features = x_f64.shape()[1];
376
377 if self.components.is_none() {
378 return Err(TransformError::TransformationError(
379 "TruncatedSVD model has not been fitted".to_string(),
380 ));
381 }
382
383 let components = self.components.as_ref().unwrap();
384
385 if n_features != components.shape()[1] {
386 return Err(TransformError::InvalidInput(format!(
387 "x has {} features, but TruncatedSVD was fitted with {} features",
388 n_features,
389 components.shape()[1]
390 )));
391 }
392
393 let mut transformed = Array2::zeros((n_samples, self.n_components));
395
396 for i in 0..n_samples {
397 for j in 0..self.n_components {
398 let mut dot_product = 0.0;
399 for k in 0..n_features {
400 dot_product += x_f64[[i, k]] * components[[j, k]];
401 }
402 transformed[[i, j]] = dot_product;
403 }
404 }
405
406 Ok(transformed)
407 }
408
409 pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
417 where
418 S: Data,
419 S::Elem: Float + NumCast,
420 {
421 self.fit(x)?;
422 self.transform(x)
423 }
424
425 pub fn components(&self) -> Option<&Array2<f64>> {
430 self.components.as_ref()
431 }
432
433 pub fn singular_values(&self) -> Option<&Array1<f64>> {
438 self.singular_values.as_ref()
439 }
440
441 pub fn explained_variance_ratio(&self) -> Option<&Array1<f64>> {
446 self.explained_variance_ratio.as_ref()
447 }
448}
449
450#[derive(Debug, Clone)]
454pub struct LDA {
455 n_components: usize,
457 solver: String,
459 components: Option<Array2<f64>>,
461 means: Option<Array2<f64>>,
463 explained_variance_ratio: Option<Array1<f64>>,
465}
466
467impl LDA {
468 pub fn new(ncomponents: usize, solver: &str) -> Result<Self> {
477 if solver != "svd" && solver != "eigen" {
478 return Err(TransformError::InvalidInput(
479 "solver must be 'svd' or 'eigen'".to_string(),
480 ));
481 }
482
483 Ok(LDA {
484 n_components: ncomponents,
485 solver: solver.to_string(),
486 components: None,
487 means: None,
488 explained_variance_ratio: None,
489 })
490 }
491
492 pub fn fit<S1, S2>(&mut self, x: &ArrayBase<S1, Ix2>, y: &ArrayBase<S2, Ix1>) -> Result<()>
501 where
502 S1: Data,
503 S2: Data,
504 S1::Elem: Float + NumCast,
505 S2::Elem: Copy + NumCast + Eq + std::hash::Hash,
506 {
507 let x_f64 = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
508
509 let n_samples = x_f64.shape()[0];
510 let n_features = x_f64.shape()[1];
511
512 if n_samples == 0 || n_features == 0 {
513 return Err(TransformError::InvalidInput("Empty input data".to_string()));
514 }
515
516 if n_samples != y.len() {
517 return Err(TransformError::InvalidInput(format!(
518 "x and y have incompatible shapes: x has {} samples, y has {} elements",
519 n_samples,
520 y.len()
521 )));
522 }
523
524 let mut class_indices = vec![];
526 let mut class_map = std::collections::HashMap::new();
527 let mut next_class_idx = 0;
528
529 for &label in y.iter() {
530 let label_u64 = NumCast::from(label).unwrap_or(0);
531
532 if let std::collections::hash_map::Entry::Vacant(e) = class_map.entry(label_u64) {
533 e.insert(next_class_idx);
534 next_class_idx += 1;
535 }
536
537 class_indices.push(class_map[&label_u64]);
538 }
539
540 let n_classes = class_map.len();
541
542 if n_classes <= 1 {
543 return Err(TransformError::InvalidInput(
544 "y has less than 2 classes, LDA requires at least 2 classes".to_string(),
545 ));
546 }
547
548 let maxn_components = n_classes - 1;
549 if self.n_components > maxn_components {
550 return Err(TransformError::InvalidInput(format!(
551 "n_components={} must be <= n_classes-1={}",
552 self.n_components, maxn_components
553 )));
554 }
555
556 let mut class_means = Array2::zeros((n_classes, n_features));
558 let mut class_counts = vec![0; n_classes];
559
560 for i in 0..n_samples {
561 let class_idx = class_indices[i];
562 class_counts[class_idx] += 1;
563
564 for j in 0..n_features {
565 class_means[[class_idx, j]] += x_f64[[i, j]];
566 }
567 }
568
569 for i in 0..n_classes {
570 if class_counts[i] > 0 {
571 for j in 0..n_features {
572 class_means[[i, j]] /= class_counts[i] as f64;
573 }
574 }
575 }
576
577 let mut global_mean = Array1::<f64>::zeros(n_features);
579 for i in 0..n_samples {
580 for j in 0..n_features {
581 global_mean[j] += x_f64[[i, j]];
582 }
583 }
584 global_mean.mapv_inplace(|x: f64| x / n_samples as f64);
585
586 let mut sw = Array2::<f64>::zeros((n_features, n_features));
588 for i in 0..n_samples {
589 let class_idx = class_indices[i];
590 let mut x_centered = Array1::<f64>::zeros(n_features);
591
592 for j in 0..n_features {
593 x_centered[j] = x_f64[[i, j]] - class_means[[class_idx, j]];
594 }
595
596 for j in 0..n_features {
597 for k in 0..n_features {
598 sw[[j, k]] += x_centered[j] * x_centered[k];
599 }
600 }
601 }
602
603 let mut sb = Array2::<f64>::zeros((n_features, n_features));
605 for i in 0..n_classes {
606 let mut mean_diff = Array1::<f64>::zeros(n_features);
607 for j in 0..n_features {
608 mean_diff[j] = class_means[[i, j]] - global_mean[j];
609 }
610
611 for j in 0..n_features {
612 for k in 0..n_features {
613 sb[[j, k]] += class_counts[i] as f64 * mean_diff[j] * mean_diff[k];
614 }
615 }
616 }
617
618 let mut components = Array2::<f64>::zeros((self.n_components, n_features));
620 let mut eigenvalues = Array1::<f64>::zeros(self.n_components);
621
622 if self.solver == "svd" {
623 let (u_sw, s_sw, vt_sw) = match svd::<f64>(&sw.view(), true, None) {
627 Ok(result) => result,
628 Err(e) => return Err(TransformError::LinalgError(e)),
629 };
630
631 let mut sw_sqrt_inv = Array2::<f64>::zeros((n_features, n_features));
633 for i in 0..n_features {
634 if s_sw[i] > EPSILON {
635 for j in 0..n_features {
636 for k in 0..n_features {
637 let s_inv_sqrt = 1.0 / s_sw[i].sqrt();
638 sw_sqrt_inv[[j, k]] += u_sw[[j, i]] * s_inv_sqrt * vt_sw[[i, k]];
639 }
640 }
641 }
642 }
643
644 let mut sb_transformed = Array2::<f64>::zeros((n_features, n_features));
646 for i in 0..n_features {
647 for j in 0..n_features {
648 for k in 0..n_features {
649 for l in 0..n_features {
650 sb_transformed[[i, j]] +=
651 sw_sqrt_inv[[i, k]] * sb[[k, l]] * sw_sqrt_inv[[l, j]];
652 }
653 }
654 }
655 }
656
657 let (u_sb, s_sb, vt_sb) = match svd::<f64>(&sb_transformed.view(), true, None) {
659 Ok(result) => result,
660 Err(e) => return Err(TransformError::LinalgError(e)),
661 };
662
663 for i in 0..self.n_components {
665 eigenvalues[i] = s_sb[i];
666
667 for j in 0..n_features {
668 for k in 0..n_features {
669 components[[i, j]] += sw_sqrt_inv[[k, j]] * u_sb[[k, i]];
670 }
671 }
672 }
673 } else {
674 let mut sw_reg = sw.clone();
679 for i in 0..n_features {
680 sw_reg[[i, i]] += EPSILON; }
682
683 let (u_sw, s_sw, vt_sw) = match svd::<f64>(&sw_reg.view(), true, None) {
686 Ok(result) => result,
687 Err(e) => return Err(TransformError::LinalgError(e)),
688 };
689
690 let mut sw_inv = Array2::<f64>::zeros((n_features, n_features));
692 for i in 0..n_features {
693 if s_sw[i] > EPSILON {
694 for j in 0..n_features {
695 for k in 0..n_features {
696 sw_inv[[j, k]] += u_sw[[j, i]] * (1.0 / s_sw[i]) * vt_sw[[i, k]];
697 }
698 }
699 }
700 }
701
702 let mut sw_inv_sb = Array2::<f64>::zeros((n_features, n_features));
704 for i in 0..n_features {
705 for j in 0..n_features {
706 for k in 0..n_features {
707 sw_inv_sb[[i, j]] += sw_inv[[i, k]] * sb[[k, j]];
708 }
709 }
710 }
711
712 let mut sym_matrix = Array2::<f64>::zeros((n_features, n_features));
716 for i in 0..n_features {
717 for j in 0..n_features {
718 sym_matrix[[i, j]] = (sw_inv_sb[[i, j]] + sw_inv_sb[[j, i]]) / 2.0;
719 }
720 }
721
722 let (eig_vals, eig_vecs) = match scirs2_linalg::eigh::<f64>(&sym_matrix.view(), None) {
724 Ok(result) => result,
725 Err(_) => {
726 let (u, s, vt) = match svd::<f64>(&sw_inv_sb.view(), true, None) {
728 Ok(result) => result,
729 Err(e) => return Err(TransformError::LinalgError(e)),
730 };
731 (s, u)
732 }
733 };
734
735 let mut indices: Vec<usize> = (0..n_features).collect();
737 indices.sort_by(|&i, &j| eig_vals[j].partial_cmp(&eig_vals[i]).unwrap());
738
739 for i in 0..self.n_components {
741 let idx = indices[i];
742 eigenvalues[i] = eig_vals[idx].max(0.0); for j in 0..n_features {
745 components[[i, j]] = eig_vecs[[j, idx]];
746 }
747 }
748
749 for i in 0..self.n_components {
751 let mut norm = 0.0;
752 for j in 0..n_features {
753 norm += components[[i, j]] * components[[i, j]];
754 }
755 norm = norm.sqrt();
756
757 if norm > EPSILON {
758 for j in 0..n_features {
759 components[[i, j]] /= norm;
760 }
761 }
762 }
763 }
764
765 let total_eigenvalues = eigenvalues.iter().sum::<f64>();
767 let explained_variance_ratio = eigenvalues.mapv(|e| e / total_eigenvalues);
768
769 self.components = Some(components);
770 self.means = Some(class_means);
771 self.explained_variance_ratio = Some(explained_variance_ratio);
772
773 Ok(())
774 }
775
776 pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
784 where
785 S: Data,
786 S::Elem: Float + NumCast,
787 {
788 let x_f64 = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
789
790 let n_samples = x_f64.shape()[0];
791 let n_features = x_f64.shape()[1];
792
793 if self.components.is_none() {
794 return Err(TransformError::TransformationError(
795 "LDA model has not been fitted".to_string(),
796 ));
797 }
798
799 let components = self.components.as_ref().unwrap();
800
801 if n_features != components.shape()[1] {
802 return Err(TransformError::InvalidInput(format!(
803 "x has {} features, but LDA was fitted with {} features",
804 n_features,
805 components.shape()[1]
806 )));
807 }
808
809 let mut transformed = Array2::zeros((n_samples, self.n_components));
811
812 for i in 0..n_samples {
813 for j in 0..self.n_components {
814 let mut dot_product = 0.0;
815 for k in 0..n_features {
816 dot_product += x_f64[[i, k]] * components[[j, k]];
817 }
818 transformed[[i, j]] = dot_product;
819 }
820 }
821
822 Ok(transformed)
823 }
824
825 pub fn fit_transform<S1, S2>(
834 &mut self,
835 x: &ArrayBase<S1, Ix2>,
836 y: &ArrayBase<S2, Ix1>,
837 ) -> Result<Array2<f64>>
838 where
839 S1: Data,
840 S2: Data,
841 S1::Elem: Float + NumCast,
842 S2::Elem: Copy + NumCast + Eq + std::hash::Hash,
843 {
844 self.fit(x, y)?;
845 self.transform(x)
846 }
847
848 pub fn components(&self) -> Option<&Array2<f64>> {
853 self.components.as_ref()
854 }
855
856 pub fn explained_variance_ratio(&self) -> Option<&Array1<f64>> {
861 self.explained_variance_ratio.as_ref()
862 }
863}
864
865#[cfg(test)]
866mod tests {
867 use super::*;
868 use approx::assert_abs_diff_eq;
869 use scirs2_core::ndarray::Array;
870
871 #[test]
872 fn test_pca_transform() {
873 let x = Array::from_shape_vec(
875 (4, 3),
876 vec![
877 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
878 ],
879 )
880 .unwrap();
881
882 let mut pca = PCA::new(2, true, false);
884 let x_transformed = pca.fit_transform(&x).unwrap();
885
886 assert_eq!(x_transformed.shape(), &[4, 2]);
888
889 let explained_variance = pca.explained_variance_ratio().unwrap();
891 assert_eq!(explained_variance.len(), 2);
892
893 assert!(explained_variance.sum() > 0.0 && explained_variance.sum().is_finite());
895 }
896
897 #[test]
898 fn test_truncated_svd() {
899 let x = Array::from_shape_vec(
901 (4, 3),
902 vec![
903 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
904 ],
905 )
906 .unwrap();
907
908 let mut svd = TruncatedSVD::new(2);
910 let x_transformed = svd.fit_transform(&x).unwrap();
911
912 assert_eq!(x_transformed.shape(), &[4, 2]);
914
915 let explained_variance = svd.explained_variance_ratio().unwrap();
917 assert_eq!(explained_variance.len(), 2);
918
919 assert!(explained_variance.sum() > 0.0 && explained_variance.sum().is_finite());
921 }
922
923 #[test]
924 fn test_lda() {
925 let x = Array::from_shape_vec(
927 (6, 2),
928 vec![1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 5.0, 4.0, 6.0, 5.0, 7.0, 4.0],
929 )
930 .unwrap();
931
932 let y = Array::from_vec(vec![0, 0, 0, 1, 1, 1]);
933
934 let mut lda = LDA::new(1, "svd").unwrap();
936 let x_transformed = lda.fit_transform(&x, &y).unwrap();
937
938 assert_eq!(x_transformed.shape(), &[6, 1]);
940
941 let explained_variance = lda.explained_variance_ratio().unwrap();
943 assert_abs_diff_eq!(explained_variance[0], 1.0, epsilon = 1e-10);
944 }
945
946 #[test]
947 fn test_lda_eigen_solver() {
948 let x = Array::from_shape_vec(
950 (9, 2),
951 vec![
952 1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 5.0, 4.0, 6.0, 5.0, 7.0, 4.0, 9.0, 8.0, 10.0, 9.0, 11.0, 10.0, ],
956 )
957 .unwrap();
958
959 let y = Array::from_vec(vec![0, 0, 0, 1, 1, 1, 2, 2, 2]);
960
961 let mut lda_eigen = LDA::new(2, "eigen").unwrap(); let x_transformed_eigen = lda_eigen.fit_transform(&x, &y).unwrap();
964
965 let mut lda_svd = LDA::new(2, "svd").unwrap();
967 let x_transformed_svd = lda_svd.fit_transform(&x, &y).unwrap();
968
969 assert_eq!(x_transformed_eigen.shape(), &[9, 2]);
971 assert_eq!(x_transformed_svd.shape(), &[9, 2]);
972
973 assert!(x_transformed_eigen.iter().all(|&x| x.is_finite()));
975 assert!(x_transformed_svd.iter().all(|&x| x.is_finite()));
976
977 let explained_variance_eigen = lda_eigen.explained_variance_ratio().unwrap();
979 let explained_variance_svd = lda_svd.explained_variance_ratio().unwrap();
980
981 assert_eq!(explained_variance_eigen.len(), 2);
982 assert_eq!(explained_variance_svd.len(), 2);
983
984 assert_abs_diff_eq!(explained_variance_eigen.sum(), 1.0, epsilon = 1e-10);
986 assert_abs_diff_eq!(explained_variance_svd.sum(), 1.0, epsilon = 1e-10);
987
988 assert!(explained_variance_eigen.iter().all(|&x| x >= 0.0));
990 assert!(explained_variance_svd.iter().all(|&x| x >= 0.0));
991 }
992
993 #[test]
994 fn test_lda_invalid_solver() {
995 let result = LDA::new(1, "invalid");
996 assert!(result.is_err());
997 assert!(result
998 .unwrap_err()
999 .to_string()
1000 .contains("solver must be 'svd' or 'eigen'"));
1001 }
1002}