wickra_core/indicators/
tsf.rs1use std::collections::VecDeque;
4
5use crate::error::{Error, Result};
6use crate::traits::Indicator;
7
8#[derive(Debug, Clone)]
38pub struct Tsf {
39 period: usize,
40 window: VecDeque<f64>,
41 sum_x: f64,
42 denom: f64,
43 sum_y: f64,
44 sum_xy: f64,
45}
46
47impl Tsf {
48 pub fn new(period: usize) -> Result<Self> {
54 if period < 2 {
55 return Err(Error::InvalidPeriod {
56 message: "time series forecast needs period >= 2",
57 });
58 }
59 let n = period as f64;
60 let sum_x = n * (n - 1.0) / 2.0;
61 let sum_xx = (n - 1.0) * n * (2.0 * n - 1.0) / 6.0;
62 Ok(Self {
63 period,
64 window: VecDeque::with_capacity(period),
65 sum_x,
66 denom: n * sum_xx - sum_x * sum_x,
67 sum_y: 0.0,
68 sum_xy: 0.0,
69 })
70 }
71
72 pub const fn period(&self) -> usize {
74 self.period
75 }
76}
77
78impl Indicator for Tsf {
79 type Input = f64;
80 type Output = f64;
81
82 fn update(&mut self, value: f64) -> Option<f64> {
83 if self.window.len() == self.period {
84 let y0 = self.window.pop_front().expect("non-empty");
85 self.sum_xy = self.sum_xy - self.sum_y + y0;
86 self.sum_y -= y0;
87 }
88 let k = self.window.len() as f64;
89 self.window.push_back(value);
90 self.sum_y += value;
91 self.sum_xy += k * value;
92
93 if self.window.len() < self.period {
94 return None;
95 }
96 let n = self.period as f64;
97 let slope = (n * self.sum_xy - self.sum_x * self.sum_y) / self.denom;
98 let intercept = (self.sum_y - slope * self.sum_x) / n;
99 Some(intercept + slope * n)
100 }
101
102 fn reset(&mut self) {
103 self.window.clear();
104 self.sum_y = 0.0;
105 self.sum_xy = 0.0;
106 }
107
108 fn warmup_period(&self) -> usize {
109 self.period
110 }
111
112 fn is_ready(&self) -> bool {
113 self.window.len() == self.period
114 }
115
116 fn name(&self) -> &'static str {
117 "TSF"
118 }
119}
120
121#[cfg(test)]
122mod tests {
123 use super::*;
124 use crate::traits::BatchExt;
125 use approx::assert_relative_eq;
126
127 #[test]
128 fn rejects_short_period() {
129 assert!(matches!(Tsf::new(1), Err(Error::InvalidPeriod { .. })));
130 }
131
132 #[test]
133 fn accessors_report_config() {
134 let tsf = Tsf::new(5).unwrap();
135 assert_eq!(tsf.period(), 5);
136 assert_eq!(tsf.name(), "TSF");
137 assert_eq!(tsf.warmup_period(), 5);
138 assert!(!tsf.is_ready());
139 }
140
141 #[test]
142 fn reference_value() {
143 let mut tsf = Tsf::new(3).unwrap();
145 let out: Vec<Option<f64>> = tsf.batch(&[1.0, 2.0, 9.0]);
146 assert!(out[0].is_none());
147 assert!(out[1].is_none());
148 assert_relative_eq!(out[2].unwrap(), 12.0, epsilon = 1e-9);
149 assert!(tsf.is_ready());
150 }
151
152 #[test]
153 fn forecasts_a_clean_line_one_step_ahead() {
154 let mut tsf = Tsf::new(3).unwrap();
156 let out: Vec<Option<f64>> = tsf.batch(&[1.0, 10.0, 12.0, 14.0]);
157 assert_relative_eq!(out[3].unwrap(), 16.0, epsilon = 1e-9);
158 }
159
160 #[test]
161 fn reset_clears_state() {
162 let mut tsf = Tsf::new(3).unwrap();
163 let _ = tsf.batch(&[1.0, 2.0, 9.0]);
164 assert!(tsf.is_ready());
165 tsf.reset();
166 assert!(!tsf.is_ready());
167 assert_eq!(tsf.update(1.0), None);
168 }
169}