wickra_core/indicators/
trend_strength_index.rs1use std::collections::VecDeque;
5
6use crate::error::{Error, Result};
7use crate::traits::Indicator;
8
9#[derive(Debug, Clone)]
40pub struct TrendStrengthIndex {
41 period: usize,
42 buf: VecDeque<f64>,
43}
44
45impl TrendStrengthIndex {
46 pub fn new(period: usize) -> Result<Self> {
53 if period == 0 {
54 return Err(Error::PeriodZero);
55 }
56 if period == 1 {
57 return Err(Error::InvalidPeriod {
58 message: "period must be >= 2 for a regression",
59 });
60 }
61 Ok(Self {
62 period,
63 buf: VecDeque::with_capacity(period),
64 })
65 }
66
67 pub const fn period(&self) -> usize {
69 self.period
70 }
71}
72
73impl Indicator for TrendStrengthIndex {
74 type Input = f64;
75 type Output = f64;
76
77 fn update(&mut self, price: f64) -> Option<f64> {
78 self.buf.push_back(price);
79 if self.buf.len() > self.period {
80 self.buf.pop_front();
81 }
82 if self.buf.len() < self.period {
83 return None;
84 }
85
86 let count = self.period as f64;
87 let mut sum_x = 0.0;
88 let mut sum_xx = 0.0;
89 let mut sum_y = 0.0;
90 let mut sum_yy = 0.0;
91 let mut sum_xy = 0.0;
92 for (idx, &price) in self.buf.iter().enumerate() {
93 let x = idx as f64;
94 sum_x += x;
95 sum_xx += x * x;
96 sum_y += price;
97 sum_yy += price * price;
98 sum_xy += x * price;
99 }
100
101 let cov = count.mul_add(sum_xy, -(sum_x * sum_y));
102 let var_x = count.mul_add(sum_xx, -(sum_x * sum_x));
103 let var_y = count.mul_add(sum_yy, -(sum_y * sum_y));
104 if var_y <= 0.0 {
105 return Some(0.0);
106 }
107 let r2 = (cov * cov) / (var_x * var_y);
108 Some(if cov >= 0.0 { r2 } else { -r2 })
109 }
110
111 fn reset(&mut self) {
112 self.buf.clear();
113 }
114
115 fn warmup_period(&self) -> usize {
116 self.period
117 }
118
119 fn is_ready(&self) -> bool {
120 self.buf.len() >= self.period
121 }
122
123 fn name(&self) -> &'static str {
124 "TrendStrengthIndex"
125 }
126}
127
128#[cfg(test)]
129mod tests {
130 use super::*;
131 use crate::traits::BatchExt;
132 use approx::assert_relative_eq;
133
134 #[test]
135 fn rejects_invalid_period() {
136 assert!(matches!(TrendStrengthIndex::new(0), Err(Error::PeriodZero)));
137 assert!(matches!(
138 TrendStrengthIndex::new(1),
139 Err(Error::InvalidPeriod { .. })
140 ));
141 }
142
143 #[test]
144 fn accessors_and_metadata() {
145 let tsi = TrendStrengthIndex::new(20).unwrap();
146 assert_eq!(tsi.period(), 20);
147 assert_eq!(tsi.warmup_period(), 20);
148 assert_eq!(tsi.name(), "TrendStrengthIndex");
149 assert!(!tsi.is_ready());
150 }
151
152 #[test]
153 fn warmup_emits_at_period() {
154 let mut tsi = TrendStrengthIndex::new(4).unwrap();
155 let inputs: Vec<f64> = (0..6).map(f64::from).collect();
156 let out = tsi.batch(&inputs);
157 assert!(out[2].is_none());
158 assert!(out[3].is_some());
159 }
160
161 #[test]
162 fn perfect_uptrend_is_plus_one() {
163 let mut tsi = TrendStrengthIndex::new(10).unwrap();
164 let inputs: Vec<f64> = (0..10).map(f64::from).collect();
165 let last = tsi.batch(&inputs).last().unwrap().unwrap();
166 assert_relative_eq!(last, 1.0, epsilon = 1e-9);
167 }
168
169 #[test]
170 fn perfect_downtrend_is_minus_one() {
171 let mut tsi = TrendStrengthIndex::new(10).unwrap();
172 let inputs: Vec<f64> = (0..10).map(|i| 100.0 - f64::from(i)).collect();
173 let last = tsi.batch(&inputs).last().unwrap().unwrap();
174 assert_relative_eq!(last, -1.0, epsilon = 1e-9);
175 }
176
177 #[test]
178 fn flat_market_returns_zero() {
179 let mut tsi = TrendStrengthIndex::new(8).unwrap();
180 let inputs = [42.0; 12];
181 let last = tsi.batch(&inputs).last().unwrap().unwrap();
182 assert_relative_eq!(last, 0.0, epsilon = 1e-12);
183 }
184
185 #[test]
186 fn noisy_trend_is_between() {
187 let mut tsi = TrendStrengthIndex::new(12).unwrap();
189 let inputs: Vec<f64> = (0..12)
190 .map(|i| f64::from(i) + if i % 2 == 0 { 0.0 } else { 3.0 })
191 .collect();
192 let last = tsi.batch(&inputs).last().unwrap().unwrap();
193 assert!(last > 0.0 && last < 1.0, "tsi {last} should be in (0, 1)");
194 }
195
196 #[test]
197 fn reset_clears_state() {
198 let mut tsi = TrendStrengthIndex::new(10).unwrap();
199 let inputs: Vec<f64> = (0..10).map(f64::from).collect();
200 tsi.batch(&inputs);
201 assert!(tsi.is_ready());
202 tsi.reset();
203 assert!(!tsi.is_ready());
204 }
205
206 #[test]
207 fn batch_equals_streaming() {
208 let inputs: Vec<f64> = (0..80)
209 .map(|i| 100.0 + (f64::from(i) * 0.2).sin() * 5.0)
210 .collect();
211 let mut a = TrendStrengthIndex::new(15).unwrap();
212 let mut b = TrendStrengthIndex::new(15).unwrap();
213 assert_eq!(
214 a.batch(&inputs),
215 inputs.iter().map(|x| b.update(*x)).collect::<Vec<_>>()
216 );
217 }
218}