tensor_forge/registry.rs
1//! Kernel registry for runtime dispatch of compute operations.
2//!
3//! A [`KernelRegistry`] maps graph-level operations ([`OpKind`]) to concrete
4//! compute implementations ([`Kernel`]). This allows the execution engine to
5//! resolve an operation at runtime without hardcoding kernel types.
6//!
7//! Kernels are stored as trait objects (`Box<dyn Kernel>`), enabling different
8//! kernel implementations to be registered and overridden dynamically.
9//!
10//! # Examples
11//!
12//! Register and retrieve a kernel:
13//!
14//! ```
15//! use tensor_forge::kernel::AddKernel;
16//! use tensor_forge::op::OpKind;
17//! use tensor_forge::registry::KernelRegistry;
18//! use tensor_forge::tensor::Tensor;
19//!
20//! let mut reg = KernelRegistry::new();
21//! assert!(reg.register(OpKind::Add, Box::new(AddKernel)).is_none());
22//!
23//! let kernel = reg.get(&OpKind::Add).unwrap();
24//!
25//! let shape = vec![1, 4];
26//! let a = Tensor::from_vec(shape.clone(), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
27//! let b = Tensor::from_vec(shape.clone(), vec![10.0, 20.0, 30.0, 40.0]).unwrap();
28//! let mut out = Tensor::zeros(shape).unwrap();
29//!
30//! kernel.compute(&[&a, &b], &mut out).unwrap();
31//! assert_eq!(out.data(), &[11.0, 22.0, 33.0, 44.0]);
32//! ```
33//!
34//! Overwrite an existing kernel:
35//!
36//! ```
37//! use tensor_forge::kernel::{AddKernel, MatMulKernel};
38//! use tensor_forge::op::OpKind;
39//! use tensor_forge::registry::KernelRegistry;
40//!
41//! let mut reg = KernelRegistry::new();
42//!
43//! assert!(reg.register(OpKind::Add, Box::new(AddKernel)).is_none());
44//! assert!(reg.register(OpKind::Add, Box::new(MatMulKernel)).is_some());
45//! ```
46use std::collections::HashMap;
47
48use crate::kernel::{AddKernel, Kernel, MatMulKernel, ReluKernel};
49use crate::op::OpKind;
50
51/// Registry mapping [`OpKind`] to runtime-executable [`Kernel`] implementations.
52///
53/// This type is used by the execution engine to dispatch operations to the
54/// correct kernel at runtime.
55pub struct KernelRegistry {
56 kernels: HashMap<OpKind, Box<dyn Kernel>>,
57}
58
59impl Default for KernelRegistry {
60 /// Creates a registry populated with the built-in kernels.
61 ///
62 /// Registers:
63 /// - [`OpKind::Add`] → [`crate::kernel::AddKernel`]
64 /// - [`OpKind::MatMul`] → [`crate::kernel::MatMulKernel`]
65 /// - [`OpKind::ReLU`] → [`crate::kernel::ReluKernel`]
66 ///
67 /// # Examples
68 ///
69 /// ```
70 /// use tensor_forge::op::OpKind;
71 /// use tensor_forge::registry::KernelRegistry;
72 ///
73 /// let reg = KernelRegistry::default();
74 /// assert!(reg.get(&OpKind::Add).is_some());
75 /// assert!(reg.get(&OpKind::MatMul).is_some());
76 /// assert!(reg.get(&OpKind::ReLU).is_some());
77 /// ```
78 fn default() -> Self {
79 let mut registry = Self::new();
80 let _ = registry.register(OpKind::Add, Box::new(AddKernel));
81 let _ = registry.register(OpKind::MatMul, Box::new(MatMulKernel));
82 let _ = registry.register(OpKind::ReLU, Box::new(ReluKernel));
83 registry
84 }
85}
86
87impl KernelRegistry {
88 /// Creates an empty kernel registry.
89 ///
90 /// # Examples
91 ///
92 /// ```
93 /// use tensor_forge::registry::KernelRegistry;
94 /// use tensor_forge::op::OpKind;
95 ///
96 /// let reg = KernelRegistry::new();
97 ///
98 /// assert!(reg.get(&OpKind::Add).is_none());
99 /// assert!(reg.get(&OpKind::MatMul).is_none());
100 /// assert!(reg.get(&OpKind::ReLU).is_none());
101 /// ```
102 #[must_use]
103 pub fn new() -> Self {
104 Self {
105 kernels: HashMap::new(),
106 }
107 }
108
109 /// Registers `kernel` as the implementation for `op`.
110 ///
111 /// Returns the previously registered kernel if one existed.
112 ///
113 /// # Examples
114 ///
115 /// ```
116 /// use tensor_forge::kernel::{AddKernel, MatMulKernel};
117 /// use tensor_forge::op::OpKind;
118 /// use tensor_forge::registry::KernelRegistry;
119 ///
120 /// let mut reg = KernelRegistry::new();
121 ///
122 /// let old_kernel = reg.register(OpKind::Add, Box::new(AddKernel));
123 /// assert!(old_kernel.is_none());
124 /// assert!(reg.get(&OpKind::Add).is_some());
125 ///
126 /// // Add conflicting mapping
127 /// let old_kernel = reg.register(OpKind::Add, Box::new(MatMulKernel));
128 /// assert!(old_kernel.is_some()); // returns old AddKernel Mapping.
129 /// assert!(reg.get(&OpKind::Add).is_some()); // `OpKind::Add` now maps to MatMulKernel.
130 ///
131 /// ```
132 #[must_use]
133 pub fn register(&mut self, op: OpKind, kernel: Box<dyn Kernel>) -> Option<Box<dyn Kernel>> {
134 self.kernels.insert(op, kernel)
135 }
136
137 /// Returns the kernel registered for `op`, if present.
138 ///
139 /// # Examples
140 ///
141 /// ```
142 /// use tensor_forge::kernel::AddKernel;
143 /// use tensor_forge::op::OpKind;
144 /// use tensor_forge::registry::KernelRegistry;
145 ///
146 /// let mut reg = KernelRegistry::new();
147 /// reg.register(OpKind::Add, Box::new(AddKernel));
148 ///
149 /// assert!(reg.get(&OpKind::Add).is_some());
150 /// assert!(reg.get(&OpKind::MatMul).is_none());
151 /// ```
152 #[must_use]
153 pub fn get(&self, op: &OpKind) -> Option<&dyn Kernel> {
154 self.kernels.get(op).map(Box::as_ref)
155 }
156}