1use crate::error::{Error, Result};
4use crate::traits::Indicator;
5
6#[derive(Debug, Clone, Copy, PartialEq)]
8pub struct BollingerOutput {
9 pub upper: f64,
11 pub middle: f64,
13 pub lower: f64,
15 pub stddev: f64,
18}
19
20#[derive(Debug, Clone)]
47pub struct BollingerBands {
48 period: usize,
49 multiplier: f64,
50 buf: Box<[f64]>,
53 head: usize,
55 count: usize,
57 sum: f64,
58 sum_sq: f64,
59 updates_since_recompute: usize,
62}
63
64const RECOMPUTE_EVERY: usize = 16;
69
70impl BollingerBands {
71 pub fn new(period: usize, multiplier: f64) -> Result<Self> {
78 if period == 0 {
79 return Err(Error::PeriodZero);
80 }
81 if !multiplier.is_finite() || multiplier <= 0.0 {
82 return Err(Error::NonPositiveMultiplier);
83 }
84 Ok(Self {
85 period,
86 multiplier,
87 buf: vec![0.0; period].into_boxed_slice(),
88 head: 0,
89 count: 0,
90 sum: 0.0,
91 sum_sq: 0.0,
92 updates_since_recompute: 0,
93 })
94 }
95
96 pub fn classic() -> Self {
98 Self::new(20, 2.0).expect("classic Bollinger parameters are valid")
99 }
100
101 pub const fn period(&self) -> usize {
103 self.period
104 }
105
106 pub const fn multiplier(&self) -> f64 {
108 self.multiplier
109 }
110
111 pub fn batch_bands(&mut self, inputs: &[f64]) -> Vec<f64> {
125 let p = self.period;
126 let n = inputs.len();
127 if self.count != 0
128 || self.updates_since_recompute != 0
129 || !inputs.iter().all(|x| x.is_finite())
130 {
131 let mut out = vec![f64::NAN; n * 4];
133 for (i, &x) in inputs.iter().enumerate() {
134 if let Some(o) = self.update(x) {
135 out[i * 4] = o.upper;
136 out[i * 4 + 1] = o.middle;
137 out[i * 4 + 2] = o.lower;
138 out[i * 4 + 3] = o.stddev;
139 }
140 }
141 return out;
142 }
143
144 let p_f64 = p as f64;
145 let mult = self.multiplier;
146 let mut out = vec![f64::NAN; n * 4];
149 for (i, &x) in inputs.iter().enumerate() {
150 if self.count == p {
151 let old = self.buf[self.head];
152 self.sum -= old;
153 self.sum_sq -= old * old;
154 self.buf[self.head] = x;
155 self.sum += x;
156 self.sum_sq += x * x;
157 } else {
158 self.buf[self.head] = x;
159 self.sum += x;
160 self.sum_sq += x * x;
161 self.count += 1;
162 }
163 self.head += 1;
164 if self.head == p {
165 self.head = 0;
166 }
167 self.updates_since_recompute += 1;
168 if self.updates_since_recompute >= RECOMPUTE_EVERY * p {
169 let chronological = self.buf[self.head..].iter().chain(&self.buf[..self.head]);
170 self.sum = chronological.clone().copied().sum();
171 self.sum_sq = chronological.map(|&v| v * v).sum();
172 self.updates_since_recompute = 0;
173 }
174 if self.count == p {
175 let mean = self.sum / p_f64;
176 let stddev = (self.sum_sq / p_f64 - mean * mean).max(0.0).sqrt();
177 let band = mult * stddev;
178 out[i * 4] = mean + band;
179 out[i * 4 + 1] = mean;
180 out[i * 4 + 2] = mean - band;
181 out[i * 4 + 3] = stddev;
182 }
183 }
184 out
185 }
186
187 fn current(&self) -> Option<BollingerOutput> {
188 if self.count != self.period {
189 return None;
190 }
191 let n = self.period as f64;
192 let mean = self.sum / n;
193 let var = (self.sum_sq / n - mean * mean).max(0.0);
196 let stddev = var.sqrt();
197 Some(BollingerOutput {
198 upper: mean + self.multiplier * stddev,
199 middle: mean,
200 lower: mean - self.multiplier * stddev,
201 stddev,
202 })
203 }
204}
205
206impl Indicator for BollingerBands {
207 type Input = f64;
208 type Output = BollingerOutput;
209
210 fn update(&mut self, input: f64) -> Option<BollingerOutput> {
211 if !input.is_finite() {
212 return self.current();
213 }
214 if self.count == self.period {
215 let old = self.buf[self.head];
216 self.sum -= old;
217 self.sum_sq -= old * old;
218 self.buf[self.head] = input;
219 self.sum += input;
220 self.sum_sq += input * input;
221 } else {
222 self.buf[self.head] = input;
223 self.sum += input;
224 self.sum_sq += input * input;
225 self.count += 1;
226 }
227 self.head += 1;
228 if self.head == self.period {
229 self.head = 0;
230 }
231 self.updates_since_recompute += 1;
232 if self.updates_since_recompute >= RECOMPUTE_EVERY * self.period {
233 let chronological = self.buf[self.head..].iter().chain(&self.buf[..self.head]);
236 self.sum = chronological.clone().copied().sum();
237 self.sum_sq = chronological.map(|&x| x * x).sum();
238 self.updates_since_recompute = 0;
239 }
240 self.current()
241 }
242
243 fn reset(&mut self) {
244 self.head = 0;
245 self.count = 0;
246 self.sum = 0.0;
247 self.sum_sq = 0.0;
248 self.updates_since_recompute = 0;
249 }
250
251 fn warmup_period(&self) -> usize {
252 self.period
253 }
254
255 fn is_ready(&self) -> bool {
256 self.count == self.period
257 }
258
259 fn name(&self) -> &'static str {
260 "BollingerBands"
261 }
262}
263
264#[cfg(test)]
265mod tests {
266 use super::*;
267 use crate::traits::BatchExt;
268 use approx::assert_relative_eq;
269 use std::collections::VecDeque;
270
271 fn naive(prices: &[f64], period: usize, mult: f64) -> BollingerOutput {
272 assert!(
273 prices.len() >= period,
274 "naive requires at least `period` prices"
275 );
276 let w = &prices[prices.len() - period..];
277 let mean = w.iter().sum::<f64>() / period as f64;
278 let var = w.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / period as f64;
279 let s = var.sqrt();
280 BollingerOutput {
281 upper: mean + mult * s,
282 middle: mean,
283 lower: mean - mult * s,
284 stddev: s,
285 }
286 }
287
288 #[test]
289 fn rejects_zero_period() {
290 assert!(matches!(
291 BollingerBands::new(0, 2.0),
292 Err(Error::PeriodZero)
293 ));
294 }
295
296 #[test]
297 fn rejects_non_positive_multiplier() {
298 assert!(matches!(
299 BollingerBands::new(20, 0.0),
300 Err(Error::NonPositiveMultiplier)
301 ));
302 assert!(matches!(
303 BollingerBands::new(20, -1.0),
304 Err(Error::NonPositiveMultiplier)
305 ));
306 assert!(matches!(
307 BollingerBands::new(20, f64::NAN),
308 Err(Error::NonPositiveMultiplier)
309 ));
310 }
311
312 #[test]
318 fn classic_and_accessors_and_metadata() {
319 let bb = BollingerBands::classic();
320 assert_eq!(bb.period(), 20);
321 assert_relative_eq!(bb.multiplier(), 2.0, epsilon = 1e-12);
322 assert_eq!(bb.warmup_period(), 20);
323 assert_eq!(bb.name(), "BollingerBands");
324 }
325
326 #[test]
327 fn warmup_returns_none() {
328 let mut bb = BollingerBands::new(5, 2.0).unwrap();
329 for v in [1.0, 2.0, 3.0, 4.0] {
330 assert!(bb.update(v).is_none());
331 }
332 assert!(bb.update(5.0).is_some());
333 }
334
335 #[test]
336 fn constant_series_yields_zero_stddev() {
337 let mut bb = BollingerBands::new(10, 2.0).unwrap();
338 let out = bb.batch(&[5.0_f64; 30]);
339 let last = out.iter().rev().flatten().next().unwrap();
340 assert_relative_eq!(last.middle, 5.0, epsilon = 1e-12);
341 assert_relative_eq!(last.stddev, 0.0, epsilon = 1e-12);
342 assert_relative_eq!(last.upper, 5.0, epsilon = 1e-12);
343 assert_relative_eq!(last.lower, 5.0, epsilon = 1e-12);
344 }
345
346 #[test]
347 fn matches_naive_definition() {
348 let prices: Vec<f64> = (1..=60)
349 .map(|i| (f64::from(i) * 0.3).sin() * 10.0 + 50.0)
350 .collect();
351 let mut bb = BollingerBands::new(20, 2.0).unwrap();
352 let out = bb.batch(&prices);
353 for i in 19..prices.len() {
354 let got = out[i].unwrap();
355 let want = naive(&prices[..=i], 20, 2.0);
356 assert_relative_eq!(got.middle, want.middle, epsilon = 1e-9);
357 assert_relative_eq!(got.stddev, want.stddev, epsilon = 1e-9);
358 assert_relative_eq!(got.upper, want.upper, epsilon = 1e-9);
359 assert_relative_eq!(got.lower, want.lower, epsilon = 1e-9);
360 }
361 }
362
363 #[test]
364 fn upper_above_middle_above_lower() {
365 let prices: Vec<f64> = (1..=100).map(f64::from).collect();
366 let mut bb = BollingerBands::new(20, 2.0).unwrap();
367 for o in bb.batch(&prices).into_iter().flatten() {
368 assert!(o.upper >= o.middle);
369 assert!(o.middle >= o.lower);
370 }
371 }
372
373 #[test]
374 fn batch_equals_streaming() {
375 let prices: Vec<f64> = (1..=50).map(|i| f64::from(i) * 0.7).collect();
376 let mut a = BollingerBands::new(10, 2.0).unwrap();
377 let mut b = BollingerBands::new(10, 2.0).unwrap();
378 assert_eq!(
379 a.batch(&prices),
380 prices.iter().map(|p| b.update(*p)).collect::<Vec<_>>()
381 );
382 }
383
384 #[test]
385 fn reset_clears_state() {
386 let mut bb = BollingerBands::new(5, 2.0).unwrap();
387 bb.batch(&[1.0, 2.0, 3.0, 4.0, 5.0]);
388 assert!(bb.is_ready());
389 bb.reset();
390 assert!(!bb.is_ready());
391 }
392
393 #[test]
399 fn long_stream_drift_stays_bounded() {
400 let period = 20;
401 let mult = 2.0;
402 let mut bb = BollingerBands::new(period, mult).unwrap();
403 let mut window: VecDeque<f64> = VecDeque::with_capacity(period);
404 let n_updates = 16 * period * 5;
406 let mut last = None;
407 for i in 0..n_updates {
408 let v = if i % 2 == 0 { 1e6 } else { 1.0 };
409 last = bb.update(v);
410 if window.len() == period {
411 window.pop_front();
412 }
413 window.push_back(v);
414 }
415 let scratch = naive(&window.iter().copied().collect::<Vec<_>>(), period, mult);
416 let got = last.expect("warmed up");
417 assert!(
418 (got.middle - scratch.middle).abs() < 1e-3,
419 "middle drift: got={}, scratch={}",
420 got.middle,
421 scratch.middle,
422 );
423 assert!(
424 (got.stddev - scratch.stddev).abs() < 1e-3,
425 "stddev drift: got={}, scratch={}",
426 got.stddev,
427 scratch.stddev,
428 );
429 }
430
431 fn bits_eq(a: &[f64], b: &[f64]) -> bool {
432 a.len() == b.len()
433 && a.iter()
434 .zip(b)
435 .all(|(x, y)| x == y || (x.is_nan() && y.is_nan()))
436 }
437
438 fn bb_replay(period: usize, mult: f64, series: &[f64]) -> Vec<f64> {
440 let mut bb = BollingerBands::new(period, mult).unwrap();
441 let mut out = Vec::with_capacity(series.len() * 4);
442 for &x in series {
443 match bb.update(x) {
444 Some(o) => out.extend_from_slice(&[o.upper, o.middle, o.lower, o.stddev]),
445 None => out.extend_from_slice(&[f64::NAN; 4]),
446 }
447 }
448 out
449 }
450
451 #[test]
452 fn batch_bands_fast_path_is_bit_identical_with_reseed() {
453 let series: Vec<f64> = (0..500)
455 .map(|i| (f64::from(i) * 0.2).sin() * 10.0 + 50.0)
456 .collect();
457 let mut bb = BollingerBands::new(20, 2.0).unwrap();
458 let got = bb.batch_bands(&series);
459 assert!(bits_eq(&got, &bb_replay(20, 2.0, &series)));
460 let mut ref_bb = BollingerBands::new(20, 2.0).unwrap();
462 for &x in &series {
463 ref_bb.update(x);
464 }
465 assert_eq!(bb.update(55.0), ref_bb.update(55.0));
466 }
467
468 #[test]
469 fn batch_bands_falls_back_on_non_finite() {
470 let series = [1.0, 2.0, 3.0, f64::NAN, 5.0, 6.0, 7.0];
471 let mut bb = BollingerBands::new(3, 2.0).unwrap();
472 assert!(bits_eq(
473 &bb.batch_bands(&series),
474 &bb_replay(3, 2.0, &series)
475 ));
476 }
477
478 #[test]
479 fn batch_bands_falls_back_when_not_fresh() {
480 let mut bb = BollingerBands::new(3, 2.0).unwrap();
481 bb.update(99.0);
482 let series = [1.0, 2.0, 3.0, 4.0];
483 let mut ref_bb = BollingerBands::new(3, 2.0).unwrap();
484 ref_bb.update(99.0);
485 let mut want = Vec::new();
486 for &x in &series {
487 match ref_bb.update(x) {
488 Some(o) => want.extend_from_slice(&[o.upper, o.middle, o.lower, o.stddev]),
489 None => want.extend_from_slice(&[f64::NAN; 4]),
490 }
491 }
492 assert!(bits_eq(&bb.batch_bands(&series), &want));
493 }
494
495 #[test]
496 fn batch_bands_sub_period_slice_is_all_nan() {
497 let series = [1.0, 2.0, 3.0];
498 let mut bb = BollingerBands::new(10, 2.0).unwrap();
499 let got = bb.batch_bands(&series);
500 assert!(bits_eq(&got, &bb_replay(10, 2.0, &series)));
501 assert!(got.iter().all(|x| x.is_nan()) && got.len() == 12);
502 }
503
504 #[test]
505 fn ignores_non_finite_input() {
506 let mut bb = BollingerBands::new(5, 2.0).unwrap();
507 let ready = bb.batch(&[1.0, 2.0, 3.0, 4.0, 5.0]);
508 let last = ready.last().unwrap().unwrap();
509 assert_eq!(bb.update(f64::NAN).unwrap(), last);
511 assert_eq!(bb.update(f64::INFINITY).unwrap(), last);
512 let after = bb.update(6.0).unwrap();
514 assert_relative_eq!(
515 after.middle,
516 (2.0 + 3.0 + 4.0 + 5.0 + 6.0) / 5.0,
517 epsilon = 1e-12
518 );
519 }
520}