Skip to main content

tensorlogic_ir/graph/
mod.rs

1//! Tensor computation graphs (EinsumGraph).
2
3pub mod advanced_algorithms;
4pub mod advanced_analysis;
5pub mod canonicalization;
6pub mod constant_folding;
7pub mod cost_model;
8pub mod dot_export;
9mod einsum_spec;
10mod einsum_spec_display;
11pub mod export;
12pub mod fusion;
13pub mod layout;
14pub mod memory;
15mod node;
16pub mod optimization;
17mod optype;
18pub mod parallel;
19pub mod pattern;
20pub mod pgo;
21pub mod schedule;
22pub mod tiling;
23pub mod transform;
24pub mod validation;
25
26use serde::{Deserialize, Serialize};
27use std::collections::HashMap;
28
29pub use canonicalization::{are_graphs_equivalent, canonical_hash, canonicalize_graph};
30pub use dot_export::{export_to_dot, export_to_dot_with_options, DotExportOptions};
31pub use einsum_spec::EinsumSpec;
32pub use node::EinsumNode;
33pub use optimization::{
34    eliminate_common_subexpressions, eliminate_dead_code, optimize_graph,
35    simplify_identity_operations, OptimizationStats,
36};
37pub use optype::OpType;
38// Public API traits for graph transformation - meant for external use
39#[allow(unused_imports)]
40pub use transform::{GraphMutVisitor, GraphVisitor};
41pub use validation::{
42    validate_graph, GraphValidationStats, ValidationError, ValidationErrorKind, ValidationReport,
43    ValidationWarning, ValidationWarningKind,
44};
45
46use crate::error::IrError;
47use crate::metadata::Metadata;
48
49#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)]
50pub struct EinsumGraph {
51    pub tensors: Vec<String>,
52    pub nodes: Vec<EinsumNode>,
53    pub inputs: Vec<usize>,
54    pub outputs: Vec<usize>,
55    /// Metadata for tensors (indexed by tensor index)
56    #[serde(default)]
57    pub tensor_metadata: HashMap<usize, Metadata>,
58}
59
60impl EinsumGraph {
61    pub fn new() -> Self {
62        Self::default()
63    }
64
65    pub fn with_capacity(tensor_cap: usize, node_cap: usize) -> Self {
66        EinsumGraph {
67            tensors: Vec::with_capacity(tensor_cap),
68            nodes: Vec::with_capacity(node_cap),
69            inputs: Vec::new(),
70            outputs: Vec::new(),
71            tensor_metadata: HashMap::new(),
72        }
73    }
74
75    pub fn add_tensor(&mut self, name: impl Into<String>) -> usize {
76        let idx = self.tensors.len();
77        self.tensors.push(name.into());
78        idx
79    }
80
81    pub fn add_node(&mut self, node: EinsumNode) -> Result<usize, IrError> {
82        node.validate(self.tensors.len())?;
83        let idx = self.nodes.len();
84        self.nodes.push(node);
85        Ok(idx)
86    }
87
88    pub fn add_input(&mut self, tensor_idx: usize) -> Result<(), IrError> {
89        if tensor_idx >= self.tensors.len() {
90            return Err(IrError::TensorIndexOutOfBounds {
91                index: tensor_idx,
92                max: self.tensors.len() - 1,
93            });
94        }
95        self.inputs.push(tensor_idx);
96        Ok(())
97    }
98
99    pub fn add_output(&mut self, tensor_idx: usize) -> Result<(), IrError> {
100        if tensor_idx >= self.tensors.len() {
101            return Err(IrError::OutputIndexOutOfBounds {
102                index: tensor_idx,
103                max: self.tensors.len() - 1,
104            });
105        }
106        self.outputs.push(tensor_idx);
107        Ok(())
108    }
109
110    pub fn validate(&self) -> Result<(), IrError> {
111        for (idx, node) in self.nodes.iter().enumerate() {
112            node.validate(self.tensors.len())
113                .map_err(|e| IrError::NodeValidation {
114                    node: idx,
115                    message: e.to_string(),
116                })?;
117        }
118
119        for &out_idx in &self.outputs {
120            if out_idx >= self.tensors.len() {
121                return Err(IrError::OutputIndexOutOfBounds {
122                    index: out_idx,
123                    max: self.tensors.len() - 1,
124                });
125            }
126        }
127
128        Ok(())
129    }
130
131    pub fn is_empty(&self) -> bool {
132        self.tensors.is_empty() && self.nodes.is_empty()
133    }
134
135    /// Add metadata for a tensor.
136    pub fn add_tensor_metadata(&mut self, tensor_idx: usize, metadata: Metadata) {
137        self.tensor_metadata.insert(tensor_idx, metadata);
138    }
139
140    /// Get metadata for a tensor if it exists.
141    pub fn get_tensor_metadata(&self, tensor_idx: usize) -> Option<&Metadata> {
142        self.tensor_metadata.get(&tensor_idx)
143    }
144
145    /// Add a tensor with metadata.
146    pub fn add_tensor_with_metadata(
147        &mut self,
148        name: impl Into<String>,
149        metadata: Metadata,
150    ) -> usize {
151        let idx = self.add_tensor(name);
152        self.add_tensor_metadata(idx, metadata);
153        idx
154    }
155}