Loss

Trait Loss 

Source
pub trait Loss: Debug {
    // Required methods
    fn compute(
        &self,
        predictions: &ArrayView<'_, f64, Ix2>,
        targets: &ArrayView<'_, f64, Ix2>,
    ) -> TrainResult<f64>;
    fn gradient(
        &self,
        predictions: &ArrayView<'_, f64, Ix2>,
        targets: &ArrayView<'_, f64, Ix2>,
    ) -> TrainResult<Array<f64, Ix2>>;
}
Expand description

Trait for loss functions.

Required Methods§

Source

fn compute( &self, predictions: &ArrayView<'_, f64, Ix2>, targets: &ArrayView<'_, f64, Ix2>, ) -> TrainResult<f64>

Compute loss value.

Source

fn gradient( &self, predictions: &ArrayView<'_, f64, Ix2>, targets: &ArrayView<'_, f64, Ix2>, ) -> TrainResult<Array<f64, Ix2>>

Compute loss gradient with respect to predictions.

Implementors§