Skip to main content

use_loss/
lib.rs

1#![forbid(unsafe_code)]
2//! Primitive loss and error helpers for optimization.
3//!
4//! # Examples
5//!
6//! ```rust
7//! use use_loss::{absolute_error, mean_squared_error, root_mean_squared_error};
8//!
9//! assert_eq!(absolute_error(4.0, 3.0), 1.0);
10//! assert_eq!(mean_squared_error(&[1.0, 2.0], &[1.0, 4.0]).unwrap(), 2.0);
11//! assert_eq!(root_mean_squared_error(&[1.0, 2.0], &[1.0, 4.0]).unwrap(), 2.0_f64.sqrt());
12//! ```
13
14#[derive(Debug, Clone, PartialEq, Eq)]
15pub enum LossError {
16    EmptyInput,
17    MismatchedLengths,
18    NonFiniteInput,
19}
20
21pub fn absolute_error(actual: f64, predicted: f64) -> f64 {
22    (actual - predicted).abs()
23}
24
25pub fn squared_error(actual: f64, predicted: f64) -> f64 {
26    let difference = actual - predicted;
27    difference * difference
28}
29
30pub fn mean_absolute_error(actual: &[f64], predicted: &[f64]) -> Result<f64, LossError> {
31    validate_inputs(actual, predicted)?;
32
33    Ok(actual
34        .iter()
35        .zip(predicted.iter())
36        .map(|(actual_value, predicted_value)| absolute_error(*actual_value, *predicted_value))
37        .sum::<f64>()
38        / actual.len() as f64)
39}
40
41pub fn mean_squared_error(actual: &[f64], predicted: &[f64]) -> Result<f64, LossError> {
42    validate_inputs(actual, predicted)?;
43
44    Ok(actual
45        .iter()
46        .zip(predicted.iter())
47        .map(|(actual_value, predicted_value)| squared_error(*actual_value, *predicted_value))
48        .sum::<f64>()
49        / actual.len() as f64)
50}
51
52pub fn root_mean_squared_error(actual: &[f64], predicted: &[f64]) -> Result<f64, LossError> {
53    Ok(mean_squared_error(actual, predicted)?.sqrt())
54}
55
56fn validate_inputs(actual: &[f64], predicted: &[f64]) -> Result<(), LossError> {
57    if actual.is_empty() || predicted.is_empty() {
58        return Err(LossError::EmptyInput);
59    }
60
61    if actual.len() != predicted.len() {
62        return Err(LossError::MismatchedLengths);
63    }
64
65    if actual
66        .iter()
67        .chain(predicted.iter())
68        .any(|value| !value.is_finite())
69    {
70        return Err(LossError::NonFiniteInput);
71    }
72
73    Ok(())
74}
75
76#[cfg(test)]
77mod tests {
78    use super::{
79        LossError, absolute_error, mean_absolute_error, mean_squared_error,
80        root_mean_squared_error, squared_error,
81    };
82
83    fn approx_eq(left: f64, right: f64) {
84        assert!((left - right).abs() < 1.0e-10, "left={left}, right={right}");
85    }
86
87    #[test]
88    fn computes_basic_error_terms() {
89        assert_eq!(absolute_error(4.0, 3.0), 1.0);
90        assert_eq!(squared_error(4.0, 3.0), 1.0);
91    }
92
93    #[test]
94    fn computes_common_loss_functions() {
95        let actual = [1.0, 2.0, 3.0];
96        let predicted = [1.5, 2.5, 2.0];
97
98        approx_eq(mean_absolute_error(&actual, &predicted).unwrap(), 2.0 / 3.0);
99        approx_eq(mean_squared_error(&actual, &predicted).unwrap(), 0.5);
100        approx_eq(
101            root_mean_squared_error(&actual, &predicted).unwrap(),
102            0.5_f64.sqrt(),
103        );
104    }
105
106    #[test]
107    fn handles_single_value_inputs() {
108        approx_eq(mean_absolute_error(&[3.0], &[2.0]).unwrap(), 1.0);
109        approx_eq(mean_squared_error(&[3.0], &[2.0]).unwrap(), 1.0);
110        approx_eq(root_mean_squared_error(&[3.0], &[2.0]).unwrap(), 1.0);
111    }
112
113    #[test]
114    fn rejects_invalid_loss_inputs() {
115        assert_eq!(mean_absolute_error(&[], &[]), Err(LossError::EmptyInput));
116        assert_eq!(
117            mean_squared_error(&[1.0], &[1.0, 2.0]),
118            Err(LossError::MismatchedLengths)
119        );
120        assert_eq!(
121            root_mean_squared_error(&[1.0, f64::NAN], &[1.0, 2.0]),
122            Err(LossError::NonFiniteInput)
123        );
124    }
125}