Skip to main content

tensorlogic_train/loss/
functions.rs

1//! Loss trait and loss function test suite.
2
3use crate::TrainResult;
4use scirs2_core::ndarray::{Array, ArrayView, Ix2};
5use std::fmt::Debug;
6
7/// Trait for loss functions.
8pub trait Loss: Debug {
9    /// Compute loss value.
10    fn compute(
11        &self,
12        predictions: &ArrayView<f64, Ix2>,
13        targets: &ArrayView<f64, Ix2>,
14    ) -> TrainResult<f64>;
15    /// Compute loss gradient with respect to predictions.
16    fn gradient(
17        &self,
18        predictions: &ArrayView<f64, Ix2>,
19        targets: &ArrayView<f64, Ix2>,
20    ) -> TrainResult<Array<f64, Ix2>>;
21    /// Get the name of the loss function.
22    fn name(&self) -> &str {
23        "unknown"
24    }
25}
26#[cfg(test)]
27mod tests {
28    use super::super::types::{
29        BCEWithLogitsLoss, ConstraintViolationLoss, ContrastiveLoss, CrossEntropyLoss, DiceLoss,
30        FocalLoss, HingeLoss, HuberLoss, KLDivergenceLoss, MseLoss, PolyLoss, RuleSatisfactionLoss,
31        TripletLoss, TverskyLoss,
32    };
33    use super::*;
34    use scirs2_core::ndarray::array;
35    #[test]
36    fn test_cross_entropy_loss() {
37        let loss = CrossEntropyLoss::default();
38        let predictions = array![[0.7, 0.2, 0.1], [0.1, 0.8, 0.1]];
39        let targets = array![[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]];
40        let loss_val = loss
41            .compute(&predictions.view(), &targets.view())
42            .expect("unwrap");
43        assert!(loss_val > 0.0);
44        let grad = loss
45            .gradient(&predictions.view(), &targets.view())
46            .expect("unwrap");
47        assert_eq!(grad.shape(), predictions.shape());
48    }
49    #[test]
50    fn test_mse_loss() {
51        let loss = MseLoss;
52        let predictions = array![[1.0, 2.0], [3.0, 4.0]];
53        let targets = array![[1.5, 2.5], [3.5, 4.5]];
54        let loss_val = loss
55            .compute(&predictions.view(), &targets.view())
56            .expect("unwrap");
57        assert!((loss_val - 0.25).abs() < 1e-6);
58        let grad = loss
59            .gradient(&predictions.view(), &targets.view())
60            .expect("unwrap");
61        assert_eq!(grad.shape(), predictions.shape());
62    }
63    #[test]
64    fn test_rule_satisfaction_loss() {
65        let loss = RuleSatisfactionLoss::default();
66        let rule_values = array![[0.9, 0.8], [0.95, 0.85]];
67        let targets = array![[1.0, 1.0], [1.0, 1.0]];
68        let loss_val = loss
69            .compute(&rule_values.view(), &targets.view())
70            .expect("unwrap");
71        assert!(loss_val > 0.0);
72        let grad = loss
73            .gradient(&rule_values.view(), &targets.view())
74            .expect("unwrap");
75        assert_eq!(grad.shape(), rule_values.shape());
76    }
77    #[test]
78    fn test_constraint_violation_loss() {
79        let loss = ConstraintViolationLoss::default();
80        let constraint_values = array![[0.1, -0.1], [0.2, -0.2]];
81        let targets = array![[0.0, 0.0], [0.0, 0.0]];
82        let loss_val = loss
83            .compute(&constraint_values.view(), &targets.view())
84            .expect("unwrap");
85        assert!(loss_val > 0.0);
86        let grad = loss
87            .gradient(&constraint_values.view(), &targets.view())
88            .expect("unwrap");
89        assert_eq!(grad.shape(), constraint_values.shape());
90    }
91    #[test]
92    fn test_focal_loss() {
93        let loss = FocalLoss::default();
94        let predictions = array![[0.9, 0.1], [0.2, 0.8]];
95        let targets = array![[1.0, 0.0], [0.0, 1.0]];
96        let loss_val = loss
97            .compute(&predictions.view(), &targets.view())
98            .expect("unwrap");
99        assert!(loss_val >= 0.0);
100        let grad = loss
101            .gradient(&predictions.view(), &targets.view())
102            .expect("unwrap");
103        assert_eq!(grad.shape(), predictions.shape());
104    }
105    #[test]
106    fn test_huber_loss() {
107        let loss = HuberLoss::default();
108        let predictions = array![[1.0, 3.0], [2.0, 5.0]];
109        let targets = array![[1.5, 2.0], [2.5, 4.0]];
110        let loss_val = loss
111            .compute(&predictions.view(), &targets.view())
112            .expect("unwrap");
113        assert!(loss_val > 0.0);
114        let grad = loss
115            .gradient(&predictions.view(), &targets.view())
116            .expect("unwrap");
117        assert_eq!(grad.shape(), predictions.shape());
118    }
119    #[test]
120    fn test_bce_with_logits_loss() {
121        let loss = BCEWithLogitsLoss;
122        let logits = array![[0.5, -0.5], [1.0, -1.0]];
123        let targets = array![[1.0, 0.0], [1.0, 0.0]];
124        let loss_val = loss
125            .compute(&logits.view(), &targets.view())
126            .expect("unwrap");
127        assert!(loss_val >= 0.0);
128        let grad = loss
129            .gradient(&logits.view(), &targets.view())
130            .expect("unwrap");
131        assert_eq!(grad.shape(), logits.shape());
132    }
133    #[test]
134    fn test_dice_loss() {
135        let loss = DiceLoss::default();
136        let predictions = array![[0.9, 0.1], [0.8, 0.2]];
137        let targets = array![[1.0, 0.0], [1.0, 0.0]];
138        let loss_val = loss
139            .compute(&predictions.view(), &targets.view())
140            .expect("unwrap");
141        assert!(loss_val >= 0.0);
142        assert!(loss_val <= 1.0);
143        let grad = loss
144            .gradient(&predictions.view(), &targets.view())
145            .expect("unwrap");
146        assert_eq!(grad.shape(), predictions.shape());
147    }
148    #[test]
149    fn test_tversky_loss() {
150        let loss = TverskyLoss::default();
151        let predictions = array![[0.9, 0.1], [0.8, 0.2]];
152        let targets = array![[1.0, 0.0], [1.0, 0.0]];
153        let loss_val = loss
154            .compute(&predictions.view(), &targets.view())
155            .expect("unwrap");
156        assert!(loss_val >= 0.0);
157        assert!(loss_val <= 1.0);
158        let grad = loss
159            .gradient(&predictions.view(), &targets.view())
160            .expect("unwrap");
161        assert_eq!(grad.shape(), predictions.shape());
162    }
163    #[test]
164    fn test_contrastive_loss() {
165        let loss = ContrastiveLoss::default();
166        let predictions = array![[0.5, 0.0], [1.5, 0.0], [0.2, 0.0]];
167        let targets = array![[1.0], [0.0], [1.0]];
168        let loss_val = loss
169            .compute(&predictions.view(), &targets.view())
170            .expect("unwrap");
171        assert!(loss_val >= 0.0);
172        let grad = loss
173            .gradient(&predictions.view(), &targets.view())
174            .expect("unwrap");
175        assert_eq!(grad.shape(), predictions.shape());
176        assert!(grad[[0, 0]] > 0.0);
177        assert_eq!(grad[[1, 0]], 0.0);
178    }
179    #[test]
180    fn test_triplet_loss() {
181        let loss = TripletLoss::default();
182        let predictions = array![[0.5, 2.0], [1.0, 0.5], [0.3, 1.5]];
183        let targets = array![[0.0], [0.0], [0.0]];
184        let loss_val = loss
185            .compute(&predictions.view(), &targets.view())
186            .expect("unwrap");
187        assert!(loss_val >= 0.0);
188        let grad = loss
189            .gradient(&predictions.view(), &targets.view())
190            .expect("unwrap");
191        assert_eq!(grad.shape(), predictions.shape());
192        assert_eq!(grad[[0, 0]], 0.0);
193        assert_eq!(grad[[0, 1]], 0.0);
194        assert!(grad[[1, 0]] > 0.0);
195        assert!(grad[[1, 1]] < 0.0);
196    }
197    #[test]
198    fn test_hinge_loss() {
199        let loss = HingeLoss::default();
200        let predictions = array![[0.5, -0.5], [2.0, -2.0]];
201        let targets = array![[1.0, -1.0], [1.0, -1.0]];
202        let loss_val = loss
203            .compute(&predictions.view(), &targets.view())
204            .expect("unwrap");
205        assert!(loss_val >= 0.0);
206        let grad = loss
207            .gradient(&predictions.view(), &targets.view())
208            .expect("unwrap");
209        assert_eq!(grad.shape(), predictions.shape());
210        assert_eq!(grad[[1, 0]], 0.0);
211        assert_eq!(grad[[1, 1]], 0.0);
212    }
213    #[test]
214    fn test_kl_divergence_loss() {
215        let loss = KLDivergenceLoss::default();
216        let predictions = array![[0.6, 0.4], [0.7, 0.3]];
217        let targets = array![[0.5, 0.5], [0.8, 0.2]];
218        let loss_val = loss
219            .compute(&predictions.view(), &targets.view())
220            .expect("unwrap");
221        assert!(loss_val >= 0.0);
222        let grad = loss
223            .gradient(&predictions.view(), &targets.view())
224            .expect("unwrap");
225        assert_eq!(grad.shape(), predictions.shape());
226        let identical_preds = array![[0.5, 0.5]];
227        let identical_targets = array![[0.5, 0.5]];
228        let identical_loss = loss
229            .compute(&identical_preds.view(), &identical_targets.view())
230            .expect("unwrap");
231        assert!(identical_loss.abs() < 1e-6);
232    }
233    #[test]
234    fn test_poly_loss() {
235        let loss = PolyLoss::default();
236        let predictions = array![[0.9, 0.1], [0.2, 0.8]];
237        let targets = array![[1.0, 0.0], [0.0, 1.0]];
238        let loss_val = loss
239            .compute(&predictions.view(), &targets.view())
240            .expect("unwrap");
241        assert!(loss_val > 0.0);
242        let grad = loss
243            .gradient(&predictions.view(), &targets.view())
244            .expect("unwrap");
245        assert_eq!(grad.shape(), predictions.shape());
246        let ce_loss = CrossEntropyLoss::default();
247        let ce_val = ce_loss
248            .compute(&predictions.view(), &targets.view())
249            .expect("unwrap");
250        assert!(loss_val >= ce_val);
251    }
252    #[test]
253    fn test_poly_loss_custom_coefficient() {
254        let loss = PolyLoss::new(2.0);
255        let predictions = array![[0.8, 0.2]];
256        let targets = array![[1.0, 0.0]];
257        let loss_val = loss
258            .compute(&predictions.view(), &targets.view())
259            .expect("unwrap");
260        assert!(loss_val > 0.0);
261        let loss_low_coeff = PolyLoss::new(0.5);
262        let loss_val_low = loss_low_coeff
263            .compute(&predictions.view(), &targets.view())
264            .expect("unwrap");
265        assert!(loss_val > loss_val_low);
266    }
267}