pub struct TripletLoss { /* private fields */ }
Expand description
Triplet loss function.
The triplet loss is used to learn embeddings where the distance between anchor and positive samples is minimized while the distance between anchor and negative samples is maximized. The loss is defined as: L = max(0, d(a, p) - d(a, n) + margin) where a is anchor, p is positive, n is negative, and d is distance.
§Examples
use scirs2_neural::losses::TripletLoss;
use scirs2_neural::losses::Loss;
use scirs2_core::ndarray::{Array, arr3};
let triplet = TripletLoss::new(1.0);
// Triplets: (batch_size, 3, embedding_dim) where 3 = [anchor, positive, negative]
let embeddings = arr3(&[
[ // First triplet
[0.1, 0.2, 0.3], // Anchor
[0.1, 0.3, 0.3], // Positive (similar to anchor)
[0.9, 0.8, 0.7], // Negative (dissimilar to anchor)
],
[ // Second triplet
[0.5, 0.5, 0.5], // Anchor
[0.6, 0.4, 0.5], // Positive
[0.1, 0.1, 0.1], // Negative
]
]).into_dyn();
// Targets not used in triplet loss (can be dummy)
let targets = Array::zeros(embeddings.raw_dim());
// Forward pass to calculate loss
let loss = triplet.forward(&embeddings, &targets).unwrap();
// Backward pass to calculate gradients
let gradients = triplet.backward(&embeddings, &targets).unwrap();
Implementations§
Source§impl TripletLoss
impl TripletLoss
Sourcepub fn new(margin: f64) -> Self
pub fn new(margin: f64) -> Self
Create a new triplet loss function
§Arguments
margin
- Margin between positive and negative distances
Examples found in repository?
examples/loss_functions_example.rs (line 74)
7fn main() -> Result<(), Box<dyn std::error::Error>> {
8 println!("Loss functions example");
9 // Mean Squared Error example
10 println!("\n--- Mean Squared Error Example ---");
11 let mse = MeanSquaredError::new();
12 // Create sample data for regression
13 let predictions = Array::from_vec(vec![1.0, 2.0, 3.0]).into_dyn();
14 let targets = Array::from_vec(vec![1.5, 1.8, 2.5]).into_dyn();
15 // Calculate loss
16 let loss = mse.forward(&predictions, &targets)?;
17 println!("Predictions: {predictions:?}");
18 println!("Targets: {targets:?}");
19 println!("MSE Loss: {loss:.4}");
20 // Calculate gradients
21 let gradients = mse.backward(&predictions, &targets)?;
22 println!("MSE Gradients: {gradients:?}");
23 // Cross-Entropy Loss example
24 println!("\n--- Cross-Entropy Loss Example ---");
25 let ce = CrossEntropyLoss::new(1e-10);
26 // Create sample data for multi-class classification
27 let predictions = Array::from_shape_vec(IxDyn(&[2, 3]), vec![0.7, 0.2, 0.1, 0.3, 0.6, 0.1])?;
28 let targets = Array::from_shape_vec(IxDyn(&[2, 3]), vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0])?;
29 let loss = ce.forward(&predictions, &targets)?;
30 println!("Predictions (probabilities):");
31 println!("{predictions:?}");
32 println!("Targets (one-hot):");
33 println!("{targets:?}");
34 println!("Cross-Entropy Loss: {loss:.4}");
35 let gradients = ce.backward(&predictions, &targets)?;
36 println!("Cross-Entropy Gradients:");
37 println!("{gradients:?}");
38 // Focal Loss example
39 println!("\n--- Focal Loss Example ---");
40 let focal = FocalLoss::new(2.0, Some(0.25), 1e-10);
41 // Create sample data for imbalanced classification
42 let loss = focal.forward(&predictions, &targets)?;
43 println!("Focal Loss (gamma=2.0, alpha=0.25): {loss:.4}");
44 let _gradients = focal.backward(&predictions, &targets)?;
45 println!("Focal Loss Gradients:");
46 // Contrastive Loss example
47 println!("\n--- Contrastive Loss Example ---");
48 let contrastive = ContrastiveLoss::new(1.0);
49 // Create sample data for similarity learning
50 // Embedding pairs (batch_size x 2 x embedding_dim)
51 let embeddings = Array::from_shape_vec(
52 IxDyn(&[2, 2, 3]),
53 vec![
54 0.1, 0.2, 0.3, // First pair, first embedding
55 0.1, 0.3, 0.3, // First pair, second embedding (similar)
56 0.5, 0.5, 0.5, // Second pair, first embedding
57 0.9, 0.8, 0.7, // Second pair, second embedding (dissimilar)
58 ],
59 )?;
60 // Labels: 1 for similar pairs, 0 for dissimilar
61 let labels = Array::from_shape_vec(IxDyn(&[2, 1]), vec![1.0, 0.0])?;
62 let loss = contrastive.forward(&embeddings, &labels)?;
63 println!("Embeddings (batch_size x 2 x embedding_dim):");
64 println!("{embeddings:?}");
65 println!("Labels (1 for similar, 0 for dissimilar):");
66 println!("{labels:?}");
67 println!("Contrastive Loss (margin=1.0): {loss:.4}");
68 let gradients = contrastive.backward(&embeddings, &labels)?;
69 println!("Contrastive Loss Gradients (first few):");
70 let gradient_slice = gradients.slice(scirs2_core::ndarray::s![0, .., 0]);
71 println!("{gradient_slice:?}");
72 // Triplet Loss example
73 println!("\n--- Triplet Loss Example ---");
74 let triplet = TripletLoss::new(0.5);
75 // Create sample data for triplet learning
76 // Embedding triplets (batch_size x 3 x embedding_dim)
77 let triplet_embeddings = Array3::from_shape_vec(
78 (2, 3, 3),
79 vec![
80 0.1, 0.2, 0.3, // First triplet, anchor
81 0.1, 0.3, 0.3, // First triplet, positive
82 0.5, 0.5, 0.5, // First triplet, negative
83 0.6, 0.6, 0.6, // Second triplet, anchor
84 0.5, 0.6, 0.6, // Second triplet, positive
85 0.1, 0.1, 0.1, // Second triplet, negative
86 ],
87 )?
88 .into_dyn();
89 // Dummy labels (not used by triplet loss)
90 let dummy_labels = Array::zeros(IxDyn(&[2, 1]));
91 let loss = triplet.forward(&triplet_embeddings, &dummy_labels)?;
92 println!("Embeddings (batch_size x 3 x embedding_dim):");
93 println!(" - First dimension: batch size");
94 println!(" - Second dimension: [anchor, positive, negative]");
95 println!(" - Third dimension: embedding components");
96 println!("Triplet Loss (margin=0.5): {loss:.4}");
97 let _gradients = triplet.backward(&triplet_embeddings, &dummy_labels)?;
98 println!("Triplet Loss Gradients (first few):");
99 Ok(())
100}
Trait Implementations§
Source§impl Clone for TripletLoss
impl Clone for TripletLoss
Source§fn clone(&self) -> TripletLoss
fn clone(&self) -> TripletLoss
Returns a duplicate of the value. Read more
1.0.0 · Source§fn clone_from(&mut self, source: &Self)
fn clone_from(&mut self, source: &Self)
Performs copy-assignment from
source
. Read moreSource§impl Debug for TripletLoss
impl Debug for TripletLoss
Source§impl Default for TripletLoss
impl Default for TripletLoss
Source§impl<F: Float + Debug> Loss<F> for TripletLoss
impl<F: Float + Debug> Loss<F> for TripletLoss
impl Copy for TripletLoss
Auto Trait Implementations§
impl Freeze for TripletLoss
impl RefUnwindSafe for TripletLoss
impl Send for TripletLoss
impl Sync for TripletLoss
impl Unpin for TripletLoss
impl UnwindSafe for TripletLoss
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more
Source§impl<T> CloneToUninit for Twhere
T: Clone,
impl<T> CloneToUninit for Twhere
T: Clone,
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
Converts
self
into a Left
variant of Either<Self, Self>
if into_left
is true
.
Converts self
into a Right
variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
Converts
self
into a Left
variant of Either<Self, Self>
if into_left(&self)
returns true
.
Converts self
into a Right
variant of Either<Self, Self>
otherwise. Read more