1use chumsky::prelude::*;
6use rust_decimal::Decimal;
7use std::str::FromStr;
8
9use crate::ast::{
10 BalancesQuery, BinaryOperator, ColumnDef, CreateTableStmt, Expr, FromClause, FunctionCall,
11 InsertSource, InsertStmt, JournalQuery, Literal, OrderSpec, PrintQuery, Query, SelectQuery,
12 SortDirection, Target, UnaryOperator, WindowFunction, WindowSpec,
13};
14use crate::error::{ParseError, ParseErrorKind};
15use rustledger_core::NaiveDate;
16
17type ParserInput<'a> = &'a str;
18type ParserExtra<'a> = extra::Err<Rich<'a, char>>;
19
20enum ComparisonSuffix {
22 Between(Expr, Expr),
23 Binary(BinaryOperator, Expr),
24}
25
26pub fn parse(source: &str) -> Result<Query, ParseError> {
32 let (result, errs) = query_parser()
33 .then_ignore(ws())
34 .then_ignore(end())
35 .parse(source)
36 .into_output_errors();
37
38 if let Some(query) = result {
39 Ok(query)
40 } else {
41 let err = errs.first().map(|e| {
42 let kind = if e.found().is_none() {
43 ParseErrorKind::UnexpectedEof
44 } else {
45 ParseErrorKind::SyntaxError(e.to_string())
46 };
47 ParseError::new(kind, e.span().start)
48 });
49 Err(err.unwrap_or_else(|| ParseError::new(ParseErrorKind::UnexpectedEof, 0)))
50 }
51}
52
53fn ws<'a>() -> impl Parser<'a, ParserInput<'a>, (), ParserExtra<'a>> + Clone {
55 one_of(" \t\r\n").repeated().ignored()
56}
57
58fn ws1<'a>() -> impl Parser<'a, ParserInput<'a>, (), ParserExtra<'a>> + Clone {
60 one_of(" \t\r\n").repeated().at_least(1).ignored()
61}
62
63fn kw<'a>(keyword: &'static str) -> impl Parser<'a, ParserInput<'a>, (), ParserExtra<'a>> + Clone {
65 text::ident().try_map(move |s: &str, span| {
66 if s.eq_ignore_ascii_case(keyword) {
67 Ok(())
68 } else {
69 Err(Rich::custom(span, format!("expected keyword '{keyword}'")))
70 }
71 })
72}
73
74fn digits<'a>() -> impl Parser<'a, ParserInput<'a>, &'a str, ParserExtra<'a>> + Clone {
76 one_of("0123456789").repeated().at_least(1).to_slice()
77}
78
79fn query_parser<'a>() -> impl Parser<'a, ParserInput<'a>, Query, ParserExtra<'a>> {
81 ws().ignore_then(choice((
82 create_table_stmt().map(Query::CreateTable),
83 insert_stmt().map(Query::Insert),
84 select_query().map(|sq| Query::Select(Box::new(sq))),
85 journal_query().map(Query::Journal),
86 balances_query().map(Query::Balances),
87 print_query().map(Query::Print),
88 )))
89 .then_ignore(ws())
90 .then_ignore(just(';').or_not())
91}
92
93fn select_query<'a>() -> impl Parser<'a, ParserInput<'a>, SelectQuery, ParserExtra<'a>> {
95 recursive(|select_parser| {
96 let subquery_from = ws1()
98 .ignore_then(kw("FROM"))
99 .ignore_then(ws1())
100 .ignore_then(just('('))
101 .ignore_then(ws())
102 .ignore_then(select_parser)
103 .then_ignore(ws())
104 .then_ignore(just(')'))
105 .map(|sq| Some(FromClause::from_subquery(sq)));
106
107 let table_from = ws1()
110 .ignore_then(kw("FROM"))
111 .ignore_then(ws1())
112 .ignore_then(identifier().try_map(|name, span| {
113 if name.contains(':') {
116 Err(Rich::custom(
117 span,
118 "table names cannot contain ':' - this looks like an account filter expression",
119 ))
120 } else {
121 Ok(name)
122 }
123 }))
124 .then_ignore(
125 ws().then(choice((
127 kw("WHERE").ignored(),
128 kw("GROUP").ignored(),
129 kw("ORDER").ignored(),
130 kw("HAVING").ignored(),
131 kw("LIMIT").ignored(),
132 kw("PIVOT").ignored(),
133 end().ignored(),
134 )))
135 .rewind(),
136 )
137 .map(|name| Some(FromClause::from_table(name)));
138
139 let regular_from = from_clause().map(Some);
141
142 kw("SELECT")
143 .ignore_then(ws1())
144 .ignore_then(
145 kw("DISTINCT")
146 .then_ignore(ws1())
147 .or_not()
148 .map(|d| d.is_some()),
149 )
150 .then(targets())
151 .then(
152 subquery_from
153 .or(table_from)
154 .or(regular_from)
155 .or_not()
156 .map(std::option::Option::flatten),
157 )
158 .then(where_clause().or_not())
159 .then(group_by_clause().or_not())
160 .then(having_clause().or_not())
161 .then(pivot_by_clause().or_not())
162 .then(order_by_clause().or_not())
163 .then(limit_clause().or_not())
164 .map(
165 |(
166 (
167 (
168 (((((distinct, targets), from), where_clause), group_by), having),
169 pivot_by,
170 ),
171 order_by,
172 ),
173 limit,
174 )| {
175 SelectQuery {
176 distinct,
177 targets,
178 from,
179 where_clause,
180 group_by,
181 having,
182 pivot_by,
183 order_by,
184 limit,
185 }
186 },
187 )
188 })
189}
190
191fn from_clause<'a>() -> impl Parser<'a, ParserInput<'a>, FromClause, ParserExtra<'a>> + Clone {
193 ws1()
194 .ignore_then(kw("FROM"))
195 .ignore_then(ws1())
196 .ignore_then(from_modifiers())
197}
198
199fn targets<'a>() -> impl Parser<'a, ParserInput<'a>, Vec<Target>, ParserExtra<'a>> + Clone {
201 target()
202 .separated_by(ws().then(just(',')).then(ws()))
203 .at_least(1)
204 .collect()
205}
206
207fn target<'a>() -> impl Parser<'a, ParserInput<'a>, Target, ParserExtra<'a>> + Clone {
209 expr()
210 .then(
211 ws1()
212 .ignore_then(kw("AS"))
213 .ignore_then(ws1())
214 .ignore_then(identifier())
215 .or_not(),
216 )
217 .map(|(expr, alias)| Target { expr, alias })
218}
219
220fn from_modifiers<'a>() -> impl Parser<'a, ParserInput<'a>, FromClause, ParserExtra<'a>> + Clone {
222 let open_on = kw("OPEN")
223 .ignore_then(ws1())
224 .ignore_then(kw("ON"))
225 .ignore_then(ws1())
226 .ignore_then(date_literal())
227 .then_ignore(ws());
228
229 let close_on = kw("CLOSE")
230 .ignore_then(ws().then(kw("ON")).then(ws()).or_not())
231 .ignore_then(date_literal())
232 .then_ignore(ws());
233
234 let clear = kw("CLEAR").then_ignore(ws());
235
236 open_on
239 .or_not()
240 .then(close_on.or_not())
241 .then(clear.or_not().map(|c| c.is_some()))
242 .then(from_filter().or_not())
243 .map(|(((open_on, close_on), clear), filter)| FromClause {
244 open_on,
245 close_on,
246 clear,
247 filter,
248 subquery: None,
249 table_name: None,
250 })
251}
252
253fn from_filter<'a>() -> impl Parser<'a, ParserInput<'a>, Expr, ParserExtra<'a>> + Clone {
255 expr()
256}
257
258fn where_clause<'a>() -> impl Parser<'a, ParserInput<'a>, Expr, ParserExtra<'a>> + Clone {
260 ws1()
261 .ignore_then(kw("WHERE"))
262 .ignore_then(ws1())
263 .ignore_then(expr())
264}
265
266fn group_by_clause<'a>() -> impl Parser<'a, ParserInput<'a>, Vec<Expr>, ParserExtra<'a>> + Clone {
268 ws1()
269 .ignore_then(kw("GROUP"))
270 .ignore_then(ws1())
271 .ignore_then(kw("BY"))
272 .ignore_then(ws1())
273 .ignore_then(
274 expr()
275 .separated_by(ws().then(just(',')).then(ws()))
276 .at_least(1)
277 .collect(),
278 )
279}
280
281fn having_clause<'a>() -> impl Parser<'a, ParserInput<'a>, Expr, ParserExtra<'a>> + Clone {
283 ws1()
284 .ignore_then(kw("HAVING"))
285 .ignore_then(ws1())
286 .ignore_then(expr())
287}
288
289fn pivot_by_clause<'a>() -> impl Parser<'a, ParserInput<'a>, Vec<Expr>, ParserExtra<'a>> + Clone {
291 ws1()
292 .ignore_then(kw("PIVOT"))
293 .ignore_then(ws1())
294 .ignore_then(kw("BY"))
295 .ignore_then(ws1())
296 .ignore_then(
297 expr()
298 .separated_by(ws().then(just(',')).then(ws()))
299 .at_least(1)
300 .collect(),
301 )
302}
303
304fn order_by_clause<'a>() -> impl Parser<'a, ParserInput<'a>, Vec<OrderSpec>, ParserExtra<'a>> + Clone
306{
307 ws1()
308 .ignore_then(kw("ORDER"))
309 .ignore_then(ws1())
310 .ignore_then(kw("BY"))
311 .ignore_then(ws1())
312 .ignore_then(
313 order_spec()
314 .separated_by(ws().then(just(',')).then(ws()))
315 .at_least(1)
316 .collect(),
317 )
318}
319
320fn order_spec<'a>() -> impl Parser<'a, ParserInput<'a>, OrderSpec, ParserExtra<'a>> + Clone {
322 expr()
323 .then(
324 ws1()
325 .ignore_then(choice((
326 kw("ASC").to(SortDirection::Asc),
327 kw("DESC").to(SortDirection::Desc),
328 )))
329 .or_not(),
330 )
331 .map(|(expr, dir)| OrderSpec {
332 expr,
333 direction: dir.unwrap_or_default(),
334 })
335}
336
337fn limit_clause<'a>() -> impl Parser<'a, ParserInput<'a>, u64, ParserExtra<'a>> + Clone {
339 ws1()
340 .ignore_then(kw("LIMIT"))
341 .ignore_then(ws1())
342 .ignore_then(integer())
343 .map(|n| n as u64)
344}
345
346fn journal_query<'a>() -> impl Parser<'a, ParserInput<'a>, JournalQuery, ParserExtra<'a>> + Clone {
348 kw("JOURNAL")
349 .ignore_then(
350 ws1().ignore_then(string_literal()).or_not(),
352 )
353 .then(at_function().or_not())
354 .then(
355 ws1()
356 .ignore_then(kw("FROM"))
357 .ignore_then(ws1())
358 .ignore_then(from_modifiers())
359 .or_not(),
360 )
361 .map(|((account_pattern, at_function), from)| JournalQuery {
362 account_pattern: account_pattern.unwrap_or_default(),
363 at_function,
364 from,
365 })
366}
367
368fn balances_query<'a>() -> impl Parser<'a, ParserInput<'a>, BalancesQuery, ParserExtra<'a>> + Clone
370{
371 kw("BALANCES")
372 .ignore_then(at_function().or_not())
373 .then(
374 ws1()
375 .ignore_then(kw("FROM"))
376 .ignore_then(ws1())
377 .ignore_then(from_modifiers())
378 .or_not(),
379 )
380 .map(|(at_function, from)| BalancesQuery { at_function, from })
381}
382
383fn print_query<'a>() -> impl Parser<'a, ParserInput<'a>, PrintQuery, ParserExtra<'a>> + Clone {
385 kw("PRINT")
386 .ignore_then(
387 ws1()
388 .ignore_then(kw("FROM"))
389 .ignore_then(ws1())
390 .ignore_then(from_modifiers())
391 .or_not(),
392 )
393 .map(|from| PrintQuery { from })
394}
395
396fn create_table_stmt<'a>() -> impl Parser<'a, ParserInput<'a>, CreateTableStmt, ParserExtra<'a>> {
398 let column_def = identifier()
400 .then(ws().ignore_then(identifier()).or_not())
401 .map(|(name, type_hint)| ColumnDef { name, type_hint });
402
403 let column_list = just('(')
404 .ignore_then(ws())
405 .ignore_then(
406 column_def
407 .separated_by(ws().ignore_then(just(',')).then_ignore(ws()))
408 .collect::<Vec<_>>(),
409 )
410 .then_ignore(ws())
411 .then_ignore(just(')'));
412
413 let as_select = ws1()
414 .ignore_then(kw("AS"))
415 .ignore_then(ws1())
416 .ignore_then(select_query())
417 .map(Box::new);
418
419 kw("CREATE")
420 .ignore_then(ws1())
421 .ignore_then(kw("TABLE"))
422 .ignore_then(ws1())
423 .ignore_then(identifier())
424 .then(ws().ignore_then(column_list).or_not())
425 .then(as_select.or_not())
426 .map(|((table_name, columns), as_select)| CreateTableStmt {
427 table_name,
428 columns: columns.unwrap_or_default(),
429 as_select,
430 })
431}
432
433fn insert_stmt<'a>() -> impl Parser<'a, ParserInput<'a>, InsertStmt, ParserExtra<'a>> {
435 let column_list = just('(')
437 .ignore_then(ws())
438 .ignore_then(
439 identifier()
440 .separated_by(ws().ignore_then(just(',')).then_ignore(ws()))
441 .collect::<Vec<_>>(),
442 )
443 .then_ignore(ws())
444 .then_ignore(just(')'));
445
446 let value_row = just('(')
448 .ignore_then(ws())
449 .ignore_then(
450 expr()
451 .separated_by(ws().ignore_then(just(',')).then_ignore(ws()))
452 .collect::<Vec<_>>(),
453 )
454 .then_ignore(ws())
455 .then_ignore(just(')'));
456
457 let values_source = kw("VALUES")
458 .ignore_then(ws())
459 .ignore_then(
460 value_row
461 .separated_by(ws().ignore_then(just(',')).then_ignore(ws()))
462 .collect::<Vec<_>>(),
463 )
464 .map(InsertSource::Values);
465
466 let select_source = select_query().map(|sq| InsertSource::Select(Box::new(sq)));
468
469 let source = choice((values_source, select_source));
470
471 kw("INSERT")
472 .ignore_then(ws1())
473 .ignore_then(kw("INTO"))
474 .ignore_then(ws1())
475 .ignore_then(identifier())
476 .then(ws().ignore_then(column_list).or_not())
477 .then_ignore(ws())
478 .then(source)
479 .map(|((table_name, columns), source)| InsertStmt {
480 table_name,
481 columns,
482 source,
483 })
484}
485
486fn at_function<'a>() -> impl Parser<'a, ParserInput<'a>, String, ParserExtra<'a>> + Clone {
488 ws1()
489 .ignore_then(kw("AT"))
490 .ignore_then(ws1())
491 .ignore_then(identifier())
492}
493
494#[allow(clippy::large_stack_frames)]
496fn expr<'a>() -> impl Parser<'a, ParserInput<'a>, Expr, ParserExtra<'a>> + Clone {
497 recursive(|expr| {
498 let primary = primary_expr(expr.clone());
499
500 let unary = just('-')
502 .then_ignore(ws())
503 .or_not()
504 .then(primary)
505 .map(|(neg, e)| {
506 if neg.is_some() {
507 Expr::unary(UnaryOperator::Neg, e)
508 } else {
509 e
510 }
511 });
512
513 let multiplicative = unary.clone().foldl(
515 ws().ignore_then(choice((
516 just('*').to(BinaryOperator::Mul),
517 just('/').to(BinaryOperator::Div),
518 just('%').to(BinaryOperator::Mod),
519 )))
520 .then_ignore(ws())
521 .then(unary)
522 .repeated(),
523 |left, (op, right)| Expr::binary(left, op, right),
524 );
525
526 let additive = multiplicative.clone().foldl(
528 ws().ignore_then(choice((
529 just('+').to(BinaryOperator::Add),
530 just('-').to(BinaryOperator::Sub),
531 )))
532 .then_ignore(ws())
533 .then(multiplicative)
534 .repeated(),
535 |left, (op, right)| Expr::binary(left, op, right),
536 );
537
538 let comparison = additive
540 .clone()
541 .then(
542 choice((
543 ws1()
545 .ignore_then(kw("BETWEEN"))
546 .ignore_then(ws1())
547 .ignore_then(additive.clone())
548 .then_ignore(ws1())
549 .then_ignore(kw("AND"))
550 .then_ignore(ws1())
551 .then(additive.clone())
552 .map(|(low, high)| ComparisonSuffix::Between(low, high)),
553 ws()
555 .ignore_then(comparison_op())
556 .then_ignore(ws())
557 .then(additive)
558 .map(|(op, right)| ComparisonSuffix::Binary(op, right)),
559 ))
560 .or_not(),
561 )
562 .map(|(left, suffix)| match suffix {
563 Some(ComparisonSuffix::Between(low, high)) => Expr::between(left, low, high),
564 Some(ComparisonSuffix::Binary(op, right)) => Expr::binary(left, op, right),
565 None => left,
566 })
567 .then(
569 ws1()
570 .ignore_then(kw("IS"))
571 .ignore_then(ws1())
572 .ignore_then(choice((
573 kw("NOT")
574 .ignore_then(ws1())
575 .ignore_then(kw("NULL"))
576 .to(UnaryOperator::IsNotNull),
577 kw("NULL").to(UnaryOperator::IsNull),
578 )))
579 .or_not(),
580 )
581 .map(|(expr, is_null)| {
582 if let Some(op) = is_null {
583 Expr::unary(op, expr)
584 } else {
585 expr
586 }
587 });
588
589 let not_expr = kw("NOT")
591 .ignore_then(ws1())
592 .repeated()
593 .collect::<Vec<_>>()
594 .then(comparison)
595 .map(|(nots, e)| {
596 nots.into_iter()
597 .fold(e, |acc, ()| Expr::unary(UnaryOperator::Not, acc))
598 });
599
600 let and_expr = not_expr.clone().foldl(
602 ws1()
603 .ignore_then(kw("AND"))
604 .ignore_then(ws1())
605 .ignore_then(not_expr)
606 .repeated(),
607 |left, right| Expr::binary(left, BinaryOperator::And, right),
608 );
609
610 and_expr.clone().foldl(
612 ws1()
613 .ignore_then(kw("OR"))
614 .ignore_then(ws1())
615 .ignore_then(and_expr)
616 .repeated(),
617 |left, right| Expr::binary(left, BinaryOperator::Or, right),
618 )
619 })
620}
621
622fn comparison_op<'a>() -> impl Parser<'a, ParserInput<'a>, BinaryOperator, ParserExtra<'a>> + Clone
624{
625 choice((
626 just("!=").to(BinaryOperator::Ne),
628 just("!~").to(BinaryOperator::NotRegex),
629 just("<=").to(BinaryOperator::Le),
630 just(">=").to(BinaryOperator::Ge),
631 just('=').to(BinaryOperator::Eq),
633 just('<').to(BinaryOperator::Lt),
634 just('>').to(BinaryOperator::Gt),
635 just('~').to(BinaryOperator::Regex),
636 kw("NOT")
638 .ignore_then(ws1())
639 .ignore_then(kw("IN"))
640 .to(BinaryOperator::NotIn),
641 kw("IN").to(BinaryOperator::In),
642 ))
643}
644
645fn primary_expr<'a>(
647 expr: impl Parser<'a, ParserInput<'a>, Expr, ParserExtra<'a>> + Clone + 'a,
648) -> impl Parser<'a, ParserInput<'a>, Expr, ParserExtra<'a>> + Clone {
649 choice((
650 just('(')
652 .ignore_then(ws())
653 .ignore_then(expr.clone())
654 .then_ignore(ws())
655 .then_ignore(just(')'))
656 .map(|e| Expr::Paren(Box::new(e))),
657 function_call_or_column(expr),
660 literal().map(Expr::Literal),
662 just('*').to(Expr::Wildcard),
664 ))
665}
666
667fn function_call_or_column<'a>(
669 expr: impl Parser<'a, ParserInput<'a>, Expr, ParserExtra<'a>> + Clone + 'a,
670) -> impl Parser<'a, ParserInput<'a>, Expr, ParserExtra<'a>> + Clone {
671 identifier()
672 .then(
673 ws().ignore_then(just('('))
674 .ignore_then(ws())
675 .ignore_then(function_args(expr))
676 .then_ignore(ws())
677 .then_ignore(just(')'))
678 .or_not(),
679 )
680 .then(
681 ws1()
683 .ignore_then(kw("OVER"))
684 .ignore_then(ws())
685 .ignore_then(just('('))
686 .ignore_then(ws())
687 .ignore_then(window_spec())
688 .then_ignore(ws())
689 .then_ignore(just(')'))
690 .or_not(),
691 )
692 .map(|((name, args), over)| {
693 if let Some(args) = args {
694 if let Some(window_spec) = over {
695 Expr::Window(WindowFunction {
697 name,
698 args,
699 over: window_spec,
700 })
701 } else {
702 Expr::Function(FunctionCall { name, args })
704 }
705 } else {
706 Expr::Column(name)
707 }
708 })
709}
710
711fn window_spec<'a>() -> impl Parser<'a, ParserInput<'a>, WindowSpec, ParserExtra<'a>> + Clone {
713 let partition_by = kw("PARTITION")
714 .ignore_then(ws1())
715 .ignore_then(kw("BY"))
716 .ignore_then(ws1())
717 .ignore_then(
718 simple_arg()
719 .separated_by(ws().then(just(',')).then(ws()))
720 .at_least(1)
721 .collect::<Vec<_>>(),
722 )
723 .then_ignore(ws());
724
725 let window_order_by = kw("ORDER")
726 .ignore_then(ws1())
727 .ignore_then(kw("BY"))
728 .ignore_then(ws1())
729 .ignore_then(
730 window_order_spec()
731 .separated_by(ws().then(just(',')).then(ws()))
732 .at_least(1)
733 .collect::<Vec<_>>(),
734 );
735
736 partition_by
737 .or_not()
738 .then(window_order_by.or_not())
739 .map(|(partition_by, order_by)| WindowSpec {
740 partition_by,
741 order_by,
742 })
743}
744
745fn window_order_spec<'a>() -> impl Parser<'a, ParserInput<'a>, OrderSpec, ParserExtra<'a>> + Clone {
747 simple_arg()
748 .then(
749 ws1()
750 .ignore_then(choice((
751 kw("ASC").to(SortDirection::Asc),
752 kw("DESC").to(SortDirection::Desc),
753 )))
754 .or_not(),
755 )
756 .map(|(expr, dir)| OrderSpec {
757 expr,
758 direction: dir.unwrap_or_default(),
759 })
760}
761
762fn function_args<'a>(
764 expr: impl Parser<'a, ParserInput<'a>, Expr, ParserExtra<'a>> + Clone + 'a,
765) -> impl Parser<'a, ParserInput<'a>, Vec<Expr>, ParserExtra<'a>> + Clone {
766 expr.separated_by(ws().then(just(',')).then(ws())).collect()
769}
770
771fn simple_arg<'a>() -> impl Parser<'a, ParserInput<'a>, Expr, ParserExtra<'a>> + Clone {
773 choice((
774 just('*').to(Expr::Wildcard),
775 identifier().map(Expr::Column),
776 literal().map(Expr::Literal),
777 ))
778}
779
780fn literal<'a>() -> impl Parser<'a, ParserInput<'a>, Literal, ParserExtra<'a>> + Clone {
782 choice((
783 kw("TRUE").to(Literal::Boolean(true)),
785 kw("FALSE").to(Literal::Boolean(false)),
786 kw("NULL").to(Literal::Null),
787 date_literal().map(Literal::Date),
789 decimal().map(Literal::Number),
791 string_literal().map(Literal::String),
793 ))
794}
795
796fn identifier<'a>() -> impl Parser<'a, ParserInput<'a>, String, ParserExtra<'a>> + Clone {
798 text::ident().map(|s: &str| s.to_string())
799}
800
801fn string_literal<'a>() -> impl Parser<'a, ParserInput<'a>, String, ParserExtra<'a>> + Clone {
803 let double_quoted = just('"')
805 .ignore_then(
806 none_of("\"\\")
807 .or(just('\\').ignore_then(any()))
808 .repeated()
809 .collect::<String>(),
810 )
811 .then_ignore(just('"'));
812
813 let single_quoted = just('\'')
815 .ignore_then(
816 none_of("'\\")
817 .or(just('\\').ignore_then(any()))
818 .repeated()
819 .collect::<String>(),
820 )
821 .then_ignore(just('\''));
822
823 choice((double_quoted, single_quoted))
824}
825
826fn date_literal<'a>() -> impl Parser<'a, ParserInput<'a>, NaiveDate, ParserExtra<'a>> + Clone {
828 digits()
829 .then_ignore(just('-'))
830 .then(digits())
831 .then_ignore(just('-'))
832 .then(digits())
833 .try_map(|((year, month), day): ((&str, &str), &str), span| {
834 let year: i32 = year
835 .parse()
836 .map_err(|_| Rich::custom(span, "invalid year"))?;
837 let month: u32 = month
838 .parse()
839 .map_err(|_| Rich::custom(span, "invalid month"))?;
840 let day: u32 = day.parse().map_err(|_| Rich::custom(span, "invalid day"))?;
841 NaiveDate::from_ymd_opt(year, month, day)
842 .ok_or_else(|| Rich::custom(span, "invalid date"))
843 })
844}
845
846fn decimal<'a>() -> impl Parser<'a, ParserInput<'a>, Decimal, ParserExtra<'a>> + Clone {
848 just('-')
849 .or_not()
850 .then(digits())
851 .then(just('.').then(digits()).or_not())
852 .try_map(
853 |((neg, int_part), frac_part): ((Option<char>, &str), Option<(char, &str)>), span| {
854 let mut s = String::new();
855 if neg.is_some() {
856 s.push('-');
857 }
858 s.push_str(int_part);
859 if let Some((_, frac)) = frac_part {
860 s.push('.');
861 s.push_str(frac);
862 }
863 Decimal::from_str(&s).map_err(|_| Rich::custom(span, "invalid number"))
864 },
865 )
866}
867
868fn integer<'a>() -> impl Parser<'a, ParserInput<'a>, i64, ParserExtra<'a>> + Clone {
870 digits().try_map(|s: &str, span| {
871 s.parse::<i64>()
872 .map_err(|_| Rich::custom(span, "invalid integer"))
873 })
874}
875
876#[cfg(test)]
877mod tests {
878 use super::*;
879 use rust_decimal_macros::dec;
880
881 #[test]
882 fn test_simple_select() {
883 let query = parse("SELECT * FROM year = 2024").unwrap();
884 match query {
885 Query::Select(sel) => {
886 assert!(!sel.distinct);
887 assert_eq!(sel.targets.len(), 1);
888 assert!(matches!(sel.targets[0].expr, Expr::Wildcard));
889 assert!(sel.from.is_some());
890 }
891 _ => panic!("Expected SELECT query"),
892 }
893 }
894
895 #[test]
896 fn test_select_columns() {
897 let query = parse("SELECT date, account, position").unwrap();
898 match query {
899 Query::Select(sel) => {
900 assert_eq!(sel.targets.len(), 3);
901 assert!(matches!(&sel.targets[0].expr, Expr::Column(c) if c == "date"));
902 assert!(matches!(&sel.targets[1].expr, Expr::Column(c) if c == "account"));
903 assert!(matches!(&sel.targets[2].expr, Expr::Column(c) if c == "position"));
904 }
905 _ => panic!("Expected SELECT query"),
906 }
907 }
908
909 #[test]
910 fn test_select_with_alias() {
911 let query = parse("SELECT SUM(position) AS total").unwrap();
912 match query {
913 Query::Select(sel) => {
914 assert_eq!(sel.targets.len(), 1);
915 assert_eq!(sel.targets[0].alias, Some("total".to_string()));
916 match &sel.targets[0].expr {
917 Expr::Function(f) => {
918 assert_eq!(f.name, "SUM");
919 assert_eq!(f.args.len(), 1);
920 }
921 _ => panic!("Expected function"),
922 }
923 }
924 _ => panic!("Expected SELECT query"),
925 }
926 }
927
928 #[test]
929 fn test_select_distinct() {
930 let query = parse("SELECT DISTINCT account").unwrap();
931 match query {
932 Query::Select(sel) => {
933 assert!(sel.distinct);
934 }
935 _ => panic!("Expected SELECT query"),
936 }
937 }
938
939 #[test]
940 fn test_where_clause() {
941 let query = parse("SELECT * WHERE account ~ \"Expenses:\"").unwrap();
942 match query {
943 Query::Select(sel) => {
944 assert!(sel.where_clause.is_some());
945 match sel.where_clause.unwrap() {
946 Expr::BinaryOp(op) => {
947 assert_eq!(op.op, BinaryOperator::Regex);
948 }
949 _ => panic!("Expected binary op"),
950 }
951 }
952 _ => panic!("Expected SELECT query"),
953 }
954 }
955
956 #[test]
957 fn test_group_by() {
958 let query = parse("SELECT account, SUM(position) GROUP BY account").unwrap();
959 match query {
960 Query::Select(sel) => {
961 assert!(sel.group_by.is_some());
962 assert_eq!(sel.group_by.unwrap().len(), 1);
963 }
964 _ => panic!("Expected SELECT query"),
965 }
966 }
967
968 #[test]
969 fn test_order_by() {
970 let query = parse("SELECT * ORDER BY date DESC, account ASC").unwrap();
971 match query {
972 Query::Select(sel) => {
973 assert!(sel.order_by.is_some());
974 let order = sel.order_by.unwrap();
975 assert_eq!(order.len(), 2);
976 assert_eq!(order[0].direction, SortDirection::Desc);
977 assert_eq!(order[1].direction, SortDirection::Asc);
978 }
979 _ => panic!("Expected SELECT query"),
980 }
981 }
982
983 #[test]
984 fn test_limit() {
985 let query = parse("SELECT * LIMIT 100").unwrap();
986 match query {
987 Query::Select(sel) => {
988 assert_eq!(sel.limit, Some(100));
989 }
990 _ => panic!("Expected SELECT query"),
991 }
992 }
993
994 #[test]
995 fn test_from_open_close_clear() {
996 let query = parse("SELECT * FROM OPEN ON 2024-01-01 CLOSE ON 2024-12-31 CLEAR").unwrap();
997 match query {
998 Query::Select(sel) => {
999 let from = sel.from.unwrap();
1000 assert_eq!(
1001 from.open_on,
1002 Some(NaiveDate::from_ymd_opt(2024, 1, 1).unwrap())
1003 );
1004 assert_eq!(
1005 from.close_on,
1006 Some(NaiveDate::from_ymd_opt(2024, 12, 31).unwrap())
1007 );
1008 assert!(from.clear);
1009 }
1010 _ => panic!("Expected SELECT query"),
1011 }
1012 }
1013
1014 #[test]
1015 fn test_from_year_filter() {
1016 let query = parse("SELECT date, account FROM year = 2024").unwrap();
1017 match query {
1018 Query::Select(sel) => {
1019 let from = sel.from.unwrap();
1020 assert!(from.filter.is_some(), "FROM filter should be present");
1021 match from.filter.unwrap() {
1022 Expr::BinaryOp(op) => {
1023 assert_eq!(op.op, BinaryOperator::Eq);
1024 assert!(matches!(op.left, Expr::Column(ref c) if c == "year"));
1025 match op.right {
1027 Expr::Literal(Literal::Integer(n)) => assert_eq!(n, 2024),
1028 Expr::Literal(Literal::Number(n)) => assert_eq!(n, dec!(2024)),
1029 other => panic!("Expected numeric literal, got {other:?}"),
1030 }
1031 }
1032 other => panic!("Expected BinaryOp, got {other:?}"),
1033 }
1034 }
1035 _ => panic!("Expected SELECT query"),
1036 }
1037 }
1038
1039 #[test]
1040 fn test_journal_query() {
1041 let query = parse("JOURNAL \"Assets:Bank\" AT cost").unwrap();
1042 match query {
1043 Query::Journal(j) => {
1044 assert_eq!(j.account_pattern, "Assets:Bank");
1045 assert_eq!(j.at_function, Some("cost".to_string()));
1046 }
1047 _ => panic!("Expected JOURNAL query"),
1048 }
1049 }
1050
1051 #[test]
1052 fn test_balances_query() {
1053 let query = parse("BALANCES AT units FROM year = 2024").unwrap();
1054 match query {
1055 Query::Balances(b) => {
1056 assert_eq!(b.at_function, Some("units".to_string()));
1057 assert!(b.from.is_some());
1058 }
1059 _ => panic!("Expected BALANCES query"),
1060 }
1061 }
1062
1063 #[test]
1064 fn test_print_query() {
1065 let query = parse("PRINT").unwrap();
1066 assert!(matches!(query, Query::Print(_)));
1067 }
1068
1069 #[test]
1070 fn test_complex_expression() {
1071 let query = parse("SELECT * WHERE date >= 2024-01-01 AND account ~ \"Expenses:\"").unwrap();
1072 match query {
1073 Query::Select(sel) => match sel.where_clause.unwrap() {
1074 Expr::BinaryOp(op) => {
1075 assert_eq!(op.op, BinaryOperator::And);
1076 }
1077 _ => panic!("Expected AND"),
1078 },
1079 _ => panic!("Expected SELECT query"),
1080 }
1081 }
1082
1083 #[test]
1084 fn test_number_literal() {
1085 let query = parse("SELECT * WHERE year = 2024").unwrap();
1086 match query {
1087 Query::Select(sel) => match sel.where_clause.unwrap() {
1088 Expr::BinaryOp(op) => match op.right {
1089 Expr::Literal(Literal::Number(n)) => {
1090 assert_eq!(n, dec!(2024));
1091 }
1092 _ => panic!("Expected number literal"),
1093 },
1094 _ => panic!("Expected binary op"),
1095 },
1096 _ => panic!("Expected SELECT query"),
1097 }
1098 }
1099
1100 #[test]
1101 fn test_semicolon_optional() {
1102 assert!(parse("SELECT *").is_ok());
1103 assert!(parse("SELECT *;").is_ok());
1104 }
1105
1106 #[test]
1107 fn test_subquery_basic() {
1108 let query = parse("SELECT * FROM (SELECT account, position)").unwrap();
1109 match query {
1110 Query::Select(sel) => {
1111 assert!(sel.from.is_some());
1112 let from = sel.from.unwrap();
1113 assert!(from.subquery.is_some());
1114 let subquery = from.subquery.unwrap();
1115 assert_eq!(subquery.targets.len(), 2);
1116 }
1117 _ => panic!("Expected SELECT query"),
1118 }
1119 }
1120
1121 #[test]
1122 fn test_subquery_with_groupby() {
1123 let query = parse(
1124 "SELECT account, total FROM (SELECT account, SUM(position) AS total GROUP BY account)",
1125 )
1126 .unwrap();
1127 match query {
1128 Query::Select(sel) => {
1129 assert_eq!(sel.targets.len(), 2);
1130 let from = sel.from.unwrap();
1131 assert!(from.subquery.is_some());
1132 let subquery = from.subquery.unwrap();
1133 assert!(subquery.group_by.is_some());
1134 }
1135 _ => panic!("Expected SELECT query"),
1136 }
1137 }
1138
1139 #[test]
1140 fn test_subquery_with_outer_where() {
1141 let query =
1142 parse("SELECT * FROM (SELECT * WHERE year = 2024) WHERE account ~ \"Expenses:\"")
1143 .unwrap();
1144 match query {
1145 Query::Select(sel) => {
1146 assert!(sel.where_clause.is_some());
1148 let from = sel.from.unwrap();
1150 let subquery = from.subquery.unwrap();
1151 assert!(subquery.where_clause.is_some());
1152 }
1153 _ => panic!("Expected SELECT query"),
1154 }
1155 }
1156
1157 #[test]
1158 fn test_nested_subquery() {
1159 let query = parse("SELECT * FROM (SELECT * FROM (SELECT account))").unwrap();
1161 match query {
1162 Query::Select(sel) => {
1163 let from = sel.from.unwrap();
1164 let subquery1 = from.subquery.unwrap();
1165 let from2 = subquery1.from.unwrap();
1166 assert!(from2.subquery.is_some());
1167 }
1168 _ => panic!("Expected SELECT query"),
1169 }
1170 }
1171
1172 #[test]
1173 fn test_nested_function_calls() {
1174 let query = parse("SELECT units(sum(position))").unwrap();
1176 match query {
1177 Query::Select(sel) => {
1178 assert_eq!(sel.targets.len(), 1);
1179 match &sel.targets[0].expr {
1180 Expr::Function(outer) => {
1181 assert_eq!(outer.name, "units");
1182 assert_eq!(outer.args.len(), 1);
1183 match &outer.args[0] {
1184 Expr::Function(inner) => {
1185 assert_eq!(inner.name, "sum");
1186 assert_eq!(inner.args.len(), 1);
1187 assert!(
1188 matches!(&inner.args[0], Expr::Column(c) if c == "position")
1189 );
1190 }
1191 _ => panic!("Expected inner function call"),
1192 }
1193 }
1194 _ => panic!("Expected outer function call"),
1195 }
1196 }
1197 _ => panic!("Expected SELECT query"),
1198 }
1199 }
1200
1201 #[test]
1202 fn test_deeply_nested_function_calls() {
1203 let query = parse("SELECT foo(bar(baz(x)))").unwrap();
1205 match query {
1206 Query::Select(sel) => {
1207 assert_eq!(sel.targets.len(), 1);
1208 match &sel.targets[0].expr {
1209 Expr::Function(f1) => {
1210 assert_eq!(f1.name, "foo");
1211 match &f1.args[0] {
1212 Expr::Function(f2) => {
1213 assert_eq!(f2.name, "bar");
1214 match &f2.args[0] {
1215 Expr::Function(f3) => {
1216 assert_eq!(f3.name, "baz");
1217 assert!(matches!(&f3.args[0], Expr::Column(c) if c == "x"));
1218 }
1219 _ => panic!("Expected f3"),
1220 }
1221 }
1222 _ => panic!("Expected f2"),
1223 }
1224 }
1225 _ => panic!("Expected f1"),
1226 }
1227 }
1228 _ => panic!("Expected SELECT query"),
1229 }
1230 }
1231
1232 #[test]
1233 fn test_function_with_arithmetic() {
1234 let query = parse("SELECT sum(amount * 2)").unwrap();
1236 match query {
1237 Query::Select(sel) => match &sel.targets[0].expr {
1238 Expr::Function(f) => {
1239 assert_eq!(f.name, "sum");
1240 assert!(matches!(&f.args[0], Expr::BinaryOp(_)));
1241 }
1242 _ => panic!("Expected function"),
1243 },
1244 _ => panic!("Expected SELECT query"),
1245 }
1246 }
1247
1248 #[test]
1249 fn test_is_null() {
1250 let query = parse("SELECT * WHERE payee IS NULL").unwrap();
1251 match query {
1252 Query::Select(sel) => match sel.where_clause.unwrap() {
1253 Expr::UnaryOp(op) => {
1254 assert_eq!(op.op, UnaryOperator::IsNull);
1255 assert!(matches!(&op.operand, Expr::Column(c) if c == "payee"));
1256 }
1257 _ => panic!("Expected unary op"),
1258 },
1259 _ => panic!("Expected SELECT query"),
1260 }
1261 }
1262
1263 #[test]
1264 fn test_is_not_null() {
1265 let query = parse("SELECT * WHERE payee IS NOT NULL").unwrap();
1266 match query {
1267 Query::Select(sel) => match sel.where_clause.unwrap() {
1268 Expr::UnaryOp(op) => {
1269 assert_eq!(op.op, UnaryOperator::IsNotNull);
1270 assert!(matches!(&op.operand, Expr::Column(c) if c == "payee"));
1271 }
1272 _ => panic!("Expected unary op"),
1273 },
1274 _ => panic!("Expected SELECT query"),
1275 }
1276 }
1277
1278 #[test]
1279 fn test_not_regex() {
1280 let query = parse("SELECT * WHERE account !~ \"Assets:\"").unwrap();
1281 match query {
1282 Query::Select(sel) => match sel.where_clause.unwrap() {
1283 Expr::BinaryOp(op) => {
1284 assert_eq!(op.op, BinaryOperator::NotRegex);
1285 }
1286 _ => panic!("Expected binary op"),
1287 },
1288 _ => panic!("Expected SELECT query"),
1289 }
1290 }
1291
1292 #[test]
1293 fn test_modulo() {
1294 let query = parse("SELECT year % 4").unwrap();
1295 match query {
1296 Query::Select(sel) => match &sel.targets[0].expr {
1297 Expr::BinaryOp(op) => {
1298 assert_eq!(op.op, BinaryOperator::Mod);
1299 }
1300 _ => panic!("Expected binary op"),
1301 },
1302 _ => panic!("Expected SELECT query"),
1303 }
1304 }
1305
1306 #[test]
1307 fn test_between() {
1308 let query = parse("SELECT * WHERE year BETWEEN 2020 AND 2024").unwrap();
1309 match query {
1310 Query::Select(sel) => match sel.where_clause.unwrap() {
1311 Expr::Between { value, low, high } => {
1312 assert!(matches!(*value, Expr::Column(c) if c == "year"));
1313 assert!(matches!(*low, Expr::Literal(Literal::Number(_))));
1314 assert!(matches!(*high, Expr::Literal(Literal::Number(_))));
1315 }
1316 _ => panic!("Expected BETWEEN"),
1317 },
1318 _ => panic!("Expected SELECT query"),
1319 }
1320 }
1321
1322 #[test]
1323 fn test_not_in() {
1324 let query = parse("SELECT * WHERE account NOT IN tags").unwrap();
1325 match query {
1326 Query::Select(sel) => match sel.where_clause.unwrap() {
1327 Expr::BinaryOp(op) => {
1328 assert_eq!(op.op, BinaryOperator::NotIn);
1329 }
1330 _ => panic!("Expected binary op"),
1331 },
1332 _ => panic!("Expected SELECT query"),
1333 }
1334 }
1335
1336 #[test]
1337 fn test_string_arg_function() {
1338 let query = parse("SELECT foo(x)").unwrap();
1340 match query {
1341 Query::Select(sel) => match &sel.targets[0].expr {
1342 Expr::Function(f) => {
1343 assert_eq!(f.name, "foo");
1344 }
1345 _ => panic!("Expected function"),
1346 },
1347 _ => panic!("Expected SELECT query"),
1348 }
1349
1350 let query = parse("SELECT foo('bar')").unwrap();
1352 match query {
1353 Query::Select(sel) => match &sel.targets[0].expr {
1354 Expr::Function(f) => {
1355 assert_eq!(f.name, "foo");
1356 assert!(matches!(&f.args[0], Expr::Literal(Literal::String(s)) if s == "bar"));
1357 }
1358 _ => panic!("Expected function"),
1359 },
1360 _ => panic!("Expected SELECT query"),
1361 }
1362 }
1363
1364 #[test]
1365 fn test_meta_function() {
1366 let query = parse("SELECT meta('category')").unwrap();
1367 match query {
1368 Query::Select(sel) => match &sel.targets[0].expr {
1369 Expr::Function(f) => {
1370 assert_eq!(f.name.to_uppercase(), "META");
1371 assert_eq!(f.args.len(), 1);
1372 assert!(
1373 matches!(&f.args[0], Expr::Literal(Literal::String(s)) if s == "category")
1374 );
1375 }
1376 _ => panic!("Expected function"),
1377 },
1378 _ => panic!("Expected SELECT query"),
1379 }
1380 }
1381
1382 #[test]
1383 fn test_entry_meta_function() {
1384 let query = parse("SELECT entry_meta('source')").unwrap();
1385 match query {
1386 Query::Select(sel) => match &sel.targets[0].expr {
1387 Expr::Function(f) => {
1388 assert_eq!(f.name.to_uppercase(), "ENTRY_META");
1389 assert_eq!(f.args.len(), 1);
1390 }
1391 _ => panic!("Expected function"),
1392 },
1393 _ => panic!("Expected SELECT query"),
1394 }
1395 }
1396
1397 #[test]
1398 fn test_convert_function() {
1399 let query = parse("SELECT convert(position, 'USD')").unwrap();
1400 match query {
1401 Query::Select(sel) => match &sel.targets[0].expr {
1402 Expr::Function(f) => {
1403 assert_eq!(f.name.to_uppercase(), "CONVERT");
1404 assert_eq!(f.args.len(), 2);
1405 }
1406 _ => panic!("Expected function"),
1407 },
1408 _ => panic!("Expected SELECT query"),
1409 }
1410 }
1411
1412 #[test]
1413 fn test_type_cast_functions() {
1414 let query = parse("SELECT int(number)").unwrap();
1416 match query {
1417 Query::Select(sel) => match &sel.targets[0].expr {
1418 Expr::Function(f) => {
1419 assert_eq!(f.name.to_uppercase(), "INT");
1420 assert_eq!(f.args.len(), 1);
1421 }
1422 _ => panic!("Expected function"),
1423 },
1424 _ => panic!("Expected SELECT query"),
1425 }
1426
1427 let query = parse("SELECT decimal('123.45')").unwrap();
1429 match query {
1430 Query::Select(sel) => match &sel.targets[0].expr {
1431 Expr::Function(f) => {
1432 assert_eq!(f.name.to_uppercase(), "DECIMAL");
1433 }
1434 _ => panic!("Expected function"),
1435 },
1436 _ => panic!("Expected SELECT query"),
1437 }
1438
1439 let query = parse("SELECT str(123)").unwrap();
1441 match query {
1442 Query::Select(sel) => match &sel.targets[0].expr {
1443 Expr::Function(f) => {
1444 assert_eq!(f.name.to_uppercase(), "STR");
1445 }
1446 _ => panic!("Expected function"),
1447 },
1448 _ => panic!("Expected SELECT query"),
1449 }
1450
1451 let query = parse("SELECT bool(1)").unwrap();
1453 match query {
1454 Query::Select(sel) => match &sel.targets[0].expr {
1455 Expr::Function(f) => {
1456 assert_eq!(f.name.to_uppercase(), "BOOL");
1457 }
1458 _ => panic!("Expected function"),
1459 },
1460 _ => panic!("Expected SELECT query"),
1461 }
1462 }
1463}