1use anyhow::{anyhow, Result};
5use std::collections::HashMap;
6use std::sync::Arc;
7
8use crate::data::datatable::DataValue;
9use crate::sql::parser::ast::{SqlExpression, WindowSpec};
10use crate::sql::window_context::WindowContext;
11
12mod aggregates;
14use aggregates::*;
15
16pub trait WindowFunction: Send + Sync {
22 fn name(&self) -> &str;
24
25 fn description(&self) -> &str;
27
28 fn signature(&self) -> &str;
30
31 fn compute(
34 &self,
35 context: &WindowContext,
36 row_index: usize,
37 args: &[SqlExpression],
38 _evaluator: &mut dyn ExpressionEvaluator,
39 ) -> Result<DataValue>;
40
41 fn transform_window_spec(
44 &self,
45 base_spec: &WindowSpec,
46 _args: &[SqlExpression],
47 ) -> Result<WindowSpec> {
48 Ok(base_spec.clone())
50 }
51
52 fn validate_args(&self, _args: &[SqlExpression]) -> Result<()> {
54 Ok(())
55 }
56}
57
58pub trait ExpressionEvaluator {
61 fn evaluate(&mut self, expr: &SqlExpression, row_index: usize) -> Result<DataValue>;
62}
63
64pub struct WindowFunctionRegistry {
66 functions: HashMap<String, Arc<Box<dyn WindowFunction>>>,
67}
68
69impl WindowFunctionRegistry {
70 pub fn new() -> Self {
71 let mut registry = Self {
72 functions: HashMap::new(),
73 };
74 registry.register_builtin_functions();
75 registry
76 }
77
78 pub fn register(&mut self, function: Box<dyn WindowFunction>) {
80 let name = function.name().to_uppercase();
81 self.functions.insert(name, Arc::new(function));
82 }
83
84 pub fn get(&self, name: &str) -> Option<Arc<Box<dyn WindowFunction>>> {
86 self.functions.get(&name.to_uppercase()).cloned()
87 }
88
89 pub fn contains(&self, name: &str) -> bool {
91 self.functions.contains_key(&name.to_uppercase())
92 }
93
94 pub fn list_functions(&self) -> Vec<String> {
96 self.functions.keys().cloned().collect()
97 }
98
99 fn register_builtin_functions(&mut self) {
101 self.register(Box::new(WindowSumFunction));
103 self.register(Box::new(WindowAvgFunction));
104 self.register(Box::new(WindowMinFunction));
105 self.register(Box::new(WindowMaxFunction));
106 self.register(Box::new(WindowCountFunction));
107 self.register(Box::new(WindowStddevFunction));
108 self.register(Box::new(WindowStdevFunction)); self.register(Box::new(WindowVarianceFunction));
110 self.register(Box::new(WindowVarFunction)); self.register(Box::new(MovingAvgFunction));
114 self.register(Box::new(RollingStddevFunction));
115 self.register(Box::new(CumulativeSumFunction));
116 self.register(Box::new(CumulativeAvgFunction));
117 self.register(Box::new(ZScoreFunction));
118
119 self.register(Box::new(BollingerUpperFunction));
121 self.register(Box::new(BollingerLowerFunction));
122
123 self.register(Box::new(PercentChangeFunction));
125
126 }
128}
129
130struct MovingAvgFunction;
135
136impl WindowFunction for MovingAvgFunction {
137 fn name(&self) -> &str {
138 "MOVING_AVG"
139 }
140
141 fn description(&self) -> &str {
142 "Calculate moving average over specified window size"
143 }
144
145 fn signature(&self) -> &str {
146 "MOVING_AVG(column, window_size)"
147 }
148
149 fn compute(
150 &self,
151 context: &WindowContext,
152 row_index: usize,
153 args: &[SqlExpression],
154 _evaluator: &mut dyn ExpressionEvaluator,
155 ) -> Result<DataValue> {
156 let column = match &args[0] {
158 SqlExpression::Column(col) => col,
159 _ => {
160 return Err(anyhow::anyhow!(
161 "MOVING_AVG first argument must be a column"
162 ))
163 }
164 };
165
166 context
169 .get_frame_avg(row_index, &column.name)
170 .ok_or_else(|| anyhow::anyhow!("Failed to compute moving average"))
171 }
172
173 fn transform_window_spec(
174 &self,
175 base_spec: &WindowSpec,
176 args: &[SqlExpression],
177 ) -> Result<WindowSpec> {
178 use crate::sql::parser::ast::{FrameBound, FrameUnit, WindowFrame};
179
180 let window_size = match &args.get(1) {
182 Some(SqlExpression::NumberLiteral(n)) => n
183 .parse::<i64>()
184 .map_err(|_| anyhow::anyhow!("Invalid window size"))?,
185 _ => return Err(anyhow::anyhow!("MOVING_AVG requires numeric window_size")),
186 };
187
188 let mut spec = base_spec.clone();
190 spec.frame = Some(WindowFrame {
191 unit: FrameUnit::Rows,
192 start: FrameBound::Preceding(window_size - 1),
193 end: None, });
195
196 Ok(spec)
197 }
198
199 fn validate_args(&self, args: &[SqlExpression]) -> Result<()> {
200 if args.len() != 2 {
201 return Err(anyhow::anyhow!("MOVING_AVG requires exactly 2 arguments"));
202 }
203 Ok(())
204 }
205}
206
207struct RollingStddevFunction;
210
211impl WindowFunction for RollingStddevFunction {
212 fn name(&self) -> &str {
213 "ROLLING_STDDEV"
214 }
215
216 fn description(&self) -> &str {
217 "Calculate rolling standard deviation over specified window"
218 }
219
220 fn signature(&self) -> &str {
221 "ROLLING_STDDEV(column, window_size)"
222 }
223
224 fn compute(
225 &self,
226 context: &WindowContext,
227 row_index: usize,
228 args: &[SqlExpression],
229 _evaluator: &mut dyn ExpressionEvaluator,
230 ) -> Result<DataValue> {
231 let column = match &args[0] {
232 SqlExpression::Column(col) => col,
233 _ => {
234 return Err(anyhow::anyhow!(
235 "ROLLING_STDDEV first argument must be a column"
236 ))
237 }
238 };
239
240 context
241 .get_frame_stddev(row_index, &column.name)
242 .ok_or_else(|| anyhow::anyhow!("Failed to compute rolling stddev"))
243 }
244
245 fn transform_window_spec(
246 &self,
247 base_spec: &WindowSpec,
248 args: &[SqlExpression],
249 ) -> Result<WindowSpec> {
250 use crate::sql::parser::ast::{FrameBound, FrameUnit, WindowFrame};
251
252 let window_size = match &args.get(1) {
253 Some(SqlExpression::NumberLiteral(n)) => n
254 .parse::<i64>()
255 .map_err(|_| anyhow::anyhow!("Invalid window size"))?,
256 _ => {
257 return Err(anyhow::anyhow!(
258 "ROLLING_STDDEV requires numeric window_size"
259 ))
260 }
261 };
262
263 let mut spec = base_spec.clone();
264 spec.frame = Some(WindowFrame {
265 unit: FrameUnit::Rows,
266 start: FrameBound::Preceding(window_size - 1),
267 end: None,
268 });
269
270 Ok(spec)
271 }
272
273 fn validate_args(&self, args: &[SqlExpression]) -> Result<()> {
274 if args.len() != 2 {
275 return Err(anyhow::anyhow!(
276 "ROLLING_STDDEV requires exactly 2 arguments"
277 ));
278 }
279 Ok(())
280 }
281}
282
283struct CumulativeSumFunction;
286
287impl WindowFunction for CumulativeSumFunction {
288 fn name(&self) -> &str {
289 "CUMULATIVE_SUM"
290 }
291
292 fn description(&self) -> &str {
293 "Calculate cumulative sum from beginning to current row"
294 }
295
296 fn signature(&self) -> &str {
297 "CUMULATIVE_SUM(column)"
298 }
299
300 fn compute(
301 &self,
302 context: &WindowContext,
303 row_index: usize,
304 args: &[SqlExpression],
305 _evaluator: &mut dyn ExpressionEvaluator,
306 ) -> Result<DataValue> {
307 let column = match &args[0] {
308 SqlExpression::Column(col) => col,
309 _ => return Err(anyhow::anyhow!("CUMULATIVE_SUM argument must be a column")),
310 };
311
312 context
313 .get_frame_sum(row_index, &column.name)
314 .ok_or_else(|| anyhow::anyhow!("Failed to compute cumulative sum"))
315 }
316
317 fn transform_window_spec(
318 &self,
319 base_spec: &WindowSpec,
320 _args: &[SqlExpression],
321 ) -> Result<WindowSpec> {
322 use crate::sql::parser::ast::{FrameBound, FrameUnit, WindowFrame};
323
324 let mut spec = base_spec.clone();
325 spec.frame = Some(WindowFrame {
326 unit: FrameUnit::Rows,
327 start: FrameBound::UnboundedPreceding,
328 end: None, });
330
331 Ok(spec)
332 }
333
334 fn validate_args(&self, args: &[SqlExpression]) -> Result<()> {
335 if args.len() != 1 {
336 return Err(anyhow::anyhow!(
337 "CUMULATIVE_SUM requires exactly 1 argument"
338 ));
339 }
340 Ok(())
341 }
342}
343
344struct CumulativeAvgFunction;
347
348impl WindowFunction for CumulativeAvgFunction {
349 fn name(&self) -> &str {
350 "CUMULATIVE_AVG"
351 }
352
353 fn description(&self) -> &str {
354 "Calculate cumulative average from beginning to current row"
355 }
356
357 fn signature(&self) -> &str {
358 "CUMULATIVE_AVG(column)"
359 }
360
361 fn compute(
362 &self,
363 context: &WindowContext,
364 row_index: usize,
365 args: &[SqlExpression],
366 _evaluator: &mut dyn ExpressionEvaluator,
367 ) -> Result<DataValue> {
368 let column = match &args[0] {
369 SqlExpression::Column(col) => col,
370 _ => return Err(anyhow::anyhow!("CUMULATIVE_AVG argument must be a column")),
371 };
372
373 context
374 .get_frame_avg(row_index, &column.name)
375 .ok_or_else(|| anyhow::anyhow!("Failed to compute cumulative average"))
376 }
377
378 fn transform_window_spec(
379 &self,
380 base_spec: &WindowSpec,
381 _args: &[SqlExpression],
382 ) -> Result<WindowSpec> {
383 use crate::sql::parser::ast::{FrameBound, FrameUnit, WindowFrame};
384
385 let mut spec = base_spec.clone();
386 spec.frame = Some(WindowFrame {
387 unit: FrameUnit::Rows,
388 start: FrameBound::UnboundedPreceding,
389 end: None,
390 });
391
392 Ok(spec)
393 }
394
395 fn validate_args(&self, args: &[SqlExpression]) -> Result<()> {
396 if args.len() != 1 {
397 return Err(anyhow::anyhow!(
398 "CUMULATIVE_AVG requires exactly 1 argument"
399 ));
400 }
401 Ok(())
402 }
403}
404
405struct ZScoreFunction;
408
409impl WindowFunction for ZScoreFunction {
410 fn name(&self) -> &str {
411 "Z_SCORE"
412 }
413
414 fn description(&self) -> &str {
415 "Calculate Z-score (standard deviations from mean) over window"
416 }
417
418 fn signature(&self) -> &str {
419 "Z_SCORE(column, window_size)"
420 }
421
422 fn compute(
423 &self,
424 context: &WindowContext,
425 row_index: usize,
426 args: &[SqlExpression],
427 _evaluator: &mut dyn ExpressionEvaluator,
428 ) -> Result<DataValue> {
429 let column = match &args[0] {
430 SqlExpression::Column(col) => col,
431 _ => return Err(anyhow::anyhow!("Z_SCORE first argument must be a column")),
432 };
433
434 let current_value = {
436 let source = context.source();
437 let col_idx = source
438 .get_column_index(&column.name)
439 .ok_or_else(|| anyhow::anyhow!("Column {} not found", column))?;
440 source
441 .get_value(row_index, col_idx)
442 .cloned()
443 .unwrap_or(DataValue::Null)
444 };
445
446 let mean = context
448 .get_frame_avg(row_index, &column.name)
449 .unwrap_or(DataValue::Null);
450 let stddev = context
451 .get_frame_stddev(row_index, &column.name)
452 .unwrap_or(DataValue::Null);
453
454 match (current_value, mean, stddev) {
456 (DataValue::Integer(v), DataValue::Float(m), DataValue::Float(s)) if s > 0.0 => {
457 Ok(DataValue::Float((v as f64 - m) / s))
458 }
459 (DataValue::Float(v), DataValue::Float(m), DataValue::Float(s)) if s > 0.0 => {
460 Ok(DataValue::Float((v - m) / s))
461 }
462 _ => Ok(DataValue::Null),
463 }
464 }
465
466 fn transform_window_spec(
467 &self,
468 base_spec: &WindowSpec,
469 args: &[SqlExpression],
470 ) -> Result<WindowSpec> {
471 use crate::sql::parser::ast::{FrameBound, FrameUnit, WindowFrame};
472
473 let window_size = match &args.get(1) {
474 Some(SqlExpression::NumberLiteral(n)) => n
475 .parse::<i64>()
476 .map_err(|_| anyhow::anyhow!("Invalid window size"))?,
477 _ => return Err(anyhow::anyhow!("Z_SCORE requires numeric window_size")),
478 };
479
480 let mut spec = base_spec.clone();
481 spec.frame = Some(WindowFrame {
482 unit: FrameUnit::Rows,
483 start: FrameBound::Preceding(window_size - 1),
484 end: None,
485 });
486
487 Ok(spec)
488 }
489
490 fn validate_args(&self, args: &[SqlExpression]) -> Result<()> {
491 if args.len() != 2 {
492 return Err(anyhow::anyhow!("Z_SCORE requires exactly 2 arguments"));
493 }
494 Ok(())
495 }
496}
497
498struct BollingerUpperFunction;
501
502impl WindowFunction for BollingerUpperFunction {
503 fn name(&self) -> &str {
504 "BOLLINGER_UPPER"
505 }
506
507 fn description(&self) -> &str {
508 "Calculate upper Bollinger Band (MA + n*STDDEV)"
509 }
510
511 fn signature(&self) -> &str {
512 "BOLLINGER_UPPER(column, window_size, num_std)"
513 }
514
515 fn compute(
516 &self,
517 context: &WindowContext,
518 row_index: usize,
519 args: &[SqlExpression],
520 _evaluator: &mut dyn ExpressionEvaluator,
521 ) -> Result<DataValue> {
522 let column = match &args[0] {
523 SqlExpression::Column(col) => col,
524 _ => return Err(anyhow!("BOLLINGER_UPPER first argument must be a column")),
525 };
526
527 let num_std = match args.get(2) {
529 Some(SqlExpression::NumberLiteral(n)) => n
530 .parse::<f64>()
531 .map_err(|_| anyhow!("Invalid num_std value"))?,
532 _ => 2.0, };
534
535 let mean = context
537 .get_frame_avg(row_index, &column.name)
538 .unwrap_or(DataValue::Null);
539 let stddev = context
540 .get_frame_stddev(row_index, &column.name)
541 .unwrap_or(DataValue::Null);
542
543 match (mean, stddev) {
545 (DataValue::Float(m), DataValue::Float(s)) => Ok(DataValue::Float(m + (num_std * s))),
546 _ => Ok(DataValue::Null),
547 }
548 }
549
550 fn transform_window_spec(
551 &self,
552 base_spec: &WindowSpec,
553 args: &[SqlExpression],
554 ) -> Result<WindowSpec> {
555 use crate::sql::parser::ast::{FrameBound, FrameUnit, WindowFrame};
556
557 let window_size = match args.get(1) {
558 Some(SqlExpression::NumberLiteral(n)) => n
559 .parse::<i64>()
560 .map_err(|_| anyhow!("Invalid window size"))?,
561 _ => return Err(anyhow!("BOLLINGER_UPPER requires numeric window_size")),
562 };
563
564 let mut spec = base_spec.clone();
565 spec.frame = Some(WindowFrame {
566 unit: FrameUnit::Rows,
567 start: FrameBound::Preceding(window_size - 1),
568 end: None,
569 });
570
571 Ok(spec)
572 }
573
574 fn validate_args(&self, args: &[SqlExpression]) -> Result<()> {
575 if args.len() < 2 || args.len() > 3 {
576 return Err(anyhow!("BOLLINGER_UPPER requires 2 or 3 arguments"));
577 }
578 Ok(())
579 }
580}
581
582struct BollingerLowerFunction;
585
586impl WindowFunction for BollingerLowerFunction {
587 fn name(&self) -> &str {
588 "BOLLINGER_LOWER"
589 }
590
591 fn description(&self) -> &str {
592 "Calculate lower Bollinger Band (MA - n*STDDEV)"
593 }
594
595 fn signature(&self) -> &str {
596 "BOLLINGER_LOWER(column, window_size, num_std)"
597 }
598
599 fn compute(
600 &self,
601 context: &WindowContext,
602 row_index: usize,
603 args: &[SqlExpression],
604 _evaluator: &mut dyn ExpressionEvaluator,
605 ) -> Result<DataValue> {
606 let column = match &args[0] {
607 SqlExpression::Column(col) => col,
608 _ => return Err(anyhow!("BOLLINGER_LOWER first argument must be a column")),
609 };
610
611 let num_std = match args.get(2) {
613 Some(SqlExpression::NumberLiteral(n)) => n
614 .parse::<f64>()
615 .map_err(|_| anyhow!("Invalid num_std value"))?,
616 _ => 2.0, };
618
619 let mean = context
621 .get_frame_avg(row_index, &column.name)
622 .unwrap_or(DataValue::Null);
623 let stddev = context
624 .get_frame_stddev(row_index, &column.name)
625 .unwrap_or(DataValue::Null);
626
627 match (mean, stddev) {
629 (DataValue::Float(m), DataValue::Float(s)) => Ok(DataValue::Float(m - (num_std * s))),
630 _ => Ok(DataValue::Null),
631 }
632 }
633
634 fn transform_window_spec(
635 &self,
636 base_spec: &WindowSpec,
637 args: &[SqlExpression],
638 ) -> Result<WindowSpec> {
639 use crate::sql::parser::ast::{FrameBound, FrameUnit, WindowFrame};
640
641 let window_size = match args.get(1) {
642 Some(SqlExpression::NumberLiteral(n)) => n
643 .parse::<i64>()
644 .map_err(|_| anyhow!("Invalid window size"))?,
645 _ => return Err(anyhow!("BOLLINGER_LOWER requires numeric window_size")),
646 };
647
648 let mut spec = base_spec.clone();
649 spec.frame = Some(WindowFrame {
650 unit: FrameUnit::Rows,
651 start: FrameBound::Preceding(window_size - 1),
652 end: None,
653 });
654
655 Ok(spec)
656 }
657
658 fn validate_args(&self, args: &[SqlExpression]) -> Result<()> {
659 if args.len() < 2 || args.len() > 3 {
660 return Err(anyhow!("BOLLINGER_LOWER requires 2 or 3 arguments"));
661 }
662 Ok(())
663 }
664}
665
666struct PercentChangeFunction;
670
671impl WindowFunction for PercentChangeFunction {
672 fn name(&self) -> &str {
673 "PERCENT_CHANGE"
674 }
675
676 fn description(&self) -> &str {
677 "Calculate percentage change from N periods ago"
678 }
679
680 fn signature(&self) -> &str {
681 "PERCENT_CHANGE(column, periods)"
682 }
683
684 fn compute(
685 &self,
686 context: &WindowContext,
687 row_index: usize,
688 args: &[SqlExpression],
689 _evaluator: &mut dyn ExpressionEvaluator,
690 ) -> Result<DataValue> {
691 let column = match &args[0] {
692 SqlExpression::Column(col) => col,
693 _ => return Err(anyhow!("PERCENT_CHANGE first argument must be a column")),
694 };
695
696 let periods = match args.get(1) {
698 Some(SqlExpression::NumberLiteral(n)) => n
699 .parse::<i32>()
700 .map_err(|_| anyhow!("Invalid periods value"))?,
701 _ => 1, };
703
704 let current_value = {
706 let source = context.source();
707 let col_idx = source
708 .get_column_index(&column.name)
709 .ok_or_else(|| anyhow!("Column {} not found", column))?;
710 source.get_value(row_index, col_idx).cloned()
711 };
712
713 let previous_value = context.get_offset_value(row_index, -periods, &column.name);
715
716 match (current_value, previous_value) {
718 (Some(DataValue::Float(curr)), Some(DataValue::Float(prev))) if prev != 0.0 => {
719 Ok(DataValue::Float(((curr - prev) / prev) * 100.0))
720 }
721 (Some(DataValue::Integer(curr)), Some(DataValue::Integer(prev))) if prev != 0 => {
722 let curr_f = curr as f64;
723 let prev_f = prev as f64;
724 Ok(DataValue::Float(((curr_f - prev_f) / prev_f) * 100.0))
725 }
726 (Some(DataValue::Float(curr)), Some(DataValue::Integer(prev))) if prev != 0 => {
727 let prev_f = prev as f64;
728 Ok(DataValue::Float(((curr - prev_f) / prev_f) * 100.0))
729 }
730 (Some(DataValue::Integer(curr)), Some(DataValue::Float(prev))) if prev != 0.0 => {
731 let curr_f = curr as f64;
732 Ok(DataValue::Float(((curr_f - prev) / prev) * 100.0))
733 }
734 _ => Ok(DataValue::Null), }
736 }
737
738 fn transform_window_spec(
739 &self,
740 base_spec: &WindowSpec,
741 _args: &[SqlExpression],
742 ) -> Result<WindowSpec> {
743 Ok(base_spec.clone())
746 }
747
748 fn validate_args(&self, args: &[SqlExpression]) -> Result<()> {
749 if args.is_empty() || args.len() > 2 {
750 return Err(anyhow!("PERCENT_CHANGE requires 1 or 2 arguments"));
751 }
752 Ok(())
753 }
754}
755
756#[cfg(test)]
762mod tests {
763 use super::*;
764 use crate::sql::parser::ast::ColumnRef;
765
766 #[test]
767 fn test_registry_creation() {
768 let registry = WindowFunctionRegistry::new();
769 assert!(registry.contains("MOVING_AVG"));
770 assert!(registry.contains("ROLLING_STDDEV"));
771 assert!(registry.contains("CUMULATIVE_SUM"));
772 }
773
774 #[test]
775 fn test_window_spec_transformation() {
776 use crate::sql::parser::ast::{FrameBound, WindowSpec};
777
778 let func = MovingAvgFunction;
779 let base_spec = WindowSpec {
780 partition_by: vec![],
781 order_by: vec![],
782 frame: None,
783 };
784
785 let args = vec![
786 SqlExpression::Column(ColumnRef::unquoted("close".to_string())),
787 SqlExpression::NumberLiteral("20".to_string()),
788 ];
789
790 let transformed = func.transform_window_spec(&base_spec, &args).unwrap();
791
792 assert!(transformed.frame.is_some());
793 let frame = transformed.frame.unwrap();
794 assert_eq!(frame.start, FrameBound::Preceding(19));
795 }
796}