wickra_core/indicators/
sma.rs1use crate::error::{Error, Result};
4use crate::traits::Indicator;
5
6#[derive(Debug, Clone)]
32pub struct Sma {
33 period: usize,
34 buf: Box<[f64]>,
38 head: usize,
40 count: usize,
42 sum: f64,
43 updates_since_recompute: usize,
47}
48
49const RECOMPUTE_EVERY: usize = 16;
55
56impl Sma {
57 pub fn new(period: usize) -> Result<Self> {
63 if period == 0 {
64 return Err(Error::PeriodZero);
65 }
66 Ok(Self {
67 period,
68 buf: vec![0.0; period].into_boxed_slice(),
69 head: 0,
70 count: 0,
71 sum: 0.0,
72 updates_since_recompute: 0,
73 })
74 }
75
76 pub const fn period(&self) -> usize {
78 self.period
79 }
80
81 pub fn value(&self) -> Option<f64> {
83 if self.count == self.period {
84 Some(self.sum / self.period as f64)
85 } else {
86 None
87 }
88 }
89}
90
91impl Indicator for Sma {
92 type Input = f64;
93 type Output = f64;
94
95 fn update(&mut self, input: f64) -> Option<f64> {
96 if !input.is_finite() {
97 return self.value();
98 }
99 if self.count == self.period {
100 self.sum -= self.buf[self.head];
104 self.buf[self.head] = input;
105 self.sum += input;
106 } else {
107 self.buf[self.head] = input;
108 self.sum += input;
109 self.count += 1;
110 }
111 self.head += 1;
113 if self.head == self.period {
114 self.head = 0;
115 }
116 self.updates_since_recompute += 1;
117 if self.updates_since_recompute >= RECOMPUTE_EVERY * self.period {
118 self.sum = self.buf[self.head..]
121 .iter()
122 .chain(&self.buf[..self.head])
123 .copied()
124 .sum();
125 self.updates_since_recompute = 0;
126 }
127 self.value()
128 }
129
130 fn reset(&mut self) {
131 self.head = 0;
132 self.count = 0;
133 self.sum = 0.0;
134 self.updates_since_recompute = 0;
135 }
136
137 fn warmup_period(&self) -> usize {
138 self.period
139 }
140
141 fn is_ready(&self) -> bool {
142 self.count == self.period
143 }
144
145 fn name(&self) -> &'static str {
146 "SMA"
147 }
148}
149
150#[cfg(test)]
151mod tests {
152 use super::*;
153 use crate::traits::BatchExt;
154 use approx::assert_relative_eq;
155 use std::collections::VecDeque;
156
157 #[test]
158 fn new_rejects_zero_period() {
159 assert!(matches!(Sma::new(0), Err(Error::PeriodZero)));
160 }
161
162 #[test]
166 fn accessors_and_metadata() {
167 let sma = Sma::new(20).unwrap();
168 assert_eq!(sma.period(), 20);
169 assert_eq!(sma.warmup_period(), 20);
170 assert_eq!(sma.name(), "SMA");
171 }
172
173 #[test]
174 fn warmup_returns_none() {
175 let mut sma = Sma::new(3).unwrap();
176 assert_eq!(sma.update(1.0), None);
177 assert_eq!(sma.update(2.0), None);
178 assert_eq!(sma.update(3.0), Some(2.0));
179 }
180
181 #[test]
182 fn rolls_window_after_full() {
183 let mut sma = Sma::new(3).unwrap();
184 let out: Vec<_> = [1.0, 2.0, 3.0, 4.0, 5.0]
185 .iter()
186 .map(|p| sma.update(*p))
187 .collect();
188 assert_eq!(out, vec![None, None, Some(2.0), Some(3.0), Some(4.0)]);
189 }
190
191 #[test]
192 fn period_one_is_pass_through() {
193 let mut sma = Sma::new(1).unwrap();
194 assert_eq!(sma.update(5.0), Some(5.0));
195 assert_eq!(sma.update(10.0), Some(10.0));
196 }
197
198 #[test]
199 fn ignores_non_finite_input_but_keeps_state() {
200 let mut sma = Sma::new(3).unwrap();
201 sma.update(1.0);
202 sma.update(2.0);
203 sma.update(3.0);
204 assert_eq!(sma.update(f64::NAN), Some(2.0));
205 assert_eq!(sma.update(f64::INFINITY), Some(2.0));
206 assert_eq!(sma.update(6.0), Some((2.0 + 3.0 + 6.0) / 3.0));
208 }
209
210 #[test]
211 fn reset_clears_state() {
212 let mut sma = Sma::new(3).unwrap();
213 sma.batch(&[1.0, 2.0, 3.0]);
214 assert!(sma.is_ready());
215 sma.reset();
216 assert!(!sma.is_ready());
217 assert_eq!(sma.update(10.0), None);
218 }
219
220 #[test]
221 fn batch_equals_streaming() {
222 let prices: Vec<f64> = (1..=20).map(f64::from).collect();
223 let mut a = Sma::new(5).unwrap();
224 let batch = a.batch(&prices);
225 let mut b = Sma::new(5).unwrap();
226 let streamed: Vec<_> = prices.iter().map(|p| b.update(*p)).collect();
227 assert_eq!(batch, streamed);
228 }
229
230 #[test]
231 fn known_reference_values() {
232 let mut sma = Sma::new(3).unwrap();
234 let out = sma.batch(&[2.0, 4.0, 6.0, 8.0, 10.0]);
235 assert_eq!(out[2], Some(4.0));
236 assert_eq!(out[3], Some(6.0));
237 assert_eq!(out[4], Some(8.0));
238 }
239
240 #[test]
241 fn constant_series_yields_constant_sma() {
242 let mut sma = Sma::new(5).unwrap();
243 let v = sma.batch(&[7.0; 10]);
244 for x in v.iter().skip(4) {
245 assert_relative_eq!(x.unwrap(), 7.0, epsilon = 1e-12);
246 }
247 }
248
249 proptest::proptest! {
250 #![proptest_config(proptest::test_runner::Config::with_cases(64))]
251 #[test]
252 fn sma_matches_naive_definition(
253 period in 1usize..20,
254 prices in proptest::collection::vec(-1000.0_f64..1000.0, 0..200),
255 ) {
256 let mut sma = Sma::new(period).unwrap();
257 let stream: Vec<_> = prices.iter().map(|p| sma.update(*p)).collect();
258 for (i, got) in stream.iter().enumerate() {
259 if i + 1 < period {
260 proptest::prop_assert!(got.is_none());
261 } else {
262 let window = &prices[i + 1 - period..=i];
263 let expected = window.iter().sum::<f64>() / period as f64;
264 let actual = got.expect("ready");
265 proptest::prop_assert!(
266 (actual - expected).abs() < 1e-9,
267 "i={i} actual={actual} expected={expected}"
268 );
269 }
270 }
271 }
272 }
273
274 #[test]
281 fn long_stream_drift_stays_bounded() {
282 let period = 20;
283 let mut sma = Sma::new(period).unwrap();
284 let mut window: VecDeque<f64> = VecDeque::with_capacity(period);
285 let n_updates = 16 * period * 5;
287 for i in 0..n_updates {
288 let v = if i % 2 == 0 { 1e9 } else { 1.0 };
289 sma.update(v);
290 if window.len() == period {
291 window.pop_front();
292 }
293 window.push_back(v);
294 }
295 let from_scratch: f64 = window.iter().sum::<f64>() / period as f64;
296 let got = sma.value().expect("warmed up");
297 assert!(
298 (got - from_scratch).abs() < 1e-6,
299 "SMA drift exceeds 1e-6 over {n_updates} updates: got={got}, scratch={from_scratch}"
300 );
301 }
302}