sari/
evaluator.rs

1use crate::ast::BinaryOp;
2use crate::ast::Node;
3use crate::error::Error;
4
5pub struct Evaluator<'a> {
6    ast: &'a Node,
7}
8
9impl Evaluator<'_> {
10    pub fn new(ast: &Node) -> Evaluator {
11        Evaluator { ast }
12    }
13
14    pub fn eval(&self) -> Result<i32, Error> {
15        fn eval_node(node: &Node) -> Result<i32, Error> {
16            match node {
17                Node::IntLit(value) => Ok(*value),
18                Node::BinaryExpr { op, left, right } => {
19                    let left = eval_node(left)?;
20                    let right = eval_node(right)?;
21
22                    match op {
23                        BinaryOp::Add => Ok(left.wrapping_add(right)),
24                        BinaryOp::Sub => Ok(left.wrapping_sub(right)),
25                        BinaryOp::Mul => Ok(left.wrapping_mul(right)),
26                        BinaryOp::Div => {
27                            if right != 0 {
28                                Ok(left.wrapping_div(right))
29                            } else {
30                                Err(Error::new("division by zero"))
31                            }
32                        }
33                    }
34                }
35            }
36        }
37
38        eval_node(self.ast)
39    }
40}
41
42#[cfg(test)]
43mod tests {
44    use super::*;
45    use crate::tests::helpers::ast::*;
46
47    macro_rules! assert_evals {
48        ($ast:expr, $value:expr) => {
49            let ast = $ast;
50            let evaluator = Evaluator::new(&ast);
51
52            assert_eq!(evaluator.eval(), Ok($value));
53        };
54    }
55
56    macro_rules! assert_does_not_eval {
57        ($ast:expr, $message:expr) => {
58            let ast = $ast;
59            let evaluator = Evaluator::new(&ast);
60
61            assert_eq!(evaluator.eval(), Err(Error::new($message)));
62        };
63    }
64
65    #[test]
66    fn evals_int_lit() {
67        assert_evals!(int(1), 1);
68    }
69
70    #[test]
71    fn evals_binary_expr_add() {
72        assert_evals!(add(int(1), int(2)), 3);
73
74        // overflow
75        assert_evals!(add(int(2147483647), int(1)), -2147483648);
76        assert_evals!(add(int(-2147483648), int(-1)), 2147483647);
77    }
78
79    #[test]
80    fn evals_binary_expr_sub() {
81        assert_evals!(sub(int(3), int(2)), 1);
82
83        // overflow
84        assert_evals!(sub(int(2147483647), int(-1)), -2147483648);
85        assert_evals!(sub(int(-2147483648), int(1)), 2147483647);
86    }
87
88    #[test]
89    fn evals_binary_expr_mul() {
90        assert_evals!(mul(int(2), int(3)), 6);
91
92        // overflow
93        assert_evals!(mul(int(-2147483648), int(-1)), -2147483648);
94    }
95
96    #[test]
97    fn evals_binary_expr_div() {
98        assert_evals!(div(int(6), int(3)), 2);
99
100        // overflow
101        assert_evals!(div(int(-2147483648), int(-1)), -2147483648);
102
103        // division by zero
104        assert_does_not_eval!(div(int(1), int(0)), "division by zero");
105    }
106
107    #[test]
108    fn evals_complex_expressions() {
109        assert_evals!(mul(add(int(1), int(2)), add(int(3), int(4))), 21);
110    }
111}