Skip to main content

tensor_forge/
executor.rs

1//! Execution engine for evaluating compute graphs against a [`KernelRegistry`].
2//!
3//! An [`Executor`] is responsible for:
4//! - validating runtime input bindings,
5//! - traversing a [`Graph`] in topological order,
6//! - dispatching each non-input node to its registered kernel, and
7//! - returning the tensors for all nodes marked as graph outputs.
8//!
9//! Execution is deterministic with respect to graph topology because it relies on
10//! [`Graph::topo_sort`], which returns a stable topological order.
11//!
12//! # Input Binding Model
13//!
14//! Execution requires one tensor binding for every [`OpKind::Input`] node in the graph.
15//! Bindings are passed as a `Vec<(NodeId, Tensor)>`, where:
16//! - the [`NodeId`] identifies an input node in the graph, and
17//! - the [`Tensor`] is the runtime value to supply for that input.
18//!
19//! Bindings are validated before execution begins:
20//! - each bound node must exist in the graph,
21//! - each bound node must be an input node,
22//! - each input node must be bound exactly once, and
23//! - each bound tensor must match the input node’s declared shape.
24//!
25//! # Output Model
26//!
27//! The executor returns a `HashMap<NodeId, Tensor>` containing the computed tensors
28//! for all nodes designated as outputs via [`Graph::set_output_node`].
29//!
30//! Output values are keyed by the output node’s [`NodeId`]. Output ordering is not
31//! part of the API contract.
32//!
33//! # Errors
34//!
35//! Execution may fail for several classes of reasons:
36//! - graph-level failures (for example, invalid topology),
37//! - invalid runtime bindings,
38//! - missing kernel implementations, or
39//! - kernel execution failures at a specific node.
40//!
41//! These are reported via [`ExecutionError`].
42//!
43//! # Examples
44//! ```
45//! # use tensor_forge::executor::Executor;
46//! # use tensor_forge::graph::Graph;
47//! # use tensor_forge::registry::KernelRegistry;
48//! # use tensor_forge::tensor::Tensor;
49//! let mut g = Graph::new();
50//! let a = g.input_node(vec![2, 2]);
51//! let b = g.input_node(vec![2, 2]);
52//! let c = g.add(a, b).expect("Valid add operation should succeed");
53//! g.set_output_node(c).expect("Valid output node should succeed");
54//!
55//! let a_tensor = Tensor::zeros(vec![2, 2]).expect("Tensor allocation should succeed");
56//! let b_tensor = Tensor::zeros(vec![2, 2]).expect("Tensor allocation should succeed");
57//!
58//! let exec = Executor::new(KernelRegistry::default());
59//! let outputs = exec
60//!     .execute(&g, vec![(a, a_tensor), (b, b_tensor)])
61//!     .expect("Execution should succeed");
62//!
63//! assert!(outputs.contains_key(&c));
64//! ```
65use crate::graph::{Graph, GraphError};
66use crate::kernel::KernelError;
67use crate::node::NodeId;
68use crate::op::OpKind;
69use crate::registry::KernelRegistry;
70use crate::tensor::Tensor;
71use std::collections::{HashMap, HashSet};
72
73/// Executes graphs using kernels registered in a [`KernelRegistry`].
74///
75/// An [`Executor`] owns a kernel registry and uses it to evaluate graph nodes
76/// according to each node’s [`OpKind`]. Non-input nodes are executed in
77/// deterministic topological order, and intermediate tensors are stored internally
78/// until all requested graph outputs have been produced.
79///
80/// # Examples
81/// ```
82/// # use tensor_forge::executor::Executor;
83/// # use tensor_forge::registry::KernelRegistry;
84/// let exec = Executor::new(KernelRegistry::default());
85/// ```
86pub struct Executor {
87    registry: KernelRegistry,
88}
89
90impl Executor {
91    /// Creates a new executor backed by the provided kernel `registry`.
92    ///
93    /// The registry determines which kernel implementation will be used for each
94    /// operation kind encountered during execution.
95    ///
96    /// # Examples
97    /// ```
98    /// # use tensor_forge::executor::Executor;
99    /// # use tensor_forge::registry::KernelRegistry;
100    /// let registry = KernelRegistry::default();
101    /// let exec = Executor::new(registry);
102    /// ```
103    #[must_use]
104    pub fn new(registry: KernelRegistry) -> Self {
105        Self { registry }
106    }
107
108    /// Executes `graph` using the provided input `bindings`.
109    ///
110    /// Each binding is a `(NodeId, Tensor)` pair supplying the runtime value for a
111    /// graph input node. Execution proceeds in deterministic topological order:
112    ///
113    /// 1. Validate the graph topology.
114    /// 2. Validate input bindings.
115    /// 3. Execute every non-input node using the corresponding registered kernel.
116    /// 4. Return a map containing the tensors for all graph output nodes.
117    ///
118    /// The returned map is keyed by output [`NodeId`]. Output order is not part of
119    /// the contract.
120    ///
121    /// # Binding Rules
122    ///
123    /// The `inputs` vector must satisfy all of the following:
124    /// - every bound node must exist in `graph`,
125    /// - every bound node must be an [`OpKind::Input`] node,
126    /// - every graph input node must appear exactly once, and
127    /// - every bound tensor shape must match the input node’s declared shape.
128    ///
129    /// # Errors
130    ///
131    /// Returns:
132    /// - [`ExecutionError::GraphError`] if topological traversal or graph lookup fails,
133    /// - [`ExecutionError::DuplicateBinding`] if the same input node is bound more than once,
134    /// - [`ExecutionError::InvalidBindingNode`] if a binding references a node not in the graph,
135    /// - [`ExecutionError::BindingToNonInputNode`] if a binding targets a non-input node,
136    /// - [`ExecutionError::InputShapeMismatch`] if a bound tensor has the wrong shape,
137    /// - [`ExecutionError::MissingInput`] if any graph input node is not bound,
138    /// - [`ExecutionError::KernelNotFound`] if no kernel is registered for an op,
139    /// - [`ExecutionError::KernelExecutionFailed`] if a kernel returns an error during execution,
140    /// - [`ExecutionError::InternalError`] if an internal invariant is violated.
141    ///
142    /// # Examples
143    /// ```
144    /// # use tensor_forge::executor::Executor;
145    /// # use tensor_forge::graph::Graph;
146    /// # use tensor_forge::registry::KernelRegistry;
147    /// # use tensor_forge::tensor::Tensor;
148    /// let mut g = Graph::new();
149    /// let x = g.input_node(vec![2, 2]);
150    /// let y = g.relu(x).expect("Valid ReLU operation should succeed");
151    /// g.set_output_node(y).expect("Valid output node should succeed");
152    ///
153    /// let x_tensor = Tensor::zeros(vec![2, 2]).expect("Tensor allocation should succeed");
154    ///
155    /// let exec = Executor::new(KernelRegistry::default());
156    /// let outputs = exec
157    ///     .execute(&g, vec![(x, x_tensor)])
158    ///     .expect("Execution should succeed");
159    ///
160    /// assert!(outputs.contains_key(&y));
161    /// ```
162    pub fn execute(
163        &self,
164        graph: &Graph,
165        inputs: Vec<(NodeId, Tensor)>,
166    ) -> Result<HashMap<NodeId, Tensor>, ExecutionError> {
167        // 0) Discover input nodes (OpKind::Input) and validate bindings.
168        //
169        // NOTE: This uses topo_sort() to traverse all nodes deterministically.
170        // If the graph is cyclic/malformed, this returns a GraphError.
171        let topo = graph.topo_sort().map_err(ExecutionError::GraphError)?;
172
173        let mut input_nodes: Vec<NodeId> = Vec::new();
174        for &id in &topo {
175            let node = graph.node(id).map_err(ExecutionError::GraphError)?;
176            if node.op == OpKind::Input {
177                input_nodes.push(id);
178            }
179        }
180
181        // 1) Validate: duplicate bindings
182        let mut seen: HashSet<NodeId> = HashSet::with_capacity(inputs.len());
183        for (id, _) in &inputs {
184            if !seen.insert(*id) {
185                return Err(ExecutionError::DuplicateBinding { node: *id });
186            }
187        }
188
189        // 2) Validate: binding node exists, is input node, and shape matches
190        //
191        // Also build the runtime value table with the provided inputs.
192        let mut values: HashMap<NodeId, Tensor> =
193            HashMap::with_capacity(graph.num_nodes().max(inputs.len()));
194
195        for (id, t) in inputs {
196            let node = graph
197                .node(id)
198                .map_err(|_| ExecutionError::InvalidBindingNode { node: id })?;
199
200            if node.op != OpKind::Input {
201                return Err(ExecutionError::BindingToNonInputNode {
202                    node: id,
203                    op: node.op.clone(),
204                });
205            }
206
207            let expected = node.shape.clone();
208            let actual = t.shape().to_vec();
209            if actual != expected {
210                return Err(ExecutionError::InputShapeMismatch {
211                    node: id,
212                    expected,
213                    actual,
214                });
215            }
216
217            // Move the owned tensor into the value table.
218            values.insert(id, t);
219        }
220
221        // 3) Validate: all graph inputs are present in bindings
222        for &input_id in &input_nodes {
223            if !values.contains_key(&input_id) {
224                return Err(ExecutionError::MissingInput { node: input_id });
225            }
226        }
227
228        // 4) Execute in topological order.
229        for &node_id in &topo {
230            let node = graph.node(node_id).map_err(ExecutionError::GraphError)?;
231
232            if node.op == OpKind::Input {
233                // Inputs were already populated by bindings validation.
234                continue;
235            }
236
237            // 4a) Fetch kernel
238            let kernel =
239                self.registry
240                    .get(&node.op)
241                    .ok_or_else(|| ExecutionError::KernelNotFound {
242                        op: node.op.clone(),
243                    })?;
244
245            // 4b) Fetch input tensors (by NodeId)
246            let mut input_tensors: Vec<&Tensor> = Vec::with_capacity(node.inputs.len());
247            for &dep in &node.inputs {
248                let t = values.get(&dep).ok_or({
249                    // This indicates a bug in topo_sort/executor bookkeeping, or a malformed graph.
250                    ExecutionError::InternalError("missing dependency tensor during execution")
251                })?;
252                input_tensors.push(t);
253            }
254
255            // 4c) Allocate output tensor
256            let mut out = Tensor::zeros(node.shape.clone())
257                .map_err(|_| ExecutionError::InternalError("failed to allocate output tensor"))?;
258
259            // 4d) Execute kernel
260            kernel
261                .compute(&input_tensors, &mut out)
262                .map_err(|e: KernelError| ExecutionError::KernelExecutionFailed {
263                    node: node_id,
264                    op: node.op.clone(),
265                    source: e,
266                })?;
267
268            // Store produced tensor for downstream consumers.
269            values.insert(node_id, out);
270        }
271
272        // 5) Collect outputs by output NodeId.
273        let mut outputs: HashMap<NodeId, Tensor> = HashMap::with_capacity(graph.outputs().len());
274        for &out_id in graph.outputs() {
275            let t = values.remove(&out_id).ok_or({
276                ExecutionError::InternalError("output tensor missing after execution")
277            })?;
278            outputs.insert(out_id, t);
279        }
280
281        Ok(outputs)
282    }
283}
284
285impl Default for Executor {
286    /// Creates an executor using [`KernelRegistry::default`].
287    ///
288    /// This is a convenience constructor for the common case where the default
289    /// kernel set is sufficient.
290    ///
291    /// # Examples
292    /// ```
293    /// # use tensor_forge::executor::Executor;
294    /// let exec = Executor::default();
295    /// ```
296    fn default() -> Self {
297        Self {
298            registry: KernelRegistry::default(),
299        }
300    }
301}
302
303/// Errors that may occur while validating bindings or executing a graph.
304///
305/// These errors distinguish between:
306/// - invalid caller-supplied bindings,
307/// - graph-level failures,
308/// - missing kernel implementations, and
309/// - runtime kernel failures.
310///
311/// # Examples
312/// ```
313/// # use tensor_forge::executor::{ExecutionError, Executor};
314/// # use tensor_forge::graph::Graph;
315/// # use tensor_forge::registry::KernelRegistry;
316/// # use tensor_forge::tensor::Tensor;
317/// let mut g = Graph::new();
318/// let x = g.input_node(vec![2, 2]);
319/// g.set_output_node(x).expect("Valid output node should succeed");
320///
321/// let wrong = Tensor::zeros(vec![3, 3]).expect("Tensor allocation should succeed");
322/// let exec = Executor::new(KernelRegistry::default());
323///
324/// let err = exec.execute(&g, vec![(x, wrong)]).unwrap_err();
325/// assert!(matches!(err, ExecutionError::InputShapeMismatch { .. }));
326/// ```
327#[derive(Debug, Clone)]
328pub enum ExecutionError {
329    /// A required input node was not provided in the bindings.
330    ///
331    /// Every [`OpKind::Input`] node in the graph must have exactly one runtime binding.
332    MissingInput {
333        /// The input node that was not bound at execution time.
334        node: NodeId,
335    },
336
337    /// A binding was provided for a node that exists in the graph but is not an input node.
338    ///
339    /// Only input nodes may be bound directly by the caller.
340    BindingToNonInputNode {
341        /// The node that was incorrectly bound.
342        node: NodeId,
343        /// The operation kind of `node`.
344        op: OpKind,
345    },
346
347    /// A binding was provided for a node that does not exist in the graph.
348    ///
349    /// This typically indicates that:
350    /// - a stale [`NodeId`] was reused, or
351    /// - a [`NodeId`] from another graph was supplied.
352    InvalidBindingNode {
353        /// The invalid node identifier supplied in the bindings.
354        node: NodeId,
355    },
356
357    /// Multiple bindings were provided for the same input node.
358    ///
359    /// Input bindings must be unique by [`NodeId`].
360    DuplicateBinding {
361        /// The input node that was bound more than once.
362        node: NodeId,
363    },
364
365    /// A runtime tensor shape did not match the graph input node’s declared shape.
366    ///
367    /// The `expected` shape is taken from the graph node, while `actual` is the
368    /// shape of the caller-provided tensor.
369    InputShapeMismatch {
370        /// The input node whose binding had the wrong shape.
371        node: NodeId,
372        /// The shape declared by the graph for `node`.
373        expected: Vec<usize>,
374        /// The shape of the tensor supplied by the caller.
375        actual: Vec<usize>,
376    },
377
378    /// No kernel implementation was registered for the requested operation.
379    ///
380    /// This prevents execution of any node with the given [`OpKind`].
381    KernelNotFound {
382        /// The operation kind for which no kernel was registered.
383        op: OpKind,
384    },
385
386    /// A kernel returned an error while executing a specific graph node.
387    ///
388    /// This variant preserves:
389    /// - the failing node ID,
390    /// - the operation kind being executed, and
391    /// - the underlying [`KernelError`].
392    KernelExecutionFailed {
393        /// The graph node whose kernel execution failed.
394        node: NodeId,
395        /// The operation kind being executed at `node`.
396        op: OpKind,
397        /// The underlying kernel error returned by the registered kernel.
398        source: KernelError,
399    },
400
401    /// A graph-level failure occurred while preparing execution.
402    ///
403    /// This wraps errors originating from graph traversal or graph validation,
404    /// such as cycle detection or invalid node references.
405    GraphError(GraphError),
406
407    /// An internal executor invariant was violated.
408    ///
409    /// This variant indicates a bug or malformed internal state rather than a
410    /// user-facing validation issue. Under normal operation, callers should not
411    /// be able to trigger this error through the public API alone.
412    InternalError(&'static str),
413}
414
415impl std::fmt::Display for ExecutionError {
416    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
417        match self {
418            ExecutionError::MissingInput { node } => {
419                write!(f, "Missing input binding for node {node:?}")
420            }
421
422            ExecutionError::BindingToNonInputNode { node, op } => {
423                write!(f, "Cannot bind tensor to non-input node {node:?} ({op:?})")
424            }
425
426            ExecutionError::InvalidBindingNode { node } => {
427                write!(f, "Binding provided for invalid node {node:?}")
428            }
429
430            ExecutionError::DuplicateBinding { node } => {
431                write!(f, "Duplicate binding for node {node:?}")
432            }
433
434            ExecutionError::InputShapeMismatch {
435                node,
436                expected,
437                actual,
438            } => write!(
439                f,
440                "Shape mismatch for node {node:?}: expected {expected:?}, got {actual:?}",
441            ),
442
443            ExecutionError::KernelNotFound { op } => {
444                write!(f, "No kernel registered for operation {op:?}")
445            }
446
447            ExecutionError::KernelExecutionFailed { node, op, source } => write!(
448                f,
449                "Kernel execution failed at node {node:?} ({op:?}): {source}",
450            ),
451
452            ExecutionError::GraphError(e) => write!(f, "Graph error during execution: {e}"),
453
454            ExecutionError::InternalError(msg) => write!(f, "Executor internal error: {msg}"),
455        }
456    }
457}