1use anyhow::{anyhow, Result};
5use std::collections::HashMap;
6use std::sync::Arc;
7
8use crate::data::datatable::DataValue;
9use crate::sql::parser::ast::{ColumnRef, SqlExpression, WindowSpec};
10use crate::sql::window_context::WindowContext;
11
12pub trait WindowFunction: Send + Sync {
18 fn name(&self) -> &str;
20
21 fn description(&self) -> &str;
23
24 fn signature(&self) -> &str;
26
27 fn compute(
30 &self,
31 context: &WindowContext,
32 row_index: usize,
33 args: &[SqlExpression],
34 evaluator: &mut dyn ExpressionEvaluator,
35 ) -> Result<DataValue>;
36
37 fn transform_window_spec(
40 &self,
41 base_spec: &WindowSpec,
42 args: &[SqlExpression],
43 ) -> Result<WindowSpec> {
44 Ok(base_spec.clone())
46 }
47
48 fn validate_args(&self, args: &[SqlExpression]) -> Result<()> {
50 Ok(())
51 }
52}
53
54pub trait ExpressionEvaluator {
57 fn evaluate(&mut self, expr: &SqlExpression, row_index: usize) -> Result<DataValue>;
58}
59
60pub struct WindowFunctionRegistry {
62 functions: HashMap<String, Arc<Box<dyn WindowFunction>>>,
63}
64
65impl WindowFunctionRegistry {
66 pub fn new() -> Self {
67 let mut registry = Self {
68 functions: HashMap::new(),
69 };
70 registry.register_builtin_functions();
71 registry
72 }
73
74 pub fn register(&mut self, function: Box<dyn WindowFunction>) {
76 let name = function.name().to_uppercase();
77 self.functions.insert(name, Arc::new(function));
78 }
79
80 pub fn get(&self, name: &str) -> Option<Arc<Box<dyn WindowFunction>>> {
82 self.functions.get(&name.to_uppercase()).cloned()
83 }
84
85 pub fn contains(&self, name: &str) -> bool {
87 self.functions.contains_key(&name.to_uppercase())
88 }
89
90 pub fn list_functions(&self) -> Vec<String> {
92 self.functions.keys().cloned().collect()
93 }
94
95 fn register_builtin_functions(&mut self) {
97 self.register(Box::new(MovingAvgFunction));
99 self.register(Box::new(RollingStddevFunction));
100 self.register(Box::new(CumulativeSumFunction));
101 self.register(Box::new(CumulativeAvgFunction));
102 self.register(Box::new(ZScoreFunction));
103
104 self.register(Box::new(BollingerUpperFunction));
106 self.register(Box::new(BollingerLowerFunction));
107
108 self.register(Box::new(PercentChangeFunction));
110
111 }
113}
114
115struct MovingAvgFunction;
120
121impl WindowFunction for MovingAvgFunction {
122 fn name(&self) -> &str {
123 "MOVING_AVG"
124 }
125
126 fn description(&self) -> &str {
127 "Calculate moving average over specified window size"
128 }
129
130 fn signature(&self) -> &str {
131 "MOVING_AVG(column, window_size)"
132 }
133
134 fn compute(
135 &self,
136 context: &WindowContext,
137 row_index: usize,
138 args: &[SqlExpression],
139 evaluator: &mut dyn ExpressionEvaluator,
140 ) -> Result<DataValue> {
141 let column = match &args[0] {
143 SqlExpression::Column(col) => col,
144 _ => {
145 return Err(anyhow::anyhow!(
146 "MOVING_AVG first argument must be a column"
147 ))
148 }
149 };
150
151 context
154 .get_frame_avg(row_index, &column.name)
155 .ok_or_else(|| anyhow::anyhow!("Failed to compute moving average"))
156 }
157
158 fn transform_window_spec(
159 &self,
160 base_spec: &WindowSpec,
161 args: &[SqlExpression],
162 ) -> Result<WindowSpec> {
163 use crate::sql::parser::ast::{FrameBound, FrameUnit, WindowFrame};
164
165 let window_size = match &args.get(1) {
167 Some(SqlExpression::NumberLiteral(n)) => n
168 .parse::<i64>()
169 .map_err(|_| anyhow::anyhow!("Invalid window size"))?,
170 _ => return Err(anyhow::anyhow!("MOVING_AVG requires numeric window_size")),
171 };
172
173 let mut spec = base_spec.clone();
175 spec.frame = Some(WindowFrame {
176 unit: FrameUnit::Rows,
177 start: FrameBound::Preceding(window_size - 1),
178 end: None, });
180
181 Ok(spec)
182 }
183
184 fn validate_args(&self, args: &[SqlExpression]) -> Result<()> {
185 if args.len() != 2 {
186 return Err(anyhow::anyhow!("MOVING_AVG requires exactly 2 arguments"));
187 }
188 Ok(())
189 }
190}
191
192struct RollingStddevFunction;
195
196impl WindowFunction for RollingStddevFunction {
197 fn name(&self) -> &str {
198 "ROLLING_STDDEV"
199 }
200
201 fn description(&self) -> &str {
202 "Calculate rolling standard deviation over specified window"
203 }
204
205 fn signature(&self) -> &str {
206 "ROLLING_STDDEV(column, window_size)"
207 }
208
209 fn compute(
210 &self,
211 context: &WindowContext,
212 row_index: usize,
213 args: &[SqlExpression],
214 evaluator: &mut dyn ExpressionEvaluator,
215 ) -> Result<DataValue> {
216 let column = match &args[0] {
217 SqlExpression::Column(col) => col,
218 _ => {
219 return Err(anyhow::anyhow!(
220 "ROLLING_STDDEV first argument must be a column"
221 ))
222 }
223 };
224
225 context
226 .get_frame_stddev(row_index, &column.name)
227 .ok_or_else(|| anyhow::anyhow!("Failed to compute rolling stddev"))
228 }
229
230 fn transform_window_spec(
231 &self,
232 base_spec: &WindowSpec,
233 args: &[SqlExpression],
234 ) -> Result<WindowSpec> {
235 use crate::sql::parser::ast::{FrameBound, FrameUnit, WindowFrame};
236
237 let window_size = match &args.get(1) {
238 Some(SqlExpression::NumberLiteral(n)) => n
239 .parse::<i64>()
240 .map_err(|_| anyhow::anyhow!("Invalid window size"))?,
241 _ => {
242 return Err(anyhow::anyhow!(
243 "ROLLING_STDDEV requires numeric window_size"
244 ))
245 }
246 };
247
248 let mut spec = base_spec.clone();
249 spec.frame = Some(WindowFrame {
250 unit: FrameUnit::Rows,
251 start: FrameBound::Preceding(window_size - 1),
252 end: None,
253 });
254
255 Ok(spec)
256 }
257
258 fn validate_args(&self, args: &[SqlExpression]) -> Result<()> {
259 if args.len() != 2 {
260 return Err(anyhow::anyhow!(
261 "ROLLING_STDDEV requires exactly 2 arguments"
262 ));
263 }
264 Ok(())
265 }
266}
267
268struct CumulativeSumFunction;
271
272impl WindowFunction for CumulativeSumFunction {
273 fn name(&self) -> &str {
274 "CUMULATIVE_SUM"
275 }
276
277 fn description(&self) -> &str {
278 "Calculate cumulative sum from beginning to current row"
279 }
280
281 fn signature(&self) -> &str {
282 "CUMULATIVE_SUM(column)"
283 }
284
285 fn compute(
286 &self,
287 context: &WindowContext,
288 row_index: usize,
289 args: &[SqlExpression],
290 evaluator: &mut dyn ExpressionEvaluator,
291 ) -> Result<DataValue> {
292 let column = match &args[0] {
293 SqlExpression::Column(col) => col,
294 _ => return Err(anyhow::anyhow!("CUMULATIVE_SUM argument must be a column")),
295 };
296
297 context
298 .get_frame_sum(row_index, &column.name)
299 .ok_or_else(|| anyhow::anyhow!("Failed to compute cumulative sum"))
300 }
301
302 fn transform_window_spec(
303 &self,
304 base_spec: &WindowSpec,
305 args: &[SqlExpression],
306 ) -> Result<WindowSpec> {
307 use crate::sql::parser::ast::{FrameBound, FrameUnit, WindowFrame};
308
309 let mut spec = base_spec.clone();
310 spec.frame = Some(WindowFrame {
311 unit: FrameUnit::Rows,
312 start: FrameBound::UnboundedPreceding,
313 end: None, });
315
316 Ok(spec)
317 }
318
319 fn validate_args(&self, args: &[SqlExpression]) -> Result<()> {
320 if args.len() != 1 {
321 return Err(anyhow::anyhow!(
322 "CUMULATIVE_SUM requires exactly 1 argument"
323 ));
324 }
325 Ok(())
326 }
327}
328
329struct CumulativeAvgFunction;
332
333impl WindowFunction for CumulativeAvgFunction {
334 fn name(&self) -> &str {
335 "CUMULATIVE_AVG"
336 }
337
338 fn description(&self) -> &str {
339 "Calculate cumulative average from beginning to current row"
340 }
341
342 fn signature(&self) -> &str {
343 "CUMULATIVE_AVG(column)"
344 }
345
346 fn compute(
347 &self,
348 context: &WindowContext,
349 row_index: usize,
350 args: &[SqlExpression],
351 evaluator: &mut dyn ExpressionEvaluator,
352 ) -> Result<DataValue> {
353 let column = match &args[0] {
354 SqlExpression::Column(col) => col,
355 _ => return Err(anyhow::anyhow!("CUMULATIVE_AVG argument must be a column")),
356 };
357
358 context
359 .get_frame_avg(row_index, &column.name)
360 .ok_or_else(|| anyhow::anyhow!("Failed to compute cumulative average"))
361 }
362
363 fn transform_window_spec(
364 &self,
365 base_spec: &WindowSpec,
366 args: &[SqlExpression],
367 ) -> Result<WindowSpec> {
368 use crate::sql::parser::ast::{FrameBound, FrameUnit, WindowFrame};
369
370 let mut spec = base_spec.clone();
371 spec.frame = Some(WindowFrame {
372 unit: FrameUnit::Rows,
373 start: FrameBound::UnboundedPreceding,
374 end: None,
375 });
376
377 Ok(spec)
378 }
379
380 fn validate_args(&self, args: &[SqlExpression]) -> Result<()> {
381 if args.len() != 1 {
382 return Err(anyhow::anyhow!(
383 "CUMULATIVE_AVG requires exactly 1 argument"
384 ));
385 }
386 Ok(())
387 }
388}
389
390struct ZScoreFunction;
393
394impl WindowFunction for ZScoreFunction {
395 fn name(&self) -> &str {
396 "Z_SCORE"
397 }
398
399 fn description(&self) -> &str {
400 "Calculate Z-score (standard deviations from mean) over window"
401 }
402
403 fn signature(&self) -> &str {
404 "Z_SCORE(column, window_size)"
405 }
406
407 fn compute(
408 &self,
409 context: &WindowContext,
410 row_index: usize,
411 args: &[SqlExpression],
412 evaluator: &mut dyn ExpressionEvaluator,
413 ) -> Result<DataValue> {
414 let column = match &args[0] {
415 SqlExpression::Column(col) => col,
416 _ => return Err(anyhow::anyhow!("Z_SCORE first argument must be a column")),
417 };
418
419 let current_value = {
421 let source = context.source();
422 let col_idx = source
423 .get_column_index(&column.name)
424 .ok_or_else(|| anyhow::anyhow!("Column {} not found", column))?;
425 source
426 .get_value(row_index, col_idx)
427 .cloned()
428 .unwrap_or(DataValue::Null)
429 };
430
431 let mean = context
433 .get_frame_avg(row_index, &column.name)
434 .unwrap_or(DataValue::Null);
435 let stddev = context
436 .get_frame_stddev(row_index, &column.name)
437 .unwrap_or(DataValue::Null);
438
439 match (current_value, mean, stddev) {
441 (DataValue::Integer(v), DataValue::Float(m), DataValue::Float(s)) if s > 0.0 => {
442 Ok(DataValue::Float((v as f64 - m) / s))
443 }
444 (DataValue::Float(v), DataValue::Float(m), DataValue::Float(s)) if s > 0.0 => {
445 Ok(DataValue::Float((v - m) / s))
446 }
447 _ => Ok(DataValue::Null),
448 }
449 }
450
451 fn transform_window_spec(
452 &self,
453 base_spec: &WindowSpec,
454 args: &[SqlExpression],
455 ) -> Result<WindowSpec> {
456 use crate::sql::parser::ast::{FrameBound, FrameUnit, WindowFrame};
457
458 let window_size = match &args.get(1) {
459 Some(SqlExpression::NumberLiteral(n)) => n
460 .parse::<i64>()
461 .map_err(|_| anyhow::anyhow!("Invalid window size"))?,
462 _ => return Err(anyhow::anyhow!("Z_SCORE requires numeric window_size")),
463 };
464
465 let mut spec = base_spec.clone();
466 spec.frame = Some(WindowFrame {
467 unit: FrameUnit::Rows,
468 start: FrameBound::Preceding(window_size - 1),
469 end: None,
470 });
471
472 Ok(spec)
473 }
474
475 fn validate_args(&self, args: &[SqlExpression]) -> Result<()> {
476 if args.len() != 2 {
477 return Err(anyhow::anyhow!("Z_SCORE requires exactly 2 arguments"));
478 }
479 Ok(())
480 }
481}
482
483struct BollingerUpperFunction;
486
487impl WindowFunction for BollingerUpperFunction {
488 fn name(&self) -> &str {
489 "BOLLINGER_UPPER"
490 }
491
492 fn description(&self) -> &str {
493 "Calculate upper Bollinger Band (MA + n*STDDEV)"
494 }
495
496 fn signature(&self) -> &str {
497 "BOLLINGER_UPPER(column, window_size, num_std)"
498 }
499
500 fn compute(
501 &self,
502 context: &WindowContext,
503 row_index: usize,
504 args: &[SqlExpression],
505 _evaluator: &mut dyn ExpressionEvaluator,
506 ) -> Result<DataValue> {
507 let column = match &args[0] {
508 SqlExpression::Column(col) => col,
509 _ => return Err(anyhow!("BOLLINGER_UPPER first argument must be a column")),
510 };
511
512 let num_std = match args.get(2) {
514 Some(SqlExpression::NumberLiteral(n)) => n
515 .parse::<f64>()
516 .map_err(|_| anyhow!("Invalid num_std value"))?,
517 _ => 2.0, };
519
520 let mean = context
522 .get_frame_avg(row_index, &column.name)
523 .unwrap_or(DataValue::Null);
524 let stddev = context
525 .get_frame_stddev(row_index, &column.name)
526 .unwrap_or(DataValue::Null);
527
528 match (mean, stddev) {
530 (DataValue::Float(m), DataValue::Float(s)) => Ok(DataValue::Float(m + (num_std * s))),
531 _ => Ok(DataValue::Null),
532 }
533 }
534
535 fn transform_window_spec(
536 &self,
537 base_spec: &WindowSpec,
538 args: &[SqlExpression],
539 ) -> Result<WindowSpec> {
540 use crate::sql::parser::ast::{FrameBound, FrameUnit, WindowFrame};
541
542 let window_size = match args.get(1) {
543 Some(SqlExpression::NumberLiteral(n)) => n
544 .parse::<i64>()
545 .map_err(|_| anyhow!("Invalid window size"))?,
546 _ => return Err(anyhow!("BOLLINGER_UPPER requires numeric window_size")),
547 };
548
549 let mut spec = base_spec.clone();
550 spec.frame = Some(WindowFrame {
551 unit: FrameUnit::Rows,
552 start: FrameBound::Preceding(window_size - 1),
553 end: None,
554 });
555
556 Ok(spec)
557 }
558
559 fn validate_args(&self, args: &[SqlExpression]) -> Result<()> {
560 if args.len() < 2 || args.len() > 3 {
561 return Err(anyhow!("BOLLINGER_UPPER requires 2 or 3 arguments"));
562 }
563 Ok(())
564 }
565}
566
567struct BollingerLowerFunction;
570
571impl WindowFunction for BollingerLowerFunction {
572 fn name(&self) -> &str {
573 "BOLLINGER_LOWER"
574 }
575
576 fn description(&self) -> &str {
577 "Calculate lower Bollinger Band (MA - n*STDDEV)"
578 }
579
580 fn signature(&self) -> &str {
581 "BOLLINGER_LOWER(column, window_size, num_std)"
582 }
583
584 fn compute(
585 &self,
586 context: &WindowContext,
587 row_index: usize,
588 args: &[SqlExpression],
589 _evaluator: &mut dyn ExpressionEvaluator,
590 ) -> Result<DataValue> {
591 let column = match &args[0] {
592 SqlExpression::Column(col) => col,
593 _ => return Err(anyhow!("BOLLINGER_LOWER first argument must be a column")),
594 };
595
596 let num_std = match args.get(2) {
598 Some(SqlExpression::NumberLiteral(n)) => n
599 .parse::<f64>()
600 .map_err(|_| anyhow!("Invalid num_std value"))?,
601 _ => 2.0, };
603
604 let mean = context
606 .get_frame_avg(row_index, &column.name)
607 .unwrap_or(DataValue::Null);
608 let stddev = context
609 .get_frame_stddev(row_index, &column.name)
610 .unwrap_or(DataValue::Null);
611
612 match (mean, stddev) {
614 (DataValue::Float(m), DataValue::Float(s)) => Ok(DataValue::Float(m - (num_std * s))),
615 _ => Ok(DataValue::Null),
616 }
617 }
618
619 fn transform_window_spec(
620 &self,
621 base_spec: &WindowSpec,
622 args: &[SqlExpression],
623 ) -> Result<WindowSpec> {
624 use crate::sql::parser::ast::{FrameBound, FrameUnit, WindowFrame};
625
626 let window_size = match args.get(1) {
627 Some(SqlExpression::NumberLiteral(n)) => n
628 .parse::<i64>()
629 .map_err(|_| anyhow!("Invalid window size"))?,
630 _ => return Err(anyhow!("BOLLINGER_LOWER requires numeric window_size")),
631 };
632
633 let mut spec = base_spec.clone();
634 spec.frame = Some(WindowFrame {
635 unit: FrameUnit::Rows,
636 start: FrameBound::Preceding(window_size - 1),
637 end: None,
638 });
639
640 Ok(spec)
641 }
642
643 fn validate_args(&self, args: &[SqlExpression]) -> Result<()> {
644 if args.len() < 2 || args.len() > 3 {
645 return Err(anyhow!("BOLLINGER_LOWER requires 2 or 3 arguments"));
646 }
647 Ok(())
648 }
649}
650
651struct PercentChangeFunction;
655
656impl WindowFunction for PercentChangeFunction {
657 fn name(&self) -> &str {
658 "PERCENT_CHANGE"
659 }
660
661 fn description(&self) -> &str {
662 "Calculate percentage change from N periods ago"
663 }
664
665 fn signature(&self) -> &str {
666 "PERCENT_CHANGE(column, periods)"
667 }
668
669 fn compute(
670 &self,
671 context: &WindowContext,
672 row_index: usize,
673 args: &[SqlExpression],
674 _evaluator: &mut dyn ExpressionEvaluator,
675 ) -> Result<DataValue> {
676 let column = match &args[0] {
677 SqlExpression::Column(col) => col,
678 _ => return Err(anyhow!("PERCENT_CHANGE first argument must be a column")),
679 };
680
681 let periods = match args.get(1) {
683 Some(SqlExpression::NumberLiteral(n)) => n
684 .parse::<i32>()
685 .map_err(|_| anyhow!("Invalid periods value"))?,
686 _ => 1, };
688
689 let current_value = {
691 let source = context.source();
692 let col_idx = source
693 .get_column_index(&column.name)
694 .ok_or_else(|| anyhow!("Column {} not found", column))?;
695 source.get_value(row_index, col_idx).cloned()
696 };
697
698 let previous_value = context.get_offset_value(row_index, -periods, &column.name);
700
701 match (current_value, previous_value) {
703 (Some(DataValue::Float(curr)), Some(DataValue::Float(prev))) if prev != 0.0 => {
704 Ok(DataValue::Float(((curr - prev) / prev) * 100.0))
705 }
706 (Some(DataValue::Integer(curr)), Some(DataValue::Integer(prev))) if prev != 0 => {
707 let curr_f = curr as f64;
708 let prev_f = prev as f64;
709 Ok(DataValue::Float(((curr_f - prev_f) / prev_f) * 100.0))
710 }
711 (Some(DataValue::Float(curr)), Some(DataValue::Integer(prev))) if prev != 0 => {
712 let prev_f = prev as f64;
713 Ok(DataValue::Float(((curr - prev_f) / prev_f) * 100.0))
714 }
715 (Some(DataValue::Integer(curr)), Some(DataValue::Float(prev))) if prev != 0.0 => {
716 let curr_f = curr as f64;
717 Ok(DataValue::Float(((curr_f - prev) / prev) * 100.0))
718 }
719 _ => Ok(DataValue::Null), }
721 }
722
723 fn transform_window_spec(
724 &self,
725 base_spec: &WindowSpec,
726 _args: &[SqlExpression],
727 ) -> Result<WindowSpec> {
728 Ok(base_spec.clone())
731 }
732
733 fn validate_args(&self, args: &[SqlExpression]) -> Result<()> {
734 if args.is_empty() || args.len() > 2 {
735 return Err(anyhow!("PERCENT_CHANGE requires 1 or 2 arguments"));
736 }
737 Ok(())
738 }
739}
740
741#[cfg(test)]
747mod tests {
748 use super::*;
749
750 #[test]
751 fn test_registry_creation() {
752 let registry = WindowFunctionRegistry::new();
753 assert!(registry.contains("MOVING_AVG"));
754 assert!(registry.contains("ROLLING_STDDEV"));
755 assert!(registry.contains("CUMULATIVE_SUM"));
756 }
757
758 #[test]
759 fn test_window_spec_transformation() {
760 use crate::sql::parser::ast::{FrameBound, WindowSpec};
761
762 let func = MovingAvgFunction;
763 let base_spec = WindowSpec {
764 partition_by: vec![],
765 order_by: vec![],
766 frame: None,
767 };
768
769 let args = vec![
770 SqlExpression::Column(ColumnRef::unquoted("close".to_string())),
771 SqlExpression::NumberLiteral("20".to_string()),
772 ];
773
774 let transformed = func.transform_window_spec(&base_spec, &args).unwrap();
775
776 assert!(transformed.frame.is_some());
777 let frame = transformed.frame.unwrap();
778 assert_eq!(frame.start, FrameBound::Preceding(19));
779 }
780}