rust_cel_parser/
ast.rs

1use crate::visitor::Visitor;
2use std::fmt;
3
4// --- Literals ---
5/// A literal is a primitive value that can be used in an expression.
6#[derive(PartialEq, Clone)]
7pub enum Literal {
8    Int(i64),
9    Uint(u64),
10    Float(f64),
11    String(String),
12    Bytes(Vec<u8>),
13    Bool(bool),
14    Null,
15}
16
17impl fmt::Debug for Literal {
18    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
19        match self {
20            Literal::Int(i) => write!(f, "{}", i),
21            Literal::Uint(u) => write!(f, "{}u", u),
22            Literal::Float(fl) => {
23                let s = fl.to_string();
24                if s.contains('.') || s.contains('e') || s.contains('E') {
25                    write!(f, "{}", s)
26                } else {
27                    write!(f, "{}.0", s)
28                }
29            }
30            Literal::String(s) => write!(f, "\"{}\"", s.escape_debug()),
31            Literal::Bytes(b) => {
32                write!(f, "b\"")?;
33                for &byte in b {
34                    if byte == b'\\' {
35                        write!(f, "\\\\")?;
36                    } else if byte == b'"' {
37                        write!(f, "\\\"")?;
38                    } else if byte >= 0x20 && byte <= 0x7e {
39                        write!(f, "{}", byte as char)?;
40                    } else {
41                        write!(f, "\\x{:02x}", byte)?;
42                    }
43                }
44                write!(f, "\"")
45            }
46            Literal::Bool(b) => write!(f, "{}", b),
47            Literal::Null => write!(f, "null"),
48        }
49    }
50}
51
52impl fmt::Display for Literal {
53    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
54        match self {
55            Literal::Int(i) => write!(f, "{}", i),
56            Literal::Uint(u) => write!(f, "{}u", u),
57            Literal::Float(fl) => write!(f, "{}", fl),
58            Literal::Bool(b) => write!(f, "{}", b),
59            Literal::Null => write!(f, "null"),
60            Literal::String(s) => {
61                write!(f, "\"")?;
62                for c in s.chars() {
63                    match c {
64                        '"' => write!(f, "\\\"")?,
65                        '\\' => write!(f, "\\\\")?,
66                        '\n' => write!(f, "\\n")?,
67                        '\r' => write!(f, "\\r")?,
68                        '\t' => write!(f, "\\t")?,
69                        _ => write!(f, "{}", c)?,
70                    }
71                }
72                write!(f, "\"")
73            }
74            Literal::Bytes(b) => {
75                write!(f, "b\"")?;
76                for &byte in b {
77                    if byte == b'\\' {
78                        write!(f, "\\\\")?;
79                    } else if byte == b'"' {
80                        write!(f, "\\\"")?;
81                    } else if byte >= 0x20 && byte <= 0x7e {
82                        write!(f, "{}", byte as char)?;
83                    } else {
84                        write!(f, "\\x{:02x}", byte)?;
85                    }
86                }
87                write!(f, "\"")
88            }
89        }
90    }
91}
92
93/// A comprehension operation is a way to iterate over a collection and apply a filter or transformation.
94#[derive(PartialEq, Debug, Clone, Copy)]
95pub enum ComprehensionOp {
96    All,
97    Exists,
98    ExistsOne,
99    Filter,
100}
101
102// --- Operators ---
103
104/// A unary operator is an operator that takes a single operand.
105#[derive(PartialEq, Debug, Clone, Copy)]
106pub enum UnaryOperator {
107    Not, // !
108    Neg, // -
109}
110
111impl fmt::Display for UnaryOperator {
112    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
113        match self {
114            UnaryOperator::Not => write!(f, "!"),
115            UnaryOperator::Neg => write!(f, "-"),
116        }
117    }
118}
119
120/// An operator that takes two operands.
121#[derive(PartialEq, Debug, Clone, Copy)]
122pub enum BinaryOperator {
123    Or,  // ||
124    And, // &&
125    Eq,  // ==
126    Ne,  // !=
127    Lt,  // <
128    Le,  // <=
129    Gt,  // >
130    Ge,  // >=
131    In,  // in
132    Add, // +
133    Sub, // -
134    Mul, // *
135    Div, // /
136    Rem, // %
137}
138
139impl BinaryOperator {
140    /// Returns the precedence level of the operator. Higher numbers bind more tightly.
141    fn precedence(&self) -> u8 {
142        match self {
143            BinaryOperator::Or => 1,
144            BinaryOperator::And => 2,
145            BinaryOperator::Eq
146            | BinaryOperator::Ne
147            | BinaryOperator::Lt
148            | BinaryOperator::Le
149            | BinaryOperator::Gt
150            | BinaryOperator::Ge
151            | BinaryOperator::In => 3,
152            BinaryOperator::Add | BinaryOperator::Sub => 4,
153            BinaryOperator::Mul | BinaryOperator::Div | BinaryOperator::Rem => 5,
154        }
155    }
156}
157
158impl fmt::Display for BinaryOperator {
159    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
160        let s = match self {
161            BinaryOperator::Or => "||",
162            BinaryOperator::And => "&&",
163            BinaryOperator::Eq => "==",
164            BinaryOperator::Ne => "!=",
165            BinaryOperator::Lt => "<",
166            BinaryOperator::Le => "<=",
167            BinaryOperator::Gt => ">",
168            BinaryOperator::Ge => ">=",
169            BinaryOperator::In => "in",
170            BinaryOperator::Add => "+",
171            BinaryOperator::Sub => "-",
172            BinaryOperator::Mul => "*",
173            BinaryOperator::Div => "/",
174            BinaryOperator::Rem => "%",
175        };
176        write!(f, "{}", s)
177    }
178}
179
180/// A way to represent the type of a value in an expression.
181#[derive(PartialEq, Debug, Clone, Copy)]
182pub enum CelType {
183    Int,
184    Uint,
185    Double,
186    Bool,
187    String,
188    Bytes,
189    List,
190    Map,
191    NullType,
192    Type,
193}
194
195
196#[derive(PartialEq, Debug, Clone)]
197pub enum Expr {
198    Literal(Literal),
199    Identifier(String),
200    UnaryOp {
201        op: UnaryOperator,
202        operand: Box<Expr>,
203    },
204    BinaryOp {
205        op: BinaryOperator,
206        left: Box<Expr>,
207        right: Box<Expr>,
208    },
209    Conditional {
210        cond: Box<Expr>,
211        true_branch: Box<Expr>,
212        false_branch: Box<Expr>,
213    },
214    List {
215        elements: Vec<Expr>,
216    },
217    FieldAccess {
218        base: Box<Expr>,
219        field: String,
220    },
221    Call {
222        target: Box<Expr>,
223        args: Vec<Expr>,
224    },
225    Index {
226        base: Box<Expr>,
227        index: Box<Expr>,
228    },
229    MapLiteral {
230        entries: Vec<(Expr, Expr)>,
231    },
232    MessageLiteral {
233        type_name: String,
234        fields: Vec<(String, Expr)>,
235    },
236    Has {
237        target: Box<Expr>,
238    },
239    Comprehension {
240        op: ComprehensionOp,
241        target: Box<Expr>,
242        iter_var: String,
243        predicate: Box<Expr>,
244    },
245    Map {
246        target: Box<Expr>,
247        iter_var: String,
248        filter: Option<Box<Expr>>,
249        transform: Box<Expr>,
250    },
251    Type(CelType),
252}
253
254// --- Helper Methods ---
255impl Expr {
256    pub fn accept<'ast, V: Visitor<'ast>>(&'ast self, visitor: &mut V) {
257        visitor.visit_expr(self);
258    }
259    // ... (other helpers remain unchanged) ...
260    pub fn is_literal(&self) -> bool {
261        matches!(self, Expr::Literal(_))
262    }
263    pub fn as_literal(&self) -> Option<&Literal> {
264        match self {
265            Expr::Literal(lit) => Some(lit),
266            _ => None,
267        }
268    }
269    pub fn as_identifier(&self) -> Option<&str> {
270        match self {
271            Expr::Identifier(name) => Some(name),
272            _ => None,
273        }
274    }
275    pub fn as_binary_op(&self) -> Option<(BinaryOperator, &Expr, &Expr)> {
276        match self {
277            Expr::BinaryOp { op, left, right } => Some((*op, left, right)),
278            _ => None,
279        }
280    }
281    pub fn as_unary_op(&self) -> Option<(UnaryOperator, &Expr)> {
282        match self {
283            Expr::UnaryOp { op, operand } => Some((*op, operand)),
284            _ => None,
285        }
286    }
287    pub fn as_call(&self) -> Option<(&Expr, &[Expr])> {
288        match self {
289            Expr::Call { target, args } => Some((target, args)),
290            _ => None,
291        }
292    }
293}
294
295impl fmt::Display for Expr {
296    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
297        match self {
298            Expr::Literal(lit) => write!(f, "{}", lit),
299            Expr::Identifier(s) => write!(f, "{}", s),
300            Expr::UnaryOp { op, operand } => {
301                // Add parentheses if the operand is a binary operation
302                if let Expr::BinaryOp { .. } = **operand {
303                    write!(f, "{}({})", op, operand)
304                } else {
305                    write!(f, "{}{}", op, operand)
306                }
307            }
308            Expr::BinaryOp { op, left, right } => {
309                // Helper to format a child expression, adding parentheses if needed.
310                let format_child = |child: &Expr| -> String {
311                    if let Expr::BinaryOp { op: child_op, .. } = child {
312                        if child_op.precedence() < op.precedence() {
313                            return format!("({})", child);
314                        }
315                    }
316                    format!("{}", child)
317                };
318                write!(f, "{} {} {}", format_child(left), op, format_child(right))
319            }
320            Expr::Conditional {
321                cond,
322                true_branch,
323                false_branch,
324            } => {
325                write!(f, "{} ? {} : {}", cond, true_branch, false_branch)
326            }
327            Expr::List { elements } => {
328                write!(f, "[")?;
329                for (i, elem) in elements.iter().enumerate() {
330                    if i > 0 {
331                        write!(f, ", ")?;
332                    }
333                    write!(f, "{}", elem)?;
334                }
335                write!(f, "]")
336            }
337            Expr::MapLiteral { entries } => {
338                write!(f, "{{")?;
339                for (i, (key, value)) in entries.iter().enumerate() {
340                    if i > 0 {
341                        write!(f, ", ")?;
342                    }
343                    write!(f, "{}: {}", key, value)?;
344                }
345                write!(f, "}}")
346            }
347            Expr::FieldAccess { base, field } => write!(f, "{}.{}", base, field),
348            Expr::Index { base, index } => write!(f, "{}[{}]", base, index),
349            Expr::Call { target, args } => {
350                write!(f, "{}(", target)?;
351                for (i, arg) in args.iter().enumerate() {
352                    if i > 0 {
353                        write!(f, ", ")?;
354                    }
355                    write!(f, "{}", arg)?;
356                }
357                write!(f, ")")
358            }
359            Expr::MessageLiteral { type_name, fields } => {
360                write!(f, "{} {{", type_name)?;
361                for (i, (name, value)) in fields.iter().enumerate() {
362                    if i > 0 {
363                        write!(f, ", ")?;
364                    }
365                    write!(f, "{}: {}", name, value)?;
366                }
367                write!(f, "}}")
368            }
369            Expr::Has { target } => write!(f, "has({})", target),
370            // Display for comprehensions can be complex, this is a simplified version.
371            Expr::Comprehension {
372                op,
373                target,
374                iter_var,
375                predicate,
376            } => {
377                let op_str = match op {
378                    ComprehensionOp::All => "all",
379                    ComprehensionOp::Exists => "exists",
380                    ComprehensionOp::ExistsOne => "exists_one",
381                    ComprehensionOp::Filter => "filter",
382                };
383                write!(f, "{}.{}({}, {})", target, op_str, iter_var, predicate)
384            }
385            Expr::Map { .. } => write!(f, "{:?}", self), // Fallback for Map macro
386            Expr::Type(t) => write!(f, "{:?}", t),       // Fallback for Type
387        }
388    }
389}
390
391// --- Conversions to Literal ---
392
393impl From<i32> for Literal {
394    fn from(val: i32) -> Self {
395        Literal::Int(val as i64)
396    }
397}
398
399impl From<u32> for Literal {
400    fn from(val: u32) -> Self {
401        Literal::Uint(val as u64)
402    }
403}
404
405impl From<i64> for Literal {
406    fn from(val: i64) -> Self {
407        Literal::Int(val)
408    }
409}
410
411impl From<u64> for Literal {
412    fn from(val: u64) -> Self {
413        Literal::Uint(val)
414    }
415}
416
417impl From<f64> for Literal {
418    fn from(val: f64) -> Self {
419        Literal::Float(val)
420    }
421}
422
423impl From<bool> for Literal {
424    fn from(val: bool) -> Self {
425        Literal::Bool(val)
426    }
427}
428
429impl From<&str> for Literal {
430    fn from(val: &str) -> Self {
431        Literal::String(val.to_string())
432    }
433}
434
435impl From<String> for Literal {
436    fn from(val: String) -> Self {
437        Literal::String(val)
438    }
439}
440
441impl From<Vec<u8>> for Literal {
442    fn from(val: Vec<u8>) -> Self {
443        Literal::Bytes(val)
444    }
445}
446
447impl From<&[u8]> for Literal {
448    fn from(val: &[u8]) -> Self {
449        Literal::Bytes(val.to_vec())
450    }
451}
452
453// -- Conversions to Expr --
454
455impl From<Literal> for Expr {
456    fn from(val: Literal) -> Self {
457        Expr::Literal(val)
458    }
459}
460
461impl From<i32> for Expr {
462    fn from(val: i32) -> Self {
463        Expr::Literal(val.into())
464    }
465}
466
467impl From<u32> for Expr {
468    fn from(val: u32) -> Self {
469        Expr::Literal(val.into())
470    }
471}
472
473impl From<i64> for Expr {
474    fn from(val: i64) -> Self {
475        Expr::Literal(Literal::Int(val))
476    }
477}
478
479impl From<u64> for Expr {
480    fn from(val: u64) -> Self {
481        Expr::Literal(Literal::Uint(val))
482    }
483}
484
485impl From<f64> for Expr {
486    fn from(val: f64) -> Self {
487        Expr::Literal(Literal::Float(val))
488    }
489}
490
491impl From<bool> for Expr {
492    fn from(val: bool) -> Self {
493        Expr::Literal(Literal::Bool(val))
494    }
495}
496
497impl From<&str> for Expr {
498    fn from(val: &str) -> Self {
499        Expr::Literal(Literal::String(val.to_string()))
500    }
501}
502
503impl From<String> for Expr {
504    fn from(val: String) -> Self {
505        Expr::Literal(Literal::String(val))
506    }
507}
508
509impl From<Vec<u8>> for Expr {
510    fn from(val: Vec<u8>) -> Self {
511        Expr::Literal(Literal::Bytes(val))
512    }
513}
514
515impl From<&[u8]> for Expr {
516    fn from(val: &[u8]) -> Self {
517        Expr::Literal(Literal::Bytes(val.to_vec()))
518    }
519}
520
521// --- Tests ---
522#[cfg(test)]
523mod tests {
524
525    use super::*; // Import everything from the parent module (ast.rs)
526
527    #[test]
528    fn test_is_literal() {
529        let lit_expr = Expr::Literal(Literal::Int(42));
530        let non_lit_expr = Expr::Identifier("x".to_string());
531        assert!(lit_expr.is_literal());
532        assert!(!non_lit_expr.is_literal());
533    }
534
535    #[test]
536    fn test_as_literal() {
537        let lit_expr = Expr::Literal(Literal::Bool(true));
538        let non_lit_expr = Expr::Identifier("y".to_string());
539        assert_eq!(lit_expr.as_literal(), Some(&Literal::Bool(true)));
540        assert_eq!(non_lit_expr.as_literal(), None);
541    }
542
543    #[test]
544    fn test_as_identifier() {
545        let ident_expr = Expr::Identifier("my_var".to_string());
546        let non_ident_expr = Expr::Literal(Literal::Int(1));
547        assert_eq!(ident_expr.as_identifier(), Some("my_var"));
548        assert_eq!(non_ident_expr.as_identifier(), None);
549    }
550
551    #[test]
552    fn test_as_binary_op() {
553        let left = Box::new(Expr::Literal(Literal::Int(1)));
554        let right = Box::new(Expr::Literal(Literal::Int(2)));
555        let bin_op_expr = Expr::BinaryOp {
556            op: BinaryOperator::Add,
557            left: left.clone(),
558            right: right.clone(),
559        };
560        let non_bin_op_expr = Expr::Identifier("z".to_string());
561
562        assert_eq!(
563            bin_op_expr.as_binary_op(),
564            Some((BinaryOperator::Add, left.as_ref(), right.as_ref()))
565        );
566        assert_eq!(non_bin_op_expr.as_binary_op(), None);
567    }
568
569    #[test]
570    fn test_as_unary_op() {
571        let operand = Box::new(Expr::Literal(Literal::Bool(true)));
572        let unary_expr = Expr::UnaryOp {
573            op: UnaryOperator::Not,
574            operand: operand.clone(),
575        };
576        let non_unary_expr = Expr::Identifier("a".to_string());
577
578        assert_eq!(
579            unary_expr.as_unary_op(),
580            Some((UnaryOperator::Not, operand.as_ref()))
581        );
582        assert_eq!(non_unary_expr.as_unary_op(), None);
583    }
584
585    #[test]
586    fn test_as_call() {
587        let arg1 = Expr::Literal(Literal::Int(10));
588        let call_expr = Expr::Call {
589            target: Box::new(Expr::Identifier("my_func".to_string())),
590            args: vec![arg1.clone()],
591        };
592        let non_call_expr = Expr::Literal(Literal::Null);
593
594        let (target, args) = call_expr.as_call().unwrap();
595        assert_eq!(target, &Expr::Identifier("my_func".to_string()));
596        assert_eq!(args, &[arg1]);
597
598        assert_eq!(non_call_expr.as_call(), None);
599    }
600
601    // --- Tests for `From` implementations ---
602    #[cfg(test)]
603    mod from_impl_tests {
604        use crate::ast::{Expr, Literal};
605
606        #[test]
607        fn test_from_i64() {
608            let expr: Expr = 42.into();
609            assert_eq!(expr, Expr::Literal(Literal::Int(42)));
610            let lit: Literal = (-100).into();
611            assert_eq!(lit, Literal::Int(-100));
612        }
613
614        #[test]
615        fn test_from_u64() {
616            let expr: Expr = 123u64.into();
617            assert_eq!(expr, Expr::Literal(Literal::Uint(123)));
618            let lit: Literal = 0u64.into();
619            assert_eq!(lit, Literal::Uint(0));
620        }
621
622        #[test]
623        fn test_from_f64() {
624            let expr: Expr = 3.14.into();
625            assert_eq!(expr, Expr::Literal(Literal::Float(3.14)));
626            let lit: Literal = (-1.0e-5).into();
627            assert_eq!(lit, Literal::Float(-0.00001));
628        }
629
630        #[test]
631        fn test_from_bool() {
632            let expr: Expr = true.into();
633            assert_eq!(expr, Expr::Literal(Literal::Bool(true)));
634            let lit: Literal = false.into();
635            assert_eq!(lit, Literal::Bool(false));
636        }
637
638        #[test]
639        fn test_from_str_slice() {
640            let expr: Expr = "hello".into();
641            assert_eq!(expr, Expr::Literal(Literal::String("hello".to_string())));
642            let lit: Literal = "world".into();
643            assert_eq!(lit, Literal::String("world".to_string()));
644        }
645
646        #[test]
647        fn test_from_string() {
648            let s = String::from("owned");
649            let expr: Expr = s.clone().into(); // clone because into() consumes
650            assert_eq!(expr, Expr::Literal(Literal::String("owned".to_string())));
651            let lit: Literal = s.into();
652            assert_eq!(lit, Literal::String("owned".to_string()));
653        }
654
655        #[test]
656        fn test_from_u8_slice() {
657            let bytes: &[u8] = &[0, 1, 255];
658            let expr: Expr = bytes.into();
659            assert_eq!(expr, Expr::Literal(Literal::Bytes(vec![0, 1, 255])));
660            let lit: Literal = bytes.into();
661            assert_eq!(lit, Literal::Bytes(vec![0, 1, 255]));
662        }
663
664        #[test]
665        fn test_from_u8_vec() {
666            let bytes_vec = vec![10, 20, 30];
667            let expr: Expr = bytes_vec.clone().into(); // clone because into() consumes
668            assert_eq!(expr, Expr::Literal(Literal::Bytes(vec![10, 20, 30])));
669            let lit: Literal = bytes_vec.into();
670            assert_eq!(lit, Literal::Bytes(vec![10, 20, 30]));
671        }
672
673        #[test]
674        fn test_from_literal_to_expr() {
675            let lit = Literal::Int(123);
676            let expr: Expr = lit.clone().into();
677            assert_eq!(expr, Expr::Literal(lit));
678        }
679    }
680}
681
682#[cfg(test)]
683mod display_impl_tests {
684    use crate::parser::parse_cel_program;
685
686    /// Helper to parse an expression and assert its string representation.
687    fn assert_display(input: &str, expected: &str) {
688        let ast = parse_cel_program(input)
689            .unwrap_or_else(|e| panic!("Failed to parse input '{}': {}", input, e));
690        assert_eq!(ast.to_string(), expected);
691    }
692
693    #[test]
694    fn test_display_literals() {
695        assert_display("123", "123");
696        assert_display("456u", "456u");
697        assert_display("true", "true");
698        assert_display("null", "null");
699        assert_display("1.23", "1.23");
700        assert_display("\"hello world\"", "\"hello world\"");
701        assert_display("\"quotes \\\" here\"", "\"quotes \\\" here\"");
702        assert_display("b\"\\xFF\\x00\"", "b\"\\xff\\x00\"");
703    }
704
705    #[test]
706    fn test_display_simple_binary_op() {
707        assert_display("1 + 2", "1 + 2");
708        assert_display("a && b", "a && b");
709    }
710
711    #[test]
712    fn test_display_precedence_no_parens_needed() {
713        // Higher precedence op is on the right, so no parens needed.
714        assert_display("1 + 2 * 3", "1 + 2 * 3");
715        assert_display("a || b && c", "a || b && c");
716    }
717
718    #[test]
719    fn test_display_precedence_parens_needed() {
720        // Lower precedence op is a child of higher precedence op.
721        assert_display("(1 + 2) * 3", "(1 + 2) * 3");
722        assert_display("a && (b || c)", "a && (b || c)");
723    }
724
725    #[test]
726    fn test_display_left_associativity() {
727        // (1 - 2) + 3 should print without parens.
728        assert_display("1 - 2 + 3", "1 - 2 + 3");
729    }
730
731    #[test]
732    fn test_display_complex_expression() {
733        let expr = "request.user.id == 'admin' && resource.acl in user.groups";
734        let expected = "request.user.id == \"admin\" && resource.acl in user.groups";
735        assert_display(expr, expected);
736    }
737
738    #[test]
739    fn test_display_call_and_access() {
740        let expr = "a.b(c)[d]";
741        assert_display(expr, expr);
742    }
743
744    #[test]
745    fn test_display_list_and_map() {
746        assert_display("[1, true, \"three\"]", "[1, true, \"three\"]");
747        assert_display("{'a': 1, 2: false}", "{\"a\": 1, 2: false}");
748    }
749}