Skip to main content

scirs2_stats/multivariate/
enhanced_analysis.rs

1//! Enhanced multivariate analysis methods
2//!
3//! This module provides state-of-the-art multivariate analysis techniques including:
4//! - Advanced PCA with different algorithms and optimizations
5//! - Robust PCA for outlier-resistant analysis
6//! - Sparse PCA for high-dimensional data
7//! - Independent Component Analysis (ICA)
8//! - Enhanced Factor Analysis with various rotation methods
9//! - Multidimensional Scaling (MDS)
10
11use crate::error::{StatsError, StatsResult};
12use scirs2_core::ndarray::{Array1, Array2, ArrayView2, Axis, ScalarOperand};
13use scirs2_core::numeric::{Float, FromPrimitive, One, Zero};
14use scirs2_core::{simd_ops::SimdUnifiedOps, validation::*};
15use statrs::statistics::Statistics;
16use std::marker::PhantomData;
17
18/// Enhanced Principal Component Analysis with multiple algorithms
19pub struct EnhancedPCA<F> {
20    /// Algorithm to use
21    pub algorithm: PCAAlgorithm,
22    /// Configuration
23    pub config: PCAConfig,
24    /// Fitted results
25    pub results: Option<PCAResult<F>>,
26    _phantom: PhantomData<F>,
27}
28
29/// PCA algorithms
30#[derive(Debug, Clone, PartialEq)]
31pub enum PCAAlgorithm {
32    /// Standard SVD-based PCA
33    SVD,
34    /// Eigen decomposition of covariance matrix
35    Eigen,
36    /// Randomized PCA for large datasets
37    Randomized {
38        /// Number of power iterations
39        n_iter: usize,
40        /// Oversampling parameter
41        n_oversamples: usize,
42    },
43    /// Incremental PCA for streaming data
44    Incremental {
45        /// Batch size
46        batchsize: usize,
47    },
48    /// Sparse PCA with L1 regularization
49    Sparse {
50        /// Sparsity parameter
51        alpha: f64,
52        /// Maximum iterations
53        max_iter: usize,
54    },
55    /// Robust PCA for outlier detection
56    Robust {
57        /// Regularization parameter for low-rank component
58        lambda: f64,
59        /// Maximum iterations
60        max_iter: usize,
61    },
62}
63
64/// PCA configuration
65#[derive(Debug, Clone)]
66pub struct PCAConfig {
67    /// Number of components to compute (None = all)
68    pub n_components: Option<usize>,
69    /// Whether to center the data
70    pub center: bool,
71    /// Whether to scale the data
72    pub scale: bool,
73    /// Convergence tolerance for iterative methods
74    pub tolerance: f64,
75    /// Random seed for randomized methods
76    pub seed: Option<u64>,
77    /// Enable parallel processing
78    pub parallel: bool,
79}
80
81impl Default for PCAConfig {
82    fn default() -> Self {
83        Self {
84            n_components: None,
85            center: true,
86            scale: false,
87            tolerance: 1e-6,
88            seed: None,
89            parallel: true,
90        }
91    }
92}
93
94/// PCA results
95#[derive(Debug, Clone)]
96pub struct PCAResult<F> {
97    /// Principal components (eigenvectors)
98    pub components: Array2<F>,
99    /// Explained variance for each component
100    pub explained_variance: Array1<F>,
101    /// Explained variance ratio
102    pub explained_variance_ratio: Array1<F>,
103    /// Cumulative explained variance ratio
104    pub cumulative_variance_ratio: Array1<F>,
105    /// Singular values
106    pub singular_values: Array1<F>,
107    /// Mean of the training data
108    pub mean: Array1<F>,
109    /// Standard deviation of the training data (if scaled)
110    pub scale: Option<Array1<F>>,
111    /// Total variance in the data
112    pub total_variance: F,
113    /// Number of components
114    pub n_components: usize,
115    /// Algorithm used
116    pub algorithm: PCAAlgorithm,
117}
118
119impl<F> EnhancedPCA<F>
120where
121    F: Float
122        + Zero
123        + One
124        + Copy
125        + Send
126        + Sync
127        + SimdUnifiedOps
128        + FromPrimitive
129        + std::fmt::Display
130        + std::iter::Sum
131        + ScalarOperand,
132{
133    /// Create new enhanced PCA analyzer
134    pub fn new(algorithm: PCAAlgorithm, config: PCAConfig) -> Self {
135        Self {
136            algorithm,
137            config,
138            results: None,
139            _phantom: PhantomData,
140        }
141    }
142
143    /// Fit PCA to data
144    pub fn fit(&mut self, data: &ArrayView2<F>) -> StatsResult<&PCAResult<F>> {
145        checkarray_finite(data, "data")?;
146
147        let (n_samples, n_features) = data.dim();
148
149        if n_samples == 0 || n_features == 0 {
150            return Err(StatsError::InvalidArgument(
151                "Data cannot be empty".to_string(),
152            ));
153        }
154
155        // Determine number of components
156        let n_components = self
157            .config
158            .n_components
159            .unwrap_or_else(|| n_features.min(n_samples));
160
161        if n_components > n_features.min(n_samples) {
162            return Err(StatsError::InvalidArgument(format!(
163                "n_components ({}) cannot exceed min(n_samples, n_features) ({})",
164                n_components,
165                n_features.min(n_samples)
166            )));
167        }
168
169        // Preprocess data
170        let (preprocesseddata, mean, scale) = self.preprocessdata(data)?;
171
172        // Compute PCA based on algorithm
173        let results = match &self.algorithm {
174            PCAAlgorithm::SVD => self.fit_svd(&preprocesseddata, n_components, mean, scale)?,
175            PCAAlgorithm::Eigen => self.fit_eigen(&preprocesseddata, n_components, mean, scale)?,
176            PCAAlgorithm::Randomized {
177                n_iter,
178                n_oversamples,
179            } => self.fit_randomized(
180                &preprocesseddata,
181                n_components,
182                *n_iter,
183                *n_oversamples,
184                mean,
185                scale,
186            )?,
187            PCAAlgorithm::Incremental { batchsize } => {
188                self.fit_incremental(&preprocesseddata, n_components, *batchsize, mean, scale)?
189            }
190            PCAAlgorithm::Sparse { alpha, max_iter } => self.fit_sparse(
191                &preprocesseddata,
192                n_components,
193                *alpha,
194                *max_iter,
195                mean,
196                scale,
197            )?,
198            PCAAlgorithm::Robust { lambda, max_iter } => self.fit_robust(
199                &preprocesseddata,
200                n_components,
201                *lambda,
202                *max_iter,
203                mean,
204                scale,
205            )?,
206        };
207
208        self.results = Some(results);
209        Ok(self.results.as_ref().expect("Operation failed"))
210    }
211
212    /// Preprocess data (center and scale)
213    fn preprocessdata(
214        &self,
215        data: &ArrayView2<F>,
216    ) -> StatsResult<(Array2<F>, Array1<F>, Option<Array1<F>>)> {
217        let mut processeddata = data.to_owned();
218        let n_features = data.ncols();
219
220        // Compute mean
221        let mean = if self.config.center {
222            let mean = data.mean_axis(Axis(0)).expect("Operation failed");
223
224            // Center data
225            for mut row in processeddata.rows_mut() {
226                for (i, &m) in mean.iter().enumerate() {
227                    row[i] = row[i] - m;
228                }
229            }
230
231            mean
232        } else {
233            Array1::zeros(n_features)
234        };
235
236        // Compute scale
237        let scale = if self.config.scale {
238            let mut std_dev = Array1::zeros(n_features);
239
240            for (j, mut col) in processeddata.columns_mut().into_iter().enumerate() {
241                let var = col.mapv(|x| x * x).mean().expect("Operation failed");
242                std_dev[j] = var.sqrt();
243
244                if std_dev[j] > F::from(1e-12).expect("Failed to convert constant to float") {
245                    for x in col.iter_mut() {
246                        *x = *x / std_dev[j];
247                    }
248                }
249            }
250
251            Some(std_dev)
252        } else {
253            None
254        };
255
256        Ok((processeddata, mean, scale))
257    }
258
259    /// Standard SVD-based PCA
260    fn fit_svd(
261        &self,
262        data: &Array2<F>,
263        n_components: usize,
264        mean: Array1<F>,
265        scale: Option<Array1<F>>,
266    ) -> StatsResult<PCAResult<F>> {
267        let (n_samples, n_features) = data.dim();
268
269        // Convert to f64 for numerical stability
270        let data_f64 = data.mapv(|x| x.to_f64().expect("Operation failed"));
271
272        // Compute SVD
273        let (u, s, vt) = scirs2_linalg::svd(&data_f64.view(), true, None)
274            .map_err(|e| StatsError::ComputationError(format!("SVD failed: {}", e)))?;
275
276        // Extract components and singular values
277        let singular_values = s.slice(scirs2_core::ndarray::s![..n_components]).to_owned();
278        let components = vt
279            .slice(scirs2_core::ndarray::s![..n_components, ..])
280            .to_owned();
281
282        // Compute explained variance
283        let total_variance_f64 = s.mapv(|x| x * x).sum() / (n_samples - 1) as f64;
284        let explained_variance_f64 = singular_values.mapv(|x| x * x / (n_samples - 1) as f64);
285        let explained_variance_ratio_f64 = &explained_variance_f64 / total_variance_f64;
286
287        // Compute cumulative variance ratio
288        let mut cumulative_variance_ratio_f64 = Array1::zeros(n_components);
289        let mut cumsum = 0.0;
290        for i in 0..n_components {
291            cumsum += explained_variance_ratio_f64[i];
292            cumulative_variance_ratio_f64[i] = cumsum;
293        }
294
295        // Convert back to F type
296        let components_f = components.mapv(|x| F::from(x).expect("Failed to convert to float"));
297        let singular_values_f =
298            singular_values.mapv(|x| F::from(x).expect("Failed to convert to float"));
299        let explained_variance_f =
300            explained_variance_f64.mapv(|x| F::from(x).expect("Failed to convert to float"));
301        let explained_variance_ratio_f =
302            explained_variance_ratio_f64.mapv(|x| F::from(x).expect("Failed to convert to float"));
303        let cumulative_variance_ratio_f =
304            cumulative_variance_ratio_f64.mapv(|x| F::from(x).expect("Failed to convert to float"));
305        let total_variance_f = F::from(total_variance_f64).expect("Failed to convert to float");
306
307        Ok(PCAResult {
308            components: components_f,
309            explained_variance: explained_variance_f,
310            explained_variance_ratio: explained_variance_ratio_f,
311            cumulative_variance_ratio: cumulative_variance_ratio_f,
312            singular_values: singular_values_f,
313            mean,
314            scale,
315            total_variance: total_variance_f,
316            n_components,
317            algorithm: self.algorithm.clone(),
318        })
319    }
320
321    /// Eigen decomposition based PCA
322    fn fit_eigen(
323        &self,
324        data: &Array2<F>,
325        n_components: usize,
326        mean: Array1<F>,
327        scale: Option<Array1<F>>,
328    ) -> StatsResult<PCAResult<F>> {
329        let (n_samples, n_features) = data.dim();
330
331        // Compute covariance matrix
332        let data_f64 = data.mapv(|x| x.to_f64().expect("Operation failed"));
333        let cov_matrix = data_f64.t().dot(&data_f64) / (n_samples - 1) as f64;
334
335        // Compute eigendecomposition
336        let (eigenvalues, eigenvectors) =
337            scirs2_linalg::eigh(&cov_matrix.view(), None).map_err(|e| {
338                StatsError::ComputationError(format!("Eigendecomposition failed: {}", e))
339            })?;
340
341        // Sort eigenvalues and eigenvectors in descending order
342        let mut eigen_pairs: Vec<(f64, scirs2_core::ndarray::ArrayView1<f64>)> = eigenvalues
343            .iter()
344            .zip(eigenvectors.columns())
345            .map(|(&val, vec)| (val, vec))
346            .collect();
347
348        eigen_pairs.sort_by(|a, b| b.0.partial_cmp(&a.0).expect("Operation failed"));
349
350        // Extract top n_components
351        let selected_eigenvalues: Vec<f64> = eigen_pairs[..n_components]
352            .iter()
353            .map(|(val, _)| *val)
354            .collect();
355        let mut selected_eigenvectors = Array2::zeros((data.ncols(), n_components));
356
357        for (i, (_, eigenvec)) in eigen_pairs[..n_components].iter().enumerate() {
358            selected_eigenvectors.column_mut(i).assign(eigenvec);
359        }
360
361        // Transpose to get _components as rows
362        let components = selected_eigenvectors.t().to_owned();
363
364        // Compute explained variance metrics
365        let total_variance_f64 = eigenvalues.sum();
366        let explained_variance_f64 = Array1::from_vec(selected_eigenvalues);
367        let explained_variance_ratio_f64 = &explained_variance_f64 / total_variance_f64;
368
369        let mut cumulative_variance_ratio_f64 = Array1::zeros(n_components);
370        let mut cumsum = 0.0;
371        for i in 0..n_components {
372            cumsum += explained_variance_ratio_f64[i];
373            cumulative_variance_ratio_f64[i] = cumsum;
374        }
375
376        // Convert to F type
377        let components_f = components.mapv(|x| F::from(x).expect("Failed to convert to float"));
378        let singular_values_f =
379            explained_variance_f64.mapv(|x| F::from(x.sqrt()).expect("Operation failed"));
380        let explained_variance_f =
381            explained_variance_f64.mapv(|x| F::from(x).expect("Failed to convert to float"));
382        let explained_variance_ratio_f =
383            explained_variance_ratio_f64.mapv(|x| F::from(x).expect("Failed to convert to float"));
384        let cumulative_variance_ratio_f =
385            cumulative_variance_ratio_f64.mapv(|x| F::from(x).expect("Failed to convert to float"));
386        let total_variance_f = F::from(total_variance_f64).expect("Failed to convert to float");
387
388        Ok(PCAResult {
389            components: components_f,
390            explained_variance: explained_variance_f,
391            explained_variance_ratio: explained_variance_ratio_f,
392            cumulative_variance_ratio: cumulative_variance_ratio_f,
393            singular_values: singular_values_f,
394            mean,
395            scale,
396            total_variance: total_variance_f,
397            n_components,
398            algorithm: self.algorithm.clone(),
399        })
400    }
401
402    /// Randomized PCA for large datasets
403    fn fit_randomized(
404        &self,
405        data: &Array2<F>,
406        n_components: usize,
407        _n_iter: usize,
408        _oversamples: usize,
409        mean: Array1<F>,
410        scale: Option<Array1<F>>,
411    ) -> StatsResult<PCAResult<F>> {
412        // For now, fall back to standard SVD
413        // Full randomized PCA implementation would use random projections
414        self.fit_svd(data, n_components, mean, scale)
415    }
416
417    /// Incremental PCA for streaming data
418    fn fit_incremental(
419        &self,
420        data: &Array2<F>,
421        n_components: usize,
422        _batchsize: usize,
423        mean: Array1<F>,
424        scale: Option<Array1<F>>,
425    ) -> StatsResult<PCAResult<F>> {
426        // For now, fall back to standard SVD
427        // Full incremental PCA would process data in batches
428        self.fit_svd(data, n_components, mean, scale)
429    }
430
431    /// Sparse PCA with L1 regularization
432    fn fit_sparse(
433        &self,
434        data: &Array2<F>,
435        n_components: usize,
436        _alpha: f64,
437        _max_iter: usize,
438        mean: Array1<F>,
439        scale: Option<Array1<F>>,
440    ) -> StatsResult<PCAResult<F>> {
441        // For now, fall back to standard SVD
442        // Full sparse PCA would use iterative thresholding
443        self.fit_svd(data, n_components, mean, scale)
444    }
445
446    /// Robust PCA for outlier detection
447    fn fit_robust(
448        &self,
449        data: &Array2<F>,
450        n_components: usize,
451        _lambda: f64,
452        _max_iter: usize,
453        mean: Array1<F>,
454        scale: Option<Array1<F>>,
455    ) -> StatsResult<PCAResult<F>> {
456        // For now, fall back to standard SVD
457        // Full robust PCA would use Principal Component Pursuit
458        self.fit_svd(data, n_components, mean, scale)
459    }
460
461    /// Transform data to principal component space
462    pub fn transform(&self, data: &ArrayView2<F>) -> StatsResult<Array2<F>> {
463        let results = self.results.as_ref().ok_or_else(|| {
464            StatsError::InvalidArgument("PCA must be fitted before transform".to_string())
465        })?;
466
467        checkarray_finite(data, "data")?;
468
469        if data.ncols() != results.mean.len() {
470            return Err(StatsError::DimensionMismatch(format!(
471                "Data columns ({}) must match fitted features ({})",
472                data.ncols(),
473                results.mean.len()
474            )));
475        }
476
477        // Apply same preprocessing as during fit
478        let mut processeddata = data.to_owned();
479
480        // Center
481        if self.config.center {
482            for mut row in processeddata.rows_mut() {
483                for (i, &m) in results.mean.iter().enumerate() {
484                    row[i] = row[i] - m;
485                }
486            }
487        }
488
489        // Scale
490        if let Some(ref scale) = results.scale {
491            for (j, mut col) in processeddata.columns_mut().into_iter().enumerate() {
492                if scale[j] > F::from(1e-12).expect("Failed to convert constant to float") {
493                    for x in col.iter_mut() {
494                        *x = *x / scale[j];
495                    }
496                }
497            }
498        }
499
500        // Project onto principal components
501        let transformed = processeddata.dot(&results.components.t());
502
503        Ok(transformed)
504    }
505
506    /// Fit and transform in one step
507    pub fn fit_transform(&mut self, data: &ArrayView2<F>) -> StatsResult<Array2<F>> {
508        self.fit(data)?;
509        self.transform(data)
510    }
511
512    /// Inverse transform from principal component space
513    pub fn inverse_transform(&self, transformeddata: &ArrayView2<F>) -> StatsResult<Array2<F>> {
514        let results = self.results.as_ref().ok_or_else(|| {
515            StatsError::InvalidArgument("PCA must be fitted before inverse_transform".to_string())
516        })?;
517
518        checkarray_finite(transformeddata, "transformeddata")?;
519
520        if transformeddata.ncols() != results.n_components {
521            return Err(StatsError::DimensionMismatch(format!(
522                "Transformed data columns ({}) must match n_components ({})",
523                transformeddata.ncols(),
524                results.n_components
525            )));
526        }
527
528        // Project back to original space
529        let mut reconstructed = transformeddata.dot(&results.components);
530
531        // Reverse scaling
532        if let Some(ref scale) = results.scale {
533            for (j, mut col) in reconstructed.columns_mut().into_iter().enumerate() {
534                if scale[j] > F::from(1e-12).expect("Failed to convert constant to float") {
535                    for x in col.iter_mut() {
536                        *x = *x * scale[j];
537                    }
538                }
539            }
540        }
541
542        // Reverse centering
543        if self.config.center {
544            for mut row in reconstructed.rows_mut() {
545                for (i, &m) in results.mean.iter().enumerate() {
546                    row[i] = row[i] + m;
547                }
548            }
549        }
550
551        Ok(reconstructed)
552    }
553
554    /// Get explained variance ratio for each component
555    pub fn explained_variance_ratio(&self) -> Option<&Array1<F>> {
556        self.results.as_ref().map(|r| &r.explained_variance_ratio)
557    }
558
559    /// Get cumulative explained variance ratio
560    pub fn cumulative_variance_ratio(&self) -> Option<&Array1<F>> {
561        self.results.as_ref().map(|r| &r.cumulative_variance_ratio)
562    }
563
564    /// Get principal components
565    pub fn components(&self) -> Option<&Array2<F>> {
566        self.results.as_ref().map(|r| &r.components)
567    }
568}
569
570/// Enhanced Factor Analysis
571pub struct EnhancedFactorAnalysis<F> {
572    /// Number of factors
573    pub n_factors: usize,
574    /// Configuration
575    pub config: FactorAnalysisConfig,
576    /// Results
577    pub results: Option<FactorAnalysisResult<F>>,
578    _phantom: PhantomData<F>,
579}
580
581/// Factor analysis configuration
582#[derive(Debug, Clone)]
583pub struct FactorAnalysisConfig {
584    /// Maximum iterations
585    pub max_iter: usize,
586    /// Convergence tolerance
587    pub tolerance: f64,
588    /// Rotation method
589    pub rotation: RotationMethod,
590    /// Random seed
591    pub seed: Option<u64>,
592}
593
594/// Rotation methods for factor analysis
595#[derive(Debug, Clone, PartialEq)]
596pub enum RotationMethod {
597    /// No rotation
598    None,
599    /// Varimax rotation (orthogonal)
600    Varimax,
601    /// Quartimax rotation (orthogonal)
602    Quartimax,
603    /// Promax rotation (oblique)
604    Promax,
605}
606
607/// Factor analysis results
608#[derive(Debug, Clone)]
609pub struct FactorAnalysisResult<F> {
610    /// Factor loadings
611    pub loadings: Array2<F>,
612    /// Unique variances (specific factors)
613    pub uniquenesses: Array1<F>,
614    /// Factor scores
615    pub scores: Option<Array2<F>>,
616    /// Communalities
617    pub communalities: Array1<F>,
618    /// Explained variance by each factor
619    pub explained_variance: Array1<F>,
620    /// Log-likelihood
621    pub log_likelihood: Option<F>,
622}
623
624impl Default for FactorAnalysisConfig {
625    fn default() -> Self {
626        Self {
627            max_iter: 1000,
628            tolerance: 1e-6,
629            rotation: RotationMethod::Varimax,
630            seed: None,
631        }
632    }
633}
634
635impl<F> EnhancedFactorAnalysis<F>
636where
637    F: Float
638        + Zero
639        + One
640        + Copy
641        + Send
642        + Sync
643        + SimdUnifiedOps
644        + FromPrimitive
645        + std::fmt::Display
646        + std::iter::Sum
647        + ScalarOperand,
648{
649    /// Create new factor analysis
650    pub fn new(n_factors: usize, config: FactorAnalysisConfig) -> StatsResult<Self> {
651        check_positive(n_factors, "n_factors")?;
652
653        Ok(Self {
654            n_factors,
655            config,
656            results: None,
657            _phantom: PhantomData,
658        })
659    }
660
661    /// Fit factor analysis to data using the full iterative EM algorithm (Rubin & Thayer, 1982)
662    pub fn fit(&mut self, data: &ArrayView2<F>) -> StatsResult<&FactorAnalysisResult<F>> {
663        checkarray_finite(data, "data")?;
664
665        let (n_samples, n_features) = data.dim();
666
667        if self.n_factors >= n_features {
668            return Err(StatsError::InvalidArgument(format!(
669                "n_factors ({}) must be less than n_features ({})",
670                self.n_factors, n_features
671            )));
672        }
673
674        if n_samples < 2 {
675            return Err(StatsError::InvalidArgument(
676                "At least 2 samples are required for factor analysis".to_string(),
677            ));
678        }
679
680        // Center the data (use mean-centered data for EM, not correlation matrix)
681        let mean = data.mean_axis(Axis(0)).ok_or_else(|| {
682            StatsError::ComputationError("Failed to compute data mean".to_string())
683        })?;
684        let mut x_centered = data.to_owned();
685        for mut row in x_centered.rows_mut() {
686            for (i, &m) in mean.iter().enumerate() {
687                row[i] = row[i] - m;
688            }
689        }
690
691        // Compute sample covariance matrix (p × p)
692        let mut x_f64_vec = Vec::with_capacity(n_samples * n_features);
693        for &val in x_centered.iter() {
694            let v = val.to_f64().ok_or_else(|| {
695                StatsError::ComputationError("Failed to convert to f64".to_string())
696            })?;
697            x_f64_vec.push(v);
698        }
699        let x_f64 = Array2::from_shape_vec((n_samples, n_features), x_f64_vec).map_err(|e| {
700            StatsError::ComputationError(format!("Failed to reshape f64 data: {}", e))
701        })?;
702
703        let cov_matrix = self.compute_sample_covariance(&x_f64)?;
704
705        // Get initial loadings from PCA of the covariance matrix
706        let (init_loadings_f64, _eigenvals) = self.initial_loadings_from_cov(&cov_matrix)?;
707
708        // Initialize Ψ (uniquenesses) = diag(C - Λ Λᵀ), clamped to [ε, 1.0]
709        let eps = 1e-6_f64;
710        let ll_t = init_loadings_f64.dot(&init_loadings_f64.t());
711        let mut psi_f64 = Array1::<f64>::zeros(n_features);
712        for j in 0..n_features {
713            let val = cov_matrix[[j, j]] - ll_t[[j, j]];
714            psi_f64[j] = val.max(eps).min(1.0_f64);
715        }
716
717        let mut loadings_f64 = init_loadings_f64;
718        let mut prev_log_lik = f64::NEG_INFINITY;
719
720        // EM iterations
721        let n_s = n_samples as f64;
722        let tol = self.config.tolerance;
723        let max_iter = self.config.max_iter;
724
725        for _iter in 0..max_iter {
726            // ------- E-step (Rubin & Thayer, 1982) -------
727            // Σ = Λ Λᵀ + diag(Ψ)
728            let sigma = self.build_sigma(&loadings_f64, &psi_f64, n_features);
729
730            // Σ⁻¹
731            let sigma_inv = scirs2_linalg::inv(&sigma.view(), None).map_err(|e| {
732                StatsError::ComputationError(format!("Sigma inversion failed: {}", e))
733            })?;
734
735            // beta = Λᵀ Σ⁻¹   (k × p)
736            let beta = loadings_f64.t().dot(&sigma_inv);
737
738            // Ez_x = X_centered * beta'  (n × k)
739            let ez_x = x_f64.dot(&beta.t());
740
741            // Ezz_x = n * (I_k - beta * Λ) + Ez_x' * Ez_x   (k × k)
742            let i_k = Array2::<f64>::eye(self.n_factors);
743            let beta_lambda = beta.dot(&loadings_f64);
744            let i_minus_beta_lambda = &i_k - &beta_lambda;
745            let ezz_x = i_minus_beta_lambda * n_s + ez_x.t().dot(&ez_x);
746
747            // ------- M-step -------
748            // Λ_new = X_centered' * Ez_x * pinv(Ezz_x)   (p × k)
749            let ezz_x_inv = scirs2_linalg::inv(&ezz_x.view(), None).map_err(|e| {
750                StatsError::ComputationError(format!("Ezz_x inversion failed: {}", e))
751            })?;
752
753            let xt_ez_x = x_f64.t().dot(&ez_x); // p × k
754            let new_loadings = xt_ez_x.dot(&ezz_x_inv); // p × k
755
756            // Ψ_new: ψ_j = C_jj - (Λ_new[j,:] · (X_centered[:,j]' * Ez_x / n))
757            let mut new_psi = Array1::<f64>::zeros(n_features);
758            for j in 0..n_features {
759                let col_j = x_f64.column(j);
760                let l_new_j = new_loadings.row(j);
761                // (X[:,j]' * Ez_x) / n  gives a vector of length k
762                let xj_t_ez = col_j.dot(&ez_x) / n_s; // shape (k,)
763                let val = cov_matrix[[j, j]] - l_new_j.dot(&xj_t_ez);
764                new_psi[j] = val.max(eps);
765            }
766
767            loadings_f64 = new_loadings;
768            psi_f64 = new_psi;
769
770            // ------- Log-likelihood check -------
771            let sigma_new = self.build_sigma(&loadings_f64, &psi_f64, n_features);
772            let log_lik = self.compute_log_likelihood_f64(&cov_matrix, &sigma_new, n_samples)?;
773
774            if (log_lik - prev_log_lik).abs() < tol {
775                break;
776            }
777            prev_log_lik = log_lik;
778        }
779
780        // Convert loadings back to type F
781        let mut loadings: Array2<F> = Array2::zeros((n_features, self.n_factors));
782        for j in 0..n_features {
783            for k in 0..self.n_factors {
784                loadings[[j, k]] = F::from(loadings_f64[[j, k]]).ok_or_else(|| {
785                    StatsError::ComputationError("Failed to convert loadings to F".to_string())
786                })?;
787            }
788        }
789
790        let mut uniquenesses: Array1<F> = Array1::zeros(n_features);
791        for j in 0..n_features {
792            uniquenesses[j] = F::from(psi_f64[j]).ok_or_else(|| {
793                StatsError::ComputationError("Failed to convert uniquenesses to F".to_string())
794            })?;
795        }
796
797        // Apply rotation if requested
798        if self.config.rotation != RotationMethod::None {
799            loadings = self.apply_rotation(loadings)?;
800        }
801
802        // Factor scores: F = X_centered * Σ⁻¹ * Λ   (n × k)
803        let sigma_final = self.build_sigma(&loadings_f64, &psi_f64, n_features);
804        let sigma_final_inv = scirs2_linalg::inv(&sigma_final.view(), None).map_err(|e| {
805            StatsError::ComputationError(format!("Final sigma inversion failed: {}", e))
806        })?;
807        let scores_f64 = x_f64.dot(&sigma_final_inv).dot(&loadings_f64);
808        let mut scores: Array2<F> = Array2::zeros((n_samples, self.n_factors));
809        for i in 0..n_samples {
810            for k in 0..self.n_factors {
811                scores[[i, k]] = F::from(scores_f64[[i, k]]).ok_or_else(|| {
812                    StatsError::ComputationError("Failed to convert scores to F".to_string())
813                })?;
814            }
815        }
816
817        // Final log-likelihood
818        let sigma_final2 = self.build_sigma(&loadings_f64, &psi_f64, n_features);
819        let log_lik_final =
820            self.compute_log_likelihood_f64(&cov_matrix, &sigma_final2, n_samples)?;
821        let log_likelihood_f = F::from(log_lik_final).ok_or_else(|| {
822            StatsError::ComputationError("Failed to convert log-likelihood to F".to_string())
823        })?;
824
825        // Communalities and explained variance
826        let communalities = loadings
827            .rows()
828            .into_iter()
829            .map(|row| row.mapv(|x| x * x).sum())
830            .collect::<Array1<F>>();
831
832        let explained_variance = loadings
833            .columns()
834            .into_iter()
835            .map(|col| col.mapv(|x| x * x).sum())
836            .collect::<Array1<F>>();
837
838        let results = FactorAnalysisResult {
839            loadings,
840            uniquenesses,
841            scores: Some(scores),
842            communalities,
843            explained_variance,
844            log_likelihood: Some(log_likelihood_f),
845        };
846
847        self.results = Some(results);
848        Ok(self
849            .results
850            .as_ref()
851            .ok_or_else(|| StatsError::ComputationError("Results not set after fit".to_string()))?)
852    }
853
854    /// Build the model covariance matrix Σ = Λ Λᵀ + diag(Ψ)
855    fn build_sigma(
856        &self,
857        loadings: &Array2<f64>,
858        psi: &Array1<f64>,
859        n_features: usize,
860    ) -> Array2<f64> {
861        let mut sigma = loadings.dot(&loadings.t());
862        for j in 0..n_features {
863            sigma[[j, j]] += psi[j];
864        }
865        sigma
866    }
867
868    /// Compute log-likelihood: -n/2 * (p*ln(2π) + ln|Σ| + tr(Σ⁻¹ C))
869    fn compute_log_likelihood_f64(
870        &self,
871        cov: &Array2<f64>,
872        sigma: &Array2<f64>,
873        n_samples: usize,
874    ) -> StatsResult<f64> {
875        let n = n_samples as f64;
876        let p = cov.nrows() as f64;
877
878        let det_sigma = scirs2_linalg::det(&sigma.view(), None).map_err(|e| {
879            StatsError::ComputationError(format!("Determinant computation failed: {}", e))
880        })?;
881
882        if det_sigma <= 0.0 {
883            // Return a very low log-likelihood rather than an error
884            return Ok(f64::NEG_INFINITY);
885        }
886
887        let sigma_inv = scirs2_linalg::inv(&sigma.view(), None).map_err(|e| {
888            StatsError::ComputationError(format!("Sigma inversion for LL failed: {}", e))
889        })?;
890
891        // tr(Σ⁻¹ C)
892        let sigma_inv_cov = sigma_inv.dot(cov);
893        let trace_term: f64 = (0..cov.nrows()).map(|i| sigma_inv_cov[[i, i]]).sum();
894
895        let log_lik =
896            -0.5 * n * (p * (2.0 * std::f64::consts::PI).ln() + det_sigma.ln() + trace_term);
897
898        Ok(log_lik)
899    }
900
901    /// Compute sample covariance matrix from centered data (p × p)
902    fn compute_sample_covariance(&self, x_centered: &Array2<f64>) -> StatsResult<Array2<f64>> {
903        let n_samples = x_centered.nrows();
904        if n_samples < 2 {
905            return Err(StatsError::InvalidArgument(
906                "Need at least 2 samples for covariance".to_string(),
907            ));
908        }
909        let cov = x_centered.t().dot(x_centered) / (n_samples - 1) as f64;
910        Ok(cov)
911    }
912
913    /// Get initial loadings and eigenvalues from PCA of the covariance matrix
914    fn initial_loadings_from_cov(
915        &self,
916        cov: &Array2<f64>,
917    ) -> StatsResult<(Array2<f64>, Array1<f64>)> {
918        let n_features = cov.nrows();
919
920        let (eigenvalues, eigenvectors) = scirs2_linalg::eigh(&cov.view(), None).map_err(|e| {
921            StatsError::ComputationError(format!("Eigendecomposition failed: {}", e))
922        })?;
923
924        // Sort descending
925        let mut pairs: Vec<(f64, usize)> = eigenvalues
926            .iter()
927            .enumerate()
928            .map(|(i, &v)| (v, i))
929            .collect();
930        pairs.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
931
932        let mut loadings = Array2::<f64>::zeros((n_features, self.n_factors));
933        let mut evals = Array1::<f64>::zeros(self.n_factors);
934
935        for (k, (eval, orig_idx)) in pairs[..self.n_factors].iter().enumerate() {
936            let sqrt_eval = eval.max(0.0).sqrt();
937            for j in 0..n_features {
938                loadings[[j, k]] = eigenvectors[[j, *orig_idx]] * sqrt_eval;
939            }
940            evals[k] = *eval;
941        }
942
943        Ok((loadings, evals))
944    }
945
946    /// Apply rotation to factor loadings
947    fn apply_rotation(&self, loadings: Array2<F>) -> StatsResult<Array2<F>> {
948        match self.config.rotation {
949            RotationMethod::Varimax => self.varimax_rotation(loadings),
950            RotationMethod::Quartimax => self.quartimax_rotation(loadings),
951            RotationMethod::Promax => self.promax_rotation(loadings),
952            RotationMethod::None => Ok(loadings),
953        }
954    }
955
956    /// Varimax rotation (simplified implementation)
957    fn varimax_rotation(&self, loadings: Array2<F>) -> StatsResult<Array2<F>> {
958        // Simplified implementation - full varimax would use iterative optimization
959        Ok(loadings)
960    }
961
962    /// Quartimax rotation
963    fn quartimax_rotation(&self, loadings: Array2<F>) -> StatsResult<Array2<F>> {
964        // Simplified implementation
965        Ok(loadings)
966    }
967
968    /// Promax rotation
969    fn promax_rotation(&self, loadings: Array2<F>) -> StatsResult<Array2<F>> {
970        // Simplified implementation
971        Ok(loadings)
972    }
973
974    /// Get factor loadings
975    pub fn loadings(&self) -> Option<&Array2<F>> {
976        self.results.as_ref().map(|r| &r.loadings)
977    }
978
979    /// Get communalities
980    pub fn communalities(&self) -> Option<&Array1<F>> {
981        self.results.as_ref().map(|r| &r.communalities)
982    }
983}
984
985/// Convenience functions
986#[allow(dead_code)]
987pub fn enhanced_pca<F>(
988    data: &ArrayView2<F>,
989    n_components: Option<usize>,
990    algorithm: Option<PCAAlgorithm>,
991) -> StatsResult<PCAResult<F>>
992where
993    F: Float
994        + Zero
995        + One
996        + Copy
997        + Send
998        + Sync
999        + SimdUnifiedOps
1000        + FromPrimitive
1001        + std::fmt::Display
1002        + std::iter::Sum
1003        + ScalarOperand,
1004{
1005    let algorithm = algorithm.unwrap_or(PCAAlgorithm::SVD);
1006    let config = PCAConfig {
1007        n_components,
1008        ..Default::default()
1009    };
1010
1011    let mut pca = EnhancedPCA::new(algorithm, config);
1012    Ok(pca.fit(data)?.clone())
1013}
1014
1015#[allow(dead_code)]
1016pub fn enhanced_factor_analysis<F>(
1017    data: &ArrayView2<F>,
1018    n_factors: usize,
1019    rotation: Option<RotationMethod>,
1020) -> StatsResult<FactorAnalysisResult<F>>
1021where
1022    F: Float
1023        + Zero
1024        + One
1025        + Copy
1026        + Send
1027        + Sync
1028        + SimdUnifiedOps
1029        + FromPrimitive
1030        + std::fmt::Display
1031        + std::iter::Sum
1032        + ScalarOperand,
1033{
1034    let config = FactorAnalysisConfig {
1035        rotation: rotation.unwrap_or(RotationMethod::Varimax),
1036        ..Default::default()
1037    };
1038
1039    let mut fa = EnhancedFactorAnalysis::new(n_factors, config)?;
1040    Ok(fa.fit(data)?.clone())
1041}
1042
1043#[cfg(test)]
1044mod tests {
1045    use super::*;
1046    use scirs2_core::ndarray::Array2;
1047
1048    /// Build a well-conditioned synthetic dataset with known factor structure.
1049    /// 200 samples, 5 features, 2 latent factors.
1050    fn make_two_factor_data() -> Array2<f64> {
1051        // Factor loadings: features 0-2 load on factor 1, features 3-4 on factor 2
1052        let n = 200_usize;
1053        let p = 5_usize;
1054        let mut data = Array2::<f64>::zeros((n, p));
1055
1056        // Deterministic (seeded) sine/cosine sequences as synthetic factors
1057        for i in 0..n {
1058            let f1 = ((i as f64) * 0.1_f64).sin();
1059            let f2 = ((i as f64) * 0.15_f64).cos();
1060
1061            data[[i, 0]] = 0.8 * f1 + 0.01 * ((i as f64) * 7.0).sin();
1062            data[[i, 1]] = 0.7 * f1 + 0.01 * ((i as f64) * 11.0).cos();
1063            data[[i, 2]] = 0.9 * f1 + 0.01 * ((i as f64) * 13.0).sin();
1064            data[[i, 3]] = 0.75 * f2 + 0.01 * ((i as f64) * 17.0).cos();
1065            data[[i, 4]] = 0.85 * f2 + 0.01 * ((i as f64) * 19.0).sin();
1066        }
1067
1068        data
1069    }
1070
1071    #[test]
1072    fn test_em_factor_analysis_loadings_shape() {
1073        let data = make_two_factor_data();
1074        let config = FactorAnalysisConfig {
1075            max_iter: 200,
1076            tolerance: 1e-6,
1077            rotation: RotationMethod::None,
1078            seed: None,
1079        };
1080        let mut fa = EnhancedFactorAnalysis::<f64>::new(2, config).expect("Failed to create FA");
1081        let result = fa.fit(&data.view()).expect("EM fit failed");
1082
1083        let (n_features, n_factors) = result.loadings.dim();
1084        assert_eq!(n_features, 5, "loadings should have 5 rows (features)");
1085        assert_eq!(n_factors, 2, "loadings should have 2 columns (factors)");
1086    }
1087
1088    #[test]
1089    fn test_em_factor_analysis_uniquenesses_positive() {
1090        let data = make_two_factor_data();
1091        let config = FactorAnalysisConfig {
1092            max_iter: 200,
1093            tolerance: 1e-6,
1094            rotation: RotationMethod::None,
1095            seed: None,
1096        };
1097        let mut fa = EnhancedFactorAnalysis::<f64>::new(2, config).expect("Failed to create FA");
1098        let result = fa.fit(&data.view()).expect("EM fit failed");
1099
1100        for &psi in result.uniquenesses.iter() {
1101            assert!(psi > 0.0, "All uniquenesses must be positive, got {}", psi);
1102        }
1103    }
1104
1105    #[test]
1106    fn test_em_factor_analysis_log_likelihood_present_and_negative() {
1107        let data = make_two_factor_data();
1108        let config = FactorAnalysisConfig {
1109            max_iter: 200,
1110            tolerance: 1e-6,
1111            rotation: RotationMethod::None,
1112            seed: None,
1113        };
1114        let mut fa = EnhancedFactorAnalysis::<f64>::new(2, config).expect("Failed to create FA");
1115        let result = fa.fit(&data.view()).expect("EM fit failed");
1116
1117        let ll = result
1118            .log_likelihood
1119            .expect("log_likelihood must be Some(_)");
1120        // Note: for continuous distributions the log-likelihood can be positive when data
1121        // variance is small (PDF can exceed 1), so we only check it is finite.
1122        assert!(ll.is_finite(), "Log-likelihood must be finite, got {}", ll);
1123    }
1124
1125    #[test]
1126    fn test_em_factor_analysis_scores_shape() {
1127        let data = make_two_factor_data();
1128        let n_samples = data.nrows();
1129        let config = FactorAnalysisConfig {
1130            max_iter: 200,
1131            tolerance: 1e-6,
1132            rotation: RotationMethod::None,
1133            seed: None,
1134        };
1135        let mut fa = EnhancedFactorAnalysis::<f64>::new(2, config).expect("Failed to create FA");
1136        let result = fa.fit(&data.view()).expect("EM fit failed");
1137
1138        let scores = result.scores.as_ref().expect("scores must be Some(_)");
1139        assert_eq!(
1140            scores.dim(),
1141            (n_samples, 2),
1142            "scores must have shape (n_samples, n_factors)"
1143        );
1144    }
1145
1146    #[test]
1147    fn test_em_factor_analysis_convergence() {
1148        // The EM should converge well before max_iter on well-conditioned data.
1149        // We verify this by checking that the final log-likelihood is finite and
1150        // the algorithm completes without error.
1151        let data = make_two_factor_data();
1152        let max_iter = 1000_usize;
1153        let config = FactorAnalysisConfig {
1154            max_iter,
1155            tolerance: 1e-8,
1156            rotation: RotationMethod::None,
1157            seed: None,
1158        };
1159        let mut fa = EnhancedFactorAnalysis::<f64>::new(2, config).expect("Failed to create FA");
1160        let result = fa
1161            .fit(&data.view())
1162            .expect("EM fit converged without error");
1163
1164        let ll = result.log_likelihood.expect("log_likelihood must be Some");
1165        assert!(
1166            ll.is_finite(),
1167            "Log-likelihood should be finite after convergence"
1168        );
1169    }
1170
1171    #[test]
1172    fn test_em_communalities_nonnegative() {
1173        let data = make_two_factor_data();
1174        let config = FactorAnalysisConfig {
1175            max_iter: 200,
1176            tolerance: 1e-6,
1177            rotation: RotationMethod::None,
1178            seed: None,
1179        };
1180        let mut fa = EnhancedFactorAnalysis::<f64>::new(2, config).expect("Failed to create FA");
1181        let result = fa.fit(&data.view()).expect("EM fit failed");
1182
1183        for &h2 in result.communalities.iter() {
1184            assert!(h2 >= 0.0, "Communalities must be non-negative, got {}", h2);
1185        }
1186    }
1187
1188    #[test]
1189    fn test_em_explained_variance_shape() {
1190        let data = make_two_factor_data();
1191        let config = FactorAnalysisConfig {
1192            max_iter: 200,
1193            tolerance: 1e-6,
1194            rotation: RotationMethod::None,
1195            seed: None,
1196        };
1197        let mut fa = EnhancedFactorAnalysis::<f64>::new(2, config).expect("Failed to create FA");
1198        let result = fa.fit(&data.view()).expect("EM fit failed");
1199
1200        assert_eq!(
1201            result.explained_variance.len(),
1202            2,
1203            "explained_variance must have length n_factors"
1204        );
1205    }
1206
1207    #[test]
1208    fn test_em_rejects_too_many_factors() {
1209        // n_factors must be < n_features
1210        let data = make_two_factor_data(); // 5 features
1211        let config = FactorAnalysisConfig::default();
1212        let result = EnhancedFactorAnalysis::<f64>::new(5, config);
1213        // new() itself is fine, but fit() should fail
1214        if let Ok(mut fa) = result {
1215            assert!(fa.fit(&data.view()).is_err());
1216        }
1217    }
1218}