branching_graph/branching_graph.rs
1//! Slightly richer example: a branching graph with ReLU, MatMul, and Add.
2//!
3//! Graph:
4//! ra = relu(a)
5//! rb = relu(b)
6//! mm = matmul(ra, rb)
7//! out = add(mm, c)
8//! c ---- \
9//! a ---- relu(a) ---- \ out
10//! matmul --- /
11//! b ---- relu(b) ---- /
12//!
13//! This demonstrates:
14//! - multiple input nodes with different shapes,
15//! - shape propagation through several operations,
16//! - intermediate nodes feeding later computation,
17//! - how the executor relies on registered kernels for each op kind.
18//!
19//! This example is useful because it shows that kernels are not
20//! responsible for graph traversal. The executor handles dependency ordering; each
21//! kernel only computes one node once its dependencies are available.
22//!
23//! For kernel authors:
24//!
25//! The executor will evaluate this graph in dependency order roughly like:
26//! 1. read input bindings for `a`, `b`, `c`
27//! 2. execute ReLU kernel for `ra`
28//! 3. execute ReLU kernel for `rb`
29//! 4. execute MatMul kernel for `mm`
30//! 5. execute Add kernel for `out`
31//!
32//! Important consequence:
33//! - a kernel does not need to recursively evaluate upstream nodes,
34//! - a kernel only consumes already-available input tensors,
35//! - shape compatibility should already be structurally valid if the graph builder
36//! enforces shape rules, though kernels may still defensively validate runtime inputs.
37//!
38//! A custom MatMul kernel, for example, would typically:
39//! - expect exactly 2 input tensors,
40//! - verify rank/shape assumptions,
41//! - compute matrix multiplication,
42//! - return a new tensor with the node's declared output shape.
43
44use tensor_forge::{Executor, Graph, KernelRegistry, Tensor};
45
46fn main() {
47 let mut graph = Graph::new();
48
49 // Declare graph inputs.
50 //
51 // The shapes here establish the legal runtime tensor shapes:
52 // a: [2, 3]
53 // b: [3, 2]
54 // c: [2, 2]
55 let a = graph.input_node(vec![2, 3]);
56 let b = graph.input_node(vec![3, 2]);
57 let c = graph.input_node(vec![2, 2]);
58
59 // Build intermediate operations.
60 //
61 // `relu(a)` preserves shape [2, 3].
62 let ra = graph.relu(a).expect("Valid ReLU operation should succeed");
63
64 // `relu(b)` preserves shape [3, 2].
65 let rb = graph.relu(b).expect("Valid ReLU operation should succeed");
66
67 // `matmul(ra, rb)` combines [2, 3] x [3, 2] -> [2, 2].
68 //
69 // This is a good example of graph-level validation preventing malformed graphs
70 // before execution ever begins.
71 let mm = graph
72 .matmul(ra, rb)
73 .expect("Valid matmul operation should succeed");
74
75 // `add(mm, c)` adds two [2, 2] tensors and also yields [2, 2].
76 let out = graph
77 .add(mm, c)
78 .expect("Valid add operation should succeed");
79
80 graph
81 .set_output_node(out)
82 .expect("Setting output node should succeed");
83
84 // Bind concrete runtime values.
85 //
86 // These values are only examples; the graph structure is independent of them.
87 // The same graph can be executed many times with different input tensors.
88 let a_tensor = Tensor::from_vec(vec![2, 3], vec![-1.0, 2.0, -3.0, 4.0, -5.0, 6.0])
89 .expect("Tensor construction should succeed");
90
91 let b_tensor = Tensor::from_vec(vec![3, 2], vec![-7.0, 8.0, 9.0, -10.0, 11.0, 12.0])
92 .expect("Tensor construction should succeed");
93
94 let c_tensor = Tensor::from_vec(vec![2, 2], vec![0.5, 1.5, 2.5, 3.5])
95 .expect("Tensor construction should succeed");
96
97 let exec = Executor::new(KernelRegistry::default());
98
99 let outputs = exec
100 .execute(&graph, vec![(a, a_tensor), (b, b_tensor), (c, c_tensor)])
101 .expect("Execution should succeed");
102
103 let result = outputs
104 .get(&out)
105 .expect("Declared output should be present in executor results");
106
107 println!("Computed output for node {:?}: {:?}", out, result);
108}