Skip to main content

Loss

Trait Loss 

Source
pub trait Loss: Debug {
    // Required methods
    fn compute(
        &self,
        predictions: &ArrayView<'_, f64, Ix2>,
        targets: &ArrayView<'_, f64, Ix2>,
    ) -> TrainResult<f64>;
    fn gradient(
        &self,
        predictions: &ArrayView<'_, f64, Ix2>,
        targets: &ArrayView<'_, f64, Ix2>,
    ) -> TrainResult<Array<f64, Ix2>>;

    // Provided method
    fn name(&self) -> &str { ... }
}
Expand description

Trait for loss functions.

Required Methods§

Source

fn compute( &self, predictions: &ArrayView<'_, f64, Ix2>, targets: &ArrayView<'_, f64, Ix2>, ) -> TrainResult<f64>

Compute loss value.

Source

fn gradient( &self, predictions: &ArrayView<'_, f64, Ix2>, targets: &ArrayView<'_, f64, Ix2>, ) -> TrainResult<Array<f64, Ix2>>

Compute loss gradient with respect to predictions.

Provided Methods§

Source

fn name(&self) -> &str

Get the name of the loss function.

Implementors§