wickra_core/indicators/
r_squared.rs1use std::collections::VecDeque;
4
5use crate::error::{Error, Result};
6use crate::traits::Indicator;
7
8#[derive(Debug, Clone)]
50pub struct RSquared {
51 period: usize,
52 window: VecDeque<f64>,
53 sum_x: f64,
54 denom: f64,
56 sum_y: f64,
57 sum_xy: f64,
58 sum_y_sq: f64,
59}
60
61impl RSquared {
62 pub fn new(period: usize) -> Result<Self> {
68 if period < 2 {
69 return Err(Error::InvalidPeriod {
70 message: "R² needs period >= 2",
71 });
72 }
73 let n = period as f64;
74 let sum_x = n * (n - 1.0) / 2.0;
75 let sum_xx = (n - 1.0) * n * (2.0 * n - 1.0) / 6.0;
76 Ok(Self {
77 period,
78 window: VecDeque::with_capacity(period),
79 sum_x,
80 denom: n * sum_xx - sum_x * sum_x,
81 sum_y: 0.0,
82 sum_xy: 0.0,
83 sum_y_sq: 0.0,
84 })
85 }
86
87 pub const fn period(&self) -> usize {
89 self.period
90 }
91}
92
93impl Indicator for RSquared {
94 type Input = f64;
95 type Output = f64;
96
97 fn update(&mut self, value: f64) -> Option<f64> {
98 if !value.is_finite() {
99 return None;
100 }
101 if self.window.len() == self.period {
102 let y0 = self.window.pop_front().expect("non-empty");
103 self.sum_xy = self.sum_xy - self.sum_y + y0;
104 self.sum_y -= y0;
105 self.sum_y_sq -= y0 * y0;
106 }
107 let k = self.window.len() as f64;
108 self.window.push_back(value);
109 self.sum_y += value;
110 self.sum_xy += k * value;
111 self.sum_y_sq += value * value;
112
113 if self.window.len() < self.period {
114 return None;
115 }
116 let n = self.period as f64;
117 let slope = (n * self.sum_xy - self.sum_x * self.sum_y) / self.denom;
118 let mean_y = self.sum_y / n;
119 let ss_total = (self.sum_y_sq - n * mean_y * mean_y).max(0.0);
120 let s_xx = self.denom / n;
121 let ss_explained = slope * slope * s_xx;
122 if ss_total <= 0.0 {
123 return Some(1.0);
125 }
126 Some((ss_explained / ss_total).clamp(0.0, 1.0))
127 }
128
129 fn reset(&mut self) {
130 self.window.clear();
131 self.sum_y = 0.0;
132 self.sum_xy = 0.0;
133 self.sum_y_sq = 0.0;
134 }
135
136 fn warmup_period(&self) -> usize {
137 self.period
138 }
139
140 fn is_ready(&self) -> bool {
141 self.window.len() == self.period
142 }
143
144 fn name(&self) -> &'static str {
145 "RSquared"
146 }
147}
148
149#[cfg(test)]
150mod tests {
151 use super::*;
152 use crate::traits::BatchExt;
153 use approx::assert_relative_eq;
154
155 #[test]
156 fn rejects_period_below_two() {
157 assert!(RSquared::new(0).is_err());
158 assert!(RSquared::new(1).is_err());
159 assert!(RSquared::new(2).is_ok());
160 }
161
162 #[test]
163 fn accessors_and_metadata() {
164 let r = RSquared::new(14).unwrap();
165 assert_eq!(r.period(), 14);
166 assert_eq!(r.warmup_period(), 14);
167 assert_eq!(r.name(), "RSquared");
168 }
169
170 #[test]
171 fn perfect_line_is_one() {
172 let prices: Vec<f64> = (0..30).map(|i| 2.0 * f64::from(i) + 5.0).collect();
173 let mut r = RSquared::new(10).unwrap();
174 for v in r.batch(&prices).into_iter().flatten() {
175 assert_relative_eq!(v, 1.0, epsilon = 1e-9);
176 }
177 }
178
179 #[test]
180 fn constant_series_is_one() {
181 let mut r = RSquared::new(5).unwrap();
183 for v in r.batch(&[42.0; 20]).into_iter().flatten() {
184 assert_relative_eq!(v, 1.0, epsilon = 1e-12);
185 }
186 }
187
188 #[test]
189 fn output_stays_in_zero_one_range() {
190 let prices: Vec<f64> = (0..120)
191 .map(|i| 100.0 + (f64::from(i) * 0.4).sin() * 5.0 + (f64::from(i) * 0.07).cos() * 12.0)
192 .collect();
193 let mut r = RSquared::new(20).unwrap();
194 for v in r.batch(&prices).into_iter().flatten() {
195 assert!((0.0..=1.0).contains(&v), "R² out of range: {v}");
196 }
197 }
198
199 #[test]
200 fn reset_clears_state() {
201 let mut r = RSquared::new(5).unwrap();
202 r.batch(&[1.0, 2.0, 3.0, 4.0, 5.0]);
203 assert!(r.is_ready());
204 r.reset();
205 assert!(!r.is_ready());
206 assert_eq!(r.update(1.0), None);
207 }
208
209 #[test]
210 fn batch_equals_streaming() {
211 let prices: Vec<f64> = (0..60)
212 .map(|i| 50.0 + (f64::from(i) * 0.3).sin() * 10.0)
213 .collect();
214 let batch = RSquared::new(14).unwrap().batch(&prices);
215 let mut b = RSquared::new(14).unwrap();
216 let streamed: Vec<_> = prices.iter().map(|p| b.update(*p)).collect();
217 assert_eq!(batch, streamed);
218 }
219}