Skip to main content

Backend

Trait Backend 

Source
pub trait Backend {
Show 27 methods // Required methods fn add(&self, lhs: &Tensor, rhs: &Tensor) -> Result<Tensor, KernelError>; fn sub(&self, lhs: &Tensor, rhs: &Tensor) -> Result<Tensor, KernelError>; fn mul(&self, lhs: &Tensor, rhs: &Tensor) -> Result<Tensor, KernelError>; fn relu(&self, input: &Tensor) -> Tensor; fn sigmoid(&self, input: &Tensor) -> Tensor; fn exp(&self, input: &Tensor) -> Tensor; fn tanh_act(&self, input: &Tensor) -> Tensor; fn softmax_last_dim(&self, input: &Tensor) -> Result<Tensor, KernelError>; fn log_softmax_last_dim( &self, input: &Tensor, ) -> Result<Tensor, KernelError>; fn logsumexp_last_dim(&self, input: &Tensor) -> Result<Tensor, KernelError>; fn layer_norm_last_dim( &self, input: &Tensor, params: LayerNormLastDimParams<'_>, ) -> Result<Tensor, KernelError>; fn max_pool2d_nhwc( &self, input: &Tensor, kernel_h: usize, kernel_w: usize, stride_h: usize, stride_w: usize, ) -> Result<Tensor, KernelError>; fn avg_pool2d_nhwc( &self, input: &Tensor, kernel_h: usize, kernel_w: usize, stride_h: usize, stride_w: usize, ) -> Result<Tensor, KernelError>; fn conv2d_nhwc( &self, input: &Tensor, kernel: &Tensor, bias: Option<&Tensor>, stride_h: usize, stride_w: usize, ) -> Result<Tensor, KernelError>; fn depthwise_conv2d_nhwc( &self, input: &Tensor, kernel: &Tensor, bias: Option<&Tensor>, stride_h: usize, stride_w: usize, ) -> Result<Tensor, KernelError>; fn separable_conv2d_nhwc( &self, input: &Tensor, params: SeparableConv2dParams<'_>, stride_h: usize, stride_w: usize, ) -> Result<Tensor, KernelError>; fn batch_norm2d_nhwc( &self, input: &Tensor, params: BatchNorm2dParams<'_>, ) -> Result<Tensor, KernelError>; fn group_norm_nhwc( &self, input: &Tensor, params: GroupNormNhwcParams<'_>, ) -> Result<Tensor, KernelError>; fn rms_norm_last_dim( &self, input: &Tensor, params: RmsNormLastDimParams<'_>, ) -> Result<Tensor, KernelError>; fn matmul_2d( &self, lhs: &Tensor, rhs: &Tensor, ) -> Result<Tensor, KernelError>; // Provided methods fn neg(&self, input: &Tensor) -> Tensor { ... } fn div(&self, lhs: &Tensor, rhs: &Tensor) -> Result<Tensor, KernelError> { ... } fn sqrt(&self, input: &Tensor) -> Tensor { ... } fn transpose_2d(&self, input: &Tensor) -> Result<Tensor, KernelError> { ... } fn sum_all(&self, input: &Tensor) -> Tensor { ... } fn mul_scalar(&self, input: &Tensor, scalar: f32) -> Tensor { ... } fn reciprocal(&self, input: &Tensor) -> Tensor { ... }
}
Expand description

Runtime backend contract for core deterministic kernels.

Required Methods§

Source

fn add(&self, lhs: &Tensor, rhs: &Tensor) -> Result<Tensor, KernelError>

Source

fn sub(&self, lhs: &Tensor, rhs: &Tensor) -> Result<Tensor, KernelError>

Source

fn mul(&self, lhs: &Tensor, rhs: &Tensor) -> Result<Tensor, KernelError>

Source

fn relu(&self, input: &Tensor) -> Tensor

Source

fn sigmoid(&self, input: &Tensor) -> Tensor

Source

fn exp(&self, input: &Tensor) -> Tensor

Source

fn tanh_act(&self, input: &Tensor) -> Tensor

Source

fn softmax_last_dim(&self, input: &Tensor) -> Result<Tensor, KernelError>

Source

fn log_softmax_last_dim(&self, input: &Tensor) -> Result<Tensor, KernelError>

Source

fn logsumexp_last_dim(&self, input: &Tensor) -> Result<Tensor, KernelError>

Source

fn layer_norm_last_dim( &self, input: &Tensor, params: LayerNormLastDimParams<'_>, ) -> Result<Tensor, KernelError>

Source

fn max_pool2d_nhwc( &self, input: &Tensor, kernel_h: usize, kernel_w: usize, stride_h: usize, stride_w: usize, ) -> Result<Tensor, KernelError>

Source

fn avg_pool2d_nhwc( &self, input: &Tensor, kernel_h: usize, kernel_w: usize, stride_h: usize, stride_w: usize, ) -> Result<Tensor, KernelError>

Source

fn conv2d_nhwc( &self, input: &Tensor, kernel: &Tensor, bias: Option<&Tensor>, stride_h: usize, stride_w: usize, ) -> Result<Tensor, KernelError>

Source

fn depthwise_conv2d_nhwc( &self, input: &Tensor, kernel: &Tensor, bias: Option<&Tensor>, stride_h: usize, stride_w: usize, ) -> Result<Tensor, KernelError>

Source

fn separable_conv2d_nhwc( &self, input: &Tensor, params: SeparableConv2dParams<'_>, stride_h: usize, stride_w: usize, ) -> Result<Tensor, KernelError>

Source

fn batch_norm2d_nhwc( &self, input: &Tensor, params: BatchNorm2dParams<'_>, ) -> Result<Tensor, KernelError>

Source

fn group_norm_nhwc( &self, input: &Tensor, params: GroupNormNhwcParams<'_>, ) -> Result<Tensor, KernelError>

Source

fn rms_norm_last_dim( &self, input: &Tensor, params: RmsNormLastDimParams<'_>, ) -> Result<Tensor, KernelError>

Source

fn matmul_2d(&self, lhs: &Tensor, rhs: &Tensor) -> Result<Tensor, KernelError>

Provided Methods§

Source

fn neg(&self, input: &Tensor) -> Tensor

Element-wise negation.

Source

fn div(&self, lhs: &Tensor, rhs: &Tensor) -> Result<Tensor, KernelError>

Element-wise division with broadcast.

Source

fn sqrt(&self, input: &Tensor) -> Tensor

Element-wise square root.

Source

fn transpose_2d(&self, input: &Tensor) -> Result<Tensor, KernelError>

Transpose a 2-D matrix.

Source

fn sum_all(&self, input: &Tensor) -> Tensor

Scalar sum of all elements (returns a scalar tensor).

Source

fn mul_scalar(&self, input: &Tensor, scalar: f32) -> Tensor

Multiply every element by a scalar.

Source

fn reciprocal(&self, input: &Tensor) -> Tensor

Element-wise reciprocal (1/x).

Implementors§