Skip to main content

scirs2_stats/causal/
pc_algorithm.rs

1//! PC Algorithm for Causal Discovery
2//!
3//! The Peter-Clark (PC) algorithm is a constraint-based method that learns
4//! the Markov equivalence class of a DAG from observational data, represented
5//! as a Completed Partially Directed Acyclic Graph (CPDAG).
6//!
7//! # Phases
8//!
9//! 1. **Skeleton discovery** (PC-stable variant): start with a fully connected
10//!    undirected graph and iteratively remove edges when a conditional
11//!    independence is found, recording separation sets.
12//!
13//! 2. **V-structure orientation**: for each triple X - Z - Y where X and Y
14//!    are not adjacent, orient X -> Z <- Y if Z is not in sep(X, Y).
15//!
16//! 3. **Meek's rules** (R1-R4): propagate orientations to avoid new
17//!    v-structures and cycles.
18//!
19//! # PC-stable variant
20//!
21//! In the standard PC algorithm, edge removals during skeleton discovery can
22//! depend on the order in which edges are tested. The PC-stable variant
23//! (Colombo & Maathuis, 2014) fixes this by computing all removals at each
24//! conditioning-set size before applying them.
25//!
26//! # References
27//!
28//! - Spirtes, P., Glymour, C. & Scheines, R. (2000). *Causation, Prediction,
29//!   and Search* (2nd ed.). MIT Press.
30//! - Colombo, D. & Maathuis, M.H. (2014). Order-independent constraint-based
31//!   causal structure learning. *JMLR* 15, 3741-3782.
32//! - Meek, C. (1995). Causal inference and causal explanation with background
33//!   knowledge. *UAI 1995*, pp. 403-410.
34
35use std::collections::HashMap;
36
37use scirs2_core::ndarray::ArrayView2;
38
39use super::conditional_independence::{ConditionalIndependenceTest, PartialCorrelationTest};
40use super::{CausalGraph, EdgeMark};
41use crate::error::{StatsError, StatsResult};
42
43// ---------------------------------------------------------------------------
44// PC Algorithm
45// ---------------------------------------------------------------------------
46
47/// Configuration for the PC algorithm.
48#[derive(Debug, Clone)]
49pub struct PcAlgorithm {
50    /// Significance level alpha for CI tests (default 0.05).
51    pub alpha: f64,
52    /// Maximum conditioning set size (default 3).
53    pub max_cond_set_size: usize,
54    /// Whether to use the PC-stable variant (default true).
55    pub stable: bool,
56}
57
58impl Default for PcAlgorithm {
59    fn default() -> Self {
60        Self {
61            alpha: 0.05,
62            max_cond_set_size: 3,
63            stable: true,
64        }
65    }
66}
67
68/// Result of the PC algorithm.
69#[derive(Debug, Clone)]
70pub struct PcResult {
71    /// The learned CPDAG.
72    pub graph: CausalGraph,
73    /// Separation sets: sep_sets[(i,j)] = conditioning set that made i and j independent.
74    pub sep_sets: HashMap<(usize, usize), Vec<usize>>,
75    /// Number of CI tests performed.
76    pub n_tests: usize,
77}
78
79impl PcAlgorithm {
80    /// Create a PC algorithm with the given significance level.
81    pub fn new(alpha: f64) -> Self {
82        Self {
83            alpha,
84            ..Default::default()
85        }
86    }
87
88    /// Create a PC algorithm with custom parameters.
89    pub fn with_params(alpha: f64, max_cond_set_size: usize, stable: bool) -> Self {
90        Self {
91            alpha,
92            max_cond_set_size,
93            stable,
94        }
95    }
96
97    /// Run the PC algorithm using the default partial correlation CI test.
98    pub fn fit(&self, data: ArrayView2<f64>, var_names: &[&str]) -> StatsResult<PcResult> {
99        let ci_test = PartialCorrelationTest::new(self.alpha);
100        self.fit_with_test(data, var_names, &ci_test)
101    }
102
103    /// Run the PC algorithm with a custom CI test.
104    pub fn fit_with_test<T: ConditionalIndependenceTest>(
105        &self,
106        data: ArrayView2<f64>,
107        var_names: &[&str],
108        ci_test: &T,
109    ) -> StatsResult<PcResult> {
110        let p = data.ncols();
111        if var_names.len() != p {
112            return Err(StatsError::DimensionMismatch(
113                "var_names length must match number of columns".to_owned(),
114            ));
115        }
116        if p == 0 {
117            return Ok(PcResult {
118                graph: CausalGraph::new(var_names),
119                sep_sets: HashMap::new(),
120                n_tests: 0,
121            });
122        }
123
124        // Phase 1: Skeleton discovery
125        let (adj, sep_sets, n_tests) = if self.stable {
126            self.skeleton_stable(data, p, ci_test)?
127        } else {
128            self.skeleton_standard(data, p, ci_test)?
129        };
130
131        // Phase 2: Orient v-structures
132        let mut graph = CausalGraph::new(var_names);
133        // Set up adjacency from skeleton
134        for i in 0..p {
135            for j in (i + 1)..p {
136                if adj[i][j] {
137                    graph.set_edge(i, j, EdgeMark::Tail, EdgeMark::Tail);
138                }
139            }
140        }
141
142        orient_v_structures(&mut graph, &adj, &sep_sets, p);
143
144        // Phase 3: Apply Meek's rules R1-R4
145        apply_meek_rules(&mut graph, p);
146
147        Ok(PcResult {
148            graph,
149            sep_sets,
150            n_tests,
151        })
152    }
153
154    /// Standard (order-dependent) skeleton discovery.
155    fn skeleton_standard<T: ConditionalIndependenceTest>(
156        &self,
157        data: ArrayView2<f64>,
158        p: usize,
159        ci_test: &T,
160    ) -> StatsResult<(Vec<Vec<bool>>, HashMap<(usize, usize), Vec<usize>>, usize)> {
161        let mut adj = vec![vec![true; p]; p];
162        for i in 0..p {
163            adj[i][i] = false;
164        }
165        let mut sep_sets: HashMap<(usize, usize), Vec<usize>> = HashMap::new();
166        let mut n_tests = 0usize;
167
168        for ord in 0..=self.max_cond_set_size {
169            let edges: Vec<(usize, usize)> = (0..p)
170                .flat_map(|i| ((i + 1)..p).map(move |j| (i, j)))
171                .filter(|&(i, j)| adj[i][j])
172                .collect();
173
174            for (x, y) in edges {
175                let z_candidates: Vec<usize> =
176                    (0..p).filter(|&k| k != x && k != y && adj[x][k]).collect();
177                if z_candidates.len() < ord {
178                    continue;
179                }
180
181                for z_set in subsets(&z_candidates, ord) {
182                    n_tests += 1;
183                    if ci_test.is_independent(x, y, &z_set, data, self.alpha)? {
184                        adj[x][y] = false;
185                        adj[y][x] = false;
186                        let key = (x.min(y), x.max(y));
187                        sep_sets.insert(key, z_set);
188                        break;
189                    }
190                }
191            }
192        }
193
194        Ok((adj, sep_sets, n_tests))
195    }
196
197    /// PC-stable skeleton discovery (order-independent).
198    ///
199    /// At each conditioning-set size, all removals are computed on the
200    /// adjacency from the *previous* level, then applied simultaneously.
201    fn skeleton_stable<T: ConditionalIndependenceTest>(
202        &self,
203        data: ArrayView2<f64>,
204        p: usize,
205        ci_test: &T,
206    ) -> StatsResult<(Vec<Vec<bool>>, HashMap<(usize, usize), Vec<usize>>, usize)> {
207        let mut adj = vec![vec![true; p]; p];
208        for i in 0..p {
209            adj[i][i] = false;
210        }
211        let mut sep_sets: HashMap<(usize, usize), Vec<usize>> = HashMap::new();
212        let mut n_tests = 0usize;
213
214        for ord in 0..=self.max_cond_set_size {
215            // Snapshot adjacency at start of this order
216            let adj_snapshot = adj.clone();
217
218            let edges: Vec<(usize, usize)> = (0..p)
219                .flat_map(|i| ((i + 1)..p).map(move |j| (i, j)))
220                .filter(|&(i, j)| adj_snapshot[i][j])
221                .collect();
222
223            // Collect all removals to apply at end
224            let mut removals: Vec<(usize, usize, Vec<usize>)> = Vec::new();
225
226            for (x, y) in edges {
227                // Use adjacency from snapshot (PC-stable)
228                let z_candidates: Vec<usize> = (0..p)
229                    .filter(|&k| k != x && k != y && adj_snapshot[x][k])
230                    .collect();
231                if z_candidates.len() < ord {
232                    continue;
233                }
234
235                // Also check neighbours of y (both sides, PC-stable)
236                let z_candidates_y: Vec<usize> = (0..p)
237                    .filter(|&k| k != x && k != y && adj_snapshot[y][k])
238                    .collect();
239
240                let mut found = false;
241                // Test from x's side
242                for z_set in subsets(&z_candidates, ord) {
243                    n_tests += 1;
244                    if ci_test.is_independent(x, y, &z_set, data, self.alpha)? {
245                        removals.push((x, y, z_set));
246                        found = true;
247                        break;
248                    }
249                }
250                if found {
251                    continue;
252                }
253                // Test from y's side (PC-stable considers both neighbour sets)
254                if z_candidates_y.len() >= ord {
255                    for z_set in subsets(&z_candidates_y, ord) {
256                        // Skip sets already tested from x's side
257                        n_tests += 1;
258                        if ci_test.is_independent(x, y, &z_set, data, self.alpha)? {
259                            removals.push((x, y, z_set));
260                            break;
261                        }
262                    }
263                }
264            }
265
266            // Apply removals simultaneously
267            for (x, y, z_set) in removals {
268                adj[x][y] = false;
269                adj[y][x] = false;
270                let key = (x.min(y), x.max(y));
271                sep_sets.insert(key, z_set);
272            }
273        }
274
275        Ok((adj, sep_sets, n_tests))
276    }
277}
278
279// ---------------------------------------------------------------------------
280// V-structure orientation
281// ---------------------------------------------------------------------------
282
283/// Orient v-structures: for each X - Z - Y where X-Y not adjacent,
284/// if Z not in sep(X,Y), orient X -> Z <- Y.
285fn orient_v_structures(
286    graph: &mut CausalGraph,
287    adj: &[Vec<bool>],
288    sep_sets: &HashMap<(usize, usize), Vec<usize>>,
289    p: usize,
290) {
291    for z in 0..p {
292        let neighbours: Vec<usize> = (0..p).filter(|&k| k != z && adj[z][k]).collect();
293        for i in 0..neighbours.len() {
294            for j in (i + 1)..neighbours.len() {
295                let x = neighbours[i];
296                let y = neighbours[j];
297                // x - y must not be adjacent
298                if adj[x][y] {
299                    continue;
300                }
301                let key = (x.min(y), x.max(y));
302                let sep = sep_sets.get(&key).cloned().unwrap_or_default();
303                if !sep.contains(&z) {
304                    // Orient X -> Z <- Y
305                    graph.set_edge(x, z, EdgeMark::Tail, EdgeMark::Arrow);
306                    graph.set_edge(y, z, EdgeMark::Tail, EdgeMark::Arrow);
307                }
308            }
309        }
310    }
311}
312
313// ---------------------------------------------------------------------------
314// Meek's Rules R1-R4
315// ---------------------------------------------------------------------------
316
317/// Apply Meek's orientation rules R1-R4 until no more orientations change.
318///
319/// - R1: a -> b - c and a not adj c => orient b -> c
320/// - R2: a -> b -> c and a - c => orient a -> c
321/// - R3: a - b, a - c, a - d, b -> d, c -> d, b not adj c => orient a -> d
322/// - R4: a - b, b -> c, a - c, a -> d not adj c, d -> c exists on some
323///       directed path => orient a -> c (acyclicity preservation)
324pub fn apply_meek_rules(graph: &mut CausalGraph, p: usize) {
325    let max_iterations = p * p + 10;
326    let mut changed = true;
327    let mut iterations = 0;
328
329    while changed && iterations < max_iterations {
330        changed = false;
331        iterations += 1;
332
333        // R1: a -> b - c, a not adj c => b -> c
334        changed |= meek_r1(graph, p);
335
336        // R2: a -> b -> c, a - c => a -> c
337        changed |= meek_r2(graph, p);
338
339        // R3: a - d, b -> d, c -> d, a - b, a - c, b not adj c => a -> d
340        changed |= meek_r3(graph, p);
341
342        // R4: a - b, b -> c -> ... -> a (directed path), b not adj a through
343        //     the path => orient a -> b to avoid cycle
344        changed |= meek_r4(graph, p);
345    }
346}
347
348/// R1: If a -> b - c and a is not adjacent to c, orient b -> c.
349fn meek_r1(graph: &mut CausalGraph, p: usize) -> bool {
350    let mut changed = false;
351    for b in 0..p {
352        for a in 0..p {
353            if a == b {
354                continue;
355            }
356            // Check a -> b (directed from a to b)
357            if !graph.is_directed(a, b) {
358                continue;
359            }
360            for c in 0..p {
361                if c == a || c == b {
362                    continue;
363                }
364                // Check b - c (undirected)
365                if !graph.is_undirected(b, c) {
366                    continue;
367                }
368                // Check a not adjacent to c
369                if graph.is_adjacent(a, c) {
370                    continue;
371                }
372                // Orient b -> c
373                graph.set_edge(b, c, EdgeMark::Tail, EdgeMark::Arrow);
374                changed = true;
375            }
376        }
377    }
378    changed
379}
380
381/// R2: If a -> b -> c and a - c, orient a -> c.
382fn meek_r2(graph: &mut CausalGraph, p: usize) -> bool {
383    let mut changed = false;
384    for a in 0..p {
385        for b in 0..p {
386            if a == b {
387                continue;
388            }
389            if !graph.is_directed(a, b) {
390                continue;
391            }
392            for c in 0..p {
393                if c == a || c == b {
394                    continue;
395                }
396                if !graph.is_directed(b, c) {
397                    continue;
398                }
399                if !graph.is_undirected(a, c) {
400                    continue;
401                }
402                graph.set_edge(a, c, EdgeMark::Tail, EdgeMark::Arrow);
403                changed = true;
404            }
405        }
406    }
407    changed
408}
409
410/// R3: If a - d, and there exist b, c such that b -> d, c -> d,
411/// a - b, a - c, and b not adj c, orient a -> d.
412fn meek_r3(graph: &mut CausalGraph, p: usize) -> bool {
413    let mut changed = false;
414    for a in 0..p {
415        for d in 0..p {
416            if a == d {
417                continue;
418            }
419            // a - d (undirected)
420            if !graph.is_undirected(a, d) {
421                continue;
422            }
423            // Find b, c where: b -> d, c -> d, a - b, a - c, b not adj c
424            let parents_of_d: Vec<usize> = (0..p)
425                .filter(|&k| k != a && k != d && graph.is_directed(k, d))
426                .collect();
427            let mut orient = false;
428            for i in 0..parents_of_d.len() {
429                for j in (i + 1)..parents_of_d.len() {
430                    let b = parents_of_d[i];
431                    let c = parents_of_d[j];
432                    if graph.is_undirected(a, b)
433                        && graph.is_undirected(a, c)
434                        && !graph.is_adjacent(b, c)
435                    {
436                        orient = true;
437                        break;
438                    }
439                }
440                if orient {
441                    break;
442                }
443            }
444            if orient {
445                graph.set_edge(a, d, EdgeMark::Tail, EdgeMark::Arrow);
446                changed = true;
447            }
448        }
449    }
450    changed
451}
452
453/// R4: If a - b and there exists a directed path from b to a through
454/// some node c where b -> c and a - c, orient a -> b.
455///
456/// More precisely: if a - b, a - c, c -> ... -> b (directed path of length >= 1),
457/// and b -> c, then orient a -> b. This is needed to prevent creating a new
458/// v-structure or a directed cycle.
459fn meek_r4(graph: &mut CausalGraph, p: usize) -> bool {
460    let mut changed = false;
461    for a in 0..p {
462        for b in 0..p {
463            if a == b {
464                continue;
465            }
466            if !graph.is_undirected(a, b) {
467                continue;
468            }
469            // Check: exists c such that a - c, b -> c, and directed path c ->* b
470            // (which would create cycle if we orient a <- b)
471            for c in 0..p {
472                if c == a || c == b {
473                    continue;
474                }
475                if !graph.is_undirected(a, c) {
476                    continue;
477                }
478                if !graph.is_directed(b, c) {
479                    continue;
480                }
481                // Check directed path from c to a
482                if has_directed_path(graph, c, a, p) {
483                    graph.set_edge(a, b, EdgeMark::Tail, EdgeMark::Arrow);
484                    changed = true;
485                    break;
486                }
487            }
488        }
489    }
490    changed
491}
492
493/// Check if there is a directed path from `src` to `dst` in the graph.
494fn has_directed_path(graph: &CausalGraph, src: usize, dst: usize, p: usize) -> bool {
495    let mut visited = vec![false; p];
496    let mut stack = vec![src];
497    while let Some(cur) = stack.pop() {
498        if cur == dst {
499            return true;
500        }
501        if visited[cur] {
502            continue;
503        }
504        visited[cur] = true;
505        for next in 0..p {
506            if !visited[next] && graph.is_directed(cur, next) {
507                stack.push(next);
508            }
509        }
510    }
511    false
512}
513
514// ---------------------------------------------------------------------------
515// Helpers
516// ---------------------------------------------------------------------------
517
518/// Generate all subsets of `items` of size `k`.
519pub(crate) fn subsets<T: Copy>(items: &[T], k: usize) -> Vec<Vec<T>> {
520    if k == 0 {
521        return vec![Vec::new()];
522    }
523    if k > items.len() {
524        return Vec::new();
525    }
526    let mut result = Vec::new();
527    for i in 0..=(items.len() - k) {
528        for mut rest in subsets(&items[i + 1..], k - 1) {
529            rest.insert(0, items[i]);
530            result.push(rest);
531        }
532    }
533    result
534}
535
536// ---------------------------------------------------------------------------
537// Tests
538// ---------------------------------------------------------------------------
539
540#[cfg(test)]
541mod tests {
542    use super::*;
543    use scirs2_core::ndarray::Array2;
544
545    fn lcg_uniform(s: &mut u64) -> f64 {
546        *s = s
547            .wrapping_mul(6364136223846793005)
548            .wrapping_add(1442695040888963407);
549        ((*s >> 11) as f64) / ((1u64 << 53) as f64)
550    }
551
552    fn lcg_normal(s: &mut u64) -> f64 {
553        let u1 = lcg_uniform(s).max(1e-15);
554        let u2 = lcg_uniform(s);
555        (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
556    }
557
558    /// Generate chain X -> Y -> Z data.
559    fn chain_data(n: usize, seed: u64) -> Array2<f64> {
560        let mut data = Array2::<f64>::zeros((n, 3));
561        let mut lcg = seed;
562        for i in 0..n {
563            data[[i, 0]] = lcg_normal(&mut lcg);
564            data[[i, 1]] = 0.9 * data[[i, 0]] + lcg_normal(&mut lcg) * 0.3;
565            data[[i, 2]] = 0.9 * data[[i, 1]] + lcg_normal(&mut lcg) * 0.3;
566        }
567        data
568    }
569
570    /// Generate fork X <- Y -> Z data.
571    fn fork_data(n: usize, seed: u64) -> Array2<f64> {
572        let mut data = Array2::<f64>::zeros((n, 3));
573        let mut lcg = seed;
574        for i in 0..n {
575            let y = lcg_normal(&mut lcg);
576            data[[i, 0]] = 0.9 * y + lcg_normal(&mut lcg) * 0.3;
577            data[[i, 1]] = y;
578            data[[i, 2]] = 0.9 * y + lcg_normal(&mut lcg) * 0.3;
579        }
580        data
581    }
582
583    /// Generate collider X -> Y <- Z data.
584    fn collider_data(n: usize, seed: u64) -> Array2<f64> {
585        let mut data = Array2::<f64>::zeros((n, 3));
586        let mut lcg = seed;
587        for i in 0..n {
588            data[[i, 0]] = lcg_normal(&mut lcg);
589            data[[i, 2]] = lcg_normal(&mut lcg);
590            data[[i, 1]] = 0.7 * data[[i, 0]] + 0.7 * data[[i, 2]] + lcg_normal(&mut lcg) * 0.3;
591        }
592        data
593    }
594
595    #[test]
596    fn test_pc_chain() {
597        let data = chain_data(300, 12345);
598        let pc = PcAlgorithm::new(0.05);
599        let result = pc.fit(data.view(), &["X", "Y", "Z"]).expect("PC failed");
600        // Chain: X - Y - Z, X not adj Z
601        // X-Y should be adjacent
602        assert!(
603            result.graph.is_adjacent(0, 1),
604            "X-Y should be adjacent in chain"
605        );
606        // Y-Z should be adjacent
607        assert!(
608            result.graph.is_adjacent(1, 2),
609            "Y-Z should be adjacent in chain"
610        );
611        // X-Z should NOT be adjacent (conditionally independent given Y)
612        assert!(
613            !result.graph.is_adjacent(0, 2),
614            "X-Z should not be adjacent in chain"
615        );
616    }
617
618    #[test]
619    fn test_pc_fork() {
620        let data = fork_data(300, 54321);
621        let pc = PcAlgorithm::new(0.05);
622        let result = pc.fit(data.view(), &["X", "Y", "Z"]).expect("PC failed");
623        // Fork: X <- Y -> Z
624        // X-Y, Y-Z adjacent; X-Z not adjacent (given Y)
625        assert!(result.graph.is_adjacent(0, 1), "X-Y should be adjacent");
626        assert!(result.graph.is_adjacent(1, 2), "Y-Z should be adjacent");
627        assert!(
628            !result.graph.is_adjacent(0, 2),
629            "X-Z should not be adjacent given Y"
630        );
631    }
632
633    #[test]
634    fn test_pc_collider() {
635        let data = collider_data(300, 99999);
636        let pc = PcAlgorithm::new(0.05);
637        let result = pc.fit(data.view(), &["X", "Y", "Z"]).expect("PC failed");
638        // Collider: X -> Y <- Z
639        // All should be adjacent but X-Z not
640        assert!(result.graph.is_adjacent(0, 1), "X-Y should be adjacent");
641        assert!(result.graph.is_adjacent(1, 2), "Y-Z should be adjacent");
642        // X and Z should be marginally independent
643        assert!(
644            !result.graph.is_adjacent(0, 2),
645            "X-Z should not be adjacent (independent causes)"
646        );
647        // V-structure: X -> Y <- Z
648        // Y should be oriented as collider
649        assert!(
650            result.graph.is_directed(0, 1) || result.graph.is_directed(2, 1),
651            "At least one edge should point into Y (v-structure)"
652        );
653    }
654
655    #[test]
656    fn test_pc_meek_r1() {
657        // Test Meek's R1: if a -> b - c and a not adj c, orient b -> c
658        let mut graph = CausalGraph::new(&["A", "B", "C"]);
659        // a -> b
660        graph.set_edge(0, 1, EdgeMark::Tail, EdgeMark::Arrow);
661        // b - c
662        graph.set_edge(1, 2, EdgeMark::Tail, EdgeMark::Tail);
663        // a not adj c (no edge)
664
665        apply_meek_rules(&mut graph, 3);
666
667        // R1 should orient b -> c
668        assert!(graph.is_directed(1, 2), "R1: b -> c expected");
669    }
670
671    #[test]
672    fn test_pc_meek_r2() {
673        // R2: a -> b -> c and a - c => a -> c
674        let mut graph = CausalGraph::new(&["A", "B", "C"]);
675        graph.set_edge(0, 1, EdgeMark::Tail, EdgeMark::Arrow); // a -> b
676        graph.set_edge(1, 2, EdgeMark::Tail, EdgeMark::Arrow); // b -> c
677        graph.set_edge(0, 2, EdgeMark::Tail, EdgeMark::Tail); // a - c
678
679        apply_meek_rules(&mut graph, 3);
680
681        assert!(graph.is_directed(0, 2), "R2: a -> c expected");
682    }
683
684    #[test]
685    fn test_pc_meek_r3() {
686        // R3: a - d, b -> d, c -> d, a - b, a - c, b not adj c => a -> d
687        let mut graph = CausalGraph::new(&["A", "B", "C", "D"]);
688        graph.set_edge(0, 3, EdgeMark::Tail, EdgeMark::Tail); // a - d
689        graph.set_edge(1, 3, EdgeMark::Tail, EdgeMark::Arrow); // b -> d
690        graph.set_edge(2, 3, EdgeMark::Tail, EdgeMark::Arrow); // c -> d
691        graph.set_edge(0, 1, EdgeMark::Tail, EdgeMark::Tail); // a - b
692        graph.set_edge(0, 2, EdgeMark::Tail, EdgeMark::Tail); // a - c
693                                                              // b and c NOT adjacent
694
695        apply_meek_rules(&mut graph, 4);
696
697        assert!(graph.is_directed(0, 3), "R3: a -> d expected");
698    }
699
700    #[test]
701    fn test_pc_stable_vs_standard() {
702        let data = chain_data(200, 77777);
703        let pc_stable = PcAlgorithm::with_params(0.05, 3, true);
704        let pc_standard = PcAlgorithm::with_params(0.05, 3, false);
705        let r1 = pc_stable
706            .fit(data.view(), &["X", "Y", "Z"])
707            .expect("stable failed");
708        let r2 = pc_standard
709            .fit(data.view(), &["X", "Y", "Z"])
710            .expect("standard failed");
711        // Both should find the same skeleton for a simple chain
712        assert_eq!(
713            r1.graph.is_adjacent(0, 2),
714            r2.graph.is_adjacent(0, 2),
715            "Skeleton should match for simple structures"
716        );
717    }
718
719    #[test]
720    fn test_pc_sep_sets() {
721        let data = chain_data(300, 12345);
722        let pc = PcAlgorithm::new(0.05);
723        let result = pc.fit(data.view(), &["X", "Y", "Z"]).expect("PC failed");
724        // X-Z should be separated by Y
725        if let Some(sep) = result.sep_sets.get(&(0, 2)) {
726            assert!(sep.contains(&1), "Sep set for X-Z should contain Y");
727        }
728        // If X-Z is adjacent, there's no sep set, which is also valid for some data
729    }
730
731    #[test]
732    fn test_subsets() {
733        let items = vec![0, 1, 2, 3];
734        let s0 = subsets(&items, 0);
735        assert_eq!(s0.len(), 1);
736        assert!(s0[0].is_empty());
737
738        let s1 = subsets(&items, 1);
739        assert_eq!(s1.len(), 4);
740
741        let s2 = subsets(&items, 2);
742        assert_eq!(s2.len(), 6);
743
744        let s3 = subsets(&items, 3);
745        assert_eq!(s3.len(), 4);
746
747        let s4 = subsets(&items, 4);
748        assert_eq!(s4.len(), 1);
749
750        let s5 = subsets(&items, 5);
751        assert!(s5.is_empty());
752    }
753
754    #[test]
755    fn test_directed_path_detection() {
756        let mut graph = CausalGraph::new(&["A", "B", "C", "D"]);
757        graph.set_edge(0, 1, EdgeMark::Tail, EdgeMark::Arrow); // A -> B
758        graph.set_edge(1, 2, EdgeMark::Tail, EdgeMark::Arrow); // B -> C
759        graph.set_edge(2, 3, EdgeMark::Tail, EdgeMark::Arrow); // C -> D
760
761        assert!(has_directed_path(&graph, 0, 3, 4), "A -> B -> C -> D");
762        assert!(has_directed_path(&graph, 0, 2, 4), "A -> B -> C");
763        assert!(!has_directed_path(&graph, 3, 0, 4), "No path D -> A");
764        assert!(!has_directed_path(&graph, 1, 0, 4), "No path B -> A");
765    }
766}