Skip to main content

rust_mlp/
loss.rs

1//! Loss functions.
2//!
3//! These are small, allocation-free helpers intended to be used like:
4//!
5//! - run `model.forward(...)`
6//! - compute `d_output` via a loss (e.g. `mse_backward`)
7//! - run `model.backward(...)`
8//! - update parameters with an optimizer
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11/// Supported loss functions.
12pub enum Loss {
13    /// Mean squared error.
14    Mse,
15    /// Mean absolute error.
16    Mae,
17    /// Binary cross-entropy with logits.
18    ///
19    /// This expects raw logits as predictions and targets in `[0, 1]`.
20    /// In most cases you should use an `Identity` activation on the output layer.
21    BinaryCrossEntropyWithLogits,
22    /// Softmax cross-entropy.
23    ///
24    /// This expects raw logits as predictions and a one-hot target vector.
25    /// In most cases you should use an `Identity` activation on the output layer.
26    SoftmaxCrossEntropy,
27}
28
29impl Loss {
30    /// Validate a loss configuration.
31    pub fn validate(self) -> crate::Result<()> {
32        // No parameters today.
33        Ok(())
34    }
35
36    /// Compute a loss value.
37    ///
38    /// Shape contract: `pred.len() == target.len()`.
39    #[inline]
40    pub fn forward(self, pred: &[f32], target: &[f32]) -> f32 {
41        match self {
42            Loss::Mse => mse(pred, target),
43            Loss::Mae => mae(pred, target),
44            Loss::BinaryCrossEntropyWithLogits => bce_with_logits(pred, target),
45            Loss::SoftmaxCrossEntropy => softmax_cross_entropy(pred, target),
46        }
47    }
48
49    /// Compute loss + gradient w.r.t `pred`.
50    ///
51    /// Writes `d_pred = dL/d(pred)` into `d_pred` and returns the loss.
52    ///
53    /// Shape contract:
54    /// - `pred.len() == target.len()`
55    /// - `pred.len() == d_pred.len()`
56    #[inline]
57    pub fn backward(self, pred: &[f32], target: &[f32], d_pred: &mut [f32]) -> f32 {
58        match self {
59            Loss::Mse => mse_backward(pred, target, d_pred),
60            Loss::Mae => mae_backward(pred, target, d_pred),
61            Loss::BinaryCrossEntropyWithLogits => bce_with_logits_backward(pred, target, d_pred),
62            Loss::SoftmaxCrossEntropy => softmax_cross_entropy_backward(pred, target, d_pred),
63        }
64    }
65}
66
67/// Mean squared error (MSE) loss.
68///
69/// Returns `0.5 * mean((pred - target)^2)`.
70#[inline]
71pub fn mse(pred: &[f32], target: &[f32]) -> f32 {
72    assert_eq!(
73        pred.len(),
74        target.len(),
75        "pred len {} does not match target len {}",
76        pred.len(),
77        target.len()
78    );
79
80    if pred.is_empty() {
81        return 0.0;
82    }
83
84    let inv_n = 1.0 / pred.len() as f32;
85    let mut sum_sq = 0.0_f32;
86    for i in 0..pred.len() {
87        let diff = pred[i] - target[i];
88        sum_sq = diff.mul_add(diff, sum_sq);
89    }
90    0.5 * sum_sq * inv_n
91}
92
93/// MSE loss + gradient w.r.t. `pred`.
94///
95/// Writes `d_pred = dL/d(pred)` into `d_pred` and returns the loss.
96///
97/// With `L = 0.5 * mean((pred - target)^2)`, the gradient is:
98/// - `d_pred[i] = (pred[i] - target[i]) / N`
99#[inline]
100pub fn mse_backward(pred: &[f32], target: &[f32], d_pred: &mut [f32]) -> f32 {
101    assert_eq!(
102        pred.len(),
103        target.len(),
104        "pred len {} does not match target len {}",
105        pred.len(),
106        target.len()
107    );
108    assert_eq!(
109        pred.len(),
110        d_pred.len(),
111        "pred len {} does not match d_pred len {}",
112        pred.len(),
113        d_pred.len()
114    );
115
116    if pred.is_empty() {
117        return 0.0;
118    }
119
120    let inv_n = 1.0 / pred.len() as f32;
121    let mut sum_sq = 0.0_f32;
122
123    for i in 0..pred.len() {
124        let diff = pred[i] - target[i];
125        sum_sq = diff.mul_add(diff, sum_sq);
126        d_pred[i] = diff * inv_n;
127    }
128
129    0.5 * sum_sq * inv_n
130}
131
132/// Mean absolute error (MAE) loss.
133///
134/// Returns `mean(|pred - target|)`.
135#[inline]
136pub fn mae(pred: &[f32], target: &[f32]) -> f32 {
137    assert_eq!(
138        pred.len(),
139        target.len(),
140        "pred len {} does not match target len {}",
141        pred.len(),
142        target.len()
143    );
144
145    if pred.is_empty() {
146        return 0.0;
147    }
148
149    let inv_n = 1.0 / pred.len() as f32;
150    let mut sum = 0.0_f32;
151    for i in 0..pred.len() {
152        sum += (pred[i] - target[i]).abs();
153    }
154    sum * inv_n
155}
156
157/// MAE loss + gradient w.r.t `pred`.
158///
159/// Gradient is a subgradient at `pred == target`.
160#[inline]
161pub fn mae_backward(pred: &[f32], target: &[f32], d_pred: &mut [f32]) -> f32 {
162    assert_eq!(
163        pred.len(),
164        target.len(),
165        "pred len {} does not match target len {}",
166        pred.len(),
167        target.len()
168    );
169    assert_eq!(
170        pred.len(),
171        d_pred.len(),
172        "pred len {} does not match d_pred len {}",
173        pred.len(),
174        d_pred.len()
175    );
176
177    if pred.is_empty() {
178        return 0.0;
179    }
180
181    let inv_n = 1.0 / pred.len() as f32;
182    let mut sum = 0.0_f32;
183    for i in 0..pred.len() {
184        let diff = pred[i] - target[i];
185        sum += diff.abs();
186        d_pred[i] = if diff > 0.0 {
187            inv_n
188        } else if diff < 0.0 {
189            -inv_n
190        } else {
191            0.0
192        };
193    }
194    sum * inv_n
195}
196
197/// Binary cross-entropy loss with logits.
198///
199/// Per element (with `t` in \[0, 1\]):
200///
201/// - `L = max(x, 0) - x * t + ln(1 + exp(-|x|))`
202///
203/// This is numerically stable for large |x|.
204#[inline]
205pub fn bce_with_logits(logits: &[f32], target: &[f32]) -> f32 {
206    assert_eq!(
207        logits.len(),
208        target.len(),
209        "pred len {} does not match target len {}",
210        logits.len(),
211        target.len()
212    );
213
214    if logits.is_empty() {
215        return 0.0;
216    }
217
218    let inv_n = 1.0 / logits.len() as f32;
219    let mut sum = 0.0_f32;
220    for i in 0..logits.len() {
221        let x = logits[i];
222        let t = target[i];
223        let abs_x = x.abs();
224        let loss = x.max(0.0) - x * t + (1.0 + (-abs_x).exp()).ln();
225        sum += loss;
226    }
227    sum * inv_n
228}
229
230/// BCE-with-logits loss + gradient w.r.t logits.
231///
232/// Gradient: `dL/dx = (sigmoid(x) - t) / N`.
233#[inline]
234pub fn bce_with_logits_backward(logits: &[f32], target: &[f32], d_logits: &mut [f32]) -> f32 {
235    assert_eq!(
236        logits.len(),
237        target.len(),
238        "pred len {} does not match target len {}",
239        logits.len(),
240        target.len()
241    );
242    assert_eq!(
243        logits.len(),
244        d_logits.len(),
245        "pred len {} does not match d_pred len {}",
246        logits.len(),
247        d_logits.len()
248    );
249
250    if logits.is_empty() {
251        return 0.0;
252    }
253
254    let inv_n = 1.0 / logits.len() as f32;
255    let mut sum = 0.0_f32;
256
257    for i in 0..logits.len() {
258        let x = logits[i];
259        let t = target[i];
260        let abs_x = x.abs();
261        let loss = x.max(0.0) - x * t + (1.0 + (-abs_x).exp()).ln();
262        sum += loss;
263
264        let s = sigmoid(x);
265        d_logits[i] = (s - t) * inv_n;
266    }
267
268    sum * inv_n
269}
270
271/// Softmax cross-entropy over a single sample.
272///
273/// `logits` is a length-K vector. `target` is a one-hot length-K vector.
274#[inline]
275pub fn softmax_cross_entropy(logits: &[f32], target: &[f32]) -> f32 {
276    assert_eq!(
277        logits.len(),
278        target.len(),
279        "pred len {} does not match target len {}",
280        logits.len(),
281        target.len()
282    );
283    assert!(
284        !logits.is_empty(),
285        "softmax_cross_entropy requires at least 1 class"
286    );
287
288    let (log_sum_exp, _max) = log_sum_exp_and_max(logits);
289
290    // Cross entropy: -sum_i t_i * log softmax_i
291    // log softmax_i = logits[i] - log_sum_exp
292    let mut sum = 0.0_f32;
293    for i in 0..logits.len() {
294        let t = target[i];
295        if t != 0.0 {
296            sum -= t * (logits[i] - log_sum_exp);
297        }
298    }
299
300    // Mean over classes (matches the crate's "mean over pred.len()" convention).
301    sum / logits.len() as f32
302}
303
304/// Softmax cross-entropy + gradient w.r.t logits.
305///
306/// Writes `d_logits = (softmax(logits) - target) / K`.
307///
308/// This function is allocation-free: it computes softmax into `d_logits` and then
309/// turns it into a gradient in place.
310#[inline]
311pub fn softmax_cross_entropy_backward(logits: &[f32], target: &[f32], d_logits: &mut [f32]) -> f32 {
312    assert_eq!(
313        logits.len(),
314        target.len(),
315        "pred len {} does not match target len {}",
316        logits.len(),
317        target.len()
318    );
319    assert_eq!(
320        logits.len(),
321        d_logits.len(),
322        "pred len {} does not match d_pred len {}",
323        logits.len(),
324        d_logits.len()
325    );
326    assert!(
327        !logits.is_empty(),
328        "softmax_cross_entropy_backward requires at least 1 class"
329    );
330
331    let k = logits.len();
332    let inv_k = 1.0 / k as f32;
333
334    let (log_sum_exp, max_logit) = log_sum_exp_and_max(logits);
335
336    // Softmax into d_logits.
337    for i in 0..k {
338        d_logits[i] = (logits[i] - max_logit).exp();
339    }
340    let mut sum_exp = 0.0_f32;
341    for &v in d_logits.iter() {
342        sum_exp += v;
343    }
344    let inv_sum = 1.0 / sum_exp;
345    for v in d_logits.iter_mut() {
346        *v *= inv_sum;
347    }
348
349    // Loss.
350    let mut loss = 0.0_f32;
351    for i in 0..k {
352        let t = target[i];
353        if t != 0.0 {
354            loss -= t * (logits[i] - log_sum_exp);
355        }
356    }
357    loss *= inv_k;
358
359    // Gradient: (softmax - target) / K.
360    for i in 0..k {
361        d_logits[i] = (d_logits[i] - target[i]) * inv_k;
362    }
363
364    loss
365}
366
367#[inline]
368fn sigmoid(x: f32) -> f32 {
369    // Stable sigmoid (duplicated here to keep loss module self-contained).
370    if x >= 0.0 {
371        let z = (-x).exp();
372        1.0 / (1.0 + z)
373    } else {
374        let z = x.exp();
375        z / (1.0 + z)
376    }
377}
378
379#[inline]
380fn log_sum_exp_and_max(xs: &[f32]) -> (f32, f32) {
381    let mut max_x = xs[0];
382    for &x in xs.iter().skip(1) {
383        if x > max_x {
384            max_x = x;
385        }
386    }
387    let mut sum_exp = 0.0_f32;
388    for &x in xs {
389        sum_exp += (x - max_x).exp();
390    }
391    (max_x + sum_exp.ln(), max_x)
392}
393
394#[cfg(test)]
395mod tests {
396    use super::*;
397
398    #[test]
399    fn mse_is_zero_when_equal() {
400        let pred = [1.0_f32, -2.0, 0.5];
401        let target = pred;
402        assert_eq!(mse(&pred, &target), 0.0);
403    }
404
405    #[test]
406    fn mse_backward_matches_expected_gradient() {
407        let pred = [1.0_f32, 3.0];
408        let target = [2.0_f32, 1.0];
409        let mut d_pred = [0.0_f32; 2];
410        let loss = mse_backward(&pred, &target, &mut d_pred);
411
412        // L = 0.5 * mean([(-1)^2, (2)^2]) = 0.5 * (1 + 4)/2 = 1.25
413        assert!((loss - 1.25).abs() < 1e-6);
414        // dL/dpred = (pred - target) / N
415        assert!((d_pred[0] - (-0.5)).abs() < 1e-6);
416        assert!((d_pred[1] - (1.0)).abs() < 1e-6);
417    }
418
419    #[test]
420    fn bce_with_logits_is_reasonable_for_extreme_logits() {
421        let logits = [100.0_f32, -100.0];
422        let target = [1.0_f32, 0.0];
423        let loss = bce_with_logits(&logits, &target);
424        assert!(loss.is_finite());
425        assert!(loss < 1e-3);
426    }
427
428    #[test]
429    fn bce_with_logits_backward_matches_sigmoid_minus_target() {
430        let logits = [0.0_f32];
431        let target = [1.0_f32];
432        let mut d = [0.0_f32];
433        let loss = bce_with_logits_backward(&logits, &target, &mut d);
434        assert!((loss - std::f32::consts::LN_2).abs() < 1e-5);
435        // sigmoid(0) - 1 = -0.5
436        assert!((d[0] - (-0.5)).abs() < 1e-6);
437    }
438
439    #[test]
440    fn softmax_cross_entropy_prefers_correct_class() {
441        let logits_good = [5.0_f32, 0.0, -1.0];
442        let logits_bad = [-1.0_f32, 0.0, 5.0];
443        let target = [1.0_f32, 0.0, 0.0];
444        let loss_good = softmax_cross_entropy(&logits_good, &target);
445        let loss_bad = softmax_cross_entropy(&logits_bad, &target);
446        assert!(loss_good < loss_bad);
447    }
448}