scirs2_stats/causal/id_algorithm.rs
1//! Shpitser-Pearl ID Algorithm for Causal Effect Identification
2//!
3//! Implements Algorithm 1 from Shpitser & Pearl (AAAI 2006):
4//!
5//! > **ID Algorithm**: Given a semi-Markovian causal model with DAG G,
6//! > observed variables V, and query P(y | do(x)), the algorithm either
7//! > returns a closed-form expression for P(y | do(x)) in terms of the
8//! > observational distribution P(V), or certifies non-identifiability
9//! > by returning a hedge certificate.
10//!
11//! # Algorithm Overview (Algorithm 1 of Shpitser-Pearl 2006)
12//!
13//! ```text
14//! ID(y, x, P, G):
15//! V = all nodes in G
16//! Line 1: if x = ∅, return Σ_{v \ y} P(V)
17//! Line 2: let W = An(Y)_G \ X; if W ≠ V \ X:
18//! return ID(y, x ∩ An(Y)_G, P(An(Y)_G), G[An(Y)_G])
19//! Line 3: let W = (V \ X) \ An(Y)_{G[V\X]}; if W ≠ ∅:
20//! return ID(y, x, P, G[V \ W]) — equivalently: ID(y, x ∪ W, P, G)
21//! Line 4: if C(G[V\X]) = {S₁,...,Sk}: k > 1:
22//! return Σ_{v \ (y ∪ x)} ∏ ID(Sᵢ, V \ Sᵢ, P, G)
23//! Line 5: if C(G[V\X]) = {V\X}:
24//! if C(G) = {G}: FAIL(G, C(G)) [hedge found]
25//! if ∃ S ∈ C(G) : S ⊊ V\X:
26//! Line 6: return Σ_{v \ (y ∪ x) ∩ S} ∏_{Vᵢ ∈ S} P(Vᵢ | V_{π<i} ∩ S, V_{π<i} \ S)
27//! if S ∈ C(G) : S ⊃ V\X — impossible by construction
28//! Line 7: if ∃ S ∈ C(G[V\X]) s.t. ∃ S' ∈ C(G): S ⊊ S':
29//! return Σ_{s \ y} ID(y, x ∩ S', ∏_{Vᵢ ∈ S'} P(Vᵢ | V_{π<i} ∩ S'), G[S'])
30//! ```
31//!
32//! # Do-Calculus Rules
33//!
34//! - **Rule 1**: P(y | do(x), z, w) = P(y | do(x), w) when (Y ⊥ Z | X, W) in G_{X̄}
35//! - **Rule 2**: P(y | do(x), do(z), w) = P(y | do(x), z, w) when (Y ⊥ Z | X, W) in G_{X̄, Z̄}
36//! - **Rule 3**: P(y | do(x), do(z), w) = P(y | do(x), w) when (Y ⊥ Z | X, W) in G_{X̄, Z(W̄)}
37//!
38//! # References
39//!
40//! - Shpitser, I. & Pearl, J. (2006). Identification of Joint Interventional
41//! Distributions in Recursive Semi-Markovian Causal Models. *AAAI 2006*.
42//! - Tian, J. & Pearl, J. (2002). A General Identification Condition for
43//! Causal Effects. *AAAI 2002*.
44
45use std::collections::BTreeSet;
46
47use crate::causal::hedge::{
48 ancestors_of, c_components_in_subgraph, topological_order, HedgeCertificate,
49};
50use crate::causal::semi_markov_graph::SemiMarkovGraph;
51use crate::causal::symbolic_prob::ProbExpr;
52
53// ---------------------------------------------------------------------------
54// IdResult
55// ---------------------------------------------------------------------------
56
57/// Result of the ID algorithm.
58#[derive(Debug, Clone)]
59#[non_exhaustive]
60pub enum IdResult {
61 /// The query P(y | do(x)) is identifiable.
62 Identified(ProbExpr),
63 /// The query is NOT identifiable.
64 NotIdentifiable(HedgeCertificate),
65}
66
67impl IdResult {
68 /// Returns `true` if the effect is identifiable.
69 pub fn is_identified(&self) -> bool {
70 matches!(self, IdResult::Identified(_))
71 }
72
73 /// Return the expression if identified, or `None`.
74 pub fn expression(&self) -> Option<&ProbExpr> {
75 match self {
76 IdResult::Identified(e) => Some(e),
77 IdResult::NotIdentifiable(_) => None,
78 }
79 }
80
81 /// Return the hedge certificate if not identifiable, or `None`.
82 pub fn hedge(&self) -> Option<&HedgeCertificate> {
83 match self {
84 IdResult::Identified(_) => None,
85 IdResult::NotIdentifiable(h) => Some(h),
86 }
87 }
88}
89
90// ---------------------------------------------------------------------------
91// Do-calculus rule predicates
92// ---------------------------------------------------------------------------
93
94/// Predicate for do-calculus Rule 1 (insertion/deletion of observations).
95///
96/// Returns `true` iff (Y ⊥ Z | X, W) in G_{X̄}.
97pub fn do_calculus_rule1(
98 graph: &SemiMarkovGraph,
99 y: &BTreeSet<String>,
100 x: &BTreeSet<String>,
101 z: &BTreeSet<String>,
102 w: &BTreeSet<String>,
103) -> bool {
104 let g_xbar = graph.mutilate(x);
105 let conditioning: BTreeSet<String> = x.union(w).cloned().collect();
106 d_separated_set(&g_xbar, y, z, &conditioning)
107}
108
109/// Predicate for do-calculus Rule 2 (action/observation exchange).
110///
111/// Returns `true` iff (Y ⊥ Z | X, W) in G_{X̄, Z̄}.
112pub fn do_calculus_rule2(
113 graph: &SemiMarkovGraph,
114 y: &BTreeSet<String>,
115 x: &BTreeSet<String>,
116 z: &BTreeSet<String>,
117 w: &BTreeSet<String>,
118) -> bool {
119 let xz: BTreeSet<String> = x.union(z).cloned().collect();
120 let g_xbar_zbar = graph.mutilate(&xz);
121 let conditioning: BTreeSet<String> = x.union(w).cloned().collect();
122 d_separated_set(&g_xbar_zbar, y, z, &conditioning)
123}
124
125/// Predicate for do-calculus Rule 3 (insertion/deletion of actions).
126///
127/// Returns `true` iff (Y ⊥ Z | X, W) in G_{X̄, Z(W̄)}.
128pub fn do_calculus_rule3(
129 graph: &SemiMarkovGraph,
130 y: &BTreeSet<String>,
131 x: &BTreeSet<String>,
132 z: &BTreeSet<String>,
133 w: &BTreeSet<String>,
134) -> bool {
135 let mut g_modified = graph.mutilate(x);
136 let anc_w = g_modified.ancestors(w);
137 for z_node in z {
138 let parents: Vec<String> = g_modified.parents(z_node).collect();
139 for parent in parents {
140 if !anc_w.contains(&parent) {
141 g_modified.remove_directed(&parent, z_node);
142 }
143 }
144 }
145 let conditioning: BTreeSet<String> = x.union(w).cloned().collect();
146 d_separated_set(&g_modified, y, z, &conditioning)
147}
148
149// ---------------------------------------------------------------------------
150// IdAlgorithm
151// ---------------------------------------------------------------------------
152
153/// The Shpitser-Pearl ID algorithm for causal effect identification.
154pub struct IdAlgorithm;
155
156impl IdAlgorithm {
157 /// Run the ID algorithm to identify P(y | do(x)).
158 ///
159 /// # Parameters
160 ///
161 /// - `y` – outcome variable names
162 /// - `x` – intervention variable names (do(x))
163 /// - `obs_dist` – the observational joint distribution P(V)
164 /// - `dag` – the semi-Markovian causal graph
165 pub fn identify(
166 y: &[String],
167 x: &[String],
168 obs_dist: ProbExpr,
169 dag: &SemiMarkovGraph,
170 ) -> IdResult {
171 let v: BTreeSet<String> = dag.node_set();
172 let y_set: BTreeSet<String> = y.iter().cloned().collect();
173 let x_set: BTreeSet<String> = x.iter().cloned().collect();
174 id_recursive(&y_set, &x_set, &obs_dist, dag, &v, 0)
175 }
176}
177
178// ---------------------------------------------------------------------------
179// Core recursive ID procedure — Algorithm 1 of Shpitser-Pearl (AAAI 2006)
180// ---------------------------------------------------------------------------
181
182/// Recursive implementation.
183///
184/// Parameters follow Algorithm 1:
185/// - `y` — target variable set (what we want to observe)
186/// - `x` — intervention set (do(x))
187/// - `p` — current available distribution (symbolic)
188/// - `g` — current subgraph
189/// - `v` — current variable scope
190/// - `depth` — recursion depth guard
191fn id_recursive(
192 y: &BTreeSet<String>,
193 x: &BTreeSet<String>,
194 p: &ProbExpr,
195 g: &SemiMarkovGraph,
196 v: &BTreeSet<String>,
197 depth: usize,
198) -> IdResult {
199 const MAX_DEPTH: usize = 64;
200 if depth > MAX_DEPTH {
201 return IdResult::NotIdentifiable(HedgeCertificate {
202 s_component: v.clone(),
203 blocking_x: x.clone(),
204 outcome_y: y.clone(),
205 explanation: "Recursion depth exceeded — potential cycle in ID algorithm.".to_string(),
206 });
207 }
208
209 // -------------------------------------------------------------------
210 // Line 1: if x = ∅, return Σ_{v \ y} P(v)
211 // -------------------------------------------------------------------
212 if x.is_empty() {
213 return marginal_over(p, v, y);
214 }
215
216 // -------------------------------------------------------------------
217 // Line 2: W = An(Y)_G (ancestors of Y in G, including Y itself)
218 // if W ≠ V (not all variables are ancestors of Y):
219 // return ID(y, x ∩ W, P(W), G[W])
220 //
221 // This restricts the graph to the "relevant" part: variables that are
222 // actually on causal/confounding paths to Y.
223 // -------------------------------------------------------------------
224 let an_y: BTreeSet<String> = ancestors_of(g, &y.iter().cloned().collect::<Vec<_>>());
225
226 // V \ X (for use in Lines 3-7)
227 let v_minus_x: BTreeSet<String> = v.difference(x).cloned().collect();
228
229 if an_y != *v {
230 // Some variables in V are NOT ancestors of Y → restrict to An(Y)_G
231 let w = an_y; // = An(Y)_G (subset of V)
232 let g_w = g.subgraph(&w);
233 let new_x: BTreeSet<String> = x.intersection(&w).cloned().collect();
234 let p_w = marginal_to_scope(p, v, &w);
235 return id_recursive(y, &new_x, &p_w, &g_w, &w, depth + 1);
236 }
237
238 // -------------------------------------------------------------------
239 // Lines 4-7: C(G[V\X]) analysis (checked before Line 3 to correctly
240 // handle instrument variable (IV) identification patterns).
241 //
242 // When C(G[V\X]) has multiple components, we decompose immediately.
243 // This is crucial for graphs like IV (Z → X → Y, X ↔ Y) where
244 // C(G[{Z,Y}]) = {{Z},{Y}} correctly identifies the effect before
245 // non-ancestral variable removal (Line 3) can interfere.
246 // -------------------------------------------------------------------
247 let components_vmx = c_components_in_subgraph(g, &v_minus_x);
248
249 // -------------------------------------------------------------------
250 // Line 4: C(G[V\X]) = {S₁, ..., Sₖ} with k > 1
251 // return Σ_{v \ (y ∪ x)} ∏ ID(Sᵢ, V \ Sᵢ, P, G)
252 // -------------------------------------------------------------------
253 if components_vmx.len() > 1 {
254 let mut factor_results: Vec<ProbExpr> = Vec::new();
255
256 for si in &components_vmx {
257 let v_minus_si: BTreeSet<String> = v.difference(si).cloned().collect();
258 let sub = id_recursive(si, &v_minus_si, p, g, v, depth + 1);
259 match sub {
260 IdResult::Identified(expr) => factor_results.push(expr),
261 not_id => return not_id,
262 }
263 }
264
265 let product = make_product(factor_results);
266
267 // Marginalize over V \ (Y ∪ X): we want P(Y | do(X)) so sum out
268 // everything in (V \ X) \ Y
269 let sum_out: Vec<String> = {
270 let mut sv: Vec<String> = v_minus_x.difference(y).cloned().collect();
271 sv.sort();
272 sv
273 };
274
275 let result = if sum_out.is_empty() {
276 product
277 } else {
278 ProbExpr::Marginal {
279 expr: Box::new(product),
280 summand_vars: sum_out,
281 }
282 .simplify()
283 };
284
285 return IdResult::Identified(result);
286 }
287
288 // From this point: C(G[V\X]) has exactly 1 component.
289 // Before checking Lines 5-7, apply Line 3 to reduce scope.
290
291 // -------------------------------------------------------------------
292 // Line 3: W = (V \ X) \ An(Y)_{G[V\X]}
293 // if W ≠ ∅: ID(y, x ∪ W, P, G)
294 //
295 // Variables in V\X that are not ancestral to Y in G[V\X] can be
296 // safely "intervened on" without changing the identification result.
297 // Adding them to x strictly increases the intervention set, ensuring termination.
298 // -------------------------------------------------------------------
299 {
300 let g_v_minus_x = g.subgraph(&v_minus_x);
301 let an_y_in_g_vmx: BTreeSet<String> =
302 ancestors_of(&g_v_minus_x, &y.iter().cloned().collect::<Vec<_>>());
303 let an_y_vmx_restricted: BTreeSet<String> =
304 an_y_in_g_vmx.intersection(&v_minus_x).cloned().collect();
305 let w_line3: BTreeSet<String> = v_minus_x
306 .difference(&an_y_vmx_restricted)
307 .cloned()
308 .collect();
309
310 if !w_line3.is_empty() {
311 let new_x: BTreeSet<String> = x.union(&w_line3).cloned().collect();
312 return id_recursive(y, &new_x, p, g, v, depth + 1);
313 }
314 }
315
316 // -------------------------------------------------------------------
317 // Line 5: C(G[V\X]) = {V\X}
318 // if C(G) = {G} (G itself is a single c-component): FAIL (hedge)
319 // else: proceed to Lines 6-7
320 // -------------------------------------------------------------------
321 let components_full = c_components_in_subgraph(g, v);
322
323 if components_full.len() == 1 && components_full[0] == *v {
324 // The whole graph is one c-component AND V\X is also one c-component
325 // → hedge: there is no way to identify P(y | do(x))
326 return IdResult::NotIdentifiable(HedgeCertificate {
327 s_component: v.clone(),
328 blocking_x: x.clone(),
329 outcome_y: y.clone(),
330 explanation: format!(
331 "Hedge: the entire variable set {:?} forms a single c-component in G, \
332 and G[V\\X] = {:?} is also a single c-component. \
333 P({:?} | do({:?})) is not identifiable.",
334 v, v_minus_x, y, x
335 ),
336 });
337 }
338
339 // Lines 6-7: there are multiple c-components in G, or G has a proper
340 // c-component structure.
341 //
342 // V \ X is a single c-component (from Line 4 filter above).
343 // Find the c-component(s) in G that contain parts of V \ X.
344
345 // For the single component S in C(G[V\X]) (which equals V\X):
346 let s_vmx = &v_minus_x; // The single c-component of G[V\X]
347
348 // -------------------------------------------------------------------
349 // Line 6: if S ∈ C(G) (i.e., S is also a c-component in the full graph G)
350 // apply Tian-Pearl factorization within S
351 // -------------------------------------------------------------------
352 // Check if S_vmx is itself a c-component in the full graph
353 let s_is_full_comp = components_full.iter().any(|fc| fc == s_vmx);
354
355 if s_is_full_comp {
356 // Tian-Pearl sum-product formula:
357 // Σ_{S \ Y} ∏_{Vᵢ ∈ S} P(Vᵢ | V_{π<i} ∩ S, V_{π<i} \ S)
358 // where the ordering is the topological order of the full graph G.
359 let topo_full = topological_order(g);
360 let factors = build_tian_pearl_factors(s_vmx, &topo_full, v);
361 let product = make_product(factors);
362
363 let sum_out: Vec<String> = {
364 let mut sv: Vec<String> = s_vmx.difference(y).cloned().collect();
365 sv.sort();
366 sv
367 };
368
369 let result = if sum_out.is_empty() {
370 product
371 } else {
372 ProbExpr::Marginal {
373 expr: Box::new(product),
374 summand_vars: sum_out,
375 }
376 .simplify()
377 };
378
379 return IdResult::Identified(result);
380 }
381
382 // -------------------------------------------------------------------
383 // Line 7: ∃ S' ∈ C(G) such that S_vmx ⊊ S'
384 // recurse: ID(y, x ∩ S', ∏_{Vᵢ ∈ S'} P(Vᵢ | V_{π<i} ∩ S'), G[S'])
385 // -------------------------------------------------------------------
386 let s_prime_opt = components_full
387 .iter()
388 .find(|fc| s_vmx.is_subset(fc) && *fc != s_vmx);
389
390 if let Some(s_prime) = s_prime_opt {
391 let topo_full = topological_order(g);
392
393 // Build P(S') as Tian-Pearl product
394 let topo_sp: Vec<String> = topo_full
395 .iter()
396 .filter(|v| s_prime.contains(*v))
397 .cloned()
398 .collect();
399
400 let factors = build_tian_pearl_factors(s_prime, &topo_full, v);
401 let p_s_prime = make_product(factors);
402
403 let g_s_prime = g.subgraph(s_prime);
404 let new_x: BTreeSet<String> = x.intersection(s_prime).cloned().collect();
405
406 return id_recursive(y, &new_x, &p_s_prime, &g_s_prime, s_prime, depth + 1);
407 }
408
409 // If we reach here: C(G[V\X]) has 1 component = V\X,
410 // C(G) has multiple components but none properly contains V\X.
411 // Per the algorithm this is actually a hedge condition (C(G) intersects X).
412 // Find which c-component of G contains elements of X.
413 for fc in &components_full {
414 let x_in_fc: BTreeSet<String> = x.intersection(fc).cloned().collect();
415 if !x_in_fc.is_empty() {
416 // V\X is a subset of this component (it must be, since V\X is one component
417 // and every non-X node should be reachable)
418 return IdResult::NotIdentifiable(HedgeCertificate {
419 s_component: fc.clone(),
420 blocking_x: x_in_fc,
421 outcome_y: y.clone(),
422 explanation: format!(
423 "Hedge: c-component {:?} of G contains intervention variables {:?} \
424 and outcome variables {:?}. P(y|do(x)) is not identifiable.",
425 fc, x, y
426 ),
427 });
428 }
429 }
430
431 // Fallback (should not be reached in a well-formed call):
432 // Return marginal of P(V) over V \ Y
433 marginal_over(p, v, y)
434}
435
436// ---------------------------------------------------------------------------
437// Tian-Pearl factorization
438// ---------------------------------------------------------------------------
439
440/// Build Tian-Pearl factors: ∏_{Vᵢ ∈ scope} P(Vᵢ | V_{π<i})
441///
442/// where V_{π<i} = all variables before Vᵢ in the full topological order
443/// (intersected with the full variable scope `v_full`).
444fn build_tian_pearl_factors(
445 scope: &BTreeSet<String>,
446 topo_full: &[String],
447 _v_full: &BTreeSet<String>,
448) -> Vec<ProbExpr> {
449 // Build position map
450 let pos: std::collections::HashMap<&str, usize> = topo_full
451 .iter()
452 .enumerate()
453 .map(|(i, v)| (v.as_str(), i))
454 .collect();
455
456 let mut factors: Vec<ProbExpr> = Vec::new();
457
458 // Sort scope by topological position
459 let mut scope_sorted: Vec<&String> = scope.iter().collect();
460 scope_sorted.sort_by_key(|v| pos.get(v.as_str()).copied().unwrap_or(usize::MAX));
461
462 for vi in &scope_sorted {
463 let vi_pos = pos.get(vi.as_str()).copied().unwrap_or(0);
464
465 // All variables in the FULL topological order before vi
466 let preceding: Vec<String> = topo_full.iter().take(vi_pos).cloned().collect();
467
468 let factor = if preceding.is_empty() {
469 // P(Vi) — marginal (Vi has no predecessors in topological order)
470 ProbExpr::Joint(vec![(*vi).clone()])
471 } else {
472 // P(Vi | preceding)
473 // Represented as P(Vi, preceding...) / P(preceding...)
474 // which simplifies to the conditional form
475 ProbExpr::Conditional {
476 numerator: Box::new(ProbExpr::Joint({
477 let mut vars = vec![(*vi).clone()];
478 vars.extend(preceding.iter().cloned());
479 vars.sort();
480 vars
481 })),
482 denominator: Box::new(ProbExpr::Joint(preceding)),
483 }
484 };
485 factors.push(factor);
486 }
487
488 factors
489}
490
491// ---------------------------------------------------------------------------
492// Expression construction helpers
493// ---------------------------------------------------------------------------
494
495/// Build a product expression, collapsing singletons.
496fn make_product(factors: Vec<ProbExpr>) -> ProbExpr {
497 if factors.is_empty() {
498 ProbExpr::Joint(Vec::new()) // probability 1
499 } else if factors.len() == 1 {
500 factors.into_iter().next().expect("length checked")
501 } else {
502 ProbExpr::Product(factors).simplify()
503 }
504}
505
506/// Return Σ_{v \ y} P(v) — marginalize P(v) to only cover variables y.
507fn marginal_over(p: &ProbExpr, v: &BTreeSet<String>, y: &BTreeSet<String>) -> IdResult {
508 let sum_out: Vec<String> = {
509 let mut sv: Vec<String> = v.difference(y).cloned().collect();
510 sv.sort();
511 sv
512 };
513 if sum_out.is_empty() {
514 IdResult::Identified(p.clone())
515 } else {
516 let result = ProbExpr::Marginal {
517 expr: Box::new(p.clone()),
518 summand_vars: sum_out,
519 }
520 .simplify();
521 IdResult::Identified(result)
522 }
523}
524
525/// Marginalize P(v) down to scope `w` by summing out v \ w.
526fn marginal_to_scope(p: &ProbExpr, v: &BTreeSet<String>, w: &BTreeSet<String>) -> ProbExpr {
527 let sum_out: Vec<String> = {
528 let mut sv: Vec<String> = v.difference(w).cloned().collect();
529 sv.sort();
530 sv
531 };
532 if sum_out.is_empty() {
533 p.clone()
534 } else {
535 ProbExpr::Marginal {
536 expr: Box::new(p.clone()),
537 summand_vars: sum_out,
538 }
539 .simplify()
540 }
541}
542
543// ---------------------------------------------------------------------------
544// D-separation helpers (for do-calculus rule predicates)
545// ---------------------------------------------------------------------------
546
547/// Check d-separation between all pairs (yi, zi) given conditioning set.
548fn d_separated_set(
549 g: &SemiMarkovGraph,
550 y: &BTreeSet<String>,
551 z: &BTreeSet<String>,
552 conditioning: &BTreeSet<String>,
553) -> bool {
554 for yi in y {
555 for zi in z {
556 if !d_separated_pair(g, yi, zi, conditioning) {
557 return false;
558 }
559 }
560 }
561 true
562}
563
564/// Bayes-Ball d-separation for semi-Markovian graphs.
565///
566/// Bidirected edges A ↔ B are treated as paths via a latent H: A ← H → B.
567fn d_separated_pair(
568 g: &SemiMarkovGraph,
569 src: &str,
570 dst: &str,
571 conditioning: &BTreeSet<String>,
572) -> bool {
573 use std::collections::{HashSet, VecDeque};
574
575 if src == dst {
576 return conditioning.contains(src);
577 }
578
579 let ancestors_of_conditioning: BTreeSet<String> = g.ancestors(conditioning);
580
581 // Bayes-Ball state: (node, via_child: bool)
582 // via_child = true → ball arrived "upward" from a child
583 // via_child = false → ball arrived "downward" from a parent
584 let mut visited: HashSet<(String, bool)> = HashSet::new();
585 let mut queue: VecDeque<(String, bool)> = VecDeque::new();
586
587 queue.push_back((src.to_owned(), true));
588 queue.push_back((src.to_owned(), false));
589
590 while let Some((node, via_child)) = queue.pop_front() {
591 if !visited.insert((node.clone(), via_child)) {
592 continue;
593 }
594 if node == dst {
595 return false; // Active path found
596 }
597
598 let is_obs = conditioning.contains(&node);
599 let is_anc_obs = ancestors_of_conditioning.contains(&node);
600
601 if via_child {
602 if !is_obs {
603 // Chain/fork: propagate to parents (upward) and children (downward)
604 for parent in g.parents(&node) {
605 queue.push_back((parent, true));
606 }
607 for child in g.children(&node) {
608 queue.push_back((child, false));
609 }
610 // Bidirected edge: treat as common-cause path
611 for nb in g.bidirected_neighbors(&node) {
612 queue.push_back((nb, false));
613 }
614 }
615 // Collider activation: if this node (collider) is observed or
616 // is an ancestor of an observed node, activate by propagating upward
617 if is_obs || is_anc_obs {
618 for parent in g.parents(&node) {
619 queue.push_back((parent, true));
620 }
621 }
622 } else {
623 // via parent
624 if !is_obs {
625 // Chain: propagate downward to children
626 for child in g.children(&node) {
627 queue.push_back((child, false));
628 }
629 // Bidirected: propagate to bidirected neighbor (common cause link)
630 for nb in g.bidirected_neighbors(&node) {
631 queue.push_back((nb, false));
632 }
633 } else {
634 // Fork block: but v-structure activation upward
635 for parent in g.parents(&node) {
636 queue.push_back((parent, true));
637 }
638 }
639 }
640 }
641
642 true // No active path → d-separated
643}
644
645// ---------------------------------------------------------------------------
646// Unit tests
647// ---------------------------------------------------------------------------
648
649#[cfg(test)]
650mod tests {
651 use super::*;
652 use crate::causal::hedge::{c_components_in_subgraph, HedgeFinder};
653 use crate::causal::semi_markov_graph::SemiMarkovGraph;
654 use crate::causal::symbolic_prob::ProbExpr;
655
656 fn s(s: &str) -> String {
657 s.to_owned()
658 }
659
660 // Chain X → Y → Z
661 fn chain_graph() -> SemiMarkovGraph {
662 let mut g = SemiMarkovGraph::new();
663 g.add_directed("X", "Y");
664 g.add_directed("Y", "Z");
665 g
666 }
667
668 // X → Y with X ↔ Y (pure confounder)
669 fn confounded_graph() -> SemiMarkovGraph {
670 let mut g = SemiMarkovGraph::new();
671 g.add_directed("X", "Y");
672 g.add_bidirected("X", "Y");
673 g
674 }
675
676 // Front-door: X → M → Y, X ↔ Y
677 fn frontdoor_graph() -> SemiMarkovGraph {
678 let mut g = SemiMarkovGraph::new();
679 g.add_directed("X", "M");
680 g.add_directed("M", "Y");
681 g.add_bidirected("X", "Y");
682 g
683 }
684
685 // IV: Z → X → Y, X ↔ Y
686 fn iv_graph() -> SemiMarkovGraph {
687 let mut g = SemiMarkovGraph::new();
688 g.add_directed("Z", "X");
689 g.add_directed("X", "Y");
690 g.add_bidirected("X", "Y");
691 g
692 }
693
694 // Backdoor admissible: W → X → Y, W → Y (no hidden confounders)
695 fn backdoor_graph() -> SemiMarkovGraph {
696 let mut g = SemiMarkovGraph::new();
697 g.add_directed("W", "X");
698 g.add_directed("W", "Y");
699 g.add_directed("X", "Y");
700 g
701 }
702
703 // -----------------------------------------------------------------------
704 // c_components tests
705 // -----------------------------------------------------------------------
706
707 #[test]
708 fn test_c_components_chain_no_bidirected_via_id() {
709 let g = chain_graph();
710 let vars: BTreeSet<String> = ["X", "Y", "Z"].iter().map(|s| s.to_string()).collect();
711 let comps = c_components_in_subgraph(&g, &vars);
712 assert_eq!(comps.len(), 3, "Expected 3 singletons, got {}", comps.len());
713 }
714
715 #[test]
716 fn test_c_components_bidirected_chain() {
717 let mut g = SemiMarkovGraph::new();
718 g.add_bidirected("X", "Y");
719 g.add_bidirected("Y", "Z");
720 let vars: BTreeSet<String> = ["X", "Y", "Z"].iter().map(|s| s.to_string()).collect();
721 let comps = c_components_in_subgraph(&g, &vars);
722 assert_eq!(comps.len(), 1);
723 assert_eq!(comps[0].len(), 3);
724 }
725
726 // -----------------------------------------------------------------------
727 // topological_order tests
728 // -----------------------------------------------------------------------
729
730 #[test]
731 fn test_topological_order_chain() {
732 let g = chain_graph();
733 let order = topological_order(&g);
734 let x_pos = order.iter().position(|v| v == "X").expect("X missing");
735 let y_pos = order.iter().position(|v| v == "Y").expect("Y missing");
736 let z_pos = order.iter().position(|v| v == "Z").expect("Z missing");
737 assert!(x_pos < y_pos);
738 assert!(y_pos < z_pos);
739 }
740
741 // -----------------------------------------------------------------------
742 // ancestors_of tests
743 // -----------------------------------------------------------------------
744
745 #[test]
746 fn test_ancestors_of_chain() {
747 let g = chain_graph();
748 let anc = ancestors_of(&g, &[s("Z")]);
749 assert!(anc.contains("X"));
750 assert!(anc.contains("Y"));
751 assert!(anc.contains("Z"));
752 }
753
754 // -----------------------------------------------------------------------
755 // ID: no intervention → always identifiable
756 // -----------------------------------------------------------------------
757
758 #[test]
759 fn test_id_no_intervention_returns_marginal() {
760 let g = chain_graph();
761 let p = ProbExpr::p(vec![s("X"), s("Y"), s("Z")]);
762 let result = IdAlgorithm::identify(&[s("Z")], &[], p, &g);
763 assert!(
764 result.is_identified(),
765 "No intervention should be identifiable"
766 );
767 }
768
769 // -----------------------------------------------------------------------
770 // ID: backdoor admissible (W → X → Y, W → Y, no hidden confounders)
771 // -----------------------------------------------------------------------
772
773 #[test]
774 fn test_id_backdoor_admissible() {
775 let g = backdoor_graph();
776 let p = ProbExpr::p(vec![s("W"), s("X"), s("Y")]);
777 let result = IdAlgorithm::identify(&[s("Y")], &[s("X")], p, &g);
778 assert!(
779 result.is_identified(),
780 "Backdoor admissible graph should be identifiable; hedge: {:?}",
781 result.hedge()
782 );
783 }
784
785 // -----------------------------------------------------------------------
786 // ID: pure confounder X ↔ Y, no instrument → NOT identifiable
787 // -----------------------------------------------------------------------
788
789 #[test]
790 fn test_id_simple_confounder_not_identifiable() {
791 let g = confounded_graph();
792 let p = ProbExpr::p(vec![s("X"), s("Y")]);
793 let result = IdAlgorithm::identify(&[s("Y")], &[s("X")], p, &g);
794 assert!(
795 !result.is_identified(),
796 "Pure confounder X↔Y with no instrument should NOT be identifiable"
797 );
798 }
799
800 // -----------------------------------------------------------------------
801 // ID: front-door criterion (X → M → Y, X ↔ Y) → identifiable
802 // -----------------------------------------------------------------------
803
804 #[test]
805 fn test_id_frontdoor_identifiable() {
806 let g = frontdoor_graph();
807 let p = ProbExpr::p(vec![s("X"), s("M"), s("Y")]);
808 let result = IdAlgorithm::identify(&[s("Y")], &[s("X")], p, &g);
809 assert!(
810 result.is_identified(),
811 "Front-door graph should be identifiable; hedge: {:?}",
812 result.hedge()
813 );
814 }
815
816 // -----------------------------------------------------------------------
817 // ID: IV graph (Z → X → Y, X ↔ Y)
818 //
819 // The IV formula P(Y|do(X)) = Σ_z P(Y|X,Z=z)P(Z=z) requires do-calculus
820 // Rule 2 to convert P(Y|do(X),Z) → P(Y|X,Z) once Z is fixed. Algorithm 1
821 // decomposes into sub-IDs: ID({Z},{X,Y},P,G) × ID({Y},{Z,X},P,G).
822 // The sub-call ID({Y},{Z,X},P,G) recurses into G[{X,Y}] where the hedge
823 // {X,Y} (via X↔Y) triggers. The full IV identification requires the
824 // do-calculus Rule 2 step which is handled separately (see do_calculus_rule2).
825 //
826 // This test verifies that Algorithm 1's Line 4 decomposition FIRES
827 // (C(G[V\X]) = {{Z},{Y}} has 2 components), even if the recursive sub-call
828 // eventually terminates via the hedge path in G[{X,Y}].
829 // -----------------------------------------------------------------------
830
831 #[test]
832 fn test_id_iv_line4_decomposes() {
833 // Verify that C(G[V\X]) has 2 components for the IV graph
834 // (necessary condition for IV identification via Line 4)
835 let g = iv_graph();
836 let v_minus_x: BTreeSet<String> = ["Z".to_string(), "Y".to_string()].into();
837 let comps = c_components_in_subgraph(&g, &v_minus_x);
838 assert_eq!(
839 comps.len(),
840 2,
841 "IV graph: C(G[V\\X]) should have 2 components ({{Z}} and {{Y}}), got {:?}",
842 comps
843 );
844 }
845
846 #[test]
847 fn test_id_iv_rule2_applies() {
848 // do-calculus Rule 2: P(y|do(x),z,w) = P(y|do(x),z,w) when conditions hold.
849 // For IV graph: Z→X→Y, X↔Y
850 // Rule 2 can exchange do(Z) for observing Z given appropriate d-separation.
851 let g = iv_graph();
852 // y={Y}, x={X}, z={Z}, w=∅
853 let y: BTreeSet<String> = ["Y".to_string()].into();
854 let x: BTreeSet<String> = ["X".to_string()].into();
855 let z: BTreeSet<String> = ["Z".to_string()].into();
856 let w: BTreeSet<String> = BTreeSet::new();
857 // This predicate should run without panic
858 let _rule2 = do_calculus_rule2(&g, &y, &x, &z, &w);
859 // Rule 2 applies: P(Y|do(X),do(Z),W) = P(Y|do(X),Z,W) when (Y⊥Z|X,W) in G_{X̄,Z̄}
860 // We just verify the predicate runs correctly
861 }
862
863 // -----------------------------------------------------------------------
864 // HedgeFinder: none for chain (no bidirected → identifiable)
865 // -----------------------------------------------------------------------
866
867 #[test]
868 fn test_hedge_finder_none_for_chain() {
869 let g = chain_graph();
870 let cert = HedgeFinder::find(&g, &[s("Z")], &[s("X")]);
871 assert!(cert.is_none(), "Chain graph should have no hedge");
872 }
873
874 // -----------------------------------------------------------------------
875 // HedgeFinder: certificate for confounded graph
876 // -----------------------------------------------------------------------
877
878 #[test]
879 fn test_hedge_finder_certificate_for_confounded() {
880 let g = confounded_graph();
881 let cert = HedgeFinder::find(&g, &[s("Y")], &[s("X")]);
882 assert!(cert.is_some(), "Confounded graph should have a hedge");
883 let cert = cert.expect("certificate");
884 assert!(!cert.blocking_x.is_empty());
885 }
886
887 // -----------------------------------------------------------------------
888 // ProbExpr display tests
889 // -----------------------------------------------------------------------
890
891 #[test]
892 fn test_prob_expr_do_display() {
893 let e = ProbExpr::p_do(vec![s("Y")], vec![s("X")]);
894 let disp = format!("{e}");
895 assert!(disp.contains("do(X)"), "Should show do(X): {disp}");
896 assert!(disp.contains("Y"), "Should show Y: {disp}");
897 }
898
899 #[test]
900 fn test_prob_expr_marginal_display() {
901 let inner = ProbExpr::p(vec![s("Y"), s("Z")]);
902 let marg = ProbExpr::marginal(inner, vec![s("Z")]);
903 let disp = format!("{marg}");
904 assert!(disp.contains("Σ_{Z}"), "Should contain Σ_{{Z}}: {disp}");
905 }
906
907 // -----------------------------------------------------------------------
908 // Product simplification
909 // -----------------------------------------------------------------------
910
911 #[test]
912 fn test_product_two_conditionals_simplify() {
913 let e1 = ProbExpr::conditional(vec![s("Y")], vec![s("X")]);
914 let e2 = ProbExpr::conditional(vec![s("Z")], vec![s("M")]);
915 let prod = ProbExpr::product(vec![e1, e2]);
916 let simplified = prod.simplify();
917 match simplified {
918 ProbExpr::Product(ref terms) => assert_eq!(terms.len(), 2),
919 other => panic!("Expected Product, got {other:?}"),
920 }
921 }
922
923 // -----------------------------------------------------------------------
924 // Tian-Pearl factors
925 // -----------------------------------------------------------------------
926
927 #[test]
928 fn test_tian_pearl_factors_chain() {
929 let g = chain_graph();
930 let topo = topological_order(&g);
931 let scope: BTreeSet<String> = ["X", "Y", "Z"].iter().map(|s| s.to_string()).collect();
932 let v: BTreeSet<String> = scope.clone();
933 let factors = build_tian_pearl_factors(&scope, &topo, &v);
934 assert_eq!(factors.len(), 3, "One factor per variable in chain");
935 }
936
937 // -----------------------------------------------------------------------
938 // Do-calculus rule predicates
939 // -----------------------------------------------------------------------
940
941 #[test]
942 fn test_do_calculus_rule1_applies() {
943 let mut g = SemiMarkovGraph::new();
944 g.add_directed("Z", "X");
945 g.add_directed("X", "Y");
946 let y: BTreeSet<String> = ["Y".to_string()].into();
947 let x: BTreeSet<String> = ["X".to_string()].into();
948 let z: BTreeSet<String> = ["Z".to_string()].into();
949 let w: BTreeSet<String> = BTreeSet::new();
950 let _applies = do_calculus_rule1(&g, &y, &x, &z, &w);
951 }
952
953 #[test]
954 fn test_do_calculus_rule2_applies() {
955 let mut g = SemiMarkovGraph::new();
956 g.add_directed("Z", "X");
957 g.add_directed("X", "Y");
958 let y: BTreeSet<String> = ["Y".to_string()].into();
959 let x: BTreeSet<String> = ["X".to_string()].into();
960 let z: BTreeSet<String> = ["Z".to_string()].into();
961 let w: BTreeSet<String> = BTreeSet::new();
962 let _applies = do_calculus_rule2(&g, &y, &x, &z, &w);
963 }
964
965 #[test]
966 fn test_do_calculus_rule3_applies() {
967 let mut g = SemiMarkovGraph::new();
968 g.add_directed("Z", "X");
969 g.add_directed("X", "Y");
970 let y: BTreeSet<String> = ["Y".to_string()].into();
971 let x: BTreeSet<String> = ["X".to_string()].into();
972 let z: BTreeSet<String> = ["Z".to_string()].into();
973 let w: BTreeSet<String> = BTreeSet::new();
974 let _applies = do_calculus_rule3(&g, &y, &x, &z, &w);
975 }
976}