1use crate::error::{Error, Result};
4use crate::traits::Indicator;
5
6#[derive(Debug, Clone)]
26pub struct Rsi {
27 period: usize,
28 n_minus_1: f64,
30 inv_period: f64,
33 prev_close: f64,
36 has_prev: bool,
37 seed_buf_gains: Vec<f64>,
40 seed_buf_losses: Vec<f64>,
41 avg_gain: f64,
44 avg_loss: f64,
45 avgs_seeded: bool,
46 last_value: Option<f64>,
47}
48
49impl Rsi {
50 pub fn new(period: usize) -> Result<Self> {
56 if period == 0 {
57 return Err(Error::PeriodZero);
58 }
59 Ok(Self {
60 period,
61 n_minus_1: (period - 1) as f64,
62 inv_period: 1.0 / period as f64,
63 prev_close: 0.0,
64 has_prev: false,
65 seed_buf_gains: Vec::with_capacity(period),
66 seed_buf_losses: Vec::with_capacity(period),
67 avg_gain: 0.0,
68 avg_loss: 0.0,
69 avgs_seeded: false,
70 last_value: None,
71 })
72 }
73
74 pub const fn period(&self) -> usize {
76 self.period
77 }
78
79 pub const fn value(&self) -> Option<f64> {
81 self.last_value
82 }
83
84 fn rsi_from_avgs(avg_gain: f64, avg_loss: f64) -> f64 {
85 let denom = avg_gain + avg_loss;
91 if denom == 0.0 {
92 50.0
93 } else {
94 100.0 * avg_gain / denom
95 }
96 }
97}
98
99impl Indicator for Rsi {
100 type Input = f64;
101 type Output = f64;
102
103 fn update(&mut self, input: f64) -> Option<f64> {
104 if !input.is_finite() {
105 return self.last_value;
106 }
107
108 if !self.has_prev {
109 self.prev_close = input;
110 self.has_prev = true;
111 return None;
112 }
113 let prev = self.prev_close;
114 self.prev_close = input;
115
116 let diff = input - prev;
117 let gain = if diff > 0.0 { diff } else { 0.0 };
118 let loss = if diff < 0.0 { -diff } else { 0.0 };
119
120 if self.avgs_seeded {
121 let new_ag = self.avg_gain.mul_add(self.n_minus_1, gain) * self.inv_period;
124 let new_al = self.avg_loss.mul_add(self.n_minus_1, loss) * self.inv_period;
125 self.avg_gain = new_ag;
126 self.avg_loss = new_al;
127 let v = Self::rsi_from_avgs(new_ag, new_al);
128 self.last_value = Some(v);
129 return Some(v);
130 }
131
132 self.seed_buf_gains.push(gain);
133 self.seed_buf_losses.push(loss);
134 if self.seed_buf_gains.len() == self.period {
135 let ag = self.seed_buf_gains.iter().sum::<f64>() / self.period as f64;
136 let al = self.seed_buf_losses.iter().sum::<f64>() / self.period as f64;
137 self.avg_gain = ag;
138 self.avg_loss = al;
139 self.avgs_seeded = true;
140 let v = Self::rsi_from_avgs(ag, al);
141 self.last_value = Some(v);
142 return Some(v);
143 }
144 None
145 }
146
147 fn reset(&mut self) {
148 self.prev_close = 0.0;
149 self.has_prev = false;
150 self.seed_buf_gains.clear();
151 self.seed_buf_losses.clear();
152 self.avg_gain = 0.0;
153 self.avg_loss = 0.0;
154 self.avgs_seeded = false;
155 self.last_value = None;
156 }
157
158 fn warmup_period(&self) -> usize {
159 self.period + 1
160 }
161
162 fn is_ready(&self) -> bool {
163 self.last_value.is_some()
164 }
165
166 fn name(&self) -> &'static str {
167 "RSI"
168 }
169}
170
171#[cfg(test)]
172mod tests {
173 use super::*;
174 use crate::traits::BatchExt;
175 use approx::assert_relative_eq;
176
177 fn rsi_naive(prices: &[f64], period: usize) -> Vec<Option<f64>> {
179 let n = period as f64;
180 let mut out = vec![None; prices.len()];
181 let mut gains: Vec<f64> = Vec::new();
182 let mut losses: Vec<f64> = Vec::new();
183 let mut avg_gain: Option<f64> = None;
184 let mut avg_loss: Option<f64> = None;
185 let rsi_val = |ag: f64, al: f64| -> f64 {
186 if al == 0.0 {
187 if ag == 0.0 {
188 50.0
189 } else {
190 100.0
191 }
192 } else {
193 100.0 - 100.0 / (1.0 + ag / al)
194 }
195 };
196 for i in 1..prices.len() {
197 let diff = prices[i] - prices[i - 1];
198 let gain = if diff > 0.0 { diff } else { 0.0 };
199 let loss = if diff < 0.0 { -diff } else { 0.0 };
200 if let (Some(ag), Some(al)) = (avg_gain, avg_loss) {
201 let nag = (ag * (n - 1.0) + gain) / n;
202 let nal = (al * (n - 1.0) + loss) / n;
203 avg_gain = Some(nag);
204 avg_loss = Some(nal);
205 out[i] = Some(rsi_val(nag, nal));
206 } else {
207 gains.push(gain);
208 losses.push(loss);
209 if gains.len() == period {
210 let ag = gains.iter().sum::<f64>() / n;
211 let al = losses.iter().sum::<f64>() / n;
212 avg_gain = Some(ag);
213 avg_loss = Some(al);
214 out[i] = Some(rsi_val(ag, al));
215 }
216 }
217 }
218 out
219 }
220
221 #[test]
222 fn new_rejects_zero_period() {
223 assert!(matches!(Rsi::new(0), Err(Error::PeriodZero)));
224 }
225
226 #[test]
230 fn accessors_and_metadata() {
231 let mut rsi = Rsi::new(14).unwrap();
232 assert_eq!(rsi.period(), 14);
233 assert_eq!(rsi.name(), "RSI");
234 assert_eq!(rsi.value(), None);
235 for i in 1..=15 {
236 rsi.update(100.0 + f64::from(i));
237 }
238 assert!(rsi.value().is_some());
239 }
240
241 #[test]
247 fn naive_helper_flat_series_yields_50() {
248 let ks = rsi_naive(&[42.0; 20], 5);
249 for r in ks.into_iter().skip(5) {
250 assert_eq!(r.expect("ready after period+1 inputs"), 50.0);
251 }
252 }
253
254 #[test]
260 fn naive_helper_monotone_up_yields_100() {
261 let prices: Vec<f64> = (1..=20).map(f64::from).collect();
262 let ks = rsi_naive(&prices, 5);
263 for r in ks.into_iter().skip(5) {
264 assert_eq!(r.expect("ready after period+1 inputs"), 100.0);
265 }
266 }
267
268 #[test]
269 fn warmup_period_is_period_plus_one() {
270 let rsi = Rsi::new(14).unwrap();
271 assert_eq!(rsi.warmup_period(), 15);
272 }
273
274 #[test]
275 fn first_emission_at_index_period() {
276 let prices: Vec<f64> = (1..=20).map(f64::from).collect();
278 let mut rsi = Rsi::new(14).unwrap();
279 let out = rsi.batch(&prices);
280 for x in &out[..14] {
282 assert!(x.is_none());
283 }
284 assert!(out[14].is_some());
285 }
286
287 #[test]
288 fn pure_uptrend_yields_rsi_100() {
289 let prices: Vec<f64> = (1..=20).map(f64::from).collect();
290 let mut rsi = Rsi::new(14).unwrap();
291 let out = rsi.batch(&prices);
292 for v in out.iter().filter_map(|x| x.as_ref()) {
294 assert_relative_eq!(*v, 100.0, epsilon = 1e-9);
295 }
296 }
297
298 #[test]
299 fn pure_downtrend_yields_rsi_0() {
300 let prices: Vec<f64> = (1..=20).rev().map(f64::from).collect();
301 let mut rsi = Rsi::new(14).unwrap();
302 let out = rsi.batch(&prices);
303 for v in out.iter().filter_map(|x| x.as_ref()) {
304 assert_relative_eq!(*v, 0.0, epsilon = 1e-9);
305 }
306 }
307
308 #[test]
309 fn flat_series_yields_rsi_50() {
310 let prices = [10.0_f64; 30];
311 let mut rsi = Rsi::new(14).unwrap();
312 let out = rsi.batch(&prices);
313 for v in out.iter().filter_map(|x| x.as_ref()) {
314 assert_relative_eq!(*v, 50.0, epsilon = 1e-12);
315 }
316 }
317
318 #[test]
319 fn classic_wilder_textbook_values() {
320 let prices = [
325 44.34, 44.09, 44.15, 43.61, 44.33, 44.83, 45.10, 45.42, 45.84, 46.08, 45.89, 46.03,
326 45.61, 46.28, 46.28,
327 ];
328 let mut rsi = Rsi::new(14).unwrap();
329 let out = rsi.batch(&prices);
330 let first = out[14].expect("first RSI emitted at index period");
331 assert_relative_eq!(first, 70.464, epsilon = 0.05);
332 }
333
334 #[test]
335 fn rsi_stays_in_0_100_range() {
336 let prices: Vec<f64> = (0..200)
337 .map(|i| 100.0 + (f64::from(i) * 0.7).sin() * 10.0)
338 .collect();
339 let mut rsi = Rsi::new(14).unwrap();
340 for x in rsi.batch(&prices).into_iter().flatten() {
341 assert!((0.0..=100.0).contains(&x), "RSI out of range: {x}");
342 }
343 }
344
345 #[test]
346 fn reset_clears_state() {
347 let mut rsi = Rsi::new(5).unwrap();
348 rsi.batch(&[1.0, 2.0, 3.0, 2.0, 4.0, 5.0, 6.0]);
349 assert!(rsi.is_ready());
350 rsi.reset();
351 assert!(!rsi.is_ready());
352 assert_eq!(rsi.update(1.0), None);
353 }
354
355 #[test]
356 fn batch_equals_streaming() {
357 let prices: Vec<f64> = (1..=40)
358 .map(|i| (f64::from(i) * 0.3).sin() * 5.0 + f64::from(i))
359 .collect();
360 let mut a = Rsi::new(7).unwrap();
361 let mut b = Rsi::new(7).unwrap();
362 assert_eq!(
363 a.batch(&prices),
364 prices.iter().map(|p| b.update(*p)).collect::<Vec<_>>()
365 );
366 }
367
368 #[test]
369 fn ignores_non_finite_input() {
370 let mut rsi = Rsi::new(3).unwrap();
371 rsi.batch(&[1.0, 2.0, 3.0, 4.0]);
372 let before = rsi.value();
373 assert!(before.is_some());
374 assert_eq!(rsi.update(f64::NAN), before);
375 assert_eq!(rsi.update(f64::INFINITY), before);
376 assert_eq!(rsi.value(), before);
377 }
378
379 proptest::proptest! {
380 #![proptest_config(proptest::test_runner::Config::with_cases(48))]
381 #[test]
382 fn rsi_matches_naive(
383 period in 1usize..20,
384 prices in proptest::collection::vec(1.0_f64..1000.0, 0..150),
385 ) {
386 let mut rsi = Rsi::new(period).unwrap();
387 let got = rsi.batch(&prices);
388 let want = rsi_naive(&prices, period);
389 proptest::prop_assert_eq!(got.len(), want.len());
390 for (g, w) in got.iter().zip(want.iter()) {
391 match (g, w) {
392 (None, None) => {}
393 (Some(a), Some(b)) => proptest::prop_assert!(
394 (a - b).abs() < 1e-7,
395 "got={a} want={b}"
396 ),
397 _ => proptest::prop_assert!(false, "warmup mismatch"),
398 }
399 }
400 }
401 }
402}