1use crate::error::{Error, Result};
4use crate::traits::Indicator;
5
6#[derive(Debug, Clone)]
26pub struct Rsi {
27 period: usize,
28 prev_close: Option<f64>,
29 seed_buf_gains: Vec<f64>,
32 seed_buf_losses: Vec<f64>,
33 avg_gain: Option<f64>,
34 avg_loss: Option<f64>,
35 last_value: Option<f64>,
36}
37
38impl Rsi {
39 pub fn new(period: usize) -> Result<Self> {
45 if period == 0 {
46 return Err(Error::PeriodZero);
47 }
48 Ok(Self {
49 period,
50 prev_close: None,
51 seed_buf_gains: Vec::with_capacity(period),
52 seed_buf_losses: Vec::with_capacity(period),
53 avg_gain: None,
54 avg_loss: None,
55 last_value: None,
56 })
57 }
58
59 pub const fn period(&self) -> usize {
61 self.period
62 }
63
64 pub const fn value(&self) -> Option<f64> {
66 self.last_value
67 }
68
69 fn rsi_from_avgs(avg_gain: f64, avg_loss: f64) -> f64 {
70 if avg_loss == 0.0 {
71 if avg_gain == 0.0 {
72 50.0
74 } else {
75 100.0
76 }
77 } else {
78 let rs = avg_gain / avg_loss;
79 100.0 - 100.0 / (1.0 + rs)
80 }
81 }
82}
83
84impl Indicator for Rsi {
85 type Input = f64;
86 type Output = f64;
87
88 fn update(&mut self, input: f64) -> Option<f64> {
89 if !input.is_finite() {
90 return self.last_value;
91 }
92
93 let Some(prev) = self.prev_close else {
94 self.prev_close = Some(input);
95 return None;
96 };
97 self.prev_close = Some(input);
98
99 let diff = input - prev;
100 let gain = if diff > 0.0 { diff } else { 0.0 };
101 let loss = if diff < 0.0 { -diff } else { 0.0 };
102
103 if let (Some(ag), Some(al)) = (self.avg_gain, self.avg_loss) {
104 let n = self.period as f64;
105 let new_ag = (ag * (n - 1.0) + gain) / n;
106 let new_al = (al * (n - 1.0) + loss) / n;
107 self.avg_gain = Some(new_ag);
108 self.avg_loss = Some(new_al);
109 let v = Self::rsi_from_avgs(new_ag, new_al);
110 self.last_value = Some(v);
111 return Some(v);
112 }
113
114 self.seed_buf_gains.push(gain);
115 self.seed_buf_losses.push(loss);
116 if self.seed_buf_gains.len() == self.period {
117 let ag = self.seed_buf_gains.iter().sum::<f64>() / self.period as f64;
118 let al = self.seed_buf_losses.iter().sum::<f64>() / self.period as f64;
119 self.avg_gain = Some(ag);
120 self.avg_loss = Some(al);
121 let v = Self::rsi_from_avgs(ag, al);
122 self.last_value = Some(v);
123 return Some(v);
124 }
125 None
126 }
127
128 fn reset(&mut self) {
129 self.prev_close = None;
130 self.seed_buf_gains.clear();
131 self.seed_buf_losses.clear();
132 self.avg_gain = None;
133 self.avg_loss = None;
134 self.last_value = None;
135 }
136
137 fn warmup_period(&self) -> usize {
138 self.period + 1
139 }
140
141 fn is_ready(&self) -> bool {
142 self.last_value.is_some()
143 }
144
145 fn name(&self) -> &'static str {
146 "RSI"
147 }
148}
149
150#[cfg(test)]
151mod tests {
152 use super::*;
153 use crate::traits::BatchExt;
154 use approx::assert_relative_eq;
155
156 fn rsi_naive(prices: &[f64], period: usize) -> Vec<Option<f64>> {
158 let n = period as f64;
159 let mut out = vec![None; prices.len()];
160 let mut gains: Vec<f64> = Vec::new();
161 let mut losses: Vec<f64> = Vec::new();
162 let mut avg_gain: Option<f64> = None;
163 let mut avg_loss: Option<f64> = None;
164 let rsi_val = |ag: f64, al: f64| -> f64 {
165 if al == 0.0 {
166 if ag == 0.0 {
167 50.0
168 } else {
169 100.0
170 }
171 } else {
172 100.0 - 100.0 / (1.0 + ag / al)
173 }
174 };
175 for i in 1..prices.len() {
176 let diff = prices[i] - prices[i - 1];
177 let gain = if diff > 0.0 { diff } else { 0.0 };
178 let loss = if diff < 0.0 { -diff } else { 0.0 };
179 if let (Some(ag), Some(al)) = (avg_gain, avg_loss) {
180 let nag = (ag * (n - 1.0) + gain) / n;
181 let nal = (al * (n - 1.0) + loss) / n;
182 avg_gain = Some(nag);
183 avg_loss = Some(nal);
184 out[i] = Some(rsi_val(nag, nal));
185 } else {
186 gains.push(gain);
187 losses.push(loss);
188 if gains.len() == period {
189 let ag = gains.iter().sum::<f64>() / n;
190 let al = losses.iter().sum::<f64>() / n;
191 avg_gain = Some(ag);
192 avg_loss = Some(al);
193 out[i] = Some(rsi_val(ag, al));
194 }
195 }
196 }
197 out
198 }
199
200 #[test]
201 fn new_rejects_zero_period() {
202 assert!(matches!(Rsi::new(0), Err(Error::PeriodZero)));
203 }
204
205 #[test]
209 fn accessors_and_metadata() {
210 let mut rsi = Rsi::new(14).unwrap();
211 assert_eq!(rsi.period(), 14);
212 assert_eq!(rsi.name(), "RSI");
213 assert_eq!(rsi.value(), None);
214 for i in 1..=15 {
215 rsi.update(100.0 + f64::from(i));
216 }
217 assert!(rsi.value().is_some());
218 }
219
220 #[test]
226 fn naive_helper_flat_series_yields_50() {
227 let ks = rsi_naive(&[42.0; 20], 5);
228 for r in ks.into_iter().skip(5) {
229 assert_eq!(r.expect("ready after period+1 inputs"), 50.0);
230 }
231 }
232
233 #[test]
234 fn warmup_period_is_period_plus_one() {
235 let rsi = Rsi::new(14).unwrap();
236 assert_eq!(rsi.warmup_period(), 15);
237 }
238
239 #[test]
240 fn first_emission_at_index_period() {
241 let prices: Vec<f64> = (1..=20).map(f64::from).collect();
243 let mut rsi = Rsi::new(14).unwrap();
244 let out = rsi.batch(&prices);
245 for x in &out[..14] {
247 assert!(x.is_none());
248 }
249 assert!(out[14].is_some());
250 }
251
252 #[test]
253 fn pure_uptrend_yields_rsi_100() {
254 let prices: Vec<f64> = (1..=20).map(f64::from).collect();
255 let mut rsi = Rsi::new(14).unwrap();
256 let out = rsi.batch(&prices);
257 for v in out.iter().filter_map(|x| x.as_ref()) {
259 assert_relative_eq!(*v, 100.0, epsilon = 1e-9);
260 }
261 }
262
263 #[test]
264 fn pure_downtrend_yields_rsi_0() {
265 let prices: Vec<f64> = (1..=20).rev().map(f64::from).collect();
266 let mut rsi = Rsi::new(14).unwrap();
267 let out = rsi.batch(&prices);
268 for v in out.iter().filter_map(|x| x.as_ref()) {
269 assert_relative_eq!(*v, 0.0, epsilon = 1e-9);
270 }
271 }
272
273 #[test]
274 fn flat_series_yields_rsi_50() {
275 let prices = [10.0_f64; 30];
276 let mut rsi = Rsi::new(14).unwrap();
277 let out = rsi.batch(&prices);
278 for v in out.iter().filter_map(|x| x.as_ref()) {
279 assert_relative_eq!(*v, 50.0, epsilon = 1e-12);
280 }
281 }
282
283 #[test]
284 fn classic_wilder_textbook_values() {
285 let prices = [
290 44.34, 44.09, 44.15, 43.61, 44.33, 44.83, 45.10, 45.42, 45.84, 46.08, 45.89, 46.03,
291 45.61, 46.28, 46.28,
292 ];
293 let mut rsi = Rsi::new(14).unwrap();
294 let out = rsi.batch(&prices);
295 let first = out[14].expect("first RSI emitted at index period");
296 assert_relative_eq!(first, 70.464, epsilon = 0.05);
297 }
298
299 #[test]
300 fn rsi_stays_in_0_100_range() {
301 let prices: Vec<f64> = (0..200)
302 .map(|i| 100.0 + (f64::from(i) * 0.7).sin() * 10.0)
303 .collect();
304 let mut rsi = Rsi::new(14).unwrap();
305 for x in rsi.batch(&prices).into_iter().flatten() {
306 assert!((0.0..=100.0).contains(&x), "RSI out of range: {x}");
307 }
308 }
309
310 #[test]
311 fn reset_clears_state() {
312 let mut rsi = Rsi::new(5).unwrap();
313 rsi.batch(&[1.0, 2.0, 3.0, 2.0, 4.0, 5.0, 6.0]);
314 assert!(rsi.is_ready());
315 rsi.reset();
316 assert!(!rsi.is_ready());
317 assert_eq!(rsi.update(1.0), None);
318 }
319
320 #[test]
321 fn batch_equals_streaming() {
322 let prices: Vec<f64> = (1..=40)
323 .map(|i| (f64::from(i) * 0.3).sin() * 5.0 + f64::from(i))
324 .collect();
325 let mut a = Rsi::new(7).unwrap();
326 let mut b = Rsi::new(7).unwrap();
327 assert_eq!(
328 a.batch(&prices),
329 prices.iter().map(|p| b.update(*p)).collect::<Vec<_>>()
330 );
331 }
332
333 #[test]
334 fn ignores_non_finite_input() {
335 let mut rsi = Rsi::new(3).unwrap();
336 rsi.batch(&[1.0, 2.0, 3.0, 4.0]);
337 let before = rsi.value();
338 assert!(before.is_some());
339 assert_eq!(rsi.update(f64::NAN), before);
340 assert_eq!(rsi.update(f64::INFINITY), before);
341 assert_eq!(rsi.value(), before);
342 }
343
344 proptest::proptest! {
345 #![proptest_config(proptest::test_runner::Config::with_cases(48))]
346 #[test]
347 fn rsi_matches_naive(
348 period in 1usize..20,
349 prices in proptest::collection::vec(1.0_f64..1000.0, 0..150),
350 ) {
351 let mut rsi = Rsi::new(period).unwrap();
352 let got = rsi.batch(&prices);
353 let want = rsi_naive(&prices, period);
354 proptest::prop_assert_eq!(got.len(), want.len());
355 for (g, w) in got.iter().zip(want.iter()) {
356 match (g, w) {
357 (None, None) => {}
358 (Some(a), Some(b)) => proptest::prop_assert!(
359 (a - b).abs() < 1e-7,
360 "got={a} want={b}"
361 ),
362 _ => proptest::prop_assert!(false, "warmup mismatch"),
363 }
364 }
365 }
366 }
367}