pub trait Backend {
Show 27 methods
// Required methods
fn add(&self, lhs: &Tensor, rhs: &Tensor) -> Result<Tensor, KernelError>;
fn sub(&self, lhs: &Tensor, rhs: &Tensor) -> Result<Tensor, KernelError>;
fn mul(&self, lhs: &Tensor, rhs: &Tensor) -> Result<Tensor, KernelError>;
fn relu(&self, input: &Tensor) -> Tensor;
fn sigmoid(&self, input: &Tensor) -> Tensor;
fn exp(&self, input: &Tensor) -> Tensor;
fn tanh_act(&self, input: &Tensor) -> Tensor;
fn softmax_last_dim(&self, input: &Tensor) -> Result<Tensor, KernelError>;
fn log_softmax_last_dim(
&self,
input: &Tensor,
) -> Result<Tensor, KernelError>;
fn logsumexp_last_dim(&self, input: &Tensor) -> Result<Tensor, KernelError>;
fn layer_norm_last_dim(
&self,
input: &Tensor,
params: LayerNormLastDimParams<'_>,
) -> Result<Tensor, KernelError>;
fn max_pool2d_nhwc(
&self,
input: &Tensor,
kernel_h: usize,
kernel_w: usize,
stride_h: usize,
stride_w: usize,
) -> Result<Tensor, KernelError>;
fn avg_pool2d_nhwc(
&self,
input: &Tensor,
kernel_h: usize,
kernel_w: usize,
stride_h: usize,
stride_w: usize,
) -> Result<Tensor, KernelError>;
fn conv2d_nhwc(
&self,
input: &Tensor,
kernel: &Tensor,
bias: Option<&Tensor>,
stride_h: usize,
stride_w: usize,
) -> Result<Tensor, KernelError>;
fn depthwise_conv2d_nhwc(
&self,
input: &Tensor,
kernel: &Tensor,
bias: Option<&Tensor>,
stride_h: usize,
stride_w: usize,
) -> Result<Tensor, KernelError>;
fn separable_conv2d_nhwc(
&self,
input: &Tensor,
params: SeparableConv2dParams<'_>,
stride_h: usize,
stride_w: usize,
) -> Result<Tensor, KernelError>;
fn batch_norm2d_nhwc(
&self,
input: &Tensor,
params: BatchNorm2dParams<'_>,
) -> Result<Tensor, KernelError>;
fn group_norm_nhwc(
&self,
input: &Tensor,
params: GroupNormNhwcParams<'_>,
) -> Result<Tensor, KernelError>;
fn rms_norm_last_dim(
&self,
input: &Tensor,
params: RmsNormLastDimParams<'_>,
) -> Result<Tensor, KernelError>;
fn matmul_2d(
&self,
lhs: &Tensor,
rhs: &Tensor,
) -> Result<Tensor, KernelError>;
// Provided methods
fn neg(&self, input: &Tensor) -> Tensor { ... }
fn div(&self, lhs: &Tensor, rhs: &Tensor) -> Result<Tensor, KernelError> { ... }
fn sqrt(&self, input: &Tensor) -> Tensor { ... }
fn transpose_2d(&self, input: &Tensor) -> Result<Tensor, KernelError> { ... }
fn sum_all(&self, input: &Tensor) -> Tensor { ... }
fn mul_scalar(&self, input: &Tensor, scalar: f32) -> Tensor { ... }
fn reciprocal(&self, input: &Tensor) -> Tensor { ... }
}Expand description
Runtime backend contract for core deterministic kernels.
Required Methods§
fn add(&self, lhs: &Tensor, rhs: &Tensor) -> Result<Tensor, KernelError>
fn sub(&self, lhs: &Tensor, rhs: &Tensor) -> Result<Tensor, KernelError>
fn mul(&self, lhs: &Tensor, rhs: &Tensor) -> Result<Tensor, KernelError>
fn relu(&self, input: &Tensor) -> Tensor
fn sigmoid(&self, input: &Tensor) -> Tensor
fn exp(&self, input: &Tensor) -> Tensor
fn tanh_act(&self, input: &Tensor) -> Tensor
fn softmax_last_dim(&self, input: &Tensor) -> Result<Tensor, KernelError>
fn log_softmax_last_dim(&self, input: &Tensor) -> Result<Tensor, KernelError>
fn logsumexp_last_dim(&self, input: &Tensor) -> Result<Tensor, KernelError>
fn layer_norm_last_dim( &self, input: &Tensor, params: LayerNormLastDimParams<'_>, ) -> Result<Tensor, KernelError>
fn max_pool2d_nhwc( &self, input: &Tensor, kernel_h: usize, kernel_w: usize, stride_h: usize, stride_w: usize, ) -> Result<Tensor, KernelError>
fn avg_pool2d_nhwc( &self, input: &Tensor, kernel_h: usize, kernel_w: usize, stride_h: usize, stride_w: usize, ) -> Result<Tensor, KernelError>
fn conv2d_nhwc( &self, input: &Tensor, kernel: &Tensor, bias: Option<&Tensor>, stride_h: usize, stride_w: usize, ) -> Result<Tensor, KernelError>
fn depthwise_conv2d_nhwc( &self, input: &Tensor, kernel: &Tensor, bias: Option<&Tensor>, stride_h: usize, stride_w: usize, ) -> Result<Tensor, KernelError>
fn separable_conv2d_nhwc( &self, input: &Tensor, params: SeparableConv2dParams<'_>, stride_h: usize, stride_w: usize, ) -> Result<Tensor, KernelError>
fn batch_norm2d_nhwc( &self, input: &Tensor, params: BatchNorm2dParams<'_>, ) -> Result<Tensor, KernelError>
fn group_norm_nhwc( &self, input: &Tensor, params: GroupNormNhwcParams<'_>, ) -> Result<Tensor, KernelError>
fn rms_norm_last_dim( &self, input: &Tensor, params: RmsNormLastDimParams<'_>, ) -> Result<Tensor, KernelError>
fn matmul_2d(&self, lhs: &Tensor, rhs: &Tensor) -> Result<Tensor, KernelError>
Provided Methods§
Sourcefn div(&self, lhs: &Tensor, rhs: &Tensor) -> Result<Tensor, KernelError>
fn div(&self, lhs: &Tensor, rhs: &Tensor) -> Result<Tensor, KernelError>
Element-wise division with broadcast.
Sourcefn transpose_2d(&self, input: &Tensor) -> Result<Tensor, KernelError>
fn transpose_2d(&self, input: &Tensor) -> Result<Tensor, KernelError>
Transpose a 2-D matrix.
Sourcefn sum_all(&self, input: &Tensor) -> Tensor
fn sum_all(&self, input: &Tensor) -> Tensor
Scalar sum of all elements (returns a scalar tensor).
Sourcefn mul_scalar(&self, input: &Tensor, scalar: f32) -> Tensor
fn mul_scalar(&self, input: &Tensor, scalar: f32) -> Tensor
Multiply every element by a scalar.
Sourcefn reciprocal(&self, input: &Tensor) -> Tensor
fn reciprocal(&self, input: &Tensor) -> Tensor
Element-wise reciprocal (1/x).