1use anyhow::{anyhow, Result};
2
3use super::{ArgCount, FunctionCategory, FunctionSignature, SqlFunction};
4use crate::data::datatable::DataValue;
5
6pub struct ReturnsFunction;
9
10impl SqlFunction for ReturnsFunction {
11 fn signature(&self) -> FunctionSignature {
12 FunctionSignature {
13 name: "RETURNS",
14 category: FunctionCategory::Mathematical,
15 arg_count: ArgCount::Fixed(2),
16 description: "Calculate returns from current and previous price",
17 returns: "FLOAT",
18 examples: vec![
19 "SELECT RETURNS(close, LAG(close) OVER (ORDER BY date)) FROM stocks",
20 "SELECT RETURNS(100, 95)", "SELECT RETURNS(95, 100)", ],
23 }
24 }
25
26 fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
27 if args.len() != 2 {
28 return Err(anyhow!(
29 "RETURNS expects exactly 2 arguments: current_price, previous_price"
30 ));
31 }
32
33 let current = match &args[0] {
34 DataValue::Integer(i) => *i as f64,
35 DataValue::Float(f) => *f,
36 DataValue::Null => return Ok(DataValue::Null),
37 _ => return Err(anyhow!("RETURNS expects numeric values")),
38 };
39
40 let previous = match &args[1] {
41 DataValue::Integer(i) => *i as f64,
42 DataValue::Float(f) => *f,
43 DataValue::Null => return Ok(DataValue::Null),
44 _ => return Err(anyhow!("RETURNS expects numeric values")),
45 };
46
47 if previous == 0.0 {
48 return Err(anyhow!("Cannot calculate returns with previous price of 0"));
49 }
50
51 let returns = (current - previous) / previous;
52 Ok(DataValue::Float(returns))
53 }
54}
55
56pub struct LogReturnsFunction;
59
60impl SqlFunction for LogReturnsFunction {
61 fn signature(&self) -> FunctionSignature {
62 FunctionSignature {
63 name: "LOG_RETURNS",
64 category: FunctionCategory::Mathematical,
65 arg_count: ArgCount::Fixed(2),
66 description: "Calculate logarithmic returns from current and previous price",
67 returns: "FLOAT",
68 examples: vec![
69 "SELECT LOG_RETURNS(close, LAG(close) OVER (ORDER BY date)) FROM stocks",
70 "SELECT LOG_RETURNS(100, 95)", ],
72 }
73 }
74
75 fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
76 if args.len() != 2 {
77 return Err(anyhow!(
78 "LOG_RETURNS expects exactly 2 arguments: current_price, previous_price"
79 ));
80 }
81
82 let current = match &args[0] {
83 DataValue::Integer(i) => *i as f64,
84 DataValue::Float(f) => *f,
85 DataValue::Null => return Ok(DataValue::Null),
86 _ => return Err(anyhow!("LOG_RETURNS expects numeric values")),
87 };
88
89 let previous = match &args[1] {
90 DataValue::Integer(i) => *i as f64,
91 DataValue::Float(f) => *f,
92 DataValue::Null => return Ok(DataValue::Null),
93 _ => return Err(anyhow!("LOG_RETURNS expects numeric values")),
94 };
95
96 if previous <= 0.0 || current <= 0.0 {
97 return Err(anyhow!(
98 "Cannot calculate log returns with non-positive prices"
99 ));
100 }
101
102 let log_returns = (current / previous).ln();
103 Ok(DataValue::Float(log_returns))
104 }
105}
106
107pub struct VolatilityFunction;
110
111impl SqlFunction for VolatilityFunction {
112 fn signature(&self) -> FunctionSignature {
113 FunctionSignature {
114 name: "VOLATILITY",
115 category: FunctionCategory::Mathematical,
116 arg_count: ArgCount::Variadic,
117 description: "Calculate volatility (standard deviation) of returns",
118 returns: "FLOAT",
119 examples: vec![
120 "SELECT VOLATILITY(0.01, -0.02, 0.015, -0.005, 0.008)",
121 "WITH returns AS (SELECT RETURNS(close, LAG(close) OVER (ORDER BY date)) as r FROM stocks) SELECT VOLATILITY(r) FROM returns",
122 ],
123 }
124 }
125
126 fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
127 if args.is_empty() {
128 return Err(anyhow!("VOLATILITY requires at least one value"));
129 }
130
131 let mut values = Vec::new();
132 for arg in args {
133 match arg {
134 DataValue::Integer(i) => values.push(*i as f64),
135 DataValue::Float(f) => values.push(*f),
136 DataValue::Null => continue, _ => return Err(anyhow!("VOLATILITY expects numeric values")),
138 }
139 }
140
141 if values.is_empty() {
142 return Ok(DataValue::Null);
143 }
144
145 if values.len() == 1 {
146 return Ok(DataValue::Float(0.0)); }
148
149 let mean = values.iter().sum::<f64>() / values.len() as f64;
151
152 let variance =
154 values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / (values.len() - 1) as f64; let std_dev = variance.sqrt();
158 Ok(DataValue::Float(std_dev))
159 }
160}
161
162pub struct SharpeRatioFunction;
165
166impl SqlFunction for SharpeRatioFunction {
167 fn signature(&self) -> FunctionSignature {
168 FunctionSignature {
169 name: "SHARPE_RATIO",
170 category: FunctionCategory::Mathematical,
171 arg_count: ArgCount::Fixed(3),
172 description: "Calculate Sharpe ratio: (mean_return - risk_free_rate) / volatility",
173 returns: "FLOAT",
174 examples: vec![
175 "SELECT SHARPE_RATIO(0.08, 0.02, 0.15)", "SELECT SHARPE_RATIO(mean_return, 0.02, volatility) FROM portfolio_stats",
177 ],
178 }
179 }
180
181 fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
182 if args.len() != 3 {
183 return Err(anyhow!(
184 "SHARPE_RATIO expects 3 arguments: mean_return, risk_free_rate, volatility"
185 ));
186 }
187
188 let mean_return = match &args[0] {
189 DataValue::Integer(i) => *i as f64,
190 DataValue::Float(f) => *f,
191 DataValue::Null => return Ok(DataValue::Null),
192 _ => return Err(anyhow!("SHARPE_RATIO expects numeric values")),
193 };
194
195 let risk_free_rate = match &args[1] {
196 DataValue::Integer(i) => *i as f64,
197 DataValue::Float(f) => *f,
198 DataValue::Null => 0.0, _ => return Err(anyhow!("SHARPE_RATIO expects numeric values")),
200 };
201
202 let volatility = match &args[2] {
203 DataValue::Integer(i) => *i as f64,
204 DataValue::Float(f) => *f,
205 DataValue::Null => return Ok(DataValue::Null),
206 _ => return Err(anyhow!("SHARPE_RATIO expects numeric values")),
207 };
208
209 if volatility == 0.0 {
210 return Err(anyhow!(
211 "Cannot calculate Sharpe ratio with zero volatility"
212 ));
213 }
214
215 let sharpe = (mean_return - risk_free_rate) / volatility;
216 Ok(DataValue::Float(sharpe))
217 }
218}
219
220pub struct StdDevFunction;
223
224impl SqlFunction for StdDevFunction {
225 fn signature(&self) -> FunctionSignature {
226 FunctionSignature {
227 name: "STDDEV",
228 category: FunctionCategory::Mathematical,
229 arg_count: ArgCount::Variadic,
230 description: "Calculate sample standard deviation",
231 returns: "FLOAT",
232 examples: vec![
233 "SELECT STDDEV(1, 2, 3, 4, 5)", "SELECT STDDEV(returns) OVER (ORDER BY date ROWS 19 PRECEDING) FROM stocks",
235 ],
236 }
237 }
238
239 fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
240 VolatilityFunction.evaluate(args)
242 }
243}
244
245pub fn register_financial_functions(registry: &mut super::FunctionRegistry) {
247 registry.register(Box::new(ReturnsFunction));
248 registry.register(Box::new(LogReturnsFunction));
249 registry.register(Box::new(VolatilityFunction));
250 registry.register(Box::new(StdDevFunction));
251 registry.register(Box::new(SharpeRatioFunction));
252}