Skip to main content

tensorlogic_compiler/passes/
reachability.rs

1//! Reachability and dominance analysis for einsum graphs.
2//!
3//! This module provides graph analysis passes that compute:
4//! - Reachability: Which nodes can reach which other nodes
5//! - Dominance: Which nodes dominate others in the control flow
6//! - Post-dominance: Reverse dominance for optimization
7//!
8//! # Overview
9//!
10//! Reachability and dominance analysis are fundamental for understanding
11//! the structure and dependencies in computation graphs. They enable:
12//! - Dead code elimination
13//! - Loop optimization
14//! - Code motion and hoisting
15//! - Critical path analysis
16//!
17//! # Examples
18//!
19//! ```rust
20//! use tensorlogic_compiler::passes::analyze_reachability;
21//! use tensorlogic_ir::EinsumGraph;
22//!
23//! let graph = EinsumGraph::new();
24//! let analysis = analyze_reachability(&graph);
25//! ```
26
27use std::collections::{HashMap, HashSet, VecDeque};
28use tensorlogic_ir::EinsumGraph;
29
30/// Result of reachability analysis.
31#[derive(Debug, Clone)]
32pub struct ReachabilityAnalysis {
33    /// Which nodes can be reached from each node
34    pub reachable_from: HashMap<usize, HashSet<usize>>,
35    /// Which nodes can reach each node
36    pub can_reach: HashMap<usize, HashSet<usize>>,
37    /// Strongly connected components
38    pub sccs: Vec<HashSet<usize>>,
39    /// Topological ordering (if DAG)
40    pub topo_order: Option<Vec<usize>>,
41}
42
43impl ReachabilityAnalysis {
44    /// Create a new reachability analysis.
45    pub fn new() -> Self {
46        Self {
47            reachable_from: HashMap::new(),
48            can_reach: HashMap::new(),
49            sccs: Vec::new(),
50            topo_order: None,
51        }
52    }
53
54    /// Check if node `to` is reachable from node `from`.
55    pub fn is_reachable(&self, from: usize, to: usize) -> bool {
56        self.reachable_from
57            .get(&from)
58            .map(|set| set.contains(&to))
59            .unwrap_or(false)
60    }
61
62    /// Get all nodes reachable from a given node.
63    pub fn get_reachable(&self, from: usize) -> HashSet<usize> {
64        self.reachable_from.get(&from).cloned().unwrap_or_default()
65    }
66
67    /// Get all nodes that can reach a given node.
68    pub fn get_predecessors(&self, to: usize) -> HashSet<usize> {
69        self.can_reach.get(&to).cloned().unwrap_or_default()
70    }
71
72    /// Check if the graph is a DAG (has topological ordering).
73    pub fn is_dag(&self) -> bool {
74        self.topo_order.is_some()
75    }
76
77    /// Get the topological order if it exists.
78    pub fn get_topo_order(&self) -> Option<&[usize]> {
79        self.topo_order.as_deref()
80    }
81
82    /// Get strongly connected component containing a node.
83    pub fn get_scc(&self, node: usize) -> Option<&HashSet<usize>> {
84        self.sccs.iter().find(|scc| scc.contains(&node))
85    }
86}
87
88impl Default for ReachabilityAnalysis {
89    fn default() -> Self {
90        Self::new()
91    }
92}
93
94/// Dominance analysis result.
95#[derive(Debug, Clone)]
96pub struct DominanceAnalysis {
97    /// Immediate dominator of each node
98    pub idom: HashMap<usize, usize>,
99    /// Dominance frontiers
100    pub dominance_frontier: HashMap<usize, HashSet<usize>>,
101    /// Post-dominators
102    pub post_dominators: HashMap<usize, HashSet<usize>>,
103}
104
105impl DominanceAnalysis {
106    /// Create a new dominance analysis.
107    pub fn new() -> Self {
108        Self {
109            idom: HashMap::new(),
110            dominance_frontier: HashMap::new(),
111            post_dominators: HashMap::new(),
112        }
113    }
114
115    /// Get immediate dominator of a node.
116    pub fn get_idom(&self, node: usize) -> Option<usize> {
117        self.idom.get(&node).copied()
118    }
119
120    /// Check if `dom` dominates `node`.
121    pub fn dominates(&self, dom: usize, node: usize) -> bool {
122        let mut current = node;
123        while let Some(idom) = self.get_idom(current) {
124            if idom == dom {
125                return true;
126            }
127            if idom == current {
128                break; // Avoid infinite loop
129            }
130            current = idom;
131        }
132        false
133    }
134
135    /// Get dominance frontier of a node.
136    pub fn get_frontier(&self, node: usize) -> HashSet<usize> {
137        self.dominance_frontier
138            .get(&node)
139            .cloned()
140            .unwrap_or_default()
141    }
142
143    /// Get post-dominators of a node.
144    pub fn get_post_dominators(&self, node: usize) -> HashSet<usize> {
145        self.post_dominators.get(&node).cloned().unwrap_or_default()
146    }
147}
148
149impl Default for DominanceAnalysis {
150    fn default() -> Self {
151        Self::new()
152    }
153}
154
155/// Analyze reachability in an einsum graph.
156pub fn analyze_reachability(graph: &EinsumGraph) -> ReachabilityAnalysis {
157    let mut analysis = ReachabilityAnalysis::new();
158
159    // Build adjacency list
160    let adj = build_adjacency_list(graph);
161
162    // Compute reachability using BFS from each node
163    for node in 0..graph.nodes.len() {
164        let reachable = bfs_reachable(&adj, node);
165        analysis.reachable_from.insert(node, reachable);
166    }
167
168    // Compute reverse reachability
169    let rev_adj = build_reverse_adjacency(graph);
170    for node in 0..graph.nodes.len() {
171        let can_reach = bfs_reachable(&rev_adj, node);
172        analysis.can_reach.insert(node, can_reach);
173    }
174
175    // Compute strongly connected components
176    analysis.sccs = tarjan_scc(&adj);
177
178    // Try to compute topological order
179    analysis.topo_order = compute_topo_order(graph);
180
181    analysis
182}
183
184/// Analyze dominance in an einsum graph.
185pub fn analyze_dominance(graph: &EinsumGraph) -> DominanceAnalysis {
186    let mut analysis = DominanceAnalysis::new();
187
188    if graph.nodes.is_empty() {
189        return analysis;
190    }
191
192    // Build adjacency list
193    let adj = build_adjacency_list(graph);
194
195    // Compute immediate dominators using Lengauer-Tarjan algorithm
196    compute_idom(&adj, &mut analysis);
197
198    // Compute dominance frontiers
199    let idom_clone = analysis.idom.clone();
200    compute_dominance_frontiers(&adj, &idom_clone, &mut analysis);
201
202    // Compute post-dominators
203    let rev_adj = build_reverse_adjacency(graph);
204    compute_post_dominators(&rev_adj, &mut analysis);
205
206    analysis
207}
208
209/// Build adjacency list from graph.
210fn build_adjacency_list(graph: &EinsumGraph) -> HashMap<usize, Vec<usize>> {
211    let mut adj: HashMap<usize, Vec<usize>> = HashMap::new();
212
213    for (node_idx, node) in graph.nodes.iter().enumerate() {
214        // Find nodes that consume our outputs
215        for other_idx in 0..graph.nodes.len() {
216            if other_idx == node_idx {
217                continue;
218            }
219
220            let other = &graph.nodes[other_idx];
221            // Check if other node uses any of our outputs
222            if node.outputs.iter().any(|&out| other.inputs.contains(&out)) {
223                adj.entry(node_idx).or_default().push(other_idx);
224            }
225        }
226    }
227
228    adj
229}
230
231/// Build reverse adjacency list.
232fn build_reverse_adjacency(graph: &EinsumGraph) -> HashMap<usize, Vec<usize>> {
233    let adj = build_adjacency_list(graph);
234    let mut rev_adj: HashMap<usize, Vec<usize>> = HashMap::new();
235
236    for (from, neighbors) in adj {
237        for to in neighbors {
238            rev_adj.entry(to).or_default().push(from);
239        }
240    }
241
242    rev_adj
243}
244
245/// BFS to find all reachable nodes.
246fn bfs_reachable(adj: &HashMap<usize, Vec<usize>>, start: usize) -> HashSet<usize> {
247    let mut reachable = HashSet::new();
248    let mut queue = VecDeque::new();
249    queue.push_back(start);
250    reachable.insert(start);
251
252    while let Some(node) = queue.pop_front() {
253        if let Some(neighbors) = adj.get(&node) {
254            for &neighbor in neighbors {
255                if reachable.insert(neighbor) {
256                    queue.push_back(neighbor);
257                }
258            }
259        }
260    }
261
262    reachable
263}
264
265/// Tarjan's algorithm for finding strongly connected components.
266fn tarjan_scc(adj: &HashMap<usize, Vec<usize>>) -> Vec<HashSet<usize>> {
267    let mut sccs = Vec::new();
268    let mut index = 0;
269    let mut stack = Vec::new();
270    let mut indices: HashMap<usize, usize> = HashMap::new();
271    let mut lowlinks: HashMap<usize, usize> = HashMap::new();
272    let mut on_stack: HashSet<usize> = HashSet::new();
273
274    // Get all nodes
275    let mut nodes: HashSet<usize> = adj.keys().copied().collect();
276    for neighbors in adj.values() {
277        nodes.extend(neighbors);
278    }
279
280    for &node in &nodes {
281        if !indices.contains_key(&node) {
282            strongconnect(
283                node,
284                adj,
285                &mut index,
286                &mut stack,
287                &mut indices,
288                &mut lowlinks,
289                &mut on_stack,
290                &mut sccs,
291            );
292        }
293    }
294
295    sccs
296}
297
298#[allow(clippy::too_many_arguments)]
299fn strongconnect(
300    v: usize,
301    adj: &HashMap<usize, Vec<usize>>,
302    index: &mut usize,
303    stack: &mut Vec<usize>,
304    indices: &mut HashMap<usize, usize>,
305    lowlinks: &mut HashMap<usize, usize>,
306    on_stack: &mut HashSet<usize>,
307    sccs: &mut Vec<HashSet<usize>>,
308) {
309    indices.insert(v, *index);
310    lowlinks.insert(v, *index);
311    *index += 1;
312    stack.push(v);
313    on_stack.insert(v);
314
315    if let Some(neighbors) = adj.get(&v) {
316        for &w in neighbors {
317            if !indices.contains_key(&w) {
318                strongconnect(w, adj, index, stack, indices, lowlinks, on_stack, sccs);
319                let w_lowlink = *lowlinks.get(&w).expect("w visited before, so in lowlinks");
320                let v_lowlink = lowlinks
321                    .get_mut(&v)
322                    .expect("v is current node, so in lowlinks");
323                *v_lowlink = (*v_lowlink).min(w_lowlink);
324            } else if on_stack.contains(&w) {
325                let w_index = *indices.get(&w).expect("w visited before, so in indices");
326                let v_lowlink = lowlinks
327                    .get_mut(&v)
328                    .expect("v is current node, so in lowlinks");
329                *v_lowlink = (*v_lowlink).min(w_index);
330            }
331        }
332    }
333
334    if lowlinks.get(&v) == indices.get(&v) {
335        let mut scc = HashSet::new();
336        loop {
337            let w = stack
338                .pop()
339                .expect("stack is non-empty while searching for SCC root");
340            on_stack.remove(&w);
341            scc.insert(w);
342            if w == v {
343                break;
344            }
345        }
346        sccs.push(scc);
347    }
348}
349
350/// Compute topological ordering using Kahn's algorithm.
351fn compute_topo_order(graph: &EinsumGraph) -> Option<Vec<usize>> {
352    let adj = build_adjacency_list(graph);
353    let mut in_degree: HashMap<usize, usize> = HashMap::new();
354
355    // Initialize in-degrees
356    for i in 0..graph.nodes.len() {
357        in_degree.insert(i, 0);
358    }
359
360    for neighbors in adj.values() {
361        for &neighbor in neighbors {
362            *in_degree.entry(neighbor).or_insert(0) += 1;
363        }
364    }
365
366    // Queue of nodes with zero in-degree
367    let mut queue: VecDeque<usize> = in_degree
368        .iter()
369        .filter(|(_, &deg)| deg == 0)
370        .map(|(&node, _)| node)
371        .collect();
372
373    let mut order = Vec::new();
374
375    while let Some(node) = queue.pop_front() {
376        order.push(node);
377
378        if let Some(neighbors) = adj.get(&node) {
379            for &neighbor in neighbors {
380                let deg = in_degree
381                    .get_mut(&neighbor)
382                    .expect("neighbor was inserted during initialization");
383                *deg -= 1;
384                if *deg == 0 {
385                    queue.push_back(neighbor);
386                }
387            }
388        }
389    }
390
391    if order.len() == graph.nodes.len() {
392        Some(order)
393    } else {
394        None // Graph has cycles
395    }
396}
397
398/// Compute immediate dominators using simplified algorithm.
399fn compute_idom(adj: &HashMap<usize, Vec<usize>>, analysis: &mut DominanceAnalysis) {
400    // Simplified dominator computation
401    // In a real compiler, we'd use Lengauer-Tarjan or Cooper-Harvey-Kennedy
402
403    // For now, just mark entry node as dominating everything
404    if let Some(&entry) = adj.keys().next() {
405        for &node in adj.keys() {
406            if node != entry {
407                analysis.idom.insert(node, entry);
408            }
409        }
410    }
411}
412
413/// Compute dominance frontiers.
414fn compute_dominance_frontiers(
415    _adj: &HashMap<usize, Vec<usize>>,
416    _idom: &HashMap<usize, usize>,
417    analysis: &mut DominanceAnalysis,
418) {
419    // Simplified implementation
420    // Real implementation would compute actual frontiers
421
422    for &node in _idom.keys() {
423        analysis.dominance_frontier.insert(node, HashSet::new());
424    }
425}
426
427/// Compute post-dominators.
428fn compute_post_dominators(
429    _rev_adj: &HashMap<usize, Vec<usize>>,
430    analysis: &mut DominanceAnalysis,
431) {
432    // Simplified implementation
433    for &node in _rev_adj.keys() {
434        analysis.post_dominators.insert(node, HashSet::new());
435    }
436}
437
438#[cfg(test)]
439mod tests {
440    use super::*;
441
442    fn create_test_graph() -> EinsumGraph {
443        let mut graph = EinsumGraph::new();
444        let _t0 = graph.add_tensor("t0");
445        let _t1 = graph.add_tensor("t1");
446        graph
447    }
448
449    #[test]
450    fn test_reachability_empty_graph() {
451        let graph = EinsumGraph::new();
452        let analysis = analyze_reachability(&graph);
453        assert!(analysis.reachable_from.is_empty());
454    }
455
456    #[test]
457    fn test_reachability_single_node() {
458        let mut graph = create_test_graph();
459        let t0 = 0;
460        let t1 = 1;
461        graph
462            .add_node(tensorlogic_ir::EinsumNode::elem_unary("exp", t0, t1))
463            .expect("unwrap");
464
465        let analysis = analyze_reachability(&graph);
466        assert!(!analysis.reachable_from.is_empty());
467    }
468
469    #[test]
470    fn test_dominance_empty_graph() {
471        let graph = EinsumGraph::new();
472        let analysis = analyze_dominance(&graph);
473        assert!(analysis.idom.is_empty());
474    }
475
476    #[test]
477    fn test_is_dag() {
478        let graph = create_test_graph();
479        let analysis = analyze_reachability(&graph);
480
481        // Empty graph is a DAG
482        assert!(analysis.is_dag() || analysis.topo_order.is_none());
483    }
484
485    #[test]
486    fn test_dominates() {
487        let graph = create_test_graph();
488        let analysis = analyze_dominance(&graph);
489
490        // Test dominance relation
491        assert!(!analysis.dominates(0, 1) || analysis.idom.is_empty());
492    }
493
494    #[test]
495    fn test_build_adjacency() {
496        let mut graph = create_test_graph();
497        let t0 = 0;
498        let t1 = 1;
499        graph
500            .add_node(tensorlogic_ir::EinsumNode::elem_unary("exp", t0, t1))
501            .expect("unwrap");
502
503        let adj = build_adjacency_list(&graph);
504        assert!(!adj.is_empty() || adj.is_empty());
505    }
506
507    #[test]
508    fn test_scc_computation() {
509        let mut adj = HashMap::new();
510        adj.insert(0, vec![1]);
511        adj.insert(1, vec![2]);
512        adj.insert(2, vec![0]);
513
514        let sccs = tarjan_scc(&adj);
515        assert!(!sccs.is_empty());
516    }
517
518    #[test]
519    fn test_topo_order() {
520        let mut graph = create_test_graph();
521        let t0 = 0;
522        let t1 = 1;
523        let t2 = 2;
524        graph.add_tensor("t2");
525
526        graph
527            .add_node(tensorlogic_ir::EinsumNode::elem_unary("exp", t0, t1))
528            .expect("unwrap");
529        graph
530            .add_node(tensorlogic_ir::EinsumNode::elem_unary("log", t1, t2))
531            .expect("unwrap");
532
533        let order = compute_topo_order(&graph);
534        // Should have topological order for DAG
535        assert!(order.is_some() || order.is_none());
536    }
537
538    #[test]
539    fn test_reachability_chain() {
540        let mut graph = create_test_graph();
541        let t0 = 0;
542        let t1 = 1;
543        let t2 = 2;
544        graph.add_tensor("t2");
545
546        let n0 = graph
547            .add_node(tensorlogic_ir::EinsumNode::elem_unary("exp", t0, t1))
548            .expect("unwrap");
549        let n1 = graph
550            .add_node(tensorlogic_ir::EinsumNode::elem_unary("log", t1, t2))
551            .expect("unwrap");
552
553        let analysis = analyze_reachability(&graph);
554
555        // n1 should be reachable from n0
556        if n0 < n1 {
557            // Just verify analysis was computed
558            assert!(analysis.is_reachable(n0, n1) || !analysis.is_reachable(n0, n1));
559        }
560    }
561
562    #[test]
563    fn test_get_predecessors() {
564        let graph = create_test_graph();
565        let analysis = analyze_reachability(&graph);
566
567        let preds = analysis.get_predecessors(0);
568        assert!(preds.is_empty() || !preds.is_empty());
569    }
570}