1use anyhow::{anyhow, Result};
2use std::collections::VecDeque;
3
4use super::{ArgCount, FunctionCategory, FunctionSignature, SqlFunction};
5use crate::data::datatable::DataValue;
6
7pub struct ReturnsFunction;
10
11impl SqlFunction for ReturnsFunction {
12 fn signature(&self) -> FunctionSignature {
13 FunctionSignature {
14 name: "RETURNS",
15 category: FunctionCategory::Mathematical,
16 arg_count: ArgCount::Fixed(2),
17 description: "Calculate returns from current and previous price",
18 returns: "FLOAT",
19 examples: vec![
20 "SELECT RETURNS(close, LAG(close) OVER (ORDER BY date)) FROM stocks",
21 "SELECT RETURNS(100, 95)", "SELECT RETURNS(95, 100)", ],
24 }
25 }
26
27 fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
28 if args.len() != 2 {
29 return Err(anyhow!(
30 "RETURNS expects exactly 2 arguments: current_price, previous_price"
31 ));
32 }
33
34 let current = match &args[0] {
35 DataValue::Integer(i) => *i as f64,
36 DataValue::Float(f) => *f,
37 DataValue::Null => return Ok(DataValue::Null),
38 _ => return Err(anyhow!("RETURNS expects numeric values")),
39 };
40
41 let previous = match &args[1] {
42 DataValue::Integer(i) => *i as f64,
43 DataValue::Float(f) => *f,
44 DataValue::Null => return Ok(DataValue::Null),
45 _ => return Err(anyhow!("RETURNS expects numeric values")),
46 };
47
48 if previous == 0.0 {
49 return Err(anyhow!("Cannot calculate returns with previous price of 0"));
50 }
51
52 let returns = (current - previous) / previous;
53 Ok(DataValue::Float(returns))
54 }
55}
56
57pub struct LogReturnsFunction;
60
61impl SqlFunction for LogReturnsFunction {
62 fn signature(&self) -> FunctionSignature {
63 FunctionSignature {
64 name: "LOG_RETURNS",
65 category: FunctionCategory::Mathematical,
66 arg_count: ArgCount::Fixed(2),
67 description: "Calculate logarithmic returns from current and previous price",
68 returns: "FLOAT",
69 examples: vec![
70 "SELECT LOG_RETURNS(close, LAG(close) OVER (ORDER BY date)) FROM stocks",
71 "SELECT LOG_RETURNS(100, 95)", ],
73 }
74 }
75
76 fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
77 if args.len() != 2 {
78 return Err(anyhow!(
79 "LOG_RETURNS expects exactly 2 arguments: current_price, previous_price"
80 ));
81 }
82
83 let current = match &args[0] {
84 DataValue::Integer(i) => *i as f64,
85 DataValue::Float(f) => *f,
86 DataValue::Null => return Ok(DataValue::Null),
87 _ => return Err(anyhow!("LOG_RETURNS expects numeric values")),
88 };
89
90 let previous = match &args[1] {
91 DataValue::Integer(i) => *i as f64,
92 DataValue::Float(f) => *f,
93 DataValue::Null => return Ok(DataValue::Null),
94 _ => return Err(anyhow!("LOG_RETURNS expects numeric values")),
95 };
96
97 if previous <= 0.0 || current <= 0.0 {
98 return Err(anyhow!(
99 "Cannot calculate log returns with non-positive prices"
100 ));
101 }
102
103 let log_returns = (current / previous).ln();
104 Ok(DataValue::Float(log_returns))
105 }
106}
107
108pub struct VolatilityFunction;
111
112impl SqlFunction for VolatilityFunction {
113 fn signature(&self) -> FunctionSignature {
114 FunctionSignature {
115 name: "VOLATILITY",
116 category: FunctionCategory::Mathematical,
117 arg_count: ArgCount::Variadic,
118 description: "Calculate volatility (standard deviation) of returns",
119 returns: "FLOAT",
120 examples: vec![
121 "SELECT VOLATILITY(0.01, -0.02, 0.015, -0.005, 0.008)",
122 "WITH returns AS (SELECT RETURNS(close, LAG(close) OVER (ORDER BY date)) as r FROM stocks) SELECT VOLATILITY(r) FROM returns",
123 ],
124 }
125 }
126
127 fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
128 if args.is_empty() {
129 return Err(anyhow!("VOLATILITY requires at least one value"));
130 }
131
132 let mut values = Vec::new();
133 for arg in args {
134 match arg {
135 DataValue::Integer(i) => values.push(*i as f64),
136 DataValue::Float(f) => values.push(*f),
137 DataValue::Null => continue, _ => return Err(anyhow!("VOLATILITY expects numeric values")),
139 }
140 }
141
142 if values.is_empty() {
143 return Ok(DataValue::Null);
144 }
145
146 if values.len() == 1 {
147 return Ok(DataValue::Float(0.0)); }
149
150 let mean = values.iter().sum::<f64>() / values.len() as f64;
152
153 let variance =
155 values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / (values.len() - 1) as f64; let std_dev = variance.sqrt();
159 Ok(DataValue::Float(std_dev))
160 }
161}
162
163pub struct SharpeRatioFunction;
166
167impl SqlFunction for SharpeRatioFunction {
168 fn signature(&self) -> FunctionSignature {
169 FunctionSignature {
170 name: "SHARPE_RATIO",
171 category: FunctionCategory::Mathematical,
172 arg_count: ArgCount::Fixed(3),
173 description: "Calculate Sharpe ratio: (mean_return - risk_free_rate) / volatility",
174 returns: "FLOAT",
175 examples: vec![
176 "SELECT SHARPE_RATIO(0.08, 0.02, 0.15)", "SELECT SHARPE_RATIO(mean_return, 0.02, volatility) FROM portfolio_stats",
178 ],
179 }
180 }
181
182 fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
183 if args.len() != 3 {
184 return Err(anyhow!(
185 "SHARPE_RATIO expects 3 arguments: mean_return, risk_free_rate, volatility"
186 ));
187 }
188
189 let mean_return = match &args[0] {
190 DataValue::Integer(i) => *i as f64,
191 DataValue::Float(f) => *f,
192 DataValue::Null => return Ok(DataValue::Null),
193 _ => return Err(anyhow!("SHARPE_RATIO expects numeric values")),
194 };
195
196 let risk_free_rate = match &args[1] {
197 DataValue::Integer(i) => *i as f64,
198 DataValue::Float(f) => *f,
199 DataValue::Null => 0.0, _ => return Err(anyhow!("SHARPE_RATIO expects numeric values")),
201 };
202
203 let volatility = match &args[2] {
204 DataValue::Integer(i) => *i as f64,
205 DataValue::Float(f) => *f,
206 DataValue::Null => return Ok(DataValue::Null),
207 _ => return Err(anyhow!("SHARPE_RATIO expects numeric values")),
208 };
209
210 if volatility == 0.0 {
211 return Err(anyhow!(
212 "Cannot calculate Sharpe ratio with zero volatility"
213 ));
214 }
215
216 let sharpe = (mean_return - risk_free_rate) / volatility;
217 Ok(DataValue::Float(sharpe))
218 }
219}
220
221pub struct StdDevFunction;
224
225impl SqlFunction for StdDevFunction {
226 fn signature(&self) -> FunctionSignature {
227 FunctionSignature {
228 name: "STDDEV",
229 category: FunctionCategory::Mathematical,
230 arg_count: ArgCount::Variadic,
231 description: "Calculate sample standard deviation",
232 returns: "FLOAT",
233 examples: vec![
234 "SELECT STDDEV(1, 2, 3, 4, 5)", "SELECT STDDEV(returns) OVER (ORDER BY date ROWS 19 PRECEDING) FROM stocks",
236 ],
237 }
238 }
239
240 fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
241 VolatilityFunction.evaluate(args)
243 }
244}
245
246pub fn register_financial_functions(registry: &mut super::FunctionRegistry) {
248 registry.register(Box::new(ReturnsFunction));
249 registry.register(Box::new(LogReturnsFunction));
250 registry.register(Box::new(VolatilityFunction));
251 registry.register(Box::new(StdDevFunction));
252 registry.register(Box::new(SharpeRatioFunction));
253}