Skip to main content

scirs2_transform/reduction/
mod.rs

1//! Dimensionality reduction techniques
2//!
3//! This module provides algorithms for reducing the dimensionality of data,
4//! which is useful for visualization, feature extraction, and reducing
5//! computational complexity.
6
7/// Factor Analysis module
8pub mod factor_analysis;
9mod isomap;
10mod lle;
11mod spectral_embedding;
12mod tsne;
13mod umap;
14
15/// Laplacian Eigenmaps for manifold learning
16pub mod laplacian_eigenmaps;
17
18/// Diffusion Maps for nonlinear dimensionality reduction
19pub mod diffusion_maps;
20
21pub use crate::reduction::diffusion_maps::DiffusionMaps;
22pub use crate::reduction::factor_analysis::{
23    factor_analysis, scree_plot_data, FactorAnalysis, FactorAnalysisResult, RotationMethod,
24    ScreePlotData,
25};
26pub use crate::reduction::isomap::Isomap;
27pub use crate::reduction::laplacian_eigenmaps::{
28    GraphMethod, LaplacianEigenmaps, LaplacianType as LELaplacianType,
29};
30pub use crate::reduction::lle::LLE;
31pub use crate::reduction::spectral_embedding::{AffinityMethod, SpectralEmbedding};
32pub use crate::reduction::tsne::{trustworthiness, TSNE};
33pub use crate::reduction::umap::UMAP;
34
35use scirs2_core::ndarray::{Array1, Array2, ArrayBase, Axis, Data, Ix1, Ix2};
36use scirs2_core::numeric::{Float, NumCast};
37use scirs2_linalg::svd;
38
39use crate::error::{Result, TransformError};
40
41// Define a small value to use for comparison with zero
42const EPSILON: f64 = 1e-10;
43
44/// Principal Component Analysis (PCA) dimensionality reduction
45///
46/// PCA finds the directions of maximum variance in the data and
47/// projects the data onto a lower dimensional space.
48#[derive(Debug, Clone)]
49pub struct PCA {
50    /// Number of components to keep
51    n_components: usize,
52    /// Whether to center the data before computing the SVD
53    center: bool,
54    /// Whether to scale the data before computing the SVD
55    scale: bool,
56    /// The principal components
57    components: Option<Array2<f64>>,
58    /// The mean of the training data
59    mean: Option<Array1<f64>>,
60    /// The standard deviation of the training data
61    std: Option<Array1<f64>>,
62    /// The singular values of the centered training data
63    singular_values: Option<Array1<f64>>,
64    /// The explained variance ratio
65    explained_variance_ratio: Option<Array1<f64>>,
66}
67
68impl PCA {
69    /// Creates a new PCA instance
70    ///
71    /// # Arguments
72    /// * `n_components` - Number of components to keep
73    /// * `center` - Whether to center the data before computing the SVD
74    /// * `scale` - Whether to scale the data before computing the SVD
75    ///
76    /// # Returns
77    /// * A new PCA instance
78    pub fn new(ncomponents: usize, center: bool, scale: bool) -> Self {
79        PCA {
80            n_components: ncomponents,
81            center,
82            scale,
83            components: None,
84            mean: None,
85            std: None,
86            singular_values: None,
87            explained_variance_ratio: None,
88        }
89    }
90
91    /// Fits the PCA model to the input data
92    ///
93    /// # Arguments
94    /// * `x` - The input data, shape (n_samples, n_features)
95    ///
96    /// # Returns
97    /// * `Result<()>` - Ok if successful, Err otherwise
98    pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<()>
99    where
100        S: Data,
101        S::Elem: Float + NumCast,
102    {
103        let x_f64 = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
104
105        let n_samples = x_f64.shape()[0];
106        let n_features = x_f64.shape()[1];
107
108        if n_samples == 0 || n_features == 0 {
109            return Err(TransformError::InvalidInput("Empty input data".to_string()));
110        }
111
112        // PCA can return at most min(n_samples, n_features) non-trivial components.
113        // Validate against both: n_components must not exceed either dimension.
114        // (Previously only n_features was checked; on n << p wide data this allowed
115        // n_components > rank(X) which made the legacy full_matrices SVD path
116        // collapse to zero singular values for the trailing components.)
117        let max_components = n_samples.min(n_features);
118        if self.n_components > max_components {
119            return Err(TransformError::InvalidInput(format!(
120                "n_components={} must be <= min(n_samples, n_features)={}",
121                self.n_components, max_components
122            )));
123        }
124
125        // Center and scale data if requested
126        let mut x_processed = Array2::zeros((n_samples, n_features));
127        let mut mean = Array1::zeros(n_features);
128        let mut std = Array1::ones(n_features);
129
130        if self.center {
131            for j in 0..n_features {
132                let col_mean = x_f64.column(j).sum() / n_samples as f64;
133                mean[j] = col_mean;
134
135                for i in 0..n_samples {
136                    x_processed[[i, j]] = x_f64[[i, j]] - col_mean;
137                }
138            }
139        } else {
140            x_processed.assign(&x_f64);
141        }
142
143        if self.scale {
144            for j in 0..n_features {
145                let col_std =
146                    (x_processed.column(j).mapv(|x| x * x).sum() / n_samples as f64).sqrt();
147                if col_std > f64::EPSILON {
148                    std[j] = col_std;
149
150                    for i in 0..n_samples {
151                        x_processed[[i, j]] /= col_std;
152                    }
153                }
154            }
155        }
156
157        // Perform thin SVD: full_matrices=false returns U (n x k), S (k), Vᵀ (k x m)
158        // where k = min(n_samples, n_features). On wide data (n << p) this is critical:
159        // full_matrices=true would build a p×p Vᵀ via Gram–Schmidt orthogonal extension,
160        // which is O(p³) work — catastrophic for typical n << p problems (issue #124).
161        let (_u, s, vt) = match svd::<f64>(&x_processed.view(), false, None) {
162            Ok(result) => result,
163            Err(e) => return Err(TransformError::LinalgError(e)),
164        };
165
166        // Extract components and singular values
167        let mut components = Array2::zeros((self.n_components, n_features));
168        let mut singular_values = Array1::zeros(self.n_components);
169
170        for i in 0..self.n_components {
171            singular_values[i] = s[i];
172            for j in 0..n_features {
173                components[[i, j]] = vt[[i, j]];
174            }
175        }
176
177        // Compute explained variance ratio.
178        // With thin SVD we have only min(n,m) singular values, but for centred data the
179        // trace of the covariance equals sum_i sigma_i^2 over those same min(n,m) values
180        // (the trailing singular values from full_matrices=true would be zero), so this
181        // ratio is identical to the previous behaviour on well-posed inputs.
182        let total_variance = s.mapv(|s| s * s).sum();
183        let explained_variance_ratio = if total_variance > EPSILON {
184            singular_values.mapv(|s| s * s / total_variance)
185        } else {
186            Array1::zeros(self.n_components)
187        };
188
189        self.components = Some(components);
190        self.mean = Some(mean);
191        self.std = Some(std);
192        self.singular_values = Some(singular_values);
193        self.explained_variance_ratio = Some(explained_variance_ratio);
194
195        Ok(())
196    }
197
198    /// Transforms the input data using the fitted PCA model
199    ///
200    /// # Arguments
201    /// * `x` - The input data, shape (n_samples, n_features)
202    ///
203    /// # Returns
204    /// * `Result<Array2<f64>>` - The transformed data, shape (n_samples, n_components)
205    pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
206    where
207        S: Data,
208        S::Elem: Float + NumCast,
209    {
210        let x_f64 = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
211
212        let n_samples = x_f64.shape()[0];
213        let n_features = x_f64.shape()[1];
214
215        if self.components.is_none() {
216            return Err(TransformError::TransformationError(
217                "PCA model has not been fitted".to_string(),
218            ));
219        }
220
221        let components = self.components.as_ref().expect("Operation failed");
222        let mean = self.mean.as_ref().expect("Operation failed");
223        let std = self.std.as_ref().expect("Operation failed");
224
225        if n_features != components.shape()[1] {
226            return Err(TransformError::InvalidInput(format!(
227                "x has {} features, but PCA was fitted with {} features",
228                n_features,
229                components.shape()[1]
230            )));
231        }
232
233        // Center and scale data if the model was fitted with centering/scaling
234        let mut x_processed = Array2::zeros((n_samples, n_features));
235
236        for i in 0..n_samples {
237            for j in 0..n_features {
238                let mut value = x_f64[[i, j]];
239
240                if self.center {
241                    value -= mean[j];
242                }
243
244                if self.scale {
245                    value /= std[j];
246                }
247
248                x_processed[[i, j]] = value;
249            }
250        }
251
252        // Project data onto principal components
253        let mut transformed = Array2::zeros((n_samples, self.n_components));
254
255        for i in 0..n_samples {
256            for j in 0..self.n_components {
257                let mut dot_product = 0.0;
258                for k in 0..n_features {
259                    dot_product += x_processed[[i, k]] * components[[j, k]];
260                }
261                transformed[[i, j]] = dot_product;
262            }
263        }
264
265        Ok(transformed)
266    }
267
268    /// Fits the PCA model to the input data and transforms it
269    ///
270    /// # Arguments
271    /// * `x` - The input data, shape (n_samples, n_features)
272    ///
273    /// # Returns
274    /// * `Result<Array2<f64>>` - The transformed data, shape (n_samples, n_components)
275    pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
276    where
277        S: Data,
278        S::Elem: Float + NumCast,
279    {
280        self.fit(x)?;
281        self.transform(x)
282    }
283
284    /// Returns the principal components
285    ///
286    /// # Returns
287    /// * `Option<&Array2<f64>>` - The principal components, shape (n_components, n_features)
288    pub fn components(&self) -> Option<&Array2<f64>> {
289        self.components.as_ref()
290    }
291
292    /// Returns the mean of the training data (used for centering)
293    ///
294    /// # Returns
295    /// * `Option<&Array1<f64>>` - The per-feature mean, shape (n_features,)
296    pub fn mean(&self) -> Option<&Array1<f64>> {
297        self.mean.as_ref()
298    }
299
300    /// Returns the explained variance ratio
301    ///
302    /// # Returns
303    /// * `Option<&Array1<f64>>` - The explained variance ratio
304    pub fn explained_variance_ratio(&self) -> Option<&Array1<f64>> {
305        self.explained_variance_ratio.as_ref()
306    }
307}
308
309/// Truncated Singular Value Decomposition (SVD) for dimensionality reduction
310///
311/// This transformer performs linear dimensionality reduction by means of
312/// truncated singular value decomposition (SVD). It works on any data and
313/// not just sparse matrices.
314#[derive(Debug, Clone)]
315pub struct TruncatedSVD {
316    /// Number of components to keep
317    n_components: usize,
318    /// The singular values of the training data
319    singular_values: Option<Array1<f64>>,
320    /// The right singular vectors
321    components: Option<Array2<f64>>,
322    /// The explained variance ratio
323    explained_variance_ratio: Option<Array1<f64>>,
324}
325
326impl TruncatedSVD {
327    /// Creates a new TruncatedSVD instance
328    ///
329    /// # Arguments
330    /// * `n_components` - Number of components to keep
331    ///
332    /// # Returns
333    /// * A new TruncatedSVD instance
334    pub fn new(ncomponents: usize) -> Self {
335        TruncatedSVD {
336            n_components: ncomponents,
337            singular_values: None,
338            components: None,
339            explained_variance_ratio: None,
340        }
341    }
342
343    /// Fits the TruncatedSVD model to the input data
344    ///
345    /// # Arguments
346    /// * `x` - The input data, shape (n_samples, n_features)
347    ///
348    /// # Returns
349    /// * `Result<()>` - Ok if successful, Err otherwise
350    pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<()>
351    where
352        S: Data,
353        S::Elem: Float + NumCast,
354    {
355        let x_f64 = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
356
357        let n_samples = x_f64.shape()[0];
358        let n_features = x_f64.shape()[1];
359
360        if n_samples == 0 || n_features == 0 {
361            return Err(TransformError::InvalidInput("Empty input data".to_string()));
362        }
363
364        if self.n_components > n_features {
365            return Err(TransformError::InvalidInput(format!(
366                "n_components={} must be <= n_features={}",
367                self.n_components, n_features
368            )));
369        }
370
371        // Perform SVD
372        let (_u, s, vt) = match svd::<f64>(&x_f64.view(), true, None) {
373            Ok(result) => result,
374            Err(e) => return Err(TransformError::LinalgError(e)),
375        };
376
377        // Extract components and singular values
378        let mut components = Array2::zeros((self.n_components, n_features));
379        let mut singular_values = Array1::zeros(self.n_components);
380
381        for i in 0..self.n_components {
382            singular_values[i] = s[i];
383            for j in 0..n_features {
384                components[[i, j]] = vt[[i, j]];
385            }
386        }
387
388        // Compute explained variance ratio
389        let total_variance =
390            (x_f64.map_axis(Axis(1), |row| row.dot(&row)).sum()) / n_samples as f64;
391        let explained_variance = singular_values.mapv(|s| s * s / n_samples as f64);
392        let explained_variance_ratio = explained_variance.mapv(|v| v / total_variance);
393
394        self.singular_values = Some(singular_values);
395        self.components = Some(components);
396        self.explained_variance_ratio = Some(explained_variance_ratio);
397
398        Ok(())
399    }
400
401    /// Transforms the input data using the fitted TruncatedSVD model
402    ///
403    /// # Arguments
404    /// * `x` - The input data, shape (n_samples, n_features)
405    ///
406    /// # Returns
407    /// * `Result<Array2<f64>>` - The transformed data, shape (n_samples, n_components)
408    pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
409    where
410        S: Data,
411        S::Elem: Float + NumCast,
412    {
413        let x_f64 = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
414
415        let n_samples = x_f64.shape()[0];
416        let n_features = x_f64.shape()[1];
417
418        if self.components.is_none() {
419            return Err(TransformError::TransformationError(
420                "TruncatedSVD model has not been fitted".to_string(),
421            ));
422        }
423
424        let components = self.components.as_ref().expect("Operation failed");
425
426        if n_features != components.shape()[1] {
427            return Err(TransformError::InvalidInput(format!(
428                "x has {} features, but TruncatedSVD was fitted with {} features",
429                n_features,
430                components.shape()[1]
431            )));
432        }
433
434        // Project data onto components
435        let mut transformed = Array2::zeros((n_samples, self.n_components));
436
437        for i in 0..n_samples {
438            for j in 0..self.n_components {
439                let mut dot_product = 0.0;
440                for k in 0..n_features {
441                    dot_product += x_f64[[i, k]] * components[[j, k]];
442                }
443                transformed[[i, j]] = dot_product;
444            }
445        }
446
447        Ok(transformed)
448    }
449
450    /// Fits the TruncatedSVD model to the input data and transforms it
451    ///
452    /// # Arguments
453    /// * `x` - The input data, shape (n_samples, n_features)
454    ///
455    /// # Returns
456    /// * `Result<Array2<f64>>` - The transformed data, shape (n_samples, n_components)
457    pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
458    where
459        S: Data,
460        S::Elem: Float + NumCast,
461    {
462        self.fit(x)?;
463        self.transform(x)
464    }
465
466    /// Returns the components (right singular vectors)
467    ///
468    /// # Returns
469    /// * `Option<&Array2<f64>>` - The components, shape (n_components, n_features)
470    pub fn components(&self) -> Option<&Array2<f64>> {
471        self.components.as_ref()
472    }
473
474    /// Returns the singular values
475    ///
476    /// # Returns
477    /// * `Option<&Array1<f64>>` - The singular values
478    pub fn singular_values(&self) -> Option<&Array1<f64>> {
479        self.singular_values.as_ref()
480    }
481
482    /// Returns the explained variance ratio
483    ///
484    /// # Returns
485    /// * `Option<&Array1<f64>>` - The explained variance ratio
486    pub fn explained_variance_ratio(&self) -> Option<&Array1<f64>> {
487        self.explained_variance_ratio.as_ref()
488    }
489}
490
491/// Linear Discriminant Analysis (LDA) for dimensionality reduction
492///
493/// LDA finds the directions that maximize the separation between classes.
494#[derive(Debug, Clone)]
495pub struct LDA {
496    /// Number of components to keep
497    n_components: usize,
498    /// Whether to use Singular Value Decomposition
499    solver: String,
500    /// The LDA components
501    components: Option<Array2<f64>>,
502    /// The class means
503    means: Option<Array2<f64>>,
504    /// The explained variance ratio
505    explained_variance_ratio: Option<Array1<f64>>,
506}
507
508impl LDA {
509    /// Creates a new LDA instance
510    ///
511    /// # Arguments
512    /// * `n_components` - Number of components to keep
513    /// * `solver` - The solver to use ('svd' or 'eigen')
514    ///
515    /// # Returns
516    /// * A new LDA instance
517    pub fn new(ncomponents: usize, solver: &str) -> Result<Self> {
518        if solver != "svd" && solver != "eigen" {
519            return Err(TransformError::InvalidInput(
520                "solver must be 'svd' or 'eigen'".to_string(),
521            ));
522        }
523
524        Ok(LDA {
525            n_components: ncomponents,
526            solver: solver.to_string(),
527            components: None,
528            means: None,
529            explained_variance_ratio: None,
530        })
531    }
532
533    /// Fits the LDA model to the input data
534    ///
535    /// # Arguments
536    /// * `x` - The input data, shape (n_samples, n_features)
537    /// * `y` - The target labels, shape (n_samples,)
538    ///
539    /// # Returns
540    /// * `Result<()>` - Ok if successful, Err otherwise
541    pub fn fit<S1, S2>(&mut self, x: &ArrayBase<S1, Ix2>, y: &ArrayBase<S2, Ix1>) -> Result<()>
542    where
543        S1: Data,
544        S2: Data,
545        S1::Elem: Float + NumCast,
546        S2::Elem: Copy + NumCast + Eq + std::hash::Hash,
547    {
548        let x_f64 = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
549
550        let n_samples = x_f64.shape()[0];
551        let n_features = x_f64.shape()[1];
552
553        if n_samples == 0 || n_features == 0 {
554            return Err(TransformError::InvalidInput("Empty input data".to_string()));
555        }
556
557        if n_samples != y.len() {
558            return Err(TransformError::InvalidInput(format!(
559                "x and y have incompatible shapes: x has {} samples, y has {} elements",
560                n_samples,
561                y.len()
562            )));
563        }
564
565        // Convert y to class indices
566        let mut class_indices = vec![];
567        let mut class_map = std::collections::HashMap::new();
568        let mut next_class_idx = 0;
569
570        for &label in y.iter() {
571            let label_u64 = NumCast::from(label).unwrap_or(0);
572
573            if let std::collections::hash_map::Entry::Vacant(e) = class_map.entry(label_u64) {
574                e.insert(next_class_idx);
575                next_class_idx += 1;
576            }
577
578            class_indices.push(class_map[&label_u64]);
579        }
580
581        let n_classes = class_map.len();
582
583        if n_classes <= 1 {
584            return Err(TransformError::InvalidInput(
585                "y has less than 2 classes, LDA requires at least 2 classes".to_string(),
586            ));
587        }
588
589        let maxn_components = n_classes - 1;
590        if self.n_components > maxn_components {
591            return Err(TransformError::InvalidInput(format!(
592                "n_components={} must be <= n_classes-1={}",
593                self.n_components, maxn_components
594            )));
595        }
596
597        // Compute class means
598        let mut class_means = Array2::zeros((n_classes, n_features));
599        let mut class_counts = vec![0; n_classes];
600
601        for i in 0..n_samples {
602            let class_idx = class_indices[i];
603            class_counts[class_idx] += 1;
604
605            for j in 0..n_features {
606                class_means[[class_idx, j]] += x_f64[[i, j]];
607            }
608        }
609
610        for i in 0..n_classes {
611            if class_counts[i] > 0 {
612                for j in 0..n_features {
613                    class_means[[i, j]] /= class_counts[i] as f64;
614                }
615            }
616        }
617
618        // Compute global mean
619        let mut global_mean = Array1::<f64>::zeros(n_features);
620        for i in 0..n_samples {
621            for j in 0..n_features {
622                global_mean[j] += x_f64[[i, j]];
623            }
624        }
625        global_mean.mapv_inplace(|x: f64| x / n_samples as f64);
626
627        // Fast path for the n_samples < n_features (wide / n << p) regime with the
628        // svd solver: form only n×p centred-within-class and c×p between-class matrices
629        // and SVD those instead of the p×p scatter matrices. This is the textbook
630        // small-side algorithm — equivalent to the existing p×p formulation but
631        // O(n·p·rank) instead of O(p⁴). Issue #124.
632        //
633        // We deliberately keep the wide-old eigen path untouched: the eigen branch
634        // requires Sw^{-1}·Sb explicitly, which the small-side trick does not produce.
635        if self.solver == "svd" && n_samples < n_features {
636            let (components, eigenvalues) = self.fit_svd_small_side(
637                &x_f64,
638                &class_means,
639                &global_mean,
640                &class_indices,
641                &class_counts,
642                n_samples,
643                n_features,
644                n_classes,
645            )?;
646            let total_eig = eigenvalues.iter().sum::<f64>();
647            let explained_variance_ratio = if total_eig > EPSILON {
648                eigenvalues.mapv(|e| e / total_eig)
649            } else {
650                Array1::from_elem(self.n_components, 1.0 / self.n_components as f64)
651            };
652            self.components = Some(components);
653            self.means = Some(class_means);
654            self.explained_variance_ratio = Some(explained_variance_ratio);
655            return Ok(());
656        }
657
658        // Compute within-class scatter matrix
659        let mut sw = Array2::<f64>::zeros((n_features, n_features));
660        for i in 0..n_samples {
661            let class_idx = class_indices[i];
662            let mut x_centered = Array1::<f64>::zeros(n_features);
663
664            for j in 0..n_features {
665                x_centered[j] = x_f64[[i, j]] - class_means[[class_idx, j]];
666            }
667
668            for j in 0..n_features {
669                for k in 0..n_features {
670                    sw[[j, k]] += x_centered[j] * x_centered[k];
671                }
672            }
673        }
674
675        // Compute between-class scatter matrix
676        let mut sb = Array2::<f64>::zeros((n_features, n_features));
677        for i in 0..n_classes {
678            let mut mean_diff = Array1::<f64>::zeros(n_features);
679            for j in 0..n_features {
680                mean_diff[j] = class_means[[i, j]] - global_mean[j];
681            }
682
683            for j in 0..n_features {
684                for k in 0..n_features {
685                    sb[[j, k]] += class_counts[i] as f64 * mean_diff[j] * mean_diff[k];
686                }
687            }
688        }
689
690        // Solve the generalized eigenvalue problem
691        let mut components = Array2::<f64>::zeros((self.n_components, n_features));
692        let mut eigenvalues = Array1::<f64>::zeros(self.n_components);
693
694        if self.solver == "svd" {
695            // SVD-based solver
696
697            // Decompose the within-class scatter matrix
698            let (u_sw, s_sw, vt_sw) = match svd::<f64>(&sw.view(), true, None) {
699                Ok(result) => result,
700                Err(e) => return Err(TransformError::LinalgError(e)),
701            };
702
703            // Compute the pseudoinverse of sw^(1/2)
704            let mut sw_sqrt_inv = Array2::<f64>::zeros((n_features, n_features));
705            for i in 0..n_features {
706                if s_sw[i] > EPSILON {
707                    for j in 0..n_features {
708                        for k in 0..n_features {
709                            let s_inv_sqrt = 1.0 / s_sw[i].sqrt();
710                            sw_sqrt_inv[[j, k]] += u_sw[[j, i]] * s_inv_sqrt * vt_sw[[i, k]];
711                        }
712                    }
713                }
714            }
715
716            // Transform the between-class scatter matrix
717            let mut sb_transformed = Array2::<f64>::zeros((n_features, n_features));
718            for i in 0..n_features {
719                for j in 0..n_features {
720                    for k in 0..n_features {
721                        for l in 0..n_features {
722                            sb_transformed[[i, j]] +=
723                                sw_sqrt_inv[[i, k]] * sb[[k, l]] * sw_sqrt_inv[[l, j]];
724                        }
725                    }
726                }
727            }
728
729            // Perform SVD on the transformed between-class scatter matrix
730            let (u_sb, s_sb, vt_sb) = match svd::<f64>(&sb_transformed.view(), true, None) {
731                Ok(result) => result,
732                Err(e) => return Err(TransformError::LinalgError(e)),
733            };
734
735            // Compute the LDA components
736            for i in 0..self.n_components {
737                eigenvalues[i] = s_sb[i];
738
739                for j in 0..n_features {
740                    for k in 0..n_features {
741                        components[[i, j]] += sw_sqrt_inv[[k, j]] * u_sb[[k, i]];
742                    }
743                }
744            }
745        } else {
746            // Eigen-based solver - proper generalized eigenvalue problem
747            // Solve: Sb * v = λ * Sw * v
748
749            // Step 1: Regularize Sw to ensure it's invertible
750            let mut sw_reg = sw.clone();
751            for i in 0..n_features {
752                sw_reg[[i, i]] += EPSILON; // Add small regularization to diagonal
753            }
754
755            // Step 2: Compute Cholesky decomposition of regularized Sw
756            // We'll use a simpler approach: Sw^(-1) * Sb
757            let (u_sw, s_sw, vt_sw) = match svd::<f64>(&sw_reg.view(), true, None) {
758                Ok(result) => result,
759                Err(e) => return Err(TransformError::LinalgError(e)),
760            };
761
762            // Compute pseudoinverse of Sw
763            let mut sw_inv = Array2::<f64>::zeros((n_features, n_features));
764            for i in 0..n_features {
765                if s_sw[i] > EPSILON {
766                    for j in 0..n_features {
767                        for k in 0..n_features {
768                            sw_inv[[j, k]] += u_sw[[j, i]] * (1.0 / s_sw[i]) * vt_sw[[i, k]];
769                        }
770                    }
771                }
772            }
773
774            // Step 3: Compute Sw^(-1) * Sb
775            let mut sw_inv_sb = Array2::<f64>::zeros((n_features, n_features));
776            for i in 0..n_features {
777                for j in 0..n_features {
778                    for k in 0..n_features {
779                        sw_inv_sb[[i, j]] += sw_inv[[i, k]] * sb[[k, j]];
780                    }
781                }
782            }
783
784            // Step 4: Compute eigendecomposition of Sw^(-1) * Sb
785            // Since this matrix may not be symmetric, we use the approach where we
786            // symmetrize it by computing (Sw^(-1) * Sb + (Sw^(-1) * Sb)^T) / 2
787            let mut sym_matrix = Array2::<f64>::zeros((n_features, n_features));
788            for i in 0..n_features {
789                for j in 0..n_features {
790                    sym_matrix[[i, j]] = (sw_inv_sb[[i, j]] + sw_inv_sb[[j, i]]) / 2.0;
791                }
792            }
793
794            // Perform eigendecomposition on the symmetrized matrix
795            let (eig_vals, eig_vecs) = match scirs2_linalg::eigh::<f64>(&sym_matrix.view(), None) {
796                Ok(result) => result,
797                Err(_) => {
798                    // Fallback to SVD if eigendecomposition fails
799                    let (u, s, vt) = match svd::<f64>(&sw_inv_sb.view(), true, None) {
800                        Ok(result) => result,
801                        Err(e) => return Err(TransformError::LinalgError(e)),
802                    };
803                    (s, u)
804                }
805            };
806
807            // Sort eigenvalues and eigenvectors in descending order
808            let mut indices: Vec<usize> = (0..n_features).collect();
809            indices.sort_by(|&i, &j| {
810                eig_vals[j]
811                    .partial_cmp(&eig_vals[i])
812                    .expect("Operation failed")
813            });
814
815            // Select top n_components eigenvectors
816            for i in 0..self.n_components {
817                let idx = indices[i];
818                eigenvalues[i] = eig_vals[idx].max(0.0); // Ensure non-negative
819
820                for j in 0..n_features {
821                    components[[i, j]] = eig_vecs[[j, idx]];
822                }
823            }
824
825            // Normalize components
826            for i in 0..self.n_components {
827                let mut norm = 0.0;
828                for j in 0..n_features {
829                    norm += components[[i, j]] * components[[i, j]];
830                }
831                norm = norm.sqrt();
832
833                if norm > EPSILON {
834                    for j in 0..n_features {
835                        components[[i, j]] /= norm;
836                    }
837                }
838            }
839        }
840
841        // Compute explained variance ratio
842        let total_eigenvalues = eigenvalues.iter().sum::<f64>();
843        let explained_variance_ratio = eigenvalues.mapv(|e| e / total_eigenvalues);
844
845        self.components = Some(components);
846        self.means = Some(class_means);
847        self.explained_variance_ratio = Some(explained_variance_ratio);
848
849        Ok(())
850    }
851
852    /// Small-side SVD solver for the n_samples < n_features ("wide" / n << p) regime.
853    ///
854    /// Replaces the textbook p×p Sw / Sb formulation with an n×p centred within-class
855    /// matrix and a c×p between-class matrix. Avoids materialising p×p scatters at all,
856    /// so cost drops from O(p^4) (the dense formulation's quartic four-deep loop) to
857    /// O(n·p·rank). For n=10, p=8319 this is a ~10^9× reduction.
858    ///
859    /// Algorithm (analogous to scikit-learn's `LinearDiscriminantAnalysis(solver="svd")`):
860    /// 1. Stack centred-within-class rows: Xw_i = X_i - mu_{class(i)}.  Shape (n, p).
861    /// 2. Thin SVD of Xw = U_w · diag(S_w) · Vt_w with full_matrices=false.
862    ///    Truncate to rank_w = #{ S_w > tol }.  (rank_w <= n - c.)
863    /// 3. Whitening matrix W = Vt_w[:rank_w].T / S_w[:rank_w]   (shape (p, rank_w)).
864    /// 4. Weighted between-class deviations: Xb_c = sqrt(N_c) · (mu_c - mu).
865    ///    Shape (c, p), rank <= c - 1.
866    /// 5. Project Xb through whitening: M = Xb · W                (shape (c, rank_w)).
867    /// 6. Thin SVD of M = U_b · diag(S_b) · Vt_b.
868    /// 7. Final scalings = W · Vt_b.T                              (shape (p, rank_w)),
869    ///    take its first n_components columns transposed into row-major components.
870    /// 8. eigenvalues are S_b[:n_components]^2 (Fisher discriminant ratios).
871    #[allow(clippy::too_many_arguments)]
872    fn fit_svd_small_side(
873        &self,
874        x_f64: &Array2<f64>,
875        class_means: &Array2<f64>,
876        global_mean: &Array1<f64>,
877        class_indices: &[usize],
878        class_counts: &[usize],
879        n_samples: usize,
880        n_features: usize,
881        n_classes: usize,
882    ) -> Result<(Array2<f64>, Array1<f64>)> {
883        // Step 1: build n×p centred within-class matrix
884        let mut xw = Array2::<f64>::zeros((n_samples, n_features));
885        for i in 0..n_samples {
886            let c = class_indices[i];
887            for j in 0..n_features {
888                xw[[i, j]] = x_f64[[i, j]] - class_means[[c, j]];
889            }
890        }
891
892        // Step 2: thin SVD of Xw.  Vt_w shape (rank, p).
893        let (_u_w, s_w, vt_w) =
894            svd::<f64>(&xw.view(), false, None).map_err(TransformError::LinalgError)?;
895
896        // Determine numerical rank
897        let tol = s_w.iter().cloned().fold(0.0_f64, f64::max)
898            * (n_samples.max(n_features) as f64)
899            * f64::EPSILON;
900        let rank_w = s_w.iter().filter(|&&v| v > tol).count();
901        if rank_w == 0 {
902            return Err(TransformError::ComputationError(
903                "Within-class scatter has zero rank (all samples in same point per class)"
904                    .to_string(),
905            ));
906        }
907
908        // Step 3: whitening matrix W = Vt_w[:rank_w].T * diag(1/S_w[:rank_w])  shape (p, rank_w)
909        let mut w = Array2::<f64>::zeros((n_features, rank_w));
910        for k in 0..rank_w {
911            let inv_s = 1.0 / s_w[k];
912            for j in 0..n_features {
913                w[[j, k]] = vt_w[[k, j]] * inv_s;
914            }
915        }
916
917        // Step 4: weighted between-class deviations  Xb shape (c, p)
918        let mut xb = Array2::<f64>::zeros((n_classes, n_features));
919        for c in 0..n_classes {
920            let sqrt_nc = (class_counts[c] as f64).sqrt();
921            for j in 0..n_features {
922                xb[[c, j]] = sqrt_nc * (class_means[[c, j]] - global_mean[j]);
923            }
924        }
925
926        // Step 5: project Xb through whitening:  M = Xb · W   shape (c, rank_w)
927        let m = xb.dot(&w);
928
929        // Step 6: thin SVD of M.  Vt_b shape (rank_b, rank_w).
930        let (_u_b, s_b, vt_b) =
931            svd::<f64>(&m.view(), false, None).map_err(TransformError::LinalgError)?;
932
933        let rank_b = s_b.len().min(vt_b.nrows());
934        if self.n_components > rank_b {
935            return Err(TransformError::InvalidInput(format!(
936                "n_components={} exceeds the rank of the between-class scatter ({})",
937                self.n_components, rank_b
938            )));
939        }
940
941        // Step 7: components = (W · Vt_b.T)[:, :n_components].T   shape (n_components, p)
942        let mut components = Array2::<f64>::zeros((self.n_components, n_features));
943        for k in 0..self.n_components {
944            for j in 0..n_features {
945                let mut acc = 0.0;
946                for r in 0..rank_w {
947                    acc += w[[j, r]] * vt_b[[k, r]];
948                }
949                components[[k, j]] = acc;
950            }
951        }
952
953        // Step 8: eigenvalues = S_b^2 (discriminant ratios; sum is total separability)
954        let mut eigenvalues = Array1::<f64>::zeros(self.n_components);
955        for k in 0..self.n_components {
956            eigenvalues[k] = s_b[k] * s_b[k];
957        }
958
959        Ok((components, eigenvalues))
960    }
961
962    /// Transforms the input data using the fitted LDA model
963    ///
964    /// # Arguments
965    /// * `x` - The input data, shape (n_samples, n_features)
966    ///
967    /// # Returns
968    /// * `Result<Array2<f64>>` - The transformed data, shape (n_samples, n_components)
969    pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
970    where
971        S: Data,
972        S::Elem: Float + NumCast,
973    {
974        let x_f64 = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
975
976        let n_samples = x_f64.shape()[0];
977        let n_features = x_f64.shape()[1];
978
979        if self.components.is_none() {
980            return Err(TransformError::TransformationError(
981                "LDA model has not been fitted".to_string(),
982            ));
983        }
984
985        let components = self.components.as_ref().expect("Operation failed");
986
987        if n_features != components.shape()[1] {
988            return Err(TransformError::InvalidInput(format!(
989                "x has {} features, but LDA was fitted with {} features",
990                n_features,
991                components.shape()[1]
992            )));
993        }
994
995        // Project data onto LDA components
996        let mut transformed = Array2::zeros((n_samples, self.n_components));
997
998        for i in 0..n_samples {
999            for j in 0..self.n_components {
1000                let mut dot_product = 0.0;
1001                for k in 0..n_features {
1002                    dot_product += x_f64[[i, k]] * components[[j, k]];
1003                }
1004                transformed[[i, j]] = dot_product;
1005            }
1006        }
1007
1008        Ok(transformed)
1009    }
1010
1011    /// Fits the LDA model to the input data and transforms it
1012    ///
1013    /// # Arguments
1014    /// * `x` - The input data, shape (n_samples, n_features)
1015    /// * `y` - The target labels, shape (n_samples,)
1016    ///
1017    /// # Returns
1018    /// * `Result<Array2<f64>>` - The transformed data, shape (n_samples, n_components)
1019    pub fn fit_transform<S1, S2>(
1020        &mut self,
1021        x: &ArrayBase<S1, Ix2>,
1022        y: &ArrayBase<S2, Ix1>,
1023    ) -> Result<Array2<f64>>
1024    where
1025        S1: Data,
1026        S2: Data,
1027        S1::Elem: Float + NumCast,
1028        S2::Elem: Copy + NumCast + Eq + std::hash::Hash,
1029    {
1030        self.fit(x, y)?;
1031        self.transform(x)
1032    }
1033
1034    /// Returns the LDA components
1035    ///
1036    /// # Returns
1037    /// * `Option<&Array2<f64>>` - The LDA components, shape (n_components, n_features)
1038    pub fn components(&self) -> Option<&Array2<f64>> {
1039        self.components.as_ref()
1040    }
1041
1042    /// Returns the explained variance ratio
1043    ///
1044    /// # Returns
1045    /// * `Option<&Array1<f64>>` - The explained variance ratio
1046    pub fn explained_variance_ratio(&self) -> Option<&Array1<f64>> {
1047        self.explained_variance_ratio.as_ref()
1048    }
1049}
1050
1051#[cfg(test)]
1052mod tests {
1053    use super::*;
1054    use approx::assert_abs_diff_eq;
1055    use scirs2_core::ndarray::Array;
1056
1057    #[test]
1058    fn test_pca_transform() {
1059        // Create a simple dataset
1060        let x = Array::from_shape_vec(
1061            (4, 3),
1062            vec![
1063                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
1064            ],
1065        )
1066        .expect("Operation failed");
1067
1068        // Initialize and fit PCA with 2 components
1069        let mut pca = PCA::new(2, true, false);
1070        let x_transformed = pca.fit_transform(&x).expect("Operation failed");
1071
1072        // Check that the shape is correct
1073        assert_eq!(x_transformed.shape(), &[4, 2]);
1074
1075        // Check that we have the correct number of explained variance components
1076        let explained_variance = pca.explained_variance_ratio().expect("Operation failed");
1077        assert_eq!(explained_variance.len(), 2);
1078
1079        // Check that the sum is a valid number (we don't need to enforce sum = 1)
1080        assert!(explained_variance.sum() > 0.0 && explained_variance.sum().is_finite());
1081    }
1082
1083    #[test]
1084    fn test_truncated_svd() {
1085        // Create a simple dataset
1086        let x = Array::from_shape_vec(
1087            (4, 3),
1088            vec![
1089                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
1090            ],
1091        )
1092        .expect("Operation failed");
1093
1094        // Initialize and fit TruncatedSVD with 2 components
1095        let mut svd = TruncatedSVD::new(2);
1096        let x_transformed = svd.fit_transform(&x).expect("Operation failed");
1097
1098        // Check that the shape is correct
1099        assert_eq!(x_transformed.shape(), &[4, 2]);
1100
1101        // Check that we have the correct number of explained variance components
1102        let explained_variance = svd.explained_variance_ratio().expect("Operation failed");
1103        assert_eq!(explained_variance.len(), 2);
1104
1105        // Check that the sum is a valid number (we don't need to enforce sum = 1)
1106        assert!(explained_variance.sum() > 0.0 && explained_variance.sum().is_finite());
1107    }
1108
1109    #[test]
1110    fn test_lda() {
1111        // Create a simple dataset with 2 classes
1112        let x = Array::from_shape_vec(
1113            (6, 2),
1114            vec![1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 5.0, 4.0, 6.0, 5.0, 7.0, 4.0],
1115        )
1116        .expect("Operation failed");
1117
1118        let y = Array::from_vec(vec![0, 0, 0, 1, 1, 1]);
1119
1120        // Initialize and fit LDA with 1 component (max for 2 classes)
1121        let mut lda = LDA::new(1, "svd").expect("Operation failed");
1122        let x_transformed = lda.fit_transform(&x, &y).expect("Operation failed");
1123
1124        // Check that the shape is correct
1125        assert_eq!(x_transformed.shape(), &[6, 1]);
1126
1127        // Check that the explained variance ratio is 1.0 for a single component
1128        let explained_variance = lda.explained_variance_ratio().expect("Operation failed");
1129        assert_abs_diff_eq!(explained_variance[0], 1.0, epsilon = 1e-10);
1130    }
1131
1132    #[test]
1133    fn test_lda_eigen_solver() {
1134        // Create a simple dataset with 3 classes
1135        let x = Array::from_shape_vec(
1136            (9, 2),
1137            vec![
1138                1.0, 2.0, 2.0, 3.0, 3.0, 3.0, // Class 0
1139                5.0, 4.0, 6.0, 5.0, 7.0, 4.0, // Class 1
1140                9.0, 8.0, 10.0, 9.0, 11.0, 10.0, // Class 2
1141            ],
1142        )
1143        .expect("Operation failed");
1144
1145        let y = Array::from_vec(vec![0, 0, 0, 1, 1, 1, 2, 2, 2]);
1146
1147        // Test eigen solver
1148        let mut lda_eigen = LDA::new(2, "eigen").expect("Operation failed"); // 2 components for 3 classes
1149        let x_transformed_eigen = lda_eigen.fit_transform(&x, &y).expect("Operation failed");
1150
1151        // Test SVD solver for comparison
1152        let mut lda_svd = LDA::new(2, "svd").expect("Operation failed");
1153        let x_transformed_svd = lda_svd.fit_transform(&x, &y).expect("Operation failed");
1154
1155        // Check that both transformations have correct shape
1156        assert_eq!(x_transformed_eigen.shape(), &[9, 2]);
1157        assert_eq!(x_transformed_svd.shape(), &[9, 2]);
1158
1159        // Check that both produce valid results
1160        assert!(x_transformed_eigen.iter().all(|&x| x.is_finite()));
1161        assert!(x_transformed_svd.iter().all(|&x| x.is_finite()));
1162
1163        // Check that explained variance ratios are valid for both solvers
1164        let explained_variance_eigen = lda_eigen
1165            .explained_variance_ratio()
1166            .expect("Operation failed");
1167        let explained_variance_svd = lda_svd
1168            .explained_variance_ratio()
1169            .expect("Operation failed");
1170
1171        assert_eq!(explained_variance_eigen.len(), 2);
1172        assert_eq!(explained_variance_svd.len(), 2);
1173
1174        // Both should sum to approximately 1.0
1175        assert_abs_diff_eq!(explained_variance_eigen.sum(), 1.0, epsilon = 1e-10);
1176        assert_abs_diff_eq!(explained_variance_svd.sum(), 1.0, epsilon = 1e-10);
1177
1178        // Eigenvalues should be non-negative
1179        assert!(explained_variance_eigen.iter().all(|&x| x >= 0.0));
1180        assert!(explained_variance_svd.iter().all(|&x| x >= 0.0));
1181    }
1182
1183    #[test]
1184    fn test_lda_invalid_solver() {
1185        let result = LDA::new(1, "invalid");
1186        assert!(result.is_err());
1187        assert!(result
1188            .unwrap_err()
1189            .to_string()
1190            .contains("solver must be 'svd' or 'eigen'"));
1191    }
1192
1193    // ─────────────────────────────────────────────────────────────────────────
1194    // Regression tests for issue #124
1195    //
1196    // PCA and LDA on wide data (n_samples << n_features) used to be O(p^3) /
1197    // O(p^4) because the full-matrices SVD path extended Vᵀ to p×p via Gram–
1198    // Schmidt, and the LDA `svd` solver materialised p×p Sw / Sb scatters. The
1199    // fixes (thin SVD in PCA, small-side algorithm in LDA::fit) bring this to
1200    // O(n·p·rank). We verify both correctness and that the wide path completes
1201    // quickly enough for tests; the timing budget is generous so it isn't flaky.
1202    // ─────────────────────────────────────────────────────────────────────────
1203
1204    fn issue_124_synthetic_data(
1205        n_samples: usize,
1206        n_features: usize,
1207        n_classes: usize,
1208    ) -> (Array2<f64>, scirs2_core::ndarray::Array1<i32>) {
1209        let x = Array2::from_shape_fn((n_samples, n_features), |(i, j)| {
1210            ((i * 7919 + j) % 97) as f64 / 97.0
1211        });
1212        let y = scirs2_core::ndarray::Array1::from_shape_fn(n_samples, |i| (i % n_classes) as i32);
1213        (x, y)
1214    }
1215
1216    #[test]
1217    fn test_issue_124_pca_wide_data_completes_quickly() {
1218        // n_samples << n_features regime: 12 samples × 400 features. Pre-fix, this
1219        // path runs the full-matrices SVD that extends Vᵀ to 400×400 via Gram–Schmidt
1220        // — measurable seconds. Post-fix, it should be <100 ms.
1221        let (x, _) = issue_124_synthetic_data(12, 400, 3);
1222        let start = std::time::Instant::now();
1223        let mut pca = PCA::new(2, true, false);
1224        let result = pca
1225            .fit_transform(&x)
1226            .expect("PCA on wide data must succeed");
1227        let elapsed = start.elapsed();
1228
1229        assert_eq!(result.shape(), &[12, 2]);
1230        assert!(
1231            result.iter().all(|v| v.is_finite()),
1232            "PCA wide output must be finite"
1233        );
1234        // Generous bound to keep CI stable: pre-fix this took seconds, post-fix < 1s
1235        // is the expected behaviour on any reasonable machine. 5 s gives ample margin.
1236        assert!(
1237            elapsed.as_secs_f64() < 5.0,
1238            "PCA on 12x400 wide data took {:.3}s — far slower than expected; \
1239             possible regression of issue #124",
1240            elapsed.as_secs_f64()
1241        );
1242
1243        // Reconstruction sanity: shape is correct and singular values are sorted.
1244        let sv = pca
1245            .singular_values
1246            .as_ref()
1247            .expect("singular values present");
1248        assert!(sv.len() == 2);
1249        assert!(sv[0] >= sv[1]);
1250    }
1251
1252    #[test]
1253    fn test_issue_124_pca_validates_n_components_against_min_dim() {
1254        // n_components > min(n_samples, n_features) must now error rather than
1255        // silently zero-pad the trailing components (legacy behaviour).
1256        let (x, _) = issue_124_synthetic_data(5, 100, 2);
1257        // min(5, 100) = 5; 6 components is invalid
1258        let mut pca_bad = PCA::new(6, true, false);
1259        let err = pca_bad
1260            .fit(&x)
1261            .expect_err("PCA with n_components > min(n,p) must fail");
1262        assert!(
1263            err.to_string().contains("min(n_samples, n_features)"),
1264            "error message should reference min(n_samples, n_features): {err}"
1265        );
1266
1267        // The boundary case (n_components == min) must still succeed.
1268        let mut pca_ok = PCA::new(5, true, false);
1269        let r = pca_ok
1270            .fit_transform(&x)
1271            .expect("PCA with n_components == min(n,p) must succeed");
1272        assert_eq!(r.shape(), &[5, 5]);
1273    }
1274
1275    #[test]
1276    fn test_issue_124_lda_wide_data_svd_solver_completes_quickly() {
1277        // 12 samples × 400 features × 3 classes — the wide regime that the small-side
1278        // svd solver now handles in O(n·p·rank). Pre-fix, this path materialised two
1279        // 400×400 scatter matrices and a four-deep loop ⇒ O(p^4).
1280        let (x, y) = issue_124_synthetic_data(12, 400, 3);
1281
1282        let start = std::time::Instant::now();
1283        let mut lda = LDA::new(2, "svd").expect("LDA::new");
1284        let z = lda
1285            .fit_transform(&x, &y)
1286            .expect("LDA svd on wide data must succeed");
1287        let elapsed = start.elapsed();
1288
1289        assert_eq!(z.shape(), &[12, 2]);
1290        assert!(z.iter().all(|v| v.is_finite()));
1291        // Pre-fix this was many seconds; post-fix < 1 s. 5 s gives ample CI margin.
1292        assert!(
1293            elapsed.as_secs_f64() < 5.0,
1294            "LDA(svd) on 12x400 wide data took {:.3}s — possible regression of issue #124",
1295            elapsed.as_secs_f64()
1296        );
1297
1298        // Components should have unit-ish norm (small-side variant returns whitened
1299        // projections, not strictly unit-norm, but each row should be finite & non-zero).
1300        let components = lda.components().expect("components present");
1301        assert_eq!(components.shape(), &[2, 400]);
1302        for row in components.outer_iter() {
1303            let norm: f64 = row.iter().map(|v| v * v).sum::<f64>().sqrt();
1304            assert!(norm.is_finite() && norm > 0.0, "component norm = {norm}");
1305        }
1306    }
1307
1308    #[test]
1309    fn test_issue_124_lda_svd_small_side_separates_classes() {
1310        // Correctness: in the wide regime, the small-side SVD solver must still
1311        // produce projections that separate classes. Use 3 well-separated Gaussian-
1312        // ish clusters in 200-dim space with n=9 samples.
1313        let mut data = vec![0.0; 9 * 200];
1314        for i in 0..9 {
1315            let class = i / 3;
1316            let center = match class {
1317                0 => 0.0,
1318                1 => 10.0,
1319                _ => -10.0,
1320            };
1321            for j in 0..200 {
1322                // small per-sample jitter + per-class centre offset on the first 10 dims
1323                let jitter = ((i * 31 + j) % 7) as f64 * 0.01;
1324                let centre = if j < 10 { center } else { 0.0 };
1325                data[i * 200 + j] = centre + jitter;
1326            }
1327        }
1328        let x = Array2::from_shape_vec((9, 200), data).expect("shape ok");
1329        let y = scirs2_core::ndarray::Array1::from_vec(vec![0, 0, 0, 1, 1, 1, 2, 2, 2]);
1330
1331        let mut lda = LDA::new(2, "svd").expect("LDA::new");
1332        let z = lda
1333            .fit_transform(&x, &y)
1334            .expect("LDA svd on wide data must succeed");
1335        assert_eq!(z.shape(), &[9, 2]);
1336
1337        // Within-class spread should be much smaller than between-class spread on
1338        // the leading discriminant axis.
1339        let class_means: Vec<f64> = (0..3)
1340            .map(|c| {
1341                let idxs: Vec<usize> = (0..9).filter(|&i| (i / 3) == c).collect();
1342                idxs.iter().map(|&i| z[[i, 0]]).sum::<f64>() / idxs.len() as f64
1343            })
1344            .collect();
1345        let mut between: f64 = 0.0;
1346        for a in 0..3 {
1347            for b in 0..3 {
1348                between += (class_means[a] - class_means[b]).powi(2);
1349            }
1350        }
1351        let between = between.sqrt();
1352        let within: f64 = (0..9)
1353            .map(|i| (z[[i, 0]] - class_means[i / 3]).powi(2))
1354            .sum::<f64>()
1355            .sqrt();
1356        assert!(
1357            between > 5.0 * within,
1358            "between/within ratio = {} (between={}, within={}) — class separation too weak",
1359            between / within.max(1e-12),
1360            between,
1361            within,
1362        );
1363    }
1364}