1#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum Loss {
13 Mse,
15 Mae,
17 BinaryCrossEntropyWithLogits,
22 SoftmaxCrossEntropy,
27}
28
29impl Loss {
30 pub fn validate(self) -> crate::Result<()> {
32 Ok(())
34 }
35
36 #[inline]
40 pub fn forward(self, pred: &[f32], target: &[f32]) -> f32 {
41 match self {
42 Loss::Mse => mse(pred, target),
43 Loss::Mae => mae(pred, target),
44 Loss::BinaryCrossEntropyWithLogits => bce_with_logits(pred, target),
45 Loss::SoftmaxCrossEntropy => softmax_cross_entropy(pred, target),
46 }
47 }
48
49 #[inline]
57 pub fn backward(self, pred: &[f32], target: &[f32], d_pred: &mut [f32]) -> f32 {
58 match self {
59 Loss::Mse => mse_backward(pred, target, d_pred),
60 Loss::Mae => mae_backward(pred, target, d_pred),
61 Loss::BinaryCrossEntropyWithLogits => bce_with_logits_backward(pred, target, d_pred),
62 Loss::SoftmaxCrossEntropy => softmax_cross_entropy_backward(pred, target, d_pred),
63 }
64 }
65}
66
67#[inline]
71pub fn mse(pred: &[f32], target: &[f32]) -> f32 {
72 assert_eq!(
73 pred.len(),
74 target.len(),
75 "pred len {} does not match target len {}",
76 pred.len(),
77 target.len()
78 );
79
80 if pred.is_empty() {
81 return 0.0;
82 }
83
84 let inv_n = 1.0 / pred.len() as f32;
85 let mut sum_sq = 0.0_f32;
86 for i in 0..pred.len() {
87 let diff = pred[i] - target[i];
88 sum_sq = diff.mul_add(diff, sum_sq);
89 }
90 0.5 * sum_sq * inv_n
91}
92
93#[inline]
100pub fn mse_backward(pred: &[f32], target: &[f32], d_pred: &mut [f32]) -> f32 {
101 assert_eq!(
102 pred.len(),
103 target.len(),
104 "pred len {} does not match target len {}",
105 pred.len(),
106 target.len()
107 );
108 assert_eq!(
109 pred.len(),
110 d_pred.len(),
111 "pred len {} does not match d_pred len {}",
112 pred.len(),
113 d_pred.len()
114 );
115
116 if pred.is_empty() {
117 return 0.0;
118 }
119
120 let inv_n = 1.0 / pred.len() as f32;
121 let mut sum_sq = 0.0_f32;
122
123 for i in 0..pred.len() {
124 let diff = pred[i] - target[i];
125 sum_sq = diff.mul_add(diff, sum_sq);
126 d_pred[i] = diff * inv_n;
127 }
128
129 0.5 * sum_sq * inv_n
130}
131
132#[inline]
136pub fn mae(pred: &[f32], target: &[f32]) -> f32 {
137 assert_eq!(
138 pred.len(),
139 target.len(),
140 "pred len {} does not match target len {}",
141 pred.len(),
142 target.len()
143 );
144
145 if pred.is_empty() {
146 return 0.0;
147 }
148
149 let inv_n = 1.0 / pred.len() as f32;
150 let mut sum = 0.0_f32;
151 for i in 0..pred.len() {
152 sum += (pred[i] - target[i]).abs();
153 }
154 sum * inv_n
155}
156
157#[inline]
161pub fn mae_backward(pred: &[f32], target: &[f32], d_pred: &mut [f32]) -> f32 {
162 assert_eq!(
163 pred.len(),
164 target.len(),
165 "pred len {} does not match target len {}",
166 pred.len(),
167 target.len()
168 );
169 assert_eq!(
170 pred.len(),
171 d_pred.len(),
172 "pred len {} does not match d_pred len {}",
173 pred.len(),
174 d_pred.len()
175 );
176
177 if pred.is_empty() {
178 return 0.0;
179 }
180
181 let inv_n = 1.0 / pred.len() as f32;
182 let mut sum = 0.0_f32;
183 for i in 0..pred.len() {
184 let diff = pred[i] - target[i];
185 sum += diff.abs();
186 d_pred[i] = if diff > 0.0 {
187 inv_n
188 } else if diff < 0.0 {
189 -inv_n
190 } else {
191 0.0
192 };
193 }
194 sum * inv_n
195}
196
197#[inline]
205pub fn bce_with_logits(logits: &[f32], target: &[f32]) -> f32 {
206 assert_eq!(
207 logits.len(),
208 target.len(),
209 "pred len {} does not match target len {}",
210 logits.len(),
211 target.len()
212 );
213
214 if logits.is_empty() {
215 return 0.0;
216 }
217
218 let inv_n = 1.0 / logits.len() as f32;
219 let mut sum = 0.0_f32;
220 for i in 0..logits.len() {
221 let x = logits[i];
222 let t = target[i];
223 let abs_x = x.abs();
224 let loss = x.max(0.0) - x * t + (1.0 + (-abs_x).exp()).ln();
225 sum += loss;
226 }
227 sum * inv_n
228}
229
230#[inline]
234pub fn bce_with_logits_backward(logits: &[f32], target: &[f32], d_logits: &mut [f32]) -> f32 {
235 assert_eq!(
236 logits.len(),
237 target.len(),
238 "pred len {} does not match target len {}",
239 logits.len(),
240 target.len()
241 );
242 assert_eq!(
243 logits.len(),
244 d_logits.len(),
245 "pred len {} does not match d_pred len {}",
246 logits.len(),
247 d_logits.len()
248 );
249
250 if logits.is_empty() {
251 return 0.0;
252 }
253
254 let inv_n = 1.0 / logits.len() as f32;
255 let mut sum = 0.0_f32;
256
257 for i in 0..logits.len() {
258 let x = logits[i];
259 let t = target[i];
260 let abs_x = x.abs();
261 let loss = x.max(0.0) - x * t + (1.0 + (-abs_x).exp()).ln();
262 sum += loss;
263
264 let s = sigmoid(x);
265 d_logits[i] = (s - t) * inv_n;
266 }
267
268 sum * inv_n
269}
270
271#[inline]
275pub fn softmax_cross_entropy(logits: &[f32], target: &[f32]) -> f32 {
276 assert_eq!(
277 logits.len(),
278 target.len(),
279 "pred len {} does not match target len {}",
280 logits.len(),
281 target.len()
282 );
283 assert!(
284 !logits.is_empty(),
285 "softmax_cross_entropy requires at least 1 class"
286 );
287
288 let (log_sum_exp, _max) = log_sum_exp_and_max(logits);
289
290 let mut sum = 0.0_f32;
293 for i in 0..logits.len() {
294 let t = target[i];
295 if t != 0.0 {
296 sum -= t * (logits[i] - log_sum_exp);
297 }
298 }
299
300 sum / logits.len() as f32
302}
303
304#[inline]
311pub fn softmax_cross_entropy_backward(logits: &[f32], target: &[f32], d_logits: &mut [f32]) -> f32 {
312 assert_eq!(
313 logits.len(),
314 target.len(),
315 "pred len {} does not match target len {}",
316 logits.len(),
317 target.len()
318 );
319 assert_eq!(
320 logits.len(),
321 d_logits.len(),
322 "pred len {} does not match d_pred len {}",
323 logits.len(),
324 d_logits.len()
325 );
326 assert!(
327 !logits.is_empty(),
328 "softmax_cross_entropy_backward requires at least 1 class"
329 );
330
331 let k = logits.len();
332 let inv_k = 1.0 / k as f32;
333
334 let (log_sum_exp, max_logit) = log_sum_exp_and_max(logits);
335
336 for i in 0..k {
338 d_logits[i] = (logits[i] - max_logit).exp();
339 }
340 let mut sum_exp = 0.0_f32;
341 for &v in d_logits.iter() {
342 sum_exp += v;
343 }
344 let inv_sum = 1.0 / sum_exp;
345 for v in d_logits.iter_mut() {
346 *v *= inv_sum;
347 }
348
349 let mut loss = 0.0_f32;
351 for i in 0..k {
352 let t = target[i];
353 if t != 0.0 {
354 loss -= t * (logits[i] - log_sum_exp);
355 }
356 }
357 loss *= inv_k;
358
359 for i in 0..k {
361 d_logits[i] = (d_logits[i] - target[i]) * inv_k;
362 }
363
364 loss
365}
366
367#[inline]
368fn sigmoid(x: f32) -> f32 {
369 if x >= 0.0 {
371 let z = (-x).exp();
372 1.0 / (1.0 + z)
373 } else {
374 let z = x.exp();
375 z / (1.0 + z)
376 }
377}
378
379#[inline]
380fn log_sum_exp_and_max(xs: &[f32]) -> (f32, f32) {
381 let mut max_x = xs[0];
382 for &x in xs.iter().skip(1) {
383 if x > max_x {
384 max_x = x;
385 }
386 }
387 let mut sum_exp = 0.0_f32;
388 for &x in xs {
389 sum_exp += (x - max_x).exp();
390 }
391 (max_x + sum_exp.ln(), max_x)
392}
393
394#[cfg(test)]
395mod tests {
396 use super::*;
397
398 #[test]
399 fn mse_is_zero_when_equal() {
400 let pred = [1.0_f32, -2.0, 0.5];
401 let target = pred;
402 assert_eq!(mse(&pred, &target), 0.0);
403 }
404
405 #[test]
406 fn mse_backward_matches_expected_gradient() {
407 let pred = [1.0_f32, 3.0];
408 let target = [2.0_f32, 1.0];
409 let mut d_pred = [0.0_f32; 2];
410 let loss = mse_backward(&pred, &target, &mut d_pred);
411
412 assert!((loss - 1.25).abs() < 1e-6);
414 assert!((d_pred[0] - (-0.5)).abs() < 1e-6);
416 assert!((d_pred[1] - (1.0)).abs() < 1e-6);
417 }
418
419 #[test]
420 fn bce_with_logits_is_reasonable_for_extreme_logits() {
421 let logits = [100.0_f32, -100.0];
422 let target = [1.0_f32, 0.0];
423 let loss = bce_with_logits(&logits, &target);
424 assert!(loss.is_finite());
425 assert!(loss < 1e-3);
426 }
427
428 #[test]
429 fn bce_with_logits_backward_matches_sigmoid_minus_target() {
430 let logits = [0.0_f32];
431 let target = [1.0_f32];
432 let mut d = [0.0_f32];
433 let loss = bce_with_logits_backward(&logits, &target, &mut d);
434 assert!((loss - std::f32::consts::LN_2).abs() < 1e-5);
435 assert!((d[0] - (-0.5)).abs() < 1e-6);
437 }
438
439 #[test]
440 fn softmax_cross_entropy_prefers_correct_class() {
441 let logits_good = [5.0_f32, 0.0, -1.0];
442 let logits_bad = [-1.0_f32, 0.0, 5.0];
443 let target = [1.0_f32, 0.0, 0.0];
444 let loss_good = softmax_cross_entropy(&logits_good, &target);
445 let loss_bad = softmax_cross_entropy(&logits_bad, &target);
446 assert!(loss_good < loss_bad);
447 }
448}