1use std::collections::{HashMap, HashSet};
22
23use crate::causal_graph::dag::CausalDAG;
24use crate::error::{StatsError, StatsResult};
25
26#[derive(Debug, Clone, PartialEq, Eq)]
32pub enum DoCalculusRule {
33 Rule1,
37 Rule2,
41 Rule3,
45 None,
47}
48
49#[derive(Debug, Clone)]
51pub struct BackdoorResult {
52 pub is_admissible: bool,
54 pub adjustment_set: Vec<String>,
56 pub all_minimal_sets: Vec<Vec<String>>,
58}
59
60#[derive(Debug, Clone)]
62pub struct FrontdoorResult {
63 pub is_applicable: bool,
65 pub mediator_set: Vec<String>,
67 pub formula: String,
69}
70
71#[derive(Debug, Clone)]
73pub struct IdResult {
74 pub identifiable: bool,
76 pub expression: String,
78 pub explanation: String,
80}
81
82#[derive(Debug, Clone)]
84pub struct CComponent {
85 pub nodes: HashSet<usize>,
87}
88
89pub fn check_do_calculus_rule(
104 dag: &CausalDAG,
105 y: &[&str],
106 x: &[&str],
107 z: &[&str],
108 w: &[&str],
109 rule: DoCalculusRule,
110) -> bool {
111 match rule {
112 DoCalculusRule::Rule1 => {
113 let mut g_xbar = dag.clone();
115 remove_incoming_edges(&mut g_xbar, x);
116 let mut conditioning: Vec<&str> = Vec::new();
117 conditioning.extend_from_slice(x);
118 conditioning.extend_from_slice(w);
119 check_d_separation_all(&g_xbar, y, z, &conditioning)
120 }
121 DoCalculusRule::Rule2 => {
122 let mut g = dag.clone();
124 remove_incoming_edges(&mut g, x);
125 remove_incoming_edges(&mut g, z);
126 let mut conditioning: Vec<&str> = Vec::new();
127 conditioning.extend_from_slice(x);
128 conditioning.extend_from_slice(w);
129 check_d_separation_all(&g, y, z, &conditioning)
130 }
131 DoCalculusRule::Rule3 => {
132 let mut g_xbar = dag.clone();
135 remove_incoming_edges(&mut g_xbar, x);
136
137 let w_ancestors = ancestors_of_names(&g_xbar, w);
138 let z_not_w_anc: Vec<&str> = z
139 .iter()
140 .filter(|&&zz| {
141 let idx = dag.node_index(zz);
142 !idx.map(|i| w_ancestors.contains(&i)).unwrap_or(false)
143 })
144 .copied()
145 .collect();
146
147 let mut g = g_xbar;
149 remove_outgoing_edges(&mut g, &z_not_w_anc);
150
151 let mut conditioning: Vec<&str> = Vec::new();
152 conditioning.extend_from_slice(x);
153 conditioning.extend_from_slice(w);
154 check_d_separation_all(&g, y, z, &conditioning)
155 }
156 DoCalculusRule::None => false,
157 }
158}
159
160pub fn satisfies_backdoor(dag: &CausalDAG, x: &str, y: &str, z_set: &[&str]) -> bool {
172 let desc_x = dag.descendants(x);
174 for &z in z_set {
175 if let Some(zi) = dag.node_index(z) {
176 if desc_x.contains(&zi) {
177 return false;
178 }
179 }
180 }
181 let mut g = dag.clone();
183 remove_outgoing_edges(&mut g, &[x]);
184 g.is_d_separated(x, y, z_set)
185}
186
187pub fn find_backdoor_sets(
192 dag: &CausalDAG,
193 x: &str,
194 y: &str,
195 max_set_size: usize,
196) -> BackdoorResult {
197 let desc_x = dag.descendants(x);
198 let xi = dag.node_index(x).unwrap_or(usize::MAX);
199 let yi = dag.node_index(y).unwrap_or(usize::MAX);
200
201 let candidates: Vec<usize> = (0..dag.n_nodes())
203 .filter(|&i| i != xi && i != yi && !desc_x.contains(&i))
204 .collect();
205
206 let mut all_minimal: Vec<Vec<String>> = Vec::new();
207 let mut found_any = false;
208
209 'outer: for size in 0..=max_set_size.min(candidates.len()) {
211 for subset in subsets(&candidates, size) {
212 let z_names: Vec<&str> = subset.iter().filter_map(|&i| dag.node_name(i)).collect();
213 if satisfies_backdoor(dag, x, y, &z_names) {
214 let z_strings: Vec<String> = z_names.iter().map(|s| s.to_string()).collect();
215 all_minimal.push(z_strings);
216 found_any = true;
217 if all_minimal.len() >= 20 {
218 break 'outer;
220 }
221 }
222 }
223 if found_any && size < max_set_size {
226 }
228 }
229
230 let best = all_minimal.first().cloned().unwrap_or_default();
231 BackdoorResult {
232 is_admissible: found_any,
233 adjustment_set: best,
234 all_minimal_sets: all_minimal,
235 }
236}
237
238pub fn satisfies_frontdoor(dag: &CausalDAG, x: &str, y: &str, m_set: &[&str]) -> bool {
250 if !intercepts_all_paths(dag, x, y, m_set) {
252 return false;
253 }
254 let mut g_xbar = dag.clone();
259 remove_incoming_edges(&mut g_xbar, &[x]);
260 remove_outgoing_edges(&mut g_xbar, &[x]);
261 for &m in m_set {
262 if !g_xbar.is_d_separated(x, m, &[]) {
263 return false;
264 }
265 }
266 for &m in m_set {
268 if !satisfies_backdoor(dag, m, y, &[x]) {
269 return false;
270 }
271 }
272 true
273}
274
275pub fn find_frontdoor_set(dag: &CausalDAG, x: &str, y: &str) -> FrontdoorResult {
277 let xi = dag.node_index(x).unwrap_or(usize::MAX);
278 let yi = dag.node_index(y).unwrap_or(usize::MAX);
279
280 let descendants_x = dag.descendants(x);
281 let candidates: Vec<usize> = descendants_x
283 .iter()
284 .filter(|&&i| i != yi && i != xi)
285 .copied()
286 .collect();
287
288 for size in 1..=candidates.len() {
289 for subset in subsets(&candidates, size) {
290 let m_names: Vec<&str> = subset.iter().filter_map(|&i| dag.node_name(i)).collect();
291 if satisfies_frontdoor(dag, x, y, &m_names) {
292 let formula = frontdoor_formula(x, y, &m_names);
293 return FrontdoorResult {
294 is_applicable: true,
295 mediator_set: m_names.iter().map(|s| s.to_string()).collect(),
296 formula,
297 };
298 }
299 }
300 }
301
302 FrontdoorResult {
303 is_applicable: false,
304 mediator_set: Vec::new(),
305 formula: "Not identifiable via frontdoor".to_owned(),
306 }
307}
308
309pub fn id_algorithm(dag: &CausalDAG, y: &[&str], x: &[&str]) -> IdResult {
321 if x.is_empty() {
323 return IdResult {
324 identifiable: true,
325 expression: format!("P({})", y.join(", ")),
326 explanation: "No intervention; trivially identified as the observational distribution."
327 .to_owned(),
328 };
329 }
330
331 if x.len() == 1 && y.len() == 1 {
334 let xv = x[0];
335 let yv = y[0];
336
337 if satisfies_backdoor(dag, xv, yv, &[]) {
339 return IdResult {
340 identifiable: true,
341 expression: format!("P({yv} | {xv})"),
342 explanation: "Identified via empty backdoor set (no confounding).".to_owned(),
343 };
344 }
345
346 let bd = find_backdoor_sets(dag, xv, yv, 5);
348 if bd.is_admissible {
349 let z_str = bd.adjustment_set.join(", ");
350 return IdResult {
351 identifiable: true,
352 expression: format!("Σ_{{{}}} P({yv} | {xv}, {z_str}) P({z_str})", z_str,),
353 explanation: format!("Identified via backdoor adjustment on {{{z_str}}}."),
354 };
355 }
356
357 let fd = find_frontdoor_set(dag, xv, yv);
359 if fd.is_applicable {
360 return IdResult {
361 identifiable: true,
362 expression: fd.formula,
363 explanation: format!(
364 "Identified via frontdoor criterion through mediators: {:?}.",
365 fd.mediator_set
366 ),
367 };
368 }
369 }
370
371 let tian = tian_pearl_id(dag, y, x);
373 if tian.identifiable {
374 return tian;
375 }
376
377 IdResult {
378 identifiable: false,
379 expression: String::new(),
380 explanation: format!(
381 "P({y} | do({x})) is not identifiable by the ID algorithm with the given DAG.",
382 y = y.join(", "),
383 x = x.join(", ")
384 ),
385 }
386}
387
388pub fn tian_pearl_id(dag: &CausalDAG, y: &[&str], x: &[&str]) -> IdResult {
397 let topo = dag.topological_sort();
399 let n = dag.n_nodes();
400
401 let topo_pos: HashMap<&str, usize> = topo
403 .iter()
404 .enumerate()
405 .map(|(i, &name)| (name, i))
406 .collect();
407
408 let y_set: HashSet<&str> = y.iter().copied().collect();
414 let x_set: HashSet<&str> = x.iter().copied().collect();
415
416 let sum_over: Vec<&str> = topo
418 .iter()
419 .copied()
420 .filter(|&v| !y_set.contains(v) && !x_set.contains(v))
421 .collect();
422
423 let mut numerator_parts: Vec<String> = Vec::new();
425 let mut denominator_parts: Vec<String> = Vec::new();
426
427 for &node in &topo {
428 let pos = topo_pos[node];
429 let pa: Vec<&str> = dag.parents(node);
430 let prior: Vec<&str> = topo[..pos].to_vec();
432
433 let cond: Vec<String> = pa
434 .iter()
435 .map(|s| s.to_string())
436 .chain(prior.iter().map(|s| s.to_string()))
437 .collect();
438
439 let cond_str = if cond.is_empty() {
440 String::new()
441 } else {
442 format!(" | {}", cond.join(", "))
443 };
444
445 if !x_set.contains(node) {
446 numerator_parts.push(format!("P({node}{cond_str})"));
448 }
449 if pa.iter().any(|p| x_set.contains(*p)) || prior.iter().any(|p| x_set.contains(*p)) {
451 denominator_parts.push(format!("P({node}{cond_str})"));
452 }
453 }
454
455 let sum_str = if sum_over.is_empty() {
456 String::new()
457 } else {
458 format!("Σ_{{{}}}", sum_over.join(","))
459 };
460
461 let num_str = numerator_parts.join(" ");
462 let expr = if denominator_parts.is_empty() {
463 format!("{sum_str} {num_str}")
464 } else {
465 format!("{sum_str} {num_str} / ({})", denominator_parts.join(" "))
466 };
467
468 IdResult {
469 identifiable: n > 0,
470 expression: expr.trim().to_owned(),
471 explanation: "Tian-Pearl c-component factorization (DAG, no hidden variables).".to_owned(),
472 }
473}
474
475pub fn c_components_with_hidden(dag: &CausalDAG, bidirected: &[(&str, &str)]) -> Vec<CComponent> {
480 let n = dag.n_nodes();
481 let mut union_find: Vec<usize> = (0..n).collect();
482
483 fn find(uf: &mut Vec<usize>, mut i: usize) -> usize {
484 while uf[i] != i {
485 uf[i] = uf[uf[i]]; i = uf[i];
487 }
488 i
489 }
490
491 fn union(uf: &mut Vec<usize>, a: usize, b: usize) {
492 let ra = find(uf, a);
493 let rb = find(uf, b);
494 if ra != rb {
495 uf[ra] = rb;
496 }
497 }
498
499 for &(u, v) in bidirected {
500 if let (Some(ui), Some(vi)) = (dag.node_index(u), dag.node_index(v)) {
501 union(&mut union_find, ui, vi);
502 }
503 }
504
505 let mut comp_map: HashMap<usize, HashSet<usize>> = HashMap::new();
507 for i in 0..n {
508 let root = find(&mut union_find, i);
509 comp_map.entry(root).or_default().insert(i);
510 }
511
512 comp_map
513 .into_values()
514 .map(|nodes| CComponent { nodes })
515 .collect()
516}
517
518fn remove_incoming_edges(dag: &mut CausalDAG, targets: &[&str]) {
524 let target_idxs: HashSet<usize> = targets.iter().filter_map(|&t| dag.node_index(t)).collect();
525 dag.remove_incoming_edges_for(&target_idxs);
526}
527
528fn remove_outgoing_edges(dag: &mut CausalDAG, targets: &[&str]) {
530 let target_idxs: HashSet<usize> = targets.iter().filter_map(|&t| dag.node_index(t)).collect();
531 dag.remove_outgoing_edges_for(&target_idxs);
532}
533
534fn check_d_separation_all(dag: &CausalDAG, y: &[&str], z: &[&str], conditioning: &[&str]) -> bool {
536 for &yi in y {
537 for &zi in z {
538 if !dag.is_d_separated(yi, zi, conditioning) {
539 return false;
540 }
541 }
542 }
543 true
544}
545
546fn ancestors_of_names(dag: &CausalDAG, names: &[&str]) -> HashSet<usize> {
548 let mut all_anc = HashSet::new();
549 for &name in names {
550 for anc in dag.ancestors(name) {
551 all_anc.insert(anc);
552 }
553 }
554 all_anc
555}
556
557fn intercepts_all_paths(dag: &CausalDAG, x: &str, y: &str, m_set: &[&str]) -> bool {
559 let xi = match dag.node_index(x) {
561 None => return true,
562 Some(i) => i,
563 };
564 let yi = match dag.node_index(y) {
565 None => return true,
566 Some(i) => i,
567 };
568 let m_idxs: HashSet<usize> = m_set.iter().filter_map(|&m| dag.node_index(m)).collect();
569
570 let mut stack: Vec<usize> = vec![xi];
572 let mut visited: HashSet<usize> = HashSet::new();
573 while let Some(cur) = stack.pop() {
574 if cur == yi {
575 return false; }
577 if !visited.insert(cur) {
578 continue;
579 }
580 for c in dag.children(dag.node_name(cur).unwrap_or("")) {
581 if let Some(ci) = dag.node_index(c) {
582 if !m_idxs.contains(&ci) {
583 stack.push(ci);
584 }
585 }
586 }
587 }
588 true
589}
590
591fn frontdoor_formula(x: &str, y: &str, m_set: &[&str]) -> String {
593 let m_str = m_set.join(", ");
594 format!(
595 "Σ_{{{m_str}}} P({m_str} | {x}) Σ_{{{x}'}} P({y} | {x}', {m_str}) P({x}')",
596 m_str = m_str,
597 x = x,
598 y = y,
599 )
600}
601
602fn subsets<T: Copy>(items: &[T], k: usize) -> Vec<Vec<T>> {
604 if k == 0 {
605 return vec![Vec::new()];
606 }
607 if k > items.len() {
608 return Vec::new();
609 }
610 let mut result = Vec::new();
611 for i in 0..=(items.len() - k) {
612 for mut rest in subsets(&items[i + 1..], k - 1) {
613 rest.insert(0, items[i]);
614 result.push(rest);
615 }
616 }
617 result
618}
619
620#[cfg(test)]
625mod tests {
626 use super::*;
627 use crate::causal_graph::dag::CausalDAG;
628
629 fn smoke_dag() -> CausalDAG {
630 let mut dag = CausalDAG::new();
632 dag.add_edge("X", "M").unwrap();
633 dag.add_edge("M", "Y").unwrap();
634 dag
635 }
636
637 fn confounded_dag() -> CausalDAG {
638 let mut dag = CausalDAG::new();
641 dag.add_edge("Z", "X").unwrap();
642 dag.add_edge("Z", "Y").unwrap();
643 dag.add_edge("X", "Y").unwrap();
644 dag
645 }
646
647 #[test]
648 fn test_backdoor_with_z() {
649 let dag = confounded_dag();
650 assert!(satisfies_backdoor(&dag, "X", "Y", &["Z"]));
652 assert!(!satisfies_backdoor(&dag, "X", "Y", &[]));
654 }
655
656 #[test]
657 fn test_find_backdoor_set() {
658 let dag = confounded_dag();
659 let res = find_backdoor_sets(&dag, "X", "Y", 3);
660 assert!(res.is_admissible);
661 assert!(res.adjustment_set.contains(&"Z".to_string()));
662 }
663
664 #[test]
665 fn test_frontdoor() {
666 let dag = smoke_dag();
667 assert!(satisfies_frontdoor(&dag, "X", "Y", &["M"]));
669 let fd = find_frontdoor_set(&dag, "X", "Y");
670 assert!(fd.is_applicable);
671 }
672
673 #[test]
674 fn test_id_trivial() {
675 let dag = smoke_dag();
676 let res = id_algorithm(&dag, &["Y"], &[]);
677 assert!(res.identifiable);
678 assert!(res.expression.contains('P'));
679 }
680
681 #[test]
682 fn test_tian_pearl() {
683 let dag = smoke_dag();
684 let res = tian_pearl_id(&dag, &["Y"], &["X"]);
685 assert!(res.identifiable);
686 }
687
688 #[test]
689 fn test_c_components_with_hidden() {
690 let dag = smoke_dag();
691 let comps = c_components_with_hidden(&dag, &[]);
693 assert_eq!(comps.len(), dag.n_nodes());
694 let comps2 = c_components_with_hidden(&dag, &[("X", "Y")]);
696 assert!(comps2.len() < dag.n_nodes());
697 }
698
699 #[test]
700 fn test_do_calculus_rule1() {
701 let dag = confounded_dag();
702 let applies =
703 check_do_calculus_rule(&dag, &["Y"], &["X"], &["Z"], &[], DoCalculusRule::Rule1);
704 let _ = applies;
707 }
708
709 #[test]
710 fn test_subsets() {
711 let items = vec![1, 2, 3];
712 assert_eq!(subsets(&items, 0).len(), 1);
713 assert_eq!(subsets(&items, 1).len(), 3);
714 assert_eq!(subsets(&items, 2).len(), 3);
715 assert_eq!(subsets(&items, 3).len(), 1);
716 }
717}