1use crate::tokenizer::{MathToken, MathTokenizer};
2use std::cmp::Ordering;
3
4#[derive(PartialEq, Debug)]
5pub enum Assoc {
6 Left,
7 Right,
8 None,
9}
10
11pub fn precedence(mt: &MathToken) -> (usize, Assoc) {
12 match *mt {
21 MathToken::OParen => (1, Assoc::Left), MathToken::BOp(ref o) if o == "+" => (2, Assoc::Left),
23 MathToken::BOp(ref o) if o == "-" => (2, Assoc::Left),
24 MathToken::BOp(ref o) if o == "*" => (3, Assoc::Left),
25 MathToken::BOp(ref o) if o == "/" => (3, Assoc::Left),
26 MathToken::BOp(ref o) if o == "%" => (3, Assoc::Left),
27 MathToken::UOp(ref o) if o == "-" => (5, Assoc::Right), MathToken::BOp(ref o) if o == "^" => (5, Assoc::Right),
29 MathToken::UOp(ref o) if o == "!" => (6, Assoc::Left), MathToken::Function(_, _) => (7, Assoc::Left),
31 _ => (99, Assoc::None),
32 }
33}
34
35#[derive(PartialEq, Debug, Clone)]
36pub struct RPNExpr(pub Vec<MathToken>);
37
38pub struct ShuntingParser;
39
40impl ShuntingParser {
41 pub fn parse_str(expr: &str) -> Result<RPNExpr, String> {
42 Self::parse(&mut MathTokenizer::new(expr.chars()))
43 }
44
45 pub fn parse(lex: &mut impl Iterator<Item = MathToken>) -> Result<RPNExpr, String> {
46 let mut out = Vec::new();
47 let mut stack = Vec::new();
48 let mut arity = Vec::<usize>::new();
49
50 for token in lex {
51 match token {
52 MathToken::Number(_) => out.push(token),
53 MathToken::Variable(_) => out.push(token),
54 MathToken::OParen => stack.push(token),
55 MathToken::Function(_, _) => {
56 stack.push(token);
57 arity.push(1);
58 }
59 MathToken::Comma | MathToken::CParen => {
60 while !stack.is_empty() && stack.last() != Some(&MathToken::OParen) {
61 out.push(stack.pop().unwrap());
62 }
63 if stack.is_empty() {
64 return Err("Missing Opening Paren".to_string());
65 }
66 if token == MathToken::CParen {
68 stack.pop(); match stack.pop() {
70 Some(MathToken::Function(func, _)) => {
71 out.push(MathToken::Function(func, arity.pop().unwrap()))
72 }
73 Some(other) => stack.push(other),
74 None => (),
75 }
76 } else if let Some(a) = arity.last_mut() {
77 *a += 1;
78 } }
80 MathToken::UOp(_) | MathToken::BOp(_) => {
81 let (prec_rhs, assoc_rhs) = precedence(&token);
82 while !stack.is_empty() {
83 let (prec_lhs, _) = precedence(stack.last().unwrap());
84 match prec_lhs.cmp(&prec_rhs) {
85 Ordering::Greater => out.push(stack.pop().unwrap()),
86 Ordering::Less => break,
87 Ordering::Equal => match assoc_rhs {
88 Assoc::Left => out.push(stack.pop().unwrap()),
89 Assoc::None => return Err("No Associativity".to_string()),
90 Assoc::Right => break,
91 },
92 }
93 }
94 stack.push(token);
95 }
96 MathToken::Unknown(lexeme) => return Err(format!("Bad token: {}", lexeme)),
97 }
98 }
99 while let Some(top) = stack.pop() {
100 match top {
101 MathToken::OParen => return Err("Missing Closing Paren".to_string()),
102 token => out.push(token),
103 }
104 }
105 Ok(RPNExpr(out))
106 }
107}
108
109#[cfg(test)]
110mod tests {
111 use crate::parser::{RPNExpr, ShuntingParser};
112 use crate::tokenizer::MathToken;
113
114 #[test]
115 fn test_parse1() {
116 let rpn = ShuntingParser::parse_str("3+4*2/-(1-5)^2^3").unwrap();
117 let expect = vec![
118 MathToken::Number(3.0),
119 MathToken::Number(4.0),
120 MathToken::Number(2.0),
121 MathToken::BOp(format!("*")),
122 MathToken::Number(1.0),
123 MathToken::Number(5.0),
124 MathToken::BOp(format!("-")),
125 MathToken::Number(2.0),
126 MathToken::Number(3.0),
127 MathToken::BOp(format!("^")),
128 MathToken::BOp(format!("^")),
129 MathToken::UOp(format!("-")),
130 MathToken::BOp(format!("/")),
131 MathToken::BOp(format!("+")),
132 ];
133 assert_eq!(rpn, RPNExpr(expect));
134 }
135 #[test]
136 fn test_parse2() {
137 let rpn = ShuntingParser::parse_str("3.4e-2 * sin(x)/(7! % -4) * max(2, x)").unwrap();
138 let expect = vec![
139 MathToken::Number(3.4e-2),
140 MathToken::Variable(format!("x")),
141 MathToken::Function(format!("sin"), 1),
142 MathToken::BOp(format!("*")),
143 MathToken::Number(7.0),
144 MathToken::UOp(format!("!")),
145 MathToken::Number(4.0),
146 MathToken::UOp(format!("-")),
147 MathToken::BOp(format!("%")),
148 MathToken::BOp(format!("/")),
149 MathToken::Number(2.0),
150 MathToken::Variable(format!("x")),
151 MathToken::Function(format!("max"), 2),
152 MathToken::BOp(format!("*")),
153 ];
154 assert_eq!(rpn, RPNExpr(expect));
155 }
156
157 #[test]
158 fn test_parse3() {
159 let rpn = ShuntingParser::parse_str("sqrt(-(1-x^2) / (1 + x^2))").unwrap();
160 let expect = vec![
161 MathToken::Number(1.0),
162 MathToken::Variable(format!("x")),
163 MathToken::Number(2.0),
164 MathToken::BOp(format!("^")),
165 MathToken::BOp(format!("-")),
166 MathToken::UOp(format!("-")),
167 MathToken::Number(1.0),
168 MathToken::Variable(format!("x")),
169 MathToken::Number(2.0),
170 MathToken::BOp(format!("^")),
171 MathToken::BOp(format!("+")),
172 MathToken::BOp(format!("/")),
173 MathToken::Function(format!("sqrt"), 1),
174 ];
175 assert_eq!(rpn, RPNExpr(expect));
176 }
177
178 #[test]
179 fn bad_parse() {
180 let rpn = ShuntingParser::parse_str("sqrt(-(1-x^2) / (1 + x^2)");
181 assert_eq!(rpn, Err(format!("Missing Closing Paren")));
182
183 let rpn = ShuntingParser::parse_str("-(1-x^2) / (1 + x^2))");
184 assert_eq!(rpn, Err(format!("Missing Opening Paren")));
185
186 let rpn = ShuntingParser::parse_str("max 4, 6, 4)");
187 assert_eq!(rpn, Err(format!("Missing Opening Paren")));
188 }
189
190 #[test]
191 fn check_arity() {
192 use std::collections::HashMap;
193 let rpn = ShuntingParser::parse_str("sin(1)+(max(2, gamma(3.5), gcd(24, 8))+sum(i,0,10))")
194 .unwrap();
195 let mut expect = HashMap::new();
196 expect.insert("sin", 1);
197 expect.insert("max", 3);
198 expect.insert("gamma", 1);
199 expect.insert("gcd", 2);
200 expect.insert("sum", 3);
201
202 for token in rpn.0.iter() {
203 match *token {
204 MathToken::Function(ref func, arity) => {
205 let expected_arity = expect.get(&func[..]);
206 assert_eq!(*expected_arity.unwrap(), arity);
207 }
208 _ => (),
209 }
210 }
211 }
212}