pub struct FocalLoss { /* private fields */ }
Expand description
Focal loss function.
The focal loss is a modified version of cross-entropy that reduces the relative loss for well-classified examples, focusing more on hard, misclassified examples.
The focal loss is defined as: FL(p_t) = -alpha_t * (1 - p_t)^gamma * log(p_t)
where:
- p_t is the model’s estimated probability for the class with true label t
- alpha_t is a weighting factor (can be per-class)
- gamma is the focusing parameter (gamma > 0)
This is particularly useful for class imbalance problems.
§Examples
use scirs2_neural::losses::FocalLoss;
use scirs2_neural::losses::Loss;
use ndarray::{Array, arr2};
// Create focal loss with gamma=2.0 and alpha=0.25
let focal = FocalLoss::new(2.0, Some(0.25), 1e-10);
// One-hot encoded targets and softmax'd predictions for a 3-class problem
let predictions = arr2(&[
[0.7, 0.2, 0.1], // First sample, class probabilities
[0.3, 0.6, 0.1] // Second sample, class probabilities
]).into_dyn();
let targets = arr2(&[
[1.0, 0.0, 0.0], // First sample, true class is 0
[0.0, 1.0, 0.0] // Second sample, true class is 1
]).into_dyn();
// Forward pass to calculate loss
let loss = focal.forward(&predictions, &targets).unwrap();
// Backward pass to calculate gradients
let gradients = focal.backward(&predictions, &targets).unwrap();
Implementations§
Source§impl FocalLoss
impl FocalLoss
Sourcepub fn new(gamma: f64, alpha: Option<f64>, epsilon: f64) -> Self
pub fn new(gamma: f64, alpha: Option<f64>, epsilon: f64) -> Self
Create a new focal loss function with a single alpha value for all classes
§Arguments
gamma
- Focusing parameter, gamma >= 0. Higher gamma means more focus on misclassified examples.alpha
- Optional weighting factor, typically between 0 and 1.epsilon
- Small value to add to predictions to avoid log(0)
Examples found in repository?
examples/loss_functions_example.rs (line 50)
6fn main() -> Result<(), Box<dyn std::error::Error>> {
7 println!("Loss functions example");
8
9 // Mean Squared Error example
10 println!("\n--- Mean Squared Error Example ---");
11 let mse = MeanSquaredError::new();
12
13 // Create sample data for regression
14 let predictions = Array::from_vec(vec![1.0, 2.0, 3.0]).into_dyn();
15 let targets = Array::from_vec(vec![1.5, 1.8, 2.5]).into_dyn();
16
17 // Calculate loss
18 let loss = mse.forward(&predictions, &targets)?;
19 println!("Predictions: {:?}", predictions);
20 println!("Targets: {:?}", targets);
21 println!("MSE Loss: {:.4}", loss);
22
23 // Calculate gradients
24 let gradients = mse.backward(&predictions, &targets)?;
25 println!("MSE Gradients: {:?}", gradients);
26
27 // Cross-Entropy Loss example
28 println!("\n--- Cross-Entropy Loss Example ---");
29 let ce = CrossEntropyLoss::new(1e-10);
30
31 // Create sample data for multi-class classification
32 let predictions = Array::from_shape_vec(IxDyn(&[2, 3]), vec![0.7, 0.2, 0.1, 0.3, 0.6, 0.1])?;
33 let targets = Array::from_shape_vec(IxDyn(&[2, 3]), vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0])?;
34
35 // Calculate loss
36 let loss = ce.forward(&predictions, &targets)?;
37 println!("Predictions (probabilities):");
38 println!("{:?}", predictions);
39 println!("Targets (one-hot):");
40 println!("{:?}", targets);
41 println!("Cross-Entropy Loss: {:.4}", loss);
42
43 // Calculate gradients
44 let gradients = ce.backward(&predictions, &targets)?;
45 println!("Cross-Entropy Gradients:");
46 println!("{:?}", gradients);
47
48 // Focal Loss example
49 println!("\n--- Focal Loss Example ---");
50 let focal = FocalLoss::new(2.0, Some(0.25), 1e-10);
51
52 // Create sample data for imbalanced classification
53 let predictions = Array::from_shape_vec(IxDyn(&[2, 3]), vec![0.7, 0.2, 0.1, 0.3, 0.6, 0.1])?;
54 let targets = Array::from_shape_vec(IxDyn(&[2, 3]), vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0])?;
55
56 // Calculate loss
57 let loss = focal.forward(&predictions, &targets)?;
58 println!("Predictions (probabilities):");
59 println!("{:?}", predictions);
60 println!("Targets (one-hot):");
61 println!("{:?}", targets);
62 println!("Focal Loss (gamma=2.0, alpha=0.25): {:.4}", loss);
63
64 // Calculate gradients
65 let gradients = focal.backward(&predictions, &targets)?;
66 println!("Focal Loss Gradients:");
67 println!("{:?}", gradients);
68
69 // Contrastive Loss example
70 println!("\n--- Contrastive Loss Example ---");
71 let contrastive = ContrastiveLoss::new(1.0);
72
73 // Create sample data for similarity learning
74 // Embedding pairs (batch_size x 2 x embedding_dim)
75 let embeddings = Array::from_shape_vec(
76 IxDyn(&[2, 2, 3]),
77 vec![
78 0.1, 0.2, 0.3, // First pair, first embedding
79 0.1, 0.3, 0.3, // First pair, second embedding (similar)
80 0.5, 0.5, 0.5, // Second pair, first embedding
81 0.9, 0.8, 0.7, // Second pair, second embedding (dissimilar)
82 ],
83 )?;
84
85 // Labels: 1 for similar pairs, 0 for dissimilar
86 let labels = Array::from_shape_vec(IxDyn(&[2, 1]), vec![1.0, 0.0])?;
87
88 // Calculate loss
89 let loss = contrastive.forward(&embeddings, &labels)?;
90 println!("Embeddings (batch_size x 2 x embedding_dim):");
91 println!("{:?}", embeddings);
92 println!("Labels (1 for similar, 0 for dissimilar):");
93 println!("{:?}", labels);
94 println!("Contrastive Loss (margin=1.0): {:.4}", loss);
95
96 // Calculate gradients
97 let gradients = contrastive.backward(&embeddings, &labels)?;
98 println!("Contrastive Loss Gradients (first few):");
99 println!("{:?}", gradients.slice(ndarray::s![0, .., 0]));
100
101 // Triplet Loss example
102 println!("\n--- Triplet Loss Example ---");
103 let triplet = TripletLoss::new(0.5);
104
105 // Create sample data for triplet learning
106 // Embedding triplets (batch_size x 3 x embedding_dim)
107 let embeddings = Array::from_shape_vec(
108 IxDyn(&[2, 3, 3]),
109 vec![
110 0.1, 0.2, 0.3, // First triplet, anchor
111 0.1, 0.3, 0.3, // First triplet, positive
112 0.5, 0.5, 0.5, // First triplet, negative
113 0.6, 0.6, 0.6, // Second triplet, anchor
114 0.5, 0.6, 0.6, // Second triplet, positive
115 0.1, 0.1, 0.1, // Second triplet, negative
116 ],
117 )?;
118
119 // Dummy labels (not used by triplet loss)
120 let dummy_labels = Array::zeros(IxDyn(&[2, 1]));
121
122 // Calculate loss
123 let loss = triplet.forward(&embeddings, &dummy_labels)?;
124 println!("Embeddings (batch_size x 3 x embedding_dim):");
125 println!(" - First dimension: batch size");
126 println!(" - Second dimension: [anchor, positive, negative]");
127 println!(" - Third dimension: embedding components");
128 println!("{:?}", embeddings);
129 println!("Triplet Loss (margin=0.5): {:.4}", loss);
130
131 // Calculate gradients
132 let gradients = triplet.backward(&embeddings, &dummy_labels)?;
133 println!("Triplet Loss Gradients (first few):");
134 println!("{:?}", gradients.slice(ndarray::s![0, .., 0]));
135
136 Ok(())
137}
Sourcepub fn with_class_weights(
gamma: f64,
alpha_per_class: Vec<f64>,
epsilon: f64,
) -> Self
pub fn with_class_weights( gamma: f64, alpha_per_class: Vec<f64>, epsilon: f64, ) -> Self
Create a focal loss with class-specific alpha weights
§Arguments
gamma
- Focusing parameter, gamma >= 0. Higher gamma means more focus on misclassified examples.alpha_per_class
- Vector of weighting factors, one per classepsilon
- Small value to add to predictions to avoid log(0)
Trait Implementations§
Auto Trait Implementations§
impl Freeze for FocalLoss
impl RefUnwindSafe for FocalLoss
impl Send for FocalLoss
impl Sync for FocalLoss
impl Unpin for FocalLoss
impl UnwindSafe for FocalLoss
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more
Source§impl<T> CloneToUninit for Twhere
T: Clone,
impl<T> CloneToUninit for Twhere
T: Clone,
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
Converts
self
into a Left
variant of Either<Self, Self>
if into_left
is true
.
Converts self
into a Right
variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
Converts
self
into a Left
variant of Either<Self, Self>
if into_left(&self)
returns true
.
Converts self
into a Right
variant of Either<Self, Self>
otherwise. Read more