Skip to main content

custom_kernel/
custom_kernel.rs

1//! Demonstrates how to define and register a custom kernel.
2//!
3//! This example is aimed at users who want to extend the runtime with their own
4//! kernel implementations.
5//!
6//! The important architectural pieces are:
7//!
8//! - `Graph` describes *what* should be computed.
9//! - `Executor` decides *when* each node is executed.
10//! - `KernelRegistry` decides *which kernel* handles each `OpKind`.
11//! - A `Kernel` implementation computes the output for one node from its inputs.
12//!
13//! In this example, we replace the default `Add` kernel with a custom one.
14//!
15//! ## The kernel contract
16//!
17//! A kernel implementation receives:
18//! - `inputs`: already-computed input tensors for the current node
19//! - `output`: a mutable tensor already allocated by the executor
20//!
21//! A kernel is expected to:
22//! 1. validate its input count and any assumptions it cares about,
23//! 2. read from `inputs`,
24//! 3. write the result into `output`,
25//! 4. return `Ok(())` on success or `Err(KernelError)` on failure.
26//!
27//! The executor handles:
28//! - topological ordering,
29//! - dependency resolution,
30//! - output allocation,
31//! - error attribution back to the graph node.
32//!
33//! So kernel authors do **not** need to traverse the graph themselves.
34
35use tensor_forge::executor::Executor;
36use tensor_forge::graph::Graph;
37use tensor_forge::kernel::{Kernel, KernelError};
38use tensor_forge::op::OpKind;
39use tensor_forge::registry::KernelRegistry;
40use tensor_forge::tensor::Tensor;
41
42/// A minimal custom kernel for `OpKind::Add`.
43///
44/// This example focuses on the *shape* of a kernel implementation and the
45/// registry/executor integration points.
46struct CustomAddKernel;
47
48impl Kernel for CustomAddKernel {
49    fn compute(&self, inputs: &[&Tensor], output: &mut Tensor) -> Result<(), KernelError> {
50        // Validate arity.
51        //
52        // The Add operation expects exactly two input tensors.
53        if inputs.len() != 2 {
54            return Err(KernelError::InvalidArguments);
55        }
56
57        // Validate shape agreement.
58        //
59        // Graph construction should already guarantee this for well-formed Add nodes,
60        // but kernels may still validate assumptions defensively.
61        let left = inputs[0];
62        let right = inputs[1];
63
64        if left.shape() != right.shape() || left.shape() != output.shape() {
65            return Err(KernelError::InvalidArguments);
66        }
67
68        let left_data = left.data();
69        let right_data = right.data();
70        let output_data = output.data_mut();
71        for i in 0..output_data.len() {
72            output_data[i] = left_data[i] + right_data[i];
73        }
74        Ok(())
75    }
76}
77
78fn main() {
79    // Build a tiny graph:
80    //
81    //   out = add(a, b)
82    //
83    // The graph describes *what* should be computed, not how.
84    let mut graph = Graph::new();
85
86    let a = graph.input_node(vec![2, 2]);
87    let b = graph.input_node(vec![2, 2]);
88    let out = graph
89        .add(a, b)
90        .expect("Adding valid input nodes should succeed");
91
92    graph
93        .set_output_node(out)
94        .expect("Setting output node should succeed");
95
96    // Create a custom registry.
97    //
98    // Start from an empty registry and explicitly register only the kernel(s)
99    // needed by this graph.
100    let mut registry = KernelRegistry::new();
101
102    // Register our custom Add kernel.
103    //
104    // `register(...)` returns the previous mapping if one existed.
105    let old = registry.register(OpKind::Add, Box::new(CustomAddKernel));
106    assert!(
107        old.is_none(),
108        "First Add registration should not replace an existing kernel"
109    );
110
111    // Construct the executor with the custom registry.
112    let exec = Executor::new(registry);
113
114    // Bind runtime inputs.
115    //
116    // These are ordinary tensors supplied for the graph input nodes.
117    let a_tensor = Tensor::from_vec(vec![2, 2], vec![1.0, 2.0, 3.0, 4.0])
118        .expect("Tensor construction should succeed");
119    let b_tensor = Tensor::from_vec(vec![2, 2], vec![10.0, 20.0, 30.0, 40.0])
120        .expect("Tensor construction should succeed");
121
122    // Execute the graph.
123    //
124    // During execution:
125    // - the executor validates input bindings,
126    // - walks the graph in topological order,
127    // - sees an `OpKind::Add` node,
128    // - looks up `OpKind::Add` in the registry,
129    // - dispatches to `CustomAddKernel::compute(...)`.
130    let outputs = exec
131        .execute(&graph, vec![(a, a_tensor), (b, b_tensor)])
132        .expect("Execution should succeed");
133
134    let result = outputs
135        .get(&out)
136        .expect("Declared output should be present in executor results");
137
138    println!("Computed output for node {:?}: {:?}", out, result);
139}