xee_xpath_ast/parser/
pattern.rs

1use chumsky::{input::ValueInput, prelude::*};
2use std::borrow::Cow;
3use xot::xmlname::NameStrInfo;
4
5use xee_xpath_lexer::Token;
6
7use crate::ast::Span;
8use crate::{ast, WithSpan, FN_NAMESPACE};
9use crate::{pattern, Namespaces, ParserError, VariableNames};
10
11use super::axis_node_test::parser_axis_node_test;
12use super::name::parser_name;
13use super::parser_core::parser as xpath_parser;
14use super::primary::parser_primary;
15use super::{parse, tokens};
16
17use super::types::BoxedParser;
18
19#[derive(Clone)]
20pub(crate) struct PatternParserOutput<'a, I>
21where
22    I: ValueInput<'a, Token = Token<'a>, Span = Span>,
23{
24    pub(crate) pattern: BoxedParser<'a, I, pattern::Pattern<ast::ExprS>>,
25}
26
27pub(crate) fn parser<'a, I>() -> PatternParserOutput<'a, I>
28where
29    I: ValueInput<'a, Token = Token<'a>, Span = Span>,
30{
31    let xpath_parser_output = xpath_parser();
32    let expr_single = xpath_parser_output.expr_single_core;
33    let name_output = parser_name();
34    let name = name_output.eqname;
35    let parser_primary_output = parser_primary(name.clone());
36    let literal = parser_primary_output.literal;
37    let var_ref = parser_primary_output.var_ref;
38    let parser_axis_node_test_output =
39        parser_axis_node_test(name.clone(), xpath_parser_output.kind_test);
40    let node_test = parser_axis_node_test_output.node_test;
41    let abbrev_forward_step = parser_axis_node_test_output.abbrev_forward_step;
42
43    // HACK: a bit of repetition here to produce predicate_list, as getting it out
44    // of the xpath parser seems to lead to recursive parser errors
45    let expr = expr_single
46        .clone()
47        .separated_by(just(Token::Comma))
48        .at_least(1)
49        .collect::<Vec<_>>()
50        .map_with(|exprs, extra| ast::Expr(exprs).with_span(extra.span()))
51        .boxed();
52    let predicate = expr
53        .clone()
54        .delimited_by(just(Token::LeftBracket), just(Token::RightBracket))
55        .boxed();
56    let predicate_list = predicate.repeated().collect::<Vec<_>>().boxed();
57
58    let predicate_pattern = (just(Token::Dot).ignore_then(predicate_list.clone()))
59        .map(|predicates| pattern::PredicatePattern { predicates })
60        .boxed();
61
62    let outer_function_name = name.try_map(|name, span| {
63        let name = name.value;
64        if name.namespace() == FN_NAMESPACE || name.namespace().is_empty() {
65            {
66                match name.local_name() {
67                    "doc" => Ok(pattern::OuterFunctionName::Doc),
68                    "id" => Ok(pattern::OuterFunctionName::Id),
69                    "element-with-id" => Ok(pattern::OuterFunctionName::ElementWithId),
70                    "key" => Ok(pattern::OuterFunctionName::Key),
71                    "root" => Ok(pattern::OuterFunctionName::Root),
72                    _ => Err(ParserError::IllegalFunctionInPattern { name, span }),
73                }
74            }
75        } else {
76            Err(ParserError::IllegalFunctionInPattern { name, span })
77        }
78    });
79
80    let argument = var_ref
81        .clone()
82        .map(|var_ref| {
83            if let ast::PrimaryExpr::VarRef(name) = var_ref.value {
84                pattern::Argument::VarRef(name)
85            } else {
86                unreachable!()
87            }
88        })
89        .or(literal.map(|literal| {
90            if let ast::PrimaryExpr::Literal(literal) = literal.value {
91                pattern::Argument::Literal(literal)
92            } else {
93                unreachable!()
94            }
95        }));
96
97    let argument_list = (argument.separated_by(just(Token::Comma)))
98        .at_least(1)
99        .collect::<Vec<_>>()
100        .delimited_by(just(Token::LeftParen), just(Token::RightParen))
101        .boxed();
102
103    let function_call = outer_function_name.then(argument_list).boxed();
104
105    let rooted_var_ref = var_ref.map(|var_ref| {
106        if let ast::PrimaryExpr::VarRef(name) = var_ref.value {
107            pattern::RootExpr::VarRef(name)
108        } else {
109            unreachable!()
110        }
111    });
112
113    let rooted_function_call = function_call
114        .map(|(name, args)| pattern::RootExpr::FunctionCall(pattern::FunctionCall { name, args }));
115
116    let rooted_path_start = rooted_function_call.or(rooted_var_ref).boxed();
117
118    let slash_or_double_slash = just(Token::Slash).or(just(Token::DoubleSlash));
119
120    let expr_pattern = recursive(|expr_pattern| {
121        let parenthesized_expr = expr_pattern
122            .delimited_by(just(Token::LeftParen), just(Token::RightParen))
123            .boxed();
124
125        let postfix_expr = parenthesized_expr.then(predicate_list.clone()).boxed();
126
127        let forward_axis = (just(Token::Child)
128            .or(just(Token::Descendant))
129            .or(just(Token::Attribute))
130            .or(just(Token::Self_))
131            .or(just(Token::DescendantOrSelf))
132            .or(just(Token::Namespace)))
133        .then_ignore(just(Token::DoubleColon))
134        .map(|token| match token {
135            Token::Child => pattern::ForwardAxis::Child,
136            Token::Descendant => pattern::ForwardAxis::Descendant,
137            Token::Attribute => pattern::ForwardAxis::Attribute,
138            Token::Self_ => pattern::ForwardAxis::Self_,
139            Token::DescendantOrSelf => pattern::ForwardAxis::DescendantOrSelf,
140            Token::Namespace => pattern::ForwardAxis::Namespace,
141            _ => unreachable!(),
142        })
143        .boxed();
144
145        let forward_step_axis_node_test = forward_axis.then(node_test);
146        let forward_step_abbrev = abbrev_forward_step.map(|(axis, node_test)| {
147            let axis = match axis {
148                ast::Axis::Attribute => pattern::ForwardAxis::Attribute,
149                ast::Axis::Child => pattern::ForwardAxis::Child,
150                _ => unreachable!(),
151            };
152            (axis, node_test)
153        });
154
155        let forward_step = forward_step_axis_node_test.or(forward_step_abbrev);
156
157        let axis_step = forward_step.then(predicate_list.clone());
158
159        let step_expr = postfix_expr
160            .map(|(expr, predicates)| {
161                pattern::StepExpr::PostfixExpr(pattern::PostfixExpr { expr, predicates })
162            })
163            .or(axis_step.map(|((axis, node_test), predicates)| {
164                pattern::StepExpr::AxisStep(pattern::AxisStep {
165                    forward: axis,
166                    node_test,
167                    predicates,
168                })
169            }))
170            .boxed();
171
172        let relative_path_expr = step_expr
173            .clone()
174            .then(
175                (slash_or_double_slash.then(step_expr))
176                    .repeated()
177                    .collect::<Vec<_>>(),
178            )
179            .map(|(first_step, rest_steps)| {
180                let mut steps = vec![first_step];
181                for (token, step) in rest_steps {
182                    match token {
183                        Token::Slash => {}
184                        Token::DoubleSlash => {
185                            let axis_step = pattern::AxisStep {
186                                forward: pattern::ForwardAxis::DescendantOrSelf,
187                                node_test: ast::NodeTest::KindTest(ast::KindTest::Any),
188                                predicates: vec![],
189                            };
190                            steps.push(pattern::StepExpr::AxisStep(axis_step));
191                        }
192                        _ => unreachable!(),
193                    }
194                    steps.push(step);
195                }
196                steps
197            })
198            .boxed();
199
200        let rooted_path = rooted_path_start
201            .then(predicate_list)
202            .then(
203                (just(Token::Slash)
204                    .or(just(Token::DoubleSlash))
205                    .then(relative_path_expr.clone()))
206                .or_not(),
207            )
208            .map(|((root, predicates), token_relative_steps)| {
209                let steps = if let Some((token, relative_steps)) = token_relative_steps {
210                    match token {
211                        Token::Slash => relative_steps,
212                        Token::DoubleSlash => {
213                            let axis_step = pattern::AxisStep {
214                                forward: pattern::ForwardAxis::DescendantOrSelf,
215                                node_test: ast::NodeTest::KindTest(ast::KindTest::Any),
216                                predicates: vec![],
217                            };
218                            let mut steps = vec![pattern::StepExpr::AxisStep(axis_step)];
219                            steps.extend(relative_steps);
220                            steps
221                        }
222                        _ => unreachable!(),
223                    }
224                } else {
225                    vec![]
226                };
227                pattern::PathExpr {
228                    root: pattern::PathRoot::Rooted { root, predicates },
229                    steps,
230                }
231            });
232        let absolute_slash_path = just(Token::Slash)
233            .ignore_then(relative_path_expr.clone().or_not())
234            .map(|steps| pattern::PathExpr {
235                root: pattern::PathRoot::AbsoluteSlash,
236                steps: steps.unwrap_or_default(),
237            });
238        let absolute_double_slash_path = just(Token::DoubleSlash)
239            .ignore_then(relative_path_expr.clone())
240            .map(|steps| pattern::PathExpr {
241                root: pattern::PathRoot::AbsoluteDoubleSlash,
242                steps,
243            });
244        let relative_path = relative_path_expr.map(|steps| {
245            // shortcut to create an absolute path if that's possible.
246            // The use of parenthesized expr can otherwise turn stuff into
247            // a postfix expr even though it's actually a simple path expr
248            if steps.len() == 1 {
249                if let pattern::StepExpr::PostfixExpr(postfix_expr) = &steps[0] {
250                    if postfix_expr.predicates.is_empty() {
251                        if let pattern::ExprPattern::Path(path_expr) = &postfix_expr.expr {
252                            return path_expr.clone();
253                        }
254                    }
255                }
256            }
257            pattern::PathExpr {
258                root: pattern::PathRoot::Relative,
259                steps,
260            }
261        });
262
263        let path_expr = absolute_slash_path
264            .or(absolute_double_slash_path)
265            .or(relative_path)
266            .or(rooted_path)
267            .boxed();
268
269        let operator = just(Token::Intersect)
270            .or(just(Token::Except))
271            .or(just(Token::Union))
272            .or(just(Token::Pipe))
273            .map(|token| match token {
274                Token::Intersect => pattern::Operator::Intersect,
275                Token::Except => pattern::Operator::Except,
276                Token::Union => pattern::Operator::Union,
277                Token::Pipe => pattern::Operator::Union,
278                _ => unreachable!(),
279            });
280
281        let expr_pattern = (path_expr.clone().map(pattern::ExprPattern::Path))
282            .foldl(
283                operator.then(path_expr.clone()).repeated(),
284                |left, (operator, right)| {
285                    pattern::ExprPattern::BinaryExpr(pattern::BinaryExpr {
286                        operator,
287                        left: Box::new(left),
288                        right: Box::new(pattern::ExprPattern::Path(right)),
289                    })
290                },
291            )
292            .boxed();
293
294        expr_pattern
295    })
296    .boxed();
297
298    let predicate_pattern = predicate_pattern
299        .then_ignore(end())
300        .map(pattern::Pattern::Predicate)
301        .boxed();
302
303    let union_pattern = expr_pattern
304        .then_ignore(end())
305        .map(pattern::Pattern::Expr)
306        .boxed();
307
308    let pattern = predicate_pattern.or(union_pattern).boxed();
309
310    PatternParserOutput { pattern }
311}
312
313impl pattern::Pattern<ast::ExprS> {
314    pub fn parse<'a>(
315        input: &'a str,
316        namespaces: &'a Namespaces,
317        _variable_names: &'a VariableNames,
318    ) -> Result<Self, ParserError> {
319        let pattern = parse(parser().pattern, tokens(input), Cow::Borrowed(namespaces))?;
320        // TODO: do we need to rename variables to unique names? probably
321        Ok(pattern)
322    }
323}
324
325#[cfg(test)]
326mod tests {
327    use ahash::HashSetExt;
328    use insta::assert_ron_snapshot;
329
330    use super::*;
331
332    #[test]
333    fn test_predicate_pattern_no_predicates() {
334        let namespaces = Namespaces::default();
335        let variable_names = VariableNames::new();
336        assert_ron_snapshot!(pattern::Pattern::parse(".", &namespaces, &variable_names));
337    }
338
339    #[test]
340    fn test_predicate_pattern_single_predicate() {
341        let namespaces = Namespaces::default();
342        let variable_names = VariableNames::new();
343        assert_ron_snapshot!(pattern::Pattern::parse(
344            ".[1]",
345            &namespaces,
346            &variable_names
347        ));
348    }
349
350    #[test]
351    fn test_expr_pattern() {
352        let namespaces = Namespaces::default();
353        let variable_names = VariableNames::new();
354        assert_ron_snapshot!(pattern::Pattern::parse(
355            "$a | $b",
356            &namespaces,
357            &variable_names
358        ));
359    }
360
361    #[test]
362    fn test_expr_pattern_rooted_path() {
363        let namespaces = Namespaces::default();
364        let variable_names = VariableNames::new();
365        assert_ron_snapshot!(pattern::Pattern::parse(
366            "$a/foo",
367            &namespaces,
368            &variable_names
369        ));
370    }
371
372    #[test]
373    fn test_expr_pattern_absolute_slash() {
374        let namespaces = Namespaces::default();
375        let variable_names = VariableNames::new();
376        assert_ron_snapshot!(pattern::Pattern::parse(
377            "/foo",
378            &namespaces,
379            &variable_names
380        ));
381    }
382
383    #[test]
384    fn test_expr_pattern_absolute_double_slash() {
385        let namespaces = Namespaces::default();
386        let variable_names = VariableNames::new();
387        assert_ron_snapshot!(pattern::Pattern::parse(
388            "//foo",
389            &namespaces,
390            &variable_names
391        ));
392    }
393
394    #[test]
395    fn test_absolute_slash_without_steps() {
396        let namespaces = Namespaces::default();
397        let variable_names = VariableNames::new();
398        assert_ron_snapshot!(pattern::Pattern::parse("/", &namespaces, &variable_names));
399    }
400
401    #[test]
402    fn test_absolute_slash_without_steps_in_parenthesis() {
403        let namespaces = Namespaces::default();
404        let variable_names = VariableNames::new();
405        assert_ron_snapshot!(pattern::Pattern::parse("(/)", &namespaces, &variable_names));
406    }
407
408    #[test]
409    fn test_expr_pattern_relative() {
410        let namespaces = Namespaces::default();
411        let variable_names = VariableNames::new();
412        assert_ron_snapshot!(pattern::Pattern::parse("foo", &namespaces, &variable_names));
413    }
414
415    #[test]
416    fn test_postfix_expr() {
417        let namespaces = Namespaces::default();
418        let variable_names = VariableNames::new();
419        assert_ron_snapshot!(pattern::Pattern::parse(
420            "foo[1]",
421            &namespaces,
422            &variable_names
423        ));
424    }
425
426    #[test]
427    fn test_union() {
428        let namespaces = Namespaces::default();
429        let variable_names = VariableNames::new();
430        assert_ron_snapshot!(pattern::Pattern::parse(
431            "foo | bar",
432            &namespaces,
433            &variable_names
434        ));
435    }
436
437    #[test]
438    fn test_intersect() {
439        let namespaces = Namespaces::default();
440        let variable_names = VariableNames::new();
441        assert_ron_snapshot!(pattern::Pattern::parse(
442            "foo intersect bar",
443            &namespaces,
444            &variable_names
445        ));
446    }
447
448    #[test]
449    fn test_union_with_intersect() {
450        let namespaces = Namespaces::default();
451        let variable_names = VariableNames::new();
452        assert_ron_snapshot!(pattern::Pattern::parse(
453            "foo intersect bar | baz",
454            &namespaces,
455            &variable_names
456        ));
457    }
458
459    #[test]
460    fn test_union_with_union() {
461        let namespaces = Namespaces::default();
462        let variable_names = VariableNames::new();
463        assert_ron_snapshot!(pattern::Pattern::parse(
464            "foo | (bar | baz)",
465            &namespaces,
466            &variable_names
467        ));
468    }
469
470    #[test]
471    fn test_intersect_with_union() {
472        let namespaces = Namespaces::default();
473        let variable_names = VariableNames::new();
474        assert_ron_snapshot!(pattern::Pattern::parse(
475            "foo intersect (bar | baz)",
476            &namespaces,
477            &variable_names
478        ));
479    }
480
481    #[test]
482    fn test_root_intersect_with_other_path() {
483        let namespaces = Namespaces::default();
484        let variable_names = VariableNames::new();
485        // have to use bracketrs here, as otherwise 'intersect' is interpreted
486        // as an element name as per xpath rules
487        assert_ron_snapshot!(pattern::Pattern::parse(
488            "(/) intersect foo",
489            &namespaces,
490            &variable_names
491        ));
492    }
493}