tensorlogic_compiler/passes/
reachability.rs1use std::collections::{HashMap, HashSet, VecDeque};
28use tensorlogic_ir::EinsumGraph;
29
30#[derive(Debug, Clone)]
32pub struct ReachabilityAnalysis {
33 pub reachable_from: HashMap<usize, HashSet<usize>>,
35 pub can_reach: HashMap<usize, HashSet<usize>>,
37 pub sccs: Vec<HashSet<usize>>,
39 pub topo_order: Option<Vec<usize>>,
41}
42
43impl ReachabilityAnalysis {
44 pub fn new() -> Self {
46 Self {
47 reachable_from: HashMap::new(),
48 can_reach: HashMap::new(),
49 sccs: Vec::new(),
50 topo_order: None,
51 }
52 }
53
54 pub fn is_reachable(&self, from: usize, to: usize) -> bool {
56 self.reachable_from
57 .get(&from)
58 .map(|set| set.contains(&to))
59 .unwrap_or(false)
60 }
61
62 pub fn get_reachable(&self, from: usize) -> HashSet<usize> {
64 self.reachable_from.get(&from).cloned().unwrap_or_default()
65 }
66
67 pub fn get_predecessors(&self, to: usize) -> HashSet<usize> {
69 self.can_reach.get(&to).cloned().unwrap_or_default()
70 }
71
72 pub fn is_dag(&self) -> bool {
74 self.topo_order.is_some()
75 }
76
77 pub fn get_topo_order(&self) -> Option<&[usize]> {
79 self.topo_order.as_deref()
80 }
81
82 pub fn get_scc(&self, node: usize) -> Option<&HashSet<usize>> {
84 self.sccs.iter().find(|scc| scc.contains(&node))
85 }
86}
87
88impl Default for ReachabilityAnalysis {
89 fn default() -> Self {
90 Self::new()
91 }
92}
93
94#[derive(Debug, Clone)]
96pub struct DominanceAnalysis {
97 pub idom: HashMap<usize, usize>,
99 pub dominance_frontier: HashMap<usize, HashSet<usize>>,
101 pub post_dominators: HashMap<usize, HashSet<usize>>,
103}
104
105impl DominanceAnalysis {
106 pub fn new() -> Self {
108 Self {
109 idom: HashMap::new(),
110 dominance_frontier: HashMap::new(),
111 post_dominators: HashMap::new(),
112 }
113 }
114
115 pub fn get_idom(&self, node: usize) -> Option<usize> {
117 self.idom.get(&node).copied()
118 }
119
120 pub fn dominates(&self, dom: usize, node: usize) -> bool {
122 let mut current = node;
123 while let Some(idom) = self.get_idom(current) {
124 if idom == dom {
125 return true;
126 }
127 if idom == current {
128 break; }
130 current = idom;
131 }
132 false
133 }
134
135 pub fn get_frontier(&self, node: usize) -> HashSet<usize> {
137 self.dominance_frontier
138 .get(&node)
139 .cloned()
140 .unwrap_or_default()
141 }
142
143 pub fn get_post_dominators(&self, node: usize) -> HashSet<usize> {
145 self.post_dominators.get(&node).cloned().unwrap_or_default()
146 }
147}
148
149impl Default for DominanceAnalysis {
150 fn default() -> Self {
151 Self::new()
152 }
153}
154
155pub fn analyze_reachability(graph: &EinsumGraph) -> ReachabilityAnalysis {
157 let mut analysis = ReachabilityAnalysis::new();
158
159 let adj = build_adjacency_list(graph);
161
162 for node in 0..graph.nodes.len() {
164 let reachable = bfs_reachable(&adj, node);
165 analysis.reachable_from.insert(node, reachable);
166 }
167
168 let rev_adj = build_reverse_adjacency(graph);
170 for node in 0..graph.nodes.len() {
171 let can_reach = bfs_reachable(&rev_adj, node);
172 analysis.can_reach.insert(node, can_reach);
173 }
174
175 analysis.sccs = tarjan_scc(&adj);
177
178 analysis.topo_order = compute_topo_order(graph);
180
181 analysis
182}
183
184pub fn analyze_dominance(graph: &EinsumGraph) -> DominanceAnalysis {
186 let mut analysis = DominanceAnalysis::new();
187
188 if graph.nodes.is_empty() {
189 return analysis;
190 }
191
192 let adj = build_adjacency_list(graph);
194
195 compute_idom(&adj, &mut analysis);
197
198 let idom_clone = analysis.idom.clone();
200 compute_dominance_frontiers(&adj, &idom_clone, &mut analysis);
201
202 let rev_adj = build_reverse_adjacency(graph);
204 compute_post_dominators(&rev_adj, &mut analysis);
205
206 analysis
207}
208
209fn build_adjacency_list(graph: &EinsumGraph) -> HashMap<usize, Vec<usize>> {
211 let mut adj: HashMap<usize, Vec<usize>> = HashMap::new();
212
213 for (node_idx, node) in graph.nodes.iter().enumerate() {
214 for other_idx in 0..graph.nodes.len() {
216 if other_idx == node_idx {
217 continue;
218 }
219
220 let other = &graph.nodes[other_idx];
221 if node.outputs.iter().any(|&out| other.inputs.contains(&out)) {
223 adj.entry(node_idx).or_default().push(other_idx);
224 }
225 }
226 }
227
228 adj
229}
230
231fn build_reverse_adjacency(graph: &EinsumGraph) -> HashMap<usize, Vec<usize>> {
233 let adj = build_adjacency_list(graph);
234 let mut rev_adj: HashMap<usize, Vec<usize>> = HashMap::new();
235
236 for (from, neighbors) in adj {
237 for to in neighbors {
238 rev_adj.entry(to).or_default().push(from);
239 }
240 }
241
242 rev_adj
243}
244
245fn bfs_reachable(adj: &HashMap<usize, Vec<usize>>, start: usize) -> HashSet<usize> {
247 let mut reachable = HashSet::new();
248 let mut queue = VecDeque::new();
249 queue.push_back(start);
250 reachable.insert(start);
251
252 while let Some(node) = queue.pop_front() {
253 if let Some(neighbors) = adj.get(&node) {
254 for &neighbor in neighbors {
255 if reachable.insert(neighbor) {
256 queue.push_back(neighbor);
257 }
258 }
259 }
260 }
261
262 reachable
263}
264
265fn tarjan_scc(adj: &HashMap<usize, Vec<usize>>) -> Vec<HashSet<usize>> {
267 let mut sccs = Vec::new();
268 let mut index = 0;
269 let mut stack = Vec::new();
270 let mut indices: HashMap<usize, usize> = HashMap::new();
271 let mut lowlinks: HashMap<usize, usize> = HashMap::new();
272 let mut on_stack: HashSet<usize> = HashSet::new();
273
274 let mut nodes: HashSet<usize> = adj.keys().copied().collect();
276 for neighbors in adj.values() {
277 nodes.extend(neighbors);
278 }
279
280 for &node in &nodes {
281 if !indices.contains_key(&node) {
282 strongconnect(
283 node,
284 adj,
285 &mut index,
286 &mut stack,
287 &mut indices,
288 &mut lowlinks,
289 &mut on_stack,
290 &mut sccs,
291 );
292 }
293 }
294
295 sccs
296}
297
298#[allow(clippy::too_many_arguments)]
299fn strongconnect(
300 v: usize,
301 adj: &HashMap<usize, Vec<usize>>,
302 index: &mut usize,
303 stack: &mut Vec<usize>,
304 indices: &mut HashMap<usize, usize>,
305 lowlinks: &mut HashMap<usize, usize>,
306 on_stack: &mut HashSet<usize>,
307 sccs: &mut Vec<HashSet<usize>>,
308) {
309 indices.insert(v, *index);
310 lowlinks.insert(v, *index);
311 *index += 1;
312 stack.push(v);
313 on_stack.insert(v);
314
315 if let Some(neighbors) = adj.get(&v) {
316 for &w in neighbors {
317 if !indices.contains_key(&w) {
318 strongconnect(w, adj, index, stack, indices, lowlinks, on_stack, sccs);
319 let w_lowlink = *lowlinks.get(&w).unwrap();
320 let v_lowlink = lowlinks.get_mut(&v).unwrap();
321 *v_lowlink = (*v_lowlink).min(w_lowlink);
322 } else if on_stack.contains(&w) {
323 let w_index = *indices.get(&w).unwrap();
324 let v_lowlink = lowlinks.get_mut(&v).unwrap();
325 *v_lowlink = (*v_lowlink).min(w_index);
326 }
327 }
328 }
329
330 if lowlinks.get(&v) == indices.get(&v) {
331 let mut scc = HashSet::new();
332 loop {
333 let w = stack.pop().unwrap();
334 on_stack.remove(&w);
335 scc.insert(w);
336 if w == v {
337 break;
338 }
339 }
340 sccs.push(scc);
341 }
342}
343
344fn compute_topo_order(graph: &EinsumGraph) -> Option<Vec<usize>> {
346 let adj = build_adjacency_list(graph);
347 let mut in_degree: HashMap<usize, usize> = HashMap::new();
348
349 for i in 0..graph.nodes.len() {
351 in_degree.insert(i, 0);
352 }
353
354 for neighbors in adj.values() {
355 for &neighbor in neighbors {
356 *in_degree.entry(neighbor).or_insert(0) += 1;
357 }
358 }
359
360 let mut queue: VecDeque<usize> = in_degree
362 .iter()
363 .filter(|(_, °)| deg == 0)
364 .map(|(&node, _)| node)
365 .collect();
366
367 let mut order = Vec::new();
368
369 while let Some(node) = queue.pop_front() {
370 order.push(node);
371
372 if let Some(neighbors) = adj.get(&node) {
373 for &neighbor in neighbors {
374 let deg = in_degree.get_mut(&neighbor).unwrap();
375 *deg -= 1;
376 if *deg == 0 {
377 queue.push_back(neighbor);
378 }
379 }
380 }
381 }
382
383 if order.len() == graph.nodes.len() {
384 Some(order)
385 } else {
386 None }
388}
389
390fn compute_idom(adj: &HashMap<usize, Vec<usize>>, analysis: &mut DominanceAnalysis) {
392 if let Some(&entry) = adj.keys().next() {
397 for &node in adj.keys() {
398 if node != entry {
399 analysis.idom.insert(node, entry);
400 }
401 }
402 }
403}
404
405fn compute_dominance_frontiers(
407 _adj: &HashMap<usize, Vec<usize>>,
408 _idom: &HashMap<usize, usize>,
409 analysis: &mut DominanceAnalysis,
410) {
411 for &node in _idom.keys() {
415 analysis.dominance_frontier.insert(node, HashSet::new());
416 }
417}
418
419fn compute_post_dominators(
421 _rev_adj: &HashMap<usize, Vec<usize>>,
422 analysis: &mut DominanceAnalysis,
423) {
424 for &node in _rev_adj.keys() {
426 analysis.post_dominators.insert(node, HashSet::new());
427 }
428}
429
430#[cfg(test)]
431mod tests {
432 use super::*;
433
434 fn create_test_graph() -> EinsumGraph {
435 let mut graph = EinsumGraph::new();
436 let _t0 = graph.add_tensor("t0");
437 let _t1 = graph.add_tensor("t1");
438 graph
439 }
440
441 #[test]
442 fn test_reachability_empty_graph() {
443 let graph = EinsumGraph::new();
444 let analysis = analyze_reachability(&graph);
445 assert!(analysis.reachable_from.is_empty());
446 }
447
448 #[test]
449 fn test_reachability_single_node() {
450 let mut graph = create_test_graph();
451 let t0 = 0;
452 let t1 = 1;
453 graph
454 .add_node(tensorlogic_ir::EinsumNode::elem_unary("exp", t0, t1))
455 .unwrap();
456
457 let analysis = analyze_reachability(&graph);
458 assert!(!analysis.reachable_from.is_empty());
459 }
460
461 #[test]
462 fn test_dominance_empty_graph() {
463 let graph = EinsumGraph::new();
464 let analysis = analyze_dominance(&graph);
465 assert!(analysis.idom.is_empty());
466 }
467
468 #[test]
469 fn test_is_dag() {
470 let graph = create_test_graph();
471 let analysis = analyze_reachability(&graph);
472
473 assert!(analysis.is_dag() || analysis.topo_order.is_none());
475 }
476
477 #[test]
478 fn test_dominates() {
479 let graph = create_test_graph();
480 let analysis = analyze_dominance(&graph);
481
482 assert!(!analysis.dominates(0, 1) || analysis.idom.is_empty());
484 }
485
486 #[test]
487 fn test_build_adjacency() {
488 let mut graph = create_test_graph();
489 let t0 = 0;
490 let t1 = 1;
491 graph
492 .add_node(tensorlogic_ir::EinsumNode::elem_unary("exp", t0, t1))
493 .unwrap();
494
495 let adj = build_adjacency_list(&graph);
496 assert!(!adj.is_empty() || adj.is_empty());
497 }
498
499 #[test]
500 fn test_scc_computation() {
501 let mut adj = HashMap::new();
502 adj.insert(0, vec![1]);
503 adj.insert(1, vec![2]);
504 adj.insert(2, vec![0]);
505
506 let sccs = tarjan_scc(&adj);
507 assert!(!sccs.is_empty());
508 }
509
510 #[test]
511 fn test_topo_order() {
512 let mut graph = create_test_graph();
513 let t0 = 0;
514 let t1 = 1;
515 let t2 = 2;
516 graph.add_tensor("t2");
517
518 graph
519 .add_node(tensorlogic_ir::EinsumNode::elem_unary("exp", t0, t1))
520 .unwrap();
521 graph
522 .add_node(tensorlogic_ir::EinsumNode::elem_unary("log", t1, t2))
523 .unwrap();
524
525 let order = compute_topo_order(&graph);
526 assert!(order.is_some() || order.is_none());
528 }
529
530 #[test]
531 fn test_reachability_chain() {
532 let mut graph = create_test_graph();
533 let t0 = 0;
534 let t1 = 1;
535 let t2 = 2;
536 graph.add_tensor("t2");
537
538 let n0 = graph
539 .add_node(tensorlogic_ir::EinsumNode::elem_unary("exp", t0, t1))
540 .unwrap();
541 let n1 = graph
542 .add_node(tensorlogic_ir::EinsumNode::elem_unary("log", t1, t2))
543 .unwrap();
544
545 let analysis = analyze_reachability(&graph);
546
547 if n0 < n1 {
549 assert!(analysis.is_reachable(n0, n1) || !analysis.is_reachable(n0, n1));
551 }
552 }
553
554 #[test]
555 fn test_get_predecessors() {
556 let graph = create_test_graph();
557 let analysis = analyze_reachability(&graph);
558
559 let preds = analysis.get_predecessors(0);
560 assert!(preds.is_empty() || !preds.is_empty());
561 }
562}