shape_ast/parser/expressions/
window.rs1use crate::ast::{
9 Expr, Literal, OrderByClause, SortDirection, WindowBound, WindowExpr, WindowFrame,
10 WindowFrameType, WindowFunction, WindowSpec,
11};
12use crate::error::{Result, ShapeError, SourceLocation};
13use crate::parser::{Rule, pair_location, pair_span};
14use pest::iterators::Pair;
15
16use super::super::expressions;
17
18pub fn parse_window_function_call(pair: Pair<Rule>) -> Result<Expr> {
22 let span = pair_span(&pair);
23 let pair_loc = pair_location(&pair);
24 let mut inner = pair.into_inner();
25
26 let func_name_pair = inner.next().ok_or_else(|| ShapeError::ParseError {
28 message: "expected window function name".to_string(),
29 location: Some(pair_loc.clone()),
30 })?;
31 let func_name = func_name_pair.as_str().to_lowercase();
32
33 let mut args = Vec::new();
35 let mut over_pair = None;
36
37 for part in inner {
38 match part.as_rule() {
39 Rule::window_function_args => {
40 for arg_pair in part.into_inner() {
41 if arg_pair.as_rule() == Rule::expression {
42 args.push(expressions::parse_expression(arg_pair)?);
43 }
44 }
45 }
46 Rule::over_clause => {
47 over_pair = Some(part);
48 }
49 _ => {}
50 }
51 }
52
53 let over_clause = over_pair.ok_or_else(|| ShapeError::ParseError {
55 message: "window function requires OVER clause".to_string(),
56 location: Some(
57 pair_loc
58 .clone()
59 .with_hint("add OVER (...) after the function call"),
60 ),
61 })?;
62 let window_spec = parse_over_clause(over_clause)?;
63
64 let function = build_window_function(&func_name, args, &pair_loc)?;
66
67 Ok(Expr::WindowExpr(
68 Box::new(WindowExpr {
69 function,
70 over: window_spec,
71 }),
72 span,
73 ))
74}
75
76fn build_window_function(
78 name: &str,
79 args: Vec<Expr>,
80 loc: &SourceLocation,
81) -> Result<WindowFunction> {
82 match name {
83 "lag" => {
84 let expr = args.first().cloned().unwrap_or(Expr::Identifier(
85 "close".to_string(),
86 crate::ast::Span::new(0, 0),
87 ));
88 let offset = extract_usize(&args.get(1).cloned()).unwrap_or(1);
89 let default = args.get(2).map(|e| Box::new(e.clone()));
90 Ok(WindowFunction::Lag {
91 expr: Box::new(expr),
92 offset,
93 default,
94 })
95 }
96 "lead" => {
97 let expr = args.first().cloned().unwrap_or(Expr::Identifier(
98 "close".to_string(),
99 crate::ast::Span::new(0, 0),
100 ));
101 let offset = extract_usize(&args.get(1).cloned()).unwrap_or(1);
102 let default = args.get(2).map(|e| Box::new(e.clone()));
103 Ok(WindowFunction::Lead {
104 expr: Box::new(expr),
105 offset,
106 default,
107 })
108 }
109 "row_number" => Ok(WindowFunction::RowNumber),
110 "rank" => Ok(WindowFunction::Rank),
111 "dense_rank" => Ok(WindowFunction::DenseRank),
112 "ntile" => {
113 let n = extract_usize(&args.first().cloned()).unwrap_or(1);
114 Ok(WindowFunction::Ntile(n))
115 }
116 "first_value" => {
117 let expr = args.into_iter().next().ok_or_else(|| ShapeError::ParseError {
118 message: "first_value requires an expression argument".to_string(),
119 location: Some(loc.clone()),
120 })?;
121 Ok(WindowFunction::FirstValue(Box::new(expr)))
122 }
123 "last_value" => {
124 let expr = args.into_iter().next().ok_or_else(|| ShapeError::ParseError {
125 message: "last_value requires an expression argument".to_string(),
126 location: Some(loc.clone()),
127 })?;
128 Ok(WindowFunction::LastValue(Box::new(expr)))
129 }
130 "nth_value" => {
131 let mut iter = args.into_iter();
132 let expr = iter.next().ok_or_else(|| ShapeError::ParseError {
133 message: "nth_value requires an expression argument".to_string(),
134 location: Some(loc.clone()),
135 })?;
136 let n = extract_usize(&iter.next()).unwrap_or(1);
137 Ok(WindowFunction::NthValue(Box::new(expr), n))
138 }
139 "sum" => {
140 let expr = args.into_iter().next().ok_or_else(|| ShapeError::ParseError {
141 message: "sum requires an expression argument".to_string(),
142 location: Some(loc.clone()),
143 })?;
144 Ok(WindowFunction::Sum(Box::new(expr)))
145 }
146 "avg" => {
147 let expr = args.into_iter().next().ok_or_else(|| ShapeError::ParseError {
148 message: "avg requires an expression argument".to_string(),
149 location: Some(loc.clone()),
150 })?;
151 Ok(WindowFunction::Avg(Box::new(expr)))
152 }
153 "min" => {
154 let expr = args.into_iter().next().ok_or_else(|| ShapeError::ParseError {
155 message: "min requires an expression argument".to_string(),
156 location: Some(loc.clone()),
157 })?;
158 Ok(WindowFunction::Min(Box::new(expr)))
159 }
160 "max" => {
161 let expr = args.into_iter().next().ok_or_else(|| ShapeError::ParseError {
162 message: "max requires an expression argument".to_string(),
163 location: Some(loc.clone()),
164 })?;
165 Ok(WindowFunction::Max(Box::new(expr)))
166 }
167 "count" => {
168 let expr = args.into_iter().next().map(Box::new);
169 Ok(WindowFunction::Count(expr))
170 }
171 _ => Err(ShapeError::ParseError {
172 message: format!("unknown window function: '{}'", name),
173 location: Some(
174 loc.clone()
175 .with_hint("valid functions: lag, lead, row_number, rank, dense_rank, ntile, first_value, last_value, sum, avg, min, max, count"),
176 ),
177 }),
178 }
179}
180
181fn extract_usize(expr: &Option<Expr>) -> Option<usize> {
183 match expr {
184 Some(Expr::Literal(Literal::Number(n), _)) => Some(*n as usize),
185 _ => None,
186 }
187}
188
189fn parse_over_clause(pair: Pair<Rule>) -> Result<WindowSpec> {
193 let mut partition_by = Vec::new();
194 let mut order_by = None;
195 let mut frame = None;
196
197 for inner in pair.into_inner() {
199 if inner.as_rule() == Rule::window_spec {
200 for spec_part in inner.into_inner() {
201 match spec_part.as_rule() {
202 Rule::partition_by_clause => {
203 partition_by = parse_partition_by_clause(spec_part)?;
204 }
205 Rule::order_by_clause => {
206 order_by = Some(parse_window_order_by(spec_part)?);
207 }
208 Rule::window_frame_clause => {
209 frame = Some(parse_window_frame_clause(spec_part)?);
210 }
211 _ => {}
212 }
213 }
214 }
215 }
216
217 Ok(WindowSpec {
218 partition_by,
219 order_by,
220 frame,
221 })
222}
223
224fn parse_partition_by_clause(pair: Pair<Rule>) -> Result<Vec<Expr>> {
228 let mut exprs = Vec::new();
229 for inner in pair.into_inner() {
230 if inner.as_rule() == Rule::expression {
231 exprs.push(expressions::parse_expression(inner)?);
232 }
233 }
234 Ok(exprs)
235}
236
237fn parse_window_order_by(pair: Pair<Rule>) -> Result<OrderByClause> {
239 let mut columns = Vec::new();
240
241 for inner in pair.into_inner() {
242 if inner.as_rule() == Rule::order_by_list {
243 for item in inner.into_inner() {
244 if item.as_rule() == Rule::order_by_item {
245 let mut item_inner = item.into_inner();
246
247 let expr_pair = item_inner.next().ok_or_else(|| ShapeError::ParseError {
249 message: "expected expression in ORDER BY".to_string(),
250 location: None,
251 })?;
252 let expr = expressions::parse_expression(expr_pair)?;
253
254 let direction = if let Some(dir_pair) = item_inner.next() {
256 match dir_pair.as_str().to_lowercase().as_str() {
257 "desc" => SortDirection::Descending,
258 _ => SortDirection::Ascending,
259 }
260 } else {
261 SortDirection::Ascending
262 };
263
264 columns.push((expr, direction));
265 }
266 }
267 }
268 }
269
270 Ok(OrderByClause { columns })
271}
272
273fn parse_window_frame_clause(pair: Pair<Rule>) -> Result<WindowFrame> {
277 let pair_loc = pair_location(&pair);
278 let mut inner = pair.into_inner();
279
280 let frame_type_pair = inner.next().ok_or_else(|| ShapeError::ParseError {
282 message: "expected frame type (ROWS or RANGE)".to_string(),
283 location: Some(pair_loc.clone()),
284 })?;
285 let frame_type = match frame_type_pair.as_str().to_lowercase().as_str() {
286 "rows" => WindowFrameType::Rows,
287 "range" => WindowFrameType::Range,
288 _ => WindowFrameType::Rows,
289 };
290
291 let extent_pair = inner.next().ok_or_else(|| ShapeError::ParseError {
293 message: "expected frame extent".to_string(),
294 location: Some(pair_loc),
295 })?;
296 let (start, end) = parse_frame_extent(extent_pair)?;
297
298 Ok(WindowFrame {
299 frame_type,
300 start,
301 end,
302 })
303}
304
305fn parse_frame_extent(pair: Pair<Rule>) -> Result<(WindowBound, WindowBound)> {
309 let mut bounds = Vec::new();
310
311 for inner in pair.into_inner() {
312 if inner.as_rule() == Rule::frame_bound {
313 bounds.push(parse_frame_bound(inner)?);
314 }
315 }
316
317 match bounds.len() {
318 1 => {
319 Ok((bounds.remove(0), WindowBound::CurrentRow))
321 }
322 2 => {
323 let end = bounds.remove(1);
325 let start = bounds.remove(0);
326 Ok((start, end))
327 }
328 _ => Ok((WindowBound::UnboundedPreceding, WindowBound::CurrentRow)),
329 }
330}
331
332fn parse_frame_bound(pair: Pair<Rule>) -> Result<WindowBound> {
336 let text = pair.as_str().to_lowercase();
337 let parts: Vec<&str> = text.split_whitespace().collect();
338
339 match parts.as_slice() {
340 ["unbounded", "preceding"] => Ok(WindowBound::UnboundedPreceding),
341 ["unbounded", "following"] => Ok(WindowBound::UnboundedFollowing),
342 ["current", "row"] => Ok(WindowBound::CurrentRow),
343 [n, "preceding"] => {
344 let num = n.parse::<usize>().map_err(|_| ShapeError::ParseError {
345 message: format!("invalid frame bound number: '{}'", n),
346 location: Some(pair_location(&pair)),
347 })?;
348 Ok(WindowBound::Preceding(num))
349 }
350 [n, "following"] => {
351 let num = n.parse::<usize>().map_err(|_| ShapeError::ParseError {
352 message: format!("invalid frame bound number: '{}'", n),
353 location: Some(pair_location(&pair)),
354 })?;
355 Ok(WindowBound::Following(num))
356 }
357 _ => Err(ShapeError::ParseError {
358 message: format!("invalid frame bound: '{}'", text),
359 location: Some(
360 pair_location(&pair)
361 .with_hint("use: UNBOUNDED PRECEDING, n PRECEDING, CURRENT ROW, n FOLLOWING, or UNBOUNDED FOLLOWING"),
362 ),
363 }),
364 }
365}
366
367pub fn parse_window_from_function_call(
370 name: String,
371 args: Vec<Expr>,
372 over_pair: Pair<Rule>,
373 span: crate::ast::Span,
374) -> Result<Expr> {
375 let window_spec = parse_over_clause(over_pair)?;
376 let loc = SourceLocation::new(1, 1); let function = build_window_function(&name.to_lowercase(), args, &loc)?;
379
380 Ok(Expr::WindowExpr(
381 Box::new(WindowExpr {
382 function,
383 over: window_spec,
384 }),
385 span,
386 ))
387}
388
389#[cfg(test)]
390mod tests {
391 use super::*;
392 use pest::Parser;
393
394 fn parse_window_func(input: &str) -> Result<Expr> {
395 let pairs =
396 crate::parser::ShapeParser::parse(Rule::window_function_call, input).map_err(|e| {
397 ShapeError::ParseError {
398 message: format!("parse error: {}", e),
399 location: None,
400 }
401 })?;
402 let pair = pairs.into_iter().next().unwrap();
403 parse_window_function_call(pair)
404 }
405
406 #[test]
407 fn test_row_number() {
408 let result = parse_window_func("row_number() over ()");
409 assert!(result.is_ok());
410 if let Ok(Expr::WindowExpr(w, _)) = result {
411 assert!(matches!(w.function, WindowFunction::RowNumber));
412 }
413 }
414
415 #[test]
416 fn test_lag_with_args() {
417 let result = parse_window_func("lag(close, 1) over (order by timestamp)");
418 assert!(result.is_ok());
419 if let Ok(Expr::WindowExpr(w, _)) = result {
420 assert!(matches!(w.function, WindowFunction::Lag { offset: 1, .. }));
421 assert!(w.over.order_by.is_some());
422 }
423 }
424
425 #[test]
426 fn test_sum_with_partition() {
427 let result = parse_window_func("sum(volume) over (partition by symbol)");
428 assert!(result.is_ok());
429 if let Ok(Expr::WindowExpr(w, _)) = result {
430 assert!(matches!(w.function, WindowFunction::Sum(_)));
431 assert_eq!(w.over.partition_by.len(), 1);
432 }
433 }
434
435 #[test]
436 fn test_avg_with_frame() {
437 let result =
438 parse_window_func("avg(close) over (rows between 5 preceding and current row)");
439 assert!(result.is_ok());
440 if let Ok(Expr::WindowExpr(w, _)) = result {
441 assert!(matches!(w.function, WindowFunction::Avg(_)));
442 assert!(w.over.frame.is_some());
443 let frame = w.over.frame.unwrap();
444 assert!(matches!(frame.start, WindowBound::Preceding(5)));
445 assert!(matches!(frame.end, WindowBound::CurrentRow));
446 }
447 }
448}