TripletLoss

Struct TripletLoss 

Source
pub struct TripletLoss { /* private fields */ }
Expand description

Triplet loss function.

The triplet loss is used to learn embeddings where the distance between anchor and positive samples is minimized while the distance between anchor and negative samples is maximized. The loss is defined as: L = max(0, d(a, p) - d(a, n) + margin) where a is anchor, p is positive, n is negative, and d is distance.

§Examples

use scirs2_neural::losses::TripletLoss;
use scirs2_neural::losses::Loss;
use scirs2_core::ndarray::{Array, arr3};

let triplet = TripletLoss::new(1.0);
// Triplets: (batch_size, 3, embedding_dim) where 3 = [anchor, positive, negative]
let embeddings = arr3(&[
    [   // First triplet
        [0.1, 0.2, 0.3],  // Anchor
        [0.1, 0.3, 0.3],  // Positive (similar to anchor)
        [0.9, 0.8, 0.7],  // Negative (dissimilar to anchor)
    ],
    [   // Second triplet
        [0.5, 0.5, 0.5],  // Anchor
        [0.6, 0.4, 0.5],  // Positive
        [0.1, 0.1, 0.1],  // Negative
    ]
]).into_dyn();
// Targets not used in triplet loss (can be dummy)
let targets = Array::zeros(embeddings.raw_dim());
// Forward pass to calculate loss
let loss = triplet.forward(&embeddings, &targets).unwrap();
// Backward pass to calculate gradients
let gradients = triplet.backward(&embeddings, &targets).unwrap();

Implementations§

Source§

impl TripletLoss

Source

pub fn new(margin: f64) -> Self

Create a new triplet loss function

§Arguments
  • margin - Margin between positive and negative distances
Examples found in repository?
examples/loss_functions_example.rs (line 74)
7fn main() -> Result<(), Box<dyn std::error::Error>> {
8    println!("Loss functions example");
9    // Mean Squared Error example
10    println!("\n--- Mean Squared Error Example ---");
11    let mse = MeanSquaredError::new();
12    // Create sample data for regression
13    let predictions = Array::from_vec(vec![1.0, 2.0, 3.0]).into_dyn();
14    let targets = Array::from_vec(vec![1.5, 1.8, 2.5]).into_dyn();
15    // Calculate loss
16    let loss = mse.forward(&predictions, &targets)?;
17    println!("Predictions: {predictions:?}");
18    println!("Targets: {targets:?}");
19    println!("MSE Loss: {loss:.4}");
20    // Calculate gradients
21    let gradients = mse.backward(&predictions, &targets)?;
22    println!("MSE Gradients: {gradients:?}");
23    // Cross-Entropy Loss example
24    println!("\n--- Cross-Entropy Loss Example ---");
25    let ce = CrossEntropyLoss::new(1e-10);
26    // Create sample data for multi-class classification
27    let predictions = Array::from_shape_vec(IxDyn(&[2, 3]), vec![0.7, 0.2, 0.1, 0.3, 0.6, 0.1])?;
28    let targets = Array::from_shape_vec(IxDyn(&[2, 3]), vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0])?;
29    let loss = ce.forward(&predictions, &targets)?;
30    println!("Predictions (probabilities):");
31    println!("{predictions:?}");
32    println!("Targets (one-hot):");
33    println!("{targets:?}");
34    println!("Cross-Entropy Loss: {loss:.4}");
35    let gradients = ce.backward(&predictions, &targets)?;
36    println!("Cross-Entropy Gradients:");
37    println!("{gradients:?}");
38    // Focal Loss example
39    println!("\n--- Focal Loss Example ---");
40    let focal = FocalLoss::new(2.0, Some(0.25), 1e-10);
41    // Create sample data for imbalanced classification
42    let loss = focal.forward(&predictions, &targets)?;
43    println!("Focal Loss (gamma=2.0, alpha=0.25): {loss:.4}");
44    let _gradients = focal.backward(&predictions, &targets)?;
45    println!("Focal Loss Gradients:");
46    // Contrastive Loss example
47    println!("\n--- Contrastive Loss Example ---");
48    let contrastive = ContrastiveLoss::new(1.0);
49    // Create sample data for similarity learning
50    // Embedding pairs (batch_size x 2 x embedding_dim)
51    let embeddings = Array::from_shape_vec(
52        IxDyn(&[2, 2, 3]),
53        vec![
54            0.1, 0.2, 0.3, // First pair, first embedding
55            0.1, 0.3, 0.3, // First pair, second embedding (similar)
56            0.5, 0.5, 0.5, // Second pair, first embedding
57            0.9, 0.8, 0.7, // Second pair, second embedding (dissimilar)
58        ],
59    )?;
60    // Labels: 1 for similar pairs, 0 for dissimilar
61    let labels = Array::from_shape_vec(IxDyn(&[2, 1]), vec![1.0, 0.0])?;
62    let loss = contrastive.forward(&embeddings, &labels)?;
63    println!("Embeddings (batch_size x 2 x embedding_dim):");
64    println!("{embeddings:?}");
65    println!("Labels (1 for similar, 0 for dissimilar):");
66    println!("{labels:?}");
67    println!("Contrastive Loss (margin=1.0): {loss:.4}");
68    let gradients = contrastive.backward(&embeddings, &labels)?;
69    println!("Contrastive Loss Gradients (first few):");
70    let gradient_slice = gradients.slice(scirs2_core::ndarray::s![0, .., 0]);
71    println!("{gradient_slice:?}");
72    // Triplet Loss example
73    println!("\n--- Triplet Loss Example ---");
74    let triplet = TripletLoss::new(0.5);
75    // Create sample data for triplet learning
76    // Embedding triplets (batch_size x 3 x embedding_dim)
77    let triplet_embeddings = Array3::from_shape_vec(
78        (2, 3, 3),
79        vec![
80            0.1, 0.2, 0.3, // First triplet, anchor
81            0.1, 0.3, 0.3, // First triplet, positive
82            0.5, 0.5, 0.5, // First triplet, negative
83            0.6, 0.6, 0.6, // Second triplet, anchor
84            0.5, 0.6, 0.6, // Second triplet, positive
85            0.1, 0.1, 0.1, // Second triplet, negative
86        ],
87    )?
88    .into_dyn();
89    // Dummy labels (not used by triplet loss)
90    let dummy_labels = Array::zeros(IxDyn(&[2, 1]));
91    let loss = triplet.forward(&triplet_embeddings, &dummy_labels)?;
92    println!("Embeddings (batch_size x 3 x embedding_dim):");
93    println!("  - First dimension: batch size");
94    println!("  - Second dimension: [anchor, positive, negative]");
95    println!("  - Third dimension: embedding components");
96    println!("Triplet Loss (margin=0.5): {loss:.4}");
97    let _gradients = triplet.backward(&triplet_embeddings, &dummy_labels)?;
98    println!("Triplet Loss Gradients (first few):");
99    Ok(())
100}

Trait Implementations§

Source§

impl Clone for TripletLoss

Source§

fn clone(&self) -> TripletLoss

Returns a duplicate of the value. Read more
1.0.0 · Source§

fn clone_from(&mut self, source: &Self)

Performs copy-assignment from source. Read more
Source§

impl Debug for TripletLoss

Source§

fn fmt(&self, f: &mut Formatter<'_>) -> Result

Formats the value using the given formatter. Read more
Source§

impl Default for TripletLoss

Source§

fn default() -> Self

Returns the “default value” for a type. Read more
Source§

impl<F: Float + Debug> Loss<F> for TripletLoss

Source§

fn forward( &self, predictions: &Array<F, IxDyn>, _targets: &Array<F, IxDyn>, ) -> Result<F>

Calculate the loss between predictions and targets Read more
Source§

fn backward( &self, predictions: &Array<F, IxDyn>, _targets: &Array<F, IxDyn>, ) -> Result<Array<F, IxDyn>>

Calculate the gradient of the loss with respect to the predictions This method computes ∂Loss/∂predictions, which is used in backpropagation to update the model parameters. Read more
Source§

impl Copy for TripletLoss

Auto Trait Implementations§

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> CloneToUninit for T
where T: Clone,

Source§

unsafe fn clone_to_uninit(&self, dest: *mut u8)

🔬This is a nightly-only experimental API. (clone_to_uninit)
Performs copy-assignment from self to dest. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> IntoEither for T

Source§

fn into_either(self, into_left: bool) -> Either<Self, Self>

Converts self into a Left variant of Either<Self, Self> if into_left is true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
where F: FnOnce(&Self) -> bool,

Converts self into a Left variant of Either<Self, Self> if into_left(&self) returns true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

impl<T> Pointable for T

Source§

const ALIGN: usize

The alignment of pointer.
Source§

type Init = T

The type for initializers.
Source§

unsafe fn init(init: <T as Pointable>::Init) -> usize

Initializes a with the given initializer. Read more
Source§

unsafe fn deref<'a>(ptr: usize) -> &'a T

Dereferences the given pointer. Read more
Source§

unsafe fn deref_mut<'a>(ptr: usize) -> &'a mut T

Mutably dereferences the given pointer. Read more
Source§

unsafe fn drop(ptr: usize)

Drops the object pointed to by the given pointer. Read more
Source§

impl<T> ToOwned for T
where T: Clone,

Source§

type Owned = T

The resulting type after obtaining ownership.
Source§

fn to_owned(&self) -> T

Creates owned data from borrowed data, usually by cloning. Read more
Source§

fn clone_into(&self, target: &mut T)

Uses borrowed data to replace owned data, usually by cloning. Read more
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
Source§

impl<V, T> VZip<V> for T
where V: MultiLane<T>,

Source§

fn vzip(self) -> V