Skip to main content

ruvector_cnn/contrastive/
triplet.rs

1//! # Triplet Loss
2//!
3//! Implementation of the Triplet Loss for metric learning.
4//!
5//! ## Mathematical Formulation
6//!
7//! ```text
8//! L(a, p, n) = max(0, ||a - p||^2 - ||a - n||^2 + margin)
9//! ```
10//!
11//! Where:
12//! - `a` is the anchor embedding
13//! - `p` is the positive (similar) embedding
14//! - `n` is the negative (dissimilar) embedding
15//! - `margin` is the minimum desired separation
16//!
17//! ## Variants
18//!
19//! - **Standard**: Uses Euclidean distance
20//! - **Angular**: Uses angular distance for normalized embeddings
21//! - **Soft**: Uses soft-margin (log-exp) for smoother gradients
22//!
23//! ## References
24//!
25//! - "FaceNet: A Unified Embedding for Face Recognition and Clustering"
26//! - "Deep Metric Learning Using Triplet Network"
27
28use crate::error::{CnnError, CnnResult};
29use serde::{Deserialize, Serialize};
30
31/// Distance metric for triplet loss computation.
32#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
33pub enum TripletDistance {
34    /// Euclidean (L2) distance
35    Euclidean,
36    /// Squared Euclidean distance (avoids sqrt, faster)
37    SquaredEuclidean,
38    /// Cosine distance (1 - cosine_similarity)
39    Cosine,
40}
41
42/// Result of triplet loss computation.
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct TripletResult {
45    /// The computed loss value
46    pub loss: f64,
47    /// Distance between anchor and positive
48    pub positive_distance: f64,
49    /// Distance between anchor and negative
50    pub negative_distance: f64,
51    /// Whether the triplet is a "hard" triplet (loss > 0)
52    pub is_hard: bool,
53    /// Whether the triplet violates the margin
54    pub violates_margin: bool,
55}
56
57/// Triplet loss for metric learning.
58///
59/// # Example
60///
61/// ```rust
62/// use ruvector_cnn::contrastive::TripletLoss;
63///
64/// let triplet = TripletLoss::new(1.0);
65///
66/// let anchor = vec![1.0, 0.0, 0.0];
67/// let positive = vec![0.9, 0.1, 0.0];  // similar to anchor
68/// let negative = vec![0.0, 1.0, 0.0];  // dissimilar to anchor
69///
70/// let loss = triplet.forward(&anchor, &positive, &negative);
71/// assert!(loss >= 0.0);
72/// ```
73#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct TripletLoss {
75    /// Margin for the triplet loss (default: 1.0)
76    margin: f64,
77    /// Distance metric to use
78    distance: TripletDistance,
79    /// Whether to use soft margin (log-exp smoothing)
80    soft_margin: bool,
81    /// L2 regularization weight (optional)
82    l2_regularization: Option<f64>,
83}
84
85impl TripletLoss {
86    /// Create a new triplet loss with the specified margin.
87    ///
88    /// # Arguments
89    ///
90    /// * `margin` - The minimum desired separation between positive and negative
91    ///   distances. Typical values: 0.2-2.0
92    ///
93    /// # Panics
94    ///
95    /// Panics if margin is negative.
96    pub fn new(margin: f64) -> Self {
97        assert!(margin >= 0.0, "Margin must be non-negative");
98        Self {
99            margin,
100            distance: TripletDistance::SquaredEuclidean,
101            soft_margin: false,
102            l2_regularization: None,
103        }
104    }
105
106    /// Set the distance metric.
107    pub fn with_distance(mut self, distance: TripletDistance) -> Self {
108        self.distance = distance;
109        self
110    }
111
112    /// Enable soft margin for smoother gradients.
113    ///
114    /// Instead of `max(0, x)`, uses `log(1 + exp(x))`.
115    pub fn with_soft_margin(mut self) -> Self {
116        self.soft_margin = true;
117        self
118    }
119
120    /// Add L2 regularization on embeddings.
121    pub fn with_l2_regularization(mut self, weight: f64) -> Self {
122        self.l2_regularization = Some(weight);
123        self
124    }
125
126    /// Get the margin.
127    pub fn margin(&self) -> f64 {
128        self.margin
129    }
130
131    /// Get the distance metric.
132    pub fn distance_metric(&self) -> TripletDistance {
133        self.distance
134    }
135
136    /// Compute triplet loss for a single triplet.
137    ///
138    /// # Arguments
139    ///
140    /// * `anchor` - The anchor embedding
141    /// * `positive` - The positive (similar) embedding
142    /// * `negative` - The negative (dissimilar) embedding
143    ///
144    /// # Returns
145    ///
146    /// The triplet loss value (non-negative).
147    pub fn forward(&self, anchor: &[f64], positive: &[f64], negative: &[f64]) -> f64 {
148        self.forward_detailed(anchor, positive, negative)
149            .map(|r| r.loss)
150            .unwrap_or(0.0)
151    }
152
153    /// Compute triplet loss with detailed results.
154    pub fn forward_detailed(
155        &self,
156        anchor: &[f64],
157        positive: &[f64],
158        negative: &[f64],
159    ) -> CnnResult<TripletResult> {
160        // Validate inputs
161        if anchor.is_empty() {
162            return Err(CnnError::InvalidInput("anchor cannot be empty".to_string()));
163        }
164
165        let dim = anchor.len();
166        if positive.len() != dim {
167            return Err(CnnError::DimensionMismatch(format!(
168                "positive has dimension {}, expected {}",
169                positive.len(),
170                dim
171            )));
172        }
173        if negative.len() != dim {
174            return Err(CnnError::DimensionMismatch(format!(
175                "negative has dimension {}, expected {}",
176                negative.len(),
177                dim
178            )));
179        }
180
181        // Check for NaN/Inf
182        for (name, vec) in [("anchor", anchor), ("positive", positive), ("negative", negative)] {
183            if vec.iter().any(|x| x.is_nan() || x.is_infinite()) {
184                return Err(CnnError::InvalidInput(format!(
185                    "{} contains NaN or Inf",
186                    name
187                )));
188            }
189        }
190
191        // Compute distances
192        let pos_dist = self.compute_distance(anchor, positive);
193        let neg_dist = self.compute_distance(anchor, negative);
194
195        // Compute loss
196        let diff = pos_dist - neg_dist + self.margin;
197        let loss = if self.soft_margin {
198            soft_relu(diff)
199        } else {
200            diff.max(0.0)
201        };
202
203        // Add L2 regularization if enabled
204        let loss = if let Some(weight) = self.l2_regularization {
205            let anchor_norm: f64 = anchor.iter().map(|x| x * x).sum();
206            let pos_norm: f64 = positive.iter().map(|x| x * x).sum();
207            let neg_norm: f64 = negative.iter().map(|x| x * x).sum();
208            loss + weight * (anchor_norm + pos_norm + neg_norm) / 3.0
209        } else {
210            loss
211        };
212
213        Ok(TripletResult {
214            loss,
215            positive_distance: pos_dist,
216            negative_distance: neg_dist,
217            is_hard: diff > 0.0,
218            violates_margin: pos_dist + self.margin > neg_dist,
219        })
220    }
221
222    /// Compute batch triplet loss.
223    ///
224    /// # Arguments
225    ///
226    /// * `anchors` - Batch of anchor embeddings
227    /// * `positives` - Batch of positive embeddings
228    /// * `negatives` - Batch of negative embeddings
229    ///
230    /// # Returns
231    ///
232    /// Mean triplet loss across the batch.
233    pub fn forward_batch(
234        &self,
235        anchors: &[Vec<f64>],
236        positives: &[Vec<f64>],
237        negatives: &[Vec<f64>],
238    ) -> CnnResult<f64> {
239        if anchors.len() != positives.len() || anchors.len() != negatives.len() {
240            return Err(CnnError::DimensionMismatch(format!(
241                "Batch sizes must match: anchors={}, positives={}, negatives={}",
242                anchors.len(),
243                positives.len(),
244                negatives.len()
245            )));
246        }
247
248        if anchors.is_empty() {
249            return Err(CnnError::InvalidInput("batch cannot be empty".to_string()));
250        }
251
252        let mut total_loss = 0.0;
253        for ((anchor, positive), negative) in anchors.iter().zip(positives).zip(negatives) {
254            total_loss += self.forward(anchor, positive, negative);
255        }
256
257        Ok(total_loss / anchors.len() as f64)
258    }
259
260    /// Mine hard triplets from a batch.
261    ///
262    /// Returns indices of (anchor, positive, negative) triplets where the loss is positive.
263    ///
264    /// # Arguments
265    ///
266    /// * `embeddings` - All embeddings in the batch
267    /// * `labels` - Class labels for each embedding
268    ///
269    /// # Returns
270    ///
271    /// Vector of (anchor_idx, positive_idx, negative_idx) tuples.
272    pub fn mine_hard_triplets(
273        &self,
274        embeddings: &[Vec<f64>],
275        labels: &[usize],
276    ) -> Vec<(usize, usize, usize)> {
277        if embeddings.len() != labels.len() {
278            return vec![];
279        }
280
281        let n = embeddings.len();
282        let mut triplets = Vec::new();
283
284        // Precompute distance matrix
285        let distances = self.compute_distance_matrix(embeddings);
286
287        for anchor_idx in 0..n {
288            let anchor_label = labels[anchor_idx];
289
290            // Find hardest positive (furthest with same label)
291            let mut hardest_pos_idx = None;
292            let mut hardest_pos_dist = f64::NEG_INFINITY;
293
294            // Find hardest negative (closest with different label)
295            let mut hardest_neg_idx = None;
296            let mut hardest_neg_dist = f64::INFINITY;
297
298            for other_idx in 0..n {
299                if other_idx == anchor_idx {
300                    continue;
301                }
302
303                let dist = distances[anchor_idx][other_idx];
304
305                if labels[other_idx] == anchor_label {
306                    // Same class - potential positive
307                    if dist > hardest_pos_dist {
308                        hardest_pos_dist = dist;
309                        hardest_pos_idx = Some(other_idx);
310                    }
311                } else {
312                    // Different class - potential negative
313                    if dist < hardest_neg_dist {
314                        hardest_neg_dist = dist;
315                        hardest_neg_idx = Some(other_idx);
316                    }
317                }
318            }
319
320            // Add triplet if valid and hard
321            if let (Some(pos_idx), Some(neg_idx)) = (hardest_pos_idx, hardest_neg_idx) {
322                if hardest_pos_dist - hardest_neg_dist + self.margin > 0.0 {
323                    triplets.push((anchor_idx, pos_idx, neg_idx));
324                }
325            }
326        }
327
328        triplets
329    }
330
331    /// Compute distance between two vectors.
332    #[inline]
333    fn compute_distance(&self, a: &[f64], b: &[f64]) -> f64 {
334        match self.distance {
335            TripletDistance::Euclidean => {
336                let sum_sq: f64 = a.iter().zip(b).map(|(x, y)| (x - y).powi(2)).sum();
337                sum_sq.sqrt()
338            }
339            TripletDistance::SquaredEuclidean => {
340                a.iter().zip(b).map(|(x, y)| (x - y).powi(2)).sum()
341            }
342            TripletDistance::Cosine => {
343                let mut dot = 0.0;
344                let mut norm_a_sq = 0.0;
345                let mut norm_b_sq = 0.0;
346
347                for (x, y) in a.iter().zip(b) {
348                    dot += x * y;
349                    norm_a_sq += x * x;
350                    norm_b_sq += y * y;
351                }
352
353                let norm = (norm_a_sq * norm_b_sq).sqrt();
354                if norm < 1e-8 {
355                    1.0 // Maximum distance for zero vectors
356                } else {
357                    1.0 - dot / norm
358                }
359            }
360        }
361    }
362
363    /// Compute pairwise distance matrix.
364    fn compute_distance_matrix(&self, embeddings: &[Vec<f64>]) -> Vec<Vec<f64>> {
365        let n = embeddings.len();
366        let mut matrix = vec![vec![0.0; n]; n];
367
368        for i in 0..n {
369            for j in (i + 1)..n {
370                let dist = self.compute_distance(&embeddings[i], &embeddings[j]);
371                matrix[i][j] = dist;
372                matrix[j][i] = dist;
373            }
374        }
375
376        matrix
377    }
378}
379
380impl Default for TripletLoss {
381    fn default() -> Self {
382        Self::new(1.0)
383    }
384}
385
386/// Soft ReLU: log(1 + exp(x)) for smooth gradients.
387#[inline]
388fn soft_relu(x: f64) -> f64 {
389    if x > 20.0 {
390        x // Avoid overflow
391    } else if x < -20.0 {
392        0.0 // Underflow to 0
393    } else {
394        (1.0 + x.exp()).ln()
395    }
396}
397
398#[cfg(test)]
399mod tests {
400    use super::*;
401
402    #[test]
403    fn test_triplet_basic() {
404        let triplet = TripletLoss::new(1.0);
405
406        let anchor = vec![1.0, 0.0, 0.0];
407        let positive = vec![0.9, 0.1, 0.0];
408        let negative = vec![0.0, 1.0, 0.0];
409
410        let loss = triplet.forward(&anchor, &positive, &negative);
411        assert!(loss >= 0.0);
412    }
413
414    #[test]
415    fn test_triplet_zero_loss() {
416        let triplet = TripletLoss::new(0.1);
417
418        // Negative is far, positive is close - should be zero loss
419        let anchor = vec![1.0, 0.0];
420        let positive = vec![1.0, 0.0]; // identical
421        let negative = vec![-1.0, 0.0]; // opposite
422
423        let result = triplet.forward_detailed(&anchor, &positive, &negative).unwrap();
424        assert_eq!(result.loss, 0.0);
425        assert!(!result.is_hard);
426    }
427
428    #[test]
429    fn test_triplet_hard() {
430        let triplet = TripletLoss::new(1.0);
431
432        // Negative is closer than positive - hard triplet
433        let anchor = vec![0.0, 0.0];
434        let positive = vec![2.0, 0.0];
435        let negative = vec![1.0, 0.0];
436
437        let result = triplet.forward_detailed(&anchor, &positive, &negative).unwrap();
438        assert!(result.loss > 0.0);
439        assert!(result.is_hard);
440        assert!(result.violates_margin);
441    }
442
443    #[test]
444    fn test_triplet_distances() {
445        // Test Euclidean distance
446        let triplet_euclidean = TripletLoss::new(0.0).with_distance(TripletDistance::Euclidean);
447        let a = vec![0.0, 0.0];
448        let b = vec![3.0, 4.0];
449        let c = vec![0.0, 0.0];
450
451        let result = triplet_euclidean.forward_detailed(&a, &b, &c).unwrap();
452        assert!((result.positive_distance - 5.0).abs() < 1e-6);
453        assert!(result.negative_distance.abs() < 1e-6);
454
455        // Test cosine distance
456        let triplet_cosine = TripletLoss::new(0.0).with_distance(TripletDistance::Cosine);
457        let x = vec![1.0, 0.0];
458        let y = vec![0.0, 1.0];
459        let z = vec![1.0, 0.0];
460
461        let result = triplet_cosine.forward_detailed(&x, &y, &z).unwrap();
462        assert!((result.positive_distance - 1.0).abs() < 1e-6); // orthogonal = 1
463        assert!(result.negative_distance.abs() < 1e-6); // identical = 0
464    }
465
466    #[test]
467    fn test_soft_margin() {
468        let hard = TripletLoss::new(1.0);
469        let soft = TripletLoss::new(1.0).with_soft_margin();
470
471        let anchor = vec![0.0, 0.0];
472        let positive = vec![1.0, 0.0];
473        let negative = vec![0.5, 0.0];
474
475        let hard_loss = hard.forward(&anchor, &positive, &negative);
476        let soft_loss = soft.forward(&anchor, &positive, &negative);
477
478        // Soft margin should be >= hard margin
479        assert!(soft_loss >= hard_loss);
480        // Both should be positive for this hard triplet
481        assert!(hard_loss > 0.0);
482        assert!(soft_loss > 0.0);
483    }
484
485    #[test]
486    fn test_batch_triplet() {
487        let triplet = TripletLoss::new(1.0);
488
489        let anchors = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
490        let positives = vec![vec![0.9, 0.1], vec![0.1, 0.9]];
491        let negatives = vec![vec![0.0, 1.0], vec![1.0, 0.0]];
492
493        let loss = triplet.forward_batch(&anchors, &positives, &negatives).unwrap();
494        assert!(loss >= 0.0);
495    }
496
497    #[test]
498    fn test_mine_hard_triplets() {
499        // Use a smaller margin so triplets are more likely to be "hard"
500        let triplet = TripletLoss::new(0.01);
501
502        // Create embeddings where hard triplets are guaranteed to exist
503        // Class 0 and class 1 embeddings are close to each other
504        let embeddings = vec![
505            vec![1.0, 0.0],   // class 0 - anchor
506            vec![0.95, 0.05], // class 0 - positive (close to anchor)
507            vec![0.9, 0.1],   // class 1 - negative (also close, creating hard triplet)
508            vec![0.85, 0.15], // class 1 - another negative
509        ];
510        let labels = vec![0, 0, 1, 1];
511
512        let hard_triplets = triplet.mine_hard_triplets(&embeddings, &labels);
513
514        // Verify triplet structure for any hard triplets found
515        // Note: hard triplets may not always be found depending on the margin and embeddings
516        for (a, p, n) in &hard_triplets {
517            assert_eq!(labels[*a], labels[*p]); // anchor and positive same class
518            assert_ne!(labels[*a], labels[*n]); // anchor and negative different class
519        }
520
521        // The function should at least return a valid (possibly empty) vec
522        // If hard triplets are found, they should have valid structure (tested above)
523    }
524
525    #[test]
526    fn test_l2_regularization() {
527        let no_reg = TripletLoss::new(0.0);
528        let with_reg = TripletLoss::new(0.0).with_l2_regularization(0.01);
529
530        let anchor = vec![10.0, 0.0];
531        let positive = vec![10.0, 0.0];
532        let negative = vec![-10.0, 0.0];
533
534        let loss_no_reg = no_reg.forward(&anchor, &positive, &negative);
535        let loss_with_reg = with_reg.forward(&anchor, &positive, &negative);
536
537        // With L2 regularization, loss should be higher for large embeddings
538        assert!(loss_with_reg > loss_no_reg);
539    }
540
541    #[test]
542    fn test_error_handling() {
543        let triplet = TripletLoss::new(1.0);
544
545        // Empty anchor
546        let result = triplet.forward_detailed(&[], &[1.0], &[1.0]);
547        assert!(result.is_err());
548
549        // Dimension mismatch
550        let result = triplet.forward_detailed(&[1.0, 2.0], &[1.0], &[1.0, 2.0]);
551        assert!(result.is_err());
552    }
553
554    #[test]
555    fn test_soft_relu() {
556        // Basic cases
557        assert!((soft_relu(0.0) - 2.0_f64.ln()).abs() < 1e-6);
558        assert!(soft_relu(-100.0) < 1e-10);
559        assert!((soft_relu(100.0) - 100.0).abs() < 1e-6);
560
561        // Smooth transition
562        let x = 1.0;
563        let y = soft_relu(x);
564        assert!(y > x.max(0.0)); // Always >= max(0, x) but smoother
565    }
566}