Skip to main content

wickra_core/indicators/
rsi.rs

1//! Relative Strength Index using Wilder's smoothing.
2
3use crate::error::{Error, Result};
4use crate::traits::Indicator;
5
6/// Relative Strength Index (Wilder, 1978).
7///
8/// Uses Wilder's smoothing (an EMA with `alpha = 1 / period`). The first output
9/// is produced after `period + 1` inputs: the seed averages the first `period`
10/// gains and losses, and the first emitted RSI corresponds to the input at
11/// index `period`.
12///
13/// # Example
14///
15/// ```
16/// use wickra_core::{Indicator, Rsi};
17///
18/// let mut indicator = Rsi::new(3).unwrap();
19/// let mut last = None;
20/// for i in 0..80 {
21///     last = indicator.update(100.0 + f64::from(i));
22/// }
23/// assert!(last.is_some());
24/// ```
25#[derive(Debug, Clone)]
26pub struct Rsi {
27    period: usize,
28    /// `period - 1` as `f64`, precomputed for the Wilder smoothing step.
29    n_minus_1: f64,
30    /// `1 / period`, precomputed so the per-tick smoothing multiplies instead of
31    /// divides (a reciprocal is hoisted out of the hot path).
32    inv_period: f64,
33    /// Previous close, valid once `has_prev` is set. Bare `f64` + flag instead of
34    /// `Option<f64>` to avoid an enum-tag read on every tick.
35    prev_close: f64,
36    has_prev: bool,
37    // Wilder seeds with the simple average of the first `period` gains/losses,
38    // then transitions to recursive smoothing.
39    seed_buf_gains: Vec<f64>,
40    seed_buf_losses: Vec<f64>,
41    /// Smoothed average gain / loss, valid once `avgs_seeded` is set. Bare `f64`s
42    /// + flag so the hot recurrence avoids reading two `Option<f64>` tags per tick.
43    avg_gain: f64,
44    avg_loss: f64,
45    avgs_seeded: bool,
46    last_value: Option<f64>,
47}
48
49impl Rsi {
50    /// Construct an RSI with the given Wilder period.
51    ///
52    /// # Errors
53    ///
54    /// Returns [`Error::PeriodZero`] if `period == 0`.
55    pub fn new(period: usize) -> Result<Self> {
56        if period == 0 {
57            return Err(Error::PeriodZero);
58        }
59        Ok(Self {
60            period,
61            n_minus_1: (period - 1) as f64,
62            inv_period: 1.0 / period as f64,
63            prev_close: 0.0,
64            has_prev: false,
65            seed_buf_gains: Vec::with_capacity(period),
66            seed_buf_losses: Vec::with_capacity(period),
67            avg_gain: 0.0,
68            avg_loss: 0.0,
69            avgs_seeded: false,
70            last_value: None,
71        })
72    }
73
74    /// Configured period.
75    pub const fn period(&self) -> usize {
76        self.period
77    }
78
79    /// Current value if available.
80    pub const fn value(&self) -> Option<f64> {
81        self.last_value
82    }
83
84    fn rsi_from_avgs(avg_gain: f64, avg_loss: f64) -> f64 {
85        // Algebraically `100 - 100/(1 + ag/al)` collapses to `100·ag/(ag+al)`,
86        // which needs a single division instead of two and removes the separate
87        // `rs` step. Edge cases stay exact: `al == 0, ag > 0` gives `100·ag/ag =
88        // 100`; `ag == 0, al > 0` gives `0`; both zero (no movement) is the
89        // undefined case and returns the neutral 50.
90        let denom = avg_gain + avg_loss;
91        if denom == 0.0 {
92            50.0
93        } else {
94            100.0 * avg_gain / denom
95        }
96    }
97}
98
99impl Indicator for Rsi {
100    type Input = f64;
101    type Output = f64;
102
103    fn update(&mut self, input: f64) -> Option<f64> {
104        if !input.is_finite() {
105            return self.last_value;
106        }
107
108        if !self.has_prev {
109            self.prev_close = input;
110            self.has_prev = true;
111            return None;
112        }
113        let prev = self.prev_close;
114        self.prev_close = input;
115
116        let diff = input - prev;
117        let gain = if diff > 0.0 { diff } else { 0.0 };
118        let loss = if diff < 0.0 { -diff } else { 0.0 };
119
120        if self.avgs_seeded {
121            // Wilder smoothing `(prev·(n-1) + x) / n` with the reciprocal hoisted:
122            // a fused multiply-add then a multiply by `1/n`, no per-tick division.
123            let new_ag = self.avg_gain.mul_add(self.n_minus_1, gain) * self.inv_period;
124            let new_al = self.avg_loss.mul_add(self.n_minus_1, loss) * self.inv_period;
125            self.avg_gain = new_ag;
126            self.avg_loss = new_al;
127            let v = Self::rsi_from_avgs(new_ag, new_al);
128            self.last_value = Some(v);
129            return Some(v);
130        }
131
132        self.seed_buf_gains.push(gain);
133        self.seed_buf_losses.push(loss);
134        if self.seed_buf_gains.len() == self.period {
135            let ag = self.seed_buf_gains.iter().sum::<f64>() / self.period as f64;
136            let al = self.seed_buf_losses.iter().sum::<f64>() / self.period as f64;
137            self.avg_gain = ag;
138            self.avg_loss = al;
139            self.avgs_seeded = true;
140            let v = Self::rsi_from_avgs(ag, al);
141            self.last_value = Some(v);
142            return Some(v);
143        }
144        None
145    }
146
147    fn reset(&mut self) {
148        self.prev_close = 0.0;
149        self.has_prev = false;
150        self.seed_buf_gains.clear();
151        self.seed_buf_losses.clear();
152        self.avg_gain = 0.0;
153        self.avg_loss = 0.0;
154        self.avgs_seeded = false;
155        self.last_value = None;
156    }
157
158    fn warmup_period(&self) -> usize {
159        self.period + 1
160    }
161
162    fn is_ready(&self) -> bool {
163        self.last_value.is_some()
164    }
165
166    fn name(&self) -> &'static str {
167        "RSI"
168    }
169}
170
171#[cfg(test)]
172mod tests {
173    use super::*;
174    use crate::traits::BatchExt;
175    use approx::assert_relative_eq;
176
177    /// Independent reference: Wilder RSI computed straight from the definition.
178    fn rsi_naive(prices: &[f64], period: usize) -> Vec<Option<f64>> {
179        let n = period as f64;
180        let mut out = vec![None; prices.len()];
181        let mut gains: Vec<f64> = Vec::new();
182        let mut losses: Vec<f64> = Vec::new();
183        let mut avg_gain: Option<f64> = None;
184        let mut avg_loss: Option<f64> = None;
185        let rsi_val = |ag: f64, al: f64| -> f64 {
186            if al == 0.0 {
187                if ag == 0.0 {
188                    50.0
189                } else {
190                    100.0
191                }
192            } else {
193                100.0 - 100.0 / (1.0 + ag / al)
194            }
195        };
196        for i in 1..prices.len() {
197            let diff = prices[i] - prices[i - 1];
198            let gain = if diff > 0.0 { diff } else { 0.0 };
199            let loss = if diff < 0.0 { -diff } else { 0.0 };
200            if let (Some(ag), Some(al)) = (avg_gain, avg_loss) {
201                let nag = (ag * (n - 1.0) + gain) / n;
202                let nal = (al * (n - 1.0) + loss) / n;
203                avg_gain = Some(nag);
204                avg_loss = Some(nal);
205                out[i] = Some(rsi_val(nag, nal));
206            } else {
207                gains.push(gain);
208                losses.push(loss);
209                if gains.len() == period {
210                    let ag = gains.iter().sum::<f64>() / n;
211                    let al = losses.iter().sum::<f64>() / n;
212                    avg_gain = Some(ag);
213                    avg_loss = Some(al);
214                    out[i] = Some(rsi_val(ag, al));
215                }
216            }
217        }
218        out
219    }
220
221    #[test]
222    fn new_rejects_zero_period() {
223        assert!(matches!(Rsi::new(0), Err(Error::PeriodZero)));
224    }
225
226    /// Cover the const accessors `period` / `value` (60-67) and the
227    /// Indicator-impl `name` body (145-147). `warmup_period` is covered
228    /// already by `warmup_period_is_period_plus_one`.
229    #[test]
230    fn accessors_and_metadata() {
231        let mut rsi = Rsi::new(14).unwrap();
232        assert_eq!(rsi.period(), 14);
233        assert_eq!(rsi.name(), "RSI");
234        assert_eq!(rsi.value(), None);
235        for i in 1..=15 {
236            rsi.update(100.0 + f64::from(i));
237        }
238        assert!(rsi.value().is_some());
239    }
240
241    /// Cover the `ag == 0` branch (line 167) of the test-helper `rsi_naive`:
242    /// when both `avg_gain` and `avg_loss` are 0 (a perfectly flat series),
243    /// the helper must return the neutral 50.0. The proptest reference uses
244    /// random inputs that essentially never hit zero gains AND zero losses
245    /// simultaneously, leaving this branch dead in the helper.
246    #[test]
247    fn naive_helper_flat_series_yields_50() {
248        let ks = rsi_naive(&[42.0; 20], 5);
249        for r in ks.into_iter().skip(5) {
250            assert_eq!(r.expect("ready after period+1 inputs"), 50.0);
251        }
252    }
253
254    /// Cover the `100.0` branch (line 169) of the test-helper `rsi_naive`:
255    /// strictly increasing prices give `avg_loss == 0` while `avg_gain > 0`,
256    /// the textbook overbought saturation case. Random proptest inputs
257    /// virtually never satisfy `al == 0 && ag != 0`, so this needs an
258    /// explicit monotone series.
259    #[test]
260    fn naive_helper_monotone_up_yields_100() {
261        let prices: Vec<f64> = (1..=20).map(f64::from).collect();
262        let ks = rsi_naive(&prices, 5);
263        for r in ks.into_iter().skip(5) {
264            assert_eq!(r.expect("ready after period+1 inputs"), 100.0);
265        }
266    }
267
268    #[test]
269    fn warmup_period_is_period_plus_one() {
270        let rsi = Rsi::new(14).unwrap();
271        assert_eq!(rsi.warmup_period(), 15);
272    }
273
274    #[test]
275    fn first_emission_at_index_period() {
276        // RSI(14) needs 14 diffs => 15 inputs before first value.
277        let prices: Vec<f64> = (1..=20).map(f64::from).collect();
278        let mut rsi = Rsi::new(14).unwrap();
279        let out = rsi.batch(&prices);
280        // indices 0..14 -> None, index 14 -> first Some
281        for x in &out[..14] {
282            assert!(x.is_none());
283        }
284        assert!(out[14].is_some());
285    }
286
287    #[test]
288    fn pure_uptrend_yields_rsi_100() {
289        let prices: Vec<f64> = (1..=20).map(f64::from).collect();
290        let mut rsi = Rsi::new(14).unwrap();
291        let out = rsi.batch(&prices);
292        // All diffs are positive => avg_loss == 0 => RSI == 100
293        for v in out.iter().filter_map(|x| x.as_ref()) {
294            assert_relative_eq!(*v, 100.0, epsilon = 1e-9);
295        }
296    }
297
298    #[test]
299    fn pure_downtrend_yields_rsi_0() {
300        let prices: Vec<f64> = (1..=20).rev().map(f64::from).collect();
301        let mut rsi = Rsi::new(14).unwrap();
302        let out = rsi.batch(&prices);
303        for v in out.iter().filter_map(|x| x.as_ref()) {
304            assert_relative_eq!(*v, 0.0, epsilon = 1e-9);
305        }
306    }
307
308    #[test]
309    fn flat_series_yields_rsi_50() {
310        let prices = [10.0_f64; 30];
311        let mut rsi = Rsi::new(14).unwrap();
312        let out = rsi.batch(&prices);
313        for v in out.iter().filter_map(|x| x.as_ref()) {
314            assert_relative_eq!(*v, 50.0, epsilon = 1e-12);
315        }
316    }
317
318    #[test]
319    fn classic_wilder_textbook_values() {
320        // Wilder's original example from "New Concepts in Technical Trading Systems",
321        // 14-period RSI. We compute the first value at index 14 and compare to the
322        // value Wilder publishes (~70.46).
323        // Source: classic textbook table, reproduced in many references (e.g. Investopedia).
324        let prices = [
325            44.34, 44.09, 44.15, 43.61, 44.33, 44.83, 45.10, 45.42, 45.84, 46.08, 45.89, 46.03,
326            45.61, 46.28, 46.28,
327        ];
328        let mut rsi = Rsi::new(14).unwrap();
329        let out = rsi.batch(&prices);
330        let first = out[14].expect("first RSI emitted at index period");
331        assert_relative_eq!(first, 70.464, epsilon = 0.05);
332    }
333
334    #[test]
335    fn rsi_stays_in_0_100_range() {
336        let prices: Vec<f64> = (0..200)
337            .map(|i| 100.0 + (f64::from(i) * 0.7).sin() * 10.0)
338            .collect();
339        let mut rsi = Rsi::new(14).unwrap();
340        for x in rsi.batch(&prices).into_iter().flatten() {
341            assert!((0.0..=100.0).contains(&x), "RSI out of range: {x}");
342        }
343    }
344
345    #[test]
346    fn reset_clears_state() {
347        let mut rsi = Rsi::new(5).unwrap();
348        rsi.batch(&[1.0, 2.0, 3.0, 2.0, 4.0, 5.0, 6.0]);
349        assert!(rsi.is_ready());
350        rsi.reset();
351        assert!(!rsi.is_ready());
352        assert_eq!(rsi.update(1.0), None);
353    }
354
355    #[test]
356    fn batch_equals_streaming() {
357        let prices: Vec<f64> = (1..=40)
358            .map(|i| (f64::from(i) * 0.3).sin() * 5.0 + f64::from(i))
359            .collect();
360        let mut a = Rsi::new(7).unwrap();
361        let mut b = Rsi::new(7).unwrap();
362        assert_eq!(
363            a.batch(&prices),
364            prices.iter().map(|p| b.update(*p)).collect::<Vec<_>>()
365        );
366    }
367
368    #[test]
369    fn ignores_non_finite_input() {
370        let mut rsi = Rsi::new(3).unwrap();
371        rsi.batch(&[1.0, 2.0, 3.0, 4.0]);
372        let before = rsi.value();
373        assert!(before.is_some());
374        assert_eq!(rsi.update(f64::NAN), before);
375        assert_eq!(rsi.update(f64::INFINITY), before);
376        assert_eq!(rsi.value(), before);
377    }
378
379    proptest::proptest! {
380        #![proptest_config(proptest::test_runner::Config::with_cases(48))]
381        #[test]
382        fn rsi_matches_naive(
383            period in 1usize..20,
384            prices in proptest::collection::vec(1.0_f64..1000.0, 0..150),
385        ) {
386            let mut rsi = Rsi::new(period).unwrap();
387            let got = rsi.batch(&prices);
388            let want = rsi_naive(&prices, period);
389            proptest::prop_assert_eq!(got.len(), want.len());
390            for (g, w) in got.iter().zip(want.iter()) {
391                match (g, w) {
392                    (None, None) => {}
393                    (Some(a), Some(b)) => proptest::prop_assert!(
394                        (a - b).abs() < 1e-7,
395                        "got={a} want={b}"
396                    ),
397                    _ => proptest::prop_assert!(false, "warmup mismatch"),
398                }
399            }
400        }
401    }
402}