1use crate::error::{Error, Result};
4use crate::traits::Indicator;
5
6#[derive(Debug, Clone, Copy, PartialEq)]
8pub struct BollingerOutput {
9 pub upper: f64,
11 pub middle: f64,
13 pub lower: f64,
15 pub stddev: f64,
18}
19
20#[derive(Debug, Clone)]
47pub struct BollingerBands {
48 period: usize,
49 multiplier: f64,
50 buf: Box<[f64]>,
53 head: usize,
55 count: usize,
57 sum: f64,
58 sum_sq: f64,
59 updates_since_recompute: usize,
62}
63
64const RECOMPUTE_EVERY: usize = 16;
69
70impl BollingerBands {
71 pub fn new(period: usize, multiplier: f64) -> Result<Self> {
78 if period == 0 {
79 return Err(Error::PeriodZero);
80 }
81 if !multiplier.is_finite() || multiplier <= 0.0 {
82 return Err(Error::NonPositiveMultiplier);
83 }
84 Ok(Self {
85 period,
86 multiplier,
87 buf: vec![0.0; period].into_boxed_slice(),
88 head: 0,
89 count: 0,
90 sum: 0.0,
91 sum_sq: 0.0,
92 updates_since_recompute: 0,
93 })
94 }
95
96 pub fn classic() -> Self {
98 Self::new(20, 2.0).expect("classic Bollinger parameters are valid")
99 }
100
101 pub const fn period(&self) -> usize {
103 self.period
104 }
105
106 pub const fn multiplier(&self) -> f64 {
108 self.multiplier
109 }
110
111 fn current(&self) -> Option<BollingerOutput> {
112 if self.count != self.period {
113 return None;
114 }
115 let n = self.period as f64;
116 let mean = self.sum / n;
117 let var = (self.sum_sq / n - mean * mean).max(0.0);
120 let stddev = var.sqrt();
121 Some(BollingerOutput {
122 upper: mean + self.multiplier * stddev,
123 middle: mean,
124 lower: mean - self.multiplier * stddev,
125 stddev,
126 })
127 }
128}
129
130impl Indicator for BollingerBands {
131 type Input = f64;
132 type Output = BollingerOutput;
133
134 fn update(&mut self, input: f64) -> Option<BollingerOutput> {
135 if !input.is_finite() {
136 return self.current();
137 }
138 if self.count == self.period {
139 let old = self.buf[self.head];
140 self.sum -= old;
141 self.sum_sq -= old * old;
142 self.buf[self.head] = input;
143 self.sum += input;
144 self.sum_sq += input * input;
145 } else {
146 self.buf[self.head] = input;
147 self.sum += input;
148 self.sum_sq += input * input;
149 self.count += 1;
150 }
151 self.head += 1;
152 if self.head == self.period {
153 self.head = 0;
154 }
155 self.updates_since_recompute += 1;
156 if self.updates_since_recompute >= RECOMPUTE_EVERY * self.period {
157 let chronological = self.buf[self.head..].iter().chain(&self.buf[..self.head]);
160 self.sum = chronological.clone().copied().sum();
161 self.sum_sq = chronological.map(|&x| x * x).sum();
162 self.updates_since_recompute = 0;
163 }
164 self.current()
165 }
166
167 fn reset(&mut self) {
168 self.head = 0;
169 self.count = 0;
170 self.sum = 0.0;
171 self.sum_sq = 0.0;
172 self.updates_since_recompute = 0;
173 }
174
175 fn warmup_period(&self) -> usize {
176 self.period
177 }
178
179 fn is_ready(&self) -> bool {
180 self.count == self.period
181 }
182
183 fn name(&self) -> &'static str {
184 "BollingerBands"
185 }
186}
187
188#[cfg(test)]
189mod tests {
190 use super::*;
191 use crate::traits::BatchExt;
192 use approx::assert_relative_eq;
193 use std::collections::VecDeque;
194
195 fn naive(prices: &[f64], period: usize, mult: f64) -> BollingerOutput {
196 assert!(
197 prices.len() >= period,
198 "naive requires at least `period` prices"
199 );
200 let w = &prices[prices.len() - period..];
201 let mean = w.iter().sum::<f64>() / period as f64;
202 let var = w.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / period as f64;
203 let s = var.sqrt();
204 BollingerOutput {
205 upper: mean + mult * s,
206 middle: mean,
207 lower: mean - mult * s,
208 stddev: s,
209 }
210 }
211
212 #[test]
213 fn rejects_zero_period() {
214 assert!(matches!(
215 BollingerBands::new(0, 2.0),
216 Err(Error::PeriodZero)
217 ));
218 }
219
220 #[test]
221 fn rejects_non_positive_multiplier() {
222 assert!(matches!(
223 BollingerBands::new(20, 0.0),
224 Err(Error::NonPositiveMultiplier)
225 ));
226 assert!(matches!(
227 BollingerBands::new(20, -1.0),
228 Err(Error::NonPositiveMultiplier)
229 ));
230 assert!(matches!(
231 BollingerBands::new(20, f64::NAN),
232 Err(Error::NonPositiveMultiplier)
233 ));
234 }
235
236 #[test]
242 fn classic_and_accessors_and_metadata() {
243 let bb = BollingerBands::classic();
244 assert_eq!(bb.period(), 20);
245 assert_relative_eq!(bb.multiplier(), 2.0, epsilon = 1e-12);
246 assert_eq!(bb.warmup_period(), 20);
247 assert_eq!(bb.name(), "BollingerBands");
248 }
249
250 #[test]
251 fn warmup_returns_none() {
252 let mut bb = BollingerBands::new(5, 2.0).unwrap();
253 for v in [1.0, 2.0, 3.0, 4.0] {
254 assert!(bb.update(v).is_none());
255 }
256 assert!(bb.update(5.0).is_some());
257 }
258
259 #[test]
260 fn constant_series_yields_zero_stddev() {
261 let mut bb = BollingerBands::new(10, 2.0).unwrap();
262 let out = bb.batch(&[5.0_f64; 30]);
263 let last = out.iter().rev().flatten().next().unwrap();
264 assert_relative_eq!(last.middle, 5.0, epsilon = 1e-12);
265 assert_relative_eq!(last.stddev, 0.0, epsilon = 1e-12);
266 assert_relative_eq!(last.upper, 5.0, epsilon = 1e-12);
267 assert_relative_eq!(last.lower, 5.0, epsilon = 1e-12);
268 }
269
270 #[test]
271 fn matches_naive_definition() {
272 let prices: Vec<f64> = (1..=60)
273 .map(|i| (f64::from(i) * 0.3).sin() * 10.0 + 50.0)
274 .collect();
275 let mut bb = BollingerBands::new(20, 2.0).unwrap();
276 let out = bb.batch(&prices);
277 for i in 19..prices.len() {
278 let got = out[i].unwrap();
279 let want = naive(&prices[..=i], 20, 2.0);
280 assert_relative_eq!(got.middle, want.middle, epsilon = 1e-9);
281 assert_relative_eq!(got.stddev, want.stddev, epsilon = 1e-9);
282 assert_relative_eq!(got.upper, want.upper, epsilon = 1e-9);
283 assert_relative_eq!(got.lower, want.lower, epsilon = 1e-9);
284 }
285 }
286
287 #[test]
288 fn upper_above_middle_above_lower() {
289 let prices: Vec<f64> = (1..=100).map(f64::from).collect();
290 let mut bb = BollingerBands::new(20, 2.0).unwrap();
291 for o in bb.batch(&prices).into_iter().flatten() {
292 assert!(o.upper >= o.middle);
293 assert!(o.middle >= o.lower);
294 }
295 }
296
297 #[test]
298 fn batch_equals_streaming() {
299 let prices: Vec<f64> = (1..=50).map(|i| f64::from(i) * 0.7).collect();
300 let mut a = BollingerBands::new(10, 2.0).unwrap();
301 let mut b = BollingerBands::new(10, 2.0).unwrap();
302 assert_eq!(
303 a.batch(&prices),
304 prices.iter().map(|p| b.update(*p)).collect::<Vec<_>>()
305 );
306 }
307
308 #[test]
309 fn reset_clears_state() {
310 let mut bb = BollingerBands::new(5, 2.0).unwrap();
311 bb.batch(&[1.0, 2.0, 3.0, 4.0, 5.0]);
312 assert!(bb.is_ready());
313 bb.reset();
314 assert!(!bb.is_ready());
315 }
316
317 #[test]
323 fn long_stream_drift_stays_bounded() {
324 let period = 20;
325 let mult = 2.0;
326 let mut bb = BollingerBands::new(period, mult).unwrap();
327 let mut window: VecDeque<f64> = VecDeque::with_capacity(period);
328 let n_updates = 16 * period * 5;
330 let mut last = None;
331 for i in 0..n_updates {
332 let v = if i % 2 == 0 { 1e6 } else { 1.0 };
333 last = bb.update(v);
334 if window.len() == period {
335 window.pop_front();
336 }
337 window.push_back(v);
338 }
339 let scratch = naive(&window.iter().copied().collect::<Vec<_>>(), period, mult);
340 let got = last.expect("warmed up");
341 assert!(
342 (got.middle - scratch.middle).abs() < 1e-3,
343 "middle drift: got={}, scratch={}",
344 got.middle,
345 scratch.middle,
346 );
347 assert!(
348 (got.stddev - scratch.stddev).abs() < 1e-3,
349 "stddev drift: got={}, scratch={}",
350 got.stddev,
351 scratch.stddev,
352 );
353 }
354
355 #[test]
356 fn ignores_non_finite_input() {
357 let mut bb = BollingerBands::new(5, 2.0).unwrap();
358 let ready = bb.batch(&[1.0, 2.0, 3.0, 4.0, 5.0]);
359 let last = ready.last().unwrap().unwrap();
360 assert_eq!(bb.update(f64::NAN).unwrap(), last);
362 assert_eq!(bb.update(f64::INFINITY).unwrap(), last);
363 let after = bb.update(6.0).unwrap();
365 assert_relative_eq!(
366 after.middle,
367 (2.0 + 3.0 + 4.0 + 5.0 + 6.0) / 5.0,
368 epsilon = 1e-12
369 );
370 }
371}