thin_shunting/
rpneval.rs

1use crate::parser::RPNExpr;
2use crate::tokenizer::MathToken;
3use std::collections::HashMap;
4
5// a shorthand for checking number of arguments before eval_fn
6macro_rules! nargs {
7    ($argcheck:expr, $ifok:expr) => {
8        if $argcheck {
9            $ifok
10        } else {
11            Err("Wrong number of arguments".to_string())
12        }
13    };
14}
15
16#[derive(Debug, Clone)]
17pub struct MathContext(pub HashMap<String, f64>);
18
19impl MathContext {
20    pub fn new() -> MathContext {
21        use std::f64::consts;
22        let mut cx = HashMap::new();
23        cx.insert("pi".to_string(), consts::PI);
24        cx.insert("e".to_string(), consts::E);
25        MathContext(cx)
26    }
27
28    pub fn setvar(&mut self, var: &str, val: f64) {
29        self.0.insert(var.to_string(), val);
30    }
31
32    pub fn eval(&self, rpn: &RPNExpr) -> Result<f64, String> {
33        let mut operands = Vec::new();
34
35        for token in rpn.0.iter() {
36            match *token {
37                MathToken::Number(num) => operands.push(num),
38                MathToken::Variable(ref var) => match self.0.get(var) {
39                    Some(value) => operands.push(*value),
40                    None => return Err(format!("Unknown Variable: {}", var.to_string())),
41                },
42                MathToken::BOp(ref op) => {
43                    let r = operands
44                        .pop()
45                        .ok_or_else(|| "Wrong number of arguments".to_string())?;
46                    let l = operands
47                        .pop()
48                        .ok_or_else(|| "Wrong number of arguments".to_string())?;
49                    match &op[..] {
50                        "+" => operands.push(l + r),
51                        "-" => operands.push(l - r),
52                        "*" => operands.push(l * r),
53                        "/" => operands.push(l / r),
54                        "%" => operands.push(l % r),
55                        "^" => operands.push(l.powf(r)),
56                        _ => return Err(format!("Bad Token: {}", op.clone())),
57                    }
58                }
59                MathToken::UOp(ref op) => {
60                    let o = operands
61                        .pop()
62                        .ok_or_else(|| "Wrong number of arguments".to_string())?;
63                    match &op[..] {
64                        "-" => operands.push(-o),
65                        "!" => operands.push(Self::eval_fn("tgamma", vec![o + 1.0])?),
66                        _ => return Err(format!("Bad Token: {}", op.clone())),
67                    }
68                }
69                MathToken::Function(ref fname, arity) => {
70                    if arity > operands.len() {
71                        return Err("Wrong number of arguments".to_string());
72                    }
73                    let cut = operands.len() - arity;
74                    let args = operands.split_off(cut);
75                    operands.push(Self::eval_fn(fname, args)?)
76                }
77                _ => return Err(format!("Bad Token: {:?}", *token)),
78            }
79        }
80        operands
81            .pop()
82            .ok_or_else(|| "Wrong number of arguments".to_string())
83    }
84
85    fn eval_fn(fname: &str, args: Vec<f64>) -> Result<f64, String> {
86        match fname {
87            "sin" => nargs!(args.len() == 1, Ok(args[0].sin())),
88            "cos" => nargs!(args.len() == 1, Ok(args[0].cos())),
89            "atan2" => nargs!(args.len() == 2, Ok(args[0].atan2(args[1]))),
90            "max" => nargs!(
91                !args.is_empty(),
92                Ok(args[1..].iter().fold(args[0], |a, &item| a.max(item)))
93            ),
94            "min" => nargs!(
95                !args.is_empty(),
96                Ok(args[1..].iter().fold(args[0], |a, &item| a.min(item)))
97            ),
98            "abs" => nargs!(args.len() == 1, Ok(f64::abs(args[0]))),
99            "rand" => nargs!(args.len() == 1, Ok(args[0] * rand::random::<f64>())),
100            // Order is important
101            "nMPr" => nargs!(args.len() == 2, Ok(args[0].powf(args[1]))),
102            // Unknown function
103            _ => Err(format!("Unknown function: {}", fname)),
104        }
105    }
106}
107
108impl Default for MathContext {
109    fn default() -> Self {
110        Self::new()
111    }
112}
113
114#[cfg(test)]
115mod tests {
116    use super::MathContext;
117    use crate::parser::ShuntingParser;
118
119    macro_rules! fuzzy_eq {
120        ($lhs:expr, $rhs:expr) => {
121            assert!(($lhs - $rhs).abs() < 1.0e-10)
122        };
123    }
124
125    #[test]
126    fn test_eval1() {
127        let expr = ShuntingParser::parse_str("3+4*2/-(1-5)^2^3").unwrap();
128        fuzzy_eq!(MathContext::new().eval(&expr).unwrap(), 2.99987792969);
129    }
130
131    #[test]
132    fn test_eval2() {
133        let expr = ShuntingParser::parse_str("3.4e-2 * sin(pi/3)/(541 % -4) * max(2, -7)").unwrap();
134        fuzzy_eq!(MathContext::new().eval(&expr).unwrap(), 0.058889727457341);
135    }
136
137    #[test]
138    fn test_eval3() {
139        let expr = ShuntingParser::parse_str("(-(1-9^2) / (1 + 6^2))^0.5").unwrap();
140        fuzzy_eq!(
141            MathContext::new().eval(&expr).unwrap(),
142            1.470429244187615496759
143        );
144    }
145
146    #[test]
147    fn test_eval4() {
148        let expr = ShuntingParser::parse_str("sin(0.345)^2 + cos(0.345)^2").unwrap();
149        fuzzy_eq!(MathContext::new().eval(&expr).unwrap(), 1.0);
150    }
151
152    #[test]
153    fn test_eval5() {
154        let expr = ShuntingParser::parse_str("sin(e)/cos(e)").unwrap();
155        fuzzy_eq!(MathContext::new().eval(&expr).unwrap(), -0.4505495340698074);
156    }
157
158    #[test]
159    fn test_eval6() {
160        let expr = ShuntingParser::parse_str("(3+4)*3").unwrap();
161        fuzzy_eq!(MathContext::new().eval(&expr).unwrap(), 21.0);
162    }
163
164    #[test]
165    fn test_eval7() {
166        let expr = ShuntingParser::parse_str("(3+4)*3").unwrap();
167        fuzzy_eq!(MathContext::new().eval(&expr).unwrap(), 21.0);
168    }
169
170    #[test]
171    fn test_eval8() {
172        let expr = ShuntingParser::parse_str("2^3").unwrap();
173        fuzzy_eq!(MathContext::new().eval(&expr).unwrap(), 8.0);
174        let expr = ShuntingParser::parse_str("2^-3").unwrap();
175        fuzzy_eq!(MathContext::new().eval(&expr).unwrap(), 0.125);
176        let expr = ShuntingParser::parse_str("-2^3").unwrap();
177        fuzzy_eq!(MathContext::new().eval(&expr).unwrap(), -8.0);
178        let expr = ShuntingParser::parse_str("-2^-3").unwrap();
179        fuzzy_eq!(MathContext::new().eval(&expr).unwrap(), -0.125);
180    }
181}