scirs2_transform/reduction/
mod.rs

1//! Dimensionality reduction techniques
2//!
3//! This module provides algorithms for reducing the dimensionality of data,
4//! which is useful for visualization, feature extraction, and reducing
5//! computational complexity.
6
7mod isomap;
8mod lle;
9mod spectral_embedding;
10mod tsne;
11mod umap;
12
13pub use crate::reduction::isomap::Isomap;
14pub use crate::reduction::lle::LLE;
15pub use crate::reduction::spectral_embedding::{AffinityMethod, SpectralEmbedding};
16pub use crate::reduction::tsne::{trustworthiness, TSNE};
17pub use crate::reduction::umap::UMAP;
18
19use scirs2_core::ndarray::{Array1, Array2, ArrayBase, Axis, Data, Ix1, Ix2};
20use scirs2_core::numeric::{Float, NumCast};
21use scirs2_linalg::svd;
22
23use crate::error::{Result, TransformError};
24
25// Define a small value to use for comparison with zero
26const EPSILON: f64 = 1e-10;
27
28/// Principal Component Analysis (PCA) dimensionality reduction
29///
30/// PCA finds the directions of maximum variance in the data and
31/// projects the data onto a lower dimensional space.
32#[derive(Debug, Clone)]
33pub struct PCA {
34    /// Number of components to keep
35    n_components: usize,
36    /// Whether to center the data before computing the SVD
37    center: bool,
38    /// Whether to scale the data before computing the SVD
39    scale: bool,
40    /// The principal components
41    components: Option<Array2<f64>>,
42    /// The mean of the training data
43    mean: Option<Array1<f64>>,
44    /// The standard deviation of the training data
45    std: Option<Array1<f64>>,
46    /// The singular values of the centered training data
47    singular_values: Option<Array1<f64>>,
48    /// The explained variance ratio
49    explained_variance_ratio: Option<Array1<f64>>,
50}
51
52impl PCA {
53    /// Creates a new PCA instance
54    ///
55    /// # Arguments
56    /// * `n_components` - Number of components to keep
57    /// * `center` - Whether to center the data before computing the SVD
58    /// * `scale` - Whether to scale the data before computing the SVD
59    ///
60    /// # Returns
61    /// * A new PCA instance
62    pub fn new(ncomponents: usize, center: bool, scale: bool) -> Self {
63        PCA {
64            n_components: ncomponents,
65            center,
66            scale,
67            components: None,
68            mean: None,
69            std: None,
70            singular_values: None,
71            explained_variance_ratio: None,
72        }
73    }
74
75    /// Fits the PCA model to the input data
76    ///
77    /// # Arguments
78    /// * `x` - The input data, shape (n_samples, n_features)
79    ///
80    /// # Returns
81    /// * `Result<()>` - Ok if successful, Err otherwise
82    pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<()>
83    where
84        S: Data,
85        S::Elem: Float + NumCast,
86    {
87        let x_f64 = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
88
89        let n_samples = x_f64.shape()[0];
90        let n_features = x_f64.shape()[1];
91
92        if n_samples == 0 || n_features == 0 {
93            return Err(TransformError::InvalidInput("Empty input data".to_string()));
94        }
95
96        if self.n_components > n_features {
97            return Err(TransformError::InvalidInput(format!(
98                "n_components={} must be <= n_features={}",
99                self.n_components, n_features
100            )));
101        }
102
103        // Center and scale data if requested
104        let mut x_processed = Array2::zeros((n_samples, n_features));
105        let mut mean = Array1::zeros(n_features);
106        let mut std = Array1::ones(n_features);
107
108        if self.center {
109            for j in 0..n_features {
110                let col_mean = x_f64.column(j).sum() / n_samples as f64;
111                mean[j] = col_mean;
112
113                for i in 0..n_samples {
114                    x_processed[[i, j]] = x_f64[[i, j]] - col_mean;
115                }
116            }
117        } else {
118            x_processed.assign(&x_f64);
119        }
120
121        if self.scale {
122            for j in 0..n_features {
123                let col_std =
124                    (x_processed.column(j).mapv(|x| x * x).sum() / n_samples as f64).sqrt();
125                if col_std > f64::EPSILON {
126                    std[j] = col_std;
127
128                    for i in 0..n_samples {
129                        x_processed[[i, j]] /= col_std;
130                    }
131                }
132            }
133        }
134
135        // Perform SVD
136        let (_u, s, vt) = match svd::<f64>(&x_processed.view(), true, None) {
137            Ok(result) => result,
138            Err(e) => return Err(TransformError::LinalgError(e)),
139        };
140
141        // Extract components and singular values
142        let mut components = Array2::zeros((self.n_components, n_features));
143        let mut singular_values = Array1::zeros(self.n_components);
144
145        for i in 0..self.n_components {
146            singular_values[i] = s[i];
147            for j in 0..n_features {
148                components[[i, j]] = vt[[i, j]];
149            }
150        }
151
152        // Compute explained variance ratio
153        let total_variance = s.mapv(|s| s * s).sum();
154        let explained_variance_ratio = singular_values.mapv(|s| s * s / total_variance);
155
156        self.components = Some(components);
157        self.mean = Some(mean);
158        self.std = Some(std);
159        self.singular_values = Some(singular_values);
160        self.explained_variance_ratio = Some(explained_variance_ratio);
161
162        Ok(())
163    }
164
165    /// Transforms the input data using the fitted PCA model
166    ///
167    /// # Arguments
168    /// * `x` - The input data, shape (n_samples, n_features)
169    ///
170    /// # Returns
171    /// * `Result<Array2<f64>>` - The transformed data, shape (n_samples, n_components)
172    pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
173    where
174        S: Data,
175        S::Elem: Float + NumCast,
176    {
177        let x_f64 = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
178
179        let n_samples = x_f64.shape()[0];
180        let n_features = x_f64.shape()[1];
181
182        if self.components.is_none() {
183            return Err(TransformError::TransformationError(
184                "PCA model has not been fitted".to_string(),
185            ));
186        }
187
188        let components = self.components.as_ref().unwrap();
189        let mean = self.mean.as_ref().unwrap();
190        let std = self.std.as_ref().unwrap();
191
192        if n_features != components.shape()[1] {
193            return Err(TransformError::InvalidInput(format!(
194                "x has {} features, but PCA was fitted with {} features",
195                n_features,
196                components.shape()[1]
197            )));
198        }
199
200        // Center and scale data if the model was fitted with centering/scaling
201        let mut x_processed = Array2::zeros((n_samples, n_features));
202
203        for i in 0..n_samples {
204            for j in 0..n_features {
205                let mut value = x_f64[[i, j]];
206
207                if self.center {
208                    value -= mean[j];
209                }
210
211                if self.scale {
212                    value /= std[j];
213                }
214
215                x_processed[[i, j]] = value;
216            }
217        }
218
219        // Project data onto principal components
220        let mut transformed = Array2::zeros((n_samples, self.n_components));
221
222        for i in 0..n_samples {
223            for j in 0..self.n_components {
224                let mut dot_product = 0.0;
225                for k in 0..n_features {
226                    dot_product += x_processed[[i, k]] * components[[j, k]];
227                }
228                transformed[[i, j]] = dot_product;
229            }
230        }
231
232        Ok(transformed)
233    }
234
235    /// Fits the PCA model to the input data and transforms it
236    ///
237    /// # Arguments
238    /// * `x` - The input data, shape (n_samples, n_features)
239    ///
240    /// # Returns
241    /// * `Result<Array2<f64>>` - The transformed data, shape (n_samples, n_components)
242    pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
243    where
244        S: Data,
245        S::Elem: Float + NumCast,
246    {
247        self.fit(x)?;
248        self.transform(x)
249    }
250
251    /// Returns the principal components
252    ///
253    /// # Returns
254    /// * `Option<&Array2<f64>>` - The principal components, shape (n_components, n_features)
255    pub fn components(&self) -> Option<&Array2<f64>> {
256        self.components.as_ref()
257    }
258
259    /// Returns the explained variance ratio
260    ///
261    /// # Returns
262    /// * `Option<&Array1<f64>>` - The explained variance ratio
263    pub fn explained_variance_ratio(&self) -> Option<&Array1<f64>> {
264        self.explained_variance_ratio.as_ref()
265    }
266}
267
268/// Truncated Singular Value Decomposition (SVD) for dimensionality reduction
269///
270/// This transformer performs linear dimensionality reduction by means of
271/// truncated singular value decomposition (SVD). It works on any data and
272/// not just sparse matrices.
273#[derive(Debug, Clone)]
274pub struct TruncatedSVD {
275    /// Number of components to keep
276    n_components: usize,
277    /// The singular values of the training data
278    singular_values: Option<Array1<f64>>,
279    /// The right singular vectors
280    components: Option<Array2<f64>>,
281    /// The explained variance ratio
282    explained_variance_ratio: Option<Array1<f64>>,
283}
284
285impl TruncatedSVD {
286    /// Creates a new TruncatedSVD instance
287    ///
288    /// # Arguments
289    /// * `n_components` - Number of components to keep
290    ///
291    /// # Returns
292    /// * A new TruncatedSVD instance
293    pub fn new(ncomponents: usize) -> Self {
294        TruncatedSVD {
295            n_components: ncomponents,
296            singular_values: None,
297            components: None,
298            explained_variance_ratio: None,
299        }
300    }
301
302    /// Fits the TruncatedSVD model to the input data
303    ///
304    /// # Arguments
305    /// * `x` - The input data, shape (n_samples, n_features)
306    ///
307    /// # Returns
308    /// * `Result<()>` - Ok if successful, Err otherwise
309    pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<()>
310    where
311        S: Data,
312        S::Elem: Float + NumCast,
313    {
314        let x_f64 = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
315
316        let n_samples = x_f64.shape()[0];
317        let n_features = x_f64.shape()[1];
318
319        if n_samples == 0 || n_features == 0 {
320            return Err(TransformError::InvalidInput("Empty input data".to_string()));
321        }
322
323        if self.n_components > n_features {
324            return Err(TransformError::InvalidInput(format!(
325                "n_components={} must be <= n_features={}",
326                self.n_components, n_features
327            )));
328        }
329
330        // Perform SVD
331        let (_u, s, vt) = match svd::<f64>(&x_f64.view(), true, None) {
332            Ok(result) => result,
333            Err(e) => return Err(TransformError::LinalgError(e)),
334        };
335
336        // Extract components and singular values
337        let mut components = Array2::zeros((self.n_components, n_features));
338        let mut singular_values = Array1::zeros(self.n_components);
339
340        for i in 0..self.n_components {
341            singular_values[i] = s[i];
342            for j in 0..n_features {
343                components[[i, j]] = vt[[i, j]];
344            }
345        }
346
347        // Compute explained variance ratio
348        let total_variance =
349            (x_f64.map_axis(Axis(1), |row| row.dot(&row)).sum()) / n_samples as f64;
350        let explained_variance = singular_values.mapv(|s| s * s / n_samples as f64);
351        let explained_variance_ratio = explained_variance.mapv(|v| v / total_variance);
352
353        self.singular_values = Some(singular_values);
354        self.components = Some(components);
355        self.explained_variance_ratio = Some(explained_variance_ratio);
356
357        Ok(())
358    }
359
360    /// Transforms the input data using the fitted TruncatedSVD model
361    ///
362    /// # Arguments
363    /// * `x` - The input data, shape (n_samples, n_features)
364    ///
365    /// # Returns
366    /// * `Result<Array2<f64>>` - The transformed data, shape (n_samples, n_components)
367    pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
368    where
369        S: Data,
370        S::Elem: Float + NumCast,
371    {
372        let x_f64 = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
373
374        let n_samples = x_f64.shape()[0];
375        let n_features = x_f64.shape()[1];
376
377        if self.components.is_none() {
378            return Err(TransformError::TransformationError(
379                "TruncatedSVD model has not been fitted".to_string(),
380            ));
381        }
382
383        let components = self.components.as_ref().unwrap();
384
385        if n_features != components.shape()[1] {
386            return Err(TransformError::InvalidInput(format!(
387                "x has {} features, but TruncatedSVD was fitted with {} features",
388                n_features,
389                components.shape()[1]
390            )));
391        }
392
393        // Project data onto components
394        let mut transformed = Array2::zeros((n_samples, self.n_components));
395
396        for i in 0..n_samples {
397            for j in 0..self.n_components {
398                let mut dot_product = 0.0;
399                for k in 0..n_features {
400                    dot_product += x_f64[[i, k]] * components[[j, k]];
401                }
402                transformed[[i, j]] = dot_product;
403            }
404        }
405
406        Ok(transformed)
407    }
408
409    /// Fits the TruncatedSVD model to the input data and transforms it
410    ///
411    /// # Arguments
412    /// * `x` - The input data, shape (n_samples, n_features)
413    ///
414    /// # Returns
415    /// * `Result<Array2<f64>>` - The transformed data, shape (n_samples, n_components)
416    pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
417    where
418        S: Data,
419        S::Elem: Float + NumCast,
420    {
421        self.fit(x)?;
422        self.transform(x)
423    }
424
425    /// Returns the components (right singular vectors)
426    ///
427    /// # Returns
428    /// * `Option<&Array2<f64>>` - The components, shape (n_components, n_features)
429    pub fn components(&self) -> Option<&Array2<f64>> {
430        self.components.as_ref()
431    }
432
433    /// Returns the singular values
434    ///
435    /// # Returns
436    /// * `Option<&Array1<f64>>` - The singular values
437    pub fn singular_values(&self) -> Option<&Array1<f64>> {
438        self.singular_values.as_ref()
439    }
440
441    /// Returns the explained variance ratio
442    ///
443    /// # Returns
444    /// * `Option<&Array1<f64>>` - The explained variance ratio
445    pub fn explained_variance_ratio(&self) -> Option<&Array1<f64>> {
446        self.explained_variance_ratio.as_ref()
447    }
448}
449
450/// Linear Discriminant Analysis (LDA) for dimensionality reduction
451///
452/// LDA finds the directions that maximize the separation between classes.
453#[derive(Debug, Clone)]
454pub struct LDA {
455    /// Number of components to keep
456    n_components: usize,
457    /// Whether to use Singular Value Decomposition
458    solver: String,
459    /// The LDA components
460    components: Option<Array2<f64>>,
461    /// The class means
462    means: Option<Array2<f64>>,
463    /// The explained variance ratio
464    explained_variance_ratio: Option<Array1<f64>>,
465}
466
467impl LDA {
468    /// Creates a new LDA instance
469    ///
470    /// # Arguments
471    /// * `n_components` - Number of components to keep
472    /// * `solver` - The solver to use ('svd' or 'eigen')
473    ///
474    /// # Returns
475    /// * A new LDA instance
476    pub fn new(ncomponents: usize, solver: &str) -> Result<Self> {
477        if solver != "svd" && solver != "eigen" {
478            return Err(TransformError::InvalidInput(
479                "solver must be 'svd' or 'eigen'".to_string(),
480            ));
481        }
482
483        Ok(LDA {
484            n_components: ncomponents,
485            solver: solver.to_string(),
486            components: None,
487            means: None,
488            explained_variance_ratio: None,
489        })
490    }
491
492    /// Fits the LDA model to the input data
493    ///
494    /// # Arguments
495    /// * `x` - The input data, shape (n_samples, n_features)
496    /// * `y` - The target labels, shape (n_samples,)
497    ///
498    /// # Returns
499    /// * `Result<()>` - Ok if successful, Err otherwise
500    pub fn fit<S1, S2>(&mut self, x: &ArrayBase<S1, Ix2>, y: &ArrayBase<S2, Ix1>) -> Result<()>
501    where
502        S1: Data,
503        S2: Data,
504        S1::Elem: Float + NumCast,
505        S2::Elem: Copy + NumCast + Eq + std::hash::Hash,
506    {
507        let x_f64 = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
508
509        let n_samples = x_f64.shape()[0];
510        let n_features = x_f64.shape()[1];
511
512        if n_samples == 0 || n_features == 0 {
513            return Err(TransformError::InvalidInput("Empty input data".to_string()));
514        }
515
516        if n_samples != y.len() {
517            return Err(TransformError::InvalidInput(format!(
518                "x and y have incompatible shapes: x has {} samples, y has {} elements",
519                n_samples,
520                y.len()
521            )));
522        }
523
524        // Convert y to class indices
525        let mut class_indices = vec![];
526        let mut class_map = std::collections::HashMap::new();
527        let mut next_class_idx = 0;
528
529        for &label in y.iter() {
530            let label_u64 = NumCast::from(label).unwrap_or(0);
531
532            if let std::collections::hash_map::Entry::Vacant(e) = class_map.entry(label_u64) {
533                e.insert(next_class_idx);
534                next_class_idx += 1;
535            }
536
537            class_indices.push(class_map[&label_u64]);
538        }
539
540        let n_classes = class_map.len();
541
542        if n_classes <= 1 {
543            return Err(TransformError::InvalidInput(
544                "y has less than 2 classes, LDA requires at least 2 classes".to_string(),
545            ));
546        }
547
548        let maxn_components = n_classes - 1;
549        if self.n_components > maxn_components {
550            return Err(TransformError::InvalidInput(format!(
551                "n_components={} must be <= n_classes-1={}",
552                self.n_components, maxn_components
553            )));
554        }
555
556        // Compute class means
557        let mut class_means = Array2::zeros((n_classes, n_features));
558        let mut class_counts = vec![0; n_classes];
559
560        for i in 0..n_samples {
561            let class_idx = class_indices[i];
562            class_counts[class_idx] += 1;
563
564            for j in 0..n_features {
565                class_means[[class_idx, j]] += x_f64[[i, j]];
566            }
567        }
568
569        for i in 0..n_classes {
570            if class_counts[i] > 0 {
571                for j in 0..n_features {
572                    class_means[[i, j]] /= class_counts[i] as f64;
573                }
574            }
575        }
576
577        // Compute global mean
578        let mut global_mean = Array1::<f64>::zeros(n_features);
579        for i in 0..n_samples {
580            for j in 0..n_features {
581                global_mean[j] += x_f64[[i, j]];
582            }
583        }
584        global_mean.mapv_inplace(|x: f64| x / n_samples as f64);
585
586        // Compute within-class scatter matrix
587        let mut sw = Array2::<f64>::zeros((n_features, n_features));
588        for i in 0..n_samples {
589            let class_idx = class_indices[i];
590            let mut x_centered = Array1::<f64>::zeros(n_features);
591
592            for j in 0..n_features {
593                x_centered[j] = x_f64[[i, j]] - class_means[[class_idx, j]];
594            }
595
596            for j in 0..n_features {
597                for k in 0..n_features {
598                    sw[[j, k]] += x_centered[j] * x_centered[k];
599                }
600            }
601        }
602
603        // Compute between-class scatter matrix
604        let mut sb = Array2::<f64>::zeros((n_features, n_features));
605        for i in 0..n_classes {
606            let mut mean_diff = Array1::<f64>::zeros(n_features);
607            for j in 0..n_features {
608                mean_diff[j] = class_means[[i, j]] - global_mean[j];
609            }
610
611            for j in 0..n_features {
612                for k in 0..n_features {
613                    sb[[j, k]] += class_counts[i] as f64 * mean_diff[j] * mean_diff[k];
614                }
615            }
616        }
617
618        // Solve the generalized eigenvalue problem
619        let mut components = Array2::<f64>::zeros((self.n_components, n_features));
620        let mut eigenvalues = Array1::<f64>::zeros(self.n_components);
621
622        if self.solver == "svd" {
623            // SVD-based solver
624
625            // Decompose the within-class scatter matrix
626            let (u_sw, s_sw, vt_sw) = match svd::<f64>(&sw.view(), true, None) {
627                Ok(result) => result,
628                Err(e) => return Err(TransformError::LinalgError(e)),
629            };
630
631            // Compute the pseudoinverse of sw^(1/2)
632            let mut sw_sqrt_inv = Array2::<f64>::zeros((n_features, n_features));
633            for i in 0..n_features {
634                if s_sw[i] > EPSILON {
635                    for j in 0..n_features {
636                        for k in 0..n_features {
637                            let s_inv_sqrt = 1.0 / s_sw[i].sqrt();
638                            sw_sqrt_inv[[j, k]] += u_sw[[j, i]] * s_inv_sqrt * vt_sw[[i, k]];
639                        }
640                    }
641                }
642            }
643
644            // Transform the between-class scatter matrix
645            let mut sb_transformed = Array2::<f64>::zeros((n_features, n_features));
646            for i in 0..n_features {
647                for j in 0..n_features {
648                    for k in 0..n_features {
649                        for l in 0..n_features {
650                            sb_transformed[[i, j]] +=
651                                sw_sqrt_inv[[i, k]] * sb[[k, l]] * sw_sqrt_inv[[l, j]];
652                        }
653                    }
654                }
655            }
656
657            // Perform SVD on the transformed between-class scatter matrix
658            let (u_sb, s_sb, vt_sb) = match svd::<f64>(&sb_transformed.view(), true, None) {
659                Ok(result) => result,
660                Err(e) => return Err(TransformError::LinalgError(e)),
661            };
662
663            // Compute the LDA components
664            for i in 0..self.n_components {
665                eigenvalues[i] = s_sb[i];
666
667                for j in 0..n_features {
668                    for k in 0..n_features {
669                        components[[i, j]] += sw_sqrt_inv[[k, j]] * u_sb[[k, i]];
670                    }
671                }
672            }
673        } else {
674            // Eigen-based solver - proper generalized eigenvalue problem
675            // Solve: Sb * v = λ * Sw * v
676
677            // Step 1: Regularize Sw to ensure it's invertible
678            let mut sw_reg = sw.clone();
679            for i in 0..n_features {
680                sw_reg[[i, i]] += EPSILON; // Add small regularization to diagonal
681            }
682
683            // Step 2: Compute Cholesky decomposition of regularized Sw
684            // We'll use a simpler approach: Sw^(-1) * Sb
685            let (u_sw, s_sw, vt_sw) = match svd::<f64>(&sw_reg.view(), true, None) {
686                Ok(result) => result,
687                Err(e) => return Err(TransformError::LinalgError(e)),
688            };
689
690            // Compute pseudoinverse of Sw
691            let mut sw_inv = Array2::<f64>::zeros((n_features, n_features));
692            for i in 0..n_features {
693                if s_sw[i] > EPSILON {
694                    for j in 0..n_features {
695                        for k in 0..n_features {
696                            sw_inv[[j, k]] += u_sw[[j, i]] * (1.0 / s_sw[i]) * vt_sw[[i, k]];
697                        }
698                    }
699                }
700            }
701
702            // Step 3: Compute Sw^(-1) * Sb
703            let mut sw_inv_sb = Array2::<f64>::zeros((n_features, n_features));
704            for i in 0..n_features {
705                for j in 0..n_features {
706                    for k in 0..n_features {
707                        sw_inv_sb[[i, j]] += sw_inv[[i, k]] * sb[[k, j]];
708                    }
709                }
710            }
711
712            // Step 4: Compute eigendecomposition of Sw^(-1) * Sb
713            // Since this matrix may not be symmetric, we use the approach where we
714            // symmetrize it by computing (Sw^(-1) * Sb + (Sw^(-1) * Sb)^T) / 2
715            let mut sym_matrix = Array2::<f64>::zeros((n_features, n_features));
716            for i in 0..n_features {
717                for j in 0..n_features {
718                    sym_matrix[[i, j]] = (sw_inv_sb[[i, j]] + sw_inv_sb[[j, i]]) / 2.0;
719                }
720            }
721
722            // Perform eigendecomposition on the symmetrized matrix
723            let (eig_vals, eig_vecs) = match scirs2_linalg::eigh::<f64>(&sym_matrix.view(), None) {
724                Ok(result) => result,
725                Err(_) => {
726                    // Fallback to SVD if eigendecomposition fails
727                    let (u, s, vt) = match svd::<f64>(&sw_inv_sb.view(), true, None) {
728                        Ok(result) => result,
729                        Err(e) => return Err(TransformError::LinalgError(e)),
730                    };
731                    (s, u)
732                }
733            };
734
735            // Sort eigenvalues and eigenvectors in descending order
736            let mut indices: Vec<usize> = (0..n_features).collect();
737            indices.sort_by(|&i, &j| eig_vals[j].partial_cmp(&eig_vals[i]).unwrap());
738
739            // Select top n_components eigenvectors
740            for i in 0..self.n_components {
741                let idx = indices[i];
742                eigenvalues[i] = eig_vals[idx].max(0.0); // Ensure non-negative
743
744                for j in 0..n_features {
745                    components[[i, j]] = eig_vecs[[j, idx]];
746                }
747            }
748
749            // Normalize components
750            for i in 0..self.n_components {
751                let mut norm = 0.0;
752                for j in 0..n_features {
753                    norm += components[[i, j]] * components[[i, j]];
754                }
755                norm = norm.sqrt();
756
757                if norm > EPSILON {
758                    for j in 0..n_features {
759                        components[[i, j]] /= norm;
760                    }
761                }
762            }
763        }
764
765        // Compute explained variance ratio
766        let total_eigenvalues = eigenvalues.iter().sum::<f64>();
767        let explained_variance_ratio = eigenvalues.mapv(|e| e / total_eigenvalues);
768
769        self.components = Some(components);
770        self.means = Some(class_means);
771        self.explained_variance_ratio = Some(explained_variance_ratio);
772
773        Ok(())
774    }
775
776    /// Transforms the input data using the fitted LDA model
777    ///
778    /// # Arguments
779    /// * `x` - The input data, shape (n_samples, n_features)
780    ///
781    /// # Returns
782    /// * `Result<Array2<f64>>` - The transformed data, shape (n_samples, n_components)
783    pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
784    where
785        S: Data,
786        S::Elem: Float + NumCast,
787    {
788        let x_f64 = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
789
790        let n_samples = x_f64.shape()[0];
791        let n_features = x_f64.shape()[1];
792
793        if self.components.is_none() {
794            return Err(TransformError::TransformationError(
795                "LDA model has not been fitted".to_string(),
796            ));
797        }
798
799        let components = self.components.as_ref().unwrap();
800
801        if n_features != components.shape()[1] {
802            return Err(TransformError::InvalidInput(format!(
803                "x has {} features, but LDA was fitted with {} features",
804                n_features,
805                components.shape()[1]
806            )));
807        }
808
809        // Project data onto LDA components
810        let mut transformed = Array2::zeros((n_samples, self.n_components));
811
812        for i in 0..n_samples {
813            for j in 0..self.n_components {
814                let mut dot_product = 0.0;
815                for k in 0..n_features {
816                    dot_product += x_f64[[i, k]] * components[[j, k]];
817                }
818                transformed[[i, j]] = dot_product;
819            }
820        }
821
822        Ok(transformed)
823    }
824
825    /// Fits the LDA model to the input data and transforms it
826    ///
827    /// # Arguments
828    /// * `x` - The input data, shape (n_samples, n_features)
829    /// * `y` - The target labels, shape (n_samples,)
830    ///
831    /// # Returns
832    /// * `Result<Array2<f64>>` - The transformed data, shape (n_samples, n_components)
833    pub fn fit_transform<S1, S2>(
834        &mut self,
835        x: &ArrayBase<S1, Ix2>,
836        y: &ArrayBase<S2, Ix1>,
837    ) -> Result<Array2<f64>>
838    where
839        S1: Data,
840        S2: Data,
841        S1::Elem: Float + NumCast,
842        S2::Elem: Copy + NumCast + Eq + std::hash::Hash,
843    {
844        self.fit(x, y)?;
845        self.transform(x)
846    }
847
848    /// Returns the LDA components
849    ///
850    /// # Returns
851    /// * `Option<&Array2<f64>>` - The LDA components, shape (n_components, n_features)
852    pub fn components(&self) -> Option<&Array2<f64>> {
853        self.components.as_ref()
854    }
855
856    /// Returns the explained variance ratio
857    ///
858    /// # Returns
859    /// * `Option<&Array1<f64>>` - The explained variance ratio
860    pub fn explained_variance_ratio(&self) -> Option<&Array1<f64>> {
861        self.explained_variance_ratio.as_ref()
862    }
863}
864
865#[cfg(test)]
866mod tests {
867    use super::*;
868    use approx::assert_abs_diff_eq;
869    use scirs2_core::ndarray::Array;
870
871    #[test]
872    fn test_pca_transform() {
873        // Create a simple dataset
874        let x = Array::from_shape_vec(
875            (4, 3),
876            vec![
877                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
878            ],
879        )
880        .unwrap();
881
882        // Initialize and fit PCA with 2 components
883        let mut pca = PCA::new(2, true, false);
884        let x_transformed = pca.fit_transform(&x).unwrap();
885
886        // Check that the shape is correct
887        assert_eq!(x_transformed.shape(), &[4, 2]);
888
889        // Check that we have the correct number of explained variance components
890        let explained_variance = pca.explained_variance_ratio().unwrap();
891        assert_eq!(explained_variance.len(), 2);
892
893        // Check that the sum is a valid number (we don't need to enforce sum = 1)
894        assert!(explained_variance.sum() > 0.0 && explained_variance.sum().is_finite());
895    }
896
897    #[test]
898    fn test_truncated_svd() {
899        // Create a simple dataset
900        let x = Array::from_shape_vec(
901            (4, 3),
902            vec![
903                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
904            ],
905        )
906        .unwrap();
907
908        // Initialize and fit TruncatedSVD with 2 components
909        let mut svd = TruncatedSVD::new(2);
910        let x_transformed = svd.fit_transform(&x).unwrap();
911
912        // Check that the shape is correct
913        assert_eq!(x_transformed.shape(), &[4, 2]);
914
915        // Check that we have the correct number of explained variance components
916        let explained_variance = svd.explained_variance_ratio().unwrap();
917        assert_eq!(explained_variance.len(), 2);
918
919        // Check that the sum is a valid number (we don't need to enforce sum = 1)
920        assert!(explained_variance.sum() > 0.0 && explained_variance.sum().is_finite());
921    }
922
923    #[test]
924    fn test_lda() {
925        // Create a simple dataset with 2 classes
926        let x = Array::from_shape_vec(
927            (6, 2),
928            vec![1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 5.0, 4.0, 6.0, 5.0, 7.0, 4.0],
929        )
930        .unwrap();
931
932        let y = Array::from_vec(vec![0, 0, 0, 1, 1, 1]);
933
934        // Initialize and fit LDA with 1 component (max for 2 classes)
935        let mut lda = LDA::new(1, "svd").unwrap();
936        let x_transformed = lda.fit_transform(&x, &y).unwrap();
937
938        // Check that the shape is correct
939        assert_eq!(x_transformed.shape(), &[6, 1]);
940
941        // Check that the explained variance ratio is 1.0 for a single component
942        let explained_variance = lda.explained_variance_ratio().unwrap();
943        assert_abs_diff_eq!(explained_variance[0], 1.0, epsilon = 1e-10);
944    }
945
946    #[test]
947    fn test_lda_eigen_solver() {
948        // Create a simple dataset with 3 classes
949        let x = Array::from_shape_vec(
950            (9, 2),
951            vec![
952                1.0, 2.0, 2.0, 3.0, 3.0, 3.0, // Class 0
953                5.0, 4.0, 6.0, 5.0, 7.0, 4.0, // Class 1
954                9.0, 8.0, 10.0, 9.0, 11.0, 10.0, // Class 2
955            ],
956        )
957        .unwrap();
958
959        let y = Array::from_vec(vec![0, 0, 0, 1, 1, 1, 2, 2, 2]);
960
961        // Test eigen solver
962        let mut lda_eigen = LDA::new(2, "eigen").unwrap(); // 2 components for 3 classes
963        let x_transformed_eigen = lda_eigen.fit_transform(&x, &y).unwrap();
964
965        // Test SVD solver for comparison
966        let mut lda_svd = LDA::new(2, "svd").unwrap();
967        let x_transformed_svd = lda_svd.fit_transform(&x, &y).unwrap();
968
969        // Check that both transformations have correct shape
970        assert_eq!(x_transformed_eigen.shape(), &[9, 2]);
971        assert_eq!(x_transformed_svd.shape(), &[9, 2]);
972
973        // Check that both produce valid results
974        assert!(x_transformed_eigen.iter().all(|&x| x.is_finite()));
975        assert!(x_transformed_svd.iter().all(|&x| x.is_finite()));
976
977        // Check that explained variance ratios are valid for both solvers
978        let explained_variance_eigen = lda_eigen.explained_variance_ratio().unwrap();
979        let explained_variance_svd = lda_svd.explained_variance_ratio().unwrap();
980
981        assert_eq!(explained_variance_eigen.len(), 2);
982        assert_eq!(explained_variance_svd.len(), 2);
983
984        // Both should sum to approximately 1.0
985        assert_abs_diff_eq!(explained_variance_eigen.sum(), 1.0, epsilon = 1e-10);
986        assert_abs_diff_eq!(explained_variance_svd.sum(), 1.0, epsilon = 1e-10);
987
988        // Eigenvalues should be non-negative
989        assert!(explained_variance_eigen.iter().all(|&x| x >= 0.0));
990        assert!(explained_variance_svd.iter().all(|&x| x >= 0.0));
991    }
992
993    #[test]
994    fn test_lda_invalid_solver() {
995        let result = LDA::new(1, "invalid");
996        assert!(result.is_err());
997        assert!(result
998            .unwrap_err()
999            .to_string()
1000            .contains("solver must be 'svd' or 'eigen'"));
1001    }
1002}