Skip to main content

scirs2_stats/multivariate/
factor_analysis.rs

1//! Factor Analysis
2//!
3//! Factor analysis is a dimensionality reduction technique that identifies latent factors
4//! that explain the correlations among observed variables.
5
6use crate::error::{StatsError, StatsResult as Result};
7use scirs2_core::ndarray::{Array1, Array2, ArrayView2, Axis};
8use scirs2_core::random::{rngs::StdRng, SeedableRng};
9use scirs2_core::validation::*;
10
11/// Factor Analysis model
12#[derive(Debug, Clone)]
13pub struct FactorAnalysis {
14    /// Number of factors to extract
15    pub n_factors: usize,
16    /// Maximum number of iterations for EM algorithm
17    pub max_iter: usize,
18    /// Convergence tolerance
19    pub tol: f64,
20    /// Whether to perform varimax rotation
21    pub rotation: RotationType,
22    /// Random state for reproducibility
23    pub random_state: Option<u64>,
24}
25
26/// Type of factor rotation
27#[derive(Debug, Clone, Copy, PartialEq)]
28pub enum RotationType {
29    /// No rotation
30    None,
31    /// Varimax rotation (orthogonal)
32    Varimax,
33    /// Promax rotation (oblique)
34    Promax,
35}
36
37/// Result of factor analysis
38#[derive(Debug, Clone)]
39pub struct FactorAnalysisResult {
40    /// Factor loadings matrix (p x k)
41    pub loadings: Array2<f64>,
42    /// Specific variances (unique factors)
43    pub noise_variance: Array1<f64>,
44    /// Factors scores for training data (n x k)
45    pub scores: Array2<f64>,
46    /// Mean of training data
47    pub mean: Array1<f64>,
48    /// Log-likelihood of the model
49    pub log_likelihood: f64,
50    /// Number of iterations until convergence
51    pub n_iter: usize,
52    /// Proportion of variance explained by each factor
53    pub explained_variance_ratio: Array1<f64>,
54    /// Communalities (proportion of variance in each variable explained by factors)
55    pub communalities: Array1<f64>,
56}
57
58impl Default for FactorAnalysis {
59    fn default() -> Self {
60        Self {
61            n_factors: 2,
62            max_iter: 1000,
63            tol: 1e-6,
64            rotation: RotationType::Varimax,
65            random_state: None,
66        }
67    }
68}
69
70impl FactorAnalysis {
71    /// Create a new factor analysis instance
72    pub fn new(n_factors: usize) -> Result<Self> {
73        check_positive(n_factors, "n_factors")?;
74        Ok(Self {
75            n_factors,
76            ..Default::default()
77        })
78    }
79
80    /// Set maximum iterations
81    pub fn with_max_iter(mut self, max_iter: usize) -> Self {
82        self.max_iter = max_iter;
83        self
84    }
85
86    /// Set convergence tolerance
87    pub fn with_tolerance(mut self, tol: f64) -> Self {
88        self.tol = tol;
89        self
90    }
91
92    /// Set rotation type
93    pub fn with_rotation(mut self, rotation: RotationType) -> Self {
94        self.rotation = rotation;
95        self
96    }
97
98    /// Set random state
99    pub fn with_random_state(mut self, seed: u64) -> Self {
100        self.random_state = Some(seed);
101        self
102    }
103
104    /// Fit the factor analysis model
105    pub fn fit(&self, data: ArrayView2<f64>) -> Result<FactorAnalysisResult> {
106        checkarray_finite(&data, "data")?;
107        let (n_samples, n_features) = data.dim();
108
109        if n_samples < 2 {
110            return Err(StatsError::InvalidArgument(
111                "n_samples must be at least 2".to_string(),
112            ));
113        }
114
115        if self.n_factors >= n_features {
116            return Err(StatsError::InvalidArgument(format!(
117                "n_factors ({}) must be less than n_features ({})",
118                self.n_factors, n_features
119            )));
120        }
121
122        // Center the data
123        let mean = data.mean_axis(Axis(0)).expect("Operation failed");
124        let mut centereddata = data.to_owned();
125        for mut row in centereddata.rows_mut() {
126            row -= &mean;
127        }
128
129        // Initialize parameters
130        let (mut loadings, mut psi) = self.initialize_parameters(&centereddata)?;
131
132        let mut prev_log_likelihood = f64::NEG_INFINITY;
133        let mut n_iter = 0;
134
135        // EM algorithm
136        for iteration in 0..self.max_iter {
137            // E-step: compute expected sufficient statistics
138            let (e_h, e_hht) = self.e_step(&centereddata, &loadings, &psi)?;
139
140            // M-step: update parameters
141            let (new_loadings, new_psi) = self.m_step(&centereddata, &e_h, &e_hht)?;
142
143            // Compute log-likelihood
144            let log_likelihood =
145                self.compute_log_likelihood(&centereddata, &new_loadings, &new_psi)?;
146
147            // Check convergence
148            if (log_likelihood - prev_log_likelihood).abs() < self.tol {
149                loadings = new_loadings;
150                psi = new_psi;
151                n_iter = iteration + 1;
152                break;
153            }
154
155            loadings = new_loadings;
156            psi = new_psi;
157            prev_log_likelihood = log_likelihood;
158            n_iter = iteration + 1;
159        }
160
161        if n_iter == self.max_iter {
162            return Err(StatsError::ConvergenceError(format!(
163                "EM algorithm failed to converge after {} iterations",
164                self.max_iter
165            )));
166        }
167
168        // Apply rotation if specified
169        let rotated_loadings = match self.rotation {
170            RotationType::None => loadings,
171            RotationType::Varimax => self.varimax_rotation(&loadings)?,
172            RotationType::Promax => self.promax_rotation(&loadings)?,
173        };
174
175        // Compute factor scores
176        let scores = self.compute_factor_scores(&centereddata, &rotated_loadings, &psi)?;
177
178        // Compute explained variance and communalities
179        let explained_variance_ratio = self.compute_explained_variance(&rotated_loadings);
180        let communalities = self.compute_communalities(&rotated_loadings);
181
182        // Final log-likelihood
183        let final_log_likelihood =
184            self.compute_log_likelihood(&centereddata, &rotated_loadings, &psi)?;
185
186        Ok(FactorAnalysisResult {
187            loadings: rotated_loadings,
188            noise_variance: psi,
189            scores,
190            mean,
191            log_likelihood: final_log_likelihood,
192            n_iter,
193            explained_variance_ratio,
194            communalities,
195        })
196    }
197
198    /// Initialize factor loadings and specific variances
199    fn initialize_parameters(&self, data: &Array2<f64>) -> Result<(Array2<f64>, Array1<f64>)> {
200        let (n_samples, n_features) = data.dim();
201
202        // Initialize using SVD of data (using scirs2_linalg)
203        let (_u, s, vt) = scirs2_linalg::svd(&data.view(), false, None).map_err(|e| {
204            StatsError::ComputationError(format!("SVD initialization failed: {}", e))
205        })?;
206
207        let v = vt.t().to_owned();
208
209        // Initial loadings from first k components
210        let mut loadings = Array2::zeros((n_features, self.n_factors));
211        for i in 0..self.n_factors {
212            let scale = (s[i] / (n_samples as f64).sqrt()).max(1e-6);
213            for j in 0..n_features {
214                loadings[[j, i]] = v[[j, i]] * scale;
215            }
216        }
217
218        // Initialize specific variances
219        let mut psi = Array1::ones(n_features);
220        for i in 0..n_features {
221            let communality = loadings.row(i).dot(&loadings.row(i));
222            psi[i] = (1.0 - communality).max(0.01); // Ensure positive
223        }
224
225        Ok((loadings, psi))
226    }
227
228    /// E-step of EM algorithm
229    fn e_step(
230        &self,
231        data: &Array2<f64>,
232        loadings: &Array2<f64>,
233        psi: &Array1<f64>,
234    ) -> Result<(Array2<f64>, Array2<f64>)> {
235        let (n_samples, n_features) = data.dim();
236
237        // Construct precision matrix: Psi^{-1}
238        let mut psi_inv = Array2::zeros((n_features, n_features));
239        for i in 0..n_features {
240            if psi[i] <= 0.0 {
241                return Err(StatsError::ComputationError(
242                    "Specific variances must be positive".to_string(),
243                ));
244            }
245            psi_inv[[i, i]] = 1.0 / psi[i];
246        }
247
248        // Compute M = I + L^T Psi^{-1} L
249        let lt_psi_inv = loadings.t().dot(&psi_inv);
250        let m = Array2::eye(self.n_factors) + lt_psi_inv.dot(loadings);
251
252        // Invert M
253        let m_inv = scirs2_linalg::inv(&m.view(), None).map_err(|e| {
254            StatsError::ComputationError(format!("Failed to invert M matrix: {}", e))
255        })?;
256
257        // Compute conditional expectations
258        let mut e_h = Array2::zeros((n_samples, self.n_factors));
259        let e_hht = m_inv.clone(); // This is E[h h^T | x]
260
261        for i in 0..n_samples {
262            let x = data.row(i);
263            let e_h_i = m_inv.dot(&lt_psi_inv.dot(&x.to_owned()));
264            e_h.row_mut(i).assign(&e_h_i);
265        }
266
267        Ok((e_h, e_hht))
268    }
269
270    /// M-step of EM algorithm
271    fn m_step(
272        &self,
273        data: &Array2<f64>,
274        e_h: &Array2<f64>,
275        e_hht: &Array2<f64>,
276    ) -> Result<(Array2<f64>, Array1<f64>)> {
277        let (n_samples, n_features) = data.dim();
278
279        // Update loadings: L = (X^T E[H]) (E[H^T H])^{-1}
280        let xte_h = data.t().dot(e_h);
281        let sum_e_hht = e_hht * n_samples as f64; // Sum over samples
282
283        let sum_e_hht_inv = scirs2_linalg::inv(&sum_e_hht.view(), None).map_err(|e| {
284            StatsError::ComputationError(format!("Failed to invert sum E[HH^T]: {}", e))
285        })?;
286
287        let new_loadings = xte_h.dot(&sum_e_hht_inv);
288
289        // Update specific variances
290        let mut new_psi = Array1::zeros(n_features);
291
292        for j in 0..n_features {
293            let x_j = data.column(j);
294            let l_j = new_loadings.row(j);
295
296            let mut sum_var = 0.0;
297            for i in 0..n_samples {
298                let x_ij = x_j[i];
299                let e_h_i = e_h.row(i);
300                let residual = x_ij - l_j.dot(&e_h_i.to_owned());
301                sum_var += residual * residual;
302
303                // Add E[h h^T] term
304                let quad_form = l_j.dot(&e_hht.dot(&l_j.to_owned()));
305                sum_var += quad_form;
306            }
307
308            new_psi[j] = (sum_var / n_samples as f64).max(1e-6); // Ensure positive
309        }
310
311        Ok((new_loadings, new_psi))
312    }
313
314    /// Compute log-likelihood
315    fn compute_log_likelihood(
316        &self,
317        data: &Array2<f64>,
318        loadings: &Array2<f64>,
319        psi: &Array1<f64>,
320    ) -> Result<f64> {
321        let (n_samples, n_features) = data.dim();
322
323        // Construct covariance matrix: Sigma = L L^T + Psi
324        let ll_t = loadings.dot(&loadings.t());
325        let mut sigma = ll_t;
326        for i in 0..n_features {
327            sigma[[i, i]] += psi[i];
328        }
329
330        // Compute determinant and inverse
331        let det_sigma = scirs2_linalg::det(&sigma.view(), None).map_err(|e| {
332            StatsError::ComputationError(format!("Failed to compute determinant: {}", e))
333        })?;
334
335        if det_sigma <= 0.0 {
336            return Err(StatsError::ComputationError(
337                "Covariance matrix must be positive definite".to_string(),
338            ));
339        }
340
341        let sigma_inv = scirs2_linalg::inv(&sigma.view(), None).map_err(|e| {
342            StatsError::ComputationError(format!("Failed to invert covariance: {}", e))
343        })?;
344
345        // Compute log-likelihood
346        let mut log_likelihood = 0.0;
347        let log_det_term =
348            -0.5 * n_features as f64 * (2.0 * std::f64::consts::PI).ln() - 0.5 * det_sigma.ln();
349
350        for i in 0..n_samples {
351            let x = data.row(i);
352            let quad_form = x.dot(&sigma_inv.dot(&x.to_owned()));
353            log_likelihood += log_det_term - 0.5 * quad_form;
354        }
355
356        Ok(log_likelihood)
357    }
358
359    /// Varimax rotation
360    fn varimax_rotation(&self, loadings: &Array2<f64>) -> Result<Array2<f64>> {
361        let (n_features, n_factors) = loadings.dim();
362        let mut rotated = loadings.clone();
363
364        let max_iter = 30;
365        let tol = 1e-6;
366
367        for _ in 0..max_iter {
368            let rotation_matrix = Array2::<f64>::eye(n_factors);
369            let mut converged = true;
370
371            // Rotate each pair of factors
372            for i in 0..n_factors {
373                for j in (i + 1)..n_factors {
374                    let col_i = rotated.column(i).to_owned();
375                    let col_j = rotated.column(j).to_owned();
376
377                    // Compute rotation angle
378                    let u = &col_i * &col_i - &col_j * &col_j;
379                    let v = 2.0 * &col_i * &col_j;
380
381                    let a = u.sum();
382                    let b = v.sum();
383                    let c = (&u * &u - &v * &v).sum();
384                    let d = 2.0 * (&u * &v).sum();
385
386                    let num = d - 2.0 * a * b / n_features as f64;
387                    let den = c - (a * a - b * b) / n_features as f64;
388
389                    if den.abs() < 1e-10 {
390                        continue;
391                    }
392
393                    let phi = 0.25 * (num / den).atan();
394
395                    if phi.abs() > tol {
396                        converged = false;
397
398                        // Apply rotation
399                        let cos_phi = phi.cos();
400                        let sin_phi = phi.sin();
401
402                        let new_col_i = cos_phi * &col_i - sin_phi * &col_j;
403                        let new_col_j = sin_phi * &col_i + cos_phi * &col_j;
404
405                        rotated.column_mut(i).assign(&new_col_i);
406                        rotated.column_mut(j).assign(&new_col_j);
407                    }
408                }
409            }
410
411            if converged {
412                break;
413            }
414        }
415
416        Ok(rotated)
417    }
418
419    /// Promax rotation (oblique)
420    fn promax_rotation(&self, loadings: &Array2<f64>) -> Result<Array2<f64>> {
421        // First apply varimax rotation
422        let varimax_rotated = self.varimax_rotation(loadings)?;
423
424        // Then apply promax transformation
425        let kappa = 4.0; // Power parameter
426        let (n_features, n_factors) = varimax_rotated.dim();
427
428        // Compute target matrix by raising loadings to power kappa
429        let mut target = Array2::zeros((n_features, n_factors));
430        for i in 0..n_features {
431            for j in 0..n_factors {
432                let val = varimax_rotated[[i, j]];
433                target[[i, j]] = val.abs().powf(kappa) * val.signum();
434            }
435        }
436
437        // Solve for transformation matrix using least squares
438        // T = (L^T L)^{-1} L^T P where P is target
439        let ltl = varimax_rotated.t().dot(&varimax_rotated);
440        let ltl_inv = scirs2_linalg::inv(&ltl.view(), None)
441            .map_err(|e| StatsError::ComputationError(format!("Failed to invert L^T L: {}", e)))?;
442
443        let ltp = varimax_rotated.t().dot(&target);
444        let transform = ltl_inv.dot(&ltp);
445
446        // Apply transformation
447        let rotated = varimax_rotated.dot(&transform);
448
449        Ok(rotated)
450    }
451
452    /// Compute factor scores using regression method
453    fn compute_factor_scores(
454        &self,
455        data: &Array2<f64>,
456        loadings: &Array2<f64>,
457        psi: &Array1<f64>,
458    ) -> Result<Array2<f64>> {
459        let n_features = loadings.nrows();
460
461        // Construct precision matrix
462        let mut psi_inv = Array2::zeros((n_features, n_features));
463        for i in 0..n_features {
464            psi_inv[[i, i]] = 1.0 / psi[i];
465        }
466
467        // Compute factor score coefficient matrix: (L^T Psi^{-1} L)^{-1} L^T Psi^{-1}
468        let lt_psi_inv = loadings.t().dot(&psi_inv);
469        let lt_psi_inv_l = lt_psi_inv.dot(loadings);
470
471        let lt_psi_inv_l_inv = scirs2_linalg::inv(&lt_psi_inv_l.view(), None).map_err(|e| {
472            StatsError::ComputationError(format!("Failed to compute factor score weights: {}", e))
473        })?;
474
475        let score_weights = lt_psi_inv_l_inv.dot(&lt_psi_inv);
476
477        // Compute scores
478        let scores = data.dot(&score_weights.t());
479
480        Ok(scores)
481    }
482
483    /// Compute explained variance ratio for each factor
484    fn compute_explained_variance(&self, loadings: &Array2<f64>) -> Array1<f64> {
485        let factor_variances = loadings
486            .axis_iter(Axis(1))
487            .map(|col| col.dot(&col))
488            .collect::<Vec<_>>();
489
490        let total_variance: f64 = factor_variances.iter().sum();
491
492        Array1::from_vec(factor_variances).mapv(|v| v / total_variance)
493    }
494
495    /// Compute communalities (proportion of variance explained for each variable)
496    fn compute_communalities(&self, loadings: &Array2<f64>) -> Array1<f64> {
497        let mut communalities = Array1::zeros(loadings.nrows());
498
499        for i in 0..loadings.nrows() {
500            communalities[i] = loadings.row(i).dot(&loadings.row(i));
501        }
502
503        communalities
504    }
505
506    /// Transform new data to factor space
507    pub fn transform(
508        &self,
509        data: ArrayView2<f64>,
510        result: &FactorAnalysisResult,
511    ) -> Result<Array2<f64>> {
512        checkarray_finite(&data, "data")?;
513
514        if data.ncols() != result.mean.len() {
515            return Err(StatsError::DimensionMismatch(format!(
516                "data has {} features, expected {}",
517                data.ncols(),
518                result.mean.len()
519            )));
520        }
521
522        // Center the data
523        let mut centered = data.to_owned();
524        for mut row in centered.rows_mut() {
525            row -= &result.mean;
526        }
527
528        // Compute factor scores
529        self.compute_factor_scores(&centered, &result.loadings, &result.noise_variance)
530    }
531}
532
533/// Exploratory Factor Analysis (EFA) utilities
534pub mod efa {
535    use super::*;
536
537    /// Determine optimal number of factors using parallel analysis
538    pub fn parallel_analysis(
539        data: ArrayView2<f64>,
540        n_simulations: usize,
541        percentile: f64,
542        seed: Option<u64>,
543    ) -> Result<usize> {
544        checkarray_finite(&data, "data")?;
545        check_positive(n_simulations, "n_simulations")?;
546
547        if percentile <= 0.0 || percentile >= 100.0 {
548            return Err(StatsError::InvalidArgument(
549                "percentile must be between 0 and 100".to_string(),
550            ));
551        }
552
553        let (n_samples, n_features) = data.dim();
554
555        // Compute eigenvalues of real data correlation matrix
556        let real_eigenvalues = compute_correlation_eigenvalues(data)?;
557
558        // Initialize RNG
559        let mut rng = match seed {
560            Some(s) => StdRng::seed_from_u64(s),
561            None => {
562                use std::time::{SystemTime, UNIX_EPOCH};
563                let s = SystemTime::now()
564                    .duration_since(UNIX_EPOCH)
565                    .unwrap_or_default()
566                    .as_secs();
567                StdRng::seed_from_u64(s)
568            }
569        };
570
571        // Generate random data and compute eigenvalues
572        let mut simulated_eigenvalues = Vec::with_capacity(n_simulations);
573
574        for _ in 0..n_simulations {
575            // Generate random normal data with same dimensions
576            let mut randomdata = Array2::zeros((n_samples, n_features));
577            use scirs2_core::random::{Distribution, Normal};
578            let normal = Normal::new(0.0, 1.0).map_err(|e| {
579                StatsError::ComputationError(format!("Failed to create normal distribution: {}", e))
580            })?;
581
582            for i in 0..n_samples {
583                for j in 0..n_features {
584                    randomdata[[i, j]] = normal.sample(&mut rng);
585                }
586            }
587
588            let eigenvalues = compute_correlation_eigenvalues(randomdata.view())?;
589            simulated_eigenvalues.push(eigenvalues);
590        }
591
592        // Compute percentile thresholds
593        let mut thresholds = Array1::zeros(n_features);
594        for i in 0..n_features {
595            let mut values: Vec<f64> = simulated_eigenvalues.iter().map(|ev| ev[i]).collect();
596            values.sort_by(|a, b| a.partial_cmp(b).expect("Operation failed"));
597
598            let index = ((percentile / 100.0) * (n_simulations - 1) as f64).round() as usize;
599            thresholds[i] = values[index.min(n_simulations - 1)];
600        }
601
602        // Count factors where real eigenvalue > threshold
603        let mut n_factors = 0;
604        for i in 0..n_features {
605            if real_eigenvalues[i] > thresholds[i] {
606                n_factors += 1;
607            } else {
608                break;
609            }
610        }
611
612        Ok(n_factors.max(1)) // At least 1 factor
613    }
614
615    /// Compute eigenvalues of correlation matrix
616    fn compute_correlation_eigenvalues(data: ArrayView2<f64>) -> Result<Array1<f64>> {
617        // Center data
618        let mean = data.mean_axis(Axis(0)).expect("Operation failed");
619        let mut centered = data.to_owned();
620        for mut row in centered.rows_mut() {
621            row -= &mean;
622        }
623
624        // Compute correlation matrix
625        let cov = centered.t().dot(&centered) / (data.nrows() - 1) as f64;
626
627        // Standardize to correlation
628        let mut corr = cov.clone();
629        for i in 0..corr.nrows() {
630            for j in 0..corr.ncols() {
631                let std_i = cov[[i, i]].sqrt();
632                let std_j = cov[[j, j]].sqrt();
633                if std_i > 1e-10 && std_j > 1e-10 {
634                    corr[[i, j]] = cov[[i, j]] / (std_i * std_j);
635                }
636            }
637        }
638
639        // Compute eigenvalues using scirs2_linalg
640        let (eigenvalues, _eigenvectors) =
641            scirs2_linalg::eigh_f64_lapack(&corr.view()).map_err(|e| {
642                StatsError::ComputationError(format!("Eigenvalue decomposition failed: {}", e))
643            })?;
644
645        // Sort in descending order
646        let mut sorted_eigenvalues: Vec<f64> = eigenvalues.to_vec();
647        sorted_eigenvalues.sort_by(|a: &f64, b: &f64| b.partial_cmp(a).expect("Operation failed"));
648
649        Ok(Array1::from_vec(sorted_eigenvalues))
650    }
651
652    /// Kaiser-Meyer-Olkin (KMO) measure of sampling adequacy
653    pub fn kmo_test(data: ArrayView2<f64>) -> Result<f64> {
654        checkarray_finite(&data, "data")?;
655
656        // Compute correlation matrix
657        let mean = data.mean_axis(Axis(0)).expect("Operation failed");
658        let mut centered = data.to_owned();
659        for mut row in centered.rows_mut() {
660            row -= &mean;
661        }
662
663        let cov = centered.t().dot(&centered) / (data.nrows() - 1) as f64;
664        let n = cov.nrows();
665
666        // Standardize to correlation
667        let mut corr = Array2::zeros((n, n));
668        for i in 0..n {
669            for j in 0..n {
670                let std_i = cov[[i, i]].sqrt();
671                let std_j = cov[[j, j]].sqrt();
672                if std_i > 1e-10 && std_j > 1e-10 {
673                    corr[[i, j]] = cov[[i, j]] / (std_i * std_j);
674                } else if i == j {
675                    corr[[i, j]] = 1.0;
676                }
677            }
678        }
679
680        // Compute anti-image correlation matrix
681        let corr_inv = scirs2_linalg::inv(&corr.view(), None).map_err(|e| {
682            StatsError::ComputationError(format!("Failed to invert correlation matrix: {}", e))
683        })?;
684
685        // Compute KMO statistic
686        let mut sum_squared_corr = 0.0;
687        let mut sum_squared_partial = 0.0;
688
689        for i in 0..n {
690            for j in 0..n {
691                if i != j {
692                    sum_squared_corr += corr[[i, j]] * corr[[i, j]];
693
694                    // Partial correlation
695                    let partial = -corr_inv[[i, j]] / (corr_inv[[i, i]] * corr_inv[[j, j]]).sqrt();
696                    sum_squared_partial += partial * partial;
697                }
698            }
699        }
700
701        let kmo = sum_squared_corr / (sum_squared_corr + sum_squared_partial);
702        Ok(kmo)
703    }
704
705    /// Bartlett's test of sphericity
706    pub fn bartlett_test(data: ArrayView2<f64>) -> Result<(f64, f64)> {
707        checkarray_finite(&data, "data")?;
708        let (n, p) = data.dim();
709
710        if n <= p {
711            return Err(StatsError::InvalidArgument(
712                "Number of samples must exceed number of variables".to_string(),
713            ));
714        }
715
716        // Compute correlation matrix
717        let mean = data.mean_axis(Axis(0)).expect("Operation failed");
718        let mut centered = data.to_owned();
719        for mut row in centered.rows_mut() {
720            row -= &mean;
721        }
722
723        let cov = centered.t().dot(&centered) / (n - 1) as f64;
724
725        // Standardize to correlation
726        let mut corr = Array2::zeros((p, p));
727        for i in 0..p {
728            for j in 0..p {
729                let std_i = cov[[i, i]].sqrt();
730                let std_j = cov[[j, j]].sqrt();
731                if std_i > 1e-10 && std_j > 1e-10 {
732                    corr[[i, j]] = cov[[i, j]] / (std_i * std_j);
733                } else if i == j {
734                    corr[[i, j]] = 1.0;
735                }
736            }
737        }
738
739        // Compute test statistic
740        let det_corr = scirs2_linalg::det(&corr.view(), None).map_err(|e| {
741            StatsError::ComputationError(format!("Failed to compute determinant: {}", e))
742        })?;
743
744        if det_corr <= 0.0 {
745            return Err(StatsError::ComputationError(
746                "Correlation matrix must be positive definite".to_string(),
747            ));
748        }
749
750        let chi2 = -(n as f64 - 1.0 - (2.0 * p as f64 + 5.0) / 6.0) * det_corr.ln();
751        let df = p * (p - 1) / 2;
752
753        // Approximate p-value using chi-square distribution
754        let p_value = chi2_survival(chi2, df as f64);
755
756        Ok((chi2, p_value))
757    }
758}
759
760/// Approximate survival function for chi-square distribution
761#[allow(dead_code)]
762fn chi2_survival(x: f64, df: f64) -> f64 {
763    if x <= 0.0 {
764        return 1.0;
765    }
766
767    // Very rough approximation - in practice use proper chi-square CDF
768    let mean = df;
769    let var = 2.0 * df;
770    let std = var.sqrt();
771
772    // Normal approximation for large df
773    if df > 30.0 {
774        let z = (x - mean) / std;
775        return 0.5 * (1.0 - erf(z / std::f64::consts::SQRT_2));
776    }
777
778    // Simple exponential approximation for small df
779    (-x / mean).exp()
780}
781
782/// Error function approximation
783#[allow(dead_code)]
784fn erf(x: f64) -> f64 {
785    // Abramowitz and Stegun approximation
786    let a1 = 0.254829592;
787    let a2 = -0.284496736;
788    let a3 = 1.421413741;
789    let a4 = -1.453152027;
790    let a5 = 1.061405429;
791    let p = 0.3275911;
792
793    let sign = if x >= 0.0 { 1.0 } else { -1.0 };
794    let x = x.abs();
795
796    let t = 1.0 / (1.0 + p * x);
797    let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
798
799    sign * y
800}