Skip to main content

tensorlogic_ir/graph/
canonicalization.rs

1//! Graph canonicalization for deterministic comparison and hashing.
2//!
3//! This module provides functionality to convert computation graphs into a canonical form,
4//! which is useful for:
5//! - Graph equality comparison
6//! - Duplicate graph detection
7//! - Common subexpression elimination
8//! - Graph hashing and caching
9//!
10//! # Algorithm
11//!
12//! The canonicalization process:
13//! 1. Compute topological ordering of tensors
14//! 2. Assign canonical names (t0, t1, t2, ...) based on order
15//! 3. Sort nodes in execution order
16//! 4. Normalize inputs and outputs
17//!
18//! # Examples
19//!
20//! ```
21//! use tensorlogic_ir::{EinsumGraph, EinsumNode};
22//!
23//! let mut graph = EinsumGraph::new();
24//! let a = graph.add_tensor("foo");
25//! let b = graph.add_tensor("bar");
26//! let c = graph.add_tensor("baz");
27//! graph.add_node(EinsumNode::einsum("ij,jk->ik", vec![a, b], vec![c])).unwrap();
28//! graph.add_output(c).unwrap();
29//!
30//! let canonical = tensorlogic_ir::canonicalize_graph(&graph).unwrap();
31//! // Tensors are renamed to t0, t1, t2
32//! assert_eq!(canonical.tensors, vec!["t0", "t1", "t2"]);
33//! ```
34
35use std::collections::{HashMap, HashSet, VecDeque};
36
37use super::{EinsumGraph, EinsumNode};
38use crate::error::IrError;
39
40/// Canonicalize a computation graph.
41///
42/// This function converts a graph into a canonical form where:
43/// - Tensors are renamed to t0, t1, t2, ... in topological order
44/// - Nodes are sorted in execution order
45/// - Inputs and outputs are sorted consistently
46///
47/// The resulting graph is semantically equivalent to the original but has a
48/// normalized structure that facilitates comparison and hashing.
49pub fn canonicalize_graph(graph: &EinsumGraph) -> Result<EinsumGraph, IrError> {
50    // Empty graph is already canonical
51    if graph.is_empty() {
52        return Ok(graph.clone());
53    }
54
55    // Validate the input graph
56    graph.validate()?;
57
58    // Step 1: Compute topological order of tensors
59    let tensor_order = topological_sort_tensors(graph)?;
60
61    // Step 2: Create mapping from old indices to new canonical indices
62    let mut tensor_mapping = HashMap::new();
63    for (new_idx, &old_idx) in tensor_order.iter().enumerate() {
64        tensor_mapping.insert(old_idx, new_idx);
65    }
66
67    // Step 3: Build canonical graph
68    let mut canonical = EinsumGraph::new();
69
70    // Add tensors with canonical names
71    for i in 0..tensor_order.len() {
72        canonical.add_tensor(format!("t{}", i));
73    }
74
75    // Step 4: Remap and add nodes in topological order
76    let sorted_nodes = topological_sort_nodes(graph)?;
77    for node_idx in sorted_nodes {
78        let old_node = &graph.nodes[node_idx];
79        let new_node = remap_node(old_node, &tensor_mapping);
80        canonical.add_node(new_node)?;
81    }
82
83    // Step 5: Remap and sort inputs
84    let mut new_inputs: Vec<usize> = graph
85        .inputs
86        .iter()
87        .map(|&idx| *tensor_mapping.get(&idx).unwrap())
88        .collect();
89    new_inputs.sort_unstable();
90    canonical.inputs = new_inputs;
91
92    // Step 6: Remap and sort outputs
93    let mut new_outputs: Vec<usize> = graph
94        .outputs
95        .iter()
96        .map(|&idx| *tensor_mapping.get(&idx).unwrap())
97        .collect();
98    new_outputs.sort_unstable();
99    canonical.outputs = new_outputs;
100
101    Ok(canonical)
102}
103
104/// Compute topological ordering of tensors.
105///
106/// Returns a vector of tensor indices in topological order, where:
107/// - Input tensors come first
108/// - Intermediate tensors are ordered by their producers
109/// - Unused tensors come last
110fn topological_sort_tensors(graph: &EinsumGraph) -> Result<Vec<usize>, IrError> {
111    let num_tensors = graph.tensors.len();
112
113    // Track tensor dependencies: which tensors are used to produce each tensor
114    let mut producers: HashMap<usize, usize> = HashMap::new(); // tensor -> node that produces it
115    let mut dependencies: HashMap<usize, Vec<usize>> = HashMap::new(); // tensor -> input tensors
116
117    for (node_idx, node) in graph.nodes.iter().enumerate() {
118        for &output_tensor in &node.outputs {
119            producers.insert(output_tensor, node_idx);
120            dependencies.insert(output_tensor, node.inputs.clone());
121        }
122    }
123
124    // Tensors with no producers are inputs or constants
125    let mut result = Vec::new();
126    let mut visited = HashSet::new();
127    let mut processing = HashSet::new();
128
129    // Helper function for DFS traversal
130    fn visit(
131        tensor_idx: usize,
132        dependencies: &HashMap<usize, Vec<usize>>,
133        visited: &mut HashSet<usize>,
134        processing: &mut HashSet<usize>,
135        result: &mut Vec<usize>,
136    ) -> Result<(), IrError> {
137        if visited.contains(&tensor_idx) {
138            return Ok(());
139        }
140        if processing.contains(&tensor_idx) {
141            return Err(IrError::CyclicGraph);
142        }
143
144        processing.insert(tensor_idx);
145
146        // Visit dependencies first
147        if let Some(deps) = dependencies.get(&tensor_idx) {
148            for &dep in deps {
149                visit(dep, dependencies, visited, processing, result)?;
150            }
151        }
152
153        processing.remove(&tensor_idx);
154        visited.insert(tensor_idx);
155        result.push(tensor_idx);
156
157        Ok(())
158    }
159
160    // Process all tensors
161    for tensor_idx in 0..num_tensors {
162        if !visited.contains(&tensor_idx) {
163            visit(
164                tensor_idx,
165                &dependencies,
166                &mut visited,
167                &mut processing,
168                &mut result,
169            )?;
170        }
171    }
172
173    Ok(result)
174}
175
176/// Compute topological ordering of nodes.
177///
178/// Returns a vector of node indices in execution order.
179fn topological_sort_nodes(graph: &EinsumGraph) -> Result<Vec<usize>, IrError> {
180    let num_nodes = graph.nodes.len();
181
182    // Build dependency graph: which nodes must execute before others
183    let mut in_degree = vec![0; num_nodes];
184    let mut adjacency: Vec<Vec<usize>> = vec![Vec::new(); num_nodes];
185
186    // Track which node produces each tensor
187    let mut tensor_producers: HashMap<usize, usize> = HashMap::new();
188    for (node_idx, node) in graph.nodes.iter().enumerate() {
189        for &output_tensor in &node.outputs {
190            tensor_producers.insert(output_tensor, node_idx);
191        }
192    }
193
194    // Build edges: if node B uses a tensor produced by node A, then A -> B
195    for (node_idx, node) in graph.nodes.iter().enumerate() {
196        for &input_tensor in &node.inputs {
197            if let Some(&producer_idx) = tensor_producers.get(&input_tensor) {
198                if producer_idx != node_idx {
199                    adjacency[producer_idx].push(node_idx);
200                    in_degree[node_idx] += 1;
201                }
202            }
203        }
204    }
205
206    // Kahn's algorithm for topological sort
207    let mut queue = VecDeque::new();
208    for (idx, &degree) in in_degree.iter().enumerate() {
209        if degree == 0 {
210            queue.push_back(idx);
211        }
212    }
213
214    let mut result = Vec::new();
215    while let Some(node_idx) = queue.pop_front() {
216        result.push(node_idx);
217
218        for &neighbor in &adjacency[node_idx] {
219            in_degree[neighbor] -= 1;
220            if in_degree[neighbor] == 0 {
221                queue.push_back(neighbor);
222            }
223        }
224    }
225
226    if result.len() != num_nodes {
227        return Err(IrError::CyclicGraph);
228    }
229
230    Ok(result)
231}
232
233/// Remap a node's tensor indices using the provided mapping.
234fn remap_node(node: &EinsumNode, tensor_mapping: &HashMap<usize, usize>) -> EinsumNode {
235    let new_inputs = node
236        .inputs
237        .iter()
238        .map(|&idx| *tensor_mapping.get(&idx).unwrap())
239        .collect();
240    let new_outputs = node
241        .outputs
242        .iter()
243        .map(|&idx| *tensor_mapping.get(&idx).unwrap())
244        .collect();
245
246    EinsumNode {
247        op: node.op.clone(),
248        inputs: new_inputs,
249        outputs: new_outputs,
250        metadata: node.metadata.clone(),
251    }
252}
253
254/// Check if two graphs are canonically equivalent.
255///
256/// This is more efficient than canonicalizing both graphs and comparing,
257/// as it can short-circuit on basic structural differences.
258pub fn are_graphs_equivalent(g1: &EinsumGraph, g2: &EinsumGraph) -> bool {
259    // Quick structural checks
260    if g1.tensors.len() != g2.tensors.len()
261        || g1.nodes.len() != g2.nodes.len()
262        || g1.inputs.len() != g2.inputs.len()
263        || g1.outputs.len() != g2.outputs.len()
264    {
265        return false;
266    }
267
268    // Canonicalize and compare
269    match (canonicalize_graph(g1), canonicalize_graph(g2)) {
270        (Ok(c1), Ok(c2)) => c1 == c2,
271        _ => false,
272    }
273}
274
275/// Compute a hash of a graph in canonical form.
276///
277/// This can be used for efficient graph deduplication and caching.
278pub fn canonical_hash(graph: &EinsumGraph) -> Result<u64, IrError> {
279    use std::collections::hash_map::DefaultHasher;
280    use std::hash::{Hash, Hasher};
281
282    let canonical = canonicalize_graph(graph)?;
283
284    let mut hasher = DefaultHasher::new();
285
286    // Hash the structure
287    canonical.tensors.len().hash(&mut hasher);
288    canonical.nodes.len().hash(&mut hasher);
289    canonical.inputs.len().hash(&mut hasher);
290    canonical.outputs.len().hash(&mut hasher);
291
292    // Hash tensor names (should all be t0, t1, t2, ... but hash anyway)
293    for tensor in &canonical.tensors {
294        tensor.hash(&mut hasher);
295    }
296
297    // Hash nodes
298    for node in &canonical.nodes {
299        // Hash operation type
300        match &node.op {
301            super::OpType::Einsum { spec } => {
302                "einsum".hash(&mut hasher);
303                spec.hash(&mut hasher);
304            }
305            super::OpType::ElemUnary { op } => {
306                "elem_unary".hash(&mut hasher);
307                op.hash(&mut hasher);
308            }
309            super::OpType::ElemBinary { op } => {
310                "elem_binary".hash(&mut hasher);
311                op.hash(&mut hasher);
312            }
313            super::OpType::Reduce { op, axes } => {
314                "reduce".hash(&mut hasher);
315                op.hash(&mut hasher);
316                axes.hash(&mut hasher);
317            }
318        }
319
320        // Hash inputs and outputs
321        node.inputs.hash(&mut hasher);
322        node.outputs.hash(&mut hasher);
323    }
324
325    // Hash inputs and outputs
326    canonical.inputs.hash(&mut hasher);
327    canonical.outputs.hash(&mut hasher);
328
329    Ok(hasher.finish())
330}
331
332#[cfg(test)]
333mod tests {
334    use super::*;
335
336    #[test]
337    fn test_empty_graph_canonicalization() {
338        let graph = EinsumGraph::new();
339        let canonical = canonicalize_graph(&graph).unwrap();
340        assert!(canonical.is_empty());
341    }
342
343    #[test]
344    fn test_simple_graph_canonicalization() {
345        // Build a simple graph: A @ B = C
346        let mut graph = EinsumGraph::new();
347        let a = graph.add_tensor("matrix_A");
348        let b = graph.add_tensor("matrix_B");
349        let c = graph.add_tensor("result");
350
351        graph
352            .add_node(EinsumNode::einsum("ij,jk->ik", vec![a, b], vec![c]))
353            .unwrap();
354        graph.add_output(c).unwrap();
355
356        let canonical = canonicalize_graph(&graph).unwrap();
357
358        // Check tensor names are canonical
359        assert_eq!(canonical.tensors, vec!["t0", "t1", "t2"]);
360
361        // Check structure is preserved
362        assert_eq!(canonical.nodes.len(), 1);
363        assert_eq!(canonical.outputs.len(), 1);
364    }
365
366    #[test]
367    fn test_tensor_reordering() {
368        // Build two graphs with different tensor orderings but same computation
369        let mut g1 = EinsumGraph::new();
370        let a1 = g1.add_tensor("A");
371        let b1 = g1.add_tensor("B");
372        let c1 = g1.add_tensor("C");
373        g1.add_node(EinsumNode::elem_binary("mul", a1, b1, c1))
374            .unwrap();
375        g1.add_output(c1).unwrap();
376
377        let mut g2 = EinsumGraph::new();
378        let x2 = g2.add_tensor("X");
379        let y2 = g2.add_tensor("Y");
380        let z2 = g2.add_tensor("Z");
381        g2.add_node(EinsumNode::elem_binary("mul", x2, y2, z2))
382            .unwrap();
383        g2.add_output(z2).unwrap();
384
385        // Both should canonicalize to the same structure
386        let c1 = canonicalize_graph(&g1).unwrap();
387        let c2 = canonicalize_graph(&g2).unwrap();
388
389        assert_eq!(c1, c2);
390    }
391
392    #[test]
393    fn test_graph_equivalence() {
394        let mut g1 = EinsumGraph::new();
395        let a = g1.add_tensor("foo");
396        let b = g1.add_tensor("bar");
397        g1.add_node(EinsumNode::elem_unary("neg", a, b)).unwrap();
398
399        let mut g2 = EinsumGraph::new();
400        let x = g2.add_tensor("different");
401        let y = g2.add_tensor("names");
402        g2.add_node(EinsumNode::elem_unary("neg", x, y)).unwrap();
403
404        assert!(are_graphs_equivalent(&g1, &g2));
405    }
406
407    #[test]
408    fn test_non_equivalent_graphs() {
409        let mut g1 = EinsumGraph::new();
410        let a = g1.add_tensor("A");
411        let b = g1.add_tensor("B");
412        g1.add_node(EinsumNode::elem_unary("neg", a, b)).unwrap();
413
414        let mut g2 = EinsumGraph::new();
415        let x = g2.add_tensor("X");
416        let y = g2.add_tensor("Y");
417        g2.add_node(EinsumNode::elem_unary("sqrt", x, y)).unwrap();
418
419        assert!(!are_graphs_equivalent(&g1, &g2));
420    }
421
422    #[test]
423    fn test_canonical_hash_consistency() {
424        let mut graph = EinsumGraph::new();
425        let a = graph.add_tensor("A");
426        let b = graph.add_tensor("B");
427        graph
428            .add_node(EinsumNode::elem_binary("add", a, a, b))
429            .unwrap();
430
431        let hash1 = canonical_hash(&graph).unwrap();
432        let hash2 = canonical_hash(&graph).unwrap();
433
434        assert_eq!(hash1, hash2);
435    }
436
437    #[test]
438    fn test_equivalent_graphs_same_hash() {
439        let mut g1 = EinsumGraph::new();
440        let a1 = g1.add_tensor("foo");
441        let b1 = g1.add_tensor("bar");
442        g1.add_node(EinsumNode::elem_unary("exp", a1, b1)).unwrap();
443
444        let mut g2 = EinsumGraph::new();
445        let a2 = g2.add_tensor("different");
446        let b2 = g2.add_tensor("names");
447        g2.add_node(EinsumNode::elem_unary("exp", a2, b2)).unwrap();
448
449        let hash1 = canonical_hash(&g1).unwrap();
450        let hash2 = canonical_hash(&g2).unwrap();
451
452        assert_eq!(hash1, hash2);
453    }
454
455    #[test]
456    fn test_complex_graph_canonicalization() {
457        // Build a multi-node graph
458        let mut graph = EinsumGraph::new();
459        let a = graph.add_tensor("input1");
460        let b = graph.add_tensor("input2");
461        let c = graph.add_tensor("intermediate1");
462        let d = graph.add_tensor("intermediate2");
463        let e = graph.add_tensor("output");
464
465        graph
466            .add_node(EinsumNode::elem_binary("mul", a, b, c))
467            .unwrap();
468        graph
469            .add_node(EinsumNode::elem_unary("sqrt", c, d))
470            .unwrap();
471        graph
472            .add_node(EinsumNode::elem_binary("add", d, a, e))
473            .unwrap();
474        graph.add_output(e).unwrap();
475
476        let canonical = canonicalize_graph(&graph).unwrap();
477
478        // Verify canonicalization worked
479        assert_eq!(canonical.tensors.len(), 5);
480        assert_eq!(canonical.nodes.len(), 3);
481
482        // Verify all tensor names are canonical
483        for (i, name) in canonical.tensors.iter().enumerate() {
484            assert_eq!(name, &format!("t{}", i));
485        }
486    }
487
488    #[test]
489    fn test_topological_sort_simple() {
490        let mut graph = EinsumGraph::new();
491        let a = graph.add_tensor("A");
492        let b = graph.add_tensor("B");
493        let c = graph.add_tensor("C");
494
495        // A -> B -> C
496        graph.add_node(EinsumNode::elem_unary("op1", a, b)).unwrap();
497        graph.add_node(EinsumNode::elem_unary("op2", b, c)).unwrap();
498
499        let node_order = topological_sort_nodes(&graph).unwrap();
500
501        // First node should come before second node
502        assert_eq!(node_order, vec![0, 1]);
503    }
504
505    #[test]
506    fn test_inputs_outputs_preservation() {
507        let mut graph = EinsumGraph::new();
508        let in1 = graph.add_tensor("input1");
509        let in2 = graph.add_tensor("input2");
510        let out1 = graph.add_tensor("output1");
511        let out2 = graph.add_tensor("output2");
512
513        graph.inputs = vec![in1, in2];
514        graph.outputs = vec![out1, out2];
515
516        graph
517            .add_node(EinsumNode::elem_unary("op1", in1, out1))
518            .unwrap();
519        graph
520            .add_node(EinsumNode::elem_unary("op2", in2, out2))
521            .unwrap();
522
523        let canonical = canonicalize_graph(&graph).unwrap();
524
525        // Inputs and outputs should be preserved (but sorted)
526        assert_eq!(canonical.inputs.len(), 2);
527        assert_eq!(canonical.outputs.len(), 2);
528
529        // They should be sorted
530        let mut sorted_inputs = canonical.inputs.clone();
531        sorted_inputs.sort_unstable();
532        assert_eq!(canonical.inputs, sorted_inputs);
533
534        let mut sorted_outputs = canonical.outputs.clone();
535        sorted_outputs.sort_unstable();
536        assert_eq!(canonical.outputs, sorted_outputs);
537    }
538}