Skip to main content

Kernel

Trait Kernel 

Source
pub trait Kernel {
    // Required method
    fn compute(
        &self,
        inputs: &[&Tensor],
        output: &mut Tensor,
    ) -> Result<(), KernelError>;
}
Expand description

A runtime-executable compute primitive.

A kernel computes output = f(inputs...) for a particular operation. The caller is responsible for allocating output with the correct shape.

§Errors

Implementations return KernelError if:

  • The number of inputs does not match the kernel contract,
  • Shapes are incompatible for the operation, or
  • The operation requires a specific rank (e.g., 2-D matrices for matmul) and the input rank is unsupported.

§Examples

let shape = vec![2, 2];
let a = Tensor::from_vec(shape.clone(), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
let b = Tensor::from_vec(shape.clone(), vec![10.0, 20.0, 30.0, 40.0]).unwrap();

let mut out = Tensor::zeros(shape).unwrap();
AddKernel.compute(&[&a, &b], &mut out).unwrap();
assert_eq!(out.data(), &[11.0, 22.0, 33.0, 44.0]);

Required Methods§

Source

fn compute( &self, inputs: &[&Tensor], output: &mut Tensor, ) -> Result<(), KernelError>

Computes the kernel output in-place.

§Errors

Returns KernelError on invalid input arity, shape incompatibility, or unsupported rank.

Implementors§