sql_cli/sql/functions/
financial.rs

1use anyhow::{anyhow, Result};
2
3use super::{ArgCount, FunctionCategory, FunctionSignature, SqlFunction};
4use crate::data::datatable::DataValue;
5
6/// RETURNS function - calculates simple returns from price series
7/// Returns = (price[t] - price[t-1]) / price[t-1]
8pub struct ReturnsFunction;
9
10impl SqlFunction for ReturnsFunction {
11    fn signature(&self) -> FunctionSignature {
12        FunctionSignature {
13            name: "RETURNS",
14            category: FunctionCategory::Mathematical,
15            arg_count: ArgCount::Fixed(2),
16            description: "Calculate returns from current and previous price",
17            returns: "FLOAT",
18            examples: vec![
19                "SELECT RETURNS(close, LAG(close) OVER (ORDER BY date)) FROM stocks",
20                "SELECT RETURNS(100, 95)", // Returns 0.0526 (5.26% gain)
21                "SELECT RETURNS(95, 100)", // Returns -0.05 (5% loss)
22            ],
23        }
24    }
25
26    fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
27        if args.len() != 2 {
28            return Err(anyhow!(
29                "RETURNS expects exactly 2 arguments: current_price, previous_price"
30            ));
31        }
32
33        let current = match &args[0] {
34            DataValue::Integer(i) => *i as f64,
35            DataValue::Float(f) => *f,
36            DataValue::Null => return Ok(DataValue::Null),
37            _ => return Err(anyhow!("RETURNS expects numeric values")),
38        };
39
40        let previous = match &args[1] {
41            DataValue::Integer(i) => *i as f64,
42            DataValue::Float(f) => *f,
43            DataValue::Null => return Ok(DataValue::Null),
44            _ => return Err(anyhow!("RETURNS expects numeric values")),
45        };
46
47        if previous == 0.0 {
48            return Err(anyhow!("Cannot calculate returns with previous price of 0"));
49        }
50
51        let returns = (current - previous) / previous;
52        Ok(DataValue::Float(returns))
53    }
54}
55
56/// LOG_RETURNS function - calculates logarithmic returns
57/// Log Returns = ln(price[t] / price[t-1])
58pub struct LogReturnsFunction;
59
60impl SqlFunction for LogReturnsFunction {
61    fn signature(&self) -> FunctionSignature {
62        FunctionSignature {
63            name: "LOG_RETURNS",
64            category: FunctionCategory::Mathematical,
65            arg_count: ArgCount::Fixed(2),
66            description: "Calculate logarithmic returns from current and previous price",
67            returns: "FLOAT",
68            examples: vec![
69                "SELECT LOG_RETURNS(close, LAG(close) OVER (ORDER BY date)) FROM stocks",
70                "SELECT LOG_RETURNS(100, 95)", // Returns 0.0513 (ln(100/95))
71            ],
72        }
73    }
74
75    fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
76        if args.len() != 2 {
77            return Err(anyhow!(
78                "LOG_RETURNS expects exactly 2 arguments: current_price, previous_price"
79            ));
80        }
81
82        let current = match &args[0] {
83            DataValue::Integer(i) => *i as f64,
84            DataValue::Float(f) => *f,
85            DataValue::Null => return Ok(DataValue::Null),
86            _ => return Err(anyhow!("LOG_RETURNS expects numeric values")),
87        };
88
89        let previous = match &args[1] {
90            DataValue::Integer(i) => *i as f64,
91            DataValue::Float(f) => *f,
92            DataValue::Null => return Ok(DataValue::Null),
93            _ => return Err(anyhow!("LOG_RETURNS expects numeric values")),
94        };
95
96        if previous <= 0.0 || current <= 0.0 {
97            return Err(anyhow!(
98                "Cannot calculate log returns with non-positive prices"
99            ));
100        }
101
102        let log_returns = (current / previous).ln();
103        Ok(DataValue::Float(log_returns))
104    }
105}
106
107/// VOLATILITY function - calculates standard deviation of returns
108/// This is a simplified version that takes an array of returns
109pub struct VolatilityFunction;
110
111impl SqlFunction for VolatilityFunction {
112    fn signature(&self) -> FunctionSignature {
113        FunctionSignature {
114            name: "VOLATILITY",
115            category: FunctionCategory::Mathematical,
116            arg_count: ArgCount::Variadic,
117            description: "Calculate volatility (standard deviation) of returns",
118            returns: "FLOAT",
119            examples: vec![
120                "SELECT VOLATILITY(0.01, -0.02, 0.015, -0.005, 0.008)",
121                "WITH returns AS (SELECT RETURNS(close, LAG(close) OVER (ORDER BY date)) as r FROM stocks) SELECT VOLATILITY(r) FROM returns",
122            ],
123        }
124    }
125
126    fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
127        if args.is_empty() {
128            return Err(anyhow!("VOLATILITY requires at least one value"));
129        }
130
131        let mut values = Vec::new();
132        for arg in args {
133            match arg {
134                DataValue::Integer(i) => values.push(*i as f64),
135                DataValue::Float(f) => values.push(*f),
136                DataValue::Null => continue, // Skip nulls
137                _ => return Err(anyhow!("VOLATILITY expects numeric values")),
138            }
139        }
140
141        if values.is_empty() {
142            return Ok(DataValue::Null);
143        }
144
145        if values.len() == 1 {
146            return Ok(DataValue::Float(0.0)); // No variation with single value
147        }
148
149        // Calculate mean
150        let mean = values.iter().sum::<f64>() / values.len() as f64;
151
152        // Calculate variance
153        let variance =
154            values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / (values.len() - 1) as f64; // Sample variance (n-1)
155
156        // Standard deviation
157        let std_dev = variance.sqrt();
158        Ok(DataValue::Float(std_dev))
159    }
160}
161
162/// SHARPE_RATIO function - calculates Sharpe ratio
163/// Sharpe = (mean_return - risk_free_rate) / volatility
164pub struct SharpeRatioFunction;
165
166impl SqlFunction for SharpeRatioFunction {
167    fn signature(&self) -> FunctionSignature {
168        FunctionSignature {
169            name: "SHARPE_RATIO",
170            category: FunctionCategory::Mathematical,
171            arg_count: ArgCount::Fixed(3),
172            description: "Calculate Sharpe ratio: (mean_return - risk_free_rate) / volatility",
173            returns: "FLOAT",
174            examples: vec![
175                "SELECT SHARPE_RATIO(0.08, 0.02, 0.15)", // 8% return, 2% risk-free, 15% volatility = 0.4
176                "SELECT SHARPE_RATIO(mean_return, 0.02, volatility) FROM portfolio_stats",
177            ],
178        }
179    }
180
181    fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
182        if args.len() != 3 {
183            return Err(anyhow!(
184                "SHARPE_RATIO expects 3 arguments: mean_return, risk_free_rate, volatility"
185            ));
186        }
187
188        let mean_return = match &args[0] {
189            DataValue::Integer(i) => *i as f64,
190            DataValue::Float(f) => *f,
191            DataValue::Null => return Ok(DataValue::Null),
192            _ => return Err(anyhow!("SHARPE_RATIO expects numeric values")),
193        };
194
195        let risk_free_rate = match &args[1] {
196            DataValue::Integer(i) => *i as f64,
197            DataValue::Float(f) => *f,
198            DataValue::Null => 0.0, // Default to 0 if null
199            _ => return Err(anyhow!("SHARPE_RATIO expects numeric values")),
200        };
201
202        let volatility = match &args[2] {
203            DataValue::Integer(i) => *i as f64,
204            DataValue::Float(f) => *f,
205            DataValue::Null => return Ok(DataValue::Null),
206            _ => return Err(anyhow!("SHARPE_RATIO expects numeric values")),
207        };
208
209        if volatility == 0.0 {
210            return Err(anyhow!(
211                "Cannot calculate Sharpe ratio with zero volatility"
212            ));
213        }
214
215        let sharpe = (mean_return - risk_free_rate) / volatility;
216        Ok(DataValue::Float(sharpe))
217    }
218}
219
220/// STDDEV function - calculates standard deviation (sample)
221/// This is an alias for VOLATILITY but more SQL-standard
222pub struct StdDevFunction;
223
224impl SqlFunction for StdDevFunction {
225    fn signature(&self) -> FunctionSignature {
226        FunctionSignature {
227            name: "STDDEV",
228            category: FunctionCategory::Mathematical,
229            arg_count: ArgCount::Variadic,
230            description: "Calculate sample standard deviation",
231            returns: "FLOAT",
232            examples: vec![
233                "SELECT STDDEV(1, 2, 3, 4, 5)", // Returns 1.58
234                "SELECT STDDEV(returns) OVER (ORDER BY date ROWS 19 PRECEDING) FROM stocks",
235            ],
236        }
237    }
238
239    fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
240        // Reuse volatility implementation
241        VolatilityFunction.evaluate(args)
242    }
243}
244
245/// Register all financial functions
246pub fn register_financial_functions(registry: &mut super::FunctionRegistry) {
247    registry.register(Box::new(ReturnsFunction));
248    registry.register(Box::new(LogReturnsFunction));
249    registry.register(Box::new(VolatilityFunction));
250    registry.register(Box::new(StdDevFunction));
251    registry.register(Box::new(SharpeRatioFunction));
252}