1use anyhow::{anyhow, Result};
5use std::collections::HashMap;
6use std::sync::Arc;
7
8use crate::data::datatable::DataValue;
9use crate::sql::parser::ast::{SqlExpression, WindowSpec};
10use crate::sql::window_context::WindowContext;
11
12pub trait WindowFunction: Send + Sync {
18    fn name(&self) -> &str;
20
21    fn description(&self) -> &str;
23
24    fn signature(&self) -> &str;
26
27    fn compute(
30        &self,
31        context: &WindowContext,
32        row_index: usize,
33        args: &[SqlExpression],
34        _evaluator: &mut dyn ExpressionEvaluator,
35    ) -> Result<DataValue>;
36
37    fn transform_window_spec(
40        &self,
41        base_spec: &WindowSpec,
42        _args: &[SqlExpression],
43    ) -> Result<WindowSpec> {
44        Ok(base_spec.clone())
46    }
47
48    fn validate_args(&self, _args: &[SqlExpression]) -> Result<()> {
50        Ok(())
51    }
52}
53
54pub trait ExpressionEvaluator {
57    fn evaluate(&mut self, expr: &SqlExpression, row_index: usize) -> Result<DataValue>;
58}
59
60pub struct WindowFunctionRegistry {
62    functions: HashMap<String, Arc<Box<dyn WindowFunction>>>,
63}
64
65impl WindowFunctionRegistry {
66    pub fn new() -> Self {
67        let mut registry = Self {
68            functions: HashMap::new(),
69        };
70        registry.register_builtin_functions();
71        registry
72    }
73
74    pub fn register(&mut self, function: Box<dyn WindowFunction>) {
76        let name = function.name().to_uppercase();
77        self.functions.insert(name, Arc::new(function));
78    }
79
80    pub fn get(&self, name: &str) -> Option<Arc<Box<dyn WindowFunction>>> {
82        self.functions.get(&name.to_uppercase()).cloned()
83    }
84
85    pub fn contains(&self, name: &str) -> bool {
87        self.functions.contains_key(&name.to_uppercase())
88    }
89
90    pub fn list_functions(&self) -> Vec<String> {
92        self.functions.keys().cloned().collect()
93    }
94
95    fn register_builtin_functions(&mut self) {
97        self.register(Box::new(MovingAvgFunction));
99        self.register(Box::new(RollingStddevFunction));
100        self.register(Box::new(CumulativeSumFunction));
101        self.register(Box::new(CumulativeAvgFunction));
102        self.register(Box::new(ZScoreFunction));
103
104        self.register(Box::new(BollingerUpperFunction));
106        self.register(Box::new(BollingerLowerFunction));
107
108        self.register(Box::new(PercentChangeFunction));
110
111        }
113}
114
115struct MovingAvgFunction;
120
121impl WindowFunction for MovingAvgFunction {
122    fn name(&self) -> &str {
123        "MOVING_AVG"
124    }
125
126    fn description(&self) -> &str {
127        "Calculate moving average over specified window size"
128    }
129
130    fn signature(&self) -> &str {
131        "MOVING_AVG(column, window_size)"
132    }
133
134    fn compute(
135        &self,
136        context: &WindowContext,
137        row_index: usize,
138        args: &[SqlExpression],
139        _evaluator: &mut dyn ExpressionEvaluator,
140    ) -> Result<DataValue> {
141        let column = match &args[0] {
143            SqlExpression::Column(col) => col,
144            _ => {
145                return Err(anyhow::anyhow!(
146                    "MOVING_AVG first argument must be a column"
147                ))
148            }
149        };
150
151        context
154            .get_frame_avg(row_index, &column.name)
155            .ok_or_else(|| anyhow::anyhow!("Failed to compute moving average"))
156    }
157
158    fn transform_window_spec(
159        &self,
160        base_spec: &WindowSpec,
161        args: &[SqlExpression],
162    ) -> Result<WindowSpec> {
163        use crate::sql::parser::ast::{FrameBound, FrameUnit, WindowFrame};
164
165        let window_size = match &args.get(1) {
167            Some(SqlExpression::NumberLiteral(n)) => n
168                .parse::<i64>()
169                .map_err(|_| anyhow::anyhow!("Invalid window size"))?,
170            _ => return Err(anyhow::anyhow!("MOVING_AVG requires numeric window_size")),
171        };
172
173        let mut spec = base_spec.clone();
175        spec.frame = Some(WindowFrame {
176            unit: FrameUnit::Rows,
177            start: FrameBound::Preceding(window_size - 1),
178            end: None, });
180
181        Ok(spec)
182    }
183
184    fn validate_args(&self, args: &[SqlExpression]) -> Result<()> {
185        if args.len() != 2 {
186            return Err(anyhow::anyhow!("MOVING_AVG requires exactly 2 arguments"));
187        }
188        Ok(())
189    }
190}
191
192struct RollingStddevFunction;
195
196impl WindowFunction for RollingStddevFunction {
197    fn name(&self) -> &str {
198        "ROLLING_STDDEV"
199    }
200
201    fn description(&self) -> &str {
202        "Calculate rolling standard deviation over specified window"
203    }
204
205    fn signature(&self) -> &str {
206        "ROLLING_STDDEV(column, window_size)"
207    }
208
209    fn compute(
210        &self,
211        context: &WindowContext,
212        row_index: usize,
213        args: &[SqlExpression],
214        _evaluator: &mut dyn ExpressionEvaluator,
215    ) -> Result<DataValue> {
216        let column = match &args[0] {
217            SqlExpression::Column(col) => col,
218            _ => {
219                return Err(anyhow::anyhow!(
220                    "ROLLING_STDDEV first argument must be a column"
221                ))
222            }
223        };
224
225        context
226            .get_frame_stddev(row_index, &column.name)
227            .ok_or_else(|| anyhow::anyhow!("Failed to compute rolling stddev"))
228    }
229
230    fn transform_window_spec(
231        &self,
232        base_spec: &WindowSpec,
233        args: &[SqlExpression],
234    ) -> Result<WindowSpec> {
235        use crate::sql::parser::ast::{FrameBound, FrameUnit, WindowFrame};
236
237        let window_size = match &args.get(1) {
238            Some(SqlExpression::NumberLiteral(n)) => n
239                .parse::<i64>()
240                .map_err(|_| anyhow::anyhow!("Invalid window size"))?,
241            _ => {
242                return Err(anyhow::anyhow!(
243                    "ROLLING_STDDEV requires numeric window_size"
244                ))
245            }
246        };
247
248        let mut spec = base_spec.clone();
249        spec.frame = Some(WindowFrame {
250            unit: FrameUnit::Rows,
251            start: FrameBound::Preceding(window_size - 1),
252            end: None,
253        });
254
255        Ok(spec)
256    }
257
258    fn validate_args(&self, args: &[SqlExpression]) -> Result<()> {
259        if args.len() != 2 {
260            return Err(anyhow::anyhow!(
261                "ROLLING_STDDEV requires exactly 2 arguments"
262            ));
263        }
264        Ok(())
265    }
266}
267
268struct CumulativeSumFunction;
271
272impl WindowFunction for CumulativeSumFunction {
273    fn name(&self) -> &str {
274        "CUMULATIVE_SUM"
275    }
276
277    fn description(&self) -> &str {
278        "Calculate cumulative sum from beginning to current row"
279    }
280
281    fn signature(&self) -> &str {
282        "CUMULATIVE_SUM(column)"
283    }
284
285    fn compute(
286        &self,
287        context: &WindowContext,
288        row_index: usize,
289        args: &[SqlExpression],
290        _evaluator: &mut dyn ExpressionEvaluator,
291    ) -> Result<DataValue> {
292        let column = match &args[0] {
293            SqlExpression::Column(col) => col,
294            _ => return Err(anyhow::anyhow!("CUMULATIVE_SUM argument must be a column")),
295        };
296
297        context
298            .get_frame_sum(row_index, &column.name)
299            .ok_or_else(|| anyhow::anyhow!("Failed to compute cumulative sum"))
300    }
301
302    fn transform_window_spec(
303        &self,
304        base_spec: &WindowSpec,
305        _args: &[SqlExpression],
306    ) -> Result<WindowSpec> {
307        use crate::sql::parser::ast::{FrameBound, FrameUnit, WindowFrame};
308
309        let mut spec = base_spec.clone();
310        spec.frame = Some(WindowFrame {
311            unit: FrameUnit::Rows,
312            start: FrameBound::UnboundedPreceding,
313            end: None, });
315
316        Ok(spec)
317    }
318
319    fn validate_args(&self, args: &[SqlExpression]) -> Result<()> {
320        if args.len() != 1 {
321            return Err(anyhow::anyhow!(
322                "CUMULATIVE_SUM requires exactly 1 argument"
323            ));
324        }
325        Ok(())
326    }
327}
328
329struct CumulativeAvgFunction;
332
333impl WindowFunction for CumulativeAvgFunction {
334    fn name(&self) -> &str {
335        "CUMULATIVE_AVG"
336    }
337
338    fn description(&self) -> &str {
339        "Calculate cumulative average from beginning to current row"
340    }
341
342    fn signature(&self) -> &str {
343        "CUMULATIVE_AVG(column)"
344    }
345
346    fn compute(
347        &self,
348        context: &WindowContext,
349        row_index: usize,
350        args: &[SqlExpression],
351        _evaluator: &mut dyn ExpressionEvaluator,
352    ) -> Result<DataValue> {
353        let column = match &args[0] {
354            SqlExpression::Column(col) => col,
355            _ => return Err(anyhow::anyhow!("CUMULATIVE_AVG argument must be a column")),
356        };
357
358        context
359            .get_frame_avg(row_index, &column.name)
360            .ok_or_else(|| anyhow::anyhow!("Failed to compute cumulative average"))
361    }
362
363    fn transform_window_spec(
364        &self,
365        base_spec: &WindowSpec,
366        _args: &[SqlExpression],
367    ) -> Result<WindowSpec> {
368        use crate::sql::parser::ast::{FrameBound, FrameUnit, WindowFrame};
369
370        let mut spec = base_spec.clone();
371        spec.frame = Some(WindowFrame {
372            unit: FrameUnit::Rows,
373            start: FrameBound::UnboundedPreceding,
374            end: None,
375        });
376
377        Ok(spec)
378    }
379
380    fn validate_args(&self, args: &[SqlExpression]) -> Result<()> {
381        if args.len() != 1 {
382            return Err(anyhow::anyhow!(
383                "CUMULATIVE_AVG requires exactly 1 argument"
384            ));
385        }
386        Ok(())
387    }
388}
389
390struct ZScoreFunction;
393
394impl WindowFunction for ZScoreFunction {
395    fn name(&self) -> &str {
396        "Z_SCORE"
397    }
398
399    fn description(&self) -> &str {
400        "Calculate Z-score (standard deviations from mean) over window"
401    }
402
403    fn signature(&self) -> &str {
404        "Z_SCORE(column, window_size)"
405    }
406
407    fn compute(
408        &self,
409        context: &WindowContext,
410        row_index: usize,
411        args: &[SqlExpression],
412        _evaluator: &mut dyn ExpressionEvaluator,
413    ) -> Result<DataValue> {
414        let column = match &args[0] {
415            SqlExpression::Column(col) => col,
416            _ => return Err(anyhow::anyhow!("Z_SCORE first argument must be a column")),
417        };
418
419        let current_value = {
421            let source = context.source();
422            let col_idx = source
423                .get_column_index(&column.name)
424                .ok_or_else(|| anyhow::anyhow!("Column {} not found", column))?;
425            source
426                .get_value(row_index, col_idx)
427                .cloned()
428                .unwrap_or(DataValue::Null)
429        };
430
431        let mean = context
433            .get_frame_avg(row_index, &column.name)
434            .unwrap_or(DataValue::Null);
435        let stddev = context
436            .get_frame_stddev(row_index, &column.name)
437            .unwrap_or(DataValue::Null);
438
439        match (current_value, mean, stddev) {
441            (DataValue::Integer(v), DataValue::Float(m), DataValue::Float(s)) if s > 0.0 => {
442                Ok(DataValue::Float((v as f64 - m) / s))
443            }
444            (DataValue::Float(v), DataValue::Float(m), DataValue::Float(s)) if s > 0.0 => {
445                Ok(DataValue::Float((v - m) / s))
446            }
447            _ => Ok(DataValue::Null),
448        }
449    }
450
451    fn transform_window_spec(
452        &self,
453        base_spec: &WindowSpec,
454        args: &[SqlExpression],
455    ) -> Result<WindowSpec> {
456        use crate::sql::parser::ast::{FrameBound, FrameUnit, WindowFrame};
457
458        let window_size = match &args.get(1) {
459            Some(SqlExpression::NumberLiteral(n)) => n
460                .parse::<i64>()
461                .map_err(|_| anyhow::anyhow!("Invalid window size"))?,
462            _ => return Err(anyhow::anyhow!("Z_SCORE requires numeric window_size")),
463        };
464
465        let mut spec = base_spec.clone();
466        spec.frame = Some(WindowFrame {
467            unit: FrameUnit::Rows,
468            start: FrameBound::Preceding(window_size - 1),
469            end: None,
470        });
471
472        Ok(spec)
473    }
474
475    fn validate_args(&self, args: &[SqlExpression]) -> Result<()> {
476        if args.len() != 2 {
477            return Err(anyhow::anyhow!("Z_SCORE requires exactly 2 arguments"));
478        }
479        Ok(())
480    }
481}
482
483struct BollingerUpperFunction;
486
487impl WindowFunction for BollingerUpperFunction {
488    fn name(&self) -> &str {
489        "BOLLINGER_UPPER"
490    }
491
492    fn description(&self) -> &str {
493        "Calculate upper Bollinger Band (MA + n*STDDEV)"
494    }
495
496    fn signature(&self) -> &str {
497        "BOLLINGER_UPPER(column, window_size, num_std)"
498    }
499
500    fn compute(
501        &self,
502        context: &WindowContext,
503        row_index: usize,
504        args: &[SqlExpression],
505        _evaluator: &mut dyn ExpressionEvaluator,
506    ) -> Result<DataValue> {
507        let column = match &args[0] {
508            SqlExpression::Column(col) => col,
509            _ => return Err(anyhow!("BOLLINGER_UPPER first argument must be a column")),
510        };
511
512        let num_std = match args.get(2) {
514            Some(SqlExpression::NumberLiteral(n)) => n
515                .parse::<f64>()
516                .map_err(|_| anyhow!("Invalid num_std value"))?,
517            _ => 2.0, };
519
520        let mean = context
522            .get_frame_avg(row_index, &column.name)
523            .unwrap_or(DataValue::Null);
524        let stddev = context
525            .get_frame_stddev(row_index, &column.name)
526            .unwrap_or(DataValue::Null);
527
528        match (mean, stddev) {
530            (DataValue::Float(m), DataValue::Float(s)) => Ok(DataValue::Float(m + (num_std * s))),
531            _ => Ok(DataValue::Null),
532        }
533    }
534
535    fn transform_window_spec(
536        &self,
537        base_spec: &WindowSpec,
538        args: &[SqlExpression],
539    ) -> Result<WindowSpec> {
540        use crate::sql::parser::ast::{FrameBound, FrameUnit, WindowFrame};
541
542        let window_size = match args.get(1) {
543            Some(SqlExpression::NumberLiteral(n)) => n
544                .parse::<i64>()
545                .map_err(|_| anyhow!("Invalid window size"))?,
546            _ => return Err(anyhow!("BOLLINGER_UPPER requires numeric window_size")),
547        };
548
549        let mut spec = base_spec.clone();
550        spec.frame = Some(WindowFrame {
551            unit: FrameUnit::Rows,
552            start: FrameBound::Preceding(window_size - 1),
553            end: None,
554        });
555
556        Ok(spec)
557    }
558
559    fn validate_args(&self, args: &[SqlExpression]) -> Result<()> {
560        if args.len() < 2 || args.len() > 3 {
561            return Err(anyhow!("BOLLINGER_UPPER requires 2 or 3 arguments"));
562        }
563        Ok(())
564    }
565}
566
567struct BollingerLowerFunction;
570
571impl WindowFunction for BollingerLowerFunction {
572    fn name(&self) -> &str {
573        "BOLLINGER_LOWER"
574    }
575
576    fn description(&self) -> &str {
577        "Calculate lower Bollinger Band (MA - n*STDDEV)"
578    }
579
580    fn signature(&self) -> &str {
581        "BOLLINGER_LOWER(column, window_size, num_std)"
582    }
583
584    fn compute(
585        &self,
586        context: &WindowContext,
587        row_index: usize,
588        args: &[SqlExpression],
589        _evaluator: &mut dyn ExpressionEvaluator,
590    ) -> Result<DataValue> {
591        let column = match &args[0] {
592            SqlExpression::Column(col) => col,
593            _ => return Err(anyhow!("BOLLINGER_LOWER first argument must be a column")),
594        };
595
596        let num_std = match args.get(2) {
598            Some(SqlExpression::NumberLiteral(n)) => n
599                .parse::<f64>()
600                .map_err(|_| anyhow!("Invalid num_std value"))?,
601            _ => 2.0, };
603
604        let mean = context
606            .get_frame_avg(row_index, &column.name)
607            .unwrap_or(DataValue::Null);
608        let stddev = context
609            .get_frame_stddev(row_index, &column.name)
610            .unwrap_or(DataValue::Null);
611
612        match (mean, stddev) {
614            (DataValue::Float(m), DataValue::Float(s)) => Ok(DataValue::Float(m - (num_std * s))),
615            _ => Ok(DataValue::Null),
616        }
617    }
618
619    fn transform_window_spec(
620        &self,
621        base_spec: &WindowSpec,
622        args: &[SqlExpression],
623    ) -> Result<WindowSpec> {
624        use crate::sql::parser::ast::{FrameBound, FrameUnit, WindowFrame};
625
626        let window_size = match args.get(1) {
627            Some(SqlExpression::NumberLiteral(n)) => n
628                .parse::<i64>()
629                .map_err(|_| anyhow!("Invalid window size"))?,
630            _ => return Err(anyhow!("BOLLINGER_LOWER requires numeric window_size")),
631        };
632
633        let mut spec = base_spec.clone();
634        spec.frame = Some(WindowFrame {
635            unit: FrameUnit::Rows,
636            start: FrameBound::Preceding(window_size - 1),
637            end: None,
638        });
639
640        Ok(spec)
641    }
642
643    fn validate_args(&self, args: &[SqlExpression]) -> Result<()> {
644        if args.len() < 2 || args.len() > 3 {
645            return Err(anyhow!("BOLLINGER_LOWER requires 2 or 3 arguments"));
646        }
647        Ok(())
648    }
649}
650
651struct PercentChangeFunction;
655
656impl WindowFunction for PercentChangeFunction {
657    fn name(&self) -> &str {
658        "PERCENT_CHANGE"
659    }
660
661    fn description(&self) -> &str {
662        "Calculate percentage change from N periods ago"
663    }
664
665    fn signature(&self) -> &str {
666        "PERCENT_CHANGE(column, periods)"
667    }
668
669    fn compute(
670        &self,
671        context: &WindowContext,
672        row_index: usize,
673        args: &[SqlExpression],
674        _evaluator: &mut dyn ExpressionEvaluator,
675    ) -> Result<DataValue> {
676        let column = match &args[0] {
677            SqlExpression::Column(col) => col,
678            _ => return Err(anyhow!("PERCENT_CHANGE first argument must be a column")),
679        };
680
681        let periods = match args.get(1) {
683            Some(SqlExpression::NumberLiteral(n)) => n
684                .parse::<i32>()
685                .map_err(|_| anyhow!("Invalid periods value"))?,
686            _ => 1, };
688
689        let current_value = {
691            let source = context.source();
692            let col_idx = source
693                .get_column_index(&column.name)
694                .ok_or_else(|| anyhow!("Column {} not found", column))?;
695            source.get_value(row_index, col_idx).cloned()
696        };
697
698        let previous_value = context.get_offset_value(row_index, -periods, &column.name);
700
701        match (current_value, previous_value) {
703            (Some(DataValue::Float(curr)), Some(DataValue::Float(prev))) if prev != 0.0 => {
704                Ok(DataValue::Float(((curr - prev) / prev) * 100.0))
705            }
706            (Some(DataValue::Integer(curr)), Some(DataValue::Integer(prev))) if prev != 0 => {
707                let curr_f = curr as f64;
708                let prev_f = prev as f64;
709                Ok(DataValue::Float(((curr_f - prev_f) / prev_f) * 100.0))
710            }
711            (Some(DataValue::Float(curr)), Some(DataValue::Integer(prev))) if prev != 0 => {
712                let prev_f = prev as f64;
713                Ok(DataValue::Float(((curr - prev_f) / prev_f) * 100.0))
714            }
715            (Some(DataValue::Integer(curr)), Some(DataValue::Float(prev))) if prev != 0.0 => {
716                let curr_f = curr as f64;
717                Ok(DataValue::Float(((curr_f - prev) / prev) * 100.0))
718            }
719            _ => Ok(DataValue::Null), }
721    }
722
723    fn transform_window_spec(
724        &self,
725        base_spec: &WindowSpec,
726        _args: &[SqlExpression],
727    ) -> Result<WindowSpec> {
728        Ok(base_spec.clone())
731    }
732
733    fn validate_args(&self, args: &[SqlExpression]) -> Result<()> {
734        if args.is_empty() || args.len() > 2 {
735            return Err(anyhow!("PERCENT_CHANGE requires 1 or 2 arguments"));
736        }
737        Ok(())
738    }
739}
740
741#[cfg(test)]
747mod tests {
748    use super::*;
749    use crate::sql::parser::ast::ColumnRef;
750
751    #[test]
752    fn test_registry_creation() {
753        let registry = WindowFunctionRegistry::new();
754        assert!(registry.contains("MOVING_AVG"));
755        assert!(registry.contains("ROLLING_STDDEV"));
756        assert!(registry.contains("CUMULATIVE_SUM"));
757    }
758
759    #[test]
760    fn test_window_spec_transformation() {
761        use crate::sql::parser::ast::{FrameBound, WindowSpec};
762
763        let func = MovingAvgFunction;
764        let base_spec = WindowSpec {
765            partition_by: vec![],
766            order_by: vec![],
767            frame: None,
768        };
769
770        let args = vec![
771            SqlExpression::Column(ColumnRef::unquoted("close".to_string())),
772            SqlExpression::NumberLiteral("20".to_string()),
773        ];
774
775        let transformed = func.transform_window_spec(&base_spec, &args).unwrap();
776
777        assert!(transformed.frame.is_some());
778        let frame = transformed.frame.unwrap();
779        assert_eq!(frame.start, FrameBound::Preceding(19));
780    }
781}