Skip to main content

tensor_forge/
graph.rs

1//! Structure for representing ML runtimes via Node and Op intermediate representation.
2
3use crate::node::{Node, NodeId};
4use crate::op::OpKind;
5use std::collections::{BTreeSet, HashMap, HashSet};
6use std::fmt;
7
8/// Error types for [`Graph`] construction and validation.
9///
10/// These errors are returned by graph-building APIs when an operation cannot be
11/// represented safely in the current graph.
12///
13/// # Examples
14/// ```
15/// # use tensor_forge::graph::{Graph, GraphError};
16/// let mut g = Graph::new();
17/// let a = g.input_node(vec![2, 3]);
18/// let b = g.input_node(vec![2, 4]);
19///
20/// // add() requires identical shapes
21/// assert!(matches!(g.add(a, b).unwrap_err(), GraphError::ShapeMismatch));
22/// ```
23#[derive(Clone, Debug)]
24pub enum GraphError {
25    /// Raised when connecting nodes whose tensor shapes are incompatible for the requested op.
26    ///
27    /// # Examples
28    /// - `add(A, B)` requires `shape(A) == shape(B)`.
29    /// - `matmul(L, R)` requires `L` and `R` be 2-D and `L.shape[1] == R.shape[0]`.
30    ///
31    /// This error indicates the graph is not well-typed under the op’s shape rules.
32    ShapeMismatch,
33    /// Raised when an operation references a [`NodeId`] that does not exist in the graph.
34    ///
35    /// This typically happens when:
36    /// - A `NodeId` was produced by a different [`Graph`] instance, or
37    /// - A stale/invalid `NodeId` was stored and reused.
38    ///
39    /// # Example
40    /// ```
41    /// # use tensor_forge::graph::{Graph, GraphError};
42    /// let mut g1 = Graph::new();
43    /// let foreign = g1.input_node(vec![1, 1]);
44    ///
45    /// let mut g2 = Graph::new();
46    /// assert!(matches!(g2.relu(foreign).unwrap_err(), GraphError::InvalidNodeId));
47    /// ```
48    InvalidNodeId,
49    /// Raised when inserting a node whose ID already exists in the graph.
50    ///
51    /// In this implementation, node IDs are expected to be monotonically increasing and unique.
52    /// A collision indicates a serious invariant failure (e.g. ID overflow or a bug in node
53    /// allocation), and should be treated as unrecoverable at the application level.
54    IdCollision,
55    /// Raised when the graph contains a cycle and no valid execution order exists.
56    CycleDetected,
57}
58
59/// A minimal compute-graph container for an ML runtime intermediate representation (IR).
60///
61/// A [`Graph`] owns a set of [`Node`]s indexed by [`NodeId`]. Each node encodes:
62/// - an operation kind ([`OpKind`]),
63/// - a list of input dependencies (by `NodeId`), and
64/// - the inferred output tensor shape.
65///
66/// This type currently supports constructing a graph via:
67/// - [`Graph::input_node`] for source nodes, and
68/// - op constructors like [`Graph::add`], [`Graph::matmul`], and [`Graph::relu`].
69///
70/// Output nodes must be designated explicitly via [`Graph::set_output_node`].
71///
72/// # Examples
73/// ```
74/// # use tensor_forge::graph::Graph;
75/// let mut g = Graph::new();
76/// let x = g.input_node(vec![2, 3]);
77/// let y = g.relu(x).unwrap();
78/// g.set_output_node(y).unwrap();
79/// assert_eq!(g.outputs().len(), 1);
80/// ```
81pub struct Graph {
82    nodes: HashMap<NodeId, Node>,
83    inputs: Vec<NodeId>,
84    outputs: HashSet<NodeId>,
85}
86
87impl Default for Graph {
88    fn default() -> Self {
89        Self::new()
90    }
91}
92
93impl Graph {
94    /// Private helper method for inserting new node into the graph.
95    fn add_node(&mut self, node: Node) -> Result<NodeId, GraphError> {
96        let node_id = node.id;
97        // Each node is generated to be unique in monotonically increasing order. Collisions
98        // indicate that graph nodes have overflowed.
99        if self.nodes.contains_key(&node_id) {
100            return Err(GraphError::IdCollision);
101        }
102        if node.op == OpKind::Input {
103            self.inputs.push(node_id);
104        }
105        self.nodes.insert(node_id, node);
106        Ok(node_id)
107    }
108
109    /// Creates an empty graph with no nodes, inputs, or outputs.
110    ///
111    /// # Examples
112    /// ```
113    /// # use tensor_forge::graph::Graph;
114    /// let g = Graph::new();
115    /// assert_eq!(g.num_nodes(), 0);
116    /// assert!(g.inputs().is_empty());
117    /// assert!(g.outputs().is_empty());
118    /// ```
119    #[must_use]
120    pub fn new() -> Self {
121        Graph {
122            nodes: HashMap::new(),
123            inputs: Vec::new(),
124            outputs: HashSet::new(),
125        }
126    }
127
128    /// Creates a new input node with the given tensor `shape` and returns its `NodeId`.
129    ///
130    /// Input nodes have no dependencies and an output shape equal to `shape`.
131    ///
132    /// # Panics
133    /// Panics if a node ID collision is detected (an invariant violation indicating too many nodes
134    /// have been allocated or ID generation is broken).
135    ///
136    /// # Examples
137    /// ```
138    /// # use tensor_forge::graph::Graph;
139    /// let mut g = Graph::new();
140    /// let x = g.input_node(vec![2, 3]);
141    /// assert!(g.node(x).is_ok());
142    /// assert_eq!(g.num_nodes(), 1);
143    /// ```
144    pub fn input_node(&mut self, shape: Vec<usize>) -> NodeId {
145        let node = Node::new(OpKind::Input, Vec::new(), shape);
146        self.add_node(node).expect("Node ID collision detected on node creation. Too many nodes may have been allocated. Ensure that Graph operations are single-threaded.")
147    }
148
149    /// Adds a matrix multiplication node `left × right`.
150    ///
151    /// Shape rule (2-D):
152    /// - `left.shape = [m, k]`
153    /// - `right.shape = [k, n]`
154    /// - output shape is `[m, n]`
155    ///
156    /// # Errors
157    ///
158    /// Returns [`GraphError::InvalidNodeId`] if either `left` or `right` does not exist
159    /// in this graph.
160    ///
161    /// Returns [`GraphError::ShapeMismatch`] if the inner dimensions do not match.
162    ///
163    /// # Examples
164    /// ```
165    /// # use tensor_forge::graph::{Graph, GraphError};
166    /// let mut g = Graph::new();
167    /// let a = g.input_node(vec![2, 3]);
168    /// let b = g.input_node(vec![3, 4]);
169    ///
170    /// let c = g.matmul(a, b).unwrap();
171    /// assert!(g.node(c).is_ok());
172    /// assert_eq!(g.num_nodes(), 3);
173    ///
174    /// // Mismatched inner dimension: [2,3] x [2,4] is invalid
175    /// let bad = g.input_node(vec![2, 4]);
176    /// assert!(matches!(g.matmul(a, bad).unwrap_err(), GraphError::ShapeMismatch));
177    /// ```
178    pub fn matmul(&mut self, left: NodeId, right: NodeId) -> Result<NodeId, GraphError> {
179        let left_node = self.node(left)?;
180        let right_node = self.node(right)?;
181        if left_node.shape[1] != right_node.shape[0] {
182            return Err(GraphError::ShapeMismatch);
183        }
184        let shape = vec![left_node.shape[0], right_node.shape[1]];
185        let matmul_node = Node::new(OpKind::MatMul, vec![left_node.id, right_node.id], shape);
186        self.add_node(matmul_node)
187    }
188
189    /// Adds an elementwise addition node `left + right`.
190    ///
191    /// Shape rule:
192    /// - `shape(left) == shape(right)`
193    ///
194    /// # Errors
195    ///
196    /// Returns [`GraphError::InvalidNodeId`] if either input does not exist in this graph.
197    ///
198    /// Returns [`GraphError::ShapeMismatch`] if the shapes differ.
199    ///
200    /// # Examples
201    /// ```
202    /// # use tensor_forge::graph::{Graph, GraphError};
203    /// let mut g = Graph::new();
204    /// let a = g.input_node(vec![2, 3]);
205    /// let b = g.input_node(vec![2, 3]);
206    ///
207    /// let c = g.add(a, b).unwrap();
208    /// assert!(g.node(c).is_ok());
209    ///
210    /// let d = g.input_node(vec![2, 4]);
211    /// assert!(matches!(g.add(a, d).unwrap_err(), GraphError::ShapeMismatch));
212    /// ```
213    pub fn add(&mut self, left: NodeId, right: NodeId) -> Result<NodeId, GraphError> {
214        let left_node = self.node(left)?;
215        let right_node = self.node(right)?;
216        if left_node.shape != right_node.shape {
217            return Err(GraphError::ShapeMismatch);
218        }
219        let addition_node = Node::new(
220            OpKind::Add,
221            vec![left_node.id, right_node.id],
222            left_node.shape.clone(),
223        );
224        self.add_node(addition_node)
225    }
226
227    /// Adds a `ReLU` node `relu(input)`.
228    ///
229    /// `ReLU` preserves shape: `shape(output) == shape(input)`.
230    ///
231    /// # Errors
232    ///
233    /// Returns [`GraphError::InvalidNodeId`] if `input` does not exist in this graph.
234    ///
235    /// # Examples
236    /// ```
237    /// # use tensor_forge::graph::{Graph, GraphError};
238    /// let mut g = Graph::new();
239    /// let x = g.input_node(vec![2, 3]);
240    ///
241    /// let y = g.relu(x).unwrap();
242    /// assert!(g.node(y).is_ok());
243    ///
244    /// // Using a NodeId from another graph is invalid
245    /// let mut other = Graph::new();
246    /// let foreign = other.input_node(vec![2, 3]);
247    /// assert!(matches!(g.relu(foreign).unwrap_err(), GraphError::InvalidNodeId));
248    /// ```
249    pub fn relu(&mut self, input: NodeId) -> Result<NodeId, GraphError> {
250        let input_node = self.node(input)?;
251        let relu_node = Node::new(OpKind::ReLU, vec![input_node.id], input_node.shape.clone());
252        self.add_node(relu_node)
253    }
254
255    /// Marks `node` as an output node.
256    ///
257    /// Graphs must have at least one output node to be meaningful for execution, and may have
258    /// multiple outputs. This method does **not** create a new node or execute anything; it only
259    /// records the provided node ID as an output.
260    ///
261    /// # Errors
262    ///
263    /// Returns [`GraphError::InvalidNodeId`] if `node` does not exist in this graph.
264    ///
265    /// # Examples
266    /// ```
267    /// # use tensor_forge::graph::{Graph, GraphError};
268    /// let mut g = Graph::new();
269    /// let x = g.input_node(vec![2, 3]);
270    /// let y = g.relu(x).expect("No error should occur in the construction of this ReLU");
271    ///
272    /// assert!(g.outputs().is_empty());
273    /// g.set_output_node(y).expect("We are passing a valid output node");
274    /// assert_eq!(g.outputs().contains(&y), true);
275    /// assert_eq!(g.outputs().len(), 1);
276    ///
277    /// // A NodeId from another graph is invalid
278    /// let mut other = Graph::new();
279    /// let foreign = other.input_node(vec![2, 3]);
280    /// assert!(matches!(g.set_output_node(foreign).unwrap_err(), GraphError::InvalidNodeId));
281    /// ```
282    pub fn set_output_node(&mut self, node: NodeId) -> Result<(), GraphError> {
283        let node = self.node(node)?;
284        self.outputs.insert(node.id);
285        Ok(())
286    }
287
288    /// Returns a shared reference to the node with the given `NodeId`.
289    ///
290    /// # Errors
291    ///
292    /// Returns [`GraphError::InvalidNodeId`] if the node is not present in this graph.
293    ///
294    /// # Examples
295    /// ```
296    /// # use tensor_forge::graph::{Graph, GraphError};
297    /// let mut g = Graph::new();
298    /// let x = g.input_node(vec![1, 1]);
299    /// assert!(g.node(x).is_ok());
300    ///
301    /// // A NodeId from another graph is invalid
302    /// let mut other = Graph::new();
303    /// let foreign = other.input_node(vec![1, 1]);
304    /// assert!(matches!(g.node(foreign).unwrap_err(), GraphError::InvalidNodeId));
305    /// ```
306    pub fn node(&self, id: NodeId) -> Result<&Node, GraphError> {
307        match self.nodes.get(&id) {
308            Some(node) => Ok(node),
309            None => Err(GraphError::InvalidNodeId),
310        }
311    }
312
313    /// Returns the total number of nodes stored in this graph.
314    ///
315    /// # Examples
316    /// ```
317    /// # use tensor_forge::graph::Graph;
318    /// let mut g = Graph::new();
319    /// assert_eq!(g.num_nodes(), 0);
320    /// let x = g.input_node(vec![2, 3]);
321    /// let y = g.relu(x).unwrap();
322    /// assert_eq!(g.num_nodes(), 2);
323    /// ```
324    #[must_use]
325    pub fn num_nodes(&self) -> usize {
326        self.nodes.values().len()
327    }
328
329    /// Returns the list of nodes.
330    ///
331    /// Every inserted node is appended to this list
332    /// (including op nodes created by [`Graph::add`], [`Graph::matmul`], and [`Graph::relu`]).
333    ///
334    /// # Examples
335    /// ```
336    /// # use std::collections::HashSet;
337    /// # use tensor_forge::graph::Graph;
338    /// let mut g = Graph::new();
339    /// let a = g.input_node(vec![2, 3]);
340    /// let b = g.input_node(vec![2, 3]);
341    /// let c = g.add(a, b).unwrap();
342    ///
343    /// // Includes both inputs and the derived node.
344    /// for node in g.nodes() {
345    ///     assert!([a, b, c].contains(&node.id));
346    /// }
347    ///
348    /// ```
349    pub fn nodes(&self) -> impl Iterator<Item = &Node> {
350        self.nodes.values()
351    }
352
353    /// Returns the list of nodes recorded as inputs.
354    ///
355    /// # Examples
356    /// ```
357    /// # use tensor_forge::graph::Graph;
358    /// let mut g = Graph::new();
359    /// let a = g.input_node(vec![2, 3]);
360    /// let b = g.input_node(vec![2, 3]);
361    /// let c = g.add(a, b).unwrap();
362    ///
363    /// // Only includes both inputs.
364    /// assert_eq!(g.inputs(), &[a, b]);
365    /// ```
366    #[must_use]
367    pub fn inputs(&self) -> &[NodeId] {
368        &self.inputs
369    }
370
371    /// Returns the list of nodes marked as outputs via [`Graph::set_output_node`].
372    ///
373    /// # Examples
374    /// ```
375    /// # use tensor_forge::graph::Graph;
376    /// let mut g = Graph::new();
377    /// let x = g.input_node(vec![2, 3]);
378    /// let y = g.relu(x).expect("No error should occur in the construction of this ReLU");
379    ///
380    /// assert!(g.outputs().is_empty());
381    /// g.set_output_node(y).expect("We are passing a valid output node");
382    /// assert!(g.outputs().contains(&y));
383    /// assert_eq!(g.outputs().len(), 1);
384    /// ```
385    #[must_use]
386    pub fn outputs(&self) -> &HashSet<NodeId> {
387        &self.outputs
388    }
389
390    /// Computes a deterministic topological execution order (Kahn's Algorithm) of all nodes in the graph.
391    ///
392    /// Topological ordering guarantees that every node appears *after* all of its
393    /// dependencies. This ordering is required for correct execution of the compute graph,
394    /// since kernels must not execute before their input tensors are available.
395    ///
396    /// The returned order includes every node in the graph exactly once.
397    ///
398    /// # Determinism
399    ///
400    /// Determinism is guaranteed by enforcing a stable tie-breaking rule when multiple
401    /// nodes are ready for execution. Nodes with zero remaining dependencies are processed
402    /// in ascending [`NodeId`] order.
403    ///
404    /// This ensures:
405    ///
406    /// - Reproducible execution across runs
407    /// - Independence from hash seed randomization
408    /// - Stable ordering suitable for debugging and testing
409    ///
410    /// # Returns
411    ///
412    /// A vector of [`NodeId`] representing the execution order.
413    ///
414    /// The order satisfies the invariant:
415    ///
416    /// ```text
417    /// For every node N:
418    ///     all inputs(N) appear before N in the returned vector
419    /// ```
420    ///
421    /// # Errors
422    ///
423    /// Returns [`GraphError::CycleDetected`] if the graph contains a cycle. Assuming normal API
424    /// use, Graph methods will not allow cycle creation to ever occur.
425    ///
426    /// Cycles violate compute graph semantics because no valid execution order exists.
427    ///
428    /// # Complexity
429    ///
430    /// Time complexity: **O(V + E)**
431    /// Space complexity: **O(V + E)**
432    ///
433    /// where:
434    ///
435    /// - V = number of nodes
436    /// - E = number of edges (dependencies)
437    ///
438    /// # Examples
439    ///
440    /// ```
441    /// # use tensor_forge::graph::Graph;
442    /// let mut g = Graph::new();
443    ///
444    /// let a = g.input_node(vec![2, 3]);
445    /// let b = g.relu(a).unwrap();
446    /// let c = g.relu(b).unwrap();
447    ///
448    /// let order = g.topo_sort().unwrap();
449    ///
450    /// let pos = |id| order.iter().position(|&x| x == id).unwrap();
451    ///
452    /// assert!(pos(a) < pos(b));
453    /// assert!(pos(b) < pos(c));
454    /// ```
455    pub fn topo_sort(&self) -> Result<Vec<NodeId>, GraphError> {
456        // indegree[v] = number of incoming edges to v (i.e., number of deps v has)
457        // outgoing[u] = list of nodes that depend on u (edges u -> v)
458        let n = self.nodes.len();
459        let mut indegree: HashMap<NodeId, usize> = HashMap::with_capacity(n);
460        let mut outgoing: HashMap<NodeId, Vec<NodeId>> = HashMap::with_capacity(n);
461
462        // Initialize all nodes with indegree 0 so we can safely increment later.
463        for &id in self.nodes.keys() {
464            indegree.insert(id, 0);
465        }
466
467        // Build indegree and outgoing adjacency.
468        for (&id, node) in &self.nodes {
469            for &dep in &node.inputs {
470                // If the graph was constructed via public API this can't happen,
471                // but it keeps topo_sort robust against malformed graphs.
472                if !self.nodes.contains_key(&dep) {
473                    return Err(GraphError::InvalidNodeId);
474                }
475                match indegree.get_mut(&id) {
476                    Some(deg) => *deg += 1,
477                    None => return Err(GraphError::InvalidNodeId),
478                }
479                outgoing.entry(dep).or_default().push(id);
480            }
481        }
482
483        // Deterministic ready set: always pop the smallest NodeId.
484        let mut ready: BTreeSet<NodeId> = indegree
485            .iter()
486            .filter_map(|(&id, &deg)| if deg == 0 { Some(id) } else { None })
487            .collect();
488
489        let mut order: Vec<NodeId> = Vec::with_capacity(n);
490
491        while let Some(&id) = ready.iter().next() {
492            ready.remove(&id);
493            order.push(id);
494
495            if let Some(dependents) = outgoing.get(&id) {
496                for &v in dependents {
497                    let Some(deg) = indegree.get_mut(&v) else {
498                        return Err(GraphError::InvalidNodeId);
499                    };
500
501                    // v had an incoming edge from id; remove it.
502                    if *deg == 0 {
503                        return Err(GraphError::CycleDetected);
504                    }
505                    *deg -= 1;
506
507                    if *deg == 0 {
508                        ready.insert(v);
509                    }
510                }
511            }
512        }
513
514        if order.len() != n {
515            // Some nodes never reached indegree 0 => cycle (or unreachable due to malformed indegrees).
516            return Err(GraphError::CycleDetected);
517        }
518
519        Ok(order)
520    }
521
522    /// Private helper method that allows inserting duplicate node IDs or creating cycles for
523    /// stress-testing the API.
524    ///
525    /// This should not be used in any code other than the unit tests in `graph.rs`.
526    #[cfg(test)]
527    fn add_node_unsafe(&mut self, node: Node) -> NodeId {
528        let node_id = node.id;
529        self.nodes.insert(node_id, node);
530        self.inputs.push(node_id);
531        node_id
532    }
533}
534
535impl fmt::Display for GraphError {
536    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
537        match self {
538            GraphError::ShapeMismatch => {
539                write!(
540                    f,
541                    "Mismatched input and output dimensions for Nodes A and B. dim(Output(A)) must match dim(Output(B))"
542                )
543            }
544            GraphError::InvalidNodeId => {
545                write!(
546                    f,
547                    "Attempted to operate on a Node that does not exist in the graph. Ensure you are only interacting with nodes via Graph::input_node()."
548                )
549            }
550            GraphError::IdCollision => {
551                write!(
552                    f,
553                    "Attempted to add a new node to a graph with an ID that already exists."
554                )
555            }
556            GraphError::CycleDetected => {
557                write!(
558                    f,
559                    "Graph contains a dependency cycle. Execution order cannot be determined."
560                )
561            }
562        }
563    }
564}
565
566#[cfg(test)]
567mod tests {
568    use crate::graph::*;
569    use crate::node::*;
570
571    /// Small unit test for internal implementation of returning `IdCollision`. This is untestable in
572    /// integration tests because normal methods of generating node collisions are not publicly exposed
573    /// in the API.
574    ///
575    /// See `tests/graph_tests.rs` for graph integration tests.
576    #[test]
577    fn add_node_rejects_duplicate_id() {
578        let mut g = Graph::new();
579
580        let n1 = Node::new(OpKind::Input, vec![], vec![2, 2]);
581        let n2 = Node {
582            id: n1.id,
583            op: OpKind::Input,
584            inputs: vec![],
585            shape: vec![2, 2],
586        };
587
588        assert!(g.add_node(n1).is_ok());
589        assert!(matches!(
590            g.add_node(n2).unwrap_err(),
591            GraphError::IdCollision
592        ));
593    }
594
595    /// Public API only allows appending to the graph via forward-referencing only. As such,
596    /// there is no way of generating a cycle via the public API.
597    ///
598    /// See `tests/graph_tests.rs` for graph integration tests.
599    #[test]
600    fn topo_sort_rejects_cycles() {
601        let mut graph = Graph::new();
602
603        // Create two nodes first (as inputs), then manually wire them into a cycle by
604        // constructing new Nodes with explicit inputs and inserting via add_node().
605        //
606        // This pattern is contained to unit tests (integration tests can’t access add_node).
607
608        let a = graph.input_node(vec![1, 1]);
609        let b = graph.input_node(vec![1, 1]);
610
611        // Overwrite the existing nodes in the graph with cyclic dependencies.
612        // Because nodes are stored by NodeId, we can replace them by inserting a Node
613        // with the same id.
614        let c = Node {
615            id: a,
616            op: OpKind::ReLU,
617            inputs: vec![b],
618            shape: vec![1, 1],
619        };
620        let d = Node {
621            id: b,
622            op: OpKind::ReLU,
623            inputs: vec![a],
624            shape: vec![1, 1],
625        };
626
627        graph.add_node_unsafe(c);
628        graph.add_node_unsafe(d);
629
630        let err = graph.topo_sort().unwrap_err();
631        assert!(matches!(err, GraphError::CycleDetected));
632    }
633}