tensor_forge/kernel.rs
1//! Defines runtime-executable compute kernels.
2//!
3//! A *kernel* is a pure compute primitive that reads one or more input [`Tensor`]s and writes
4//! results into a pre-allocated output [`Tensor`]. Kernels do not allocate output storage and
5//! are expected to validate basic argument invariants (arity, shape, rank) before any compute.
6//!
7//! To allocate [`Tensor`]s, see the [`crate::graph::Graph`] API.
8
9use crate::tensor::Tensor;
10use std::fmt;
11use std::iter::zip;
12
13/// A runtime-executable compute primitive.
14///
15/// A kernel computes `output = f(inputs...)` for a particular operation. The caller is
16/// responsible for allocating `output` with the correct shape.
17///
18/// # Errors
19///
20/// Implementations return [`KernelError`] if:
21/// - The number of `inputs` does not match the kernel contract,
22/// - Shapes are incompatible for the operation, or
23/// - The operation requires a specific rank (e.g., 2-D matrices for matmul) and the input rank
24/// is unsupported.
25///
26/// # Examples
27/// ```
28/// # use tensor_forge::tensor::Tensor;
29/// # use tensor_forge::kernel::{Kernel, AddKernel};
30/// let shape = vec![2, 2];
31/// let a = Tensor::from_vec(shape.clone(), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
32/// let b = Tensor::from_vec(shape.clone(), vec![10.0, 20.0, 30.0, 40.0]).unwrap();
33///
34/// let mut out = Tensor::zeros(shape).unwrap();
35/// AddKernel.compute(&[&a, &b], &mut out).unwrap();
36/// assert_eq!(out.data(), &[11.0, 22.0, 33.0, 44.0]);
37/// ```
38pub trait Kernel {
39 /// Computes the kernel output in-place.
40 ///
41 /// # Errors
42 ///
43 /// Returns [`KernelError`] on invalid input arity, shape incompatibility, or unsupported rank.
44 fn compute(&self, inputs: &[&Tensor], output: &mut Tensor) -> Result<(), KernelError>;
45}
46
47/// Validates argument invariants for 2-D matrix multiplication.
48///
49/// Shape rule (rank-2 only):
50/// - `left.shape = [m, n]`
51/// - `right.shape = [n, d]`
52/// - `output.shape` must be `[m, d]`
53///
54/// # Errors
55///
56/// Returns [`KernelError::InvalidRank`] if either input is not rank-2.
57///
58/// Returns [`KernelError::ShapeMismatch`] if inner dimensions do not match or `output` has the
59/// wrong shape.
60fn verify_matmul_arguments(
61 left: &Tensor,
62 right: &Tensor,
63 output: &Tensor,
64) -> Result<(), KernelError> {
65 // N-dimensional matrix multiplications are not supported right now.
66 if left.shape().len() != 2 || right.shape().len() != 2 {
67 return Err(KernelError::InvalidRank);
68 }
69
70 // Input/Output connections should already be verified in Graph construction, but good
71 // to sanity check here.
72 let exp_output_shape = vec![left.shape()[0], right.shape()[1]];
73 if left.shape()[1] != right.shape()[0] || output.shape() != exp_output_shape {
74 return Err(KernelError::ShapeMismatch);
75 }
76 Ok(())
77}
78
79/// Naive 2-D matrix multiplication kernel.
80///
81/// Computes `output = left × right` for rank-2 matrices.
82///
83/// - `left` shape: `[m, n]`
84/// - `right` shape: `[n, d]`
85/// - `output` shape: `[m, d]`
86///
87/// Data is interpreted as row-major contiguous storage.
88///
89/// # Errors
90///
91/// Returns [`KernelError::InvalidArguments`] if `inputs.len() != 2`.
92///
93/// Returns [`KernelError::InvalidRank`] if either input is not rank-2.
94///
95/// Returns [`KernelError::ShapeMismatch`] if inner dimensions do not match or `output` has the
96/// wrong shape.
97///
98/// # Examples
99/// ```
100/// # use tensor_forge::tensor::Tensor;
101/// # use tensor_forge::kernel::{Kernel, MatMulKernel};
102/// let a = Tensor::from_vec(vec![2, 3], vec![1.0, 2.0, 3.0,
103/// 4.0, 5.0, 6.0]).unwrap();
104/// let b = Tensor::from_vec(vec![3, 2], vec![7.0, 8.0,
105/// 9.0, 10.0,
106/// 11.0, 12.0]).unwrap();
107///
108/// let mut out = Tensor::zeros(vec![2, 2]).unwrap();
109/// MatMulKernel.compute(&[&a, &b], &mut out).unwrap();
110/// assert_eq!(out.data(), &[58.0, 64.0, 139.0, 154.0]);
111/// ```
112pub struct MatMulKernel;
113
114impl Kernel for MatMulKernel {
115 fn compute(&self, inputs: &[&Tensor], output: &mut Tensor) -> Result<(), KernelError> {
116 if inputs.len() != 2 {
117 return Err(KernelError::InvalidArguments);
118 }
119 let (left, right) = (inputs[0], inputs[1]);
120 verify_matmul_arguments(left, right, output)?;
121 // Perform a naive matrix multiplication for alpha testing purposes.
122 //
123 // For an (MxN) x (NxD) Matmul: we compute successive dot products with each row and
124 // column.
125 let n = left.shape()[1];
126 let d = right.shape()[1];
127 // Tensor data is stored in row-major order. Therefore, we can access a full column with an iterator stride
128 let l_data = left.data();
129 let r_data = right.data();
130 for (i, element) in output.data_mut().iter_mut().enumerate() {
131 let row_offset = (i / d) * n;
132 let col_offset = i % d;
133 let row = l_data.iter().skip(row_offset).take(n);
134 let col = r_data.iter().skip(col_offset).step_by(d).take(n);
135 *element = zip(row, col).fold(0.0, |acc, (r, c)| acc + r * c);
136 }
137 Ok(())
138 }
139}
140
141/// Elementwise addition kernel.
142///
143/// Computes `output = left + right` (binary add).
144///
145/// All tensors must have identical shapes. Addition is performed elementwise in row-major order.
146///
147/// # Errors
148///
149/// Returns [`KernelError::InvalidArguments`] if `inputs.len() != 2`.
150///
151/// Returns [`KernelError::ShapeMismatch`] if `left`, `right`, and `output` do not all share the
152/// same shape.
153///
154/// # Examples
155/// ```
156/// # use tensor_forge::tensor::Tensor;
157/// # use tensor_forge::kernel::{Kernel, AddKernel};
158/// let shape = vec![1, 4];
159/// let a = Tensor::from_vec(shape.clone(), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
160/// let b = Tensor::from_vec(shape.clone(), vec![10.0, 20.0, 30.0, 40.0]).unwrap();
161///
162/// let mut out = Tensor::zeros(shape).unwrap();
163/// AddKernel.compute(&[&a, &b], &mut out).unwrap();
164/// assert_eq!(out.data(), &[11.0, 22.0, 33.0, 44.0]);
165/// ```
166pub struct AddKernel;
167
168impl Kernel for AddKernel {
169 fn compute(&self, inputs: &[&Tensor], output: &mut Tensor) -> Result<(), KernelError> {
170 if inputs.len() != 2 {
171 return Err(KernelError::InvalidArguments);
172 }
173 // Input connections should already be verified in Graph construction, but good
174 // to sanity check here.
175 let (left, right) = (inputs[0], inputs[1]);
176 if left.shape() != right.shape() || output.shape() != right.shape() {
177 return Err(KernelError::ShapeMismatch);
178 }
179 zip(output.data_mut().iter_mut(), zip(left.data(), right.data()))
180 .for_each(|(out, (l, r))| *out = l + r);
181 Ok(())
182 }
183}
184
185/// Rectified Linear Unit (`ReLU`) activation kernel.
186///
187/// Computes `output[i] = max(0, input[i])` elementwise.
188///
189/// NaN handling: if an input element is NaN, the corresponding output element is set to NaN.
190///
191/// # Errors
192///
193/// Returns [`KernelError::InvalidArguments`] if `inputs.len() != 1`.
194///
195/// Returns [`KernelError::ShapeMismatch`] if `input.shape() != output.shape()`.
196///
197/// # Examples
198/// ```
199/// # use tensor_forge::tensor::Tensor;
200/// # use tensor_forge::kernel::{Kernel, ReluKernel};
201/// let shape = vec![1, 5];
202/// let x = Tensor::from_vec(shape.clone(), vec![-2.0, -0.0, 0.0, 1.5, 3.0]).unwrap();
203///
204/// let mut out = Tensor::zeros(shape).unwrap();
205/// ReluKernel.compute(&[&x], &mut out).unwrap();
206/// assert_eq!(out.data(), &[0.0, 0.0, 0.0, 1.5, 3.0]);
207/// ```
208pub struct ReluKernel;
209
210impl Kernel for ReluKernel {
211 fn compute(&self, inputs: &[&Tensor], output: &mut Tensor) -> Result<(), KernelError> {
212 if inputs.len() != 1 {
213 return Err(KernelError::InvalidArguments);
214 }
215 let input = inputs[0];
216 // Input connections should already be verified in Graph construction, but good
217 // to sanity check here.
218 if input.shape() != output.shape() {
219 return Err(KernelError::ShapeMismatch);
220 }
221 for (output, &input) in zip(output.data_mut().iter_mut(), input.data().iter()) {
222 if input.is_nan() {
223 *output = input;
224 } else {
225 *output = input.max(0_f64);
226 }
227 }
228 Ok(())
229 }
230}
231
232/// Kernel-level errors raised during argument validation or execution.
233#[derive(Clone, Debug)]
234pub enum KernelError {
235 /// Wrong number of inputs were provided for the kernel.
236 InvalidArguments,
237 /// Input/output shapes are incompatible for the requested operation.
238 ShapeMismatch,
239 /// Input rank is unsupported for the requested operation (e.g., non-2-D matmul).
240 InvalidRank,
241}
242
243impl fmt::Display for KernelError {
244 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
245 match self {
246 KernelError::InvalidArguments => write!(
247 f,
248 "Invalid tensor dimensions. Tensor shape must not contain a zero."
249 ),
250 KernelError::ShapeMismatch => {
251 write!(
252 f,
253 "Input/Output tensors have mismatched data size for the selected operation."
254 )
255 }
256 KernelError::InvalidRank => {
257 write!(
258 f,
259 "Invalid/Unsupported matrix rank for supported operation."
260 )
261 }
262 }
263 }
264}