1use super::CausalityResult;
40use crate::error::TimeSeriesError;
41
42use std::collections::{HashMap, HashSet};
43
44#[non_exhaustive]
50#[derive(Debug, Clone)]
51pub struct PCConfig {
52 pub significance_level: f64,
54 pub max_cond_set_size: usize,
58 pub test_type: IndependenceTest,
60}
61
62impl Default for PCConfig {
63 fn default() -> Self {
64 Self {
65 significance_level: 0.05,
66 max_cond_set_size: 4,
67 test_type: IndependenceTest::PartialCorrelation,
68 }
69 }
70}
71
72#[non_exhaustive]
74#[derive(Debug, Clone, Copy, PartialEq)]
75pub enum IndependenceTest {
76 PartialCorrelation,
78 MutualInformation,
80 KernelBased,
82}
83
84#[non_exhaustive]
86#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
87pub enum EdgeType {
88 Directed,
90 Undirected,
92 Bidirected,
94}
95
96#[derive(Debug, Clone, PartialEq, Eq, Hash)]
98pub struct CausalEdge {
99 pub from: usize,
101 pub to: usize,
103 pub edge_type: EdgeType,
105}
106
107#[derive(Debug, Clone)]
109pub struct CausalGraph {
110 pub nodes: usize,
112 pub edges: Vec<CausalEdge>,
114 pub separation_sets: HashMap<(usize, usize), Vec<usize>>,
117}
118
119impl CausalGraph {
120 pub fn has_directed_edge(&self, from: usize, to: usize) -> bool {
122 self.edges
123 .iter()
124 .any(|e| e.from == from && e.to == to && e.edge_type == EdgeType::Directed)
125 }
126
127 pub fn has_edge(&self, a: usize, b: usize) -> bool {
129 self.edges
130 .iter()
131 .any(|e| (e.from == a && e.to == b) || (e.from == b && e.to == a))
132 }
133
134 pub fn count_edges(&self, edge_type: EdgeType) -> usize {
136 self.edges
137 .iter()
138 .filter(|e| e.edge_type == edge_type)
139 .count()
140 }
141}
142
143#[derive(Debug, Clone)]
149pub struct PCAlgorithm {
150 config: PCConfig,
151}
152
153impl PCAlgorithm {
154 pub fn new(config: PCConfig) -> Self {
156 Self { config }
157 }
158
159 pub fn discover(&self, data: &[Vec<f64>]) -> CausalityResult<CausalGraph> {
168 let n_samples = data.len();
169 if n_samples < 4 {
170 return Err(TimeSeriesError::InsufficientData {
171 message: "Need at least 4 samples for PC algorithm".to_string(),
172 required: 4,
173 actual: n_samples,
174 });
175 }
176
177 let n_vars = data[0].len();
178 if n_vars < 2 {
179 return Err(TimeSeriesError::InvalidInput(
180 "Need at least 2 variables for causal discovery".to_string(),
181 ));
182 }
183
184 for (i, sample) in data.iter().enumerate() {
186 if sample.len() != n_vars {
187 return Err(TimeSeriesError::DimensionMismatch {
188 expected: n_vars,
189 actual: sample.len(),
190 });
191 }
192 for &v in sample {
194 if !v.is_finite() {
195 return Err(TimeSeriesError::InvalidInput(format!(
196 "Non-finite value in sample {}",
197 i
198 )));
199 }
200 }
201 }
202
203 let cov_matrix = compute_covariance_matrix(data)?;
205
206 let (adjacency, separation_sets) =
208 self.discover_skeleton(n_vars, n_samples, &cov_matrix)?;
209
210 let mut edge_types = self.orient_v_structures(n_vars, &adjacency, &separation_sets);
212
213 self.apply_meek_rules(n_vars, &adjacency, &mut edge_types);
215
216 let mut edges = Vec::new();
218 for i in 0..n_vars {
219 for j in (i + 1)..n_vars {
220 if adjacency[i].contains(&j) {
221 let key = (i, j);
222 let et = edge_types
223 .get(&key)
224 .copied()
225 .unwrap_or(EdgeType::Undirected);
226 match et {
227 EdgeType::Directed => {
228 if let Some(&dir) = edge_types.get(&(i, j)) {
231 if dir == EdgeType::Directed {
232 edges.push(CausalEdge {
233 from: i,
234 to: j,
235 edge_type: EdgeType::Directed,
236 });
237 }
238 }
239 }
240 _ => {
241 edges.push(CausalEdge {
242 from: i,
243 to: j,
244 edge_type: et,
245 });
246 }
247 }
248 }
249 }
250 }
251
252 for (&(from, to), &et) in &edge_types {
254 if et == EdgeType::Directed && from > to {
255 edges.push(CausalEdge {
257 from,
258 to,
259 edge_type: EdgeType::Directed,
260 });
261 }
262 }
263
264 let mut seen = HashSet::new();
266 let deduped: Vec<CausalEdge> = edges
267 .into_iter()
268 .filter(|e| {
269 let key = (e.from, e.to, e.edge_type);
270 seen.insert(key)
271 })
272 .collect();
273
274 Ok(CausalGraph {
275 nodes: n_vars,
276 edges: deduped,
277 separation_sets,
278 })
279 }
280
281 fn discover_skeleton(
284 &self,
285 n_vars: usize,
286 n_samples: usize,
287 cov_matrix: &[Vec<f64>],
288 ) -> CausalityResult<(Vec<HashSet<usize>>, HashMap<(usize, usize), Vec<usize>>)> {
289 let mut adjacency: Vec<HashSet<usize>> = (0..n_vars)
291 .map(|i| (0..n_vars).filter(|&j| j != i).collect())
292 .collect();
293
294 let mut separation_sets: HashMap<(usize, usize), Vec<usize>> = HashMap::new();
295
296 let mut p = 0usize;
297 loop {
298 if p > self.config.max_cond_set_size {
299 break;
300 }
301
302 let mut any_testable = false;
303 let mut removals: Vec<(usize, usize, Vec<usize>)> = Vec::new();
304
305 let adj_snapshot: Vec<Vec<usize>> = adjacency
307 .iter()
308 .map(|s| {
309 let mut v: Vec<usize> = s.iter().copied().collect();
310 v.sort();
311 v
312 })
313 .collect();
314
315 for i in 0..n_vars {
316 let neighbors_i = &adj_snapshot[i];
317 for &j in neighbors_i {
318 if j <= i {
319 continue; }
321
322 let cond_candidates: Vec<usize> =
324 neighbors_i.iter().copied().filter(|&k| k != j).collect();
325
326 if cond_candidates.len() < p {
327 continue;
328 }
329 any_testable = true;
330
331 let subsets = gen_combinations(&cond_candidates, p);
333 let mut found_independent = false;
334 let mut best_sep = Vec::new();
335
336 for subset in &subsets {
337 let p_value = self
338 .test_conditional_independence(i, j, subset, n_samples, cov_matrix)?;
339
340 if p_value > self.config.significance_level {
341 found_independent = true;
342 best_sep = subset.clone();
343 break;
344 }
345 }
346
347 if found_independent {
348 removals.push((i, j, best_sep));
349 }
350 }
351 }
352
353 for (i, j, sep_set) in removals {
355 adjacency[i].remove(&j);
356 adjacency[j].remove(&i);
357 let key = if i < j { (i, j) } else { (j, i) };
358 separation_sets.insert(key, sep_set);
359 }
360
361 if !any_testable {
362 break;
363 }
364 p += 1;
365 }
366
367 Ok((adjacency, separation_sets))
368 }
369
370 fn orient_v_structures(
373 &self,
374 n_vars: usize,
375 adjacency: &[HashSet<usize>],
376 separation_sets: &HashMap<(usize, usize), Vec<usize>>,
377 ) -> HashMap<(usize, usize), EdgeType> {
378 let mut edge_types: HashMap<(usize, usize), EdgeType> = HashMap::new();
379
380 for i in 0..n_vars {
382 for &j in &adjacency[i] {
383 if j > i {
384 edge_types.insert((i, j), EdgeType::Undirected);
385 }
386 }
387 }
388
389 for z in 0..n_vars {
391 let neighbors_z: Vec<usize> = adjacency[z].iter().copied().collect();
392 for idx_x in 0..neighbors_z.len() {
393 for idx_y in (idx_x + 1)..neighbors_z.len() {
394 let x = neighbors_z[idx_x];
395 let y = neighbors_z[idx_y];
396
397 if adjacency[x].contains(&y) {
399 continue;
400 }
401
402 let key = if x < y { (x, y) } else { (y, x) };
404 let sep_set = separation_sets.get(&key);
405
406 let z_in_sep = sep_set.map(|s| s.contains(&z)).unwrap_or(false);
408
409 if !z_in_sep {
410 edge_types.insert((x, z), EdgeType::Directed);
412 edge_types.insert((y, z), EdgeType::Directed);
414 let k1 = if x < z { (x, z) } else { (z, x) };
416 let k2 = if y < z { (y, z) } else { (z, y) };
417 edge_types.remove(&k1);
418 edge_types.remove(&k2);
419 edge_types.insert((x, z), EdgeType::Directed);
420 edge_types.insert((y, z), EdgeType::Directed);
421 }
422 }
423 }
424 }
425
426 edge_types
427 }
428
429 fn apply_meek_rules(
432 &self,
433 n_vars: usize,
434 adjacency: &[HashSet<usize>],
435 edge_types: &mut HashMap<(usize, usize), EdgeType>,
436 ) {
437 let max_iterations = n_vars * n_vars;
439 for _ in 0..max_iterations {
440 let mut changed = false;
441
442 for b in 0..n_vars {
444 let neighbors_b: Vec<usize> = adjacency[b].iter().copied().collect();
445 for &c in &neighbors_b {
446 if !is_undirected(edge_types, b, c) {
448 continue;
449 }
450
451 for &a in &neighbors_b {
452 if a == c {
453 continue;
454 }
455 if !is_directed(edge_types, a, b) {
457 continue;
458 }
459 if adjacency[a].contains(&c) {
461 continue;
462 }
463
464 orient_edge(edge_types, b, c);
466 changed = true;
467 }
468 }
469 }
470
471 for a in 0..n_vars {
473 let neighbors_a: Vec<usize> = adjacency[a].iter().copied().collect();
474 for &b in &neighbors_a {
475 if !is_undirected(edge_types, a, b) {
476 continue;
477 }
478
479 for &c in &neighbors_a {
481 if c == b {
482 continue;
483 }
484 if !is_directed(edge_types, a, c) {
485 continue;
486 }
487 if !adjacency[c].contains(&b) {
488 continue;
489 }
490 if !is_directed(edge_types, c, b) {
491 continue;
492 }
493
494 orient_edge(edge_types, a, b);
495 changed = true;
496 }
497 }
498 }
499
500 for a in 0..n_vars {
503 let neighbors_a: Vec<usize> = adjacency[a].iter().copied().collect();
504 for &b in &neighbors_a {
505 if !is_undirected(edge_types, a, b) {
506 continue;
507 }
508
509 let mut oriented = false;
511 for idx_c in 0..neighbors_a.len() {
512 if oriented {
513 break;
514 }
515 let c = neighbors_a[idx_c];
516 if c == b {
517 continue;
518 }
519 if !is_undirected(edge_types, a, c) {
520 continue;
521 }
522 if !adjacency[c].contains(&b) || !is_directed(edge_types, c, b) {
523 continue;
524 }
525
526 for idx_d in (idx_c + 1)..neighbors_a.len() {
527 let d = neighbors_a[idx_d];
528 if d == b || d == c {
529 continue;
530 }
531 if !is_undirected(edge_types, a, d) {
532 continue;
533 }
534 if !adjacency[d].contains(&b) || !is_directed(edge_types, d, b) {
535 continue;
536 }
537 if adjacency[c].contains(&d) {
539 continue;
540 }
541
542 orient_edge(edge_types, a, b);
543 changed = true;
544 oriented = true;
545 break;
546 }
547 }
548 }
549 }
550
551 if !changed {
552 break;
553 }
554 }
555 }
556
557 fn test_conditional_independence(
560 &self,
561 i: usize,
562 j: usize,
563 cond_set: &[usize],
564 n_samples: usize,
565 cov_matrix: &[Vec<f64>],
566 ) -> CausalityResult<f64> {
567 match self.config.test_type {
568 IndependenceTest::PartialCorrelation => {
569 partial_correlation_test(i, j, cond_set, n_samples, cov_matrix)
570 }
571 IndependenceTest::MutualInformation => {
572 mutual_information_test(i, j, cond_set, n_samples, cov_matrix)
573 }
574 IndependenceTest::KernelBased => {
575 partial_correlation_test(i, j, cond_set, n_samples, cov_matrix)
578 }
579 }
580 }
581}
582
583fn partial_correlation_test(
592 i: usize,
593 j: usize,
594 cond_set: &[usize],
595 n_samples: usize,
596 cov_matrix: &[Vec<f64>],
597) -> CausalityResult<f64> {
598 let parcorr = compute_partial_corr(i, j, cond_set, cov_matrix)?;
599
600 let df = n_samples as f64 - cond_set.len() as f64 - 2.0;
602 if df < 1.0 {
603 return Ok(1.0); }
605
606 let clamped = parcorr.clamp(-0.9999, 0.9999);
607 let z_stat = 0.5 * ((1.0 + clamped) / (1.0 - clamped)).ln() * df.sqrt();
608
609 let p_value = 2.0 * (1.0 - normal_cdf(z_stat.abs()));
611 Ok(p_value)
612}
613
614fn mutual_information_test(
619 i: usize,
620 j: usize,
621 cond_set: &[usize],
622 n_samples: usize,
623 cov_matrix: &[Vec<f64>],
624) -> CausalityResult<f64> {
625 let parcorr = compute_partial_corr(i, j, cond_set, cov_matrix)?;
626
627 let r_sq = parcorr * parcorr;
628 let mi = if r_sq < 1.0 {
629 -0.5 * (1.0 - r_sq).ln()
630 } else {
631 f64::INFINITY
632 };
633
634 let test_stat = 2.0 * n_samples as f64 * mi;
635 let p_value = chi_squared_p_value_1df(test_stat);
636 Ok(p_value)
637}
638
639fn compute_partial_corr(
644 i: usize,
645 j: usize,
646 cond_set: &[usize],
647 cov_matrix: &[Vec<f64>],
648) -> CausalityResult<f64> {
649 if cond_set.is_empty() {
650 let var_i = cov_matrix[i][i];
652 let var_j = cov_matrix[j][j];
653 let denom = (var_i * var_j).sqrt();
654 if denom < 1e-15 {
655 return Ok(0.0);
656 }
657 return Ok(cov_matrix[i][j] / denom);
658 }
659
660 let mut indices = vec![i, j];
662 indices.extend_from_slice(cond_set);
663 let k = indices.len();
664
665 let mut sub_cov = vec![vec![0.0; k]; k];
666 for (a_idx, &a) in indices.iter().enumerate() {
667 for (b_idx, &b) in indices.iter().enumerate() {
668 sub_cov[a_idx][b_idx] = cov_matrix[a][b];
669 }
670 }
671
672 for idx in 0..k {
674 sub_cov[idx][idx] += 1e-10;
675 }
676
677 let precision = invert_small_matrix(&sub_cov)?;
679
680 let denom = (precision[0][0] * precision[1][1]).sqrt();
681 if denom < 1e-15 {
682 return Ok(0.0);
683 }
684
685 Ok(-precision[0][1] / denom)
686}
687
688fn compute_covariance_matrix(data: &[Vec<f64>]) -> CausalityResult<Vec<Vec<f64>>> {
694 let n = data.len();
695 let p = data[0].len();
696
697 let mut means = vec![0.0; p];
699 for sample in data {
700 for (j, &v) in sample.iter().enumerate() {
701 means[j] += v;
702 }
703 }
704 for m in &mut means {
705 *m /= n as f64;
706 }
707
708 let mut cov = vec![vec![0.0; p]; p];
710 for sample in data {
711 for a in 0..p {
712 let da = sample[a] - means[a];
713 for b in a..p {
714 let db = sample[b] - means[b];
715 cov[a][b] += da * db;
716 }
717 }
718 }
719
720 let denom = (n as f64 - 1.0).max(1.0);
721 for a in 0..p {
722 for b in a..p {
723 cov[a][b] /= denom;
724 cov[b][a] = cov[a][b];
725 }
726 }
727
728 Ok(cov)
729}
730
731fn invert_small_matrix(mat: &[Vec<f64>]) -> CausalityResult<Vec<Vec<f64>>> {
733 let n = mat.len();
734 let mut augmented = vec![vec![0.0; 2 * n]; n];
735
736 for i in 0..n {
737 for j in 0..n {
738 augmented[i][j] = mat[i][j];
739 }
740 augmented[i][n + i] = 1.0;
741 }
742
743 for col in 0..n {
744 let mut max_val = augmented[col][col].abs();
745 let mut max_row = col;
746 for row in (col + 1)..n {
747 let val = augmented[row][col].abs();
748 if val > max_val {
749 max_val = val;
750 max_row = row;
751 }
752 }
753
754 if max_val < 1e-14 {
755 return Err(TimeSeriesError::NumericalInstability(
756 "Singular matrix in partial correlation computation".to_string(),
757 ));
758 }
759
760 if max_row != col {
761 augmented.swap(col, max_row);
762 }
763
764 let pivot = augmented[col][col];
765 for j in 0..(2 * n) {
766 augmented[col][j] /= pivot;
767 }
768
769 for row in 0..n {
770 if row != col {
771 let factor = augmented[row][col];
772 for j in 0..(2 * n) {
773 augmented[row][j] -= factor * augmented[col][j];
774 }
775 }
776 }
777 }
778
779 let mut inv = vec![vec![0.0; n]; n];
780 for i in 0..n {
781 for j in 0..n {
782 inv[i][j] = augmented[i][n + j];
783 }
784 }
785
786 Ok(inv)
787}
788
789fn is_directed(edge_types: &HashMap<(usize, usize), EdgeType>, from: usize, to: usize) -> bool {
794 edge_types
795 .get(&(from, to))
796 .map(|&et| et == EdgeType::Directed)
797 .unwrap_or(false)
798}
799
800fn is_undirected(edge_types: &HashMap<(usize, usize), EdgeType>, a: usize, b: usize) -> bool {
801 let k1 = (a, b);
802 let k2 = (b, a);
803 let k_canon = if a < b { (a, b) } else { (b, a) };
804
805 if is_directed(edge_types, a, b) || is_directed(edge_types, b, a) {
807 return false;
808 }
809
810 edge_types
812 .get(&k_canon)
813 .map(|&et| et == EdgeType::Undirected)
814 .unwrap_or(false)
815 || edge_types
816 .get(&k1)
817 .map(|&et| et == EdgeType::Undirected)
818 .unwrap_or(false)
819 || edge_types
820 .get(&k2)
821 .map(|&et| et == EdgeType::Undirected)
822 .unwrap_or(false)
823}
824
825fn orient_edge(edge_types: &mut HashMap<(usize, usize), EdgeType>, from: usize, to: usize) {
826 let k_canon = if from < to { (from, to) } else { (to, from) };
828 edge_types.remove(&k_canon);
829 edge_types.remove(&(to, from));
830 edge_types.remove(&(from, to));
831 edge_types.insert((from, to), EdgeType::Directed);
833}
834
835fn gen_combinations(items: &[usize], k: usize) -> Vec<Vec<usize>> {
840 if k == 0 {
841 return vec![vec![]];
842 }
843 if k > items.len() {
844 return vec![];
845 }
846 if k == items.len() {
847 return vec![items.to_vec()];
848 }
849
850 let mut result = Vec::new();
851 gen_combinations_rec(items, k, 0, &mut vec![], &mut result);
852 result
853}
854
855fn gen_combinations_rec(
856 items: &[usize],
857 k: usize,
858 start: usize,
859 current: &mut Vec<usize>,
860 result: &mut Vec<Vec<usize>>,
861) {
862 if current.len() == k {
863 result.push(current.clone());
864 return;
865 }
866 let remaining = k - current.len();
867 let available = items.len() - start;
868 if available < remaining {
869 return;
870 }
871 for i in start..items.len() {
872 current.push(items[i]);
873 gen_combinations_rec(items, k, i + 1, current, result);
874 current.pop();
875 }
876}
877
878fn normal_cdf(x: f64) -> f64 {
883 0.5 * (1.0 + erf(x / std::f64::consts::SQRT_2))
884}
885
886fn erf(x: f64) -> f64 {
887 let a1 = 0.254_829_592;
888 let a2 = -0.284_496_736;
889 let a3 = 1.421_413_741;
890 let a4 = -1.453_152_027;
891 let a5 = 1.061_405_429;
892 let p = 0.327_591_1;
893
894 let sign = if x < 0.0 { -1.0 } else { 1.0 };
895 let x = x.abs();
896 let t = 1.0 / (1.0 + p * x);
897 let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
898 sign * y
899}
900
901fn chi_squared_p_value_1df(chi2: f64) -> f64 {
903 if chi2 <= 0.0 {
904 return 1.0;
905 }
906 2.0 * (1.0 - normal_cdf(chi2.sqrt()))
909}
910
911pub fn fisher_z_transform(r: f64, n: usize, cond_size: usize) -> (f64, f64) {
921 let clamped = r.clamp(-0.9999, 0.9999);
922 let z = 0.5 * ((1.0 + clamped) / (1.0 - clamped)).ln();
923 let df = n as f64 - cond_size as f64 - 3.0;
924 let z_stat = if df > 0.0 { z * df.sqrt() } else { 0.0 };
925 let p_value = if df > 0.0 {
926 2.0 * (1.0 - normal_cdf(z_stat.abs()))
927 } else {
928 1.0
929 };
930 (z_stat, p_value)
931}
932
933#[cfg(test)]
938mod tests {
939 use super::*;
940
941 fn next_rand(state: &mut u64) -> f64 {
943 *state = state
944 .wrapping_mul(6364136223846793005)
945 .wrapping_add(1442695040888963407);
946 ((*state >> 32) as f64) / (u32::MAX as f64) - 0.5
947 }
948
949 fn generate_chain(n: usize, seed: u64) -> Vec<Vec<f64>> {
951 let mut state = seed;
952 let mut data = Vec::with_capacity(n);
953 for _ in 0..n {
954 let x = next_rand(&mut state);
955 let y = 0.8 * x + next_rand(&mut state) * 0.3;
956 let z = 0.8 * y + next_rand(&mut state) * 0.3;
957 data.push(vec![x, y, z]);
958 }
959 data
960 }
961
962 fn generate_v_structure(n: usize, seed: u64) -> Vec<Vec<f64>> {
964 let mut state = seed;
965 let mut data = Vec::with_capacity(n);
966 for _ in 0..n {
967 let x = next_rand(&mut state);
968 let y = next_rand(&mut state);
969 let z = 0.7 * x + 0.7 * y + next_rand(&mut state) * 0.2;
970 data.push(vec![x, y, z]);
971 }
972 data
973 }
974
975 fn generate_independent(n: usize, seed: u64) -> Vec<Vec<f64>> {
977 let mut state = seed;
978 let mut data = Vec::with_capacity(n);
979 for _ in 0..n {
980 let x = next_rand(&mut state);
981 let y = next_rand(&mut state);
982 let z = next_rand(&mut state);
983 data.push(vec![x, y, z]);
984 }
985 data
986 }
987
988 #[test]
989 fn test_pc_config_default() {
990 let cfg = PCConfig::default();
991 assert!((cfg.significance_level - 0.05).abs() < 1e-10);
992 assert_eq!(cfg.max_cond_set_size, 4);
993 assert_eq!(cfg.test_type, IndependenceTest::PartialCorrelation);
994 }
995
996 #[test]
997 fn test_independent_variables_no_edge() {
998 let data = generate_independent(500, 42);
999 let config = PCConfig {
1000 significance_level: 0.05,
1001 max_cond_set_size: 2,
1002 test_type: IndependenceTest::PartialCorrelation,
1003 };
1004 let pc = PCAlgorithm::new(config);
1005 let graph = pc.discover(&data).expect("discovery");
1006 assert!(
1008 graph.edges.len() <= 1,
1009 "Independent vars should have ~0 edges, got {}",
1010 graph.edges.len()
1011 );
1012 }
1013
1014 #[test]
1015 fn test_chain_skeleton_discovered() {
1016 let data = generate_chain(1000, 123);
1018 let config = PCConfig {
1019 significance_level: 0.05,
1020 max_cond_set_size: 2,
1021 test_type: IndependenceTest::PartialCorrelation,
1022 };
1023 let pc = PCAlgorithm::new(config);
1024 let graph = pc.discover(&data).expect("discovery");
1025
1026 let has_xy = graph.has_edge(0, 1);
1028 let has_yz = graph.has_edge(1, 2);
1029 assert!(has_xy, "Should have X-Y edge in chain");
1030 assert!(has_yz, "Should have Y-Z edge in chain");
1031
1032 let has_xz = graph.has_edge(0, 2);
1034 assert!(!has_xz, "Should NOT have X-Z direct edge in chain");
1035 }
1036
1037 #[test]
1038 fn test_v_structure_orientation() {
1039 let data = generate_v_structure(1000, 456);
1041 let config = PCConfig {
1042 significance_level: 0.05,
1043 max_cond_set_size: 2,
1044 test_type: IndependenceTest::PartialCorrelation,
1045 };
1046 let pc = PCAlgorithm::new(config);
1047 let graph = pc.discover(&data).expect("discovery");
1048
1049 assert!(
1051 !graph.has_edge(0, 1),
1052 "X and Y should not be adjacent in v-structure"
1053 );
1054
1055 let has_xz = graph.has_edge(0, 2);
1057 let has_yz = graph.has_edge(1, 2);
1058 assert!(has_xz, "Should have X-Z edge");
1059 assert!(has_yz, "Should have Y-Z edge");
1060
1061 let x_to_z = graph.has_directed_edge(0, 2);
1063 let y_to_z = graph.has_directed_edge(1, 2);
1064 assert!(x_to_z, "Should orient X -> Z in v-structure");
1065 assert!(y_to_z, "Should orient Y -> Z in v-structure");
1066 }
1067
1068 #[test]
1069 fn test_causal_graph_node_edge_counts() {
1070 let data = generate_chain(500, 789);
1071 let config = PCConfig::default();
1072 let pc = PCAlgorithm::new(config);
1073 let graph = pc.discover(&data).expect("discovery");
1074
1075 assert_eq!(graph.nodes, 3);
1076 assert!(
1078 graph.edges.len() >= 2,
1079 "Chain should have at least 2 edges, got {}",
1080 graph.edges.len()
1081 );
1082 }
1083
1084 #[test]
1085 fn test_partial_correlation_independent() {
1086 let data = generate_independent(500, 999);
1088 let cov = compute_covariance_matrix(&data).expect("cov");
1089 let parcorr = compute_partial_corr(0, 1, &[], &cov).expect("parcorr");
1090 assert!(
1091 parcorr.abs() < 0.15,
1092 "Independent vars should have near-zero partial corr, got {}",
1093 parcorr
1094 );
1095 }
1096
1097 #[test]
1098 fn test_partial_correlation_dependent() {
1099 let data = generate_chain(500, 111);
1100 let cov = compute_covariance_matrix(&data).expect("cov");
1101 let parcorr = compute_partial_corr(0, 1, &[], &cov).expect("parcorr");
1102 assert!(
1103 parcorr.abs() > 0.3,
1104 "Dependent vars should have significant partial corr, got {}",
1105 parcorr
1106 );
1107 }
1108
1109 #[test]
1110 fn test_partial_correlation_conditional_independence() {
1111 let data = generate_chain(1000, 222);
1113 let cov = compute_covariance_matrix(&data).expect("cov");
1114 let parcorr_xz_given_y = compute_partial_corr(0, 2, &[1], &cov).expect("parcorr");
1115 assert!(
1116 parcorr_xz_given_y.abs() < 0.15,
1117 "X⊥Z|Y should hold in chain, parcorr={}",
1118 parcorr_xz_given_y
1119 );
1120 }
1121
1122 #[test]
1123 fn test_fisher_z_transform_correct_pvalue() {
1124 let (_, p) = fisher_z_transform(0.0, 100, 0);
1126 assert!(
1127 (p - 1.0).abs() < 0.01,
1128 "Zero correlation should give p≈1, got {}",
1129 p
1130 );
1131
1132 let (_, p2) = fisher_z_transform(0.9, 100, 0);
1134 assert!(
1135 p2 < 0.01,
1136 "Strong correlation should give small p-value, got {}",
1137 p2
1138 );
1139 }
1140
1141 #[test]
1142 fn test_mutual_information_test() {
1143 let data = generate_chain(500, 333);
1144 let cov = compute_covariance_matrix(&data).expect("cov");
1145 let p = mutual_information_test(0, 1, &[], 500, &cov).expect("mi");
1146 assert!(
1147 p < 0.05,
1148 "MI test should detect dependence in chain, p={}",
1149 p
1150 );
1151 }
1152
1153 #[test]
1154 fn test_pc_insufficient_data() {
1155 let data = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
1156 let pc = PCAlgorithm::new(PCConfig::default());
1157 let result = pc.discover(&data);
1158 assert!(result.is_err());
1159 }
1160
1161 #[test]
1162 fn test_edge_type_non_exhaustive() {
1163 let _d = EdgeType::Directed;
1165 let _u = EdgeType::Undirected;
1166 let _b = EdgeType::Bidirected;
1167 }
1168
1169 #[test]
1170 fn test_known_graph_recovery_synthetic() {
1171 let n = 2000;
1174 let mut state: u64 = 42;
1175 let mut data = Vec::with_capacity(n);
1176 for _ in 0..n {
1177 let x0 = next_rand(&mut state);
1178 let x1 = 0.8 * x0 + next_rand(&mut state) * 0.2;
1179 let x2 = 0.8 * x0 + next_rand(&mut state) * 0.2;
1180 let x3 = 0.5 * x1 + 0.5 * x2 + next_rand(&mut state) * 0.2;
1181 data.push(vec![x0, x1, x2, x3]);
1182 }
1183
1184 let config = PCConfig {
1185 significance_level: 0.05,
1186 max_cond_set_size: 3,
1187 test_type: IndependenceTest::PartialCorrelation,
1188 };
1189 let pc = PCAlgorithm::new(config);
1190 let graph = pc.discover(&data).expect("discovery");
1191
1192 assert_eq!(graph.nodes, 4);
1193
1194 assert!(graph.has_edge(0, 1), "Should have X0-X1 edge");
1196 assert!(graph.has_edge(0, 2), "Should have X0-X2 edge");
1197 assert!(graph.has_edge(1, 3), "Should have X1-X3 edge");
1198 assert!(graph.has_edge(2, 3), "Should have X2-X3 edge");
1199 }
1200}