1pub(crate) mod locy_parser;
2mod locy_walker;
3mod walker;
4
5use crate::ast::{Expr, Query};
6use crate::locy_ast::LocyProgram;
7use pest::Parser;
8use pest_derive::Parser;
9
10#[derive(Debug, thiserror::Error)]
12#[error("{message}")]
13pub struct ParseError {
14 message: String,
15}
16
17impl ParseError {
18 pub fn new(message: String) -> Self {
19 Self { message }
20 }
21}
22
23#[derive(Parser)]
24#[grammar = "grammar/cypher.pest"]
25pub struct CypherParser;
26
27const MAX_NESTING_DEPTH: u32 = 200;
38
39fn check_nesting_depth(input: &str) -> Result<(), ParseError> {
52 let bytes = input.as_bytes();
53 let mut i = 0usize;
54 let mut depth: i32 = 0;
55 let mut max_depth: i32 = 0;
56
57 while i < bytes.len() {
58 match bytes[i] {
59 quote @ (b'\'' | b'"') => {
60 i += 1;
62 while i < bytes.len() {
63 match bytes[i] {
64 b'\\' => i += 2,
65 c if c == quote => {
66 i += 1;
67 break;
68 }
69 _ => i += 1,
70 }
71 }
72 }
73 b'`' => {
74 i += 1;
76 while i < bytes.len() && bytes[i] != b'`' {
77 i += 1;
78 }
79 i += 1;
80 }
81 b'/' if bytes.get(i + 1) == Some(&b'/') => {
82 i += 2;
83 while i < bytes.len() && bytes[i] != b'\n' {
84 i += 1;
85 }
86 }
87 b'/' if bytes.get(i + 1) == Some(&b'*') => {
88 i += 2;
89 while i < bytes.len() && !(bytes[i] == b'*' && bytes.get(i + 1) == Some(&b'/')) {
90 i += 1;
91 }
92 i += 2;
93 }
94 b'(' | b'[' | b'{' => {
95 depth += 1;
96 max_depth = max_depth.max(depth);
97 i += 1;
98 }
99 b')' | b']' | b'}' => {
100 depth = (depth - 1).max(0);
101 i += 1;
102 }
103 b if b.is_ascii_alphabetic() || b == b'_' => {
104 let start = i;
106 while i < bytes.len() && (bytes[i].is_ascii_alphanumeric() || bytes[i] == b'_') {
107 i += 1;
108 }
109 let word = &input[start..i];
110 if word.eq_ignore_ascii_case("case") {
111 depth += 1;
112 max_depth = max_depth.max(depth);
113 } else if word.eq_ignore_ascii_case("end") {
114 depth = (depth - 1).max(0);
115 }
116 }
117 _ => i += 1,
118 }
119
120 if max_depth as u32 > MAX_NESTING_DEPTH {
121 return Err(ParseError::new(format!(
122 "SyntaxError: NestingTooDeep - query nesting exceeds the maximum \
123 supported depth ({MAX_NESTING_DEPTH})"
124 )));
125 }
126 }
127
128 Ok(())
129}
130
131pub fn parse(input: &str) -> Result<Query, ParseError> {
132 check_nesting_depth(input)?;
133 let pairs = CypherParser::parse(Rule::query, input).map_err(|e| map_pest_error(input, e))?;
134
135 walker::build_query(pairs)
136}
137
138pub fn parse_expression(input: &str) -> Result<Expr, ParseError> {
139 check_nesting_depth(input)?;
140 let pairs =
141 CypherParser::parse(Rule::expression, input).map_err(|e| map_pest_error(input, e))?;
142
143 walker::build_expression(pairs.into_iter().next().unwrap())
144}
145
146pub fn parse_locy(input: &str) -> Result<LocyProgram, ParseError> {
147 use locy_parser::LocyParser;
148 use locy_parser::Rule as LocyRule;
149
150 check_nesting_depth(input)?;
151 let pairs = LocyParser::parse(LocyRule::locy_query, input)
152 .map_err(|e| map_locy_pest_error(input, e))?;
153
154 locy_walker::build_program(pairs.into_iter().next().unwrap())
155}
156
157fn expects_identifier(e: &pest::error::Error<Rule>) -> bool {
161 use pest::error::ErrorVariant;
162 match &e.variant {
163 ErrorVariant::ParsingError { positives, .. } => positives
164 .iter()
165 .any(|r| matches!(r, Rule::identifier | Rule::identifier_or_keyword)),
166 _ => false,
167 }
168}
169
170fn error_position<R: pest::RuleType>(e: &pest::error::Error<R>) -> usize {
171 match e.location {
172 pest::error::InputLocation::Pos(p) => p,
173 pest::error::InputLocation::Span((s, _)) => s,
174 }
175}
176
177fn extract_token_span_at(input: &str, pos: usize) -> Option<(usize, usize)> {
178 let bytes = input.as_bytes();
179 if bytes.is_empty() {
180 return None;
181 }
182
183 let mut p = pos.min(bytes.len() - 1);
184
185 let is_token_char =
186 |b: u8| b.is_ascii_alphanumeric() || matches!(b, b'_' | b'-' | b'.' | b'#' | b'$');
187
188 if !is_token_char(bytes[p]) {
189 if p == 0 || !is_token_char(bytes[p - 1]) {
190 return None;
191 }
192 p -= 1;
193 }
194
195 let mut start = p;
196 while start > 0 && is_token_char(bytes[start - 1]) {
197 start -= 1;
198 }
199
200 let mut end = p;
201 while end < bytes.len() && is_token_char(bytes[end]) {
202 end += 1;
203 }
204
205 Some((start, end))
206}
207
208fn is_map_key_like_context(input: &str, start: usize, end: usize) -> bool {
209 let bytes = input.as_bytes();
210 if bytes.is_empty() || start >= bytes.len() || end > bytes.len() {
211 return false;
212 }
213
214 let mut colon_pos = end;
215 while colon_pos < bytes.len() && bytes[colon_pos].is_ascii_whitespace() {
216 colon_pos += 1;
217 }
218 if colon_pos >= bytes.len() || bytes[colon_pos] != b':' {
219 return false;
220 }
221
222 let mut prev_pos = start;
223 while prev_pos > 0 && bytes[prev_pos - 1].is_ascii_whitespace() {
224 prev_pos -= 1;
225 }
226 if prev_pos == 0 {
227 return false;
228 }
229
230 matches!(bytes[prev_pos - 1], b'{' | b',')
231}
232
233fn relationship_bracket_segment(input: &str, pos: usize) -> Option<&str> {
234 let pos = pos.min(input.len());
235 let before = &input[..pos];
236 let start = before.rfind('[')?;
237
238 let prefix = &input[..start];
240 if !prefix.trim_end().ends_with('-') {
241 return None;
242 }
243
244 let after = &input[start..];
245 let end = after.find(']').map(|i| start + i + 1).unwrap_or(pos);
246 Some(&input[start..end])
247}
248
249fn is_invalid_relationship_pattern(input: &str, pos: usize) -> bool {
250 let Some(segment) = relationship_bracket_segment(input, pos) else {
251 return false;
252 };
253 (segment.contains("..") && !segment.contains('*')) || segment.contains("*-")
255}
256
257fn is_invalid_number_literal(input: &str, pos: usize) -> bool {
258 let Some((start, end)) = extract_token_span_at(input, pos) else {
259 return false;
260 };
261 if is_map_key_like_context(input, start, end) {
262 return false;
263 }
264 let token = &input[start..end];
265
266 let t = token.strip_prefix('-').unwrap_or(token);
267 if !t.as_bytes().first().is_some_and(|b| b.is_ascii_digit()) {
268 return false;
269 }
270
271 let has_only = |digits: &str, valid: fn(&char) -> bool| {
272 digits.is_empty() || !digits.chars().all(|c| valid(&c) || c == '_')
273 };
274
275 if let Some(digits) = t.strip_prefix("0x").or_else(|| t.strip_prefix("0X")) {
276 return has_only(digits, char::is_ascii_hexdigit);
277 }
278 if let Some(digits) = t.strip_prefix("0o").or_else(|| t.strip_prefix("0O")) {
279 return has_only(digits, |c| matches!(c, '0'..='7'));
280 }
281
282 t.chars().any(|c| c.is_ascii_alphabetic())
284}
285
286fn invalid_unicode_character(input: &str, pos: usize) -> Option<char> {
287 let ch = input.get(pos..)?.chars().next()?;
288 matches!(ch, '—' | '–' | '−').then_some(ch)
289}
290
291const CYPHER_RESERVED_KEYWORDS: &[&str] = &[
294 "match",
295 "optional",
296 "where",
297 "create",
298 "merge",
299 "set",
300 "remove",
301 "delete",
302 "detach",
303 "return",
304 "with",
305 "unwind",
306 "union",
307 "call",
308 "yield",
309 "distinct",
310 "order",
311 "by",
312 "asc",
313 "desc",
314 "skip",
315 "limit",
316 "as",
317 "and",
318 "or",
319 "xor",
320 "not",
321 "in",
322 "contains",
323 "starts",
324 "ends",
325 "is",
326 "null",
327 "true",
328 "false",
329 "case",
330 "when",
331 "then",
332 "else",
333 "if",
334 "from",
335 "to",
336 "on",
337 "drop",
338 "alter",
339 "show",
340 "over",
341 "partition",
342 "explain",
343 "recursive",
344 "valid_at",
345 "each",
346];
347
348const LOCY_RESERVED_KEYWORDS: &[&str] = &[
350 "rule", "along", "prev", "fold", "best", "derive", "assume", "abduce", "query",
351];
352
353fn reserved_keyword_at(input: &str, pos: usize, extra_keywords: &[&str]) -> Option<String> {
355 let (start, end) = extract_token_span_at(input, pos)?;
356 let token = &input[start..end];
357 let lower = token.to_lowercase();
358 if CYPHER_RESERVED_KEYWORDS.contains(&lower.as_str())
359 || extra_keywords.contains(&lower.as_str())
360 {
361 Some(token.to_string())
362 } else {
363 None
364 }
365}
366
367fn locy_context_category(input: &str, pos: usize) -> Option<&'static str> {
369 let before = input[..pos].trim_end();
370 let before_upper = before.to_uppercase();
371 if before_upper.ends_with("BEST BY") {
373 return Some("InvalidBestByClause");
374 }
375 if before_upper.ends_with("ALONG") {
376 return Some("InvalidAlongClause");
377 }
378 if before_upper.ends_with("FOLD") {
379 return Some("InvalidFoldClause");
380 }
381 if before_upper.ends_with("ASSUME") {
382 return Some("InvalidAssumeBlock");
383 }
384 if before_upper.ends_with("DERIVE") {
385 return Some("InvalidDeriveCommand");
386 }
387 if before_upper.contains("CREATE RULE") {
389 return Some("InvalidRuleDefinition");
390 }
391 if before_upper.ends_with("QUERY") && !before_upper.contains("CREATE RULE") {
393 return Some("InvalidGoalQuery");
394 }
395 None
396}
397
398fn map_locy_pest_error(input: &str, e: pest::error::Error<locy_parser::Rule>) -> ParseError {
399 let pos = error_position(&e);
400
401 if is_invalid_relationship_pattern(input, pos) {
403 return ParseError::new(format!("LocySyntaxError: InvalidRelationshipPattern - {e}"));
404 }
405 if is_invalid_number_literal(input, pos) {
406 return ParseError::new(format!("LocySyntaxError: InvalidNumberLiteral - {e}"));
407 }
408 if let Some(ch) = invalid_unicode_character(input, pos) {
409 return ParseError::new(format!(
410 "LocySyntaxError: InvalidUnicodeCharacter - Invalid character '{ch}'"
411 ));
412 }
413 if let Some(kw) = reserved_keyword_at(input, pos, LOCY_RESERVED_KEYWORDS) {
414 return ParseError::new(format!(
415 "LocySyntaxError: ReservedKeyword - \"{kw}\" is a reserved keyword \
416 and cannot be used as a variable name. Use backtick-quoting: `{kw}`\n{e}"
417 ));
418 }
419
420 if let Some(category) = locy_context_category(input, pos) {
422 return ParseError::new(format!("LocySyntaxError: {category} - {e}"));
423 }
424
425 ParseError::new(format!("LocySyntaxError: {e}"))
426}
427
428fn map_pest_error(input: &str, e: pest::error::Error<Rule>) -> ParseError {
429 let pos = error_position(&e);
430 if is_invalid_relationship_pattern(input, pos) {
431 return ParseError::new(format!("SyntaxError: InvalidRelationshipPattern - {e}"));
432 }
433 if is_invalid_number_literal(input, pos) {
434 return ParseError::new(format!("SyntaxError: InvalidNumberLiteral - {e}"));
435 }
436 if let Some(ch) = invalid_unicode_character(input, pos) {
437 return ParseError::new(format!(
438 "SyntaxError: InvalidUnicodeCharacter - Invalid character '{ch}'"
439 ));
440 }
441 if let Some(kw) = expects_identifier(&e)
442 .then(|| reserved_keyword_at(input, pos, &[]))
443 .flatten()
444 {
445 return ParseError::new(format!(
446 "SyntaxError: ReservedKeyword - \"{kw}\" is a reserved keyword \
447 and cannot be used as a variable name. Use backtick-quoting: `{kw}`\n{e}"
448 ));
449 }
450
451 ParseError::new(format!("UnexpectedSyntax: {e}"))
452}
453
454#[cfg(test)]
455mod tests {
456 use super::*;
457
458 #[test]
459 fn test_expression_parsing() {
460 let cases = [
461 ("1", Rule::integer),
462 ("3.14", Rule::float),
463 ("'hello'", Rule::string),
464 ("n.name", Rule::expression),
465 ("1 + 2", Rule::expression),
466 ("a AND b OR c", Rule::expression),
467 ];
468
469 for (input, rule) in cases {
470 let result = CypherParser::parse(rule, input);
471 assert!(
472 result.is_ok(),
473 "Failed to parse '{}' as {:?}: {:?}",
474 input,
475 rule,
476 result.err()
477 );
478 }
479 }
480
481 #[test]
482 fn test_list_expressions() {
483 assert!(parse_expression("[]").is_ok());
485
486 assert!(parse_expression("[1, 2, 3]").is_ok());
488
489 assert!(parse_expression("[x IN range(1,10) | x * 2]").is_ok());
491 assert!(parse_expression("[x IN list WHERE x > 5 | x]").is_ok());
492
493 assert!(parse_expression("[(n)-[:KNOWS]->(m) | m.name]").is_ok());
495 assert!(parse_expression("[p = (n)-->(m) WHERE m.age > 30 | p]").is_ok());
496 }
497
498 #[test]
499 fn test_ambiguous_cases() {
500 assert!(parse_expression("[n]").is_ok()); assert!(parse_expression("[n.name]").is_ok()); assert!(parse_expression("[n IN list]").is_ok()); assert!(parse_expression("[(n)]").is_ok()); }
525
526 fn parse_err_msg(input: &str) -> String {
527 parse(input).unwrap_err().to_string()
528 }
529
530 #[test]
531 fn test_invalid_relationship_pattern_missing_star_error_code() {
532 let msg = parse_err_msg("MATCH (a:A)\nMATCH (a)-[:LIKES..]->(c)\nRETURN c.name");
533 assert!(
534 msg.contains("InvalidRelationshipPattern"),
535 "expected InvalidRelationshipPattern, got: {msg}"
536 );
537 }
538
539 #[test]
540 fn test_invalid_number_literal_error_code_decimal_alpha() {
541 let msg = parse_err_msg("RETURN 9223372h54775808 AS literal");
542 assert!(
543 msg.contains("InvalidNumberLiteral"),
544 "expected InvalidNumberLiteral, got: {msg}"
545 );
546 }
547
548 #[test]
549 fn test_invalid_number_literal_error_code_hex_prefix_only() {
550 let msg = parse_err_msg("RETURN 0x AS literal");
551 assert!(
552 msg.contains("InvalidNumberLiteral"),
553 "expected InvalidNumberLiteral, got: {msg}"
554 );
555 }
556
557 #[test]
558 fn test_invalid_unicode_character_error_code() {
559 let msg = parse_err_msg("RETURN 42 — 41");
560 assert!(
561 msg.contains("InvalidUnicodeCharacter"),
562 "expected InvalidUnicodeCharacter, got: {msg}"
563 );
564 }
565
566 #[test]
567 fn test_symbol_in_number_stays_unexpected_syntax() {
568 let msg = parse_err_msg("RETURN 9223372#54775808 AS literal");
569 assert!(
570 msg.contains("UnexpectedSyntax"),
571 "expected UnexpectedSyntax, got: {msg}"
572 );
573 }
574
575 #[test]
576 fn test_map_key_starting_with_number_stays_unexpected_syntax() {
577 let msg = parse_err_msg("RETURN {1B2c3e67:1} AS literal");
578 assert!(
579 msg.contains("UnexpectedSyntax"),
580 "expected UnexpectedSyntax, got: {msg}"
581 );
582 }
583
584 #[test]
585 fn test_unary_minus_double() {
586 use crate::ast::{CypherLiteral, Expr};
587 let expr = parse_expression("--5").expect("--5 should parse");
589 assert_eq!(expr, Expr::Literal(CypherLiteral::Integer(5)));
590 }
591
592 #[test]
593 fn test_unary_minus_single() {
594 use crate::ast::{CypherLiteral, Expr};
595 let expr = parse_expression("-5").expect("-5 should parse");
597 assert_eq!(expr, Expr::Literal(CypherLiteral::Integer(-5)));
598 }
599
600 #[test]
601 fn test_unary_minus_triple() {
602 use crate::ast::{CypherLiteral, Expr};
603 let expr = parse_expression("---5").expect("---5 should parse");
605 assert_eq!(expr, Expr::Literal(CypherLiteral::Integer(-5)));
606 }
607
608 #[test]
609 fn test_unary_plus_identity() {
610 use crate::ast::{CypherLiteral, Expr};
611 let expr = parse_expression("+5").expect("+5 should parse");
613 assert_eq!(expr, Expr::Literal(CypherLiteral::Integer(5)));
614 }
615
616 #[test]
617 fn test_unary_plus_minus() {
618 use crate::ast::{CypherLiteral, Expr};
619 let expr = parse_expression("+-5").expect("+-5 should parse");
621 assert_eq!(expr, Expr::Literal(CypherLiteral::Integer(-5)));
622 }
623
624 #[test]
625 fn test_unary_minus_plus() {
626 use crate::ast::{CypherLiteral, Expr};
627 let expr = parse_expression("-+5").expect("-+5 should parse");
629 assert_eq!(expr, Expr::Literal(CypherLiteral::Integer(-5)));
630 }
631
632 #[test]
633 fn test_unary_double_minus_overflow() {
634 let result = parse_expression("--9223372036854775808");
636 assert!(
637 result.is_err(),
638 "expected overflow error, got: {:?}",
639 result
640 );
641 let msg = result.unwrap_err().to_string();
642 assert!(
643 msg.contains("IntegerOverflow"),
644 "expected IntegerOverflow, got: {msg}"
645 );
646 }
647
648 #[test]
649 fn test_unary_minus_i64_min() {
650 use crate::ast::{CypherLiteral, Expr};
651 let expr = parse_expression("-9223372036854775808").expect("-i64::MIN should parse");
653 assert_eq!(expr, Expr::Literal(CypherLiteral::Integer(i64::MIN)));
654 }
655
656 #[test]
657 fn test_stacked_predicates_is_null_is_not_null() {
658 let result = parse("RETURN x IS NULL IS NOT NULL");
660 assert!(
661 result.is_err(),
662 "expected parse error for stacked IS NULL IS NOT NULL"
663 );
664 let msg = result.unwrap_err().to_string();
665 assert!(
666 msg.contains("InvalidPredicateChain"),
667 "expected InvalidPredicateChain, got: {msg}"
668 );
669 }
670
671 #[test]
672 fn test_stacked_predicates_starts_with() {
673 let result = parse("RETURN x STARTS WITH 'a' STARTS WITH 'b'");
675 assert!(
676 result.is_err(),
677 "expected parse error for stacked STARTS WITH"
678 );
679 let msg = result.unwrap_err().to_string();
680 assert!(
681 msg.contains("InvalidPredicateChain"),
682 "expected InvalidPredicateChain, got: {msg}"
683 );
684 }
685
686 #[test]
687 fn test_stacked_predicates_in() {
688 let result = parse("RETURN x IN [1] IN [true]");
690 assert!(result.is_err(), "expected parse error for stacked IN");
691 let msg = result.unwrap_err().to_string();
692 assert!(
693 msg.contains("InvalidPredicateChain"),
694 "expected InvalidPredicateChain, got: {msg}"
695 );
696 }
697
698 #[test]
699 fn test_stacked_predicates_contains_ends_with() {
700 let result = parse("RETURN x CONTAINS 'a' ENDS WITH 'b'");
702 assert!(
703 result.is_err(),
704 "expected parse error for stacked CONTAINS/ENDS WITH"
705 );
706 let msg = result.unwrap_err().to_string();
707 assert!(
708 msg.contains("InvalidPredicateChain"),
709 "expected InvalidPredicateChain, got: {msg}"
710 );
711 }
712
713 #[test]
714 fn test_label_stacking_allowed() {
715 assert!(
718 parse("MATCH (x) WHERE x:Person:Employee RETURN x").is_ok(),
719 "label stacking should be allowed"
720 );
721 }
722
723 #[test]
724 fn test_range_chaining_allowed() {
725 assert!(
727 parse("MATCH (n) WHERE 1 < n.num < 3 RETURN n").is_ok(),
728 "range chaining 1 < n.num < 3 should be allowed"
729 );
730 }
731
732 #[test]
733 fn test_reserved_keyword_as_variable_name() {
734 let msg = parse_err_msg("MATCH (match:N) RETURN match");
735 assert!(
736 msg.contains("ReservedKeyword"),
737 "expected ReservedKeyword, got: {msg}"
738 );
739 assert!(
740 msg.contains("backtick-quoting"),
741 "expected backtick hint, got: {msg}"
742 );
743 }
744
745 #[test]
746 fn test_reserved_keyword_return_as_variable() {
747 let msg = parse_err_msg("MATCH (return:N) RETURN return");
748 assert!(
749 msg.contains("ReservedKeyword"),
750 "expected ReservedKeyword, got: {msg}"
751 );
752 }
753
754 #[test]
755 fn test_non_reserved_keyword_allowed() {
756 assert!(
758 parse("MATCH (end:N) RETURN end").is_ok(),
759 "non-reserved keyword 'end' should be allowed as variable name"
760 );
761 }
762
763 #[test]
764 fn test_backtick_escaped_reserved_keyword() {
765 assert!(
766 parse("MATCH (`match`:N) RETURN `match`").is_ok(),
767 "backtick-escaped reserved keyword should be allowed"
768 );
769 }
770}