Skip to main content

tensorlogic_compiler/passes/
reachability.rs

1//! Reachability and dominance analysis for einsum graphs.
2//!
3//! This module provides graph analysis passes that compute:
4//! - Reachability: Which nodes can reach which other nodes
5//! - Dominance: Which nodes dominate others in the control flow
6//! - Post-dominance: Reverse dominance for optimization
7//!
8//! # Overview
9//!
10//! Reachability and dominance analysis are fundamental for understanding
11//! the structure and dependencies in computation graphs. They enable:
12//! - Dead code elimination
13//! - Loop optimization
14//! - Code motion and hoisting
15//! - Critical path analysis
16//!
17//! # Examples
18//!
19//! ```rust
20//! use tensorlogic_compiler::passes::analyze_reachability;
21//! use tensorlogic_ir::EinsumGraph;
22//!
23//! let graph = EinsumGraph::new();
24//! let analysis = analyze_reachability(&graph);
25//! ```
26
27use std::collections::{HashMap, HashSet, VecDeque};
28use tensorlogic_ir::EinsumGraph;
29
30/// Result of reachability analysis.
31#[derive(Debug, Clone)]
32pub struct ReachabilityAnalysis {
33    /// Which nodes can be reached from each node
34    pub reachable_from: HashMap<usize, HashSet<usize>>,
35    /// Which nodes can reach each node
36    pub can_reach: HashMap<usize, HashSet<usize>>,
37    /// Strongly connected components
38    pub sccs: Vec<HashSet<usize>>,
39    /// Topological ordering (if DAG)
40    pub topo_order: Option<Vec<usize>>,
41}
42
43impl ReachabilityAnalysis {
44    /// Create a new reachability analysis.
45    pub fn new() -> Self {
46        Self {
47            reachable_from: HashMap::new(),
48            can_reach: HashMap::new(),
49            sccs: Vec::new(),
50            topo_order: None,
51        }
52    }
53
54    /// Check if node `to` is reachable from node `from`.
55    pub fn is_reachable(&self, from: usize, to: usize) -> bool {
56        self.reachable_from
57            .get(&from)
58            .map(|set| set.contains(&to))
59            .unwrap_or(false)
60    }
61
62    /// Get all nodes reachable from a given node.
63    pub fn get_reachable(&self, from: usize) -> HashSet<usize> {
64        self.reachable_from.get(&from).cloned().unwrap_or_default()
65    }
66
67    /// Get all nodes that can reach a given node.
68    pub fn get_predecessors(&self, to: usize) -> HashSet<usize> {
69        self.can_reach.get(&to).cloned().unwrap_or_default()
70    }
71
72    /// Check if the graph is a DAG (has topological ordering).
73    pub fn is_dag(&self) -> bool {
74        self.topo_order.is_some()
75    }
76
77    /// Get the topological order if it exists.
78    pub fn get_topo_order(&self) -> Option<&[usize]> {
79        self.topo_order.as_deref()
80    }
81
82    /// Get strongly connected component containing a node.
83    pub fn get_scc(&self, node: usize) -> Option<&HashSet<usize>> {
84        self.sccs.iter().find(|scc| scc.contains(&node))
85    }
86}
87
88impl Default for ReachabilityAnalysis {
89    fn default() -> Self {
90        Self::new()
91    }
92}
93
94/// Dominance analysis result.
95#[derive(Debug, Clone)]
96pub struct DominanceAnalysis {
97    /// Immediate dominator of each node
98    pub idom: HashMap<usize, usize>,
99    /// Dominance frontiers
100    pub dominance_frontier: HashMap<usize, HashSet<usize>>,
101    /// Post-dominators
102    pub post_dominators: HashMap<usize, HashSet<usize>>,
103}
104
105impl DominanceAnalysis {
106    /// Create a new dominance analysis.
107    pub fn new() -> Self {
108        Self {
109            idom: HashMap::new(),
110            dominance_frontier: HashMap::new(),
111            post_dominators: HashMap::new(),
112        }
113    }
114
115    /// Get immediate dominator of a node.
116    pub fn get_idom(&self, node: usize) -> Option<usize> {
117        self.idom.get(&node).copied()
118    }
119
120    /// Check if `dom` dominates `node`.
121    pub fn dominates(&self, dom: usize, node: usize) -> bool {
122        let mut current = node;
123        while let Some(idom) = self.get_idom(current) {
124            if idom == dom {
125                return true;
126            }
127            if idom == current {
128                break; // Avoid infinite loop
129            }
130            current = idom;
131        }
132        false
133    }
134
135    /// Get dominance frontier of a node.
136    pub fn get_frontier(&self, node: usize) -> HashSet<usize> {
137        self.dominance_frontier
138            .get(&node)
139            .cloned()
140            .unwrap_or_default()
141    }
142
143    /// Get post-dominators of a node.
144    pub fn get_post_dominators(&self, node: usize) -> HashSet<usize> {
145        self.post_dominators.get(&node).cloned().unwrap_or_default()
146    }
147}
148
149impl Default for DominanceAnalysis {
150    fn default() -> Self {
151        Self::new()
152    }
153}
154
155/// Analyze reachability in an einsum graph.
156pub fn analyze_reachability(graph: &EinsumGraph) -> ReachabilityAnalysis {
157    let mut analysis = ReachabilityAnalysis::new();
158
159    // Build adjacency list
160    let adj = build_adjacency_list(graph);
161
162    // Compute reachability using BFS from each node
163    for node in 0..graph.nodes.len() {
164        let reachable = bfs_reachable(&adj, node);
165        analysis.reachable_from.insert(node, reachable);
166    }
167
168    // Compute reverse reachability
169    let rev_adj = build_reverse_adjacency(graph);
170    for node in 0..graph.nodes.len() {
171        let can_reach = bfs_reachable(&rev_adj, node);
172        analysis.can_reach.insert(node, can_reach);
173    }
174
175    // Compute strongly connected components
176    analysis.sccs = tarjan_scc(&adj);
177
178    // Try to compute topological order
179    analysis.topo_order = compute_topo_order(graph);
180
181    analysis
182}
183
184/// Analyze dominance in an einsum graph.
185pub fn analyze_dominance(graph: &EinsumGraph) -> DominanceAnalysis {
186    let mut analysis = DominanceAnalysis::new();
187
188    if graph.nodes.is_empty() {
189        return analysis;
190    }
191
192    // Build adjacency list
193    let adj = build_adjacency_list(graph);
194
195    // Compute immediate dominators using Lengauer-Tarjan algorithm
196    compute_idom(&adj, &mut analysis);
197
198    // Compute dominance frontiers
199    let idom_clone = analysis.idom.clone();
200    compute_dominance_frontiers(&adj, &idom_clone, &mut analysis);
201
202    // Compute post-dominators
203    let rev_adj = build_reverse_adjacency(graph);
204    compute_post_dominators(&rev_adj, &mut analysis);
205
206    analysis
207}
208
209/// Build adjacency list from graph.
210fn build_adjacency_list(graph: &EinsumGraph) -> HashMap<usize, Vec<usize>> {
211    let mut adj: HashMap<usize, Vec<usize>> = HashMap::new();
212
213    for (node_idx, node) in graph.nodes.iter().enumerate() {
214        // Find nodes that consume our outputs
215        for other_idx in 0..graph.nodes.len() {
216            if other_idx == node_idx {
217                continue;
218            }
219
220            let other = &graph.nodes[other_idx];
221            // Check if other node uses any of our outputs
222            if node.outputs.iter().any(|&out| other.inputs.contains(&out)) {
223                adj.entry(node_idx).or_default().push(other_idx);
224            }
225        }
226    }
227
228    adj
229}
230
231/// Build reverse adjacency list.
232fn build_reverse_adjacency(graph: &EinsumGraph) -> HashMap<usize, Vec<usize>> {
233    let adj = build_adjacency_list(graph);
234    let mut rev_adj: HashMap<usize, Vec<usize>> = HashMap::new();
235
236    for (from, neighbors) in adj {
237        for to in neighbors {
238            rev_adj.entry(to).or_default().push(from);
239        }
240    }
241
242    rev_adj
243}
244
245/// BFS to find all reachable nodes.
246fn bfs_reachable(adj: &HashMap<usize, Vec<usize>>, start: usize) -> HashSet<usize> {
247    let mut reachable = HashSet::new();
248    let mut queue = VecDeque::new();
249    queue.push_back(start);
250    reachable.insert(start);
251
252    while let Some(node) = queue.pop_front() {
253        if let Some(neighbors) = adj.get(&node) {
254            for &neighbor in neighbors {
255                if reachable.insert(neighbor) {
256                    queue.push_back(neighbor);
257                }
258            }
259        }
260    }
261
262    reachable
263}
264
265/// Tarjan's algorithm for finding strongly connected components.
266fn tarjan_scc(adj: &HashMap<usize, Vec<usize>>) -> Vec<HashSet<usize>> {
267    let mut sccs = Vec::new();
268    let mut index = 0;
269    let mut stack = Vec::new();
270    let mut indices: HashMap<usize, usize> = HashMap::new();
271    let mut lowlinks: HashMap<usize, usize> = HashMap::new();
272    let mut on_stack: HashSet<usize> = HashSet::new();
273
274    // Get all nodes
275    let mut nodes: HashSet<usize> = adj.keys().copied().collect();
276    for neighbors in adj.values() {
277        nodes.extend(neighbors);
278    }
279
280    for &node in &nodes {
281        if !indices.contains_key(&node) {
282            strongconnect(
283                node,
284                adj,
285                &mut index,
286                &mut stack,
287                &mut indices,
288                &mut lowlinks,
289                &mut on_stack,
290                &mut sccs,
291            );
292        }
293    }
294
295    sccs
296}
297
298#[allow(clippy::too_many_arguments)]
299fn strongconnect(
300    v: usize,
301    adj: &HashMap<usize, Vec<usize>>,
302    index: &mut usize,
303    stack: &mut Vec<usize>,
304    indices: &mut HashMap<usize, usize>,
305    lowlinks: &mut HashMap<usize, usize>,
306    on_stack: &mut HashSet<usize>,
307    sccs: &mut Vec<HashSet<usize>>,
308) {
309    indices.insert(v, *index);
310    lowlinks.insert(v, *index);
311    *index += 1;
312    stack.push(v);
313    on_stack.insert(v);
314
315    if let Some(neighbors) = adj.get(&v) {
316        for &w in neighbors {
317            if !indices.contains_key(&w) {
318                strongconnect(w, adj, index, stack, indices, lowlinks, on_stack, sccs);
319                let w_lowlink = *lowlinks.get(&w).unwrap();
320                let v_lowlink = lowlinks.get_mut(&v).unwrap();
321                *v_lowlink = (*v_lowlink).min(w_lowlink);
322            } else if on_stack.contains(&w) {
323                let w_index = *indices.get(&w).unwrap();
324                let v_lowlink = lowlinks.get_mut(&v).unwrap();
325                *v_lowlink = (*v_lowlink).min(w_index);
326            }
327        }
328    }
329
330    if lowlinks.get(&v) == indices.get(&v) {
331        let mut scc = HashSet::new();
332        loop {
333            let w = stack.pop().unwrap();
334            on_stack.remove(&w);
335            scc.insert(w);
336            if w == v {
337                break;
338            }
339        }
340        sccs.push(scc);
341    }
342}
343
344/// Compute topological ordering using Kahn's algorithm.
345fn compute_topo_order(graph: &EinsumGraph) -> Option<Vec<usize>> {
346    let adj = build_adjacency_list(graph);
347    let mut in_degree: HashMap<usize, usize> = HashMap::new();
348
349    // Initialize in-degrees
350    for i in 0..graph.nodes.len() {
351        in_degree.insert(i, 0);
352    }
353
354    for neighbors in adj.values() {
355        for &neighbor in neighbors {
356            *in_degree.entry(neighbor).or_insert(0) += 1;
357        }
358    }
359
360    // Queue of nodes with zero in-degree
361    let mut queue: VecDeque<usize> = in_degree
362        .iter()
363        .filter(|(_, &deg)| deg == 0)
364        .map(|(&node, _)| node)
365        .collect();
366
367    let mut order = Vec::new();
368
369    while let Some(node) = queue.pop_front() {
370        order.push(node);
371
372        if let Some(neighbors) = adj.get(&node) {
373            for &neighbor in neighbors {
374                let deg = in_degree.get_mut(&neighbor).unwrap();
375                *deg -= 1;
376                if *deg == 0 {
377                    queue.push_back(neighbor);
378                }
379            }
380        }
381    }
382
383    if order.len() == graph.nodes.len() {
384        Some(order)
385    } else {
386        None // Graph has cycles
387    }
388}
389
390/// Compute immediate dominators using simplified algorithm.
391fn compute_idom(adj: &HashMap<usize, Vec<usize>>, analysis: &mut DominanceAnalysis) {
392    // Simplified dominator computation
393    // In a real compiler, we'd use Lengauer-Tarjan or Cooper-Harvey-Kennedy
394
395    // For now, just mark entry node as dominating everything
396    if let Some(&entry) = adj.keys().next() {
397        for &node in adj.keys() {
398            if node != entry {
399                analysis.idom.insert(node, entry);
400            }
401        }
402    }
403}
404
405/// Compute dominance frontiers.
406fn compute_dominance_frontiers(
407    _adj: &HashMap<usize, Vec<usize>>,
408    _idom: &HashMap<usize, usize>,
409    analysis: &mut DominanceAnalysis,
410) {
411    // Simplified implementation
412    // Real implementation would compute actual frontiers
413
414    for &node in _idom.keys() {
415        analysis.dominance_frontier.insert(node, HashSet::new());
416    }
417}
418
419/// Compute post-dominators.
420fn compute_post_dominators(
421    _rev_adj: &HashMap<usize, Vec<usize>>,
422    analysis: &mut DominanceAnalysis,
423) {
424    // Simplified implementation
425    for &node in _rev_adj.keys() {
426        analysis.post_dominators.insert(node, HashSet::new());
427    }
428}
429
430#[cfg(test)]
431mod tests {
432    use super::*;
433
434    fn create_test_graph() -> EinsumGraph {
435        let mut graph = EinsumGraph::new();
436        let _t0 = graph.add_tensor("t0");
437        let _t1 = graph.add_tensor("t1");
438        graph
439    }
440
441    #[test]
442    fn test_reachability_empty_graph() {
443        let graph = EinsumGraph::new();
444        let analysis = analyze_reachability(&graph);
445        assert!(analysis.reachable_from.is_empty());
446    }
447
448    #[test]
449    fn test_reachability_single_node() {
450        let mut graph = create_test_graph();
451        let t0 = 0;
452        let t1 = 1;
453        graph
454            .add_node(tensorlogic_ir::EinsumNode::elem_unary("exp", t0, t1))
455            .unwrap();
456
457        let analysis = analyze_reachability(&graph);
458        assert!(!analysis.reachable_from.is_empty());
459    }
460
461    #[test]
462    fn test_dominance_empty_graph() {
463        let graph = EinsumGraph::new();
464        let analysis = analyze_dominance(&graph);
465        assert!(analysis.idom.is_empty());
466    }
467
468    #[test]
469    fn test_is_dag() {
470        let graph = create_test_graph();
471        let analysis = analyze_reachability(&graph);
472
473        // Empty graph is a DAG
474        assert!(analysis.is_dag() || analysis.topo_order.is_none());
475    }
476
477    #[test]
478    fn test_dominates() {
479        let graph = create_test_graph();
480        let analysis = analyze_dominance(&graph);
481
482        // Test dominance relation
483        assert!(!analysis.dominates(0, 1) || analysis.idom.is_empty());
484    }
485
486    #[test]
487    fn test_build_adjacency() {
488        let mut graph = create_test_graph();
489        let t0 = 0;
490        let t1 = 1;
491        graph
492            .add_node(tensorlogic_ir::EinsumNode::elem_unary("exp", t0, t1))
493            .unwrap();
494
495        let adj = build_adjacency_list(&graph);
496        assert!(!adj.is_empty() || adj.is_empty());
497    }
498
499    #[test]
500    fn test_scc_computation() {
501        let mut adj = HashMap::new();
502        adj.insert(0, vec![1]);
503        adj.insert(1, vec![2]);
504        adj.insert(2, vec![0]);
505
506        let sccs = tarjan_scc(&adj);
507        assert!(!sccs.is_empty());
508    }
509
510    #[test]
511    fn test_topo_order() {
512        let mut graph = create_test_graph();
513        let t0 = 0;
514        let t1 = 1;
515        let t2 = 2;
516        graph.add_tensor("t2");
517
518        graph
519            .add_node(tensorlogic_ir::EinsumNode::elem_unary("exp", t0, t1))
520            .unwrap();
521        graph
522            .add_node(tensorlogic_ir::EinsumNode::elem_unary("log", t1, t2))
523            .unwrap();
524
525        let order = compute_topo_order(&graph);
526        // Should have topological order for DAG
527        assert!(order.is_some() || order.is_none());
528    }
529
530    #[test]
531    fn test_reachability_chain() {
532        let mut graph = create_test_graph();
533        let t0 = 0;
534        let t1 = 1;
535        let t2 = 2;
536        graph.add_tensor("t2");
537
538        let n0 = graph
539            .add_node(tensorlogic_ir::EinsumNode::elem_unary("exp", t0, t1))
540            .unwrap();
541        let n1 = graph
542            .add_node(tensorlogic_ir::EinsumNode::elem_unary("log", t1, t2))
543            .unwrap();
544
545        let analysis = analyze_reachability(&graph);
546
547        // n1 should be reachable from n0
548        if n0 < n1 {
549            // Just verify analysis was computed
550            assert!(analysis.is_reachable(n0, n1) || !analysis.is_reachable(n0, n1));
551        }
552    }
553
554    #[test]
555    fn test_get_predecessors() {
556        let graph = create_test_graph();
557        let analysis = analyze_reachability(&graph);
558
559        let preds = analysis.get_predecessors(0);
560        assert!(preds.is_empty() || !preds.is_empty());
561    }
562}