1use std::collections::{HashMap, HashSet};
4
5use super::{EinsumGraph, EinsumNode};
6use crate::error::IrError;
7
8pub trait GraphVisitor {
10 fn visit_node(&mut self, node_idx: usize, node: &EinsumNode, graph: &EinsumGraph);
12
13 fn start(&mut self, _graph: &EinsumGraph) {}
15
16 fn finish(&mut self, _graph: &EinsumGraph) {}
18}
19
20pub trait GraphMutVisitor {
22 fn visit_node_mut(
24 &mut self,
25 node_idx: usize,
26 node: &mut EinsumNode,
27 graph: &EinsumGraph,
28 ) -> Result<(), IrError>;
29}
30
31impl EinsumGraph {
32 pub fn extract_subgraph(&self, node_indices: &[usize]) -> Result<EinsumGraph, IrError> {
34 for &idx in node_indices {
36 if idx >= self.nodes.len() {
37 return Err(IrError::NodeValidation {
38 node: idx,
39 message: format!("Node index {} out of bounds", idx),
40 });
41 }
42 }
43
44 let mut reachable_nodes = HashSet::new();
46 for &idx in node_indices {
47 self.collect_dependencies(idx, &mut reachable_nodes);
48 }
49
50 let mut tensor_map = HashMap::new();
52 let mut new_graph = EinsumGraph::new();
53
54 let mut used_tensors = HashSet::new();
56 for &node_idx in &reachable_nodes {
57 let node = &self.nodes[node_idx];
58 for &input_idx in &node.inputs {
59 used_tensors.insert(input_idx);
60 }
61 for &output_idx in &node.outputs {
62 used_tensors.insert(output_idx);
63 }
64 }
65
66 for &tensor_idx in &used_tensors {
68 let new_idx = new_graph.add_tensor(&self.tensors[tensor_idx]);
69 tensor_map.insert(tensor_idx, new_idx);
70 }
71
72 for &node_idx in &reachable_nodes {
74 let old_node = &self.nodes[node_idx];
75 let new_node = old_node.remap_tensors(&tensor_map)?;
76 new_graph.add_node(new_node)?;
77 }
78
79 for &out_idx in &self.outputs {
81 if let Some(&new_idx) = tensor_map.get(&out_idx) {
82 new_graph.add_output(new_idx)?;
83 }
84 }
85
86 Ok(new_graph)
87 }
88
89 fn collect_dependencies(&self, node_idx: usize, visited: &mut HashSet<usize>) {
91 if visited.contains(&node_idx) {
92 return;
93 }
94 visited.insert(node_idx);
95
96 let node = &self.nodes[node_idx];
97
98 for &input_tensor in &node.inputs {
100 for (idx, other_node) in self.nodes.iter().enumerate() {
102 if idx < node_idx && other_node.produces(input_tensor) {
103 self.collect_dependencies(idx, visited);
104 }
105 }
106 }
107 }
108
109 pub fn merge(&mut self, other: &EinsumGraph) -> Result<HashMap<usize, usize>, IrError> {
113 let mut tensor_map = HashMap::new();
114
115 for (old_idx, tensor_name) in other.tensors.iter().enumerate() {
117 if let Some(existing_idx) = self.tensors.iter().position(|t| t == tensor_name) {
118 tensor_map.insert(old_idx, existing_idx);
119 } else {
120 let new_idx = self.add_tensor(tensor_name);
121 tensor_map.insert(old_idx, new_idx);
122 }
123 }
124
125 for node in &other.nodes {
127 let new_node = node.remap_tensors(&tensor_map)?;
128 self.add_node(new_node)?;
129 }
130
131 for &out_idx in &other.outputs {
133 if let Some(&new_idx) = tensor_map.get(&out_idx) {
134 if !self.outputs.contains(&new_idx) {
135 self.add_output(new_idx)?;
136 }
137 }
138 }
139
140 Ok(tensor_map)
141 }
142
143 pub fn visit<V: GraphVisitor>(&self, visitor: &mut V) {
145 visitor.start(self);
146 for (idx, node) in self.nodes.iter().enumerate() {
147 visitor.visit_node(idx, node, self);
148 }
149 visitor.finish(self);
150 }
151
152 pub fn visit_mut<V: GraphMutVisitor>(&mut self, visitor: &mut V) -> Result<(), IrError> {
154 let graph_clone = self.clone();
156
157 for idx in 0..self.nodes.len() {
158 visitor.visit_node_mut(idx, &mut self.nodes[idx], &graph_clone)?;
159 }
160
161 Ok(())
162 }
163
164 pub fn apply_rewrite<F>(&mut self, mut rule: F) -> Result<usize, IrError>
168 where
169 F: FnMut(&EinsumNode) -> Option<EinsumNode>,
170 {
171 let mut rewrites = 0;
172
173 for node in &mut self.nodes {
174 if let Some(new_node) = rule(node) {
175 *node = new_node;
176 rewrites += 1;
177 }
178 }
179
180 Ok(rewrites)
181 }
182
183 pub fn tensor_consumers(&self, tensor_idx: usize) -> Vec<usize> {
185 self.nodes
186 .iter()
187 .enumerate()
188 .filter(|(_, node)| node.inputs.contains(&tensor_idx))
189 .map(|(idx, _)| idx)
190 .collect()
191 }
192
193 pub fn tensor_producer(&self, tensor_idx: usize) -> Option<usize> {
199 let consumers = self.tensor_consumers(tensor_idx);
201 if consumers.is_empty() {
202 return None;
203 }
204
205 let min_consumer = consumers.iter().min().copied()?;
206
207 if min_consumer > 0 {
209 Some(min_consumer - 1)
210 } else {
211 None
212 }
213 }
214
215 pub fn has_path(&self, node_from: usize, node_to: usize) -> bool {
217 node_from <= node_to
219 }
220
221 pub fn dependencies(&self, node_idx: usize) -> HashSet<usize> {
223 let mut deps = HashSet::new();
224 self.collect_dependencies(node_idx, &mut deps);
225 deps.remove(&node_idx); deps
227 }
228
229 pub fn node_count(&self) -> usize {
231 self.nodes.len()
232 }
233
234 pub fn tensor_count(&self) -> usize {
236 self.tensors.len()
237 }
238}
239
240impl EinsumNode {
241 pub(crate) fn remap_tensors(
243 &self,
244 tensor_map: &HashMap<usize, usize>,
245 ) -> Result<Self, IrError> {
246 let inputs: Vec<usize> = self
247 .inputs
248 .iter()
249 .map(|&idx| {
250 tensor_map
251 .get(&idx)
252 .copied()
253 .ok_or_else(|| IrError::NodeValidation {
254 node: 0,
255 message: format!("Input tensor {} not in mapping", idx),
256 })
257 })
258 .collect::<Result<_, _>>()?;
259
260 let outputs: Vec<usize> = self
261 .outputs
262 .iter()
263 .map(|&idx| {
264 tensor_map
265 .get(&idx)
266 .copied()
267 .ok_or_else(|| IrError::NodeValidation {
268 node: 0,
269 message: format!("Output tensor {} not in mapping", idx),
270 })
271 })
272 .collect::<Result<_, _>>()?;
273
274 Ok(EinsumNode {
275 inputs,
276 outputs,
277 op: self.op.clone(),
278 metadata: self.metadata.clone(),
279 })
280 }
281}
282
283#[cfg(test)]
284mod tests {
285 use super::*;
286 use crate::graph::OpType;
287
288 fn create_test_graph() -> EinsumGraph {
289 let mut g = EinsumGraph::new();
290
291 let t0 = g.add_tensor("t0");
293 let t1 = g.add_tensor("t1");
294 let t2 = g.add_tensor("t2");
295 let _t3 = g.add_tensor("t3");
296 let t4 = g.add_tensor("t4"); let t5 = g.add_tensor("t5"); let t6 = g.add_tensor("t6"); g.add_node(EinsumNode {
302 inputs: vec![t0],
303 outputs: vec![t4],
304 op: OpType::Einsum {
305 spec: "i->i".to_string(),
306 },
307 metadata: None,
308 })
309 .unwrap();
310
311 g.add_node(EinsumNode {
313 inputs: vec![t1],
314 outputs: vec![t5],
315 op: OpType::Einsum {
316 spec: "i->i".to_string(),
317 },
318 metadata: None,
319 })
320 .unwrap();
321
322 g.add_node(EinsumNode {
324 inputs: vec![t2],
325 outputs: vec![t6],
326 op: OpType::Einsum {
327 spec: "i->i".to_string(),
328 },
329 metadata: None,
330 })
331 .unwrap();
332
333 g.add_output(t6).unwrap();
334
335 g
336 }
337
338 #[test]
339 fn test_extract_subgraph() {
340 let graph = create_test_graph();
341
342 let subgraph = graph.extract_subgraph(&[0, 1]).unwrap();
344
345 assert_eq!(subgraph.nodes.len(), 2);
346 assert!(subgraph.tensors.len() >= 2);
347 }
348
349 #[test]
350 fn test_merge_graphs() {
351 let mut g1 = EinsumGraph::new();
352 let t0 = g1.add_tensor("shared");
353 let t1 = g1.add_tensor("out1");
354 g1.add_node(EinsumNode {
355 inputs: vec![t0],
356 outputs: vec![t1],
357 op: OpType::Einsum {
358 spec: "i->i".to_string(),
359 },
360 metadata: None,
361 })
362 .unwrap();
363
364 let mut g2 = EinsumGraph::new();
365 let t0_2 = g2.add_tensor("shared");
366 let t1_2 = g2.add_tensor("out2");
367 g2.add_node(EinsumNode {
368 inputs: vec![t0_2],
369 outputs: vec![t1_2],
370 op: OpType::Einsum {
371 spec: "i->i".to_string(),
372 },
373 metadata: None,
374 })
375 .unwrap();
376
377 let tensor_map = g1.merge(&g2).unwrap();
378
379 assert_eq!(tensor_map[&0], 0); assert_eq!(g1.nodes.len(), 2);
382 }
383
384 #[test]
385 fn test_tensor_consumers() {
386 let graph = create_test_graph();
387
388 let consumers = graph.tensor_consumers(1); assert_eq!(consumers.len(), 1);
390 assert_eq!(consumers[0], 1); }
392
393 #[test]
394 fn test_has_path() {
395 let graph = create_test_graph();
396
397 assert!(graph.has_path(0, 2)); assert!(graph.has_path(0, 0)); assert!(!graph.has_path(2, 0)); }
401
402 #[test]
403 fn test_visitor_pattern() {
404 let graph = create_test_graph();
405
406 struct CountingVisitor {
407 count: usize,
408 }
409
410 impl GraphVisitor for CountingVisitor {
411 fn visit_node(&mut self, _idx: usize, _node: &EinsumNode, _graph: &EinsumGraph) {
412 self.count += 1;
413 }
414 }
415
416 let mut visitor = CountingVisitor { count: 0 };
417 graph.visit(&mut visitor);
418
419 assert_eq!(visitor.count, 3);
420 }
421
422 #[test]
423 fn test_apply_rewrite() {
424 let mut graph = create_test_graph();
425
426 let rewrites = graph
428 .apply_rewrite(|node| {
429 if matches!(node.op, OpType::Einsum { .. }) {
430 Some(EinsumNode {
431 inputs: node.inputs.clone(),
432 outputs: node.outputs.clone(),
433 op: OpType::Einsum {
434 spec: "new->spec".to_string(),
435 },
436 metadata: None,
437 })
438 } else {
439 None
440 }
441 })
442 .unwrap();
443
444 assert_eq!(rewrites, 3);
445
446 for node in &graph.nodes {
447 if let OpType::Einsum { spec } = &node.op {
448 assert_eq!(spec, "new->spec");
449 }
450 }
451 }
452
453 #[test]
454 fn test_node_count() {
455 let graph = create_test_graph();
456 assert_eq!(graph.node_count(), 3);
457 assert_eq!(graph.tensor_count(), 7); }
459
460 #[test]
461 fn test_dependencies() {
462 let mut graph = EinsumGraph::new();
464 let t0 = graph.add_tensor("t0");
465 let t1 = graph.add_tensor("t1"); let t2 = graph.add_tensor("t2"); graph
470 .add_node(EinsumNode {
471 inputs: vec![t0],
472 outputs: vec![t1],
473 op: OpType::Einsum {
474 spec: "i->i".to_string(),
475 },
476 metadata: None,
477 })
478 .unwrap();
479
480 graph
482 .add_node(EinsumNode {
483 inputs: vec![t1],
484 outputs: vec![t2],
485 op: OpType::Einsum {
486 spec: "i->i".to_string(),
487 },
488 metadata: None,
489 })
490 .unwrap();
491
492 let deps = graph.dependencies(1);
493 assert!(deps.contains(&0));
495 assert_eq!(deps.len(), 1);
496 }
497}