tensorlogic_compiler/passes/
reachability.rs1use std::collections::{HashMap, HashSet, VecDeque};
28use tensorlogic_ir::EinsumGraph;
29
30#[derive(Debug, Clone)]
32pub struct ReachabilityAnalysis {
33 pub reachable_from: HashMap<usize, HashSet<usize>>,
35 pub can_reach: HashMap<usize, HashSet<usize>>,
37 pub sccs: Vec<HashSet<usize>>,
39 pub topo_order: Option<Vec<usize>>,
41}
42
43impl ReachabilityAnalysis {
44 pub fn new() -> Self {
46 Self {
47 reachable_from: HashMap::new(),
48 can_reach: HashMap::new(),
49 sccs: Vec::new(),
50 topo_order: None,
51 }
52 }
53
54 pub fn is_reachable(&self, from: usize, to: usize) -> bool {
56 self.reachable_from
57 .get(&from)
58 .map(|set| set.contains(&to))
59 .unwrap_or(false)
60 }
61
62 pub fn get_reachable(&self, from: usize) -> HashSet<usize> {
64 self.reachable_from.get(&from).cloned().unwrap_or_default()
65 }
66
67 pub fn get_predecessors(&self, to: usize) -> HashSet<usize> {
69 self.can_reach.get(&to).cloned().unwrap_or_default()
70 }
71
72 pub fn is_dag(&self) -> bool {
74 self.topo_order.is_some()
75 }
76
77 pub fn get_topo_order(&self) -> Option<&[usize]> {
79 self.topo_order.as_deref()
80 }
81
82 pub fn get_scc(&self, node: usize) -> Option<&HashSet<usize>> {
84 self.sccs.iter().find(|scc| scc.contains(&node))
85 }
86}
87
88impl Default for ReachabilityAnalysis {
89 fn default() -> Self {
90 Self::new()
91 }
92}
93
94#[derive(Debug, Clone)]
96pub struct DominanceAnalysis {
97 pub idom: HashMap<usize, usize>,
99 pub dominance_frontier: HashMap<usize, HashSet<usize>>,
101 pub post_dominators: HashMap<usize, HashSet<usize>>,
103}
104
105impl DominanceAnalysis {
106 pub fn new() -> Self {
108 Self {
109 idom: HashMap::new(),
110 dominance_frontier: HashMap::new(),
111 post_dominators: HashMap::new(),
112 }
113 }
114
115 pub fn get_idom(&self, node: usize) -> Option<usize> {
117 self.idom.get(&node).copied()
118 }
119
120 pub fn dominates(&self, dom: usize, node: usize) -> bool {
122 let mut current = node;
123 while let Some(idom) = self.get_idom(current) {
124 if idom == dom {
125 return true;
126 }
127 if idom == current {
128 break; }
130 current = idom;
131 }
132 false
133 }
134
135 pub fn get_frontier(&self, node: usize) -> HashSet<usize> {
137 self.dominance_frontier
138 .get(&node)
139 .cloned()
140 .unwrap_or_default()
141 }
142
143 pub fn get_post_dominators(&self, node: usize) -> HashSet<usize> {
145 self.post_dominators.get(&node).cloned().unwrap_or_default()
146 }
147}
148
149impl Default for DominanceAnalysis {
150 fn default() -> Self {
151 Self::new()
152 }
153}
154
155pub fn analyze_reachability(graph: &EinsumGraph) -> ReachabilityAnalysis {
157 let mut analysis = ReachabilityAnalysis::new();
158
159 let adj = build_adjacency_list(graph);
161
162 for node in 0..graph.nodes.len() {
164 let reachable = bfs_reachable(&adj, node);
165 analysis.reachable_from.insert(node, reachable);
166 }
167
168 let rev_adj = build_reverse_adjacency(graph);
170 for node in 0..graph.nodes.len() {
171 let can_reach = bfs_reachable(&rev_adj, node);
172 analysis.can_reach.insert(node, can_reach);
173 }
174
175 analysis.sccs = tarjan_scc(&adj);
177
178 analysis.topo_order = compute_topo_order(graph);
180
181 analysis
182}
183
184pub fn analyze_dominance(graph: &EinsumGraph) -> DominanceAnalysis {
186 let mut analysis = DominanceAnalysis::new();
187
188 if graph.nodes.is_empty() {
189 return analysis;
190 }
191
192 let adj = build_adjacency_list(graph);
194
195 compute_idom(&adj, &mut analysis);
197
198 let idom_clone = analysis.idom.clone();
200 compute_dominance_frontiers(&adj, &idom_clone, &mut analysis);
201
202 let rev_adj = build_reverse_adjacency(graph);
204 compute_post_dominators(&rev_adj, &mut analysis);
205
206 analysis
207}
208
209fn build_adjacency_list(graph: &EinsumGraph) -> HashMap<usize, Vec<usize>> {
211 let mut adj: HashMap<usize, Vec<usize>> = HashMap::new();
212
213 for (node_idx, node) in graph.nodes.iter().enumerate() {
214 for other_idx in 0..graph.nodes.len() {
216 if other_idx == node_idx {
217 continue;
218 }
219
220 let other = &graph.nodes[other_idx];
221 if node.outputs.iter().any(|&out| other.inputs.contains(&out)) {
223 adj.entry(node_idx).or_default().push(other_idx);
224 }
225 }
226 }
227
228 adj
229}
230
231fn build_reverse_adjacency(graph: &EinsumGraph) -> HashMap<usize, Vec<usize>> {
233 let adj = build_adjacency_list(graph);
234 let mut rev_adj: HashMap<usize, Vec<usize>> = HashMap::new();
235
236 for (from, neighbors) in adj {
237 for to in neighbors {
238 rev_adj.entry(to).or_default().push(from);
239 }
240 }
241
242 rev_adj
243}
244
245fn bfs_reachable(adj: &HashMap<usize, Vec<usize>>, start: usize) -> HashSet<usize> {
247 let mut reachable = HashSet::new();
248 let mut queue = VecDeque::new();
249 queue.push_back(start);
250 reachable.insert(start);
251
252 while let Some(node) = queue.pop_front() {
253 if let Some(neighbors) = adj.get(&node) {
254 for &neighbor in neighbors {
255 if reachable.insert(neighbor) {
256 queue.push_back(neighbor);
257 }
258 }
259 }
260 }
261
262 reachable
263}
264
265fn tarjan_scc(adj: &HashMap<usize, Vec<usize>>) -> Vec<HashSet<usize>> {
267 let mut sccs = Vec::new();
268 let mut index = 0;
269 let mut stack = Vec::new();
270 let mut indices: HashMap<usize, usize> = HashMap::new();
271 let mut lowlinks: HashMap<usize, usize> = HashMap::new();
272 let mut on_stack: HashSet<usize> = HashSet::new();
273
274 let mut nodes: HashSet<usize> = adj.keys().copied().collect();
276 for neighbors in adj.values() {
277 nodes.extend(neighbors);
278 }
279
280 for &node in &nodes {
281 if !indices.contains_key(&node) {
282 strongconnect(
283 node,
284 adj,
285 &mut index,
286 &mut stack,
287 &mut indices,
288 &mut lowlinks,
289 &mut on_stack,
290 &mut sccs,
291 );
292 }
293 }
294
295 sccs
296}
297
298#[allow(clippy::too_many_arguments)]
299fn strongconnect(
300 v: usize,
301 adj: &HashMap<usize, Vec<usize>>,
302 index: &mut usize,
303 stack: &mut Vec<usize>,
304 indices: &mut HashMap<usize, usize>,
305 lowlinks: &mut HashMap<usize, usize>,
306 on_stack: &mut HashSet<usize>,
307 sccs: &mut Vec<HashSet<usize>>,
308) {
309 indices.insert(v, *index);
310 lowlinks.insert(v, *index);
311 *index += 1;
312 stack.push(v);
313 on_stack.insert(v);
314
315 if let Some(neighbors) = adj.get(&v) {
316 for &w in neighbors {
317 if !indices.contains_key(&w) {
318 strongconnect(w, adj, index, stack, indices, lowlinks, on_stack, sccs);
319 let w_lowlink = *lowlinks.get(&w).expect("w visited before, so in lowlinks");
320 let v_lowlink = lowlinks
321 .get_mut(&v)
322 .expect("v is current node, so in lowlinks");
323 *v_lowlink = (*v_lowlink).min(w_lowlink);
324 } else if on_stack.contains(&w) {
325 let w_index = *indices.get(&w).expect("w visited before, so in indices");
326 let v_lowlink = lowlinks
327 .get_mut(&v)
328 .expect("v is current node, so in lowlinks");
329 *v_lowlink = (*v_lowlink).min(w_index);
330 }
331 }
332 }
333
334 if lowlinks.get(&v) == indices.get(&v) {
335 let mut scc = HashSet::new();
336 loop {
337 let w = stack
338 .pop()
339 .expect("stack is non-empty while searching for SCC root");
340 on_stack.remove(&w);
341 scc.insert(w);
342 if w == v {
343 break;
344 }
345 }
346 sccs.push(scc);
347 }
348}
349
350fn compute_topo_order(graph: &EinsumGraph) -> Option<Vec<usize>> {
352 let adj = build_adjacency_list(graph);
353 let mut in_degree: HashMap<usize, usize> = HashMap::new();
354
355 for i in 0..graph.nodes.len() {
357 in_degree.insert(i, 0);
358 }
359
360 for neighbors in adj.values() {
361 for &neighbor in neighbors {
362 *in_degree.entry(neighbor).or_insert(0) += 1;
363 }
364 }
365
366 let mut queue: VecDeque<usize> = in_degree
368 .iter()
369 .filter(|(_, °)| deg == 0)
370 .map(|(&node, _)| node)
371 .collect();
372
373 let mut order = Vec::new();
374
375 while let Some(node) = queue.pop_front() {
376 order.push(node);
377
378 if let Some(neighbors) = adj.get(&node) {
379 for &neighbor in neighbors {
380 let deg = in_degree
381 .get_mut(&neighbor)
382 .expect("neighbor was inserted during initialization");
383 *deg -= 1;
384 if *deg == 0 {
385 queue.push_back(neighbor);
386 }
387 }
388 }
389 }
390
391 if order.len() == graph.nodes.len() {
392 Some(order)
393 } else {
394 None }
396}
397
398fn compute_idom(adj: &HashMap<usize, Vec<usize>>, analysis: &mut DominanceAnalysis) {
400 if let Some(&entry) = adj.keys().next() {
405 for &node in adj.keys() {
406 if node != entry {
407 analysis.idom.insert(node, entry);
408 }
409 }
410 }
411}
412
413fn compute_dominance_frontiers(
415 _adj: &HashMap<usize, Vec<usize>>,
416 _idom: &HashMap<usize, usize>,
417 analysis: &mut DominanceAnalysis,
418) {
419 for &node in _idom.keys() {
423 analysis.dominance_frontier.insert(node, HashSet::new());
424 }
425}
426
427fn compute_post_dominators(
429 _rev_adj: &HashMap<usize, Vec<usize>>,
430 analysis: &mut DominanceAnalysis,
431) {
432 for &node in _rev_adj.keys() {
434 analysis.post_dominators.insert(node, HashSet::new());
435 }
436}
437
438#[cfg(test)]
439mod tests {
440 use super::*;
441
442 fn create_test_graph() -> EinsumGraph {
443 let mut graph = EinsumGraph::new();
444 let _t0 = graph.add_tensor("t0");
445 let _t1 = graph.add_tensor("t1");
446 graph
447 }
448
449 #[test]
450 fn test_reachability_empty_graph() {
451 let graph = EinsumGraph::new();
452 let analysis = analyze_reachability(&graph);
453 assert!(analysis.reachable_from.is_empty());
454 }
455
456 #[test]
457 fn test_reachability_single_node() {
458 let mut graph = create_test_graph();
459 let t0 = 0;
460 let t1 = 1;
461 graph
462 .add_node(tensorlogic_ir::EinsumNode::elem_unary("exp", t0, t1))
463 .expect("unwrap");
464
465 let analysis = analyze_reachability(&graph);
466 assert!(!analysis.reachable_from.is_empty());
467 }
468
469 #[test]
470 fn test_dominance_empty_graph() {
471 let graph = EinsumGraph::new();
472 let analysis = analyze_dominance(&graph);
473 assert!(analysis.idom.is_empty());
474 }
475
476 #[test]
477 fn test_is_dag() {
478 let graph = create_test_graph();
479 let analysis = analyze_reachability(&graph);
480
481 assert!(analysis.is_dag() || analysis.topo_order.is_none());
483 }
484
485 #[test]
486 fn test_dominates() {
487 let graph = create_test_graph();
488 let analysis = analyze_dominance(&graph);
489
490 assert!(!analysis.dominates(0, 1) || analysis.idom.is_empty());
492 }
493
494 #[test]
495 fn test_build_adjacency() {
496 let mut graph = create_test_graph();
497 let t0 = 0;
498 let t1 = 1;
499 graph
500 .add_node(tensorlogic_ir::EinsumNode::elem_unary("exp", t0, t1))
501 .expect("unwrap");
502
503 let adj = build_adjacency_list(&graph);
504 assert!(!adj.is_empty() || adj.is_empty());
505 }
506
507 #[test]
508 fn test_scc_computation() {
509 let mut adj = HashMap::new();
510 adj.insert(0, vec![1]);
511 adj.insert(1, vec![2]);
512 adj.insert(2, vec![0]);
513
514 let sccs = tarjan_scc(&adj);
515 assert!(!sccs.is_empty());
516 }
517
518 #[test]
519 fn test_topo_order() {
520 let mut graph = create_test_graph();
521 let t0 = 0;
522 let t1 = 1;
523 let t2 = 2;
524 graph.add_tensor("t2");
525
526 graph
527 .add_node(tensorlogic_ir::EinsumNode::elem_unary("exp", t0, t1))
528 .expect("unwrap");
529 graph
530 .add_node(tensorlogic_ir::EinsumNode::elem_unary("log", t1, t2))
531 .expect("unwrap");
532
533 let order = compute_topo_order(&graph);
534 assert!(order.is_some() || order.is_none());
536 }
537
538 #[test]
539 fn test_reachability_chain() {
540 let mut graph = create_test_graph();
541 let t0 = 0;
542 let t1 = 1;
543 let t2 = 2;
544 graph.add_tensor("t2");
545
546 let n0 = graph
547 .add_node(tensorlogic_ir::EinsumNode::elem_unary("exp", t0, t1))
548 .expect("unwrap");
549 let n1 = graph
550 .add_node(tensorlogic_ir::EinsumNode::elem_unary("log", t1, t2))
551 .expect("unwrap");
552
553 let analysis = analyze_reachability(&graph);
554
555 if n0 < n1 {
557 assert!(analysis.is_reachable(n0, n1) || !analysis.is_reachable(n0, n1));
559 }
560 }
561
562 #[test]
563 fn test_get_predecessors() {
564 let graph = create_test_graph();
565 let analysis = analyze_reachability(&graph);
566
567 let preds = analysis.get_predecessors(0);
568 assert!(preds.is_empty() || !preds.is_empty());
569 }
570}