wickra_core/indicators/
trend_strength_index.rs1use std::collections::VecDeque;
5
6use crate::error::{Error, Result};
7use crate::traits::Indicator;
8
9#[derive(Debug, Clone)]
40pub struct TrendStrengthIndex {
41 period: usize,
42 buf: VecDeque<f64>,
43}
44
45impl TrendStrengthIndex {
46 pub fn new(period: usize) -> Result<Self> {
53 if period == 0 {
54 return Err(Error::PeriodZero);
55 }
56 if period == 1 {
57 return Err(Error::InvalidPeriod {
58 message: "period must be >= 2 for a regression",
59 });
60 }
61 Ok(Self {
62 period,
63 buf: VecDeque::with_capacity(period),
64 })
65 }
66
67 pub const fn period(&self) -> usize {
69 self.period
70 }
71}
72
73impl Indicator for TrendStrengthIndex {
74 type Input = f64;
75 type Output = f64;
76
77 fn update(&mut self, price: f64) -> Option<f64> {
78 if !price.is_finite() {
79 return None;
80 }
81 self.buf.push_back(price);
82 if self.buf.len() > self.period {
83 self.buf.pop_front();
84 }
85 if self.buf.len() < self.period {
86 return None;
87 }
88
89 let count = self.period as f64;
90 let mut sum_x = 0.0;
91 let mut sum_xx = 0.0;
92 let mut sum_y = 0.0;
93 let mut sum_yy = 0.0;
94 let mut sum_xy = 0.0;
95 for (idx, &price) in self.buf.iter().enumerate() {
96 let x = idx as f64;
97 sum_x += x;
98 sum_xx += x * x;
99 sum_y += price;
100 sum_yy += price * price;
101 sum_xy += x * price;
102 }
103
104 let cov = count.mul_add(sum_xy, -(sum_x * sum_y));
105 let var_x = count.mul_add(sum_xx, -(sum_x * sum_x));
106 let var_y = count.mul_add(sum_yy, -(sum_y * sum_y));
107 if var_y <= 0.0 {
108 return Some(0.0);
109 }
110 let r2 = (cov * cov) / (var_x * var_y);
111 Some(if cov >= 0.0 { r2 } else { -r2 })
112 }
113
114 fn reset(&mut self) {
115 self.buf.clear();
116 }
117
118 fn warmup_period(&self) -> usize {
119 self.period
120 }
121
122 fn is_ready(&self) -> bool {
123 self.buf.len() >= self.period
124 }
125
126 fn name(&self) -> &'static str {
127 "TrendStrengthIndex"
128 }
129}
130
131#[cfg(test)]
132mod tests {
133 use super::*;
134 use crate::traits::BatchExt;
135 use approx::assert_relative_eq;
136
137 #[test]
138 fn rejects_invalid_period() {
139 assert!(matches!(TrendStrengthIndex::new(0), Err(Error::PeriodZero)));
140 assert!(matches!(
141 TrendStrengthIndex::new(1),
142 Err(Error::InvalidPeriod { .. })
143 ));
144 }
145
146 #[test]
147 fn accessors_and_metadata() {
148 let tsi = TrendStrengthIndex::new(20).unwrap();
149 assert_eq!(tsi.period(), 20);
150 assert_eq!(tsi.warmup_period(), 20);
151 assert_eq!(tsi.name(), "TrendStrengthIndex");
152 assert!(!tsi.is_ready());
153 }
154
155 #[test]
156 fn warmup_emits_at_period() {
157 let mut tsi = TrendStrengthIndex::new(4).unwrap();
158 let inputs: Vec<f64> = (0..6).map(f64::from).collect();
159 let out = tsi.batch(&inputs);
160 assert!(out[2].is_none());
161 assert!(out[3].is_some());
162 }
163
164 #[test]
165 fn perfect_uptrend_is_plus_one() {
166 let mut tsi = TrendStrengthIndex::new(10).unwrap();
167 let inputs: Vec<f64> = (0..10).map(f64::from).collect();
168 let last = tsi.batch(&inputs).last().unwrap().unwrap();
169 assert_relative_eq!(last, 1.0, epsilon = 1e-9);
170 }
171
172 #[test]
173 fn perfect_downtrend_is_minus_one() {
174 let mut tsi = TrendStrengthIndex::new(10).unwrap();
175 let inputs: Vec<f64> = (0..10).map(|i| 100.0 - f64::from(i)).collect();
176 let last = tsi.batch(&inputs).last().unwrap().unwrap();
177 assert_relative_eq!(last, -1.0, epsilon = 1e-9);
178 }
179
180 #[test]
181 fn flat_market_returns_zero() {
182 let mut tsi = TrendStrengthIndex::new(8).unwrap();
183 let inputs = [42.0; 12];
184 let last = tsi.batch(&inputs).last().unwrap().unwrap();
185 assert_relative_eq!(last, 0.0, epsilon = 1e-12);
186 }
187
188 #[test]
189 fn noisy_trend_is_between() {
190 let mut tsi = TrendStrengthIndex::new(12).unwrap();
192 let inputs: Vec<f64> = (0..12)
193 .map(|i| f64::from(i) + if i % 2 == 0 { 0.0 } else { 3.0 })
194 .collect();
195 let last = tsi.batch(&inputs).last().unwrap().unwrap();
196 assert!(last > 0.0 && last < 1.0, "tsi {last} should be in (0, 1)");
197 }
198
199 #[test]
200 fn reset_clears_state() {
201 let mut tsi = TrendStrengthIndex::new(10).unwrap();
202 let inputs: Vec<f64> = (0..10).map(f64::from).collect();
203 tsi.batch(&inputs);
204 assert!(tsi.is_ready());
205 tsi.reset();
206 assert!(!tsi.is_ready());
207 }
208
209 #[test]
210 fn batch_equals_streaming() {
211 let inputs: Vec<f64> = (0..80)
212 .map(|i| 100.0 + (f64::from(i) * 0.2).sin() * 5.0)
213 .collect();
214 let mut a = TrendStrengthIndex::new(15).unwrap();
215 let mut b = TrendStrengthIndex::new(15).unwrap();
216 assert_eq!(
217 a.batch(&inputs),
218 inputs.iter().map(|x| b.update(*x)).collect::<Vec<_>>()
219 );
220 }
221}