Skip to main content

tensorlogic_ir/graph/
transform.rs

1//! Graph transformation and manipulation utilities.
2
3use std::collections::{HashMap, HashSet};
4
5use super::{EinsumGraph, EinsumNode};
6use crate::error::IrError;
7
8/// Visitor trait for traversing graph nodes.
9pub trait GraphVisitor {
10    /// Visit a node in the graph.
11    fn visit_node(&mut self, node_idx: usize, node: &EinsumNode, graph: &EinsumGraph);
12
13    /// Called before visiting all nodes.
14    fn start(&mut self, _graph: &EinsumGraph) {}
15
16    /// Called after visiting all nodes.
17    fn finish(&mut self, _graph: &EinsumGraph) {}
18}
19
20/// Mutable visitor trait for transforming graph nodes.
21pub trait GraphMutVisitor {
22    /// Visit and potentially modify a node.
23    fn visit_node_mut(
24        &mut self,
25        node_idx: usize,
26        node: &mut EinsumNode,
27        graph: &EinsumGraph,
28    ) -> Result<(), IrError>;
29}
30
31impl EinsumGraph {
32    /// Extract a subgraph containing only the specified nodes and their dependencies.
33    pub fn extract_subgraph(&self, node_indices: &[usize]) -> Result<EinsumGraph, IrError> {
34        // Validate node indices
35        for &idx in node_indices {
36            if idx >= self.nodes.len() {
37                return Err(IrError::NodeValidation {
38                    node: idx,
39                    message: format!("Node index {} out of bounds", idx),
40                });
41            }
42        }
43
44        // Collect all nodes reachable from the specified nodes (via dependencies)
45        let mut reachable_nodes = HashSet::new();
46        for &idx in node_indices {
47            self.collect_dependencies(idx, &mut reachable_nodes);
48        }
49
50        // Build index mapping for tensors
51        let mut tensor_map = HashMap::new();
52        let mut new_graph = EinsumGraph::new();
53
54        // Collect all tensors used by reachable nodes (both inputs and outputs)
55        let mut used_tensors = HashSet::new();
56        for &node_idx in &reachable_nodes {
57            let node = &self.nodes[node_idx];
58            for &input_idx in &node.inputs {
59                used_tensors.insert(input_idx);
60            }
61            for &output_idx in &node.outputs {
62                used_tensors.insert(output_idx);
63            }
64        }
65
66        // Add tensors to new graph
67        for &tensor_idx in &used_tensors {
68            let new_idx = new_graph.add_tensor(&self.tensors[tensor_idx]);
69            tensor_map.insert(tensor_idx, new_idx);
70        }
71
72        // Add nodes with remapped tensor indices
73        for &node_idx in &reachable_nodes {
74            let old_node = &self.nodes[node_idx];
75            let new_node = old_node.remap_tensors(&tensor_map)?;
76            new_graph.add_node(new_node)?;
77        }
78
79        // Set outputs (if any of the original outputs are in the subgraph)
80        for &out_idx in &self.outputs {
81            if let Some(&new_idx) = tensor_map.get(&out_idx) {
82                new_graph.add_output(new_idx)?;
83            }
84        }
85
86        Ok(new_graph)
87    }
88
89    /// Collect all nodes that this node depends on (recursively).
90    fn collect_dependencies(&self, node_idx: usize, visited: &mut HashSet<usize>) {
91        if visited.contains(&node_idx) {
92            return;
93        }
94        visited.insert(node_idx);
95
96        let node = &self.nodes[node_idx];
97
98        // Find nodes that produce the input tensors for this node
99        for &input_tensor in &node.inputs {
100            // Find which node produces this input tensor
101            for (idx, other_node) in self.nodes.iter().enumerate() {
102                if idx < node_idx && other_node.produces(input_tensor) {
103                    self.collect_dependencies(idx, visited);
104                }
105            }
106        }
107    }
108
109    /// Merge another graph into this one.
110    ///
111    /// Returns a mapping from old tensor indices to new tensor indices.
112    pub fn merge(&mut self, other: &EinsumGraph) -> Result<HashMap<usize, usize>, IrError> {
113        let mut tensor_map = HashMap::new();
114
115        // Try to reuse existing tensors with the same name
116        for (old_idx, tensor_name) in other.tensors.iter().enumerate() {
117            if let Some(existing_idx) = self.tensors.iter().position(|t| t == tensor_name) {
118                tensor_map.insert(old_idx, existing_idx);
119            } else {
120                let new_idx = self.add_tensor(tensor_name);
121                tensor_map.insert(old_idx, new_idx);
122            }
123        }
124
125        // Add nodes with remapped tensor indices
126        for node in &other.nodes {
127            let new_node = node.remap_tensors(&tensor_map)?;
128            self.add_node(new_node)?;
129        }
130
131        // Add outputs
132        for &out_idx in &other.outputs {
133            if let Some(&new_idx) = tensor_map.get(&out_idx) {
134                if !self.outputs.contains(&new_idx) {
135                    self.add_output(new_idx)?;
136                }
137            }
138        }
139
140        Ok(tensor_map)
141    }
142
143    /// Visit all nodes in the graph using a visitor.
144    pub fn visit<V: GraphVisitor>(&self, visitor: &mut V) {
145        visitor.start(self);
146        for (idx, node) in self.nodes.iter().enumerate() {
147            visitor.visit_node(idx, node, self);
148        }
149        visitor.finish(self);
150    }
151
152    /// Visit all nodes mutably using a mutable visitor.
153    pub fn visit_mut<V: GraphMutVisitor>(&mut self, visitor: &mut V) -> Result<(), IrError> {
154        // We need to clone the graph for the visitor to see the original structure
155        let graph_clone = self.clone();
156
157        for idx in 0..self.nodes.len() {
158            visitor.visit_node_mut(idx, &mut self.nodes[idx], &graph_clone)?;
159        }
160
161        Ok(())
162    }
163
164    /// Apply a rewrite rule to all nodes in the graph.
165    ///
166    /// The rule function takes a node and returns an optional replacement node.
167    pub fn apply_rewrite<F>(&mut self, mut rule: F) -> Result<usize, IrError>
168    where
169        F: FnMut(&EinsumNode) -> Option<EinsumNode>,
170    {
171        let mut rewrites = 0;
172
173        for node in &mut self.nodes {
174            if let Some(new_node) = rule(node) {
175                *node = new_node;
176                rewrites += 1;
177            }
178        }
179
180        Ok(rewrites)
181    }
182
183    /// Get all nodes that depend on a specific tensor (consume it as input).
184    pub fn tensor_consumers(&self, tensor_idx: usize) -> Vec<usize> {
185        self.nodes
186            .iter()
187            .enumerate()
188            .filter(|(_, node)| node.inputs.contains(&tensor_idx))
189            .map(|(idx, _)| idx)
190            .collect()
191    }
192
193    /// Get the node that produces a specific tensor.
194    ///
195    /// Note: In the current graph model, tensors can be produced by at most one node
196    /// or be external inputs. This returns nodes that might output to this tensor
197    /// based on graph topology.
198    pub fn tensor_producer(&self, tensor_idx: usize) -> Option<usize> {
199        // A simple heuristic: find nodes that come before uses of this tensor
200        let consumers = self.tensor_consumers(tensor_idx);
201        if consumers.is_empty() {
202            return None;
203        }
204
205        let min_consumer = consumers.iter().min().copied()?;
206
207        // Find the latest node before min_consumer
208        if min_consumer > 0 {
209            Some(min_consumer - 1)
210        } else {
211            None
212        }
213    }
214
215    /// Check if there's a path from node_from to node_to based on node ordering.
216    pub fn has_path(&self, node_from: usize, node_to: usize) -> bool {
217        // Simple topological ordering: lower indices come before higher indices
218        node_from <= node_to
219    }
220
221    /// Get dependency chain for a node (all nodes it depends on).
222    pub fn dependencies(&self, node_idx: usize) -> HashSet<usize> {
223        let mut deps = HashSet::new();
224        self.collect_dependencies(node_idx, &mut deps);
225        deps.remove(&node_idx); // Remove self
226        deps
227    }
228
229    /// Get number of nodes.
230    pub fn node_count(&self) -> usize {
231        self.nodes.len()
232    }
233
234    /// Get number of tensors.
235    pub fn tensor_count(&self) -> usize {
236        self.tensors.len()
237    }
238}
239
240impl EinsumNode {
241    /// Remap tensor indices using the provided mapping.
242    pub(crate) fn remap_tensors(
243        &self,
244        tensor_map: &HashMap<usize, usize>,
245    ) -> Result<Self, IrError> {
246        let inputs: Vec<usize> = self
247            .inputs
248            .iter()
249            .map(|&idx| {
250                tensor_map
251                    .get(&idx)
252                    .copied()
253                    .ok_or_else(|| IrError::NodeValidation {
254                        node: 0,
255                        message: format!("Input tensor {} not in mapping", idx),
256                    })
257            })
258            .collect::<Result<_, _>>()?;
259
260        let outputs: Vec<usize> = self
261            .outputs
262            .iter()
263            .map(|&idx| {
264                tensor_map
265                    .get(&idx)
266                    .copied()
267                    .ok_or_else(|| IrError::NodeValidation {
268                        node: 0,
269                        message: format!("Output tensor {} not in mapping", idx),
270                    })
271            })
272            .collect::<Result<_, _>>()?;
273
274        Ok(EinsumNode {
275            inputs,
276            outputs,
277            op: self.op.clone(),
278            metadata: self.metadata.clone(),
279        })
280    }
281}
282
283#[cfg(test)]
284mod tests {
285    use super::*;
286    use crate::graph::OpType;
287
288    fn create_test_graph() -> EinsumGraph {
289        let mut g = EinsumGraph::new();
290
291        // Tensors: t0, t1, t2, t3, t4, t5, t6 (inputs + intermediate + outputs)
292        let t0 = g.add_tensor("t0");
293        let t1 = g.add_tensor("t1");
294        let t2 = g.add_tensor("t2");
295        let _t3 = g.add_tensor("t3");
296        let t4 = g.add_tensor("t4"); // output of node 0
297        let t5 = g.add_tensor("t5"); // output of node 1
298        let t6 = g.add_tensor("t6"); // output of node 2
299
300        // Node 0: uses t0, produces t4
301        g.add_node(EinsumNode {
302            inputs: vec![t0],
303            outputs: vec![t4],
304            op: OpType::Einsum {
305                spec: "i->i".to_string(),
306            },
307            metadata: None,
308        })
309        .unwrap();
310
311        // Node 1: uses t1, produces t5
312        g.add_node(EinsumNode {
313            inputs: vec![t1],
314            outputs: vec![t5],
315            op: OpType::Einsum {
316                spec: "i->i".to_string(),
317            },
318            metadata: None,
319        })
320        .unwrap();
321
322        // Node 2: uses t2, produces t6
323        g.add_node(EinsumNode {
324            inputs: vec![t2],
325            outputs: vec![t6],
326            op: OpType::Einsum {
327                spec: "i->i".to_string(),
328            },
329            metadata: None,
330        })
331        .unwrap();
332
333        g.add_output(t6).unwrap();
334
335        g
336    }
337
338    #[test]
339    fn test_extract_subgraph() {
340        let graph = create_test_graph();
341
342        // Extract nodes 0 and 1
343        let subgraph = graph.extract_subgraph(&[0, 1]).unwrap();
344
345        assert_eq!(subgraph.nodes.len(), 2);
346        assert!(subgraph.tensors.len() >= 2);
347    }
348
349    #[test]
350    fn test_merge_graphs() {
351        let mut g1 = EinsumGraph::new();
352        let t0 = g1.add_tensor("shared");
353        let t1 = g1.add_tensor("out1");
354        g1.add_node(EinsumNode {
355            inputs: vec![t0],
356            outputs: vec![t1],
357            op: OpType::Einsum {
358                spec: "i->i".to_string(),
359            },
360            metadata: None,
361        })
362        .unwrap();
363
364        let mut g2 = EinsumGraph::new();
365        let t0_2 = g2.add_tensor("shared");
366        let t1_2 = g2.add_tensor("out2");
367        g2.add_node(EinsumNode {
368            inputs: vec![t0_2],
369            outputs: vec![t1_2],
370            op: OpType::Einsum {
371                spec: "i->i".to_string(),
372            },
373            metadata: None,
374        })
375        .unwrap();
376
377        let tensor_map = g1.merge(&g2).unwrap();
378
379        // Should reuse "shared" tensor
380        assert_eq!(tensor_map[&0], 0); // "shared" mapped to same index
381        assert_eq!(g1.nodes.len(), 2);
382    }
383
384    #[test]
385    fn test_tensor_consumers() {
386        let graph = create_test_graph();
387
388        let consumers = graph.tensor_consumers(1); // t1
389        assert_eq!(consumers.len(), 1);
390        assert_eq!(consumers[0], 1); // Node 1 consumes t1
391    }
392
393    #[test]
394    fn test_has_path() {
395        let graph = create_test_graph();
396
397        assert!(graph.has_path(0, 2)); // 0 -> 2 (via ordering)
398        assert!(graph.has_path(0, 0)); // Same node
399        assert!(!graph.has_path(2, 0)); // No backward path
400    }
401
402    #[test]
403    fn test_visitor_pattern() {
404        let graph = create_test_graph();
405
406        struct CountingVisitor {
407            count: usize,
408        }
409
410        impl GraphVisitor for CountingVisitor {
411            fn visit_node(&mut self, _idx: usize, _node: &EinsumNode, _graph: &EinsumGraph) {
412                self.count += 1;
413            }
414        }
415
416        let mut visitor = CountingVisitor { count: 0 };
417        graph.visit(&mut visitor);
418
419        assert_eq!(visitor.count, 3);
420    }
421
422    #[test]
423    fn test_apply_rewrite() {
424        let mut graph = create_test_graph();
425
426        // Replace all einsum operations with a different spec
427        let rewrites = graph
428            .apply_rewrite(|node| {
429                if matches!(node.op, OpType::Einsum { .. }) {
430                    Some(EinsumNode {
431                        inputs: node.inputs.clone(),
432                        outputs: node.outputs.clone(),
433                        op: OpType::Einsum {
434                            spec: "new->spec".to_string(),
435                        },
436                        metadata: None,
437                    })
438                } else {
439                    None
440                }
441            })
442            .unwrap();
443
444        assert_eq!(rewrites, 3);
445
446        for node in &graph.nodes {
447            if let OpType::Einsum { spec } = &node.op {
448                assert_eq!(spec, "new->spec");
449            }
450        }
451    }
452
453    #[test]
454    fn test_node_count() {
455        let graph = create_test_graph();
456        assert_eq!(graph.node_count(), 3);
457        assert_eq!(graph.tensor_count(), 7); // t0-t6 (3 inputs + 1 unused + 3 outputs)
458    }
459
460    #[test]
461    fn test_dependencies() {
462        // Create a graph with actual dependencies
463        let mut graph = EinsumGraph::new();
464        let t0 = graph.add_tensor("t0");
465        let t1 = graph.add_tensor("t1"); // output of node 0
466        let t2 = graph.add_tensor("t2"); // output of node 1
467
468        // Node 0: produces t1 from t0
469        graph
470            .add_node(EinsumNode {
471                inputs: vec![t0],
472                outputs: vec![t1],
473                op: OpType::Einsum {
474                    spec: "i->i".to_string(),
475                },
476                metadata: None,
477            })
478            .unwrap();
479
480        // Node 1: produces t2 from t1 (depends on node 0)
481        graph
482            .add_node(EinsumNode {
483                inputs: vec![t1],
484                outputs: vec![t2],
485                op: OpType::Einsum {
486                    spec: "i->i".to_string(),
487                },
488                metadata: None,
489            })
490            .unwrap();
491
492        let deps = graph.dependencies(1);
493        // Node 1 depends on node 0 (which produces t1)
494        assert!(deps.contains(&0));
495        assert_eq!(deps.len(), 1);
496    }
497}