tensor_forge/executor.rs
1//! Execution engine for evaluating compute graphs against a [`KernelRegistry`].
2//!
3//! An [`Executor`] is responsible for:
4//! - validating runtime input bindings,
5//! - traversing a [`Graph`] in topological order,
6//! - dispatching each non-input node to its registered kernel, and
7//! - returning the tensors for all nodes marked as graph outputs.
8//!
9//! Execution is deterministic with respect to graph topology because it relies on
10//! [`Graph::topo_sort`], which returns a stable topological order.
11//!
12//! # Input Binding Model
13//!
14//! Execution requires one tensor binding for every [`OpKind::Input`] node in the graph.
15//! Bindings are passed as a `Vec<(NodeId, Tensor)>`, where:
16//! - the [`NodeId`] identifies an input node in the graph, and
17//! - the [`Tensor`] is the runtime value to supply for that input.
18//!
19//! Bindings are validated before execution begins:
20//! - each bound node must exist in the graph,
21//! - each bound node must be an input node,
22//! - each input node must be bound exactly once, and
23//! - each bound tensor must match the input node’s declared shape.
24//!
25//! # Output Model
26//!
27//! The executor returns a `HashMap<NodeId, Tensor>` containing the computed tensors
28//! for all nodes designated as outputs via [`Graph::set_output_node`].
29//!
30//! Output values are keyed by the output node’s [`NodeId`]. Output ordering is not
31//! part of the API contract.
32//!
33//! # Errors
34//!
35//! Execution may fail for several classes of reasons:
36//! - graph-level failures (for example, invalid topology),
37//! - invalid runtime bindings,
38//! - missing kernel implementations, or
39//! - kernel execution failures at a specific node.
40//!
41//! These are reported via [`ExecutionError`].
42//!
43//! # Examples
44//! ```
45//! # use tensor_forge::executor::Executor;
46//! # use tensor_forge::graph::Graph;
47//! # use tensor_forge::registry::KernelRegistry;
48//! # use tensor_forge::tensor::Tensor;
49//! let mut g = Graph::new();
50//! let a = g.input_node(vec![2, 2]);
51//! let b = g.input_node(vec![2, 2]);
52//! let c = g.add(a, b).expect("Valid add operation should succeed");
53//! g.set_output_node(c).expect("Valid output node should succeed");
54//!
55//! let a_tensor = Tensor::zeros(vec![2, 2]).expect("Tensor allocation should succeed");
56//! let b_tensor = Tensor::zeros(vec![2, 2]).expect("Tensor allocation should succeed");
57//!
58//! let exec = Executor::new(KernelRegistry::default());
59//! let outputs = exec
60//! .execute(&g, vec![(a, a_tensor), (b, b_tensor)])
61//! .expect("Execution should succeed");
62//!
63//! assert!(outputs.contains_key(&c));
64//! ```
65use crate::graph::{Graph, GraphError};
66use crate::kernel::KernelError;
67use crate::node::NodeId;
68use crate::op::OpKind;
69use crate::registry::KernelRegistry;
70use crate::tensor::Tensor;
71use std::collections::{HashMap, HashSet};
72
73/// Executes graphs using kernels registered in a [`KernelRegistry`].
74///
75/// An [`Executor`] owns a kernel registry and uses it to evaluate graph nodes
76/// according to each node’s [`OpKind`]. Non-input nodes are executed in
77/// deterministic topological order, and intermediate tensors are stored internally
78/// until all requested graph outputs have been produced.
79///
80/// # Examples
81/// ```
82/// # use tensor_forge::executor::Executor;
83/// # use tensor_forge::registry::KernelRegistry;
84/// let exec = Executor::new(KernelRegistry::default());
85/// ```
86pub struct Executor {
87 registry: KernelRegistry,
88}
89
90impl Executor {
91 /// Creates a new executor backed by the provided kernel `registry`.
92 ///
93 /// The registry determines which kernel implementation will be used for each
94 /// operation kind encountered during execution.
95 ///
96 /// # Examples
97 /// ```
98 /// # use tensor_forge::executor::Executor;
99 /// # use tensor_forge::registry::KernelRegistry;
100 /// let registry = KernelRegistry::default();
101 /// let exec = Executor::new(registry);
102 /// ```
103 #[must_use]
104 pub fn new(registry: KernelRegistry) -> Self {
105 Self { registry }
106 }
107
108 /// Executes `graph` using the provided input `bindings`.
109 ///
110 /// Each binding is a `(NodeId, Tensor)` pair supplying the runtime value for a
111 /// graph input node. Execution proceeds in deterministic topological order:
112 ///
113 /// 1. Validate the graph topology.
114 /// 2. Validate input bindings.
115 /// 3. Execute every non-input node using the corresponding registered kernel.
116 /// 4. Return a map containing the tensors for all graph output nodes.
117 ///
118 /// The returned map is keyed by output [`NodeId`]. Output order is not part of
119 /// the contract.
120 ///
121 /// # Binding Rules
122 ///
123 /// The `inputs` vector must satisfy all of the following:
124 /// - every bound node must exist in `graph`,
125 /// - every bound node must be an [`OpKind::Input`] node,
126 /// - every graph input node must appear exactly once, and
127 /// - every bound tensor shape must match the input node’s declared shape.
128 ///
129 /// # Errors
130 ///
131 /// Returns:
132 /// - [`ExecutionError::GraphError`] if topological traversal or graph lookup fails,
133 /// - [`ExecutionError::DuplicateBinding`] if the same input node is bound more than once,
134 /// - [`ExecutionError::InvalidBindingNode`] if a binding references a node not in the graph,
135 /// - [`ExecutionError::BindingToNonInputNode`] if a binding targets a non-input node,
136 /// - [`ExecutionError::InputShapeMismatch`] if a bound tensor has the wrong shape,
137 /// - [`ExecutionError::MissingInput`] if any graph input node is not bound,
138 /// - [`ExecutionError::KernelNotFound`] if no kernel is registered for an op,
139 /// - [`ExecutionError::KernelExecutionFailed`] if a kernel returns an error during execution,
140 /// - [`ExecutionError::InternalError`] if an internal invariant is violated.
141 ///
142 /// # Examples
143 /// ```
144 /// # use tensor_forge::executor::Executor;
145 /// # use tensor_forge::graph::Graph;
146 /// # use tensor_forge::registry::KernelRegistry;
147 /// # use tensor_forge::tensor::Tensor;
148 /// let mut g = Graph::new();
149 /// let x = g.input_node(vec![2, 2]);
150 /// let y = g.relu(x).expect("Valid ReLU operation should succeed");
151 /// g.set_output_node(y).expect("Valid output node should succeed");
152 ///
153 /// let x_tensor = Tensor::zeros(vec![2, 2]).expect("Tensor allocation should succeed");
154 ///
155 /// let exec = Executor::new(KernelRegistry::default());
156 /// let outputs = exec
157 /// .execute(&g, vec![(x, x_tensor)])
158 /// .expect("Execution should succeed");
159 ///
160 /// assert!(outputs.contains_key(&y));
161 /// ```
162 pub fn execute(
163 &self,
164 graph: &Graph,
165 inputs: Vec<(NodeId, Tensor)>,
166 ) -> Result<HashMap<NodeId, Tensor>, ExecutionError> {
167 // 0) Discover input nodes (OpKind::Input) and validate bindings.
168 //
169 // NOTE: This uses topo_sort() to traverse all nodes deterministically.
170 // If the graph is cyclic/malformed, this returns a GraphError.
171 let topo = graph.topo_sort().map_err(ExecutionError::GraphError)?;
172
173 let mut input_nodes: Vec<NodeId> = Vec::new();
174 for &id in &topo {
175 let node = graph.node(id).map_err(ExecutionError::GraphError)?;
176 if node.op == OpKind::Input {
177 input_nodes.push(id);
178 }
179 }
180
181 // 1) Validate: duplicate bindings
182 let mut seen: HashSet<NodeId> = HashSet::with_capacity(inputs.len());
183 for (id, _) in &inputs {
184 if !seen.insert(*id) {
185 return Err(ExecutionError::DuplicateBinding { node: *id });
186 }
187 }
188
189 // 2) Validate: binding node exists, is input node, and shape matches
190 //
191 // Also build the runtime value table with the provided inputs.
192 let mut values: HashMap<NodeId, Tensor> =
193 HashMap::with_capacity(graph.num_nodes().max(inputs.len()));
194
195 for (id, t) in inputs {
196 let node = graph
197 .node(id)
198 .map_err(|_| ExecutionError::InvalidBindingNode { node: id })?;
199
200 if node.op != OpKind::Input {
201 return Err(ExecutionError::BindingToNonInputNode {
202 node: id,
203 op: node.op.clone(),
204 });
205 }
206
207 let expected = node.shape.clone();
208 let actual = t.shape().to_vec();
209 if actual != expected {
210 return Err(ExecutionError::InputShapeMismatch {
211 node: id,
212 expected,
213 actual,
214 });
215 }
216
217 // Move the owned tensor into the value table.
218 values.insert(id, t);
219 }
220
221 // 3) Validate: all graph inputs are present in bindings
222 for &input_id in &input_nodes {
223 if !values.contains_key(&input_id) {
224 return Err(ExecutionError::MissingInput { node: input_id });
225 }
226 }
227
228 // 4) Execute in topological order.
229 for &node_id in &topo {
230 let node = graph.node(node_id).map_err(ExecutionError::GraphError)?;
231
232 if node.op == OpKind::Input {
233 // Inputs were already populated by bindings validation.
234 continue;
235 }
236
237 // 4a) Fetch kernel
238 let kernel =
239 self.registry
240 .get(&node.op)
241 .ok_or_else(|| ExecutionError::KernelNotFound {
242 op: node.op.clone(),
243 })?;
244
245 // 4b) Fetch input tensors (by NodeId)
246 let mut input_tensors: Vec<&Tensor> = Vec::with_capacity(node.inputs.len());
247 for &dep in &node.inputs {
248 let t = values.get(&dep).ok_or({
249 // This indicates a bug in topo_sort/executor bookkeeping, or a malformed graph.
250 ExecutionError::InternalError("missing dependency tensor during execution")
251 })?;
252 input_tensors.push(t);
253 }
254
255 // 4c) Allocate output tensor
256 let mut out = Tensor::zeros(node.shape.clone())
257 .map_err(|_| ExecutionError::InternalError("failed to allocate output tensor"))?;
258
259 // 4d) Execute kernel
260 kernel
261 .compute(&input_tensors, &mut out)
262 .map_err(|e: KernelError| ExecutionError::KernelExecutionFailed {
263 node: node_id,
264 op: node.op.clone(),
265 source: e,
266 })?;
267
268 // Store produced tensor for downstream consumers.
269 values.insert(node_id, out);
270 }
271
272 // 5) Collect outputs by output NodeId.
273 let mut outputs: HashMap<NodeId, Tensor> = HashMap::with_capacity(graph.outputs().len());
274 for &out_id in graph.outputs() {
275 let t = values.remove(&out_id).ok_or({
276 ExecutionError::InternalError("output tensor missing after execution")
277 })?;
278 outputs.insert(out_id, t);
279 }
280
281 Ok(outputs)
282 }
283}
284
285impl Default for Executor {
286 /// Creates an executor using [`KernelRegistry::default`].
287 ///
288 /// This is a convenience constructor for the common case where the default
289 /// kernel set is sufficient.
290 ///
291 /// # Examples
292 /// ```
293 /// # use tensor_forge::executor::Executor;
294 /// let exec = Executor::default();
295 /// ```
296 fn default() -> Self {
297 Self {
298 registry: KernelRegistry::default(),
299 }
300 }
301}
302
303/// Errors that may occur while validating bindings or executing a graph.
304///
305/// These errors distinguish between:
306/// - invalid caller-supplied bindings,
307/// - graph-level failures,
308/// - missing kernel implementations, and
309/// - runtime kernel failures.
310///
311/// # Examples
312/// ```
313/// # use tensor_forge::executor::{ExecutionError, Executor};
314/// # use tensor_forge::graph::Graph;
315/// # use tensor_forge::registry::KernelRegistry;
316/// # use tensor_forge::tensor::Tensor;
317/// let mut g = Graph::new();
318/// let x = g.input_node(vec![2, 2]);
319/// g.set_output_node(x).expect("Valid output node should succeed");
320///
321/// let wrong = Tensor::zeros(vec![3, 3]).expect("Tensor allocation should succeed");
322/// let exec = Executor::new(KernelRegistry::default());
323///
324/// let err = exec.execute(&g, vec![(x, wrong)]).unwrap_err();
325/// assert!(matches!(err, ExecutionError::InputShapeMismatch { .. }));
326/// ```
327#[derive(Debug, Clone)]
328pub enum ExecutionError {
329 /// A required input node was not provided in the bindings.
330 ///
331 /// Every [`OpKind::Input`] node in the graph must have exactly one runtime binding.
332 MissingInput {
333 /// The input node that was not bound at execution time.
334 node: NodeId,
335 },
336
337 /// A binding was provided for a node that exists in the graph but is not an input node.
338 ///
339 /// Only input nodes may be bound directly by the caller.
340 BindingToNonInputNode {
341 /// The node that was incorrectly bound.
342 node: NodeId,
343 /// The operation kind of `node`.
344 op: OpKind,
345 },
346
347 /// A binding was provided for a node that does not exist in the graph.
348 ///
349 /// This typically indicates that:
350 /// - a stale [`NodeId`] was reused, or
351 /// - a [`NodeId`] from another graph was supplied.
352 InvalidBindingNode {
353 /// The invalid node identifier supplied in the bindings.
354 node: NodeId,
355 },
356
357 /// Multiple bindings were provided for the same input node.
358 ///
359 /// Input bindings must be unique by [`NodeId`].
360 DuplicateBinding {
361 /// The input node that was bound more than once.
362 node: NodeId,
363 },
364
365 /// A runtime tensor shape did not match the graph input node’s declared shape.
366 ///
367 /// The `expected` shape is taken from the graph node, while `actual` is the
368 /// shape of the caller-provided tensor.
369 InputShapeMismatch {
370 /// The input node whose binding had the wrong shape.
371 node: NodeId,
372 /// The shape declared by the graph for `node`.
373 expected: Vec<usize>,
374 /// The shape of the tensor supplied by the caller.
375 actual: Vec<usize>,
376 },
377
378 /// No kernel implementation was registered for the requested operation.
379 ///
380 /// This prevents execution of any node with the given [`OpKind`].
381 KernelNotFound {
382 /// The operation kind for which no kernel was registered.
383 op: OpKind,
384 },
385
386 /// A kernel returned an error while executing a specific graph node.
387 ///
388 /// This variant preserves:
389 /// - the failing node ID,
390 /// - the operation kind being executed, and
391 /// - the underlying [`KernelError`].
392 KernelExecutionFailed {
393 /// The graph node whose kernel execution failed.
394 node: NodeId,
395 /// The operation kind being executed at `node`.
396 op: OpKind,
397 /// The underlying kernel error returned by the registered kernel.
398 source: KernelError,
399 },
400
401 /// A graph-level failure occurred while preparing execution.
402 ///
403 /// This wraps errors originating from graph traversal or graph validation,
404 /// such as cycle detection or invalid node references.
405 GraphError(GraphError),
406
407 /// An internal executor invariant was violated.
408 ///
409 /// This variant indicates a bug or malformed internal state rather than a
410 /// user-facing validation issue. Under normal operation, callers should not
411 /// be able to trigger this error through the public API alone.
412 InternalError(&'static str),
413}
414
415impl std::fmt::Display for ExecutionError {
416 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
417 match self {
418 ExecutionError::MissingInput { node } => {
419 write!(f, "Missing input binding for node {node:?}")
420 }
421
422 ExecutionError::BindingToNonInputNode { node, op } => {
423 write!(f, "Cannot bind tensor to non-input node {node:?} ({op:?})")
424 }
425
426 ExecutionError::InvalidBindingNode { node } => {
427 write!(f, "Binding provided for invalid node {node:?}")
428 }
429
430 ExecutionError::DuplicateBinding { node } => {
431 write!(f, "Duplicate binding for node {node:?}")
432 }
433
434 ExecutionError::InputShapeMismatch {
435 node,
436 expected,
437 actual,
438 } => write!(
439 f,
440 "Shape mismatch for node {node:?}: expected {expected:?}, got {actual:?}",
441 ),
442
443 ExecutionError::KernelNotFound { op } => {
444 write!(f, "No kernel registered for operation {op:?}")
445 }
446
447 ExecutionError::KernelExecutionFailed { node, op, source } => write!(
448 f,
449 "Kernel execution failed at node {node:?} ({op:?}): {source}",
450 ),
451
452 ExecutionError::GraphError(e) => write!(f, "Graph error during execution: {e}"),
453
454 ExecutionError::InternalError(msg) => write!(f, "Executor internal error: {msg}"),
455 }
456 }
457}