pub fn focal_loss(
graph: &mut Graph,
prediction: NodeId,
target: NodeId,
alpha: f32,
gamma: f32,
) -> Result<NodeId, ModelError>Expand description
Focal loss for imbalanced classification.
FL = -alpha * (1 - p_t)^gamma * log(p_t) averaged over elements.
prediction should be sigmoid probabilities, target binary labels.