1use std::collections::{HashMap, HashSet};
25
26use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
27
28use crate::causal_graph::dag::CausalDAG;
29use crate::error::{StatsError, StatsResult};
30
31#[derive(Debug, Clone)]
37pub struct StructureLearningResult {
38 pub dag: CausalDAG,
40 pub score: f64,
42 pub algorithm: String,
44 pub n_tests: usize,
47 pub edge_info: HashMap<(usize, usize), EdgeType>,
49}
50
51#[derive(Debug, Clone, PartialEq, Eq)]
53pub enum EdgeType {
54 Directed,
56 Undirected,
58 Bidirected,
60 PartiallyDirected,
62}
63
64fn partial_correlation_test(
71 data: ArrayView2<f64>,
72 x: usize,
73 y: usize,
74 z_set: &[usize],
75) -> StatsResult<f64> {
76 let n = data.nrows();
77 if z_set.is_empty() {
78 let rho = pearson_r(data.column(x), data.column(y));
80 return Ok(pearson_p_value(rho, n));
81 }
82
83 let res_x = ols_residuals(data, x, z_set)?;
85 let res_y = ols_residuals(data, y, z_set)?;
86 let rho = pearson_r(res_x.view(), res_y.view());
87 Ok(pearson_p_value(rho, n.saturating_sub(z_set.len())))
88}
89
90fn pearson_r(
91 a: scirs2_core::ndarray::ArrayView1<f64>,
92 b: scirs2_core::ndarray::ArrayView1<f64>,
93) -> f64 {
94 let n = a.len() as f64;
95 let ma = a.mean().unwrap_or(0.0);
96 let mb = b.mean().unwrap_or(0.0);
97 let cov: f64 = a
98 .iter()
99 .zip(b.iter())
100 .map(|(&ai, &bi)| (ai - ma) * (bi - mb))
101 .sum::<f64>();
102 let va: f64 = a.iter().map(|&ai| (ai - ma).powi(2)).sum::<f64>();
103 let vb: f64 = b.iter().map(|&bi| (bi - mb).powi(2)).sum::<f64>();
104 cov / (va * vb).sqrt().max(f64::EPSILON)
105}
106
107fn pearson_p_value(rho: f64, n: usize) -> f64 {
108 if n < 3 {
109 return 1.0;
110 }
111 let df = (n - 2) as f64;
112 let t = rho * (df / (1.0 - rho * rho).max(1e-12)).sqrt();
113 t_dist_two_sided_p(t, df)
115}
116
117fn t_dist_two_sided_p(t: f64, df: f64) -> f64 {
120 if !t.is_finite() || !df.is_finite() || df < 1.0 {
121 return 1.0;
122 }
123 if df > 30.0 {
125 return 2.0 * (1.0 - normal_cdf(t.abs()));
126 }
127 let x = df / (df + t * t);
131 let p = inc_beta_series(df * 0.5, 0.5, x);
132 p.clamp(0.0, 1.0)
133}
134
135fn inc_beta_series(a: f64, b: f64, x: f64) -> f64 {
138 if !x.is_finite() || x <= 0.0 {
139 return 0.0;
140 }
141 if x >= 1.0 {
142 return 1.0;
143 }
144 let log_prefix = a * x.ln() + b * (1.0 - x).ln() - log_beta(a, b);
146 if !log_prefix.is_finite() {
147 return 0.5;
148 }
149 let prefix = log_prefix.exp();
150 if x < (a + 1.0) / (a + b + 2.0) {
152 let mut s = 0.0_f64;
155 let mut t_term = 1.0_f64 / a;
156 s += t_term;
157 for k in 1..200_usize {
158 t_term *= x * (a + b + k as f64 - 1.0) / ((a + k as f64) * k as f64);
159 s += t_term;
160 if t_term.abs() < 1e-12 {
161 break;
162 }
163 }
164 (prefix * s).clamp(0.0, 1.0)
165 } else {
166 1.0 - inc_beta_series(b, a, 1.0 - x)
168 }
169}
170
171fn log_beta(a: f64, b: f64) -> f64 {
172 lgamma(a) + lgamma(b) - lgamma(a + b)
173}
174
175fn normal_cdf(x: f64) -> f64 {
176 0.5 * (1.0 + erf(x / std::f64::consts::SQRT_2))
177}
178
179fn erf(x: f64) -> f64 {
180 let t = 1.0 / (1.0 + 0.3275911 * x.abs());
181 let poly = t
182 * (0.254_829_592
183 + t * (-0.284_496_736
184 + t * (1.421_413_741 + t * (-1.453_152_027 + t * 1.061_405_429))));
185 if x >= 0.0 {
186 1.0 - poly * (-x * x).exp()
187 } else {
188 -(1.0 - poly * (-x * x).exp())
189 }
190}
191
192fn lgamma(x: f64) -> f64 {
195 if x < 0.5 {
197 std::f64::consts::PI.ln() - (std::f64::consts::PI * x).sin().abs().ln() - lgamma(1.0 - x)
198 } else {
199 let z = x - 1.0;
200 let t = z + 7.5;
201 let coeffs = [
202 0.999_999_999_999_809_9,
203 676.520_368_121_885_1,
204 -1_259.139_216_722_402_8,
205 771.323_428_777_653_1,
206 -176.615_029_162_140_6,
207 12.507_343_278_686_905,
208 -0.138_571_095_265_720_12,
209 9.984_369_578_019_572e-6,
210 1.505_632_735_149_312e-7,
211 ];
212 let mut x_part = coeffs[0];
213 for (i, &c) in coeffs[1..].iter().enumerate() {
214 x_part += c / (z + 1.0 + i as f64);
215 }
216 0.5 * (2.0 * std::f64::consts::PI).ln() + (z + 0.5) * t.ln() - t + x_part.ln()
217 }
218}
219
220fn ols_residuals(
222 data: ArrayView2<f64>,
223 target: usize,
224 predictors: &[usize],
225) -> StatsResult<Array1<f64>> {
226 let n = data.nrows();
227 let p = predictors.len();
228 let mut design = Array2::<f64>::ones((n, p + 1));
229 for (j, &pred) in predictors.iter().enumerate() {
230 for i in 0..n {
231 design[[i, j + 1]] = data[[i, pred]];
232 }
233 }
234 let y: Array1<f64> = data.column(target).to_owned();
235 let coef = ols_solve(design.view(), y.view())?;
237 let mut residuals = y.clone();
238 for i in 0..n {
239 let pred: f64 = (0..=p).map(|j| design[[i, j]] * coef[j]).sum();
240 residuals[i] -= pred;
241 }
242 Ok(residuals)
243}
244
245fn ols_solve(x: ArrayView2<f64>, y: ArrayView1<f64>) -> StatsResult<Array1<f64>> {
246 let (n, p) = x.dim();
247 let mut xtx = Array2::<f64>::zeros((p, p));
248 let mut xty = Array1::<f64>::zeros(p);
249 for i in 0..n {
250 for j in 0..p {
251 xty[j] += x[[i, j]] * y[i];
252 for k in 0..p {
253 xtx[[j, k]] += x[[i, j]] * x[[i, k]];
254 }
255 }
256 }
257 for j in 0..p {
259 xtx[[j, j]] += 1e-8;
260 }
261 gauss_jordan_solve(xtx, xty)
262}
263
264fn gauss_jordan_solve(mut a: Array2<f64>, mut b: Array1<f64>) -> StatsResult<Array1<f64>> {
265 let n = b.len();
266 for col in 0..n {
267 let pivot_row = (col..n)
268 .max_by(|&i, &j| {
269 a[[i, col]]
270 .abs()
271 .partial_cmp(&a[[j, col]].abs())
272 .unwrap_or(std::cmp::Ordering::Equal)
273 })
274 .ok_or_else(|| StatsError::ComputationError("Singular matrix".to_owned()))?;
275 for k in 0..n {
277 let tmp = a[[col, k]];
278 a[[col, k]] = a[[pivot_row, k]];
279 a[[pivot_row, k]] = tmp;
280 }
281 let tmp = b[col];
282 b[col] = b[pivot_row];
283 b[pivot_row] = tmp;
284
285 let pivot = a[[col, col]];
286 if pivot.abs() < 1e-12 {
287 return Err(StatsError::ComputationError(
288 "Singular OLS system".to_owned(),
289 ));
290 }
291 for k in col..n {
292 a[[col, k]] /= pivot;
293 }
294 b[col] /= pivot;
295 for row in 0..n {
296 if row != col {
297 let factor = a[[row, col]];
298 for k in col..n {
299 let av = a[[col, k]];
300 a[[row, k]] -= factor * av;
301 }
302 b[row] -= factor * b[col];
303 }
304 }
305 }
306 Ok(b)
307}
308
309pub struct PcAlgorithm {
320 pub alpha: f64,
322 pub max_cond_set_size: usize,
324 pub gaussian: bool,
326}
327
328impl Default for PcAlgorithm {
329 fn default() -> Self {
330 Self {
331 alpha: 0.05,
332 max_cond_set_size: 3,
333 gaussian: true,
334 }
335 }
336}
337
338impl PcAlgorithm {
339 pub fn fit(
343 &self,
344 data: ArrayView2<f64>,
345 var_names: &[&str],
346 ) -> StatsResult<StructureLearningResult> {
347 let p = data.ncols();
348 if var_names.len() != p {
349 return Err(StatsError::DimensionMismatch(
350 "var_names length must equal number of columns in data".to_owned(),
351 ));
352 }
353
354 let mut adj: Vec<Vec<bool>> = vec![vec![true; p]; p];
357 for i in 0..p {
358 adj[i][i] = false;
359 }
360
361 let mut sep_sets: HashMap<(usize, usize), Vec<usize>> = HashMap::new();
363 let mut n_tests = 0usize;
364
365 for ord in 0..=self.max_cond_set_size {
366 let edges: Vec<(usize, usize)> = (0..p)
367 .flat_map(|i| {
368 (0..p)
369 .filter(move |&j| i < j)
370 .collect::<Vec<_>>()
371 .into_iter()
372 .map(move |j| (i, j))
373 .collect::<Vec<_>>()
374 })
375 .filter(|&(i, j)| adj[i][j])
376 .collect();
377 for (x, y) in edges {
378 let z_candidates: Vec<usize> =
380 (0..p).filter(|&k| k != x && k != y && adj[x][k]).collect();
381 if z_candidates.len() < ord {
382 continue;
383 }
384 let mut found_sep = false;
386 'cond: for z_set in subsets(&z_candidates, ord) {
387 n_tests += 1;
388 let p_val = partial_correlation_test(data, x, y, &z_set).unwrap_or(1.0);
389 if p_val > self.alpha {
390 adj[x][y] = false;
392 adj[y][x] = false;
393 sep_sets.insert((x.min(y), x.max(y)), z_set);
394 found_sep = true;
395 break 'cond;
396 }
397 }
398 if found_sep {
399 break;
400 }
401 }
402 }
403
404 let mut directed: HashMap<(usize, usize), EdgeType> = HashMap::new();
406 for z in 0..p {
408 let neighbours: Vec<usize> = (0..p).filter(|&k| k != z && adj[z][k]).collect();
409 for i in 0..neighbours.len() {
410 for j in (i + 1)..neighbours.len() {
411 let x = neighbours[i];
412 let y = neighbours[j];
413 if adj[x][y] {
414 continue;
415 } let key = (x.min(y), x.max(y));
417 let sep = sep_sets.get(&key).cloned().unwrap_or_default();
418 if !sep.contains(&z) {
419 directed.insert((x, z), EdgeType::Directed);
421 directed.insert((y, z), EdgeType::Directed);
422 }
423 }
424 }
425 }
426
427 meek_rules(p, &adj, &mut directed);
429
430 let mut dag = CausalDAG::new();
432 for name in var_names {
433 dag.add_node(name);
434 }
435 let mut edge_info: HashMap<(usize, usize), EdgeType> = HashMap::new();
436
437 for i in 0..p {
438 for j in 0..p {
439 if i == j || !adj[i][j] {
440 continue;
441 }
442 let et = directed.get(&(i, j)).cloned();
443 match et {
444 Some(EdgeType::Directed) => {
445 let _ = dag.add_edge(var_names[i], var_names[j]);
447 edge_info.insert((i, j), EdgeType::Directed);
448 }
449 _ => {
450 if i < j {
452 let _ = dag.add_edge(var_names[i], var_names[j]);
453 edge_info.insert((i, j), EdgeType::Undirected);
454 }
455 }
456 }
457 }
458 }
459
460 Ok(StructureLearningResult {
461 dag,
462 score: f64::NAN,
463 algorithm: "PC".to_owned(),
464 n_tests,
465 edge_info,
466 })
467 }
468}
469
470fn meek_rules(p: usize, adj: &[Vec<bool>], directed: &mut HashMap<(usize, usize), EdgeType>) {
472 let mut changed = true;
473 let mut iters = 0;
474 while changed && iters < 100 {
475 changed = false;
476 iters += 1;
477 for b in 0..p {
479 for a in 0..p {
480 if !adj[a][b] {
481 continue;
482 }
483 if directed.get(&(a, b)) != Some(&EdgeType::Directed) {
484 continue;
485 }
486 for c in 0..p {
487 if c == a || !adj[b][c] {
488 continue;
489 }
490 if directed.contains_key(&(b, c)) {
491 continue;
492 }
493 if !adj[a][c] {
494 directed.insert((b, c), EdgeType::Directed);
495 changed = true;
496 }
497 }
498 }
499 }
500 for a in 0..p {
502 for b in 0..p {
503 if directed.get(&(a, b)) != Some(&EdgeType::Directed) {
504 continue;
505 }
506 for c in 0..p {
507 if directed.get(&(b, c)) != Some(&EdgeType::Directed) {
508 continue;
509 }
510 if adj[a][c] && !directed.contains_key(&(a, c)) {
511 directed.insert((a, c), EdgeType::Directed);
512 changed = true;
513 }
514 }
515 }
516 }
517 }
518}
519
520pub struct FciAlgorithm {
530 pub alpha: f64,
532 pub max_cond_set_size: usize,
534}
535
536impl Default for FciAlgorithm {
537 fn default() -> Self {
538 Self {
539 alpha: 0.05,
540 max_cond_set_size: 3,
541 }
542 }
543}
544
545impl FciAlgorithm {
546 pub fn fit(
548 &self,
549 data: ArrayView2<f64>,
550 var_names: &[&str],
551 ) -> StatsResult<StructureLearningResult> {
552 let pc = PcAlgorithm {
554 alpha: self.alpha,
555 max_cond_set_size: self.max_cond_set_size,
556 gaussian: true,
557 };
558 let mut result = pc.fit(data, var_names)?;
559 result.algorithm = "FCI".to_owned();
560
561 let p = var_names.len();
566 let directed_clone = result.edge_info.clone();
567 for i in 0..p {
568 for j in 0..p {
569 if i == j {
570 continue;
571 }
572 let ij = directed_clone.get(&(i, j));
574 let ji = directed_clone.get(&(j, i));
575 if ij.is_none() && ji.is_none() {
576 if i < j {
578 result.edge_info.insert((i, j), EdgeType::PartiallyDirected);
579 }
580 }
581 }
582 }
583
584 Ok(result)
585 }
586}
587
588fn bic_score(data: ArrayView2<f64>, node: usize, parents: &[usize], bic_penalty: f64) -> f64 {
594 let n = data.nrows() as f64;
595 let k = parents.len() as f64;
596
597 let residuals = if parents.is_empty() {
599 let mean = data.column(node).mean().unwrap_or(0.0);
600 data.column(node)
601 .iter()
602 .map(|&y| y - mean)
603 .collect::<Vec<_>>()
604 } else {
605 match ols_residuals(data, node, parents) {
606 Ok(r) => r.to_vec(),
607 Err(_) => return f64::NEG_INFINITY,
608 }
609 };
610
611 let rss: f64 = residuals.iter().map(|r| r * r).sum();
612 let sigma2 = rss / n;
613 if sigma2 < 1e-12 {
614 return 0.0;
615 }
616 -(n * sigma2.ln() + bic_penalty * (k + 1.0) * n.ln())
618}
619
620pub struct BicGreedySearch {
622 pub penalty: f64,
624 pub max_parents: usize,
626 pub max_iter: usize,
628 pub n_restarts: usize,
630}
631
632impl Default for BicGreedySearch {
633 fn default() -> Self {
634 Self {
635 penalty: 1.0,
636 max_parents: 4,
637 max_iter: 500,
638 n_restarts: 1,
639 }
640 }
641}
642
643impl BicGreedySearch {
644 pub fn fit(
646 &self,
647 data: ArrayView2<f64>,
648 var_names: &[&str],
649 ) -> StatsResult<StructureLearningResult> {
650 let p = data.ncols();
651 if var_names.len() != p {
652 return Err(StatsError::DimensionMismatch(
653 "var_names length mismatch".to_owned(),
654 ));
655 }
656
657 let mut best_dag = CausalDAG::new();
658 for name in var_names {
659 best_dag.add_node(name);
660 }
661 let mut best_score = self.compute_total_bic(data, &vec![vec![]; p]);
662 let mut best_parents = vec![vec![]; p];
663
664 let mut iters = 0usize;
665 let mut current_parents = vec![vec![]; p];
666
667 let mut improved = true;
668 while improved && iters < self.max_iter {
669 improved = false;
670 iters += 1;
671
672 for i in 0..p {
674 for j in 0..p {
675 if i == j {
676 continue;
677 }
678 if current_parents[j].contains(&i) {
679 continue;
680 }
681 if current_parents[j].len() >= self.max_parents {
682 continue;
683 }
684 if self.creates_cycle(¤t_parents, i, j, p) {
686 continue;
687 }
688
689 let mut trial = current_parents.clone();
690 trial[j].push(i);
691 let score = self.compute_total_bic(data, &trial);
692 if score > best_score {
693 best_score = score;
694 best_parents = trial;
695 improved = true;
696 }
697 }
698 }
699
700 if improved {
701 current_parents = best_parents.clone();
702 }
703
704 improved = false;
706 for j in 0..p {
707 let pa = current_parents[j].clone();
708 for (k, &pi) in pa.iter().enumerate() {
709 let mut trial = current_parents.clone();
710 trial[j].remove(k);
711 let score = self.compute_total_bic(data, &trial);
712 if score > best_score {
713 best_score = score;
714 best_parents = trial;
715 improved = true;
716 }
717 let _ = pi;
718 }
719 }
720 if improved {
721 current_parents = best_parents.clone();
722 }
723 }
724
725 let mut dag = CausalDAG::new();
727 for name in var_names {
728 dag.add_node(name);
729 }
730 for (j, parents) in best_parents.iter().enumerate() {
731 for &i in parents {
732 let _ = dag.add_edge(var_names[i], var_names[j]);
733 }
734 }
735
736 Ok(StructureLearningResult {
737 dag,
738 score: best_score,
739 algorithm: "BIC Greedy".to_owned(),
740 n_tests: iters,
741 edge_info: HashMap::new(),
742 })
743 }
744
745 fn compute_total_bic(&self, data: ArrayView2<f64>, parents: &[Vec<usize>]) -> f64 {
746 (0..data.ncols())
747 .map(|j| bic_score(data, j, &parents[j], self.penalty))
748 .sum()
749 }
750
751 fn creates_cycle(
753 &self,
754 parents: &[Vec<usize>],
755 new_parent: usize,
756 child: usize,
757 p: usize,
758 ) -> bool {
759 let mut visited = HashSet::new();
761 let mut stack = vec![new_parent];
762 while let Some(cur) = stack.pop() {
763 if cur == child {
764 return true;
765 }
766 if !visited.insert(cur) {
767 continue;
768 }
769 for &pa in &parents[cur] {
770 stack.push(pa);
771 }
772 }
773 let _ = p;
774 false
775 }
776}
777
778pub struct LiNGAM {
791 pub max_iter: usize,
793 pub tol: f64,
795 pub threshold: f64,
797}
798
799impl Default for LiNGAM {
800 fn default() -> Self {
801 Self {
802 max_iter: 500,
803 tol: 1e-6,
804 threshold: 0.1,
805 }
806 }
807}
808
809#[derive(Debug, Clone)]
811pub struct LiNGAMResult {
812 pub causal_order: Vec<usize>,
814 pub b_matrix: Array2<f64>,
816 pub dag: CausalDAG,
818}
819
820impl LiNGAM {
821 pub fn fit(&self, data: ArrayView2<f64>, var_names: &[&str]) -> StatsResult<LiNGAMResult> {
823 let (n, p) = data.dim();
824 if var_names.len() != p {
825 return Err(StatsError::DimensionMismatch(
826 "var_names must equal ncols".to_owned(),
827 ));
828 }
829
830 let means: Array1<f64> = (0..p)
832 .map(|j| data.column(j).mean().unwrap_or(0.0))
833 .collect();
834 let mut xc = data.to_owned();
835 for i in 0..n {
836 for j in 0..p {
837 xc[[i, j]] -= means[j];
838 }
839 }
840
841 let (xw, whitening_matrix) = whiten(xc.view())?;
843
844 let w_ica = fast_ica(xw.view(), self.max_iter, self.tol)?;
846
847 let a_hat = pseudo_inverse_2x2_general(&w_ica, p)?;
850
851 let b_matrix = normalise_lingam(a_hat, p);
853
854 let causal_order = lingam_order(&b_matrix, p);
856
857 let mut dag = CausalDAG::new();
859 for name in var_names {
860 dag.add_node(name);
861 }
862 for j in 0..p {
863 for i in 0..p {
864 if i == j {
865 continue;
866 }
867 if b_matrix[[i, j]].abs() > self.threshold {
868 let _ = dag.add_edge(var_names[j], var_names[i]);
870 }
871 }
872 }
873 let _ = whitening_matrix;
874
875 Ok(LiNGAMResult {
876 causal_order,
877 b_matrix,
878 dag,
879 })
880 }
881}
882
883fn whiten(data: ArrayView2<f64>) -> StatsResult<(Array2<f64>, Array2<f64>)> {
885 let (n, p) = data.dim();
886 let mut cov = Array2::<f64>::zeros((p, p));
888 for i in 0..n {
889 for j in 0..p {
890 for k in 0..p {
891 cov[[j, k]] += data[[i, j]] * data[[i, k]];
892 }
893 }
894 }
895 cov.mapv_inplace(|x| x / n as f64);
896
897 let (eigvals, eigvecs) = jacobi_eigen(cov.view(), 100)?;
899
900 let mut w = Array2::<f64>::zeros((p, p));
902 for i in 0..p {
903 let scale = if eigvals[i] > 1e-10 {
904 eigvals[i].sqrt().recip()
905 } else {
906 0.0
907 };
908 for j in 0..p {
909 w[[i, j]] = scale * eigvecs[[j, i]]; }
911 }
912
913 let mut xw = Array2::<f64>::zeros((n, p));
915 for i in 0..n {
916 for j in 0..p {
917 for k in 0..p {
918 xw[[i, j]] += w[[j, k]] * data[[i, k]];
919 }
920 }
921 }
922 Ok((xw, w))
923}
924
925fn jacobi_eigen(a: ArrayView2<f64>, max_iter: usize) -> StatsResult<(Array1<f64>, Array2<f64>)> {
927 let n = a.nrows();
928 let mut d = a.to_owned();
929 let mut v = Array2::<f64>::eye(n);
930 for _ in 0..max_iter {
931 let mut max_val = 0.0_f64;
933 let (mut p, mut q) = (0, 1);
934 for i in 0..n {
935 for j in (i + 1)..n {
936 if d[[i, j]].abs() > max_val {
937 max_val = d[[i, j]].abs();
938 p = i;
939 q = j;
940 }
941 }
942 }
943 if max_val < 1e-12 {
944 break;
945 }
946 let theta = if (d[[p, p]] - d[[q, q]]).abs() < 1e-12 {
947 std::f64::consts::FRAC_PI_4
948 } else {
949 0.5 * ((2.0 * d[[p, q]]) / (d[[q, q]] - d[[p, p]])).atan()
950 };
951 let (s, c) = theta.sin_cos();
952 let (dpp, dqq, dpq) = (d[[p, p]], d[[q, q]], d[[p, q]]);
954 d[[p, p]] = c * c * dpp - 2.0 * s * c * dpq + s * s * dqq;
955 d[[q, q]] = s * s * dpp + 2.0 * s * c * dpq + c * c * dqq;
956 d[[p, q]] = 0.0;
957 d[[q, p]] = 0.0;
958 for k in 0..n {
959 if k != p && k != q {
960 let dpk = d[[p, k]];
961 let dqk = d[[q, k]];
962 d[[p, k]] = c * dpk - s * dqk;
963 d[[k, p]] = d[[p, k]];
964 d[[q, k]] = s * dpk + c * dqk;
965 d[[k, q]] = d[[q, k]];
966 }
967 let vpk = v[[k, p]];
968 let vqk = v[[k, q]];
969 v[[k, p]] = c * vpk - s * vqk;
970 v[[k, q]] = s * vpk + c * vqk;
971 }
972 }
973 let eigvals: Array1<f64> = (0..n).map(|i| d[[i, i]]).collect();
974 Ok((eigvals, v))
975}
976
977fn fast_ica(xw: ArrayView2<f64>, max_iter: usize, tol: f64) -> StatsResult<Array2<f64>> {
979 let (n, p) = xw.dim();
980 let mut w_mat = Array2::<f64>::eye(p);
981
982 for comp in 0..p {
983 let mut w = Array1::<f64>::from_shape_fn(p, |i| if i == comp { 1.0 } else { 0.0 });
984
985 for _ in 0..max_iter {
986 let wx: Vec<f64> = (0..n)
988 .map(|i| {
989 w.iter()
990 .zip(xw.row(i).iter())
991 .map(|(a, b)| a * b)
992 .sum::<f64>()
993 })
994 .collect();
995
996 let g: Vec<f64> = wx.iter().map(|&u| u.tanh()).collect();
998 let gp: Vec<f64> = wx.iter().map(|&u| 1.0 - u.tanh().powi(2)).collect();
999
1000 let mut w_new = Array1::<f64>::zeros(p);
1001 for i in 0..n {
1002 for j in 0..p {
1003 w_new[j] += g[i] * xw[[i, j]];
1004 }
1005 }
1006 w_new.mapv_inplace(|x| x / n as f64);
1007 let gp_mean = gp.iter().sum::<f64>() / n as f64;
1008 for j in 0..p {
1009 w_new[j] -= gp_mean * w[j];
1010 }
1011
1012 for prev in 0..comp {
1014 let w_prev = w_mat.row(prev);
1015 let dot: f64 = w_new.iter().zip(w_prev.iter()).map(|(a, b)| a * b).sum();
1016 for j in 0..p {
1017 w_new[j] -= dot * w_prev[j];
1018 }
1019 }
1020
1021 let norm: f64 = w_new
1023 .iter()
1024 .map(|x| x * x)
1025 .sum::<f64>()
1026 .sqrt()
1027 .max(f64::EPSILON);
1028 w_new.mapv_inplace(|x| x / norm);
1029
1030 let diff: f64 = w
1031 .iter()
1032 .zip(w_new.iter())
1033 .map(|(a, b)| (a - b).powi(2))
1034 .sum::<f64>()
1035 .sqrt();
1036 w = w_new;
1037 if diff < tol {
1038 break;
1039 }
1040 }
1041 for j in 0..p {
1042 w_mat[[comp, j]] = w[j];
1043 }
1044 }
1045 Ok(w_mat)
1046}
1047
1048fn pseudo_inverse_2x2_general(w: &Array2<f64>, p: usize) -> StatsResult<Array2<f64>> {
1049 let mut aug = Array2::<f64>::zeros((p, 2 * p));
1052 for i in 0..p {
1053 for j in 0..p {
1054 aug[[i, j]] = w[[i, j]];
1055 }
1056 aug[[i, p + i]] = 1.0;
1057 }
1058 for col in 0..p {
1059 let pivot = (col..p)
1060 .max_by(|&i, &j| {
1061 aug[[i, col]]
1062 .abs()
1063 .partial_cmp(&aug[[j, col]].abs())
1064 .unwrap_or(std::cmp::Ordering::Equal)
1065 })
1066 .ok_or_else(|| {
1067 StatsError::ComputationError("Singular ICA unmixing matrix".to_owned())
1068 })?;
1069 for k in 0..(2 * p) {
1070 let tmp = aug[[col, k]];
1071 aug[[col, k]] = aug[[pivot, k]];
1072 aug[[pivot, k]] = tmp;
1073 }
1074 let piv_val = aug[[col, col]];
1075 if piv_val.abs() < 1e-12 {
1076 return Err(StatsError::ComputationError("Singular".to_owned()));
1077 }
1078 for k in 0..(2 * p) {
1079 aug[[col, k]] /= piv_val;
1080 }
1081 for row in 0..p {
1082 if row != col {
1083 let factor = aug[[row, col]];
1084 for k in 0..(2 * p) {
1085 let av = aug[[col, k]];
1086 aug[[row, k]] -= factor * av;
1087 }
1088 }
1089 }
1090 }
1091 let mut inv = Array2::<f64>::zeros((p, p));
1092 for i in 0..p {
1093 for j in 0..p {
1094 inv[[i, j]] = aug[[i, p + j]];
1095 }
1096 }
1097 Ok(inv)
1098}
1099
1100fn normalise_lingam(mut b: Array2<f64>, p: usize) -> Array2<f64> {
1101 for i in 0..p {
1102 let diag = b[[i, i]];
1103 if diag.abs() > 1e-10 {
1104 for j in 0..p {
1105 b[[i, j]] /= diag;
1106 }
1107 }
1108 }
1109 for i in 0..p {
1110 b[[i, i]] = 0.0;
1111 }
1112 b
1113}
1114
1115fn lingam_order(b: &Array2<f64>, p: usize) -> Vec<usize> {
1116 let mut remaining: Vec<usize> = (0..p).collect();
1119 let mut order = Vec::with_capacity(p);
1120 while !remaining.is_empty() {
1121 let best = remaining
1122 .iter()
1123 .min_by(|&&i, &&j| {
1124 let li: f64 = remaining
1125 .iter()
1126 .filter(|&&k| k != i)
1127 .map(|&k| b[[i, k]].abs())
1128 .sum();
1129 let lj: f64 = remaining
1130 .iter()
1131 .filter(|&&k| k != j)
1132 .map(|&k| b[[j, k]].abs())
1133 .sum();
1134 li.partial_cmp(&lj).unwrap_or(std::cmp::Ordering::Equal)
1135 })
1136 .copied()
1137 .unwrap_or(remaining[0]);
1138 order.push(best);
1139 remaining.retain(|&x| x != best);
1140 }
1141 order
1142}
1143
1144pub struct Notears {
1155 pub lambda: f64,
1157 pub max_iter: usize,
1159 pub max_inner_iter: usize,
1161 pub h_tol: f64,
1163 pub w_threshold: f64,
1165}
1166
1167impl Default for Notears {
1168 fn default() -> Self {
1169 Self {
1170 lambda: 0.1,
1171 max_iter: 100,
1172 max_inner_iter: 300,
1173 h_tol: 1e-8,
1174 w_threshold: 0.3,
1175 }
1176 }
1177}
1178
1179impl Notears {
1180 pub fn fit(
1182 &self,
1183 data: ArrayView2<f64>,
1184 var_names: &[&str],
1185 ) -> StatsResult<StructureLearningResult> {
1186 let (n, p) = data.dim();
1187 if var_names.len() != p {
1188 return Err(StatsError::DimensionMismatch(
1189 "var_names mismatch".to_owned(),
1190 ));
1191 }
1192
1193 let means: Array1<f64> = (0..p)
1195 .map(|j| data.column(j).mean().unwrap_or(0.0))
1196 .collect();
1197 let mut xc = data.to_owned();
1198 for i in 0..n {
1199 for j in 0..p {
1200 xc[[i, j]] -= means[j];
1201 }
1202 }
1203
1204 let mut w = Array2::<f64>::zeros((p, p));
1206 let mut alpha = 0.0_f64; let mut rho = 1.0_f64;
1208 let rho_max = 1e16_f64;
1209 let mut h_prev = f64::INFINITY;
1210 let mut outer_iters = 0usize;
1211
1212 for _ in 0..self.max_iter {
1213 outer_iters += 1;
1214 w = self.inner_optim(xc.view(), &w, alpha, rho, n, p)?;
1216 let h_val = notears_h(&w, p);
1217
1218 if h_val.abs() < self.h_tol {
1219 break;
1220 }
1221
1222 alpha += rho * h_val;
1224 if h_val > 0.25 * h_prev {
1225 rho = (rho * 10.0).min(rho_max);
1226 }
1227 h_prev = h_val;
1228 }
1229
1230 let mut dag = CausalDAG::new();
1232 for name in var_names {
1233 dag.add_node(name);
1234 }
1235 let mut edge_info = HashMap::new();
1236 for i in 0..p {
1237 for j in 0..p {
1238 if i == j {
1239 continue;
1240 }
1241 if w[[i, j]].abs() > self.w_threshold {
1242 let _ = dag.add_edge(var_names[i], var_names[j]);
1243 edge_info.insert((i, j), EdgeType::Directed);
1244 }
1245 }
1246 }
1247
1248 Ok(StructureLearningResult {
1249 dag,
1250 score: -notears_loss(xc.view(), &w, n, p),
1251 algorithm: "NOTEARS".to_owned(),
1252 n_tests: outer_iters,
1253 edge_info,
1254 })
1255 }
1256
1257 fn inner_optim(
1258 &self,
1259 x: ArrayView2<f64>,
1260 w_init: &Array2<f64>,
1261 alpha: f64,
1262 rho: f64,
1263 n: usize,
1264 p: usize,
1265 ) -> StatsResult<Array2<f64>> {
1266 let mut w = w_init.clone();
1267 let lr = 1e-3;
1268
1269 for _step in 0..self.max_inner_iter {
1270 let grad = self.aug_lagrangian_gradient(x, &w, alpha, rho, n, p);
1271 let mut w_new = Array2::<f64>::zeros((p, p));
1273 for i in 0..p {
1274 for j in 0..p {
1275 if i == j {
1276 continue;
1277 }
1278 let u = w[[i, j]] - lr * grad[[i, j]];
1279 w_new[[i, j]] = if u > lr * self.lambda {
1281 u - lr * self.lambda
1282 } else if u < -lr * self.lambda {
1283 u + lr * self.lambda
1284 } else {
1285 0.0
1286 };
1287 }
1288 }
1289 let diff: f64 = {
1290 let mut d = 0.0_f64;
1291 for ii in 0..p {
1292 for jj in 0..p {
1293 d += (w_new[[ii, jj]] - w[[ii, jj]]).powi(2);
1294 }
1295 }
1296 d.sqrt()
1297 };
1298 w = w_new;
1299 if diff < 1e-6 {
1300 break;
1301 }
1302 }
1303 Ok(w)
1304 }
1305
1306 fn aug_lagrangian_gradient(
1307 &self,
1308 x: ArrayView2<f64>,
1309 w: &Array2<f64>,
1310 alpha: f64,
1311 rho: f64,
1312 n: usize,
1313 p: usize,
1314 ) -> Array2<f64> {
1315 let mut grad = Array2::<f64>::zeros((p, p));
1317
1318 let xw = x_times_w(x, w, n, p);
1321 for i in 0..p {
1322 for j in 0..p {
1323 if i == j {
1324 continue;
1325 }
1326 let mut g = 0.0_f64;
1327 for k in 0..n {
1328 g += x[[k, i]] * (xw[[k, j]] - x[[k, j]]);
1329 }
1330 grad[[i, j]] = g / n as f64;
1331 }
1332 }
1333
1334 let exp_ww = notears_exp_ww(w, p);
1336 let h = exp_ww
1337 .iter()
1338 .enumerate()
1339 .filter(|(i, _)| i / p == i % p)
1340 .map(|(_, &v)| v)
1341 .sum::<f64>()
1342 - p as f64;
1343 let dh_dw = notears_dh_dw(&exp_ww, w, p);
1344 for i in 0..p {
1345 for j in 0..p {
1346 grad[[i, j]] += (alpha + rho * h) * dh_dw[[i, j]];
1347 }
1348 }
1349 grad
1350 }
1351}
1352
1353fn x_times_w(x: ArrayView2<f64>, w: &Array2<f64>, n: usize, p: usize) -> Array2<f64> {
1354 let mut xw = Array2::<f64>::zeros((n, p));
1355 for i in 0..n {
1356 for j in 0..p {
1357 for k in 0..p {
1358 xw[[i, j]] += x[[i, k]] * w[[k, j]];
1359 }
1360 }
1361 }
1362 xw
1363}
1364
1365fn notears_h(w: &Array2<f64>, p: usize) -> f64 {
1366 let exp_ww = notears_exp_ww(w, p);
1368 (0..p).map(|i| exp_ww[[i, i]]).sum::<f64>() - p as f64
1369}
1370
1371fn notears_exp_ww(w: &Array2<f64>, p: usize) -> Array2<f64> {
1373 let ww: Array2<f64> = w.mapv(|x| x * x);
1375 let mut result = Array2::<f64>::eye(p);
1377 let mut term = Array2::<f64>::eye(p);
1378 let mut factorial = 1.0_f64;
1379 for k in 1..=15_usize {
1380 factorial *= k as f64;
1381 let mut new_term = Array2::<f64>::zeros((p, p));
1383 for i in 0..p {
1384 for j in 0..p {
1385 for l in 0..p {
1386 new_term[[i, j]] += term[[i, l]] * ww[[l, j]];
1387 }
1388 }
1389 }
1390 term = new_term;
1391 for i in 0..p {
1392 for j in 0..p {
1393 result[[i, j]] += term[[i, j]] / factorial;
1394 }
1395 }
1396 if term.iter().map(|x| x.abs()).fold(0.0_f64, f64::max) < 1e-12 {
1397 break;
1398 }
1399 }
1400 result
1401}
1402
1403fn notears_dh_dw(exp_ww: &Array2<f64>, w: &Array2<f64>, p: usize) -> Array2<f64> {
1404 let mut dh = Array2::<f64>::zeros((p, p));
1406 for i in 0..p {
1407 for j in 0..p {
1408 dh[[i, j]] = exp_ww[[j, i]] * 2.0 * w[[i, j]];
1409 }
1410 }
1411 dh
1412}
1413
1414fn notears_loss(x: ArrayView2<f64>, w: &Array2<f64>, n: usize, p: usize) -> f64 {
1415 let xw = x_times_w(x, w, n, p);
1416 let mut loss = 0.0_f64;
1417 for i in 0..n {
1418 for j in 0..p {
1419 loss += (xw[[i, j]] - x[[i, j]]).powi(2);
1420 }
1421 }
1422 loss / (2.0 * n as f64)
1423}
1424
1425fn subsets<T: Copy>(items: &[T], k: usize) -> Vec<Vec<T>> {
1430 if k == 0 {
1431 return vec![Vec::new()];
1432 }
1433 if k > items.len() {
1434 return Vec::new();
1435 }
1436 let mut result = Vec::new();
1437 for i in 0..=(items.len() - k) {
1438 for mut rest in subsets(&items[i + 1..], k - 1) {
1439 rest.insert(0, items[i]);
1440 result.push(rest);
1441 }
1442 }
1443 result
1444}
1445
1446#[cfg(test)]
1451mod tests {
1452 use super::*;
1453 use scirs2_core::ndarray::Array2;
1454
1455 fn chain_data() -> Array2<f64> {
1456 let n = 100;
1458 let mut data = Array2::<f64>::zeros((n, 3));
1459 let mut lcg: u64 = 12345;
1460 let next = |s: &mut u64| -> f64 {
1461 *s = s.wrapping_mul(6364136223846793005).wrapping_add(1);
1463 let u = (*s >> 33) as f64 / (1u64 << 31) as f64;
1464 *s = s.wrapping_mul(6364136223846793005).wrapping_add(1);
1465 let v = ((*s >> 33) as f64 / (1u64 << 31) as f64).max(1e-10);
1467 (-2.0 * v.ln()).sqrt() * (2.0 * std::f64::consts::PI * u).cos()
1469 };
1470 for i in 0..n {
1471 data[[i, 0]] = next(&mut lcg);
1472 data[[i, 1]] = 0.8 * data[[i, 0]] + next(&mut lcg) * 0.5;
1473 data[[i, 2]] = 0.8 * data[[i, 1]] + next(&mut lcg) * 0.5;
1474 }
1475 data
1476 }
1477
1478 #[test]
1479 fn test_pc_runs() {
1480 let data = chain_data();
1481 let pc = PcAlgorithm::default();
1482 let res = pc.fit(data.view(), &["X", "Y", "Z"]).unwrap();
1483 assert_eq!(res.algorithm, "PC");
1484 assert!(res.dag.n_nodes() == 3);
1485 }
1486
1487 #[test]
1488 fn test_fci_runs() {
1489 let data = chain_data();
1490 let fci = FciAlgorithm::default();
1491 let res = fci.fit(data.view(), &["X", "Y", "Z"]).unwrap();
1492 assert_eq!(res.algorithm, "FCI");
1493 }
1494
1495 #[test]
1496 fn test_bic_greedy() {
1497 let data = chain_data();
1498 let learner = BicGreedySearch {
1499 max_iter: 50,
1500 ..Default::default()
1501 };
1502 let res = learner.fit(data.view(), &["X", "Y", "Z"]).unwrap();
1503 assert!(!res.score.is_nan());
1505 }
1506
1507 #[test]
1508 fn test_lingam_runs() {
1509 let data = chain_data();
1510 let ling = LiNGAM::default();
1511 let res = ling.fit(data.view(), &["X", "Y", "Z"]).unwrap();
1512 assert_eq!(res.causal_order.len(), 3);
1513 assert_eq!(res.b_matrix.nrows(), 3);
1514 }
1515
1516 #[test]
1517 fn test_notears_runs() {
1518 let data = chain_data();
1519 let nt = Notears {
1520 max_iter: 5,
1521 max_inner_iter: 10,
1522 ..Default::default()
1523 };
1524 let res = nt.fit(data.view(), &["X", "Y", "Z"]).unwrap();
1525 assert_eq!(res.dag.n_nodes(), 3);
1526 }
1527
1528 #[test]
1529 fn test_partial_correlation_independence() {
1530 let data = chain_data();
1532 let p_val = partial_correlation_test(data.view(), 0, 2, &[1]).unwrap();
1533 assert!(p_val > 0.01, "p={p_val}");
1535 }
1536}