1use crate::error::{StatsError, StatsResult as Result};
7use scirs2_core::ndarray::{Array1, Array2, ArrayView2, Axis};
8use scirs2_core::random::{rngs::StdRng, SeedableRng};
9use scirs2_core::validation::*;
10
11#[derive(Debug, Clone)]
13pub struct FactorAnalysis {
14 pub n_factors: usize,
16 pub max_iter: usize,
18 pub tol: f64,
20 pub rotation: RotationType,
22 pub random_state: Option<u64>,
24}
25
26#[derive(Debug, Clone, Copy, PartialEq)]
28pub enum RotationType {
29 None,
31 Varimax,
33 Promax,
35}
36
37#[derive(Debug, Clone)]
39pub struct FactorAnalysisResult {
40 pub loadings: Array2<f64>,
42 pub noise_variance: Array1<f64>,
44 pub scores: Array2<f64>,
46 pub mean: Array1<f64>,
48 pub log_likelihood: f64,
50 pub n_iter: usize,
52 pub explained_variance_ratio: Array1<f64>,
54 pub communalities: Array1<f64>,
56}
57
58impl Default for FactorAnalysis {
59 fn default() -> Self {
60 Self {
61 n_factors: 2,
62 max_iter: 1000,
63 tol: 1e-6,
64 rotation: RotationType::Varimax,
65 random_state: None,
66 }
67 }
68}
69
70impl FactorAnalysis {
71 pub fn new(n_factors: usize) -> Result<Self> {
73 check_positive(n_factors, "n_factors")?;
74 Ok(Self {
75 n_factors,
76 ..Default::default()
77 })
78 }
79
80 pub fn with_max_iter(mut self, max_iter: usize) -> Self {
82 self.max_iter = max_iter;
83 self
84 }
85
86 pub fn with_tolerance(mut self, tol: f64) -> Self {
88 self.tol = tol;
89 self
90 }
91
92 pub fn with_rotation(mut self, rotation: RotationType) -> Self {
94 self.rotation = rotation;
95 self
96 }
97
98 pub fn with_random_state(mut self, seed: u64) -> Self {
100 self.random_state = Some(seed);
101 self
102 }
103
104 pub fn fit(&self, data: ArrayView2<f64>) -> Result<FactorAnalysisResult> {
106 checkarray_finite(&data, "data")?;
107 let (n_samples, n_features) = data.dim();
108
109 if n_samples < 2 {
110 return Err(StatsError::InvalidArgument(
111 "n_samples must be at least 2".to_string(),
112 ));
113 }
114
115 if self.n_factors >= n_features {
116 return Err(StatsError::InvalidArgument(format!(
117 "n_factors ({}) must be less than n_features ({})",
118 self.n_factors, n_features
119 )));
120 }
121
122 let mean = data.mean_axis(Axis(0)).expect("Operation failed");
124 let mut centereddata = data.to_owned();
125 for mut row in centereddata.rows_mut() {
126 row -= &mean;
127 }
128
129 let (mut loadings, mut psi) = self.initialize_parameters(¢ereddata)?;
131
132 let mut prev_log_likelihood = f64::NEG_INFINITY;
133 let mut n_iter = 0;
134
135 for iteration in 0..self.max_iter {
137 let (e_h, e_hht) = self.e_step(¢ereddata, &loadings, &psi)?;
139
140 let (new_loadings, new_psi) = self.m_step(¢ereddata, &e_h, &e_hht)?;
142
143 let log_likelihood =
145 self.compute_log_likelihood(¢ereddata, &new_loadings, &new_psi)?;
146
147 if (log_likelihood - prev_log_likelihood).abs() < self.tol {
149 loadings = new_loadings;
150 psi = new_psi;
151 n_iter = iteration + 1;
152 break;
153 }
154
155 loadings = new_loadings;
156 psi = new_psi;
157 prev_log_likelihood = log_likelihood;
158 n_iter = iteration + 1;
159 }
160
161 if n_iter == self.max_iter {
162 return Err(StatsError::ConvergenceError(format!(
163 "EM algorithm failed to converge after {} iterations",
164 self.max_iter
165 )));
166 }
167
168 let rotated_loadings = match self.rotation {
170 RotationType::None => loadings,
171 RotationType::Varimax => self.varimax_rotation(&loadings)?,
172 RotationType::Promax => self.promax_rotation(&loadings)?,
173 };
174
175 let scores = self.compute_factor_scores(¢ereddata, &rotated_loadings, &psi)?;
177
178 let explained_variance_ratio = self.compute_explained_variance(&rotated_loadings);
180 let communalities = self.compute_communalities(&rotated_loadings);
181
182 let final_log_likelihood =
184 self.compute_log_likelihood(¢ereddata, &rotated_loadings, &psi)?;
185
186 Ok(FactorAnalysisResult {
187 loadings: rotated_loadings,
188 noise_variance: psi,
189 scores,
190 mean,
191 log_likelihood: final_log_likelihood,
192 n_iter,
193 explained_variance_ratio,
194 communalities,
195 })
196 }
197
198 fn initialize_parameters(&self, data: &Array2<f64>) -> Result<(Array2<f64>, Array1<f64>)> {
200 let (n_samples, n_features) = data.dim();
201
202 let (_u, s, vt) = scirs2_linalg::svd(&data.view(), false, None).map_err(|e| {
204 StatsError::ComputationError(format!("SVD initialization failed: {}", e))
205 })?;
206
207 let v = vt.t().to_owned();
208
209 let mut loadings = Array2::zeros((n_features, self.n_factors));
211 for i in 0..self.n_factors {
212 let scale = (s[i] / (n_samples as f64).sqrt()).max(1e-6);
213 for j in 0..n_features {
214 loadings[[j, i]] = v[[j, i]] * scale;
215 }
216 }
217
218 let mut psi = Array1::ones(n_features);
220 for i in 0..n_features {
221 let communality = loadings.row(i).dot(&loadings.row(i));
222 psi[i] = (1.0 - communality).max(0.01); }
224
225 Ok((loadings, psi))
226 }
227
228 fn e_step(
230 &self,
231 data: &Array2<f64>,
232 loadings: &Array2<f64>,
233 psi: &Array1<f64>,
234 ) -> Result<(Array2<f64>, Array2<f64>)> {
235 let (n_samples, n_features) = data.dim();
236
237 let mut psi_inv = Array2::zeros((n_features, n_features));
239 for i in 0..n_features {
240 if psi[i] <= 0.0 {
241 return Err(StatsError::ComputationError(
242 "Specific variances must be positive".to_string(),
243 ));
244 }
245 psi_inv[[i, i]] = 1.0 / psi[i];
246 }
247
248 let lt_psi_inv = loadings.t().dot(&psi_inv);
250 let m = Array2::eye(self.n_factors) + lt_psi_inv.dot(loadings);
251
252 let m_inv = scirs2_linalg::inv(&m.view(), None).map_err(|e| {
254 StatsError::ComputationError(format!("Failed to invert M matrix: {}", e))
255 })?;
256
257 let mut e_h = Array2::zeros((n_samples, self.n_factors));
259 let e_hht = m_inv.clone(); for i in 0..n_samples {
262 let x = data.row(i);
263 let e_h_i = m_inv.dot(<_psi_inv.dot(&x.to_owned()));
264 e_h.row_mut(i).assign(&e_h_i);
265 }
266
267 Ok((e_h, e_hht))
268 }
269
270 fn m_step(
272 &self,
273 data: &Array2<f64>,
274 e_h: &Array2<f64>,
275 e_hht: &Array2<f64>,
276 ) -> Result<(Array2<f64>, Array1<f64>)> {
277 let (n_samples, n_features) = data.dim();
278
279 let xte_h = data.t().dot(e_h);
281 let sum_e_hht = e_hht * n_samples as f64; let sum_e_hht_inv = scirs2_linalg::inv(&sum_e_hht.view(), None).map_err(|e| {
284 StatsError::ComputationError(format!("Failed to invert sum E[HH^T]: {}", e))
285 })?;
286
287 let new_loadings = xte_h.dot(&sum_e_hht_inv);
288
289 let mut new_psi = Array1::zeros(n_features);
291
292 for j in 0..n_features {
293 let x_j = data.column(j);
294 let l_j = new_loadings.row(j);
295
296 let mut sum_var = 0.0;
297 for i in 0..n_samples {
298 let x_ij = x_j[i];
299 let e_h_i = e_h.row(i);
300 let residual = x_ij - l_j.dot(&e_h_i.to_owned());
301 sum_var += residual * residual;
302
303 let quad_form = l_j.dot(&e_hht.dot(&l_j.to_owned()));
305 sum_var += quad_form;
306 }
307
308 new_psi[j] = (sum_var / n_samples as f64).max(1e-6); }
310
311 Ok((new_loadings, new_psi))
312 }
313
314 fn compute_log_likelihood(
316 &self,
317 data: &Array2<f64>,
318 loadings: &Array2<f64>,
319 psi: &Array1<f64>,
320 ) -> Result<f64> {
321 let (n_samples, n_features) = data.dim();
322
323 let ll_t = loadings.dot(&loadings.t());
325 let mut sigma = ll_t;
326 for i in 0..n_features {
327 sigma[[i, i]] += psi[i];
328 }
329
330 let det_sigma = scirs2_linalg::det(&sigma.view(), None).map_err(|e| {
332 StatsError::ComputationError(format!("Failed to compute determinant: {}", e))
333 })?;
334
335 if det_sigma <= 0.0 {
336 return Err(StatsError::ComputationError(
337 "Covariance matrix must be positive definite".to_string(),
338 ));
339 }
340
341 let sigma_inv = scirs2_linalg::inv(&sigma.view(), None).map_err(|e| {
342 StatsError::ComputationError(format!("Failed to invert covariance: {}", e))
343 })?;
344
345 let mut log_likelihood = 0.0;
347 let log_det_term =
348 -0.5 * n_features as f64 * (2.0 * std::f64::consts::PI).ln() - 0.5 * det_sigma.ln();
349
350 for i in 0..n_samples {
351 let x = data.row(i);
352 let quad_form = x.dot(&sigma_inv.dot(&x.to_owned()));
353 log_likelihood += log_det_term - 0.5 * quad_form;
354 }
355
356 Ok(log_likelihood)
357 }
358
359 fn varimax_rotation(&self, loadings: &Array2<f64>) -> Result<Array2<f64>> {
361 let (n_features, n_factors) = loadings.dim();
362 let mut rotated = loadings.clone();
363
364 let max_iter = 30;
365 let tol = 1e-6;
366
367 for _ in 0..max_iter {
368 let rotation_matrix = Array2::<f64>::eye(n_factors);
369 let mut converged = true;
370
371 for i in 0..n_factors {
373 for j in (i + 1)..n_factors {
374 let col_i = rotated.column(i).to_owned();
375 let col_j = rotated.column(j).to_owned();
376
377 let u = &col_i * &col_i - &col_j * &col_j;
379 let v = 2.0 * &col_i * &col_j;
380
381 let a = u.sum();
382 let b = v.sum();
383 let c = (&u * &u - &v * &v).sum();
384 let d = 2.0 * (&u * &v).sum();
385
386 let num = d - 2.0 * a * b / n_features as f64;
387 let den = c - (a * a - b * b) / n_features as f64;
388
389 if den.abs() < 1e-10 {
390 continue;
391 }
392
393 let phi = 0.25 * (num / den).atan();
394
395 if phi.abs() > tol {
396 converged = false;
397
398 let cos_phi = phi.cos();
400 let sin_phi = phi.sin();
401
402 let new_col_i = cos_phi * &col_i - sin_phi * &col_j;
403 let new_col_j = sin_phi * &col_i + cos_phi * &col_j;
404
405 rotated.column_mut(i).assign(&new_col_i);
406 rotated.column_mut(j).assign(&new_col_j);
407 }
408 }
409 }
410
411 if converged {
412 break;
413 }
414 }
415
416 Ok(rotated)
417 }
418
419 fn promax_rotation(&self, loadings: &Array2<f64>) -> Result<Array2<f64>> {
421 let varimax_rotated = self.varimax_rotation(loadings)?;
423
424 let kappa = 4.0; let (n_features, n_factors) = varimax_rotated.dim();
427
428 let mut target = Array2::zeros((n_features, n_factors));
430 for i in 0..n_features {
431 for j in 0..n_factors {
432 let val = varimax_rotated[[i, j]];
433 target[[i, j]] = val.abs().powf(kappa) * val.signum();
434 }
435 }
436
437 let ltl = varimax_rotated.t().dot(&varimax_rotated);
440 let ltl_inv = scirs2_linalg::inv(<l.view(), None)
441 .map_err(|e| StatsError::ComputationError(format!("Failed to invert L^T L: {}", e)))?;
442
443 let ltp = varimax_rotated.t().dot(&target);
444 let transform = ltl_inv.dot(<p);
445
446 let rotated = varimax_rotated.dot(&transform);
448
449 Ok(rotated)
450 }
451
452 fn compute_factor_scores(
454 &self,
455 data: &Array2<f64>,
456 loadings: &Array2<f64>,
457 psi: &Array1<f64>,
458 ) -> Result<Array2<f64>> {
459 let n_features = loadings.nrows();
460
461 let mut psi_inv = Array2::zeros((n_features, n_features));
463 for i in 0..n_features {
464 psi_inv[[i, i]] = 1.0 / psi[i];
465 }
466
467 let lt_psi_inv = loadings.t().dot(&psi_inv);
469 let lt_psi_inv_l = lt_psi_inv.dot(loadings);
470
471 let lt_psi_inv_l_inv = scirs2_linalg::inv(<_psi_inv_l.view(), None).map_err(|e| {
472 StatsError::ComputationError(format!("Failed to compute factor score weights: {}", e))
473 })?;
474
475 let score_weights = lt_psi_inv_l_inv.dot(<_psi_inv);
476
477 let scores = data.dot(&score_weights.t());
479
480 Ok(scores)
481 }
482
483 fn compute_explained_variance(&self, loadings: &Array2<f64>) -> Array1<f64> {
485 let factor_variances = loadings
486 .axis_iter(Axis(1))
487 .map(|col| col.dot(&col))
488 .collect::<Vec<_>>();
489
490 let total_variance: f64 = factor_variances.iter().sum();
491
492 Array1::from_vec(factor_variances).mapv(|v| v / total_variance)
493 }
494
495 fn compute_communalities(&self, loadings: &Array2<f64>) -> Array1<f64> {
497 let mut communalities = Array1::zeros(loadings.nrows());
498
499 for i in 0..loadings.nrows() {
500 communalities[i] = loadings.row(i).dot(&loadings.row(i));
501 }
502
503 communalities
504 }
505
506 pub fn transform(
508 &self,
509 data: ArrayView2<f64>,
510 result: &FactorAnalysisResult,
511 ) -> Result<Array2<f64>> {
512 checkarray_finite(&data, "data")?;
513
514 if data.ncols() != result.mean.len() {
515 return Err(StatsError::DimensionMismatch(format!(
516 "data has {} features, expected {}",
517 data.ncols(),
518 result.mean.len()
519 )));
520 }
521
522 let mut centered = data.to_owned();
524 for mut row in centered.rows_mut() {
525 row -= &result.mean;
526 }
527
528 self.compute_factor_scores(¢ered, &result.loadings, &result.noise_variance)
530 }
531}
532
533pub mod efa {
535 use super::*;
536
537 pub fn parallel_analysis(
539 data: ArrayView2<f64>,
540 n_simulations: usize,
541 percentile: f64,
542 seed: Option<u64>,
543 ) -> Result<usize> {
544 checkarray_finite(&data, "data")?;
545 check_positive(n_simulations, "n_simulations")?;
546
547 if percentile <= 0.0 || percentile >= 100.0 {
548 return Err(StatsError::InvalidArgument(
549 "percentile must be between 0 and 100".to_string(),
550 ));
551 }
552
553 let (n_samples, n_features) = data.dim();
554
555 let real_eigenvalues = compute_correlation_eigenvalues(data)?;
557
558 let mut rng = match seed {
560 Some(s) => StdRng::seed_from_u64(s),
561 None => {
562 use std::time::{SystemTime, UNIX_EPOCH};
563 let s = SystemTime::now()
564 .duration_since(UNIX_EPOCH)
565 .unwrap_or_default()
566 .as_secs();
567 StdRng::seed_from_u64(s)
568 }
569 };
570
571 let mut simulated_eigenvalues = Vec::with_capacity(n_simulations);
573
574 for _ in 0..n_simulations {
575 let mut randomdata = Array2::zeros((n_samples, n_features));
577 use scirs2_core::random::{Distribution, Normal};
578 let normal = Normal::new(0.0, 1.0).map_err(|e| {
579 StatsError::ComputationError(format!("Failed to create normal distribution: {}", e))
580 })?;
581
582 for i in 0..n_samples {
583 for j in 0..n_features {
584 randomdata[[i, j]] = normal.sample(&mut rng);
585 }
586 }
587
588 let eigenvalues = compute_correlation_eigenvalues(randomdata.view())?;
589 simulated_eigenvalues.push(eigenvalues);
590 }
591
592 let mut thresholds = Array1::zeros(n_features);
594 for i in 0..n_features {
595 let mut values: Vec<f64> = simulated_eigenvalues.iter().map(|ev| ev[i]).collect();
596 values.sort_by(|a, b| a.partial_cmp(b).expect("Operation failed"));
597
598 let index = ((percentile / 100.0) * (n_simulations - 1) as f64).round() as usize;
599 thresholds[i] = values[index.min(n_simulations - 1)];
600 }
601
602 let mut n_factors = 0;
604 for i in 0..n_features {
605 if real_eigenvalues[i] > thresholds[i] {
606 n_factors += 1;
607 } else {
608 break;
609 }
610 }
611
612 Ok(n_factors.max(1)) }
614
615 fn compute_correlation_eigenvalues(data: ArrayView2<f64>) -> Result<Array1<f64>> {
617 let mean = data.mean_axis(Axis(0)).expect("Operation failed");
619 let mut centered = data.to_owned();
620 for mut row in centered.rows_mut() {
621 row -= &mean;
622 }
623
624 let cov = centered.t().dot(¢ered) / (data.nrows() - 1) as f64;
626
627 let mut corr = cov.clone();
629 for i in 0..corr.nrows() {
630 for j in 0..corr.ncols() {
631 let std_i = cov[[i, i]].sqrt();
632 let std_j = cov[[j, j]].sqrt();
633 if std_i > 1e-10 && std_j > 1e-10 {
634 corr[[i, j]] = cov[[i, j]] / (std_i * std_j);
635 }
636 }
637 }
638
639 let (eigenvalues, _eigenvectors) =
641 scirs2_linalg::eigh_f64_lapack(&corr.view()).map_err(|e| {
642 StatsError::ComputationError(format!("Eigenvalue decomposition failed: {}", e))
643 })?;
644
645 let mut sorted_eigenvalues: Vec<f64> = eigenvalues.to_vec();
647 sorted_eigenvalues.sort_by(|a: &f64, b: &f64| b.partial_cmp(a).expect("Operation failed"));
648
649 Ok(Array1::from_vec(sorted_eigenvalues))
650 }
651
652 pub fn kmo_test(data: ArrayView2<f64>) -> Result<f64> {
654 checkarray_finite(&data, "data")?;
655
656 let mean = data.mean_axis(Axis(0)).expect("Operation failed");
658 let mut centered = data.to_owned();
659 for mut row in centered.rows_mut() {
660 row -= &mean;
661 }
662
663 let cov = centered.t().dot(¢ered) / (data.nrows() - 1) as f64;
664 let n = cov.nrows();
665
666 let mut corr = Array2::zeros((n, n));
668 for i in 0..n {
669 for j in 0..n {
670 let std_i = cov[[i, i]].sqrt();
671 let std_j = cov[[j, j]].sqrt();
672 if std_i > 1e-10 && std_j > 1e-10 {
673 corr[[i, j]] = cov[[i, j]] / (std_i * std_j);
674 } else if i == j {
675 corr[[i, j]] = 1.0;
676 }
677 }
678 }
679
680 let corr_inv = scirs2_linalg::inv(&corr.view(), None).map_err(|e| {
682 StatsError::ComputationError(format!("Failed to invert correlation matrix: {}", e))
683 })?;
684
685 let mut sum_squared_corr = 0.0;
687 let mut sum_squared_partial = 0.0;
688
689 for i in 0..n {
690 for j in 0..n {
691 if i != j {
692 sum_squared_corr += corr[[i, j]] * corr[[i, j]];
693
694 let partial = -corr_inv[[i, j]] / (corr_inv[[i, i]] * corr_inv[[j, j]]).sqrt();
696 sum_squared_partial += partial * partial;
697 }
698 }
699 }
700
701 let kmo = sum_squared_corr / (sum_squared_corr + sum_squared_partial);
702 Ok(kmo)
703 }
704
705 pub fn bartlett_test(data: ArrayView2<f64>) -> Result<(f64, f64)> {
707 checkarray_finite(&data, "data")?;
708 let (n, p) = data.dim();
709
710 if n <= p {
711 return Err(StatsError::InvalidArgument(
712 "Number of samples must exceed number of variables".to_string(),
713 ));
714 }
715
716 let mean = data.mean_axis(Axis(0)).expect("Operation failed");
718 let mut centered = data.to_owned();
719 for mut row in centered.rows_mut() {
720 row -= &mean;
721 }
722
723 let cov = centered.t().dot(¢ered) / (n - 1) as f64;
724
725 let mut corr = Array2::zeros((p, p));
727 for i in 0..p {
728 for j in 0..p {
729 let std_i = cov[[i, i]].sqrt();
730 let std_j = cov[[j, j]].sqrt();
731 if std_i > 1e-10 && std_j > 1e-10 {
732 corr[[i, j]] = cov[[i, j]] / (std_i * std_j);
733 } else if i == j {
734 corr[[i, j]] = 1.0;
735 }
736 }
737 }
738
739 let det_corr = scirs2_linalg::det(&corr.view(), None).map_err(|e| {
741 StatsError::ComputationError(format!("Failed to compute determinant: {}", e))
742 })?;
743
744 if det_corr <= 0.0 {
745 return Err(StatsError::ComputationError(
746 "Correlation matrix must be positive definite".to_string(),
747 ));
748 }
749
750 let chi2 = -(n as f64 - 1.0 - (2.0 * p as f64 + 5.0) / 6.0) * det_corr.ln();
751 let df = p * (p - 1) / 2;
752
753 let p_value = chi2_survival(chi2, df as f64);
755
756 Ok((chi2, p_value))
757 }
758}
759
760#[allow(dead_code)]
762fn chi2_survival(x: f64, df: f64) -> f64 {
763 if x <= 0.0 {
764 return 1.0;
765 }
766
767 let mean = df;
769 let var = 2.0 * df;
770 let std = var.sqrt();
771
772 if df > 30.0 {
774 let z = (x - mean) / std;
775 return 0.5 * (1.0 - erf(z / std::f64::consts::SQRT_2));
776 }
777
778 (-x / mean).exp()
780}
781
782#[allow(dead_code)]
784fn erf(x: f64) -> f64 {
785 let a1 = 0.254829592;
787 let a2 = -0.284496736;
788 let a3 = 1.421413741;
789 let a4 = -1.453152027;
790 let a5 = 1.061405429;
791 let p = 0.3275911;
792
793 let sign = if x >= 0.0 { 1.0 } else { -1.0 };
794 let x = x.abs();
795
796 let t = 1.0 / (1.0 + p * x);
797 let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
798
799 sign * y
800}