Skip to main content

uni_cypher/grammar/
mod.rs

1pub(crate) mod locy_parser;
2mod locy_walker;
3mod walker;
4
5use crate::ast::{Expr, Query};
6use crate::locy_ast::LocyProgram;
7use pest::Parser;
8use pest_derive::Parser;
9
10/// Error type for Cypher parsing failures.
11#[derive(Debug, thiserror::Error)]
12#[error("{message}")]
13pub struct ParseError {
14    message: String,
15}
16
17impl ParseError {
18    pub fn new(message: String) -> Self {
19        Self { message }
20    }
21}
22
23#[derive(Parser)]
24#[grammar = "grammar/cypher.pest"]
25pub struct CypherParser;
26
27pub fn parse(input: &str) -> Result<Query, ParseError> {
28    let pairs = CypherParser::parse(Rule::query, input).map_err(|e| map_pest_error(input, e))?;
29
30    walker::build_query(pairs)
31}
32
33pub fn parse_expression(input: &str) -> Result<Expr, ParseError> {
34    let pairs =
35        CypherParser::parse(Rule::expression, input).map_err(|e| map_pest_error(input, e))?;
36
37    walker::build_expression(pairs.into_iter().next().unwrap())
38}
39
40pub fn parse_locy(input: &str) -> Result<LocyProgram, ParseError> {
41    use locy_parser::LocyParser;
42    use locy_parser::Rule as LocyRule;
43
44    let pairs = LocyParser::parse(LocyRule::locy_query, input)
45        .map_err(|e| map_locy_pest_error(input, e))?;
46
47    locy_walker::build_program(pairs.into_iter().next().unwrap())
48}
49
50/// Returns true if the pest error expects an identifier-like rule at the error position.
51/// Used to gate the reserved-keyword check so it only fires when a keyword is used
52/// where a variable name is expected, not when it appears after unrelated syntax errors.
53fn expects_identifier(e: &pest::error::Error<Rule>) -> bool {
54    use pest::error::ErrorVariant;
55    match &e.variant {
56        ErrorVariant::ParsingError { positives, .. } => positives
57            .iter()
58            .any(|r| matches!(r, Rule::identifier | Rule::identifier_or_keyword)),
59        _ => false,
60    }
61}
62
63fn error_position<R: pest::RuleType>(e: &pest::error::Error<R>) -> usize {
64    match e.location {
65        pest::error::InputLocation::Pos(p) => p,
66        pest::error::InputLocation::Span((s, _)) => s,
67    }
68}
69
70fn extract_token_span_at(input: &str, pos: usize) -> Option<(usize, usize)> {
71    let bytes = input.as_bytes();
72    if bytes.is_empty() {
73        return None;
74    }
75
76    let mut p = pos.min(bytes.len() - 1);
77
78    let is_token_char =
79        |b: u8| b.is_ascii_alphanumeric() || matches!(b, b'_' | b'-' | b'.' | b'#' | b'$');
80
81    if !is_token_char(bytes[p]) {
82        if p == 0 || !is_token_char(bytes[p - 1]) {
83            return None;
84        }
85        p -= 1;
86    }
87
88    let mut start = p;
89    while start > 0 && is_token_char(bytes[start - 1]) {
90        start -= 1;
91    }
92
93    let mut end = p;
94    while end < bytes.len() && is_token_char(bytes[end]) {
95        end += 1;
96    }
97
98    Some((start, end))
99}
100
101fn is_map_key_like_context(input: &str, start: usize, end: usize) -> bool {
102    let bytes = input.as_bytes();
103    if bytes.is_empty() || start >= bytes.len() || end > bytes.len() {
104        return false;
105    }
106
107    let mut colon_pos = end;
108    while colon_pos < bytes.len() && bytes[colon_pos].is_ascii_whitespace() {
109        colon_pos += 1;
110    }
111    if colon_pos >= bytes.len() || bytes[colon_pos] != b':' {
112        return false;
113    }
114
115    let mut prev_pos = start;
116    while prev_pos > 0 && bytes[prev_pos - 1].is_ascii_whitespace() {
117        prev_pos -= 1;
118    }
119    if prev_pos == 0 {
120        return false;
121    }
122
123    matches!(bytes[prev_pos - 1], b'{' | b',')
124}
125
126fn relationship_bracket_segment(input: &str, pos: usize) -> Option<&str> {
127    let pos = pos.min(input.len());
128    let before = &input[..pos];
129    let start = before.rfind('[')?;
130
131    // Restrict to relationship patterns: ...-[ ... ]-...
132    let prefix = &input[..start];
133    if !prefix.trim_end().ends_with('-') {
134        return None;
135    }
136
137    let after = &input[start..];
138    let end = after.find(']').map(|i| start + i + 1).unwrap_or(pos);
139    Some(&input[start..end])
140}
141
142fn is_invalid_relationship_pattern(input: &str, pos: usize) -> bool {
143    let Some(segment) = relationship_bracket_segment(input, pos) else {
144        return false;
145    };
146    // [:LIKES..] (missing `*`) or [:LIKES*-2] (negative range bound)
147    (segment.contains("..") && !segment.contains('*')) || segment.contains("*-")
148}
149
150fn is_invalid_number_literal(input: &str, pos: usize) -> bool {
151    let Some((start, end)) = extract_token_span_at(input, pos) else {
152        return false;
153    };
154    if is_map_key_like_context(input, start, end) {
155        return false;
156    }
157    let token = &input[start..end];
158
159    let t = token.strip_prefix('-').unwrap_or(token);
160    if !t.as_bytes().first().is_some_and(|b| b.is_ascii_digit()) {
161        return false;
162    }
163
164    let has_only = |digits: &str, valid: fn(&char) -> bool| {
165        digits.is_empty() || !digits.chars().all(|c| valid(&c) || c == '_')
166    };
167
168    if let Some(digits) = t.strip_prefix("0x").or_else(|| t.strip_prefix("0X")) {
169        return has_only(digits, char::is_ascii_hexdigit);
170    }
171    if let Some(digits) = t.strip_prefix("0o").or_else(|| t.strip_prefix("0O")) {
172        return has_only(digits, |c| matches!(c, '0'..='7'));
173    }
174
175    // Decimal-like token with alphabetic suffix/midfix, e.g. 9223372h54775808
176    t.chars().any(|c| c.is_ascii_alphabetic())
177}
178
179fn invalid_unicode_character(input: &str, pos: usize) -> Option<char> {
180    let ch = input.get(pos..)?.chars().next()?;
181    matches!(ch, '—' | '–' | '−').then_some(ch)
182}
183
184/// All Cypher reserved keywords (from `keyword_reserved` in cypher.pest).
185/// Stored lowercase for case-insensitive comparison.
186const CYPHER_RESERVED_KEYWORDS: &[&str] = &[
187    "match",
188    "optional",
189    "where",
190    "create",
191    "merge",
192    "set",
193    "remove",
194    "delete",
195    "detach",
196    "return",
197    "with",
198    "unwind",
199    "union",
200    "call",
201    "yield",
202    "distinct",
203    "order",
204    "by",
205    "asc",
206    "desc",
207    "skip",
208    "limit",
209    "as",
210    "and",
211    "or",
212    "xor",
213    "not",
214    "in",
215    "contains",
216    "starts",
217    "ends",
218    "is",
219    "null",
220    "true",
221    "false",
222    "case",
223    "when",
224    "then",
225    "else",
226    "if",
227    "from",
228    "to",
229    "on",
230    "drop",
231    "alter",
232    "show",
233    "over",
234    "partition",
235    "explain",
236    "recursive",
237    "valid_at",
238    "each",
239];
240
241/// Additional Locy-only reserved keywords (from `locy_keyword_reserved` in locy.pest).
242const LOCY_RESERVED_KEYWORDS: &[&str] = &[
243    "rule", "along", "prev", "fold", "best", "derive", "assume", "abduce", "query",
244];
245
246/// If the token at the error position is a reserved keyword, return it.
247fn reserved_keyword_at(input: &str, pos: usize, extra_keywords: &[&str]) -> Option<String> {
248    let (start, end) = extract_token_span_at(input, pos)?;
249    let token = &input[start..end];
250    let lower = token.to_lowercase();
251    if CYPHER_RESERVED_KEYWORDS.contains(&lower.as_str())
252        || extra_keywords.contains(&lower.as_str())
253    {
254        Some(token.to_string())
255    } else {
256        None
257    }
258}
259
260/// Categorize a Locy parse error based on context before the error position.
261fn locy_context_category(input: &str, pos: usize) -> Option<&'static str> {
262    let before = input[..pos].trim_end();
263    let before_upper = before.to_uppercase();
264    // Check in reverse order of specificity
265    if before_upper.ends_with("BEST BY") {
266        return Some("InvalidBestByClause");
267    }
268    if before_upper.ends_with("ALONG") {
269        return Some("InvalidAlongClause");
270    }
271    if before_upper.ends_with("FOLD") {
272        return Some("InvalidFoldClause");
273    }
274    if before_upper.ends_with("ASSUME") {
275        return Some("InvalidAssumeBlock");
276    }
277    if before_upper.ends_with("DERIVE") {
278        return Some("InvalidDeriveCommand");
279    }
280    // Check for CREATE RULE (may have name/priority between)
281    if before_upper.contains("CREATE RULE") {
282        return Some("InvalidRuleDefinition");
283    }
284    // Standalone QUERY (not part of CREATE RULE ... YIELD ... QUERY)
285    if before_upper.ends_with("QUERY") && !before_upper.contains("CREATE RULE") {
286        return Some("InvalidGoalQuery");
287    }
288    None
289}
290
291fn map_locy_pest_error(input: &str, e: pest::error::Error<locy_parser::Rule>) -> ParseError {
292    let pos = error_position(&e);
293
294    // Reuse input-based heuristics from the Cypher parser
295    if is_invalid_relationship_pattern(input, pos) {
296        return ParseError::new(format!("LocySyntaxError: InvalidRelationshipPattern - {e}"));
297    }
298    if is_invalid_number_literal(input, pos) {
299        return ParseError::new(format!("LocySyntaxError: InvalidNumberLiteral - {e}"));
300    }
301    if let Some(ch) = invalid_unicode_character(input, pos) {
302        return ParseError::new(format!(
303            "LocySyntaxError: InvalidUnicodeCharacter - Invalid character '{ch}'"
304        ));
305    }
306    if let Some(kw) = reserved_keyword_at(input, pos, LOCY_RESERVED_KEYWORDS) {
307        return ParseError::new(format!(
308            "LocySyntaxError: ReservedKeyword - \"{kw}\" is a reserved keyword \
309             and cannot be used as a variable name. Use backtick-quoting: `{kw}`\n{e}"
310        ));
311    }
312
313    // Locy-specific context categorization
314    if let Some(category) = locy_context_category(input, pos) {
315        return ParseError::new(format!("LocySyntaxError: {category} - {e}"));
316    }
317
318    ParseError::new(format!("LocySyntaxError: {e}"))
319}
320
321fn map_pest_error(input: &str, e: pest::error::Error<Rule>) -> ParseError {
322    let pos = error_position(&e);
323    if is_invalid_relationship_pattern(input, pos) {
324        return ParseError::new(format!("SyntaxError: InvalidRelationshipPattern - {e}"));
325    }
326    if is_invalid_number_literal(input, pos) {
327        return ParseError::new(format!("SyntaxError: InvalidNumberLiteral - {e}"));
328    }
329    if let Some(ch) = invalid_unicode_character(input, pos) {
330        return ParseError::new(format!(
331            "SyntaxError: InvalidUnicodeCharacter - Invalid character '{ch}'"
332        ));
333    }
334    if let Some(kw) = expects_identifier(&e)
335        .then(|| reserved_keyword_at(input, pos, &[]))
336        .flatten()
337    {
338        return ParseError::new(format!(
339            "SyntaxError: ReservedKeyword - \"{kw}\" is a reserved keyword \
340             and cannot be used as a variable name. Use backtick-quoting: `{kw}`\n{e}"
341        ));
342    }
343
344    ParseError::new(format!("UnexpectedSyntax: {e}"))
345}
346
347#[cfg(test)]
348mod tests {
349    use super::*;
350
351    #[test]
352    fn test_expression_parsing() {
353        let cases = [
354            ("1", Rule::integer),
355            ("3.14", Rule::float),
356            ("'hello'", Rule::string),
357            ("n.name", Rule::expression),
358            ("1 + 2", Rule::expression),
359            ("a AND b OR c", Rule::expression),
360        ];
361
362        for (input, rule) in cases {
363            let result = CypherParser::parse(rule, input);
364            assert!(
365                result.is_ok(),
366                "Failed to parse '{}' as {:?}: {:?}",
367                input,
368                rule,
369                result.err()
370            );
371        }
372    }
373
374    #[test]
375    fn test_list_expressions() {
376        // Empty list
377        assert!(parse_expression("[]").is_ok());
378
379        // List literal
380        assert!(parse_expression("[1, 2, 3]").is_ok());
381
382        // List comprehension
383        assert!(parse_expression("[x IN range(1,10) | x * 2]").is_ok());
384        assert!(parse_expression("[x IN list WHERE x > 5 | x]").is_ok());
385
386        // Pattern comprehension - THE KEY TEST
387        assert!(parse_expression("[(n)-[:KNOWS]->(m) | m.name]").is_ok());
388        assert!(parse_expression("[p = (n)-->(m) WHERE m.age > 30 | p]").is_ok());
389    }
390
391    #[test]
392    fn test_ambiguous_cases() {
393        // These caused LR(1) conflicts before
394        assert!(parse_expression("[n]").is_ok()); // List with variable
395        assert!(parse_expression("[n.name]").is_ok()); // List with property
396        assert!(parse_expression("[n IN list]").is_ok()); // Comprehension? No, missing |, so list with boolean IN expression?
397        // Wait, [n IN list] in Cypher is valid list literal containing one boolean expression `n IN list`.
398        // UNLESS it's a comprehension. Comprehension MUST have `|`.
399        // My grammar handles this:
400        // list_expression = { ... | "[" ~ list_comprehension_body ~ "]" | ... }
401        // list_comprehension_body = { identifier ~ IN ~ comprehension_expr ~ ... ~ pipe ~ expression }
402        // So `[n IN list]` matches `list_literal` containing `expression(n IN list)`.
403        // It does NOT match `list_comprehension_body` because of missing pipe.
404        // Correct.
405
406        assert!(parse_expression("[(n)]").is_ok()); // Pattern comprehension? No, pattern comprehension must have pattern.
407        // `[(n)]` -> List literal containing parenthesized expression `(n)` (node pattern used as expr? No, `(n)` is node pattern).
408        // But `(n)` as expression?
409        // `primary_expression` -> `(` expression `)`.
410        // If `n` is identifier, `(n)` is expression.
411        // So `[(n)]` is list literal.
412        // `[(n)-->(m)]`? List literal containing boolean pattern expression?
413        // Yes, `pattern_expression` is valid in `boolean_primary`.
414        // `pattern_comprehension` requires `|`.
415        // `[(n)-->(m) | x]` is comprehension.
416        // `[(n)-->(m)]` is list of pattern expression.
417    }
418
419    fn parse_err_msg(input: &str) -> String {
420        parse(input).unwrap_err().to_string()
421    }
422
423    #[test]
424    fn test_invalid_relationship_pattern_missing_star_error_code() {
425        let msg = parse_err_msg("MATCH (a:A)\nMATCH (a)-[:LIKES..]->(c)\nRETURN c.name");
426        assert!(
427            msg.contains("InvalidRelationshipPattern"),
428            "expected InvalidRelationshipPattern, got: {msg}"
429        );
430    }
431
432    #[test]
433    fn test_invalid_number_literal_error_code_decimal_alpha() {
434        let msg = parse_err_msg("RETURN 9223372h54775808 AS literal");
435        assert!(
436            msg.contains("InvalidNumberLiteral"),
437            "expected InvalidNumberLiteral, got: {msg}"
438        );
439    }
440
441    #[test]
442    fn test_invalid_number_literal_error_code_hex_prefix_only() {
443        let msg = parse_err_msg("RETURN 0x AS literal");
444        assert!(
445            msg.contains("InvalidNumberLiteral"),
446            "expected InvalidNumberLiteral, got: {msg}"
447        );
448    }
449
450    #[test]
451    fn test_invalid_unicode_character_error_code() {
452        let msg = parse_err_msg("RETURN 42 — 41");
453        assert!(
454            msg.contains("InvalidUnicodeCharacter"),
455            "expected InvalidUnicodeCharacter, got: {msg}"
456        );
457    }
458
459    #[test]
460    fn test_symbol_in_number_stays_unexpected_syntax() {
461        let msg = parse_err_msg("RETURN 9223372#54775808 AS literal");
462        assert!(
463            msg.contains("UnexpectedSyntax"),
464            "expected UnexpectedSyntax, got: {msg}"
465        );
466    }
467
468    #[test]
469    fn test_map_key_starting_with_number_stays_unexpected_syntax() {
470        let msg = parse_err_msg("RETURN {1B2c3e67:1} AS literal");
471        assert!(
472            msg.contains("UnexpectedSyntax"),
473            "expected UnexpectedSyntax, got: {msg}"
474        );
475    }
476
477    #[test]
478    fn test_unary_minus_double() {
479        use crate::ast::{CypherLiteral, Expr};
480        // --5 → Integer(5)
481        let expr = parse_expression("--5").expect("--5 should parse");
482        assert_eq!(expr, Expr::Literal(CypherLiteral::Integer(5)));
483    }
484
485    #[test]
486    fn test_unary_minus_single() {
487        use crate::ast::{CypherLiteral, Expr};
488        // -5 → Integer(-5)
489        let expr = parse_expression("-5").expect("-5 should parse");
490        assert_eq!(expr, Expr::Literal(CypherLiteral::Integer(-5)));
491    }
492
493    #[test]
494    fn test_unary_minus_triple() {
495        use crate::ast::{CypherLiteral, Expr};
496        // ---5 → Integer(-5)
497        let expr = parse_expression("---5").expect("---5 should parse");
498        assert_eq!(expr, Expr::Literal(CypherLiteral::Integer(-5)));
499    }
500
501    #[test]
502    fn test_unary_plus_identity() {
503        use crate::ast::{CypherLiteral, Expr};
504        // +5 → Integer(5)
505        let expr = parse_expression("+5").expect("+5 should parse");
506        assert_eq!(expr, Expr::Literal(CypherLiteral::Integer(5)));
507    }
508
509    #[test]
510    fn test_unary_plus_minus() {
511        use crate::ast::{CypherLiteral, Expr};
512        // +-5 → Integer(-5)
513        let expr = parse_expression("+-5").expect("+-5 should parse");
514        assert_eq!(expr, Expr::Literal(CypherLiteral::Integer(-5)));
515    }
516
517    #[test]
518    fn test_unary_minus_plus() {
519        use crate::ast::{CypherLiteral, Expr};
520        // -+5 → Integer(-5)
521        let expr = parse_expression("-+5").expect("-+5 should parse");
522        assert_eq!(expr, Expr::Literal(CypherLiteral::Integer(-5)));
523    }
524
525    #[test]
526    fn test_unary_double_minus_overflow() {
527        // --9223372036854775808 → overflow error
528        let result = parse_expression("--9223372036854775808");
529        assert!(
530            result.is_err(),
531            "expected overflow error, got: {:?}",
532            result
533        );
534        let msg = result.unwrap_err().to_string();
535        assert!(
536            msg.contains("IntegerOverflow"),
537            "expected IntegerOverflow, got: {msg}"
538        );
539    }
540
541    #[test]
542    fn test_unary_minus_i64_min() {
543        use crate::ast::{CypherLiteral, Expr};
544        // -9223372036854775808 → Integer(i64::MIN) (valid)
545        let expr = parse_expression("-9223372036854775808").expect("-i64::MIN should parse");
546        assert_eq!(expr, Expr::Literal(CypherLiteral::Integer(i64::MIN)));
547    }
548
549    #[test]
550    fn test_stacked_predicates_is_null_is_not_null() {
551        // x IS NULL IS NOT NULL → error
552        let result = parse("RETURN x IS NULL IS NOT NULL");
553        assert!(
554            result.is_err(),
555            "expected parse error for stacked IS NULL IS NOT NULL"
556        );
557        let msg = result.unwrap_err().to_string();
558        assert!(
559            msg.contains("InvalidPredicateChain"),
560            "expected InvalidPredicateChain, got: {msg}"
561        );
562    }
563
564    #[test]
565    fn test_stacked_predicates_starts_with() {
566        // x STARTS WITH 'a' STARTS WITH 'b' → error
567        let result = parse("RETURN x STARTS WITH 'a' STARTS WITH 'b'");
568        assert!(
569            result.is_err(),
570            "expected parse error for stacked STARTS WITH"
571        );
572        let msg = result.unwrap_err().to_string();
573        assert!(
574            msg.contains("InvalidPredicateChain"),
575            "expected InvalidPredicateChain, got: {msg}"
576        );
577    }
578
579    #[test]
580    fn test_stacked_predicates_in() {
581        // x IN [1] IN [true] → error
582        let result = parse("RETURN x IN [1] IN [true]");
583        assert!(result.is_err(), "expected parse error for stacked IN");
584        let msg = result.unwrap_err().to_string();
585        assert!(
586            msg.contains("InvalidPredicateChain"),
587            "expected InvalidPredicateChain, got: {msg}"
588        );
589    }
590
591    #[test]
592    fn test_stacked_predicates_contains_ends_with() {
593        // x CONTAINS 'a' ENDS WITH 'b' → error
594        let result = parse("RETURN x CONTAINS 'a' ENDS WITH 'b'");
595        assert!(
596            result.is_err(),
597            "expected parse error for stacked CONTAINS/ENDS WITH"
598        );
599        let msg = result.unwrap_err().to_string();
600        assert!(
601            msg.contains("InvalidPredicateChain"),
602            "expected InvalidPredicateChain, got: {msg}"
603        );
604    }
605
606    #[test]
607    fn test_label_stacking_allowed() {
608        // x :Person :Employee → OK (label stacking is valid)
609        // Note: label predicates in comparison context are valid
610        assert!(
611            parse("MATCH (x) WHERE x:Person:Employee RETURN x").is_ok(),
612            "label stacking should be allowed"
613        );
614    }
615
616    #[test]
617    fn test_range_chaining_allowed() {
618        // 1 < n.num < 3 → OK (required by TCK Comparison3)
619        assert!(
620            parse("MATCH (n) WHERE 1 < n.num < 3 RETURN n").is_ok(),
621            "range chaining 1 < n.num < 3 should be allowed"
622        );
623    }
624
625    #[test]
626    fn test_reserved_keyword_as_variable_name() {
627        let msg = parse_err_msg("MATCH (match:N) RETURN match");
628        assert!(
629            msg.contains("ReservedKeyword"),
630            "expected ReservedKeyword, got: {msg}"
631        );
632        assert!(
633            msg.contains("backtick-quoting"),
634            "expected backtick hint, got: {msg}"
635        );
636    }
637
638    #[test]
639    fn test_reserved_keyword_return_as_variable() {
640        let msg = parse_err_msg("MATCH (return:N) RETURN return");
641        assert!(
642            msg.contains("ReservedKeyword"),
643            "expected ReservedKeyword, got: {msg}"
644        );
645    }
646
647    #[test]
648    fn test_non_reserved_keyword_allowed() {
649        // `end` was moved to keyword_nonreserved — should parse fine
650        assert!(
651            parse("MATCH (end:N) RETURN end").is_ok(),
652            "non-reserved keyword 'end' should be allowed as variable name"
653        );
654    }
655
656    #[test]
657    fn test_backtick_escaped_reserved_keyword() {
658        assert!(
659            parse("MATCH (`match`:N) RETURN `match`").is_ok(),
660            "backtick-escaped reserved keyword should be allowed"
661        );
662    }
663}