Skip to main content

tensorlogic_infer/causal/
graph.rs

1//! Causal graph (DAG) structure and graph-theoretic queries.
2//!
3//! Defines [`CausalGraph`] plus its impl block: d-separation, ancestors,
4//! descendants, and backdoor-path reachability primitives.
5
6use std::collections::{HashMap, HashSet, VecDeque};
7
8use super::error::CausalError;
9
10// ---------------------------------------------------------------------------
11// CausalGraph
12// ---------------------------------------------------------------------------
13
14/// A Directed Acyclic Graph (DAG) representing causal structure among variables.
15///
16/// Nodes are identified by string names; edges encode direct causal relationships
17/// (parent → child). The graph enforces acyclicity lazily via [`CausalGraph::is_acyclic`].
18#[derive(Debug, Clone)]
19pub struct CausalGraph {
20    pub(super) nodes: Vec<String>,
21    /// Directed edges stored as (parent_idx, child_idx) index pairs.
22    pub(super) edges: Vec<(usize, usize)>,
23}
24
25impl CausalGraph {
26    /// Create a new causal graph with the given variable names.
27    pub fn new(nodes: Vec<String>) -> Self {
28        Self {
29            nodes,
30            edges: Vec::new(),
31        }
32    }
33
34    /// Return the index of a node by name, or `None` if it does not exist.
35    pub fn node_index(&self, name: &str) -> Option<usize> {
36        self.nodes.iter().position(|n| n == name)
37    }
38
39    /// Add a directed edge `parent → child`.
40    ///
41    /// Returns [`CausalError::NodeNotFound`] if either node is absent.
42    /// Does not check for cycles — call [`CausalGraph::is_acyclic`] separately.
43    pub fn add_edge(&mut self, parent: &str, child: &str) -> Result<(), CausalError> {
44        let p = self
45            .node_index(parent)
46            .ok_or_else(|| CausalError::NodeNotFound(parent.to_string()))?;
47        let c = self
48            .node_index(child)
49            .ok_or_else(|| CausalError::NodeNotFound(child.to_string()))?;
50        self.edges.push((p, c));
51        Ok(())
52    }
53
54    /// Return direct parents of `node`.
55    pub fn parents_of(&self, node: &str) -> Vec<String> {
56        match self.node_index(node) {
57            None => vec![],
58            Some(idx) => self
59                .edges
60                .iter()
61                .filter(|&&(_, c)| c == idx)
62                .map(|&(p, _)| self.nodes[p].clone())
63                .collect(),
64        }
65    }
66
67    /// Return direct children of `node`.
68    pub fn children_of(&self, node: &str) -> Vec<String> {
69        match self.node_index(node) {
70            None => vec![],
71            Some(idx) => self
72                .edges
73                .iter()
74                .filter(|&&(p, _)| p == idx)
75                .map(|&(_, c)| self.nodes[c].clone())
76                .collect(),
77        }
78    }
79
80    /// Return all ancestors of `node` (transitive parents), excluding the node itself.
81    pub fn ancestors_of(&self, node: &str) -> Vec<String> {
82        let mut visited = HashSet::new();
83        let mut queue = VecDeque::new();
84        if let Some(start) = self.node_index(node) {
85            queue.push_back(start);
86        }
87        while let Some(cur) = queue.pop_front() {
88            for &(p, c) in &self.edges {
89                if c == cur && !visited.contains(&p) {
90                    visited.insert(p);
91                    queue.push_back(p);
92                }
93            }
94        }
95        visited.into_iter().map(|i| self.nodes[i].clone()).collect()
96    }
97
98    /// Return all descendants of `node` (transitive children), excluding the node itself.
99    pub fn descendants_of(&self, node: &str) -> Vec<String> {
100        let mut visited = HashSet::new();
101        let mut queue = VecDeque::new();
102        if let Some(start) = self.node_index(node) {
103            queue.push_back(start);
104        }
105        while let Some(cur) = queue.pop_front() {
106            for &(p, c) in &self.edges {
107                if p == cur && !visited.contains(&c) {
108                    visited.insert(c);
109                    queue.push_back(c);
110                }
111            }
112        }
113        visited.into_iter().map(|i| self.nodes[i].clone()).collect()
114    }
115
116    /// Check whether the graph is acyclic using Kahn's BFS topological sort algorithm.
117    ///
118    /// Returns `true` if the graph is a valid DAG.
119    pub fn is_acyclic(&self) -> bool {
120        let n = self.nodes.len();
121        let mut in_degree = vec![0usize; n];
122        for &(_, c) in &self.edges {
123            in_degree[c] += 1;
124        }
125        let mut queue: VecDeque<usize> = (0..n).filter(|&i| in_degree[i] == 0).collect();
126        let mut processed = 0usize;
127        while let Some(cur) = queue.pop_front() {
128            processed += 1;
129            for &(p, c) in &self.edges {
130                if p == cur {
131                    in_degree[c] -= 1;
132                    if in_degree[c] == 0 {
133                        queue.push_back(c);
134                    }
135                }
136            }
137        }
138        processed == n
139    }
140
141    /// Return all node names.
142    pub fn nodes(&self) -> &[String] {
143        &self.nodes
144    }
145
146    /// Return the number of nodes.
147    pub fn node_count(&self) -> usize {
148        self.nodes.len()
149    }
150
151    /// Return the number of directed edges.
152    pub fn edge_count(&self) -> usize {
153        self.edges.len()
154    }
155
156    /// Test d-separation: is `x` d-separated from `y` given the observed set `observed`?
157    ///
158    /// Uses the Bayes-Ball algorithm on the moral graph / active path traversal.
159    /// A path is *active* given `observed` if:
160    /// - At every non-collider on the path, the node is NOT in `observed`.
161    /// - At every collider, the collider OR one of its descendants IS in `observed`.
162    pub fn d_separated(&self, x: &str, y: &str, observed: &[&str]) -> bool {
163        let x_idx = match self.node_index(x) {
164            Some(i) => i,
165            None => return true,
166        };
167        let y_idx = match self.node_index(y) {
168            Some(i) => i,
169            None => return true,
170        };
171        if x_idx == y_idx {
172            return false;
173        }
174
175        let obs_set: HashSet<usize> = observed
176            .iter()
177            .filter_map(|&name| self.node_index(name))
178            .collect();
179
180        // Pre-compute descendants of all observed nodes (needed for collider check).
181        let mut obs_or_desc: HashSet<usize> = obs_set.clone();
182        for &o in &obs_set {
183            let node_name = &self.nodes[o].clone();
184            for desc in self.descendants_of(node_name) {
185                if let Some(di) = self.node_index(&desc) {
186                    obs_or_desc.insert(di);
187                }
188            }
189        }
190
191        // State: (node_idx, arrived_via_child: bool)
192        // arrived_via_child = true  → we arrived at this node from one of its children (going "up")
193        // arrived_via_child = false → we arrived from a parent (going "down")
194        let mut visited: HashSet<(usize, bool)> = HashSet::new();
195        let mut queue: VecDeque<(usize, bool)> = VecDeque::new();
196
197        // We can start from x going both up and down.
198        queue.push_back((x_idx, true));
199        queue.push_back((x_idx, false));
200
201        while let Some((cur, via_child)) = queue.pop_front() {
202            if !visited.insert((cur, via_child)) {
203                continue;
204            }
205            if cur == y_idx {
206                return false; // active path found → NOT d-separated
207            }
208
209            if via_child && !obs_set.contains(&cur) {
210                // Traversing up (non-collider direction): pass through parents and children
211                // go up to parents
212                for &(p, c) in &self.edges {
213                    if c == cur {
214                        let state = (p, true);
215                        if !visited.contains(&state) {
216                            queue.push_back(state);
217                        }
218                    }
219                }
220                // go down to children
221                for &(p, c) in &self.edges {
222                    if p == cur {
223                        let state = (c, false);
224                        if !visited.contains(&state) {
225                            queue.push_back(state);
226                        }
227                    }
228                }
229            }
230
231            if !via_child {
232                // Arriving from above (going down)
233                if !obs_set.contains(&cur) {
234                    // Non-collider going down: continue downward
235                    for &(p, c) in &self.edges {
236                        if p == cur {
237                            let state = (c, false);
238                            if !visited.contains(&state) {
239                                queue.push_back(state);
240                            }
241                        }
242                    }
243                }
244                // Collider activation: if cur (collider) or descendant is observed, go up
245                if obs_or_desc.contains(&cur) {
246                    for &(p, c) in &self.edges {
247                        if c == cur {
248                            let state = (p, true);
249                            if !visited.contains(&state) {
250                                queue.push_back(state);
251                            }
252                        }
253                    }
254                }
255            }
256        }
257
258        true // no active path found → d-separated
259    }
260
261    /// Internal helper: collect all undirected (bidirectional) adjacency paths from `src` to `dst`
262    /// that are *backdoor paths* (i.e. paths that enter `src` via a parent of `src`).
263    /// Returns true if there exists at least one unblocked backdoor path given `adjustment_set`.
264    pub(super) fn has_unblocked_backdoor_path(
265        &self,
266        src: usize,
267        dst: usize,
268        adjustment_set: &HashSet<usize>,
269    ) -> bool {
270        // A backdoor path from src to dst is an undirected path that starts by going
271        // "upward" from src (i.e. first step is via a parent of src).
272        // We block a path by conditioning on a non-collider on the path,
273        // or by NOT conditioning on a collider / its descendant.
274        //
275        // We use a simplified reachability check:
276        // A node Z blocks a path if it is a non-collider on the path AND Z is in adjustment_set,
277        // or it is a collider not in adjustment_set and none of its descendants are.
278        //
279        // State: (current_node, previous_node, direction: true=going_up)
280        // We only consider paths that leave src going upward (backdoor).
281
282        // Compute descendants for collider check
283        let mut desc_map: HashMap<usize, HashSet<usize>> = HashMap::new();
284        for i in 0..self.nodes.len() {
285            let desc_names = self.descendants_of(&self.nodes[i].clone());
286            let desc_idxs: HashSet<usize> = desc_names
287                .iter()
288                .filter_map(|n| self.node_index(n))
289                .collect();
290            desc_map.insert(i, desc_idxs);
291        }
292
293        let is_in_adj_or_desc = |node: usize| -> bool {
294            if adjustment_set.contains(&node) {
295                return true;
296            }
297            if let Some(descs) = desc_map.get(&node) {
298                return descs.iter().any(|d| adjustment_set.contains(d));
299            }
300            false
301        };
302
303        // State: (current_node, prev_node, arrived_via_up: bool)
304        let mut visited: HashSet<(usize, usize, bool)> = HashSet::new();
305        let mut queue: VecDeque<(usize, usize, bool)> = VecDeque::new();
306
307        // Only start on parents of src (backdoor = entering src from above)
308        for &(p, c) in &self.edges {
309            if c == src {
310                // parent p of src: going up (from src to p)
311                // The first step is upward. p is a non-collider relative to src→p.
312                // Block if p is in adjustment_set
313                if !adjustment_set.contains(&p) {
314                    queue.push_back((p, src, true));
315                }
316            }
317        }
318
319        while let Some((cur, prev, going_up)) = queue.pop_front() {
320            if !visited.insert((cur, prev, going_up)) {
321                continue;
322            }
323            if cur == dst {
324                return true;
325            }
326
327            // Explore neighbors
328            // Build set of parents and children of cur
329            let parents: Vec<usize> = self
330                .edges
331                .iter()
332                .filter(|&&(_, c)| c == cur)
333                .map(|&(p, _)| p)
334                .collect();
335            let children: Vec<usize> = self
336                .edges
337                .iter()
338                .filter(|&&(p, _)| p == cur)
339                .map(|&(_, c)| c)
340                .collect();
341
342            for &next in parents.iter().chain(children.iter()) {
343                if next == prev {
344                    continue;
345                }
346                // Determine if cur is a collider on the segment prev→cur→next
347                // cur is a collider iff both prev and next are parents of cur
348                let prev_is_parent_of_cur = parents.contains(&prev);
349                let next_is_parent_of_cur = parents.contains(&next);
350                let is_collider = prev_is_parent_of_cur && next_is_parent_of_cur;
351
352                let blocked = if is_collider {
353                    // Collider: blocked unless cur or its descendant is in adjustment set
354                    !is_in_adj_or_desc(cur)
355                } else {
356                    // Non-collider: blocked if cur is in adjustment set
357                    adjustment_set.contains(&cur)
358                };
359
360                if !blocked {
361                    let next_going_up = parents.contains(&next);
362                    let state = (next, cur, next_going_up);
363                    if !visited.contains(&state) {
364                        queue.push_back(state);
365                    }
366                }
367            }
368        }
369
370        false
371    }
372
373    /// Check whether there is a directed path from `src` to `dst`.
374    pub fn has_directed_path(&self, src: &str, dst: &str) -> bool {
375        let src_idx = match self.node_index(src) {
376            Some(i) => i,
377            None => return false,
378        };
379        let dst_idx = match self.node_index(dst) {
380            Some(i) => i,
381            None => return false,
382        };
383        let mut visited = HashSet::new();
384        let mut queue = VecDeque::new();
385        queue.push_back(src_idx);
386        while let Some(cur) = queue.pop_front() {
387            if cur == dst_idx {
388                return true;
389            }
390            if !visited.insert(cur) {
391                continue;
392            }
393            for &(p, c) in &self.edges {
394                if p == cur && !visited.contains(&c) {
395                    queue.push_back(c);
396                }
397            }
398        }
399        false
400    }
401}