sql_cli/sql/functions/
financial.rs

1use anyhow::{anyhow, Result};
2use std::collections::VecDeque;
3
4use super::{ArgCount, FunctionCategory, FunctionSignature, SqlFunction};
5use crate::data::datatable::DataValue;
6
7/// RETURNS function - calculates simple returns from price series
8/// Returns = (price[t] - price[t-1]) / price[t-1]
9pub struct ReturnsFunction;
10
11impl SqlFunction for ReturnsFunction {
12    fn signature(&self) -> FunctionSignature {
13        FunctionSignature {
14            name: "RETURNS",
15            category: FunctionCategory::Mathematical,
16            arg_count: ArgCount::Fixed(2),
17            description: "Calculate returns from current and previous price",
18            returns: "FLOAT",
19            examples: vec![
20                "SELECT RETURNS(close, LAG(close) OVER (ORDER BY date)) FROM stocks",
21                "SELECT RETURNS(100, 95)", // Returns 0.0526 (5.26% gain)
22                "SELECT RETURNS(95, 100)", // Returns -0.05 (5% loss)
23            ],
24        }
25    }
26
27    fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
28        if args.len() != 2 {
29            return Err(anyhow!(
30                "RETURNS expects exactly 2 arguments: current_price, previous_price"
31            ));
32        }
33
34        let current = match &args[0] {
35            DataValue::Integer(i) => *i as f64,
36            DataValue::Float(f) => *f,
37            DataValue::Null => return Ok(DataValue::Null),
38            _ => return Err(anyhow!("RETURNS expects numeric values")),
39        };
40
41        let previous = match &args[1] {
42            DataValue::Integer(i) => *i as f64,
43            DataValue::Float(f) => *f,
44            DataValue::Null => return Ok(DataValue::Null),
45            _ => return Err(anyhow!("RETURNS expects numeric values")),
46        };
47
48        if previous == 0.0 {
49            return Err(anyhow!("Cannot calculate returns with previous price of 0"));
50        }
51
52        let returns = (current - previous) / previous;
53        Ok(DataValue::Float(returns))
54    }
55}
56
57/// LOG_RETURNS function - calculates logarithmic returns
58/// Log Returns = ln(price[t] / price[t-1])
59pub struct LogReturnsFunction;
60
61impl SqlFunction for LogReturnsFunction {
62    fn signature(&self) -> FunctionSignature {
63        FunctionSignature {
64            name: "LOG_RETURNS",
65            category: FunctionCategory::Mathematical,
66            arg_count: ArgCount::Fixed(2),
67            description: "Calculate logarithmic returns from current and previous price",
68            returns: "FLOAT",
69            examples: vec![
70                "SELECT LOG_RETURNS(close, LAG(close) OVER (ORDER BY date)) FROM stocks",
71                "SELECT LOG_RETURNS(100, 95)", // Returns 0.0513 (ln(100/95))
72            ],
73        }
74    }
75
76    fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
77        if args.len() != 2 {
78            return Err(anyhow!(
79                "LOG_RETURNS expects exactly 2 arguments: current_price, previous_price"
80            ));
81        }
82
83        let current = match &args[0] {
84            DataValue::Integer(i) => *i as f64,
85            DataValue::Float(f) => *f,
86            DataValue::Null => return Ok(DataValue::Null),
87            _ => return Err(anyhow!("LOG_RETURNS expects numeric values")),
88        };
89
90        let previous = match &args[1] {
91            DataValue::Integer(i) => *i as f64,
92            DataValue::Float(f) => *f,
93            DataValue::Null => return Ok(DataValue::Null),
94            _ => return Err(anyhow!("LOG_RETURNS expects numeric values")),
95        };
96
97        if previous <= 0.0 || current <= 0.0 {
98            return Err(anyhow!(
99                "Cannot calculate log returns with non-positive prices"
100            ));
101        }
102
103        let log_returns = (current / previous).ln();
104        Ok(DataValue::Float(log_returns))
105    }
106}
107
108/// VOLATILITY function - calculates standard deviation of returns
109/// This is a simplified version that takes an array of returns
110pub struct VolatilityFunction;
111
112impl SqlFunction for VolatilityFunction {
113    fn signature(&self) -> FunctionSignature {
114        FunctionSignature {
115            name: "VOLATILITY",
116            category: FunctionCategory::Mathematical,
117            arg_count: ArgCount::Variadic,
118            description: "Calculate volatility (standard deviation) of returns",
119            returns: "FLOAT",
120            examples: vec![
121                "SELECT VOLATILITY(0.01, -0.02, 0.015, -0.005, 0.008)",
122                "WITH returns AS (SELECT RETURNS(close, LAG(close) OVER (ORDER BY date)) as r FROM stocks) SELECT VOLATILITY(r) FROM returns",
123            ],
124        }
125    }
126
127    fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
128        if args.is_empty() {
129            return Err(anyhow!("VOLATILITY requires at least one value"));
130        }
131
132        let mut values = Vec::new();
133        for arg in args {
134            match arg {
135                DataValue::Integer(i) => values.push(*i as f64),
136                DataValue::Float(f) => values.push(*f),
137                DataValue::Null => continue, // Skip nulls
138                _ => return Err(anyhow!("VOLATILITY expects numeric values")),
139            }
140        }
141
142        if values.is_empty() {
143            return Ok(DataValue::Null);
144        }
145
146        if values.len() == 1 {
147            return Ok(DataValue::Float(0.0)); // No variation with single value
148        }
149
150        // Calculate mean
151        let mean = values.iter().sum::<f64>() / values.len() as f64;
152
153        // Calculate variance
154        let variance =
155            values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / (values.len() - 1) as f64; // Sample variance (n-1)
156
157        // Standard deviation
158        let std_dev = variance.sqrt();
159        Ok(DataValue::Float(std_dev))
160    }
161}
162
163/// SHARPE_RATIO function - calculates Sharpe ratio
164/// Sharpe = (mean_return - risk_free_rate) / volatility
165pub struct SharpeRatioFunction;
166
167impl SqlFunction for SharpeRatioFunction {
168    fn signature(&self) -> FunctionSignature {
169        FunctionSignature {
170            name: "SHARPE_RATIO",
171            category: FunctionCategory::Mathematical,
172            arg_count: ArgCount::Fixed(3),
173            description: "Calculate Sharpe ratio: (mean_return - risk_free_rate) / volatility",
174            returns: "FLOAT",
175            examples: vec![
176                "SELECT SHARPE_RATIO(0.08, 0.02, 0.15)", // 8% return, 2% risk-free, 15% volatility = 0.4
177                "SELECT SHARPE_RATIO(mean_return, 0.02, volatility) FROM portfolio_stats",
178            ],
179        }
180    }
181
182    fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
183        if args.len() != 3 {
184            return Err(anyhow!(
185                "SHARPE_RATIO expects 3 arguments: mean_return, risk_free_rate, volatility"
186            ));
187        }
188
189        let mean_return = match &args[0] {
190            DataValue::Integer(i) => *i as f64,
191            DataValue::Float(f) => *f,
192            DataValue::Null => return Ok(DataValue::Null),
193            _ => return Err(anyhow!("SHARPE_RATIO expects numeric values")),
194        };
195
196        let risk_free_rate = match &args[1] {
197            DataValue::Integer(i) => *i as f64,
198            DataValue::Float(f) => *f,
199            DataValue::Null => 0.0, // Default to 0 if null
200            _ => return Err(anyhow!("SHARPE_RATIO expects numeric values")),
201        };
202
203        let volatility = match &args[2] {
204            DataValue::Integer(i) => *i as f64,
205            DataValue::Float(f) => *f,
206            DataValue::Null => return Ok(DataValue::Null),
207            _ => return Err(anyhow!("SHARPE_RATIO expects numeric values")),
208        };
209
210        if volatility == 0.0 {
211            return Err(anyhow!(
212                "Cannot calculate Sharpe ratio with zero volatility"
213            ));
214        }
215
216        let sharpe = (mean_return - risk_free_rate) / volatility;
217        Ok(DataValue::Float(sharpe))
218    }
219}
220
221/// STDDEV function - calculates standard deviation (sample)
222/// This is an alias for VOLATILITY but more SQL-standard
223pub struct StdDevFunction;
224
225impl SqlFunction for StdDevFunction {
226    fn signature(&self) -> FunctionSignature {
227        FunctionSignature {
228            name: "STDDEV",
229            category: FunctionCategory::Mathematical,
230            arg_count: ArgCount::Variadic,
231            description: "Calculate sample standard deviation",
232            returns: "FLOAT",
233            examples: vec![
234                "SELECT STDDEV(1, 2, 3, 4, 5)", // Returns 1.58
235                "SELECT STDDEV(returns) OVER (ORDER BY date ROWS 19 PRECEDING) FROM stocks",
236            ],
237        }
238    }
239
240    fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
241        // Reuse volatility implementation
242        VolatilityFunction.evaluate(args)
243    }
244}
245
246/// Register all financial functions
247pub fn register_financial_functions(registry: &mut super::FunctionRegistry) {
248    registry.register(Box::new(ReturnsFunction));
249    registry.register(Box::new(LogReturnsFunction));
250    registry.register(Box::new(VolatilityFunction));
251    registry.register(Box::new(StdDevFunction));
252    registry.register(Box::new(SharpeRatioFunction));
253}