Skip to main content

scirs2_stats/causal/
fci_algorithm.rs

1//! FCI Algorithm for Causal Discovery with Latent Confounders
2//!
3//! The Fast Causal Inference (FCI) algorithm extends the PC algorithm to
4//! handle latent (unmeasured) confounders. Instead of a CPDAG, FCI outputs
5//! a Partial Ancestral Graph (PAG) that uses additional edge marks:
6//!
7//! - **Circle (o)**: the mark is unknown (could be tail or arrow)
8//! - **Arrow (>)**: definite arrowhead
9//! - **Tail (-)**: definite tail
10//!
11//! Edge types in a PAG:
12//! - `->` : definite direct cause
13//! - `<->` : latent common cause (bidirected)
14//! - `o->` : possible direct cause or latent common cause
15//! - `o-o` : completely uncertain
16//!
17//! # Algorithm outline
18//!
19//! 1. Run PC skeleton discovery
20//! 2. Orient unshielded colliders (v-structures)
21//! 3. Compute Possible-D-SEP sets and refine adjacency
22//! 4. Re-orient v-structures on refined skeleton
23//! 5. Apply FCI orientation rules R1-R10
24//!
25//! # References
26//!
27//! - Spirtes, P., Glymour, C. & Scheines, R. (2000). *Causation, Prediction,
28//!   and Search* (2nd ed.). MIT Press.
29//! - Richardson, T. & Spirtes, P. (2002). Ancestral graph Markov models.
30//!   *Ann. Statist.* 30, 962-1030.
31//! - Zhang, J. (2008). On the completeness of orientation rules for causal
32//!   discovery in the presence of latent confounders and selection variables.
33//!   *Artificial Intelligence* 172, 1873-1896.
34
35use std::collections::{HashMap, HashSet, VecDeque};
36
37use scirs2_core::ndarray::ArrayView2;
38
39use super::conditional_independence::{ConditionalIndependenceTest, PartialCorrelationTest};
40use super::pc_algorithm::subsets;
41use super::{CausalGraph, EdgeMark};
42use crate::error::{StatsError, StatsResult};
43
44// ---------------------------------------------------------------------------
45// FCI Algorithm
46// ---------------------------------------------------------------------------
47
48/// Configuration for the FCI algorithm.
49#[derive(Debug, Clone)]
50pub struct FciAlgorithm {
51    /// Significance level alpha for CI tests (default 0.05).
52    pub alpha: f64,
53    /// Maximum conditioning set size (default 4).
54    pub max_cond_set_size: usize,
55    /// Maximum Possible-D-SEP set size (default 4).
56    pub max_pdsep_size: usize,
57}
58
59impl Default for FciAlgorithm {
60    fn default() -> Self {
61        Self {
62            alpha: 0.05,
63            max_cond_set_size: 4,
64            max_pdsep_size: 4,
65        }
66    }
67}
68
69/// Result of the FCI algorithm.
70#[derive(Debug, Clone)]
71pub struct FciResult {
72    /// The learned PAG (Partial Ancestral Graph).
73    pub graph: CausalGraph,
74    /// Separation sets.
75    pub sep_sets: HashMap<(usize, usize), Vec<usize>>,
76    /// Number of CI tests performed.
77    pub n_tests: usize,
78    /// Whether latent confounders were detected (any bidirected edges).
79    pub has_latent_confounders: bool,
80}
81
82impl FciAlgorithm {
83    /// Create an FCI algorithm with the given significance level.
84    pub fn new(alpha: f64) -> Self {
85        Self {
86            alpha,
87            ..Default::default()
88        }
89    }
90
91    /// Create with custom parameters.
92    pub fn with_params(alpha: f64, max_cond_set_size: usize, max_pdsep_size: usize) -> Self {
93        Self {
94            alpha,
95            max_cond_set_size,
96            max_pdsep_size,
97        }
98    }
99
100    /// Run FCI using the default partial correlation CI test.
101    pub fn fit(&self, data: ArrayView2<f64>, var_names: &[&str]) -> StatsResult<FciResult> {
102        let ci_test = PartialCorrelationTest::new(self.alpha);
103        self.fit_with_test(data, var_names, &ci_test)
104    }
105
106    /// Run FCI with a custom CI test.
107    pub fn fit_with_test<T: ConditionalIndependenceTest>(
108        &self,
109        data: ArrayView2<f64>,
110        var_names: &[&str],
111        ci_test: &T,
112    ) -> StatsResult<FciResult> {
113        let p = data.ncols();
114        if var_names.len() != p {
115            return Err(StatsError::DimensionMismatch(
116                "var_names length must match number of columns".to_owned(),
117            ));
118        }
119        if p == 0 {
120            return Ok(FciResult {
121                graph: CausalGraph::new(var_names),
122                sep_sets: HashMap::new(),
123                n_tests: 0,
124                has_latent_confounders: false,
125            });
126        }
127
128        // Step 1: Initial skeleton discovery (same as PC)
129        let (mut adj, mut sep_sets, mut n_tests) =
130            skeleton_discovery(data, p, self.alpha, self.max_cond_set_size, ci_test)?;
131
132        // Step 2: Initial v-structure orientation
133        let mut graph = CausalGraph::new(var_names);
134        // Initialise with circle marks (FCI uses o-o for unoriented)
135        for i in 0..p {
136            for j in (i + 1)..p {
137                if adj[i][j] {
138                    graph.set_edge(i, j, EdgeMark::Circle, EdgeMark::Circle);
139                }
140            }
141        }
142        orient_unshielded_colliders(&mut graph, &adj, &sep_sets, p);
143
144        // Step 3: Possible-D-SEP refinement
145        let pdsep_removals = possible_dsep_phase(
146            &graph,
147            data,
148            &adj,
149            p,
150            self.alpha,
151            self.max_pdsep_size,
152            ci_test,
153            &mut n_tests,
154        )?;
155
156        // Apply removals
157        for (x, y, z_set) in pdsep_removals {
158            adj[x][y] = false;
159            adj[y][x] = false;
160            graph.remove_edge(x, y);
161            let key = (x.min(y), x.max(y));
162            sep_sets.insert(key, z_set);
163        }
164
165        // Step 4: Re-orient v-structures on refined skeleton
166        // Reset all remaining edges to circle-circle
167        for i in 0..p {
168            for j in (i + 1)..p {
169                if adj[i][j] {
170                    graph.set_edge(i, j, EdgeMark::Circle, EdgeMark::Circle);
171                }
172            }
173        }
174        orient_unshielded_colliders(&mut graph, &adj, &sep_sets, p);
175
176        // Step 5: Apply FCI orientation rules R1-R10
177        apply_fci_rules(&mut graph, &adj, &sep_sets, p);
178
179        // Detect latent confounders (bidirected edges)
180        let has_latent_confounders =
181            (0..p).any(|i| (0..p).any(|j| i != j && graph.is_bidirected(i, j)));
182
183        Ok(FciResult {
184            graph,
185            sep_sets,
186            n_tests,
187            has_latent_confounders,
188        })
189    }
190}
191
192// ---------------------------------------------------------------------------
193// Skeleton Discovery (shared with PC)
194// ---------------------------------------------------------------------------
195
196/// PC-stable skeleton discovery.
197fn skeleton_discovery<T: ConditionalIndependenceTest>(
198    data: ArrayView2<f64>,
199    p: usize,
200    alpha: f64,
201    max_cond_set_size: usize,
202    ci_test: &T,
203) -> StatsResult<(Vec<Vec<bool>>, HashMap<(usize, usize), Vec<usize>>, usize)> {
204    let mut adj = vec![vec![true; p]; p];
205    for i in 0..p {
206        adj[i][i] = false;
207    }
208    let mut sep_sets: HashMap<(usize, usize), Vec<usize>> = HashMap::new();
209    let mut n_tests = 0usize;
210
211    for ord in 0..=max_cond_set_size {
212        let adj_snapshot = adj.clone();
213        let edges: Vec<(usize, usize)> = (0..p)
214            .flat_map(|i| ((i + 1)..p).map(move |j| (i, j)))
215            .filter(|&(i, j)| adj_snapshot[i][j])
216            .collect();
217
218        let mut removals = Vec::new();
219
220        for (x, y) in edges {
221            let z_x: Vec<usize> = (0..p)
222                .filter(|&k| k != x && k != y && adj_snapshot[x][k])
223                .collect();
224            let z_y: Vec<usize> = (0..p)
225                .filter(|&k| k != x && k != y && adj_snapshot[y][k])
226                .collect();
227
228            let mut found = false;
229            if z_x.len() >= ord {
230                for z_set in subsets(&z_x, ord) {
231                    n_tests += 1;
232                    if ci_test.is_independent(x, y, &z_set, data, alpha)? {
233                        removals.push((x, y, z_set));
234                        found = true;
235                        break;
236                    }
237                }
238            }
239            if !found && z_y.len() >= ord {
240                for z_set in subsets(&z_y, ord) {
241                    n_tests += 1;
242                    if ci_test.is_independent(x, y, &z_set, data, alpha)? {
243                        removals.push((x, y, z_set));
244                        break;
245                    }
246                }
247            }
248        }
249
250        for (x, y, z_set) in removals {
251            adj[x][y] = false;
252            adj[y][x] = false;
253            let key = (x.min(y), x.max(y));
254            sep_sets.insert(key, z_set);
255        }
256    }
257
258    Ok((adj, sep_sets, n_tests))
259}
260
261// ---------------------------------------------------------------------------
262// Unshielded Collider Orientation
263// ---------------------------------------------------------------------------
264
265/// Orient unshielded colliders: X *-o Z o-* Y where X-Y not adjacent
266/// and Z not in sep(X,Y) => X *-> Z <-* Y.
267fn orient_unshielded_colliders(
268    graph: &mut CausalGraph,
269    adj: &[Vec<bool>],
270    sep_sets: &HashMap<(usize, usize), Vec<usize>>,
271    p: usize,
272) {
273    for z in 0..p {
274        let neighbours: Vec<usize> = (0..p).filter(|&k| k != z && adj[z][k]).collect();
275        for i in 0..neighbours.len() {
276            for j in (i + 1)..neighbours.len() {
277                let x = neighbours[i];
278                let y = neighbours[j];
279                if adj[x][y] {
280                    continue; // shielded
281                }
282                let key = (x.min(y), x.max(y));
283                let sep = sep_sets.get(&key).cloned().unwrap_or_default();
284                if !sep.contains(&z) {
285                    // Orient: arrowhead at z from both x and y
286                    let mark_xz_from = graph.get_mark_from(x, z).unwrap_or(EdgeMark::Circle);
287                    graph.set_edge(x, z, mark_xz_from, EdgeMark::Arrow);
288                    let mark_yz_from = graph.get_mark_from(y, z).unwrap_or(EdgeMark::Circle);
289                    graph.set_edge(y, z, mark_yz_from, EdgeMark::Arrow);
290                }
291            }
292        }
293    }
294}
295
296// ---------------------------------------------------------------------------
297// Possible-D-SEP
298// ---------------------------------------------------------------------------
299
300/// Compute Possible-D-SEP(a, b) in the partially oriented graph.
301///
302/// Possible-D-SEP(a, b) is the set of all nodes that can be reached from a
303/// by a path on which every non-endpoint node is a collider or has an
304/// undetermined edge mark.
305fn possible_dsep(graph: &CausalGraph, a: usize, b: usize, p: usize) -> HashSet<usize> {
306    let mut pdsep = HashSet::new();
307    let mut visited = HashSet::new();
308    let mut queue = VecDeque::new();
309
310    // Start from a's neighbours
311    for k in 0..p {
312        if k != a && k != b && graph.is_adjacent(a, k) {
313            queue.push_back((k, a)); // (current, previous)
314        }
315    }
316
317    while let Some((cur, prev)) = queue.pop_front() {
318        if !visited.insert((cur, prev)) {
319            continue;
320        }
321        pdsep.insert(cur);
322
323        // Continue along the path if cur is a collider on the path or
324        // has a circle mark from the previous node
325        for next in 0..p {
326            if next == prev || next == a || !graph.is_adjacent(cur, next) {
327                continue;
328            }
329            // Check if cur is a "possible collider" on the path prev - cur - next
330            // A node is on Possible-D-SEP if:
331            // 1. It has an arrowhead from prev (prev *-> cur)
332            // 2. Or it has a circle mark from prev (prev o-* cur)
333            let mark_at_cur_from_prev = graph.get_mark_at(prev, cur);
334            let is_possible_collider = match mark_at_cur_from_prev {
335                Some(EdgeMark::Arrow) | Some(EdgeMark::Circle) => true,
336                _ => false,
337            };
338
339            if is_possible_collider {
340                queue.push_back((next, cur));
341            }
342        }
343    }
344
345    pdsep
346}
347
348/// Possible-D-SEP phase: for each adjacent pair, test independence
349/// conditioning on subsets of Possible-D-SEP.
350fn possible_dsep_phase<T: ConditionalIndependenceTest>(
351    graph: &CausalGraph,
352    data: ArrayView2<f64>,
353    adj: &[Vec<bool>],
354    p: usize,
355    alpha: f64,
356    max_pdsep_size: usize,
357    ci_test: &T,
358    n_tests: &mut usize,
359) -> StatsResult<Vec<(usize, usize, Vec<usize>)>> {
360    let mut removals = Vec::new();
361
362    for x in 0..p {
363        for y in (x + 1)..p {
364            if !adj[x][y] {
365                continue;
366            }
367
368            let pdsep_x = possible_dsep(graph, x, y, p);
369            let pdsep_y = possible_dsep(graph, y, x, p);
370            let combined: Vec<usize> = pdsep_x
371                .union(&pdsep_y)
372                .copied()
373                .filter(|&k| k != x && k != y)
374                .collect();
375
376            if combined.is_empty() {
377                continue;
378            }
379
380            let max_size = max_pdsep_size.min(combined.len());
381            let mut found = false;
382            for ord in 0..=max_size {
383                if found {
384                    break;
385                }
386                for z_set in subsets(&combined, ord) {
387                    *n_tests += 1;
388                    if ci_test.is_independent(x, y, &z_set, data, alpha)? {
389                        removals.push((x, y, z_set));
390                        found = true;
391                        break;
392                    }
393                }
394            }
395        }
396    }
397
398    Ok(removals)
399}
400
401// ---------------------------------------------------------------------------
402// FCI Orientation Rules R1-R10
403// ---------------------------------------------------------------------------
404
405/// Apply FCI orientation rules R1-R10 until convergence.
406///
407/// These rules orient remaining circle marks into tails or arrowheads.
408fn apply_fci_rules(
409    graph: &mut CausalGraph,
410    adj: &[Vec<bool>],
411    sep_sets: &HashMap<(usize, usize), Vec<usize>>,
412    p: usize,
413) {
414    let max_iterations = p * p * 2 + 10;
415    let mut changed = true;
416    let mut iterations = 0;
417
418    while changed && iterations < max_iterations {
419        changed = false;
420        iterations += 1;
421
422        changed |= fci_r1(graph, p);
423        changed |= fci_r2(graph, p);
424        changed |= fci_r3(graph, adj, p);
425        changed |= fci_r4(graph, adj, sep_sets, p);
426        changed |= fci_r5(graph, adj, p);
427        changed |= fci_r6(graph, p);
428        changed |= fci_r7(graph, p);
429        changed |= fci_r8(graph, p);
430        changed |= fci_r9(graph, p);
431        changed |= fci_r10(graph, p);
432    }
433}
434
435/// R1: If a *-> b o-* c, and a is not adjacent to c, orient b *-> c.
436fn fci_r1(graph: &mut CausalGraph, p: usize) -> bool {
437    let mut changed = false;
438    for b in 0..p {
439        for a in 0..p {
440            if a == b {
441                continue;
442            }
443            // a *-> b: arrowhead at b from a
444            if graph.get_mark_at(a, b) != Some(EdgeMark::Arrow) {
445                continue;
446            }
447            for c in 0..p {
448                if c == a || c == b {
449                    continue;
450                }
451                if !graph.is_adjacent(b, c) {
452                    continue;
453                }
454                if graph.is_adjacent(a, c) {
455                    continue;
456                }
457                // b o-* c: circle mark at b on edge b-c
458                if graph.get_mark_from(b, c) != Some(EdgeMark::Circle) {
459                    continue;
460                }
461                // Orient: change circle at b to tail => b -> c (preserve mark at c)
462                let mark_at_c = graph.get_mark_at(b, c).unwrap_or(EdgeMark::Circle);
463                graph.set_edge(b, c, EdgeMark::Tail, mark_at_c);
464                changed = true;
465            }
466        }
467    }
468    changed
469}
470
471/// R2: If a -> b *-> c or a *-> b -> c, and a *-o c, orient a *-> c.
472fn fci_r2(graph: &mut CausalGraph, p: usize) -> bool {
473    let mut changed = false;
474    for a in 0..p {
475        for c in 0..p {
476            if a == c || !graph.is_adjacent(a, c) {
477                continue;
478            }
479            // a *-o c: circle mark at c from a
480            if graph.get_mark_at(a, c) != Some(EdgeMark::Circle) {
481                continue;
482            }
483            for b in 0..p {
484                if b == a || b == c {
485                    continue;
486                }
487                // Case 1: a -> b *-> c
488                let case1 = graph.get_mark_from(a, b) == Some(EdgeMark::Tail)
489                    && graph.get_mark_at(a, b) == Some(EdgeMark::Arrow)
490                    && graph.get_mark_at(b, c) == Some(EdgeMark::Arrow);
491                // Case 2: a *-> b -> c
492                let case2 = graph.get_mark_at(a, b) == Some(EdgeMark::Arrow)
493                    && graph.get_mark_from(b, c) == Some(EdgeMark::Tail)
494                    && graph.get_mark_at(b, c) == Some(EdgeMark::Arrow);
495
496                if case1 || case2 {
497                    // Orient a *-> c
498                    let mark_from_a = graph.get_mark_from(a, c).unwrap_or(EdgeMark::Circle);
499                    graph.set_edge(a, c, mark_from_a, EdgeMark::Arrow);
500                    changed = true;
501                    break;
502                }
503            }
504        }
505    }
506    changed
507}
508
509/// R3: If a *-> b <-* c, a *-o d o-* c, a not adj c, d *-o b,
510/// orient d *-> b.
511fn fci_r3(graph: &mut CausalGraph, adj: &[Vec<bool>], p: usize) -> bool {
512    let mut changed = false;
513    for d in 0..p {
514        for b in 0..p {
515            if d == b || !graph.is_adjacent(d, b) {
516                continue;
517            }
518            // d *-o b
519            if graph.get_mark_at(d, b) != Some(EdgeMark::Circle) {
520                continue;
521            }
522            // Find a, c such that: a *-> b <-* c, a *-o d o-* c, a not adj c
523            let parents_b: Vec<usize> = (0..p)
524                .filter(|&k| {
525                    k != b
526                        && k != d
527                        && graph.is_adjacent(k, b)
528                        && graph.get_mark_at(k, b) == Some(EdgeMark::Arrow)
529                })
530                .collect();
531            let mut orient = false;
532            for i in 0..parents_b.len() {
533                for j in (i + 1)..parents_b.len() {
534                    let a = parents_b[i];
535                    let c = parents_b[j];
536                    if adj[a][c] {
537                        continue;
538                    }
539                    // a *-o d
540                    if !graph.is_adjacent(a, d) {
541                        continue;
542                    }
543                    if graph.get_mark_at(a, d) != Some(EdgeMark::Circle) {
544                        continue;
545                    }
546                    // d o-* c (i.e., c *-o d from c's perspective)
547                    if !graph.is_adjacent(c, d) {
548                        continue;
549                    }
550                    if graph.get_mark_at(c, d) != Some(EdgeMark::Circle) {
551                        continue;
552                    }
553                    orient = true;
554                    break;
555                }
556                if orient {
557                    break;
558                }
559            }
560            if orient {
561                let mark_from = graph.get_mark_from(d, b).unwrap_or(EdgeMark::Circle);
562                graph.set_edge(d, b, mark_from, EdgeMark::Arrow);
563                changed = true;
564            }
565        }
566    }
567    changed
568}
569
570/// R4: Discriminating path rule.
571///
572/// If there is a discriminating path <a, ..., b, c> for b, and b is a
573/// collider on the path, and c is in sep(a, c_end), orient b <-> c;
574/// otherwise orient b -> c.
575fn fci_r4(
576    graph: &mut CausalGraph,
577    _adj: &[Vec<bool>],
578    sep_sets: &HashMap<(usize, usize), Vec<usize>>,
579    p: usize,
580) -> bool {
581    let mut changed = false;
582    // For each triple b, c where b *-> c o-* ? (circle at b on b-c edge)
583    for c in 0..p {
584        for b in 0..p {
585            if b == c || !graph.is_adjacent(b, c) {
586                continue;
587            }
588            if graph.get_mark_at(b, c) != Some(EdgeMark::Arrow) {
589                continue;
590            }
591            // b o-* c (circle at c end from b side)
592            if graph.get_mark_from(b, c) != Some(EdgeMark::Circle) {
593                continue;
594            }
595
596            // Try to find a discriminating path for b
597            // A discriminating path is: a - ... - v_k - b - c where
598            // a is not adjacent to c, every v_i is a collider with arrow into c (parent of c),
599            // and v_i *-> ... (collider on subpath)
600            for a in 0..p {
601                if a == b || a == c || !graph.is_adjacent(a, c) {
602                    continue;
603                }
604                // Simple case: length-3 discriminating path a - b - c
605                // a must not be adjacent to c... wait, a IS adjacent to c in this check
606                // Actually for a discriminating path, a is NOT adjacent to c
607                // Let me re-check: we need a not adjacent to c
608            }
609
610            // Simplified discriminating path: look for path a -> b *-> c
611            // where a is not adjacent to c, a -> b, and a is a parent of c
612            // This is a simplified version of R4
613            for a in 0..p {
614                if a == b || a == c {
615                    continue;
616                }
617                if graph.is_adjacent(a, c) {
618                    continue; // a must not be adj c
619                }
620                if !graph.is_adjacent(a, b) {
621                    continue;
622                }
623                // Check a *-> b
624                if graph.get_mark_at(a, b) != Some(EdgeMark::Arrow) {
625                    continue;
626                }
627
628                // Found discriminating path candidate
629                let key = (a.min(c), a.max(c));
630                let sep = sep_sets.get(&key).cloned().unwrap_or_default();
631
632                if sep.contains(&b) {
633                    // b is in sep(a, c) => orient b - c as tail
634                    let mark_from_b = graph.get_mark_from(b, c).unwrap_or(EdgeMark::Circle);
635                    let _mark_at_c = EdgeMark::Arrow;
636                    // Actually orient the circle: replace circle at b-side with tail
637                    graph.set_edge(b, c, EdgeMark::Tail, EdgeMark::Arrow);
638                    let _ = mark_from_b;
639                } else {
640                    // b not in sep(a,c) => orient as bidirected b <-> c
641                    graph.set_edge(b, c, EdgeMark::Arrow, EdgeMark::Arrow);
642                }
643                changed = true;
644                break;
645            }
646        }
647    }
648    changed
649}
650
651/// R5: If a o-o b and there is an uncovered circle path from a to b
652/// (all edges are o-o and consecutive nodes on the path are non-adjacent
653/// except via the path), orient a o-o b as a - b (tail-tail).
654fn fci_r5(graph: &mut CausalGraph, _adj: &[Vec<bool>], p: usize) -> bool {
655    let mut changed = false;
656    for a in 0..p {
657        for b in (a + 1)..p {
658            if !graph.is_adjacent(a, b) {
659                continue;
660            }
661            // Check a o-o b
662            if graph.get_mark_from(a, b) != Some(EdgeMark::Circle)
663                || graph.get_mark_at(a, b) != Some(EdgeMark::Circle)
664            {
665                continue;
666            }
667            // Look for uncovered circle path from a to b (length >= 3)
668            if has_uncovered_circle_path(graph, a, b, p) {
669                graph.set_edge(a, b, EdgeMark::Tail, EdgeMark::Tail);
670                changed = true;
671            }
672        }
673    }
674    changed
675}
676
677/// R6: If a - b o-* c, orient b -* c (change circle at b to tail).
678fn fci_r6(graph: &mut CausalGraph, p: usize) -> bool {
679    let mut changed = false;
680    for b in 0..p {
681        for a in 0..p {
682            if a == b || !graph.is_adjacent(a, b) {
683                continue;
684            }
685            // a - b (tail-tail, undirected)
686            if graph.get_mark_from(a, b) != Some(EdgeMark::Tail)
687                || graph.get_mark_at(a, b) != Some(EdgeMark::Tail)
688            {
689                continue;
690            }
691            for c in 0..p {
692                if c == a || c == b || !graph.is_adjacent(b, c) {
693                    continue;
694                }
695                // b o-* c
696                if graph.get_mark_from(b, c) != Some(EdgeMark::Circle) {
697                    continue;
698                }
699                // Orient: change circle at b-side to tail
700                let mark_at_c = graph.get_mark_at(b, c).unwrap_or(EdgeMark::Circle);
701                graph.set_edge(b, c, EdgeMark::Tail, mark_at_c);
702                changed = true;
703            }
704        }
705    }
706    changed
707}
708
709/// R7: If a -o b o-* c, a not adj c, orient b -* c (circle at b -> tail).
710fn fci_r7(graph: &mut CausalGraph, p: usize) -> bool {
711    let mut changed = false;
712    for b in 0..p {
713        for a in 0..p {
714            if a == b || !graph.is_adjacent(a, b) {
715                continue;
716            }
717            // a -o b: tail at a, circle at b
718            if graph.get_mark_from(a, b) != Some(EdgeMark::Tail)
719                || graph.get_mark_at(a, b) != Some(EdgeMark::Circle)
720            {
721                continue;
722            }
723            for c in 0..p {
724                if c == a || c == b || !graph.is_adjacent(b, c) {
725                    continue;
726                }
727                // a not adj c
728                if graph.is_adjacent(a, c) {
729                    continue;
730                }
731                // b o-* c
732                if graph.get_mark_from(b, c) != Some(EdgeMark::Circle) {
733                    continue;
734                }
735                let mark_at_c = graph.get_mark_at(b, c).unwrap_or(EdgeMark::Circle);
736                graph.set_edge(b, c, EdgeMark::Tail, mark_at_c);
737                changed = true;
738            }
739        }
740    }
741    changed
742}
743
744/// R8: If a -> b -> c or a -o b -> c, and a o-> c, orient a -> c.
745fn fci_r8(graph: &mut CausalGraph, p: usize) -> bool {
746    let mut changed = false;
747    for a in 0..p {
748        for c in 0..p {
749            if a == c || !graph.is_adjacent(a, c) {
750                continue;
751            }
752            // a o-> c
753            if graph.get_mark_from(a, c) != Some(EdgeMark::Circle)
754                || graph.get_mark_at(a, c) != Some(EdgeMark::Arrow)
755            {
756                continue;
757            }
758            for b in 0..p {
759                if b == a || b == c {
760                    continue;
761                }
762                // b -> c
763                if graph.get_mark_from(b, c) != Some(EdgeMark::Tail)
764                    || graph.get_mark_at(b, c) != Some(EdgeMark::Arrow)
765                {
766                    continue;
767                }
768                // a -> b or a -o b
769                let mark_at_b = graph.get_mark_at(a, b);
770                let mark_from_a_to_b = graph.get_mark_from(a, b);
771                let valid = match (mark_from_a_to_b, mark_at_b) {
772                    (Some(EdgeMark::Tail), Some(EdgeMark::Arrow)) => true, // a -> b
773                    (Some(EdgeMark::Tail), Some(EdgeMark::Circle)) => true, // a -o b
774                    _ => false,
775                };
776                if valid {
777                    graph.set_edge(a, c, EdgeMark::Tail, EdgeMark::Arrow);
778                    changed = true;
779                    break;
780                }
781            }
782        }
783    }
784    changed
785}
786
787/// R9: If a o-> c and there is a directed path from a to c
788/// (a -> ... -> c through intermediate nodes), orient a -> c.
789fn fci_r9(graph: &mut CausalGraph, p: usize) -> bool {
790    let mut changed = false;
791    for a in 0..p {
792        for c in 0..p {
793            if a == c || !graph.is_adjacent(a, c) {
794                continue;
795            }
796            // a o-> c
797            if graph.get_mark_from(a, c) != Some(EdgeMark::Circle)
798                || graph.get_mark_at(a, c) != Some(EdgeMark::Arrow)
799            {
800                continue;
801            }
802            // Check for directed path from a to c not through the direct edge
803            if has_directed_path_excluding_direct(graph, a, c, p) {
804                graph.set_edge(a, c, EdgeMark::Tail, EdgeMark::Arrow);
805                changed = true;
806            }
807        }
808    }
809    changed
810}
811
812/// R10: If a o-> c, b -> c, d -> c, a o-o b, a o-o d,
813/// and there exists a directed path from b to a or from d to a,
814/// orient a -> c.
815fn fci_r10(graph: &mut CausalGraph, p: usize) -> bool {
816    let mut changed = false;
817    for a in 0..p {
818        for c in 0..p {
819            if a == c || !graph.is_adjacent(a, c) {
820                continue;
821            }
822            // a o-> c
823            if graph.get_mark_from(a, c) != Some(EdgeMark::Circle)
824                || graph.get_mark_at(a, c) != Some(EdgeMark::Arrow)
825            {
826                continue;
827            }
828            // Find b, d: b -> c, d -> c, a o-o b, a o-o d
829            let parents_c: Vec<usize> = (0..p)
830                .filter(|&k| {
831                    k != a
832                        && k != c
833                        && graph.get_mark_from(k, c) == Some(EdgeMark::Tail)
834                        && graph.get_mark_at(k, c) == Some(EdgeMark::Arrow)
835                })
836                .collect();
837
838            let mut orient = false;
839            for i in 0..parents_c.len() {
840                for j in (i + 1)..parents_c.len() {
841                    let b = parents_c[i];
842                    let d = parents_c[j];
843                    // a o-o b and a o-o d
844                    let a_oo_b = graph.get_mark_from(a, b) == Some(EdgeMark::Circle)
845                        && graph.get_mark_at(a, b) == Some(EdgeMark::Circle);
846                    let a_oo_d = graph.get_mark_from(a, d) == Some(EdgeMark::Circle)
847                        && graph.get_mark_at(a, d) == Some(EdgeMark::Circle);
848                    if !a_oo_b || !a_oo_d {
849                        continue;
850                    }
851                    // Directed path from b to a or d to a
852                    if has_directed_path_general(graph, b, a, p)
853                        || has_directed_path_general(graph, d, a, p)
854                    {
855                        orient = true;
856                        break;
857                    }
858                }
859                if orient {
860                    break;
861                }
862            }
863            if orient {
864                graph.set_edge(a, c, EdgeMark::Tail, EdgeMark::Arrow);
865                changed = true;
866            }
867        }
868    }
869    changed
870}
871
872// ---------------------------------------------------------------------------
873// Path helpers
874// ---------------------------------------------------------------------------
875
876/// Check if there is an uncovered circle path (all o-o edges) from src to dst.
877fn has_uncovered_circle_path(graph: &CausalGraph, src: usize, dst: usize, p: usize) -> bool {
878    // BFS for o-o paths of length >= 3
879    let mut visited = vec![false; p];
880    visited[src] = true;
881    let mut queue = VecDeque::new();
882
883    // Start: neighbours of src connected by o-o edges
884    for k in 0..p {
885        if k == dst || k == src {
886            continue;
887        }
888        if graph.is_adjacent(src, k)
889            && graph.get_mark_from(src, k) == Some(EdgeMark::Circle)
890            && graph.get_mark_at(src, k) == Some(EdgeMark::Circle)
891        {
892            queue.push_back((k, 2usize)); // (node, path_length)
893        }
894    }
895
896    while let Some((cur, len)) = queue.pop_front() {
897        if visited[cur] {
898            continue;
899        }
900        visited[cur] = true;
901
902        // Check if cur connects to dst via o-o
903        if graph.is_adjacent(cur, dst)
904            && graph.get_mark_from(cur, dst) == Some(EdgeMark::Circle)
905            && graph.get_mark_at(cur, dst) == Some(EdgeMark::Circle)
906            && len + 1 >= 3
907        {
908            return true;
909        }
910
911        // Continue along o-o edges
912        for next in 0..p {
913            if visited[next] || next == src || next == dst {
914                continue;
915            }
916            if graph.is_adjacent(cur, next)
917                && graph.get_mark_from(cur, next) == Some(EdgeMark::Circle)
918                && graph.get_mark_at(cur, next) == Some(EdgeMark::Circle)
919            {
920                queue.push_back((next, len + 1));
921            }
922        }
923    }
924    false
925}
926
927/// Check for directed path from src to dst (using -> edges only), excluding
928/// the direct edge src-dst.
929fn has_directed_path_excluding_direct(
930    graph: &CausalGraph,
931    src: usize,
932    dst: usize,
933    p: usize,
934) -> bool {
935    let mut visited = vec![false; p];
936    let mut stack = Vec::new();
937    // Start from src's directed children (excluding dst directly)
938    for k in 0..p {
939        if k != dst
940            && graph.get_mark_from(src, k) == Some(EdgeMark::Tail)
941            && graph.get_mark_at(src, k) == Some(EdgeMark::Arrow)
942        {
943            stack.push(k);
944        }
945    }
946
947    while let Some(cur) = stack.pop() {
948        if cur == dst {
949            return true;
950        }
951        if visited[cur] {
952            continue;
953        }
954        visited[cur] = true;
955        for next in 0..p {
956            if !visited[next]
957                && graph.get_mark_from(cur, next) == Some(EdgeMark::Tail)
958                && graph.get_mark_at(cur, next) == Some(EdgeMark::Arrow)
959            {
960                stack.push(next);
961            }
962        }
963    }
964    false
965}
966
967/// Check for directed path from src to dst (general, including all -> edges).
968fn has_directed_path_general(graph: &CausalGraph, src: usize, dst: usize, p: usize) -> bool {
969    let mut visited = vec![false; p];
970    let mut stack = vec![src];
971    while let Some(cur) = stack.pop() {
972        if cur == dst && cur != src {
973            return true;
974        }
975        if visited[cur] {
976            continue;
977        }
978        visited[cur] = true;
979        for next in 0..p {
980            if !visited[next]
981                && graph.get_mark_from(cur, next) == Some(EdgeMark::Tail)
982                && graph.get_mark_at(cur, next) == Some(EdgeMark::Arrow)
983            {
984                stack.push(next);
985            }
986        }
987    }
988    false
989}
990
991// ---------------------------------------------------------------------------
992// Tests
993// ---------------------------------------------------------------------------
994
995#[cfg(test)]
996mod tests {
997    use super::*;
998    use scirs2_core::ndarray::Array2;
999
1000    fn lcg_uniform(s: &mut u64) -> f64 {
1001        *s = s
1002            .wrapping_mul(6364136223846793005)
1003            .wrapping_add(1442695040888963407);
1004        ((*s >> 11) as f64) / ((1u64 << 53) as f64)
1005    }
1006
1007    fn lcg_normal(s: &mut u64) -> f64 {
1008        let u1 = lcg_uniform(s).max(1e-15);
1009        let u2 = lcg_uniform(s);
1010        (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
1011    }
1012
1013    /// Generate chain X -> Y -> Z data.
1014    fn chain_data(n: usize, seed: u64) -> Array2<f64> {
1015        let mut data = Array2::<f64>::zeros((n, 3));
1016        let mut lcg = seed;
1017        for i in 0..n {
1018            data[[i, 0]] = lcg_normal(&mut lcg);
1019            data[[i, 1]] = 0.9 * data[[i, 0]] + lcg_normal(&mut lcg) * 0.3;
1020            data[[i, 2]] = 0.9 * data[[i, 1]] + lcg_normal(&mut lcg) * 0.3;
1021        }
1022        data
1023    }
1024
1025    /// Generate data with a latent confounder: X <- L -> Y, X -> Z, Y -> Z.
1026    /// (L is not observed, so we only have X, Y, Z)
1027    fn latent_confounder_data(n: usize, seed: u64) -> Array2<f64> {
1028        let mut data = Array2::<f64>::zeros((n, 3));
1029        let mut lcg = seed;
1030        for i in 0..n {
1031            let latent = lcg_normal(&mut lcg);
1032            data[[i, 0]] = 0.8 * latent + lcg_normal(&mut lcg) * 0.3;
1033            data[[i, 1]] = 0.8 * latent + lcg_normal(&mut lcg) * 0.3;
1034            data[[i, 2]] = 0.5 * data[[i, 0]] + 0.5 * data[[i, 1]] + lcg_normal(&mut lcg) * 0.3;
1035        }
1036        data
1037    }
1038
1039    #[test]
1040    fn test_fci_chain() {
1041        let data = chain_data(300, 12345);
1042        let fci = FciAlgorithm::new(0.05);
1043        let result = fci.fit(data.view(), &["X", "Y", "Z"]).expect("FCI failed");
1044        // X-Y adjacent, Y-Z adjacent, X-Z not adjacent
1045        assert!(
1046            result.graph.is_adjacent(0, 1),
1047            "X-Y should be adjacent in chain"
1048        );
1049        assert!(
1050            result.graph.is_adjacent(1, 2),
1051            "Y-Z should be adjacent in chain"
1052        );
1053        assert!(
1054            !result.graph.is_adjacent(0, 2),
1055            "X-Z should not be adjacent"
1056        );
1057    }
1058
1059    #[test]
1060    fn test_fci_latent_confounder() {
1061        let data = latent_confounder_data(500, 54321);
1062        let fci = FciAlgorithm::new(0.05);
1063        let result = fci.fit(data.view(), &["X", "Y", "Z"]).expect("FCI failed");
1064        // With a latent confounder between X and Y, FCI should detect
1065        // something unusual in the graph structure
1066        assert!(
1067            result.graph.is_adjacent(0, 1) || result.graph.is_adjacent(0, 2),
1068            "Should find some adjacency"
1069        );
1070        assert!(result.n_tests > 0, "Should perform CI tests");
1071    }
1072
1073    #[test]
1074    fn test_fci_produces_pag() {
1075        let data = chain_data(200, 99999);
1076        let fci = FciAlgorithm::new(0.05);
1077        let result = fci.fit(data.view(), &["X", "Y", "Z"]).expect("FCI failed");
1078        // PAG should have 3 nodes
1079        assert_eq!(result.graph.num_vars(), 3);
1080    }
1081
1082    #[test]
1083    fn test_fci_collider_detection() {
1084        // X -> Z <- Y, X and Y independent
1085        let n = 300;
1086        let mut data = Array2::<f64>::zeros((n, 3));
1087        let mut lcg: u64 = 77777;
1088        for i in 0..n {
1089            data[[i, 0]] = lcg_normal(&mut lcg);
1090            data[[i, 1]] = lcg_normal(&mut lcg);
1091            data[[i, 2]] = 0.7 * data[[i, 0]] + 0.7 * data[[i, 1]] + lcg_normal(&mut lcg) * 0.3;
1092        }
1093        let fci = FciAlgorithm::new(0.05);
1094        let result = fci.fit(data.view(), &["X", "Y", "Z"]).expect("FCI failed");
1095        // X-Z and Y-Z should be adjacent, X-Y not
1096        assert!(result.graph.is_adjacent(0, 2), "X-Z should be adjacent");
1097        assert!(result.graph.is_adjacent(1, 2), "Y-Z should be adjacent");
1098        assert!(
1099            !result.graph.is_adjacent(0, 1),
1100            "X-Y should not be adjacent"
1101        );
1102        // Z should have arrowheads from both X and Y (v-structure)
1103        assert!(
1104            result.graph.get_mark_at(0, 2) == Some(EdgeMark::Arrow)
1105                || result.graph.get_mark_at(1, 2) == Some(EdgeMark::Arrow),
1106            "Should detect v-structure at Z"
1107        );
1108    }
1109
1110    #[test]
1111    fn test_fci_possible_dsep() {
1112        let mut graph = CausalGraph::new(&["A", "B", "C", "D"]);
1113        // A o-> B, B o-o C, C o-> D
1114        graph.set_edge(0, 1, EdgeMark::Circle, EdgeMark::Arrow);
1115        graph.set_edge(1, 2, EdgeMark::Circle, EdgeMark::Circle);
1116        graph.set_edge(2, 3, EdgeMark::Circle, EdgeMark::Arrow);
1117
1118        let pdsep = possible_dsep(&graph, 0, 3, 4);
1119        // Should include B and C as possible d-separating nodes
1120        assert!(
1121            pdsep.contains(&1) || pdsep.contains(&2),
1122            "Possible-D-SEP should contain intermediate nodes"
1123        );
1124    }
1125
1126    #[test]
1127    fn test_fci_r1_orientation() {
1128        let mut graph = CausalGraph::new(&["A", "B", "C"]);
1129        // a *-> b o-* c, a not adj c
1130        graph.set_edge(0, 1, EdgeMark::Tail, EdgeMark::Arrow); // a -> b
1131        graph.set_edge(1, 2, EdgeMark::Circle, EdgeMark::Circle); // b o-o c
1132                                                                  // a not adjacent to c
1133
1134        let changed = fci_r1(&mut graph, 3);
1135        // R1 should orient b -> c (change circle at b to tail)
1136        assert!(changed, "R1 should make a change");
1137        assert_eq!(
1138            graph.get_mark_from(1, 2),
1139            Some(EdgeMark::Tail),
1140            "R1: b side should be tail"
1141        );
1142    }
1143
1144    #[test]
1145    fn test_fci_edge_marks() {
1146        let mut graph = CausalGraph::new(&["A", "B", "C"]);
1147        graph.set_edge(0, 1, EdgeMark::Tail, EdgeMark::Arrow);
1148        graph.set_edge(1, 2, EdgeMark::Arrow, EdgeMark::Arrow);
1149
1150        assert!(graph.is_directed(0, 1), "A -> B");
1151        assert!(graph.is_bidirected(1, 2), "B <-> C");
1152        assert!(!graph.is_undirected(0, 1), "A -> B is not undirected");
1153    }
1154
1155    #[test]
1156    fn test_fci_empty_graph() {
1157        let data = Array2::<f64>::zeros((10, 0));
1158        let fci = FciAlgorithm::new(0.05);
1159        let result = fci.fit(data.view(), &[]).expect("FCI should handle empty");
1160        assert_eq!(result.graph.num_vars(), 0);
1161        assert_eq!(result.n_tests, 0);
1162    }
1163
1164    #[test]
1165    fn test_fci_two_vars() {
1166        let n = 200;
1167        let mut data = Array2::<f64>::zeros((n, 2));
1168        let mut lcg: u64 = 11111;
1169        for i in 0..n {
1170            data[[i, 0]] = lcg_normal(&mut lcg);
1171            data[[i, 1]] = 0.9 * data[[i, 0]] + lcg_normal(&mut lcg) * 0.3;
1172        }
1173        let fci = FciAlgorithm::new(0.05);
1174        let result = fci.fit(data.view(), &["X", "Y"]).expect("FCI with 2 vars");
1175        assert!(result.graph.is_adjacent(0, 1), "X-Y should be adjacent");
1176    }
1177}