1use crate::token::Token;
2use crate::{lexer::Lexer, measurement::Measurement, value::pow, value::Value};
3use std::fmt;
4use std::panic;
5
6#[derive(Debug)]
12enum S {
13 Atom(Token), Group(Token, Vec<S>), }
16
17impl fmt::Display for S {
18 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
19 match self {
20 S::Atom(i) => write!(f, "{}", i),
21 S::Group(head, rest) => {
22 write!(f, "({}", head)?;
23 for s in rest {
24 write!(f, " {}", s)?
25 }
26 write!(f, ")")
27 }
28 }
29 }
30}
31
32fn expr(text: &str) -> S {
33 let mut lexer = Lexer::new(text);
34 expr_bp(&mut lexer, 0)
35}
36
37fn expr_bp(lexer: &mut Lexer, min_bp: u8) -> S {
39 let first_token = lexer.next();
40 let mut lhs = match first_token {
41 Token::PosNum(_) | Token::EulersNum | Token::Pi => S::Atom(first_token),
42 Token::LeftParen => {
43 let lhs = expr_bp(lexer, 0);
44 assert_eq!(lexer.next(), Token::RightParen);
45 lhs
46 }
47 Token::Minus => {
48 let ((), r_bp) = prefix_binding_power(&first_token);
49 let rhs = expr_bp(lexer, r_bp);
50 S::Group(first_token, vec![rhs])
51 }
52 t => panic!("bad token(lhs): {:?}", t),
53 };
54 loop {
55 let token = lexer.peek();
56 let op = match token {
57 Token::Eof => break,
58 Token::Add
59 | Token::Minus
60 | Token::Mul
61 | Token::Div
62 | Token::Caret
63 | Token::RightParen
64 | Token::PlusMinus => token,
65 Token::LeftParen => panic!("Excess left parenthesis \'(\'"),
66 t => panic!("bad token(rhs): {:?}", t),
67 };
68 if let Some((l_bp, r_bp)) = infix_binding_power(&op) {
69 if l_bp < min_bp {
70 break;
71 }
72 lexer.next();
73 let rhs = expr_bp(lexer, r_bp);
74 lhs = S::Group(op, vec![lhs, rhs]);
75 } else {
76 break;
78 }
79 }
80
81 lhs
82}
83
84fn prefix_binding_power(op: &Token) -> ((), u8) {
85 match op {
86 Token::Minus => ((), 9),
87 _ => panic!("bad operator: {:?} (is not a prefix operator)", op),
88 }
89}
90
91fn infix_binding_power(op: &Token) -> Option<(u8, u8)> {
98 let res = match op {
99 Token::Add | Token::Minus => (1, 2),
100 Token::Mul | Token::Div => (3, 4),
101 Token::Caret => (5, 6),
102 Token::PlusMinus => (7, 8),
103 _ => return None,
104 };
105 Some(res)
106}
107
108fn eval_expr(expression: &S) -> Value {
109 match expression {
110 S::Atom(token) => match token {
111 Token::PosNum(x) => Value::PosNumber(x.as_float()),
112 Token::EulersNum => Value::Number(std::f64::consts::E),
113 Token::Pi => Value::Number(std::f64::consts::PI),
114 _ => panic!("bad token(eval atom): {:?}", token),
115 },
116 S::Group(op, sub_expressions) => {
117 match op {
118 Token::Add => {
119 if sub_expressions.len() != 2 {
120 panic!(
121 "bad sub-expressions: {:?}, addition ('+') operator is binary.",
122 sub_expressions
123 )
124 } else {
125 let lhs = eval_expr(&sub_expressions[0]);
126 let rhs = eval_expr(&sub_expressions[1]);
127 lhs + rhs
128 }
129 }
130 Token::Minus => {
131 if sub_expressions.len() == 1 {
132 -eval_expr(&sub_expressions[0])
134 } else if sub_expressions.len() != 2 {
135 panic!(
136 "bad sub-expressions: {:?}, subtraction ('-') operator is binary.",
137 sub_expressions
138 )
139 } else {
140 let lhs = eval_expr(&sub_expressions[0]);
141 let rhs = eval_expr(&sub_expressions[1]);
142 lhs - rhs
143 }
144 }
145 Token::Mul => {
146 if sub_expressions.len() != 2 {
147 panic!(
148 "bad sub-expressions: {:?}, multiplication ('*') operator is binary.",
149 sub_expressions
150 )
151 } else {
152 let lhs = eval_expr(&sub_expressions[0]);
153 let rhs = eval_expr(&sub_expressions[1]);
154 lhs * rhs
155 }
156 }
157 Token::Div => {
158 if sub_expressions.len() != 2 {
159 panic!(
160 "bad sub-expressions: {:?}, division ('/') operator is binary.",
161 sub_expressions
162 )
163 } else {
164 let lhs = eval_expr(&sub_expressions[0]);
165 let rhs = eval_expr(&sub_expressions[1]);
166 lhs / rhs
167 }
168 }
169 Token::Caret => {
170 if sub_expressions.len() != 2 {
171 panic!(
172 "bad sub-expressions: {:?}, exponentiation ('^') operator is binary.",
173 sub_expressions
174 )
175 } else {
176 let lhs = eval_expr(&sub_expressions[0]);
177 let rhs = eval_expr(&sub_expressions[1]);
178 pow(lhs, rhs)
179 }
180 }
181 Token::PlusMinus => {
182 if sub_expressions.len() != 2 {
183 panic!(
184 "bad sub-expressions: {:?}, plus-minus ('±') operator is binary.",
185 sub_expressions
186 )
187 } else {
188 let lhs = eval_expr(&sub_expressions[0]);
189 let rhs = eval_expr(&sub_expressions[1]);
190 let x = match lhs {
191 Value::Number(m) | Value::PosNumber(m) => m,
192 _ => panic!(
193 "left-hand side is not a number! lhs: {:?}",
194 sub_expressions[0]
195 ),
196 };
197 let y = match rhs {
198 Value::PosNumber(m) => m,
199 _ => panic!(
200 "right-hand side is not a positive number! rhs: {:?}",
201 sub_expressions[1]
202 ),
203 };
204 Value::Measurement(Measurement::new(x, y))
205 }
206 }
207 _ => todo!(),
208 }
209 }
210 }
211}
212
213pub fn eval(input: &str) -> Value {
214 let s = expr(input);
215 eval_expr(&s)
216}
217
218#[cfg(test)]
219mod tests {
220 use super::*;
221 #[test]
222 fn tests() {
223 let s = expr("1 + 2 * 3");
224 assert_eq!(s.to_string(), "(+ 1 (* 2 3))");
225 let s = expr("--1 * 2");
226 assert_eq!(s.to_string(), "(* (- (- 1)) 2)");
227 let s = expr("(((0)))");
228 assert_eq!(s.to_string(), "0");
229 let s = expr("1 ± 2 * 3");
230 assert_eq!(s.to_string(), "(* (± 1 2) 3)");
231 }
232 #[test]
233 fn test_negative() {
234 let s = expr("-1.0 ± 2.0");
235 assert_eq!(s.to_string(), "(± (- 1.0) 2.0)");
236 }
237 #[test]
238 fn test_eval_simple() {
239 let s = expr("1 + 2");
240 assert_eq!(s.to_string(), "(+ 1 2)");
241 let val = eval_expr(&s);
242 match val {
243 Value::PosNumber(x) => assert_eq!(x, 3.0),
244 _ => panic!("Error"),
245 }
246 }
247 #[test]
248 fn test_eval_measurement() {
249 let s = expr("1.0 ± 2.0");
250 assert_eq!(s.to_string(), "(± 1.0 2.0)");
251 let val = eval_expr(&s);
252 match val {
253 Value::Measurement(m) => assert_eq!(m, Measurement::new(1.0, 2.0)),
254 _ => panic!("Error"),
255 }
256 }
257 #[test]
258 fn test_eval_measurement_add() {
259 let s = expr("1.0 ± 0.01 + 1.7 ± 0.02");
260 assert_eq!(s.to_string(), "(+ (± 1.0 0.01) (± 1.7 0.02))");
261 let val = eval_expr(&s);
262 match val {
263 Value::Measurement(m) => {
264 let actual_value = Measurement::new(1.0, 0.01) + Measurement::new(1.7, 0.02);
265 assert_eq!(m, actual_value)
266 }
267 _ => panic!("Error"),
268 }
269 }
270
271 #[test]
272 fn test_eval_measurement_div() {
273 let s = expr("1.0 ± 0.01 / 1.7 ± 0.02");
274 assert_eq!(s.to_string(), "(/ (± 1.0 0.01) (± 1.7 0.02))");
275 let val = eval_expr(&s);
276 match val {
277 Value::Measurement(m) => {
278 let actual_value = Measurement::new(1.0, 0.01) / Measurement::new(1.7, 0.02);
279 assert_eq!(m, actual_value)
280 }
281 _ => panic!("Error"),
282 }
283 }
284 #[test]
285 fn test_eval_measurement_mul() {
286 let s = expr("1.0 ± 0.01 * 1.7 ± 0.02");
287 assert_eq!(s.to_string(), "(* (± 1.0 0.01) (± 1.7 0.02))");
288 let val = eval_expr(&s);
289 match val {
290 Value::Measurement(m) => {
291 let actual_value = Measurement::new(1.0, 0.01) * Measurement::new(1.7, 0.02);
292 assert_eq!(m, actual_value)
293 }
294 _ => panic!("Error"),
295 }
296 }
297 #[test]
298 fn test_valid_parenthesis() {
299 let s = expr("(-1.0) ± 2.0");
300 assert_eq!(s.to_string(), "(± (- 1.0) 2.0)")
301 }
302 #[test]
303 #[should_panic]
304 fn test_wrong_parenthesis() {
305 expr("-1.0 (± 2.0");
306 }
307}