pub struct FocalLoss { /* private fields */ }
Expand description
Focal loss function.
The focal loss is designed to address class imbalance problems by down-weighting easy examples and focusing training on hard examples. For a single class, the focal loss is defined as: FL(p_t) = -α_t * (1 - p_t)^γ * log(p_t) where p_t is the model’s estimated probability for the true class.
§Examples
use scirs2_neural::losses::FocalLoss;
use scirs2_neural::losses::Loss;
use scirs2_core::ndarray::{Array, arr2};
let focal = FocalLoss::new(2.0, Some(0.25), 1e-10);
// Predictions and targets for a 3-class problem
let predictions = arr2(&[
[0.7, 0.2, 0.1], // First sample, class probabilities
[0.3, 0.6, 0.1] // Second sample, class probabilities
]).into_dyn();
let targets = arr2(&[
[1.0, 0.0, 0.0], // First sample, true class is 0
[0.0, 1.0, 0.0] // Second sample, true class is 1
]).into_dyn();
// Forward pass to calculate loss
let loss = focal.forward(&predictions, &targets).unwrap();
// Backward pass to calculate gradients
let gradients = focal.backward(&predictions, &targets).unwrap();
Implementations§
Source§impl FocalLoss
impl FocalLoss
Sourcepub fn new(gamma: f64, alpha: Option<f64>, epsilon: f64) -> Self
pub fn new(gamma: f64, alpha: Option<f64>, epsilon: f64) -> Self
Create a new focal loss function with a single alpha value for all classes
§Arguments
gamma
- Focusing parameter, gamma >= 0. Higher gamma means more focus on misclassified examples.alpha
- Optional weighting factor, typically between 0 and 1.epsilon
- Small value to add to predictions to avoid log(0)
Examples found in repository?
examples/loss_functions_example.rs (line 40)
7fn main() -> Result<(), Box<dyn std::error::Error>> {
8 println!("Loss functions example");
9 // Mean Squared Error example
10 println!("\n--- Mean Squared Error Example ---");
11 let mse = MeanSquaredError::new();
12 // Create sample data for regression
13 let predictions = Array::from_vec(vec![1.0, 2.0, 3.0]).into_dyn();
14 let targets = Array::from_vec(vec![1.5, 1.8, 2.5]).into_dyn();
15 // Calculate loss
16 let loss = mse.forward(&predictions, &targets)?;
17 println!("Predictions: {predictions:?}");
18 println!("Targets: {targets:?}");
19 println!("MSE Loss: {loss:.4}");
20 // Calculate gradients
21 let gradients = mse.backward(&predictions, &targets)?;
22 println!("MSE Gradients: {gradients:?}");
23 // Cross-Entropy Loss example
24 println!("\n--- Cross-Entropy Loss Example ---");
25 let ce = CrossEntropyLoss::new(1e-10);
26 // Create sample data for multi-class classification
27 let predictions = Array::from_shape_vec(IxDyn(&[2, 3]), vec![0.7, 0.2, 0.1, 0.3, 0.6, 0.1])?;
28 let targets = Array::from_shape_vec(IxDyn(&[2, 3]), vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0])?;
29 let loss = ce.forward(&predictions, &targets)?;
30 println!("Predictions (probabilities):");
31 println!("{predictions:?}");
32 println!("Targets (one-hot):");
33 println!("{targets:?}");
34 println!("Cross-Entropy Loss: {loss:.4}");
35 let gradients = ce.backward(&predictions, &targets)?;
36 println!("Cross-Entropy Gradients:");
37 println!("{gradients:?}");
38 // Focal Loss example
39 println!("\n--- Focal Loss Example ---");
40 let focal = FocalLoss::new(2.0, Some(0.25), 1e-10);
41 // Create sample data for imbalanced classification
42 let loss = focal.forward(&predictions, &targets)?;
43 println!("Focal Loss (gamma=2.0, alpha=0.25): {loss:.4}");
44 let _gradients = focal.backward(&predictions, &targets)?;
45 println!("Focal Loss Gradients:");
46 // Contrastive Loss example
47 println!("\n--- Contrastive Loss Example ---");
48 let contrastive = ContrastiveLoss::new(1.0);
49 // Create sample data for similarity learning
50 // Embedding pairs (batch_size x 2 x embedding_dim)
51 let embeddings = Array::from_shape_vec(
52 IxDyn(&[2, 2, 3]),
53 vec![
54 0.1, 0.2, 0.3, // First pair, first embedding
55 0.1, 0.3, 0.3, // First pair, second embedding (similar)
56 0.5, 0.5, 0.5, // Second pair, first embedding
57 0.9, 0.8, 0.7, // Second pair, second embedding (dissimilar)
58 ],
59 )?;
60 // Labels: 1 for similar pairs, 0 for dissimilar
61 let labels = Array::from_shape_vec(IxDyn(&[2, 1]), vec![1.0, 0.0])?;
62 let loss = contrastive.forward(&embeddings, &labels)?;
63 println!("Embeddings (batch_size x 2 x embedding_dim):");
64 println!("{embeddings:?}");
65 println!("Labels (1 for similar, 0 for dissimilar):");
66 println!("{labels:?}");
67 println!("Contrastive Loss (margin=1.0): {loss:.4}");
68 let gradients = contrastive.backward(&embeddings, &labels)?;
69 println!("Contrastive Loss Gradients (first few):");
70 let gradient_slice = gradients.slice(scirs2_core::ndarray::s![0, .., 0]);
71 println!("{gradient_slice:?}");
72 // Triplet Loss example
73 println!("\n--- Triplet Loss Example ---");
74 let triplet = TripletLoss::new(0.5);
75 // Create sample data for triplet learning
76 // Embedding triplets (batch_size x 3 x embedding_dim)
77 let triplet_embeddings = Array3::from_shape_vec(
78 (2, 3, 3),
79 vec![
80 0.1, 0.2, 0.3, // First triplet, anchor
81 0.1, 0.3, 0.3, // First triplet, positive
82 0.5, 0.5, 0.5, // First triplet, negative
83 0.6, 0.6, 0.6, // Second triplet, anchor
84 0.5, 0.6, 0.6, // Second triplet, positive
85 0.1, 0.1, 0.1, // Second triplet, negative
86 ],
87 )?
88 .into_dyn();
89 // Dummy labels (not used by triplet loss)
90 let dummy_labels = Array::zeros(IxDyn(&[2, 1]));
91 let loss = triplet.forward(&triplet_embeddings, &dummy_labels)?;
92 println!("Embeddings (batch_size x 3 x embedding_dim):");
93 println!(" - First dimension: batch size");
94 println!(" - Second dimension: [anchor, positive, negative]");
95 println!(" - Third dimension: embedding components");
96 println!("Triplet Loss (margin=0.5): {loss:.4}");
97 let _gradients = triplet.backward(&triplet_embeddings, &dummy_labels)?;
98 println!("Triplet Loss Gradients (first few):");
99 Ok(())
100}
Trait Implementations§
Auto Trait Implementations§
impl Freeze for FocalLoss
impl RefUnwindSafe for FocalLoss
impl Send for FocalLoss
impl Sync for FocalLoss
impl Unpin for FocalLoss
impl UnwindSafe for FocalLoss
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more
Source§impl<T> CloneToUninit for Twhere
T: Clone,
impl<T> CloneToUninit for Twhere
T: Clone,
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
Converts
self
into a Left
variant of Either<Self, Self>
if into_left
is true
.
Converts self
into a Right
variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
Converts
self
into a Left
variant of Either<Self, Self>
if into_left(&self)
returns true
.
Converts self
into a Right
variant of Either<Self, Self>
otherwise. Read more