1use std::collections::{HashMap, HashSet, VecDeque};
36
37use super::{EinsumGraph, EinsumNode};
38use crate::error::IrError;
39
40pub fn canonicalize_graph(graph: &EinsumGraph) -> Result<EinsumGraph, IrError> {
50 if graph.is_empty() {
52 return Ok(graph.clone());
53 }
54
55 graph.validate()?;
57
58 let tensor_order = topological_sort_tensors(graph)?;
60
61 let mut tensor_mapping = HashMap::new();
63 for (new_idx, &old_idx) in tensor_order.iter().enumerate() {
64 tensor_mapping.insert(old_idx, new_idx);
65 }
66
67 let mut canonical = EinsumGraph::new();
69
70 for i in 0..tensor_order.len() {
72 canonical.add_tensor(format!("t{}", i));
73 }
74
75 let sorted_nodes = topological_sort_nodes(graph)?;
77 for node_idx in sorted_nodes {
78 let old_node = &graph.nodes[node_idx];
79 let new_node = remap_node(old_node, &tensor_mapping);
80 canonical.add_node(new_node)?;
81 }
82
83 let mut new_inputs: Vec<usize> = graph
85 .inputs
86 .iter()
87 .map(|&idx| *tensor_mapping.get(&idx).unwrap())
88 .collect();
89 new_inputs.sort_unstable();
90 canonical.inputs = new_inputs;
91
92 let mut new_outputs: Vec<usize> = graph
94 .outputs
95 .iter()
96 .map(|&idx| *tensor_mapping.get(&idx).unwrap())
97 .collect();
98 new_outputs.sort_unstable();
99 canonical.outputs = new_outputs;
100
101 Ok(canonical)
102}
103
104fn topological_sort_tensors(graph: &EinsumGraph) -> Result<Vec<usize>, IrError> {
111 let num_tensors = graph.tensors.len();
112
113 let mut producers: HashMap<usize, usize> = HashMap::new(); let mut dependencies: HashMap<usize, Vec<usize>> = HashMap::new(); for (node_idx, node) in graph.nodes.iter().enumerate() {
118 for &output_tensor in &node.outputs {
119 producers.insert(output_tensor, node_idx);
120 dependencies.insert(output_tensor, node.inputs.clone());
121 }
122 }
123
124 let mut result = Vec::new();
126 let mut visited = HashSet::new();
127 let mut processing = HashSet::new();
128
129 fn visit(
131 tensor_idx: usize,
132 dependencies: &HashMap<usize, Vec<usize>>,
133 visited: &mut HashSet<usize>,
134 processing: &mut HashSet<usize>,
135 result: &mut Vec<usize>,
136 ) -> Result<(), IrError> {
137 if visited.contains(&tensor_idx) {
138 return Ok(());
139 }
140 if processing.contains(&tensor_idx) {
141 return Err(IrError::CyclicGraph);
142 }
143
144 processing.insert(tensor_idx);
145
146 if let Some(deps) = dependencies.get(&tensor_idx) {
148 for &dep in deps {
149 visit(dep, dependencies, visited, processing, result)?;
150 }
151 }
152
153 processing.remove(&tensor_idx);
154 visited.insert(tensor_idx);
155 result.push(tensor_idx);
156
157 Ok(())
158 }
159
160 for tensor_idx in 0..num_tensors {
162 if !visited.contains(&tensor_idx) {
163 visit(
164 tensor_idx,
165 &dependencies,
166 &mut visited,
167 &mut processing,
168 &mut result,
169 )?;
170 }
171 }
172
173 Ok(result)
174}
175
176fn topological_sort_nodes(graph: &EinsumGraph) -> Result<Vec<usize>, IrError> {
180 let num_nodes = graph.nodes.len();
181
182 let mut in_degree = vec![0; num_nodes];
184 let mut adjacency: Vec<Vec<usize>> = vec![Vec::new(); num_nodes];
185
186 let mut tensor_producers: HashMap<usize, usize> = HashMap::new();
188 for (node_idx, node) in graph.nodes.iter().enumerate() {
189 for &output_tensor in &node.outputs {
190 tensor_producers.insert(output_tensor, node_idx);
191 }
192 }
193
194 for (node_idx, node) in graph.nodes.iter().enumerate() {
196 for &input_tensor in &node.inputs {
197 if let Some(&producer_idx) = tensor_producers.get(&input_tensor) {
198 if producer_idx != node_idx {
199 adjacency[producer_idx].push(node_idx);
200 in_degree[node_idx] += 1;
201 }
202 }
203 }
204 }
205
206 let mut queue = VecDeque::new();
208 for (idx, °ree) in in_degree.iter().enumerate() {
209 if degree == 0 {
210 queue.push_back(idx);
211 }
212 }
213
214 let mut result = Vec::new();
215 while let Some(node_idx) = queue.pop_front() {
216 result.push(node_idx);
217
218 for &neighbor in &adjacency[node_idx] {
219 in_degree[neighbor] -= 1;
220 if in_degree[neighbor] == 0 {
221 queue.push_back(neighbor);
222 }
223 }
224 }
225
226 if result.len() != num_nodes {
227 return Err(IrError::CyclicGraph);
228 }
229
230 Ok(result)
231}
232
233fn remap_node(node: &EinsumNode, tensor_mapping: &HashMap<usize, usize>) -> EinsumNode {
235 let new_inputs = node
236 .inputs
237 .iter()
238 .map(|&idx| *tensor_mapping.get(&idx).unwrap())
239 .collect();
240 let new_outputs = node
241 .outputs
242 .iter()
243 .map(|&idx| *tensor_mapping.get(&idx).unwrap())
244 .collect();
245
246 EinsumNode {
247 op: node.op.clone(),
248 inputs: new_inputs,
249 outputs: new_outputs,
250 metadata: node.metadata.clone(),
251 }
252}
253
254pub fn are_graphs_equivalent(g1: &EinsumGraph, g2: &EinsumGraph) -> bool {
259 if g1.tensors.len() != g2.tensors.len()
261 || g1.nodes.len() != g2.nodes.len()
262 || g1.inputs.len() != g2.inputs.len()
263 || g1.outputs.len() != g2.outputs.len()
264 {
265 return false;
266 }
267
268 match (canonicalize_graph(g1), canonicalize_graph(g2)) {
270 (Ok(c1), Ok(c2)) => c1 == c2,
271 _ => false,
272 }
273}
274
275pub fn canonical_hash(graph: &EinsumGraph) -> Result<u64, IrError> {
279 use std::collections::hash_map::DefaultHasher;
280 use std::hash::{Hash, Hasher};
281
282 let canonical = canonicalize_graph(graph)?;
283
284 let mut hasher = DefaultHasher::new();
285
286 canonical.tensors.len().hash(&mut hasher);
288 canonical.nodes.len().hash(&mut hasher);
289 canonical.inputs.len().hash(&mut hasher);
290 canonical.outputs.len().hash(&mut hasher);
291
292 for tensor in &canonical.tensors {
294 tensor.hash(&mut hasher);
295 }
296
297 for node in &canonical.nodes {
299 match &node.op {
301 super::OpType::Einsum { spec } => {
302 "einsum".hash(&mut hasher);
303 spec.hash(&mut hasher);
304 }
305 super::OpType::ElemUnary { op } => {
306 "elem_unary".hash(&mut hasher);
307 op.hash(&mut hasher);
308 }
309 super::OpType::ElemBinary { op } => {
310 "elem_binary".hash(&mut hasher);
311 op.hash(&mut hasher);
312 }
313 super::OpType::Reduce { op, axes } => {
314 "reduce".hash(&mut hasher);
315 op.hash(&mut hasher);
316 axes.hash(&mut hasher);
317 }
318 }
319
320 node.inputs.hash(&mut hasher);
322 node.outputs.hash(&mut hasher);
323 }
324
325 canonical.inputs.hash(&mut hasher);
327 canonical.outputs.hash(&mut hasher);
328
329 Ok(hasher.finish())
330}
331
332#[cfg(test)]
333mod tests {
334 use super::*;
335
336 #[test]
337 fn test_empty_graph_canonicalization() {
338 let graph = EinsumGraph::new();
339 let canonical = canonicalize_graph(&graph).unwrap();
340 assert!(canonical.is_empty());
341 }
342
343 #[test]
344 fn test_simple_graph_canonicalization() {
345 let mut graph = EinsumGraph::new();
347 let a = graph.add_tensor("matrix_A");
348 let b = graph.add_tensor("matrix_B");
349 let c = graph.add_tensor("result");
350
351 graph
352 .add_node(EinsumNode::einsum("ij,jk->ik", vec![a, b], vec![c]))
353 .unwrap();
354 graph.add_output(c).unwrap();
355
356 let canonical = canonicalize_graph(&graph).unwrap();
357
358 assert_eq!(canonical.tensors, vec!["t0", "t1", "t2"]);
360
361 assert_eq!(canonical.nodes.len(), 1);
363 assert_eq!(canonical.outputs.len(), 1);
364 }
365
366 #[test]
367 fn test_tensor_reordering() {
368 let mut g1 = EinsumGraph::new();
370 let a1 = g1.add_tensor("A");
371 let b1 = g1.add_tensor("B");
372 let c1 = g1.add_tensor("C");
373 g1.add_node(EinsumNode::elem_binary("mul", a1, b1, c1))
374 .unwrap();
375 g1.add_output(c1).unwrap();
376
377 let mut g2 = EinsumGraph::new();
378 let x2 = g2.add_tensor("X");
379 let y2 = g2.add_tensor("Y");
380 let z2 = g2.add_tensor("Z");
381 g2.add_node(EinsumNode::elem_binary("mul", x2, y2, z2))
382 .unwrap();
383 g2.add_output(z2).unwrap();
384
385 let c1 = canonicalize_graph(&g1).unwrap();
387 let c2 = canonicalize_graph(&g2).unwrap();
388
389 assert_eq!(c1, c2);
390 }
391
392 #[test]
393 fn test_graph_equivalence() {
394 let mut g1 = EinsumGraph::new();
395 let a = g1.add_tensor("foo");
396 let b = g1.add_tensor("bar");
397 g1.add_node(EinsumNode::elem_unary("neg", a, b)).unwrap();
398
399 let mut g2 = EinsumGraph::new();
400 let x = g2.add_tensor("different");
401 let y = g2.add_tensor("names");
402 g2.add_node(EinsumNode::elem_unary("neg", x, y)).unwrap();
403
404 assert!(are_graphs_equivalent(&g1, &g2));
405 }
406
407 #[test]
408 fn test_non_equivalent_graphs() {
409 let mut g1 = EinsumGraph::new();
410 let a = g1.add_tensor("A");
411 let b = g1.add_tensor("B");
412 g1.add_node(EinsumNode::elem_unary("neg", a, b)).unwrap();
413
414 let mut g2 = EinsumGraph::new();
415 let x = g2.add_tensor("X");
416 let y = g2.add_tensor("Y");
417 g2.add_node(EinsumNode::elem_unary("sqrt", x, y)).unwrap();
418
419 assert!(!are_graphs_equivalent(&g1, &g2));
420 }
421
422 #[test]
423 fn test_canonical_hash_consistency() {
424 let mut graph = EinsumGraph::new();
425 let a = graph.add_tensor("A");
426 let b = graph.add_tensor("B");
427 graph
428 .add_node(EinsumNode::elem_binary("add", a, a, b))
429 .unwrap();
430
431 let hash1 = canonical_hash(&graph).unwrap();
432 let hash2 = canonical_hash(&graph).unwrap();
433
434 assert_eq!(hash1, hash2);
435 }
436
437 #[test]
438 fn test_equivalent_graphs_same_hash() {
439 let mut g1 = EinsumGraph::new();
440 let a1 = g1.add_tensor("foo");
441 let b1 = g1.add_tensor("bar");
442 g1.add_node(EinsumNode::elem_unary("exp", a1, b1)).unwrap();
443
444 let mut g2 = EinsumGraph::new();
445 let a2 = g2.add_tensor("different");
446 let b2 = g2.add_tensor("names");
447 g2.add_node(EinsumNode::elem_unary("exp", a2, b2)).unwrap();
448
449 let hash1 = canonical_hash(&g1).unwrap();
450 let hash2 = canonical_hash(&g2).unwrap();
451
452 assert_eq!(hash1, hash2);
453 }
454
455 #[test]
456 fn test_complex_graph_canonicalization() {
457 let mut graph = EinsumGraph::new();
459 let a = graph.add_tensor("input1");
460 let b = graph.add_tensor("input2");
461 let c = graph.add_tensor("intermediate1");
462 let d = graph.add_tensor("intermediate2");
463 let e = graph.add_tensor("output");
464
465 graph
466 .add_node(EinsumNode::elem_binary("mul", a, b, c))
467 .unwrap();
468 graph
469 .add_node(EinsumNode::elem_unary("sqrt", c, d))
470 .unwrap();
471 graph
472 .add_node(EinsumNode::elem_binary("add", d, a, e))
473 .unwrap();
474 graph.add_output(e).unwrap();
475
476 let canonical = canonicalize_graph(&graph).unwrap();
477
478 assert_eq!(canonical.tensors.len(), 5);
480 assert_eq!(canonical.nodes.len(), 3);
481
482 for (i, name) in canonical.tensors.iter().enumerate() {
484 assert_eq!(name, &format!("t{}", i));
485 }
486 }
487
488 #[test]
489 fn test_topological_sort_simple() {
490 let mut graph = EinsumGraph::new();
491 let a = graph.add_tensor("A");
492 let b = graph.add_tensor("B");
493 let c = graph.add_tensor("C");
494
495 graph.add_node(EinsumNode::elem_unary("op1", a, b)).unwrap();
497 graph.add_node(EinsumNode::elem_unary("op2", b, c)).unwrap();
498
499 let node_order = topological_sort_nodes(&graph).unwrap();
500
501 assert_eq!(node_order, vec![0, 1]);
503 }
504
505 #[test]
506 fn test_inputs_outputs_preservation() {
507 let mut graph = EinsumGraph::new();
508 let in1 = graph.add_tensor("input1");
509 let in2 = graph.add_tensor("input2");
510 let out1 = graph.add_tensor("output1");
511 let out2 = graph.add_tensor("output2");
512
513 graph.inputs = vec![in1, in2];
514 graph.outputs = vec![out1, out2];
515
516 graph
517 .add_node(EinsumNode::elem_unary("op1", in1, out1))
518 .unwrap();
519 graph
520 .add_node(EinsumNode::elem_unary("op2", in2, out2))
521 .unwrap();
522
523 let canonical = canonicalize_graph(&graph).unwrap();
524
525 assert_eq!(canonical.inputs.len(), 2);
527 assert_eq!(canonical.outputs.len(), 2);
528
529 let mut sorted_inputs = canonical.inputs.clone();
531 sorted_inputs.sort_unstable();
532 assert_eq!(canonical.inputs, sorted_inputs);
533
534 let mut sorted_outputs = canonical.outputs.clone();
535 sorted_outputs.sort_unstable();
536 assert_eq!(canonical.outputs, sorted_outputs);
537 }
538}