Skip to main content

wickra_core/indicators/
r_squared.rs

1//! Coefficient of determination R² for the rolling OLS fit.
2
3use std::collections::VecDeque;
4
5use crate::error::{Error, Result};
6use crate::traits::Indicator;
7
8/// R² (coefficient of determination) of the rolling least-squares fit.
9///
10/// Over the trailing window indexed `x = 0, 1, …, period − 1` the OLS line
11/// `y = a + b·x` is fitted and the ratio of variance explained by the line
12/// to total variance is reported:
13///
14/// ```text
15/// slope        = (n·Σxy − Σx·Σy) / (n·Σxx − (Σx)²)
16/// SS_total     = Σy² − n·ȳ²
17/// SS_explained = slope² · ( denom / n )
18/// R²           = SS_explained / SS_total                  if SS_total > 0
19///              = 1                                        otherwise (flat window)
20/// ```
21///
22/// A reading of `1.0` means the window lies on a straight line — perfect
23/// linear fit. `0.0` means the slope is irrelevant; the trend explains none
24/// of the variance. Mid-range values quantify how trending the recent price
25/// action is, independent of the slope's sign or magnitude. Use it as a
26/// trend-quality filter: a strategy that needs a clear trend can require
27/// `R² > 0.7`, while a mean-reversion strategy can prefer `R² < 0.3`.
28///
29/// A flat window has `SS_total = 0`; the line is also flat and the fit is
30/// trivially perfect, so the indicator returns `1.0` rather than dividing
31/// by zero.
32///
33/// Each `update` is O(1) via the same rolling sums as
34/// [`crate::LinearRegression`], plus a running `Σy²`. The output is
35/// clamped to `[0, 1]` to absorb tiny floating-point cancellation.
36///
37/// # Example
38///
39/// ```
40/// use wickra_core::{Indicator, RSquared};
41///
42/// let mut indicator = RSquared::new(14).unwrap();
43/// let mut last = None;
44/// for i in 0..40 {
45///     last = indicator.update(f64::from(i));
46/// }
47/// assert!(last.is_some());
48/// ```
49#[derive(Debug, Clone)]
50pub struct RSquared {
51    period: usize,
52    window: VecDeque<f64>,
53    sum_x: f64,
54    /// `n·Σxx − (Σx)²` — OLS denominator, constant in `period`.
55    denom: f64,
56    sum_y: f64,
57    sum_xy: f64,
58    sum_y_sq: f64,
59}
60
61impl RSquared {
62    /// Construct a new rolling R² over `period` inputs.
63    ///
64    /// # Errors
65    /// Returns [`Error::InvalidPeriod`] if `period < 2` — a regression line
66    /// is undefined for fewer than two points.
67    pub fn new(period: usize) -> Result<Self> {
68        if period < 2 {
69            return Err(Error::InvalidPeriod {
70                message: "R² needs period >= 2",
71            });
72        }
73        let n = period as f64;
74        let sum_x = n * (n - 1.0) / 2.0;
75        let sum_xx = (n - 1.0) * n * (2.0 * n - 1.0) / 6.0;
76        Ok(Self {
77            period,
78            window: VecDeque::with_capacity(period),
79            sum_x,
80            denom: n * sum_xx - sum_x * sum_x,
81            sum_y: 0.0,
82            sum_xy: 0.0,
83            sum_y_sq: 0.0,
84        })
85    }
86
87    /// Configured period.
88    pub const fn period(&self) -> usize {
89        self.period
90    }
91}
92
93impl Indicator for RSquared {
94    type Input = f64;
95    type Output = f64;
96
97    fn update(&mut self, value: f64) -> Option<f64> {
98        if !value.is_finite() {
99            return None;
100        }
101        if self.window.len() == self.period {
102            let y0 = self.window.pop_front().expect("non-empty");
103            self.sum_xy = self.sum_xy - self.sum_y + y0;
104            self.sum_y -= y0;
105            self.sum_y_sq -= y0 * y0;
106        }
107        let k = self.window.len() as f64;
108        self.window.push_back(value);
109        self.sum_y += value;
110        self.sum_xy += k * value;
111        self.sum_y_sq += value * value;
112
113        if self.window.len() < self.period {
114            return None;
115        }
116        let n = self.period as f64;
117        let slope = (n * self.sum_xy - self.sum_x * self.sum_y) / self.denom;
118        let mean_y = self.sum_y / n;
119        let ss_total = (self.sum_y_sq - n * mean_y * mean_y).max(0.0);
120        let s_xx = self.denom / n;
121        let ss_explained = slope * slope * s_xx;
122        if ss_total <= 0.0 {
123            // Flat window: the fit is trivially perfect.
124            return Some(1.0);
125        }
126        Some((ss_explained / ss_total).clamp(0.0, 1.0))
127    }
128
129    fn reset(&mut self) {
130        self.window.clear();
131        self.sum_y = 0.0;
132        self.sum_xy = 0.0;
133        self.sum_y_sq = 0.0;
134    }
135
136    fn warmup_period(&self) -> usize {
137        self.period
138    }
139
140    fn is_ready(&self) -> bool {
141        self.window.len() == self.period
142    }
143
144    fn name(&self) -> &'static str {
145        "RSquared"
146    }
147}
148
149#[cfg(test)]
150mod tests {
151    use super::*;
152    use crate::traits::BatchExt;
153    use approx::assert_relative_eq;
154
155    #[test]
156    fn rejects_period_below_two() {
157        assert!(RSquared::new(0).is_err());
158        assert!(RSquared::new(1).is_err());
159        assert!(RSquared::new(2).is_ok());
160    }
161
162    #[test]
163    fn accessors_and_metadata() {
164        let r = RSquared::new(14).unwrap();
165        assert_eq!(r.period(), 14);
166        assert_eq!(r.warmup_period(), 14);
167        assert_eq!(r.name(), "RSquared");
168    }
169
170    #[test]
171    fn perfect_line_is_one() {
172        let prices: Vec<f64> = (0..30).map(|i| 2.0 * f64::from(i) + 5.0).collect();
173        let mut r = RSquared::new(10).unwrap();
174        for v in r.batch(&prices).into_iter().flatten() {
175            assert_relative_eq!(v, 1.0, epsilon = 1e-9);
176        }
177    }
178
179    #[test]
180    fn constant_series_is_one() {
181        // SS_total is zero; the indicator must return 1 instead of NaN.
182        let mut r = RSquared::new(5).unwrap();
183        for v in r.batch(&[42.0; 20]).into_iter().flatten() {
184            assert_relative_eq!(v, 1.0, epsilon = 1e-12);
185        }
186    }
187
188    #[test]
189    fn output_stays_in_zero_one_range() {
190        let prices: Vec<f64> = (0..120)
191            .map(|i| 100.0 + (f64::from(i) * 0.4).sin() * 5.0 + (f64::from(i) * 0.07).cos() * 12.0)
192            .collect();
193        let mut r = RSquared::new(20).unwrap();
194        for v in r.batch(&prices).into_iter().flatten() {
195            assert!((0.0..=1.0).contains(&v), "R² out of range: {v}");
196        }
197    }
198
199    #[test]
200    fn reset_clears_state() {
201        let mut r = RSquared::new(5).unwrap();
202        r.batch(&[1.0, 2.0, 3.0, 4.0, 5.0]);
203        assert!(r.is_ready());
204        r.reset();
205        assert!(!r.is_ready());
206        assert_eq!(r.update(1.0), None);
207    }
208
209    #[test]
210    fn batch_equals_streaming() {
211        let prices: Vec<f64> = (0..60)
212            .map(|i| 50.0 + (f64::from(i) * 0.3).sin() * 10.0)
213            .collect();
214        let batch = RSquared::new(14).unwrap().batch(&prices);
215        let mut b = RSquared::new(14).unwrap();
216        let streamed: Vec<_> = prices.iter().map(|p| b.update(*p)).collect();
217        assert_eq!(batch, streamed);
218    }
219}