wickra_core/indicators/
sma.rs1use std::collections::VecDeque;
4
5use crate::error::{Error, Result};
6use crate::traits::Indicator;
7
8#[derive(Debug, Clone)]
34pub struct Sma {
35 period: usize,
36 window: VecDeque<f64>,
37 sum: f64,
38 updates_since_recompute: usize,
42}
43
44const RECOMPUTE_EVERY: usize = 16;
50
51impl Sma {
52 pub fn new(period: usize) -> Result<Self> {
58 if period == 0 {
59 return Err(Error::PeriodZero);
60 }
61 Ok(Self {
62 period,
63 window: VecDeque::with_capacity(period),
64 sum: 0.0,
65 updates_since_recompute: 0,
66 })
67 }
68
69 pub const fn period(&self) -> usize {
71 self.period
72 }
73
74 pub fn value(&self) -> Option<f64> {
76 if self.window.len() == self.period {
77 Some(self.sum / self.period as f64)
78 } else {
79 None
80 }
81 }
82}
83
84impl Indicator for Sma {
85 type Input = f64;
86 type Output = f64;
87
88 fn update(&mut self, input: f64) -> Option<f64> {
89 if !input.is_finite() {
90 return self.value();
91 }
92 if self.window.len() == self.period {
93 let old = self.window.pop_front().expect("window non-empty");
97 self.sum -= old;
98 }
99 self.window.push_back(input);
100 self.sum += input;
101 self.updates_since_recompute += 1;
102 if self.updates_since_recompute >= RECOMPUTE_EVERY * self.period {
103 self.sum = self.window.iter().copied().sum();
104 self.updates_since_recompute = 0;
105 }
106 self.value()
107 }
108
109 fn reset(&mut self) {
110 self.window.clear();
111 self.sum = 0.0;
112 self.updates_since_recompute = 0;
113 }
114
115 fn warmup_period(&self) -> usize {
116 self.period
117 }
118
119 fn is_ready(&self) -> bool {
120 self.window.len() == self.period
121 }
122
123 fn name(&self) -> &'static str {
124 "SMA"
125 }
126}
127
128#[cfg(test)]
129mod tests {
130 use super::*;
131 use crate::traits::BatchExt;
132 use approx::assert_relative_eq;
133
134 #[test]
135 fn new_rejects_zero_period() {
136 assert!(matches!(Sma::new(0), Err(Error::PeriodZero)));
137 }
138
139 #[test]
143 fn accessors_and_metadata() {
144 let sma = Sma::new(20).unwrap();
145 assert_eq!(sma.period(), 20);
146 assert_eq!(sma.warmup_period(), 20);
147 assert_eq!(sma.name(), "SMA");
148 }
149
150 #[test]
151 fn warmup_returns_none() {
152 let mut sma = Sma::new(3).unwrap();
153 assert_eq!(sma.update(1.0), None);
154 assert_eq!(sma.update(2.0), None);
155 assert_eq!(sma.update(3.0), Some(2.0));
156 }
157
158 #[test]
159 fn rolls_window_after_full() {
160 let mut sma = Sma::new(3).unwrap();
161 let out: Vec<_> = [1.0, 2.0, 3.0, 4.0, 5.0]
162 .iter()
163 .map(|p| sma.update(*p))
164 .collect();
165 assert_eq!(out, vec![None, None, Some(2.0), Some(3.0), Some(4.0)]);
166 }
167
168 #[test]
169 fn period_one_is_pass_through() {
170 let mut sma = Sma::new(1).unwrap();
171 assert_eq!(sma.update(5.0), Some(5.0));
172 assert_eq!(sma.update(10.0), Some(10.0));
173 }
174
175 #[test]
176 fn ignores_non_finite_input_but_keeps_state() {
177 let mut sma = Sma::new(3).unwrap();
178 sma.update(1.0);
179 sma.update(2.0);
180 sma.update(3.0);
181 assert_eq!(sma.update(f64::NAN), Some(2.0));
182 assert_eq!(sma.update(f64::INFINITY), Some(2.0));
183 assert_eq!(sma.update(6.0), Some((2.0 + 3.0 + 6.0) / 3.0));
185 }
186
187 #[test]
188 fn reset_clears_state() {
189 let mut sma = Sma::new(3).unwrap();
190 sma.batch(&[1.0, 2.0, 3.0]);
191 assert!(sma.is_ready());
192 sma.reset();
193 assert!(!sma.is_ready());
194 assert_eq!(sma.update(10.0), None);
195 }
196
197 #[test]
198 fn batch_equals_streaming() {
199 let prices: Vec<f64> = (1..=20).map(f64::from).collect();
200 let mut a = Sma::new(5).unwrap();
201 let batch = a.batch(&prices);
202 let mut b = Sma::new(5).unwrap();
203 let streamed: Vec<_> = prices.iter().map(|p| b.update(*p)).collect();
204 assert_eq!(batch, streamed);
205 }
206
207 #[test]
208 fn known_reference_values() {
209 let mut sma = Sma::new(3).unwrap();
211 let out = sma.batch(&[2.0, 4.0, 6.0, 8.0, 10.0]);
212 assert_eq!(out[2], Some(4.0));
213 assert_eq!(out[3], Some(6.0));
214 assert_eq!(out[4], Some(8.0));
215 }
216
217 #[test]
218 fn constant_series_yields_constant_sma() {
219 let mut sma = Sma::new(5).unwrap();
220 let v = sma.batch(&[7.0; 10]);
221 for x in v.iter().skip(4) {
222 assert_relative_eq!(x.unwrap(), 7.0, epsilon = 1e-12);
223 }
224 }
225
226 proptest::proptest! {
227 #![proptest_config(proptest::test_runner::Config::with_cases(64))]
228 #[test]
229 fn sma_matches_naive_definition(
230 period in 1usize..20,
231 prices in proptest::collection::vec(-1000.0_f64..1000.0, 0..200),
232 ) {
233 let mut sma = Sma::new(period).unwrap();
234 let stream: Vec<_> = prices.iter().map(|p| sma.update(*p)).collect();
235 for (i, got) in stream.iter().enumerate() {
236 if i + 1 < period {
237 proptest::prop_assert!(got.is_none());
238 } else {
239 let window = &prices[i + 1 - period..=i];
240 let expected = window.iter().sum::<f64>() / period as f64;
241 let actual = got.expect("ready");
242 proptest::prop_assert!(
243 (actual - expected).abs() < 1e-9,
244 "i={i} actual={actual} expected={expected}"
245 );
246 }
247 }
248 }
249 }
250
251 #[test]
258 fn long_stream_drift_stays_bounded() {
259 let period = 20;
260 let mut sma = Sma::new(period).unwrap();
261 let mut window: VecDeque<f64> = VecDeque::with_capacity(period);
262 let n_updates = 16 * period * 5;
264 for i in 0..n_updates {
265 let v = if i % 2 == 0 { 1e9 } else { 1.0 };
266 sma.update(v);
267 if window.len() == period {
268 window.pop_front();
269 }
270 window.push_back(v);
271 }
272 let from_scratch: f64 = window.iter().sum::<f64>() / period as f64;
273 let got = sma.value().expect("warmed up");
274 assert!(
275 (got - from_scratch).abs() < 1e-6,
276 "SMA drift exceeds 1e-6 over {n_updates} updates: got={got}, scratch={from_scratch}"
277 );
278 }
279}