Skip to main content

wickra_core/indicators/
rsi.rs

1//! Relative Strength Index using Wilder's smoothing.
2
3use crate::error::{Error, Result};
4use crate::traits::Indicator;
5
6/// Relative Strength Index (Wilder, 1978).
7///
8/// Uses Wilder's smoothing (an EMA with `alpha = 1 / period`). The first output
9/// is produced after `period + 1` inputs: the seed averages the first `period`
10/// gains and losses, and the first emitted RSI corresponds to the input at
11/// index `period`.
12///
13/// # Example
14///
15/// ```
16/// use wickra_core::{Indicator, Rsi};
17///
18/// let mut indicator = Rsi::new(3).unwrap();
19/// let mut last = None;
20/// for i in 0..80 {
21///     last = indicator.update(100.0 + f64::from(i));
22/// }
23/// assert!(last.is_some());
24/// ```
25#[derive(Debug, Clone)]
26pub struct Rsi {
27    period: usize,
28    prev_close: Option<f64>,
29    // Wilder seeds with the simple average of the first `period` gains/losses,
30    // then transitions to recursive smoothing.
31    seed_buf_gains: Vec<f64>,
32    seed_buf_losses: Vec<f64>,
33    avg_gain: Option<f64>,
34    avg_loss: Option<f64>,
35    last_value: Option<f64>,
36}
37
38impl Rsi {
39    /// Construct an RSI with the given Wilder period.
40    ///
41    /// # Errors
42    ///
43    /// Returns [`Error::PeriodZero`] if `period == 0`.
44    pub fn new(period: usize) -> Result<Self> {
45        if period == 0 {
46            return Err(Error::PeriodZero);
47        }
48        Ok(Self {
49            period,
50            prev_close: None,
51            seed_buf_gains: Vec::with_capacity(period),
52            seed_buf_losses: Vec::with_capacity(period),
53            avg_gain: None,
54            avg_loss: None,
55            last_value: None,
56        })
57    }
58
59    /// Configured period.
60    pub const fn period(&self) -> usize {
61        self.period
62    }
63
64    /// Current value if available.
65    pub const fn value(&self) -> Option<f64> {
66        self.last_value
67    }
68
69    fn rsi_from_avgs(avg_gain: f64, avg_loss: f64) -> f64 {
70        if avg_loss == 0.0 {
71            if avg_gain == 0.0 {
72                // No movement at all -> RSI undefined; standard convention returns 50.
73                50.0
74            } else {
75                100.0
76            }
77        } else {
78            let rs = avg_gain / avg_loss;
79            100.0 - 100.0 / (1.0 + rs)
80        }
81    }
82}
83
84impl Indicator for Rsi {
85    type Input = f64;
86    type Output = f64;
87
88    fn update(&mut self, input: f64) -> Option<f64> {
89        if !input.is_finite() {
90            return self.last_value;
91        }
92
93        let Some(prev) = self.prev_close else {
94            self.prev_close = Some(input);
95            return None;
96        };
97        self.prev_close = Some(input);
98
99        let diff = input - prev;
100        let gain = if diff > 0.0 { diff } else { 0.0 };
101        let loss = if diff < 0.0 { -diff } else { 0.0 };
102
103        if let (Some(ag), Some(al)) = (self.avg_gain, self.avg_loss) {
104            let n = self.period as f64;
105            let new_ag = (ag * (n - 1.0) + gain) / n;
106            let new_al = (al * (n - 1.0) + loss) / n;
107            self.avg_gain = Some(new_ag);
108            self.avg_loss = Some(new_al);
109            let v = Self::rsi_from_avgs(new_ag, new_al);
110            self.last_value = Some(v);
111            return Some(v);
112        }
113
114        self.seed_buf_gains.push(gain);
115        self.seed_buf_losses.push(loss);
116        if self.seed_buf_gains.len() == self.period {
117            let ag = self.seed_buf_gains.iter().sum::<f64>() / self.period as f64;
118            let al = self.seed_buf_losses.iter().sum::<f64>() / self.period as f64;
119            self.avg_gain = Some(ag);
120            self.avg_loss = Some(al);
121            let v = Self::rsi_from_avgs(ag, al);
122            self.last_value = Some(v);
123            return Some(v);
124        }
125        None
126    }
127
128    fn reset(&mut self) {
129        self.prev_close = None;
130        self.seed_buf_gains.clear();
131        self.seed_buf_losses.clear();
132        self.avg_gain = None;
133        self.avg_loss = None;
134        self.last_value = None;
135    }
136
137    fn warmup_period(&self) -> usize {
138        self.period + 1
139    }
140
141    fn is_ready(&self) -> bool {
142        self.last_value.is_some()
143    }
144
145    fn name(&self) -> &'static str {
146        "RSI"
147    }
148}
149
150#[cfg(test)]
151mod tests {
152    use super::*;
153    use crate::traits::BatchExt;
154    use approx::assert_relative_eq;
155
156    /// Independent reference: Wilder RSI computed straight from the definition.
157    fn rsi_naive(prices: &[f64], period: usize) -> Vec<Option<f64>> {
158        let n = period as f64;
159        let mut out = vec![None; prices.len()];
160        let mut gains: Vec<f64> = Vec::new();
161        let mut losses: Vec<f64> = Vec::new();
162        let mut avg_gain: Option<f64> = None;
163        let mut avg_loss: Option<f64> = None;
164        let rsi_val = |ag: f64, al: f64| -> f64 {
165            if al == 0.0 {
166                if ag == 0.0 {
167                    50.0
168                } else {
169                    100.0
170                }
171            } else {
172                100.0 - 100.0 / (1.0 + ag / al)
173            }
174        };
175        for i in 1..prices.len() {
176            let diff = prices[i] - prices[i - 1];
177            let gain = if diff > 0.0 { diff } else { 0.0 };
178            let loss = if diff < 0.0 { -diff } else { 0.0 };
179            if let (Some(ag), Some(al)) = (avg_gain, avg_loss) {
180                let nag = (ag * (n - 1.0) + gain) / n;
181                let nal = (al * (n - 1.0) + loss) / n;
182                avg_gain = Some(nag);
183                avg_loss = Some(nal);
184                out[i] = Some(rsi_val(nag, nal));
185            } else {
186                gains.push(gain);
187                losses.push(loss);
188                if gains.len() == period {
189                    let ag = gains.iter().sum::<f64>() / n;
190                    let al = losses.iter().sum::<f64>() / n;
191                    avg_gain = Some(ag);
192                    avg_loss = Some(al);
193                    out[i] = Some(rsi_val(ag, al));
194                }
195            }
196        }
197        out
198    }
199
200    #[test]
201    fn new_rejects_zero_period() {
202        assert!(matches!(Rsi::new(0), Err(Error::PeriodZero)));
203    }
204
205    /// Cover the const accessors `period` / `value` (60-67) and the
206    /// Indicator-impl `name` body (145-147). `warmup_period` is covered
207    /// already by `warmup_period_is_period_plus_one`.
208    #[test]
209    fn accessors_and_metadata() {
210        let mut rsi = Rsi::new(14).unwrap();
211        assert_eq!(rsi.period(), 14);
212        assert_eq!(rsi.name(), "RSI");
213        assert_eq!(rsi.value(), None);
214        for i in 1..=15 {
215            rsi.update(100.0 + f64::from(i));
216        }
217        assert!(rsi.value().is_some());
218    }
219
220    /// Cover the `ag == 0` branch (line 167) of the test-helper `rsi_naive`:
221    /// when both `avg_gain` and `avg_loss` are 0 (a perfectly flat series),
222    /// the helper must return the neutral 50.0. The proptest reference uses
223    /// random inputs that essentially never hit zero gains AND zero losses
224    /// simultaneously, leaving this branch dead in the helper.
225    #[test]
226    fn naive_helper_flat_series_yields_50() {
227        let ks = rsi_naive(&[42.0; 20], 5);
228        for r in ks.into_iter().skip(5) {
229            assert_eq!(r.expect("ready after period+1 inputs"), 50.0);
230        }
231    }
232
233    /// Cover the `100.0` branch (line 169) of the test-helper `rsi_naive`:
234    /// strictly increasing prices give `avg_loss == 0` while `avg_gain > 0`,
235    /// the textbook overbought saturation case. Random proptest inputs
236    /// virtually never satisfy `al == 0 && ag != 0`, so this needs an
237    /// explicit monotone series.
238    #[test]
239    fn naive_helper_monotone_up_yields_100() {
240        let prices: Vec<f64> = (1..=20).map(f64::from).collect();
241        let ks = rsi_naive(&prices, 5);
242        for r in ks.into_iter().skip(5) {
243            assert_eq!(r.expect("ready after period+1 inputs"), 100.0);
244        }
245    }
246
247    #[test]
248    fn warmup_period_is_period_plus_one() {
249        let rsi = Rsi::new(14).unwrap();
250        assert_eq!(rsi.warmup_period(), 15);
251    }
252
253    #[test]
254    fn first_emission_at_index_period() {
255        // RSI(14) needs 14 diffs => 15 inputs before first value.
256        let prices: Vec<f64> = (1..=20).map(f64::from).collect();
257        let mut rsi = Rsi::new(14).unwrap();
258        let out = rsi.batch(&prices);
259        // indices 0..14 -> None, index 14 -> first Some
260        for x in &out[..14] {
261            assert!(x.is_none());
262        }
263        assert!(out[14].is_some());
264    }
265
266    #[test]
267    fn pure_uptrend_yields_rsi_100() {
268        let prices: Vec<f64> = (1..=20).map(f64::from).collect();
269        let mut rsi = Rsi::new(14).unwrap();
270        let out = rsi.batch(&prices);
271        // All diffs are positive => avg_loss == 0 => RSI == 100
272        for v in out.iter().filter_map(|x| x.as_ref()) {
273            assert_relative_eq!(*v, 100.0, epsilon = 1e-9);
274        }
275    }
276
277    #[test]
278    fn pure_downtrend_yields_rsi_0() {
279        let prices: Vec<f64> = (1..=20).rev().map(f64::from).collect();
280        let mut rsi = Rsi::new(14).unwrap();
281        let out = rsi.batch(&prices);
282        for v in out.iter().filter_map(|x| x.as_ref()) {
283            assert_relative_eq!(*v, 0.0, epsilon = 1e-9);
284        }
285    }
286
287    #[test]
288    fn flat_series_yields_rsi_50() {
289        let prices = [10.0_f64; 30];
290        let mut rsi = Rsi::new(14).unwrap();
291        let out = rsi.batch(&prices);
292        for v in out.iter().filter_map(|x| x.as_ref()) {
293            assert_relative_eq!(*v, 50.0, epsilon = 1e-12);
294        }
295    }
296
297    #[test]
298    fn classic_wilder_textbook_values() {
299        // Wilder's original example from "New Concepts in Technical Trading Systems",
300        // 14-period RSI. We compute the first value at index 14 and compare to the
301        // value Wilder publishes (~70.46).
302        // Source: classic textbook table, reproduced in many references (e.g. Investopedia).
303        let prices = [
304            44.34, 44.09, 44.15, 43.61, 44.33, 44.83, 45.10, 45.42, 45.84, 46.08, 45.89, 46.03,
305            45.61, 46.28, 46.28,
306        ];
307        let mut rsi = Rsi::new(14).unwrap();
308        let out = rsi.batch(&prices);
309        let first = out[14].expect("first RSI emitted at index period");
310        assert_relative_eq!(first, 70.464, epsilon = 0.05);
311    }
312
313    #[test]
314    fn rsi_stays_in_0_100_range() {
315        let prices: Vec<f64> = (0..200)
316            .map(|i| 100.0 + (f64::from(i) * 0.7).sin() * 10.0)
317            .collect();
318        let mut rsi = Rsi::new(14).unwrap();
319        for x in rsi.batch(&prices).into_iter().flatten() {
320            assert!((0.0..=100.0).contains(&x), "RSI out of range: {x}");
321        }
322    }
323
324    #[test]
325    fn reset_clears_state() {
326        let mut rsi = Rsi::new(5).unwrap();
327        rsi.batch(&[1.0, 2.0, 3.0, 2.0, 4.0, 5.0, 6.0]);
328        assert!(rsi.is_ready());
329        rsi.reset();
330        assert!(!rsi.is_ready());
331        assert_eq!(rsi.update(1.0), None);
332    }
333
334    #[test]
335    fn batch_equals_streaming() {
336        let prices: Vec<f64> = (1..=40)
337            .map(|i| (f64::from(i) * 0.3).sin() * 5.0 + f64::from(i))
338            .collect();
339        let mut a = Rsi::new(7).unwrap();
340        let mut b = Rsi::new(7).unwrap();
341        assert_eq!(
342            a.batch(&prices),
343            prices.iter().map(|p| b.update(*p)).collect::<Vec<_>>()
344        );
345    }
346
347    #[test]
348    fn ignores_non_finite_input() {
349        let mut rsi = Rsi::new(3).unwrap();
350        rsi.batch(&[1.0, 2.0, 3.0, 4.0]);
351        let before = rsi.value();
352        assert!(before.is_some());
353        assert_eq!(rsi.update(f64::NAN), before);
354        assert_eq!(rsi.update(f64::INFINITY), before);
355        assert_eq!(rsi.value(), before);
356    }
357
358    proptest::proptest! {
359        #![proptest_config(proptest::test_runner::Config::with_cases(48))]
360        #[test]
361        fn rsi_matches_naive(
362            period in 1usize..20,
363            prices in proptest::collection::vec(1.0_f64..1000.0, 0..150),
364        ) {
365            let mut rsi = Rsi::new(period).unwrap();
366            let got = rsi.batch(&prices);
367            let want = rsi_naive(&prices, period);
368            proptest::prop_assert_eq!(got.len(), want.len());
369            for (g, w) in got.iter().zip(want.iter()) {
370                match (g, w) {
371                    (None, None) => {}
372                    (Some(a), Some(b)) => proptest::prop_assert!(
373                        (a - b).abs() < 1e-7,
374                        "got={a} want={b}"
375                    ),
376                    _ => proptest::prop_assert!(false, "warmup mismatch"),
377                }
378            }
379        }
380    }
381}