Skip to main content

tensorlogic_train/loss/
contrastiveloss_traits.rs

1//! # ContrastiveLoss - Trait Implementations
2//!
3//! This module contains trait implementations for `ContrastiveLoss`.
4//!
5//! ## Implemented Traits
6//!
7//! - `Default`
8//! - `Loss`
9
10use crate::{TrainError, TrainResult};
11use scirs2_core::ndarray::{Array, ArrayView, Ix2};
12
13use super::functions::Loss;
14use super::types::ContrastiveLoss;
15
16impl Default for ContrastiveLoss {
17    fn default() -> Self {
18        Self { margin: 1.0 }
19    }
20}
21
22impl Loss for ContrastiveLoss {
23    fn compute(
24        &self,
25        predictions: &ArrayView<f64, Ix2>,
26        targets: &ArrayView<f64, Ix2>,
27    ) -> TrainResult<f64> {
28        if predictions.ncols() != 2 || targets.ncols() != 1 {
29            return Err(
30                TrainError::LossError(
31                    format!(
32                        "ContrastiveLoss expects predictions shape [N, 2] (distances) and targets shape [N, 1] (labels), got {:?} and {:?}",
33                        predictions.shape(), targets.shape()
34                    ),
35                ),
36            );
37        }
38        let mut total_loss = 0.0;
39        let n = predictions.nrows() as f64;
40        for i in 0..predictions.nrows() {
41            let distance = predictions[[i, 0]];
42            let label = targets[[i, 0]];
43            if label > 0.5 {
44                total_loss += distance * distance;
45            } else {
46                total_loss += (self.margin - distance).max(0.0).powi(2);
47            }
48        }
49        Ok(total_loss / n)
50    }
51    fn gradient(
52        &self,
53        predictions: &ArrayView<f64, Ix2>,
54        targets: &ArrayView<f64, Ix2>,
55    ) -> TrainResult<Array<f64, Ix2>> {
56        let mut grad = Array::zeros(predictions.raw_dim());
57        let n = predictions.nrows() as f64;
58        for i in 0..predictions.nrows() {
59            let distance = predictions[[i, 0]];
60            let label = targets[[i, 0]];
61            if label > 0.5 {
62                grad[[i, 0]] = 2.0 * distance / n;
63            } else {
64                if distance < self.margin {
65                    grad[[i, 0]] = -2.0 * (self.margin - distance) / n;
66                }
67            }
68        }
69        Ok(grad)
70    }
71}