thin_shunting/
parser.rs

1use crate::tokenizer::{MathToken, MathTokenizer};
2use std::cmp::Ordering;
3
4#[derive(PartialEq, Debug)]
5pub enum Assoc {
6    Left,
7    Right,
8    None,
9}
10
11pub fn precedence(mt: &MathToken) -> (usize, Assoc) {
12    // You can play with the relation between exponentiation an unary - by
13    // a. switching order in which the lexer tokenizes, if it tries
14    // operators first then '-' will never be the negative part of number,
15    // else if numbers are tried before operators, - can only be unary
16    // for non-numeric tokens (eg: -(3)).
17    // b. changing the precedence of '-' respect to '^'
18    // If '-' has lower precedence then 2^-3 will fail to evaluate if the
19    // '-' isn't part of the number because ^ will only find 1 operator
20    match *mt {
21        MathToken::OParen => (1, Assoc::Left), // keep at bottom
22        MathToken::BOp(ref o) if o == "+" => (2, Assoc::Left),
23        MathToken::BOp(ref o) if o == "-" => (2, Assoc::Left),
24        MathToken::BOp(ref o) if o == "*" => (3, Assoc::Left),
25        MathToken::BOp(ref o) if o == "/" => (3, Assoc::Left),
26        MathToken::BOp(ref o) if o == "%" => (3, Assoc::Left),
27        MathToken::UOp(ref o) if o == "-" => (5, Assoc::Right), // unary minus
28        MathToken::BOp(ref o) if o == "^" => (5, Assoc::Right),
29        MathToken::UOp(ref o) if o == "!" => (6, Assoc::Left), // factorial
30        MathToken::Function(_, _) => (7, Assoc::Left),
31        _ => (99, Assoc::None),
32    }
33}
34
35#[derive(PartialEq, Debug, Clone)]
36pub struct RPNExpr(pub Vec<MathToken>);
37
38pub struct ShuntingParser;
39
40impl ShuntingParser {
41    pub fn parse_str(expr: &str) -> Result<RPNExpr, String> {
42        Self::parse(&mut MathTokenizer::new(expr.chars()))
43    }
44
45    pub fn parse(lex: &mut impl Iterator<Item = MathToken>) -> Result<RPNExpr, String> {
46        let mut out = Vec::new();
47        let mut stack = Vec::new();
48        let mut arity = Vec::<usize>::new();
49
50        for token in lex {
51            match token {
52                MathToken::Number(_) => out.push(token),
53                MathToken::Variable(_) => out.push(token),
54                MathToken::OParen => stack.push(token),
55                MathToken::Function(_, _) => {
56                    stack.push(token);
57                    arity.push(1);
58                }
59                MathToken::Comma | MathToken::CParen => {
60                    while !stack.is_empty() && stack.last() != Some(&MathToken::OParen) {
61                        out.push(stack.pop().unwrap());
62                    }
63                    if stack.is_empty() {
64                        return Err("Missing Opening Paren".to_string());
65                    }
66                    // end of grouping: check if this is a function call
67                    if token == MathToken::CParen {
68                        stack.pop(); // peel matching OParen
69                        match stack.pop() {
70                            Some(MathToken::Function(func, _)) => {
71                                out.push(MathToken::Function(func, arity.pop().unwrap()))
72                            }
73                            Some(other) => stack.push(other),
74                            None => (),
75                        }
76                    } else if let Some(a) = arity.last_mut() {
77                        *a += 1;
78                    } // Comma
79                }
80                MathToken::UOp(_) | MathToken::BOp(_) => {
81                    let (prec_rhs, assoc_rhs) = precedence(&token);
82                    while !stack.is_empty() {
83                        let (prec_lhs, _) = precedence(stack.last().unwrap());
84                        match prec_lhs.cmp(&prec_rhs) {
85                            Ordering::Greater => out.push(stack.pop().unwrap()),
86                            Ordering::Less => break,
87                            Ordering::Equal => match assoc_rhs {
88                                Assoc::Left => out.push(stack.pop().unwrap()),
89                                Assoc::None => return Err("No Associativity".to_string()),
90                                Assoc::Right => break,
91                            },
92                        }
93                    }
94                    stack.push(token);
95                }
96                MathToken::Unknown(lexeme) => return Err(format!("Bad token: {}", lexeme)),
97            }
98        }
99        while let Some(top) = stack.pop() {
100            match top {
101                MathToken::OParen => return Err("Missing Closing Paren".to_string()),
102                token => out.push(token),
103            }
104        }
105        Ok(RPNExpr(out))
106    }
107}
108
109#[cfg(test)]
110mod tests {
111    use crate::parser::{RPNExpr, ShuntingParser};
112    use crate::tokenizer::MathToken;
113
114    #[test]
115    fn test_parse1() {
116        let rpn = ShuntingParser::parse_str("3+4*2/-(1-5)^2^3").unwrap();
117        let expect = vec![
118            MathToken::Number(3.0),
119            MathToken::Number(4.0),
120            MathToken::Number(2.0),
121            MathToken::BOp(format!("*")),
122            MathToken::Number(1.0),
123            MathToken::Number(5.0),
124            MathToken::BOp(format!("-")),
125            MathToken::Number(2.0),
126            MathToken::Number(3.0),
127            MathToken::BOp(format!("^")),
128            MathToken::BOp(format!("^")),
129            MathToken::UOp(format!("-")),
130            MathToken::BOp(format!("/")),
131            MathToken::BOp(format!("+")),
132        ];
133        assert_eq!(rpn, RPNExpr(expect));
134    }
135    #[test]
136    fn test_parse2() {
137        let rpn = ShuntingParser::parse_str("3.4e-2 * sin(x)/(7! % -4) * max(2, x)").unwrap();
138        let expect = vec![
139            MathToken::Number(3.4e-2),
140            MathToken::Variable(format!("x")),
141            MathToken::Function(format!("sin"), 1),
142            MathToken::BOp(format!("*")),
143            MathToken::Number(7.0),
144            MathToken::UOp(format!("!")),
145            MathToken::Number(4.0),
146            MathToken::UOp(format!("-")),
147            MathToken::BOp(format!("%")),
148            MathToken::BOp(format!("/")),
149            MathToken::Number(2.0),
150            MathToken::Variable(format!("x")),
151            MathToken::Function(format!("max"), 2),
152            MathToken::BOp(format!("*")),
153        ];
154        assert_eq!(rpn, RPNExpr(expect));
155    }
156
157    #[test]
158    fn test_parse3() {
159        let rpn = ShuntingParser::parse_str("sqrt(-(1-x^2) / (1 + x^2))").unwrap();
160        let expect = vec![
161            MathToken::Number(1.0),
162            MathToken::Variable(format!("x")),
163            MathToken::Number(2.0),
164            MathToken::BOp(format!("^")),
165            MathToken::BOp(format!("-")),
166            MathToken::UOp(format!("-")),
167            MathToken::Number(1.0),
168            MathToken::Variable(format!("x")),
169            MathToken::Number(2.0),
170            MathToken::BOp(format!("^")),
171            MathToken::BOp(format!("+")),
172            MathToken::BOp(format!("/")),
173            MathToken::Function(format!("sqrt"), 1),
174        ];
175        assert_eq!(rpn, RPNExpr(expect));
176    }
177
178    #[test]
179    fn bad_parse() {
180        let rpn = ShuntingParser::parse_str("sqrt(-(1-x^2) / (1 + x^2)");
181        assert_eq!(rpn, Err(format!("Missing Closing Paren")));
182
183        let rpn = ShuntingParser::parse_str("-(1-x^2) / (1 + x^2))");
184        assert_eq!(rpn, Err(format!("Missing Opening Paren")));
185
186        let rpn = ShuntingParser::parse_str("max 4, 6, 4)");
187        assert_eq!(rpn, Err(format!("Missing Opening Paren")));
188    }
189
190    #[test]
191    fn check_arity() {
192        use std::collections::HashMap;
193        let rpn = ShuntingParser::parse_str("sin(1)+(max(2, gamma(3.5), gcd(24, 8))+sum(i,0,10))")
194            .unwrap();
195        let mut expect = HashMap::new();
196        expect.insert("sin", 1);
197        expect.insert("max", 3);
198        expect.insert("gamma", 1);
199        expect.insert("gcd", 2);
200        expect.insert("sum", 3);
201
202        for token in rpn.0.iter() {
203            match *token {
204                MathToken::Function(ref func, arity) => {
205                    let expected_arity = expect.get(&func[..]);
206                    assert_eq!(*expected_arity.unwrap(), arity);
207                }
208                _ => (),
209            }
210        }
211    }
212}