tensorlogic_infer/causal/graph.rs
1//! Causal graph (DAG) structure and graph-theoretic queries.
2//!
3//! Defines [`CausalGraph`] plus its impl block: d-separation, ancestors,
4//! descendants, and backdoor-path reachability primitives.
5
6use std::collections::{HashMap, HashSet, VecDeque};
7
8use super::error::CausalError;
9
10// ---------------------------------------------------------------------------
11// CausalGraph
12// ---------------------------------------------------------------------------
13
14/// A Directed Acyclic Graph (DAG) representing causal structure among variables.
15///
16/// Nodes are identified by string names; edges encode direct causal relationships
17/// (parent → child). The graph enforces acyclicity lazily via [`CausalGraph::is_acyclic`].
18#[derive(Debug, Clone)]
19pub struct CausalGraph {
20 pub(super) nodes: Vec<String>,
21 /// Directed edges stored as (parent_idx, child_idx) index pairs.
22 pub(super) edges: Vec<(usize, usize)>,
23}
24
25impl CausalGraph {
26 /// Create a new causal graph with the given variable names.
27 pub fn new(nodes: Vec<String>) -> Self {
28 Self {
29 nodes,
30 edges: Vec::new(),
31 }
32 }
33
34 /// Return the index of a node by name, or `None` if it does not exist.
35 pub fn node_index(&self, name: &str) -> Option<usize> {
36 self.nodes.iter().position(|n| n == name)
37 }
38
39 /// Add a directed edge `parent → child`.
40 ///
41 /// Returns [`CausalError::NodeNotFound`] if either node is absent.
42 /// Does not check for cycles — call [`CausalGraph::is_acyclic`] separately.
43 pub fn add_edge(&mut self, parent: &str, child: &str) -> Result<(), CausalError> {
44 let p = self
45 .node_index(parent)
46 .ok_or_else(|| CausalError::NodeNotFound(parent.to_string()))?;
47 let c = self
48 .node_index(child)
49 .ok_or_else(|| CausalError::NodeNotFound(child.to_string()))?;
50 self.edges.push((p, c));
51 Ok(())
52 }
53
54 /// Return direct parents of `node`.
55 pub fn parents_of(&self, node: &str) -> Vec<String> {
56 match self.node_index(node) {
57 None => vec![],
58 Some(idx) => self
59 .edges
60 .iter()
61 .filter(|&&(_, c)| c == idx)
62 .map(|&(p, _)| self.nodes[p].clone())
63 .collect(),
64 }
65 }
66
67 /// Return direct children of `node`.
68 pub fn children_of(&self, node: &str) -> Vec<String> {
69 match self.node_index(node) {
70 None => vec![],
71 Some(idx) => self
72 .edges
73 .iter()
74 .filter(|&&(p, _)| p == idx)
75 .map(|&(_, c)| self.nodes[c].clone())
76 .collect(),
77 }
78 }
79
80 /// Return all ancestors of `node` (transitive parents), excluding the node itself.
81 pub fn ancestors_of(&self, node: &str) -> Vec<String> {
82 let mut visited = HashSet::new();
83 let mut queue = VecDeque::new();
84 if let Some(start) = self.node_index(node) {
85 queue.push_back(start);
86 }
87 while let Some(cur) = queue.pop_front() {
88 for &(p, c) in &self.edges {
89 if c == cur && !visited.contains(&p) {
90 visited.insert(p);
91 queue.push_back(p);
92 }
93 }
94 }
95 visited.into_iter().map(|i| self.nodes[i].clone()).collect()
96 }
97
98 /// Return all descendants of `node` (transitive children), excluding the node itself.
99 pub fn descendants_of(&self, node: &str) -> Vec<String> {
100 let mut visited = HashSet::new();
101 let mut queue = VecDeque::new();
102 if let Some(start) = self.node_index(node) {
103 queue.push_back(start);
104 }
105 while let Some(cur) = queue.pop_front() {
106 for &(p, c) in &self.edges {
107 if p == cur && !visited.contains(&c) {
108 visited.insert(c);
109 queue.push_back(c);
110 }
111 }
112 }
113 visited.into_iter().map(|i| self.nodes[i].clone()).collect()
114 }
115
116 /// Check whether the graph is acyclic using Kahn's BFS topological sort algorithm.
117 ///
118 /// Returns `true` if the graph is a valid DAG.
119 pub fn is_acyclic(&self) -> bool {
120 let n = self.nodes.len();
121 let mut in_degree = vec![0usize; n];
122 for &(_, c) in &self.edges {
123 in_degree[c] += 1;
124 }
125 let mut queue: VecDeque<usize> = (0..n).filter(|&i| in_degree[i] == 0).collect();
126 let mut processed = 0usize;
127 while let Some(cur) = queue.pop_front() {
128 processed += 1;
129 for &(p, c) in &self.edges {
130 if p == cur {
131 in_degree[c] -= 1;
132 if in_degree[c] == 0 {
133 queue.push_back(c);
134 }
135 }
136 }
137 }
138 processed == n
139 }
140
141 /// Return all node names.
142 pub fn nodes(&self) -> &[String] {
143 &self.nodes
144 }
145
146 /// Return the number of nodes.
147 pub fn node_count(&self) -> usize {
148 self.nodes.len()
149 }
150
151 /// Return the number of directed edges.
152 pub fn edge_count(&self) -> usize {
153 self.edges.len()
154 }
155
156 /// Test d-separation: is `x` d-separated from `y` given the observed set `observed`?
157 ///
158 /// Uses the Bayes-Ball algorithm on the moral graph / active path traversal.
159 /// A path is *active* given `observed` if:
160 /// - At every non-collider on the path, the node is NOT in `observed`.
161 /// - At every collider, the collider OR one of its descendants IS in `observed`.
162 pub fn d_separated(&self, x: &str, y: &str, observed: &[&str]) -> bool {
163 let x_idx = match self.node_index(x) {
164 Some(i) => i,
165 None => return true,
166 };
167 let y_idx = match self.node_index(y) {
168 Some(i) => i,
169 None => return true,
170 };
171 if x_idx == y_idx {
172 return false;
173 }
174
175 let obs_set: HashSet<usize> = observed
176 .iter()
177 .filter_map(|&name| self.node_index(name))
178 .collect();
179
180 // Pre-compute descendants of all observed nodes (needed for collider check).
181 let mut obs_or_desc: HashSet<usize> = obs_set.clone();
182 for &o in &obs_set {
183 let node_name = &self.nodes[o].clone();
184 for desc in self.descendants_of(node_name) {
185 if let Some(di) = self.node_index(&desc) {
186 obs_or_desc.insert(di);
187 }
188 }
189 }
190
191 // State: (node_idx, arrived_via_child: bool)
192 // arrived_via_child = true → we arrived at this node from one of its children (going "up")
193 // arrived_via_child = false → we arrived from a parent (going "down")
194 let mut visited: HashSet<(usize, bool)> = HashSet::new();
195 let mut queue: VecDeque<(usize, bool)> = VecDeque::new();
196
197 // We can start from x going both up and down.
198 queue.push_back((x_idx, true));
199 queue.push_back((x_idx, false));
200
201 while let Some((cur, via_child)) = queue.pop_front() {
202 if !visited.insert((cur, via_child)) {
203 continue;
204 }
205 if cur == y_idx {
206 return false; // active path found → NOT d-separated
207 }
208
209 if via_child && !obs_set.contains(&cur) {
210 // Traversing up (non-collider direction): pass through parents and children
211 // go up to parents
212 for &(p, c) in &self.edges {
213 if c == cur {
214 let state = (p, true);
215 if !visited.contains(&state) {
216 queue.push_back(state);
217 }
218 }
219 }
220 // go down to children
221 for &(p, c) in &self.edges {
222 if p == cur {
223 let state = (c, false);
224 if !visited.contains(&state) {
225 queue.push_back(state);
226 }
227 }
228 }
229 }
230
231 if !via_child {
232 // Arriving from above (going down)
233 if !obs_set.contains(&cur) {
234 // Non-collider going down: continue downward
235 for &(p, c) in &self.edges {
236 if p == cur {
237 let state = (c, false);
238 if !visited.contains(&state) {
239 queue.push_back(state);
240 }
241 }
242 }
243 }
244 // Collider activation: if cur (collider) or descendant is observed, go up
245 if obs_or_desc.contains(&cur) {
246 for &(p, c) in &self.edges {
247 if c == cur {
248 let state = (p, true);
249 if !visited.contains(&state) {
250 queue.push_back(state);
251 }
252 }
253 }
254 }
255 }
256 }
257
258 true // no active path found → d-separated
259 }
260
261 /// Internal helper: collect all undirected (bidirectional) adjacency paths from `src` to `dst`
262 /// that are *backdoor paths* (i.e. paths that enter `src` via a parent of `src`).
263 /// Returns true if there exists at least one unblocked backdoor path given `adjustment_set`.
264 pub(super) fn has_unblocked_backdoor_path(
265 &self,
266 src: usize,
267 dst: usize,
268 adjustment_set: &HashSet<usize>,
269 ) -> bool {
270 // A backdoor path from src to dst is an undirected path that starts by going
271 // "upward" from src (i.e. first step is via a parent of src).
272 // We block a path by conditioning on a non-collider on the path,
273 // or by NOT conditioning on a collider / its descendant.
274 //
275 // We use a simplified reachability check:
276 // A node Z blocks a path if it is a non-collider on the path AND Z is in adjustment_set,
277 // or it is a collider not in adjustment_set and none of its descendants are.
278 //
279 // State: (current_node, previous_node, direction: true=going_up)
280 // We only consider paths that leave src going upward (backdoor).
281
282 // Compute descendants for collider check
283 let mut desc_map: HashMap<usize, HashSet<usize>> = HashMap::new();
284 for i in 0..self.nodes.len() {
285 let desc_names = self.descendants_of(&self.nodes[i].clone());
286 let desc_idxs: HashSet<usize> = desc_names
287 .iter()
288 .filter_map(|n| self.node_index(n))
289 .collect();
290 desc_map.insert(i, desc_idxs);
291 }
292
293 let is_in_adj_or_desc = |node: usize| -> bool {
294 if adjustment_set.contains(&node) {
295 return true;
296 }
297 if let Some(descs) = desc_map.get(&node) {
298 return descs.iter().any(|d| adjustment_set.contains(d));
299 }
300 false
301 };
302
303 // State: (current_node, prev_node, arrived_via_up: bool)
304 let mut visited: HashSet<(usize, usize, bool)> = HashSet::new();
305 let mut queue: VecDeque<(usize, usize, bool)> = VecDeque::new();
306
307 // Only start on parents of src (backdoor = entering src from above)
308 for &(p, c) in &self.edges {
309 if c == src {
310 // parent p of src: going up (from src to p)
311 // The first step is upward. p is a non-collider relative to src→p.
312 // Block if p is in adjustment_set
313 if !adjustment_set.contains(&p) {
314 queue.push_back((p, src, true));
315 }
316 }
317 }
318
319 while let Some((cur, prev, going_up)) = queue.pop_front() {
320 if !visited.insert((cur, prev, going_up)) {
321 continue;
322 }
323 if cur == dst {
324 return true;
325 }
326
327 // Explore neighbors
328 // Build set of parents and children of cur
329 let parents: Vec<usize> = self
330 .edges
331 .iter()
332 .filter(|&&(_, c)| c == cur)
333 .map(|&(p, _)| p)
334 .collect();
335 let children: Vec<usize> = self
336 .edges
337 .iter()
338 .filter(|&&(p, _)| p == cur)
339 .map(|&(_, c)| c)
340 .collect();
341
342 for &next in parents.iter().chain(children.iter()) {
343 if next == prev {
344 continue;
345 }
346 // Determine if cur is a collider on the segment prev→cur→next
347 // cur is a collider iff both prev and next are parents of cur
348 let prev_is_parent_of_cur = parents.contains(&prev);
349 let next_is_parent_of_cur = parents.contains(&next);
350 let is_collider = prev_is_parent_of_cur && next_is_parent_of_cur;
351
352 let blocked = if is_collider {
353 // Collider: blocked unless cur or its descendant is in adjustment set
354 !is_in_adj_or_desc(cur)
355 } else {
356 // Non-collider: blocked if cur is in adjustment set
357 adjustment_set.contains(&cur)
358 };
359
360 if !blocked {
361 let next_going_up = parents.contains(&next);
362 let state = (next, cur, next_going_up);
363 if !visited.contains(&state) {
364 queue.push_back(state);
365 }
366 }
367 }
368 }
369
370 false
371 }
372
373 /// Check whether there is a directed path from `src` to `dst`.
374 pub fn has_directed_path(&self, src: &str, dst: &str) -> bool {
375 let src_idx = match self.node_index(src) {
376 Some(i) => i,
377 None => return false,
378 };
379 let dst_idx = match self.node_index(dst) {
380 Some(i) => i,
381 None => return false,
382 };
383 let mut visited = HashSet::new();
384 let mut queue = VecDeque::new();
385 queue.push_back(src_idx);
386 while let Some(cur) = queue.pop_front() {
387 if cur == dst_idx {
388 return true;
389 }
390 if !visited.insert(cur) {
391 continue;
392 }
393 for &(p, c) in &self.edges {
394 if p == cur && !visited.contains(&c) {
395 queue.push_back(c);
396 }
397 }
398 }
399 false
400 }
401}