rust_cel_parser/
visitor.rs

1// src/ast/visitor.rs
2
3//! The Visitor pattern for traversing the AST.
4//!
5//! This module provides a `Visitor` trait that can be implemented to perform
6//! analysis, transformations, or other operations on the AST. The traversal
7//! logic is handled for you.
8
9use super::{BinaryOperator, CelType, ComprehensionOp, Expr, Literal, UnaryOperator};
10
11/// A trait for visiting nodes in the AST in an immutable fashion.
12///
13/// The methods in this trait are called during a depth-first traversal of the AST.
14/// The default implementation of each `visit_*` method recursively visits the children
15/// of the node. To customize the behavior, implement this trait and override the
16/// methods for the nodes you are interested in.
17pub trait Visitor<'ast>
18where
19    Self: Sized,
20{
21    // --- Entry Point ---
22
23    /// Visits any `Expr` node. This is the main dispatch method.
24    fn visit_expr(&mut self, expr: &'ast Expr) {
25        walk_expr(self, expr); // Default behavior is to recurse
26    }
27
28    // --- Leaf Nodes ---
29
30    /// Visits a `Literal` node (e.g., `123`, `"hello"`).
31    fn visit_literal(&mut self, _literal: &'ast Literal) {
32        // Base case, no children to visit.
33    }
34
35    /// Visits an `Identifier` node (e.g., `request`).
36    fn visit_identifier(&mut self, _ident: &'ast str) {
37        // Base case.
38    }
39
40    /// Visits a `Type` node (e.g., `int`, `string`).
41    fn visit_type(&mut self, _cel_type: &'ast CelType) {
42        // Base case.
43    }
44
45    // --- Composite Nodes ---
46
47    /// Visits a `UnaryOp` node (e.g., `!true`).
48    fn visit_unary_op(&mut self, op: UnaryOperator, operand: &'ast Expr) {
49        walk_unary_op(self, op, operand);
50    }
51
52    /// Visits a `BinaryOp` node (e.g., `a + b`).
53    fn visit_binary_op(&mut self, op: BinaryOperator, left: &'ast Expr, right: &'ast Expr) {
54        walk_binary_op(self, op, left, right);
55    }
56
57    /// Visits a `Conditional` node (e.g., `cond ? true_br : false_br`).
58    fn visit_conditional(
59        &mut self,
60        cond: &'ast Expr,
61        true_branch: &'ast Expr,
62        false_branch: &'ast Expr,
63    ) {
64        walk_conditional(self, cond, true_branch, false_branch);
65    }
66
67    /// Visits a `List` literal (e.g., `[1, 2, 3]`).
68    fn visit_list(&mut self, elements: &'ast [Expr]) {
69        walk_list(self, elements);
70    }
71
72    /// Visits a `FieldAccess` node (e.g., `a.b`).
73    fn visit_field_access(&mut self, base: &'ast Expr, field: &'ast str) {
74        walk_field_access(self, base, field);
75    }
76
77    /// Visits a `Call` node (e.g., `func(arg1, arg2)`).
78    fn visit_call(&mut self, target: &'ast Expr, args: &'ast [Expr]) {
79        walk_call(self, target, args);
80    }
81
82    /// Visits an `Index` node (e.g., `list[0]`).
83    fn visit_index(&mut self, base: &'ast Expr, index: &'ast Expr) {
84        walk_index(self, base, index);
85    }
86
87    /// Visits a `MapLiteral` node (e.g., `{'key': 'value'}`).
88    fn visit_map_literal(&mut self, entries: &'ast [(Expr, Expr)]) {
89        walk_map_literal(self, entries);
90    }
91
92    /// Visits a `MessageLiteral` node (e.g., `Point{x: 1, y: 2}`).
93    fn visit_message_literal(&mut self, type_name: &'ast str, fields: &'ast [(String, Expr)]) {
94        walk_message_literal(self, type_name, fields);
95    }
96
97    // --- Macros ---
98
99    /// Visits a `Has` macro (e.g., `has(msg.field)`).
100    fn visit_has(&mut self, target: &'ast Expr) {
101        walk_has(self, target);
102    }
103
104    /// Visits a `Comprehension` macro (e.g., `list.all(i, i > 0)`).
105    fn visit_comprehension(
106        &mut self,
107        op: ComprehensionOp,
108        target: &'ast Expr,
109        iter_var: &'ast str,
110        predicate: &'ast Expr,
111    ) {
112        walk_comprehension(self, op, target, iter_var, predicate);
113    }
114
115    /// Visits a `Map` macro (e.g., `items.map(i, i.price)`).
116    fn visit_map(
117        &mut self,
118        target: &'ast Expr,
119        iter_var: &'ast str,
120        filter: Option<&'ast Expr>,
121        transform: &'ast Expr,
122    ) {
123        walk_map(self, target, iter_var, filter, transform);
124    }
125}
126
127// --- Walker Functions ---
128// These functions contain the actual traversal logic. The `Visitor` methods
129// call these by default to provide recursive traversal.
130
131pub fn walk_expr<'ast, V: Visitor<'ast>>(visitor: &mut V, expr: &'ast Expr) {
132    match expr {
133        Expr::Literal(lit) => visitor.visit_literal(lit),
134        Expr::Identifier(s) => visitor.visit_identifier(s),
135        Expr::UnaryOp { op, operand } => visitor.visit_unary_op(*op, operand),
136        Expr::BinaryOp { op, left, right } => visitor.visit_binary_op(*op, left, right),
137        Expr::Conditional {
138            cond,
139            true_branch,
140            false_branch,
141        } => visitor.visit_conditional(cond, true_branch, false_branch),
142        Expr::List { elements } => visitor.visit_list(elements),
143        Expr::FieldAccess { base, field } => visitor.visit_field_access(base, field),
144        Expr::Call { target, args } => visitor.visit_call(target, args),
145        Expr::Index { base, index } => visitor.visit_index(base, index),
146        Expr::MapLiteral { entries } => visitor.visit_map_literal(entries),
147        Expr::MessageLiteral { type_name, fields } => {
148            visitor.visit_message_literal(type_name, fields)
149        }
150        Expr::Has { target } => visitor.visit_has(target),
151        Expr::Comprehension {
152            op,
153            target,
154            iter_var,
155            predicate,
156        } => visitor.visit_comprehension(*op, target, iter_var, predicate),
157        Expr::Map {
158            target,
159            iter_var,
160            filter,
161            transform,
162        } => visitor.visit_map(target, iter_var, filter.as_deref(), transform),
163        Expr::Type(cel_type) => visitor.visit_type(cel_type),
164    }
165}
166
167pub fn walk_unary_op<'ast, V: Visitor<'ast>>(
168    visitor: &mut V,
169    _op: UnaryOperator,
170    operand: &'ast Expr,
171) {
172    visitor.visit_expr(operand);
173}
174
175pub fn walk_binary_op<'ast, V: Visitor<'ast>>(
176    visitor: &mut V,
177    _op: BinaryOperator,
178    left: &'ast Expr,
179    right: &'ast Expr,
180) {
181    visitor.visit_expr(left);
182    visitor.visit_expr(right);
183}
184
185pub fn walk_conditional<'ast, V: Visitor<'ast>>(
186    visitor: &mut V,
187    cond: &'ast Expr,
188    true_branch: &'ast Expr,
189    false_branch: &'ast Expr,
190) {
191    visitor.visit_expr(cond);
192    visitor.visit_expr(true_branch);
193    visitor.visit_expr(false_branch);
194}
195
196pub fn walk_list<'ast, V: Visitor<'ast>>(visitor: &mut V, elements: &'ast [Expr]) {
197    for element in elements {
198        visitor.visit_expr(element);
199    }
200}
201
202pub fn walk_field_access<'ast, V: Visitor<'ast>>(
203    visitor: &mut V,
204    base: &'ast Expr,
205    _field: &'ast str,
206) {
207    visitor.visit_expr(base);
208}
209
210pub fn walk_call<'ast, V: Visitor<'ast>>(visitor: &mut V, target: &'ast Expr, args: &'ast [Expr]) {
211    visitor.visit_expr(target);
212    for arg in args {
213        visitor.visit_expr(arg);
214    }
215}
216
217pub fn walk_index<'ast, V: Visitor<'ast>>(visitor: &mut V, base: &'ast Expr, index: &'ast Expr) {
218    visitor.visit_expr(base);
219    visitor.visit_expr(index);
220}
221
222pub fn walk_map_literal<'ast, V: Visitor<'ast>>(visitor: &mut V, entries: &'ast [(Expr, Expr)]) {
223    for (key, value) in entries {
224        visitor.visit_expr(key);
225        visitor.visit_expr(value);
226    }
227}
228
229pub fn walk_message_literal<'ast, V: Visitor<'ast>>(
230    visitor: &mut V,
231    _type_name: &'ast str,
232    fields: &'ast [(String, Expr)],
233) {
234    for (_name, value) in fields {
235        visitor.visit_expr(value);
236    }
237}
238
239pub fn walk_has<'ast, V: Visitor<'ast>>(visitor: &mut V, target: &'ast Expr) {
240    visitor.visit_expr(target);
241}
242
243pub fn walk_comprehension<'ast, V: Visitor<'ast>>(
244    visitor: &mut V,
245    _op: ComprehensionOp,
246    target: &'ast Expr,
247    _iter_var: &'ast str,
248    predicate: &'ast Expr,
249) {
250    visitor.visit_expr(target);
251    visitor.visit_expr(predicate);
252}
253
254pub fn walk_map<'ast, V: Visitor<'ast>>(
255    visitor: &mut V,
256    target: &'ast Expr,
257    _iter_var: &'ast str,
258    filter: Option<&'ast Expr>,
259    transform: &'ast Expr,
260) {
261    visitor.visit_expr(target);
262    if let Some(filter_expr) = filter {
263        visitor.visit_expr(filter_expr);
264    }
265    visitor.visit_expr(transform);
266}
267
268#[cfg(test)]
269mod tests {
270    use super::*;
271    use crate::parser::parse_cel_program;
272    use std::collections::HashSet;
273
274    /// A simple visitor that collects the names of all unique identifiers in an AST.
275    struct IdentifierCollector<'a> {
276        names: HashSet<&'a str>,
277    }
278
279    impl<'ast> Visitor<'ast> for IdentifierCollector<'ast> {
280        // We only care about identifiers, so we only override this one method.
281        // The default implementations will handle recursing through the rest of the AST.
282        fn visit_identifier(&mut self, ident: &'ast str) {
283            self.names.insert(ident);
284        }
285    }
286
287    #[test]
288    fn test_identifier_collector() {
289        let ast = parse_cel_program("request.auth.user + params.id + request.time").unwrap();
290
291        let mut collector = IdentifierCollector {
292            names: HashSet::new(),
293        };
294        collector.visit_expr(&ast); // Start the visit
295
296        let expected: HashSet<&str> = ["request", "params"].iter().cloned().collect();
297        assert_eq!(collector.names, expected);
298    }
299
300    #[test]
301    fn test_find_specific_function_calls() {
302        let ast = parse_cel_program("size(list_a) + other_func(size(list_b))").unwrap();
303
304        // A visitor that collects the arguments of all `size` function calls.
305        struct SizeCallArgumentCollector<'a> {
306            size_args: Vec<&'a Expr>,
307        }
308
309        impl<'ast> Visitor<'ast> for SizeCallArgumentCollector<'ast> {
310            fn visit_call(&mut self, target: &'ast Expr, args: &'ast [Expr]) {
311                // Check if the function call's target is the identifier "size".
312                if let Some("size") = target.as_identifier() {
313                    if !args.is_empty() {
314                        // If it is, collect the first argument.
315                        self.size_args.push(&args[0]);
316                    }
317                }
318
319                // VERY IMPORTANT: We must still call walk_call to ensure we visit
320                // nested function calls, like the `size` call inside `other_func`.
321                walk_call(self, target, args);
322            }
323        }
324
325        let mut collector = SizeCallArgumentCollector { size_args: vec![] };
326        // Use the public `accept` method to run the visitor.
327        ast.accept(&mut collector);
328
329        assert_eq!(collector.size_args.len(), 2);
330        // The traversal is depth-first, left-to-right, so the first `size` call found
331        // is the outer one with the argument `list`.
332        assert_eq!(collector.size_args[0].as_identifier(), Some("list_a"));
333        assert_eq!(collector.size_args[1].as_identifier(), Some("list_b"));
334    }
335}