1use crate::StrError;
2use num_traits::{cast, Num, NumCast};
3use std::ops::{AddAssign, Mul};
4
5pub fn linear_fitting<T>(x: &[T], y: &[T], pass_through_zero: bool) -> Result<(f64, f64), StrError>
46where
47 T: AddAssign + Copy + Mul + Num + NumCast,
48{
49 let nn = x.len();
51 if y.len() != nn {
52 return Err("arrays must have the same lengths");
53 }
54
55 let mut t_sum_x = T::zero();
57 let mut t_sum_y = T::zero();
58 let mut t_sum_xy = T::zero();
59 let mut t_sum_xx = T::zero();
60 for i in 0..nn {
61 t_sum_x += x[i];
62 t_sum_y += y[i];
63 t_sum_xy += x[i] * y[i];
64 t_sum_xx += x[i] * x[i];
65 }
66
67 let sum_x: f64 = cast(t_sum_x).unwrap();
69 let sum_y: f64 = cast(t_sum_y).unwrap();
70 let sum_xy: f64 = cast(t_sum_xy).unwrap();
71 let sum_xx: f64 = cast(t_sum_xx).unwrap();
72
73 let c;
75 let m;
76 let n = nn as f64;
77 if pass_through_zero {
78 if sum_xx == 0.0 {
79 return Ok((0.0, f64::INFINITY));
80 }
81 c = 0.0;
82 m = sum_xy / sum_xx;
83 } else {
84 let den = sum_x * sum_x - n * sum_xx;
85 if den == 0.0 {
86 return Ok((0.0, f64::INFINITY));
87 }
88 c = (sum_x * sum_xy - sum_xx * sum_y) / den;
89 m = (sum_x * sum_y - n * sum_xy) / den;
90 }
91
92 Ok((c, m))
94}
95
96#[cfg(test)]
99mod tests {
100 use super::linear_fitting;
101 use crate::approx_eq;
102
103 #[test]
104 fn linear_fitting_handles_errors() {
105 let x = [1.0, 2.0];
106 let y = [6.0, 5.0, 7.0, 10.0];
107 assert_eq!(
108 linear_fitting(&x, &y, false).err(),
109 Some("arrays must have the same lengths")
110 );
111 }
112
113 #[test]
114 fn linear_fitting_works() {
115 let x = vec![1.0, 2.0, 3.0, 4.0];
118 let y = vec![6.0, 5.0, 7.0, 10.0];
119
120 let (c, m) = linear_fitting(&x, &y, false).unwrap();
121 assert_eq!(c, 3.5);
122 assert_eq!(m, 1.4);
123
124 let (c, m) = linear_fitting(&x, &y, true).unwrap();
125 assert_eq!(c, 0.0);
126 approx_eq(m, 2.566666666666667, 1e-16);
127
128 let x = [1, 2, 3, 4_usize];
131 let y = [6, 5, 7, 10_usize];
132
133 let (c, m) = linear_fitting(&x, &y, false).unwrap();
134 assert_eq!(c, 3.5);
135 assert_eq!(m, 1.4);
136
137 let (c, m) = linear_fitting(&x, &y, true).unwrap();
138 assert_eq!(c, 0.0);
139 approx_eq(m, 2.566666666666667, 1e-16);
140
141 let x = &[1, 2, 3, 4_i32];
144 let y = &[6, 5, 7, 10_i32];
145
146 let (c, m) = linear_fitting(x, y, false).unwrap();
147 assert_eq!(c, 3.5);
148 assert_eq!(m, 1.4);
149
150 let (c, m) = linear_fitting(x, y, true).unwrap();
151 assert_eq!(c, 0.0);
152 approx_eq(m, 2.566666666666667, 1e-16);
153 }
154
155 #[test]
156 fn linear_fitting_handles_division_by_zero() {
157 let x = [1.0, 1.0, 1.0, 1.0];
158 let y = [1.0, 2.0, 3.0, 4.0];
159
160 let (c, m) = linear_fitting(&x, &y, false).unwrap();
161 assert_eq!(c, 0.0);
162 assert_eq!(m, f64::INFINITY);
163
164 let x = [0.0, 0.0, 0.0, 0.0];
165 let y = [1.0, 2.0, 3.0, 4.0];
166 let (c, m) = linear_fitting(&x, &y, true).unwrap();
167 assert_eq!(c, 0.0);
168 assert_eq!(m, f64::INFINITY);
169 }
170}