1use super::state::ShellState;
2
3#[derive(Debug, Clone, PartialEq)]
5pub enum ArithmeticToken {
6 Number(i64),
7 Variable(String),
8 Operator(ArithmeticOperator),
9 LeftParen,
10 RightParen,
11}
12
13#[derive(Debug, Clone, PartialEq)]
15pub enum ArithmeticOperator {
16 LogicalNot, BitwiseNot, Multiply, Divide, Modulo, Add, Subtract, ShiftLeft, ShiftRight, LessThan, LessEqual, GreaterThan, GreaterEqual, Equal, NotEqual, BitwiseAnd, BitwiseXor, BitwiseOr, LogicalAnd, LogicalOr, }
40
41impl ArithmeticOperator {
42 pub fn precedence(&self) -> i32 {
43 match self {
44 ArithmeticOperator::LogicalNot | ArithmeticOperator::BitwiseNot => 100,
45
46 ArithmeticOperator::Multiply
47 | ArithmeticOperator::Divide
48 | ArithmeticOperator::Modulo => 90,
49 ArithmeticOperator::Add | ArithmeticOperator::Subtract => 80,
50 ArithmeticOperator::ShiftLeft | ArithmeticOperator::ShiftRight => 70,
51 ArithmeticOperator::LessThan
52 | ArithmeticOperator::LessEqual
53 | ArithmeticOperator::GreaterThan
54 | ArithmeticOperator::GreaterEqual => 60,
55 ArithmeticOperator::Equal | ArithmeticOperator::NotEqual => 50,
56 ArithmeticOperator::BitwiseAnd => 40,
57 ArithmeticOperator::BitwiseXor => 30,
58 ArithmeticOperator::BitwiseOr => 20,
59 ArithmeticOperator::LogicalAnd => 10,
60 ArithmeticOperator::LogicalOr => 5,
61 }
62 }
63
64 pub fn is_unary(&self) -> bool {
65 matches!(
66 self,
67 ArithmeticOperator::LogicalNot | ArithmeticOperator::BitwiseNot
68 )
69 }
70}
71
72#[derive(Debug, Clone)]
74pub enum ArithmeticError {
75 SyntaxError(String),
76 DivisionByZero,
77 UnmatchedParentheses,
78 EmptyExpression,
79}
80
81impl std::fmt::Display for ArithmeticError {
82 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83 match self {
84 ArithmeticError::SyntaxError(msg) => write!(f, "Syntax error: {}", msg),
85 ArithmeticError::DivisionByZero => write!(f, "Division by zero"),
86 ArithmeticError::UnmatchedParentheses => write!(f, "Unmatched parentheses"),
87 ArithmeticError::EmptyExpression => write!(f, "Empty expression"),
88 }
89 }
90}
91
92pub fn tokenize_expression(expr: &str) -> Result<Vec<ArithmeticToken>, ArithmeticError> {
94 let mut tokens = Vec::new();
95 let mut chars = expr.chars().peekable();
96
97 while let Some(ch) = chars.next() {
98 match ch {
99 ' ' | '\t' | '\n' => continue, '(' => tokens.push(ArithmeticToken::LeftParen),
102 ')' => tokens.push(ArithmeticToken::RightParen),
103
104 '+' => {
105 if let Some(next_ch) = chars.peek()
106 && *next_ch == '+' {
107 return Err(ArithmeticError::SyntaxError("Unexpected ++".to_string()));
108 }
109 tokens.push(ArithmeticToken::Operator(ArithmeticOperator::Add));
110 }
111
112 '-' => {
113 if let Some(next_ch) = chars.peek()
114 && *next_ch == '-' {
115 return Err(ArithmeticError::SyntaxError("Unexpected --".to_string()));
116 }
117 tokens.push(ArithmeticToken::Operator(ArithmeticOperator::Subtract));
118 }
119
120 '*' => tokens.push(ArithmeticToken::Operator(ArithmeticOperator::Multiply)),
121 '/' => tokens.push(ArithmeticToken::Operator(ArithmeticOperator::Divide)),
122 '%' => tokens.push(ArithmeticToken::Operator(ArithmeticOperator::Modulo)),
123
124 '<' => {
125 if let Some(&next_ch) = chars.peek() {
126 if next_ch == '<' {
127 chars.next();
128 tokens.push(ArithmeticToken::Operator(ArithmeticOperator::ShiftLeft));
129 } else if next_ch == '=' {
130 chars.next();
131 tokens.push(ArithmeticToken::Operator(ArithmeticOperator::LessEqual));
132 } else {
133 tokens.push(ArithmeticToken::Operator(ArithmeticOperator::LessThan));
134 }
135 } else {
136 tokens.push(ArithmeticToken::Operator(ArithmeticOperator::LessThan));
137 }
138 }
139
140 '>' => {
141 if let Some(&next_ch) = chars.peek() {
142 if next_ch == '>' {
143 chars.next();
144 tokens.push(ArithmeticToken::Operator(ArithmeticOperator::ShiftRight));
145 } else if next_ch == '=' {
146 chars.next();
147 tokens.push(ArithmeticToken::Operator(ArithmeticOperator::GreaterEqual));
148 } else {
149 tokens.push(ArithmeticToken::Operator(ArithmeticOperator::GreaterThan));
150 }
151 } else {
152 tokens.push(ArithmeticToken::Operator(ArithmeticOperator::GreaterThan));
153 }
154 }
155
156 '=' => {
157 if let Some(&next_ch) = chars.peek() {
158 if next_ch == '=' {
159 chars.next();
160 tokens.push(ArithmeticToken::Operator(ArithmeticOperator::Equal));
161 } else {
162 return Err(ArithmeticError::SyntaxError("Unexpected =".to_string()));
163 }
164 } else {
165 return Err(ArithmeticError::SyntaxError("Unexpected =".to_string()));
166 }
167 }
168
169 '!' => {
170 if let Some(&next_ch) = chars.peek() {
171 if next_ch == '=' {
172 chars.next();
173 tokens.push(ArithmeticToken::Operator(ArithmeticOperator::NotEqual));
174 } else {
175 tokens.push(ArithmeticToken::Operator(ArithmeticOperator::LogicalNot));
176 }
177 } else {
178 tokens.push(ArithmeticToken::Operator(ArithmeticOperator::LogicalNot));
179 }
180 }
181
182 '&' => {
183 if let Some(&next_ch) = chars.peek() {
184 if next_ch == '&' {
185 chars.next();
186 tokens.push(ArithmeticToken::Operator(ArithmeticOperator::LogicalAnd));
187 } else {
188 tokens.push(ArithmeticToken::Operator(ArithmeticOperator::BitwiseAnd));
189 }
190 } else {
191 tokens.push(ArithmeticToken::Operator(ArithmeticOperator::BitwiseAnd));
192 }
193 }
194
195 '|' => {
196 if let Some(&next_ch) = chars.peek() {
197 if next_ch == '|' {
198 chars.next();
199 tokens.push(ArithmeticToken::Operator(ArithmeticOperator::LogicalOr));
200 } else {
201 tokens.push(ArithmeticToken::Operator(ArithmeticOperator::BitwiseOr));
202 }
203 } else {
204 tokens.push(ArithmeticToken::Operator(ArithmeticOperator::BitwiseOr));
205 }
206 }
207
208 '^' => tokens.push(ArithmeticToken::Operator(ArithmeticOperator::BitwiseXor)),
209 '~' => tokens.push(ArithmeticToken::Operator(ArithmeticOperator::BitwiseNot)),
210
211 '0'..='9' => {
213 let mut num_str = String::new();
214 num_str.push(ch);
215 while let Some(&next_ch) = chars.peek() {
216 if next_ch.is_ascii_digit() {
217 num_str.push(next_ch);
218 chars.next();
219 } else {
220 break;
221 }
222 }
223 match num_str.parse::<i64>() {
224 Ok(num) => tokens.push(ArithmeticToken::Number(num)),
225 Err(_) => {
226 return Err(ArithmeticError::SyntaxError("Invalid number".to_string()));
227 }
228 }
229 }
230
231 'a'..='z' | 'A'..='Z' | '_' => {
233 let mut var_name = String::new();
234 var_name.push(ch);
235 while let Some(&next_ch) = chars.peek() {
236 if next_ch.is_alphanumeric() || next_ch == '_' {
237 var_name.push(next_ch);
238 chars.next();
239 } else {
240 break;
241 }
242 }
243 tokens.push(ArithmeticToken::Variable(var_name));
244 }
245
246 _ => {
247 return Err(ArithmeticError::SyntaxError(format!(
248 "Unexpected character: {}",
249 ch
250 )));
251 }
252 }
253 }
254
255 Ok(tokens)
256}
257
258pub fn parse_to_rpn(tokens: Vec<ArithmeticToken>) -> Result<Vec<ArithmeticToken>, ArithmeticError> {
260 let mut output = Vec::new();
261 let mut operators = Vec::new();
262
263 for token in tokens {
264 match token {
265 ArithmeticToken::Number(_) | ArithmeticToken::Variable(_) => {
266 output.push(token);
267 }
268
269 ArithmeticToken::Operator(op) => {
270 if op.is_unary()
272 && (output.is_empty()
273 || matches!(
274 output.last(),
275 Some(ArithmeticToken::Operator(_) | ArithmeticToken::LeftParen)
276 ))
277 {
278 while !operators.is_empty() {
280 if let Some(ArithmeticToken::Operator(top_op)) = operators.last() {
281 if top_op.precedence() >= op.precedence() && !top_op.is_unary() {
282 output.push(operators.pop().unwrap());
283 } else {
284 break;
285 }
286 } else {
287 break;
288 }
289 }
290 operators.push(ArithmeticToken::Operator(op));
291 } else {
292 while !operators.is_empty() {
294 if let Some(ArithmeticToken::Operator(top_op)) = operators.last() {
295 if (top_op.precedence() > op.precedence())
296 || (top_op.precedence() == op.precedence() && !op.is_unary())
297 {
298 output.push(operators.pop().unwrap());
299 } else {
300 break;
301 }
302 } else {
303 break;
304 }
305 }
306 operators.push(ArithmeticToken::Operator(op));
307 }
308 }
309
310 ArithmeticToken::LeftParen => {
311 operators.push(token);
312 }
313
314 ArithmeticToken::RightParen => {
315 let mut found_left = false;
316 while let Some(op) = operators.pop() {
317 if op == ArithmeticToken::LeftParen {
318 found_left = true;
319 break;
320 } else {
321 output.push(op);
322 }
323 }
324 if !found_left {
325 return Err(ArithmeticError::UnmatchedParentheses);
326 }
327 }
328 }
329 }
330
331 while let Some(op) = operators.pop() {
333 if op == ArithmeticToken::LeftParen {
334 return Err(ArithmeticError::UnmatchedParentheses);
335 }
336 output.push(op);
337 }
338
339 Ok(output)
340}
341
342pub fn evaluate_rpn(
344 rpn_tokens: Vec<ArithmeticToken>,
345 shell_state: &ShellState,
346) -> Result<i64, ArithmeticError> {
347 let mut stack = Vec::new();
348
349 for token in rpn_tokens {
350 match token {
351 ArithmeticToken::Number(num) => {
352 stack.push(num);
353 }
354
355 ArithmeticToken::Variable(var_name) => {
356 if let Some(value) = shell_state.get_var(&var_name) {
357 match value.parse::<i64>() {
358 Ok(num) => stack.push(num),
359 Err(_) => {
360 stack.push(0)
362 }
363 }
364 } else {
365 stack.push(0)
367 }
368 }
369
370 ArithmeticToken::Operator(op) => {
371 if op.is_unary() {
372 if stack.is_empty() {
373 return Err(ArithmeticError::SyntaxError(
374 "Missing operand for unary operator".to_string(),
375 ));
376 }
377 let operand = stack.pop().unwrap();
378 let result = match op {
379 ArithmeticOperator::LogicalNot => !operand,
380 ArithmeticOperator::BitwiseNot => !operand,
381 _ => unreachable!(),
382 };
383 stack.push(result);
384 } else {
385 if stack.len() < 2 {
386 return Err(ArithmeticError::SyntaxError(
387 "Missing operands for binary operator".to_string(),
388 ));
389 }
390 let right = stack.pop().unwrap();
391 let left = stack.pop().unwrap();
392 let result = match op {
393 ArithmeticOperator::Add => left + right,
394 ArithmeticOperator::Subtract => left - right,
395 ArithmeticOperator::Multiply => left * right,
396 ArithmeticOperator::Divide => {
397 if right == 0 {
398 return Err(ArithmeticError::DivisionByZero);
399 }
400 left / right
401 }
402 ArithmeticOperator::Modulo => {
403 if right == 0 {
404 return Err(ArithmeticError::DivisionByZero);
405 }
406 left % right
407 }
408 ArithmeticOperator::ShiftLeft => left << right,
409 ArithmeticOperator::ShiftRight => left >> right,
410 ArithmeticOperator::LessThan => {
411 if left < right {
412 1
413 } else {
414 0
415 }
416 }
417 ArithmeticOperator::LessEqual => {
418 if left <= right {
419 1
420 } else {
421 0
422 }
423 }
424 ArithmeticOperator::GreaterThan => {
425 if left > right {
426 1
427 } else {
428 0
429 }
430 }
431 ArithmeticOperator::GreaterEqual => {
432 if left >= right {
433 1
434 } else {
435 0
436 }
437 }
438 ArithmeticOperator::Equal => {
439 if left == right {
440 1
441 } else {
442 0
443 }
444 }
445 ArithmeticOperator::NotEqual => {
446 if left != right {
447 1
448 } else {
449 0
450 }
451 }
452 ArithmeticOperator::BitwiseAnd => left & right,
453 ArithmeticOperator::BitwiseXor => left ^ right,
454 ArithmeticOperator::BitwiseOr => left | right,
455 ArithmeticOperator::LogicalAnd => {
456 if left != 0 && right != 0 {
457 1
458 } else {
459 0
460 }
461 }
462 ArithmeticOperator::LogicalOr => {
463 if left != 0 || right != 0 {
464 1
465 } else {
466 0
467 }
468 }
469 _ => unreachable!(),
470 };
471 stack.push(result);
472 }
473 }
474
475 ArithmeticToken::LeftParen | ArithmeticToken::RightParen => {
476 return Err(ArithmeticError::SyntaxError(
477 "Unexpected parenthesis in RPN".to_string(),
478 ));
479 }
480 }
481 }
482
483 if stack.len() != 1 {
484 return Err(ArithmeticError::SyntaxError(
485 "Invalid expression".to_string(),
486 ));
487 }
488
489 Ok(stack[0])
490}
491
492pub fn evaluate_arithmetic_expression(
494 expr: &str,
495 shell_state: &ShellState,
496) -> Result<i64, ArithmeticError> {
497 if expr.trim().is_empty() {
498 return Err(ArithmeticError::EmptyExpression);
499 }
500
501 let tokens = tokenize_expression(expr)?;
502 let rpn_tokens = parse_to_rpn(tokens)?;
503 let result = evaluate_rpn(rpn_tokens, shell_state)?;
504
505 Ok(result)
506}
507
508#[cfg(test)]
509mod tests {
510 use super::*;
511
512 #[test]
513 fn test_tokenize_simple_numbers() {
514 let tokens = tokenize_expression("42").unwrap();
515 assert_eq!(tokens, vec![ArithmeticToken::Number(42)]);
516 }
517
518 #[test]
519 fn test_tokenize_operators() {
520 let tokens = tokenize_expression("2+3").unwrap();
521 assert_eq!(
522 tokens,
523 vec![
524 ArithmeticToken::Number(2),
525 ArithmeticToken::Operator(ArithmeticOperator::Add),
526 ArithmeticToken::Number(3)
527 ]
528 );
529 }
530
531 #[test]
532 fn test_tokenize_parentheses() {
533 let tokens = tokenize_expression("(2+3)").unwrap();
534 assert_eq!(
535 tokens,
536 vec![
537 ArithmeticToken::LeftParen,
538 ArithmeticToken::Number(2),
539 ArithmeticToken::Operator(ArithmeticOperator::Add),
540 ArithmeticToken::Number(3),
541 ArithmeticToken::RightParen
542 ]
543 );
544 }
545
546 #[test]
547 fn test_tokenize_variables() {
548 let tokens = tokenize_expression("x+y").unwrap();
549 assert_eq!(
550 tokens,
551 vec![
552 ArithmeticToken::Variable("x".to_string()),
553 ArithmeticToken::Operator(ArithmeticOperator::Add),
554 ArithmeticToken::Variable("y".to_string())
555 ]
556 );
557 }
558
559 #[test]
560 fn test_evaluate_simple() {
561 let shell_state = ShellState::new();
562 let result = evaluate_arithmetic_expression("42", &shell_state).unwrap();
563 assert_eq!(result, 42);
564 }
565
566 #[test]
567 fn test_evaluate_addition() {
568 let shell_state = ShellState::new();
569 let result = evaluate_arithmetic_expression("2+3", &shell_state).unwrap();
570 assert_eq!(result, 5);
571 }
572
573 #[test]
574 fn test_evaluate_with_precedence() {
575 let shell_state = ShellState::new();
576 let result = evaluate_arithmetic_expression("2+3*4", &shell_state).unwrap();
577 assert_eq!(result, 14); }
579
580 #[test]
581 fn test_evaluate_with_parentheses() {
582 let shell_state = ShellState::new();
583 let result = evaluate_arithmetic_expression("(2+3)*4", &shell_state).unwrap();
584 assert_eq!(result, 20); }
586
587 #[test]
588 fn test_evaluate_comparison() {
589 let shell_state = ShellState::new();
590 let result = evaluate_arithmetic_expression("5>3", &shell_state).unwrap();
591 assert_eq!(result, 1); let result = evaluate_arithmetic_expression("3>5", &shell_state).unwrap();
594 assert_eq!(result, 0); }
596
597 #[test]
598 fn test_evaluate_variable() {
599 let mut shell_state = ShellState::new();
600 shell_state.set_var("x", "10".to_string());
601 let result = evaluate_arithmetic_expression("x + 5", &shell_state).unwrap();
602 assert_eq!(result, 15);
603 }
604
605 #[test]
606 fn test_evaluate_division_by_zero() {
607 let shell_state = ShellState::new();
608 let result = evaluate_arithmetic_expression("5/0", &shell_state);
609 assert!(matches!(result, Err(ArithmeticError::DivisionByZero)));
610 }
611
612 #[test]
613 fn test_evaluate_undefined_variable() {
614 let shell_state = ShellState::new();
615 let result = evaluate_arithmetic_expression("undefined + 5", &shell_state);
616 assert_eq!(result.unwrap(), 5);
618 }
619}