pub trait Loss {
// Required methods
fn forward(
&self,
predictions: &dyn ArrayProtocol,
targets: &dyn ArrayProtocol,
) -> CoreResult<Box<dyn ArrayProtocol>>;
fn backward(
&self,
predictions: &dyn ArrayProtocol,
targets: &dyn ArrayProtocol,
) -> CoreResult<Box<dyn ArrayProtocol>>;
fn name(&self) -> &str;
}
Expand description
Loss function trait.
Required Methods§
Sourcefn forward(
&self,
predictions: &dyn ArrayProtocol,
targets: &dyn ArrayProtocol,
) -> CoreResult<Box<dyn ArrayProtocol>>
fn forward( &self, predictions: &dyn ArrayProtocol, targets: &dyn ArrayProtocol, ) -> CoreResult<Box<dyn ArrayProtocol>>
Compute the loss between predictions and targets.
Sourcefn backward(
&self,
predictions: &dyn ArrayProtocol,
targets: &dyn ArrayProtocol,
) -> CoreResult<Box<dyn ArrayProtocol>>
fn backward( &self, predictions: &dyn ArrayProtocol, targets: &dyn ArrayProtocol, ) -> CoreResult<Box<dyn ArrayProtocol>>
Compute the gradient of the loss with respect to predictions.