wickra_core/indicators/
linreg_intercept.rs1use std::collections::VecDeque;
4
5use crate::error::{Error, Result};
6use crate::traits::Indicator;
7
8#[derive(Debug, Clone)]
35pub struct LinRegIntercept {
36 period: usize,
37 window: VecDeque<f64>,
38 sum_x: f64,
39 denom: f64,
40 sum_y: f64,
41 sum_xy: f64,
42}
43
44impl LinRegIntercept {
45 pub fn new(period: usize) -> Result<Self> {
51 if period < 2 {
52 return Err(Error::InvalidPeriod {
53 message: "linear regression intercept needs period >= 2",
54 });
55 }
56 let n = period as f64;
57 let sum_x = n * (n - 1.0) / 2.0;
58 let sum_xx = (n - 1.0) * n * (2.0 * n - 1.0) / 6.0;
59 Ok(Self {
60 period,
61 window: VecDeque::with_capacity(period),
62 sum_x,
63 denom: n * sum_xx - sum_x * sum_x,
64 sum_y: 0.0,
65 sum_xy: 0.0,
66 })
67 }
68
69 pub const fn period(&self) -> usize {
71 self.period
72 }
73}
74
75impl Indicator for LinRegIntercept {
76 type Input = f64;
77 type Output = f64;
78
79 fn update(&mut self, value: f64) -> Option<f64> {
80 if !value.is_finite() {
81 return None;
82 }
83 if self.window.len() == self.period {
84 let y0 = self.window.pop_front().expect("non-empty");
85 self.sum_xy = self.sum_xy - self.sum_y + y0;
86 self.sum_y -= y0;
87 }
88 let k = self.window.len() as f64;
89 self.window.push_back(value);
90 self.sum_y += value;
91 self.sum_xy += k * value;
92
93 if self.window.len() < self.period {
94 return None;
95 }
96 let n = self.period as f64;
97 let slope = (n * self.sum_xy - self.sum_x * self.sum_y) / self.denom;
98 let intercept = (self.sum_y - slope * self.sum_x) / n;
99 Some(intercept)
100 }
101
102 fn reset(&mut self) {
103 self.window.clear();
104 self.sum_y = 0.0;
105 self.sum_xy = 0.0;
106 }
107
108 fn warmup_period(&self) -> usize {
109 self.period
110 }
111
112 fn is_ready(&self) -> bool {
113 self.window.len() == self.period
114 }
115
116 fn name(&self) -> &'static str {
117 "LINEARREG_INTERCEPT"
118 }
119}
120
121#[cfg(test)]
122mod tests {
123 use super::*;
124 use crate::traits::BatchExt;
125 use approx::assert_relative_eq;
126
127 #[test]
128 fn rejects_short_period() {
129 assert!(matches!(
130 LinRegIntercept::new(1),
131 Err(Error::InvalidPeriod { .. })
132 ));
133 }
134
135 #[test]
136 fn accessors_report_config() {
137 let lr = LinRegIntercept::new(5).unwrap();
138 assert_eq!(lr.period(), 5);
139 assert_eq!(lr.name(), "LINEARREG_INTERCEPT");
140 assert_eq!(lr.warmup_period(), 5);
141 assert!(!lr.is_ready());
142 }
143
144 #[test]
145 fn reference_value() {
146 let mut lr = LinRegIntercept::new(3).unwrap();
148 let out: Vec<Option<f64>> = lr.batch(&[1.0, 2.0, 9.0]);
149 assert!(out[0].is_none());
150 assert!(out[1].is_none());
151 assert_relative_eq!(out[2].unwrap(), 0.0, epsilon = 1e-9);
152 assert!(lr.is_ready());
153 }
154
155 #[test]
156 fn slides_and_tracks_a_shifted_line() {
157 let mut lr = LinRegIntercept::new(3).unwrap();
160 let out: Vec<Option<f64>> = lr.batch(&[1.0, 10.0, 12.0, 14.0]);
161 assert_relative_eq!(out[3].unwrap(), 10.0, epsilon = 1e-9);
162 }
163
164 #[test]
165 fn reset_clears_state() {
166 let mut lr = LinRegIntercept::new(3).unwrap();
167 let _ = lr.batch(&[1.0, 2.0, 9.0]);
168 assert!(lr.is_ready());
169 lr.reset();
170 assert!(!lr.is_ready());
171 assert_eq!(lr.update(1.0), None);
172 }
173}