FocalLoss

Struct FocalLoss 

Source
pub struct FocalLoss { /* private fields */ }
Expand description

Focal loss function.

The focal loss is designed to address class imbalance problems by down-weighting easy examples and focusing training on hard examples. For a single class, the focal loss is defined as: FL(p_t) = -α_t * (1 - p_t)^γ * log(p_t) where p_t is the model’s estimated probability for the true class.

§Examples

use scirs2_neural::losses::FocalLoss;
use scirs2_neural::losses::Loss;
use scirs2_core::ndarray::{Array, arr2};

let focal = FocalLoss::new(2.0, Some(0.25), 1e-10);
// Predictions and targets for a 3-class problem
let predictions = arr2(&[
    [0.7, 0.2, 0.1],  // First sample, class probabilities
    [0.3, 0.6, 0.1]   // Second sample, class probabilities
]).into_dyn();
let targets = arr2(&[
    [1.0, 0.0, 0.0],  // First sample, true class is 0
    [0.0, 1.0, 0.0]   // Second sample, true class is 1
]).into_dyn();
// Forward pass to calculate loss
let loss = focal.forward(&predictions, &targets).unwrap();
// Backward pass to calculate gradients
let gradients = focal.backward(&predictions, &targets).unwrap();

Implementations§

Source§

impl FocalLoss

Source

pub fn new(gamma: f64, alpha: Option<f64>, epsilon: f64) -> Self

Create a new focal loss function with a single alpha value for all classes

§Arguments
  • gamma - Focusing parameter, gamma >= 0. Higher gamma means more focus on misclassified examples.
  • alpha - Optional weighting factor, typically between 0 and 1.
  • epsilon - Small value to add to predictions to avoid log(0)
Examples found in repository?
examples/loss_functions_example.rs (line 40)
7fn main() -> Result<(), Box<dyn std::error::Error>> {
8    println!("Loss functions example");
9    // Mean Squared Error example
10    println!("\n--- Mean Squared Error Example ---");
11    let mse = MeanSquaredError::new();
12    // Create sample data for regression
13    let predictions = Array::from_vec(vec![1.0, 2.0, 3.0]).into_dyn();
14    let targets = Array::from_vec(vec![1.5, 1.8, 2.5]).into_dyn();
15    // Calculate loss
16    let loss = mse.forward(&predictions, &targets)?;
17    println!("Predictions: {predictions:?}");
18    println!("Targets: {targets:?}");
19    println!("MSE Loss: {loss:.4}");
20    // Calculate gradients
21    let gradients = mse.backward(&predictions, &targets)?;
22    println!("MSE Gradients: {gradients:?}");
23    // Cross-Entropy Loss example
24    println!("\n--- Cross-Entropy Loss Example ---");
25    let ce = CrossEntropyLoss::new(1e-10);
26    // Create sample data for multi-class classification
27    let predictions = Array::from_shape_vec(IxDyn(&[2, 3]), vec![0.7, 0.2, 0.1, 0.3, 0.6, 0.1])?;
28    let targets = Array::from_shape_vec(IxDyn(&[2, 3]), vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0])?;
29    let loss = ce.forward(&predictions, &targets)?;
30    println!("Predictions (probabilities):");
31    println!("{predictions:?}");
32    println!("Targets (one-hot):");
33    println!("{targets:?}");
34    println!("Cross-Entropy Loss: {loss:.4}");
35    let gradients = ce.backward(&predictions, &targets)?;
36    println!("Cross-Entropy Gradients:");
37    println!("{gradients:?}");
38    // Focal Loss example
39    println!("\n--- Focal Loss Example ---");
40    let focal = FocalLoss::new(2.0, Some(0.25), 1e-10);
41    // Create sample data for imbalanced classification
42    let loss = focal.forward(&predictions, &targets)?;
43    println!("Focal Loss (gamma=2.0, alpha=0.25): {loss:.4}");
44    let _gradients = focal.backward(&predictions, &targets)?;
45    println!("Focal Loss Gradients:");
46    // Contrastive Loss example
47    println!("\n--- Contrastive Loss Example ---");
48    let contrastive = ContrastiveLoss::new(1.0);
49    // Create sample data for similarity learning
50    // Embedding pairs (batch_size x 2 x embedding_dim)
51    let embeddings = Array::from_shape_vec(
52        IxDyn(&[2, 2, 3]),
53        vec![
54            0.1, 0.2, 0.3, // First pair, first embedding
55            0.1, 0.3, 0.3, // First pair, second embedding (similar)
56            0.5, 0.5, 0.5, // Second pair, first embedding
57            0.9, 0.8, 0.7, // Second pair, second embedding (dissimilar)
58        ],
59    )?;
60    // Labels: 1 for similar pairs, 0 for dissimilar
61    let labels = Array::from_shape_vec(IxDyn(&[2, 1]), vec![1.0, 0.0])?;
62    let loss = contrastive.forward(&embeddings, &labels)?;
63    println!("Embeddings (batch_size x 2 x embedding_dim):");
64    println!("{embeddings:?}");
65    println!("Labels (1 for similar, 0 for dissimilar):");
66    println!("{labels:?}");
67    println!("Contrastive Loss (margin=1.0): {loss:.4}");
68    let gradients = contrastive.backward(&embeddings, &labels)?;
69    println!("Contrastive Loss Gradients (first few):");
70    let gradient_slice = gradients.slice(scirs2_core::ndarray::s![0, .., 0]);
71    println!("{gradient_slice:?}");
72    // Triplet Loss example
73    println!("\n--- Triplet Loss Example ---");
74    let triplet = TripletLoss::new(0.5);
75    // Create sample data for triplet learning
76    // Embedding triplets (batch_size x 3 x embedding_dim)
77    let triplet_embeddings = Array3::from_shape_vec(
78        (2, 3, 3),
79        vec![
80            0.1, 0.2, 0.3, // First triplet, anchor
81            0.1, 0.3, 0.3, // First triplet, positive
82            0.5, 0.5, 0.5, // First triplet, negative
83            0.6, 0.6, 0.6, // Second triplet, anchor
84            0.5, 0.6, 0.6, // Second triplet, positive
85            0.1, 0.1, 0.1, // Second triplet, negative
86        ],
87    )?
88    .into_dyn();
89    // Dummy labels (not used by triplet loss)
90    let dummy_labels = Array::zeros(IxDyn(&[2, 1]));
91    let loss = triplet.forward(&triplet_embeddings, &dummy_labels)?;
92    println!("Embeddings (batch_size x 3 x embedding_dim):");
93    println!("  - First dimension: batch size");
94    println!("  - Second dimension: [anchor, positive, negative]");
95    println!("  - Third dimension: embedding components");
96    println!("Triplet Loss (margin=0.5): {loss:.4}");
97    let _gradients = triplet.backward(&triplet_embeddings, &dummy_labels)?;
98    println!("Triplet Loss Gradients (first few):");
99    Ok(())
100}
Source

pub fn with_class_weights( gamma: f64, alpha_perclass: Vec<f64>, epsilon: f64, ) -> Self

Create a focal loss with class-specific alpha weights

  • alpha_per_class - Vector of weighting factors, one per class

Trait Implementations§

Source§

impl Clone for FocalLoss

Source§

fn clone(&self) -> FocalLoss

Returns a duplicate of the value. Read more
1.0.0 · Source§

fn clone_from(&mut self, source: &Self)

Performs copy-assignment from source. Read more
Source§

impl Debug for FocalLoss

Source§

fn fmt(&self, f: &mut Formatter<'_>) -> Result

Formats the value using the given formatter. Read more
Source§

impl Default for FocalLoss

Source§

fn default() -> Self

Returns the “default value” for a type. Read more
Source§

impl<F: Float + Debug> Loss<F> for FocalLoss

Source§

fn forward( &self, predictions: &Array<F, IxDyn>, targets: &Array<F, IxDyn>, ) -> Result<F>

Calculate the loss between predictions and targets Read more
Source§

fn backward( &self, predictions: &Array<F, IxDyn>, targets: &Array<F, IxDyn>, ) -> Result<Array<F, IxDyn>>

Calculate the gradient of the loss with respect to the predictions This method computes ∂Loss/∂predictions, which is used in backpropagation to update the model parameters. Read more

Auto Trait Implementations§

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> CloneToUninit for T
where T: Clone,

Source§

unsafe fn clone_to_uninit(&self, dest: *mut u8)

🔬This is a nightly-only experimental API. (clone_to_uninit)
Performs copy-assignment from self to dest. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> IntoEither for T

Source§

fn into_either(self, into_left: bool) -> Either<Self, Self>

Converts self into a Left variant of Either<Self, Self> if into_left is true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
where F: FnOnce(&Self) -> bool,

Converts self into a Left variant of Either<Self, Self> if into_left(&self) returns true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

impl<T> Pointable for T

Source§

const ALIGN: usize

The alignment of pointer.
Source§

type Init = T

The type for initializers.
Source§

unsafe fn init(init: <T as Pointable>::Init) -> usize

Initializes a with the given initializer. Read more
Source§

unsafe fn deref<'a>(ptr: usize) -> &'a T

Dereferences the given pointer. Read more
Source§

unsafe fn deref_mut<'a>(ptr: usize) -> &'a mut T

Mutably dereferences the given pointer. Read more
Source§

unsafe fn drop(ptr: usize)

Drops the object pointed to by the given pointer. Read more
Source§

impl<T> ToOwned for T
where T: Clone,

Source§

type Owned = T

The resulting type after obtaining ownership.
Source§

fn to_owned(&self) -> T

Creates owned data from borrowed data, usually by cloning. Read more
Source§

fn clone_into(&self, target: &mut T)

Uses borrowed data to replace owned data, usually by cloning. Read more
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
Source§

impl<V, T> VZip<V> for T
where V: MultiLane<T>,

Source§

fn vzip(self) -> V