Skip to main content

BackwardOps

Trait BackwardOps 

Source
pub trait BackwardOps: Backend {
    // Provided methods
    fn relu_backward(
        &self,
        upstream: &Tensor,
        forward_input: &Tensor,
    ) -> Result<Tensor, KernelError> { ... }
    fn sigmoid_backward(
        &self,
        upstream: &Tensor,
        forward_output: &Tensor,
    ) -> Result<Tensor, KernelError> { ... }
    fn tanh_backward(
        &self,
        upstream: &Tensor,
        forward_output: &Tensor,
    ) -> Result<Tensor, KernelError> { ... }
    fn exp_backward(
        &self,
        upstream: &Tensor,
        forward_output: &Tensor,
    ) -> Result<Tensor, KernelError> { ... }
    fn reduce_sum_backward(
        &self,
        upstream: &Tensor,
        original_shape: &[usize],
    ) -> Result<Tensor, KernelError> { ... }
    fn matmul_backward(
        &self,
        upstream: &Tensor,
        lhs: &Tensor,
        rhs: &Tensor,
    ) -> Result<(Tensor, Tensor), KernelError> { ... }
    fn add_backward(
        &self,
        upstream: &Tensor,
        _lhs: &Tensor,
        _rhs: &Tensor,
    ) -> Result<(Tensor, Tensor), KernelError> { ... }
    fn sub_backward(
        &self,
        upstream: &Tensor,
        _lhs: &Tensor,
        _rhs: &Tensor,
    ) -> Result<(Tensor, Tensor), KernelError> { ... }
    fn mul_backward(
        &self,
        upstream: &Tensor,
        lhs: &Tensor,
        rhs: &Tensor,
    ) -> Result<(Tensor, Tensor), KernelError> { ... }
    fn conv2d_input_backward(
        &self,
        upstream: &Tensor,
        kernel: &Tensor,
        input_shape: &[usize],
        stride_h: usize,
        stride_w: usize,
    ) -> Result<Tensor, KernelError> { ... }
}
Expand description

Extension trait for backward-pass operations.

Separated from Backend so that forward-only consumers (e.g. ONNX inference) need not depend on backward-related method signatures. All methods have default CPU implementations, so impl BackwardOps for MyBackend {} is sufficient.

Provided Methods§

Source

fn relu_backward( &self, upstream: &Tensor, forward_input: &Tensor, ) -> Result<Tensor, KernelError>

ReLU backward: grad_input[i] = upstream[i] * (forward_input[i] > 0 ? 1 : 0).

Source

fn sigmoid_backward( &self, upstream: &Tensor, forward_output: &Tensor, ) -> Result<Tensor, KernelError>

Sigmoid backward: grad_input[i] = upstream[i] * s[i] * (1 - s[i]) where s = forward output.

Source

fn tanh_backward( &self, upstream: &Tensor, forward_output: &Tensor, ) -> Result<Tensor, KernelError>

Tanh backward: grad_input[i] = upstream[i] * (1 - t[i]^2) where t = forward output.

Source

fn exp_backward( &self, upstream: &Tensor, forward_output: &Tensor, ) -> Result<Tensor, KernelError>

Exp backward: grad_input[i] = upstream[i] * e[i] where e = forward output.

Source

fn reduce_sum_backward( &self, upstream: &Tensor, original_shape: &[usize], ) -> Result<Tensor, KernelError>

Reduce-sum backward: broadcast scalar gradient to all elements of original_shape.

Source

fn matmul_backward( &self, upstream: &Tensor, lhs: &Tensor, rhs: &Tensor, ) -> Result<(Tensor, Tensor), KernelError>

MatMul backward: grad_lhs = upstream @ rhs^T, grad_rhs = lhs^T @ upstream.

Source

fn add_backward( &self, upstream: &Tensor, _lhs: &Tensor, _rhs: &Tensor, ) -> Result<(Tensor, Tensor), KernelError>

Add backward: gradient passes through unchanged to both operands.

Source

fn sub_backward( &self, upstream: &Tensor, _lhs: &Tensor, _rhs: &Tensor, ) -> Result<(Tensor, Tensor), KernelError>

Sub backward: grad_lhs = upstream, grad_rhs = -upstream.

Source

fn mul_backward( &self, upstream: &Tensor, lhs: &Tensor, rhs: &Tensor, ) -> Result<(Tensor, Tensor), KernelError>

Mul backward: grad_lhs = upstream * rhs, grad_rhs = upstream * lhs.

Source

fn conv2d_input_backward( &self, upstream: &Tensor, kernel: &Tensor, input_shape: &[usize], stride_h: usize, stride_w: usize, ) -> Result<Tensor, KernelError>

Conv2d backward (input gradient): compute dL/dInput from dL/dOutput and weights.

Default CPU implementation via full convolution with flipped kernel.

Implementors§