LossFunction

Trait LossFunction 

Source
pub trait LossFunction:
    Debug
    + Send
    + Sync {
    // Required methods
    fn loss(
        &self,
        y_true: &Array1<Float>,
        y_pred: &Array1<Float>,
    ) -> Result<Float>;
    fn loss_derivative(
        &self,
        y_true: &Array1<Float>,
        y_pred: &Array1<Float>,
    ) -> Result<Array1<Float>>;
    fn name(&self) -> &'static str;

    // Provided methods
    fn loss_and_derivative(
        &self,
        y_true: &Array1<Float>,
        y_pred: &Array1<Float>,
    ) -> Result<(Float, Array1<Float>)> { ... }
    fn is_classification(&self) -> bool { ... }
}
Expand description

Trait for loss functions that measure prediction error

Required Methods§

Source

fn loss(&self, y_true: &Array1<Float>, y_pred: &Array1<Float>) -> Result<Float>

Compute the loss value for predictions

Source

fn loss_derivative( &self, y_true: &Array1<Float>, y_pred: &Array1<Float>, ) -> Result<Array1<Float>>

Compute the derivative of loss with respect to predictions

Source

fn name(&self) -> &'static str

Get the name of this loss function

Provided Methods§

Source

fn loss_and_derivative( &self, y_true: &Array1<Float>, y_pred: &Array1<Float>, ) -> Result<(Float, Array1<Float>)>

Compute both loss and derivative (often more efficient)

Source

fn is_classification(&self) -> bool

Check if this is a classification loss (vs regression)

Implementors§