1use crate::{device::Device, dtype::DType, error::TensorError, shape::Shape, tensor::Tensor};
7use std::collections::HashMap;
8
9pub type NodeId = u64;
11
12pub type EdgeId = u64;
14
15#[derive(Clone, Debug)]
17pub struct GraphNode {
18 pub id: NodeId,
19 pub name: String,
20 pub op_type: NodeType,
21 pub device: Device,
22 pub inputs: Vec<EdgeId>,
23 pub outputs: Vec<EdgeId>,
24 pub attributes: HashMap<String, AttributeValue>,
25}
26
27#[derive(Clone, Debug, PartialEq)]
29pub enum NodeType {
30 Operation(String), Variable { dtype: DType, shape: Shape },
34 Placeholder { dtype: DType, shape: Shape },
36 Constant,
38}
39
40#[derive(Clone, Debug)]
42pub struct GraphEdge {
43 pub id: EdgeId,
44 pub from_node: NodeId,
45 pub to_node: NodeId,
46 pub from_output: usize, pub to_input: usize, pub dtype: DType,
49 pub shape: Shape,
50 pub is_control: bool, }
52
53#[derive(Clone, Debug)]
55pub enum AttributeValue {
56 String(String),
57 Int(i64),
58 Float(f64),
59 Bool(bool),
60 IntList(Vec<i64>),
61 FloatList(Vec<f64>),
62 Shape(Shape),
63 Tensor(Tensor<f32>), }
65
66#[derive(Debug, Clone)]
68pub struct Graph {
69 pub(crate) nodes: HashMap<NodeId, GraphNode>,
70 pub(crate) edges: HashMap<EdgeId, GraphEdge>,
71 pub(crate) next_node_id: NodeId,
72 pub(crate) next_edge_id: EdgeId,
73 pub(crate) name_to_node: HashMap<String, NodeId>,
74 pub(crate) topological_order: Option<Vec<NodeId>>,
75 pub(crate) version: u64,
76}
77
78impl Graph {
79 pub fn new() -> Self {
81 Self {
82 nodes: HashMap::new(),
83 edges: HashMap::new(),
84 next_node_id: 0,
85 next_edge_id: 0,
86 name_to_node: HashMap::new(),
87 topological_order: None,
88 version: 0,
89 }
90 }
91
92 pub fn add_node(
94 &mut self,
95 name: String,
96 op_type: NodeType,
97 device: Device,
98 attributes: HashMap<String, AttributeValue>,
99 ) -> Result<NodeId, TensorError> {
100 if self.name_to_node.contains_key(&name) {
102 return Err(TensorError::invalid_argument(format!(
103 "Node name '{name}' already exists"
104 )));
105 }
106
107 let node_id = self.next_node_id;
108 self.next_node_id += 1;
109
110 let node = GraphNode {
111 id: node_id,
112 name: name.clone(),
113 op_type,
114 device,
115 inputs: Vec::new(),
116 outputs: Vec::new(),
117 attributes,
118 };
119
120 self.nodes.insert(node_id, node);
121 self.name_to_node.insert(name, node_id);
122 self.topological_order = None; self.version += 1;
124
125 Ok(node_id)
126 }
127
128 #[allow(clippy::too_many_arguments)]
130 pub fn add_edge(
131 &mut self,
132 from_node: NodeId,
133 to_node: NodeId,
134 from_output: usize,
135 to_input: usize,
136 dtype: DType,
137 shape: Shape,
138 is_control: bool,
139 ) -> Result<EdgeId, TensorError> {
140 if !self.nodes.contains_key(&from_node) {
142 return Err(TensorError::invalid_argument(format!(
143 "Source node {from_node} not found"
144 )));
145 }
146 if !self.nodes.contains_key(&to_node) {
147 return Err(TensorError::invalid_argument(format!(
148 "Destination node {to_node} not found"
149 )));
150 }
151
152 if from_node == to_node {
154 return Err(TensorError::invalid_argument(
155 "Self-loops are not allowed".to_string(),
156 ));
157 }
158
159 let edge_id = self.next_edge_id;
160 self.next_edge_id += 1;
161
162 let edge = GraphEdge {
163 id: edge_id,
164 from_node,
165 to_node,
166 from_output,
167 to_input,
168 dtype,
169 shape,
170 is_control,
171 };
172
173 self.edges.insert(edge_id, edge);
174
175 self.nodes
177 .get_mut(&from_node)
178 .expect("Source node must exist after validation")
179 .outputs
180 .push(edge_id);
181 self.nodes
182 .get_mut(&to_node)
183 .expect("Destination node must exist after validation")
184 .inputs
185 .push(edge_id);
186
187 self.topological_order = None; self.version += 1;
189
190 Ok(edge_id)
191 }
192
193 pub fn get_node(&self, node_id: NodeId) -> Option<&GraphNode> {
195 self.nodes.get(&node_id)
196 }
197
198 pub fn get_node_by_name(&self, name: &str) -> Option<&GraphNode> {
200 self.name_to_node
201 .get(name)
202 .and_then(|&id| self.nodes.get(&id))
203 }
204
205 pub fn get_edge(&self, edge_id: EdgeId) -> Option<&GraphEdge> {
207 self.edges.get(&edge_id)
208 }
209
210 pub fn nodes(&self) -> impl Iterator<Item = &GraphNode> {
212 self.nodes.values()
213 }
214
215 pub fn edges(&self) -> impl Iterator<Item = &GraphEdge> {
217 self.edges.values()
218 }
219
220 pub fn get_node_mut(&mut self, node_id: NodeId) -> Option<&mut GraphNode> {
222 self.nodes.get_mut(&node_id)
223 }
224
225 pub fn get_edge_mut(&mut self, edge_id: EdgeId) -> Option<&mut GraphEdge> {
227 self.edges.get_mut(&edge_id)
228 }
229
230 pub fn node_count(&self) -> usize {
232 self.nodes.len()
233 }
234
235 pub fn edge_count(&self) -> usize {
237 self.edges.len()
238 }
239}
240
241impl Default for Graph {
242 fn default() -> Self {
243 Self::new()
244 }
245}