Skip to main content

uni_cypher/grammar/
mod.rs

1pub(crate) mod locy_parser;
2mod locy_walker;
3mod walker;
4
5use crate::ast::{Expr, Query};
6use crate::locy_ast::LocyProgram;
7use pest::Parser;
8use pest_derive::Parser;
9
10/// Error type for Cypher parsing failures.
11#[derive(Debug, thiserror::Error)]
12#[error("{message}")]
13pub struct ParseError {
14    message: String,
15}
16
17impl ParseError {
18    pub fn new(message: String) -> Self {
19        Self { message }
20    }
21}
22
23#[derive(Parser)]
24#[grammar = "grammar/cypher.pest"]
25pub struct CypherParser;
26
27pub fn parse(input: &str) -> Result<Query, ParseError> {
28    let pairs = CypherParser::parse(Rule::query, input).map_err(|e| map_pest_error(input, e))?;
29
30    walker::build_query(pairs)
31}
32
33pub fn parse_expression(input: &str) -> Result<Expr, ParseError> {
34    let pairs =
35        CypherParser::parse(Rule::expression, input).map_err(|e| map_pest_error(input, e))?;
36
37    walker::build_expression(pairs.into_iter().next().unwrap())
38}
39
40pub fn parse_locy(input: &str) -> Result<LocyProgram, ParseError> {
41    use locy_parser::LocyParser;
42    use locy_parser::Rule as LocyRule;
43
44    let pairs = LocyParser::parse(LocyRule::locy_query, input)
45        .map_err(|e| map_locy_pest_error(input, e))?;
46
47    locy_walker::build_program(pairs.into_iter().next().unwrap())
48}
49
50fn error_position(e: &pest::error::Error<Rule>) -> usize {
51    match e.location {
52        pest::error::InputLocation::Pos(p) => p,
53        pest::error::InputLocation::Span((s, _)) => s,
54    }
55}
56
57fn extract_token_span_at(input: &str, pos: usize) -> Option<(usize, usize)> {
58    let bytes = input.as_bytes();
59    if bytes.is_empty() {
60        return None;
61    }
62
63    let mut p = pos.min(bytes.len() - 1);
64
65    let is_token_char =
66        |b: u8| b.is_ascii_alphanumeric() || matches!(b, b'_' | b'-' | b'.' | b'#' | b'$');
67
68    if !is_token_char(bytes[p]) {
69        if p == 0 || !is_token_char(bytes[p - 1]) {
70            return None;
71        }
72        p -= 1;
73    }
74
75    let mut start = p;
76    while start > 0 && is_token_char(bytes[start - 1]) {
77        start -= 1;
78    }
79
80    let mut end = p;
81    while end < bytes.len() && is_token_char(bytes[end]) {
82        end += 1;
83    }
84
85    Some((start, end))
86}
87
88fn is_map_key_like_context(input: &str, start: usize, end: usize) -> bool {
89    let bytes = input.as_bytes();
90    if bytes.is_empty() || start >= bytes.len() || end > bytes.len() {
91        return false;
92    }
93
94    let mut colon_pos = end;
95    while colon_pos < bytes.len() && bytes[colon_pos].is_ascii_whitespace() {
96        colon_pos += 1;
97    }
98    if colon_pos >= bytes.len() || bytes[colon_pos] != b':' {
99        return false;
100    }
101
102    let mut prev_pos = start;
103    while prev_pos > 0 && bytes[prev_pos - 1].is_ascii_whitespace() {
104        prev_pos -= 1;
105    }
106    if prev_pos == 0 {
107        return false;
108    }
109
110    matches!(bytes[prev_pos - 1], b'{' | b',')
111}
112
113fn relationship_bracket_segment(input: &str, pos: usize) -> Option<&str> {
114    let pos = pos.min(input.len());
115    let before = &input[..pos];
116    let start = before.rfind('[')?;
117
118    // Restrict to relationship patterns: ...-[ ... ]-...
119    let prefix = &input[..start];
120    if !prefix.trim_end().ends_with('-') {
121        return None;
122    }
123
124    let after = &input[start..];
125    let end = after.find(']').map(|i| start + i + 1).unwrap_or(pos);
126    Some(&input[start..end])
127}
128
129fn is_invalid_relationship_pattern(input: &str, pos: usize) -> bool {
130    let Some(segment) = relationship_bracket_segment(input, pos) else {
131        return false;
132    };
133    // [:LIKES..] (missing `*`) or [:LIKES*-2] (negative range bound)
134    (segment.contains("..") && !segment.contains('*')) || segment.contains("*-")
135}
136
137fn is_invalid_number_literal(input: &str, pos: usize) -> bool {
138    let Some((start, end)) = extract_token_span_at(input, pos) else {
139        return false;
140    };
141    if is_map_key_like_context(input, start, end) {
142        return false;
143    }
144    let token = &input[start..end];
145
146    let t = token.strip_prefix('-').unwrap_or(token);
147    if !t.as_bytes().first().is_some_and(|b| b.is_ascii_digit()) {
148        return false;
149    }
150
151    let has_only = |digits: &str, valid: fn(&char) -> bool| {
152        digits.is_empty() || !digits.chars().all(|c| valid(&c) || c == '_')
153    };
154
155    if let Some(digits) = t.strip_prefix("0x").or_else(|| t.strip_prefix("0X")) {
156        return has_only(digits, char::is_ascii_hexdigit);
157    }
158    if let Some(digits) = t.strip_prefix("0o").or_else(|| t.strip_prefix("0O")) {
159        return has_only(digits, |c| matches!(c, '0'..='7'));
160    }
161
162    // Decimal-like token with alphabetic suffix/midfix, e.g. 9223372h54775808
163    t.chars().any(|c| c.is_ascii_alphabetic())
164}
165
166fn invalid_unicode_character(input: &str, pos: usize) -> Option<char> {
167    let ch = input.get(pos..)?.chars().next()?;
168    matches!(ch, '—' | '–' | '−').then_some(ch)
169}
170
171fn locy_error_position(e: &pest::error::Error<locy_parser::Rule>) -> usize {
172    match e.location {
173        pest::error::InputLocation::Pos(p) => p,
174        pest::error::InputLocation::Span((s, _)) => s,
175    }
176}
177
178/// Categorize a Locy parse error based on context before the error position.
179fn locy_context_category(input: &str, pos: usize) -> Option<&'static str> {
180    let before = input[..pos].trim_end();
181    let before_upper = before.to_uppercase();
182    // Check in reverse order of specificity
183    if before_upper.ends_with("BEST BY") {
184        return Some("InvalidBestByClause");
185    }
186    if before_upper.ends_with("ALONG") {
187        return Some("InvalidAlongClause");
188    }
189    if before_upper.ends_with("FOLD") {
190        return Some("InvalidFoldClause");
191    }
192    if before_upper.ends_with("ASSUME") {
193        return Some("InvalidAssumeBlock");
194    }
195    if before_upper.ends_with("DERIVE") {
196        return Some("InvalidDeriveCommand");
197    }
198    // Check for CREATE RULE (may have name/priority between)
199    if before_upper.contains("CREATE RULE") {
200        return Some("InvalidRuleDefinition");
201    }
202    // Standalone QUERY (not part of CREATE RULE ... YIELD ... QUERY)
203    if before_upper.ends_with("QUERY") && !before_upper.contains("CREATE RULE") {
204        return Some("InvalidGoalQuery");
205    }
206    None
207}
208
209fn map_locy_pest_error(input: &str, e: pest::error::Error<locy_parser::Rule>) -> ParseError {
210    let pos = locy_error_position(&e);
211
212    // Reuse input-based heuristics from the Cypher parser
213    if is_invalid_relationship_pattern(input, pos) {
214        return ParseError::new(format!("LocySyntaxError: InvalidRelationshipPattern - {e}"));
215    }
216    if is_invalid_number_literal(input, pos) {
217        return ParseError::new(format!("LocySyntaxError: InvalidNumberLiteral - {e}"));
218    }
219    if let Some(ch) = invalid_unicode_character(input, pos) {
220        return ParseError::new(format!(
221            "LocySyntaxError: InvalidUnicodeCharacter - Invalid character '{ch}'"
222        ));
223    }
224
225    // Locy-specific context categorization
226    if let Some(category) = locy_context_category(input, pos) {
227        return ParseError::new(format!("LocySyntaxError: {category} - {e}"));
228    }
229
230    ParseError::new(format!("LocySyntaxError: {e}"))
231}
232
233fn map_pest_error(input: &str, e: pest::error::Error<Rule>) -> ParseError {
234    let pos = error_position(&e);
235    if is_invalid_relationship_pattern(input, pos) {
236        return ParseError::new(format!("SyntaxError: InvalidRelationshipPattern - {e}"));
237    }
238    if is_invalid_number_literal(input, pos) {
239        return ParseError::new(format!("SyntaxError: InvalidNumberLiteral - {e}"));
240    }
241    if let Some(ch) = invalid_unicode_character(input, pos) {
242        return ParseError::new(format!(
243            "SyntaxError: InvalidUnicodeCharacter - Invalid character '{ch}'"
244        ));
245    }
246
247    ParseError::new(format!("UnexpectedSyntax: {e}"))
248}
249
250#[cfg(test)]
251mod tests {
252    use super::*;
253
254    #[test]
255    fn test_expression_parsing() {
256        let cases = [
257            ("1", Rule::integer),
258            ("3.14", Rule::float),
259            ("'hello'", Rule::string),
260            ("n.name", Rule::expression),
261            ("1 + 2", Rule::expression),
262            ("a AND b OR c", Rule::expression),
263        ];
264
265        for (input, rule) in cases {
266            let result = CypherParser::parse(rule, input);
267            assert!(
268                result.is_ok(),
269                "Failed to parse '{}' as {:?}: {:?}",
270                input,
271                rule,
272                result.err()
273            );
274        }
275    }
276
277    #[test]
278    fn test_list_expressions() {
279        // Empty list
280        assert!(parse_expression("[]").is_ok());
281
282        // List literal
283        assert!(parse_expression("[1, 2, 3]").is_ok());
284
285        // List comprehension
286        assert!(parse_expression("[x IN range(1,10) | x * 2]").is_ok());
287        assert!(parse_expression("[x IN list WHERE x > 5 | x]").is_ok());
288
289        // Pattern comprehension - THE KEY TEST
290        assert!(parse_expression("[(n)-[:KNOWS]->(m) | m.name]").is_ok());
291        assert!(parse_expression("[p = (n)-->(m) WHERE m.age > 30 | p]").is_ok());
292    }
293
294    #[test]
295    fn test_ambiguous_cases() {
296        // These caused LR(1) conflicts before
297        assert!(parse_expression("[n]").is_ok()); // List with variable
298        assert!(parse_expression("[n.name]").is_ok()); // List with property
299        assert!(parse_expression("[n IN list]").is_ok()); // Comprehension? No, missing |, so list with boolean IN expression?
300        // Wait, [n IN list] in Cypher is valid list literal containing one boolean expression `n IN list`.
301        // UNLESS it's a comprehension. Comprehension MUST have `|`.
302        // My grammar handles this:
303        // list_expression = { ... | "[" ~ list_comprehension_body ~ "]" | ... }
304        // list_comprehension_body = { identifier ~ IN ~ comprehension_expr ~ ... ~ pipe ~ expression }
305        // So `[n IN list]` matches `list_literal` containing `expression(n IN list)`.
306        // It does NOT match `list_comprehension_body` because of missing pipe.
307        // Correct.
308
309        assert!(parse_expression("[(n)]").is_ok()); // Pattern comprehension? No, pattern comprehension must have pattern.
310        // `[(n)]` -> List literal containing parenthesized expression `(n)` (node pattern used as expr? No, `(n)` is node pattern).
311        // But `(n)` as expression?
312        // `primary_expression` -> `(` expression `)`.
313        // If `n` is identifier, `(n)` is expression.
314        // So `[(n)]` is list literal.
315        // `[(n)-->(m)]`? List literal containing boolean pattern expression?
316        // Yes, `pattern_expression` is valid in `boolean_primary`.
317        // `pattern_comprehension` requires `|`.
318        // `[(n)-->(m) | x]` is comprehension.
319        // `[(n)-->(m)]` is list of pattern expression.
320    }
321
322    fn parse_err_msg(input: &str) -> String {
323        parse(input).unwrap_err().to_string()
324    }
325
326    #[test]
327    fn test_invalid_relationship_pattern_missing_star_error_code() {
328        let msg = parse_err_msg("MATCH (a:A)\nMATCH (a)-[:LIKES..]->(c)\nRETURN c.name");
329        assert!(
330            msg.contains("InvalidRelationshipPattern"),
331            "expected InvalidRelationshipPattern, got: {msg}"
332        );
333    }
334
335    #[test]
336    fn test_invalid_number_literal_error_code_decimal_alpha() {
337        let msg = parse_err_msg("RETURN 9223372h54775808 AS literal");
338        assert!(
339            msg.contains("InvalidNumberLiteral"),
340            "expected InvalidNumberLiteral, got: {msg}"
341        );
342    }
343
344    #[test]
345    fn test_invalid_number_literal_error_code_hex_prefix_only() {
346        let msg = parse_err_msg("RETURN 0x AS literal");
347        assert!(
348            msg.contains("InvalidNumberLiteral"),
349            "expected InvalidNumberLiteral, got: {msg}"
350        );
351    }
352
353    #[test]
354    fn test_invalid_unicode_character_error_code() {
355        let msg = parse_err_msg("RETURN 42 — 41");
356        assert!(
357            msg.contains("InvalidUnicodeCharacter"),
358            "expected InvalidUnicodeCharacter, got: {msg}"
359        );
360    }
361
362    #[test]
363    fn test_symbol_in_number_stays_unexpected_syntax() {
364        let msg = parse_err_msg("RETURN 9223372#54775808 AS literal");
365        assert!(
366            msg.contains("UnexpectedSyntax"),
367            "expected UnexpectedSyntax, got: {msg}"
368        );
369    }
370
371    #[test]
372    fn test_map_key_starting_with_number_stays_unexpected_syntax() {
373        let msg = parse_err_msg("RETURN {1B2c3e67:1} AS literal");
374        assert!(
375            msg.contains("UnexpectedSyntax"),
376            "expected UnexpectedSyntax, got: {msg}"
377        );
378    }
379
380    #[test]
381    fn test_unary_minus_double() {
382        use crate::ast::{CypherLiteral, Expr};
383        // --5 → Integer(5)
384        let expr = parse_expression("--5").expect("--5 should parse");
385        assert_eq!(expr, Expr::Literal(CypherLiteral::Integer(5)));
386    }
387
388    #[test]
389    fn test_unary_minus_single() {
390        use crate::ast::{CypherLiteral, Expr};
391        // -5 → Integer(-5)
392        let expr = parse_expression("-5").expect("-5 should parse");
393        assert_eq!(expr, Expr::Literal(CypherLiteral::Integer(-5)));
394    }
395
396    #[test]
397    fn test_unary_minus_triple() {
398        use crate::ast::{CypherLiteral, Expr};
399        // ---5 → Integer(-5)
400        let expr = parse_expression("---5").expect("---5 should parse");
401        assert_eq!(expr, Expr::Literal(CypherLiteral::Integer(-5)));
402    }
403
404    #[test]
405    fn test_unary_plus_identity() {
406        use crate::ast::{CypherLiteral, Expr};
407        // +5 → Integer(5)
408        let expr = parse_expression("+5").expect("+5 should parse");
409        assert_eq!(expr, Expr::Literal(CypherLiteral::Integer(5)));
410    }
411
412    #[test]
413    fn test_unary_plus_minus() {
414        use crate::ast::{CypherLiteral, Expr};
415        // +-5 → Integer(-5)
416        let expr = parse_expression("+-5").expect("+-5 should parse");
417        assert_eq!(expr, Expr::Literal(CypherLiteral::Integer(-5)));
418    }
419
420    #[test]
421    fn test_unary_minus_plus() {
422        use crate::ast::{CypherLiteral, Expr};
423        // -+5 → Integer(-5)
424        let expr = parse_expression("-+5").expect("-+5 should parse");
425        assert_eq!(expr, Expr::Literal(CypherLiteral::Integer(-5)));
426    }
427
428    #[test]
429    fn test_unary_double_minus_overflow() {
430        // --9223372036854775808 → overflow error
431        let result = parse_expression("--9223372036854775808");
432        assert!(
433            result.is_err(),
434            "expected overflow error, got: {:?}",
435            result
436        );
437        let msg = result.unwrap_err().to_string();
438        assert!(
439            msg.contains("IntegerOverflow"),
440            "expected IntegerOverflow, got: {msg}"
441        );
442    }
443
444    #[test]
445    fn test_unary_minus_i64_min() {
446        use crate::ast::{CypherLiteral, Expr};
447        // -9223372036854775808 → Integer(i64::MIN) (valid)
448        let expr = parse_expression("-9223372036854775808").expect("-i64::MIN should parse");
449        assert_eq!(expr, Expr::Literal(CypherLiteral::Integer(i64::MIN)));
450    }
451
452    #[test]
453    fn test_stacked_predicates_is_null_is_not_null() {
454        // x IS NULL IS NOT NULL → error
455        let result = parse("RETURN x IS NULL IS NOT NULL");
456        assert!(
457            result.is_err(),
458            "expected parse error for stacked IS NULL IS NOT NULL"
459        );
460        let msg = result.unwrap_err().to_string();
461        assert!(
462            msg.contains("InvalidPredicateChain"),
463            "expected InvalidPredicateChain, got: {msg}"
464        );
465    }
466
467    #[test]
468    fn test_stacked_predicates_starts_with() {
469        // x STARTS WITH 'a' STARTS WITH 'b' → error
470        let result = parse("RETURN x STARTS WITH 'a' STARTS WITH 'b'");
471        assert!(
472            result.is_err(),
473            "expected parse error for stacked STARTS WITH"
474        );
475        let msg = result.unwrap_err().to_string();
476        assert!(
477            msg.contains("InvalidPredicateChain"),
478            "expected InvalidPredicateChain, got: {msg}"
479        );
480    }
481
482    #[test]
483    fn test_stacked_predicates_in() {
484        // x IN [1] IN [true] → error
485        let result = parse("RETURN x IN [1] IN [true]");
486        assert!(result.is_err(), "expected parse error for stacked IN");
487        let msg = result.unwrap_err().to_string();
488        assert!(
489            msg.contains("InvalidPredicateChain"),
490            "expected InvalidPredicateChain, got: {msg}"
491        );
492    }
493
494    #[test]
495    fn test_stacked_predicates_contains_ends_with() {
496        // x CONTAINS 'a' ENDS WITH 'b' → error
497        let result = parse("RETURN x CONTAINS 'a' ENDS WITH 'b'");
498        assert!(
499            result.is_err(),
500            "expected parse error for stacked CONTAINS/ENDS WITH"
501        );
502        let msg = result.unwrap_err().to_string();
503        assert!(
504            msg.contains("InvalidPredicateChain"),
505            "expected InvalidPredicateChain, got: {msg}"
506        );
507    }
508
509    #[test]
510    fn test_label_stacking_allowed() {
511        // x :Person :Employee → OK (label stacking is valid)
512        // Note: label predicates in comparison context are valid
513        assert!(
514            parse("MATCH (x) WHERE x:Person:Employee RETURN x").is_ok(),
515            "label stacking should be allowed"
516        );
517    }
518
519    #[test]
520    fn test_range_chaining_allowed() {
521        // 1 < n.num < 3 → OK (required by TCK Comparison3)
522        assert!(
523            parse("MATCH (n) WHERE 1 < n.num < 3 RETURN n").is_ok(),
524            "range chaining 1 < n.num < 3 should be allowed"
525        );
526    }
527}