1use std::collections::HashMap;
36
37use scirs2_core::ndarray::ArrayView2;
38
39use super::conditional_independence::{ConditionalIndependenceTest, PartialCorrelationTest};
40use super::{CausalGraph, EdgeMark};
41use crate::error::{StatsError, StatsResult};
42
43#[derive(Debug, Clone)]
49pub struct PcAlgorithm {
50 pub alpha: f64,
52 pub max_cond_set_size: usize,
54 pub stable: bool,
56}
57
58impl Default for PcAlgorithm {
59 fn default() -> Self {
60 Self {
61 alpha: 0.05,
62 max_cond_set_size: 3,
63 stable: true,
64 }
65 }
66}
67
68#[derive(Debug, Clone)]
70pub struct PcResult {
71 pub graph: CausalGraph,
73 pub sep_sets: HashMap<(usize, usize), Vec<usize>>,
75 pub n_tests: usize,
77}
78
79impl PcAlgorithm {
80 pub fn new(alpha: f64) -> Self {
82 Self {
83 alpha,
84 ..Default::default()
85 }
86 }
87
88 pub fn with_params(alpha: f64, max_cond_set_size: usize, stable: bool) -> Self {
90 Self {
91 alpha,
92 max_cond_set_size,
93 stable,
94 }
95 }
96
97 pub fn fit(&self, data: ArrayView2<f64>, var_names: &[&str]) -> StatsResult<PcResult> {
99 let ci_test = PartialCorrelationTest::new(self.alpha);
100 self.fit_with_test(data, var_names, &ci_test)
101 }
102
103 pub fn fit_with_test<T: ConditionalIndependenceTest>(
105 &self,
106 data: ArrayView2<f64>,
107 var_names: &[&str],
108 ci_test: &T,
109 ) -> StatsResult<PcResult> {
110 let p = data.ncols();
111 if var_names.len() != p {
112 return Err(StatsError::DimensionMismatch(
113 "var_names length must match number of columns".to_owned(),
114 ));
115 }
116 if p == 0 {
117 return Ok(PcResult {
118 graph: CausalGraph::new(var_names),
119 sep_sets: HashMap::new(),
120 n_tests: 0,
121 });
122 }
123
124 let (adj, sep_sets, n_tests) = if self.stable {
126 self.skeleton_stable(data, p, ci_test)?
127 } else {
128 self.skeleton_standard(data, p, ci_test)?
129 };
130
131 let mut graph = CausalGraph::new(var_names);
133 for i in 0..p {
135 for j in (i + 1)..p {
136 if adj[i][j] {
137 graph.set_edge(i, j, EdgeMark::Tail, EdgeMark::Tail);
138 }
139 }
140 }
141
142 orient_v_structures(&mut graph, &adj, &sep_sets, p);
143
144 apply_meek_rules(&mut graph, p);
146
147 Ok(PcResult {
148 graph,
149 sep_sets,
150 n_tests,
151 })
152 }
153
154 fn skeleton_standard<T: ConditionalIndependenceTest>(
156 &self,
157 data: ArrayView2<f64>,
158 p: usize,
159 ci_test: &T,
160 ) -> StatsResult<(Vec<Vec<bool>>, HashMap<(usize, usize), Vec<usize>>, usize)> {
161 let mut adj = vec![vec![true; p]; p];
162 for i in 0..p {
163 adj[i][i] = false;
164 }
165 let mut sep_sets: HashMap<(usize, usize), Vec<usize>> = HashMap::new();
166 let mut n_tests = 0usize;
167
168 for ord in 0..=self.max_cond_set_size {
169 let edges: Vec<(usize, usize)> = (0..p)
170 .flat_map(|i| ((i + 1)..p).map(move |j| (i, j)))
171 .filter(|&(i, j)| adj[i][j])
172 .collect();
173
174 for (x, y) in edges {
175 let z_candidates: Vec<usize> =
176 (0..p).filter(|&k| k != x && k != y && adj[x][k]).collect();
177 if z_candidates.len() < ord {
178 continue;
179 }
180
181 for z_set in subsets(&z_candidates, ord) {
182 n_tests += 1;
183 if ci_test.is_independent(x, y, &z_set, data, self.alpha)? {
184 adj[x][y] = false;
185 adj[y][x] = false;
186 let key = (x.min(y), x.max(y));
187 sep_sets.insert(key, z_set);
188 break;
189 }
190 }
191 }
192 }
193
194 Ok((adj, sep_sets, n_tests))
195 }
196
197 fn skeleton_stable<T: ConditionalIndependenceTest>(
202 &self,
203 data: ArrayView2<f64>,
204 p: usize,
205 ci_test: &T,
206 ) -> StatsResult<(Vec<Vec<bool>>, HashMap<(usize, usize), Vec<usize>>, usize)> {
207 let mut adj = vec![vec![true; p]; p];
208 for i in 0..p {
209 adj[i][i] = false;
210 }
211 let mut sep_sets: HashMap<(usize, usize), Vec<usize>> = HashMap::new();
212 let mut n_tests = 0usize;
213
214 for ord in 0..=self.max_cond_set_size {
215 let adj_snapshot = adj.clone();
217
218 let edges: Vec<(usize, usize)> = (0..p)
219 .flat_map(|i| ((i + 1)..p).map(move |j| (i, j)))
220 .filter(|&(i, j)| adj_snapshot[i][j])
221 .collect();
222
223 let mut removals: Vec<(usize, usize, Vec<usize>)> = Vec::new();
225
226 for (x, y) in edges {
227 let z_candidates: Vec<usize> = (0..p)
229 .filter(|&k| k != x && k != y && adj_snapshot[x][k])
230 .collect();
231 if z_candidates.len() < ord {
232 continue;
233 }
234
235 let z_candidates_y: Vec<usize> = (0..p)
237 .filter(|&k| k != x && k != y && adj_snapshot[y][k])
238 .collect();
239
240 let mut found = false;
241 for z_set in subsets(&z_candidates, ord) {
243 n_tests += 1;
244 if ci_test.is_independent(x, y, &z_set, data, self.alpha)? {
245 removals.push((x, y, z_set));
246 found = true;
247 break;
248 }
249 }
250 if found {
251 continue;
252 }
253 if z_candidates_y.len() >= ord {
255 for z_set in subsets(&z_candidates_y, ord) {
256 n_tests += 1;
258 if ci_test.is_independent(x, y, &z_set, data, self.alpha)? {
259 removals.push((x, y, z_set));
260 break;
261 }
262 }
263 }
264 }
265
266 for (x, y, z_set) in removals {
268 adj[x][y] = false;
269 adj[y][x] = false;
270 let key = (x.min(y), x.max(y));
271 sep_sets.insert(key, z_set);
272 }
273 }
274
275 Ok((adj, sep_sets, n_tests))
276 }
277}
278
279fn orient_v_structures(
286 graph: &mut CausalGraph,
287 adj: &[Vec<bool>],
288 sep_sets: &HashMap<(usize, usize), Vec<usize>>,
289 p: usize,
290) {
291 for z in 0..p {
292 let neighbours: Vec<usize> = (0..p).filter(|&k| k != z && adj[z][k]).collect();
293 for i in 0..neighbours.len() {
294 for j in (i + 1)..neighbours.len() {
295 let x = neighbours[i];
296 let y = neighbours[j];
297 if adj[x][y] {
299 continue;
300 }
301 let key = (x.min(y), x.max(y));
302 let sep = sep_sets.get(&key).cloned().unwrap_or_default();
303 if !sep.contains(&z) {
304 graph.set_edge(x, z, EdgeMark::Tail, EdgeMark::Arrow);
306 graph.set_edge(y, z, EdgeMark::Tail, EdgeMark::Arrow);
307 }
308 }
309 }
310 }
311}
312
313pub fn apply_meek_rules(graph: &mut CausalGraph, p: usize) {
325 let max_iterations = p * p + 10;
326 let mut changed = true;
327 let mut iterations = 0;
328
329 while changed && iterations < max_iterations {
330 changed = false;
331 iterations += 1;
332
333 changed |= meek_r1(graph, p);
335
336 changed |= meek_r2(graph, p);
338
339 changed |= meek_r3(graph, p);
341
342 changed |= meek_r4(graph, p);
345 }
346}
347
348fn meek_r1(graph: &mut CausalGraph, p: usize) -> bool {
350 let mut changed = false;
351 for b in 0..p {
352 for a in 0..p {
353 if a == b {
354 continue;
355 }
356 if !graph.is_directed(a, b) {
358 continue;
359 }
360 for c in 0..p {
361 if c == a || c == b {
362 continue;
363 }
364 if !graph.is_undirected(b, c) {
366 continue;
367 }
368 if graph.is_adjacent(a, c) {
370 continue;
371 }
372 graph.set_edge(b, c, EdgeMark::Tail, EdgeMark::Arrow);
374 changed = true;
375 }
376 }
377 }
378 changed
379}
380
381fn meek_r2(graph: &mut CausalGraph, p: usize) -> bool {
383 let mut changed = false;
384 for a in 0..p {
385 for b in 0..p {
386 if a == b {
387 continue;
388 }
389 if !graph.is_directed(a, b) {
390 continue;
391 }
392 for c in 0..p {
393 if c == a || c == b {
394 continue;
395 }
396 if !graph.is_directed(b, c) {
397 continue;
398 }
399 if !graph.is_undirected(a, c) {
400 continue;
401 }
402 graph.set_edge(a, c, EdgeMark::Tail, EdgeMark::Arrow);
403 changed = true;
404 }
405 }
406 }
407 changed
408}
409
410fn meek_r3(graph: &mut CausalGraph, p: usize) -> bool {
413 let mut changed = false;
414 for a in 0..p {
415 for d in 0..p {
416 if a == d {
417 continue;
418 }
419 if !graph.is_undirected(a, d) {
421 continue;
422 }
423 let parents_of_d: Vec<usize> = (0..p)
425 .filter(|&k| k != a && k != d && graph.is_directed(k, d))
426 .collect();
427 let mut orient = false;
428 for i in 0..parents_of_d.len() {
429 for j in (i + 1)..parents_of_d.len() {
430 let b = parents_of_d[i];
431 let c = parents_of_d[j];
432 if graph.is_undirected(a, b)
433 && graph.is_undirected(a, c)
434 && !graph.is_adjacent(b, c)
435 {
436 orient = true;
437 break;
438 }
439 }
440 if orient {
441 break;
442 }
443 }
444 if orient {
445 graph.set_edge(a, d, EdgeMark::Tail, EdgeMark::Arrow);
446 changed = true;
447 }
448 }
449 }
450 changed
451}
452
453fn meek_r4(graph: &mut CausalGraph, p: usize) -> bool {
460 let mut changed = false;
461 for a in 0..p {
462 for b in 0..p {
463 if a == b {
464 continue;
465 }
466 if !graph.is_undirected(a, b) {
467 continue;
468 }
469 for c in 0..p {
472 if c == a || c == b {
473 continue;
474 }
475 if !graph.is_undirected(a, c) {
476 continue;
477 }
478 if !graph.is_directed(b, c) {
479 continue;
480 }
481 if has_directed_path(graph, c, a, p) {
483 graph.set_edge(a, b, EdgeMark::Tail, EdgeMark::Arrow);
484 changed = true;
485 break;
486 }
487 }
488 }
489 }
490 changed
491}
492
493fn has_directed_path(graph: &CausalGraph, src: usize, dst: usize, p: usize) -> bool {
495 let mut visited = vec![false; p];
496 let mut stack = vec![src];
497 while let Some(cur) = stack.pop() {
498 if cur == dst {
499 return true;
500 }
501 if visited[cur] {
502 continue;
503 }
504 visited[cur] = true;
505 for next in 0..p {
506 if !visited[next] && graph.is_directed(cur, next) {
507 stack.push(next);
508 }
509 }
510 }
511 false
512}
513
514pub(crate) fn subsets<T: Copy>(items: &[T], k: usize) -> Vec<Vec<T>> {
520 if k == 0 {
521 return vec![Vec::new()];
522 }
523 if k > items.len() {
524 return Vec::new();
525 }
526 let mut result = Vec::new();
527 for i in 0..=(items.len() - k) {
528 for mut rest in subsets(&items[i + 1..], k - 1) {
529 rest.insert(0, items[i]);
530 result.push(rest);
531 }
532 }
533 result
534}
535
536#[cfg(test)]
541mod tests {
542 use super::*;
543 use scirs2_core::ndarray::Array2;
544
545 fn lcg_uniform(s: &mut u64) -> f64 {
546 *s = s
547 .wrapping_mul(6364136223846793005)
548 .wrapping_add(1442695040888963407);
549 ((*s >> 11) as f64) / ((1u64 << 53) as f64)
550 }
551
552 fn lcg_normal(s: &mut u64) -> f64 {
553 let u1 = lcg_uniform(s).max(1e-15);
554 let u2 = lcg_uniform(s);
555 (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
556 }
557
558 fn chain_data(n: usize, seed: u64) -> Array2<f64> {
560 let mut data = Array2::<f64>::zeros((n, 3));
561 let mut lcg = seed;
562 for i in 0..n {
563 data[[i, 0]] = lcg_normal(&mut lcg);
564 data[[i, 1]] = 0.9 * data[[i, 0]] + lcg_normal(&mut lcg) * 0.3;
565 data[[i, 2]] = 0.9 * data[[i, 1]] + lcg_normal(&mut lcg) * 0.3;
566 }
567 data
568 }
569
570 fn fork_data(n: usize, seed: u64) -> Array2<f64> {
572 let mut data = Array2::<f64>::zeros((n, 3));
573 let mut lcg = seed;
574 for i in 0..n {
575 let y = lcg_normal(&mut lcg);
576 data[[i, 0]] = 0.9 * y + lcg_normal(&mut lcg) * 0.3;
577 data[[i, 1]] = y;
578 data[[i, 2]] = 0.9 * y + lcg_normal(&mut lcg) * 0.3;
579 }
580 data
581 }
582
583 fn collider_data(n: usize, seed: u64) -> Array2<f64> {
585 let mut data = Array2::<f64>::zeros((n, 3));
586 let mut lcg = seed;
587 for i in 0..n {
588 data[[i, 0]] = lcg_normal(&mut lcg);
589 data[[i, 2]] = lcg_normal(&mut lcg);
590 data[[i, 1]] = 0.7 * data[[i, 0]] + 0.7 * data[[i, 2]] + lcg_normal(&mut lcg) * 0.3;
591 }
592 data
593 }
594
595 #[test]
596 fn test_pc_chain() {
597 let data = chain_data(300, 12345);
598 let pc = PcAlgorithm::new(0.05);
599 let result = pc.fit(data.view(), &["X", "Y", "Z"]).expect("PC failed");
600 assert!(
603 result.graph.is_adjacent(0, 1),
604 "X-Y should be adjacent in chain"
605 );
606 assert!(
608 result.graph.is_adjacent(1, 2),
609 "Y-Z should be adjacent in chain"
610 );
611 assert!(
613 !result.graph.is_adjacent(0, 2),
614 "X-Z should not be adjacent in chain"
615 );
616 }
617
618 #[test]
619 fn test_pc_fork() {
620 let data = fork_data(300, 54321);
621 let pc = PcAlgorithm::new(0.05);
622 let result = pc.fit(data.view(), &["X", "Y", "Z"]).expect("PC failed");
623 assert!(result.graph.is_adjacent(0, 1), "X-Y should be adjacent");
626 assert!(result.graph.is_adjacent(1, 2), "Y-Z should be adjacent");
627 assert!(
628 !result.graph.is_adjacent(0, 2),
629 "X-Z should not be adjacent given Y"
630 );
631 }
632
633 #[test]
634 fn test_pc_collider() {
635 let data = collider_data(300, 99999);
636 let pc = PcAlgorithm::new(0.05);
637 let result = pc.fit(data.view(), &["X", "Y", "Z"]).expect("PC failed");
638 assert!(result.graph.is_adjacent(0, 1), "X-Y should be adjacent");
641 assert!(result.graph.is_adjacent(1, 2), "Y-Z should be adjacent");
642 assert!(
644 !result.graph.is_adjacent(0, 2),
645 "X-Z should not be adjacent (independent causes)"
646 );
647 assert!(
650 result.graph.is_directed(0, 1) || result.graph.is_directed(2, 1),
651 "At least one edge should point into Y (v-structure)"
652 );
653 }
654
655 #[test]
656 fn test_pc_meek_r1() {
657 let mut graph = CausalGraph::new(&["A", "B", "C"]);
659 graph.set_edge(0, 1, EdgeMark::Tail, EdgeMark::Arrow);
661 graph.set_edge(1, 2, EdgeMark::Tail, EdgeMark::Tail);
663 apply_meek_rules(&mut graph, 3);
666
667 assert!(graph.is_directed(1, 2), "R1: b -> c expected");
669 }
670
671 #[test]
672 fn test_pc_meek_r2() {
673 let mut graph = CausalGraph::new(&["A", "B", "C"]);
675 graph.set_edge(0, 1, EdgeMark::Tail, EdgeMark::Arrow); graph.set_edge(1, 2, EdgeMark::Tail, EdgeMark::Arrow); graph.set_edge(0, 2, EdgeMark::Tail, EdgeMark::Tail); apply_meek_rules(&mut graph, 3);
680
681 assert!(graph.is_directed(0, 2), "R2: a -> c expected");
682 }
683
684 #[test]
685 fn test_pc_meek_r3() {
686 let mut graph = CausalGraph::new(&["A", "B", "C", "D"]);
688 graph.set_edge(0, 3, EdgeMark::Tail, EdgeMark::Tail); graph.set_edge(1, 3, EdgeMark::Tail, EdgeMark::Arrow); graph.set_edge(2, 3, EdgeMark::Tail, EdgeMark::Arrow); graph.set_edge(0, 1, EdgeMark::Tail, EdgeMark::Tail); graph.set_edge(0, 2, EdgeMark::Tail, EdgeMark::Tail); apply_meek_rules(&mut graph, 4);
696
697 assert!(graph.is_directed(0, 3), "R3: a -> d expected");
698 }
699
700 #[test]
701 fn test_pc_stable_vs_standard() {
702 let data = chain_data(200, 77777);
703 let pc_stable = PcAlgorithm::with_params(0.05, 3, true);
704 let pc_standard = PcAlgorithm::with_params(0.05, 3, false);
705 let r1 = pc_stable
706 .fit(data.view(), &["X", "Y", "Z"])
707 .expect("stable failed");
708 let r2 = pc_standard
709 .fit(data.view(), &["X", "Y", "Z"])
710 .expect("standard failed");
711 assert_eq!(
713 r1.graph.is_adjacent(0, 2),
714 r2.graph.is_adjacent(0, 2),
715 "Skeleton should match for simple structures"
716 );
717 }
718
719 #[test]
720 fn test_pc_sep_sets() {
721 let data = chain_data(300, 12345);
722 let pc = PcAlgorithm::new(0.05);
723 let result = pc.fit(data.view(), &["X", "Y", "Z"]).expect("PC failed");
724 if let Some(sep) = result.sep_sets.get(&(0, 2)) {
726 assert!(sep.contains(&1), "Sep set for X-Z should contain Y");
727 }
728 }
730
731 #[test]
732 fn test_subsets() {
733 let items = vec![0, 1, 2, 3];
734 let s0 = subsets(&items, 0);
735 assert_eq!(s0.len(), 1);
736 assert!(s0[0].is_empty());
737
738 let s1 = subsets(&items, 1);
739 assert_eq!(s1.len(), 4);
740
741 let s2 = subsets(&items, 2);
742 assert_eq!(s2.len(), 6);
743
744 let s3 = subsets(&items, 3);
745 assert_eq!(s3.len(), 4);
746
747 let s4 = subsets(&items, 4);
748 assert_eq!(s4.len(), 1);
749
750 let s5 = subsets(&items, 5);
751 assert!(s5.is_empty());
752 }
753
754 #[test]
755 fn test_directed_path_detection() {
756 let mut graph = CausalGraph::new(&["A", "B", "C", "D"]);
757 graph.set_edge(0, 1, EdgeMark::Tail, EdgeMark::Arrow); graph.set_edge(1, 2, EdgeMark::Tail, EdgeMark::Arrow); graph.set_edge(2, 3, EdgeMark::Tail, EdgeMark::Arrow); assert!(has_directed_path(&graph, 0, 3, 4), "A -> B -> C -> D");
762 assert!(has_directed_path(&graph, 0, 2, 4), "A -> B -> C");
763 assert!(!has_directed_path(&graph, 3, 0, 4), "No path D -> A");
764 assert!(!has_directed_path(&graph, 1, 0, 4), "No path B -> A");
765 }
766}