wickra_core/indicators/
sharpe_ratio.rs1use std::collections::VecDeque;
4
5use crate::error::{Error, Result};
6use crate::traits::Indicator;
7
8#[derive(Debug, Clone)]
44pub struct SharpeRatio {
45 period: usize,
46 risk_free: f64,
47 window: VecDeque<f64>,
48 sum: f64,
49 sum_sq: f64,
50}
51
52impl SharpeRatio {
53 pub fn new(period: usize, risk_free: f64) -> Result<Self> {
60 if period < 2 {
61 return Err(Error::InvalidPeriod {
62 message: "sharpe ratio needs period >= 2",
63 });
64 }
65 Ok(Self {
66 period,
67 risk_free,
68 window: VecDeque::with_capacity(period),
69 sum: 0.0,
70 sum_sq: 0.0,
71 })
72 }
73
74 pub const fn period(&self) -> usize {
76 self.period
77 }
78
79 pub const fn risk_free(&self) -> f64 {
81 self.risk_free
82 }
83}
84
85impl Indicator for SharpeRatio {
86 type Input = f64;
87 type Output = f64;
88
89 fn update(&mut self, input: f64) -> Option<f64> {
90 if !input.is_finite() {
91 return None;
92 }
93 if self.window.len() == self.period {
94 let old = self.window.pop_front().expect("non-empty");
95 self.sum -= old;
96 self.sum_sq -= old * old;
97 }
98 self.window.push_back(input);
99 self.sum += input;
100 self.sum_sq += input * input;
101 if self.window.len() < self.period {
102 return None;
103 }
104 let n = self.period as f64;
105 let mean = self.sum / n;
106 let var = (self.sum_sq - n * mean * mean).max(0.0) / (n - 1.0);
108 let sd = var.sqrt();
109 if sd == 0.0 {
110 return Some(0.0);
111 }
112 Some((mean - self.risk_free) / sd)
113 }
114
115 fn reset(&mut self) {
116 self.window.clear();
117 self.sum = 0.0;
118 self.sum_sq = 0.0;
119 }
120
121 fn warmup_period(&self) -> usize {
122 self.period
123 }
124
125 fn is_ready(&self) -> bool {
126 self.window.len() == self.period
127 }
128
129 fn name(&self) -> &'static str {
130 "SharpeRatio"
131 }
132}
133
134#[cfg(test)]
135mod tests {
136 use super::*;
137 use crate::traits::BatchExt;
138 use approx::assert_relative_eq;
139
140 #[test]
141 fn rejects_period_less_than_two() {
142 assert!(matches!(
143 SharpeRatio::new(1, 0.0),
144 Err(Error::InvalidPeriod { .. })
145 ));
146 assert!(matches!(
147 SharpeRatio::new(0, 0.0),
148 Err(Error::InvalidPeriod { .. })
149 ));
150 }
151
152 #[test]
153 fn accessors_and_metadata() {
154 let sr = SharpeRatio::new(20, 0.001).unwrap();
155 assert_eq!(sr.period(), 20);
156 assert_relative_eq!(sr.risk_free(), 0.001, epsilon = 1e-12);
157 assert_eq!(sr.name(), "SharpeRatio");
158 assert_eq!(sr.warmup_period(), 20);
159 }
160
161 #[test]
162 fn constant_returns_yield_zero() {
163 let mut sr = SharpeRatio::new(5, 0.0).unwrap();
164 let out = sr.batch(&[0.01; 10]);
165 for v in out.into_iter().flatten() {
166 assert_relative_eq!(v, 0.0, epsilon = 1e-12);
167 }
168 }
169
170 #[test]
171 fn reference_value() {
172 let mut sr = SharpeRatio::new(4, 0.0).unwrap();
177 let out = sr.batch(&[0.01, 0.02, 0.03, 0.04]);
178 let expected = 0.025_f64 / (0.000_166_666_666_666_666_67_f64).sqrt();
179 assert_relative_eq!(out[3].unwrap(), expected, epsilon = 1e-9);
180 }
181
182 #[test]
183 fn ignores_non_finite_input() {
184 let mut sr = SharpeRatio::new(3, 0.0).unwrap();
185 assert_eq!(sr.update(0.01), None);
186 assert_eq!(sr.update(f64::NAN), None);
187 assert_eq!(sr.update(0.02), None);
188 assert!(sr.update(0.03).is_some());
189 }
190
191 #[test]
192 fn warmup_returns_none() {
193 let mut sr = SharpeRatio::new(5, 0.0).unwrap();
194 for i in 0..4 {
195 assert_eq!(sr.update(f64::from(i) * 0.01), None);
196 }
197 assert!(sr.update(0.05).is_some());
198 }
199
200 #[test]
201 fn reset_clears_state() {
202 let mut sr = SharpeRatio::new(3, 0.0).unwrap();
203 sr.batch(&[0.01, 0.02, 0.03]);
204 assert!(sr.is_ready());
205 sr.reset();
206 assert!(!sr.is_ready());
207 assert_eq!(sr.update(0.01), None);
208 }
209
210 #[test]
211 fn batch_equals_streaming() {
212 let returns: Vec<f64> = (0..50)
213 .map(|i| 0.001 + (f64::from(i) * 0.2).sin() * 0.01)
214 .collect();
215 let batch = SharpeRatio::new(10, 0.0).unwrap().batch(&returns);
216 let mut s = SharpeRatio::new(10, 0.0).unwrap();
217 let streamed: Vec<_> = returns.iter().map(|p| s.update(*p)).collect();
218 assert_eq!(batch, streamed);
219 }
220}