1use crate::error::{StatsError, StatsResult as Result};
7use scirs2_core::ndarray::{Array1, Array2, ArrayView2, Axis};
8use scirs2_core::random::{rngs::StdRng, SeedableRng};
9use scirs2_core::validation::*;
10
11#[derive(Debug, Clone)]
13pub struct FactorAnalysis {
14 pub n_factors: usize,
16 pub max_iter: usize,
18 pub tol: f64,
20 pub rotation: RotationType,
22 pub random_state: Option<u64>,
24}
25
26#[derive(Debug, Clone, Copy, PartialEq)]
28pub enum RotationType {
29 None,
31 Varimax,
33 Promax,
35}
36
37#[derive(Debug, Clone)]
39pub struct FactorAnalysisResult {
40 pub loadings: Array2<f64>,
42 pub noise_variance: Array1<f64>,
44 pub scores: Array2<f64>,
46 pub mean: Array1<f64>,
48 pub log_likelihood: f64,
50 pub n_iter: usize,
52 pub explained_variance_ratio: Array1<f64>,
54 pub communalities: Array1<f64>,
56}
57
58impl Default for FactorAnalysis {
59 fn default() -> Self {
60 Self {
61 n_factors: 2,
62 max_iter: 1000,
63 tol: 1e-6,
64 rotation: RotationType::Varimax,
65 random_state: None,
66 }
67 }
68}
69
70impl FactorAnalysis {
71 pub fn new(n_factors: usize) -> Result<Self> {
73 check_positive(n_factors, "n_factors")?;
74 Ok(Self {
75 n_factors,
76 ..Default::default()
77 })
78 }
79
80 pub fn with_max_iter(mut self, max_iter: usize) -> Self {
82 self.max_iter = max_iter;
83 self
84 }
85
86 pub fn with_tolerance(mut self, tol: f64) -> Self {
88 self.tol = tol;
89 self
90 }
91
92 pub fn with_rotation(mut self, rotation: RotationType) -> Self {
94 self.rotation = rotation;
95 self
96 }
97
98 pub fn with_random_state(mut self, seed: u64) -> Self {
100 self.random_state = Some(seed);
101 self
102 }
103
104 pub fn fit(&self, data: ArrayView2<f64>) -> Result<FactorAnalysisResult> {
106 checkarray_finite(&data, "data")?;
107 let (n_samples, n_features) = data.dim();
108
109 if n_samples < 2 {
110 return Err(StatsError::InvalidArgument(
111 "n_samples must be at least 2".to_string(),
112 ));
113 }
114
115 if self.n_factors >= n_features {
116 return Err(StatsError::InvalidArgument(format!(
117 "n_factors ({}) must be less than n_features ({})",
118 self.n_factors, n_features
119 )));
120 }
121
122 let mean = data.mean_axis(Axis(0)).unwrap();
124 let mut centereddata = data.to_owned();
125 for mut row in centereddata.rows_mut() {
126 row -= &mean;
127 }
128
129 let (mut loadings, mut psi) = self.initialize_parameters(¢ereddata)?;
131
132 let mut prev_log_likelihood = f64::NEG_INFINITY;
133 let mut n_iter = 0;
134
135 for iteration in 0..self.max_iter {
137 let (e_h, e_hht) = self.e_step(¢ereddata, &loadings, &psi)?;
139
140 let (new_loadings, new_psi) = self.m_step(¢ereddata, &e_h, &e_hht)?;
142
143 let log_likelihood =
145 self.compute_log_likelihood(¢ereddata, &new_loadings, &new_psi)?;
146
147 if (log_likelihood - prev_log_likelihood).abs() < self.tol {
149 loadings = new_loadings;
150 psi = new_psi;
151 n_iter = iteration + 1;
152 break;
153 }
154
155 loadings = new_loadings;
156 psi = new_psi;
157 prev_log_likelihood = log_likelihood;
158 n_iter = iteration + 1;
159 }
160
161 if n_iter == self.max_iter {
162 return Err(StatsError::ConvergenceError(format!(
163 "EM algorithm failed to converge after {} iterations",
164 self.max_iter
165 )));
166 }
167
168 let rotated_loadings = match self.rotation {
170 RotationType::None => loadings,
171 RotationType::Varimax => self.varimax_rotation(&loadings)?,
172 RotationType::Promax => self.promax_rotation(&loadings)?,
173 };
174
175 let scores = self.compute_factor_scores(¢ereddata, &rotated_loadings, &psi)?;
177
178 let explained_variance_ratio = self.compute_explained_variance(&rotated_loadings);
180 let communalities = self.compute_communalities(&rotated_loadings);
181
182 let final_log_likelihood =
184 self.compute_log_likelihood(¢ereddata, &rotated_loadings, &psi)?;
185
186 Ok(FactorAnalysisResult {
187 loadings: rotated_loadings,
188 noise_variance: psi,
189 scores,
190 mean,
191 log_likelihood: final_log_likelihood,
192 n_iter,
193 explained_variance_ratio,
194 communalities,
195 })
196 }
197
198 fn initialize_parameters(&self, data: &Array2<f64>) -> Result<(Array2<f64>, Array1<f64>)> {
200 let (n_samples, n_features) = data.dim();
201
202 use scirs2_core::ndarray::ndarray_linalg::SVD;
204 let (u, s, vt) = data.svd(false, true).map_err(|e| {
205 StatsError::ComputationError(format!("SVD initialization failed: {}", e))
206 })?;
207
208 let v = vt.unwrap().t().to_owned();
209
210 let mut loadings = Array2::zeros((n_features, self.n_factors));
212 for i in 0..self.n_factors {
213 let scale = (s[i] / (n_samples as f64).sqrt()).max(1e-6);
214 for j in 0..n_features {
215 loadings[[j, i]] = v[[j, i]] * scale;
216 }
217 }
218
219 let mut psi = Array1::ones(n_features);
221 for i in 0..n_features {
222 let communality = loadings.row(i).dot(&loadings.row(i));
223 psi[i] = (1.0 - communality).max(0.01); }
225
226 Ok((loadings, psi))
227 }
228
229 fn e_step(
231 &self,
232 data: &Array2<f64>,
233 loadings: &Array2<f64>,
234 psi: &Array1<f64>,
235 ) -> Result<(Array2<f64>, Array2<f64>)> {
236 let (n_samples, n_features) = data.dim();
237
238 let mut psi_inv = Array2::zeros((n_features, n_features));
240 for i in 0..n_features {
241 if psi[i] <= 0.0 {
242 return Err(StatsError::ComputationError(
243 "Specific variances must be positive".to_string(),
244 ));
245 }
246 psi_inv[[i, i]] = 1.0 / psi[i];
247 }
248
249 let lt_psi_inv = loadings.t().dot(&psi_inv);
251 let m = Array2::eye(self.n_factors) + lt_psi_inv.dot(loadings);
252
253 let m_inv = scirs2_linalg::inv(&m.view(), None).map_err(|e| {
255 StatsError::ComputationError(format!("Failed to invert M matrix: {}", e))
256 })?;
257
258 let mut e_h = Array2::zeros((n_samples, self.n_factors));
260 let e_hht = m_inv.clone(); for i in 0..n_samples {
263 let x = data.row(i);
264 let e_h_i = m_inv.dot(<_psi_inv.dot(&x.to_owned()));
265 e_h.row_mut(i).assign(&e_h_i);
266 }
267
268 Ok((e_h, e_hht))
269 }
270
271 fn m_step(
273 &self,
274 data: &Array2<f64>,
275 e_h: &Array2<f64>,
276 e_hht: &Array2<f64>,
277 ) -> Result<(Array2<f64>, Array1<f64>)> {
278 let (n_samples, n_features) = data.dim();
279
280 let xte_h = data.t().dot(e_h);
282 let sum_e_hht = e_hht * n_samples as f64; let sum_e_hht_inv = scirs2_linalg::inv(&sum_e_hht.view(), None).map_err(|e| {
285 StatsError::ComputationError(format!("Failed to invert sum E[HH^T]: {}", e))
286 })?;
287
288 let new_loadings = xte_h.dot(&sum_e_hht_inv);
289
290 let mut new_psi = Array1::zeros(n_features);
292
293 for j in 0..n_features {
294 let x_j = data.column(j);
295 let l_j = new_loadings.row(j);
296
297 let mut sum_var = 0.0;
298 for i in 0..n_samples {
299 let x_ij = x_j[i];
300 let e_h_i = e_h.row(i);
301 let residual = x_ij - l_j.dot(&e_h_i.to_owned());
302 sum_var += residual * residual;
303
304 let quad_form = l_j.dot(&e_hht.dot(&l_j.to_owned()));
306 sum_var += quad_form;
307 }
308
309 new_psi[j] = (sum_var / n_samples as f64).max(1e-6); }
311
312 Ok((new_loadings, new_psi))
313 }
314
315 fn compute_log_likelihood(
317 &self,
318 data: &Array2<f64>,
319 loadings: &Array2<f64>,
320 psi: &Array1<f64>,
321 ) -> Result<f64> {
322 let (n_samples, n_features) = data.dim();
323
324 let ll_t = loadings.dot(&loadings.t());
326 let mut sigma = ll_t;
327 for i in 0..n_features {
328 sigma[[i, i]] += psi[i];
329 }
330
331 let det_sigma = scirs2_linalg::det(&sigma.view(), None).map_err(|e| {
333 StatsError::ComputationError(format!("Failed to compute determinant: {}", e))
334 })?;
335
336 if det_sigma <= 0.0 {
337 return Err(StatsError::ComputationError(
338 "Covariance matrix must be positive definite".to_string(),
339 ));
340 }
341
342 let sigma_inv = scirs2_linalg::inv(&sigma.view(), None).map_err(|e| {
343 StatsError::ComputationError(format!("Failed to invert covariance: {}", e))
344 })?;
345
346 let mut log_likelihood = 0.0;
348 let log_det_term =
349 -0.5 * n_features as f64 * (2.0 * std::f64::consts::PI).ln() - 0.5 * det_sigma.ln();
350
351 for i in 0..n_samples {
352 let x = data.row(i);
353 let quad_form = x.dot(&sigma_inv.dot(&x.to_owned()));
354 log_likelihood += log_det_term - 0.5 * quad_form;
355 }
356
357 Ok(log_likelihood)
358 }
359
360 fn varimax_rotation(&self, loadings: &Array2<f64>) -> Result<Array2<f64>> {
362 let (n_features, n_factors) = loadings.dim();
363 let mut rotated = loadings.clone();
364
365 let max_iter = 30;
366 let tol = 1e-6;
367
368 for _ in 0..max_iter {
369 let rotation_matrix = Array2::<f64>::eye(n_factors);
370 let mut converged = true;
371
372 for i in 0..n_factors {
374 for j in (i + 1)..n_factors {
375 let col_i = rotated.column(i).to_owned();
376 let col_j = rotated.column(j).to_owned();
377
378 let u = &col_i * &col_i - &col_j * &col_j;
380 let v = 2.0 * &col_i * &col_j;
381
382 let a = u.sum();
383 let b = v.sum();
384 let c = (&u * &u - &v * &v).sum();
385 let d = 2.0 * (&u * &v).sum();
386
387 let num = d - 2.0 * a * b / n_features as f64;
388 let den = c - (a * a - b * b) / n_features as f64;
389
390 if den.abs() < 1e-10 {
391 continue;
392 }
393
394 let phi = 0.25 * (num / den).atan();
395
396 if phi.abs() > tol {
397 converged = false;
398
399 let cos_phi = phi.cos();
401 let sin_phi = phi.sin();
402
403 let new_col_i = cos_phi * &col_i - sin_phi * &col_j;
404 let new_col_j = sin_phi * &col_i + cos_phi * &col_j;
405
406 rotated.column_mut(i).assign(&new_col_i);
407 rotated.column_mut(j).assign(&new_col_j);
408 }
409 }
410 }
411
412 if converged {
413 break;
414 }
415 }
416
417 Ok(rotated)
418 }
419
420 fn promax_rotation(&self, loadings: &Array2<f64>) -> Result<Array2<f64>> {
422 let varimax_rotated = self.varimax_rotation(loadings)?;
424
425 let kappa = 4.0; let (n_features, n_factors) = varimax_rotated.dim();
428
429 let mut target = Array2::zeros((n_features, n_factors));
431 for i in 0..n_features {
432 for j in 0..n_factors {
433 let val = varimax_rotated[[i, j]];
434 target[[i, j]] = val.abs().powf(kappa) * val.signum();
435 }
436 }
437
438 let ltl = varimax_rotated.t().dot(&varimax_rotated);
441 let ltl_inv = scirs2_linalg::inv(<l.view(), None)
442 .map_err(|e| StatsError::ComputationError(format!("Failed to invert L^T L: {}", e)))?;
443
444 let ltp = varimax_rotated.t().dot(&target);
445 let transform = ltl_inv.dot(<p);
446
447 let rotated = varimax_rotated.dot(&transform);
449
450 Ok(rotated)
451 }
452
453 fn compute_factor_scores(
455 &self,
456 data: &Array2<f64>,
457 loadings: &Array2<f64>,
458 psi: &Array1<f64>,
459 ) -> Result<Array2<f64>> {
460 let n_features = loadings.nrows();
461
462 let mut psi_inv = Array2::zeros((n_features, n_features));
464 for i in 0..n_features {
465 psi_inv[[i, i]] = 1.0 / psi[i];
466 }
467
468 let lt_psi_inv = loadings.t().dot(&psi_inv);
470 let lt_psi_inv_l = lt_psi_inv.dot(loadings);
471
472 let lt_psi_inv_l_inv = scirs2_linalg::inv(<_psi_inv_l.view(), None).map_err(|e| {
473 StatsError::ComputationError(format!("Failed to compute factor score weights: {}", e))
474 })?;
475
476 let score_weights = lt_psi_inv_l_inv.dot(<_psi_inv);
477
478 let scores = data.dot(&score_weights.t());
480
481 Ok(scores)
482 }
483
484 fn compute_explained_variance(&self, loadings: &Array2<f64>) -> Array1<f64> {
486 let factor_variances = loadings
487 .axis_iter(Axis(1))
488 .map(|col| col.dot(&col))
489 .collect::<Vec<_>>();
490
491 let total_variance: f64 = factor_variances.iter().sum();
492
493 Array1::from_vec(factor_variances).mapv(|v| v / total_variance)
494 }
495
496 fn compute_communalities(&self, loadings: &Array2<f64>) -> Array1<f64> {
498 let mut communalities = Array1::zeros(loadings.nrows());
499
500 for i in 0..loadings.nrows() {
501 communalities[i] = loadings.row(i).dot(&loadings.row(i));
502 }
503
504 communalities
505 }
506
507 pub fn transform(
509 &self,
510 data: ArrayView2<f64>,
511 result: &FactorAnalysisResult,
512 ) -> Result<Array2<f64>> {
513 checkarray_finite(&data, "data")?;
514
515 if data.ncols() != result.mean.len() {
516 return Err(StatsError::DimensionMismatch(format!(
517 "data has {} features, expected {}",
518 data.ncols(),
519 result.mean.len()
520 )));
521 }
522
523 let mut centered = data.to_owned();
525 for mut row in centered.rows_mut() {
526 row -= &result.mean;
527 }
528
529 self.compute_factor_scores(¢ered, &result.loadings, &result.noise_variance)
531 }
532}
533
534pub mod efa {
536 use super::*;
537
538 pub fn parallel_analysis(
540 data: ArrayView2<f64>,
541 n_simulations: usize,
542 percentile: f64,
543 seed: Option<u64>,
544 ) -> Result<usize> {
545 checkarray_finite(&data, "data")?;
546 check_positive(n_simulations, "n_simulations")?;
547
548 if percentile <= 0.0 || percentile >= 100.0 {
549 return Err(StatsError::InvalidArgument(
550 "percentile must be between 0 and 100".to_string(),
551 ));
552 }
553
554 let (n_samples, n_features) = data.dim();
555
556 let real_eigenvalues = compute_correlation_eigenvalues(data)?;
558
559 let mut rng = match seed {
561 Some(s) => StdRng::seed_from_u64(s),
562 None => {
563 use std::time::{SystemTime, UNIX_EPOCH};
564 let s = SystemTime::now()
565 .duration_since(UNIX_EPOCH)
566 .unwrap_or_default()
567 .as_secs();
568 StdRng::seed_from_u64(s)
569 }
570 };
571
572 let mut simulated_eigenvalues = Vec::with_capacity(n_simulations);
574
575 for _ in 0..n_simulations {
576 let mut randomdata = Array2::zeros((n_samples, n_features));
578 use scirs2_core::random::{Distribution, Normal};
579 let normal = Normal::new(0.0, 1.0).map_err(|e| {
580 StatsError::ComputationError(format!("Failed to create normal distribution: {}", e))
581 })?;
582
583 for i in 0..n_samples {
584 for j in 0..n_features {
585 randomdata[[i, j]] = normal.sample(&mut rng);
586 }
587 }
588
589 let eigenvalues = compute_correlation_eigenvalues(randomdata.view())?;
590 simulated_eigenvalues.push(eigenvalues);
591 }
592
593 let mut thresholds = Array1::zeros(n_features);
595 for i in 0..n_features {
596 let mut values: Vec<f64> = simulated_eigenvalues.iter().map(|ev| ev[i]).collect();
597 values.sort_by(|a, b| a.partial_cmp(b).unwrap());
598
599 let index = ((percentile / 100.0) * (n_simulations - 1) as f64).round() as usize;
600 thresholds[i] = values[index.min(n_simulations - 1)];
601 }
602
603 let mut n_factors = 0;
605 for i in 0..n_features {
606 if real_eigenvalues[i] > thresholds[i] {
607 n_factors += 1;
608 } else {
609 break;
610 }
611 }
612
613 Ok(n_factors.max(1)) }
615
616 fn compute_correlation_eigenvalues(data: ArrayView2<f64>) -> Result<Array1<f64>> {
618 let mean = data.mean_axis(Axis(0)).unwrap();
620 let mut centered = data.to_owned();
621 for mut row in centered.rows_mut() {
622 row -= &mean;
623 }
624
625 let cov = centered.t().dot(¢ered) / (data.nrows() - 1) as f64;
627
628 let mut corr = cov.clone();
630 for i in 0..corr.nrows() {
631 for j in 0..corr.ncols() {
632 let std_i = cov[[i, i]].sqrt();
633 let std_j = cov[[j, j]].sqrt();
634 if std_i > 1e-10 && std_j > 1e-10 {
635 corr[[i, j]] = cov[[i, j]] / (std_i * std_j);
636 }
637 }
638 }
639
640 use scirs2_core::ndarray::ndarray_linalg::Eigh;
642 let eigenvalues = corr
643 .eigh(scirs2_core::ndarray::ndarray_linalg::UPLO::Upper)
644 .map_err(|e| {
645 StatsError::ComputationError(format!("Eigenvalue decomposition failed: {}", e))
646 })?
647 .0;
648
649 let mut sorted_eigenvalues = eigenvalues.to_vec();
651 sorted_eigenvalues.sort_by(|a, b| b.partial_cmp(a).unwrap());
652
653 Ok(Array1::from_vec(sorted_eigenvalues))
654 }
655
656 pub fn kmo_test(data: ArrayView2<f64>) -> Result<f64> {
658 checkarray_finite(&data, "data")?;
659
660 let mean = data.mean_axis(Axis(0)).unwrap();
662 let mut centered = data.to_owned();
663 for mut row in centered.rows_mut() {
664 row -= &mean;
665 }
666
667 let cov = centered.t().dot(¢ered) / (data.nrows() - 1) as f64;
668 let n = cov.nrows();
669
670 let mut corr = Array2::zeros((n, n));
672 for i in 0..n {
673 for j in 0..n {
674 let std_i = cov[[i, i]].sqrt();
675 let std_j = cov[[j, j]].sqrt();
676 if std_i > 1e-10 && std_j > 1e-10 {
677 corr[[i, j]] = cov[[i, j]] / (std_i * std_j);
678 } else if i == j {
679 corr[[i, j]] = 1.0;
680 }
681 }
682 }
683
684 let corr_inv = scirs2_linalg::inv(&corr.view(), None).map_err(|e| {
686 StatsError::ComputationError(format!("Failed to invert correlation matrix: {}", e))
687 })?;
688
689 let mut sum_squared_corr = 0.0;
691 let mut sum_squared_partial = 0.0;
692
693 for i in 0..n {
694 for j in 0..n {
695 if i != j {
696 sum_squared_corr += corr[[i, j]] * corr[[i, j]];
697
698 let partial = -corr_inv[[i, j]] / (corr_inv[[i, i]] * corr_inv[[j, j]]).sqrt();
700 sum_squared_partial += partial * partial;
701 }
702 }
703 }
704
705 let kmo = sum_squared_corr / (sum_squared_corr + sum_squared_partial);
706 Ok(kmo)
707 }
708
709 pub fn bartlett_test(data: ArrayView2<f64>) -> Result<(f64, f64)> {
711 checkarray_finite(&data, "data")?;
712 let (n, p) = data.dim();
713
714 if n <= p {
715 return Err(StatsError::InvalidArgument(
716 "Number of samples must exceed number of variables".to_string(),
717 ));
718 }
719
720 let mean = data.mean_axis(Axis(0)).unwrap();
722 let mut centered = data.to_owned();
723 for mut row in centered.rows_mut() {
724 row -= &mean;
725 }
726
727 let cov = centered.t().dot(¢ered) / (n - 1) as f64;
728
729 let mut corr = Array2::zeros((p, p));
731 for i in 0..p {
732 for j in 0..p {
733 let std_i = cov[[i, i]].sqrt();
734 let std_j = cov[[j, j]].sqrt();
735 if std_i > 1e-10 && std_j > 1e-10 {
736 corr[[i, j]] = cov[[i, j]] / (std_i * std_j);
737 } else if i == j {
738 corr[[i, j]] = 1.0;
739 }
740 }
741 }
742
743 let det_corr = scirs2_linalg::det(&corr.view(), None).map_err(|e| {
745 StatsError::ComputationError(format!("Failed to compute determinant: {}", e))
746 })?;
747
748 if det_corr <= 0.0 {
749 return Err(StatsError::ComputationError(
750 "Correlation matrix must be positive definite".to_string(),
751 ));
752 }
753
754 let chi2 = -(n as f64 - 1.0 - (2.0 * p as f64 + 5.0) / 6.0) * det_corr.ln();
755 let df = p * (p - 1) / 2;
756
757 let p_value = chi2_survival(chi2, df as f64);
759
760 Ok((chi2, p_value))
761 }
762}
763
764#[allow(dead_code)]
766fn chi2_survival(x: f64, df: f64) -> f64 {
767 if x <= 0.0 {
768 return 1.0;
769 }
770
771 let mean = df;
773 let var = 2.0 * df;
774 let std = var.sqrt();
775
776 if df > 30.0 {
778 let z = (x - mean) / std;
779 return 0.5 * (1.0 - erf(z / std::f64::consts::SQRT_2));
780 }
781
782 (-x / mean).exp()
784}
785
786#[allow(dead_code)]
788fn erf(x: f64) -> f64 {
789 let a1 = 0.254829592;
791 let a2 = -0.284496736;
792 let a3 = 1.421413741;
793 let a4 = -1.453152027;
794 let a5 = 1.061405429;
795 let p = 0.3275911;
796
797 let sign = if x >= 0.0 { 1.0 } else { -1.0 };
798 let x = x.abs();
799
800 let t = 1.0 / (1.0 + p * x);
801 let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
802
803 sign * y
804}