zen_expression/parser/
unary.rs

1use crate::functions::{
2    ClosureFunction, DateMethod, DeprecatedFunction, FunctionKind, InternalFunction, MethodKind,
3};
4use crate::lexer::{Bracket, ComparisonOperator, Identifier, LogicalOperator, Operator, TokenKind};
5use crate::parser::ast::{AstNodeError, Node};
6use crate::parser::constants::{Associativity, BINARY_OPERATORS, UNARY_OPERATORS};
7use crate::parser::parser::{Parser, ParserContext};
8use crate::parser::unary::UnaryNodeBehaviour::CompareWithReference;
9use crate::parser::{NodeMetadata, ParserResult};
10
11#[derive(Debug)]
12pub struct Unary;
13
14const ROOT_NODE: Node<'static> = Node::Identifier("$");
15
16impl<'arena, 'token_ref> Parser<'arena, 'token_ref, Unary> {
17    pub fn parse(&self) -> ParserResult<'arena> {
18        let root = self.root_expression();
19
20        ParserResult {
21            root,
22            is_complete: self.is_done(),
23            metadata: self.node_metadata.clone().map(|t| t.into_inner()),
24        }
25    }
26
27    fn root_expression(&self) -> &'arena Node<'arena> {
28        let mut left_node = self.expression_pair();
29
30        while !self.is_done() {
31            let Some(current_token) = self.current() else {
32                break;
33            };
34
35            let join_operator = match &current_token.kind {
36                TokenKind::Operator(Operator::Logical(LogicalOperator::And)) => {
37                    Operator::Logical(LogicalOperator::And)
38                }
39                TokenKind::Operator(Operator::Logical(LogicalOperator::Or))
40                | TokenKind::Operator(Operator::Comma) => Operator::Logical(LogicalOperator::Or),
41                _ => {
42                    return self.error(AstNodeError::Custom {
43                        message: self.bump.alloc_str(
44                            format!("Invalid join operator `{}`", current_token.kind).as_str(),
45                        ),
46                        span: current_token.span,
47                    })
48                }
49            };
50
51            self.next();
52            let right_node = self.expression_pair();
53            left_node = self.node(
54                Node::Binary {
55                    left: left_node,
56                    operator: join_operator,
57                    right: right_node,
58                },
59                |h| NodeMetadata {
60                    span: h.span(left_node, right_node).unwrap_or_default(),
61                },
62            );
63        }
64
65        left_node
66    }
67
68    fn expression_pair(&self) -> &'arena Node<'arena> {
69        let mut left_node = &ROOT_NODE;
70        let current_token = self.current();
71
72        if let Some(TokenKind::Operator(Operator::Comparison(_))) = self.current_kind() {
73            // Skips
74        } else {
75            left_node = self.binary_expression(0, ParserContext::Global);
76        }
77
78        match self.current_kind() {
79            Some(TokenKind::Operator(Operator::Comparison(comparison))) => {
80                self.next();
81                let right_node = self.binary_expression(0, ParserContext::Global);
82                left_node = self.node(
83                    Node::Binary {
84                        left: left_node,
85                        operator: Operator::Comparison(*comparison),
86                        right: right_node,
87                    },
88                    |h| NodeMetadata {
89                        span: (
90                            current_token.map(|t| t.span.0).unwrap_or_default(),
91                            h.metadata(right_node).map(|n| n.span.1).unwrap_or_default(),
92                        ),
93                    },
94                );
95            }
96            _ => {
97                let behaviour = UnaryNodeBehaviour::from(left_node);
98                match behaviour {
99                    CompareWithReference(comparator) => {
100                        left_node = self.node(
101                            Node::Binary {
102                                left: &ROOT_NODE,
103                                operator: Operator::Comparison(comparator),
104                                right: left_node,
105                            },
106                            |h| NodeMetadata {
107                                span: (
108                                    current_token.map(|t| t.span.0).unwrap_or_default(),
109                                    h.metadata(left_node).map(|n| n.span.1).unwrap_or_default(),
110                                ),
111                            },
112                        )
113                    }
114                    UnaryNodeBehaviour::AsBoolean => {
115                        left_node = self.node(
116                            Node::FunctionCall {
117                                kind: FunctionKind::Internal(InternalFunction::Bool),
118                                arguments: self.bump.alloc_slice_clone(&[left_node]),
119                            },
120                            |h| NodeMetadata {
121                                span: (
122                                    current_token.map(|t| t.span.0).unwrap_or_default(),
123                                    h.metadata(left_node).map(|n| n.span.1).unwrap_or_default(),
124                                ),
125                            },
126                        )
127                    }
128                }
129            }
130        }
131
132        left_node
133    }
134
135    #[cfg_attr(feature = "stack-protection", recursive::recursive)]
136    fn binary_expression(&self, precedence: u8, ctx: ParserContext) -> &'arena Node<'arena> {
137        let mut node_left = self.unary_expression();
138        let Some(mut token) = self.current() else {
139            return node_left;
140        };
141
142        while let TokenKind::Operator(operator) = &token.kind {
143            if self.is_done() {
144                break;
145            }
146
147            if ctx == ParserContext::Global
148                && matches!(
149                    operator,
150                    Operator::Comma
151                        | Operator::Logical(LogicalOperator::And)
152                        | Operator::Logical(LogicalOperator::Or)
153                )
154            {
155                break;
156            }
157
158            let Some(op) = BINARY_OPERATORS.get(operator) else {
159                break;
160            };
161
162            if op.precedence < precedence {
163                break;
164            }
165
166            self.next();
167            let node_right = match op.associativity {
168                Associativity::Left => {
169                    self.binary_expression(op.precedence + 1, ParserContext::Global)
170                }
171                _ => self.binary_expression(op.precedence, ParserContext::Global),
172            };
173
174            node_left = self.node(
175                Node::Binary {
176                    operator: *operator,
177                    left: node_left,
178                    right: node_right,
179                },
180                |h| NodeMetadata {
181                    span: h.span(node_left, node_right).unwrap_or_default(),
182                },
183            );
184
185            let Some(t) = self.current() else {
186                break;
187            };
188            token = t;
189        }
190
191        if precedence == 0 {
192            if let Some(conditional_node) =
193                self.conditional(node_left, &|c| self.binary_expression(0, c))
194            {
195                node_left = conditional_node;
196            }
197        }
198
199        node_left
200    }
201
202    fn unary_expression(&self) -> &'arena Node<'arena> {
203        let Some(token) = self.current() else {
204            return self.literal(&|c| self.binary_expression(0, c));
205        };
206
207        if self.depth() > 0 && token.kind == TokenKind::Identifier(Identifier::CallbackReference) {
208            self.next();
209
210            let node = self.node(Node::Pointer, |_| NodeMetadata { span: token.span });
211            let (n, _) = self.with_postfix(node, &|c| self.binary_expression(0, c));
212            return n;
213        }
214
215        if let TokenKind::Operator(operator) = &token.kind {
216            let Some(unary_operator) = UNARY_OPERATORS.get(operator) else {
217                return self.error(AstNodeError::UnexpectedToken {
218                    expected: self.bump.alloc_str("UnaryOperator"),
219                    received: self.bump.alloc_str(token.kind.to_string().as_str()),
220                    span: token.span,
221                });
222            };
223
224            self.next();
225            let expr = self.binary_expression(unary_operator.precedence, ParserContext::Global);
226            let node = self.node(
227                Node::Unary {
228                    operator: *operator,
229                    node: expr,
230                },
231                |h| NodeMetadata {
232                    span: (
233                        token.span.0,
234                        h.metadata(expr).map(|n| n.span.1).unwrap_or_default(),
235                    ),
236                },
237            );
238
239            return node;
240        }
241
242        if let Some(interval_node) = self.interval(&|c| self.binary_expression(0, c)) {
243            return interval_node;
244        }
245
246        if token.kind == TokenKind::Bracket(Bracket::LeftParenthesis) {
247            let p_start = self.current().map(|s| s.span.0);
248
249            self.next();
250            let binary_node = self.binary_expression(0, ParserContext::Global);
251            if let Some(error_node) = self.expect(TokenKind::Bracket(Bracket::RightParenthesis)) {
252                return error_node;
253            };
254
255            let expr = self.node(Node::Parenthesized(binary_node), |_| NodeMetadata {
256                span: (p_start.unwrap_or_default(), self.prev_token_end()),
257            });
258
259            let (n, _) = self.with_postfix(expr, &|c| self.binary_expression(0, c));
260            return n;
261        }
262
263        self.literal(&|c| self.binary_expression(0, c))
264    }
265}
266
267/// Dictates the behaviour of nodes in unary mode.
268/// If `CompareWithReference` is set, node will attempt to make the comparison with the reference,
269/// essentially making it (in case of Equal operator) `$ == nodeValue`, or (in case of In operator)
270/// `$ in nodeValue`.
271///
272/// Using `AsBoolean` will cast the nodeValue to boolean and skip comparison with reference ($).
273/// You may still use references in such case directly, e.g. `contains($, 'hello')`.
274///
275/// Rationale behind this is to avoid scenarios where e.g. $ = false and expression is
276/// `contains($, 'needle')`. If we didn't ignore the reference, unary expression will be
277/// reduced to `$ == contains($, 'needle')` which will be truthy when $ does not
278/// contain needle.
279#[derive(Debug, PartialEq)]
280enum UnaryNodeBehaviour {
281    CompareWithReference(ComparisonOperator),
282    AsBoolean,
283}
284
285impl From<&Node<'_>> for UnaryNodeBehaviour {
286    fn from(value: &Node) -> Self {
287        use ComparisonOperator::*;
288        use UnaryNodeBehaviour::*;
289
290        match value {
291            Node::Null => CompareWithReference(Equal),
292            Node::Root => CompareWithReference(Equal),
293            Node::Bool(_) => CompareWithReference(Equal),
294            Node::Number(_) => CompareWithReference(Equal),
295            Node::String(_) => CompareWithReference(Equal),
296            Node::TemplateString(_) => CompareWithReference(Equal),
297            Node::Object(_) => CompareWithReference(Equal),
298            Node::Assignments { .. } => CompareWithReference(Equal),
299            Node::Pointer => AsBoolean,
300            Node::Array(_) => CompareWithReference(In),
301            Node::Identifier(_) => CompareWithReference(Equal),
302            Node::Closure(_) => AsBoolean,
303            Node::Member { .. } => CompareWithReference(Equal),
304            Node::Slice { .. } => CompareWithReference(In),
305            Node::Interval { .. } => CompareWithReference(In),
306            Node::Conditional {
307                on_true, on_false, ..
308            } => {
309                let a = UnaryNodeBehaviour::from(*on_true);
310                let b = UnaryNodeBehaviour::from(*on_false);
311
312                if a == b {
313                    a
314                } else {
315                    CompareWithReference(Equal)
316                }
317            }
318            Node::Unary { node, .. } => UnaryNodeBehaviour::from(*node),
319            Node::Parenthesized(n) => UnaryNodeBehaviour::from(*n),
320            Node::Binary {
321                left,
322                operator,
323                right,
324            } => match operator {
325                Operator::Arithmetic(_) => {
326                    let a = UnaryNodeBehaviour::from(*left);
327                    let b = UnaryNodeBehaviour::from(*right);
328
329                    if a == b {
330                        a
331                    } else {
332                        CompareWithReference(Equal)
333                    }
334                }
335                Operator::Logical(_) => AsBoolean,
336                Operator::Comparison(_) => AsBoolean,
337                Operator::Range => CompareWithReference(In),
338                Operator::Slice => CompareWithReference(In),
339                Operator::Comma => AsBoolean,
340                Operator::Dot => AsBoolean,
341                Operator::QuestionMark => AsBoolean,
342                Operator::Assign => AsBoolean,
343                Operator::Semi => AsBoolean,
344            },
345            Node::FunctionCall { kind, .. } => match kind {
346                FunctionKind::Internal(i) => match i {
347                    InternalFunction::Len => CompareWithReference(Equal),
348                    InternalFunction::Upper => CompareWithReference(Equal),
349                    InternalFunction::Lower => CompareWithReference(Equal),
350                    InternalFunction::Trim => CompareWithReference(Equal),
351                    InternalFunction::Abs => CompareWithReference(Equal),
352                    InternalFunction::Sum => CompareWithReference(Equal),
353                    InternalFunction::Avg => CompareWithReference(Equal),
354                    InternalFunction::Min => CompareWithReference(Equal),
355                    InternalFunction::Max => CompareWithReference(Equal),
356                    InternalFunction::Rand => CompareWithReference(Equal),
357                    InternalFunction::Median => CompareWithReference(Equal),
358                    InternalFunction::Mode => CompareWithReference(Equal),
359                    InternalFunction::Floor => CompareWithReference(Equal),
360                    InternalFunction::Ceil => CompareWithReference(Equal),
361                    InternalFunction::Round => CompareWithReference(Equal),
362                    InternalFunction::Trunc => CompareWithReference(Equal),
363                    InternalFunction::String => CompareWithReference(Equal),
364                    InternalFunction::Number => CompareWithReference(Equal),
365                    InternalFunction::Bool => CompareWithReference(Equal),
366                    InternalFunction::Flatten => CompareWithReference(In),
367                    InternalFunction::Extract => CompareWithReference(In),
368                    InternalFunction::Contains => AsBoolean,
369                    InternalFunction::StartsWith => AsBoolean,
370                    InternalFunction::EndsWith => AsBoolean,
371                    InternalFunction::Matches => AsBoolean,
372                    InternalFunction::FuzzyMatch => CompareWithReference(Equal),
373                    InternalFunction::Split => CompareWithReference(In),
374                    InternalFunction::IsNumeric => AsBoolean,
375                    InternalFunction::Keys => CompareWithReference(In),
376                    InternalFunction::Values => CompareWithReference(In),
377                    InternalFunction::Type => CompareWithReference(Equal),
378                    InternalFunction::Date => CompareWithReference(Equal),
379                },
380                FunctionKind::Deprecated(d) => match d {
381                    DeprecatedFunction::Date => CompareWithReference(Equal),
382                    DeprecatedFunction::Time => CompareWithReference(Equal),
383                    DeprecatedFunction::Duration => CompareWithReference(Equal),
384                    DeprecatedFunction::Year => CompareWithReference(Equal),
385                    DeprecatedFunction::DayOfWeek => CompareWithReference(Equal),
386                    DeprecatedFunction::DayOfMonth => CompareWithReference(Equal),
387                    DeprecatedFunction::DayOfYear => CompareWithReference(Equal),
388                    DeprecatedFunction::WeekOfYear => CompareWithReference(Equal),
389                    DeprecatedFunction::MonthOfYear => CompareWithReference(Equal),
390                    DeprecatedFunction::MonthString => CompareWithReference(Equal),
391                    DeprecatedFunction::DateString => CompareWithReference(Equal),
392                    DeprecatedFunction::WeekdayString => CompareWithReference(Equal),
393                    DeprecatedFunction::StartOf => CompareWithReference(Equal),
394                    DeprecatedFunction::EndOf => CompareWithReference(Equal),
395                },
396                FunctionKind::Closure(c) => match c {
397                    ClosureFunction::All => AsBoolean,
398                    ClosureFunction::Some => AsBoolean,
399                    ClosureFunction::None => AsBoolean,
400                    ClosureFunction::One => AsBoolean,
401                    ClosureFunction::Filter => CompareWithReference(In),
402                    ClosureFunction::Map => CompareWithReference(In),
403                    ClosureFunction::FlatMap => CompareWithReference(In),
404                    ClosureFunction::Count => CompareWithReference(Equal),
405                },
406            },
407            Node::MethodCall { kind, .. } => match kind {
408                MethodKind::DateMethod(dm) => match dm {
409                    DateMethod::Add => CompareWithReference(Equal),
410                    DateMethod::Sub => CompareWithReference(Equal),
411                    DateMethod::Format => CompareWithReference(Equal),
412                    DateMethod::Month => CompareWithReference(Equal),
413                    DateMethod::Year => CompareWithReference(Equal),
414                    DateMethod::Set => CompareWithReference(Equal),
415                    DateMethod::StartOf => CompareWithReference(Equal),
416                    DateMethod::EndOf => CompareWithReference(Equal),
417                    DateMethod::Diff => CompareWithReference(Equal),
418                    DateMethod::Tz => CompareWithReference(Equal),
419                    DateMethod::Second => CompareWithReference(Equal),
420                    DateMethod::Minute => CompareWithReference(Equal),
421                    DateMethod::Hour => CompareWithReference(Equal),
422                    DateMethod::Day => CompareWithReference(Equal),
423                    DateMethod::DayOfYear => CompareWithReference(Equal),
424                    DateMethod::Week => CompareWithReference(Equal),
425                    DateMethod::Weekday => CompareWithReference(Equal),
426                    DateMethod::Quarter => CompareWithReference(Equal),
427                    DateMethod::Timestamp => CompareWithReference(Equal),
428                    DateMethod::OffsetName => CompareWithReference(Equal),
429                    DateMethod::IsSame => AsBoolean,
430                    DateMethod::IsBefore => AsBoolean,
431                    DateMethod::IsAfter => AsBoolean,
432                    DateMethod::IsSameOrBefore => AsBoolean,
433                    DateMethod::IsSameOrAfter => AsBoolean,
434                    DateMethod::IsValid => AsBoolean,
435                    DateMethod::IsYesterday => AsBoolean,
436                    DateMethod::IsToday => AsBoolean,
437                    DateMethod::IsTomorrow => AsBoolean,
438                    DateMethod::IsLeapYear => AsBoolean,
439                },
440            },
441            Node::Error { .. } => AsBoolean,
442        }
443    }
444}