tatk/indicators/linear_regression.rs
1//! Linear Regression (LR / LineReg), creates a best fit line.
2//!
3//! Creates a line that best fits a period of data using the least squares approach.
4
5use crate::traits::{AsValue, InternalValue, Next, Period, Stats};
6use crate::{Buffer, Num, TAError};
7use tatk_derive::{InternalValue, Period};
8
9/// Linear Regression (LR / LineReg), creates a best fit line.
10///
11/// Creates a line that best fits a period of data using the least squares approach.
12#[derive(Debug, InternalValue, Period)]
13pub struct LinearRegression {
14 /// Size of the period (window) in which data is looked at.
15 period: usize,
16 /// LR's current value.
17 value: Num,
18 /// Stasis values.
19 values: Buffer,
20 /// Holds all of the current period's values.
21 buffer: Buffer,
22 /// Sum of the X.
23 sum_x: Num,
24 /// Sum of the X, squared.
25 sum_x_sq: Num,
26 /// Intercept of the line.
27 intercept: Num,
28 /// Slope of the line.
29 slope: Num,
30}
31
32impl LinearRegression {
33 /// Creates a new Linear Regression line with the supplied period and initial data.
34 ///
35 /// ### Requirements:
36 ///
37 /// * Period must be greater than 1.
38 /// * Data must have at least `period` elements.
39 ///
40 /// ## Arguments
41 ///
42 /// * `period` - Size of the period / window used.
43 /// * `data` - Array of values to create the LR from.
44 pub fn new(period: usize, data: &[Num]) -> Result<Self, TAError> {
45 // Check we can calculate Linear Regression.
46 if period < 2 {
47 return Err(TAError::InvalidSize(String::from(
48 "period cannot be less than 2 to calculate linear regression",
49 )));
50 } else if data.len() < period {
51 // Make sure we have enough data.
52 return Err(TAError::InvalidData(String::from(
53 "not enough data for period provided",
54 )));
55 }
56
57 // Constants
58 let sum_x: Num = (period * (period + 1)) as Num * 0.5;
59 let sum_x_sq: Num = (period * (period + 1) * (2 * period + 1)) as Num / 6.0;
60
61 // Build the buffer containing the `period` of y values.
62 let mut values: Buffer = match Buffer::from_array(period, &data[..period]) {
63 Ok(value) => value,
64 Err(error) => return Err(error),
65 };
66
67 // Calculate the first value to seed the buffer.
68 let (mut intercept, mut slope) = Self::calculate(period, &values, sum_x, sum_x_sq);
69 let mut value: Num = intercept + (slope * period as Num);
70
71 // Build the buffer to hold old best fit values.
72 let mut buffer: Buffer = match Buffer::from_array(period, &[value]) {
73 Ok(value) => value,
74 Err(error) => return Err(error),
75 };
76
77 // Calculate the remaining best fit values.
78 for y in data[period..].iter() {
79 values.shift(*y);
80 (intercept, slope) = Self::calculate(period, &values, sum_x, sum_x_sq);
81
82 // Calculate new value.
83 value = intercept + (slope * period as Num);
84 buffer.shift(value);
85 }
86
87 Ok(Self {
88 period,
89 value,
90 values,
91 buffer,
92 sum_x,
93 sum_x_sq,
94 intercept,
95 slope,
96 })
97 }
98
99 /// Current and most recent value calculated.
100 pub fn value(&self) -> Num {
101 self.value
102 }
103
104 /// Calculates the intercept and slope for the line.
105 ///
106 /// # Arguments
107 ///
108 /// * `period` - Size of the period / window used.
109 /// * `values` - Last `period` of values to fit a line to.
110 /// * `sum_x` - Constant used, represents the sum of the time portion.
111 /// * `sum_x_sq` - Constant used, represents the square of the sum of the time portion.
112 fn calculate(period: usize, values: &Buffer, sum_x: Num, sum_x_sq: Num) -> (Num, Num) {
113 let sum_y: Num = values.sum();
114 let sum_xy: Num = (1..=period)
115 .zip(values.queue().iter().take(period))
116 .map(|(x, y)| x as Num * y)
117 .sum();
118
119 // Calculate intercept and slope.
120 let period_as: Num = period as Num;
121 let slope = (period_as * sum_xy - sum_x * sum_y) / (period_as * sum_x_sq - sum_x * sum_x);
122 let intercept = (sum_y - slope * sum_x) / period_as;
123
124 return (intercept, slope);
125 }
126
127 /// Predicted value of the dependent variable when all independent variables are set to zero.
128 pub fn intercept(&self) -> Num {
129 self.intercept
130 }
131
132 /// Coefficient associated with the independent variable.
133 pub fn slope(&self) -> Num {
134 self.slope
135 }
136
137 /// Percentage of variance in the dependent variable that can be explained by the independent variable.
138 pub fn r_sq(&self) -> Num {
139 let mean_y: Num = self.values.mean();
140 let predicted_y: Vec<Num> = (1..=self.period())
141 .map(|i| self.intercept() + self.slope() * i as Num)
142 .collect();
143
144 // Sum of Squares Total (sst) and Sum of Squares Residual (ssr).
145 let mut sst: Num = 0.0;
146 let mut ssr: Num = 0.0;
147 for (i, y) in self.values.queue().iter().enumerate() {
148 sst += (y - mean_y).powi(2);
149 ssr += (y - predicted_y[i]).powi(2);
150 }
151
152 1.0 - (ssr / sst)
153 }
154
155 /// Gets the standard deviation for the current line.
156 /// - ±1 stdev, 68%
157 /// - ±2 stdev, 95%
158 /// - ±3 stdev, 99.7%
159 pub fn line_stdev(&self) -> Num {
160 self.values.stdev(true)
161 }
162
163 /// Predicts (forecasts) a future value `distance` away from the current.
164 ///
165 /// # Arguments
166 ///
167 /// * `distance` - How far in the future to predict.
168 pub fn forecast(&self, distance: usize) -> Num {
169 self.intercept() + (self.slope() * (self.period() + distance) as Num)
170 }
171}
172
173impl Next<Num> for LinearRegression {
174 /// Next value for the LR.
175 type Output = Num;
176
177 /// Supply an additional value to recalculate a new LR.
178 ///
179 /// # Arguments
180 ///
181 /// * `value` - New value to add to period.
182 fn next(&mut self, value: Num) -> Self::Output {
183 // Rotate the buffer.
184 self.values.shift(value);
185
186 // Get the intercept and slope.
187 (self.intercept, self.slope) =
188 Self::calculate(self.period(), &self.values, self.sum_x, self.sum_x_sq);
189
190 // Calculate the current value.
191 self.value = self.intercept() + (self.slope() * self.period() as Num);
192 self.buffer.shift(self.value());
193
194 self.value
195 }
196}
197
198impl<T> Next<T> for LinearRegression
199where
200 T: AsValue,
201{
202 /// Next value for the LR.
203 type Output = Num;
204
205 /// Supply an additional value to recalculate a new LR.
206 ///
207 /// # Arguments
208 ///
209 /// * `value` - New value to add to period.
210 fn next(&mut self, value: T) -> Self::Output {
211 self.next(value.as_value())
212 }
213}
214
215impl Stats for LinearRegression {
216 /// Obtains the total sum of the buffer for LR.
217 fn sum(&self) -> Num {
218 self.buffer.sum()
219 }
220
221 /// Mean for the period of the LR.
222 fn mean(&self) -> Num {
223 self.buffer.mean()
224 }
225
226 /// Current variance for the period.
227 ///
228 /// # Arguments
229 ///
230 /// * `is_sample` - If the data is a Sample or Population, default should be True.
231 fn variance(&self, is_sample: bool) -> Num {
232 self.buffer.variance(is_sample)
233 }
234
235 /// Current standard deviation for the period.
236 ///
237 /// # Arguments
238 ///
239 /// * `is_sample` - If the data is a Sample or Population, default should be True.
240 fn stdev(&self, is_sample: bool) -> Num {
241 self.buffer.stdev(is_sample)
242 }
243}