wickra_core/indicators/
linreg_channel.rs1use std::collections::VecDeque;
4
5use crate::error::{Error, Result};
6use crate::traits::Indicator;
7
8#[derive(Debug, Clone, Copy, PartialEq)]
10pub struct LinRegChannelOutput {
11 pub upper: f64,
14 pub middle: f64,
16 pub lower: f64,
19}
20
21#[derive(Debug, Clone)]
53pub struct LinRegChannel {
54 period: usize,
55 multiplier: f64,
56 window: VecDeque<f64>,
57 sum_x: f64,
58 sum_xx: f64,
59}
60
61impl LinRegChannel {
62 pub fn new(period: usize, multiplier: f64) -> Result<Self> {
67 if period < 2 {
68 return Err(Error::InvalidPeriod {
69 message: "linear regression channel needs period >= 2",
70 });
71 }
72 if !multiplier.is_finite() || multiplier <= 0.0 {
73 return Err(Error::NonPositiveMultiplier);
74 }
75 let n = period as f64;
76 Ok(Self {
77 period,
78 multiplier,
79 window: VecDeque::with_capacity(period),
80 sum_x: n * (n - 1.0) / 2.0,
81 sum_xx: (n - 1.0) * n * (2.0 * n - 1.0) / 6.0,
82 })
83 }
84
85 pub const fn period(&self) -> usize {
87 self.period
88 }
89
90 pub const fn multiplier(&self) -> f64 {
92 self.multiplier
93 }
94}
95
96impl Indicator for LinRegChannel {
97 type Input = f64;
98 type Output = LinRegChannelOutput;
99
100 fn update(&mut self, value: f64) -> Option<LinRegChannelOutput> {
101 if self.window.len() == self.period {
102 self.window.pop_front();
103 }
104 self.window.push_back(value);
105 if self.window.len() < self.period {
106 return None;
107 }
108 let n = self.period as f64;
114 let mut sum_y = 0.0;
115 let mut sum_xy = 0.0;
116 for (i, &y) in self.window.iter().enumerate() {
117 let x = i as f64;
118 sum_y += y;
119 sum_xy += x * y;
120 }
121 let denom = n * self.sum_xx - self.sum_x * self.sum_x;
122 let slope = (n * sum_xy - self.sum_x * sum_y) / denom;
123 let intercept = (sum_y - slope * self.sum_x) / n;
124
125 let mut sum_sq = 0.0;
127 for (i, &y) in self.window.iter().enumerate() {
128 let fitted = intercept + slope * (i as f64);
129 let r = y - fitted;
130 sum_sq += r * r;
131 }
132 let sigma = (sum_sq / n).sqrt();
133 let middle = intercept + slope * (n - 1.0);
134 Some(LinRegChannelOutput {
135 upper: middle + self.multiplier * sigma,
136 middle,
137 lower: middle - self.multiplier * sigma,
138 })
139 }
140
141 fn reset(&mut self) {
142 self.window.clear();
143 }
144
145 fn warmup_period(&self) -> usize {
146 self.period
147 }
148
149 fn is_ready(&self) -> bool {
150 self.window.len() == self.period
151 }
152
153 fn name(&self) -> &'static str {
154 "LinRegChannel"
155 }
156}
157
158#[cfg(test)]
159mod tests {
160 use super::*;
161 use crate::traits::BatchExt;
162 use approx::assert_relative_eq;
163
164 #[test]
165 fn rejects_period_below_two() {
166 assert!(LinRegChannel::new(0, 2.0).is_err());
167 assert!(LinRegChannel::new(1, 2.0).is_err());
168 assert!(LinRegChannel::new(2, 2.0).is_ok());
169 }
170
171 #[test]
172 fn rejects_non_positive_multiplier() {
173 assert!(matches!(
174 LinRegChannel::new(20, 0.0),
175 Err(Error::NonPositiveMultiplier)
176 ));
177 assert!(matches!(
178 LinRegChannel::new(20, -1.0),
179 Err(Error::NonPositiveMultiplier)
180 ));
181 assert!(matches!(
182 LinRegChannel::new(20, f64::NAN),
183 Err(Error::NonPositiveMultiplier)
184 ));
185 }
186
187 #[test]
188 fn accessors_and_metadata() {
189 let lc = LinRegChannel::new(20, 2.0).unwrap();
190 assert_eq!(lc.period(), 20);
191 assert_relative_eq!(lc.multiplier(), 2.0, epsilon = 1e-12);
192 assert_eq!(lc.warmup_period(), 20);
193 assert_eq!(lc.name(), "LinRegChannel");
194 }
195
196 #[test]
197 fn perfect_line_collapses_channel() {
198 let prices: Vec<f64> = (0..40).map(|i| 2.0 * f64::from(i) + 5.0).collect();
200 let mut lc = LinRegChannel::new(10, 2.0).unwrap();
201 for o in lc.batch(&prices).into_iter().flatten() {
202 assert_relative_eq!(o.upper, o.middle, epsilon = 1e-9);
203 assert_relative_eq!(o.middle, o.lower, epsilon = 1e-9);
204 }
205 }
206
207 #[test]
208 fn constant_series_collapses_channel() {
209 let mut lc = LinRegChannel::new(8, 2.0).unwrap();
210 let out = lc.batch(&[42.0; 20]);
211 let v = out.iter().rev().flatten().next().unwrap();
212 assert_relative_eq!(v.middle, 42.0, epsilon = 1e-9);
213 assert_relative_eq!(v.upper, 42.0, epsilon = 1e-9);
214 assert_relative_eq!(v.lower, 42.0, epsilon = 1e-9);
215 }
216
217 #[test]
218 fn upper_above_middle_above_lower() {
219 let prices: Vec<f64> = (0..80)
220 .map(|i| 100.0 + (f64::from(i) * 0.3).sin() * 10.0)
221 .collect();
222 let mut lc = LinRegChannel::new(20, 2.0).unwrap();
223 for o in lc.batch(&prices).into_iter().flatten() {
224 assert!(o.upper >= o.middle);
225 assert!(o.middle >= o.lower);
226 }
227 }
228
229 #[test]
230 fn batch_equals_streaming() {
231 let prices: Vec<f64> = (0..60)
232 .map(|i| 50.0 + (f64::from(i) * 0.3).sin() * 10.0)
233 .collect();
234 let mut a = LinRegChannel::new(14, 2.0).unwrap();
235 let mut b = LinRegChannel::new(14, 2.0).unwrap();
236 assert_eq!(
237 a.batch(&prices),
238 prices.iter().map(|p| b.update(*p)).collect::<Vec<_>>()
239 );
240 }
241
242 #[test]
243 fn reset_clears_state() {
244 let mut lc = LinRegChannel::new(5, 2.0).unwrap();
245 lc.batch(&[1.0, 2.0, 3.0, 4.0, 5.0]);
246 assert!(lc.is_ready());
247 lc.reset();
248 assert!(!lc.is_ready());
249 assert_eq!(lc.update(1.0), None);
250 }
251
252 #[test]
257 fn reference_values() {
258 let mut lc = LinRegChannel::new(3, 2.0).unwrap();
259 let out = lc.batch(&[1.0, 2.0, 9.0]);
260 let v = out[2].unwrap();
261 let s2 = f64::sqrt(2.0);
262 assert_relative_eq!(v.middle, 8.0, epsilon = 1e-9);
263 assert_relative_eq!(v.upper, 8.0 + 2.0 * s2, epsilon = 1e-9);
264 assert_relative_eq!(v.lower, 8.0 - 2.0 * s2, epsilon = 1e-9);
265 }
266}