1use crate::error::{Error, Result};
4use crate::traits::Indicator;
5
6#[derive(Debug, Clone)]
26pub struct Rsi {
27 period: usize,
28 n_minus_1: f64,
30 inv_period: f64,
33 prev_close: f64,
36 has_prev: bool,
37 seed_buf_gains: Vec<f64>,
40 seed_buf_losses: Vec<f64>,
41 avg_gain: f64,
44 avg_loss: f64,
45 avgs_seeded: bool,
46 last_value: Option<f64>,
47}
48
49impl Rsi {
50 pub fn new(period: usize) -> Result<Self> {
56 if period == 0 {
57 return Err(Error::PeriodZero);
58 }
59 Ok(Self {
60 period,
61 n_minus_1: (period - 1) as f64,
62 inv_period: 1.0 / period as f64,
63 prev_close: 0.0,
64 has_prev: false,
65 seed_buf_gains: Vec::with_capacity(period),
66 seed_buf_losses: Vec::with_capacity(period),
67 avg_gain: 0.0,
68 avg_loss: 0.0,
69 avgs_seeded: false,
70 last_value: None,
71 })
72 }
73
74 pub const fn period(&self) -> usize {
76 self.period
77 }
78
79 pub const fn value(&self) -> Option<f64> {
81 self.last_value
82 }
83
84 pub fn batch_nan(&mut self, inputs: &[f64]) -> Vec<f64> {
97 let p = self.period;
98 let n = inputs.len();
99 if self.has_prev
100 || self.avgs_seeded
101 || !self.seed_buf_gains.is_empty()
102 || n <= p
103 || !inputs.iter().all(|x| x.is_finite())
104 {
105 return inputs
106 .iter()
107 .map(|&x| self.update(x).unwrap_or(f64::NAN))
108 .collect();
109 }
110
111 let mut out = vec![f64::NAN; p];
113 out.reserve(n - p);
114 let mut prev = inputs[0];
117 let (mut sum_gain, mut sum_loss) = (0.0_f64, 0.0_f64);
118 for &x in &inputs[1..=p] {
119 let diff = x - prev;
120 prev = x;
121 let gain = if diff > 0.0 { diff } else { 0.0 };
122 let loss = if diff < 0.0 { -diff } else { 0.0 };
123 self.seed_buf_gains.push(gain);
124 self.seed_buf_losses.push(loss);
125 sum_gain += gain;
126 sum_loss += loss;
127 }
128 let p_f64 = p as f64;
129 let mut ag = sum_gain / p_f64;
130 let mut al = sum_loss / p_f64;
131 out.push(Self::rsi_from_avgs(ag, al));
132
133 for &x in &inputs[p + 1..] {
135 let diff = x - prev;
136 prev = x;
137 let gain = if diff > 0.0 { diff } else { 0.0 };
138 let loss = if diff < 0.0 { -diff } else { 0.0 };
139 ag = ag.mul_add(self.n_minus_1, gain) * self.inv_period;
140 al = al.mul_add(self.n_minus_1, loss) * self.inv_period;
141 out.push(Self::rsi_from_avgs(ag, al));
142 }
143
144 self.prev_close = prev;
146 self.has_prev = true;
147 self.avg_gain = ag;
148 self.avg_loss = al;
149 self.avgs_seeded = true;
150 self.last_value = Some(out[n - 1]);
151 out
152 }
153
154 fn rsi_from_avgs(avg_gain: f64, avg_loss: f64) -> f64 {
155 let denom = avg_gain + avg_loss;
161 if denom == 0.0 {
162 50.0
163 } else {
164 100.0 * avg_gain / denom
165 }
166 }
167}
168
169impl Indicator for Rsi {
170 type Input = f64;
171 type Output = f64;
172
173 fn update(&mut self, input: f64) -> Option<f64> {
174 if !input.is_finite() {
175 return self.last_value;
176 }
177
178 if !self.has_prev {
179 self.prev_close = input;
180 self.has_prev = true;
181 return None;
182 }
183 let prev = self.prev_close;
184 self.prev_close = input;
185
186 let diff = input - prev;
187 let gain = if diff > 0.0 { diff } else { 0.0 };
188 let loss = if diff < 0.0 { -diff } else { 0.0 };
189
190 if self.avgs_seeded {
191 let new_ag = self.avg_gain.mul_add(self.n_minus_1, gain) * self.inv_period;
194 let new_al = self.avg_loss.mul_add(self.n_minus_1, loss) * self.inv_period;
195 self.avg_gain = new_ag;
196 self.avg_loss = new_al;
197 let v = Self::rsi_from_avgs(new_ag, new_al);
198 self.last_value = Some(v);
199 return Some(v);
200 }
201
202 self.seed_buf_gains.push(gain);
203 self.seed_buf_losses.push(loss);
204 if self.seed_buf_gains.len() == self.period {
205 let ag = self.seed_buf_gains.iter().sum::<f64>() / self.period as f64;
206 let al = self.seed_buf_losses.iter().sum::<f64>() / self.period as f64;
207 self.avg_gain = ag;
208 self.avg_loss = al;
209 self.avgs_seeded = true;
210 let v = Self::rsi_from_avgs(ag, al);
211 self.last_value = Some(v);
212 return Some(v);
213 }
214 None
215 }
216
217 fn reset(&mut self) {
218 self.prev_close = 0.0;
219 self.has_prev = false;
220 self.seed_buf_gains.clear();
221 self.seed_buf_losses.clear();
222 self.avg_gain = 0.0;
223 self.avg_loss = 0.0;
224 self.avgs_seeded = false;
225 self.last_value = None;
226 }
227
228 fn warmup_period(&self) -> usize {
229 self.period + 1
230 }
231
232 fn is_ready(&self) -> bool {
233 self.last_value.is_some()
234 }
235
236 fn name(&self) -> &'static str {
237 "RSI"
238 }
239}
240
241#[cfg(test)]
242mod tests {
243 use super::*;
244 use crate::traits::BatchExt;
245 use approx::assert_relative_eq;
246
247 fn rsi_naive(prices: &[f64], period: usize) -> Vec<Option<f64>> {
249 let n = period as f64;
250 let mut out = vec![None; prices.len()];
251 let mut gains: Vec<f64> = Vec::new();
252 let mut losses: Vec<f64> = Vec::new();
253 let mut avg_gain: Option<f64> = None;
254 let mut avg_loss: Option<f64> = None;
255 let rsi_val = |ag: f64, al: f64| -> f64 {
256 if al == 0.0 {
257 if ag == 0.0 {
258 50.0
259 } else {
260 100.0
261 }
262 } else {
263 100.0 - 100.0 / (1.0 + ag / al)
264 }
265 };
266 for i in 1..prices.len() {
267 let diff = prices[i] - prices[i - 1];
268 let gain = if diff > 0.0 { diff } else { 0.0 };
269 let loss = if diff < 0.0 { -diff } else { 0.0 };
270 if let (Some(ag), Some(al)) = (avg_gain, avg_loss) {
271 let nag = (ag * (n - 1.0) + gain) / n;
272 let nal = (al * (n - 1.0) + loss) / n;
273 avg_gain = Some(nag);
274 avg_loss = Some(nal);
275 out[i] = Some(rsi_val(nag, nal));
276 } else {
277 gains.push(gain);
278 losses.push(loss);
279 if gains.len() == period {
280 let ag = gains.iter().sum::<f64>() / n;
281 let al = losses.iter().sum::<f64>() / n;
282 avg_gain = Some(ag);
283 avg_loss = Some(al);
284 out[i] = Some(rsi_val(ag, al));
285 }
286 }
287 }
288 out
289 }
290
291 #[test]
292 fn new_rejects_zero_period() {
293 assert!(matches!(Rsi::new(0), Err(Error::PeriodZero)));
294 }
295
296 #[test]
300 fn accessors_and_metadata() {
301 let mut rsi = Rsi::new(14).unwrap();
302 assert_eq!(rsi.period(), 14);
303 assert_eq!(rsi.name(), "RSI");
304 assert_eq!(rsi.value(), None);
305 for i in 1..=15 {
306 rsi.update(100.0 + f64::from(i));
307 }
308 assert!(rsi.value().is_some());
309 }
310
311 #[test]
317 fn naive_helper_flat_series_yields_50() {
318 let ks = rsi_naive(&[42.0; 20], 5);
319 for r in ks.into_iter().skip(5) {
320 assert_eq!(r.expect("ready after period+1 inputs"), 50.0);
321 }
322 }
323
324 #[test]
330 fn naive_helper_monotone_up_yields_100() {
331 let prices: Vec<f64> = (1..=20).map(f64::from).collect();
332 let ks = rsi_naive(&prices, 5);
333 for r in ks.into_iter().skip(5) {
334 assert_eq!(r.expect("ready after period+1 inputs"), 100.0);
335 }
336 }
337
338 #[test]
339 fn warmup_period_is_period_plus_one() {
340 let rsi = Rsi::new(14).unwrap();
341 assert_eq!(rsi.warmup_period(), 15);
342 }
343
344 #[test]
345 fn first_emission_at_index_period() {
346 let prices: Vec<f64> = (1..=20).map(f64::from).collect();
348 let mut rsi = Rsi::new(14).unwrap();
349 let out = rsi.batch(&prices);
350 for x in &out[..14] {
352 assert!(x.is_none());
353 }
354 assert!(out[14].is_some());
355 }
356
357 #[test]
358 fn pure_uptrend_yields_rsi_100() {
359 let prices: Vec<f64> = (1..=20).map(f64::from).collect();
360 let mut rsi = Rsi::new(14).unwrap();
361 let out = rsi.batch(&prices);
362 for v in out.iter().filter_map(|x| x.as_ref()) {
364 assert_relative_eq!(*v, 100.0, epsilon = 1e-9);
365 }
366 }
367
368 #[test]
369 fn pure_downtrend_yields_rsi_0() {
370 let prices: Vec<f64> = (1..=20).rev().map(f64::from).collect();
371 let mut rsi = Rsi::new(14).unwrap();
372 let out = rsi.batch(&prices);
373 for v in out.iter().filter_map(|x| x.as_ref()) {
374 assert_relative_eq!(*v, 0.0, epsilon = 1e-9);
375 }
376 }
377
378 #[test]
379 fn flat_series_yields_rsi_50() {
380 let prices = [10.0_f64; 30];
381 let mut rsi = Rsi::new(14).unwrap();
382 let out = rsi.batch(&prices);
383 for v in out.iter().filter_map(|x| x.as_ref()) {
384 assert_relative_eq!(*v, 50.0, epsilon = 1e-12);
385 }
386 }
387
388 #[test]
389 fn classic_wilder_textbook_values() {
390 let prices = [
395 44.34, 44.09, 44.15, 43.61, 44.33, 44.83, 45.10, 45.42, 45.84, 46.08, 45.89, 46.03,
396 45.61, 46.28, 46.28,
397 ];
398 let mut rsi = Rsi::new(14).unwrap();
399 let out = rsi.batch(&prices);
400 let first = out[14].expect("first RSI emitted at index period");
401 assert_relative_eq!(first, 70.464, epsilon = 0.05);
402 }
403
404 #[test]
405 fn rsi_stays_in_0_100_range() {
406 let prices: Vec<f64> = (0..200)
407 .map(|i| 100.0 + (f64::from(i) * 0.7).sin() * 10.0)
408 .collect();
409 let mut rsi = Rsi::new(14).unwrap();
410 for x in rsi.batch(&prices).into_iter().flatten() {
411 assert!((0.0..=100.0).contains(&x), "RSI out of range: {x}");
412 }
413 }
414
415 #[test]
416 fn reset_clears_state() {
417 let mut rsi = Rsi::new(5).unwrap();
418 rsi.batch(&[1.0, 2.0, 3.0, 2.0, 4.0, 5.0, 6.0]);
419 assert!(rsi.is_ready());
420 rsi.reset();
421 assert!(!rsi.is_ready());
422 assert_eq!(rsi.update(1.0), None);
423 }
424
425 #[test]
426 fn batch_equals_streaming() {
427 let prices: Vec<f64> = (1..=40)
428 .map(|i| (f64::from(i) * 0.3).sin() * 5.0 + f64::from(i))
429 .collect();
430 let mut a = Rsi::new(7).unwrap();
431 let mut b = Rsi::new(7).unwrap();
432 assert_eq!(
433 a.batch(&prices),
434 prices.iter().map(|p| b.update(*p)).collect::<Vec<_>>()
435 );
436 }
437
438 #[test]
439 fn ignores_non_finite_input() {
440 let mut rsi = Rsi::new(3).unwrap();
441 rsi.batch(&[1.0, 2.0, 3.0, 4.0]);
442 let before = rsi.value();
443 assert!(before.is_some());
444 assert_eq!(rsi.update(f64::NAN), before);
445 assert_eq!(rsi.update(f64::INFINITY), before);
446 assert_eq!(rsi.value(), before);
447 }
448
449 fn bits_eq(a: &[f64], b: &[f64]) -> bool {
450 a.len() == b.len()
451 && a.iter()
452 .zip(b)
453 .all(|(x, y)| x == y || (x.is_nan() && y.is_nan()))
454 }
455
456 fn rsi_replay(period: usize, series: &[f64]) -> Vec<f64> {
457 let mut r = Rsi::new(period).unwrap();
458 series
459 .iter()
460 .map(|&x| r.update(x).unwrap_or(f64::NAN))
461 .collect()
462 }
463
464 #[test]
465 fn batch_nan_fast_path_is_bit_identical() {
466 let series: Vec<f64> = (0..300)
467 .map(|i| (f64::from(i) * 0.3).sin() * 5.0 + f64::from(i) * 0.1 + 100.0)
468 .collect();
469 let mut rsi = Rsi::new(14).unwrap();
470 let got = rsi.batch_nan(&series);
471 assert!(bits_eq(&got, &rsi_replay(14, &series)));
472 let mut ref_rsi = Rsi::new(14).unwrap();
473 for &x in &series {
474 ref_rsi.update(x);
475 }
476 assert_eq!(rsi.update(123.0), ref_rsi.update(123.0));
477 }
478
479 #[test]
480 fn batch_nan_falls_back_on_non_finite() {
481 let series = [10.0, 11.0, 9.0, f64::NAN, 12.0, 13.0, 8.0];
482 let mut rsi = Rsi::new(3).unwrap();
483 assert!(bits_eq(&rsi.batch_nan(&series), &rsi_replay(3, &series)));
484 }
485
486 #[test]
487 fn batch_nan_falls_back_when_not_fresh() {
488 let mut rsi = Rsi::new(3).unwrap();
489 rsi.update(50.0);
490 let series = [51.0, 49.0, 52.0, 53.0, 50.0];
491 let mut ref_rsi = Rsi::new(3).unwrap();
492 ref_rsi.update(50.0);
493 let want: Vec<f64> = series
494 .iter()
495 .map(|&x| ref_rsi.update(x).unwrap_or(f64::NAN))
496 .collect();
497 assert!(bits_eq(&rsi.batch_nan(&series), &want));
498 }
499
500 #[test]
501 fn batch_nan_too_short_to_seed_falls_back() {
502 let series = [10.0, 11.0, 12.0];
504 let mut rsi = Rsi::new(3).unwrap();
505 assert!(bits_eq(&rsi.batch_nan(&series), &rsi_replay(3, &series)));
506 }
507
508 proptest::proptest! {
509 #![proptest_config(proptest::test_runner::Config::with_cases(48))]
510 #[test]
511 fn rsi_matches_naive(
512 period in 1usize..20,
513 prices in proptest::collection::vec(1.0_f64..1000.0, 0..150),
514 ) {
515 let mut rsi = Rsi::new(period).unwrap();
516 let got = rsi.batch(&prices);
517 let want = rsi_naive(&prices, period);
518 proptest::prop_assert_eq!(got.len(), want.len());
519 for (g, w) in got.iter().zip(want.iter()) {
520 match (g, w) {
521 (None, None) => {}
522 (Some(a), Some(b)) => proptest::prop_assert!(
523 (a - b).abs() < 1e-7,
524 "got={a} want={b}"
525 ),
526 _ => proptest::prop_assert!(false, "warmup mismatch"),
527 }
528 }
529 }
530 }
531}