Skip to main content

tensor_forge/
registry.rs

1//! Kernel registry for runtime dispatch of compute operations.
2//!
3//! A [`KernelRegistry`] maps graph-level operations ([`OpKind`]) to concrete
4//! compute implementations ([`Kernel`]). This allows the execution engine to
5//! resolve an operation at runtime without hardcoding kernel types.
6//!
7//! Kernels are stored as trait objects (`Box<dyn Kernel>`), enabling different
8//! kernel implementations to be registered and overridden dynamically.
9//!
10//! # Examples
11//!
12//! Register and retrieve a kernel:
13//!
14//! ```
15//! use tensor_forge::kernel::AddKernel;
16//! use tensor_forge::op::OpKind;
17//! use tensor_forge::registry::KernelRegistry;
18//! use tensor_forge::tensor::Tensor;
19//!
20//! let mut reg = KernelRegistry::new();
21//! assert!(reg.register(OpKind::Add, Box::new(AddKernel)).is_none());
22//!
23//! let kernel = reg.get(&OpKind::Add).unwrap();
24//!
25//! let shape = vec![1, 4];
26//! let a = Tensor::from_vec(shape.clone(), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
27//! let b = Tensor::from_vec(shape.clone(), vec![10.0, 20.0, 30.0, 40.0]).unwrap();
28//! let mut out = Tensor::zeros(shape).unwrap();
29//!
30//! kernel.compute(&[&a, &b], &mut out).unwrap();
31//! assert_eq!(out.data(), &[11.0, 22.0, 33.0, 44.0]);
32//! ```
33//!
34//! Overwrite an existing kernel:
35//!
36//! ```
37//! use tensor_forge::kernel::{AddKernel, MatMulKernel};
38//! use tensor_forge::op::OpKind;
39//! use tensor_forge::registry::KernelRegistry;
40//!
41//! let mut reg = KernelRegistry::new();
42//!
43//! assert!(reg.register(OpKind::Add, Box::new(AddKernel)).is_none());
44//! assert!(reg.register(OpKind::Add, Box::new(MatMulKernel)).is_some());
45//! ```
46use std::collections::HashMap;
47
48use crate::kernel::{AddKernel, Kernel, MatMulKernel, ReluKernel};
49use crate::op::OpKind;
50
51/// Registry mapping [`OpKind`] to runtime-executable [`Kernel`] implementations.
52///
53/// This type is used by the execution engine to dispatch operations to the
54/// correct kernel at runtime.
55pub struct KernelRegistry {
56    kernels: HashMap<OpKind, Box<dyn Kernel>>,
57}
58
59impl Default for KernelRegistry {
60    /// Creates a registry populated with the built-in kernels.
61    ///
62    /// Registers:
63    /// - [`OpKind::Add`] → [`crate::kernel::AddKernel`]
64    /// - [`OpKind::MatMul`] → [`crate::kernel::MatMulKernel`]
65    /// - [`OpKind::ReLU`] → [`crate::kernel::ReluKernel`]
66    ///
67    /// # Examples
68    ///
69    /// ```
70    /// use tensor_forge::op::OpKind;
71    /// use tensor_forge::registry::KernelRegistry;
72    ///
73    /// let reg = KernelRegistry::default();
74    /// assert!(reg.get(&OpKind::Add).is_some());
75    /// assert!(reg.get(&OpKind::MatMul).is_some());
76    /// assert!(reg.get(&OpKind::ReLU).is_some());
77    /// ```
78    fn default() -> Self {
79        let mut registry = Self::new();
80        let _ = registry.register(OpKind::Add, Box::new(AddKernel));
81        let _ = registry.register(OpKind::MatMul, Box::new(MatMulKernel));
82        let _ = registry.register(OpKind::ReLU, Box::new(ReluKernel));
83        registry
84    }
85}
86
87impl KernelRegistry {
88    /// Creates an empty kernel registry.
89    ///
90    /// # Examples
91    ///
92    /// ```
93    /// use tensor_forge::registry::KernelRegistry;
94    /// use tensor_forge::op::OpKind;
95    ///
96    /// let reg = KernelRegistry::new();
97    ///
98    /// assert!(reg.get(&OpKind::Add).is_none());
99    /// assert!(reg.get(&OpKind::MatMul).is_none());
100    /// assert!(reg.get(&OpKind::ReLU).is_none());
101    /// ```
102    #[must_use]
103    pub fn new() -> Self {
104        Self {
105            kernels: HashMap::new(),
106        }
107    }
108
109    /// Registers `kernel` as the implementation for `op`.
110    ///
111    /// Returns the previously registered kernel if one existed.
112    ///
113    /// # Examples
114    ///
115    /// ```
116    /// use tensor_forge::kernel::{AddKernel, MatMulKernel};
117    /// use tensor_forge::op::OpKind;
118    /// use tensor_forge::registry::KernelRegistry;
119    ///
120    /// let mut reg = KernelRegistry::new();
121    ///
122    /// let old_kernel = reg.register(OpKind::Add, Box::new(AddKernel));
123    /// assert!(old_kernel.is_none());
124    /// assert!(reg.get(&OpKind::Add).is_some());
125    ///
126    /// // Add conflicting mapping
127    /// let old_kernel = reg.register(OpKind::Add, Box::new(MatMulKernel));
128    /// assert!(old_kernel.is_some());             // returns old AddKernel Mapping.
129    /// assert!(reg.get(&OpKind::Add).is_some());  // `OpKind::Add` now maps to MatMulKernel.
130    ///
131    /// ```
132    #[must_use]
133    pub fn register(&mut self, op: OpKind, kernel: Box<dyn Kernel>) -> Option<Box<dyn Kernel>> {
134        self.kernels.insert(op, kernel)
135    }
136
137    /// Returns the kernel registered for `op`, if present.
138    ///
139    /// # Examples
140    ///
141    /// ```
142    /// use tensor_forge::kernel::AddKernel;
143    /// use tensor_forge::op::OpKind;
144    /// use tensor_forge::registry::KernelRegistry;
145    ///
146    /// let mut reg = KernelRegistry::new();
147    /// reg.register(OpKind::Add, Box::new(AddKernel));
148    ///
149    /// assert!(reg.get(&OpKind::Add).is_some());
150    /// assert!(reg.get(&OpKind::MatMul).is_none());
151    /// ```
152    #[must_use]
153    pub fn get(&self, op: &OpKind) -> Option<&dyn Kernel> {
154        self.kernels.get(op).map(Box::as_ref)
155    }
156}