pub trait LossFunction {
// Required methods
fn compute_loss(
&self,
predictions: &Array2<f64>,
targets: &Array2<f64>,
) -> f64;
fn compute_gradient(
&self,
predictions: &Array2<f64>,
targets: &Array2<f64>,
) -> Array2<f64>;
// Provided methods
fn compute_batch_loss(
&self,
predictions: &Array2<f64>,
targets: &Array2<f64>,
) -> f64 { ... }
fn compute_batch_gradient(
&self,
predictions: &Array2<f64>,
targets: &Array2<f64>,
) -> Array2<f64> { ... }
}
Expand description
Loss function trait for training neural networks