1pub mod factor_analysis;
9mod isomap;
10mod lle;
11mod spectral_embedding;
12mod tsne;
13mod umap;
14
15pub mod laplacian_eigenmaps;
17
18pub mod diffusion_maps;
20
21pub use crate::reduction::diffusion_maps::DiffusionMaps;
22pub use crate::reduction::factor_analysis::{
23 factor_analysis, scree_plot_data, FactorAnalysis, FactorAnalysisResult, RotationMethod,
24 ScreePlotData,
25};
26pub use crate::reduction::isomap::Isomap;
27pub use crate::reduction::laplacian_eigenmaps::{
28 GraphMethod, LaplacianEigenmaps, LaplacianType as LELaplacianType,
29};
30pub use crate::reduction::lle::LLE;
31pub use crate::reduction::spectral_embedding::{AffinityMethod, SpectralEmbedding};
32pub use crate::reduction::tsne::{trustworthiness, TSNE};
33pub use crate::reduction::umap::UMAP;
34
35use scirs2_core::ndarray::{Array1, Array2, ArrayBase, Axis, Data, Ix1, Ix2};
36use scirs2_core::numeric::{Float, NumCast};
37use scirs2_linalg::svd;
38
39use crate::error::{Result, TransformError};
40
41const EPSILON: f64 = 1e-10;
43
44#[derive(Debug, Clone)]
49pub struct PCA {
50 n_components: usize,
52 center: bool,
54 scale: bool,
56 components: Option<Array2<f64>>,
58 mean: Option<Array1<f64>>,
60 std: Option<Array1<f64>>,
62 singular_values: Option<Array1<f64>>,
64 explained_variance_ratio: Option<Array1<f64>>,
66}
67
68impl PCA {
69 pub fn new(ncomponents: usize, center: bool, scale: bool) -> Self {
79 PCA {
80 n_components: ncomponents,
81 center,
82 scale,
83 components: None,
84 mean: None,
85 std: None,
86 singular_values: None,
87 explained_variance_ratio: None,
88 }
89 }
90
91 pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<()>
99 where
100 S: Data,
101 S::Elem: Float + NumCast,
102 {
103 let x_f64 = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
104
105 let n_samples = x_f64.shape()[0];
106 let n_features = x_f64.shape()[1];
107
108 if n_samples == 0 || n_features == 0 {
109 return Err(TransformError::InvalidInput("Empty input data".to_string()));
110 }
111
112 if self.n_components > n_features {
113 return Err(TransformError::InvalidInput(format!(
114 "n_components={} must be <= n_features={}",
115 self.n_components, n_features
116 )));
117 }
118
119 let mut x_processed = Array2::zeros((n_samples, n_features));
121 let mut mean = Array1::zeros(n_features);
122 let mut std = Array1::ones(n_features);
123
124 if self.center {
125 for j in 0..n_features {
126 let col_mean = x_f64.column(j).sum() / n_samples as f64;
127 mean[j] = col_mean;
128
129 for i in 0..n_samples {
130 x_processed[[i, j]] = x_f64[[i, j]] - col_mean;
131 }
132 }
133 } else {
134 x_processed.assign(&x_f64);
135 }
136
137 if self.scale {
138 for j in 0..n_features {
139 let col_std =
140 (x_processed.column(j).mapv(|x| x * x).sum() / n_samples as f64).sqrt();
141 if col_std > f64::EPSILON {
142 std[j] = col_std;
143
144 for i in 0..n_samples {
145 x_processed[[i, j]] /= col_std;
146 }
147 }
148 }
149 }
150
151 let (_u, s, vt) = match svd::<f64>(&x_processed.view(), true, None) {
153 Ok(result) => result,
154 Err(e) => return Err(TransformError::LinalgError(e)),
155 };
156
157 let mut components = Array2::zeros((self.n_components, n_features));
159 let mut singular_values = Array1::zeros(self.n_components);
160
161 for i in 0..self.n_components {
162 singular_values[i] = s[i];
163 for j in 0..n_features {
164 components[[i, j]] = vt[[i, j]];
165 }
166 }
167
168 let total_variance = s.mapv(|s| s * s).sum();
170 let explained_variance_ratio = singular_values.mapv(|s| s * s / total_variance);
171
172 self.components = Some(components);
173 self.mean = Some(mean);
174 self.std = Some(std);
175 self.singular_values = Some(singular_values);
176 self.explained_variance_ratio = Some(explained_variance_ratio);
177
178 Ok(())
179 }
180
181 pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
189 where
190 S: Data,
191 S::Elem: Float + NumCast,
192 {
193 let x_f64 = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
194
195 let n_samples = x_f64.shape()[0];
196 let n_features = x_f64.shape()[1];
197
198 if self.components.is_none() {
199 return Err(TransformError::TransformationError(
200 "PCA model has not been fitted".to_string(),
201 ));
202 }
203
204 let components = self.components.as_ref().expect("Operation failed");
205 let mean = self.mean.as_ref().expect("Operation failed");
206 let std = self.std.as_ref().expect("Operation failed");
207
208 if n_features != components.shape()[1] {
209 return Err(TransformError::InvalidInput(format!(
210 "x has {} features, but PCA was fitted with {} features",
211 n_features,
212 components.shape()[1]
213 )));
214 }
215
216 let mut x_processed = Array2::zeros((n_samples, n_features));
218
219 for i in 0..n_samples {
220 for j in 0..n_features {
221 let mut value = x_f64[[i, j]];
222
223 if self.center {
224 value -= mean[j];
225 }
226
227 if self.scale {
228 value /= std[j];
229 }
230
231 x_processed[[i, j]] = value;
232 }
233 }
234
235 let mut transformed = Array2::zeros((n_samples, self.n_components));
237
238 for i in 0..n_samples {
239 for j in 0..self.n_components {
240 let mut dot_product = 0.0;
241 for k in 0..n_features {
242 dot_product += x_processed[[i, k]] * components[[j, k]];
243 }
244 transformed[[i, j]] = dot_product;
245 }
246 }
247
248 Ok(transformed)
249 }
250
251 pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
259 where
260 S: Data,
261 S::Elem: Float + NumCast,
262 {
263 self.fit(x)?;
264 self.transform(x)
265 }
266
267 pub fn components(&self) -> Option<&Array2<f64>> {
272 self.components.as_ref()
273 }
274
275 pub fn mean(&self) -> Option<&Array1<f64>> {
280 self.mean.as_ref()
281 }
282
283 pub fn explained_variance_ratio(&self) -> Option<&Array1<f64>> {
288 self.explained_variance_ratio.as_ref()
289 }
290}
291
292#[derive(Debug, Clone)]
298pub struct TruncatedSVD {
299 n_components: usize,
301 singular_values: Option<Array1<f64>>,
303 components: Option<Array2<f64>>,
305 explained_variance_ratio: Option<Array1<f64>>,
307}
308
309impl TruncatedSVD {
310 pub fn new(ncomponents: usize) -> Self {
318 TruncatedSVD {
319 n_components: ncomponents,
320 singular_values: None,
321 components: None,
322 explained_variance_ratio: None,
323 }
324 }
325
326 pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<()>
334 where
335 S: Data,
336 S::Elem: Float + NumCast,
337 {
338 let x_f64 = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
339
340 let n_samples = x_f64.shape()[0];
341 let n_features = x_f64.shape()[1];
342
343 if n_samples == 0 || n_features == 0 {
344 return Err(TransformError::InvalidInput("Empty input data".to_string()));
345 }
346
347 if self.n_components > n_features {
348 return Err(TransformError::InvalidInput(format!(
349 "n_components={} must be <= n_features={}",
350 self.n_components, n_features
351 )));
352 }
353
354 let (_u, s, vt) = match svd::<f64>(&x_f64.view(), true, None) {
356 Ok(result) => result,
357 Err(e) => return Err(TransformError::LinalgError(e)),
358 };
359
360 let mut components = Array2::zeros((self.n_components, n_features));
362 let mut singular_values = Array1::zeros(self.n_components);
363
364 for i in 0..self.n_components {
365 singular_values[i] = s[i];
366 for j in 0..n_features {
367 components[[i, j]] = vt[[i, j]];
368 }
369 }
370
371 let total_variance =
373 (x_f64.map_axis(Axis(1), |row| row.dot(&row)).sum()) / n_samples as f64;
374 let explained_variance = singular_values.mapv(|s| s * s / n_samples as f64);
375 let explained_variance_ratio = explained_variance.mapv(|v| v / total_variance);
376
377 self.singular_values = Some(singular_values);
378 self.components = Some(components);
379 self.explained_variance_ratio = Some(explained_variance_ratio);
380
381 Ok(())
382 }
383
384 pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
392 where
393 S: Data,
394 S::Elem: Float + NumCast,
395 {
396 let x_f64 = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
397
398 let n_samples = x_f64.shape()[0];
399 let n_features = x_f64.shape()[1];
400
401 if self.components.is_none() {
402 return Err(TransformError::TransformationError(
403 "TruncatedSVD model has not been fitted".to_string(),
404 ));
405 }
406
407 let components = self.components.as_ref().expect("Operation failed");
408
409 if n_features != components.shape()[1] {
410 return Err(TransformError::InvalidInput(format!(
411 "x has {} features, but TruncatedSVD was fitted with {} features",
412 n_features,
413 components.shape()[1]
414 )));
415 }
416
417 let mut transformed = Array2::zeros((n_samples, self.n_components));
419
420 for i in 0..n_samples {
421 for j in 0..self.n_components {
422 let mut dot_product = 0.0;
423 for k in 0..n_features {
424 dot_product += x_f64[[i, k]] * components[[j, k]];
425 }
426 transformed[[i, j]] = dot_product;
427 }
428 }
429
430 Ok(transformed)
431 }
432
433 pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
441 where
442 S: Data,
443 S::Elem: Float + NumCast,
444 {
445 self.fit(x)?;
446 self.transform(x)
447 }
448
449 pub fn components(&self) -> Option<&Array2<f64>> {
454 self.components.as_ref()
455 }
456
457 pub fn singular_values(&self) -> Option<&Array1<f64>> {
462 self.singular_values.as_ref()
463 }
464
465 pub fn explained_variance_ratio(&self) -> Option<&Array1<f64>> {
470 self.explained_variance_ratio.as_ref()
471 }
472}
473
474#[derive(Debug, Clone)]
478pub struct LDA {
479 n_components: usize,
481 solver: String,
483 components: Option<Array2<f64>>,
485 means: Option<Array2<f64>>,
487 explained_variance_ratio: Option<Array1<f64>>,
489}
490
491impl LDA {
492 pub fn new(ncomponents: usize, solver: &str) -> Result<Self> {
501 if solver != "svd" && solver != "eigen" {
502 return Err(TransformError::InvalidInput(
503 "solver must be 'svd' or 'eigen'".to_string(),
504 ));
505 }
506
507 Ok(LDA {
508 n_components: ncomponents,
509 solver: solver.to_string(),
510 components: None,
511 means: None,
512 explained_variance_ratio: None,
513 })
514 }
515
516 pub fn fit<S1, S2>(&mut self, x: &ArrayBase<S1, Ix2>, y: &ArrayBase<S2, Ix1>) -> Result<()>
525 where
526 S1: Data,
527 S2: Data,
528 S1::Elem: Float + NumCast,
529 S2::Elem: Copy + NumCast + Eq + std::hash::Hash,
530 {
531 let x_f64 = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
532
533 let n_samples = x_f64.shape()[0];
534 let n_features = x_f64.shape()[1];
535
536 if n_samples == 0 || n_features == 0 {
537 return Err(TransformError::InvalidInput("Empty input data".to_string()));
538 }
539
540 if n_samples != y.len() {
541 return Err(TransformError::InvalidInput(format!(
542 "x and y have incompatible shapes: x has {} samples, y has {} elements",
543 n_samples,
544 y.len()
545 )));
546 }
547
548 let mut class_indices = vec![];
550 let mut class_map = std::collections::HashMap::new();
551 let mut next_class_idx = 0;
552
553 for &label in y.iter() {
554 let label_u64 = NumCast::from(label).unwrap_or(0);
555
556 if let std::collections::hash_map::Entry::Vacant(e) = class_map.entry(label_u64) {
557 e.insert(next_class_idx);
558 next_class_idx += 1;
559 }
560
561 class_indices.push(class_map[&label_u64]);
562 }
563
564 let n_classes = class_map.len();
565
566 if n_classes <= 1 {
567 return Err(TransformError::InvalidInput(
568 "y has less than 2 classes, LDA requires at least 2 classes".to_string(),
569 ));
570 }
571
572 let maxn_components = n_classes - 1;
573 if self.n_components > maxn_components {
574 return Err(TransformError::InvalidInput(format!(
575 "n_components={} must be <= n_classes-1={}",
576 self.n_components, maxn_components
577 )));
578 }
579
580 let mut class_means = Array2::zeros((n_classes, n_features));
582 let mut class_counts = vec![0; n_classes];
583
584 for i in 0..n_samples {
585 let class_idx = class_indices[i];
586 class_counts[class_idx] += 1;
587
588 for j in 0..n_features {
589 class_means[[class_idx, j]] += x_f64[[i, j]];
590 }
591 }
592
593 for i in 0..n_classes {
594 if class_counts[i] > 0 {
595 for j in 0..n_features {
596 class_means[[i, j]] /= class_counts[i] as f64;
597 }
598 }
599 }
600
601 let mut global_mean = Array1::<f64>::zeros(n_features);
603 for i in 0..n_samples {
604 for j in 0..n_features {
605 global_mean[j] += x_f64[[i, j]];
606 }
607 }
608 global_mean.mapv_inplace(|x: f64| x / n_samples as f64);
609
610 let mut sw = Array2::<f64>::zeros((n_features, n_features));
612 for i in 0..n_samples {
613 let class_idx = class_indices[i];
614 let mut x_centered = Array1::<f64>::zeros(n_features);
615
616 for j in 0..n_features {
617 x_centered[j] = x_f64[[i, j]] - class_means[[class_idx, j]];
618 }
619
620 for j in 0..n_features {
621 for k in 0..n_features {
622 sw[[j, k]] += x_centered[j] * x_centered[k];
623 }
624 }
625 }
626
627 let mut sb = Array2::<f64>::zeros((n_features, n_features));
629 for i in 0..n_classes {
630 let mut mean_diff = Array1::<f64>::zeros(n_features);
631 for j in 0..n_features {
632 mean_diff[j] = class_means[[i, j]] - global_mean[j];
633 }
634
635 for j in 0..n_features {
636 for k in 0..n_features {
637 sb[[j, k]] += class_counts[i] as f64 * mean_diff[j] * mean_diff[k];
638 }
639 }
640 }
641
642 let mut components = Array2::<f64>::zeros((self.n_components, n_features));
644 let mut eigenvalues = Array1::<f64>::zeros(self.n_components);
645
646 if self.solver == "svd" {
647 let (u_sw, s_sw, vt_sw) = match svd::<f64>(&sw.view(), true, None) {
651 Ok(result) => result,
652 Err(e) => return Err(TransformError::LinalgError(e)),
653 };
654
655 let mut sw_sqrt_inv = Array2::<f64>::zeros((n_features, n_features));
657 for i in 0..n_features {
658 if s_sw[i] > EPSILON {
659 for j in 0..n_features {
660 for k in 0..n_features {
661 let s_inv_sqrt = 1.0 / s_sw[i].sqrt();
662 sw_sqrt_inv[[j, k]] += u_sw[[j, i]] * s_inv_sqrt * vt_sw[[i, k]];
663 }
664 }
665 }
666 }
667
668 let mut sb_transformed = Array2::<f64>::zeros((n_features, n_features));
670 for i in 0..n_features {
671 for j in 0..n_features {
672 for k in 0..n_features {
673 for l in 0..n_features {
674 sb_transformed[[i, j]] +=
675 sw_sqrt_inv[[i, k]] * sb[[k, l]] * sw_sqrt_inv[[l, j]];
676 }
677 }
678 }
679 }
680
681 let (u_sb, s_sb, vt_sb) = match svd::<f64>(&sb_transformed.view(), true, None) {
683 Ok(result) => result,
684 Err(e) => return Err(TransformError::LinalgError(e)),
685 };
686
687 for i in 0..self.n_components {
689 eigenvalues[i] = s_sb[i];
690
691 for j in 0..n_features {
692 for k in 0..n_features {
693 components[[i, j]] += sw_sqrt_inv[[k, j]] * u_sb[[k, i]];
694 }
695 }
696 }
697 } else {
698 let mut sw_reg = sw.clone();
703 for i in 0..n_features {
704 sw_reg[[i, i]] += EPSILON; }
706
707 let (u_sw, s_sw, vt_sw) = match svd::<f64>(&sw_reg.view(), true, None) {
710 Ok(result) => result,
711 Err(e) => return Err(TransformError::LinalgError(e)),
712 };
713
714 let mut sw_inv = Array2::<f64>::zeros((n_features, n_features));
716 for i in 0..n_features {
717 if s_sw[i] > EPSILON {
718 for j in 0..n_features {
719 for k in 0..n_features {
720 sw_inv[[j, k]] += u_sw[[j, i]] * (1.0 / s_sw[i]) * vt_sw[[i, k]];
721 }
722 }
723 }
724 }
725
726 let mut sw_inv_sb = Array2::<f64>::zeros((n_features, n_features));
728 for i in 0..n_features {
729 for j in 0..n_features {
730 for k in 0..n_features {
731 sw_inv_sb[[i, j]] += sw_inv[[i, k]] * sb[[k, j]];
732 }
733 }
734 }
735
736 let mut sym_matrix = Array2::<f64>::zeros((n_features, n_features));
740 for i in 0..n_features {
741 for j in 0..n_features {
742 sym_matrix[[i, j]] = (sw_inv_sb[[i, j]] + sw_inv_sb[[j, i]]) / 2.0;
743 }
744 }
745
746 let (eig_vals, eig_vecs) = match scirs2_linalg::eigh::<f64>(&sym_matrix.view(), None) {
748 Ok(result) => result,
749 Err(_) => {
750 let (u, s, vt) = match svd::<f64>(&sw_inv_sb.view(), true, None) {
752 Ok(result) => result,
753 Err(e) => return Err(TransformError::LinalgError(e)),
754 };
755 (s, u)
756 }
757 };
758
759 let mut indices: Vec<usize> = (0..n_features).collect();
761 indices.sort_by(|&i, &j| {
762 eig_vals[j]
763 .partial_cmp(&eig_vals[i])
764 .expect("Operation failed")
765 });
766
767 for i in 0..self.n_components {
769 let idx = indices[i];
770 eigenvalues[i] = eig_vals[idx].max(0.0); for j in 0..n_features {
773 components[[i, j]] = eig_vecs[[j, idx]];
774 }
775 }
776
777 for i in 0..self.n_components {
779 let mut norm = 0.0;
780 for j in 0..n_features {
781 norm += components[[i, j]] * components[[i, j]];
782 }
783 norm = norm.sqrt();
784
785 if norm > EPSILON {
786 for j in 0..n_features {
787 components[[i, j]] /= norm;
788 }
789 }
790 }
791 }
792
793 let total_eigenvalues = eigenvalues.iter().sum::<f64>();
795 let explained_variance_ratio = eigenvalues.mapv(|e| e / total_eigenvalues);
796
797 self.components = Some(components);
798 self.means = Some(class_means);
799 self.explained_variance_ratio = Some(explained_variance_ratio);
800
801 Ok(())
802 }
803
804 pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
812 where
813 S: Data,
814 S::Elem: Float + NumCast,
815 {
816 let x_f64 = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
817
818 let n_samples = x_f64.shape()[0];
819 let n_features = x_f64.shape()[1];
820
821 if self.components.is_none() {
822 return Err(TransformError::TransformationError(
823 "LDA model has not been fitted".to_string(),
824 ));
825 }
826
827 let components = self.components.as_ref().expect("Operation failed");
828
829 if n_features != components.shape()[1] {
830 return Err(TransformError::InvalidInput(format!(
831 "x has {} features, but LDA was fitted with {} features",
832 n_features,
833 components.shape()[1]
834 )));
835 }
836
837 let mut transformed = Array2::zeros((n_samples, self.n_components));
839
840 for i in 0..n_samples {
841 for j in 0..self.n_components {
842 let mut dot_product = 0.0;
843 for k in 0..n_features {
844 dot_product += x_f64[[i, k]] * components[[j, k]];
845 }
846 transformed[[i, j]] = dot_product;
847 }
848 }
849
850 Ok(transformed)
851 }
852
853 pub fn fit_transform<S1, S2>(
862 &mut self,
863 x: &ArrayBase<S1, Ix2>,
864 y: &ArrayBase<S2, Ix1>,
865 ) -> Result<Array2<f64>>
866 where
867 S1: Data,
868 S2: Data,
869 S1::Elem: Float + NumCast,
870 S2::Elem: Copy + NumCast + Eq + std::hash::Hash,
871 {
872 self.fit(x, y)?;
873 self.transform(x)
874 }
875
876 pub fn components(&self) -> Option<&Array2<f64>> {
881 self.components.as_ref()
882 }
883
884 pub fn explained_variance_ratio(&self) -> Option<&Array1<f64>> {
889 self.explained_variance_ratio.as_ref()
890 }
891}
892
893#[cfg(test)]
894mod tests {
895 use super::*;
896 use approx::assert_abs_diff_eq;
897 use scirs2_core::ndarray::Array;
898
899 #[test]
900 fn test_pca_transform() {
901 let x = Array::from_shape_vec(
903 (4, 3),
904 vec![
905 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
906 ],
907 )
908 .expect("Operation failed");
909
910 let mut pca = PCA::new(2, true, false);
912 let x_transformed = pca.fit_transform(&x).expect("Operation failed");
913
914 assert_eq!(x_transformed.shape(), &[4, 2]);
916
917 let explained_variance = pca.explained_variance_ratio().expect("Operation failed");
919 assert_eq!(explained_variance.len(), 2);
920
921 assert!(explained_variance.sum() > 0.0 && explained_variance.sum().is_finite());
923 }
924
925 #[test]
926 fn test_truncated_svd() {
927 let x = Array::from_shape_vec(
929 (4, 3),
930 vec![
931 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
932 ],
933 )
934 .expect("Operation failed");
935
936 let mut svd = TruncatedSVD::new(2);
938 let x_transformed = svd.fit_transform(&x).expect("Operation failed");
939
940 assert_eq!(x_transformed.shape(), &[4, 2]);
942
943 let explained_variance = svd.explained_variance_ratio().expect("Operation failed");
945 assert_eq!(explained_variance.len(), 2);
946
947 assert!(explained_variance.sum() > 0.0 && explained_variance.sum().is_finite());
949 }
950
951 #[test]
952 fn test_lda() {
953 let x = Array::from_shape_vec(
955 (6, 2),
956 vec![1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 5.0, 4.0, 6.0, 5.0, 7.0, 4.0],
957 )
958 .expect("Operation failed");
959
960 let y = Array::from_vec(vec![0, 0, 0, 1, 1, 1]);
961
962 let mut lda = LDA::new(1, "svd").expect("Operation failed");
964 let x_transformed = lda.fit_transform(&x, &y).expect("Operation failed");
965
966 assert_eq!(x_transformed.shape(), &[6, 1]);
968
969 let explained_variance = lda.explained_variance_ratio().expect("Operation failed");
971 assert_abs_diff_eq!(explained_variance[0], 1.0, epsilon = 1e-10);
972 }
973
974 #[test]
975 fn test_lda_eigen_solver() {
976 let x = Array::from_shape_vec(
978 (9, 2),
979 vec![
980 1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 5.0, 4.0, 6.0, 5.0, 7.0, 4.0, 9.0, 8.0, 10.0, 9.0, 11.0, 10.0, ],
984 )
985 .expect("Operation failed");
986
987 let y = Array::from_vec(vec![0, 0, 0, 1, 1, 1, 2, 2, 2]);
988
989 let mut lda_eigen = LDA::new(2, "eigen").expect("Operation failed"); let x_transformed_eigen = lda_eigen.fit_transform(&x, &y).expect("Operation failed");
992
993 let mut lda_svd = LDA::new(2, "svd").expect("Operation failed");
995 let x_transformed_svd = lda_svd.fit_transform(&x, &y).expect("Operation failed");
996
997 assert_eq!(x_transformed_eigen.shape(), &[9, 2]);
999 assert_eq!(x_transformed_svd.shape(), &[9, 2]);
1000
1001 assert!(x_transformed_eigen.iter().all(|&x| x.is_finite()));
1003 assert!(x_transformed_svd.iter().all(|&x| x.is_finite()));
1004
1005 let explained_variance_eigen = lda_eigen
1007 .explained_variance_ratio()
1008 .expect("Operation failed");
1009 let explained_variance_svd = lda_svd
1010 .explained_variance_ratio()
1011 .expect("Operation failed");
1012
1013 assert_eq!(explained_variance_eigen.len(), 2);
1014 assert_eq!(explained_variance_svd.len(), 2);
1015
1016 assert_abs_diff_eq!(explained_variance_eigen.sum(), 1.0, epsilon = 1e-10);
1018 assert_abs_diff_eq!(explained_variance_svd.sum(), 1.0, epsilon = 1e-10);
1019
1020 assert!(explained_variance_eigen.iter().all(|&x| x >= 0.0));
1022 assert!(explained_variance_svd.iter().all(|&x| x >= 0.0));
1023 }
1024
1025 #[test]
1026 fn test_lda_invalid_solver() {
1027 let result = LDA::new(1, "invalid");
1028 assert!(result.is_err());
1029 assert!(result
1030 .unwrap_err()
1031 .to_string()
1032 .contains("solver must be 'svd' or 'eigen'"));
1033 }
1034}