Skip to main content

tensorlogic_train/loss/
tripletloss_traits.rs

1//! # TripletLoss - Trait Implementations
2//!
3//! This module contains trait implementations for `TripletLoss`.
4//!
5//! ## Implemented Traits
6//!
7//! - `Default`
8//! - `Loss`
9
10use crate::{TrainError, TrainResult};
11use scirs2_core::ndarray::{Array, ArrayView, Ix2};
12
13use super::functions::Loss;
14use super::types::TripletLoss;
15
16impl Default for TripletLoss {
17    fn default() -> Self {
18        Self { margin: 1.0 }
19    }
20}
21
22impl Loss for TripletLoss {
23    fn compute(
24        &self,
25        predictions: &ArrayView<f64, Ix2>,
26        _targets: &ArrayView<f64, Ix2>,
27    ) -> TrainResult<f64> {
28        if predictions.ncols() != 2 {
29            return Err(TrainError::LossError(format!(
30                "TripletLoss expects predictions shape [N, 2] (pos_dist, neg_dist), got {:?}",
31                predictions.shape()
32            )));
33        }
34        let mut total_loss = 0.0;
35        let n = predictions.nrows() as f64;
36        for i in 0..predictions.nrows() {
37            let pos_distance = predictions[[i, 0]];
38            let neg_distance = predictions[[i, 1]];
39            let loss = (pos_distance - neg_distance + self.margin).max(0.0);
40            total_loss += loss;
41        }
42        Ok(total_loss / n)
43    }
44    fn gradient(
45        &self,
46        predictions: &ArrayView<f64, Ix2>,
47        _targets: &ArrayView<f64, Ix2>,
48    ) -> TrainResult<Array<f64, Ix2>> {
49        let mut grad = Array::zeros(predictions.raw_dim());
50        let n = predictions.nrows() as f64;
51        for i in 0..predictions.nrows() {
52            let pos_distance = predictions[[i, 0]];
53            let neg_distance = predictions[[i, 1]];
54            if pos_distance - neg_distance + self.margin > 0.0 {
55                grad[[i, 0]] = 1.0 / n;
56                grad[[i, 1]] = -1.0 / n;
57            }
58        }
59        Ok(grad)
60    }
61}