Expand description
Loss functions for neural networks
This module provides common loss functions used in neural networks for training. Loss functions measure how well the model’s predictions match the target values, providing a signal for optimization during training.
§Overview
Loss functions serve two main purposes in neural networks:
- Forward pass: Calculate the scalar loss value between predictions and targets
- Backward pass: Calculate gradients needed for backpropagation The choice of loss function depends on your task type and requirements.
§Available Loss Functions
- MeanSquaredError: For regression tasks, measures squared differences
- CrossEntropyLoss: For classification tasks, combines softmax and negative log-likelihood
- FocalLoss: For imbalanced classification, focuses on hard examples
- ContrastiveLoss: For learning embeddings, brings similar items closer
- TripletLoss: For metric learning, enforces relative distances
§Examples
§Regression with Mean Squared Error
use scirs2_neural::losses::{Loss, MeanSquaredError};
use scirs2_core::ndarray::Array;
let mse = MeanSquaredError::new();
// Predictions and targets for regression
let predictions = Array::from_vec(vec![1.5, 2.3, 0.8]).into_dyn();
let targets = Array::from_vec(vec![1.0, 2.0, 1.0]).into_dyn();
// Calculate loss
let loss_value = mse.forward(&predictions, &targets)?;
println!("MSE Loss: {:.4}", loss_value);
// Calculate gradients for backpropagation
let gradients = mse.backward(&predictions, &targets)?;
println!("Gradients: {:?}", gradients);
§Classification with Cross Entropy
use scirs2_neural::losses::{Loss, CrossEntropyLoss}; let ce_loss = CrossEntropyLoss::default(); // Logits (raw predictions) for 3-class classification let logits = Array::from_shape_vec((2, 3), vec![ 1.0, 2.0, 0.5, // Sample 1 0.8, 0.2, 1.5, // Sample 2 ])?.into_dyn(); // One-hot encoded targets let targets = Array::from_shape_vec((2, 3), vec![ 0.0, 1.0, 0.0, // Sample 1: class 1 0.0, 0.0, 1.0, // Sample 2: class 2 // Calculate cross-entropy loss let loss_value = ce_loss.forward(&logits, &targets)?; println!(“Cross-Entropy Loss: {:.4}”, loss_value); // Get gradients for backpropagation let gradients = ce_loss.backward(&logits, &targets)?;
§Handling Class Imbalance with Focal Loss
use scirs2_neural::losses::{Loss, FocalLoss}; // Focal loss with gamma=2.0, alpha=0.25 (typical values) let focal_loss = FocalLoss::new(2.0, Some(0.25), 1e-10); let predictions = Array::from_shape_vec((2, 2), vec![ 0.9, 0.1, // High confidence prediction 0.6, 0.4, // Lower confidence prediction let targets = Array::from_shape_vec((2, 2), vec![ 1.0, 0.0, // Correct prediction 1.0, 0.0, // Incorrect prediction (will get higher loss) let loss_value = focal_loss.forward(&predictions, &targets)?; // Focal loss will be higher for the incorrect, hard example println!(“Focal Loss: {:.4}”, loss_value);
§Metric Learning with Triplet Loss
use scirs2_neural::losses::{Loss, TripletLoss}; let triplet_loss = TripletLoss::new(1.0); // margin = 1.0 // Embeddings: [anchor, positive, negative] let anchor = Array::from_vec(vec![1.0, 0.0, 0.0]).into_dyn(); let positive = Array::from_vec(vec![0.9, 0.1, 0.0]).into_dyn(); // Similar to anchor let negative = Array::from_vec(vec![0.0, 0.0, 1.0]).into_dyn(); // Different from anchor // Stack into batch format let embeddings = Array::from_shape_vec((3, 3), vec![ 1.0, 0.0, 0.0, // anchor 0.9, 0.1, 0.0, // positive 0.0, 0.0, 1.0, // negative // Dummy targets (not used in triplet loss) let targets = Array::zeros((3,)).into_dyn(); let loss_value = triplet_loss.forward(&embeddings, &targets)?; println!(“Triplet Loss: {:.4}”, loss_value);
§Choosing the Right Loss Function
§For Regression:
- MeanSquaredError: Most common, penalizes large errors heavily
- Mean Absolute Error: More robust to outliers (implement if needed)
- Huber Loss: Combines benefits of MSE and MAE (implement if needed)
§For Classification:
- CrossEntropyLoss: Standard choice for balanced multi-class problems
- FocalLoss: Better for imbalanced datasets, focuses on hard examples
- Binary Cross-Entropy: For binary classification (implement if needed)
§For Similarity Learning:
- ContrastiveLoss: For learning embeddings with positive/negative pairs
- TripletLoss: For learning relative similarities with triplets
§Training Tips
- Scale your loss: If loss values are too large/small, consider scaling
- Monitor gradients: Use backward() to check gradient magnitudes
- Combine losses: You can combine multiple loss functions for complex objectives
- Loss scheduling: Consider changing loss function or parameters during training
Structs§
- Contrastive
Loss - Contrastive loss function.
- Cross
Entropy Loss - Cross-entropy loss function.
- Focal
Loss - Focal loss function.
- Mean
Squared Error - Mean squared error loss function.
- Triplet
Loss - Triplet loss function.
Traits§
- Loss
- Trait for loss functions used in neural networks