Trait LossFunction

Source
pub trait LossFunction {
    // Required methods
    fn compute_loss(
        &self,
        predictions: &Array2<f64>,
        targets: &Array2<f64>,
    ) -> f64;
    fn compute_gradient(
        &self,
        predictions: &Array2<f64>,
        targets: &Array2<f64>,
    ) -> Array2<f64>;
}
Expand description

Loss function trait for training neural networks

Required Methods§

Source

fn compute_loss(&self, predictions: &Array2<f64>, targets: &Array2<f64>) -> f64

Compute the loss between predictions and targets

Source

fn compute_gradient( &self, predictions: &Array2<f64>, targets: &Array2<f64>, ) -> Array2<f64>

Compute the gradient of the loss with respect to predictions

Implementors§