tatk/indicators/
linear_regression.rs

1//! Linear Regression (LR / LineReg), creates a best fit line.
2//!
3//! Creates a line that best fits a period of data using the least squares approach.
4
5use crate::traits::{AsValue, InternalValue, Next, Period, Stats};
6use crate::{Buffer, Num, TAError};
7use tatk_derive::{InternalValue, Period};
8
9/// Linear Regression (LR / LineReg), creates a best fit line.
10///
11/// Creates a line that best fits a period of data using the least squares approach.
12#[derive(Debug, InternalValue, Period)]
13pub struct LinearRegression {
14    /// Size of the period (window) in which data is looked at.
15    period: usize,
16    /// LR's current value.
17    value: Num,
18    /// Stasis values.
19    values: Buffer,
20    /// Holds all of the current period's values.
21    buffer: Buffer,
22    /// Sum of the X.
23    sum_x: Num,
24    /// Sum of the X, squared.
25    sum_x_sq: Num,
26    /// Intercept of the line.
27    intercept: Num,
28    /// Slope of the line.
29    slope: Num,
30}
31
32impl LinearRegression {
33    /// Creates a new Linear Regression line with the supplied period and initial data.
34    ///
35    /// ### Requirements:
36    ///
37    /// * Period must be greater than 1.
38    /// * Data must have at least `period` elements.
39    ///
40    /// ## Arguments
41    ///
42    /// * `period` - Size of the period / window used.
43    /// * `data` - Array of values to create the LR from.
44    pub fn new(period: usize, data: &[Num]) -> Result<Self, TAError> {
45        // Check we can calculate Linear Regression.
46        if period < 2 {
47            return Err(TAError::InvalidSize(String::from(
48                "period cannot be less than 2 to calculate linear regression",
49            )));
50        } else if data.len() < period {
51            // Make sure we have enough data.
52            return Err(TAError::InvalidData(String::from(
53                "not enough data for period provided",
54            )));
55        }
56
57        // Constants
58        let sum_x: Num = (period * (period + 1)) as Num * 0.5;
59        let sum_x_sq: Num = (period * (period + 1) * (2 * period + 1)) as Num / 6.0;
60
61        // Build the buffer containing the `period` of y values.
62        let mut values: Buffer = match Buffer::from_array(period, &data[..period]) {
63            Ok(value) => value,
64            Err(error) => return Err(error),
65        };
66
67        // Calculate the first value to seed the buffer.
68        let (mut intercept, mut slope) = Self::calculate(period, &values, sum_x, sum_x_sq);
69        let mut value: Num = intercept + (slope * period as Num);
70
71        // Build the buffer to hold old best fit values.
72        let mut buffer: Buffer = match Buffer::from_array(period, &[value]) {
73            Ok(value) => value,
74            Err(error) => return Err(error),
75        };
76
77        // Calculate the remaining best fit values.
78        for y in data[period..].iter() {
79            values.shift(*y);
80            (intercept, slope) = Self::calculate(period, &values, sum_x, sum_x_sq);
81
82            // Calculate new value.
83            value = intercept + (slope * period as Num);
84            buffer.shift(value);
85        }
86
87        Ok(Self {
88            period,
89            value,
90            values,
91            buffer,
92            sum_x,
93            sum_x_sq,
94            intercept,
95            slope,
96        })
97    }
98
99    /// Current and most recent value calculated.
100    pub fn value(&self) -> Num {
101        self.value
102    }
103
104    /// Calculates the intercept and slope for the line.
105    ///
106    /// # Arguments
107    ///
108    /// * `period` - Size of the period / window used.
109    /// * `values` - Last `period` of values to fit a line to.
110    /// * `sum_x` - Constant used, represents the sum of the time portion.
111    /// * `sum_x_sq` - Constant used, represents the square of the sum of the time portion.
112    fn calculate(period: usize, values: &Buffer, sum_x: Num, sum_x_sq: Num) -> (Num, Num) {
113        let sum_y: Num = values.sum();
114        let sum_xy: Num = (1..=period)
115            .zip(values.queue().iter().take(period))
116            .map(|(x, y)| x as Num * y)
117            .sum();
118
119        // Calculate intercept and slope.
120        let period_as: Num = period as Num;
121        let slope = (period_as * sum_xy - sum_x * sum_y) / (period_as * sum_x_sq - sum_x * sum_x);
122        let intercept = (sum_y - slope * sum_x) / period_as;
123
124        return (intercept, slope);
125    }
126
127    /// Predicted value of the dependent variable when all independent variables are set to zero.
128    pub fn intercept(&self) -> Num {
129        self.intercept
130    }
131
132    /// Coefficient associated with the independent variable.
133    pub fn slope(&self) -> Num {
134        self.slope
135    }
136
137    /// Percentage of variance in the dependent variable that can be explained by the independent variable.
138    pub fn r_sq(&self) -> Num {
139        let mean_y: Num = self.values.mean();
140        let predicted_y: Vec<Num> = (1..=self.period())
141            .map(|i| self.intercept() + self.slope() * i as Num)
142            .collect();
143
144        // Sum of Squares Total (sst) and Sum of Squares Residual (ssr).
145        let mut sst: Num = 0.0;
146        let mut ssr: Num = 0.0;
147        for (i, y) in self.values.queue().iter().enumerate() {
148            sst += (y - mean_y).powi(2);
149            ssr += (y - predicted_y[i]).powi(2);
150        }
151
152        1.0 - (ssr / sst)
153    }
154
155    /// Gets the standard deviation for the current line.
156    /// - ±1 stdev, 68%
157    /// - ±2 stdev, 95%
158    /// - ±3 stdev, 99.7%
159    pub fn line_stdev(&self) -> Num {
160        self.values.stdev(true)
161    }
162
163    /// Predicts (forecasts) a future value `distance` away from the current.
164    ///
165    /// # Arguments
166    ///
167    /// * `distance` - How far in the future to predict.
168    pub fn forecast(&self, distance: usize) -> Num {
169        self.intercept() + (self.slope() * (self.period() + distance) as Num)
170    }
171}
172
173impl Next<Num> for LinearRegression {
174    /// Next value for the LR.
175    type Output = Num;
176
177    /// Supply an additional value to recalculate a new LR.
178    ///
179    /// # Arguments
180    ///
181    /// * `value` - New value to add to period.
182    fn next(&mut self, value: Num) -> Self::Output {
183        // Rotate the buffer.
184        self.values.shift(value);
185
186        // Get the intercept and slope.
187        (self.intercept, self.slope) =
188            Self::calculate(self.period(), &self.values, self.sum_x, self.sum_x_sq);
189
190        // Calculate the current value.
191        self.value = self.intercept() + (self.slope() * self.period() as Num);
192        self.buffer.shift(self.value());
193
194        self.value
195    }
196}
197
198impl<T> Next<T> for LinearRegression
199where
200    T: AsValue,
201{
202    /// Next value for the LR.
203    type Output = Num;
204
205    /// Supply an additional value to recalculate a new LR.
206    ///
207    /// # Arguments
208    ///
209    /// * `value` - New value to add to period.
210    fn next(&mut self, value: T) -> Self::Output {
211        self.next(value.as_value())
212    }
213}
214
215impl Stats for LinearRegression {
216    /// Obtains the total sum of the buffer for LR.
217    fn sum(&self) -> Num {
218        self.buffer.sum()
219    }
220
221    /// Mean for the period of the LR.
222    fn mean(&self) -> Num {
223        self.buffer.mean()
224    }
225
226    /// Current variance for the period.
227    ///
228    /// # Arguments
229    ///
230    /// * `is_sample` - If the data is a Sample or Population, default should be True.
231    fn variance(&self, is_sample: bool) -> Num {
232        self.buffer.variance(is_sample)
233    }
234
235    /// Current standard deviation for the period.
236    ///
237    /// # Arguments
238    ///
239    /// * `is_sample` - If the data is a Sample or Population, default should be True.
240    fn stdev(&self, is_sample: bool) -> Num {
241        self.buffer.stdev(is_sample)
242    }
243}