LossFunction

Trait LossFunction 

Source
pub trait LossFunction<T: FloatBounds + ScalarOperand> {
    // Required methods
    fn compute_loss(
        &self,
        predictions: &Array2<T>,
        targets: &Array2<T>,
    ) -> Result<T, SklearsError>;
    fn compute_gradient(
        &self,
        predictions: &Array2<T>,
        targets: &Array2<T>,
    ) -> Result<Array2<T>, SklearsError>;
}
Expand description

Loss function trait for gradient checking

Required Methods§

Source

fn compute_loss( &self, predictions: &Array2<T>, targets: &Array2<T>, ) -> Result<T, SklearsError>

Compute loss given predictions and targets

Source

fn compute_gradient( &self, predictions: &Array2<T>, targets: &Array2<T>, ) -> Result<Array2<T>, SklearsError>

Compute loss gradient with respect to predictions

Implementors§