1use crate::error::{CnnError, CnnResult};
29use serde::{Deserialize, Serialize};
30
31#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
33pub enum TripletDistance {
34 Euclidean,
36 SquaredEuclidean,
38 Cosine,
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct TripletResult {
45 pub loss: f64,
47 pub positive_distance: f64,
49 pub negative_distance: f64,
51 pub is_hard: bool,
53 pub violates_margin: bool,
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct TripletLoss {
75 margin: f64,
77 distance: TripletDistance,
79 soft_margin: bool,
81 l2_regularization: Option<f64>,
83}
84
85impl TripletLoss {
86 pub fn new(margin: f64) -> Self {
97 assert!(margin >= 0.0, "Margin must be non-negative");
98 Self {
99 margin,
100 distance: TripletDistance::SquaredEuclidean,
101 soft_margin: false,
102 l2_regularization: None,
103 }
104 }
105
106 pub fn with_distance(mut self, distance: TripletDistance) -> Self {
108 self.distance = distance;
109 self
110 }
111
112 pub fn with_soft_margin(mut self) -> Self {
116 self.soft_margin = true;
117 self
118 }
119
120 pub fn with_l2_regularization(mut self, weight: f64) -> Self {
122 self.l2_regularization = Some(weight);
123 self
124 }
125
126 pub fn margin(&self) -> f64 {
128 self.margin
129 }
130
131 pub fn distance_metric(&self) -> TripletDistance {
133 self.distance
134 }
135
136 pub fn forward(&self, anchor: &[f64], positive: &[f64], negative: &[f64]) -> f64 {
148 self.forward_detailed(anchor, positive, negative)
149 .map(|r| r.loss)
150 .unwrap_or(0.0)
151 }
152
153 pub fn forward_detailed(
155 &self,
156 anchor: &[f64],
157 positive: &[f64],
158 negative: &[f64],
159 ) -> CnnResult<TripletResult> {
160 if anchor.is_empty() {
162 return Err(CnnError::InvalidInput("anchor cannot be empty".to_string()));
163 }
164
165 let dim = anchor.len();
166 if positive.len() != dim {
167 return Err(CnnError::DimensionMismatch(format!(
168 "positive has dimension {}, expected {}",
169 positive.len(),
170 dim
171 )));
172 }
173 if negative.len() != dim {
174 return Err(CnnError::DimensionMismatch(format!(
175 "negative has dimension {}, expected {}",
176 negative.len(),
177 dim
178 )));
179 }
180
181 for (name, vec) in [("anchor", anchor), ("positive", positive), ("negative", negative)] {
183 if vec.iter().any(|x| x.is_nan() || x.is_infinite()) {
184 return Err(CnnError::InvalidInput(format!(
185 "{} contains NaN or Inf",
186 name
187 )));
188 }
189 }
190
191 let pos_dist = self.compute_distance(anchor, positive);
193 let neg_dist = self.compute_distance(anchor, negative);
194
195 let diff = pos_dist - neg_dist + self.margin;
197 let loss = if self.soft_margin {
198 soft_relu(diff)
199 } else {
200 diff.max(0.0)
201 };
202
203 let loss = if let Some(weight) = self.l2_regularization {
205 let anchor_norm: f64 = anchor.iter().map(|x| x * x).sum();
206 let pos_norm: f64 = positive.iter().map(|x| x * x).sum();
207 let neg_norm: f64 = negative.iter().map(|x| x * x).sum();
208 loss + weight * (anchor_norm + pos_norm + neg_norm) / 3.0
209 } else {
210 loss
211 };
212
213 Ok(TripletResult {
214 loss,
215 positive_distance: pos_dist,
216 negative_distance: neg_dist,
217 is_hard: diff > 0.0,
218 violates_margin: pos_dist + self.margin > neg_dist,
219 })
220 }
221
222 pub fn forward_batch(
234 &self,
235 anchors: &[Vec<f64>],
236 positives: &[Vec<f64>],
237 negatives: &[Vec<f64>],
238 ) -> CnnResult<f64> {
239 if anchors.len() != positives.len() || anchors.len() != negatives.len() {
240 return Err(CnnError::DimensionMismatch(format!(
241 "Batch sizes must match: anchors={}, positives={}, negatives={}",
242 anchors.len(),
243 positives.len(),
244 negatives.len()
245 )));
246 }
247
248 if anchors.is_empty() {
249 return Err(CnnError::InvalidInput("batch cannot be empty".to_string()));
250 }
251
252 let mut total_loss = 0.0;
253 for ((anchor, positive), negative) in anchors.iter().zip(positives).zip(negatives) {
254 total_loss += self.forward(anchor, positive, negative);
255 }
256
257 Ok(total_loss / anchors.len() as f64)
258 }
259
260 pub fn mine_hard_triplets(
273 &self,
274 embeddings: &[Vec<f64>],
275 labels: &[usize],
276 ) -> Vec<(usize, usize, usize)> {
277 if embeddings.len() != labels.len() {
278 return vec![];
279 }
280
281 let n = embeddings.len();
282 let mut triplets = Vec::new();
283
284 let distances = self.compute_distance_matrix(embeddings);
286
287 for anchor_idx in 0..n {
288 let anchor_label = labels[anchor_idx];
289
290 let mut hardest_pos_idx = None;
292 let mut hardest_pos_dist = f64::NEG_INFINITY;
293
294 let mut hardest_neg_idx = None;
296 let mut hardest_neg_dist = f64::INFINITY;
297
298 for other_idx in 0..n {
299 if other_idx == anchor_idx {
300 continue;
301 }
302
303 let dist = distances[anchor_idx][other_idx];
304
305 if labels[other_idx] == anchor_label {
306 if dist > hardest_pos_dist {
308 hardest_pos_dist = dist;
309 hardest_pos_idx = Some(other_idx);
310 }
311 } else {
312 if dist < hardest_neg_dist {
314 hardest_neg_dist = dist;
315 hardest_neg_idx = Some(other_idx);
316 }
317 }
318 }
319
320 if let (Some(pos_idx), Some(neg_idx)) = (hardest_pos_idx, hardest_neg_idx) {
322 if hardest_pos_dist - hardest_neg_dist + self.margin > 0.0 {
323 triplets.push((anchor_idx, pos_idx, neg_idx));
324 }
325 }
326 }
327
328 triplets
329 }
330
331 #[inline]
333 fn compute_distance(&self, a: &[f64], b: &[f64]) -> f64 {
334 match self.distance {
335 TripletDistance::Euclidean => {
336 let sum_sq: f64 = a.iter().zip(b).map(|(x, y)| (x - y).powi(2)).sum();
337 sum_sq.sqrt()
338 }
339 TripletDistance::SquaredEuclidean => {
340 a.iter().zip(b).map(|(x, y)| (x - y).powi(2)).sum()
341 }
342 TripletDistance::Cosine => {
343 let mut dot = 0.0;
344 let mut norm_a_sq = 0.0;
345 let mut norm_b_sq = 0.0;
346
347 for (x, y) in a.iter().zip(b) {
348 dot += x * y;
349 norm_a_sq += x * x;
350 norm_b_sq += y * y;
351 }
352
353 let norm = (norm_a_sq * norm_b_sq).sqrt();
354 if norm < 1e-8 {
355 1.0 } else {
357 1.0 - dot / norm
358 }
359 }
360 }
361 }
362
363 fn compute_distance_matrix(&self, embeddings: &[Vec<f64>]) -> Vec<Vec<f64>> {
365 let n = embeddings.len();
366 let mut matrix = vec![vec![0.0; n]; n];
367
368 for i in 0..n {
369 for j in (i + 1)..n {
370 let dist = self.compute_distance(&embeddings[i], &embeddings[j]);
371 matrix[i][j] = dist;
372 matrix[j][i] = dist;
373 }
374 }
375
376 matrix
377 }
378}
379
380impl Default for TripletLoss {
381 fn default() -> Self {
382 Self::new(1.0)
383 }
384}
385
386#[inline]
388fn soft_relu(x: f64) -> f64 {
389 if x > 20.0 {
390 x } else if x < -20.0 {
392 0.0 } else {
394 (1.0 + x.exp()).ln()
395 }
396}
397
398#[cfg(test)]
399mod tests {
400 use super::*;
401
402 #[test]
403 fn test_triplet_basic() {
404 let triplet = TripletLoss::new(1.0);
405
406 let anchor = vec![1.0, 0.0, 0.0];
407 let positive = vec![0.9, 0.1, 0.0];
408 let negative = vec![0.0, 1.0, 0.0];
409
410 let loss = triplet.forward(&anchor, &positive, &negative);
411 assert!(loss >= 0.0);
412 }
413
414 #[test]
415 fn test_triplet_zero_loss() {
416 let triplet = TripletLoss::new(0.1);
417
418 let anchor = vec![1.0, 0.0];
420 let positive = vec![1.0, 0.0]; let negative = vec![-1.0, 0.0]; let result = triplet.forward_detailed(&anchor, &positive, &negative).unwrap();
424 assert_eq!(result.loss, 0.0);
425 assert!(!result.is_hard);
426 }
427
428 #[test]
429 fn test_triplet_hard() {
430 let triplet = TripletLoss::new(1.0);
431
432 let anchor = vec![0.0, 0.0];
434 let positive = vec![2.0, 0.0];
435 let negative = vec![1.0, 0.0];
436
437 let result = triplet.forward_detailed(&anchor, &positive, &negative).unwrap();
438 assert!(result.loss > 0.0);
439 assert!(result.is_hard);
440 assert!(result.violates_margin);
441 }
442
443 #[test]
444 fn test_triplet_distances() {
445 let triplet_euclidean = TripletLoss::new(0.0).with_distance(TripletDistance::Euclidean);
447 let a = vec![0.0, 0.0];
448 let b = vec![3.0, 4.0];
449 let c = vec![0.0, 0.0];
450
451 let result = triplet_euclidean.forward_detailed(&a, &b, &c).unwrap();
452 assert!((result.positive_distance - 5.0).abs() < 1e-6);
453 assert!(result.negative_distance.abs() < 1e-6);
454
455 let triplet_cosine = TripletLoss::new(0.0).with_distance(TripletDistance::Cosine);
457 let x = vec![1.0, 0.0];
458 let y = vec![0.0, 1.0];
459 let z = vec![1.0, 0.0];
460
461 let result = triplet_cosine.forward_detailed(&x, &y, &z).unwrap();
462 assert!((result.positive_distance - 1.0).abs() < 1e-6); assert!(result.negative_distance.abs() < 1e-6); }
465
466 #[test]
467 fn test_soft_margin() {
468 let hard = TripletLoss::new(1.0);
469 let soft = TripletLoss::new(1.0).with_soft_margin();
470
471 let anchor = vec![0.0, 0.0];
472 let positive = vec![1.0, 0.0];
473 let negative = vec![0.5, 0.0];
474
475 let hard_loss = hard.forward(&anchor, &positive, &negative);
476 let soft_loss = soft.forward(&anchor, &positive, &negative);
477
478 assert!(soft_loss >= hard_loss);
480 assert!(hard_loss > 0.0);
482 assert!(soft_loss > 0.0);
483 }
484
485 #[test]
486 fn test_batch_triplet() {
487 let triplet = TripletLoss::new(1.0);
488
489 let anchors = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
490 let positives = vec![vec![0.9, 0.1], vec![0.1, 0.9]];
491 let negatives = vec![vec![0.0, 1.0], vec![1.0, 0.0]];
492
493 let loss = triplet.forward_batch(&anchors, &positives, &negatives).unwrap();
494 assert!(loss >= 0.0);
495 }
496
497 #[test]
498 fn test_mine_hard_triplets() {
499 let triplet = TripletLoss::new(0.01);
501
502 let embeddings = vec![
505 vec![1.0, 0.0], vec![0.95, 0.05], vec![0.9, 0.1], vec![0.85, 0.15], ];
510 let labels = vec![0, 0, 1, 1];
511
512 let hard_triplets = triplet.mine_hard_triplets(&embeddings, &labels);
513
514 for (a, p, n) in &hard_triplets {
517 assert_eq!(labels[*a], labels[*p]); assert_ne!(labels[*a], labels[*n]); }
520
521 }
524
525 #[test]
526 fn test_l2_regularization() {
527 let no_reg = TripletLoss::new(0.0);
528 let with_reg = TripletLoss::new(0.0).with_l2_regularization(0.01);
529
530 let anchor = vec![10.0, 0.0];
531 let positive = vec![10.0, 0.0];
532 let negative = vec![-10.0, 0.0];
533
534 let loss_no_reg = no_reg.forward(&anchor, &positive, &negative);
535 let loss_with_reg = with_reg.forward(&anchor, &positive, &negative);
536
537 assert!(loss_with_reg > loss_no_reg);
539 }
540
541 #[test]
542 fn test_error_handling() {
543 let triplet = TripletLoss::new(1.0);
544
545 let result = triplet.forward_detailed(&[], &[1.0], &[1.0]);
547 assert!(result.is_err());
548
549 let result = triplet.forward_detailed(&[1.0, 2.0], &[1.0], &[1.0, 2.0]);
551 assert!(result.is_err());
552 }
553
554 #[test]
555 fn test_soft_relu() {
556 assert!((soft_relu(0.0) - 2.0_f64.ln()).abs() < 1e-6);
558 assert!(soft_relu(-100.0) < 1e-10);
559 assert!((soft_relu(100.0) - 100.0).abs() < 1e-6);
560
561 let x = 1.0;
563 let y = soft_relu(x);
564 assert!(y > x.max(0.0)); }
566}