TensorTrait

Trait TensorTrait 

Source
pub trait TensorTrait {
    // Required methods
    fn forward(
        &self,
        ctx: &mut Context,
        inputs: Vec<ArrayView1<'_, f64>>,
    ) -> f64;
    fn backward(&self, ctx: &mut Context, grad_output: ArrayView1<'_, f64>);
    fn get_value(&self) -> ArrayView1<'_, f64>;
    fn get_grad(&self) -> Option<Array1<f64>>;
}
Expand description

This trait defines the common interface for tensors in a computational graph.

Required Methods§

Source

fn forward(&self, ctx: &mut Context, inputs: Vec<ArrayView1<'_, f64>>) -> f64

Computes the forward pass of the tensor.

§Arguments
  • ctx - A mutable reference to the computation context.
  • inputs - A vector of input arrays for the forward pass.
§Returns

(f64): The result of the forward pass.

Source

fn backward(&self, ctx: &mut Context, grad_output: ArrayView1<'_, f64>)

Computes the backward pass of the tensor to calculate gradients.

§Arguments
  • ctx - A mutable reference to the computation context.
  • grad_output - The gradient of the loss with respect to the output.
Source

fn get_value(&self) -> ArrayView1<'_, f64>

Gets the value of the tensor.

§Returns

(ArrayView1<f64>): The view of the tensor’s value.

Source

fn get_grad(&self) -> Option<Array1<f64>>

Gets the gradient of the tensor.

§Returns

(Option<Gradient>): The option containing the gradient if available.

Implementors§