Skip to main content

scirs2_transform/reduction/
mod.rs

1//! Dimensionality reduction techniques
2//!
3//! This module provides algorithms for reducing the dimensionality of data,
4//! which is useful for visualization, feature extraction, and reducing
5//! computational complexity.
6
7/// Factor Analysis module
8pub mod factor_analysis;
9mod isomap;
10mod lle;
11mod spectral_embedding;
12mod tsne;
13mod umap;
14
15/// Laplacian Eigenmaps for manifold learning
16pub mod laplacian_eigenmaps;
17
18/// Diffusion Maps for nonlinear dimensionality reduction
19pub mod diffusion_maps;
20
21pub use crate::reduction::diffusion_maps::DiffusionMaps;
22pub use crate::reduction::factor_analysis::{
23    factor_analysis, scree_plot_data, FactorAnalysis, FactorAnalysisResult, RotationMethod,
24    ScreePlotData,
25};
26pub use crate::reduction::isomap::Isomap;
27pub use crate::reduction::laplacian_eigenmaps::{
28    GraphMethod, LaplacianEigenmaps, LaplacianType as LELaplacianType,
29};
30pub use crate::reduction::lle::LLE;
31pub use crate::reduction::spectral_embedding::{AffinityMethod, SpectralEmbedding};
32pub use crate::reduction::tsne::{trustworthiness, TSNE};
33pub use crate::reduction::umap::UMAP;
34
35use scirs2_core::ndarray::{Array1, Array2, ArrayBase, Axis, Data, Ix1, Ix2};
36use scirs2_core::numeric::{Float, NumCast};
37use scirs2_linalg::svd;
38
39use crate::error::{Result, TransformError};
40
41// Define a small value to use for comparison with zero
42const EPSILON: f64 = 1e-10;
43
44/// Principal Component Analysis (PCA) dimensionality reduction
45///
46/// PCA finds the directions of maximum variance in the data and
47/// projects the data onto a lower dimensional space.
48#[derive(Debug, Clone)]
49pub struct PCA {
50    /// Number of components to keep
51    n_components: usize,
52    /// Whether to center the data before computing the SVD
53    center: bool,
54    /// Whether to scale the data before computing the SVD
55    scale: bool,
56    /// The principal components
57    components: Option<Array2<f64>>,
58    /// The mean of the training data
59    mean: Option<Array1<f64>>,
60    /// The standard deviation of the training data
61    std: Option<Array1<f64>>,
62    /// The singular values of the centered training data
63    singular_values: Option<Array1<f64>>,
64    /// The explained variance ratio
65    explained_variance_ratio: Option<Array1<f64>>,
66}
67
68impl PCA {
69    /// Creates a new PCA instance
70    ///
71    /// # Arguments
72    /// * `n_components` - Number of components to keep
73    /// * `center` - Whether to center the data before computing the SVD
74    /// * `scale` - Whether to scale the data before computing the SVD
75    ///
76    /// # Returns
77    /// * A new PCA instance
78    pub fn new(ncomponents: usize, center: bool, scale: bool) -> Self {
79        PCA {
80            n_components: ncomponents,
81            center,
82            scale,
83            components: None,
84            mean: None,
85            std: None,
86            singular_values: None,
87            explained_variance_ratio: None,
88        }
89    }
90
91    /// Fits the PCA model to the input data
92    ///
93    /// # Arguments
94    /// * `x` - The input data, shape (n_samples, n_features)
95    ///
96    /// # Returns
97    /// * `Result<()>` - Ok if successful, Err otherwise
98    pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<()>
99    where
100        S: Data,
101        S::Elem: Float + NumCast,
102    {
103        let x_f64 = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
104
105        let n_samples = x_f64.shape()[0];
106        let n_features = x_f64.shape()[1];
107
108        if n_samples == 0 || n_features == 0 {
109            return Err(TransformError::InvalidInput("Empty input data".to_string()));
110        }
111
112        if self.n_components > n_features {
113            return Err(TransformError::InvalidInput(format!(
114                "n_components={} must be <= n_features={}",
115                self.n_components, n_features
116            )));
117        }
118
119        // Center and scale data if requested
120        let mut x_processed = Array2::zeros((n_samples, n_features));
121        let mut mean = Array1::zeros(n_features);
122        let mut std = Array1::ones(n_features);
123
124        if self.center {
125            for j in 0..n_features {
126                let col_mean = x_f64.column(j).sum() / n_samples as f64;
127                mean[j] = col_mean;
128
129                for i in 0..n_samples {
130                    x_processed[[i, j]] = x_f64[[i, j]] - col_mean;
131                }
132            }
133        } else {
134            x_processed.assign(&x_f64);
135        }
136
137        if self.scale {
138            for j in 0..n_features {
139                let col_std =
140                    (x_processed.column(j).mapv(|x| x * x).sum() / n_samples as f64).sqrt();
141                if col_std > f64::EPSILON {
142                    std[j] = col_std;
143
144                    for i in 0..n_samples {
145                        x_processed[[i, j]] /= col_std;
146                    }
147                }
148            }
149        }
150
151        // Perform SVD
152        let (_u, s, vt) = match svd::<f64>(&x_processed.view(), true, None) {
153            Ok(result) => result,
154            Err(e) => return Err(TransformError::LinalgError(e)),
155        };
156
157        // Extract components and singular values
158        let mut components = Array2::zeros((self.n_components, n_features));
159        let mut singular_values = Array1::zeros(self.n_components);
160
161        for i in 0..self.n_components {
162            singular_values[i] = s[i];
163            for j in 0..n_features {
164                components[[i, j]] = vt[[i, j]];
165            }
166        }
167
168        // Compute explained variance ratio
169        let total_variance = s.mapv(|s| s * s).sum();
170        let explained_variance_ratio = singular_values.mapv(|s| s * s / total_variance);
171
172        self.components = Some(components);
173        self.mean = Some(mean);
174        self.std = Some(std);
175        self.singular_values = Some(singular_values);
176        self.explained_variance_ratio = Some(explained_variance_ratio);
177
178        Ok(())
179    }
180
181    /// Transforms the input data using the fitted PCA model
182    ///
183    /// # Arguments
184    /// * `x` - The input data, shape (n_samples, n_features)
185    ///
186    /// # Returns
187    /// * `Result<Array2<f64>>` - The transformed data, shape (n_samples, n_components)
188    pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
189    where
190        S: Data,
191        S::Elem: Float + NumCast,
192    {
193        let x_f64 = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
194
195        let n_samples = x_f64.shape()[0];
196        let n_features = x_f64.shape()[1];
197
198        if self.components.is_none() {
199            return Err(TransformError::TransformationError(
200                "PCA model has not been fitted".to_string(),
201            ));
202        }
203
204        let components = self.components.as_ref().expect("Operation failed");
205        let mean = self.mean.as_ref().expect("Operation failed");
206        let std = self.std.as_ref().expect("Operation failed");
207
208        if n_features != components.shape()[1] {
209            return Err(TransformError::InvalidInput(format!(
210                "x has {} features, but PCA was fitted with {} features",
211                n_features,
212                components.shape()[1]
213            )));
214        }
215
216        // Center and scale data if the model was fitted with centering/scaling
217        let mut x_processed = Array2::zeros((n_samples, n_features));
218
219        for i in 0..n_samples {
220            for j in 0..n_features {
221                let mut value = x_f64[[i, j]];
222
223                if self.center {
224                    value -= mean[j];
225                }
226
227                if self.scale {
228                    value /= std[j];
229                }
230
231                x_processed[[i, j]] = value;
232            }
233        }
234
235        // Project data onto principal components
236        let mut transformed = Array2::zeros((n_samples, self.n_components));
237
238        for i in 0..n_samples {
239            for j in 0..self.n_components {
240                let mut dot_product = 0.0;
241                for k in 0..n_features {
242                    dot_product += x_processed[[i, k]] * components[[j, k]];
243                }
244                transformed[[i, j]] = dot_product;
245            }
246        }
247
248        Ok(transformed)
249    }
250
251    /// Fits the PCA model to the input data and transforms it
252    ///
253    /// # Arguments
254    /// * `x` - The input data, shape (n_samples, n_features)
255    ///
256    /// # Returns
257    /// * `Result<Array2<f64>>` - The transformed data, shape (n_samples, n_components)
258    pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
259    where
260        S: Data,
261        S::Elem: Float + NumCast,
262    {
263        self.fit(x)?;
264        self.transform(x)
265    }
266
267    /// Returns the principal components
268    ///
269    /// # Returns
270    /// * `Option<&Array2<f64>>` - The principal components, shape (n_components, n_features)
271    pub fn components(&self) -> Option<&Array2<f64>> {
272        self.components.as_ref()
273    }
274
275    /// Returns the explained variance ratio
276    ///
277    /// # Returns
278    /// * `Option<&Array1<f64>>` - The explained variance ratio
279    pub fn explained_variance_ratio(&self) -> Option<&Array1<f64>> {
280        self.explained_variance_ratio.as_ref()
281    }
282}
283
284/// Truncated Singular Value Decomposition (SVD) for dimensionality reduction
285///
286/// This transformer performs linear dimensionality reduction by means of
287/// truncated singular value decomposition (SVD). It works on any data and
288/// not just sparse matrices.
289#[derive(Debug, Clone)]
290pub struct TruncatedSVD {
291    /// Number of components to keep
292    n_components: usize,
293    /// The singular values of the training data
294    singular_values: Option<Array1<f64>>,
295    /// The right singular vectors
296    components: Option<Array2<f64>>,
297    /// The explained variance ratio
298    explained_variance_ratio: Option<Array1<f64>>,
299}
300
301impl TruncatedSVD {
302    /// Creates a new TruncatedSVD instance
303    ///
304    /// # Arguments
305    /// * `n_components` - Number of components to keep
306    ///
307    /// # Returns
308    /// * A new TruncatedSVD instance
309    pub fn new(ncomponents: usize) -> Self {
310        TruncatedSVD {
311            n_components: ncomponents,
312            singular_values: None,
313            components: None,
314            explained_variance_ratio: None,
315        }
316    }
317
318    /// Fits the TruncatedSVD model to the input data
319    ///
320    /// # Arguments
321    /// * `x` - The input data, shape (n_samples, n_features)
322    ///
323    /// # Returns
324    /// * `Result<()>` - Ok if successful, Err otherwise
325    pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<()>
326    where
327        S: Data,
328        S::Elem: Float + NumCast,
329    {
330        let x_f64 = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
331
332        let n_samples = x_f64.shape()[0];
333        let n_features = x_f64.shape()[1];
334
335        if n_samples == 0 || n_features == 0 {
336            return Err(TransformError::InvalidInput("Empty input data".to_string()));
337        }
338
339        if self.n_components > n_features {
340            return Err(TransformError::InvalidInput(format!(
341                "n_components={} must be <= n_features={}",
342                self.n_components, n_features
343            )));
344        }
345
346        // Perform SVD
347        let (_u, s, vt) = match svd::<f64>(&x_f64.view(), true, None) {
348            Ok(result) => result,
349            Err(e) => return Err(TransformError::LinalgError(e)),
350        };
351
352        // Extract components and singular values
353        let mut components = Array2::zeros((self.n_components, n_features));
354        let mut singular_values = Array1::zeros(self.n_components);
355
356        for i in 0..self.n_components {
357            singular_values[i] = s[i];
358            for j in 0..n_features {
359                components[[i, j]] = vt[[i, j]];
360            }
361        }
362
363        // Compute explained variance ratio
364        let total_variance =
365            (x_f64.map_axis(Axis(1), |row| row.dot(&row)).sum()) / n_samples as f64;
366        let explained_variance = singular_values.mapv(|s| s * s / n_samples as f64);
367        let explained_variance_ratio = explained_variance.mapv(|v| v / total_variance);
368
369        self.singular_values = Some(singular_values);
370        self.components = Some(components);
371        self.explained_variance_ratio = Some(explained_variance_ratio);
372
373        Ok(())
374    }
375
376    /// Transforms the input data using the fitted TruncatedSVD model
377    ///
378    /// # Arguments
379    /// * `x` - The input data, shape (n_samples, n_features)
380    ///
381    /// # Returns
382    /// * `Result<Array2<f64>>` - The transformed data, shape (n_samples, n_components)
383    pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
384    where
385        S: Data,
386        S::Elem: Float + NumCast,
387    {
388        let x_f64 = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
389
390        let n_samples = x_f64.shape()[0];
391        let n_features = x_f64.shape()[1];
392
393        if self.components.is_none() {
394            return Err(TransformError::TransformationError(
395                "TruncatedSVD model has not been fitted".to_string(),
396            ));
397        }
398
399        let components = self.components.as_ref().expect("Operation failed");
400
401        if n_features != components.shape()[1] {
402            return Err(TransformError::InvalidInput(format!(
403                "x has {} features, but TruncatedSVD was fitted with {} features",
404                n_features,
405                components.shape()[1]
406            )));
407        }
408
409        // Project data onto components
410        let mut transformed = Array2::zeros((n_samples, self.n_components));
411
412        for i in 0..n_samples {
413            for j in 0..self.n_components {
414                let mut dot_product = 0.0;
415                for k in 0..n_features {
416                    dot_product += x_f64[[i, k]] * components[[j, k]];
417                }
418                transformed[[i, j]] = dot_product;
419            }
420        }
421
422        Ok(transformed)
423    }
424
425    /// Fits the TruncatedSVD model to the input data and transforms it
426    ///
427    /// # Arguments
428    /// * `x` - The input data, shape (n_samples, n_features)
429    ///
430    /// # Returns
431    /// * `Result<Array2<f64>>` - The transformed data, shape (n_samples, n_components)
432    pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
433    where
434        S: Data,
435        S::Elem: Float + NumCast,
436    {
437        self.fit(x)?;
438        self.transform(x)
439    }
440
441    /// Returns the components (right singular vectors)
442    ///
443    /// # Returns
444    /// * `Option<&Array2<f64>>` - The components, shape (n_components, n_features)
445    pub fn components(&self) -> Option<&Array2<f64>> {
446        self.components.as_ref()
447    }
448
449    /// Returns the singular values
450    ///
451    /// # Returns
452    /// * `Option<&Array1<f64>>` - The singular values
453    pub fn singular_values(&self) -> Option<&Array1<f64>> {
454        self.singular_values.as_ref()
455    }
456
457    /// Returns the explained variance ratio
458    ///
459    /// # Returns
460    /// * `Option<&Array1<f64>>` - The explained variance ratio
461    pub fn explained_variance_ratio(&self) -> Option<&Array1<f64>> {
462        self.explained_variance_ratio.as_ref()
463    }
464}
465
466/// Linear Discriminant Analysis (LDA) for dimensionality reduction
467///
468/// LDA finds the directions that maximize the separation between classes.
469#[derive(Debug, Clone)]
470pub struct LDA {
471    /// Number of components to keep
472    n_components: usize,
473    /// Whether to use Singular Value Decomposition
474    solver: String,
475    /// The LDA components
476    components: Option<Array2<f64>>,
477    /// The class means
478    means: Option<Array2<f64>>,
479    /// The explained variance ratio
480    explained_variance_ratio: Option<Array1<f64>>,
481}
482
483impl LDA {
484    /// Creates a new LDA instance
485    ///
486    /// # Arguments
487    /// * `n_components` - Number of components to keep
488    /// * `solver` - The solver to use ('svd' or 'eigen')
489    ///
490    /// # Returns
491    /// * A new LDA instance
492    pub fn new(ncomponents: usize, solver: &str) -> Result<Self> {
493        if solver != "svd" && solver != "eigen" {
494            return Err(TransformError::InvalidInput(
495                "solver must be 'svd' or 'eigen'".to_string(),
496            ));
497        }
498
499        Ok(LDA {
500            n_components: ncomponents,
501            solver: solver.to_string(),
502            components: None,
503            means: None,
504            explained_variance_ratio: None,
505        })
506    }
507
508    /// Fits the LDA model to the input data
509    ///
510    /// # Arguments
511    /// * `x` - The input data, shape (n_samples, n_features)
512    /// * `y` - The target labels, shape (n_samples,)
513    ///
514    /// # Returns
515    /// * `Result<()>` - Ok if successful, Err otherwise
516    pub fn fit<S1, S2>(&mut self, x: &ArrayBase<S1, Ix2>, y: &ArrayBase<S2, Ix1>) -> Result<()>
517    where
518        S1: Data,
519        S2: Data,
520        S1::Elem: Float + NumCast,
521        S2::Elem: Copy + NumCast + Eq + std::hash::Hash,
522    {
523        let x_f64 = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
524
525        let n_samples = x_f64.shape()[0];
526        let n_features = x_f64.shape()[1];
527
528        if n_samples == 0 || n_features == 0 {
529            return Err(TransformError::InvalidInput("Empty input data".to_string()));
530        }
531
532        if n_samples != y.len() {
533            return Err(TransformError::InvalidInput(format!(
534                "x and y have incompatible shapes: x has {} samples, y has {} elements",
535                n_samples,
536                y.len()
537            )));
538        }
539
540        // Convert y to class indices
541        let mut class_indices = vec![];
542        let mut class_map = std::collections::HashMap::new();
543        let mut next_class_idx = 0;
544
545        for &label in y.iter() {
546            let label_u64 = NumCast::from(label).unwrap_or(0);
547
548            if let std::collections::hash_map::Entry::Vacant(e) = class_map.entry(label_u64) {
549                e.insert(next_class_idx);
550                next_class_idx += 1;
551            }
552
553            class_indices.push(class_map[&label_u64]);
554        }
555
556        let n_classes = class_map.len();
557
558        if n_classes <= 1 {
559            return Err(TransformError::InvalidInput(
560                "y has less than 2 classes, LDA requires at least 2 classes".to_string(),
561            ));
562        }
563
564        let maxn_components = n_classes - 1;
565        if self.n_components > maxn_components {
566            return Err(TransformError::InvalidInput(format!(
567                "n_components={} must be <= n_classes-1={}",
568                self.n_components, maxn_components
569            )));
570        }
571
572        // Compute class means
573        let mut class_means = Array2::zeros((n_classes, n_features));
574        let mut class_counts = vec![0; n_classes];
575
576        for i in 0..n_samples {
577            let class_idx = class_indices[i];
578            class_counts[class_idx] += 1;
579
580            for j in 0..n_features {
581                class_means[[class_idx, j]] += x_f64[[i, j]];
582            }
583        }
584
585        for i in 0..n_classes {
586            if class_counts[i] > 0 {
587                for j in 0..n_features {
588                    class_means[[i, j]] /= class_counts[i] as f64;
589                }
590            }
591        }
592
593        // Compute global mean
594        let mut global_mean = Array1::<f64>::zeros(n_features);
595        for i in 0..n_samples {
596            for j in 0..n_features {
597                global_mean[j] += x_f64[[i, j]];
598            }
599        }
600        global_mean.mapv_inplace(|x: f64| x / n_samples as f64);
601
602        // Compute within-class scatter matrix
603        let mut sw = Array2::<f64>::zeros((n_features, n_features));
604        for i in 0..n_samples {
605            let class_idx = class_indices[i];
606            let mut x_centered = Array1::<f64>::zeros(n_features);
607
608            for j in 0..n_features {
609                x_centered[j] = x_f64[[i, j]] - class_means[[class_idx, j]];
610            }
611
612            for j in 0..n_features {
613                for k in 0..n_features {
614                    sw[[j, k]] += x_centered[j] * x_centered[k];
615                }
616            }
617        }
618
619        // Compute between-class scatter matrix
620        let mut sb = Array2::<f64>::zeros((n_features, n_features));
621        for i in 0..n_classes {
622            let mut mean_diff = Array1::<f64>::zeros(n_features);
623            for j in 0..n_features {
624                mean_diff[j] = class_means[[i, j]] - global_mean[j];
625            }
626
627            for j in 0..n_features {
628                for k in 0..n_features {
629                    sb[[j, k]] += class_counts[i] as f64 * mean_diff[j] * mean_diff[k];
630                }
631            }
632        }
633
634        // Solve the generalized eigenvalue problem
635        let mut components = Array2::<f64>::zeros((self.n_components, n_features));
636        let mut eigenvalues = Array1::<f64>::zeros(self.n_components);
637
638        if self.solver == "svd" {
639            // SVD-based solver
640
641            // Decompose the within-class scatter matrix
642            let (u_sw, s_sw, vt_sw) = match svd::<f64>(&sw.view(), true, None) {
643                Ok(result) => result,
644                Err(e) => return Err(TransformError::LinalgError(e)),
645            };
646
647            // Compute the pseudoinverse of sw^(1/2)
648            let mut sw_sqrt_inv = Array2::<f64>::zeros((n_features, n_features));
649            for i in 0..n_features {
650                if s_sw[i] > EPSILON {
651                    for j in 0..n_features {
652                        for k in 0..n_features {
653                            let s_inv_sqrt = 1.0 / s_sw[i].sqrt();
654                            sw_sqrt_inv[[j, k]] += u_sw[[j, i]] * s_inv_sqrt * vt_sw[[i, k]];
655                        }
656                    }
657                }
658            }
659
660            // Transform the between-class scatter matrix
661            let mut sb_transformed = Array2::<f64>::zeros((n_features, n_features));
662            for i in 0..n_features {
663                for j in 0..n_features {
664                    for k in 0..n_features {
665                        for l in 0..n_features {
666                            sb_transformed[[i, j]] +=
667                                sw_sqrt_inv[[i, k]] * sb[[k, l]] * sw_sqrt_inv[[l, j]];
668                        }
669                    }
670                }
671            }
672
673            // Perform SVD on the transformed between-class scatter matrix
674            let (u_sb, s_sb, vt_sb) = match svd::<f64>(&sb_transformed.view(), true, None) {
675                Ok(result) => result,
676                Err(e) => return Err(TransformError::LinalgError(e)),
677            };
678
679            // Compute the LDA components
680            for i in 0..self.n_components {
681                eigenvalues[i] = s_sb[i];
682
683                for j in 0..n_features {
684                    for k in 0..n_features {
685                        components[[i, j]] += sw_sqrt_inv[[k, j]] * u_sb[[k, i]];
686                    }
687                }
688            }
689        } else {
690            // Eigen-based solver - proper generalized eigenvalue problem
691            // Solve: Sb * v = λ * Sw * v
692
693            // Step 1: Regularize Sw to ensure it's invertible
694            let mut sw_reg = sw.clone();
695            for i in 0..n_features {
696                sw_reg[[i, i]] += EPSILON; // Add small regularization to diagonal
697            }
698
699            // Step 2: Compute Cholesky decomposition of regularized Sw
700            // We'll use a simpler approach: Sw^(-1) * Sb
701            let (u_sw, s_sw, vt_sw) = match svd::<f64>(&sw_reg.view(), true, None) {
702                Ok(result) => result,
703                Err(e) => return Err(TransformError::LinalgError(e)),
704            };
705
706            // Compute pseudoinverse of Sw
707            let mut sw_inv = Array2::<f64>::zeros((n_features, n_features));
708            for i in 0..n_features {
709                if s_sw[i] > EPSILON {
710                    for j in 0..n_features {
711                        for k in 0..n_features {
712                            sw_inv[[j, k]] += u_sw[[j, i]] * (1.0 / s_sw[i]) * vt_sw[[i, k]];
713                        }
714                    }
715                }
716            }
717
718            // Step 3: Compute Sw^(-1) * Sb
719            let mut sw_inv_sb = Array2::<f64>::zeros((n_features, n_features));
720            for i in 0..n_features {
721                for j in 0..n_features {
722                    for k in 0..n_features {
723                        sw_inv_sb[[i, j]] += sw_inv[[i, k]] * sb[[k, j]];
724                    }
725                }
726            }
727
728            // Step 4: Compute eigendecomposition of Sw^(-1) * Sb
729            // Since this matrix may not be symmetric, we use the approach where we
730            // symmetrize it by computing (Sw^(-1) * Sb + (Sw^(-1) * Sb)^T) / 2
731            let mut sym_matrix = Array2::<f64>::zeros((n_features, n_features));
732            for i in 0..n_features {
733                for j in 0..n_features {
734                    sym_matrix[[i, j]] = (sw_inv_sb[[i, j]] + sw_inv_sb[[j, i]]) / 2.0;
735                }
736            }
737
738            // Perform eigendecomposition on the symmetrized matrix
739            let (eig_vals, eig_vecs) = match scirs2_linalg::eigh::<f64>(&sym_matrix.view(), None) {
740                Ok(result) => result,
741                Err(_) => {
742                    // Fallback to SVD if eigendecomposition fails
743                    let (u, s, vt) = match svd::<f64>(&sw_inv_sb.view(), true, None) {
744                        Ok(result) => result,
745                        Err(e) => return Err(TransformError::LinalgError(e)),
746                    };
747                    (s, u)
748                }
749            };
750
751            // Sort eigenvalues and eigenvectors in descending order
752            let mut indices: Vec<usize> = (0..n_features).collect();
753            indices.sort_by(|&i, &j| {
754                eig_vals[j]
755                    .partial_cmp(&eig_vals[i])
756                    .expect("Operation failed")
757            });
758
759            // Select top n_components eigenvectors
760            for i in 0..self.n_components {
761                let idx = indices[i];
762                eigenvalues[i] = eig_vals[idx].max(0.0); // Ensure non-negative
763
764                for j in 0..n_features {
765                    components[[i, j]] = eig_vecs[[j, idx]];
766                }
767            }
768
769            // Normalize components
770            for i in 0..self.n_components {
771                let mut norm = 0.0;
772                for j in 0..n_features {
773                    norm += components[[i, j]] * components[[i, j]];
774                }
775                norm = norm.sqrt();
776
777                if norm > EPSILON {
778                    for j in 0..n_features {
779                        components[[i, j]] /= norm;
780                    }
781                }
782            }
783        }
784
785        // Compute explained variance ratio
786        let total_eigenvalues = eigenvalues.iter().sum::<f64>();
787        let explained_variance_ratio = eigenvalues.mapv(|e| e / total_eigenvalues);
788
789        self.components = Some(components);
790        self.means = Some(class_means);
791        self.explained_variance_ratio = Some(explained_variance_ratio);
792
793        Ok(())
794    }
795
796    /// Transforms the input data using the fitted LDA model
797    ///
798    /// # Arguments
799    /// * `x` - The input data, shape (n_samples, n_features)
800    ///
801    /// # Returns
802    /// * `Result<Array2<f64>>` - The transformed data, shape (n_samples, n_components)
803    pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
804    where
805        S: Data,
806        S::Elem: Float + NumCast,
807    {
808        let x_f64 = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
809
810        let n_samples = x_f64.shape()[0];
811        let n_features = x_f64.shape()[1];
812
813        if self.components.is_none() {
814            return Err(TransformError::TransformationError(
815                "LDA model has not been fitted".to_string(),
816            ));
817        }
818
819        let components = self.components.as_ref().expect("Operation failed");
820
821        if n_features != components.shape()[1] {
822            return Err(TransformError::InvalidInput(format!(
823                "x has {} features, but LDA was fitted with {} features",
824                n_features,
825                components.shape()[1]
826            )));
827        }
828
829        // Project data onto LDA components
830        let mut transformed = Array2::zeros((n_samples, self.n_components));
831
832        for i in 0..n_samples {
833            for j in 0..self.n_components {
834                let mut dot_product = 0.0;
835                for k in 0..n_features {
836                    dot_product += x_f64[[i, k]] * components[[j, k]];
837                }
838                transformed[[i, j]] = dot_product;
839            }
840        }
841
842        Ok(transformed)
843    }
844
845    /// Fits the LDA model to the input data and transforms it
846    ///
847    /// # Arguments
848    /// * `x` - The input data, shape (n_samples, n_features)
849    /// * `y` - The target labels, shape (n_samples,)
850    ///
851    /// # Returns
852    /// * `Result<Array2<f64>>` - The transformed data, shape (n_samples, n_components)
853    pub fn fit_transform<S1, S2>(
854        &mut self,
855        x: &ArrayBase<S1, Ix2>,
856        y: &ArrayBase<S2, Ix1>,
857    ) -> Result<Array2<f64>>
858    where
859        S1: Data,
860        S2: Data,
861        S1::Elem: Float + NumCast,
862        S2::Elem: Copy + NumCast + Eq + std::hash::Hash,
863    {
864        self.fit(x, y)?;
865        self.transform(x)
866    }
867
868    /// Returns the LDA components
869    ///
870    /// # Returns
871    /// * `Option<&Array2<f64>>` - The LDA components, shape (n_components, n_features)
872    pub fn components(&self) -> Option<&Array2<f64>> {
873        self.components.as_ref()
874    }
875
876    /// Returns the explained variance ratio
877    ///
878    /// # Returns
879    /// * `Option<&Array1<f64>>` - The explained variance ratio
880    pub fn explained_variance_ratio(&self) -> Option<&Array1<f64>> {
881        self.explained_variance_ratio.as_ref()
882    }
883}
884
885#[cfg(test)]
886mod tests {
887    use super::*;
888    use approx::assert_abs_diff_eq;
889    use scirs2_core::ndarray::Array;
890
891    #[test]
892    fn test_pca_transform() {
893        // Create a simple dataset
894        let x = Array::from_shape_vec(
895            (4, 3),
896            vec![
897                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
898            ],
899        )
900        .expect("Operation failed");
901
902        // Initialize and fit PCA with 2 components
903        let mut pca = PCA::new(2, true, false);
904        let x_transformed = pca.fit_transform(&x).expect("Operation failed");
905
906        // Check that the shape is correct
907        assert_eq!(x_transformed.shape(), &[4, 2]);
908
909        // Check that we have the correct number of explained variance components
910        let explained_variance = pca.explained_variance_ratio().expect("Operation failed");
911        assert_eq!(explained_variance.len(), 2);
912
913        // Check that the sum is a valid number (we don't need to enforce sum = 1)
914        assert!(explained_variance.sum() > 0.0 && explained_variance.sum().is_finite());
915    }
916
917    #[test]
918    fn test_truncated_svd() {
919        // Create a simple dataset
920        let x = Array::from_shape_vec(
921            (4, 3),
922            vec![
923                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
924            ],
925        )
926        .expect("Operation failed");
927
928        // Initialize and fit TruncatedSVD with 2 components
929        let mut svd = TruncatedSVD::new(2);
930        let x_transformed = svd.fit_transform(&x).expect("Operation failed");
931
932        // Check that the shape is correct
933        assert_eq!(x_transformed.shape(), &[4, 2]);
934
935        // Check that we have the correct number of explained variance components
936        let explained_variance = svd.explained_variance_ratio().expect("Operation failed");
937        assert_eq!(explained_variance.len(), 2);
938
939        // Check that the sum is a valid number (we don't need to enforce sum = 1)
940        assert!(explained_variance.sum() > 0.0 && explained_variance.sum().is_finite());
941    }
942
943    #[test]
944    fn test_lda() {
945        // Create a simple dataset with 2 classes
946        let x = Array::from_shape_vec(
947            (6, 2),
948            vec![1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 5.0, 4.0, 6.0, 5.0, 7.0, 4.0],
949        )
950        .expect("Operation failed");
951
952        let y = Array::from_vec(vec![0, 0, 0, 1, 1, 1]);
953
954        // Initialize and fit LDA with 1 component (max for 2 classes)
955        let mut lda = LDA::new(1, "svd").expect("Operation failed");
956        let x_transformed = lda.fit_transform(&x, &y).expect("Operation failed");
957
958        // Check that the shape is correct
959        assert_eq!(x_transformed.shape(), &[6, 1]);
960
961        // Check that the explained variance ratio is 1.0 for a single component
962        let explained_variance = lda.explained_variance_ratio().expect("Operation failed");
963        assert_abs_diff_eq!(explained_variance[0], 1.0, epsilon = 1e-10);
964    }
965
966    #[test]
967    fn test_lda_eigen_solver() {
968        // Create a simple dataset with 3 classes
969        let x = Array::from_shape_vec(
970            (9, 2),
971            vec![
972                1.0, 2.0, 2.0, 3.0, 3.0, 3.0, // Class 0
973                5.0, 4.0, 6.0, 5.0, 7.0, 4.0, // Class 1
974                9.0, 8.0, 10.0, 9.0, 11.0, 10.0, // Class 2
975            ],
976        )
977        .expect("Operation failed");
978
979        let y = Array::from_vec(vec![0, 0, 0, 1, 1, 1, 2, 2, 2]);
980
981        // Test eigen solver
982        let mut lda_eigen = LDA::new(2, "eigen").expect("Operation failed"); // 2 components for 3 classes
983        let x_transformed_eigen = lda_eigen.fit_transform(&x, &y).expect("Operation failed");
984
985        // Test SVD solver for comparison
986        let mut lda_svd = LDA::new(2, "svd").expect("Operation failed");
987        let x_transformed_svd = lda_svd.fit_transform(&x, &y).expect("Operation failed");
988
989        // Check that both transformations have correct shape
990        assert_eq!(x_transformed_eigen.shape(), &[9, 2]);
991        assert_eq!(x_transformed_svd.shape(), &[9, 2]);
992
993        // Check that both produce valid results
994        assert!(x_transformed_eigen.iter().all(|&x| x.is_finite()));
995        assert!(x_transformed_svd.iter().all(|&x| x.is_finite()));
996
997        // Check that explained variance ratios are valid for both solvers
998        let explained_variance_eigen = lda_eigen
999            .explained_variance_ratio()
1000            .expect("Operation failed");
1001        let explained_variance_svd = lda_svd
1002            .explained_variance_ratio()
1003            .expect("Operation failed");
1004
1005        assert_eq!(explained_variance_eigen.len(), 2);
1006        assert_eq!(explained_variance_svd.len(), 2);
1007
1008        // Both should sum to approximately 1.0
1009        assert_abs_diff_eq!(explained_variance_eigen.sum(), 1.0, epsilon = 1e-10);
1010        assert_abs_diff_eq!(explained_variance_svd.sum(), 1.0, epsilon = 1e-10);
1011
1012        // Eigenvalues should be non-negative
1013        assert!(explained_variance_eigen.iter().all(|&x| x >= 0.0));
1014        assert!(explained_variance_svd.iter().all(|&x| x >= 0.0));
1015    }
1016
1017    #[test]
1018    fn test_lda_invalid_solver() {
1019        let result = LDA::new(1, "invalid");
1020        assert!(result.is_err());
1021        assert!(result
1022            .unwrap_err()
1023            .to_string()
1024            .contains("solver must be 'svd' or 'eigen'"));
1025    }
1026}