1use super::{BinaryOperator, CelType, ComprehensionOp, Expr, Literal, UnaryOperator};
10
11pub trait Visitor<'ast>
18where
19 Self: Sized,
20{
21 fn visit_expr(&mut self, expr: &'ast Expr) {
25 walk_expr(self, expr); }
27
28 fn visit_literal(&mut self, _literal: &'ast Literal) {
32 }
34
35 fn visit_identifier(&mut self, _ident: &'ast str) {
37 }
39
40 fn visit_type(&mut self, _cel_type: &'ast CelType) {
42 }
44
45 fn visit_unary_op(&mut self, op: UnaryOperator, operand: &'ast Expr) {
49 walk_unary_op(self, op, operand);
50 }
51
52 fn visit_binary_op(&mut self, op: BinaryOperator, left: &'ast Expr, right: &'ast Expr) {
54 walk_binary_op(self, op, left, right);
55 }
56
57 fn visit_conditional(
59 &mut self,
60 cond: &'ast Expr,
61 true_branch: &'ast Expr,
62 false_branch: &'ast Expr,
63 ) {
64 walk_conditional(self, cond, true_branch, false_branch);
65 }
66
67 fn visit_list(&mut self, elements: &'ast [Expr]) {
69 walk_list(self, elements);
70 }
71
72 fn visit_field_access(&mut self, base: &'ast Expr, field: &'ast str) {
74 walk_field_access(self, base, field);
75 }
76
77 fn visit_call(&mut self, target: &'ast Expr, args: &'ast [Expr]) {
79 walk_call(self, target, args);
80 }
81
82 fn visit_index(&mut self, base: &'ast Expr, index: &'ast Expr) {
84 walk_index(self, base, index);
85 }
86
87 fn visit_map_literal(&mut self, entries: &'ast [(Expr, Expr)]) {
89 walk_map_literal(self, entries);
90 }
91
92 fn visit_message_literal(&mut self, type_name: &'ast str, fields: &'ast [(String, Expr)]) {
94 walk_message_literal(self, type_name, fields);
95 }
96
97 fn visit_has(&mut self, target: &'ast Expr) {
101 walk_has(self, target);
102 }
103
104 fn visit_comprehension(
106 &mut self,
107 op: ComprehensionOp,
108 target: &'ast Expr,
109 iter_var: &'ast str,
110 predicate: &'ast Expr,
111 ) {
112 walk_comprehension(self, op, target, iter_var, predicate);
113 }
114
115 fn visit_map(
117 &mut self,
118 target: &'ast Expr,
119 iter_var: &'ast str,
120 filter: Option<&'ast Expr>,
121 transform: &'ast Expr,
122 ) {
123 walk_map(self, target, iter_var, filter, transform);
124 }
125}
126
127pub fn walk_expr<'ast, V: Visitor<'ast>>(visitor: &mut V, expr: &'ast Expr) {
132 match expr {
133 Expr::Literal(lit) => visitor.visit_literal(lit),
134 Expr::Identifier(s) => visitor.visit_identifier(s),
135 Expr::UnaryOp { op, operand } => visitor.visit_unary_op(*op, operand),
136 Expr::BinaryOp { op, left, right } => visitor.visit_binary_op(*op, left, right),
137 Expr::Conditional {
138 cond,
139 true_branch,
140 false_branch,
141 } => visitor.visit_conditional(cond, true_branch, false_branch),
142 Expr::List { elements } => visitor.visit_list(elements),
143 Expr::FieldAccess { base, field } => visitor.visit_field_access(base, field),
144 Expr::Call { target, args } => visitor.visit_call(target, args),
145 Expr::Index { base, index } => visitor.visit_index(base, index),
146 Expr::MapLiteral { entries } => visitor.visit_map_literal(entries),
147 Expr::MessageLiteral { type_name, fields } => {
148 visitor.visit_message_literal(type_name, fields)
149 }
150 Expr::Has { target } => visitor.visit_has(target),
151 Expr::Comprehension {
152 op,
153 target,
154 iter_var,
155 predicate,
156 } => visitor.visit_comprehension(*op, target, iter_var, predicate),
157 Expr::Map {
158 target,
159 iter_var,
160 filter,
161 transform,
162 } => visitor.visit_map(target, iter_var, filter.as_deref(), transform),
163 Expr::Type(cel_type) => visitor.visit_type(cel_type),
164 }
165}
166
167pub fn walk_unary_op<'ast, V: Visitor<'ast>>(
168 visitor: &mut V,
169 _op: UnaryOperator,
170 operand: &'ast Expr,
171) {
172 visitor.visit_expr(operand);
173}
174
175pub fn walk_binary_op<'ast, V: Visitor<'ast>>(
176 visitor: &mut V,
177 _op: BinaryOperator,
178 left: &'ast Expr,
179 right: &'ast Expr,
180) {
181 visitor.visit_expr(left);
182 visitor.visit_expr(right);
183}
184
185pub fn walk_conditional<'ast, V: Visitor<'ast>>(
186 visitor: &mut V,
187 cond: &'ast Expr,
188 true_branch: &'ast Expr,
189 false_branch: &'ast Expr,
190) {
191 visitor.visit_expr(cond);
192 visitor.visit_expr(true_branch);
193 visitor.visit_expr(false_branch);
194}
195
196pub fn walk_list<'ast, V: Visitor<'ast>>(visitor: &mut V, elements: &'ast [Expr]) {
197 for element in elements {
198 visitor.visit_expr(element);
199 }
200}
201
202pub fn walk_field_access<'ast, V: Visitor<'ast>>(
203 visitor: &mut V,
204 base: &'ast Expr,
205 _field: &'ast str,
206) {
207 visitor.visit_expr(base);
208}
209
210pub fn walk_call<'ast, V: Visitor<'ast>>(visitor: &mut V, target: &'ast Expr, args: &'ast [Expr]) {
211 visitor.visit_expr(target);
212 for arg in args {
213 visitor.visit_expr(arg);
214 }
215}
216
217pub fn walk_index<'ast, V: Visitor<'ast>>(visitor: &mut V, base: &'ast Expr, index: &'ast Expr) {
218 visitor.visit_expr(base);
219 visitor.visit_expr(index);
220}
221
222pub fn walk_map_literal<'ast, V: Visitor<'ast>>(visitor: &mut V, entries: &'ast [(Expr, Expr)]) {
223 for (key, value) in entries {
224 visitor.visit_expr(key);
225 visitor.visit_expr(value);
226 }
227}
228
229pub fn walk_message_literal<'ast, V: Visitor<'ast>>(
230 visitor: &mut V,
231 _type_name: &'ast str,
232 fields: &'ast [(String, Expr)],
233) {
234 for (_name, value) in fields {
235 visitor.visit_expr(value);
236 }
237}
238
239pub fn walk_has<'ast, V: Visitor<'ast>>(visitor: &mut V, target: &'ast Expr) {
240 visitor.visit_expr(target);
241}
242
243pub fn walk_comprehension<'ast, V: Visitor<'ast>>(
244 visitor: &mut V,
245 _op: ComprehensionOp,
246 target: &'ast Expr,
247 _iter_var: &'ast str,
248 predicate: &'ast Expr,
249) {
250 visitor.visit_expr(target);
251 visitor.visit_expr(predicate);
252}
253
254pub fn walk_map<'ast, V: Visitor<'ast>>(
255 visitor: &mut V,
256 target: &'ast Expr,
257 _iter_var: &'ast str,
258 filter: Option<&'ast Expr>,
259 transform: &'ast Expr,
260) {
261 visitor.visit_expr(target);
262 if let Some(filter_expr) = filter {
263 visitor.visit_expr(filter_expr);
264 }
265 visitor.visit_expr(transform);
266}
267
268#[cfg(test)]
269mod tests {
270 use super::*;
271 use crate::parser::parse_cel_program;
272 use std::collections::HashSet;
273
274 struct IdentifierCollector<'a> {
276 names: HashSet<&'a str>,
277 }
278
279 impl<'ast> Visitor<'ast> for IdentifierCollector<'ast> {
280 fn visit_identifier(&mut self, ident: &'ast str) {
283 self.names.insert(ident);
284 }
285 }
286
287 #[test]
288 fn test_identifier_collector() {
289 let ast = parse_cel_program("request.auth.user + params.id + request.time").unwrap();
290
291 let mut collector = IdentifierCollector {
292 names: HashSet::new(),
293 };
294 collector.visit_expr(&ast); let expected: HashSet<&str> = ["request", "params"].iter().cloned().collect();
297 assert_eq!(collector.names, expected);
298 }
299
300 #[test]
301 fn test_find_specific_function_calls() {
302 let ast = parse_cel_program("size(list_a) + other_func(size(list_b))").unwrap();
303
304 struct SizeCallArgumentCollector<'a> {
306 size_args: Vec<&'a Expr>,
307 }
308
309 impl<'ast> Visitor<'ast> for SizeCallArgumentCollector<'ast> {
310 fn visit_call(&mut self, target: &'ast Expr, args: &'ast [Expr]) {
311 if let Some("size") = target.as_identifier() {
313 if !args.is_empty() {
314 self.size_args.push(&args[0]);
316 }
317 }
318
319 walk_call(self, target, args);
322 }
323 }
324
325 let mut collector = SizeCallArgumentCollector { size_args: vec![] };
326 ast.accept(&mut collector);
328
329 assert_eq!(collector.size_args.len(), 2);
330 assert_eq!(collector.size_args[0].as_identifier(), Some("list_a"));
333 assert_eq!(collector.size_args[1].as_identifier(), Some("list_b"));
334 }
335}