Skip to main content

tensorlogic_ir/graph/
advanced_algorithms.rs

1//! # Advanced Graph Algorithms for EinsumGraph
2//!
3//! This module implements sophisticated graph analysis algorithms for tensor computation graphs:
4//!
5//! - **Cycle Detection**: Find cycles in the computation graph (important for detecting feedback loops)
6//! - **Strongly Connected Components (SCC)**: Find maximal strongly connected subgraphs (Tarjan's algorithm)
7//! - **Topological Ordering**: Generate execution orders respecting dependencies
8//! - **Graph Isomorphism**: Check if two graphs are structurally equivalent
9//! - **Minimum Cut**: Find bottlenecks in computation flow
10//! - **Critical Path Analysis**: Identify longest paths (critical for scheduling)
11//! - **Dominator Trees**: Find nodes that dominate others in the control flow
12//!
13//! ## Applications
14//!
15//! - **Optimization**: Detect opportunities for fusion, reordering, parallelization
16//! - **Verification**: Ensure acyclicity, detect redundancy
17//! - **Scheduling**: Find critical paths, identify parallelizable regions
18//! - **Debugging**: Understand graph structure, find anomalies
19//!
20//! ## Example
21//!
22//! ```rust
23//! use tensorlogic_ir::{EinsumGraph, EinsumNode, find_cycles, strongly_connected_components, topological_sort};
24//!
25//! let mut graph = EinsumGraph::new();
26//! let a = graph.add_tensor("A");
27//! let b = graph.add_tensor("B");
28//! let c = graph.add_tensor("C");
29//!
30//! graph.add_node(EinsumNode::einsum("ij,jk->ik", vec![a, b], vec![c])).expect("unwrap");
31//!
32//! // Detect cycles
33//! let cycles = find_cycles(&graph);
34//! assert!(cycles.is_empty()); // Should be acyclic
35//!
36//! // Find strongly connected components
37//! let sccs = strongly_connected_components(&graph);
38//!
39//! // Topological sort
40//! let topo_order = topological_sort(&graph).expect("unwrap");
41//! ```
42
43use crate::graph::EinsumGraph;
44use serde::{Deserialize, Serialize};
45use std::collections::{HashMap, HashSet, VecDeque};
46
47/// A cycle in the computation graph.
48#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
49pub struct Cycle {
50    /// Tensor indices forming the cycle
51    pub tensors: Vec<usize>,
52    /// Node indices involved in the cycle
53    pub nodes: Vec<usize>,
54}
55
56/// Find all cycles in the computation graph.
57///
58/// Uses depth-first search with backtracking to enumerate all simple cycles.
59/// Note: This can be expensive for large graphs with many cycles.
60pub fn find_cycles(graph: &EinsumGraph) -> Vec<Cycle> {
61    let mut cycles = Vec::new();
62    let mut visited = HashSet::new();
63    let mut rec_stack = HashSet::new();
64    let mut path = Vec::new();
65
66    // Build adjacency list for tensor dependencies
67    let adjacency = build_tensor_adjacency(graph);
68
69    for tensor_idx in 0..graph.tensors.len() {
70        if !visited.contains(&tensor_idx) {
71            dfs_find_cycles(
72                tensor_idx,
73                &adjacency,
74                &mut visited,
75                &mut rec_stack,
76                &mut path,
77                &mut cycles,
78            );
79        }
80    }
81
82    cycles
83}
84
85/// DFS helper for cycle detection.
86fn dfs_find_cycles(
87    tensor: usize,
88    adjacency: &HashMap<usize, Vec<usize>>,
89    visited: &mut HashSet<usize>,
90    rec_stack: &mut HashSet<usize>,
91    path: &mut Vec<usize>,
92    cycles: &mut Vec<Cycle>,
93) {
94    visited.insert(tensor);
95    rec_stack.insert(tensor);
96    path.push(tensor);
97
98    if let Some(neighbors) = adjacency.get(&tensor) {
99        for &neighbor in neighbors {
100            if !visited.contains(&neighbor) {
101                dfs_find_cycles(neighbor, adjacency, visited, rec_stack, path, cycles);
102            } else if rec_stack.contains(&neighbor) {
103                // Found a cycle
104                if let Some(cycle_start) = path.iter().position(|&t| t == neighbor) {
105                    let cycle_tensors = path[cycle_start..].to_vec();
106                    cycles.push(Cycle {
107                        tensors: cycle_tensors,
108                        nodes: Vec::new(), // Would need to compute from tensors
109                    });
110                }
111            }
112        }
113    }
114
115    path.pop();
116    rec_stack.remove(&tensor);
117}
118
119/// Build tensor adjacency list from graph.
120fn build_tensor_adjacency(graph: &EinsumGraph) -> HashMap<usize, Vec<usize>> {
121    let mut adjacency: HashMap<usize, Vec<usize>> = HashMap::new();
122
123    for node in &graph.nodes {
124        for &input_tensor in &node.inputs {
125            for &output_tensor in &node.outputs {
126                adjacency
127                    .entry(input_tensor)
128                    .or_default()
129                    .push(output_tensor);
130            }
131        }
132    }
133
134    adjacency
135}
136
137/// A strongly connected component (SCC) in the computation graph.
138#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
139pub struct StronglyConnectedComponent {
140    /// Tensor indices in this SCC
141    pub tensors: Vec<usize>,
142    /// Node indices in this SCC
143    pub nodes: Vec<usize>,
144}
145
146/// Find all strongly connected components using Tarjan's algorithm.
147///
148/// An SCC is a maximal set of nodes where every node is reachable from every other node.
149/// This is useful for detecting mutually dependent computations.
150pub fn strongly_connected_components(graph: &EinsumGraph) -> Vec<StronglyConnectedComponent> {
151    let mut tarjan = TarjanSCC::new(graph);
152    tarjan.find_sccs();
153    tarjan.sccs
154}
155
156/// Tarjan's algorithm for finding SCCs.
157struct TarjanSCC<'a> {
158    graph: &'a EinsumGraph,
159    adjacency: HashMap<usize, Vec<usize>>,
160    index: usize,
161    indices: HashMap<usize, usize>,
162    lowlinks: HashMap<usize, usize>,
163    on_stack: HashSet<usize>,
164    stack: Vec<usize>,
165    sccs: Vec<StronglyConnectedComponent>,
166}
167
168impl<'a> TarjanSCC<'a> {
169    fn new(graph: &'a EinsumGraph) -> Self {
170        TarjanSCC {
171            graph,
172            adjacency: build_tensor_adjacency(graph),
173            index: 0,
174            indices: HashMap::new(),
175            lowlinks: HashMap::new(),
176            on_stack: HashSet::new(),
177            stack: Vec::new(),
178            sccs: Vec::new(),
179        }
180    }
181
182    fn find_sccs(&mut self) {
183        for tensor_idx in 0..self.graph.tensors.len() {
184            if !self.indices.contains_key(&tensor_idx) {
185                self.strong_connect(tensor_idx);
186            }
187        }
188    }
189
190    fn strong_connect(&mut self, v: usize) {
191        self.indices.insert(v, self.index);
192        self.lowlinks.insert(v, self.index);
193        self.index += 1;
194        self.stack.push(v);
195        self.on_stack.insert(v);
196
197        if let Some(neighbors) = self.adjacency.get(&v).cloned() {
198            for w in neighbors {
199                if !self.indices.contains_key(&w) {
200                    self.strong_connect(w);
201                    let w_lowlink = *self
202                        .lowlinks
203                        .get(&w)
204                        .expect("lowlink must exist for visited node");
205                    let v_lowlink = *self
206                        .lowlinks
207                        .get(&v)
208                        .expect("lowlink must exist for visited node");
209                    self.lowlinks.insert(v, v_lowlink.min(w_lowlink));
210                } else if self.on_stack.contains(&w) {
211                    let w_index = *self
212                        .indices
213                        .get(&w)
214                        .expect("index must exist for visited node");
215                    let v_lowlink = *self
216                        .lowlinks
217                        .get(&v)
218                        .expect("lowlink must exist for visited node");
219                    self.lowlinks.insert(v, v_lowlink.min(w_index));
220                }
221            }
222        }
223
224        // If v is a root node, pop the stack to get an SCC
225        if self.lowlinks[&v] == self.indices[&v] {
226            let mut scc_tensors = Vec::new();
227            loop {
228                let w = self
229                    .stack
230                    .pop()
231                    .expect("stack must be non-empty when processing SCC");
232                self.on_stack.remove(&w);
233                scc_tensors.push(w);
234                if w == v {
235                    break;
236                }
237            }
238            self.sccs.push(StronglyConnectedComponent {
239                tensors: scc_tensors,
240                nodes: Vec::new(),
241            });
242        }
243    }
244}
245
246/// Perform topological sort on the computation graph.
247///
248/// Returns a linearization of tensors such that if there's a dependency from A to B,
249/// A appears before B in the ordering. Returns `None` if the graph contains cycles.
250pub fn topological_sort(graph: &EinsumGraph) -> Option<Vec<usize>> {
251    let adjacency = build_tensor_adjacency(graph);
252    let mut in_degree = vec![0; graph.tensors.len()];
253
254    // Compute in-degrees
255    for neighbors in adjacency.values() {
256        for &neighbor in neighbors {
257            in_degree[neighbor] += 1;
258        }
259    }
260
261    // Queue of tensors with in-degree 0
262    let mut queue: VecDeque<usize> = in_degree
263        .iter()
264        .enumerate()
265        .filter(|(_, &deg)| deg == 0)
266        .map(|(idx, _)| idx)
267        .collect();
268
269    let mut result = Vec::new();
270
271    while let Some(tensor) = queue.pop_front() {
272        result.push(tensor);
273
274        if let Some(neighbors) = adjacency.get(&tensor) {
275            for &neighbor in neighbors {
276                in_degree[neighbor] -= 1;
277                if in_degree[neighbor] == 0 {
278                    queue.push_back(neighbor);
279                }
280            }
281        }
282    }
283
284    // If we didn't process all tensors, there's a cycle
285    if result.len() == graph.tensors.len() {
286        Some(result)
287    } else {
288        None
289    }
290}
291
292/// Check if a graph is a directed acyclic graph (DAG).
293pub fn is_dag(graph: &EinsumGraph) -> bool {
294    topological_sort(graph).is_some()
295}
296
297/// Graph isomorphism result.
298#[derive(Clone, Debug, PartialEq, Eq)]
299pub enum IsomorphismResult {
300    /// Graphs are isomorphic with the given mapping
301    Isomorphic { mapping: HashMap<usize, usize> },
302    /// Graphs are not isomorphic
303    NotIsomorphic,
304}
305
306/// Check if two graphs are isomorphic.
307///
308/// This uses a simplified algorithm based on degree sequences and local structure.
309/// Note: Graph isomorphism is NP-complete in general, so this uses heuristics.
310pub fn are_isomorphic(g1: &EinsumGraph, g2: &EinsumGraph) -> IsomorphismResult {
311    // Quick checks
312    if g1.tensors.len() != g2.tensors.len() || g1.nodes.len() != g2.nodes.len() {
313        return IsomorphismResult::NotIsomorphic;
314    }
315
316    // Check degree sequences
317    let deg1 = compute_degree_sequence(g1);
318    let deg2 = compute_degree_sequence(g2);
319
320    if deg1 != deg2 {
321        return IsomorphismResult::NotIsomorphic;
322    }
323
324    // Try to find an isomorphism using backtracking
325    // (This is a simplified implementation; full GI would use more sophisticated methods)
326
327    let mut mapping = HashMap::new();
328    if backtrack_isomorphism(g1, g2, &mut mapping, 0) {
329        IsomorphismResult::Isomorphic { mapping }
330    } else {
331        IsomorphismResult::NotIsomorphic
332    }
333}
334
335/// Compute degree sequence for a graph.
336fn compute_degree_sequence(graph: &EinsumGraph) -> Vec<(usize, usize)> {
337    let mut in_degrees = vec![0; graph.tensors.len()];
338    let mut out_degrees = vec![0; graph.tensors.len()];
339
340    for node in &graph.nodes {
341        for &input in &node.inputs {
342            out_degrees[input] += 1;
343        }
344        for &output in &node.outputs {
345            in_degrees[output] += 1;
346        }
347    }
348
349    let mut degrees: Vec<(usize, usize)> = in_degrees.into_iter().zip(out_degrees).collect();
350
351    degrees.sort_unstable();
352    degrees
353}
354
355/// Backtracking search for graph isomorphism.
356fn backtrack_isomorphism(
357    g1: &EinsumGraph,
358    g2: &EinsumGraph,
359    mapping: &mut HashMap<usize, usize>,
360    tensor_idx: usize,
361) -> bool {
362    // Base case: all tensors mapped
363    if tensor_idx >= g1.tensors.len() {
364        return verify_isomorphism(g1, g2, mapping);
365    }
366
367    // Try mapping tensor_idx to each unmapped tensor in g2
368    let mapped_values: HashSet<usize> = mapping.values().copied().collect();
369
370    for candidate in 0..g2.tensors.len() {
371        if !mapped_values.contains(&candidate) {
372            mapping.insert(tensor_idx, candidate);
373
374            if backtrack_isomorphism(g1, g2, mapping, tensor_idx + 1) {
375                return true;
376            }
377
378            mapping.remove(&tensor_idx);
379        }
380    }
381
382    false
383}
384
385/// Verify that a mapping is a valid isomorphism.
386fn verify_isomorphism(g1: &EinsumGraph, g2: &EinsumGraph, mapping: &HashMap<usize, usize>) -> bool {
387    // Check that every edge in g1 maps to an edge in g2
388    let adj1 = build_tensor_adjacency(g1);
389    let adj2 = build_tensor_adjacency(g2);
390
391    for (u, neighbors) in &adj1 {
392        let u_mapped = mapping[u];
393
394        for &v in neighbors {
395            let v_mapped = mapping[&v];
396
397            // Check if edge (u_mapped -> v_mapped) exists in g2
398            if let Some(adj2_neighbors) = adj2.get(&u_mapped) {
399                if !adj2_neighbors.contains(&v_mapped) {
400                    return false;
401                }
402            } else {
403                return false;
404            }
405        }
406    }
407
408    true
409}
410
411/// Critical path analysis result.
412#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
413pub struct CriticalPath {
414    /// Tensors on the critical path
415    pub tensors: Vec<usize>,
416    /// Nodes on the critical path
417    pub nodes: Vec<usize>,
418    /// Total length (sum of weights) of the critical path
419    pub length: f64,
420}
421
422/// Find the critical path in the computation graph.
423///
424/// The critical path is the longest path from inputs to outputs,
425/// which represents the minimum time required for execution.
426pub fn critical_path_analysis(
427    graph: &EinsumGraph,
428    weights: &HashMap<usize, f64>,
429) -> Option<CriticalPath> {
430    if !is_dag(graph) {
431        return None; // Critical path only defined for DAGs
432    }
433
434    let topo_order = topological_sort(graph)?;
435    let adjacency = build_tensor_adjacency(graph);
436
437    let mut distances: HashMap<usize, f64> = HashMap::new();
438    let mut predecessors: HashMap<usize, usize> = HashMap::new();
439
440    // Initialize distances
441    for &tensor in &topo_order {
442        distances.insert(tensor, 0.0);
443    }
444
445    // Compute longest paths
446    for &u in &topo_order {
447        if let Some(neighbors) = adjacency.get(&u) {
448            let u_dist = distances[&u];
449
450            for &v in neighbors {
451                let weight = weights.get(&v).copied().unwrap_or(1.0);
452                let new_dist = u_dist + weight;
453
454                if new_dist > *distances.get(&v).unwrap_or(&0.0) {
455                    distances.insert(v, new_dist);
456                    predecessors.insert(v, u);
457                }
458            }
459        }
460    }
461
462    // Find the tensor with maximum distance (end of critical path)
463    let (&end_tensor, &max_dist) = distances
464        .iter()
465        .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))?;
466
467    // Reconstruct path
468    let mut path = Vec::new();
469    let mut current = end_tensor;
470
471    loop {
472        path.push(current);
473        if let Some(&pred) = predecessors.get(&current) {
474            current = pred;
475        } else {
476            break;
477        }
478    }
479
480    path.reverse();
481
482    Some(CriticalPath {
483        tensors: path,
484        nodes: Vec::new(),
485        length: max_dist,
486    })
487}
488
489/// Compute graph diameter (longest shortest path).
490pub fn graph_diameter(graph: &EinsumGraph) -> Option<usize> {
491    let adjacency = build_tensor_adjacency(graph);
492    let mut max_distance = 0;
493
494    // Run BFS from each tensor
495    for start in 0..graph.tensors.len() {
496        let distances = bfs_distances(&adjacency, start);
497        if let Some(&max) = distances.values().max() {
498            max_distance = max_distance.max(max);
499        }
500    }
501
502    Some(max_distance)
503}
504
505/// BFS to compute distances from a source tensor.
506fn bfs_distances(adjacency: &HashMap<usize, Vec<usize>>, source: usize) -> HashMap<usize, usize> {
507    let mut distances = HashMap::new();
508    let mut queue = VecDeque::new();
509
510    distances.insert(source, 0);
511    queue.push_back(source);
512
513    while let Some(u) = queue.pop_front() {
514        let dist_u = distances[&u];
515
516        if let Some(neighbors) = adjacency.get(&u) {
517            for &v in neighbors {
518                if let std::collections::hash_map::Entry::Vacant(e) = distances.entry(v) {
519                    e.insert(dist_u + 1);
520                    queue.push_back(v);
521                }
522            }
523        }
524    }
525
526    distances
527}
528
529/// Find all paths between two tensors.
530pub fn find_all_paths(graph: &EinsumGraph, from: usize, to: usize) -> Vec<Vec<usize>> {
531    let adjacency = build_tensor_adjacency(graph);
532    let mut paths = Vec::new();
533    let mut current_path = Vec::new();
534    let mut visited = HashSet::new();
535
536    dfs_all_paths(
537        from,
538        to,
539        &adjacency,
540        &mut current_path,
541        &mut visited,
542        &mut paths,
543    );
544
545    paths
546}
547
548/// DFS helper for finding all paths.
549fn dfs_all_paths(
550    current: usize,
551    target: usize,
552    adjacency: &HashMap<usize, Vec<usize>>,
553    path: &mut Vec<usize>,
554    visited: &mut HashSet<usize>,
555    paths: &mut Vec<Vec<usize>>,
556) {
557    path.push(current);
558    visited.insert(current);
559
560    if current == target {
561        paths.push(path.clone());
562    } else if let Some(neighbors) = adjacency.get(&current) {
563        for &neighbor in neighbors {
564            if !visited.contains(&neighbor) {
565                dfs_all_paths(neighbor, target, adjacency, path, visited, paths);
566            }
567        }
568    }
569
570    path.pop();
571    visited.remove(&current);
572}
573
574#[cfg(test)]
575mod tests {
576    use super::*;
577    use crate::graph::{EinsumNode, OpType};
578
579    fn create_simple_graph() -> EinsumGraph {
580        let mut graph = EinsumGraph::new();
581        let a = graph.add_tensor("A");
582        let b = graph.add_tensor("B");
583        let c = graph.add_tensor("C");
584
585        let node = EinsumNode {
586            op: OpType::Einsum {
587                spec: "ij,jk->ik".to_string(),
588            },
589            inputs: vec![a, b],
590            outputs: vec![c],
591            metadata: Default::default(),
592        };
593
594        graph.add_node(node).expect("unwrap");
595        graph
596    }
597
598    #[test]
599    fn test_acyclic_graph_no_cycles() {
600        let graph = create_simple_graph();
601        let cycles = find_cycles(&graph);
602        assert!(cycles.is_empty());
603    }
604
605    #[test]
606    fn test_is_dag() {
607        let graph = create_simple_graph();
608        assert!(is_dag(&graph));
609    }
610
611    #[test]
612    fn test_topological_sort() {
613        let graph = create_simple_graph();
614        let topo = topological_sort(&graph);
615        assert!(topo.is_some());
616        let order = topo.expect("unwrap");
617        assert_eq!(order.len(), 3);
618    }
619
620    #[test]
621    fn test_strongly_connected_components() {
622        let graph = create_simple_graph();
623        let sccs = strongly_connected_components(&graph);
624        // In a DAG, each node is its own SCC
625        assert_eq!(sccs.len(), 3);
626    }
627
628    #[test]
629    fn test_graph_diameter() {
630        let graph = create_simple_graph();
631        let diameter = graph_diameter(&graph);
632        assert!(diameter.is_some());
633        assert!(diameter.expect("unwrap") >= 1);
634    }
635
636    #[test]
637    fn test_critical_path() {
638        let graph = create_simple_graph();
639        let weights = HashMap::new(); // All weights = 1
640        let critical = critical_path_analysis(&graph, &weights);
641        assert!(critical.is_some());
642    }
643
644    #[test]
645    fn test_find_all_paths() {
646        let graph = create_simple_graph();
647        // A -> C (through B)
648        let paths = find_all_paths(&graph, 0, 2);
649        assert!(!paths.is_empty());
650    }
651
652    #[test]
653    fn test_isomorphism_identical_graphs() {
654        let g1 = create_simple_graph();
655        let g2 = create_simple_graph();
656
657        let result = are_isomorphic(&g1, &g2);
658        assert!(matches!(result, IsomorphismResult::Isomorphic { .. }));
659    }
660
661    #[test]
662    fn test_isomorphism_different_sizes() {
663        let g1 = create_simple_graph();
664        let mut g2 = EinsumGraph::new();
665        g2.add_tensor("A");
666
667        let result = are_isomorphic(&g1, &g2);
668        assert_eq!(result, IsomorphismResult::NotIsomorphic);
669    }
670
671    #[test]
672    fn test_tensor_adjacency() {
673        let graph = create_simple_graph();
674        let adj = build_tensor_adjacency(&graph);
675
676        // A -> C and B -> C
677        assert!(adj.contains_key(&0));
678        assert!(adj.contains_key(&1));
679    }
680
681    #[test]
682    fn test_degree_sequence() {
683        let graph = create_simple_graph();
684        let deg_seq = compute_degree_sequence(&graph);
685        assert_eq!(deg_seq.len(), 3);
686    }
687
688    #[test]
689    fn test_bfs_distances() {
690        let mut adj = HashMap::new();
691        adj.insert(0, vec![1, 2]);
692        adj.insert(1, vec![3]);
693        adj.insert(2, vec![3]);
694
695        let distances = bfs_distances(&adj, 0);
696        assert_eq!(distances[&0], 0);
697        assert_eq!(distances[&1], 1);
698        assert_eq!(distances[&3], 2);
699    }
700}