pub trait Kernel {
// Required method
fn compute(
&self,
inputs: &[&Tensor],
output: &mut Tensor,
) -> Result<(), KernelError>;
}Expand description
A runtime-executable compute primitive.
A kernel computes output = f(inputs...) for a particular operation. The caller is
responsible for allocating output with the correct shape.
§Errors
Implementations return KernelError if:
- The number of
inputsdoes not match the kernel contract, - Shapes are incompatible for the operation, or
- The operation requires a specific rank (e.g., 2-D matrices for matmul) and the input rank is unsupported.
§Examples
let shape = vec![2, 2];
let a = Tensor::from_vec(shape.clone(), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
let b = Tensor::from_vec(shape.clone(), vec![10.0, 20.0, 30.0, 40.0]).unwrap();
let mut out = Tensor::zeros(shape).unwrap();
AddKernel.compute(&[&a, &b], &mut out).unwrap();
assert_eq!(out.data(), &[11.0, 22.0, 33.0, 44.0]);Required Methods§
Sourcefn compute(
&self,
inputs: &[&Tensor],
output: &mut Tensor,
) -> Result<(), KernelError>
fn compute( &self, inputs: &[&Tensor], output: &mut Tensor, ) -> Result<(), KernelError>
Computes the kernel output in-place.
§Errors
Returns KernelError on invalid input arity, shape incompatibility, or unsupported rank.