sql_cli/sql/window_functions/
mod.rs

1// Window Function Registry
2// Provides a clean API for window computations with syntactic sugar
3
4use anyhow::{anyhow, Result};
5use std::collections::HashMap;
6use std::sync::Arc;
7
8use crate::data::datatable::DataValue;
9use crate::sql::parser::ast::{SqlExpression, WindowSpec};
10use crate::sql::window_context::WindowContext;
11
12// Aggregate window functions module
13mod aggregates;
14use aggregates::*;
15
16/// Window function computation trait
17/// Each window function receives:
18/// - The window context (partitions, ordering, frames)
19/// - The current row index
20/// - Arguments (column names, parameters)
21pub trait WindowFunction: Send + Sync {
22    /// Function name (e.g., "MOVING_AVG")
23    fn name(&self) -> &str;
24
25    /// Description for help system
26    fn description(&self) -> &str;
27
28    /// Signature for documentation (e.g., "MOVING_AVG(column, window_size)")
29    fn signature(&self) -> &str;
30
31    /// Compute the function value for a specific row
32    /// This is called once per row in the result set
33    fn compute(
34        &self,
35        context: &WindowContext,
36        row_index: usize,
37        args: &[SqlExpression],
38        _evaluator: &mut dyn ExpressionEvaluator,
39    ) -> Result<DataValue>;
40
41    /// Optional: Transform/expand the window specification
42    /// This allows functions to modify the window (e.g., MOVING_AVG sets ROWS n PRECEDING)
43    fn transform_window_spec(
44        &self,
45        base_spec: &WindowSpec,
46        _args: &[SqlExpression],
47    ) -> Result<WindowSpec> {
48        // Default: use the base spec unchanged
49        Ok(base_spec.clone())
50    }
51
52    /// Validate arguments at parse time
53    fn validate_args(&self, _args: &[SqlExpression]) -> Result<()> {
54        Ok(())
55    }
56}
57
58/// Expression evaluator trait for evaluating arguments
59/// This allows window functions to evaluate expressions without depending on ArithmeticEvaluator
60pub trait ExpressionEvaluator {
61    fn evaluate(&mut self, expr: &SqlExpression, row_index: usize) -> Result<DataValue>;
62}
63
64/// Registry for window functions
65pub struct WindowFunctionRegistry {
66    functions: HashMap<String, Arc<Box<dyn WindowFunction>>>,
67}
68
69impl WindowFunctionRegistry {
70    pub fn new() -> Self {
71        let mut registry = Self {
72            functions: HashMap::new(),
73        };
74        registry.register_builtin_functions();
75        registry
76    }
77
78    /// Register a window function
79    pub fn register(&mut self, function: Box<dyn WindowFunction>) {
80        let name = function.name().to_uppercase();
81        self.functions.insert(name, Arc::new(function));
82    }
83
84    /// Get a window function by name
85    pub fn get(&self, name: &str) -> Option<Arc<Box<dyn WindowFunction>>> {
86        self.functions.get(&name.to_uppercase()).cloned()
87    }
88
89    /// Check if a function exists
90    pub fn contains(&self, name: &str) -> bool {
91        self.functions.contains_key(&name.to_uppercase())
92    }
93
94    /// List all registered functions
95    pub fn list_functions(&self) -> Vec<String> {
96        self.functions.keys().cloned().collect()
97    }
98
99    /// Register built-in syntactic sugar functions
100    fn register_builtin_functions(&mut self) {
101        // Window aggregate functions that can handle expressions
102        self.register(Box::new(WindowSumFunction));
103        self.register(Box::new(WindowAvgFunction));
104        self.register(Box::new(WindowMinFunction));
105        self.register(Box::new(WindowMaxFunction));
106        self.register(Box::new(WindowCountFunction));
107        self.register(Box::new(WindowStddevFunction));
108        self.register(Box::new(WindowStdevFunction)); // Alias for STDDEV
109        self.register(Box::new(WindowVarianceFunction));
110        self.register(Box::new(WindowVarFunction)); // Alias for VARIANCE
111
112        // Moving average and statistics
113        self.register(Box::new(MovingAvgFunction));
114        self.register(Box::new(RollingStddevFunction));
115        self.register(Box::new(CumulativeSumFunction));
116        self.register(Box::new(CumulativeAvgFunction));
117        self.register(Box::new(ZScoreFunction));
118
119        // Bollinger Bands
120        self.register(Box::new(BollingerUpperFunction));
121        self.register(Box::new(BollingerLowerFunction));
122
123        // Financial calculations
124        self.register(Box::new(PercentChangeFunction));
125
126        // Add more as we implement them
127    }
128}
129
130// ============= Syntactic Sugar Implementations =============
131
132/// MOVING_AVG(column, window_size)
133/// Expands to: AVG(column) OVER (ORDER BY <inherited> ROWS window_size-1 PRECEDING)
134struct MovingAvgFunction;
135
136impl WindowFunction for MovingAvgFunction {
137    fn name(&self) -> &str {
138        "MOVING_AVG"
139    }
140
141    fn description(&self) -> &str {
142        "Calculate moving average over specified window size"
143    }
144
145    fn signature(&self) -> &str {
146        "MOVING_AVG(column, window_size)"
147    }
148
149    fn compute(
150        &self,
151        context: &WindowContext,
152        row_index: usize,
153        args: &[SqlExpression],
154        _evaluator: &mut dyn ExpressionEvaluator,
155    ) -> Result<DataValue> {
156        // Extract column name
157        let column = match &args[0] {
158            SqlExpression::Column(col) => col,
159            _ => {
160                return Err(anyhow::anyhow!(
161                    "MOVING_AVG first argument must be a column"
162                ))
163            }
164        };
165
166        // The window has already been configured by transform_window_spec
167        // Just compute the average over the frame
168        context
169            .get_frame_avg(row_index, &column.name)
170            .ok_or_else(|| anyhow::anyhow!("Failed to compute moving average"))
171    }
172
173    fn transform_window_spec(
174        &self,
175        base_spec: &WindowSpec,
176        args: &[SqlExpression],
177    ) -> Result<WindowSpec> {
178        use crate::sql::parser::ast::{FrameBound, FrameUnit, WindowFrame};
179
180        // Extract window size from second argument
181        let window_size = match &args.get(1) {
182            Some(SqlExpression::NumberLiteral(n)) => n
183                .parse::<i64>()
184                .map_err(|_| anyhow::anyhow!("Invalid window size"))?,
185            _ => return Err(anyhow::anyhow!("MOVING_AVG requires numeric window_size")),
186        };
187
188        // Create a new spec with ROWS n-1 PRECEDING frame
189        let mut spec = base_spec.clone();
190        spec.frame = Some(WindowFrame {
191            unit: FrameUnit::Rows,
192            start: FrameBound::Preceding(window_size - 1),
193            end: None, // Defaults to CURRENT ROW
194        });
195
196        Ok(spec)
197    }
198
199    fn validate_args(&self, args: &[SqlExpression]) -> Result<()> {
200        if args.len() != 2 {
201            return Err(anyhow::anyhow!("MOVING_AVG requires exactly 2 arguments"));
202        }
203        Ok(())
204    }
205}
206
207/// ROLLING_STDDEV(column, window_size)
208/// Expands to: STDDEV(column) OVER (ORDER BY <inherited> ROWS window_size-1 PRECEDING)
209struct RollingStddevFunction;
210
211impl WindowFunction for RollingStddevFunction {
212    fn name(&self) -> &str {
213        "ROLLING_STDDEV"
214    }
215
216    fn description(&self) -> &str {
217        "Calculate rolling standard deviation over specified window"
218    }
219
220    fn signature(&self) -> &str {
221        "ROLLING_STDDEV(column, window_size)"
222    }
223
224    fn compute(
225        &self,
226        context: &WindowContext,
227        row_index: usize,
228        args: &[SqlExpression],
229        _evaluator: &mut dyn ExpressionEvaluator,
230    ) -> Result<DataValue> {
231        let column = match &args[0] {
232            SqlExpression::Column(col) => col,
233            _ => {
234                return Err(anyhow::anyhow!(
235                    "ROLLING_STDDEV first argument must be a column"
236                ))
237            }
238        };
239
240        context
241            .get_frame_stddev(row_index, &column.name)
242            .ok_or_else(|| anyhow::anyhow!("Failed to compute rolling stddev"))
243    }
244
245    fn transform_window_spec(
246        &self,
247        base_spec: &WindowSpec,
248        args: &[SqlExpression],
249    ) -> Result<WindowSpec> {
250        use crate::sql::parser::ast::{FrameBound, FrameUnit, WindowFrame};
251
252        let window_size = match &args.get(1) {
253            Some(SqlExpression::NumberLiteral(n)) => n
254                .parse::<i64>()
255                .map_err(|_| anyhow::anyhow!("Invalid window size"))?,
256            _ => {
257                return Err(anyhow::anyhow!(
258                    "ROLLING_STDDEV requires numeric window_size"
259                ))
260            }
261        };
262
263        let mut spec = base_spec.clone();
264        spec.frame = Some(WindowFrame {
265            unit: FrameUnit::Rows,
266            start: FrameBound::Preceding(window_size - 1),
267            end: None,
268        });
269
270        Ok(spec)
271    }
272
273    fn validate_args(&self, args: &[SqlExpression]) -> Result<()> {
274        if args.len() != 2 {
275            return Err(anyhow::anyhow!(
276                "ROLLING_STDDEV requires exactly 2 arguments"
277            ));
278        }
279        Ok(())
280    }
281}
282
283/// CUMULATIVE_SUM(column)
284/// Expands to: SUM(column) OVER (ORDER BY <inherited> ROWS UNBOUNDED PRECEDING)
285struct CumulativeSumFunction;
286
287impl WindowFunction for CumulativeSumFunction {
288    fn name(&self) -> &str {
289        "CUMULATIVE_SUM"
290    }
291
292    fn description(&self) -> &str {
293        "Calculate cumulative sum from beginning to current row"
294    }
295
296    fn signature(&self) -> &str {
297        "CUMULATIVE_SUM(column)"
298    }
299
300    fn compute(
301        &self,
302        context: &WindowContext,
303        row_index: usize,
304        args: &[SqlExpression],
305        _evaluator: &mut dyn ExpressionEvaluator,
306    ) -> Result<DataValue> {
307        let column = match &args[0] {
308            SqlExpression::Column(col) => col,
309            _ => return Err(anyhow::anyhow!("CUMULATIVE_SUM argument must be a column")),
310        };
311
312        context
313            .get_frame_sum(row_index, &column.name)
314            .ok_or_else(|| anyhow::anyhow!("Failed to compute cumulative sum"))
315    }
316
317    fn transform_window_spec(
318        &self,
319        base_spec: &WindowSpec,
320        _args: &[SqlExpression],
321    ) -> Result<WindowSpec> {
322        use crate::sql::parser::ast::{FrameBound, FrameUnit, WindowFrame};
323
324        let mut spec = base_spec.clone();
325        spec.frame = Some(WindowFrame {
326            unit: FrameUnit::Rows,
327            start: FrameBound::UnboundedPreceding,
328            end: None, // CURRENT ROW
329        });
330
331        Ok(spec)
332    }
333
334    fn validate_args(&self, args: &[SqlExpression]) -> Result<()> {
335        if args.len() != 1 {
336            return Err(anyhow::anyhow!(
337                "CUMULATIVE_SUM requires exactly 1 argument"
338            ));
339        }
340        Ok(())
341    }
342}
343
344/// CUMULATIVE_AVG(column)
345/// Expands to: AVG(column) OVER (ORDER BY <inherited> ROWS UNBOUNDED PRECEDING)
346struct CumulativeAvgFunction;
347
348impl WindowFunction for CumulativeAvgFunction {
349    fn name(&self) -> &str {
350        "CUMULATIVE_AVG"
351    }
352
353    fn description(&self) -> &str {
354        "Calculate cumulative average from beginning to current row"
355    }
356
357    fn signature(&self) -> &str {
358        "CUMULATIVE_AVG(column)"
359    }
360
361    fn compute(
362        &self,
363        context: &WindowContext,
364        row_index: usize,
365        args: &[SqlExpression],
366        _evaluator: &mut dyn ExpressionEvaluator,
367    ) -> Result<DataValue> {
368        let column = match &args[0] {
369            SqlExpression::Column(col) => col,
370            _ => return Err(anyhow::anyhow!("CUMULATIVE_AVG argument must be a column")),
371        };
372
373        context
374            .get_frame_avg(row_index, &column.name)
375            .ok_or_else(|| anyhow::anyhow!("Failed to compute cumulative average"))
376    }
377
378    fn transform_window_spec(
379        &self,
380        base_spec: &WindowSpec,
381        _args: &[SqlExpression],
382    ) -> Result<WindowSpec> {
383        use crate::sql::parser::ast::{FrameBound, FrameUnit, WindowFrame};
384
385        let mut spec = base_spec.clone();
386        spec.frame = Some(WindowFrame {
387            unit: FrameUnit::Rows,
388            start: FrameBound::UnboundedPreceding,
389            end: None,
390        });
391
392        Ok(spec)
393    }
394
395    fn validate_args(&self, args: &[SqlExpression]) -> Result<()> {
396        if args.len() != 1 {
397            return Err(anyhow::anyhow!(
398                "CUMULATIVE_AVG requires exactly 1 argument"
399            ));
400        }
401        Ok(())
402    }
403}
404
405/// Z_SCORE(column, window_size)
406/// Calculates: (value - mean) / stddev over the window
407struct ZScoreFunction;
408
409impl WindowFunction for ZScoreFunction {
410    fn name(&self) -> &str {
411        "Z_SCORE"
412    }
413
414    fn description(&self) -> &str {
415        "Calculate Z-score (standard deviations from mean) over window"
416    }
417
418    fn signature(&self) -> &str {
419        "Z_SCORE(column, window_size)"
420    }
421
422    fn compute(
423        &self,
424        context: &WindowContext,
425        row_index: usize,
426        args: &[SqlExpression],
427        _evaluator: &mut dyn ExpressionEvaluator,
428    ) -> Result<DataValue> {
429        let column = match &args[0] {
430            SqlExpression::Column(col) => col,
431            _ => return Err(anyhow::anyhow!("Z_SCORE first argument must be a column")),
432        };
433
434        // Get current value
435        let current_value = {
436            let source = context.source();
437            let col_idx = source
438                .get_column_index(&column.name)
439                .ok_or_else(|| anyhow::anyhow!("Column {} not found", column))?;
440            source
441                .get_value(row_index, col_idx)
442                .cloned()
443                .unwrap_or(DataValue::Null)
444        };
445
446        // Get mean and stddev over the window
447        let mean = context
448            .get_frame_avg(row_index, &column.name)
449            .unwrap_or(DataValue::Null);
450        let stddev = context
451            .get_frame_stddev(row_index, &column.name)
452            .unwrap_or(DataValue::Null);
453
454        // Calculate Z-score
455        match (current_value, mean, stddev) {
456            (DataValue::Integer(v), DataValue::Float(m), DataValue::Float(s)) if s > 0.0 => {
457                Ok(DataValue::Float((v as f64 - m) / s))
458            }
459            (DataValue::Float(v), DataValue::Float(m), DataValue::Float(s)) if s > 0.0 => {
460                Ok(DataValue::Float((v - m) / s))
461            }
462            _ => Ok(DataValue::Null),
463        }
464    }
465
466    fn transform_window_spec(
467        &self,
468        base_spec: &WindowSpec,
469        args: &[SqlExpression],
470    ) -> Result<WindowSpec> {
471        use crate::sql::parser::ast::{FrameBound, FrameUnit, WindowFrame};
472
473        let window_size = match &args.get(1) {
474            Some(SqlExpression::NumberLiteral(n)) => n
475                .parse::<i64>()
476                .map_err(|_| anyhow::anyhow!("Invalid window size"))?,
477            _ => return Err(anyhow::anyhow!("Z_SCORE requires numeric window_size")),
478        };
479
480        let mut spec = base_spec.clone();
481        spec.frame = Some(WindowFrame {
482            unit: FrameUnit::Rows,
483            start: FrameBound::Preceding(window_size - 1),
484            end: None,
485        });
486
487        Ok(spec)
488    }
489
490    fn validate_args(&self, args: &[SqlExpression]) -> Result<()> {
491        if args.len() != 2 {
492            return Err(anyhow::anyhow!("Z_SCORE requires exactly 2 arguments"));
493        }
494        Ok(())
495    }
496}
497
498/// BOLLINGER_UPPER(column, window_size, num_std)
499/// Calculates upper Bollinger Band: MA + (num_std * STDDEV)
500struct BollingerUpperFunction;
501
502impl WindowFunction for BollingerUpperFunction {
503    fn name(&self) -> &str {
504        "BOLLINGER_UPPER"
505    }
506
507    fn description(&self) -> &str {
508        "Calculate upper Bollinger Band (MA + n*STDDEV)"
509    }
510
511    fn signature(&self) -> &str {
512        "BOLLINGER_UPPER(column, window_size, num_std)"
513    }
514
515    fn compute(
516        &self,
517        context: &WindowContext,
518        row_index: usize,
519        args: &[SqlExpression],
520        _evaluator: &mut dyn ExpressionEvaluator,
521    ) -> Result<DataValue> {
522        let column = match &args[0] {
523            SqlExpression::Column(col) => col,
524            _ => return Err(anyhow!("BOLLINGER_UPPER first argument must be a column")),
525        };
526
527        // Get num_std from third argument (default 2)
528        let num_std = match args.get(2) {
529            Some(SqlExpression::NumberLiteral(n)) => n
530                .parse::<f64>()
531                .map_err(|_| anyhow!("Invalid num_std value"))?,
532            _ => 2.0, // Default to 2 standard deviations
533        };
534
535        // Get mean and stddev over the window
536        let mean = context
537            .get_frame_avg(row_index, &column.name)
538            .unwrap_or(DataValue::Null);
539        let stddev = context
540            .get_frame_stddev(row_index, &column.name)
541            .unwrap_or(DataValue::Null);
542
543        // Calculate upper band: mean + (num_std * stddev)
544        match (mean, stddev) {
545            (DataValue::Float(m), DataValue::Float(s)) => Ok(DataValue::Float(m + (num_std * s))),
546            _ => Ok(DataValue::Null),
547        }
548    }
549
550    fn transform_window_spec(
551        &self,
552        base_spec: &WindowSpec,
553        args: &[SqlExpression],
554    ) -> Result<WindowSpec> {
555        use crate::sql::parser::ast::{FrameBound, FrameUnit, WindowFrame};
556
557        let window_size = match args.get(1) {
558            Some(SqlExpression::NumberLiteral(n)) => n
559                .parse::<i64>()
560                .map_err(|_| anyhow!("Invalid window size"))?,
561            _ => return Err(anyhow!("BOLLINGER_UPPER requires numeric window_size")),
562        };
563
564        let mut spec = base_spec.clone();
565        spec.frame = Some(WindowFrame {
566            unit: FrameUnit::Rows,
567            start: FrameBound::Preceding(window_size - 1),
568            end: None,
569        });
570
571        Ok(spec)
572    }
573
574    fn validate_args(&self, args: &[SqlExpression]) -> Result<()> {
575        if args.len() < 2 || args.len() > 3 {
576            return Err(anyhow!("BOLLINGER_UPPER requires 2 or 3 arguments"));
577        }
578        Ok(())
579    }
580}
581
582/// BOLLINGER_LOWER(column, window_size, num_std)
583/// Calculates lower Bollinger Band: MA - (num_std * STDDEV)
584struct BollingerLowerFunction;
585
586impl WindowFunction for BollingerLowerFunction {
587    fn name(&self) -> &str {
588        "BOLLINGER_LOWER"
589    }
590
591    fn description(&self) -> &str {
592        "Calculate lower Bollinger Band (MA - n*STDDEV)"
593    }
594
595    fn signature(&self) -> &str {
596        "BOLLINGER_LOWER(column, window_size, num_std)"
597    }
598
599    fn compute(
600        &self,
601        context: &WindowContext,
602        row_index: usize,
603        args: &[SqlExpression],
604        _evaluator: &mut dyn ExpressionEvaluator,
605    ) -> Result<DataValue> {
606        let column = match &args[0] {
607            SqlExpression::Column(col) => col,
608            _ => return Err(anyhow!("BOLLINGER_LOWER first argument must be a column")),
609        };
610
611        // Get num_std from third argument (default 2)
612        let num_std = match args.get(2) {
613            Some(SqlExpression::NumberLiteral(n)) => n
614                .parse::<f64>()
615                .map_err(|_| anyhow!("Invalid num_std value"))?,
616            _ => 2.0, // Default to 2 standard deviations
617        };
618
619        // Get mean and stddev over the window
620        let mean = context
621            .get_frame_avg(row_index, &column.name)
622            .unwrap_or(DataValue::Null);
623        let stddev = context
624            .get_frame_stddev(row_index, &column.name)
625            .unwrap_or(DataValue::Null);
626
627        // Calculate lower band: mean - (num_std * stddev)
628        match (mean, stddev) {
629            (DataValue::Float(m), DataValue::Float(s)) => Ok(DataValue::Float(m - (num_std * s))),
630            _ => Ok(DataValue::Null),
631        }
632    }
633
634    fn transform_window_spec(
635        &self,
636        base_spec: &WindowSpec,
637        args: &[SqlExpression],
638    ) -> Result<WindowSpec> {
639        use crate::sql::parser::ast::{FrameBound, FrameUnit, WindowFrame};
640
641        let window_size = match args.get(1) {
642            Some(SqlExpression::NumberLiteral(n)) => n
643                .parse::<i64>()
644                .map_err(|_| anyhow!("Invalid window size"))?,
645            _ => return Err(anyhow!("BOLLINGER_LOWER requires numeric window_size")),
646        };
647
648        let mut spec = base_spec.clone();
649        spec.frame = Some(WindowFrame {
650            unit: FrameUnit::Rows,
651            start: FrameBound::Preceding(window_size - 1),
652            end: None,
653        });
654
655        Ok(spec)
656    }
657
658    fn validate_args(&self, args: &[SqlExpression]) -> Result<()> {
659        if args.len() < 2 || args.len() > 3 {
660            return Err(anyhow!("BOLLINGER_LOWER requires 2 or 3 arguments"));
661        }
662        Ok(())
663    }
664}
665
666/// PERCENT_CHANGE(column, periods)
667/// Calculates percentage change from N periods ago
668/// Formula: ((current - previous) / previous) * 100
669struct PercentChangeFunction;
670
671impl WindowFunction for PercentChangeFunction {
672    fn name(&self) -> &str {
673        "PERCENT_CHANGE"
674    }
675
676    fn description(&self) -> &str {
677        "Calculate percentage change from N periods ago"
678    }
679
680    fn signature(&self) -> &str {
681        "PERCENT_CHANGE(column, periods)"
682    }
683
684    fn compute(
685        &self,
686        context: &WindowContext,
687        row_index: usize,
688        args: &[SqlExpression],
689        _evaluator: &mut dyn ExpressionEvaluator,
690    ) -> Result<DataValue> {
691        let column = match &args[0] {
692            SqlExpression::Column(col) => col,
693            _ => return Err(anyhow!("PERCENT_CHANGE first argument must be a column")),
694        };
695
696        // Get periods from second argument (default 1)
697        let periods = match args.get(1) {
698            Some(SqlExpression::NumberLiteral(n)) => n
699                .parse::<i32>()
700                .map_err(|_| anyhow!("Invalid periods value"))?,
701            _ => 1, // Default to 1 period
702        };
703
704        // Get current value
705        let current_value = {
706            let source = context.source();
707            let col_idx = source
708                .get_column_index(&column.name)
709                .ok_or_else(|| anyhow!("Column {} not found", column))?;
710            source.get_value(row_index, col_idx).cloned()
711        };
712
713        // Get previous value using offset
714        let previous_value = context.get_offset_value(row_index, -periods, &column.name);
715
716        // Calculate percent change: ((current - previous) / previous) * 100
717        match (current_value, previous_value) {
718            (Some(DataValue::Float(curr)), Some(DataValue::Float(prev))) if prev != 0.0 => {
719                Ok(DataValue::Float(((curr - prev) / prev) * 100.0))
720            }
721            (Some(DataValue::Integer(curr)), Some(DataValue::Integer(prev))) if prev != 0 => {
722                let curr_f = curr as f64;
723                let prev_f = prev as f64;
724                Ok(DataValue::Float(((curr_f - prev_f) / prev_f) * 100.0))
725            }
726            (Some(DataValue::Float(curr)), Some(DataValue::Integer(prev))) if prev != 0 => {
727                let prev_f = prev as f64;
728                Ok(DataValue::Float(((curr - prev_f) / prev_f) * 100.0))
729            }
730            (Some(DataValue::Integer(curr)), Some(DataValue::Float(prev))) if prev != 0.0 => {
731                let curr_f = curr as f64;
732                Ok(DataValue::Float(((curr_f - prev) / prev) * 100.0))
733            }
734            _ => Ok(DataValue::Null), // Return NULL for invalid comparisons or division by zero
735        }
736    }
737
738    fn transform_window_spec(
739        &self,
740        base_spec: &WindowSpec,
741        _args: &[SqlExpression],
742    ) -> Result<WindowSpec> {
743        // PERCENT_CHANGE doesn't need to modify the window frame
744        // It uses LAG internally which works within the partition
745        Ok(base_spec.clone())
746    }
747
748    fn validate_args(&self, args: &[SqlExpression]) -> Result<()> {
749        if args.is_empty() || args.len() > 2 {
750            return Err(anyhow!("PERCENT_CHANGE requires 1 or 2 arguments"));
751        }
752        Ok(())
753    }
754}
755
756// TODO: Add more functions like:
757// - EXPONENTIAL_AVG(column, alpha)
758// - PERCENT_RANK_IN_WINDOW(column, window)
759// - MEDIAN_IN_WINDOW(column, window)
760
761#[cfg(test)]
762mod tests {
763    use super::*;
764    use crate::sql::parser::ast::ColumnRef;
765
766    #[test]
767    fn test_registry_creation() {
768        let registry = WindowFunctionRegistry::new();
769        assert!(registry.contains("MOVING_AVG"));
770        assert!(registry.contains("ROLLING_STDDEV"));
771        assert!(registry.contains("CUMULATIVE_SUM"));
772    }
773
774    #[test]
775    fn test_window_spec_transformation() {
776        use crate::sql::parser::ast::{FrameBound, WindowSpec};
777
778        let func = MovingAvgFunction;
779        let base_spec = WindowSpec {
780            partition_by: vec![],
781            order_by: vec![],
782            frame: None,
783        };
784
785        let args = vec![
786            SqlExpression::Column(ColumnRef::unquoted("close".to_string())),
787            SqlExpression::NumberLiteral("20".to_string()),
788        ];
789
790        let transformed = func.transform_window_spec(&base_spec, &args).unwrap();
791
792        assert!(transformed.frame.is_some());
793        let frame = transformed.frame.unwrap();
794        assert_eq!(frame.start, FrameBound::Preceding(19));
795    }
796}