tensorlogic_ir/graph/
mod.rs

1//! Tensor computation graphs (EinsumGraph).
2
3pub mod advanced_analysis;
4pub mod canonicalization;
5pub mod cost_model;
6pub mod dot_export;
7mod einsum_spec;
8mod einsum_spec_display;
9mod node;
10pub mod optimization;
11mod optype;
12pub mod transform;
13pub mod validation;
14
15use serde::{Deserialize, Serialize};
16use std::collections::HashMap;
17
18pub use canonicalization::{are_graphs_equivalent, canonical_hash, canonicalize_graph};
19pub use dot_export::{export_to_dot, export_to_dot_with_options, DotExportOptions};
20pub use einsum_spec::EinsumSpec;
21pub use node::EinsumNode;
22pub use optype::OpType;
23// Public API traits for graph transformation - meant for external use
24#[allow(unused_imports)]
25pub use transform::{GraphMutVisitor, GraphVisitor};
26pub use validation::{
27    validate_graph, GraphValidationStats, ValidationError, ValidationErrorKind, ValidationReport,
28    ValidationWarning, ValidationWarningKind,
29};
30
31use crate::error::IrError;
32use crate::metadata::Metadata;
33
34#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)]
35pub struct EinsumGraph {
36    pub tensors: Vec<String>,
37    pub nodes: Vec<EinsumNode>,
38    pub inputs: Vec<usize>,
39    pub outputs: Vec<usize>,
40    /// Metadata for tensors (indexed by tensor index)
41    #[serde(default)]
42    pub tensor_metadata: HashMap<usize, Metadata>,
43}
44
45impl EinsumGraph {
46    pub fn new() -> Self {
47        Self::default()
48    }
49
50    pub fn with_capacity(tensor_cap: usize, node_cap: usize) -> Self {
51        EinsumGraph {
52            tensors: Vec::with_capacity(tensor_cap),
53            nodes: Vec::with_capacity(node_cap),
54            inputs: Vec::new(),
55            outputs: Vec::new(),
56            tensor_metadata: HashMap::new(),
57        }
58    }
59
60    pub fn add_tensor(&mut self, name: impl Into<String>) -> usize {
61        let idx = self.tensors.len();
62        self.tensors.push(name.into());
63        idx
64    }
65
66    pub fn add_node(&mut self, node: EinsumNode) -> Result<usize, IrError> {
67        node.validate(self.tensors.len())?;
68        let idx = self.nodes.len();
69        self.nodes.push(node);
70        Ok(idx)
71    }
72
73    pub fn add_input(&mut self, tensor_idx: usize) -> Result<(), IrError> {
74        if tensor_idx >= self.tensors.len() {
75            return Err(IrError::TensorIndexOutOfBounds {
76                index: tensor_idx,
77                max: self.tensors.len() - 1,
78            });
79        }
80        self.inputs.push(tensor_idx);
81        Ok(())
82    }
83
84    pub fn add_output(&mut self, tensor_idx: usize) -> Result<(), IrError> {
85        if tensor_idx >= self.tensors.len() {
86            return Err(IrError::OutputIndexOutOfBounds {
87                index: tensor_idx,
88                max: self.tensors.len() - 1,
89            });
90        }
91        self.outputs.push(tensor_idx);
92        Ok(())
93    }
94
95    pub fn validate(&self) -> Result<(), IrError> {
96        for (idx, node) in self.nodes.iter().enumerate() {
97            node.validate(self.tensors.len())
98                .map_err(|e| IrError::NodeValidation {
99                    node: idx,
100                    message: e.to_string(),
101                })?;
102        }
103
104        for &out_idx in &self.outputs {
105            if out_idx >= self.tensors.len() {
106                return Err(IrError::OutputIndexOutOfBounds {
107                    index: out_idx,
108                    max: self.tensors.len() - 1,
109                });
110            }
111        }
112
113        Ok(())
114    }
115
116    pub fn is_empty(&self) -> bool {
117        self.tensors.is_empty() && self.nodes.is_empty()
118    }
119
120    /// Add metadata for a tensor.
121    pub fn add_tensor_metadata(&mut self, tensor_idx: usize, metadata: Metadata) {
122        self.tensor_metadata.insert(tensor_idx, metadata);
123    }
124
125    /// Get metadata for a tensor if it exists.
126    pub fn get_tensor_metadata(&self, tensor_idx: usize) -> Option<&Metadata> {
127        self.tensor_metadata.get(&tensor_idx)
128    }
129
130    /// Add a tensor with metadata.
131    pub fn add_tensor_with_metadata(
132        &mut self,
133        name: impl Into<String>,
134        metadata: Metadata,
135    ) -> usize {
136        let idx = self.add_tensor(name);
137        self.add_tensor_metadata(idx, metadata);
138        idx
139    }
140}