wickra_core/indicators/
sine_weighted_ma.rs1use std::collections::VecDeque;
4
5use crate::error::{Error, Result};
6use crate::traits::Indicator;
7
8#[derive(Debug, Clone)]
43pub struct SineWeightedMa {
44 period: usize,
45 window: VecDeque<f64>,
46 weights: Vec<f64>,
49 weights_total: f64,
50}
51
52impl SineWeightedMa {
53 pub fn new(period: usize) -> Result<Self> {
59 if period == 0 {
60 return Err(Error::PeriodZero);
61 }
62 let denom = period as f64 + 1.0;
63 let weights: Vec<f64> = (0..period)
64 .map(|i| (std::f64::consts::PI * (i as f64 + 1.0) / denom).sin())
65 .collect();
66 let weights_total = weights.iter().sum();
67 Ok(Self {
68 period,
69 window: VecDeque::with_capacity(period),
70 weights,
71 weights_total,
72 })
73 }
74
75 pub const fn period(&self) -> usize {
77 self.period
78 }
79
80 pub fn value(&self) -> Option<f64> {
82 if self.window.len() == self.period {
83 let dot: f64 = self
84 .window
85 .iter()
86 .zip(&self.weights)
87 .map(|(v, w)| v * w)
88 .sum();
89 Some(dot / self.weights_total)
90 } else {
91 None
92 }
93 }
94}
95
96impl Indicator for SineWeightedMa {
97 type Input = f64;
98 type Output = f64;
99
100 fn update(&mut self, input: f64) -> Option<f64> {
101 if !input.is_finite() {
102 return self.value();
103 }
104 if self.window.len() == self.period {
105 self.window.pop_front();
106 }
107 self.window.push_back(input);
108 self.value()
109 }
110
111 fn reset(&mut self) {
112 self.window.clear();
113 }
114
115 fn warmup_period(&self) -> usize {
116 self.period
117 }
118
119 fn is_ready(&self) -> bool {
120 self.window.len() == self.period
121 }
122
123 fn name(&self) -> &'static str {
124 "SWMA"
125 }
126}
127
128#[cfg(test)]
129mod tests {
130 use super::*;
131 use crate::traits::BatchExt;
132 use approx::assert_relative_eq;
133
134 fn swma_naive(prices: &[f64], period: usize) -> Vec<Option<f64>> {
136 let denom = period as f64 + 1.0;
137 let weights: Vec<f64> = (0..period)
138 .map(|i| (std::f64::consts::PI * (i as f64 + 1.0) / denom).sin())
139 .collect();
140 let total: f64 = weights.iter().sum();
141 prices
142 .iter()
143 .enumerate()
144 .map(|(i, _)| {
145 if i + 1 < period {
146 None
147 } else {
148 let window = &prices[i + 1 - period..=i];
149 let dot: f64 = window.iter().zip(&weights).map(|(v, w)| v * w).sum();
150 Some(dot / total)
151 }
152 })
153 .collect()
154 }
155
156 #[test]
157 fn new_rejects_zero_period() {
158 assert!(matches!(SineWeightedMa::new(0), Err(Error::PeriodZero)));
159 }
160
161 #[test]
164 fn accessors_and_metadata() {
165 let swma = SineWeightedMa::new(7).unwrap();
166 assert_eq!(swma.period(), 7);
167 assert_eq!(swma.warmup_period(), 7);
168 assert_eq!(swma.name(), "SWMA");
169 }
170
171 #[test]
172 fn warmup_returns_none() {
173 let mut swma = SineWeightedMa::new(3).unwrap();
174 assert_eq!(swma.update(1.0), None);
175 assert_eq!(swma.update(2.0), None);
176 let s = std::f64::consts::FRAC_1_SQRT_2;
179 let total = s + 1.0 + s;
180 let want = (s * 1.0 + 1.0 * 2.0 + s * 3.0) / total;
181 assert_relative_eq!(swma.update(3.0).unwrap(), want, epsilon = 1e-12);
182 }
183
184 #[test]
185 fn symmetric_weights_give_midpoint_on_linear_window() {
186 let mut swma = SineWeightedMa::new(5).unwrap();
189 let v = swma.batch(&[1.0, 2.0, 3.0, 4.0, 5.0]);
190 assert_relative_eq!(v[4].unwrap(), 3.0, epsilon = 1e-12);
191 }
192
193 #[test]
194 fn period_one_is_pass_through() {
195 let mut swma = SineWeightedMa::new(1).unwrap();
196 assert_relative_eq!(swma.update(5.5).unwrap(), 5.5, epsilon = 1e-12);
197 assert_relative_eq!(swma.update(7.5).unwrap(), 7.5, epsilon = 1e-12);
198 }
199
200 #[test]
201 fn matches_naive_over_inputs() {
202 let prices: Vec<f64> = (1..=30).map(|i| f64::from(i) * 1.7 - 5.0).collect();
203 let mut swma = SineWeightedMa::new(7).unwrap();
204 let got = swma.batch(&prices);
205 let want = swma_naive(&prices, 7);
206 for (i, (g, w)) in got.iter().zip(want.iter()).enumerate() {
207 assert_eq!(g.is_some(), w.is_some(), "warmup mismatch at index {i}");
208 if let (Some(a), Some(b)) = (g, w) {
209 assert_relative_eq!(*a, *b, epsilon = 1e-9);
210 }
211 }
212 }
213
214 #[test]
215 fn reset_clears_state() {
216 let mut swma = SineWeightedMa::new(4).unwrap();
217 swma.batch(&[1.0, 2.0, 3.0, 4.0, 5.0]);
218 assert!(swma.is_ready());
219 swma.reset();
220 assert!(!swma.is_ready());
221 assert_eq!(swma.update(10.0), None);
222 }
223
224 #[test]
225 fn batch_equals_streaming() {
226 let prices: Vec<f64> = (1..=20).map(|i| f64::from(i) * 0.5).collect();
227 let mut a = SineWeightedMa::new(5).unwrap();
228 let mut b = SineWeightedMa::new(5).unwrap();
229 assert_eq!(
230 a.batch(&prices),
231 prices.iter().map(|p| b.update(*p)).collect::<Vec<_>>()
232 );
233 }
234
235 #[test]
236 fn ignores_non_finite_input_but_keeps_state() {
237 let mut swma = SineWeightedMa::new(3).unwrap();
238 swma.update(1.0);
239 swma.update(2.0);
240 let ready = swma.update(3.0).expect("SWMA(3) ready after three inputs");
241 assert_eq!(swma.update(f64::NAN), Some(ready));
242 assert_eq!(swma.update(f64::INFINITY), Some(ready));
243 let s = std::f64::consts::FRAC_1_SQRT_2;
245 let total = s + 1.0 + s;
246 let want = (s * 2.0 + 1.0 * 3.0 + s * 4.0) / total;
247 assert_relative_eq!(swma.update(4.0).unwrap(), want, epsilon = 1e-12);
248 }
249
250 proptest::proptest! {
251 #![proptest_config(proptest::test_runner::Config::with_cases(48))]
252 #[test]
253 fn proptest_matches_naive(
254 period in 1usize..15,
255 prices in proptest::collection::vec(-500.0_f64..500.0, 0..120),
256 ) {
257 let mut swma = SineWeightedMa::new(period).unwrap();
258 let got = swma.batch(&prices);
259 let want = swma_naive(&prices, period);
260 proptest::prop_assert_eq!(got.len(), want.len());
261 for (g, w) in got.iter().zip(want.iter()) {
262 match (g, w) {
263 (None, None) => {}
264 (Some(a), Some(b)) => proptest::prop_assert!(
265 (a - b).abs() < 1e-7,
266 "got={a} want={b}"
267 ),
268 _ => proptest::prop_assert!(false, "warmup mismatch"),
269 }
270 }
271 }
272 }
273}