1use crate::graph::EinsumGraph;
44use serde::{Deserialize, Serialize};
45use std::collections::{HashMap, HashSet, VecDeque};
46
47#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
49pub struct Cycle {
50 pub tensors: Vec<usize>,
52 pub nodes: Vec<usize>,
54}
55
56pub fn find_cycles(graph: &EinsumGraph) -> Vec<Cycle> {
61 let mut cycles = Vec::new();
62 let mut visited = HashSet::new();
63 let mut rec_stack = HashSet::new();
64 let mut path = Vec::new();
65
66 let adjacency = build_tensor_adjacency(graph);
68
69 for tensor_idx in 0..graph.tensors.len() {
70 if !visited.contains(&tensor_idx) {
71 dfs_find_cycles(
72 tensor_idx,
73 &adjacency,
74 &mut visited,
75 &mut rec_stack,
76 &mut path,
77 &mut cycles,
78 );
79 }
80 }
81
82 cycles
83}
84
85fn dfs_find_cycles(
87 tensor: usize,
88 adjacency: &HashMap<usize, Vec<usize>>,
89 visited: &mut HashSet<usize>,
90 rec_stack: &mut HashSet<usize>,
91 path: &mut Vec<usize>,
92 cycles: &mut Vec<Cycle>,
93) {
94 visited.insert(tensor);
95 rec_stack.insert(tensor);
96 path.push(tensor);
97
98 if let Some(neighbors) = adjacency.get(&tensor) {
99 for &neighbor in neighbors {
100 if !visited.contains(&neighbor) {
101 dfs_find_cycles(neighbor, adjacency, visited, rec_stack, path, cycles);
102 } else if rec_stack.contains(&neighbor) {
103 if let Some(cycle_start) = path.iter().position(|&t| t == neighbor) {
105 let cycle_tensors = path[cycle_start..].to_vec();
106 cycles.push(Cycle {
107 tensors: cycle_tensors,
108 nodes: Vec::new(), });
110 }
111 }
112 }
113 }
114
115 path.pop();
116 rec_stack.remove(&tensor);
117}
118
119fn build_tensor_adjacency(graph: &EinsumGraph) -> HashMap<usize, Vec<usize>> {
121 let mut adjacency: HashMap<usize, Vec<usize>> = HashMap::new();
122
123 for node in &graph.nodes {
124 for &input_tensor in &node.inputs {
125 for &output_tensor in &node.outputs {
126 adjacency
127 .entry(input_tensor)
128 .or_default()
129 .push(output_tensor);
130 }
131 }
132 }
133
134 adjacency
135}
136
137#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
139pub struct StronglyConnectedComponent {
140 pub tensors: Vec<usize>,
142 pub nodes: Vec<usize>,
144}
145
146pub fn strongly_connected_components(graph: &EinsumGraph) -> Vec<StronglyConnectedComponent> {
151 let mut tarjan = TarjanSCC::new(graph);
152 tarjan.find_sccs();
153 tarjan.sccs
154}
155
156struct TarjanSCC<'a> {
158 graph: &'a EinsumGraph,
159 adjacency: HashMap<usize, Vec<usize>>,
160 index: usize,
161 indices: HashMap<usize, usize>,
162 lowlinks: HashMap<usize, usize>,
163 on_stack: HashSet<usize>,
164 stack: Vec<usize>,
165 sccs: Vec<StronglyConnectedComponent>,
166}
167
168impl<'a> TarjanSCC<'a> {
169 fn new(graph: &'a EinsumGraph) -> Self {
170 TarjanSCC {
171 graph,
172 adjacency: build_tensor_adjacency(graph),
173 index: 0,
174 indices: HashMap::new(),
175 lowlinks: HashMap::new(),
176 on_stack: HashSet::new(),
177 stack: Vec::new(),
178 sccs: Vec::new(),
179 }
180 }
181
182 fn find_sccs(&mut self) {
183 for tensor_idx in 0..self.graph.tensors.len() {
184 if !self.indices.contains_key(&tensor_idx) {
185 self.strong_connect(tensor_idx);
186 }
187 }
188 }
189
190 fn strong_connect(&mut self, v: usize) {
191 self.indices.insert(v, self.index);
192 self.lowlinks.insert(v, self.index);
193 self.index += 1;
194 self.stack.push(v);
195 self.on_stack.insert(v);
196
197 if let Some(neighbors) = self.adjacency.get(&v).cloned() {
198 for w in neighbors {
199 if !self.indices.contains_key(&w) {
200 self.strong_connect(w);
201 let w_lowlink = *self
202 .lowlinks
203 .get(&w)
204 .expect("lowlink must exist for visited node");
205 let v_lowlink = *self
206 .lowlinks
207 .get(&v)
208 .expect("lowlink must exist for visited node");
209 self.lowlinks.insert(v, v_lowlink.min(w_lowlink));
210 } else if self.on_stack.contains(&w) {
211 let w_index = *self
212 .indices
213 .get(&w)
214 .expect("index must exist for visited node");
215 let v_lowlink = *self
216 .lowlinks
217 .get(&v)
218 .expect("lowlink must exist for visited node");
219 self.lowlinks.insert(v, v_lowlink.min(w_index));
220 }
221 }
222 }
223
224 if self.lowlinks[&v] == self.indices[&v] {
226 let mut scc_tensors = Vec::new();
227 loop {
228 let w = self
229 .stack
230 .pop()
231 .expect("stack must be non-empty when processing SCC");
232 self.on_stack.remove(&w);
233 scc_tensors.push(w);
234 if w == v {
235 break;
236 }
237 }
238 self.sccs.push(StronglyConnectedComponent {
239 tensors: scc_tensors,
240 nodes: Vec::new(),
241 });
242 }
243 }
244}
245
246pub fn topological_sort(graph: &EinsumGraph) -> Option<Vec<usize>> {
251 let adjacency = build_tensor_adjacency(graph);
252 let mut in_degree = vec![0; graph.tensors.len()];
253
254 for neighbors in adjacency.values() {
256 for &neighbor in neighbors {
257 in_degree[neighbor] += 1;
258 }
259 }
260
261 let mut queue: VecDeque<usize> = in_degree
263 .iter()
264 .enumerate()
265 .filter(|(_, °)| deg == 0)
266 .map(|(idx, _)| idx)
267 .collect();
268
269 let mut result = Vec::new();
270
271 while let Some(tensor) = queue.pop_front() {
272 result.push(tensor);
273
274 if let Some(neighbors) = adjacency.get(&tensor) {
275 for &neighbor in neighbors {
276 in_degree[neighbor] -= 1;
277 if in_degree[neighbor] == 0 {
278 queue.push_back(neighbor);
279 }
280 }
281 }
282 }
283
284 if result.len() == graph.tensors.len() {
286 Some(result)
287 } else {
288 None
289 }
290}
291
292pub fn is_dag(graph: &EinsumGraph) -> bool {
294 topological_sort(graph).is_some()
295}
296
297#[derive(Clone, Debug, PartialEq, Eq)]
299pub enum IsomorphismResult {
300 Isomorphic { mapping: HashMap<usize, usize> },
302 NotIsomorphic,
304}
305
306pub fn are_isomorphic(g1: &EinsumGraph, g2: &EinsumGraph) -> IsomorphismResult {
311 if g1.tensors.len() != g2.tensors.len() || g1.nodes.len() != g2.nodes.len() {
313 return IsomorphismResult::NotIsomorphic;
314 }
315
316 let deg1 = compute_degree_sequence(g1);
318 let deg2 = compute_degree_sequence(g2);
319
320 if deg1 != deg2 {
321 return IsomorphismResult::NotIsomorphic;
322 }
323
324 let mut mapping = HashMap::new();
328 if backtrack_isomorphism(g1, g2, &mut mapping, 0) {
329 IsomorphismResult::Isomorphic { mapping }
330 } else {
331 IsomorphismResult::NotIsomorphic
332 }
333}
334
335fn compute_degree_sequence(graph: &EinsumGraph) -> Vec<(usize, usize)> {
337 let mut in_degrees = vec![0; graph.tensors.len()];
338 let mut out_degrees = vec![0; graph.tensors.len()];
339
340 for node in &graph.nodes {
341 for &input in &node.inputs {
342 out_degrees[input] += 1;
343 }
344 for &output in &node.outputs {
345 in_degrees[output] += 1;
346 }
347 }
348
349 let mut degrees: Vec<(usize, usize)> = in_degrees.into_iter().zip(out_degrees).collect();
350
351 degrees.sort_unstable();
352 degrees
353}
354
355fn backtrack_isomorphism(
357 g1: &EinsumGraph,
358 g2: &EinsumGraph,
359 mapping: &mut HashMap<usize, usize>,
360 tensor_idx: usize,
361) -> bool {
362 if tensor_idx >= g1.tensors.len() {
364 return verify_isomorphism(g1, g2, mapping);
365 }
366
367 let mapped_values: HashSet<usize> = mapping.values().copied().collect();
369
370 for candidate in 0..g2.tensors.len() {
371 if !mapped_values.contains(&candidate) {
372 mapping.insert(tensor_idx, candidate);
373
374 if backtrack_isomorphism(g1, g2, mapping, tensor_idx + 1) {
375 return true;
376 }
377
378 mapping.remove(&tensor_idx);
379 }
380 }
381
382 false
383}
384
385fn verify_isomorphism(g1: &EinsumGraph, g2: &EinsumGraph, mapping: &HashMap<usize, usize>) -> bool {
387 let adj1 = build_tensor_adjacency(g1);
389 let adj2 = build_tensor_adjacency(g2);
390
391 for (u, neighbors) in &adj1 {
392 let u_mapped = mapping[u];
393
394 for &v in neighbors {
395 let v_mapped = mapping[&v];
396
397 if let Some(adj2_neighbors) = adj2.get(&u_mapped) {
399 if !adj2_neighbors.contains(&v_mapped) {
400 return false;
401 }
402 } else {
403 return false;
404 }
405 }
406 }
407
408 true
409}
410
411#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
413pub struct CriticalPath {
414 pub tensors: Vec<usize>,
416 pub nodes: Vec<usize>,
418 pub length: f64,
420}
421
422pub fn critical_path_analysis(
427 graph: &EinsumGraph,
428 weights: &HashMap<usize, f64>,
429) -> Option<CriticalPath> {
430 if !is_dag(graph) {
431 return None; }
433
434 let topo_order = topological_sort(graph)?;
435 let adjacency = build_tensor_adjacency(graph);
436
437 let mut distances: HashMap<usize, f64> = HashMap::new();
438 let mut predecessors: HashMap<usize, usize> = HashMap::new();
439
440 for &tensor in &topo_order {
442 distances.insert(tensor, 0.0);
443 }
444
445 for &u in &topo_order {
447 if let Some(neighbors) = adjacency.get(&u) {
448 let u_dist = distances[&u];
449
450 for &v in neighbors {
451 let weight = weights.get(&v).copied().unwrap_or(1.0);
452 let new_dist = u_dist + weight;
453
454 if new_dist > *distances.get(&v).unwrap_or(&0.0) {
455 distances.insert(v, new_dist);
456 predecessors.insert(v, u);
457 }
458 }
459 }
460 }
461
462 let (&end_tensor, &max_dist) = distances
464 .iter()
465 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))?;
466
467 let mut path = Vec::new();
469 let mut current = end_tensor;
470
471 loop {
472 path.push(current);
473 if let Some(&pred) = predecessors.get(¤t) {
474 current = pred;
475 } else {
476 break;
477 }
478 }
479
480 path.reverse();
481
482 Some(CriticalPath {
483 tensors: path,
484 nodes: Vec::new(),
485 length: max_dist,
486 })
487}
488
489pub fn graph_diameter(graph: &EinsumGraph) -> Option<usize> {
491 let adjacency = build_tensor_adjacency(graph);
492 let mut max_distance = 0;
493
494 for start in 0..graph.tensors.len() {
496 let distances = bfs_distances(&adjacency, start);
497 if let Some(&max) = distances.values().max() {
498 max_distance = max_distance.max(max);
499 }
500 }
501
502 Some(max_distance)
503}
504
505fn bfs_distances(adjacency: &HashMap<usize, Vec<usize>>, source: usize) -> HashMap<usize, usize> {
507 let mut distances = HashMap::new();
508 let mut queue = VecDeque::new();
509
510 distances.insert(source, 0);
511 queue.push_back(source);
512
513 while let Some(u) = queue.pop_front() {
514 let dist_u = distances[&u];
515
516 if let Some(neighbors) = adjacency.get(&u) {
517 for &v in neighbors {
518 if let std::collections::hash_map::Entry::Vacant(e) = distances.entry(v) {
519 e.insert(dist_u + 1);
520 queue.push_back(v);
521 }
522 }
523 }
524 }
525
526 distances
527}
528
529pub fn find_all_paths(graph: &EinsumGraph, from: usize, to: usize) -> Vec<Vec<usize>> {
531 let adjacency = build_tensor_adjacency(graph);
532 let mut paths = Vec::new();
533 let mut current_path = Vec::new();
534 let mut visited = HashSet::new();
535
536 dfs_all_paths(
537 from,
538 to,
539 &adjacency,
540 &mut current_path,
541 &mut visited,
542 &mut paths,
543 );
544
545 paths
546}
547
548fn dfs_all_paths(
550 current: usize,
551 target: usize,
552 adjacency: &HashMap<usize, Vec<usize>>,
553 path: &mut Vec<usize>,
554 visited: &mut HashSet<usize>,
555 paths: &mut Vec<Vec<usize>>,
556) {
557 path.push(current);
558 visited.insert(current);
559
560 if current == target {
561 paths.push(path.clone());
562 } else if let Some(neighbors) = adjacency.get(¤t) {
563 for &neighbor in neighbors {
564 if !visited.contains(&neighbor) {
565 dfs_all_paths(neighbor, target, adjacency, path, visited, paths);
566 }
567 }
568 }
569
570 path.pop();
571 visited.remove(¤t);
572}
573
574#[cfg(test)]
575mod tests {
576 use super::*;
577 use crate::graph::{EinsumNode, OpType};
578
579 fn create_simple_graph() -> EinsumGraph {
580 let mut graph = EinsumGraph::new();
581 let a = graph.add_tensor("A");
582 let b = graph.add_tensor("B");
583 let c = graph.add_tensor("C");
584
585 let node = EinsumNode {
586 op: OpType::Einsum {
587 spec: "ij,jk->ik".to_string(),
588 },
589 inputs: vec![a, b],
590 outputs: vec![c],
591 metadata: Default::default(),
592 };
593
594 graph.add_node(node).expect("unwrap");
595 graph
596 }
597
598 #[test]
599 fn test_acyclic_graph_no_cycles() {
600 let graph = create_simple_graph();
601 let cycles = find_cycles(&graph);
602 assert!(cycles.is_empty());
603 }
604
605 #[test]
606 fn test_is_dag() {
607 let graph = create_simple_graph();
608 assert!(is_dag(&graph));
609 }
610
611 #[test]
612 fn test_topological_sort() {
613 let graph = create_simple_graph();
614 let topo = topological_sort(&graph);
615 assert!(topo.is_some());
616 let order = topo.expect("unwrap");
617 assert_eq!(order.len(), 3);
618 }
619
620 #[test]
621 fn test_strongly_connected_components() {
622 let graph = create_simple_graph();
623 let sccs = strongly_connected_components(&graph);
624 assert_eq!(sccs.len(), 3);
626 }
627
628 #[test]
629 fn test_graph_diameter() {
630 let graph = create_simple_graph();
631 let diameter = graph_diameter(&graph);
632 assert!(diameter.is_some());
633 assert!(diameter.expect("unwrap") >= 1);
634 }
635
636 #[test]
637 fn test_critical_path() {
638 let graph = create_simple_graph();
639 let weights = HashMap::new(); let critical = critical_path_analysis(&graph, &weights);
641 assert!(critical.is_some());
642 }
643
644 #[test]
645 fn test_find_all_paths() {
646 let graph = create_simple_graph();
647 let paths = find_all_paths(&graph, 0, 2);
649 assert!(!paths.is_empty());
650 }
651
652 #[test]
653 fn test_isomorphism_identical_graphs() {
654 let g1 = create_simple_graph();
655 let g2 = create_simple_graph();
656
657 let result = are_isomorphic(&g1, &g2);
658 assert!(matches!(result, IsomorphismResult::Isomorphic { .. }));
659 }
660
661 #[test]
662 fn test_isomorphism_different_sizes() {
663 let g1 = create_simple_graph();
664 let mut g2 = EinsumGraph::new();
665 g2.add_tensor("A");
666
667 let result = are_isomorphic(&g1, &g2);
668 assert_eq!(result, IsomorphismResult::NotIsomorphic);
669 }
670
671 #[test]
672 fn test_tensor_adjacency() {
673 let graph = create_simple_graph();
674 let adj = build_tensor_adjacency(&graph);
675
676 assert!(adj.contains_key(&0));
678 assert!(adj.contains_key(&1));
679 }
680
681 #[test]
682 fn test_degree_sequence() {
683 let graph = create_simple_graph();
684 let deg_seq = compute_degree_sequence(&graph);
685 assert_eq!(deg_seq.len(), 3);
686 }
687
688 #[test]
689 fn test_bfs_distances() {
690 let mut adj = HashMap::new();
691 adj.insert(0, vec![1, 2]);
692 adj.insert(1, vec![3]);
693 adj.insert(2, vec![3]);
694
695 let distances = bfs_distances(&adj, 0);
696 assert_eq!(distances[&0], 0);
697 assert_eq!(distances[&1], 1);
698 assert_eq!(distances[&3], 2);
699 }
700}