Skip to main content

tensorlogic_ir/graph/
advanced_algorithms.rs

1//! # Advanced Graph Algorithms for EinsumGraph
2//!
3//! This module implements sophisticated graph analysis algorithms for tensor computation graphs:
4//!
5//! - **Cycle Detection**: Find cycles in the computation graph (important for detecting feedback loops)
6//! - **Strongly Connected Components (SCC)**: Find maximal strongly connected subgraphs (Tarjan's algorithm)
7//! - **Topological Ordering**: Generate execution orders respecting dependencies
8//! - **Graph Isomorphism**: Check if two graphs are structurally equivalent
9//! - **Minimum Cut**: Find bottlenecks in computation flow
10//! - **Critical Path Analysis**: Identify longest paths (critical for scheduling)
11//! - **Dominator Trees**: Find nodes that dominate others in the control flow
12//!
13//! ## Applications
14//!
15//! - **Optimization**: Detect opportunities for fusion, reordering, parallelization
16//! - **Verification**: Ensure acyclicity, detect redundancy
17//! - **Scheduling**: Find critical paths, identify parallelizable regions
18//! - **Debugging**: Understand graph structure, find anomalies
19//!
20//! ## Example
21//!
22//! ```rust
23//! use tensorlogic_ir::{EinsumGraph, EinsumNode, find_cycles, strongly_connected_components, topological_sort};
24//!
25//! let mut graph = EinsumGraph::new();
26//! let a = graph.add_tensor("A");
27//! let b = graph.add_tensor("B");
28//! let c = graph.add_tensor("C");
29//!
30//! graph.add_node(EinsumNode::einsum("ij,jk->ik", vec![a, b], vec![c])).unwrap();
31//!
32//! // Detect cycles
33//! let cycles = find_cycles(&graph);
34//! assert!(cycles.is_empty()); // Should be acyclic
35//!
36//! // Find strongly connected components
37//! let sccs = strongly_connected_components(&graph);
38//!
39//! // Topological sort
40//! let topo_order = topological_sort(&graph).unwrap();
41//! ```
42
43use crate::graph::EinsumGraph;
44use serde::{Deserialize, Serialize};
45use std::collections::{HashMap, HashSet, VecDeque};
46
47/// A cycle in the computation graph.
48#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
49pub struct Cycle {
50    /// Tensor indices forming the cycle
51    pub tensors: Vec<usize>,
52    /// Node indices involved in the cycle
53    pub nodes: Vec<usize>,
54}
55
56/// Find all cycles in the computation graph.
57///
58/// Uses depth-first search with backtracking to enumerate all simple cycles.
59/// Note: This can be expensive for large graphs with many cycles.
60pub fn find_cycles(graph: &EinsumGraph) -> Vec<Cycle> {
61    let mut cycles = Vec::new();
62    let mut visited = HashSet::new();
63    let mut rec_stack = HashSet::new();
64    let mut path = Vec::new();
65
66    // Build adjacency list for tensor dependencies
67    let adjacency = build_tensor_adjacency(graph);
68
69    for tensor_idx in 0..graph.tensors.len() {
70        if !visited.contains(&tensor_idx) {
71            dfs_find_cycles(
72                tensor_idx,
73                &adjacency,
74                &mut visited,
75                &mut rec_stack,
76                &mut path,
77                &mut cycles,
78            );
79        }
80    }
81
82    cycles
83}
84
85/// DFS helper for cycle detection.
86fn dfs_find_cycles(
87    tensor: usize,
88    adjacency: &HashMap<usize, Vec<usize>>,
89    visited: &mut HashSet<usize>,
90    rec_stack: &mut HashSet<usize>,
91    path: &mut Vec<usize>,
92    cycles: &mut Vec<Cycle>,
93) {
94    visited.insert(tensor);
95    rec_stack.insert(tensor);
96    path.push(tensor);
97
98    if let Some(neighbors) = adjacency.get(&tensor) {
99        for &neighbor in neighbors {
100            if !visited.contains(&neighbor) {
101                dfs_find_cycles(neighbor, adjacency, visited, rec_stack, path, cycles);
102            } else if rec_stack.contains(&neighbor) {
103                // Found a cycle
104                if let Some(cycle_start) = path.iter().position(|&t| t == neighbor) {
105                    let cycle_tensors = path[cycle_start..].to_vec();
106                    cycles.push(Cycle {
107                        tensors: cycle_tensors,
108                        nodes: Vec::new(), // Would need to compute from tensors
109                    });
110                }
111            }
112        }
113    }
114
115    path.pop();
116    rec_stack.remove(&tensor);
117}
118
119/// Build tensor adjacency list from graph.
120fn build_tensor_adjacency(graph: &EinsumGraph) -> HashMap<usize, Vec<usize>> {
121    let mut adjacency: HashMap<usize, Vec<usize>> = HashMap::new();
122
123    for node in &graph.nodes {
124        for &input_tensor in &node.inputs {
125            for &output_tensor in &node.outputs {
126                adjacency
127                    .entry(input_tensor)
128                    .or_default()
129                    .push(output_tensor);
130            }
131        }
132    }
133
134    adjacency
135}
136
137/// A strongly connected component (SCC) in the computation graph.
138#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
139pub struct StronglyConnectedComponent {
140    /// Tensor indices in this SCC
141    pub tensors: Vec<usize>,
142    /// Node indices in this SCC
143    pub nodes: Vec<usize>,
144}
145
146/// Find all strongly connected components using Tarjan's algorithm.
147///
148/// An SCC is a maximal set of nodes where every node is reachable from every other node.
149/// This is useful for detecting mutually dependent computations.
150pub fn strongly_connected_components(graph: &EinsumGraph) -> Vec<StronglyConnectedComponent> {
151    let mut tarjan = TarjanSCC::new(graph);
152    tarjan.find_sccs();
153    tarjan.sccs
154}
155
156/// Tarjan's algorithm for finding SCCs.
157struct TarjanSCC<'a> {
158    graph: &'a EinsumGraph,
159    adjacency: HashMap<usize, Vec<usize>>,
160    index: usize,
161    indices: HashMap<usize, usize>,
162    lowlinks: HashMap<usize, usize>,
163    on_stack: HashSet<usize>,
164    stack: Vec<usize>,
165    sccs: Vec<StronglyConnectedComponent>,
166}
167
168impl<'a> TarjanSCC<'a> {
169    fn new(graph: &'a EinsumGraph) -> Self {
170        TarjanSCC {
171            graph,
172            adjacency: build_tensor_adjacency(graph),
173            index: 0,
174            indices: HashMap::new(),
175            lowlinks: HashMap::new(),
176            on_stack: HashSet::new(),
177            stack: Vec::new(),
178            sccs: Vec::new(),
179        }
180    }
181
182    fn find_sccs(&mut self) {
183        for tensor_idx in 0..self.graph.tensors.len() {
184            if !self.indices.contains_key(&tensor_idx) {
185                self.strong_connect(tensor_idx);
186            }
187        }
188    }
189
190    fn strong_connect(&mut self, v: usize) {
191        self.indices.insert(v, self.index);
192        self.lowlinks.insert(v, self.index);
193        self.index += 1;
194        self.stack.push(v);
195        self.on_stack.insert(v);
196
197        if let Some(neighbors) = self.adjacency.get(&v).cloned() {
198            for w in neighbors {
199                if !self.indices.contains_key(&w) {
200                    self.strong_connect(w);
201                    let w_lowlink = *self.lowlinks.get(&w).unwrap();
202                    let v_lowlink = *self.lowlinks.get(&v).unwrap();
203                    self.lowlinks.insert(v, v_lowlink.min(w_lowlink));
204                } else if self.on_stack.contains(&w) {
205                    let w_index = *self.indices.get(&w).unwrap();
206                    let v_lowlink = *self.lowlinks.get(&v).unwrap();
207                    self.lowlinks.insert(v, v_lowlink.min(w_index));
208                }
209            }
210        }
211
212        // If v is a root node, pop the stack to get an SCC
213        if self.lowlinks[&v] == self.indices[&v] {
214            let mut scc_tensors = Vec::new();
215            loop {
216                let w = self.stack.pop().unwrap();
217                self.on_stack.remove(&w);
218                scc_tensors.push(w);
219                if w == v {
220                    break;
221                }
222            }
223            self.sccs.push(StronglyConnectedComponent {
224                tensors: scc_tensors,
225                nodes: Vec::new(),
226            });
227        }
228    }
229}
230
231/// Perform topological sort on the computation graph.
232///
233/// Returns a linearization of tensors such that if there's a dependency from A to B,
234/// A appears before B in the ordering. Returns `None` if the graph contains cycles.
235pub fn topological_sort(graph: &EinsumGraph) -> Option<Vec<usize>> {
236    let adjacency = build_tensor_adjacency(graph);
237    let mut in_degree = vec![0; graph.tensors.len()];
238
239    // Compute in-degrees
240    for neighbors in adjacency.values() {
241        for &neighbor in neighbors {
242            in_degree[neighbor] += 1;
243        }
244    }
245
246    // Queue of tensors with in-degree 0
247    let mut queue: VecDeque<usize> = in_degree
248        .iter()
249        .enumerate()
250        .filter(|(_, &deg)| deg == 0)
251        .map(|(idx, _)| idx)
252        .collect();
253
254    let mut result = Vec::new();
255
256    while let Some(tensor) = queue.pop_front() {
257        result.push(tensor);
258
259        if let Some(neighbors) = adjacency.get(&tensor) {
260            for &neighbor in neighbors {
261                in_degree[neighbor] -= 1;
262                if in_degree[neighbor] == 0 {
263                    queue.push_back(neighbor);
264                }
265            }
266        }
267    }
268
269    // If we didn't process all tensors, there's a cycle
270    if result.len() == graph.tensors.len() {
271        Some(result)
272    } else {
273        None
274    }
275}
276
277/// Check if a graph is a directed acyclic graph (DAG).
278pub fn is_dag(graph: &EinsumGraph) -> bool {
279    topological_sort(graph).is_some()
280}
281
282/// Graph isomorphism result.
283#[derive(Clone, Debug, PartialEq, Eq)]
284pub enum IsomorphismResult {
285    /// Graphs are isomorphic with the given mapping
286    Isomorphic { mapping: HashMap<usize, usize> },
287    /// Graphs are not isomorphic
288    NotIsomorphic,
289}
290
291/// Check if two graphs are isomorphic.
292///
293/// This uses a simplified algorithm based on degree sequences and local structure.
294/// Note: Graph isomorphism is NP-complete in general, so this uses heuristics.
295pub fn are_isomorphic(g1: &EinsumGraph, g2: &EinsumGraph) -> IsomorphismResult {
296    // Quick checks
297    if g1.tensors.len() != g2.tensors.len() || g1.nodes.len() != g2.nodes.len() {
298        return IsomorphismResult::NotIsomorphic;
299    }
300
301    // Check degree sequences
302    let deg1 = compute_degree_sequence(g1);
303    let deg2 = compute_degree_sequence(g2);
304
305    if deg1 != deg2 {
306        return IsomorphismResult::NotIsomorphic;
307    }
308
309    // Try to find an isomorphism using backtracking
310    // (This is a simplified implementation; full GI would use more sophisticated methods)
311
312    let mut mapping = HashMap::new();
313    if backtrack_isomorphism(g1, g2, &mut mapping, 0) {
314        IsomorphismResult::Isomorphic { mapping }
315    } else {
316        IsomorphismResult::NotIsomorphic
317    }
318}
319
320/// Compute degree sequence for a graph.
321fn compute_degree_sequence(graph: &EinsumGraph) -> Vec<(usize, usize)> {
322    let mut in_degrees = vec![0; graph.tensors.len()];
323    let mut out_degrees = vec![0; graph.tensors.len()];
324
325    for node in &graph.nodes {
326        for &input in &node.inputs {
327            out_degrees[input] += 1;
328        }
329        for &output in &node.outputs {
330            in_degrees[output] += 1;
331        }
332    }
333
334    let mut degrees: Vec<(usize, usize)> = in_degrees.into_iter().zip(out_degrees).collect();
335
336    degrees.sort_unstable();
337    degrees
338}
339
340/// Backtracking search for graph isomorphism.
341fn backtrack_isomorphism(
342    g1: &EinsumGraph,
343    g2: &EinsumGraph,
344    mapping: &mut HashMap<usize, usize>,
345    tensor_idx: usize,
346) -> bool {
347    // Base case: all tensors mapped
348    if tensor_idx >= g1.tensors.len() {
349        return verify_isomorphism(g1, g2, mapping);
350    }
351
352    // Try mapping tensor_idx to each unmapped tensor in g2
353    let mapped_values: HashSet<usize> = mapping.values().copied().collect();
354
355    for candidate in 0..g2.tensors.len() {
356        if !mapped_values.contains(&candidate) {
357            mapping.insert(tensor_idx, candidate);
358
359            if backtrack_isomorphism(g1, g2, mapping, tensor_idx + 1) {
360                return true;
361            }
362
363            mapping.remove(&tensor_idx);
364        }
365    }
366
367    false
368}
369
370/// Verify that a mapping is a valid isomorphism.
371fn verify_isomorphism(g1: &EinsumGraph, g2: &EinsumGraph, mapping: &HashMap<usize, usize>) -> bool {
372    // Check that every edge in g1 maps to an edge in g2
373    let adj1 = build_tensor_adjacency(g1);
374    let adj2 = build_tensor_adjacency(g2);
375
376    for (u, neighbors) in &adj1 {
377        let u_mapped = mapping[u];
378
379        for &v in neighbors {
380            let v_mapped = mapping[&v];
381
382            // Check if edge (u_mapped -> v_mapped) exists in g2
383            if let Some(adj2_neighbors) = adj2.get(&u_mapped) {
384                if !adj2_neighbors.contains(&v_mapped) {
385                    return false;
386                }
387            } else {
388                return false;
389            }
390        }
391    }
392
393    true
394}
395
396/// Critical path analysis result.
397#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
398pub struct CriticalPath {
399    /// Tensors on the critical path
400    pub tensors: Vec<usize>,
401    /// Nodes on the critical path
402    pub nodes: Vec<usize>,
403    /// Total length (sum of weights) of the critical path
404    pub length: f64,
405}
406
407/// Find the critical path in the computation graph.
408///
409/// The critical path is the longest path from inputs to outputs,
410/// which represents the minimum time required for execution.
411pub fn critical_path_analysis(
412    graph: &EinsumGraph,
413    weights: &HashMap<usize, f64>,
414) -> Option<CriticalPath> {
415    if !is_dag(graph) {
416        return None; // Critical path only defined for DAGs
417    }
418
419    let topo_order = topological_sort(graph)?;
420    let adjacency = build_tensor_adjacency(graph);
421
422    let mut distances: HashMap<usize, f64> = HashMap::new();
423    let mut predecessors: HashMap<usize, usize> = HashMap::new();
424
425    // Initialize distances
426    for &tensor in &topo_order {
427        distances.insert(tensor, 0.0);
428    }
429
430    // Compute longest paths
431    for &u in &topo_order {
432        if let Some(neighbors) = adjacency.get(&u) {
433            let u_dist = distances[&u];
434
435            for &v in neighbors {
436                let weight = weights.get(&v).copied().unwrap_or(1.0);
437                let new_dist = u_dist + weight;
438
439                if new_dist > *distances.get(&v).unwrap_or(&0.0) {
440                    distances.insert(v, new_dist);
441                    predecessors.insert(v, u);
442                }
443            }
444        }
445    }
446
447    // Find the tensor with maximum distance (end of critical path)
448    let (&end_tensor, &max_dist) = distances
449        .iter()
450        .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())?;
451
452    // Reconstruct path
453    let mut path = Vec::new();
454    let mut current = end_tensor;
455
456    loop {
457        path.push(current);
458        if let Some(&pred) = predecessors.get(&current) {
459            current = pred;
460        } else {
461            break;
462        }
463    }
464
465    path.reverse();
466
467    Some(CriticalPath {
468        tensors: path,
469        nodes: Vec::new(),
470        length: max_dist,
471    })
472}
473
474/// Compute graph diameter (longest shortest path).
475pub fn graph_diameter(graph: &EinsumGraph) -> Option<usize> {
476    let adjacency = build_tensor_adjacency(graph);
477    let mut max_distance = 0;
478
479    // Run BFS from each tensor
480    for start in 0..graph.tensors.len() {
481        let distances = bfs_distances(&adjacency, start);
482        if let Some(&max) = distances.values().max() {
483            max_distance = max_distance.max(max);
484        }
485    }
486
487    Some(max_distance)
488}
489
490/// BFS to compute distances from a source tensor.
491fn bfs_distances(adjacency: &HashMap<usize, Vec<usize>>, source: usize) -> HashMap<usize, usize> {
492    let mut distances = HashMap::new();
493    let mut queue = VecDeque::new();
494
495    distances.insert(source, 0);
496    queue.push_back(source);
497
498    while let Some(u) = queue.pop_front() {
499        let dist_u = distances[&u];
500
501        if let Some(neighbors) = adjacency.get(&u) {
502            for &v in neighbors {
503                if let std::collections::hash_map::Entry::Vacant(e) = distances.entry(v) {
504                    e.insert(dist_u + 1);
505                    queue.push_back(v);
506                }
507            }
508        }
509    }
510
511    distances
512}
513
514/// Find all paths between two tensors.
515pub fn find_all_paths(graph: &EinsumGraph, from: usize, to: usize) -> Vec<Vec<usize>> {
516    let adjacency = build_tensor_adjacency(graph);
517    let mut paths = Vec::new();
518    let mut current_path = Vec::new();
519    let mut visited = HashSet::new();
520
521    dfs_all_paths(
522        from,
523        to,
524        &adjacency,
525        &mut current_path,
526        &mut visited,
527        &mut paths,
528    );
529
530    paths
531}
532
533/// DFS helper for finding all paths.
534fn dfs_all_paths(
535    current: usize,
536    target: usize,
537    adjacency: &HashMap<usize, Vec<usize>>,
538    path: &mut Vec<usize>,
539    visited: &mut HashSet<usize>,
540    paths: &mut Vec<Vec<usize>>,
541) {
542    path.push(current);
543    visited.insert(current);
544
545    if current == target {
546        paths.push(path.clone());
547    } else if let Some(neighbors) = adjacency.get(&current) {
548        for &neighbor in neighbors {
549            if !visited.contains(&neighbor) {
550                dfs_all_paths(neighbor, target, adjacency, path, visited, paths);
551            }
552        }
553    }
554
555    path.pop();
556    visited.remove(&current);
557}
558
559#[cfg(test)]
560mod tests {
561    use super::*;
562    use crate::graph::{EinsumNode, OpType};
563
564    fn create_simple_graph() -> EinsumGraph {
565        let mut graph = EinsumGraph::new();
566        let a = graph.add_tensor("A");
567        let b = graph.add_tensor("B");
568        let c = graph.add_tensor("C");
569
570        let node = EinsumNode {
571            op: OpType::Einsum {
572                spec: "ij,jk->ik".to_string(),
573            },
574            inputs: vec![a, b],
575            outputs: vec![c],
576            metadata: Default::default(),
577        };
578
579        graph.add_node(node).unwrap();
580        graph
581    }
582
583    #[test]
584    fn test_acyclic_graph_no_cycles() {
585        let graph = create_simple_graph();
586        let cycles = find_cycles(&graph);
587        assert!(cycles.is_empty());
588    }
589
590    #[test]
591    fn test_is_dag() {
592        let graph = create_simple_graph();
593        assert!(is_dag(&graph));
594    }
595
596    #[test]
597    fn test_topological_sort() {
598        let graph = create_simple_graph();
599        let topo = topological_sort(&graph);
600        assert!(topo.is_some());
601        let order = topo.unwrap();
602        assert_eq!(order.len(), 3);
603    }
604
605    #[test]
606    fn test_strongly_connected_components() {
607        let graph = create_simple_graph();
608        let sccs = strongly_connected_components(&graph);
609        // In a DAG, each node is its own SCC
610        assert_eq!(sccs.len(), 3);
611    }
612
613    #[test]
614    fn test_graph_diameter() {
615        let graph = create_simple_graph();
616        let diameter = graph_diameter(&graph);
617        assert!(diameter.is_some());
618        assert!(diameter.unwrap() >= 1);
619    }
620
621    #[test]
622    fn test_critical_path() {
623        let graph = create_simple_graph();
624        let weights = HashMap::new(); // All weights = 1
625        let critical = critical_path_analysis(&graph, &weights);
626        assert!(critical.is_some());
627    }
628
629    #[test]
630    fn test_find_all_paths() {
631        let graph = create_simple_graph();
632        // A -> C (through B)
633        let paths = find_all_paths(&graph, 0, 2);
634        assert!(!paths.is_empty());
635    }
636
637    #[test]
638    fn test_isomorphism_identical_graphs() {
639        let g1 = create_simple_graph();
640        let g2 = create_simple_graph();
641
642        let result = are_isomorphic(&g1, &g2);
643        assert!(matches!(result, IsomorphismResult::Isomorphic { .. }));
644    }
645
646    #[test]
647    fn test_isomorphism_different_sizes() {
648        let g1 = create_simple_graph();
649        let mut g2 = EinsumGraph::new();
650        g2.add_tensor("A");
651
652        let result = are_isomorphic(&g1, &g2);
653        assert_eq!(result, IsomorphismResult::NotIsomorphic);
654    }
655
656    #[test]
657    fn test_tensor_adjacency() {
658        let graph = create_simple_graph();
659        let adj = build_tensor_adjacency(&graph);
660
661        // A -> C and B -> C
662        assert!(adj.contains_key(&0));
663        assert!(adj.contains_key(&1));
664    }
665
666    #[test]
667    fn test_degree_sequence() {
668        let graph = create_simple_graph();
669        let deg_seq = compute_degree_sequence(&graph);
670        assert_eq!(deg_seq.len(), 3);
671    }
672
673    #[test]
674    fn test_bfs_distances() {
675        let mut adj = HashMap::new();
676        adj.insert(0, vec![1, 2]);
677        adj.insert(1, vec![3]);
678        adj.insert(2, vec![3]);
679
680        let distances = bfs_distances(&adj, 0);
681        assert_eq!(distances[&0], 0);
682        assert_eq!(distances[&1], 1);
683        assert_eq!(distances[&3], 2);
684    }
685}