ForwardBackward

Trait ForwardBackward 

Source
pub trait ForwardBackward {
    // Required methods
    fn forward(
        &self,
        ctx: &mut Context,
        inputs: Vec<ArrayView1<'_, f64>>,
    ) -> f64;
    fn backward(&self, ctx: &mut Context, grad_output: ArrayView1<'_, f64>);
}
Expand description

This trait defines the interface for operations that have both forward and backward passes.

Required Methods§

Source

fn forward(&self, ctx: &mut Context, inputs: Vec<ArrayView1<'_, f64>>) -> f64

Computes the forward pass of the operation.

§Arguments
  • ctx - A mutable reference to the computation context.
  • inputs - A vector of input arrays for the forward pass.
§Returns

(f64): The result of the forward pass.

Source

fn backward(&self, ctx: &mut Context, grad_output: ArrayView1<'_, f64>)

Computes the backward pass of the operation to calculate gradients.

§Arguments
  • ctx - A mutable reference to the computation context.
  • grad_output - The gradient of the loss with respect to the output.

Implementors§