1use crate::TrainResult;
4use scirs2_core::ndarray::{Array, ArrayView, Ix2};
5use std::fmt::Debug;
6
7pub trait Loss: Debug {
9 fn compute(
11 &self,
12 predictions: &ArrayView<f64, Ix2>,
13 targets: &ArrayView<f64, Ix2>,
14 ) -> TrainResult<f64>;
15 fn gradient(
17 &self,
18 predictions: &ArrayView<f64, Ix2>,
19 targets: &ArrayView<f64, Ix2>,
20 ) -> TrainResult<Array<f64, Ix2>>;
21 fn name(&self) -> &str {
23 "unknown"
24 }
25}
26#[cfg(test)]
27mod tests {
28 use super::super::types::{
29 BCEWithLogitsLoss, ConstraintViolationLoss, ContrastiveLoss, CrossEntropyLoss, DiceLoss,
30 FocalLoss, HingeLoss, HuberLoss, KLDivergenceLoss, MseLoss, PolyLoss, RuleSatisfactionLoss,
31 TripletLoss, TverskyLoss,
32 };
33 use super::*;
34 use scirs2_core::ndarray::array;
35 #[test]
36 fn test_cross_entropy_loss() {
37 let loss = CrossEntropyLoss::default();
38 let predictions = array![[0.7, 0.2, 0.1], [0.1, 0.8, 0.1]];
39 let targets = array![[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]];
40 let loss_val = loss
41 .compute(&predictions.view(), &targets.view())
42 .expect("unwrap");
43 assert!(loss_val > 0.0);
44 let grad = loss
45 .gradient(&predictions.view(), &targets.view())
46 .expect("unwrap");
47 assert_eq!(grad.shape(), predictions.shape());
48 }
49 #[test]
50 fn test_mse_loss() {
51 let loss = MseLoss;
52 let predictions = array![[1.0, 2.0], [3.0, 4.0]];
53 let targets = array![[1.5, 2.5], [3.5, 4.5]];
54 let loss_val = loss
55 .compute(&predictions.view(), &targets.view())
56 .expect("unwrap");
57 assert!((loss_val - 0.25).abs() < 1e-6);
58 let grad = loss
59 .gradient(&predictions.view(), &targets.view())
60 .expect("unwrap");
61 assert_eq!(grad.shape(), predictions.shape());
62 }
63 #[test]
64 fn test_rule_satisfaction_loss() {
65 let loss = RuleSatisfactionLoss::default();
66 let rule_values = array![[0.9, 0.8], [0.95, 0.85]];
67 let targets = array![[1.0, 1.0], [1.0, 1.0]];
68 let loss_val = loss
69 .compute(&rule_values.view(), &targets.view())
70 .expect("unwrap");
71 assert!(loss_val > 0.0);
72 let grad = loss
73 .gradient(&rule_values.view(), &targets.view())
74 .expect("unwrap");
75 assert_eq!(grad.shape(), rule_values.shape());
76 }
77 #[test]
78 fn test_constraint_violation_loss() {
79 let loss = ConstraintViolationLoss::default();
80 let constraint_values = array![[0.1, -0.1], [0.2, -0.2]];
81 let targets = array![[0.0, 0.0], [0.0, 0.0]];
82 let loss_val = loss
83 .compute(&constraint_values.view(), &targets.view())
84 .expect("unwrap");
85 assert!(loss_val > 0.0);
86 let grad = loss
87 .gradient(&constraint_values.view(), &targets.view())
88 .expect("unwrap");
89 assert_eq!(grad.shape(), constraint_values.shape());
90 }
91 #[test]
92 fn test_focal_loss() {
93 let loss = FocalLoss::default();
94 let predictions = array![[0.9, 0.1], [0.2, 0.8]];
95 let targets = array![[1.0, 0.0], [0.0, 1.0]];
96 let loss_val = loss
97 .compute(&predictions.view(), &targets.view())
98 .expect("unwrap");
99 assert!(loss_val >= 0.0);
100 let grad = loss
101 .gradient(&predictions.view(), &targets.view())
102 .expect("unwrap");
103 assert_eq!(grad.shape(), predictions.shape());
104 }
105 #[test]
106 fn test_huber_loss() {
107 let loss = HuberLoss::default();
108 let predictions = array![[1.0, 3.0], [2.0, 5.0]];
109 let targets = array![[1.5, 2.0], [2.5, 4.0]];
110 let loss_val = loss
111 .compute(&predictions.view(), &targets.view())
112 .expect("unwrap");
113 assert!(loss_val > 0.0);
114 let grad = loss
115 .gradient(&predictions.view(), &targets.view())
116 .expect("unwrap");
117 assert_eq!(grad.shape(), predictions.shape());
118 }
119 #[test]
120 fn test_bce_with_logits_loss() {
121 let loss = BCEWithLogitsLoss;
122 let logits = array![[0.5, -0.5], [1.0, -1.0]];
123 let targets = array![[1.0, 0.0], [1.0, 0.0]];
124 let loss_val = loss
125 .compute(&logits.view(), &targets.view())
126 .expect("unwrap");
127 assert!(loss_val >= 0.0);
128 let grad = loss
129 .gradient(&logits.view(), &targets.view())
130 .expect("unwrap");
131 assert_eq!(grad.shape(), logits.shape());
132 }
133 #[test]
134 fn test_dice_loss() {
135 let loss = DiceLoss::default();
136 let predictions = array![[0.9, 0.1], [0.8, 0.2]];
137 let targets = array![[1.0, 0.0], [1.0, 0.0]];
138 let loss_val = loss
139 .compute(&predictions.view(), &targets.view())
140 .expect("unwrap");
141 assert!(loss_val >= 0.0);
142 assert!(loss_val <= 1.0);
143 let grad = loss
144 .gradient(&predictions.view(), &targets.view())
145 .expect("unwrap");
146 assert_eq!(grad.shape(), predictions.shape());
147 }
148 #[test]
149 fn test_tversky_loss() {
150 let loss = TverskyLoss::default();
151 let predictions = array![[0.9, 0.1], [0.8, 0.2]];
152 let targets = array![[1.0, 0.0], [1.0, 0.0]];
153 let loss_val = loss
154 .compute(&predictions.view(), &targets.view())
155 .expect("unwrap");
156 assert!(loss_val >= 0.0);
157 assert!(loss_val <= 1.0);
158 let grad = loss
159 .gradient(&predictions.view(), &targets.view())
160 .expect("unwrap");
161 assert_eq!(grad.shape(), predictions.shape());
162 }
163 #[test]
164 fn test_contrastive_loss() {
165 let loss = ContrastiveLoss::default();
166 let predictions = array![[0.5, 0.0], [1.5, 0.0], [0.2, 0.0]];
167 let targets = array![[1.0], [0.0], [1.0]];
168 let loss_val = loss
169 .compute(&predictions.view(), &targets.view())
170 .expect("unwrap");
171 assert!(loss_val >= 0.0);
172 let grad = loss
173 .gradient(&predictions.view(), &targets.view())
174 .expect("unwrap");
175 assert_eq!(grad.shape(), predictions.shape());
176 assert!(grad[[0, 0]] > 0.0);
177 assert_eq!(grad[[1, 0]], 0.0);
178 }
179 #[test]
180 fn test_triplet_loss() {
181 let loss = TripletLoss::default();
182 let predictions = array![[0.5, 2.0], [1.0, 0.5], [0.3, 1.5]];
183 let targets = array![[0.0], [0.0], [0.0]];
184 let loss_val = loss
185 .compute(&predictions.view(), &targets.view())
186 .expect("unwrap");
187 assert!(loss_val >= 0.0);
188 let grad = loss
189 .gradient(&predictions.view(), &targets.view())
190 .expect("unwrap");
191 assert_eq!(grad.shape(), predictions.shape());
192 assert_eq!(grad[[0, 0]], 0.0);
193 assert_eq!(grad[[0, 1]], 0.0);
194 assert!(grad[[1, 0]] > 0.0);
195 assert!(grad[[1, 1]] < 0.0);
196 }
197 #[test]
198 fn test_hinge_loss() {
199 let loss = HingeLoss::default();
200 let predictions = array![[0.5, -0.5], [2.0, -2.0]];
201 let targets = array![[1.0, -1.0], [1.0, -1.0]];
202 let loss_val = loss
203 .compute(&predictions.view(), &targets.view())
204 .expect("unwrap");
205 assert!(loss_val >= 0.0);
206 let grad = loss
207 .gradient(&predictions.view(), &targets.view())
208 .expect("unwrap");
209 assert_eq!(grad.shape(), predictions.shape());
210 assert_eq!(grad[[1, 0]], 0.0);
211 assert_eq!(grad[[1, 1]], 0.0);
212 }
213 #[test]
214 fn test_kl_divergence_loss() {
215 let loss = KLDivergenceLoss::default();
216 let predictions = array![[0.6, 0.4], [0.7, 0.3]];
217 let targets = array![[0.5, 0.5], [0.8, 0.2]];
218 let loss_val = loss
219 .compute(&predictions.view(), &targets.view())
220 .expect("unwrap");
221 assert!(loss_val >= 0.0);
222 let grad = loss
223 .gradient(&predictions.view(), &targets.view())
224 .expect("unwrap");
225 assert_eq!(grad.shape(), predictions.shape());
226 let identical_preds = array![[0.5, 0.5]];
227 let identical_targets = array![[0.5, 0.5]];
228 let identical_loss = loss
229 .compute(&identical_preds.view(), &identical_targets.view())
230 .expect("unwrap");
231 assert!(identical_loss.abs() < 1e-6);
232 }
233 #[test]
234 fn test_poly_loss() {
235 let loss = PolyLoss::default();
236 let predictions = array![[0.9, 0.1], [0.2, 0.8]];
237 let targets = array![[1.0, 0.0], [0.0, 1.0]];
238 let loss_val = loss
239 .compute(&predictions.view(), &targets.view())
240 .expect("unwrap");
241 assert!(loss_val > 0.0);
242 let grad = loss
243 .gradient(&predictions.view(), &targets.view())
244 .expect("unwrap");
245 assert_eq!(grad.shape(), predictions.shape());
246 let ce_loss = CrossEntropyLoss::default();
247 let ce_val = ce_loss
248 .compute(&predictions.view(), &targets.view())
249 .expect("unwrap");
250 assert!(loss_val >= ce_val);
251 }
252 #[test]
253 fn test_poly_loss_custom_coefficient() {
254 let loss = PolyLoss::new(2.0);
255 let predictions = array![[0.8, 0.2]];
256 let targets = array![[1.0, 0.0]];
257 let loss_val = loss
258 .compute(&predictions.view(), &targets.view())
259 .expect("unwrap");
260 assert!(loss_val > 0.0);
261 let loss_low_coeff = PolyLoss::new(0.5);
262 let loss_val_low = loss_low_coeff
263 .compute(&predictions.view(), &targets.view())
264 .expect("unwrap");
265 assert!(loss_val > loss_val_low);
266 }
267}