Skip to main content

scirs2_transform/reduction/
mod.rs

1//! Dimensionality reduction techniques
2//!
3//! This module provides algorithms for reducing the dimensionality of data,
4//! which is useful for visualization, feature extraction, and reducing
5//! computational complexity.
6
7/// Factor Analysis module
8pub mod factor_analysis;
9mod isomap;
10mod lle;
11mod spectral_embedding;
12mod tsne;
13mod umap;
14
15/// Laplacian Eigenmaps for manifold learning
16pub mod laplacian_eigenmaps;
17
18/// Diffusion Maps for nonlinear dimensionality reduction
19pub mod diffusion_maps;
20
21pub use crate::reduction::diffusion_maps::DiffusionMaps;
22pub use crate::reduction::factor_analysis::{
23    factor_analysis, scree_plot_data, FactorAnalysis, FactorAnalysisResult, RotationMethod,
24    ScreePlotData,
25};
26pub use crate::reduction::isomap::Isomap;
27pub use crate::reduction::laplacian_eigenmaps::{
28    GraphMethod, LaplacianEigenmaps, LaplacianType as LELaplacianType,
29};
30pub use crate::reduction::lle::LLE;
31pub use crate::reduction::spectral_embedding::{AffinityMethod, SpectralEmbedding};
32pub use crate::reduction::tsne::{trustworthiness, TSNE};
33pub use crate::reduction::umap::UMAP;
34
35use scirs2_core::ndarray::{Array1, Array2, ArrayBase, Axis, Data, Ix1, Ix2};
36use scirs2_core::numeric::{Float, NumCast};
37use scirs2_linalg::svd;
38
39use crate::error::{Result, TransformError};
40
41// Define a small value to use for comparison with zero
42const EPSILON: f64 = 1e-10;
43
44/// Principal Component Analysis (PCA) dimensionality reduction
45///
46/// PCA finds the directions of maximum variance in the data and
47/// projects the data onto a lower dimensional space.
48#[derive(Debug, Clone)]
49pub struct PCA {
50    /// Number of components to keep
51    n_components: usize,
52    /// Whether to center the data before computing the SVD
53    center: bool,
54    /// Whether to scale the data before computing the SVD
55    scale: bool,
56    /// The principal components
57    components: Option<Array2<f64>>,
58    /// The mean of the training data
59    mean: Option<Array1<f64>>,
60    /// The standard deviation of the training data
61    std: Option<Array1<f64>>,
62    /// The singular values of the centered training data
63    singular_values: Option<Array1<f64>>,
64    /// The explained variance ratio
65    explained_variance_ratio: Option<Array1<f64>>,
66}
67
68impl PCA {
69    /// Creates a new PCA instance
70    ///
71    /// # Arguments
72    /// * `n_components` - Number of components to keep
73    /// * `center` - Whether to center the data before computing the SVD
74    /// * `scale` - Whether to scale the data before computing the SVD
75    ///
76    /// # Returns
77    /// * A new PCA instance
78    pub fn new(ncomponents: usize, center: bool, scale: bool) -> Self {
79        PCA {
80            n_components: ncomponents,
81            center,
82            scale,
83            components: None,
84            mean: None,
85            std: None,
86            singular_values: None,
87            explained_variance_ratio: None,
88        }
89    }
90
91    /// Fits the PCA model to the input data
92    ///
93    /// # Arguments
94    /// * `x` - The input data, shape (n_samples, n_features)
95    ///
96    /// # Returns
97    /// * `Result<()>` - Ok if successful, Err otherwise
98    pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<()>
99    where
100        S: Data,
101        S::Elem: Float + NumCast,
102    {
103        let x_f64 = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
104
105        let n_samples = x_f64.shape()[0];
106        let n_features = x_f64.shape()[1];
107
108        if n_samples == 0 || n_features == 0 {
109            return Err(TransformError::InvalidInput("Empty input data".to_string()));
110        }
111
112        if self.n_components > n_features {
113            return Err(TransformError::InvalidInput(format!(
114                "n_components={} must be <= n_features={}",
115                self.n_components, n_features
116            )));
117        }
118
119        // Center and scale data if requested
120        let mut x_processed = Array2::zeros((n_samples, n_features));
121        let mut mean = Array1::zeros(n_features);
122        let mut std = Array1::ones(n_features);
123
124        if self.center {
125            for j in 0..n_features {
126                let col_mean = x_f64.column(j).sum() / n_samples as f64;
127                mean[j] = col_mean;
128
129                for i in 0..n_samples {
130                    x_processed[[i, j]] = x_f64[[i, j]] - col_mean;
131                }
132            }
133        } else {
134            x_processed.assign(&x_f64);
135        }
136
137        if self.scale {
138            for j in 0..n_features {
139                let col_std =
140                    (x_processed.column(j).mapv(|x| x * x).sum() / n_samples as f64).sqrt();
141                if col_std > f64::EPSILON {
142                    std[j] = col_std;
143
144                    for i in 0..n_samples {
145                        x_processed[[i, j]] /= col_std;
146                    }
147                }
148            }
149        }
150
151        // Perform SVD
152        let (_u, s, vt) = match svd::<f64>(&x_processed.view(), true, None) {
153            Ok(result) => result,
154            Err(e) => return Err(TransformError::LinalgError(e)),
155        };
156
157        // Extract components and singular values
158        let mut components = Array2::zeros((self.n_components, n_features));
159        let mut singular_values = Array1::zeros(self.n_components);
160
161        for i in 0..self.n_components {
162            singular_values[i] = s[i];
163            for j in 0..n_features {
164                components[[i, j]] = vt[[i, j]];
165            }
166        }
167
168        // Compute explained variance ratio
169        let total_variance = s.mapv(|s| s * s).sum();
170        let explained_variance_ratio = singular_values.mapv(|s| s * s / total_variance);
171
172        self.components = Some(components);
173        self.mean = Some(mean);
174        self.std = Some(std);
175        self.singular_values = Some(singular_values);
176        self.explained_variance_ratio = Some(explained_variance_ratio);
177
178        Ok(())
179    }
180
181    /// Transforms the input data using the fitted PCA model
182    ///
183    /// # Arguments
184    /// * `x` - The input data, shape (n_samples, n_features)
185    ///
186    /// # Returns
187    /// * `Result<Array2<f64>>` - The transformed data, shape (n_samples, n_components)
188    pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
189    where
190        S: Data,
191        S::Elem: Float + NumCast,
192    {
193        let x_f64 = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
194
195        let n_samples = x_f64.shape()[0];
196        let n_features = x_f64.shape()[1];
197
198        if self.components.is_none() {
199            return Err(TransformError::TransformationError(
200                "PCA model has not been fitted".to_string(),
201            ));
202        }
203
204        let components = self.components.as_ref().expect("Operation failed");
205        let mean = self.mean.as_ref().expect("Operation failed");
206        let std = self.std.as_ref().expect("Operation failed");
207
208        if n_features != components.shape()[1] {
209            return Err(TransformError::InvalidInput(format!(
210                "x has {} features, but PCA was fitted with {} features",
211                n_features,
212                components.shape()[1]
213            )));
214        }
215
216        // Center and scale data if the model was fitted with centering/scaling
217        let mut x_processed = Array2::zeros((n_samples, n_features));
218
219        for i in 0..n_samples {
220            for j in 0..n_features {
221                let mut value = x_f64[[i, j]];
222
223                if self.center {
224                    value -= mean[j];
225                }
226
227                if self.scale {
228                    value /= std[j];
229                }
230
231                x_processed[[i, j]] = value;
232            }
233        }
234
235        // Project data onto principal components
236        let mut transformed = Array2::zeros((n_samples, self.n_components));
237
238        for i in 0..n_samples {
239            for j in 0..self.n_components {
240                let mut dot_product = 0.0;
241                for k in 0..n_features {
242                    dot_product += x_processed[[i, k]] * components[[j, k]];
243                }
244                transformed[[i, j]] = dot_product;
245            }
246        }
247
248        Ok(transformed)
249    }
250
251    /// Fits the PCA model to the input data and transforms it
252    ///
253    /// # Arguments
254    /// * `x` - The input data, shape (n_samples, n_features)
255    ///
256    /// # Returns
257    /// * `Result<Array2<f64>>` - The transformed data, shape (n_samples, n_components)
258    pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
259    where
260        S: Data,
261        S::Elem: Float + NumCast,
262    {
263        self.fit(x)?;
264        self.transform(x)
265    }
266
267    /// Returns the principal components
268    ///
269    /// # Returns
270    /// * `Option<&Array2<f64>>` - The principal components, shape (n_components, n_features)
271    pub fn components(&self) -> Option<&Array2<f64>> {
272        self.components.as_ref()
273    }
274
275    /// Returns the mean of the training data (used for centering)
276    ///
277    /// # Returns
278    /// * `Option<&Array1<f64>>` - The per-feature mean, shape (n_features,)
279    pub fn mean(&self) -> Option<&Array1<f64>> {
280        self.mean.as_ref()
281    }
282
283    /// Returns the explained variance ratio
284    ///
285    /// # Returns
286    /// * `Option<&Array1<f64>>` - The explained variance ratio
287    pub fn explained_variance_ratio(&self) -> Option<&Array1<f64>> {
288        self.explained_variance_ratio.as_ref()
289    }
290}
291
292/// Truncated Singular Value Decomposition (SVD) for dimensionality reduction
293///
294/// This transformer performs linear dimensionality reduction by means of
295/// truncated singular value decomposition (SVD). It works on any data and
296/// not just sparse matrices.
297#[derive(Debug, Clone)]
298pub struct TruncatedSVD {
299    /// Number of components to keep
300    n_components: usize,
301    /// The singular values of the training data
302    singular_values: Option<Array1<f64>>,
303    /// The right singular vectors
304    components: Option<Array2<f64>>,
305    /// The explained variance ratio
306    explained_variance_ratio: Option<Array1<f64>>,
307}
308
309impl TruncatedSVD {
310    /// Creates a new TruncatedSVD instance
311    ///
312    /// # Arguments
313    /// * `n_components` - Number of components to keep
314    ///
315    /// # Returns
316    /// * A new TruncatedSVD instance
317    pub fn new(ncomponents: usize) -> Self {
318        TruncatedSVD {
319            n_components: ncomponents,
320            singular_values: None,
321            components: None,
322            explained_variance_ratio: None,
323        }
324    }
325
326    /// Fits the TruncatedSVD model to the input data
327    ///
328    /// # Arguments
329    /// * `x` - The input data, shape (n_samples, n_features)
330    ///
331    /// # Returns
332    /// * `Result<()>` - Ok if successful, Err otherwise
333    pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<()>
334    where
335        S: Data,
336        S::Elem: Float + NumCast,
337    {
338        let x_f64 = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
339
340        let n_samples = x_f64.shape()[0];
341        let n_features = x_f64.shape()[1];
342
343        if n_samples == 0 || n_features == 0 {
344            return Err(TransformError::InvalidInput("Empty input data".to_string()));
345        }
346
347        if self.n_components > n_features {
348            return Err(TransformError::InvalidInput(format!(
349                "n_components={} must be <= n_features={}",
350                self.n_components, n_features
351            )));
352        }
353
354        // Perform SVD
355        let (_u, s, vt) = match svd::<f64>(&x_f64.view(), true, None) {
356            Ok(result) => result,
357            Err(e) => return Err(TransformError::LinalgError(e)),
358        };
359
360        // Extract components and singular values
361        let mut components = Array2::zeros((self.n_components, n_features));
362        let mut singular_values = Array1::zeros(self.n_components);
363
364        for i in 0..self.n_components {
365            singular_values[i] = s[i];
366            for j in 0..n_features {
367                components[[i, j]] = vt[[i, j]];
368            }
369        }
370
371        // Compute explained variance ratio
372        let total_variance =
373            (x_f64.map_axis(Axis(1), |row| row.dot(&row)).sum()) / n_samples as f64;
374        let explained_variance = singular_values.mapv(|s| s * s / n_samples as f64);
375        let explained_variance_ratio = explained_variance.mapv(|v| v / total_variance);
376
377        self.singular_values = Some(singular_values);
378        self.components = Some(components);
379        self.explained_variance_ratio = Some(explained_variance_ratio);
380
381        Ok(())
382    }
383
384    /// Transforms the input data using the fitted TruncatedSVD model
385    ///
386    /// # Arguments
387    /// * `x` - The input data, shape (n_samples, n_features)
388    ///
389    /// # Returns
390    /// * `Result<Array2<f64>>` - The transformed data, shape (n_samples, n_components)
391    pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
392    where
393        S: Data,
394        S::Elem: Float + NumCast,
395    {
396        let x_f64 = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
397
398        let n_samples = x_f64.shape()[0];
399        let n_features = x_f64.shape()[1];
400
401        if self.components.is_none() {
402            return Err(TransformError::TransformationError(
403                "TruncatedSVD model has not been fitted".to_string(),
404            ));
405        }
406
407        let components = self.components.as_ref().expect("Operation failed");
408
409        if n_features != components.shape()[1] {
410            return Err(TransformError::InvalidInput(format!(
411                "x has {} features, but TruncatedSVD was fitted with {} features",
412                n_features,
413                components.shape()[1]
414            )));
415        }
416
417        // Project data onto components
418        let mut transformed = Array2::zeros((n_samples, self.n_components));
419
420        for i in 0..n_samples {
421            for j in 0..self.n_components {
422                let mut dot_product = 0.0;
423                for k in 0..n_features {
424                    dot_product += x_f64[[i, k]] * components[[j, k]];
425                }
426                transformed[[i, j]] = dot_product;
427            }
428        }
429
430        Ok(transformed)
431    }
432
433    /// Fits the TruncatedSVD model to the input data and transforms it
434    ///
435    /// # Arguments
436    /// * `x` - The input data, shape (n_samples, n_features)
437    ///
438    /// # Returns
439    /// * `Result<Array2<f64>>` - The transformed data, shape (n_samples, n_components)
440    pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
441    where
442        S: Data,
443        S::Elem: Float + NumCast,
444    {
445        self.fit(x)?;
446        self.transform(x)
447    }
448
449    /// Returns the components (right singular vectors)
450    ///
451    /// # Returns
452    /// * `Option<&Array2<f64>>` - The components, shape (n_components, n_features)
453    pub fn components(&self) -> Option<&Array2<f64>> {
454        self.components.as_ref()
455    }
456
457    /// Returns the singular values
458    ///
459    /// # Returns
460    /// * `Option<&Array1<f64>>` - The singular values
461    pub fn singular_values(&self) -> Option<&Array1<f64>> {
462        self.singular_values.as_ref()
463    }
464
465    /// Returns the explained variance ratio
466    ///
467    /// # Returns
468    /// * `Option<&Array1<f64>>` - The explained variance ratio
469    pub fn explained_variance_ratio(&self) -> Option<&Array1<f64>> {
470        self.explained_variance_ratio.as_ref()
471    }
472}
473
474/// Linear Discriminant Analysis (LDA) for dimensionality reduction
475///
476/// LDA finds the directions that maximize the separation between classes.
477#[derive(Debug, Clone)]
478pub struct LDA {
479    /// Number of components to keep
480    n_components: usize,
481    /// Whether to use Singular Value Decomposition
482    solver: String,
483    /// The LDA components
484    components: Option<Array2<f64>>,
485    /// The class means
486    means: Option<Array2<f64>>,
487    /// The explained variance ratio
488    explained_variance_ratio: Option<Array1<f64>>,
489}
490
491impl LDA {
492    /// Creates a new LDA instance
493    ///
494    /// # Arguments
495    /// * `n_components` - Number of components to keep
496    /// * `solver` - The solver to use ('svd' or 'eigen')
497    ///
498    /// # Returns
499    /// * A new LDA instance
500    pub fn new(ncomponents: usize, solver: &str) -> Result<Self> {
501        if solver != "svd" && solver != "eigen" {
502            return Err(TransformError::InvalidInput(
503                "solver must be 'svd' or 'eigen'".to_string(),
504            ));
505        }
506
507        Ok(LDA {
508            n_components: ncomponents,
509            solver: solver.to_string(),
510            components: None,
511            means: None,
512            explained_variance_ratio: None,
513        })
514    }
515
516    /// Fits the LDA model to the input data
517    ///
518    /// # Arguments
519    /// * `x` - The input data, shape (n_samples, n_features)
520    /// * `y` - The target labels, shape (n_samples,)
521    ///
522    /// # Returns
523    /// * `Result<()>` - Ok if successful, Err otherwise
524    pub fn fit<S1, S2>(&mut self, x: &ArrayBase<S1, Ix2>, y: &ArrayBase<S2, Ix1>) -> Result<()>
525    where
526        S1: Data,
527        S2: Data,
528        S1::Elem: Float + NumCast,
529        S2::Elem: Copy + NumCast + Eq + std::hash::Hash,
530    {
531        let x_f64 = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
532
533        let n_samples = x_f64.shape()[0];
534        let n_features = x_f64.shape()[1];
535
536        if n_samples == 0 || n_features == 0 {
537            return Err(TransformError::InvalidInput("Empty input data".to_string()));
538        }
539
540        if n_samples != y.len() {
541            return Err(TransformError::InvalidInput(format!(
542                "x and y have incompatible shapes: x has {} samples, y has {} elements",
543                n_samples,
544                y.len()
545            )));
546        }
547
548        // Convert y to class indices
549        let mut class_indices = vec![];
550        let mut class_map = std::collections::HashMap::new();
551        let mut next_class_idx = 0;
552
553        for &label in y.iter() {
554            let label_u64 = NumCast::from(label).unwrap_or(0);
555
556            if let std::collections::hash_map::Entry::Vacant(e) = class_map.entry(label_u64) {
557                e.insert(next_class_idx);
558                next_class_idx += 1;
559            }
560
561            class_indices.push(class_map[&label_u64]);
562        }
563
564        let n_classes = class_map.len();
565
566        if n_classes <= 1 {
567            return Err(TransformError::InvalidInput(
568                "y has less than 2 classes, LDA requires at least 2 classes".to_string(),
569            ));
570        }
571
572        let maxn_components = n_classes - 1;
573        if self.n_components > maxn_components {
574            return Err(TransformError::InvalidInput(format!(
575                "n_components={} must be <= n_classes-1={}",
576                self.n_components, maxn_components
577            )));
578        }
579
580        // Compute class means
581        let mut class_means = Array2::zeros((n_classes, n_features));
582        let mut class_counts = vec![0; n_classes];
583
584        for i in 0..n_samples {
585            let class_idx = class_indices[i];
586            class_counts[class_idx] += 1;
587
588            for j in 0..n_features {
589                class_means[[class_idx, j]] += x_f64[[i, j]];
590            }
591        }
592
593        for i in 0..n_classes {
594            if class_counts[i] > 0 {
595                for j in 0..n_features {
596                    class_means[[i, j]] /= class_counts[i] as f64;
597                }
598            }
599        }
600
601        // Compute global mean
602        let mut global_mean = Array1::<f64>::zeros(n_features);
603        for i in 0..n_samples {
604            for j in 0..n_features {
605                global_mean[j] += x_f64[[i, j]];
606            }
607        }
608        global_mean.mapv_inplace(|x: f64| x / n_samples as f64);
609
610        // Compute within-class scatter matrix
611        let mut sw = Array2::<f64>::zeros((n_features, n_features));
612        for i in 0..n_samples {
613            let class_idx = class_indices[i];
614            let mut x_centered = Array1::<f64>::zeros(n_features);
615
616            for j in 0..n_features {
617                x_centered[j] = x_f64[[i, j]] - class_means[[class_idx, j]];
618            }
619
620            for j in 0..n_features {
621                for k in 0..n_features {
622                    sw[[j, k]] += x_centered[j] * x_centered[k];
623                }
624            }
625        }
626
627        // Compute between-class scatter matrix
628        let mut sb = Array2::<f64>::zeros((n_features, n_features));
629        for i in 0..n_classes {
630            let mut mean_diff = Array1::<f64>::zeros(n_features);
631            for j in 0..n_features {
632                mean_diff[j] = class_means[[i, j]] - global_mean[j];
633            }
634
635            for j in 0..n_features {
636                for k in 0..n_features {
637                    sb[[j, k]] += class_counts[i] as f64 * mean_diff[j] * mean_diff[k];
638                }
639            }
640        }
641
642        // Solve the generalized eigenvalue problem
643        let mut components = Array2::<f64>::zeros((self.n_components, n_features));
644        let mut eigenvalues = Array1::<f64>::zeros(self.n_components);
645
646        if self.solver == "svd" {
647            // SVD-based solver
648
649            // Decompose the within-class scatter matrix
650            let (u_sw, s_sw, vt_sw) = match svd::<f64>(&sw.view(), true, None) {
651                Ok(result) => result,
652                Err(e) => return Err(TransformError::LinalgError(e)),
653            };
654
655            // Compute the pseudoinverse of sw^(1/2)
656            let mut sw_sqrt_inv = Array2::<f64>::zeros((n_features, n_features));
657            for i in 0..n_features {
658                if s_sw[i] > EPSILON {
659                    for j in 0..n_features {
660                        for k in 0..n_features {
661                            let s_inv_sqrt = 1.0 / s_sw[i].sqrt();
662                            sw_sqrt_inv[[j, k]] += u_sw[[j, i]] * s_inv_sqrt * vt_sw[[i, k]];
663                        }
664                    }
665                }
666            }
667
668            // Transform the between-class scatter matrix
669            let mut sb_transformed = Array2::<f64>::zeros((n_features, n_features));
670            for i in 0..n_features {
671                for j in 0..n_features {
672                    for k in 0..n_features {
673                        for l in 0..n_features {
674                            sb_transformed[[i, j]] +=
675                                sw_sqrt_inv[[i, k]] * sb[[k, l]] * sw_sqrt_inv[[l, j]];
676                        }
677                    }
678                }
679            }
680
681            // Perform SVD on the transformed between-class scatter matrix
682            let (u_sb, s_sb, vt_sb) = match svd::<f64>(&sb_transformed.view(), true, None) {
683                Ok(result) => result,
684                Err(e) => return Err(TransformError::LinalgError(e)),
685            };
686
687            // Compute the LDA components
688            for i in 0..self.n_components {
689                eigenvalues[i] = s_sb[i];
690
691                for j in 0..n_features {
692                    for k in 0..n_features {
693                        components[[i, j]] += sw_sqrt_inv[[k, j]] * u_sb[[k, i]];
694                    }
695                }
696            }
697        } else {
698            // Eigen-based solver - proper generalized eigenvalue problem
699            // Solve: Sb * v = λ * Sw * v
700
701            // Step 1: Regularize Sw to ensure it's invertible
702            let mut sw_reg = sw.clone();
703            for i in 0..n_features {
704                sw_reg[[i, i]] += EPSILON; // Add small regularization to diagonal
705            }
706
707            // Step 2: Compute Cholesky decomposition of regularized Sw
708            // We'll use a simpler approach: Sw^(-1) * Sb
709            let (u_sw, s_sw, vt_sw) = match svd::<f64>(&sw_reg.view(), true, None) {
710                Ok(result) => result,
711                Err(e) => return Err(TransformError::LinalgError(e)),
712            };
713
714            // Compute pseudoinverse of Sw
715            let mut sw_inv = Array2::<f64>::zeros((n_features, n_features));
716            for i in 0..n_features {
717                if s_sw[i] > EPSILON {
718                    for j in 0..n_features {
719                        for k in 0..n_features {
720                            sw_inv[[j, k]] += u_sw[[j, i]] * (1.0 / s_sw[i]) * vt_sw[[i, k]];
721                        }
722                    }
723                }
724            }
725
726            // Step 3: Compute Sw^(-1) * Sb
727            let mut sw_inv_sb = Array2::<f64>::zeros((n_features, n_features));
728            for i in 0..n_features {
729                for j in 0..n_features {
730                    for k in 0..n_features {
731                        sw_inv_sb[[i, j]] += sw_inv[[i, k]] * sb[[k, j]];
732                    }
733                }
734            }
735
736            // Step 4: Compute eigendecomposition of Sw^(-1) * Sb
737            // Since this matrix may not be symmetric, we use the approach where we
738            // symmetrize it by computing (Sw^(-1) * Sb + (Sw^(-1) * Sb)^T) / 2
739            let mut sym_matrix = Array2::<f64>::zeros((n_features, n_features));
740            for i in 0..n_features {
741                for j in 0..n_features {
742                    sym_matrix[[i, j]] = (sw_inv_sb[[i, j]] + sw_inv_sb[[j, i]]) / 2.0;
743                }
744            }
745
746            // Perform eigendecomposition on the symmetrized matrix
747            let (eig_vals, eig_vecs) = match scirs2_linalg::eigh::<f64>(&sym_matrix.view(), None) {
748                Ok(result) => result,
749                Err(_) => {
750                    // Fallback to SVD if eigendecomposition fails
751                    let (u, s, vt) = match svd::<f64>(&sw_inv_sb.view(), true, None) {
752                        Ok(result) => result,
753                        Err(e) => return Err(TransformError::LinalgError(e)),
754                    };
755                    (s, u)
756                }
757            };
758
759            // Sort eigenvalues and eigenvectors in descending order
760            let mut indices: Vec<usize> = (0..n_features).collect();
761            indices.sort_by(|&i, &j| {
762                eig_vals[j]
763                    .partial_cmp(&eig_vals[i])
764                    .expect("Operation failed")
765            });
766
767            // Select top n_components eigenvectors
768            for i in 0..self.n_components {
769                let idx = indices[i];
770                eigenvalues[i] = eig_vals[idx].max(0.0); // Ensure non-negative
771
772                for j in 0..n_features {
773                    components[[i, j]] = eig_vecs[[j, idx]];
774                }
775            }
776
777            // Normalize components
778            for i in 0..self.n_components {
779                let mut norm = 0.0;
780                for j in 0..n_features {
781                    norm += components[[i, j]] * components[[i, j]];
782                }
783                norm = norm.sqrt();
784
785                if norm > EPSILON {
786                    for j in 0..n_features {
787                        components[[i, j]] /= norm;
788                    }
789                }
790            }
791        }
792
793        // Compute explained variance ratio
794        let total_eigenvalues = eigenvalues.iter().sum::<f64>();
795        let explained_variance_ratio = eigenvalues.mapv(|e| e / total_eigenvalues);
796
797        self.components = Some(components);
798        self.means = Some(class_means);
799        self.explained_variance_ratio = Some(explained_variance_ratio);
800
801        Ok(())
802    }
803
804    /// Transforms the input data using the fitted LDA model
805    ///
806    /// # Arguments
807    /// * `x` - The input data, shape (n_samples, n_features)
808    ///
809    /// # Returns
810    /// * `Result<Array2<f64>>` - The transformed data, shape (n_samples, n_components)
811    pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
812    where
813        S: Data,
814        S::Elem: Float + NumCast,
815    {
816        let x_f64 = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
817
818        let n_samples = x_f64.shape()[0];
819        let n_features = x_f64.shape()[1];
820
821        if self.components.is_none() {
822            return Err(TransformError::TransformationError(
823                "LDA model has not been fitted".to_string(),
824            ));
825        }
826
827        let components = self.components.as_ref().expect("Operation failed");
828
829        if n_features != components.shape()[1] {
830            return Err(TransformError::InvalidInput(format!(
831                "x has {} features, but LDA was fitted with {} features",
832                n_features,
833                components.shape()[1]
834            )));
835        }
836
837        // Project data onto LDA components
838        let mut transformed = Array2::zeros((n_samples, self.n_components));
839
840        for i in 0..n_samples {
841            for j in 0..self.n_components {
842                let mut dot_product = 0.0;
843                for k in 0..n_features {
844                    dot_product += x_f64[[i, k]] * components[[j, k]];
845                }
846                transformed[[i, j]] = dot_product;
847            }
848        }
849
850        Ok(transformed)
851    }
852
853    /// Fits the LDA model to the input data and transforms it
854    ///
855    /// # Arguments
856    /// * `x` - The input data, shape (n_samples, n_features)
857    /// * `y` - The target labels, shape (n_samples,)
858    ///
859    /// # Returns
860    /// * `Result<Array2<f64>>` - The transformed data, shape (n_samples, n_components)
861    pub fn fit_transform<S1, S2>(
862        &mut self,
863        x: &ArrayBase<S1, Ix2>,
864        y: &ArrayBase<S2, Ix1>,
865    ) -> Result<Array2<f64>>
866    where
867        S1: Data,
868        S2: Data,
869        S1::Elem: Float + NumCast,
870        S2::Elem: Copy + NumCast + Eq + std::hash::Hash,
871    {
872        self.fit(x, y)?;
873        self.transform(x)
874    }
875
876    /// Returns the LDA components
877    ///
878    /// # Returns
879    /// * `Option<&Array2<f64>>` - The LDA components, shape (n_components, n_features)
880    pub fn components(&self) -> Option<&Array2<f64>> {
881        self.components.as_ref()
882    }
883
884    /// Returns the explained variance ratio
885    ///
886    /// # Returns
887    /// * `Option<&Array1<f64>>` - The explained variance ratio
888    pub fn explained_variance_ratio(&self) -> Option<&Array1<f64>> {
889        self.explained_variance_ratio.as_ref()
890    }
891}
892
893#[cfg(test)]
894mod tests {
895    use super::*;
896    use approx::assert_abs_diff_eq;
897    use scirs2_core::ndarray::Array;
898
899    #[test]
900    fn test_pca_transform() {
901        // Create a simple dataset
902        let x = Array::from_shape_vec(
903            (4, 3),
904            vec![
905                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
906            ],
907        )
908        .expect("Operation failed");
909
910        // Initialize and fit PCA with 2 components
911        let mut pca = PCA::new(2, true, false);
912        let x_transformed = pca.fit_transform(&x).expect("Operation failed");
913
914        // Check that the shape is correct
915        assert_eq!(x_transformed.shape(), &[4, 2]);
916
917        // Check that we have the correct number of explained variance components
918        let explained_variance = pca.explained_variance_ratio().expect("Operation failed");
919        assert_eq!(explained_variance.len(), 2);
920
921        // Check that the sum is a valid number (we don't need to enforce sum = 1)
922        assert!(explained_variance.sum() > 0.0 && explained_variance.sum().is_finite());
923    }
924
925    #[test]
926    fn test_truncated_svd() {
927        // Create a simple dataset
928        let x = Array::from_shape_vec(
929            (4, 3),
930            vec![
931                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
932            ],
933        )
934        .expect("Operation failed");
935
936        // Initialize and fit TruncatedSVD with 2 components
937        let mut svd = TruncatedSVD::new(2);
938        let x_transformed = svd.fit_transform(&x).expect("Operation failed");
939
940        // Check that the shape is correct
941        assert_eq!(x_transformed.shape(), &[4, 2]);
942
943        // Check that we have the correct number of explained variance components
944        let explained_variance = svd.explained_variance_ratio().expect("Operation failed");
945        assert_eq!(explained_variance.len(), 2);
946
947        // Check that the sum is a valid number (we don't need to enforce sum = 1)
948        assert!(explained_variance.sum() > 0.0 && explained_variance.sum().is_finite());
949    }
950
951    #[test]
952    fn test_lda() {
953        // Create a simple dataset with 2 classes
954        let x = Array::from_shape_vec(
955            (6, 2),
956            vec![1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 5.0, 4.0, 6.0, 5.0, 7.0, 4.0],
957        )
958        .expect("Operation failed");
959
960        let y = Array::from_vec(vec![0, 0, 0, 1, 1, 1]);
961
962        // Initialize and fit LDA with 1 component (max for 2 classes)
963        let mut lda = LDA::new(1, "svd").expect("Operation failed");
964        let x_transformed = lda.fit_transform(&x, &y).expect("Operation failed");
965
966        // Check that the shape is correct
967        assert_eq!(x_transformed.shape(), &[6, 1]);
968
969        // Check that the explained variance ratio is 1.0 for a single component
970        let explained_variance = lda.explained_variance_ratio().expect("Operation failed");
971        assert_abs_diff_eq!(explained_variance[0], 1.0, epsilon = 1e-10);
972    }
973
974    #[test]
975    fn test_lda_eigen_solver() {
976        // Create a simple dataset with 3 classes
977        let x = Array::from_shape_vec(
978            (9, 2),
979            vec![
980                1.0, 2.0, 2.0, 3.0, 3.0, 3.0, // Class 0
981                5.0, 4.0, 6.0, 5.0, 7.0, 4.0, // Class 1
982                9.0, 8.0, 10.0, 9.0, 11.0, 10.0, // Class 2
983            ],
984        )
985        .expect("Operation failed");
986
987        let y = Array::from_vec(vec![0, 0, 0, 1, 1, 1, 2, 2, 2]);
988
989        // Test eigen solver
990        let mut lda_eigen = LDA::new(2, "eigen").expect("Operation failed"); // 2 components for 3 classes
991        let x_transformed_eigen = lda_eigen.fit_transform(&x, &y).expect("Operation failed");
992
993        // Test SVD solver for comparison
994        let mut lda_svd = LDA::new(2, "svd").expect("Operation failed");
995        let x_transformed_svd = lda_svd.fit_transform(&x, &y).expect("Operation failed");
996
997        // Check that both transformations have correct shape
998        assert_eq!(x_transformed_eigen.shape(), &[9, 2]);
999        assert_eq!(x_transformed_svd.shape(), &[9, 2]);
1000
1001        // Check that both produce valid results
1002        assert!(x_transformed_eigen.iter().all(|&x| x.is_finite()));
1003        assert!(x_transformed_svd.iter().all(|&x| x.is_finite()));
1004
1005        // Check that explained variance ratios are valid for both solvers
1006        let explained_variance_eigen = lda_eigen
1007            .explained_variance_ratio()
1008            .expect("Operation failed");
1009        let explained_variance_svd = lda_svd
1010            .explained_variance_ratio()
1011            .expect("Operation failed");
1012
1013        assert_eq!(explained_variance_eigen.len(), 2);
1014        assert_eq!(explained_variance_svd.len(), 2);
1015
1016        // Both should sum to approximately 1.0
1017        assert_abs_diff_eq!(explained_variance_eigen.sum(), 1.0, epsilon = 1e-10);
1018        assert_abs_diff_eq!(explained_variance_svd.sum(), 1.0, epsilon = 1e-10);
1019
1020        // Eigenvalues should be non-negative
1021        assert!(explained_variance_eigen.iter().all(|&x| x >= 0.0));
1022        assert!(explained_variance_svd.iter().all(|&x| x >= 0.0));
1023    }
1024
1025    #[test]
1026    fn test_lda_invalid_solver() {
1027        let result = LDA::new(1, "invalid");
1028        assert!(result.is_err());
1029        assert!(result
1030            .unwrap_err()
1031            .to_string()
1032            .contains("solver must be 'svd' or 'eigen'"));
1033    }
1034}