Loss

Trait Loss 

Source
pub trait Loss {
    // Required methods
    fn forward(
        &self,
        predictions: &dyn ArrayProtocol,
        targets: &dyn ArrayProtocol,
    ) -> CoreResult<Box<dyn ArrayProtocol>>;
    fn backward(
        &self,
        predictions: &dyn ArrayProtocol,
        targets: &dyn ArrayProtocol,
    ) -> CoreResult<Box<dyn ArrayProtocol>>;
    fn name(&self) -> &str;
}
Expand description

Loss function trait.

Required Methods§

Source

fn forward( &self, predictions: &dyn ArrayProtocol, targets: &dyn ArrayProtocol, ) -> CoreResult<Box<dyn ArrayProtocol>>

Compute the loss between predictions and targets.

Source

fn backward( &self, predictions: &dyn ArrayProtocol, targets: &dyn ArrayProtocol, ) -> CoreResult<Box<dyn ArrayProtocol>>

Compute the gradient of the loss with respect to predictions.

Source

fn name(&self) -> &str

Get the name of the loss function.

Implementors§