Skip to main content

shape_ast/parser/expressions/
window.rs

1//! Window function expression parser
2//!
3//! Parses window function expressions like:
4//! - `lag(close, 1) over (partition by symbol order by timestamp)`
5//! - `row_number() over (order by close desc)`
6//! - `sum(volume) over (rows between 5 preceding and current row)`
7
8use crate::ast::{
9    Expr, Literal, OrderByClause, SortDirection, WindowBound, WindowExpr, WindowFrame,
10    WindowFrameType, WindowFunction, WindowSpec,
11};
12use crate::error::{Result, ShapeError, SourceLocation};
13use crate::parser::{Rule, pair_location, pair_span};
14use pest::iterators::Pair;
15
16use super::super::expressions;
17
18/// Parse a window function call expression
19///
20/// Grammar: `window_function_name "(" window_function_args? ")" over_clause`
21pub fn parse_window_function_call(pair: Pair<Rule>) -> Result<Expr> {
22    let span = pair_span(&pair);
23    let pair_loc = pair_location(&pair);
24    let mut inner = pair.into_inner();
25
26    // Parse function name
27    let func_name_pair = inner.next().ok_or_else(|| ShapeError::ParseError {
28        message: "expected window function name".to_string(),
29        location: Some(pair_loc.clone()),
30    })?;
31    let func_name = func_name_pair.as_str().to_lowercase();
32
33    // Parse function arguments (optional)
34    let mut args = Vec::new();
35    let mut over_pair = None;
36
37    for part in inner {
38        match part.as_rule() {
39            Rule::window_function_args => {
40                for arg_pair in part.into_inner() {
41                    if arg_pair.as_rule() == Rule::expression {
42                        args.push(expressions::parse_expression(arg_pair)?);
43                    }
44                }
45            }
46            Rule::over_clause => {
47                over_pair = Some(part);
48            }
49            _ => {}
50        }
51    }
52
53    // Parse OVER clause
54    let over_clause = over_pair.ok_or_else(|| ShapeError::ParseError {
55        message: "window function requires OVER clause".to_string(),
56        location: Some(
57            pair_loc
58                .clone()
59                .with_hint("add OVER (...) after the function call"),
60        ),
61    })?;
62    let window_spec = parse_over_clause(over_clause)?;
63
64    // Build WindowFunction based on name
65    let function = build_window_function(&func_name, args, &pair_loc)?;
66
67    Ok(Expr::WindowExpr(
68        Box::new(WindowExpr {
69            function,
70            over: window_spec,
71        }),
72        span,
73    ))
74}
75
76/// Build a WindowFunction from the parsed name and arguments
77fn build_window_function(
78    name: &str,
79    args: Vec<Expr>,
80    loc: &SourceLocation,
81) -> Result<WindowFunction> {
82    match name {
83        "lag" => {
84            let expr = args.first().cloned().unwrap_or(Expr::Identifier(
85                "close".to_string(),
86                crate::ast::Span::new(0, 0),
87            ));
88            let offset = extract_usize(&args.get(1).cloned()).unwrap_or(1);
89            let default = args.get(2).map(|e| Box::new(e.clone()));
90            Ok(WindowFunction::Lag {
91                expr: Box::new(expr),
92                offset,
93                default,
94            })
95        }
96        "lead" => {
97            let expr = args.first().cloned().unwrap_or(Expr::Identifier(
98                "close".to_string(),
99                crate::ast::Span::new(0, 0),
100            ));
101            let offset = extract_usize(&args.get(1).cloned()).unwrap_or(1);
102            let default = args.get(2).map(|e| Box::new(e.clone()));
103            Ok(WindowFunction::Lead {
104                expr: Box::new(expr),
105                offset,
106                default,
107            })
108        }
109        "row_number" => Ok(WindowFunction::RowNumber),
110        "rank" => Ok(WindowFunction::Rank),
111        "dense_rank" => Ok(WindowFunction::DenseRank),
112        "ntile" => {
113            let n = extract_usize(&args.first().cloned()).unwrap_or(1);
114            Ok(WindowFunction::Ntile(n))
115        }
116        "first_value" => {
117            let expr = args.into_iter().next().ok_or_else(|| ShapeError::ParseError {
118                message: "first_value requires an expression argument".to_string(),
119                location: Some(loc.clone()),
120            })?;
121            Ok(WindowFunction::FirstValue(Box::new(expr)))
122        }
123        "last_value" => {
124            let expr = args.into_iter().next().ok_or_else(|| ShapeError::ParseError {
125                message: "last_value requires an expression argument".to_string(),
126                location: Some(loc.clone()),
127            })?;
128            Ok(WindowFunction::LastValue(Box::new(expr)))
129        }
130        "nth_value" => {
131            let mut iter = args.into_iter();
132            let expr = iter.next().ok_or_else(|| ShapeError::ParseError {
133                message: "nth_value requires an expression argument".to_string(),
134                location: Some(loc.clone()),
135            })?;
136            let n = extract_usize(&iter.next()).unwrap_or(1);
137            Ok(WindowFunction::NthValue(Box::new(expr), n))
138        }
139        "sum" => {
140            let expr = args.into_iter().next().ok_or_else(|| ShapeError::ParseError {
141                message: "sum requires an expression argument".to_string(),
142                location: Some(loc.clone()),
143            })?;
144            Ok(WindowFunction::Sum(Box::new(expr)))
145        }
146        "avg" => {
147            let expr = args.into_iter().next().ok_or_else(|| ShapeError::ParseError {
148                message: "avg requires an expression argument".to_string(),
149                location: Some(loc.clone()),
150            })?;
151            Ok(WindowFunction::Avg(Box::new(expr)))
152        }
153        "min" => {
154            let expr = args.into_iter().next().ok_or_else(|| ShapeError::ParseError {
155                message: "min requires an expression argument".to_string(),
156                location: Some(loc.clone()),
157            })?;
158            Ok(WindowFunction::Min(Box::new(expr)))
159        }
160        "max" => {
161            let expr = args.into_iter().next().ok_or_else(|| ShapeError::ParseError {
162                message: "max requires an expression argument".to_string(),
163                location: Some(loc.clone()),
164            })?;
165            Ok(WindowFunction::Max(Box::new(expr)))
166        }
167        "count" => {
168            let expr = args.into_iter().next().map(Box::new);
169            Ok(WindowFunction::Count(expr))
170        }
171        _ => Err(ShapeError::ParseError {
172            message: format!("unknown window function: '{}'", name),
173            location: Some(
174                loc.clone()
175                    .with_hint("valid functions: lag, lead, row_number, rank, dense_rank, ntile, first_value, last_value, sum, avg, min, max, count"),
176            ),
177        }),
178    }
179}
180
181/// Extract usize from an expression if it's a literal number
182fn extract_usize(expr: &Option<Expr>) -> Option<usize> {
183    match expr {
184        Some(Expr::Literal(Literal::Number(n), _)) => Some(*n as usize),
185        _ => None,
186    }
187}
188
189/// Parse the OVER clause
190///
191/// Grammar: `"over" "(" window_spec? ")"`
192fn parse_over_clause(pair: Pair<Rule>) -> Result<WindowSpec> {
193    let mut partition_by = Vec::new();
194    let mut order_by = None;
195    let mut frame = None;
196
197    // Look for window_spec inside over_clause
198    for inner in pair.into_inner() {
199        if inner.as_rule() == Rule::window_spec {
200            for spec_part in inner.into_inner() {
201                match spec_part.as_rule() {
202                    Rule::partition_by_clause => {
203                        partition_by = parse_partition_by_clause(spec_part)?;
204                    }
205                    Rule::order_by_clause => {
206                        order_by = Some(parse_window_order_by(spec_part)?);
207                    }
208                    Rule::window_frame_clause => {
209                        frame = Some(parse_window_frame_clause(spec_part)?);
210                    }
211                    _ => {}
212                }
213            }
214        }
215    }
216
217    Ok(WindowSpec {
218        partition_by,
219        order_by,
220        frame,
221    })
222}
223
224/// Parse PARTITION BY clause
225///
226/// Grammar: `"partition" "by" expression ("," expression)*`
227fn parse_partition_by_clause(pair: Pair<Rule>) -> Result<Vec<Expr>> {
228    let mut exprs = Vec::new();
229    for inner in pair.into_inner() {
230        if inner.as_rule() == Rule::expression {
231            exprs.push(expressions::parse_expression(inner)?);
232        }
233    }
234    Ok(exprs)
235}
236
237/// Parse ORDER BY clause for window functions
238fn parse_window_order_by(pair: Pair<Rule>) -> Result<OrderByClause> {
239    let mut columns = Vec::new();
240
241    for inner in pair.into_inner() {
242        if inner.as_rule() == Rule::order_by_list {
243            for item in inner.into_inner() {
244                if item.as_rule() == Rule::order_by_item {
245                    let mut item_inner = item.into_inner();
246
247                    // Parse expression
248                    let expr_pair = item_inner.next().ok_or_else(|| ShapeError::ParseError {
249                        message: "expected expression in ORDER BY".to_string(),
250                        location: None,
251                    })?;
252                    let expr = expressions::parse_expression(expr_pair)?;
253
254                    // Parse optional direction
255                    let direction = if let Some(dir_pair) = item_inner.next() {
256                        match dir_pair.as_str().to_lowercase().as_str() {
257                            "desc" => SortDirection::Descending,
258                            _ => SortDirection::Ascending,
259                        }
260                    } else {
261                        SortDirection::Ascending
262                    };
263
264                    columns.push((expr, direction));
265                }
266            }
267        }
268    }
269
270    Ok(OrderByClause { columns })
271}
272
273/// Parse window frame clause
274///
275/// Grammar: `frame_type frame_extent`
276fn parse_window_frame_clause(pair: Pair<Rule>) -> Result<WindowFrame> {
277    let pair_loc = pair_location(&pair);
278    let mut inner = pair.into_inner();
279
280    // Parse frame type (ROWS or RANGE)
281    let frame_type_pair = inner.next().ok_or_else(|| ShapeError::ParseError {
282        message: "expected frame type (ROWS or RANGE)".to_string(),
283        location: Some(pair_loc.clone()),
284    })?;
285    let frame_type = match frame_type_pair.as_str().to_lowercase().as_str() {
286        "rows" => WindowFrameType::Rows,
287        "range" => WindowFrameType::Range,
288        _ => WindowFrameType::Rows,
289    };
290
291    // Parse frame extent
292    let extent_pair = inner.next().ok_or_else(|| ShapeError::ParseError {
293        message: "expected frame extent".to_string(),
294        location: Some(pair_loc),
295    })?;
296    let (start, end) = parse_frame_extent(extent_pair)?;
297
298    Ok(WindowFrame {
299        frame_type,
300        start,
301        end,
302    })
303}
304
305/// Parse frame extent
306///
307/// Grammar: `"between" frame_bound "and" frame_bound | frame_bound`
308fn parse_frame_extent(pair: Pair<Rule>) -> Result<(WindowBound, WindowBound)> {
309    let mut bounds = Vec::new();
310
311    for inner in pair.into_inner() {
312        if inner.as_rule() == Rule::frame_bound {
313            bounds.push(parse_frame_bound(inner)?);
314        }
315    }
316
317    match bounds.len() {
318        1 => {
319            // Single bound means start..CURRENT ROW
320            Ok((bounds.remove(0), WindowBound::CurrentRow))
321        }
322        2 => {
323            // BETWEEN start AND end
324            let end = bounds.remove(1);
325            let start = bounds.remove(0);
326            Ok((start, end))
327        }
328        _ => Ok((WindowBound::UnboundedPreceding, WindowBound::CurrentRow)),
329    }
330}
331
332/// Parse a single frame bound
333///
334/// Grammar: `"unbounded" "preceding" | "current" "row" | integer "preceding" | integer "following" | "unbounded" "following"`
335fn parse_frame_bound(pair: Pair<Rule>) -> Result<WindowBound> {
336    let text = pair.as_str().to_lowercase();
337    let parts: Vec<&str> = text.split_whitespace().collect();
338
339    match parts.as_slice() {
340        ["unbounded", "preceding"] => Ok(WindowBound::UnboundedPreceding),
341        ["unbounded", "following"] => Ok(WindowBound::UnboundedFollowing),
342        ["current", "row"] => Ok(WindowBound::CurrentRow),
343        [n, "preceding"] => {
344            let num = n.parse::<usize>().map_err(|_| ShapeError::ParseError {
345                message: format!("invalid frame bound number: '{}'", n),
346                location: Some(pair_location(&pair)),
347            })?;
348            Ok(WindowBound::Preceding(num))
349        }
350        [n, "following"] => {
351            let num = n.parse::<usize>().map_err(|_| ShapeError::ParseError {
352                message: format!("invalid frame bound number: '{}'", n),
353                location: Some(pair_location(&pair)),
354            })?;
355            Ok(WindowBound::Following(num))
356        }
357        _ => Err(ShapeError::ParseError {
358            message: format!("invalid frame bound: '{}'", text),
359            location: Some(
360                pair_location(&pair)
361                    .with_hint("use: UNBOUNDED PRECEDING, n PRECEDING, CURRENT ROW, n FOLLOWING, or UNBOUNDED FOLLOWING"),
362            ),
363        }),
364    }
365}
366
367/// Parse window function from a regular function call that has an OVER clause
368/// This is called when we detect a function call followed by OVER
369pub fn parse_window_from_function_call(
370    name: String,
371    args: Vec<Expr>,
372    over_pair: Pair<Rule>,
373    span: crate::ast::Span,
374) -> Result<Expr> {
375    let window_spec = parse_over_clause(over_pair)?;
376    let loc = SourceLocation::new(1, 1); // Placeholder location
377
378    let function = build_window_function(&name.to_lowercase(), args, &loc)?;
379
380    Ok(Expr::WindowExpr(
381        Box::new(WindowExpr {
382            function,
383            over: window_spec,
384        }),
385        span,
386    ))
387}
388
389#[cfg(test)]
390mod tests {
391    use super::*;
392    use pest::Parser;
393
394    fn parse_window_func(input: &str) -> Result<Expr> {
395        let pairs =
396            crate::parser::ShapeParser::parse(Rule::window_function_call, input).map_err(|e| {
397                ShapeError::ParseError {
398                    message: format!("parse error: {}", e),
399                    location: None,
400                }
401            })?;
402        let pair = pairs.into_iter().next().unwrap();
403        parse_window_function_call(pair)
404    }
405
406    #[test]
407    fn test_row_number() {
408        let result = parse_window_func("row_number() over ()");
409        assert!(result.is_ok());
410        if let Ok(Expr::WindowExpr(w, _)) = result {
411            assert!(matches!(w.function, WindowFunction::RowNumber));
412        }
413    }
414
415    #[test]
416    fn test_lag_with_args() {
417        let result = parse_window_func("lag(close, 1) over (order by timestamp)");
418        assert!(result.is_ok());
419        if let Ok(Expr::WindowExpr(w, _)) = result {
420            assert!(matches!(w.function, WindowFunction::Lag { offset: 1, .. }));
421            assert!(w.over.order_by.is_some());
422        }
423    }
424
425    #[test]
426    fn test_sum_with_partition() {
427        let result = parse_window_func("sum(volume) over (partition by symbol)");
428        assert!(result.is_ok());
429        if let Ok(Expr::WindowExpr(w, _)) = result {
430            assert!(matches!(w.function, WindowFunction::Sum(_)));
431            assert_eq!(w.over.partition_by.len(), 1);
432        }
433    }
434
435    #[test]
436    fn test_avg_with_frame() {
437        let result =
438            parse_window_func("avg(close) over (rows between 5 preceding and current row)");
439        assert!(result.is_ok());
440        if let Ok(Expr::WindowExpr(w, _)) = result {
441            assert!(matches!(w.function, WindowFunction::Avg(_)));
442            assert!(w.over.frame.is_some());
443            let frame = w.over.frame.unwrap();
444            assert!(matches!(frame.start, WindowBound::Preceding(5)));
445            assert!(matches!(frame.end, WindowBound::CurrentRow));
446        }
447    }
448}