Skip to main content

branching_graph/
branching_graph.rs

1//! Slightly richer example: a branching graph with ReLU, MatMul, and Add.
2//!
3//! Graph:
4//!   ra = relu(a)
5//!   rb = relu(b)
6//!   mm = matmul(ra, rb)
7//!   out = add(mm, c)
8//!                             c ---- \
9//!   a ---- relu(a) ---- \              out
10//!                         matmul --- /
11//!   b ---- relu(b) ---- /
12//!
13//! This demonstrates:
14//! - multiple input nodes with different shapes,
15//! - shape propagation through several operations,
16//! - intermediate nodes feeding later computation,
17//! - how the executor relies on registered kernels for each op kind.
18//!
19//! This example is useful because it shows that kernels are not
20//! responsible for graph traversal. The executor handles dependency ordering; each
21//! kernel only computes one node once its dependencies are available.
22//!
23//! For kernel authors:
24//!
25//! The executor will evaluate this graph in dependency order roughly like:
26//!   1. read input bindings for `a`, `b`, `c`
27//!   2. execute ReLU kernel for `ra`
28//!   3. execute ReLU kernel for `rb`
29//!   4. execute MatMul kernel for `mm`
30//!   5. execute Add kernel for `out`
31//!
32//! Important consequence:
33//! - a kernel does not need to recursively evaluate upstream nodes,
34//! - a kernel only consumes already-available input tensors,
35//! - shape compatibility should already be structurally valid if the graph builder
36//!   enforces shape rules, though kernels may still defensively validate runtime inputs.
37//!
38//! A custom MatMul kernel, for example, would typically:
39//! - expect exactly 2 input tensors,
40//! - verify rank/shape assumptions,
41//! - compute matrix multiplication,
42//! - return a new tensor with the node's declared output shape.
43
44use tensor_forge::{Executor, Graph, KernelRegistry, Tensor};
45
46fn main() {
47    let mut graph = Graph::new();
48
49    // Declare graph inputs.
50    //
51    // The shapes here establish the legal runtime tensor shapes:
52    //   a: [2, 3]
53    //   b: [3, 2]
54    //   c: [2, 2]
55    let a = graph.input_node(vec![2, 3]);
56    let b = graph.input_node(vec![3, 2]);
57    let c = graph.input_node(vec![2, 2]);
58
59    // Build intermediate operations.
60    //
61    // `relu(a)` preserves shape [2, 3].
62    let ra = graph.relu(a).expect("Valid ReLU operation should succeed");
63
64    // `relu(b)` preserves shape [3, 2].
65    let rb = graph.relu(b).expect("Valid ReLU operation should succeed");
66
67    // `matmul(ra, rb)` combines [2, 3] x [3, 2] -> [2, 2].
68    //
69    // This is a good example of graph-level validation preventing malformed graphs
70    // before execution ever begins.
71    let mm = graph
72        .matmul(ra, rb)
73        .expect("Valid matmul operation should succeed");
74
75    // `add(mm, c)` adds two [2, 2] tensors and also yields [2, 2].
76    let out = graph
77        .add(mm, c)
78        .expect("Valid add operation should succeed");
79
80    graph
81        .set_output_node(out)
82        .expect("Setting output node should succeed");
83
84    // Bind concrete runtime values.
85    //
86    // These values are only examples; the graph structure is independent of them.
87    // The same graph can be executed many times with different input tensors.
88    let a_tensor = Tensor::from_vec(vec![2, 3], vec![-1.0, 2.0, -3.0, 4.0, -5.0, 6.0])
89        .expect("Tensor construction should succeed");
90
91    let b_tensor = Tensor::from_vec(vec![3, 2], vec![-7.0, 8.0, 9.0, -10.0, 11.0, 12.0])
92        .expect("Tensor construction should succeed");
93
94    let c_tensor = Tensor::from_vec(vec![2, 2], vec![0.5, 1.5, 2.5, 3.5])
95        .expect("Tensor construction should succeed");
96
97    let exec = Executor::new(KernelRegistry::default());
98
99    let outputs = exec
100        .execute(&graph, vec![(a, a_tensor), (b, b_tensor), (c, c_tensor)])
101        .expect("Execution should succeed");
102
103    let result = outputs
104        .get(&out)
105        .expect("Declared output should be present in executor results");
106
107    println!("Computed output for node {:?}: {:?}", out, result);
108}