pub trait TensorTrait {
// Required methods
fn forward(
&self,
ctx: &mut Context,
inputs: Vec<ArrayView1<'_, f64>>,
) -> f64;
fn backward(&self, ctx: &mut Context, grad_output: ArrayView1<'_, f64>);
fn get_value(&self) -> ArrayView1<'_, f64>;
fn get_grad(&self) -> Option<Array1<f64>>;
}Expand description
This trait defines the common interface for tensors in a computational graph.
Required Methods§
Sourcefn backward(&self, ctx: &mut Context, grad_output: ArrayView1<'_, f64>)
fn backward(&self, ctx: &mut Context, grad_output: ArrayView1<'_, f64>)
Computes the backward pass of the tensor to calculate gradients.
§Arguments
ctx- A mutable reference to the computation context.grad_output- The gradient of the loss with respect to the output.