1use crate::graph::EinsumGraph;
44use serde::{Deserialize, Serialize};
45use std::collections::{HashMap, HashSet, VecDeque};
46
47#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
49pub struct Cycle {
50 pub tensors: Vec<usize>,
52 pub nodes: Vec<usize>,
54}
55
56pub fn find_cycles(graph: &EinsumGraph) -> Vec<Cycle> {
61 let mut cycles = Vec::new();
62 let mut visited = HashSet::new();
63 let mut rec_stack = HashSet::new();
64 let mut path = Vec::new();
65
66 let adjacency = build_tensor_adjacency(graph);
68
69 for tensor_idx in 0..graph.tensors.len() {
70 if !visited.contains(&tensor_idx) {
71 dfs_find_cycles(
72 tensor_idx,
73 &adjacency,
74 &mut visited,
75 &mut rec_stack,
76 &mut path,
77 &mut cycles,
78 );
79 }
80 }
81
82 cycles
83}
84
85fn dfs_find_cycles(
87 tensor: usize,
88 adjacency: &HashMap<usize, Vec<usize>>,
89 visited: &mut HashSet<usize>,
90 rec_stack: &mut HashSet<usize>,
91 path: &mut Vec<usize>,
92 cycles: &mut Vec<Cycle>,
93) {
94 visited.insert(tensor);
95 rec_stack.insert(tensor);
96 path.push(tensor);
97
98 if let Some(neighbors) = adjacency.get(&tensor) {
99 for &neighbor in neighbors {
100 if !visited.contains(&neighbor) {
101 dfs_find_cycles(neighbor, adjacency, visited, rec_stack, path, cycles);
102 } else if rec_stack.contains(&neighbor) {
103 if let Some(cycle_start) = path.iter().position(|&t| t == neighbor) {
105 let cycle_tensors = path[cycle_start..].to_vec();
106 cycles.push(Cycle {
107 tensors: cycle_tensors,
108 nodes: Vec::new(), });
110 }
111 }
112 }
113 }
114
115 path.pop();
116 rec_stack.remove(&tensor);
117}
118
119fn build_tensor_adjacency(graph: &EinsumGraph) -> HashMap<usize, Vec<usize>> {
121 let mut adjacency: HashMap<usize, Vec<usize>> = HashMap::new();
122
123 for node in &graph.nodes {
124 for &input_tensor in &node.inputs {
125 for &output_tensor in &node.outputs {
126 adjacency
127 .entry(input_tensor)
128 .or_default()
129 .push(output_tensor);
130 }
131 }
132 }
133
134 adjacency
135}
136
137#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
139pub struct StronglyConnectedComponent {
140 pub tensors: Vec<usize>,
142 pub nodes: Vec<usize>,
144}
145
146pub fn strongly_connected_components(graph: &EinsumGraph) -> Vec<StronglyConnectedComponent> {
151 let mut tarjan = TarjanSCC::new(graph);
152 tarjan.find_sccs();
153 tarjan.sccs
154}
155
156struct TarjanSCC<'a> {
158 graph: &'a EinsumGraph,
159 adjacency: HashMap<usize, Vec<usize>>,
160 index: usize,
161 indices: HashMap<usize, usize>,
162 lowlinks: HashMap<usize, usize>,
163 on_stack: HashSet<usize>,
164 stack: Vec<usize>,
165 sccs: Vec<StronglyConnectedComponent>,
166}
167
168impl<'a> TarjanSCC<'a> {
169 fn new(graph: &'a EinsumGraph) -> Self {
170 TarjanSCC {
171 graph,
172 adjacency: build_tensor_adjacency(graph),
173 index: 0,
174 indices: HashMap::new(),
175 lowlinks: HashMap::new(),
176 on_stack: HashSet::new(),
177 stack: Vec::new(),
178 sccs: Vec::new(),
179 }
180 }
181
182 fn find_sccs(&mut self) {
183 for tensor_idx in 0..self.graph.tensors.len() {
184 if !self.indices.contains_key(&tensor_idx) {
185 self.strong_connect(tensor_idx);
186 }
187 }
188 }
189
190 fn strong_connect(&mut self, v: usize) {
191 self.indices.insert(v, self.index);
192 self.lowlinks.insert(v, self.index);
193 self.index += 1;
194 self.stack.push(v);
195 self.on_stack.insert(v);
196
197 if let Some(neighbors) = self.adjacency.get(&v).cloned() {
198 for w in neighbors {
199 if !self.indices.contains_key(&w) {
200 self.strong_connect(w);
201 let w_lowlink = *self.lowlinks.get(&w).unwrap();
202 let v_lowlink = *self.lowlinks.get(&v).unwrap();
203 self.lowlinks.insert(v, v_lowlink.min(w_lowlink));
204 } else if self.on_stack.contains(&w) {
205 let w_index = *self.indices.get(&w).unwrap();
206 let v_lowlink = *self.lowlinks.get(&v).unwrap();
207 self.lowlinks.insert(v, v_lowlink.min(w_index));
208 }
209 }
210 }
211
212 if self.lowlinks[&v] == self.indices[&v] {
214 let mut scc_tensors = Vec::new();
215 loop {
216 let w = self.stack.pop().unwrap();
217 self.on_stack.remove(&w);
218 scc_tensors.push(w);
219 if w == v {
220 break;
221 }
222 }
223 self.sccs.push(StronglyConnectedComponent {
224 tensors: scc_tensors,
225 nodes: Vec::new(),
226 });
227 }
228 }
229}
230
231pub fn topological_sort(graph: &EinsumGraph) -> Option<Vec<usize>> {
236 let adjacency = build_tensor_adjacency(graph);
237 let mut in_degree = vec![0; graph.tensors.len()];
238
239 for neighbors in adjacency.values() {
241 for &neighbor in neighbors {
242 in_degree[neighbor] += 1;
243 }
244 }
245
246 let mut queue: VecDeque<usize> = in_degree
248 .iter()
249 .enumerate()
250 .filter(|(_, °)| deg == 0)
251 .map(|(idx, _)| idx)
252 .collect();
253
254 let mut result = Vec::new();
255
256 while let Some(tensor) = queue.pop_front() {
257 result.push(tensor);
258
259 if let Some(neighbors) = adjacency.get(&tensor) {
260 for &neighbor in neighbors {
261 in_degree[neighbor] -= 1;
262 if in_degree[neighbor] == 0 {
263 queue.push_back(neighbor);
264 }
265 }
266 }
267 }
268
269 if result.len() == graph.tensors.len() {
271 Some(result)
272 } else {
273 None
274 }
275}
276
277pub fn is_dag(graph: &EinsumGraph) -> bool {
279 topological_sort(graph).is_some()
280}
281
282#[derive(Clone, Debug, PartialEq, Eq)]
284pub enum IsomorphismResult {
285 Isomorphic { mapping: HashMap<usize, usize> },
287 NotIsomorphic,
289}
290
291pub fn are_isomorphic(g1: &EinsumGraph, g2: &EinsumGraph) -> IsomorphismResult {
296 if g1.tensors.len() != g2.tensors.len() || g1.nodes.len() != g2.nodes.len() {
298 return IsomorphismResult::NotIsomorphic;
299 }
300
301 let deg1 = compute_degree_sequence(g1);
303 let deg2 = compute_degree_sequence(g2);
304
305 if deg1 != deg2 {
306 return IsomorphismResult::NotIsomorphic;
307 }
308
309 let mut mapping = HashMap::new();
313 if backtrack_isomorphism(g1, g2, &mut mapping, 0) {
314 IsomorphismResult::Isomorphic { mapping }
315 } else {
316 IsomorphismResult::NotIsomorphic
317 }
318}
319
320fn compute_degree_sequence(graph: &EinsumGraph) -> Vec<(usize, usize)> {
322 let mut in_degrees = vec![0; graph.tensors.len()];
323 let mut out_degrees = vec![0; graph.tensors.len()];
324
325 for node in &graph.nodes {
326 for &input in &node.inputs {
327 out_degrees[input] += 1;
328 }
329 for &output in &node.outputs {
330 in_degrees[output] += 1;
331 }
332 }
333
334 let mut degrees: Vec<(usize, usize)> = in_degrees.into_iter().zip(out_degrees).collect();
335
336 degrees.sort_unstable();
337 degrees
338}
339
340fn backtrack_isomorphism(
342 g1: &EinsumGraph,
343 g2: &EinsumGraph,
344 mapping: &mut HashMap<usize, usize>,
345 tensor_idx: usize,
346) -> bool {
347 if tensor_idx >= g1.tensors.len() {
349 return verify_isomorphism(g1, g2, mapping);
350 }
351
352 let mapped_values: HashSet<usize> = mapping.values().copied().collect();
354
355 for candidate in 0..g2.tensors.len() {
356 if !mapped_values.contains(&candidate) {
357 mapping.insert(tensor_idx, candidate);
358
359 if backtrack_isomorphism(g1, g2, mapping, tensor_idx + 1) {
360 return true;
361 }
362
363 mapping.remove(&tensor_idx);
364 }
365 }
366
367 false
368}
369
370fn verify_isomorphism(g1: &EinsumGraph, g2: &EinsumGraph, mapping: &HashMap<usize, usize>) -> bool {
372 let adj1 = build_tensor_adjacency(g1);
374 let adj2 = build_tensor_adjacency(g2);
375
376 for (u, neighbors) in &adj1 {
377 let u_mapped = mapping[u];
378
379 for &v in neighbors {
380 let v_mapped = mapping[&v];
381
382 if let Some(adj2_neighbors) = adj2.get(&u_mapped) {
384 if !adj2_neighbors.contains(&v_mapped) {
385 return false;
386 }
387 } else {
388 return false;
389 }
390 }
391 }
392
393 true
394}
395
396#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
398pub struct CriticalPath {
399 pub tensors: Vec<usize>,
401 pub nodes: Vec<usize>,
403 pub length: f64,
405}
406
407pub fn critical_path_analysis(
412 graph: &EinsumGraph,
413 weights: &HashMap<usize, f64>,
414) -> Option<CriticalPath> {
415 if !is_dag(graph) {
416 return None; }
418
419 let topo_order = topological_sort(graph)?;
420 let adjacency = build_tensor_adjacency(graph);
421
422 let mut distances: HashMap<usize, f64> = HashMap::new();
423 let mut predecessors: HashMap<usize, usize> = HashMap::new();
424
425 for &tensor in &topo_order {
427 distances.insert(tensor, 0.0);
428 }
429
430 for &u in &topo_order {
432 if let Some(neighbors) = adjacency.get(&u) {
433 let u_dist = distances[&u];
434
435 for &v in neighbors {
436 let weight = weights.get(&v).copied().unwrap_or(1.0);
437 let new_dist = u_dist + weight;
438
439 if new_dist > *distances.get(&v).unwrap_or(&0.0) {
440 distances.insert(v, new_dist);
441 predecessors.insert(v, u);
442 }
443 }
444 }
445 }
446
447 let (&end_tensor, &max_dist) = distances
449 .iter()
450 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())?;
451
452 let mut path = Vec::new();
454 let mut current = end_tensor;
455
456 loop {
457 path.push(current);
458 if let Some(&pred) = predecessors.get(¤t) {
459 current = pred;
460 } else {
461 break;
462 }
463 }
464
465 path.reverse();
466
467 Some(CriticalPath {
468 tensors: path,
469 nodes: Vec::new(),
470 length: max_dist,
471 })
472}
473
474pub fn graph_diameter(graph: &EinsumGraph) -> Option<usize> {
476 let adjacency = build_tensor_adjacency(graph);
477 let mut max_distance = 0;
478
479 for start in 0..graph.tensors.len() {
481 let distances = bfs_distances(&adjacency, start);
482 if let Some(&max) = distances.values().max() {
483 max_distance = max_distance.max(max);
484 }
485 }
486
487 Some(max_distance)
488}
489
490fn bfs_distances(adjacency: &HashMap<usize, Vec<usize>>, source: usize) -> HashMap<usize, usize> {
492 let mut distances = HashMap::new();
493 let mut queue = VecDeque::new();
494
495 distances.insert(source, 0);
496 queue.push_back(source);
497
498 while let Some(u) = queue.pop_front() {
499 let dist_u = distances[&u];
500
501 if let Some(neighbors) = adjacency.get(&u) {
502 for &v in neighbors {
503 if let std::collections::hash_map::Entry::Vacant(e) = distances.entry(v) {
504 e.insert(dist_u + 1);
505 queue.push_back(v);
506 }
507 }
508 }
509 }
510
511 distances
512}
513
514pub fn find_all_paths(graph: &EinsumGraph, from: usize, to: usize) -> Vec<Vec<usize>> {
516 let adjacency = build_tensor_adjacency(graph);
517 let mut paths = Vec::new();
518 let mut current_path = Vec::new();
519 let mut visited = HashSet::new();
520
521 dfs_all_paths(
522 from,
523 to,
524 &adjacency,
525 &mut current_path,
526 &mut visited,
527 &mut paths,
528 );
529
530 paths
531}
532
533fn dfs_all_paths(
535 current: usize,
536 target: usize,
537 adjacency: &HashMap<usize, Vec<usize>>,
538 path: &mut Vec<usize>,
539 visited: &mut HashSet<usize>,
540 paths: &mut Vec<Vec<usize>>,
541) {
542 path.push(current);
543 visited.insert(current);
544
545 if current == target {
546 paths.push(path.clone());
547 } else if let Some(neighbors) = adjacency.get(¤t) {
548 for &neighbor in neighbors {
549 if !visited.contains(&neighbor) {
550 dfs_all_paths(neighbor, target, adjacency, path, visited, paths);
551 }
552 }
553 }
554
555 path.pop();
556 visited.remove(¤t);
557}
558
559#[cfg(test)]
560mod tests {
561 use super::*;
562 use crate::graph::{EinsumNode, OpType};
563
564 fn create_simple_graph() -> EinsumGraph {
565 let mut graph = EinsumGraph::new();
566 let a = graph.add_tensor("A");
567 let b = graph.add_tensor("B");
568 let c = graph.add_tensor("C");
569
570 let node = EinsumNode {
571 op: OpType::Einsum {
572 spec: "ij,jk->ik".to_string(),
573 },
574 inputs: vec![a, b],
575 outputs: vec![c],
576 metadata: Default::default(),
577 };
578
579 graph.add_node(node).unwrap();
580 graph
581 }
582
583 #[test]
584 fn test_acyclic_graph_no_cycles() {
585 let graph = create_simple_graph();
586 let cycles = find_cycles(&graph);
587 assert!(cycles.is_empty());
588 }
589
590 #[test]
591 fn test_is_dag() {
592 let graph = create_simple_graph();
593 assert!(is_dag(&graph));
594 }
595
596 #[test]
597 fn test_topological_sort() {
598 let graph = create_simple_graph();
599 let topo = topological_sort(&graph);
600 assert!(topo.is_some());
601 let order = topo.unwrap();
602 assert_eq!(order.len(), 3);
603 }
604
605 #[test]
606 fn test_strongly_connected_components() {
607 let graph = create_simple_graph();
608 let sccs = strongly_connected_components(&graph);
609 assert_eq!(sccs.len(), 3);
611 }
612
613 #[test]
614 fn test_graph_diameter() {
615 let graph = create_simple_graph();
616 let diameter = graph_diameter(&graph);
617 assert!(diameter.is_some());
618 assert!(diameter.unwrap() >= 1);
619 }
620
621 #[test]
622 fn test_critical_path() {
623 let graph = create_simple_graph();
624 let weights = HashMap::new(); let critical = critical_path_analysis(&graph, &weights);
626 assert!(critical.is_some());
627 }
628
629 #[test]
630 fn test_find_all_paths() {
631 let graph = create_simple_graph();
632 let paths = find_all_paths(&graph, 0, 2);
634 assert!(!paths.is_empty());
635 }
636
637 #[test]
638 fn test_isomorphism_identical_graphs() {
639 let g1 = create_simple_graph();
640 let g2 = create_simple_graph();
641
642 let result = are_isomorphic(&g1, &g2);
643 assert!(matches!(result, IsomorphismResult::Isomorphic { .. }));
644 }
645
646 #[test]
647 fn test_isomorphism_different_sizes() {
648 let g1 = create_simple_graph();
649 let mut g2 = EinsumGraph::new();
650 g2.add_tensor("A");
651
652 let result = are_isomorphic(&g1, &g2);
653 assert_eq!(result, IsomorphismResult::NotIsomorphic);
654 }
655
656 #[test]
657 fn test_tensor_adjacency() {
658 let graph = create_simple_graph();
659 let adj = build_tensor_adjacency(&graph);
660
661 assert!(adj.contains_key(&0));
663 assert!(adj.contains_key(&1));
664 }
665
666 #[test]
667 fn test_degree_sequence() {
668 let graph = create_simple_graph();
669 let deg_seq = compute_degree_sequence(&graph);
670 assert_eq!(deg_seq.len(), 3);
671 }
672
673 #[test]
674 fn test_bfs_distances() {
675 let mut adj = HashMap::new();
676 adj.insert(0, vec![1, 2]);
677 adj.insert(1, vec![3]);
678 adj.insert(2, vec![3]);
679
680 let distances = bfs_distances(&adj, 0);
681 assert_eq!(distances[&0], 0);
682 assert_eq!(distances[&1], 1);
683 assert_eq!(distances[&3], 2);
684 }
685}