rustledger_query/
parser.rs

1//! BQL Parser implementation.
2//!
3//! Uses chumsky for parser combinators.
4
5use chumsky::prelude::*;
6use rust_decimal::Decimal;
7use std::str::FromStr;
8
9use crate::ast::{
10    BalancesQuery, BinaryOperator, Expr, FromClause, FunctionCall, JournalQuery, Literal,
11    OrderSpec, PrintQuery, Query, SelectQuery, SortDirection, Target, UnaryOperator,
12    WindowFunction, WindowSpec,
13};
14use crate::error::{ParseError, ParseErrorKind};
15use rustledger_core::NaiveDate;
16
17type ParserInput<'a> = &'a str;
18type ParserExtra<'a> = extra::Err<Rich<'a, char>>;
19
20/// Parse a BQL query string.
21///
22/// # Errors
23///
24/// Returns a `ParseError` if the query string is malformed.
25pub fn parse(source: &str) -> Result<Query, ParseError> {
26    let (result, errs) = query_parser()
27        .then_ignore(ws())
28        .then_ignore(end())
29        .parse(source)
30        .into_output_errors();
31
32    if let Some(query) = result {
33        Ok(query)
34    } else {
35        let err = errs.first().map(|e| {
36            let kind = if e.found().is_none() {
37                ParseErrorKind::UnexpectedEof
38            } else {
39                ParseErrorKind::SyntaxError(e.to_string())
40            };
41            ParseError::new(kind, e.span().start)
42        });
43        Err(err.unwrap_or_else(|| ParseError::new(ParseErrorKind::UnexpectedEof, 0)))
44    }
45}
46
47/// Parse whitespace (spaces, tabs, newlines).
48fn ws<'a>() -> impl Parser<'a, ParserInput<'a>, (), ParserExtra<'a>> + Clone {
49    one_of(" \t\r\n").repeated().ignored()
50}
51
52/// Parse required whitespace.
53fn ws1<'a>() -> impl Parser<'a, ParserInput<'a>, (), ParserExtra<'a>> + Clone {
54    one_of(" \t\r\n").repeated().at_least(1).ignored()
55}
56
57/// Case-insensitive keyword parser.
58fn kw<'a>(keyword: &'static str) -> impl Parser<'a, ParserInput<'a>, (), ParserExtra<'a>> + Clone {
59    text::keyword(keyword).ignored()
60}
61
62/// Parse digits.
63fn digits<'a>() -> impl Parser<'a, ParserInput<'a>, &'a str, ParserExtra<'a>> + Clone {
64    one_of("0123456789").repeated().at_least(1).to_slice()
65}
66
67/// Parse the main query.
68fn query_parser<'a>() -> impl Parser<'a, ParserInput<'a>, Query, ParserExtra<'a>> {
69    ws().ignore_then(choice((
70        select_query().map(|sq| Query::Select(Box::new(sq))),
71        journal_query().map(Query::Journal),
72        balances_query().map(Query::Balances),
73        print_query().map(Query::Print),
74    )))
75    .then_ignore(ws())
76    .then_ignore(just(';').or_not())
77}
78
79/// Parse a SELECT query with optional subquery support.
80fn select_query<'a>() -> impl Parser<'a, ParserInput<'a>, SelectQuery, ParserExtra<'a>> {
81    recursive(|select_parser| {
82        // Subquery in FROM clause: FROM (SELECT ...)
83        let subquery_from = ws1()
84            .ignore_then(kw("FROM"))
85            .ignore_then(ws1())
86            .ignore_then(just('('))
87            .ignore_then(ws())
88            .ignore_then(select_parser)
89            .then_ignore(ws())
90            .then_ignore(just(')'))
91            .map(|sq| Some(FromClause::from_subquery(sq)));
92
93        // Regular FROM clause
94        let regular_from = from_clause().map(Some);
95
96        kw("SELECT")
97            .ignore_then(ws1())
98            .ignore_then(
99                kw("DISTINCT")
100                    .then_ignore(ws1())
101                    .or_not()
102                    .map(|d| d.is_some()),
103            )
104            .then(targets())
105            .then(
106                subquery_from
107                    .or(regular_from)
108                    .or_not()
109                    .map(std::option::Option::flatten),
110            )
111            .then(where_clause().or_not())
112            .then(group_by_clause().or_not())
113            .then(having_clause().or_not())
114            .then(pivot_by_clause().or_not())
115            .then(order_by_clause().or_not())
116            .then(limit_clause().or_not())
117            .map(
118                |(
119                    (
120                        (
121                            (((((distinct, targets), from), where_clause), group_by), having),
122                            pivot_by,
123                        ),
124                        order_by,
125                    ),
126                    limit,
127                )| {
128                    SelectQuery {
129                        distinct,
130                        targets,
131                        from,
132                        where_clause,
133                        group_by,
134                        having,
135                        pivot_by,
136                        order_by,
137                        limit,
138                    }
139                },
140            )
141    })
142}
143
144/// Parse FROM clause.
145fn from_clause<'a>() -> impl Parser<'a, ParserInput<'a>, FromClause, ParserExtra<'a>> + Clone {
146    ws1()
147        .ignore_then(kw("FROM"))
148        .ignore_then(ws1())
149        .ignore_then(from_modifiers())
150}
151
152/// Parse target expressions.
153fn targets<'a>() -> impl Parser<'a, ParserInput<'a>, Vec<Target>, ParserExtra<'a>> + Clone {
154    target()
155        .separated_by(ws().then(just(',')).then(ws()))
156        .at_least(1)
157        .collect()
158}
159
160/// Parse a single target.
161fn target<'a>() -> impl Parser<'a, ParserInput<'a>, Target, ParserExtra<'a>> + Clone {
162    expr()
163        .then(
164            ws1()
165                .ignore_then(kw("AS"))
166                .ignore_then(ws1())
167                .ignore_then(identifier())
168                .or_not(),
169        )
170        .map(|(expr, alias)| Target { expr, alias })
171}
172
173/// Parse FROM modifiers (OPEN ON, CLOSE ON, CLEAR, filter).
174fn from_modifiers<'a>() -> impl Parser<'a, ParserInput<'a>, FromClause, ParserExtra<'a>> + Clone {
175    let open_on = kw("OPEN")
176        .ignore_then(ws1())
177        .ignore_then(kw("ON"))
178        .ignore_then(ws1())
179        .ignore_then(date_literal())
180        .then_ignore(ws());
181
182    let close_on = kw("CLOSE")
183        .ignore_then(ws().then(kw("ON")).then(ws()).or_not())
184        .ignore_then(date_literal())
185        .then_ignore(ws());
186
187    let clear = kw("CLEAR").then_ignore(ws());
188
189    // Parse modifiers in order: OPEN ON, CLOSE ON, CLEAR, filter
190    open_on
191        .or_not()
192        .then(close_on.or_not())
193        .then(clear.or_not().map(|c| c.is_some()))
194        .then(from_filter().or_not())
195        .map(|(((open_on, close_on), clear), filter)| FromClause {
196            open_on,
197            close_on,
198            clear,
199            filter,
200            subquery: None,
201        })
202}
203
204/// Parse FROM filter expression (predicates).
205fn from_filter<'a>() -> impl Parser<'a, ParserInput<'a>, Expr, ParserExtra<'a>> + Clone {
206    expr()
207}
208
209/// Parse WHERE clause.
210fn where_clause<'a>() -> impl Parser<'a, ParserInput<'a>, Expr, ParserExtra<'a>> + Clone {
211    ws1()
212        .ignore_then(kw("WHERE"))
213        .ignore_then(ws1())
214        .ignore_then(expr())
215}
216
217/// Parse GROUP BY clause.
218fn group_by_clause<'a>() -> impl Parser<'a, ParserInput<'a>, Vec<Expr>, ParserExtra<'a>> + Clone {
219    ws1()
220        .ignore_then(kw("GROUP"))
221        .ignore_then(ws1())
222        .ignore_then(kw("BY"))
223        .ignore_then(ws1())
224        .ignore_then(
225            expr()
226                .separated_by(ws().then(just(',')).then(ws()))
227                .at_least(1)
228                .collect(),
229        )
230}
231
232/// Parse HAVING clause (filter on aggregated results).
233fn having_clause<'a>() -> impl Parser<'a, ParserInput<'a>, Expr, ParserExtra<'a>> + Clone {
234    ws1()
235        .ignore_then(kw("HAVING"))
236        .ignore_then(ws1())
237        .ignore_then(expr())
238}
239
240/// Parse PIVOT BY clause (pivot table transformation).
241fn pivot_by_clause<'a>() -> impl Parser<'a, ParserInput<'a>, Vec<Expr>, ParserExtra<'a>> + Clone {
242    ws1()
243        .ignore_then(kw("PIVOT"))
244        .ignore_then(ws1())
245        .ignore_then(kw("BY"))
246        .ignore_then(ws1())
247        .ignore_then(
248            expr()
249                .separated_by(ws().then(just(',')).then(ws()))
250                .at_least(1)
251                .collect(),
252        )
253}
254
255/// Parse ORDER BY clause.
256fn order_by_clause<'a>() -> impl Parser<'a, ParserInput<'a>, Vec<OrderSpec>, ParserExtra<'a>> + Clone
257{
258    ws1()
259        .ignore_then(kw("ORDER"))
260        .ignore_then(ws1())
261        .ignore_then(kw("BY"))
262        .ignore_then(ws1())
263        .ignore_then(
264            order_spec()
265                .separated_by(ws().then(just(',')).then(ws()))
266                .at_least(1)
267                .collect(),
268        )
269}
270
271/// Parse a single ORDER BY spec.
272fn order_spec<'a>() -> impl Parser<'a, ParserInput<'a>, OrderSpec, ParserExtra<'a>> + Clone {
273    expr()
274        .then(
275            ws1()
276                .ignore_then(choice((
277                    kw("ASC").to(SortDirection::Asc),
278                    kw("DESC").to(SortDirection::Desc),
279                )))
280                .or_not(),
281        )
282        .map(|(expr, dir)| OrderSpec {
283            expr,
284            direction: dir.unwrap_or_default(),
285        })
286}
287
288/// Parse LIMIT clause.
289fn limit_clause<'a>() -> impl Parser<'a, ParserInput<'a>, u64, ParserExtra<'a>> + Clone {
290    ws1()
291        .ignore_then(kw("LIMIT"))
292        .ignore_then(ws1())
293        .ignore_then(integer())
294        .map(|n| n as u64)
295}
296
297/// Parse JOURNAL query.
298fn journal_query<'a>() -> impl Parser<'a, ParserInput<'a>, JournalQuery, ParserExtra<'a>> + Clone {
299    kw("JOURNAL")
300        .ignore_then(ws1())
301        .ignore_then(string_literal())
302        .then(at_function().or_not())
303        .then(
304            ws1()
305                .ignore_then(kw("FROM"))
306                .ignore_then(ws1())
307                .ignore_then(from_modifiers())
308                .or_not(),
309        )
310        .map(|((account_pattern, at_function), from)| JournalQuery {
311            account_pattern,
312            at_function,
313            from,
314        })
315}
316
317/// Parse BALANCES query.
318fn balances_query<'a>() -> impl Parser<'a, ParserInput<'a>, BalancesQuery, ParserExtra<'a>> + Clone
319{
320    kw("BALANCES")
321        .ignore_then(at_function().or_not())
322        .then(
323            ws1()
324                .ignore_then(kw("FROM"))
325                .ignore_then(ws1())
326                .ignore_then(from_modifiers())
327                .or_not(),
328        )
329        .map(|(at_function, from)| BalancesQuery { at_function, from })
330}
331
332/// Parse PRINT query.
333fn print_query<'a>() -> impl Parser<'a, ParserInput<'a>, PrintQuery, ParserExtra<'a>> + Clone {
334    kw("PRINT")
335        .ignore_then(
336            ws1()
337                .ignore_then(kw("FROM"))
338                .ignore_then(ws1())
339                .ignore_then(from_modifiers())
340                .or_not(),
341        )
342        .map(|from| PrintQuery { from })
343}
344
345/// Parse AT function (e.g., AT cost, AT units).
346fn at_function<'a>() -> impl Parser<'a, ParserInput<'a>, String, ParserExtra<'a>> + Clone {
347    ws1()
348        .ignore_then(kw("AT"))
349        .ignore_then(ws1())
350        .ignore_then(identifier())
351}
352
353/// Parse an expression (with precedence climbing).
354#[allow(clippy::large_stack_frames)]
355fn expr<'a>() -> impl Parser<'a, ParserInput<'a>, Expr, ParserExtra<'a>> + Clone {
356    recursive(|expr| {
357        let primary = primary_expr(expr.clone());
358
359        // Unary minus
360        let unary = just('-')
361            .then_ignore(ws())
362            .or_not()
363            .then(primary)
364            .map(|(neg, e)| {
365                if neg.is_some() {
366                    Expr::unary(UnaryOperator::Neg, e)
367                } else {
368                    e
369                }
370            });
371
372        // Multiplicative: * /
373        let multiplicative = unary.clone().foldl(
374            ws().ignore_then(choice((
375                just('*').to(BinaryOperator::Mul),
376                just('/').to(BinaryOperator::Div),
377            )))
378            .then_ignore(ws())
379            .then(unary)
380            .repeated(),
381            |left, (op, right)| Expr::binary(left, op, right),
382        );
383
384        // Additive: + -
385        let additive = multiplicative.clone().foldl(
386            ws().ignore_then(choice((
387                just('+').to(BinaryOperator::Add),
388                just('-').to(BinaryOperator::Sub),
389            )))
390            .then_ignore(ws())
391            .then(multiplicative)
392            .repeated(),
393            |left, (op, right)| Expr::binary(left, op, right),
394        );
395
396        // Comparison: = != < <= > >= ~ IN
397        let comparison = additive
398            .clone()
399            .then(
400                ws().ignore_then(comparison_op())
401                    .then_ignore(ws())
402                    .then(additive)
403                    .or_not(),
404            )
405            .map(|(left, rest)| {
406                if let Some((op, right)) = rest {
407                    Expr::binary(left, op, right)
408                } else {
409                    left
410                }
411            });
412
413        // NOT
414        let not_expr = kw("NOT")
415            .ignore_then(ws1())
416            .repeated()
417            .collect::<Vec<_>>()
418            .then(comparison)
419            .map(|(nots, e)| {
420                nots.into_iter()
421                    .fold(e, |acc, ()| Expr::unary(UnaryOperator::Not, acc))
422            });
423
424        // AND
425        let and_expr = not_expr.clone().foldl(
426            ws1()
427                .ignore_then(kw("AND"))
428                .ignore_then(ws1())
429                .ignore_then(not_expr)
430                .repeated(),
431            |left, right| Expr::binary(left, BinaryOperator::And, right),
432        );
433
434        // OR (lowest precedence)
435        and_expr.clone().foldl(
436            ws1()
437                .ignore_then(kw("OR"))
438                .ignore_then(ws1())
439                .ignore_then(and_expr)
440                .repeated(),
441            |left, right| Expr::binary(left, BinaryOperator::Or, right),
442        )
443    })
444}
445
446/// Parse comparison operators.
447fn comparison_op<'a>() -> impl Parser<'a, ParserInput<'a>, BinaryOperator, ParserExtra<'a>> + Clone
448{
449    choice((
450        just("!=").to(BinaryOperator::Ne),
451        just("<=").to(BinaryOperator::Le),
452        just(">=").to(BinaryOperator::Ge),
453        just('=').to(BinaryOperator::Eq),
454        just('<').to(BinaryOperator::Lt),
455        just('>').to(BinaryOperator::Gt),
456        just('~').to(BinaryOperator::Regex),
457        kw("IN").to(BinaryOperator::In),
458    ))
459}
460
461/// Parse primary expressions.
462fn primary_expr<'a>(
463    expr: impl Parser<'a, ParserInput<'a>, Expr, ParserExtra<'a>> + Clone + 'a,
464) -> impl Parser<'a, ParserInput<'a>, Expr, ParserExtra<'a>> + Clone {
465    choice((
466        // Parenthesized expression
467        just('(')
468            .ignore_then(ws())
469            .ignore_then(expr)
470            .then_ignore(ws())
471            .then_ignore(just(')'))
472            .map(|e| Expr::Paren(Box::new(e))),
473        // Function call or column reference (must come before wildcard check)
474        function_call_or_column(),
475        // Literals
476        literal().map(Expr::Literal),
477        // Wildcard (fallback if nothing else matched)
478        just('*').to(Expr::Wildcard),
479    ))
480}
481
482/// Parse function call, window function, or column reference.
483fn function_call_or_column<'a>() -> impl Parser<'a, ParserInput<'a>, Expr, ParserExtra<'a>> + Clone
484{
485    identifier()
486        .then(
487            ws().ignore_then(just('('))
488                .ignore_then(ws())
489                .ignore_then(function_args())
490                .then_ignore(ws())
491                .then_ignore(just(')'))
492                .or_not(),
493        )
494        .then(
495            // Optional OVER clause for window functions
496            ws1()
497                .ignore_then(kw("OVER"))
498                .ignore_then(ws())
499                .ignore_then(just('('))
500                .ignore_then(ws())
501                .ignore_then(window_spec())
502                .then_ignore(ws())
503                .then_ignore(just(')'))
504                .or_not(),
505        )
506        .map(|((name, args), over)| {
507            if let Some(args) = args {
508                if let Some(window_spec) = over {
509                    // Window function
510                    Expr::Window(WindowFunction {
511                        name,
512                        args,
513                        over: window_spec,
514                    })
515                } else {
516                    // Regular function
517                    Expr::Function(FunctionCall { name, args })
518                }
519            } else {
520                Expr::Column(name)
521            }
522        })
523}
524
525/// Parse window specification (PARTITION BY and ORDER BY).
526fn window_spec<'a>() -> impl Parser<'a, ParserInput<'a>, WindowSpec, ParserExtra<'a>> + Clone {
527    let partition_by = kw("PARTITION")
528        .ignore_then(ws1())
529        .ignore_then(kw("BY"))
530        .ignore_then(ws1())
531        .ignore_then(
532            simple_arg()
533                .separated_by(ws().then(just(',')).then(ws()))
534                .at_least(1)
535                .collect::<Vec<_>>(),
536        )
537        .then_ignore(ws());
538
539    let window_order_by = kw("ORDER")
540        .ignore_then(ws1())
541        .ignore_then(kw("BY"))
542        .ignore_then(ws1())
543        .ignore_then(
544            window_order_spec()
545                .separated_by(ws().then(just(',')).then(ws()))
546                .at_least(1)
547                .collect::<Vec<_>>(),
548        );
549
550    partition_by
551        .or_not()
552        .then(window_order_by.or_not())
553        .map(|(partition_by, order_by)| WindowSpec {
554            partition_by,
555            order_by,
556        })
557}
558
559/// Parse ORDER BY spec within window (simple version).
560fn window_order_spec<'a>() -> impl Parser<'a, ParserInput<'a>, OrderSpec, ParserExtra<'a>> + Clone {
561    simple_arg()
562        .then(
563            ws1()
564                .ignore_then(choice((
565                    kw("ASC").to(SortDirection::Asc),
566                    kw("DESC").to(SortDirection::Desc),
567                )))
568                .or_not(),
569        )
570        .map(|(expr, dir)| OrderSpec {
571            expr,
572            direction: dir.unwrap_or_default(),
573        })
574}
575
576/// Parse function arguments.
577fn function_args<'a>() -> impl Parser<'a, ParserInput<'a>, Vec<Expr>, ParserExtra<'a>> + Clone {
578    // Allow empty args or comma-separated expressions
579    // Simple version: only allow columns and wildcards as function args (not full expressions)
580    simple_arg()
581        .separated_by(ws().then(just(',')).then(ws()))
582        .collect()
583}
584
585/// Parse a simple function argument (column, wildcard, or literal).
586fn simple_arg<'a>() -> impl Parser<'a, ParserInput<'a>, Expr, ParserExtra<'a>> + Clone {
587    choice((
588        just('*').to(Expr::Wildcard),
589        identifier().map(Expr::Column),
590        literal().map(Expr::Literal),
591    ))
592}
593
594/// Parse a literal.
595fn literal<'a>() -> impl Parser<'a, ParserInput<'a>, Literal, ParserExtra<'a>> + Clone {
596    choice((
597        // Keywords first
598        kw("TRUE").to(Literal::Boolean(true)),
599        kw("FALSE").to(Literal::Boolean(false)),
600        kw("NULL").to(Literal::Null),
601        // Date literal (must be before number to avoid parsing year as number)
602        date_literal().map(Literal::Date),
603        // Number
604        decimal().map(Literal::Number),
605        // String
606        string_literal().map(Literal::String),
607    ))
608}
609
610/// Parse an identifier (column name, function name).
611fn identifier<'a>() -> impl Parser<'a, ParserInput<'a>, String, ParserExtra<'a>> + Clone {
612    text::ident().map(|s: &str| s.to_string())
613}
614
615/// Parse a string literal.
616fn string_literal<'a>() -> impl Parser<'a, ParserInput<'a>, String, ParserExtra<'a>> + Clone {
617    // Double-quoted string
618    just('"')
619        .ignore_then(
620            none_of("\"\\")
621                .or(just('\\').ignore_then(any()))
622                .repeated()
623                .collect::<String>(),
624        )
625        .then_ignore(just('"'))
626}
627
628/// Parse a date literal (YYYY-MM-DD).
629fn date_literal<'a>() -> impl Parser<'a, ParserInput<'a>, NaiveDate, ParserExtra<'a>> + Clone {
630    digits()
631        .then_ignore(just('-'))
632        .then(digits())
633        .then_ignore(just('-'))
634        .then(digits())
635        .try_map(|((year, month), day): ((&str, &str), &str), span| {
636            let year: i32 = year
637                .parse()
638                .map_err(|_| Rich::custom(span, "invalid year"))?;
639            let month: u32 = month
640                .parse()
641                .map_err(|_| Rich::custom(span, "invalid month"))?;
642            let day: u32 = day.parse().map_err(|_| Rich::custom(span, "invalid day"))?;
643            NaiveDate::from_ymd_opt(year, month, day)
644                .ok_or_else(|| Rich::custom(span, "invalid date"))
645        })
646}
647
648/// Parse a decimal number.
649fn decimal<'a>() -> impl Parser<'a, ParserInput<'a>, Decimal, ParserExtra<'a>> + Clone {
650    just('-')
651        .or_not()
652        .then(digits())
653        .then(just('.').then(digits()).or_not())
654        .try_map(
655            |((neg, int_part), frac_part): ((Option<char>, &str), Option<(char, &str)>), span| {
656                let mut s = String::new();
657                if neg.is_some() {
658                    s.push('-');
659                }
660                s.push_str(int_part);
661                if let Some((_, frac)) = frac_part {
662                    s.push('.');
663                    s.push_str(frac);
664                }
665                Decimal::from_str(&s).map_err(|_| Rich::custom(span, "invalid number"))
666            },
667        )
668}
669
670/// Parse an integer.
671fn integer<'a>() -> impl Parser<'a, ParserInput<'a>, i64, ParserExtra<'a>> + Clone {
672    digits().try_map(|s: &str, span| {
673        s.parse::<i64>()
674            .map_err(|_| Rich::custom(span, "invalid integer"))
675    })
676}
677
678#[cfg(test)]
679mod tests {
680    use super::*;
681    use rust_decimal_macros::dec;
682
683    #[test]
684    fn test_simple_select() {
685        let query = parse("SELECT * FROM year = 2024").unwrap();
686        match query {
687            Query::Select(sel) => {
688                assert!(!sel.distinct);
689                assert_eq!(sel.targets.len(), 1);
690                assert!(matches!(sel.targets[0].expr, Expr::Wildcard));
691                assert!(sel.from.is_some());
692            }
693            _ => panic!("Expected SELECT query"),
694        }
695    }
696
697    #[test]
698    fn test_select_columns() {
699        let query = parse("SELECT date, account, position").unwrap();
700        match query {
701            Query::Select(sel) => {
702                assert_eq!(sel.targets.len(), 3);
703                assert!(matches!(&sel.targets[0].expr, Expr::Column(c) if c == "date"));
704                assert!(matches!(&sel.targets[1].expr, Expr::Column(c) if c == "account"));
705                assert!(matches!(&sel.targets[2].expr, Expr::Column(c) if c == "position"));
706            }
707            _ => panic!("Expected SELECT query"),
708        }
709    }
710
711    #[test]
712    fn test_select_with_alias() {
713        let query = parse("SELECT SUM(position) AS total").unwrap();
714        match query {
715            Query::Select(sel) => {
716                assert_eq!(sel.targets.len(), 1);
717                assert_eq!(sel.targets[0].alias, Some("total".to_string()));
718                match &sel.targets[0].expr {
719                    Expr::Function(f) => {
720                        assert_eq!(f.name, "SUM");
721                        assert_eq!(f.args.len(), 1);
722                    }
723                    _ => panic!("Expected function"),
724                }
725            }
726            _ => panic!("Expected SELECT query"),
727        }
728    }
729
730    #[test]
731    fn test_select_distinct() {
732        let query = parse("SELECT DISTINCT account").unwrap();
733        match query {
734            Query::Select(sel) => {
735                assert!(sel.distinct);
736            }
737            _ => panic!("Expected SELECT query"),
738        }
739    }
740
741    #[test]
742    fn test_where_clause() {
743        let query = parse("SELECT * WHERE account ~ \"Expenses:\"").unwrap();
744        match query {
745            Query::Select(sel) => {
746                assert!(sel.where_clause.is_some());
747                match sel.where_clause.unwrap() {
748                    Expr::BinaryOp(op) => {
749                        assert_eq!(op.op, BinaryOperator::Regex);
750                    }
751                    _ => panic!("Expected binary op"),
752                }
753            }
754            _ => panic!("Expected SELECT query"),
755        }
756    }
757
758    #[test]
759    fn test_group_by() {
760        let query = parse("SELECT account, SUM(position) GROUP BY account").unwrap();
761        match query {
762            Query::Select(sel) => {
763                assert!(sel.group_by.is_some());
764                assert_eq!(sel.group_by.unwrap().len(), 1);
765            }
766            _ => panic!("Expected SELECT query"),
767        }
768    }
769
770    #[test]
771    fn test_order_by() {
772        let query = parse("SELECT * ORDER BY date DESC, account ASC").unwrap();
773        match query {
774            Query::Select(sel) => {
775                assert!(sel.order_by.is_some());
776                let order = sel.order_by.unwrap();
777                assert_eq!(order.len(), 2);
778                assert_eq!(order[0].direction, SortDirection::Desc);
779                assert_eq!(order[1].direction, SortDirection::Asc);
780            }
781            _ => panic!("Expected SELECT query"),
782        }
783    }
784
785    #[test]
786    fn test_limit() {
787        let query = parse("SELECT * LIMIT 100").unwrap();
788        match query {
789            Query::Select(sel) => {
790                assert_eq!(sel.limit, Some(100));
791            }
792            _ => panic!("Expected SELECT query"),
793        }
794    }
795
796    #[test]
797    fn test_from_open_close_clear() {
798        let query = parse("SELECT * FROM OPEN ON 2024-01-01 CLOSE ON 2024-12-31 CLEAR").unwrap();
799        match query {
800            Query::Select(sel) => {
801                let from = sel.from.unwrap();
802                assert_eq!(
803                    from.open_on,
804                    Some(NaiveDate::from_ymd_opt(2024, 1, 1).unwrap())
805                );
806                assert_eq!(
807                    from.close_on,
808                    Some(NaiveDate::from_ymd_opt(2024, 12, 31).unwrap())
809                );
810                assert!(from.clear);
811            }
812            _ => panic!("Expected SELECT query"),
813        }
814    }
815
816    #[test]
817    fn test_journal_query() {
818        let query = parse("JOURNAL \"Assets:Bank\" AT cost").unwrap();
819        match query {
820            Query::Journal(j) => {
821                assert_eq!(j.account_pattern, "Assets:Bank");
822                assert_eq!(j.at_function, Some("cost".to_string()));
823            }
824            _ => panic!("Expected JOURNAL query"),
825        }
826    }
827
828    #[test]
829    fn test_balances_query() {
830        let query = parse("BALANCES AT units FROM year = 2024").unwrap();
831        match query {
832            Query::Balances(b) => {
833                assert_eq!(b.at_function, Some("units".to_string()));
834                assert!(b.from.is_some());
835            }
836            _ => panic!("Expected BALANCES query"),
837        }
838    }
839
840    #[test]
841    fn test_print_query() {
842        let query = parse("PRINT").unwrap();
843        assert!(matches!(query, Query::Print(_)));
844    }
845
846    #[test]
847    fn test_complex_expression() {
848        let query = parse("SELECT * WHERE date >= 2024-01-01 AND account ~ \"Expenses:\"").unwrap();
849        match query {
850            Query::Select(sel) => match sel.where_clause.unwrap() {
851                Expr::BinaryOp(op) => {
852                    assert_eq!(op.op, BinaryOperator::And);
853                }
854                _ => panic!("Expected AND"),
855            },
856            _ => panic!("Expected SELECT query"),
857        }
858    }
859
860    #[test]
861    fn test_number_literal() {
862        let query = parse("SELECT * WHERE year = 2024").unwrap();
863        match query {
864            Query::Select(sel) => match sel.where_clause.unwrap() {
865                Expr::BinaryOp(op) => match op.right {
866                    Expr::Literal(Literal::Number(n)) => {
867                        assert_eq!(n, dec!(2024));
868                    }
869                    _ => panic!("Expected number literal"),
870                },
871                _ => panic!("Expected binary op"),
872            },
873            _ => panic!("Expected SELECT query"),
874        }
875    }
876
877    #[test]
878    fn test_semicolon_optional() {
879        assert!(parse("SELECT *").is_ok());
880        assert!(parse("SELECT *;").is_ok());
881    }
882
883    #[test]
884    fn test_subquery_basic() {
885        let query = parse("SELECT * FROM (SELECT account, position)").unwrap();
886        match query {
887            Query::Select(sel) => {
888                assert!(sel.from.is_some());
889                let from = sel.from.unwrap();
890                assert!(from.subquery.is_some());
891                let subquery = from.subquery.unwrap();
892                assert_eq!(subquery.targets.len(), 2);
893            }
894            _ => panic!("Expected SELECT query"),
895        }
896    }
897
898    #[test]
899    fn test_subquery_with_groupby() {
900        let query = parse(
901            "SELECT account, total FROM (SELECT account, SUM(position) AS total GROUP BY account)",
902        )
903        .unwrap();
904        match query {
905            Query::Select(sel) => {
906                assert_eq!(sel.targets.len(), 2);
907                let from = sel.from.unwrap();
908                assert!(from.subquery.is_some());
909                let subquery = from.subquery.unwrap();
910                assert!(subquery.group_by.is_some());
911            }
912            _ => panic!("Expected SELECT query"),
913        }
914    }
915
916    #[test]
917    fn test_subquery_with_outer_where() {
918        let query =
919            parse("SELECT * FROM (SELECT * WHERE year = 2024) WHERE account ~ \"Expenses:\"")
920                .unwrap();
921        match query {
922            Query::Select(sel) => {
923                // Outer WHERE
924                assert!(sel.where_clause.is_some());
925                // Subquery with its own WHERE
926                let from = sel.from.unwrap();
927                let subquery = from.subquery.unwrap();
928                assert!(subquery.where_clause.is_some());
929            }
930            _ => panic!("Expected SELECT query"),
931        }
932    }
933
934    #[test]
935    fn test_nested_subquery() {
936        // Two levels of nesting
937        let query = parse("SELECT * FROM (SELECT * FROM (SELECT account))").unwrap();
938        match query {
939            Query::Select(sel) => {
940                let from = sel.from.unwrap();
941                let subquery1 = from.subquery.unwrap();
942                let from2 = subquery1.from.unwrap();
943                assert!(from2.subquery.is_some());
944            }
945            _ => panic!("Expected SELECT query"),
946        }
947    }
948}