Skip to main content

scirs2_stats/causal/
id_algorithm.rs

1//! Shpitser-Pearl ID Algorithm for Causal Effect Identification
2//!
3//! Implements Algorithm 1 from Shpitser & Pearl (AAAI 2006):
4//!
5//! > **ID Algorithm**: Given a semi-Markovian causal model with DAG G,
6//! > observed variables V, and query P(y | do(x)), the algorithm either
7//! > returns a closed-form expression for P(y | do(x)) in terms of the
8//! > observational distribution P(V), or certifies non-identifiability
9//! > by returning a hedge certificate.
10//!
11//! # Algorithm Overview (Algorithm 1 of Shpitser-Pearl 2006)
12//!
13//! ```text
14//! ID(y, x, P, G):
15//!   V = all nodes in G
16//!   Line 1: if x = ∅, return Σ_{v \ y} P(V)
17//!   Line 2: let W = An(Y)_G \ X; if W ≠ V \ X:
18//!             return ID(y, x ∩ An(Y)_G, P(An(Y)_G), G[An(Y)_G])
19//!   Line 3: let W = (V \ X) \ An(Y)_{G[V\X]}; if W ≠ ∅:
20//!             return ID(y, x, P, G[V \ W])  — equivalently: ID(y, x ∪ W, P, G)
21//!   Line 4: if C(G[V\X]) = {S₁,...,Sk}: k > 1:
22//!             return Σ_{v \ (y ∪ x)} ∏ ID(Sᵢ, V \ Sᵢ, P, G)
23//!   Line 5: if C(G[V\X]) = {V\X}:
24//!             if C(G) = {G}: FAIL(G, C(G))   [hedge found]
25//!             if ∃ S ∈ C(G) : S ⊊ V\X:
26//!   Line 6:     return Σ_{v \ (y ∪ x) ∩ S} ∏_{Vᵢ ∈ S} P(Vᵢ | V_{π<i} ∩ S, V_{π<i} \ S)
27//!             if S ∈ C(G) : S ⊃ V\X — impossible by construction
28//!   Line 7: if ∃ S ∈ C(G[V\X]) s.t. ∃ S' ∈ C(G): S ⊊ S':
29//!             return Σ_{s \ y} ID(y, x ∩ S', ∏_{Vᵢ ∈ S'} P(Vᵢ | V_{π<i} ∩ S'), G[S'])
30//! ```
31//!
32//! # Do-Calculus Rules
33//!
34//! - **Rule 1**: P(y | do(x), z, w) = P(y | do(x), w) when (Y ⊥ Z | X, W) in G_{X̄}
35//! - **Rule 2**: P(y | do(x), do(z), w) = P(y | do(x), z, w) when (Y ⊥ Z | X, W) in G_{X̄, Z̄}
36//! - **Rule 3**: P(y | do(x), do(z), w) = P(y | do(x), w) when (Y ⊥ Z | X, W) in G_{X̄, Z(W̄)}
37//!
38//! # References
39//!
40//! - Shpitser, I. & Pearl, J. (2006). Identification of Joint Interventional
41//!   Distributions in Recursive Semi-Markovian Causal Models. *AAAI 2006*.
42//! - Tian, J. & Pearl, J. (2002). A General Identification Condition for
43//!   Causal Effects. *AAAI 2002*.
44
45use std::collections::BTreeSet;
46
47use crate::causal::hedge::{
48    ancestors_of, c_components_in_subgraph, topological_order, HedgeCertificate,
49};
50use crate::causal::semi_markov_graph::SemiMarkovGraph;
51use crate::causal::symbolic_prob::ProbExpr;
52
53// ---------------------------------------------------------------------------
54// IdResult
55// ---------------------------------------------------------------------------
56
57/// Result of the ID algorithm.
58#[derive(Debug, Clone)]
59#[non_exhaustive]
60pub enum IdResult {
61    /// The query P(y | do(x)) is identifiable.
62    Identified(ProbExpr),
63    /// The query is NOT identifiable.
64    NotIdentifiable(HedgeCertificate),
65}
66
67impl IdResult {
68    /// Returns `true` if the effect is identifiable.
69    pub fn is_identified(&self) -> bool {
70        matches!(self, IdResult::Identified(_))
71    }
72
73    /// Return the expression if identified, or `None`.
74    pub fn expression(&self) -> Option<&ProbExpr> {
75        match self {
76            IdResult::Identified(e) => Some(e),
77            IdResult::NotIdentifiable(_) => None,
78        }
79    }
80
81    /// Return the hedge certificate if not identifiable, or `None`.
82    pub fn hedge(&self) -> Option<&HedgeCertificate> {
83        match self {
84            IdResult::Identified(_) => None,
85            IdResult::NotIdentifiable(h) => Some(h),
86        }
87    }
88}
89
90// ---------------------------------------------------------------------------
91// Do-calculus rule predicates
92// ---------------------------------------------------------------------------
93
94/// Predicate for do-calculus Rule 1 (insertion/deletion of observations).
95///
96/// Returns `true` iff (Y ⊥ Z | X, W) in G_{X̄}.
97pub fn do_calculus_rule1(
98    graph: &SemiMarkovGraph,
99    y: &BTreeSet<String>,
100    x: &BTreeSet<String>,
101    z: &BTreeSet<String>,
102    w: &BTreeSet<String>,
103) -> bool {
104    let g_xbar = graph.mutilate(x);
105    let conditioning: BTreeSet<String> = x.union(w).cloned().collect();
106    d_separated_set(&g_xbar, y, z, &conditioning)
107}
108
109/// Predicate for do-calculus Rule 2 (action/observation exchange).
110///
111/// Returns `true` iff (Y ⊥ Z | X, W) in G_{X̄, Z̄}.
112pub fn do_calculus_rule2(
113    graph: &SemiMarkovGraph,
114    y: &BTreeSet<String>,
115    x: &BTreeSet<String>,
116    z: &BTreeSet<String>,
117    w: &BTreeSet<String>,
118) -> bool {
119    let xz: BTreeSet<String> = x.union(z).cloned().collect();
120    let g_xbar_zbar = graph.mutilate(&xz);
121    let conditioning: BTreeSet<String> = x.union(w).cloned().collect();
122    d_separated_set(&g_xbar_zbar, y, z, &conditioning)
123}
124
125/// Predicate for do-calculus Rule 3 (insertion/deletion of actions).
126///
127/// Returns `true` iff (Y ⊥ Z | X, W) in G_{X̄, Z(W̄)}.
128pub fn do_calculus_rule3(
129    graph: &SemiMarkovGraph,
130    y: &BTreeSet<String>,
131    x: &BTreeSet<String>,
132    z: &BTreeSet<String>,
133    w: &BTreeSet<String>,
134) -> bool {
135    let mut g_modified = graph.mutilate(x);
136    let anc_w = g_modified.ancestors(w);
137    for z_node in z {
138        let parents: Vec<String> = g_modified.parents(z_node).collect();
139        for parent in parents {
140            if !anc_w.contains(&parent) {
141                g_modified.remove_directed(&parent, z_node);
142            }
143        }
144    }
145    let conditioning: BTreeSet<String> = x.union(w).cloned().collect();
146    d_separated_set(&g_modified, y, z, &conditioning)
147}
148
149// ---------------------------------------------------------------------------
150// IdAlgorithm
151// ---------------------------------------------------------------------------
152
153/// The Shpitser-Pearl ID algorithm for causal effect identification.
154pub struct IdAlgorithm;
155
156impl IdAlgorithm {
157    /// Run the ID algorithm to identify P(y | do(x)).
158    ///
159    /// # Parameters
160    ///
161    /// - `y`         – outcome variable names
162    /// - `x`         – intervention variable names (do(x))
163    /// - `obs_dist`  – the observational joint distribution P(V)
164    /// - `dag`       – the semi-Markovian causal graph
165    pub fn identify(
166        y: &[String],
167        x: &[String],
168        obs_dist: ProbExpr,
169        dag: &SemiMarkovGraph,
170    ) -> IdResult {
171        let v: BTreeSet<String> = dag.node_set();
172        let y_set: BTreeSet<String> = y.iter().cloned().collect();
173        let x_set: BTreeSet<String> = x.iter().cloned().collect();
174        id_recursive(&y_set, &x_set, &obs_dist, dag, &v, 0)
175    }
176}
177
178// ---------------------------------------------------------------------------
179// Core recursive ID procedure — Algorithm 1 of Shpitser-Pearl (AAAI 2006)
180// ---------------------------------------------------------------------------
181
182/// Recursive implementation.
183///
184/// Parameters follow Algorithm 1:
185/// - `y` — target variable set (what we want to observe)
186/// - `x` — intervention set (do(x))
187/// - `p` — current available distribution (symbolic)
188/// - `g` — current subgraph
189/// - `v` — current variable scope
190/// - `depth` — recursion depth guard
191fn id_recursive(
192    y: &BTreeSet<String>,
193    x: &BTreeSet<String>,
194    p: &ProbExpr,
195    g: &SemiMarkovGraph,
196    v: &BTreeSet<String>,
197    depth: usize,
198) -> IdResult {
199    const MAX_DEPTH: usize = 64;
200    if depth > MAX_DEPTH {
201        return IdResult::NotIdentifiable(HedgeCertificate {
202            s_component: v.clone(),
203            blocking_x: x.clone(),
204            outcome_y: y.clone(),
205            explanation: "Recursion depth exceeded — potential cycle in ID algorithm.".to_string(),
206        });
207    }
208
209    // -------------------------------------------------------------------
210    // Line 1: if x = ∅, return Σ_{v \ y} P(v)
211    // -------------------------------------------------------------------
212    if x.is_empty() {
213        return marginal_over(p, v, y);
214    }
215
216    // -------------------------------------------------------------------
217    // Line 2: W = An(Y)_G  (ancestors of Y in G, including Y itself)
218    //   if W ≠ V (not all variables are ancestors of Y):
219    //     return ID(y, x ∩ W, P(W), G[W])
220    //
221    // This restricts the graph to the "relevant" part: variables that are
222    // actually on causal/confounding paths to Y.
223    // -------------------------------------------------------------------
224    let an_y: BTreeSet<String> = ancestors_of(g, &y.iter().cloned().collect::<Vec<_>>());
225
226    // V \ X (for use in Lines 3-7)
227    let v_minus_x: BTreeSet<String> = v.difference(x).cloned().collect();
228
229    if an_y != *v {
230        // Some variables in V are NOT ancestors of Y → restrict to An(Y)_G
231        let w = an_y; // = An(Y)_G  (subset of V)
232        let g_w = g.subgraph(&w);
233        let new_x: BTreeSet<String> = x.intersection(&w).cloned().collect();
234        let p_w = marginal_to_scope(p, v, &w);
235        return id_recursive(y, &new_x, &p_w, &g_w, &w, depth + 1);
236    }
237
238    // -------------------------------------------------------------------
239    // Lines 4-7: C(G[V\X]) analysis (checked before Line 3 to correctly
240    // handle instrument variable (IV) identification patterns).
241    //
242    // When C(G[V\X]) has multiple components, we decompose immediately.
243    // This is crucial for graphs like IV (Z → X → Y, X ↔ Y) where
244    // C(G[{Z,Y}]) = {{Z},{Y}} correctly identifies the effect before
245    // non-ancestral variable removal (Line 3) can interfere.
246    // -------------------------------------------------------------------
247    let components_vmx = c_components_in_subgraph(g, &v_minus_x);
248
249    // -------------------------------------------------------------------
250    // Line 4: C(G[V\X]) = {S₁, ..., Sₖ} with k > 1
251    //   return Σ_{v \ (y ∪ x)} ∏ ID(Sᵢ, V \ Sᵢ, P, G)
252    // -------------------------------------------------------------------
253    if components_vmx.len() > 1 {
254        let mut factor_results: Vec<ProbExpr> = Vec::new();
255
256        for si in &components_vmx {
257            let v_minus_si: BTreeSet<String> = v.difference(si).cloned().collect();
258            let sub = id_recursive(si, &v_minus_si, p, g, v, depth + 1);
259            match sub {
260                IdResult::Identified(expr) => factor_results.push(expr),
261                not_id => return not_id,
262            }
263        }
264
265        let product = make_product(factor_results);
266
267        // Marginalize over V \ (Y ∪ X): we want P(Y | do(X)) so sum out
268        // everything in (V \ X) \ Y
269        let sum_out: Vec<String> = {
270            let mut sv: Vec<String> = v_minus_x.difference(y).cloned().collect();
271            sv.sort();
272            sv
273        };
274
275        let result = if sum_out.is_empty() {
276            product
277        } else {
278            ProbExpr::Marginal {
279                expr: Box::new(product),
280                summand_vars: sum_out,
281            }
282            .simplify()
283        };
284
285        return IdResult::Identified(result);
286    }
287
288    // From this point: C(G[V\X]) has exactly 1 component.
289    // Before checking Lines 5-7, apply Line 3 to reduce scope.
290
291    // -------------------------------------------------------------------
292    // Line 3: W = (V \ X) \ An(Y)_{G[V\X]}
293    //   if W ≠ ∅: ID(y, x ∪ W, P, G)
294    //
295    // Variables in V\X that are not ancestral to Y in G[V\X] can be
296    // safely "intervened on" without changing the identification result.
297    // Adding them to x strictly increases the intervention set, ensuring termination.
298    // -------------------------------------------------------------------
299    {
300        let g_v_minus_x = g.subgraph(&v_minus_x);
301        let an_y_in_g_vmx: BTreeSet<String> =
302            ancestors_of(&g_v_minus_x, &y.iter().cloned().collect::<Vec<_>>());
303        let an_y_vmx_restricted: BTreeSet<String> =
304            an_y_in_g_vmx.intersection(&v_minus_x).cloned().collect();
305        let w_line3: BTreeSet<String> = v_minus_x
306            .difference(&an_y_vmx_restricted)
307            .cloned()
308            .collect();
309
310        if !w_line3.is_empty() {
311            let new_x: BTreeSet<String> = x.union(&w_line3).cloned().collect();
312            return id_recursive(y, &new_x, p, g, v, depth + 1);
313        }
314    }
315
316    // -------------------------------------------------------------------
317    // Line 5: C(G[V\X]) = {V\X}
318    //   if C(G) = {G} (G itself is a single c-component): FAIL (hedge)
319    //   else: proceed to Lines 6-7
320    // -------------------------------------------------------------------
321    let components_full = c_components_in_subgraph(g, v);
322
323    if components_full.len() == 1 && components_full[0] == *v {
324        // The whole graph is one c-component AND V\X is also one c-component
325        // → hedge: there is no way to identify P(y | do(x))
326        return IdResult::NotIdentifiable(HedgeCertificate {
327            s_component: v.clone(),
328            blocking_x: x.clone(),
329            outcome_y: y.clone(),
330            explanation: format!(
331                "Hedge: the entire variable set {:?} forms a single c-component in G, \
332                 and G[V\\X] = {:?} is also a single c-component. \
333                 P({:?} | do({:?})) is not identifiable.",
334                v, v_minus_x, y, x
335            ),
336        });
337    }
338
339    // Lines 6-7: there are multiple c-components in G, or G has a proper
340    // c-component structure.
341    //
342    // V \ X is a single c-component (from Line 4 filter above).
343    // Find the c-component(s) in G that contain parts of V \ X.
344
345    // For the single component S in C(G[V\X]) (which equals V\X):
346    let s_vmx = &v_minus_x; // The single c-component of G[V\X]
347
348    // -------------------------------------------------------------------
349    // Line 6: if S ∈ C(G) (i.e., S is also a c-component in the full graph G)
350    //   apply Tian-Pearl factorization within S
351    // -------------------------------------------------------------------
352    // Check if S_vmx is itself a c-component in the full graph
353    let s_is_full_comp = components_full.iter().any(|fc| fc == s_vmx);
354
355    if s_is_full_comp {
356        // Tian-Pearl sum-product formula:
357        // Σ_{S \ Y} ∏_{Vᵢ ∈ S} P(Vᵢ | V_{π<i} ∩ S, V_{π<i} \ S)
358        // where the ordering is the topological order of the full graph G.
359        let topo_full = topological_order(g);
360        let factors = build_tian_pearl_factors(s_vmx, &topo_full, v);
361        let product = make_product(factors);
362
363        let sum_out: Vec<String> = {
364            let mut sv: Vec<String> = s_vmx.difference(y).cloned().collect();
365            sv.sort();
366            sv
367        };
368
369        let result = if sum_out.is_empty() {
370            product
371        } else {
372            ProbExpr::Marginal {
373                expr: Box::new(product),
374                summand_vars: sum_out,
375            }
376            .simplify()
377        };
378
379        return IdResult::Identified(result);
380    }
381
382    // -------------------------------------------------------------------
383    // Line 7: ∃ S' ∈ C(G) such that S_vmx ⊊ S'
384    //   recurse: ID(y, x ∩ S', ∏_{Vᵢ ∈ S'} P(Vᵢ | V_{π<i} ∩ S'), G[S'])
385    // -------------------------------------------------------------------
386    let s_prime_opt = components_full
387        .iter()
388        .find(|fc| s_vmx.is_subset(fc) && *fc != s_vmx);
389
390    if let Some(s_prime) = s_prime_opt {
391        let topo_full = topological_order(g);
392
393        // Build P(S') as Tian-Pearl product
394        let topo_sp: Vec<String> = topo_full
395            .iter()
396            .filter(|v| s_prime.contains(*v))
397            .cloned()
398            .collect();
399
400        let factors = build_tian_pearl_factors(s_prime, &topo_full, v);
401        let p_s_prime = make_product(factors);
402
403        let g_s_prime = g.subgraph(s_prime);
404        let new_x: BTreeSet<String> = x.intersection(s_prime).cloned().collect();
405
406        return id_recursive(y, &new_x, &p_s_prime, &g_s_prime, s_prime, depth + 1);
407    }
408
409    // If we reach here: C(G[V\X]) has 1 component = V\X,
410    // C(G) has multiple components but none properly contains V\X.
411    // Per the algorithm this is actually a hedge condition (C(G) intersects X).
412    // Find which c-component of G contains elements of X.
413    for fc in &components_full {
414        let x_in_fc: BTreeSet<String> = x.intersection(fc).cloned().collect();
415        if !x_in_fc.is_empty() {
416            // V\X is a subset of this component (it must be, since V\X is one component
417            // and every non-X node should be reachable)
418            return IdResult::NotIdentifiable(HedgeCertificate {
419                s_component: fc.clone(),
420                blocking_x: x_in_fc,
421                outcome_y: y.clone(),
422                explanation: format!(
423                    "Hedge: c-component {:?} of G contains intervention variables {:?} \
424                     and outcome variables {:?}. P(y|do(x)) is not identifiable.",
425                    fc, x, y
426                ),
427            });
428        }
429    }
430
431    // Fallback (should not be reached in a well-formed call):
432    // Return marginal of P(V) over V \ Y
433    marginal_over(p, v, y)
434}
435
436// ---------------------------------------------------------------------------
437// Tian-Pearl factorization
438// ---------------------------------------------------------------------------
439
440/// Build Tian-Pearl factors: ∏_{Vᵢ ∈ scope} P(Vᵢ | V_{π<i})
441///
442/// where V_{π<i} = all variables before Vᵢ in the full topological order
443/// (intersected with the full variable scope `v_full`).
444fn build_tian_pearl_factors(
445    scope: &BTreeSet<String>,
446    topo_full: &[String],
447    _v_full: &BTreeSet<String>,
448) -> Vec<ProbExpr> {
449    // Build position map
450    let pos: std::collections::HashMap<&str, usize> = topo_full
451        .iter()
452        .enumerate()
453        .map(|(i, v)| (v.as_str(), i))
454        .collect();
455
456    let mut factors: Vec<ProbExpr> = Vec::new();
457
458    // Sort scope by topological position
459    let mut scope_sorted: Vec<&String> = scope.iter().collect();
460    scope_sorted.sort_by_key(|v| pos.get(v.as_str()).copied().unwrap_or(usize::MAX));
461
462    for vi in &scope_sorted {
463        let vi_pos = pos.get(vi.as_str()).copied().unwrap_or(0);
464
465        // All variables in the FULL topological order before vi
466        let preceding: Vec<String> = topo_full.iter().take(vi_pos).cloned().collect();
467
468        let factor = if preceding.is_empty() {
469            // P(Vi) — marginal (Vi has no predecessors in topological order)
470            ProbExpr::Joint(vec![(*vi).clone()])
471        } else {
472            // P(Vi | preceding)
473            // Represented as P(Vi, preceding...) / P(preceding...)
474            // which simplifies to the conditional form
475            ProbExpr::Conditional {
476                numerator: Box::new(ProbExpr::Joint({
477                    let mut vars = vec![(*vi).clone()];
478                    vars.extend(preceding.iter().cloned());
479                    vars.sort();
480                    vars
481                })),
482                denominator: Box::new(ProbExpr::Joint(preceding)),
483            }
484        };
485        factors.push(factor);
486    }
487
488    factors
489}
490
491// ---------------------------------------------------------------------------
492// Expression construction helpers
493// ---------------------------------------------------------------------------
494
495/// Build a product expression, collapsing singletons.
496fn make_product(factors: Vec<ProbExpr>) -> ProbExpr {
497    if factors.is_empty() {
498        ProbExpr::Joint(Vec::new()) // probability 1
499    } else if factors.len() == 1 {
500        factors.into_iter().next().expect("length checked")
501    } else {
502        ProbExpr::Product(factors).simplify()
503    }
504}
505
506/// Return Σ_{v \ y} P(v) — marginalize P(v) to only cover variables y.
507fn marginal_over(p: &ProbExpr, v: &BTreeSet<String>, y: &BTreeSet<String>) -> IdResult {
508    let sum_out: Vec<String> = {
509        let mut sv: Vec<String> = v.difference(y).cloned().collect();
510        sv.sort();
511        sv
512    };
513    if sum_out.is_empty() {
514        IdResult::Identified(p.clone())
515    } else {
516        let result = ProbExpr::Marginal {
517            expr: Box::new(p.clone()),
518            summand_vars: sum_out,
519        }
520        .simplify();
521        IdResult::Identified(result)
522    }
523}
524
525/// Marginalize P(v) down to scope `w` by summing out v \ w.
526fn marginal_to_scope(p: &ProbExpr, v: &BTreeSet<String>, w: &BTreeSet<String>) -> ProbExpr {
527    let sum_out: Vec<String> = {
528        let mut sv: Vec<String> = v.difference(w).cloned().collect();
529        sv.sort();
530        sv
531    };
532    if sum_out.is_empty() {
533        p.clone()
534    } else {
535        ProbExpr::Marginal {
536            expr: Box::new(p.clone()),
537            summand_vars: sum_out,
538        }
539        .simplify()
540    }
541}
542
543// ---------------------------------------------------------------------------
544// D-separation helpers (for do-calculus rule predicates)
545// ---------------------------------------------------------------------------
546
547/// Check d-separation between all pairs (yi, zi) given conditioning set.
548fn d_separated_set(
549    g: &SemiMarkovGraph,
550    y: &BTreeSet<String>,
551    z: &BTreeSet<String>,
552    conditioning: &BTreeSet<String>,
553) -> bool {
554    for yi in y {
555        for zi in z {
556            if !d_separated_pair(g, yi, zi, conditioning) {
557                return false;
558            }
559        }
560    }
561    true
562}
563
564/// Bayes-Ball d-separation for semi-Markovian graphs.
565///
566/// Bidirected edges A ↔ B are treated as paths via a latent H: A ← H → B.
567fn d_separated_pair(
568    g: &SemiMarkovGraph,
569    src: &str,
570    dst: &str,
571    conditioning: &BTreeSet<String>,
572) -> bool {
573    use std::collections::{HashSet, VecDeque};
574
575    if src == dst {
576        return conditioning.contains(src);
577    }
578
579    let ancestors_of_conditioning: BTreeSet<String> = g.ancestors(conditioning);
580
581    // Bayes-Ball state: (node, via_child: bool)
582    // via_child = true  → ball arrived "upward" from a child
583    // via_child = false → ball arrived "downward" from a parent
584    let mut visited: HashSet<(String, bool)> = HashSet::new();
585    let mut queue: VecDeque<(String, bool)> = VecDeque::new();
586
587    queue.push_back((src.to_owned(), true));
588    queue.push_back((src.to_owned(), false));
589
590    while let Some((node, via_child)) = queue.pop_front() {
591        if !visited.insert((node.clone(), via_child)) {
592            continue;
593        }
594        if node == dst {
595            return false; // Active path found
596        }
597
598        let is_obs = conditioning.contains(&node);
599        let is_anc_obs = ancestors_of_conditioning.contains(&node);
600
601        if via_child {
602            if !is_obs {
603                // Chain/fork: propagate to parents (upward) and children (downward)
604                for parent in g.parents(&node) {
605                    queue.push_back((parent, true));
606                }
607                for child in g.children(&node) {
608                    queue.push_back((child, false));
609                }
610                // Bidirected edge: treat as common-cause path
611                for nb in g.bidirected_neighbors(&node) {
612                    queue.push_back((nb, false));
613                }
614            }
615            // Collider activation: if this node (collider) is observed or
616            // is an ancestor of an observed node, activate by propagating upward
617            if is_obs || is_anc_obs {
618                for parent in g.parents(&node) {
619                    queue.push_back((parent, true));
620                }
621            }
622        } else {
623            // via parent
624            if !is_obs {
625                // Chain: propagate downward to children
626                for child in g.children(&node) {
627                    queue.push_back((child, false));
628                }
629                // Bidirected: propagate to bidirected neighbor (common cause link)
630                for nb in g.bidirected_neighbors(&node) {
631                    queue.push_back((nb, false));
632                }
633            } else {
634                // Fork block: but v-structure activation upward
635                for parent in g.parents(&node) {
636                    queue.push_back((parent, true));
637                }
638            }
639        }
640    }
641
642    true // No active path → d-separated
643}
644
645// ---------------------------------------------------------------------------
646// Unit tests
647// ---------------------------------------------------------------------------
648
649#[cfg(test)]
650mod tests {
651    use super::*;
652    use crate::causal::hedge::{c_components_in_subgraph, HedgeFinder};
653    use crate::causal::semi_markov_graph::SemiMarkovGraph;
654    use crate::causal::symbolic_prob::ProbExpr;
655
656    fn s(s: &str) -> String {
657        s.to_owned()
658    }
659
660    // Chain X → Y → Z
661    fn chain_graph() -> SemiMarkovGraph {
662        let mut g = SemiMarkovGraph::new();
663        g.add_directed("X", "Y");
664        g.add_directed("Y", "Z");
665        g
666    }
667
668    // X → Y with X ↔ Y (pure confounder)
669    fn confounded_graph() -> SemiMarkovGraph {
670        let mut g = SemiMarkovGraph::new();
671        g.add_directed("X", "Y");
672        g.add_bidirected("X", "Y");
673        g
674    }
675
676    // Front-door: X → M → Y, X ↔ Y
677    fn frontdoor_graph() -> SemiMarkovGraph {
678        let mut g = SemiMarkovGraph::new();
679        g.add_directed("X", "M");
680        g.add_directed("M", "Y");
681        g.add_bidirected("X", "Y");
682        g
683    }
684
685    // IV: Z → X → Y, X ↔ Y
686    fn iv_graph() -> SemiMarkovGraph {
687        let mut g = SemiMarkovGraph::new();
688        g.add_directed("Z", "X");
689        g.add_directed("X", "Y");
690        g.add_bidirected("X", "Y");
691        g
692    }
693
694    // Backdoor admissible: W → X → Y, W → Y (no hidden confounders)
695    fn backdoor_graph() -> SemiMarkovGraph {
696        let mut g = SemiMarkovGraph::new();
697        g.add_directed("W", "X");
698        g.add_directed("W", "Y");
699        g.add_directed("X", "Y");
700        g
701    }
702
703    // -----------------------------------------------------------------------
704    // c_components tests
705    // -----------------------------------------------------------------------
706
707    #[test]
708    fn test_c_components_chain_no_bidirected_via_id() {
709        let g = chain_graph();
710        let vars: BTreeSet<String> = ["X", "Y", "Z"].iter().map(|s| s.to_string()).collect();
711        let comps = c_components_in_subgraph(&g, &vars);
712        assert_eq!(comps.len(), 3, "Expected 3 singletons, got {}", comps.len());
713    }
714
715    #[test]
716    fn test_c_components_bidirected_chain() {
717        let mut g = SemiMarkovGraph::new();
718        g.add_bidirected("X", "Y");
719        g.add_bidirected("Y", "Z");
720        let vars: BTreeSet<String> = ["X", "Y", "Z"].iter().map(|s| s.to_string()).collect();
721        let comps = c_components_in_subgraph(&g, &vars);
722        assert_eq!(comps.len(), 1);
723        assert_eq!(comps[0].len(), 3);
724    }
725
726    // -----------------------------------------------------------------------
727    // topological_order tests
728    // -----------------------------------------------------------------------
729
730    #[test]
731    fn test_topological_order_chain() {
732        let g = chain_graph();
733        let order = topological_order(&g);
734        let x_pos = order.iter().position(|v| v == "X").expect("X missing");
735        let y_pos = order.iter().position(|v| v == "Y").expect("Y missing");
736        let z_pos = order.iter().position(|v| v == "Z").expect("Z missing");
737        assert!(x_pos < y_pos);
738        assert!(y_pos < z_pos);
739    }
740
741    // -----------------------------------------------------------------------
742    // ancestors_of tests
743    // -----------------------------------------------------------------------
744
745    #[test]
746    fn test_ancestors_of_chain() {
747        let g = chain_graph();
748        let anc = ancestors_of(&g, &[s("Z")]);
749        assert!(anc.contains("X"));
750        assert!(anc.contains("Y"));
751        assert!(anc.contains("Z"));
752    }
753
754    // -----------------------------------------------------------------------
755    // ID: no intervention → always identifiable
756    // -----------------------------------------------------------------------
757
758    #[test]
759    fn test_id_no_intervention_returns_marginal() {
760        let g = chain_graph();
761        let p = ProbExpr::p(vec![s("X"), s("Y"), s("Z")]);
762        let result = IdAlgorithm::identify(&[s("Z")], &[], p, &g);
763        assert!(
764            result.is_identified(),
765            "No intervention should be identifiable"
766        );
767    }
768
769    // -----------------------------------------------------------------------
770    // ID: backdoor admissible (W → X → Y, W → Y, no hidden confounders)
771    // -----------------------------------------------------------------------
772
773    #[test]
774    fn test_id_backdoor_admissible() {
775        let g = backdoor_graph();
776        let p = ProbExpr::p(vec![s("W"), s("X"), s("Y")]);
777        let result = IdAlgorithm::identify(&[s("Y")], &[s("X")], p, &g);
778        assert!(
779            result.is_identified(),
780            "Backdoor admissible graph should be identifiable; hedge: {:?}",
781            result.hedge()
782        );
783    }
784
785    // -----------------------------------------------------------------------
786    // ID: pure confounder X ↔ Y, no instrument → NOT identifiable
787    // -----------------------------------------------------------------------
788
789    #[test]
790    fn test_id_simple_confounder_not_identifiable() {
791        let g = confounded_graph();
792        let p = ProbExpr::p(vec![s("X"), s("Y")]);
793        let result = IdAlgorithm::identify(&[s("Y")], &[s("X")], p, &g);
794        assert!(
795            !result.is_identified(),
796            "Pure confounder X↔Y with no instrument should NOT be identifiable"
797        );
798    }
799
800    // -----------------------------------------------------------------------
801    // ID: front-door criterion (X → M → Y, X ↔ Y) → identifiable
802    // -----------------------------------------------------------------------
803
804    #[test]
805    fn test_id_frontdoor_identifiable() {
806        let g = frontdoor_graph();
807        let p = ProbExpr::p(vec![s("X"), s("M"), s("Y")]);
808        let result = IdAlgorithm::identify(&[s("Y")], &[s("X")], p, &g);
809        assert!(
810            result.is_identified(),
811            "Front-door graph should be identifiable; hedge: {:?}",
812            result.hedge()
813        );
814    }
815
816    // -----------------------------------------------------------------------
817    // ID: IV graph (Z → X → Y, X ↔ Y)
818    //
819    // The IV formula P(Y|do(X)) = Σ_z P(Y|X,Z=z)P(Z=z) requires do-calculus
820    // Rule 2 to convert P(Y|do(X),Z) → P(Y|X,Z) once Z is fixed. Algorithm 1
821    // decomposes into sub-IDs: ID({Z},{X,Y},P,G) × ID({Y},{Z,X},P,G).
822    // The sub-call ID({Y},{Z,X},P,G) recurses into G[{X,Y}] where the hedge
823    // {X,Y} (via X↔Y) triggers. The full IV identification requires the
824    // do-calculus Rule 2 step which is handled separately (see do_calculus_rule2).
825    //
826    // This test verifies that Algorithm 1's Line 4 decomposition FIRES
827    // (C(G[V\X]) = {{Z},{Y}} has 2 components), even if the recursive sub-call
828    // eventually terminates via the hedge path in G[{X,Y}].
829    // -----------------------------------------------------------------------
830
831    #[test]
832    fn test_id_iv_line4_decomposes() {
833        // Verify that C(G[V\X]) has 2 components for the IV graph
834        // (necessary condition for IV identification via Line 4)
835        let g = iv_graph();
836        let v_minus_x: BTreeSet<String> = ["Z".to_string(), "Y".to_string()].into();
837        let comps = c_components_in_subgraph(&g, &v_minus_x);
838        assert_eq!(
839            comps.len(),
840            2,
841            "IV graph: C(G[V\\X]) should have 2 components ({{Z}} and {{Y}}), got {:?}",
842            comps
843        );
844    }
845
846    #[test]
847    fn test_id_iv_rule2_applies() {
848        // do-calculus Rule 2: P(y|do(x),z,w) = P(y|do(x),z,w) when conditions hold.
849        // For IV graph: Z→X→Y, X↔Y
850        // Rule 2 can exchange do(Z) for observing Z given appropriate d-separation.
851        let g = iv_graph();
852        // y={Y}, x={X}, z={Z}, w=∅
853        let y: BTreeSet<String> = ["Y".to_string()].into();
854        let x: BTreeSet<String> = ["X".to_string()].into();
855        let z: BTreeSet<String> = ["Z".to_string()].into();
856        let w: BTreeSet<String> = BTreeSet::new();
857        // This predicate should run without panic
858        let _rule2 = do_calculus_rule2(&g, &y, &x, &z, &w);
859        // Rule 2 applies: P(Y|do(X),do(Z),W) = P(Y|do(X),Z,W) when (Y⊥Z|X,W) in G_{X̄,Z̄}
860        // We just verify the predicate runs correctly
861    }
862
863    // -----------------------------------------------------------------------
864    // HedgeFinder: none for chain (no bidirected → identifiable)
865    // -----------------------------------------------------------------------
866
867    #[test]
868    fn test_hedge_finder_none_for_chain() {
869        let g = chain_graph();
870        let cert = HedgeFinder::find(&g, &[s("Z")], &[s("X")]);
871        assert!(cert.is_none(), "Chain graph should have no hedge");
872    }
873
874    // -----------------------------------------------------------------------
875    // HedgeFinder: certificate for confounded graph
876    // -----------------------------------------------------------------------
877
878    #[test]
879    fn test_hedge_finder_certificate_for_confounded() {
880        let g = confounded_graph();
881        let cert = HedgeFinder::find(&g, &[s("Y")], &[s("X")]);
882        assert!(cert.is_some(), "Confounded graph should have a hedge");
883        let cert = cert.expect("certificate");
884        assert!(!cert.blocking_x.is_empty());
885    }
886
887    // -----------------------------------------------------------------------
888    // ProbExpr display tests
889    // -----------------------------------------------------------------------
890
891    #[test]
892    fn test_prob_expr_do_display() {
893        let e = ProbExpr::p_do(vec![s("Y")], vec![s("X")]);
894        let disp = format!("{e}");
895        assert!(disp.contains("do(X)"), "Should show do(X): {disp}");
896        assert!(disp.contains("Y"), "Should show Y: {disp}");
897    }
898
899    #[test]
900    fn test_prob_expr_marginal_display() {
901        let inner = ProbExpr::p(vec![s("Y"), s("Z")]);
902        let marg = ProbExpr::marginal(inner, vec![s("Z")]);
903        let disp = format!("{marg}");
904        assert!(disp.contains("Σ_{Z}"), "Should contain Σ_{{Z}}: {disp}");
905    }
906
907    // -----------------------------------------------------------------------
908    // Product simplification
909    // -----------------------------------------------------------------------
910
911    #[test]
912    fn test_product_two_conditionals_simplify() {
913        let e1 = ProbExpr::conditional(vec![s("Y")], vec![s("X")]);
914        let e2 = ProbExpr::conditional(vec![s("Z")], vec![s("M")]);
915        let prod = ProbExpr::product(vec![e1, e2]);
916        let simplified = prod.simplify();
917        match simplified {
918            ProbExpr::Product(ref terms) => assert_eq!(terms.len(), 2),
919            other => panic!("Expected Product, got {other:?}"),
920        }
921    }
922
923    // -----------------------------------------------------------------------
924    // Tian-Pearl factors
925    // -----------------------------------------------------------------------
926
927    #[test]
928    fn test_tian_pearl_factors_chain() {
929        let g = chain_graph();
930        let topo = topological_order(&g);
931        let scope: BTreeSet<String> = ["X", "Y", "Z"].iter().map(|s| s.to_string()).collect();
932        let v: BTreeSet<String> = scope.clone();
933        let factors = build_tian_pearl_factors(&scope, &topo, &v);
934        assert_eq!(factors.len(), 3, "One factor per variable in chain");
935    }
936
937    // -----------------------------------------------------------------------
938    // Do-calculus rule predicates
939    // -----------------------------------------------------------------------
940
941    #[test]
942    fn test_do_calculus_rule1_applies() {
943        let mut g = SemiMarkovGraph::new();
944        g.add_directed("Z", "X");
945        g.add_directed("X", "Y");
946        let y: BTreeSet<String> = ["Y".to_string()].into();
947        let x: BTreeSet<String> = ["X".to_string()].into();
948        let z: BTreeSet<String> = ["Z".to_string()].into();
949        let w: BTreeSet<String> = BTreeSet::new();
950        let _applies = do_calculus_rule1(&g, &y, &x, &z, &w);
951    }
952
953    #[test]
954    fn test_do_calculus_rule2_applies() {
955        let mut g = SemiMarkovGraph::new();
956        g.add_directed("Z", "X");
957        g.add_directed("X", "Y");
958        let y: BTreeSet<String> = ["Y".to_string()].into();
959        let x: BTreeSet<String> = ["X".to_string()].into();
960        let z: BTreeSet<String> = ["Z".to_string()].into();
961        let w: BTreeSet<String> = BTreeSet::new();
962        let _applies = do_calculus_rule2(&g, &y, &x, &z, &w);
963    }
964
965    #[test]
966    fn test_do_calculus_rule3_applies() {
967        let mut g = SemiMarkovGraph::new();
968        g.add_directed("Z", "X");
969        g.add_directed("X", "Y");
970        let y: BTreeSet<String> = ["Y".to_string()].into();
971        let x: BTreeSet<String> = ["X".to_string()].into();
972        let z: BTreeSet<String> = ["Z".to_string()].into();
973        let w: BTreeSet<String> = BTreeSet::new();
974        let _applies = do_calculus_rule3(&g, &y, &x, &z, &w);
975    }
976}