russell_lab/algo/
linear_fitting.rs

1use crate::StrError;
2use num_traits::{cast, Num, NumCast};
3use std::ops::{AddAssign, Mul};
4
5/// Calculates the parameters of a linear model using least squares fitting
6///
7/// # Input
8///
9/// `x` -- the X-data vector with dimension n
10/// `y` -- the Y-data vector with dimension n
11/// `pass_through_zero` -- compute the parameters such that the line passes through zero (c = 0)
12///
13/// # Output
14///
15/// * `(c, m)` -- the y(x=0)=c intersect and the slope m
16///
17/// # Special cases
18///
19/// This function returns `(0.0, f64::INFINITY)` in two situations:
20///
21/// * If `pass_through_zero == True` and `sum(X) == 0`
22/// * If `pass_through_zero == False` and the line is vertical (null denominator)
23///
24/// # Panics
25///
26/// This function may panic if the number type cannot be converted to `f64`.
27///
28/// # Examples
29///
30/// ![Linear fitting](https://raw.githubusercontent.com/cpmech/russell/main/russell_lab/data/figures/algo_linear_fitting_1.svg)
31///
32/// ```
33/// use russell_lab::{approx_eq, linear_fitting, StrError};
34///
35/// fn main() -> Result<(), StrError> {
36///     // model: c is the y value @ x = 0; m is the slope
37///     let x = [0.0, 1.0, 3.0, 5.0];
38///     let y = [1.0, 0.0, 2.0, 4.0];
39///     let (c, m) = linear_fitting(&x, &y, false)?;
40///     approx_eq(c, 0.1864406779661015, 1e-15);
41///     approx_eq(m, 0.6949152542372882, 1e-15);
42///     Ok(())
43/// }
44/// ```
45pub fn linear_fitting<T>(x: &[T], y: &[T], pass_through_zero: bool) -> Result<(f64, f64), StrError>
46where
47    T: AddAssign + Copy + Mul + Num + NumCast,
48{
49    // dimension
50    let nn = x.len();
51    if y.len() != nn {
52        return Err("arrays must have the same lengths");
53    }
54
55    // sums
56    let mut t_sum_x = T::zero();
57    let mut t_sum_y = T::zero();
58    let mut t_sum_xy = T::zero();
59    let mut t_sum_xx = T::zero();
60    for i in 0..nn {
61        t_sum_x += x[i];
62        t_sum_y += y[i];
63        t_sum_xy += x[i] * y[i];
64        t_sum_xx += x[i] * x[i];
65    }
66
67    // cast sums to f64
68    let sum_x: f64 = cast(t_sum_x).unwrap();
69    let sum_y: f64 = cast(t_sum_y).unwrap();
70    let sum_xy: f64 = cast(t_sum_xy).unwrap();
71    let sum_xx: f64 = cast(t_sum_xx).unwrap();
72
73    // calculate parameters
74    let c;
75    let m;
76    let n = nn as f64;
77    if pass_through_zero {
78        if sum_xx == 0.0 {
79            return Ok((0.0, f64::INFINITY));
80        }
81        c = 0.0;
82        m = sum_xy / sum_xx;
83    } else {
84        let den = sum_x * sum_x - n * sum_xx;
85        if den == 0.0 {
86            return Ok((0.0, f64::INFINITY));
87        }
88        c = (sum_x * sum_xy - sum_xx * sum_y) / den;
89        m = (sum_x * sum_y - n * sum_xy) / den;
90    }
91
92    // results
93    Ok((c, m))
94}
95
96////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
97
98#[cfg(test)]
99mod tests {
100    use super::linear_fitting;
101    use crate::approx_eq;
102
103    #[test]
104    fn linear_fitting_handles_errors() {
105        let x = [1.0, 2.0];
106        let y = [6.0, 5.0, 7.0, 10.0];
107        assert_eq!(
108            linear_fitting(&x, &y, false).err(),
109            Some("arrays must have the same lengths")
110        );
111    }
112
113    #[test]
114    fn linear_fitting_works() {
115        // f64 (heap)
116
117        let x = vec![1.0, 2.0, 3.0, 4.0];
118        let y = vec![6.0, 5.0, 7.0, 10.0];
119
120        let (c, m) = linear_fitting(&x, &y, false).unwrap();
121        assert_eq!(c, 3.5);
122        assert_eq!(m, 1.4);
123
124        let (c, m) = linear_fitting(&x, &y, true).unwrap();
125        assert_eq!(c, 0.0);
126        approx_eq(m, 2.566666666666667, 1e-16);
127
128        // usize (stack)
129
130        let x = [1, 2, 3, 4_usize];
131        let y = [6, 5, 7, 10_usize];
132
133        let (c, m) = linear_fitting(&x, &y, false).unwrap();
134        assert_eq!(c, 3.5);
135        assert_eq!(m, 1.4);
136
137        let (c, m) = linear_fitting(&x, &y, true).unwrap();
138        assert_eq!(c, 0.0);
139        approx_eq(m, 2.566666666666667, 1e-16);
140
141        // i32 (slice)
142
143        let x = &[1, 2, 3, 4_i32];
144        let y = &[6, 5, 7, 10_i32];
145
146        let (c, m) = linear_fitting(x, y, false).unwrap();
147        assert_eq!(c, 3.5);
148        assert_eq!(m, 1.4);
149
150        let (c, m) = linear_fitting(x, y, true).unwrap();
151        assert_eq!(c, 0.0);
152        approx_eq(m, 2.566666666666667, 1e-16);
153    }
154
155    #[test]
156    fn linear_fitting_handles_division_by_zero() {
157        let x = [1.0, 1.0, 1.0, 1.0];
158        let y = [1.0, 2.0, 3.0, 4.0];
159
160        let (c, m) = linear_fitting(&x, &y, false).unwrap();
161        assert_eq!(c, 0.0);
162        assert_eq!(m, f64::INFINITY);
163
164        let x = [0.0, 0.0, 0.0, 0.0];
165        let y = [1.0, 2.0, 3.0, 4.0];
166        let (c, m) = linear_fitting(&x, &y, true).unwrap();
167        assert_eq!(c, 0.0);
168        assert_eq!(m, f64::INFINITY);
169    }
170}