Struct FocalLoss

Source
pub struct FocalLoss { /* private fields */ }
Expand description

Focal loss function.

The focal loss is a modified version of cross-entropy that reduces the relative loss for well-classified examples, focusing more on hard, misclassified examples.

The focal loss is defined as: FL(p_t) = -alpha_t * (1 - p_t)^gamma * log(p_t)

where:

  • p_t is the model’s estimated probability for the class with true label t
  • alpha_t is a weighting factor (can be per-class)
  • gamma is the focusing parameter (gamma > 0)

This is particularly useful for class imbalance problems.

§Examples

use scirs2_neural::losses::FocalLoss;
use scirs2_neural::losses::Loss;
use ndarray::{Array, arr2};

// Create focal loss with gamma=2.0 and alpha=0.25
let focal = FocalLoss::new(2.0, Some(0.25), 1e-10);

// One-hot encoded targets and softmax'd predictions for a 3-class problem
let predictions = arr2(&[
    [0.7, 0.2, 0.1],  // First sample, class probabilities
    [0.3, 0.6, 0.1]   // Second sample, class probabilities
]).into_dyn();

let targets = arr2(&[
    [1.0, 0.0, 0.0],  // First sample, true class is 0
    [0.0, 1.0, 0.0]   // Second sample, true class is 1
]).into_dyn();

// Forward pass to calculate loss
let loss = focal.forward(&predictions, &targets).unwrap();

// Backward pass to calculate gradients
let gradients = focal.backward(&predictions, &targets).unwrap();

Implementations§

Source§

impl FocalLoss

Source

pub fn new(gamma: f64, alpha: Option<f64>, epsilon: f64) -> Self

Create a new focal loss function with a single alpha value for all classes

§Arguments
  • gamma - Focusing parameter, gamma >= 0. Higher gamma means more focus on misclassified examples.
  • alpha - Optional weighting factor, typically between 0 and 1.
  • epsilon - Small value to add to predictions to avoid log(0)
Examples found in repository?
examples/loss_functions_example.rs (line 50)
6fn main() -> Result<(), Box<dyn std::error::Error>> {
7    println!("Loss functions example");
8
9    // Mean Squared Error example
10    println!("\n--- Mean Squared Error Example ---");
11    let mse = MeanSquaredError::new();
12
13    // Create sample data for regression
14    let predictions = Array::from_vec(vec![1.0, 2.0, 3.0]).into_dyn();
15    let targets = Array::from_vec(vec![1.5, 1.8, 2.5]).into_dyn();
16
17    // Calculate loss
18    let loss = mse.forward(&predictions, &targets)?;
19    println!("Predictions: {:?}", predictions);
20    println!("Targets: {:?}", targets);
21    println!("MSE Loss: {:.4}", loss);
22
23    // Calculate gradients
24    let gradients = mse.backward(&predictions, &targets)?;
25    println!("MSE Gradients: {:?}", gradients);
26
27    // Cross-Entropy Loss example
28    println!("\n--- Cross-Entropy Loss Example ---");
29    let ce = CrossEntropyLoss::new(1e-10);
30
31    // Create sample data for multi-class classification
32    let predictions = Array::from_shape_vec(IxDyn(&[2, 3]), vec![0.7, 0.2, 0.1, 0.3, 0.6, 0.1])?;
33    let targets = Array::from_shape_vec(IxDyn(&[2, 3]), vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0])?;
34
35    // Calculate loss
36    let loss = ce.forward(&predictions, &targets)?;
37    println!("Predictions (probabilities):");
38    println!("{:?}", predictions);
39    println!("Targets (one-hot):");
40    println!("{:?}", targets);
41    println!("Cross-Entropy Loss: {:.4}", loss);
42
43    // Calculate gradients
44    let gradients = ce.backward(&predictions, &targets)?;
45    println!("Cross-Entropy Gradients:");
46    println!("{:?}", gradients);
47
48    // Focal Loss example
49    println!("\n--- Focal Loss Example ---");
50    let focal = FocalLoss::new(2.0, Some(0.25), 1e-10);
51
52    // Create sample data for imbalanced classification
53    let predictions = Array::from_shape_vec(IxDyn(&[2, 3]), vec![0.7, 0.2, 0.1, 0.3, 0.6, 0.1])?;
54    let targets = Array::from_shape_vec(IxDyn(&[2, 3]), vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0])?;
55
56    // Calculate loss
57    let loss = focal.forward(&predictions, &targets)?;
58    println!("Predictions (probabilities):");
59    println!("{:?}", predictions);
60    println!("Targets (one-hot):");
61    println!("{:?}", targets);
62    println!("Focal Loss (gamma=2.0, alpha=0.25): {:.4}", loss);
63
64    // Calculate gradients
65    let gradients = focal.backward(&predictions, &targets)?;
66    println!("Focal Loss Gradients:");
67    println!("{:?}", gradients);
68
69    // Contrastive Loss example
70    println!("\n--- Contrastive Loss Example ---");
71    let contrastive = ContrastiveLoss::new(1.0);
72
73    // Create sample data for similarity learning
74    // Embedding pairs (batch_size x 2 x embedding_dim)
75    let embeddings = Array::from_shape_vec(
76        IxDyn(&[2, 2, 3]),
77        vec![
78            0.1, 0.2, 0.3, // First pair, first embedding
79            0.1, 0.3, 0.3, // First pair, second embedding (similar)
80            0.5, 0.5, 0.5, // Second pair, first embedding
81            0.9, 0.8, 0.7, // Second pair, second embedding (dissimilar)
82        ],
83    )?;
84
85    // Labels: 1 for similar pairs, 0 for dissimilar
86    let labels = Array::from_shape_vec(IxDyn(&[2, 1]), vec![1.0, 0.0])?;
87
88    // Calculate loss
89    let loss = contrastive.forward(&embeddings, &labels)?;
90    println!("Embeddings (batch_size x 2 x embedding_dim):");
91    println!("{:?}", embeddings);
92    println!("Labels (1 for similar, 0 for dissimilar):");
93    println!("{:?}", labels);
94    println!("Contrastive Loss (margin=1.0): {:.4}", loss);
95
96    // Calculate gradients
97    let gradients = contrastive.backward(&embeddings, &labels)?;
98    println!("Contrastive Loss Gradients (first few):");
99    println!("{:?}", gradients.slice(ndarray::s![0, .., 0]));
100
101    // Triplet Loss example
102    println!("\n--- Triplet Loss Example ---");
103    let triplet = TripletLoss::new(0.5);
104
105    // Create sample data for triplet learning
106    // Embedding triplets (batch_size x 3 x embedding_dim)
107    let embeddings = Array::from_shape_vec(
108        IxDyn(&[2, 3, 3]),
109        vec![
110            0.1, 0.2, 0.3, // First triplet, anchor
111            0.1, 0.3, 0.3, // First triplet, positive
112            0.5, 0.5, 0.5, // First triplet, negative
113            0.6, 0.6, 0.6, // Second triplet, anchor
114            0.5, 0.6, 0.6, // Second triplet, positive
115            0.1, 0.1, 0.1, // Second triplet, negative
116        ],
117    )?;
118
119    // Dummy labels (not used by triplet loss)
120    let dummy_labels = Array::zeros(IxDyn(&[2, 1]));
121
122    // Calculate loss
123    let loss = triplet.forward(&embeddings, &dummy_labels)?;
124    println!("Embeddings (batch_size x 3 x embedding_dim):");
125    println!("  - First dimension: batch size");
126    println!("  - Second dimension: [anchor, positive, negative]");
127    println!("  - Third dimension: embedding components");
128    println!("{:?}", embeddings);
129    println!("Triplet Loss (margin=0.5): {:.4}", loss);
130
131    // Calculate gradients
132    let gradients = triplet.backward(&embeddings, &dummy_labels)?;
133    println!("Triplet Loss Gradients (first few):");
134    println!("{:?}", gradients.slice(ndarray::s![0, .., 0]));
135
136    Ok(())
137}
Source

pub fn with_class_weights( gamma: f64, alpha_per_class: Vec<f64>, epsilon: f64, ) -> Self

Create a focal loss with class-specific alpha weights

§Arguments
  • gamma - Focusing parameter, gamma >= 0. Higher gamma means more focus on misclassified examples.
  • alpha_per_class - Vector of weighting factors, one per class
  • epsilon - Small value to add to predictions to avoid log(0)

Trait Implementations§

Source§

impl Clone for FocalLoss

Source§

fn clone(&self) -> FocalLoss

Returns a duplicate of the value. Read more
1.0.0 · Source§

const fn clone_from(&mut self, source: &Self)

Performs copy-assignment from source. Read more
Source§

impl Debug for FocalLoss

Source§

fn fmt(&self, f: &mut Formatter<'_>) -> Result

Formats the value using the given formatter. Read more
Source§

impl Default for FocalLoss

Source§

fn default() -> Self

Returns the “default value” for a type. Read more
Source§

impl<F: Float + Debug> Loss<F> for FocalLoss

Source§

fn forward( &self, predictions: &Array<F, IxDyn>, targets: &Array<F, IxDyn>, ) -> Result<F>

Calculate the loss between predictions and targets Read more
Source§

fn backward( &self, predictions: &Array<F, IxDyn>, targets: &Array<F, IxDyn>, ) -> Result<Array<F, IxDyn>>

Calculate the gradient of the loss with respect to the predictions Read more

Auto Trait Implementations§

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> CloneToUninit for T
where T: Clone,

Source§

unsafe fn clone_to_uninit(&self, dest: *mut u8)

🔬This is a nightly-only experimental API. (clone_to_uninit)
Performs copy-assignment from self to dest. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> IntoEither for T

Source§

fn into_either(self, into_left: bool) -> Either<Self, Self>

Converts self into a Left variant of Either<Self, Self> if into_left is true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
where F: FnOnce(&Self) -> bool,

Converts self into a Left variant of Either<Self, Self> if into_left(&self) returns true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

impl<T> Pointable for T

Source§

const ALIGN: usize

The alignment of pointer.
Source§

type Init = T

The type for initializers.
Source§

unsafe fn init(init: <T as Pointable>::Init) -> usize

Initializes a with the given initializer. Read more
Source§

unsafe fn deref<'a>(ptr: usize) -> &'a T

Dereferences the given pointer. Read more
Source§

unsafe fn deref_mut<'a>(ptr: usize) -> &'a mut T

Mutably dereferences the given pointer. Read more
Source§

unsafe fn drop(ptr: usize)

Drops the object pointed to by the given pointer. Read more
Source§

impl<T> ToOwned for T
where T: Clone,

Source§

type Owned = T

The resulting type after obtaining ownership.
Source§

fn to_owned(&self) -> T

Creates owned data from borrowed data, usually by cloning. Read more
Source§

fn clone_into(&self, target: &mut T)

Uses borrowed data to replace owned data, usually by cloning. Read more
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
Source§

impl<V, T> VZip<V> for T
where V: MultiLane<T>,

Source§

fn vzip(self) -> V