Trait ForwardBackward
Source pub trait ForwardBackward {
// Required methods
fn forward(
&self,
ctx: &mut Context,
inputs: Vec<ArrayView1<'_, f64>>,
) -> f64;
fn backward(&self, ctx: &mut Context, grad_output: ArrayView1<'_, f64>);
}
Expand description
This trait defines the interface for operations that have both forward and backward passes.
Computes the forward pass of the operation.
§Arguments
ctx - A mutable reference to the computation context.
inputs - A vector of input arrays for the forward pass.
§Returns
(f64): The result of the forward pass.
Computes the backward pass of the operation to calculate gradients.
§Arguments
ctx - A mutable reference to the computation context.
grad_output - The gradient of the loss with respect to the output.