Skip to main content

wickra_core/indicators/
rsi.rs

1//! Relative Strength Index using Wilder's smoothing.
2
3use crate::error::{Error, Result};
4use crate::traits::Indicator;
5
6/// Relative Strength Index (Wilder, 1978).
7///
8/// Uses Wilder's smoothing (an EMA with `alpha = 1 / period`). The first output
9/// is produced after `period + 1` inputs: the seed averages the first `period`
10/// gains and losses, and the first emitted RSI corresponds to the input at
11/// index `period`.
12///
13/// # Example
14///
15/// ```
16/// use wickra_core::{Indicator, Rsi};
17///
18/// let mut indicator = Rsi::new(3).unwrap();
19/// let mut last = None;
20/// for i in 0..80 {
21///     last = indicator.update(100.0 + f64::from(i));
22/// }
23/// assert!(last.is_some());
24/// ```
25#[derive(Debug, Clone)]
26pub struct Rsi {
27    period: usize,
28    prev_close: Option<f64>,
29    // Wilder seeds with the simple average of the first `period` gains/losses,
30    // then transitions to recursive smoothing.
31    seed_buf_gains: Vec<f64>,
32    seed_buf_losses: Vec<f64>,
33    avg_gain: Option<f64>,
34    avg_loss: Option<f64>,
35    last_value: Option<f64>,
36}
37
38impl Rsi {
39    /// Construct an RSI with the given Wilder period.
40    ///
41    /// # Errors
42    ///
43    /// Returns [`Error::PeriodZero`] if `period == 0`.
44    pub fn new(period: usize) -> Result<Self> {
45        if period == 0 {
46            return Err(Error::PeriodZero);
47        }
48        Ok(Self {
49            period,
50            prev_close: None,
51            seed_buf_gains: Vec::with_capacity(period),
52            seed_buf_losses: Vec::with_capacity(period),
53            avg_gain: None,
54            avg_loss: None,
55            last_value: None,
56        })
57    }
58
59    /// Configured period.
60    pub const fn period(&self) -> usize {
61        self.period
62    }
63
64    /// Current value if available.
65    pub const fn value(&self) -> Option<f64> {
66        self.last_value
67    }
68
69    fn rsi_from_avgs(avg_gain: f64, avg_loss: f64) -> f64 {
70        if avg_loss == 0.0 {
71            if avg_gain == 0.0 {
72                // No movement at all -> RSI undefined; standard convention returns 50.
73                50.0
74            } else {
75                100.0
76            }
77        } else {
78            let rs = avg_gain / avg_loss;
79            100.0 - 100.0 / (1.0 + rs)
80        }
81    }
82}
83
84impl Indicator for Rsi {
85    type Input = f64;
86    type Output = f64;
87
88    fn update(&mut self, input: f64) -> Option<f64> {
89        if !input.is_finite() {
90            return self.last_value;
91        }
92
93        let Some(prev) = self.prev_close else {
94            self.prev_close = Some(input);
95            return None;
96        };
97        self.prev_close = Some(input);
98
99        let diff = input - prev;
100        let gain = if diff > 0.0 { diff } else { 0.0 };
101        let loss = if diff < 0.0 { -diff } else { 0.0 };
102
103        if let (Some(ag), Some(al)) = (self.avg_gain, self.avg_loss) {
104            let n = self.period as f64;
105            let new_ag = (ag * (n - 1.0) + gain) / n;
106            let new_al = (al * (n - 1.0) + loss) / n;
107            self.avg_gain = Some(new_ag);
108            self.avg_loss = Some(new_al);
109            let v = Self::rsi_from_avgs(new_ag, new_al);
110            self.last_value = Some(v);
111            return Some(v);
112        }
113
114        self.seed_buf_gains.push(gain);
115        self.seed_buf_losses.push(loss);
116        if self.seed_buf_gains.len() == self.period {
117            let ag = self.seed_buf_gains.iter().sum::<f64>() / self.period as f64;
118            let al = self.seed_buf_losses.iter().sum::<f64>() / self.period as f64;
119            self.avg_gain = Some(ag);
120            self.avg_loss = Some(al);
121            let v = Self::rsi_from_avgs(ag, al);
122            self.last_value = Some(v);
123            return Some(v);
124        }
125        None
126    }
127
128    fn reset(&mut self) {
129        self.prev_close = None;
130        self.seed_buf_gains.clear();
131        self.seed_buf_losses.clear();
132        self.avg_gain = None;
133        self.avg_loss = None;
134        self.last_value = None;
135    }
136
137    fn warmup_period(&self) -> usize {
138        self.period + 1
139    }
140
141    fn is_ready(&self) -> bool {
142        self.last_value.is_some()
143    }
144
145    fn name(&self) -> &'static str {
146        "RSI"
147    }
148}
149
150#[cfg(test)]
151mod tests {
152    use super::*;
153    use crate::traits::BatchExt;
154    use approx::assert_relative_eq;
155
156    /// Independent reference: Wilder RSI computed straight from the definition.
157    fn rsi_naive(prices: &[f64], period: usize) -> Vec<Option<f64>> {
158        let n = period as f64;
159        let mut out = vec![None; prices.len()];
160        let mut gains: Vec<f64> = Vec::new();
161        let mut losses: Vec<f64> = Vec::new();
162        let mut avg_gain: Option<f64> = None;
163        let mut avg_loss: Option<f64> = None;
164        let rsi_val = |ag: f64, al: f64| -> f64 {
165            if al == 0.0 {
166                if ag == 0.0 {
167                    50.0
168                } else {
169                    100.0
170                }
171            } else {
172                100.0 - 100.0 / (1.0 + ag / al)
173            }
174        };
175        for i in 1..prices.len() {
176            let diff = prices[i] - prices[i - 1];
177            let gain = if diff > 0.0 { diff } else { 0.0 };
178            let loss = if diff < 0.0 { -diff } else { 0.0 };
179            if let (Some(ag), Some(al)) = (avg_gain, avg_loss) {
180                let nag = (ag * (n - 1.0) + gain) / n;
181                let nal = (al * (n - 1.0) + loss) / n;
182                avg_gain = Some(nag);
183                avg_loss = Some(nal);
184                out[i] = Some(rsi_val(nag, nal));
185            } else {
186                gains.push(gain);
187                losses.push(loss);
188                if gains.len() == period {
189                    let ag = gains.iter().sum::<f64>() / n;
190                    let al = losses.iter().sum::<f64>() / n;
191                    avg_gain = Some(ag);
192                    avg_loss = Some(al);
193                    out[i] = Some(rsi_val(ag, al));
194                }
195            }
196        }
197        out
198    }
199
200    #[test]
201    fn new_rejects_zero_period() {
202        assert!(matches!(Rsi::new(0), Err(Error::PeriodZero)));
203    }
204
205    /// Cover the const accessors `period` / `value` (60-67) and the
206    /// Indicator-impl `name` body (145-147). `warmup_period` is covered
207    /// already by `warmup_period_is_period_plus_one`.
208    #[test]
209    fn accessors_and_metadata() {
210        let mut rsi = Rsi::new(14).unwrap();
211        assert_eq!(rsi.period(), 14);
212        assert_eq!(rsi.name(), "RSI");
213        assert_eq!(rsi.value(), None);
214        for i in 1..=15 {
215            rsi.update(100.0 + f64::from(i));
216        }
217        assert!(rsi.value().is_some());
218    }
219
220    /// Cover the `ag == 0` branch (line 167) of the test-helper `rsi_naive`:
221    /// when both `avg_gain` and `avg_loss` are 0 (a perfectly flat series),
222    /// the helper must return the neutral 50.0. The proptest reference uses
223    /// random inputs that essentially never hit zero gains AND zero losses
224    /// simultaneously, leaving this branch dead in the helper.
225    #[test]
226    fn naive_helper_flat_series_yields_50() {
227        let ks = rsi_naive(&[42.0; 20], 5);
228        for r in ks.into_iter().skip(5) {
229            assert_eq!(r.expect("ready after period+1 inputs"), 50.0);
230        }
231    }
232
233    #[test]
234    fn warmup_period_is_period_plus_one() {
235        let rsi = Rsi::new(14).unwrap();
236        assert_eq!(rsi.warmup_period(), 15);
237    }
238
239    #[test]
240    fn first_emission_at_index_period() {
241        // RSI(14) needs 14 diffs => 15 inputs before first value.
242        let prices: Vec<f64> = (1..=20).map(f64::from).collect();
243        let mut rsi = Rsi::new(14).unwrap();
244        let out = rsi.batch(&prices);
245        // indices 0..14 -> None, index 14 -> first Some
246        for x in &out[..14] {
247            assert!(x.is_none());
248        }
249        assert!(out[14].is_some());
250    }
251
252    #[test]
253    fn pure_uptrend_yields_rsi_100() {
254        let prices: Vec<f64> = (1..=20).map(f64::from).collect();
255        let mut rsi = Rsi::new(14).unwrap();
256        let out = rsi.batch(&prices);
257        // All diffs are positive => avg_loss == 0 => RSI == 100
258        for v in out.iter().filter_map(|x| x.as_ref()) {
259            assert_relative_eq!(*v, 100.0, epsilon = 1e-9);
260        }
261    }
262
263    #[test]
264    fn pure_downtrend_yields_rsi_0() {
265        let prices: Vec<f64> = (1..=20).rev().map(f64::from).collect();
266        let mut rsi = Rsi::new(14).unwrap();
267        let out = rsi.batch(&prices);
268        for v in out.iter().filter_map(|x| x.as_ref()) {
269            assert_relative_eq!(*v, 0.0, epsilon = 1e-9);
270        }
271    }
272
273    #[test]
274    fn flat_series_yields_rsi_50() {
275        let prices = [10.0_f64; 30];
276        let mut rsi = Rsi::new(14).unwrap();
277        let out = rsi.batch(&prices);
278        for v in out.iter().filter_map(|x| x.as_ref()) {
279            assert_relative_eq!(*v, 50.0, epsilon = 1e-12);
280        }
281    }
282
283    #[test]
284    fn classic_wilder_textbook_values() {
285        // Wilder's original example from "New Concepts in Technical Trading Systems",
286        // 14-period RSI. We compute the first value at index 14 and compare to the
287        // value Wilder publishes (~70.46).
288        // Source: classic textbook table, reproduced in many references (e.g. Investopedia).
289        let prices = [
290            44.34, 44.09, 44.15, 43.61, 44.33, 44.83, 45.10, 45.42, 45.84, 46.08, 45.89, 46.03,
291            45.61, 46.28, 46.28,
292        ];
293        let mut rsi = Rsi::new(14).unwrap();
294        let out = rsi.batch(&prices);
295        let first = out[14].expect("first RSI emitted at index period");
296        assert_relative_eq!(first, 70.464, epsilon = 0.05);
297    }
298
299    #[test]
300    fn rsi_stays_in_0_100_range() {
301        let prices: Vec<f64> = (0..200)
302            .map(|i| 100.0 + (f64::from(i) * 0.7).sin() * 10.0)
303            .collect();
304        let mut rsi = Rsi::new(14).unwrap();
305        for x in rsi.batch(&prices).into_iter().flatten() {
306            assert!((0.0..=100.0).contains(&x), "RSI out of range: {x}");
307        }
308    }
309
310    #[test]
311    fn reset_clears_state() {
312        let mut rsi = Rsi::new(5).unwrap();
313        rsi.batch(&[1.0, 2.0, 3.0, 2.0, 4.0, 5.0, 6.0]);
314        assert!(rsi.is_ready());
315        rsi.reset();
316        assert!(!rsi.is_ready());
317        assert_eq!(rsi.update(1.0), None);
318    }
319
320    #[test]
321    fn batch_equals_streaming() {
322        let prices: Vec<f64> = (1..=40)
323            .map(|i| (f64::from(i) * 0.3).sin() * 5.0 + f64::from(i))
324            .collect();
325        let mut a = Rsi::new(7).unwrap();
326        let mut b = Rsi::new(7).unwrap();
327        assert_eq!(
328            a.batch(&prices),
329            prices.iter().map(|p| b.update(*p)).collect::<Vec<_>>()
330        );
331    }
332
333    #[test]
334    fn ignores_non_finite_input() {
335        let mut rsi = Rsi::new(3).unwrap();
336        rsi.batch(&[1.0, 2.0, 3.0, 4.0]);
337        let before = rsi.value();
338        assert!(before.is_some());
339        assert_eq!(rsi.update(f64::NAN), before);
340        assert_eq!(rsi.update(f64::INFINITY), before);
341        assert_eq!(rsi.value(), before);
342    }
343
344    proptest::proptest! {
345        #![proptest_config(proptest::test_runner::Config::with_cases(48))]
346        #[test]
347        fn rsi_matches_naive(
348            period in 1usize..20,
349            prices in proptest::collection::vec(1.0_f64..1000.0, 0..150),
350        ) {
351            let mut rsi = Rsi::new(period).unwrap();
352            let got = rsi.batch(&prices);
353            let want = rsi_naive(&prices, period);
354            proptest::prop_assert_eq!(got.len(), want.len());
355            for (g, w) in got.iter().zip(want.iter()) {
356                match (g, w) {
357                    (None, None) => {}
358                    (Some(a), Some(b)) => proptest::prop_assert!(
359                        (a - b).abs() < 1e-7,
360                        "got={a} want={b}"
361                    ),
362                    _ => proptest::prop_assert!(false, "warmup mismatch"),
363                }
364            }
365        }
366    }
367}