Skip to main content

wickra_core/indicators/
rsi.rs

1//! Relative Strength Index using Wilder's smoothing.
2
3use crate::error::{Error, Result};
4use crate::traits::Indicator;
5
6/// Relative Strength Index (Wilder, 1978).
7///
8/// Uses Wilder's smoothing (an EMA with `alpha = 1 / period`). The first output
9/// is produced after `period + 1` inputs: the seed averages the first `period`
10/// gains and losses, and the first emitted RSI corresponds to the input at
11/// index `period`.
12///
13/// # Example
14///
15/// ```
16/// use wickra_core::{Indicator, Rsi};
17///
18/// let mut indicator = Rsi::new(3).unwrap();
19/// let mut last = None;
20/// for i in 0..80 {
21///     last = indicator.update(100.0 + f64::from(i));
22/// }
23/// assert!(last.is_some());
24/// ```
25#[derive(Debug, Clone)]
26pub struct Rsi {
27    period: usize,
28    /// `period - 1` as `f64`, precomputed for the Wilder smoothing step.
29    n_minus_1: f64,
30    /// `1 / period`, precomputed so the per-tick smoothing multiplies instead of
31    /// divides (a reciprocal is hoisted out of the hot path).
32    inv_period: f64,
33    /// Previous close, valid once `has_prev` is set. Bare `f64` + flag instead of
34    /// `Option<f64>` to avoid an enum-tag read on every tick.
35    prev_close: f64,
36    has_prev: bool,
37    // Wilder seeds with the simple average of the first `period` gains/losses,
38    // then transitions to recursive smoothing.
39    seed_buf_gains: Vec<f64>,
40    seed_buf_losses: Vec<f64>,
41    /// Smoothed average gain / loss, valid once `avgs_seeded` is set. Bare `f64`s
42    /// + flag so the hot recurrence avoids reading two `Option<f64>` tags per tick.
43    avg_gain: f64,
44    avg_loss: f64,
45    avgs_seeded: bool,
46    last_value: Option<f64>,
47}
48
49impl Rsi {
50    /// Construct an RSI with the given Wilder period.
51    ///
52    /// # Errors
53    ///
54    /// Returns [`Error::PeriodZero`] if `period == 0`.
55    pub fn new(period: usize) -> Result<Self> {
56        if period == 0 {
57            return Err(Error::PeriodZero);
58        }
59        Ok(Self {
60            period,
61            n_minus_1: (period - 1) as f64,
62            inv_period: 1.0 / period as f64,
63            prev_close: 0.0,
64            has_prev: false,
65            seed_buf_gains: Vec::with_capacity(period),
66            seed_buf_losses: Vec::with_capacity(period),
67            avg_gain: 0.0,
68            avg_loss: 0.0,
69            avgs_seeded: false,
70            last_value: None,
71        })
72    }
73
74    /// Configured period.
75    pub const fn period(&self) -> usize {
76        self.period
77    }
78
79    /// Current value if available.
80    pub const fn value(&self) -> Option<f64> {
81        self.last_value
82    }
83
84    /// Vectorized batch returning one `f64` per input (`NaN` during warmup).
85    ///
86    /// Shadows the generic [`BatchNanExt::batch_nan`](crate::BatchNanExt) blanket
87    /// default. RSI is a recursive (IIR) filter — Wilder smoothing — so it cannot
88    /// be SIMD-vectorized any more than the C peers manage; the win is purely in
89    /// stripping per-tick overhead. For a fresh indicator over an all-finite slice
90    /// long enough to seed (`n > period`) it runs the seed once and then the bare
91    /// smoothing recurrence in a tight loop with no per-tick `is_finite`/`has_prev`/
92    /// `avgs_seeded` branch and no `Option`, using the identical division at the
93    /// seed and `mul_add`/`rsi_from_avgs` afterwards — so it is *bit-for-bit* equal
94    /// to replaying `update`. Shorter or non-fresh/non-finite inputs defer to the
95    /// exact `update` replay.
96    pub fn batch_nan(&mut self, inputs: &[f64]) -> Vec<f64> {
97        let p = self.period;
98        let n = inputs.len();
99        if self.has_prev
100            || self.avgs_seeded
101            || !self.seed_buf_gains.is_empty()
102            || n <= p
103            || !inputs.iter().all(|x| x.is_finite())
104        {
105            return inputs
106                .iter()
107                .map(|&x| self.update(x).unwrap_or(f64::NAN))
108                .collect();
109        }
110
111        // Warmup `[0, p)` is `NaN`; outputs from index `p` on are pushed once each.
112        let mut out = vec![f64::NAN; p];
113        out.reserve(n - p);
114        // Seed from the first `period` diffs (inputs[1..=p]); index 0 only sets the
115        // baseline. Retain the seed gains/losses exactly as `update` leaves them.
116        let mut prev = inputs[0];
117        let (mut sum_gain, mut sum_loss) = (0.0_f64, 0.0_f64);
118        for &x in &inputs[1..=p] {
119            let diff = x - prev;
120            prev = x;
121            let gain = if diff > 0.0 { diff } else { 0.0 };
122            let loss = if diff < 0.0 { -diff } else { 0.0 };
123            self.seed_buf_gains.push(gain);
124            self.seed_buf_losses.push(loss);
125            sum_gain += gain;
126            sum_loss += loss;
127        }
128        let p_f64 = p as f64;
129        let mut ag = sum_gain / p_f64;
130        let mut al = sum_loss / p_f64;
131        out.push(Self::rsi_from_avgs(ag, al));
132
133        // Steady state: Wilder smoothing, reciprocal hoisted, one `rsi_from_avgs`.
134        for &x in &inputs[p + 1..] {
135            let diff = x - prev;
136            prev = x;
137            let gain = if diff > 0.0 { diff } else { 0.0 };
138            let loss = if diff < 0.0 { -diff } else { 0.0 };
139            ag = ag.mul_add(self.n_minus_1, gain) * self.inv_period;
140            al = al.mul_add(self.n_minus_1, loss) * self.inv_period;
141            out.push(Self::rsi_from_avgs(ag, al));
142        }
143
144        // Leave state where a full `update` replay would.
145        self.prev_close = prev;
146        self.has_prev = true;
147        self.avg_gain = ag;
148        self.avg_loss = al;
149        self.avgs_seeded = true;
150        self.last_value = Some(out[n - 1]);
151        out
152    }
153
154    fn rsi_from_avgs(avg_gain: f64, avg_loss: f64) -> f64 {
155        // Algebraically `100 - 100/(1 + ag/al)` collapses to `100·ag/(ag+al)`,
156        // which needs a single division instead of two and removes the separate
157        // `rs` step. Edge cases stay exact: `al == 0, ag > 0` gives `100·ag/ag =
158        // 100`; `ag == 0, al > 0` gives `0`; both zero (no movement) is the
159        // undefined case and returns the neutral 50.
160        let denom = avg_gain + avg_loss;
161        if denom == 0.0 {
162            50.0
163        } else {
164            100.0 * avg_gain / denom
165        }
166    }
167}
168
169impl Indicator for Rsi {
170    type Input = f64;
171    type Output = f64;
172
173    fn update(&mut self, input: f64) -> Option<f64> {
174        if !input.is_finite() {
175            return self.last_value;
176        }
177
178        if !self.has_prev {
179            self.prev_close = input;
180            self.has_prev = true;
181            return None;
182        }
183        let prev = self.prev_close;
184        self.prev_close = input;
185
186        let diff = input - prev;
187        let gain = if diff > 0.0 { diff } else { 0.0 };
188        let loss = if diff < 0.0 { -diff } else { 0.0 };
189
190        if self.avgs_seeded {
191            // Wilder smoothing `(prev·(n-1) + x) / n` with the reciprocal hoisted:
192            // a fused multiply-add then a multiply by `1/n`, no per-tick division.
193            let new_ag = self.avg_gain.mul_add(self.n_minus_1, gain) * self.inv_period;
194            let new_al = self.avg_loss.mul_add(self.n_minus_1, loss) * self.inv_period;
195            self.avg_gain = new_ag;
196            self.avg_loss = new_al;
197            let v = Self::rsi_from_avgs(new_ag, new_al);
198            self.last_value = Some(v);
199            return Some(v);
200        }
201
202        self.seed_buf_gains.push(gain);
203        self.seed_buf_losses.push(loss);
204        if self.seed_buf_gains.len() == self.period {
205            let ag = self.seed_buf_gains.iter().sum::<f64>() / self.period as f64;
206            let al = self.seed_buf_losses.iter().sum::<f64>() / self.period as f64;
207            self.avg_gain = ag;
208            self.avg_loss = al;
209            self.avgs_seeded = true;
210            let v = Self::rsi_from_avgs(ag, al);
211            self.last_value = Some(v);
212            return Some(v);
213        }
214        None
215    }
216
217    fn reset(&mut self) {
218        self.prev_close = 0.0;
219        self.has_prev = false;
220        self.seed_buf_gains.clear();
221        self.seed_buf_losses.clear();
222        self.avg_gain = 0.0;
223        self.avg_loss = 0.0;
224        self.avgs_seeded = false;
225        self.last_value = None;
226    }
227
228    fn warmup_period(&self) -> usize {
229        self.period + 1
230    }
231
232    fn is_ready(&self) -> bool {
233        self.last_value.is_some()
234    }
235
236    fn name(&self) -> &'static str {
237        "RSI"
238    }
239}
240
241#[cfg(test)]
242mod tests {
243    use super::*;
244    use crate::traits::BatchExt;
245    use approx::assert_relative_eq;
246
247    /// Independent reference: Wilder RSI computed straight from the definition.
248    fn rsi_naive(prices: &[f64], period: usize) -> Vec<Option<f64>> {
249        let n = period as f64;
250        let mut out = vec![None; prices.len()];
251        let mut gains: Vec<f64> = Vec::new();
252        let mut losses: Vec<f64> = Vec::new();
253        let mut avg_gain: Option<f64> = None;
254        let mut avg_loss: Option<f64> = None;
255        let rsi_val = |ag: f64, al: f64| -> f64 {
256            if al == 0.0 {
257                if ag == 0.0 {
258                    50.0
259                } else {
260                    100.0
261                }
262            } else {
263                100.0 - 100.0 / (1.0 + ag / al)
264            }
265        };
266        for i in 1..prices.len() {
267            let diff = prices[i] - prices[i - 1];
268            let gain = if diff > 0.0 { diff } else { 0.0 };
269            let loss = if diff < 0.0 { -diff } else { 0.0 };
270            if let (Some(ag), Some(al)) = (avg_gain, avg_loss) {
271                let nag = (ag * (n - 1.0) + gain) / n;
272                let nal = (al * (n - 1.0) + loss) / n;
273                avg_gain = Some(nag);
274                avg_loss = Some(nal);
275                out[i] = Some(rsi_val(nag, nal));
276            } else {
277                gains.push(gain);
278                losses.push(loss);
279                if gains.len() == period {
280                    let ag = gains.iter().sum::<f64>() / n;
281                    let al = losses.iter().sum::<f64>() / n;
282                    avg_gain = Some(ag);
283                    avg_loss = Some(al);
284                    out[i] = Some(rsi_val(ag, al));
285                }
286            }
287        }
288        out
289    }
290
291    #[test]
292    fn new_rejects_zero_period() {
293        assert!(matches!(Rsi::new(0), Err(Error::PeriodZero)));
294    }
295
296    /// Cover the const accessors `period` / `value` (60-67) and the
297    /// Indicator-impl `name` body (145-147). `warmup_period` is covered
298    /// already by `warmup_period_is_period_plus_one`.
299    #[test]
300    fn accessors_and_metadata() {
301        let mut rsi = Rsi::new(14).unwrap();
302        assert_eq!(rsi.period(), 14);
303        assert_eq!(rsi.name(), "RSI");
304        assert_eq!(rsi.value(), None);
305        for i in 1..=15 {
306            rsi.update(100.0 + f64::from(i));
307        }
308        assert!(rsi.value().is_some());
309    }
310
311    /// Cover the `ag == 0` branch (line 167) of the test-helper `rsi_naive`:
312    /// when both `avg_gain` and `avg_loss` are 0 (a perfectly flat series),
313    /// the helper must return the neutral 50.0. The proptest reference uses
314    /// random inputs that essentially never hit zero gains AND zero losses
315    /// simultaneously, leaving this branch dead in the helper.
316    #[test]
317    fn naive_helper_flat_series_yields_50() {
318        let ks = rsi_naive(&[42.0; 20], 5);
319        for r in ks.into_iter().skip(5) {
320            assert_eq!(r.expect("ready after period+1 inputs"), 50.0);
321        }
322    }
323
324    /// Cover the `100.0` branch (line 169) of the test-helper `rsi_naive`:
325    /// strictly increasing prices give `avg_loss == 0` while `avg_gain > 0`,
326    /// the textbook overbought saturation case. Random proptest inputs
327    /// virtually never satisfy `al == 0 && ag != 0`, so this needs an
328    /// explicit monotone series.
329    #[test]
330    fn naive_helper_monotone_up_yields_100() {
331        let prices: Vec<f64> = (1..=20).map(f64::from).collect();
332        let ks = rsi_naive(&prices, 5);
333        for r in ks.into_iter().skip(5) {
334            assert_eq!(r.expect("ready after period+1 inputs"), 100.0);
335        }
336    }
337
338    #[test]
339    fn warmup_period_is_period_plus_one() {
340        let rsi = Rsi::new(14).unwrap();
341        assert_eq!(rsi.warmup_period(), 15);
342    }
343
344    #[test]
345    fn first_emission_at_index_period() {
346        // RSI(14) needs 14 diffs => 15 inputs before first value.
347        let prices: Vec<f64> = (1..=20).map(f64::from).collect();
348        let mut rsi = Rsi::new(14).unwrap();
349        let out = rsi.batch(&prices);
350        // indices 0..14 -> None, index 14 -> first Some
351        for x in &out[..14] {
352            assert!(x.is_none());
353        }
354        assert!(out[14].is_some());
355    }
356
357    #[test]
358    fn pure_uptrend_yields_rsi_100() {
359        let prices: Vec<f64> = (1..=20).map(f64::from).collect();
360        let mut rsi = Rsi::new(14).unwrap();
361        let out = rsi.batch(&prices);
362        // All diffs are positive => avg_loss == 0 => RSI == 100
363        for v in out.iter().filter_map(|x| x.as_ref()) {
364            assert_relative_eq!(*v, 100.0, epsilon = 1e-9);
365        }
366    }
367
368    #[test]
369    fn pure_downtrend_yields_rsi_0() {
370        let prices: Vec<f64> = (1..=20).rev().map(f64::from).collect();
371        let mut rsi = Rsi::new(14).unwrap();
372        let out = rsi.batch(&prices);
373        for v in out.iter().filter_map(|x| x.as_ref()) {
374            assert_relative_eq!(*v, 0.0, epsilon = 1e-9);
375        }
376    }
377
378    #[test]
379    fn flat_series_yields_rsi_50() {
380        let prices = [10.0_f64; 30];
381        let mut rsi = Rsi::new(14).unwrap();
382        let out = rsi.batch(&prices);
383        for v in out.iter().filter_map(|x| x.as_ref()) {
384            assert_relative_eq!(*v, 50.0, epsilon = 1e-12);
385        }
386    }
387
388    #[test]
389    fn classic_wilder_textbook_values() {
390        // Wilder's original example from "New Concepts in Technical Trading Systems",
391        // 14-period RSI. We compute the first value at index 14 and compare to the
392        // value Wilder publishes (~70.46).
393        // Source: classic textbook table, reproduced in many references (e.g. Investopedia).
394        let prices = [
395            44.34, 44.09, 44.15, 43.61, 44.33, 44.83, 45.10, 45.42, 45.84, 46.08, 45.89, 46.03,
396            45.61, 46.28, 46.28,
397        ];
398        let mut rsi = Rsi::new(14).unwrap();
399        let out = rsi.batch(&prices);
400        let first = out[14].expect("first RSI emitted at index period");
401        assert_relative_eq!(first, 70.464, epsilon = 0.05);
402    }
403
404    #[test]
405    fn rsi_stays_in_0_100_range() {
406        let prices: Vec<f64> = (0..200)
407            .map(|i| 100.0 + (f64::from(i) * 0.7).sin() * 10.0)
408            .collect();
409        let mut rsi = Rsi::new(14).unwrap();
410        for x in rsi.batch(&prices).into_iter().flatten() {
411            assert!((0.0..=100.0).contains(&x), "RSI out of range: {x}");
412        }
413    }
414
415    #[test]
416    fn reset_clears_state() {
417        let mut rsi = Rsi::new(5).unwrap();
418        rsi.batch(&[1.0, 2.0, 3.0, 2.0, 4.0, 5.0, 6.0]);
419        assert!(rsi.is_ready());
420        rsi.reset();
421        assert!(!rsi.is_ready());
422        assert_eq!(rsi.update(1.0), None);
423    }
424
425    #[test]
426    fn batch_equals_streaming() {
427        let prices: Vec<f64> = (1..=40)
428            .map(|i| (f64::from(i) * 0.3).sin() * 5.0 + f64::from(i))
429            .collect();
430        let mut a = Rsi::new(7).unwrap();
431        let mut b = Rsi::new(7).unwrap();
432        assert_eq!(
433            a.batch(&prices),
434            prices.iter().map(|p| b.update(*p)).collect::<Vec<_>>()
435        );
436    }
437
438    #[test]
439    fn ignores_non_finite_input() {
440        let mut rsi = Rsi::new(3).unwrap();
441        rsi.batch(&[1.0, 2.0, 3.0, 4.0]);
442        let before = rsi.value();
443        assert!(before.is_some());
444        assert_eq!(rsi.update(f64::NAN), before);
445        assert_eq!(rsi.update(f64::INFINITY), before);
446        assert_eq!(rsi.value(), before);
447    }
448
449    fn bits_eq(a: &[f64], b: &[f64]) -> bool {
450        a.len() == b.len()
451            && a.iter()
452                .zip(b)
453                .all(|(x, y)| x == y || (x.is_nan() && y.is_nan()))
454    }
455
456    fn rsi_replay(period: usize, series: &[f64]) -> Vec<f64> {
457        let mut r = Rsi::new(period).unwrap();
458        series
459            .iter()
460            .map(|&x| r.update(x).unwrap_or(f64::NAN))
461            .collect()
462    }
463
464    #[test]
465    fn batch_nan_fast_path_is_bit_identical() {
466        let series: Vec<f64> = (0..300)
467            .map(|i| (f64::from(i) * 0.3).sin() * 5.0 + f64::from(i) * 0.1 + 100.0)
468            .collect();
469        let mut rsi = Rsi::new(14).unwrap();
470        let got = rsi.batch_nan(&series);
471        assert!(bits_eq(&got, &rsi_replay(14, &series)));
472        let mut ref_rsi = Rsi::new(14).unwrap();
473        for &x in &series {
474            ref_rsi.update(x);
475        }
476        assert_eq!(rsi.update(123.0), ref_rsi.update(123.0));
477    }
478
479    #[test]
480    fn batch_nan_falls_back_on_non_finite() {
481        let series = [10.0, 11.0, 9.0, f64::NAN, 12.0, 13.0, 8.0];
482        let mut rsi = Rsi::new(3).unwrap();
483        assert!(bits_eq(&rsi.batch_nan(&series), &rsi_replay(3, &series)));
484    }
485
486    #[test]
487    fn batch_nan_falls_back_when_not_fresh() {
488        let mut rsi = Rsi::new(3).unwrap();
489        rsi.update(50.0);
490        let series = [51.0, 49.0, 52.0, 53.0, 50.0];
491        let mut ref_rsi = Rsi::new(3).unwrap();
492        ref_rsi.update(50.0);
493        let want: Vec<f64> = series
494            .iter()
495            .map(|&x| ref_rsi.update(x).unwrap_or(f64::NAN))
496            .collect();
497        assert!(bits_eq(&rsi.batch_nan(&series), &want));
498    }
499
500    #[test]
501    fn batch_nan_too_short_to_seed_falls_back() {
502        // n <= period: routed to the exact replay (cannot seed yet).
503        let series = [10.0, 11.0, 12.0];
504        let mut rsi = Rsi::new(3).unwrap();
505        assert!(bits_eq(&rsi.batch_nan(&series), &rsi_replay(3, &series)));
506    }
507
508    proptest::proptest! {
509        #![proptest_config(proptest::test_runner::Config::with_cases(48))]
510        #[test]
511        fn rsi_matches_naive(
512            period in 1usize..20,
513            prices in proptest::collection::vec(1.0_f64..1000.0, 0..150),
514        ) {
515            let mut rsi = Rsi::new(period).unwrap();
516            let got = rsi.batch(&prices);
517            let want = rsi_naive(&prices, period);
518            proptest::prop_assert_eq!(got.len(), want.len());
519            for (g, w) in got.iter().zip(want.iter()) {
520                match (g, w) {
521                    (None, None) => {}
522                    (Some(a), Some(b)) => proptest::prop_assert!(
523                        (a - b).abs() < 1e-7,
524                        "got={a} want={b}"
525                    ),
526                    _ => proptest::prop_assert!(false, "warmup mismatch"),
527                }
528            }
529        }
530    }
531}