Skip to main content

AttackLoss

Trait AttackLoss 

Source
pub trait AttackLoss: Send + Sync {
    // Required methods
    fn loss(&self, predictions: &[f64], labels: &[f64]) -> f64;
    fn grad(&self, predictions: &[f64], labels: &[f64]) -> Vec<f64>;
}
Expand description

A differentiable loss function used by attack algorithms.

Both loss and grad receive raw model outputs (logits or probabilities) and target labels, and must be thread-safe.

Required Methods§

Source

fn loss(&self, predictions: &[f64], labels: &[f64]) -> f64

Compute the scalar loss value.

Source

fn grad(&self, predictions: &[f64], labels: &[f64]) -> Vec<f64>

Compute the gradient of the loss with respect to predictions.

Implementors§