wickra_core/indicators/
tsf.rs1use std::collections::VecDeque;
4
5use crate::error::{Error, Result};
6use crate::traits::Indicator;
7
8#[derive(Debug, Clone)]
38pub struct Tsf {
39 period: usize,
40 window: VecDeque<f64>,
41 sum_x: f64,
42 denom: f64,
43 sum_y: f64,
44 sum_xy: f64,
45}
46
47impl Tsf {
48 pub fn new(period: usize) -> Result<Self> {
54 if period < 2 {
55 return Err(Error::InvalidPeriod {
56 message: "time series forecast needs period >= 2",
57 });
58 }
59 let n = period as f64;
60 let sum_x = n * (n - 1.0) / 2.0;
61 let sum_xx = (n - 1.0) * n * (2.0 * n - 1.0) / 6.0;
62 Ok(Self {
63 period,
64 window: VecDeque::with_capacity(period),
65 sum_x,
66 denom: n * sum_xx - sum_x * sum_x,
67 sum_y: 0.0,
68 sum_xy: 0.0,
69 })
70 }
71
72 pub const fn period(&self) -> usize {
74 self.period
75 }
76}
77
78impl Indicator for Tsf {
79 type Input = f64;
80 type Output = f64;
81
82 fn update(&mut self, value: f64) -> Option<f64> {
83 if !value.is_finite() {
84 return None;
85 }
86 if self.window.len() == self.period {
87 let y0 = self.window.pop_front().expect("non-empty");
88 self.sum_xy = self.sum_xy - self.sum_y + y0;
89 self.sum_y -= y0;
90 }
91 let k = self.window.len() as f64;
92 self.window.push_back(value);
93 self.sum_y += value;
94 self.sum_xy += k * value;
95
96 if self.window.len() < self.period {
97 return None;
98 }
99 let n = self.period as f64;
100 let slope = (n * self.sum_xy - self.sum_x * self.sum_y) / self.denom;
101 let intercept = (self.sum_y - slope * self.sum_x) / n;
102 Some(intercept + slope * n)
103 }
104
105 fn reset(&mut self) {
106 self.window.clear();
107 self.sum_y = 0.0;
108 self.sum_xy = 0.0;
109 }
110
111 fn warmup_period(&self) -> usize {
112 self.period
113 }
114
115 fn is_ready(&self) -> bool {
116 self.window.len() == self.period
117 }
118
119 fn name(&self) -> &'static str {
120 "TSF"
121 }
122}
123
124#[cfg(test)]
125mod tests {
126 use super::*;
127 use crate::traits::BatchExt;
128 use approx::assert_relative_eq;
129
130 #[test]
131 fn rejects_short_period() {
132 assert!(matches!(Tsf::new(1), Err(Error::InvalidPeriod { .. })));
133 }
134
135 #[test]
136 fn accessors_report_config() {
137 let tsf = Tsf::new(5).unwrap();
138 assert_eq!(tsf.period(), 5);
139 assert_eq!(tsf.name(), "TSF");
140 assert_eq!(tsf.warmup_period(), 5);
141 assert!(!tsf.is_ready());
142 }
143
144 #[test]
145 fn reference_value() {
146 let mut tsf = Tsf::new(3).unwrap();
148 let out: Vec<Option<f64>> = tsf.batch(&[1.0, 2.0, 9.0]);
149 assert!(out[0].is_none());
150 assert!(out[1].is_none());
151 assert_relative_eq!(out[2].unwrap(), 12.0, epsilon = 1e-9);
152 assert!(tsf.is_ready());
153 }
154
155 #[test]
156 fn forecasts_a_clean_line_one_step_ahead() {
157 let mut tsf = Tsf::new(3).unwrap();
159 let out: Vec<Option<f64>> = tsf.batch(&[1.0, 10.0, 12.0, 14.0]);
160 assert_relative_eq!(out[3].unwrap(), 16.0, epsilon = 1e-9);
161 }
162
163 #[test]
164 fn reset_clears_state() {
165 let mut tsf = Tsf::new(3).unwrap();
166 let _ = tsf.batch(&[1.0, 2.0, 9.0]);
167 assert!(tsf.is_ready());
168 tsf.reset();
169 assert!(!tsf.is_ready());
170 assert_eq!(tsf.update(1.0), None);
171 }
172}