zen_expression/parser/
unary.rs

1use crate::functions::{ClosureFunction, DeprecatedFunction, FunctionKind, InternalFunction};
2use crate::lexer::{Bracket, ComparisonOperator, Identifier, LogicalOperator, Operator, TokenKind};
3use crate::parser::ast::{AstNodeError, Node};
4use crate::parser::constants::{Associativity, BINARY_OPERATORS, UNARY_OPERATORS};
5use crate::parser::parser::{Parser, ParserContext};
6use crate::parser::unary::UnaryNodeBehaviour::CompareWithReference;
7use crate::parser::{NodeMetadata, ParserResult};
8
9#[derive(Debug)]
10pub struct Unary;
11
12const ROOT_NODE: Node<'static> = Node::Identifier("$");
13
14impl<'arena, 'token_ref> Parser<'arena, 'token_ref, Unary> {
15    pub fn parse(&self) -> ParserResult<'arena> {
16        let root = self.root_expression();
17
18        ParserResult {
19            root,
20            is_complete: self.is_done(),
21            metadata: self.node_metadata.clone().map(|t| t.into_inner()),
22        }
23    }
24
25    fn root_expression(&self) -> &'arena Node<'arena> {
26        let mut left_node = self.expression_pair();
27
28        while !self.is_done() {
29            let Some(current_token) = self.current() else {
30                break;
31            };
32
33            let join_operator = match &current_token.kind {
34                TokenKind::Operator(Operator::Logical(LogicalOperator::And)) => {
35                    Operator::Logical(LogicalOperator::And)
36                }
37                TokenKind::Operator(Operator::Logical(LogicalOperator::Or))
38                | TokenKind::Operator(Operator::Comma) => Operator::Logical(LogicalOperator::Or),
39                _ => {
40                    return self.error(AstNodeError::Custom {
41                        message: self.bump.alloc_str(
42                            format!("Invalid join operator `{}`", current_token.kind).as_str(),
43                        ),
44                        span: current_token.span,
45                    })
46                }
47            };
48
49            self.next();
50            let right_node = self.expression_pair();
51            left_node = self.node(
52                Node::Binary {
53                    left: left_node,
54                    operator: join_operator,
55                    right: right_node,
56                },
57                |h| NodeMetadata {
58                    span: h.span(left_node, right_node).unwrap_or_default(),
59                },
60            );
61        }
62
63        left_node
64    }
65
66    fn expression_pair(&self) -> &'arena Node<'arena> {
67        let mut left_node = &ROOT_NODE;
68        let current_token = self.current();
69
70        if let Some(TokenKind::Operator(Operator::Comparison(_))) = self.current_kind() {
71            // Skips
72        } else {
73            left_node = self.binary_expression(0, ParserContext::Global);
74        }
75
76        match self.current_kind() {
77            Some(TokenKind::Operator(Operator::Comparison(comparison))) => {
78                self.next();
79                let right_node = self.binary_expression(0, ParserContext::Global);
80                left_node = self.node(
81                    Node::Binary {
82                        left: left_node,
83                        operator: Operator::Comparison(*comparison),
84                        right: right_node,
85                    },
86                    |h| NodeMetadata {
87                        span: (
88                            current_token.map(|t| t.span.0).unwrap_or_default(),
89                            h.metadata(right_node).map(|n| n.span.1).unwrap_or_default(),
90                        ),
91                    },
92                );
93            }
94            _ => {
95                let behaviour = UnaryNodeBehaviour::from(left_node);
96                match behaviour {
97                    CompareWithReference(comparator) => {
98                        left_node = self.node(
99                            Node::Binary {
100                                left: &ROOT_NODE,
101                                operator: Operator::Comparison(comparator),
102                                right: left_node,
103                            },
104                            |h| NodeMetadata {
105                                span: (
106                                    current_token.map(|t| t.span.0).unwrap_or_default(),
107                                    h.metadata(left_node).map(|n| n.span.1).unwrap_or_default(),
108                                ),
109                            },
110                        )
111                    }
112                    UnaryNodeBehaviour::AsBoolean => {
113                        left_node = self.node(
114                            Node::FunctionCall {
115                                kind: FunctionKind::Internal(InternalFunction::Bool),
116                                arguments: self.bump.alloc_slice_clone(&[left_node]),
117                            },
118                            |h| NodeMetadata {
119                                span: (
120                                    current_token.map(|t| t.span.0).unwrap_or_default(),
121                                    h.metadata(left_node).map(|n| n.span.1).unwrap_or_default(),
122                                ),
123                            },
124                        )
125                    }
126                }
127            }
128        }
129
130        left_node
131    }
132
133    fn binary_expression(&self, precedence: u8, ctx: ParserContext) -> &'arena Node<'arena> {
134        let mut node_left = self.unary_expression();
135        let Some(mut token) = self.current() else {
136            return node_left;
137        };
138
139        while let TokenKind::Operator(operator) = &token.kind {
140            if self.is_done() {
141                break;
142            }
143
144            if ctx == ParserContext::Global
145                && matches!(
146                    operator,
147                    Operator::Comma
148                        | Operator::Logical(LogicalOperator::And)
149                        | Operator::Logical(LogicalOperator::Or)
150                )
151            {
152                break;
153            }
154
155            let Some(op) = BINARY_OPERATORS.get(operator) else {
156                break;
157            };
158
159            if op.precedence < precedence {
160                break;
161            }
162
163            self.next();
164            let node_right = match op.associativity {
165                Associativity::Left => {
166                    self.binary_expression(op.precedence + 1, ParserContext::Global)
167                }
168                _ => self.binary_expression(op.precedence, ParserContext::Global),
169            };
170
171            node_left = self.node(
172                Node::Binary {
173                    operator: *operator,
174                    left: node_left,
175                    right: node_right,
176                },
177                |h| NodeMetadata {
178                    span: h.span(node_left, node_right).unwrap_or_default(),
179                },
180            );
181
182            let Some(t) = self.current() else {
183                break;
184            };
185            token = t;
186        }
187
188        if precedence == 0 {
189            if let Some(conditional_node) =
190                self.conditional(node_left, |c| self.binary_expression(0, c))
191            {
192                node_left = conditional_node;
193            }
194        }
195
196        node_left
197    }
198
199    fn unary_expression(&self) -> &'arena Node<'arena> {
200        let Some(token) = self.current() else {
201            return self.literal(|c| self.binary_expression(0, c));
202        };
203
204        if self.depth() > 0 && token.kind == TokenKind::Identifier(Identifier::CallbackReference) {
205            self.next();
206
207            let node = self.node(Node::Pointer, |_| NodeMetadata { span: token.span });
208            return self.with_postfix(node, |c| self.binary_expression(0, c));
209        }
210
211        if let TokenKind::Operator(operator) = &token.kind {
212            let Some(unary_operator) = UNARY_OPERATORS.get(operator) else {
213                return self.error(AstNodeError::UnexpectedToken {
214                    expected: self.bump.alloc_str("UnaryOperator"),
215                    received: self.bump.alloc_str(token.kind.to_string().as_str()),
216                    span: token.span,
217                });
218            };
219
220            self.next();
221            let expr = self.binary_expression(unary_operator.precedence, ParserContext::Global);
222            let node = self.node(
223                Node::Unary {
224                    operator: *operator,
225                    node: expr,
226                },
227                |h| NodeMetadata {
228                    span: (
229                        token.span.0,
230                        h.metadata(expr).map(|n| n.span.1).unwrap_or_default(),
231                    ),
232                },
233            );
234
235            return node;
236        }
237
238        if let Some(interval_node) = self.interval(|c| self.binary_expression(0, c)) {
239            return interval_node;
240        }
241
242        if token.kind == TokenKind::Bracket(Bracket::LeftParenthesis) {
243            let p_start = self.current().map(|s| s.span.0);
244
245            self.next();
246            let binary_node = self.binary_expression(0, ParserContext::Global);
247            if let Some(error_node) = self.expect(TokenKind::Bracket(Bracket::RightParenthesis)) {
248                return error_node;
249            };
250
251            let expr = self.node(Node::Parenthesized(binary_node), |_| NodeMetadata {
252                span: (p_start.unwrap_or_default(), self.prev_token_end()),
253            });
254
255            return self.with_postfix(expr, |c| self.binary_expression(0, c));
256        }
257
258        self.literal(|c| self.binary_expression(0, c))
259    }
260}
261
262/// Dictates the behaviour of nodes in unary mode.
263/// If `CompareWithReference` is set, node will attempt to make the comparison with the reference,
264/// essentially making it (in case of Equal operator) `$ == nodeValue`, or (in case of In operator)
265/// `$ in nodeValue`.
266///
267/// Using `AsBoolean` will cast the nodeValue to boolean and skip comparison with reference ($).
268/// You may still use references in such case directly, e.g. `contains($, 'hello')`.
269///
270/// Rationale behind this is to avoid scenarios where e.g. $ = false and expression is
271/// `contains($, 'needle')`. If we didn't ignore the reference, unary expression will be
272/// reduced to `$ == contains($, 'needle')` which will be truthy when $ does not
273/// contain needle.
274#[derive(Debug, PartialEq)]
275enum UnaryNodeBehaviour {
276    CompareWithReference(ComparisonOperator),
277    AsBoolean,
278}
279
280impl From<&Node<'_>> for UnaryNodeBehaviour {
281    fn from(value: &Node) -> Self {
282        use ComparisonOperator::*;
283        use UnaryNodeBehaviour::*;
284
285        match value {
286            Node::Null => CompareWithReference(Equal),
287            Node::Root => CompareWithReference(Equal),
288            Node::Bool(_) => CompareWithReference(Equal),
289            Node::Number(_) => CompareWithReference(Equal),
290            Node::String(_) => CompareWithReference(Equal),
291            Node::TemplateString(_) => CompareWithReference(Equal),
292            Node::Object(_) => CompareWithReference(Equal),
293            Node::Pointer => AsBoolean,
294            Node::Array(_) => CompareWithReference(In),
295            Node::Identifier(_) => CompareWithReference(Equal),
296            Node::Closure(_) => AsBoolean,
297            Node::Member { .. } => CompareWithReference(Equal),
298            Node::Slice { .. } => CompareWithReference(In),
299            Node::Interval { .. } => CompareWithReference(In),
300            Node::Conditional {
301                on_true, on_false, ..
302            } => {
303                let a = UnaryNodeBehaviour::from(*on_true);
304                let b = UnaryNodeBehaviour::from(*on_false);
305
306                if a == b {
307                    a
308                } else {
309                    CompareWithReference(Equal)
310                }
311            }
312            Node::Unary { node, .. } => UnaryNodeBehaviour::from(*node),
313            Node::Parenthesized(n) => UnaryNodeBehaviour::from(*n),
314            Node::Binary {
315                left,
316                operator,
317                right,
318            } => match operator {
319                Operator::Arithmetic(_) => {
320                    let a = UnaryNodeBehaviour::from(*left);
321                    let b = UnaryNodeBehaviour::from(*right);
322
323                    if a == b {
324                        a
325                    } else {
326                        CompareWithReference(Equal)
327                    }
328                }
329                Operator::Logical(_) => AsBoolean,
330                Operator::Comparison(_) => AsBoolean,
331                Operator::Range => CompareWithReference(In),
332                Operator::Slice => CompareWithReference(In),
333                Operator::Comma => AsBoolean,
334                Operator::Dot => AsBoolean,
335                Operator::QuestionMark => AsBoolean,
336            },
337            Node::FunctionCall { kind, .. } => match kind {
338                FunctionKind::Internal(i) => match i {
339                    InternalFunction::Len => CompareWithReference(Equal),
340                    InternalFunction::Upper => CompareWithReference(Equal),
341                    InternalFunction::Lower => CompareWithReference(Equal),
342                    InternalFunction::Trim => CompareWithReference(Equal),
343                    InternalFunction::Abs => CompareWithReference(Equal),
344                    InternalFunction::Sum => CompareWithReference(Equal),
345                    InternalFunction::Avg => CompareWithReference(Equal),
346                    InternalFunction::Min => CompareWithReference(Equal),
347                    InternalFunction::Max => CompareWithReference(Equal),
348                    InternalFunction::Rand => CompareWithReference(Equal),
349                    InternalFunction::Median => CompareWithReference(Equal),
350                    InternalFunction::Mode => CompareWithReference(Equal),
351                    InternalFunction::Floor => CompareWithReference(Equal),
352                    InternalFunction::Ceil => CompareWithReference(Equal),
353                    InternalFunction::Round => CompareWithReference(Equal),
354                    InternalFunction::String => CompareWithReference(Equal),
355                    InternalFunction::Number => CompareWithReference(Equal),
356                    InternalFunction::Bool => CompareWithReference(Equal),
357                    InternalFunction::Flatten => CompareWithReference(In),
358                    InternalFunction::Extract => CompareWithReference(In),
359                    InternalFunction::Contains => AsBoolean,
360                    InternalFunction::StartsWith => AsBoolean,
361                    InternalFunction::EndsWith => AsBoolean,
362                    InternalFunction::Matches => AsBoolean,
363                    InternalFunction::FuzzyMatch => CompareWithReference(Equal),
364                    InternalFunction::Split => CompareWithReference(In),
365                    InternalFunction::IsNumeric => AsBoolean,
366                    InternalFunction::Keys => CompareWithReference(In),
367                    InternalFunction::Values => CompareWithReference(In),
368                    InternalFunction::Type => CompareWithReference(Equal),
369                },
370                FunctionKind::Deprecated(d) => match d {
371                    DeprecatedFunction::Date => CompareWithReference(Equal),
372                    DeprecatedFunction::Time => CompareWithReference(Equal),
373                    DeprecatedFunction::Duration => CompareWithReference(Equal),
374                    DeprecatedFunction::Year => CompareWithReference(Equal),
375                    DeprecatedFunction::DayOfWeek => CompareWithReference(Equal),
376                    DeprecatedFunction::DayOfMonth => CompareWithReference(Equal),
377                    DeprecatedFunction::DayOfYear => CompareWithReference(Equal),
378                    DeprecatedFunction::WeekOfYear => CompareWithReference(Equal),
379                    DeprecatedFunction::MonthOfYear => CompareWithReference(Equal),
380                    DeprecatedFunction::MonthString => CompareWithReference(Equal),
381                    DeprecatedFunction::DateString => CompareWithReference(Equal),
382                    DeprecatedFunction::WeekdayString => CompareWithReference(Equal),
383                    DeprecatedFunction::StartOf => CompareWithReference(Equal),
384                    DeprecatedFunction::EndOf => CompareWithReference(Equal),
385                },
386                FunctionKind::Closure(c) => match c {
387                    ClosureFunction::All => AsBoolean,
388                    ClosureFunction::Some => AsBoolean,
389                    ClosureFunction::None => AsBoolean,
390                    ClosureFunction::One => AsBoolean,
391                    ClosureFunction::Filter => CompareWithReference(In),
392                    ClosureFunction::Map => CompareWithReference(In),
393                    ClosureFunction::FlatMap => CompareWithReference(In),
394                    ClosureFunction::Count => CompareWithReference(Equal),
395                },
396            },
397            Node::Error { .. } => AsBoolean,
398        }
399    }
400}