scirs2_stats/multivariate/
factor_analysis.rs

1//! Factor Analysis
2//!
3//! Factor analysis is a dimensionality reduction technique that identifies latent factors
4//! that explain the correlations among observed variables.
5
6use crate::error::{StatsError, StatsResult as Result};
7use scirs2_core::ndarray::{Array1, Array2, ArrayView2, Axis};
8use scirs2_core::random::{rngs::StdRng, SeedableRng};
9use scirs2_core::validation::*;
10
11/// Factor Analysis model
12#[derive(Debug, Clone)]
13pub struct FactorAnalysis {
14    /// Number of factors to extract
15    pub n_factors: usize,
16    /// Maximum number of iterations for EM algorithm
17    pub max_iter: usize,
18    /// Convergence tolerance
19    pub tol: f64,
20    /// Whether to perform varimax rotation
21    pub rotation: RotationType,
22    /// Random state for reproducibility
23    pub random_state: Option<u64>,
24}
25
26/// Type of factor rotation
27#[derive(Debug, Clone, Copy, PartialEq)]
28pub enum RotationType {
29    /// No rotation
30    None,
31    /// Varimax rotation (orthogonal)
32    Varimax,
33    /// Promax rotation (oblique)
34    Promax,
35}
36
37/// Result of factor analysis
38#[derive(Debug, Clone)]
39pub struct FactorAnalysisResult {
40    /// Factor loadings matrix (p x k)
41    pub loadings: Array2<f64>,
42    /// Specific variances (unique factors)
43    pub noise_variance: Array1<f64>,
44    /// Factors scores for training data (n x k)
45    pub scores: Array2<f64>,
46    /// Mean of training data
47    pub mean: Array1<f64>,
48    /// Log-likelihood of the model
49    pub log_likelihood: f64,
50    /// Number of iterations until convergence
51    pub n_iter: usize,
52    /// Proportion of variance explained by each factor
53    pub explained_variance_ratio: Array1<f64>,
54    /// Communalities (proportion of variance in each variable explained by factors)
55    pub communalities: Array1<f64>,
56}
57
58impl Default for FactorAnalysis {
59    fn default() -> Self {
60        Self {
61            n_factors: 2,
62            max_iter: 1000,
63            tol: 1e-6,
64            rotation: RotationType::Varimax,
65            random_state: None,
66        }
67    }
68}
69
70impl FactorAnalysis {
71    /// Create a new factor analysis instance
72    pub fn new(n_factors: usize) -> Result<Self> {
73        check_positive(n_factors, "n_factors")?;
74        Ok(Self {
75            n_factors,
76            ..Default::default()
77        })
78    }
79
80    /// Set maximum iterations
81    pub fn with_max_iter(mut self, max_iter: usize) -> Self {
82        self.max_iter = max_iter;
83        self
84    }
85
86    /// Set convergence tolerance
87    pub fn with_tolerance(mut self, tol: f64) -> Self {
88        self.tol = tol;
89        self
90    }
91
92    /// Set rotation type
93    pub fn with_rotation(mut self, rotation: RotationType) -> Self {
94        self.rotation = rotation;
95        self
96    }
97
98    /// Set random state
99    pub fn with_random_state(mut self, seed: u64) -> Self {
100        self.random_state = Some(seed);
101        self
102    }
103
104    /// Fit the factor analysis model
105    pub fn fit(&self, data: ArrayView2<f64>) -> Result<FactorAnalysisResult> {
106        checkarray_finite(&data, "data")?;
107        let (n_samples, n_features) = data.dim();
108
109        if n_samples < 2 {
110            return Err(StatsError::InvalidArgument(
111                "n_samples must be at least 2".to_string(),
112            ));
113        }
114
115        if self.n_factors >= n_features {
116            return Err(StatsError::InvalidArgument(format!(
117                "n_factors ({}) must be less than n_features ({})",
118                self.n_factors, n_features
119            )));
120        }
121
122        // Center the data
123        let mean = data.mean_axis(Axis(0)).unwrap();
124        let mut centereddata = data.to_owned();
125        for mut row in centereddata.rows_mut() {
126            row -= &mean;
127        }
128
129        // Initialize parameters
130        let (mut loadings, mut psi) = self.initialize_parameters(&centereddata)?;
131
132        let mut prev_log_likelihood = f64::NEG_INFINITY;
133        let mut n_iter = 0;
134
135        // EM algorithm
136        for iteration in 0..self.max_iter {
137            // E-step: compute expected sufficient statistics
138            let (e_h, e_hht) = self.e_step(&centereddata, &loadings, &psi)?;
139
140            // M-step: update parameters
141            let (new_loadings, new_psi) = self.m_step(&centereddata, &e_h, &e_hht)?;
142
143            // Compute log-likelihood
144            let log_likelihood =
145                self.compute_log_likelihood(&centereddata, &new_loadings, &new_psi)?;
146
147            // Check convergence
148            if (log_likelihood - prev_log_likelihood).abs() < self.tol {
149                loadings = new_loadings;
150                psi = new_psi;
151                n_iter = iteration + 1;
152                break;
153            }
154
155            loadings = new_loadings;
156            psi = new_psi;
157            prev_log_likelihood = log_likelihood;
158            n_iter = iteration + 1;
159        }
160
161        if n_iter == self.max_iter {
162            return Err(StatsError::ConvergenceError(format!(
163                "EM algorithm failed to converge after {} iterations",
164                self.max_iter
165            )));
166        }
167
168        // Apply rotation if specified
169        let rotated_loadings = match self.rotation {
170            RotationType::None => loadings,
171            RotationType::Varimax => self.varimax_rotation(&loadings)?,
172            RotationType::Promax => self.promax_rotation(&loadings)?,
173        };
174
175        // Compute factor scores
176        let scores = self.compute_factor_scores(&centereddata, &rotated_loadings, &psi)?;
177
178        // Compute explained variance and communalities
179        let explained_variance_ratio = self.compute_explained_variance(&rotated_loadings);
180        let communalities = self.compute_communalities(&rotated_loadings);
181
182        // Final log-likelihood
183        let final_log_likelihood =
184            self.compute_log_likelihood(&centereddata, &rotated_loadings, &psi)?;
185
186        Ok(FactorAnalysisResult {
187            loadings: rotated_loadings,
188            noise_variance: psi,
189            scores,
190            mean,
191            log_likelihood: final_log_likelihood,
192            n_iter,
193            explained_variance_ratio,
194            communalities,
195        })
196    }
197
198    /// Initialize factor loadings and specific variances
199    fn initialize_parameters(&self, data: &Array2<f64>) -> Result<(Array2<f64>, Array1<f64>)> {
200        let (n_samples, n_features) = data.dim();
201
202        // Initialize using SVD of data
203        use scirs2_core::ndarray::ndarray_linalg::SVD;
204        let (u, s, vt) = data.svd(false, true).map_err(|e| {
205            StatsError::ComputationError(format!("SVD initialization failed: {}", e))
206        })?;
207
208        let v = vt.unwrap().t().to_owned();
209
210        // Initial loadings from first k components
211        let mut loadings = Array2::zeros((n_features, self.n_factors));
212        for i in 0..self.n_factors {
213            let scale = (s[i] / (n_samples as f64).sqrt()).max(1e-6);
214            for j in 0..n_features {
215                loadings[[j, i]] = v[[j, i]] * scale;
216            }
217        }
218
219        // Initialize specific variances
220        let mut psi = Array1::ones(n_features);
221        for i in 0..n_features {
222            let communality = loadings.row(i).dot(&loadings.row(i));
223            psi[i] = (1.0 - communality).max(0.01); // Ensure positive
224        }
225
226        Ok((loadings, psi))
227    }
228
229    /// E-step of EM algorithm
230    fn e_step(
231        &self,
232        data: &Array2<f64>,
233        loadings: &Array2<f64>,
234        psi: &Array1<f64>,
235    ) -> Result<(Array2<f64>, Array2<f64>)> {
236        let (n_samples, n_features) = data.dim();
237
238        // Construct precision matrix: Psi^{-1}
239        let mut psi_inv = Array2::zeros((n_features, n_features));
240        for i in 0..n_features {
241            if psi[i] <= 0.0 {
242                return Err(StatsError::ComputationError(
243                    "Specific variances must be positive".to_string(),
244                ));
245            }
246            psi_inv[[i, i]] = 1.0 / psi[i];
247        }
248
249        // Compute M = I + L^T Psi^{-1} L
250        let lt_psi_inv = loadings.t().dot(&psi_inv);
251        let m = Array2::eye(self.n_factors) + lt_psi_inv.dot(loadings);
252
253        // Invert M
254        let m_inv = scirs2_linalg::inv(&m.view(), None).map_err(|e| {
255            StatsError::ComputationError(format!("Failed to invert M matrix: {}", e))
256        })?;
257
258        // Compute conditional expectations
259        let mut e_h = Array2::zeros((n_samples, self.n_factors));
260        let e_hht = m_inv.clone(); // This is E[h h^T | x]
261
262        for i in 0..n_samples {
263            let x = data.row(i);
264            let e_h_i = m_inv.dot(&lt_psi_inv.dot(&x.to_owned()));
265            e_h.row_mut(i).assign(&e_h_i);
266        }
267
268        Ok((e_h, e_hht))
269    }
270
271    /// M-step of EM algorithm
272    fn m_step(
273        &self,
274        data: &Array2<f64>,
275        e_h: &Array2<f64>,
276        e_hht: &Array2<f64>,
277    ) -> Result<(Array2<f64>, Array1<f64>)> {
278        let (n_samples, n_features) = data.dim();
279
280        // Update loadings: L = (X^T E[H]) (E[H^T H])^{-1}
281        let xte_h = data.t().dot(e_h);
282        let sum_e_hht = e_hht * n_samples as f64; // Sum over samples
283
284        let sum_e_hht_inv = scirs2_linalg::inv(&sum_e_hht.view(), None).map_err(|e| {
285            StatsError::ComputationError(format!("Failed to invert sum E[HH^T]: {}", e))
286        })?;
287
288        let new_loadings = xte_h.dot(&sum_e_hht_inv);
289
290        // Update specific variances
291        let mut new_psi = Array1::zeros(n_features);
292
293        for j in 0..n_features {
294            let x_j = data.column(j);
295            let l_j = new_loadings.row(j);
296
297            let mut sum_var = 0.0;
298            for i in 0..n_samples {
299                let x_ij = x_j[i];
300                let e_h_i = e_h.row(i);
301                let residual = x_ij - l_j.dot(&e_h_i.to_owned());
302                sum_var += residual * residual;
303
304                // Add E[h h^T] term
305                let quad_form = l_j.dot(&e_hht.dot(&l_j.to_owned()));
306                sum_var += quad_form;
307            }
308
309            new_psi[j] = (sum_var / n_samples as f64).max(1e-6); // Ensure positive
310        }
311
312        Ok((new_loadings, new_psi))
313    }
314
315    /// Compute log-likelihood
316    fn compute_log_likelihood(
317        &self,
318        data: &Array2<f64>,
319        loadings: &Array2<f64>,
320        psi: &Array1<f64>,
321    ) -> Result<f64> {
322        let (n_samples, n_features) = data.dim();
323
324        // Construct covariance matrix: Sigma = L L^T + Psi
325        let ll_t = loadings.dot(&loadings.t());
326        let mut sigma = ll_t;
327        for i in 0..n_features {
328            sigma[[i, i]] += psi[i];
329        }
330
331        // Compute determinant and inverse
332        let det_sigma = scirs2_linalg::det(&sigma.view(), None).map_err(|e| {
333            StatsError::ComputationError(format!("Failed to compute determinant: {}", e))
334        })?;
335
336        if det_sigma <= 0.0 {
337            return Err(StatsError::ComputationError(
338                "Covariance matrix must be positive definite".to_string(),
339            ));
340        }
341
342        let sigma_inv = scirs2_linalg::inv(&sigma.view(), None).map_err(|e| {
343            StatsError::ComputationError(format!("Failed to invert covariance: {}", e))
344        })?;
345
346        // Compute log-likelihood
347        let mut log_likelihood = 0.0;
348        let log_det_term =
349            -0.5 * n_features as f64 * (2.0 * std::f64::consts::PI).ln() - 0.5 * det_sigma.ln();
350
351        for i in 0..n_samples {
352            let x = data.row(i);
353            let quad_form = x.dot(&sigma_inv.dot(&x.to_owned()));
354            log_likelihood += log_det_term - 0.5 * quad_form;
355        }
356
357        Ok(log_likelihood)
358    }
359
360    /// Varimax rotation
361    fn varimax_rotation(&self, loadings: &Array2<f64>) -> Result<Array2<f64>> {
362        let (n_features, n_factors) = loadings.dim();
363        let mut rotated = loadings.clone();
364
365        let max_iter = 30;
366        let tol = 1e-6;
367
368        for _ in 0..max_iter {
369            let rotation_matrix = Array2::<f64>::eye(n_factors);
370            let mut converged = true;
371
372            // Rotate each pair of factors
373            for i in 0..n_factors {
374                for j in (i + 1)..n_factors {
375                    let col_i = rotated.column(i).to_owned();
376                    let col_j = rotated.column(j).to_owned();
377
378                    // Compute rotation angle
379                    let u = &col_i * &col_i - &col_j * &col_j;
380                    let v = 2.0 * &col_i * &col_j;
381
382                    let a = u.sum();
383                    let b = v.sum();
384                    let c = (&u * &u - &v * &v).sum();
385                    let d = 2.0 * (&u * &v).sum();
386
387                    let num = d - 2.0 * a * b / n_features as f64;
388                    let den = c - (a * a - b * b) / n_features as f64;
389
390                    if den.abs() < 1e-10 {
391                        continue;
392                    }
393
394                    let phi = 0.25 * (num / den).atan();
395
396                    if phi.abs() > tol {
397                        converged = false;
398
399                        // Apply rotation
400                        let cos_phi = phi.cos();
401                        let sin_phi = phi.sin();
402
403                        let new_col_i = cos_phi * &col_i - sin_phi * &col_j;
404                        let new_col_j = sin_phi * &col_i + cos_phi * &col_j;
405
406                        rotated.column_mut(i).assign(&new_col_i);
407                        rotated.column_mut(j).assign(&new_col_j);
408                    }
409                }
410            }
411
412            if converged {
413                break;
414            }
415        }
416
417        Ok(rotated)
418    }
419
420    /// Promax rotation (oblique)
421    fn promax_rotation(&self, loadings: &Array2<f64>) -> Result<Array2<f64>> {
422        // First apply varimax rotation
423        let varimax_rotated = self.varimax_rotation(loadings)?;
424
425        // Then apply promax transformation
426        let kappa = 4.0; // Power parameter
427        let (n_features, n_factors) = varimax_rotated.dim();
428
429        // Compute target matrix by raising loadings to power kappa
430        let mut target = Array2::zeros((n_features, n_factors));
431        for i in 0..n_features {
432            for j in 0..n_factors {
433                let val = varimax_rotated[[i, j]];
434                target[[i, j]] = val.abs().powf(kappa) * val.signum();
435            }
436        }
437
438        // Solve for transformation matrix using least squares
439        // T = (L^T L)^{-1} L^T P where P is target
440        let ltl = varimax_rotated.t().dot(&varimax_rotated);
441        let ltl_inv = scirs2_linalg::inv(&ltl.view(), None)
442            .map_err(|e| StatsError::ComputationError(format!("Failed to invert L^T L: {}", e)))?;
443
444        let ltp = varimax_rotated.t().dot(&target);
445        let transform = ltl_inv.dot(&ltp);
446
447        // Apply transformation
448        let rotated = varimax_rotated.dot(&transform);
449
450        Ok(rotated)
451    }
452
453    /// Compute factor scores using regression method
454    fn compute_factor_scores(
455        &self,
456        data: &Array2<f64>,
457        loadings: &Array2<f64>,
458        psi: &Array1<f64>,
459    ) -> Result<Array2<f64>> {
460        let n_features = loadings.nrows();
461
462        // Construct precision matrix
463        let mut psi_inv = Array2::zeros((n_features, n_features));
464        for i in 0..n_features {
465            psi_inv[[i, i]] = 1.0 / psi[i];
466        }
467
468        // Compute factor score coefficient matrix: (L^T Psi^{-1} L)^{-1} L^T Psi^{-1}
469        let lt_psi_inv = loadings.t().dot(&psi_inv);
470        let lt_psi_inv_l = lt_psi_inv.dot(loadings);
471
472        let lt_psi_inv_l_inv = scirs2_linalg::inv(&lt_psi_inv_l.view(), None).map_err(|e| {
473            StatsError::ComputationError(format!("Failed to compute factor score weights: {}", e))
474        })?;
475
476        let score_weights = lt_psi_inv_l_inv.dot(&lt_psi_inv);
477
478        // Compute scores
479        let scores = data.dot(&score_weights.t());
480
481        Ok(scores)
482    }
483
484    /// Compute explained variance ratio for each factor
485    fn compute_explained_variance(&self, loadings: &Array2<f64>) -> Array1<f64> {
486        let factor_variances = loadings
487            .axis_iter(Axis(1))
488            .map(|col| col.dot(&col))
489            .collect::<Vec<_>>();
490
491        let total_variance: f64 = factor_variances.iter().sum();
492
493        Array1::from_vec(factor_variances).mapv(|v| v / total_variance)
494    }
495
496    /// Compute communalities (proportion of variance explained for each variable)
497    fn compute_communalities(&self, loadings: &Array2<f64>) -> Array1<f64> {
498        let mut communalities = Array1::zeros(loadings.nrows());
499
500        for i in 0..loadings.nrows() {
501            communalities[i] = loadings.row(i).dot(&loadings.row(i));
502        }
503
504        communalities
505    }
506
507    /// Transform new data to factor space
508    pub fn transform(
509        &self,
510        data: ArrayView2<f64>,
511        result: &FactorAnalysisResult,
512    ) -> Result<Array2<f64>> {
513        checkarray_finite(&data, "data")?;
514
515        if data.ncols() != result.mean.len() {
516            return Err(StatsError::DimensionMismatch(format!(
517                "data has {} features, expected {}",
518                data.ncols(),
519                result.mean.len()
520            )));
521        }
522
523        // Center the data
524        let mut centered = data.to_owned();
525        for mut row in centered.rows_mut() {
526            row -= &result.mean;
527        }
528
529        // Compute factor scores
530        self.compute_factor_scores(&centered, &result.loadings, &result.noise_variance)
531    }
532}
533
534/// Exploratory Factor Analysis (EFA) utilities
535pub mod efa {
536    use super::*;
537
538    /// Determine optimal number of factors using parallel analysis
539    pub fn parallel_analysis(
540        data: ArrayView2<f64>,
541        n_simulations: usize,
542        percentile: f64,
543        seed: Option<u64>,
544    ) -> Result<usize> {
545        checkarray_finite(&data, "data")?;
546        check_positive(n_simulations, "n_simulations")?;
547
548        if percentile <= 0.0 || percentile >= 100.0 {
549            return Err(StatsError::InvalidArgument(
550                "percentile must be between 0 and 100".to_string(),
551            ));
552        }
553
554        let (n_samples, n_features) = data.dim();
555
556        // Compute eigenvalues of real data correlation matrix
557        let real_eigenvalues = compute_correlation_eigenvalues(data)?;
558
559        // Initialize RNG
560        let mut rng = match seed {
561            Some(s) => StdRng::seed_from_u64(s),
562            None => {
563                use std::time::{SystemTime, UNIX_EPOCH};
564                let s = SystemTime::now()
565                    .duration_since(UNIX_EPOCH)
566                    .unwrap_or_default()
567                    .as_secs();
568                StdRng::seed_from_u64(s)
569            }
570        };
571
572        // Generate random data and compute eigenvalues
573        let mut simulated_eigenvalues = Vec::with_capacity(n_simulations);
574
575        for _ in 0..n_simulations {
576            // Generate random normal data with same dimensions
577            let mut randomdata = Array2::zeros((n_samples, n_features));
578            use scirs2_core::random::{Distribution, Normal};
579            let normal = Normal::new(0.0, 1.0).map_err(|e| {
580                StatsError::ComputationError(format!("Failed to create normal distribution: {}", e))
581            })?;
582
583            for i in 0..n_samples {
584                for j in 0..n_features {
585                    randomdata[[i, j]] = normal.sample(&mut rng);
586                }
587            }
588
589            let eigenvalues = compute_correlation_eigenvalues(randomdata.view())?;
590            simulated_eigenvalues.push(eigenvalues);
591        }
592
593        // Compute percentile thresholds
594        let mut thresholds = Array1::zeros(n_features);
595        for i in 0..n_features {
596            let mut values: Vec<f64> = simulated_eigenvalues.iter().map(|ev| ev[i]).collect();
597            values.sort_by(|a, b| a.partial_cmp(b).unwrap());
598
599            let index = ((percentile / 100.0) * (n_simulations - 1) as f64).round() as usize;
600            thresholds[i] = values[index.min(n_simulations - 1)];
601        }
602
603        // Count factors where real eigenvalue > threshold
604        let mut n_factors = 0;
605        for i in 0..n_features {
606            if real_eigenvalues[i] > thresholds[i] {
607                n_factors += 1;
608            } else {
609                break;
610            }
611        }
612
613        Ok(n_factors.max(1)) // At least 1 factor
614    }
615
616    /// Compute eigenvalues of correlation matrix
617    fn compute_correlation_eigenvalues(data: ArrayView2<f64>) -> Result<Array1<f64>> {
618        // Center data
619        let mean = data.mean_axis(Axis(0)).unwrap();
620        let mut centered = data.to_owned();
621        for mut row in centered.rows_mut() {
622            row -= &mean;
623        }
624
625        // Compute correlation matrix
626        let cov = centered.t().dot(&centered) / (data.nrows() - 1) as f64;
627
628        // Standardize to correlation
629        let mut corr = cov.clone();
630        for i in 0..corr.nrows() {
631            for j in 0..corr.ncols() {
632                let std_i = cov[[i, i]].sqrt();
633                let std_j = cov[[j, j]].sqrt();
634                if std_i > 1e-10 && std_j > 1e-10 {
635                    corr[[i, j]] = cov[[i, j]] / (std_i * std_j);
636                }
637            }
638        }
639
640        // Compute eigenvalues
641        use scirs2_core::ndarray::ndarray_linalg::Eigh;
642        let eigenvalues = corr
643            .eigh(scirs2_core::ndarray::ndarray_linalg::UPLO::Upper)
644            .map_err(|e| {
645                StatsError::ComputationError(format!("Eigenvalue decomposition failed: {}", e))
646            })?
647            .0;
648
649        // Sort in descending order
650        let mut sorted_eigenvalues = eigenvalues.to_vec();
651        sorted_eigenvalues.sort_by(|a, b| b.partial_cmp(a).unwrap());
652
653        Ok(Array1::from_vec(sorted_eigenvalues))
654    }
655
656    /// Kaiser-Meyer-Olkin (KMO) measure of sampling adequacy
657    pub fn kmo_test(data: ArrayView2<f64>) -> Result<f64> {
658        checkarray_finite(&data, "data")?;
659
660        // Compute correlation matrix
661        let mean = data.mean_axis(Axis(0)).unwrap();
662        let mut centered = data.to_owned();
663        for mut row in centered.rows_mut() {
664            row -= &mean;
665        }
666
667        let cov = centered.t().dot(&centered) / (data.nrows() - 1) as f64;
668        let n = cov.nrows();
669
670        // Standardize to correlation
671        let mut corr = Array2::zeros((n, n));
672        for i in 0..n {
673            for j in 0..n {
674                let std_i = cov[[i, i]].sqrt();
675                let std_j = cov[[j, j]].sqrt();
676                if std_i > 1e-10 && std_j > 1e-10 {
677                    corr[[i, j]] = cov[[i, j]] / (std_i * std_j);
678                } else if i == j {
679                    corr[[i, j]] = 1.0;
680                }
681            }
682        }
683
684        // Compute anti-image correlation matrix
685        let corr_inv = scirs2_linalg::inv(&corr.view(), None).map_err(|e| {
686            StatsError::ComputationError(format!("Failed to invert correlation matrix: {}", e))
687        })?;
688
689        // Compute KMO statistic
690        let mut sum_squared_corr = 0.0;
691        let mut sum_squared_partial = 0.0;
692
693        for i in 0..n {
694            for j in 0..n {
695                if i != j {
696                    sum_squared_corr += corr[[i, j]] * corr[[i, j]];
697
698                    // Partial correlation
699                    let partial = -corr_inv[[i, j]] / (corr_inv[[i, i]] * corr_inv[[j, j]]).sqrt();
700                    sum_squared_partial += partial * partial;
701                }
702            }
703        }
704
705        let kmo = sum_squared_corr / (sum_squared_corr + sum_squared_partial);
706        Ok(kmo)
707    }
708
709    /// Bartlett's test of sphericity
710    pub fn bartlett_test(data: ArrayView2<f64>) -> Result<(f64, f64)> {
711        checkarray_finite(&data, "data")?;
712        let (n, p) = data.dim();
713
714        if n <= p {
715            return Err(StatsError::InvalidArgument(
716                "Number of samples must exceed number of variables".to_string(),
717            ));
718        }
719
720        // Compute correlation matrix
721        let mean = data.mean_axis(Axis(0)).unwrap();
722        let mut centered = data.to_owned();
723        for mut row in centered.rows_mut() {
724            row -= &mean;
725        }
726
727        let cov = centered.t().dot(&centered) / (n - 1) as f64;
728
729        // Standardize to correlation
730        let mut corr = Array2::zeros((p, p));
731        for i in 0..p {
732            for j in 0..p {
733                let std_i = cov[[i, i]].sqrt();
734                let std_j = cov[[j, j]].sqrt();
735                if std_i > 1e-10 && std_j > 1e-10 {
736                    corr[[i, j]] = cov[[i, j]] / (std_i * std_j);
737                } else if i == j {
738                    corr[[i, j]] = 1.0;
739                }
740            }
741        }
742
743        // Compute test statistic
744        let det_corr = scirs2_linalg::det(&corr.view(), None).map_err(|e| {
745            StatsError::ComputationError(format!("Failed to compute determinant: {}", e))
746        })?;
747
748        if det_corr <= 0.0 {
749            return Err(StatsError::ComputationError(
750                "Correlation matrix must be positive definite".to_string(),
751            ));
752        }
753
754        let chi2 = -(n as f64 - 1.0 - (2.0 * p as f64 + 5.0) / 6.0) * det_corr.ln();
755        let df = p * (p - 1) / 2;
756
757        // Approximate p-value using chi-square distribution
758        let p_value = chi2_survival(chi2, df as f64);
759
760        Ok((chi2, p_value))
761    }
762}
763
764/// Approximate survival function for chi-square distribution
765#[allow(dead_code)]
766fn chi2_survival(x: f64, df: f64) -> f64 {
767    if x <= 0.0 {
768        return 1.0;
769    }
770
771    // Very rough approximation - in practice use proper chi-square CDF
772    let mean = df;
773    let var = 2.0 * df;
774    let std = var.sqrt();
775
776    // Normal approximation for large df
777    if df > 30.0 {
778        let z = (x - mean) / std;
779        return 0.5 * (1.0 - erf(z / std::f64::consts::SQRT_2));
780    }
781
782    // Simple exponential approximation for small df
783    (-x / mean).exp()
784}
785
786/// Error function approximation
787#[allow(dead_code)]
788fn erf(x: f64) -> f64 {
789    // Abramowitz and Stegun approximation
790    let a1 = 0.254829592;
791    let a2 = -0.284496736;
792    let a3 = 1.421413741;
793    let a4 = -1.453152027;
794    let a5 = 1.061405429;
795    let p = 0.3275911;
796
797    let sign = if x >= 0.0 { 1.0 } else { -1.0 };
798    let x = x.abs();
799
800    let t = 1.0 / (1.0 + p * x);
801    let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
802
803    sign * y
804}