wickra_core/indicators/
linreg_channel.rs1use std::collections::VecDeque;
4
5use crate::error::{Error, Result};
6use crate::traits::Indicator;
7
8#[derive(Debug, Clone, Copy, PartialEq)]
10pub struct LinRegChannelOutput {
11 pub upper: f64,
14 pub middle: f64,
16 pub lower: f64,
19}
20
21#[derive(Debug, Clone)]
53pub struct LinRegChannel {
54 period: usize,
55 multiplier: f64,
56 window: VecDeque<f64>,
57 sum_x: f64,
58 sum_xx: f64,
59}
60
61impl LinRegChannel {
62 pub fn new(period: usize, multiplier: f64) -> Result<Self> {
67 if period < 2 {
68 return Err(Error::InvalidPeriod {
69 message: "linear regression channel needs period >= 2",
70 });
71 }
72 if !multiplier.is_finite() || multiplier <= 0.0 {
73 return Err(Error::NonPositiveMultiplier);
74 }
75 let n = period as f64;
76 Ok(Self {
77 period,
78 multiplier,
79 window: VecDeque::with_capacity(period),
80 sum_x: n * (n - 1.0) / 2.0,
81 sum_xx: (n - 1.0) * n * (2.0 * n - 1.0) / 6.0,
82 })
83 }
84
85 pub const fn period(&self) -> usize {
87 self.period
88 }
89
90 pub const fn multiplier(&self) -> f64 {
92 self.multiplier
93 }
94}
95
96impl Indicator for LinRegChannel {
97 type Input = f64;
98 type Output = LinRegChannelOutput;
99
100 fn update(&mut self, value: f64) -> Option<LinRegChannelOutput> {
101 if !value.is_finite() {
102 return None;
103 }
104 if self.window.len() == self.period {
105 self.window.pop_front();
106 }
107 self.window.push_back(value);
108 if self.window.len() < self.period {
109 return None;
110 }
111 let n = self.period as f64;
117 let mut sum_y = 0.0;
118 let mut sum_xy = 0.0;
119 for (i, &y) in self.window.iter().enumerate() {
120 let x = i as f64;
121 sum_y += y;
122 sum_xy += x * y;
123 }
124 let denom = n * self.sum_xx - self.sum_x * self.sum_x;
125 let slope = (n * sum_xy - self.sum_x * sum_y) / denom;
126 let intercept = (sum_y - slope * self.sum_x) / n;
127
128 let mut sum_sq = 0.0;
130 for (i, &y) in self.window.iter().enumerate() {
131 let fitted = intercept + slope * (i as f64);
132 let r = y - fitted;
133 sum_sq += r * r;
134 }
135 let sigma = (sum_sq / n).sqrt();
136 let middle = intercept + slope * (n - 1.0);
137 Some(LinRegChannelOutput {
138 upper: middle + self.multiplier * sigma,
139 middle,
140 lower: middle - self.multiplier * sigma,
141 })
142 }
143
144 fn reset(&mut self) {
145 self.window.clear();
146 }
147
148 fn warmup_period(&self) -> usize {
149 self.period
150 }
151
152 fn is_ready(&self) -> bool {
153 self.window.len() == self.period
154 }
155
156 fn name(&self) -> &'static str {
157 "LinRegChannel"
158 }
159}
160
161#[cfg(test)]
162mod tests {
163 use super::*;
164 use crate::traits::BatchExt;
165 use approx::assert_relative_eq;
166
167 #[test]
168 fn rejects_period_below_two() {
169 assert!(LinRegChannel::new(0, 2.0).is_err());
170 assert!(LinRegChannel::new(1, 2.0).is_err());
171 assert!(LinRegChannel::new(2, 2.0).is_ok());
172 }
173
174 #[test]
175 fn rejects_non_positive_multiplier() {
176 assert!(matches!(
177 LinRegChannel::new(20, 0.0),
178 Err(Error::NonPositiveMultiplier)
179 ));
180 assert!(matches!(
181 LinRegChannel::new(20, -1.0),
182 Err(Error::NonPositiveMultiplier)
183 ));
184 assert!(matches!(
185 LinRegChannel::new(20, f64::NAN),
186 Err(Error::NonPositiveMultiplier)
187 ));
188 }
189
190 #[test]
191 fn accessors_and_metadata() {
192 let lc = LinRegChannel::new(20, 2.0).unwrap();
193 assert_eq!(lc.period(), 20);
194 assert_relative_eq!(lc.multiplier(), 2.0, epsilon = 1e-12);
195 assert_eq!(lc.warmup_period(), 20);
196 assert_eq!(lc.name(), "LinRegChannel");
197 }
198
199 #[test]
200 fn perfect_line_collapses_channel() {
201 let prices: Vec<f64> = (0..40).map(|i| 2.0 * f64::from(i) + 5.0).collect();
203 let mut lc = LinRegChannel::new(10, 2.0).unwrap();
204 for o in lc.batch(&prices).into_iter().flatten() {
205 assert_relative_eq!(o.upper, o.middle, epsilon = 1e-9);
206 assert_relative_eq!(o.middle, o.lower, epsilon = 1e-9);
207 }
208 }
209
210 #[test]
211 fn constant_series_collapses_channel() {
212 let mut lc = LinRegChannel::new(8, 2.0).unwrap();
213 let out = lc.batch(&[42.0; 20]);
214 let v = out.iter().rev().flatten().next().unwrap();
215 assert_relative_eq!(v.middle, 42.0, epsilon = 1e-9);
216 assert_relative_eq!(v.upper, 42.0, epsilon = 1e-9);
217 assert_relative_eq!(v.lower, 42.0, epsilon = 1e-9);
218 }
219
220 #[test]
221 fn upper_above_middle_above_lower() {
222 let prices: Vec<f64> = (0..80)
223 .map(|i| 100.0 + (f64::from(i) * 0.3).sin() * 10.0)
224 .collect();
225 let mut lc = LinRegChannel::new(20, 2.0).unwrap();
226 for o in lc.batch(&prices).into_iter().flatten() {
227 assert!(o.upper >= o.middle);
228 assert!(o.middle >= o.lower);
229 }
230 }
231
232 #[test]
233 fn batch_equals_streaming() {
234 let prices: Vec<f64> = (0..60)
235 .map(|i| 50.0 + (f64::from(i) * 0.3).sin() * 10.0)
236 .collect();
237 let mut a = LinRegChannel::new(14, 2.0).unwrap();
238 let mut b = LinRegChannel::new(14, 2.0).unwrap();
239 assert_eq!(
240 a.batch(&prices),
241 prices.iter().map(|p| b.update(*p)).collect::<Vec<_>>()
242 );
243 }
244
245 #[test]
246 fn reset_clears_state() {
247 let mut lc = LinRegChannel::new(5, 2.0).unwrap();
248 lc.batch(&[1.0, 2.0, 3.0, 4.0, 5.0]);
249 assert!(lc.is_ready());
250 lc.reset();
251 assert!(!lc.is_ready());
252 assert_eq!(lc.update(1.0), None);
253 }
254
255 #[test]
260 fn reference_values() {
261 let mut lc = LinRegChannel::new(3, 2.0).unwrap();
262 let out = lc.batch(&[1.0, 2.0, 9.0]);
263 let v = out[2].unwrap();
264 let s2 = f64::sqrt(2.0);
265 assert_relative_eq!(v.middle, 8.0, epsilon = 1e-9);
266 assert_relative_eq!(v.upper, 8.0 + 2.0 * s2, epsilon = 1e-9);
267 assert_relative_eq!(v.lower, 8.0 - 2.0 * s2, epsilon = 1e-9);
268 }
269}