ta_lib_in_rust/indicators/oscillators/
rsi.rs

1use polars::prelude::*;
2
3/// Calculates Relative Strength Index (RSI)
4///
5/// # Arguments
6///
7/// * `df` - DataFrame containing price data
8/// * `window` - RSI calculation period (typically 14)
9/// * `column` - Column name to use for calculations (default "close")
10///
11/// # Returns
12///
13/// * `PolarsResult<Series>` - RSI values as a Series
14pub fn calculate_rsi(df: &DataFrame, window: usize, column: &str) -> PolarsResult<Series> {
15    // Check we have enough data
16    if df.height() < window + 1 {
17        return Err(PolarsError::ComputeError(
18            format!(
19                "Not enough data points for RSI calculation with window size {}",
20                window
21            )
22            .into(),
23        ));
24    }
25
26    // Get price data
27    let close = df.column(column)?.f64()?.clone().into_series();
28
29    // Calculate price changes
30    let prev_close = close.shift(1);
31    let price_diff: Vec<f64> = close
32        .f64()?
33        .iter()
34        .zip(prev_close.f64()?.iter())
35        .map(|(curr, prev)| match (curr, prev) {
36            (Some(c), Some(p)) => c - p,
37            _ => f64::NAN,
38        })
39        .collect();
40
41    // Separate gains and losses
42    let mut gains: Vec<f64> = Vec::with_capacity(df.height());
43    let mut losses: Vec<f64> = Vec::with_capacity(df.height());
44
45    // First value is NaN (no previous value to compare)
46    gains.push(0.0);
47    losses.push(0.0);
48
49    for &diff in &price_diff[1..] {
50        if diff.is_nan() {
51            gains.push(f64::NAN);
52            losses.push(f64::NAN);
53        } else if diff > 0.0 {
54            gains.push(diff);
55            losses.push(0.0);
56        } else {
57            gains.push(0.0);
58            losses.push(diff.abs());
59        }
60    }
61
62    // Calculate RSI using Wilder's smoothing method
63    let mut avg_gain = 0.0;
64    let mut avg_loss = 0.0;
65    let mut rsi: Vec<f64> = Vec::with_capacity(df.height());
66
67    // Fill initial values with NaN
68    for _i in 0..window {
69        rsi.push(f64::NAN);
70    }
71
72    // First average gain/loss is a simple average
73    for i in 1..=window {
74        avg_gain += gains[i];
75        avg_loss += losses[i];
76    }
77    avg_gain /= window as f64;
78    avg_loss /= window as f64;
79
80    // First RSI value
81    let rs = if avg_loss == 0.0 {
82        100.0 // Prevent division by zero
83    } else {
84        avg_gain / avg_loss
85    };
86    let rsi_val = 100.0 - (100.0 / (1.0 + rs));
87    rsi[window - 1] = rsi_val;
88
89    // Calculate smoothed RSI for the rest of the series
90    for i in window + 1..df.height() {
91        // Update using Wilder's smoothing
92        avg_gain = ((avg_gain * (window - 1) as f64) + gains[i]) / window as f64;
93        avg_loss = ((avg_loss * (window - 1) as f64) + losses[i]) / window as f64;
94
95        // Calculate RSI
96        let rs = if avg_loss == 0.0 {
97            100.0 // Prevent division by zero
98        } else {
99            avg_gain / avg_loss
100        };
101        let rsi_val = 100.0 - (100.0 / (1.0 + rs));
102        rsi.push(rsi_val);
103    }
104
105    Ok(Series::new(format!("rsi_{}", window).into(), rsi))
106}