1use crate::error::{Error, Result};
4use crate::traits::Indicator;
5
6#[derive(Debug, Clone)]
26pub struct Rsi {
27 period: usize,
28 prev_close: Option<f64>,
29 seed_buf_gains: Vec<f64>,
32 seed_buf_losses: Vec<f64>,
33 avg_gain: Option<f64>,
34 avg_loss: Option<f64>,
35 last_value: Option<f64>,
36}
37
38impl Rsi {
39 pub fn new(period: usize) -> Result<Self> {
45 if period == 0 {
46 return Err(Error::PeriodZero);
47 }
48 Ok(Self {
49 period,
50 prev_close: None,
51 seed_buf_gains: Vec::with_capacity(period),
52 seed_buf_losses: Vec::with_capacity(period),
53 avg_gain: None,
54 avg_loss: None,
55 last_value: None,
56 })
57 }
58
59 pub const fn period(&self) -> usize {
61 self.period
62 }
63
64 pub const fn value(&self) -> Option<f64> {
66 self.last_value
67 }
68
69 fn rsi_from_avgs(avg_gain: f64, avg_loss: f64) -> f64 {
70 if avg_loss == 0.0 {
71 if avg_gain == 0.0 {
72 50.0
74 } else {
75 100.0
76 }
77 } else {
78 let rs = avg_gain / avg_loss;
79 100.0 - 100.0 / (1.0 + rs)
80 }
81 }
82}
83
84impl Indicator for Rsi {
85 type Input = f64;
86 type Output = f64;
87
88 fn update(&mut self, input: f64) -> Option<f64> {
89 if !input.is_finite() {
90 return self.last_value;
91 }
92
93 let Some(prev) = self.prev_close else {
94 self.prev_close = Some(input);
95 return None;
96 };
97 self.prev_close = Some(input);
98
99 let diff = input - prev;
100 let gain = if diff > 0.0 { diff } else { 0.0 };
101 let loss = if diff < 0.0 { -diff } else { 0.0 };
102
103 if let (Some(ag), Some(al)) = (self.avg_gain, self.avg_loss) {
104 let n = self.period as f64;
105 let new_ag = (ag * (n - 1.0) + gain) / n;
106 let new_al = (al * (n - 1.0) + loss) / n;
107 self.avg_gain = Some(new_ag);
108 self.avg_loss = Some(new_al);
109 let v = Self::rsi_from_avgs(new_ag, new_al);
110 self.last_value = Some(v);
111 return Some(v);
112 }
113
114 self.seed_buf_gains.push(gain);
115 self.seed_buf_losses.push(loss);
116 if self.seed_buf_gains.len() == self.period {
117 let ag = self.seed_buf_gains.iter().sum::<f64>() / self.period as f64;
118 let al = self.seed_buf_losses.iter().sum::<f64>() / self.period as f64;
119 self.avg_gain = Some(ag);
120 self.avg_loss = Some(al);
121 let v = Self::rsi_from_avgs(ag, al);
122 self.last_value = Some(v);
123 return Some(v);
124 }
125 None
126 }
127
128 fn reset(&mut self) {
129 self.prev_close = None;
130 self.seed_buf_gains.clear();
131 self.seed_buf_losses.clear();
132 self.avg_gain = None;
133 self.avg_loss = None;
134 self.last_value = None;
135 }
136
137 fn warmup_period(&self) -> usize {
138 self.period + 1
139 }
140
141 fn is_ready(&self) -> bool {
142 self.last_value.is_some()
143 }
144
145 fn name(&self) -> &'static str {
146 "RSI"
147 }
148}
149
150#[cfg(test)]
151mod tests {
152 use super::*;
153 use crate::traits::BatchExt;
154 use approx::assert_relative_eq;
155
156 fn rsi_naive(prices: &[f64], period: usize) -> Vec<Option<f64>> {
158 let n = period as f64;
159 let mut out = vec![None; prices.len()];
160 let mut gains: Vec<f64> = Vec::new();
161 let mut losses: Vec<f64> = Vec::new();
162 let mut avg_gain: Option<f64> = None;
163 let mut avg_loss: Option<f64> = None;
164 let rsi_val = |ag: f64, al: f64| -> f64 {
165 if al == 0.0 {
166 if ag == 0.0 {
167 50.0
168 } else {
169 100.0
170 }
171 } else {
172 100.0 - 100.0 / (1.0 + ag / al)
173 }
174 };
175 for i in 1..prices.len() {
176 let diff = prices[i] - prices[i - 1];
177 let gain = if diff > 0.0 { diff } else { 0.0 };
178 let loss = if diff < 0.0 { -diff } else { 0.0 };
179 if let (Some(ag), Some(al)) = (avg_gain, avg_loss) {
180 let nag = (ag * (n - 1.0) + gain) / n;
181 let nal = (al * (n - 1.0) + loss) / n;
182 avg_gain = Some(nag);
183 avg_loss = Some(nal);
184 out[i] = Some(rsi_val(nag, nal));
185 } else {
186 gains.push(gain);
187 losses.push(loss);
188 if gains.len() == period {
189 let ag = gains.iter().sum::<f64>() / n;
190 let al = losses.iter().sum::<f64>() / n;
191 avg_gain = Some(ag);
192 avg_loss = Some(al);
193 out[i] = Some(rsi_val(ag, al));
194 }
195 }
196 }
197 out
198 }
199
200 #[test]
201 fn new_rejects_zero_period() {
202 assert!(matches!(Rsi::new(0), Err(Error::PeriodZero)));
203 }
204
205 #[test]
209 fn accessors_and_metadata() {
210 let mut rsi = Rsi::new(14).unwrap();
211 assert_eq!(rsi.period(), 14);
212 assert_eq!(rsi.name(), "RSI");
213 assert_eq!(rsi.value(), None);
214 for i in 1..=15 {
215 rsi.update(100.0 + f64::from(i));
216 }
217 assert!(rsi.value().is_some());
218 }
219
220 #[test]
226 fn naive_helper_flat_series_yields_50() {
227 let ks = rsi_naive(&[42.0; 20], 5);
228 for r in ks.into_iter().skip(5) {
229 assert_eq!(r.expect("ready after period+1 inputs"), 50.0);
230 }
231 }
232
233 #[test]
239 fn naive_helper_monotone_up_yields_100() {
240 let prices: Vec<f64> = (1..=20).map(f64::from).collect();
241 let ks = rsi_naive(&prices, 5);
242 for r in ks.into_iter().skip(5) {
243 assert_eq!(r.expect("ready after period+1 inputs"), 100.0);
244 }
245 }
246
247 #[test]
248 fn warmup_period_is_period_plus_one() {
249 let rsi = Rsi::new(14).unwrap();
250 assert_eq!(rsi.warmup_period(), 15);
251 }
252
253 #[test]
254 fn first_emission_at_index_period() {
255 let prices: Vec<f64> = (1..=20).map(f64::from).collect();
257 let mut rsi = Rsi::new(14).unwrap();
258 let out = rsi.batch(&prices);
259 for x in &out[..14] {
261 assert!(x.is_none());
262 }
263 assert!(out[14].is_some());
264 }
265
266 #[test]
267 fn pure_uptrend_yields_rsi_100() {
268 let prices: Vec<f64> = (1..=20).map(f64::from).collect();
269 let mut rsi = Rsi::new(14).unwrap();
270 let out = rsi.batch(&prices);
271 for v in out.iter().filter_map(|x| x.as_ref()) {
273 assert_relative_eq!(*v, 100.0, epsilon = 1e-9);
274 }
275 }
276
277 #[test]
278 fn pure_downtrend_yields_rsi_0() {
279 let prices: Vec<f64> = (1..=20).rev().map(f64::from).collect();
280 let mut rsi = Rsi::new(14).unwrap();
281 let out = rsi.batch(&prices);
282 for v in out.iter().filter_map(|x| x.as_ref()) {
283 assert_relative_eq!(*v, 0.0, epsilon = 1e-9);
284 }
285 }
286
287 #[test]
288 fn flat_series_yields_rsi_50() {
289 let prices = [10.0_f64; 30];
290 let mut rsi = Rsi::new(14).unwrap();
291 let out = rsi.batch(&prices);
292 for v in out.iter().filter_map(|x| x.as_ref()) {
293 assert_relative_eq!(*v, 50.0, epsilon = 1e-12);
294 }
295 }
296
297 #[test]
298 fn classic_wilder_textbook_values() {
299 let prices = [
304 44.34, 44.09, 44.15, 43.61, 44.33, 44.83, 45.10, 45.42, 45.84, 46.08, 45.89, 46.03,
305 45.61, 46.28, 46.28,
306 ];
307 let mut rsi = Rsi::new(14).unwrap();
308 let out = rsi.batch(&prices);
309 let first = out[14].expect("first RSI emitted at index period");
310 assert_relative_eq!(first, 70.464, epsilon = 0.05);
311 }
312
313 #[test]
314 fn rsi_stays_in_0_100_range() {
315 let prices: Vec<f64> = (0..200)
316 .map(|i| 100.0 + (f64::from(i) * 0.7).sin() * 10.0)
317 .collect();
318 let mut rsi = Rsi::new(14).unwrap();
319 for x in rsi.batch(&prices).into_iter().flatten() {
320 assert!((0.0..=100.0).contains(&x), "RSI out of range: {x}");
321 }
322 }
323
324 #[test]
325 fn reset_clears_state() {
326 let mut rsi = Rsi::new(5).unwrap();
327 rsi.batch(&[1.0, 2.0, 3.0, 2.0, 4.0, 5.0, 6.0]);
328 assert!(rsi.is_ready());
329 rsi.reset();
330 assert!(!rsi.is_ready());
331 assert_eq!(rsi.update(1.0), None);
332 }
333
334 #[test]
335 fn batch_equals_streaming() {
336 let prices: Vec<f64> = (1..=40)
337 .map(|i| (f64::from(i) * 0.3).sin() * 5.0 + f64::from(i))
338 .collect();
339 let mut a = Rsi::new(7).unwrap();
340 let mut b = Rsi::new(7).unwrap();
341 assert_eq!(
342 a.batch(&prices),
343 prices.iter().map(|p| b.update(*p)).collect::<Vec<_>>()
344 );
345 }
346
347 #[test]
348 fn ignores_non_finite_input() {
349 let mut rsi = Rsi::new(3).unwrap();
350 rsi.batch(&[1.0, 2.0, 3.0, 4.0]);
351 let before = rsi.value();
352 assert!(before.is_some());
353 assert_eq!(rsi.update(f64::NAN), before);
354 assert_eq!(rsi.update(f64::INFINITY), before);
355 assert_eq!(rsi.value(), before);
356 }
357
358 proptest::proptest! {
359 #![proptest_config(proptest::test_runner::Config::with_cases(48))]
360 #[test]
361 fn rsi_matches_naive(
362 period in 1usize..20,
363 prices in proptest::collection::vec(1.0_f64..1000.0, 0..150),
364 ) {
365 let mut rsi = Rsi::new(period).unwrap();
366 let got = rsi.batch(&prices);
367 let want = rsi_naive(&prices, period);
368 proptest::prop_assert_eq!(got.len(), want.len());
369 for (g, w) in got.iter().zip(want.iter()) {
370 match (g, w) {
371 (None, None) => {}
372 (Some(a), Some(b)) => proptest::prop_assert!(
373 (a - b).abs() < 1e-7,
374 "got={a} want={b}"
375 ),
376 _ => proptest::prop_assert!(false, "warmup mismatch"),
377 }
378 }
379 }
380 }
381}