tensor_forge/graph.rs
1//! Structure for representing ML runtimes via Node and Op intermediate representation.
2
3use crate::node::{Node, NodeId};
4use crate::op::OpKind;
5use std::collections::{BTreeSet, HashMap, HashSet};
6use std::fmt;
7
8/// Error types for [`Graph`] construction and validation.
9///
10/// These errors are returned by graph-building APIs when an operation cannot be
11/// represented safely in the current graph.
12///
13/// # Examples
14/// ```
15/// # use tensor_forge::graph::{Graph, GraphError};
16/// let mut g = Graph::new();
17/// let a = g.input_node(vec![2, 3]);
18/// let b = g.input_node(vec![2, 4]);
19///
20/// // add() requires identical shapes
21/// assert!(matches!(g.add(a, b).unwrap_err(), GraphError::ShapeMismatch));
22/// ```
23#[derive(Clone, Debug)]
24pub enum GraphError {
25 /// Raised when connecting nodes whose tensor shapes are incompatible for the requested op.
26 ///
27 /// # Examples
28 /// - `add(A, B)` requires `shape(A) == shape(B)`.
29 /// - `matmul(L, R)` requires `L` and `R` be 2-D and `L.shape[1] == R.shape[0]`.
30 ///
31 /// This error indicates the graph is not well-typed under the op’s shape rules.
32 ShapeMismatch,
33 /// Raised when an operation references a [`NodeId`] that does not exist in the graph.
34 ///
35 /// This typically happens when:
36 /// - A `NodeId` was produced by a different [`Graph`] instance, or
37 /// - A stale/invalid `NodeId` was stored and reused.
38 ///
39 /// # Example
40 /// ```
41 /// # use tensor_forge::graph::{Graph, GraphError};
42 /// let mut g1 = Graph::new();
43 /// let foreign = g1.input_node(vec![1, 1]);
44 ///
45 /// let mut g2 = Graph::new();
46 /// assert!(matches!(g2.relu(foreign).unwrap_err(), GraphError::InvalidNodeId));
47 /// ```
48 InvalidNodeId,
49 /// Raised when inserting a node whose ID already exists in the graph.
50 ///
51 /// In this implementation, node IDs are expected to be monotonically increasing and unique.
52 /// A collision indicates a serious invariant failure (e.g. ID overflow or a bug in node
53 /// allocation), and should be treated as unrecoverable at the application level.
54 IdCollision,
55 /// Raised when the graph contains a cycle and no valid execution order exists.
56 CycleDetected,
57}
58
59/// A minimal compute-graph container for an ML runtime intermediate representation (IR).
60///
61/// A [`Graph`] owns a set of [`Node`]s indexed by [`NodeId`]. Each node encodes:
62/// - an operation kind ([`OpKind`]),
63/// - a list of input dependencies (by `NodeId`), and
64/// - the inferred output tensor shape.
65///
66/// This type currently supports constructing a graph via:
67/// - [`Graph::input_node`] for source nodes, and
68/// - op constructors like [`Graph::add`], [`Graph::matmul`], and [`Graph::relu`].
69///
70/// Output nodes must be designated explicitly via [`Graph::set_output_node`].
71///
72/// # Examples
73/// ```
74/// # use tensor_forge::graph::Graph;
75/// let mut g = Graph::new();
76/// let x = g.input_node(vec![2, 3]);
77/// let y = g.relu(x).unwrap();
78/// g.set_output_node(y).unwrap();
79/// assert_eq!(g.outputs().len(), 1);
80/// ```
81pub struct Graph {
82 nodes: HashMap<NodeId, Node>,
83 inputs: Vec<NodeId>,
84 outputs: HashSet<NodeId>,
85}
86
87impl Default for Graph {
88 fn default() -> Self {
89 Self::new()
90 }
91}
92
93impl Graph {
94 /// Private helper method for inserting new node into the graph.
95 fn add_node(&mut self, node: Node) -> Result<NodeId, GraphError> {
96 let node_id = node.id;
97 // Each node is generated to be unique in monotonically increasing order. Collisions
98 // indicate that graph nodes have overflowed.
99 if self.nodes.contains_key(&node_id) {
100 return Err(GraphError::IdCollision);
101 }
102 if node.op == OpKind::Input {
103 self.inputs.push(node_id);
104 }
105 self.nodes.insert(node_id, node);
106 Ok(node_id)
107 }
108
109 /// Creates an empty graph with no nodes, inputs, or outputs.
110 ///
111 /// # Examples
112 /// ```
113 /// # use tensor_forge::graph::Graph;
114 /// let g = Graph::new();
115 /// assert_eq!(g.num_nodes(), 0);
116 /// assert!(g.inputs().is_empty());
117 /// assert!(g.outputs().is_empty());
118 /// ```
119 #[must_use]
120 pub fn new() -> Self {
121 Graph {
122 nodes: HashMap::new(),
123 inputs: Vec::new(),
124 outputs: HashSet::new(),
125 }
126 }
127
128 /// Creates a new input node with the given tensor `shape` and returns its `NodeId`.
129 ///
130 /// Input nodes have no dependencies and an output shape equal to `shape`.
131 ///
132 /// # Panics
133 /// Panics if a node ID collision is detected (an invariant violation indicating too many nodes
134 /// have been allocated or ID generation is broken).
135 ///
136 /// # Examples
137 /// ```
138 /// # use tensor_forge::graph::Graph;
139 /// let mut g = Graph::new();
140 /// let x = g.input_node(vec![2, 3]);
141 /// assert!(g.node(x).is_ok());
142 /// assert_eq!(g.num_nodes(), 1);
143 /// ```
144 pub fn input_node(&mut self, shape: Vec<usize>) -> NodeId {
145 let node = Node::new(OpKind::Input, Vec::new(), shape);
146 self.add_node(node).expect("Node ID collision detected on node creation. Too many nodes may have been allocated. Ensure that Graph operations are single-threaded.")
147 }
148
149 /// Adds a matrix multiplication node `left × right`.
150 ///
151 /// Shape rule (2-D):
152 /// - `left.shape = [m, k]`
153 /// - `right.shape = [k, n]`
154 /// - output shape is `[m, n]`
155 ///
156 /// # Errors
157 ///
158 /// Returns [`GraphError::InvalidNodeId`] if either `left` or `right` does not exist
159 /// in this graph.
160 ///
161 /// Returns [`GraphError::ShapeMismatch`] if the inner dimensions do not match.
162 ///
163 /// # Examples
164 /// ```
165 /// # use tensor_forge::graph::{Graph, GraphError};
166 /// let mut g = Graph::new();
167 /// let a = g.input_node(vec![2, 3]);
168 /// let b = g.input_node(vec![3, 4]);
169 ///
170 /// let c = g.matmul(a, b).unwrap();
171 /// assert!(g.node(c).is_ok());
172 /// assert_eq!(g.num_nodes(), 3);
173 ///
174 /// // Mismatched inner dimension: [2,3] x [2,4] is invalid
175 /// let bad = g.input_node(vec![2, 4]);
176 /// assert!(matches!(g.matmul(a, bad).unwrap_err(), GraphError::ShapeMismatch));
177 /// ```
178 pub fn matmul(&mut self, left: NodeId, right: NodeId) -> Result<NodeId, GraphError> {
179 let left_node = self.node(left)?;
180 let right_node = self.node(right)?;
181 if left_node.shape[1] != right_node.shape[0] {
182 return Err(GraphError::ShapeMismatch);
183 }
184 let shape = vec![left_node.shape[0], right_node.shape[1]];
185 let matmul_node = Node::new(OpKind::MatMul, vec![left_node.id, right_node.id], shape);
186 self.add_node(matmul_node)
187 }
188
189 /// Adds an elementwise addition node `left + right`.
190 ///
191 /// Shape rule:
192 /// - `shape(left) == shape(right)`
193 ///
194 /// # Errors
195 ///
196 /// Returns [`GraphError::InvalidNodeId`] if either input does not exist in this graph.
197 ///
198 /// Returns [`GraphError::ShapeMismatch`] if the shapes differ.
199 ///
200 /// # Examples
201 /// ```
202 /// # use tensor_forge::graph::{Graph, GraphError};
203 /// let mut g = Graph::new();
204 /// let a = g.input_node(vec![2, 3]);
205 /// let b = g.input_node(vec![2, 3]);
206 ///
207 /// let c = g.add(a, b).unwrap();
208 /// assert!(g.node(c).is_ok());
209 ///
210 /// let d = g.input_node(vec![2, 4]);
211 /// assert!(matches!(g.add(a, d).unwrap_err(), GraphError::ShapeMismatch));
212 /// ```
213 pub fn add(&mut self, left: NodeId, right: NodeId) -> Result<NodeId, GraphError> {
214 let left_node = self.node(left)?;
215 let right_node = self.node(right)?;
216 if left_node.shape != right_node.shape {
217 return Err(GraphError::ShapeMismatch);
218 }
219 let addition_node = Node::new(
220 OpKind::Add,
221 vec![left_node.id, right_node.id],
222 left_node.shape.clone(),
223 );
224 self.add_node(addition_node)
225 }
226
227 /// Adds a `ReLU` node `relu(input)`.
228 ///
229 /// `ReLU` preserves shape: `shape(output) == shape(input)`.
230 ///
231 /// # Errors
232 ///
233 /// Returns [`GraphError::InvalidNodeId`] if `input` does not exist in this graph.
234 ///
235 /// # Examples
236 /// ```
237 /// # use tensor_forge::graph::{Graph, GraphError};
238 /// let mut g = Graph::new();
239 /// let x = g.input_node(vec![2, 3]);
240 ///
241 /// let y = g.relu(x).unwrap();
242 /// assert!(g.node(y).is_ok());
243 ///
244 /// // Using a NodeId from another graph is invalid
245 /// let mut other = Graph::new();
246 /// let foreign = other.input_node(vec![2, 3]);
247 /// assert!(matches!(g.relu(foreign).unwrap_err(), GraphError::InvalidNodeId));
248 /// ```
249 pub fn relu(&mut self, input: NodeId) -> Result<NodeId, GraphError> {
250 let input_node = self.node(input)?;
251 let relu_node = Node::new(OpKind::ReLU, vec![input_node.id], input_node.shape.clone());
252 self.add_node(relu_node)
253 }
254
255 /// Marks `node` as an output node.
256 ///
257 /// Graphs must have at least one output node to be meaningful for execution, and may have
258 /// multiple outputs. This method does **not** create a new node or execute anything; it only
259 /// records the provided node ID as an output.
260 ///
261 /// # Errors
262 ///
263 /// Returns [`GraphError::InvalidNodeId`] if `node` does not exist in this graph.
264 ///
265 /// # Examples
266 /// ```
267 /// # use tensor_forge::graph::{Graph, GraphError};
268 /// let mut g = Graph::new();
269 /// let x = g.input_node(vec![2, 3]);
270 /// let y = g.relu(x).expect("No error should occur in the construction of this ReLU");
271 ///
272 /// assert!(g.outputs().is_empty());
273 /// g.set_output_node(y).expect("We are passing a valid output node");
274 /// assert_eq!(g.outputs().contains(&y), true);
275 /// assert_eq!(g.outputs().len(), 1);
276 ///
277 /// // A NodeId from another graph is invalid
278 /// let mut other = Graph::new();
279 /// let foreign = other.input_node(vec![2, 3]);
280 /// assert!(matches!(g.set_output_node(foreign).unwrap_err(), GraphError::InvalidNodeId));
281 /// ```
282 pub fn set_output_node(&mut self, node: NodeId) -> Result<(), GraphError> {
283 let node = self.node(node)?;
284 self.outputs.insert(node.id);
285 Ok(())
286 }
287
288 /// Returns a shared reference to the node with the given `NodeId`.
289 ///
290 /// # Errors
291 ///
292 /// Returns [`GraphError::InvalidNodeId`] if the node is not present in this graph.
293 ///
294 /// # Examples
295 /// ```
296 /// # use tensor_forge::graph::{Graph, GraphError};
297 /// let mut g = Graph::new();
298 /// let x = g.input_node(vec![1, 1]);
299 /// assert!(g.node(x).is_ok());
300 ///
301 /// // A NodeId from another graph is invalid
302 /// let mut other = Graph::new();
303 /// let foreign = other.input_node(vec![1, 1]);
304 /// assert!(matches!(g.node(foreign).unwrap_err(), GraphError::InvalidNodeId));
305 /// ```
306 pub fn node(&self, id: NodeId) -> Result<&Node, GraphError> {
307 match self.nodes.get(&id) {
308 Some(node) => Ok(node),
309 None => Err(GraphError::InvalidNodeId),
310 }
311 }
312
313 /// Returns the total number of nodes stored in this graph.
314 ///
315 /// # Examples
316 /// ```
317 /// # use tensor_forge::graph::Graph;
318 /// let mut g = Graph::new();
319 /// assert_eq!(g.num_nodes(), 0);
320 /// let x = g.input_node(vec![2, 3]);
321 /// let y = g.relu(x).unwrap();
322 /// assert_eq!(g.num_nodes(), 2);
323 /// ```
324 #[must_use]
325 pub fn num_nodes(&self) -> usize {
326 self.nodes.values().len()
327 }
328
329 /// Returns the list of nodes.
330 ///
331 /// Every inserted node is appended to this list
332 /// (including op nodes created by [`Graph::add`], [`Graph::matmul`], and [`Graph::relu`]).
333 ///
334 /// # Examples
335 /// ```
336 /// # use std::collections::HashSet;
337 /// # use tensor_forge::graph::Graph;
338 /// let mut g = Graph::new();
339 /// let a = g.input_node(vec![2, 3]);
340 /// let b = g.input_node(vec![2, 3]);
341 /// let c = g.add(a, b).unwrap();
342 ///
343 /// // Includes both inputs and the derived node.
344 /// for node in g.nodes() {
345 /// assert!([a, b, c].contains(&node.id));
346 /// }
347 ///
348 /// ```
349 pub fn nodes(&self) -> impl Iterator<Item = &Node> {
350 self.nodes.values()
351 }
352
353 /// Returns the list of nodes recorded as inputs.
354 ///
355 /// # Examples
356 /// ```
357 /// # use tensor_forge::graph::Graph;
358 /// let mut g = Graph::new();
359 /// let a = g.input_node(vec![2, 3]);
360 /// let b = g.input_node(vec![2, 3]);
361 /// let c = g.add(a, b).unwrap();
362 ///
363 /// // Only includes both inputs.
364 /// assert_eq!(g.inputs(), &[a, b]);
365 /// ```
366 #[must_use]
367 pub fn inputs(&self) -> &[NodeId] {
368 &self.inputs
369 }
370
371 /// Returns the list of nodes marked as outputs via [`Graph::set_output_node`].
372 ///
373 /// # Examples
374 /// ```
375 /// # use tensor_forge::graph::Graph;
376 /// let mut g = Graph::new();
377 /// let x = g.input_node(vec![2, 3]);
378 /// let y = g.relu(x).expect("No error should occur in the construction of this ReLU");
379 ///
380 /// assert!(g.outputs().is_empty());
381 /// g.set_output_node(y).expect("We are passing a valid output node");
382 /// assert!(g.outputs().contains(&y));
383 /// assert_eq!(g.outputs().len(), 1);
384 /// ```
385 #[must_use]
386 pub fn outputs(&self) -> &HashSet<NodeId> {
387 &self.outputs
388 }
389
390 /// Computes a deterministic topological execution order (Kahn's Algorithm) of all nodes in the graph.
391 ///
392 /// Topological ordering guarantees that every node appears *after* all of its
393 /// dependencies. This ordering is required for correct execution of the compute graph,
394 /// since kernels must not execute before their input tensors are available.
395 ///
396 /// The returned order includes every node in the graph exactly once.
397 ///
398 /// # Determinism
399 ///
400 /// Determinism is guaranteed by enforcing a stable tie-breaking rule when multiple
401 /// nodes are ready for execution. Nodes with zero remaining dependencies are processed
402 /// in ascending [`NodeId`] order.
403 ///
404 /// This ensures:
405 ///
406 /// - Reproducible execution across runs
407 /// - Independence from hash seed randomization
408 /// - Stable ordering suitable for debugging and testing
409 ///
410 /// # Returns
411 ///
412 /// A vector of [`NodeId`] representing the execution order.
413 ///
414 /// The order satisfies the invariant:
415 ///
416 /// ```text
417 /// For every node N:
418 /// all inputs(N) appear before N in the returned vector
419 /// ```
420 ///
421 /// # Errors
422 ///
423 /// Returns [`GraphError::CycleDetected`] if the graph contains a cycle. Assuming normal API
424 /// use, Graph methods will not allow cycle creation to ever occur.
425 ///
426 /// Cycles violate compute graph semantics because no valid execution order exists.
427 ///
428 /// # Complexity
429 ///
430 /// Time complexity: **O(V + E)**
431 /// Space complexity: **O(V + E)**
432 ///
433 /// where:
434 ///
435 /// - V = number of nodes
436 /// - E = number of edges (dependencies)
437 ///
438 /// # Examples
439 ///
440 /// ```
441 /// # use tensor_forge::graph::Graph;
442 /// let mut g = Graph::new();
443 ///
444 /// let a = g.input_node(vec![2, 3]);
445 /// let b = g.relu(a).unwrap();
446 /// let c = g.relu(b).unwrap();
447 ///
448 /// let order = g.topo_sort().unwrap();
449 ///
450 /// let pos = |id| order.iter().position(|&x| x == id).unwrap();
451 ///
452 /// assert!(pos(a) < pos(b));
453 /// assert!(pos(b) < pos(c));
454 /// ```
455 pub fn topo_sort(&self) -> Result<Vec<NodeId>, GraphError> {
456 // indegree[v] = number of incoming edges to v (i.e., number of deps v has)
457 // outgoing[u] = list of nodes that depend on u (edges u -> v)
458 let n = self.nodes.len();
459 let mut indegree: HashMap<NodeId, usize> = HashMap::with_capacity(n);
460 let mut outgoing: HashMap<NodeId, Vec<NodeId>> = HashMap::with_capacity(n);
461
462 // Initialize all nodes with indegree 0 so we can safely increment later.
463 for &id in self.nodes.keys() {
464 indegree.insert(id, 0);
465 }
466
467 // Build indegree and outgoing adjacency.
468 for (&id, node) in &self.nodes {
469 for &dep in &node.inputs {
470 // If the graph was constructed via public API this can't happen,
471 // but it keeps topo_sort robust against malformed graphs.
472 if !self.nodes.contains_key(&dep) {
473 return Err(GraphError::InvalidNodeId);
474 }
475 match indegree.get_mut(&id) {
476 Some(deg) => *deg += 1,
477 None => return Err(GraphError::InvalidNodeId),
478 }
479 outgoing.entry(dep).or_default().push(id);
480 }
481 }
482
483 // Deterministic ready set: always pop the smallest NodeId.
484 let mut ready: BTreeSet<NodeId> = indegree
485 .iter()
486 .filter_map(|(&id, °)| if deg == 0 { Some(id) } else { None })
487 .collect();
488
489 let mut order: Vec<NodeId> = Vec::with_capacity(n);
490
491 while let Some(&id) = ready.iter().next() {
492 ready.remove(&id);
493 order.push(id);
494
495 if let Some(dependents) = outgoing.get(&id) {
496 for &v in dependents {
497 let Some(deg) = indegree.get_mut(&v) else {
498 return Err(GraphError::InvalidNodeId);
499 };
500
501 // v had an incoming edge from id; remove it.
502 if *deg == 0 {
503 return Err(GraphError::CycleDetected);
504 }
505 *deg -= 1;
506
507 if *deg == 0 {
508 ready.insert(v);
509 }
510 }
511 }
512 }
513
514 if order.len() != n {
515 // Some nodes never reached indegree 0 => cycle (or unreachable due to malformed indegrees).
516 return Err(GraphError::CycleDetected);
517 }
518
519 Ok(order)
520 }
521
522 /// Private helper method that allows inserting duplicate node IDs or creating cycles for
523 /// stress-testing the API.
524 ///
525 /// This should not be used in any code other than the unit tests in `graph.rs`.
526 #[cfg(test)]
527 fn add_node_unsafe(&mut self, node: Node) -> NodeId {
528 let node_id = node.id;
529 self.nodes.insert(node_id, node);
530 self.inputs.push(node_id);
531 node_id
532 }
533}
534
535impl fmt::Display for GraphError {
536 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
537 match self {
538 GraphError::ShapeMismatch => {
539 write!(
540 f,
541 "Mismatched input and output dimensions for Nodes A and B. dim(Output(A)) must match dim(Output(B))"
542 )
543 }
544 GraphError::InvalidNodeId => {
545 write!(
546 f,
547 "Attempted to operate on a Node that does not exist in the graph. Ensure you are only interacting with nodes via Graph::input_node()."
548 )
549 }
550 GraphError::IdCollision => {
551 write!(
552 f,
553 "Attempted to add a new node to a graph with an ID that already exists."
554 )
555 }
556 GraphError::CycleDetected => {
557 write!(
558 f,
559 "Graph contains a dependency cycle. Execution order cannot be determined."
560 )
561 }
562 }
563 }
564}
565
566#[cfg(test)]
567mod tests {
568 use crate::graph::*;
569 use crate::node::*;
570
571 /// Small unit test for internal implementation of returning `IdCollision`. This is untestable in
572 /// integration tests because normal methods of generating node collisions are not publicly exposed
573 /// in the API.
574 ///
575 /// See `tests/graph_tests.rs` for graph integration tests.
576 #[test]
577 fn add_node_rejects_duplicate_id() {
578 let mut g = Graph::new();
579
580 let n1 = Node::new(OpKind::Input, vec![], vec![2, 2]);
581 let n2 = Node {
582 id: n1.id,
583 op: OpKind::Input,
584 inputs: vec![],
585 shape: vec![2, 2],
586 };
587
588 assert!(g.add_node(n1).is_ok());
589 assert!(matches!(
590 g.add_node(n2).unwrap_err(),
591 GraphError::IdCollision
592 ));
593 }
594
595 /// Public API only allows appending to the graph via forward-referencing only. As such,
596 /// there is no way of generating a cycle via the public API.
597 ///
598 /// See `tests/graph_tests.rs` for graph integration tests.
599 #[test]
600 fn topo_sort_rejects_cycles() {
601 let mut graph = Graph::new();
602
603 // Create two nodes first (as inputs), then manually wire them into a cycle by
604 // constructing new Nodes with explicit inputs and inserting via add_node().
605 //
606 // This pattern is contained to unit tests (integration tests can’t access add_node).
607
608 let a = graph.input_node(vec![1, 1]);
609 let b = graph.input_node(vec![1, 1]);
610
611 // Overwrite the existing nodes in the graph with cyclic dependencies.
612 // Because nodes are stored by NodeId, we can replace them by inserting a Node
613 // with the same id.
614 let c = Node {
615 id: a,
616 op: OpKind::ReLU,
617 inputs: vec![b],
618 shape: vec![1, 1],
619 };
620 let d = Node {
621 id: b,
622 op: OpKind::ReLU,
623 inputs: vec![a],
624 shape: vec![1, 1],
625 };
626
627 graph.add_node_unsafe(c);
628 graph.add_node_unsafe(d);
629
630 let err = graph.topo_sort().unwrap_err();
631 assert!(matches!(err, GraphError::CycleDetected));
632 }
633}